From 965cffb946605ccdd13610d851ebdd1c905a0624 Mon Sep 17 00:00:00 2001 From: DankerMu Date: Sat, 21 Feb 2026 12:41:17 +0800 Subject: [PATCH 1/3] feat(word-count): scene budget check with expand/compress rewrite (#15) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add word_count service: check_word_budget() + build_rewrite_prompt() - POST /api/generate/word-count-check: budget status endpoint - POST /api/generate/rewrite: SSE streaming expand/compress via LLM - Tolerance ±15%, rejects rewrite if already within budget - Input validation: target_chars 100-50000, mode expand|compress - 16 tests: 8 unit (budget logic) + 2 prompt + 6 API endpoint - Frontend TypeScript types: WordCountCheck, RewriteMode Co-Authored-By: Claude Opus 4.6 --- backend/app/api/ai_schemas.py | 24 +++ backend/app/api/generation.py | 64 +++++++- backend/app/services/word_count.py | 72 +++++++++ backend/tests/test_word_count.py | 238 +++++++++++++++++++++++++++++ frontend/src/lib/types.ts | 12 ++ 5 files changed, 409 insertions(+), 1 deletion(-) create mode 100644 backend/app/services/word_count.py create mode 100644 backend/tests/test_word_count.py diff --git a/backend/app/api/ai_schemas.py b/backend/app/api/ai_schemas.py index 6086cdf..cc2d259 100644 --- a/backend/app/api/ai_schemas.py +++ b/backend/app/api/ai_schemas.py @@ -47,3 +47,27 @@ class ChapterSummaryModel(BaseModel): plot_threads: list[str] = Field( default_factory=list, description="情节线索" ) + + +class RewriteRequest(BaseModel): + scene_id: int + text: str = Field(description="当前场景正文") + target_chars: int = Field( + default=1500, ge=100, le=50000, + description="目标字数", + ) + mode: str = Field( + description="重写模式: expand 或 compress", + pattern="^(expand|compress)$", + ) + + +class WordCountCheck(BaseModel): + status: str = Field(description="within / over / under") + actual_chars: int + target_chars: int + delta: int + deviation: float + suggestion: str | None = Field( + default=None, description="compress / expand / None" + ) diff --git a/backend/app/api/generation.py b/backend/app/api/generation.py index 7424153..db964cb 100644 --- a/backend/app/api/generation.py +++ b/backend/app/api/generation.py @@ -6,7 +6,13 @@ from fastapi.responses import StreamingResponse from sqlalchemy.ext.asyncio import AsyncSession -from app.api.ai_schemas import SceneCard, SceneCardRequest, SceneDraftRequest +from app.api.ai_schemas import ( + RewriteRequest, + SceneCard, + SceneCardRequest, + SceneDraftRequest, + WordCountCheck, +) from app.core.config import settings from app.core.database import get_db from app.core.llm import call_llm_stream, instructor_client @@ -115,3 +121,59 @@ async def event_stream(): return StreamingResponse( event_stream(), media_type="text/event-stream" ) + + +@router.post("/word-count-check", response_model=WordCountCheck) +async def word_count_check(req: RewriteRequest): + """Check if scene text fits within the target char budget.""" + from app.services.word_count import check_word_budget + + return check_word_budget(req.text, req.target_chars) + + +@router.post("/rewrite") +async def rewrite_scene( + req: RewriteRequest, db: AsyncSession = Depends(get_db) +): + """Expand or compress scene text to fit target char budget via SSE.""" + scene = await db.get(Scene, req.scene_id) + if not scene: + raise HTTPException(404, "Scene not found") + + from app.services.word_count import build_rewrite_prompt, check_word_budget + + budget = check_word_budget(req.text, req.target_chars) + if budget["status"] == "within": + raise HTTPException( + 400, + f"Text already within budget " + f"(deviation={budget['deviation']:.1%})", + ) + + prompt = build_rewrite_prompt( + req.text, req.target_chars, req.mode + ) + + async def event_stream(): + total_text = "" + async for chunk in call_llm_stream( + messages=[{"role": "user", "content": prompt}] + ): + total_text += chunk + payload = json.dumps( + {"text": chunk}, ensure_ascii=False + ) + yield f"data: {payload}\n\n" + + result = check_word_budget(total_text, req.target_chars) + done_data = { + "done": True, + "char_count": len(total_text), + "budget": result, + } + payload = json.dumps(done_data, ensure_ascii=False) + yield f"data: {payload}\n\n" + + return StreamingResponse( + event_stream(), media_type="text/event-stream" + ) diff --git a/backend/app/services/word_count.py b/backend/app/services/word_count.py new file mode 100644 index 0000000..59831da --- /dev/null +++ b/backend/app/services/word_count.py @@ -0,0 +1,72 @@ +"""Word count budget checking for scene generation.""" + +from typing import Literal + + +def check_word_budget( + text: str, + target_chars: int, + tolerance: float = 0.15, +) -> dict: + """Check if text length is within the target budget. + + Returns dict with: + status: 'within' | 'over' | 'under' + actual_chars: int + target_chars: int + delta: int (actual - target) + deviation: float (abs ratio) + suggestion: 'compress' | 'expand' | None + """ + actual = len(text) + delta = actual - target_chars + if target_chars <= 0: + deviation = float(actual) if actual > 0 else 0.0 + else: + deviation = abs(delta) / target_chars + + if deviation <= tolerance: + status: Literal["within", "over", "under"] = "within" + suggestion = None + elif delta > 0: + status = "over" + suggestion = "compress" + else: + status = "under" + suggestion = "expand" + + return { + "status": status, + "actual_chars": actual, + "target_chars": target_chars, + "delta": delta, + "deviation": round(deviation, 3), + "suggestion": suggestion, + } + + +def build_rewrite_prompt( + text: str, + target_chars: int, + mode: str, +) -> str: + """Build a prompt for expanding or compressing scene text.""" + actual = len(text) + if mode == "expand": + return ( + f"你是一位专业的中文小说作家。以下场景正文目前有 {actual} 字," + f"目标是约 {target_chars} 字。" + "请在保持原有情节和人物不变的前提下,扩写以下内容," + "增加细节描写、对话或环境描写。\n\n" + f"{text}\n\n" + "请直接输出扩写后的完整场景正文,不要输出任何标记或说明。" + ) + else: + return ( + f"你是一位专业的中文小说编辑。以下场景正文目前有 {actual} 字," + f"目标是约 {target_chars} 字。" + "请在保持核心情节和关键对话不变的前提下,精简以下内容," + "删除冗余描写和不必要的细节。\n\n" + f"{text}\n\n" + "请直接输出精简后的完整场景正文,不要输出任何标记或说明。" + ) diff --git a/backend/tests/test_word_count.py b/backend/tests/test_word_count.py new file mode 100644 index 0000000..cd1cc44 --- /dev/null +++ b/backend/tests/test_word_count.py @@ -0,0 +1,238 @@ +"""Tests for word count budget check and rewrite endpoint.""" + +import pytest + +from app.services.word_count import build_rewrite_prompt, check_word_budget + +# ---------- Helpers ---------- + +async def _setup_project(client) -> int: + resp = await client.post("/api/projects", json={"title": "WC Test"}) + return resp.json()["id"] + + +async def _setup_book(client, project_id: int) -> int: + resp = await client.post( + "/api/books", json={"project_id": project_id, "title": "Book 1"} + ) + return resp.json()["id"] + + +async def _setup_chapter(client, book_id: int) -> int: + resp = await client.post( + "/api/chapters", + json={"book_id": book_id, "title": "Chapter 1", "sort_order": 0}, + ) + return resp.json()["id"] + + +async def _setup_scene(client, chapter_id: int) -> int: + resp = await client.post( + "/api/scenes", + json={"chapter_id": chapter_id, "title": "Scene 1"}, + ) + return resp.json()["id"] + + +# ---------- Unit tests: check_word_budget ---------- + +def test_within_budget(): + """Text within ±15% tolerance returns 'within'.""" + text = "a" * 1000 + result = check_word_budget(text, target_chars=1000) + assert result["status"] == "within" + assert result["actual_chars"] == 1000 + assert result["delta"] == 0 + assert result["deviation"] == 0.0 + assert result["suggestion"] is None + + +def test_within_budget_at_tolerance_boundary(): + """Text at exactly 15% over is still 'within'.""" + text = "a" * 1150 # 15% over 1000 + result = check_word_budget(text, target_chars=1000, tolerance=0.15) + assert result["status"] == "within" + assert result["suggestion"] is None + + +def test_over_budget(): + """Text exceeding tolerance suggests compression.""" + text = "a" * 1200 # 20% over 1000 + result = check_word_budget(text, target_chars=1000) + assert result["status"] == "over" + assert result["delta"] == 200 + assert result["suggestion"] == "compress" + + +def test_under_budget(): + """Text below tolerance suggests expansion.""" + text = "a" * 700 # 30% under 1000 + result = check_word_budget(text, target_chars=1000) + assert result["status"] == "under" + assert result["delta"] == -300 + assert result["suggestion"] == "expand" + + +def test_custom_tolerance(): + """Custom tolerance changes the threshold.""" + text = "a" * 900 # 10% under + # Default tolerance 15% → within + assert check_word_budget(text, 1000)["status"] == "within" + # Strict tolerance 5% → under + assert check_word_budget(text, 1000, tolerance=0.05)["status"] == "under" + + +def test_zero_target(): + """Zero target_chars doesn't crash.""" + result = check_word_budget("hello", target_chars=0) + assert result["status"] == "over" + assert result["actual_chars"] == 5 + + +def test_empty_text(): + """Empty text is under budget.""" + result = check_word_budget("", target_chars=1000) + assert result["status"] == "under" + assert result["actual_chars"] == 0 + assert result["suggestion"] == "expand" + + +def test_deviation_precision(): + """Deviation is rounded to 3 decimal places.""" + text = "a" * 1234 + result = check_word_budget(text, target_chars=1000) + assert isinstance(result["deviation"], float) + assert result["deviation"] == 0.234 + + +# ---------- Unit tests: build_rewrite_prompt ---------- + +def test_expand_prompt_contains_key_info(): + """Expand prompt includes actual chars, target, and the text.""" + prompt = build_rewrite_prompt("Some text.", 2000, "expand") + assert "10 字" in prompt + assert "2000 字" in prompt + assert "Some text." in prompt + assert "扩写" in prompt + + +def test_compress_prompt_contains_key_info(): + """Compress prompt includes actual chars, target, and the text.""" + prompt = build_rewrite_prompt("Some long text.", 500, "compress") + assert "15 字" in prompt + assert "500 字" in prompt + assert "Some long text." in prompt + assert "精简" in prompt + + +# ---------- API tests ---------- + +@pytest.mark.asyncio +async def test_word_count_check_endpoint(client): + """POST /api/generate/word-count-check returns budget status.""" + pid = await _setup_project(client) + bid = await _setup_book(client, pid) + cid = await _setup_chapter(client, bid) + sid = await _setup_scene(client, cid) + + resp = await client.post( + "/api/generate/word-count-check", + json={ + "scene_id": sid, + "text": "a" * 1200, + "target_chars": 1000, + "mode": "compress", + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "over" + assert data["actual_chars"] == 1200 + assert data["suggestion"] == "compress" + + +@pytest.mark.asyncio +async def test_word_count_check_within(client): + """Within-budget text returns status='within'.""" + pid = await _setup_project(client) + bid = await _setup_book(client, pid) + cid = await _setup_chapter(client, bid) + sid = await _setup_scene(client, cid) + + resp = await client.post( + "/api/generate/word-count-check", + json={ + "scene_id": sid, + "text": "a" * 1000, + "target_chars": 1000, + "mode": "expand", + }, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "within" + + +@pytest.mark.asyncio +async def test_rewrite_rejects_within_budget(client): + """Rewrite returns 400 if text is already within budget.""" + pid = await _setup_project(client) + bid = await _setup_book(client, pid) + cid = await _setup_chapter(client, bid) + sid = await _setup_scene(client, cid) + + resp = await client.post( + "/api/generate/rewrite", + json={ + "scene_id": sid, + "text": "a" * 1000, + "target_chars": 1000, + "mode": "compress", + }, + ) + assert resp.status_code == 400 + assert "within budget" in resp.json()["detail"] + + +@pytest.mark.asyncio +async def test_rewrite_invalid_scene(client): + """Rewrite returns 404 for non-existent scene.""" + resp = await client.post( + "/api/generate/rewrite", + json={ + "scene_id": 999999, + "text": "a" * 2000, + "target_chars": 1000, + "mode": "compress", + }, + ) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_rewrite_invalid_mode(client): + """Rewrite returns 422 for invalid mode.""" + resp = await client.post( + "/api/generate/rewrite", + json={ + "scene_id": 1, + "text": "some text", + "target_chars": 1000, + "mode": "invalid", + }, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_rewrite_target_chars_validation(client): + """target_chars must be between 100 and 50000.""" + resp = await client.post( + "/api/generate/word-count-check", + json={ + "scene_id": 1, + "text": "some text", + "target_chars": 50, + "mode": "expand", + }, + ) + assert resp.status_code == 422 diff --git a/frontend/src/lib/types.ts b/frontend/src/lib/types.ts index 313c36d..44fa843 100644 --- a/frontend/src/lib/types.ts +++ b/frontend/src/lib/types.ts @@ -141,3 +141,15 @@ export interface KGEdge { relation: string properties: Record } + +export type WordCountStatus = 'within' | 'over' | 'under' +export type RewriteMode = 'expand' | 'compress' + +export interface WordCountCheck { + status: WordCountStatus + actual_chars: number + target_chars: number + delta: number + deviation: number + suggestion: RewriteMode | null +} From d39e41fa46537c176eb213f43d0feca73df0f096 Mon Sep 17 00:00:00 2001 From: DankerMu Date: Sat, 21 Feb 2026 12:51:44 +0800 Subject: [PATCH 2/3] fix(word-count): address review findings from agent team - H1: validate mode matches budget direction (reject mismatch with 400) - M1: create dedicated WordCountCheckRequest schema (no scene_id/mode) - M2: add max_length=100000 to text field - L1: move word_count imports to module top level - L3: add mode/direction mismatch test All 17 tests passing. Co-Authored-By: Claude Opus 4.6 --- backend/app/api/ai_schemas.py | 14 ++++++++- backend/app/api/generation.py | 16 ++++++---- backend/tests/test_word_count.py | 50 +++++++++++++++----------------- 3 files changed, 48 insertions(+), 32 deletions(-) diff --git a/backend/app/api/ai_schemas.py b/backend/app/api/ai_schemas.py index cc2d259..10f9715 100644 --- a/backend/app/api/ai_schemas.py +++ b/backend/app/api/ai_schemas.py @@ -51,7 +51,9 @@ class ChapterSummaryModel(BaseModel): class RewriteRequest(BaseModel): scene_id: int - text: str = Field(description="当前场景正文") + text: str = Field( + description="当前场景正文", max_length=100000 + ) target_chars: int = Field( default=1500, ge=100, le=50000, description="目标字数", @@ -62,6 +64,16 @@ class RewriteRequest(BaseModel): ) +class WordCountCheckRequest(BaseModel): + text: str = Field( + description="当前场景正文", max_length=100000 + ) + target_chars: int = Field( + default=1500, ge=100, le=50000, + description="目标字数", + ) + + class WordCountCheck(BaseModel): status: str = Field(description="within / over / under") actual_chars: int diff --git a/backend/app/api/generation.py b/backend/app/api/generation.py index db964cb..aeb94ab 100644 --- a/backend/app/api/generation.py +++ b/backend/app/api/generation.py @@ -12,6 +12,7 @@ SceneCardRequest, SceneDraftRequest, WordCountCheck, + WordCountCheckRequest, ) from app.core.config import settings from app.core.database import get_db @@ -21,6 +22,7 @@ assemble_context_pack, get_scene_project_id, ) +from app.services.word_count import build_rewrite_prompt, check_word_budget router = APIRouter(prefix="/api/generate", tags=["generation"]) @@ -124,10 +126,8 @@ async def event_stream(): @router.post("/word-count-check", response_model=WordCountCheck) -async def word_count_check(req: RewriteRequest): +async def word_count_check(req: WordCountCheckRequest): """Check if scene text fits within the target char budget.""" - from app.services.word_count import check_word_budget - return check_word_budget(req.text, req.target_chars) @@ -140,8 +140,6 @@ async def rewrite_scene( if not scene: raise HTTPException(404, "Scene not found") - from app.services.word_count import build_rewrite_prompt, check_word_budget - budget = check_word_budget(req.text, req.target_chars) if budget["status"] == "within": raise HTTPException( @@ -150,6 +148,14 @@ async def rewrite_scene( f"(deviation={budget['deviation']:.1%})", ) + expected_mode = budget["suggestion"] + if req.mode != expected_mode: + raise HTTPException( + 400, + f"Mode '{req.mode}' conflicts with budget status " + f"'{budget['status']}': expected '{expected_mode}'", + ) + prompt = build_rewrite_prompt( req.text, req.target_chars, req.mode ) diff --git a/backend/tests/test_word_count.py b/backend/tests/test_word_count.py index cd1cc44..e8d6316 100644 --- a/backend/tests/test_word_count.py +++ b/backend/tests/test_word_count.py @@ -130,19 +130,9 @@ def test_compress_prompt_contains_key_info(): @pytest.mark.asyncio async def test_word_count_check_endpoint(client): """POST /api/generate/word-count-check returns budget status.""" - pid = await _setup_project(client) - bid = await _setup_book(client, pid) - cid = await _setup_chapter(client, bid) - sid = await _setup_scene(client, cid) - resp = await client.post( "/api/generate/word-count-check", - json={ - "scene_id": sid, - "text": "a" * 1200, - "target_chars": 1000, - "mode": "compress", - }, + json={"text": "a" * 1200, "target_chars": 1000}, ) assert resp.status_code == 200 data = resp.json() @@ -154,19 +144,9 @@ async def test_word_count_check_endpoint(client): @pytest.mark.asyncio async def test_word_count_check_within(client): """Within-budget text returns status='within'.""" - pid = await _setup_project(client) - bid = await _setup_book(client, pid) - cid = await _setup_chapter(client, bid) - sid = await _setup_scene(client, cid) - resp = await client.post( "/api/generate/word-count-check", - json={ - "scene_id": sid, - "text": "a" * 1000, - "target_chars": 1000, - "mode": "expand", - }, + json={"text": "a" * 1000, "target_chars": 1000}, ) assert resp.status_code == 200 assert resp.json()["status"] == "within" @@ -228,11 +208,29 @@ async def test_rewrite_target_chars_validation(client): """target_chars must be between 100 and 50000.""" resp = await client.post( "/api/generate/word-count-check", + json={"text": "some text", "target_chars": 50}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_rewrite_mode_direction_mismatch(client): + """Rewrite rejects mode that conflicts with budget direction.""" + pid = await _setup_project(client) + bid = await _setup_book(client, pid) + cid = await _setup_chapter(client, bid) + sid = await _setup_scene(client, cid) + + # Text is 20% over budget → suggestion is 'compress' + # but we send mode='expand' → should be rejected + resp = await client.post( + "/api/generate/rewrite", json={ - "scene_id": 1, - "text": "some text", - "target_chars": 50, + "scene_id": sid, + "text": "a" * 1200, + "target_chars": 1000, "mode": "expand", }, ) - assert resp.status_code == 422 + assert resp.status_code == 400 + assert "conflicts" in resp.json()["detail"] From f5919d6ce09b017c4aa51fb3282e2984bfd5dd52 Mon Sep 17 00:00:00 2001 From: DankerMu Date: Sat, 21 Feb 2026 13:02:38 +0800 Subject: [PATCH 3/3] fix(word-count): use Literal type for rewrite mode field Co-Authored-By: Claude Opus 4.6 --- backend/app/api/ai_schemas.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/backend/app/api/ai_schemas.py b/backend/app/api/ai_schemas.py index 10f9715..afac0fa 100644 --- a/backend/app/api/ai_schemas.py +++ b/backend/app/api/ai_schemas.py @@ -1,5 +1,7 @@ """AI generation schemas: SceneCard, SceneDraft, Context Pack.""" +from typing import Literal + from pydantic import BaseModel, Field @@ -58,9 +60,8 @@ class RewriteRequest(BaseModel): default=1500, ge=100, le=50000, description="目标字数", ) - mode: str = Field( + mode: Literal["expand", "compress"] = Field( description="重写模式: expand 或 compress", - pattern="^(expand|compress)$", )