From 67604175945b3dfedab9f825eddb268938803960 Mon Sep 17 00:00:00 2001 From: Caleb Muthama Date: Mon, 30 Mar 2026 13:32:44 +0300 Subject: [PATCH] feat:robust report structure management class --- .gitignore | 3 +- api/db/database.py | 6 +- api/db/init_db.py | 4 +- api/db/models.py | 42 +++++- api/db/repositories.py | 278 ++++++++++++++++++++++++++++++++++- api/main.py | 22 ++- api/routes/report_schemas.py | 149 +++++++++++++++++++ api/routes/templates.py | 129 +++++++++++++++- api/schemas/report_class.py | 84 +++++++++++ api/schemas/templates.py | 10 +- requirements.txt | 1 + src/controller.py | 24 ++- src/file_manipulator.py | 73 ++++++++- src/filler.py | 47 +++++- src/pdf_utils.py | 17 +++ src/report_schema.py | 129 ++++++++++++++++ 16 files changed, 982 insertions(+), 36 deletions(-) create mode 100644 api/routes/report_schemas.py create mode 100644 api/schemas/report_class.py create mode 100644 src/pdf_utils.py create mode 100644 src/report_schema.py diff --git a/.gitignore b/.gitignore index 7fa2022..359e827 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ .idea venv .venv -*.db \ No newline at end of file +*.db +template_files/ \ No newline at end of file diff --git a/api/db/database.py b/api/db/database.py index 7943947..a117252 100644 --- a/api/db/database.py +++ b/api/db/database.py @@ -1,10 +1,12 @@ from sqlmodel import create_engine, Session +import os -DATABASE_URL = "sqlite:///./fireform.db" +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +DATABASE_URL = f"sqlite:///{os.path.join(BASE_DIR, 'fireform.db')}" engine = create_engine( DATABASE_URL, - echo=True, + echo=False, connect_args={"check_same_thread": False}, ) diff --git a/api/db/init_db.py b/api/db/init_db.py index 9ad27ea..c868db4 100644 --- a/api/db/init_db.py +++ b/api/db/init_db.py @@ -1,6 +1,6 @@ from sqlmodel import SQLModel -from api.db.database import engine -from api.db import models +from database import engine +import models def init_db(): SQLModel.metadata.create_all(engine) diff --git a/api/db/models.py b/api/db/models.py index f76c93b..bbbbdff 100644 --- a/api/db/models.py +++ b/api/db/models.py @@ -1,10 +1,11 @@ -from sqlmodel import SQLModel, Field +from sqlmodel import SQLModel, Field, UniqueConstraint from sqlalchemy import Column, JSON from datetime import datetime +from enum import Enum class Template(SQLModel, table=True): id: int | None = Field(default=None, primary_key=True) - name: str + name: str = Field(unique=True) fields: dict = Field(sa_column=Column(JSON)) pdf_path: str created_at: datetime = Field(default_factory=datetime.utcnow) @@ -15,4 +16,39 @@ class FormSubmission(SQLModel, table=True): template_id: int input_text: str output_pdf_path: str - created_at: datetime = Field(default_factory=datetime.utcnow) \ No newline at end of file + created_at: datetime = Field(default_factory=datetime.utcnow) + + +class ReportSchema(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str = Field(unique=True) + description: str + use_case: str + created_at: datetime = Field(default_factory=datetime.utcnow) + +class ReportSchemaTemplate(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + template_id: int + report_schema_id: int + field_mapping: dict = Field(default={}, sa_column=Column(JSON)) + + __table_args__ = (UniqueConstraint("template_id", "report_schema_id"),) + +class Datatype(str, Enum): + STRING = "string" + INT = "int" + DATE = "date" + ENUM = 'enum' + + +class SchemaField(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + report_schema_id: int + field_name: str + source_template_id: int + description: str = Field(default="") + data_type: Datatype = Field(default=Datatype.STRING) + word_limit: int | None = Field(default=None) + required: bool = Field(default=False) + allowed_values: dict | None = Field(sa_column=Column(JSON)) + canonical_name: str | None = Field(default=None) diff --git a/api/db/repositories.py b/api/db/repositories.py index 6608718..2a59e21 100644 --- a/api/db/repositories.py +++ b/api/db/repositories.py @@ -1,19 +1,285 @@ +from ast import For +from collections import defaultdict +from sqlalchemy.exc import IntegrityError from sqlmodel import Session, select -from api.db.models import Template, FormSubmission +from api.db.models import ( + Template, + FormSubmission, + ReportSchema, + ReportSchemaTemplate, + SchemaField, +) -# Templates def create_template(session: Session, template: Template) -> Template: + try: + session.add(template) + session.commit() + session.refresh(template) + return template + except IntegrityError: + raise + +def get_template(session: Session, template_id: int) -> Template | None: + return session.get(Template, template_id) + +def update_template(session: Session, template_id: int, updates: dict) -> Template | None: + template = session.get(Template, template_id) + if not template: + return None + for key, value in updates.items(): + setattr(template, key, value) session.add(template) session.commit() session.refresh(template) return template -def get_template(session: Session, template_id: int) -> Template | None: - return session.get(Template, template_id) +def list_templates(session: Session) -> list[Template]: + return session.exec(select(Template)).all() + +def delete_template(session: Session, template_id: int) -> bool: + """Remove template and dependent rows (form submissions, schema links, schema fields).""" + template = session.get(Template, template_id) + if not template: + return False + + for form in session.exec( + select(FormSubmission).where(FormSubmission.template_id == template_id) + ).all(): + session.delete(form) + + for junction in session.exec( + select(ReportSchemaTemplate).where( + ReportSchemaTemplate.template_id == template_id + ) + ).all(): + for field in session.exec( + select(SchemaField).where( + SchemaField.report_schema_id == junction.report_schema_id, + SchemaField.source_template_id == template_id, + ) + ).all(): + session.delete(field) + session.delete(junction) + + session.delete(template) + session.commit() + return True -# Forms def create_form(session: Session, form: FormSubmission) -> FormSubmission: session.add(form) session.commit() session.refresh(form) - return form \ No newline at end of file + return form + +def get_form(session: Session, form_id: int) -> FormSubmission: + return session.get(FormSubmission, form_id) + +def update_form(session: Session, form_id: int, updates: dict) -> FormSubmission | None: + form = session.get(FormSubmission, form_id) + if not form: + return None + for key, value in updates.items(): + setattr(form, key, value) + session.add(form) + session.commit() + session.refresh(form) + return form + +def delete_form(session: Session, form_id: int) -> FormSubmission: + form_submission = session.get(FormSubmission, form_id) + if form_submission: + session.delete(form_submission) + session.commit() + return True + return False + +def create_report_schema(session: Session, schema: ReportSchema) -> ReportSchema: + try: + session.add(schema) + session.commit() + session.refresh(schema) + return schema + except IntegrityError: + raise + +def get_report_schema(session: Session, schema_id: int) -> ReportSchema | None: + return session.get(ReportSchema, schema_id) + +def list_report_schemas(session: Session) -> list[ReportSchema]: + return session.exec(select(ReportSchema)).all() + +def update_report_schema(session: Session, schema_id: int, updates: dict) -> ReportSchema | None: + schema = session.get(ReportSchema, schema_id) + if not schema: + return None + for key, value in updates.items(): + setattr(schema, key, value) + session.add(schema) + session.commit() + session.refresh(schema) + return schema + +def delete_report_schema(session: Session, schema_id: int) -> bool: + schema = session.get(ReportSchema, schema_id) + if not schema: + return False + + fields = session.exec( + select(SchemaField).where(SchemaField.report_schema_id == schema_id) + ).all() + for field in fields: + session.delete(field) + + junctions = session.exec( + select(ReportSchemaTemplate).where( + ReportSchemaTemplate.report_schema_id == schema_id + ) + ).all() + for junction in junctions: + session.delete(junction) + + session.delete(schema) + session.commit() + return True + + +def add_template_to_schema( + session: Session, schema_id: int, template_id: int +) -> ReportSchemaTemplate: + """Associate a template with a schema. + + Looks up `template.fields` and auto-creates a SchemaField for each field, + pre-populated with `field_name` and `source_template_id`. + Other metadata is left as defaults for the user to fill in later. + """ + template = session.get(Template, template_id) + if not template: + raise ValueError(f"Template {template_id} not found") + + schema = session.get(ReportSchema, schema_id) + if not schema: + raise ValueError(f"ReportSchema {schema_id} not found") + + # exists = session.exec(select(ReportSchemaTemplate).where(ReportSchemaTemplate.report_schema_id == schema_id, ReportSchemaTemplate.template_id == template_id)).first() + # if exists: + # raise IntegrityError + + # Create the junction record (field_mapping starts empty, populated during canonization) + junction = ReportSchemaTemplate( + report_schema_id=schema_id, + template_id=template_id, + ) + + session.add(junction) + + # Auto-create a SchemaField for each field in the template + for field_name in template.fields: + schema_field = SchemaField( + report_schema_id=schema_id, + field_name=field_name, + source_template_id=template_id, + ) + session.add(schema_field) + + session.commit() + session.refresh(junction) + return junction + +def remove_template_from_schema( + session: Session, schema_id: int, template_id: int +) -> bool: + """Disassociate a template from a schema and remove its SchemaField entries.""" + junction = session.exec( + select(ReportSchemaTemplate).where( + ReportSchemaTemplate.report_schema_id == schema_id, + ReportSchemaTemplate.template_id == template_id, + ) + ).first() + if not junction: + return False + + fields = session.exec( + select(SchemaField).where( + SchemaField.report_schema_id == schema_id, + SchemaField.source_template_id == template_id, + ) + ).all() + for field in fields: + session.delete(field) + + session.delete(junction) + session.commit() + return True + + +def get_schema_fields(session: Session, schema_id: int) -> list[SchemaField]: + return session.exec( + select(SchemaField).where(SchemaField.report_schema_id == schema_id) + ).all() + +def get_schema_field(session: Session, field_id: int) -> SchemaField: + return session.get(SchemaField, field_id) + +def update_schema_field(session: Session, schema_id: int, field_id: int, updates: dict) -> SchemaField | None: + """Update field metadata: description, data_type, word_limit, required, allowed_values. + + Validates that the field belongs to the given schema before updating, + so the same template field in different schemas can have independent metadata. + """ + field = session.get(SchemaField, field_id) + if not field or field.report_schema_id != schema_id: + return None + for key, value in updates.items(): + setattr(field, key, value) + session.add(field) + session.commit() + session.refresh(field) + return field + + +# ── Template mapping (post-canonization) ───────────────────────────────────── + +def update_template_mapping( + session: Session, schema_id: int, template_id: int +) -> ReportSchemaTemplate | None: + """Auto-generate and store the canonical → PDF field mapping after canonization. + + Builds the mapping by looking up all SchemaFields for this schema+template pair + and mapping each field's canonical_name → field_name. + """ + junction = session.exec( + select(ReportSchemaTemplate).where( + ReportSchemaTemplate.report_schema_id == schema_id, + ReportSchemaTemplate.template_id == template_id, + ) + ).first() + if not junction: + return None + + # Build mapping from SchemaFields that have been canonized + fields = session.exec( + select(SchemaField).where( + SchemaField.report_schema_id == schema_id, + SchemaField.source_template_id == template_id, + ) + ).all() + + grouped: defaultdict[str, list[str]] = defaultdict(list) + for field in sorted(fields, key=lambda f: f.field_name): + key = field.canonical_name if field.canonical_name else field.field_name + grouped[key].append(field.field_name) + + # One PDF field -> store str; several sharing a canonical -> list (distribute handles both). + field_mapping: dict = {} + for key, names in grouped.items(): + field_mapping[key] = names[0] if len(names) == 1 else names + + junction.field_mapping = field_mapping + session.add(junction) + session.commit() + session.refresh(junction) + return junction + +def get_field_mapping(session: Session, schema_id: int, template_id: int) -> ReportSchemaTemplate: + junction = session.exec(select(ReportSchemaTemplate).where(ReportSchemaTemplate.report_schema_id == schema_id, ReportSchemaTemplate.template_id == template_id)).first() + return junction.field_mapping \ No newline at end of file diff --git a/api/main.py b/api/main.py index d0b8c79..61e63dc 100644 --- a/api/main.py +++ b/api/main.py @@ -1,7 +1,25 @@ from fastapi import FastAPI -from api.routes import templates, forms +from fastapi.middleware.cors import CORSMiddleware + +from api.routes import templates, forms, report_schemas +from api.errors.handlers import register_exception_handlers app = FastAPI() +register_exception_handlers(app) + +app.add_middleware( + CORSMiddleware, + allow_origins=[ + "http://127.0.0.1:5173", + "http://localhost:5173", + "http://127.0.0.1:4173", + "http://localhost:4173", + ], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) app.include_router(templates.router) -app.include_router(forms.router) \ No newline at end of file +app.include_router(forms.router) +app.include_router(report_schemas.router) \ No newline at end of file diff --git a/api/routes/report_schemas.py b/api/routes/report_schemas.py new file mode 100644 index 0000000..5b9ff65 --- /dev/null +++ b/api/routes/report_schemas.py @@ -0,0 +1,149 @@ +from sqlite3 import IntegrityError +from fastapi import APIRouter, Depends, HTTPException +from sqlmodel import Session, select +from api.deps import get_db +from api.schemas.report_class import ( + ReportSchemaCreate, + ReportSchemaUpdate, + ReportSchemaResponse, + TemplateAssociation, + SchemaFieldResponse, + SchemaFieldUpdate, + CanonicalSchema, + ReportFill, + ReportFillResponse, +) +from api.db import repositories as repo +from api.db.models import ReportSchema +from src.report_schema import ReportSchemaProcessor +from src.controller import Controller +from api.db.models import FormSubmission, ReportSchemaTemplate +from sqlalchemy.exc import IntegrityError + +router = APIRouter(prefix="/schemas", tags=["schemas"]) + + +@router.post("/create", response_model=ReportSchemaResponse) +def create_schema(data: ReportSchemaCreate, db: Session = Depends(get_db)): + schema = ReportSchema(**data.model_dump()) + try: + return repo.create_report_schema(db, schema) + except IntegrityError: + raise HTTPException( + status_code=409, + detail="A schema with this name already exists" + ) + +@router.get("/", response_model=list[ReportSchemaResponse]) +def list_schemas(db: Session = Depends(get_db)): + return repo.list_report_schemas(db) + +@router.get("/{schema_id}", response_model=ReportSchemaResponse) +def get_schema(schema_id: int, db: Session = Depends(get_db)): + schema = repo.get_report_schema(db, schema_id) + if not schema: + raise HTTPException(status_code=404, detail="Schema not found") + return schema + +@router.put("/{schema_id}", response_model=ReportSchemaResponse) +def update_schema(schema_id: int, data: ReportSchemaUpdate, db: Session = Depends(get_db)): + updates = data.model_dump(exclude_none=True) + schema = repo.update_report_schema(db, schema_id, updates) + if not schema: + raise HTTPException(status_code=404, detail="Schema not found") + return schema + +@router.delete("/{schema_id}") +def delete_schema(schema_id: int, db: Session = Depends(get_db)): + deleted = repo.delete_report_schema(db, schema_id) + if not deleted: + raise HTTPException(status_code=404, detail="Schema not found") + return {"detail": "Schema deleted"} + + +@router.post("/{schema_id}/templates", response_model=list[SchemaFieldResponse]) +def add_template(schema_id: int, data: TemplateAssociation, db: Session = Depends(get_db)): + """Associate a template with a schema. + + Auto-creates SchemaField entries from template.fields and returns them. + """ + try: + repo.add_template_to_schema(db, schema_id, data.template_id) + except IntegrityError: + raise HTTPException(status_code=409, detail="Template is already added to schema") + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + return repo.get_schema_fields(db, schema_id) + +@router.delete("/{schema_id}/templates/{template_id}") +def remove_template(schema_id: int, template_id: int, db: Session = Depends(get_db)): + removed = repo.remove_template_from_schema(db, schema_id, template_id) + if not removed: + raise HTTPException(status_code=404, detail="Template association not found") + return {"detail": "Template disassociated"} + + + +@router.get("/{schema_id}/fields", response_model=list[SchemaFieldResponse]) +def list_fields(schema_id: int, db: Session = Depends(get_db)): + return repo.get_schema_fields(db, schema_id) + +@router.put("/{schema_id}/fields/{field_id}", response_model=SchemaFieldResponse) +def update_field(schema_id: int, field_id: int, data: SchemaFieldUpdate, db: Session = Depends(get_db)): + updates = data.model_dump(exclude_none=True) + field = repo.update_schema_field(db, schema_id, field_id, updates) + if not field: + raise HTTPException(status_code=404, detail="Field not found or does not belong to this schema") + return field + + +@router.post("/{schema_id}/canonize", response_model=CanonicalSchema) +def canonize_schema(schema_id: int, db: Session = Depends(get_db)): + """Trigger canonization: group fields, assign canonical names, generate field mappings.""" + schema = repo.get_report_schema(db, schema_id) + if not schema: + raise HTTPException(status_code=404, detail="Schema not found") + + return ReportSchemaProcessor.canonize(db, schema_id) + +@router.get("/mapping/{schema_id}/{template_id}") +def get_schema_template_mapping(schema_id: int, template_id: int, db: Session = Depends(get_db)): + return repo.get_field_mapping(db, schema_id, template_id) + +@router.post("/{schema_id}/fill", response_model=ReportFillResponse) +def fill_schema(schema_id: int, data: ReportFill, db: Session = Depends(get_db)): + """ + End-to-end report generation. + Takes a single transcript, extracts canonical fields, distributes to + all schema templates, fills them, and logs the submissions. + """ + schema = repo.get_report_schema(db, schema_id) + if not schema: + raise HTTPException(status_code=404, detail="Schema not found") + + controller = Controller() + + output_paths = controller.fill_report(db, data.input_text, schema_id) + + # 2. Log submissions + junctions = db.exec( + select(ReportSchemaTemplate).where( + ReportSchemaTemplate.report_schema_id == schema_id + ) + ).all() + + for template_id, path in output_paths.items(): + submission = FormSubmission( + template_id=template_id, + input_text=data.input_text, + output_pdf_path=path + ) + db.add(submission) + + db.commit() + + return ReportFillResponse( + schema_id=schema_id, + input_text=data.input_text, + output_pdf_paths=list(output_paths.values()) + ) diff --git a/api/routes/templates.py b/api/routes/templates.py index 5c2281b..dcacecc 100644 --- a/api/routes/templates.py +++ b/api/routes/templates.py @@ -1,16 +1,131 @@ -from fastapi import APIRouter, Depends +import re +from sqlalchemy.exc import IntegrityError +import uuid +from pathlib import Path + +from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile +from fastapi.responses import FileResponse from sqlmodel import Session + from api.deps import get_db -from api.schemas.templates import TemplateCreate, TemplateResponse -from api.db.repositories import create_template from api.db.models import Template +from api.db.repositories import ( + create_template, + delete_template, + get_template, + update_template, + list_templates +) +from api.schemas.templates import TemplateResponse, TemplateUpdate from src.controller import Controller router = APIRouter(prefix="/templates", tags=["templates"]) +INPUT_FILES_DIR = Path(__file__).resolve().parents[2] / "template_files" + + +def _safe_name_fragment(name: str) -> str: + base = Path(name).name + s = re.sub(r"[^\w\-.]+", "_", base.strip(), flags=re.UNICODE) + s = s.strip("._-") or "template" + return s[:120] + + @router.post("/create", response_model=TemplateResponse) -def create(template: TemplateCreate, db: Session = Depends(get_db)): +def create( + name: str = Form(...), + file: UploadFile = File(...), + db: Session = Depends(get_db), +): + filename = (file.filename or "").lower() + if not filename.endswith(".pdf"): + raise HTTPException(status_code=400, detail="File must be a .pdf") + + frag = _safe_name_fragment(name) + uid = uuid.uuid4().hex + INPUT_FILES_DIR.mkdir(parents=True, exist_ok=True) + dest = INPUT_FILES_DIR / f"{frag}_{uid}.pdf" + + raw = file.file.read() + if not raw: + raise HTTPException(status_code=400, detail="Empty file") + dest.write_bytes(raw) + controller = Controller() - template_path = controller.create_template(template.pdf_path) - tpl = Template(**template.model_dump(exclude={"pdf_path"}), pdf_path=template_path) - return create_template(db, tpl) \ No newline at end of file + try: + template_path = controller.create_template(str(dest)) + except Exception as e: + dest.unlink(missing_ok=True) + print(e) + raise HTTPException( + status_code=500, detail=f"Failed to prepare PDF template: {e}" + ) from e + + fields = controller.extract_template_fields(template_path) + tpl = Template(name=name.strip(), fields=fields, pdf_path=template_path) + + try: + return create_template(db, tpl) + except IntegrityError: + raise HTTPException( + status_code=409, + detail="A template with the same name already exists" + ) + +@router.get("/", response_model=list[Template]) +def list(db: Session = Depends(get_db)): + return list_templates(db) + + +@router.get("/{template_id}/pdf") +def get_template_pdf(template_id: int, db: Session = Depends(get_db)): + """Serve the stored PDF for preview in the schema wizard.""" + tpl = get_template(db, template_id) + if not tpl: + raise HTTPException(status_code=404, detail="Template not found") + root = INPUT_FILES_DIR.resolve() + path = Path(tpl.pdf_path).resolve() + try: + path.relative_to(root) + except ValueError: + raise HTTPException(status_code=403, detail="Invalid template file location") + if not path.is_file(): + raise HTTPException(status_code=404, detail="PDF file missing on disk") + return FileResponse( + path, + media_type="application/pdf", + filename=f"{tpl.name}.pdf", + ) + + +@router.get("/{template_id}", response_model=TemplateResponse) +def get_one(template_id: int, db: Session = Depends(get_db)): + tpl = get_template(db, template_id) + if not tpl: + raise HTTPException(status_code=404, detail="Template not found") + return tpl + + +@router.put("/{template_id}", response_model=TemplateResponse) +def update_one( + template_id: int, + data: TemplateUpdate, + db: Session = Depends(get_db), +): + updates = data.model_dump(exclude_none=True) + if not updates: + tpl = get_template(db, template_id) + if not tpl: + raise HTTPException(status_code=404, detail="Template not found") + return tpl + tpl = update_template(db, template_id, updates) + if not tpl: + raise HTTPException(status_code=404, detail="Template not found") + return tpl + + +@router.delete("/{template_id}") +def delete_one(template_id: int, db: Session = Depends(get_db)): + if not delete_template(db, template_id): + raise HTTPException(status_code=404, detail="Template not found") + return {"detail": "Template deleted"} diff --git a/api/schemas/report_class.py b/api/schemas/report_class.py new file mode 100644 index 0000000..ebc3776 --- /dev/null +++ b/api/schemas/report_class.py @@ -0,0 +1,84 @@ +from pydantic import BaseModel +from datetime import datetime +from api.db.models import Datatype + + +class ReportSchemaCreate(BaseModel): + name: str + description: str + use_case: str + +class ReportSchemaUpdate(BaseModel): + name: str | None = None + description: str | None = None + use_case: str | None = None + +class TemplateAssociation(BaseModel): + template_id: int + +class ReportFill(BaseModel): + input_text: str + +class ReportFillResponse(BaseModel): + schema_id: int + input_text: str + output_pdf_paths: list[str] + +class SchemaFieldUpdate(BaseModel): + description: str | None = None + data_type: Datatype | None = None + word_limit: int | None = None + required: bool | None = None + allowed_values: dict | None = None + canonical_name: str | None = None + + +class SchemaFieldResponse(BaseModel): + id: int + report_schema_id: int + field_name: str + source_template_id: int + description: str + data_type: Datatype + word_limit: int | None + required: bool + allowed_values: dict | None + canonical_name: str | None + + class Config: + from_attributes = True + +class TemplateInSchema(BaseModel): + id: int + template_id: int + report_schema_id: int + field_mapping: dict + + class Config: + from_attributes = True + +class ReportSchemaResponse(BaseModel): + id: int + name: str + description: str + use_case: str + created_at: datetime + templates: list[TemplateInSchema] = [] + fields: list[SchemaFieldResponse] = [] + + class Config: + from_attributes = True + + +class CanonicalFieldEntry(BaseModel): + canonical_name: str + description: str + data_type: Datatype + word_limit: int | None + required: bool + allowed_values: dict | None + source_fields: list[SchemaFieldResponse] + +class CanonicalSchema(BaseModel): + report_schema_id: int + canonical_fields: list[CanonicalFieldEntry] diff --git a/api/schemas/templates.py b/api/schemas/templates.py index 961f219..8ef630e 100644 --- a/api/schemas/templates.py +++ b/api/schemas/templates.py @@ -1,9 +1,11 @@ from pydantic import BaseModel -class TemplateCreate(BaseModel): - name: str - pdf_path: str - fields: dict + +class TemplateUpdate(BaseModel): + name: str | None = None + fields: dict | None = None + pdf_path: str | None = None + class TemplateResponse(BaseModel): id: int diff --git a/requirements.txt b/requirements.txt index eaa6c81..558ac18 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ pdfrw flask commonforms fastapi +python-multipart uvicorn pydantic sqlmodel diff --git a/src/controller.py b/src/controller.py index d31ec9c..a6bab12 100644 --- a/src/controller.py +++ b/src/controller.py @@ -1,4 +1,6 @@ from src.file_manipulator import FileManipulator +from sqlmodel import Session +from src.report_schema import ReportSchemaProcessor class Controller: def __init__(self): @@ -8,4 +10,24 @@ def fill_form(self, user_input: str, fields: list, pdf_form_path: str): return self.file_manipulator.fill_form(user_input, fields, pdf_form_path) def create_template(self, pdf_path: str): - return self.file_manipulator.create_template(pdf_path) \ No newline at end of file + return self.file_manipulator.create_template(pdf_path) + + def extract_template_fields(self, pdf_path: str) -> dict[str, str]: + return self.file_manipulator.extract_template_field_map(pdf_path) + + def fill_report(self, session: Session, user_input: str, schema_id: int) -> dict[int, str]: + """ + Main pipeline entry point for filling a multi-template report schema. + 1. Triggers canonization to get the latest schema definition. + 2. Builds the JSON Schema extraction target for the LLM. + 3. Hands off to FileManipulator for actual processing. + """ + canonical_schema = ReportSchemaProcessor.canonize(session, schema_id) + extraction_target = ReportSchemaProcessor.build_extraction_target(canonical_schema) + + return self.file_manipulator.fill_report( + session=session, + user_input=user_input, + schema_id=schema_id, + canonical_target=extraction_target + ) \ No newline at end of file diff --git a/src/file_manipulator.py b/src/file_manipulator.py index b7815cc..cc54606 100644 --- a/src/file_manipulator.py +++ b/src/file_manipulator.py @@ -1,7 +1,13 @@ import os +from pdfrw import PdfReader from src.filler import Filler from src.llm import LLM +from src.pdf_utils import decode_pdf_name from commonforms import prepare_form +from sqlmodel import Session +from src.report_schema import ReportSchemaProcessor +from api.db.models import Template + class FileManipulator: @@ -13,9 +19,23 @@ def create_template(self, pdf_path: str): """ By using commonforms, we create an editable .pdf template and we store it. """ - template_path = pdf_path[:-4] + "_template.pdf" - prepare_form(pdf_path, template_path) - return template_path + prepare_form(pdf_path, pdf_path) + return pdf_path + + def extract_template_field_map(self, pdf_path: str) -> dict[str, str]: + """AcroForm widget names from a PDF, each mapped to type ``string`` (Template.fields shape).""" + pdf = PdfReader(pdf_path) + names: list[str] = [] + for page in pdf.pages: + if not getattr(page, "Annots", None): + continue + for annot in page.Annots: + if getattr(annot, "Subtype", None) != "/Widget" or not getattr(annot, "T", None): + continue + raw = decode_pdf_name(str(annot.T).strip("() /")) + if raw and raw not in names: + names.append(raw) + return {n: "string" for n in names} def fill_form(self, user_input: str, fields: list, pdf_form_path: str): """ @@ -45,3 +65,50 @@ def fill_form(self, user_input: str, fields: list, pdf_form_path: str): print(f"An error occurred during PDF generation: {e}") # Re-raise the exception so the frontend can handle it raise e + + def fill_report(self, session: Session, user_input: str, schema_id: int, canonical_target: dict) -> dict[int, str]: + """ + Extracts data using a canonical schema target, distributes the results + to all associated templates, and fills them by name. + """ + print(f"[1] Received report fill request for schema {schema_id}.") + print("[2] Starting canonical extraction process...") + + try: + # 1. Extract against the canonical target + self.llm._target_fields = canonical_target + self.llm._transcript_text = user_input + + t2j = self.llm.main_loop() + canonical_data = t2j.get_data() + + print(f"[3] Canonical extraction complete. Distributing to templates...") + + # 2. Distribute to per-template dictionaries + distribution = ReportSchemaProcessor.distribute(session, schema_id, canonical_data) + + # 3. Fill each template + output_paths: dict[int, str] = {} + + for template_id, template_data in distribution.items(): + template = session.get(Template, template_id) + if not template or not os.path.exists(template.pdf_path): + print(f" -> Skipping template {template_id} (not found or missing PDF)") + continue + + print(f" -> Filling template {template_id} ({template.name})...") + output_name = self.filler.fill_form_by_name( + pdf_form=template.pdf_path, + field_values=template_data + ) + output_paths[template_id] = output_name + + print("\n----------------------------------") + print("✅ Report generation complete.") + print(f"Outputs saved to: {list(output_paths.values())}") + + return output_paths + + except Exception as e: + print(f"An error occurred during report generation: {e}") + raise e diff --git a/src/filler.py b/src/filler.py index e31e535..74f8b98 100644 --- a/src/filler.py +++ b/src/filler.py @@ -1,6 +1,8 @@ -from pdfrw import PdfReader, PdfWriter +from pdfrw import PdfReader, PdfWriter, PdfDict, PdfObject from src.llm import LLM +from src.pdf_utils import decode_pdf_name from datetime import datetime +import uuid class Filler: @@ -15,7 +17,7 @@ def fill_form(self, pdf_form: str, llm: LLM): output_pdf = ( pdf_form[:-4] + "_" - + datetime.now().strftime("%Y%m%d_%H%M%S") + + str(uuid.uuid4()) + "_filled.pdf" ) @@ -28,14 +30,14 @@ def fill_form(self, pdf_form: str, llm: LLM): # Read PDF pdf = PdfReader(pdf_form) - # Loop through pages + # Global index across all pages (visual order is per page, pages in document order). + i = 0 for page in pdf.pages: if page.Annots: sorted_annots = sorted( page.Annots, key=lambda a: (-float(a.Rect[1]), float(a.Rect[0])) ) - i = 0 for annot in sorted_annots: if annot.Subtype == "/Widget" and annot.T: if i < len(answers_list): @@ -43,10 +45,45 @@ def fill_form(self, pdf_form: str, llm: LLM): annot.AP = None i += 1 else: - # Stop if we run out of answers break PdfWriter().write(output_pdf, pdf) # Your main.py expects this function to return the path return output_pdf + + + def fill_form_by_name(self, pdf_form: str, field_values: dict[str, str]) -> str: + """ + Fill a PDF form with values from a dictionary mapped by field name. + Unlike `fill_form`, this does not rely on visual ordering, it relies on + the exact field name defined in the PDF template matching a key in `field_values`. + """ + output_pdf = ( + pdf_form[:-4] + + "_" + + str(uuid.uuid4()) + + "_filled.pdf" + ) + + # Read PDF + pdf = PdfReader(pdf_form) + + # Force generation of Appearance Streams so text is visible in standard viewers + if pdf.Root.AcroForm: + pdf.Root.AcroForm.update(PdfDict(NeedAppearances=PdfObject('true'))) + + # Loop through pages + for page in pdf.pages: + if page.Annots: + for annot in page.Annots: + if annot.Subtype == "/Widget" and annot.T: + field_name = decode_pdf_name(str(annot.T).strip("() /")) + + if field_name in field_values: + # Update the PDF annotation + annot.V = f"{field_values[field_name]}" + annot.AP = None + + PdfWriter().write(output_pdf, pdf) + return output_pdf diff --git a/src/pdf_utils.py b/src/pdf_utils.py new file mode 100644 index 0000000..c3ef2f3 --- /dev/null +++ b/src/pdf_utils.py @@ -0,0 +1,17 @@ +import re + +_PDF_ESCAPE_RE = re.compile(r'\\(\d{1,3}|[nrtbf()\\])') +_NAMED_ESCAPES = { + 'n': '\n', 'r': '\r', 't': '\t', 'b': '\b', 'f': '\f', + '(': '(', ')': ')', '\\': '\\', +} + + +def decode_pdf_name(raw: str) -> str: + """Decode all PDF literal-string escape sequences (ISO 32000 §7.3.4.2).""" + def _replace(m: re.Match) -> str: + s = m.group(1) + if s[0].isdigit(): + return chr(int(s, 8)) + return _NAMED_ESCAPES.get(s, s) + return _PDF_ESCAPE_RE.sub(_replace, raw) diff --git a/src/report_schema.py b/src/report_schema.py new file mode 100644 index 0000000..785792f --- /dev/null +++ b/src/report_schema.py @@ -0,0 +1,129 @@ +from sqlmodel import Session, select +from typing import Any +from api.db.models import SchemaField, ReportSchemaTemplate, Datatype +from api.schemas.report_class import CanonicalSchema, CanonicalFieldEntry, SchemaFieldResponse +from api.db.repositories import update_template_mapping, get_report_schema + + +class ReportSchemaProcessor: + @staticmethod + def canonize(session: Session, schema_id: int) -> CanonicalSchema: + """Group fields by their canonical names (falling back to original names).""" + schema = get_report_schema(session, schema_id) + if not schema: + raise ValueError(f"ReportSchema {schema_id} not found") + + # 1. Fetch all fields for this schema + fields = session.exec( + select(SchemaField).where(SchemaField.report_schema_id == schema_id) + ).all() + + # 2. Group fields by their effective canonical name + groups: dict[str, list[SchemaField]] = {} + + for field in fields: + # The manual override rule: If no canonical name is set, use the raw field name + effective_name = field.canonical_name if field.canonical_name else field.field_name + + if effective_name not in groups: + groups[effective_name] = [] + groups[effective_name].append(field) + + # 3. Build the CanonicalSchema representation + canonical_fields = [] + for effective_name, source_fields in groups.items(): + # Use metadata from the first field in the group as the canonical metadata + # (In a more complex system, we might merge these or let the user elect a "primary" field) + primary = source_fields[0] + + canonical_fields.append( + CanonicalFieldEntry( + canonical_name=effective_name, + description=primary.description, + data_type=primary.data_type, + word_limit=primary.word_limit, + required=primary.required, + allowed_values=primary.allowed_values, + source_fields=[SchemaFieldResponse.model_validate(f) for f in source_fields] + ) + ) + + # 4. Update the junction tables so they know how to map back + # We need to do this per-template + template_ids = {f.source_template_id for f in fields} + for t_id in template_ids: + update_template_mapping(session, schema_id, t_id) + + return CanonicalSchema( + report_schema_id=schema_id, + canonical_fields=canonical_fields + ) + + @staticmethod + def build_extraction_target(canonical_schema: CanonicalSchema) -> dict[str, Any]: + """Convert the CanonicalSchema into a JSON schema dict for LLM function calling.""" + properties = {} + required = [] + + type_mapping = { + Datatype.STRING: "string", + Datatype.INT: "integer", + Datatype.DATE: "string", # Represent dates as strings for LLM + Datatype.ENUM: "string" # Enums are strings restricted by allowed_values + } + + for field in canonical_schema.canonical_fields: + field_def = { + "type": type_mapping.get(field.data_type, "string"), + "description": field.description + } + + if field.data_type == Datatype.ENUM and field.allowed_values and "values" in field.allowed_values: + field_def["enum"] = field.allowed_values["values"] + + if field.word_limit: + field_def["description"] += f" (Maximum {field.word_limit} words)" + + properties[field.canonical_name] = field_def + + if field.required: + required.append(field.canonical_name) + + return { + "type": "object", + "properties": properties, + "required": required + } + + @staticmethod + def distribute( + session: Session, schema_id: int, canonical_data: dict[str, Any] + ) -> dict[int, dict[str, Any]]: + """Map canonical extraction output back to individual template fields.""" + junctions = session.exec( + select(ReportSchemaTemplate).where( + ReportSchemaTemplate.report_schema_id == schema_id + ) + ).all() + + distribution = {} + + for junction in junctions: + template_id = junction.template_id + mapping = junction.field_mapping or {} + + template_data = {} + for canonical_name, pdf_targets in mapping.items(): + if canonical_name not in canonical_data: + continue + names = ( + pdf_targets + if isinstance(pdf_targets, list) + else [pdf_targets] + ) + for pdf_field_name in names: + template_data[pdf_field_name] = canonical_data[canonical_name] + + distribution[template_id] = template_data + + return distribution