diff --git a/api/db/database.py b/api/db/database.py index 7943947..a117252 100644 --- a/api/db/database.py +++ b/api/db/database.py @@ -1,10 +1,12 @@ from sqlmodel import create_engine, Session +import os -DATABASE_URL = "sqlite:///./fireform.db" +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +DATABASE_URL = f"sqlite:///{os.path.join(BASE_DIR, 'fireform.db')}" engine = create_engine( DATABASE_URL, - echo=True, + echo=False, connect_args={"check_same_thread": False}, ) diff --git a/api/db/init_db.py b/api/db/init_db.py index 9ad27ea..c868db4 100644 --- a/api/db/init_db.py +++ b/api/db/init_db.py @@ -1,6 +1,6 @@ from sqlmodel import SQLModel -from api.db.database import engine -from api.db import models +from database import engine +import models def init_db(): SQLModel.metadata.create_all(engine) diff --git a/api/db/models.py b/api/db/models.py index f76c93b..bbbbdff 100644 --- a/api/db/models.py +++ b/api/db/models.py @@ -1,10 +1,11 @@ -from sqlmodel import SQLModel, Field +from sqlmodel import SQLModel, Field, UniqueConstraint from sqlalchemy import Column, JSON from datetime import datetime +from enum import Enum class Template(SQLModel, table=True): id: int | None = Field(default=None, primary_key=True) - name: str + name: str = Field(unique=True) fields: dict = Field(sa_column=Column(JSON)) pdf_path: str created_at: datetime = Field(default_factory=datetime.utcnow) @@ -15,4 +16,39 @@ class FormSubmission(SQLModel, table=True): template_id: int input_text: str output_pdf_path: str - created_at: datetime = Field(default_factory=datetime.utcnow) \ No newline at end of file + created_at: datetime = Field(default_factory=datetime.utcnow) + + +class ReportSchema(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str = Field(unique=True) + description: str + use_case: str + created_at: datetime = Field(default_factory=datetime.utcnow) + +class ReportSchemaTemplate(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + template_id: int + report_schema_id: int + field_mapping: dict = Field(default={}, sa_column=Column(JSON)) + + __table_args__ = (UniqueConstraint("template_id", "report_schema_id"),) + +class Datatype(str, Enum): + STRING = "string" + INT = "int" + DATE = "date" + ENUM = 'enum' + + +class SchemaField(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + report_schema_id: int + field_name: str + source_template_id: int + description: str = Field(default="") + data_type: Datatype = Field(default=Datatype.STRING) + word_limit: int | None = Field(default=None) + required: bool = Field(default=False) + allowed_values: dict | None = Field(sa_column=Column(JSON)) + canonical_name: str | None = Field(default=None) diff --git a/api/db/repositories.py b/api/db/repositories.py index 6608718..2a59e21 100644 --- a/api/db/repositories.py +++ b/api/db/repositories.py @@ -1,19 +1,285 @@ +from ast import For +from collections import defaultdict +from sqlalchemy.exc import IntegrityError from sqlmodel import Session, select -from api.db.models import Template, FormSubmission +from api.db.models import ( + Template, + FormSubmission, + ReportSchema, + ReportSchemaTemplate, + SchemaField, +) -# Templates def create_template(session: Session, template: Template) -> Template: + try: + session.add(template) + session.commit() + session.refresh(template) + return template + except IntegrityError: + raise + +def get_template(session: Session, template_id: int) -> Template | None: + return session.get(Template, template_id) + +def update_template(session: Session, template_id: int, updates: dict) -> Template | None: + template = session.get(Template, template_id) + if not template: + return None + for key, value in updates.items(): + setattr(template, key, value) session.add(template) session.commit() session.refresh(template) return template -def get_template(session: Session, template_id: int) -> Template | None: - return session.get(Template, template_id) +def list_templates(session: Session) -> list[Template]: + return session.exec(select(Template)).all() + +def delete_template(session: Session, template_id: int) -> bool: + """Remove template and dependent rows (form submissions, schema links, schema fields).""" + template = session.get(Template, template_id) + if not template: + return False + + for form in session.exec( + select(FormSubmission).where(FormSubmission.template_id == template_id) + ).all(): + session.delete(form) + + for junction in session.exec( + select(ReportSchemaTemplate).where( + ReportSchemaTemplate.template_id == template_id + ) + ).all(): + for field in session.exec( + select(SchemaField).where( + SchemaField.report_schema_id == junction.report_schema_id, + SchemaField.source_template_id == template_id, + ) + ).all(): + session.delete(field) + session.delete(junction) + + session.delete(template) + session.commit() + return True -# Forms def create_form(session: Session, form: FormSubmission) -> FormSubmission: session.add(form) session.commit() session.refresh(form) - return form \ No newline at end of file + return form + +def get_form(session: Session, form_id: int) -> FormSubmission: + return session.get(FormSubmission, form_id) + +def update_form(session: Session, form_id: int, updates: dict) -> FormSubmission | None: + form = session.get(FormSubmission, form_id) + if not form: + return None + for key, value in updates.items(): + setattr(form, key, value) + session.add(form) + session.commit() + session.refresh(form) + return form + +def delete_form(session: Session, form_id: int) -> FormSubmission: + form_submission = session.get(FormSubmission, form_id) + if form_submission: + session.delete(form_submission) + session.commit() + return True + return False + +def create_report_schema(session: Session, schema: ReportSchema) -> ReportSchema: + try: + session.add(schema) + session.commit() + session.refresh(schema) + return schema + except IntegrityError: + raise + +def get_report_schema(session: Session, schema_id: int) -> ReportSchema | None: + return session.get(ReportSchema, schema_id) + +def list_report_schemas(session: Session) -> list[ReportSchema]: + return session.exec(select(ReportSchema)).all() + +def update_report_schema(session: Session, schema_id: int, updates: dict) -> ReportSchema | None: + schema = session.get(ReportSchema, schema_id) + if not schema: + return None + for key, value in updates.items(): + setattr(schema, key, value) + session.add(schema) + session.commit() + session.refresh(schema) + return schema + +def delete_report_schema(session: Session, schema_id: int) -> bool: + schema = session.get(ReportSchema, schema_id) + if not schema: + return False + + fields = session.exec( + select(SchemaField).where(SchemaField.report_schema_id == schema_id) + ).all() + for field in fields: + session.delete(field) + + junctions = session.exec( + select(ReportSchemaTemplate).where( + ReportSchemaTemplate.report_schema_id == schema_id + ) + ).all() + for junction in junctions: + session.delete(junction) + + session.delete(schema) + session.commit() + return True + + +def add_template_to_schema( + session: Session, schema_id: int, template_id: int +) -> ReportSchemaTemplate: + """Associate a template with a schema. + + Looks up `template.fields` and auto-creates a SchemaField for each field, + pre-populated with `field_name` and `source_template_id`. + Other metadata is left as defaults for the user to fill in later. + """ + template = session.get(Template, template_id) + if not template: + raise ValueError(f"Template {template_id} not found") + + schema = session.get(ReportSchema, schema_id) + if not schema: + raise ValueError(f"ReportSchema {schema_id} not found") + + # exists = session.exec(select(ReportSchemaTemplate).where(ReportSchemaTemplate.report_schema_id == schema_id, ReportSchemaTemplate.template_id == template_id)).first() + # if exists: + # raise IntegrityError + + # Create the junction record (field_mapping starts empty, populated during canonization) + junction = ReportSchemaTemplate( + report_schema_id=schema_id, + template_id=template_id, + ) + + session.add(junction) + + # Auto-create a SchemaField for each field in the template + for field_name in template.fields: + schema_field = SchemaField( + report_schema_id=schema_id, + field_name=field_name, + source_template_id=template_id, + ) + session.add(schema_field) + + session.commit() + session.refresh(junction) + return junction + +def remove_template_from_schema( + session: Session, schema_id: int, template_id: int +) -> bool: + """Disassociate a template from a schema and remove its SchemaField entries.""" + junction = session.exec( + select(ReportSchemaTemplate).where( + ReportSchemaTemplate.report_schema_id == schema_id, + ReportSchemaTemplate.template_id == template_id, + ) + ).first() + if not junction: + return False + + fields = session.exec( + select(SchemaField).where( + SchemaField.report_schema_id == schema_id, + SchemaField.source_template_id == template_id, + ) + ).all() + for field in fields: + session.delete(field) + + session.delete(junction) + session.commit() + return True + + +def get_schema_fields(session: Session, schema_id: int) -> list[SchemaField]: + return session.exec( + select(SchemaField).where(SchemaField.report_schema_id == schema_id) + ).all() + +def get_schema_field(session: Session, field_id: int) -> SchemaField: + return session.get(SchemaField, field_id) + +def update_schema_field(session: Session, schema_id: int, field_id: int, updates: dict) -> SchemaField | None: + """Update field metadata: description, data_type, word_limit, required, allowed_values. + + Validates that the field belongs to the given schema before updating, + so the same template field in different schemas can have independent metadata. + """ + field = session.get(SchemaField, field_id) + if not field or field.report_schema_id != schema_id: + return None + for key, value in updates.items(): + setattr(field, key, value) + session.add(field) + session.commit() + session.refresh(field) + return field + + +# ── Template mapping (post-canonization) ───────────────────────────────────── + +def update_template_mapping( + session: Session, schema_id: int, template_id: int +) -> ReportSchemaTemplate | None: + """Auto-generate and store the canonical → PDF field mapping after canonization. + + Builds the mapping by looking up all SchemaFields for this schema+template pair + and mapping each field's canonical_name → field_name. + """ + junction = session.exec( + select(ReportSchemaTemplate).where( + ReportSchemaTemplate.report_schema_id == schema_id, + ReportSchemaTemplate.template_id == template_id, + ) + ).first() + if not junction: + return None + + # Build mapping from SchemaFields that have been canonized + fields = session.exec( + select(SchemaField).where( + SchemaField.report_schema_id == schema_id, + SchemaField.source_template_id == template_id, + ) + ).all() + + grouped: defaultdict[str, list[str]] = defaultdict(list) + for field in sorted(fields, key=lambda f: f.field_name): + key = field.canonical_name if field.canonical_name else field.field_name + grouped[key].append(field.field_name) + + # One PDF field -> store str; several sharing a canonical -> list (distribute handles both). + field_mapping: dict = {} + for key, names in grouped.items(): + field_mapping[key] = names[0] if len(names) == 1 else names + + junction.field_mapping = field_mapping + session.add(junction) + session.commit() + session.refresh(junction) + return junction + +def get_field_mapping(session: Session, schema_id: int, template_id: int) -> ReportSchemaTemplate: + junction = session.exec(select(ReportSchemaTemplate).where(ReportSchemaTemplate.report_schema_id == schema_id, ReportSchemaTemplate.template_id == template_id)).first() + return junction.field_mapping \ No newline at end of file diff --git a/api/schemas/report_class.py b/api/schemas/report_class.py new file mode 100644 index 0000000..ebc3776 --- /dev/null +++ b/api/schemas/report_class.py @@ -0,0 +1,84 @@ +from pydantic import BaseModel +from datetime import datetime +from api.db.models import Datatype + + +class ReportSchemaCreate(BaseModel): + name: str + description: str + use_case: str + +class ReportSchemaUpdate(BaseModel): + name: str | None = None + description: str | None = None + use_case: str | None = None + +class TemplateAssociation(BaseModel): + template_id: int + +class ReportFill(BaseModel): + input_text: str + +class ReportFillResponse(BaseModel): + schema_id: int + input_text: str + output_pdf_paths: list[str] + +class SchemaFieldUpdate(BaseModel): + description: str | None = None + data_type: Datatype | None = None + word_limit: int | None = None + required: bool | None = None + allowed_values: dict | None = None + canonical_name: str | None = None + + +class SchemaFieldResponse(BaseModel): + id: int + report_schema_id: int + field_name: str + source_template_id: int + description: str + data_type: Datatype + word_limit: int | None + required: bool + allowed_values: dict | None + canonical_name: str | None + + class Config: + from_attributes = True + +class TemplateInSchema(BaseModel): + id: int + template_id: int + report_schema_id: int + field_mapping: dict + + class Config: + from_attributes = True + +class ReportSchemaResponse(BaseModel): + id: int + name: str + description: str + use_case: str + created_at: datetime + templates: list[TemplateInSchema] = [] + fields: list[SchemaFieldResponse] = [] + + class Config: + from_attributes = True + + +class CanonicalFieldEntry(BaseModel): + canonical_name: str + description: str + data_type: Datatype + word_limit: int | None + required: bool + allowed_values: dict | None + source_fields: list[SchemaFieldResponse] + +class CanonicalSchema(BaseModel): + report_schema_id: int + canonical_fields: list[CanonicalFieldEntry] diff --git a/client b/client new file mode 160000 index 0000000..529f7ff --- /dev/null +++ b/client @@ -0,0 +1 @@ +Subproject commit 529f7ffe06ae778ee90d5bb0ad86a6bb91cf94fb diff --git a/tests/unit/test_repositories.py b/tests/unit/test_repositories.py new file mode 100644 index 0000000..13ddb6e --- /dev/null +++ b/tests/unit/test_repositories.py @@ -0,0 +1,346 @@ +import sys +from pathlib import Path + +import pytest +from sqlmodel import SQLModel, Session, create_engine, select + +sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) + +from api.db.models import Datatype, FormSubmission, ReportSchema, ReportSchemaTemplate, SchemaField, Template +from api.db.repositories import ( + add_template_to_schema, + create_form, + create_report_schema, + create_template, + delete_form, + delete_report_schema, + delete_template, + get_form, + get_report_schema, + get_schema_fields, + get_template, + list_report_schemas, + remove_template_from_schema, + update_form, + update_report_schema, + update_schema_field, + update_template, + update_template_mapping, +) + + +test_engine = create_engine("sqlite:///:memory:", connect_args={"check_same_thread": False}) + + +@pytest.fixture(name="session") +def session_fixture(): + SQLModel.metadata.create_all(test_engine) + with Session(test_engine) as session: + yield session + SQLModel.metadata.drop_all(test_engine) + + +def _mk_schema(session: Session, name: str = "schema") -> ReportSchema: + return create_report_schema(session, ReportSchema(name=name, description=f"{name}-desc", use_case=f"{name}-use")) + + +def _mk_template(session: Session, name: str = "template", fields: dict | None = None) -> Template: + return create_template( + session, + Template(name=name, fields=fields if fields is not None else {"f1": "v1"}, pdf_path=f"{name}.pdf"), + ) + + +def test_create_get_update_and_delete_template(session: Session): + created = _mk_template(session, "t-main", {"a": "b"}) + + # test that creation is accurate + assert created.id is not None + assert created.name == "t-main" + assert created.fields == {"a": "b"} + assert created.pdf_path == "t-main.pdf" + + fetched = get_template(session, created.id) + + # test whether the fetched and created templates match + assert fetched is not None + assert fetched.id == created.id + assert fetched.name == "t-main" + assert fetched.fields == {"a": "b"} + assert fetched.pdf_path == "t-main.pdf" + + # test whether updates are persistent and are done correctly + _ = update_template(session, fetched.id, { "name" : "updated-name", "fields" :{"ua" : "ub"}, "pdf_path" : "t-updated.pdf"}) + updated = get_template(session, fetched.id) + assert updated is not None + assert updated.id == created.id + assert updated.name == "updated-name" + assert updated.fields == {"ua": "ub"} + assert updated.pdf_path == "t-updated.pdf" + + # test that deleting works and double deleting does not work + assert delete_template(session, fetched.id) is True + assert delete_template(session, fetched.id) is False + + # test that getting a template that does not exist does not work + assert get_template(session, 999999) is None + + +def test_delete_template_cascades_forms_and_schema_links(session: Session): + schema = _mk_schema(session, "s-cascade") + tpl = _mk_template(session, "t-cascade", {"a": "string", "b": "string"}) + add_template_to_schema(session, schema_id=schema.id, template_id=tpl.id) + assert len(get_schema_fields(session, schema.id)) == 2 + + form = create_form( + session, + FormSubmission( + template_id=tpl.id, + input_text="hi", + output_pdf_path="/out.pdf", + ), + ) + + assert delete_template(session, tpl.id) is True + assert get_template(session, tpl.id) is None + assert get_schema_fields(session, schema.id) == [] + assert get_form(session, form.id) is None + assert ( + session.exec( + select(ReportSchemaTemplate).where( + ReportSchemaTemplate.template_id == tpl.id + ) + ).first() + is None + ) + + +def test_create_get_update_and_delete_form_submission(session: Session): + form = FormSubmission(template_id=123, input_text="sample input", output_pdf_path="/tmp/out.pdf") + created = create_form(session, form) + + # test creation of form is correct + assert created.id is not None + assert created.template_id == 123 + assert created.input_text == "sample input" + assert created.output_pdf_path == "/tmp/out.pdf" + + fetched = get_form(session, created.id) + + # test whether the fetched and created forms match + assert fetched.id == created.id + assert fetched.template_id == 123 + assert fetched.input_text == "sample input" + assert fetched.output_pdf_path == "/tmp/out.pdf" + + # test whether updates are persistent and are done correctly + _ = update_form(session, fetched.id, { "template_id" : 321, "input_text" : "input sample", "output_pdf_path" : "t-updated.pdf"}) + updated = get_form(session, fetched.id) + assert updated is not None + assert updated.id == created.id + assert updated.template_id == 321 + assert updated.input_text == "input sample" + assert updated.output_pdf_path == "t-updated.pdf" + + # test that deletion works and double deletion does not work + assert delete_form(session, fetched.id) is True + assert delete_form(session, fetched.id) is False + + # test that getting a template that does not exist does not work + assert get_form(session, 999999) is None + + +def test_create_get_list_update_and_delete_report_schema(session: Session): + s1 = _mk_schema(session, "s1") + s2 = _mk_schema(session, "s2") + + # implicitly tests schema creation and directly tests fetching + fetched = get_report_schema(session, s1.id) + assert fetched is not None + assert fetched.id == s1.id + assert fetched.name == "s1" + assert fetched.description == "s1-desc" + assert fetched.use_case == "s1-use" + + # test getting a template that does not exist does not work + assert get_report_schema(session, 999999) is None + + # test listing all schemas works correctly + listed = list_report_schemas(session) + assert {s.name for s in listed} == {"s1", "s2"} + + # test updating a schema works correctly + updated = update_report_schema(session, s1.id, {"name": "s1-new", "use_case": "u-new"}) + assert updated is not None + assert updated.name == "s1-new" + assert updated.description == "s1-desc" + assert updated.use_case == "u-new" + + # test updating a schema that does not exist does not work + assert update_report_schema(session, 999999, {"name": "x"}) is None + + +def test_add_template_to_schema_creates_junction_and_schema_fields(session: Session): + schema = _mk_schema(session) + template = _mk_template(session, fields={"field1": "x", "field2": "y"}) + + _ = add_template_to_schema(session, schema.id, template.id) + junction = session.get(ReportSchemaTemplate, _.id) + + # assert that the junction was created and created with the correct details + assert junction is not None + assert junction.report_schema_id == schema.id + assert junction.template_id == template.id + + fields = get_schema_fields(session, schema.id) + + # assert correct number of fields were created + assert len(fields) == 2 + + # assert all fields were created using correct details + assert {f.field_name for f in fields} == {"field1", "field2"} + assert {f.source_template_id for f in fields} == {template.id} + assert {f.report_schema_id for f in fields} == {schema.id} + +def test_delete_report_schema_deletes_schema_fields_and_junctions(session: Session): + schema = _mk_schema(session, "cascade") + t1 = _mk_template(session, "t1", {"f1": "v1", "f2": "v2"}) + t2 = _mk_template(session, "t2", {"f3": "v3"}) + add_template_to_schema(session, schema.id, t1.id) + add_template_to_schema(session, schema.id, t2.id) + + assert len(get_schema_fields(session, schema.id)) == 3 + assert session.query(ReportSchemaTemplate).count() == 2 + + # test deletion works correctly and fields as well as juncitions are deleted + assert delete_report_schema(session, schema.id) is True + assert get_report_schema(session, schema.id) is None + assert get_schema_fields(session, schema.id) == [] + assert session.query(ReportSchemaTemplate).count() == 0 + + # test double deletion and deleting schemas that do not exist + assert delete_report_schema(session, schema.id) is False + assert delete_report_schema(session, 424242) is False + + +def test_add_template_to_schema_supports_empty_template_fields(session: Session): + schema = _mk_schema(session, "empty-fields") + template = _mk_template(session, "empty-template", fields={}) + + junction = add_template_to_schema(session, schema.id, template.id) + + assert junction.id is not None + assert get_schema_fields(session, schema.id) == [] + +def test_add_template_to_schema_raises_for_missing_template_or_schema(session: Session): + schema = _mk_schema(session, "schema-only") + template = _mk_template(session, "template-only") + + with pytest.raises(ValueError, match="Template .* not found"): + add_template_to_schema(session, schema.id, 999999) + + with pytest.raises(ValueError, match="ReportSchema .* not found"): + add_template_to_schema(session, 999999, template.id) + + +def test_add_template_to_schema_duplicate_association_creates_extra_rows(session: Session): + schema = _mk_schema(session, "dup-schema") + template = _mk_template(session, "dup-template", {"f1": "v1"}) + + add_template_to_schema(session, schema.id, template.id) + add_template_to_schema(session, schema.id, template.id) + + assert session.query(ReportSchemaTemplate).count() == 2 + fields = get_schema_fields(session, schema.id) + assert len(fields) == 2 + assert all(field.field_name == "f1" for field in fields) + + +def test_remove_template_from_schema_removes_only_target_template_rows(session: Session): + schema = _mk_schema(session, "remove") + t1 = _mk_template(session, "t1", {"a": "1"}) + t2 = _mk_template(session, "t2", {"b": "2"}) + add_template_to_schema(session, schema.id, t1.id) + add_template_to_schema(session, schema.id, t2.id) + + assert remove_template_from_schema(session, schema.id, t1.id) is True + + remaining_fields = get_schema_fields(session, schema.id) + assert len(remaining_fields) == 1 + assert remaining_fields[0].field_name == "b" + assert remaining_fields[0].source_template_id == t2.id + assert remove_template_from_schema(session, schema.id, t1.id) is False + assert remove_template_from_schema(session, 101010, 202020) is False + + +def test_get_schema_fields_returns_fields_for_only_given_schema(session: Session): + s1 = _mk_schema(session, "s1") + s2 = _mk_schema(session, "s2") + t1 = _mk_template(session, "t1", {"f1": "v1", "f2": "v2"}) + t2 = _mk_template(session, "t2", {"x": "y"}) + add_template_to_schema(session, s1.id, t1.id) + add_template_to_schema(session, s2.id, t2.id) + + s1_fields = get_schema_fields(session, s1.id) + assert len(s1_fields) == 2 + assert {f.field_name for f in s1_fields} == {"f1", "f2"} + + +def test_update_schema_field_updates_all_supported_metadata(session: Session): + schema = _mk_schema(session, "meta") + template = _mk_template(session, "meta-t", {"status": "draft"}) + add_template_to_schema(session, schema.id, template.id) + field = get_schema_fields(session, schema.id)[0] + + updates = { + "description": "Status of the workflow", + "data_type": Datatype.ENUM, + "word_limit": 3, + "required": True, + "allowed_values": {"values": ["draft", "final"]}, + "canonical_name": "status_canonical", + } + updated = update_schema_field(session, schema.id, field.id, updates) + assert updated is not None + + refreshed = session.get(SchemaField, field.id) + assert refreshed is not None + assert refreshed.description == updates["description"] + assert refreshed.data_type == updates["data_type"] + assert refreshed.word_limit == updates["word_limit"] + assert refreshed.required is True + assert refreshed.allowed_values == updates["allowed_values"] + assert refreshed.canonical_name == updates["canonical_name"] + + +def test_update_schema_field_returns_none_for_missing_or_mismatched_field(session: Session): + s1 = _mk_schema(session, "s1") + s2 = _mk_schema(session, "s2") + t = _mk_template(session, "t", {"f1": "v1"}) + add_template_to_schema(session, s1.id, t.id) + field = get_schema_fields(session, s1.id)[0] + + assert update_schema_field(session, s2.id, field.id, {"description": "x"}) is None + assert update_schema_field(session, s1.id, 999999, {"description": "x"}) is None + + +def test_update_template_mapping_uses_canonical_name_or_fallback_field_name(session: Session): + schema = _mk_schema(session, "mapping") + template = _mk_template(session, "mapping-t", {"f1": "v1", "f2": "v2"}) + add_template_to_schema(session, schema.id, template.id) + fields = sorted(get_schema_fields(session, schema.id), key=lambda f: f.field_name) + + update_schema_field(session, schema.id, fields[0].id, {"canonical_name": "canon_f1"}) + # fields[1] intentionally left without canonical_name to test fallback + + junction = update_template_mapping(session, schema.id, template.id) + assert junction is not None + assert junction.field_mapping == {"canon_f1": "f1", "f2": "f2"} + + +def test_update_template_mapping_returns_none_when_junction_missing(session: Session): + schema = _mk_schema(session, "missing-junction") + template = _mk_template(session, "missing-junction-t") + + # No call to add_template_to_schema, so no junction exists. + assert update_template_mapping(session, schema.id, template.id) is None