diff --git a/.github/workflows/tests-macos.yml b/.github/workflows/tests-macos.yml new file mode 100644 index 0000000..1f35015 --- /dev/null +++ b/.github/workflows/tests-macos.yml @@ -0,0 +1,43 @@ +name: Tests (macOS) + +on: + push: + branches: + - main + pull_request: + branches: + - main + +permissions: + contents: read + +jobs: + test: + runs-on: macos-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + echo "$HOME/.cargo/bin" >> $GITHUB_PATH + shell: bash + + - name: Configure Git for private repos + run: | + git config --global url."https://x-access-token:${{ secrets.PRIVATE_REPO_TOKEN }}@github.com/".insteadOf "https://github.com/" + + - name: Install dependencies + run: | + uv sync + + - name: Run tests + run: | + make run-tests diff --git a/.github/workflows/tests.yml b/.github/workflows/tests-ubuntu.yml similarity index 53% rename from .github/workflows/tests.yml rename to .github/workflows/tests-ubuntu.yml index aa9bcf6..e60562c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests-ubuntu.yml @@ -1,4 +1,4 @@ -name: Tests +name: Tests (Ubuntu) on: push: @@ -13,48 +13,41 @@ permissions: jobs: test: - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ['3.11'] - + runs-on: ubuntu-latest + steps: - name: Checkout code uses: actions/checkout@v4 - - - name: Set up Python ${{ matrix.python-version }} + + - name: Set up Python 3.11 uses: actions/setup-python@v5 with: - python-version: ${{ matrix.python-version }} - + python-version: '3.11' + - name: Install uv run: | curl -LsSf https://astral.sh/uv/install.sh | sh echo "$HOME/.cargo/bin" >> $GITHUB_PATH shell: bash - + - name: Configure Git for private repos run: | git config --global url."https://x-access-token:${{ secrets.PRIVATE_REPO_TOKEN }}@github.com/".insteadOf "https://github.com/" - + - name: Install dependencies run: | uv sync - + - name: Run tests run: | - uv run pytest src/voxkit/storage/test/ -v --tb=short - - - name: Run linting (Ubuntu only) - if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11' + make run-tests + + - name: Run linting run: | - uv run ruff check src/ + make lint-check continue-on-error: true - - - name: Run type checking (Ubuntu only) - if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11' + + - name: Run type checking run: | - uv run mypy src/ + make mypy-check continue-on-error: true diff --git a/.github/workflows/tests-windows.yml b/.github/workflows/tests-windows.yml new file mode 100644 index 0000000..9fca071 --- /dev/null +++ b/.github/workflows/tests-windows.yml @@ -0,0 +1,43 @@ +name: Tests (Windows) + +on: + push: + branches: + - main + pull_request: + branches: + - main + +permissions: + contents: read + +jobs: + test: + runs-on: windows-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + echo "$HOME/.cargo/bin" >> $GITHUB_PATH + shell: bash + + - name: Configure Git for private repos + run: | + git config --global url."https://x-access-token:${{ secrets.PRIVATE_REPO_TOKEN }}@github.com/".insteadOf "https://github.com/" + + - name: Install dependencies + run: | + uv sync + + - name: Run tests + run: | + make run-tests diff --git a/README.md b/README.md index 915280b..2abab69 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,9 @@
diff --git a/_frozen_patch.py b/_frozen_patch.py index 009ec10..e7d5742 100644 --- a/_frozen_patch.py +++ b/_frozen_patch.py @@ -7,11 +7,11 @@ if getattr(sys, "frozen", False): # We're running in a PyInstaller bundle print("[PATCH] Applying frozen app patches...") - + # Patch 1: Disable typeguard runtime checking try: import typeguard - + # Replace typechecked decorator with a no-op def _noop_decorator(func=None, **kwargs): """No-op decorator that just returns the function unchanged""" @@ -20,12 +20,12 @@ def _noop_decorator(func=None, **kwargs): return lambda f: f # Called without arguments: @typechecked return func - - typeguard.typechecked = _noop_decorator + + typeguard.typechecked = _noop_decorator # type: ignore[assignment] print("[PATCH] Disabled typeguard runtime checking") except ImportError: pass - + # Patch 2: Fix inspect.getsource to not fail in frozen apps import inspect @@ -70,7 +70,7 @@ def _patched_getsourcefile(object): try: result = _original_getsourcefile(object) # Check if the file actually exists - if result and not __import__('os').path.exists(result): + if result and not __import__("os").path.exists(result): return None return result except (OSError, TypeError): diff --git a/hooks/hook-nltk.py b/hooks/hook-nltk.py index dc6350f..2e0cf5f 100644 --- a/hooks/hook-nltk.py +++ b/hooks/hook-nltk.py @@ -5,7 +5,7 @@ """ # Don't collect any NLTK data files - they will be downloaded at runtime -datas = [] +datas: list[tuple[str, str, str]] = [] # Exclude data collection -excludedimports = [] +excludedimports: list[str] = [] diff --git a/pyproject.toml b/pyproject.toml index cde8ec2..7722cb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dev = [ "pre-commit>=4.3.0", "ruff>=0.14.0", "mypy>=1.18.2", + "types-PyYAML>=6.0.0", "genbadge[coverage]>=1.1.3", ] docs = [ @@ -109,11 +110,15 @@ indent-style = "space" [tool.mypy] python_version = "3.11" ignore_missing_imports = true +check_untyped_defs = false +warn_return_any = true +warn_unused_configs = true exclude = [ "main.py", "build.py", "example_startup_script.py", - "test_imports.py" + "test_imports.py", + "tests/" ] diff --git a/src/voxkit/analyzers/default_analyzer.py b/src/voxkit/analyzers/default_analyzer.py index 4582d91..d9b7383 100644 --- a/src/voxkit/analyzers/default_analyzer.py +++ b/src/voxkit/analyzers/default_analyzer.py @@ -125,6 +125,9 @@ def paintEvent(self, event): color = QColor("#3498db") h, s, lightness, a = color.getHsl() + assert ( + h is not None and s is not None and lightness is not None and a is not None + ) ratio = count / max_count new_l = int(lightness + (220 - lightness) * (1 - ratio)) color.setHsl(h, s, min(new_l, 240), a) diff --git a/src/voxkit/config.py b/src/voxkit/config.py index 43dce6b..4b099be 100644 --- a/src/voxkit/config.py +++ b/src/voxkit/config.py @@ -4,7 +4,6 @@ from voxkit.services.mfa import download_acoustic_model from voxkit.storage import models from voxkit.storage.config import MODELS_ROOT -from voxkit.storage.models import download_and_copy_huggingface_model from voxkit.storage.utils import get_storage_root AppName = "VoxKit" @@ -94,10 +93,10 @@ def startup_routine(): # else: # print("[STARTUP] Failed to download W2TG model.") - try: import nltk - nltk.download('averaged_perceptron_tagger_eng') + + nltk.download("averaged_perceptron_tagger_eng") except Exception as e: print(f"[STARTUP] Failed to download NLTK resources. Error: {e}") diff --git a/src/voxkit/config/startup_config.py b/src/voxkit/config/startup_config.py index 405bd33..6bd2491 100644 --- a/src/voxkit/config/startup_config.py +++ b/src/voxkit/config/startup_config.py @@ -49,6 +49,7 @@ def startup_routine(): if not success: print(f"[STARTUP] Failed to create model metadata for {model}. {metadata}") continue + assert not isinstance(metadata, str) model_dest = metadata.get("model_path") if not model_dest: print(f"[STARTUP] Model path not found in metadata for {model}.") @@ -81,6 +82,7 @@ def startup_routine(): if not success: print(f"[STARTUP] Failed to create model metadata. {metadata}") return + assert not isinstance(metadata, str) model_dest = metadata.get("model_path") if not model_dest: print("[STARTUP] Model path not found in metadata.") @@ -94,10 +96,10 @@ def startup_routine(): else: print("[STARTUP] Failed to download W2TG model.") - try: import nltk - nltk.download('averaged_perceptron_tagger_eng') + + nltk.download("averaged_perceptron_tagger_eng") except Exception as e: print(f"[STARTUP] Failed to download NLTK resources. Error: {e}") diff --git a/src/voxkit/engines/__init__.py b/src/voxkit/engines/__init__.py index 0290be2..16421ab 100644 --- a/src/voxkit/engines/__init__.py +++ b/src/voxkit/engines/__init__.py @@ -97,7 +97,7 @@ def get_tool_providers(self, tool: ToolType) -> dict[str, AlignmentEngine]: w2tg = W2TGEngine(id="W2TGENGINE") mfa = MFAEngine(id="MFAENGINE") faster_whisper = FasterWhisperEngine(id="FASTERWHISPERENGINE") -engines = EngineManager({ mfa.id: mfa, faster_whisper.id: faster_whisper, w2tg.id: w2tg }) +engines = EngineManager({mfa.id: mfa, faster_whisper.id: faster_whisper, w2tg.id: w2tg}) __all__ = [ "engines", diff --git a/src/voxkit/engines/base.py b/src/voxkit/engines/base.py index 4502872..ae34871 100644 --- a/src/voxkit/engines/base.py +++ b/src/voxkit/engines/base.py @@ -203,7 +203,8 @@ def _load_json(self, path: Path | str) -> dict: return {} with open(path, "r", encoding="utf-8") as f: - return json.load(f) + result: dict = json.load(f) + return result def _get_default_settings(self, cfg: Any) -> dict: """ diff --git a/src/voxkit/engines/mfa_engine.py b/src/voxkit/engines/mfa_engine.py index 36dffd8..3035779 100644 --- a/src/voxkit/engines/mfa_engine.py +++ b/src/voxkit/engines/mfa_engine.py @@ -106,15 +106,16 @@ def __init__(self, id: str | None = None): id=id, ) - def align(self, dataset_id: str, model_id: str) -> None: print(f"Aligning with MFA using model: {model_id}") model_metadata = models.get_model_metadata(self.id, model_id) dataset_metadata = datasets.get_dataset_metadata(dataset_id) + if dataset_metadata is None: + raise ValueError(f"Dataset '{dataset_id}' not found.") - corpus_path = None + corpus_path: Path | None = None if bool(dataset_metadata["cached"]): corpus_path = datasets._get_dataset_root(dataset_id) @@ -130,6 +131,10 @@ def align(self, dataset_id: str, model_id: str) -> None: dataset_id=dataset_id, ) + if not result: + raise ValueError(f"Alignment creation failed: {msg}") + + assert not isinstance(msg, str) alignment_output_path = msg["tg_path"] print( @@ -172,6 +177,11 @@ def train_aligner( model_name=new_model_id, ) + if not success: + raise ValueError(f"Failed to create model entry: {msg}") + + assert not isinstance(msg, str) + # ========= TEMP FIX FOR MFA MODEL EXTENSION ======== # MFA models use .model extension, so we need to adjust the model path accordingly # This should ideally be handled in the storage/models.py create_model function @@ -182,7 +192,7 @@ def train_aligner( new_metadata = msg new_model_path = new_metadata["model_path"] if str(new_model_path).endswith(".model"): - new_model_path = str(new_model_path).split(".model")[0] + ".zip" + new_model_path = Path(str(new_model_path).split(".model")[0] + ".zip") new_metadata["model_path"] = new_model_path # ================================================== @@ -190,13 +200,14 @@ def train_aligner( model_metadata_path = Path(new_metadata["model_path"]).parent / "voxkit_model.json" # Make metadata dict serializable - for key in new_metadata: - if isinstance(new_metadata[key], Path): - new_metadata[key] = str(new_metadata[key]) + serializable: dict[str, object] = dict(new_metadata) + for key in serializable: + if isinstance(serializable[key], Path): + serializable[key] = str(serializable[key]) with open(model_metadata_path, "w") as f: import json - json.dump(new_metadata, f, indent=4) + json.dump(serializable, f, indent=4) # ================================================== @@ -209,9 +220,6 @@ def train_aligner( # ================================================== - if not success: - raise ValueError(f"Failed to create model entry: {msg}") - new_model_path = new_metadata["model_path"] print( diff --git a/src/voxkit/engines/w2tg_engine.py b/src/voxkit/engines/w2tg_engine.py index 28898c4..8134edb 100644 --- a/src/voxkit/engines/w2tg_engine.py +++ b/src/voxkit/engines/w2tg_engine.py @@ -137,10 +137,11 @@ def align(self, dataset_id: str, model_id: str) -> None: ) print(f"Alignment creation result: {result}, message: {msg}") - if result is False: + if not result: print(f"Alignment creation failed: {msg}") return + assert not isinstance(msg, str) alignment_meta = msg dataset_meta = datasets.get_dataset_metadata(dataset_id) model_meta = models.get_model_metadata(self.id, model_id) @@ -191,6 +192,7 @@ def train_aligner( if not success: raise ValueError(f"Failed to create model entry: {message}") + assert not isinstance(message, str) model_meta = message model_path = Path(model_meta["model_path"]) data_path = Path(model_meta["data_path"]) @@ -235,8 +237,9 @@ def train_aligner( ) except Exception as e: print(f"Training failed: {e}") - # CLean up model entry on failure - models.delete_model(engine_id=self.id, model_id=new_model_actual_id) + # Clean up model entry on failure + if new_model_actual_id is not None: + models.delete_model(engine_id=self.id, model_id=new_model_actual_id) raise e def _validate_align_settings(self, settings: dict) -> bool: diff --git a/src/voxkit/gui/__init__.py b/src/voxkit/gui/__init__.py index 557a38f..d5d0eea 100644 --- a/src/voxkit/gui/__init__.py +++ b/src/voxkit/gui/__init__.py @@ -155,6 +155,7 @@ } """ + class AlignmentGUI(QMainWindow): def __init__( self, diff --git a/src/voxkit/gui/components/column_dropdown.py b/src/voxkit/gui/components/column_dropdown.py index 198e93a..3967cec 100644 --- a/src/voxkit/gui/components/column_dropdown.py +++ b/src/voxkit/gui/components/column_dropdown.py @@ -22,14 +22,20 @@ def __init__(self, parent=None): table_view = QTableView() table_view.setSelectionBehavior(QTableView.SelectionBehavior.SelectRows) table_view.setSelectionMode(QTableView.SelectionMode.SingleSelection) - table_view.verticalHeader().hide() - table_view.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeMode.Stretch) + vertical_header = table_view.verticalHeader() + if vertical_header is not None: + vertical_header.hide() + horizontal_header = table_view.horizontalHeader() + if horizontal_header is not None: + horizontal_header.setSectionResizeMode(QHeaderView.ResizeMode.Stretch) table_view.setShowGrid(False) self.setView(table_view) # Optional: make the popup wider - self.view().setMinimumWidth(400) + view = self.view() + if view is not None: + view.setMinimumWidth(400) def set_data(self, rows, headers=None, placeholder=None): """ @@ -71,12 +77,15 @@ def set_data(self, rows, headers=None, placeholder=None): def current_id(self): """Get the ID of the currently selected row.""" - index = self.model().index(self.currentIndex(), 0) + model = self.model() + if model is None: + return None + index = model.index(self.currentIndex(), 0) if not index.isValid(): return None - return self.model().data(index, Qt.ItemDataRole.UserRole) + return model.data(index, Qt.ItemDataRole.UserRole) if __name__ == "__main__": diff --git a/src/voxkit/gui/components/csv_viewer_dialog.py b/src/voxkit/gui/components/csv_viewer_dialog.py index 3bcc9de..e193491 100644 --- a/src/voxkit/gui/components/csv_viewer_dialog.py +++ b/src/voxkit/gui/components/csv_viewer_dialog.py @@ -49,7 +49,7 @@ def __init__(self, csv_path: str, parent=None, visualization=None): self.blur_effect.setBlurRadius(10) parent.setGraphicsEffect(self.blur_effect) - self.parent = parent + self._parent_widget = parent self._init_ui() if not self.visualization: self._load_csv() @@ -119,12 +119,15 @@ def _load_csv(self): # Auto-resize columns header = self.table.horizontalHeader() - for i in range(len(headers)): - header.setSectionResizeMode(i, QHeaderView.ResizeMode.ResizeToContents) + if header is not None: + for i in range(len(headers)): + header.setSectionResizeMode(i, QHeaderView.ResizeMode.ResizeToContents) - # Make last column stretch - if len(headers) > 0: - header.setSectionResizeMode(len(headers) - 1, QHeaderView.ResizeMode.Stretch) + # Make last column stretch + if len(headers) > 0: + header.setSectionResizeMode( + len(headers) - 1, QHeaderView.ResizeMode.Stretch + ) # Update stats self.stats_label.setText(f"✅ {len(data)} rows × {len(headers)} columns") @@ -135,14 +138,14 @@ def _load_csv(self): def closeEvent(self, event): """Handle dialog close event to remove blur effect.""" print("Dialog closed, removing blur effect from parent") - if self.parent: + if self._parent_widget: print("Removing blur effect from parent") - self.parent.setGraphicsEffect(None) + self._parent_widget.setGraphicsEffect(None) event.accept() def reject(self): """Handle dialog rejection to remove blur effect.""" - if self.parent: + if self._parent_widget: print("Removing blur effect from parent") - self.parent.setGraphicsEffect(None) + self._parent_widget.setGraphicsEffect(None) super().reject() diff --git a/src/voxkit/gui/components/grip_splitter.py b/src/voxkit/gui/components/grip_splitter.py index 565748f..7470099 100644 --- a/src/voxkit/gui/components/grip_splitter.py +++ b/src/voxkit/gui/components/grip_splitter.py @@ -67,9 +67,7 @@ def paintEvent(self, event): else: x = center_x y = center_y + (i * dot_spacing) - painter.drawEllipse( - x - dot_radius, y - dot_radius, dot_radius * 2, dot_radius * 2 - ) + painter.drawEllipse(x - dot_radius, y - dot_radius, dot_radius * 2, dot_radius * 2) painter.end() diff --git a/src/voxkit/gui/components/huggingface_button.py b/src/voxkit/gui/components/huggingface_button.py index b992bbb..665cc62 100644 --- a/src/voxkit/gui/components/huggingface_button.py +++ b/src/voxkit/gui/components/huggingface_button.py @@ -9,7 +9,8 @@ from PyQt6.QtCore import Qt from PyQt6.QtWidgets import QHBoxLayout, QLabel, QPushButton, QWidget -from voxkit.gui.styles import Labels, Buttons + +from voxkit.gui.styles import Buttons, Labels class HuggingFaceButton(QPushButton): @@ -43,7 +44,6 @@ def _setup_ui(self, title): self.setCursor(Qt.CursorShape.PointingHandCursor) - # Example usage if __name__ == "__main__": import sys diff --git a/src/voxkit/gui/components/loading_dialog.py b/src/voxkit/gui/components/loading_dialog.py index 5a39743..87e1c61 100644 --- a/src/voxkit/gui/components/loading_dialog.py +++ b/src/voxkit/gui/components/loading_dialog.py @@ -114,11 +114,13 @@ def center_on_screen(self): """Center the dialog on the primary screen.""" from PyQt6.QtGui import QGuiApplication - screen = QGuiApplication.primaryScreen().geometry() - dialog_rect = self.frameGeometry() - center_point = screen.center() - dialog_rect.moveCenter(center_point) - self.move(dialog_rect.topLeft()) + screen = QGuiApplication.primaryScreen() + if screen is not None: + screen_geometry = screen.geometry() + dialog_rect = self.frameGeometry() + center_point = screen_geometry.center() + dialog_rect.moveCenter(center_point) + self.move(dialog_rect.topLeft()) def _update_spinner(self): """Update the spinner animation.""" @@ -136,10 +138,13 @@ def update_message(self, message: str): Args: message: The new message to display """ - if self.layout() and self.layout().itemAt(0): - label = self.layout().itemAt(0).widget() - if isinstance(label, QLabel): - label.setText(message) + layout = self.layout() + if layout and layout.itemAt(0): + item = layout.itemAt(0) + if item: + label = item.widget() + if isinstance(label, QLabel): + label.setText(message) def close_gracefully(self): """Close the dialog with a fade-out animation.""" diff --git a/src/voxkit/gui/components/model_selection_panel.py b/src/voxkit/gui/components/model_selection_panel.py index df244b6..034de41 100644 --- a/src/voxkit/gui/components/model_selection_panel.py +++ b/src/voxkit/gui/components/model_selection_panel.py @@ -18,8 +18,8 @@ ) from voxkit.gui.components import MultiColumnComboBox +from voxkit.gui.styles import Containers, Labels from voxkit.storage import models -from voxkit.gui.styles import Labels, Containers class ModelSelectionPanel(QGroupBox): @@ -52,10 +52,10 @@ def __init__( self.engines_dict = engines_dict self.info_text = info_text self.placeholder = placeholder - self.selected_engine = None + self.selected_engine: str | None = None self.engine_dropdowns: dict[str, MultiColumnComboBox] = {} self.engine_radios: dict[str, QRadioButton] = {} - self.mode_button_group = None + self.mode_button_group: QButtonGroup if title: self.setTitle(title) @@ -200,8 +200,8 @@ def get_selected_model_id(self) -> str | None: """ if self.selected_engine is not None and self.selected_engine in self.engine_dropdowns: - - return self.engine_dropdowns[self.selected_engine].current_id() + model_id = self.engine_dropdowns[self.selected_engine].current_id() + return model_id if isinstance(model_id, str) or model_id is None else None return None def reload_models(self): diff --git a/src/voxkit/gui/frameworks/categorical_table/categorical_table.py b/src/voxkit/gui/frameworks/categorical_table/categorical_table.py index fb47469..6f7a588 100644 --- a/src/voxkit/gui/frameworks/categorical_table/categorical_table.py +++ b/src/voxkit/gui/frameworks/categorical_table/categorical_table.py @@ -25,7 +25,7 @@ ) from voxkit.gui.components import HuggingFaceButton -from voxkit.gui.styles import Buttons, Colors, Containers, Labels +from voxkit.gui.styles import Buttons, Containers, Labels class CategoricalTableWidget(QWidget): @@ -259,7 +259,7 @@ def update_display(self): # Determine columns to show if not self.columns_shown: # Auto-detect columns from first few items - all_keys = set() + all_keys: set[str] = set() for item in category_data[:5]: # Sample first 5 items if isinstance(item, dict): all_keys.update(item.keys()) @@ -299,16 +299,17 @@ def update_display(self): # Configure column widths for optimal stretching header = self.table_widget.horizontalHeader() - # Make data columns resize to contents or stretch - for i in range(len(self.columns_shown)): - if i == 0: - # First column: resize to contents - header.setSectionResizeMode(i, QHeaderView.ResizeMode.ResizeToContents) - else: - # Other data columns: stretch - header.setSectionResizeMode(i, QHeaderView.ResizeMode.Stretch) - # Actions column: fixed width - header.setSectionResizeMode(len(self.columns_shown), QHeaderView.ResizeMode.Fixed) + if header is not None: + # Make data columns resize to contents or stretch + for i in range(len(self.columns_shown)): + if i == 0: + # First column: resize to contents + header.setSectionResizeMode(i, QHeaderView.ResizeMode.ResizeToContents) + else: + # Other data columns: stretch + header.setSectionResizeMode(i, QHeaderView.ResizeMode.Stretch) + # Actions column: fixed width + header.setSectionResizeMode(len(self.columns_shown), QHeaderView.ResizeMode.Fixed) self.table_widget.setColumnWidth(len(self.columns_shown), 80) def view_item_details(self, row_index): @@ -406,7 +407,7 @@ def deselect_all(self): def get_selected_items(self): """Get list of selected items""" - selected = [] + selected: list[dict] = [] if not self.category_keys: return selected @@ -414,11 +415,13 @@ def get_selected_items(self): category_data = self.data[current_category] # Get selected rows from table - selected_rows = self.table_widget.selectionModel().selectedRows() - for index in selected_rows: - row = index.row() - if row < len(category_data): - selected.append(category_data[row]) + selection_model = self.table_widget.selectionModel() + if selection_model is not None: + selected_rows = selection_model.selectedRows() + for index in selected_rows: + row = index.row() + if row < len(category_data): + selected.append(category_data[row]) return selected diff --git a/src/voxkit/gui/frameworks/settings_modal/generic.py b/src/voxkit/gui/frameworks/settings_modal/generic.py index bde4ce6..d8a65bc 100644 --- a/src/voxkit/gui/frameworks/settings_modal/generic.py +++ b/src/voxkit/gui/frameworks/settings_modal/generic.py @@ -121,7 +121,7 @@ def _setup_overlay(self, parent) -> None: parent: Parent widget to apply blur effect to. """ try: - main_window = parent.parent + main_window = parent.parent() if main_window is None: return @@ -132,7 +132,7 @@ def _setup_overlay(self, parent) -> None: # Apply blur effect blur_effect = QGraphicsBlurEffect() blur_effect.setBlurRadius(5) - parent.parent.setGraphicsEffect(blur_effect) + parent.parent().setGraphicsEffect(blur_effect) overlay.deleteLater() except (AttributeError, ImportError): @@ -281,7 +281,7 @@ def _create_field_widget(self, config: FieldConfig) -> QWidget: Raises: ValueError: If field_type is not recognized. """ - widget: Union[QSpinBox, QDoubleSpinBox, QLineEdit, QComboBox] + widget: Union[QSpinBox, QDoubleSpinBox, QLineEdit, QComboBox, ToggleSwitch] if config.field_type == FieldType.SPINBOX: widget = self._create_spinbox(config) elif config.field_type == FieldType.DOUBLE_SPINBOX: diff --git a/src/voxkit/gui/pages/datasets/datasets_page.py b/src/voxkit/gui/pages/datasets/datasets_page.py index 0355feb..0cd3537 100644 --- a/src/voxkit/gui/pages/datasets/datasets_page.py +++ b/src/voxkit/gui/pages/datasets/datasets_page.py @@ -34,6 +34,8 @@ from voxkit.gui.styles import Buttons, Containers, Labels from voxkit.gui.workers import DatasetRegistrationWorker from voxkit.storage import alignments, datasets +from voxkit.storage.alignments import AlignmentMetadata +from voxkit.storage.datasets import DatasetMetadata class DatasetsPage(QWidget): @@ -42,8 +44,8 @@ class DatasetsPage(QWidget): def __init__(self, parent: QWidget | None = None): super().__init__(parent) self.parent_window = parent - self.registration_worker = None - self.selected_dataset: dict | None = None + self.registration_worker: DatasetRegistrationWorker | None = None + self.selected_dataset: str | None = None self.init_ui() self.refresh_datasets() @@ -179,7 +181,7 @@ def on_import(self): QMessageBox.warning(self, "No Destination Selected", "Please select a destination.") return - success, message = datasets.import_dataset(dir_path) + success, message = datasets.import_dataset(Path(dir_path)) if success: QMessageBox.information(self, "Success", message) @@ -275,11 +277,12 @@ def _create_list_section(self): # Configure table header = self.dataset_table.horizontalHeader() - header.setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents) - header.setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) - for i in range(2, 6): - header.setSectionResizeMode(i, QHeaderView.ResizeMode.ResizeToContents) - header.setSectionResizeMode(6, QHeaderView.ResizeMode.Fixed) + if header is not None: + header.setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents) + header.setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) + for i in range(2, 6): + header.setSectionResizeMode(i, QHeaderView.ResizeMode.ResizeToContents) + header.setSectionResizeMode(6, QHeaderView.ResizeMode.Fixed) self.dataset_table.setColumnWidth(6, 100) self.dataset_table.setSelectionBehavior(QTableWidget.SelectionBehavior.SelectRows) @@ -365,11 +368,12 @@ def _create_alignments_panel(self): # Configure alignments table align_header = self.alignments_table.horizontalHeader() - align_header.setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents) - align_header.setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) - align_header.setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents) - align_header.setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents) - align_header.setSectionResizeMode(4, QHeaderView.ResizeMode.Fixed) + if align_header is not None: + align_header.setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents) + align_header.setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) + align_header.setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents) + align_header.setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents) + align_header.setSectionResizeMode(4, QHeaderView.ResizeMode.Fixed) self.alignments_table.setColumnWidth(4, 150) # Disable selection and editing @@ -429,7 +433,7 @@ def _set_alignments_blur(self, blurred: bool): self.alignments_content.setGraphicsEffect(None) self.alignments_content.setEnabled(True) - def _load_alignments(self, dataset_id: datasets.DatasetMetadata["id"]): + def _load_alignments(self, dataset_id: str) -> None: """Load alignments for the selected dataset""" print(f"Loading alignments for dataset ID: {dataset_id}") @@ -525,7 +529,7 @@ def _on_alignment_cell_clicked(self, row: int, col: int): f"Model clicked: {model_name}\n\nModel details will be shown here.", ) - def _create_alignment_action_buttons(self, alignment: dict): + def _create_alignment_action_buttons(self, alignment: alignments.AlignmentMetadata) -> QWidget: """Create action buttons for an alignment row""" widget = QWidget() layout = QHBoxLayout(widget) @@ -545,6 +549,7 @@ def _create_alignment_action_buttons(self, alignment: dict): view_btn = QPushButton("View") view_btn.setStyleSheet(button_style) view_btn.clicked.connect(lambda: self._view_alignment(alignment)) + layout.addWidget(view_btn) return widget @@ -553,15 +558,15 @@ def convert_alignments(self, dataset_name: str) -> list: data_set_alignments = alignments.list_alignments(dataset_name) return data_set_alignments - def _view_alignment(self, alignment: dict): + def _view_alignment(self, alignment: AlignmentMetadata): """View alignment details""" # TODO: Implement alignment details view QMessageBox.information( self, "Alignment Details", - f"Engine: {alignment['engine']}\n" - f"Model: {alignment['model']}\n" - f"Date: {alignment['date_aligned']}\n" + f"Engine: {alignment['engine_id']}\n" + f"Model: {alignment['model_metadata']['id']}\n" + f"Date: {alignment['alignment_date']}\n" f"Status: {alignment['status']}", ) @@ -574,7 +579,7 @@ def _export_alignment(self, alignment: dict): f"Export functionality for alignment '{alignment['model']}' will be implemented.", ) - def _delete_alignment(self, alignment: dict): + def _delete_alignment(self, alignment: AlignmentMetadata): """Delete an alignment""" reply = QMessageBox.question( self, @@ -585,6 +590,8 @@ def _delete_alignment(self, alignment: dict): ) if reply == QMessageBox.StandardButton.Yes: + if self.selected_dataset is None: + return success, msg = alignments.delete_alignment( dataset_id=self.selected_dataset, alignment_id=alignment["id"] ) @@ -657,7 +664,7 @@ def open_registration_dialog(self): label="De-identified", field_type=FieldType.CHECKBOX, default_value=False, - tooltip="Has the dataset been de-identified to remove personally identifiable information?", + tooltip="Has personally identifiable information been removed?", ), FieldConfig( name="transcribed", @@ -741,17 +748,16 @@ def registration_complete(self, success, message): def refresh_datasets(self): """Refresh the dataset list""" - # Show empty label if no datasets, otherwise show table - if not datasets: + metadata_list = datasets.list_datasets_metadata() + + if not metadata_list: self.dataset_table.hide() self.empty_label.show() self.empty_label.raise_() return - else: - self.empty_label.hide() - self.dataset_table.show() - metadata_list = datasets.list_datasets_metadata() + self.empty_label.hide() + self.dataset_table.show() self.dataset_table.setRowCount(len(metadata_list)) @@ -790,7 +796,7 @@ def refresh_datasets(self): actions_widget = self._create_dataset_action_buttons(meta) self.dataset_table.setCellWidget(index, 6, actions_widget) - def _create_dataset_action_buttons(self, dataset_meta: dict): + def _create_dataset_action_buttons(self, dataset_meta: DatasetMetadata): """Create action buttons for a dataset row. Args: @@ -815,7 +821,7 @@ def _create_dataset_action_buttons(self, dataset_meta: dict): return widget - def _view_dataset_details(self, dataset_meta: dict): + def _view_dataset_details(self, dataset_meta: DatasetMetadata | dict): """View dataset analysis CSV details. Args: diff --git a/src/voxkit/gui/pages/datasets/utils.py b/src/voxkit/gui/pages/datasets/utils.py index 696cc23..a14beeb 100644 --- a/src/voxkit/gui/pages/datasets/utils.py +++ b/src/voxkit/gui/pages/datasets/utils.py @@ -8,6 +8,8 @@ - **on_delete**: Handle delete button click for selected dataset """ +from pathlib import Path + from PyQt6.QtWidgets import QFileDialog, QMessageBox from voxkit.storage import datasets @@ -30,7 +32,7 @@ def on_export(self): QMessageBox.warning(self, "No Destination Selected", "Please select a destination.") return - success, message = datasets.export_dataset(self.selected_dataset["id"], dir_path) + success, message = datasets.export_dataset(self.selected_dataset["id"], Path(dir_path)) if success: QMessageBox.information(self, "Success", message) else: diff --git a/src/voxkit/gui/pages/models/import_dialog.py b/src/voxkit/gui/pages/models/import_dialog.py index e7ebf12..a325942 100644 --- a/src/voxkit/gui/pages/models/import_dialog.py +++ b/src/voxkit/gui/pages/models/import_dialog.py @@ -41,7 +41,7 @@ def __init__( """ self.on_import_callback = on_import or self._placeholder_import self.engine_id = engine_id - self.parent = parent + self._parent_widget = parent # Define fields fields = [ @@ -70,7 +70,7 @@ def __init__( ) super().__init__( - parent=self.parent, + parent=self._parent_widget, config=config, ) @@ -98,8 +98,11 @@ def _placeholder_import(self, model_path: str): ) if not success: QMessageBox.critical(self, "Import Failed", f"Failed to create model entry: {message}") + return + if isinstance(message, dict): + dest_model_path = str(message["model_path"]) else: - dest_model_path = message["model_path"] + return result = download_and_copy_huggingface_model(model_path, destination=dest_model_path) if result is None: diff --git a/src/voxkit/gui/pages/models/models_page.py b/src/voxkit/gui/pages/models/models_page.py index a4a1420..83dd377 100644 --- a/src/voxkit/gui/pages/models/models_page.py +++ b/src/voxkit/gui/pages/models/models_page.py @@ -21,6 +21,7 @@ ) from voxkit.gui.workers import ModelRegistrationWorker from voxkit.storage import models +from voxkit.storage.models import ModelMetadata from .import_dialog import ImportModelDialog from .utils import handle_delete, handle_export, handle_import @@ -34,11 +35,11 @@ class ManageAlignersWidget(CategoricalTableWidget): """ def __init__(self, parent=None): - self.parent = parent + self._parent_widget = parent self.data = {} self.registration_worker = None - def refresh_models_function() -> dict[str, list[dict[Any, Any]]]: + def refresh_models_function() -> dict[str, list[ModelMetadata]]: try: model_dict = {} for engine in self.get_engines(): @@ -70,7 +71,7 @@ def delete_models_function(category: str, items: list[dict[Any, Any]]) -> tuple[ delete_function=delete_models_function, columns_shown=["name", "download_date", "id"], huggingface_callback=self.on_huggingface_browse, - parent=self.parent, + parent=self._parent_widget, ) self.setWindowTitle("Model Manager") @@ -98,6 +99,7 @@ def get_engines(self) -> list: if engine.has_tool("align") or engine.has_tool("train"): filtered_engines.append(engine_id) return filtered_engines + def showEvent(self, event): """Refresh models when the widget is shown. @@ -173,9 +175,8 @@ def scrub_training_runs(self, mode, items: dict): def reload_models(self): """Reload models in the dropdowns""" - - w2tg_models = models.list_models("W2TGENGINE") - self.set_items("W2TGENGINE", w2tg_models) + for engine in self.get_engines(): + self.set_items(engine, models.list_models(engine)) def open_import_dialog(self, category): print(f"Opening import dialog for category: {category}") @@ -185,7 +186,7 @@ def open_import_dialog(self, category): print(f"Importing {category} Model from {path}") self.reload_models() # Clean up - self.parent.setGraphicsEffect(None) + self._parent_widget.setGraphicsEffect(None) def open_registration_dialog(self): """Open the model registration settings dialog""" diff --git a/src/voxkit/gui/pages/models/utils.py b/src/voxkit/gui/pages/models/utils.py index ebad374..1e91a60 100644 --- a/src/voxkit/gui/pages/models/utils.py +++ b/src/voxkit/gui/pages/models/utils.py @@ -20,7 +20,7 @@ from voxkit.storage import models -def handle_import(parent_widget, current_category: str): +def handle_import(parent_widget, current_category: str) -> tuple[bool, str]: """ Handle importing models into the storage. @@ -110,9 +110,10 @@ def handle_export( source_path = models._get_model_root(current_category, item["id"]) else: failed_items.append(f"{item} (invalid item format)") + continue - if not source_path.exists(): - failed_items.append(f"{str(source_path)} (source path does not exist)") + if source_path is None or not source_path.exists(): + failed_items.append(f"{source_path} (source path does not exist)") continue # Determine destination @@ -193,9 +194,7 @@ def create_export_handler(widget, data): """ return lambda folder_name, selected_items: handle_export( widget, - folder_name, selected_items, - data, widget.category_keys[widget.current_category_index], ) diff --git a/src/voxkit/gui/pages/pipeline/base_stacker.py b/src/voxkit/gui/pages/pipeline/base_stacker.py index eb4e181..e1227ef 100644 --- a/src/voxkit/gui/pages/pipeline/base_stacker.py +++ b/src/voxkit/gui/pages/pipeline/base_stacker.py @@ -15,13 +15,13 @@ class BaseStacker(QWidget): """Base class for pipeline stacker widgets. - + This class provides common functionality for all stackers including: - Standard layout and margins - Header with title and optional settings button - Status label management - Reload methods for datasets and models - + Subclasses should implement: - build_ui(): Create the main content of the stacker - get_title(): Return the stacker title (optional) @@ -30,57 +30,57 @@ class BaseStacker(QWidget): - reload_models(): Reload model data (optional) - reload_datasets(): Reload dataset data (optional) """ - - def __init__(self, parent=None): + + def __init__(self, parent: QWidget | None = None) -> None: """Initialize the base stacker. - + Args: parent: Parent widget, typically the main window """ super().__init__(parent) - self.parent = parent - self.status_label = None - self.main_layout = None - self.content_layout = None + self._parent_widget = parent + self.status_label: QLabel | None = None + self.main_layout: QVBoxLayout + self.content_layout: QVBoxLayout self.init_ui() - + def init_ui(self): """Initialize the standard UI structure.""" self.setMinimumWidth(600) self.main_layout = QVBoxLayout(self) self.main_layout.setSpacing(15) self.main_layout.setContentsMargins(30, 30, 30, 30) - + # Create header if title is provided if self.get_title(): self._create_header() self.main_layout.addSpacing(20) - + # Create content layout for subclass to populate self.content_layout = QVBoxLayout() self.content_layout.setSpacing(10) self.main_layout.addLayout(self.content_layout) - + # Call subclass to build their specific UI self.build_ui() - + # Add status label at the bottom if stacker wants it if self.has_status_label(): self._create_status_label() - + # Add stretch at the end self.main_layout.addStretch() - + def _create_header(self): """Create the standard header with title and optional settings button.""" header_layout = QHBoxLayout() - + # Title title = QLabel(self.get_title()) title.setStyleSheet(Labels.PAGE_TITLE) header_layout.addWidget(title) header_layout.addStretch() - + # Settings button (if needed) if self.has_settings(): settings_btn = QPushButton("⚙️") @@ -88,96 +88,96 @@ def _create_header(self): settings_btn.setStyleSheet(Buttons.ICON) settings_btn.clicked.connect(self.on_settings) header_layout.addWidget(settings_btn) - + self.main_layout.addLayout(header_layout) - + def _create_status_label(self): """Create the standard status label.""" self.status_label = QLabel("Ready") self.status_label.setStyleSheet(Labels.STATUS_READY) self.status_label.setAlignment(Qt.AlignmentFlag.AlignCenter) self.main_layout.addWidget(self.status_label) - + def set_status(self, message: str, status_type: str = "ready"): """Set the status label text and styling. - + Args: message: Status message to display status_type: One of "ready", "working", "success", "error" """ if not self.status_label: return - + status_styles = { "ready": Labels.STATUS_READY, "working": Labels.STATUS_WORKING, "success": Labels.STATUS_SUCCESS, "error": Labels.STATUS_ERROR, } - + self.status_label.setText(message) self.status_label.setStyleSheet(status_styles.get(status_type, Labels.STATUS_READY)) - + # Methods for subclasses to override - + def build_ui(self): """Build the stacker-specific UI components. - + This method is called during init_ui() and should populate self.content_layout with the stacker's specific widgets. - + Subclasses MUST override this method. """ raise NotImplementedError("Subclasses must implement build_ui()") - + def get_title(self) -> str: """Return the stacker's title for the header. - + Returns: Title string, or empty string for no header """ return "" - + def has_settings(self) -> bool: """Whether this stacker has a settings button. - + Returns: True if settings button should be shown """ return False - + def has_status_label(self) -> bool: """Whether this stacker has a status label. - + Returns: True if status label should be shown """ return True - + def on_settings(self): """Handle settings button click. - + Override this method if has_settings() returns True. """ pass - + def reload_models(self): """Reload model data in the stacker. - + Override this method if the stacker displays models. """ pass - + def reload_datasets(self): """Reload dataset data in the stacker. - + Override this method if the stacker displays datasets. """ pass - + def reload(self): """Reload all data in the stacker. - + This is called by the parent when data needs to be refreshed. Default implementation calls reload_models() and reload_datasets(). """ diff --git a/src/voxkit/gui/pages/pipeline/comparison_stacker.py b/src/voxkit/gui/pages/pipeline/comparison_stacker.py index 86878a9..ee1f254 100644 --- a/src/voxkit/gui/pages/pipeline/comparison_stacker.py +++ b/src/voxkit/gui/pages/pipeline/comparison_stacker.py @@ -9,6 +9,8 @@ - **ComparisonStacker**: Alignment comparison workflow UI """ +from __future__ import annotations + import glob from pathlib import Path @@ -32,9 +34,11 @@ from voxkit.gui.pages.pipeline.base_stacker import BaseStacker from voxkit.gui.styles import Buttons, Colors, Containers, Labels from voxkit.storage import alignments, datasets +from voxkit.storage.alignments import AlignmentMetadata +from voxkit.storage.datasets import DatasetMetadata -def _get_tg_paths(alignment_meta: dict) -> list[str]: +def _get_tg_paths(alignment_meta: AlignmentMetadata) -> list[str]: """Glob all TextGrid files under an alignment's tg_path directory.""" tg_root = Path(alignment_meta["tg_path"]) return glob.glob(str(tg_root / "**" / "*.TextGrid"), recursive=True) @@ -64,37 +68,37 @@ class ComparisonStacker(BaseStacker): as PNGs to a folder you choose. """ - def __init__(self, parent=None): - # Shared dataset state - self._dataset_dropdown: MultiColumnComboBox | None = None - self._dataset_meta: dict | None = None + def __init__(self, parent: QWidget | None = None) -> None: + # Widgets set in build_ui() (called by super().__init__) + self._dataset_dropdown: MultiColumnComboBox + self._dataset_meta: DatasetMetadata | None = None # A-side alignment state - self._a_alignment_dropdown: MultiColumnComboBox | None = None - self._a_alignment_meta: dict | None = None + self._a_alignment_dropdown: MultiColumnComboBox + self._a_alignment_meta: AlignmentMetadata | None = None # B-side alignment state - self._b_alignment_dropdown: MultiColumnComboBox | None = None - self._b_alignment_meta: dict | None = None + self._b_alignment_dropdown: MultiColumnComboBox + self._b_alignment_meta: AlignmentMetadata | None = None # Options - self._tier_input: QLineEdit | None = None - self._aggregate_cb: QCheckBox | None = None - self._threshold_spin: QDoubleSpinBox | None = None - self._compare_btn: QPushButton | None = None + self._tier_input: QLineEdit + self._aggregate_cb: QCheckBox + self._threshold_spin: QDoubleSpinBox + self._compare_btn: QPushButton # Results - self._results_widget: QWidget | None = None - self._tab_widget: QTabWidget | None = None + self._results_widget: QWidget + self._tab_widget: QTabWidget # Download self._dl_folder: str = "" - self._dl_folder_label: QLabel | None = None - self._dl_counts_cb: QCheckBox | None = None - self._dl_overlap_cb: QCheckBox | None = None - self._dl_rate_cb: QCheckBox | None = None - self._dl_scatter_cb: QCheckBox | None = None - self._download_btn: QPushButton | None = None + self._dl_folder_label: QLabel + self._dl_counts_cb: QCheckBox + self._dl_overlap_cb: QCheckBox + self._dl_rate_cb: QCheckBox + self._dl_scatter_cb: QCheckBox + self._download_btn: QPushButton # Last comparison parameters (populated on successful compare) self._last_comparison: dict | None = None @@ -167,7 +171,8 @@ def build_ui(self): self._threshold_spin.setFixedWidth(70) self._threshold_spin.setStyleSheet( "QDoubleSpinBox { border: 1px solid #d0d0d0; border-radius: 4px; " - "padding: 4px; font-size: 12px; color: black; background: white; selection-color: black; selection-background-color: #cce5ff; }" + "padding: 4px; font-size: 12px; color: black; background: white; " + "selection-color: black; selection-background-color: #cce5ff; }" ) opts.addWidget(self._threshold_spin) @@ -188,9 +193,7 @@ def build_ui(self): self._tab_widget = QTabWidget() self._tab_widget.setFixedHeight(520) - self._tab_widget.setSizePolicy( - QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed - ) + self._tab_widget.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed) results_col.addWidget(self._tab_widget) # Download group ────────────────────────────────────────────────────── @@ -272,11 +275,8 @@ def _make_alignment_box(self, side: str, *, is_a: bool) -> QGroupBox: # ── Reload hook ────────────────────────────────────────────────────────── - def reload_datasets(self): + def reload_datasets(self) -> None: """Refresh the dataset dropdown from storage.""" - if self._dataset_dropdown is None: - return - self._dataset_meta = None self._a_alignment_meta = None self._b_alignment_meta = None @@ -381,9 +381,7 @@ def _on_b_alignment_changed(self): def _update_compare_btn(self) -> None: if self._compare_btn: - self._compare_btn.setEnabled( - bool(self._a_alignment_meta and self._b_alignment_meta) - ) + self._compare_btn.setEnabled(bool(self._a_alignment_meta and self._b_alignment_meta)) # ── Comparison ─────────────────────────────────────────────────────────── @@ -436,9 +434,7 @@ def _run_comparison(self) -> None: try: counts_a = count_phonemes(paths_a, tier_name=tier, normalize=aggregate) counts_b = count_phonemes(paths_b, tier_name=tier, normalize=aggregate) - overlap = compute_phoneme_overlap( - paths_a, paths_b, tier_name=tier, normalize=aggregate - ) + overlap = compute_phoneme_overlap(paths_a, paths_b, tier_name=tier, normalize=aggregate) rates = compute_phoneme_overlap_rate( paths_a, paths_b, tier_name=tier, normalize=aggregate, threshold=threshold ) @@ -549,9 +545,7 @@ def _download_plots(self) -> None: errors.append(f"{filename}: {exc}") if errors: - self.set_status( - f"Saved {len(saved)} plot(s); errors: {'; '.join(errors)}", "error" - ) + self.set_status(f"Saved {len(saved)} plot(s); errors: {'; '.join(errors)}", "error") elif saved: self.set_status(f"Saved {len(saved)} plot(s) to {self._dl_folder}", "success") else: diff --git a/src/voxkit/gui/pages/pipeline/markdown_stacker.py b/src/voxkit/gui/pages/pipeline/markdown_stacker.py index 6894a09..6084a84 100644 --- a/src/voxkit/gui/pages/pipeline/markdown_stacker.py +++ b/src/voxkit/gui/pages/pipeline/markdown_stacker.py @@ -9,9 +9,10 @@ from PyQt6.QtWidgets import QSizePolicy, QTextBrowser -from .base_stacker import BaseStacker from voxkit.gui.styles import Containers +from .base_stacker import BaseStacker + class MarkdownStacker(BaseStacker): """A stacker widget that displays markdown content.""" @@ -24,9 +25,9 @@ def __init__(self, parent=None, markdown_content: str = ""): markdown_content: Markdown text to display """ self.markdown_content = markdown_content - self.text_browser = None + self.text_browser: QTextBrowser | None = None super().__init__(parent) - + # Remove the stretch at the end added by BaseStacker to allow # the text browser to expand and fill all available vertical space if self.main_layout.count() > 0: @@ -34,19 +35,16 @@ def __init__(self, parent=None, markdown_content: str = ""): if last_item and last_item.spacerItem(): self.main_layout.removeItem(last_item) - def build_ui(self): + def build_ui(self) -> None: """Build the markdown display UI.""" # Create text browser for markdown rendering self.text_browser = QTextBrowser() self.text_browser.setObjectName("markdownDisplay") self.text_browser.setOpenExternalLinks(True) # Allow clickable links self.text_browser.setStyleSheet(Containers.MARKDOWN_DISPLAY) - + # Set size policy to expand and fill available space - self.text_browser.setSizePolicy( - QSizePolicy.Policy.Expanding, - QSizePolicy.Policy.Expanding - ) + self.text_browser.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) # Set markdown content self.set_markdown(self.markdown_content) diff --git a/src/voxkit/gui/pages/pipeline/pllr_stacker.py b/src/voxkit/gui/pages/pipeline/pllr_stacker.py index b599cbc..27d2642 100644 --- a/src/voxkit/gui/pages/pipeline/pllr_stacker.py +++ b/src/voxkit/gui/pages/pipeline/pllr_stacker.py @@ -38,11 +38,11 @@ GenericDialog, SettingsConfig, ) +from voxkit.gui.styles import Buttons, Containers, Labels from voxkit.gui.utils import validate_path, validate_paths from voxkit.gui.workers.worker_thread import WorkerThread from voxkit.storage import alignments, datasets from voxkit.storage.utils import get_storage_root -from voxkit.gui.styles import Buttons, Containers, Labels FIELDS: list[FieldConfig] = [ FieldConfig( @@ -114,11 +114,11 @@ class PLLRStacker(QWidget): existing alignments using the PLLR (Probabilistic Linear Likelihood Ratio) method. """ - def __init__(self, parent=None): + def __init__(self, parent: QWidget | None = None) -> None: super().__init__() - self.parent = parent - self.pllr_dataset_dropdown = None - self.pllr_alignment_dropdown = None + self._parent_widget = parent + self.pllr_dataset_dropdown: MultiColumnComboBox + self.pllr_alignment_dropdown: MultiColumnComboBox self.init_ui() def on_extract_settings(self): @@ -126,8 +126,9 @@ def on_extract_settings(self): result = settings_dialog.exec() - # Clean up - self.parent.setGraphicsEffect(None) + # Clean up blur applied by GenericDialog to self.parent() + if self.parent(): + self.parent().setGraphicsEffect(None) if result == QDialog.DialogCode.Accepted: settings_dialog.save() @@ -375,13 +376,18 @@ def on_extract_pllr(self): # Get dataset path dataset_meta = datasets.get_dataset_metadata(selected_dataset_id) + if not dataset_meta: + QMessageBox.warning(self, "Invalid Dataset", "Could not find dataset metadata.") + return - wavlab_path = None + wavlab_path: Path | str | None = None if not (dataset_meta["cached"] == "True" or dataset_meta["cached"] is True): wavlab_path = dataset_meta["original_path"] else: - wavlab_path = datasets._get_dataset_root(selected_dataset_id) / "cache" + dataset_root = datasets._get_dataset_root(selected_dataset_id) + if dataset_root: + wavlab_path = dataset_root / "cache" print(f"[DEBUG] Dataset root path: {wavlab_path}") diff --git a/src/voxkit/gui/pages/pipeline/prediction_stacker.py b/src/voxkit/gui/pages/pipeline/prediction_stacker.py index c0217cd..89f0fca 100644 --- a/src/voxkit/gui/pages/pipeline/prediction_stacker.py +++ b/src/voxkit/gui/pages/pipeline/prediction_stacker.py @@ -57,8 +57,8 @@ def on_settings(self): if settings_dialog.result() == QDialog.DialogCode.Accepted: settings_dialog.save() - if self.parent: - self.parent.setGraphicsEffect(None) + if self.parent(): + self.parent().setGraphicsEffect(None) def reload_models(self): """Reload models in all engine dropdowns.""" diff --git a/src/voxkit/gui/pages/pipeline/training_stacker.py b/src/voxkit/gui/pages/pipeline/training_stacker.py index ba3dfcc..de2bb30 100644 --- a/src/voxkit/gui/pages/pipeline/training_stacker.py +++ b/src/voxkit/gui/pages/pipeline/training_stacker.py @@ -9,7 +9,6 @@ from pathlib import Path -from PyQt6.QtCore import Qt from PyQt6.QtWidgets import ( QDialog, QFileDialog, @@ -22,10 +21,10 @@ from voxkit.config import Defaults from voxkit.gui.components import ModelSelectionPanel, MultiColumnComboBox from voxkit.gui.frameworks.settings_modal import GenericDialog +from voxkit.gui.styles import Buttons, Containers, Labels from voxkit.gui.utils import validate_path, validate_paths from voxkit.gui.workers.worker_thread import WorkerThread from voxkit.storage import alignments, datasets, models -from voxkit.gui.styles import Buttons, Containers, Labels from .base_stacker import BaseStacker @@ -36,8 +35,10 @@ class TrainingStacker(BaseStacker): Allows users to train custom alignment models using a dataset and existing alignment as training data. """ + def __init__(self, parent): from voxkit.engines import engines + self.engines = engines self.train_dataset_dropdown = None self.train_alignment_dropdown = None @@ -51,18 +52,18 @@ def __init__(self, parent): def get_title(self) -> str: """Return the stacker's title.""" return "Train Aligners" - + def has_settings(self) -> bool: """This stacker has settings.""" return True - + def on_settings(self): """Handle settings button click on training page.""" self.settings_dialog = GenericDialog( parent=self, - config=self.engines.get_tool_providers("train")[self.model_panel.get_selected_engine()].get_settings_config( - "train" - ), + config=self.engines.get_tool_providers("train")[ + self.model_panel.get_selected_engine() + ].get_settings_config("train"), ) result = self.settings_dialog.exec() @@ -72,8 +73,8 @@ def on_settings(self): except Exception as e: print("Error syncing training settings:", e) # Clean up - if self.parent: - self.parent.setGraphicsEffect(None) + if self.parent(): + self.parent().setGraphicsEffect(None) def on_dataset_selected(self): """Handle dataset selection change and load corresponding alignments""" @@ -288,7 +289,10 @@ def reload_datasets(self): def build_ui(self): """Build the training UI.""" # Model Selection Panel - engines_dict = {engine_id: engine for engine_id, engine in self.engines.get_tool_providers("train").items()} + engines_dict = { + engine_id: engine + for engine_id, engine in self.engines.get_tool_providers("train").items() + } self.model_panel = ModelSelectionPanel(engines_dict) self.content_layout.addWidget(self.model_panel) diff --git a/src/voxkit/gui/pages/pipeline/transcription_stacker.py b/src/voxkit/gui/pages/pipeline/transcription_stacker.py index 78592e2..70c24ee 100644 --- a/src/voxkit/gui/pages/pipeline/transcription_stacker.py +++ b/src/voxkit/gui/pages/pipeline/transcription_stacker.py @@ -61,8 +61,8 @@ def on_settings(self): if settings_dialog.result() == QDialog.DialogCode.Accepted: settings_dialog.save() - if self.parent: - self.parent.setGraphicsEffect(None) + if self.parent(): + self.parent().setGraphicsEffect(None) def reload_datasets(self): """Reload datasets in the dropdown.""" @@ -147,7 +147,8 @@ def _get_selected_engine_id(self) -> str | None: idx = self.engine_dropdown.currentIndex() if idx < 0: return None - return self.engine_dropdown.itemData(idx) + engine_id = self.engine_dropdown.itemData(idx) + return engine_id if isinstance(engine_id, str) or engine_id is None else None def on_transcribe(self): """Handle Transcribe button click.""" diff --git a/src/voxkit/gui/pages/pipeline/viewer_stacker.py b/src/voxkit/gui/pages/pipeline/viewer_stacker.py index cdc7c53..2f8078a 100644 --- a/src/voxkit/gui/pages/pipeline/viewer_stacker.py +++ b/src/voxkit/gui/pages/pipeline/viewer_stacker.py @@ -11,12 +11,16 @@ - **ViewerStacker**: Alignment viewer workflow UI """ +from __future__ import annotations + +import os import re import subprocess import sys from pathlib import Path +from typing import TYPE_CHECKING -from PyQt6.QtCore import Qt, QPoint, QUrl, pyqtSignal +from PyQt6.QtCore import QPoint, Qt, QUrl, pyqtSignal from PyQt6.QtGui import QColor, QFont, QPainter, QPen, QPolygon from PyQt6.QtWidgets import ( QComboBox, @@ -38,9 +42,12 @@ from voxkit.storage import alignments, datasets from voxkit.storage.datasets import _get_dataset_root -try: +if TYPE_CHECKING: from PyQt6.QtMultimedia import QAudioOutput, QMediaPlayer +try: + from PyQt6.QtMultimedia import QAudioOutput, QMediaPlayer # noqa: F811 + MULTIMEDIA_AVAILABLE = True except ImportError: MULTIMEDIA_AVAILABLE = False @@ -99,9 +106,7 @@ def _parse_textgrid(filepath: str) -> list[dict]: time = re.search(r"time\s*=\s*([0-9.e+\-]+)", pb) mark = re.search(r'mark\s*=\s*"([^"]*)"', pb) if time and mark: - tier["intervals"].append( - {"time": float(time.group(1)), "label": mark.group(1)} - ) + tier["intervals"].append({"time": float(time.group(1)), "label": mark.group(1)}) tiers.append(tier) @@ -113,7 +118,7 @@ def _parse_textgrid(filepath: str) -> list[dict]: # --------------------------------------------------------------------------- -def _dataset_data_path(meta: dict) -> Path: +def _dataset_data_path(meta: datasets.DatasetMetadata) -> Path: """Return the directory containing speaker subdirs (audio + .lab files).""" if meta.get("cached"): root = _get_dataset_root(meta["id"]) @@ -159,12 +164,12 @@ class TextGridTimeline(QWidget): TIER_HEIGHT = 36 RULER_HEIGHT = 26 - LEFT_MARGIN = 92 # space reserved for tier name labels + LEFT_MARGIN = 92 # space reserved for tier name labels RIGHT_MARGIN = 8 # Fixed colors for well-known tier names (case-insensitive match) _TIER_COLOR_MAP: dict[str, QColor] = { - "words": QColor("#3498db"), # blue + "words": QColor("#3498db"), # blue "phones": QColor("#27ae60"), # green } @@ -178,7 +183,7 @@ class TextGridTimeline(QWidget): QColor("#8e44ad"), # dark purple ] - def __init__(self, parent=None): + def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) self._tiers: list[dict] = [] self._duration: float = 0.0 @@ -259,11 +264,16 @@ def paintEvent(self, _event): while t <= self._duration + step * 0.01: x = self._time_to_x(t) painter.drawLine(x, self.RULER_HEIGHT - 5, x, self.RULER_HEIGHT) - lbl = f"{t:.2f}s" if t < 1 else ( - f"{t:.1f}s" if t < 60 else f"{int(t // 60)}:{int(t % 60):02d}" + lbl = ( + f"{t:.2f}s" + if t < 1 + else (f"{t:.1f}s" if t < 60 else f"{int(t // 60)}:{int(t % 60):02d}") ) painter.drawText( - x - 26, 1, 52, self.RULER_HEIGHT - 6, + x - 26, + 1, + 52, + self.RULER_HEIGHT - 6, Qt.AlignmentFlag.AlignHCenter | Qt.AlignmentFlag.AlignVCenter, lbl, ) @@ -273,8 +283,10 @@ def paintEvent(self, _event): # Left margin background painter.fillRect( - 0, self.RULER_HEIGHT, - self.LEFT_MARGIN, h - self.RULER_HEIGHT, + 0, + self.RULER_HEIGHT, + self.LEFT_MARGIN, + h - self.RULER_HEIGHT, QColor("#ecf0f1"), ) painter.setPen(QPen(QColor("#bdc3c7"), 1)) @@ -306,7 +318,10 @@ def paintEvent(self, _event): painter.setFont(name_font) painter.setPen(color.darker(150)) painter.drawText( - 4, y, self.LEFT_MARGIN - 6, self.TIER_HEIGHT, + 4, + y, + self.LEFT_MARGIN - 6, + self.TIER_HEIGHT, Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignLeft, tier["name"], ) @@ -344,13 +359,13 @@ def paintEvent(self, _event): # Label inside block if bw > 10 and iv_label: - text_color = ( - QColor("white") if (active or not silent) - else color.darker(140) - ) + text_color = QColor("white") if (active or not silent) else color.darker(140) painter.setPen(text_color) painter.drawText( - x1 + 2, y + pad, bw - 4, self.TIER_HEIGHT - pad * 2, + x1 + 2, + y + pad, + bw - 4, + self.TIER_HEIGHT - pad * 2, Qt.AlignmentFlag.AlignCenter, iv_label, ) @@ -396,29 +411,30 @@ class ViewerStacker(BaseStacker): at once. """ - def __init__(self, parent=None): + def __init__(self, parent: QWidget | None = None) -> None: # Pre-declare all attributes so build_ui() (called by super().__init__) # can reference them safely. - self._dataset_dropdown: MultiColumnComboBox | None = None - self._alignment_dropdown: MultiColumnComboBox | None = None - self._speaker_dropdown: QComboBox | None = None - self._file_list: QListWidget | None = None - self._file_search = None # QLineEdit, set in build_ui + # Widgets set in build_ui() (called by super().__init__) + self._dataset_dropdown: MultiColumnComboBox + self._alignment_dropdown: MultiColumnComboBox + self._speaker_dropdown: QComboBox + self._file_list: QListWidget + self._file_search: QLineEdit self._all_audio_files: list[str] = [] - self._selection_section: QWidget | None = None - self._viewer_section: QWidget | None = None - self._timeline: TextGridTimeline | None = None - self._active_label: QLabel | None = None - self._transcript_edit: QTextEdit | None = None - self._audio_path_label: QLabel | None = None - self._current_dataset_meta: dict | None = None - self._current_alignment_meta: dict | None = None + self._selection_section: QWidget + self._viewer_section: QWidget + self._timeline: TextGridTimeline + self._active_label: QLabel + self._transcript_edit: QTextEdit + self._audio_path_label: QLabel + self._current_dataset_meta: datasets.DatasetMetadata | None = None + self._current_alignment_meta: alignments.AlignmentMetadata | None = None self._current_data_path: Path | None = None self._current_audio_path: Path | None = None self._loaded_tiers: list[dict] = [] # Multimedia (may remain None if QtMultimedia is unavailable) - self._player = None - self._audio_output = None + self._player: QMediaPlayer | None = None + self._audio_output: QAudioOutput | None = None self._play_btn: QPushButton | None = None self._seek_slider: QSlider | None = None self._time_label: QLabel | None = None @@ -574,9 +590,7 @@ def build_ui(self): self._timeline = TextGridTimeline() self._timeline.seek_requested.connect(self._seek_to_seconds) - self._timeline.setStyleSheet( - f"border: 1px solid {Colors.BORDER}; border-radius: 4px;" - ) + self._timeline.setStyleSheet(f"border: 1px solid {Colors.BORDER}; border-radius: 4px;") view_col.addWidget(self._timeline) # Active-segment indicator ──────────────────────────────────────────── @@ -717,9 +731,7 @@ def _on_alignment_changed(self): if not alignment_id or not self._current_dataset_meta: return - meta = alignments.get_alignment_metadata( - self._current_dataset_meta["id"], alignment_id - ) + meta = alignments.get_alignment_metadata(self._current_dataset_meta["id"], alignment_id) if not meta: return @@ -764,7 +776,9 @@ def _filter_file_list(self, query: str): """Show only files whose names contain the search query (case-insensitive).""" self._file_list.clear() q = query.strip().lower() - matches = [f for f in self._all_audio_files if q in f.lower()] if q else self._all_audio_files + matches = ( + [f for f in self._all_audio_files if q in f.lower()] if q else self._all_audio_files + ) self._file_list.addItems(matches) # Hide viewer if the previously selected file is no longer visible if self._viewer_section and self._viewer_section.isVisible(): @@ -847,6 +861,7 @@ def _load_viewer( f"{self._audio_path_label.text()} [TextGrid parse error: {exc}]" ) else: + assert self._current_alignment_meta is not None self._audio_path_label.setText( self._audio_path_label.text() + f" [TextGrid not found in {Path(self._current_alignment_meta['tg_path'])}]" @@ -936,12 +951,15 @@ def _open_audio_externally(self): path = self._current_audio_path if not path: return - if sys.platform == "darwin": - subprocess.Popen(["open", str(path)]) - elif sys.platform == "win32": - subprocess.Popen(["start", "", str(path)], shell=True) - else: - subprocess.Popen(["xdg-open", str(path)]) + try: + if sys.platform == "darwin": + subprocess.run(["open", str(path)], check=False) # noqa: S603,S607 + elif sys.platform == "win32": + os.startfile(str(path)) # noqa: S606 + else: + subprocess.run(["xdg-open", str(path)], check=False) # noqa: S603,S607 + except (OSError, subprocess.SubprocessError): + pass # ── Helpers ─────────────────────────────────────────────────────────────── diff --git a/src/voxkit/gui/styles/__init__.py b/src/voxkit/gui/styles/__init__.py index 9c2a8ef..383e5ec 100644 --- a/src/voxkit/gui/styles/__init__.py +++ b/src/voxkit/gui/styles/__init__.py @@ -348,7 +348,6 @@ class Buttons: border: 2px solid #e0e0e0; } """ - class Inputs: @@ -778,6 +777,7 @@ class Containers: } """ + __all__ = [ "Colors", "Buttons", diff --git a/src/voxkit/storage/alignments.py b/src/voxkit/storage/alignments.py index e7124e1..e90502e 100644 --- a/src/voxkit/storage/alignments.py +++ b/src/voxkit/storage/alignments.py @@ -221,10 +221,12 @@ def get_alignment_metadata(dataset_id: str, alignment_id: str) -> AlignmentMetad try: with open(metadata_path, "r") as f: - metadata = json.load(f) + metadata: AlignmentMetadata = json.load(f) # Normalize status to lowercase for consistency if "status" in metadata: - metadata["status"] = metadata["status"].lower() + status_lower = metadata["status"].lower() + # Cast to the correct literal type + metadata["status"] = status_lower # type: ignore[typeddict-item] return metadata except Exception as e: print(f"Failed to load alignment metadata from '{metadata_path}': {str(e)}") diff --git a/src/voxkit/storage/datasets.py b/src/voxkit/storage/datasets.py index c06ff39..88a619b 100644 --- a/src/voxkit/storage/datasets.py +++ b/src/voxkit/storage/datasets.py @@ -115,7 +115,8 @@ def _get_dataset_metadata(dataset_root: Path) -> DatasetMetadata | None: if not metadata_path.exists(): return None with open(metadata_path, "r") as f: - return json.load(f) + result: DatasetMetadata = json.load(f) + return result except Exception: return None @@ -313,7 +314,6 @@ def update_dataset_metadata( Tuple of (True, success_message) on success or (False, error_message) on failure Raises: - KeyError: If an invalid metadata field is specified FileNotFoundError: If the dataset is not found Exception: If metadata file cannot be written """ @@ -323,14 +323,9 @@ def update_dataset_metadata( if not metadata: return False, f"Dataset {dataset_id} not found" - if updates["description"] is not None: - metadata["description"] = updates["description"] - if updates["cached"] is not None: - metadata["cached"] = updates["cached"] - if updates["anonymize"] is not None: - metadata["anonymize"] = updates["anonymize"] - if updates["transcribed"] is not None: - metadata["transcribed"] = updates["transcribed"] + for field in ("description", "cached", "anonymize", "transcribed"): + if field in updates and updates[field] is not None: + metadata[field] = updates[field] # Save the updated metadata metadata_path = _get_datasets_root() / dataset_id / "voxkit_dataset.json" diff --git a/src/voxkit/storage/models.py b/src/voxkit/storage/models.py index 2020a52..eb5089d 100644 --- a/src/voxkit/storage/models.py +++ b/src/voxkit/storage/models.py @@ -305,7 +305,7 @@ def get_model_metadata(engine_id: str, model_id: str) -> ModelMetadata: if not metadata_path.exists(): raise FileNotFoundError(f"Metadata file not found for model '{model_id}'") with open(metadata_path, "r") as f: - metadata = json.load(f) + metadata: ModelMetadata = json.load(f) return metadata @@ -336,7 +336,6 @@ def download_and_copy_huggingface_model( # Download to HF cache (returns path to snapshot with symlinks) cache_snapshot_path = snapshot_download( repo_id=model_path, - resume_download=True, ) print(f"Downloaded to cache: {cache_snapshot_path}") diff --git a/tests/gui/test_csv_viewer_dialog.py b/tests/gui/test_csv_viewer_dialog.py index f03c8a8..84d4aed 100644 --- a/tests/gui/test_csv_viewer_dialog.py +++ b/tests/gui/test_csv_viewer_dialog.py @@ -13,18 +13,26 @@ def test_headers_match_csv(self, qtbot, sample_csv): dialog = CSVViewerDialog(csv_path=sample_csv) qtbot.addWidget(dialog) - headers = [ - dialog.table.horizontalHeaderItem(i).text() for i in range(dialog.table.columnCount()) - ] + headers = [] + for i in range(dialog.table.columnCount()): + item = dialog.table.horizontalHeaderItem(i) + if item is not None: + headers.append(item.text()) assert headers == ["name", "age", "city"] def test_cell_content(self, qtbot, sample_csv): dialog = CSVViewerDialog(csv_path=sample_csv) qtbot.addWidget(dialog) - assert dialog.table.item(0, 0).text() == "Alice" - assert dialog.table.item(1, 1).text() == "25" - assert dialog.table.item(2, 2).text() == "Chicago" + item1 = dialog.table.item(0, 0) + assert item1 is not None + assert item1.text() == "Alice" + item2 = dialog.table.item(1, 1) + assert item2 is not None + assert item2.text() == "25" + item3 = dialog.table.item(2, 2) + assert item3 is not None + assert item3.text() == "Chicago" def test_stats_label_shows_dimensions(self, qtbot, sample_csv): dialog = CSVViewerDialog(csv_path=sample_csv) @@ -52,6 +60,7 @@ def test_cells_are_read_only(self, qtbot, sample_csv): qtbot.addWidget(dialog) item = dialog.table.item(0, 0) + assert item is not None assert not (item.flags() & Qt.ItemFlag.ItemIsEditable) def test_window_title(self, qtbot, sample_csv): diff --git a/tests/gui/test_datasets_page.py b/tests/gui/test_datasets_page.py new file mode 100644 index 0000000..3ea5d12 --- /dev/null +++ b/tests/gui/test_datasets_page.py @@ -0,0 +1,50 @@ +from unittest.mock import patch + +import pytest + +from voxkit.gui.pages.datasets.datasets_page import DatasetsPage + + +@pytest.fixture +def datasets_page(qtbot): + """DatasetsPage instance with no real datasets loaded.""" + with patch( + "voxkit.gui.pages.datasets.datasets_page.datasets.list_datasets_metadata", return_value=[] + ): + page = DatasetsPage() + qtbot.addWidget(page) + return page + + +class TestRefreshDatasets: + def test_empty_state_shows_label_and_hides_table(self, qtbot, datasets_page): + with patch( + "voxkit.gui.pages.datasets.datasets_page.datasets.list_datasets_metadata", + return_value=[], + ): + datasets_page.refresh_datasets() + + assert not datasets_page.empty_label.isHidden() + assert datasets_page.dataset_table.isHidden() + + def test_populated_state_shows_table_and_hides_label(self, qtbot, datasets_page): + sample_metadata = [ + { + "id": "ds-1", + "name": "Test Dataset", + "description": "A test dataset", + "cached": False, + "anonymize": False, + "transcribed": False, + "registration_date": "2024-01-01T00:00:00", + } + ] + with patch( + "voxkit.gui.pages.datasets.datasets_page.datasets.list_datasets_metadata", + return_value=sample_metadata, + ): + datasets_page.refresh_datasets() + + assert datasets_page.empty_label.isHidden() + assert not datasets_page.dataset_table.isHidden() + assert datasets_page.dataset_table.rowCount() == 1 diff --git a/tests/gui/test_loading_dialog.py b/tests/gui/test_loading_dialog.py index 4f69203..5d4b288 100644 --- a/tests/gui/test_loading_dialog.py +++ b/tests/gui/test_loading_dialog.py @@ -8,7 +8,11 @@ def test_default_message(self, qtbot): dialog = LoadingDialog() qtbot.addWidget(dialog) - label = dialog.layout().itemAt(0).widget() + layout = dialog.layout() + assert layout is not None + item = layout.itemAt(0) + assert item is not None + label = item.widget() assert isinstance(label, QLabel) assert label.text() == "Loading..." @@ -16,7 +20,12 @@ def test_custom_message(self, qtbot): dialog = LoadingDialog(message="Please wait...") qtbot.addWidget(dialog) - label = dialog.layout().itemAt(0).widget() + layout = dialog.layout() + assert layout is not None + item = layout.itemAt(0) + assert item is not None + label = item.widget() + assert isinstance(label, QLabel) assert label.text() == "Please wait..." def test_update_message(self, qtbot): @@ -25,7 +34,12 @@ def test_update_message(self, qtbot): dialog.update_message("Step 2") - label = dialog.layout().itemAt(0).widget() + layout = dialog.layout() + assert layout is not None + item = layout.itemAt(0) + assert item is not None + label = item.widget() + assert isinstance(label, QLabel) assert label.text() == "Step 2" def test_spinner_frames_cycle(self, qtbot): diff --git a/tests/gui/test_models_page.py b/tests/gui/test_models_page.py new file mode 100644 index 0000000..42029cb --- /dev/null +++ b/tests/gui/test_models_page.py @@ -0,0 +1,65 @@ +"""Tests for ManageAlignersWidget in models_page.""" + +from unittest.mock import MagicMock, call, patch + + +class TestReloadModels: + """reload_models must refresh every engine, not just W2TGENGINE.""" + + def _make_widget(self, engines): + """Return a minimal ManageAlignersWidget stand-in with the given engines.""" + widget = MagicMock() + widget.get_engines.return_value = engines + return widget + + def test_calls_set_items_for_every_engine(self): + engines = ["W2TGENGINE", "MFAENGINE", "CUSTOMENGINE"] + widget = self._make_widget(engines) + + model_data = { + "W2TGENGINE": [{"id": "w1", "name": "w2tg_model"}], + "MFAENGINE": [{"id": "m1", "name": "mfa_model"}], + "CUSTOMENGINE": [], + } + + with patch( + "voxkit.gui.pages.models.models_page.models.list_models", + side_effect=lambda e: model_data[e], + ): + from voxkit.gui.pages.models.models_page import ManageAlignersWidget + + ManageAlignersWidget.reload_models(widget) + + widget.set_items.assert_has_calls( + [ + call("W2TGENGINE", model_data["W2TGENGINE"]), + call("MFAENGINE", model_data["MFAENGINE"]), + call("CUSTOMENGINE", model_data["CUSTOMENGINE"]), + ], + any_order=False, + ) + assert widget.set_items.call_count == len(engines) + + def test_no_set_items_calls_when_no_engines(self): + widget = self._make_widget([]) + + with patch("voxkit.gui.pages.models.models_page.models.list_models"): + from voxkit.gui.pages.models.models_page import ManageAlignersWidget + + ManageAlignersWidget.reload_models(widget) + + widget.set_items.assert_not_called() + + def test_single_engine_still_calls_set_items(self): + widget = self._make_widget(["W2TGENGINE"]) + model_list = [{"id": "w1", "name": "w2tg_model"}] + + with patch( + "voxkit.gui.pages.models.models_page.models.list_models", + return_value=model_list, + ): + from voxkit.gui.pages.models.models_page import ManageAlignersWidget + + ManageAlignersWidget.reload_models(widget) + + widget.set_items.assert_called_once_with("W2TGENGINE", model_list) diff --git a/tests/gui/test_toggle_switch.py b/tests/gui/test_toggle_switch.py index 54b09a8..d60070c 100644 --- a/tests/gui/test_toggle_switch.py +++ b/tests/gui/test_toggle_switch.py @@ -1,3 +1,5 @@ +from typing import cast + from PyQt6.QtCore import Qt from PyQt6.QtTest import QTest @@ -20,10 +22,12 @@ def test_click_toggles_state(self, qtbot): qtbot.addWidget(switch) switch.show() - QTest.mouseClick(switch, Qt.MouseButton.LeftButton) + from PyQt6.QtWidgets import QWidget as QWidgetType + + QTest.mouseClick(cast(QWidgetType, switch), Qt.MouseButton.LeftButton) assert switch.isChecked() is True - QTest.mouseClick(switch, Qt.MouseButton.LeftButton) + QTest.mouseClick(cast(QWidgetType, switch), Qt.MouseButton.LeftButton) assert switch.isChecked() is False def test_set_checked_programmatically(self, qtbot): diff --git a/tests/storage/test_alignments.py b/tests/storage/test_alignments.py index 109dc97..f65e79e 100644 --- a/tests/storage/test_alignments.py +++ b/tests/storage/test_alignments.py @@ -62,6 +62,7 @@ def sample_dataset(monkeypatch): ) assert success is True + assert isinstance(dataset_metadata, dict) return dataset_metadata @@ -80,6 +81,7 @@ def sample_model(monkeypatch): ) assert success is True + assert isinstance(model_metadata, dict) return model_metadata @@ -151,6 +153,7 @@ def test_create_alignment_invalid_dataset(self, monkeypatch, sample_model): ) assert msg is False + assert isinstance(result, str) assert "Dataset" in result def test_create_alignment_non_cached_dataset(self, monkeypatch, sample_model): @@ -173,6 +176,7 @@ def test_create_alignment_non_cached_dataset(self, monkeypatch, sample_model): ) assert success is True + assert isinstance(dataset_metadata, dict) dataset_id = dataset_metadata["id"] engine_id = sample_model["engine_id"] @@ -221,6 +225,7 @@ def test_get_alignment_metadata_success(self, monkeypatch, sample_dataset, sampl ) assert success is True + assert isinstance(alignment_metadata, dict) alignment_id = alignment_metadata["id"] @@ -247,6 +252,7 @@ def test_get_alignment_metadata_invalid_id(self, monkeypatch, sample_dataset, sa ) assert success is True + assert isinstance(alignment_metadata, dict) invalid_alignment_id = "NON_EXISTENT_ALIGNMENT" @@ -305,6 +311,7 @@ def test_list_alignments_success(self, monkeypatch, sample_dataset, sample_model model_id=model_id, ) assert success is True + assert isinstance(alignment_metadata, dict) created_alignment_ids.add(alignment_metadata["id"]) alignments_list = list_alignments(dataset_id=dataset_id) @@ -341,6 +348,7 @@ def test_delete_alignment_success(self, sample_dataset, sample_model): ) assert success is True + assert isinstance(alignment_metadata, dict) alignment_id = alignment_metadata["id"] @@ -410,6 +418,7 @@ def test_update_alignment_success(self, sample_dataset, sample_model): ) assert success is True + assert isinstance(alignment_metadata, dict) alignment_id = alignment_metadata["id"] @@ -450,6 +459,7 @@ def test_update_alignment_status_case_insensitive(self, sample_dataset, sample_m ) assert success is True + assert isinstance(alignment_metadata, dict) alignment_id = alignment_metadata["id"] @@ -508,6 +518,7 @@ def test_list_alignments_normalizes_status(self, monkeypatch, sample_dataset, sa ) assert success is True + assert isinstance(alignment_metadata, dict) alignment_id = alignment_metadata["id"] @@ -554,6 +565,7 @@ def test_get_alignment_metadata_normalizes_status( ) assert success is True + assert isinstance(alignment_metadata, dict) alignment_id = alignment_metadata["id"] diff --git a/tests/storage/test_datasets.py b/tests/storage/test_datasets.py index 3ebc32d..1f5c621 100644 --- a/tests/storage/test_datasets.py +++ b/tests/storage/test_datasets.py @@ -92,6 +92,7 @@ def test_create_dataset_success_no_cache(self, monkeypatch): ) assert success is True + assert isinstance(message, dict) for key in DatasetMetadata.__annotations__.keys(): assert key in message @@ -117,6 +118,7 @@ def test_create_dataset_success_with_cache(self, monkeypatch): ) assert success is True + assert isinstance(message, dict) for key in DatasetMetadata.__annotations__.keys(): assert key in message @@ -141,6 +143,7 @@ def test_create_dataset_invalid_path(self, monkeypatch): ) assert success is False + assert isinstance(message, str) assert "No label files found" in message assert invalid_dataset_path.exists() is True @@ -169,10 +172,10 @@ def test_list_datasets_metadata(self, monkeypatch): transcribed=False, ) - datasets = list_datasets_metadata() - assert len(datasets) >= 2 # At least the two we just created + dataset_list = list_datasets_metadata() + assert len(dataset_list) >= 2 # At least the two we just created - names = [ds["name"] for ds in datasets] + names = [ds["name"] for ds in dataset_list] assert "dataset_one" in names assert "dataset_two" in names @@ -192,12 +195,12 @@ def test_list_datasets_output_format(self, monkeypatch): transcribed=True, ) - datasets = list_datasets_metadata() + dataset_list = list_datasets_metadata() # Check that each dataset has all required fields - for i in range(len(datasets)): + for i in range(len(dataset_list)): for key in DatasetMetadata.__annotations__.keys(): - assert key in datasets[i].keys() + assert key in dataset_list[i].keys() def test_list_datasets_empty(self, monkeypatch): from voxkit.storage import datasets @@ -209,9 +212,9 @@ def test_list_datasets_empty(self, monkeypatch): deactivate_test_environment(mock_get_storage_root()) activate_test_environment(mock_get_storage_root()) - datasets = list_datasets_metadata() - assert isinstance(datasets, list) - assert len(datasets) == 0 + dataset_list = list_datasets_metadata() + assert isinstance(dataset_list, list) + assert len(dataset_list) == 0 class TestGetDatasetMetadata: def test_get_dataset_metadata_success(self, monkeypatch): @@ -229,10 +232,12 @@ def test_get_dataset_metadata_success(self, monkeypatch): anonymize=False, transcribed=True, ) + assert success is True + assert isinstance(message, dict) dataset_id = message["id"] metadata = get_dataset_metadata(dataset_id) - assert success is not None + assert metadata is not None for key in DatasetMetadata.__annotations__.keys(): assert key in metadata @@ -267,7 +272,7 @@ def test_delete_dataset_success(self, monkeypatch): monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) # Create a dataset - _, message = create_dataset( + success, message = create_dataset( name="dataset_delete_test", description="Testing delete_dataset", original_path=valid_dataset_path, @@ -275,6 +280,8 @@ def test_delete_dataset_success(self, monkeypatch): anonymize=False, transcribed=True, ) + assert success is True + assert isinstance(message, dict) dataset_id = message["id"] # Delete the dataset @@ -314,7 +321,7 @@ def test_delete_dataset_twice(self, monkeypatch): monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) # Create a dataset - _, message = create_dataset( + success, message = create_dataset( name="dataset_delete_twice_test", description="Testing delete_dataset twice", original_path=valid_dataset_path, @@ -322,6 +329,8 @@ def test_delete_dataset_twice(self, monkeypatch): anonymize=False, transcribed=True, ) + assert success is True + assert isinstance(message, dict) dataset_id = message["id"] # First deletion @@ -352,7 +361,7 @@ def test_export_dataset_success(self, monkeypatch, tmp_path): monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) # Create a dataset - _, message = create_dataset( + success, message = create_dataset( name="dataset_export_test", description="Testing export_dataset", original_path=valid_dataset_path, @@ -360,6 +369,8 @@ def test_export_dataset_success(self, monkeypatch, tmp_path): anonymize=False, transcribed=True, ) + assert success is True + assert isinstance(message, dict) dataset_id = message["id"] export_path = mock_get_storage_root() @@ -375,7 +386,7 @@ def test_export_equal(self, monkeypatch): monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) # Create a dataset - _, message = create_dataset( + success, message = create_dataset( name="dataset_export_equal_test", description="Testing export_dataset equality", original_path=valid_dataset_path, @@ -383,6 +394,8 @@ def test_export_equal(self, monkeypatch): anonymize=False, transcribed=True, ) + assert success is True + assert isinstance(message, dict) dataset_id = message["id"] export_path = mock_get_storage_root() @@ -393,6 +406,7 @@ def test_export_equal(self, monkeypatch): # Dataset metadata original_metadata = datasets.get_dataset_metadata(dataset_id) + assert original_metadata is not None # Verify exported files match original files original_dataset_path = _get_datasets_root() / dataset_id @@ -415,7 +429,7 @@ def test_import_dataset_success(self, monkeypatch): monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) # Create a dataset to export and then import - _, message = create_dataset( + success, message = create_dataset( name="dataset_import_test", description="Testing import_dataset", original_path=valid_dataset_path, @@ -423,6 +437,8 @@ def test_import_dataset_success(self, monkeypatch): anonymize=False, transcribed=True, ) + assert success is True + assert isinstance(message, dict) dataset_id = message["id"] export_path = mock_get_storage_root() @@ -468,7 +484,7 @@ def test_import_dataset_empty_cache_true(self, monkeypatch): monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) # Create a dataset to export and then import - _, message = create_dataset( + success, message = create_dataset( name="dataset_import_test", description="Testing import_dataset", original_path=valid_dataset_path, @@ -476,6 +492,8 @@ def test_import_dataset_empty_cache_true(self, monkeypatch): anonymize=False, transcribed=True, ) + assert success is True + assert isinstance(message, dict) # empty cache directory dataset_id = message["id"] @@ -516,6 +534,7 @@ def test_update_dataset_metadata_success(self, monkeypatch): transcribed=False, ) assert success is True + assert isinstance(message, dict) dataset_id = message["id"] # Update the dataset metadata @@ -532,6 +551,7 @@ def test_update_dataset_metadata_success(self, monkeypatch): # Verify the updates metadata = get_dataset_metadata(dataset_id) + assert metadata is not None assert metadata["description"] == "Updated description" assert metadata["anonymize"] is True assert metadata["transcribed"] is True @@ -575,6 +595,7 @@ def test_update_dataset_metadata_partial(self, monkeypatch): transcribed=True, ) assert success is True + assert isinstance(message, dict) dataset_id = message["id"] # Update only description @@ -589,11 +610,49 @@ def test_update_dataset_metadata_partial(self, monkeypatch): # Verify only description changed metadata = get_dataset_metadata(dataset_id) + assert metadata is not None assert metadata["description"] == "Only description updated" assert metadata["cached"] is True assert metadata["anonymize"] is True assert metadata["transcribed"] is True + def test_update_dataset_metadata_missing_keys(self, monkeypatch): + from voxkit.storage import datasets + from voxkit.storage.datasets import ( + create_dataset, + get_dataset_metadata, + update_dataset_metadata, + ) + + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) + + # Create a dataset + success, message = create_dataset( + name="dataset_missing_keys_test", + description="Original description", + original_path=valid_dataset_path, + cached=True, + anonymize=True, + transcribed=True, + ) + assert success is True + assert isinstance(message, dict) + dataset_id = message["id"] + + # Pass a partial dict with only one key — no KeyError should be raised + update_success, update_msg = update_dataset_metadata( + dataset_id, {"description": "Partial update"} + ) + assert update_success is True + + # Only description should change; other fields stay untouched + metadata = get_dataset_metadata(dataset_id) + assert metadata is not None + assert metadata["description"] == "Partial update" + assert metadata["cached"] is True + assert metadata["anonymize"] is True + assert metadata["transcribed"] is True + class TestCreateDatasetWithAnalysis: def test_create_dataset_with_analysis_data(self, monkeypatch): from voxkit.storage import datasets @@ -618,6 +677,7 @@ def test_create_dataset_with_analysis_data(self, monkeypatch): ) assert success is True + assert isinstance(message, dict) dataset_id = message["id"] # Verify CSV file was created @@ -651,6 +711,7 @@ def test_create_dataset_without_analysis_data(self, monkeypatch): ) assert success is True + assert isinstance(message, dict) dataset_id = message["id"] # Verify no CSV file was created diff --git a/tests/storage/test_models.py b/tests/storage/test_models.py index 125b751..d975fb7 100644 --- a/tests/storage/test_models.py +++ b/tests/storage/test_models.py @@ -34,6 +34,7 @@ def test_create_model_success(self, monkeypatch): model_name="test_model", ) assert success is True + assert not isinstance(message, str) assert message is not None required_keys = set(ModelMetadata.__annotations__.keys()) missing = required_keys - set(message.keys()) @@ -71,6 +72,7 @@ def test_create_multiple_models(self, monkeypatch): model_name=name, ) assert success is True + assert not isinstance(message, str) assert message["name"] == name assert message["id"] not in created_ids, "Duplicate model ID generated" created_ids.add(message["id"]) @@ -87,6 +89,7 @@ def test_model_paths_created(self, monkeypatch): model_name="path_test_model", ) assert success is True + assert not isinstance(message, str) model_path = Path(message["model_path"]) data_path = Path(message["data_path"]) @@ -111,6 +114,7 @@ def test_model_fits_modelmetadata(self, monkeypatch): model_name="metadata_test_model", ) assert success is True + assert not isinstance(message, str) # Check that the returned message fits the ModelMetadata TypedDict required_keys = set(ModelMetadata.__annotations__.keys()) @@ -202,6 +206,7 @@ def test_delete_model_success(self, monkeypatch): model_name="delete_test_model", ) assert success is True + assert not isinstance(message, str) model_id = message["id"] # Now delete the model @@ -250,6 +255,7 @@ def test_delete_model_multiple(self, monkeypatch): model_name=f"multi_delete_model_{i}", ) assert success is True + assert not isinstance(message, str) model_ids.append(message["id"]) # Delete the models one by one @@ -277,6 +283,7 @@ def test_delete_model_invalid_engine(self, monkeypatch): model_name="invalid_engine_delete_model", ) assert success is True + assert not isinstance(message, str) model_id = message["id"] # Attempt to delete with invalid engine_id @@ -304,6 +311,7 @@ def test_get_model_metadata_success(self, monkeypatch): model_name="metadata_test_model", ) assert success is True + assert not isinstance(message, str) model_id = message["id"] # Retrieve metadata @@ -346,6 +354,7 @@ def test_get_model_metadata_multiple(self, monkeypatch): model_name=f"multi_metadata_model_{i}", ) assert success is True + assert not isinstance(message, str) model_ids.append(message["id"]) # Retrieve and verify metadata for each model @@ -370,6 +379,7 @@ def test_get_model_metadata_invalid_engine(self, monkeypatch): model_name="invalid_engine_metadata_model", ) assert success is True + assert not isinstance(message, str) model_id = message["id"] # Attempt to get metadata with invalid engine_id @@ -393,6 +403,7 @@ def test_get_model_metadata_output_format(self, monkeypatch): model_name="format_metadata_model", ) assert success is True + assert not isinstance(message, str) model_id = message["id"] # Retrieve metadata @@ -422,6 +433,7 @@ def test_update_model_metadata_success(self, monkeypatch): model_name="update_test_model", ) assert success is True + assert not isinstance(message, str) model_id = message["id"] # Update the model metadata @@ -466,6 +478,7 @@ def test_update_model_metadata_invalid_engine(self, monkeypatch): model_name="invalid_engine_update_model", ) assert success is True + assert not isinstance(message, str) model_id = message["id"] # Try to update with invalid engine @@ -494,6 +507,7 @@ def test_update_model_metadata_ignores_unknown_fields(self, monkeypatch): model_name="unknown_fields_model", ) assert success is True + assert not isinstance(message, str) model_id = message["id"] # Update with unknown field @@ -532,6 +546,7 @@ def test_create_model_with_directory_source(self, monkeypatch): ) assert success is True + assert not isinstance(message, str) assert message["name"] == "model_from_source" # Verify source files were copied @@ -562,6 +577,7 @@ def test_create_model_with_zip_source(self, monkeypatch): ) assert success is True + assert not isinstance(message, str) assert message["name"] == "model_from_zip" # Verify zip was copied as entrypoint.zip diff --git a/uv.lock b/uv.lock index 78c3ae6..f5e8ee3 100644 --- a/uv.lock +++ b/uv.lock @@ -2985,6 +2985,7 @@ dev = [ { name = "pytest-cov" }, { name = "pytest-qt" }, { name = "ruff" }, + { name = "types-pyyaml" }, ] docs = [ { name = "pdoc" }, @@ -3023,6 +3024,7 @@ dev = [ { name = "pytest-cov", specifier = ">=7.0.0" }, { name = "pytest-qt", specifier = ">=4.4.0" }, { name = "ruff", specifier = ">=0.14.0" }, + { name = "types-pyyaml", specifier = ">=6.0.0" }, ] docs = [{ name = "pdoc", specifier = ">=16.0.0" }] installation = [{ name = "pyinstaller", specifier = ">=5.10.1" }] @@ -4074,6 +4076,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4a/91/48db081e7a63bb37284f9fbcefda7c44c277b18b0e13fbc36ea2335b71e6/typer-0.24.1-py3-none-any.whl", hash = "sha256:112c1f0ce578bfb4cab9ffdabc68f031416ebcc216536611ba21f04e9aa84c9e", size = 56085, upload-time = "2026-02-21T16:54:41.616Z" }, ] +[[package]] +name = "types-pyyaml" +version = "6.0.12.20260408" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/74/73/b759b1e413c31034cc01ecdfb96b38115d0ab4db55a752a3929f0cd449fd/types_pyyaml-6.0.12.20260408.tar.gz", hash = "sha256:92a73f2b8d7f39ef392a38131f76b970f8c66e4c42b3125ae872b7c93b556307", size = 17735, upload-time = "2026-04-08T04:30:50.974Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1c/f0/c391068b86abb708882c6d75a08cd7d25b2c7227dab527b3a3685a3c635b/types_pyyaml-6.0.12.20260408-py3-none-any.whl", hash = "sha256:fbc42037d12159d9c801ebfcc79ebd28335a7c13b08a4cfbc6916df78fee9384", size = 20339, upload-time = "2026-04-08T04:30:50.113Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0"