diff --git a/alembic/manual_migrations/migrate_jsonb_ranges_to_table_rows.py b/alembic/manual_migrations/migrate_jsonb_ranges_to_table_rows.py new file mode 100644 index 00000000..f5219369 --- /dev/null +++ b/alembic/manual_migrations/migrate_jsonb_ranges_to_table_rows.py @@ -0,0 +1,374 @@ +""" +Migration script to convert JSONB functional_ranges to the new row-based implementation. + +This script migrates data from ScoreCalibration.functional_ranges (JSONB column) +to the new ScoreCalibrationFunctionalClassification table with proper foreign key relationships. +""" +from typing import Any, Dict + +import sqlalchemy as sa +from sqlalchemy.orm import Session, configure_mappers + +from mavedb.models import * +from mavedb.db.session import SessionLocal +from mavedb.models.acmg_classification import ACMGClassification +from mavedb.models.enums.acmg_criterion import ACMGCriterion +from mavedb.models.enums.functional_classification import FunctionalClassification +from mavedb.models.enums.strength_of_evidence import StrengthOfEvidenceProvided +from mavedb.models.score_calibration import ScoreCalibration +from mavedb.models.score_calibration_functional_classification import ScoreCalibrationFunctionalClassification +from mavedb.models.score_calibration_functional_classification_variant_association import ( + score_calibration_functional_classification_variants_association_table +) +from mavedb.models.variant import Variant +from mavedb.view_models.acmg_classification import ACMGClassificationCreate + +configure_mappers() + + +def populate_variant_associations( + db: Session, + functional_classification: ScoreCalibrationFunctionalClassification, + calibration: ScoreCalibration, +) -> int: + """Populate the association table with variants that fall within this functional range.""" + # Create a view model instance to use the existing range checking logic + if not functional_classification or not functional_classification.range: + print(f" Skipping variant association - no valid range or view model") + return 0 + + print(f" Finding variants within range {functional_classification.range} (lower_inclusive={functional_classification.inclusive_lower_bound}, upper_inclusive={functional_classification.inclusive_upper_bound})") + + # Get all variants for this score set and their scores + variants_query = db.execute(sa.select(Variant).where( + Variant.score_set_id == calibration.score_set_id, + )).scalars().all() + + variants_in_range = [] + total_variants = 0 + + for variant in variants_query: + total_variants += 1 + + # Extract score from JSONB data + try: + score_data = variant.data.get("score_data", {}).get("score") if variant.data else None + if score_data is not None: + variant_score = float(score_data) + + # Use the existing view model method for range checking + if functional_classification.score_is_contained_in_range(variant_score): + variants_in_range.append(variant) + + except (ValueError, TypeError) as e: + print(f" Warning: Could not parse score for variant {variant.id}: {e}") + continue + + print(f" Found {len(variants_in_range)} variants in range out of {total_variants} total variants") + + # Bulk insert associations + if variants_in_range: + associations = [ + { + "functional_classification_id": functional_classification.id, + "variant_id": variant.id + } + for variant in variants_in_range + ] + + db.execute( + score_calibration_functional_classification_variants_association_table.insert(), + associations + ) + + return len(variants_in_range) + + +def migrate_functional_range_to_row( + db: Session, + calibration: ScoreCalibration, + functional_range: Dict[str, Any], + acmg_classification_cache: Dict[str, ACMGClassification] +) -> ScoreCalibrationFunctionalClassification: + """Convert a single functional range from JSONB to table row.""" + + # Handle ACMG classification if present + acmg_classification_id = None + acmg_data = functional_range.get("acmg_classification") + if acmg_data: + # Create a cache key for the ACMG classification + criterion = acmg_data.get("criterion").upper() if acmg_data.get("criterion") else None + evidence_strength = acmg_data.get("evidence_strength").upper() if acmg_data.get("evidence_strength") else None + points = acmg_data.get("points") + + classification = ACMGClassificationCreate( + criterion=ACMGCriterion(criterion) if criterion else None, + evidence_strength=StrengthOfEvidenceProvided(evidence_strength) if evidence_strength else None, + points=points + ) + + cache_key = f"{classification.criterion}_{classification.evidence_strength}_{classification.points}" + + if cache_key not in acmg_classification_cache: + # Create new ACMG classification + acmg_classification = ACMGClassification( + criterion=classification.criterion, + evidence_strength=classification.evidence_strength, + points=classification.points + ) + db.add(acmg_classification) + db.flush() # Get the ID + acmg_classification_cache[cache_key] = acmg_classification + + acmg_classification_id = acmg_classification_cache[cache_key].id + + # Create the functional classification row + functional_classification = ScoreCalibrationFunctionalClassification( + calibration_id=calibration.id, + label=functional_range.get("label", ""), + description=functional_range.get("description"), + classification=FunctionalClassification(functional_range.get("classification", "not_specified")), + range=functional_range.get("range"), + inclusive_lower_bound=functional_range.get("inclusive_lower_bound"), + inclusive_upper_bound=functional_range.get("inclusive_upper_bound"), + oddspaths_ratio=functional_range.get("oddspaths_ratio"), + positive_likelihood_ratio=functional_range.get("positive_likelihood_ratio"), + acmg_classification_id=acmg_classification_id + ) + + return functional_classification + + +def do_migration(db: Session): + """Main migration function.""" + print("Starting migration of JSONB functional_ranges to table rows...") + + # Find all calibrations with functional_ranges + calibrations_with_ranges = db.scalars( + sa.select(ScoreCalibration).where(ScoreCalibration.functional_ranges_deprecated_json.isnot(None)) + ).all() + + print(f"Found {len(calibrations_with_ranges)} calibrations with functional ranges to migrate.") + + # Cache for ACMG classifications to avoid duplicates + acmg_classification_cache: Dict[str, ACMGClassification] = {} + + migrated_count = 0 + error_count = 0 + + for calibration in calibrations_with_ranges: + try: + print(f"Migrating calibration {calibration.id} (URN: {calibration.urn})...") + + functional_ranges_data = calibration.functional_ranges_deprecated_json + if not functional_ranges_data or not isinstance(functional_ranges_data, list): + print(f" Skipping calibration {calibration.id} - no valid functional ranges data") + continue + + # Create functional classification rows for each range + functional_classifications = [] + for i, functional_range in enumerate(functional_ranges_data): + try: + functional_classification = migrate_functional_range_to_row( + db, calibration, functional_range, acmg_classification_cache + ) + db.add(functional_classification) + functional_classifications.append(functional_classification) + print(f" Created functional classification row {i+1}/{len(functional_ranges_data)}") + + except Exception as e: + print(f" Error migrating functional range {i+1} for calibration {calibration.id}: {e}") + error_count += 1 + continue + + # Flush to get IDs for the functional classifications + db.flush() + + # Populate variant associations for each functional classification + total_associations = 0 + for functional_classification in functional_classifications: + try: + associations_count = populate_variant_associations( + db, functional_classification, calibration + ) + total_associations += associations_count + + except Exception as e: + print(f" Error populating variant associations for functional classification {functional_classification.id}: {e}") + error_count += 1 + continue + + print(f" Created {total_associations} variant associations") + + # Commit the changes for this calibration + db.commit() + migrated_count += 1 + print(f" Successfully migrated calibration {calibration.id}") + + except Exception as e: + print(f"Error migrating calibration {calibration.id}: {e}") + db.rollback() + error_count += 1 + continue + + # Final statistics + total_functional_classifications = db.scalar( + sa.select(sa.func.count(ScoreCalibrationFunctionalClassification.id)) + ) + + total_associations = db.scalar( + sa.select(sa.func.count()).select_from( + score_calibration_functional_classification_variants_association_table + ) + ) or 0 + + print(f"\nMigration completed:") + print(f" Successfully migrated: {migrated_count} calibrations") + print(f" Functional classification rows created: {total_functional_classifications}") + print(f" Variant associations created: {total_associations}") + print(f" ACMG classifications created: {len(acmg_classification_cache)}") + print(f" Errors encountered: {error_count}") + + +def verify_migration(db: Session): + """Verify that the migration was successful.""" + print("\nVerifying migration...") + + # Count original calibrations with functional ranges + original_count = db.scalar( + sa.select(sa.func.count(ScoreCalibration.id)).where( + ScoreCalibration.functional_ranges_deprecated_json.isnot(None) + ) + ) + + # Count migrated functional classifications + migrated_count = db.scalar( + sa.select(sa.func.count(ScoreCalibrationFunctionalClassification.id)) + ) + + # Count ACMG classifications + acmg_count = db.scalar( + sa.select(sa.func.count(ACMGClassification.id)) + ) + + # Count variant associations + association_count = db.scalar( + sa.select(sa.func.count()).select_from( + score_calibration_functional_classification_variants_association_table + ) + ) + + print(f"Original calibrations with functional ranges: {original_count}") + print(f"Migrated functional classification rows: {migrated_count}") + print(f"ACMG classification records: {acmg_count}") + print(f"Variant associations created: {association_count}") + + # Sample verification - check that relationships work + sample_classification = db.scalar( + sa.select(ScoreCalibrationFunctionalClassification).limit(1) + ) + + if sample_classification: + print(f"\nSample verification:") + print(f" Functional classification ID: {sample_classification.id}") + print(f" Label: {sample_classification.label}") + print(f" Classification: {sample_classification.classification}") + print(f" Range: {sample_classification.range}") + print(f" Calibration ID: {sample_classification.calibration_id}") + print(f" ACMG classification ID: {sample_classification.acmg_classification_id}") + + # Count variants associated with this classification + variant_count = db.scalar( + sa.select(sa.func.count()).select_from( + score_calibration_functional_classification_variants_association_table + ).where( + score_calibration_functional_classification_variants_association_table.c.functional_classification_id == sample_classification.id + ) + ) + print(f" Associated variants: {variant_count}") + + # Functional classifications by type + classification_stats = db.execute( + sa.select( + ScoreCalibrationFunctionalClassification.classification, + sa.func.count().label('count') + ).group_by(ScoreCalibrationFunctionalClassification.classification) + ).all() + + for classification, count in classification_stats: + print(f"{classification}: {count} ranges") + + + +def rollback_migration(db: Session): + """Rollback the migration by deleting all migrated data.""" + print("Rolling back migration...") + + # Count records before deletion + functional_count = db.scalar( + sa.select(sa.func.count(ScoreCalibrationFunctionalClassification.id)) + ) + + acmg_count = db.scalar( + sa.select(sa.func.count(ACMGClassification.id)) + ) + + association_count = db.scalar( + sa.select(sa.func.count()).select_from( + score_calibration_functional_classification_variants_association_table + ) + ) + + # Delete in correct order (associations first, then functional classifications, then ACMG) + db.execute(sa.delete(score_calibration_functional_classification_variants_association_table)) + db.execute(sa.delete(ScoreCalibrationFunctionalClassification)) + db.execute(sa.delete(ACMGClassification)) + db.commit() + + print(f"Deleted {association_count} variant associations") + print(f"Deleted {functional_count} functional classification rows") + print(f"Deleted {acmg_count} ACMG classification rows") + + +def show_usage(): + """Show usage information.""" + print(""" +Usage: python migrate_jsonb_ranges_to_table_rows.py [command] + +Commands: + migrate (default) - Migrate JSONB functional_ranges to table rows + verify - Verify migration without running it + rollback - Remove all migrated data (destructive!) + +Examples: + python migrate_jsonb_ranges_to_table_rows.py # Run migration + python migrate_jsonb_ranges_to_table_rows.py verify # Check status + python migrate_jsonb_ranges_to_table_rows.py rollback # Undo migration +""") + + +if __name__ == "__main__": + import sys + + command = sys.argv[1] if len(sys.argv) > 1 else "migrate" + + if command == "help" or command == "--help" or command == "-h": + show_usage() + elif command == "rollback": + print("WARNING: This will delete all migrated functional classification data!") + response = input("Are you sure you want to continue? (y/N): ") + if response.lower() == 'y': + with SessionLocal() as db: + rollback_migration(db) + else: + print("Rollback cancelled.") + elif command == "verify": + with SessionLocal() as db: + verify_migration(db) + elif command == "migrate": + with SessionLocal() as db: + do_migration(db) + verify_migration(db) + else: + print(f"Unknown command: {command}") + show_usage() diff --git a/alembic/versions/00dab0f5f498_add_external_links_property_to_.py b/alembic/versions/00dab0f5f498_add_external_links_property_to_.py new file mode 100644 index 00000000..01c25ad4 --- /dev/null +++ b/alembic/versions/00dab0f5f498_add_external_links_property_to_.py @@ -0,0 +1,33 @@ +"""add external links property to experiments + +Revision ID: 00dab0f5f498 +Revises: b22b450d409c +Create Date: 2025-12-16 12:06:15.265947 + +""" + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "00dab0f5f498" +down_revision = "b22b450d409c" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "experiments", + sa.Column("external_links", postgresql.JSONB(astext_type=sa.Text()), nullable=False, server_default="{}"), + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("experiments", "external_links") + # ### end Alembic commands ### diff --git a/alembic/versions/0520dfa9f2db_rename_functional_ranges_to_functional_.py b/alembic/versions/0520dfa9f2db_rename_functional_ranges_to_functional_.py new file mode 100644 index 00000000..7b66d976 --- /dev/null +++ b/alembic/versions/0520dfa9f2db_rename_functional_ranges_to_functional_.py @@ -0,0 +1,45 @@ +"""rename functional ranges to functional classifications, add class_ to model, rename classification to functional_classification + +Revision ID: 0520dfa9f2db +Revises: c770fa9e6e58 +Create Date: 2025-11-18 18:51:33.107952 + +""" + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "0520dfa9f2db" +down_revision = "c770fa9e6e58" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "score_calibration_functional_classifications", + "classification", + new_column_name="functional_classification", + type_=sa.Enum( + "normal", "abnormal", "not_specified", name="functionalclassification", native_enum=False, length=32 + ), + nullable=False, + ) + op.add_column("score_calibration_functional_classifications", sa.Column("class_", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "score_calibration_functional_classifications", + "functional_classification", + new_column_name="classification", + type_=sa.VARCHAR(length=32), + nullable=False, + ) + op.drop_column("score_calibration_functional_classifications", "class_") + # ### end Alembic commands ### diff --git a/alembic/versions/16beeb593513_add_acmg_classification_and_functional_.py b/alembic/versions/16beeb593513_add_acmg_classification_and_functional_.py new file mode 100644 index 00000000..41e86383 --- /dev/null +++ b/alembic/versions/16beeb593513_add_acmg_classification_and_functional_.py @@ -0,0 +1,141 @@ +"""add acmg classification and functional classification tables + +Revision ID: 16beeb593513 +Revises: b22b450d409c +Create Date: 2025-11-17 11:46:38.276980 + +""" + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "16beeb593513" +down_revision = "b22b450d409c" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "acmg_classifications", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "criterion", + sa.Enum( + "PVS1", + "PS1", + "PS2", + "PS3", + "PS4", + "PM1", + "PM2", + "PM3", + "PM4", + "PM5", + "PM6", + "PP1", + "PP2", + "PP3", + "PP4", + "PP5", + "BA1", + "BS1", + "BS2", + "BS3", + "BS4", + "BP1", + "BP2", + "BP3", + "BP4", + "BP5", + "BP6", + "BP7", + name="acmgcriterion", + native_enum=False, + length=32, + ), + nullable=True, + ), + sa.Column( + "evidence_strength", + sa.Enum( + "VERY_STRONG", + "STRONG", + "MODERATE_PLUS", + "MODERATE", + "SUPPORTING", + name="strengthofevidenceprovided", + native_enum=False, + length=32, + ), + nullable=True, + ), + sa.Column("points", sa.Integer(), nullable=True), + sa.Column("creation_date", sa.Date(), nullable=False), + sa.Column("modification_date", sa.Date(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "score_calibration_functional_classifications", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("calibration_id", sa.Integer(), nullable=False), + sa.Column("label", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column( + "classification", + sa.Enum( + "normal", "abnormal", "not_specified", name="functionalclassification", native_enum=False, length=32 + ), + nullable=False, + ), + sa.Column("range", postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), nullable=True), + sa.Column("inclusive_lower_bound", sa.Boolean(), nullable=True), + sa.Column("inclusive_upper_bound", sa.Boolean(), nullable=True), + sa.Column("oddspaths_ratio", sa.Float(), nullable=True), + sa.Column("positive_likelihood_ratio", sa.Float(), nullable=True), + sa.Column("acmg_classification_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["acmg_classification_id"], + ["acmg_classifications.id"], + ), + sa.ForeignKeyConstraint( + ["calibration_id"], + ["score_calibrations.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "score_calibration_functional_classification_variants", + sa.Column("functional_classification_id", sa.Integer(), nullable=False), + sa.Column("variant_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["functional_classification_id"], + ["score_calibration_functional_classifications.id"], + ), + sa.ForeignKeyConstraint( + ["variant_id"], + ["variants.id"], + ), + sa.PrimaryKeyConstraint("functional_classification_id", "variant_id"), + ) + op.alter_column("score_calibrations", "functional_ranges", new_column_name="functional_ranges_deprecated_json") + op.create_index( + op.f("ix_score_calibrations_modified_by_id"), "score_calibrations", ["modified_by_id"], unique=False + ) + op.create_index(op.f("ix_score_calibrations_urn"), "score_calibrations", ["urn"], unique=True) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_score_calibrations_modified_by_id"), table_name="score_calibrations") + op.drop_index(op.f("ix_score_calibrations_created_by_id"), table_name="score_calibrations") + op.drop_table("score_calibration_functional_classification_variants") + op.drop_table("score_calibration_functional_classifications") + op.drop_table("acmg_classifications") + op.alter_column("score_calibrations", "functional_ranges_deprecated_json", new_column_name="functional_ranges") + # ### end Alembic commands ### diff --git a/alembic/versions/c770fa9e6e58_drop_functional_range_jsonb.py b/alembic/versions/c770fa9e6e58_drop_functional_range_jsonb.py new file mode 100644 index 00000000..3b1e7998 --- /dev/null +++ b/alembic/versions/c770fa9e6e58_drop_functional_range_jsonb.py @@ -0,0 +1,38 @@ +"""drop functional range jsonb + +Revision ID: c770fa9e6e58 +Revises: 16beeb593513 +Create Date: 2025-11-17 22:19:22.440742 + +""" + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "c770fa9e6e58" +down_revision = "16beeb593513" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("score_calibrations", "functional_ranges_deprecated_json") + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "score_calibrations", + sa.Column( + "functional_ranges_deprecated_json", + postgresql.JSONB(astext_type=sa.Text()), + autoincrement=False, + nullable=True, + ), + ) + # ### end Alembic commands ### diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index 5c674d96..d9d430af 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -49,6 +49,7 @@ services: dcd-mapping: build: ../dcd_mapping + platform: linux/amd64 image: dcd-mapping:dev command: bash -c "uvicorn api.server_main:app --host 0.0.0.0 --port 8000 --reload" depends_on: @@ -59,6 +60,7 @@ services: ports: - "8004:8000" volumes: + - ../dcd_mapping:/usr/src/app - mavedb-seqrepo-dev:/usr/local/share/seqrepo cdot-rest: diff --git a/poetry.lock b/poetry.lock index 18ecdd5e..c50bfea6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -165,7 +165,6 @@ files = [ {file = "attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3"}, {file = "attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b"}, ] -markers = {main = "extra == \"server\""} [package.extras] benchmark = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] @@ -263,7 +262,6 @@ description = "miscellaneous simple bioinformatics utilities and lookup tables" optional = false python-versions = ">=3.10" groups = ["main"] -markers = "extra == \"server\"" files = [ {file = "bioutils-0.6.1-py3-none-any.whl", hash = "sha256:9928297331b9fc0a4fd4235afdef9a80a0916d8b5c2811ab781bded0dad4b9b6"}, {file = "bioutils-0.6.1.tar.gz", hash = "sha256:6ad7a9b6da73beea798a935499339d8b60a434edc37dfc803474d2e93e0e64aa"}, @@ -769,7 +767,6 @@ description = "Canonical JSON" optional = false python-versions = ">=3.7" groups = ["main"] -markers = "extra == \"server\"" files = [ {file = "canonicaljson-2.0.0-py3-none-any.whl", hash = "sha256:c38a315de3b5a0532f1ec1f9153cd3d716abfc565a558d00a4835428a34fca5b"}, {file = "canonicaljson-2.0.0.tar.gz", hash = "sha256:e2fdaef1d7fadc5d9cb59bd3d0d41b064ddda697809ac4325dced721d12f113f"}, @@ -1501,7 +1498,6 @@ description = "GA4GH Categorical Variation Representation (Cat-VRS) reference im optional = false python-versions = ">=3.10" groups = ["main"] -markers = "extra == \"server\"" files = [ {file = "ga4gh_cat_vrs-0.7.1-py3-none-any.whl", hash = "sha256:549e726182d9fdc28d049b9adc6a8c65189bbade06b2ceed8cb20a35cbdefc45"}, {file = "ga4gh_cat_vrs-0.7.1.tar.gz", hash = "sha256:ac8d11ea5f474e8a9745107673d4e8b6949819ccdc9debe2ab8ad8e5f853f87c"}, @@ -1523,7 +1519,6 @@ description = "GA4GH Variant Annotation (VA) reference implementation" optional = false python-versions = ">=3.10" groups = ["main"] -markers = "extra == \"server\"" files = [ {file = "ga4gh_va_spec-0.4.2-py3-none-any.whl", hash = "sha256:c165a96dfa225845b5d63740d3ad40c9f2dcb26808cf759b73bc122a68a9a60e"}, {file = "ga4gh_va_spec-0.4.2.tar.gz", hash = "sha256:13eda6a8cfc7a2baa395e33d17e3296c2ec1c63ec85fe38085751c112cf1c902"}, @@ -1546,7 +1541,6 @@ description = "GA4GH Variation Representation Specification (VRS) reference impl optional = false python-versions = ">=3.10" groups = ["main"] -markers = "extra == \"server\"" files = [ {file = "ga4gh_vrs-2.1.3-py3-none-any.whl", hash = "sha256:15b20363d9d4a4604be0930b41b14c9b4e6dc15a6e8be813544f0775b873bc5b"}, {file = "ga4gh_vrs-2.1.3.tar.gz", hash = "sha256:48af6de1eb40e00aa68ed5a935061917b4017468ef366e8e68bbbc17ffaa60f3"}, @@ -3815,7 +3809,6 @@ files = [ {file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"}, {file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"}, ] -markers = {main = "extra == \"server\""} [package.extras] check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""] @@ -4794,9 +4787,9 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more_it type = ["pytest-mypy"] [extras] -server = ["alembic", "alembic-utils", "arq", "authlib", "biocommons", "boto3", "cdot", "cryptography", "fastapi", "ga4gh-va-spec", "hgvs", "orcid", "psycopg2", "pyathena", "python-jose", "python-multipart", "requests", "slack-sdk", "starlette", "starlette-context", "uvicorn", "watchtower"] +server = ["alembic", "alembic-utils", "arq", "authlib", "biocommons", "boto3", "cdot", "cryptography", "fastapi", "hgvs", "orcid", "psycopg2", "pyathena", "python-jose", "python-multipart", "requests", "slack-sdk", "starlette", "starlette-context", "uvicorn", "watchtower"] [metadata] lock-version = "2.1" python-versions = "^3.11" -content-hash = "cb94d5f7faedc07aa0e3457fdb0735b6526b2f40f02c6d438cab46b733123fd6" +content-hash = "83fa85dbfeb224b9f3f68539182b9ccabca4b05c13182da12e1bf12c50eafbc4" diff --git a/pyproject.toml b/pyproject.toml index ca00ecf0..0ac106d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ SQLAlchemy = { extras = ["mypy"], version = "~2.0.0" } [tool.poetry.extras] -server = ["alembic", "alembic-utils", "arq", "authlib", "biocommons", "boto3", "cdot", "cryptography", "fastapi", "hgvs", "ga4gh-va-spec", "orcid", "psycopg2", "python-jose", "python-multipart", "pyathena", "requests", "starlette", "starlette-context", "slack-sdk", "uvicorn", "watchtower"] +server = ["alembic", "alembic-utils", "arq", "authlib", "biocommons", "boto3", "cdot", "cryptography", "fastapi", "hgvs", "orcid", "psycopg2", "python-jose", "python-multipart", "pyathena", "requests", "starlette", "starlette-context", "slack-sdk", "uvicorn", "watchtower"] [tool.mypy] @@ -100,7 +100,7 @@ plugins = [ mypy_path = "mypy_stubs" [tool.pytest.ini_options] -addopts = "-v -rP --import-mode=importlib --disable-socket --allow-unix-socket --allow-hosts localhost,::1,127.0.0.1" +addopts = "-v --import-mode=importlib --disable-socket --allow-unix-socket --allow-hosts localhost,::1,127.0.0.1" asyncio_mode = 'strict' testpaths = "tests/" pythonpath = "." diff --git a/src/mavedb/__init__.py b/src/mavedb/__init__.py index 60558b4a..9041300b 100644 --- a/src/mavedb/__init__.py +++ b/src/mavedb/__init__.py @@ -9,3 +9,6 @@ __version__ = "2025.5.0" logger.info(f"MaveDB {__version__}") + +# Import the model rebuild module to ensure all view model forward references are resolved +from mavedb.view_models import model_rebuild # noqa: F401, E402 diff --git a/src/mavedb/lib/acmg.py b/src/mavedb/lib/acmg.py index 971923c2..d7de860e 100644 --- a/src/mavedb/lib/acmg.py +++ b/src/mavedb/lib/acmg.py @@ -1,58 +1,11 @@ -from enum import Enum from typing import Optional +from sqlalchemy import select +from sqlalchemy.orm import Session -class ACMGCriterion(str, Enum): - """Enum for ACMG criteria codes.""" - - PVS1 = "PVS1" - PS1 = "PS1" - PS2 = "PS2" - PS3 = "PS3" - PS4 = "PS4" - PM1 = "PM1" - PM2 = "PM2" - PM3 = "PM3" - PM4 = "PM4" - PM5 = "PM5" - PM6 = "PM6" - PP1 = "PP1" - PP2 = "PP2" - PP3 = "PP3" - PP4 = "PP4" - PP5 = "PP5" - BA1 = "BA1" - BS1 = "BS1" - BS2 = "BS2" - BS3 = "BS3" - BS4 = "BS4" - BP1 = "BP1" - BP2 = "BP2" - BP3 = "BP3" - BP4 = "BP4" - BP5 = "BP5" - BP6 = "BP6" - BP7 = "BP7" - - @property - def is_pathogenic(self) -> bool: - """Return True if the criterion is pathogenic, False if benign.""" - return self.name.startswith("P") # PVS, PS, PM, PP are pathogenic criteria - - @property - def is_benign(self) -> bool: - """Return True if the criterion is benign, False if pathogenic.""" - return self.name.startswith("B") # BA, BS, BP are benign criteria - - -class StrengthOfEvidenceProvided(str, Enum): - """Enum for strength of evidence provided.""" - - VERY_STRONG = "very_strong" - STRONG = "strong" - MODERATE_PLUS = "moderate_plus" - MODERATE = "moderate" - SUPPORTING = "supporting" +from mavedb.models.acmg_classification import ACMGClassification +from mavedb.models.enums.acmg_criterion import ACMGCriterion +from mavedb.models.enums.strength_of_evidence import StrengthOfEvidenceProvided def points_evidence_strength_equivalent( @@ -121,3 +74,61 @@ def points_evidence_strength_equivalent( return (ACMGCriterion.BS3, StrengthOfEvidenceProvided.STRONG) else: # points <= -8 return (ACMGCriterion.BS3, StrengthOfEvidenceProvided.VERY_STRONG) + + +def find_or_create_acmg_classification( + db: Session, + criterion: Optional[ACMGCriterion], + evidence_strength: Optional[StrengthOfEvidenceProvided], + points: Optional[int], +): + """Create or find an ACMG classification based on criterion, evidence strength, and points. + + Parameters + ---------- + db : Session + The database session to use for querying and creating the ACMG classification. + criterion : Optional[ACMGCriterion] + The ACMG criterion for the classification. + evidence_strength : Optional[StrengthOfEvidenceProvided] + The strength of evidence provided for the classification. + points : Optional[int] + The point value associated with the classification. + + Returns + ------- + ACMGClassification + The existing or newly created ACMG classification instance. + + Raises + ------ + ValueError + If the combination of criterion, evidence strength, and points does not correspond to a valid ACMG classification. + + Notes + ----- + - This function does not commit the new entry to the database; the caller is responsible for committing the session. + """ + if (criterion is None) != (evidence_strength is None): + raise ValueError("Both criterion and evidence_strength must be provided together or both be None, with points.") + elif criterion is None and evidence_strength is None and points is not None: + criterion, evidence_strength = points_evidence_strength_equivalent(points) + + # If we cannot infer a classification, return None + if criterion is None and evidence_strength is None: + return None + + acmg_classification = db.execute( + select(ACMGClassification) + .where(ACMGClassification.criterion == criterion) + .where(ACMGClassification.evidence_strength == evidence_strength) + .where(ACMGClassification.points == points) + ).scalar_one_or_none() + + if not acmg_classification: + acmg_classification = ACMGClassification( + criterion=criterion, evidence_strength=evidence_strength, points=points + ) + db.add(acmg_classification) + + return acmg_classification diff --git a/src/mavedb/lib/annotation/classification.py b/src/mavedb/lib/annotation/classification.py index 9bf7526b..19dd13a5 100644 --- a/src/mavedb/lib/annotation/classification.py +++ b/src/mavedb/lib/annotation/classification.py @@ -5,8 +5,9 @@ from ga4gh.va_spec.acmg_2015 import VariantPathogenicityEvidenceLine from ga4gh.va_spec.base.enums import StrengthOfEvidenceProvided +from mavedb.models.enums.functional_classification import FunctionalClassification as FunctionalClassificationOptions from mavedb.models.mapped_variant import MappedVariant -from mavedb.view_models.score_calibration import FunctionalRange +from mavedb.view_models.score_calibration import FunctionalClassification logger = logging.getLogger(__name__) @@ -43,7 +44,7 @@ def functional_classification_of_variant( " Unable to classify functional impact." ) - if not primary_calibration.functional_ranges: + if not primary_calibration.functional_classifications: raise ValueError( f"Variant {mapped_variant.variant.urn} does not have ranges defined in its primary score calibration." " Unable to classify functional impact." @@ -57,14 +58,14 @@ def functional_classification_of_variant( " Unable to classify functional impact." ) - for functional_range in primary_calibration.functional_ranges: + for functional_range in primary_calibration.functional_classifications: # It's easier to reason with the view model objects for functional ranges than the JSONB fields in the raw database object. - functional_range_view = FunctionalRange.model_validate(functional_range) + functional_range_view = FunctionalClassification.model_validate(functional_range) if functional_range_view.is_contained_by_range(functional_score): - if functional_range_view.classification == "normal": + if functional_range_view.functional_classification is FunctionalClassificationOptions.normal: return ExperimentalVariantFunctionalImpactClassification.NORMAL - elif functional_range_view.classification == "abnormal": + elif functional_range_view.functional_classification is FunctionalClassificationOptions.abnormal: return ExperimentalVariantFunctionalImpactClassification.ABNORMAL else: return ExperimentalVariantFunctionalImpactClassification.INDETERMINATE @@ -96,7 +97,7 @@ def pathogenicity_classification_of_variant( " Unable to classify clinical impact." ) - if not primary_calibration.functional_ranges: + if not primary_calibration.functional_classifications: raise ValueError( f"Variant {mapped_variant.variant.urn} does not have ranges defined in its primary score calibration." " Unable to classify clinical impact." @@ -110,9 +111,9 @@ def pathogenicity_classification_of_variant( " Unable to classify clinical impact." ) - for pathogenicity_range in primary_calibration.functional_ranges: + for pathogenicity_range in primary_calibration.functional_classifications: # It's easier to reason with the view model objects for functional ranges than the JSONB fields in the raw database object. - pathogenicity_range_view = FunctionalRange.model_validate(pathogenicity_range) + pathogenicity_range_view = FunctionalClassification.model_validate(pathogenicity_range) if pathogenicity_range_view.is_contained_by_range(functional_score): if pathogenicity_range_view.acmg_classification is None: @@ -123,7 +124,7 @@ def pathogenicity_classification_of_variant( if ( pathogenicity_range_view.acmg_classification.evidence_strength is None or pathogenicity_range_view.acmg_classification.criterion is None - ): # pragma: no cover - enforced by model validators in FunctionalRange view model + ): # pragma: no cover - enforced by model validators in FunctionalClassification view model return (VariantPathogenicityEvidenceLine.Criterion.PS3, None) # TODO#540: Handle moderate+ @@ -139,7 +140,7 @@ def pathogenicity_classification_of_variant( if ( pathogenicity_range_view.acmg_classification.criterion.name not in VariantPathogenicityEvidenceLine.Criterion._member_names_ - ): # pragma: no cover - enforced by model validators in FunctionalRange view model + ): # pragma: no cover - enforced by model validators in FunctionalClassification view model raise ValueError( f"Variant {mapped_variant.variant.urn} is contained in a clinical calibration range with an invalid criterion." " Unable to classify clinical impact." diff --git a/src/mavedb/lib/annotation/util.py b/src/mavedb/lib/annotation/util.py index 0baab474..0b6274ad 100644 --- a/src/mavedb/lib/annotation/util.py +++ b/src/mavedb/lib/annotation/util.py @@ -1,16 +1,18 @@ from typing import Literal + from ga4gh.core.models import Extension from ga4gh.vrs.models import ( - MolecularVariation, Allele, CisPhasedBlock, - SequenceLocation, - SequenceReference, Expression, LiteralSequenceExpression, + MolecularVariation, + SequenceLocation, + SequenceReference, ) -from mavedb.models.mapped_variant import MappedVariant + from mavedb.lib.annotation.exceptions import MappingDataDoesntExistException +from mavedb.models.mapped_variant import MappedVariant from mavedb.view_models.score_calibration import SavedScoreCalibration @@ -190,13 +192,16 @@ def _variant_score_calibrations_have_required_calibrations_and_ranges_for_annota saved_calibration = SavedScoreCalibration.model_validate(primary_calibration) if annotation_type == "pathogenicity": return ( - saved_calibration.functional_ranges is not None - and len(saved_calibration.functional_ranges) > 0 - and any(fr.acmg_classification is not None for fr in saved_calibration.functional_ranges) + saved_calibration.functional_classifications is not None + and len(saved_calibration.functional_classifications) > 0 + and any(fr.acmg_classification is not None for fr in saved_calibration.functional_classifications) ) if annotation_type == "functional": - return saved_calibration.functional_ranges is not None and len(saved_calibration.functional_ranges) > 0 + return ( + saved_calibration.functional_classifications is not None + and len(saved_calibration.functional_classifications) > 0 + ) return True diff --git a/src/mavedb/lib/authentication.py b/src/mavedb/lib/authentication.py index b82faf3b..4ff59272 100644 --- a/src/mavedb/lib/authentication.py +++ b/src/mavedb/lib/authentication.py @@ -1,6 +1,5 @@ import logging import os -from dataclasses import dataclass from datetime import datetime from enum import Enum from typing import Optional @@ -19,6 +18,7 @@ from mavedb import deps from mavedb.lib.logging.context import format_raised_exception_info_as_dict, logging_context, save_to_logging_context from mavedb.lib.orcid import fetch_orcid_user_email +from mavedb.lib.types.authentication import UserData from mavedb.models.access_key import AccessKey from mavedb.models.enums.user_role import UserRole from mavedb.models.user import User @@ -45,12 +45,6 @@ class AuthenticationMethod(str, Enum): jwt = "jwt" -@dataclass -class UserData: - user: User - active_roles: list[UserRole] - - #################################################################################################### # JWT authentication #################################################################################################### diff --git a/src/mavedb/lib/authorization.py b/src/mavedb/lib/authorization.py index c9b2ab81..94f011c9 100644 --- a/src/mavedb/lib/authorization.py +++ b/src/mavedb/lib/authorization.py @@ -3,8 +3,9 @@ from fastapi import Depends, HTTPException -from mavedb.lib.authentication import UserData, get_current_user +from mavedb.lib.authentication import get_current_user from mavedb.lib.logging.context import logging_context, save_to_logging_context +from mavedb.lib.types.authentication import UserData from mavedb.models.enums.user_role import UserRole logger = logging.getLogger(__name__) diff --git a/src/mavedb/lib/clingen/services.py b/src/mavedb/lib/clingen/services.py index 1bcb7778..0450d61d 100644 --- a/src/mavedb/lib/clingen/services.py +++ b/src/mavedb/lib/clingen/services.py @@ -1,19 +1,17 @@ import hashlib import logging -import requests import os import time from datetime import datetime -from typing import Optional +from typing import Optional, Union from urllib import parse - +import requests from jose import jwt -from mavedb.lib.logging.context import logging_context, save_to_logging_context, format_raised_exception_info_as_dict from mavedb.lib.clingen.constants import GENBOREE_ACCOUNT_NAME, GENBOREE_ACCOUNT_PASSWORD, LDH_MAVE_ACCESS_ENDPOINT - -from mavedb.lib.types.clingen import LdhSubmission, ClinGenAllele +from mavedb.lib.logging.context import format_raised_exception_info_as_dict, logging_context, save_to_logging_context +from mavedb.lib.types.clingen import ClinGenAllele, ClinGenSubmissionError, LdhSubmission from mavedb.lib.utils import batched logger = logging.getLogger(__name__) @@ -71,7 +69,9 @@ def construct_auth_url(self, url: str) -> str: token = hashlib.sha1((url + identity + gbTime).encode("utf-8")).hexdigest() return url + "&gbLogin=" + GENBOREE_ACCOUNT_NAME + "&gbTime=" + gbTime + "&gbToken=" + token - def dispatch_submissions(self, content_submissions: list[str]) -> list[ClinGenAllele]: + def dispatch_submissions( + self, content_submissions: list[str] + ) -> list[Union[ClinGenAllele, ClinGenSubmissionError]]: save_to_logging_context({"car_submission_count": len(content_submissions)}) try: @@ -89,7 +89,7 @@ def dispatch_submissions(self, content_submissions: list[str]) -> list[ClinGenAl logger.error(msg="Failed to dispatch CAR submission.", exc_info=exc, extra=logging_context()) return [] - response_data: list[ClinGenAllele] = response.json() + response_data: list[Union[ClinGenAllele, ClinGenSubmissionError]] = response.json() save_to_logging_context({"car_submission_response_count": len(response_data)}) logger.info(msg="Successfully dispatched CAR submission.", extra=logging_context()) @@ -324,7 +324,7 @@ def clingen_allele_id_from_ldh_variation(variation: Optional[dict]) -> Optional[ def get_allele_registry_associations( - content_submissions: list[str], submission_response: list[ClinGenAllele] + content_submissions: list[str], submission_response: list[Union[ClinGenAllele, ClinGenSubmissionError]] ) -> dict[str, str]: """ Links HGVS strings and ClinGen Canonoical Allele IDs (CAIDs) given a list of both. @@ -360,9 +360,20 @@ def get_allele_registry_associations( allele_registry_associations: dict[str, str] = {} for registration in submission_response: + if "errorType" in registration: + logger.warning( + msg=f"Skipping errored ClinGen Allele Registry HGVS {registration.get('hgvs', 'unknown')} ({registration.get('errorType', 'unknown')}): {registration.get('message', 'unknown error message')}", + extra=logging_context(), + ) + continue + # Extract the CAID from the URL (e.g., "http://reg.test.genome.network/allele/CA2513066" -> "CA2513066") caid = registration["@id"].split("/")[-1] - alleles = registration.get("genomicAlleles", []) + registration.get("transcriptAlleles", []) + alleles = ( + registration.get("genomicAlleles", []) + + registration.get("transcriptAlleles", []) + + registration.get("aminoAcidAlleles", []) + ) for allele in alleles: for hgvs_string in content_submissions: diff --git a/src/mavedb/lib/experiments.py b/src/mavedb/lib/experiments.py index ed02b701..a200e93b 100644 --- a/src/mavedb/lib/experiments.py +++ b/src/mavedb/lib/experiments.py @@ -1,13 +1,13 @@ import logging from typing import Optional -from sqlalchemy import func, or_, not_ +from sqlalchemy import func, not_, or_ from sqlalchemy.orm import Session -from mavedb.lib.authentication import UserData from mavedb.lib.logging.context import logging_context, save_to_logging_context from mavedb.lib.permissions import Action from mavedb.lib.score_sets import find_superseded_score_set_tail +from mavedb.lib.types.authentication import UserData from mavedb.models.contributor import Contributor from mavedb.models.controlled_keyword import ControlledKeyword from mavedb.models.experiment import Experiment diff --git a/src/mavedb/lib/flexible_model_loader.py b/src/mavedb/lib/flexible_model_loader.py new file mode 100644 index 00000000..b3041e54 --- /dev/null +++ b/src/mavedb/lib/flexible_model_loader.py @@ -0,0 +1,210 @@ +"""Generic dependency for loading Pydantic models from either JSON body or multipart form data.""" + +from typing import Awaitable, Callable, Type, TypeVar + +from fastapi import Form, HTTPException, Request +from fastapi.exceptions import RequestValidationError +from pydantic import BaseModel, ValidationError + +T = TypeVar("T", bound=BaseModel) + + +def create_flexible_model_loader( + model_class: Type[T], form_field_name: str = "item", error_detail_prefix: str = "Invalid request" +) -> Callable[..., Awaitable[T]]: + """Create a flexible FastAPI dependency that can load a Pydantic model from either + JSON request body or multipart form data containing JSON. + + This factory function creates a dependency that enables FastAPI routes to accept + data in two formats: + 1. Standard JSON request body (Content-Type: application/json) + 2. Multipart form data with JSON string in a specified field + + This is particularly useful for endpoints that need to handle both pure JSON + requests and file uploads with accompanying metadata, allowing clients to + choose the most appropriate format for their use case. + + Args: + model_class (Type[T]): The Pydantic model class to instantiate from the JSON data. + Must be a subclass of BaseModel with proper field definitions and validation. + form_field_name (str, optional): Name of the form field containing JSON data + when using multipart/form-data requests. This parameter is primarily for + documentation purposes - the actual form field in OpenAPI docs will be + named 'item'. Defaults to "item". + error_detail_prefix (str, optional): Prefix text for error messages to provide + context about which operation failed. Defaults to "Invalid request". + + Returns: + Callable[..., Awaitable[T]]: An async dependency function that can be used + with FastAPI's Depends(). The returned function accepts a Request object + and optional form data, returning an instance of the specified model_class. + + Raises: + RequestValidationError: When the JSON data doesn't match the Pydantic model schema. + This preserves FastAPI's standard validation error format for consistent + client error handling. + HTTPException: For other parsing errors like invalid JSON syntax, missing data, + or unexpected exceptions during processing. + + Example: + Basic usage with a simple model: + + >>> from pydantic import BaseModel + >>> class UserModel(BaseModel): + ... name: str + ... email: str + + >>> user_loader = create_flexible_model_loader(UserModel) + + >>> @app.post("/users") + ... async def create_user(user: UserModel = Depends(user_loader)): + ... return {"user": user} + + Advanced usage with file uploads: + + >>> calibration_loader = create_flexible_model_loader( + ... ScoreCalibrationCreate, + ... form_field_name="calibration_metadata", + ... error_detail_prefix="Invalid calibration data" + ... ) + + >>> @app.post("/calibrations") + ... async def create_calibration( + ... calibration: ScoreCalibrationCreate = Depends(calibration_loader), + ... file: UploadFile = File(...) + ... ): + ... # Process both calibration metadata and uploaded file + ... return process_calibration(calibration, file) + + Client Usage Examples: + JSON request: + ```bash + curl -X POST "http://api/users" \\ + -H "Content-Type: application/json" \\ + -d '{"name": "John", "email": "john@example.com"}' + ``` + + Multipart form request: + ```bash + curl -X POST "http://api/calibrations" \\ + -F 'item={"name": "Test", "description": "Example"}' \\ + -F 'file=@data.csv' + ``` + + Note: + The dependency prioritizes form data over JSON body - if both are provided, + the form field data will be used. This ensures predictable behavior when + clients mix content types. + + OpenAPI Documentation Enhancement: + Without manual definition, OpenAPI docs will show the form field as 'item' for + multipart requests, regardless of the form_field_name parameter. To customize the + OpenAPI documentation and show both JSON and multipart form options clearly, use + the `openapi_extra` parameter on your route decorator: + + ```python + @router.post( + "/example-endpoint", + response_model=ExampleResponseModel, + summary="Example endpoint using flexible model loader", + description="Example endpoint description", + openapi_extra={ + "requestBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/YourModelName"}, + "example": { + "example_field": "example_value", + "another_field": 123 + } + }, + "multipart/form-data": { + "schema": { + "type": "object", + "properties": { + "item": { + "type": "string", + "description": "JSON string containing the model data", + "example": '{"example_field":"example_value","another_field":123}' + }, + "file_upload": { + "type": "string", + "format": "binary", + "description": "Optional file upload" + } + } + } + } + }, + "description": "Data can be sent as JSON body or multipart form data" + } + } + ) + async def example_endpoint( + model_data: YourModel = Depends(your_loader), + file_upload: UploadFile = File(None) + ): + return process_data(model_data, file_upload) + ``` + + This configuration will display both content types clearly in the OpenAPI/Swagger UI, + allowing users to choose between JSON and multipart form submission methods. + """ + + async def flexible_loader( + request: Request, + item: str = Form(None, description="JSON data for the request", alias=form_field_name), + ) -> T: + """Load Pydantic model from either JSON body or form field.""" + try: + # Prefer form field if provided + if item is not None: + model_instance = model_class.model_validate_json(item) + # Fall back to JSON body + else: + body = await request.body() + if not body: + raise HTTPException( + status_code=422, detail=f"{error_detail_prefix}: No data provided in form field or request body" + ) + model_instance = model_class.model_validate_json(body) + + return model_instance + + # Raise validation errors in FastAPI's expected format + except ValidationError as e: + raise RequestValidationError(e.errors()) + # Any other parsing errors + except Exception as e: + raise HTTPException(status_code=422, detail=f"{error_detail_prefix}: {str(e)}") + + return flexible_loader + + +# Convenience factory for common use cases +def json_or_form_loader(model_class: Type[T], field_name: str = "item") -> Callable[..., Awaitable[T]]: + """Simplified factory function for creating flexible model loaders with sensible defaults. + + This is a convenience wrapper around create_flexible_model_loader() that provides + a quick way to create loaders without specifying all parameters. It automatically + generates an appropriate error message prefix based on the model class name. + + Args: + model_class (Type[T]): The Pydantic model class to load from JSON data. + field_name (str, optional): Name of the form field for documentation purposes. + Defaults to "item". + + Returns: + Callable[..., Awaitable[T]]: A flexible dependency function ready to use with Depends(). + + Example: + Quick setup for simple cases: + + >>> user_loader = json_or_form_loader(UserModel) + >>> @app.post("/users") + ... async def create_user(user: UserModel = Depends(user_loader)): + ... return user + """ + return create_flexible_model_loader( + model_class=model_class, form_field_name=field_name, error_detail_prefix=f"Invalid {model_class.__name__} data" + ) diff --git a/src/mavedb/lib/permissions.py b/src/mavedb/lib/permissions.py deleted file mode 100644 index 99b2ada0..00000000 --- a/src/mavedb/lib/permissions.py +++ /dev/null @@ -1,506 +0,0 @@ -import logging -from enum import Enum -from typing import Optional - -from mavedb.db.base import Base -from mavedb.lib.authentication import UserData -from mavedb.lib.logging.context import logging_context, save_to_logging_context -from mavedb.models.collection import Collection -from mavedb.models.enums.contribution_role import ContributionRole -from mavedb.models.enums.user_role import UserRole -from mavedb.models.experiment import Experiment -from mavedb.models.experiment_set import ExperimentSet -from mavedb.models.score_calibration import ScoreCalibration -from mavedb.models.score_set import ScoreSet -from mavedb.models.user import User - -logger = logging.getLogger(__name__) - - -class Action(Enum): - LOOKUP = "lookup" - READ = "read" - UPDATE = "update" - DELETE = "delete" - ADD_EXPERIMENT = "add_experiment" - ADD_SCORE_SET = "add_score_set" - SET_SCORES = "set_scores" - ADD_ROLE = "add_role" - PUBLISH = "publish" - ADD_BADGE = "add_badge" - CHANGE_RANK = "change_rank" - - -class PermissionResponse: - def __init__(self, permitted: bool, http_code: int = 403, message: Optional[str] = None): - self.permitted = permitted - self.http_code = http_code if not permitted else None - self.message = message if not permitted else None - - save_to_logging_context({"permission_message": self.message, "access_permitted": self.permitted}) - if self.permitted: - logger.debug( - msg="Access to the requested resource is permitted.", - extra=logging_context(), - ) - else: - logger.debug( - msg="Access to the requested resource is not permitted.", - extra=logging_context(), - ) - - -class PermissionException(Exception): - def __init__(self, http_code: int, message: str): - self.http_code = http_code - self.message = message - - -def roles_permitted(user_roles: list[UserRole], permitted_roles: list[UserRole]) -> bool: - save_to_logging_context({"permitted_roles": [role.name for role in permitted_roles]}) - - if not user_roles: - logger.debug(msg="User has no associated roles.", extra=logging_context()) - return False - - return any(role in permitted_roles for role in user_roles) - - -def has_permission(user_data: Optional[UserData], item: Base, action: Action) -> PermissionResponse: - private = False - user_is_owner = False - user_is_self = False - user_may_edit = False - user_may_view_private = False - active_roles = user_data.active_roles if user_data else [] - - if isinstance(item, ExperimentSet) or isinstance(item, Experiment) or isinstance(item, ScoreSet): - assert item.private is not None - private = item.private - published = item.published_date is not None - user_is_owner = item.created_by_id == user_data.user.id if user_data is not None else False - user_may_edit = user_is_owner or ( - user_data is not None and user_data.user.username in [c.orcid_id for c in item.contributors] - ) - - save_to_logging_context({"resource_is_published": published}) - - if isinstance(item, Collection): - assert item.private is not None - private = item.private - published = item.private is False - user_is_owner = item.created_by_id == user_data.user.id if user_data is not None else False - admin_user_ids = set() - editor_user_ids = set() - viewer_user_ids = set() - for user_association in item.user_associations: - if user_association.contribution_role == ContributionRole.admin: - admin_user_ids.add(user_association.user_id) - elif user_association.contribution_role == ContributionRole.editor: - editor_user_ids.add(user_association.user_id) - elif user_association.contribution_role == ContributionRole.viewer: - viewer_user_ids.add(user_association.user_id) - user_is_admin = user_is_owner or (user_data is not None and user_data.user.id in admin_user_ids) - user_may_edit = user_is_admin or (user_data is not None and user_data.user.id in editor_user_ids) - user_may_view_private = user_may_edit or (user_data is not None and (user_data.user.id in viewer_user_ids)) - - save_to_logging_context({"resource_is_published": published}) - - if isinstance(item, ScoreCalibration): - assert item.private is not None - private = item.private - published = item.private is False - user_is_owner = item.created_by_id == user_data.user.id if user_data is not None else False - - # If the calibration is investigator provided, treat permissions like score set permissions where contributors - # may also make changes to the calibration. Otherwise, only allow the calibration owner to edit the calibration. - if item.investigator_provided: - user_may_edit = user_is_owner or ( - user_data is not None and user_data.user.username in [c.orcid_id for c in item.score_set.contributors] - ) - else: - user_may_edit = user_is_owner - - if isinstance(item, User): - user_is_self = item.id == user_data.user.id if user_data is not None else False - user_may_edit = user_is_self - - save_to_logging_context( - { - "resource_is_private": private, - "user_is_owner_of_resource": user_is_owner, - "user_is_may_edit_resource": user_may_edit, - "user_is_self": user_is_self, - } - ) - - if isinstance(item, ExperimentSet): - if action == Action.READ: - if user_may_edit or not private: - return PermissionResponse(True) - # Roles which may perform this operation. - elif roles_permitted(active_roles, [UserRole.admin, UserRole.mapper]): - return PermissionResponse(True) - elif private: - # Do not acknowledge the existence of a private entity. - return PermissionResponse(False, 404, f"experiment set with URN '{item.urn}' not found") - elif user_data is None or user_data.user is None: - return PermissionResponse(False, 401, f"insufficient permissions for URN '{item.urn}'") - else: - return PermissionResponse(False, 403, f"insufficient permissions for URN '{item.urn}'") - elif action == Action.UPDATE: - if user_may_edit: - return PermissionResponse(True) - # Roles which may perform this operation. - elif roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - elif private: - # Do not acknowledge the existence of a private entity. - return PermissionResponse(False, 404, f"experiment set with URN '{item.urn}' not found") - elif user_data is None or user_data.user is None: - return PermissionResponse(False, 401, f"insufficient permissions for URN '{item.urn}'") - else: - return PermissionResponse(False, 403, f"insufficient permissions for URN '{item.urn}'") - elif action == Action.DELETE: - # Owner may only delete an experiment set if it has not already been published. - if user_may_edit: - return PermissionResponse(not published, 403, f"insufficient permissions for URN '{item.urn}'") - # Roles which may perform this operation. - elif roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - elif private: - # Do not acknowledge the existence of a private entity. - return PermissionResponse(False, 404, f"experiment set with URN '{item.urn}' not found") - else: - return PermissionResponse(False) - elif action == Action.ADD_EXPERIMENT: - # Only permitted users can add an experiment to an existing experiment set. - return PermissionResponse( - user_may_edit or roles_permitted(active_roles, [UserRole.admin]), - 404 if private else 403, - ( - f"experiment set with URN '{item.urn}' not found" - if private - else f"insufficient permissions for URN '{item.urn}'" - ), - ) - else: - raise NotImplementedError(f"has_permission(User, ExperimentSet, {action}, Role)") - - elif isinstance(item, Experiment): - if action == Action.READ: - if user_may_edit or not private: - return PermissionResponse(True) - # Roles which may perform this operation. - elif roles_permitted(active_roles, [UserRole.admin, UserRole.mapper]): - return PermissionResponse(True) - elif private: - # Do not acknowledge the existence of a private entity. - return PermissionResponse(False, 404, f"experiment with URN '{item.urn}' not found") - elif user_data is None or user_data.user is None: - return PermissionResponse(False, 401, f"insufficient permissions for URN '{item.urn}'") - else: - return PermissionResponse(False, 403, f"insufficient permissions for URN '{item.urn}'") - elif action == Action.UPDATE: - if user_may_edit: - return PermissionResponse(True) - # Roles which may perform this operation. - elif roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - elif private: - # Do not acknowledge the existence of a private entity. - return PermissionResponse(False, 404, f"experiment with URN '{item.urn}' not found") - elif user_data is None or user_data.user is None: - return PermissionResponse(False, 401, f"insufficient permissions for URN '{item.urn}'") - else: - return PermissionResponse(False, 403, f"insufficient permissions for URN '{item.urn}'") - elif action == Action.DELETE: - # Owner may only delete an experiment if it has not already been published. - if user_may_edit: - return PermissionResponse(not published, 403, f"insufficient permissions for URN '{item.urn}'") - # Roles which may perform this operation. - elif roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - elif private: - # Do not acknowledge the existence of a private entity. - return PermissionResponse(False, 404, f"experiment set with URN '{item.urn}' not found") - else: - return PermissionResponse(False) - elif action == Action.ADD_SCORE_SET: - # Only permitted users can add a score set to a private experiment. - if user_may_edit or roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - elif private: - return PermissionResponse(False, 404, f"experiment with URN '{item.urn}' not found") - # Any signed in user has permissions to add a score set to a public experiment - elif user_data is not None: - return PermissionResponse(True) - else: - return PermissionResponse(False, 403, f"insufficient permissions for URN '{item.urn}'") - else: - raise NotImplementedError(f"has_permission(User, Experiment, {action}, Role)") - - elif isinstance(item, ScoreSet): - if action == Action.READ: - if user_may_edit or not private: - return PermissionResponse(True) - # Roles which may perform this operation. - elif roles_permitted(active_roles, [UserRole.admin, UserRole.mapper]): - return PermissionResponse(True) - elif private: - # Do not acknowledge the existence of a private entity. - return PermissionResponse(False, 404, f"score set with URN '{item.urn}' not found") - elif user_data is None or user_data.user is None: - return PermissionResponse(False, 401, f"insufficient permissions for URN '{item.urn}'") - else: - return PermissionResponse(False, 403, f"insufficient permissions for URN '{item.urn}'") - elif action == Action.UPDATE: - if user_may_edit: - return PermissionResponse(True) - # Roles which may perform this operation. - elif roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - elif private: - # Do not acknowledge the existence of a private entity. - return PermissionResponse(False, 404, f"score set with URN '{item.urn}' not found") - elif user_data is None or user_data.user is None: - return PermissionResponse(False, 401, f"insufficient permissions for URN '{item.urn}'") - else: - return PermissionResponse(False, 403, f"insufficient permissions for URN '{item.urn}'") - elif action == Action.DELETE: - # Owner may only delete a score set if it has not already been published. - if user_may_edit: - return PermissionResponse(not published, 403, f"insufficient permissions for URN '{item.urn}'") - # Roles which may perform this operation. - elif roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - elif private: - # Do not acknowledge the existence of a private entity. - return PermissionResponse(False, 404, f"experiment set with URN '{item.urn}' not found") - else: - return PermissionResponse(False) - # Only the owner may publish a private score set. - elif action == Action.PUBLISH: - if user_may_edit: - return PermissionResponse(True) - elif roles_permitted(active_roles, []): - return PermissionResponse(True) - elif private: - # Do not acknowledge the existence of a private entity. - return PermissionResponse(False, 404, f"score set with URN '{item.urn}' not found") - else: - return PermissionResponse(False) - elif action == Action.SET_SCORES: - return PermissionResponse( - (user_may_edit or roles_permitted(active_roles, [UserRole.admin])), - 404 if private else 403, - ( - f"score set with URN '{item.urn}' not found" - if private - else f"insufficient permissions for URN '{item.urn}'" - ), - ) - else: - raise NotImplementedError(f"has_permission(User, ScoreSet, {action}, Role)") - - elif isinstance(item, Collection): - if action == Action.READ: - if user_may_view_private or not private: - return PermissionResponse(True) - # Roles which may perform this operation. - elif roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - elif private: - # Do not acknowledge the existence of a private entity. - return PermissionResponse(False, 404, f"collection with URN '{item.urn}' not found") - elif user_data is None or user_data.user is None: - return PermissionResponse(False, 401, f"insufficient permissions for URN '{item.urn}'") - else: - return PermissionResponse(False, 403, f"insufficient permissions for URN '{item.urn}'") - elif action == Action.UPDATE: - if user_may_edit: - return PermissionResponse(True) - # Roles which may perform this operation. - elif roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - elif private and not user_may_view_private: - # Do not acknowledge the existence of a private entity. - return PermissionResponse(False, 404, f"score set with URN '{item.urn}' not found") - elif user_data is None or user_data.user is None: - return PermissionResponse(False, 401, f"insufficient permissions for URN '{item.urn}'") - else: - return PermissionResponse(False, 403, f"insufficient permissions for URN '{item.urn}'") - elif action == Action.DELETE: - # A collection may be deleted even if it has been published, as long as it is not an official collection. - if user_is_owner: - return PermissionResponse( - not item.badge_name, - 403, - f"insufficient permissions for URN '{item.urn}'", - ) - # MaveDB admins may delete official collections. - elif roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - elif private and not user_may_view_private: - # Do not acknowledge the existence of a private entity. - return PermissionResponse(False, 404, f"collection with URN '{item.urn}' not found") - else: - return PermissionResponse(False) - elif action == Action.PUBLISH: - if user_is_admin: - return PermissionResponse(True) - elif roles_permitted(active_roles, []): - return PermissionResponse(True) - elif private and not user_may_view_private: - # Do not acknowledge the existence of a private entity. - return PermissionResponse(False, 404, f"score set with URN '{item.urn}' not found") - else: - return PermissionResponse(False) - elif action == Action.ADD_SCORE_SET: - # Whether the collection is private or public, only permitted users can add a score set to a collection. - if user_may_edit or roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - elif private and not user_may_view_private: - return PermissionResponse(False, 404, f"collection with URN '{item.urn}' not found") - else: - return PermissionResponse(False, 403, f"insufficient permissions for URN '{item.urn}'") - elif action == Action.ADD_EXPERIMENT: - # Only permitted users can add an experiment to an existing collection. - return PermissionResponse( - user_may_edit or roles_permitted(active_roles, [UserRole.admin]), - 404 if private and not user_may_view_private else 403, - ( - f"collection with URN '{item.urn}' not found" - if private and not user_may_view_private - else f"insufficient permissions for URN '{item.urn}'" - ), - ) - elif action == Action.ADD_ROLE: - # Both collection admins and MaveDB admins can add a user to a collection role - if user_is_admin or roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - else: - return PermissionResponse(False, 403, "Insufficient permissions to add user role.") - # only MaveDB admins may add a badge name to a collection, which makes the collection considered "official" - elif action == Action.ADD_BADGE: - # Roles which may perform this operation. - if roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - elif private: - # Do not acknowledge the existence of a private entity. - return PermissionResponse(False, 404, f"collection with URN '{item.urn}' not found") - elif user_data is None or user_data.user is None: - return PermissionResponse(False, 401, f"insufficient permissions for URN '{item.urn}'") - else: - return PermissionResponse(False, 403, f"insufficient permissions for URN '{item.urn}'") - else: - raise NotImplementedError(f"has_permission(User, ScoreSet, {action}, Role)") - elif isinstance(item, ScoreCalibration): - if action == Action.READ: - if user_may_edit or not private: - return PermissionResponse(True) - # Roles which may perform this operation. - elif roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - elif private: - # Do not acknowledge the existence of a private entity. - return PermissionResponse(False, 404, f"score calibration with URN '{item.urn}' not found") - elif user_data is None or user_data.user is None: - return PermissionResponse(False, 401, f"insufficient permissions for URN '{item.urn}'") - else: - return PermissionResponse(False, 403, f"insufficient permissions for URN '{item.urn}'") - elif action == Action.UPDATE: - if roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - # TODO#549: Allow editing of certain fields even if published. For now, - # Owner may only edit if a calibration is not published. - elif user_may_edit: - return PermissionResponse(not published, 403, f"insufficient permissions for URN '{item.urn}'") - elif private: - # Do not acknowledge the existence of a private entity. - return PermissionResponse(False, 404, f"score calibration with URN '{item.urn}' not found") - elif user_data is None or user_data.user is None: - return PermissionResponse(False, 401, f"insufficient permissions for URN '{item.urn}'") - else: - return PermissionResponse(False, 403, f"insufficient permissions for URN '{item.urn}'") - elif action == Action.DELETE: - # Roles which may perform this operation. - if roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - # Owner may only delete a calibration if it has not already been published. - elif user_may_edit: - return PermissionResponse(not published, 403, f"insufficient permissions for URN '{item.urn}'") - elif private: - # Do not acknowledge the existence of a private entity. - return PermissionResponse(False, 404, f"score calibration with URN '{item.urn}' not found") - else: - return PermissionResponse(False) - # Only the owner may publish a private calibration. - elif action == Action.PUBLISH: - if user_may_edit: - return PermissionResponse(True) - elif roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - elif private: - # Do not acknowledge the existence of a private entity. - return PermissionResponse(False, 404, f"score calibration with URN '{item.urn}' not found") - else: - return PermissionResponse(False) - elif action == Action.CHANGE_RANK: - if user_may_edit: - return PermissionResponse(True) - elif roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - else: - return PermissionResponse(False, 403, f"insufficient permissions for URN '{item.urn}'") - - else: - raise NotImplementedError(f"has_permission(User, ScoreCalibration, {action}, Role)") - - elif isinstance(item, User): - if action == Action.LOOKUP: - # any existing user can look up any mavedb user by Orcid ID - # lookup differs from read because lookup means getting the first name, last name, and orcid ID of the user, - # while read means getting an admin view of the user's details - if user_data is not None and user_data.user is not None: - return PermissionResponse(True) - else: - # TODO is this inappropriately acknowledging the existence of the user? - return PermissionResponse(False, 401, "Insufficient permissions for user lookup.") - if action == Action.READ: - if user_is_self: - return PermissionResponse(True) - elif roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - else: - return PermissionResponse(False, 403, "Insufficient permissions for user update.") - elif action == Action.UPDATE: - if user_is_self: - return PermissionResponse(True) - elif roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - else: - return PermissionResponse(False, 403, "Insufficient permissions for user update.") - elif action == Action.ADD_ROLE: - if roles_permitted(active_roles, [UserRole.admin]): - return PermissionResponse(True) - else: - return PermissionResponse(False, 403, "Insufficient permissions to add user role.") - elif action == Action.DELETE: - raise NotImplementedError(f"has_permission(User, ScoreSet, {action}, Role)") - else: - raise NotImplementedError(f"has_permission(User, ScoreSet, {action}, Role)") - - else: - raise NotImplementedError(f"has_permission(User, {item.__class__}, {action}, Role)") - - -def assert_permission(user_data: Optional[UserData], item: Base, action: Action) -> PermissionResponse: - save_to_logging_context({"permission_boundary": action.name}) - permission = has_permission(user_data, item, action) - - if not permission.permitted: - assert permission.http_code and permission.message - raise PermissionException(http_code=permission.http_code, message=permission.message) - - return permission diff --git a/src/mavedb/lib/permissions/__init__.py b/src/mavedb/lib/permissions/__init__.py new file mode 100644 index 00000000..2f226cef --- /dev/null +++ b/src/mavedb/lib/permissions/__init__.py @@ -0,0 +1,27 @@ +""" +Permission system for MaveDB entities. + +This module provides a comprehensive permission system for checking user access +to various entity types including ScoreSets, Experiments, Collections, etc. + +Main Functions: + has_permission: Check if a user has permission for an action on an entity + assert_permission: Assert permission or raise exception + +Usage: + >>> from mavedb.lib.permissions import Action, has_permission, assert_permission + >>> + >>> # Check permission and handle response + >>> result = has_permission(user_data, score_set, Action.READ) + >>> if result.permitted: + ... # User has access + ... pass + >>> + >>> # Assert permission (raises exception if denied) + >>> assert_permission(user_data, score_set, Action.UPDATE) +""" + +from .actions import Action +from .core import assert_permission, has_permission + +__all__ = ["has_permission", "assert_permission", "Action"] diff --git a/src/mavedb/lib/permissions/actions.py b/src/mavedb/lib/permissions/actions.py new file mode 100644 index 00000000..cc3a9559 --- /dev/null +++ b/src/mavedb/lib/permissions/actions.py @@ -0,0 +1,15 @@ +from enum import Enum + + +class Action(Enum): + LOOKUP = "lookup" + READ = "read" + UPDATE = "update" + DELETE = "delete" + ADD_EXPERIMENT = "add_experiment" + ADD_SCORE_SET = "add_score_set" + SET_SCORES = "set_scores" + ADD_ROLE = "add_role" + PUBLISH = "publish" + ADD_BADGE = "add_badge" + CHANGE_RANK = "change_rank" diff --git a/src/mavedb/lib/permissions/collection.py b/src/mavedb/lib/permissions/collection.py new file mode 100644 index 00000000..a629a45e --- /dev/null +++ b/src/mavedb/lib/permissions/collection.py @@ -0,0 +1,405 @@ +from typing import Optional + +from mavedb.lib.logging.context import save_to_logging_context +from mavedb.lib.permissions.actions import Action +from mavedb.lib.permissions.models import PermissionResponse +from mavedb.lib.permissions.utils import deny_action_for_entity, roles_permitted +from mavedb.lib.types.authentication import UserData +from mavedb.models.collection import Collection +from mavedb.models.enums.contribution_role import ContributionRole +from mavedb.models.enums.user_role import UserRole + + +def has_permission(user_data: Optional[UserData], entity: Collection, action: Action) -> PermissionResponse: + """ + Check if a user has permission to perform an action on a Collection entity. + + This function evaluates user permissions based on Collection role associations, + ownership, and user roles. Collections use a special permission model with + role-based user associations. + + Args: + user_data: The user's authentication data and roles. None for anonymous users. + entity: The Collection entity to check permissions for. + action: The action to be performed (READ, UPDATE, DELETE, ADD_EXPERIMENT, ADD_SCORE_SET, ADD_ROLE, ADD_BADGE). + + Returns: + PermissionResponse: Contains permission result, HTTP status code, and message. + + Raises: + ValueError: If the entity's private attribute is not set. + NotImplementedError: If the action is not supported for Collection entities. + + Note: + Collections use CollectionUserAssociation objects to define user roles + (admin, editor, viewer) rather than simple contributor lists. + """ + if entity.private is None: + raise ValueError("Collection entity must have 'private' attribute set for permission checks.") + + user_is_owner = False + collection_roles = [] + active_roles = [] + + if user_data is not None: + user_is_owner = entity.created_by_id == user_data.user.id + + # Find the user's collection roles in this collection through user_associations. + user_associations = [assoc for assoc in entity.user_associations if assoc.user_id == user_data.user.id] + if user_associations: + collection_roles = [assoc.contribution_role for assoc in user_associations] + + active_roles = user_data.active_roles + + save_to_logging_context( + { + "resource_is_private": entity.private, + "user_is_owner": user_is_owner, + "collection_roles": [role.value for role in collection_roles] if collection_roles else None, + } + ) + + handlers = { + Action.READ: _handle_read_action, + Action.UPDATE: _handle_update_action, + Action.DELETE: _handle_delete_action, + Action.PUBLISH: _handle_publish_action, + Action.ADD_EXPERIMENT: _handle_add_experiment_action, + Action.ADD_SCORE_SET: _handle_add_score_set_action, + Action.ADD_ROLE: _handle_add_role_action, + Action.ADD_BADGE: _handle_add_badge_action, + } + + if action not in handlers: + supported_actions = ", ".join(a.value for a in handlers.keys()) + raise NotImplementedError( + f"Action '{action.value}' is not supported for collection entities. " + f"Supported actions: {supported_actions}" + ) + + return handlers[action]( + user_data, + entity, + entity.private, + entity.badge_name is not None, + user_is_owner, + collection_roles, + active_roles, + ) + + +def _handle_read_action( + user_data: Optional[UserData], + entity: Collection, + private: bool, + official_collection: bool, + user_is_owner: bool, + collection_roles: list[ContributionRole], + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle READ action permission check for Collection entities. + + Public Collections are readable by anyone. Private Collections are only readable + by users with Collection roles, owners, admins, and mappers. + + Args: + user_data: The user's authentication data. + entity: The Collection entity being accessed. + private: Whether the Collection is private. + official_collection: Whether the Collection is an official collection. + user_is_owner: Whether the user owns the Collection. + collection_roles: The user's roles in this Collection (admin/editor/viewer). + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow read access under the following conditions: + # Any user may read a non-private collection. + if not private: + return PermissionResponse(True) + # The owner may read a private collection. + if user_is_owner: + return PermissionResponse(True) + # Collection role holders may read a private collection. + if roles_permitted(collection_roles, [ContributionRole.admin, ContributionRole.editor, ContributionRole.viewer]): + return PermissionResponse(True) + # Users with these specific roles may read a private collection. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, bool(collection_roles) or user_is_owner, "collection") + + +def _handle_update_action( + user_data: Optional[UserData], + entity: Collection, + private: bool, + official_collection: bool, + user_is_owner: bool, + collection_roles: list[ContributionRole], + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle UPDATE action permission check for Collection entities. + + Only owners, Collection admins/editors, and system admins can update Collections. + + Args: + user_data: The user's authentication data. + entity: The Collection entity being updated. + private: Whether the Collection is private. + official_collection: Whether the Collection is an official collection. + user_is_owner: Whether the user owns the Collection. + collection_roles: The user's roles in this Collection (admin/editor/viewer). + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow update access under the following conditions: + # The owner may update the collection. + if user_is_owner: + return PermissionResponse(True) + # Collection admins and editors may update the collection. + if roles_permitted(collection_roles, [ContributionRole.admin, ContributionRole.editor]): + return PermissionResponse(True) + # Users with these specific roles may update the collection. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, bool(collection_roles) or user_is_owner, "collection") + + +def _handle_delete_action( + user_data: Optional[UserData], + entity: Collection, + private: bool, + official_collection: bool, + user_is_owner: bool, + collection_roles: list[ContributionRole], + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle DELETE action permission check for Collection entities. + + System admins can delete any Collection. Owners and Collection admins can only + delete unpublished Collections. + + Args: + user_data: The user's authentication data. + entity: The Collection entity being deleted. + private: Whether the Collection is private. + official_collection: Whether the Collection is official. + user_is_owner: Whether the user owns the Collection. + collection_roles: The user's roles in this Collection (admin/editor/viewer). + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow delete access under the following conditions: + # System admins may delete any collection. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + # Other users may only delete non-official collections. + if not official_collection: + # Owners may delete a collection only if it is still private. + # Collection admins/editors/viewers may not delete collections. + if user_is_owner and private: + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, bool(collection_roles) or user_is_owner, "collection") + + +def _handle_publish_action( + user_data: Optional[UserData], + entity: Collection, + private: bool, + official_collection: bool, + user_is_owner: bool, + collection_roles: list[ContributionRole], + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle PUBLISH action permission check for Collection entities. + + Only owners, Collection admins, and system admins can publish Collections. + + Args: + user_data: The user's authentication data. + entity: The Collection entity being published. + private: Whether the Collection is private. + official_collection: Whether the Collection is official. + user_is_owner: Whether the user owns the Collection. + collection_roles: The user's roles in this Collection (admin/editor/viewer). + active_roles: List of the user's active roles. + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow publish access under the following conditions: + # The owner may publish a collection. + if user_is_owner: + return PermissionResponse(True) + # Collection admins may publish the collection. + if roles_permitted(collection_roles, [ContributionRole.admin]): + return PermissionResponse(True) + # Users with these specific roles may publish the collection. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, bool(collection_roles) or user_is_owner, "collection") + + +def _handle_add_experiment_action( + user_data: Optional[UserData], + entity: Collection, + private: bool, + official_collection: bool, + user_is_owner: bool, + collection_roles: list[ContributionRole], + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle ADD_EXPERIMENT action permission check for Collection entities. + + Only owners, Collection admins/editors, and system admins can add experiment sets + to private Collections. Any authenticated user can add to public Collections. + + Args: + user_data: The user's authentication data. + entity: The Collection entity to add an experiment to. + private: Whether the Collection is private. + official_collection: Whether the Collection is official. + user_is_owner: Whether the user owns the Collection. + collection_roles: The user's roles in this Collection (admin/editor/viewer). + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow add experiment add access under the following conditions: + # The owner may add an experiment to a private collection. + if user_is_owner: + return PermissionResponse(True) + # Collection admins/editors may add an experiment to the collection. + if roles_permitted(collection_roles, [ContributionRole.admin, ContributionRole.editor]): + return PermissionResponse(True) + # Users with these specific roles may add an experiment to the collection. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, bool(collection_roles) or user_is_owner, "collection") + + +def _handle_add_score_set_action( + user_data: Optional[UserData], + entity: Collection, + private: bool, + official_collection: bool, + user_is_owner: bool, + collection_roles: list[ContributionRole], + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle ADD_SCORE_SET action permission check for Collection entities. + + Only owners, Collection admins/editors, and system admins can add score sets + to private Collections. Any authenticated user can add to public Collections. + + Args: + user_data: The user's authentication data. + entity: The Collection entity to add a score set to. + private: Whether the Collection is private. + official_collection: Whether the Collection is official. + user_is_owner: Whether the user owns the Collection. + collection_roles: The user's roles in this Collection (admin/editor/viewer). + active_roles: List of the user's active roles. + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow add score set access under the following conditions: + # The owner may add a score set to a private collection. + if user_is_owner: + return PermissionResponse(True) + # Collection admins/editors may add a score set to the collection. + if roles_permitted(collection_roles, [ContributionRole.admin, ContributionRole.editor]): + return PermissionResponse(True) + # Users with these specific roles may add a score set to the collection. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, bool(collection_roles) or user_is_owner, "collection") + + +def _handle_add_role_action( + user_data: Optional[UserData], + entity: Collection, + private: bool, + official_collection: bool, + user_is_owner: bool, + collection_roles: list[ContributionRole], + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle ADD_ROLE action permission check for Collection entities. + + Only owners and Collection admins can add roles to Collections. + + Args: + user_data: The user's authentication data. + entity: The Collection entity to add a role to. + private: Whether the Collection is private. + official_collection: Whether the Collection is official. + user_is_owner: Whether the user owns the Collection. + collection_roles: The user's roles in this Collection (admin/editor/viewer). + active_roles: List of the user's active roles. + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow add role access under the following conditions: + # The owner may add a role. + if user_is_owner: + return PermissionResponse(True) + # Collection admins may add a role to the collection. + if roles_permitted(collection_roles, [ContributionRole.admin]): + return PermissionResponse(True) + # Users with these specific roles may add a role to the collection. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, bool(collection_roles) or user_is_owner, "collection") + + +def _handle_add_badge_action( + user_data: Optional[UserData], + entity: Collection, + private: bool, + official_collection: bool, + user_is_owner: bool, + collection_roles: list[ContributionRole], + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle ADD_BADGE action permission check for Collection entities. + + Only system admins can add badges to Collections. + + Args: + user_data: The user's authentication data. + entity: The Collection entity to add a badge to. + private: Whether the Collection is private. + official_collection: Whether the Collection is official. + user_is_owner: Whether the user owns the Collection. + collection_roles: The user's roles in this Collection (admin/editor/viewer). + active_roles: List of the user's active roles. + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow add badge access under the following conditions: + # Users with these specific roles may add a badge to the collection. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, bool(collection_roles) or user_is_owner, "collection") diff --git a/src/mavedb/lib/permissions/core.py b/src/mavedb/lib/permissions/core.py new file mode 100644 index 00000000..df6facc8 --- /dev/null +++ b/src/mavedb/lib/permissions/core.py @@ -0,0 +1,114 @@ +from typing import Any, Callable, Optional + +from mavedb.lib.logging.context import save_to_logging_context +from mavedb.lib.permissions.actions import Action +from mavedb.lib.permissions.exceptions import PermissionException +from mavedb.lib.permissions.models import PermissionResponse +from mavedb.lib.types.authentication import UserData +from mavedb.lib.types.permissions import EntityType +from mavedb.models.collection import Collection +from mavedb.models.experiment import Experiment +from mavedb.models.experiment_set import ExperimentSet +from mavedb.models.score_calibration import ScoreCalibration +from mavedb.models.score_set import ScoreSet +from mavedb.models.user import User + +# Import entity-specific permission modules +from . import ( + collection, + experiment, + experiment_set, + score_calibration, + score_set, + user, +) + + +def has_permission(user_data: Optional[UserData], entity: EntityType, action: Action) -> PermissionResponse: + """ + Main dispatcher function for permission checks across all entity types. + + This function routes permission checks to the appropriate entity-specific + module based on the type of the entity provided. Each entity type has + its own permission logic and supported actions. + + Args: + user_data: The user's authentication data and roles. None for anonymous users. + entity: The entity to check permissions for. Must be one of the supported types. + action: The action to be performed on the entity. + + Returns: + PermissionResponse: Contains permission result, HTTP status code, and message. + + Raises: + NotImplementedError: If the entity type is not supported. + + Example: + >>> from mavedb.lib.permissions.core import has_permission + >>> from mavedb.lib.permissions.actions import Action + >>> result = has_permission(user_data, score_set, Action.READ) + >>> if result.permitted: + ... # User has permission + ... pass + + Note: + This is the main entry point for all permission checks in the application. + Each entity type delegates to its own module for specific permission logic. + """ + # Dictionary mapping entity types to their corresponding permission modules + entity_handlers: dict[type, Callable[[Optional[UserData], Any, Action], PermissionResponse]] = { + Collection: collection.has_permission, + Experiment: experiment.has_permission, + ExperimentSet: experiment_set.has_permission, + ScoreCalibration: score_calibration.has_permission, + ScoreSet: score_set.has_permission, + User: user.has_permission, + } + + entity_type = type(entity) + + if entity_type not in entity_handlers: + supported_types = ", ".join(cls.__name__ for cls in entity_handlers.keys()) + raise NotImplementedError( + f"Permission checks are not implemented for entity type '{entity_type.__name__}'. " + f"Supported entity types: {supported_types}" + ) + + handler = entity_handlers[entity_type] + return handler(user_data, entity, action) + + +def assert_permission(user_data: Optional[UserData], entity: EntityType, action: Action) -> PermissionResponse: + """ + Assert that a user has permission to perform an action on an entity. + + This function checks permissions and raises an exception if the user lacks + the necessary permissions. It's a convenience wrapper around has_permission + for cases where you want to fail fast on permission denials. + + Args: + user_data: The user's authentication data and roles. None for anonymous users. + entity: The entity to check permissions for. + action: The action to be performed on the entity. + + Returns: + PermissionResponse: The permission result if access is granted. + + Raises: + PermissionException: If the user lacks sufficient permissions. + + Example: + >>> from mavedb.lib.permissions.core import assert_permission + >>> from mavedb.lib.permissions.actions import Action + >>> # This will raise PermissionException if user can't update + >>> assert_permission(user_data, score_set, Action.UPDATE) + """ + save_to_logging_context({"permission_boundary": action.name}) + permission = has_permission(user_data, entity, action) + + if not permission.permitted: + http_code = permission.http_code if permission.http_code is not None else 403 + message = permission.message if permission.message is not None else "Permission denied" + raise PermissionException(http_code=http_code, message=message) + + return permission diff --git a/src/mavedb/lib/permissions/exceptions.py b/src/mavedb/lib/permissions/exceptions.py new file mode 100644 index 00000000..d3ebf87e --- /dev/null +++ b/src/mavedb/lib/permissions/exceptions.py @@ -0,0 +1,4 @@ +class PermissionException(Exception): + def __init__(self, http_code: int, message: str): + self.http_code = http_code + self.message = message diff --git a/src/mavedb/lib/permissions/experiment.py b/src/mavedb/lib/permissions/experiment.py new file mode 100644 index 00000000..91f8f617 --- /dev/null +++ b/src/mavedb/lib/permissions/experiment.py @@ -0,0 +1,221 @@ +from typing import Optional + +from mavedb.lib.logging.context import save_to_logging_context +from mavedb.lib.permissions.actions import Action +from mavedb.lib.permissions.models import PermissionResponse +from mavedb.lib.permissions.utils import deny_action_for_entity, roles_permitted +from mavedb.lib.types.authentication import UserData +from mavedb.models.enums.user_role import UserRole +from mavedb.models.experiment import Experiment + + +def has_permission(user_data: Optional[UserData], entity: Experiment, action: Action) -> PermissionResponse: + """ + Check if a user has permission to perform an action on an Experiment entity. + + This function evaluates user permissions based on ownership, contributor status, + and user roles. It handles both private and public Experiments with different + access control rules. + + Args: + user_data: The user's authentication data and roles. None for anonymous users. + entity: The Experiment entity to check permissions for. + action: The action to be performed (READ, UPDATE, DELETE, ADD_SCORE_SET). + + Returns: + PermissionResponse: Contains permission result, HTTP status code, and message. + + Raises: + ValueError: If the entity's private attribute is not set. + NotImplementedError: If the action is not supported for Experiment entities. + """ + if entity.private is None: + raise ValueError("Experiment entity must have 'private' attribute set for permission checks.") + + user_is_owner = False + user_is_contributor = False + active_roles = [] + if user_data is not None: + user_is_owner = entity.created_by_id == user_data.user.id + user_is_contributor = user_data.user.username in [c.orcid_id for c in entity.contributors] + active_roles = user_data.active_roles + + save_to_logging_context( + { + "resource_is_private": entity.private, + "user_is_owner": user_is_owner, + "user_is_contributor": user_is_contributor, + } + ) + + handlers = { + Action.READ: _handle_read_action, + Action.UPDATE: _handle_update_action, + Action.DELETE: _handle_delete_action, + Action.ADD_SCORE_SET: _handle_add_score_set_action, + } + + if action not in handlers: + supported_actions = ", ".join(a.value for a in handlers.keys()) + raise NotImplementedError( + f"Action '{action.value}' is not supported for experiment entities. " + f"Supported actions: {supported_actions}" + ) + + return handlers[action]( + user_data, + entity, + entity.private, + user_is_owner, + user_is_contributor, + active_roles, + ) + + +def _handle_read_action( + user_data: Optional[UserData], + entity: Experiment, + private: bool, + user_is_owner: bool, + user_is_contributor: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle READ action permission check for Experiment entities. + + Public Experiments are readable by anyone. Private Experiments are only readable + by owners, contributors, admins, and mappers. + + Args: + user_data: The user's authentication data. + entity: The Experiment entity being accessed. + private: Whether the Experiment is private. + user_is_owner: Whether the user owns the Experiment. + user_is_contributor: Whether the user is a contributor to the Experiment. + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow read access under the following conditions: + # Any user may read a non-private experiment. + if not private: + return PermissionResponse(True) + # The owner or contributors may read a private experiment. + if user_is_owner or user_is_contributor: + return PermissionResponse(True) + # Users with these specific roles may read a private experiment. + if roles_permitted(active_roles, [UserRole.admin, UserRole.mapper]): + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, user_is_contributor or user_is_owner, "experiment") + + +def _handle_update_action( + user_data: Optional[UserData], + entity: Experiment, + private: bool, + user_is_owner: bool, + user_is_contributor: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle UPDATE action permission check for Experiment entities. + + Only owners, contributors, and admins can update Experiments. + + Args: + user_data: The user's authentication data. + entity: The Experiment entity being updated. + private: Whether the Experiment is private. + user_is_owner: Whether the user owns the Experiment. + user_is_contributor: Whether the user is a contributor to the Experiment. + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow update access under the following conditions: + # The owner or contributors may update the experiment. + if user_is_owner or user_is_contributor: + return PermissionResponse(True) + # Users with these specific roles may update the experiment. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, user_is_contributor or user_is_owner, "experiment") + + +def _handle_delete_action( + user_data: Optional[UserData], + entity: Experiment, + private: bool, + user_is_owner: bool, + user_is_contributor: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle DELETE action permission check for Experiment entities. + + Admins can delete any Experiment. Owners can only delete unpublished Experiments. + Contributors cannot delete Experiments. + + Args: + user_data: The user's authentication data. + entity: The Experiment entity being deleted. + private: Whether the Experiment is private. + user_is_owner: Whether the user owns the Experiment. + user_is_contributor: Whether the user is a contributor to the Experiment. + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow delete access under the following conditions: + # Admins may delete any experiment. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + # Owners may delete an experiment only if it is still private. Contributors may not delete an experiment. + if user_is_owner and private: + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, user_is_contributor or user_is_owner, "experiment") + + +def _handle_add_score_set_action( + user_data: Optional[UserData], + entity: Experiment, + private: bool, + user_is_owner: bool, + user_is_contributor: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle ADD_SCORE_SET action permission check for Experiment entities. + + Only permitted users can add a score set to a private experiment. + Any authenticated user can add a score set to a public experiment. + + Args: + user_data: The user's authentication data. + entity: The Experiment entity to add a score set to. + private: Whether the Experiment is private. + user_is_owner: Whether the user owns the Experiment. + user_is_contributor: Whether the user is a contributor to the Experiment. + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow add score set access under the following conditions: + # Owners or contributors may add a score set. + if user_is_owner or user_is_contributor: + return PermissionResponse(True) + # Users with these specific roles may update the experiment. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + # Any authenticated user may add a score set to a non-private experiment. + if not private and user_data is not None: + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, user_is_contributor or user_is_owner, "experiment") diff --git a/src/mavedb/lib/permissions/experiment_set.py b/src/mavedb/lib/permissions/experiment_set.py new file mode 100644 index 00000000..d8cbecc3 --- /dev/null +++ b/src/mavedb/lib/permissions/experiment_set.py @@ -0,0 +1,218 @@ +from typing import Optional + +from mavedb.lib.logging.context import save_to_logging_context +from mavedb.lib.permissions.actions import Action +from mavedb.lib.permissions.models import PermissionResponse +from mavedb.lib.permissions.utils import deny_action_for_entity, roles_permitted +from mavedb.lib.types.authentication import UserData +from mavedb.models.enums.user_role import UserRole +from mavedb.models.experiment_set import ExperimentSet + + +def has_permission(user_data: Optional[UserData], entity: ExperimentSet, action: Action) -> PermissionResponse: + """ + Check if a user has permission to perform an action on an ExperimentSet entity. + + This function evaluates user permissions based on ownership, contributor status, + and user roles. It handles both private and public ExperimentSets with different + access control rules. + + Args: + user_data: The user's authentication data and roles. None for anonymous users. + entity: The ExperimentSet entity to check permissions for. + action: The action to be performed (READ, UPDATE, DELETE, ADD_EXPERIMENT). + + Returns: + PermissionResponse: Contains permission result, HTTP status code, and message. + + Raises: + ValueError: If the entity's private attribute is not set. + NotImplementedError: If the action is not supported for ExperimentSet entities. + """ + if entity.private is None: + raise ValueError("ExperimentSet entity must have 'private' attribute set for permission checks.") + + user_is_owner = False + user_is_contributor = False + active_roles = [] + if user_data is not None: + user_is_owner = entity.created_by_id == user_data.user.id + user_is_contributor = user_data.user.username in [c.orcid_id for c in entity.contributors] + active_roles = user_data.active_roles + + save_to_logging_context( + { + "resource_is_private": entity.private, + "user_is_owner": user_is_owner, + "user_is_contributor": user_is_contributor, + } + ) + + handlers = { + Action.READ: _handle_read_action, + Action.UPDATE: _handle_update_action, + Action.DELETE: _handle_delete_action, + Action.ADD_EXPERIMENT: _handle_add_experiment_action, + } + + if action not in handlers: + supported_actions = ", ".join(a.value for a in handlers.keys()) + raise NotImplementedError( + f"Action '{action.value}' is not supported for experiment set entities. " + f"Supported actions: {supported_actions}" + ) + + return handlers[action]( + user_data, + entity, + entity.private, + user_is_owner, + user_is_contributor, + active_roles, + ) + + +def _handle_read_action( + user_data: Optional[UserData], + entity: ExperimentSet, + private: bool, + user_is_owner: bool, + user_is_contributor: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle READ action permission check for ExperimentSet entities. + + Public ExperimentSets are readable by anyone. Private ExperimentSets are only readable + by owners, contributors, admins, and mappers. + + Args: + user_data: The user's authentication data. + entity: The ExperimentSet entity being accessed. + private: Whether the ExperimentSet is private. + user_is_owner: Whether the user owns the ExperimentSet. + user_is_contributor: Whether the user is a contributor to the ExperimentSet. + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow read access under the following conditions: + # Any user may read a non-private experiment set. + if not private: + return PermissionResponse(True) + # The owner or contributors may read a private experiment set. + if user_is_owner or user_is_contributor: + return PermissionResponse(True) + # Users with these specific roles may read a private experiment set. + if roles_permitted(active_roles, [UserRole.admin, UserRole.mapper]): + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, user_is_contributor or user_is_owner, "experiment set") + + +def _handle_update_action( + user_data: Optional[UserData], + entity: ExperimentSet, + private: bool, + user_is_owner: bool, + user_is_contributor: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle UPDATE action permission check for ExperimentSet entities. + + Only owners, contributors, and admins can update ExperimentSets. + + Args: + user_data: The user's authentication data. + entity: The ExperimentSet entity being updated. + private: Whether the ExperimentSet is private. + user_is_owner: Whether the user owns the ExperimentSet. + user_is_contributor: Whether the user is a contributor to the ExperimentSet. + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow update access under the following conditions: + # The owner or contributors may update the experiment set. + if user_is_owner or user_is_contributor: + return PermissionResponse(True) + # Users with these specific roles may update the experiment set. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, user_is_contributor or user_is_owner, "experiment set") + + +def _handle_delete_action( + user_data: Optional[UserData], + entity: ExperimentSet, + private: bool, + user_is_owner: bool, + user_is_contributor: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle DELETE action permission check for ExperimentSet entities. + + Admins can delete any ExperimentSet. Owners can only delete unpublished ExperimentSets. + Contributors cannot delete ExperimentSets. + + Args: + user_data: The user's authentication data. + entity: The ExperimentSet entity being deleted. + private: Whether the ExperimentSet is private. + user_is_owner: Whether the user owns the ExperimentSet. + user_is_contributor: Whether the user is a contributor to the ExperimentSet. + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow delete access under the following conditions: + # Admins may delete any experiment set. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + # Owners may delete an experiment set only if it is still private. Contributors may not delete an experiment set. + if user_is_owner and private: + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, user_is_contributor or user_is_owner, "experiment set") + + +def _handle_add_experiment_action( + user_data: Optional[UserData], + entity: ExperimentSet, + private: bool, + user_is_owner: bool, + user_is_contributor: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle ADD_EXPERIMENT action permission check for ExperimentSet entities. + + Only permitted users can add an experiment to a private experiment set. + Any authenticated user can add an experiment to a public experiment set. + + Args: + user_data: The user's authentication data. + entity: The ExperimentSet entity to add an experiment to. + private: Whether the ExperimentSet is private. + user_is_owner: Whether the user owns the ExperimentSet. + user_is_contributor: Whether the user is a contributor to the ExperimentSet. + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow add experiment access under the following conditions: + # Owners or contributors may add an experiment. + if user_is_owner or user_is_contributor: + return PermissionResponse(True) + # Users with these specific roles may add an experiment to the experiment set. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, user_is_contributor or user_is_owner, "experiment set") diff --git a/src/mavedb/lib/permissions/models.py b/src/mavedb/lib/permissions/models.py new file mode 100644 index 00000000..0145fc08 --- /dev/null +++ b/src/mavedb/lib/permissions/models.py @@ -0,0 +1,25 @@ +import logging +from typing import Optional + +from mavedb.lib.logging.context import logging_context, save_to_logging_context + +logger = logging.getLogger(__name__) + + +class PermissionResponse: + def __init__(self, permitted: bool, http_code: int = 403, message: Optional[str] = None): + self.permitted = permitted + self.http_code = http_code if not permitted else None + self.message = message if not permitted else None + + save_to_logging_context({"permission_message": self.message, "access_permitted": self.permitted}) + if self.permitted: + logger.debug( + msg="Access to the requested resource is permitted.", + extra=logging_context(), + ) + else: + logger.debug( + msg="Access to the requested resource is not permitted.", + extra=logging_context(), + ) diff --git a/src/mavedb/lib/permissions/score_calibration.py b/src/mavedb/lib/permissions/score_calibration.py new file mode 100644 index 00000000..4c068c7c --- /dev/null +++ b/src/mavedb/lib/permissions/score_calibration.py @@ -0,0 +1,277 @@ +from typing import Optional + +from mavedb.lib.logging.context import save_to_logging_context +from mavedb.lib.permissions.actions import Action +from mavedb.lib.permissions.models import PermissionResponse +from mavedb.lib.permissions.utils import deny_action_for_entity, roles_permitted +from mavedb.lib.types.authentication import UserData +from mavedb.models.enums.user_role import UserRole +from mavedb.models.score_calibration import ScoreCalibration + + +def has_permission(user_data: Optional[UserData], entity: ScoreCalibration, action: Action) -> PermissionResponse: + """ + Check if a user has permission to perform an action on a ScoreCalibration entity. + + This function evaluates user permissions for ScoreCalibration entities, which are + typically administrative objects that require special permissions to modify. + ScoreCalibrations don't have traditional ownership but are tied to ScoreSets. + + Args: + user_data: The user's authentication data and roles. None for anonymous users. + entity: The ScoreCalibration entity to check permissions for. + action: The action to be performed (READ, UPDATE, DELETE, CREATE). + + Returns: + PermissionResponse: Contains permission result, HTTP status code, and message. + + Raises: + NotImplementedError: If the action is not supported for ScoreCalibration entities. + """ + if entity.private is None: + raise ValueError("ScoreCalibration entity must have 'private' attribute set for permission checks.") + + user_is_owner = False + user_is_contributor_to_score_set = False + active_roles = [] + if user_data is not None: + user_is_owner = entity.created_by_id == user_data.user.id + # Contributor status is determined by matching the user's username (ORCID ID) against the contributors' ORCID IDs, + # as well as by matching the user's ID against the created_by_id and modified_by_id fields of the ScoreSet. + user_is_contributor_to_score_set = ( + user_data.user.username in [c.orcid_id for c in entity.score_set.contributors] + or user_data.user.id == entity.score_set.created_by_id + or user_data.user.id == entity.score_set.modified_by_id + ) + active_roles = user_data.active_roles + + save_to_logging_context( + { + "user_is_owner": user_is_owner, + "user_is_contributor_to_score_set": user_is_contributor_to_score_set, + "score_calibration_id": entity.id, + } + ) + + handlers = { + Action.READ: _handle_read_action, + Action.UPDATE: _handle_update_action, + Action.DELETE: _handle_delete_action, + Action.PUBLISH: _handle_publish_action, + Action.CHANGE_RANK: _handle_change_rank_action, + } + + if action not in handlers: + supported_actions = ", ".join(a.value for a in handlers.keys()) + raise NotImplementedError( + f"Action '{action.value}' is not supported for score calibration entities. " + f"Supported actions: {supported_actions}" + ) + + return handlers[action]( + user_data, + entity, + user_is_owner, + user_is_contributor_to_score_set, + entity.private, + active_roles, + ) + + +def _handle_read_action( + user_data: Optional[UserData], + entity: ScoreCalibration, + user_is_owner: bool, + user_is_contributor_to_score_set: bool, + private: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle READ action permission check for ScoreCalibration entities. + + ScoreCalibrations are generally readable by anyone who can access the + associated ScoreSet, as they provide important contextual information + about the score data. + + Args: + user_data: The user's authentication data. + entity: The ScoreCalibration entity being accessed. + user_is_owner: Whether the user created the ScoreCalibration. + user_is_contributor_to_score_set: Whether the user is a contributor to the associated ScoreSet. + private: Whether the ScoreCalibration is private. + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow read access under the following conditions: + # Any user may read a ScoreCalibration if it is not private. + if not private: + return PermissionResponse(True) + # Owners of the ScoreCalibration may read it. + if user_is_owner: + return PermissionResponse(True) + # If the calibration is investigator provided, contributors to the ScoreSet may read it. + if entity.investigator_provided and user_is_contributor_to_score_set: + return PermissionResponse(True) + # System admins may read any ScoreCalibration. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + + user_may_view_private = user_is_owner or (entity.investigator_provided and user_is_contributor_to_score_set) + return deny_action_for_entity(entity, private, user_data, user_may_view_private, "score calibration") + + +def _handle_update_action( + user_data: Optional[UserData], + entity: ScoreCalibration, + user_is_owner: bool, + user_is_contributor_to_score_set: bool, + private: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle UPDATE action permission check for ScoreCalibration entities. + + Updating ScoreCalibrations is typically restricted to administrators + or the original creators, as changes can significantly impact + the interpretation of score data. + + Args: + user_data: The user's authentication data. + entity: The ScoreCalibration entity being accessed. + user_is_owner: Whether the user crated the ScoreCalibration. + user_is_contributor_to_score_set: Whether the user is a contributor to the associated ScoreSet. + private: Whether the ScoreCalibration is private. + active_roles: List of the user's active roles. + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow update access under the following conditions: + # System admins may update any ScoreCalibration. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + # TODO#549: Allow editing of certain fields if the calibration is published. + # For now, published calibrations cannot be updated. + if entity.private: + # Owners may update their own ScoreCalibration if it is not published. + if user_is_owner: + return PermissionResponse(True) + # If the calibration is investigator provided, contributors to the ScoreSet may update it if not published. + if entity.investigator_provided and user_is_contributor_to_score_set: + return PermissionResponse(True) + + user_may_view_private = user_is_owner or (entity.investigator_provided and user_is_contributor_to_score_set) + return deny_action_for_entity(entity, private, user_data, user_may_view_private, "score calibration") + + +def _handle_delete_action( + user_data: Optional[UserData], + entity: ScoreCalibration, + user_is_owner: bool, + user_is_contributor_to_score_set: bool, + private: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle DELETE action permission check for ScoreCalibration entities. + + Deleting ScoreCalibrations is a sensitive operation typically reserved + for administrators or the original creators, as it can affect data integrity. + + Args: + user_data: The user's authentication data. + entity: The ScoreCalibration entity being accessed. + user_is_owner: Whether the user created the ScoreCalibration. + user_is_contributor_to_score_set: Whether the user is a contributor to the associated ScoreSet. + private: Whether the ScoreCalibration is private. + active_roles: List of the user's active roles. + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow delete access under the following conditions: + # System admins may delete any ScoreCalibration. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + # Owners may delete their own ScoreCalibration if it is still private. Contributors may not delete ScoreCalibrations. + if user_is_owner and private: + return PermissionResponse(True) + + user_may_view_private = user_is_owner or (entity.investigator_provided and user_is_contributor_to_score_set) + return deny_action_for_entity(entity, private, user_data, user_may_view_private, "score calibration") + + +def _handle_publish_action( + user_data: Optional[UserData], + entity: ScoreCalibration, + user_is_owner: bool, + user_is_contributor_to_score_set: bool, + private: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle PUBLISH action permission check for ScoreCalibration entities. + + Publishing ScoreCalibrations is typically restricted to administrators + or the original creators, as it signifies that the calibration is + finalized and ready for public use. + + Args: + user_data: The user's authentication data. + entity: The ScoreCalibration entity being accessed. + user_is_owner: Whether the user created the ScoreCalibration. + user_is_contributor_to_score_set: Whether the user is a contributor to the associated ScoreSet. + private: Whether the ScoreCalibration is private. + active_roles: List of the user's active roles. + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow publish access under the following conditions: + # System admins may publish any ScoreCalibration. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + # Owners may publish their own ScoreCalibration. + if user_is_owner: + return PermissionResponse(True) + + user_may_view_private = user_is_owner or (entity.investigator_provided and user_is_contributor_to_score_set) + return deny_action_for_entity(entity, private, user_data, user_may_view_private, "score calibration") + + +def _handle_change_rank_action( + user_data: Optional[UserData], + entity: ScoreCalibration, + user_is_owner: bool, + user_is_contributor_to_score_set: bool, + private: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle CHANGE_RANK action permission check for ScoreCalibration entities. + + Changing the rank of ScoreCalibrations is typically restricted to administrators + or the original creators, as it affects the order in which calibrations are applied. + + Args: + user_data: The user's authentication data. + entity: The ScoreCalibration entity being accessed. + user_is_owner: Whether the user created the ScoreCalibration. + user_is_contributor_to_score_set: Whether the user is a contributor to the associated ScoreSet. + private: Whether the ScoreCalibration is private. + active_roles: List of the user's active roles. + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow change rank access under the following conditions: + # System admins may change the rank of any ScoreCalibration. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + # Owners may change the rank of their own ScoreCalibration. + if user_is_owner: + return PermissionResponse(True) + # If the calibration is investigator provided, contributors to the ScoreSet may change its rank. + if entity.investigator_provided and user_is_contributor_to_score_set: + return PermissionResponse(True) + + user_may_view_private = user_is_owner or (entity.investigator_provided and user_is_contributor_to_score_set) + return deny_action_for_entity(entity, private, user_data, user_may_view_private, "score calibration") diff --git a/src/mavedb/lib/permissions/score_set.py b/src/mavedb/lib/permissions/score_set.py new file mode 100644 index 00000000..6e580669 --- /dev/null +++ b/src/mavedb/lib/permissions/score_set.py @@ -0,0 +1,255 @@ +from typing import Optional + +from mavedb.lib.logging.context import save_to_logging_context +from mavedb.lib.permissions.actions import Action +from mavedb.lib.permissions.models import PermissionResponse +from mavedb.lib.permissions.utils import deny_action_for_entity, roles_permitted +from mavedb.lib.types.authentication import UserData +from mavedb.models.enums.user_role import UserRole +from mavedb.models.score_set import ScoreSet + + +def has_permission(user_data: Optional[UserData], entity: ScoreSet, action: Action) -> PermissionResponse: + """ + Check if a user has permission to perform an action on a ScoreSet entity. + + This function evaluates user permissions based on ownership, contributor status, + and user roles. It handles both private and public ScoreSets with different + access control rules. + + Args: + user_data: The user's authentication data and roles. None for anonymous users. + entity: The ScoreSet entity to check permissions for. + action: The action to be performed (READ, UPDATE, DELETE, PUBLISH, SET_SCORES). + + Returns: + PermissionResponse: Contains permission result, HTTP status code, and message. + + Raises: + ValueError: If the entity's private attribute is not set. + NotImplementedError: If the action is not supported for ScoreSet entities. + """ + if entity.private is None: + raise ValueError("ScoreSet entity must have 'private' attribute set for permission checks.") + + user_is_owner = False + user_is_contributor = False + active_roles = [] + if user_data is not None: + user_is_owner = entity.created_by_id == user_data.user.id + user_is_contributor = user_data.user.username in [c.orcid_id for c in entity.contributors] + active_roles = user_data.active_roles + + save_to_logging_context( + { + "resource_is_private": entity.private, + "user_is_owner": user_is_owner, + "user_is_contributor": user_is_contributor, + } + ) + + handlers = { + Action.READ: _handle_read_action, + Action.UPDATE: _handle_update_action, + Action.DELETE: _handle_delete_action, + Action.PUBLISH: _handle_publish_action, + Action.SET_SCORES: _handle_set_scores_action, + } + + if action not in handlers: + supported_actions = ", ".join(a.value for a in handlers.keys()) + raise NotImplementedError( + f"Action '{action.value}' is not supported for score set entities. " + f"Supported actions: {supported_actions}" + ) + + return handlers[action]( + user_data, + entity, + entity.private, + user_is_owner, + user_is_contributor, + active_roles, + ) + + +def _handle_read_action( + user_data: Optional[UserData], + entity: ScoreSet, + private: bool, + user_is_owner: bool, + user_is_contributor: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle READ action permission check for ScoreSet entities. + + Public ScoreSets are readable by anyone. Private ScoreSets are only readable + by owners, contributors, admins, and mappers. + + Args: + user_data: The user's authentication data. + entity: The ScoreSet entity being accessed. + private: Whether the ScoreSet is private. + user_is_owner: Whether the user owns the ScoreSet. + user_is_contributor: Whether the user is a contributor to the ScoreSet. + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow read access under the following conditions: + # Any user may read a non-private score set. + if not private: + return PermissionResponse(True) + # The owner or contributors may read a private score set. + if user_is_owner or user_is_contributor: + return PermissionResponse(True) + # Users with these specific roles may read a private score set. + if roles_permitted(active_roles, [UserRole.admin, UserRole.mapper]): + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, user_is_contributor or user_is_owner, "score set") + + +def _handle_update_action( + user_data: Optional[UserData], + entity: ScoreSet, + private: bool, + user_is_owner: bool, + user_is_contributor: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle UPDATE action permission check for ScoreSet entities. + + Only owners, contributors, and admins can update ScoreSets. + + Args: + user_data: The user's authentication data. + entity: The ScoreSet entity being updated. + private: Whether the ScoreSet is private. + user_is_owner: Whether the user owns the ScoreSet. + user_is_contributor: Whether the user is a contributor to the ScoreSet. + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow update access under the following conditions: + # The owner or contributors may update the score set. + if user_is_owner or user_is_contributor: + return PermissionResponse(True) + # Users with these specific roles may update the score set. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, user_is_contributor or user_is_owner, "score set") + + +def _handle_delete_action( + user_data: Optional[UserData], + entity: ScoreSet, + private: bool, + user_is_owner: bool, + user_is_contributor: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle DELETE action permission check for ScoreSet entities. + + Admins can delete any ScoreSet. Owners can only delete unpublished ScoreSets. + Contributors cannot delete ScoreSets. + + Args: + user_data: The user's authentication data. + entity: The ScoreSet entity being deleted. + private: Whether the ScoreSet is private. + user_is_owner: Whether the user owns the ScoreSet. + user_is_contributor: Whether the user is a contributor to the ScoreSet. + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow delete access under the following conditions: + # Admins may delete any score set. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + # Owners may delete a score set only if it is still private. Contributors may not delete a score set. + if user_is_owner and private: + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, user_is_contributor or user_is_owner, "score set") + + +def _handle_publish_action( + user_data: Optional[UserData], + entity: ScoreSet, + private: bool, + user_is_owner: bool, + user_is_contributor: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle PUBLISH action permission check for ScoreSet entities. + + Owners, and admins can publish private ScoreSets to make them + publicly accessible. + + Args: + user_data: The user's authentication data. + entity: The ScoreSet entity being published. + private: Whether the ScoreSet is private. + user_is_owner: Whether the user owns the ScoreSet. + user_is_contributor: Whether the user is a contributor to the ScoreSet. + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow publish access under the following conditions: + # The owner may publish the score set. + if user_is_owner: + return PermissionResponse(True) + # Users with these specific roles may publish the score set. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, user_is_contributor or user_is_owner, "score set") + + +def _handle_set_scores_action( + user_data: Optional[UserData], + entity: ScoreSet, + private: bool, + user_is_owner: bool, + user_is_contributor: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle SET_SCORES action permission check for ScoreSet entities. + + Only owners, contributors, and admins can modify the scores data within + a ScoreSet. This is a critical operation that affects the scientific data. + + Args: + user_data: The user's authentication data. + entity: The ScoreSet entity whose scores are being modified. + private: Whether the ScoreSet is private. + user_is_owner: Whether the user owns the ScoreSet. + user_is_contributor: Whether the user is a contributor to the ScoreSet. + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow set scores access under the following conditions: + # The owner or contributors may set scores. + if user_is_owner or user_is_contributor: + return PermissionResponse(True) + # Users with these specific roles may set scores. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + + return deny_action_for_entity(entity, private, user_data, user_is_contributor or user_is_owner, "score set") diff --git a/src/mavedb/lib/permissions/user.py b/src/mavedb/lib/permissions/user.py new file mode 100644 index 00000000..cee817ca --- /dev/null +++ b/src/mavedb/lib/permissions/user.py @@ -0,0 +1,192 @@ +from typing import Optional + +from mavedb.lib.logging.context import save_to_logging_context +from mavedb.lib.permissions.actions import Action +from mavedb.lib.permissions.models import PermissionResponse +from mavedb.lib.permissions.utils import deny_action_for_entity, roles_permitted +from mavedb.lib.types.authentication import UserData +from mavedb.models.enums.user_role import UserRole +from mavedb.models.user import User + + +def has_permission(user_data: Optional[UserData], entity: User, action: Action) -> PermissionResponse: + """ + Check if a user has permission to perform an action on a User entity. + + This function evaluates user permissions based on user identity and roles. + User entities have different access patterns since they don't have public/private + states or ownership in the traditional sense. + + Args: + user_data: The user's authentication data and roles. None for anonymous users. + entity: The User entity to check permissions for. + action: The action to be performed (READ, UPDATE, LOOKUP, ADD_ROLE). + + Returns: + PermissionResponse: Contains permission result, HTTP status code, and message. + + Raises: + NotImplementedError: If the action is not supported for User entities. + + Note: + User entities do not have private/public states or traditional ownership models. + Permissions are based on user identity and administrative roles. + """ + user_is_self = False + active_roles = [] + + if user_data is not None: + user_is_self = entity.id == user_data.user.id + active_roles = user_data.active_roles + + save_to_logging_context( + { + "user_is_self": user_is_self, + "target_user_id": entity.id, + } + ) + + handlers = { + Action.READ: _handle_read_action, + Action.UPDATE: _handle_update_action, + Action.LOOKUP: _handle_lookup_action, + Action.ADD_ROLE: _handle_add_role_action, + } + + if action not in handlers: + supported_actions = ", ".join(a.value for a in handlers.keys()) + raise NotImplementedError( + f"Action '{action.value}' is not supported for user profile entities. " + f"Supported actions: {supported_actions}" + ) + + return handlers[action]( + user_data, + entity, + user_is_self, + active_roles, + ) + + +def _handle_read_action( + user_data: Optional[UserData], + entity: User, + user_is_self: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle READ action permission check for User entities. + + Users can read their own profile. Admins can read any user profile. + READ access to profiles refers to admin level properties. Basic user info + is handled by the LOOKUP action. + + Args: + user_data: The user's authentication data. + entity: The User entity being accessed. + user_is_self: Whether the user is viewing their own profile. + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + + Note: + Basic user information (username, display name) is typically public, + but sensitive information requires appropriate permissions. + """ + ## Allow read access under the following conditions: + # Users can always read their own profile. + if user_is_self: + return PermissionResponse(True) + # Admins can read any user profile. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + + return deny_action_for_entity(entity, False, user_data, False, "user profile") + + +def _handle_lookup_action( + user_data: Optional[UserData], + entity: User, + user_is_self: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle LOOKUP action permission check for User entities. + + Any authenticated user can look up basic information about other users. + Anonymous users cannot perform LOOKUP actions. + + Args: + user_data: The user's authentication data. + entity: The User entity being looked up. + user_is_self: Whether the user is looking up their own profile. + active_roles: List of the user's active roles. + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow lookup access under the following conditions: + # Any authenticated user can look up basic user information. + if user_data is not None and user_data.user is not None: + return PermissionResponse(True) + + return deny_action_for_entity(entity, False, user_data, False, "user profile") + + +def _handle_update_action( + user_data: Optional[UserData], + entity: User, + user_is_self: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle UPDATE action permission check for User entities. + + Users can update their own profile. Admins can update any user profile. + + Args: + user_data: The user's authentication data. + entity: The User entity being updated. + user_is_self: Whether the user is updating their own profile. + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow update access under the following conditions: + # Users can update their own profile. + if user_is_self: + return PermissionResponse(True) + # Admins can update any user profile. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + + return deny_action_for_entity(entity, False, user_data, False, "user profile") + + +def _handle_add_role_action( + user_data: Optional[UserData], + entity: User, + user_is_self: bool, + active_roles: list[UserRole], +) -> PermissionResponse: + """ + Handle ADD_ROLE action permission check for User entities. + + Only admins can add roles to users. + + Args: + user_data: The user's authentication data. + entity: The User entity being modified. + user_is_self: Whether the user is modifying their own profile. + active_roles: List of the user's active roles. + + Returns: + PermissionResponse: Permission result with appropriate HTTP status. + """ + ## Allow add role access under the following conditions: + # Only admins can add roles to users. + if roles_permitted(active_roles, [UserRole.admin]): + return PermissionResponse(True) + + return deny_action_for_entity(entity, False, user_data, False, "user profile") diff --git a/src/mavedb/lib/permissions/utils.py b/src/mavedb/lib/permissions/utils.py new file mode 100644 index 00000000..3d92ce1d --- /dev/null +++ b/src/mavedb/lib/permissions/utils.py @@ -0,0 +1,132 @@ +import logging +from typing import Optional, Union, overload + +from mavedb.lib.logging.context import logging_context, save_to_logging_context +from mavedb.lib.permissions.models import PermissionResponse +from mavedb.lib.types.authentication import UserData +from mavedb.lib.types.permissions import EntityType +from mavedb.models.enums.contribution_role import ContributionRole +from mavedb.models.enums.user_role import UserRole + +logger = logging.getLogger(__name__) + + +@overload +def roles_permitted( + user_roles: list[UserRole], + permitted_roles: list[UserRole], +) -> bool: ... + + +@overload +def roles_permitted( + user_roles: list[ContributionRole], + permitted_roles: list[ContributionRole], +) -> bool: ... + + +def roles_permitted( + user_roles: Union[list[UserRole], list[ContributionRole]], + permitted_roles: Union[list[UserRole], list[ContributionRole]], +) -> bool: + """ + Check if any user role is permitted based on a list of allowed roles. + + This function validates that both user_roles and permitted_roles are lists of the same enum type + (either all UserRole or all ContributionRole), and checks if any user role is present in the permitted roles. + Raises ValueError if either list contains mixed role types or if the lists are of different types. + + Args: + user_roles: List of roles assigned to the user (UserRole or ContributionRole). + permitted_roles: List of roles that are permitted for the action (UserRole or ContributionRole). + + Returns: + bool: True if any user role is permitted, False otherwise. + + Raises: + ValueError: If user_roles or permitted_roles contain mixed role types, or if the lists are of different types. + + Example: + >>> roles_permitted([UserRole.admin], [UserRole.admin, UserRole.editor]) + True + >>> roles_permitted([ContributionRole.admin], [ContributionRole.editor]) + False + + Note: + This function is used to enforce type safety and prevent mixing of role enums in permission checks. + """ + save_to_logging_context({"permitted_roles": [role.name for role in permitted_roles]}) + + if not user_roles: + logger.debug(msg="User has no associated roles.", extra=logging_context()) + return False + + # Validate that both lists contain the same enum type + if user_roles and permitted_roles: + user_role_types = {type(role) for role in user_roles} + permitted_role_types = {type(role) for role in permitted_roles} + + # Check if either list has mixed types + if len(user_role_types) > 1: + raise ValueError("user_roles list cannot contain mixed role types (UserRole and ContributionRole)") + if len(permitted_role_types) > 1: + raise ValueError("permitted_roles list cannot contain mixed role types (UserRole and ContributionRole)") + + # Check if the lists have different role types + if user_role_types != permitted_role_types: + raise ValueError( + "user_roles and permitted_roles must contain the same role type (both UserRole or both ContributionRole)" + ) + + return any(role in permitted_roles for role in user_roles) + + +def deny_action_for_entity( + entity: EntityType, + private: bool, + user_data: Optional[UserData], + user_may_view_private: bool, + user_facing_model_name: str = "entity", +) -> PermissionResponse: + """ + Generate appropriate denial response for entity permission checks. + + This helper function determines the correct HTTP status code and message + when denying access to an entity based on its privacy and user authentication. + + Args: + entity: The entity being accessed. + private: Whether the entity is private. + user_data: The user's authentication data (None for anonymous). + user_may_view_private: Whether the user has permission to view private entities. + + Returns: + PermissionResponse: Denial response with appropriate HTTP status and message. + + Note: + Returns 404 for private entities to avoid information disclosure, + 401 for unauthenticated users, and 403 for insufficient permissions. + """ + + def _identifier_for_entity(entity: EntityType) -> tuple[str, str]: + if hasattr(entity, "urn") and entity.urn is not None: + return "URN", entity.urn + elif hasattr(entity, "id") and entity.id is not None: + return "ID", str(entity.id) + else: + return "unknown", "unknown" + + field, identifier = _identifier_for_entity(entity) + # Do not acknowledge the existence of a private score set. + if private and not user_may_view_private: + return PermissionResponse(False, 404, f"{user_facing_model_name} with {field} '{identifier}' not found") + # No authenticated user is present. + if user_data is None or user_data.user is None: + return PermissionResponse( + False, 401, f"authentication required to access {user_facing_model_name} with {field} '{identifier}'" + ) + + # The authenticated user lacks sufficient permissions. + return PermissionResponse( + False, 403, f"insufficient permissions on {user_facing_model_name} with {field} '{identifier}'" + ) diff --git a/src/mavedb/lib/score_calibrations.py b/src/mavedb/lib/score_calibrations.py index cc67673a..98c7708c 100644 --- a/src/mavedb/lib/score_calibrations.py +++ b/src/mavedb/lib/score_calibrations.py @@ -1,18 +1,105 @@ """Utilities for building and mutating score calibration ORM objects.""" +import math +from typing import Optional, Union + +import pandas as pd +from sqlalchemy import Float, and_, select from sqlalchemy.orm import Session +from mavedb.lib.acmg import find_or_create_acmg_classification from mavedb.lib.identifiers import find_or_create_publication_identifier +from mavedb.lib.types.score_calibrations import ClassificationDict +from mavedb.lib.validation.constants.general import ( + calibration_class_column_name, + calibration_variant_column_name, + hgvs_nt_column, + hgvs_pro_column, +) +from mavedb.lib.validation.utilities import inf_or_float from mavedb.models.enums.score_calibration_relation import ScoreCalibrationRelation from mavedb.models.score_calibration import ScoreCalibration -from mavedb.models.score_set import ScoreSet +from mavedb.models.score_calibration_functional_classification import ScoreCalibrationFunctionalClassification from mavedb.models.score_calibration_publication_identifier import ScoreCalibrationPublicationIdentifierAssociation +from mavedb.models.score_set import ScoreSet from mavedb.models.user import User +from mavedb.models.variant import Variant from mavedb.view_models import score_calibration +def create_functional_classification( + db: Session, + functional_range_create: Union[ + score_calibration.FunctionalClassificationCreate, score_calibration.FunctionalClassificationModify + ], + containing_calibration: ScoreCalibration, + variant_classes: Optional[ClassificationDict] = None, +) -> ScoreCalibrationFunctionalClassification: + """ + Create a functional classification entity for score calibration. + This function creates a new ScoreCalibrationFunctionalClassification object + based on the provided functional range data. It optionally creates or finds + an associated ACMG classification if one is specified in the input data. + + Args: + db (Session): Database session for performing database operations. + functional_range_create (score_calibration.FunctionalClassificationCreate): + Input data containing the functional range parameters including label, + description, range bounds, inclusivity flags, and optional ACMG + classification information. + containing_calibration (ScoreCalibration): The ScoreCalibration instance. + variant_classes (Optional[ClassificationDict]): Optional dictionary mapping variant classes + to their corresponding variant identifiers. + + Returns: + ScoreCalibrationFunctionalClassification: The newly created functional + classification entity that has been added to the database session. + + Note: + The function adds the created functional classification to the database + session but does not commit the transaction. The caller is responsible + for committing the changes. + """ + acmg_classification = None + if functional_range_create.acmg_classification: + acmg_classification = find_or_create_acmg_classification( + db, + criterion=functional_range_create.acmg_classification.criterion, + evidence_strength=functional_range_create.acmg_classification.evidence_strength, + points=functional_range_create.acmg_classification.points, + ) + else: + acmg_classification = None + + functional_classification = ScoreCalibrationFunctionalClassification( + label=functional_range_create.label, + description=functional_range_create.description, + range=functional_range_create.range, + class_=functional_range_create.class_, + inclusive_lower_bound=functional_range_create.inclusive_lower_bound, + inclusive_upper_bound=functional_range_create.inclusive_upper_bound, + acmg_classification=acmg_classification, + functional_classification=functional_range_create.functional_classification, + oddspaths_ratio=functional_range_create.oddspaths_ratio, # type: ignore[arg-type] + positive_likelihood_ratio=functional_range_create.positive_likelihood_ratio, # type: ignore[arg-type] + acmg_classification_id=acmg_classification.id if acmg_classification else None, + calibration=containing_calibration, + ) + + contained_variants = variants_for_functional_classification( + db, functional_classification, variant_classes=variant_classes, use_sql=True + ) + functional_classification.variants = contained_variants + + return functional_classification + + async def _create_score_calibration( - db: Session, calibration_create: score_calibration.ScoreCalibrationCreate, user: User + db: Session, + calibration_create: score_calibration.ScoreCalibrationCreate, + user: User, + variant_classes: Optional[ClassificationDict] = None, + containing_score_set: Optional[ScoreSet] = None, ) -> ScoreCalibration: """ Create a ScoreCalibration ORM instance (not yet persisted) together with its @@ -44,6 +131,10 @@ async def _create_score_calibration( optional lists of publication source identifiers grouped by relation type. user : User Authenticated user context; the user to be recorded for audit + variant_classes (Optional[ClassificationDict]): + Optional dictionary mapping variant classes to their corresponding variant identifiers. + containing_score_set : Optional[ScoreSet] + If provided, the ScoreSet instance to which the new calibration will belong. Returns ------- @@ -89,6 +180,7 @@ async def _create_score_calibration( **calibration_create.model_dump( by_alias=False, exclude={ + "functional_classifications", "threshold_sources", "classification_sources", "method_sources", @@ -96,15 +188,30 @@ async def _create_score_calibration( }, ), publication_identifier_associations=calibration_pub_assocs, + functional_classifications=[], created_by=user, modified_by=user, ) # type: ignore[call-arg] + if containing_score_set: + calibration.score_set = containing_score_set + calibration.score_set_id = containing_score_set.id + + for functional_range_create in calibration_create.functional_classifications or []: + persisted_functional_range = create_functional_classification( + db, functional_range_create, containing_calibration=calibration, variant_classes=variant_classes + ) + db.add(persisted_functional_range) + calibration.functional_classifications.append(persisted_functional_range) + return calibration async def create_score_calibration_in_score_set( - db: Session, calibration_create: score_calibration.ScoreCalibrationCreate, user: User + db: Session, + calibration_create: score_calibration.ScoreCalibrationCreate, + user: User, + variant_classes: Optional[ClassificationDict] = None, ) -> ScoreCalibration: """ Create a new score calibration and associate it with an existing score set. @@ -120,6 +227,8 @@ async def create_score_calibration_in_score_set( object containing the fields required to create a score calibration. Must include a non-empty score_set_urn. user (User): Authenticated user information used for auditing + variant_classes (Optional[ClassificationDict]): Optional dictionary mapping variant classes + to their corresponding variant identifiers. Returns: ScoreCalibration: The newly created and persisted score calibration object with its @@ -142,8 +251,7 @@ async def create_score_calibration_in_score_set( raise ValueError("score_set_urn must be provided to create a score calibration within a score set.") containing_score_set = db.query(ScoreSet).where(ScoreSet.urn == calibration_create.score_set_urn).one() - calibration = await _create_score_calibration(db, calibration_create, user) - calibration.score_set = containing_score_set + calibration = await _create_score_calibration(db, calibration_create, user, variant_classes, containing_score_set) if user.username in [contributor.orcid_id for contributor in containing_score_set.contributors] + [ containing_score_set.created_by.username, @@ -158,7 +266,10 @@ async def create_score_calibration_in_score_set( async def create_score_calibration( - db: Session, calibration_create: score_calibration.ScoreCalibrationCreate, user: User + db: Session, + calibration_create: score_calibration.ScoreCalibrationCreate, + user: User, + variant_classes: Optional[ClassificationDict] = None, ) -> ScoreCalibration: """ Asynchronously create and persist a new ScoreCalibration record. @@ -176,6 +287,8 @@ async def create_score_calibration( score set identifiers). user : User Authenticated user context; the user to be recorded for audit + variant_classes (Optional[ClassificationDict]): Optional dictionary mapping variant classes + to their corresponding variant identifiers. Returns ------- @@ -207,7 +320,9 @@ async def create_score_calibration( if calibration_create.score_set_urn: raise ValueError("score_set_urn must not be provided to create a score calibration outside a score set.") - created_calibration = await _create_score_calibration(db, calibration_create, user) + created_calibration = await _create_score_calibration( + db, calibration_create, user, variant_classes, containing_score_set=None + ) db.add(created_calibration) return created_calibration @@ -218,76 +333,79 @@ async def modify_score_calibration( calibration: ScoreCalibration, calibration_update: score_calibration.ScoreCalibrationModify, user: User, + variant_classes: Optional[ClassificationDict] = None, ) -> ScoreCalibration: """ - Asynchronously modify an existing ScoreCalibration record and its related publication - identifier associations. - - This function: - 1. Validates that a score_set_urn is provided in the update model (raises ValueError if absent). - 2. Loads (via SELECT ... WHERE urn = :score_set_urn) the ScoreSet that will contain the calibration. - 3. Reconciles publication identifier associations for three relation categories: - - threshold_sources -> ScoreCalibrationRelation.threshold - - classification_sources -> ScoreCalibrationRelation.classification - - method_sources -> ScoreCalibrationRelation.method - For each provided source identifier: - * Calls find_or_create_publication_identifier to obtain (or persist) the identifier row. - * Preserves an existing association if already present. - * Creates a new association if missing. - Any previously existing associations not referenced in the update are deleted from the session. - 4. Updates mutable scalar fields on the calibration instance from calibration_update, excluding: - threshold_sources, classification_sources, method_sources, created_at, created_by, - modified_at, modified_by. - 5. Reassigns the calibration to the resolved ScoreSet, replaces its association collection, - and stamps modified_by with the requesting user. - 6. Adds the modified calibration back into the SQLAlchemy session and returns it (no commit). - - Parameters - ---------- - db : Session - An active SQLAlchemy session (synchronous engine session used within an async context). - calibration : ScoreCalibration - The existing calibration ORM instance to be modified (must be persistent or pending). - del carrying updated field values plus source identifier lists: - - score_set_urn (required) - - threshold_sources, classification_sources, method_sources (iterables of identifier objects) - - Additional mutable calibration attributes. - user : User - Context for the authenticated user; the user to be recorded for audit. - - Returns - ------- - ScoreCalibration - The in-memory (and session-added) updated calibration instance. Changes are not committed. - - Raises - ------ - ValueError - If score_set_urn is missing in the update model. - sqlalchemy.orm.exc.NoResultFound - If no ScoreSet exists with the provided URN. - sqlalchemy.orm.exc.MultipleResultsFound - If more than one ScoreSet matches the provided URN. - Any exception raised by find_or_create_publication_identifier - If identifier resolution/creation fails. - - Side Effects - ------------ - - Issues SELECT statements for the ScoreSet and publication identifiers. - - May INSERT new publication identifiers and association rows. - - May DELETE association rows no longer referenced. - - Mutates the provided calibration object in-place. - - Concurrency / Consistency Notes - ------------------------------- - The reconciliation of associations assumes no concurrent modification of the same calibration's - association set within the active transaction. To prevent races leading to duplicate associations, - enforce appropriate transaction isolation or unique constraints at the database level. - - Commit Responsibility - --------------------- - This function does NOT call commit or flush explicitly; the caller is responsible for committing - the session to persist changes. + Asynchronously modify an existing ScoreCalibration record and its related publication + identifier associations. + + This function: + 1. Validates that a score_set_urn is provided in the update model (raises ValueError if absent). + 2. Loads (via SELECT ... WHERE urn = :score_set_urn) the ScoreSet that will contain the calibration. + 3. Reconciles publication identifier associations for three relation categories: + - threshold_sources -> ScoreCalibrationRelation.threshold + - classification_sources -> ScoreCalibrationRelation.classification + - method_sources -> ScoreCalibrationRelation.method + For each provided source identifier: + * Calls find_or_create_publication_identifier to obtain (or persist) the identifier row. + * Preserves an existing association if already present. + * Creates a new association if missing. + Any previously existing associations not referenced in the update are deleted from the session. + 4. Updates mutable scalar fields on the calibration instance from calibration_update, excluding: + threshold_sources, classification_sources, method_sources, created_at, created_by, + modified_at, modified_by. + 5. Reassigns the calibration to the resolved ScoreSet, replaces its association collection, + and stamps modified_by with the requesting user. + 6. Adds the modified calibration back into the SQLAlchemy session and returns it (no commit). + + Parameters + ---------- + db : Session + An active SQLAlchemy session (synchronous engine session used within an async context). + calibration : ScoreCalibration + The existing calibration ORM instance to be modified (must be persistent or pending). + calibration_update : score_calibration.ScoreCalibrationModify + - score_set_urn (required) + - threshold_sources, classification_sources, method_sources (iterables of identifier objects) + - Additional mutable calibration attributes. + user : User + Context for the authenticated user; the user to be recorded for audit. + variant_classes (Optional[ClassificationDict]): Optional dictionary mapping variant classes + to their corresponding variant identifiers. + + Returns + ------- + ScoreCalibration + The in-memory (and session-added) updated calibration instance. Changes are not committed. + + Raises + ------ + ValueError + If score_set_urn is missing in the update model. + sqlalchemy.orm.exc.NoResultFound + If no ScoreSet exists with the provided URN. + sqlalchemy.orm.exc.MultipleResultsFound + If more than one ScoreSet matches the provided URN. + Any exception raised by find_or_create_publication_identifier + If identifier resolution/creation fails. + + Side Effects + ------------ + - Issues SELECT statements for the ScoreSet and publication identifiers. + - May INSERT new publication identifiers and association rows. + - May DELETE association rows no longer referenced. + - Mutates the provided calibration object in-place. + + Concurrency / Consistency Notes + ------------------------------- + The reconciliation of associations assumes no concurrent modification of the same calibration's + association set within the active transaction. To prevent races leading to duplicate associations, + enforce appropriate transaction isolation or unique constraints at the database level. + + Commit Responsibility + --------------------- + This function does NOT call commit or flush explicitly; the caller is responsible for committing + the session to persist changes. """ if not calibration_update.score_set_urn: @@ -328,12 +446,19 @@ async def modify_score_calibration( db.add(pub) db.flush() - # Remove associations that are no longer present + # Remove associations and calibrations that are no longer present for assoc in existing_assocs_map.values(): db.delete(assoc) + for functional_classification in calibration.functional_classifications: + db.delete(functional_classification) + calibration.functional_classifications.clear() + + db.flush() + db.refresh(calibration) for attr, value in calibration_update.model_dump().items(): if attr not in { + "functional_classifications", "threshold_sources", "classification_sources", "method_sources", @@ -346,9 +471,17 @@ async def modify_score_calibration( setattr(calibration, attr, value) calibration.score_set = containing_score_set + calibration.score_set_id = containing_score_set.id calibration.publication_identifier_associations = updated_assocs calibration.modified_by = user + for functional_range_update in calibration_update.functional_classifications or []: + persisted_functional_range = create_functional_classification( + db, functional_range_update, variant_classes=variant_classes, containing_calibration=calibration + ) + db.add(persisted_functional_range) + calibration.functional_classifications.append(persisted_functional_range) + db.add(calibration) return calibration @@ -517,3 +650,182 @@ def delete_score_calibration(db: Session, calibration: ScoreCalibration) -> None db.delete(calibration) return None + + +def variants_for_functional_classification( + db: Session, + functional_classification: ScoreCalibrationFunctionalClassification, + variant_classes: Optional[ClassificationDict] = None, + use_sql: bool = False, +) -> list[Variant]: + """ + Return variants in the parent score set whose numeric score falls inside the + functional classification's range. + + The variant score is extracted from the JSONB ``Variant.data`` field using + ``score_json_path`` (default: ("score_data", "score") meaning + ``variant.data['score_data']['score']``). The classification's existing + ``score_is_contained_in_range`` method is used for interval logic, including + inclusive/exclusive behaviors. + + Parameters + ---------- + db : Session + Active SQLAlchemy session. + functional_classification : ScoreCalibrationFunctionalClassification + The ORM row defining the interval to test against. + variant_classes : Optional[ClassificationDict] + If provided, a dictionary mapping variant classes to their corresponding variant identifiers + to use for classification rather than the range property of the functional_classification. + use_sql : bool + When True, perform filtering in the database using JSONB extraction and + range predicates; falls back to Python filtering if an error occurs. + + Returns + ------- + list[Variant] + Variants whose score falls within the specified range. Empty list if + classification has no usable range. + + Notes + ----- + * If use_sql=False (default) filtering occurs in Python after loading all + variants for the score set. For large sets set use_sql=True to push + comparison into Postgres. + * Variants lacking a score or with non-numeric scores are skipped. + * If ``functional_classification.range`` is ``None`` an empty list is + returned immediately. + """ + # Resolve score set id from attached calibration (relationship may be lazy) + score_set_id = functional_classification.calibration.score_set_id # type: ignore[attr-defined] + + if variant_classes and variant_classes["indexed_by"] not in [ + hgvs_nt_column, + hgvs_pro_column, + calibration_variant_column_name, + ]: + raise ValueError(f"Unsupported index column `{variant_classes['indexed_by']}` for variant classification.") + + if use_sql: + try: + # Build score extraction expression: data['score_data']['score']::text::float + score_expr = Variant.data["score_data"]["score"].astext.cast(Float) + + conditions = [Variant.score_set_id == score_set_id] + if variant_classes is not None and functional_classification.class_ is not None: + index_element = variant_classes["classifications"].get(functional_classification.class_, set()) + + if variant_classes["indexed_by"] == hgvs_nt_column: + conditions.append(Variant.hgvs_nt.in_(index_element)) + elif variant_classes["indexed_by"] == hgvs_pro_column: + conditions.append(Variant.hgvs_pro.in_(index_element)) + elif variant_classes["indexed_by"] == calibration_variant_column_name: + conditions.append(Variant.urn.in_(index_element)) + else: # pragma: no cover + return [] + + elif functional_classification.range is not None and len(functional_classification.range) == 2: + lower_raw, upper_raw = functional_classification.range + + # Convert 'inf' sentinels (or None) to float infinities for condition omission. + lower_bound = inf_or_float(lower_raw, lower=True) + upper_bound = inf_or_float(upper_raw, lower=False) + + if not math.isinf(lower_bound): + if functional_classification.inclusive_lower_bound: + conditions.append(score_expr >= lower_bound) + else: + conditions.append(score_expr > lower_bound) + if not math.isinf(upper_bound): + if functional_classification.inclusive_upper_bound: + conditions.append(score_expr <= upper_bound) + else: + conditions.append(score_expr < upper_bound) + + else: + # No usable classification mechanism; return empty list. + return [] + + stmt = select(Variant).where(and_(*conditions)) + return list(db.execute(stmt).scalars()) + + except Exception: # noqa: BLE001 + # Fall back to Python filtering if casting/JSON path errors occur. + pass + + # Python filtering fallback / default path + variants = db.execute(select(Variant).where(Variant.score_set_id == score_set_id)).scalars().all() + matches: list[Variant] = [] + for v in variants: + if variant_classes is not None and functional_classification.class_ is not None: + index_element = variant_classes["classifications"].get(functional_classification.class_, set()) + + if variant_classes["indexed_by"] == hgvs_nt_column: + if v.hgvs_nt in index_element: + matches.append(v) + elif variant_classes["indexed_by"] == hgvs_pro_column: + if v.hgvs_pro in index_element: + matches.append(v) + elif variant_classes["indexed_by"] == calibration_variant_column_name: + if v.urn in index_element: + matches.append(v) + else: # pragma: no cover + continue + + elif functional_classification.range is not None and len(functional_classification.range) == 2: + try: + container = v.data.get("score_data") if isinstance(v.data, dict) else None + if not container or not isinstance(container, dict): + continue + + raw = container.get("score") + if raw is None: + continue + + score = float(raw) + + except Exception: # noqa: BLE001 + continue + + if functional_classification.score_is_contained_in_range(score): + matches.append(v) + + return matches + + +def variant_classification_df_to_dict( + df: pd.DataFrame, + index_column: str, +) -> ClassificationDict: + """ + Convert a DataFrame of variant classifications into a dictionary mapping + functional class labels to lists of distinct variant URNs. + + The input DataFrame is expected to have at least two columns: + - The unique identifier for each variant (given by calibration_variant_column_name). + - The functional classification label for each variant (given by calibration_class_column_name). + + Parameters + ---------- + df : pd.DataFrame + DataFrame containing variant classifications with 'variant_urn' and + 'functional_class' columns. + + Returns + ------- + ClassificationDict + A dictionary with two keys: 'indexed_by' indicating the index column name, + and 'classifications' mapping each functional class label to a list of + distinct variant URNs. + """ + classifications: dict[str, set[str]] = {} + for _, row in df.iterrows(): + index_element = row[index_column] + functional_class = row[calibration_class_column_name] + + if functional_class not in classifications: + classifications[functional_class] = set() + + classifications[functional_class].add(index_element) + + return {"indexed_by": index_column, "classifications": classifications} diff --git a/src/mavedb/lib/score_sets.py b/src/mavedb/lib/score_sets.py index 190d7b42..3d785129 100644 --- a/src/mavedb/lib/score_sets.py +++ b/src/mavedb/lib/score_sets.py @@ -23,6 +23,8 @@ VARIANT_SCORE_DATA, ) from mavedb.lib.mave.utils import is_csv_null +from mavedb.lib.permissions import Action, has_permission +from mavedb.lib.types.authentication import UserData from mavedb.lib.validation.constants.general import null_values_list from mavedb.lib.validation.utilities import is_null as validate_is_null from mavedb.lib.variants import get_digest_from_post_mapped, get_hgvs_from_post_mapped, is_hgvs_g, is_hgvs_p @@ -55,7 +57,6 @@ from mavedb.view_models.search import ScoreSetsSearch if TYPE_CHECKING: - from mavedb.lib.authentication import UserData from mavedb.lib.permissions import Action VariantData = dict[str, Optional[dict[str, dict]]] @@ -298,21 +299,40 @@ def score_set_search_filter_options_from_counter(counter: Counter): return [{"value": value, "count": count} for value, count in counter.items()] -def fetch_score_set_search_filter_options(db: Session, owner_or_contributor: Optional[User], search: ScoreSetsSearch): +def fetch_score_set_search_filter_options( + db: Session, requester: Optional[UserData], owner_or_contributor: Optional[User], search: ScoreSetsSearch +): save_to_logging_context({"score_set_search_criteria": search.model_dump()}) query = db.query(ScoreSet) query = build_search_score_sets_query_filter(db, query, owner_or_contributor, search) - score_sets: list[ScoreSet] = query.all() if not score_sets: score_sets = [] + # Target related counters target_category_counter: Counter[str] = Counter() target_name_counter: Counter[str] = Counter() target_organism_name_counter: Counter[str] = Counter() target_accession_counter: Counter[str] = Counter() + # Publication related counters + publication_author_name_counter: Counter[str] = Counter() + publication_db_name_counter: Counter[str] = Counter() + publication_journal_counter: Counter[str] = Counter() + + # --- PERFORMANCE NOTE --- + # The following counter construction loop is a bottleneck for large score set queries. + # Practical future optimizations might include: + # - Batch permission checks and attribute access outside the loop if possible + # - Use parallelization (e.g., multiprocessing or concurrent.futures) for large datasets + # - Pre-fetch or denormalize target/publication data in the DB query + # - Profile and refactor nested attribute lookups to minimize Python overhead for score_set in score_sets: + # Check read permission for each score set, skip if no permission + if not has_permission(requester, score_set, Action.READ).permitted: + continue + + # Target related options for target in getattr(score_set, "target_genes", []): category = getattr(target, "category", None) if category: @@ -335,10 +355,7 @@ def fetch_score_set_search_filter_options(db: Session, owner_or_contributor: Opt if accession: target_accession_counter[accession] += 1 - publication_author_name_counter: Counter[str] = Counter() - publication_db_name_counter: Counter[str] = Counter() - publication_journal_counter: Counter[str] = Counter() - for score_set in score_sets: + # Publication related options for publication_association in getattr(score_set, "publication_identifier_associations", []): publication = getattr(publication_association, "publication", None) @@ -443,8 +460,6 @@ def find_meta_analyses_for_experiment_sets(db: Session, urns: list[str]) -> list def find_superseded_score_set_tail( score_set: ScoreSet, action: Optional["Action"] = None, user_data: Optional["UserData"] = None ) -> Optional[ScoreSet]: - from mavedb.lib.permissions import has_permission - while score_set.superseding_score_set is not None: next_score_set_in_chain = score_set.superseding_score_set @@ -502,7 +517,7 @@ def find_publish_or_private_superseded_score_set_tail( def get_score_set_variants_as_csv( db: Session, score_set: ScoreSet, - namespaces: List[Literal["scores", "counts", "vep", "gnomad"]], + namespaces: List[Literal["scores", "counts", "vep", "gnomad", "clingen"]], namespaced: Optional[bool] = None, start: Optional[int] = None, limit: Optional[int] = None, @@ -519,8 +534,8 @@ def get_score_set_variants_as_csv( The database session to use. score_set : ScoreSet The score set to get the variants from. - namespaces : List[Literal["scores", "counts", "vep", "gnomad"]] - The namespaces for data. Now there are only scores, counts, VEP, and gnomAD. ClinVar will be added in the future. + namespaces : List[Literal["scores", "counts", "vep", "gnomad", "clingen"]] + The namespaces for data. Now there are only scores, counts, VEP, gnomAD, and ClinGen. ClinVar will be added in the future. namespaced: Optional[bool] = None Whether namespace the columns or not. start : int, optional @@ -569,6 +584,8 @@ def get_score_set_variants_as_csv( namespaced_score_set_columns["vep"].append("vep_functional_consequence") if "gnomad" in namespaced_score_set_columns: namespaced_score_set_columns["gnomad"].append("gnomad_af") + if "clingen" in namespaced_score_set_columns: + namespaced_score_set_columns["clingen"].append("clingen_allele_id") variants: Sequence[Variant] = [] mappings: Optional[list[Optional[MappedVariant]]] = None gnomad_data: Optional[list[Optional[GnomADVariant]]] = None @@ -841,6 +858,15 @@ def variant_to_csv_row( value = na_rep key = f"gnomad.{column_key}" if namespaced else column_key row[key] = value + for column_key in columns.get("clingen", []): + if column_key == "clingen_allele_id": + clingen_allele_id = mapping.clingen_allele_id if mapping else None + if clingen_allele_id is not None: + value = str(clingen_allele_id) + else: + value = na_rep + key = f"clingen.{column_key}" if namespaced else column_key + row[key] = value return row @@ -1100,7 +1126,7 @@ def bulk_create_urns(n, score_set, reset_counter=False) -> list[str]: return child_urns -def csv_data_to_df(file_data: BinaryIO) -> pd.DataFrame: +def csv_data_to_df(file_data: BinaryIO, induce_hgvs_cols: bool = True) -> pd.DataFrame: extra_na_values = list( set( list(null_values_list) @@ -1121,9 +1147,10 @@ def csv_data_to_df(file_data: BinaryIO) -> pd.DataFrame: dtype={**{col: str for col in HGVSColumns.options()}, "scores": float}, ) - for c in HGVSColumns.options(): - if c not in ingested_df.columns: - ingested_df[c] = np.NaN + if induce_hgvs_cols: + for c in HGVSColumns.options(): + if c not in ingested_df.columns: + ingested_df[c] = np.NaN return ingested_df diff --git a/src/mavedb/lib/types/authentication.py b/src/mavedb/lib/types/authentication.py new file mode 100644 index 00000000..748f6f90 --- /dev/null +++ b/src/mavedb/lib/types/authentication.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from mavedb.models.enums.user_role import UserRole + from mavedb.models.user import User + + +@dataclass +class UserData: + user: "User" + active_roles: list["UserRole"] diff --git a/src/mavedb/lib/types/clingen.py b/src/mavedb/lib/types/clingen.py index 9085a9da..708b6c17 100644 --- a/src/mavedb/lib/types/clingen.py +++ b/src/mavedb/lib/types/clingen.py @@ -1,6 +1,6 @@ -from typing import Any, Optional, TypedDict, Literal -from typing_extensions import NotRequired +from typing import Any, Literal, Optional, TypedDict +from typing_extensions import NotRequired # See: https://ldh.genome.network/docs/ldh/submit.html#content-submission-body @@ -152,3 +152,15 @@ class ClinGenAlleleDefinition(TypedDict): "aminoAcidAlleles": NotRequired[list[ClinGenAlleleDefinition]], }, ) + +ClinGenSubmissionError = TypedDict( + "ClinGenSubmissionError", + { + "description": str, + "errorType": str, + "hgvs": str, + "inputLine": str, + "message": str, + "position": str, + }, +) diff --git a/src/mavedb/lib/types/permissions.py b/src/mavedb/lib/types/permissions.py new file mode 100644 index 00000000..aa9628c7 --- /dev/null +++ b/src/mavedb/lib/types/permissions.py @@ -0,0 +1,18 @@ +from typing import Union + +from mavedb.models.collection import Collection +from mavedb.models.experiment import Experiment +from mavedb.models.experiment_set import ExperimentSet +from mavedb.models.score_calibration import ScoreCalibration +from mavedb.models.score_set import ScoreSet +from mavedb.models.user import User + +# Define the supported entity types +EntityType = Union[ + Collection, + Experiment, + ExperimentSet, + ScoreCalibration, + ScoreSet, + User, +] diff --git a/src/mavedb/lib/types/score_calibrations.py b/src/mavedb/lib/types/score_calibrations.py new file mode 100644 index 00000000..d40edaf2 --- /dev/null +++ b/src/mavedb/lib/types/score_calibrations.py @@ -0,0 +1,6 @@ +from typing import TypedDict + + +class ClassificationDict(TypedDict): + indexed_by: str + classifications: dict[str, set[str]] diff --git a/src/mavedb/lib/validation/constants/general.py b/src/mavedb/lib/validation/constants/general.py index 92b4fd5b..22ca4cbf 100644 --- a/src/mavedb/lib/validation/constants/general.py +++ b/src/mavedb/lib/validation/constants/general.py @@ -44,6 +44,9 @@ variant_count_data = "count_data" required_score_column = "score" +calibration_variant_column_name = "variant_urn" +calibration_class_column_name = "class_name" + valid_dataset_columns = [score_columns, count_columns] valid_variant_columns = [variant_score_data, variant_count_data] diff --git a/src/mavedb/lib/validation/dataframe/calibration.py b/src/mavedb/lib/validation/dataframe/calibration.py new file mode 100644 index 00000000..1c46be46 --- /dev/null +++ b/src/mavedb/lib/validation/dataframe/calibration.py @@ -0,0 +1,248 @@ +import pandas as pd +from sqlalchemy import select +from sqlalchemy.orm import Session + +from mavedb.lib.validation.constants.general import ( + calibration_class_column_name, + calibration_variant_column_name, + hgvs_nt_column, + hgvs_pro_column, +) +from mavedb.lib.validation.dataframe.column import validate_data_column, validate_variant_column +from mavedb.lib.validation.dataframe.dataframe import standardize_dataframe, validate_no_null_rows +from mavedb.lib.validation.exceptions import ValidationError +from mavedb.models.score_set import ScoreSet +from mavedb.models.variant import Variant +from mavedb.view_models import score_calibration + +STANDARD_CALIBRATION_COLUMNS = ( + calibration_variant_column_name, + calibration_class_column_name, + hgvs_nt_column, + hgvs_pro_column, +) + + +def validate_and_standardize_calibration_classes_dataframe( + db: Session, + score_set: ScoreSet, + calibration: score_calibration.ScoreCalibrationCreate | score_calibration.ScoreCalibrationModify, + classes_df: pd.DataFrame, +) -> tuple[pd.DataFrame, str]: + """ + Validate and standardize a calibration classes dataframe for functional classification calibrations. + + This function performs comprehensive validation of a calibration classes dataframe, ensuring + it meets the requirements for functional classification calibrations. It standardizes column + names, validates data integrity, and checks that variants and classes are properly formatted. + + Args: + db (Session): Database session for validation queries. + score_set (ScoreSet): The score set associated with the calibration. + calibration (ScoreCalibrationCreate | ScoreCalibrationModify): The calibration object + containing configuration details. Must be class-based. + classes_df (pd.DataFrame): The input dataframe containing calibration classes data. + + Returns: + pd.DataFrame: The standardized and validated calibration classes dataframe. + + Raises: + ValueError: If the calibration is not class-based. + ValidationError: If the dataframe contains invalid data, unexpected columns, + invalid variant URNs, or improperly formatted classes. + + Note: + The function expects the dataframe to contain specific columns for variants and + calibration classes, and performs strict validation on both column structure + and data content. + """ + if not calibration.class_based: + raise ValidationError( + "Calibration classes file can only be provided for functional classification calibrations." + ) + + standardized_classes_df = standardize_dataframe(classes_df, STANDARD_CALIBRATION_COLUMNS) + validate_calibration_df_column_names(standardized_classes_df) + validate_no_null_rows(standardized_classes_df) + + column_mapping = {c.lower(): c for c in standardized_classes_df.columns} + index_column = choose_calibration_index_column(standardized_classes_df) + + # Drop rows where the calibration class column is NA + standardized_classes_df = standardized_classes_df.dropna( + subset=[column_mapping[calibration_class_column_name]] + ).reset_index(drop=True) + + for c in column_mapping: + if c in {calibration_variant_column_name, hgvs_nt_column, hgvs_pro_column}: + validate_variant_column(standardized_classes_df[c], column_mapping[c] == index_column) + elif c == calibration_class_column_name: + validate_data_column(standardized_classes_df[c], force_numeric=False) + validate_calibration_classes(calibration, standardized_classes_df[c]) + + if column_mapping[c] == index_column: + validate_index_existence_in_score_set( + db, score_set, standardized_classes_df[column_mapping[c]], column_mapping[c] + ) + + return standardized_classes_df, index_column + + +def validate_calibration_df_column_names(df: pd.DataFrame) -> None: + """ + Validate the column names of a calibration DataFrame. + + This function performs comprehensive validation of DataFrame column names to ensure + they meet the required format and structure for calibration data processing. + + Args: + df (pd.DataFrame): The DataFrame whose columns need to be validated. + + Raises: + ValidationError: If any of the following validation checks fail: + - Column names are not strings + - Column names are empty or contain only whitespace + - Required calibration variant column is missing + - Required calibration class column is missing + - DataFrame contains unexpected columns (must match STANDARD_CALIBRATION_COLUMNS exactly) + + Returns: + None: This function performs validation only and returns nothing on success. + + Note: + Column name comparison is case-insensitive. The function converts all column + names to lowercase before performing validation checks. + """ + if any(type(c) is not str for c in df.columns): + raise ValidationError("column names must be strings") + + if any(c.isspace() for c in df.columns) or any(len(c) == 0 for c in df.columns): + raise ValidationError("column names cannot be empty or whitespace") + + if len(df.columns) != len(set(c.lower() for c in df.columns)): + raise ValidationError("duplicate column names are not allowed (case-insensitive)") + + columns = [c.lower() for c in df.columns] + + if calibration_class_column_name not in columns: + raise ValidationError(f"missing required column: '{calibration_class_column_name}'") + + if set(columns).isdisjoint({hgvs_nt_column, hgvs_pro_column, calibration_variant_column_name}): + raise ValidationError( + f"at least one of {', '.join({hgvs_nt_column, hgvs_pro_column, calibration_variant_column_name})} must be present" + ) + + +def validate_index_existence_in_score_set( + db: Session, score_set: ScoreSet, index_column: pd.Series, index_column_name: str +) -> None: + """ + Validate that all provided resources in the index column exist in the given score set. + + Args: + db (Session): Database session for querying variants. + score_set (ScoreSet): The score set to validate variants against. + variant_urns (pd.Series): Series of variant URNs to validate. + + Raises: + ValidationError: If any variant URNs do not exist in the score set. + + Returns: + None: Function returns nothing if validation passes. + """ + if index_column_name.lower() == calibration_variant_column_name: + existing_resources = set( + db.scalars( + select(Variant.urn).where(Variant.score_set_id == score_set.id, Variant.urn.in_(index_column.tolist())) + ).all() + ) + elif index_column_name.lower() == hgvs_nt_column: + existing_resources = set( + db.scalars( + select(Variant.hgvs_nt).where( + Variant.score_set_id == score_set.id, Variant.hgvs_nt.in_(index_column.tolist()) + ) + ).all() + ) + elif index_column_name.lower() == hgvs_pro_column: + existing_resources = set( + db.scalars( + select(Variant.hgvs_pro).where( + Variant.score_set_id == score_set.id, Variant.hgvs_pro.in_(index_column.tolist()) + ) + ).all() + ) + + missing_resources = set(index_column.tolist()) - existing_resources + if missing_resources: + raise ValidationError( + f"The following resources do not exist in the score set: {', '.join(sorted(missing_resources))}" + ) + + +def choose_calibration_index_column(df: pd.DataFrame) -> str: + """ + Choose the appropriate index column for a calibration DataFrame. + + This function selects the index column based on the presence of specific columns + in the DataFrame. It prioritizes the calibration variant column, followed by + HGVS notation columns. + + Args: + df (pd.DataFrame): The DataFrame from which to choose the index column. + + Returns: + str: The name of the chosen index column. + + Raises: + ValidationError: If no valid index column is found in the DataFrame. + """ + column_mapping = {c.lower(): c for c in df.columns if not df[c].isna().all()} + + if calibration_variant_column_name in column_mapping: + return column_mapping[calibration_variant_column_name] + elif hgvs_nt_column in column_mapping: + return column_mapping[hgvs_nt_column] + elif hgvs_pro_column in column_mapping: + return column_mapping[hgvs_pro_column] + else: + raise ValidationError("failed to find valid calibration index column") + + +def validate_calibration_classes( + calibration: score_calibration.ScoreCalibrationCreate | score_calibration.ScoreCalibrationModify, classes: pd.Series +) -> None: + """ + Validate that the functional classifications in a calibration match the provided classes. + + This function ensures that: + 1. The calibration has functional classifications defined + 2. All classes in the provided series are defined in the calibration + 3. All classes defined in the calibration are present in the provided series + + Args: + calibration: A ScoreCalibrationCreate or ScoreCalibrationModify object containing + functional classifications to validate against. + classes: A pandas Series containing class labels to validate. + + Raises: + ValueError: If the calibration does not have functional classifications defined. + ValidationError: If there are classes in the series that are not defined in the + calibration, or if there are classes defined in the calibration + that are missing from the series. + """ + if not calibration.functional_classifications: + raise ValidationError("Calibration must have functional classifications defined for class validation.") + + defined_classes = {c.class_ for c in calibration.functional_classifications} + provided_classes = set(classes.tolist()) + + undefined_classes = provided_classes - defined_classes + if undefined_classes: + raise ValidationError( + f"The following classes are not defined in the calibration: {', '.join(sorted(undefined_classes))}" + ) + + unprovided_classes = defined_classes - provided_classes + if unprovided_classes: + raise ValidationError("Some defined classes in the calibration are missing from the classes file.") diff --git a/src/mavedb/lib/validation/dataframe/dataframe.py b/src/mavedb/lib/validation/dataframe/dataframe.py index 75a07db6..dcf6c8e5 100644 --- a/src/mavedb/lib/validation/dataframe/dataframe.py +++ b/src/mavedb/lib/validation/dataframe/dataframe.py @@ -82,8 +82,8 @@ def validate_and_standardize_dataframe_pair( if not targets: raise ValueError("Can't validate provided file with no targets.") - standardized_scores_df = standardize_dataframe(scores_df) - standardized_counts_df = standardize_dataframe(counts_df) if counts_df is not None else None + standardized_scores_df = standardize_dataframe(scores_df, STANDARD_COLUMNS) + standardized_counts_df = standardize_dataframe(counts_df, STANDARD_COLUMNS) if counts_df is not None else None validate_dataframe(standardized_scores_df, "scores", targets, hdp) @@ -224,7 +224,7 @@ def standardize_dict_keys(d: dict[str, Any]) -> dict[str, Any]: return {clean_col_name(k): v for k, v in d.items()} -def standardize_dataframe(df: pd.DataFrame) -> pd.DataFrame: +def standardize_dataframe(df: pd.DataFrame, standard_columns: tuple[str, ...]) -> pd.DataFrame: """Standardize a dataframe by sorting the columns and changing the standard column names to lowercase. Also strips leading and trailing whitespace from column names and removes any quoted strings from column names. @@ -250,7 +250,7 @@ def standardize_dataframe(df: pd.DataFrame) -> pd.DataFrame: cleaned_columns = {c: clean_col_name(c) for c in df.columns} df.rename(columns=cleaned_columns, inplace=True) - column_mapper = {x: x.lower() for x in df.columns if x.lower() in STANDARD_COLUMNS} + column_mapper = {x: x.lower() for x in df.columns if x.lower() in standard_columns} df.rename(columns=column_mapper, inplace=True) return sort_dataframe_columns(df) diff --git a/src/mavedb/models/__init__.py b/src/mavedb/models/__init__.py index 684b3c98..1a20b792 100644 --- a/src/mavedb/models/__init__.py +++ b/src/mavedb/models/__init__.py @@ -1,5 +1,6 @@ __all__ = [ "access_key", + "acmg_classification", "collection", "clinical_control", "controlled_keyword", @@ -19,8 +20,11 @@ "refseq_identifier", "refseq_offset", "role", - "score_set", + "score_calibration_functional_classification_variant_association", + "score_calibration_functional_classification", + "score_calibration_publication_identifier", "score_calibration", + "score_set", "target_gene", "target_sequence", "taxonomy", diff --git a/src/mavedb/models/acmg_classification.py b/src/mavedb/models/acmg_classification.py new file mode 100644 index 00000000..027a2caa --- /dev/null +++ b/src/mavedb/models/acmg_classification.py @@ -0,0 +1,26 @@ +"""SQLAlchemy model for ACMG classification entities.""" + +from datetime import date + +from sqlalchemy import Column, Date, Enum, Integer + +from mavedb.db.base import Base +from mavedb.models.enums.acmg_criterion import ACMGCriterion +from mavedb.models.enums.strength_of_evidence import StrengthOfEvidenceProvided + + +class ACMGClassification(Base): + """ACMG classification model for storing ACMG criteria, evidence strength, and points.""" + + __tablename__ = "acmg_classifications" + + id = Column(Integer, primary_key=True) + + criterion = Column(Enum(ACMGCriterion, native_enum=False, validate_strings=True, length=32), nullable=True) + evidence_strength = Column( + Enum(StrengthOfEvidenceProvided, native_enum=False, validate_strings=True, length=32), nullable=True + ) + points = Column(Integer, nullable=True) + + creation_date = Column(Date, nullable=False, default=date.today) + modification_date = Column(Date, nullable=False, default=date.today, onupdate=date.today) diff --git a/src/mavedb/models/enums/acmg_criterion.py b/src/mavedb/models/enums/acmg_criterion.py new file mode 100644 index 00000000..1c5435eb --- /dev/null +++ b/src/mavedb/models/enums/acmg_criterion.py @@ -0,0 +1,44 @@ +import enum + + +class ACMGCriterion(enum.Enum): + """Enum for ACMG criteria codes.""" + + PVS1 = "PVS1" + PS1 = "PS1" + PS2 = "PS2" + PS3 = "PS3" + PS4 = "PS4" + PM1 = "PM1" + PM2 = "PM2" + PM3 = "PM3" + PM4 = "PM4" + PM5 = "PM5" + PM6 = "PM6" + PP1 = "PP1" + PP2 = "PP2" + PP3 = "PP3" + PP4 = "PP4" + PP5 = "PP5" + BA1 = "BA1" + BS1 = "BS1" + BS2 = "BS2" + BS3 = "BS3" + BS4 = "BS4" + BP1 = "BP1" + BP2 = "BP2" + BP3 = "BP3" + BP4 = "BP4" + BP5 = "BP5" + BP6 = "BP6" + BP7 = "BP7" + + @property + def is_pathogenic(self) -> bool: + """Return True if the criterion is pathogenic, False if benign.""" + return self.name.startswith("P") # PVS, PS, PM, PP are pathogenic criteria + + @property + def is_benign(self) -> bool: + """Return True if the criterion is benign, False if pathogenic.""" + return self.name.startswith("B") # BA, BS, BP are benign criteria diff --git a/src/mavedb/models/enums/functional_classification.py b/src/mavedb/models/enums/functional_classification.py new file mode 100644 index 00000000..2a472a65 --- /dev/null +++ b/src/mavedb/models/enums/functional_classification.py @@ -0,0 +1,7 @@ +import enum + + +class FunctionalClassification(enum.Enum): + normal = "normal" + abnormal = "abnormal" + not_specified = "not_specified" diff --git a/src/mavedb/models/enums/strength_of_evidence.py b/src/mavedb/models/enums/strength_of_evidence.py new file mode 100644 index 00000000..58c3c26d --- /dev/null +++ b/src/mavedb/models/enums/strength_of_evidence.py @@ -0,0 +1,11 @@ +import enum + + +class StrengthOfEvidenceProvided(enum.Enum): + """Enum for strength of evidence provided.""" + + VERY_STRONG = "VERY_STRONG" + STRONG = "STRONG" + MODERATE_PLUS = "MODERATE_PLUS" + MODERATE = "MODERATE" + SUPPORTING = "SUPPORTING" diff --git a/src/mavedb/models/experiment.py b/src/mavedb/models/experiment.py index 846ab00a..22014a59 100644 --- a/src/mavedb/models/experiment.py +++ b/src/mavedb/models/experiment.py @@ -73,6 +73,7 @@ class Experiment(Base): abstract_text = Column(String, nullable=False) method_text = Column(String, nullable=False) extra_metadata = Column(JSONB, nullable=False) + external_links = Column(JSONB, nullable=False, default={}) private = Column(Boolean, nullable=False, default=True) approved = Column(Boolean, nullable=False, default=False) diff --git a/src/mavedb/models/score_calibration.py b/src/mavedb/models/score_calibration.py index 988d4d04..38ce1f28 100644 --- a/src/mavedb/models/score_calibration.py +++ b/src/mavedb/models/score_calibration.py @@ -12,6 +12,7 @@ from mavedb.db.base import Base from mavedb.lib.urns import generate_calibration_urn +from mavedb.models.score_calibration_functional_classification import ScoreCalibrationFunctionalClassification from mavedb.models.score_calibration_publication_identifier import ScoreCalibrationPublicationIdentifierAssociation if TYPE_CHECKING: @@ -33,16 +34,18 @@ class ScoreCalibration(Base): title = Column(String, nullable=False) research_use_only = Column(Boolean, nullable=False, default=False) primary = Column(Boolean, nullable=False, default=False) - investigator_provided = Column(Boolean, nullable=False, default=False) + investigator_provided: Mapped[bool] = Column(Boolean, nullable=False, default=False) private = Column(Boolean, nullable=False, default=True) notes = Column(String, nullable=True) baseline_score = Column(Float, nullable=True) baseline_score_description = Column(String, nullable=True) - # Ranges and sources are stored as JSONB (intersection structure) to avoid complex joins for now. - # ranges: list[ { label, description?, classification, range:[lower,upper], inclusive_lower_bound, inclusive_upper_bound } ] - functional_ranges = Column(JSONB(none_as_null=True), nullable=True) + functional_classifications: Mapped[list["ScoreCalibrationFunctionalClassification"]] = relationship( + "ScoreCalibrationFunctionalClassification", + back_populates="calibration", + cascade="all, delete-orphan", + ) publication_identifier_associations: Mapped[list[ScoreCalibrationPublicationIdentifierAssociation]] = relationship( "ScoreCalibrationPublicationIdentifierAssociation", diff --git a/src/mavedb/models/score_calibration_functional_classification.py b/src/mavedb/models/score_calibration_functional_classification.py new file mode 100644 index 00000000..1975310a --- /dev/null +++ b/src/mavedb/models/score_calibration_functional_classification.py @@ -0,0 +1,81 @@ +"""SQLAlchemy model for variant score calibration functional classifications.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from sqlalchemy import Boolean, Column, Enum, Float, ForeignKey, Integer, String +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped, relationship + +from mavedb.db.base import Base +from mavedb.lib.validation.utilities import inf_or_float +from mavedb.models.acmg_classification import ACMGClassification +from mavedb.models.enums.functional_classification import FunctionalClassification as FunctionalClassificationOptions +from mavedb.models.score_calibration_functional_classification_variant_association import ( + score_calibration_functional_classification_variants_association_table, +) + +if TYPE_CHECKING: + from mavedb.models.score_calibration import ScoreCalibration + from mavedb.models.variant import Variant + + +class ScoreCalibrationFunctionalClassification(Base): + __tablename__ = "score_calibration_functional_classifications" + + id = Column(Integer, primary_key=True) + + calibration_id = Column(Integer, ForeignKey("score_calibrations.id"), nullable=False) + calibration: Mapped["ScoreCalibration"] = relationship("ScoreCalibration", foreign_keys=[calibration_id]) + + label = Column(String, nullable=False) + description = Column(String, nullable=True) + + functional_classification = Column( + Enum(FunctionalClassificationOptions, native_enum=False, validate_strings=True, length=32), + nullable=False, + default=FunctionalClassificationOptions.not_specified, + ) + + range = Column(JSONB(none_as_null=True), nullable=True) # (lower_bound, upper_bound) + class_ = Column(String, nullable=True) + + inclusive_lower_bound = Column(Boolean, nullable=True) + inclusive_upper_bound = Column(Boolean, nullable=True) + + oddspaths_ratio = Column(Float, nullable=True) + positive_likelihood_ratio = Column(Float, nullable=True) + + acmg_classification_id = Column(Integer, ForeignKey("acmg_classifications.id"), nullable=True) + acmg_classification: Mapped[ACMGClassification] = relationship( + "ACMGClassification", foreign_keys=[acmg_classification_id] + ) + + # Many-to-many relationship with variants + variants: Mapped[list["Variant"]] = relationship( + "Variant", + secondary=score_calibration_functional_classification_variants_association_table, + ) + + def score_is_contained_in_range(self, score: float) -> bool: + """Check if a given score falls within the defined range.""" + if self.range is None or not isinstance(self.range, list) or len(self.range) != 2: + return False + + lower_bound, upper_bound = inf_or_float(self.range[0], lower=True), inf_or_float(self.range[1], lower=False) + if self.inclusive_lower_bound: + if score < lower_bound: + return False + else: + if score <= lower_bound: + return False + + if self.inclusive_upper_bound: + if score > upper_bound: + return False + else: + if score >= upper_bound: + return False + + return True diff --git a/src/mavedb/models/score_calibration_functional_classification_variant_association.py b/src/mavedb/models/score_calibration_functional_classification_variant_association.py new file mode 100644 index 00000000..61f074bd --- /dev/null +++ b/src/mavedb/models/score_calibration_functional_classification_variant_association.py @@ -0,0 +1,14 @@ +"""SQLAlchemy association table for variants belonging to functional classifications.""" + +from sqlalchemy import Column, ForeignKey, Table + +from mavedb.db.base import Base + +score_calibration_functional_classification_variants_association_table = Table( + "score_calibration_functional_classification_variants", + Base.metadata, + Column( + "functional_classification_id", ForeignKey("score_calibration_functional_classifications.id"), primary_key=True + ), + Column("variant_id", ForeignKey("variants.id"), primary_key=True), +) diff --git a/src/mavedb/models/variant.py b/src/mavedb/models/variant.py index b038c1ea..59b6e729 100644 --- a/src/mavedb/models/variant.py +++ b/src/mavedb/models/variant.py @@ -34,3 +34,6 @@ class Variant(Base): mapped_variants: Mapped[List["MappedVariant"]] = relationship( back_populates="variant", cascade="all, delete-orphan" ) + + # Bidirectional relationship with ScoreCalibrationFunctionalClassification is left + # purposefully undefined for performance reasons. diff --git a/src/mavedb/routers/access_keys.py b/src/mavedb/routers/access_keys.py index c584dcb2..275fc437 100644 --- a/src/mavedb/routers/access_keys.py +++ b/src/mavedb/routers/access_keys.py @@ -12,10 +12,10 @@ from sqlalchemy.orm import Session from mavedb import deps -from mavedb.lib.authentication import UserData from mavedb.lib.authorization import require_current_user from mavedb.lib.logging import LoggedRoute from mavedb.lib.logging.context import logging_context, save_to_logging_context +from mavedb.lib.types.authentication import UserData from mavedb.models.access_key import AccessKey from mavedb.models.enums.user_role import UserRole from mavedb.routers.shared import ACCESS_CONTROL_ERROR_RESPONSES, PUBLIC_ERROR_RESPONSES, ROUTER_BASE_PREFIX diff --git a/src/mavedb/routers/collections.py b/src/mavedb/routers/collections.py index cf215a69..48f38964 100644 --- a/src/mavedb/routers/collections.py +++ b/src/mavedb/routers/collections.py @@ -9,7 +9,7 @@ from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound from mavedb import deps -from mavedb.lib.authentication import UserData, get_current_user +from mavedb.lib.authentication import get_current_user from mavedb.lib.authorization import require_current_user, require_current_user_with_email from mavedb.lib.logging import LoggedRoute from mavedb.lib.logging.context import ( @@ -18,6 +18,7 @@ save_to_logging_context, ) from mavedb.lib.permissions import Action, assert_permission, has_permission +from mavedb.lib.types.authentication import UserData from mavedb.models.collection import Collection from mavedb.models.collection_user_association import CollectionUserAssociation from mavedb.models.enums.contribution_role import ContributionRole diff --git a/src/mavedb/routers/experiment_sets.py b/src/mavedb/routers/experiment_sets.py index 386da37b..6bc5214c 100644 --- a/src/mavedb/routers/experiment_sets.py +++ b/src/mavedb/routers/experiment_sets.py @@ -6,11 +6,12 @@ from sqlalchemy.orm import Session from mavedb import deps -from mavedb.lib.authentication import UserData, get_current_user +from mavedb.lib.authentication import get_current_user from mavedb.lib.experiments import enrich_experiment_with_num_score_sets from mavedb.lib.logging import LoggedRoute from mavedb.lib.logging.context import logging_context, save_to_logging_context from mavedb.lib.permissions import Action, assert_permission, has_permission +from mavedb.lib.types.authentication import UserData from mavedb.models.experiment_set import ExperimentSet from mavedb.routers.shared import ACCESS_CONTROL_ERROR_RESPONSES, PUBLIC_ERROR_RESPONSES, ROUTER_BASE_PREFIX from mavedb.view_models import experiment_set @@ -57,7 +58,7 @@ def fetch_experiment_set( # the exception is raised, not returned - you will get a validation # error otherwise. logger.debug(msg="The requested resources does not exist.", extra=logging_context()) - raise HTTPException(status_code=404, detail=f"Experiment set with URN {urn} not found") + raise HTTPException(status_code=404, detail=f"experiment set with URN {urn} not found") else: item.experiments.sort(key=attrgetter("urn")) diff --git a/src/mavedb/routers/experiments.py b/src/mavedb/routers/experiments.py index 5d37ecb3..2777f1f6 100644 --- a/src/mavedb/routers/experiments.py +++ b/src/mavedb/routers/experiments.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import Session from mavedb import deps -from mavedb.lib.authentication import UserData, get_current_user +from mavedb.lib.authentication import get_current_user from mavedb.lib.authorization import require_current_user, require_current_user_with_email from mavedb.lib.contributors import find_or_create_contributor from mavedb.lib.exceptions import NonexistentOrcidUserError @@ -25,6 +25,7 @@ from mavedb.lib.logging.context import logging_context, save_to_logging_context from mavedb.lib.permissions import Action, assert_permission, has_permission from mavedb.lib.score_sets import find_superseded_score_set_tail +from mavedb.lib.types.authentication import UserData from mavedb.lib.validation.exceptions import ValidationError from mavedb.lib.validation.keywords import validate_keyword_list from mavedb.models.contributor import Contributor @@ -155,7 +156,7 @@ def fetch_experiment( if not item: logger.debug(msg="The requested experiment does not exist.", extra=logging_context()) - raise HTTPException(status_code=404, detail=f"Experiment with URN {urn} not found") + raise HTTPException(status_code=404, detail=f"experiment with URN {urn} not found") assert_permission(user_data, item, Action.READ) return enrich_experiment_with_num_score_sets(item, user_data) @@ -459,6 +460,7 @@ async def update_experiment( item.raw_read_identifiers = raw_read_identifiers if item_update.keywords: + keywords: list[ExperimentControlledKeywordAssociation] = [] all_labels_none = all(k.keyword.label is None for k in item_update.keywords) if all_labels_none is False: # Users may choose part of keywords from dropdown menu. Remove not chosen keywords from the list. @@ -467,10 +469,18 @@ async def update_experiment( validate_keyword_list(filtered_keywords) except ValidationError as e: raise HTTPException(status_code=422, detail=str(e)) - try: - await item.set_keywords(db, filtered_keywords) - except Exception as e: - raise HTTPException(status_code=500, detail=f"Invalid keywords: {str(e)}") + for upload_keyword in filtered_keywords: + try: + description = upload_keyword.description + controlled_keyword = search_keyword(db, upload_keyword.keyword.key, upload_keyword.keyword.label) + experiment_controlled_keyword = ExperimentControlledKeywordAssociation( + controlled_keyword=controlled_keyword, + description=description, + ) + keywords.append(experiment_controlled_keyword) + except ValueError as e: + raise HTTPException(status_code=422, detail=str(e)) + item.keyword_objs = keywords item.modified_by = user_data.user diff --git a/src/mavedb/routers/mapped_variant.py b/src/mavedb/routers/mapped_variant.py index 5657fd3a..52830f92 100644 --- a/src/mavedb/routers/mapped_variant.py +++ b/src/mavedb/routers/mapped_variant.py @@ -17,7 +17,6 @@ variant_study_result, ) from mavedb.lib.annotation.exceptions import MappingDataDoesntExistException -from mavedb.lib.authentication import UserData from mavedb.lib.authorization import get_current_user from mavedb.lib.logging import LoggedRoute from mavedb.lib.logging.context import ( @@ -25,6 +24,7 @@ save_to_logging_context, ) from mavedb.lib.permissions import Action, assert_permission, has_permission +from mavedb.lib.types.authentication import UserData from mavedb.models.mapped_variant import MappedVariant from mavedb.models.variant import Variant from mavedb.routers.shared import ACCESS_CONTROL_ERROR_RESPONSES, PUBLIC_ERROR_RESPONSES, ROUTER_BASE_PREFIX diff --git a/src/mavedb/routers/permissions.py b/src/mavedb/routers/permissions.py index c100cfa2..39833ec7 100644 --- a/src/mavedb/routers/permissions.py +++ b/src/mavedb/routers/permissions.py @@ -6,10 +6,11 @@ from sqlalchemy.orm import Session from mavedb import deps -from mavedb.lib.authentication import UserData, get_current_user +from mavedb.lib.authentication import get_current_user from mavedb.lib.logging import LoggedRoute from mavedb.lib.logging.context import logging_context, save_to_logging_context from mavedb.lib.permissions import Action, has_permission +from mavedb.lib.types.authentication import UserData from mavedb.models.collection import Collection from mavedb.models.experiment import Experiment from mavedb.models.experiment_set import ExperimentSet diff --git a/src/mavedb/routers/score_calibrations.py b/src/mavedb/routers/score_calibrations.py index daac1950..8a413677 100644 --- a/src/mavedb/routers/score_calibrations.py +++ b/src/mavedb/routers/score_calibrations.py @@ -1,40 +1,57 @@ import logging - -from fastapi import APIRouter, Depends, HTTPException, Query from typing import Optional -from sqlalchemy.orm import Session + +from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile +from sqlalchemy.orm import Session, selectinload from mavedb import deps +from mavedb.lib.authentication import get_current_user +from mavedb.lib.authorization import require_current_user +from mavedb.lib.flexible_model_loader import json_or_form_loader from mavedb.lib.logging import LoggedRoute from mavedb.lib.logging.context import ( logging_context, save_to_logging_context, ) -from mavedb.lib.authentication import get_current_user, UserData -from mavedb.lib.authorization import require_current_user from mavedb.lib.permissions import Action, assert_permission, has_permission from mavedb.lib.score_calibrations import ( create_score_calibration_in_score_set, - modify_score_calibration, delete_score_calibration, demote_score_calibration_from_primary, + modify_score_calibration, promote_score_calibration_to_primary, publish_score_calibration, + variant_classification_df_to_dict, ) +from mavedb.lib.score_sets import csv_data_to_df +from mavedb.lib.types.authentication import UserData +from mavedb.lib.validation.constants.general import calibration_class_column_name, calibration_variant_column_name +from mavedb.lib.validation.dataframe.calibration import validate_and_standardize_calibration_classes_dataframe +from mavedb.lib.validation.exceptions import ValidationError from mavedb.models.score_calibration import ScoreCalibration -from mavedb.routers.score_sets import fetch_score_set_by_urn +from mavedb.models.score_set import ScoreSet from mavedb.view_models import score_calibration - logger = logging.getLogger(__name__) router = APIRouter( prefix="/api/v1/score-calibrations", - tags=["score-calibrations"], + tags=["Score Calibrations"], responses={404: {"description": "Not found"}}, route_class=LoggedRoute, ) +# Create dependency loaders for flexible JSON/form parsing +calibration_create_loader = json_or_form_loader( + score_calibration.ScoreCalibrationCreate, + field_name="calibration_json", +) + +calibration_modify_loader = json_or_form_loader( + score_calibration.ScoreCalibrationModify, + field_name="calibration_json", +) + @router.get( "/{urn}", @@ -52,7 +69,12 @@ def get_score_calibration( """ save_to_logging_context({"requested_resource": urn}) - item = db.query(ScoreCalibration).where(ScoreCalibration.urn == urn).one_or_none() + item = ( + db.query(ScoreCalibration) + .options(selectinload(ScoreCalibration.score_set).selectinload(ScoreSet.contributors)) + .where(ScoreCalibration.urn == urn) + .one_or_none() + ) if not item: logger.debug("The requested score calibration does not exist", extra=logging_context()) raise HTTPException(status_code=404, detail="The requested score calibration does not exist") @@ -76,12 +98,23 @@ async def get_score_calibrations_for_score_set( Retrieve all score calibrations for a given score set URN. """ save_to_logging_context({"requested_resource": score_set_urn, "resource_property": "calibrations"}) - score_set = await fetch_score_set_by_urn(db, score_set_urn, user_data, None, False) + score_set = db.query(ScoreSet).filter(ScoreSet.urn == score_set_urn).one_or_none() + + if not score_set: + logger.debug("ScoreSet not found", extra=logging_context()) + raise HTTPException(status_code=404, detail=f"score set with URN '{score_set_urn}' not found") + + assert_permission(user_data, score_set, Action.READ) + + calibrations = ( + db.query(ScoreCalibration) + .filter(ScoreCalibration.score_set_id == score_set.id) + .options(selectinload(ScoreCalibration.score_set).selectinload(ScoreSet.contributors)) + .all() + ) permitted_calibrations = [ - calibration - for calibration in score_set.score_calibrations - if has_permission(user_data, calibration, Action.READ).permitted + calibration for calibration in calibrations if has_permission(user_data, calibration, Action.READ).permitted ] if not permitted_calibrations: logger.debug("No score calibrations found for the requested score set", extra=logging_context()) @@ -105,12 +138,23 @@ async def get_primary_score_calibrations_for_score_set( Retrieve the primary score calibration for a given score set URN. """ save_to_logging_context({"requested_resource": score_set_urn, "resource_property": "calibrations"}) - score_set = await fetch_score_set_by_urn(db, score_set_urn, user_data, None, False) + + score_set = db.query(ScoreSet).filter(ScoreSet.urn == score_set_urn).one_or_none() + if not score_set: + logger.debug("ScoreSet not found", extra=logging_context()) + raise HTTPException(status_code=404, detail=f"score set with URN '{score_set_urn}' not found") + + assert_permission(user_data, score_set, Action.READ) + + calibrations = ( + db.query(ScoreCalibration) + .filter(ScoreCalibration.score_set_id == score_set.id) + .options(selectinload(ScoreCalibration.score_set).selectinload(ScoreSet.contributors)) + .all() + ) permitted_calibrations = [ - calibration - for calibration in score_set.score_calibrations - if has_permission(user_data, calibration, Action.READ) + calibration for calibration in calibrations if has_permission(user_data, calibration, Action.READ).permitted ] if not permitted_calibrations: logger.debug("No score calibrations found for the requested score set", extra=logging_context()) @@ -136,31 +180,144 @@ async def get_primary_score_calibrations_for_score_set( @router.post( "/", response_model=score_calibration.ScoreCalibrationWithScoreSetUrn, - responses={404: {}}, + responses={404: {}, 422: {"description": "Validation Error"}}, + openapi_extra={ + "requestBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/ScoreCalibrationCreate"}, + }, + "multipart/form-data": { + "schema": { + "type": "object", + "properties": { + "calibration_json": { + "type": "string", + "description": "JSON string containing the calibration data", + "example": '{"score_set_urn":"urn:mavedb:0000000X-X-X","title":"My Calibration","description":"Functional score calibration","baseline_score":1.0}', + }, + "classes_file": { + "type": "string", + "format": "binary", + "description": "CSV file containing variant classifications", + }, + }, + } + }, + }, + "description": "Score calibration data. Can be sent as JSON body or multipart form data", + } + }, ) async def create_score_calibration_route( *, - calibration: score_calibration.ScoreCalibrationCreate, + calibration: score_calibration.ScoreCalibrationCreate = Depends(calibration_create_loader), + classes_file: Optional[UploadFile] = File( + None, + description=f"CSV file containing variant classifications. This file must contain two columns: '{calibration_variant_column_name}' and '{calibration_class_column_name}'.", + ), db: Session = Depends(deps.get_db), user_data: UserData = Depends(require_current_user), ) -> ScoreCalibration: """ Create a new score calibration. - The score set URN must be provided to associate the calibration with an existing score set. - The user must have write permission on the associated score set. + This endpoint supports two different request formats to accommodate various client needs: + + ## Method 1: JSON Request Body (application/json) + Send calibration data as a standard JSON request body. This method is ideal for + creating calibrations without file uploads. + + **Content-Type**: `application/json` + + **Example**: + ```json + { + "score_set_urn": "urn:mavedb:0000000X-X-X", + "title": "My Calibration", + "description": "Functional score calibration", + "baseline_score": 1.0 + } + ``` + + ## Method 2: Multipart Form Data (multipart/form-data) + Send calibration data as JSON in a form field, optionally with file uploads. + This method is required when uploading classification files. + + **Content-Type**: `multipart/form-data` + + **Form Fields**: + - `calibration_json` (string, required): JSON string containing the calibration data + - `classes_file` (file, optional): CSV file containing variant classifications + + **Example**: + ```bash + curl -X POST "/api/v1/score-calibrations/" \\ + -H "Authorization: Bearer your-token" \\ + -F 'calibration_json={"score_set_urn":"urn:mavedb:0000000X-X-X","title":"My Calibration","description":"Functional score calibration","baseline_score":"1.0"}' \\ + -F 'classes_file=@variant_classes.csv' + ``` + + ## Requirements + - The score set URN must be provided to associate the calibration with an existing score set + - User must have write permission on the associated score set + - If uploading a classes_file, it must be a valid CSV with variant classification data + + ## File Upload Details + The `classes_file` parameter accepts CSV files containing variant classification data. + The file should have appropriate headers and contain columns for variant urns and class names. + + ## Response + Returns the created score calibration with its generated URN and associated score set information. """ if not calibration.score_set_urn: raise HTTPException(status_code=422, detail="score_set_urn must be provided to create a score calibration.") save_to_logging_context({"requested_resource": calibration.score_set_urn, "resource_property": "calibrations"}) - score_set = await fetch_score_set_by_urn(db, calibration.score_set_urn, user_data, None, False) + score_set = db.query(ScoreSet).filter(ScoreSet.urn == calibration.score_set_urn).one_or_none() + if not score_set: + logger.debug("ScoreSet not found", extra=logging_context()) + raise HTTPException(status_code=404, detail=f"score set with URN '{calibration.score_set_urn}' not found") + # TODO#539: Allow any authenticated user to upload a score calibration for a score set, not just those with # permission to update the score set itself. assert_permission(user_data, score_set, Action.UPDATE) - created_calibration = await create_score_calibration_in_score_set(db, calibration, user_data.user) + if calibration.class_based and not classes_file: + raise HTTPException( + status_code=422, + detail="A classes_file must be provided when creating a class-based calibration.", + ) + + if classes_file: + if calibration.range_based: + raise HTTPException( + status_code=422, + detail="A classes_file should not be provided when creating a range-based calibration.", + ) + + try: + classes_df = csv_data_to_df(classes_file.file, induce_hgvs_cols=False) + except UnicodeDecodeError as e: + raise HTTPException( + status_code=400, detail=f"Error decoding file: {e}. Ensure the file has correct values." + ) + + try: + standardized_classes_df, index_column = validate_and_standardize_calibration_classes_dataframe( + db, score_set, calibration, classes_df + ) + variant_classes = variant_classification_df_to_dict(standardized_classes_df, index_column) + except ValidationError as e: + raise HTTPException( + status_code=422, + detail=[{"loc": [e.custom_loc or "classesFile"], "msg": str(e), "type": "value_error"}], + ) + + created_calibration = await create_score_calibration_in_score_set( + db, calibration, user_data.user, variant_classes if classes_file else None + ) db.commit() db.refresh(created_calibration) @@ -171,36 +328,165 @@ async def create_score_calibration_route( @router.put( "/{urn}", response_model=score_calibration.ScoreCalibrationWithScoreSetUrn, - responses={404: {}}, + responses={404: {}, 422: {"description": "Validation Error"}}, + openapi_extra={ + "requestBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/ScoreCalibrationModify"}, + }, + "multipart/form-data": { + "schema": { + "type": "object", + "properties": { + "calibration_json": { + "type": "string", + "description": "JSON string containing the calibration update data", + "example": '{"title":"Updated Calibration","description":"Updated description","baseline_score":2.0}', + }, + "classes_file": { + "type": "string", + "format": "binary", + "description": "CSV file containing updated variant classifications", + }, + }, + } + }, + }, + "description": "Score calibration update data. Can be sent as JSON body or multipart form data", + } + }, ) async def modify_score_calibration_route( *, urn: str, - calibration_update: score_calibration.ScoreCalibrationModify, + calibration_update: score_calibration.ScoreCalibrationModify = Depends(calibration_modify_loader), + classes_file: Optional[UploadFile] = File( + None, + description=f"CSV file containing variant classifications. This file must contain two columns: '{calibration_variant_column_name}' and '{calibration_class_column_name}'.", + ), db: Session = Depends(deps.get_db), user_data: UserData = Depends(require_current_user), ) -> ScoreCalibration: """ Modify an existing score calibration by its URN. + + This endpoint supports two different request formats to accommodate various client needs: + + ## Method 1: JSON Request Body (application/json) + Send calibration update data as a standard JSON request body. This method is ideal for + modifying calibrations without file uploads. + + **Content-Type**: `application/json` + + **Example**: + ```json + { + "score_set_urn": "urn:mavedb:0000000X-X-X", + "title": "Updated Calibration Title", + "description": "Updated functional score calibration", + "baseline_score": 1.0 + } + ``` + + ## Method 2: Multipart Form Data (multipart/form-data) + Send calibration update data as JSON in a form field, optionally with file uploads. + This method is required when uploading new classification files. + + **Content-Type**: `multipart/form-data` + + **Form Fields**: + - `calibration_json` (string, required): JSON string containing the calibration update data + - `classes_file` (file, optional): CSV file containing updated variant classifications + + **Example**: + ```bash + curl -X PUT "/api/v1/score-calibrations/{urn}" \\ + -H "Authorization: Bearer your-token" \\ + -F 'calibration_json={"score_set_urn":"urn:mavedb:0000000X-X-X","title":"My Calibration","description":"Functional score calibration","baseline_score":"1.0"}' \\ + -F 'classes_file=@updated_variant_classes.csv' + ``` + + ## Requirements + - User must have update permission on the calibration + - If changing the score_set_urn, user must have permission on the new score set + - All fields in the update are optional - only provided fields will be modified + + ## File Upload Details + The `classes_file` parameter accepts CSV files containing updated variant classification data. + If provided, this will replace the existing classification data for the calibration. + The file should have appropriate headers and follow the expected format for variant + classifications within the associated score set. + + ## Response + Returns the updated score calibration with all modifications applied and any new + classification data from the uploaded file. """ save_to_logging_context({"requested_resource": urn}) # If the user supplies a new score_set_urn, validate it exists and the user has permission to use it. if calibration_update.score_set_urn is not None: - score_set = await fetch_score_set_by_urn(db, calibration_update.score_set_urn, user_data, None, False) + score_set_update = db.query(ScoreSet).filter(ScoreSet.urn == calibration_update.score_set_urn).one_or_none() + + if not score_set_update: + logger.debug("ScoreSet not found", extra=logging_context()) + raise HTTPException( + status_code=404, detail=f"score set with URN '{calibration_update.score_set_urn}' not found" + ) # TODO#539: Allow any authenticated user to upload a score calibration for a score set, not just those with # permission to update the score set itself. - assert_permission(user_data, score_set, Action.UPDATE) - - item = db.query(ScoreCalibration).where(ScoreCalibration.urn == urn).one_or_none() + assert_permission(user_data, score_set_update, Action.UPDATE) + else: + score_set_update = None + + item = ( + db.query(ScoreCalibration) + .options(selectinload(ScoreCalibration.score_set).selectinload(ScoreSet.contributors)) + .where(ScoreCalibration.urn == urn) + .one_or_none() + ) if not item: logger.debug("The requested score calibration does not exist", extra=logging_context()) raise HTTPException(status_code=404, detail="The requested score calibration does not exist") assert_permission(user_data, item, Action.UPDATE) + score_set = score_set_update or item.score_set + + if calibration_update.class_based and not classes_file: + raise HTTPException( + status_code=422, + detail="A classes_file must be provided when modifying a class-based calibration.", + ) - updated_calibration = await modify_score_calibration(db, item, calibration_update, user_data.user) + if classes_file: + if calibration_update.range_based: + raise HTTPException( + status_code=422, + detail="A classes_file should not be provided when modifying a range-based calibration.", + ) + + try: + classes_df = csv_data_to_df(classes_file.file, induce_hgvs_cols=False) + except UnicodeDecodeError as e: + raise HTTPException( + status_code=400, detail=f"Error decoding file: {e}. Ensure the file has correct values." + ) + + try: + standardized_classes_df, index_column = validate_and_standardize_calibration_classes_dataframe( + db, score_set, calibration_update, classes_df + ) + variant_classes = variant_classification_df_to_dict(standardized_classes_df, index_column) + except ValidationError as e: + raise HTTPException( + status_code=422, + detail=[{"loc": [e.custom_loc or "classesFile"], "msg": str(e), "type": "value_error"}], + ) + + updated_calibration = await modify_score_calibration( + db, item, calibration_update, user_data.user, variant_classes if classes_file else None + ) db.commit() db.refresh(updated_calibration) @@ -225,7 +511,12 @@ async def delete_score_calibration_route( """ save_to_logging_context({"requested_resource": urn}) - item = db.query(ScoreCalibration).where(ScoreCalibration.urn == urn).one_or_none() + item = ( + db.query(ScoreCalibration) + .options(selectinload(ScoreCalibration.score_set).selectinload(ScoreSet.contributors)) + .where(ScoreCalibration.urn == urn) + .one_or_none() + ) if not item: logger.debug("The requested score calibration does not exist", extra=logging_context()) raise HTTPException(status_code=404, detail="The requested score calibration does not exist") @@ -259,7 +550,12 @@ async def promote_score_calibration_to_primary_route( {"requested_resource": urn, "resource_property": "primary", "demote_existing_primary": demote_existing_primary} ) - item = db.query(ScoreCalibration).where(ScoreCalibration.urn == urn).one_or_none() + item = ( + db.query(ScoreCalibration) + .options(selectinload(ScoreCalibration.score_set).selectinload(ScoreSet.contributors)) + .where(ScoreCalibration.urn == urn) + .one_or_none() + ) if not item: logger.debug("The requested score calibration does not exist", extra=logging_context()) raise HTTPException(status_code=404, detail="The requested score calibration does not exist") @@ -318,7 +614,12 @@ def demote_score_calibration_from_primary_route( """ save_to_logging_context({"requested_resource": urn, "resource_property": "primary"}) - item = db.query(ScoreCalibration).where(ScoreCalibration.urn == urn).one_or_none() + item = ( + db.query(ScoreCalibration) + .options(selectinload(ScoreCalibration.score_set).selectinload(ScoreSet.contributors)) + .where(ScoreCalibration.urn == urn) + .one_or_none() + ) if not item: logger.debug("The requested score calibration does not exist", extra=logging_context()) raise HTTPException(status_code=404, detail="The requested score calibration does not exist") @@ -352,7 +653,12 @@ def publish_score_calibration_route( """ save_to_logging_context({"requested_resource": urn, "resource_property": "private"}) - item = db.query(ScoreCalibration).where(ScoreCalibration.urn == urn).one_or_none() + item = ( + db.query(ScoreCalibration) + .options(selectinload(ScoreCalibration.score_set).selectinload(ScoreSet.contributors)) + .where(ScoreCalibration.urn == urn) + .one_or_none() + ) if not item: logger.debug("The requested score calibration does not exist", extra=logging_context()) raise HTTPException(status_code=404, detail="The requested score calibration does not exist") diff --git a/src/mavedb/routers/score_sets.py b/src/mavedb/routers/score_sets.py index 959f9133..694860d2 100644 --- a/src/mavedb/routers/score_sets.py +++ b/src/mavedb/routers/score_sets.py @@ -26,7 +26,6 @@ variant_study_result, ) from mavedb.lib.annotation.exceptions import MappingDataDoesntExistException -from mavedb.lib.authentication import UserData from mavedb.lib.authorization import ( get_current_user, require_current_user, @@ -61,6 +60,7 @@ ) from mavedb.lib.target_genes import find_or_create_target_gene_by_accession, find_or_create_target_gene_by_sequence from mavedb.lib.taxonomies import find_or_create_taxonomy +from mavedb.lib.types.authentication import UserData from mavedb.lib.urns import ( generate_experiment_set_urn, generate_experiment_urn, @@ -598,8 +598,9 @@ def search_score_sets( def get_filter_options_for_search( search: ScoreSetsSearch, db: Session = Depends(deps.get_db), + user_data: Optional[UserData] = Depends(get_current_user), ) -> Any: - return fetch_score_set_search_filter_options(db, None, search) + return fetch_score_set_search_filter_options(db, user_data, None, search) @router.get( @@ -706,8 +707,8 @@ def get_score_set_variants_csv( urn: str, start: int = Query(default=None, description="Start index for pagination"), limit: int = Query(default=None, description="Maximum number of variants to return"), - namespaces: List[Literal["scores", "counts", "vep", "gnomad"]] = Query( - default=["scores"], description="One or more data types to include: scores, counts, clinVar, gnomAD, VEP" + namespaces: List[Literal["scores", "counts", "vep", "gnomad", "clingen"]] = Query( + default=["scores"], description="One or more data types to include: scores, counts, ClinGen, gnomAD, VEP" ), drop_na_columns: Optional[bool] = None, include_custom_columns: Optional[bool] = None, @@ -732,7 +733,7 @@ def get_score_set_variants_csv( The index to start from. If None, starts from the beginning. limit : Optional[int] The maximum number of variants to return. If None, returns all variants. - namespaces: List[Literal["scores", "counts", "vep", "gnomad"]] + namespaces: List[Literal["scores", "counts", "vep", "gnomad", "clingen"]] The namespaces of all columns except for accession, hgvs_nt, hgvs_pro, and hgvs_splice. We may add ClinVar in the future. drop_na_columns : bool, optional @@ -1551,7 +1552,20 @@ async def create_score_set( score_calibrations: list[ScoreCalibration] = [] if item_create.score_calibrations: for calibration_create in item_create.score_calibrations: - created_calibration_item = await create_score_calibration(db, calibration_create, user_data.user) + # TODO#592: Support for class-based calibrations on score set creation + if calibration_create.class_based: + logger.info( + msg="Failed to create score set; Class-based calibrations are not supported on score set creation.", + extra=logging_context(), + ) + raise HTTPException( + status_code=409, + detail="Class-based calibrations are not supported on score set creation. Please create class-based calibrations after creating the score set.", + ) + + created_calibration_item = await create_score_calibration( + db, calibration_create, user_data.user, variant_classes=None + ) created_calibration_item.investigator_provided = True # necessarily true on score set creation score_calibrations.append(created_calibration_item) @@ -1691,6 +1705,41 @@ async def create_score_set( response_model_exclude_none=True, responses={**BASE_400_RESPONSE, **ACCESS_CONTROL_ERROR_RESPONSES}, summary="Upload score and variant count files for a score set", + openapi_extra={ + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "type": "object", + "properties": { + "scores_file": { + "type": "string", + "format": "binary", + "description": "CSV file containing variant scores. This file is required, and should have at least one score column.", + }, + "counts_file": { + "type": "string", + "format": "binary", + "description": "CSV file containing variant counts. If provided, this file should have the same index and variant columns as the scores file.", + }, + "score_columns_metadata": { + "type": "string", + "format": "binary", + "description": "JSON file containing metadata for score columns. If provided, this file should have metadata for one or more score columns in the scores file. This JSON file should provide a dictionary mapping column names to metadata objects. Metadata objects should follow the DatasetColumnMetadata schema: `{'description': string, 'details': string}`.", + }, + "count_columns_metadata": { + "type": "string", + "format": "binary", + "description": "JSON file containing metadata for count columns. If provided, this file should have metadata for one or more count columns in the counts file. This JSON file should provide a dictionary mapping column names to metadata objects. Metadata objects should follow the DatasetColumnMetadata schema: `{'description': string, 'details': string}`.", + }, + }, + "required": ["scores_file"], + } + }, + }, + "description": "Score files, to be uploaded as multipart form data. The `scores_file` is required, while the `counts_file`, `score_columns_metadata`, and `count_columns_metadata` are optional.", + } + }, ) async def upload_score_set_variant_data( *, @@ -1763,6 +1812,41 @@ async def upload_score_set_variant_data( response_model_exclude_none=True, responses={**BASE_400_RESPONSE, **ACCESS_CONTROL_ERROR_RESPONSES}, summary="Update score ranges / calibrations for a score set", + openapi_extra={ + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "type": "object", + "properties": { + **score_set.ScoreSetUpdateAllOptional.model_json_schema(by_alias=False)["properties"], + "scores_file": { + "type": "string", + "format": "binary", + "description": "CSV file containing variant scores. If provided, this file should have at least one score column.", + }, + "counts_file": { + "type": "string", + "format": "binary", + "description": "CSV file containing variant counts. If provided, this file should have the same index and variant columns as the scores file.", + }, + "score_columns_metadata": { + "type": "string", + "format": "binary", + "description": "JSON file containing metadata for score columns. If provided, this file should have metadata for one or more score columns in the scores file. This JSON file should provide a dictionary mapping column names to metadata objects. Metadata objects should follow the DatasetColumnMetadata schema: `{'description': string, 'details': string}`.", + }, + "count_columns_metadata": { + "type": "string", + "format": "binary", + "description": "JSON file containing metadata for count columns. If provided, this file should have metadata for one or more count columns in the counts file. This JSON file should provide a dictionary mapping column names to metadata objects. Metadata objects should follow the DatasetColumnMetadata schema: `{'description': string, 'details': string}`.", + }, + }, + } + }, + }, + "description": "Score set properties and score files, to be uploaded as multipart form data. All fields here are optional, and only those provided will be updated.", + } + }, ) async def update_score_set_with_variants( *, @@ -1780,6 +1864,13 @@ async def update_score_set_with_variants( """ logger.info(msg="Began score set with variants update.", extra=logging_context()) + # TODO#629: Use `flexible_model_loader` utility here to support both form data and JSON body. + # See: https://github.com/VariantEffect/mavedb-api/pull/589/changes/d1641de7e4bee43e8a0c9f9283e022c5b56830ff + # Currently, only form data is supported but this would allow us to also support JSON bodies + # in cases where no files are being uploaded. My view is accepting score set calibration + # information via a single form field is also more straightforward than handling all the score + # set update fields as separate form fields and parsing them into an object. Doing so will also + # simplify the OpenAPI schema for this endpoint. try: # Get all form data from the request form_data = await request.form() diff --git a/src/mavedb/routers/target_genes.py b/src/mavedb/routers/target_genes.py index 29f91c5e..a304ee89 100644 --- a/src/mavedb/routers/target_genes.py +++ b/src/mavedb/routers/target_genes.py @@ -4,13 +4,14 @@ from sqlalchemy.orm import Session, selectinload from mavedb import deps -from mavedb.lib.authentication import UserData, get_current_user +from mavedb.lib.authentication import get_current_user from mavedb.lib.authorization import require_current_user from mavedb.lib.permissions import Action, has_permission from mavedb.lib.score_sets import find_superseded_score_set_tail from mavedb.lib.target_genes import ( search_target_genes as _search_target_genes, ) +from mavedb.lib.types.authentication import UserData from mavedb.models.score_set import ScoreSet from mavedb.models.target_gene import TargetGene from mavedb.routers.shared import ACCESS_CONTROL_ERROR_RESPONSES, PUBLIC_ERROR_RESPONSES, ROUTER_BASE_PREFIX diff --git a/src/mavedb/routers/users.py b/src/mavedb/routers/users.py index fd3a4d95..7aafc0ab 100644 --- a/src/mavedb/routers/users.py +++ b/src/mavedb/routers/users.py @@ -5,11 +5,11 @@ from starlette.convertors import Convertor, register_url_convertor from mavedb import deps -from mavedb.lib.authentication import UserData from mavedb.lib.authorization import RoleRequirer, require_current_user from mavedb.lib.logging import LoggedRoute from mavedb.lib.logging.context import logging_context, save_to_logging_context from mavedb.lib.permissions import Action, assert_permission +from mavedb.lib.types.authentication import UserData from mavedb.models.enums.user_role import UserRole from mavedb.models.user import User from mavedb.routers.shared import ACCESS_CONTROL_ERROR_RESPONSES, PUBLIC_ERROR_RESPONSES, ROUTER_BASE_PREFIX @@ -104,7 +104,7 @@ async def show_user_admin( msg="Could not show user; Requested user does not exist.", extra=logging_context(), ) - raise HTTPException(status_code=404, detail=f"User with ID {id} not found") + raise HTTPException(status_code=404, detail=f"user profile with ID {id} not found") # moving toward always accessing permissions module, even though this function does already require admin role to access assert_permission(user_data, item, Action.READ) @@ -135,7 +135,7 @@ async def show_user( msg="Could not show user; Requested user does not exist.", extra=logging_context(), ) - raise HTTPException(status_code=404, detail=f"User with ID {orcid_id} not found") + raise HTTPException(status_code=404, detail=f"user profile with ID {orcid_id} not found") # moving toward always accessing permissions module, even though this function does already require existing user in order to access assert_permission(user_data, item, Action.LOOKUP) @@ -217,7 +217,7 @@ async def update_user( msg="Could not update user; Requested user does not exist.", extra=logging_context(), ) - raise HTTPException(status_code=404, detail=f"User with id {id} not found.") + raise HTTPException(status_code=404, detail=f"user profile with id {id} not found.") assert_permission(user_data, item, Action.UPDATE) assert_permission(user_data, item, Action.ADD_ROLE) diff --git a/src/mavedb/routers/variants.py b/src/mavedb/routers/variants.py index 4de1de1d..c195f903 100644 --- a/src/mavedb/routers/variants.py +++ b/src/mavedb/routers/variants.py @@ -10,10 +10,11 @@ from sqlalchemy.sql import or_ from mavedb import deps -from mavedb.lib.authentication import UserData, get_current_user +from mavedb.lib.authentication import get_current_user from mavedb.lib.logging import LoggedRoute from mavedb.lib.logging.context import logging_context, save_to_logging_context from mavedb.lib.permissions import Action, assert_permission, has_permission +from mavedb.lib.types.authentication import UserData from mavedb.models.mapped_variant import MappedVariant from mavedb.models.score_set import ScoreSet from mavedb.models.variant import Variant diff --git a/src/mavedb/scripts/calibrated_variant_effects.py b/src/mavedb/scripts/calibrated_variant_effects.py new file mode 100644 index 00000000..a8f47088 --- /dev/null +++ b/src/mavedb/scripts/calibrated_variant_effects.py @@ -0,0 +1,148 @@ +""" +Count unique variant effect measurements within ACMG-classified functional ranges. + +This script analyzes MaveDB score sets to count how many variant effect measurements +have functional scores that fall within score calibration ranges associated with +ACMG (American College of Medical Genetics) classifications. The analysis provides +insights into how many variants can be clinically interpreted using established +evidence strength frameworks. + +Usage: + # Show help and available options + with_mavedb_local poetry run python3 -m mavedb.scripts.effect_measurements --help + + # Run in dry-run mode (default, no database changes, shows results) + with_mavedb_local poetry run python3 -m mavedb.scripts.effect_measurements --dry-run + + # Run and commit results (this script is read-only, so commit doesn't change anything) + with_mavedb_local poetry run python3 -m mavedb.scripts.effect_measurements --commit + +Behavior: + 1. Queries all non-superseded score sets that have score calibrations + 2. Identifies calibrations with functional ranges that have ACMG classifications + 3. For each qualifying score set, queries its variants with non-null scores + 4. Counts variants whose scores fall within ACMG-classified ranges + 5. Reports statistics on classification coverage + +Key Filters: + - Excludes superseded score sets (where superseding_score_set is not None) + - Only processes score sets that have at least one score calibration + - Only considers functional ranges with ACMG classification data + - Only counts variants that have non-null functional scores + - Each variant is counted only once per score set, even if it matches multiple ranges + +ACMG Classification Detection: + A functional range is considered to have an ACMG classification if its + acmg_classification field contains any of: + - criterion (PS3, BS3, etc.) + - evidence_strength (Supporting, Moderate, Strong, Very Strong) + - points (numeric evidence points) + +Performance Notes: + - Uses optimized queries to avoid loading unnecessary data + - Loads score sets and calibrations first, then queries variants separately + - Filters variants at the database level for better performance + - Memory usage scales with the number of score sets with ACMG ranges + +Output: + - Progress updates for each score set with classified variants + - Summary statistics including: + * Number of score sets with ACMG classifications + * Total unique variants processed + * Number of variants within ACMG-classified ranges + * Overall classification rate percentage + +Caveats: + - This is a read-only analysis script (makes no database changes) + - Variants with null/missing scores are included in the analysis +""" + +import logging +from typing import Set + +import click +from sqlalchemy import select +from sqlalchemy.orm import Session, joinedload + +from mavedb.models.score_calibration_functional_classification import ScoreCalibrationFunctionalClassification +from mavedb.models.score_set import ScoreSet +from mavedb.scripts.environment import with_database_session + +logger = logging.getLogger(__name__) + + +@click.command() +@with_database_session +def main(db: Session) -> None: + """Count unique variant effect measurements with ACMG-classified functional ranges.""" + + query = ( + select(ScoreSet) + .options(joinedload(ScoreSet.score_calibrations)) + .where(ScoreSet.private.is_(False)) # Public score sets only + .where(ScoreSet.superseded_score_set_id.is_(None)) # Not superseded + .where(ScoreSet.score_calibrations.any()) # Has calibrations + ) + + score_sets = db.scalars(query).unique().all() + + total_variants_count = 0 + classified_variants_count = 0 + score_sets_with_acmg_count = 0 + gene_list: Set[str] = set() + + click.echo(f"Found {len(score_sets)} non-superseded score sets with calibrations") + + for score_set in score_sets: + # Collect all ACMG-classified ranges from this score set's calibrations + acmg_ranges: list[ScoreCalibrationFunctionalClassification] = [] + for calibration in score_set.score_calibrations: + if calibration.functional_classifications: + for func_classification in calibration.functional_classifications: + if func_classification.acmg_classification_id is not None: + acmg_ranges.append(func_classification) + + if not acmg_ranges: + continue + + score_sets_with_acmg_count += 1 + + # Retain a list of unique target genes for reporting + for target in score_set.target_genes: + target_name = target.name + if not target_name: + continue + + gene_list.add(target_name.strip().upper()) + + score_set_classified_variants: set[int] = set() + for classified_range in acmg_ranges: + variants_classified_by_range: list[int] = [ + variant.id for variant in classified_range.variants if variant.id is not None + ] + score_set_classified_variants.update(variants_classified_by_range) + + total_variants_count += score_set.num_variants or 0 + classified_variants_count += len(score_set_classified_variants) + if score_set_classified_variants: + click.echo( + f"Score set {score_set.urn}: {len(score_set_classified_variants)} classified variants ({score_set.num_variants} total variants)" + ) + + click.echo("\n" + "=" * 60) + click.echo("SUMMARY") + click.echo("=" * 60) + click.echo(f"Score sets with ACMG classifications: {score_sets_with_acmg_count}") + click.echo(f"Total unique variants processed: {total_variants_count}") + click.echo(f"Variants within ACMG-classified ranges: {classified_variants_count}") + click.echo(f"Unique target genes covered ({len(gene_list)}):") + for gene in sorted(gene_list): + click.echo(f" - {gene}") + + if total_variants_count > 0: + percentage = (classified_variants_count / total_variants_count) * 100 + click.echo(f"Classification rate: {percentage:.1f}%") + + +if __name__ == "__main__": # pragma: no cover + main() diff --git a/src/mavedb/scripts/clingen_car_submission.py b/src/mavedb/scripts/clingen_car_submission.py index 29ea5fd8..0c0e7bc4 100644 --- a/src/mavedb/scripts/clingen_car_submission.py +++ b/src/mavedb/scripts/clingen_car_submission.py @@ -1,16 +1,17 @@ -import click import logging from typing import Sequence + +import click from sqlalchemy import select from sqlalchemy.orm import Session +from mavedb.lib.clingen.constants import CAR_SUBMISSION_ENDPOINT +from mavedb.lib.clingen.services import ClinGenAlleleRegistryService, get_allele_registry_associations +from mavedb.lib.variants import get_hgvs_from_post_mapped +from mavedb.models.mapped_variant import MappedVariant from mavedb.models.score_set import ScoreSet from mavedb.models.variant import Variant -from mavedb.models.mapped_variant import MappedVariant from mavedb.scripts.environment import with_database_session -from mavedb.lib.clingen.services import ClinGenAlleleRegistryService, get_allele_registry_associations -from mavedb.lib.clingen.constants import CAR_SUBMISSION_ENDPOINT -from mavedb.lib.variants import get_hgvs_from_post_mapped logger = logging.getLogger(__name__) diff --git a/src/mavedb/scripts/load_calibration_csv.py b/src/mavedb/scripts/load_calibration_csv.py index 5c3b2bba..904f51e9 100644 --- a/src/mavedb/scripts/load_calibration_csv.py +++ b/src/mavedb/scripts/load_calibration_csv.py @@ -92,7 +92,7 @@ import csv import re from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import click from sqlalchemy.orm import Session @@ -101,12 +101,13 @@ from mavedb.lib.oddspaths import oddspaths_evidence_strength_equivalent from mavedb.lib.score_calibrations import create_score_calibration_in_score_set from mavedb.models import score_calibration +from mavedb.models.enums.functional_classification import FunctionalClassification as FunctionalClassifcationOptions from mavedb.models.score_set import ScoreSet from mavedb.models.user import User from mavedb.scripts.environment import with_database_session from mavedb.view_models.acmg_classification import ACMGClassificationCreate from mavedb.view_models.publication_identifier import PublicationIdentifierCreate -from mavedb.view_models.score_calibration import FunctionalRangeCreate, ScoreCalibrationCreate +from mavedb.view_models.score_calibration import FunctionalClassificationCreate, ScoreCalibrationCreate BRNICH_PMID = "31892348" RANGE_PATTERN = re.compile(r"^\s*([\[(])\s*([^,]+)\s*,\s*([^\])]+)\s*([])])\s*$", re.IGNORECASE) @@ -152,23 +153,21 @@ def parse_interval(text: str) -> Tuple[Optional[float], Optional[float], bool, b return lower, upper, inclusive_lower, inclusive_upper -def normalize_classification( - raw: Optional[str], strength: Optional[str] -) -> Literal["normal", "abnormal", "not_specified"]: +def normalize_classification(raw: Optional[str], strength: Optional[str]) -> FunctionalClassifcationOptions: if raw: r = raw.strip().lower() if r in {"normal", "abnormal", "not_specified"}: - return r # type: ignore[return-value] + return FunctionalClassifcationOptions[r] if r in {"indeterminate", "uncertain", "unknown"}: - return "not_specified" + return FunctionalClassifcationOptions.not_specified if strength: if strength.upper().startswith("PS"): - return "abnormal" + return FunctionalClassifcationOptions.abnormal if strength.upper().startswith("BS"): - return "normal" + return FunctionalClassifcationOptions.normal - return "not_specified" + return FunctionalClassifcationOptions.not_specified def build_publications( @@ -274,7 +273,7 @@ def build_ranges(row: Dict[str, str], infer_strengths: bool = True) -> Tuple[Lis label = row.get(f"class_{i}_name", "").strip() ranges.append( - FunctionalRangeCreate( + FunctionalClassificationCreate( label=label, classification=classification, range=(lower, upper), @@ -366,7 +365,7 @@ def main(db: Session, csv_path: str, delimiter: str, overwrite: bool, purge_publ method_sources=method_publications, classification_sources=calculation_publications, research_use_only=False, - functional_ranges=ranges, + functional_classifications=ranges, notes=calibration_notes, ) except Exception as e: # broad to keep import running diff --git a/src/mavedb/scripts/load_pp_style_calibration.py b/src/mavedb/scripts/load_pp_style_calibration.py index 83abd1c4..5bcf6e46 100644 --- a/src/mavedb/scripts/load_pp_style_calibration.py +++ b/src/mavedb/scripts/load_pp_style_calibration.py @@ -84,6 +84,7 @@ from sqlalchemy.orm import Session from mavedb.lib.score_calibrations import create_score_calibration_in_score_set +from mavedb.models.enums.functional_classification import FunctionalClassification as FunctionalClassifcationOptions from mavedb.models.score_calibration import ScoreCalibration from mavedb.models.score_set import ScoreSet from mavedb.models.user import User @@ -92,9 +93,9 @@ POINT_LABEL_MAPPINGS: Dict[int, str] = { 8: "Very Strong", - 7: "Very Strong", - 6: "Very Strong", - 5: "Very Strong", + 7: "Strong", + 6: "Strong", + 5: "Strong", 4: "Strong", 3: "Moderate+", 2: "Moderate", @@ -183,7 +184,7 @@ def main(db: Session, archive_path: str, dataset_map: str, overwrite: bool) -> N click.echo(f" Overwriting existing '{calibration_name}' in Score Set {score_set.urn}") benign_has_lower_functional_scores = calibration_data.get("scoreset_flipped", False) - functional_ranges: List[score_calibration.FunctionalRangeCreate] = [] + functional_classifications: List[score_calibration.FunctionalClassificationCreate] = [] for points, range_data in calibration_data.get("point_ranges", {}).items(): if not range_data: continue @@ -212,9 +213,11 @@ def main(db: Session, archive_path: str, dataset_map: str, overwrite: bool) -> N inclusive_lower = False inclusive_upper = True if upper_bound is not None else False - functional_range = score_calibration.FunctionalRangeCreate( + functional_range = score_calibration.FunctionalClassificationCreate( label=f"{ps_or_bs} {strength_label} ({points})", - classification="abnormal" if points > 0 else "normal", + classification=FunctionalClassifcationOptions.abnormal + if points > 0 + else FunctionalClassifcationOptions.normal, range=range_data, acmg_classification=acmg_classification.ACMGClassificationCreate( points=int(points), @@ -222,15 +225,17 @@ def main(db: Session, archive_path: str, dataset_map: str, overwrite: bool) -> N inclusive_lower_bound=inclusive_lower, inclusive_upper_bound=inclusive_upper, ) - functional_ranges.append(functional_range) + functional_classifications.append(functional_range) score_calibration_create = score_calibration.ScoreCalibrationCreate( title=calibration_name, - functional_ranges=functional_ranges, + functional_classifications=functional_classifications, research_use_only=True, score_set_urn=score_set.urn, calibration_metadata={"prior_probability_pathogenicity": calibration_data.get("prior", None)}, method_sources=[ZEIBERG_CALIBRATION_CITATION], + threshold_sources=[], + classification_sources=[], ) new_calibration_object = asyncio.run( diff --git a/src/mavedb/server_main.py b/src/mavedb/server_main.py index 80db5403..23717e43 100644 --- a/src/mavedb/server_main.py +++ b/src/mavedb/server_main.py @@ -31,11 +31,12 @@ logging_context, save_to_logging_context, ) -from mavedb.lib.permissions import PermissionException +from mavedb.lib.permissions.exceptions import PermissionException from mavedb.lib.slack import send_slack_error from mavedb.models import * # noqa: F403 from mavedb.routers import ( access_keys, + alphafold, api_information, collections, controlled_keywords, @@ -59,7 +60,6 @@ taxonomies, users, variants, - alphafold, ) logger = logging.getLogger(__name__) diff --git a/src/mavedb/view_models/acmg_classification.py b/src/mavedb/view_models/acmg_classification.py index 05757442..5bb8832f 100644 --- a/src/mavedb/view_models/acmg_classification.py +++ b/src/mavedb/view_models/acmg_classification.py @@ -4,16 +4,17 @@ classifications, and associated odds path ratios. """ +from datetime import date from typing import Optional + from pydantic import model_validator -from mavedb.lib.exceptions import ValidationError from mavedb.lib.acmg import ( - StrengthOfEvidenceProvided, ACMGCriterion, + StrengthOfEvidenceProvided, points_evidence_strength_equivalent, ) - +from mavedb.lib.exceptions import ValidationError from mavedb.view_models import record_type_validator, set_record_type from mavedb.view_models.base.base import BaseModel @@ -76,6 +77,15 @@ class SavedACMGClassification(ACMGClassificationBase): record_type: str = None # type: ignore _record_type_factory = record_type_validator()(set_record_type) + creation_date: date + modification_date: date + + class Config: + """Pydantic configuration (ORM mode).""" + + from_attributes = True + arbitrary_types_allowed = True + class ACMGClassification(SavedACMGClassification): """Complete ACMG classification model returned by the API.""" diff --git a/src/mavedb/view_models/clinical_control.py b/src/mavedb/view_models/clinical_control.py index f098dd12..85cd7834 100644 --- a/src/mavedb/view_models/clinical_control.py +++ b/src/mavedb/view_models/clinical_control.py @@ -2,11 +2,14 @@ from __future__ import annotations from datetime import date -from typing import Optional, Sequence +from typing import TYPE_CHECKING, Optional, Sequence from mavedb.view_models import record_type_validator, set_record_type from mavedb.view_models.base.base import BaseModel +if TYPE_CHECKING: + from mavedb.view_models.mapped_variant import MappedVariantCreate, MappedVariantForClinicalControl + class ClinicalControlBase(BaseModel): db_identifier: str @@ -54,11 +57,3 @@ class ClinicalControlWithMappedVariants(SavedClinicalControlWithMappedVariants): class ClinicalControlOptions(BaseModel): db_name: str available_versions: list[str] - - -# ruff: noqa: E402 -from mavedb.view_models.mapped_variant import MappedVariantCreate, MappedVariantForClinicalControl - -# ClinicalControlUpdate.model_rebuild() -SavedClinicalControlWithMappedVariants.model_rebuild() -ClinicalControlWithMappedVariants.model_rebuild() diff --git a/src/mavedb/view_models/collection.py b/src/mavedb/view_models/collection.py index 9761686d..08d2c0a3 100644 --- a/src/mavedb/view_models/collection.py +++ b/src/mavedb/view_models/collection.py @@ -1,5 +1,5 @@ from datetime import date -from typing import Any, Sequence, Optional +from typing import Any, Optional, Sequence from pydantic import Field, model_validator @@ -84,36 +84,36 @@ class Config: from_attributes = True # These 'synthetic' fields are generated from other model properties. Transform data from other properties as needed, setting - # the appropriate field on the model itself. Then, proceed with Pydantic ingestion once fields are created. + # the appropriate field on the model itself. Then, proceed with Pydantic ingestion once fields are created. Only perform these + # transformations if the relevant attributes are present on the input data (i.e., when creating from an ORM object). @model_validator(mode="before") def generate_contribution_role_user_relationships(cls, data: Any): - try: - user_associations = transform_contribution_role_associations_to_roles(data.user_associations) - for k, v in user_associations.items(): - data.__setattr__(k, v) - - except AttributeError as exc: - raise ValidationError( - f"Unable to create {cls.__name__} without attribute: {exc}." # type: ignore - ) + if hasattr(data, "user_associations"): + try: + user_associations = transform_contribution_role_associations_to_roles(data.user_associations) + for k, v in user_associations.items(): + data.__setattr__(k, v) + + except (AttributeError, KeyError) as exc: + raise ValidationError(f"Unable to coerce user associations for {cls.__name__}: {exc}.") return data @model_validator(mode="before") def generate_score_set_urn_list(cls, data: Any): - if not hasattr(data, "score_set_urns"): + if hasattr(data, "score_sets"): try: data.__setattr__("score_set_urns", transform_score_set_list_to_urn_list(data.score_sets)) - except AttributeError as exc: - raise ValidationError(f"Unable to create {cls.__name__} without attribute: {exc}.") # type: ignore + except (AttributeError, KeyError) as exc: + raise ValidationError(f"Unable to coerce score set urns for {cls.__name__}: {exc}.") return data @model_validator(mode="before") def generate_experiment_urn_list(cls, data: Any): - if not hasattr(data, "experiment_urns"): + if hasattr(data, "experiments"): try: data.__setattr__("experiment_urns", transform_experiment_list_to_urn_list(data.experiments)) - except AttributeError as exc: - raise ValidationError(f"Unable to create {cls.__name__} without attribute: {exc}.") # type: ignore + except (AttributeError, KeyError) as exc: + raise ValidationError(f"Unable to coerce experiment urns for {cls.__name__}: {exc}.") return data @@ -132,3 +132,14 @@ class Collection(SavedCollection): # NOTE: Coupled to ContributionRole enum class AdminCollection(Collection): pass + + +# Properties to return for official collections +class OfficialCollection(BaseModel): + badge_name: str + name: str + urn: str + + class Config: + arbitrary_types_allowed = True + from_attributes = True diff --git a/src/mavedb/view_models/components/external_link.py b/src/mavedb/view_models/components/external_link.py new file mode 100644 index 00000000..43c5d28e --- /dev/null +++ b/src/mavedb/view_models/components/external_link.py @@ -0,0 +1,15 @@ +from typing import Optional + +from mavedb.view_models.base.base import BaseModel + + +class ExternalLink(BaseModel): + """ + Represents an external hyperlink for view models. + + Attributes: + url (Optional[str]): Fully qualified URL for the external resource. + May be None if no link is available or applicable. + """ + + url: Optional[str] = None diff --git a/src/mavedb/view_models/experiment.py b/src/mavedb/view_models/experiment.py index b05766ff..e9387bb7 100644 --- a/src/mavedb/view_models/experiment.py +++ b/src/mavedb/view_models/experiment.py @@ -1,18 +1,20 @@ from datetime import date -from typing import Any, Collection, Optional, Sequence +from typing import TYPE_CHECKING, Any, Collection, Optional, Sequence -from pydantic import field_validator, model_validator, ValidationInfo +from pydantic import ValidationInfo, field_validator, model_validator +from mavedb.lib.validation import urn_re from mavedb.lib.validation.exceptions import ValidationError from mavedb.lib.validation.transform import ( transform_experiment_set_to_urn, - transform_score_set_list_to_urn_list, transform_record_publication_identifiers, + transform_score_set_list_to_urn_list, ) -from mavedb.lib.validation import urn_re from mavedb.lib.validation.utilities import is_null from mavedb.view_models import record_type_validator, set_record_type from mavedb.view_models.base.base import BaseModel +from mavedb.view_models.collection import OfficialCollection +from mavedb.view_models.components.external_link import ExternalLink from mavedb.view_models.contributor import Contributor, ContributorCreate from mavedb.view_models.doi_identifier import ( DoiIdentifier, @@ -36,15 +38,8 @@ ) from mavedb.view_models.user import SavedUser, User - -class OfficialCollection(BaseModel): - badge_name: str - name: str - urn: str - - class Config: - arbitrary_types_allowed = True - from_attributes = True +if TYPE_CHECKING: + from mavedb.view_models.score_set import ScoreSetPublicDump class ExperimentBase(BaseModel): @@ -115,6 +110,7 @@ class SavedExperiment(ExperimentBase): contributors: list[Contributor] keywords: Sequence[SavedExperimentControlledKeyword] score_set_urns: list[str] + external_links: dict[str, ExternalLink] _record_type_factory = record_type_validator()(set_record_type) @@ -129,12 +125,11 @@ def publication_identifiers_validator(cls, v: Any, info: ValidationInfo) -> list return list(v) # Re-cast into proper list-like type # These 'synthetic' fields are generated from other model properties. Transform data from other properties as needed, setting - # the appropriate field on the model itself. Then, proceed with Pydantic ingestion once fields are created. + # the appropriate field on the model itself. Then, proceed with Pydantic ingestion once fields are created. Only perform these + # transformations if the relevant attributes are present on the input data (i.e., when creating from an ORM object). @model_validator(mode="before") def generate_primary_and_secondary_publications(cls, data: Any): - if not hasattr(data, "primary_publication_identifiers") or not hasattr( - data, "secondary_publication_identifiers" - ): + if hasattr(data, "publication_identifier_associations"): try: publication_identifiers = transform_record_publication_identifiers( data.publication_identifier_associations @@ -145,28 +140,30 @@ def generate_primary_and_secondary_publications(cls, data: Any): data.__setattr__( "secondary_publication_identifiers", publication_identifiers["secondary_publication_identifiers"] ) - except AttributeError as exc: + except (KeyError, AttributeError) as exc: raise ValidationError( - f"Unable to create {cls.__name__} without attribute: {exc}." # type: ignore + f"Unable to coerce publication identifier attributes from ORM for {cls.__name__}: {exc}." # type: ignore ) return data @model_validator(mode="before") def generate_score_set_urn_list(cls, data: Any): - if not hasattr(data, "score_set_urns"): + if hasattr(data, "score_sets"): try: data.__setattr__("score_set_urns", transform_score_set_list_to_urn_list(data.score_sets)) - except AttributeError as exc: - raise ValidationError(f"Unable to create {cls.__name__} without attribute: {exc}.") # type: ignore + except (KeyError, AttributeError) as exc: + raise ValidationError(f"Unable to coerce associated score set URNs from ORM for {cls.__name__}: {exc}.") # type: ignore return data @model_validator(mode="before") def generate_experiment_set_urn(cls, data: Any): - if not hasattr(data, "experiment_set_urn"): + if hasattr(data, "experiment_set"): try: data.__setattr__("experiment_set_urn", transform_experiment_set_to_urn(data.experiment_set)) - except AttributeError as exc: - raise ValidationError(f"Unable to create {cls.__name__} without attribute: {exc}.") # type: ignore + except (KeyError, AttributeError) as exc: + raise ValidationError( + f"Unable to coerce associated experiment set URN from ORM for {cls.__name__}: {exc}." + ) # type: ignore return data @@ -198,9 +195,3 @@ class AdminExperiment(Experiment): # Properties to include in a dump of all published data. class ExperimentPublicDump(SavedExperiment): score_sets: "Sequence[ScoreSetPublicDump]" - - -# ruff: noqa: E402 -from mavedb.view_models.score_set import ScoreSetPublicDump - -ExperimentPublicDump.model_rebuild() diff --git a/src/mavedb/view_models/experiment_set.py b/src/mavedb/view_models/experiment_set.py index de414a4b..c65d1ba8 100644 --- a/src/mavedb/view_models/experiment_set.py +++ b/src/mavedb/view_models/experiment_set.py @@ -1,11 +1,14 @@ from datetime import date -from typing import Sequence, Optional +from typing import TYPE_CHECKING, Optional, Sequence from mavedb.view_models import record_type_validator, set_record_type from mavedb.view_models.base.base import BaseModel from mavedb.view_models.contributor import Contributor from mavedb.view_models.user import SavedUser, User +if TYPE_CHECKING: + from mavedb.view_models.experiment import Experiment, ExperimentPublicDump, SavedExperiment + class ExperimentSetBase(BaseModel): urn: str @@ -60,12 +63,3 @@ class ExperimentSetPublicDump(SavedExperimentSet): experiments: "Sequence[ExperimentPublicDump]" created_by: Optional[User] = None modified_by: Optional[User] = None - - -# ruff: noqa: E402 -from mavedb.view_models.experiment import Experiment, ExperimentPublicDump, SavedExperiment - -SavedExperimentSet.model_rebuild() -ExperimentSet.model_rebuild() -AdminExperimentSet.model_rebuild() -ExperimentSetPublicDump.model_rebuild() diff --git a/src/mavedb/view_models/gnomad_variant.py b/src/mavedb/view_models/gnomad_variant.py index 1171cb49..97dd675e 100644 --- a/src/mavedb/view_models/gnomad_variant.py +++ b/src/mavedb/view_models/gnomad_variant.py @@ -2,11 +2,14 @@ from __future__ import annotations from datetime import date -from typing import Optional, Sequence +from typing import TYPE_CHECKING, Optional, Sequence from mavedb.view_models import record_type_validator, set_record_type from mavedb.view_models.base.base import BaseModel +if TYPE_CHECKING: + from mavedb.view_models.mapped_variant import MappedVariant, MappedVariantCreate, SavedMappedVariant + class GnomADVariantBase(BaseModel): """Base class for GnomAD variant view models.""" @@ -67,11 +70,3 @@ class GnomADVariantWithMappedVariants(SavedGnomADVariantWithMappedVariants): """GnomAD variant view model with mapped variants for non-admin clients.""" mapped_variants: Sequence["MappedVariant"] - - -# ruff: noqa: E402 -from mavedb.view_models.mapped_variant import MappedVariant, SavedMappedVariant, MappedVariantCreate - -GnomADVariantUpdate.model_rebuild() -SavedGnomADVariantWithMappedVariants.model_rebuild() -GnomADVariantWithMappedVariants.model_rebuild() diff --git a/src/mavedb/view_models/mapped_variant.py b/src/mavedb/view_models/mapped_variant.py index 13aec65d..bcb02ecf 100644 --- a/src/mavedb/view_models/mapped_variant.py +++ b/src/mavedb/view_models/mapped_variant.py @@ -2,7 +2,7 @@ from __future__ import annotations from datetime import date -from typing import Any, Optional, Sequence +from typing import TYPE_CHECKING, Any, Optional, Sequence from pydantic import model_validator @@ -10,6 +10,10 @@ from mavedb.view_models import record_type_validator, set_record_type from mavedb.view_models.base.base import BaseModel +if TYPE_CHECKING: + from mavedb.view_models.clinical_control import ClinicalControl, ClinicalControlBase, SavedClinicalControl + from mavedb.view_models.gnomad_variant import GnomADVariant, GnomADVariantBase, SavedGnomADVariant + class MappedVariantBase(BaseModel): pre_mapped: Optional[Any] = None @@ -55,13 +59,16 @@ class SavedMappedVariant(MappedVariantBase): class Config: from_attributes = True + # These 'synthetic' fields are generated from other model properties. Transform data from other properties as needed, setting + # the appropriate field on the model itself. Then, proceed with Pydantic ingestion once fields are created. Only perform these + # transformations if the relevant attributes are present on the input data (i.e., when creating from an ORM object). @model_validator(mode="before") def generate_score_set_urn_list(cls, data: Any): - if not hasattr(data, "variant_urn") and hasattr(data, "variant"): + if hasattr(data, "variant"): try: data.__setattr__("variant_urn", None if not data.variant else data.variant.urn) - except AttributeError as exc: - raise ValidationError(f"Unable to create {cls.__name__} without attribute: {exc}.") # type: ignore + except (AttributeError, KeyError) as exc: + raise ValidationError(f"Unable to coerce variant urn for {cls.__name__}: {exc}.") # type: ignore return data @@ -97,8 +104,8 @@ def generate_score_set_urn_list(cls, data: Any): # ruff: noqa: E402 -from mavedb.view_models.clinical_control import ClinicalControlBase, ClinicalControl, SavedClinicalControl -from mavedb.view_models.gnomad_variant import GnomADVariantBase, GnomADVariant, SavedGnomADVariant +from mavedb.view_models.clinical_control import ClinicalControl, ClinicalControlBase, SavedClinicalControl +from mavedb.view_models.gnomad_variant import GnomADVariant, GnomADVariantBase, SavedGnomADVariant MappedVariantUpdate.model_rebuild() SavedMappedVariantWithControls.model_rebuild() diff --git a/src/mavedb/view_models/model_rebuild.py b/src/mavedb/view_models/model_rebuild.py new file mode 100644 index 00000000..ce738143 --- /dev/null +++ b/src/mavedb/view_models/model_rebuild.py @@ -0,0 +1,133 @@ +""" +Centralized model rebuilding for view models with circular dependencies. + +This module handles the rebuilding of all Pydantic models that have forward references +to other models, resolving circular import issues by performing the rebuilds after +all modules have been imported. +""" + +from __future__ import annotations + +import importlib +import inspect +import logging +import pkgutil +from pathlib import Path + +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +def _discover_and_sort_models(): + """Discover all Pydantic models and sort them by dependencies.""" + import mavedb.view_models + + view_models_path = Path(mavedb.view_models.__file__).parent + models_by_module = {} + + # Discover all models grouped by module + for module_info in pkgutil.walk_packages([str(view_models_path)], "mavedb.view_models."): + module_name = module_info.name + if module_name.endswith(".model_rebuild"): + continue + + try: + module = importlib.import_module(module_name) + module_models = [] + + for name, obj in inspect.getmembers(module, inspect.isclass): + if issubclass(obj, BaseModel) and obj.__module__ == module_name and hasattr(obj, "model_rebuild"): + module_models.append((f"{module_name}.{name}", obj)) + + if module_models: + models_by_module[module_name] = module_models + + except ImportError as e: + logger.warning("Could not import %s: %s", module_name, e) + + # Sort models within each module by dependency (base classes first) + sorted_models = [] + for module_name, module_models in models_by_module.items(): + + def dependency_count(item): + _, model_class = item + # Count base classes within the same module + count = 0 + for base in model_class.__bases__: + if issubclass(base, BaseModel) and base != BaseModel and any(base == mc for _, mc in module_models): + count += 1 + return count + + module_models.sort(key=dependency_count) + sorted_models.extend(module_models) + + return sorted_models + + +def rebuild_all_models(): + """ + Rebuild all Pydantic models in the view_models package. + + Discovers models, sorts by dependencies, and rebuilds with multi-pass + approach to achieve 0 circular dependencies. + """ + # Discover and sort models by dependencies + models_to_rebuild = _discover_and_sort_models() + + # Create registry for forward reference resolution + model_registry = {name.split(".")[-1]: cls for name, cls in models_to_rebuild} + + logger.debug("Rebuilding %d Pydantic models...", len(models_to_rebuild)) + + successful_rebuilds = 0 + remaining_models = models_to_rebuild[:] + + # Multi-pass rebuild to handle complex dependencies + for pass_num in range(3): + if not remaining_models: + break + + models_for_next_pass = [] + + for model_name, model_class in remaining_models: + try: + # Temporarily inject all models into the module for forward references + module = importlib.import_module(model_class.__module__) + injected = {} + + for simple_name, ref_class in model_registry.items(): + if simple_name not in module.__dict__: + injected[simple_name] = module.__dict__.get(simple_name) + module.__dict__[simple_name] = ref_class + + try: + model_class.model_rebuild() + successful_rebuilds += 1 + logger.debug("Rebuilt %s", model_name) + finally: + # Restore original module state + for name, original in injected.items(): + if original is None: + module.__dict__.pop(name, None) + else: + module.__dict__[name] = original + + except Exception as e: + if "is not defined" in str(e) and pass_num < 2: + models_for_next_pass.append((model_name, model_class)) + logger.debug("Deferring %s to next pass", model_name) + else: + logger.error("Failed to rebuild %s: %s", model_name, e) + + remaining_models = models_for_next_pass + + logger.debug( + "Rebuilt %d Pydantic models successfully, %d models with circular dependencies remain.", + successful_rebuilds, + len(remaining_models), + ) + + +# Automatically rebuild all models when this module is imported +rebuild_all_models() diff --git a/src/mavedb/view_models/score_calibration.py b/src/mavedb/view_models/score_calibration.py index 00d5d692..985d5949 100644 --- a/src/mavedb/view_models/score_calibration.py +++ b/src/mavedb/view_models/score_calibration.py @@ -5,9 +5,9 @@ """ from datetime import date -from typing import Any, Collection, Literal, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Collection, Optional, Sequence, Union -from pydantic import field_validator, model_validator +from pydantic import Field, field_validator, model_validator from mavedb.lib.oddspaths import oddspaths_evidence_strength_equivalent from mavedb.lib.validation.exceptions import ValidationError @@ -16,6 +16,7 @@ transform_score_set_to_urn, ) from mavedb.lib.validation.utilities import inf_or_float +from mavedb.models.enums.functional_classification import FunctionalClassification as FunctionalClassifcationOptions from mavedb.view_models import record_type_validator, set_record_type from mavedb.view_models.acmg_classification import ( ACMGClassification, @@ -33,10 +34,16 @@ ) from mavedb.view_models.user import SavedUser, User +if TYPE_CHECKING: + from mavedb.view_models.variant import ( + SavedVariantEffectMeasurement, + VariantEffectMeasurement, + ) + ### Functional range models -class FunctionalRangeBase(BaseModel): +class FunctionalClassificationBase(BaseModel): """Base functional range model. Represents a labeled numeric score interval with optional evidence metadata. @@ -46,22 +53,30 @@ class FunctionalRangeBase(BaseModel): label: str description: Optional[str] = None - classification: Literal["normal", "abnormal", "not_specified"] = "not_specified" + functional_classification: FunctionalClassifcationOptions = FunctionalClassifcationOptions.not_specified + + range: Optional[tuple[Union[float, None], Union[float, None]]] = None # (lower_bound, upper_bound) + class_: Optional[str] = Field(None, alias="class", serialization_alias="class") - range: tuple[Union[float, None], Union[float, None]] - inclusive_lower_bound: bool = True - inclusive_upper_bound: bool = False + inclusive_lower_bound: Optional[bool] = None + inclusive_upper_bound: Optional[bool] = None acmg_classification: Optional[ACMGClassificationBase] = None oddspaths_ratio: Optional[float] = None positive_likelihood_ratio: Optional[float] = None + class Config: + populate_by_name = True + @field_validator("range") def ranges_are_not_backwards( - cls, field_value: tuple[Union[float, None], Union[float, None]] - ) -> tuple[Union[float, None], Union[float, None]]: + cls, field_value: Optional[tuple[Union[float, None], Union[float, None]]] + ) -> Optional[tuple[Union[float, None], Union[float, None]]]: """Reject reversed or zero-width intervals.""" + if field_value is None: + return None + lower = inf_or_float(field_value[0], True) upper = inf_or_float(field_value[1], False) if lower > upper: @@ -78,36 +93,95 @@ def ratios_must_be_positive(cls, field_value: Optional[float]) -> Optional[float return field_value + @field_validator("class_", "label", mode="before") + def labels_and_class_strip_whitespace_and_validate_not_empty(cls, field_value: Optional[str]) -> Optional[str]: + """Strip leading/trailing whitespace from class names.""" + if field_value is None: + return None + + field_value = field_value.strip() + if not field_value: + raise ValidationError("This field may not be empty or contain only whitespace.") + + return field_value + + @model_validator(mode="after") + def At_least_one_of_range_or_class_must_be_provided( + self: "FunctionalClassificationBase", + ) -> "FunctionalClassificationBase": + """Either a range or a class must be provided.""" + if self.range is None and self.class_ is None: + raise ValidationError("A functional range must specify either a numeric range or a class.") + + return self + @model_validator(mode="after") - def inclusive_bounds_do_not_include_infinity(self: "FunctionalRangeBase") -> "FunctionalRangeBase": + def class_and_range_mutually_exclusive( + self: "FunctionalClassificationBase", + ) -> "FunctionalClassificationBase": + """Either a range or a class may be provided, but not both.""" + if self.range is not None and self.class_ is not None: + raise ValidationError("A functional range may not specify both a numeric range and a class.") + + return self + + @model_validator(mode="after") + def inclusive_bounds_require_range(self: "FunctionalClassificationBase") -> "FunctionalClassificationBase": + """Inclusive bounds may only be set if a range is provided. If they are unset, default them.""" + if self.class_ is not None: + if self.inclusive_lower_bound is not None: + raise ValidationError( + "An inclusive lower bound may not be set on a class based functional classification." + ) + if self.inclusive_upper_bound is not None: + raise ValidationError( + "An inclusive upper bound may not be set on a class based functional classification." + ) + + if self.range is not None: + if self.inclusive_lower_bound is None: + self.inclusive_lower_bound = True + if self.inclusive_upper_bound is None: + self.inclusive_upper_bound = False + + return self + + @model_validator(mode="after") + def inclusive_bounds_do_not_include_infinity( + self: "FunctionalClassificationBase", + ) -> "FunctionalClassificationBase": """Disallow inclusive bounds on unbounded (infinite) ends.""" - if self.inclusive_lower_bound and self.range[0] is None: + if self.inclusive_lower_bound and self.range is not None and self.range[0] is None: raise ValidationError("An inclusive lower bound may not include negative infinity.") - if self.inclusive_upper_bound and self.range[1] is None: + if self.inclusive_upper_bound and self.range is not None and self.range[1] is None: raise ValidationError("An inclusive upper bound may not include positive infinity.") return self @model_validator(mode="after") - def acmg_classification_evidence_agrees_with_classification(self: "FunctionalRangeBase") -> "FunctionalRangeBase": + def acmg_classification_evidence_agrees_with_classification( + self: "FunctionalClassificationBase", + ) -> "FunctionalClassificationBase": """If oddspaths is provided, ensure its evidence agrees with the classification.""" if self.acmg_classification is None or self.acmg_classification.criterion is None: return self if ( - self.classification == "normal" + self.functional_classification is FunctionalClassifcationOptions.normal and self.acmg_classification.criterion.is_pathogenic - or self.classification == "abnormal" + or self.functional_classification is FunctionalClassifcationOptions.abnormal and self.acmg_classification.criterion.is_benign ): raise ValidationError( - f"The ACMG classification criterion ({self.acmg_classification.criterion}) must agree with the functional range classification ({self.classification})." + f"The ACMG classification criterion ({self.acmg_classification.criterion}) must agree with the functional range classification ({self.functional_classification})." ) return self @model_validator(mode="after") - def oddspaths_ratio_agrees_with_acmg_classification(self: "FunctionalRangeBase") -> "FunctionalRangeBase": + def oddspaths_ratio_agrees_with_acmg_classification( + self: "FunctionalClassificationBase", + ) -> "FunctionalClassificationBase": """If both oddspaths and acmg_classification are provided, ensure they agree.""" if self.oddspaths_ratio is None or self.acmg_classification is None: return self @@ -129,42 +203,63 @@ def oddspaths_ratio_agrees_with_acmg_classification(self: "FunctionalRangeBase") def is_contained_by_range(self, score: float) -> bool: """Determine if a given score falls within this functional range.""" + if not self.range: + return False + lower_bound, upper_bound = ( inf_or_float(self.range[0], lower=True), inf_or_float(self.range[1], lower=False), ) - lower_check = score > lower_bound or (self.inclusive_lower_bound and score == lower_bound) - upper_check = score < upper_bound or (self.inclusive_upper_bound and score == upper_bound) + lower_check = score > lower_bound or (self.inclusive_lower_bound is True and score == lower_bound) + upper_check = score < upper_bound or (self.inclusive_upper_bound is True and score == upper_bound) return lower_check and upper_check + @property + def class_based(self) -> bool: + """Determine if this functional classification is class-based.""" + return self.class_ is not None -class FunctionalRangeModify(FunctionalRangeBase): + @property + def range_based(self) -> bool: + """Determine if this functional classification is range-based.""" + return self.range is not None + + +class FunctionalClassificationModify(FunctionalClassificationBase): """Model used to modify an existing functional range.""" acmg_classification: Optional[ACMGClassificationModify] = None -class FunctionalRangeCreate(FunctionalRangeModify): +class FunctionalClassificationCreate(FunctionalClassificationModify): """Model used to create a new functional range.""" acmg_classification: Optional[ACMGClassificationCreate] = None -class SavedFunctionalRange(FunctionalRangeBase): +class SavedFunctionalClassification(FunctionalClassificationBase): """Persisted functional range model (includes record type metadata).""" record_type: str = None # type: ignore acmg_classification: Optional[SavedACMGClassification] = None + variants: Sequence["SavedVariantEffectMeasurement"] = [] _record_type_factory = record_type_validator()(set_record_type) + class Config: + """Pydantic configuration (ORM mode).""" + + from_attributes = True + arbitrary_types_allowed = True + -class FunctionalRange(SavedFunctionalRange): +class FunctionalClassification(SavedFunctionalClassification): """Complete functional range model returned by the API.""" acmg_classification: Optional[ACMGClassification] = None + variants: Sequence["VariantEffectMeasurement"] = [] ### Score calibration models @@ -183,21 +278,26 @@ class ScoreCalibrationBase(BaseModel): baseline_score_description: Optional[str] = None notes: Optional[str] = None - functional_ranges: Optional[Sequence[FunctionalRangeBase]] = None - threshold_sources: Optional[Sequence[PublicationIdentifierBase]] = None - classification_sources: Optional[Sequence[PublicationIdentifierBase]] = None - method_sources: Optional[Sequence[PublicationIdentifierBase]] = None + functional_classifications: Optional[Sequence[FunctionalClassificationBase]] = None + threshold_sources: Sequence[PublicationIdentifierBase] + classification_sources: Sequence[PublicationIdentifierBase] + method_sources: Sequence[PublicationIdentifierBase] calibration_metadata: Optional[dict] = None - @field_validator("functional_ranges") + @field_validator("functional_classifications") def ranges_do_not_overlap( - cls, field_value: Optional[Sequence[FunctionalRangeBase]] - ) -> Optional[Sequence[FunctionalRangeBase]]: + cls, field_value: Optional[Sequence[FunctionalClassificationBase]] + ) -> Optional[Sequence[FunctionalClassificationBase]]: """Ensure that no two functional ranges overlap (respecting inclusivity).""" - def test_overlap(range_test: FunctionalRangeBase, range_check: FunctionalRangeBase) -> bool: + def test_overlap(range_test: FunctionalClassificationBase, range_check: FunctionalClassificationBase) -> bool: # Allow 'not_specified' classifications to overlap with anything. - if range_test.classification == "not_specified" or range_check.classification == "not_specified": + if ( + range_test.functional_classification is FunctionalClassifcationOptions.not_specified + or range_check.functional_classification is FunctionalClassifcationOptions.not_specified + or range_test.range is None + or range_check.range is None + ): return False if min(inf_or_float(range_test.range[0], True), inf_or_float(range_check.range[0], True)) == inf_or_float( @@ -207,14 +307,15 @@ def test_overlap(range_test: FunctionalRangeBase, range_check: FunctionalRangeBa else: first, second = range_check, range_test + # The range types below that mypy complains about are verified by the earlier checks for None. touching_and_inclusive = ( first.inclusive_upper_bound and second.inclusive_lower_bound - and inf_or_float(first.range[1], False) == inf_or_float(second.range[0], True) + and inf_or_float(first.range[1], False) == inf_or_float(second.range[0], True) # type: ignore ) if touching_and_inclusive: return True - if inf_or_float(first.range[1], False) > inf_or_float(second.range[0], True): + if inf_or_float(first.range[1], False) > inf_or_float(second.range[0], True): # type: ignore return True return False @@ -232,23 +333,34 @@ def test_overlap(range_test: FunctionalRangeBase, range_check: FunctionalRangeBa return field_value @model_validator(mode="after") - def functional_range_labels_must_be_unique(self: "ScoreCalibrationBase") -> "ScoreCalibrationBase": - """Enforce uniqueness (post-strip) of functional range labels.""" - if not self.functional_ranges: + def functional_range_labels_classes_must_be_unique(self: "ScoreCalibrationBase") -> "ScoreCalibrationBase": + """Enforce uniqueness (post-strip) of functional range labels and classes.""" + if not self.functional_classifications: return self - seen, dupes = set(), set() - for i, fr in enumerate(self.functional_ranges): - fr.label = fr.label.strip() - if fr.label in seen: - dupes.add((fr.label, i)) + seen_l, dupes_l = set(), set() + seen_c, dupes_c = set(), set() + for i, fr in enumerate(self.functional_classifications): + if fr.label in seen_l: + dupes_l.add((fr.label, i)) else: - seen.add(fr.label) + seen_l.add(fr.label) + + if fr.class_ is not None: + if fr.class_ in seen_c: + dupes_c.add((fr.class_, i)) + else: + seen_c.add(fr.class_) - if dupes: + if dupes_l: + raise ValidationError( + f"Detected repeated label(s): {', '.join(label for label, _ in dupes_l)}. Functional range labels must be unique.", + custom_loc=["body", "functionalClassifications", dupes_l.pop()[1], "label"], + ) + if dupes_c: raise ValidationError( - f"Detected repeated label(s): {', '.join(label for label, _ in dupes)}. Functional range labels must be unique.", - custom_loc=["body", "functionalRanges", dupes.pop()[1], "label"], + f"Detected repeated class name(s): {', '.join(class_name for class_name, _ in dupes_c)}. Functional range class names must be unique.", + custom_loc=["body", "functionalClassifications", dupes_c.pop()[1], "class"], ) return self @@ -256,14 +368,17 @@ def functional_range_labels_must_be_unique(self: "ScoreCalibrationBase") -> "Sco @model_validator(mode="after") def validate_baseline_score(self: "ScoreCalibrationBase") -> "ScoreCalibrationBase": """If a baseline score is provided and it falls within a functional range, it may only be contained in a normal range.""" - if not self.functional_ranges: + if not self.functional_classifications: return self if self.baseline_score is None: return self - for fr in self.functional_ranges: - if fr.is_contained_by_range(self.baseline_score) and fr.classification != "normal": + for fr in self.functional_classifications: + if ( + fr.is_contained_by_range(self.baseline_score) + and fr.functional_classification is not FunctionalClassifcationOptions.normal + ): raise ValidationError( f"The provided baseline score of {self.baseline_score} falls within a non-normal range ({fr.label}). Baseline scores may not fall within non-normal ranges.", custom_loc=["body", "baselineScore"], @@ -271,25 +386,60 @@ def validate_baseline_score(self: "ScoreCalibrationBase") -> "ScoreCalibrationBa return self + @model_validator(mode="after") + def functional_classifications_must_be_of_same_type( + self: "ScoreCalibrationBase", + ) -> "ScoreCalibrationBase": + """All functional classifications must be either range-based or class-based.""" + if not self.functional_classifications: + return self + + range_based_count = sum(1 for fc in self.functional_classifications if fc.range_based) + class_based_count = sum(1 for fc in self.functional_classifications if fc.class_based) + + if range_based_count > 0 and class_based_count > 0: + raise ValidationError( + "All functional classifications within a score calibration must be of the same type (either all range-based or all class-based).", + custom_loc=["body", "functionalClassifications"], + ) + + return self + + @property + def range_based(self) -> bool: + """Determine if this score calibration is range-based.""" + if not self.functional_classifications: + return False + + return self.functional_classifications[0].range_based + + @property + def class_based(self) -> bool: + """Determine if this score calibration is class-based.""" + if not self.functional_classifications: + return False + + return self.functional_classifications[0].class_based + class ScoreCalibrationModify(ScoreCalibrationBase): """Model used to modify an existing score calibration.""" score_set_urn: Optional[str] = None - functional_ranges: Optional[Sequence[FunctionalRangeModify]] = None - threshold_sources: Optional[Sequence[PublicationIdentifierCreate]] = None - classification_sources: Optional[Sequence[PublicationIdentifierCreate]] = None - method_sources: Optional[Sequence[PublicationIdentifierCreate]] = None + functional_classifications: Optional[Sequence[FunctionalClassificationModify]] = None + threshold_sources: Sequence[PublicationIdentifierCreate] + classification_sources: Sequence[PublicationIdentifierCreate] + method_sources: Sequence[PublicationIdentifierCreate] class ScoreCalibrationCreate(ScoreCalibrationModify): """Model used to create a new score calibration.""" - functional_ranges: Optional[Sequence[FunctionalRangeCreate]] = None - threshold_sources: Optional[Sequence[PublicationIdentifierCreate]] = None - classification_sources: Optional[Sequence[PublicationIdentifierCreate]] = None - method_sources: Optional[Sequence[PublicationIdentifierCreate]] = None + functional_classifications: Optional[Sequence[FunctionalClassificationCreate]] = None + threshold_sources: Sequence[PublicationIdentifierCreate] + classification_sources: Sequence[PublicationIdentifierCreate] + method_sources: Sequence[PublicationIdentifierCreate] class SavedScoreCalibration(ScoreCalibrationBase): @@ -306,10 +456,10 @@ class SavedScoreCalibration(ScoreCalibrationBase): primary: bool = False private: bool = True - functional_ranges: Optional[Sequence[SavedFunctionalRange]] = None - threshold_sources: Optional[Sequence[SavedPublicationIdentifier]] = None - classification_sources: Optional[Sequence[SavedPublicationIdentifier]] = None - method_sources: Optional[Sequence[SavedPublicationIdentifier]] = None + functional_classifications: Optional[Sequence[SavedFunctionalClassification]] = None + threshold_sources: Sequence[SavedPublicationIdentifier] + classification_sources: Sequence[SavedPublicationIdentifier] + method_sources: Sequence[SavedPublicationIdentifier] created_by: Optional[SavedUser] = None modified_by: Optional[SavedUser] = None @@ -327,9 +477,6 @@ class Config: @field_validator("threshold_sources", "classification_sources", "method_sources", mode="before") def publication_identifiers_validator(cls, value: Any) -> Optional[list[PublicationIdentifier]]: """Coerce association proxy collections to plain lists.""" - if value is None: - return None - assert isinstance(value, Collection), "Publication identifier lists must be a collection" return list(value) @@ -354,19 +501,13 @@ def primary_calibrations_may_not_be_private(self: "SavedScoreCalibration") -> "S return self + # These 'synthetic' fields are generated from other model properties. Transform data from other properties as needed, setting + # the appropriate field on the model itself. Then, proceed with Pydantic ingestion once fields are created. Only perform these + # transformations if the relevant attributes are present on the input data (i.e., when creating from an ORM object). @model_validator(mode="before") def generate_threshold_classification_and_method_sources(cls, data: Any): # type: ignore[override] """Populate threshold/classification/method source fields from association objects if missing.""" - association_keys = { - "threshold_sources", - "thresholdSources", - "classification_sources", - "classificationSources", - "method_sources", - "methodSources", - } - - if not any(hasattr(data, key) for key in association_keys): + if hasattr(data, "publication_identifier_associations"): try: publication_identifiers = transform_score_calibration_publication_identifiers( data.publication_identifier_associations @@ -374,9 +515,9 @@ def generate_threshold_classification_and_method_sources(cls, data: Any): # typ data.__setattr__("threshold_sources", publication_identifiers["threshold_sources"]) data.__setattr__("classification_sources", publication_identifiers["classification_sources"]) data.__setattr__("method_sources", publication_identifiers["method_sources"]) - except AttributeError as exc: + except (AttributeError, KeyError) as exc: raise ValidationError( - f"Unable to create {cls.__name__} without attribute: {exc}." # type: ignore + f"Unable to coerce publication associations for {cls.__name__}: {exc}." # type: ignore ) return data @@ -384,10 +525,10 @@ def generate_threshold_classification_and_method_sources(cls, data: Any): # typ class ScoreCalibration(SavedScoreCalibration): """Complete score calibration model returned by the API.""" - functional_ranges: Optional[Sequence[FunctionalRange]] = None - threshold_sources: Optional[Sequence[PublicationIdentifier]] = None - classification_sources: Optional[Sequence[PublicationIdentifier]] = None - method_sources: Optional[Sequence[PublicationIdentifier]] = None + functional_classifications: Optional[Sequence[FunctionalClassification]] = None + threshold_sources: Sequence[PublicationIdentifier] + classification_sources: Sequence[PublicationIdentifier] + method_sources: Sequence[PublicationIdentifier] created_by: Optional[User] = None modified_by: Optional[User] = None @@ -399,11 +540,11 @@ class ScoreCalibrationWithScoreSetUrn(SavedScoreCalibration): @model_validator(mode="before") def generate_score_set_urn(cls, data: Any): - if not hasattr(data, "score_set_urn"): + if hasattr(data, "score_set"): try: data.__setattr__("score_set_urn", transform_score_set_to_urn(data.score_set)) - except AttributeError as exc: + except (AttributeError, KeyError) as exc: raise ValidationError( - f"Unable to create {cls.__name__} without attribute: {exc}." # type: ignore + f"Unable to coerce score set urn for {cls.__name__}: {exc}." # type: ignore ) return data diff --git a/src/mavedb/view_models/score_set.py b/src/mavedb/view_models/score_set.py index 9f53cf64..84c445ee 100644 --- a/src/mavedb/view_models/score_set.py +++ b/src/mavedb/view_models/score_set.py @@ -3,7 +3,7 @@ import json from datetime import date -from typing import Any, Collection, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Collection, Optional, Sequence, Union from pydantic import field_validator, model_validator from typing_extensions import Self @@ -19,6 +19,8 @@ from mavedb.models.enums.processing_state import ProcessingState from mavedb.view_models import record_type_validator, set_record_type from mavedb.view_models.base.base import BaseModel +from mavedb.view_models.collection import OfficialCollection +from mavedb.view_models.components.external_link import ExternalLink from mavedb.view_models.contributor import Contributor, ContributorCreate from mavedb.view_models.doi_identifier import ( DoiIdentifier, @@ -46,21 +48,11 @@ from mavedb.view_models.user import SavedUser, User from mavedb.view_models.utils import all_fields_optional_model -UnboundedRange = tuple[Union[float, None], Union[float, None]] - - -class ExternalLink(BaseModel): - url: Optional[str] = None - - -class OfficialCollection(BaseModel): - badge_name: str - name: str - urn: str +if TYPE_CHECKING: + from mavedb.view_models.experiment import Experiment + from mavedb.view_models.variant import SavedVariantEffectMeasurement - class Config: - arbitrary_types_allowed = True - from_attributes = True +UnboundedRange = tuple[Union[float, None], Union[float, None]] class ScoreSetBase(BaseModel): @@ -109,7 +101,7 @@ def targets_need_labels_when_multiple_targets_exist(self) -> Self: "Target sequence labels cannot be empty when multiple targets are defined.", custom_loc=[ "body", - "targetGene", + "targetGenes", idx, "targetSequence", "label", @@ -134,7 +126,7 @@ def target_labels_are_unique(self) -> Self: "Target sequence labels cannot be duplicated.", custom_loc=[ "body", - "targetGene", + "targetGenes", dup_indices[-1], "targetSequence", "label", @@ -161,7 +153,7 @@ def target_accession_base_editor_targets_are_consistent(cls, field_value, values "All target accessions must be of the same base editor type.", custom_loc=[ "body", - "targetGene", + "targetGenes", 0, "targetAccession", "isBaseEditor", @@ -311,12 +303,11 @@ class Config: arbitrary_types_allowed = True # These 'synthetic' fields are generated from other model properties. Transform data from other properties as needed, setting - # the appropriate field on the model itself. Then, proceed with Pydantic ingestion once fields are created. + # the appropriate field on the model itself. Then, proceed with Pydantic ingestion once fields are created. Only perform these + # transformations if the relevant attributes are present on the input data (i.e., when creating from an ORM object). @model_validator(mode="before") def generate_primary_and_secondary_publications(cls, data: Any): - if not hasattr(data, "primary_publication_identifiers") or not hasattr( - data, "secondary_publication_identifiers" - ): + if hasattr(data, "publication_identifier_associations"): try: publication_identifiers = transform_record_publication_identifiers( data.publication_identifier_associations @@ -327,9 +318,9 @@ def generate_primary_and_secondary_publications(cls, data: Any): data.__setattr__( "secondary_publication_identifiers", publication_identifiers["secondary_publication_identifiers"] ) - except AttributeError as exc: + except (AttributeError, KeyError) as exc: raise ValidationError( - f"Unable to create {cls.__name__} without attribute: {exc}." # type: ignore + f"Unable to coerce publication identifier attributes for {cls.__name__}: {exc}." # type: ignore ) return data @@ -384,12 +375,11 @@ def publication_identifiers_validator(cls, value: Any) -> list[PublicationIdenti return list(value) # Re-cast into proper list-like type # These 'synthetic' fields are generated from other model properties. Transform data from other properties as needed, setting - # the appropriate field on the model itself. Then, proceed with Pydantic ingestion once fields are created. + # the appropriate field on the model itself. Then, proceed with Pydantic ingestion once fields are created. Only perform these + # transformations if the relevant attributes are present on the input data (i.e., when creating from an ORM object). @model_validator(mode="before") def generate_primary_and_secondary_publications(cls, data: Any): - if not hasattr(data, "primary_publication_identifiers") or not hasattr( - data, "secondary_publication_identifiers" - ): + if hasattr(data, "publication_identifier_associations"): try: publication_identifiers = transform_record_publication_identifiers( data.publication_identifier_associations @@ -400,33 +390,35 @@ def generate_primary_and_secondary_publications(cls, data: Any): data.__setattr__( "secondary_publication_identifiers", publication_identifiers["secondary_publication_identifiers"] ) - except AttributeError as exc: - raise ValidationError( - f"Unable to create {cls.__name__} without attribute: {exc}." # type: ignore - ) + except (AttributeError, KeyError) as exc: + raise ValidationError(f"Unable to coerce publication identifier attributes for {cls.__name__}: {exc}.") return data @model_validator(mode="before") def transform_meta_analysis_objects_to_urns(cls, data: Any): - if not hasattr(data, "meta_analyzes_score_set_urns"): + if hasattr(data, "meta_analyzes_score_sets"): try: data.__setattr__( "meta_analyzes_score_set_urns", transform_score_set_list_to_urn_list(data.meta_analyzes_score_sets) ) - except AttributeError as exc: - raise ValidationError(f"Unable to create {cls.__name__} without attribute: {exc}.") # type: ignore + except (AttributeError, KeyError) as exc: + raise ValidationError( + f"Unable to coerce meta analyzes score set urn attribute for {cls.__name__}: {exc}." + ) return data @model_validator(mode="before") def transform_meta_analyzed_objects_to_urns(cls, data: Any): - if not hasattr(data, "meta_analyzed_by_score_set_urns"): + if hasattr(data, "meta_analyzed_by_score_sets"): try: data.__setattr__( "meta_analyzed_by_score_set_urns", transform_score_set_list_to_urn_list(data.meta_analyzed_by_score_sets), ) - except AttributeError as exc: - raise ValidationError(f"Unable to create {cls.__name__} without attribute: {exc}.") # type: ignore + except (AttributeError, KeyError) as exc: + raise ValidationError( + f"Unable to coerce meta analyzed by score set urn attribute for {cls.__name__}: {exc}." + ) return data @@ -456,7 +448,7 @@ class ScoreSetWithVariants(ScoreSet): are requested. """ - variants: list[SavedVariantEffectMeasurement] + variants: list["SavedVariantEffectMeasurement"] class AdminScoreSet(ScoreSet): @@ -482,13 +474,3 @@ class ScoreSetPublicDump(SavedScoreSet): mapping_state: Optional[MappingState] = None mapping_errors: Optional[dict] = None score_calibrations: Optional[Sequence[ScoreCalibration]] = None # type: ignore[assignment] - - -# ruff: noqa: E402 -from mavedb.view_models.experiment import Experiment -from mavedb.view_models.variant import SavedVariantEffectMeasurement - -ScoreSetWithVariants.model_rebuild() -ShortScoreSet.model_rebuild() -ScoreSet.model_rebuild() -ScoreSetWithVariants.model_rebuild() diff --git a/src/mavedb/view_models/target_gene.py b/src/mavedb/view_models/target_gene.py index 48396a98..02ae0cbc 100644 --- a/src/mavedb/view_models/target_gene.py +++ b/src/mavedb/view_models/target_gene.py @@ -69,15 +69,16 @@ class Config: arbitrary_types_allowed = True # These 'synthetic' fields are generated from other model properties. Transform data from other properties as needed, setting - # the appropriate field on the model itself. Then, proceed with Pydantic ingestion once fields are created. + # the appropriate field on the model itself. Then, proceed with Pydantic ingestion once fields are created. Only perform these + # transformations if the relevant attributes are present on the input data (i.e., when creating from an ORM object). @model_validator(mode="before") def generate_external_identifiers_list(cls, data: Any): - if not hasattr(data, "external_identifiers"): + if hasattr(data, "ensembl_offset") or hasattr(data, "refseq_offset") or hasattr(data, "uniprot_offset"): try: data.__setattr__("external_identifiers", transform_external_identifier_offsets_to_list(data)) - except AttributeError as exc: + except (AttributeError, KeyError) as exc: raise ValidationError( - f"Unable to create {cls.__name__} without attribute: {exc}." # type: ignore + f"Unable to coerce external identifiers for {cls.__name__}: {exc}." # type: ignore ) return data @@ -108,15 +109,16 @@ class TargetGeneWithScoreSetUrn(TargetGene): score_set_urn: str # These 'synthetic' fields are generated from other model properties. Transform data from other properties as needed, setting - # the appropriate field on the model itself. Then, proceed with Pydantic ingestion once fields are created. + # the appropriate field on the model itself. Then, proceed with Pydantic ingestion once fields are created. Only perform these + # transformations if the relevant attributes are present on the input data (i.e., when creating from an ORM object). @model_validator(mode="before") def generate_score_set_urn(cls, data: Any): - if not hasattr(data, "score_set_urn"): + if hasattr(data, "score_set"): try: data.__setattr__("score_set_urn", transform_score_set_to_urn(data.score_set)) - except AttributeError as exc: + except (AttributeError, KeyError) as exc: raise ValidationError( - f"Unable to create {cls.__name__} without attribute: {exc}." # type: ignore + f"Unable to coerce score set urn for {cls.__name__}: {exc}." # type: ignore ) return data diff --git a/src/mavedb/view_models/variant.py b/src/mavedb/view_models/variant.py index 2fc62d7f..b01b0183 100644 --- a/src/mavedb/view_models/variant.py +++ b/src/mavedb/view_models/variant.py @@ -1,12 +1,15 @@ from datetime import date -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from pydantic import model_validator from mavedb.lib.validation.exceptions import ValidationError -from mavedb.view_models.mapped_variant import MappedVariant, SavedMappedVariant from mavedb.view_models import record_type_validator, set_record_type from mavedb.view_models.base.base import BaseModel +from mavedb.view_models.mapped_variant import MappedVariant, SavedMappedVariant + +if TYPE_CHECKING: + from mavedb.view_models.score_set import ScoreSet, ShortScoreSet class VariantEffectMeasurementBase(BaseModel): @@ -51,18 +54,19 @@ class SavedVariantEffectMeasurementWithMappedVariant(SavedVariantEffectMeasureme mapped_variant: Optional[SavedMappedVariant] = None + # These 'synthetic' fields are generated from other model properties. Transform data from other properties as needed, setting + # the appropriate field on the model itself. Then, proceed with Pydantic ingestion once fields are created. Only perform these + # transformations if the relevant attributes are present on the input data (i.e., when creating from an ORM object). @model_validator(mode="before") - def generate_score_set_urn_list(cls, data: Any): - if not hasattr(data, "mapped_variant"): + def generate_associated_mapped_variant(cls, data: Any): + if hasattr(data, "mapped_variants"): try: - mapped_variant = None - if data.mapped_variants: - mapped_variant = next( - mapped_variant for mapped_variant in data.mapped_variants if mapped_variant.current - ) + mapped_variant = next( + (mapped_variant for mapped_variant in data.mapped_variants if mapped_variant.current), None + ) data.__setattr__("mapped_variant", mapped_variant) - except AttributeError as exc: - raise ValidationError(f"Unable to create {cls.__name__} without attribute: {exc}.") # type: ignore + except (AttributeError, KeyError) as exc: + raise ValidationError(f"Unable to coerce mapped variant for {cls.__name__}: {exc}.") # type: ignore return data @@ -106,10 +110,3 @@ class ClingenAlleleIdVariantLookupResponse(BaseModel): exact_match: Optional[Variant] = None equivalent_nt: list[Variant] = [] equivalent_aa: list[Variant] = [] - - -# ruff: noqa: E402 -from mavedb.view_models.score_set import ScoreSet, ShortScoreSet - -VariantEffectMeasurementWithScoreSet.update_forward_refs() -VariantEffectMeasurementWithShortScoreSet.update_forward_refs() diff --git a/tests/conftest.py b/tests/conftest.py index c79c033e..33e709e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import logging # noqa: F401 +import sys from datetime import datetime from unittest import mock -import sys import email_validator import pytest @@ -11,35 +11,33 @@ from sqlalchemy.pool import NullPool from mavedb.db.base import Base +from mavedb.models import * # noqa: F403 +from mavedb.models.experiment import Experiment from mavedb.models.experiment_set import ExperimentSet -from mavedb.models.score_set_publication_identifier import ScoreSetPublicationIdentifierAssociation -from mavedb.models.user import User, UserRole, Role from mavedb.models.license import License -from mavedb.models.taxonomy import Taxonomy -from mavedb.models.publication_identifier import PublicationIdentifier -from mavedb.models.experiment import Experiment -from mavedb.models.variant import Variant from mavedb.models.mapped_variant import MappedVariant +from mavedb.models.publication_identifier import PublicationIdentifier from mavedb.models.score_set import ScoreSet - -from mavedb.models import * # noqa: F403 - +from mavedb.models.score_set_publication_identifier import ScoreSetPublicationIdentifierAssociation +from mavedb.models.taxonomy import Taxonomy +from mavedb.models.user import Role, User, UserRole +from mavedb.models.variant import Variant from tests.helpers.constants import ( ADMIN_USER, EXTRA_USER, - TEST_LICENSE, + TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, TEST_INACTIVE_LICENSE, + TEST_LICENSE, + TEST_PATHOGENICITY_SCORE_CALIBRATION, + TEST_PUBMED_IDENTIFIER, TEST_SAVED_TAXONOMY, TEST_USER, - VALID_VARIANT_URN, - VALID_SCORE_SET_URN, - VALID_EXPERIMENT_URN, - VALID_EXPERIMENT_SET_URN, - TEST_PUBMED_IDENTIFIER, TEST_VALID_POST_MAPPED_VRS_ALLELE_VRS2_X, TEST_VALID_PRE_MAPPED_VRS_ALLELE_VRS2_X, - TEST_BRNICH_SCORE_CALIBRATION, - TEST_PATHOGENICITY_SCORE_CALIBRATION, + VALID_EXPERIMENT_SET_URN, + VALID_EXPERIMENT_URN, + VALID_SCORE_SET_URN, + VALID_VARIANT_URN, ) sys.path.append(".") @@ -56,7 +54,7 @@ assert pytest_postgresql.factories # Allow the @test domain name through our email validator. -email_validator.SPECIAL_USE_DOMAIN_NAMES.remove("test") +email_validator.TEST_ENVIRONMENT = True @pytest.fixture() @@ -145,7 +143,7 @@ def mock_experiment(): def mock_score_set(mock_user, mock_experiment, mock_publication_associations): score_set = mock.Mock(spec=ScoreSet) score_set.urn = VALID_SCORE_SET_URN - score_set.score_calibrations = [TEST_BRNICH_SCORE_CALIBRATION, TEST_PATHOGENICITY_SCORE_CALIBRATION] + score_set.score_calibrations = [TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, TEST_PATHOGENICITY_SCORE_CALIBRATION] score_set.license.short_name = "MIT" score_set.created_by = mock_user score_set.modified_by = mock_user diff --git a/tests/conftest_optional.py b/tests/conftest_optional.py index 8597c4f9..a07607a7 100644 --- a/tests/conftest_optional.py +++ b/tests/conftest_optional.py @@ -1,9 +1,10 @@ import os +import shutil +import tempfile from concurrent import futures from inspect import getsourcefile from posixpath import abspath -import shutil -import tempfile +from unittest.mock import patch import cdot.hgvs.dataproviders import pytest @@ -12,15 +13,14 @@ from biocommons.seqrepo import SeqRepo from fastapi.testclient import TestClient from httpx import AsyncClient -from unittest.mock import patch -from mavedb.lib.authentication import UserData, get_current_user +from mavedb.deps import get_db, get_seqrepo, get_worker, hgvs_data_provider +from mavedb.lib.authentication import get_current_user from mavedb.lib.authorization import require_current_user +from mavedb.lib.types.authentication import UserData from mavedb.models.user import User from mavedb.server_main import app -from mavedb.deps import get_db, get_worker, hgvs_data_provider, get_seqrepo -from mavedb.worker.settings import BACKGROUND_FUNCTIONS, BACKGROUND_CRONJOBS - +from mavedb.worker.settings import BACKGROUND_CRONJOBS, BACKGROUND_FUNCTIONS from tests.helpers.constants import ADMIN_USER, EXTRA_USER, TEST_SEQREPO_INITIAL_STATE, TEST_USER #################################################################################################### diff --git a/tests/helpers/constants.py b/tests/helpers/constants.py index 1a219f17..821439c5 100644 --- a/tests/helpers/constants.py +++ b/tests/helpers/constants.py @@ -2,6 +2,7 @@ from humps import camelize +from mavedb.models.enums.functional_classification import FunctionalClassification as FunctionalClassificationOptions from mavedb.models.enums.processing_state import ProcessingState VALID_EXPERIMENT_SET_URN = "urn:mavedb:01234567" @@ -427,7 +428,7 @@ "special": False, "description": "Description", }, - {"key": "Delivery method", "label": "Other", "special": False, "description": "Description"}, + {"key": "Delivery Method", "label": "Other", "special": False, "description": "Description"}, { "key": "Phenotypic Assay Mechanism", "label": "Other", @@ -442,6 +443,13 @@ "special": False, "description": "Description", }, + { + "key": "Phenotypic Assay Profiling Strategy", + "label": "Shotgun sequencing", + "code": None, + "special": False, + "description": "Description", + }, ] TEST_KEYWORDS = [ @@ -470,7 +478,7 @@ }, }, { - "keyword": {"key": "Delivery method", "label": "Other", "special": False, "description": "Description"}, + "keyword": {"key": "Delivery Method", "label": "Other", "special": False, "description": "Description"}, "description": "Details of delivery method", }, ] @@ -492,7 +500,7 @@ "methodText": "Methods", "keywords": [ { - "keyword": {"key": "Delivery method", "label": "Other", "special": False, "description": "Description"}, + "keyword": {"key": "Delivery Method", "label": "Other", "special": False, "description": "Description"}, "description": "Details of delivery method", }, ], @@ -540,6 +548,7 @@ "primaryPublicationIdentifiers": [], "secondaryPublicationIdentifiers": [], "rawReadIdentifiers": [], + "externalLinks": {}, # keys to be set after receiving response "urn": None, "experimentSetUrn": None, @@ -572,7 +581,7 @@ "keywords": [ { "recordType": "ExperimentControlledKeyword", - "keyword": {"key": "Delivery method", "label": "Other", "special": False, "description": "Description"}, + "keyword": {"key": "Delivery Method", "label": "Other", "special": False, "description": "Description"}, "description": "Details of delivery method", }, ], @@ -580,6 +589,53 @@ "primaryPublicationIdentifiers": [], "secondaryPublicationIdentifiers": [], "rawReadIdentifiers": [], + "externalLinks": {}, + # keys to be set after receiving response + "urn": None, + "experimentSetUrn": None, + "officialCollections": [], + "numScoreSets": 0, # NOTE: This is context-dependent and may need overriding per test +} + +TEST_EXPERIMENT_WITH_UPDATE_KEYWORD_RESPONSE = { + "recordType": "Experiment", + "title": "Test Experiment Title", + "shortDescription": "Test experiment", + "abstractText": "Abstract", + "methodText": "Methods", + "createdBy": { + "recordType": "User", + "firstName": TEST_USER["first_name"], + "lastName": TEST_USER["last_name"], + "orcidId": TEST_USER["username"], + }, + "modifiedBy": { + "recordType": "User", + "firstName": TEST_USER["first_name"], + "lastName": TEST_USER["last_name"], + "orcidId": TEST_USER["username"], + }, + "creationDate": date.today().isoformat(), + "modificationDate": date.today().isoformat(), + "scoreSetUrns": [], + "contributors": [], + "keywords": [ + { + "recordType": "ExperimentControlledKeyword", + "keyword": { + "key": "Phenotypic Assay Profiling Strategy", + "label": "Shotgun sequencing", + "special": False, + "description": "Description", + }, + "description": "Details of phenotypic assay profiling strategy", + }, + ], + "doiIdentifiers": [], + "primaryPublicationIdentifiers": [], + "secondaryPublicationIdentifiers": [], + "rawReadIdentifiers": [], + "externalLinks": {}, # keys to be set after receiving response "urn": None, "experimentSetUrn": None, @@ -622,7 +678,7 @@ }, { "recordType": "ExperimentControlledKeyword", - "keyword": {"key": "Delivery method", "label": "Other", "special": False, "description": "Description"}, + "keyword": {"key": "Delivery Method", "label": "Other", "special": False, "description": "Description"}, "description": "Description", }, ], @@ -630,6 +686,7 @@ "primaryPublicationIdentifiers": [], "secondaryPublicationIdentifiers": [], "rawReadIdentifiers": [], + "externalLinks": {}, # keys to be set after receiving response "urn": None, "experimentSetUrn": None, @@ -1355,44 +1412,52 @@ TEST_ACMG_BS3_STRONG_CLASSIFICATION = { "criterion": "BS3", - "evidence_strength": "strong", + "evidence_strength": "STRONG", } TEST_SAVED_ACMG_BS3_STRONG_CLASSIFICATION = { "recordType": "ACMGClassification", + "creationDate": date.today().isoformat(), + "modificationDate": date.today().isoformat(), **{camelize(k): v for k, v in TEST_ACMG_BS3_STRONG_CLASSIFICATION.items()}, } TEST_ACMG_PS3_STRONG_CLASSIFICATION = { "criterion": "PS3", - "evidence_strength": "strong", + "evidence_strength": "STRONG", } TEST_SAVED_ACMG_PS3_STRONG_CLASSIFICATION = { "recordType": "ACMGClassification", + "creationDate": date.today().isoformat(), + "modificationDate": date.today().isoformat(), **{camelize(k): v for k, v in TEST_ACMG_PS3_STRONG_CLASSIFICATION.items()}, } TEST_ACMG_BS3_STRONG_CLASSIFICATION_WITH_POINTS = { "criterion": "BS3", - "evidence_strength": "strong", + "evidence_strength": "STRONG", "points": -4, } TEST_SAVED_ACMG_BS3_STRONG_CLASSIFICATION_WITH_POINTS = { "recordType": "ACMGClassification", + "creationDate": date.today().isoformat(), + "modificationDate": date.today().isoformat(), **{camelize(k): v for k, v in TEST_ACMG_BS3_STRONG_CLASSIFICATION_WITH_POINTS.items()}, } TEST_ACMG_PS3_STRONG_CLASSIFICATION_WITH_POINTS = { "criterion": "PS3", - "evidence_strength": "strong", + "evidence_strength": "STRONG", "points": 4, } TEST_SAVED_ACMG_PS3_STRONG_CLASSIFICATION_WITH_POINTS = { "recordType": "ACMGClassification", + "creationDate": date.today().isoformat(), + "modificationDate": date.today().isoformat(), **{camelize(k): v for k, v in TEST_ACMG_PS3_STRONG_CLASSIFICATION_WITH_POINTS.items()}, } @@ -1403,7 +1468,7 @@ TEST_FUNCTIONAL_RANGE_NORMAL = { "label": "test normal functional range", "description": "A normal functional range", - "classification": "normal", + "functional_classification": FunctionalClassificationOptions.normal.value, "range": [1.0, 5.0], "acmg_classification": TEST_ACMG_BS3_STRONG_CLASSIFICATION, "oddspaths_ratio": TEST_BS3_STRONG_ODDS_PATH_RATIO, @@ -1413,16 +1478,17 @@ TEST_SAVED_FUNCTIONAL_RANGE_NORMAL = { - "recordType": "FunctionalRange", + "recordType": "FunctionalClassification", **{camelize(k): v for k, v in TEST_FUNCTIONAL_RANGE_NORMAL.items() if k not in ("acmg_classification",)}, "acmgClassification": TEST_SAVED_ACMG_BS3_STRONG_CLASSIFICATION, + "variants": [], } TEST_FUNCTIONAL_RANGE_ABNORMAL = { "label": "test abnormal functional range", "description": "An abnormal functional range", - "classification": "abnormal", + "functional_classification": FunctionalClassificationOptions.abnormal.value, "range": [-5.0, -1.0], "acmg_classification": TEST_ACMG_PS3_STRONG_CLASSIFICATION, "oddspaths_ratio": TEST_PS3_STRONG_ODDS_PATH_RATIO, @@ -1432,15 +1498,16 @@ TEST_SAVED_FUNCTIONAL_RANGE_ABNORMAL = { - "recordType": "FunctionalRange", + "recordType": "FunctionalClassification", **{camelize(k): v for k, v in TEST_FUNCTIONAL_RANGE_ABNORMAL.items() if k not in ("acmg_classification",)}, "acmgClassification": TEST_SAVED_ACMG_PS3_STRONG_CLASSIFICATION, + "variants": [], } TEST_FUNCTIONAL_RANGE_NOT_SPECIFIED = { "label": "test not specified functional range", - "classification": "not_specified", + "functional_classification": FunctionalClassificationOptions.not_specified.value, "range": [-1.0, 1.0], "inclusive_lower_bound": True, "inclusive_upper_bound": False, @@ -1448,15 +1515,66 @@ TEST_SAVED_FUNCTIONAL_RANGE_NOT_SPECIFIED = { - "recordType": "FunctionalRange", + "recordType": "FunctionalClassification", **{camelize(k): v for k, v in TEST_FUNCTIONAL_RANGE_NOT_SPECIFIED.items()}, + "variants": [], +} + + +TEST_FUNCTIONAL_CLASSIFICATION_NORMAL = { + "label": "test normal functional class", + "description": "A normal functional class", + "functional_classification": FunctionalClassificationOptions.normal.value, + "class": "normal_class", + "acmg_classification": TEST_ACMG_BS3_STRONG_CLASSIFICATION, + "oddspaths_ratio": TEST_BS3_STRONG_ODDS_PATH_RATIO, +} + + +TEST_SAVED_FUNCTIONAL_CLASSIFICATION_NORMAL = { + "recordType": "FunctionalClassification", + **{camelize(k): v for k, v in TEST_FUNCTIONAL_CLASSIFICATION_NORMAL.items() if k not in ("acmg_classification",)}, + "acmgClassification": TEST_SAVED_ACMG_BS3_STRONG_CLASSIFICATION, + "variants": [], +} + + +TEST_FUNCTIONAL_CLASSIFICATION_ABNORMAL = { + "label": "test abnormal functional class", + "description": "An abnormal functional class", + "functional_classification": FunctionalClassificationOptions.abnormal.value, + "class": "abnormal_class", + "acmg_classification": TEST_ACMG_PS3_STRONG_CLASSIFICATION, + "oddspaths_ratio": TEST_PS3_STRONG_ODDS_PATH_RATIO, +} + + +TEST_SAVED_FUNCTIONAL_CLASSIFICATION_ABNORMAL = { + "recordType": "FunctionalClassification", + **{camelize(k): v for k, v in TEST_FUNCTIONAL_CLASSIFICATION_ABNORMAL.items() if k not in ("acmg_classification",)}, + "acmgClassification": TEST_SAVED_ACMG_PS3_STRONG_CLASSIFICATION, + "variants": [], +} + + +TEST_FUNCTIONAL_CLASSIFICATION_NOT_SPECIFIED = { + "label": "test not specified functional class", + "functional_classification": FunctionalClassificationOptions.not_specified.value, + "class": "not_specified_class", +} + + +TEST_SAVED_FUNCTIONAL_CLASSIFICATION_NOT_SPECIFIED = { + "recordType": "FunctionalClassification", + **{camelize(k): v for k, v in TEST_FUNCTIONAL_CLASSIFICATION_NOT_SPECIFIED.items()}, + "variants": [], } TEST_FUNCTIONAL_RANGE_INCLUDING_NEGATIVE_INFINITY = { "label": "test functional range including negative infinity", "description": "A functional range including negative infinity", - "classification": "not_specified", + "functional_classification": FunctionalClassificationOptions.not_specified.value, "range": [None, 0.0], "inclusive_lower_bound": False, "inclusive_upper_bound": False, @@ -1464,7 +1582,7 @@ TEST_SAVED_FUNCTIONAL_RANGE_INCLUDING_NEGATIVE_INFINITY = { - "recordType": "FunctionalRange", + "recordType": "FunctionalClassification", **{camelize(k): v for k, v in TEST_FUNCTIONAL_RANGE_INCLUDING_NEGATIVE_INFINITY.items()}, } @@ -1472,7 +1590,7 @@ TEST_FUNCTIONAL_RANGE_INCLUDING_POSITIVE_INFINITY = { "label": "test functional range including positive infinity", "description": "A functional range including positive infinity", - "classification": "not_specified", + "functional_classification": FunctionalClassificationOptions.not_specified.value, "range": [0.0, None], "inclusive_lower_bound": False, "inclusive_upper_bound": False, @@ -1483,7 +1601,7 @@ "title": "Test BRNICH Score Calibration", "research_use_only": False, "investigator_provided": False, - "functional_ranges": [ + "functional_classifications": [ TEST_FUNCTIONAL_RANGE_NORMAL, TEST_FUNCTIONAL_RANGE_ABNORMAL, TEST_FUNCTIONAL_RANGE_NOT_SPECIFIED, @@ -1495,12 +1613,12 @@ } -TEST_BRNICH_SCORE_CALIBRATION = { +TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED = { "title": "Test BRNICH Score Calibration", "research_use_only": False, "baseline_score": TEST_BASELINE_SCORE, "baseline_score_description": "Test baseline score description", - "functional_ranges": [ + "functional_classifications": [ TEST_FUNCTIONAL_RANGE_NORMAL, TEST_FUNCTIONAL_RANGE_ABNORMAL, TEST_FUNCTIONAL_RANGE_NOT_SPECIFIED, @@ -1514,14 +1632,14 @@ "calibration_metadata": {}, } -TEST_SAVED_BRNICH_SCORE_CALIBRATION = { +TEST_SAVED_BRNICH_SCORE_CALIBRATION_RANGE_BASED = { "recordType": "ScoreCalibration", **{ camelize(k): v - for k, v in TEST_BRNICH_SCORE_CALIBRATION.items() - if k not in ("functional_ranges", "classification_sources", "threshold_sources", "method_sources") + for k, v in TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED.items() + if k not in ("functional_classifications", "classification_sources", "threshold_sources", "method_sources") }, - "functionalRanges": [ + "functionalClassifications": [ TEST_SAVED_FUNCTIONAL_RANGE_NORMAL, TEST_SAVED_FUNCTIONAL_RANGE_ABNORMAL, TEST_SAVED_FUNCTIONAL_RANGE_NOT_SPECIFIED, @@ -1551,18 +1669,37 @@ "modificationDate": date.today().isoformat(), } + +TEST_BRNICH_SCORE_CALIBRATION_CLASS_BASED = { + **TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, + "functional_classifications": [ + TEST_FUNCTIONAL_CLASSIFICATION_NORMAL, + TEST_FUNCTIONAL_CLASSIFICATION_ABNORMAL, + TEST_FUNCTIONAL_CLASSIFICATION_NOT_SPECIFIED, + ], +} + +TEST_SAVED_BRNICH_SCORE_CALIBRATION_CLASS_BASED = { + **TEST_SAVED_BRNICH_SCORE_CALIBRATION_RANGE_BASED, + "functionalClassifications": [ + TEST_SAVED_FUNCTIONAL_CLASSIFICATION_NORMAL, + TEST_SAVED_FUNCTIONAL_CLASSIFICATION_ABNORMAL, + TEST_SAVED_FUNCTIONAL_CLASSIFICATION_NOT_SPECIFIED, + ], +} + TEST_PATHOGENICITY_SCORE_CALIBRATION = { "title": "Test Pathogenicity Score Calibration", "research_use_only": False, "baseline_score": TEST_BASELINE_SCORE, "baseline_score_description": "Test baseline score description", - "functional_ranges": [ + "functional_classifications": [ TEST_FUNCTIONAL_RANGE_NORMAL, TEST_FUNCTIONAL_RANGE_ABNORMAL, ], "threshold_sources": [{"identifier": TEST_PUBMED_IDENTIFIER, "db_name": "PubMed"}], - "classification_sources": None, - "method_sources": None, + "classification_sources": [], + "method_sources": [], "calibration_metadata": {}, } @@ -1571,15 +1708,15 @@ **{ camelize(k): v for k, v in TEST_PATHOGENICITY_SCORE_CALIBRATION.items() - if k not in ("functional_ranges", "classification_sources", "threshold_sources", "method_sources") + if k not in ("functional_classifications", "classification_sources", "threshold_sources", "method_sources") }, - "functionalRanges": [ + "functionalClassifications": [ TEST_SAVED_FUNCTIONAL_RANGE_NORMAL, TEST_SAVED_FUNCTIONAL_RANGE_ABNORMAL, ], "thresholdSources": [SAVED_PUBMED_PUBLICATION], - "classificationSources": None, - "methodSources": None, + "classificationSources": [], + "methodSources": [], "id": 2, "investigatorProvided": True, "primary": False, diff --git a/tests/helpers/util/score_calibration.py b/tests/helpers/util/score_calibration.py index 8c432e8f..a535096c 100644 --- a/tests/helpers/util/score_calibration.py +++ b/tests/helpers/util/score_calibration.py @@ -6,16 +6,19 @@ from mavedb.models.score_calibration import ScoreCalibration from mavedb.models.user import User from mavedb.view_models.score_calibration import ScoreCalibrationCreate, ScoreCalibrationWithScoreSetUrn - -from tests.helpers.constants import TEST_BRNICH_SCORE_CALIBRATION +from tests.helpers.constants import TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED if TYPE_CHECKING: - from sqlalchemy.orm import Session from fastapi.testclient import TestClient + from sqlalchemy.orm import Session -async def create_test_score_calibration_in_score_set(db: "Session", score_set_urn: str, user: User) -> ScoreCalibration: - calibration_create = ScoreCalibrationCreate(**TEST_BRNICH_SCORE_CALIBRATION, score_set_urn=score_set_urn) +async def create_test_range_based_score_calibration_in_score_set( + db: "Session", score_set_urn: str, user: User +) -> ScoreCalibration: + calibration_create = ScoreCalibrationCreate( + **TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, score_set_urn=score_set_urn + ) created_score_calibration = await create_score_calibration_in_score_set(db, calibration_create, user) assert created_score_calibration is not None diff --git a/tests/lib/annotation/test_annotate.py b/tests/lib/annotation/test_annotate.py index 9c1846cb..3a664d7e 100644 --- a/tests/lib/annotation/test_annotate.py +++ b/tests/lib/annotation/test_annotate.py @@ -1,8 +1,10 @@ from copy import deepcopy -from mavedb.lib.annotation.annotate import variant_study_result -from mavedb.lib.annotation.annotate import variant_functional_impact_statement -from mavedb.lib.annotation.annotate import variant_pathogenicity_evidence +from mavedb.lib.annotation.annotate import ( + variant_functional_impact_statement, + variant_pathogenicity_evidence, + variant_study_result, +) # The contents of these results are tested elsewhere. These tests focus on object structure. @@ -81,8 +83,8 @@ def test_variant_pathogenicity_evidence_with_no_acmg_classifications( for ( calibration ) in mock_mapped_variant_with_pathogenicity_calibration_score_set.variant.score_set.score_calibrations: - calibration.functional_ranges = [ - {**deepcopy(r), "acmgClassification": None} for r in calibration.functional_ranges + calibration.functional_classifications = [ + {**deepcopy(r), "acmgClassification": None} for r in calibration.functional_classifications ] result = variant_pathogenicity_evidence(mock_mapped_variant_with_pathogenicity_calibration_score_set) diff --git a/tests/lib/annotation/test_classification.py b/tests/lib/annotation/test_classification.py index 83f2388d..bab685e7 100644 --- a/tests/lib/annotation/test_classification.py +++ b/tests/lib/annotation/test_classification.py @@ -87,7 +87,7 @@ def test_functional_classification_of_variant_without_ranges_in_primary_calibrat None, ) assert primary_cal is not None - primary_cal.functional_ranges = None + primary_cal.functional_classifications = None with pytest.raises(ValueError) as exc: functional_classification_of_variant(mock_mapped_variant_with_functional_calibration_score_set) @@ -171,7 +171,7 @@ def test_pathogenicity_classification_of_variant_without_ranges_in_primary_calib None, ) assert primary_cal is not None - primary_cal.functional_ranges = None + primary_cal.functional_classifications = None with pytest.raises(ValueError) as exc: pathogenicity_classification_of_variant(mock_mapped_variant_with_pathogenicity_calibration_score_set) @@ -194,7 +194,7 @@ def test_pathogenicity_classification_of_variant_without_acmg_classification_in_ None, ) assert primary_cal is not None - for r in primary_cal.functional_ranges: + for r in primary_cal.functional_classifications: r["acmgClassification"] = None criterion, strength = pathogenicity_classification_of_variant( @@ -217,8 +217,8 @@ def test_pathogenicity_classification_of_variant_with_invalid_evidence_strength_ None, ) assert primary_cal is not None - for r in primary_cal.functional_ranges: - r["acmgClassification"]["evidenceStrength"] = "moderate_plus" + for r in primary_cal.functional_classifications: + r["acmgClassification"]["evidenceStrength"] = "MODERATE_PLUS" r["oddspathsRatio"] = None with pytest.raises(ValueError) as exc: diff --git a/tests/lib/annotation/test_util.py b/tests/lib/annotation/test_util.py index afb19cbe..572a0489 100644 --- a/tests/lib/annotation/test_util.py +++ b/tests/lib/annotation/test_util.py @@ -1,17 +1,17 @@ from copy import deepcopy +from unittest.mock import patch + import pytest from mavedb.lib.annotation.exceptions import MappingDataDoesntExistException from mavedb.lib.annotation.util import ( - variation_from_mapped_variant, _can_annotate_variant_base_assumptions, _variant_score_calibrations_have_required_calibrations_and_ranges_for_annotation, can_annotate_variant_for_functional_statement, can_annotate_variant_for_pathogenicity_evidence, + variation_from_mapped_variant, ) - -from tests.helpers.constants import TEST_VALID_POST_MAPPED_VRS_ALLELE, TEST_SEQUENCE_LOCATION_ACCESSION -from unittest.mock import patch +from tests.helpers.constants import TEST_SEQUENCE_LOCATION_ACCESSION, TEST_VALID_POST_MAPPED_VRS_ALLELE @pytest.mark.parametrize( @@ -87,7 +87,7 @@ def test_score_range_check_returns_false_when_calibrations_present_with_empty_ra mock_mapped_variant = request.getfixturevalue(variant_fixture) for calibration in mock_mapped_variant.variant.score_set.score_calibrations: - calibration.functional_ranges = None + calibration.functional_classifications = None assert ( _variant_score_calibrations_have_required_calibrations_and_ranges_for_annotation(mock_mapped_variant, kind) @@ -101,11 +101,11 @@ def test_pathogenicity_range_check_returns_false_when_no_acmg_calibration( for ( calibration ) in mock_mapped_variant_with_pathogenicity_calibration_score_set.variant.score_set.score_calibrations: - acmg_classification_removed = [deepcopy(r) for r in calibration.functional_ranges] + acmg_classification_removed = [deepcopy(r) for r in calibration.functional_classifications] for fr in acmg_classification_removed: fr["acmgClassification"] = None - calibration.functional_ranges = acmg_classification_removed + calibration.functional_classifications = acmg_classification_removed assert ( _variant_score_calibrations_have_required_calibrations_and_ranges_for_annotation( @@ -121,10 +121,10 @@ def test_pathogenicity_range_check_returns_true_when_some_acmg_calibration( for ( calibration ) in mock_mapped_variant_with_pathogenicity_calibration_score_set.variant.score_set.score_calibrations: - acmg_classification_removed = [deepcopy(r) for r in calibration.functional_ranges] + acmg_classification_removed = [deepcopy(r) for r in calibration.functional_classifications] acmg_classification_removed[0]["acmgClassification"] = None - calibration.functional_ranges = acmg_classification_removed + calibration.functional_classifications = acmg_classification_removed assert ( _variant_score_calibrations_have_required_calibrations_and_ranges_for_annotation( @@ -193,7 +193,7 @@ def test_functional_range_check_returns_false_when_base_assumptions_fail(mock_ma assert result is False -def test_functional_range_check_returns_false_when_functional_ranges_check_fails(mock_mapped_variant): +def test_functional_range_check_returns_false_when_functional_classifications_check_fails(mock_mapped_variant): with patch( "mavedb.lib.annotation.util._variant_score_calibrations_have_required_calibrations_and_ranges_for_annotation", return_value=False, diff --git a/tests/lib/clingen/test_services.py b/tests/lib/clingen/test_services.py index 34828649..481c16d8 100644 --- a/tests/lib/clingen/test_services.py +++ b/tests/lib/clingen/test_services.py @@ -1,26 +1,26 @@ # ruff: noqa: E402 import os -import pytest -import requests from datetime import datetime -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch from urllib import parse +import pytest +import requests + arq = pytest.importorskip("arq") cdot = pytest.importorskip("cdot") fastapi = pytest.importorskip("fastapi") -from mavedb.lib.clingen.constants import LDH_MAVE_ACCESS_ENDPOINT, GENBOREE_ACCOUNT_NAME, GENBOREE_ACCOUNT_PASSWORD -from mavedb.lib.utils import batched +from mavedb.lib.clingen.constants import GENBOREE_ACCOUNT_NAME, GENBOREE_ACCOUNT_PASSWORD, LDH_MAVE_ACCESS_ENDPOINT from mavedb.lib.clingen.services import ( ClinGenAlleleRegistryService, ClinGenLdhService, - get_clingen_variation, clingen_allele_id_from_ldh_variation, get_allele_registry_associations, + get_clingen_variation, ) - +from mavedb.lib.utils import batched from tests.helpers.constants import VALID_CLINGEN_CA_ID TEST_CLINGEN_URL = "https://pytest.clingen.com" @@ -332,7 +332,7 @@ def test_dispatch_submissions_failure(self, mock_auth_url, mock_put, car_service def test_get_allele_registry_associations_success(): - content_submissions = ["NM_0001:c.1A>G", "NM_0002:c.2T>C"] + content_submissions = ["NM_0001:c.1A>G", "NM_0002:c.2T>C", "NM_0003:c.3G>A"] submission_response = [ { "@id": "http://reg.test.genome.network/allele/CA123", @@ -344,9 +344,15 @@ def test_get_allele_registry_associations_success(): "genomicAlleles": [], "transcriptAlleles": [{"hgvs": "NM_0002:c.2T>C"}], }, + { + "@id": "http://reg.test.genome.network/allele/CA789", + "genomicAlleles": [], + "transcriptAlleles": [], + "aminoAcidAlleles": [{"hgvs": "NM_0003:c.3G>A"}], + }, ] result = get_allele_registry_associations(content_submissions, submission_response) - assert result == {"NM_0001:c.1A>G": "CA123", "NM_0002:c.2T>C": "CA456"} + assert result == {"NM_0001:c.1A>G": "CA123", "NM_0002:c.2T>C": "CA456", "NM_0003:c.3G>A": "CA789"} def test_get_allele_registry_associations_empty(): @@ -365,3 +371,27 @@ def test_get_allele_registry_associations_no_match(): ] result = get_allele_registry_associations(content_submissions, submission_response) assert result == {} + + +def test_get_allele_registry_associations_mixed(): + content_submissions = ["NM_0001:c.1A>G", "NM_0002:c.2T>C", "NM_0003:c.3G>A"] + submission_response = [ + { + "@id": "http://reg.test.genome.network/allele/CA123", + "genomicAlleles": [{"hgvs": "NM_0001:c.1A>G"}], + "transcriptAlleles": [], + }, + { + "errorType": "InvalidHGVS", + "hgvs": "NM_0002:c.2T>C", + "message": "The HGVS string is invalid.", + }, + { + "@id": "http://reg.test.genome.network/allele/CA789", + "genomicAlleles": [], + "transcriptAlleles": [{"hgvs": "NM_0003:c.3G>A"}], + }, + ] + + result = get_allele_registry_associations(content_submissions, submission_response) + assert result == {"NM_0001:c.1A>G": "CA123", "NM_0003:c.3G>A": "CA789"} diff --git a/tests/lib/conftest.py b/tests/lib/conftest.py index 5cffa374..c281f5eb 100644 --- a/tests/lib/conftest.py +++ b/tests/lib/conftest.py @@ -1,45 +1,51 @@ -from humps import decamelize from copy import deepcopy from datetime import datetime from pathlib import Path -import pytest from shutil import copytree from unittest import mock +import pytest +from humps import decamelize + +from mavedb.models.acmg_classification import ACMGClassification from mavedb.models.enums.user_role import UserRole -from mavedb.models.score_calibration import ScoreCalibration -from mavedb.models.experiment_set import ExperimentSet from mavedb.models.experiment import Experiment +from mavedb.models.experiment_set import ExperimentSet from mavedb.models.license import License +from mavedb.models.mapped_variant import MappedVariant from mavedb.models.publication_identifier import PublicationIdentifier -from mavedb.models.score_set_publication_identifier import ScoreSetPublicationIdentifierAssociation from mavedb.models.role import Role -from mavedb.models.taxonomy import Taxonomy +from mavedb.models.score_calibration import ScoreCalibration from mavedb.models.score_set import ScoreSet +from mavedb.models.score_set_publication_identifier import ScoreSetPublicationIdentifierAssociation +from mavedb.models.taxonomy import Taxonomy from mavedb.models.user import User from mavedb.models.variant import Variant -from mavedb.models.mapped_variant import MappedVariant from tests.helpers.constants import ( ADMIN_USER, EXTRA_USER, + TEST_ACMG_BS3_STRONG_CLASSIFICATION, + TEST_ACMG_BS3_STRONG_CLASSIFICATION_WITH_POINTS, + TEST_ACMG_PS3_STRONG_CLASSIFICATION, + TEST_ACMG_PS3_STRONG_CLASSIFICATION_WITH_POINTS, TEST_EXPERIMENT, TEST_EXPERIMENT_SET, - TEST_LICENSE, TEST_INACTIVE_LICENSE, + TEST_LICENSE, TEST_MAVEDB_ATHENA_ROW, TEST_MINIMAL_MAPPED_VARIANT, TEST_MINIMAL_VARIANT, + TEST_PUBMED_IDENTIFIER, + TEST_SAVED_BRNICH_SCORE_CALIBRATION_RANGE_BASED, + TEST_SAVED_PATHOGENICITY_SCORE_CALIBRATION, TEST_SAVED_TAXONOMY, TEST_SEQ_SCORESET, TEST_USER, - TEST_VALID_PRE_MAPPED_VRS_ALLELE_VRS2_X, TEST_VALID_POST_MAPPED_VRS_ALLELE_VRS2_X, - VALID_SCORE_SET_URN, - VALID_EXPERIMENT_URN, + TEST_VALID_PRE_MAPPED_VRS_ALLELE_VRS2_X, VALID_EXPERIMENT_SET_URN, - TEST_SAVED_BRNICH_SCORE_CALIBRATION, - TEST_SAVED_PATHOGENICITY_SCORE_CALIBRATION, - TEST_PUBMED_IDENTIFIER, + VALID_EXPERIMENT_URN, + VALID_SCORE_SET_URN, ) @@ -56,6 +62,10 @@ def setup_lib_db(session): db.add(Taxonomy(**TEST_SAVED_TAXONOMY)) db.add(License(**TEST_LICENSE)) db.add(License(**TEST_INACTIVE_LICENSE)) + db.add(ACMGClassification(**TEST_ACMG_PS3_STRONG_CLASSIFICATION)) + db.add(ACMGClassification(**TEST_ACMG_BS3_STRONG_CLASSIFICATION)) + db.add(ACMGClassification(**TEST_ACMG_BS3_STRONG_CLASSIFICATION_WITH_POINTS)) + db.add(ACMGClassification(**TEST_ACMG_PS3_STRONG_CLASSIFICATION_WITH_POINTS)) db.commit() @@ -177,7 +187,7 @@ def mock_experiment(): def mock_functional_calibration(mock_user): calibration = mock.Mock(spec=ScoreCalibration) - for key, value in TEST_SAVED_BRNICH_SCORE_CALIBRATION.items(): + for key, value in TEST_SAVED_BRNICH_SCORE_CALIBRATION_RANGE_BASED.items(): setattr(calibration, decamelize(key), deepcopy(value)) calibration.primary = True # Ensure functional calibration is primary for tests diff --git a/tests/lib/permissions/__init__.py b/tests/lib/permissions/__init__.py new file mode 100644 index 00000000..78b319a5 --- /dev/null +++ b/tests/lib/permissions/__init__.py @@ -0,0 +1 @@ +"""Tests for the modular permissions system.""" diff --git a/tests/lib/permissions/conftest.py b/tests/lib/permissions/conftest.py new file mode 100644 index 00000000..302159f5 --- /dev/null +++ b/tests/lib/permissions/conftest.py @@ -0,0 +1,196 @@ +"""Shared fixtures and helpers for permissions tests.""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Union +from unittest.mock import Mock + +import pytest + +from mavedb.models.enums.contribution_role import ContributionRole +from mavedb.models.enums.user_role import UserRole + +if TYPE_CHECKING: + from mavedb.lib.permissions.actions import Action + + +@dataclass +class PermissionTest: + """Represents a single permission test case for action handler testing. + + Used for parametrized testing of individual action handlers (_handle_read_action, etc.) + rather than comprehensive end-to-end permission testing. + + Args: + entity_type: Entity type name for context (not used in handler tests) + entity_state: "private" or "published" (None for stateless entities like User) + user_type: "admin", "owner", "contributor", "other_user", "anonymous", "self" + action: Action enum value (for documentation, handlers test specific actions) + should_be_permitted: True/False for normal cases, "NotImplementedError" for unsupported + expected_code: HTTP error code when denied (403, 404, 401, etc.) + description: Human-readable test description + collection_role: For Collection tests: "collection_admin", "collection_editor", "collection_viewer" + investigator_provided: For ScoreCalibration tests: True=investigator, False=community + """ + + entity_type: str + entity_state: Optional[str] + user_type: str + action: "Action" + should_be_permitted: Union[bool, str] + expected_code: Optional[int] = None + description: Optional[str] = None + collection_role: Optional[str] = None + collection_badge: Optional[str] = None + investigator_provided: Optional[bool] = None + + +class EntityTestHelper: + """Helper class to create test entities and user data with consistent properties.""" + + @staticmethod + def create_user_data(user_type: str): + """Create UserData mock for different user types. + + Args: + user_type: "admin", "owner", "contributor", "other_user", "anonymous", "self", "mapper" + + Returns: + Mock UserData object or None for anonymous users + """ + user_configs = { + "admin": (1, "1111-1111-1111-111X", [UserRole.admin]), + "owner": (2, "2222-2222-2222-222X", []), + "contributor": (3, "3333-3333-3333-333X", []), + "other_user": (4, "4444-4444-4444-444X", []), + "self": (5, "5555-5555-5555-555X", []), + "mapper": (6, "6666-6666-6666-666X", [UserRole.mapper]), + } + + if user_type == "anonymous": + return None + + if user_type not in user_configs: + raise ValueError(f"Unknown user type: {user_type}") + + user_id, username, roles = user_configs[user_type] + return Mock(user=Mock(id=user_id, username=username), active_roles=roles) + + @staticmethod + def create_score_set(entity_state: str = "private", owner_id: int = 2): + """Create a ScoreSet mock for testing.""" + private = entity_state == "private" + published_date = None if private else "2023-01-01" + contributors = [Mock(orcid_id="3333-3333-3333-333X")] + + return Mock( + id=1, + urn="urn:mavedb:00000001-a-1", + private=private, + created_by_id=owner_id, + published_date=published_date, + contributors=contributors, + ) + + @staticmethod + def create_experiment(entity_state: str = "private", owner_id: int = 2): + """Create an Experiment mock for testing.""" + private = entity_state == "private" + published_date = None if private else "2023-01-01" + contributors = [Mock(orcid_id="3333-3333-3333-333X")] + + return Mock( + id=1, + urn="urn:mavedb:00000001-a", + private=private, + created_by_id=owner_id, + published_date=published_date, + contributors=contributors, + ) + + @staticmethod + def create_experiment_set(entity_state: str = "private", owner_id: int = 2): + """Create an ExperimentSet mock for testing.""" + private = entity_state == "private" + published_date = None if private else "2023-01-01" + contributors = [Mock(orcid_id="3333-3333-3333-333X")] + + return Mock( + id=1, + urn="urn:mavedb:00000001", + private=private, + created_by_id=owner_id, + published_date=published_date, + contributors=contributors, + ) + + @staticmethod + def create_collection( + entity_state: str = "private", + owner_id: int = 2, + collection_role: Optional[str] = None, + badge_name: Optional[str] = None, + ): + """Create a Collection mock for testing. + + Args: + entity_state: "private" or "published" + owner_id: ID of the collection owner + collection_role: "collection_admin", "collection_editor", or "collection_viewer" + to create user association for contributor user (ID=3) + """ + private = entity_state == "private" + published_date = None if private else "2023-01-01" + + user_associations = [] + if collection_role: + role_map = { + "collection_admin": ContributionRole.admin, + "collection_editor": ContributionRole.editor, + "collection_viewer": ContributionRole.viewer, + } + user_associations.append(Mock(user_id=3, contribution_role=role_map[collection_role])) + + return Mock( + id=1, + urn="urn:mavedb:collection-001", + private=private, + created_by_id=owner_id, + published_date=published_date, + user_associations=user_associations, + badge_name=badge_name, + ) + + @staticmethod + def create_user(user_id: int = 5): + """Create a User mock for testing.""" + return Mock( + id=user_id, + username=f"{user_id}{user_id}{user_id}{user_id}-{user_id}{user_id}{user_id}{user_id}-{user_id}{user_id}{user_id}{user_id}-{user_id}{user_id}{user_id}X", + ) + + @staticmethod + def create_score_calibration(entity_state: str = "private", investigator_provided: bool = False): + """Create a ScoreCalibration mock for testing. + + Args: + entity_state: "private" or "published" (affects score_set and private property) + investigator_provided: True if investigator-provided, False if community-provided + """ + private = entity_state == "private" + score_set = EntityTestHelper.create_score_set(entity_state) + + # ScoreCalibrations have their own private property plus associated ScoreSet + return Mock( + id=1, + private=private, + score_set=score_set, + investigator_provided=investigator_provided, + created_by_id=2, # owner + modified_by_id=2, # owner + ) + + +@pytest.fixture +def entity_helper(): + """Fixture providing EntityTestHelper instance.""" + return EntityTestHelper() diff --git a/tests/lib/permissions/test_collection.py b/tests/lib/permissions/test_collection.py new file mode 100644 index 00000000..ab0593bb --- /dev/null +++ b/tests/lib/permissions/test_collection.py @@ -0,0 +1,732 @@ +# ruff: noqa: E402 + +"""Tests for Collection permissions module.""" + +import pytest + +pytest.importorskip("fastapi", reason="Skipping permissions tests; FastAPI is required but not installed.") + +from typing import Callable, List +from unittest import mock + +from mavedb.lib.permissions.actions import Action +from mavedb.lib.permissions.collection import ( + _handle_add_badge_action, + _handle_add_experiment_action, + _handle_add_role_action, + _handle_add_score_set_action, + _handle_delete_action, + _handle_publish_action, + _handle_read_action, + _handle_update_action, + has_permission, +) +from mavedb.models.enums.contribution_role import ContributionRole +from mavedb.models.enums.user_role import UserRole +from tests.lib.permissions.conftest import EntityTestHelper, PermissionTest + +COLLECTION_SUPPORTED_ACTIONS: dict[Action, Callable] = { + Action.READ: _handle_read_action, + Action.UPDATE: _handle_update_action, + Action.DELETE: _handle_delete_action, + Action.PUBLISH: _handle_publish_action, + Action.ADD_EXPERIMENT: _handle_add_experiment_action, + Action.ADD_SCORE_SET: _handle_add_score_set_action, + Action.ADD_ROLE: _handle_add_role_action, + Action.ADD_BADGE: _handle_add_badge_action, +} + +COLLECTION_UNSUPPORTED_ACTIONS: List[Action] = [ + Action.LOOKUP, + Action.CHANGE_RANK, + Action.SET_SCORES, +] + +COLLECTION_ROLE_MAP = { + "collection_admin": ContributionRole.admin, + "collection_editor": ContributionRole.editor, + "collection_viewer": ContributionRole.viewer, +} + + +def test_collection_handles_all_actions() -> None: + """Test that all Collection actions are either supported or explicitly unsupported.""" + all_actions = set(action for action in Action) + supported = set(COLLECTION_SUPPORTED_ACTIONS) + unsupported = set(COLLECTION_UNSUPPORTED_ACTIONS) + + assert ( + supported.union(unsupported) == all_actions + ), "Some actions are not categorized as supported or unsupported for collections." + + +class TestCollectionHasPermission: + """Test the main has_permission dispatcher function for Collection entities.""" + + @pytest.mark.parametrize("action, handler", COLLECTION_SUPPORTED_ACTIONS.items()) + def test_supported_actions_route_to_correct_action_handler( + self, entity_helper: EntityTestHelper, action: Action, handler: Callable + ) -> None: + """Test that has_permission routes supported actions to their handlers.""" + collection = entity_helper.create_collection() + admin_user = entity_helper.create_user_data("admin") + + with mock.patch("mavedb.lib.permissions.collection." + handler.__name__, wraps=handler) as mock_handler: + has_permission(admin_user, collection, action) + mock_handler.assert_called_once_with( + admin_user, + collection, + collection.private, + collection.badge_name is not None, + False, # admin is not the owner + [], # admin has no collection roles + [UserRole.admin], + ) + + def test_has_permission_calls_helper_with_collection_roles_when_present(self, entity_helper: EntityTestHelper): + """Test that has_permission passes collection roles to action handlers.""" + collection = entity_helper.create_collection(collection_role="collection_editor") + contributor_user = entity_helper.create_user_data("contributor") + + with mock.patch( + "mavedb.lib.permissions.collection._handle_read_action", wraps=_handle_read_action + ) as mock_handler: + has_permission(contributor_user, collection, Action.READ) + mock_handler.assert_called_once_with( + contributor_user, + collection, + collection.private, + collection.badge_name is not None, + False, # contributor is not the owner + [ContributionRole.editor], # collection role + [], # user has no active roles + ) + + @pytest.mark.parametrize("action", COLLECTION_UNSUPPORTED_ACTIONS) + def test_raises_for_unsupported_actions(self, entity_helper: EntityTestHelper, action: Action) -> None: + """Test that unsupported actions raise NotImplementedError with descriptive message.""" + collection = entity_helper.create_collection() + admin_user = entity_helper.create_user_data("admin") + + with pytest.raises(NotImplementedError) as exc_info: + has_permission(admin_user, collection, action) + + error_msg = str(exc_info.value) + assert action.value in error_msg + assert all(a.value in error_msg for a in COLLECTION_SUPPORTED_ACTIONS) + + def test_requires_private_attribute(self, entity_helper: EntityTestHelper) -> None: + """Test that ValueError is raised if Collection.private is None.""" + collection = entity_helper.create_collection() + collection.private = None + admin_user = entity_helper.create_user_data("admin") + + with pytest.raises(ValueError) as exc_info: + has_permission(admin_user, collection, Action.READ) + + assert "private" in str(exc_info.value) + + +class TestCollectionReadActionHandler: + """Test the _handle_read_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # System admins can read any Collection + PermissionTest("Collection", "published", "admin", Action.READ, True), + PermissionTest("Collection", "private", "admin", Action.READ, True), + # Owners can read any Collection they own + PermissionTest("Collection", "published", "owner", Action.READ, True), + PermissionTest("Collection", "private", "owner", Action.READ, True), + # Collection admins can read any Collection they have admin role for + PermissionTest( + "Collection", "published", "contributor", Action.READ, True, collection_role="collection_admin" + ), + PermissionTest( + "Collection", "private", "contributor", Action.READ, True, collection_role="collection_admin" + ), + # Collection editors can read any Collection they have editor role for + PermissionTest( + "Collection", "published", "contributor", Action.READ, True, collection_role="collection_editor" + ), + PermissionTest( + "Collection", "private", "contributor", Action.READ, True, collection_role="collection_editor" + ), + # Collection viewers can read any Collection they have viewer role for + PermissionTest( + "Collection", "published", "contributor", Action.READ, True, collection_role="collection_viewer" + ), + PermissionTest( + "Collection", "private", "contributor", Action.READ, True, collection_role="collection_viewer" + ), + # Other users can only read published Collections + PermissionTest("Collection", "published", "other_user", Action.READ, True), + PermissionTest("Collection", "private", "other_user", Action.READ, False, 404), + # Anonymous users can only read published Collections + PermissionTest("Collection", "published", "anonymous", Action.READ, True), + PermissionTest("Collection", "private", "anonymous", Action.READ, False, 404), + ], + ids=lambda tc: f"{tc.user_type}_{tc.collection_role if tc.collection_role else 'no_role'}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_read_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_read_action helper function directly.""" + assert test_case.entity_state is not None, "Collection tests must have entity_state" + collection = entity_helper.create_collection(test_case.entity_state, collection_role=test_case.collection_role) + user_data = entity_helper.create_user_data(test_case.user_type) + + # Determine user relationship to entity + private = test_case.entity_state == "private" + official_collection = collection.badge_name is not None + user_is_owner = test_case.user_type == "owner" + collection_roles = [COLLECTION_ROLE_MAP[test_case.collection_role]] if test_case.collection_role else [] + active_roles = user_data.active_roles if user_data else [] + + # Test the helper function directly + result = _handle_read_action( + user_data, collection, private, official_collection, user_is_owner, collection_roles, active_roles + ) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestCollectionUpdateActionHandler: + """Test the _handle_update_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # System admins can update any Collection + PermissionTest("Collection", "private", "admin", Action.UPDATE, True), + PermissionTest("Collection", "published", "admin", Action.UPDATE, True), + # Owners can update any Collection they own + PermissionTest("Collection", "private", "owner", Action.UPDATE, True), + PermissionTest("Collection", "published", "owner", Action.UPDATE, True), + # Collection admins can update any Collection they have admin role for + PermissionTest( + "Collection", "private", "contributor", Action.UPDATE, True, collection_role="collection_admin" + ), + PermissionTest( + "Collection", "published", "contributor", Action.UPDATE, True, collection_role="collection_admin" + ), + # Collection editors can update any Collection they have editor role for + PermissionTest( + "Collection", "private", "contributor", Action.UPDATE, True, collection_role="collection_editor" + ), + PermissionTest( + "Collection", "published", "contributor", Action.UPDATE, True, collection_role="collection_editor" + ), + # Collection viewers cannot update Collections + PermissionTest( + "Collection", "private", "contributor", Action.UPDATE, False, 403, collection_role="collection_viewer" + ), + PermissionTest( + "Collection", "published", "contributor", Action.UPDATE, False, 403, collection_role="collection_viewer" + ), + # Other users cannot update Collections + PermissionTest("Collection", "private", "other_user", Action.UPDATE, False, 404), + PermissionTest("Collection", "published", "other_user", Action.UPDATE, False, 403), + # Anonymous users cannot update Collections + PermissionTest("Collection", "private", "anonymous", Action.UPDATE, False, 404), + PermissionTest("Collection", "published", "anonymous", Action.UPDATE, False, 401), + ], + ids=lambda tc: f"{tc.user_type}_{tc.collection_role if tc.collection_role else 'no_role'}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_update_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_update_action helper function directly.""" + assert test_case.entity_state is not None, "Collection tests must have entity_state" + collection = entity_helper.create_collection(test_case.entity_state, collection_role=test_case.collection_role) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + official_collection = collection.badge_name is not None + user_is_owner = test_case.user_type == "owner" + collection_roles = [COLLECTION_ROLE_MAP[test_case.collection_role]] if test_case.collection_role else [] + active_roles = user_data.active_roles if user_data else [] + + result = _handle_update_action( + user_data, collection, private, official_collection, user_is_owner, collection_roles, active_roles + ) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestCollectionDeleteActionHandler: + """Test the _handle_delete_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # System admins can delete any Collection + PermissionTest("Collection", "private", "admin", Action.DELETE, True), + PermissionTest("Collection", "published", "admin", Action.DELETE, True), + PermissionTest("Collection", "private", "admin", Action.DELETE, True, collection_badge="official"), + PermissionTest("Collection", "published", "admin", Action.DELETE, True, collection_badge="official"), + # Owners can only delete unpublished, unofficial Collections + PermissionTest("Collection", "private", "owner", Action.DELETE, True), + PermissionTest("Collection", "published", "owner", Action.DELETE, False, 403), + PermissionTest("Collection", "private", "owner", Action.DELETE, False, 403, collection_badge="official"), + PermissionTest("Collection", "published", "owner", Action.DELETE, False, 403, collection_badge="official"), + # Collection admins cannot delete Collections + PermissionTest( + "Collection", "private", "contributor", Action.DELETE, False, 403, collection_role="collection_admin" + ), + PermissionTest( + "Collection", "published", "contributor", Action.DELETE, False, 403, collection_role="collection_admin" + ), + # Collection editors cannot delete Collections + PermissionTest( + "Collection", "private", "contributor", Action.DELETE, False, 403, collection_role="collection_editor" + ), + PermissionTest( + "Collection", "published", "contributor", Action.DELETE, False, 403, collection_role="collection_editor" + ), + # Collection viewers cannot delete Collections + PermissionTest( + "Collection", "private", "contributor", Action.DELETE, False, 403, collection_role="collection_viewer" + ), + PermissionTest( + "Collection", "published", "contributor", Action.DELETE, False, 403, collection_role="collection_viewer" + ), + # Other users cannot delete Collections + PermissionTest("Collection", "private", "other_user", Action.DELETE, False, 404), + PermissionTest("Collection", "published", "other_user", Action.DELETE, False, 403), + # Anonymous users cannot delete Collections + PermissionTest("Collection", "private", "anonymous", Action.DELETE, False, 404), + PermissionTest("Collection", "published", "anonymous", Action.DELETE, False, 401), + ], + ids=lambda tc: f"{tc.user_type}_{tc.collection_role if tc.collection_role else 'no_role'}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_delete_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_delete_action helper function directly.""" + assert test_case.entity_state is not None, "Collection tests must have entity_state" + collection = entity_helper.create_collection( + test_case.entity_state, collection_role=test_case.collection_role, badge_name=test_case.collection_badge + ) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + official_collection = collection.badge_name is not None + user_is_owner = test_case.user_type == "owner" + collection_roles = [COLLECTION_ROLE_MAP[test_case.collection_role]] if test_case.collection_role else [] + active_roles = user_data.active_roles if user_data else [] + + result = _handle_delete_action( + user_data, collection, private, official_collection, user_is_owner, collection_roles, active_roles + ) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestCollectionPublishActionHandler: + """Test the _handle_publish_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # System admins can publish any Collection + PermissionTest("Collection", "private", "admin", Action.PUBLISH, True), + PermissionTest("Collection", "published", "admin", Action.PUBLISH, True), + # Owners can publish any Collection they own + PermissionTest("Collection", "private", "owner", Action.PUBLISH, True), + PermissionTest("Collection", "published", "owner", Action.PUBLISH, True), + # Collection admins can publish any Collection they have admin role for + PermissionTest( + "Collection", "private", "contributor", Action.PUBLISH, True, collection_role="collection_admin" + ), + PermissionTest( + "Collection", "published", "contributor", Action.PUBLISH, True, collection_role="collection_admin" + ), + # Collection editors cannot publish Collections + PermissionTest( + "Collection", "private", "contributor", Action.PUBLISH, False, 403, collection_role="collection_editor" + ), + PermissionTest( + "Collection", + "published", + "contributor", + Action.PUBLISH, + False, + 403, + collection_role="collection_editor", + ), + # Collection viewers cannot publish Collections + PermissionTest( + "Collection", "private", "contributor", Action.PUBLISH, False, 403, collection_role="collection_viewer" + ), + PermissionTest( + "Collection", + "published", + "contributor", + Action.PUBLISH, + False, + 403, + collection_role="collection_viewer", + ), + # Other users cannot publish Collections + PermissionTest("Collection", "private", "other_user", Action.PUBLISH, False, 404), + PermissionTest("Collection", "published", "other_user", Action.PUBLISH, False, 403), + # Anonymous users cannot publish Collections + PermissionTest("Collection", "private", "anonymous", Action.PUBLISH, False, 404), + PermissionTest("Collection", "published", "anonymous", Action.PUBLISH, False, 401), + ], + ids=lambda tc: f"{tc.user_type}_{tc.collection_role if tc.collection_role else 'no_role'}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_publish_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_publish_action helper function directly.""" + assert test_case.entity_state is not None, "Collection tests must have entity_state" + collection = entity_helper.create_collection(test_case.entity_state, collection_role=test_case.collection_role) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + official_collection = collection.badge_name is not None + user_is_owner = test_case.user_type == "owner" + collection_roles = [COLLECTION_ROLE_MAP[test_case.collection_role]] if test_case.collection_role else [] + active_roles = user_data.active_roles if user_data else [] + + result = _handle_publish_action( + user_data, collection, private, official_collection, user_is_owner, collection_roles, active_roles + ) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestCollectionAddExperimentActionHandler: + """Test the _handle_add_experiment_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # System admins can add experiments to any Collection + PermissionTest("Collection", "private", "admin", Action.ADD_EXPERIMENT, True), + PermissionTest("Collection", "published", "admin", Action.ADD_EXPERIMENT, True), + # Owners can add experiments to any Collection they own + PermissionTest("Collection", "private", "owner", Action.ADD_EXPERIMENT, True), + PermissionTest("Collection", "published", "owner", Action.ADD_EXPERIMENT, True), + # Collection admins can add experiments to any Collection they have admin role for + PermissionTest( + "Collection", "private", "contributor", Action.ADD_EXPERIMENT, True, collection_role="collection_admin" + ), + PermissionTest( + "Collection", + "published", + "contributor", + Action.ADD_EXPERIMENT, + True, + collection_role="collection_admin", + ), + # Collection editors can add experiments to any Collection they have editor role for + PermissionTest( + "Collection", "private", "contributor", Action.ADD_EXPERIMENT, True, collection_role="collection_editor" + ), + PermissionTest( + "Collection", + "published", + "contributor", + Action.ADD_EXPERIMENT, + True, + collection_role="collection_editor", + ), + # Collection viewers cannot add experiments to Collections + PermissionTest( + "Collection", + "private", + "contributor", + Action.ADD_EXPERIMENT, + False, + 403, + collection_role="collection_viewer", + ), + PermissionTest( + "Collection", + "published", + "contributor", + Action.ADD_EXPERIMENT, + False, + 403, + collection_role="collection_viewer", + ), + # Other users cannot add experiments to Collections + PermissionTest("Collection", "private", "other_user", Action.ADD_EXPERIMENT, False, 404), + PermissionTest("Collection", "published", "other_user", Action.ADD_EXPERIMENT, False, 403), + # Anonymous users cannot add experiments to Collections + PermissionTest("Collection", "private", "anonymous", Action.ADD_EXPERIMENT, False, 404), + PermissionTest("Collection", "published", "anonymous", Action.ADD_EXPERIMENT, False, 401), + ], + ids=lambda tc: f"{tc.user_type}_{tc.collection_role if tc.collection_role else 'no_role'}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_add_experiment_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_add_experiment_action helper function directly.""" + assert test_case.entity_state is not None, "Collection tests must have entity_state" + collection = entity_helper.create_collection(test_case.entity_state, collection_role=test_case.collection_role) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + official_collection = collection.badge_name is not None + user_is_owner = test_case.user_type == "owner" + collection_roles = [COLLECTION_ROLE_MAP[test_case.collection_role]] if test_case.collection_role else [] + active_roles = user_data.active_roles if user_data else [] + + result = _handle_add_experiment_action( + user_data, collection, private, official_collection, user_is_owner, collection_roles, active_roles + ) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestCollectionAddScoreSetActionHandler: + """Test the _handle_add_score_set_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # System admins can add score sets to any Collection + PermissionTest("Collection", "private", "admin", Action.ADD_SCORE_SET, True), + PermissionTest("Collection", "published", "admin", Action.ADD_SCORE_SET, True), + # Owners can add score sets to any Collection they own + PermissionTest("Collection", "private", "owner", Action.ADD_SCORE_SET, True), + PermissionTest("Collection", "published", "owner", Action.ADD_SCORE_SET, True), + # Collection admins can add score sets to any Collection they have admin role for + PermissionTest( + "Collection", "private", "contributor", Action.ADD_SCORE_SET, True, collection_role="collection_admin" + ), + PermissionTest( + "Collection", "published", "contributor", Action.ADD_SCORE_SET, True, collection_role="collection_admin" + ), + # Collection editors can add score sets to any Collection they have editor role for + PermissionTest( + "Collection", "private", "contributor", Action.ADD_SCORE_SET, True, collection_role="collection_editor" + ), + PermissionTest( + "Collection", + "published", + "contributor", + Action.ADD_SCORE_SET, + True, + collection_role="collection_editor", + ), + # Collection viewers cannot add score sets to Collections + PermissionTest( + "Collection", + "private", + "contributor", + Action.ADD_SCORE_SET, + False, + 403, + collection_role="collection_viewer", + ), + PermissionTest( + "Collection", + "published", + "contributor", + Action.ADD_SCORE_SET, + False, + 403, + collection_role="collection_viewer", + ), + # Other users cannot add score sets to Collections + PermissionTest("Collection", "private", "other_user", Action.ADD_SCORE_SET, False, 404), + PermissionTest("Collection", "published", "other_user", Action.ADD_SCORE_SET, False, 403), + # Anonymous users cannot add score sets to Collections + PermissionTest("Collection", "private", "anonymous", Action.ADD_SCORE_SET, False, 404), + PermissionTest("Collection", "published", "anonymous", Action.ADD_SCORE_SET, False, 401), + ], + ids=lambda tc: f"{tc.user_type}_{tc.collection_role if tc.collection_role else 'no_role'}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_add_score_set_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_add_score_set_action helper function directly.""" + assert test_case.entity_state is not None, "Collection tests must have entity_state" + collection = entity_helper.create_collection(test_case.entity_state, collection_role=test_case.collection_role) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + official_collection = collection.badge_name is not None + user_is_owner = test_case.user_type == "owner" + collection_roles = [COLLECTION_ROLE_MAP[test_case.collection_role]] if test_case.collection_role else [] + active_roles = user_data.active_roles if user_data else [] + + result = _handle_add_score_set_action( + user_data, collection, private, official_collection, user_is_owner, collection_roles, active_roles + ) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestCollectionAddRoleActionHandler: + """Test the _handle_add_role_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # System admins can add roles to any Collection + PermissionTest("Collection", "private", "admin", Action.ADD_ROLE, True), + PermissionTest("Collection", "published", "admin", Action.ADD_ROLE, True), + # Owners can add roles to any Collection they own + PermissionTest("Collection", "private", "owner", Action.ADD_ROLE, True), + PermissionTest("Collection", "published", "owner", Action.ADD_ROLE, True), + # Collection admins can add roles to any Collection they have admin role for + PermissionTest( + "Collection", "private", "contributor", Action.ADD_ROLE, True, collection_role="collection_admin" + ), + PermissionTest( + "Collection", "published", "contributor", Action.ADD_ROLE, True, collection_role="collection_admin" + ), + # Collection editors cannot add roles to Collections + PermissionTest( + "Collection", "private", "contributor", Action.ADD_ROLE, False, 403, collection_role="collection_editor" + ), + PermissionTest( + "Collection", + "published", + "contributor", + Action.ADD_ROLE, + False, + 403, + collection_role="collection_editor", + ), + # Collection viewers cannot add roles to Collections + PermissionTest( + "Collection", "private", "contributor", Action.ADD_ROLE, False, 403, collection_role="collection_viewer" + ), + PermissionTest( + "Collection", + "published", + "contributor", + Action.ADD_ROLE, + False, + 403, + collection_role="collection_viewer", + ), + # Other users cannot add roles to Collections + PermissionTest("Collection", "private", "other_user", Action.ADD_ROLE, False, 404), + PermissionTest("Collection", "published", "other_user", Action.ADD_ROLE, False, 403), + # Anonymous users cannot add roles to Collections + PermissionTest("Collection", "private", "anonymous", Action.ADD_ROLE, False, 404), + PermissionTest("Collection", "published", "anonymous", Action.ADD_ROLE, False, 401), + ], + ids=lambda tc: f"{tc.user_type}_{tc.collection_role if tc.collection_role else 'no_role'}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_add_role_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_add_role_action helper function directly.""" + assert test_case.entity_state is not None, "Collection tests must have entity_state" + collection = entity_helper.create_collection(test_case.entity_state, collection_role=test_case.collection_role) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + official_collection = collection.badge_name is not None + user_is_owner = test_case.user_type == "owner" + collection_roles = [COLLECTION_ROLE_MAP[test_case.collection_role]] if test_case.collection_role else [] + active_roles = user_data.active_roles if user_data else [] + + result = _handle_add_role_action( + user_data, collection, private, official_collection, user_is_owner, collection_roles, active_roles + ) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestCollectionAddBadgeActionHandler: + """Test the _handle_add_badge_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # System admins can add badges to any Collection + PermissionTest("Collection", "private", "admin", Action.ADD_BADGE, True), + PermissionTest("Collection", "published", "admin", Action.ADD_BADGE, True), + # Owners cannot add badges to Collections (admin-only operation) + PermissionTest("Collection", "private", "owner", Action.ADD_BADGE, False, 403), + PermissionTest("Collection", "published", "owner", Action.ADD_BADGE, False, 403), + # Collection admins cannot add badges to Collections (system admin-only) + PermissionTest( + "Collection", "private", "contributor", Action.ADD_BADGE, False, 403, collection_role="collection_admin" + ), + PermissionTest( + "Collection", + "published", + "contributor", + Action.ADD_BADGE, + False, + 403, + collection_role="collection_admin", + ), + # Collection editors cannot add badges to Collections + PermissionTest( + "Collection", + "private", + "contributor", + Action.ADD_BADGE, + False, + 403, + collection_role="collection_editor", + ), + PermissionTest( + "Collection", + "published", + "contributor", + Action.ADD_BADGE, + False, + 403, + collection_role="collection_editor", + ), + # Collection viewers cannot add badges to Collections + PermissionTest( + "Collection", + "private", + "contributor", + Action.ADD_BADGE, + False, + 403, + collection_role="collection_viewer", + ), + PermissionTest( + "Collection", + "published", + "contributor", + Action.ADD_BADGE, + False, + 403, + collection_role="collection_viewer", + ), + # Other users cannot add badges to Collections + PermissionTest("Collection", "private", "other_user", Action.ADD_BADGE, False, 404), + PermissionTest("Collection", "published", "other_user", Action.ADD_BADGE, False, 403), + # Anonymous users cannot add badges to Collections + PermissionTest("Collection", "private", "anonymous", Action.ADD_BADGE, False, 404), + PermissionTest("Collection", "published", "anonymous", Action.ADD_BADGE, False, 401), + ], + ids=lambda tc: f"{tc.user_type}_{tc.collection_role if tc.collection_role else 'no_role'}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_add_badge_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_add_badge_action helper function directly.""" + assert test_case.entity_state is not None, "Collection tests must have entity_state" + collection = entity_helper.create_collection(test_case.entity_state, collection_role=test_case.collection_role) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + official_collection = collection.badge_name is not None + user_is_owner = test_case.user_type == "owner" + collection_roles = [COLLECTION_ROLE_MAP[test_case.collection_role]] if test_case.collection_role else [] + active_roles = user_data.active_roles if user_data else [] + + result = _handle_add_badge_action( + user_data, collection, private, official_collection, user_is_owner, collection_roles, active_roles + ) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code diff --git a/tests/lib/permissions/test_core.py b/tests/lib/permissions/test_core.py new file mode 100644 index 00000000..55a99107 --- /dev/null +++ b/tests/lib/permissions/test_core.py @@ -0,0 +1,132 @@ +# ruff: noqa: E402 + +"""Tests for core permissions functionality.""" + +import pytest + +pytest.importorskip("fastapi", reason="Skipping permissions tests; FastAPI is required but not installed.") + +from unittest.mock import Mock, patch + +from mavedb.lib.permissions import ( + assert_permission, + collection, + experiment, + experiment_set, + score_calibration, + score_set, + user, +) +from mavedb.lib.permissions.actions import Action +from mavedb.lib.permissions.core import has_permission as core_has_permission +from mavedb.lib.permissions.exceptions import PermissionException +from mavedb.lib.permissions.models import PermissionResponse +from mavedb.models.collection import Collection +from mavedb.models.experiment import Experiment +from mavedb.models.experiment_set import ExperimentSet +from mavedb.models.score_calibration import ScoreCalibration +from mavedb.models.score_set import ScoreSet +from mavedb.models.user import User + +SUPPORTED_ENTITY_TYPES = { + ScoreSet: score_set.has_permission, + Experiment: experiment.has_permission, + ExperimentSet: experiment_set.has_permission, + Collection: collection.has_permission, + User: user.has_permission, + ScoreCalibration: score_calibration.has_permission, +} + + +class TestCoreDispatcher: + """Test the core permission dispatcher functionality.""" + + @pytest.mark.parametrize("entity, handler", SUPPORTED_ENTITY_TYPES.items()) + def test_dispatcher_routes_to_correct_entity_handler(self, entity_helper, entity, handler): + """Test that the dispatcher routes requests to the correct entity-specific handler.""" + admin_user = entity_helper.create_user_data("admin") + + with ( + patch("mavedb.lib.permissions.core.type", return_value=entity), + patch( + f"mavedb.lib.permissions.core.{handler.__module__.split('.')[-1]}.{handler.__name__}", + return_value=PermissionResponse(True), + ) as mocked_handler, + ): + core_has_permission(admin_user, entity, Action.READ) + mocked_handler.assert_called_once_with(admin_user, entity, Action.READ) + + def test_dispatcher_raises_for_unsupported_entity_type(self, entity_helper): + """Test that unsupported entity types raise NotImplementedError.""" + admin_user = entity_helper.create_user_data("admin") + unsupported_entity = Mock() # Some random object + + with pytest.raises(NotImplementedError) as exc_info: + core_has_permission(admin_user, unsupported_entity, Action.READ) + + error_msg = str(exc_info.value) + assert "not implemented" in error_msg.lower() + assert "Mock" in error_msg # Should mention the actual type + assert "Supported entity types" in error_msg + + +class TestAssertPermission: + """Test the assert_permission function.""" + + def test_assert_permission_returns_result_when_permitted(self, entity_helper): + """Test that assert_permission returns the PermissionResponse when access is granted.""" + + with patch("mavedb.lib.permissions.core.has_permission", return_value=PermissionResponse(True)): + user_data = entity_helper.create_user_data("admin") + score_set = entity_helper.create_score_set("published") + + result = assert_permission(user_data, score_set, Action.READ) + + assert isinstance(result, PermissionResponse) + assert result.permitted is True + + def test_assert_permission_raises_when_denied(self, entity_helper): + """Test that assert_permission raises PermissionException when access is denied.""" + + with ( + patch( + "mavedb.lib.permissions.core.has_permission", + return_value=PermissionResponse(False, http_code=404, message="Not found"), + ), + pytest.raises(PermissionException) as exc_info, + ): + user_data = entity_helper.create_user_data("admin") + score_set = entity_helper.create_score_set("published") + + assert_permission(user_data, score_set, Action.READ) + + exception = exc_info.value + assert hasattr(exception, "http_code") + assert hasattr(exception, "message") + assert exception.http_code == 404 + assert "not found" in exception.message.lower() + + @pytest.mark.parametrize( + "http_code,message", + [ + (403, "Forbidden"), + (401, "Unauthorized"), + (404, "Not Found"), + ], + ) + def test_assert_permission_preserves_error_details(self, entity_helper, http_code, message): + """Test that assert_permission preserves HTTP codes and messages from permission check.""" + + with ( + patch( + "mavedb.lib.permissions.core.has_permission", + return_value=PermissionResponse(False, http_code=http_code, message=message), + ), + pytest.raises(PermissionException) as exc_info, + ): + user_data = entity_helper.create_user_data("admin") + score_set = entity_helper.create_score_set("published") + + assert_permission(user_data, score_set, Action.READ) + + assert exc_info.value.http_code == http_code, f"Expected {http_code} for {http_code} on {message} entity" diff --git a/tests/lib/permissions/test_experiment.py b/tests/lib/permissions/test_experiment.py new file mode 100644 index 00000000..b4e5dc24 --- /dev/null +++ b/tests/lib/permissions/test_experiment.py @@ -0,0 +1,280 @@ +# ruff: noqa: E402 + +"""Tests for Experiment permissions module.""" + +import pytest + +pytest.importorskip("fastapi", reason="Skipping permissions tests; FastAPI is required but not installed.") + +from typing import Callable, List +from unittest import mock + +from mavedb.lib.permissions.actions import Action +from mavedb.lib.permissions.experiment import ( + _handle_add_score_set_action, + _handle_delete_action, + _handle_read_action, + _handle_update_action, + has_permission, +) +from mavedb.models.enums.user_role import UserRole +from tests.lib.permissions.conftest import EntityTestHelper, PermissionTest + +EXPERIMENT_SUPPORTED_ACTIONS: dict[Action, Callable] = { + Action.READ: _handle_read_action, + Action.UPDATE: _handle_update_action, + Action.DELETE: _handle_delete_action, + Action.ADD_SCORE_SET: _handle_add_score_set_action, +} + +EXPERIMENT_UNSUPPORTED_ACTIONS: List[Action] = [ + Action.ADD_EXPERIMENT, + Action.ADD_ROLE, + Action.LOOKUP, + Action.ADD_BADGE, + Action.CHANGE_RANK, + Action.SET_SCORES, + Action.PUBLISH, +] + + +def test_experiment_handles_all_actions() -> None: + """Test that all Experiment actions are either supported or explicitly unsupported.""" + all_actions = set(action for action in Action) + supported = set(EXPERIMENT_SUPPORTED_ACTIONS) + unsupported = set(EXPERIMENT_UNSUPPORTED_ACTIONS) + + assert ( + supported.union(unsupported) == all_actions + ), "Some actions are not categorized as supported or unsupported for experiments." + + +class TestExperimentHasPermission: + """Test the main has_permission dispatcher function for Experiment entities.""" + + @pytest.mark.parametrize("action, handler", EXPERIMENT_SUPPORTED_ACTIONS.items()) + def test_supported_actions_route_to_correct_action_handler( + self, entity_helper: EntityTestHelper, action: Action, handler: Callable + ) -> None: + """Test that has_permission routes supported actions to their handlers.""" + experiment = entity_helper.create_experiment() + admin_user = entity_helper.create_user_data("admin") + + with mock.patch("mavedb.lib.permissions.experiment." + handler.__name__, wraps=handler) as mock_handler: + has_permission(admin_user, experiment, action) + mock_handler.assert_called_once_with( + admin_user, + experiment, + experiment.private, + False, # admin is not the owner + False, # admin is not a contributor + [UserRole.admin], + ) + + @pytest.mark.parametrize("action", EXPERIMENT_UNSUPPORTED_ACTIONS) + def test_raises_for_unsupported_actions(self, entity_helper: EntityTestHelper, action: Action) -> None: + """Test that unsupported actions raise NotImplementedError with descriptive message.""" + experiment = entity_helper.create_experiment() + admin_user = entity_helper.create_user_data("admin") + + with pytest.raises(NotImplementedError) as exc_info: + has_permission(admin_user, experiment, action) + + error_msg = str(exc_info.value) + assert action.value in error_msg + assert all(a.value in error_msg for a in EXPERIMENT_SUPPORTED_ACTIONS) + + def test_requires_private_attribute(self, entity_helper: EntityTestHelper) -> None: + """Test that ValueError is raised if Experiment.private is None.""" + experiment = entity_helper.create_experiment() + experiment.private = None + admin_user = entity_helper.create_user_data("admin") + + with pytest.raises(ValueError) as exc_info: + has_permission(admin_user, experiment, Action.READ) + + assert "private" in str(exc_info.value) + + +class TestExperimentReadActionHandler: + """Test the _handle_read_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # Admins can read any Experiment + PermissionTest("Experiment", "published", "admin", Action.READ, True), + PermissionTest("Experiment", "private", "admin", Action.READ, True), + # Owners can read any Experiment they own + PermissionTest("Experiment", "published", "owner", Action.READ, True), + PermissionTest("Experiment", "private", "owner", Action.READ, True), + # Contributors can read any Experiment they contribute to + PermissionTest("Experiment", "published", "contributor", Action.READ, True), + PermissionTest("Experiment", "private", "contributor", Action.READ, True), + # Mappers can read any Experiment (including private) + PermissionTest("Experiment", "published", "mapper", Action.READ, True), + PermissionTest("Experiment", "private", "mapper", Action.READ, True), + # Other users can only read published Experiments + PermissionTest("Experiment", "published", "other_user", Action.READ, True), + PermissionTest("Experiment", "private", "other_user", Action.READ, False, 404), + # Anonymous users can only read published Experiments + PermissionTest("Experiment", "published", "anonymous", Action.READ, True), + PermissionTest("Experiment", "private", "anonymous", Action.READ, False, 404), + ], + ids=lambda tc: f"{tc.user_type}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_read_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_read_action helper function directly.""" + assert test_case.entity_state is not None, "Experiment tests must have entity_state" + experiment = entity_helper.create_experiment(test_case.entity_state) + user_data = entity_helper.create_user_data(test_case.user_type) + + # Determine user relationship to entity + private = test_case.entity_state == "private" + user_is_owner = test_case.user_type == "owner" + user_is_contributor = test_case.user_type == "contributor" + active_roles = user_data.active_roles if user_data else [] + + # Test the helper function directly + result = _handle_read_action(user_data, experiment, private, user_is_owner, user_is_contributor, active_roles) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestExperimentUpdateActionHandler: + """Test the _handle_update_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # Admins can update any Experiment + PermissionTest("Experiment", "private", "admin", Action.UPDATE, True), + PermissionTest("Experiment", "published", "admin", Action.UPDATE, True), + # Owners can update any Experiment they own + PermissionTest("Experiment", "private", "owner", Action.UPDATE, True), + PermissionTest("Experiment", "published", "owner", Action.UPDATE, True), + # Contributors can update any Experiment they contribute to + PermissionTest("Experiment", "private", "contributor", Action.UPDATE, True), + PermissionTest("Experiment", "published", "contributor", Action.UPDATE, True), + # Mappers cannot update Experiments + PermissionTest("Experiment", "private", "mapper", Action.UPDATE, False, 404), + PermissionTest("Experiment", "published", "mapper", Action.UPDATE, False, 403), + # Other users cannot update Experiments + PermissionTest("Experiment", "private", "other_user", Action.UPDATE, False, 404), + PermissionTest("Experiment", "published", "other_user", Action.UPDATE, False, 403), + # Anonymous users cannot update Experiments + PermissionTest("Experiment", "private", "anonymous", Action.UPDATE, False, 404), + PermissionTest("Experiment", "published", "anonymous", Action.UPDATE, False, 401), + ], + ids=lambda tc: f"{tc.user_type}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_update_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_update_action helper function directly.""" + assert test_case.entity_state is not None, "Experiment tests must have entity_state" + experiment = entity_helper.create_experiment(test_case.entity_state) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + user_is_owner = test_case.user_type == "owner" + user_is_contributor = test_case.user_type == "contributor" + active_roles = user_data.active_roles if user_data else [] + + result = _handle_update_action(user_data, experiment, private, user_is_owner, user_is_contributor, active_roles) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestExperimentDeleteActionHandler: + """Test the _handle_delete_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # Admins can delete any Experiment + PermissionTest("Experiment", "private", "admin", Action.DELETE, True), + PermissionTest("Experiment", "published", "admin", Action.DELETE, True), + # Owners can only delete unpublished Experiments + PermissionTest("Experiment", "private", "owner", Action.DELETE, True), + PermissionTest("Experiment", "published", "owner", Action.DELETE, False, 403), + # Contributors cannot delete + PermissionTest("Experiment", "private", "contributor", Action.DELETE, False, 403), + PermissionTest("Experiment", "published", "contributor", Action.DELETE, False, 403), + # Other users cannot delete + PermissionTest("Experiment", "private", "other_user", Action.DELETE, False, 404), + PermissionTest("Experiment", "published", "other_user", Action.DELETE, False, 403), + # Anonymous users cannot delete + PermissionTest("Experiment", "private", "anonymous", Action.DELETE, False, 404), + PermissionTest("Experiment", "published", "anonymous", Action.DELETE, False, 401), + # Mappers cannot delete + PermissionTest("Experiment", "private", "mapper", Action.DELETE, False, 404), + PermissionTest("Experiment", "published", "mapper", Action.DELETE, False, 403), + ], + ids=lambda tc: f"{tc.user_type}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_delete_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_delete_action helper function directly.""" + assert test_case.entity_state is not None, "Experiment tests must have entity_state" + experiment = entity_helper.create_experiment(test_case.entity_state) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + user_is_owner = test_case.user_type == "owner" + user_is_contributor = test_case.user_type == "contributor" + active_roles = user_data.active_roles if user_data else [] + + result = _handle_delete_action(user_data, experiment, private, user_is_owner, user_is_contributor, active_roles) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestExperimentAddScoreSetActionHandler: + """Test the _handle_add_score_set_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # Admins can add score sets to any Experiment + PermissionTest("Experiment", "private", "admin", Action.ADD_SCORE_SET, True), + PermissionTest("Experiment", "published", "admin", Action.ADD_SCORE_SET, True), + # Owners can add score sets to any Experiment they own + PermissionTest("Experiment", "private", "owner", Action.ADD_SCORE_SET, True), + PermissionTest("Experiment", "published", "owner", Action.ADD_SCORE_SET, True), + # Contributors can add score sets to any Experiment they contribute to + PermissionTest("Experiment", "private", "contributor", Action.ADD_SCORE_SET, True), + PermissionTest("Experiment", "published", "contributor", Action.ADD_SCORE_SET, True), + # Mappers can add score sets to public Experiments + PermissionTest("Experiment", "private", "mapper", Action.ADD_SCORE_SET, False, 404), + PermissionTest("Experiment", "published", "mapper", Action.ADD_SCORE_SET, True), + # Other users can add score sets to public Experiments + PermissionTest("Experiment", "private", "other_user", Action.ADD_SCORE_SET, False, 404), + PermissionTest("Experiment", "published", "other_user", Action.ADD_SCORE_SET, True), + # Anonymous users cannot add score sets to Experiments + PermissionTest("Experiment", "private", "anonymous", Action.ADD_SCORE_SET, False, 404), + PermissionTest("Experiment", "published", "anonymous", Action.ADD_SCORE_SET, False, 401), + ], + ids=lambda tc: f"{tc.user_type}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_add_score_set_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_add_score_set_action helper function directly.""" + assert test_case.entity_state is not None, "Experiment tests must have entity_state" + experiment = entity_helper.create_experiment(test_case.entity_state) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + user_is_owner = test_case.user_type == "owner" + user_is_contributor = test_case.user_type == "contributor" + active_roles = user_data.active_roles if user_data else [] + + result = _handle_add_score_set_action( + user_data, experiment, private, user_is_owner, user_is_contributor, active_roles + ) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code diff --git a/tests/lib/permissions/test_experiment_set.py b/tests/lib/permissions/test_experiment_set.py new file mode 100644 index 00000000..adf109fb --- /dev/null +++ b/tests/lib/permissions/test_experiment_set.py @@ -0,0 +1,286 @@ +# ruff: noqa: E402 + +"""Tests for ExperimentSet permissions module.""" + +import pytest + +pytest.importorskip("fastapi", reason="Skipping permissions tests; FastAPI is required but not installed.") + +from typing import Callable, List +from unittest import mock + +from mavedb.lib.permissions.actions import Action +from mavedb.lib.permissions.experiment_set import ( + _handle_add_experiment_action, + _handle_delete_action, + _handle_read_action, + _handle_update_action, + has_permission, +) +from mavedb.models.enums.user_role import UserRole +from tests.lib.permissions.conftest import EntityTestHelper, PermissionTest + +EXPERIMENT_SET_SUPPORTED_ACTIONS: dict[Action, Callable] = { + Action.READ: _handle_read_action, + Action.UPDATE: _handle_update_action, + Action.DELETE: _handle_delete_action, + Action.ADD_EXPERIMENT: _handle_add_experiment_action, +} + +EXPERIMENT_SET_UNSUPPORTED_ACTIONS: List[Action] = [ + Action.ADD_SCORE_SET, + Action.ADD_ROLE, + Action.LOOKUP, + Action.ADD_BADGE, + Action.CHANGE_RANK, + Action.SET_SCORES, + Action.PUBLISH, +] + + +def test_experiment_set_handles_all_actions() -> None: + """Test that all ExperimentSet actions are either supported or explicitly unsupported.""" + all_actions = set(action for action in Action) + supported = set(EXPERIMENT_SET_SUPPORTED_ACTIONS) + unsupported = set(EXPERIMENT_SET_UNSUPPORTED_ACTIONS) + + assert ( + supported.union(unsupported) == all_actions + ), "Some actions are not categorized as supported or unsupported for experiment sets." + + +class TestExperimentSetHasPermission: + """Test the main has_permission dispatcher function for ExperimentSet entities.""" + + @pytest.mark.parametrize("action, handler", EXPERIMENT_SET_SUPPORTED_ACTIONS.items()) + def test_supported_actions_route_to_correct_action_handler( + self, entity_helper: EntityTestHelper, action: Action, handler: Callable + ) -> None: + """Test that has_permission routes supported actions to their handlers.""" + experiment_set = entity_helper.create_experiment_set() + admin_user = entity_helper.create_user_data("admin") + + with mock.patch("mavedb.lib.permissions.experiment_set." + handler.__name__, wraps=handler) as mock_handler: + has_permission(admin_user, experiment_set, action) + mock_handler.assert_called_once_with( + admin_user, + experiment_set, + experiment_set.private, + False, # admin is not the owner + False, # admin is not a contributor + [UserRole.admin], + ) + + @pytest.mark.parametrize("action", EXPERIMENT_SET_UNSUPPORTED_ACTIONS) + def test_raises_for_unsupported_actions(self, entity_helper: EntityTestHelper, action: Action) -> None: + """Test that unsupported actions raise NotImplementedError with descriptive message.""" + experiment_set = entity_helper.create_experiment_set() + admin_user = entity_helper.create_user_data("admin") + + with pytest.raises(NotImplementedError) as exc_info: + has_permission(admin_user, experiment_set, action) + + error_msg = str(exc_info.value) + assert action.value in error_msg + assert all(a.value in error_msg for a in EXPERIMENT_SET_SUPPORTED_ACTIONS) + + def test_requires_private_attribute(self, entity_helper: EntityTestHelper) -> None: + """Test that ValueError is raised if ExperimentSet.private is None.""" + experiment_set = entity_helper.create_experiment_set() + experiment_set.private = None + admin_user = entity_helper.create_user_data("admin") + + with pytest.raises(ValueError) as exc_info: + has_permission(admin_user, experiment_set, Action.READ) + + assert "private" in str(exc_info.value) + + +class TestExperimentSetReadActionHandler: + """Test the _handle_read_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # Admins can read any ExperimentSet + PermissionTest("ExperimentSet", "published", "admin", Action.READ, True), + PermissionTest("ExperimentSet", "private", "admin", Action.READ, True), + # Owners can read any ExperimentSet they own + PermissionTest("ExperimentSet", "published", "owner", Action.READ, True), + PermissionTest("ExperimentSet", "private", "owner", Action.READ, True), + # Contributors can read any ExperimentSet they contribute to + PermissionTest("ExperimentSet", "published", "contributor", Action.READ, True), + PermissionTest("ExperimentSet", "private", "contributor", Action.READ, True), + # Mappers can read any ExperimentSet (including private) + PermissionTest("ExperimentSet", "published", "mapper", Action.READ, True), + PermissionTest("ExperimentSet", "private", "mapper", Action.READ, True), + # Other users can only read published ExperimentSets + PermissionTest("ExperimentSet", "published", "other_user", Action.READ, True), + PermissionTest("ExperimentSet", "private", "other_user", Action.READ, False, 404), + # Anonymous users can only read published ExperimentSets + PermissionTest("ExperimentSet", "published", "anonymous", Action.READ, True), + PermissionTest("ExperimentSet", "private", "anonymous", Action.READ, False, 404), + ], + ids=lambda tc: f"{tc.user_type}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_read_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_read_action helper function directly.""" + assert test_case.entity_state is not None, "ExperimentSet tests must have entity_state" + experiment_set = entity_helper.create_experiment_set(test_case.entity_state) + user_data = entity_helper.create_user_data(test_case.user_type) + + # Determine user relationship to entity + private = test_case.entity_state == "private" + user_is_owner = test_case.user_type == "owner" + user_is_contributor = test_case.user_type == "contributor" + active_roles = user_data.active_roles if user_data else [] + + # Test the helper function directly + result = _handle_read_action( + user_data, experiment_set, private, user_is_owner, user_is_contributor, active_roles + ) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestExperimentSetUpdateActionHandler: + """Test the _handle_update_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # Admins can update any ExperimentSet + PermissionTest("ExperimentSet", "private", "admin", Action.UPDATE, True), + PermissionTest("ExperimentSet", "published", "admin", Action.UPDATE, True), + # Owners can update any ExperimentSet they own + PermissionTest("ExperimentSet", "private", "owner", Action.UPDATE, True), + PermissionTest("ExperimentSet", "published", "owner", Action.UPDATE, True), + # Contributors can update any ExperimentSet they contribute to + PermissionTest("ExperimentSet", "private", "contributor", Action.UPDATE, True), + PermissionTest("ExperimentSet", "published", "contributor", Action.UPDATE, True), + # Mappers cannot update ExperimentSets + PermissionTest("ExperimentSet", "private", "mapper", Action.UPDATE, False, 404), + PermissionTest("ExperimentSet", "published", "mapper", Action.UPDATE, False, 403), + # Other users cannot update ExperimentSets + PermissionTest("ExperimentSet", "private", "other_user", Action.UPDATE, False, 404), + PermissionTest("ExperimentSet", "published", "other_user", Action.UPDATE, False, 403), + # Anonymous users cannot update ExperimentSets + PermissionTest("ExperimentSet", "private", "anonymous", Action.UPDATE, False, 404), + PermissionTest("ExperimentSet", "published", "anonymous", Action.UPDATE, False, 401), + ], + ids=lambda tc: f"{tc.user_type}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_update_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_update_action helper function directly.""" + assert test_case.entity_state is not None, "ExperimentSet tests must have entity_state" + experiment_set = entity_helper.create_experiment_set(test_case.entity_state) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + user_is_owner = test_case.user_type == "owner" + user_is_contributor = test_case.user_type == "contributor" + active_roles = user_data.active_roles if user_data else [] + + result = _handle_update_action( + user_data, experiment_set, private, user_is_owner, user_is_contributor, active_roles + ) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestExperimentSetDeleteActionHandler: + """Test the _handle_delete_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # Admins can delete any ExperimentSet + PermissionTest("ExperimentSet", "private", "admin", Action.DELETE, True), + PermissionTest("ExperimentSet", "published", "admin", Action.DELETE, True), + # Owners can only delete unpublished ExperimentSets + PermissionTest("ExperimentSet", "private", "owner", Action.DELETE, True), + PermissionTest("ExperimentSet", "published", "owner", Action.DELETE, False, 403), + # Contributors cannot delete + PermissionTest("ExperimentSet", "private", "contributor", Action.DELETE, False, 403), + PermissionTest("ExperimentSet", "published", "contributor", Action.DELETE, False, 403), + # Other users cannot delete + PermissionTest("ExperimentSet", "private", "other_user", Action.DELETE, False, 404), + PermissionTest("ExperimentSet", "published", "other_user", Action.DELETE, False, 403), + # Anonymous users cannot delete + PermissionTest("ExperimentSet", "private", "anonymous", Action.DELETE, False, 404), + PermissionTest("ExperimentSet", "published", "anonymous", Action.DELETE, False, 401), + # Mappers cannot delete + PermissionTest("ExperimentSet", "private", "mapper", Action.DELETE, False, 404), + PermissionTest("ExperimentSet", "published", "mapper", Action.DELETE, False, 403), + ], + ids=lambda tc: f"{tc.user_type}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_delete_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_delete_action helper function directly.""" + assert test_case.entity_state is not None, "ExperimentSet tests must have entity_state" + experiment_set = entity_helper.create_experiment_set(test_case.entity_state) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + user_is_owner = test_case.user_type == "owner" + user_is_contributor = test_case.user_type == "contributor" + active_roles = user_data.active_roles if user_data else [] + + result = _handle_delete_action( + user_data, experiment_set, private, user_is_owner, user_is_contributor, active_roles + ) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestExperimentSetAddExperimentActionHandler: + """Test the _handle_add_experiment_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # Admins can add experiments to any ExperimentSet + PermissionTest("ExperimentSet", "private", "admin", Action.ADD_EXPERIMENT, True), + PermissionTest("ExperimentSet", "published", "admin", Action.ADD_EXPERIMENT, True), + # Owners can add experiments to any ExperimentSet they own + PermissionTest("ExperimentSet", "private", "owner", Action.ADD_EXPERIMENT, True), + PermissionTest("ExperimentSet", "published", "owner", Action.ADD_EXPERIMENT, True), + # Contributors can add experiments to any ExperimentSet they contribute to + PermissionTest("ExperimentSet", "private", "contributor", Action.ADD_EXPERIMENT, True), + PermissionTest("ExperimentSet", "published", "contributor", Action.ADD_EXPERIMENT, True), + # Mappers cannot add experiments to ExperimentSets + PermissionTest("ExperimentSet", "private", "mapper", Action.ADD_EXPERIMENT, False, 404), + PermissionTest("ExperimentSet", "published", "mapper", Action.ADD_EXPERIMENT, False, 403), + # Other users cannot add experiments to ExperimentSets + PermissionTest("ExperimentSet", "private", "other_user", Action.ADD_EXPERIMENT, False, 404), + PermissionTest("ExperimentSet", "published", "other_user", Action.ADD_EXPERIMENT, False, 403), + # Anonymous users cannot add experiments to ExperimentSets + PermissionTest("ExperimentSet", "private", "anonymous", Action.ADD_EXPERIMENT, False, 404), + PermissionTest("ExperimentSet", "published", "anonymous", Action.ADD_EXPERIMENT, False, 401), + ], + ids=lambda tc: f"{tc.user_type}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_add_experiment_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_add_experiment_action helper function directly.""" + assert test_case.entity_state is not None, "ExperimentSet tests must have entity_state" + experiment_set = entity_helper.create_experiment_set(test_case.entity_state) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + user_is_owner = test_case.user_type == "owner" + user_is_contributor = test_case.user_type == "contributor" + active_roles = user_data.active_roles if user_data else [] + + result = _handle_add_experiment_action( + user_data, experiment_set, private, user_is_owner, user_is_contributor, active_roles + ) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code diff --git a/tests/lib/permissions/test_models.py b/tests/lib/permissions/test_models.py new file mode 100644 index 00000000..7627d56a --- /dev/null +++ b/tests/lib/permissions/test_models.py @@ -0,0 +1,45 @@ +# ruff: noqa: E402 + +"""Tests for permissions models module.""" + +import pytest + +pytest.importorskip("fastapi", reason="Skipping permissions tests; FastAPI is required but not installed.") + +from mavedb.lib.permissions.models import PermissionResponse + + +class TestPermissionResponse: + """Test the PermissionResponse class.""" + + def test_permitted_response_creation(self): + """Test creating a PermissionResponse for permitted access.""" + response = PermissionResponse(permitted=True) + + assert response.permitted is True + assert response.http_code is None + assert response.message is None + + def test_denied_response_creation_with_defaults(self): + """Test creating a PermissionResponse for denied access with default values.""" + response = PermissionResponse(permitted=False) + + assert response.permitted is False + assert response.http_code == 403 + assert response.message is None + + def test_denied_response_creation_with_custom_values(self): + """Test creating a PermissionResponse for denied access with custom values.""" + response = PermissionResponse(permitted=False, http_code=404, message="Resource not found") + + assert response.permitted is False + assert response.http_code == 404 + assert response.message == "Resource not found" + + def test_permitted_response_ignores_error_parameters(self): + """Test that permitted responses ignore http_code and message parameters.""" + response = PermissionResponse(permitted=True, http_code=404, message="This should be ignored") + + assert response.permitted is True + assert response.http_code is None + assert response.message is None diff --git a/tests/lib/permissions/test_score_calibration.py b/tests/lib/permissions/test_score_calibration.py new file mode 100644 index 00000000..a3384368 --- /dev/null +++ b/tests/lib/permissions/test_score_calibration.py @@ -0,0 +1,554 @@ +# ruff: noqa: E402 + +"""Tests for ScoreCalibration permissions module.""" + +import pytest + +pytest.importorskip("fastapi", reason="Skipping permissions tests; FastAPI is required but not installed.") + +from typing import Callable, List +from unittest import mock + +from mavedb.lib.permissions.actions import Action +from mavedb.lib.permissions.score_calibration import ( + _handle_change_rank_action, + _handle_delete_action, + _handle_publish_action, + _handle_read_action, + _handle_update_action, + has_permission, +) +from mavedb.models.enums.user_role import UserRole +from tests.lib.permissions.conftest import EntityTestHelper, PermissionTest + +SCORE_CALIBRATION_SUPPORTED_ACTIONS: dict[Action, Callable] = { + Action.READ: _handle_read_action, + Action.UPDATE: _handle_update_action, + Action.DELETE: _handle_delete_action, + Action.PUBLISH: _handle_publish_action, + Action.CHANGE_RANK: _handle_change_rank_action, +} + +SCORE_CALIBRATION_UNSUPPORTED_ACTIONS: List[Action] = [ + Action.ADD_EXPERIMENT, + Action.ADD_SCORE_SET, + Action.ADD_ROLE, + Action.LOOKUP, + Action.ADD_BADGE, + Action.SET_SCORES, +] + + +def test_score_calibration_handles_all_actions() -> None: + """Test that all ScoreCalibration actions are either supported or explicitly unsupported.""" + all_actions = set(action for action in Action) + supported = set(SCORE_CALIBRATION_SUPPORTED_ACTIONS) + unsupported = set(SCORE_CALIBRATION_UNSUPPORTED_ACTIONS) + + assert ( + supported.union(unsupported) == all_actions + ), "Some actions are not categorized as supported or unsupported for score calibrations." + + +class TestScoreCalibrationHasPermission: + """Test the main has_permission dispatcher function for ScoreCalibration entities.""" + + @pytest.mark.parametrize("action, handler", SCORE_CALIBRATION_SUPPORTED_ACTIONS.items()) + def test_supported_actions_route_to_correct_action_handler( + self, entity_helper: EntityTestHelper, action: Action, handler: Callable + ) -> None: + """Test that has_permission routes supported actions to their handlers.""" + score_calibration = entity_helper.create_score_calibration() + admin_user = entity_helper.create_user_data("admin") + + with mock.patch("mavedb.lib.permissions.score_calibration." + handler.__name__, wraps=handler) as mock_handler: + has_permission(admin_user, score_calibration, action) + mock_handler.assert_called_once_with( + admin_user, + score_calibration, + False, # admin is not the owner + False, # admin is not a contributor to score set + score_calibration.private, + [UserRole.admin], + ) + + @pytest.mark.parametrize("action", SCORE_CALIBRATION_UNSUPPORTED_ACTIONS) + def test_raises_for_unsupported_actions(self, entity_helper: EntityTestHelper, action: Action) -> None: + """Test that unsupported actions raise NotImplementedError with descriptive message.""" + score_calibration = entity_helper.create_score_calibration() + admin_user = entity_helper.create_user_data("admin") + + with pytest.raises(NotImplementedError) as exc_info: + has_permission(admin_user, score_calibration, action) + + error_msg = str(exc_info.value) + assert action.value in error_msg + assert all(a.value in error_msg for a in SCORE_CALIBRATION_SUPPORTED_ACTIONS) + + def test_requires_private_attribute(self, entity_helper: EntityTestHelper) -> None: + """Test that ValueError is raised if ScoreCalibration.private is None.""" + score_calibration = entity_helper.create_score_calibration() + score_calibration.private = None + admin_user = entity_helper.create_user_data("admin") + + with pytest.raises(ValueError) as exc_info: + has_permission(admin_user, score_calibration, Action.READ) + + assert "private" in str(exc_info.value) + + +class TestScoreCalibrationReadActionHandler: + """Test the _handle_read_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # System admins: Can read any ScoreCalibration regardless of state or investigator_provided flag + PermissionTest("ScoreCalibration", "published", "admin", Action.READ, True, investigator_provided=True), + PermissionTest("ScoreCalibration", "published", "admin", Action.READ, True, investigator_provided=False), + PermissionTest("ScoreCalibration", "private", "admin", Action.READ, True, investigator_provided=True), + PermissionTest("ScoreCalibration", "private", "admin", Action.READ, True, investigator_provided=False), + # Owners: Can read any ScoreCalibration they created regardless of state or investigator_provided flag + PermissionTest("ScoreCalibration", "published", "owner", Action.READ, True, investigator_provided=True), + PermissionTest("ScoreCalibration", "published", "owner", Action.READ, True, investigator_provided=False), + PermissionTest("ScoreCalibration", "private", "owner", Action.READ, True, investigator_provided=True), + PermissionTest("ScoreCalibration", "private", "owner", Action.READ, True, investigator_provided=False), + # Contributors to associated ScoreSet: Can read published ScoreCalibrations (any type) and private investigator-provided ScoreCalibrations, but NOT private community-provided ones + PermissionTest( + "ScoreCalibration", "published", "contributor", Action.READ, True, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "published", "contributor", Action.READ, True, investigator_provided=False + ), + PermissionTest("ScoreCalibration", "private", "contributor", Action.READ, True, investigator_provided=True), + PermissionTest( + "ScoreCalibration", "private", "contributor", Action.READ, False, 404, investigator_provided=False + ), + # Other users: Can only read published ScoreCalibrations, cannot access any private ones + PermissionTest( + "ScoreCalibration", "published", "other_user", Action.READ, True, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "published", "other_user", Action.READ, True, investigator_provided=False + ), + PermissionTest( + "ScoreCalibration", "private", "other_user", Action.READ, False, 404, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "private", "other_user", Action.READ, False, 404, investigator_provided=False + ), + # Anonymous users: Can only read published ScoreCalibrations, cannot access any private ones + PermissionTest("ScoreCalibration", "published", "anonymous", Action.READ, True, investigator_provided=True), + PermissionTest( + "ScoreCalibration", "published", "anonymous", Action.READ, True, investigator_provided=False + ), + PermissionTest( + "ScoreCalibration", "private", "anonymous", Action.READ, False, 404, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "private", "anonymous", Action.READ, False, 404, investigator_provided=False + ), + ], + ids=lambda tc: f"{tc.user_type}_{tc.entity_state}_{'investigator' if tc.investigator_provided else 'community'}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_read_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_read_action helper function directly.""" + assert test_case.entity_state is not None, "ScoreCalibration tests must have entity_state" + assert test_case.investigator_provided is not None, "ScoreCalibration tests must have investigator_provided" + score_calibration = entity_helper.create_score_calibration( + test_case.entity_state, test_case.investigator_provided + ) + user_data = entity_helper.create_user_data(test_case.user_type) + + # Determine user relationship to entity + private = test_case.entity_state == "private" + user_is_owner = test_case.user_type == "owner" + user_is_contributor_to_score_set = test_case.user_type == "contributor" + active_roles = user_data.active_roles if user_data else [] + + # Test the helper function directly + result = _handle_read_action( + user_data, score_calibration, user_is_owner, user_is_contributor_to_score_set, private, active_roles + ) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestScoreCalibrationUpdateActionHandler: + """Test the _handle_update_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # System admins: Can update any ScoreCalibration regardless of state or investigator_provided flag + PermissionTest("ScoreCalibration", "private", "admin", Action.UPDATE, True, investigator_provided=True), + PermissionTest("ScoreCalibration", "private", "admin", Action.UPDATE, True, investigator_provided=False), + PermissionTest("ScoreCalibration", "published", "admin", Action.UPDATE, True, investigator_provided=True), + PermissionTest("ScoreCalibration", "published", "admin", Action.UPDATE, True, investigator_provided=False), + # Owners: Can update only their own private ScoreCalibrations, cannot update published ones (even their own) + PermissionTest("ScoreCalibration", "private", "owner", Action.UPDATE, True, investigator_provided=True), + PermissionTest("ScoreCalibration", "private", "owner", Action.UPDATE, True, investigator_provided=False), + PermissionTest( + "ScoreCalibration", "published", "owner", Action.UPDATE, False, 403, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "published", "owner", Action.UPDATE, False, 403, investigator_provided=False + ), + # Contributors to associated ScoreSet: Can update only private investigator-provided ScoreCalibrations, cannot update community-provided or published ones + PermissionTest( + "ScoreCalibration", "private", "contributor", Action.UPDATE, True, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "private", "contributor", Action.UPDATE, False, 404, investigator_provided=False + ), + PermissionTest( + "ScoreCalibration", "published", "contributor", Action.UPDATE, False, 403, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "published", "contributor", Action.UPDATE, False, 403, investigator_provided=False + ), + # Other users: Cannot update any ScoreCalibrations + PermissionTest( + "ScoreCalibration", "private", "other_user", Action.UPDATE, False, 404, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "private", "other_user", Action.UPDATE, False, 404, investigator_provided=False + ), + PermissionTest( + "ScoreCalibration", "published", "other_user", Action.UPDATE, False, 403, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "published", "other_user", Action.UPDATE, False, 403, investigator_provided=False + ), + # Anonymous users: Cannot update any ScoreCalibrations + PermissionTest( + "ScoreCalibration", "private", "anonymous", Action.UPDATE, False, 404, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "private", "anonymous", Action.UPDATE, False, 404, investigator_provided=False + ), + PermissionTest( + "ScoreCalibration", "published", "anonymous", Action.UPDATE, False, 401, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "published", "anonymous", Action.UPDATE, False, 401, investigator_provided=False + ), + ], + ids=lambda tc: f"{tc.user_type}_{tc.entity_state}_{'investigator' if tc.investigator_provided else 'community'}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_update_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_update_action helper function directly.""" + assert test_case.entity_state is not None, "ScoreCalibration tests must have entity_state" + assert test_case.investigator_provided is not None, "ScoreCalibration tests must have investigator_provided" + score_calibration = entity_helper.create_score_calibration( + test_case.entity_state, test_case.investigator_provided + ) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + user_is_owner = test_case.user_type == "owner" + user_is_contributor_to_score_set = test_case.user_type == "contributor" + active_roles = user_data.active_roles if user_data else [] + + result = _handle_update_action( + user_data, score_calibration, user_is_owner, user_is_contributor_to_score_set, private, active_roles + ) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestScoreCalibrationDeleteActionHandler: + """Test the _handle_delete_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # System admins: Can delete any ScoreCalibration regardless of state or investigator_provided flag + PermissionTest("ScoreCalibration", "private", "admin", Action.DELETE, True, investigator_provided=True), + PermissionTest("ScoreCalibration", "private", "admin", Action.DELETE, True, investigator_provided=False), + PermissionTest("ScoreCalibration", "published", "admin", Action.DELETE, True, investigator_provided=True), + PermissionTest("ScoreCalibration", "published", "admin", Action.DELETE, True, investigator_provided=False), + # Owners: Can delete only their own private ScoreCalibrations, cannot delete published ones (even their own) + PermissionTest("ScoreCalibration", "private", "owner", Action.DELETE, True, investigator_provided=True), + PermissionTest("ScoreCalibration", "private", "owner", Action.DELETE, True, investigator_provided=False), + PermissionTest( + "ScoreCalibration", "published", "owner", Action.DELETE, False, 403, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "published", "owner", Action.DELETE, False, 403, investigator_provided=False + ), + # Contributors to associated ScoreSet: Cannot delete any ScoreCalibrations (even investigator-provided ones they can read/update) + PermissionTest( + "ScoreCalibration", "private", "contributor", Action.DELETE, False, 403, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "private", "contributor", Action.DELETE, False, 404, investigator_provided=False + ), + PermissionTest( + "ScoreCalibration", "published", "contributor", Action.DELETE, False, 403, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "published", "contributor", Action.DELETE, False, 403, investigator_provided=False + ), + # Other users: Cannot delete any ScoreCalibrations + PermissionTest( + "ScoreCalibration", "private", "other_user", Action.DELETE, False, 404, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "private", "other_user", Action.DELETE, False, 404, investigator_provided=False + ), + PermissionTest( + "ScoreCalibration", "published", "other_user", Action.DELETE, False, 403, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "published", "other_user", Action.DELETE, False, 403, investigator_provided=False + ), + # Anonymous users: Cannot delete any ScoreCalibrations + PermissionTest( + "ScoreCalibration", "private", "anonymous", Action.DELETE, False, 404, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "private", "anonymous", Action.DELETE, False, 404, investigator_provided=False + ), + PermissionTest( + "ScoreCalibration", "published", "anonymous", Action.DELETE, False, 401, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "published", "anonymous", Action.DELETE, False, 401, investigator_provided=False + ), + ], + ids=lambda tc: f"{tc.user_type}_{tc.entity_state}_{'investigator' if tc.investigator_provided else 'community'}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_delete_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_delete_action helper function directly.""" + assert test_case.entity_state is not None, "ScoreCalibration tests must have entity_state" + assert test_case.investigator_provided is not None, "ScoreCalibration tests must have investigator_provided" + score_calibration = entity_helper.create_score_calibration( + test_case.entity_state, test_case.investigator_provided + ) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + user_is_owner = test_case.user_type == "owner" + user_is_contributor_to_score_set = test_case.user_type == "contributor" + active_roles = user_data.active_roles if user_data else [] + + result = _handle_delete_action( + user_data, score_calibration, user_is_owner, user_is_contributor_to_score_set, private, active_roles + ) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestScoreCalibrationPublishActionHandler: + """Test the _handle_publish_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # System admins: Can publish any ScoreCalibration regardless of state or investigator_provided flag + PermissionTest("ScoreCalibration", "private", "admin", Action.PUBLISH, True, investigator_provided=True), + PermissionTest("ScoreCalibration", "private", "admin", Action.PUBLISH, True, investigator_provided=False), + PermissionTest("ScoreCalibration", "published", "admin", Action.PUBLISH, True, investigator_provided=True), + PermissionTest("ScoreCalibration", "published", "admin", Action.PUBLISH, True, investigator_provided=False), + # Owners: Can publish their own ScoreCalibrations regardless of state or investigator_provided flag + PermissionTest("ScoreCalibration", "private", "owner", Action.PUBLISH, True, investigator_provided=True), + PermissionTest("ScoreCalibration", "private", "owner", Action.PUBLISH, True, investigator_provided=False), + PermissionTest("ScoreCalibration", "published", "owner", Action.PUBLISH, True, investigator_provided=True), + PermissionTest("ScoreCalibration", "published", "owner", Action.PUBLISH, True, investigator_provided=False), + # Contributors to associated ScoreSet: Cannot publish any ScoreCalibrations (even investigator-provided ones they can read/update) + PermissionTest( + "ScoreCalibration", "private", "contributor", Action.PUBLISH, False, 403, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "private", "contributor", Action.PUBLISH, False, 404, investigator_provided=False + ), + PermissionTest( + "ScoreCalibration", "published", "contributor", Action.PUBLISH, False, 403, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "published", "contributor", Action.PUBLISH, False, 403, investigator_provided=False + ), + # Other users: Cannot publish any ScoreCalibrations + PermissionTest( + "ScoreCalibration", "private", "other_user", Action.PUBLISH, False, 404, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "private", "other_user", Action.PUBLISH, False, 404, investigator_provided=False + ), + PermissionTest( + "ScoreCalibration", "published", "other_user", Action.PUBLISH, False, 403, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "published", "other_user", Action.PUBLISH, False, 403, investigator_provided=False + ), + # Anonymous users: Cannot publish any ScoreCalibrations + PermissionTest( + "ScoreCalibration", "private", "anonymous", Action.PUBLISH, False, 404, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "private", "anonymous", Action.PUBLISH, False, 404, investigator_provided=False + ), + PermissionTest( + "ScoreCalibration", "published", "anonymous", Action.PUBLISH, False, 401, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "published", "anonymous", Action.PUBLISH, False, 401, investigator_provided=False + ), + ], + ids=lambda tc: f"{tc.user_type}_{tc.entity_state}_{'investigator' if tc.investigator_provided else 'community'}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_publish_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_publish_action helper function directly.""" + assert test_case.entity_state is not None, "ScoreCalibration tests must have entity_state" + assert test_case.investigator_provided is not None, "ScoreCalibration tests must have investigator_provided" + score_calibration = entity_helper.create_score_calibration( + test_case.entity_state, test_case.investigator_provided + ) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + user_is_owner = test_case.user_type == "owner" + user_is_contributor_to_score_set = test_case.user_type == "contributor" + active_roles = user_data.active_roles if user_data else [] + + result = _handle_publish_action( + user_data, score_calibration, user_is_owner, user_is_contributor_to_score_set, private, active_roles + ) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestScoreCalibrationChangeRankActionHandler: + """Test the _handle_change_rank_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # System admins: Can change rank of any ScoreCalibration regardless of state or investigator_provided flag + PermissionTest( + "ScoreCalibration", "private", "admin", Action.CHANGE_RANK, True, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "private", "admin", Action.CHANGE_RANK, True, investigator_provided=False + ), + PermissionTest( + "ScoreCalibration", "published", "admin", Action.CHANGE_RANK, True, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "published", "admin", Action.CHANGE_RANK, True, investigator_provided=False + ), + # Owners: Can change rank of their own ScoreCalibrations regardless of state or investigator_provided flag + PermissionTest( + "ScoreCalibration", "private", "owner", Action.CHANGE_RANK, True, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "private", "owner", Action.CHANGE_RANK, True, investigator_provided=False + ), + PermissionTest( + "ScoreCalibration", "published", "owner", Action.CHANGE_RANK, True, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "published", "owner", Action.CHANGE_RANK, True, investigator_provided=False + ), + # Contributors to associated ScoreSet: Can change rank of investigator-provided ScoreCalibrations (private or published), but cannot change rank of community-provided ones + PermissionTest( + "ScoreCalibration", "private", "contributor", Action.CHANGE_RANK, True, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", + "private", + "contributor", + Action.CHANGE_RANK, + False, + 404, + investigator_provided=False, + ), + PermissionTest( + "ScoreCalibration", "published", "contributor", Action.CHANGE_RANK, True, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", + "published", + "contributor", + Action.CHANGE_RANK, + False, + 403, + investigator_provided=False, + ), + # Other users: Cannot change rank of any ScoreCalibrations + PermissionTest( + "ScoreCalibration", "private", "other_user", Action.CHANGE_RANK, False, 404, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "private", "other_user", Action.CHANGE_RANK, False, 404, investigator_provided=False + ), + PermissionTest( + "ScoreCalibration", + "published", + "other_user", + Action.CHANGE_RANK, + False, + 403, + investigator_provided=True, + ), + PermissionTest( + "ScoreCalibration", + "published", + "other_user", + Action.CHANGE_RANK, + False, + 403, + investigator_provided=False, + ), + # Anonymous users: Cannot change rank of any ScoreCalibrations + PermissionTest( + "ScoreCalibration", "private", "anonymous", Action.CHANGE_RANK, False, 404, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", "private", "anonymous", Action.CHANGE_RANK, False, 404, investigator_provided=False + ), + PermissionTest( + "ScoreCalibration", "published", "anonymous", Action.CHANGE_RANK, False, 401, investigator_provided=True + ), + PermissionTest( + "ScoreCalibration", + "published", + "anonymous", + Action.CHANGE_RANK, + False, + 401, + investigator_provided=False, + ), + ], + ids=lambda tc: f"{tc.user_type}_{tc.entity_state}_{'investigator' if tc.investigator_provided else 'community'}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_change_rank_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_change_rank_action helper function directly.""" + assert test_case.entity_state is not None, "ScoreCalibration tests must have entity_state" + assert test_case.investigator_provided is not None, "ScoreCalibration tests must have investigator_provided" + score_calibration = entity_helper.create_score_calibration( + test_case.entity_state, test_case.investigator_provided + ) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + user_is_owner = test_case.user_type == "owner" + user_is_contributor_to_score_set = test_case.user_type == "contributor" + active_roles = user_data.active_roles if user_data else [] + + result = _handle_change_rank_action( + user_data, score_calibration, user_is_owner, user_is_contributor_to_score_set, private, active_roles + ) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code diff --git a/tests/lib/permissions/test_score_set.py b/tests/lib/permissions/test_score_set.py new file mode 100644 index 00000000..2349359f --- /dev/null +++ b/tests/lib/permissions/test_score_set.py @@ -0,0 +1,326 @@ +# ruff: noqa: E402 + +"""Tests for ScoreSet permissions module.""" + +import pytest + +pytest.importorskip("fastapi", reason="Skipping permissions tests; FastAPI is required but not installed.") + +from typing import Callable, List +from unittest import mock + +from mavedb.lib.permissions.actions import Action +from mavedb.lib.permissions.score_set import ( + _handle_delete_action, + _handle_publish_action, + _handle_read_action, + _handle_set_scores_action, + _handle_update_action, + has_permission, +) +from mavedb.models.enums.user_role import UserRole +from tests.lib.permissions.conftest import EntityTestHelper, PermissionTest + +SCORE_SET_SUPPORTED_ACTIONS: dict[Action, Callable] = { + Action.READ: _handle_read_action, + Action.UPDATE: _handle_update_action, + Action.DELETE: _handle_delete_action, + Action.SET_SCORES: _handle_set_scores_action, + Action.PUBLISH: _handle_publish_action, +} + +SCORE_SET_UNSUPPORTED_ACTIONS: List[Action] = [ + Action.ADD_EXPERIMENT, + Action.ADD_SCORE_SET, + Action.ADD_ROLE, + Action.LOOKUP, + Action.ADD_BADGE, + Action.CHANGE_RANK, +] + + +def test_score_set_handles_all_actions() -> None: + """Test that all ScoreSet actions are either supported or explicitly unsupported.""" + all_actions = set(action for action in Action) + supported = set(SCORE_SET_SUPPORTED_ACTIONS) + unsupported = set(SCORE_SET_UNSUPPORTED_ACTIONS) + + assert ( + supported.union(unsupported) == all_actions + ), "Some actions are not categorized as supported or unsupported for score sets." + + +class TestScoreSetHasPermission: + """Test the main has_permission dispatcher function for ScoreSet entities.""" + + @pytest.mark.parametrize("action, handler", SCORE_SET_SUPPORTED_ACTIONS.items()) + def test_supported_actions_route_to_correct_action_handler( + self, entity_helper: EntityTestHelper, action: Action, handler: Callable + ) -> None: + """Test that has_permission routes supported actions to their handlers.""" + score_set = entity_helper.create_score_set() + admin_user = entity_helper.create_user_data("admin") + + with mock.patch("mavedb.lib.permissions.score_set." + handler.__name__, wraps=handler) as mock_handler: + has_permission(admin_user, score_set, action) + mock_handler.assert_called_once_with( + admin_user, + score_set, + score_set.private, + False, # admin is not the owner + False, # admin is not a contributor + [UserRole.admin], + ) + + @pytest.mark.parametrize("action", SCORE_SET_UNSUPPORTED_ACTIONS) + def test_raises_for_unsupported_actions(self, entity_helper: EntityTestHelper, action: Action) -> None: + """Test that unsupported actions raise NotImplementedError with descriptive message.""" + score_set = entity_helper.create_score_set() + admin_user = entity_helper.create_user_data("admin") + + with pytest.raises(NotImplementedError) as exc_info: + has_permission(admin_user, score_set, action) + + error_msg = str(exc_info.value) + assert action.value in error_msg + assert all(a.value in error_msg for a in SCORE_SET_SUPPORTED_ACTIONS) + + def test_requires_private_attribute(self, entity_helper: EntityTestHelper) -> None: + """Test that ValueError is raised if ScoreSet.private is None.""" + score_set = entity_helper.create_score_set() + score_set.private = None + admin_user = entity_helper.create_user_data("admin") + + with pytest.raises(ValueError) as exc_info: + has_permission(admin_user, score_set, Action.READ) + + assert "private" in str(exc_info.value) + + +class TestScoreSetReadActionHandler: + """Test the _handle_read_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # Admins can read any ScoreSet + PermissionTest("ScoreSet", "published", "admin", Action.READ, True), + PermissionTest("ScoreSet", "private", "admin", Action.READ, True), + # Owners can read any ScoreSet they own + PermissionTest("ScoreSet", "published", "owner", Action.READ, True), + PermissionTest("ScoreSet", "private", "owner", Action.READ, True), + # Contributors can read any ScoreSet they contribute to + PermissionTest("ScoreSet", "published", "contributor", Action.READ, True), + PermissionTest("ScoreSet", "private", "contributor", Action.READ, True), + # Mappers can read any ScoreSet (including private) + PermissionTest("ScoreSet", "published", "mapper", Action.READ, True), + PermissionTest("ScoreSet", "private", "mapper", Action.READ, True), + # Other users can only read published ScoreSets + PermissionTest("ScoreSet", "published", "other_user", Action.READ, True), + PermissionTest("ScoreSet", "private", "other_user", Action.READ, False, 404), + # Anonymous users can only read published ScoreSets + PermissionTest("ScoreSet", "published", "anonymous", Action.READ, True), + PermissionTest("ScoreSet", "private", "anonymous", Action.READ, False, 404), + ], + ids=lambda tc: f"{tc.user_type}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_read_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_read_action helper function directly.""" + assert test_case.entity_state is not None, "ScoreSet tests must have entity_state" + score_set = entity_helper.create_score_set(test_case.entity_state) + user_data = entity_helper.create_user_data(test_case.user_type) + + # Determine user relationship to entity + private = test_case.entity_state == "private" + user_is_owner = test_case.user_type == "owner" + user_is_contributor = test_case.user_type == "contributor" + active_roles = user_data.active_roles if user_data else [] + + # Test the helper function directly + result = _handle_read_action(user_data, score_set, private, user_is_owner, user_is_contributor, active_roles) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestScoreSetUpdateActionHandler: + """Test the _handle_update_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # Admins can update any ScoreSet + PermissionTest("ScoreSet", "private", "admin", Action.UPDATE, True), + PermissionTest("ScoreSet", "published", "admin", Action.UPDATE, True), + # Owners can update any ScoreSet they own + PermissionTest("ScoreSet", "private", "owner", Action.UPDATE, True), + PermissionTest("ScoreSet", "published", "owner", Action.UPDATE, True), + # Contributors can update any ScoreSet they contribute to + PermissionTest("ScoreSet", "private", "contributor", Action.UPDATE, True), + PermissionTest("ScoreSet", "published", "contributor", Action.UPDATE, True), + # Mappers cannot update ScoreSets + PermissionTest("ScoreSet", "private", "mapper", Action.UPDATE, False, 404), + PermissionTest("ScoreSet", "published", "mapper", Action.UPDATE, False, 403), + # Other users cannot update ScoreSets + PermissionTest("ScoreSet", "private", "other_user", Action.UPDATE, False, 404), + PermissionTest("ScoreSet", "published", "other_user", Action.UPDATE, False, 403), + # Anonymous users cannot update ScoreSets + PermissionTest("ScoreSet", "private", "anonymous", Action.UPDATE, False, 404), + PermissionTest("ScoreSet", "published", "anonymous", Action.UPDATE, False, 401), + ], + ids=lambda tc: f"{tc.user_type}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_update_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_update_action helper function directly.""" + assert test_case.entity_state is not None, "ScoreSet tests must have entity_state" + score_set = entity_helper.create_score_set(test_case.entity_state) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + user_is_owner = test_case.user_type == "owner" + user_is_contributor = test_case.user_type == "contributor" + active_roles = user_data.active_roles if user_data else [] + + result = _handle_update_action(user_data, score_set, private, user_is_owner, user_is_contributor, active_roles) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestScoreSetDeleteActionHandler: + """Test the _handle_delete_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # Admins can delete any ScoreSet + PermissionTest("ScoreSet", "private", "admin", Action.DELETE, True), + PermissionTest("ScoreSet", "published", "admin", Action.DELETE, True), + # Owners can only delete unpublished ScoreSets + PermissionTest("ScoreSet", "private", "owner", Action.DELETE, True), + PermissionTest("ScoreSet", "published", "owner", Action.DELETE, False, 403), + # Contributors cannot delete + PermissionTest("ScoreSet", "private", "contributor", Action.DELETE, False, 403), + PermissionTest("ScoreSet", "published", "contributor", Action.DELETE, False, 403), + # Other users cannot delete + PermissionTest("ScoreSet", "private", "other_user", Action.DELETE, False, 404), + PermissionTest("ScoreSet", "published", "other_user", Action.DELETE, False, 403), + # Anonymous users cannot delete + PermissionTest("ScoreSet", "private", "anonymous", Action.DELETE, False, 404), + PermissionTest("ScoreSet", "published", "anonymous", Action.DELETE, False, 401), + # Mappers cannot delete + PermissionTest("ScoreSet", "private", "mapper", Action.DELETE, False, 404), + PermissionTest("ScoreSet", "published", "mapper", Action.DELETE, False, 403), + ], + ids=lambda tc: f"{tc.user_type}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_delete_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_delete_action helper function directly.""" + assert test_case.entity_state is not None, "ScoreSet tests must have entity_state" + score_set = entity_helper.create_score_set(test_case.entity_state) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + user_is_owner = test_case.user_type == "owner" + user_is_contributor = test_case.user_type == "contributor" + active_roles = user_data.active_roles if user_data else [] + + result = _handle_delete_action(user_data, score_set, private, user_is_owner, user_is_contributor, active_roles) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestScoreSetSetScoresActionHandler: + """Test the _handle_set_scores_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # Admins can set scores on any ScoreSet + PermissionTest("ScoreSet", "private", "admin", Action.SET_SCORES, True), + PermissionTest("ScoreSet", "published", "admin", Action.SET_SCORES, True), + # Owners can set scores on any ScoreSet they own + PermissionTest("ScoreSet", "private", "owner", Action.SET_SCORES, True), + PermissionTest("ScoreSet", "published", "owner", Action.SET_SCORES, True), + # Contributors can set scores on any ScoreSet they contribute to + PermissionTest("ScoreSet", "private", "contributor", Action.SET_SCORES, True), + PermissionTest("ScoreSet", "published", "contributor", Action.SET_SCORES, True), + # Mappers cannot set scores on ScoreSets + PermissionTest("ScoreSet", "private", "mapper", Action.SET_SCORES, False, 404), + PermissionTest("ScoreSet", "published", "mapper", Action.SET_SCORES, False, 403), + # Other users cannot set scores on ScoreSets + PermissionTest("ScoreSet", "private", "other_user", Action.SET_SCORES, False, 404), + PermissionTest("ScoreSet", "published", "other_user", Action.SET_SCORES, False, 403), + # Anonymous users cannot set scores on ScoreSets + PermissionTest("ScoreSet", "private", "anonymous", Action.SET_SCORES, False, 404), + PermissionTest("ScoreSet", "published", "anonymous", Action.SET_SCORES, False, 401), + ], + ids=lambda tc: f"{tc.user_type}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_set_scores_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_set_scores_action helper function directly.""" + assert test_case.entity_state is not None, "ScoreSet tests must have entity_state" + score_set = entity_helper.create_score_set(test_case.entity_state) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + user_is_owner = test_case.user_type == "owner" + user_is_contributor = test_case.user_type == "contributor" + active_roles = user_data.active_roles if user_data else [] + + result = _handle_set_scores_action( + user_data, score_set, private, user_is_owner, user_is_contributor, active_roles + ) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestScoreSetPublishActionHandler: + """Test the _handle_publish_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # Admins can publish any ScoreSet + PermissionTest("ScoreSet", "private", "admin", Action.PUBLISH, True), + PermissionTest("ScoreSet", "published", "admin", Action.PUBLISH, True), + # Owners can publish any ScoreSet they own + PermissionTest("ScoreSet", "private", "owner", Action.PUBLISH, True), + PermissionTest("ScoreSet", "published", "owner", Action.PUBLISH, True), + # Contributors cannot publish ScoreSets they contribute to + PermissionTest("ScoreSet", "private", "contributor", Action.PUBLISH, False, 403), + PermissionTest("ScoreSet", "published", "contributor", Action.PUBLISH, False, 403), + # Mappers cannot publish ScoreSets + PermissionTest("ScoreSet", "private", "mapper", Action.PUBLISH, False, 404), + PermissionTest("ScoreSet", "published", "mapper", Action.PUBLISH, False, 403), + # Other users cannot publish ScoreSets + PermissionTest("ScoreSet", "private", "other_user", Action.PUBLISH, False, 404), + PermissionTest("ScoreSet", "published", "other_user", Action.PUBLISH, False, 403), + # Anonymous users cannot publish ScoreSets + PermissionTest("ScoreSet", "private", "anonymous", Action.PUBLISH, False, 404), + PermissionTest("ScoreSet", "published", "anonymous", Action.PUBLISH, False, 401), + ], + ids=lambda tc: f"{tc.user_type}_{tc.entity_state}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_publish_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_publish_action helper function directly.""" + assert test_case.entity_state is not None, "ScoreSet tests must have entity_state" + score_set = entity_helper.create_score_set(test_case.entity_state) + user_data = entity_helper.create_user_data(test_case.user_type) + + private = test_case.entity_state == "private" + user_is_owner = test_case.user_type == "owner" + user_is_contributor = test_case.user_type == "contributor" + active_roles = user_data.active_roles if user_data else [] + + result = _handle_publish_action(user_data, score_set, private, user_is_owner, user_is_contributor, active_roles) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code diff --git a/tests/lib/permissions/test_user.py b/tests/lib/permissions/test_user.py new file mode 100644 index 00000000..b4efa876 --- /dev/null +++ b/tests/lib/permissions/test_user.py @@ -0,0 +1,237 @@ +# ruff: noqa: E402 + +"""Tests for User permissions module.""" + +import pytest + +pytest.importorskip("fastapi", reason="Skipping permissions tests; FastAPI is required but not installed.") + +from typing import Callable, List +from unittest import mock + +from mavedb.lib.permissions.actions import Action +from mavedb.lib.permissions.user import ( + _handle_add_role_action, + _handle_lookup_action, + _handle_read_action, + _handle_update_action, + has_permission, +) +from mavedb.models.enums.user_role import UserRole +from tests.lib.permissions.conftest import EntityTestHelper, PermissionTest + +USER_SUPPORTED_ACTIONS: dict[Action, Callable] = { + Action.READ: _handle_read_action, + Action.UPDATE: _handle_update_action, + Action.LOOKUP: _handle_lookup_action, + Action.ADD_ROLE: _handle_add_role_action, +} + +USER_UNSUPPORTED_ACTIONS: List[Action] = [ + Action.DELETE, + Action.ADD_EXPERIMENT, + Action.ADD_SCORE_SET, + Action.ADD_BADGE, + Action.CHANGE_RANK, + Action.SET_SCORES, + Action.PUBLISH, +] + + +def test_user_handles_all_actions() -> None: + """Test that all User actions are either supported or explicitly unsupported.""" + all_actions = set(action for action in Action) + supported = set(USER_SUPPORTED_ACTIONS) + unsupported = set(USER_UNSUPPORTED_ACTIONS) + + assert ( + supported.union(unsupported) == all_actions + ), "Some actions are not categorized as supported or unsupported for users." + + +class TestUserHasPermission: + """Test the main has_permission dispatcher function for User entities.""" + + @pytest.mark.parametrize("action, handler", USER_SUPPORTED_ACTIONS.items()) + def test_supported_actions_route_to_correct_action_handler( + self, entity_helper: EntityTestHelper, action: Action, handler: Callable + ) -> None: + """Test that has_permission routes supported actions to their handlers.""" + user = entity_helper.create_user() + admin_user = entity_helper.create_user_data("admin") + + with mock.patch("mavedb.lib.permissions.user." + handler.__name__, wraps=handler) as mock_handler: + has_permission(admin_user, user, action) + mock_handler.assert_called_once_with( + admin_user, + user, + False, # admin is not viewing self + [UserRole.admin], + ) + + @pytest.mark.parametrize("action", USER_UNSUPPORTED_ACTIONS) + def test_raises_for_unsupported_actions(self, entity_helper: EntityTestHelper, action: Action) -> None: + """Test that unsupported actions raise NotImplementedError with descriptive message.""" + user = entity_helper.create_user() + admin_user = entity_helper.create_user_data("admin") + + with pytest.raises(NotImplementedError) as exc_info: + has_permission(admin_user, user, action) + + error_msg = str(exc_info.value) + assert action.value in error_msg + assert all(a.value in error_msg for a in USER_SUPPORTED_ACTIONS) + + +class TestUserReadActionHandler: + """Test the _handle_read_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # Admins can read any User profile + PermissionTest("User", None, "admin", Action.READ, True), + # Users can read their own profile + PermissionTest("User", None, "self", Action.READ, True), + # Owners cannot read other user profiles (no special privilege) + PermissionTest("User", None, "owner", Action.READ, False, 403), + # Contributors cannot read other user profiles + PermissionTest("User", None, "contributor", Action.READ, False, 403), + # Mappers cannot read other user profiles + PermissionTest("User", None, "mapper", Action.READ, False, 403), + # Other users cannot read other user profiles + PermissionTest("User", None, "other_user", Action.READ, False, 403), + # Anonymous users cannot read user profiles + PermissionTest("User", None, "anonymous", Action.READ, False, 401), + ], + ids=lambda tc: f"{tc.user_type}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_read_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_read_action helper function directly.""" + user = entity_helper.create_user() + user_data = entity_helper.create_user_data(test_case.user_type) + + # Determine user relationship to entity + user_is_self = test_case.user_type == "self" + active_roles = user_data.active_roles if user_data else [] + + # Test the helper function directly + result = _handle_read_action(user_data, user, user_is_self, active_roles) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestUserUpdateActionHandler: + """Test the _handle_update_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # Admins can update any User profile + PermissionTest("User", None, "admin", Action.UPDATE, True), + # Users can update their own profile + PermissionTest("User", None, "self", Action.UPDATE, True), + # Owners cannot update other user profiles (no special privilege) + PermissionTest("User", None, "owner", Action.UPDATE, False, 403), + # Contributors cannot update other user profiles + PermissionTest("User", None, "contributor", Action.UPDATE, False, 403), + # Mappers cannot update other user profiles + PermissionTest("User", None, "mapper", Action.UPDATE, False, 403), + # Other users cannot update other user profiles + PermissionTest("User", None, "other_user", Action.UPDATE, False, 403), + # Anonymous users cannot update user profiles + PermissionTest("User", None, "anonymous", Action.UPDATE, False, 401), + ], + ids=lambda tc: f"{tc.user_type}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_update_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_update_action helper function directly.""" + user = entity_helper.create_user() + user_data = entity_helper.create_user_data(test_case.user_type) + + user_is_self = test_case.user_type == "self" + active_roles = user_data.active_roles if user_data else [] + + result = _handle_update_action(user_data, user, user_is_self, active_roles) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestUserLookupActionHandler: + """Test the _handle_lookup_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # Admins can lookup any User + PermissionTest("User", None, "admin", Action.LOOKUP, True), + # Users can lookup themselves + PermissionTest("User", None, "self", Action.LOOKUP, True), + # Owners can lookup other users (authenticated user privilege) + PermissionTest("User", None, "owner", Action.LOOKUP, True), + # Contributors can lookup other users (authenticated user privilege) + PermissionTest("User", None, "contributor", Action.LOOKUP, True), + # Mappers can lookup other users (authenticated user privilege) + PermissionTest("User", None, "mapper", Action.LOOKUP, True), + # Other authenticated users can lookup other users + PermissionTest("User", None, "other_user", Action.LOOKUP, True), + # Anonymous users cannot lookup users + PermissionTest("User", None, "anonymous", Action.LOOKUP, False, 401), + ], + ids=lambda tc: f"{tc.user_type}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_lookup_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_lookup_action helper function directly.""" + user = entity_helper.create_user() + user_data = entity_helper.create_user_data(test_case.user_type) + + user_is_self = test_case.user_type == "self" + active_roles = user_data.active_roles if user_data else [] + + result = _handle_lookup_action(user_data, user, user_is_self, active_roles) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code + + +class TestUserAddRoleActionHandler: + """Test the _handle_add_role_action helper function directly.""" + + @pytest.mark.parametrize( + "test_case", + [ + # Admins can add roles to any User + PermissionTest("User", None, "admin", Action.ADD_ROLE, True), + # Users cannot add roles to themselves + PermissionTest("User", None, "self", Action.ADD_ROLE, False, 403), + # Owners cannot add roles to other users + PermissionTest("User", None, "owner", Action.ADD_ROLE, False, 403), + # Contributors cannot add roles to other users + PermissionTest("User", None, "contributor", Action.ADD_ROLE, False, 403), + # Mappers cannot add roles to other users + PermissionTest("User", None, "mapper", Action.ADD_ROLE, False, 403), + # Other users cannot add roles to other users + PermissionTest("User", None, "other_user", Action.ADD_ROLE, False, 403), + # Anonymous users cannot add roles to users + PermissionTest("User", None, "anonymous", Action.ADD_ROLE, False, 401), + ], + ids=lambda tc: f"{tc.user_type}_{tc.action.value}_{'permitted' if tc.should_be_permitted else 'denied'}", + ) + def test_handle_add_role_action(self, test_case: PermissionTest, entity_helper: EntityTestHelper) -> None: + """Test _handle_add_role_action helper function directly.""" + user = entity_helper.create_user() + user_data = entity_helper.create_user_data(test_case.user_type) + + user_is_self = test_case.user_type == "self" + active_roles = user_data.active_roles if user_data else [] + + result = _handle_add_role_action(user_data, user, user_is_self, active_roles) + + assert result.permitted == test_case.should_be_permitted + if not test_case.should_be_permitted and test_case.expected_code: + assert result.http_code == test_case.expected_code diff --git a/tests/lib/permissions/test_utils.py b/tests/lib/permissions/test_utils.py new file mode 100644 index 00000000..0cc8d76a --- /dev/null +++ b/tests/lib/permissions/test_utils.py @@ -0,0 +1,223 @@ +# ruff: noqa: E402 + +"""Tests for permissions utils module.""" + +import pytest + +pytest.importorskip("fastapi", reason="Skipping permissions tests; FastAPI is required but not installed.") + +from unittest.mock import Mock + +from mavedb.lib.permissions.utils import deny_action_for_entity, roles_permitted +from mavedb.models.enums.contribution_role import ContributionRole +from mavedb.models.enums.user_role import UserRole + + +class TestRolesPermitted: + """Test the roles_permitted utility function.""" + + def test_user_role_permission_granted(self): + """Test that permission is granted when user has a permitted role.""" + user_roles = [UserRole.admin, UserRole.mapper] + permitted_roles = [UserRole.admin] + + result = roles_permitted(user_roles, permitted_roles) + assert result is True + + def test_user_role_permission_denied(self): + """Test that permission is denied when user lacks permitted roles.""" + user_roles = [UserRole.mapper] + permitted_roles = [UserRole.admin] + + result = roles_permitted(user_roles, permitted_roles) + assert result is False + + def test_contribution_role_permission_granted(self): + """Test that permission is granted for contribution roles.""" + user_roles = [ContributionRole.admin, ContributionRole.editor] + permitted_roles = [ContributionRole.admin] + + result = roles_permitted(user_roles, permitted_roles) + assert result is True + + def test_contribution_role_permission_denied(self): + """Test that permission is denied for contribution roles.""" + user_roles = [ContributionRole.viewer] + permitted_roles = [ContributionRole.admin, ContributionRole.editor] + + result = roles_permitted(user_roles, permitted_roles) + assert result is False + + def test_empty_user_roles_permission_denied(self): + """Test that permission is denied when user has no roles.""" + user_roles = [] + permitted_roles = [UserRole.admin] + + result = roles_permitted(user_roles, permitted_roles) + assert result is False + + def test_multiple_matching_roles(self): + """Test permission when user has multiple permitted roles.""" + user_roles = [UserRole.admin, UserRole.mapper] + permitted_roles = [UserRole.admin, UserRole.mapper] + + result = roles_permitted(user_roles, permitted_roles) + assert result is True + + def test_partial_role_match(self): + """Test permission when user has some but not all permitted roles.""" + user_roles = [UserRole.mapper] + permitted_roles = [UserRole.admin, UserRole.mapper] + + result = roles_permitted(user_roles, permitted_roles) + assert result is True + + def test_no_role_overlap(self): + """Test permission when user roles don't overlap with permitted roles.""" + user_roles = [ContributionRole.viewer] + permitted_roles = [ContributionRole.admin, ContributionRole.editor] + + result = roles_permitted(user_roles, permitted_roles) + assert result is False + + def test_empty_permitted_roles(self): + """Test behavior when no roles are permitted.""" + user_roles = [UserRole.admin] + permitted_roles = [] + + result = roles_permitted(user_roles, permitted_roles) + assert result is False + + def test_both_empty_roles(self): + """Test behavior when both user and permitted roles are empty.""" + user_roles = [] + permitted_roles = [] + + result = roles_permitted(user_roles, permitted_roles) + assert result is False + + def test_consistent_role_types_allowed(self): + """Test behavior with consistent role types (should work fine).""" + user_roles = [UserRole.admin] + permitted_roles = [UserRole.admin, UserRole.mapper] + assert roles_permitted(user_roles, permitted_roles) is True + + user_roles = [ContributionRole.editor] + permitted_roles = [ContributionRole.admin, ContributionRole.editor, ContributionRole.viewer] + assert roles_permitted(user_roles, permitted_roles) is True + + def test_mixed_user_role_types_raises_error(self): + """Test that mixed role types in user_roles list raises ValueError.""" + permitted_roles = [UserRole.admin] + mixed_user_roles = [UserRole.admin, ContributionRole.editor] + + with pytest.raises(ValueError) as exc_info: + roles_permitted(mixed_user_roles, permitted_roles) + + assert "user_roles list cannot contain mixed role types" in str(exc_info.value) + + def test_mixed_permitted_role_types_raises_error(self): + """Test that mixed role types in permitted_roles list raises ValueError.""" + user_roles = [UserRole.admin] + mixed_permitted_roles = [UserRole.admin, ContributionRole.editor] + + with pytest.raises(ValueError) as exc_info: + roles_permitted(user_roles, mixed_permitted_roles) + + assert "permitted_roles list cannot contain mixed role types" in str(exc_info.value) + + def test_different_role_types_between_lists_raises_error(self): + """Test that different role types between lists raises ValueError.""" + user_roles = [UserRole.admin] + permitted_roles = [ContributionRole.admin] + + with pytest.raises(ValueError) as exc_info: + roles_permitted(user_roles, permitted_roles) + + assert "user_roles and permitted_roles must contain the same role type" in str(exc_info.value) + + def test_single_role_lists(self): + """Test with single-item role lists.""" + user_roles = [UserRole.admin] + permitted_roles = [UserRole.admin] + assert roles_permitted(user_roles, permitted_roles) is True + + user_roles = [UserRole.mapper] + permitted_roles = [UserRole.admin] + assert roles_permitted(user_roles, permitted_roles) is False + + +class TestDenyActionForEntity: + """Test the deny_action_for_entity utility function.""" + + @pytest.mark.parametrize( + "entity_is_private, user_data, user_can_view_private, expected_status", + [ + # Private entity, anonymous user + (True, None, False, 404), + # Private entity, authenticated user without permissions + (True, Mock(user=Mock(id=1)), False, 404), + # Private entity, authenticated user with permissions + (True, Mock(user=Mock(id=1)), True, 403), + # Public entity, anonymous user + (False, None, False, 401), + # Public entity, authenticated user + (False, Mock(user=Mock(id=1)), False, 403), + ], + ids=[ + "private_anonymous_not-viewer", + "private_authenticated_not-viewer", + "private_authenticated_viewer", + "public_anonymous", + "public_authenticated", + ], + ) + def test_deny_action(self, entity_is_private, user_data, user_can_view_private, expected_status): + """Test denial for various user and entity privacy scenarios.""" + + entity = Mock(urn="entity:1234") + response = deny_action_for_entity(entity, entity_is_private, user_data, user_can_view_private) + + assert response.permitted is False + assert response.http_code == expected_status + + def test_deny_action_urn_available(self): + """Test denial message includes URN when available.""" + entity = Mock(urn="entity:5678") + response = deny_action_for_entity(entity, True, None, False) + + assert "URN 'entity:5678'" in response.message + + def test_deny_action_id_available(self): + """Test denial message includes ID when URN is not available.""" + entity = Mock(urn=None, id=42) + response = deny_action_for_entity(entity, True, None, False) + + assert "ID '42'" in response.message + + def test_deny_action_no_identifier(self): + """Test denial message when neither URN nor ID is available.""" + entity = Mock(urn=None, id=None) + response = deny_action_for_entity(entity, True, None, False) + + assert "unknown" in response.message + + def test_deny_handles_undefined_attributres(self): + """Test denial message when identifier attributes are undefined.""" + entity = Mock() + del entity.urn # Remove urn attribute + del entity.id # Remove id attribute + response = deny_action_for_entity(entity, True, None, False) + + assert "unknown" in response.message + + def test_deny_action_entity_name_in_message(self): + """Test denial message includes entity class name.""" + + class CustomEntity: + pass + + entity = CustomEntity() + response = deny_action_for_entity(entity, True, None, False, "custom entity") + + assert "custom entity" in response.message diff --git a/tests/lib/test_acmg.py b/tests/lib/test_acmg.py index db458439..cc5dfac0 100644 --- a/tests/lib/test_acmg.py +++ b/tests/lib/test_acmg.py @@ -1,10 +1,21 @@ +# ruff: noqa: E402 + import pytest +from sqlalchemy import select + +pytest.importorskip("psycopg2") from mavedb.lib.acmg import ( ACMGCriterion, StrengthOfEvidenceProvided, + find_or_create_acmg_classification, points_evidence_strength_equivalent, ) +from mavedb.models.acmg_classification import ACMGClassification + +############################################################################### +# Tests for points_evidence_strength_equivalent +############################################################################### @pytest.mark.parametrize( @@ -79,3 +90,154 @@ def test_all_strength_categories_covered(): assert StrengthOfEvidenceProvided.MODERATE_PLUS in seen assert StrengthOfEvidenceProvided.MODERATE in seen assert StrengthOfEvidenceProvided.SUPPORTING in seen + + +############################################################################### +# Tests for find_or_create_acmg_classification +############################################################################### + + +@pytest.mark.parametrize( + "criterion,evidence_strength,points", + [ + # Valid combinations + (ACMGCriterion.PS3, StrengthOfEvidenceProvided.STRONG, 4), + (ACMGCriterion.BS3, StrengthOfEvidenceProvided.MODERATE, -2), + (None, None, None), # Should return None + (None, None, 5), # Should derive from points + ], +) +def test_find_or_create_acmg_classification_validation_does_not_raise_on_valid_combinations( + session, criterion, evidence_strength, points +): + """Test input validation for find_or_create_acmg_classification valid values.""" + result = find_or_create_acmg_classification(session, criterion, evidence_strength, points) + + if criterion is None and evidence_strength is None and points is None: + assert result is None + else: + assert result is not None + + +@pytest.mark.parametrize( + "criterion,evidence_strength,points", + [ + # Invalid combinations - only one is None + (ACMGCriterion.PS3, None, 4), + (None, StrengthOfEvidenceProvided.STRONG, 4), + ], +) +def test_find_or_create_acmg_classification_validation_raises_on_invalid_combinations( + session, criterion, evidence_strength, points +): + """Test input validation for find_or_create_acmg_classification invalid values.""" + with pytest.raises( + ValueError, + match="Both criterion and evidence_strength must be provided together or both be None, with points.", + ): + find_or_create_acmg_classification(session, criterion, evidence_strength, points) + + +def test_find_or_create_acmg_classification_returns_none_for_all_none(session): + """Test that function returns None when all parameters are None.""" + + result = find_or_create_acmg_classification(session, None, None, None) + assert result is None + + +def test_find_or_create_acmg_classification_derives_from_points(session): + """Test that function derives criterion and evidence_strength from points when they are None.""" + + result = find_or_create_acmg_classification(session, None, None, 4) + + assert result is not None + assert result.criterion == ACMGCriterion.PS3 + assert result.evidence_strength == StrengthOfEvidenceProvided.STRONG + assert result.points == 4 + + +def test_find_or_create_acmg_classification_creates_new_entry(session): + """Test that function creates a new ACMGClassification when one doesn't exist.""" + + # Verify no existing entry + existing = session.execute( + select(ACMGClassification) + .where(ACMGClassification.criterion == ACMGCriterion.PS3) + .where(ACMGClassification.evidence_strength == StrengthOfEvidenceProvided.MODERATE) + .where(ACMGClassification.points == 2) + ).scalar_one_or_none() + assert existing is None + + result = find_or_create_acmg_classification(session, ACMGCriterion.PS3, StrengthOfEvidenceProvided.MODERATE, 2) + + assert result is not None + assert result.criterion == ACMGCriterion.PS3 + assert result.evidence_strength == StrengthOfEvidenceProvided.MODERATE + assert result.points == 2 + + # Verify it was added to the session + session_objects = [obj for obj in session.new if isinstance(obj, ACMGClassification)] + assert len(session_objects) == 1 + assert session_objects[0] == result + + +def test_find_or_create_acmg_classification_finds_existing_entry(session): + """Test that function finds and returns existing ACMGClassification.""" + + # Create an existing entry + existing_classification = ACMGClassification( + criterion=ACMGCriterion.BS3, evidence_strength=StrengthOfEvidenceProvided.STRONG, points=-5 + ) + session.add(existing_classification) + session.commit() + + result = find_or_create_acmg_classification(session, ACMGCriterion.BS3, StrengthOfEvidenceProvided.STRONG, -5) + + assert result is not None + assert result == existing_classification + assert result.criterion == ACMGCriterion.BS3 + assert result.evidence_strength == StrengthOfEvidenceProvided.STRONG + assert result.points == -5 + + # Verify no new objects were added to the session + assert len(session.new) == 0 + + +def test_find_or_create_acmg_classification_with_zero_points(session): + """Test function behavior with zero points.""" + + result = find_or_create_acmg_classification(session, None, None, 0) + assert result is None + + +@pytest.mark.parametrize("points", [-8, -4, -1, 1, 3, 8]) +def test_find_or_create_acmg_classification_points_integration(session, points): + """Test that function works correctly with various point values.""" + + result = find_or_create_acmg_classification(session, None, None, points) + + expected_criterion, expected_strength = points_evidence_strength_equivalent(points) + + assert result is not None + assert result.criterion == expected_criterion + assert result.evidence_strength == expected_strength + assert result.points == points + + +def test_find_or_create_acmg_classification_does_not_commit(session): + """Test that function does not commit the session.""" + + find_or_create_acmg_classification(session, ACMGCriterion.PS3, StrengthOfEvidenceProvided.SUPPORTING, 1) + + # Rollback the session + session.rollback() + + # Verify the object is no longer in the database + existing = session.execute( + select(ACMGClassification) + .where(ACMGClassification.criterion == ACMGCriterion.PS3) + .where(ACMGClassification.evidence_strength == StrengthOfEvidenceProvided.SUPPORTING) + .where(ACMGClassification.points == 1) + ).scalar_one_or_none() + + assert existing is None diff --git a/tests/lib/test_flexible_model_loader.py b/tests/lib/test_flexible_model_loader.py new file mode 100644 index 00000000..e9f578d7 --- /dev/null +++ b/tests/lib/test_flexible_model_loader.py @@ -0,0 +1,342 @@ +# ruff: noqa: E402 + +import pytest + +pytest.importorskip("fastapi") + +import json +from typing import Optional +from unittest.mock import AsyncMock, Mock, patch + +from fastapi import HTTPException, Request +from fastapi.exceptions import RequestValidationError + +from mavedb.lib.flexible_model_loader import create_flexible_model_loader, json_or_form_loader +from mavedb.view_models.base.base import BaseModel + + +class SampleModel(BaseModel): + """Sample model for flexible model loader tests.""" + + name: str + age: int + email: Optional[str] = None + + +class ComplexSampleModel(BaseModel): + """More complex sample model with validation.""" + + id: int + title: str + tags: list[str] = [] + metadata: dict = {} + + +@pytest.fixture +def test_model_loader(): + """Create a flexible model loader for SampleModel.""" + return create_flexible_model_loader(SampleModel) + + +@pytest.fixture +def custom_loader(): + """Create a flexible model loader with custom parameters.""" + return create_flexible_model_loader(SampleModel, form_field_name="custom_field", error_detail_prefix="Custom error") + + +@pytest.fixture +def mock_request(): + """Create a mock FastAPI Request object.""" + request = Mock(spec=Request) + request.body = AsyncMock() + return request + + +class TestCreateFlexibleModelLoader: + """Test suite for create_flexible_model_loader function.""" + + @pytest.mark.asyncio + async def test_load_from_form_field_valid_data(self, test_model_loader, mock_request): + """Test loading valid data from form field.""" + test_data = {"name": "John", "age": 30, "email": "john@example.com"} + json_data = json.dumps(test_data) + + result = await test_model_loader(mock_request, item=json_data) + + assert isinstance(result, SampleModel) + assert result.name == "John" + assert result.age == 30 + assert result.email == "john@example.com" + + @pytest.mark.asyncio + async def test_load_from_form_field_minimal_data(self, test_model_loader, mock_request): + """Test loading minimal valid data from form field.""" + test_data = {"name": "Jane", "age": 25} + json_data = json.dumps(test_data) + + result = await test_model_loader(mock_request, item=json_data) + + assert isinstance(result, SampleModel) + assert result.name == "Jane" + assert result.age == 25 + assert result.email is None + + @pytest.mark.asyncio + async def test_load_from_json_body_valid_data(self, test_model_loader, mock_request): + """Test loading valid data from JSON body.""" + test_data = {"name": "Bob", "age": 35, "email": "bob@example.com"} + json_bytes = json.dumps(test_data).encode("utf-8") + mock_request.body.return_value = json_bytes + + result = await test_model_loader(mock_request, item=None) + + assert isinstance(result, SampleModel) + assert result.name == "Bob" + assert result.age == 35 + assert result.email == "bob@example.com" + + @pytest.mark.asyncio + async def test_form_field_takes_priority_over_json_body(self, test_model_loader, mock_request): + """Test that form field data takes priority over JSON body.""" + form_data = {"name": "FormUser", "age": 25} + body_data = {"name": "BodyUser", "age": 30} + + json_form = json.dumps(form_data) + json_body = json.dumps(body_data).encode("utf-8") + mock_request.body.return_value = json_body + + result = await test_model_loader(mock_request, item=json_form) + + assert result.name == "FormUser" + assert result.age == 25 + + @pytest.mark.asyncio + async def test_validation_error_from_form_field(self, test_model_loader, mock_request): + """Test ValidationError handling for invalid form field data.""" + invalid_data = {"name": "John"} # Missing required 'age' field + json_data = json.dumps(invalid_data) + + with pytest.raises(RequestValidationError) as exc_info: + await test_model_loader(mock_request, item=json_data) + + errors = exc_info.value.errors() + assert len(errors) > 0 + assert any(error["loc"] == ("age",) for error in errors) + + @pytest.mark.asyncio + async def test_validation_error_from_json_body(self, test_model_loader, mock_request): + """Test ValidationError handling for invalid JSON body data.""" + invalid_data = {"age": 25} # Missing required 'name' field + json_bytes = json.dumps(invalid_data).encode("utf-8") + mock_request.body.return_value = json_bytes + + with pytest.raises(RequestValidationError) as exc_info: + await test_model_loader(mock_request, item=None) + + errors = exc_info.value.errors() + assert len(errors) > 0 + assert any(error["loc"] == ("name",) for error in errors) + + @pytest.mark.asyncio + async def test_invalid_json_syntax_form_field(self, test_model_loader, mock_request): + """Test handling of invalid JSON syntax in form field.""" + invalid_json = '{"name": "John", "age":}' # Invalid JSON + + with pytest.raises(RequestValidationError) as exc_info: + await test_model_loader(mock_request, item=invalid_json) + + assert exc_info.value.errors() + assert "json_invalid" in exc_info.value.errors()[0]["type"] + + @pytest.mark.asyncio + async def test_invalid_json_syntax_body(self, test_model_loader, mock_request): + """Test handling of invalid JSON syntax in request body.""" + invalid_json = b'{"name": "John", "age":}' # Invalid JSON + mock_request.body.return_value = invalid_json + + with pytest.raises(RequestValidationError) as exc_info: + await test_model_loader(mock_request, item=None) + + assert exc_info.value.errors() + assert "json_invalid" in exc_info.value.errors()[0]["type"] + + @pytest.mark.asyncio + async def test_empty_request_body_and_no_form_field(self, test_model_loader, mock_request): + """Test handling when no data is provided in either form field or body.""" + mock_request.body.return_value = b"" + + with pytest.raises(HTTPException) as exc_info: + await test_model_loader(mock_request, item=None) + + assert exc_info.value.status_code == 422 + assert "No data provided in form field or request body" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_custom_error_detail_prefix(self, custom_loader, mock_request): + """Test custom error detail prefix is used in error messages.""" + mock_request.body.return_value = b"" + + with pytest.raises(HTTPException) as exc_info: + await custom_loader(mock_request, item=None) + + assert exc_info.value.status_code == 422 + assert "Custom error" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_complex_model_with_nested_data(self, mock_request): + """Test loading complex model with nested data structures.""" + complex_loader = create_flexible_model_loader(ComplexSampleModel) + test_data = { + "id": 1, + "title": "Test Item", + "tags": ["tag1", "tag2", "tag3"], + "metadata": {"key1": "value1", "key2": {"nested": "value"}}, + } + json_data = json.dumps(test_data) + + result = await complex_loader(mock_request, item=json_data) + + assert isinstance(result, ComplexSampleModel) + assert result.id == 1 + assert result.title == "Test Item" + assert result.tags == ["tag1", "tag2", "tag3"] + assert result.metadata == {"key1": "value1", "key2": {"nested": "value"}} + + @pytest.mark.asyncio + async def test_form_field_name_parameter_documentation_only(self, mock_request): + """Test that form_field_name parameter doesn't affect functionality.""" + # Create loaders with different form_field_name values + loader1 = create_flexible_model_loader(SampleModel, form_field_name="item") + loader2 = create_flexible_model_loader(SampleModel, form_field_name="custom_name") + + test_data = {"name": "Test", "age": 30} + json_data = json.dumps(test_data) + + # Both should work the same way since form_field_name is for docs only + result1 = await loader1(mock_request, item=json_data) + result2 = await loader2(mock_request, item=json_data) + + assert result1.name == result2.name == "Test" + assert result1.age == result2.age == 30 + + @pytest.mark.asyncio + async def test_exception_handling_for_unexpected_errors(self, test_model_loader, mock_request): + """Test handling of unexpected exceptions during processing.""" + # Mock an exception during model validation + with patch.object(SampleModel, "model_validate_json", side_effect=RuntimeError("Unexpected error")): + test_data = {"name": "John", "age": 30} + json_data = json.dumps(test_data) + + with pytest.raises(HTTPException) as exc_info: + await test_model_loader(mock_request, item=json_data) + + assert exc_info.value.status_code == 422 + assert "Unexpected error" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_unicode_data_handling(self, test_model_loader, mock_request): + """Test handling of unicode characters in data.""" + test_data = {"name": "José María", "age": 25, "email": "josé@example.com"} + json_data = json.dumps(test_data, ensure_ascii=False) + + result = await test_model_loader(mock_request, item=json_data) + + assert result.name == "José María" + assert result.email == "josé@example.com" + + +class TestJsonOrFormLoader: + """Test suite for json_or_form_loader convenience function.""" + + @pytest.mark.asyncio + async def test_convenience_function_basic_usage(self, mock_request): + """Test the convenience function with basic usage.""" + loader = json_or_form_loader(SampleModel) + test_data = {"name": "Alice", "age": 28} + json_data = json.dumps(test_data) + + result = await loader(mock_request, item=json_data) + + assert isinstance(result, SampleModel) + assert result.name == "Alice" + assert result.age == 28 + + @pytest.mark.asyncio + async def test_convenience_function_custom_field_name(self, mock_request): + """Test the convenience function with custom field name.""" + loader = json_or_form_loader(SampleModel, field_name="custom_field") + test_data = {"name": "Charlie", "age": 35} + json_data = json.dumps(test_data) + + result = await loader(mock_request, item=json_data) + + assert isinstance(result, SampleModel) + assert result.name == "Charlie" + assert result.age == 35 + + @pytest.mark.asyncio + async def test_convenience_function_error_message_format(self, mock_request): + """Test that convenience function generates appropriate error messages.""" + loader = json_or_form_loader(SampleModel) + mock_request.body.return_value = b"" + + with pytest.raises(HTTPException) as exc_info: + await loader(mock_request, item=None) + + assert exc_info.value.status_code == 422 + assert "Invalid SampleModel data" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_convenience_function_with_complex_model(self, mock_request): + """Test convenience function with more complex model.""" + loader = json_or_form_loader(ComplexSampleModel) + test_data = {"id": 42, "title": "Complex Test", "tags": ["test", "complex"], "metadata": {"source": "test"}} + json_data = json.dumps(test_data) + + result = await loader(mock_request, item=json_data) + + assert isinstance(result, ComplexSampleModel) + assert result.id == 42 + assert result.title == "Complex Test" + assert result.tags == ["test", "complex"] + assert result.metadata == {"source": "test"} + + +class TestEdgeCases: + """Test edge cases and boundary conditions.""" + + @pytest.mark.asyncio + async def test_empty_string_form_field(self, test_model_loader, mock_request): + """Test handling of empty string in form field.""" + with pytest.raises(RequestValidationError) as exc_info: + await test_model_loader(mock_request, item="") + + assert exc_info.value.errors() + assert "json_invalid" in exc_info.value.errors()[0]["type"] + + @pytest.mark.asyncio + async def test_whitespace_only_form_field(self, test_model_loader, mock_request): + """Test handling of whitespace-only form field.""" + with pytest.raises(RequestValidationError) as exc_info: + await test_model_loader(mock_request, item=" ") + + assert exc_info.value.errors() + assert "json_invalid" in exc_info.value.errors()[0]["type"] + + @pytest.mark.asyncio + async def test_null_json_value(self, test_model_loader, mock_request): + """Test handling of null JSON value.""" + with pytest.raises(RequestValidationError) as exc_info: + await test_model_loader(mock_request, item="null") + + assert exc_info.value.errors() + assert "model_type" in exc_info.value.errors()[0]["type"] + + @pytest.mark.asyncio + async def test_array_json_value(self, test_model_loader, mock_request): + """Test handling of array JSON value instead of object.""" + with pytest.raises(RequestValidationError) as exc_info: + await test_model_loader(mock_request, item='["not", "an", "object"]') + + assert exc_info.value.errors() + assert "model_type" in exc_info.value.errors()[0]["type"] diff --git a/tests/lib/test_score_calibrations.py b/tests/lib/test_score_calibrations.py index 9ca1b010..ad6bb0ea 100644 --- a/tests/lib/test_score_calibrations.py +++ b/tests/lib/test_score_calibrations.py @@ -6,11 +6,13 @@ from unittest import mock +import pandas as pd from pydantic import create_model from sqlalchemy import select from sqlalchemy.exc import NoResultFound from mavedb.lib.score_calibrations import ( + create_functional_classification, create_score_calibration, create_score_calibration_in_score_set, delete_score_calibration, @@ -18,29 +20,292 @@ modify_score_calibration, promote_score_calibration_to_primary, publish_score_calibration, + variant_classification_df_to_dict, + variants_for_functional_classification, +) +from mavedb.lib.validation.constants.general import ( + calibration_class_column_name, + calibration_variant_column_name, + hgvs_nt_column, + hgvs_pro_column, ) from mavedb.models.enums.score_calibration_relation import ScoreCalibrationRelation from mavedb.models.score_calibration import ScoreCalibration +from mavedb.models.score_calibration_functional_classification import ScoreCalibrationFunctionalClassification from mavedb.models.score_set import ScoreSet from mavedb.models.user import User +from mavedb.models.variant import Variant from mavedb.view_models.score_calibration import ScoreCalibrationCreate, ScoreCalibrationModify - from tests.helpers.constants import ( + EXTRA_USER, TEST_BIORXIV_IDENTIFIER, - TEST_BRNICH_SCORE_CALIBRATION, + TEST_BRNICH_SCORE_CALIBRATION_CLASS_BASED, + TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, TEST_CROSSREF_IDENTIFIER, TEST_LICENSE, TEST_PATHOGENICITY_SCORE_CALIBRATION, TEST_PUBMED_IDENTIFIER, TEST_SEQ_SCORESET, VALID_SCORE_SET_URN, - EXTRA_USER, ) from tests.helpers.util.contributor import add_contributor -from tests.helpers.util.score_calibration import create_test_score_calibration_in_score_set +from tests.helpers.util.score_calibration import create_test_range_based_score_calibration_in_score_set ################################################################################ -# Tests for create_score_calibration +# Tests for create_functional_classification +################################################################################ + + +def test_create_functional_classification_without_acmg_classification(setup_lib_db, session): + # Create a mock calibration + calibration = ScoreCalibration() + + # Create mock functional range without ACMG classification + MockFunctionalClassificationCreate = create_model( + "MockFunctionalClassificationCreate", + label=(str, "Test Label"), + description=(str, "Test Description"), + range=(list, [0.0, 1.0]), + class_=(type(None), None), + inclusive_lower_bound=(bool, True), + inclusive_upper_bound=(bool, False), + functional_classification=(str, "pathogenic"), + oddspaths_ratio=(float, 1.5), + positive_likelihood_ratio=(float, 2.0), + acmg_classification=(type(None), None), + ) + + result = create_functional_classification(session, MockFunctionalClassificationCreate(), calibration) + + assert result.description == "Test Description" + assert result.range == [0.0, 1.0] + assert result.inclusive_lower_bound is True + assert result.inclusive_upper_bound is False + assert result.functional_classification == "pathogenic" + assert result.oddspaths_ratio == 1.5 + assert result.positive_likelihood_ratio == 2.0 + assert result.acmg_classification is None + assert result.acmg_classification_id is None + assert result.calibration == calibration + + +def test_create_functional_classification_with_acmg_classification(setup_lib_db, session): + # Create a mock calibration + calibration = ScoreCalibration() + + # Create mock ACMG classification + mock_criterion = "PS1" + mock_evidence_strength = "STRONG" + mock_points = 4 + MockAcmgClassification = create_model( + "MockAcmgClassification", + criterion=(str, mock_criterion), + evidence_strength=(str, mock_evidence_strength), + points=(int, mock_points), + ) + + # Create mock functional range with ACMG classification + MockFunctionalClassificationCreate = create_model( + "MockFunctionalClassificationCreate", + label=(str, "Test Label"), + description=(str, "Test Description"), + range=(list, [0.0, 1.0]), + class_=(type(None), None), + inclusive_lower_bound=(bool, True), + inclusive_upper_bound=(bool, False), + functional_classification=(str, "pathogenic"), + oddspaths_ratio=(float, 1.5), + positive_likelihood_ratio=(float, 2.0), + acmg_classification=(MockAcmgClassification, MockAcmgClassification()), + ) + + functional_range_create = MockFunctionalClassificationCreate() + + with mock.patch("mavedb.lib.score_calibrations.find_or_create_acmg_classification") as mock_find_or_create: + # Mock the ACMG classification with an ID + MockPersistedAcmgClassification = create_model( + "MockPersistedAcmgClassification", + id=(int, 123), + ) + + mocked_persisted_acmg_classification = MockPersistedAcmgClassification() + mock_find_or_create.return_value = mocked_persisted_acmg_classification + result = create_functional_classification(session, functional_range_create, calibration) + + # Verify find_or_create_acmg_classification was called with correct parameters + mock_find_or_create.assert_called_once_with( + session, + criterion=mock_criterion, + evidence_strength=mock_evidence_strength, + points=mock_points, + ) + + # Verify the result + assert result.label == "Test Label" + assert result.description == "Test Description" + assert result.range == [0.0, 1.0] + assert result.inclusive_lower_bound is True + assert result.inclusive_upper_bound is False + assert result.functional_classification == "pathogenic" + assert result.oddspaths_ratio == 1.5 + assert result.positive_likelihood_ratio == 2.0 + assert result.acmg_classification == mocked_persisted_acmg_classification + assert result.acmg_classification_id == 123 + assert result.calibration == calibration + + +def test_create_functional_classification_with_variant_classes(setup_lib_db, session): + # Create a mock calibration + calibration = ScoreCalibration() + + # Create mock functional range with variant classes + MockFunctionalClassificationCreate = create_model( + "MockFunctionalClassificationCreate", + label=(str, "Test Label"), + description=(str, "Test Description"), + range=(type(None), None), + class_=(str, "test_class"), + inclusive_lower_bound=(type(None), None), + inclusive_upper_bound=(type(None), None), + functional_classification=(str, "pathogenic"), + oddspaths_ratio=(float, 1.5), + positive_likelihood_ratio=(float, 2.0), + acmg_classification=(type(None), None), + ) + + functional_range_create = MockFunctionalClassificationCreate() + + with mock.patch("mavedb.lib.score_calibrations.variants_for_functional_classification") as mock_classified_variants: + MockedClassifiedVariant = create_model( + "MockedVariant", + urn=(str, "variant_urn_3"), + ) + mock_classified_variants.return_value = [MockedClassifiedVariant()] + + result = create_functional_classification( + session, + functional_range_create, + calibration, + variant_classes={ + "indexed_by": calibration_variant_column_name, + "classifications": { + "pathogenic": ["variant_urn_1", "variant_urn_2"], + "benign": ["variant_urn_3"], + }, + }, + ) + + mock_classified_variants.assert_called() + + assert result.description == "Test Description" + assert result.range is None + assert result.inclusive_lower_bound is None + assert result.inclusive_upper_bound is None + assert result.functional_classification == "pathogenic" + assert result.oddspaths_ratio == 1.5 + assert result.positive_likelihood_ratio == 2.0 + assert result.acmg_classification is None + assert result.acmg_classification_id is None + assert result.calibration == calibration + assert result.variants == [MockedClassifiedVariant()] + + +def test_create_functional_classification_propagates_acmg_errors(setup_lib_db, session): + # Create a mock calibration + calibration = ScoreCalibration() + + # Create mock ACMG classification + MockAcmgClassification = create_model( + "MockAcmgClassification", + criterion=(str, "PS1"), + evidence_strength=(str, "strong"), + points=(int, 4), + ) + + # Create mock functional range with ACMG classification + MockFunctionalClassificationCreate = create_model( + "MockFunctionalClassificationCreate", + label=(str, "Test Label"), + description=(str, "Test Description"), + range=(list, [0.0, 1.0]), + class_=(type(None), None), + inclusive_lower_bound=(bool, True), + inclusive_upper_bound=(bool, False), + functional_classification=(str, "pathogenic"), + oddspaths_ratio=(float, 1.5), + positive_likelihood_ratio=(float, 2.0), + acmg_classification=(MockAcmgClassification, MockAcmgClassification()), + ) + + functional_range_create = MockFunctionalClassificationCreate() + + with ( + pytest.raises(ValueError, match="ACMG error"), + mock.patch( + "mavedb.lib.score_calibrations.find_or_create_acmg_classification", + side_effect=ValueError("ACMG error"), + ), + ): + create_functional_classification(session, functional_range_create, calibration) + + +def test_create_functional_classification_propagates_functional_classification_errors(setup_lib_db, session): + # Create a mock calibration + calibration = ScoreCalibration() + + # Create mock functional range + MockFunctionalClassificationCreate = create_model( + "MockFunctionalClassificationCreate", + label=(str, "Test Label"), + description=(str, "Test Description"), + range=(list, [0.0, 1.0]), + class_=(type(None), None), + inclusive_lower_bound=(bool, True), + inclusive_upper_bound=(bool, False), + functional_classification=(str, "pathogenic"), + oddspaths_ratio=(float, 1.5), + positive_likelihood_ratio=(float, 2.0), + acmg_classification=(type(None), None), + ) + + functional_range_create = MockFunctionalClassificationCreate() + + with ( + pytest.raises(ValueError, match="Functional classification error"), + mock.patch( + "mavedb.lib.score_calibrations.ScoreCalibrationFunctionalClassification", + side_effect=ValueError("Functional classification error"), + ), + ): + create_functional_classification(session, functional_range_create, calibration) + + +def test_create_functional_classification_does_not_commit_transaction(setup_lib_db, session): + # Create a mock calibration + calibration = ScoreCalibration() + + # Create mock functional range without ACMG classification + MockFunctionalClassificationCreate = create_model( + "MockFunctionalClassificationCreate", + label=(str, "Test Label"), + description=(str, "Test Description"), + range=(list, [0.0, 1.0]), + class_=(type(None), None), + inclusive_lower_bound=(bool, True), + inclusive_upper_bound=(bool, False), + functional_classification=(str, "pathogenic"), + oddspaths_ratio=(float, 1.5), + positive_likelihood_ratio=(float, 2.0), + acmg_classification=(type(None), None), + ) + + with mock.patch.object(session, "commit") as mock_commit: + create_functional_classification(session, MockFunctionalClassificationCreate(), calibration) + mock_commit.assert_not_called() + + +################################################################################ +# Tests for _create_score_calibration (tested indirectly via the following tests to its callers) ################################################################################ @@ -83,6 +348,7 @@ async def test_create_score_calibration_in_score_set_creates_score_calibration_w threshold_sources=(list, []), classification_sources=(list, []), method_sources=(list, []), + functional_classifications=(list, []), ) calibration = await create_score_calibration_in_score_set(session, MockCalibrationCreate(), test_user) @@ -102,6 +368,7 @@ async def test_create_score_calibration_in_score_set_investigator_provided_set_w threshold_sources=(list, []), classification_sources=(list, []), method_sources=(list, []), + functional_classifications=(list, []), ) calibration = await create_score_calibration_in_score_set(session, MockCalibrationCreate(), test_user) @@ -133,6 +400,7 @@ async def test_create_score_calibration_in_score_set_investigator_provided_set_w threshold_sources=(list, []), classification_sources=(list, []), method_sources=(list, []), + functional_classifications=(list, []), ) calibration = await create_score_calibration_in_score_set(session, MockCalibrationCreate(), extra_user) @@ -153,6 +421,7 @@ async def test_create_score_calibration_in_score_set_investigator_provided_not_s threshold_sources=(list, []), classification_sources=(list, []), method_sources=(list, []), + functional_classifications=(list, []), ) # invoke from a different user context @@ -191,6 +460,7 @@ async def test_create_score_calibration_creates_score_calibration_when_score_set threshold_sources=(list, []), classification_sources=(list, []), method_sources=(list, []), + functional_classifications=(list, []), ) calibration = await create_score_calibration(session, MockCalibrationCreate(), test_user) @@ -225,6 +495,7 @@ async def test_create_score_calibration_propagates_errors_from_publication_find_ ), classification_sources=(list, []), method_sources=(list, []), + functional_classifications=(list, []), ) with ( pytest.raises( @@ -277,6 +548,7 @@ async def test_create_score_calibration_publication_identifier_associations_crea threshold_sources=(list, []), classification_sources=(list, []), method_sources=(list, []), + functional_classifications=(list, []), ) test_user = session.execute(select(User)).scalars().first() @@ -312,6 +584,7 @@ async def test_create_score_calibration_user_is_set_as_creator_and_modifier( threshold_sources=(list, []), classification_sources=(list, []), method_sources=(list, []), + functional_classifications=(list, []), ) test_user = session.execute(select(User)).scalars().first() @@ -339,20 +612,69 @@ async def test_create_score_calibration_user_is_set_as_creator_and_modifier( ], indirect=["mock_publication_fetch"], ) +@pytest.mark.parametrize( + "valid_score_calibration_data", + [ + TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, + TEST_BRNICH_SCORE_CALIBRATION_CLASS_BASED, + ], +) async def test_create_score_calibration_fully_valid_calibration( - setup_lib_db_with_score_set, session, create_function_to_call, score_set_urn, mock_publication_fetch + setup_lib_db_with_score_set, + session, + create_function_to_call, + score_set_urn, + mock_publication_fetch, + valid_score_calibration_data, ): - calibration_create = ScoreCalibrationCreate(**TEST_BRNICH_SCORE_CALIBRATION, score_set_urn=score_set_urn) + calibration_create = ScoreCalibrationCreate(**valid_score_calibration_data, score_set_urn=score_set_urn) test_user = session.execute(select(User)).scalars().first() calibration = await create_function_to_call(session, calibration_create, test_user) - for field in TEST_BRNICH_SCORE_CALIBRATION: - # Sources are tested elsewhere - # XXX: Ranges are a pain to compare between JSONB and dict input, so are assumed correct - if "sources" not in field and "functional_ranges" not in field: - assert getattr(calibration, field) == TEST_BRNICH_SCORE_CALIBRATION[field] + for field in valid_score_calibration_data: + # Sources are tested elsewhere. + if "sources" not in field and "functional_classifications" not in field: + assert getattr(calibration, field) == valid_score_calibration_data[field] + + # Verify functional classifications length. Assume the returned value of created classifications is correct, + # and test the content elsewhere. + if field == "functional_classifications": + assert len(calibration.functional_classifications) == len( + valid_score_calibration_data["functional_classifications"] + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "create_function_to_call,score_set_urn", + [ + (create_score_calibration_in_score_set, VALID_SCORE_SET_URN), + (create_score_calibration, None), + ], +) +@pytest.mark.parametrize( + "mock_publication_fetch", + [ + [ + {"dbName": "PubMed", "identifier": TEST_PUBMED_IDENTIFIER}, + {"dbName": "bioRxiv", "identifier": TEST_BIORXIV_IDENTIFIER}, + ], + ], + indirect=["mock_publication_fetch"], +) +async def test_create_score_calibration_does_not_commit_transaction( + setup_lib_db_with_score_set, session, mock_user, create_function_to_call, score_set_urn, mock_publication_fetch +): + calibration_create = ScoreCalibrationCreate( + **TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, score_set_urn=score_set_urn + ) + test_user = session.execute(select(User)).scalars().first() + + with mock.patch.object(session, "commit") as mock_commit: + await create_function_to_call(session, calibration_create, test_user) + mock_commit.assert_not_called() ################################################################################ @@ -400,7 +722,7 @@ async def test_modify_score_calibration_modifies_score_calibration_when_score_se ): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) @@ -411,6 +733,7 @@ async def test_modify_score_calibration_modifies_score_calibration_when_score_se threshold_sources=(list, []), classification_sources=(list, []), method_sources=(list, []), + functional_classifications=(list, []), ) modified_calibration = await modify_score_calibration( @@ -437,7 +760,7 @@ async def test_modify_score_calibration_clears_existing_publication_identifier_a ): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) @@ -447,6 +770,7 @@ async def test_modify_score_calibration_clears_existing_publication_identifier_a threshold_sources=(list, []), classification_sources=(list, []), method_sources=(list, []), + functional_classifications=(list, []), ) mocked_calibration = MockCalibrationModify() @@ -483,7 +807,7 @@ async def test_modify_score_calibration_publication_identifier_associations_crea ): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) @@ -493,6 +817,7 @@ async def test_modify_score_calibration_publication_identifier_associations_crea threshold_sources=(list, []), classification_sources=(list, []), method_sources=(list, []), + functional_classifications=(list, []), ) mocked_calibration = MockCalibrationModify() @@ -525,7 +850,7 @@ async def test_modify_score_calibration_retains_existing_publication_relationshi ): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) calibration_publication_relations = existing_calibration.publication_identifier_associations.copy() @@ -541,7 +866,7 @@ async def test_modify_score_calibration_retains_existing_publication_relationshi db_name=(str, pub_dict["db_name"]), identifier=(str, pub_dict["identifier"]), )() - for pub_dict in TEST_BRNICH_SCORE_CALIBRATION["threshold_sources"] + for pub_dict in TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED["threshold_sources"] ], ), classification_sources=( @@ -552,7 +877,7 @@ async def test_modify_score_calibration_retains_existing_publication_relationshi db_name=(str, pub_dict["db_name"]), identifier=(str, pub_dict["identifier"]), )() - for pub_dict in TEST_BRNICH_SCORE_CALIBRATION["classification_sources"] + for pub_dict in TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED["classification_sources"] ], ), method_sources=( @@ -563,9 +888,10 @@ async def test_modify_score_calibration_retains_existing_publication_relationshi db_name=(str, pub_dict["db_name"]), identifier=(str, pub_dict["identifier"]), )() - for pub_dict in TEST_BRNICH_SCORE_CALIBRATION["method_sources"] + for pub_dict in TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED["method_sources"] ], ), + functional_classifications=(list, []), ) modified_calibration = await modify_score_calibration( @@ -592,7 +918,7 @@ async def test_modify_score_calibration_adds_new_publication_association( ): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) @@ -611,6 +937,7 @@ async def test_modify_score_calibration_adds_new_publication_association( ), classification_sources=(list, []), method_sources=(list, []), + functional_classifications=(list, []), ) modified_calibration = await modify_score_calibration( @@ -641,7 +968,7 @@ async def test_modify_score_calibration_user_is_set_as_modifier( ): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) @@ -651,6 +978,7 @@ async def test_modify_score_calibration_user_is_set_as_modifier( threshold_sources=(list, []), classification_sources=(list, []), method_sources=(list, []), + functional_classifications=(list, []), ) modify_user = session.execute(select(User).where(User.id != test_user.id)).scalars().first() @@ -690,7 +1018,7 @@ async def test_modify_score_calibration_new_score_set(setup_lib_db_with_score_se session.refresh(new_containing_score_set) test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, new_containing_score_set.urn, test_user ) @@ -700,6 +1028,7 @@ async def test_modify_score_calibration_new_score_set(setup_lib_db_with_score_se threshold_sources=(list, []), classification_sources=(list, []), method_sources=(list, []), + functional_classifications=(list, []), ) modified_calibration = await modify_score_calibration( @@ -709,6 +1038,42 @@ async def test_modify_score_calibration_new_score_set(setup_lib_db_with_score_se assert modified_calibration.score_set == new_containing_score_set +@pytest.mark.asyncio +@pytest.mark.parametrize( + "mock_publication_fetch", + [ + [ + {"dbName": "PubMed", "identifier": TEST_PUBMED_IDENTIFIER}, + {"dbName": "bioRxiv", "identifier": TEST_BIORXIV_IDENTIFIER}, + ], + ], + indirect=["mock_publication_fetch"], +) +async def test_modify_score_calibration_clears_functional_classifications( + setup_lib_db_with_score_set, session, mock_publication_fetch +): + test_user = session.execute(select(User)).scalars().first() + + existing_calibration = await create_test_range_based_score_calibration_in_score_set( + session, setup_lib_db_with_score_set.urn, test_user + ) + + MockCalibrationModify = create_model( + "MockCalibrationModify", + score_set_urn=(str | None, setup_lib_db_with_score_set.urn), + threshold_sources=(list, []), + classification_sources=(list, []), + method_sources=(list, []), + functional_classifications=(list, []), + ) + + modified_calibration = await modify_score_calibration( + session, existing_calibration, MockCalibrationModify(), test_user + ) + assert modified_calibration is not None + assert len(modified_calibration.functional_classifications) == 0 + + @pytest.mark.asyncio @pytest.mark.parametrize( "mock_publication_fetch", @@ -725,7 +1090,7 @@ async def test_modify_score_calibration_fully_valid_calibration( ): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) @@ -735,11 +1100,46 @@ async def test_modify_score_calibration_fully_valid_calibration( modified_calibration = await modify_score_calibration(session, existing_calibration, modify_calibration, test_user) for field in TEST_PATHOGENICITY_SCORE_CALIBRATION: - # Sources are tested elsewhere - # XXX: Ranges are a pain to compare between JSONB and dict input, so are assumed correct - if "sources" not in field and "functional_ranges" not in field: + # Sources are tested elsewhere. + if "sources" not in field and "functional_classifications" not in field: assert getattr(modified_calibration, field) == TEST_PATHOGENICITY_SCORE_CALIBRATION[field] + # Verify functional classifications length. Assume the returned value of created classifications is correct, + # and test the content elsewhere. + if field == "functional_classifications": + assert len(modified_calibration.functional_classifications) == len( + TEST_PATHOGENICITY_SCORE_CALIBRATION["functional_classifications"] + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "mock_publication_fetch", + [ + [ + {"dbName": "PubMed", "identifier": TEST_PUBMED_IDENTIFIER}, + {"dbName": "bioRxiv", "identifier": TEST_BIORXIV_IDENTIFIER}, + ], + ], + indirect=["mock_publication_fetch"], +) +async def test_modify_score_calibration_does_not_commit_transaction( + setup_lib_db_with_score_set, session, mock_publication_fetch +): + test_user = session.execute(select(User)).scalars().first() + + existing_calibration = await create_test_range_based_score_calibration_in_score_set( + session, setup_lib_db_with_score_set.urn, test_user + ) + + modify_calibration = ScoreCalibrationModify( + **TEST_PATHOGENICITY_SCORE_CALIBRATION, score_set_urn=setup_lib_db_with_score_set.urn + ) + + with mock.patch.object(session, "commit") as mock_commit: + modify_score_calibration(session, existing_calibration, modify_calibration, test_user) + mock_commit.assert_not_called() + ################################################################################ # Tests for publish_score_calibration @@ -762,7 +1162,7 @@ async def test_cannot_publish_already_published_calibration( ): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) existing_calibration.private = False @@ -790,7 +1190,7 @@ async def test_publish_score_calibration_marks_calibration_public( ): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) assert existing_calibration.private is True @@ -815,7 +1215,7 @@ async def test_publish_score_calibration_user_is_set_as_modifier( ): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) @@ -845,7 +1245,7 @@ async def test_publish_score_calibration_user_is_set_as_modifier( async def test_cannot_promote_already_primary_calibration(setup_lib_db_with_score_set, session, mock_publication_fetch): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) existing_calibration.primary = True @@ -873,7 +1273,7 @@ async def test_cannot_promote_calibration_when_calibration_is_research_use_only( ): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) existing_calibration.research_use_only = True @@ -901,7 +1301,7 @@ async def test_cannot_promote_calibration_when_calibration_is_private( ): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) existing_calibration.private = True @@ -929,10 +1329,10 @@ async def test_cannot_promote_calibration_when_another_primary_exists( ): test_user = session.execute(select(User)).scalars().first() - existing_primary_calibration = await create_test_score_calibration_in_score_set( + existing_primary_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) existing_primary_calibration.private = False @@ -966,7 +1366,7 @@ async def test_promote_score_calibration_to_primary_marks_calibration_primary( ): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) existing_calibration.private = False @@ -995,10 +1395,10 @@ async def test_promote_score_calibration_to_primary_demotes_existing_primary_whe ): test_user = session.execute(select(User)).scalars().first() - existing_primary_calibration = await create_test_score_calibration_in_score_set( + existing_primary_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) existing_primary_calibration.private = False @@ -1038,7 +1438,7 @@ async def test_promote_score_calibration_to_primary_user_is_set_as_modifier( ): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) existing_calibration.private = False @@ -1072,10 +1472,10 @@ async def test_promote_score_calibration_to_primary_demoted_existing_primary_use ): test_user = session.execute(select(User)).scalars().first() - existing_primary_calibration = await create_test_score_calibration_in_score_set( + existing_primary_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) existing_primary_calibration.private = False @@ -1102,6 +1502,36 @@ async def test_promote_score_calibration_to_primary_demoted_existing_primary_use assert promoted_calibration.created_by == test_user +@pytest.mark.asyncio +@pytest.mark.parametrize( + "mock_publication_fetch", + [ + [ + {"dbName": "PubMed", "identifier": TEST_PUBMED_IDENTIFIER}, + {"dbName": "bioRxiv", "identifier": TEST_BIORXIV_IDENTIFIER}, + ], + ], + indirect=["mock_publication_fetch"], +) +async def test_promote_score_calibration_to_primary_does_not_commit_transaction( + setup_lib_db_with_score_set, session, mock_publication_fetch +): + test_user = session.execute(select(User)).scalars().first() + + existing_calibration = await create_test_range_based_score_calibration_in_score_set( + session, setup_lib_db_with_score_set.urn, test_user + ) + existing_calibration.private = False + existing_calibration.primary = False + session.add(existing_calibration) + session.commit() + session.refresh(existing_calibration) + + with mock.patch.object(session, "commit") as mock_commit: + promote_score_calibration_to_primary(session, existing_calibration, test_user, force=False) + mock_commit.assert_not_called() + + ################################################################################ # Test demote_score_calibration_from_primary ################################################################################ @@ -1121,7 +1551,7 @@ async def test_promote_score_calibration_to_primary_demoted_existing_primary_use async def test_cannot_demote_non_primary_calibration(setup_lib_db_with_score_set, session, mock_publication_fetch): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) existing_calibration.primary = False @@ -1149,7 +1579,7 @@ async def test_demote_score_calibration_from_primary_marks_calibration_non_prima ): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) existing_calibration.primary = True @@ -1178,7 +1608,7 @@ async def test_demote_score_calibration_from_primary_user_is_set_as_modifier( ): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) existing_calibration.primary = True @@ -1193,6 +1623,35 @@ async def test_demote_score_calibration_from_primary_user_is_set_as_modifier( assert demoted_calibration.created_by == test_user +@pytest.mark.asyncio +@pytest.mark.parametrize( + "mock_publication_fetch", + [ + [ + {"dbName": "PubMed", "identifier": TEST_PUBMED_IDENTIFIER}, + {"dbName": "bioRxiv", "identifier": TEST_BIORXIV_IDENTIFIER}, + ], + ], + indirect=["mock_publication_fetch"], +) +async def test_demote_score_calibration_from_primary_does_not_commit_transaction( + setup_lib_db_with_score_set, session, mock_publication_fetch +): + test_user = session.execute(select(User)).scalars().first() + + existing_calibration = await create_test_range_based_score_calibration_in_score_set( + session, setup_lib_db_with_score_set.urn, test_user + ) + existing_calibration.primary = True + session.add(existing_calibration) + session.commit() + session.refresh(existing_calibration) + + with mock.patch.object(session, "commit") as mock_commit: + demote_score_calibration_from_primary(session, existing_calibration, test_user) + mock_commit.assert_not_called() + + ################################################################################ # Test delete_score_calibration ################################################################################ @@ -1212,7 +1671,7 @@ async def test_demote_score_calibration_from_primary_user_is_set_as_modifier( async def test_cannot_delete_primary_calibration(setup_lib_db_with_score_set, session, mock_publication_fetch): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) existing_calibration.primary = True @@ -1240,7 +1699,7 @@ async def test_delete_score_calibration_deletes_calibration( ): test_user = session.execute(select(User)).scalars().first() - existing_calibration = await create_test_score_calibration_in_score_set( + existing_calibration = await create_test_range_based_score_calibration_in_score_set( session, setup_lib_db_with_score_set.urn, test_user ) calibration_id = existing_calibration.id @@ -1250,3 +1709,1010 @@ async def test_delete_score_calibration_deletes_calibration( with pytest.raises(NoResultFound, match="No row was found when one was required"): session.execute(select(ScoreCalibration).where(ScoreCalibration.id == calibration_id)).scalars().one() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "mock_publication_fetch", + [ + [ + {"dbName": "PubMed", "identifier": TEST_PUBMED_IDENTIFIER}, + {"dbName": "bioRxiv", "identifier": TEST_BIORXIV_IDENTIFIER}, + ], + ], + indirect=["mock_publication_fetch"], +) +async def test_delete_score_calibration_does_not_commit_transaction( + setup_lib_db_with_score_set, session, mock_publication_fetch +): + test_user = session.execute(select(User)).scalars().first() + + existing_calibration = await create_test_range_based_score_calibration_in_score_set( + session, setup_lib_db_with_score_set.urn, test_user + ) + + with mock.patch.object(session, "commit") as mock_commit: + delete_score_calibration(session, existing_calibration) + mock_commit.assert_not_called() + + +################################################################################ +# Tests for variants_for_functional_classification +################################################################################ + + +def test_variants_for_functional_classification_returns_empty_list_when_range_and_classes_is_none( + setup_lib_db, session +): + mock_calibration = mock.Mock(spec=ScoreCalibration) + mock_calibration.score_set_id = 1 + mock_functional_calibration = mock.Mock(spec=ScoreCalibrationFunctionalClassification) + mock_functional_calibration.range = None + mock_functional_calibration.class_ = None + mock_functional_calibration.calibration = mock_calibration + + result = variants_for_functional_classification( + session, mock_functional_calibration, variant_classes=None, use_sql=False + ) + + assert result == [] + + +def test_variants_for_functional_classification_returns_empty_list_when_range_is_empty_list_and_classes_is_none( + setup_lib_db, session +): + mock_calibration = mock.Mock(spec=ScoreCalibration) + mock_calibration.score_set_id = 1 + mock_functional_calibration = mock.Mock(spec=ScoreCalibrationFunctionalClassification) + mock_functional_calibration.range = [] + mock_functional_calibration.class_ = None + mock_functional_calibration.calibration = mock_calibration + + result = variants_for_functional_classification( + session, mock_functional_calibration, variant_classes=None, use_sql=False + ) + + assert result == [] + + +@pytest.mark.parametrize( + "use_sql", + [True, False], +) +def test_variants_for_functional_classification_raises_error_when_index_column_not_found( + setup_lib_db, session, use_sql +): + mock_calibration = mock.Mock(spec=ScoreCalibration) + mock_calibration.score_set_id = 1 + mock_functional_calibration = mock.Mock(spec=ScoreCalibrationFunctionalClassification) + mock_functional_calibration.range = None + mock_functional_calibration.class_ = "benign" + mock_functional_calibration.calibration = mock_calibration + + variant_classes = pd.DataFrame( + { + "some_other_column": [ + "urn:mavedb:variant-1", + "urn:mavedb:variant-2", + "urn:mavedb:variant-3", + ], + calibration_class_column_name: [ + "pathogenic", + "benign", + "pathogenic", + ], + } + ) + + with pytest.raises(ValueError, match="Unsupported index column `some_other_column` for variant classification."): + variants_for_functional_classification( + session, + mock_functional_calibration, + variant_classes=variant_classification_df_to_dict(variant_classes, index_column="some_other_column"), + use_sql=use_sql, + ) + + +@pytest.mark.parametrize( + "range_,class_,variant_classes,index_column", + [ + ([1.0, 2.0], None, None, None), + ( + None, + "benign", + pd.DataFrame( + { + calibration_variant_column_name: [ + "urn:mavedb:variant-1", + "urn:mavedb:variant-2", + "urn:mavedb:variant-3", + ], + calibration_class_column_name: [ + "pathogenic", + "benign", + "pathogenic", + ], + } + ), + calibration_variant_column_name, + ), + ( + None, + "benign", + pd.DataFrame( + { + hgvs_nt_column: [ + "NC_000001.11:g.1000A>T", + "NC_000001.11:g.1001G>C", + "NC_000001.11:g.1002T>A", + ], + calibration_class_column_name: [ + "pathogenic", + "benign", + "pathogenic", + ], + } + ), + hgvs_nt_column, + ), + ( + None, + "benign", + pd.DataFrame( + { + hgvs_pro_column: [ + "NP_000000.1:p.Lys100Asn", + "NP_000000.1:p.Gly101Arg", + "NP_000000.1:p.Ser102Thr", + ], + calibration_class_column_name: [ + "pathogenic", + "benign", + "pathogenic", + ], + } + ), + hgvs_pro_column, + ), + ], +) +def test_variants_for_functional_classification_python_filtering_with_valid_variants( + setup_lib_db_with_score_set, session, range_, class_, variant_classes, index_column +): + variant_1 = Variant( + data={"score_data": {"score": 0.5}}, + score_set_id=setup_lib_db_with_score_set.id, + urn="urn:mavedb:variant-1", + hgvs_nt="NC_000001.11:g.1000A>T", + hgvs_pro="NP_000000.1:p.Lys100Asn", + ) + variant_2 = Variant( + data={"score_data": {"score": 1.5}}, + score_set_id=setup_lib_db_with_score_set.id, + urn="urn:mavedb:variant-2", + hgvs_nt="NC_000001.11:g.1001G>C", + hgvs_pro="NP_000000.1:p.Gly101Arg", + ) + variant_3 = Variant( + data={"score_data": {"score": 2.5}}, + score_set_id=setup_lib_db_with_score_set.id, + urn="urn:mavedb:variant-3", + hgvs_nt="NC_000001.11:g.1002T>A", + hgvs_pro="NP_000000.1:p.Ser102Thr", + ) + + session.add_all([variant_1, variant_2, variant_3]) + session.commit() + + mock_calibration = mock.Mock(spec=ScoreCalibration) + mock_calibration.score_set_id = setup_lib_db_with_score_set.id + mock_functional_classification = mock.Mock(spec=ScoreCalibrationFunctionalClassification) + mock_functional_classification.range = range_ + mock_functional_classification.class_ = class_ + mock_functional_classification.calibration = mock_calibration + mock_functional_classification.score_is_contained_in_range = mock.Mock(side_effect=lambda x: 1.0 <= x <= 2.0) + + result = variants_for_functional_classification( + session, + mock_functional_classification, + variant_classes=variant_classification_df_to_dict(variant_classes, index_column) + if variant_classes is not None + else None, + use_sql=False, + ) + + assert len(result) == 1 + assert result[0].data["score_data"]["score"] == 1.5 + + +@pytest.mark.parametrize( + "range_,class_,variant_classes", + [ + ([1.0, 2.0], None, None), + # not applicable when filtering by class + ], +) +def test_variants_for_functional_classification_python_filtering_skips_variants_without_score_data( + setup_lib_db_with_score_set, session, range_, class_, variant_classes +): + # Create variant without score_data + variant_without_score_data = Variant( + data={"other_data": {"value": 1.0}}, + score_set_id=setup_lib_db_with_score_set.id, + urn="urn:mavedb:variant-1", + ) + + # Create variant with valid score + variant_with_score = Variant( + data={"score_data": {"score": 1.5}}, + score_set_id=setup_lib_db_with_score_set.id, + urn="urn:mavedb:variant-2", + ) + + session.add_all([variant_without_score_data, variant_with_score]) + session.commit() + + mock_calibration = mock.Mock(spec=ScoreCalibration) + mock_calibration.score_set_id = setup_lib_db_with_score_set.id + mock_functional_classification = mock.Mock(spec=ScoreCalibrationFunctionalClassification) + mock_functional_classification.range = range_ + mock_functional_classification.class_ = class_ + mock_functional_classification.calibration = mock_calibration + mock_functional_classification.score_is_contained_in_range = mock.Mock(side_effect=lambda x: 1.0 <= x <= 2.0) + + result = variants_for_functional_classification( + session, + mock_functional_classification, + variant_classes=variant_classification_df_to_dict(variant_classes) if variant_classes is not None else None, + use_sql=False, + ) + + assert len(result) == 1 + assert result[0].data["score_data"]["score"] == 1.5 + + +@pytest.mark.parametrize( + "range_,class_,variant_classes", + [ + ([1.0, 2.0], None, None), + # not applicable when filtering by class + ], +) +def test_variants_for_functional_classification_python_filtering_skips_variants_with_non_dict_score_data( + setup_lib_db_with_score_set, session, range_, class_, variant_classes +): + # Create variant with non-dict score_data + variant_invalid_score_data = Variant( + data={"score_data": "not_a_dict"}, + score_set_id=setup_lib_db_with_score_set.id, + urn="urn:mavedb:variant-1", + ) + + # Create variant with valid score + variant_with_score = Variant( + data={"score_data": {"score": 1.5}}, + score_set_id=setup_lib_db_with_score_set.id, + urn="urn:mavedb:variant-2", + ) + + session.add_all([variant_invalid_score_data, variant_with_score]) + session.commit() + + mock_calibration = mock.Mock(spec=ScoreCalibration) + mock_calibration.score_set_id = setup_lib_db_with_score_set.id + mock_functional_classification = mock.Mock(spec=ScoreCalibrationFunctionalClassification) + mock_functional_classification.range = range_ + mock_functional_classification.class_ = class_ + mock_functional_classification.calibration = mock_calibration + mock_functional_classification.score_is_contained_in_range = mock.Mock(side_effect=lambda x: 1.0 <= x <= 2.0) + + result = variants_for_functional_classification( + session, + mock_functional_classification, + variant_classes=variant_classification_df_to_dict(variant_classes) if variant_classes is not None else None, + use_sql=False, + ) + assert len(result) == 1 + assert result[0].data["score_data"]["score"] == 1.5 + + +@pytest.mark.parametrize( + "range_,class_,variant_classes", + [ + ([1.0, 2.0], None, None), + # not applicable when filtering by class + ], +) +def test_variants_for_functional_classification_python_filtering_skips_variants_with_none_score( + setup_lib_db_with_score_set, session, range_, class_, variant_classes +): + # Create variant with None score + variant_none_score = Variant( + data={"score_data": {"score": None}}, + score_set_id=setup_lib_db_with_score_set.id, + urn="urn:mavedb:variant-1", + ) + + # Create variant with valid score + variant_with_score = Variant( + data={"score_data": {"score": 1.5}}, + score_set_id=setup_lib_db_with_score_set.id, + urn="urn:mavedb:variant-2", + ) + + session.add_all([variant_none_score, variant_with_score]) + session.commit() + + mock_calibration = mock.Mock(spec=ScoreCalibration) + mock_calibration.score_set_id = setup_lib_db_with_score_set.id + mock_functional_classification = mock.Mock(spec=ScoreCalibrationFunctionalClassification) + mock_functional_classification.range = [1.0, 2.0] + mock_functional_classification.calibration = mock_calibration + mock_functional_classification.score_is_contained_in_range = mock.Mock(side_effect=lambda x: 1.0 <= x <= 2.0) + + result = variants_for_functional_classification( + session, + mock_functional_classification, + variant_classes=variant_classification_df_to_dict(variant_classes) if variant_classes is not None else None, + use_sql=False, + ) + + assert len(result) == 1 + assert result[0].data["score_data"]["score"] == 1.5 + + +@pytest.mark.parametrize( + "range_,class_,variant_classes", + [ + ([1.0, 2.0], None, None), + # not applicable when filtering by class + ], +) +def test_variants_for_functional_classification_python_filtering_skips_variants_with_non_numeric_score( + setup_lib_db_with_score_set, session, range_, class_, variant_classes +): + # Create variant with non-numeric score + variant_string_score = Variant( + data={"score_data": {"score": "not_a_number"}}, + score_set_id=setup_lib_db_with_score_set.id, + urn="urn:mavedb:variant-1", + ) + + # Create variant with valid score + variant_with_score = Variant( + data={"score_data": {"score": 1.5}}, + score_set_id=setup_lib_db_with_score_set.id, + urn="urn:mavedb:variant-2", + ) + + session.add_all([variant_string_score, variant_with_score]) + session.commit() + + mock_calibration = mock.Mock(spec=ScoreCalibration) + mock_calibration.score_set_id = setup_lib_db_with_score_set.id + mock_functional_classification = mock.Mock(spec=ScoreCalibrationFunctionalClassification) + mock_functional_classification.range = range_ + mock_functional_classification.class_ = class_ + mock_functional_classification.calibration = mock_calibration + mock_functional_classification.score_is_contained_in_range = mock.Mock(side_effect=lambda x: 1.0 <= x <= 2.0) + + result = variants_for_functional_classification( + session, + mock_functional_classification, + variant_classes=variant_classification_df_to_dict(variant_classes) if variant_classes is not None else None, + use_sql=False, + ) + assert len(result) == 1 + assert result[0].data["score_data"]["score"] == 1.5 + + +@pytest.mark.parametrize( + "range_,class_,variant_classes", + [ + ([1.0, 2.0], None, None), + # not applicable when filtering by class + ], +) +def test_variants_for_functional_classification_python_filtering_skips_variants_with_non_dict_data( + setup_lib_db_with_score_set, session, range_, class_, variant_classes +): + # Create variant with non-dict data + variant_invalid_data = Variant( + data="not_a_dict", score_set_id=setup_lib_db_with_score_set.id, urn="urn:mavedb:variant-1" + ) + + # Create variant with valid score + variant_with_score = Variant( + data={"score_data": {"score": 1.5}}, + score_set_id=setup_lib_db_with_score_set.id, + urn="urn:mavedb:variant-2", + ) + + session.add_all([variant_invalid_data, variant_with_score]) + session.commit() + + mock_calibration = mock.Mock(spec=ScoreCalibration) + mock_calibration.score_set_id = setup_lib_db_with_score_set.id + mock_functional_classification = mock.Mock(spec=ScoreCalibrationFunctionalClassification) + mock_functional_classification.range = range_ + mock_functional_classification.class_ = class_ + mock_functional_classification.calibration = mock_calibration + mock_functional_classification.score_is_contained_in_range = mock.Mock(side_effect=lambda x: 1.0 <= x <= 2.0) + + result = variants_for_functional_classification( + session, + mock_functional_classification, + variant_classes=variant_classification_df_to_dict(variant_classes) if variant_classes is not None else None, + use_sql=False, + ) + assert len(result) == 1 + assert result[0].data["score_data"]["score"] == 1.5 + + +@pytest.mark.parametrize( + "use_sql", + [True, False], +) +@pytest.mark.parametrize( + "range_,class_,variant_classes,index_column", + [ + ([1.0, 2.0], None, None, None), + ( + None, + "benign", + pd.DataFrame( + { + calibration_variant_column_name: [ + "urn:mavedb:variant-1", + "urn:mavedb:variant-2", + "urn:mavedb:variant-3", + "urn:mavedb:variant-4", + "urn:mavedb:variant-5", + ], + calibration_class_column_name: [ + "pathogenic", + "benign", + "benign", + "benign", + "pathogenic", + ], + } + ), + calibration_variant_column_name, + ), + ( + None, + "benign", + pd.DataFrame( + { + hgvs_nt_column: [ + "NC_000001.11:g.1000A>T", + "NC_000001.11:g.1001G>C", + "NC_000001.11:g.1002T>A", + "NC_000001.11:g.1003C>G", + "NC_000001.11:g.1004G>A", + ], + calibration_class_column_name: [ + "pathogenic", + "benign", + "benign", + "benign", + "pathogenic", + ], + } + ), + hgvs_nt_column, + ), + ( + None, + "benign", + pd.DataFrame( + { + hgvs_pro_column: [ + "NP_000000.1:p.Lys100Asn", + "NP_000000.1:p.Gly101Arg", + "NP_000000.1:p.Ser102Thr", + "NP_000000.1:p.Ala103Pro", + "NP_000000.1:p.Val104Met", + ], + calibration_class_column_name: [ + "pathogenic", + "benign", + "benign", + "benign", + "pathogenic", + ], + } + ), + hgvs_pro_column, + ), + ], +) +def test_variants_for_functional_classification_filters_by_conditions( + setup_lib_db_with_score_set, session, use_sql, range_, class_, variant_classes, index_column +): + # Create variants with different scores + variants = [] + scores = [0.5, 1.0, 1.5, 2.0, 2.5] + hgvs_nts = [ + "NC_000001.11:g.1000A>T", + "NC_000001.11:g.1001G>C", + "NC_000001.11:g.1002T>A", + "NC_000001.11:g.1003C>G", + "NC_000001.11:g.1004G>A", + ] + hgvs_pros = [ + "NP_000000.1:p.Lys100Asn", + "NP_000000.1:p.Gly101Arg", + "NP_000000.1:p.Ser102Thr", + "NP_000000.1:p.Ala103Pro", + "NP_000000.1:p.Val104Met", + ] + for i, score in enumerate(scores, 1): + variant = Variant( + data={"score_data": {"score": score}}, + score_set_id=setup_lib_db_with_score_set.id, + urn=f"urn:mavedb:variant-{i}", + hgvs_nt=hgvs_nts[i - 1], + hgvs_pro=hgvs_pros[i - 1], + ) + variants.append(variant) + + session.add_all(variants) + session.commit() + + mock_calibration = mock.Mock(spec=ScoreCalibration) + mock_calibration.score_set_id = setup_lib_db_with_score_set.id + mock_functional_classification = mock.Mock(spec=ScoreCalibrationFunctionalClassification) + mock_functional_classification.range = range_ + mock_functional_classification.class_ = class_ + mock_functional_classification.inclusive_lower_bound = True + mock_functional_classification.inclusive_upper_bound = True + mock_functional_classification.calibration = mock_calibration + mock_functional_classification.score_is_contained_in_range = mock.Mock(side_effect=lambda x: 1.0 <= x <= 2.0) + + with mock.patch("mavedb.lib.score_calibrations.inf_or_float", side_effect=lambda x, lower: float(x)): + result = variants_for_functional_classification( + session, + mock_functional_classification, + variant_classes=variant_classification_df_to_dict(variant_classes, index_column) + if variant_classes is not None + else None, + use_sql=use_sql, + ) + + # Should return variants with scores 1.0, 1.5, 2.0 + result_scores = [v.data["score_data"]["score"] for v in result] + expected_scores = [1.0, 1.5, 2.0] + assert sorted(result_scores) == sorted(expected_scores) + + +@pytest.mark.parametrize( + "range_,class_,variant_classes,index_column", + [ + ([1.0, 2.0], None, None, None), + ( + None, + "benign", + pd.DataFrame( + { + calibration_variant_column_name: [ + "urn:mavedb:variant-1", + "urn:mavedb:variant-2", + "urn:mavedb:variant-3", + ], + calibration_class_column_name: [ + "benign", + "pathogenic", + "pathogenic", + ], + } + ), + calibration_variant_column_name, + ), + ( + None, + "benign", + pd.DataFrame( + { + hgvs_nt_column: [ + "NC_000001.11:g.1000A>T", + "NC_000001.11:g.1001G>C", + "NC_000001.11:g.1002T>A", + ], + calibration_class_column_name: [ + "benign", + "pathogenic", + "pathogenic", + ], + } + ), + hgvs_nt_column, + ), + ( + None, + "benign", + pd.DataFrame( + { + hgvs_pro_column: [ + "NP_000000.1:p.Lys100Asn", + "NP_000000.1:p.Gly101Arg", + "NP_000000.1:p.Ser102Thr", + ], + calibration_class_column_name: [ + "benign", + "pathogenic", + "pathogenic", + ], + } + ), + hgvs_pro_column, + ), + ], +) +def test_variants_for_functional_classification_sql_fallback_on_exception( + setup_lib_db_with_score_set, session, range_, class_, variant_classes, index_column +): + # Create a variant + variant = Variant( + data={"score_data": {"score": 1.5}}, + score_set_id=setup_lib_db_with_score_set.id, + urn="urn:mavedb:variant-1", + hgvs_nt="NC_000001.11:g.1000A>T", + hgvs_pro="NP_000000.1:p.Lys100Asn", + ) + session.add(variant) + session.commit() + + mock_calibration = mock.Mock(spec=ScoreCalibration) + mock_calibration.score_set_id = setup_lib_db_with_score_set.id + mock_functional_classification = mock.Mock(spec=ScoreCalibrationFunctionalClassification) + mock_functional_classification.range = range_ + mock_functional_classification.class_ = class_ + mock_functional_classification.calibration = mock_calibration + mock_functional_classification.score_is_contained_in_range = mock.Mock(side_effect=lambda x: 1.0 <= x <= 2.0) + + # Mock db.execute to raise an exception during SQL execution + with mock.patch.object( + session, + "execute", + side_effect=[ + Exception("SQL error"), + session.execute(select(Variant).where(Variant.score_set_id == setup_lib_db_with_score_set.id)), + ], + ) as mocked_execute: + result = variants_for_functional_classification( + session, + mock_functional_classification, + variant_classes=variant_classification_df_to_dict(variant_classes, index_column) + if variant_classes is not None + else None, + use_sql=True, + ) + mocked_execute.assert_called() + + # Should fall back to Python filtering and return the matching variant + assert len(result) == 1 + assert result[0].data["score_data"]["score"] == 1.5 + + +@pytest.mark.parametrize( + "range_,class_,variant_classes", + [ + ([1.0, float("inf")], None, None), + # not applicable when filtering by class + ], +) +def test_variants_for_functional_classification_sql_with_infinite_bound( + setup_lib_db_with_score_set, session, range_, class_, variant_classes +): + # Create variants with different scores + variants = [] + scores = [0.5, 1.5, 2.5] + for i, score in enumerate(scores): + variant = Variant( + data={"score_data": {"score": score}}, + score_set_id=setup_lib_db_with_score_set.id, + urn=f"urn:mavedb:variant-{i}", + ) + variants.append(variant) + + session.add_all(variants) + session.commit() + + # Mock functional classification with infinite upper bound + mock_calibration = mock.Mock(spec=ScoreCalibration) + mock_calibration.score_set_id = setup_lib_db_with_score_set.id + mock_functional_classification = mock.Mock(spec=ScoreCalibrationFunctionalClassification) + mock_functional_classification.range = range_ + mock_functional_classification.class_ = class_ + mock_functional_classification.calibration = mock_calibration + mock_functional_classification.inclusive_lower_bound = True + mock_functional_classification.inclusive_upper_bound = False + + with mock.patch( + "mavedb.lib.score_calibrations.inf_or_float", + side_effect=lambda x, lower: float("inf") if x == float("inf") else float(x), + ): + with mock.patch("math.isinf", side_effect=lambda x: x == float("inf")): + result = variants_for_functional_classification( + session, + mock_functional_classification, + variant_classes=variant_classification_df_to_dict(variant_classes) + if variant_classes is not None + else None, + use_sql=True, + ) + + # Should return variants with scores >= 1.0 + result_scores = [v.data["score_data"]["score"] for v in result] + expected_scores = [1.5, 2.5] + assert sorted(result_scores) == sorted(expected_scores) + + +@pytest.mark.parametrize( + "range_,class_,variant_classes", + [ + ([1.0, 2.0], None, None), + # not applicable when filtering by class + ], +) +def test_variants_for_functional_classification_sql_with_exclusive_bounds( + setup_lib_db_with_score_set, session, range_, class_, variant_classes +): + # Create variants with boundary scores + variants = [] + scores = [1.0, 1.5, 2.0] + for i, score in enumerate(scores): + variant = Variant( + data={"score_data": {"score": score}}, + score_set_id=setup_lib_db_with_score_set.id, + urn=f"urn:mavedb:variant-{i}", + ) + variants.append(variant) + + session.add_all(variants) + session.commit() + + # Mock functional classification with exclusive bounds + mock_calibration = mock.Mock(spec=ScoreCalibration) + mock_calibration.score_set_id = setup_lib_db_with_score_set.id + mock_functional_classification = mock.Mock(spec=ScoreCalibrationFunctionalClassification) + mock_functional_classification.range = range_ + mock_functional_classification.class_ = class_ + mock_functional_classification.calibration = mock_calibration + mock_functional_classification.inclusive_lower_bound = False + mock_functional_classification.inclusive_upper_bound = False + + with mock.patch("mavedb.lib.score_calibrations.inf_or_float", side_effect=lambda x, lower: float(x)): + result = variants_for_functional_classification( + session, + mock_functional_classification, + variant_classes=variant_classification_df_to_dict(variant_classes) if variant_classes is not None else None, + use_sql=True, + ) + + # Should return only variant with score 1.5 (exclusive bounds) + result_scores = [v.data["score_data"]["score"] for v in result] + assert result_scores == [1.5] + + +@pytest.mark.parametrize( + "range_,class_,variant_classes", + [ + ([1.0, 2.0], None, None), + # not applicable when filtering by class + ], +) +def test_variants_for_functional_classification_only_returns_variants_from_correct_score_set( + setup_lib_db_with_score_set, session, range_, class_, variant_classes +): + # Create another score set + other_score_set = ScoreSet( + urn="urn:mavedb:00000000-B-0", + experiment_id=setup_lib_db_with_score_set.experiment_id, + licence_id=TEST_LICENSE["id"], + title="Other Score Set", + method_text="Other method", + abstract_text="Other abstract", + short_description="Other description", + created_by=setup_lib_db_with_score_set.created_by, + modified_by=setup_lib_db_with_score_set.modified_by, + extra_metadata={}, + ) + session.add(other_score_set) + session.commit() + + # Create variants in both score sets + variant_in_target_set = Variant( + data={"score_data": {"score": 1.5}}, + score_set_id=setup_lib_db_with_score_set.id, + urn="urn:mavedb:variant-target", + ) + variant_in_other_set = Variant( + data={"score_data": {"score": 1.5}}, score_set_id=other_score_set.id, urn="urn:mavedb:variant-other" + ) + + session.add_all([variant_in_target_set, variant_in_other_set]) + session.commit() + + mock_calibration = mock.Mock(spec=ScoreCalibration) + mock_calibration.score_set_id = setup_lib_db_with_score_set.id + mock_functional_classification = mock.Mock(spec=ScoreCalibrationFunctionalClassification) + mock_functional_classification.range = range_ + mock_functional_classification.class_ = class_ + mock_functional_classification.calibration = mock_calibration + mock_functional_classification.score_is_contained_in_range = mock.Mock(side_effect=lambda x: 1.0 <= x <= 2.0) + + result = variants_for_functional_classification( + session, + mock_functional_classification, + variant_classes=variant_classification_df_to_dict(variant_classes) if variant_classes is not None else None, + use_sql=False, + ) + # Should only return variant from the target score set + assert len(result) == 1 + assert result[0].score_set_id == setup_lib_db_with_score_set.id + assert result[0].urn == "urn:mavedb:variant-target" + + +################################################################################ +# Tests for variant_classification_df_to_dict +################################################################################ + + +def test_variant_classification_df_to_dict_with_single_class(): + """Test conversion with DataFrame containing variants of a single functional class.""" + df = pd.DataFrame( + { + calibration_variant_column_name: ["var1", "var2", "var3"], + calibration_class_column_name: ["pathogenic", "pathogenic", "pathogenic"], + } + ) + + result = variant_classification_df_to_dict(df, calibration_variant_column_name) + + expected = { + "indexed_by": calibration_variant_column_name, + "classifications": {"pathogenic": set(["var1", "var2", "var3"])}, + } + assert result["classifications"] == expected["classifications"] + assert result["indexed_by"] == expected["indexed_by"] + + +def test_variant_classification_df_to_dict_with_multiple_classes(): + """Test conversion with DataFrame containing variants of multiple functional classes.""" + df = pd.DataFrame( + { + calibration_variant_column_name: ["var1", "var2", "var3", "var4", "var5"], + calibration_class_column_name: ["pathogenic", "benign", "pathogenic", "uncertain", "benign"], + } + ) + + result = variant_classification_df_to_dict(df, calibration_variant_column_name) + + expected = { + "indexed_by": calibration_variant_column_name, + "classifications": { + "pathogenic": set(["var1", "var3"]), + "benign": set(["var2", "var5"]), + "uncertain": set(["var4"]), + }, + } + assert result["classifications"] == expected["classifications"] + assert result["indexed_by"] == expected["indexed_by"] + + +def test_variant_classification_df_to_dict_with_empty_dataframe(): + """Test conversion with empty DataFrame.""" + df = pd.DataFrame(columns=[calibration_variant_column_name, calibration_class_column_name]) + + result = variant_classification_df_to_dict(df, calibration_variant_column_name) + + assert result["classifications"] == {} + assert result["indexed_by"] == calibration_variant_column_name + + +def test_variant_classification_df_to_dict_with_single_row(): + """Test conversion with DataFrame containing single row.""" + df = pd.DataFrame({calibration_variant_column_name: ["var1"], calibration_class_column_name: ["pathogenic"]}) + + result = variant_classification_df_to_dict(df, calibration_variant_column_name) + + expected = {"indexed_by": calibration_variant_column_name, "classifications": {"pathogenic": set(["var1"])}} + assert result["classifications"] == expected["classifications"] + assert result["indexed_by"] == expected["indexed_by"] + + +def test_variant_classification_df_to_dict_with_extra_columns(): + """Test conversion ignores extra columns in DataFrame.""" + df = pd.DataFrame( + { + calibration_variant_column_name: ["var1", "var2"], + calibration_class_column_name: ["pathogenic", "benign"], + "extra_column": ["value1", "value2"], + "another_column": [1, 2], + } + ) + + result = variant_classification_df_to_dict(df, calibration_variant_column_name) + + expected = { + "indexed_by": calibration_variant_column_name, + "classifications": {"pathogenic": set(["var1"]), "benign": set(["var2"])}, + } + assert result["classifications"] == expected["classifications"] + assert result["indexed_by"] == expected["indexed_by"] + + +def test_variant_classification_df_to_dict_with_duplicate_variants_in_same_class(): + """Test handling of duplicate variant URNs in the same functional class.""" + df = pd.DataFrame( + { + calibration_variant_column_name: ["var1", "var1", "var2"], + calibration_class_column_name: ["pathogenic", "pathogenic", "benign"], + } + ) + + result = variant_classification_df_to_dict(df, calibration_variant_column_name) + + expected = { + "indexed_by": calibration_variant_column_name, + "classifications": {"pathogenic": set(["var1"]), "benign": set(["var2"])}, + } + assert result["classifications"] == expected["classifications"] + assert result["indexed_by"] == expected["indexed_by"] + + +def test_variant_classification_df_to_dict_with_none_values(): + """Test handling of None values in functional class column.""" + df = pd.DataFrame( + { + calibration_variant_column_name: ["var1", "var2", "var3"], + calibration_class_column_name: ["pathogenic", None, "benign"], + } + ) + + result = variant_classification_df_to_dict(df, calibration_variant_column_name) + + expected = { + "indexed_by": calibration_variant_column_name, + "classifications": {"pathogenic": set(["var1"]), None: set(["var2"]), "benign": set(["var3"])}, + } + assert result["classifications"] == expected["classifications"] + assert result["indexed_by"] == expected["indexed_by"] + + +def test_variant_classification_df_to_dict_with_numeric_classes(): + """Test handling of numeric functional class labels.""" + df = pd.DataFrame( + {calibration_variant_column_name: ["var1", "var2", "var3"], calibration_class_column_name: [1, 2, 1]} + ) + + result = variant_classification_df_to_dict(df, calibration_variant_column_name) + + expected = { + "indexed_by": calibration_variant_column_name, + "classifications": {1: set(["var1", "var3"]), 2: set(["var2"])}, + } + assert result["classifications"] == expected["classifications"] + assert result["indexed_by"] == expected["indexed_by"] + + +def test_variant_classification_df_to_dict_with_mixed_type_classes(): + """Test handling of mixed data types in functional class column.""" + df = pd.DataFrame( + { + calibration_variant_column_name: ["var1", "var2", "var3", "var4"], + calibration_class_column_name: ["pathogenic", 1, "benign", 1], + } + ) + + result = variant_classification_df_to_dict(df, calibration_variant_column_name) + + expected = { + "indexed_by": calibration_variant_column_name, + "classifications": {"pathogenic": set(["var1"]), 1: set(["var2", "var4"]), "benign": set(["var3"])}, + } + assert result["classifications"] == expected["classifications"] + assert result["indexed_by"] == expected["indexed_by"] diff --git a/tests/lib/test_score_set.py b/tests/lib/test_score_set.py index a260599a..d9f7fa39 100644 --- a/tests/lib/test_score_set.py +++ b/tests/lib/test_score_set.py @@ -7,6 +7,10 @@ import pytest from sqlalchemy import select +from mavedb.models.enums.target_category import TargetCategory +from mavedb.models.user import User +from mavedb.view_models.search import ScoreSetsSearch + arq = pytest.importorskip("arq") cdot = pytest.importorskip("cdot") fastapi = pytest.importorskip("fastapi") @@ -17,7 +21,9 @@ create_variants, create_variants_data, csv_data_to_df, + fetch_score_set_search_filter_options, ) +from mavedb.lib.types.authentication import UserData from mavedb.lib.validation.constants.general import ( hgvs_nt_column, hgvs_pro_column, @@ -33,7 +39,7 @@ from mavedb.models.target_sequence import TargetSequence from mavedb.models.taxonomy import Taxonomy from mavedb.models.variant import Variant -from tests.helpers.constants import TEST_EXPERIMENT, TEST_ACC_SCORESET, TEST_SEQ_SCORESET +from tests.helpers.constants import TEST_ACC_SCORESET, TEST_EXPERIMENT, TEST_SEQ_SCORESET, TEST_USER from tests.helpers.util.experiment import create_experiment from tests.helpers.util.score_set import create_seq_score_set @@ -377,3 +383,170 @@ def test_create_null_score_range(setup_lib_db, client, session): assert not score_set.score_calibrations assert score_set is not None + + +def test_fetch_score_set_search_filter_options_no_score_sets(setup_lib_db, session): + score_set_search = ScoreSetsSearch() + filter_options = fetch_score_set_search_filter_options(session, None, None, score_set_search) + + assert filter_options == { + "target_gene_categories": [], + "target_gene_names": [], + "target_organism_names": [], + "target_accessions": [], + "publication_author_names": [], + "publication_db_names": [], + "publication_journals": [], + } + + +def test_fetch_score_set_search_filter_options_with_score_set(setup_lib_db, session): + requesting_user = session.query(User).filter(User.username == TEST_USER["username"]).first() + user_data = UserData(user=requesting_user, active_roles=[]) + + experiment = Experiment(**TEST_EXPERIMENT) + session.add(experiment) + session.commit() + session.refresh(experiment) + + target_accessions = [TargetAccession(**seq["target_accession"]) for seq in TEST_ACC_SCORESET["target_genes"]] + target_genes = [ + TargetGene(**{**gene, **{"target_accession": target_accessions[idx]}}) + for idx, gene in enumerate(TEST_ACC_SCORESET["target_genes"]) + ] + + score_set = ScoreSet( + **{ + **TEST_ACC_SCORESET, + **{ + "experiment_id": experiment.id, + "target_genes": target_genes, + "extra_metadata": {}, + "license": session.scalars(select(License)).first(), + }, + "created_by_id": requesting_user.id, + "modified_by_id": requesting_user.id, + } + ) + session.add(score_set) + session.commit() + session.refresh(score_set) + + score_set_search = ScoreSetsSearch() + filter_options = fetch_score_set_search_filter_options(session, user_data, None, score_set_search) + + assert filter_options == { + "target_gene_categories": [{"value": TargetCategory.protein_coding, "count": 1}], + "target_gene_names": [{"value": "TEST2", "count": 1}], + "target_organism_names": [], + "target_accessions": [{"value": "NM_001637.3", "count": 1}], + "publication_author_names": [], + "publication_db_names": [], + "publication_journals": [], + } + + +def test_fetch_score_set_search_filter_options_with_partial_filtered_score_sets(setup_lib_db, session): + requesting_user = session.query(User).filter(User.username == TEST_USER["username"]).first() + user_data = UserData(user=requesting_user, active_roles=[]) + + experiment = Experiment(**TEST_EXPERIMENT) + session.add(experiment) + session.commit() + session.refresh(experiment) + + target_sequences = [ + TargetSequence(**{**seq["target_sequence"], **{"taxonomy": session.scalars(select(Taxonomy)).first()}}) + for seq in TEST_SEQ_SCORESET["target_genes"] + ] + target_genes = [ + TargetGene(**{**gene, **{"target_sequence": target_sequences[idx]}}) + for idx, gene in enumerate(TEST_SEQ_SCORESET["target_genes"]) + ] + + score_set = ScoreSet( + **{ + **TEST_SEQ_SCORESET, + **{ + "experiment_id": experiment.id, + "target_genes": target_genes, + "extra_metadata": {}, + "license": session.scalars(select(License)).first(), + }, + "created_by_id": requesting_user.id, + "modified_by_id": requesting_user.id, + } + ) + session.add(score_set) + session.commit() + session.refresh(score_set) + + target_accessions = [TargetAccession(**seq["target_accession"]) for seq in TEST_ACC_SCORESET["target_genes"]] + target_genes = [ + TargetGene(**{**gene, **{"target_accession": target_accessions[idx]}}) + for idx, gene in enumerate(TEST_ACC_SCORESET["target_genes"]) + ] + + score_set = ScoreSet( + **{ + **TEST_ACC_SCORESET, + **{ + "experiment_id": experiment.id, + "target_genes": target_genes, + "extra_metadata": {}, + "license": session.scalars(select(License)).first(), + }, + "created_by_id": requesting_user.id, + "modified_by_id": requesting_user.id, + } + ) + session.add(score_set) + session.commit() + + session.refresh(score_set) + + score_set_search = ScoreSetsSearch(targets=["TEST1"]) + requesting_user = session.query(User).filter(User.username == TEST_USER["username"]).first() + user_data = UserData(user=requesting_user, active_roles=[]) + filter_options = fetch_score_set_search_filter_options(session, user_data, None, score_set_search) + assert filter_options == { + "target_gene_categories": [{"value": TargetCategory.protein_coding, "count": 1}], + "target_gene_names": [{"value": "TEST1", "count": 1}], + "target_organism_names": [{"count": 1, "value": "Organism name"}], + "target_accessions": [], + "publication_author_names": [], + "publication_db_names": [], + "publication_journals": [], + } + + +def test_fetch_score_set_search_filter_options_with_no_matching_score_sets(setup_lib_db, session): + score_set_search = ScoreSetsSearch(publication_journals=["Non Existent Journal"]) + requesting_user = session.query(User).filter(User.username == TEST_USER["username"]).first() + user_data = UserData(user=requesting_user, active_roles=[]) + filter_options = fetch_score_set_search_filter_options(session, user_data, None, score_set_search) + + assert filter_options == { + "target_gene_categories": [], + "target_gene_names": [], + "target_organism_names": [], + "target_accessions": [], + "publication_author_names": [], + "publication_db_names": [], + "publication_journals": [], + } + + +def test_fetch_score_set_search_filter_options_with_no_permitted_score_sets(setup_lib_db, session): + score_set_search = ScoreSetsSearch() + filter_options = fetch_score_set_search_filter_options(session, None, None, score_set_search) + + assert filter_options == { + "target_gene_categories": [], + "target_gene_names": [], + "target_organism_names": [], + "target_accessions": [], + "publication_author_names": [], + "publication_db_names": [], + "publication_journals": [], + } diff --git a/tests/routers/data/calibration_classes_by_hgvs_nt.csv b/tests/routers/data/calibration_classes_by_hgvs_nt.csv new file mode 100644 index 00000000..07025f44 --- /dev/null +++ b/tests/routers/data/calibration_classes_by_hgvs_nt.csv @@ -0,0 +1,4 @@ +hgvs_nt,class_name +c.1A>T,normal_class +c.2C>T,abnormal_class +c.6T>A,not_specified_class \ No newline at end of file diff --git a/tests/routers/data/calibration_classes_by_hgvs_prot.csv b/tests/routers/data/calibration_classes_by_hgvs_prot.csv new file mode 100644 index 00000000..0a948cb8 --- /dev/null +++ b/tests/routers/data/calibration_classes_by_hgvs_prot.csv @@ -0,0 +1,4 @@ +hgvs_pro,class_name +p.Thr1Ser,normal_class +p.Thr1Met,abnormal_class +p.Phe2Leu,not_specified_class \ No newline at end of file diff --git a/tests/routers/data/calibration_classes_by_urn.csv b/tests/routers/data/calibration_classes_by_urn.csv new file mode 100644 index 00000000..d7654e67 --- /dev/null +++ b/tests/routers/data/calibration_classes_by_urn.csv @@ -0,0 +1,4 @@ +variant_urn,class_name +urn:mavedb:00000001-a-1#1,normal_class +urn:mavedb:00000001-a-1#2,abnormal_class +urn:mavedb:00000001-a-1#3,not_specified_class \ No newline at end of file diff --git a/tests/routers/test_collections.py b/tests/routers/test_collections.py index 3b3bec65..f7103a9b 100644 --- a/tests/routers/test_collections.py +++ b/tests/routers/test_collections.py @@ -14,12 +14,11 @@ from mavedb.lib.validation.urn_re import MAVEDB_COLLECTION_URN_RE from mavedb.models.enums.contribution_role import ContributionRole from mavedb.view_models.collection import Collection - from tests.helpers.constants import ( EXTRA_USER, - TEST_USER, TEST_COLLECTION, TEST_COLLECTION_RESPONSE, + TEST_USER, ) from tests.helpers.dependency_overrider import DependencyOverrider from tests.helpers.util.collection import create_collection @@ -198,7 +197,7 @@ def test_unauthorized_user_cannot_read_private_collection(session, client, setup response = client.get(f"/api/v1/collections/{collection['urn']}") assert response.status_code == 404 - assert f"collection with URN '{collection['urn']}' not found" in response.json()["detail"] + assert f"collection with URN '{collection['urn']}'" in response.json()["detail"] def test_anonymous_cannot_read_private_collection(session, client, setup_router_db, anonymous_app_overrides): @@ -208,7 +207,7 @@ def test_anonymous_cannot_read_private_collection(session, client, setup_router_ response = client.get(f"/api/v1/collections/{collection['urn']}") assert response.status_code == 404 - assert f"collection with URN '{collection['urn']}' not found" in response.json()["detail"] + assert f"collection with URN '{collection['urn']}'" in response.json()["detail"] def test_anonymous_can_read_public_collection(session, client, setup_router_db, anonymous_app_overrides): @@ -360,7 +359,7 @@ def test_viewer_cannot_add_experiment_to_collection( assert response.status_code == 403 response_data = response.json() - assert f"insufficient permissions for URN '{collection['urn']}'" in response_data["detail"] + assert f"insufficient permissions on collection with URN '{collection['urn']}'" in response_data["detail"] def test_unauthorized_user_cannot_add_experiment_to_collection( @@ -544,7 +543,7 @@ def test_viewer_cannot_add_score_set_to_collection( assert response.status_code == 403 response_data = response.json() - assert f"insufficient permissions for URN '{collection['urn']}'" in response_data["detail"] + assert f"insufficient permissions on collection with URN '{collection['urn']}'" in response_data["detail"] def test_unauthorized_user_cannot_add_score_set_to_collection( diff --git a/tests/routers/test_experiments.py b/tests/routers/test_experiments.py index 9767c125..1a04ed6a 100644 --- a/tests/routers/test_experiments.py +++ b/tests/routers/test_experiments.py @@ -28,6 +28,7 @@ TEST_EXPERIMENT_WITH_KEYWORD, TEST_EXPERIMENT_WITH_KEYWORD_HAS_DUPLICATE_OTHERS_RESPONSE, TEST_EXPERIMENT_WITH_KEYWORD_RESPONSE, + TEST_EXPERIMENT_WITH_UPDATE_KEYWORD_RESPONSE, TEST_MEDRXIV_IDENTIFIER, TEST_MINIMAL_EXPERIMENT, TEST_MINIMAL_EXPERIMENT_RESPONSE, @@ -292,6 +293,236 @@ def test_cannot_create_experiment_that_keywords_has_wrong_combination4(client, s ) +# Test the validator of Endogenous locus keywords +def test_create_experiment_that_keywords_has_endogenous(client, setup_router_db): + """ + Test src/mavedb/lib/validation/keywords.validate_keyword_keys function + if users choose endogenous locus library method in Variant Library Creation Method + """ + keywords = { + "keywords": [ + { + "keyword": { + "key": "Variant Library Creation Method", + "label": "Endogenous locus library method", + "special": False, + "description": "Description", + }, + }, + { + "keyword": { + "key": "Endogenous Locus Library Method System", + "label": "SaCas9", + "special": False, + "description": "Description", + }, + }, + { + "keyword": { + "key": "Endogenous Locus Library Method Mechanism", + "label": "Base editor", + "special": False, + "description": "Description", + }, + }, + ] + } + experiment = {**TEST_MINIMAL_EXPERIMENT, **keywords} + response = client.post("/api/v1/experiments/", json=experiment) + assert response.status_code == 200 + + +def test_cannot_create_experiment_that_keywords_has_endogenous_without_method_mechanism(client, setup_router_db): + """ + Test src/mavedb/lib/validation/keywords.validate_keyword_keys function + Choose endogenous locus library method in Variant Library Creation Method, + but miss the endogenous locus library method mechanism + """ + incomplete_keywords = { + "keywords": [ + { + "keyword": { + "key": "Variant Library Creation Method", + "label": "Endogenous locus library method", + "special": False, + "description": "Description", + }, + }, + { + "keyword": { + "key": "Endogenous Locus Library Method System", + "label": "SaCas9", + "special": False, + "description": "Description", + }, + }, + ] + } + experiment = {**TEST_MINIMAL_EXPERIMENT, **incomplete_keywords} + response = client.post("/api/v1/experiments/", json=experiment) + assert response.status_code == 422 + response_data = response.json() + assert ( + response_data["detail"] + == "If 'Variant Library Creation Method' is 'Endogenous locus library method', " + "both 'Endogenous Locus Library Method System' and 'Endogenous Locus Library Method Mechanism' " + "must be present." + ) + + +def test_cannot_create_experiment_that_keywords_has_endogenous_without_method_system(client, setup_router_db): + """ + Test src/mavedb/lib/validation/keywords.validate_keyword_keys function + Choose endogenous locus library method in Variant Library Creation Method, + but miss the endogenous locus library method system + """ + incomplete_keywords = { + "keywords": [ + { + "keyword": { + "key": "Variant Library Creation Method", + "label": "Endogenous locus library method", + "special": False, + "description": "Description", + }, + }, + { + "keyword": { + "key": "Endogenous Locus Library Method Mechanism", + "label": "Base editor", + "special": False, + "description": "Description", + }, + }, + ] + } + experiment = {**TEST_MINIMAL_EXPERIMENT, **incomplete_keywords} + response = client.post("/api/v1/experiments/", json=experiment) + assert response.status_code == 422 + response_data = response.json() + assert ( + response_data["detail"] + == "If 'Variant Library Creation Method' is 'Endogenous locus library method', " + "both 'Endogenous Locus Library Method System' and 'Endogenous Locus Library Method Mechanism' " + "must be present." + ) + + +# Test the validator of in vitro keywords +def test_create_experiment_that_keywords_has_in_vitro(client, setup_router_db): + """ + Test src/mavedb/lib/validation/keywords.validate_keyword_keys function + if users choose in vitro construct library method in Variant Library Creation Method + """ + keywords = { + "keywords": [ + { + "keyword": { + "key": "Variant Library Creation Method", + "label": "In vitro construct library method", + "special": False, + "description": "Description", + }, + }, + { + "keyword": { + "key": "In Vitro Construct Library Method System", + "label": "Oligo-directed mutagenic PCR", + "special": False, + "description": "Description", + }, + }, + { + "keyword": { + "key": "In Vitro Construct Library Method Mechanism", + "label": "Native locus replacement", + "special": False, + "description": "Description", + }, + }, + ] + } + experiment = {**TEST_MINIMAL_EXPERIMENT, **keywords} + response = client.post("/api/v1/experiments/", json=experiment) + assert response.status_code == 200 + + +def test_cannot_create_experiment_that_keywords_has_in_vitro_without_method_system(client, setup_router_db): + """ + Test src/mavedb/lib/validation/keywords.validate_keyword_keys function + Choose in vitro construct library method in Variant Library Creation Method, + but miss the in vitro construct library method system + """ + incomplete_keywords = { + "keywords": [ + { + "keyword": { + "key": "Variant Library Creation Method", + "label": "In vitro construct library method", + "special": False, + "description": "Description", + }, + }, + { + "keyword": { + "key": "In Vitro Construct Library Method Mechanism", + "label": "Native locus replacement", + "special": False, + "description": "Description", + }, + }, + ] + } + experiment = {**TEST_MINIMAL_EXPERIMENT, **incomplete_keywords} + response = client.post("/api/v1/experiments/", json=experiment) + assert response.status_code == 422 + response_data = response.json() + assert ( + response_data["detail"] + == "If 'Variant Library Creation Method' is 'In vitro construct library method', " + "both 'In Vitro Construct Library Method System' and 'In Vitro Construct Library Method Mechanism' " + "must be present." + ) + + +def test_cannot_create_experiment_that_keywords_has_in_vitro_without_method_mechanism(client, setup_router_db): + """ + Test src/mavedb/lib/validation/keywords.validate_keyword_keys function + Choose in vitro construct library method in Variant Library Creation Method, + but miss the in vitro construct library method mechanism + """ + incomplete_keywords = { + "keywords": [ + { + "keyword": { + "key": "Variant Library Creation Method", + "label": "In vitro construct library method", + "special": False, + "description": "Description", + }, + }, + { + "keyword": { + "key": "In Vitro Construct Library Method System", + "label": "Oligo-directed mutagenic PCR", + "special": False, + "description": "Description", + }, + }, + ] + } + experiment = {**TEST_MINIMAL_EXPERIMENT, **incomplete_keywords} + response = client.post("/api/v1/experiments/", json=experiment) + assert response.status_code == 422 + response_data = response.json() + assert ( + response_data["detail"] + == "If 'Variant Library Creation Method' is 'In vitro construct library method', " + "both 'In Vitro Construct Library Method System' and 'In Vitro Construct Library Method Mechanism' " + "must be present." + ) + + def test_create_experiment_that_keyword_gene_ontology_has_valid_code(client, setup_router_db): valid_keyword = { "keywords": [ @@ -422,7 +653,7 @@ def test_cannot_create_experiment_that_keywords_have_duplicate_labels(client, se "keywords": [ { "keyword": { - "key": "Delivery method", + "key": "Delivery Method", "label": "In vitro construct library method", "special": False, "description": "Description", @@ -462,7 +693,7 @@ def test_create_experiment_that_keywords_have_duplicate_others(client, setup_rou "description": "Description", }, { - "keyword": {"key": "Delivery method", "label": "Other", "special": False, "description": "Description"}, + "keyword": {"key": "Delivery Method", "label": "Other", "special": False, "description": "Description"}, "description": "Description", }, ] @@ -481,6 +712,54 @@ def test_create_experiment_that_keywords_have_duplicate_others(client, setup_rou assert (key, expected_response[key]) == (key, response_data[key]) +def test_update_experiment_keywords(session, client, setup_router_db): + response = client.post("/api/v1/experiments/", json=TEST_EXPERIMENT_WITH_KEYWORD) + assert response.status_code == 200 + experiment = response.json() + experiment_post_payload = experiment.copy() + experiment_post_payload.update({"keywords": [ + { + "keyword": { + "key": "Phenotypic Assay Profiling Strategy", + "label": "Shotgun sequencing", + "special": False, + "description": "Description" + }, + "description": "Details of phenotypic assay profiling strategy", + }, + + ]}) + updated_response = client.put(f"/api/v1/experiments/{experiment['urn']}", json=experiment_post_payload) + assert updated_response.status_code == 200 + updated_experiment = updated_response.json() + updated_expected_response = deepcopy(TEST_EXPERIMENT_WITH_UPDATE_KEYWORD_RESPONSE) + updated_expected_response.update({"urn": updated_experiment["urn"], "experimentSetUrn": updated_experiment["experimentSetUrn"]}) + assert sorted(updated_expected_response.keys()) == sorted(updated_experiment.keys()) + for key in updated_experiment: + assert (key, updated_expected_response[key]) == (key, updated_experiment[key]) + for kw in updated_experiment["keywords"]: + assert "Delivery Method" not in kw["keyword"]["key"] + + +def test_update_experiment_keywords_case_insensitive(session, client, setup_router_db): + experiment = create_experiment(client) + experiment_post_payload = experiment.copy() + # Test database has Delivery Method. The updating keyword's key is delivery method. + experiment_post_payload.update({"keywords": [ + { + "keyword": {"key": "delivery method", "label": "Other", "special": False, "description": "Description"}, + "description": "Details of delivery method", + }, + ]}) + response = client.put(f"/api/v1/experiments/{experiment['urn']}", json=experiment_post_payload) + response_data = response.json() + expected_response = deepcopy(TEST_EXPERIMENT_WITH_KEYWORD_RESPONSE) + expected_response.update({"urn": response_data["urn"], "experimentSetUrn": response_data["experimentSetUrn"]}) + assert sorted(expected_response.keys()) == sorted(response_data.keys()) + for key in expected_response: + assert (key, expected_response[key]) == (key, response_data[key]) + + def test_can_delete_experiment(client, setup_router_db): experiment = create_experiment(client) response = client.delete(f"api/v1/experiments/{experiment['urn']}") @@ -621,7 +900,10 @@ def test_cannot_update_other_users_public_experiment_set(session, data_provider, response = client.post("/api/v1/experiments/", json=experiment_post_payload) assert response.status_code == 403 response_data = response.json() - assert f"insufficient permissions for URN '{published_experiment_set_urn}'" in response_data["detail"] + assert ( + f"insufficient permissions on experiment set with URN '{published_experiment_set_urn}'" + in response_data["detail"] + ) def test_anonymous_cannot_update_others_user_public_experiment_set( @@ -1651,10 +1933,12 @@ def test_cannot_delete_own_published_experiment(session, data_provider, client, assert del_response.status_code == 403 del_response_data = del_response.json() - assert f"insufficient permissions for URN '{experiment_urn}'" in del_response_data["detail"] + assert f"insufficient permissions on experiment with URN '{experiment_urn}'" in del_response_data["detail"] -def test_contributor_can_delete_other_users_private_experiment(session, client, setup_router_db, admin_app_overrides): +def test_contributor_cannot_delete_other_users_private_experiment( + session, client, setup_router_db, admin_app_overrides +): experiment = create_experiment(client) change_ownership(session, experiment["urn"], ExperimentDbModel) add_contributor( @@ -1667,7 +1951,8 @@ def test_contributor_can_delete_other_users_private_experiment(session, client, ) response = client.delete(f"/api/v1/experiments/{experiment['urn']}") - assert response.status_code == 200 + assert response.status_code == 403 + assert f"insufficient permissions on experiment with URN '{experiment['urn']}'" in response.json()["detail"] def test_admin_can_delete_other_users_private_experiment(session, client, setup_router_db, admin_app_overrides): @@ -1833,4 +2118,4 @@ def test_cannot_add_experiment_to_others_public_experiment_set( response = client.post("/api/v1/experiments/", json=test_experiment) assert response.status_code == 403 response_data = response.json() - assert f"insufficient permissions for URN '{experiment_set_urn}'" in response_data["detail"] + assert f"insufficient permissions on experiment set with URN '{experiment_set_urn}'" in response_data["detail"] diff --git a/tests/routers/test_hgvs.py b/tests/routers/test_hgvs.py index b931d859..6011953f 100644 --- a/tests/routers/test_hgvs.py +++ b/tests/routers/test_hgvs.py @@ -10,7 +10,9 @@ fastapi = pytest.importorskip("fastapi") hgvs = pytest.importorskip("hgvs") -from tests.helpers.constants import TEST_NT_CDOT_TRANSCRIPT, VALID_NT_ACCESSION, VALID_GENE +from hgvs.dataproviders.uta import UTABase + +from tests.helpers.constants import TEST_NT_CDOT_TRANSCRIPT, VALID_GENE, VALID_NT_ACCESSION VALID_MAJOR_ASSEMBLY = "GRCh38" VALID_MINOR_ASSEMBLY = "GRCh38.p3" @@ -85,9 +87,12 @@ def test_hgvs_accessions_invalid(client, setup_router_db): def test_hgvs_genes(client, setup_router_db): - response = client.get("/api/v1/hgvs/genes") - assert response.status_code == 200 - assert VALID_GENE in response.json() + with patch.object(UTABase, "_fetchall") as mock_fetchall: + mock_fetchall.return_value = (("BRCA1",), ("TP53",), (VALID_GENE,)) + + response = client.get("/api/v1/hgvs/genes") + assert response.status_code == 200 + assert VALID_GENE in response.json() def test_hgvs_gene_info_valid(client, setup_router_db): diff --git a/tests/routers/test_mapped_variants.py b/tests/routers/test_mapped_variants.py index 81bd62e1..b071dcfd 100644 --- a/tests/routers/test_mapped_variants.py +++ b/tests/routers/test_mapped_variants.py @@ -21,7 +21,11 @@ from mavedb.models.score_set import ScoreSet as ScoreSetDbModel from mavedb.models.variant import Variant from mavedb.view_models.mapped_variant import SavedMappedVariant -from tests.helpers.constants import TEST_BIORXIV_IDENTIFIER, TEST_BRNICH_SCORE_CALIBRATION, TEST_PUBMED_IDENTIFIER +from tests.helpers.constants import ( + TEST_BIORXIV_IDENTIFIER, + TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, + TEST_PUBMED_IDENTIFIER, +) from tests.helpers.util.common import deepcamelize from tests.helpers.util.experiment import create_experiment from tests.helpers.util.score_calibration import create_publish_and_promote_score_calibration @@ -209,7 +213,9 @@ def test_show_mapped_variant_functional_impact_statement( experiment["urn"], data_files / "scores.csv", ) - create_publish_and_promote_score_calibration(client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION)) + create_publish_and_promote_score_calibration( + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + ) response = client.get(f"/api/v1/mapped-variants/{quote_plus(score_set['urn'] + '#1')}/va/functional-impact") response_data = response.json() @@ -288,7 +294,9 @@ def test_cannot_show_mapped_variant_functional_impact_statement_when_no_mapping_ experiment["urn"], data_files / "scores.csv", ) - create_publish_and_promote_score_calibration(client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION)) + create_publish_and_promote_score_calibration( + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + ) item = session.scalar(select(MappedVariant).join(Variant).where(Variant.urn == f'{score_set["urn"]}#1')) assert item is not None @@ -352,7 +360,9 @@ def test_show_mapped_variant_clinical_evidence_line( experiment["urn"], data_files / "scores.csv", ) - create_publish_and_promote_score_calibration(client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION)) + create_publish_and_promote_score_calibration( + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + ) response = client.get(f"/api/v1/mapped-variants/{quote_plus(score_set['urn'] + '#2')}/va/clinical-evidence") response_data = response.json() @@ -431,7 +441,9 @@ def test_cannot_show_mapped_variant_clinical_evidence_line_when_no_mapping_data_ experiment["urn"], data_files / "scores.csv", ) - create_publish_and_promote_score_calibration(client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION)) + create_publish_and_promote_score_calibration( + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + ) item = session.scalar(select(MappedVariant).join(Variant).where(Variant.urn == f'{score_set["urn"]}#1')) assert item is not None diff --git a/tests/routers/test_permissions.py b/tests/routers/test_permissions.py index 74405a47..b60a924e 100644 --- a/tests/routers/test_permissions.py +++ b/tests/routers/test_permissions.py @@ -131,7 +131,7 @@ def test_contributor_gets_true_permission_from_others_experiment_update_check(se assert response.json() -def test_contributor_gets_true_permission_from_others_experiment_delete_check(session, client, setup_router_db): +def test_contributor_gets_false_permission_from_others_experiment_delete_check(session, client, setup_router_db): experiment = create_experiment(client) change_ownership(session, experiment["urn"], ExperimentDbModel) add_contributor( @@ -145,7 +145,7 @@ def test_contributor_gets_true_permission_from_others_experiment_delete_check(se response = client.get(f"/api/v1/permissions/user-is-permitted/experiment/{experiment['urn']}/delete") assert response.status_code == 200 - assert response.json() + assert not response.json() def test_contributor_gets_true_permission_from_others_private_experiment_add_score_set_check( @@ -282,7 +282,7 @@ def test_contributor_gets_true_permission_from_others_score_set_update_check(ses assert response.json() -def test_contributor_gets_true_permission_from_others_score_set_delete_check(session, client, setup_router_db): +def test_contributor_gets_false_permission_from_others_score_set_delete_check(session, client, setup_router_db): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) change_ownership(session, score_set["urn"], ScoreSetDbModel) @@ -297,10 +297,10 @@ def test_contributor_gets_true_permission_from_others_score_set_delete_check(ses response = client.get(f"/api/v1/permissions/user-is-permitted/score-set/{score_set['urn']}/delete") assert response.status_code == 200 - assert response.json() + assert not response.json() -def test_contributor_gets_true_permission_from_others_score_set_publish_check(session, client, setup_router_db): +def test_contributor_gets_false_permission_from_others_score_set_publish_check(session, client, setup_router_db): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) change_ownership(session, score_set["urn"], ScoreSetDbModel) @@ -315,7 +315,7 @@ def test_contributor_gets_true_permission_from_others_score_set_publish_check(se response = client.get(f"/api/v1/permissions/user-is-permitted/score-set/{score_set['urn']}/publish") assert response.status_code == 200 - assert response.json() + assert not response.json() def test_get_false_permission_from_others_score_set_delete_check(session, client, setup_router_db): @@ -423,7 +423,7 @@ def test_contributor_gets_true_permission_from_others_investigator_provided_scor assert response.json() -def test_contributor_gets_true_permission_from_others_investigator_provided_score_calibration_delete_check( +def test_contributor_gets_false_permission_from_others_investigator_provided_score_calibration_delete_check( session, client, setup_router_db, extra_user_app_overrides ): experiment = create_experiment(client) @@ -445,10 +445,12 @@ def test_contributor_gets_true_permission_from_others_investigator_provided_scor ) assert response.status_code == 200 - assert response.json() + assert not response.json() -def test_get_false_permission_from_others_score_calibration_update_check(session, client, setup_router_db): +def test_get_true_permission_as_score_set_owner_on_others_investigator_provided_score_calibration_update_check( + session, client, setup_router_db +): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) score_calibration = create_test_score_calibration_in_score_set_via_client( @@ -458,6 +460,23 @@ def test_get_false_permission_from_others_score_calibration_update_check(session response = client.get(f"/api/v1/permissions/user-is-permitted/score-calibration/{score_calibration['urn']}/update") + assert response.status_code == 200 + assert response.json() + + +def test_get_false_permission_as_score_set_owner_on_others_community_score_calibration_update_check( + session, client, setup_router_db, admin_app_overrides +): + experiment = create_experiment(client) + score_set = create_seq_score_set(client, experiment["urn"]) + + with DependencyOverrider(admin_app_overrides): + score_calibration = create_test_score_calibration_in_score_set_via_client( + client, score_set["urn"], deepcamelize(TEST_MINIMAL_CALIBRATION) + ) + + response = client.get(f"/api/v1/permissions/user-is-permitted/score-calibration/{score_calibration['urn']}/update") + assert response.status_code == 200 assert not response.json() diff --git a/tests/routers/test_score_calibrations.py b/tests/routers/test_score_calibrations.py index 307394ec..8cdbeefe 100644 --- a/tests/routers/test_score_calibrations.py +++ b/tests/routers/test_score_calibrations.py @@ -2,10 +2,13 @@ import pytest +from mavedb.lib.validation.exceptions import ValidationError + arq = pytest.importorskip("arq") cdot = pytest.importorskip("cdot") fastapi = pytest.importorskip("fastapi") +import json from unittest.mock import patch from arq import ArqRedis @@ -13,6 +16,15 @@ from mavedb.models.score_calibration import ScoreCalibration as CalibrationDbModel from mavedb.models.score_set import ScoreSet as ScoreSetDbModel +from tests.helpers.constants import ( + EXTRA_USER, + TEST_BIORXIV_IDENTIFIER, + TEST_BRNICH_SCORE_CALIBRATION_CLASS_BASED, + TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, + TEST_PATHOGENICITY_SCORE_CALIBRATION, + TEST_PUBMED_IDENTIFIER, + VALID_CALIBRATION_URN, +) from tests.helpers.dependency_overrider import DependencyOverrider from tests.helpers.util.common import deepcamelize from tests.helpers.util.contributor import add_contributor @@ -24,15 +36,6 @@ ) from tests.helpers.util.score_set import create_seq_score_set_with_mapped_variants, publish_score_set -from tests.helpers.constants import ( - EXTRA_USER, - TEST_BIORXIV_IDENTIFIER, - TEST_BRNICH_SCORE_CALIBRATION, - TEST_PATHOGENICITY_SCORE_CALIBRATION, - TEST_PUBMED_IDENTIFIER, - VALID_CALIBRATION_URN, -) - ########################################################### # GET /score-calibrations/{calibration_urn} ########################################################### @@ -68,7 +71,7 @@ def test_anonymous_user_cannot_get_score_calibration_when_private( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(anonymous_app_overrides): @@ -101,7 +104,7 @@ def test_other_user_cannot_get_score_calibration_when_private( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(extra_user_app_overrides): @@ -134,7 +137,7 @@ def test_creating_user_can_get_score_calibration_when_private( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) response = client.get(f"/api/v1/score-calibrations/{calibration['urn']}") @@ -167,7 +170,7 @@ def test_contributing_user_can_get_score_calibration_when_private_and_investigat data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) add_contributor( @@ -219,7 +222,7 @@ def test_contributing_user_cannot_get_score_calibration_when_private_and_not_inv with DependencyOverrider(admin_app_overrides): calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) add_contributor( @@ -261,7 +264,7 @@ def test_admin_user_can_get_score_calibration_when_private( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(admin_app_overrides): @@ -295,7 +298,7 @@ def test_anonymous_user_can_get_score_calibration_when_public( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) calibration = publish_test_score_calibration_via_client(client, calibration["urn"]) @@ -330,7 +333,7 @@ def test_other_user_can_get_score_calibration_when_public( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) calibration = publish_test_score_calibration_via_client(client, calibration["urn"]) @@ -365,7 +368,7 @@ def test_creating_user_can_get_score_calibration_when_public( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) calibration = publish_test_score_calibration_via_client(client, calibration["urn"]) @@ -399,7 +402,7 @@ def test_contributing_user_can_get_score_calibration_when_public( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) calibration = publish_test_score_calibration_via_client(client, calibration["urn"]) @@ -443,7 +446,7 @@ def test_admin_user_can_get_score_calibration_when_public( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) calibration = publish_test_score_calibration_via_client(client, calibration["urn"]) @@ -512,7 +515,7 @@ def test_anonymous_user_cannot_get_score_calibrations_for_score_set_when_private data_files / "scores.csv", ) create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(anonymous_app_overrides): @@ -545,7 +548,7 @@ def test_other_user_cannot_get_score_calibrations_for_score_set_when_private( data_files / "scores.csv", ) create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(extra_user_app_overrides): @@ -578,7 +581,7 @@ def test_anonymous_user_cannot_get_score_calibrations_for_score_set_when_publish data_files / "scores.csv", ) create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with patch.object(ArqRedis, "enqueue_job", return_value=None): @@ -614,7 +617,7 @@ def test_other_user_cannot_get_score_calibrations_for_score_set_when_published_b data_files / "scores.csv", ) create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with patch.object(ArqRedis, "enqueue_job", return_value=None): @@ -650,7 +653,7 @@ def test_creating_user_can_get_score_calibrations_for_score_set_when_private( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) response = client.get(f"/api/v1/score-calibrations/score-set/{score_set['urn']}") @@ -693,11 +696,11 @@ def test_contributing_user_can_get_investigator_provided_score_calibrations_for_ with DependencyOverrider(admin_app_overrides): create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) investigator_calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) add_contributor( @@ -741,7 +744,7 @@ def test_admin_user_can_get_score_calibrations_for_score_set_when_private( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(admin_app_overrides): @@ -776,12 +779,12 @@ def test_anonymous_user_can_get_score_calibrations_for_score_set_when_public( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) # add another calibration that will remain private. The anonymous user should not see this one calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) publish_test_score_calibration_via_client(client, calibration["urn"]) @@ -821,12 +824,12 @@ def test_other_user_can_get_score_calibrations_for_score_set_when_public( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) # add another calibration that will remain private. The other user should not see this one create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) publish_test_score_calibration_via_client(client, calibration["urn"]) @@ -866,12 +869,12 @@ def test_anonymous_user_cannot_get_score_calibrations_for_score_set_when_calibra data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) # add another calibration that will remain private. The anonymous user should not see this one calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) publish_test_score_calibration_via_client(client, calibration["urn"]) @@ -906,12 +909,12 @@ def test_other_user_cannot_get_score_calibrations_for_score_set_when_calibration data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) # add another calibration that will remain private. The other user should not see this one create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) publish_test_score_calibration_via_client(client, calibration["urn"]) @@ -946,13 +949,13 @@ def test_creating_user_can_get_score_calibrations_for_score_set_when_public( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) publish_test_score_calibration_via_client(client, calibration["urn"]) # add another calibration that is private. The creating user should see this one too create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) response = client.get(f"/api/v1/score-calibrations/score-set/{score_set['urn']}") @@ -986,13 +989,13 @@ def test_contributing_user_can_get_score_calibrations_for_score_set_when_public( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) publish_test_score_calibration_via_client(client, calibration["urn"]) # add another calibration that is private. The contributing user should see this one too create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) add_contributor( @@ -1036,13 +1039,13 @@ def test_admin_user_can_get_score_calibrations_for_score_set_when_public( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) publish_test_score_calibration_via_client(client, calibration["urn"]) # add another calibration that is private. The admin user should see this one too create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(admin_app_overrides): @@ -1111,7 +1114,7 @@ def test_cannot_get_primary_score_calibration_for_score_set_when_none_exist( data_files / "scores.csv", ) create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) response = client.get(f"/api/v1/score-calibrations/score-set/{score_set['urn']}/primary") @@ -1146,7 +1149,7 @@ def test_get_primary_score_calibration_for_score_set_when_exists( data_files / "scores.csv", ) calibration = create_publish_and_promote_score_calibration( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) response = client.get(f"/api/v1/score-calibrations/score-set/{score_set['urn']}/primary") @@ -1181,9 +1184,11 @@ def test_get_primary_score_calibration_for_score_set_when_multiple_exist( data_files / "scores.csv", ) - create_publish_and_promote_score_calibration(client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION)) + create_publish_and_promote_score_calibration( + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + ) calibration2 = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) publish_test_score_calibration_via_client(client, calibration2["urn"]) @@ -1209,7 +1214,7 @@ def test_get_primary_score_calibration_for_score_set_when_multiple_exist( def test_cannot_create_score_calibration_when_missing_score_set_urn(client, setup_router_db): response = client.post( "/api/v1/score-calibrations", - json={**deepcamelize(TEST_BRNICH_SCORE_CALIBRATION)}, + json={**deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED)}, ) assert response.status_code == 422 @@ -1222,7 +1227,7 @@ def test_cannot_create_score_calibration_when_score_set_does_not_exist(client, s "/api/v1/score-calibrations", json={ "scoreSetUrn": "urn:ngs:score-set:nonexistent", - **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION), + **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED), }, ) @@ -1231,6 +1236,96 @@ def test_cannot_create_score_calibration_when_score_set_does_not_exist(client, s assert "score set with URN 'urn:ngs:score-set:nonexistent' not found" in error["detail"] +@pytest.mark.parametrize( + "mock_publication_fetch", + [ + [ + {"dbName": "PubMed", "identifier": TEST_PUBMED_IDENTIFIER}, + {"dbName": "bioRxiv", "identifier": TEST_BIORXIV_IDENTIFIER}, + ] + ], + indirect=["mock_publication_fetch"], +) +def test_cannot_create_score_calibration_when_csv_file_fails_decoding( + client, setup_router_db, session, data_provider, data_files, mock_publication_fetch +): + experiment = create_experiment(client) + score_set = create_seq_score_set_with_mapped_variants( + client, + session, + data_provider, + experiment["urn"], + data_files / "scores.csv", + ) + + calibration_csv_path = data_files / "calibration_classes_by_urn.csv" + with ( + open(calibration_csv_path, "rb") as class_file, + patch( + "mavedb.routers.score_calibrations.csv_data_to_df", + side_effect=UnicodeDecodeError("utf-8", b"", 0, 1, "invalid start byte"), + ), + ): + response = client.post( + "/api/v1/score-calibrations", + files={"classes_file": (calibration_csv_path.name, class_file, "text/csv")}, + data={ + "calibration_json": json.dumps( + {"scoreSetUrn": score_set["urn"], **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_CLASS_BASED)} + ), + }, + ) + + assert response.status_code == 400 + error = response.json() + assert "Error decoding file:" in str(error["detail"]) + + +@pytest.mark.parametrize( + "mock_publication_fetch", + [ + [ + {"dbName": "PubMed", "identifier": TEST_PUBMED_IDENTIFIER}, + {"dbName": "bioRxiv", "identifier": TEST_BIORXIV_IDENTIFIER}, + ] + ], + indirect=["mock_publication_fetch"], +) +def test_cannot_create_score_calibration_when_validation_error_is_raised_from_score_calibration_file_standardization( + client, setup_router_db, session, data_provider, data_files, mock_publication_fetch +): + experiment = create_experiment(client) + score_set = create_seq_score_set_with_mapped_variants( + client, + session, + data_provider, + experiment["urn"], + data_files / "scores.csv", + ) + + calibration_csv_path = data_files / "calibration_classes_by_urn.csv" + with ( + open(calibration_csv_path, "rb") as class_file, + patch( + "mavedb.routers.score_calibrations.validate_and_standardize_calibration_classes_dataframe", + side_effect=ValidationError("Test validation error"), + ), + ): + response = client.post( + "/api/v1/score-calibrations", + files={"classes_file": (calibration_csv_path.name, class_file, "text/csv")}, + data={ + "calibration_json": json.dumps( + {"scoreSetUrn": score_set["urn"], **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_CLASS_BASED)} + ), + }, + ) + + assert response.status_code == 422 + error = response.json() + assert "Test validation error" in str(error["detail"][0]["msg"]) + + @pytest.mark.parametrize( "mock_publication_fetch", [ @@ -1258,7 +1353,7 @@ def test_cannot_create_score_calibration_when_score_set_not_owned_by_user( "/api/v1/score-calibrations", json={ "scoreSetUrn": score_set["urn"], - **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION), + **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED), }, ) @@ -1297,13 +1392,91 @@ def test_cannot_create_score_calibration_in_public_score_set_when_score_set_not_ "/api/v1/score-calibrations", json={ "scoreSetUrn": score_set["urn"], - **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION), + **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED), }, ) assert response.status_code == 403 error = response.json() - assert f"insufficient permissions for URN '{score_set['urn']}'" in error["detail"] + assert f"insufficient permissions on score set with URN '{score_set['urn']}'" in error["detail"] + + +@pytest.mark.parametrize( + "mock_publication_fetch", + [ + [ + {"dbName": "PubMed", "identifier": TEST_PUBMED_IDENTIFIER}, + {"dbName": "bioRxiv", "identifier": TEST_BIORXIV_IDENTIFIER}, + ] + ], + indirect=["mock_publication_fetch"], +) +def test_cannot_create_class_based_score_calibration_without_classes_file( + client, setup_router_db, mock_publication_fetch, session, data_provider, data_files +): + experiment = create_experiment(client) + score_set = create_seq_score_set_with_mapped_variants( + client, + session, + data_provider, + experiment["urn"], + data_files / "scores.csv", + ) + + response = client.post( + "/api/v1/score-calibrations", + json={ + "scoreSetUrn": score_set["urn"], + **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_CLASS_BASED), + }, + ) + + assert response.status_code == 422 + error = response.json() + assert "A classes_file must be provided when creating a class-based calibration" in str(error["detail"]) + + +@pytest.mark.parametrize( + "mock_publication_fetch", + [ + [ + {"dbName": "PubMed", "identifier": TEST_PUBMED_IDENTIFIER}, + {"dbName": "bioRxiv", "identifier": TEST_BIORXIV_IDENTIFIER}, + ] + ], + indirect=["mock_publication_fetch"], +) +@pytest.mark.parametrize( + "calibration_csv_path", + ["calibration_classes_by_urn.csv", "calibration_classes_by_hgvs_nt.csv", "calibration_classes_by_hgvs_prot.csv"], +) +def test_cannot_create_range_based_score_calibration_with_classes_file( + client, setup_router_db, mock_publication_fetch, session, data_provider, data_files, calibration_csv_path +): + experiment = create_experiment(client) + score_set = create_seq_score_set_with_mapped_variants( + client, + session, + data_provider, + experiment["urn"], + data_files / "scores.csv", + ) + + classification_csv_path = data_files / calibration_csv_path + with open(classification_csv_path, "rb") as class_file: + response = client.post( + "/api/v1/score-calibrations", + files={"classes_file": (classification_csv_path.name, class_file, "text/csv")}, + data={ + "calibration_json": json.dumps( + {"scoreSetUrn": score_set["urn"], **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED)} + ), + }, + ) + + assert response.status_code == 422 + error = response.json() + assert "A classes_file should not be provided when creating a range-based calibration" in str(error["detail"]) @pytest.mark.parametrize( @@ -1333,7 +1506,7 @@ def test_cannot_create_score_calibration_as_anonymous_user( "/api/v1/score-calibrations", json={ "scoreSetUrn": score_set["urn"], - **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION), + **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED), }, ) @@ -1368,7 +1541,44 @@ def test_can_create_score_calibration_as_score_set_owner( "/api/v1/score-calibrations", json={ "scoreSetUrn": score_set["urn"], - **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION), + **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED), + }, + ) + + assert response.status_code == 200 + calibration_response = response.json() + assert calibration_response["scoreSetUrn"] == score_set["urn"] + assert calibration_response["private"] is True + + +@pytest.mark.parametrize( + "mock_publication_fetch", + [ + [ + {"dbName": "PubMed", "identifier": TEST_PUBMED_IDENTIFIER}, + {"dbName": "bioRxiv", "identifier": TEST_BIORXIV_IDENTIFIER}, + ] + ], + indirect=["mock_publication_fetch"], +) +def test_can_create_score_calibration_as_score_set_owner_form( + client, setup_router_db, mock_publication_fetch, session, data_provider, data_files +): + experiment = create_experiment(client) + score_set = create_seq_score_set_with_mapped_variants( + client, + session, + data_provider, + experiment["urn"], + data_files / "scores.csv", + ) + + response = client.post( + "/api/v1/score-calibrations", + data={ + "calibration_json": json.dumps( + {"scoreSetUrn": score_set["urn"], **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED)} + ), }, ) @@ -1414,7 +1624,7 @@ def test_can_create_score_calibration_as_score_set_contributor( "/api/v1/score-calibrations", json={ "scoreSetUrn": score_set["urn"], - **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION), + **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED), }, ) @@ -1451,7 +1661,53 @@ def test_can_create_score_calibration_as_admin_user( "/api/v1/score-calibrations", json={ "scoreSetUrn": score_set["urn"], - **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION), + **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED), + }, + ) + + assert response.status_code == 200 + calibration_response = response.json() + assert calibration_response["scoreSetUrn"] == score_set["urn"] + assert calibration_response["private"] is True + + +@pytest.mark.parametrize( + "mock_publication_fetch", + [ + [ + {"dbName": "PubMed", "identifier": TEST_PUBMED_IDENTIFIER}, + {"dbName": "bioRxiv", "identifier": TEST_BIORXIV_IDENTIFIER}, + ] + ], + indirect=["mock_publication_fetch"], +) +@pytest.mark.parametrize( + "calibration_csv_path", + ["calibration_classes_by_urn.csv", "calibration_classes_by_hgvs_nt.csv", "calibration_classes_by_hgvs_prot.csv"], +) +def test_can_create_class_based_score_calibration_form( + client, setup_router_db, mock_publication_fetch, session, data_provider, data_files, calibration_csv_path +): + experiment = create_experiment(client) + score_set = create_seq_score_set_with_mapped_variants( + client, + session, + data_provider, + experiment["urn"], + data_files / "scores.csv", + ) + with patch.object(ArqRedis, "enqueue_job", return_value=None): + score_set = publish_score_set(client, score_set["urn"]) + + classification_csv_path = data_files / calibration_csv_path + with open(classification_csv_path, "rb") as class_file: + response = client.post( + "/api/v1/score-calibrations", + files={"classes_file": (classification_csv_path.name, class_file, "text/csv")}, + data={ + "calibration_json": json.dumps( + {"scoreSetUrn": score_set["urn"], **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_CLASS_BASED)} + ), }, ) @@ -1459,6 +1715,9 @@ def test_can_create_score_calibration_as_admin_user( calibration_response = response.json() assert calibration_response["scoreSetUrn"] == score_set["urn"] assert calibration_response["private"] is True + assert all( + len(classification["variants"]) == 1 for classification in calibration_response["functionalClassifications"] + ) ########################################################### @@ -1488,7 +1747,7 @@ def test_cannot_update_score_calibration_when_score_set_not_exists( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) response = client.put( @@ -1539,6 +1798,195 @@ def test_cannot_update_score_calibration_when_calibration_not_exists( assert "The requested score calibration does not exist" in error["detail"] +@pytest.mark.parametrize( + "mock_publication_fetch", + [ + [ + {"dbName": "PubMed", "identifier": TEST_PUBMED_IDENTIFIER}, + {"dbName": "bioRxiv", "identifier": TEST_BIORXIV_IDENTIFIER}, + ] + ], + indirect=["mock_publication_fetch"], +) +def test_cannot_update_score_calibration_when_csv_file_fails_decoding( + client, setup_router_db, session, data_provider, data_files, mock_publication_fetch +): + experiment = create_experiment(client) + score_set = create_seq_score_set_with_mapped_variants( + client, + session, + data_provider, + experiment["urn"], + data_files / "scores.csv", + ) + calibration = create_test_score_calibration_in_score_set_via_client( + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + ) + + calibration_csv_path = data_files / "calibration_classes_by_urn.csv" + with ( + open(calibration_csv_path, "rb") as class_file, + patch( + "mavedb.routers.score_calibrations.csv_data_to_df", + side_effect=UnicodeDecodeError("utf-8", b"", 0, 1, "invalid start byte"), + ), + ): + response = client.put( + f"/api/v1/score-calibrations/{calibration['urn']}", + files={"classes_file": (calibration_csv_path.name, class_file, "text/csv")}, + data={ + "calibration_json": json.dumps( + { + "scoreSetUrn": score_set["urn"], + **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_CLASS_BASED), + } + ), + }, + ) + + assert response.status_code == 400 + error = response.json() + assert "Error decoding file:" in str(error["detail"]) + + +@pytest.mark.parametrize( + "mock_publication_fetch", + [ + [ + {"dbName": "PubMed", "identifier": TEST_PUBMED_IDENTIFIER}, + {"dbName": "bioRxiv", "identifier": TEST_BIORXIV_IDENTIFIER}, + ] + ], + indirect=["mock_publication_fetch"], +) +def test_cannot_update_score_calibration_when_validation_error_is_raised_from_score_calibration_file_standardization( + client, setup_router_db, session, data_provider, data_files, mock_publication_fetch +): + experiment = create_experiment(client) + score_set = create_seq_score_set_with_mapped_variants( + client, + session, + data_provider, + experiment["urn"], + data_files / "scores.csv", + ) + calibration = create_test_score_calibration_in_score_set_via_client( + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + ) + + calibration_csv_path = data_files / "calibration_classes_by_urn.csv" + with ( + open(calibration_csv_path, "rb") as class_file, + patch( + "mavedb.routers.score_calibrations.validate_and_standardize_calibration_classes_dataframe", + side_effect=ValidationError("Test validation error"), + ), + ): + response = client.put( + f"/api/v1/score-calibrations/{calibration['urn']}", + files={"classes_file": (calibration_csv_path.name, class_file, "text/csv")}, + data={ + "calibration_json": json.dumps( + { + "scoreSetUrn": score_set["urn"], + **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_CLASS_BASED), + } + ), + }, + ) + + assert response.status_code == 422 + error = response.json() + assert "Test validation error" in str(error["detail"][0]["msg"]) + + +@pytest.mark.parametrize( + "mock_publication_fetch", + [ + [ + {"dbName": "PubMed", "identifier": TEST_PUBMED_IDENTIFIER}, + {"dbName": "bioRxiv", "identifier": TEST_BIORXIV_IDENTIFIER}, + ] + ], + indirect=["mock_publication_fetch"], +) +def test_cannot_update_class_based_score_calibration_without_class_file( + client, setup_router_db, mock_publication_fetch, session, data_provider, data_files +): + experiment = create_experiment(client) + score_set = create_seq_score_set_with_mapped_variants( + client, + session, + data_provider, + experiment["urn"], + data_files / "scores.csv", + ) + calibration = create_test_score_calibration_in_score_set_via_client( + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + ) + + response = client.put( + f"/api/v1/score-calibrations/{calibration['urn']}", + json={ + "scoreSetUrn": score_set["urn"], + **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_CLASS_BASED), + }, + ) + + assert response.status_code == 422 + error = response.json() + assert "A classes_file must be provided when modifying a class-based calibration" in str(error["detail"]) + + +@pytest.mark.parametrize( + "mock_publication_fetch", + [ + [ + {"dbName": "PubMed", "identifier": TEST_PUBMED_IDENTIFIER}, + {"dbName": "bioRxiv", "identifier": TEST_BIORXIV_IDENTIFIER}, + ] + ], + indirect=["mock_publication_fetch"], +) +@pytest.mark.parametrize( + "calibration_csv_path", + ["calibration_classes_by_urn.csv", "calibration_classes_by_hgvs_nt.csv", "calibration_classes_by_hgvs_prot.csv"], +) +def test_cannot_update_range_based_score_calibration_with_class_file( + client, setup_router_db, mock_publication_fetch, session, data_provider, data_files, calibration_csv_path +): + experiment = create_experiment(client) + score_set = create_seq_score_set_with_mapped_variants( + client, + session, + data_provider, + experiment["urn"], + data_files / "scores.csv", + ) + calibration = create_test_score_calibration_in_score_set_via_client( + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + ) + + classification_csv_path = data_files / calibration_csv_path + with open(classification_csv_path, "rb") as class_file: + response = client.put( + f"/api/v1/score-calibrations/{calibration['urn']}", + files={"classes_file": (classification_csv_path.name, class_file, "text/csv")}, + data={ + "calibration_json": json.dumps( + { + "scoreSetUrn": score_set["urn"], + **deepcamelize(TEST_PATHOGENICITY_SCORE_CALIBRATION), + } + ), + }, + ) + + assert response.status_code == 422 + error = response.json() + assert "A classes_file should not be provided when modifying a range-based calibration" in str(error["detail"]) + + @pytest.mark.parametrize( "mock_publication_fetch", [ @@ -1561,7 +2009,7 @@ def test_cannot_update_score_calibration_as_anonymous_user( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(anonymous_app_overrides): @@ -1600,7 +2048,7 @@ def test_cannot_update_score_calibration_when_score_set_not_owned_by_user( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(extra_user_app_overrides): @@ -1639,7 +2087,7 @@ def test_cannot_update_score_calibration_in_published_score_set_when_score_set_n data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with patch.object(ArqRedis, "enqueue_job", return_value=None): @@ -1656,7 +2104,7 @@ def test_cannot_update_score_calibration_in_published_score_set_when_score_set_n assert response.status_code == 403 error = response.json() - assert f"insufficient permissions for URN '{score_set['urn']}'" in error["detail"] + assert f"insufficient permissions on score set with URN '{score_set['urn']}'" in error["detail"] @pytest.mark.parametrize( @@ -1681,7 +2129,7 @@ def test_can_update_score_calibration_as_score_set_owner( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) response = client.put( @@ -1699,6 +2147,47 @@ def test_can_update_score_calibration_as_score_set_owner( assert calibration_response["private"] is True +@pytest.mark.parametrize( + "mock_publication_fetch", + [ + [ + {"dbName": "PubMed", "identifier": TEST_PUBMED_IDENTIFIER}, + {"dbName": "bioRxiv", "identifier": TEST_BIORXIV_IDENTIFIER}, + ] + ], + indirect=["mock_publication_fetch"], +) +def test_can_update_score_calibration_as_score_set_owner_form( + client, setup_router_db, mock_publication_fetch, session, data_provider, data_files +): + experiment = create_experiment(client) + score_set = create_seq_score_set_with_mapped_variants( + client, + session, + data_provider, + experiment["urn"], + data_files / "scores.csv", + ) + calibration = create_test_score_calibration_in_score_set_via_client( + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + ) + + response = client.put( + f"/api/v1/score-calibrations/{calibration['urn']}", + data={ + "calibration_json": json.dumps( + {"scoreSetUrn": score_set["urn"], **deepcamelize(TEST_PATHOGENICITY_SCORE_CALIBRATION)} + ), + }, + ) + + assert response.status_code == 200 + calibration_response = response.json() + assert calibration_response["urn"] == calibration["urn"] + assert calibration_response["scoreSetUrn"] == score_set["urn"] + assert calibration_response["private"] is True + + @pytest.mark.parametrize( "mock_publication_fetch", [ @@ -1721,7 +2210,7 @@ def test_cannot_update_published_score_calibration_as_score_set_owner( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) publish_test_score_calibration_via_client(client, calibration["urn"]) @@ -1736,7 +2225,7 @@ def test_cannot_update_published_score_calibration_as_score_set_owner( assert response.status_code == 403 error = response.json() - assert f"insufficient permissions for URN '{calibration['urn']}'" in error["detail"] + assert f"insufficient permissions on score calibration with URN '{calibration['urn']}'" in error["detail"] @pytest.mark.parametrize( @@ -1761,7 +2250,7 @@ def test_can_update_investigator_provided_score_calibration_as_score_set_contrib data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) add_contributor( @@ -1820,7 +2309,7 @@ def test_cannot_update_non_investigator_score_calibration_as_score_set_contribut with DependencyOverrider(admin_app_overrides): calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) add_contributor( @@ -1868,7 +2357,7 @@ def test_can_update_score_calibration_as_admin_user( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(admin_app_overrides): @@ -1909,7 +2398,7 @@ def test_can_update_published_score_calibration_as_admin_user( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) publish_test_score_calibration_via_client(client, calibration["urn"]) @@ -1959,7 +2448,7 @@ def test_anonymous_user_may_not_move_calibration_to_another_score_set( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set1["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set1["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(anonymous_app_overrides): @@ -1967,7 +2456,7 @@ def test_anonymous_user_may_not_move_calibration_to_another_score_set( f"/api/v1/score-calibrations/{calibration['urn']}", json={ "scoreSetUrn": score_set2["urn"], - **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION), + **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED), }, ) @@ -2005,7 +2494,7 @@ def test_user_may_not_move_investigator_calibration_when_lacking_permissions_on_ data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set1["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set1["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) # Give user permissions on the first score set only @@ -2023,7 +2512,7 @@ def test_user_may_not_move_investigator_calibration_when_lacking_permissions_on_ f"/api/v1/score-calibrations/{calibration['urn']}", json={ "scoreSetUrn": score_set2["urn"], - **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION), + **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED), }, ) @@ -2061,7 +2550,7 @@ def test_user_may_move_investigator_calibration_when_has_permissions_on_destinat data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set1["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set1["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) # Give user permissions on both score sets @@ -2091,7 +2580,7 @@ def test_user_may_move_investigator_calibration_when_has_permissions_on_destinat f"/api/v1/score-calibrations/{calibration['urn']}", json={ "scoreSetUrn": score_set2["urn"], - **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION), + **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED), }, ) @@ -2130,7 +2619,7 @@ def test_admin_user_may_move_calibration_to_another_score_set( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set1["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set1["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(admin_app_overrides): @@ -2138,7 +2627,7 @@ def test_admin_user_may_move_calibration_to_another_score_set( f"/api/v1/score-calibrations/{calibration['urn']}", json={ "scoreSetUrn": score_set2["urn"], - **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION), + **deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED), }, ) @@ -2148,6 +2637,58 @@ def test_admin_user_may_move_calibration_to_another_score_set( assert calibration_response["scoreSetUrn"] == score_set2["urn"] +@pytest.mark.parametrize( + "mock_publication_fetch", + [ + [ + {"dbName": "PubMed", "identifier": TEST_PUBMED_IDENTIFIER}, + {"dbName": "bioRxiv", "identifier": TEST_BIORXIV_IDENTIFIER}, + ] + ], + indirect=["mock_publication_fetch"], +) +@pytest.mark.parametrize( + "calibration_csv_path", + ["calibration_classes_by_urn.csv", "calibration_classes_by_hgvs_nt.csv", "calibration_classes_by_hgvs_prot.csv"], +) +def test_can_modify_score_calibration_to_class_based( + client, setup_router_db, mock_publication_fetch, session, data_provider, data_files, calibration_csv_path +): + experiment = create_experiment(client) + score_set = create_seq_score_set_with_mapped_variants( + client, + session, + data_provider, + experiment["urn"], + data_files / "scores.csv", + ) + with patch.object(ArqRedis, "enqueue_job", return_value=None): + score_set = publish_score_set(client, score_set["urn"]) + + calibration = create_test_score_calibration_in_score_set_via_client( + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + ) + + classification_csv_path = data_files / calibration_csv_path + updated_calibration_data = deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_CLASS_BASED) + + with open(classification_csv_path, "rb") as class_file: + response = client.put( + f"/api/v1/score-calibrations/{calibration['urn']}", + files={"classes_file": (classification_csv_path.name, class_file, "text/csv")}, + data={ + "calibration_json": json.dumps({"scoreSetUrn": score_set["urn"], **updated_calibration_data}), + }, + ) + + assert response.status_code == 200 + calibration_response = response.json() + assert calibration_response["urn"] == calibration["urn"] + assert all( + len(classification["variants"]) == 1 for classification in calibration_response["functionalClassifications"] + ) + + ########################################################### # DELETE /score-calibrations/{calibration_urn} ########################################################### @@ -2183,7 +2724,7 @@ def test_cannot_delete_score_calibration_as_anonymous_user( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(anonymous_app_overrides): @@ -2216,7 +2757,7 @@ def test_cannot_delete_score_calibration_when_score_set_not_owned_by_user( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(extra_user_app_overrides): @@ -2249,7 +2790,7 @@ def test_can_delete_score_calibration_as_score_set_owner( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) response = client.delete(f"/api/v1/score-calibrations/{calibration['urn']}") @@ -2283,7 +2824,7 @@ def test_cannot_delete_published_score_calibration_as_owner( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) publish_test_score_calibration_via_client(client, calibration["urn"]) @@ -2291,7 +2832,7 @@ def test_cannot_delete_published_score_calibration_as_owner( assert response.status_code == 403 error = response.json() - assert f"insufficient permissions for URN '{calibration['urn']}'" in error["detail"] + assert f"insufficient permissions on score calibration with URN '{calibration['urn']}'" in error["detail"] @pytest.mark.parametrize( @@ -2304,7 +2845,7 @@ def test_cannot_delete_published_score_calibration_as_owner( ], indirect=["mock_publication_fetch"], ) -def test_can_delete_investigator_score_calibration_as_score_set_contributor( +def test_cannot_delete_investigator_score_calibration_as_score_set_contributor( client, setup_router_db, mock_publication_fetch, session, data_provider, data_files, extra_user_app_overrides ): experiment = create_experiment(client) @@ -2316,7 +2857,7 @@ def test_can_delete_investigator_score_calibration_as_score_set_contributor( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) add_contributor( @@ -2331,11 +2872,9 @@ def test_can_delete_investigator_score_calibration_as_score_set_contributor( with DependencyOverrider(extra_user_app_overrides): response = client.delete(f"/api/v1/score-calibrations/{calibration['urn']}") - assert response.status_code == 204 - - # verify it's deleted - get_response = client.get(f"/api/v1/score-calibrations/{calibration['urn']}") - assert get_response.status_code == 404 + error = response.json() + assert response.status_code == 403 + assert f"insufficient permissions on score calibration with URN '{calibration['urn']}'" in error["detail"] @pytest.mark.parametrize( @@ -2369,7 +2908,7 @@ def test_cannot_delete_non_investigator_calibration_as_score_set_contributor( with DependencyOverrider(admin_app_overrides): calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) add_contributor( @@ -2409,7 +2948,7 @@ def test_can_delete_score_calibration_as_admin_user( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(admin_app_overrides): @@ -2444,7 +2983,7 @@ def test_can_delete_published_score_calibration_as_admin_user( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) publish_test_score_calibration_via_client(client, calibration["urn"]) @@ -2480,14 +3019,14 @@ def test_cannot_delete_primary_score_calibration( data_files / "scores.csv", ) calibration = create_publish_and_promote_score_calibration( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) response = client.delete(f"/api/v1/score-calibrations/{calibration['urn']}") assert response.status_code == 403 error = response.json() - assert f"insufficient permissions for URN '{calibration['urn']}'" in error["detail"] + assert f"insufficient permissions on score calibration with URN '{calibration['urn']}'" in error["detail"] ########################################################### @@ -2528,7 +3067,7 @@ def test_cannot_promote_score_calibration_as_anonymous_user( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) publish_test_score_calibration_via_client(client, calibration["urn"]) @@ -2562,7 +3101,7 @@ def test_cannot_promote_score_calibration_when_score_calibration_not_owned_by_us data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) publish_test_score_calibration_via_client(client, calibration["urn"]) @@ -2573,7 +3112,7 @@ def test_cannot_promote_score_calibration_when_score_calibration_not_owned_by_us assert response.status_code == 403 error = response.json() - assert f"insufficient permissions for URN '{calibration['urn']}'" in error["detail"] + assert f"insufficient permissions on score calibration with URN '{calibration['urn']}'" in error["detail"] @pytest.mark.parametrize( @@ -2598,7 +3137,7 @@ def test_can_promote_score_calibration_as_score_set_owner( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) publish_test_score_calibration_via_client(client, calibration["urn"]) response = client.post(f"/api/v1/score-calibrations/{calibration['urn']}/promote-to-primary") @@ -2632,7 +3171,7 @@ def test_can_promote_score_calibration_as_score_set_contributor( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) publish_test_score_calibration_via_client(client, calibration["urn"]) @@ -2677,7 +3216,7 @@ def test_can_promote_score_calibration_as_admin_user( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) publish_test_score_calibration_via_client(client, calibration["urn"]) @@ -2713,7 +3252,7 @@ def test_can_promote_existing_primary_to_primary( data_files / "scores.csv", ) primary_calibration = create_publish_and_promote_score_calibration( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) response = client.post(f"/api/v1/score-calibrations/{primary_calibration['urn']}/promote-to-primary") @@ -2749,7 +3288,7 @@ def test_cannot_promote_research_use_only_to_primary( calibration = create_test_score_calibration_in_score_set_via_client( client, score_set["urn"], - deepcamelize({**TEST_BRNICH_SCORE_CALIBRATION, "researchUseOnly": True}), + deepcamelize({**TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, "researchUseOnly": True}), ) publish_test_score_calibration_via_client(client, calibration["urn"]) @@ -2784,7 +3323,7 @@ def test_cannot_promote_private_calibration_to_primary( calibration = create_test_score_calibration_in_score_set_via_client( client, score_set["urn"], - deepcamelize({**TEST_BRNICH_SCORE_CALIBRATION, "private": True}), + deepcamelize({**TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, "private": True}), ) response = client.post(f"/api/v1/score-calibrations/{calibration['urn']}/promote-to-primary") @@ -2815,7 +3354,9 @@ def test_cannot_promote_to_primary_if_primary_exists( experiment["urn"], data_files / "scores.csv", ) - create_publish_and_promote_score_calibration(client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION)) + create_publish_and_promote_score_calibration( + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + ) secondary_calibration = create_test_score_calibration_in_score_set_via_client( client, score_set["urn"], deepcamelize(TEST_PATHOGENICITY_SCORE_CALIBRATION) ) @@ -2850,7 +3391,7 @@ def test_can_promote_to_primary_if_primary_exists_when_demote_existing_is_true( data_files / "scores.csv", ) primary_calibration = create_publish_and_promote_score_calibration( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) secondary_calibration = create_test_score_calibration_in_score_set_via_client( client, score_set["urn"], deepcamelize(TEST_PATHOGENICITY_SCORE_CALIBRATION) @@ -2897,7 +3438,7 @@ def test_cannot_promote_to_primary_with_demote_existing_flag_if_user_does_not_ha ) with DependencyOverrider(admin_app_overrides): primary_calibration = create_publish_and_promote_score_calibration( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) secondary_calibration = create_test_score_calibration_in_score_set_via_client( client, score_set["urn"], deepcamelize(TEST_PATHOGENICITY_SCORE_CALIBRATION) @@ -2910,7 +3451,7 @@ def test_cannot_promote_to_primary_with_demote_existing_flag_if_user_does_not_ha assert response.status_code == 403 promotion_response = response.json() - assert "insufficient permissions for URN" in promotion_response["detail"] + assert "insufficient permissions on score calibration with URN" in promotion_response["detail"] # verify the previous primary is still primary @@ -2957,7 +3498,7 @@ def test_cannot_demote_score_calibration_as_anonymous_user( data_files / "scores.csv", ) calibration = create_publish_and_promote_score_calibration( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(anonymous_app_overrides): @@ -2992,7 +3533,7 @@ def test_cannot_demote_score_calibration_when_score_calibration_not_owned_by_use data_files / "scores.csv", ) calibration = create_publish_and_promote_score_calibration( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(extra_user_app_overrides): @@ -3002,7 +3543,7 @@ def test_cannot_demote_score_calibration_when_score_calibration_not_owned_by_use assert response.status_code == 403 error = response.json() - assert f"insufficient permissions for URN '{calibration['urn']}'" in error["detail"] + assert f"insufficient permissions on score calibration with URN '{calibration['urn']}'" in error["detail"] @pytest.mark.parametrize( @@ -3027,7 +3568,7 @@ def test_can_demote_score_calibration_as_score_set_contributor( data_files / "scores.csv", ) calibration = create_publish_and_promote_score_calibration( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) add_contributor( @@ -3073,7 +3614,7 @@ def test_can_demote_score_calibration_as_score_set_owner( data_files / "scores.csv", ) calibration = create_publish_and_promote_score_calibration( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) response = client.post( @@ -3109,7 +3650,7 @@ def test_can_demote_score_calibration_as_admin_user( data_files / "scores.csv", ) calibration = create_publish_and_promote_score_calibration( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(admin_app_overrides): @@ -3145,7 +3686,9 @@ def test_can_demote_non_primary_score_calibration( experiment["urn"], data_files / "scores.csv", ) - create_publish_and_promote_score_calibration(client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION)) + create_publish_and_promote_score_calibration( + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + ) secondary_calibration = create_test_score_calibration_in_score_set_via_client( client, score_set["urn"], deepcamelize(TEST_PATHOGENICITY_SCORE_CALIBRATION) ) @@ -3210,7 +3753,7 @@ def test_cannot_publish_score_calibration_as_anonymous_user( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(anonymous_app_overrides): @@ -3245,7 +3788,7 @@ def test_cannot_publish_score_calibration_when_score_calibration_not_owned_by_us data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(extra_user_app_overrides): @@ -3280,7 +3823,7 @@ def test_can_publish_score_calibration_as_score_set_owner( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) response = client.post( @@ -3316,7 +3859,7 @@ def test_can_publish_score_calibration_as_admin_user( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with DependencyOverrider(admin_app_overrides): @@ -3353,7 +3896,7 @@ def test_can_publish_already_published_calibration( data_files / "scores.csv", ) calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) # publish it first diff --git a/tests/routers/test_score_set.py b/tests/routers/test_score_set.py index 86234392..13bd7ce7 100644 --- a/tests/routers/test_score_set.py +++ b/tests/routers/test_score_set.py @@ -35,7 +35,8 @@ SAVED_PUBMED_PUBLICATION, SAVED_SHORT_EXTRA_LICENSE, TEST_BIORXIV_IDENTIFIER, - TEST_BRNICH_SCORE_CALIBRATION, + TEST_BRNICH_SCORE_CALIBRATION_CLASS_BASED, + TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, TEST_CROSSREF_IDENTIFIER, TEST_GNOMAD_DATA_VERSION, TEST_INACTIVE_LICENSE, @@ -48,11 +49,12 @@ TEST_ORCID_ID, TEST_PATHOGENICITY_SCORE_CALIBRATION, TEST_PUBMED_IDENTIFIER, - TEST_SAVED_BRNICH_SCORE_CALIBRATION, + TEST_SAVED_BRNICH_SCORE_CALIBRATION_RANGE_BASED, TEST_SAVED_CLINVAR_CONTROL, TEST_SAVED_GENERIC_CLINICAL_CONTROL, TEST_SAVED_GNOMAD_VARIANT, TEST_USER, + VALID_CLINGEN_CA_ID, ) from tests.helpers.dependency_overrider import DependencyOverrider from tests.helpers.util.common import ( @@ -204,7 +206,7 @@ def test_create_score_set_with_score_calibration(client, mock_publication_fetch, score_set["experimentUrn"] = experiment["urn"] score_set.update( { - "scoreCalibrations": [deepcamelize(TEST_BRNICH_SCORE_CALIBRATION)], + "scoreCalibrations": [deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED)], } ) @@ -219,7 +221,7 @@ def test_create_score_set_with_score_calibration(client, mock_publication_fetch, deepcopy(TEST_MINIMAL_SEQ_SCORESET_RESPONSE), experiment, response_data ) expected_response["experiment"].update({"numScoreSets": 1}) - expected_calibration = deepcopy(TEST_SAVED_BRNICH_SCORE_CALIBRATION) + expected_calibration = deepcopy(TEST_SAVED_BRNICH_SCORE_CALIBRATION_RANGE_BASED) expected_calibration["urn"] = response_data["scoreCalibrations"][0]["urn"] expected_calibration["private"] = True expected_calibration["primary"] = False @@ -234,6 +236,34 @@ def test_create_score_set_with_score_calibration(client, mock_publication_fetch, assert response.status_code == 200 +@pytest.mark.parametrize( + "mock_publication_fetch", + [ + ( + [ + {"dbName": "PubMed", "identifier": f"{TEST_PUBMED_IDENTIFIER}"}, + {"dbName": "bioRxiv", "identifier": f"{TEST_BIORXIV_IDENTIFIER}"}, + ] + ) + ], + indirect=["mock_publication_fetch"], +) +def test_cannot_create_score_set_with_class_based_calibration(client, mock_publication_fetch, setup_router_db): + experiment = create_experiment(client) + score_set = deepcopy(TEST_MINIMAL_SEQ_SCORESET) + score_set["experimentUrn"] = experiment["urn"] + score_set.update( + { + "scoreCalibrations": [deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_CLASS_BASED)], + } + ) + + response = client.post("/api/v1/score-sets/", json=score_set) + assert response.status_code == 409 + response_data = response.json() + assert "Class-based calibrations are not supported on score set creation" in response_data["detail"] + + @pytest.mark.parametrize( "mock_publication_fetch", [ @@ -815,12 +845,12 @@ def test_extra_user_can_only_view_published_score_calibrations_in_score_set( worker_queue.assert_called_once() create_test_score_calibration_in_score_set_via_client( - client, published_score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, published_score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) public_calibration = create_publish_and_promote_score_calibration( client, published_score_set["urn"], - deepcamelize(TEST_BRNICH_SCORE_CALIBRATION), + deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED), ) with DependencyOverrider(extra_user_app_overrides): @@ -848,12 +878,12 @@ def test_creating_user_can_view_all_score_calibrations_in_score_set(client, setu experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) private_calibration = create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) public_calibration = create_publish_and_promote_score_calibration( client, score_set["urn"], - deepcamelize(TEST_BRNICH_SCORE_CALIBRATION), + deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED), ) response = client.get(f"/api/v1/score-sets/{score_set['urn']}") @@ -1346,7 +1376,7 @@ def test_score_calibrations_remain_private_when_score_set_is_published( ) score_set = mock_worker_variant_insertion(client, session, data_provider, score_set, data_files / "scores.csv") create_test_score_calibration_in_score_set_via_client( - client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION) + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) ) with patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as worker_queue: @@ -1408,7 +1438,7 @@ def test_anonymous_cannot_publish_user_private_score_set( assert "Could not validate credentials" in response_data["detail"] -def test_contributor_can_publish_other_users_score_set(session, data_provider, client, setup_router_db, data_files): +def test_contributor_cannot_publish_other_users_score_set(session, data_provider, client, setup_router_db, data_files): experiment = create_experiment(client) score_set = create_seq_score_set(client, experiment["urn"]) score_set = mock_worker_variant_insertion(client, session, data_provider, score_set, data_files / "scores.csv") @@ -1423,60 +1453,15 @@ def test_contributor_can_publish_other_users_score_set(session, data_provider, c ) with patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as worker_queue: - published_score_set = publish_score_set(client, score_set["urn"]) - worker_queue.assert_called_once() - - assert published_score_set["urn"] == "urn:mavedb:00000001-a-1" - assert published_score_set["experiment"]["urn"] == "urn:mavedb:00000001-a" - - expected_response = update_expected_response_for_created_resources( - deepcopy(TEST_MINIMAL_SEQ_SCORESET_RESPONSE), published_score_set["experiment"], published_score_set - ) - expected_response["experiment"].update({"publishedDate": date.today().isoformat(), "numScoreSets": 1}) - expected_response.update( - { - "urn": published_score_set["urn"], - "publishedDate": date.today().isoformat(), - "numVariants": 3, - "private": False, - "datasetColumns": SAVED_MINIMAL_DATASET_COLUMNS, - "processingState": ProcessingState.success.name, - } - ) - expected_response["contributors"] = [ - { - "recordType": "Contributor", - "orcidId": TEST_USER["username"], - "givenName": TEST_USER["first_name"], - "familyName": TEST_USER["last_name"], - } - ] - expected_response["createdBy"] = { - "recordType": "User", - "orcidId": EXTRA_USER["username"], - "firstName": EXTRA_USER["first_name"], - "lastName": EXTRA_USER["last_name"], - } - expected_response["modifiedBy"] = { - "recordType": "User", - "orcidId": EXTRA_USER["username"], - "firstName": EXTRA_USER["first_name"], - "lastName": EXTRA_USER["last_name"], - } - assert sorted(expected_response.keys()) == sorted(published_score_set.keys()) - - # refresh score set to post worker state - score_set = (client.get(f"/api/v1/score-sets/{published_score_set['urn']}")).json() - for key in expected_response: - assert (key, expected_response[key]) == (key, score_set[key]) + response = client.post(f"/api/v1/score-sets/{score_set['urn']}/publish") + assert response.status_code == 403 + worker_queue.assert_not_called() + response_data = response.json() - score_set_variants = session.execute( - select(VariantDbModel).join(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set["urn"]) - ).scalars() - assert all([variant.urn.startswith("urn:mavedb:") for variant in score_set_variants]) + assert f"insufficient permissions on score set with URN '{score_set['urn']}'" in response_data["detail"] -def test_admin_cannot_publish_other_user_private_score_set( +def test_admin_can_publish_other_user_private_score_set( session, data_provider, client, admin_app_overrides, setup_router_db, data_files ): experiment = create_experiment(client) @@ -1488,11 +1473,8 @@ def test_admin_cannot_publish_other_user_private_score_set( patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as queue, ): response = client.post(f"/api/v1/score-sets/{score_set['urn']}/publish") - assert response.status_code == 404 - queue.assert_not_called() - response_data = response.json() - - assert f"score set with URN '{score_set['urn']}' not found" in response_data["detail"] + assert response.status_code == 200 + queue.assert_called_once() ######################################################################################################################## @@ -2334,7 +2316,9 @@ def test_cannot_delete_own_published_scoreset(session, data_provider, client, se assert del_response.status_code == 403 del_response_data = del_response.json() - assert f"insufficient permissions for URN '{published_score_set['urn']}'" in del_response_data["detail"] + assert ( + f"insufficient permissions on score set with URN '{published_score_set['urn']}'" in del_response_data["detail"] + ) def test_contributor_can_delete_other_users_private_scoreset( @@ -2355,7 +2339,9 @@ def test_contributor_can_delete_other_users_private_scoreset( response = client.delete(f"/api/v1/score-sets/{score_set['urn']}") - assert response.status_code == 200 + assert response.status_code == 403 + response_data = response.json() + assert f"insufficient permissions on score set with URN '{score_set['urn']}'" in response_data["detail"] def test_admin_can_delete_other_users_private_scoreset( @@ -2897,6 +2883,83 @@ def test_download_scores_counts_and_post_mapped_variants_file( ) +# Additional namespace export tests: VEP, ClinGen, gnomAD +def test_download_vep_file_in_variant_data_path(session, data_provider, client, setup_router_db, data_files): + experiment = create_experiment(client) + score_set = create_seq_score_set(client, experiment["urn"]) + score_set = mock_worker_variant_insertion( + client, session, data_provider, score_set, data_files / "scores.csv", data_files / "counts.csv" + ) + # Create mapped variants with VEP consequence populated + create_mapped_variants_for_score_set(session, score_set["urn"], TEST_MAPPED_VARIANT_WITH_HGVS_G_EXPRESSION) + + with patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as worker_queue: + published_score_set = publish_score_set(client, score_set["urn"]) + worker_queue.assert_called_once() + + response = client.get( + f"/api/v1/score-sets/{published_score_set['urn']}/variants/data?namespaces=vep&include_post_mapped_hgvs=true&drop_na_columns=true" + ) + assert response.status_code == 200 + reader = csv.DictReader(StringIO(response.text)) + assert "vep.vep_functional_consequence" in reader.fieldnames + # At least one row should contain the test consequence value + rows = list(reader) + assert any(row.get("vep.vep_functional_consequence") == "missense_variant" for row in rows) + + +def test_download_clingen_file_in_variant_data_path(session, data_provider, client, setup_router_db, data_files): + experiment = create_experiment(client) + score_set = create_seq_score_set(client, experiment["urn"]) + score_set = mock_worker_variant_insertion( + client, session, data_provider, score_set, data_files / "scores.csv", data_files / "counts.csv" + ) + # Create mapped variants then set ClinGen allele id for first mapped variant + create_mapped_variants_for_score_set(session, score_set["urn"], TEST_MAPPED_VARIANT_WITH_HGVS_G_EXPRESSION) + db_score_set = session.query(ScoreSetDbModel).filter(ScoreSetDbModel.urn == score_set["urn"]).one() + first_mapped_variant = db_score_set.variants[0].mapped_variants[0] + first_mapped_variant.clingen_allele_id = VALID_CLINGEN_CA_ID + session.add(first_mapped_variant) + session.commit() + + with patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as worker_queue: + published_score_set = publish_score_set(client, score_set["urn"]) + worker_queue.assert_called_once() + + response = client.get( + f"/api/v1/score-sets/{published_score_set['urn']}/variants/data?namespaces=clingen&include_post_mapped_hgvs=true&drop_na_columns=true" + ) + assert response.status_code == 200 + reader = csv.DictReader(StringIO(response.text)) + assert "clingen.clingen_allele_id" in reader.fieldnames + rows = list(reader) + assert rows[0].get("clingen.clingen_allele_id") == VALID_CLINGEN_CA_ID + + +def test_download_gnomad_file_in_variant_data_path(session, data_provider, client, setup_router_db, data_files): + experiment = create_experiment(client) + score_set = create_seq_score_set(client, experiment["urn"]) + score_set = mock_worker_variant_insertion( + client, session, data_provider, score_set, data_files / "scores.csv", data_files / "counts.csv" + ) + # Link a gnomAD variant to the first mapped variant (version may not match export filter) + score_set = create_seq_score_set_with_mapped_variants( + client, session, data_provider, experiment["urn"], data_files / "scores.csv" + ) + link_gnomad_variants_to_mapped_variants(session, score_set) + + with patch.object(arq.ArqRedis, "enqueue_job", return_value=None) as worker_queue: + published_score_set = publish_score_set(client, score_set["urn"]) + worker_queue.assert_called_once() + + response = client.get( + f"/api/v1/score-sets/{published_score_set['urn']}/variants/data?namespaces=gnomad&drop_na_columns=true" + ) + assert response.status_code == 200 + reader = csv.DictReader(StringIO(response.text)) + assert "gnomad.gnomad_af" in reader.fieldnames + + ######################################################################################################################## # Fetching clinical controls and control options for a score set ######################################################################################################################## @@ -3098,7 +3161,9 @@ def test_get_annotated_pathogenicity_evidence_lines_for_score_set( experiment["urn"], data_files / "scores.csv", ) - create_publish_and_promote_score_calibration(client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION)) + create_publish_and_promote_score_calibration( + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + ) # The contents of the annotated variants objects should be tested in more detail elsewhere. response = client.get(f"/api/v1/score-sets/{score_set['urn']}/annotated-variants/pathogenicity-evidence-line") @@ -3185,7 +3250,9 @@ def test_get_annotated_pathogenicity_evidence_lines_for_score_set_when_some_vari experiment["urn"], data_files / "scores.csv", ) - create_publish_and_promote_score_calibration(client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION)) + create_publish_and_promote_score_calibration( + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + ) first_var = clear_first_mapped_variant_post_mapped(session, score_set["urn"]) @@ -3225,7 +3292,9 @@ def test_get_annotated_functional_impact_statement_for_score_set( experiment["urn"], data_files / "scores.csv", ) - create_publish_and_promote_score_calibration(client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION)) + create_publish_and_promote_score_calibration( + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + ) response = client.get(f"/api/v1/score-sets/{score_set['urn']}/annotated-variants/functional-impact-statement") response_data = parse_ndjson_response(response) @@ -3255,7 +3324,7 @@ def test_nonetype_annotated_functional_impact_statement_for_score_set_when_calib data_files / "scores.csv", update={ "secondaryPublicationIdentifiers": [{"dbName": "PubMed", "identifier": f"{TEST_PUBMED_IDENTIFIER}"}], - "scoreRanges": camelize([TEST_BRNICH_SCORE_CALIBRATION, TEST_PATHOGENICITY_SCORE_CALIBRATION]), + "scoreRanges": camelize([TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, TEST_PATHOGENICITY_SCORE_CALIBRATION]), }, ) @@ -3314,7 +3383,9 @@ def test_get_annotated_functional_impact_statement_for_score_set_when_some_varia experiment["urn"], data_files / "scores.csv", ) - create_publish_and_promote_score_calibration(client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION)) + create_publish_and_promote_score_calibration( + client, score_set["urn"], deepcamelize(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + ) first_var = clear_first_mapped_variant_post_mapped(session, score_set["urn"]) @@ -3378,7 +3449,7 @@ def test_annotated_functional_study_result_exists_for_score_set_when_thresholds_ data_files / "scores.csv", update={ "secondaryPublicationIdentifiers": [{"dbName": "PubMed", "identifier": f"{TEST_PUBMED_IDENTIFIER}"}], - "scoreRanges": camelize([TEST_BRNICH_SCORE_CALIBRATION, TEST_PATHOGENICITY_SCORE_CALIBRATION]), + "scoreRanges": camelize([TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, TEST_PATHOGENICITY_SCORE_CALIBRATION]), }, ) @@ -3410,7 +3481,7 @@ def test_annotated_functional_study_result_exists_for_score_set_when_ranges_not_ data_files / "scores.csv", update={ "secondaryPublicationIdentifiers": [{"dbName": "PubMed", "identifier": f"{TEST_PUBMED_IDENTIFIER}"}], - "scoreRanges": camelize([TEST_BRNICH_SCORE_CALIBRATION, TEST_PATHOGENICITY_SCORE_CALIBRATION]), + "scoreRanges": camelize([TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, TEST_PATHOGENICITY_SCORE_CALIBRATION]), }, ) @@ -3465,7 +3536,7 @@ def test_annotated_functional_study_result_exists_for_score_set_when_some_varian data_files / "scores.csv", update={ "secondaryPublicationIdentifiers": [{"dbName": "PubMed", "identifier": f"{TEST_PUBMED_IDENTIFIER}"}], - "scoreRanges": camelize([TEST_BRNICH_SCORE_CALIBRATION, TEST_PATHOGENICITY_SCORE_CALIBRATION]), + "scoreRanges": camelize([TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, TEST_PATHOGENICITY_SCORE_CALIBRATION]), }, ) diff --git a/tests/routers/test_users.py b/tests/routers/test_users.py index 03b57c0b..68fa382d 100644 --- a/tests/routers/test_users.py +++ b/tests/routers/test_users.py @@ -21,14 +21,14 @@ def test_cannot_list_users_as_anonymous_user(client, setup_router_db, anonymous_ assert response.status_code == 401 response_value = response.json() - assert response_value["detail"] in "Could not validate credentials" + assert "Could not validate credentials" in response_value["detail"] def test_cannot_list_users_as_normal_user(client, setup_router_db): response = client.get("/api/v1/users/") assert response.status_code == 403 response_value = response.json() - assert response_value["detail"] in "You are not authorized to use this feature" + assert "You are not authorized to use this feature" in response_value["detail"] def test_can_list_users_as_admin_user(admin_app_overrides, setup_router_db, client): @@ -50,10 +50,7 @@ def test_cannot_get_anonymous_user(client, setup_router_db, session, anonymous_a assert response.status_code == 401 response_value = response.json() - assert response_value["detail"] in "Could not validate credentials" - - # Some lingering db transaction holds this test open unless it is explicitly closed. - session.commit() + assert "Could not validate credentials" in response_value["detail"] def test_get_current_user(client, setup_router_db, session): @@ -62,9 +59,6 @@ def test_get_current_user(client, setup_router_db, session): response_value = response.json() assert response_value["orcidId"] == TEST_USER["username"] - # Some lingering db transaction holds this test open unless it is explicitly closed. - session.commit() - def test_get_current_admin_user(client, admin_app_overrides, setup_router_db, session): with DependencyOverrider(admin_app_overrides): @@ -75,9 +69,6 @@ def test_get_current_admin_user(client, admin_app_overrides, setup_router_db, se assert response_value["orcidId"] == ADMIN_USER["username"] assert response_value["roles"] == ["admin"] - # Some lingering db transaction holds this test open unless it is explicitly closed. - session.commit() - def test_cannot_impersonate_admin_user_as_default_user(client, setup_router_db, session): # NOTE: We can't mock JWTBearer directly because the object is created when the `get_current_user` function is called. @@ -100,9 +91,6 @@ def test_cannot_impersonate_admin_user_as_default_user(client, setup_router_db, assert response.status_code == 403 assert response.json()["detail"] in "This user is not a member of the requested acting role." - # Some lingering db transaction holds this test open unless it is explicitly closed. - session.commit() - def test_cannot_fetch_single_user_as_anonymous_user(client, setup_router_db, session, anonymous_app_overrides): with DependencyOverrider(anonymous_app_overrides): @@ -111,18 +99,12 @@ def test_cannot_fetch_single_user_as_anonymous_user(client, setup_router_db, ses assert response.status_code == 401 assert response.json()["detail"] in "Could not validate credentials" - # Some lingering db transaction holds this test open unless it is explicitly closed. - session.commit() - def test_cannot_fetch_single_user_as_normal_user(client, setup_router_db, session): response = client.get("/api/v1/users/2") assert response.status_code == 403 assert response.json()["detail"] in "You are not authorized to use this feature" - # Some lingering db transaction holds this test open unless it is explicitly closed. - session.commit() - def test_can_fetch_single_user_as_admin_user(client, setup_router_db, session, admin_app_overrides): with DependencyOverrider(admin_app_overrides): @@ -132,9 +114,6 @@ def test_can_fetch_single_user_as_admin_user(client, setup_router_db, session, a response_value = response.json() assert response_value["orcidId"] == EXTRA_USER["username"] - # Some lingering db transaction holds this test open unless it is explicitly closed. - session.commit() - def test_fetching_nonexistent_user_as_admin_raises_exception(client, setup_router_db, session, admin_app_overrides): with DependencyOverrider(admin_app_overrides): @@ -142,10 +121,7 @@ def test_fetching_nonexistent_user_as_admin_raises_exception(client, setup_route assert response.status_code == 404 response_value = response.json() - assert "User with ID 0 not found" in response_value["detail"] - - # Some lingering db transaction holds this test open unless it is explicitly closed. - session.commit() + assert "user profile with ID 0 not found" in response_value["detail"] def test_anonymous_user_cannot_update_self(client, setup_router_db, anonymous_app_overrides): @@ -209,7 +185,7 @@ def test_admin_can_set_logged_in_property_on_self(client, setup_router_db, admin [ ("email", "updated@test.com"), ("first_name", "Updated"), - ("last_name", "User"), + ("last_name", "user profile"), ("roles", ["admin"]), ], ) @@ -223,7 +199,7 @@ def test_anonymous_user_cannot_update_other_users( assert response.status_code == 401 response_value = response.json() - assert response_value["detail"] in "Could not validate credentials" + assert "Could not validate credentials" in response_value["detail"] @pytest.mark.parametrize( @@ -231,7 +207,7 @@ def test_anonymous_user_cannot_update_other_users( [ ("email", "updated@test.com"), ("first_name", "Updated"), - ("last_name", "User"), + ("last_name", "user profile"), ("roles", ["admin"]), ], ) @@ -241,7 +217,7 @@ def test_user_cannot_update_other_users(client, setup_router_db, field_name, fie response = client.put("/api/v1/users//2", json=user_update) assert response.status_code == 403 response_value = response.json() - assert response_value["detail"] in "Insufficient permissions for user update." + assert "insufficient permissions on user profile with ID '2'" in response_value["detail"] @pytest.mark.parametrize( @@ -249,7 +225,7 @@ def test_user_cannot_update_other_users(client, setup_router_db, field_name, fie [ ("email", "updated@test.com"), ("first_name", "Updated"), - ("last_name", "User"), + ("last_name", "user profile"), ("roles", ["admin"]), ], ) diff --git a/tests/validation/dataframe/test_calibration.py b/tests/validation/dataframe/test_calibration.py new file mode 100644 index 00000000..57c7d22b --- /dev/null +++ b/tests/validation/dataframe/test_calibration.py @@ -0,0 +1,1070 @@ +# ruff: noqa: E402 + +import pytest + +pytest.importorskip("psycopg2") + +from unittest.mock import Mock, patch + +import pandas as pd + +from mavedb.lib.validation.constants.general import ( + calibration_class_column_name, + calibration_variant_column_name, + hgvs_nt_column, + hgvs_pro_column, +) +from mavedb.lib.validation.dataframe.calibration import ( + choose_calibration_index_column, + validate_and_standardize_calibration_classes_dataframe, + validate_calibration_classes, + validate_calibration_df_column_names, + validate_index_existence_in_score_set, +) +from mavedb.lib.validation.exceptions import ValidationError +from mavedb.view_models import score_calibration + + +class TestValidateAndStandardizeCalibrationClassesDataframe: + """Test suite for validate_and_standardize_calibration_classes_dataframe function.""" + + @pytest.fixture + def mock_dependencies(self): + """Mock all external dependencies for the function.""" + with ( + patch("mavedb.lib.validation.dataframe.calibration.standardize_dataframe") as mock_standardize, + patch("mavedb.lib.validation.dataframe.calibration.validate_no_null_rows") as mock_validate_no_null, + patch("mavedb.lib.validation.dataframe.calibration.validate_variant_column") as mock_validate_variant, + patch("mavedb.lib.validation.dataframe.calibration.validate_data_column") as mock_validate_data, + patch( + "mavedb.lib.validation.dataframe.calibration.validate_index_existence_in_score_set" + ) as mock_validate_index_existence, + ): + yield { + "standardize_dataframe": mock_standardize, + "validate_no_null_rows": mock_validate_no_null, + "validate_variant_column": mock_validate_variant, + "validate_data_column": mock_validate_data, + "validate_index_existence_in_score_set": mock_validate_index_existence, + } + + def test_validate_and_standardize_calibration_classes_dataframe_success(self, mock_dependencies): + """Test successful validation and standardization.""" + mock_db = Mock() + mock_score_set = Mock() + mock_score_set.id = 123 + + mock_calibration = Mock() + mock_calibration.class_based = True + + input_df = pd.DataFrame( + { + calibration_variant_column_name.upper(): ["var1", "var2"], + calibration_class_column_name.upper(): ["A", "B"], + } + ) + standardized_df = pd.DataFrame( + {calibration_variant_column_name: ["var1", "var2"], calibration_class_column_name: ["A", "B"]} + ) + + mock_dependencies["standardize_dataframe"].return_value = standardized_df + + mock_scalars = Mock() + mock_scalars.all.return_value = ["var1", "var2"] + mock_db.scalars.return_value = mock_scalars + + mock_classification1 = Mock() + mock_classification1.class_ = "A" + mock_classification2 = Mock() + mock_classification2.class_ = "B" + mock_calibration.functional_classifications = [mock_classification1, mock_classification2] + + result, index_column = validate_and_standardize_calibration_classes_dataframe( + mock_db, mock_score_set, mock_calibration, input_df + ) + + assert result.equals(standardized_df) + mock_dependencies["standardize_dataframe"].assert_called_once() + mock_dependencies["validate_no_null_rows"].assert_called_once_with(standardized_df) + mock_dependencies["validate_variant_column"].assert_called_once() + mock_dependencies["validate_data_column"].assert_called_once() + + def test_validate_and_standardize_calibration_classes_dataframe_not_class_based(self): + """Test ValidationError when calibration is not class-based.""" + mock_db = Mock() + mock_score_set = Mock() + mock_calibration = Mock() + mock_calibration.class_based = False + input_df = pd.DataFrame({"variant": ["var1"], "class": ["A"]}) + + with pytest.raises( + ValidationError, + match="Calibration classes file can only be provided for functional classification calibrations.", + ): + validate_and_standardize_calibration_classes_dataframe(mock_db, mock_score_set, mock_calibration, input_df) + + def test_validate_and_standardize_calibration_classes_dataframe_missing_index_columns(self, mock_dependencies): + """Test ValidationError when column validation fails.""" + mock_db = Mock() + mock_score_set = Mock() + mock_calibration = Mock() + mock_calibration.class_based = True + + input_df = pd.DataFrame({calibration_class_column_name: ["c"], "invalid": ["A"]}) + standardized_df = pd.DataFrame({calibration_class_column_name: ["c"], "invalid": ["A"]}) + + mock_dependencies["standardize_dataframe"].return_value = standardized_df + + with pytest.raises( + ValidationError, + match=f"at least one of {', '.join({hgvs_nt_column, hgvs_pro_column, calibration_variant_column_name})} must be present", + ): + validate_and_standardize_calibration_classes_dataframe(mock_db, mock_score_set, mock_calibration, input_df) + + def test_validate_and_standardize_calibration_classes_dataframe_null_rows(self, mock_dependencies): + """Test ValidationError when null rows validation fails.""" + mock_db = Mock() + mock_score_set = Mock() + mock_calibration = Mock() + mock_calibration.class_based = True + + input_df = pd.DataFrame({calibration_variant_column_name: ["var1"], calibration_class_column_name: ["A"]}) + standardized_df = pd.DataFrame( + {calibration_variant_column_name: ["var1"], calibration_class_column_name: ["A"]} + ) + + mock_dependencies["standardize_dataframe"].return_value = standardized_df + mock_dependencies["validate_no_null_rows"].side_effect = ValidationError("null rows detected") + + with pytest.raises(ValidationError, match="null rows detected"): + validate_and_standardize_calibration_classes_dataframe(mock_db, mock_score_set, mock_calibration, input_df) + + def test_validate_and_standardize_calibration_classes_dataframe_drops_null_class_rows(self, mock_dependencies): + """Test that rows with null calibration class are dropped.""" + mock_db = Mock() + mock_score_set = Mock() + mock_score_set.id = 123 + mock_calibration = Mock() + mock_calibration.class_based = True + + input_df = pd.DataFrame( + { + calibration_variant_column_name: ["var1", "var2", "var3", "var4"], + calibration_class_column_name: ["A", None, "B", pd.NA], + } + ) + standardized_df = pd.DataFrame( + { + calibration_variant_column_name: ["var1", "var2", "var3", "var4"], + calibration_class_column_name: ["A", None, "B", pd.NA], + } + ) + + mock_dependencies["standardize_dataframe"].return_value = standardized_df + + mock_scalars = Mock() + mock_scalars.all.return_value = ["var1", "var3"] + mock_db.scalars.return_value = mock_scalars + + mock_classification1 = Mock() + mock_classification1.class_ = "A" + mock_classification2 = Mock() + mock_classification2.class_ = "B" + mock_calibration.functional_classifications = [mock_classification1, mock_classification2] + + result, index_column = validate_and_standardize_calibration_classes_dataframe( + mock_db, mock_score_set, mock_calibration, input_df + ) + + expected_df = pd.DataFrame( + { + calibration_variant_column_name: ["var1", "var3"], + calibration_class_column_name: ["A", "B"], + } + ) + + assert result.equals(expected_df) + + def test_validate_and_standardize_calibration_classes_dataframe_propagates_nonexistent_variants( + self, mock_dependencies + ): + """Test ValidationError when variant URN validation fails.""" + mock_db = Mock() + mock_score_set = Mock() + mock_score_set.id = 123 + mock_calibration = Mock() + mock_calibration.class_based = True + + input_df = pd.DataFrame({calibration_variant_column_name: ["var1"], calibration_class_column_name: ["A"]}) + standardized_df = pd.DataFrame( + {calibration_variant_column_name: ["var1"], calibration_class_column_name: ["A"]} + ) + + mock_dependencies["standardize_dataframe"].return_value = standardized_df + + mock_scalars = Mock() + mock_scalars.all.return_value = [] + mock_db.scalars.return_value = mock_scalars + + mock_classification1 = Mock() + mock_classification1.class_ = "A" + mock_calibration.functional_classifications = [mock_classification1] + + mock_dependencies["validate_index_existence_in_score_set"].side_effect = ValidationError( + "The following resources do not exist in the score set: var1" + ) + + with pytest.raises(ValidationError, match="The following resources do not exist in the score set: var1"): + validate_and_standardize_calibration_classes_dataframe(mock_db, mock_score_set, mock_calibration, input_df) + + def test_validate_and_standardize_calibration_classes_dataframe_invalid_classes(self, mock_dependencies): + """Test ValidationError when class validation fails.""" + mock_db = Mock() + mock_score_set = Mock() + mock_score_set.id = 123 + mock_calibration = Mock() + mock_calibration.class_based = True + + input_df = pd.DataFrame({calibration_variant_column_name: ["var1"], calibration_class_column_name: ["A"]}) + standardized_df = pd.DataFrame( + {calibration_variant_column_name: ["var1"], calibration_class_column_name: ["A"]} + ) + + mock_dependencies["standardize_dataframe"].return_value = standardized_df + + mock_scalars = Mock() + mock_scalars.all.return_value = ["var1"] + mock_db.scalars.return_value = mock_scalars + + mock_calibration.functional_classifications = None + + with pytest.raises( + ValidationError, match="Calibration must have functional classifications defined for class validation." + ): + validate_and_standardize_calibration_classes_dataframe(mock_db, mock_score_set, mock_calibration, input_df) + + def test_validate_and_standardize_calibration_classes_dataframe_variant_column_validation_fails( + self, mock_dependencies + ): + """Test ValidationError when variant column validation fails.""" + mock_db = Mock() + mock_score_set = Mock() + mock_calibration = Mock() + mock_calibration.class_based = True + + input_df = pd.DataFrame({calibration_variant_column_name: ["var1"], calibration_class_column_name: ["A"]}) + standardized_df = pd.DataFrame( + {calibration_variant_column_name: ["var1"], calibration_class_column_name: ["A"]} + ) + + mock_dependencies["standardize_dataframe"].return_value = standardized_df + mock_dependencies["validate_variant_column"].side_effect = ValidationError("invalid variant column") + + with pytest.raises(ValidationError, match="invalid variant column"): + validate_and_standardize_calibration_classes_dataframe(mock_db, mock_score_set, mock_calibration, input_df) + + def test_validate_and_standardize_calibration_classes_dataframe_data_column_validation_fails( + self, mock_dependencies + ): + """Test ValidationError when data column validation fails.""" + mock_db = Mock() + mock_score_set = Mock() + mock_score_set.id = 123 + mock_calibration = Mock() + mock_calibration.class_based = True + + input_df = pd.DataFrame({calibration_variant_column_name: ["var1"], calibration_class_column_name: ["A"]}) + standardized_df = pd.DataFrame( + {calibration_variant_column_name: ["var1"], calibration_class_column_name: ["A"]} + ) + + mock_dependencies["standardize_dataframe"].return_value = standardized_df + mock_dependencies["validate_data_column"].side_effect = ValidationError("invalid data column") + + mock_scalars = Mock() + mock_scalars.all.return_value = ["var1"] + mock_db.scalars.return_value = mock_scalars + + with pytest.raises(ValidationError, match="invalid data column"): + validate_and_standardize_calibration_classes_dataframe(mock_db, mock_score_set, mock_calibration, input_df) + + def test_validate_and_standardize_calibration_classes_dataframe_mixed_case_columns(self, mock_dependencies): + """Test successful validation with mixed case column names.""" + mock_db = Mock() + mock_score_set = Mock() + mock_score_set.id = 123 + mock_calibration = Mock() + mock_calibration.class_based = True + + input_df = pd.DataFrame( + {calibration_variant_column_name: ["var1"], calibration_class_column_name.upper(): ["A"]} + ) + standardized_df = pd.DataFrame( + {calibration_variant_column_name: ["var1"], calibration_class_column_name: ["A"]} + ) + + mock_dependencies["standardize_dataframe"].return_value = standardized_df + + mock_scalars = Mock() + mock_scalars.all.return_value = ["var1"] + mock_db.scalars.return_value = mock_scalars + + mock_classification = Mock() + mock_classification.class_ = "A" + mock_calibration.functional_classifications = [mock_classification] + + result, index_column = validate_and_standardize_calibration_classes_dataframe( + mock_db, mock_score_set, mock_calibration, input_df + ) + + assert result.equals(standardized_df) + + def test_validate_and_standardize_calibration_classes_dataframe_with_score_calibration_modify( + self, mock_dependencies + ): + """Test function works with ScoreCalibrationModify object.""" + mock_db = Mock() + mock_score_set = Mock() + mock_score_set.id = 123 + mock_calibration = Mock(spec=score_calibration.ScoreCalibrationModify) + mock_calibration.class_based = True + + input_df = pd.DataFrame({calibration_variant_column_name: ["var1"], calibration_class_column_name: ["A"]}) + standardized_df = pd.DataFrame( + {calibration_variant_column_name: ["var1"], calibration_class_column_name: ["A"]} + ) + + mock_dependencies["standardize_dataframe"].return_value = standardized_df + + mock_scalars = Mock() + mock_scalars.all.return_value = ["var1"] + mock_db.scalars.return_value = mock_scalars + + mock_classification = Mock() + mock_classification.class_ = "A" + mock_calibration.functional_classifications = [mock_classification] + + result, index_column = validate_and_standardize_calibration_classes_dataframe( + mock_db, mock_score_set, mock_calibration, input_df + ) + + assert result.equals(standardized_df) + + def test_validate_and_standardize_calibration_classes_dataframe_empty_dataframe(self, mock_dependencies): + """Test ValidationError with empty dataframe.""" + mock_db = Mock() + mock_score_set = Mock() + mock_calibration = Mock() + mock_calibration.class_based = True + + input_df = pd.DataFrame() + standardized_df = pd.DataFrame() + + mock_dependencies["standardize_dataframe"].return_value = standardized_df + + with pytest.raises(ValidationError, match=f"missing required column: '{calibration_class_column_name}'"): + validate_and_standardize_calibration_classes_dataframe(mock_db, mock_score_set, mock_calibration, input_df) + + def test_validate_and_standardize_calibration_classes_dataframe_multiple_candidate_index_columns( + self, mock_dependencies + ): + """Test successful validation when multiple candidate index columns are present.""" + mock_db = Mock() + mock_score_set = Mock() + mock_score_set.id = 123 + mock_calibration = Mock() + mock_calibration.class_based = True + + input_df = pd.DataFrame( + { + calibration_variant_column_name: ["var1", "var2"], + hgvs_nt_column: ["NM_000546.5:c.215C>G", "NM_000546.5:c.743G>A"], + calibration_class_column_name: ["A", "B"], + } + ) + standardized_df = pd.DataFrame( + { + calibration_variant_column_name: ["var1", "var2"], + hgvs_nt_column: ["NM_000546.5:c.215C>G", "NM_000546.5:c.743G>A"], + calibration_class_column_name: ["A", "B"], + } + ) + + mock_dependencies["standardize_dataframe"].return_value = standardized_df + mock_dependencies["validate_index_existence_in_score_set"].return_value = None + + mock_scalars = Mock() + mock_scalars.all.return_value = ["var1", "var2"] + mock_db.scalars.return_value = mock_scalars + + mock_classification1 = Mock() + mock_classification1.class_ = "A" + mock_classification2 = Mock() + mock_classification2.class_ = "B" + mock_calibration.functional_classifications = [mock_classification1, mock_classification2] + + result, index_column = validate_and_standardize_calibration_classes_dataframe( + mock_db, mock_score_set, mock_calibration, input_df + ) + + assert result.equals(standardized_df) + assert index_column == calibration_variant_column_name + mock_dependencies["validate_index_existence_in_score_set"].assert_called_once() + + +class TestValidateCalibrationDfColumnNames: + """Test suite for validate_calibration_df_column_names function.""" + + def test_validate_calibration_df_column_names_success(self): + """Test successful validation with correct column names.""" + df = pd.DataFrame( + {calibration_variant_column_name: ["var1", "var2"], calibration_class_column_name: ["A", "B"]} + ) + + validate_calibration_df_column_names(df) + + def test_validate_calibration_df_column_names_case_insensitive(self): + """Test successful validation with different case column names.""" + df = pd.DataFrame( + { + calibration_variant_column_name.upper(): ["var1", "var2"], + calibration_class_column_name.upper(): ["A", "B"], + } + ) + + validate_calibration_df_column_names(df) + + def test_validate_calibration_df_column_names_mixed_case(self): + """Test successful validation with mixed case column names.""" + df = pd.DataFrame( + { + calibration_variant_column_name.capitalize(): ["var1", "var2"], + calibration_class_column_name.capitalize(): ["A", "B"], + } + ) + + validate_calibration_df_column_names(df) + + def test_validate_calibration_df_column_names_non_string_columns(self): + """Test ValidationError when column names are not strings.""" + df = pd.DataFrame({123: ["var1", "var2"], calibration_class_column_name: ["A", "B"]}) + + # Act & Assert + with pytest.raises(ValidationError, match="column names must be strings"): + validate_calibration_df_column_names(df) + + def test_validate_calibration_df_column_names_empty_column_name(self): + """Test ValidationError when column names are empty.""" + df = pd.DataFrame(columns=["", calibration_variant_column_name]) + + # Act & Assert + with pytest.raises(ValidationError, match="column names cannot be empty or whitespace"): + validate_calibration_df_column_names(df) + + def test_validate_calibration_df_column_names_whitespace_column_name(self): + """Test ValidationError when column names contain only whitespace.""" + df = pd.DataFrame(columns=[" ", calibration_class_column_name]) + + # Act & Assert + with pytest.raises(ValidationError, match="column names cannot be empty or whitespace"): + validate_calibration_df_column_names(df) + + def test_validate_calibration_df_column_names_tab_whitespace(self): + """Test ValidationError when column names contain only tab characters.""" + df = pd.DataFrame(columns=["\t\t", calibration_class_column_name]) + + # Act & Assert + with pytest.raises(ValidationError, match="column names cannot be empty or whitespace"): + validate_calibration_df_column_names(df) + + def test_validate_calibration_df_column_names_missing_variant_column(self): + """Test ValidationError when variant column is missing.""" + df = pd.DataFrame({calibration_class_column_name: ["A", "B"], "other": ["X", "Y"]}) + + # Act & Assert + with pytest.raises( + ValidationError, + match="at least one of {} must be present".format( + ", ".join({hgvs_nt_column, hgvs_pro_column, calibration_variant_column_name}) + ), + ): + validate_calibration_df_column_names(df) + + def test_validate_calibration_df_column_names_missing_class_column(self): + """Test ValidationError when class column is missing.""" + df = pd.DataFrame({calibration_variant_column_name: ["var1", "var2"], "other": ["X", "Y"]}) + + # Act & Assert + with pytest.raises(ValidationError, match=f"missing required column: '{calibration_class_column_name}'"): + validate_calibration_df_column_names(df) + + def test_validate_calibration_df_column_names_missing_both_required_columns(self): + """Test ValidationError when both required columns are missing.""" + df = pd.DataFrame({"other1": ["X", "Y"], "other2": ["A", "B"]}) + + # Act & Assert + with pytest.raises(ValidationError, match=f"missing required column: '{calibration_class_column_name}'"): + validate_calibration_df_column_names(df) + + def test_validate_calibration_df_column_names_fewer_than_expected_columns(self): + """Test ValidationError when fewer columns than expected are present.""" + df = pd.DataFrame({calibration_variant_column_name: ["var1", "var2"]}) + + # Act & Assert + with pytest.raises(ValidationError, match=f"missing required column: '{calibration_class_column_name}'"): + validate_calibration_df_column_names(df) + + def test_validate_calibration_df_column_names_duplicate_columns_different_case(self): + """Test ValidationError when duplicate columns exist with different cases.""" + df = pd.DataFrame( + columns=[ + calibration_variant_column_name, + calibration_variant_column_name.upper(), + calibration_class_column_name, + ] + ) + + # Act & Assert + with pytest.raises(ValidationError, match="duplicate column names are not allowed \(case-insensitive\)"): + validate_calibration_df_column_names(df) + + def test_validate_calibration_df_column_names_empty_dataframe(self): + """Test ValidationError when dataframe has no columns.""" + df = pd.DataFrame() + + # Act & Assert + with pytest.raises(ValidationError, match=f"missing required column: '{calibration_class_column_name}'"): + validate_calibration_df_column_names(df) + + def test_validate_calibration_df_column_names_with_numeric_and_string_mix(self): + """Test ValidationError when columns mix numeric and string types.""" + df = pd.DataFrame(columns=["variant", 42.5]) + + # Act & Assert + with pytest.raises(ValidationError, match="column names must be strings"): + validate_calibration_df_column_names(df) + + def test_validate_calibration_df_column_names_newline_in_whitespace(self): + """Test ValidationError when column names contain newline characters.""" + df = pd.DataFrame(columns=["\n\n", "class"]) + + # Act & Assert + with pytest.raises(ValidationError, match="column names cannot be empty or whitespace"): + validate_calibration_df_column_names(df) + + +class TestValidateIndexExistenceInScoreSet: + """Test suite for validate_index_existence_in_score_set function.""" + + @pytest.mark.parametrize( + "index_column_name,existing_resources_return_value,index_values", + [ + ( + calibration_variant_column_name, + ["urn:variant:1", "urn:variant:2", "urn:variant:3"], + ["urn:variant:1", "urn:variant:2", "urn:variant:3"], + ), + ( + hgvs_nt_column, + ["NM_000546.5:c.215C>G", "NM_000546.5:c.743G>A"], + ["NM_000546.5:c.215C>G", "NM_000546.5:c.743G>A"], + ), + ( + hgvs_pro_column, + ["NP_000537.3:p.Arg72Pro", "NP_000537.3:p.Gly248Trp"], + ["NP_000537.3:p.Arg72Pro", "NP_000537.3:p.Gly248Trp"], + ), + ], + ) + def test_validate_index_existence_in_score_set_success( + self, index_column_name, existing_resources_return_value, index_values + ): + """Test successful validation when all variant URNs exist in score set.""" + mock_db = Mock() + mock_scalars = Mock() + mock_scalars.all.return_value = existing_resources_return_value + mock_db.scalars.return_value = mock_scalars + + mock_score_set = Mock() + mock_score_set.id = 123 + + variant_urns = pd.Series(index_values) + + validate_index_existence_in_score_set(mock_db, mock_score_set, variant_urns, index_column_name) + + mock_db.scalars.assert_called_once() + + @pytest.mark.parametrize( + "index_column_name,existing_resources_return_value,index_values", + [ + ( + calibration_variant_column_name, + ["urn:variant:1", "urn:variant:2"], + ["urn:variant:1", "urn:variant:2", "urn:variant:3"], + ), + ( + hgvs_nt_column, + ["NM_000546.5:c.215C>G"], + ["NM_000546.5:c.215C>G", "NM_000546.5:c.743G>A"], + ), + ( + hgvs_pro_column, + ["NP_000537.3:p.Arg72Pro"], + ["NP_000537.3:p.Arg72Pro", "NP_000537.3:p.Gly248Trp"], + ), + ], + ) + def test_validate_index_existence_in_score_set_missing_variants( + self, index_column_name, existing_resources_return_value, index_values + ): + """Test ValidationError when some variant URNs don't exist in score set.""" + mock_db = Mock() + mock_scalars = Mock() + mock_scalars.all.return_value = existing_resources_return_value + mock_db.scalars.return_value = mock_scalars + + mock_score_set = Mock() + mock_score_set.id = 123 + + variant_urns = pd.Series(index_values) + + # Act & Assert + with pytest.raises( + ValidationError, + match="The following resources do not exist in the score set: {}".format( + ", ".join(sorted(set(index_values) - set(existing_resources_return_value))) + ), + ): + validate_index_existence_in_score_set( + mock_db, + mock_score_set, + variant_urns, + index_column_name, + ) + + @pytest.mark.parametrize( + "index_column_name,existing_resources_return_value,index_values", + [ + ( + calibration_variant_column_name, + ["urn:variant:1"], + ["urn:variant:1", "urn:variant:2", "urn:variant:3"], + ), + ( + hgvs_nt_column, + ["NM_000546.5:c.215C>G"], + ["NM_000546.5:c.215C>G", "NM_000546.5:c.743G>A", "NM_000546.5:c.999A>T"], + ), + ( + hgvs_pro_column, + ["NP_000537.3:p.Arg72Pro"], + ["NP_000537.3:p.Arg72Pro", "NP_000537.3:p.Gly248Trp", "NP_000537.3:p.Ser215Ile"], + ), + ], + ) + def test_validate_index_existence_in_score_set_multiple_missing_variants( + self, index_column_name, existing_resources_return_value, index_values + ): + """Test ValidationError when multiple variant resources don't exist in score set.""" + mock_db = Mock() + mock_scalars = Mock() + mock_scalars.all.return_value = existing_resources_return_value + mock_db.scalars.return_value = mock_scalars + + mock_score_set = Mock() + mock_score_set.id = 456 + + variant_urns = pd.Series(index_values) + + # Act & Assert + with pytest.raises( + ValidationError, + match="The following resources do not exist in the score set: {}".format( + ", ".join(sorted(set(index_values) - set(existing_resources_return_value))) + ), + ): + validate_index_existence_in_score_set(mock_db, mock_score_set, variant_urns, index_column_name) + + @pytest.mark.parametrize( + "index_column_name,existing_resources_return_value,index_values", + [ + (calibration_variant_column_name, [], ["urn:variant:1", "urn:variant:2", "urn:variant:3"]), + (hgvs_nt_column, [], ["NM_000546.5:c.215C>G", "NM_000546.5:c.743G>A"]), + (hgvs_pro_column, [], ["NP_000537.3:p.Arg72Pro", "NP_000537.3:p.Gly248Trp"]), + ], + ) + def test_validate_index_existence_in_score_set_all_missing( + self, index_column_name, existing_resources_return_value, index_values + ): + """Test ValidationError when all variant URNs are missing from score set.""" + mock_db = Mock() + mock_scalars = Mock() + mock_scalars.all.return_value = existing_resources_return_value + mock_db.scalars.return_value = mock_scalars + + mock_score_set = Mock() + mock_score_set.id = 789 + + variant_urns = pd.Series(index_values) + + # Act & Assert + with pytest.raises( + ValidationError, + match="The following resources do not exist in the score set: {}".format( + ", ".join(sorted(set(index_values) - set(existing_resources_return_value))) + ), + ): + validate_index_existence_in_score_set(mock_db, mock_score_set, variant_urns, index_column_name) + + @pytest.mark.parametrize( + "index_column_name,existing_resources_return_value,index_values", + [ + (calibration_variant_column_name, [], []), + (hgvs_nt_column, [], []), + (hgvs_pro_column, [], []), + ], + ) + def test_validate_index_existence_in_score_set_empty_series( + self, index_column_name, existing_resources_return_value, index_values + ): + """Test successful validation with empty index resources series.""" + mock_db = Mock() + mock_scalars = Mock() + mock_scalars.all.return_value = existing_resources_return_value + mock_db.scalars.return_value = mock_scalars + + mock_score_set = Mock() + mock_score_set.id = 123 + + variant_urns = pd.Series(index_values, dtype=object) + + # Act & Assert - should not raise any exception + validate_index_existence_in_score_set(mock_db, mock_score_set, variant_urns, index_column_name) + + @pytest.mark.parametrize( + "index_column_name,existing_resources_return_value,index_values", + [ + (calibration_variant_column_name, ["urn:variant:1"], ["urn:variant:1"]), + (hgvs_nt_column, ["NM_000546.5:c.215C>G"], ["NM_000546.5:c.215C>G"]), + (hgvs_pro_column, ["NP_000537.3:p.Arg72Pro"], ["NP_000537.3:p.Arg72Pro"]), + ], + ) + def test_validate_calibration_index_existence_single_variant( + self, index_column_name, existing_resources_return_value, index_values + ): + """Test successful validation with single index value URN.""" + mock_db = Mock() + mock_scalars = Mock() + mock_scalars.all.return_value = existing_resources_return_value + mock_db.scalars.return_value = mock_scalars + + mock_score_set = Mock() + mock_score_set.id = 123 + + variant_urns = pd.Series(index_values) + + # Act & Assert - should not raise any exception + validate_index_existence_in_score_set(mock_db, mock_score_set, variant_urns, index_column_name) + + @pytest.mark.parametrize( + "index_column_name,existing_resources_return_value,index_values", + [ + ( + calibration_variant_column_name, + ["urn:variant:1", "urn:variant:2"], + [ + "urn:variant:1", + "urn:variant:2", + "urn:variant:1", + "urn:variant:2", + ], + ), + ( + hgvs_nt_column, + ["NM_000546.5:c.215C>G", "NM_000546.5:c.743G>A"], + [ + "NM_000546.5:c.215C>G", + "NM_000546.5:c.743G>A", + "NM_000546.5:c.215C>G", + "NM_000546.5:c.743G>A", + ], + ), + ( + hgvs_pro_column, + ["NP_000537.3:p.Arg72Pro", "NP_000537.3:p.Gly248Trp"], + [ + "NP_000537.3:p.Arg72Pro", + "NP_000537.3:p.Gly248Trp", + "NP_000537.3:p.Arg72Pro", + "NP_000537.3:p.Gly248Trp", + ], + ), + ], + ) + def test_validate_calibration_index_existence_duplicate_values_in_series( + self, index_column_name, existing_resources_return_value, index_values + ): + """Test validation with duplicate index values in input series.""" + mock_db = Mock() + mock_scalars = Mock() + mock_scalars.all.return_value = existing_resources_return_value + mock_db.scalars.return_value = mock_scalars + + mock_score_set = Mock() + mock_score_set.id = 123 + + variant_urns = pd.Series(index_values) + + # Act & Assert - should not raise any exception + validate_index_existence_in_score_set(mock_db, mock_score_set, variant_urns, index_column_name) + + @pytest.mark.parametrize( + "index_column_name,existing_resources_return_value,index_values", + [ + ( + calibration_variant_column_name, + ["urn:variant:1", "urn:variant:2", "urn:variant:3"], + ["urn:variant:1", "urn:variant:2", "urn:variant:3"], + ), + ( + hgvs_nt_column, + ["NM_000546.5:c.215C>G", "NM_000546.5:c.743G>A", "NM_000546.5:c.999A>T"], + ["NM_000546.5:c.215C>G", "NM_000546.5:c.743G>A", "NM_000546.5:c.999A>T"], + ), + ( + hgvs_pro_column, + ["NP_000537.3:p.Arg72Pro", "NP_000537.3:p.Gly248Trp", "NP_000537.3:p.Ser215Ile"], + ["NP_000537.3:p.Arg72Pro", "NP_000537.3:p.Gly248Trp", "NP_000537.3:p.Ser215Ile"], + ), + ], + ) + def test_validate_calibration_index_existence_database_query_parameters( + self, index_column_name, existing_resources_return_value, index_values + ): + """Test that database query is constructed with correct parameters.""" + mock_db = Mock() + mock_scalars = Mock() + mock_scalars.all.return_value = existing_resources_return_value + mock_db.scalars.return_value = mock_scalars + + mock_score_set = Mock() + mock_score_set.id = 999 + + variant_urns = pd.Series(index_values) + + validate_index_existence_in_score_set(mock_db, mock_score_set, variant_urns, index_column_name) + + mock_db.scalars.assert_called_once() + + +class TestValidateCalibrationClasses: + """Test suite for validate_calibration_classes function.""" + + def test_validate_calibration_classes_success(self): + """Test successful validation when all classes match.""" + mock_classification1 = Mock() + mock_classification1.class_ = "class_a" + mock_classification2 = Mock() + mock_classification2.class_ = "class_b" + + calibration = Mock(spec=score_calibration.ScoreCalibrationCreate) + calibration.functional_classifications = [mock_classification1, mock_classification2] + + classes = pd.Series(["class_a", "class_b", "class_a"]) + + validate_calibration_classes(calibration, classes) + + def test_validate_calibration_classes_no_functional_classifications(self): + """Test ValidationError when calibration has no functional classifications.""" + calibration = Mock(spec=score_calibration.ScoreCalibrationCreate) + calibration.functional_classifications = None + classes = pd.Series(["class_a", "class_b"]) + + with pytest.raises( + ValidationError, match="Calibration must have functional classifications defined for class validation." + ): + validate_calibration_classes(calibration, classes) + + def test_validate_calibration_classes_empty_functional_classifications(self): + """Test ValidationError when calibration has empty functional classifications.""" + calibration = Mock(spec=score_calibration.ScoreCalibrationCreate) + calibration.functional_classifications = [] + classes = pd.Series(["class_a", "class_b"]) + + with pytest.raises( + ValidationError, match="Calibration must have functional classifications defined for class validation." + ): + validate_calibration_classes(calibration, classes) + + def test_validate_calibration_classes_undefined_classes_in_series(self): + """Test ValidationError when series contains undefined classes.""" + mock_classification = Mock() + mock_classification.class_ = "class_a" + + calibration = Mock(spec=score_calibration.ScoreCalibrationCreate) + calibration.functional_classifications = [mock_classification] + + classes = pd.Series(["class_a", "class_b", "class_c"]) + + with pytest.raises( + ValidationError, match="The following classes are not defined in the calibration: class_b, class_c" + ): + validate_calibration_classes(calibration, classes) + + def test_validate_calibration_classes_missing_defined_classes(self): + """Test ValidationError when defined classes are missing from series.""" + mock_classification1 = Mock() + mock_classification1.class_ = "class_a" + mock_classification2 = Mock() + mock_classification2.class_ = "class_b" + mock_classification3 = Mock() + mock_classification3.class_ = "class_c" + + calibration = Mock(spec=score_calibration.ScoreCalibrationCreate) + calibration.functional_classifications = [mock_classification1, mock_classification2, mock_classification3] + + classes = pd.Series(["class_a", "class_b"]) + + with pytest.raises( + ValidationError, match="Some defined classes in the calibration are missing from the classes file." + ): + validate_calibration_classes(calibration, classes) + + def test_validate_calibration_classes_with_modify_object(self): + """Test function works with ScoreCalibrationModify object.""" + mock_classification = Mock() + mock_classification.class_ = "class_a" + + calibration = Mock(spec=score_calibration.ScoreCalibrationModify) + calibration.functional_classifications = [mock_classification] + + classes = pd.Series(["class_a"]) + + validate_calibration_classes(calibration, classes) + + def test_validate_calibration_classes_empty_series(self): + """Test ValidationError when classes series is empty but calibration has classifications.""" + mock_classification = Mock() + mock_classification.class_ = "class_a" + + calibration = Mock(spec=score_calibration.ScoreCalibrationCreate) + calibration.functional_classifications = [mock_classification] + + classes = pd.Series([], dtype=object) + + with pytest.raises( + ValidationError, match="Some defined classes in the calibration are missing from the classes file." + ): + validate_calibration_classes(calibration, classes) + + def test_validate_calibration_classes_duplicate_classes_in_series(self): + """Test successful validation with duplicate classes in series.""" + mock_classification1 = Mock() + mock_classification1.class_ = "class_a" + mock_classification2 = Mock() + mock_classification2.class_ = "class_b" + + calibration = Mock(spec=score_calibration.ScoreCalibrationCreate) + calibration.functional_classifications = [mock_classification1, mock_classification2] + + classes = pd.Series(["class_a", "class_a", "class_b", "class_b", "class_a"]) + + validate_calibration_classes(calibration, classes) + + def test_validate_calibration_classes_single_class(self): + """Test successful validation with single class.""" + mock_classification = Mock() + mock_classification.class_ = "single_class" + + calibration = Mock(spec=score_calibration.ScoreCalibrationCreate) + calibration.functional_classifications = [mock_classification] + + classes = pd.Series(["single_class", "single_class"]) + + validate_calibration_classes(calibration, classes) + + +class TestChooseCalibrationIndexColumn: + """Test suite for choose_calibration_index_column function.""" + + def test_choose_variant_column_priority(self): + """Should return the variant column if present and not all NaN.""" + df = pd.DataFrame( + { + calibration_variant_column_name: ["v1", "v2"], + calibration_class_column_name: ["A", "B"], + hgvs_nt_column: [None, None], + hgvs_pro_column: [None, None], + } + ) + result = choose_calibration_index_column(df) + assert result == calibration_variant_column_name + + def test_choose_hgvs_nt_column_if_variant_missing(self): + """Should return hgvs_nt_column if variant column is missing or all NaN.""" + df = pd.DataFrame( + { + hgvs_nt_column: ["c.1A>G", "c.2T>C"], + calibration_class_column_name: ["A", "B"], + } + ) + result = choose_calibration_index_column(df) + assert result == hgvs_nt_column + + def test_choose_hgvs_pro_column_if_variant_and_nt_missing(self): + """Should return hgvs_pro_column if variant and hgvs_nt columns are missing or all NaN.""" + df = pd.DataFrame( + { + hgvs_pro_column: ["p.A1G", "p.T2C"], + calibration_class_column_name: ["A", "B"], + } + ) + result = choose_calibration_index_column(df) + assert result == hgvs_pro_column + + def test_ignores_all_nan_columns(self): + """Should ignore columns that are all NaN when choosing index column.""" + df = pd.DataFrame( + { + calibration_variant_column_name: [float("nan"), float("nan")], + hgvs_nt_column: ["c.1A>G", "c.2T>C"], + calibration_class_column_name: ["A", "B"], + } + ) + result = choose_calibration_index_column(df) + assert result == hgvs_nt_column + + def test_case_insensitive_column_names(self): + """Should handle column names in different cases.""" + df = pd.DataFrame( + { + calibration_variant_column_name.upper(): ["v1", "v2"], + calibration_class_column_name.capitalize(): ["A", "B"], + } + ) + result = choose_calibration_index_column(df) + assert result == calibration_variant_column_name.upper() + + def test_raises_if_no_valid_index_column(self): + """Should raise ValidationError if no valid index column is found.""" + df = pd.DataFrame( + { + calibration_class_column_name: ["A", "B"], + "other": ["x", "y"], + } + ) + with pytest.raises(ValidationError, match="failed to find valid calibration index column"): + choose_calibration_index_column(df) + + def test_raises_if_all_index_columns_are_nan(self): + """Should raise ValidationError if all possible index columns are all NaN.""" + df = pd.DataFrame( + { + calibration_variant_column_name: [float("nan"), float("nan")], + hgvs_nt_column: [float("nan"), float("nan")], + hgvs_pro_column: [float("nan"), float("nan")], + calibration_class_column_name: ["A", "B"], + } + ) + with pytest.raises(ValidationError, match="failed to find valid calibration index column"): + choose_calibration_index_column(df) diff --git a/tests/validation/dataframe/test_dataframe.py b/tests/validation/dataframe/test_dataframe.py index 4c8334de..daf3fd63 100644 --- a/tests/validation/dataframe/test_dataframe.py +++ b/tests/validation/dataframe/test_dataframe.py @@ -13,6 +13,7 @@ required_score_column, ) from mavedb.lib.validation.dataframe.dataframe import ( + STANDARD_COLUMNS, choose_dataframe_index_column, sort_dataframe_columns, standardize_dataframe, @@ -93,32 +94,36 @@ def test_sort_dataframe_preserves_extras_order(self): class TestStandardizeDataframe(DfTestCase): def test_preserve_standardized(self): - standardized_df = standardize_dataframe(self.dataframe) + standardized_df = standardize_dataframe(self.dataframe, STANDARD_COLUMNS) pd.testing.assert_frame_equal(self.dataframe, standardized_df) def test_standardize_changes_case_variants(self): - standardized_df = standardize_dataframe(self.dataframe.rename(columns={hgvs_nt_column: hgvs_nt_column.upper()})) + standardized_df = standardize_dataframe( + self.dataframe.rename(columns={hgvs_nt_column: hgvs_nt_column.upper()}), STANDARD_COLUMNS + ) pd.testing.assert_frame_equal(self.dataframe, standardized_df) def test_standardize_changes_case_scores(self): standardized_df = standardize_dataframe( - self.dataframe.rename(columns={required_score_column: required_score_column.title()}) + self.dataframe.rename(columns={required_score_column: required_score_column.title()}), STANDARD_COLUMNS ) pd.testing.assert_frame_equal(self.dataframe, standardized_df) def test_standardize_preserves_extras_case(self): - standardized_df = standardize_dataframe(self.dataframe.rename(columns={"extra": "extra".upper()})) + standardized_df = standardize_dataframe( + self.dataframe.rename(columns={"extra": "extra".upper()}), STANDARD_COLUMNS + ) pd.testing.assert_frame_equal(self.dataframe.rename(columns={"extra": "extra".upper()}), standardized_df) def test_standardize_removes_quotes(self): standardized_df = standardize_dataframe( - self.dataframe.rename(columns={"extra": "'extra'", "extra2": '"extra2"'}) + self.dataframe.rename(columns={"extra": "'extra'", "extra2": '"extra2"'}), STANDARD_COLUMNS ) pd.testing.assert_frame_equal(self.dataframe, standardized_df) def test_standardize_removes_whitespace(self): standardized_df = standardize_dataframe( - self.dataframe.rename(columns={"extra": " extra ", "extra2": " extra2"}) + self.dataframe.rename(columns={"extra": " extra ", "extra2": " extra2"}), STANDARD_COLUMNS ) pd.testing.assert_frame_equal(self.dataframe, standardized_df) @@ -135,7 +140,8 @@ def test_standardize_sorts_columns(self): "count1", "extra", ], - ] + ], + STANDARD_COLUMNS, ) pd.testing.assert_frame_equal( self.dataframe[ diff --git a/tests/view_models/test_acmg_classification.py b/tests/view_models/test_acmg_classification.py index f7b68149..4640196c 100644 --- a/tests/view_models/test_acmg_classification.py +++ b/tests/view_models/test_acmg_classification.py @@ -1,21 +1,22 @@ -import pytest from copy import deepcopy -from mavedb.lib.exceptions import ValidationError -from mavedb.view_models.acmg_classification import ACMGClassificationCreate, ACMGClassification +import pytest +from mavedb.lib.exceptions import ValidationError +from mavedb.models.enums.acmg_criterion import ACMGCriterion +from mavedb.models.enums.strength_of_evidence import StrengthOfEvidenceProvided +from mavedb.view_models.acmg_classification import ACMGClassification, ACMGClassificationCreate from tests.helpers.constants import ( TEST_ACMG_BS3_STRONG_CLASSIFICATION, - TEST_ACMG_PS3_STRONG_CLASSIFICATION, TEST_ACMG_BS3_STRONG_CLASSIFICATION_WITH_POINTS, + TEST_ACMG_PS3_STRONG_CLASSIFICATION, TEST_ACMG_PS3_STRONG_CLASSIFICATION_WITH_POINTS, TEST_SAVED_ACMG_BS3_STRONG_CLASSIFICATION, - TEST_SAVED_ACMG_PS3_STRONG_CLASSIFICATION, TEST_SAVED_ACMG_BS3_STRONG_CLASSIFICATION_WITH_POINTS, + TEST_SAVED_ACMG_PS3_STRONG_CLASSIFICATION, TEST_SAVED_ACMG_PS3_STRONG_CLASSIFICATION_WITH_POINTS, ) - ### ACMG Classification Creation Tests ### @@ -33,8 +34,8 @@ def test_can_create_acmg_classification(valid_acmg_classification): acmg = ACMGClassificationCreate(**valid_acmg_classification) assert isinstance(acmg, ACMGClassificationCreate) - assert acmg.criterion == valid_acmg_classification.get("criterion") - assert acmg.evidence_strength == valid_acmg_classification.get("evidence_strength") + assert acmg.criterion.value == valid_acmg_classification.get("criterion") + assert acmg.evidence_strength.value == valid_acmg_classification.get("evidence_strength") assert acmg.points == valid_acmg_classification.get("points") @@ -78,8 +79,8 @@ def test_can_create_acmg_classification_from_points(): acmg = ACMGClassificationCreate(points=-4) # BS3 Strong assert isinstance(acmg, ACMGClassificationCreate) - assert acmg.criterion == "BS3" - assert acmg.evidence_strength == "strong" + assert acmg.criterion == ACMGCriterion.BS3 + assert acmg.evidence_strength == StrengthOfEvidenceProvided.STRONG assert acmg.points == -4 @@ -100,6 +101,6 @@ def test_can_create_acmg_classification_from_saved_data(valid_saved_classificati acmg = ACMGClassification(**valid_saved_classification) assert isinstance(acmg, ACMGClassification) - assert acmg.criterion == valid_saved_classification.get("criterion") - assert acmg.evidence_strength == valid_saved_classification.get("evidenceStrength") + assert acmg.criterion.value == valid_saved_classification.get("criterion") + assert acmg.evidence_strength.value == valid_saved_classification.get("evidenceStrength") assert acmg.points == valid_saved_classification.get("points") diff --git a/tests/view_models/test_collection.py b/tests/view_models/test_collection.py new file mode 100644 index 00000000..b22cee2a --- /dev/null +++ b/tests/view_models/test_collection.py @@ -0,0 +1,120 @@ +import pytest +from pydantic import ValidationError + +from mavedb.models.enums.contribution_role import ContributionRole +from mavedb.view_models.collection import Collection, SavedCollection +from tests.helpers.constants import TEST_COLLECTION_RESPONSE +from tests.helpers.util.common import dummy_attributed_object_from_dict + + +@pytest.mark.parametrize( + "exclude,expected_missing_fields", + [ + ("user_associations", ["admins", "editors", "viewers"]), + ("score_sets", ["scoreSetUrns"]), + ("experiments", ["experimentUrns"]), + ], +) +def test_cannot_create_saved_experiment_without_all_attributed_properties(exclude, expected_missing_fields): + collection = TEST_COLLECTION_RESPONSE.copy() + collection["urn"] = "urn:mavedb:collection-xxx" + + # Remove pre-existing synthetic properties + collection.pop("experimentUrns", None) + collection.pop("scoreSetUrns", None) + collection.pop("admins", None) + collection.pop("editors", None) + collection.pop("viewers", None) + + # Set synthetic properties with dummy attributed objects to mock SQLAlchemy model objects. + collection["experiments"] = [dummy_attributed_object_from_dict({"urn": "urn:mavedb:experiment-xxx"})] + collection["score_sets"] = [ + dummy_attributed_object_from_dict({"urn": "urn:mavedb:score_set-xxx", "superseding_score_set": None}) + ] + collection["user_associations"] = [ + dummy_attributed_object_from_dict( + { + "contribution_role": ContributionRole.admin, + "user": {"id": 1, "username": "test_user", "email": "test_user@example.com"}, + } + ), + dummy_attributed_object_from_dict( + { + "contribution_role": ContributionRole.editor, + "user": {"id": 1, "username": "test_user", "email": "test_user@example.com"}, + } + ), + dummy_attributed_object_from_dict( + { + "contribution_role": ContributionRole.viewer, + "user": {"id": 1, "username": "test_user", "email": "test_user@example.com"}, + } + ), + ] + + collection.pop(exclude) + collection_attributed_object = dummy_attributed_object_from_dict(collection) + with pytest.raises(ValidationError) as exc_info: + SavedCollection.model_validate(collection_attributed_object) + + # Should fail with missing fields coerced from missing attributed properties + msg = str(exc_info.value) + assert "Field required" in msg + for field in expected_missing_fields: + assert field in msg + + +def test_saved_collection_can_be_created_with_all_attributed_properties(): + collection = TEST_COLLECTION_RESPONSE.copy() + urn = "urn:mavedb:collection-xxx" + collection["urn"] = urn + + # Remove pre-existing synthetic properties + collection.pop("experimentUrns", None) + collection.pop("scoreSetUrns", None) + collection.pop("admins", None) + collection.pop("editors", None) + collection.pop("viewers", None) + + # Set synthetic properties with dummy attributed objects to mock SQLAlchemy model objects. + collection["experiments"] = [dummy_attributed_object_from_dict({"urn": "urn:mavedb:experiment-xxx"})] + collection["score_sets"] = [ + dummy_attributed_object_from_dict({"urn": "urn:mavedb:score_set-xxx", "superseding_score_set": None}) + ] + collection["user_associations"] = [ + dummy_attributed_object_from_dict( + { + "contribution_role": ContributionRole.admin, + "user": {"id": 1, "username": "test_user", "email": "test_user@example.com"}, + } + ), + dummy_attributed_object_from_dict( + { + "contribution_role": ContributionRole.editor, + "user": {"id": 1, "username": "test_user", "email": "test_user@example.com"}, + } + ), + dummy_attributed_object_from_dict( + { + "contribution_role": ContributionRole.viewer, + "user": {"id": 1, "username": "test_user", "email": "test_user@example.com"}, + } + ), + ] + + collection_attributed_object = dummy_attributed_object_from_dict(collection) + model = SavedCollection.model_validate(collection_attributed_object) + assert model.name == TEST_COLLECTION_RESPONSE["name"] + assert model.urn == urn + assert len(model.admins) == 1 + assert len(model.editors) == 1 + assert len(model.viewers) == 1 + assert len(model.experiment_urns) == 1 + assert len(model.score_set_urns) == 1 + + +def test_collection_can_be_created_from_non_orm_context(): + data = dict(TEST_COLLECTION_RESPONSE) + data["urn"] = "urn:mavedb:collection-xxx" + model = Collection.model_validate(data) + assert model.urn == data["urn"] diff --git a/tests/view_models/test_experiment.py b/tests/view_models/test_experiment.py index 9f0c3e67..aab3c85f 100644 --- a/tests/view_models/test_experiment.py +++ b/tests/view_models/test_experiment.py @@ -1,17 +1,18 @@ import pytest +from pydantic import ValidationError -from mavedb.view_models.experiment import ExperimentCreate, SavedExperiment +from mavedb.view_models.experiment import Experiment, ExperimentCreate, SavedExperiment from mavedb.view_models.publication_identifier import PublicationIdentifier - from tests.helpers.constants import ( - VALID_EXPERIMENT_URN, - VALID_SCORE_SET_URN, - VALID_EXPERIMENT_SET_URN, + SAVED_BIORXIV_PUBLICATION, + SAVED_PUBMED_PUBLICATION, + TEST_BIORXIV_IDENTIFIER, TEST_MINIMAL_EXPERIMENT, TEST_MINIMAL_EXPERIMENT_RESPONSE, - SAVED_PUBMED_PUBLICATION, TEST_PUBMED_IDENTIFIER, - TEST_BIORXIV_IDENTIFIER, + VALID_EXPERIMENT_SET_URN, + VALID_EXPERIMENT_URN, + VALID_SCORE_SET_URN, ) from tests.helpers.util.common import dummy_attributed_object_from_dict @@ -237,8 +238,15 @@ def test_saved_experiment_synthetic_properties(): ) -@pytest.mark.parametrize("exclude", ["publication_identifier_associations", "score_sets", "experiment_set"]) -def test_cannot_create_saved_experiment_without_all_attributed_properties(exclude): +@pytest.mark.parametrize( + "exclude,expected_missing_fields", + [ + ("publication_identifier_associations", ["primaryPublicationIdentifiers", "secondaryPublicationIdentifiers"]), + ("score_sets", ["scoreSetUrns"]), + ("experiment_set", ["experimentSetUrn"]), + ], +) +def test_cannot_create_saved_experiment_without_all_attributed_properties(exclude, expected_missing_fields): experiment = TEST_MINIMAL_EXPERIMENT_RESPONSE.copy() experiment["urn"] = VALID_EXPERIMENT_URN @@ -280,11 +288,14 @@ def test_cannot_create_saved_experiment_without_all_attributed_properties(exclud experiment.pop(exclude) experiment_attributed_object = dummy_attributed_object_from_dict(experiment) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValidationError) as exc_info: SavedExperiment.model_validate(experiment_attributed_object) - assert "Unable to create SavedExperiment without attribute" in str(exc_info.value) - assert exclude in str(exc_info.value) + # Should fail with missing fields coerced from missing attributed properties + msg = str(exc_info.value) + assert "Field required" in msg + for field in expected_missing_fields: + assert field in msg def test_can_create_experiment_with_nonetype_experiment_set_urn(): @@ -303,3 +314,20 @@ def test_cant_create_experiment_with_invalid_experiment_set_urn(): ExperimentCreate(**experiment_test) assert f"'{experiment_test['experiment_set_urn']}' is not a valid experiment set URN" in str(exc_info.value) + + +def test_can_create_experiment_from_non_orm_context(): + experiment = TEST_MINIMAL_EXPERIMENT_RESPONSE.copy() + experiment["urn"] = VALID_EXPERIMENT_URN + experiment["experimentSetUrn"] = VALID_EXPERIMENT_SET_URN + experiment["scoreSetUrns"] = [VALID_SCORE_SET_URN] + experiment["primaryPublicationIdentifiers"] = [SAVED_PUBMED_PUBLICATION] + experiment["secondaryPublicationIdentifiers"] = [SAVED_PUBMED_PUBLICATION, SAVED_BIORXIV_PUBLICATION] + + # Should not require any ORM attributes + saved_experiment = Experiment.model_validate(experiment) + assert saved_experiment.urn == VALID_EXPERIMENT_URN + assert saved_experiment.experiment_set_urn == VALID_EXPERIMENT_SET_URN + assert saved_experiment.score_set_urns == [VALID_SCORE_SET_URN] + assert len(saved_experiment.primary_publication_identifiers) == 1 + assert len(saved_experiment.secondary_publication_identifiers) == 2 diff --git a/tests/view_models/test_mapped_variant.py b/tests/view_models/test_mapped_variant.py index 09866219..1a41a7ef 100644 --- a/tests/view_models/test_mapped_variant.py +++ b/tests/view_models/test_mapped_variant.py @@ -1,10 +1,9 @@ import pytest from pydantic import ValidationError -from mavedb.view_models.mapped_variant import MappedVariantCreate, MappedVariant - -from tests.helpers.util.common import dummy_attributed_object_from_dict +from mavedb.view_models.mapped_variant import MappedVariant, MappedVariantCreate from tests.helpers.constants import TEST_MINIMAL_MAPPED_VARIANT, TEST_MINIMAL_MAPPED_VARIANT_CREATE, VALID_VARIANT_URN +from tests.helpers.util.common import dummy_attributed_object_from_dict def test_minimal_mapped_variant_create(): @@ -72,10 +71,32 @@ def test_cannot_create_mapped_variant_without_variant(): MappedVariantCreate(**mapped_variant_create) +def test_can_create_saved_mapped_variant_with_variant_object(): + mapped_variant = TEST_MINIMAL_MAPPED_VARIANT.copy() + mapped_variant["id"] = 1 + + saved_mapped_variant = MappedVariant.model_validate( + dummy_attributed_object_from_dict( + {**mapped_variant, "variant": dummy_attributed_object_from_dict({"urn": VALID_VARIANT_URN})} + ) + ) + + assert all(saved_mapped_variant.__getattribute__(k) == v for k, v in mapped_variant.items()) + assert saved_mapped_variant.variant_urn == VALID_VARIANT_URN + + def test_cannot_save_mapped_variant_without_variant(): mapped_variant = TEST_MINIMAL_MAPPED_VARIANT.copy() mapped_variant["id"] = 1 - mapped_variant["variant"] = dummy_attributed_object_from_dict({"urn": None}) + mapped_variant["variant"] = None with pytest.raises(ValidationError): MappedVariant.model_validate(dummy_attributed_object_from_dict({**mapped_variant})) + + +def test_can_create_mapped_variant_from_non_orm_context(): + mapped_variant_create = TEST_MINIMAL_MAPPED_VARIANT_CREATE.copy() + mapped_variant_create["variant_urn"] = VALID_VARIANT_URN + created_mapped_variant = MappedVariantCreate.model_validate(mapped_variant_create) + + assert all(created_mapped_variant.__getattribute__(k) == v for k, v in mapped_variant_create.items()) diff --git a/tests/view_models/test_score_calibration.py b/tests/view_models/test_score_calibration.py index bf89aec4..e96ab010 100644 --- a/tests/view_models/test_score_calibration.py +++ b/tests/view_models/test_score_calibration.py @@ -4,28 +4,34 @@ from pydantic import ValidationError from mavedb.lib.acmg import ACMGCriterion +from mavedb.models.enums.functional_classification import FunctionalClassification as FunctionalClassificationOptions from mavedb.models.enums.score_calibration_relation import ScoreCalibrationRelation from mavedb.view_models.score_calibration import ( - FunctionalRangeCreate, + FunctionalClassificationCreate, ScoreCalibration, ScoreCalibrationCreate, ScoreCalibrationWithScoreSetUrn, ) from tests.helpers.constants import ( - TEST_BRNICH_SCORE_CALIBRATION, + TEST_BRNICH_SCORE_CALIBRATION_CLASS_BASED, + TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, + TEST_FUNCTIONAL_CLASSIFICATION_ABNORMAL, + TEST_FUNCTIONAL_CLASSIFICATION_NORMAL, + TEST_FUNCTIONAL_CLASSIFICATION_NOT_SPECIFIED, TEST_FUNCTIONAL_RANGE_ABNORMAL, TEST_FUNCTIONAL_RANGE_INCLUDING_NEGATIVE_INFINITY, TEST_FUNCTIONAL_RANGE_INCLUDING_POSITIVE_INFINITY, TEST_FUNCTIONAL_RANGE_NORMAL, TEST_FUNCTIONAL_RANGE_NOT_SPECIFIED, TEST_PATHOGENICITY_SCORE_CALIBRATION, - TEST_SAVED_BRNICH_SCORE_CALIBRATION, + TEST_SAVED_BRNICH_SCORE_CALIBRATION_CLASS_BASED, + TEST_SAVED_BRNICH_SCORE_CALIBRATION_RANGE_BASED, TEST_SAVED_PATHOGENICITY_SCORE_CALIBRATION, ) from tests.helpers.util.common import dummy_attributed_object_from_dict ############################################################################## -# Tests for FunctionalRange view models +# Tests for FunctionalClassification view models ############################################################################## @@ -33,38 +39,78 @@ @pytest.mark.parametrize( - "functional_range", + "functional_classification", [ TEST_FUNCTIONAL_RANGE_NORMAL, TEST_FUNCTIONAL_RANGE_ABNORMAL, TEST_FUNCTIONAL_RANGE_NOT_SPECIFIED, + TEST_FUNCTIONAL_CLASSIFICATION_NORMAL, + TEST_FUNCTIONAL_CLASSIFICATION_ABNORMAL, + TEST_FUNCTIONAL_CLASSIFICATION_NOT_SPECIFIED, TEST_FUNCTIONAL_RANGE_INCLUDING_POSITIVE_INFINITY, TEST_FUNCTIONAL_RANGE_INCLUDING_NEGATIVE_INFINITY, ], ) -def test_can_create_valid_functional_range(functional_range): - fr = FunctionalRangeCreate.model_validate(functional_range) +def test_can_create_valid_functional_classification(functional_classification): + fr = FunctionalClassificationCreate.model_validate(functional_classification) - assert fr.label == functional_range["label"] - assert fr.description == functional_range.get("description") - assert fr.classification == functional_range["classification"] - assert fr.range == tuple(functional_range["range"]) - assert fr.inclusive_lower_bound == functional_range.get("inclusive_lower_bound", True) - assert fr.inclusive_upper_bound == functional_range.get("inclusive_upper_bound", False) + assert fr.label == functional_classification["label"] + assert fr.description == functional_classification.get("description") + assert fr.functional_classification.value == functional_classification["functional_classification"] + assert fr.inclusive_lower_bound == functional_classification.get("inclusive_lower_bound") + assert fr.inclusive_upper_bound == functional_classification.get("inclusive_upper_bound") + if "range" in functional_classification: + assert fr.range == tuple(functional_classification["range"]) + assert fr.range_based is True + assert fr.class_based is False + elif "class" in functional_classification: + assert fr.class_ == functional_classification["class"] + assert fr.range_based is False + assert fr.class_based is True -def test_cannot_create_functional_range_with_reversed_range(): + +@pytest.mark.parametrize( + "property_name", + [ + "label", + "class", + ], +) +def test_cannot_create_functional_classification_when_string_fields_empty(property_name): + invalid_data = deepcopy(TEST_FUNCTIONAL_CLASSIFICATION_NORMAL) + invalid_data[property_name] = " " + with pytest.raises(ValidationError, match="This field may not be empty or contain only whitespace."): + FunctionalClassificationCreate.model_validate(invalid_data) + + +def test_cannot_create_functional_classification_without_range_or_class(): + invalid_data = deepcopy(TEST_FUNCTIONAL_RANGE_NORMAL) + invalid_data["range"] = None + invalid_data["class"] = None + with pytest.raises(ValidationError, match="A functional range must specify either a numeric range or a class."): + FunctionalClassificationCreate.model_validate(invalid_data) + + +def test_cannot_create_functional_classification_with_both_range_and_class(): + invalid_data = deepcopy(TEST_FUNCTIONAL_RANGE_NORMAL) + invalid_data["class"] = "some_class" + with pytest.raises(ValidationError, match="A functional range may not specify both a numeric range and a class."): + FunctionalClassificationCreate.model_validate(invalid_data) + + +def test_cannot_create_functional_classification_with_reversed_range(): invalid_data = deepcopy(TEST_FUNCTIONAL_RANGE_NORMAL) invalid_data["range"] = (2, 1) with pytest.raises(ValidationError, match="The lower bound cannot exceed the upper bound."): - FunctionalRangeCreate.model_validate(invalid_data) + FunctionalClassificationCreate.model_validate(invalid_data) -def test_cannot_create_functional_range_with_equal_bounds(): +def test_cannot_create_functional_classification_with_equal_bounds(): invalid_data = deepcopy(TEST_FUNCTIONAL_RANGE_NORMAL) invalid_data["range"] = (1, 1) with pytest.raises(ValidationError, match="The lower and upper bounds cannot be identical."): - FunctionalRangeCreate.model_validate(invalid_data) + FunctionalClassificationCreate.model_validate(invalid_data) def test_can_create_range_with_infinity_bounds(): @@ -73,71 +119,71 @@ def test_can_create_range_with_infinity_bounds(): valid_data["inclusive_upper_bound"] = False valid_data["range"] = (None, None) - fr = FunctionalRangeCreate.model_validate(valid_data) + fr = FunctionalClassificationCreate.model_validate(valid_data) assert fr.range == (None, None) @pytest.mark.parametrize("ratio_property", ["oddspaths_ratio", "positive_likelihood_ratio"]) -def test_cannot_create_functional_range_with_negative_ratios(ratio_property): +def test_cannot_create_functional_classification_with_negative_ratios(ratio_property): invalid_data = deepcopy(TEST_FUNCTIONAL_RANGE_NORMAL) invalid_data[ratio_property] = -1.0 with pytest.raises(ValidationError, match="The ratio must be greater than or equal to 0."): - FunctionalRangeCreate.model_validate(invalid_data) + FunctionalClassificationCreate.model_validate(invalid_data) -def test_cannot_create_functional_range_with_inclusive_bounds_at_infinity(): +def test_cannot_create_functional_classification_with_inclusive_bounds_at_infinity(): invalid_data = deepcopy(TEST_FUNCTIONAL_RANGE_INCLUDING_POSITIVE_INFINITY) invalid_data["inclusive_upper_bound"] = True with pytest.raises(ValidationError, match="An inclusive upper bound may not include positive infinity."): - FunctionalRangeCreate.model_validate(invalid_data) + FunctionalClassificationCreate.model_validate(invalid_data) invalid_data = deepcopy(TEST_FUNCTIONAL_RANGE_INCLUDING_NEGATIVE_INFINITY) invalid_data["inclusive_lower_bound"] = True with pytest.raises(ValidationError, match="An inclusive lower bound may not include negative infinity."): - FunctionalRangeCreate.model_validate(invalid_data) + FunctionalClassificationCreate.model_validate(invalid_data) @pytest.mark.parametrize( - "functional_range, opposite_criterion", + "functional_classification, opposite_criterion", [(TEST_FUNCTIONAL_RANGE_NORMAL, ACMGCriterion.PS3), (TEST_FUNCTIONAL_RANGE_ABNORMAL, ACMGCriterion.BS3)], ) -def test_cannot_create_functional_range_when_classification_disagrees_with_acmg_criterion( - functional_range, opposite_criterion +def test_cannot_create_functional_classification_when_classification_disagrees_with_acmg_criterion( + functional_classification, opposite_criterion ): - invalid_data = deepcopy(functional_range) + invalid_data = deepcopy(functional_classification) invalid_data["acmg_classification"]["criterion"] = opposite_criterion.value with pytest.raises(ValidationError, match="must agree with the functional range classification"): - FunctionalRangeCreate.model_validate(invalid_data) + FunctionalClassificationCreate.model_validate(invalid_data) def test_none_type_classification_and_evidence_strength_count_as_agreement(): valid_data = deepcopy(TEST_FUNCTIONAL_RANGE_NORMAL) valid_data["acmg_classification"] = {"criterion": None, "evidence_strength": None} - fr = FunctionalRangeCreate.model_validate(valid_data) + fr = FunctionalClassificationCreate.model_validate(valid_data) assert fr.acmg_classification.criterion is None assert fr.acmg_classification.evidence_strength is None -def test_cannot_create_functional_range_when_oddspaths_evidence_disagrees_with_classification(): +def test_cannot_create_functional_classification_when_oddspaths_evidence_disagrees_with_classification(): invalid_data = deepcopy(TEST_FUNCTIONAL_RANGE_NORMAL) # Abnormal evidence strength for a normal range invalid_data["oddspaths_ratio"] = 350 with pytest.raises(ValidationError, match="implies criterion"): - FunctionalRangeCreate.model_validate(invalid_data) + FunctionalClassificationCreate.model_validate(invalid_data) invalid_data = deepcopy(TEST_FUNCTIONAL_RANGE_ABNORMAL) # Normal evidence strength for an abnormal range invalid_data["oddspaths_ratio"] = 0.1 with pytest.raises(ValidationError, match="implies criterion"): - FunctionalRangeCreate.model_validate(invalid_data) + FunctionalClassificationCreate.model_validate(invalid_data) def test_is_contained_by_range(): - fr = FunctionalRangeCreate.model_validate( + fr = FunctionalClassificationCreate.model_validate( { "label": "test range", - "classification": "abnormal", + "functional_classification": FunctionalClassificationOptions.abnormal, "range": (0.0, 1.0), "inclusive_lower_bound": True, "inclusive_upper_bound": True, @@ -156,6 +202,58 @@ def test_is_contained_by_range(): assert not fr.is_contained_by_range(0.0), "0.0 (exclusive lower bound) should not be contained in the range" +def test_inclusive_bounds_get_default_when_unset_and_range_exists(): + fr = FunctionalClassificationCreate.model_validate( + { + "label": "test range", + "functional_classification": FunctionalClassificationOptions.abnormal, + "range": (0.0, 1.0), + } + ) + + assert fr.inclusive_lower_bound is True, "inclusive_lower_bound should default to True" + assert fr.inclusive_upper_bound is False, "inclusive_upper_bound should default to False" + + +def test_inclusive_bounds_remain_none_when_range_is_none(): + fr = FunctionalClassificationCreate.model_validate( + { + "label": "test range", + "functional_classification": FunctionalClassificationOptions.abnormal, + "class": "some_class", + } + ) + + assert fr.inclusive_lower_bound is None, "inclusive_lower_bound should remain None" + assert fr.inclusive_upper_bound is None, "inclusive_upper_bound should remain None" + + +@pytest.mark.parametrize( + "bound_property, bound_value, match_text", + [ + ( + "inclusive_lower_bound", + True, + "An inclusive lower bound may not be set on a class based functional classification.", + ), + ( + "inclusive_upper_bound", + True, + "An inclusive upper bound may not be set on a class based functional classification.", + ), + ], +) +def test_cant_set_inclusive_bounds_when_range_is_none(bound_property, bound_value, match_text): + invalid_data = { + "label": "test range", + "functional_classification": FunctionalClassificationOptions.abnormal, + "class": "some_class", + bound_property: bound_value, + } + with pytest.raises(ValidationError, match=match_text): + FunctionalClassificationCreate.model_validate(invalid_data) + + ############################################################################## # Tests for ScoreCalibration view models ############################################################################## @@ -165,7 +263,11 @@ def test_is_contained_by_range(): @pytest.mark.parametrize( "valid_calibration", - [TEST_BRNICH_SCORE_CALIBRATION, TEST_PATHOGENICITY_SCORE_CALIBRATION], + [ + TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, + TEST_BRNICH_SCORE_CALIBRATION_CLASS_BASED, + TEST_PATHOGENICITY_SCORE_CALIBRATION, + ], ) def test_can_create_valid_score_calibration(valid_calibration): sc = ScoreCalibrationCreate.model_validate(valid_calibration) @@ -175,11 +277,11 @@ def test_can_create_valid_score_calibration(valid_calibration): assert sc.baseline_score == valid_calibration.get("baseline_score") assert sc.baseline_score_description == valid_calibration.get("baseline_score_description") - if valid_calibration.get("functional_ranges") is not None: - assert len(sc.functional_ranges) == len(valid_calibration["functional_ranges"]) + if valid_calibration.get("functional_classifications") is not None: + assert len(sc.functional_classifications) == len(valid_calibration["functional_classifications"]) # functional range validation is presumed to be well tested separately. else: - assert sc.functional_ranges is None + assert sc.functional_classifications is None if valid_calibration.get("threshold_sources") is not None: assert len(sc.threshold_sources) == len(valid_calibration["threshold_sources"]) @@ -212,11 +314,11 @@ def test_can_create_valid_score_calibration(valid_calibration): # because of the large number of model validators that need to play nice with this case. @pytest.mark.parametrize( "valid_calibration", - [TEST_BRNICH_SCORE_CALIBRATION, TEST_PATHOGENICITY_SCORE_CALIBRATION], + [TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, TEST_PATHOGENICITY_SCORE_CALIBRATION], ) -def test_can_create_valid_score_calibration_without_functional_ranges(valid_calibration): +def test_can_create_valid_score_calibration_without_functional_classifications(valid_calibration): valid_calibration = deepcopy(valid_calibration) - valid_calibration["functional_ranges"] = None + valid_calibration["functional_classifications"] = None sc = ScoreCalibrationCreate.model_validate(valid_calibration) @@ -225,11 +327,11 @@ def test_can_create_valid_score_calibration_without_functional_ranges(valid_cali assert sc.baseline_score == valid_calibration.get("baseline_score") assert sc.baseline_score_description == valid_calibration.get("baseline_score_description") - if valid_calibration.get("functional_ranges") is not None: - assert len(sc.functional_ranges) == len(valid_calibration["functional_ranges"]) + if valid_calibration.get("functional_classifications") is not None: + assert len(sc.functional_classifications) == len(valid_calibration["functional_classifications"]) # functional range validation is presumed to be well tested separately. else: - assert sc.functional_ranges is None + assert sc.functional_classifications is None if valid_calibration.get("threshold_sources") is not None: assert len(sc.threshold_sources) == len(valid_calibration["threshold_sources"]) @@ -259,50 +361,65 @@ def test_can_create_valid_score_calibration_without_functional_ranges(valid_cali def test_cannot_create_score_calibration_when_classification_ranges_overlap(): - invalid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION) + invalid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) # Make the first two ranges overlap - invalid_data["functional_ranges"][0]["range"] = [1.0, 3.0] - invalid_data["functional_ranges"][1]["range"] = [2.0, 4.0] + invalid_data["functional_classifications"][0]["range"] = [1.0, 3.0] + invalid_data["functional_classifications"][1]["range"] = [2.0, 4.0] with pytest.raises(ValidationError, match="Classified score ranges may not overlap; `"): ScoreCalibrationCreate.model_validate(invalid_data) def test_can_create_score_calibration_when_unclassified_ranges_overlap_with_classified_ranges(): - valid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION) + valid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) # Make the first two ranges overlap, one being 'not_specified' - valid_data["functional_ranges"][0]["range"] = [1.5, 3.0] - valid_data["functional_ranges"][1]["range"] = [2.0, 4.0] - valid_data["functional_ranges"][0]["classification"] = "not_specified" + valid_data["functional_classifications"][0]["range"] = [1.5, 3.0] + valid_data["functional_classifications"][1]["range"] = [2.0, 4.0] + valid_data["functional_classifications"][0]["functional_classification"] = ( + FunctionalClassificationOptions.not_specified + ) sc = ScoreCalibrationCreate.model_validate(valid_data) - assert len(sc.functional_ranges) == len(valid_data["functional_ranges"]) + assert len(sc.functional_classifications) == len(valid_data["functional_classifications"]) def test_can_create_score_calibration_when_unclassified_ranges_overlap_with_each_other(): - valid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION) + valid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) # Make the first two ranges overlap, both being 'not_specified' - valid_data["functional_ranges"][0]["range"] = [1.5, 3.0] - valid_data["functional_ranges"][1]["range"] = [2.0, 4.0] - valid_data["functional_ranges"][0]["classification"] = "not_specified" - valid_data["functional_ranges"][1]["classification"] = "not_specified" + valid_data["functional_classifications"][0]["range"] = [1.5, 3.0] + valid_data["functional_classifications"][1]["range"] = [2.0, 4.0] + valid_data["functional_classifications"][0]["functional_classification"] = ( + FunctionalClassificationOptions.not_specified + ) + valid_data["functional_classifications"][1]["functional_classification"] = ( + FunctionalClassificationOptions.not_specified + ) sc = ScoreCalibrationCreate.model_validate(valid_data) - assert len(sc.functional_ranges) == len(valid_data["functional_ranges"]) + assert len(sc.functional_classifications) == len(valid_data["functional_classifications"]) def test_cannot_create_score_calibration_when_ranges_touch_with_inclusive_ranges(): - invalid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION) + invalid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) # Make the first two ranges touch - invalid_data["functional_ranges"][0]["range"] = [1.0, 2.0] - invalid_data["functional_ranges"][1]["range"] = [2.0, 4.0] - invalid_data["functional_ranges"][0]["inclusive_upper_bound"] = True + invalid_data["functional_classifications"][0]["range"] = [1.0, 2.0] + invalid_data["functional_classifications"][1]["range"] = [2.0, 4.0] + invalid_data["functional_classifications"][0]["inclusive_upper_bound"] = True with pytest.raises(ValidationError, match="Classified score ranges may not overlap; `"): ScoreCalibrationCreate.model_validate(invalid_data) def test_cannot_create_score_calibration_with_duplicate_range_labels(): - invalid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION) + invalid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + # Make the first two ranges have the same label + invalid_data["functional_classifications"][0]["label"] = "duplicate label" + invalid_data["functional_classifications"][1]["label"] = "duplicate label" + with pytest.raises(ValidationError, match="Functional range labels must be unique"): + ScoreCalibrationCreate.model_validate(invalid_data) + + +def test_cannot_create_score_calibration_with_duplicate_range_classes(): + invalid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION_CLASS_BASED) # Make the first two ranges have the same label - invalid_data["functional_ranges"][0]["label"] = "duplicate label" - invalid_data["functional_ranges"][1]["label"] = "duplicate label" + invalid_data["functional_classifications"][0]["label"] = "duplicate label" + invalid_data["functional_classifications"][1]["label"] = "duplicate label" with pytest.raises(ValidationError, match="Functional range labels must be unique"): ScoreCalibrationCreate.model_validate(invalid_data) @@ -310,7 +427,7 @@ def test_cannot_create_score_calibration_with_duplicate_range_labels(): # Making an exception to usually not testing the ability to create models without optional fields, # since model validators sometimes rely on their absence. def test_can_create_score_calibration_without_baseline_score(): - valid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION) + valid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) valid_data["baseline_score"] = None sc = ScoreCalibrationCreate.model_validate(valid_data) @@ -318,7 +435,7 @@ def test_can_create_score_calibration_without_baseline_score(): def test_can_create_score_calibration_with_baseline_score_when_outside_all_ranges(): - valid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION) + valid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) valid_data["baseline_score"] = 10.0 sc = ScoreCalibrationCreate.model_validate(valid_data) @@ -326,7 +443,7 @@ def test_can_create_score_calibration_with_baseline_score_when_outside_all_range def test_can_create_score_calibration_with_baseline_score_when_inside_normal_range(): - valid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION) + valid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) valid_data["baseline_score"] = 3.0 sc = ScoreCalibrationCreate.model_validate(valid_data) @@ -334,7 +451,7 @@ def test_can_create_score_calibration_with_baseline_score_when_inside_normal_ran def test_cannot_create_score_calibration_with_baseline_score_when_inside_non_normal_range(): - invalid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION) + invalid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) invalid_data["baseline_score"] = -3.0 with pytest.raises(ValueError, match="Baseline scores may not fall within non-normal ranges"): ScoreCalibrationCreate.model_validate(invalid_data) @@ -345,7 +462,11 @@ def test_cannot_create_score_calibration_with_baseline_score_when_inside_non_nor @pytest.mark.parametrize( "valid_calibration", - [TEST_SAVED_BRNICH_SCORE_CALIBRATION, TEST_SAVED_PATHOGENICITY_SCORE_CALIBRATION], + [ + TEST_SAVED_BRNICH_SCORE_CALIBRATION_RANGE_BASED, + TEST_SAVED_BRNICH_SCORE_CALIBRATION_CLASS_BASED, + TEST_SAVED_PATHOGENICITY_SCORE_CALIBRATION, + ], ) def test_can_create_valid_score_calibration_from_attributed_object(valid_calibration): sc = ScoreCalibration.model_validate(dummy_attributed_object_from_dict(valid_calibration)) @@ -357,11 +478,11 @@ def test_can_create_valid_score_calibration_from_attributed_object(valid_calibra assert sc.baseline_score == valid_calibration.get("baselineScore") assert sc.baseline_score_description == valid_calibration.get("baselineScoreDescription") - if valid_calibration.get("functionalRanges") is not None: - assert len(sc.functional_ranges) == len(valid_calibration["functionalRanges"]) + if valid_calibration.get("functionalClassifications") is not None: + assert len(sc.functional_classifications) == len(valid_calibration["functionalClassifications"]) # functional range validation is presumed to be well tested separately. else: - assert sc.functional_ranges is None + assert sc.functional_classifications is None if valid_calibration.get("thresholdSources") is not None: assert len(sc.threshold_sources) == len(valid_calibration["thresholdSources"]) @@ -391,17 +512,24 @@ def test_can_create_valid_score_calibration_from_attributed_object(valid_calibra def test_cannot_create_score_calibration_when_publication_information_is_missing(): - invalid_data = deepcopy(TEST_SAVED_BRNICH_SCORE_CALIBRATION) + invalid_data = deepcopy(TEST_SAVED_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + # Add publication identifiers with missing information invalid_data.pop("thresholdSources", None) invalid_data.pop("classificationSources", None) invalid_data.pop("methodSources", None) - with pytest.raises(ValidationError, match="Unable to create ScoreCalibration without attribute"): + + with pytest.raises(ValidationError) as exc_info: ScoreCalibration.model_validate(dummy_attributed_object_from_dict(invalid_data)) + assert "Field required" in str(exc_info.value) + assert "thresholdSources" in str(exc_info.value) + assert "classificationSources" in str(exc_info.value) + assert "methodSources" in str(exc_info.value) + def test_can_create_score_calibration_from_association_style_publication_identifiers_against_attributed_object(): - orig_data = TEST_SAVED_BRNICH_SCORE_CALIBRATION + orig_data = TEST_SAVED_BRNICH_SCORE_CALIBRATION_RANGE_BASED data = deepcopy(orig_data) threshold_sources = [ @@ -431,11 +559,11 @@ def test_can_create_score_calibration_from_association_style_publication_identif assert sc.baseline_score == orig_data.get("baselineScore") assert sc.baseline_score_description == orig_data.get("baselineScoreDescription") - if orig_data.get("functionalRanges") is not None: - assert len(sc.functional_ranges) == len(orig_data["functionalRanges"]) + if orig_data.get("functionalClassifications") is not None: + assert len(sc.functional_classifications) == len(orig_data["functionalClassifications"]) # functional range validation is presumed to be well tested separately. else: - assert sc.functional_ranges is None + assert sc.functional_classifications is None if orig_data.get("thresholdSources") is not None: assert len(sc.threshold_sources) == len(orig_data["thresholdSources"]) @@ -465,7 +593,7 @@ def test_can_create_score_calibration_from_association_style_publication_identif def test_primary_score_calibration_cannot_be_research_use_only(): - invalid_data = deepcopy(TEST_SAVED_BRNICH_SCORE_CALIBRATION) + invalid_data = deepcopy(TEST_SAVED_BRNICH_SCORE_CALIBRATION_RANGE_BASED) invalid_data["primary"] = True invalid_data["researchUseOnly"] = True with pytest.raises(ValidationError, match="Primary score calibrations may not be marked as research use only"): @@ -473,15 +601,33 @@ def test_primary_score_calibration_cannot_be_research_use_only(): def test_primary_score_calibration_cannot_be_private(): - invalid_data = deepcopy(TEST_SAVED_BRNICH_SCORE_CALIBRATION) + invalid_data = deepcopy(TEST_SAVED_BRNICH_SCORE_CALIBRATION_RANGE_BASED) invalid_data["primary"] = True invalid_data["private"] = True with pytest.raises(ValidationError, match="Primary score calibrations may not be marked as private"): ScoreCalibration.model_validate(dummy_attributed_object_from_dict(invalid_data)) +def test_can_create_score_calibration_from_non_orm_context(): + data = deepcopy(TEST_SAVED_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + + sc = ScoreCalibration.model_validate(data) + + assert sc.title == data["title"] + assert sc.research_use_only == data.get("researchUseOnly", False) + assert sc.primary == data.get("primary", False) + assert sc.investigator_provided == data.get("investigatorProvided", False) + assert sc.baseline_score == data.get("baselineScore") + assert sc.baseline_score_description == data.get("baselineScoreDescription") + assert len(sc.functional_classifications) == len(data["functionalClassifications"]) + assert len(sc.threshold_sources) == len(data["thresholdSources"]) + assert len(sc.classification_sources) == len(data["classificationSources"]) + assert len(sc.method_sources) == len(data["methodSources"]) + assert sc.calibration_metadata == data.get("calibrationMetadata") + + def test_score_calibration_with_score_set_urn_can_be_created_from_attributed_object(): - data = deepcopy(TEST_SAVED_BRNICH_SCORE_CALIBRATION) + data = deepcopy(TEST_SAVED_BRNICH_SCORE_CALIBRATION_RANGE_BASED) data["score_set"] = dummy_attributed_object_from_dict({"urn": "urn:mavedb:00000000-0000-0000-0000-000000000001"}) sc = ScoreCalibrationWithScoreSetUrn.model_validate(dummy_attributed_object_from_dict(data)) @@ -491,7 +637,61 @@ def test_score_calibration_with_score_set_urn_can_be_created_from_attributed_obj def test_score_calibration_with_score_set_urn_cannot_be_created_without_score_set_urn(): - invalid_data = deepcopy(TEST_SAVED_BRNICH_SCORE_CALIBRATION) + invalid_data = deepcopy(TEST_SAVED_BRNICH_SCORE_CALIBRATION_RANGE_BASED) invalid_data["score_set"] = dummy_attributed_object_from_dict({}) - with pytest.raises(ValidationError, match="Unable to create ScoreCalibrationWithScoreSetUrn without attribute"): + with pytest.raises(ValidationError, match="Unable to coerce score set urn for ScoreCalibrationWithScoreSetUrn"): ScoreCalibrationWithScoreSetUrn.model_validate(dummy_attributed_object_from_dict(invalid_data)) + + +def test_cannot_create_score_calibration_with_mixed_range_and_class_based_functional_classifications(): + """Test that score calibrations cannot have both range-based and class-based functional classifications.""" + invalid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + # Add a class-based functional classification to a range-based calibration + invalid_data["functional_classifications"].append( + { + "label": "class based classification", + "functional_classification": FunctionalClassificationOptions.abnormal, + "class": "some_class", + } + ) + + with pytest.raises( + ValidationError, match="All functional classifications within a score calibration must be of the same type" + ): + ScoreCalibrationCreate.model_validate(invalid_data) + + +def test_score_calibration_range_based_property(): + """Test the range_based property works correctly.""" + range_based_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + sc = ScoreCalibrationCreate.model_validate(range_based_data) + assert sc.range_based is True + assert sc.class_based is False + + +def test_score_calibration_class_based_property(): + """Test the class_based property works correctly.""" + class_based_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION_CLASS_BASED) + sc = ScoreCalibrationCreate.model_validate(class_based_data) + assert sc.class_based is True + assert sc.range_based is False + + +def test_score_calibration_properties_when_no_functional_classifications(): + """Test that properties return False when no functional classifications exist.""" + valid_data = deepcopy(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + valid_data["functional_classifications"] = None + + sc = ScoreCalibrationCreate.model_validate(valid_data) + assert sc.range_based is False + assert sc.class_based is False + + +def test_score_calibration_with_score_set_urn_can_be_created_from_non_orm_context(): + data = deepcopy(TEST_SAVED_BRNICH_SCORE_CALIBRATION_RANGE_BASED) + data["score_set_urn"] = "urn:mavedb:00000000-0000-0000-0000-000000000001" + + sc = ScoreCalibrationWithScoreSetUrn.model_validate(data) + + assert sc.title == data["title"] + assert sc.score_set_urn == data["score_set_urn"] diff --git a/tests/view_models/test_score_set.py b/tests/view_models/test_score_set.py index 754b8657..03698974 100644 --- a/tests/view_models/test_score_set.py +++ b/tests/view_models/test_score_set.py @@ -3,20 +3,27 @@ import pytest from mavedb.view_models.publication_identifier import PublicationIdentifier, PublicationIdentifierCreate -from mavedb.view_models.score_set import SavedScoreSet, ScoreSetCreate, ScoreSetModify, ScoreSetUpdateAllOptional +from mavedb.view_models.score_set import ( + SavedScoreSet, + ScoreSet, + ScoreSetCreate, + ScoreSetModify, + ScoreSetUpdateAllOptional, +) from mavedb.view_models.target_gene import SavedTargetGene, TargetGeneCreate from tests.helpers.constants import ( EXTRA_LICENSE, EXTRA_USER, SAVED_PUBMED_PUBLICATION, TEST_BIORXIV_IDENTIFIER, - TEST_BRNICH_SCORE_CALIBRATION, + TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED, TEST_CROSSREF_IDENTIFIER, TEST_MINIMAL_ACC_SCORESET, TEST_MINIMAL_SEQ_SCORESET, TEST_MINIMAL_SEQ_SCORESET_RESPONSE, TEST_PATHOGENICITY_SCORE_CALIBRATION, TEST_PUBMED_IDENTIFIER, + VALID_EXPERIMENT_SET_URN, VALID_EXPERIMENT_URN, VALID_SCORE_SET_URN, VALID_TMP_URN, @@ -231,7 +238,7 @@ def test_cannot_create_score_set_with_an_empty_method(): @pytest.mark.parametrize( - "calibration", [deepcopy(TEST_BRNICH_SCORE_CALIBRATION), deepcopy(TEST_PATHOGENICITY_SCORE_CALIBRATION)] + "calibration", [deepcopy(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED), deepcopy(TEST_PATHOGENICITY_SCORE_CALIBRATION)] ) def test_can_create_score_set_with_complete_and_valid_provided_calibrations(calibration): score_set_test = TEST_MINIMAL_SEQ_SCORESET.copy() @@ -247,8 +254,8 @@ def test_can_create_score_set_with_multiple_valid_calibrations(): score_set_test = TEST_MINIMAL_SEQ_SCORESET.copy() score_set_test["experiment_urn"] = VALID_EXPERIMENT_URN score_set_test["score_calibrations"] = [ - deepcopy(TEST_BRNICH_SCORE_CALIBRATION), - deepcopy(TEST_BRNICH_SCORE_CALIBRATION), + deepcopy(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED), + deepcopy(TEST_BRNICH_SCORE_CALIBRATION_RANGE_BASED), deepcopy(TEST_PATHOGENICITY_SCORE_CALIBRATION), ] @@ -372,10 +379,14 @@ def test_score_set_update_all_optional(attribute, updated_data): @pytest.mark.parametrize( - "exclude", - ["publication_identifier_associations", "meta_analyzes_score_sets", "meta_analyzed_by_score_sets"], + "exclude,expected_missing_fields", + [ + ("publication_identifier_associations", ["primaryPublicationIdentifiers", "secondaryPublicationIdentifiers"]), + ("meta_analyzes_score_sets", ["metaAnalyzesScoreSetUrns"]), + ("meta_analyzed_by_score_sets", ["metaAnalyzedByScoreSetUrns"]), + ], ) -def test_cannot_create_saved_score_set_without_all_attributed_properties(exclude): +def test_cannot_create_saved_score_set_without_all_attributed_properties(exclude, expected_missing_fields): score_set = TEST_MINIMAL_SEQ_SCORESET_RESPONSE.copy() score_set["urn"] = "urn:score-set-xxx" @@ -429,8 +440,9 @@ def test_cannot_create_saved_score_set_without_all_attributed_properties(exclude with pytest.raises(ValueError) as exc_info: SavedScoreSet.model_validate(score_set_attributed_object) - assert "Unable to create SavedScoreSet without attribute" in str(exc_info.value) - assert exclude in str(exc_info.value) + assert "Field required" in str(exc_info.value) + for exclude_field in expected_missing_fields: + assert exclude_field in str(exc_info.value) def test_can_create_score_set_with_none_type_superseded_score_set_urn(): @@ -543,3 +555,14 @@ def test_cant_create_score_set_without_experiment_urn_if_not_meta_analysis(): ScoreSetCreate(**score_set_test) assert "experiment URN is required unless your score set is a meta-analysis" in str(exc_info.value) + + +def test_can_create_score_set_from_non_orm_context(): + score_set_test = TEST_MINIMAL_SEQ_SCORESET_RESPONSE.copy() + score_set_test["urn"] = "urn:score-set-xxx" + score_set_test["experiment"]["urn"] = VALID_EXPERIMENT_URN + score_set_test["experiment"]["experimentSetUrn"] = VALID_EXPERIMENT_SET_URN + + saved_score_set = ScoreSet.model_validate(score_set_test) + + assert saved_score_set.urn == "urn:score-set-xxx" diff --git a/tests/view_models/test_target_gene.py b/tests/view_models/test_target_gene.py index 32ae4f30..71b497b9 100644 --- a/tests/view_models/test_target_gene.py +++ b/tests/view_models/test_target_gene.py @@ -1,12 +1,12 @@ import pytest -from mavedb.view_models.target_gene import TargetGeneCreate, SavedTargetGene +from mavedb.view_models.target_gene import SavedTargetGene, TargetGene, TargetGeneCreate, TargetGeneWithScoreSetUrn from tests.helpers.constants import ( SEQUENCE, + TEST_ENSEMBLE_EXTERNAL_IDENTIFIER, TEST_POPULATED_TAXONOMY, - TEST_SAVED_TAXONOMY, TEST_REFSEQ_EXTERNAL_IDENTIFIER, - TEST_ENSEMBLE_EXTERNAL_IDENTIFIER, + TEST_SAVED_TAXONOMY, TEST_UNIPROT_EXTERNAL_IDENTIFIER, ) from tests.helpers.util.common import dummy_attributed_object_from_dict @@ -200,3 +200,107 @@ def test_cannot_create_saved_target_without_seq_or_acc(): SavedTargetGene.model_validate(target_gene) assert "Either a `target_sequence` or `target_accession` is required" in str(exc_info.value) + + +def test_saved_target_gene_can_be_created_from_orm(): + orm_obj = dummy_attributed_object_from_dict( + { + "id": 1, + "name": "UBE2I", + "category": "regulatory", + "ensembl_offset": dummy_attributed_object_from_dict( + {"offset": 1, "identifier": dummy_attributed_object_from_dict(TEST_ENSEMBLE_EXTERNAL_IDENTIFIER)} + ), + "refseq_offset": None, + "uniprot_offset": None, + "target_sequence": dummy_attributed_object_from_dict( + { + "sequenceType": "dna", + "sequence": SEQUENCE, + "taxonomy": TEST_SAVED_TAXONOMY, + } + ), + "target_accession": None, + "record_type": "target_gene", + "uniprot_id_from_mapped_metadata": None, + } + ) + model = SavedTargetGene.model_validate(orm_obj) + assert model.name == "UBE2I" + assert model.external_identifiers[0].identifier.identifier == "ENSG00000103275" + + +def test_target_gene_with_score_set_urn_can_be_created_from_orm(): + orm_obj = dummy_attributed_object_from_dict( + { + "id": 1, + "name": "UBE2I", + "category": "regulatory", + "ensembl_offset": dummy_attributed_object_from_dict( + { + "offset": 1, + "identifier": dummy_attributed_object_from_dict(TEST_ENSEMBLE_EXTERNAL_IDENTIFIER), + } + ), + "refseq_offset": None, + "uniprot_offset": None, + "target_sequence": dummy_attributed_object_from_dict( + { + "sequenceType": "dna", + "sequence": SEQUENCE, + "taxonomy": TEST_SAVED_TAXONOMY, + } + ), + "target_accession": None, + "record_type": "target_gene", + "uniprot_id_from_mapped_metadata": None, + "score_set": dummy_attributed_object_from_dict({"urn": "urn:mavedb:01234567-a-1"}), + } + ) + model = TargetGeneWithScoreSetUrn.model_validate(orm_obj) + assert model.name == "UBE2I" + assert model.score_set_urn == "urn:mavedb:01234567-a-1" + + +def test_target_gene_can_be_created_from_non_orm_context(): + # Minimal valid dict for TargetGene (must have target_sequence or target_accession) + data = { + "id": 1, + "name": "UBE2I", + "category": "regulatory", + "external_identifiers": [{"identifier": TEST_ENSEMBLE_EXTERNAL_IDENTIFIER, "offset": 1}], + "target_sequence": { + "sequenceType": "dna", + "sequence": SEQUENCE, + "taxonomy": TEST_SAVED_TAXONOMY, + }, + "target_accession": None, + "record_type": "target_gene", + "uniprot_id_from_mapped_metadata": None, + } + model = TargetGene.model_validate(data) + assert model.name == data["name"] + assert model.category == data["category"] + assert model.external_identifiers[0].identifier.identifier == "ENSG00000103275" + + +def test_target_gene_with_score_set_urn_can_be_created_from_dict(): + # Minimal valid dict for TargetGeneWithScoreSetUrn (must have target_sequence or target_accession) + data = { + "id": 1, + "name": "UBE2I", + "category": "regulatory", + "external_identifiers": [{"identifier": TEST_ENSEMBLE_EXTERNAL_IDENTIFIER, "offset": 1}], + "target_sequence": { + "sequenceType": "dna", + "sequence": SEQUENCE, + "taxonomy": TEST_SAVED_TAXONOMY, + }, + "target_accession": None, + "record_type": "target_gene", + "uniprot_id_from_mapped_metadata": None, + "score_set_urn": "urn:mavedb:01234567-a-1", + } + model = TargetGeneWithScoreSetUrn.model_validate(data) + assert model.name == data["name"] + assert model.score_set_urn == data["score_set_urn"] diff --git a/tests/view_models/test_variant.py b/tests/view_models/test_variant.py index 9ec2d2f3..200eca9f 100644 --- a/tests/view_models/test_variant.py +++ b/tests/view_models/test_variant.py @@ -1,7 +1,15 @@ -from mavedb.view_models.variant import VariantEffectMeasurementCreate, VariantEffectMeasurement - +from mavedb.view_models.variant import ( + SavedVariantEffectMeasurementWithMappedVariant, + VariantEffectMeasurement, + VariantEffectMeasurementCreate, +) +from tests.helpers.constants import ( + TEST_MINIMAL_MAPPED_VARIANT, + TEST_MINIMAL_VARIANT, + TEST_POPULATED_VARIANT, + TEST_SAVED_VARIANT, +) from tests.helpers.util.common import dummy_attributed_object_from_dict -from tests.helpers.constants import TEST_MINIMAL_VARIANT, TEST_POPULATED_VARIANT, TEST_SAVED_VARIANT def test_minimal_variant_create(): @@ -19,3 +27,51 @@ def test_saved_variant(): dummy_attributed_object_from_dict({**TEST_SAVED_VARIANT, "score_set_id": 1}) ) assert all(variant.__getattribute__(k) == v for k, v in TEST_SAVED_VARIANT.items()) + + +def test_can_create_saved_variant_with_mapping_with_all_attributed_properties(): + variant = TEST_SAVED_VARIANT.copy() + variant["score_set_id"] = 1 + variant["mapped_variants"] = [ + dummy_attributed_object_from_dict( + { + **TEST_MINIMAL_MAPPED_VARIANT, + "id": 1, + "variant": dummy_attributed_object_from_dict({"urn": "urn:mavedb:variant-xxx"}), + } + ) + ] + variant_attributed_object = dummy_attributed_object_from_dict(variant) + saved_variant = SavedVariantEffectMeasurementWithMappedVariant.model_validate(variant_attributed_object) + assert saved_variant.mapped_variant is not None + assert saved_variant.mapped_variant.variant_urn == "urn:mavedb:variant-xxx" + for k, v in TEST_SAVED_VARIANT.items(): + assert saved_variant.__getattribute__(k) == v + + +# Missing attributed properties here are unproblematic, as they are optional on the view model. +def test_can_create_saved_variant_with_mapping_with_missing_attributed_properties(): + variant = TEST_SAVED_VARIANT.copy() + variant.pop("mapped_variants", None) + variant["score_set_id"] = 1 + + variant_attributed_object = dummy_attributed_object_from_dict(variant) + saved_variant = SavedVariantEffectMeasurementWithMappedVariant.model_validate(variant_attributed_object) + for k, v in TEST_SAVED_VARIANT.items(): + assert saved_variant.__getattribute__(k) == v + + +def test_can_create_saved_variant_with_mapping_from_non_orm_context(): + variant = TEST_SAVED_VARIANT.copy() + variant["score_set_id"] = 1 + variant["mapped_variant"] = { + **TEST_MINIMAL_MAPPED_VARIANT, + "id": 1, + "variant_urn": "urn:mavedb:variant-xxx", + } + + saved_variant = SavedVariantEffectMeasurementWithMappedVariant.model_validate(variant) + assert saved_variant.mapped_variant is not None + assert saved_variant.mapped_variant.variant_urn == "urn:mavedb:variant-xxx" + for k, v in TEST_SAVED_VARIANT.items(): + assert saved_variant.__getattribute__(k) == v