diff --git a/_frozen_patch.py b/_frozen_patch.py index 9729a1d..b3b3148 100644 --- a/_frozen_patch.py +++ b/_frozen_patch.py @@ -1,33 +1,34 @@ """ Runtime patch to disable problematic inspect operations in frozen apps """ + import sys -if getattr(sys, 'frozen', False): +if getattr(sys, "frozen", False): # We're running in a PyInstaller bundle import inspect - + # Patch inspect.getsource to return empty string instead of raising _original_getsource = inspect.getsource - + def _patched_getsource(object): try: return _original_getsource(object) except (OSError, TypeError): # Return a dummy source code return "# Source code not available in frozen application\npass\n" - + inspect.getsource = _patched_getsource - + # Patch getsourcelines similarly _original_getsourcelines = inspect.getsourcelines - + def _patched_getsourcelines(object): try: return _original_getsourcelines(object) except (OSError, TypeError): return (["# Source code not available\n", "pass\n"], 0) - + inspect.getsourcelines = _patched_getsourcelines - + print("[PATCH] Disabled source code inspection for frozen app") diff --git a/build.py b/build.py index 43272dc..d9b2da7 100644 --- a/build.py +++ b/build.py @@ -1,8 +1,8 @@ import argparse import os -import sys import shutil import subprocess +import sys from pathlib import Path """ @@ -17,200 +17,236 @@ def codesign_macos_app(app_path): """Ad-hoc code sign the macOS app bundle""" - print(f"[macOS] Code signing app bundle...") - + print("[macOS] Code signing app bundle...") + # Find all dylibs and executables to sign files_to_sign = [] - + internal_dir = app_path / "_internal" if internal_dir.exists(): for dylib in internal_dir.glob("*.dylib"): files_to_sign.append(dylib) for so in internal_dir.glob("**/*.so"): files_to_sign.append(so) - + executable = app_path / "VoxKit" if executable.exists(): files_to_sign.append(executable) - + # Sign each file with ad-hoc signature for file_path in files_to_sign: print(f"[macOS] Signing {file_path.name}...") - result = subprocess.run([ - "codesign", - "--force", - "--deep", - "--sign", "-", # Ad-hoc signature - str(file_path) - ], capture_output=True, text=True) - + result = subprocess.run( + [ + "codesign", + "--force", + "--deep", + "--sign", + "-", # Ad-hoc signature + str(file_path), + ], + capture_output=True, + text=True, + ) + if result.returncode != 0: print(f"[WARNING] Failed to sign {file_path.name}: {result.stderr}") - - print(f"[macOS] Code signing complete") + + print("[macOS] Code signing complete") return True def fix_macos_dylib_paths(app_path, python_lib_source): """Fix dylib paths for macOS .app bundles""" - print(f"[macOS] Fixing dynamic library paths...") - + print("[macOS] Fixing dynamic library paths...") + internal_dir = app_path / "_internal" executable = app_path / "VoxKit" - + if not internal_dir.exists(): print(f"[ERROR] _internal directory not found at {internal_dir}") return False - + # Copy Python shared library if it exists python_lib_dest = internal_dir / "libpython3.11.dylib" - + if python_lib_source.exists() and not python_lib_dest.exists(): print(f"[macOS] Copying Python library from {python_lib_source}") shutil.copy2(python_lib_source, python_lib_dest) - + if python_lib_dest.exists(): print(f"[macOS] Fixing library ID for {python_lib_dest.name}") # Make writable os.chmod(python_lib_dest, 0o755) # Change library ID to use @loader_path - subprocess.run([ - "install_name_tool", - "-id", "@loader_path/libpython3.11.dylib", - str(python_lib_dest) - ], check=False) - + subprocess.run( + ["install_name_tool", "-id", "@loader_path/libpython3.11.dylib", str(python_lib_dest)], + check=False, + ) + if executable.exists(): - print(f"[macOS] Updating executable to reference bundled Python library") + print("[macOS] Updating executable to reference bundled Python library") os.chmod(executable, 0o755) # Update executable to look for library relative to itself - subprocess.run([ - "install_name_tool", - "-change", str(python_lib_source), - "@loader_path/../_internal/libpython3.11.dylib", - str(executable) - ], check=False) - - print(f"[macOS] Dylib path fixing complete") + subprocess.run( + [ + "install_name_tool", + "-change", + str(python_lib_source), + "@loader_path/../_internal/libpython3.11.dylib", + str(executable), + ], + check=False, + ) + + print("[macOS] Dylib path fixing complete") return True def build(args): - try: - import PyInstaller.__main__ as pyi_main - except Exception: - print("PyInstaller not found. Install it with: pip install pyinstaller") - sys.exit(1) - - opts = [] - - # Basic options - if args.name: - opts.append(f'--name={args.name}') - - # macOS: Always use onedir mode for .app bundles - if sys.platform == 'darwin' and args.windowed: - print("[macOS] Using onedir mode (required for .app bundles)") - # Don't add --onefile - elif args.onefile: - opts.append('--onefile') - - if args.windowed: - opts.append('--windowed') - if args.clean: - opts.append('--clean') - if args.distpath: - opts.append(f'--distpath={args.distpath}') - if args.workpath: - opts.append(f'--workpath={args.workpath}') - if args.specpath: - opts.append(f'--specpath={args.specpath}') - if args.icon: - opts.append(f'--icon={args.icon}') - - # Hidden imports - default_hidden = [ - 'typeguard', - 'inflect', - 'g2p_en', - 'speechbrain', - 'speechbrain.utils', - # Engine modules that need to be explicitly included - 'voxkit.engines._w2tg_engine', - 'voxkit.engines._whisperx_engine', - 'voxkit.engines.mfa_engine', - 'PyQt6.QtCore', - 'PyQt6.QtGui', - 'PyQt6.QtWidgets', - 'PyQt6.QtSvg', - 'PyQt6.QtSvgWidgets' - ] - for hi in default_hidden + args.hidden_import: - opts.append(f'--hidden-import={hi}') - - # Add hooks directory if it exists - hooks_dir = Path(__file__).parent / "hooks" - if hooks_dir.exists(): - opts.append(f'--additional-hooks-dir={hooks_dir}') - - # Add data - sep = ';' if os.name == 'nt' else ':' - for ad in args.add_data: - if sep in ad: - opts.append(f'--add-data={ad}') - else: - other = ';' if sep == ':' else ':' - if other in ad: - src, dest = ad.split(other, 1) - opts.append(f'--add-data={src}{sep}{dest}') - else: - opts.append(f'--add-data={ad}{sep}{os.path.basename(ad)}') - - # Entry script is last - opts.append(args.entry) - - print("Running PyInstaller with options:", opts) - pyi_main.run(opts) - - # Post-build: Fix macOS dylib paths and code sign - if sys.platform == 'darwin': - print("\n[macOS] Running post-build fixes...") - dist_path = Path(args.distpath) if args.distpath else Path("dist") - app_path = dist_path / args.name - - # Find Python shared library - python_lib = Path(sys.base_prefix) / "lib" / "libpython3.11.dylib" - - if app_path.exists(): - fix_macos_dylib_paths(app_path, python_lib) - codesign_macos_app(app_path) - print(f"\n✅ Build complete: {app_path}") - print(f" Run with: ./dist/{args.name}/{args.name}") - else: - print(f"\n⚠️ Expected build output not found at {app_path}") + try: + import PyInstaller.__main__ as pyi_main + except Exception: + print("PyInstaller not found. Install it with: pip install pyinstaller") + sys.exit(1) + + opts = [] + + # Basic options + if args.name: + opts.append(f"--name={args.name}") + + # macOS: Always use onedir mode for .app bundles + if sys.platform == "darwin" and args.windowed: + print("[macOS] Using onedir mode (required for .app bundles)") + # Don't add --onefile + elif args.onefile: + opts.append("--onefile") + + if args.windowed: + opts.append("--windowed") + if args.clean: + opts.append("--clean") + if args.distpath: + opts.append(f"--distpath={args.distpath}") + if args.workpath: + opts.append(f"--workpath={args.workpath}") + if args.specpath: + opts.append(f"--specpath={args.specpath}") + if args.icon: + opts.append(f"--icon={args.icon}") + + # Hidden imports + default_hidden = [ + "typeguard", + "inflect", + "g2p_en", + "speechbrain", + "speechbrain.utils", + # Engine modules that need to be explicitly included + "voxkit.engines._w2tg_engine", + "voxkit.engines._whisperx_engine", + "voxkit.engines.mfa_engine", + "PyQt6.QtCore", + "PyQt6.QtGui", + "PyQt6.QtWidgets", + "PyQt6.QtSvg", + "PyQt6.QtSvgWidgets", + ] + for hi in default_hidden + args.hidden_import: + opts.append(f"--hidden-import={hi}") + + # Add hooks directory if it exists + hooks_dir = Path(__file__).parent / "hooks" + if hooks_dir.exists(): + opts.append(f"--additional-hooks-dir={hooks_dir}") + + # Add data + sep = ";" if os.name == "nt" else ":" + for ad in args.add_data: + if sep in ad: + opts.append(f"--add-data={ad}") + else: + other = ";" if sep == ":" else ":" + if other in ad: + src, dest = ad.split(other, 1) + opts.append(f"--add-data={src}{sep}{dest}") + else: + opts.append(f"--add-data={ad}{sep}{os.path.basename(ad)}") + + # Entry script is last + opts.append(args.entry) + + print("Running PyInstaller with options:", opts) + pyi_main.run(opts) + + # Post-build: Fix macOS dylib paths and code sign + if sys.platform == "darwin": + print("\n[macOS] Running post-build fixes...") + dist_path = Path(args.distpath) if args.distpath else Path("dist") + app_path = dist_path / args.name + + # Find Python shared library + python_lib = Path(sys.base_prefix) / "lib" / "libpython3.11.dylib" + + if app_path.exists(): + fix_macos_dylib_paths(app_path, python_lib) + codesign_macos_app(app_path) + print(f"\n✅ Build complete: {app_path}") + print(f" Run with: ./dist/{args.name}/{args.name}") + else: + print(f"\n⚠️ Expected build output not found at {app_path}") + def main(): - parser = argparse.ArgumentParser(prog="build.py", description="Build a standalone executable using PyInstaller") - sub = parser.add_subparsers(dest="command", required=True) - - build_p = sub.add_parser("build", help="Create executable with PyInstaller") - build_p.add_argument("--entry", "-e", required=True, help="Path to the entry-point python file") - build_p.add_argument("--name", "-n", default="VoxKit", help="Name of the generated executable") - build_p.add_argument("--onefile", action="store_true", default=True, help="Produce a single-file executable (default enabled)") - build_p.add_argument("--no-onefile", action="store_false", dest="onefile", help="Disable onefile mode") - build_p.add_argument("--windowed", "-w", action="store_true", help="Windowed/GUI app (no console)") - build_p.add_argument("--icon", help="Path to icon file (.ico/.icns)") - build_p.add_argument("--distpath", help="Where to put the bundled app") - build_p.add_argument("--workpath", help="Where to put build files (PyInstaller work path)") - build_p.add_argument("--specpath", help="Where to put the generated .spec file") - build_p.add_argument("--clean", action="store_true", help="Clean PyInstaller cache and remove temporary files") - build_p.add_argument("--add-data", "-a", action="append", default=[], help="Additional data to bundle. Format src:dest (POSIX) or src;dest (Windows). Can be passed multiple times.") - build_p.add_argument("--hidden-import", "-i", action="append", default=[], help="Hidden imports to pass to PyInstaller") - - args = parser.parse_args() - - if args.command == "build": - build(args) + parser = argparse.ArgumentParser( + prog="build.py", description="Build a standalone executable using PyInstaller" + ) + sub = parser.add_subparsers(dest="command", required=True) + + build_p = sub.add_parser("build", help="Create executable with PyInstaller") + build_p.add_argument("--entry", "-e", required=True, help="Path to the entry-point python file") + build_p.add_argument("--name", "-n", default="VoxKit", help="Name of the generated executable") + build_p.add_argument( + "--onefile", + action="store_true", + default=True, + help="Produce a single-file executable (default enabled)", + ) + build_p.add_argument( + "--no-onefile", action="store_false", dest="onefile", help="Disable onefile mode" + ) + build_p.add_argument( + "--windowed", "-w", action="store_true", help="Windowed/GUI app (no console)" + ) + build_p.add_argument("--icon", help="Path to icon file (.ico/.icns)") + build_p.add_argument("--distpath", help="Where to put the bundled app") + build_p.add_argument("--workpath", help="Where to put build files (PyInstaller work path)") + build_p.add_argument("--specpath", help="Where to put the generated .spec file") + build_p.add_argument( + "--clean", action="store_true", help="Clean PyInstaller cache and remove temporary files" + ) + build_p.add_argument( + "--add-data", + "-a", + action="append", + default=[], + help="Additional data to bundle. Format src:dest (POSIX) or src;dest (Windows). " + "Can be passed multiple times.", + ) + build_p.add_argument( + "--hidden-import", + "-i", + action="append", + default=[], + help="Hidden imports to pass to PyInstaller", + ) + + args = parser.parse_args() + + if args.command == "build": + build(args) + if __name__ == "__main__": - main() + main() diff --git a/gui.py b/gui.py index 245562f..1b74f1b 100644 --- a/gui.py +++ b/gui.py @@ -126,7 +126,6 @@ def update_active_tab_style(self, active_button): pipeline_widget.setStyleSheet(inactive_style) if manage_widget: manage_widget.setStyleSheet(active_style) - def open_datasets(self): """Switch to Datasets view""" diff --git a/hooks/hook-g2p_en.py b/hooks/hook-g2p_en.py index 50f8dc1..aee4626 100644 --- a/hooks/hook-g2p_en.py +++ b/hooks/hook-g2p_en.py @@ -2,7 +2,8 @@ g2p_en includes data files (checkpoint20.npz) that need to be bundled. """ + from PyInstaller.utils.hooks import collect_data_files # Collect all data files from g2p_en -datas = collect_data_files('g2p_en') +datas = collect_data_files("g2p_en") diff --git a/hooks/hook-pypllrcomputer.py b/hooks/hook-pypllrcomputer.py index ea6f5d9..45e2607 100644 --- a/hooks/hook-pypllrcomputer.py +++ b/hooks/hook-pypllrcomputer.py @@ -1,9 +1,10 @@ """PyInstaller hook for pypllrcomputer package. -pypllrcomputer includes data files in the models directory (*.pkl files) +pypllrcomputer includes data files in the models directory (*.pkl files) that need to be bundled. """ + from PyInstaller.utils.hooks import collect_data_files # Collect all data files from pypllrcomputer (includes models/*.pkl) -datas = collect_data_files('pypllrcomputer') +datas = collect_data_files("pypllrcomputer") diff --git a/hooks/hook-speechbrain.py b/hooks/hook-speechbrain.py index 39bba66..af16f7d 100644 --- a/hooks/hook-speechbrain.py +++ b/hooks/hook-speechbrain.py @@ -3,18 +3,19 @@ speechbrain needs to be extracted to the filesystem because code tries to access the package directory directly. We force extraction by collecting as data files. """ + import os -from pathlib import Path -from PyInstaller.utils.hooks import get_package_paths, collect_data_files + +from PyInstaller.utils.hooks import get_package_paths # Get speechbrain package location -pkg_base, pkg_dir = get_package_paths('speechbrain') +pkg_base, pkg_dir = get_package_paths("speechbrain") # Collect all Python files and data files as "datas" to force filesystem extraction datas = [] for root, dirs, files in os.walk(pkg_dir): for file in files: - if file.endswith(('.py', '.yaml', '.txt', '.json')): + if file.endswith((".py", ".yaml", ".txt", ".json")): src = os.path.join(root, file) # Calculate relative path from package base rel_path = os.path.relpath(root, pkg_base) @@ -22,4 +23,5 @@ # Also collect submodules as hidden imports so they're available for import from PyInstaller.utils.hooks import collect_submodules -hiddenimports = collect_submodules('speechbrain') + +hiddenimports = collect_submodules("speechbrain") diff --git a/hooks/hook-typeguard.py b/hooks/hook-typeguard.py index 375dfbc..258dd4b 100644 --- a/hooks/hook-typeguard.py +++ b/hooks/hook-typeguard.py @@ -2,10 +2,11 @@ PyInstaller hook for typeguard Disables runtime type checking in frozen applications """ + from PyInstaller.utils.hooks import collect_all # Collect everything from typeguard -datas, binaries, hiddenimports = collect_all('typeguard') +datas, binaries, hiddenimports = collect_all("typeguard") # Disable typeguard's runtime checking in frozen apps -excludedimports = [] +excludedimports: list[str] = [] diff --git a/main.py b/main.py index 5bdaa3f..e9badc2 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,7 @@ -import sys import faulthandler -import os import multiprocessing +import os +import sys # CRITICAL: Must be at the top for frozen apps using multiprocessing if __name__ == "__main__": @@ -11,39 +11,37 @@ faulthandler.enable() # Apply patches for frozen (PyInstaller) environment -if getattr(sys, 'frozen', False): - import _frozen_patch - +if getattr(sys, "frozen", False): # Define the minimal required environment minimal_env = { - 'HOME': os.environ.get('HOME') or os.path.expanduser('~'), - 'USER': os.environ.get('USER') or os.getlogin(), - 'TMPDIR': os.environ.get('TMPDIR') or '/tmp', + "HOME": os.environ.get("HOME") or os.path.expanduser("~"), + "USER": os.environ.get("USER") or os.getlogin(), + "TMPDIR": os.environ.get("TMPDIR") or "/tmp", } - + # PyInstaller-specific: Add Qt plugin paths - if getattr(sys, '_MEIPASS', None): - bundle_dir = sys._MEIPASS - qt_plugins = os.path.join(bundle_dir, 'PyQt6', 'Qt6', 'plugins') + if getattr(sys, "_MEIPASS", None): + bundle_dir = sys._MEIPASS # type: ignore[attr-defined] + qt_plugins = os.path.join(bundle_dir, "PyQt6", "Qt6", "plugins") if os.path.exists(qt_plugins): - minimal_env['QT_PLUGIN_PATH'] = qt_plugins - - platform_plugins = os.path.join(bundle_dir, 'PyQt6', 'Qt6', 'plugins', 'platforms') + minimal_env["QT_PLUGIN_PATH"] = qt_plugins + + platform_plugins = os.path.join(bundle_dir, "PyQt6", "Qt6", "plugins", "platforms") if os.path.exists(platform_plugins): - minimal_env['QT_QPA_PLATFORM_PLUGIN_PATH'] = platform_plugins - + minimal_env["QT_QPA_PLATFORM_PLUGIN_PATH"] = platform_plugins + # Clear all environment variables os.environ.clear() - + # Set the minimal required ones for key, value in minimal_env.items(): if value: os.environ[key] = value - from gui import AlignmentGUI from PyQt6.QtWidgets import QApplication + from voxkit.storage.utils import get_storage_root @@ -52,14 +50,14 @@ def main(): try: storage_root = get_storage_root() print(f"[INFO] Storage root: {storage_root}") - + (storage_root / "computed-likelihoods").mkdir(parents=True, exist_ok=True) (storage_root / "custom-likelihoods").mkdir(parents=True, exist_ok=True) print("[INFO] Created required directories") except Exception as e: print(f"[WARNING] Could not create directories: {e}") # Continue anyway - directories will be created on-demand if needed - + app = QApplication(sys.argv) app.setStyle("Fusion") window = AlignmentGUI() @@ -70,5 +68,5 @@ def main(): if __name__ == "__main__": # Prevent multiprocessing from spawning new app windows in frozen builds multiprocessing.freeze_support() - multiprocessing.set_start_method('spawn', force=True) + multiprocessing.set_start_method("spawn", force=True) main() diff --git a/pyproject.toml b/pyproject.toml index a484dd5..ab42859 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,19 +4,19 @@ build-backend = "setuptools.build_meta" [project] name = "pypllr-gui" version = "0.1.0" -description = "Add your description here" +description = "AI/ML Research -> Clinical Applications (Speech Pathology)" readme = "README.md" requires-python = ">=3.11" license = {text = "MIT"} authors = [ - {name = "Your Name", email = "your.email@example.com"} + {name = "Beckett Frey", email = "bfrey6@wisc.edu"} ] dependencies = [ "pyqt6>=6.9.1", "torch==2.8.0", "torchaudio==2.8.0", - "pypllrcomputer @ git+https://github.com/BrainBehaviorAnalyticsLab/PyPLLRComputer.git", + "pypllrcomputer @ git+https://github.com/BrainBehaviorAnalyticsLab/PyPLLRComputer.git@48a0c934e75f73235ba5a002538b8588fd6697e7", "wav2textgrid @ git+https://github.com/pkadambi/Wav2TextGrid.git@8db3afd", "datasets>=4.3.0", "accelerate>=1.11.0", @@ -28,9 +28,6 @@ dependencies = [ ] [dependency-groups] -security = [ - "safety>=3.6.2", -] dev = [ "pre-commit>=4.3.0", "ruff>=0.14.0", @@ -84,6 +81,11 @@ fixable = ["ALL"] [tool.ruff.lint.per-file-ignores] "tests/**/*.py" = ["S101"] "scripts/**/*.py" = ["S101","S603"] # Allow subprocess in scripts +"src/voxkit/services/**/*.py" = ["S603", "S108"] # Allow subprocess and temp dirs in services +"build.py" = ["S103", "S603", "S607"] # Allow chmod and subprocess in build script +"hooks/**/*.py" = ["E402"] # Allow imports after code in hooks +"main.py" = ["E402", "S108"] # Allow imports after code and temp dir in main +"src/voxkit/gui/pages/pipeline/evaluation_stacker.py" = ["S603", "S607"] # Allow subprocess in eval [tool.ruff.lint.isort] known-first-party = ["Wav2TextGrid"] diff --git a/src/voxkit/config.py b/src/voxkit/config.py index 093ab45..69a8670 100644 --- a/src/voxkit/config.py +++ b/src/voxkit/config.py @@ -12,4 +12,3 @@ Mode = Literal["MFAENGINE", "W2TGENGINE"] HELP_URL = "http://localhost:3000/help" - diff --git a/src/voxkit/engines/__init__.py b/src/voxkit/engines/__init__.py index 9cb6398..6d7d184 100644 --- a/src/voxkit/engines/__init__.py +++ b/src/voxkit/engines/__init__.py @@ -57,9 +57,9 @@ from typing import List from .base import AlignmentEngine, ToolType - from .w2tg_engine import W2TGEngine + class EngineManager: """ Manager class for registered engines. @@ -98,7 +98,7 @@ def get_tool_providers(self, tool: ToolType) -> dict[str, AlignmentEngine]: # Singleton instance for unified export/interface -w2tg = W2TGEngine(id ="W2TGENGINE") +w2tg = W2TGEngine(id="W2TGENGINE") engines = EngineManager({w2tg.id: w2tg}) __all__ = ["engines"] diff --git a/src/voxkit/engines/base.py b/src/voxkit/engines/base.py index 319be14..854f96e 100644 --- a/src/voxkit/engines/base.py +++ b/src/voxkit/engines/base.py @@ -70,7 +70,7 @@ def __init__( ) self.human_readable_name = human_readable_name or self.__class__.__name__ self.id = id or self.__class__.__name__ - + @abstractmethod def align(self, dataset_id: str, model_id: str) -> None: """ @@ -85,7 +85,7 @@ def align(self, dataset_id: str, model_id: str) -> None: model_id: Identifier of the alignment model to use. """ raise NotImplementedError() - + @abstractmethod def train_aligner( self, audio_root: Path, textgrid_root: Path, base_model_id: str | None, new_model_id: str @@ -166,8 +166,8 @@ def _load_json(self, path: Path | str) -> dict: path = Path(path) if not path.exists(): - return None - + return {} + with open(path, "r", encoding="utf-8") as f: return json.load(f) @@ -201,7 +201,7 @@ def get_settings(self, tool_type: ToolType) -> dict: f"Settings path not given for tool type '{tool_type}' in this engine." ) settings = self._load_json(Path(get_storage_root() / cfg.store_file)) - + if tool_type == "train": if settings is None: # Get default settings if none exist diff --git a/src/voxkit/engines/mfa_engine.py b/src/voxkit/engines/mfa_engine.py index 49c9c8a..6804d21 100644 --- a/src/voxkit/engines/mfa_engine.py +++ b/src/voxkit/engines/mfa_engine.py @@ -45,13 +45,14 @@ def __init__(self, id: str | None = None): reference_url="https://montreal-forced-aligner.readthedocs.io/en/latest/", description=( "The Montreal Forced Aligner (MFA) is a tool for " - "aligning audio files to their corresponding transcripts with a focus on speech pathology. " + "aligning audio files to their corresponding transcripts with a focus on " + "speech pathology. " ), human_readable_name="MFA", id=id, ) - def align(self, audio_root: Path, output_root: Path, model_id: str) -> None: + def align(self, dataset_id: str, model_id: str) -> None: print(f"Aligning with MFA using model: {model_id}") pass # Implement the alignment logic using MFA here diff --git a/src/voxkit/engines/w2tg_engine.py b/src/voxkit/engines/w2tg_engine.py index 2825588..d7ec56b 100644 --- a/src/voxkit/engines/w2tg_engine.py +++ b/src/voxkit/engines/w2tg_engine.py @@ -1,4 +1,3 @@ -import os from pathlib import Path from voxkit.gui.frameworks.settings_modal import ( @@ -6,7 +5,7 @@ FieldType, SettingsConfig, ) -from voxkit.storage import alignments, models, datasets +from voxkit.storage import alignments, datasets, models from Wav2TextGrid.wav2textgrid import align_dirs from Wav2TextGrid.wav2textgrid_train import train_aligner @@ -114,11 +113,11 @@ def align(self, dataset_id: str, model_id: str) -> None: dataset_id=dataset_id, ) print(f"Alignment creation result: {result}, message: {msg}") - + if result is False: print(f"Alignment creation failed: {msg}") return - + alignment_meta = msg dataset_meta = datasets.get_dataset_metadata(dataset_id) model_meta = models.get_model_metadata(self.id, model_id) @@ -141,16 +140,23 @@ def align(self, dataset_id: str, model_id: str) -> None: filetype=settings.get("file_type", "wav"), use_speaker_adaptation=settings.get("use_speaker_adaptation", False), ) - alignments.update_alignment(dataset_id=dataset_id, alignment_id=alignment_meta["id"], updates={"status": "completed"}) + alignments.update_alignment( + dataset_id=dataset_id, + alignment_id=alignment_meta["id"], + updates={"status": "completed"}, + ) except Exception as e: print(f"Alignment failed: {e}") - alignments.update_alignment(dataset_id=dataset_id, alignment_id=alignment_meta["id"], updates={"status": "failed"}) + alignments.update_alignment( + dataset_id=dataset_id, + alignment_id=alignment_meta["id"], + updates={"status": "failed"}, + ) def train_aligner( self, audio_root: Path, textgrid_root: Path, base_model_id: str | None, new_model_id: str ) -> None: - new_model_actual_id = None try: successs, message = models.create_model( @@ -160,7 +166,7 @@ def train_aligner( if not successs: raise ValueError(f"Failed to create model entry: {message}") - + model_meta = message model_path = Path(model_meta["model_path"]) data_path = Path(model_meta["data_path"]) @@ -169,19 +175,22 @@ def train_aligner( new_model_actual_id = model_meta["id"] settings = self.get_settings("train") - base_model_path = models.get_model_metadata(engine_id=self.id, model_id=base_model_id)["model_path"] if base_model_id else None + base_model_path = ( + models.get_model_metadata(engine_id=self.id, model_id=base_model_id)["model_path"] + if base_model_id + else None + ) if base_model_path is None: - raise ValueError( - f"Invalid base model specified: {base_model_id}. " - ) - print(f'Args received for train_aligner: ' - f'audio_root={audio_root}, textgrid_root={textgrid_root}, ' - f'base_model_path={base_model_path}, model_path={model_path}, ' - f'eval_path={eval_path}, new_model_id={new_model_id}, ' - f'ntrain_epochs={settings.get("epochs", 50)}') - - + raise ValueError(f"Invalid base model specified: {base_model_id}. ") + print( + f"Args received for train_aligner: " + f"audio_root={audio_root}, textgrid_root={textgrid_root}, " + f"base_model_path={base_model_path}, model_path={model_path}, " + f"eval_path={eval_path}, new_model_id={new_model_id}, " + f"ntrain_epochs={settings.get('epochs', 50)}" + ) + print(f"Training aligner with settings: {settings}") print(f"Using base model path: {base_model_path}") train_aligner( diff --git a/src/voxkit/gui/components/column_dropdown.py b/src/voxkit/gui/components/column_dropdown.py index d20714a..e77e089 100644 --- a/src/voxkit/gui/components/column_dropdown.py +++ b/src/voxkit/gui/components/column_dropdown.py @@ -2,19 +2,19 @@ A custom ComboBox widget that supports multiple columns in its dropdown list. Single selection only. -""" +""" import sys -from PyQt6.QtWidgets import ( - QApplication, QWidget, QVBoxLayout, QComboBox, QTableView, QHeaderView -) -from PyQt6.QtGui import QStandardItemModel, QStandardItem + from PyQt6.QtCore import Qt +from PyQt6.QtGui import QStandardItem, QStandardItemModel +from PyQt6.QtWidgets import QApplication, QComboBox, QHeaderView, QTableView, QVBoxLayout, QWidget + class MultiColumnComboBox(QComboBox): def __init__(self, parent=None): super().__init__(parent) - + # Use a table view for the dropdown popup to show multiple columns table_view = QTableView() table_view.setSelectionBehavior(QTableView.SelectionBehavior.SelectRows) @@ -22,31 +22,32 @@ def __init__(self, parent=None): table_view.verticalHeader().hide() table_view.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeMode.Stretch) table_view.setShowGrid(False) - + self.setView(table_view) - + # Optional: make the popup wider self.view().setMinimumWidth(400) def set_data(self, rows, headers=None, placeholder=None): """ Populate the combo box with multi-column data. - - :param rows: List of dicts with 'id' key and 'data' key containing tuple/list of column values + + :param rows: List of dicts with 'id' key and 'data' key containing + tuple/list of column values e.g., [{'id': 1, 'data': ("Alice", 30, "New York")}, ...] :param headers: Optional list of column headers :param placeholder: Optional placeholder text to show when no item is selected """ - num_cols = len(rows[0]['data']) if rows else 0 + num_cols = len(rows[0]["data"]) if rows else 0 model = QStandardItemModel(len(rows), num_cols) - + if headers: model.setHorizontalHeaderLabels(headers) - + for row_idx, row in enumerate(rows): - row_id = row['id'] - row_data = row['data'] - + row_id = row["id"] + row_data = row["data"] + for col_idx, value in enumerate(row_data): item = QStandardItem(str(value)) item.setFlags(Qt.ItemFlag.ItemIsEnabled | Qt.ItemFlag.ItemIsSelectable) @@ -54,45 +55,46 @@ def set_data(self, rows, headers=None, placeholder=None): if col_idx == 0: item.setData(row_id, Qt.ItemDataRole.UserRole) model.setItem(row_idx, col_idx, item) - + self.setModel(model) - + # The line edit (current visible item) will show the first column by default self.setModelColumn(0) - + # Set placeholder and clear selection if provided if placeholder: self.setPlaceholderText(placeholder) self.setCurrentIndex(-1) - + def current_id(self): """Get the ID of the currently selected row.""" index = self.model().index(self.currentIndex(), 0) if not index.isValid(): return None - + return self.model().data(index, Qt.ItemDataRole.UserRole) - + + if __name__ == "__main__": app = QApplication(sys.argv) - + window = QWidget() layout = QVBoxLayout(window) - + combo = MultiColumnComboBox() data = [ - {'id': 1, 'data': ("Alice", 30, "New York")}, - {'id': 2, 'data': ("Bob", 25, "Los Angeles")}, - {'id': 3, 'data': ("Charlie", 35, "Chicago")}, + {"id": 1, "data": ("Alice", 30, "New York")}, + {"id": 2, "data": ("Bob", 25, "Los Angeles")}, + {"id": 3, "data": ("Charlie", 35, "Chicago")}, ] headers = ["Name", "Age", "City"] combo.set_data(data, headers) - + layout.addWidget(combo) - + window.setWindowTitle("Multi-Column ComboBox Example") window.resize(500, 200) window.show() - + sys.exit(app.exec()) diff --git a/src/voxkit/gui/components/csv_visual.py b/src/voxkit/gui/components/csv_visual.py index 6ff7f3e..5cf7098 100644 --- a/src/voxkit/gui/components/csv_visual.py +++ b/src/voxkit/gui/components/csv_visual.py @@ -8,7 +8,7 @@ import csv import os from pathlib import Path -from typing import Optional +from typing import Any, Optional from PyQt6.QtCore import Qt from PyQt6.QtWidgets import ( @@ -30,33 +30,33 @@ class CSVVisualizationWidget(QWidget): """ A themed widget for visualizing CSV files with search and export functionality. - + Args: csv_path: Optional path to CSV file to load on initialization parent: Parent widget - + Example: >>> csv_widget = CSVVisualizationWidget("/path/to/data.csv") >>> layout.addWidget(csv_widget) """ - + def __init__(self, csv_path: Optional[str] = None, parent=None): super().__init__(parent) self.csv_path = csv_path - self.data = [] - self.headers = [] - self.filtered_data = [] - + self.data: list[Any] = [] + self.headers: list[str] = [] + self.filtered_data: list[Any] = [] + self._init_ui() - + if csv_path: self.load_csv(csv_path) - + def _init_ui(self): """Initialize the user interface""" layout = QVBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) - + # Main container container = QGroupBox("CSV Viewer") container.setStyleSheet(""" @@ -75,23 +75,23 @@ def _init_ui(self): padding: 0 5px; } """) - + container_layout = QVBoxLayout() - + # Top bar with file info and controls top_bar = self._create_top_bar() container_layout.addLayout(top_bar) - + # Search bar search_bar = self._create_search_bar() container_layout.addLayout(search_bar) - + # Table widget self.table = QTableWidget() self.table.setAlternatingRowColors(True) self._style_table() container_layout.addWidget(self.table) - + # Stats bar self.stats_label = QLabel("No data loaded") self.stats_label.setStyleSheet(""" @@ -103,14 +103,14 @@ def _init_ui(self): } """) container_layout.addWidget(self.stats_label) - + container.setLayout(container_layout) layout.addWidget(container) - + def _create_top_bar(self) -> QHBoxLayout: """Create the top bar with file info and action buttons""" top_layout = QHBoxLayout() - + # File path label self.file_label = QLabel("No file loaded") self.file_label.setStyleSheet(""" @@ -121,16 +121,16 @@ def _create_top_bar(self) -> QHBoxLayout: } """) top_layout.addWidget(self.file_label) - + top_layout.addStretch() - + # Load button load_btn = QPushButton("Load CSV") load_btn.setFixedWidth(100) load_btn.setStyleSheet(self._button_style()) load_btn.clicked.connect(self.browse_csv) top_layout.addWidget(load_btn) - + # Export button self.export_btn = QPushButton("Export") self.export_btn.setFixedWidth(80) @@ -138,7 +138,7 @@ def _create_top_bar(self) -> QHBoxLayout: self.export_btn.setStyleSheet(self._button_style()) self.export_btn.clicked.connect(self.export_filtered) top_layout.addWidget(self.export_btn) - + # Refresh button self.refresh_btn = QPushButton("Refresh") self.refresh_btn.setFixedWidth(80) @@ -146,17 +146,17 @@ def _create_top_bar(self) -> QHBoxLayout: self.refresh_btn.setStyleSheet(self._button_style()) self.refresh_btn.clicked.connect(self.refresh) top_layout.addWidget(self.refresh_btn) - + return top_layout - + def _create_search_bar(self) -> QHBoxLayout: """Create the search/filter bar""" search_layout = QHBoxLayout() - + search_label = QLabel("Search:") search_label.setStyleSheet("color: #2c3e50; font-weight: 500;") search_layout.addWidget(search_label) - + self.search_input = QLineEdit() self.search_input.setPlaceholderText("Filter rows by any column value...") self.search_input.textChanged.connect(self.filter_data) @@ -172,16 +172,16 @@ def _create_search_bar(self) -> QHBoxLayout: } """) search_layout.addWidget(self.search_input, stretch=1) - + # Clear search button clear_btn = QPushButton("Clear") clear_btn.setFixedWidth(60) clear_btn.setStyleSheet(self._button_style()) clear_btn.clicked.connect(self.clear_search) search_layout.addWidget(clear_btn) - + return search_layout - + def _style_table(self): """Apply theme-matching styles to the table""" self.table.setStyleSheet(""" @@ -216,7 +216,7 @@ def _style_table(self): background-color: #f8f9fa; } """) - + def _button_style(self) -> str: """Return consistent button styling""" return """ @@ -241,82 +241,67 @@ def _button_style(self) -> str: border-color: #e0e0e0; } """ - + def load_csv(self, csv_path: str) -> bool: """ Load a CSV file and display it in the table. - + Args: csv_path: Path to the CSV file - + Returns: True if successful, False otherwise """ if not os.path.exists(csv_path): - QMessageBox.warning( - self, - "File Not Found", - f"The file '{csv_path}' does not exist." - ) + QMessageBox.warning(self, "File Not Found", f"The file '{csv_path}' does not exist.") return False - + try: - with open(csv_path, 'r', encoding='utf-8') as f: + with open(csv_path, "r", encoding="utf-8") as f: reader = csv.reader(f) rows = list(reader) - + if not rows: - QMessageBox.warning( - self, - "Empty File", - "The CSV file is empty." - ) + QMessageBox.warning(self, "Empty File", "The CSV file is empty.") return False - + self.headers = rows[0] self.data = rows[1:] self.filtered_data = self.data.copy() self.csv_path = csv_path - + self._populate_table() self._update_ui_state() - + return True - + except Exception as e: - QMessageBox.critical( - self, - "Error Loading CSV", - f"Failed to load CSV file:\n{str(e)}" - ) + QMessageBox.critical(self, "Error Loading CSV", f"Failed to load CSV file:\n{str(e)}") return False - + def _populate_table(self): """Populate the table with current filtered data""" self.table.clear() self.table.setRowCount(len(self.filtered_data)) self.table.setColumnCount(len(self.headers)) self.table.setHorizontalHeaderLabels(self.headers) - + # Populate cells for row_idx, row_data in enumerate(self.filtered_data): for col_idx, cell_data in enumerate(row_data): item = QTableWidgetItem(str(cell_data)) item.setFlags(item.flags() & ~Qt.ItemFlag.ItemIsEditable) # Read-only self.table.setItem(row_idx, col_idx, item) - + # Auto-resize columns header = self.table.horizontalHeader() for i in range(len(self.headers)): header.setSectionResizeMode(i, QHeaderView.ResizeMode.ResizeToContents) - + # Make last column stretch if len(self.headers) > 0: - header.setSectionResizeMode( - len(self.headers) - 1, - QHeaderView.ResizeMode.Stretch - ) - + header.setSectionResizeMode(len(self.headers) - 1, QHeaderView.ResizeMode.Stretch) + def _update_ui_state(self): """Update UI elements based on current state""" if self.csv_path: @@ -324,26 +309,23 @@ def _update_ui_state(self): self.file_label.setText(f"File: {filename}") self.export_btn.setEnabled(True) self.refresh_btn.setEnabled(True) - + total_rows = len(self.data) filtered_rows = len(self.filtered_data) cols = len(self.headers) - + if filtered_rows < total_rows: self.stats_label.setText( - f"Showing {filtered_rows} of {total_rows} rows | " - f"{cols} columns" + f"Showing {filtered_rows} of {total_rows} rows | {cols} columns" ) else: - self.stats_label.setText( - f"{total_rows} rows × {cols} columns" - ) + self.stats_label.setText(f"{total_rows} rows × {cols} columns") else: self.file_label.setText("No file loaded") self.export_btn.setEnabled(False) self.refresh_btn.setEnabled(False) self.stats_label.setText("No data loaded") - + def filter_data(self, search_text: str): """Filter table data based on search text""" if not search_text: @@ -351,95 +333,76 @@ def filter_data(self, search_text: str): else: search_lower = search_text.lower() self.filtered_data = [ - row for row in self.data - if any(search_lower in str(cell).lower() for cell in row) + row for row in self.data if any(search_lower in str(cell).lower() for cell in row) ] - + self._populate_table() self._update_ui_state() - + def clear_search(self): """Clear the search filter""" self.search_input.clear() - + def browse_csv(self): """Open file dialog to select a CSV file""" file_path, _ = QFileDialog.getOpenFileName( - self, - "Select CSV File", - "", - "CSV Files (*.csv);;All Files (*)" + self, "Select CSV File", "", "CSV Files (*.csv);;All Files (*)" ) - + if file_path: self.load_csv(file_path) - + def refresh(self): """Reload the current CSV file""" if self.csv_path: self.load_csv(self.csv_path) self.search_input.clear() - + def export_filtered(self): """Export the currently filtered data to a new CSV file""" if not self.filtered_data: - QMessageBox.information( - self, - "No Data", - "No data to export." - ) + QMessageBox.information(self, "No Data", "No data to export.") return - + file_path, _ = QFileDialog.getSaveFileName( - self, - "Export Filtered Data", - "filtered_data.csv", - "CSV Files (*.csv)" + self, "Export Filtered Data", "filtered_data.csv", "CSV Files (*.csv)" ) - + if not file_path: return - + try: - with open(file_path, 'w', newline='', encoding='utf-8') as f: + with open(file_path, "w", newline="", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow(self.headers) writer.writerows(self.filtered_data) - - QMessageBox.information( - self, - "Export Successful", - f"Data exported to:\n{file_path}" - ) + + QMessageBox.information(self, "Export Successful", f"Data exported to:\n{file_path}") except Exception as e: - QMessageBox.critical( - self, - "Export Failed", - f"Failed to export data:\n{str(e)}" - ) - + QMessageBox.critical(self, "Export Failed", f"Failed to export data:\n{str(e)}") + def get_data(self) -> list: """ Get the current filtered data. - + Returns: List of rows (each row is a list of values) """ return self.filtered_data.copy() - + def get_headers(self) -> list: """ Get the CSV headers. - + Returns: List of column header names """ return self.headers.copy() - + def set_data(self, headers: list, data: list): """ Programmatically set data without loading from file. - + Args: headers: List of column header names data: List of rows (each row is a list of values) @@ -448,6 +411,6 @@ def set_data(self, headers: list, data: list): self.data = data self.filtered_data = data.copy() self.csv_path = None - + self._populate_table() self._update_ui_state() diff --git a/src/voxkit/gui/components/horizontal_button_selector.py b/src/voxkit/gui/components/horizontal_button_selector.py index 728be7c..62d24da 100644 --- a/src/voxkit/gui/components/horizontal_button_selector.py +++ b/src/voxkit/gui/components/horizontal_button_selector.py @@ -100,7 +100,9 @@ def _force_selected_state(self, is_selected: bool): padding: 6px 12px; }} QPushButton:hover {{ - background-color: rgba({gray_value + 20}, {gray_value + 20}, {gray_value + 20}, {opacity + 0.2}); + background-color: rgba( + {gray_value + 20}, {gray_value + 20}, {gray_value + 20}, {opacity + 0.2} + ); }} """) @@ -352,7 +354,6 @@ def select_button(self, index: int): if not (0 <= index < len(self.buttons)) or self.is_animating: return - old_index = self.current_index self.current_index = index # Scroll to center @@ -465,7 +466,6 @@ def _update_button_scales(self): # Check if the closest button changed - if so, update selection # BUT only if we're not currently animating (to prevent double-triggering on clicks) if closest_index != -1 and closest_index != self.current_index and not self.is_animating: - old_index = self.current_index self.current_index = closest_index # Emit signal diff --git a/src/voxkit/gui/components/huggingface_button.py b/src/voxkit/gui/components/huggingface_button.py index 6f07cf1..1b168b0 100644 --- a/src/voxkit/gui/components/huggingface_button.py +++ b/src/voxkit/gui/components/huggingface_button.py @@ -2,10 +2,8 @@ HuggingFace branded button component with logo """ -from PyQt6.QtWidgets import QPushButton, QHBoxLayout, QWidget, QLabel -from PyQt6.QtCore import Qt, QSize -from PyQt6.QtGui import QPixmap, QIcon -from PyQt6.QtSvgWidgets import QSvgWidget +from PyQt6.QtCore import Qt +from PyQt6.QtWidgets import QHBoxLayout, QLabel, QPushButton, QWidget class HuggingFaceButton(QPushButton): @@ -91,7 +89,7 @@ def _setup_ui(self): """Set up the icon button UI""" self.setText("🤗") self.setFixedSize(38, 38) - + self.setStyleSheet(""" QPushButton { background: qlineargradient(x1:0, y1:0, x2:0, y2:1, @@ -125,6 +123,7 @@ def _setup_ui(self): # Example usage if __name__ == "__main__": import sys + from PyQt6.QtWidgets import QApplication, QVBoxLayout, QWidget app = QApplication(sys.argv) diff --git a/src/voxkit/gui/components/toggle_switch.py b/src/voxkit/gui/components/toggle_switch.py index 1a43232..e4d7bb8 100644 --- a/src/voxkit/gui/components/toggle_switch.py +++ b/src/voxkit/gui/components/toggle_switch.py @@ -1,4 +1,10 @@ -from PyQt6.QtCore import QEasingCurve, QPropertyAnimation, QRect, Qt, pyqtProperty +from PyQt6.QtCore import ( # type: ignore[attr-defined] + QEasingCurve, + QPropertyAnimation, + QRect, + Qt, + pyqtProperty, +) from PyQt6.QtGui import QBrush, QColor, QPainter, QPen from PyQt6.QtWidgets import QWidget @@ -18,11 +24,11 @@ def __init__(self, checked=False, parent=None): self._thumb_pos = self.width() - self.height() if self._checked else 0 # --- Expose to Qt's meta-system --- - @pyqtProperty(float) + @pyqtProperty(float) # type: ignore[no-redef] def thumb_pos(self): return self._thumb_pos - @thumb_pos.setter + @thumb_pos.setter # type: ignore[no-redef] def thumb_pos(self, pos): self._thumb_pos = pos self.update() diff --git a/src/voxkit/gui/frameworks/_______/categorical_list.py b/src/voxkit/gui/frameworks/_______/categorical_list.py index 7dc20a7..d04dea2 100644 --- a/src/voxkit/gui/frameworks/_______/categorical_list.py +++ b/src/voxkit/gui/frameworks/_______/categorical_list.py @@ -15,6 +15,7 @@ from .styles import Buttons, Colors, Labels + class CategoryListItem(QWidget): """Custom widget for each list item with checkbox, date, and info button""" @@ -374,7 +375,7 @@ def init_ui(self): """) self.delete_btn.clicked.connect(self.on_delete) action_layout.addWidget(self.delete_btn) - + group.setLayout(action_layout) main_layout.addWidget(group) @@ -465,7 +466,7 @@ def get_selected_items(self): if widget and isinstance(widget, CategoryListItem) and widget.is_checked(): selected[widget.item_key] = widget.item_data return selected - + def set_items(self, mode, items): """Set the items for a specific category""" if mode not in self.data: @@ -501,7 +502,7 @@ def on_export(self): f"Exporting {len(selected_items)} item(s) to '{folder_name}'", QMessageBox.StandardButton.Ok, ) - + def on_delete(self): """Handle delete button click""" selected_items = self.get_selected_items() @@ -519,7 +520,8 @@ def on_delete(self): reply = QMessageBox.question( self, "Confirm Deletion", - f"Are you sure you want to delete {len(selected_items.keys())} item(s)?\n\nThis action cannot be undone.", + f"Are you sure you want to delete {len(selected_items.keys())} item(s)?\n\n" + "This action cannot be undone.", QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, QMessageBox.StandardButton.No, ) diff --git a/src/voxkit/gui/frameworks/categorical_table/__init__.py b/src/voxkit/gui/frameworks/categorical_table/__init__.py index c003fd9..7ffdf45 100644 --- a/src/voxkit/gui/frameworks/categorical_table/__init__.py +++ b/src/voxkit/gui/frameworks/categorical_table/__init__.py @@ -1,4 +1,3 @@ from .categorical_table import CategoricalTableWidget __all__ = ["CategoricalTableWidget"] - diff --git a/src/voxkit/gui/frameworks/categorical_table/categorical_table.py b/src/voxkit/gui/frameworks/categorical_table/categorical_table.py index d649dab..6d2b337 100644 --- a/src/voxkit/gui/frameworks/categorical_table/categorical_table.py +++ b/src/voxkit/gui/frameworks/categorical_table/categorical_table.py @@ -1,23 +1,20 @@ +from PyQt6.QtCore import Qt from PyQt6.QtWidgets import ( - QCheckBox, + QDialog, + QFormLayout, QGroupBox, QHBoxLayout, QHeaderView, - QInputDialog, QLabel, QMessageBox, QPushButton, + QScrollArea, QTableWidget, QTableWidgetItem, QVBoxLayout, QWidget, - QDialog, - QFormLayout, - QScrollArea, ) -from PyQt6.QtCore import Qt - from voxkit.gui.components import HuggingFaceButton from voxkit.gui.frameworks._______.styles import Buttons, Colors, Labels @@ -38,14 +35,17 @@ def __init__( ): """ Initialize the CategoricalTableWidget. - + Args: refresh_data_function: Callable that returns dict of categorical data - export_function: Callable(category: str, items: list[dict]) -> (success: bool, message: str) + export_function: Callable(category: str, items: list[dict]) + -> (success: bool, message: str) import_function: Callable(category: str) -> (success: bool, message: str) - delete_function: Callable(category: str, items: list[dict]) -> (success: bool, message: str) + delete_function: Callable(category: str, items: list[dict]) + -> (success: bool, message: str) columns_shown: Optional list of column names to display - single_selection_flag: If True, only one item can be selected at a time (default: False) + single_selection_flag: If True, only one item can be selected at a time + (default: False) huggingface_callback: Optional callback for HuggingFace button click parent: Parent widget """ @@ -58,7 +58,7 @@ def __init__( self.single_selection_flag = single_selection_flag self.huggingface_callback = huggingface_callback self.current_category_index = 0 - + # Initialize data self.data = {} self.category_keys = [] @@ -74,18 +74,18 @@ def init_ui(self): # Title header with optional HuggingFace button header_layout = QHBoxLayout() - + title = QLabel("Model Management") title.setStyleSheet(Labels.TITLE) header_layout.addWidget(title) - + # Add HuggingFace button if callback provided if self.huggingface_callback: header_layout.addStretch() self.hf_button = HuggingFaceButton(title="Browse Models") self.hf_button.clicked.connect(self.huggingface_callback) header_layout.addWidget(self.hf_button) - + main_layout.addLayout(header_layout) main_layout.addSpacing(10) @@ -222,16 +222,15 @@ def init_ui(self): } """ ) - + # Configure table - header = self.table_widget.horizontalHeader() self.table_widget.setSelectionBehavior(QTableWidget.SelectionBehavior.SelectRows) # Set selection mode based on single_selection_flag if self.single_selection_flag: self.table_widget.setSelectionMode(QTableWidget.SelectionMode.SingleSelection) else: self.table_widget.setSelectionMode(QTableWidget.SelectionMode.MultiSelection) - + table_container_layout.addWidget(self.table_widget) # Add empty state label @@ -323,7 +322,7 @@ def init_ui(self): """) self.delete_btn.clicked.connect(self.on_delete) action_layout.addWidget(self.delete_btn) - + group.setLayout(action_layout) main_layout.addWidget(group) @@ -357,7 +356,7 @@ def set_data(self, data, columns_shown=None): def update_display(self): """Update the display for the current category""" self.table_widget.clear() - + if not self.category_keys: self.category_label.setText("No Categories") self.prev_btn.setEnabled(False) @@ -379,7 +378,7 @@ def update_display(self): # Get category data category_data = self.data[current_category] - + if not category_data: self.table_widget.hide() self.empty_label.show() @@ -397,7 +396,7 @@ def update_display(self): if isinstance(item, dict): all_keys.update(item.keys()) self.columns_shown = sorted(list(all_keys)) - + # Set up table display_columns = self.columns_shown + ["Actions"] self.table_widget.setRowCount(len(category_data)) @@ -408,17 +407,21 @@ def update_display(self): for row_idx, item_data in enumerate(category_data): # Data columns for col_idx, column_name in enumerate(self.columns_shown): - value = item_data.get(column_name, "Unknown") if isinstance(item_data, dict) else "Unknown" + value = ( + item_data.get(column_name, "Unknown") + if isinstance(item_data, dict) + else "Unknown" + ) table_item = QTableWidgetItem(str(value)) table_item.setTextAlignment(Qt.AlignmentFlag.AlignCenter) self.table_widget.setItem(row_idx, col_idx, table_item) - + # View button in centered container button_container = QWidget() button_layout = QHBoxLayout(button_container) button_layout.setContentsMargins(0, 0, 0, 0) button_layout.setAlignment(Qt.AlignmentFlag.AlignCenter) - + view_btn = QPushButton("View") view_btn.setFixedSize(60, 24) view_btn.setStyleSheet(f""" @@ -439,7 +442,7 @@ def update_display(self): view_btn.clicked.connect(lambda checked, idx=row_idx: self.view_item_details(idx)) button_layout.addWidget(view_btn) self.table_widget.setCellWidget(row_idx, len(self.columns_shown), button_container) - + # Configure column widths for optimal stretching header = self.table_widget.horizontalHeader() # Make data columns resize to contents or stretch @@ -458,10 +461,10 @@ def view_item_details(self, row_index): """Show all details for an item""" if not self.category_keys: return - + current_category = self.category_keys[self.current_category_index] category_data = self.data[current_category] - + if 0 <= row_index < len(category_data): item_data = category_data[row_index] self.show_detail_dialog(item_data, row_index) @@ -472,9 +475,9 @@ def show_detail_dialog(self, item_data, row_index): dialog.setWindowTitle(f"Item Details - Row {row_index + 1}") dialog.setMinimumWidth(450) dialog.setMinimumHeight(350) - + layout = QVBoxLayout(dialog) - + # Title title = QLabel(f"All Fields for Row {row_index + 1}") title.setStyleSheet(f""" @@ -488,7 +491,7 @@ def show_detail_dialog(self, item_data, row_index): }} """) layout.addWidget(title) - + # Create scrollable area for fields scroll = QScrollArea() scroll.setWidgetResizable(True) @@ -499,20 +502,20 @@ def show_detail_dialog(self, item_data, row_index): background-color: white; }} """) - + # Container for form layout container = QWidget() form_layout = QFormLayout(container) form_layout.setSpacing(10) form_layout.setContentsMargins(10, 10, 10, 10) - + # Add all fields if isinstance(item_data, dict): sorted_keys = sorted(item_data.keys()) - + for key in sorted_keys: value = item_data[key] - + # Key label key_label = QLabel(f"{key}:") key_label.setStyleSheet(f""" @@ -522,7 +525,7 @@ def show_detail_dialog(self, item_data, row_index): min-width: 120px; }} """) - + # Value label value_label = QLabel(str(value)) value_label.setWordWrap(True) @@ -535,22 +538,22 @@ def show_detail_dialog(self, item_data, row_index): border: 1px solid {Colors.BORDER}; }} """) - + form_layout.addRow(key_label, value_label) - + scroll.setWidget(container) layout.addWidget(scroll) - + # Close button close_btn = QPushButton("Close") close_btn.setStyleSheet(Buttons.SECONDARY) close_btn.clicked.connect(dialog.close) - + button_layout = QHBoxLayout() button_layout.addStretch() button_layout.addWidget(close_btn) layout.addLayout(button_layout) - + dialog.exec() def prev_category(self): @@ -569,7 +572,7 @@ def select_all(self): """Select all items in the current table""" if self.single_selection_flag: return # Not applicable in single selection mode - + self.table_widget.selectAll() def deselect_all(self): @@ -581,19 +584,19 @@ def get_selected_items(self): selected = [] if not self.category_keys: return selected - + current_category = self.category_keys[self.current_category_index] category_data = self.data[current_category] - + # Get selected rows from table selected_rows = self.table_widget.selectionModel().selectedRows() for index in selected_rows: row = index.row() if row < len(category_data): selected.append(category_data[row]) - + return selected - + def set_items(self, mode, items): """Set the items for a specific category""" if mode not in self.data: @@ -615,7 +618,7 @@ def on_export(self): return current_category = self.category_keys[self.current_category_index] - + try: success, message = self.export_function(current_category, selected_items) if not message: @@ -641,7 +644,7 @@ def on_export(self): f"An error occurred during export: {str(e)}", QMessageBox.StandardButton.Ok, ) - + def on_delete(self): """Handle delete button click""" selected_items = self.get_selected_items() @@ -659,25 +662,26 @@ def on_delete(self): reply = QMessageBox.question( self, "Confirm Deletion", - f"Are you sure you want to delete {len(selected_items)} item(s)?\n\nThis action cannot be undone.", + f"Are you sure you want to delete {len(selected_items)} item(s)?\n\n" + "This action cannot be undone.", QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, QMessageBox.StandardButton.No, ) if reply == QMessageBox.StandardButton.Yes: current_category = self.category_keys[self.current_category_index] - + try: success, message = self.delete_function(current_category, selected_items) - + if not message: return - + if success: # Refresh data after successful deletion self.refresh_data() self.update_display() - + QMessageBox.information( self, "Deletion Successful", @@ -711,17 +715,17 @@ def on_import(self): return current_category = self.category_keys[self.current_category_index] - + try: success, message = self.import_function(current_category) - + if not message: return if success: # Refresh data after successful import self.refresh_data() self.update_display() - + QMessageBox.information( self, "Import Successful", @@ -747,6 +751,7 @@ def on_import(self): # Example usage if __name__ == "__main__": import sys + from PyQt6.QtWidgets import QApplication app = QApplication(sys.argv) @@ -754,18 +759,58 @@ def on_import(self): # Sample data - now each category contains a list of dictionaries sample_data = { "MFA Models": [ - {"name": "english_us_arpa", "model_path": "/models/mfa/english_us_arpa", "download_date": "2025-01-15", "size": "125MB"}, - {"name": "german_mfa", "model_path": "/models/mfa/german_mfa", "download_date": "2025-02-20", "size": "98MB"}, - {"name": "french_prosodylab", "model_path": "/models/mfa/french_prosodylab", "download_date": "2025-03-10", "size": "110MB"}, + { + "name": "english_us_arpa", + "model_path": "/models/mfa/english_us_arpa", + "download_date": "2025-01-15", + "size": "125MB", + }, + { + "name": "german_mfa", + "model_path": "/models/mfa/german_mfa", + "download_date": "2025-02-20", + "size": "98MB", + }, + { + "name": "french_prosodylab", + "model_path": "/models/mfa/french_prosodylab", + "download_date": "2025-03-10", + "size": "110MB", + }, ], "W2TG Models": [ - {"name": "charsiu_en", "model_path": "/models/w2tg/charsiu_en", "download_date": "2025-04-05", "version": "1.0"}, - {"name": "custom_model_v1", "model_path": "/models/w2tg/custom_v1", "download_date": "2025-05-12", "version": "1.1"}, - {"name": "custom_model_v2", "model_path": "/models/w2tg/custom_v2", "download_date": "2025-06-18", "version": "2.0"}, + { + "name": "charsiu_en", + "model_path": "/models/w2tg/charsiu_en", + "download_date": "2025-04-05", + "version": "1.0", + }, + { + "name": "custom_model_v1", + "model_path": "/models/w2tg/custom_v1", + "download_date": "2025-05-12", + "version": "1.1", + }, + { + "name": "custom_model_v2", + "model_path": "/models/w2tg/custom_v2", + "download_date": "2025-06-18", + "version": "2.0", + }, ], "Dictionaries": [ - {"name": "english_us", "path": "/dicts/en_us.dict", "download_date": "2025-01-01", "entries": "50000"}, - {"name": "german", "path": "/dicts/de.dict", "download_date": "2025-02-01", "entries": "45000"}, + { + "name": "english_us", + "path": "/dicts/en_us.dict", + "download_date": "2025-01-01", + "entries": "50000", + }, + { + "name": "german", + "path": "/dicts/de.dict", + "download_date": "2025-02-01", + "entries": "45000", + }, ], } @@ -798,7 +843,7 @@ def delete_function(category, items): # Specify which columns to show (or leave None to auto-detect) columns_shown = ["name", "download_date"] - + # Create widget with single selection mode disabled (multi-select) widget = CategoricalTableWidget( refresh_data_function=refresh_data, @@ -813,4 +858,4 @@ def delete_function(category, items): widget.resize(800, 600) widget.show() - sys.exit(app.exec()) \ No newline at end of file + sys.exit(app.exec()) diff --git a/src/voxkit/gui/frameworks/settings_modal/api.py b/src/voxkit/gui/frameworks/settings_modal/api.py index 573a496..fda9148 100644 --- a/src/voxkit/gui/frameworks/settings_modal/api.py +++ b/src/voxkit/gui/frameworks/settings_modal/api.py @@ -22,6 +22,7 @@ class FieldType(Enum): LINEEDIT = "lineedit" COMBOBOX = "combobox" + @dataclass class FieldConfig: """ @@ -78,6 +79,7 @@ class FieldConfig: # Validation validator: Optional[Callable] = None + @dataclass class SettingsConfig: """Configuration container for creating a settings dialog. @@ -110,4 +112,4 @@ class SettingsConfig: dimensions: tuple[int, int] # (width, height) apply_blur: bool fields: list[FieldConfig] - store_file: str \ No newline at end of file + store_file: str diff --git a/src/voxkit/gui/frameworks/settings_modal/generic.py b/src/voxkit/gui/frameworks/settings_modal/generic.py index 1d8ef3d..e83e3f6 100644 --- a/src/voxkit/gui/frameworks/settings_modal/generic.py +++ b/src/voxkit/gui/frameworks/settings_modal/generic.py @@ -1,7 +1,7 @@ import json import os from pathlib import Path -from typing import Any +from typing import Any, Union from PyQt6.QtCore import QEasingCurve, QPropertyAnimation, Qt from PyQt6.QtWidgets import ( @@ -87,13 +87,13 @@ def __init__( raise ValueError("File path must be within the storage root directory.") self.field_configs = config.fields or [] - self.field_widgets = {} + self.field_widgets: dict[str, Any] = {} self._apply_blur = config.apply_blur # Setup overlay and blur if parent exists if parent and config.apply_blur: self._setup_overlay(parent) - + self._save_defaults() self._setup_ui(config.title, config.dimensions) self._create_fields() @@ -121,7 +121,7 @@ def _save_defaults(self): json.dump(defaults, f, indent=4) print(f"Default settings saved to {self.store_values_path}") - def _setup_overlay(self, parent): + def _setup_overlay(self, parent) -> None: """ Setup overlay widget and apply blur effect to parent window. @@ -133,8 +133,17 @@ def _setup_overlay(self, parent): parent: Parent widget to apply blur effect to. """ try: + if not hasattr(parent, "parent") or parent.parent is None: + return + + parent_func = parent.parent + if not callable(parent_func): + return + + main_window = parent_func() + if main_window is None: + return - main_window = parent.parent overlay = OverlayWidget(main_window) overlay.resize(main_window.size()) overlay.show() @@ -234,7 +243,7 @@ def _setup_ui(self, title: str, dims: tuple[int, int]): main_layout.addWidget(container) # Center dialog on parent - if self.parent: + if self.parent is not None and hasattr(self.parent, "parent"): try: main_window = self.parent.parent self.move( @@ -291,6 +300,7 @@ def _create_field_widget(self, config: FieldConfig) -> QWidget: Raises: ValueError: If field_type is not recognized. """ + widget: Union[QSpinBox, QDoubleSpinBox, QLineEdit, QComboBox] if config.field_type == FieldType.SPINBOX: widget = self._create_spinbox(config) elif config.field_type == FieldType.DOUBLE_SPINBOX: @@ -453,7 +463,8 @@ def get_values(self) -> dict[str, Any]: >>> print(values) {'batch_size': 32, 'learning_rate': 0.001, 'use_gpu': True} """ - values = {} + + values: dict[str, Union[int, float, str]] = {} for name, widget in self.field_widgets.items(): if isinstance(widget, QSpinBox) or isinstance(widget, QDoubleSpinBox): values[name] = widget.value() diff --git a/src/voxkit/gui/pages/datasets/datasets_page.py b/src/voxkit/gui/pages/datasets/datasets_page.py index 8565afb..7771bce 100644 --- a/src/voxkit/gui/pages/datasets/datasets_page.py +++ b/src/voxkit/gui/pages/datasets/datasets_page.py @@ -2,6 +2,7 @@ Datasets management page for registering, validating, and managing speech datasets. """ +from pathlib import Path from PyQt6.QtCore import Qt from PyQt6.QtWidgets import ( @@ -20,9 +21,8 @@ QWidget, ) -from pathlib import Path - from voxkit.analyzers import ManageAnalyzers +from voxkit.engines import engines from voxkit.gui.components import HuggingFaceButton from voxkit.gui.frameworks.settings_modal import ( FieldConfig, @@ -31,20 +31,21 @@ SettingsConfig, ) from voxkit.gui.workers import DatasetRegistrationWorker -from voxkit.storage import datasets, alignments +from voxkit.storage import alignments, datasets + from .styles import Colors -from voxkit.engines import engines ENGINE_IDS = engines.list_engines() + class DatasetsPage(QWidget): """Main datasets management page""" - def __init__(self, parent=None): + def __init__(self, parent: QWidget | None = None): super().__init__(parent) self.parent_window = parent self.registration_worker = None - self.selected_dataset: datasets.DatasetMetadata | None = None + self.selected_dataset: dict | None = None self.init_ui() self.refresh_datasets() @@ -56,18 +57,18 @@ def init_ui(self): # Title header with HuggingFace button header_layout = QHBoxLayout() - + title = QLabel("Dataset Management") title.setStyleSheet("font-size: 24px; font-weight: bold; color: #2c3e50;") header_layout.addWidget(title) - + header_layout.addStretch() - + # HuggingFace button in top right self.hf_button = HuggingFaceButton("Browse Datasets") self.hf_button.clicked.connect(self.on_huggingface_browse) header_layout.addWidget(self.hf_button) - + main_layout.addLayout(header_layout) # Load available analysis methods @@ -75,10 +76,10 @@ def init_ui(self): # Add sections main_layout.addWidget(self._create_list_section()) - + # Add alignments panel main_layout.addWidget(self._create_alignments_panel()) - + # Add register section at bottom main_layout.addWidget(self._create_register_section()) @@ -95,7 +96,7 @@ def refresh_page(self): """Refresh the entire page content""" self.dataset_table.clearSelection() self.refresh_datasets() - + def _create_register_section(self): """Create the dataset registration section""" group = QGroupBox() @@ -177,7 +178,7 @@ def _create_register_section(self): group.setLayout(layout) return group - + def on_huggingface_browse(self): """Handle HuggingFace button click""" # TODO: Implement HuggingFace dataset browsing/import @@ -185,9 +186,9 @@ def on_huggingface_browse(self): self, "HuggingFace Integration", "HuggingFace dataset browsing will be available soon!\n\n" - "This will allow you to browse and import datasets directly from HuggingFace Hub." + "This will allow you to browse and import datasets directly from HuggingFace Hub.", ) - + def on_import(self): """Handle import button click to open registration dialog""" dir_path = QFileDialog.getExistingDirectory( @@ -209,13 +210,13 @@ def on_import(self): else: QMessageBox.critical(self, "Import Failed", message) - + def on_export(self): - """Handle export button click for selected dataset""" + """Handle export button click for selected dataset""" if not self.selected_dataset: QMessageBox.warning(self, "No Dataset Selected", "Please select a dataset to export.") - return - + return + dir_path = QFileDialog.getExistingDirectory( self, "Select Directory to Save Exported Dataset", @@ -233,7 +234,6 @@ def on_export(self): else: QMessageBox.critical(self, "Export Failed", message) - def on_delete(self): """Handle delete button click for selected dataset""" if not self.selected_dataset: @@ -248,10 +248,10 @@ def on_delete(self): else: QMessageBox.critical(self, "Delete Failed", message) - + def _create_list_section(self): """Create the dataset list section""" - + group = QGroupBox("Datasets") group.setStyleSheet(""" QGroupBox { @@ -269,7 +269,7 @@ def _create_list_section(self): } """) layout = QVBoxLayout() - + # Add plus button at the top button_container = QWidget() button_container.setStyleSheet("background-color: transparent;") @@ -297,12 +297,12 @@ def _create_list_section(self): plus_btn.clicked.connect(self.open_registration_dialog) button_layout.addWidget(plus_btn) button_layout.addStretch() - + layout.addWidget(button_container) - + # Helper text helper_label = QLabel("💡 Select a dataset to view its alignments below") - + helper_label.setStyleSheet(""" QLabel { color: #3498db; @@ -325,14 +325,7 @@ def _create_list_section(self): self.dataset_table = QTableWidget() self.dataset_table.setColumnCount(6) self.dataset_table.setHorizontalHeaderLabels( - [ - "Name", - "Description", - "Cached", - "De-identified", - "Transcribed", - "Registration Date" - ] + ["Name", "Description", "Cached", "De-identified", "Transcribed", "Registration Date"] ) # Configure table @@ -347,7 +340,7 @@ def _create_list_section(self): self.dataset_table.setSelectionBehavior(QTableWidget.SelectionBehavior.SelectRows) self.dataset_table.setSelectionMode(QTableWidget.SelectionMode.SingleSelection) self.dataset_table.itemSelectionChanged.connect(self._on_dataset_selected) - + self.dataset_table.setAlternatingRowColors(True) self.dataset_table.setStyleSheet( """ @@ -381,8 +374,7 @@ def _create_list_section(self): # Add empty state label self.empty_label = QLabel( - "No datasets registered yet.\n" - "Use the form above to register your first dataset." + "No datasets registered yet.\nUse the form above to register your first dataset." ) self.empty_label.setAlignment(Qt.AlignmentFlag.AlignCenter) self.empty_label.setStyleSheet(""" @@ -400,7 +392,7 @@ def _create_list_section(self): group.setLayout(layout) return group - + def _create_alignments_panel(self): """Create the alignments panel for selected dataset""" group = QGroupBox("↓ Alignments for Selected Dataset") @@ -420,15 +412,15 @@ def _create_alignments_panel(self): color: #2c3e50; } """) - + layout = QVBoxLayout() - + # Engine filter filter_layout = QHBoxLayout() filter_label = QLabel("Filter by Engine:") filter_label.setStyleSheet("color: #2c3e50; font-weight: 500;") filter_layout.addWidget(filter_label) - + self.engine_filter_combo = QComboBox() self.engine_filter_combo.addItem("All Engines") @@ -448,14 +440,14 @@ def _create_alignments_panel(self): filter_layout.addWidget(self.engine_filter_combo) filter_layout.addStretch() layout.addLayout(filter_layout) - + # Alignments table self.alignments_table = QTableWidget() self.alignments_table.setColumnCount(5) self.alignments_table.setHorizontalHeaderLabels( ["Engine", "Model", "Date Aligned", "Status", "Actions"] ) - + # Configure alignments table align_header = self.alignments_table.horizontalHeader() align_header.setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents) @@ -464,12 +456,12 @@ def _create_alignments_panel(self): align_header.setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents) align_header.setSectionResizeMode(4, QHeaderView.ResizeMode.Fixed) self.alignments_table.setColumnWidth(4, 150) - + # Disable selection and editing self.alignments_table.setSelectionMode(QTableWidget.SelectionMode.NoSelection) self.alignments_table.setEditTriggers(QTableWidget.EditTrigger.NoEditTriggers) self.alignments_table.setFocusPolicy(Qt.FocusPolicy.NoFocus) - + self.alignments_table.setAlternatingRowColors(True) self.alignments_table.setMaximumHeight(300) self.alignments_table.setStyleSheet(""" @@ -490,12 +482,12 @@ def _create_alignments_panel(self): border: none; } """) - + layout.addWidget(self.alignments_table) - + group.setLayout(layout) return group - + def _on_dataset_selected(self): """Handle dataset selection change""" selected_items = self.dataset_table.selectedItems() @@ -504,7 +496,7 @@ def _on_dataset_selected(self): self.selected_dataset = None self.alignments_table.setRowCount(0) return - + # Get dataset name from first column of selected row row = selected_items[0].row() @@ -512,20 +504,20 @@ def _on_dataset_selected(self): item = self.dataset_table.item(row, 0) # The item we stored the ID on print(item) - + if item: dataset_id = item.data(Qt.ItemDataRole.UserRole) print(f"Selected dataset ID: {dataset_id}") self.selected_dataset = dataset_id self._load_alignments(dataset_id) - + def _load_alignments(self, dataset_id: datasets.DatasetMetadata["id"]): """Load alignments for the selected dataset""" print(f"Loading alignments for dataset ID: {dataset_id}") alignments_metadata: list = alignments.list_alignments(dataset_id) - + # Populate engine filter (block signals to prevent recursion) current_filter = self.engine_filter_combo.currentText() self.engine_filter_combo.blockSignals(True) @@ -534,15 +526,15 @@ def _load_alignments(self, dataset_id: datasets.DatasetMetadata["id"]): self.engine_filter_combo.addItems(sorted(ENGINE_IDS)) self.engine_filter_combo.setCurrentText(current_filter) self.engine_filter_combo.blockSignals(False) - + self._display_alignments(alignments_metadata) - + def _filter_alignments(self): """Filter alignments by selected engine""" if not self.selected_dataset: return self._load_alignments(self.selected_dataset) - + def _display_alignments(self, alignments: list[alignments.AlignmentMetadata]): """Display alignments in the table""" # Filter by engine if selected @@ -552,19 +544,19 @@ def _display_alignments(self, alignments: list[alignments.AlignmentMetadata]): engine_filter = self.engine_filter_combo.currentText() if engine_filter != "All Engines": alignments = [a for a in alignments if a["engine_id"] == engine_filter] - + self.alignments_table.setRowCount(len(alignments)) - + if not alignments: return - + for row, alignment in enumerate(alignments): # Engine print(alignment) engine_item = QTableWidgetItem(alignment["engine_id"]) engine_item.setFlags(engine_item.flags() & ~Qt.ItemFlag.ItemIsEditable) self.alignments_table.setItem(row, 0, engine_item) - + # Model (clickable) model_item = QTableWidgetItem(alignment["model_metadata"]["name"]) model_item.setFlags(model_item.flags() & ~Qt.ItemFlag.ItemIsEditable) @@ -574,12 +566,12 @@ def _display_alignments(self, alignments: list[alignments.AlignmentMetadata]): font.setUnderline(True) model_item.setFont(font) self.alignments_table.setItem(row, 1, model_item) - + # Date Aligned date_item = QTableWidgetItem(alignment["alignment_date"]) date_item.setFlags(date_item.flags() & ~Qt.ItemFlag.ItemIsEditable) self.alignments_table.setItem(row, 2, date_item) - + # Status status_item = QTableWidgetItem(alignment.get("status", "Unknown")) status_item.setFlags(status_item.flags() & ~Qt.ItemFlag.ItemIsEditable) @@ -591,11 +583,11 @@ def _display_alignments(self, alignments: list[alignments.AlignmentMetadata]): elif status == "failed": status_item.setForeground(Qt.GlobalColor.red) self.alignments_table.setItem(row, 3, status_item) - + # Actions actions_widget = self._create_alignment_action_buttons(alignment) self.alignments_table.setCellWidget(row, 4, actions_widget) - + # Disconnect old connections and connect cell click for model column try: self.alignments_table.cellClicked.disconnect() @@ -603,7 +595,7 @@ def _display_alignments(self, alignments: list[alignments.AlignmentMetadata]): # No connections exist yet pass self.alignments_table.cellClicked.connect(self._on_alignment_cell_clicked) - + def _on_alignment_cell_clicked(self, row: int, col: int): """Handle click on alignment table cell""" # If model column (col 1) is clicked @@ -613,16 +605,16 @@ def _on_alignment_cell_clicked(self, row: int, col: int): QMessageBox.information( self, "Model Details", - f"Model clicked: {model_name}\n\nModel details will be shown here." + f"Model clicked: {model_name}\n\nModel details will be shown here.", ) - + def _create_alignment_action_buttons(self, alignment: dict): """Create action buttons for an alignment row""" widget = QWidget() layout = QHBoxLayout(widget) layout.setContentsMargins(5, 2, 5, 2) layout.setSpacing(5) - + button_style = f""" QPushButton {{ background-color: {Colors.SUCCESS}; @@ -664,14 +656,14 @@ def _create_alignment_action_buttons(self, alignment: dict): view_btn = QPushButton("View") view_btn.setStyleSheet(button_style) view_btn.clicked.connect(lambda: self._view_alignment(alignment)) - + return widget - + def convert_alignments(self, dataset_name: str) -> list: """Generate mock alignment data (to be replaced with actual data loading)""" data_set_alignments = alignments.list_alignments(dataset_name) return data_set_alignments - + def _view_alignment(self, alignment: dict): """View alignment details""" # TODO: Implement alignment details view @@ -681,42 +673,37 @@ def _view_alignment(self, alignment: dict): f"Engine: {alignment['engine']}\n" f"Model: {alignment['model']}\n" f"Date: {alignment['date_aligned']}\n" - f"Status: {alignment['status']}" + f"Status: {alignment['status']}", ) - + def _export_alignment(self, alignment: dict): """Export alignment results""" # TODO: Implement alignment export QMessageBox.information( self, "Export Alignment", - f"Export functionality for alignment '{alignment['model']}' will be implemented." + f"Export functionality for alignment '{alignment['model']}' will be implemented.", ) - + def _delete_alignment(self, alignment: dict): """Delete an alignment""" reply = QMessageBox.question( self, "Confirm Delete", - f"Are you sure you want to delete the alignment with model '{alignment['model_metadata']['id']}'?", + f"Are you sure you want to delete the alignment with model " + f"'{alignment['model_metadata']['id']}'?", QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, ) - + if reply == QMessageBox.StandardButton.Yes: - success, msg = alignments.delete_alignment(dataset_id=self.selected_dataset, alignment_id=alignment["id"]) + success, msg = alignments.delete_alignment( + dataset_id=self.selected_dataset, alignment_id=alignment["id"] + ) if not success: - QMessageBox.critical( - self, - "Delete Failed", - f"Failed to delete alignments: {msg}" - ) + QMessageBox.critical(self, "Delete Failed", f"Failed to delete alignments: {msg}") return - QMessageBox.information( - self, - "Deleted", - f"Alignments deleted successfully." - ) + QMessageBox.information(self, "Deleted", "Alignments deleted successfully.") # Refresh alignments list if self.selected_dataset: self._load_alignments(self.selected_dataset) @@ -785,11 +772,11 @@ def open_registration_dialog(self): ), ], ) - + # Create and show dialog - pass self as parent dialog = GenericDialog(self, config=config) result = dialog.exec() - + if result == QDialog.DialogCode.Accepted: # Get values from dialog values = dialog.get_values() @@ -805,26 +792,20 @@ def process_registration(self, values: dict): cache = values.get("cache", False) anonymize = values.get("anonymize", False) transcribed = values.get("transcribed", False) - + # Validation if not dataset_path: - QMessageBox.warning( - self, "Input Error", "Please provide a dataset path." - ) + QMessageBox.warning(self, "Input Error", "Please provide a dataset path.") return - + if not dataset_name: - QMessageBox.warning( - self, "Input Error", "Please enter a dataset name." - ) + QMessageBox.warning(self, "Input Error", "Please enter a dataset name.") return - + if not description: - QMessageBox.warning( - self, "Input Error", "Please enter a description." - ) + QMessageBox.warning(self, "Input Error", "Please enter a description.") return - + # Start registration in worker thread print( f"Starting dataset registration with params: {dataset_path}, {dataset_name}, " @@ -834,15 +815,16 @@ def process_registration(self, values: dict): self.registration_worker = DatasetRegistrationWorker( dataset_path, dataset_name, description, cache, anonymize, transcribed, analysis_method ) + if self.registration_worker is None: + return + self.registration_worker.progress.connect(self.show_progress) self.registration_worker.finished.connect(self.registration_complete) self.registration_worker.start() def browse_dataset_path(self): """Open directory picker for dataset path""" - directory = QFileDialog.getExistingDirectory( - self, "Select Dataset Root Directory" - ) + directory = QFileDialog.getExistingDirectory(self, "Select Dataset Root Directory") if directory: return directory return None @@ -881,7 +863,7 @@ def refresh_datasets(self): # Name — store the dataset ID on this item name_item = QTableWidgetItem(meta["name"]) name_item.setData(Qt.ItemDataRole.UserRole, meta["id"]) # ← Store ID here - self.dataset_table.setItem(index, 0, name_item) # ← Use the same item! + self.dataset_table.setItem(index, 0, name_item) # ← Use the same item! # Description desc_item = QTableWidgetItem(meta["description"]) @@ -935,14 +917,14 @@ def _create_action_buttons(self, dataset): background-color: #e0e0e0; } """ - + # Export button export_btn = QPushButton("Export") export_btn.setMaximumWidth(60) export_btn.setStyleSheet(button_style) export_btn.clicked.connect(lambda: self._export_dataset(dataset)) layout.addWidget(export_btn) - + # Delete button delete_btn = QPushButton("Delete") delete_btn.setMaximumWidth(60) @@ -983,7 +965,7 @@ def view_dataset(self, dataset): def transcribe_dataset(self, dataset): """Transcribe a dataset using available transcription engines""" # Check if already transcribed - if dataset.get('transcribed', False): + if dataset.get("transcribed", False): reply = QMessageBox.question( self, "Already Transcribed", @@ -993,7 +975,7 @@ def transcribe_dataset(self, dataset): ) if reply != QMessageBox.StandardButton.Yes: return - + # Placeholder for transcription functionality QMessageBox.information( self, @@ -1002,7 +984,7 @@ def transcribe_dataset(self, dataset): "This will use WhisperX or similar transcription engines to generate " "transcriptions for all audio files in the dataset.", ) - + # TODO: Implement actual transcription logic here # - Check for available transcription engines (WhisperX, etc.) # - Show transcription settings dialog @@ -1011,7 +993,7 @@ def transcribe_dataset(self, dataset): def _export_dataset(self, dataset): """Export dataset configuration""" - # Prompt for directory path to save exported dataset + # Prompt for directory path to save exported dataset dir_path = QFileDialog.getExistingDirectory( self, diff --git a/src/voxkit/gui/pages/datasets/utils.py b/src/voxkit/gui/pages/datasets/utils.py index 3128460..6ff3052 100644 --- a/src/voxkit/gui/pages/datasets/utils.py +++ b/src/voxkit/gui/pages/datasets/utils.py @@ -1,14 +1,14 @@ from PyQt6.QtWidgets import QFileDialog, QMessageBox -from voxkit.storage import datasets +from voxkit.storage import datasets def on_export(self): - """Handle export button click for selected dataset""" + """Handle export button click for selected dataset""" if not self.selected_dataset: QMessageBox.warning(self, "No Dataset Selected", "Please select a dataset to export.") - return - + return + dir_path = QFileDialog.getExistingDirectory( self, "Select Directory to Save Exported Dataset", @@ -26,6 +26,7 @@ def on_export(self): else: QMessageBox.critical(self, "Export Failed", message) + def on_delete(self): """Handle delete button click for selected dataset""" if not self.selected_dataset: @@ -39,4 +40,3 @@ def on_delete(self): else: QMessageBox.critical(self, "Delete Failed", message) - \ No newline at end of file diff --git a/src/voxkit/gui/pages/models/import_dialog.py b/src/voxkit/gui/pages/models/import_dialog.py index f03f243..db3b8a6 100644 --- a/src/voxkit/gui/pages/models/import_dialog.py +++ b/src/voxkit/gui/pages/models/import_dialog.py @@ -1,4 +1,5 @@ from typing import Callable, Optional + from PyQt6.QtWidgets import QMessageBox from voxkit.gui.frameworks.settings_modal import ( @@ -53,7 +54,7 @@ def __init__( dimensions=(450, 250), apply_blur=True, fields=fields, - store_file=f"{self.engine_id}/imported_models_dialog.json", + store_file=f"{self.engine_id}/imported_models_dialog.json", ) super().__init__( @@ -76,7 +77,7 @@ def accept(self): super().accept() def _placeholder_import(self, model_path: str): - parts = model_path.split('/') if model_path else [] + parts = model_path.split("/") if model_path else [] model_name = parts[-1] if parts else (model_path or "NONE") print(f"Creating destination for engine: {self.engine_id}, key: {model_name}") success, message = models.create_model( @@ -84,11 +85,7 @@ def _placeholder_import(self, model_path: str): model_name=model_name, ) if not success: - QMessageBox.critical( - self, - "Import Failed", - f"Failed to create model entry: {message}" - ) + QMessageBox.critical(self, "Import Failed", f"Failed to create model entry: {message}") else: dest_model_path = message["model_path"] result = download_and_copy_huggingface_model(model_path, destination=dest_model_path) @@ -97,7 +94,7 @@ def _placeholder_import(self, model_path: str): print("Failed to download model") else: print(f"Model imported successfully to: {result}") - + def main(): # Example usage: diff --git a/src/voxkit/gui/pages/models/models_page.py b/src/voxkit/gui/pages/models/models_page.py index 3dd2bf0..024f4e9 100644 --- a/src/voxkit/gui/pages/models/models_page.py +++ b/src/voxkit/gui/pages/models/models_page.py @@ -1,61 +1,63 @@ +from typing import Any + from PyQt6.QtCore import Qt from PyQt6.QtWidgets import QLabel, QMessageBox -from voxkit.gui.components import HuggingFaceButton +from voxkit.engines import engines from voxkit.gui.frameworks.categorical_table.categorical_table import CategoricalTableWidget from voxkit.storage import models from .import_dialog import ImportModelDialog from .utils import handle_delete, handle_export, handle_import -from voxkit.engines import engines + ENGINE_IDS = engines.list_engines() -# TODO : Implement Aligner managment logic by see frameworks/widget/categorical_list/api.py | __init__.py +# TODO : Implement Aligner managment logic by see +# frameworks/widget/categorical_list/api.py | __init__.py ENGINE_IDS = engines.list_engines() + class ManageAlignersWidget(CategoricalTableWidget): """Widget to manage and display alignment models.""" def __init__(self, parent=None): - self.parent = parent self.data = {} - def refresh_models_function() -> list[dict]: + def refresh_models_function() -> dict[str, list[dict[Any, Any]]]: try: - - model_dict = {} + model_dict = {} for engine in ENGINE_IDS: engine_models = models.list_models(engine) model_dict[engine] = engine_models return model_dict - + except Exception as e: print(f"Error refreshing models: {e}") return {} - - def export_models_function(category: str, items: dict) -> tuple[bool, str]: + + def export_models_function(category: str, items: list[dict[Any, Any]]) -> tuple[bool, str]: print(f"Export requested for category: {category}") return handle_export(self, items, category) - + def import_model_function(category: str) -> tuple[bool, str]: print(f"Import requested for category: {category}") return handle_import(self, category) - - def delete_models_function(category: str, items: dict) -> None: + + def delete_models_function(category: str, items: list[dict[Any, Any]]) -> tuple[bool, str]: print(f"Delete requested for category: {category}, items: {items}") return handle_delete(self, items, category) - + super().__init__( - refresh_data_function=refresh_models_function, - export_function=export_models_function, - import_function=import_model_function, - delete_function=delete_models_function, - columns_shown=["name","download_date", "id"], - huggingface_callback=self.on_huggingface_browse, - parent=self.parent - ) + refresh_data_function=refresh_models_function, + export_function=export_models_function, + import_function=import_model_function, + delete_function=delete_models_function, + columns_shown=["name", "download_date", "id"], + huggingface_callback=self.on_huggingface_browse, + parent=self.parent, + ) self.setWindowTitle("Model Manager") @@ -75,7 +77,7 @@ def on_huggingface_browse(self): self, "HuggingFace Integration", "HuggingFace model browsing will be available soon!\n\n" - "This will allow you to browse and import models directly from HuggingFace Hub." + "This will allow you to browse and import models directly from HuggingFace Hub.", ) def scrub_training_runs(self, mode, items: dict): @@ -86,7 +88,7 @@ def scrub_training_runs(self, mode, items: dict): mode = "W2TGENGINE" else: raise ValueError("Invalid mode") - + print(f"Scrubbing training run for model: {items}") models.delete_model(mode, items[model]["id"]) diff --git a/src/voxkit/gui/pages/models/utils.py b/src/voxkit/gui/pages/models/utils.py index d2cd4fd..9f426b2 100644 --- a/src/voxkit/gui/pages/models/utils.py +++ b/src/voxkit/gui/pages/models/utils.py @@ -3,13 +3,13 @@ """ import shutil -from pathlib import Path from datetime import datetime -import json +from pathlib import Path -from voxkit.storage import models from PyQt6.QtWidgets import QFileDialog, QMessageBox +from voxkit.storage import models + def handle_import(parent_widget, current_category: str): """ @@ -29,15 +29,13 @@ def handle_import(parent_widget, current_category: str): if not import_path: return False, "" - - return models.import_models(current_category, Path(import_path)) def handle_delete(parent_widget, selected_items: list, current_category: str) -> tuple[bool, str]: """ - Handle deleting selected models from storage. + Handle deleting selected models from storage. Args: parent_widget: Parent widget for dialogs @@ -51,7 +49,7 @@ def handle_delete(parent_widget, selected_items: list, current_category: str) -> success, msg = models.delete_model(current_category, item["id"]) if not success: return False, f"Failed to delete model {item['id']}: {msg}" - + return True, "Selected models deleted successfully." except Exception as e: return False, f"Error deleting models: {str(e)}" @@ -81,21 +79,19 @@ def handle_export( if not export_base_dir: return False, "" - - folder_name = f'voxkit_models_{datetime.now().strftime("%Y%m%d_%H%M%S")}' + + folder_name = f"voxkit_models_{datetime.now().strftime('%Y%m%d_%H%M%S')}" # Create the export folder export_folder = Path(export_base_dir) / folder_name export_folder.mkdir(parents=True, exist_ok=True) - # Track success/failure copied_count = 0 failed_items = [] # Copy each selected item for item in selected_items: - if not item: failed_items.append(f"{item} (not found in data)") continue @@ -104,7 +100,7 @@ def handle_export( if isinstance(item, dict): source_path = models._get_model_root(current_category, item["id"]) else: - failed_items.append(f"{item} (invalid item format)") + failed_items.append(f"{item} (invalid item format)") if not source_path.exists(): failed_items.append(f"{str(source_path)} (source path does not exist)") @@ -119,8 +115,6 @@ def handle_export( else: failed_items.append(f"{item} (unknown type)") - - # Build result message if copied_count == len(selected_items): message = f"Successfully exported {copied_count} item(s) to:\n{export_folder}" @@ -169,9 +163,7 @@ def _handler(folder_name: str, selected_items: list): current_category = widget.category_keys[widget.current_category_index] # Call the main export function - success, message = handle_export( - widget, folder_name, selected_items, data, current_category - ) + success, message = handle_export(widget, selected_items, current_category) # Show result to user if success: diff --git a/src/voxkit/gui/pages/pipeline/__init__.py b/src/voxkit/gui/pages/pipeline/__init__.py index 3400ab0..427dc2a 100644 --- a/src/voxkit/gui/pages/pipeline/__init__.py +++ b/src/voxkit/gui/pages/pipeline/__init__.py @@ -1,7 +1,6 @@ # TODO : Add module docstring -from PyQt6.QtCore import Qt -from PyQt6.QtWidgets import QHBoxLayout, QLabel, QListWidget, QSizePolicy, QVBoxLayout, QWidget +from PyQt6.QtWidgets import QHBoxLayout, QListWidget, QVBoxLayout, QWidget from voxkit.config import Dimensions from voxkit.gui.components import AnimatedStackedWidget @@ -60,7 +59,7 @@ def reload(self): if isinstance(training_page, TrainingStacker): training_page.reload_models() training_page.reload_datasets() - + predicting_page = self.stacked_widget.widget(1) if isinstance(predicting_page, PredictionStacker): predicting_page.reload_models() @@ -69,7 +68,7 @@ def reload(self): pllr_page = self.stacked_widget.widget(2) if isinstance(pllr_page, PLLRStacker): pllr_page.reload_datasets() - + def change_page(self, index): """Change the displayed page based on menu selection with animation""" if index >= 0: # Valid index diff --git a/src/voxkit/gui/pages/pipeline/evaluation_stacker.py b/src/voxkit/gui/pages/pipeline/evaluation_stacker.py index ee8d669..8c1f413 100644 --- a/src/voxkit/gui/pages/pipeline/evaluation_stacker.py +++ b/src/voxkit/gui/pages/pipeline/evaluation_stacker.py @@ -14,9 +14,9 @@ QVBoxLayout, QWidget, ) +from voxkit.storage.validation import validate_paths from voxkit.gui.workers.worker_thread import WorkerThread -from voxkit.storage.validation import validate_paths from .styles import BrowseButtonStyle @@ -179,9 +179,18 @@ def evaluate_logic(self, method, ref_path, pred_path, out_path): try: # Example evaluation logic (replace with real evaluation later) subprocess.run( - f'python scripts/evaluate_aligner.py --method "{method}" ' - f'--reference "{ref_path}" --predicted "{pred_path}" --output "{out_path}"', - shell=True, + [ + "python", + "scripts/evaluate_aligner.py", + "--method", + method, + "--reference", + ref_path, + "--predicted", + pred_path, + "--output", + out_path, + ], check=True, ) return f"Evaluation completed successfully using {method}" diff --git a/src/voxkit/gui/pages/pipeline/pllr_stacker.py b/src/voxkit/gui/pages/pipeline/pllr_stacker.py index 779ec9f..50986e0 100644 --- a/src/voxkit/gui/pages/pipeline/pllr_stacker.py +++ b/src/voxkit/gui/pages/pipeline/pllr_stacker.py @@ -3,7 +3,6 @@ from pypllrcomputer import compute_pllr from PyQt6.QtCore import Qt from PyQt6.QtWidgets import ( - QComboBox, QDialog, QFileDialog, QHBoxLayout, @@ -16,17 +15,17 @@ ) from voxkit.config import Defaults +from voxkit.gui.components import MultiColumnComboBox from voxkit.gui.frameworks.settings_modal import ( FieldConfig, FieldType, GenericDialog, SettingsConfig, ) +from voxkit.gui.utils import validate_path, validate_paths from voxkit.gui.workers.worker_thread import WorkerThread from voxkit.storage import alignments, datasets from voxkit.storage.utils import get_storage_root -from voxkit.gui.utils import validate_path, validate_paths -from voxkit.gui.components import MultiColumnComboBox from .styles import BrowseButtonStyle @@ -80,8 +79,10 @@ def get_pllr_settings_config() -> SettingsConfig: fields = FIELDS.copy() for field in fields: if field.name == "likelihood_dct" and not field.default_value: - field.default_value = str(get_storage_root() / "computed-likelihoods" / "likelihood_dict.pkl") - + field.default_value = str( + get_storage_root() / "computed-likelihoods" / "likelihood_dict.pkl" + ) + return SettingsConfig( title="PLLR Extraction Settings", dimensions=(400, 400), @@ -100,9 +101,8 @@ def __init__(self, parent=None): self.init_ui() def on_extract_settings(self): - settings_dialog = GenericDialog(self, get_pllr_settings_config()) - + result = settings_dialog.exec() # Clean up @@ -110,35 +110,53 @@ def on_extract_settings(self): if result == QDialog.DialogCode.Accepted: settings_dialog.save() - + def on_dataset_selected(self): """Handle dataset selection change and load corresponding alignments""" selected_index = self.pllr_dataset_dropdown.currentIndex() selected_dataset_id = self.pllr_dataset_dropdown.itemData(selected_index) - + # Clear alignment dropdown self.pllr_alignment_dropdown.clear() - + if selected_dataset_id and selected_dataset_id != "No datasets registered": # Load alignments for this dataset alignments_list = alignments.list_alignments(selected_dataset_id) - + if alignments_list: rows = [] for alignment in alignments_list: # Display format: "EngineID - ModelName (Date)" - rows.append({"id": alignment["id"], "data": (alignment["engine_id"], alignment["model_metadata"]["name"], alignment["alignment_date"], alignment["status"])}) + rows.append( + { + "id": alignment["id"], + "data": ( + alignment["engine_id"], + alignment["model_metadata"]["name"], + alignment["alignment_date"], + alignment["status"], + ), + } + ) + + self.pllr_alignment_dropdown.set_data( + rows, + ["Engine ID", "Model Name", "Date Registered", "Status"], + placeholder="Click to select an alignment", + ) - self.pllr_alignment_dropdown.set_data(rows, ["Engine ID", "Model Name", "Date Registered", "Status"], placeholder="Click to select an alignment") - self.pllr_alignment_dropdown.setEnabled(True) else: - self.pllr_alignment_dropdown.set_data([{"id": None, "data": ("No alignments registered", "", "")}], ["Method", "Model", "Date", "Status"], placeholder="No alignments registered") + self.pllr_alignment_dropdown.set_data( + [{"id": None, "data": ("No alignments registered", "", "")}], + ["Method", "Model", "Date", "Status"], + placeholder="No alignments registered", + ) self.pllr_alignment_dropdown.setEnabled(False) else: self.pllr_alignment_dropdown.set_data([], []) self.pllr_alignment_dropdown.setEnabled(False) - + def reload_datasets(self): """Reload datasets in the dropdown""" self.pllr_dataset_dropdown.clear() @@ -147,18 +165,21 @@ def reload_datasets(self): headers = ["Name", "Date", "Description"] if datasets_meta: for d in datasets_meta: - name = d["name"] date_registered = d["registration_date"] id = d["id"] description = d["description"] - data.append({"id": id, "data": (name,date_registered,description)}) + data.append({"id": id, "data": (name, date_registered, description)}) self.pllr_dataset_dropdown.set_data(data, headers, placeholder="Select a dataset") self.pllr_dataset_dropdown.setEnabled(True) else: self.pllr_dataset_dropdown.set_data([], [], placeholder="No datasets registered") self.pllr_dataset_dropdown.setEnabled(False) - self.pllr_alignment_dropdown.set_data([{"id": None, "data": ("Select a dataset first", "", "")}], ["Method", "Model", "Date", "Status"], placeholder="Select a dataset first") # Line 151 + self.pllr_alignment_dropdown.set_data( + [{"id": None, "data": ("Select a dataset first", "", "")}], + ["Method", "Model", "Date", "Status"], + placeholder="Select a dataset first", + ) # Line 151 self.pllr_alignment_dropdown.setEnabled(False) def init_ui(self): @@ -208,7 +229,7 @@ def init_ui(self): dataset_label = QLabel("① Choose a PLLR Dataset") dataset_label.setStyleSheet("font-weight: bold; color: #2c3e50;") layout.addWidget(dataset_label) - + self.pllr_dataset_dropdown = MultiColumnComboBox() self.pllr_dataset_dropdown.setStyleSheet(""" QComboBox { @@ -223,20 +244,19 @@ def init_ui(self): color: #999; } """) - # Connect to selection handler self.pllr_dataset_dropdown.currentIndexChanged.connect(self.on_dataset_selected) - + layout.addWidget(self.pllr_dataset_dropdown) - + # Alignment selection dropdown (initially disabled) alignment_label = QLabel("② Choose an Alignment") alignment_label.setStyleSheet("font-weight: bold; color: #2c3e50;") layout.addWidget(alignment_label) - + self.pllr_alignment_dropdown = MultiColumnComboBox() - self.pllr_alignment_dropdown.setStyleSheet( """ + self.pllr_alignment_dropdown.setStyleSheet(""" QComboBox { padding: 0px 8px; border: 1px solid #bdc3c7; @@ -251,9 +271,13 @@ def init_ui(self): """) - self.pllr_alignment_dropdown.set_data([{"id": None, "data": ("Select a dataset first", "", "")}], ["Method", "Model", "Date", "Status"], placeholder="Select a dataset first") + self.pllr_alignment_dropdown.set_data( + [{"id": None, "data": ("Select a dataset first", "", "")}], + ["Method", "Model", "Date", "Status"], + placeholder="Select a dataset first", + ) self.pllr_alignment_dropdown.setEnabled(False) - + layout.addWidget(self.pllr_alignment_dropdown) # Populate with registered datasets @@ -326,91 +350,89 @@ def browse_directory(self, line_edit): def on_extract_pllr(self): """Handle Extract PLLR button click""" print("\n=== PLLR EXTRACTION STARTED ===") - + # Get selected dataset selected_index = self.pllr_dataset_dropdown.currentIndex() selected_dataset_id = self.pllr_dataset_dropdown.itemData(selected_index) print(f"[DEBUG] Dataset index: {selected_index}, Dataset ID: {selected_dataset_id}") - + if not selected_dataset_id or selected_dataset_id == "No datasets registered": print("[ERROR] No dataset selected or invalid dataset ID") QMessageBox.warning( - self, "No Dataset Selected", - "Please select a dataset from the dropdown." + self, "No Dataset Selected", "Please select a dataset from the dropdown." ) return - + print(f"[DEBUG] Dataset validated: {selected_dataset_id}") - + # Get selected alignment alignment_index = self.pllr_alignment_dropdown.currentIndex() selected_alignment_id = self.pllr_alignment_dropdown.itemData(alignment_index) print(f"[DEBUG] Alignment index: {alignment_index}, Alignment ID: {selected_alignment_id}") - + if not selected_alignment_id: print("[ERROR] No alignment selected or invalid alignment ID") QMessageBox.warning( - self, "No Alignment Selected", - "Please select an alignment from the dropdown." + self, "No Alignment Selected", "Please select an alignment from the dropdown." ) return - + print(f"[DEBUG] Alignment validated: {selected_alignment_id}") - print(f"[DEBUG] Fetching alignment metadata...") - + print("[DEBUG] Fetching alignment metadata...") + # Get alignment path - alignment_data = alignments.get_alignment_metadata(selected_dataset_id, selected_alignment_id) + alignment_data = alignments.get_alignment_metadata( + selected_dataset_id, selected_alignment_id + ) print(f"[DEBUG] Alignment data retrieved: {alignment_data}") if not alignment_data: print(f"[ERROR] Could not find alignment data for ID: {selected_alignment_id}") QMessageBox.warning( - self, "Invalid Alignment", - f"Could not find alignment data for ID '{selected_alignment_id}'." + self, + "Invalid Alignment", + f"Could not find alignment data for ID '{selected_alignment_id}'.", ) return - + textgrid_path = alignment_data["tg_path"] + "/cache" print(f"[DEBUG] TextGrid path from alignment: {textgrid_path}") - print(f"[DEBUG] Checking if TextGrid path exists...") - + print("[DEBUG] Checking if TextGrid path exists...") + if not textgrid_path or not Path(textgrid_path).exists(): print(f"[ERROR] TextGrid path does not exist: {textgrid_path}") QMessageBox.warning( - self, "Invalid Alignment Path", - f"Alignment output path does not exist: {textgrid_path}" + self, + "Invalid Alignment Path", + f"Alignment output path does not exist: {textgrid_path}", ) return - + print(f"[DEBUG] TextGrid path validated: {textgrid_path}") - print(f"[DEBUG] Fetching dataset root path...") - + print("[DEBUG] Fetching dataset root path...") + # Get dataset path dataset_meta = datasets.get_dataset_metadata(selected_dataset_id) wavlab_path = None - if not (dataset_meta['cached'] == "True" or dataset_meta['cached'] is True): - wavlab_path = dataset_meta['original_path'] + if not (dataset_meta["cached"] == "True" or dataset_meta["cached"] is True): + wavlab_path = dataset_meta["original_path"] else: wavlab_path = datasets._get_dataset_root(selected_dataset_id) / "cache" - print(f"[DEBUG] Dataset root path: {wavlab_path}") - + if not wavlab_path: print(f"[ERROR] Could not find dataset path for ID: {selected_dataset_id}") - QMessageBox.warning( - self, "Invalid Dataset", - f"Could not find path for dataset." - ) + QMessageBox.warning(self, "Invalid Dataset", "Could not find path for dataset.") return - + print(f"[DEBUG] Dataset path validated: {wavlab_path}") - print(f"[DEBUG] Validating all paths...") - + print("[DEBUG] Validating all paths...") + # Validate inputs - print(f"[DEBUG] Preparing paths for validation...") + print("[DEBUG] Preparing paths for validation...") paths = { "TextGrid Path": textgrid_path, @@ -425,7 +447,7 @@ def on_extract_pllr(self): return print("[DEBUG] All paths validated successfully") - + # Get current settings output_path = self.extract_output_path.text() @@ -442,7 +464,7 @@ def on_extract_pllr(self): self.extract_btn.setEnabled(False) print("[DEBUG] Starting worker thread...") - + # Start worker thread self.worker = WorkerThread( lambda: self.extract_pllr_logic(textgrid_path, wavlab_path, output_path) @@ -455,19 +477,19 @@ def extract_pllr_logic(self, textgrid_path, wavlab_path, output_path): """Actual PLLR extraction logic""" print("\n=== EXTRACT PLLR LOGIC ===") - print(f"[LOGIC] Starting PLLR extraction...") + print("[LOGIC] Starting PLLR extraction...") print(f"[LOGIC] TextGrid Path: {textgrid_path}") print(f"[LOGIC] Wav/lab Path: {wavlab_path}") print(f"[LOGIC] Output Path: {output_path}") - + phonewise_path = str(Path(output_path) / "phonewise_proba.csv") framewise_path = str(Path(output_path) / "framewise_proba.csv") - + print(f"[LOGIC] Phonewise output file: {phonewise_path}") print(f"[LOGIC] Framewise output file: {framewise_path}") - + # Check what files exist - print(f"[LOGIC] Checking TextGrid directory contents...") + print("[LOGIC] Checking TextGrid directory contents...") tg_path_obj = Path(textgrid_path) if tg_path_obj.is_dir(): tg_files = list(tg_path_obj.glob("*.TextGrid")) @@ -476,8 +498,8 @@ def extract_pllr_logic(self, textgrid_path, wavlab_path, output_path): print(f"[LOGIC] First few TextGrid files: {[f.name for f in tg_files[:5]]}") else: print(f"[LOGIC] TextGrid path is not a directory: {textgrid_path}") - - print(f"[LOGIC] Checking wav/lab directory contents...") + + print("[LOGIC] Checking wav/lab directory contents...") wav_path_obj = Path(wavlab_path) if wav_path_obj.is_dir(): wav_files = list(wav_path_obj.glob("*.wav")) @@ -486,32 +508,33 @@ def extract_pllr_logic(self, textgrid_path, wavlab_path, output_path): print(f"[LOGIC] First few wav files: {[f.name for f in wav_files[:5]]}") else: print(f"[LOGIC] Wav/lab path is not a directory: {wavlab_path}") - + # READ THE SETTINGS FROM THE FILE - print(f"[LOGIC] Reading settings from file...") + print("[LOGIC] Reading settings from file...") path_to_pllr_settings = get_storage_root() / "pllr_settings.json" pllr_settings = {} if path_to_pllr_settings.exists(): print(f"[LOGIC] Settings file found at: {path_to_pllr_settings}") from json import load as json_load + with open(path_to_pllr_settings, "r") as f: pllr_settings = json_load(f) print(f"[LOGIC] Settings loaded: {pllr_settings}") else: print(f"[LOGIC] Settings file not found at: {path_to_pllr_settings}") - for key in PLLR_SETTINGS_CONFIG.fields: + config = get_pllr_settings_config() + for key in config.fields: pllr_settings[key.name] = key.default_value print(f"[LOGIC] Default settings loaded: {pllr_settings}") - print(f"[LOGIC] Calling compute_pllr()...") - print(f"[LOGIC] Parameters:") + print("[LOGIC] Calling compute_pllr()...") + print("[LOGIC] Parameters:") print(f" tg_files_path={textgrid_path}") print(f" wav_files_path={wavlab_path}") - print(f" phone_key='phones'") + print(" phone_key='phones'") print(f" phonewise_proba_df={phonewise_path}") print(f" framewise_proba_df={framewise_path}") - print(f" likelihood_dct=None") - - + print(" likelihood_dct=None") + try: compute_pllr( tg_files_path=textgrid_path, @@ -521,27 +544,30 @@ def extract_pllr_logic(self, textgrid_path, wavlab_path, output_path): framewise_proba_df=framewise_path, recompute_probas=pllr_settings.get("recompute_probas", True), likelihood_dct=pllr_settings.get("likelihood_dct", None), - # aggregation_function=pllr_settings.get("aggregation_function", "aggregate_by_phoneme_occurrence"), + # aggregation_function=pllr_settings.get( + # "aggregation_function", "aggregate_by_phoneme_occurrence" + # ), ) - print(f"[LOGIC] compute_pllr() completed successfully") + print("[LOGIC] compute_pllr() completed successfully") return "PLLR extracted successfully" except Exception as e: print(f"[ERROR] Exception in compute_pllr(): {type(e).__name__}") print(f"[ERROR] Exception message: {str(e)}") import traceback + print(f"[ERROR] Traceback:\n{traceback.format_exc()}") raise def on_extract_finished(self, success, message): """Handle completion of extract PLLR operation""" - print(f"\n=== PLLR EXTRACTION FINISHED ===") + print("\n=== PLLR EXTRACTION FINISHED ===") print(f"[FINISHED] Success: {success}") print(f"[FINISHED] Message: {message}") - + self.extract_btn.setEnabled(True) if success: - print(f"[FINISHED] Extraction completed successfully") + print("[FINISHED] Extraction completed successfully") self.extract_status.setText("✓ " + message) self.extract_status.setStyleSheet("color: #27ae60; font-size: 12px; margin-top: 5px;") QMessageBox.information(self, "Success", message) diff --git a/src/voxkit/gui/pages/pipeline/prediction_stacker.py b/src/voxkit/gui/pages/pipeline/prediction_stacker.py index 1e5187c..ff498e3 100644 --- a/src/voxkit/gui/pages/pipeline/prediction_stacker.py +++ b/src/voxkit/gui/pages/pipeline/prediction_stacker.py @@ -1,15 +1,10 @@ -from pathlib import Path - from PyQt6.QtCore import Qt from PyQt6.QtWidgets import ( QButtonGroup, - QComboBox, QDialog, - QFileDialog, QGroupBox, QHBoxLayout, QLabel, - QLineEdit, QMessageBox, QPushButton, QRadioButton, @@ -19,13 +14,11 @@ from voxkit.config import Defaults from voxkit.engines import engines +from voxkit.gui.components import MultiColumnComboBox from voxkit.gui.frameworks.settings_modal import GenericDialog from voxkit.gui.workers.worker_thread import WorkerThread from voxkit.storage import datasets, models -from voxkit.gui.utils import validate_path, validate_paths -from voxkit.gui.components import MultiColumnComboBox -from .styles import BrowseButtonStyle class PredictionStacker(QWidget): def __init__(self, parent): @@ -44,11 +37,11 @@ def on_mode_changed(self): if radio.isChecked(): self.selected_engine = engine_id break - + # Show/hide appropriate dropdowns for engine_id, dropdown in self.engine_dropdowns.items(): dropdown.setVisible(engine_id == self.selected_engine) - + print(f"Engine changed to: {self.selected_engine}") def reload_models(self): @@ -56,36 +49,52 @@ def reload_models(self): for engine_id, dropdown in self.engine_dropdowns.items(): dropdown.clear() model_list = models.list_models(engine_id) - + if model_list: # Handle different model list formats data = [] for m in model_list: if isinstance(m, dict): - data.append({"id": m["id"], "data": (m["name"], m["download_date"], m['id'])}) + data.append( + {"id": m["id"], "data": (m["name"], m["download_date"], m["id"])} + ) else: raise ValueError("Model list item is not a dict") - dropdown.set_data(data, ["Name", "Download Date", "ID"], placeholder="➁ Click to select a model") + dropdown.set_data( + data, ["Name", "Download Date", "ID"], placeholder="➁ Click to select a model" + ) dropdown.setEnabled(True) else: - dropdown.set_data([{"id": None, "data": ("No models registered", "", "")}], ["Name", "Download Date", "ID"], placeholder="No models registered") + dropdown.set_data( + [{"id": None, "data": ("No models registered", "", "")}], + ["Name", "Download Date", "ID"], + placeholder="No models registered", + ) dropdown.setEnabled(False) - + def reload_datasets(self): """Reload datasets in the dropdown""" self.predict_dataset_dropdown.clear() dataset_list = datasets.list_datasets_metadata() columns = ["Name", "Date", "Description"] - + if dataset_list: data = [] for d in dataset_list: - data.append({"id": d["id"], "data": (d["name"], d["registration_date"], d["description"])}) - self.predict_dataset_dropdown.set_data(data, columns, placeholder="Click to select a dataset") + data.append( + {"id": d["id"], "data": (d["name"], d["registration_date"], d["description"])} + ) + self.predict_dataset_dropdown.set_data( + data, columns, placeholder="Click to select a dataset" + ) self.predict_dataset_dropdown.setEnabled(True) else: - self.predict_dataset_dropdown.set_data([{"id": None, "data": ("No datasets registered", "", "")}], columns, placeholder="No datasets registered") + self.predict_dataset_dropdown.set_data( + [{"id": None, "data": ("No datasets registered", "", "")}], + columns, + placeholder="No datasets registered", + ) self.predict_dataset_dropdown.setEnabled(False) def init_ui(self): @@ -140,24 +149,24 @@ def init_ui(self): # Dynamically create engine options available_engines = engines.list_engines() - + for idx, engine_id in enumerate(available_engines): engine_obj = engines.get_engine(engine_id) engine_name = engine_obj.name() - engine_description = engine_obj.description - + engine_description = engine_obj.description + # Create engine layout engine_layout = QHBoxLayout() engine_layout.setSpacing(0) - # Set right side spacing + # Set right side spacing engine_layout.setContentsMargins(0, 0, 0, 0) # Radio button radio = QRadioButton(engine_name) radio.setChecked(idx == 0) # Check first one by default radio.toggled.connect(self.on_mode_changed) - + self.engine_radios[engine_id] = radio self.mode_button_group.addButton(radio) @@ -172,14 +181,14 @@ def init_ui(self): radio_widget.setFixedWidth(160) radio_widget.setStyleSheet("background-color: white;") engine_layout.addWidget(radio_widget) - + # Add spacing to align dropdown with description box engine_layout.addSpacing(25) # Model dropdown dropdown = MultiColumnComboBox() dropdown.setStyleSheet("color: #95a5a6;") - + # Populate models model_list = models.list_models(engine_id) if model_list: @@ -187,17 +196,25 @@ def init_ui(self): for m in model_list: print(m) if isinstance(m, dict): - data.append({"id": m["id"], "data": (m["name"], m["download_date"], m['id'])}) + data.append( + {"id": m["id"], "data": (m["name"], m["download_date"], m["id"])} + ) else: raise ValueError("Model list item is not a dict") - dropdown.set_data(data, ["Name", "Download Date", "ID"], placeholder="➁ Click to select a model") + dropdown.set_data( + data, ["Name", "Download Date", "ID"], placeholder="➁ Click to select a model" + ) dropdown.setEnabled(True) else: - dropdown.set_data([{"id": None, "data": ("No models registered", "", "")}], ["Name", "Download Date", "ID"], placeholder="No models registered") + dropdown.set_data( + [{"id": None, "data": ("No models registered", "", "")}], + ["Name", "Download Date", "ID"], + placeholder="No models registered", + ) dropdown.setEnabled(False) - + dropdown.setFixedWidth(300) - + self.engine_dropdowns[engine_id] = dropdown engine_layout.addWidget(dropdown) engine_layout.addStretch() @@ -218,14 +235,16 @@ def init_ui(self): """) desc_layout = QHBoxLayout(desc_container) desc_layout.setContentsMargins(8, 6, 8, 6) - + info = QLabel(engine_description) - info.setStyleSheet("color: #7f8c8d; font-size: 11px; background: transparent; border: none;") + info.setStyleSheet( + "color: #7f8c8d; font-size: 11px; background: transparent; border: none;" + ) info.setWordWrap(True) desc_layout.addWidget(info) - + model_layout.addWidget(desc_container) - + model_layout.addSpacing(5) # Set initial selected engine @@ -240,7 +259,7 @@ def init_ui(self): dataset_label = QLabel("③ Choose a Speech Dataset") dataset_label.setStyleSheet("font-weight: bold; color: #2c3e50;") layout.addWidget(dataset_label) - + self.predict_dataset_dropdown = MultiColumnComboBox() self.predict_dataset_dropdown.setStyleSheet(""" QComboBox { @@ -261,14 +280,22 @@ def init_ui(self): if dataset_list: data = [] for d in dataset_list: - data.append({"id": d["id"], "data": (d["name"], d["registration_date"], d["description"])}) - self.predict_dataset_dropdown.set_data(data, columns, placeholder="Click to select a dataset") + data.append( + {"id": d["id"], "data": (d["name"], d["registration_date"], d["description"])} + ) + self.predict_dataset_dropdown.set_data( + data, columns, placeholder="Click to select a dataset" + ) self.predict_dataset_dropdown.setEnabled(True) else: # Add dummy ID so itemData() is predictable - self.predict_dataset_dropdown.set_data([{"id": None, "data": ("No datasets registered", "", "")}], columns, placeholder="No datasets registered") + self.predict_dataset_dropdown.set_data( + [{"id": None, "data": ("No datasets registered", "", "")}], + columns, + placeholder="No datasets registered", + ) self.predict_dataset_dropdown.setEnabled(False) - + layout.addWidget(self.predict_dataset_dropdown) layout.addSpacing(10) @@ -311,15 +338,12 @@ def on_predict_settings(self): """Open settings dialog for selected engine""" engine = engines.get_engine(self.selected_engine) if engine: - settings_dialog = GenericDialog( - self, - config=engine.get_settings_config("align") - ) + settings_dialog = GenericDialog(self, config=engine.get_settings_config("align")) settings_dialog.exec() - + if settings_dialog.result() == QDialog.DialogCode.Accepted: settings_dialog.save() - + self.parent.setGraphicsEffect(None) def on_predict_alignments(self): @@ -328,11 +352,10 @@ def on_predict_alignments(self): if not selected_dataset_id: QMessageBox.warning( - self, "No Dataset Selected", - "Please select a dataset from the dropdown." + self, "No Dataset Selected", "Please select a dataset from the dropdown." ) return - + print("Predict Alignments clicked!") print(f"Engine: {self.selected_engine}") @@ -340,9 +363,7 @@ def on_predict_alignments(self): self.predict_status.setStyleSheet("color: #f39c12; font-size: 12px; margin-top: 5px;") self.predict_btn.setEnabled(False) - self.worker = WorkerThread( - lambda: self.predict_alignments_logic(selected_dataset_id) - ) + self.worker = WorkerThread(lambda: self.predict_alignments_logic(selected_dataset_id)) self.worker.finished.connect(self.on_predict_finished) self.worker.start() @@ -352,14 +373,14 @@ def predict_alignments_logic(self, dataset_id: str) -> str: selected_model_id = self.engine_dropdowns[self.selected_engine].current_id() print(f"Selected model ID: {selected_model_id}") - + # Get engine and call align method engine = engines.get_engine(self.selected_engine) engine.align( dataset_id=dataset_id, model_id=selected_model_id, ) - + return "Alignments predicted successfully" def on_predict_finished(self, success, message): diff --git a/src/voxkit/gui/pages/pipeline/training.py b/src/voxkit/gui/pages/pipeline/training.py index 6992980..c3a1bd1 100644 --- a/src/voxkit/gui/pages/pipeline/training.py +++ b/src/voxkit/gui/pages/pipeline/training.py @@ -18,6 +18,7 @@ QWidget, ) from voxkit.gui.pages.pipeline.model_eval import QComboBox +from voxkit.storage.validation import validate_path, validate_paths from voxkit.config import Defaults from voxkit.engines import ManageEngines @@ -25,12 +26,12 @@ from voxkit.gui.workers.worker_thread import WorkerThread from voxkit.storage.datasets import get_dataset_path, list_datasets from voxkit.storage.models import list_models -from voxkit.storage.validation import validate_path, validate_paths from .styles import BrowseButtonStyle TrainingTools = ManageEngines.get_tool_providers("train") + class TrainingPage(QWidget): def __init__(self, parent): super().__init__() @@ -73,16 +74,15 @@ def on_train_model(self): self, "No Dataset Selected", "Please select a dataset from the dropdown." ) return - + # Get dataset path audio_path = get_dataset_path(selected_dataset) if not audio_path: QMessageBox.warning( - self, "Invalid Dataset", - f"Could not find path for dataset '{selected_dataset}'." + self, "Invalid Dataset", f"Could not find path for dataset '{selected_dataset}'." ) return - + # Validate inputs paths = { "Training Audio Directory": audio_path, @@ -147,7 +147,6 @@ def train_model_logic(self, audio_path, textgrid_path, model_name, model): new_model_id=model_name, ) - return "Model training completed successfully" def on_train_finished(self, success, message): @@ -167,8 +166,7 @@ def on_training_settings(self): """Handle settings button click on training page""" self.settings_dialog = GenericDialog( - parent=self, - config=TrainingTools[self.selected_engine].get_settings_config("train") + parent=self, config=TrainingTools[self.selected_engine].get_settings_config("train") ) result = self.settings_dialog.exec() @@ -189,7 +187,7 @@ def reload_models(self): self.engine_panel_dropdowns[self.selected_engine].addItems( list(model_names) if model_names else [] ) - + def reload_datasets(self): """Reload datasets in the dropdown""" self.train_dataset_dropdown.clear() @@ -200,7 +198,7 @@ def reload_datasets(self): else: self.train_dataset_dropdown.addItem("No datasets registered") self.train_dataset_dropdown.setEnabled(False) - + def init_ui(self): self.setMinimumWidth(600) layout = QVBoxLayout(self) @@ -237,7 +235,7 @@ def init_ui(self): self.mode_group.setExclusive(True) # Maps - self.engine_panel_radios = {} + self.engine_panel_radios = {} self.engine_panel_dropdowns = {} # ------------------------------------------------------------------ @@ -301,7 +299,7 @@ def init_ui(self): dataset_label = QLabel("Training Dataset") dataset_label.setStyleSheet("font-weight: bold; color: #2c3e50;") layout.addWidget(dataset_label) - + self.train_dataset_dropdown = QComboBox() self.train_dataset_dropdown.setPlaceholderText("Select Dataset") self.train_dataset_dropdown.setStyleSheet(""" @@ -325,7 +323,7 @@ def init_ui(self): height: 12px; } """) - + # Populate with registered datasets datasets = list_datasets() if datasets: @@ -333,7 +331,7 @@ def init_ui(self): else: self.train_dataset_dropdown.addItem("No datasets registered") self.train_dataset_dropdown.setEnabled(False) - + layout.addWidget(self.train_dataset_dropdown) # Training Text Grid Directory diff --git a/src/voxkit/gui/pages/pipeline/training_stacker.py b/src/voxkit/gui/pages/pipeline/training_stacker.py index 466e18d..d793b1e 100644 --- a/src/voxkit/gui/pages/pipeline/training_stacker.py +++ b/src/voxkit/gui/pages/pipeline/training_stacker.py @@ -5,7 +5,6 @@ # Add these imports at the top of your file from PyQt6.QtWidgets import ( QButtonGroup, - QComboBox, QDialog, QFileDialog, QGroupBox, @@ -21,16 +20,15 @@ from voxkit.config import Defaults from voxkit.engines import engines +from voxkit.gui.components import MultiColumnComboBox from voxkit.gui.frameworks.settings_modal import GenericDialog -from voxkit.gui.workers.worker_thread import WorkerThread -from voxkit.storage import models, datasets, alignments from voxkit.gui.utils import validate_path, validate_paths -from voxkit.gui.components import MultiColumnComboBox - -from .styles import BrowseButtonStyle +from voxkit.gui.workers.worker_thread import WorkerThread +from voxkit.storage import alignments, datasets, models TrainingTools = engines.get_tool_providers("train") + class TrainingStacker(QWidget): def __init__(self, parent): super().__init__() @@ -50,29 +48,47 @@ def on_mode_changed(self): if radio.isChecked(): self.selected_engine = engine_id break - + # Show/hide appropriate dropdowns for engine_id, dropdown in self.engine_panel_dropdowns.items(): dropdown.setVisible(engine_id == self.selected_engine) - + print(f"Engine changed to: {self.selected_engine}") def on_dataset_selected(self): """Handle dataset selection change and load corresponding alignments""" selected_dataset_id = self.train_dataset_dropdown.current_id() - + # Load alignments for the selected dataset alignments_meta = alignments.list_alignments(selected_dataset_id) if alignments_meta: data = [] for a in alignments_meta: - data.append({"id": a["id"], "data": (a["engine_id"], a["model_metadata"]["name"], a["alignment_date"], a['status'])}) - self.train_alignment_dropdown.set_data(data, ["Method", "Model", "Date", "Status"], placeholder="Click to select an alignment") + data.append( + { + "id": a["id"], + "data": ( + a["engine_id"], + a["model_metadata"]["name"], + a["alignment_date"], + a["status"], + ), + } + ) + self.train_alignment_dropdown.set_data( + data, + ["Method", "Model", "Date", "Status"], + placeholder="Click to select an alignment", + ) self.train_alignment_dropdown.setEnabled(True) else: - self.train_alignment_dropdown.set_data([{"id": None, "data": ("No alignments registered", "", "")}], ["Method", "Model", "Date", "Status"], placeholder="No alignments registered") + self.train_alignment_dropdown.set_data( + [{"id": None, "data": ("No alignments registered", "", "")}], + ["Method", "Model", "Date", "Status"], + placeholder="No alignments registered", + ) self.train_alignment_dropdown.setEnabled(False) def browse_directory(self, line_edit): @@ -97,16 +113,15 @@ def on_train_model(self): self, "No Dataset Selected", "Please select a dataset from the dropdown." ) return - + # Get dataset path audio_path = datasets._get_dataset_root(selected_dataset_id) if not audio_path: QMessageBox.warning( - self, "Invalid Dataset", - f"Could not find path for dataset '{selected_dataset_id}'." + self, "Invalid Dataset", f"Could not find path for dataset '{selected_dataset_id}'." ) return - + # Get the alignments metadata align_index = self.train_alignment_dropdown.currentIndex() selected_alignment_id = self.train_alignment_dropdown.itemData(align_index) @@ -115,15 +130,18 @@ def on_train_model(self): self, "No Alignment Selected", "Please select an alignment from the dropdown." ) return - - alignment_meta = alignments.get_alignment_metadata(selected_dataset_id, selected_alignment_id) + + alignment_meta = alignments.get_alignment_metadata( + selected_dataset_id, selected_alignment_id + ) if not alignment_meta: QMessageBox.warning( - self, "Invalid Alignment", - f"Could not find metadata for alignment '{selected_alignment_id}'." + self, + "Invalid Alignment", + f"Could not find metadata for alignment '{selected_alignment_id}'.", ) return - + # Validate inputs paths = { "Training Audio Directory": audio_path, @@ -142,7 +160,7 @@ def on_train_model(self): if not model_name: QMessageBox.warning(self, "Invalid Model Name", "Please enter a valid model name.") return - + print(f"Checking if model name '{model_name}' is already taken in {mode}...") names_taken = models.list_models(mode) names_taken = [m["name"] for m in names_taken] @@ -184,7 +202,9 @@ def train_model_logic(self, audio_path, textgrid_path, model_name, model): ) selected_model_index = self.engine_panel_dropdowns[self.selected_engine].currentIndex() - base_model_id = self.engine_panel_dropdowns[self.selected_engine].itemData(selected_model_index) + base_model_id = self.engine_panel_dropdowns[self.selected_engine].itemData( + selected_model_index + ) TrainingTools[self.selected_engine].train_aligner( audio_root=Path(audio_path), @@ -193,7 +213,6 @@ def train_model_logic(self, audio_path, textgrid_path, model_name, model): new_model_id=model_name, ) - return "Model training completed successfully" def on_train_finished(self, success, message): @@ -213,8 +232,7 @@ def on_training_settings(self): """Handle settings button click on training page""" self.settings_dialog = GenericDialog( - parent=self, - config=TrainingTools[self.selected_engine].get_settings_config("train") + parent=self, config=TrainingTools[self.selected_engine].get_settings_config("train") ) result = self.settings_dialog.exec() @@ -234,30 +252,44 @@ def reload_models(self): if models_meta: data = [] for m in models_meta: - data.append({"id": m["id"], "data": (m["name"], m["download_date"], m['id'])}) + data.append( + {"id": m["id"], "data": (m["name"], m["download_date"], m["id"])} + ) combo.set_data(data, ["Name", "Download Date", "ID"]) else: - combo.set_data([{"id": None, "data": ("No models registered", "", "")}], ["Name", "Download Date", "ID"]) + combo.set_data( + [{"id": None, "data": ("No models registered", "", "")}], + ["Name", "Download Date", "ID"], + ) except Exception as e: print("Error reloading models:", e) - + def reload_datasets(self): """Reload datasets in the dropdown""" datasets_meta = datasets.list_datasets_metadata() if datasets_meta: data = [] for d in datasets_meta: - data.append({"id": d["id"], "data": (d["name"], d["description"], d['id'])}) - self.train_dataset_dropdown.set_data(data, ["Name", "Description", "ID"], placeholder="Click to select a dataset") + data.append({"id": d["id"], "data": (d["name"], d["description"], d["id"])}) + self.train_dataset_dropdown.set_data( + data, ["Name", "Description", "ID"], placeholder="Click to select a dataset" + ) self.train_dataset_dropdown.setEnabled(True) else: - self.train_dataset_dropdown.set_data([{"id": None, "data": ("No datasets registered", "", "")}], ["Name", "Description", "ID"], placeholder="No datasets registered") + self.train_dataset_dropdown.set_data( + [{"id": None, "data": ("No datasets registered", "", "")}], + ["Name", "Description", "ID"], + placeholder="No datasets registered", + ) self.train_dataset_dropdown.setEnabled(False) - self.train_alignment_dropdown.set_data([{"id": None, "data": ("Select a dataset first", "", "")}], ["Method", "Model", "Date", "Status"], placeholder="Select a dataset first") + self.train_alignment_dropdown.set_data( + [{"id": None, "data": ("Select a dataset first", "", "")}], + ["Method", "Model", "Date", "Status"], + placeholder="Select a dataset first", + ) self.train_alignment_dropdown.setEnabled(False) - - + def init_ui(self): self.setMinimumWidth(600) layout = QVBoxLayout(self) @@ -309,7 +341,7 @@ def init_ui(self): self.mode_group.setExclusive(True) # Maps - self.engine_panel_radios = {} + self.engine_panel_radios = {} self.engine_panel_dropdowns = {} # ------------------------------------------------------------------ @@ -334,16 +366,21 @@ def init_ui(self): combo.setFixedWidth(300) models_meta = models.list_models(engine_id) - + if models_meta: data = [] for m in models_meta: - data.append({"id": m["id"], "data": (m["name"], m["download_date"], m['id'])}) - combo.set_data(data, ["Name", "Download Date", "ID"], placeholder="➁ Click to select a model") + data.append({"id": m["id"], "data": (m["name"], m["download_date"], m["id"])}) + combo.set_data( + data, ["Name", "Download Date", "ID"], placeholder="➁ Click to select a model" + ) else: - combo.set_data([{"id": None, "data": ("No models registered", "", "")}], ["Name", "Download Date", "ID"], placeholder="No models registered") - - + combo.set_data( + [{"id": None, "data": ("No models registered", "", "")}], + ["Name", "Download Date", "ID"], + placeholder="No models registered", + ) + self.engine_panel_dropdowns[engine_id] = combo # ---------- layout ---------- @@ -362,7 +399,7 @@ def init_ui(self): rw.setFixedWidth(160) rw.setStyleSheet("background-color: white;") hbox.addWidget(rw) - + # Add spacing to align dropdown with description box hbox.addSpacing(25) @@ -384,12 +421,14 @@ def init_ui(self): """) desc_layout = QHBoxLayout(desc_container) desc_layout.setContentsMargins(8, 6, 8, 6) - + desc = QLabel(engine.description or "No description") - desc.setStyleSheet("color: #7f8c8d; font-size: 11px; background: transparent; border: none;") + desc.setStyleSheet( + "color: #7f8c8d; font-size: 11px; background: transparent; border: none;" + ) desc.setWordWrap(True) desc_layout.addWidget(desc) - + engine_vbox.addWidget(desc_container) engine_vbox.addSpacing(5) @@ -401,7 +440,7 @@ def init_ui(self): dataset_label = QLabel("③ Choose a Training Dataset") dataset_label.setStyleSheet("font-weight: bold; color: #2c3e50;") layout.addWidget(dataset_label) - + self.train_dataset_dropdown = MultiColumnComboBox() self.train_dataset_dropdown.setStyleSheet(""" QComboBox { @@ -416,29 +455,37 @@ def init_ui(self): color: #999; } """) - + # Populate with registered datasets datasets_meta = datasets.list_datasets_metadata() if datasets_meta: data = [] for d in datasets_meta: - data.append({"id": d["id"], "data": (d["name"], d["registration_date"], d["description"])}) - self.train_dataset_dropdown.set_data(data, ["Name", "Date", "Description"], placeholder="Click to select a dataset") + data.append( + {"id": d["id"], "data": (d["name"], d["registration_date"], d["description"])} + ) + self.train_dataset_dropdown.set_data( + data, ["Name", "Date", "Description"], placeholder="Click to select a dataset" + ) self.train_dataset_dropdown.setEnabled(True) else: - self.train_dataset_dropdown.set_data([{"id": None, "data": ("No datasets registered", "", "")}], ["Name", "Date", "Description"], placeholder="No datasets registered") + self.train_dataset_dropdown.set_data( + [{"id": None, "data": ("No datasets registered", "", "")}], + ["Name", "Date", "Description"], + placeholder="No datasets registered", + ) self.train_dataset_dropdown.setEnabled(False) - + # Connect to selection handler self.train_dataset_dropdown.currentIndexChanged.connect(self.on_dataset_selected) - + layout.addWidget(self.train_dataset_dropdown) - + # Alignment selection dropdown (initially disabled) alignment_label = QLabel("④ Choose an Alignment") alignment_label.setStyleSheet("font-weight: bold; color: #2c3e50;") layout.addWidget(alignment_label) - + self.train_alignment_dropdown = MultiColumnComboBox() self.train_alignment_dropdown.setStyleSheet(""" QComboBox { @@ -454,8 +501,11 @@ def init_ui(self): } """) - - self.train_alignment_dropdown.set_data([{"id": None, "data": ("Select a dataset first", "", "")}], ["Method", "Model", "Date", "Status"], placeholder="Select a dataset first") + self.train_alignment_dropdown.set_data( + [{"id": None, "data": ("Select a dataset first", "", "")}], + ["Method", "Model", "Date", "Status"], + placeholder="Select a dataset first", + ) self.train_alignment_dropdown.setEnabled(False) layout.addWidget(self.train_alignment_dropdown) diff --git a/src/voxkit/gui/workers/__init__.py b/src/voxkit/gui/workers/__init__.py index 8418bde..afb6146 100644 --- a/src/voxkit/gui/workers/__init__.py +++ b/src/voxkit/gui/workers/__init__.py @@ -3,4 +3,4 @@ from .datasets_thread import DatasetRegistrationWorker from .worker_thread import WorkerThread -__all__ = ["WorkerThread", "DatasetRegistrationWorker"] \ No newline at end of file +__all__ = ["WorkerThread", "DatasetRegistrationWorker"] diff --git a/src/voxkit/gui/workers/datasets_thread.py b/src/voxkit/gui/workers/datasets_thread.py index 06be5a9..908d464 100644 --- a/src/voxkit/gui/workers/datasets_thread.py +++ b/src/voxkit/gui/workers/datasets_thread.py @@ -33,15 +33,18 @@ def __init__( def run(self): self.progress.emit("Validating dataset structure...") - + # First validate the dataset success = datasets.validate_dataset(self.dataset_path) - + if not success: - self.finished.emit(False, " Dataset validation failed. " \ - "Please ensure the dataset follows the Kaldi organization pattern.") + self.finished.emit( + False, + " Dataset validation failed. " + "Please ensure the dataset follows the Kaldi organization pattern.", + ) return - + success, message = datasets.create_dataset( name=self.dataset_name, description=self.description, @@ -55,34 +58,31 @@ def run(self): if not success: self.finished.emit(False, message) - return + return else: self.progress.emit("Dataset metadata created successfully.") - + # Determine output path for CSV file dataset_dir = os.path.join(datasets._get_datasets_root(), message["id"]) csv_path = os.path.join(dataset_dir, f"{self.analysis_method.lower()}_summary.csv") - - csv_data = ManageAnalyzers.get_analyzers()[self.analysis_method].analyze( - self.dataset_path - ) + + csv_data = ManageAnalyzers.get_analyzers()[self.analysis_method].analyze(self.dataset_path) csv_success, csv_message = self._save_csv(csv_data, csv_path) - + if not csv_success: self.finished.emit(True, f"Warning: {csv_message}") else: self.finished.emit(True, csv_message) - def _save_csv(self, data: list[dict], path: str) -> tuple[bool, str]: """ Save the analysis data to a CSV file. Args: data: List of dictionaries where each dictionary represents a row. - path: Output path for the CSV file. + path: Output path for the CSV file. Returns: Tuple of (success, message) where success is True if the file was saved successfully. """ @@ -103,4 +103,4 @@ def _save_csv(self, data: list[dict], path: str) -> tuple[bool, str]: writer.writerow(row) return True, f"CSV saved successfully to {path}." except Exception as e: - return False, f"Failed to save CSV: {e}" \ No newline at end of file + return False, f"Failed to save CSV: {e}" diff --git a/src/voxkit/services/mfa.py b/src/voxkit/services/mfa.py index ece4a32..4cc8a52 100755 --- a/src/voxkit/services/mfa.py +++ b/src/voxkit/services/mfa.py @@ -1,5 +1,5 @@ """ -Until compatible exports are avalible through the MFA package this will serve as the alterative +Until compatible exports are avalible through the MFA package this will serve as the alterative entrypoint for MFA logic bootstrapping the cli """ @@ -69,5 +69,4 @@ def run_mfa_evaluate( """ Run the Montreal Forced Aligner 'evaluate' subcommand as a subprocess. """ - pass - + raise NotImplementedError("MFA evaluate is not yet implemented") diff --git a/src/voxkit/storage/__init__.py b/src/voxkit/storage/__init__.py index f4431b8..2bcb72f 100644 --- a/src/voxkit/storage/__init__.py +++ b/src/voxkit/storage/__init__.py @@ -2,8 +2,9 @@ VoxKit Storage Module ----------- - This package contains modules for managing persistent storage of datasets, models, and alignments within the VoxKit framework. - + This package contains modules for managing persistent storage of + datasets, models, and alignments within the VoxKit framework. + Imports ------- - datasets: CRUD operations for managing datasets. @@ -20,11 +21,16 @@ __email__ = "beckett.frey@gmail.com" __version__ = "0.0.1" + +# Import utils but don't call get_storage_root() at module import time +from . import alignments, datasets, models, utils + + def _ensure_storage_root(): """Ensure storage root directory exists. Called lazily when needed.""" try: - from . import utils from pathlib import Path + storage_root = Path(utils.get_storage_root()) if not storage_root.exists(): storage_root.mkdir(parents=True, exist_ok=True) @@ -33,15 +39,10 @@ def _ensure_storage_root(): print(f"Error initializing storage root: {e}") raise e -# Import utils but don't call get_storage_root() at module import time -from . import utils - -from . import alignments -from . import datasets -from . import models __all__ = [ "alignments", "datasets", "models", + "utils", ] diff --git a/src/voxkit/storage/alignments.py b/src/voxkit/storage/alignments.py index 7a719b6..43b15b4 100644 --- a/src/voxkit/storage/alignments.py +++ b/src/voxkit/storage/alignments.py @@ -33,16 +33,16 @@ - Error handling only exposes messages. - Raises FileNotFoundError if alignment or metadata not found. """ + import json import os import shutil from pathlib import Path -from typing import List, Tuple, TypedDict +from typing import List, Literal, Tuple, TypedDict from .config import ALIGNMENTS_ROOT -from .datasets import DatasetMetadata, get_dataset_metadata +from .datasets import _get_dataset_root, get_dataset_metadata from .models import ModelMetadata, get_model_metadata -from .datasets import DatasetMetadata, _get_dataset_root from .utils import generate_unique_id, readable_from_unique_id @@ -56,12 +56,12 @@ class AlignmentMetadata(TypedDict): tg_path: str -def _get_alignments_root(dataset_id: DatasetMetadata["id"]) -> Path | None: +def _get_alignments_root(dataset_id: str) -> Path | None: """Get the root directory for storing alignments for a given dataset. - + Args: dataset_id: Identifier of the dataset - + Returns: Path to the alignments root directory or None if dataset not found """ @@ -70,17 +70,17 @@ def _get_alignments_root(dataset_id: DatasetMetadata["id"]) -> Path | None: alignments_root = dataset_root / ALIGNMENTS_ROOT alignments_root.mkdir(parents=False, exist_ok=True) return alignments_root - + return None -def _get_alignment_root(dataset_id: DatasetMetadata["id"], alignment_id: AlignmentMetadata["id"]) -> Path | None: +def _get_alignment_root(dataset_id: str, alignment_id: str) -> Path | None: """Get the root directory for a specific alignment by ID. - + Args: dataset_id: Identifier of the dataset containing the alignment alignment_id: Identifier of the alignment - + Returns: Path to the alignment root directory or None if not found. """ @@ -93,8 +93,8 @@ def _get_alignment_root(dataset_id: DatasetMetadata["id"], alignment_id: Alignme def create_alignment( - dataset_id: DatasetMetadata["id"], engine_id: str, model_id: ModelMetadata["id"] -) -> tuple[True, AlignmentMetadata] | tuple[False, str]: + dataset_id: str, engine_id: str, model_id: str +) -> tuple[Literal[True], AlignmentMetadata] | tuple[Literal[False], str]: """Create a new alignment entry in the storage. Args: @@ -109,17 +109,17 @@ def create_alignment( model_metadata = get_model_metadata(engine_id, model_id) if not model_metadata: return False, f"Model '{model_id}' for engine '{engine_id}' not found" - + # Fetch dataset metadata dataset_metadata = get_dataset_metadata(dataset_id) if not dataset_metadata: return False, f"Dataset '{dataset_id}' not found" - + # Fetch alignment root alignments_root = _get_alignments_root(dataset_id) if not alignments_root: return False, f"Dataset '{dataset_id}' not found" - + # Create alignment directory now = generate_unique_id() alignment_date = readable_from_unique_id(now) @@ -145,17 +145,17 @@ def create_alignment( local=local, tg_path=str(tg_path), alignment_date=alignment_date, - status="Pending" + status="Pending", ) # Fetch model metadata metadata_path = alignment_root / "voxkit_alignment.json" - with open(metadata_path, 'w') as f: + with open(metadata_path, "w") as f: json.dump(metadata, f, indent=4) - + return True, metadata - + except Exception as e: # Clean up partially created directory if os.path.exists(alignment_root): @@ -163,16 +163,16 @@ def create_alignment( return False, f"Failed to create alignment metadata: {str(e)}" -def get_alignment_metadata(dataset_id: DatasetMetadata["id"], alignment_id: AlignmentMetadata["id"]) -> AlignmentMetadata: +def get_alignment_metadata(dataset_id: str, alignment_id: str) -> AlignmentMetadata | None: """Get the metadata for a specific alignment by ID.""" alignment_root = _get_alignment_root(dataset_id, alignment_id) if not alignment_root: return None - + metadata_path = alignment_root / "voxkit_alignment.json" try: - with open(metadata_path, 'r') as f: + with open(metadata_path, "r") as f: metadata = json.load(f) return metadata except Exception as e: @@ -180,44 +180,44 @@ def get_alignment_metadata(dataset_id: DatasetMetadata["id"], alignment_id: Alig raise e -def update_alignment(dataset_id: DatasetMetadata["id"], alignment_id: AlignmentMetadata["id"], updates: dict) -> Tuple[bool, str]: +def update_alignment(dataset_id: str, alignment_id: str, updates: dict) -> Tuple[bool, str]: """Update the status of an alignment. - + Args: dataset_id: Identifier of the dataset containing the alignment alignment_id: Identifier of the alignment to update updates: Dictionary of updates to apply to the alignment metadata - + Returns: Tuple of (success, message) """ alignment_root = _get_alignment_root(dataset_id, alignment_id) if not alignment_root: return False, f"Alignment '{alignment_id}' for dataset '{dataset_id}' not found" - + metadata_path = alignment_root / "voxkit_alignment.json" try: - with open(metadata_path, 'r') as f: + with open(metadata_path, "r") as f: metadata = json.load(f) - + # Update fields for key, value in updates.items(): if key in metadata: metadata[key] = value - - with open(metadata_path, 'w') as f: + + with open(metadata_path, "w") as f: json.dump(metadata, f, indent=4) - + return True, "Alignment metadata updated successfully." - + except Exception as e: return False, f"Failed to update alignment metadata: {str(e)}" -def list_alignments(dataset_id: DatasetMetadata["id"]) -> List[AlignmentMetadata]: +def list_alignments(dataset_id: str) -> List[AlignmentMetadata]: """List all alignment metadata for a given dataset. - + Args: dataset_id: Identifier of the dataset to list alignments @@ -227,7 +227,7 @@ def list_alignments(dataset_id: DatasetMetadata["id"]) -> List[AlignmentMetadata alignments_root = _get_alignments_root(dataset_id) if not alignments_root: return [] - + alignments_found = [] for dir in alignments_root.iterdir(): if dir.is_dir(): @@ -239,11 +239,11 @@ def list_alignments(dataset_id: DatasetMetadata["id"]) -> List[AlignmentMetadata alignments_found.append(metadata) except Exception as e: print(f"Failed to load alignment metadata from '{metadata_path}': {str(e)}") - + return alignments_found - -def delete_alignment(dataset_id: DatasetMetadata["id"], alignment_id: AlignmentMetadata["id"]) -> Tuple[bool, str]: + +def delete_alignment(dataset_id: str, alignment_id: str) -> Tuple[bool, str]: """ Delete an alignment given its dataset ID and alignment ID. @@ -257,10 +257,9 @@ def delete_alignment(dataset_id: DatasetMetadata["id"], alignment_id: AlignmentM alignment_root = _get_alignment_root(dataset_id, alignment_id) if not alignment_root: return False, f"Alignment '{alignment_id}' for dataset '{dataset_id}' not found" - + try: shutil.rmtree(alignment_root) return True, f"Alignment '{alignment_id}' deleted successfully." except Exception as e: return False, f"Failed to delete alignment '{alignment_id}': {str(e)}" - \ No newline at end of file diff --git a/src/voxkit/storage/config.py b/src/voxkit/storage/config.py index a0f53b0..ab313d6 100644 --- a/src/voxkit/storage/config.py +++ b/src/voxkit/storage/config.py @@ -3,11 +3,7 @@ --------------------- """ -STORAGE_ROOT = "~/.voxkit" # Root directory for all storage -MODELS_ROOT = "train" # Path from STORAGE_ROOT to models -DATASETS_ROOT = "datasets" # Path from STORAGE_ROOT to datasets -ALIGNMENTS_ROOT = "alignments" # Path from STORAGE_ROOT/DATASETS_ROOT to alignments - - - - +STORAGE_ROOT = "~/.voxkit" # Root directory for all storage +MODELS_ROOT = "train" # Path from STORAGE_ROOT to models +DATASETS_ROOT = "datasets" # Path from STORAGE_ROOT to datasets +ALIGNMENTS_ROOT = "alignments" # Path from STORAGE_ROOT/DATASETS_ROOT to alignments diff --git a/src/voxkit/storage/datasets.py b/src/voxkit/storage/datasets.py index 72abe50..7ef722a 100644 --- a/src/voxkit/storage/datasets.py +++ b/src/voxkit/storage/datasets.py @@ -7,7 +7,7 @@ Directory Structure (Many per Environment) ------------------------------- Each dataset follows a hierarchical structure: - + my_dataset/ ├── voxkit_dataset.json # Dataset metadata ├── alignments/ # Alignment outputs storage @@ -38,14 +38,15 @@ - transcribed flag indicates presence of transcriptions - Importing datasets adjusts metadata and validates structure """ + import json import os import shutil from pathlib import Path -from typing import List, Tuple, TypedDict +from typing import Any, List, Literal, Tuple, TypedDict -from .config import ALIGNMENTS_ROOT, DATASETS_ROOT -from .utils import generate_unique_id, get_storage_root, readable_from_unique_id +from voxkit.storage.config import ALIGNMENTS_ROOT, DATASETS_ROOT +from voxkit.storage.utils import generate_unique_id, get_storage_root, readable_from_unique_id class DatasetMetadata(TypedDict): @@ -61,7 +62,7 @@ class DatasetMetadata(TypedDict): def _get_datasets_root() -> Path: """Get the root directory for storage relative to voxkit storage root. - + Returns: Path to datasets storage root """ @@ -69,12 +70,13 @@ def _get_datasets_root() -> Path: root.mkdir(parents=False, exist_ok=True) return root -def _get_dataset_root(dataset_id: DatasetMetadata["id"]) -> Path | None: + +def _get_dataset_root(dataset_id: str) -> Path | None: """Get the root directory for a specific dataset by ID. - + Args: dataset_id: Identifier of the dataset - + Returns:""" datasets_root = _get_datasets_root() if datasets_root and dataset_id: @@ -83,6 +85,7 @@ def _get_dataset_root(dataset_id: DatasetMetadata["id"]) -> Path | None: return dataset_root return None + def _get_dataset_metadata(dataset_root: Path) -> DatasetMetadata | None: """Load dataset metadata from the given dataset root directory. @@ -93,11 +96,12 @@ def _get_dataset_metadata(dataset_root: Path) -> DatasetMetadata | None: metadata_path = dataset_root / "voxkit_dataset.json" if not metadata_path.exists(): return None - with open(metadata_path, 'r') as f: + with open(metadata_path, "r") as f: return json.load(f) except Exception: return None - + + def create_dataset( name: str, description: str, @@ -105,7 +109,7 @@ def create_dataset( cached: bool, anonymize: bool, transcribed: bool = False, -) -> tuple[bool, str]: +) -> tuple[Literal[True], DatasetMetadata] | tuple[Literal[False], str]: """Create a dataset metadata dictionary and create necessary directories. Args: @@ -122,7 +126,7 @@ def create_dataset( valid, msg = validate_dataset(Path(original_path)) if not valid: return False, msg - + now = generate_unique_id() try: @@ -137,7 +141,7 @@ def create_dataset( transcribed=transcribed, registration_date=humannow, ) - + # Create dataset directory dataset_dir = _get_datasets_root() / metadata["id"] if dataset_dir.exists(): @@ -148,7 +152,7 @@ def create_dataset( alignments_dir = dataset_dir / ALIGNMENTS_ROOT alignments_dir.mkdir(parents=False, exist_ok=False) metadata_path = dataset_dir / "voxkit_dataset.json" - with open(metadata_path, 'w') as f: + with open(metadata_path, "w") as f: json.dump(metadata, f, indent=2) # Cache the dataset if requested @@ -158,7 +162,7 @@ def create_dataset( shutil.copytree(original_path, cache_dir, dirs_exist_ok=True) return True, metadata - + except Exception as e: # Clean up on failure dataset_dir = _get_datasets_root() / now @@ -169,14 +173,14 @@ def create_dataset( return False, f"Failed to create dataset metadata: {str(e)}" -def get_dataset_metadata(dataset_id: DatasetMetadata["id"]) -> DatasetMetadata | None: +def get_dataset_metadata(dataset_id: str) -> DatasetMetadata | None: """Get the metadata for a specific dataset. - + Args: dataset_id: ID of the dataset to retrieve Returns: - Dataset metadata dictionary or None if not found + Dataset metadata dictionary or None if not found """ try: dataset_dir = _get_datasets_root() / dataset_id @@ -184,7 +188,7 @@ def get_dataset_metadata(dataset_id: DatasetMetadata["id"]) -> DatasetMetadata | if metadata is None: raise FileNotFoundError(f"Metadata for dataset '{dataset_id}' not found.") return metadata - + except Exception as e: print(f"Error retrieving dataset metadata: {str(e)}") return None @@ -192,7 +196,7 @@ def get_dataset_metadata(dataset_id: DatasetMetadata["id"]) -> DatasetMetadata | def list_datasets_metadata() -> List[DatasetMetadata]: """List all existing datasets. - + Returns: List of dataset metadata dictionaries """ @@ -204,11 +208,11 @@ def list_datasets_metadata() -> List[DatasetMetadata]: if entry.is_dir(): metadata_path = os.path.join(entry.path, "voxkit_dataset.json") if os.path.exists(metadata_path): - with open(metadata_path, 'r') as f: + with open(metadata_path, "r") as f: metadata = json.load(f) datasets.append(metadata) return datasets - + except Exception as e: print(f"Error listing datasets: {str(e)}") return [] @@ -219,7 +223,7 @@ def update_dataset_metadata( updates: dict, ) -> Tuple[bool, str]: """Update the metadata for a specific dataset. - + Args: dataset_id: ID of the dataset to update updates: Dictionary of metadata fields to update @@ -229,7 +233,10 @@ def update_dataset_metadata( """ try: metadata = get_dataset_metadata(dataset_id) - + + if not metadata: + return False, f"Dataset {dataset_id} not found" + if updates["description"] is not None: metadata["description"] = updates["description"] if updates["cached"] is not None: @@ -238,14 +245,14 @@ def update_dataset_metadata( metadata["anonymize"] = updates["anonymize"] if updates["transcribed"] is not None: metadata["transcribed"] = updates["transcribed"] - + # Save the updated metadata metadata_path = _get_datasets_root() / dataset_id / "voxkit_dataset.json" - with open(metadata_path, 'w') as f: + with open(metadata_path, "w") as f: json.dump(metadata, f, indent=2) - + return True, "Dataset metadata updated successfully" - + except KeyError as e: return False, f"Invalid metadata key: {str(e)}" except FileNotFoundError as e: @@ -254,35 +261,35 @@ def update_dataset_metadata( return False, f"Failed to update dataset metadata: {str(e)}" -def delete_dataset(dataset_id: DatasetMetadata["id"]) -> Tuple[bool, str]: +def delete_dataset(dataset_id: str) -> Tuple[bool, str]: """Delete a registered dataset. - + Args: dataset_id: ID of the dataset to delete - + Returns: Tuple of (success, message) """ if not dataset_id: return False, "Dataset ID cannot be empty." - + dataset_path = _get_datasets_root() / dataset_id if dataset_path is None: return False, f"Dataset '{dataset_id}' not found" - + if not dataset_path.exists(): return False, f"Dataset '{dataset_id}' not found" - + try: shutil.rmtree(dataset_path) return True, f"Dataset '{dataset_id}' metadata deleted successfully" - + except Exception as e: return False, f"Failed to delete dataset: {str(e)}" -def export_dataset(dataset_id: DatasetMetadata["id"], output_root: Path) -> Tuple[bool, str]: +def export_dataset(dataset_id: str, output_root: Path) -> Tuple[bool, str]: """Export an existing dataset to a specified output path. Args: @@ -292,30 +299,30 @@ def export_dataset(dataset_id: DatasetMetadata["id"], output_root: Path) -> Tupl Returns: Tuple of (success, message) """ - + if not output_root.exists(): return False, f"Output path '{output_root}' does not exist." else: dataset_path = _get_datasets_root() / dataset_id - + if not dataset_path.exists(): return False, f"Dataset '{dataset_id}' not found." - + dataset_meta = get_dataset_metadata(dataset_id) if not dataset_meta: return False, f"Metadata for dataset '{dataset_id}' not found." - + dest_path = output_root / (dataset_meta["name"] + "_" + dataset_id) try: shutil.copytree(dataset_path, dest_path, dirs_exist_ok=False) return True, f"Dataset '{dataset_id}' exported successfully to '{dest_path}'." except Exception as e: return False, f"Failed to export dataset: {str(e)}" - + def import_dataset(dataset_path: Path) -> Tuple[bool, str]: """Import an existing dataset into VoxKit storage. - + Args: dataset_path: Path to the dataset to import. @@ -331,18 +338,18 @@ def import_dataset(dataset_path: Path) -> Tuple[bool, str]: if not dataset_path.is_dir(): return False, f"Dataset path '{dataset_path}' is not a directory." valid, valid_msg = validate_dataset(dataset_path / "cache") - + now = generate_unique_id() print(now) dataset_dest = _get_datasets_root() / now try: - dataset_metadata = _get_dataset_metadata(dataset_path) - if dataset_metadata is None: + dataset_metadata_typed = _get_dataset_metadata(dataset_path) + if dataset_metadata_typed is None: return False, "Dataset metadata file not found in the provided dataset path." - + # Change metadata accordingly - dataset_metadata = dict(dataset_metadata) # Make a copy to modify + dataset_metadata: dict[str, Any] = dict(dataset_metadata_typed) # Make a copy to modify dataset_metadata["id"] = now humannow = readable_from_unique_id(now) dataset_metadata["registration_date"] = humannow @@ -351,24 +358,28 @@ def import_dataset(dataset_path: Path) -> Tuple[bool, str]: if not dataset_metadata["cached"]: original_location_exists = Path(dataset_metadata["original_path"]).exists() if not original_location_exists: - return False, f"Original dataset path {dataset_metadata['original_path']} does not exist; cannot import non-cached dataset." - - # Validate dataset + return ( + False, + f"Original dataset path {dataset_metadata['original_path']} " + "does not exist; cannot import non-cached dataset.", + ) + + # Validate dataset elif not valid: return False, f"Dataset validation failed: {valid_msg}" - + metadata_path = dataset_dest / "voxkit_dataset.json" - + if not dataset_dest.exists(): dataset_dest.mkdir(parents=False, exist_ok=False) shutil.copytree(dataset_path, dataset_dest, dirs_exist_ok=True) - with open(metadata_path, 'w') as f: + with open(metadata_path, "w") as f: json.dump(dataset_metadata, f, indent=2) - + return True, "Dataset imported successfully." - + except Exception as e: # Cleanup on failure if dataset_dest.exists(): @@ -377,11 +388,9 @@ def import_dataset(dataset_path: Path) -> Tuple[bool, str]: return False, f"Failed to import dataset: {str(e)}" -def validate_dataset( - dataset_path: Path -) -> Tuple[bool, str]: +def validate_dataset(dataset_path: Path) -> Tuple[bool, str]: """Validate if a dataset follows the organization pattern. - + Expected structure: dataset_path/ @@ -393,10 +402,10 @@ def validate_dataset( └── speaker_002/ ├── audio_001.wav └── audio_001.lab - + Args: dataset_path: Path to dataset root directory - + Returns: Tuple of (is_valid, message) where: is_valid: True if dataset is valid, False otherwise @@ -411,31 +420,51 @@ def validate_dataset( if not os.listdir(dataset_path): return False, f"Dataset path '{dataset_path}' is empty." for subdir in os.listdir(dataset_path): - if subdir.startswith('.'): + if subdir.startswith("."): continue # Skip hidden files/directories subdir_path = os.path.join(dataset_path, subdir) if not os.path.isdir(subdir_path): - return False, f"Expected speaker directories in dataset path '{dataset_path}', found file '{subdir_path}'." + return ( + False, + f"Expected speaker directories in dataset path '{dataset_path}', " + f"found file '{subdir_path}'.", + ) if not os.listdir(subdir_path): return False, f"Speaker directory '{subdir_path}' is empty." - - speaker_dirs = [d for d in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, d))] + + speaker_dirs = [ + d for d in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, d)) + ] if not speaker_dirs: return False, "No speaker directories found in the dataset path." - + for speaker in speaker_dirs: speaker_path = os.path.join(dataset_path, speaker) - audio_files = [f for f in os.listdir(speaker_path) if f.endswith('.wav') or f.endswith('.flac') or f.endswith('.mp3') or f.endswith('.ogg') or f.endswith('.m4a')] - label_files = [f for f in os.listdir(speaker_path) if f.endswith('.lab') or f.endswith('.txt')] - + audio_files = [ + f + for f in os.listdir(speaker_path) + if f.endswith(".wav") + or f.endswith(".flac") + or f.endswith(".mp3") + or f.endswith(".ogg") + or f.endswith(".m4a") + ] + label_files = [ + f for f in os.listdir(speaker_path) if f.endswith(".lab") or f.endswith(".txt") + ] + if not audio_files: return False, f"No audio files found in speaker directory '{speaker_path}'." - + if not label_files: return False, f"No label files found in speaker directory '{speaker_path}'." - + if len(audio_files) != len(label_files): - return False, f"Mismatch between number of audio and label files in speaker directory '{speaker_path}'." - + return ( + False, + f"Mismatch between number of audio and label files in speaker " + f"directory '{speaker_path}'.", + ) + return True, "Dataset is valid." diff --git a/src/voxkit/storage/models.py b/src/voxkit/storage/models.py index 5f4e661..b7e39e8 100644 --- a/src/voxkit/storage/models.py +++ b/src/voxkit/storage/models.py @@ -4,7 +4,7 @@ Specialized CRUD operations for managing models within the VoxKit storage system. -Directory Structure (None or Many per Engine) +Directory Structure (None or Many per Engine) ------------------- Each model follows a hierarchical structure: @@ -35,14 +35,16 @@ - Error handling only exposes messages. - Raises FileNotFoundError if model or metadata not found. """ + import json import shutil from pathlib import Path -from typing import Tuple, TypedDict +from typing import Literal, Tuple, TypedDict -from .config import MODELS_ROOT from voxkit.storage.utils import generate_unique_id, get_storage_root, readable_from_unique_id +from .config import MODELS_ROOT + class ModelMetadata(TypedDict): name: str @@ -55,9 +57,9 @@ class ModelMetadata(TypedDict): id: str -def _get_model_root(engine_id: str, model_id: ModelMetadata["id"]) -> Path | None: +def _get_model_root(engine_id: str, model_id: str) -> Path | None: """Get the root directory for storing models for a given engine. - + Args: engine_id: Identifier of the engine the model belongs to. model_id: Identifier of the model. @@ -67,13 +69,13 @@ def _get_model_root(engine_id: str, model_id: ModelMetadata["id"]) -> Path | Non """ model_root = Path(f"{get_storage_root()}/{engine_id}/{MODELS_ROOT}/{model_id}") if model_root.exists(): - return model_root + return model_root return None def _get_models_root(engine_id: str) -> Path | None: """Get the root directory for storing models for a given engine. - + Args: engine_id: Identifier of the engine. """ @@ -82,12 +84,12 @@ def _get_models_root(engine_id: str) -> Path | None: return models_root return None + def create_model( - engine_id: str, - model_name: str -) -> Tuple[True, ModelMetadata] | Tuple[False, str]: + engine_id: str, model_name: str +) -> tuple[Literal[True], ModelMetadata] | tuple[Literal[False], str]: """Create a new model entry in the storage. - + Args: engine_id: Identifier of the engine the model belongs to. model_name: Human-readable name for the model. @@ -95,11 +97,11 @@ def create_model( Returns: Tuple of (success, ModelMetadata) or (failure, error message) """ - + engine_models_root = Path(f"{get_storage_root()}/{engine_id}/{MODELS_ROOT}") if not engine_models_root.exists(): return False, f"Unsupported engine_id: {engine_id}" - + now = generate_unique_id() model_root = Path(f"{engine_models_root}/{now}") print(f"Creating model at: {model_root}") @@ -114,12 +116,12 @@ def create_model( model_metadata = ModelMetadata( name=model_name or f"Model_{now}", engine_id=engine_id, - model_path=str(model_path), - data_path=str(data_path), - eval_path=str(eval_path), - train_path=str(train_path), + model_path=model_path, + data_path=data_path, + eval_path=eval_path, + train_path=train_path, download_date=humandate, - id=now + id=now, ) # Create model directories @@ -129,12 +131,15 @@ def create_model( train_path.mkdir(parents=True, exist_ok=False) metadata_path = model_root / "voxkit_model.json" - # Create metadata file and write metadata + # Convert Path objects to strings for JSON serialization + json_metadata = {k: str(v) if isinstance(v, Path) else v for k, v in model_metadata.items()} + + # Create metadata file and write metadata with open(metadata_path, "w") as f: - json.dump(model_metadata, f, indent=4) + json.dump(json_metadata, f, indent=4) return True, model_metadata - + except Exception as e: print(f"Exception occurred during model creation: {e}") # Clean up partially created model directory @@ -144,11 +149,7 @@ def create_model( return False, "Failed to create model metadata." -def update_model_metadata( - engine_id: str, - model_id: str, - updates: dict -) -> Tuple[bool, str]: +def update_model_metadata(engine_id: str, model_id: str, updates: dict) -> Tuple[bool, str]: """Update metadata for an existing model. Args: @@ -162,30 +163,30 @@ def update_model_metadata( model_root = _get_model_root(engine_id, model_id) if not model_root: return False, f"Model '{model_id}' for engine '{engine_id}' not found" - + metadata_path = Path(model_root) / "voxkit_model.json" try: with open(metadata_path, "r") as f: metadata = json.load(f) - + # Update fields for key, value in updates.items(): if key in metadata: metadata[key] = str(value) - + with open(metadata_path, "w") as f: json.dump(metadata, f, indent=4) - + return True, "Model metadata updated successfully." - + except Exception as e: print(f"Exception occurred during model metadata update: {e}") - return False, f"Failed to update model metadata." + return False, "Failed to update model metadata." def list_models(engine_id: str) -> list[ModelMetadata]: """List available model names for the given engine. - + Args: engine_id: Identifier of the engine to list models for. @@ -196,7 +197,7 @@ def list_models(engine_id: str) -> list[ModelMetadata]: models_root = Path(f"{get_storage_root()}/{engine_id}/{MODELS_ROOT}") if not models_root.exists(): raise FileNotFoundError(f"Models root does not exist: {models_root}") - + models_found = [] for dir in models_root.iterdir(): if dir.is_dir(): @@ -206,11 +207,11 @@ def list_models(engine_id: str) -> list[ModelMetadata]: metadata = json.load(f) models_found.append(metadata) return models_found - + except Exception as e: print(f"Error listing models: {e}") return [] - + def get_model_metadata(engine_id: str, model_id: str) -> ModelMetadata: """Get metadata for a specific model by its ID. @@ -234,21 +235,24 @@ def get_model_metadata(engine_id: str, model_id: str) -> ModelMetadata: with open(metadata_path, "r") as f: metadata = json.load(f) return metadata - + def delete_model(engine_id: str, model_id: str) -> Tuple[bool, str]: """Delete a model given its engine ID and model ID. - + Args: engine_id: Identifier of the engine the model belongs to. model_id: Identifier of the model to delete. - + Returns: Tuple of (success, message)""" - + print(f"Attempting to delete model: engine_id={engine_id}, model_id={model_id}") model_path = _get_model_root(engine_id, model_id) + if not model_path: + return False, f"Model {model_id} not found" + print(f"Deleting model at path: {model_path}") shutil.rmtree(model_path) return True, "Model deleted successfully." @@ -256,7 +260,7 @@ def delete_model(engine_id: str, model_id: str) -> Tuple[bool, str]: def import_models(engine_id, new_models_root: Path) -> Tuple[bool, str]: """Import a model into the storage system. - + Args: model_id: Identifier of the model to import. new_models_root: Destination root path for the imported model. @@ -272,50 +276,58 @@ def import_models(engine_id, new_models_root: Path) -> Tuple[bool, str]: metadata_path = Path(new_model_path / "voxkit_model.json") if not metadata_path.exists(): return False, f"{new_model_path.name} (missing metadata file)" - + metadata = None # Read json metadata with open(metadata_path, "r") as f: metadata = json.load(f) - + if metadata is None: return False, f"{new_model_path.name} (invalid metadata file)" - + engine_models_root = get_storage_root() / engine_id if not engine_models_root.exists(): engine_models_root.mkdir(parents=True, exist_ok=False) - + model_id = generate_unique_id() if engine_id != metadata["engine_id"]: return False, f"{new_model_path.name} (engine_id mismatch)" - + new_metadata = ModelMetadata( name=metadata["name"], engine_id=metadata["engine_id"], - model_path=str(Path(engine_models_root / MODELS_ROOT / model_id / "entrypoint.model")), - data_path=str(Path(engine_models_root / MODELS_ROOT / model_id / "data")), - eval_path=str(Path(engine_models_root / MODELS_ROOT / model_id / "eval")), - train_path=str(Path(engine_models_root / MODELS_ROOT / model_id / "train")), + model_path=Path( + engine_models_root / MODELS_ROOT / model_id / "entrypoint.model" + ), + data_path=Path(engine_models_root / MODELS_ROOT / model_id / "data"), + eval_path=Path(engine_models_root / MODELS_ROOT / model_id / "eval"), + train_path=Path(engine_models_root / MODELS_ROOT / model_id / "train"), download_date=readable_from_unique_id(model_id), - id=model_id + id=model_id, ) # Copy model directory to storage dest_path = engine_models_root / MODELS_ROOT / model_id - + shutil.copytree(new_model_path, dest_path, dirs_exist_ok=True) + # Convert Path objects to strings for JSON serialization + json_metadata = { + k: str(v) if isinstance(v, Path) else v for k, v in new_metadata.items() + } + # Overwrite metadata file with new IDs and paths new_metadata_path = dest_path / "voxkit_model.json" + with open(new_metadata_path, "w") as f: - json.dump(new_metadata, f, indent=4) - + json.dump(json_metadata, f, indent=4) + except Exception as e: return False, f"{new_model_path.name} (error: {str(e)})" - + return True, f"Models imported successfully from: {new_models_root}" - + except Exception as e: - return False, f"Failed to import model: {str(e)}" \ No newline at end of file + return False, f"Failed to import model: {str(e)}" diff --git a/src/voxkit/storage/test/__init__.py b/src/voxkit/storage/test/__init__.py index 5487636..d11d9ad 100644 --- a/src/voxkit/storage/test/__init__.py +++ b/src/voxkit/storage/test/__init__.py @@ -1,2 +1 @@ - -# Pytest \ No newline at end of file +# Pytest diff --git a/src/voxkit/storage/test/test_alignments.py b/src/voxkit/storage/test/test_alignments.py index d6e96c7..e0a8243 100644 --- a/src/voxkit/storage/test/test_alignments.py +++ b/src/voxkit/storage/test/test_alignments.py @@ -1,6 +1,13 @@ from pathlib import Path + import pytest -from .test_setup import activate_test_environment, deactivate_test_environment, mock_get_storage_root, ENGINE_IDS + +from .test_setup import ( + ENGINE_IDS, + activate_test_environment, + deactivate_test_environment, + mock_get_storage_root, +) def generate_fake_datasets(): @@ -8,18 +15,18 @@ def generate_fake_datasets(): # Create wav/lab file pairs dataset_path = mock_get_storage_root() / "fake_datasets" / "valid" dataset_path.mkdir(parents=True, exist_ok=True) - + participant_names = ["participant_1", "participant_2"] for participant in participant_names: wavlab_path = dataset_path / participant wavlab_path.mkdir(parents=True, exist_ok=True) - + for i in range(3): wav_file = wavlab_path / f"sample_{i}.wav" lab_file = wavlab_path / f"sample_{i}.lab" wav_file.touch() lab_file.touch() - + return dataset_path @@ -27,7 +34,7 @@ def generate_fake_datasets(): def manage_test_environment(): # Setup before each test activate_test_environment(mock_get_storage_root(), ENGINE_IDS) - + # Generate fake datasets following setup generate_fake_datasets() yield @@ -38,12 +45,13 @@ def manage_test_environment(): @pytest.fixture def sample_dataset(monkeypatch): """Create a sample dataset for testing alignments.""" - from ..datasets import create_dataset from .. import datasets + from ..datasets import create_dataset + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) - + dataset_path = mock_get_storage_root() / "fake_datasets" / "valid" - + success, dataset_metadata = create_dataset( name="test_dataset", description="A test dataset for alignments", @@ -52,7 +60,7 @@ def sample_dataset(monkeypatch): anonymize=False, transcribed=True, ) - + assert success is True return dataset_metadata @@ -60,16 +68,17 @@ def sample_dataset(monkeypatch): @pytest.fixture def sample_model(monkeypatch): """Create a sample model for testing alignments.""" - from ..models import create_model from .. import models + from ..models import create_model + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) - + engine_id = ENGINE_IDS[0] success, model_metadata = create_model( engine_id=engine_id, model_name="test_model", ) - + assert success is True return model_metadata @@ -77,27 +86,28 @@ def sample_model(monkeypatch): class TestAlignments: class TestCreateAlignment: def test_create_alignment_success(self, monkeypatch, sample_dataset, sample_model): - from ..alignments import AlignmentMetadata, create_alignment from .. import models + from ..alignments import AlignmentMetadata, create_alignment + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) - + dataset_id = sample_dataset["id"] engine_id = sample_model["engine_id"] model_id = sample_model["id"] - + success, result = create_alignment( dataset_id=dataset_id, engine_id=engine_id, model_id=model_id, ) - + assert success is True assert isinstance(result, dict) - + # Verify all required keys are present required_keys = set(AlignmentMetadata.__annotations__.keys()) assert required_keys.issubset(set(result.keys())) - + # Verify field values assert result["engine_id"] == engine_id assert result["status"] == "Pending" @@ -106,14 +116,15 @@ def test_create_alignment_success(self, monkeypatch, sample_dataset, sample_mode assert "tg_path" in result def test_create_alignment_invalid_model(self, monkeypatch, sample_dataset): - from ..alignments import create_alignment from .. import models + from ..alignments import create_alignment + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) - + dataset_id = sample_dataset["id"] engine_id = "INVALID_ENGINE" model_id = "NON_EXISTENT_MODEL" - + # Assert error raised for invalid model with pytest.raises(FileNotFoundError) as _: create_alignment( @@ -123,14 +134,15 @@ def test_create_alignment_invalid_model(self, monkeypatch, sample_dataset): ) def test_create_alignment_invalid_dataset(self, monkeypatch, sample_model): + from .. import datasets from ..alignments import create_alignment - from .. import models - monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) - + + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) + dataset_id = "NON_EXISTENT_DATASET" engine_id = sample_model["engine_id"] model_id = sample_model["id"] - + # Assert error raised for invalid dataset msg, result = create_alignment( dataset_id=dataset_id, @@ -142,15 +154,15 @@ def test_create_alignment_invalid_dataset(self, monkeypatch, sample_model): assert "Dataset" in result def test_create_alignment_non_cached_dataset(self, monkeypatch, sample_model): + from .. import datasets + from ..alignments import AlignmentMetadata, create_alignment from ..datasets import create_dataset - from ..alignments import create_alignment, AlignmentMetadata - from .. import datasets, models + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) - monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) - + # Create a non-cached dataset dataset_path = mock_get_storage_root() / "fake_datasets" / "valid" - + success, dataset_metadata = create_dataset( name="test_dataset_non_cached", description="A test non-cached dataset for alignments", @@ -159,26 +171,26 @@ def test_create_alignment_non_cached_dataset(self, monkeypatch, sample_model): anonymize=False, transcribed=True, ) - + assert success is True - + dataset_id = dataset_metadata["id"] engine_id = sample_model["engine_id"] model_id = sample_model["id"] - + success, result = create_alignment( dataset_id=dataset_id, engine_id=engine_id, model_id=model_id, ) - + assert success is True assert isinstance(result, dict) - + # Verify all required keys are present required_keys = set(AlignmentMetadata.__annotations__.keys()) assert required_keys.issubset(set(result.keys())) - + # Verify field values assert result["engine_id"] == engine_id assert result["status"] == "Pending" @@ -191,95 +203,98 @@ def test_create_alignment_non_cached_dataset(self, monkeypatch, sample_model): alignment_root = tg_path.parent.parent assert alignment_root in tg_path.parents - - class TestGetAlignmentMetadata: - + class TestGetAlignmentMetadata: def test_get_alignment_metadata_success(self, monkeypatch, sample_dataset, sample_model): - from ..alignments import create_alignment, get_alignment_metadata from .. import models + from ..alignments import create_alignment, get_alignment_metadata + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) - + dataset_id = sample_dataset["id"] engine_id = sample_model["engine_id"] model_id = sample_model["id"] - + success, alignment_metadata = create_alignment( dataset_id=dataset_id, engine_id=engine_id, model_id=model_id, ) - + assert success is True - + alignment_id = alignment_metadata["id"] - + fetched_metadata = get_alignment_metadata( dataset_id=dataset_id, alignment_id=alignment_id, ) - + assert fetched_metadata is not None assert fetched_metadata["id"] == alignment_id assert fetched_metadata["engine_id"] == engine_id def test_get_alignment_metadata_invalid_id(self, monkeypatch, sample_dataset, sample_model): from ..alignments import create_alignment, get_alignment_metadata - + dataset_id = sample_dataset["id"] engine_id = sample_model["engine_id"] model_id = sample_model["id"] - + success, alignment_metadata = create_alignment( dataset_id=dataset_id, engine_id=engine_id, model_id=model_id, ) - + assert success is True - + invalid_alignment_id = "NON_EXISTENT_ALIGNMENT" - + fetched_metadata = get_alignment_metadata( dataset_id=dataset_id, alignment_id=invalid_alignment_id, ) - - assert fetched_metadata is None + assert fetched_metadata is None def test_get_alignment_metadata_invalid_dataset(self, monkeypatch, sample_model): + from .. import datasets from ..alignments import get_alignment_metadata - + + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) + invalid_dataset_id = "NON_EXISTENT_DATASET" alignment_id = "ANY_ALIGNMENT_ID" - + fetched_metadata = get_alignment_metadata( dataset_id=invalid_dataset_id, alignment_id=alignment_id, ) - + assert fetched_metadata is None class TestListAlignments: - def test_list_alignments_invalid_dataset(self, monkeypatch): + from .. import datasets from ..alignments import list_alignments - + + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) invalid_dataset_id = "NON_EXISTENT_DATASET" - + alignments_list = list_alignments(dataset_id=invalid_dataset_id) - + assert alignments_list == [] def test_list_alignments_success(self, monkeypatch, sample_dataset, sample_model): + from .. import datasets from ..alignments import create_alignment, list_alignments - from .. import models - monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) - + + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) + dataset_id = sample_dataset["id"] engine_id = sample_model["engine_id"] model_id = sample_model["id"] - + # Create multiple alignments num_alignments = 3 created_alignment_ids = set() @@ -291,81 +306,83 @@ def test_list_alignments_success(self, monkeypatch, sample_dataset, sample_model ) assert success is True created_alignment_ids.add(alignment_metadata["id"]) - + alignments_list = list_alignments(dataset_id=dataset_id) - + assert len(alignments_list) == num_alignments fetched_alignment_ids = {alignment["id"] for alignment in alignments_list} assert created_alignment_ids == fetched_alignment_ids - + def test_list_alignments_empty(self, monkeypatch, sample_dataset): from ..alignments import list_alignments - + dataset_id = sample_dataset["id"] - + alignments_list = list_alignments(dataset_id=dataset_id) - + assert alignments_list == [] class TestDeleteAlignment: - - def test_delete_alignment_success(self,sample_dataset, sample_model): + def test_delete_alignment_success(self, sample_dataset, sample_model): from ..alignments import create_alignment, delete_alignment, get_alignment_metadata - + dataset_id = sample_dataset["id"] engine_id = sample_model["engine_id"] model_id = sample_model["id"] - + success, alignment_metadata = create_alignment( dataset_id=dataset_id, engine_id=engine_id, model_id=model_id, ) - + assert success is True - + alignment_id = alignment_metadata["id"] - + delete_success, delete_msh = delete_alignment( dataset_id=dataset_id, alignment_id=alignment_id, ) - + assert delete_success is True assert "deleted successfully" in delete_msh - + # Verify alignment metadata no longer exists fetched_metadata = get_alignment_metadata( dataset_id=dataset_id, alignment_id=alignment_id, ) - + assert fetched_metadata is None def test_delete_alignment_invalid_id(self, sample_dataset): from ..alignments import delete_alignment - + dataset_id = sample_dataset["id"] invalid_alignment_id = "NON_EXISTENT_ALIGNMENT" - + delete_success, delete_msg = delete_alignment( dataset_id=dataset_id, alignment_id=invalid_alignment_id, ) - + assert delete_success is False assert "not found" in delete_msg - def test_delete_alignment_invalid_dataset(self): + def test_delete_alignment_invalid_dataset(self, monkeypatch): + from .. import datasets from ..alignments import delete_alignment - + + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) + invalid_dataset_id = "NON_EXISTENT_DATASET" alignment_id = "ANY_ALIGNMENT_ID" - + delete_success, delete_msg = delete_alignment( dataset_id=invalid_dataset_id, alignment_id=alignment_id, ) - + assert delete_success is False assert "not found" in delete_msg diff --git a/src/voxkit/storage/test/test_datasets.py b/src/voxkit/storage/test/test_datasets.py index bf147f0..f4edd09 100644 --- a/src/voxkit/storage/test/test_datasets.py +++ b/src/voxkit/storage/test/test_datasets.py @@ -1,21 +1,28 @@ +import os +import shutil from pathlib import Path + import pytest -import shutil -from .test_setup import activate_test_environment, deactivate_test_environment, mock_get_storage_root + from ..datasets import DatasetMetadata +from .test_setup import ( + activate_test_environment, + deactivate_test_environment, + mock_get_storage_root, +) + -import os def generate_fake_datasets(): """Generate a fake dataset for testing purposes.""" # Create wav/lab file pairs dataset_path = mock_get_storage_root() / "fake_datasets" / "valid" dataset_path.mkdir(parents=True, exist_ok=True) - + participant_names = ["participant_1", "participant_2", "participant_3"] for participant in participant_names: wavlab_path = dataset_path / participant wavlab_path.mkdir(parents=True, exist_ok=True) - + for i in range(5): wav_file = wavlab_path / f"sample_{i}.wav" lab_file = wavlab_path / f"sample_{i}.lab" @@ -28,7 +35,7 @@ def generate_fake_datasets(): for participant in participant_names: wavlab_path = dataset_path / participant wavlab_path.mkdir(parents=True, exist_ok=True) - + for i in range(5): wav_file = wavlab_path / f"sample_{i}.wav" wav_file.touch() @@ -37,10 +44,6 @@ def generate_fake_datasets(): # Create entirely empty dataset path dataset_path = mock_get_storage_root() / "fake_datasets" / "empty" dataset_path.mkdir(parents=True, exist_ok=True) - - -def mock_get_storage_root(): - return Path("./temp_test_storage_datasets") @pytest.fixture(autouse=True) @@ -53,29 +56,31 @@ def manage_test_environment(): yield # Cleanup after each test deactivate_test_environment(mock_get_storage_root()) - + + valid_dataset_path = mock_get_storage_root() / "fake_datasets" / "valid" invalid_dataset_path = mock_get_storage_root() / "fake_datasets" / "invalid" empty_dataset_path = mock_get_storage_root() / "fake_datasets" / "empty" -class TestDatasets: +class TestDatasets: class TestValidateDataset: def test_validate_dataset_valid(self, monkeypatch): - from ..datasets import validate_dataset from .. import models + from ..datasets import validate_dataset + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) - + is_valid, _ = validate_dataset(valid_dataset_path) assert is_valid is True - class TestCreateDataset: def test_create_dataset_success_no_cache(self, monkeypatch): + from .. import datasets from ..datasets import DatasetMetadata, create_dataset - from .. import models - monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) - + + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) + success, message = create_dataset( name="test_dataset", description="A test dataset", @@ -96,10 +101,11 @@ def test_create_dataset_success_no_cache(self, monkeypatch): assert message["description"] == "A test dataset" def test_create_dataset_success_with_cache(self, monkeypatch): + from .. import datasets from ..datasets import DatasetMetadata, create_dataset - from .. import models - monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) - + + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) + success, message = create_dataset( name="test_dataset_cached", description="A test dataset with caching", @@ -119,10 +125,11 @@ def test_create_dataset_success_with_cache(self, monkeypatch): assert message["description"] == "A test dataset with caching" def test_create_dataset_invalid_path(self, monkeypatch): - from ..datasets import create_dataset from .. import datasets + from ..datasets import create_dataset + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) - + success, message = create_dataset( name="test_dataset_invalid", description="A test dataset with invalid path", @@ -136,13 +143,13 @@ def test_create_dataset_invalid_path(self, monkeypatch): assert "No label files found" in message assert invalid_dataset_path.exists() is True - class TestListDatasets: def test_list_datasets_metadata(self, monkeypatch): - from ..datasets import create_dataset, list_datasets_metadata from .. import datasets + from ..datasets import create_dataset, list_datasets_metadata + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) - + # Create two datasets create_dataset( name="dataset_one", @@ -169,10 +176,11 @@ def test_list_datasets_metadata(self, monkeypatch): assert "dataset_two" in names def test_list_datasets_output_format(self, monkeypatch): - from ..datasets import create_dataset, list_datasets_metadata from .. import datasets + from ..datasets import create_dataset, list_datasets_metadata + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) - + # Create a dataset create_dataset( name="dataset_format_test", @@ -184,17 +192,18 @@ def test_list_datasets_output_format(self, monkeypatch): ) datasets = list_datasets_metadata() - + # Check that each dataset has all required fields for i in range(len(datasets)): for key in DatasetMetadata.__annotations__.keys(): assert key in datasets[i].keys() def test_list_datasets_empty(self, monkeypatch): - from ..datasets import list_datasets_metadata from .. import datasets + from ..datasets import list_datasets_metadata + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) - + # Ensure no datasets exist deactivate_test_environment(mock_get_storage_root()) activate_test_environment(mock_get_storage_root()) @@ -203,13 +212,13 @@ def test_list_datasets_empty(self, monkeypatch): assert isinstance(datasets, list) assert len(datasets) == 0 - class TestGetDatasetMetadata: def test_get_dataset_metadata_success(self, monkeypatch): - from ..datasets import create_dataset, get_dataset_metadata from .. import datasets + from ..datasets import create_dataset, get_dataset_metadata + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) - + # Create a dataset success, message = create_dataset( name="dataset_metadata_test", @@ -230,29 +239,32 @@ def test_get_dataset_metadata_success(self, monkeypatch): assert metadata == message def test_get_dataset_metadata_nonexistent(self, monkeypatch): - from ..datasets import get_dataset_metadata from .. import datasets + from ..datasets import get_dataset_metadata + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) - + # Attempt to retrieve metadata for a non-existent dataset metadata = get_dataset_metadata("nonexistent_id_12345") assert metadata is None def test_get_dataset_metadata_invalid_id(self, monkeypatch): - from ..datasets import get_dataset_metadata from .. import datasets + from ..datasets import get_dataset_metadata + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) - + # Attempt to retrieve metadata with an invalid ID format metadata = get_dataset_metadata("") - assert metadata is None + assert metadata is None class TestDeleteDataset: def test_delete_dataset_success(self, monkeypatch): - from ..datasets import create_dataset, delete_dataset, get_dataset_metadata from .. import datasets + from ..datasets import create_dataset, delete_dataset, get_dataset_metadata + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) - + # Create a dataset _, message = create_dataset( name="dataset_delete_test", @@ -273,31 +285,33 @@ def test_delete_dataset_success(self, monkeypatch): assert metadata is None def test_delete_dataset_nonexistent(self, monkeypatch): - from ..datasets import delete_dataset from .. import datasets + from ..datasets import delete_dataset + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) - + # Attempt to delete a non-existent dataset del_success, del_message = delete_dataset("nonexistent_id_12345") assert del_success is False assert "not found" in del_message def test_delete_dataset_invalid_id(self, monkeypatch): - from ..datasets import delete_dataset from .. import datasets + from ..datasets import delete_dataset + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) - + # Attempt to delete with an invalid ID format del_success, del_message = delete_dataset("") assert del_success is False assert "cannot be empty" in del_message - def test_delete_dataset_twice(self, monkeypatch): - from ..datasets import create_dataset, delete_dataset from .. import datasets + from ..datasets import create_dataset, delete_dataset + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) - + # Create a dataset _, message = create_dataset( name="dataset_delete_twice_test", @@ -319,10 +333,11 @@ def test_delete_dataset_twice(self, monkeypatch): assert "not found" in del_message def test_delete_dataset_invalid_id_format(self, monkeypatch): - from ..datasets import delete_dataset from .. import datasets + from ..datasets import delete_dataset + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) - + # Attempt to delete with an invalid ID format del_success, del_message = delete_dataset("!@#$%^&*()") assert del_success is False @@ -330,10 +345,11 @@ def test_delete_dataset_invalid_id_format(self, monkeypatch): class TestExportDataset: def test_export_dataset_success(self, monkeypatch, tmp_path): - from ..datasets import create_dataset, export_dataset, _get_datasets_root from .. import datasets + from ..datasets import create_dataset, export_dataset + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) - + # Create a dataset _, message = create_dataset( name="dataset_export_test", @@ -352,10 +368,11 @@ def test_export_dataset_success(self, monkeypatch, tmp_path): assert "exported successfully" in exp_message def test_export_equal(self, monkeypatch): - from ..datasets import create_dataset, export_dataset, _get_datasets_root from .. import datasets + from ..datasets import _get_datasets_root, create_dataset, export_dataset + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) - + # Create a dataset _, message = create_dataset( name="dataset_export_equal_test", @@ -375,11 +392,11 @@ def test_export_equal(self, monkeypatch): # Dataset metadata original_metadata = datasets.get_dataset_metadata(dataset_id) - + # Verify exported files match original files original_dataset_path = _get_datasets_root() / dataset_id destination_path = export_path / Path(original_metadata["name"] + "_" + str(dataset_id)) - + # Check that all files exist in the exported location for root, _, files in os.walk(original_dataset_path): rel_root = Path(root).relative_to(original_dataset_path) @@ -389,13 +406,13 @@ def test_export_equal(self, monkeypatch): assert exported_file.exists() is True assert original_file.stat().st_size == exported_file.stat().st_size - class TestImportDataset: def test_import_dataset_success(self, monkeypatch): - from ..datasets import create_dataset, import_dataset, validate_dataset from .. import datasets + from ..datasets import create_dataset, import_dataset, validate_dataset + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) - + # Create a dataset to export and then import _, message = create_dataset( name="dataset_import_test", @@ -410,9 +427,17 @@ def test_import_dataset_success(self, monkeypatch): export_path = mock_get_storage_root() datasets.export_dataset(dataset_id, export_path) - assert Path(export_path / Path(message["name"] + "_" + str(dataset_id)) / "cache").exists() is True - - assert validate_dataset(export_path / Path(message["name"] + "_" + str(dataset_id)) / "cache")[0] is True + assert ( + Path(export_path / Path(message["name"] + "_" + str(dataset_id)) / "cache").exists() + is True + ) + + assert ( + validate_dataset( + export_path / Path(message["name"] + "_" + str(dataset_id)) / "cache" + )[0] + is True + ) imp_success, imp_message = import_dataset( export_path / Path(message["name"] + "_" + str(dataset_id)), @@ -422,10 +447,11 @@ def test_import_dataset_success(self, monkeypatch): assert "imported successfully" in imp_message def test_import_dataset_nonexistent(self, monkeypatch): - from ..datasets import import_dataset from .. import datasets + from ..datasets import import_dataset + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) - + # Attempt to import from a non-existent path imp_success, imp_message = import_dataset( mock_get_storage_root() / "nonexistent_path_12345", @@ -435,8 +461,9 @@ def test_import_dataset_nonexistent(self, monkeypatch): assert "does not exist" in imp_message def test_import_dataset_empty_cache_true(self, monkeypatch): - from ..datasets import import_dataset, create_dataset, _get_datasets_root from .. import datasets + from ..datasets import _get_datasets_root, create_dataset, import_dataset + monkeypatch.setattr(datasets, "get_storage_root", mock_get_storage_root) # Create a dataset to export and then import @@ -459,26 +486,10 @@ def test_import_dataset_empty_cache_true(self, monkeypatch): export_path = mock_get_storage_root() datasets.export_dataset(dataset_id, export_path) - + imp_success, msg = import_dataset( empty_dataset_path, ) assert imp_success is False assert "file not found" in msg - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/src/voxkit/storage/test/test_models.py b/src/voxkit/storage/test/test_models.py index fda78a0..51bbbaf 100644 --- a/src/voxkit/storage/test/test_models.py +++ b/src/voxkit/storage/test/test_models.py @@ -1,7 +1,14 @@ from pathlib import Path + import pytest -from .test_setup import activate_test_environment, deactivate_test_environment, mock_get_storage_root, ENGINE_IDS + from ..utils import get_storage_root +from .test_setup import ( + ENGINE_IDS, + activate_test_environment, + deactivate_test_environment, + mock_get_storage_root, +) @pytest.fixture(autouse=True) @@ -16,8 +23,9 @@ def manage_test_environment(): class TestModels: class TestCreateModel: def test_create_model_success(self, monkeypatch): - from ..models import ModelMetadata, create_model from .. import models + from ..models import ModelMetadata, create_model + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) for engine_id in ENGINE_IDS: success, message = create_model( @@ -32,10 +40,10 @@ def test_create_model_success(self, monkeypatch): assert message["name"] == "test_model" assert message["engine_id"] == engine_id - def test_create_model_invalid_engine(self, monkeypatch): - from ..models import create_model from .. import models + from ..models import create_model + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) success, message = create_model( @@ -46,10 +54,10 @@ def test_create_model_invalid_engine(self, monkeypatch): assert "Unsupported engine_id" in message assert Path(get_storage_root() / "INVALID_ENGINE").exists() is False - def test_create_multiple_models(self, monkeypatch): - from ..models import create_model from .. import models + from ..models import create_model + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) engine_id = ENGINE_IDS[0] @@ -67,8 +75,9 @@ def test_create_multiple_models(self, monkeypatch): created_ids.add(message["id"]) def test_model_paths_created(self, monkeypatch): - from ..models import create_model from .. import models + from ..models import create_model + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) engine_id = ENGINE_IDS[0] @@ -87,11 +96,12 @@ def test_model_paths_created(self, monkeypatch): assert model_path.parent.exists(), "Model directory does not exist" assert data_path.exists() assert eval_path.exists() - assert train_path.exists() + assert train_path.exists() def test_model_fits_modelmetadata(self, monkeypatch): - from ..models import ModelMetadata, create_model from .. import models + from ..models import ModelMetadata, create_model + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) engine_id = ENGINE_IDS[0] @@ -108,8 +118,9 @@ def test_model_fits_modelmetadata(self, monkeypatch): class TestListModels: def test_list_models_empty(self, monkeypatch): - from ..models import list_models from .. import models + from ..models import list_models + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) engine_id = ENGINE_IDS[0] @@ -118,8 +129,9 @@ def test_list_models_empty(self, monkeypatch): assert len(models_list) == 0 def test_list_models_non_empty(self, monkeypatch): - from ..models import create_model, list_models from .. import models + from ..models import create_model, list_models + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) engine_id = ENGINE_IDS[0] @@ -135,8 +147,9 @@ def test_list_models_non_empty(self, monkeypatch): assert len(models_list) == 3 def test_list_models_output_format(self, monkeypatch): - from ..models import create_model, list_models from .. import models + from ..models import create_model, list_models + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) engine_id = ENGINE_IDS[0] @@ -156,8 +169,9 @@ def test_list_models_output_format(self, monkeypatch): assert not missing, f"Missing keys in model metadata: {missing}" def test_list_models_multiple_engines(self, monkeypatch): - from ..models import create_model, list_models from .. import models + from ..models import create_model, list_models + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) # Create models for different engines @@ -175,8 +189,9 @@ def test_list_models_multiple_engines(self, monkeypatch): class TestDeleteModel: def test_delete_model_success(self, monkeypatch): - from ..models import create_model, delete_model, list_models from .. import models + from ..models import create_model, delete_model, list_models + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) for engine_id in ENGINE_IDS: @@ -202,22 +217,26 @@ def test_delete_model_success(self, monkeypatch): assert model_id not in model_ids def test_delete_model_nonexistent(self, monkeypatch): - from ..models import delete_model from .. import models + from ..models import delete_model + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) engine_id = ENGINE_IDS[0] fake_model_id = "nonexistent_model_id" - with pytest.raises(TypeError) as _: - delete_model( - engine_id=engine_id, - model_id=fake_model_id, - ) + success, msg = delete_model( + engine_id=engine_id, + model_id=fake_model_id, + ) + + assert success is False + assert "not found" in msg def test_delete_model_multiple(self, monkeypatch): - from ..models import create_model, delete_model, list_models from .. import models + from ..models import create_model, delete_model, list_models + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) engine_id = ENGINE_IDS[0] @@ -245,8 +264,9 @@ def test_delete_model_multiple(self, monkeypatch): assert len(models_list) == 0 def test_delete_model_invalid_engine(self, monkeypatch): - from ..models import create_model, delete_model from .. import models + from ..models import create_model, delete_model + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) engine_id = ENGINE_IDS[0] @@ -260,16 +280,20 @@ def test_delete_model_invalid_engine(self, monkeypatch): # Attempt to delete with invalid engine_id invalid_engine_id = "INVALID_ENGINE" - with pytest.raises(TypeError) as _: - delete_model( - engine_id=invalid_engine_id, - model_id=model_id, - ) + + success, msg = delete_model( + engine_id=invalid_engine_id, + model_id=model_id, + ) + + assert success is False + assert "not found" in msg class TestGetModelMetadata: def test_get_model_metadata_success(self, monkeypatch): - from ..models import create_model, get_model_metadata from .. import models + from ..models import create_model, get_model_metadata + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) for engine_id in ENGINE_IDS: @@ -291,8 +315,9 @@ def test_get_model_metadata_success(self, monkeypatch): assert metadata["engine_id"] == engine_id def test_get_model_metadata_nonexistent(self, monkeypatch): - from ..models import get_model_metadata from .. import models + from ..models import get_model_metadata + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) engine_id = ENGINE_IDS[0] @@ -303,10 +328,11 @@ def test_get_model_metadata_nonexistent(self, monkeypatch): engine_id=engine_id, model_id=fake_model_id, ) - + def test_get_model_metadata_multiple(self, monkeypatch): - from ..models import create_model, get_model_metadata from .. import models + from ..models import create_model, get_model_metadata + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) engine_id = ENGINE_IDS[0] @@ -330,10 +356,10 @@ def test_get_model_metadata_multiple(self, monkeypatch): assert metadata["id"] == model_id assert metadata["engine_id"] == engine_id - def test_get_model_metadata_invalid_engine(self, monkeypatch): - from ..models import create_model, get_model_metadata from .. import models + from ..models import create_model, get_model_metadata + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) engine_id = ENGINE_IDS[0] @@ -354,8 +380,9 @@ def test_get_model_metadata_invalid_engine(self, monkeypatch): ) def test_get_model_metadata_output_format(self, monkeypatch): - from ..models import create_model, get_model_metadata from .. import models + from ..models import create_model, get_model_metadata + monkeypatch.setattr(models, "get_storage_root", mock_get_storage_root) engine_id = ENGINE_IDS[0] @@ -376,5 +403,3 @@ def test_get_model_metadata_output_format(self, monkeypatch): required_keys = set(models.ModelMetadata.__annotations__.keys()) missing = required_keys - set(metadata.keys()) assert not missing, f"Missing keys in model metadata: {missing}" - - \ No newline at end of file diff --git a/src/voxkit/storage/test/test_setup.py b/src/voxkit/storage/test/test_setup.py index 6495927..cd2b876 100644 --- a/src/voxkit/storage/test/test_setup.py +++ b/src/voxkit/storage/test/test_setup.py @@ -1,21 +1,23 @@ - import shutil -from ..config import MODELS_ROOT from pathlib import Path +from ..config import MODELS_ROOT + ENGINE_IDS = ["ENGINE_A", "ENGINE_B", "ENGINE_C"] + def activate_test_environment(storage_root, engine_ids=ENGINE_IDS) -> None: """Activate the test environment by overriding storage paths.""" for engine_id in engine_ids: engine_root = storage_root / engine_id / MODELS_ROOT engine_root.mkdir(parents=True, exist_ok=False) + def deactivate_test_environment(storage_root) -> None: """Deactivate the test environment by resetting storage paths.""" if storage_root.exists(): shutil.rmtree(storage_root) + def mock_get_storage_root(): return Path("./temp_test_storage_models") - diff --git a/src/voxkit/storage/utils.py b/src/voxkit/storage/utils.py index abddbbd..3d105bb 100644 --- a/src/voxkit/storage/utils.py +++ b/src/voxkit/storage/utils.py @@ -16,14 +16,16 @@ The storage root must be configured as a path starting with '~' to ensure it references the user's home directory. """ + from datetime import datetime from pathlib import Path + from .config import STORAGE_ROOT def get_storage_root() -> Path: """Get the root directory for storing VoxKit data. - + This uses ~ (home directory) so it works regardless of how the app is launched. """ if STORAGE_ROOT.startswith("~"): @@ -32,8 +34,10 @@ def get_storage_root() -> Path: raise ValueError("STORAGE_ROOT must be a valid path starting with '~'") -def generate_unique_id(prefix: str = None) -> str: - """Generate a unique identifier with the given prefix and current timestamp including microseconds.""" +def generate_unique_id(prefix: str | None = None) -> str: + """Generate a unique identifier with the given prefix and current timestamp + including microseconds. + """ now = datetime.now().strftime("%Y%m%d_%H%M%S_%f") if prefix: return f"{prefix}_{now}"