diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..b6b28ea --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,35 @@ +name: ci + +on: + push: + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Run tests + run: python -m unittest cmus_status_scrobbler.py tests.py test_cmus_status_scrobbler.py + typecheck: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install typecheck dependencies + run: python -m pip install -r requirements-dev.txt + - name: Run typecheck + run: make typecheck PYTHON=python diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml deleted file mode 100644 index 11f7ce3..0000000 --- a/.github/workflows/run-tests.yml +++ /dev/null @@ -1,27 +0,0 @@ -name: tests - -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] - -jobs: - Run-Tests: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.x"] - steps: - - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - - name: Run tests - run: python tests.py - - - name: Concurrency tests - run: python test_cmus_status_scrobbler.py diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..1227ed6 --- /dev/null +++ b/Makefile @@ -0,0 +1,22 @@ +PYTHON ?= ./venv/bin/python +VENV ?= ./venv +VENV_PYTHON := $(VENV)/bin/python +FILES := cmus_status_scrobbler.py tests.py test_cmus_status_scrobbler.py + +.PHONY: test typecheck format check venv + +test: + $(PYTHON) -m unittest $(FILES) + +typecheck: + $(PYTHON) -m mypy --strict --check-untyped-defs $(FILES) + $(PYTHON) -m pyright $(FILES) + +format: + $(PYTHON) -m yapf -i $(FILES) + +check: typecheck test + +venv: + python3 -m venv $(VENV) + $(VENV_PYTHON) -m pip install -r requirements-dev.txt diff --git a/README.md b/README.md index 8566668..1b51327 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # cmus-status-scrobbler -![tests passing status](https://github.com/vjeranc/cmus-status-scrobbler/actions/workflows/run-tests.yml/badge.svg?branch=main) +![tests passing status](https://github.com/vjeranc/cmus-status-scrobbler/actions/workflows/ci.yml/badge.svg?branch=main) -Works with [cmus](https://cmus.github.io/). Requires Python 3 and has no +Works with [cmus](https://cmus.github.io/). Requires Python 3.9+ and has no additional dependencies. **Features:** diff --git a/cmus_status_scrobbler.py b/cmus_status_scrobbler.py index 406a82e..9ac034b 100755 --- a/cmus_status_scrobbler.py +++ b/cmus_status_scrobbler.py @@ -1,6 +1,17 @@ #!/usr/bin/env python3 +""" +cmus_status_scrobbler entry point and core logic. + +Design: a ReaderT-like pattern where effectful operations are captured in +explicit env objects (HTTP/DB) built from fully-resolved config. Pure +functions (parsing, scrobble decision logic) sit at the top level, while +effectful programs (auth, update scrobble state) take envs first to keep +dependencies explicit and tests configurable. +""" +from __future__ import annotations import argparse +import dataclasses import configparser import datetime import hashlib @@ -13,306 +24,602 @@ import time import urllib.parse as up import urllib.request as ur -from collections import namedtuple -from functools import reduce +from collections.abc import Callable, Sequence +from dataclasses import dataclass from operator import attrgetter -CONFIG_PATH = '~/.config/cmus/cmus_status_scrobbler.ini' -DB_CONNECT_TIMEOUT = 300 -DB_PATH = '~/.config/cmus/cmus_status_scrobbler.sqlite3' -SCROBBLE_BATCH_SIZE = 50 - -parser = argparse.ArgumentParser(description="Scrobbling.") -parser.add_argument('--auth', - action='store_true', - help="Add if you're missing session_key in .ini file.") -parser.add_argument('--ini', - type=str, - default=os.path.expanduser(CONFIG_PATH), - help='Path to .ini configuration file.') -parser.add_argument('--db-path', - type=str, - default=os.path.expanduser(DB_PATH), - help='Path to sqlite3 database') -parser.add_argument( - '--log-path', - type=str, - required=False, - help='If given logging will be saved to desired path (default: no logging)' +from typing import ( + Annotated, + Dict, + List, + NamedTuple, + Optional, + TYPE_CHECKING, + Union, + get_args, + get_origin, + get_type_hints, ) -parser.add_argument('--log-db', - action='store_true', - default=False, - help='If given, SQL queries are logged') - - -class StatusDB: +if TYPE_CHECKING: + from typing_extensions import TypeAlias, TypeGuard +STATUS_STOPPED = 'stopped' +STATUS_PLAYING = 'playing' +STATUS_PAUSED = 'paused' - def __init__(self, connection, table_name): - self.con = connection - self.table_name = f'status_updates_{table_name}' - self.create() +SCROBBLER_GET_TOKEN = 'auth.gettoken' +SCROBBLER_GET_SESSION = 'auth.getsession' +SCROBBLER_NOW_PLAYING = 'track.updateNowPlaying' +SCROBBLER_SCROBBLE = 'track.scrobble' - def create(self): - self.con.execute( - f"CREATE TABLE IF NOT EXISTS {self.table_name} (pickle BLOB)") +KEYS_TO_REDACT = [b'api_key', b'sk', b'api_sig', b'token', b'session_key'] - def get_status_updates(self): - cur = self.con.cursor() - cur.execute(f"SELECT * FROM {self.table_name}") - status_updates = [] +JSONValue: 'TypeAlias' = Union[str, int, float, bool, None, List['JSONValue'], + Dict[str, 'JSONValue']] + + +@dataclass(frozen=True) +class ArgDef: + flags: tuple[str, ...] + help: str + action: Optional[str] = None + required: bool = False + + +@dataclass(frozen=True) +class Args: + auth: Annotated[ + bool, + ArgDef( + flags=('--auth', ), + help="Add if you're missing session_key in .ini file.", + action='store_true', + ), + ] + ini: Annotated[ + str, + ArgDef( + flags=('--ini', ), + help='Path to .ini configuration file.', + ), + ] + db_path: Annotated[ + str, + ArgDef( + flags=('--db-path', ), + help='Path to sqlite3 database', + ), + ] + log_path: Annotated[ + Optional[str], + ArgDef( + flags=('--log-path', ), + help= + 'If given logging will be saved to desired path (default: no logging)', + ), + ] + log_db: Annotated[ + bool, + ArgDef( + flags=('--log-db', ), + help='If given, SQL queries are logged', + action='store_true', + ), + ] + cur_time: Annotated[ + Optional[float], + ArgDef( + flags=('--cur-time', ), + help='Override current time for status update (unix timestamp).', + ), + ] + + +@dataclass(frozen=True) +class AppDefaults: + config_path: str + db_path: str + db_connect_timeout: int + db_connect_retry_attempts: int + db_connect_retry_sleep_secs: int + scrobble_batch_size: int + http_user_agent: str + http_default_timeout_secs: float + http_scrobble_timeout_secs: float + + +@dataclass(frozen=True) +class GlobalConfig: + api_key: str + shared_secret: str + db_path: str + log_db: bool + now_playing: bool + format_xml: bool + log_path: Optional[str] + + +@dataclass(frozen=True) +class ServiceConfig: + name: str + api_url: str + auth_url: str + api_key: str + shared_secret: str + session_key: Optional[str] + now_playing: bool + format_xml: bool + + +@dataclass(frozen=True) +class AppConfig: + global_config: GlobalConfig + services: list[ServiceConfig] + + +class Status(NamedTuple): + status: str + file: str + artist: Optional[str] + albumartist: Optional[str] + album: Optional[str] + discnumber: Optional[Union[str, int]] + tracknumber: Optional[str] + title: Optional[str] + date: Optional[str] + duration: Optional[Union[str, int]] + musicbrainz_trackid: Optional[str] + cur_time: float + + +def build_parser(args_type: type[Args], + defaults: AppDefaults) -> argparse.ArgumentParser: + arg_defaults = args_type( + auth=False, + ini=os.path.expanduser(defaults.config_path), + db_path=os.path.expanduser(defaults.db_path), + log_path=None, + log_db=False, + cur_time=None, + ) + parser = argparse.ArgumentParser(description='Scrobbling.') + type_hints = get_type_hints(args_type, include_extras=True) + for field in dataclasses.fields(args_type): + hint = type_hints[field.name] + arg_def: Optional[ArgDef] = None + base_type = hint + if get_origin(hint) is Annotated: + args = get_args(hint) + base_type = args[0] + for meta in args[1:]: + if isinstance(meta, ArgDef): + arg_def = meta + if arg_def is None: + raise ValueError(f'Missing ArgDef for {field.name}') + arg_type = base_type + origin = get_origin(base_type) + if origin is Union: + args = get_args(base_type) + if len(args)!=2 or type(None) not in args: + raise ValueError( + f'Unsupported union type for {field.name}: {base_type}') + arg_type = args[0] if args[1] is type(None) else args[1] + if arg_def.action in {'store_true', 'store_false'}: + if arg_type is not bool: + raise ValueError( + f'Action {arg_def.action} requires bool for {field.name}') + elif arg_type not in {str, int, float}: + raise ValueError(f'Unsupported type for {field.name}: {arg_type}') + default_value = getattr(arg_defaults, field.name) + if arg_def.action is not None: + parser.add_argument( + *arg_def.flags, + action=arg_def.action, + default=default_value, + required=arg_def.required, + help=arg_def.help, + ) + else: + parser.add_argument( + *arg_def.flags, + type=arg_type, + default=default_value, + required=arg_def.required, + help=arg_def.help, + ) + return parser + + +@dataclass(frozen=True) +class DBEnv: + create: Callable[[], None] + get_status_updates: Callable[[], list[Status]] + clear: Callable[[], None] + save_status_updates: Callable[[list[Status]], None] + + +def make_db_env( + *, + con: sqlite3.Connection, + table_name: str, +) -> DBEnv: + + def status_db_table() -> str: + return f'status_updates_{table_name}' + + def create() -> None: + con.execute( + f'CREATE TABLE IF NOT EXISTS {status_db_table()} (pickle BLOB)') + + def get_status_updates() -> list[Status]: + cur = con.cursor() + cur.execute(f'SELECT * FROM {status_db_table()}') + status_updates: list[Status] = [] for row in cur: - status_updates.append(pickle.loads(row[0])) - su = status_updates[-1] - if isinstance(su.cur_time, datetime.datetime): + loaded = pickle.loads(row[0]) + if not isinstance(loaded, Status): + raise TypeError('Unexpected status update payload.') + status_update = loaded + if isinstance(status_update.cur_time, datetime.datetime): # FIXME remove in 3 years, assuming everyone is on latest - status_updates[-1] = su._replace( - cur_time=su.cur_time.timestamp()) + status_update = status_update._replace( + cur_time=status_update.cur_time.timestamp()) + status_updates.append(status_update) return status_updates - def clear(self): - self.con.execute(f"DELETE FROM {self.table_name}") + def clear() -> None: + con.execute(f'DELETE FROM {status_db_table()}') - def save_status_updates(self, status_updates): + def save_status_updates(status_updates: list[Status]) -> None: if not status_updates: return - self.con.executemany( - f"INSERT INTO {self.table_name}(pickle) values (?)", - [(pickle.dumps(su), ) for su in status_updates]) - - -class CmusStatus: - stopped = "stopped" - playing = "playing" - paused = "paused" - - -Status = namedtuple('Status', [ - 'status', 'file', 'artist', 'albumartist', 'album', 'discnumber', - 'tracknumber', 'title', 'date', 'duration', 'musicbrainz_trackid', - 'cur_time' -]) - - -def safe_utf8_encode(text): - try: - return text.encode('utf-8') - except UnicodeEncodeError: - return text.encode('utf-8', errors='ignore') - - -def get_api_sig(params, secret): - m = hashlib.md5() - for k in sorted(params): - m.update(k) - m.update(params[k]) - m.update(secret.encode('utf-8')) - return m.hexdigest() - - -KEYS_TO_REDACT = [b'api_key', b'sk', b'api_sig', b'token', b'session_key'] + con.executemany( + f'INSERT INTO {status_db_table()}(pickle) values (?)', + [(pickle.dumps(su), ) for su in status_updates], + ) + + return DBEnv( + create=create, + get_status_updates=get_status_updates, + clear=clear, + save_status_updates=save_status_updates, + ) -def redact_dict(d): - if not isinstance(d, dict): - return d - d = d.copy() - for k, v in d.items(): - if k in KEYS_TO_REDACT: - d[k] = '' - return d - - -def send_req(api_url, - api_key, - ignore_request_fail=False, - shared_secret=None, - method=None, - xml=False, - timeout_secs=10., - **params): - params = dict(**params) - params['api_key'] = api_key - params['method'] = method - params = { - safe_utf8_encode(k): safe_utf8_encode(v) - for k, v in params.items() if v is not None - } - if shared_secret: - params['api_sig'] = get_api_sig(params, shared_secret) - if not xml: - params['format'] = 'json' - logging.info(redact_dict(params)) - api_req = ur.Request(api_url, headers={"User-Agent": "Mozilla/5.0"}) - try: - with ur.urlopen(api_req, - up.urlencode(params, encoding='utf-8', - errors='ignore').encode(), - timeout=timeout_secs) as f: - res = f.read().decode('utf-8') - logging.info(res) - if not res: - return None - if not xml: - return json.loads(res) - return res - except Exception as e: - if not ignore_request_fail: - raise e - logging.exception('Ignoring error.') - return None - - -class Scrobbler: - - def __init__(self, - name, - api_url, - api_key, - shared_secret, - session_key, - now_playing, - xml=False): - self.name = name - self.api_url = api_url - self.api_key = api_key - self.shared_secret = shared_secret - self.sk = session_key - self.now_playing = now_playing - self.xml = xml - - @staticmethod - def auth(auth_url, api_url, api_key, shared_secret, xml=False): - # fetching token that is used to ask for access - token = send_req(api_url, - api_key, - method=ScrobblerMethod.GET_TOKEN, - xml=xml) - if xml: - token = token.split("")[1].split("")[0] +@dataclass(frozen=True) +class HttpEnv: + auth: Callable[[], dict[str, str]] + scrobble: Callable[[list[Status]], None] + send_now_playing: Callable[[Status], None] + + +def make_http_env( + *, + service_config: ServiceConfig, + defaults: AppDefaults, + session_key: Optional[str], + logger: logging.LoggerAdapter[logging.Logger], +) -> HttpEnv: + + def is_json_value(value: JSONValue) -> TypeGuard[JSONValue]: + if value is None or isinstance(value, (str, int, float, bool)): + return True + if isinstance(value, list): + return all(is_json_value(item) for item in value) + if isinstance(value, dict): + return all( + isinstance(key, str) and is_json_value(item) + for key, item in value.items()) + return False + + def send_req( + *, + ignore_request_fail: bool, + method: str, + timeout_secs: Optional[float], + params: dict[str, Optional[str]], + ) -> JSONValue: + + def safe_utf8_encode(text: str) -> bytes: + try: + return text.encode('utf-8') + except UnicodeEncodeError: + return text.encode('utf-8', errors='ignore') + + def get_api_sig(encoded_params: dict[bytes, bytes], + secret: str) -> str: + sig = hashlib.md5() + for key in sorted(encoded_params): + sig.update(key) + sig.update(encoded_params[key]) + sig.update(secret.encode('utf-8')) + return sig.hexdigest() + + def redact_dict(data: dict[bytes, bytes]) -> dict[bytes, bytes]: + return { + key: (b'' if key in KEYS_TO_REDACT else value) + for key, value in data.items() + } + + merged: dict[str, str] = { + 'api_key': service_config.api_key, + 'method': method, + } + for key, value in params.items(): + if value is None: + continue + merged[key] = value + + encoded_params = { + safe_utf8_encode(key): safe_utf8_encode(value) + for key, value in merged.items() + } + if method!=SCROBBLER_GET_TOKEN: + encoded_params[b'api_sig'] = safe_utf8_encode( + get_api_sig(encoded_params, service_config.shared_secret)) + if not service_config.format_xml: + encoded_params[b'format'] = b'json' + logger.info(redact_dict(encoded_params)) + api_req = ur.Request(service_config.api_url, + headers={'User-Agent': defaults.http_user_agent}) + timeout = (defaults.http_default_timeout_secs + if timeout_secs is None else timeout_secs) + try: + with ur.urlopen( + api_req, + up.urlencode(encoded_params).encode(), + timeout=timeout, + ) as response: + payload: str = response.read().decode('utf-8') + logger.info(payload) + if not payload: + return None + if not service_config.format_xml: + loaded = json.loads(payload) + if not is_json_value(loaded): + raise ValueError('Unexpected JSON response.') + return loaded + return payload + except Exception: + if not ignore_request_fail: + raise + logger.exception('Ignoring error.') + return None + + def auth() -> dict[str, str]: + + def require_text(value: JSONValue, label: str) -> str: + if not isinstance(value, str): + raise ValueError(f'Missing {label} in response.') + return value + + def require_dict(value: JSONValue, label: str) -> dict[str, JSONValue]: + if not isinstance(value, dict): + raise ValueError(f'Missing {label} in response.') + return value + + token_response = send_req( + ignore_request_fail=False, + method=SCROBBLER_GET_TOKEN, + timeout_secs=None, + params={}, + ) + if service_config.format_xml: + token_payload = require_text(token_response, 'token') + token = token_payload.split('')[1].split('')[0] else: - token = token['token'] - print(f'{auth_url}?'+up.urlencode(dict(token=token, api_key=api_key))) + token_dict = require_dict(token_response, 'token') + token_value = token_dict.get('token') + token = require_text(token_value, 'token') + print(f'{service_config.auth_url}?'+ + up.urlencode(dict(token=token, api_key=service_config.api_key))) input('Press after visiting the link and allowing access...') - # fetching session with infinite lifetime that is used to scrobble - session = send_req(api_url, - api_key, - shared_secret=shared_secret, - method=ScrobblerMethod.GET_SESSION, - token=token, - xml=xml) - if xml: - session = dict(key=session.split("")[1].split("")[0], - name=session.split("")[1].split("")[0]) + session_response = send_req( + ignore_request_fail=False, + method=SCROBBLER_GET_SESSION, + timeout_secs=None, + params={'token': token}, + ) + if service_config.format_xml: + session_payload = require_text(session_response, 'session') + key = session_payload.split('')[1].split('')[0] + name = session_payload.split('')[1].split('')[0] + session: dict[str, JSONValue] = {'key': key, 'name': name} else: - session = session['session'] - return dict(session_key=session['key'], username=session['name']) - - @staticmethod - def make_scrobble(i, su): - return { - f'artist[{i}]': su.artist, - f'track[{i}]': su.title, - f'timestamp[{i}]': str(int(su.cur_time)), - f'album[{i}]': su.album, - f'trackNumber[{i}]': su.tracknumber, - f'mbid[{i}]': su.musicbrainz_trackid, - f'albumArtist[{i}]': - su.albumartist if su.artist!=su.albumartist else None, - f'duration[{i}]': su.duration, - } + session_dict = require_dict(session_response, 'session') + session_value = session_dict.get('session') + session = require_dict(session_value, 'session') + session_key = require_text(session.get('key'), 'session key') + username = require_text(session.get('name'), 'username') + return {'session_key': session_key, 'username': username} + + def scrobble(status_updates: list[Status]) -> None: + + def make_scrobble(i: int, + status_update: Status) -> dict[str, Optional[str]]: + return { + f'artist[{i}]': + status_update.artist, + f'track[{i}]': + status_update.title, + f'timestamp[{i}]': + str(int(status_update.cur_time)), + f'album[{i}]': + status_update.album, + f'trackNumber[{i}]': + status_update.tracknumber, + f'mbid[{i}]': + status_update.musicbrainz_trackid, + f'albumArtist[{i}]': + status_update.albumartist + if status_update.artist!=status_update.albumartist else None, + f'duration[{i}]': + None if status_update.duration is None else str( + status_update.duration), + } - def scrobble(self, status_updates): if not status_updates: return - logging.info(f'Scrobbling previous tracks for {self.name}') - # ignoring status updates with status other than playing - playing_sus = filter(lambda x: x.status==CmusStatus.playing, - status_updates) - batch_scrobble_request = reduce(lambda a, b: { - **a, - **b - }, [ - Scrobbler.make_scrobble(i, su) - for (i, su) in enumerate(playing_sus) - ], dict(sk=self.sk)) - if not batch_scrobble_request: + logger.info('Scrobbling previous tracks') + playing_updates = [ + update for update in status_updates + if update.status==STATUS_PLAYING + ] + if not playing_updates: return + batch_scrobble_request: dict[str, Optional[str]] = {'sk': session_key} + for i, status_update in enumerate(playing_updates): + batch_scrobble_request.update(make_scrobble(i, status_update)) send_req( - self.api_url, - self.api_key, - shared_secret=self.shared_secret, - method=ScrobblerMethod.SCROBBLE, - xml=self.xml, - timeout_secs=5., # scrobbling is not critical (saved in db) - **batch_scrobble_request) - - def send_now_playing(self, cur): - if not self.now_playing or cur.status!=CmusStatus.playing: + ignore_request_fail=False, + method=SCROBBLER_SCROBBLE, + timeout_secs=defaults.http_scrobble_timeout_secs, + params=batch_scrobble_request, + ) + + def send_now_playing(cur: Status) -> None: + if not service_config.now_playing or cur.status!=STATUS_PLAYING: return - - logging.info(f'Sending now playing for {self.name}') - params = dict(artist=cur.artist, - track=cur.title, - album=cur.album, - trackNumber=cur.tracknumber, - duration=cur.duration, - albumArtist=cur.albumartist - if cur.artist!=cur.albumartist else None, - mbid=cur.musicbrainz_trackid, - sk=self.sk) - send_req(self.api_url, - self.api_key, - ignore_request_fail=True, - shared_secret=self.shared_secret, - method=ScrobblerMethod.NOW_PLAYING, - xml=self.xml, - **params) - - -def parse_cmus_status_line(ls): - logging.info(ls) - r = dict( - cur_time=datetime.datetime.now(datetime.timezone.utc).timestamp(), - musicbrainz_trackid=None, - discnumber=1, - tracknumber=None, - date=None, - album=None, - albumartist=None, - artist=None, + logger.info('Sending now playing') + params = dict( + artist=cur.artist, + track=cur.title, + album=cur.album, + trackNumber=cur.tracknumber, + duration=None if cur.duration is None else str(cur.duration), + albumArtist=cur.albumartist + if cur.artist!=cur.albumartist else None, + mbid=cur.musicbrainz_trackid, + sk=session_key, + ) + send_req( + ignore_request_fail=True, + method=SCROBBLER_NOW_PLAYING, + timeout_secs=None, + params=params, + ) + + return HttpEnv( + auth=auth, + scrobble=scrobble, + send_now_playing=send_now_playing, ) - r.update((k, v) for k, v in zip(ls[::2], ls[1::2])) - return Status(**r) - -def has_played_enough(start_ts, - end_ts, - duration, - perc_thresh, - secs_thresh, - ptbp=0): - duration = int(duration) - total = end_ts-start_ts+ptbp - return total/duration>=perc_thresh or total>=secs_thresh - -def equal_tracks(a, b): - return a.file==b.file +@dataclass(frozen=True) +class ScrobblingEnv: + http: HttpEnv + db: DBEnv + logger: logging.LoggerAdapter[logging.Logger] + + +def make_scrobbling_env( + *, + con: sqlite3.Connection, + http_env: HttpEnv, + table_name: str, + logger: logging.LoggerAdapter[logging.Logger], +) -> ScrobblingEnv: + return ScrobblingEnv( + http=http_env, + db=make_db_env(con=con, table_name=table_name), + logger=logger, + ) -def get_prefix_end_exclusive_idx(status_updates): - r_su = list(reversed(status_updates)) - for i, (cur, prv) in enumerate(zip(r_su, r_su[1:])): - if (cur.status==CmusStatus.stopped or not equal_tracks(cur, prv) - or cur.status==prv.status or prv.status==CmusStatus.stopped): - return len(r_su)-i - return 0 # all statuses do not result in a scrobble +def parse_cmus_status_line( + parts: Sequence[str], + logger: logging.LoggerAdapter[logging.Logger], +) -> Status: + logger.info(parts) + cur_time = datetime.datetime.now(datetime.timezone.utc).timestamp() + musicbrainz_trackid = None + discnumber: Optional[Union[str, int]] = 1 + tracknumber = None + date = None + album = None + albumartist = None + artist = None + status = '' + file = '' + title = None + duration: Optional[Union[str, int]] = None + for key, value in zip(parts[::2], parts[1::2]): + if key=='cur_time': + try: + cur_time = float(value) + except ValueError: + cur_time = datetime.datetime.now( + datetime.timezone.utc).timestamp() + elif key=='musicbrainz_trackid': + musicbrainz_trackid = value + elif key=='discnumber': + discnumber = value + elif key=='tracknumber': + tracknumber = value + elif key=='date': + date = value + elif key=='album': + album = value + elif key=='albumartist': + albumartist = value + elif key=='artist': + artist = value + elif key=='status': + status = value + elif key=='file': + file = value + elif key=='title': + title = value + elif key=='duration': + duration = value + return Status( + status=status, + file=file, + artist=artist, + albumartist=albumartist, + album=album, + discnumber=discnumber, + tracknumber=tracknumber, + title=title, + date=date, + duration=duration, + musicbrainz_trackid=musicbrainz_trackid, + cur_time=cur_time, + ) -def calculate_scrobbles(status_updates, perc_thresh=0.5, secs_thresh=4*60): - scrobbles, leftovers = [], [] +def calculate_scrobbles( + status_updates: Sequence[Status], + perc_thresh: float = 0.5, + secs_thresh: int = 4*60, +) -> tuple[list[Status], list[Status]]: + + def has_played_enough( + start_ts: float, + end_ts: float, + duration_value: Optional[Union[str, int]], + played_before_pause: float = 0.0, + ) -> bool: + if duration_value is None: + return False + duration = int(duration_value) + total = end_ts-start_ts+played_before_pause + return total/duration>=perc_thresh or total>=secs_thresh + + def equal_tracks(first: Status, second: Status) -> bool: + return first.file==second.file + + def get_prefix_end_exclusive_idx(sus: Sequence[Status]) -> int: + r_su = list(reversed(sus)) + for i, (cur, prv) in enumerate(zip(r_su, r_su[1:])): + if (cur.status==STATUS_STOPPED or not equal_tracks(cur, prv) + or cur.status==prv.status or prv.status==STATUS_STOPPED): + return len(r_su)-i + return 0 # all statuses do not result in a scrobble + + scrobbles: list[Status] = [] + leftovers: list[Status] = [] if not status_updates or len(status_updates)==1: - return scrobbles, status_updates or leftovers + return scrobbles, list(status_updates) # if status updates array has a suffix of playing/paused updates with same # track, then these tracks need to be immediatelly leftovers @@ -321,28 +628,29 @@ def calculate_scrobbles(status_updates, perc_thresh=0.5, secs_thresh=4*60): lsus = sus[:prefix_end] # I am incapable of having simple thoughts. The pause is messing me up. # I use these two variables to scrobble paused tracks. - ptbp = 0 # played time before pausing - ptbp_status = None + played_before_pause = 0.0 + played_before_pause_status: Optional[Status] = None for cur, nxt, nxt2 in it.zip_longest(lsus, lsus[1:], lsus[2:]): - if cur.status in [CmusStatus.stopped, CmusStatus.paused]: + if cur.status in [STATUS_STOPPED, STATUS_PAUSED]: continue if nxt is None: leftovers.append(cur) break - hpe = has_played_enough( + played_enough = has_played_enough( cur.cur_time, nxt.cur_time, cur.duration, - perc_thresh, - secs_thresh, - ptbp=ptbp if ptbp_status and equal_tracks(ptbp_status, cur) else 0) + played_before_pause=played_before_pause + if played_before_pause_status + and equal_tracks(played_before_pause_status, cur) else 0.0, + ) if (not equal_tracks(cur, nxt) - or nxt.status in [CmusStatus.stopped, CmusStatus.playing]): - if hpe: + or nxt.status in [STATUS_STOPPED, STATUS_PLAYING]): + if played_enough: scrobbles.append(cur) - ptbp = 0 - ptbp_status = None + played_before_pause = 0.0 + played_before_pause_status = None continue # files are equal and nxt status paused @@ -351,81 +659,152 @@ def calculate_scrobbles(status_updates, perc_thresh=0.5, secs_thresh=4*60): leftovers.append(nxt) continue - if equal_tracks(cur, nxt2) and nxt2.status==CmusStatus.playing: + if equal_tracks(cur, nxt2) and nxt2.status==STATUS_PLAYING: # playing continued, keeping already played time for next - ptbp += nxt.cur_time-cur.cur_time - ptbp_status = cur if not ptbp_status else ptbp_status + played_before_pause += nxt.cur_time-cur.cur_time + played_before_pause_status = cur if not played_before_pause_status else played_before_pause_status continue # playing did not continue, nxt2 file is not None and it's either a # different file or it's the same file but status is not playing # in this case we just check if played enough otherwise no scrobble - if hpe: - scrobbles.append(ptbp_status or cur) + if played_enough: + scrobbles.append(played_before_pause_status or cur) return scrobbles, leftovers+sus[prefix_end:] -class ScrobblerMethod: - GET_TOKEN = 'auth.gettoken' - GET_SESSION = 'auth.getsession' - NOW_PLAYING = 'track.updateNowPlaying' - SCROBBLE = 'track.scrobble' - - -def update_scrobble_state(db, scrobbler, new_status_update): - sus = db.get_status_updates() - sus.append(new_status_update) - db.save_status_updates([new_status_update]) - scrobbles, leftovers = calculate_scrobbles(sus) - failed_scrobbles = [] - for i in range(0, len(scrobbles), SCROBBLE_BATCH_SIZE): +def run_update_scrobble_state( + env: ScrobblingEnv, + new_status_update: Status, + scrobble_batch_size: int, +) -> None: + env.db.create() + status_updates = env.db.get_status_updates() + status_updates.append(new_status_update) + env.db.save_status_updates([new_status_update]) + scrobbles, leftovers = calculate_scrobbles(status_updates) + failed_scrobbles: list[Status] = [] + for i in range(0, len(scrobbles), scrobble_batch_size): try: - scrobbler.scrobble(scrobbles[i:i+SCROBBLE_BATCH_SIZE]) + env.http.scrobble(scrobbles[i:i+scrobble_batch_size], ) except Exception: - logging.exception('Scrobbling failed') + env.logger.exception('Scrobbling failed') # tracks need to be scrobbled in correct order. If the first # batch fails then other batches need to be left for later too. failed_scrobbles.extend(scrobbles[i:]) break - db.clear() - db.save_status_updates(failed_scrobbles+leftovers) - + env.db.clear() + env.db.save_status_updates(failed_scrobbles+leftovers) -def get_tmp_dir(): - for d in ['TMPDIR', 'TEMP', 'TEMPDIR', 'TMP']: - c = os.environ.get(d) - if c: - return c - return '/tmp' - -TMP_DIR = get_tmp_dir() - - -def setup_logging(log_path): +def setup_logging(log_path: Optional[str]) -> None: + tmp_dir = '/tmp' + for name in ['TMPDIR', 'TEMP', 'TEMPDIR', 'TMP']: + value = os.environ.get(name) + if value is not None: + tmp_dir = value + break logging.basicConfig( - filename=log_path or os.path.join(TMP_DIR, 'cmus_scrobbler.log'), + filename=log_path or os.path.join(tmp_dir, 'cmus_scrobbler.log'), datefmt='%Y-%m-%d %H:%M:%S', - format='%(process)d %(asctime)s %(levelname)s %(name)s %(message)s', - level=logging.DEBUG) + format= + '%(process)d %(asctime)s %(levelname)s %(name)s %(service)s %(message)s', + level=logging.DEBUG, + ) -def get_conf(conf_path): +def get_conf(conf_path: str) -> configparser.ConfigParser: if not os.path.exists(conf_path): raise FileNotFoundError(f'{conf_path} does not exist.') conf = configparser.ConfigParser() - with open(conf_path, 'r') as f: - conf.read_file(f) + with open(conf_path, 'r') as handle: + conf.read_file(handle) return conf -DB_CONNECT_RETRY_ATTEMPTS = 10 -DB_CONNECT_RETRY_SLEEP_SECS = 10 +def read_global_config( + conf: configparser.ConfigParser, + *, + default_db_path: str, + default_log_db: bool, +) -> GlobalConfig: + api_key = conf['global'].get('api_key') + shared_secret = conf['global'].get('shared_secret') + if api_key is None or shared_secret is None: + raise KeyError('Missing api_key/shared_secret in global config.') + return GlobalConfig( + api_key=api_key, + shared_secret=shared_secret, + db_path=conf['global'].get('db_path', default_db_path), + log_db=conf['global'].getboolean('log_db', fallback=default_log_db), + now_playing=conf['global'].getboolean('now_playing', fallback=False), + format_xml=conf['global'].getboolean('format_xml', fallback=False), + log_path=conf['global'].get('log_path'), + ) -def db_connect(db_path, log_db=False): - con = sqlite3.connect(db_path, timeout=DB_CONNECT_TIMEOUT) +def read_service_config( + conf: configparser.ConfigParser, + global_config: GlobalConfig, + section: str, +) -> ServiceConfig: + api_url = conf[section].get('api_url') + auth_url = conf[section].get('auth_url') + if api_url is None or auth_url is None: + raise KeyError(f'Missing api_url/auth_url for {section}.') + api_key = conf[section].get('api_key', global_config.api_key) + shared_secret = conf[section].get('shared_secret', + global_config.shared_secret) + if api_key is None or shared_secret is None: + raise KeyError(f'Missing credentials for {section}.') + return ServiceConfig( + name=section, + api_url=api_url, + auth_url=auth_url, + api_key=api_key, + shared_secret=shared_secret, + session_key=conf[section].get('session_key'), + now_playing=conf[section].getboolean( + 'now_playing', + global_config.now_playing, + ), + format_xml=conf[section].getboolean( + 'format_xml', + global_config.format_xml, + ), + ) + + +def build_app_config( + conf: configparser.ConfigParser, + *, + default_db_path: str, + default_log_db: bool, +) -> AppConfig: + global_config = read_global_config( + conf, + default_db_path=default_db_path, + default_log_db=default_log_db, + ) + services: list[ServiceConfig] = [] + for section in conf.sections(): + if section=='global': + continue + services.append(read_service_config(conf, global_config, section)) + return AppConfig(global_config=global_config, services=services) + + +def db_connect( + db_path: str, + *, + log_db: bool = False, + connect_timeout: int, + retry_attempts: int, + retry_sleep_secs: int, + logger: logging.LoggerAdapter[logging.Logger], +) -> sqlite3.Connection: + con = sqlite3.connect(db_path, timeout=connect_timeout) if log_db: - con.set_trace_callback(logging.debug) + con.set_trace_callback(logger.debug) # BEGIN IMMEDIATE can return SQLITE_BUSY. After it succeeds, no other query # will return SQLITE_BUSY. # Retrying opens the possibility of incorrect event order but it should not @@ -434,12 +813,12 @@ def db_connect(db_path, log_db=False): # running process. That was not the point of this simple script. # 10 retries leaves enough room for scrobble ops to finish and release the # db lock. - for _ in range(DB_CONNECT_RETRY_ATTEMPTS): + for _ in range(retry_attempts): try: con.execute('BEGIN IMMEDIATE') break except sqlite3.OperationalError: - time.sleep(DB_CONNECT_RETRY_SLEEP_SECS) + time.sleep(retry_sleep_secs) else: raise Exception('Could not connect to db.') # when multiple status updates arrive one after another, then @@ -449,71 +828,98 @@ def db_connect(db_path, log_db=False): return con -def get_scrobblers(conf): - api_key = conf['global'].get('api_key') - shared_secret = conf['global'].get('shared_secret') - scrs = [] - for section in conf.sections(): - if section=='global': - continue - scrs.append( - Scrobbler( - section, conf[section]['api_url'], - conf[section].get('api_key', api_key), - conf[section].get('shared_secret', shared_secret), - conf[section].get('session_key'), conf[section].getboolean( - 'now_playing', conf['global'].getboolean('now_playing')), - conf[section].getboolean( - 'format_xml', conf['global'].getboolean('format_xml')))) - return scrs - - -def auth(conf): - api_key = conf['global'].get('api_key') - shared_secret = conf['global'].get('shared_secret') - format_xml = conf['global'].getboolean('format_xml') - for section in conf.sections(): - if section=='global': - continue - if 'session_key' in conf[section]: - print(f'Session key already active for {section}. Skipping...') - continue - try: - conf[section].update( - Scrobbler.auth( - conf[section]['auth_url'], - conf[section]['api_url'], - conf[section].get('api_key', api_key), - conf[section].get('shared_secret', shared_secret), - conf[section].getboolean('format_xml', format_xml), - )) - except Exception: - logging.exception('Authentication failed.') +def run_auth( + http_env: HttpEnv, + conf: configparser.ConfigParser, + service_name: str, + logger: logging.LoggerAdapter[logging.Logger], +) -> configparser.ConfigParser: + try: + conf[service_name].update(http_env.auth()) + except Exception: + logger.exception('Authentication failed.') return conf -def main(): - args, rest = parser.parse_known_args() +def main() -> None: + defaults = AppDefaults( + config_path='~/.config/cmus/cmus_status_scrobbler.ini', + db_path='~/.config/cmus/cmus_status_scrobbler.sqlite3', + db_connect_timeout=300, + db_connect_retry_attempts=10, + db_connect_retry_sleep_secs=10, + scrobble_batch_size=50, + http_user_agent='Mozilla/5.0', + http_default_timeout_secs=10.0, + http_scrobble_timeout_secs=5.0, + ) + parser = build_parser(Args, defaults) + parsed_args, rest = parser.parse_known_args() + args = Args(**vars(parsed_args)) conf_path = args.ini conf = get_conf(conf_path) - setup_logging(args.log_path or conf['global'].get('log_path')) + app_config = build_app_config( + conf, + default_db_path=args.db_path, + default_log_db=args.log_db, + ) + setup_logging(args.log_path or app_config.global_config.log_path) + logger = logging.getLogger('cmus_status_scrobbler') + base_logger = logging.LoggerAdapter(logger, {'service': '-'}) if args.auth: - with open(conf_path, 'w') as f: - auth(conf).write(f) - exit() - status = parse_cmus_status_line(rest) - scrobblers = get_scrobblers(conf) - with db_connect(conf['global'].get('db_path', args.db_path), - log_db=conf['global'].get('log_db', args.log_db)) as con: - logging.info(repr(status)) - for scr in scrobblers: - update_scrobble_state(StatusDB(con, scr.name), scr, status) - for scr in scrobblers: - scr.send_now_playing(status) - - -if __name__=="__main__": - try: - main() - except Exception: - logging.exception('Error happened') + with open(conf_path, 'w') as handle: + for service_config in app_config.services: + service_logger = logging.LoggerAdapter( + logger, + {'service': service_config.name}, + ) + if service_config.session_key is not None: + print( + f'Session key already active for {service_config.name}. Skipping...' + ) + continue + http_env = make_http_env( + service_config=service_config, + defaults=defaults, + session_key=None, + logger=service_logger, + ) + run_auth(http_env, conf, service_config.name, service_logger) + conf.write(handle) + return + status = parse_cmus_status_line(rest, base_logger) + if args.cur_time is not None: + status = status._replace(cur_time=args.cur_time) + with db_connect( + app_config.global_config.db_path, + log_db=app_config.global_config.log_db, + connect_timeout=defaults.db_connect_timeout, + retry_attempts=defaults.db_connect_retry_attempts, + retry_sleep_secs=defaults.db_connect_retry_sleep_secs, + logger=base_logger, + ) as con: + base_logger.info(repr(status)) + for service_config in app_config.services: + service_logger = logging.LoggerAdapter( + logger, + {'service': service_config.name}, + ) + http_env = make_http_env( + service_config=service_config, + defaults=defaults, + session_key=service_config.session_key, + logger=service_logger, + ) + env = make_scrobbling_env( + con=con, + http_env=http_env, + table_name=service_config.name, + logger=service_logger, + ) + run_update_scrobble_state(env, status, + defaults.scrobble_batch_size) + http_env.send_now_playing(status) + + +if __name__=='__main__': + main() diff --git a/requirements-dev.txt b/requirements-dev.txt index 1418124..06fc3d7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,3 +2,6 @@ yapf pycodestyle python-language-server flake8 +mypy +pyright +typing_extensions diff --git a/test_cmus_status_scrobbler.py b/test_cmus_status_scrobbler.py index b7104c2..0532768 100644 --- a/test_cmus_status_scrobbler.py +++ b/test_cmus_status_scrobbler.py @@ -1,57 +1,737 @@ """ -This script tests concurrent invocations of the scrobbler. +End-to-end tests for cmus_status_scrobbler.py using a stub HTTP server. """ +from __future__ import annotations +import io +import json import os +import pickle +import sqlite3 import subprocess +import sys +import tempfile +import threading import unittest -from multiprocessing import Process +import urllib.parse as up +from dataclasses import dataclass +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from typing import Callable, Optional -# Assuming cmus_status_scrobbler.py is in the same directory -PYTHON_EXECUTABLE = 'python' # or 'python3' if needed +from cmus_status_scrobbler import STATUS_PLAYING, Status +PYTHON_EXECUTABLE = sys.executable CMUS_STATUS_SCROBBLER_PATH = './cmus_status_scrobbler.py' -INI_PATH = './test.ini' -DB_PATH = './test.sqlite3' -def run_scrobbler(): - subprocess.run([ - PYTHON_EXECUTABLE, CMUS_STATUS_SCROBBLER_PATH, '--ini', INI_PATH, - 'status', 'playing', 'file', '/home/user/Music/song1.mp3', 'artist', - 'Artist A', 'album', 'Album X', 'title', 'Song 1', 'duration', '240' - ]) +@dataclass(frozen=True) +class RequestRecord: + path: str + params: dict[str, list[str]] -class TestCmusStatusScrobblerIntegration(unittest.TestCase): +@dataclass(frozen=True) +class ServerState: + xml: bool + requests: list[RequestRecord] + lock: threading.Lock + make_response: Callable[[str], str] + fail_methods: set[str] - def setUp(self): - # Create a dummy .ini file - with open(INI_PATH, 'w') as f: + +class StubScrobblerServer: + + def __init__(self, + xml: bool = False, + fail_methods: Optional[set[str]] = None) -> None: + state = ServerState( + xml=xml, + requests=[], + lock=threading.Lock(), + make_response=self._make_response, + fail_methods=fail_methods or set(), + ) + + class StubRequestHandler(BaseHTTPRequestHandler): + + def do_POST(self) -> None: + length = int(self.headers.get('Content-Length', '0')) + body = self.rfile.read(length) + params = up.parse_qs(body.decode('utf-8'), + keep_blank_values=True) + record = RequestRecord(path=self.path, params=params) + with state.lock: + state.requests.append(record) + method = params.get('method', [''])[0] + if method in state.fail_methods: + self.send_response(500) + self.end_headers() + return + response = state.make_response(method) + self.send_response(200) + content_type = 'text/xml' if state.xml else 'application/json' + self.send_header('Content-Type', content_type) + self.end_headers() + self.wfile.write(response.encode('utf-8')) + + def log_message(self, format: str, *args: str) -> None: + return + + self._state = state + self._server = ThreadingHTTPServer(('127.0.0.1', 0), + StubRequestHandler) + self.base_url = f'http://127.0.0.1:{self._server.server_address[1]}/' + self._thread = threading.Thread(target=self._server.serve_forever) + self._thread.daemon = True + self._thread.start() + + def _make_response(self, method: str) -> str: + if method=='auth.gettoken': + if self._state.xml: + return 'TEST_TOKEN' + return json.dumps({'token': 'TEST_TOKEN'}) + if method=='auth.getsession': + if self._state.xml: + return ('TEST_SK' + 'tester') + return json.dumps( + {'session': { + 'key': 'TEST_SK', + 'name': 'tester' + }}) + if self._state.xml: + return '' + return json.dumps({}) + + def reset(self) -> None: + with self._state.lock: + self._state.requests.clear() + + def get_requests(self) -> list[RequestRecord]: + with self._state.lock: + return list(self._state.requests) + + def stop(self) -> None: + self._server.shutdown() + self._server.server_close() + self._thread.join(timeout=2) + + +class E2ETestBase(unittest.TestCase): + server: StubScrobblerServer + + def setUp(self) -> None: + self.temp_dir = tempfile.TemporaryDirectory() + self.ini_path = os.path.join(self.temp_dir.name, 'test.ini') + self.db_path = os.path.join(self.temp_dir.name, 'test.sqlite3') + + def tearDown(self) -> None: + self.temp_dir.cleanup() + + def write_ini(self, + base_url: str, + session_key: Optional[str] = 'TEST_SK', + format_xml: bool = False, + now_playing: bool = False) -> None: + with open(self.ini_path, 'w') as f: f.write('[global]\n') - f.write(f'db_path = {DB_PATH}\n') - # Ensure the database file is removed before each test - if os.path.exists(DB_PATH): - os.remove(DB_PATH) - - def tearDown(self): - # Clean up the dummy .ini and database files after each test - if os.path.exists(INI_PATH): - os.remove(INI_PATH) - if os.path.exists(DB_PATH): - os.remove(DB_PATH) - - def test_multiple_invocations_with_empty_status(self): - # Define a function to run cmus_status_scrobbler.py - # Create and start multiple processes - processes = [Process(target=run_scrobbler) for _ in range(5)] - for p in processes: - p.start() - - # Wait for all processes to complete - for p in processes: - p.join() - - # TODO: Add assertions to verify the expected behavior - # For now, we just check that the processes ran without raising exceptions + f.write(f'api_key = TEST_API_KEY\n') + f.write(f'shared_secret = TEST_SHARED_SECRET\n') + f.write(f'db_path = {self.db_path}\n') + f.write(f'now_playing = {str(now_playing).lower()}\n') + f.write(f'format_xml = {str(format_xml).lower()}\n') + f.write('\n') + f.write('[stub]\n') + f.write(f'api_url = {base_url}\n') + f.write(f'auth_url = {base_url}auth\n') + if session_key is not None: + f.write(f'session_key = {session_key}\n') + + def run_scrobbler(self, *status_args: str) -> None: + cmd = [ + PYTHON_EXECUTABLE, + CMUS_STATUS_SCROBBLER_PATH, + '--ini', + self.ini_path, + '--db-path', + self.db_path, + ] + cmd.extend(status_args) + subprocess.run(cmd, check=True) + + def run_status( + self, + cur_time: int, + status: str, + file_name: str, + duration: int, + title: Optional[str] = None, + ) -> None: + self.run_scrobbler( + '--cur-time', + str(cur_time), + 'status', + status, + 'file', + file_name, + 'artist', + 'Artist', + 'title', + title or file_name, + 'duration', + str(duration), + ) + + def read_db_updates(self) -> list[Status]: + if not os.path.exists(self.db_path): + return [] + con = sqlite3.connect(self.db_path) + try: + cur = con.cursor() + cur.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = [r[0] for r in cur.fetchall()] + if not tables: + return [] + updates = [] + for table in tables: + cur.execute(f'SELECT * FROM {table}') + for row in cur.fetchall(): + updates.append(self._unpickle_status(row[0])) + return updates + finally: + con.close() + + def _unpickle_status(self, payload: bytes) -> Status: + + class StatusUnpickler(pickle.Unpickler): + + def find_class(self, module: str, name: str) -> type: + if module=='__main__' and name=='Status': + from cmus_status_scrobbler import Status as StatusClass + return StatusClass + loaded = super().find_class(module, name) + if not isinstance(loaded, type): + raise TypeError('Unexpected pickle class.') + return loaded + + loaded = StatusUnpickler(io.BytesIO(payload)).load() + if not isinstance(loaded, Status): + raise TypeError('Unexpected status payload.') + return loaded + + def get_scrobble_tracks(self, requests: list[RequestRecord]) -> list[str]: + tracks = [] + for req in requests: + params = req.params + if params.get('method', [''])[0]!='track.scrobble': + continue + indices = [] + for key in params.keys(): + if key.startswith('track[') and key.endswith(']'): + idx = int(key[6:-1]) + indices.append(idx) + for idx in sorted(indices): + tracks.append(params.get(f'track[{idx}]', [''])[0]) + return tracks + + def get_requests_by_method(self, method: str) -> list[RequestRecord]: + return [ + req for req in self.server.get_requests() + if req.params.get('method', [''])[0]==method + ] + + def assert_param(self, params: dict[str, list[str]], key: str, + expected: str) -> None: + self.assertIn(key, params) + self.assertEqual(expected, params[key][0]) + + def assert_param_present(self, params: dict[str, list[str]], + key: str) -> None: + self.assertIn(key, params) + self.assertTrue(params[key][0]) + + def get_scrobble_items( + self, params: dict[str, list[str]]) -> list[tuple[str, str]]: + indices = [] + for key in params.keys(): + if key.startswith('track[') and key.endswith(']'): + indices.append(int(key[6:-1])) + items = [] + for idx in sorted(indices): + track = params.get(f'track[{idx}]', [''])[0] + timestamp = params.get(f'timestamp[{idx}]', [''])[0] + items.append((track, timestamp)) + return items + + +class TestAuthFlow(E2ETestBase): + + def test_auth_json(self) -> None: + server = StubScrobblerServer(xml=False) + try: + self.write_ini(server.base_url, session_key=None, format_xml=False) + cmd = [ + PYTHON_EXECUTABLE, + CMUS_STATUS_SCROBBLER_PATH, + '--ini', + self.ini_path, + '--auth', + ] + process = subprocess.run( + cmd, + input='\n', + text=True, + check=True, + ) + self.assertEqual(0, process.returncode) + with open(self.ini_path, 'r') as f: + content = f.read() + self.assertIn('session_key = TEST_SK', content) + self.assertIn('username = tester', content) + requests = server.get_requests() + self.assertEqual(2, len(requests)) + token_reqs = [ + req for req in requests + if req.params.get('method', [''])[0]=='auth.gettoken' + ] + session_reqs = [ + req for req in requests + if req.params.get('method', [''])[0]=='auth.getsession' + ] + self.assertEqual(1, len(token_reqs)) + self.assertEqual(1, len(session_reqs)) + token_params = token_reqs[0].params + session_params = session_reqs[0].params + self.assert_param(token_params, 'api_key', 'TEST_API_KEY') + self.assert_param(token_params, 'format', 'json') + self.assert_param(session_params, 'api_key', 'TEST_API_KEY') + self.assert_param(session_params, 'token', 'TEST_TOKEN') + self.assert_param(session_params, 'format', 'json') + self.assert_param_present(session_params, 'api_sig') + finally: + server.stop() + + def test_auth_xml(self) -> None: + server = StubScrobblerServer(xml=True) + try: + self.write_ini(server.base_url, session_key=None, format_xml=True) + cmd = [ + PYTHON_EXECUTABLE, + CMUS_STATUS_SCROBBLER_PATH, + '--ini', + self.ini_path, + '--auth', + ] + process = subprocess.run( + cmd, + input='\n', + text=True, + check=True, + ) + self.assertEqual(0, process.returncode) + with open(self.ini_path, 'r') as f: + content = f.read() + self.assertIn('session_key = TEST_SK', content) + self.assertIn('username = tester', content) + requests = server.get_requests() + self.assertEqual(2, len(requests)) + token_reqs = [ + req for req in requests + if req.params.get('method', [''])[0]=='auth.gettoken' + ] + session_reqs = [ + req for req in requests + if req.params.get('method', [''])[0]=='auth.getsession' + ] + self.assertEqual(1, len(token_reqs)) + self.assertEqual(1, len(session_reqs)) + token_params = token_reqs[0].params + session_params = session_reqs[0].params + self.assert_param(token_params, 'api_key', 'TEST_API_KEY') + self.assertNotIn('format', token_params) + self.assert_param(session_params, 'api_key', 'TEST_API_KEY') + self.assert_param(session_params, 'token', 'TEST_TOKEN') + self.assertNotIn('format', session_params) + self.assert_param_present(session_params, 'api_sig') + finally: + server.stop() + + +class TestScrobbleE2E(E2ETestBase): + + def setUp(self) -> None: + super().setUp() + self.server = StubScrobblerServer(xml=False) + self.write_ini(self.server.base_url, session_key='TEST_SK') + + def tearDown(self) -> None: + self.server.stop() + super().tearDown() + + def test_simple_play_stop(self) -> None: + base = 1000 + self.run_status(base, 'playing', 'A', 5) + self.run_status(base+4, 'stopped', 'A', 5) + tracks = self.get_scrobble_tracks(self.server.get_requests()) + self.assertEqual(['A'], tracks) + requests = self.get_requests_by_method('track.scrobble') + self.assertEqual(1, len(requests)) + params = requests[0].params + self.assert_param(params, 'api_key', 'TEST_API_KEY') + self.assert_param(params, 'sk', 'TEST_SK') + self.assert_param(params, 'format', 'json') + self.assert_param_present(params, 'api_sig') + self.assertEqual([('A', str(base))], self.get_scrobble_items(params)) + self.assertEqual([], + self.get_requests_by_method('track.updateNowPlaying')) + self.assertEqual([], self.read_db_updates()) + + def test_repeat(self) -> None: + base = 2000 + self.run_status(base, 'playing', 'A', 5) + self.run_status(base+4, 'playing', 'A', 5) + tracks = self.get_scrobble_tracks(self.server.get_requests()) + self.assertEqual(['A'], tracks) + requests = self.get_requests_by_method('track.scrobble') + self.assertEqual(1, len(requests)) + params = requests[0].params + self.assert_param(params, 'api_key', 'TEST_API_KEY') + self.assert_param(params, 'sk', 'TEST_SK') + self.assert_param(params, 'format', 'json') + self.assert_param_present(params, 'api_sig') + self.assertEqual([('A', str(base))], self.get_scrobble_items(params)) + self.assertEqual([], + self.get_requests_by_method('track.updateNowPlaying')) + updates = self.read_db_updates() + self.assertEqual(1, len(updates)) + + def test_play_pause(self) -> None: + base = 3000 + self.run_status(base, 'playing', 'A', 5) + self.run_status(base+4, 'paused', 'A', 5) + tracks = self.get_scrobble_tracks(self.server.get_requests()) + self.assertEqual([], tracks) + self.assertEqual([], self.get_requests_by_method('track.scrobble')) + self.assertEqual([], + self.get_requests_by_method('track.updateNowPlaying')) + updates = self.read_db_updates() + self.assertEqual(2, len(updates)) + + def test_play_pause_stopped(self) -> None: + base = 4000 + self.run_status(base, 'playing', 'A', 5) + self.run_status(base+1, 'paused', 'A', 5) + self.run_status(base+20, 'stopped', 'A', 5) + tracks = self.get_scrobble_tracks(self.server.get_requests()) + self.assertEqual([], tracks) + self.assertEqual([], self.get_requests_by_method('track.scrobble')) + self.assertEqual([], + self.get_requests_by_method('track.updateNowPlaying')) + self.assertEqual([], self.read_db_updates()) + + def test_play_pause_play_pause_dotdotdot_stopped(self) -> None: + base = 5000 + self.run_status(base, 'playing', 'A', 10) + self.run_status(base+1, 'paused', 'A', 10) + self.run_status(base+100, 'playing', 'A', 10) + self.run_status(base+101, 'paused', 'A', 10) + self.run_status(base+200, 'playing', 'A', 10) + self.run_status(base+201, 'paused', 'A', 10) + self.run_status(base+300, 'playing', 'A', 10) + self.run_status(base+301, 'paused', 'A', 10) + self.run_status(base+400, 'playing', 'A', 10) + self.run_status(base+401, 'paused', 'A', 10) + self.run_status(base+402, 'stopped', 'A', 10) + tracks = self.get_scrobble_tracks(self.server.get_requests()) + self.assertEqual(['A'], tracks) + requests = self.get_requests_by_method('track.scrobble') + self.assertEqual(1, len(requests)) + params = requests[0].params + self.assert_param(params, 'api_key', 'TEST_API_KEY') + self.assert_param(params, 'sk', 'TEST_SK') + self.assert_param(params, 'format', 'json') + self.assert_param_present(params, 'api_sig') + self.assertEqual([('A', str(base))], self.get_scrobble_items(params)) + self.assertEqual([], + self.get_requests_by_method('track.updateNowPlaying')) + self.assertEqual([], self.read_db_updates()) + + def test_play_pause_stopped_enough_time_played(self) -> None: + base = 6000 + self.run_status(base, 'playing', 'A', 5) + self.run_status(base+3, 'paused', 'A', 5) + self.run_status(base+20, 'stopped', 'A', 5) + tracks = self.get_scrobble_tracks(self.server.get_requests()) + self.assertEqual(['A'], tracks) + requests = self.get_requests_by_method('track.scrobble') + self.assertEqual(1, len(requests)) + params = requests[0].params + self.assert_param(params, 'api_key', 'TEST_API_KEY') + self.assert_param(params, 'sk', 'TEST_SK') + self.assert_param(params, 'format', 'json') + self.assert_param_present(params, 'api_sig') + self.assertEqual([('A', str(base))], self.get_scrobble_items(params)) + self.assertEqual([], + self.get_requests_by_method('track.updateNowPlaying')) + self.assertEqual([], self.read_db_updates()) + + def test_normal_player_status(self) -> None: + base = 7000 + self.run_status(base, 'playing', 'A', 1) + self.run_status(base+2, 'playing', 'B', 1) + self.run_status(base+3, 'playing', 'C', 1) + self.run_status(base+5, 'playing', 'D', 1) + self.run_status(base+7, 'playing', 'E', 1) + self.run_status(base+9, 'playing', 'F', 1) + self.run_status(base+11, 'stopped', 'F', 1) + tracks = self.get_scrobble_tracks(self.server.get_requests()) + self.assertEqual(['A', 'B', 'C', 'D', 'E', 'F'], tracks) + requests = self.get_requests_by_method('track.scrobble') + items = [] + for req in requests: + params = req.params + self.assert_param(params, 'api_key', 'TEST_API_KEY') + self.assert_param(params, 'sk', 'TEST_SK') + self.assert_param(params, 'format', 'json') + self.assert_param_present(params, 'api_sig') + items.extend(self.get_scrobble_items(params)) + self.assertEqual([ + ('A', str(base)), + ('B', str(base+2)), + ('C', str(base+3)), + ('D', str(base+5)), + ('E', str(base+7)), + ('F', str(base+9)), + ], items) + self.assertEqual([], + self.get_requests_by_method('track.updateNowPlaying')) + self.assertEqual([], self.read_db_updates()) + + def test_pause_play_suffix_leftovers(self) -> None: + base = 8000 + self.run_status(base, 'playing', 'A', 1) + self.run_status(base+2, 'playing', 'B', 1) + self.run_status(base+3, 'playing', 'C', 1) + self.run_status(base+5, 'playing', 'D', 1) + self.run_status(base+7, 'playing', 'E', 1) + self.run_status(base+9, 'playing', 'F', 1) + self.run_status(base+11, 'stopped', 'F', 1) + self.run_status(base+13, 'playing', '*', 10) + self.run_status(base+15, 'paused', '*', 10) + self.run_status(base+17, 'playing', '*', 10) + self.run_status(base+21, 'paused', '*', 10) + self.run_status(base+23, 'playing', '*', 10) + self.run_status(base+25, 'paused', '*', 10) + tracks = self.get_scrobble_tracks(self.server.get_requests()) + self.assertEqual(['A', 'B', 'C', 'D', 'E', 'F'], tracks) + requests = self.get_requests_by_method('track.scrobble') + items = [] + for req in requests: + params = req.params + self.assert_param(params, 'api_key', 'TEST_API_KEY') + self.assert_param(params, 'sk', 'TEST_SK') + self.assert_param(params, 'format', 'json') + self.assert_param_present(params, 'api_sig') + items.extend(self.get_scrobble_items(params)) + self.assertEqual([ + ('A', str(base)), + ('B', str(base+2)), + ('C', str(base+3)), + ('D', str(base+5)), + ('E', str(base+7)), + ('F', str(base+9)), + ], items) + self.assertEqual([], + self.get_requests_by_method('track.updateNowPlaying')) + updates = self.read_db_updates() + self.assertEqual(6, len(updates)) + + def test_scrobble_criteria(self) -> None: + base = 9000 + for idx, stop_status in enumerate(['playing', 'stopped']): + self.server.reset() + if os.path.exists(self.db_path): + os.remove(self.db_path) + offset = idx*100 + self.run_status(base+offset, 'playing', 'A', 10) + if stop_status=='playing': + self.run_status(base+offset+10, 'playing', 'A', 10) + else: + self.run_status(base+offset+10, 'stopped', 'A', 10) + tracks = self.get_scrobble_tracks(self.server.get_requests()) + self.assertEqual(['A'], tracks) + requests = self.get_requests_by_method('track.scrobble') + self.assertEqual(1, len(requests)) + params = requests[0].params + self.assert_param(params, 'api_key', 'TEST_API_KEY') + self.assert_param(params, 'sk', 'TEST_SK') + self.assert_param(params, 'format', 'json') + self.assert_param_present(params, 'api_sig') + self.assertEqual([('A', str(base+offset))], + self.get_scrobble_items(params)) + self.assertEqual( + [], self.get_requests_by_method('track.updateNowPlaying')) + if stop_status=='playing': + self.assertEqual(1, len(self.read_db_updates())) + else: + self.assertEqual([], self.read_db_updates()) + self.server.reset() + if os.path.exists(self.db_path): + os.remove(self.db_path) + self.run_status(base+200, 'playing', 'A', 10) + self.run_status(base+210, 'playing', 'B', 10) + tracks = self.get_scrobble_tracks(self.server.get_requests()) + self.assertEqual(['A'], tracks) + requests = self.get_requests_by_method('track.scrobble') + self.assertEqual(1, len(requests)) + params = requests[0].params + self.assert_param(params, 'api_key', 'TEST_API_KEY') + self.assert_param(params, 'sk', 'TEST_SK') + self.assert_param(params, 'format', 'json') + self.assert_param_present(params, 'api_sig') + self.assertEqual([('A', str(base+200))], + self.get_scrobble_items(params)) + self.assertEqual([], + self.get_requests_by_method('track.updateNowPlaying')) + self.assertEqual(1, len(self.read_db_updates())) + + def test_xml_scrobble(self) -> None: + self.server.stop() + self.server = StubScrobblerServer(xml=True) + self.write_ini(self.server.base_url, + session_key='TEST_SK', + format_xml=True) + base = 10000 + self.run_status(base, 'playing', 'A', 5) + self.run_status(base+4, 'stopped', 'A', 5) + tracks = self.get_scrobble_tracks(self.server.get_requests()) + self.assertEqual(['A'], tracks) + requests = self.get_requests_by_method('track.scrobble') + self.assertEqual(1, len(requests)) + params = requests[0].params + self.assert_param(params, 'api_key', 'TEST_API_KEY') + self.assert_param(params, 'sk', 'TEST_SK') + self.assertNotIn('format', params) + self.assert_param_present(params, 'api_sig') + self.assertEqual([('A', str(base))], self.get_scrobble_items(params)) + self.assertEqual([], + self.get_requests_by_method('track.updateNowPlaying')) + + +class TestNowPlayingFailure(E2ETestBase): + + def test_now_playing_failure_keeps_db(self) -> None: + server = StubScrobblerServer( + xml=False, + fail_methods={'track.updateNowPlaying'}, + ) + try: + self.write_ini(server.base_url, + session_key='TEST_SK', + format_xml=False, + now_playing=True) + base = 11000 + self.run_status(base, 'playing', 'A', 5) + requests = server.get_requests() + update_reqs = [ + req for req in requests + if req.params.get('method', [''])[0]=='track.updateNowPlaying' + ] + scrobble_reqs = [ + req for req in requests + if req.params.get('method', [''])[0]=='track.scrobble' + ] + self.assertEqual(1, len(update_reqs)) + self.assertEqual([], scrobble_reqs) + updates = self.read_db_updates() + self.assertEqual(1, len(updates)) + self.assertEqual(STATUS_PLAYING, updates[0].status) + self.assertEqual('A', updates[0].file) + finally: + server.stop() + + +class TestMultiServiceFailure(E2ETestBase): + + def test_middle_service_scrobble_failure_isolated(self) -> None: + server1 = StubScrobblerServer(xml=False) + server2 = StubScrobblerServer( + xml=False, + fail_methods={'track.scrobble'}, + ) + server3 = StubScrobblerServer(xml=False) + servers = [server1, server2, server3] + try: + + def write_ini_multi() -> None: + with open(self.ini_path, 'w') as handle: + handle.write('[global]\n') + handle.write('api_key = TEST_API_KEY\n') + handle.write('shared_secret = TEST_SHARED_SECRET\n') + handle.write(f'db_path = {self.db_path}\n') + handle.write('now_playing = true\n') + handle.write('format_xml = false\n\n') + for idx, server in enumerate(servers, start=1): + name = f'svc{idx}' + handle.write(f'[{name}]\n') + handle.write(f'api_url = {server.base_url}\n') + handle.write(f'auth_url = {server.base_url}auth\n') + handle.write('session_key = TEST_SK\n') + handle.write('now_playing = true\n\n') + + def get_requests_by_method(server: StubScrobblerServer, + method: str) -> list[RequestRecord]: + return [ + req for req in server.get_requests() + if req.params.get('method', [''])[0]==method + ] + + def read_table_updates(table_name: str) -> list[Status]: + if not os.path.exists(self.db_path): + return [] + con = sqlite3.connect(self.db_path) + try: + cur = con.cursor() + table = f'status_updates_{table_name}' + cur.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (table, ), + ) + if cur.fetchone() is None: + return [] + cur.execute(f'SELECT * FROM {table}') + return [ + self._unpickle_status(row[0]) + for row in cur.fetchall() + ] + finally: + con.close() + + write_ini_multi() + base = 12000 + self.run_status(base, 'playing', 'A', 5) + self.run_status(base+4, 'stopped', 'A', 5) + + self.assertEqual( + 1, len(get_requests_by_method(server1, 'track.scrobble'))) + self.assertEqual( + 1, len(get_requests_by_method(server2, 'track.scrobble'))) + self.assertEqual( + 1, len(get_requests_by_method(server3, 'track.scrobble'))) + self.assertEqual( + 1, + len(get_requests_by_method(server1, 'track.updateNowPlaying'))) + self.assertEqual( + 1, + len(get_requests_by_method(server2, 'track.updateNowPlaying'))) + self.assertEqual( + 1, + len(get_requests_by_method(server3, 'track.updateNowPlaying'))) + + self.assertEqual([], read_table_updates('svc1')) + self.assertEqual(1, len(read_table_updates('svc2'))) + self.assertEqual([], read_table_updates('svc3')) + finally: + for server in servers: + server.stop() if __name__=='__main__': diff --git a/tests.py b/tests.py index bd36f44..2fde7ae 100644 --- a/tests.py +++ b/tests.py @@ -1,82 +1,102 @@ """ This script is a test suite for the cmus_status_scrobbler.py script. """ +from __future__ import annotations import datetime import itertools as it +import logging import os import sqlite3 import unittest -from collections import namedtuple +from collections.abc import Callable, Iterable from cmus_status_scrobbler import ( - CmusStatus, - StatusDB, + STATUS_PAUSED, + STATUS_PLAYING, + STATUS_STOPPED, + HttpEnv, + ScrobblingEnv, + Status, calculate_scrobbles, - update_scrobble_state, + make_db_env, + run_update_scrobble_state, ) -def secs(n): +def secs(n: int) -> datetime.timedelta: return datetime.timedelta(seconds=n) -_SS = namedtuple('_SS', 'cur_time duration file status') +def SS(*, cur_time: datetime.datetime, duration: int, file: str, + status: str) -> Status: + return make_status(cur_time=cur_time, + duration=duration, + file=file, + status=status) -def SS(*, cur_time, duration, file, status): - return _SS(cur_time=cur_time.timestamp(), - duration=duration, - file=file, - status=status) +def utcnow() -> datetime.datetime: + return datetime.datetime.now(datetime.timezone.utc) -def utcnow(): - return datetime.datetime.now(datetime.timezone.utc) +def make_status( + *, + cur_time: datetime.datetime, + duration: int, + file: str, + status: str, +) -> Status: + return Status( + status=status, + file=file, + artist=None, + albumartist=None, + album=None, + discnumber=1, + tracknumber=None, + title=None, + date=None, + duration=duration, + musicbrainz_trackid=None, + cur_time=cur_time.timestamp(), + ) class TestCalculateScrobbles(unittest.TestCase): - def assertArrayEqual(self, ar1, ar2): + def assertArrayEqual(self, ar1: Iterable[Status], + ar2: Iterable[Status]) -> None: for expected, actual in it.zip_longest(ar1, ar2): self.assertEqual(expected, actual) - def test_simple_play_stop(self): + def test_simple_play_stop(self) -> None: d = utcnow() ss = [ - SS(cur_time=d, duration=5, file='A', status=CmusStatus.playing), - SS(cur_time=d+secs(4), - duration=5, - file='A', - status=CmusStatus.stopped) + SS(cur_time=d, duration=5, file='A', status=STATUS_PLAYING), + SS(cur_time=d+secs(4), duration=5, file='A', status=STATUS_STOPPED) ] scrobbles, leftovers = calculate_scrobbles(ss) # track when started playing - self.assertEqual(CmusStatus.playing, scrobbles[0].status) + self.assertEqual(STATUS_PLAYING, scrobbles[0].status) self.assertEqual(ss[0], scrobbles[0]) - def test_repeat(self): + def test_repeat(self) -> None: d = utcnow() ss = [ - SS(cur_time=d, duration=5, file='A', status=CmusStatus.playing), - SS(cur_time=d+secs(4), - duration=5, - file='A', - status=CmusStatus.playing) + SS(cur_time=d, duration=5, file='A', status=STATUS_PLAYING), + SS(cur_time=d+secs(4), duration=5, file='A', status=STATUS_PLAYING) ] scrobbles, leftovers = calculate_scrobbles(ss) # track when started playing - self.assertEqual(CmusStatus.playing, scrobbles[0].status) + self.assertEqual(STATUS_PLAYING, scrobbles[0].status) self.assertEqual(ss[0], scrobbles[0]) self.assertEqual(ss[1], leftovers[0]) - def test_play_pause(self): + def test_play_pause(self) -> None: d = utcnow() ss = [ - SS(cur_time=d, duration=5, file='A', status=CmusStatus.playing), - SS(cur_time=d+secs(4), - duration=5, - file='A', - status=CmusStatus.paused) + SS(cur_time=d, duration=5, file='A', status=STATUS_PLAYING), + SS(cur_time=d+secs(4), duration=5, file='A', status=STATUS_PAUSED) ] scrobbles, leftovers = calculate_scrobbles(ss) self.assertEqual([], scrobbles) @@ -84,68 +104,66 @@ def test_play_pause(self): self.assertEqual(ss[0], leftovers[0]) self.assertEqual(ss[1], leftovers[1]) - def test_play_pause_stopped(self): + def test_play_pause_stopped(self) -> None: d = utcnow() ss = [ - SS(cur_time=d, duration=5, file='A', status=CmusStatus.playing), + SS(cur_time=d, duration=5, file='A', status=STATUS_PLAYING), SS( cur_time=d+secs(1), # not enough time duration=5, file='A', - status=CmusStatus.paused), + status=STATUS_PAUSED), SS(cur_time=d+secs(20), duration=5, file='A', - status=CmusStatus.stopped) + status=STATUS_STOPPED) ] scrobbles, leftovers = calculate_scrobbles(ss) self.assertEqual([], scrobbles) self.assertEqual([], leftovers) - def test_play_pause_play_pause_dotdotdot_stopped(self): + def test_play_pause_play_pause_dotdotdot_stopped(self) -> None: d = utcnow() ss = [ - SS(cur_time=d, duration=10, file='A', status=CmusStatus.playing), - SS(cur_time=d+secs(1), - duration=10, - file='A', - status=CmusStatus.paused), + SS(cur_time=d, duration=10, file='A', status=STATUS_PLAYING), + SS(cur_time=d+secs(1), duration=10, file='A', + status=STATUS_PAUSED), SS(cur_time=d+secs(100), duration=10, file='A', - status=CmusStatus.playing), + status=STATUS_PLAYING), SS(cur_time=d+secs(101), duration=10, file='A', - status=CmusStatus.paused), + status=STATUS_PAUSED), SS(cur_time=d+secs(200), duration=10, file='A', - status=CmusStatus.playing), + status=STATUS_PLAYING), SS(cur_time=d+secs(201), duration=10, file='A', - status=CmusStatus.paused), + status=STATUS_PAUSED), SS(cur_time=d+secs(300), duration=10, file='A', - status=CmusStatus.playing), + status=STATUS_PLAYING), SS(cur_time=d+secs(301), duration=10, file='A', - status=CmusStatus.paused), + status=STATUS_PAUSED), SS(cur_time=d+secs(400), duration=10, file='A', - status=CmusStatus.playing), + status=STATUS_PLAYING), SS(cur_time=d+secs(401), duration=10, file='A', - status=CmusStatus.paused), + status=STATUS_PAUSED), SS(cur_time=d+secs(402), duration=10, file='A', - status=CmusStatus.stopped) + status=STATUS_STOPPED) ] scrobbles, leftovers = calculate_scrobbles(ss[:6]) self.assertEqual([], scrobbles) @@ -160,110 +178,90 @@ def test_play_pause_play_pause_dotdotdot_stopped(self): self.assertEqual(1, len(scrobbles)) self.assertEqual(ss[0], scrobbles[0]) - def test_play_pause_stopped_enough_time_played(self): + def test_play_pause_stopped_enough_time_played(self) -> None: d = utcnow() ss = [ - SS(cur_time=d, duration=5, file='A', status=CmusStatus.playing), + SS(cur_time=d, duration=5, file='A', status=STATUS_PLAYING), SS( cur_time=d+secs(3), # enough time played duration=5, file='A', - status=CmusStatus.paused), + status=STATUS_PAUSED), SS(cur_time=d+secs(20), duration=5, file='A', - status=CmusStatus.stopped) + status=STATUS_STOPPED) ] scrobbles, leftovers = calculate_scrobbles(ss) self.assertEqual([], leftovers) self.assertEqual(ss[0], scrobbles[0]) - def test_normal_player_status(self): + def test_normal_player_status(self) -> None: d = utcnow() ss = [ - SS(cur_time=d, duration=1, file='A', status=CmusStatus.playing), - SS(cur_time=d+secs(2), - duration=1, - file='B', - status=CmusStatus.playing), - SS(cur_time=d+secs(3), - duration=1, - file='C', - status=CmusStatus.playing), - SS(cur_time=d+secs(5), - duration=1, - file='D', - status=CmusStatus.playing), - SS(cur_time=d+secs(7), - duration=1, - file='E', - status=CmusStatus.playing), - SS(cur_time=d+secs(9), - duration=1, - file='F', - status=CmusStatus.playing), + SS(cur_time=d, duration=1, file='A', status=STATUS_PLAYING), + SS(cur_time=d+secs(2), duration=1, file='B', + status=STATUS_PLAYING), + SS(cur_time=d+secs(3), duration=1, file='C', + status=STATUS_PLAYING), + SS(cur_time=d+secs(5), duration=1, file='D', + status=STATUS_PLAYING), + SS(cur_time=d+secs(7), duration=1, file='E', + status=STATUS_PLAYING), + SS(cur_time=d+secs(9), duration=1, file='F', + status=STATUS_PLAYING), SS(cur_time=d+secs(11), duration=1, file='F', - status=CmusStatus.stopped), + status=STATUS_STOPPED), ] scrobbles, leftovers = calculate_scrobbles(ss) self.assertEqual(6, len(scrobbles)) self.assertEqual([], leftovers) self.assertArrayEqual(ss[:-1], scrobbles) - def test_pause_play_suffix_leftovers(self): + def test_pause_play_suffix_leftovers(self) -> None: d = utcnow() ss = [ - SS(cur_time=d, duration=1, file='A', status=CmusStatus.playing), - SS(cur_time=d+secs(2), - duration=1, - file='B', - status=CmusStatus.playing), - SS(cur_time=d+secs(3), - duration=1, - file='C', - status=CmusStatus.playing), - SS(cur_time=d+secs(5), - duration=1, - file='D', - status=CmusStatus.playing), - SS(cur_time=d+secs(7), - duration=1, - file='E', - status=CmusStatus.playing), - SS(cur_time=d+secs(9), - duration=1, - file='F', - status=CmusStatus.playing), + SS(cur_time=d, duration=1, file='A', status=STATUS_PLAYING), + SS(cur_time=d+secs(2), duration=1, file='B', + status=STATUS_PLAYING), + SS(cur_time=d+secs(3), duration=1, file='C', + status=STATUS_PLAYING), + SS(cur_time=d+secs(5), duration=1, file='D', + status=STATUS_PLAYING), + SS(cur_time=d+secs(7), duration=1, file='E', + status=STATUS_PLAYING), + SS(cur_time=d+secs(9), duration=1, file='F', + status=STATUS_PLAYING), SS(cur_time=d+secs(11), duration=1, file='F', - status=CmusStatus.stopped), + status=STATUS_STOPPED), SS(cur_time=d+secs(13), duration=10, file='*', - status=CmusStatus.playing), + status=STATUS_PLAYING), SS(cur_time=d+secs(15), duration=10, file='*', - status=CmusStatus.paused), + status=STATUS_PAUSED), SS(cur_time=d+secs(17), duration=10, file='*', - status=CmusStatus.playing), + status=STATUS_PLAYING), SS(cur_time=d+secs(21), duration=10, file='*', - status=CmusStatus.paused), + status=STATUS_PAUSED), SS(cur_time=d+secs(23), duration=10, file='*', - status=CmusStatus.playing), + status=STATUS_PLAYING), SS(cur_time=d+secs(25), duration=10, file='*', - status=CmusStatus.paused), + status=STATUS_PAUSED), ] scrobbles, leftovers = calculate_scrobbles(ss) self.assertEqual(6, len(leftovers)) @@ -272,85 +270,213 @@ def test_pause_play_suffix_leftovers(self): # stopped will not be in leftovers self.assertArrayEqual(ss[7:], leftovers) - def test_scrobble_criteria(self): + def test_scrobble_criteria(self) -> None: # Should stop when: # 1. stopped # 2. playing again # 3. different file d = utcnow() - a = dict(cur_time=d, duration=10, file='A', status=CmusStatus.playing) for stop in [ dict(file='B'), - dict(status=CmusStatus.playing), - dict(status=CmusStatus.stopped) + dict(status=STATUS_PLAYING), + dict(status=STATUS_STOPPED), ]: - ss = [SS(**a), SS(**{**a, **stop, 'cur_time': d+secs(10)})] + file_name = stop.get('file', 'A') + status_value = stop.get('status', STATUS_PLAYING) + ss = [ + SS(cur_time=d, duration=10, file='A', status=STATUS_PLAYING), + SS(cur_time=d+secs(10), + duration=10, + file=file_name, + status=status_value), + ] scrobbles, leftovers = calculate_scrobbles(ss) self.assertEqual(1, len(scrobbles)) self.assertEqual(ss[0], scrobbles[0]) DB_FILE = 'test.sqlite3' +DB_TABLE_NAME = 'test_table_name' class TestStatusDB(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.con = sqlite3.connect(DB_FILE) + self.db_env = make_db_env(con=self.con, table_name=DB_TABLE_NAME) + self.db_env.create() - def tearDown(self): + def tearDown(self) -> None: self.con.close() os.remove(DB_FILE) - def build_db(self): - return StatusDB(self.con, 'test_table_name') - - def assertArrayEqual(self, ar1, ar2): + def assertArrayEqual(self, ar1: Iterable[Status], + ar2: Iterable[Status]) -> None: for expected, actual in it.zip_longest(ar1, ar2): self.assertEqual(expected, actual) - def update_scrobble_state(self, db, new_su): - sc = namedtuple('SC', 'scrobble')(scrobble=lambda x: None) - update_scrobble_state(db, sc, new_su) - - def test_update(self): + def update_scrobble_state(self, new_su: Status) -> None: + + def noop_scrobble(_status_updates: list[Status]) -> None: + return None + + def noop_send_now_playing(_status: Status) -> None: + return None + + def noop_auth() -> dict[str, str]: + raise AssertionError('auth should not be called in DB tests.') + + http_env = HttpEnv( + auth=noop_auth, + scrobble=noop_scrobble, + send_now_playing=noop_send_now_playing, + ) + env = ScrobblingEnv( + http=http_env, + db=self.db_env, + logger=logging.LoggerAdapter(logging.getLogger('test'), + {'service': 'test'}), + ) + run_update_scrobble_state(env, new_su, 50) + + def update_scrobble_state_with_scrobble( + self, + new_su: Status, + scrobble: Callable[[list[Status]], None], + batch_size: int, + ) -> None: + + def noop_send_now_playing(_status: Status) -> None: + return None + + def noop_auth() -> dict[str, str]: + raise AssertionError('auth should not be called in DB tests.') + + http_env = HttpEnv( + auth=noop_auth, + scrobble=scrobble, + send_now_playing=noop_send_now_playing, + ) + env = ScrobblingEnv( + http=http_env, + db=self.db_env, + logger=logging.LoggerAdapter(logging.getLogger('test'), + {'service': 'test'}), + ) + run_update_scrobble_state(env, new_su, batch_size) + + def test_update(self) -> None: d = datetime.datetime.now() sus = [ - SS(cur_time=d, duration=5, file='A', status=CmusStatus.playing), - SS(cur_time=d+secs(1), - duration=5, - file='A', - status=CmusStatus.paused) + make_status(cur_time=d, + duration=5, + file='A', + status=STATUS_PLAYING), + make_status(cur_time=d+secs(1), + duration=5, + file='A', + status=STATUS_PAUSED), ] - new_su = SS(cur_time=d+secs(3), - duration=5, - file='A', - status=CmusStatus.playing) + new_su = make_status(cur_time=d+secs(3), + duration=5, + file='A', + status=STATUS_PLAYING) with self.con: - db = self.build_db() - db.save_status_updates(sus) - self.assertArrayEqual(sus, db.get_status_updates()) - self.update_scrobble_state(db, new_su) - n_sus = db.get_status_updates() + self.db_env.save_status_updates(sus) + self.assertArrayEqual(sus, self.db_env.get_status_updates()) + self.update_scrobble_state(new_su) + n_sus = self.db_env.get_status_updates() self.assertArrayEqual(sus+[new_su], n_sus) - def test_scrobble_update(self): + def test_scrobble_update(self) -> None: # some tracks will scrobble and will no longer be stored d = datetime.datetime.now() sus = [ - SS(cur_time=d, duration=10, file='B', status=CmusStatus.playing) + make_status(cur_time=d, + duration=10, + file='B', + status=STATUS_PLAYING), + ] + new_su = make_status(cur_time=d+secs(10), + duration=5, + file='A', + status=STATUS_PLAYING) + with self.con: + self.db_env.save_status_updates(sus) + self.update_scrobble_state(new_su) + n_sus = self.db_env.get_status_updates() + self.assertArrayEqual([new_su], n_sus) + + def test_scrobble_batching_calls(self) -> None: + d = datetime.datetime.now() + sus = [ + make_status(cur_time=d, + duration=1, + file='A', + status=STATUS_PLAYING), + make_status(cur_time=d+secs(2), + duration=1, + file='B', + status=STATUS_PLAYING), + make_status(cur_time=d+secs(4), + duration=1, + file='C', + status=STATUS_PLAYING), ] - new_su = SS(cur_time=d+secs(10), - duration=5, - file='A', - status=CmusStatus.playing) + new_su = make_status(cur_time=d+secs(6), + duration=1, + file='D', + status=STATUS_PLAYING) + batches: list[list[Status]] = [] + + def record_scrobble(status_updates: list[Status]) -> None: + batches.append(status_updates) + with self.con: - db = self.build_db() - db.save_status_updates(sus) - self.update_scrobble_state(db, new_su) - n_sus = db.get_status_updates() + self.db_env.save_status_updates(sus) + self.update_scrobble_state_with_scrobble(new_su, + record_scrobble, + batch_size=2) + self.assertEqual([2, 1], [len(batch) for batch in batches]) + n_sus = self.db_env.get_status_updates() self.assertArrayEqual([new_su], n_sus) + def test_scrobble_partial_batch_failure_keeps_remaining(self) -> None: + d = datetime.datetime.now() + sus = [ + make_status(cur_time=d, + duration=1, + file='A', + status=STATUS_PLAYING), + make_status(cur_time=d+secs(2), + duration=1, + file='B', + status=STATUS_PLAYING), + make_status(cur_time=d+secs(4), + duration=1, + file='C', + status=STATUS_PLAYING), + ] + new_su = make_status(cur_time=d+secs(6), + duration=1, + file='D', + status=STATUS_PLAYING) + call_count = 0 + + def scrobble_with_failure(status_updates: list[Status]) -> None: + nonlocal call_count + call_count += 1 + if call_count==2: + raise RuntimeError('Batch failure') + + with self.con: + self.db_env.save_status_updates(sus) + self.update_scrobble_state_with_scrobble(new_su, + scrobble_with_failure, + batch_size=2) + n_sus = self.db_env.get_status_updates() + self.assertArrayEqual([sus[2], new_su], n_sus) + if __name__=='__main__': unittest.main()