diff --git a/backend/app/api/ai_schemas.py b/backend/app/api/ai_schemas.py index 6086cdf..afac0fa 100644 --- a/backend/app/api/ai_schemas.py +++ b/backend/app/api/ai_schemas.py @@ -1,5 +1,7 @@ """AI generation schemas: SceneCard, SceneDraft, Context Pack.""" +from typing import Literal + from pydantic import BaseModel, Field @@ -47,3 +49,38 @@ class ChapterSummaryModel(BaseModel): plot_threads: list[str] = Field( default_factory=list, description="情节线索" ) + + +class RewriteRequest(BaseModel): + scene_id: int + text: str = Field( + description="当前场景正文", max_length=100000 + ) + target_chars: int = Field( + default=1500, ge=100, le=50000, + description="目标字数", + ) + mode: Literal["expand", "compress"] = Field( + description="重写模式: expand 或 compress", + ) + + +class WordCountCheckRequest(BaseModel): + text: str = Field( + description="当前场景正文", max_length=100000 + ) + target_chars: int = Field( + default=1500, ge=100, le=50000, + description="目标字数", + ) + + +class WordCountCheck(BaseModel): + status: str = Field(description="within / over / under") + actual_chars: int + target_chars: int + delta: int + deviation: float + suggestion: str | None = Field( + default=None, description="compress / expand / None" + ) diff --git a/backend/app/api/generation.py b/backend/app/api/generation.py index 7424153..aeb94ab 100644 --- a/backend/app/api/generation.py +++ b/backend/app/api/generation.py @@ -6,7 +6,14 @@ from fastapi.responses import StreamingResponse from sqlalchemy.ext.asyncio import AsyncSession -from app.api.ai_schemas import SceneCard, SceneCardRequest, SceneDraftRequest +from app.api.ai_schemas import ( + RewriteRequest, + SceneCard, + SceneCardRequest, + SceneDraftRequest, + WordCountCheck, + WordCountCheckRequest, +) from app.core.config import settings from app.core.database import get_db from app.core.llm import call_llm_stream, instructor_client @@ -15,6 +22,7 @@ assemble_context_pack, get_scene_project_id, ) +from app.services.word_count import build_rewrite_prompt, check_word_budget router = APIRouter(prefix="/api/generate", tags=["generation"]) @@ -115,3 +123,63 @@ async def event_stream(): return StreamingResponse( event_stream(), media_type="text/event-stream" ) + + +@router.post("/word-count-check", response_model=WordCountCheck) +async def word_count_check(req: WordCountCheckRequest): + """Check if scene text fits within the target char budget.""" + return check_word_budget(req.text, req.target_chars) + + +@router.post("/rewrite") +async def rewrite_scene( + req: RewriteRequest, db: AsyncSession = Depends(get_db) +): + """Expand or compress scene text to fit target char budget via SSE.""" + scene = await db.get(Scene, req.scene_id) + if not scene: + raise HTTPException(404, "Scene not found") + + budget = check_word_budget(req.text, req.target_chars) + if budget["status"] == "within": + raise HTTPException( + 400, + f"Text already within budget " + f"(deviation={budget['deviation']:.1%})", + ) + + expected_mode = budget["suggestion"] + if req.mode != expected_mode: + raise HTTPException( + 400, + f"Mode '{req.mode}' conflicts with budget status " + f"'{budget['status']}': expected '{expected_mode}'", + ) + + prompt = build_rewrite_prompt( + req.text, req.target_chars, req.mode + ) + + async def event_stream(): + total_text = "" + async for chunk in call_llm_stream( + messages=[{"role": "user", "content": prompt}] + ): + total_text += chunk + payload = json.dumps( + {"text": chunk}, ensure_ascii=False + ) + yield f"data: {payload}\n\n" + + result = check_word_budget(total_text, req.target_chars) + done_data = { + "done": True, + "char_count": len(total_text), + "budget": result, + } + payload = json.dumps(done_data, ensure_ascii=False) + yield f"data: {payload}\n\n" + + return StreamingResponse( + event_stream(), media_type="text/event-stream" + ) diff --git a/backend/app/services/word_count.py b/backend/app/services/word_count.py new file mode 100644 index 0000000..59831da --- /dev/null +++ b/backend/app/services/word_count.py @@ -0,0 +1,72 @@ +"""Word count budget checking for scene generation.""" + +from typing import Literal + + +def check_word_budget( + text: str, + target_chars: int, + tolerance: float = 0.15, +) -> dict: + """Check if text length is within the target budget. + + Returns dict with: + status: 'within' | 'over' | 'under' + actual_chars: int + target_chars: int + delta: int (actual - target) + deviation: float (abs ratio) + suggestion: 'compress' | 'expand' | None + """ + actual = len(text) + delta = actual - target_chars + if target_chars <= 0: + deviation = float(actual) if actual > 0 else 0.0 + else: + deviation = abs(delta) / target_chars + + if deviation <= tolerance: + status: Literal["within", "over", "under"] = "within" + suggestion = None + elif delta > 0: + status = "over" + suggestion = "compress" + else: + status = "under" + suggestion = "expand" + + return { + "status": status, + "actual_chars": actual, + "target_chars": target_chars, + "delta": delta, + "deviation": round(deviation, 3), + "suggestion": suggestion, + } + + +def build_rewrite_prompt( + text: str, + target_chars: int, + mode: str, +) -> str: + """Build a prompt for expanding or compressing scene text.""" + actual = len(text) + if mode == "expand": + return ( + f"你是一位专业的中文小说作家。以下场景正文目前有 {actual} 字," + f"目标是约 {target_chars} 字。" + "请在保持原有情节和人物不变的前提下,扩写以下内容," + "增加细节描写、对话或环境描写。\n\n" + f"{text}\n\n" + "请直接输出扩写后的完整场景正文,不要输出任何标记或说明。" + ) + else: + return ( + f"你是一位专业的中文小说编辑。以下场景正文目前有 {actual} 字," + f"目标是约 {target_chars} 字。" + "请在保持核心情节和关键对话不变的前提下,精简以下内容," + "删除冗余描写和不必要的细节。\n\n" + f"{text}\n\n" + "请直接输出精简后的完整场景正文,不要输出任何标记或说明。" + ) diff --git a/backend/tests/test_word_count.py b/backend/tests/test_word_count.py new file mode 100644 index 0000000..e8d6316 --- /dev/null +++ b/backend/tests/test_word_count.py @@ -0,0 +1,236 @@ +"""Tests for word count budget check and rewrite endpoint.""" + +import pytest + +from app.services.word_count import build_rewrite_prompt, check_word_budget + +# ---------- Helpers ---------- + +async def _setup_project(client) -> int: + resp = await client.post("/api/projects", json={"title": "WC Test"}) + return resp.json()["id"] + + +async def _setup_book(client, project_id: int) -> int: + resp = await client.post( + "/api/books", json={"project_id": project_id, "title": "Book 1"} + ) + return resp.json()["id"] + + +async def _setup_chapter(client, book_id: int) -> int: + resp = await client.post( + "/api/chapters", + json={"book_id": book_id, "title": "Chapter 1", "sort_order": 0}, + ) + return resp.json()["id"] + + +async def _setup_scene(client, chapter_id: int) -> int: + resp = await client.post( + "/api/scenes", + json={"chapter_id": chapter_id, "title": "Scene 1"}, + ) + return resp.json()["id"] + + +# ---------- Unit tests: check_word_budget ---------- + +def test_within_budget(): + """Text within ±15% tolerance returns 'within'.""" + text = "a" * 1000 + result = check_word_budget(text, target_chars=1000) + assert result["status"] == "within" + assert result["actual_chars"] == 1000 + assert result["delta"] == 0 + assert result["deviation"] == 0.0 + assert result["suggestion"] is None + + +def test_within_budget_at_tolerance_boundary(): + """Text at exactly 15% over is still 'within'.""" + text = "a" * 1150 # 15% over 1000 + result = check_word_budget(text, target_chars=1000, tolerance=0.15) + assert result["status"] == "within" + assert result["suggestion"] is None + + +def test_over_budget(): + """Text exceeding tolerance suggests compression.""" + text = "a" * 1200 # 20% over 1000 + result = check_word_budget(text, target_chars=1000) + assert result["status"] == "over" + assert result["delta"] == 200 + assert result["suggestion"] == "compress" + + +def test_under_budget(): + """Text below tolerance suggests expansion.""" + text = "a" * 700 # 30% under 1000 + result = check_word_budget(text, target_chars=1000) + assert result["status"] == "under" + assert result["delta"] == -300 + assert result["suggestion"] == "expand" + + +def test_custom_tolerance(): + """Custom tolerance changes the threshold.""" + text = "a" * 900 # 10% under + # Default tolerance 15% → within + assert check_word_budget(text, 1000)["status"] == "within" + # Strict tolerance 5% → under + assert check_word_budget(text, 1000, tolerance=0.05)["status"] == "under" + + +def test_zero_target(): + """Zero target_chars doesn't crash.""" + result = check_word_budget("hello", target_chars=0) + assert result["status"] == "over" + assert result["actual_chars"] == 5 + + +def test_empty_text(): + """Empty text is under budget.""" + result = check_word_budget("", target_chars=1000) + assert result["status"] == "under" + assert result["actual_chars"] == 0 + assert result["suggestion"] == "expand" + + +def test_deviation_precision(): + """Deviation is rounded to 3 decimal places.""" + text = "a" * 1234 + result = check_word_budget(text, target_chars=1000) + assert isinstance(result["deviation"], float) + assert result["deviation"] == 0.234 + + +# ---------- Unit tests: build_rewrite_prompt ---------- + +def test_expand_prompt_contains_key_info(): + """Expand prompt includes actual chars, target, and the text.""" + prompt = build_rewrite_prompt("Some text.", 2000, "expand") + assert "10 字" in prompt + assert "2000 字" in prompt + assert "Some text." in prompt + assert "扩写" in prompt + + +def test_compress_prompt_contains_key_info(): + """Compress prompt includes actual chars, target, and the text.""" + prompt = build_rewrite_prompt("Some long text.", 500, "compress") + assert "15 字" in prompt + assert "500 字" in prompt + assert "Some long text." in prompt + assert "精简" in prompt + + +# ---------- API tests ---------- + +@pytest.mark.asyncio +async def test_word_count_check_endpoint(client): + """POST /api/generate/word-count-check returns budget status.""" + resp = await client.post( + "/api/generate/word-count-check", + json={"text": "a" * 1200, "target_chars": 1000}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "over" + assert data["actual_chars"] == 1200 + assert data["suggestion"] == "compress" + + +@pytest.mark.asyncio +async def test_word_count_check_within(client): + """Within-budget text returns status='within'.""" + resp = await client.post( + "/api/generate/word-count-check", + json={"text": "a" * 1000, "target_chars": 1000}, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "within" + + +@pytest.mark.asyncio +async def test_rewrite_rejects_within_budget(client): + """Rewrite returns 400 if text is already within budget.""" + pid = await _setup_project(client) + bid = await _setup_book(client, pid) + cid = await _setup_chapter(client, bid) + sid = await _setup_scene(client, cid) + + resp = await client.post( + "/api/generate/rewrite", + json={ + "scene_id": sid, + "text": "a" * 1000, + "target_chars": 1000, + "mode": "compress", + }, + ) + assert resp.status_code == 400 + assert "within budget" in resp.json()["detail"] + + +@pytest.mark.asyncio +async def test_rewrite_invalid_scene(client): + """Rewrite returns 404 for non-existent scene.""" + resp = await client.post( + "/api/generate/rewrite", + json={ + "scene_id": 999999, + "text": "a" * 2000, + "target_chars": 1000, + "mode": "compress", + }, + ) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_rewrite_invalid_mode(client): + """Rewrite returns 422 for invalid mode.""" + resp = await client.post( + "/api/generate/rewrite", + json={ + "scene_id": 1, + "text": "some text", + "target_chars": 1000, + "mode": "invalid", + }, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_rewrite_target_chars_validation(client): + """target_chars must be between 100 and 50000.""" + resp = await client.post( + "/api/generate/word-count-check", + json={"text": "some text", "target_chars": 50}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_rewrite_mode_direction_mismatch(client): + """Rewrite rejects mode that conflicts with budget direction.""" + pid = await _setup_project(client) + bid = await _setup_book(client, pid) + cid = await _setup_chapter(client, bid) + sid = await _setup_scene(client, cid) + + # Text is 20% over budget → suggestion is 'compress' + # but we send mode='expand' → should be rejected + resp = await client.post( + "/api/generate/rewrite", + json={ + "scene_id": sid, + "text": "a" * 1200, + "target_chars": 1000, + "mode": "expand", + }, + ) + assert resp.status_code == 400 + assert "conflicts" in resp.json()["detail"] diff --git a/frontend/src/lib/types.ts b/frontend/src/lib/types.ts index 313c36d..44fa843 100644 --- a/frontend/src/lib/types.ts +++ b/frontend/src/lib/types.ts @@ -141,3 +141,15 @@ export interface KGEdge { relation: string properties: Record } + +export type WordCountStatus = 'within' | 'over' | 'under' +export type RewriteMode = 'expand' | 'compress' + +export interface WordCountCheck { + status: WordCountStatus + actual_chars: number + target_chars: number + delta: number + deviation: number + suggestion: RewriteMode | null +}