-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathproofreading_service.py
More file actions
96 lines (79 loc) · 3.58 KB
/
proofreading_service.py
File metadata and controls
96 lines (79 loc) · 3.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import logging
import os
from typing import Optional
import anthropic
import asyncio
from models import ProofreadPrompt, SystemSetting, Transcription
from database import db
from openai import AsyncOpenAI, OpenAI
class ProofreadingService:
def __init__(self):
# self.anthropic_api_key = os.environ.get('ANTHROPIC_API_KEY')
self.deepseek_api_key = os.environ.get('DEEPSEEK_API_KEY')
self.deepseek_base_url = os.environ.get('DEEPSEEK_BASE_URL')
# Initialize Anthropic
# self.claude = anthropic.Anthropic(api_key=self.anthropic_api_key)
self.async_deepseek = AsyncOpenAI(api_key=self.deepseek_api_key, base_url=self.deepseek_base_url)
async def process_part(self, part: str, prompt: str) -> str:
# response = self.claude.messages.create(
# model="claude-3-5-sonnet-20241022",
# max_tokens=8192,
# temperature=0,
# system=proofread_prompt.prompt,
# messages=[{"role": "user", "content": part}]
# )
# processed_parts.append(response.content[0].text)
"""Process a single part of text asynchronously"""
response = await self.async_deepseek.chat.completions.create(
model="deepseek-reasoner",
max_tokens=8192,
temperature=0,
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": part}
],
stream=False
)
return response.choices[0].message.content
async def process_all_parts(self, parts: list[str], prompt: str) -> list[str]:
"""Process all parts concurrently while maintaining order"""
tasks = [self.process_part(part, prompt) for part in parts]
results = await asyncio.gather(*tasks)
return results
def proofread(self, transcription: Transcription, output_path: str) -> tuple[bool, str, Optional[str]]:
"""Returns (success, output_path, error_message)"""
try:
with open(transcription.txt_document_path, "r", encoding='utf-8') as f:
content = f.read()
# Split content into parts with maximum 500 words each
words = content.split()
parts = []
current_part = []
word_count = 0
for word in words:
current_part.append(word)
word_count += 1
if word_count >= 500:
parts.append(' '.join(current_part))
current_part = []
word_count = 0
if current_part:
parts.append(' '.join(current_part))
setting = SystemSetting.query.filter_by(
setting_key='active_proofread_prompt_id').first()
if not setting:
raise ValueError("No active proofread prompt set")
proofread_prompt = ProofreadPrompt.query.get(setting.setting_value)
if not proofread_prompt:
raise ValueError("Active proofread prompt not found")
transcription.proofread_prompt = proofread_prompt.prompt
db.session.commit()
# Process all parts asynchronously
processed_parts = asyncio.run(self.process_all_parts(parts, proofread_prompt.prompt))
# Combine all processed parts
combined_output = " ".join(processed_parts)
with open(output_path, 'w', encoding='utf-8') as file:
file.write(combined_output)
return True, output_path, None
except Exception as e:
return False, None, str(e)