From b7a9fb83854ed277ed117cb01d26007e6a776bbe Mon Sep 17 00:00:00 2001 From: Codex Date: Thu, 4 Jun 2026 16:38:27 -0400 Subject: [PATCH 1/4] Add dashboard diarize action --- internal/api/handlers.go | 87 ++++++++++++++++++ internal/api/router.go | 1 + internal/models/transcription.go | 7 +- internal/transcription/unified_service.go | 58 +++++++++--- .../src/components/TranscribeDDialog.tsx | 17 +++- .../TranscriptionConfigDialog.tsx | 1 + .../src/components/ui/swipeable-item.tsx | 26 +++++- .../components/AudioFilesTable.tsx | 89 +++++++++++++++++++ 8 files changed, 262 insertions(+), 24 deletions(-) diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 1b94beef4..3637605ef 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -1033,6 +1033,64 @@ func (h *Handler) StartTranscription(c *gin.Context) { c.JSON(http.StatusOK, job) } +// @Summary Start diarization for a completed transcription +// @Description Run speaker diarization against an existing transcript without rerunning transcription +// @Tags transcription +// @Accept json +// @Produce json +// @Param id path string true "Job ID" +// @Param parameters body models.WhisperXParams true "Diarization parameters" +// @Success 200 {object} models.TranscriptionJob +// @Failure 400 {object} map[string]string +// @Failure 404 {object} map[string]string +// @Router /api/v1/transcription/{id}/diarize [post] +// @Security ApiKeyAuth +// @Security BearerAuth +func (h *Handler) StartDiarization(c *gin.Context) { + jobID := c.Param("id") + + job, err := h.getJobForDiarization(c, jobID) + if err != nil { + return + } + + requestParams, err := h.getValidatedTranscriptionParams(c, job, jobID) + if err != nil { + return + } + + requestParams.Diarize = true + requestParams.DiarizationOnly = true + requestParams.IsMultiTrackEnabled = false + if requestParams.DiarizeModel == "" { + requestParams.DiarizeModel = transcription.ModelPyannote + } + + job.Parameters = *requestParams + job.Diarization = true + job.Status = models.StatusPending + job.ErrorMessage = nil + + if err := h.jobRepo.Update(c.Request.Context(), job); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update job"}) + return + } + + if err := h.taskQueue.EnqueueJob(jobID); err != nil { + logger.Error("Failed to enqueue diarization job", "job_id", jobID, "error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to enqueue diarization job"}) + return + } + + logger.Info("Diarization job started", + "job_id", jobID, + "diarize_model", requestParams.DiarizeModel, + "min_speakers", requestParams.MinSpeakers, + "max_speakers", requestParams.MaxSpeakers) + + c.JSON(http.StatusOK, job) +} + func (h *Handler) getJobForTranscription(c *gin.Context, jobID string) (*models.TranscriptionJob, error) { job, err := h.jobRepo.FindByID(c.Request.Context(), jobID) if err != nil { @@ -1052,6 +1110,35 @@ func (h *Handler) getJobForTranscription(c *gin.Context, jobID string) (*models. return job, nil } +func (h *Handler) getJobForDiarization(c *gin.Context, jobID string) (*models.TranscriptionJob, error) { + job, err := h.jobRepo.FindByID(c.Request.Context(), jobID) + if err != nil { + if err == gorm.ErrRecordNotFound { + c.JSON(http.StatusNotFound, gin.H{"error": "Job not found"}) + return nil, err + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get job"}) + return nil, err + } + + if job.Status == models.StatusProcessing || job.Status == models.StatusPending { + c.JSON(http.StatusBadRequest, gin.H{"error": "Cannot start diarization: job is currently processing or pending"}) + return nil, fmt.Errorf("invalid job status") + } + + if job.IsMultiTrack { + c.JSON(http.StatusBadRequest, gin.H{"error": "Standalone diarization is not supported for multi-track recordings"}) + return nil, fmt.Errorf("multi-track diarization unsupported") + } + + if job.Transcript == nil || strings.TrimSpace(*job.Transcript) == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Cannot start diarization: job does not have an existing transcript"}) + return nil, fmt.Errorf("missing transcript") + } + + return job, nil +} + func (h *Handler) getValidatedTranscriptionParams(c *gin.Context, job *models.TranscriptionJob, jobID string) (*models.WhisperXParams, error) { // Set defaults requestParams := models.WhisperXParams{ diff --git a/internal/api/router.go b/internal/api/router.go index b5bec3224..33061fdf7 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -132,6 +132,7 @@ func SetupRoutes(handler *Handler, authService *auth.AuthService) *gin.Engine { transcription.POST("/youtube", handler.DownloadFromYouTube) transcription.POST("/submit", handler.SubmitJob) transcription.POST("/:id/start", handler.StartTranscription) + transcription.POST("/:id/diarize", handler.StartDiarization) transcription.POST("/:id/kill", handler.KillJob) transcription.GET("/:id/logs", handler.GetJobLogs) transcription.GET("/:id/status", handler.GetJobStatus) diff --git a/internal/models/transcription.go b/internal/models/transcription.go index 04285b08e..b325cb74f 100644 --- a/internal/models/transcription.go +++ b/internal/models/transcription.go @@ -85,6 +85,7 @@ type WhisperXParams struct { // Diarization settings Diarize bool `json:"diarize" gorm:"type:boolean;default:false"` + DiarizationOnly bool `json:"diarization_only" gorm:"type:boolean;default:false"` MinSpeakers *int `json:"min_speakers,omitempty" gorm:"type:int"` MaxSpeakers *int `json:"max_speakers,omitempty" gorm:"type:int"` DiarizeModel string `json:"diarize_model" gorm:"type:varchar(50);default:'pyannote'"` // Options: 'pyannote', 'nvidia_sortformer' @@ -209,10 +210,10 @@ func (tp *TranscriptionProfile) BeforeSave(tx *gorm.DB) error { // LLMConfig represents LLM configuration settings type LLMConfig struct { ID uint `json:"id" gorm:"primaryKey"` - Provider string `json:"provider" gorm:"not null;type:varchar(50)"` // "ollama" or "openai" - BaseURL *string `json:"base_url,omitempty" gorm:"type:text"` // For Ollama + Provider string `json:"provider" gorm:"not null;type:varchar(50)"` // "ollama" or "openai" + BaseURL *string `json:"base_url,omitempty" gorm:"type:text"` // For Ollama OpenAIBaseURL *string `json:"openai_base_url,omitempty" gorm:"type:text"` // For OpenAI custom endpoint - APIKey *string `json:"api_key,omitempty" gorm:"type:text"` // For OpenAI (encrypted) + APIKey *string `json:"api_key,omitempty" gorm:"type:text"` // For OpenAI (encrypted) IsActive bool `json:"is_active" gorm:"type:boolean;default:false"` CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` diff --git a/internal/transcription/unified_service.go b/internal/transcription/unified_service.go index 1129353a0..bbfe39bf8 100644 --- a/internal/transcription/unified_service.go +++ b/internal/transcription/unified_service.go @@ -287,6 +287,17 @@ func (u *UnifiedTranscriptionService) processSingleTrackJob(ctx context.Context, var transcriptResult *interfaces.TranscriptResult var diarizationResult *interfaces.DiarizationResult + if job.Parameters.DiarizationOnly { + if !job.Parameters.Diarize { + return fmt.Errorf("diarization-only job requires diarization to be enabled") + } + + transcriptResult, err = u.loadExistingTranscriptResult(job) + if err != nil { + return fmt.Errorf("failed to load existing transcript for diarization: %w", err) + } + } + // Perform transcription using the preprocessed audio if transcriptionModelID != "" { logger.Info("Running transcription", "model_id", transcriptionModelID) @@ -376,19 +387,21 @@ func (u *UnifiedTranscriptionService) IsMultiTrackJob(jobID string) bool { // selectModels determines which models to use based on job parameters func (u *UnifiedTranscriptionService) selectModels(params models.WhisperXParams) (transcriptionModelID, diarizationModelID string, err error) { // Determine transcription model - switch params.ModelFamily { - case FamilyNvidiaParakeet: - transcriptionModelID = ModelParakeet - case FamilyNvidiaCanary: - transcriptionModelID = ModelCanary - case FamilyWhisper: - transcriptionModelID = ModelWhisperX - case FamilyOpenAI: - transcriptionModelID = ModelOpenAI - case FamilyMistralVoxtral: - transcriptionModelID = ModelVoxtral - default: - transcriptionModelID = ModelWhisperX // Default fallback + if !params.DiarizationOnly { + switch params.ModelFamily { + case FamilyNvidiaParakeet: + transcriptionModelID = ModelParakeet + case FamilyNvidiaCanary: + transcriptionModelID = ModelCanary + case FamilyWhisper: + transcriptionModelID = ModelWhisperX + case FamilyOpenAI: + transcriptionModelID = ModelOpenAI + case FamilyMistralVoxtral: + transcriptionModelID = ModelVoxtral + default: + transcriptionModelID = ModelWhisperX // Default fallback + } } // Determine diarization model if needed @@ -850,6 +863,7 @@ func (u *UnifiedTranscriptionService) mergeDiarizationWithTranscription(transcri // Assign speakers to transcript segments based on timing overlap for i := range mergedTranscript.Segments { segment := &mergedTranscript.Segments[i] + segment.Speaker = nil bestSpeaker := u.findBestSpeakerForSegment(segment.Start, segment.End, diarization.Segments) if bestSpeaker != "" { segment.Speaker = &bestSpeaker @@ -863,6 +877,7 @@ func (u *UnifiedTranscriptionService) mergeDiarizationWithTranscription(transcri for i := range mergedTranscript.WordSegments { word := &mergedTranscript.WordSegments[i] + word.Speaker = nil bestSpeaker := u.findBestSpeakerForSegment(word.Start, word.End, diarization.Segments) if bestSpeaker != "" { word.Speaker = &bestSpeaker @@ -893,6 +908,23 @@ func (u *UnifiedTranscriptionService) findBestSpeakerForSegment(start, end float return bestSpeaker } +func (u *UnifiedTranscriptionService) loadExistingTranscriptResult(job *models.TranscriptionJob) (*interfaces.TranscriptResult, error) { + if job.Transcript == nil || strings.TrimSpace(*job.Transcript) == "" { + return nil, fmt.Errorf("job has no transcript") + } + + var result interfaces.TranscriptResult + if err := json.Unmarshal([]byte(*job.Transcript), &result); err != nil { + return nil, fmt.Errorf("failed to parse transcript JSON: %w", err) + } + + if len(result.Segments) == 0 { + return nil, fmt.Errorf("existing transcript has no timed segments") + } + + return &result, nil +} + // saveTranscriptionResults saves the transcription results to the database func (u *UnifiedTranscriptionService) saveTranscriptionResults(jobID string, result *interfaces.TranscriptResult) error { // Convert result to JSON string for database storage diff --git a/web/frontend/src/components/TranscribeDDialog.tsx b/web/frontend/src/components/TranscribeDDialog.tsx index e3115b382..739dc3623 100644 --- a/web/frontend/src/components/TranscribeDDialog.tsx +++ b/web/frontend/src/components/TranscribeDDialog.tsx @@ -37,6 +37,10 @@ interface TranscribeDDialogProps { onStartTranscription: (params: WhisperXParams, profileId?: string) => void; loading?: boolean; title?: string; + description?: string; + submitLabel?: string; + loadingLabel?: string; + forceDiarization?: boolean; } type DiarizationModel = "pyannote" | "nvidia_sortformer"; @@ -47,6 +51,10 @@ export function TranscribeDDialog({ onStartTranscription, loading = false, title, + description, + submitLabel = "Start Transcription", + loadingLabel = "Starting...", + forceDiarization = false, }: TranscribeDDialogProps) { const { getAuthHeaders } = useAuth(); const [profiles, setProfiles] = useState([]); @@ -58,7 +66,7 @@ export function TranscribeDDialog({ const [diarizationModel, setDiarizationModel] = useState("pyannote"); const selectedProfile = profiles.find(p => p.id === selectedProfileId); - const showDiarizationSettings = selectedProfile?.parameters.diarize ?? false; + const showDiarizationSettings = forceDiarization || (selectedProfile?.parameters.diarize ?? false); const minSpeakers = parseSpeakerLimit(minSpeakersInput); const maxSpeakers = parseSpeakerLimit(maxSpeakersInput); const hasInvalidSpeakerLimits = showDiarizationSettings && ( @@ -130,6 +138,7 @@ export function TranscribeDDialog({ const params: WhisperXParams = { ...selectedProfile.parameters }; if (showDiarizationSettings) { + params.diarize = true; params.diarize_model = diarizationModel; params.min_speakers = minSpeakers; params.max_speakers = maxSpeakers; @@ -157,7 +166,7 @@ export function TranscribeDDialog({ {title || "Transcribe with Profile"} - Choose a saved profile to start transcription with your preferred settings. + {description || "Choose a saved profile to start transcription with your preferred settings."} @@ -312,10 +321,10 @@ export function TranscribeDDialog({ {loading ? ( <> - Starting... + {loadingLabel} ) : ( - "Start Transcription" + submitLabel )} diff --git a/web/frontend/src/components/transcription/TranscriptionConfigDialog.tsx b/web/frontend/src/components/transcription/TranscriptionConfigDialog.tsx index 728d5458f..cb42f7e41 100644 --- a/web/frontend/src/components/transcription/TranscriptionConfigDialog.tsx +++ b/web/frontend/src/components/transcription/TranscriptionConfigDialog.tsx @@ -44,6 +44,7 @@ export interface WhisperXParams { vad_offset: number; chunk_size: number; diarize: boolean; + diarization_only?: boolean; min_speakers?: number; max_speakers?: number; diarize_model: string; diff --git a/web/frontend/src/components/ui/swipeable-item.tsx b/web/frontend/src/components/ui/swipeable-item.tsx index 37942669f..ace6a862c 100644 --- a/web/frontend/src/components/ui/swipeable-item.tsx +++ b/web/frontend/src/components/ui/swipeable-item.tsx @@ -1,6 +1,6 @@ import { useEffect, useRef, useState, type ReactNode } from "react"; import { motion, useAnimation, type PanInfo } from "framer-motion"; -import { Trash2, Wand2, StopCircle } from "lucide-react"; +import { Trash2, Wand2, StopCircle, Users } from "lucide-react"; import { WandAdvancedIcon } from "@/components/icons/WandAdvancedIcon"; import { useIsMobile } from "@/hooks/use-mobile"; @@ -8,8 +8,10 @@ interface SwipeableItemProps { children: ReactNode; onTranscribe: () => void; onTranscribeAdvanced: () => void; + onDiarize?: () => void; onDelete: () => void; onStop?: () => void; + canDiarize?: boolean; isProcessing?: boolean; isSelectionMode?: boolean; shouldShowHint?: boolean; @@ -33,8 +35,10 @@ export function SwipeableItem({ children, onTranscribe, onTranscribeAdvanced, + onDiarize, onDelete, onStop, + canDiarize = false, isProcessing = false, isSelectionMode = false, shouldShowHint = false, @@ -51,8 +55,9 @@ export function SwipeableItem({ const hasMovedRef = useRef(false); const suppressClickUntilRef = useRef(0); - // Width of the action buttons area (3 buttons × 44px + gaps + padding) - const OPEN_WIDTH = -160; + const actionCount = canDiarize && onDiarize && !isProcessing ? 4 : 3; + const actionAreaWidth = actionCount * 44 + (actionCount - 1) * 8 + 16; + const OPEN_WIDTH = -actionAreaWidth; const handleDragStart = (_event: MouseEvent | TouchEvent | PointerEvent, info: PanInfo) => { isDraggingRef.current = true; @@ -168,7 +173,10 @@ export function SwipeableItem({ data-has-moved={hasMoved()} > {/* --- LAYER 0: The Action Buttons (Hidden underneath, mobile only) --- */} -
+
{/* Transcribe (Primary) */} + {canDiarize && onDiarize && !isProcessing && ( + + )} + {/* Delete or Stop (Destructive - furthest right) */} {isProcessing && onStop ? ( + + Diarize + + )} )} @@ -1023,6 +1101,17 @@ export const AudioFilesTable = memo(function AudioFilesTable({ onStartTranscription={onStartTranscribeWithProfile} loading={transcriptionLoading} /> + {/* Stop Transcription Dialog */} From 2213e18b7faadf42eee4df1fb3832d969aa00a37 Mon Sep 17 00:00:00 2001 From: Codex Date: Fri, 26 Jun 2026 10:27:27 -0400 Subject: [PATCH 2/4] Add persistent diarization worker controls --- cmd/server/main.go | 31 + internal/api/handlers.go | 121 +++- internal/api/router.go | 9 + internal/models/transcription.go | 1 + .../persistent_diarization_manager.go | 546 ++++++++++++++++++ .../adapters/py/nvidia/sortformer_worker.py | 124 ++++ .../adapters/py/pyannote/pyannote_worker.py | 145 +++++ .../adapters/pyannote_adapter.go | 80 ++- .../adapters/sortformer_adapter.go | 71 ++- internal/transcription/queue_integration.go | 16 + internal/transcription/unified_service.go | 28 +- web/frontend/src/components/Header.tsx | 219 ++++++- .../settings/components/ProfileSettings.tsx | 87 ++- 13 files changed, 1441 insertions(+), 37 deletions(-) create mode 100644 internal/transcription/adapters/persistent_diarization_manager.go create mode 100644 internal/transcription/adapters/py/nvidia/sortformer_worker.py create mode 100644 internal/transcription/adapters/py/pyannote/pyannote_worker.py diff --git a/cmd/server/main.go b/cmd/server/main.go index acdd3beb3..5a6e81dba 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -8,6 +8,7 @@ import ( "os" "os/signal" "path/filepath" + "strings" "syscall" "time" @@ -125,6 +126,7 @@ func main() { logger.Error("Failed to prepare Python environment", "error", err) os.Exit(1) } + loadStartupDiarizationModel(context.Background(), userRepo, unifiedProcessor) // Initialize quick transcription service logger.Startup("quick-transcription", "Initializing quick transcription service") @@ -248,3 +250,32 @@ func registerAdapters(cfg *config.Config) { logger.Info("Adapter registration complete") } + +func loadStartupDiarizationModel(ctx context.Context, userRepo repository.UserRepository, unifiedProcessor *transcription.UnifiedJobProcessor) { + users, _, err := userRepo.List(ctx, 0, 100) + if err != nil { + logger.Warn("Failed to load startup diarization preference", "error", err) + return + } + + for _, user := range users { + modelID := strings.TrimSpace(strings.ToLower(user.StartupDiarizationModel)) + if modelID == "" || modelID == "none" { + continue + } + + logger.Startup("diarization", "Loading persistent diarization model configured for startup") + loadCtx, cancel := context.WithTimeout(ctx, 30*time.Minute) + status, err := unifiedProcessor.LoadPersistentDiarizationModel(loadCtx, modelID, map[string]interface{}{ + "device": "auto", + }) + cancel() + if err != nil { + logger.Warn("Failed to load startup diarization model", "model_id", modelID, "user_id", user.ID, "error", err) + return + } + + logger.Info("Startup diarization model loaded", "model_id", status.ModelID, "pid", status.PID) + return + } +} diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 3637605ef..1a2e0874f 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "context" "crypto/rand" "crypto/sha256" "encoding/hex" @@ -106,6 +107,13 @@ type SubmitJobRequest struct { Parameters models.WhisperXParams `json:"parameters"` } +// DiarizationWorkerLoadRequest represents a persistent diarization model load request. +type DiarizationWorkerLoadRequest struct { + Model string `json:"model" binding:"required"` + Device string `json:"device,omitempty"` + HfToken string `json:"hf_token,omitempty"` +} + // LoginRequest represents the login request type LoginRequest struct { Username string `json:"username" binding:"required"` @@ -2251,6 +2259,82 @@ func (h *Handler) GetSupportedModels(c *gin.Context) { }) } +// @Summary Get persistent diarization worker status +// @Description Get the currently loaded resident diarization model, if any +// @Tags transcription +// @Produce json +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/diarization-worker/status [get] +// @Security ApiKeyAuth +// @Security BearerAuth +func (h *Handler) GetDiarizationWorkerStatus(c *gin.Context) { + c.JSON(http.StatusOK, h.unifiedProcessor.GetPersistentDiarizationStatus()) +} + +// @Summary Load persistent diarization model +// @Description Load PyAnnote or NVIDIA Sortformer into a resident worker process +// @Tags transcription +// @Accept json +// @Produce json +// @Param request body DiarizationWorkerLoadRequest true "Model load request" +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/diarization-worker/load [post] +// @Security ApiKeyAuth +// @Security BearerAuth +func (h *Handler) LoadDiarizationWorker(c *gin.Context) { + var req DiarizationWorkerLoadRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + params := map[string]interface{}{} + if req.Device != "" { + params["device"] = req.Device + } + if req.HfToken != "" { + params["hf_token"] = req.HfToken + } + + ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Minute) + defer cancel() + + status, err := h.unifiedProcessor.LoadPersistentDiarizationModel(ctx, req.Model, params) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": err.Error(), + "status": status, + }) + return + } + + c.JSON(http.StatusOK, status) +} + +// @Summary Unload persistent diarization model +// @Description Stop the resident diarization worker and release its VRAM +// @Tags transcription +// @Produce json +// @Success 200 {object} map[string]interface{} +// @Router /api/v1/diarization-worker/unload [post] +// @Security ApiKeyAuth +// @Security BearerAuth +func (h *Handler) UnloadDiarizationWorker(c *gin.Context) { + ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second) + defer cancel() + + status, err := h.unifiedProcessor.UnloadPersistentDiarizationModel(ctx) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": err.Error(), + "status": status, + }) + return + } + + c.JSON(http.StatusOK, status) +} + // Health check endpoint // @Summary Health check // @Description Check if the API is healthy @@ -2890,11 +2974,13 @@ func (h *Handler) SetUserDefaultProfile(c *gin.Context) { type UserSettingsResponse struct { AutoTranscriptionEnabled bool `json:"auto_transcription_enabled"` DefaultProfileID *string `json:"default_profile_id,omitempty"` + StartupDiarizationModel string `json:"startup_diarization_model"` } // UpdateUserSettingsRequest represents the request to update user settings type UpdateUserSettingsRequest struct { - AutoTranscriptionEnabled *bool `json:"auto_transcription_enabled,omitempty"` + AutoTranscriptionEnabled *bool `json:"auto_transcription_enabled,omitempty"` + StartupDiarizationModel *string `json:"startup_diarization_model,omitempty"` } // @Summary Get user settings @@ -2922,6 +3008,7 @@ func (h *Handler) GetUserSettings(c *gin.Context) { response := UserSettingsResponse{ AutoTranscriptionEnabled: user.AutoTranscriptionEnabled, DefaultProfileID: user.DefaultProfileID, + StartupDiarizationModel: normalizeStartupDiarizationSetting(user.StartupDiarizationModel), } c.JSON(http.StatusOK, response) @@ -2962,6 +3049,14 @@ func (h *Handler) UpdateUserSettings(c *gin.Context) { if req.AutoTranscriptionEnabled != nil { user.AutoTranscriptionEnabled = *req.AutoTranscriptionEnabled } + if req.StartupDiarizationModel != nil { + normalized, ok := validateStartupDiarizationSetting(*req.StartupDiarizationModel) + if !ok { + c.JSON(http.StatusBadRequest, gin.H{"error": "startup_diarization_model must be one of: none, pyannote, sortformer"}) + return + } + user.StartupDiarizationModel = normalized + } // Save updated user if err := h.userRepo.Update(c.Request.Context(), user); err != nil { @@ -2972,11 +3067,35 @@ func (h *Handler) UpdateUserSettings(c *gin.Context) { response := UserSettingsResponse{ AutoTranscriptionEnabled: user.AutoTranscriptionEnabled, DefaultProfileID: user.DefaultProfileID, + StartupDiarizationModel: normalizeStartupDiarizationSetting(user.StartupDiarizationModel), } c.JSON(http.StatusOK, response) } +func validateStartupDiarizationSetting(value string) (string, bool) { + normalized := normalizeStartupDiarizationSetting(value) + switch normalized { + case "none", "pyannote", "sortformer": + return normalized, true + default: + return "", false + } +} + +func normalizeStartupDiarizationSetting(value string) string { + switch strings.TrimSpace(strings.ToLower(value)) { + case "", "none": + return "none" + case "pyannote", "pyannote/speaker-diarization-3.1", "pyannote/speaker-diarization-community-1": + return "pyannote" + case "sortformer", "nvidia_sortformer", "nvidia/diar_streaming_sortformer_4spk-v2": + return "sortformer" + default: + return strings.TrimSpace(strings.ToLower(value)) + } +} + // @Summary SSE Events // @Description Subscribe to server-sent events // @Tags events diff --git a/internal/api/router.go b/internal/api/router.go index 33061fdf7..149647e73 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -159,6 +159,15 @@ func SetupRoutes(handler *Handler, authService *auth.AuthService) *gin.Engine { transcription.GET("/quick/:id", handler.GetQuickTranscriptionStatus) } + // Persistent diarization worker routes (require authentication) + diarizationWorker := v1.Group("/diarization-worker") + diarizationWorker.Use(middleware.AuthMiddleware(authService)) + { + diarizationWorker.GET("/status", handler.GetDiarizationWorkerStatus) + diarizationWorker.POST("/load", handler.LoadDiarizationWorker) + diarizationWorker.POST("/unload", handler.UnloadDiarizationWorker) + } + // Profile routes (require authentication) profiles := v1.Group("/profiles") profiles.Use(middleware.AuthMiddleware(authService)) diff --git a/internal/models/transcription.go b/internal/models/transcription.go index b325cb74f..0162ca0cf 100644 --- a/internal/models/transcription.go +++ b/internal/models/transcription.go @@ -151,6 +151,7 @@ type User struct { Password string `json:"-" gorm:"not null;type:varchar(255)"` DefaultProfileID *string `json:"default_profile_id,omitempty" gorm:"type:varchar(36)"` AutoTranscriptionEnabled bool `json:"auto_transcription_enabled" gorm:"not null;default:false"` + StartupDiarizationModel string `json:"startup_diarization_model" gorm:"type:varchar(20);not null;default:'none'"` CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` } diff --git a/internal/transcription/adapters/persistent_diarization_manager.go b/internal/transcription/adapters/persistent_diarization_manager.go new file mode 100644 index 000000000..79a5cdbe4 --- /dev/null +++ b/internal/transcription/adapters/persistent_diarization_manager.go @@ -0,0 +1,546 @@ +package adapters + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "time" + + "scriberr/pkg/logger" +) + +const ( + PersistentDiarizationStateUnloaded = "unloaded" + PersistentDiarizationStateLoading = "loading" + PersistentDiarizationStateLoaded = "loaded" + PersistentDiarizationStateUnloading = "unloading" + PersistentDiarizationStateFailed = "failed" + + PersistentDiarizationModelPyAnnote = "pyannote" + PersistentDiarizationModelSortformer = "sortformer" +) + +var errPersistentWorkerStopped = errors.New("persistent diarization worker stopped") + +// PersistentDiarizationStatus describes the resident diarization model, if any. +type PersistentDiarizationStatus struct { + State string `json:"state"` + Loaded bool `json:"loaded"` + ModelID string `json:"model_id,omitempty"` + DisplayName string `json:"display_name,omitempty"` + Error string `json:"error,omitempty"` + StartedAt *time.Time `json:"started_at,omitempty"` + PID int `json:"pid,omitempty"` +} + +// PersistentDiarizationManager owns the single resident diarization process. +type PersistentDiarizationManager struct { + mu sync.RWMutex + envPaths map[string]string + worker *persistentDiarizationWorker + status PersistentDiarizationStatus +} + +var persistentDiarizationManager = NewPersistentDiarizationManager() + +// NewPersistentDiarizationManager creates a manager with no loaded model. +func NewPersistentDiarizationManager() *PersistentDiarizationManager { + return &PersistentDiarizationManager{ + envPaths: make(map[string]string), + status: PersistentDiarizationStatus{ + State: PersistentDiarizationStateUnloaded, + }, + } +} + +// GetPersistentDiarizationManager returns the process-wide resident model manager. +func GetPersistentDiarizationManager() *PersistentDiarizationManager { + return persistentDiarizationManager +} + +// SetEnvironment registers the UV project path used by a diarization model. +func (m *PersistentDiarizationManager) SetEnvironment(modelID, envPath string) { + modelID = normalizePersistentDiarizationModel(modelID) + if modelID == "" || envPath == "" { + return + } + + m.mu.Lock() + defer m.mu.Unlock() + m.envPaths[modelID] = envPath +} + +// Status returns a copy of the current persistent worker state. +func (m *PersistentDiarizationManager) Status() PersistentDiarizationStatus { + m.mu.RLock() + defer m.mu.RUnlock() + return m.status +} + +// IsModelLoaded reports whether the requested model is currently resident. +func (m *PersistentDiarizationManager) IsModelLoaded(modelID string) bool { + modelID = normalizePersistentDiarizationModel(modelID) + + m.mu.RLock() + defer m.mu.RUnlock() + return m.worker != nil && m.status.Loaded && m.status.ModelID == modelID +} + +// Load starts a resident diarization worker and waits until the model is loaded. +func (m *PersistentDiarizationManager) Load(ctx context.Context, modelID string, params map[string]interface{}) (PersistentDiarizationStatus, error) { + modelID = normalizePersistentDiarizationModel(modelID) + if modelID == "" { + return m.Status(), fmt.Errorf("unsupported diarization model") + } + + m.mu.Lock() + if m.status.State == PersistentDiarizationStateLoading || m.status.State == PersistentDiarizationStateUnloading { + status := m.status + m.mu.Unlock() + return status, fmt.Errorf("diarization worker is currently %s", status.State) + } + if m.worker != nil && m.status.Loaded { + status := m.status + m.mu.Unlock() + if status.ModelID == modelID { + return status, nil + } + return status, fmt.Errorf("%s is already loaded; unload it before loading %s", status.DisplayName, persistentDiarizationDisplayName(modelID)) + } + + envPath := m.envPaths[modelID] + startedAt := time.Now() + m.status = PersistentDiarizationStatus{ + State: PersistentDiarizationStateLoading, + ModelID: modelID, + DisplayName: persistentDiarizationDisplayName(modelID), + StartedAt: &startedAt, + } + m.mu.Unlock() + + worker, err := startPersistentDiarizationWorker(ctx, modelID, envPath, params) + if err != nil { + m.setFailed(modelID, startedAt, err) + return m.Status(), err + } + + m.mu.Lock() + m.worker = worker + m.status = PersistentDiarizationStatus{ + State: PersistentDiarizationStateLoaded, + Loaded: true, + ModelID: modelID, + DisplayName: persistentDiarizationDisplayName(modelID), + StartedAt: &startedAt, + PID: worker.pid(), + } + status := m.status + m.mu.Unlock() + + logger.Info("Persistent diarization model loaded", "model_id", modelID, "pid", worker.pid()) + return status, nil +} + +// Unload stops the resident diarization worker, if one is running. +func (m *PersistentDiarizationManager) Unload(ctx context.Context) (PersistentDiarizationStatus, error) { + m.mu.Lock() + if m.worker == nil { + m.status = PersistentDiarizationStatus{State: PersistentDiarizationStateUnloaded} + status := m.status + m.mu.Unlock() + return status, nil + } + + worker := m.worker + modelID := m.status.ModelID + startedAt := m.status.StartedAt + m.status.State = PersistentDiarizationStateUnloading + m.status.Loaded = false + status := m.status + m.mu.Unlock() + + if err := worker.Stop(ctx); err != nil { + m.setFailed(modelID, derefTime(startedAt), err) + return m.Status(), err + } + + m.mu.Lock() + if m.worker == worker { + m.worker = nil + m.status = PersistentDiarizationStatus{State: PersistentDiarizationStateUnloaded} + status = m.status + } + m.mu.Unlock() + + logger.Info("Persistent diarization model unloaded", "model_id", modelID) + return status, nil +} + +// Diarize runs a request through the resident model. The caller still owns output parsing. +func (m *PersistentDiarizationManager) Diarize(ctx context.Context, modelID string, payload map[string]interface{}) error { + modelID = normalizePersistentDiarizationModel(modelID) + + m.mu.RLock() + worker := m.worker + loaded := worker != nil && m.status.Loaded && m.status.ModelID == modelID + m.mu.RUnlock() + + if !loaded { + return fmt.Errorf("persistent %s diarization model is not loaded", persistentDiarizationDisplayName(modelID)) + } + + err := worker.Request(ctx, payload) + if errors.Is(err, errPersistentWorkerStopped) { + m.mu.Lock() + if m.worker == worker { + m.worker = nil + m.status = PersistentDiarizationStatus{ + State: PersistentDiarizationStateFailed, + Loaded: false, + ModelID: modelID, + DisplayName: persistentDiarizationDisplayName(modelID), + Error: err.Error(), + } + } + m.mu.Unlock() + } + return err +} + +func (m *PersistentDiarizationManager) setFailed(modelID string, startedAt time.Time, err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.worker = nil + m.status = PersistentDiarizationStatus{ + State: PersistentDiarizationStateFailed, + ModelID: modelID, + DisplayName: persistentDiarizationDisplayName(modelID), + Error: err.Error(), + StartedAt: &startedAt, + } +} + +type persistentDiarizationWorker struct { + modelID string + cmd *exec.Cmd + stdin io.WriteCloser + responses chan workerProtocolMessage + ready chan workerProtocolMessage + done chan error + requestMu sync.Mutex + stopMu sync.Mutex + stopped bool + protocolEnc *json.Encoder +} + +type workerProtocolMessage struct { + Type string `json:"type,omitempty"` + ID string `json:"id,omitempty"` + OK bool `json:"ok,omitempty"` + Error string `json:"error,omitempty"` + ModelID string `json:"model_id,omitempty"` +} + +func startPersistentDiarizationWorker(ctx context.Context, modelID, envPath string, params map[string]interface{}) (*persistentDiarizationWorker, error) { + if envPath == "" { + return nil, fmt.Errorf("environment path for %s is not registered", persistentDiarizationDisplayName(modelID)) + } + + scriptPath := filepath.Join(envPath, persistentDiarizationWorkerScript(modelID)) + if _, err := os.Stat(scriptPath); err != nil { + return nil, fmt.Errorf("persistent worker script not found at %s: %w", scriptPath, err) + } + + args, err := persistentDiarizationWorkerArgs(modelID, envPath, scriptPath, params) + if err != nil { + return nil, err + } + + cmd := exec.Command("uv", args...) + cmd.Env = append(os.Environ(), "PYTHONUNBUFFERED=1") + + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("failed to open worker stdin: %w", err) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("failed to open worker stdout: %w", err) + } + stderr, err := cmd.StderrPipe() + if err != nil { + return nil, fmt.Errorf("failed to open worker stderr: %w", err) + } + + worker := &persistentDiarizationWorker{ + modelID: modelID, + cmd: cmd, + stdin: stdin, + responses: make(chan workerProtocolMessage, 4), + ready: make(chan workerProtocolMessage, 1), + done: make(chan error, 1), + protocolEnc: json.NewEncoder(stdin), + } + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("failed to start persistent diarization worker: %w", err) + } + + go worker.readStdout(stdout) + go worker.readStderr(stderr) + go func() { + worker.done <- cmd.Wait() + close(worker.done) + }() + + select { + case msg := <-worker.ready: + if msg.Type == "ready" && msg.OK { + return worker, nil + } + _ = worker.Stop(context.Background()) + if msg.Error != "" { + return nil, fmt.Errorf("persistent diarization worker failed to load: %s", msg.Error) + } + return nil, fmt.Errorf("persistent diarization worker failed to load") + case err := <-worker.done: + if err != nil { + return nil, fmt.Errorf("persistent diarization worker exited during load: %w", err) + } + return nil, fmt.Errorf("persistent diarization worker exited during load") + case <-ctx.Done(): + _ = worker.Stop(context.Background()) + return nil, fmt.Errorf("loading persistent diarization worker cancelled: %w", ctx.Err()) + } +} + +func persistentDiarizationWorkerArgs(modelID, envPath, scriptPath string, params map[string]interface{}) ([]string, error) { + args := []string{"run", "--native-tls", "--project", envPath, "python", scriptPath} + device := stringParam(params, "device", "auto") + + switch modelID { + case PersistentDiarizationModelPyAnnote: + hfToken := stringParam(params, "hf_token", "") + if hfToken == "" { + hfToken = os.Getenv("HF_TOKEN") + } + if hfToken == "" { + return nil, fmt.Errorf("HuggingFace token is required to load PyAnnote. Set HF_TOKEN before loading the resident model") + } + args = append(args, + "--hf-token", hfToken, + "--model", stringParam(params, "model", "pyannote/speaker-diarization-community-1"), + "--device", device, + ) + case PersistentDiarizationModelSortformer: + args = append(args, "--device", device) + default: + return nil, fmt.Errorf("unsupported persistent diarization model: %s", modelID) + } + + return args, nil +} + +func (w *persistentDiarizationWorker) Request(ctx context.Context, payload map[string]interface{}) error { + w.requestMu.Lock() + defer w.requestMu.Unlock() + + requestID := fmt.Sprintf("%d", time.Now().UnixNano()) + request := make(map[string]interface{}, len(payload)+2) + for key, value := range payload { + request[key] = value + } + request["id"] = requestID + request["action"] = "diarize" + + if err := w.protocolEnc.Encode(request); err != nil { + return fmt.Errorf("failed to send diarization request: %w", err) + } + + for { + select { + case msg := <-w.responses: + if msg.ID != requestID { + logger.Warn("Ignoring stale persistent diarization response", "expected_id", requestID, "response_id", msg.ID) + continue + } + if !msg.OK { + if msg.Error == "" { + msg.Error = "worker returned an unknown error" + } + return fmt.Errorf("persistent diarization failed: %s", msg.Error) + } + return nil + case err := <-w.done: + if err != nil { + return fmt.Errorf("%w: %v", errPersistentWorkerStopped, err) + } + return fmt.Errorf("%w: process exited", errPersistentWorkerStopped) + case <-ctx.Done(): + w.forceStop() + return fmt.Errorf("%w: request cancelled: %v", errPersistentWorkerStopped, ctx.Err()) + } + } +} + +func (w *persistentDiarizationWorker) Stop(ctx context.Context) error { + w.stopMu.Lock() + if w.stopped { + w.stopMu.Unlock() + return nil + } + w.stopped = true + w.stopMu.Unlock() + + _ = w.protocolEnc.Encode(map[string]interface{}{ + "id": fmt.Sprintf("shutdown-%d", time.Now().UnixNano()), + "action": "shutdown", + }) + + killed := false + select { + case err := <-w.done: + return err + case <-time.After(5 * time.Second): + w.forceStop() + killed = true + case <-ctx.Done(): + w.forceStop() + return ctx.Err() + } + + select { + case err := <-w.done: + if killed { + return nil + } + return err + case <-time.After(2 * time.Second): + return fmt.Errorf("persistent diarization worker did not exit after kill") + } +} + +func (w *persistentDiarizationWorker) forceStop() { + w.stopMu.Lock() + w.stopped = true + w.stopMu.Unlock() + + if w.cmd != nil && w.cmd.Process != nil { + _ = w.cmd.Process.Kill() + } +} + +func (w *persistentDiarizationWorker) readStdout(stdout io.Reader) { + scanner := bufio.NewScanner(stdout) + scanner.Buffer(make([]byte, 0, 64*1024), 10*1024*1024) + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + var msg workerProtocolMessage + if err := json.Unmarshal([]byte(line), &msg); err != nil { + logger.Warn("Ignoring non-protocol persistent diarization output", "model_id", w.modelID, "line", line) + continue + } + + if msg.Type == "ready" || (msg.Type == "error" && msg.ID == "") { + select { + case w.ready <- msg: + default: + logger.Warn("Dropping duplicate persistent diarization ready message", "model_id", w.modelID) + } + continue + } + + select { + case w.responses <- msg: + default: + logger.Warn("Dropping persistent diarization response because response channel is full", "model_id", w.modelID, "id", msg.ID) + } + } + + if err := scanner.Err(); err != nil { + logger.Warn("Persistent diarization stdout scanner failed", "model_id", w.modelID, "error", err) + } +} + +func (w *persistentDiarizationWorker) readStderr(stderr io.Reader) { + scanner := bufio.NewScanner(stderr) + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + for scanner.Scan() { + logger.Info("Persistent diarization worker", "model_id", w.modelID, "message", scanner.Text()) + } + if err := scanner.Err(); err != nil { + logger.Warn("Persistent diarization stderr scanner failed", "model_id", w.modelID, "error", err) + } +} + +func (w *persistentDiarizationWorker) pid() int { + if w == nil || w.cmd == nil || w.cmd.Process == nil { + return 0 + } + return w.cmd.Process.Pid +} + +func normalizePersistentDiarizationModel(modelID string) string { + switch strings.TrimSpace(strings.ToLower(modelID)) { + case "pyannote", "pyannote/speaker-diarization-3.1", "pyannote/speaker-diarization-community-1": + return PersistentDiarizationModelPyAnnote + case "sortformer", "nvidia_sortformer", "nvidia/diar_streaming_sortformer_4spk-v2": + return PersistentDiarizationModelSortformer + default: + return "" + } +} + +func persistentDiarizationDisplayName(modelID string) string { + switch normalizePersistentDiarizationModel(modelID) { + case PersistentDiarizationModelPyAnnote: + return "PyAnnote" + case PersistentDiarizationModelSortformer: + return "NVIDIA Sortformer" + default: + return modelID + } +} + +func persistentDiarizationWorkerScript(modelID string) string { + switch modelID { + case PersistentDiarizationModelPyAnnote: + return "pyannote_worker.py" + case PersistentDiarizationModelSortformer: + return "sortformer_worker.py" + default: + return "" + } +} + +func stringParam(params map[string]interface{}, key, fallback string) string { + if params == nil { + return fallback + } + if value, ok := params[key]; ok { + if str, ok := value.(string); ok && str != "" { + return str + } + } + return fallback +} + +func derefTime(value *time.Time) time.Time { + if value == nil { + return time.Now() + } + return *value +} diff --git a/internal/transcription/adapters/py/nvidia/sortformer_worker.py b/internal/transcription/adapters/py/nvidia/sortformer_worker.py new file mode 100644 index 000000000..6578cd1bf --- /dev/null +++ b/internal/transcription/adapters/py/nvidia/sortformer_worker.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +""" +Persistent NVIDIA Sortformer diarization worker. +Loads the Sortformer .nemo model once, then accepts newline-delimited JSON +requests on stdin and writes newline-delimited JSON responses on stdout. +""" + +import argparse +import contextlib +import json +import os +from pathlib import Path +import sys +import traceback + +import torch + +from nemo.collections.asr.models import SortformerEncLabelModel +from sortformer_diarize import save_results + + +def send(message): + sys.stdout.write(json.dumps(message, separators=(",", ":")) + "\n") + sys.stdout.flush() + + +def resolve_device(device): + if device == "auto": + return "cuda" if torch.cuda.is_available() else "cpu" + if device == "cuda" and not torch.cuda.is_available(): + print("CUDA requested but not available, using CPU", file=sys.stderr) + return "cpu" + return device + + +def model_path_from_virtual_env(): + virtual_env = os.environ.get("VIRTUAL_ENV") + if not virtual_env: + raise RuntimeError("VIRTUAL_ENV is not set. Run this worker with 'uv run'.") + + project_root = os.path.dirname(virtual_env) + return os.path.join(project_root, "diar_streaming_sortformer_4spk-v2.nemo") + + +def load_model(device): + model_path = model_path_from_virtual_env() + if not os.path.exists(model_path): + raise FileNotFoundError(f"Sortformer model file not found: {model_path}") + + print(f"Loading persistent Sortformer model from: {model_path}", file=sys.stderr) + diar_model = SortformerEncLabelModel.restore_from( + restore_path=model_path, + map_location=device, + strict=False, + ) + diar_model.eval() + print(f"Persistent Sortformer model loaded on {device}", file=sys.stderr) + return diar_model + + +def run_diarization(diar_model, request): + audio_file = request["audio_file"] + output_file = request["output_file"] + output_format = request.get("output_format", "json") + batch_size = int(request.get("batch_size") or 1) + + if not os.path.exists(audio_file): + raise FileNotFoundError(f"Audio file not found: {audio_file}") + + Path(output_file).parent.mkdir(parents=True, exist_ok=True) + predicted_segments = diar_model.diarize(audio=audio_file, batch_size=batch_size) + save_results(predicted_segments, output_file, audio_file, output_format) + + +def main(): + parser = argparse.ArgumentParser(description="Persistent Sortformer diarization worker") + parser.add_argument("--device", choices=["cpu", "cuda", "auto"], default="auto") + args = parser.parse_args() + + try: + device = resolve_device(args.device) + with contextlib.redirect_stdout(sys.stderr): + diar_model = load_model(device) + send({"type": "ready", "ok": True, "model_id": "sortformer", "device": device}) + except Exception as exc: + traceback.print_exc(file=sys.stderr) + send({"type": "error", "ok": False, "model_id": "sortformer", "error": str(exc)}) + return 1 + + for line in sys.stdin: + line = line.strip() + if not line: + continue + + request = {} + try: + request = json.loads(line) + request_id = request.get("id") + action = request.get("action") + + if action == "shutdown": + send({"type": "response", "id": request_id, "ok": True}) + return 0 + + if action != "diarize": + raise ValueError(f"Unsupported action: {action}") + + with contextlib.redirect_stdout(sys.stderr): + run_diarization(diar_model, request) + send({"type": "response", "id": request_id, "ok": True}) + except Exception as exc: + traceback.print_exc(file=sys.stderr) + send({ + "type": "response", + "id": request.get("id"), + "ok": False, + "error": str(exc), + }) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/internal/transcription/adapters/py/pyannote/pyannote_worker.py b/internal/transcription/adapters/py/pyannote/pyannote_worker.py new file mode 100644 index 000000000..7f70d1d6f --- /dev/null +++ b/internal/transcription/adapters/py/pyannote/pyannote_worker.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +""" +Persistent PyAnnote diarization worker. +Loads the PyAnnote pipeline once, then accepts newline-delimited JSON requests +on stdin and writes newline-delimited JSON responses on stdout. +""" + +import argparse +import contextlib +import copy +import json +import os +from pathlib import Path +import sys +import traceback + +from pyannote.audio import Pipeline +import torch + +from pyannote_diarize import save_json_format + + +def send(message): + sys.stdout.write(json.dumps(message, separators=(",", ":")) + "\n") + sys.stdout.flush() + + +def load_pipeline(args): + print(f"Loading persistent PyAnnote pipeline: {args.model}", file=sys.stderr) + pipeline = Pipeline.from_pretrained(args.model, token=args.hf_token) + + if args.device == "cuda" or (args.device == "auto" and torch.cuda.is_available()): + if torch.cuda.is_available(): + pipeline = pipeline.to(torch.device("cuda")) + print("Using CUDA for persistent PyAnnote diarization", file=sys.stderr) + else: + print("CUDA requested but not available, using CPU", file=sys.stderr) + else: + print("Using CPU for persistent PyAnnote diarization", file=sys.stderr) + + return pipeline + + +def instantiate_for_request(pipeline, base_params, request): + onset = request.get("segmentation_onset") + offset = request.get("segmentation_offset") + if onset is None and offset is None: + return + + try: + params = copy.deepcopy(base_params) + if "segmentation" not in params: + print("Warning: segmentation parameters not found in PyAnnote pipeline", file=sys.stderr) + return + + if onset is not None: + params["segmentation"]["threshold"] = float(onset) + if offset is not None: + params["segmentation"]["min_duration_off"] = float(offset) + pipeline.instantiate(params) + except Exception as exc: + print(f"Warning: could not apply PyAnnote request thresholds: {exc}", file=sys.stderr) + + +def run_diarization(pipeline, base_params, request): + audio_file = request["audio_file"] + output_file = request["output_file"] + output_format = request.get("output_format", "json") + + if not os.path.exists(audio_file): + raise FileNotFoundError(f"Audio file not found: {audio_file}") + + Path(output_file).parent.mkdir(parents=True, exist_ok=True) + instantiate_for_request(pipeline, base_params, request) + + diarization_params = {} + if request.get("min_speakers"): + diarization_params["min_speakers"] = int(request["min_speakers"]) + if request.get("max_speakers"): + diarization_params["max_speakers"] = int(request["max_speakers"]) + + if diarization_params: + diarization = pipeline(audio_file, **diarization_params) + else: + diarization = pipeline(audio_file) + + if output_format == "rttm": + with open(output_file, "w") as rttm: + diarization.write_rttm(rttm) + else: + save_json_format(diarization, output_file, audio_file) + + +def main(): + parser = argparse.ArgumentParser(description="Persistent PyAnnote diarization worker") + parser.add_argument("--hf-token", required=True) + parser.add_argument("--model", default="pyannote/speaker-diarization-community-1") + parser.add_argument("--device", choices=["cpu", "cuda", "auto"], default="auto") + args = parser.parse_args() + + try: + with contextlib.redirect_stdout(sys.stderr): + pipeline = load_pipeline(args) + base_params = copy.deepcopy(pipeline.parameters(instantiated=True)) + send({"type": "ready", "ok": True, "model_id": "pyannote", "model": args.model}) + except Exception as exc: + traceback.print_exc(file=sys.stderr) + send({"type": "error", "ok": False, "model_id": "pyannote", "error": str(exc)}) + return 1 + + for line in sys.stdin: + line = line.strip() + if not line: + continue + + request = {} + try: + request = json.loads(line) + request_id = request.get("id") + action = request.get("action") + + if action == "shutdown": + send({"type": "response", "id": request_id, "ok": True}) + return 0 + + if action != "diarize": + raise ValueError(f"Unsupported action: {action}") + + with contextlib.redirect_stdout(sys.stderr): + run_diarization(pipeline, base_params, request) + send({"type": "response", "id": request_id, "ok": True}) + except Exception as exc: + traceback.print_exc(file=sys.stderr) + send({ + "type": "response", + "id": request.get("id"), + "ok": False, + "error": str(exc), + }) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/internal/transcription/adapters/pyannote_adapter.go b/internal/transcription/adapters/pyannote_adapter.go index b4168e0bf..1f869bff0 100644 --- a/internal/transcription/adapters/pyannote_adapter.go +++ b/internal/transcription/adapters/pyannote_adapter.go @@ -29,6 +29,8 @@ type PyAnnoteAdapter struct { // NewPyAnnoteAdapter creates a new PyAnnote diarization adapter func NewPyAnnoteAdapter(envPath string) *PyAnnoteAdapter { + GetPersistentDiarizationManager().SetEnvironment(PersistentDiarizationModelPyAnnote, envPath) + capabilities := interfaces.ModelCapabilities{ ModelID: "pyannote", ModelFamily: "pyannote", @@ -181,6 +183,7 @@ func (p *PyAnnoteAdapter) GetMinSpeakers() int { // PrepareEnvironment sets up the dedicated PyAnnote environment func (p *PyAnnoteAdapter) PrepareEnvironment(ctx context.Context) error { logger.Info("Preparing PyAnnote environment", "env_path", p.envPath) + GetPersistentDiarizationManager().SetEnvironment(PersistentDiarizationModelPyAnnote, p.envPath) // Always ensure diarization script exists if err := p.copyDiarizationScript(); err != nil { @@ -259,14 +262,17 @@ func (p *PyAnnoteAdapter) copyDiarizationScript() error { return fmt.Errorf("failed to create pyannote directory: %w", err) } - scriptContent, err := pyannoteScripts.ReadFile("py/pyannote/pyannote_diarize.py") - if err != nil { - return fmt.Errorf("failed to read embedded pyannote_diarize.py: %w", err) - } + scripts := []string{"pyannote_diarize.py", "pyannote_worker.py"} + for _, scriptName := range scripts { + scriptContent, err := pyannoteScripts.ReadFile(filepath.Join("py/pyannote", scriptName)) + if err != nil { + return fmt.Errorf("failed to read embedded %s: %w", scriptName, err) + } - scriptPath := filepath.Join(p.envPath, "pyannote_diarize.py") - if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil { - return fmt.Errorf("failed to write diarization script: %w", err) + scriptPath := filepath.Join(p.envPath, scriptName) + if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil { + return fmt.Errorf("failed to write %s: %w", scriptName, err) + } } return nil @@ -295,11 +301,13 @@ func (p *PyAnnoteAdapter) Diarize(ctx context.Context, input interfaces.AudioInp if hfToken == "" { hfToken = os.Getenv("HF_TOKEN") } - if hfToken == "" { + if hfToken == "" && !GetPersistentDiarizationManager().IsModelLoaded(PersistentDiarizationModelPyAnnote) { return nil, fmt.Errorf("HuggingFace token is required for PyAnnote diarization. Set HF_TOKEN environment variable or provide it in the UI") } - // Store resolved token in params for buildPyAnnoteArgs - params["hf_token"] = hfToken + if hfToken != "" { + // Store resolved token in params for buildPyAnnoteArgs + params["hf_token"] = hfToken + } // Create temporary directory tempDir, err := p.CreateTempDirectory(procCtx) @@ -308,6 +316,19 @@ func (p *PyAnnoteAdapter) Diarize(ctx context.Context, input interfaces.AudioInp } defer p.CleanupTempDirectory(tempDir) + if GetPersistentDiarizationManager().IsModelLoaded(PersistentDiarizationModelPyAnnote) { + result, err := p.diarizeWithPersistentWorker(ctx, input, params, tempDir) + if err != nil { + return nil, err + } + + result.ProcessingTime = time.Since(startTime) + result.ModelUsed = p.GetStringParameter(params, "model") + result.Metadata = p.CreateDefaultMetadata(params) + result.Metadata["execution_mode"] = "persistent_worker" + return result, nil + } + // Build command arguments args, err := p.buildPyAnnoteArgs(input, params, tempDir) if err != nil { @@ -364,6 +385,45 @@ func (p *PyAnnoteAdapter) Diarize(ctx context.Context, input interfaces.AudioInp return result, nil } +func (p *PyAnnoteAdapter) diarizeWithPersistentWorker(ctx context.Context, input interfaces.AudioInput, params map[string]interface{}, tempDir string) (*interfaces.DiarizationResult, error) { + outputFormat := p.GetStringParameter(params, "output_format") + outputFile := filepath.Join(tempDir, "result.rttm") + if outputFormat == OutputFormatJSON { + outputFile = filepath.Join(tempDir, "result.json") + } + + request := map[string]interface{}{ + "audio_file": input.FilePath, + "output_file": outputFile, + "output_format": outputFormat, + } + + if minSpeakers := p.GetIntParameter(params, "min_speakers"); minSpeakers > 0 { + request["min_speakers"] = minSpeakers + } + if maxSpeakers := p.GetIntParameter(params, "max_speakers"); maxSpeakers > 0 { + request["max_speakers"] = maxSpeakers + } + if onset := p.GetFloatParameter(params, "segmentation_onset"); onset > 0 { + request["segmentation_onset"] = onset + } + if offset := p.GetFloatParameter(params, "segmentation_offset"); offset > 0 { + request["segmentation_offset"] = offset + } + + logger.Info("Executing PyAnnote via persistent worker", "audio_file", input.FilePath, "output_file", outputFile) + if err := GetPersistentDiarizationManager().Diarize(ctx, PersistentDiarizationModelPyAnnote, request); err != nil { + return nil, fmt.Errorf("persistent PyAnnote execution failed: %w", err) + } + + result, err := p.parseResult(tempDir, input, params) + if err != nil { + return nil, fmt.Errorf("failed to parse persistent PyAnnote result: %w", err) + } + + return result, nil +} + // buildPyAnnoteArgs builds the command arguments for PyAnnote func (p *PyAnnoteAdapter) buildPyAnnoteArgs(input interfaces.AudioInput, params map[string]interface{}, tempDir string) ([]string, error) { outputFormat := p.GetStringParameter(params, "output_format") diff --git a/internal/transcription/adapters/sortformer_adapter.go b/internal/transcription/adapters/sortformer_adapter.go index 4d8f5b1fd..3dbbdb5c8 100644 --- a/internal/transcription/adapters/sortformer_adapter.go +++ b/internal/transcription/adapters/sortformer_adapter.go @@ -25,6 +25,8 @@ type SortformerAdapter struct { // NewSortformerAdapter creates a new NVIDIA Sortformer diarization adapter func NewSortformerAdapter(envPath string) *SortformerAdapter { + GetPersistentDiarizationManager().SetEnvironment(PersistentDiarizationModelSortformer, envPath) + capabilities := interfaces.ModelCapabilities{ ModelID: "sortformer", ModelFamily: "nvidia_sortformer", @@ -162,6 +164,7 @@ func (s *SortformerAdapter) GetMinSpeakers() int { // PrepareEnvironment sets up the Sortformer environment (shared with NVIDIA models) func (s *SortformerAdapter) PrepareEnvironment(ctx context.Context) error { logger.Info("Preparing NVIDIA Sortformer environment", "env_path", s.envPath) + GetPersistentDiarizationManager().SetEnvironment(PersistentDiarizationModelSortformer, s.envPath) // Copy diarization script if err := s.copyDiarizationScript(); err != nil { @@ -279,14 +282,17 @@ func (s *SortformerAdapter) copyDiarizationScript() error { return fmt.Errorf("failed to create directory: %w", err) } - scriptContent, err := nvidiaScripts.ReadFile("py/nvidia/sortformer_diarize.py") - if err != nil { - return fmt.Errorf("failed to read embedded sortformer_diarize.py: %w", err) - } + scripts := []string{"sortformer_diarize.py", "sortformer_worker.py"} + for _, scriptName := range scripts { + scriptContent, err := nvidiaScripts.ReadFile(filepath.Join("py/nvidia", scriptName)) + if err != nil { + return fmt.Errorf("failed to read embedded %s: %w", scriptName, err) + } - scriptPath := filepath.Join(s.envPath, "sortformer_diarize.py") - if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil { - return fmt.Errorf("failed to write diarization script: %w", err) + scriptPath := filepath.Join(s.envPath, scriptName) + if err := os.WriteFile(scriptPath, scriptContent, 0755); err != nil { + return fmt.Errorf("failed to write %s: %w", scriptName, err) + } } return nil @@ -328,6 +334,19 @@ func (s *SortformerAdapter) Diarize(ctx context.Context, input interfaces.AudioI } } + if GetPersistentDiarizationManager().IsModelLoaded(PersistentDiarizationModelSortformer) { + result, err := s.diarizeWithPersistentWorker(ctx, audioInput, params, tempDir) + if err != nil { + return nil, err + } + + result.ProcessingTime = time.Since(startTime) + result.ModelUsed = "diar_streaming_sortformer_4spk-v2" + result.Metadata = s.CreateDefaultMetadata(params) + result.Metadata["execution_mode"] = "persistent_worker" + return result, nil + } + // Build command arguments args, err := s.buildSortformerArgs(audioInput, params, tempDir) if err != nil { @@ -384,6 +403,44 @@ func (s *SortformerAdapter) Diarize(ctx context.Context, input interfaces.AudioI return result, nil } +func (s *SortformerAdapter) diarizeWithPersistentWorker(ctx context.Context, input interfaces.AudioInput, params map[string]interface{}, tempDir string) (*interfaces.DiarizationResult, error) { + outputFormat := s.GetStringParameter(params, "output_format") + outputFile := filepath.Join(tempDir, "result.rttm") + if outputFormat == OutputFormatJSON { + outputFile = filepath.Join(tempDir, "result.json") + } + + request := map[string]interface{}{ + "audio_file": input.FilePath, + "output_file": outputFile, + "output_format": outputFormat, + } + if batchSize := s.GetIntParameter(params, "batch_size"); batchSize > 0 { + request["batch_size"] = batchSize + } + if maxSpeakers := s.GetIntParameter(params, "max_speakers"); maxSpeakers > 0 { + request["max_speakers"] = maxSpeakers + } + if s.GetBoolParameter(params, "streaming_mode") { + request["streaming_mode"] = true + if chunkLength := s.GetFloatParameter(params, "chunk_length_s"); chunkLength > 0 { + request["chunk_length_s"] = chunkLength + } + } + + logger.Info("Executing Sortformer via persistent worker", "audio_file", input.FilePath, "output_file", outputFile) + if err := GetPersistentDiarizationManager().Diarize(ctx, PersistentDiarizationModelSortformer, request); err != nil { + return nil, fmt.Errorf("persistent Sortformer execution failed: %w", err) + } + + result, err := s.parseResult(tempDir, input, params) + if err != nil { + return nil, fmt.Errorf("failed to parse persistent Sortformer result: %w", err) + } + + return result, nil +} + // buildSortformerArgs builds the command arguments for Sortformer func (s *SortformerAdapter) buildSortformerArgs(input interfaces.AudioInput, params map[string]interface{}, tempDir string) ([]string, error) { outputFormat := s.GetStringParameter(params, "output_format") diff --git a/internal/transcription/queue_integration.go b/internal/transcription/queue_integration.go index f746cc83b..be217b75d 100644 --- a/internal/transcription/queue_integration.go +++ b/internal/transcription/queue_integration.go @@ -5,6 +5,7 @@ import ( "os/exec" "scriberr/internal/repository" + "scriberr/internal/transcription/adapters" "scriberr/pkg/logger" ) @@ -79,6 +80,21 @@ func (u *UnifiedJobProcessor) GetModelStatus(ctx context.Context) map[string]boo return u.unifiedService.GetModelStatus(ctx) } +// GetPersistentDiarizationStatus returns the resident diarization worker state. +func (u *UnifiedJobProcessor) GetPersistentDiarizationStatus() adapters.PersistentDiarizationStatus { + return u.unifiedService.GetPersistentDiarizationStatus() +} + +// LoadPersistentDiarizationModel loads a diarization model into VRAM through a resident worker. +func (u *UnifiedJobProcessor) LoadPersistentDiarizationModel(ctx context.Context, modelID string, params map[string]interface{}) (adapters.PersistentDiarizationStatus, error) { + return u.unifiedService.LoadPersistentDiarizationModel(ctx, modelID, params) +} + +// UnloadPersistentDiarizationModel unloads the resident diarization worker. +func (u *UnifiedJobProcessor) UnloadPersistentDiarizationModel(ctx context.Context) (adapters.PersistentDiarizationStatus, error) { + return u.unifiedService.UnloadPersistentDiarizationModel(ctx) +} + // ValidateModelParameters validates parameters for a specific model func (u *UnifiedJobProcessor) ValidateModelParameters(modelID string, params map[string]interface{}) error { return u.unifiedService.ValidateModelParameters(modelID, params) diff --git a/internal/transcription/unified_service.go b/internal/transcription/unified_service.go index bbfe39bf8..74972fae7 100644 --- a/internal/transcription/unified_service.go +++ b/internal/transcription/unified_service.go @@ -14,6 +14,7 @@ import ( "scriberr/internal/models" "scriberr/internal/repository" "scriberr/internal/sse" + "scriberr/internal/transcription/adapters" "scriberr/internal/transcription/interfaces" "scriberr/internal/transcription/pipeline" "scriberr/internal/transcription/registry" @@ -286,6 +287,10 @@ func (u *UnifiedTranscriptionService) processSingleTrackJob(ctx context.Context, var transcriptResult *interfaces.TranscriptResult var diarizationResult *interfaces.DiarizationResult + useSeparateDiarization := job.Parameters.Diarize && + diarizationModelID != "" && + (!u.transcriptionIncludesDiarization(transcriptionModelID, job.Parameters) || + adapters.GetPersistentDiarizationManager().IsModelLoaded(diarizationModelID)) if job.Parameters.DiarizationOnly { if !job.Parameters.Diarize { @@ -307,7 +312,11 @@ func (u *UnifiedTranscriptionService) processSingleTrackJob(ctx context.Context, } // Convert parameters for this specific model - params := u.convertParametersForModel(job.Parameters, transcriptionModelID) + transcriptionParams := job.Parameters + if useSeparateDiarization && transcriptionModelID == ModelWhisperX { + transcriptionParams.Diarize = false + } + params := u.convertParametersForModel(transcriptionParams, transcriptionModelID) transcriptResult, err = transcriptionAdapter.Transcribe(ctx, preprocessedInput, params, procCtx) if err != nil { @@ -320,7 +329,7 @@ func (u *UnifiedTranscriptionService) processSingleTrackJob(ctx context.Context, // Convert parameters for diarization model diarizationParams := u.convertParametersForModel(job.Parameters, diarizationModelID) - if !u.transcriptionIncludesDiarization(transcriptionModelID, job.Parameters) { + if useSeparateDiarization { logger.Info("Running separate diarization", "model_id", diarizationModelID) diarizationAdapter, err := u.registry.GetDiarizationAdapter(diarizationModelID) if err != nil { @@ -963,6 +972,21 @@ func (u *UnifiedTranscriptionService) GetModelStatus(ctx context.Context) map[st return u.registry.GetModelStatus(ctx) } +// GetPersistentDiarizationStatus returns the resident diarization worker state. +func (u *UnifiedTranscriptionService) GetPersistentDiarizationStatus() adapters.PersistentDiarizationStatus { + return adapters.GetPersistentDiarizationManager().Status() +} + +// LoadPersistentDiarizationModel loads a diarization model into a resident worker. +func (u *UnifiedTranscriptionService) LoadPersistentDiarizationModel(ctx context.Context, modelID string, params map[string]interface{}) (adapters.PersistentDiarizationStatus, error) { + return adapters.GetPersistentDiarizationManager().Load(ctx, modelID, params) +} + +// UnloadPersistentDiarizationModel unloads the resident diarization worker. +func (u *UnifiedTranscriptionService) UnloadPersistentDiarizationModel(ctx context.Context) (adapters.PersistentDiarizationStatus, error) { + return adapters.GetPersistentDiarizationManager().Unload(ctx) +} + // ValidateModelParameters validates parameters for a specific model func (u *UnifiedTranscriptionService) ValidateModelParameters(modelID string, params map[string]interface{}) error { return u.registry.ValidateModelParameters(modelID, params) diff --git a/web/frontend/src/components/Header.tsx b/web/frontend/src/components/Header.tsx index 4a481be6e..0725929bc 100644 --- a/web/frontend/src/components/Header.tsx +++ b/web/frontend/src/components/Header.tsx @@ -1,12 +1,21 @@ -import { useRef, useState } from "react"; +import { useCallback, useEffect, useRef, useState } from "react"; import { Button } from "@/components/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; -import { Upload, Mic, Settings, LogOut, Home, Plus, Grip, Zap, Youtube, Video, Users, MonitorSpeaker } from "lucide-react"; +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; +import { Upload, Mic, Settings, LogOut, Home, Plus, Grip, Zap, Youtube, Video, Users, MonitorSpeaker, BrainCircuit, Loader2 } from "lucide-react"; import { ScriberrLogo } from "./ScriberrLogo"; import { ThemeSwitcher } from "./ThemeSwitcher"; import { AudioRecorder } from "./AudioRecorder"; @@ -15,6 +24,7 @@ import { QuickTranscriptionDialog } from "@/features/transcription/components/Qu import { YouTubeDownloadDialog } from "@/features/transcription/components/YouTubeDownloadDialog"; import { useNavigate } from "react-router-dom"; import { useAuth } from "@/features/auth/hooks/useAuth"; +import { useToast } from "@/components/ui/toast"; import { isVideoFile, isAudioFile } from "../utils/fileProcessor"; import { useGlobalUpload } from "@/contexts/GlobalUploadContext"; @@ -29,15 +39,28 @@ interface HeaderProps { onDownloadComplete?: () => void; } +interface DiarizationWorkerStatus { + state: string; + loaded: boolean; + model_id?: string; + display_name?: string; + error?: string; +} + export function Header({ onFileSelect, onMultiTrackClick, onDownloadComplete }: HeaderProps) { const navigate = useNavigate(); - const { logout } = useAuth(); + const { logout, getAuthHeaders } = useAuth(); + const { toast } = useToast(); const fileInputRef = useRef(null); const videoFileInputRef = useRef(null); const [isRecorderOpen, setIsRecorderOpen] = useState(false); const [isSystemRecorderOpen, setIsSystemRecorderOpen] = useState(false); const [isQuickTranscriptionOpen, setIsQuickTranscriptionOpen] = useState(false); const [isYouTubeDialogOpen, setIsYouTubeDialogOpen] = useState(false); + const [isDiarizationDialogOpen, setIsDiarizationDialogOpen] = useState(false); + const [diarizationStatus, setDiarizationStatus] = useState({ state: "unloaded", loaded: false }); + const [selectedDiarizationModel, setSelectedDiarizationModel] = useState("pyannote"); + const [isDiarizationActionRunning, setIsDiarizationActionRunning] = useState(false); // Use global upload context as fallback when props are not provided const globalUpload = useGlobalUpload(); @@ -47,6 +70,29 @@ export function Header({ onFileSelect, onMultiTrackClick, onDownloadComplete }: const effectiveMultiTrackClick = onMultiTrackClick ?? globalUpload.openMultiTrackDialog; const effectiveRecordingComplete = globalUpload.handleRecordingComplete; + const refreshDiarizationStatus = useCallback(async () => { + try { + const response = await fetch("/api/v1/diarization-worker/status", { + headers: getAuthHeaders(), + }); + if (!response.ok) { + return null; + } + + const status = await response.json(); + setDiarizationStatus(status); + return status as DiarizationWorkerStatus; + } catch { + return null; + } + }, [getAuthHeaders]); + + useEffect(() => { + refreshDiarizationStatus(); + const interval = window.setInterval(refreshDiarizationStatus, 15000); + return () => window.clearInterval(interval); + }, [refreshDiarizationStatus]); + const handleUploadClick = () => { fileInputRef.current?.click(); }; @@ -75,6 +121,72 @@ export function Header({ onFileSelect, onMultiTrackClick, onDownloadComplete }: effectiveMultiTrackClick(); }; + const handleDiarizationWorkerClick = async () => { + await refreshDiarizationStatus(); + setIsDiarizationDialogOpen(true); + }; + + const handleLoadDiarizationModel = async () => { + setIsDiarizationActionRunning(true); + try { + const response = await fetch("/api/v1/diarization-worker/load", { + method: "POST", + headers: { + "Content-Type": "application/json", + ...getAuthHeaders(), + }, + body: JSON.stringify({ + model: selectedDiarizationModel, + device: "auto", + }), + }); + const body = await response.json().catch(() => ({})); + + if (!response.ok) { + if (body.status) setDiarizationStatus(body.status); + throw new Error(body.error || "Failed to load diarization model"); + } + + setDiarizationStatus(body); + setIsDiarizationDialogOpen(false); + toast({ title: `${body.display_name || "Diarization model"} loaded` }); + } catch (error) { + toast({ + title: "Could not load diarization model", + description: error instanceof Error ? error.message : "Unknown error", + }); + } finally { + setIsDiarizationActionRunning(false); + } + }; + + const handleUnloadDiarizationModel = async () => { + setIsDiarizationActionRunning(true); + try { + const response = await fetch("/api/v1/diarization-worker/unload", { + method: "POST", + headers: getAuthHeaders(), + }); + const body = await response.json().catch(() => ({})); + + if (!response.ok) { + if (body.status) setDiarizationStatus(body.status); + throw new Error(body.error || "Failed to unload diarization model"); + } + + setDiarizationStatus(body); + setIsDiarizationDialogOpen(false); + toast({ title: "Diarization model unloaded" }); + } catch (error) { + toast({ + title: "Could not unload diarization model", + description: error instanceof Error ? error.message : "Unknown error", + }); + } finally { + setIsDiarizationActionRunning(false); + } + }; + const handleSettingsClick = () => { navigate("/settings"); }; @@ -134,6 +246,11 @@ export function Header({ onFileSelect, onMultiTrackClick, onDownloadComplete }: await effectiveRecordingComplete(blob, title); }; + const isDiarizationBusy = + isDiarizationActionRunning || + diarizationStatus.state === "loading" || + diarizationStatus.state === "unloading"; + const diarizationLabel = diarizationStatus.display_name || "Diarization model"; return (
@@ -260,6 +377,31 @@ export function Header({ onFileSelect, onMultiTrackClick, onDownloadComplete }: + + {/* Main Menu (Grip) */} @@ -340,6 +482,77 @@ export function Header({ onFileSelect, onMultiTrackClick, onDownloadComplete }: onDownloadComplete={onDownloadComplete} /> + + + {diarizationStatus.loaded ? ( + <> + + Unload diarization model? + + {diarizationLabel} is currently loaded in GPU memory. + + + + + + + + ) : ( + <> + + Load diarization model + + Choose a diarization model to keep loaded in GPU memory. + + +
+ + {diarizationStatus.state === "failed" && diarizationStatus.error && ( +

{diarizationStatus.error}

+ )} +
+ + + + + + )} +
+
+
); } diff --git a/web/frontend/src/features/settings/components/ProfileSettings.tsx b/web/frontend/src/features/settings/components/ProfileSettings.tsx index d967baa87..b716fd36a 100644 --- a/web/frontend/src/features/settings/components/ProfileSettings.tsx +++ b/web/frontend/src/features/settings/components/ProfileSettings.tsx @@ -21,6 +21,7 @@ interface TranscriptionProfile { interface UserSettings { auto_transcription_enabled: boolean; default_profile_id?: string; + startup_diarization_model: string; } export function ProfileSettings() { @@ -157,6 +158,37 @@ export function ProfileSettings() { } }; + const handleStartupDiarizationModelChange = async (model: string) => { + setError(""); + setSuccess(""); + + try { + const response = await fetch("/api/v1/user/settings", { + method: "PUT", + headers: { + "Content-Type": "application/json", + ...getAuthHeaders(), + }, + body: JSON.stringify({ + startup_diarization_model: model, + }), + }); + + if (response.ok) { + const updatedSettings = await response.json(); + setUserSettings(updatedSettings); + const label = model === "none" ? "No startup diarization model" : `${model === "pyannote" ? "PyAnnote" : "Sortformer"} at startup`; + setSuccess(`Startup diarization setting updated: ${label}.`); + } else { + const errorData = await response.json(); + setError(errorData.error || "Failed to update startup diarization setting"); + } + } catch (error) { + console.error("Error updating startup diarization setting:", error); + setError("Network error. Please try again."); + } + }; + const handleCreateProfile = useCallback(() => { setEditingProfile(null); setProfileDialogOpen(true); @@ -256,21 +288,48 @@ export function ProfileSettings() { Loading settings...
) : ( -
-
- -

- When enabled, uploaded audio files will automatically be queued for transcription using your default profile. -

+
+
+
+ +

+ When enabled, uploaded audio files will automatically be queued for transcription using your default profile. +

+
+ +
+ +
+
+ +

+ Keep a diarization model resident in GPU memory after Scriberr starts. +

+
+
-
)}
From baef915a1f6b453943a75418caee8792ad85e622 Mon Sep 17 00:00:00 2001 From: Codex Date: Fri, 26 Jun 2026 10:51:08 -0400 Subject: [PATCH 3/4] Use profile HF token for diarization worker --- cmd/server/main.go | 45 +++++++++++++++++++++++++++++++---- internal/api/handlers.go | 51 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 89 insertions(+), 7 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index 5a6e81dba..e9882fa8c 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -126,7 +126,7 @@ func main() { logger.Error("Failed to prepare Python environment", "error", err) os.Exit(1) } - loadStartupDiarizationModel(context.Background(), userRepo, unifiedProcessor) + loadStartupDiarizationModel(context.Background(), userRepo, profileRepo, unifiedProcessor) // Initialize quick transcription service logger.Startup("quick-transcription", "Initializing quick transcription service") @@ -251,7 +251,7 @@ func registerAdapters(cfg *config.Config) { logger.Info("Adapter registration complete") } -func loadStartupDiarizationModel(ctx context.Context, userRepo repository.UserRepository, unifiedProcessor *transcription.UnifiedJobProcessor) { +func loadStartupDiarizationModel(ctx context.Context, userRepo repository.UserRepository, profileRepo repository.ProfileRepository, unifiedProcessor *transcription.UnifiedJobProcessor) { users, _, err := userRepo.List(ctx, 0, 100) if err != nil { logger.Warn("Failed to load startup diarization preference", "error", err) @@ -265,10 +265,15 @@ func loadStartupDiarizationModel(ctx context.Context, userRepo repository.UserRe } logger.Startup("diarization", "Loading persistent diarization model configured for startup") - loadCtx, cancel := context.WithTimeout(ctx, 30*time.Minute) - status, err := unifiedProcessor.LoadPersistentDiarizationModel(loadCtx, modelID, map[string]interface{}{ + params := map[string]interface{}{ "device": "auto", - }) + } + if modelID == transcription.ModelPyannote { + applyStartupProfileHFToken(ctx, user.ID, user.DefaultProfileID, profileRepo, params) + } + + loadCtx, cancel := context.WithTimeout(ctx, 30*time.Minute) + status, err := unifiedProcessor.LoadPersistentDiarizationModel(loadCtx, modelID, params) cancel() if err != nil { logger.Warn("Failed to load startup diarization model", "model_id", modelID, "user_id", user.ID, "error", err) @@ -279,3 +284,33 @@ func loadStartupDiarizationModel(ctx context.Context, userRepo repository.UserRe return } } + +func applyStartupProfileHFToken(ctx context.Context, userID uint, defaultProfileID *string, profileRepo repository.ProfileRepository, params map[string]interface{}) { + if strings.TrimSpace(os.Getenv("HF_TOKEN")) != "" { + return + } + if defaultProfileID == nil || strings.TrimSpace(*defaultProfileID) == "" { + logger.Warn("Cannot resolve HF_TOKEN for startup diarization because the user has no default transcription profile", "user_id", userID) + return + } + + profileID := strings.TrimSpace(*defaultProfileID) + profile, err := profileRepo.FindByID(ctx, profileID) + if err != nil { + logger.Warn("Failed to load default transcription profile for startup diarization", "user_id", userID, "profile_id", profileID, "error", err) + return + } + if profile.Parameters.HfToken == nil || strings.TrimSpace(*profile.Parameters.HfToken) == "" { + logger.Warn("Default transcription profile has no HF token for startup diarization", "user_id", userID, "profile_id", profileID) + return + } + + hfToken := strings.TrimSpace(*profile.Parameters.HfToken) + params["hf_token"] = hfToken + if err := os.Setenv("HF_TOKEN", hfToken); err != nil { + logger.Warn("Failed to set HF_TOKEN from default transcription profile", "user_id", userID, "profile_id", profileID, "error", err) + return + } + + logger.Info("Resolved HF_TOKEN from default transcription profile for startup diarization", "user_id", userID, "profile_id", profileID) +} diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 1a2e0874f..2f76a99e7 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -2292,8 +2292,12 @@ func (h *Handler) LoadDiarizationWorker(c *gin.Context) { if req.Device != "" { params["device"] = req.Device } - if req.HfToken != "" { - params["hf_token"] = req.HfToken + if hfToken := strings.TrimSpace(req.HfToken); hfToken != "" { + params["hf_token"] = hfToken + } else if isPyAnnoteWorkerModel(req.Model) { + if hfToken := h.resolveDiarizationWorkerHFToken(c); hfToken != "" { + params["hf_token"] = hfToken + } } ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Minute) @@ -2311,6 +2315,49 @@ func (h *Handler) LoadDiarizationWorker(c *gin.Context) { c.JSON(http.StatusOK, status) } +func isPyAnnoteWorkerModel(model string) bool { + switch strings.TrimSpace(strings.ToLower(model)) { + case "pyannote", "pyannote/speaker-diarization-3.1", "pyannote/speaker-diarization-community-1": + return true + default: + return false + } +} + +func (h *Handler) resolveDiarizationWorkerHFToken(c *gin.Context) string { + profile := h.resolveRequestDefaultProfile(c) + if profile != nil && profile.Parameters.HfToken != nil { + hfToken := strings.TrimSpace(*profile.Parameters.HfToken) + if hfToken != "" { + logger.Info("Resolved HF token from transcription profile for diarization worker load", "profile_id", profile.ID) + return hfToken + } + } + + return strings.TrimSpace(os.Getenv("HF_TOKEN")) +} + +func (h *Handler) resolveRequestDefaultProfile(c *gin.Context) *models.TranscriptionProfile { + if userID, exists := c.Get("user_id"); exists { + if user, err := h.userRepo.FindByID(c.Request.Context(), userID.(uint)); err == nil && user.DefaultProfileID != nil { + if profile, err := h.profileRepo.FindByID(c.Request.Context(), *user.DefaultProfileID); err == nil { + return profile + } + } + } + + if profile, err := h.profileRepo.FindDefault(c.Request.Context()); err == nil { + return profile + } + + profiles, _, err := h.profileRepo.List(c.Request.Context(), 0, 1) + if err != nil || len(profiles) == 0 { + return nil + } + + return &profiles[0] +} + // @Summary Unload persistent diarization model // @Description Stop the resident diarization worker and release its VRAM // @Tags transcription From 002a8d9ff0a8a7326340f560ee8079f660f30c88 Mon Sep 17 00:00:00 2001 From: Codex Date: Fri, 26 Jun 2026 12:18:49 -0400 Subject: [PATCH 4/4] Reserve VRAM for persistent diarization workers --- .../persistent_diarization_manager.go | 30 +++++++++- .../persistent_diarization_manager_test.go | 59 +++++++++++++++++++ .../adapters/py/nvidia/sortformer_worker.py | 34 +++++++++++ .../adapters/py/pyannote/pyannote_worker.py | 57 +++++++++++++++--- 4 files changed, 171 insertions(+), 9 deletions(-) create mode 100644 internal/transcription/adapters/persistent_diarization_manager_test.go diff --git a/internal/transcription/adapters/persistent_diarization_manager.go b/internal/transcription/adapters/persistent_diarization_manager.go index 79a5cdbe4..f3400db77 100644 --- a/internal/transcription/adapters/persistent_diarization_manager.go +++ b/internal/transcription/adapters/persistent_diarization_manager.go @@ -10,6 +10,7 @@ import ( "os" "os/exec" "path/filepath" + "strconv" "strings" "sync" "time" @@ -26,6 +27,8 @@ const ( PersistentDiarizationModelPyAnnote = "pyannote" PersistentDiarizationModelSortformer = "sortformer" + + PersistentDiarizationDefaultVRAMReserveMB = 3000 ) var errPersistentWorkerStopped = errors.New("persistent diarization worker stopped") @@ -325,6 +328,7 @@ func startPersistentDiarizationWorker(ctx context.Context, modelID, envPath stri func persistentDiarizationWorkerArgs(modelID, envPath, scriptPath string, params map[string]interface{}) ([]string, error) { args := []string{"run", "--native-tls", "--project", envPath, "python", scriptPath} device := stringParam(params, "device", "auto") + reserveVRAMMB := intParam(params, "reserve_vram_mb", PersistentDiarizationDefaultVRAMReserveMB) switch modelID { case PersistentDiarizationModelPyAnnote: @@ -339,9 +343,13 @@ func persistentDiarizationWorkerArgs(modelID, envPath, scriptPath string, params "--hf-token", hfToken, "--model", stringParam(params, "model", "pyannote/speaker-diarization-community-1"), "--device", device, + "--reserve-vram-mb", strconv.Itoa(reserveVRAMMB), ) case PersistentDiarizationModelSortformer: - args = append(args, "--device", device) + args = append(args, + "--device", device, + "--reserve-vram-mb", strconv.Itoa(reserveVRAMMB), + ) default: return nil, fmt.Errorf("unsupported persistent diarization model: %s", modelID) } @@ -538,6 +546,26 @@ func stringParam(params map[string]interface{}, key, fallback string) string { return fallback } +func intParam(params map[string]interface{}, key string, fallback int) int { + if params == nil { + return fallback + } + switch value := params[key].(type) { + case int: + return value + case int64: + return int(value) + case float64: + return int(value) + case string: + parsed, err := strconv.Atoi(strings.TrimSpace(value)) + if err == nil { + return parsed + } + } + return fallback +} + func derefTime(value *time.Time) time.Time { if value == nil { return time.Now() diff --git a/internal/transcription/adapters/persistent_diarization_manager_test.go b/internal/transcription/adapters/persistent_diarization_manager_test.go new file mode 100644 index 000000000..31fab240a --- /dev/null +++ b/internal/transcription/adapters/persistent_diarization_manager_test.go @@ -0,0 +1,59 @@ +package adapters + +import "testing" + +func TestPersistentDiarizationWorkerArgsDefaultVRAMReservation(t *testing.T) { + tests := []struct { + name string + modelID string + params map[string]interface{} + }{ + { + name: "pyannote", + modelID: PersistentDiarizationModelPyAnnote, + params: map[string]interface{}{ + "hf_token": "test-token", + }, + }, + { + name: "sortformer", + modelID: PersistentDiarizationModelSortformer, + params: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + args, err := persistentDiarizationWorkerArgs(tt.modelID, "/tmp/env", "/tmp/worker.py", tt.params) + if err != nil { + t.Fatalf("expected worker args, got error: %v", err) + } + + if !argsHaveFlagValue(args, "--reserve-vram-mb", "3000") { + t.Fatalf("expected default 3000 MiB reservation in args: %#v", args) + } + }) + } +} + +func TestPersistentDiarizationWorkerArgsOverrideVRAMReservation(t *testing.T) { + args, err := persistentDiarizationWorkerArgs(PersistentDiarizationModelSortformer, "/tmp/env", "/tmp/worker.py", map[string]interface{}{ + "reserve_vram_mb": 1024, + }) + if err != nil { + t.Fatalf("expected worker args, got error: %v", err) + } + + if !argsHaveFlagValue(args, "--reserve-vram-mb", "1024") { + t.Fatalf("expected overridden 1024 MiB reservation in args: %#v", args) + } +} + +func argsHaveFlagValue(args []string, flag string, value string) bool { + for i := 0; i < len(args)-1; i++ { + if args[i] == flag && args[i+1] == value { + return true + } + } + return false +} diff --git a/internal/transcription/adapters/py/nvidia/sortformer_worker.py b/internal/transcription/adapters/py/nvidia/sortformer_worker.py index 6578cd1bf..1e866a7d0 100644 --- a/internal/transcription/adapters/py/nvidia/sortformer_worker.py +++ b/internal/transcription/adapters/py/nvidia/sortformer_worker.py @@ -7,6 +7,7 @@ import argparse import contextlib +import gc import json import os from pathlib import Path @@ -33,6 +34,29 @@ def resolve_device(device): return device +def reserve_cuda_vram(reserve_mb, device): + if device != "cuda" or reserve_mb <= 0: + return None + + reserve_bytes = int(reserve_mb) * 1024 * 1024 + reservation = torch.empty((reserve_bytes,), dtype=torch.uint8, device=torch.device("cuda")) + torch.cuda.synchronize() + print(f"Reserved {reserve_mb} MiB of CUDA VRAM for persistent Sortformer", file=sys.stderr) + return reservation + + +def release_cuda_vram(reservation): + if reservation is None: + return None + + del reservation + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + print("Released persistent Sortformer CUDA VRAM reservation", file=sys.stderr) + return None + + def model_path_from_virtual_env(): virtual_env = os.environ.get("VIRTUAL_ENV") if not virtual_env: @@ -75,12 +99,15 @@ def run_diarization(diar_model, request): def main(): parser = argparse.ArgumentParser(description="Persistent Sortformer diarization worker") parser.add_argument("--device", choices=["cpu", "cuda", "auto"], default="auto") + parser.add_argument("--reserve-vram-mb", type=int, default=0) args = parser.parse_args() + reservation = None try: device = resolve_device(args.device) with contextlib.redirect_stdout(sys.stderr): diar_model = load_model(device) + reservation = reserve_cuda_vram(args.reserve_vram_mb, device) send({"type": "ready", "ok": True, "model_id": "sortformer", "device": device}) except Exception as exc: traceback.print_exc(file=sys.stderr) @@ -106,10 +133,17 @@ def main(): raise ValueError(f"Unsupported action: {action}") with contextlib.redirect_stdout(sys.stderr): + reservation = release_cuda_vram(reservation) run_diarization(diar_model, request) + reservation = reserve_cuda_vram(args.reserve_vram_mb, device) send({"type": "response", "id": request_id, "ok": True}) except Exception as exc: traceback.print_exc(file=sys.stderr) + try: + with contextlib.redirect_stdout(sys.stderr): + reservation = reserve_cuda_vram(args.reserve_vram_mb, device) + except Exception as reserve_exc: + print(f"Warning: could not reacquire Sortformer VRAM reservation: {reserve_exc}", file=sys.stderr) send({ "type": "response", "id": request.get("id"), diff --git a/internal/transcription/adapters/py/pyannote/pyannote_worker.py b/internal/transcription/adapters/py/pyannote/pyannote_worker.py index 7f70d1d6f..d81fe6401 100644 --- a/internal/transcription/adapters/py/pyannote/pyannote_worker.py +++ b/internal/transcription/adapters/py/pyannote/pyannote_worker.py @@ -8,6 +8,7 @@ import argparse import contextlib import copy +import gc import json import os from pathlib import Path @@ -25,20 +26,50 @@ def send(message): sys.stdout.flush() +def resolve_device(device): + if device == "auto": + return "cuda" if torch.cuda.is_available() else "cpu" + if device == "cuda" and not torch.cuda.is_available(): + print("CUDA requested but not available, using CPU", file=sys.stderr) + return "cpu" + return device + + +def reserve_cuda_vram(reserve_mb, device): + if device != "cuda" or reserve_mb <= 0: + return None + + reserve_bytes = int(reserve_mb) * 1024 * 1024 + reservation = torch.empty((reserve_bytes,), dtype=torch.uint8, device=torch.device("cuda")) + torch.cuda.synchronize() + print(f"Reserved {reserve_mb} MiB of CUDA VRAM for persistent PyAnnote", file=sys.stderr) + return reservation + + +def release_cuda_vram(reservation): + if reservation is None: + return None + + del reservation + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + print("Released persistent PyAnnote CUDA VRAM reservation", file=sys.stderr) + return None + + def load_pipeline(args): print(f"Loading persistent PyAnnote pipeline: {args.model}", file=sys.stderr) pipeline = Pipeline.from_pretrained(args.model, token=args.hf_token) + device = resolve_device(args.device) - if args.device == "cuda" or (args.device == "auto" and torch.cuda.is_available()): - if torch.cuda.is_available(): - pipeline = pipeline.to(torch.device("cuda")) - print("Using CUDA for persistent PyAnnote diarization", file=sys.stderr) - else: - print("CUDA requested but not available, using CPU", file=sys.stderr) + if device == "cuda": + pipeline = pipeline.to(torch.device("cuda")) + print("Using CUDA for persistent PyAnnote diarization", file=sys.stderr) else: print("Using CPU for persistent PyAnnote diarization", file=sys.stderr) - return pipeline + return pipeline, device def instantiate_for_request(pipeline, base_params, request): @@ -96,12 +127,15 @@ def main(): parser.add_argument("--hf-token", required=True) parser.add_argument("--model", default="pyannote/speaker-diarization-community-1") parser.add_argument("--device", choices=["cpu", "cuda", "auto"], default="auto") + parser.add_argument("--reserve-vram-mb", type=int, default=0) args = parser.parse_args() + reservation = None try: with contextlib.redirect_stdout(sys.stderr): - pipeline = load_pipeline(args) + pipeline, device = load_pipeline(args) base_params = copy.deepcopy(pipeline.parameters(instantiated=True)) + reservation = reserve_cuda_vram(args.reserve_vram_mb, device) send({"type": "ready", "ok": True, "model_id": "pyannote", "model": args.model}) except Exception as exc: traceback.print_exc(file=sys.stderr) @@ -127,10 +161,17 @@ def main(): raise ValueError(f"Unsupported action: {action}") with contextlib.redirect_stdout(sys.stderr): + reservation = release_cuda_vram(reservation) run_diarization(pipeline, base_params, request) + reservation = reserve_cuda_vram(args.reserve_vram_mb, device) send({"type": "response", "id": request_id, "ok": True}) except Exception as exc: traceback.print_exc(file=sys.stderr) + try: + with contextlib.redirect_stdout(sys.stderr): + reservation = reserve_cuda_vram(args.reserve_vram_mb, device) + except Exception as reserve_exc: + print(f"Warning: could not reacquire PyAnnote VRAM reservation: {reserve_exc}", file=sys.stderr) send({ "type": "response", "id": request.get("id"),