-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
156 lines (131 loc) · 5.25 KB
/
main.py
File metadata and controls
156 lines (131 loc) · 5.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import json
from pydantic import BaseModel
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import StreamingResponse
from starlette.concurrency import iterate_in_threadpool
from src.services.audio_io import temp_audio_file
from src.services.speech_pipeline import analyze_speech_stream
from src.models.stt_whisper import get_whisperx_models
from src.models.phoneme import load_phoneme_models
from src.services.conversation_pipeline import conversation_stream
app = FastAPI(title="Speech Analysis API")
# 전역 변수
loaded_models = None
phoneme_models = None
@app.on_event("startup")
async def startup_event():
global loaded_models, phoneme_models
print("⏳ 모델 로딩 중...")
loaded_models = get_whisperx_models(model_name="small.en", vad_method="silero")
#phoneme_models = load_phoneme_models("src/fine_tuned_model")
phoneme_models = load_phoneme_models("wishkim/wav2vec2-l2arctic-phoneme")
print("✅ 모델 로딩 완료!")
@app.on_event("shutdown")
async def shutdown_event():
global loaded_models
loaded_models = None
print("🛑 서버 종료")
@app.post("/analyze")
async def analyze(
file: UploadFile = File(...),
taskId: str = Form(...),
analysisRequest: str = Form(...)
):
if not file.filename:
return {"error": "파일 이름이 없습니다"}
# 1. 파일 읽기 (Bytes)
audio_bytes = await file.read()
analysis_request = json.loads(analysisRequest)
# 2. 제너레이터 래퍼
def stream_with_cleanup():
with temp_audio_file(audio_bytes, suffix=".wav") as audio_path:
for chunk in analyze_speech_stream(
audio_path=audio_path,
loaded_models=loaded_models,
phoneme_models=phoneme_models,
analysis_request=analysis_request,
mode="all"
):
try:
payload = json.loads(chunk)
t = payload.get("type")
# 1) 에러 처리
if t == "error":
task = payload.get("task")
# task 에러는 type 유지 + FAIL
if task in ("pron", "inton", "feedback"):
mapped_type = "llm" if task == "feedback" else task
yield json.dumps({
"type": mapped_type,
"taskId": taskId,
"status": "FAIL",
"error": payload.get("message"),
"analysisResult": None
}, ensure_ascii=False) + "\n"
else:
# 초기 단계 에러(WhisperX 포함)
yield json.dumps({
"type": "error",
"taskId": taskId,
"status": "FAIL",
"error": payload.get("message"),
"analysisResult": None
}, ensure_ascii=False) + "\n"
continue
# 2) 정상 결과
if t == "feedback":
t = "llm"
yield json.dumps({
"type": t,
"taskId": taskId,
"status": "SUCCESS",
"error": None,
"analysisResult": payload.get("data")
}, ensure_ascii=False) + "\n"
except Exception as e:
# 파싱/래핑 에러도 FAIL
yield json.dumps({
"type": "error",
"taskId": taskId,
"status": "FAIL",
"error": str(e),
"analysisResult": None
}, ensure_ascii=False) + "\n"
# 3. StreamingResponse 반환
return StreamingResponse(
stream_with_cleanup(),
media_type="application/x-ndjson"
)
class ConversationRequest(BaseModel):
taskId: str
filePath: str
status: str | None = "SUCCESS"
error: str | None = None
analysisResult: dict | None = None
@app.post("/conversation")
async def conversation(
file: UploadFile = File(...),
taskId: str = Form(...),
analysisRequest: str = Form(None),
):
if not file.filename:
return {"error": "파일 이름이 없습니다"}
audio_bytes = await file.read()
analysis_req = json.loads(analysisRequest) if analysisRequest else {}
def stream_with_cleanup():
with temp_audio_file(audio_bytes, suffix=".wav") as audio_path:
payload = {
"taskId": taskId,
"filePath": audio_path, # ✅ temp wav 경로
"status": "SUCCESS",
"analysisRequest": analysis_req
}
for chunk in conversation_stream(payload):
yield chunk
return StreamingResponse(
stream_with_cleanup(),
media_type="application/x-ndjson"
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)