diff --git a/api-docs/swagger.json b/api-docs/swagger.json index f5ddfdbb3..eb430c1b7 100644 --- a/api-docs/swagger.json +++ b/api-docs/swagger.json @@ -4778,6 +4778,10 @@ "description": "OpenAI settings", "type": "string" }, + "api_url": { + "description": "Custom transcription API base URL (OpenAI adapter only)", + "type": "string" + }, "attention_context_left": { "description": "NVIDIA Parakeet-specific parameters for long-form audio", "type": "integer" @@ -4930,6 +4934,10 @@ "threads": { "type": "integer" }, + "timeout_minutes": { + "description": "HTTP request timeout in minutes (OpenAI adapter with custom base URL)", + "type": "integer" + }, "vad_method": { "description": "VAD (Voice Activity Detection) settings", "type": "string" diff --git a/api-docs/swagger.yaml b/api-docs/swagger.yaml index 4b43a4ca1..39b1a8ecc 100644 --- a/api-docs/swagger.yaml +++ b/api-docs/swagger.yaml @@ -641,6 +641,9 @@ definitions: api_key: description: OpenAI settings type: string + api_url: + description: Custom transcription API base URL (OpenAI adapter only) + type: string attention_context_left: description: NVIDIA Parakeet-specific parameters for long-form audio type: integer @@ -747,6 +750,9 @@ definitions: type: number threads: type: integer + timeout_minutes: + description: HTTP request timeout in minutes (OpenAI adapter with custom base URL) + type: integer vad_method: description: VAD (Voice Activity Detection) settings type: string diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 1b94beef4..3637605ef 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -1033,6 +1033,64 @@ func (h *Handler) StartTranscription(c *gin.Context) { c.JSON(http.StatusOK, job) } +// @Summary Start diarization for a completed transcription +// @Description Run speaker diarization against an existing transcript without rerunning transcription +// @Tags transcription +// @Accept json +// @Produce json +// @Param id path string true "Job ID" +// @Param parameters body models.WhisperXParams true "Diarization parameters" +// @Success 200 {object} models.TranscriptionJob +// @Failure 400 {object} map[string]string +// @Failure 404 {object} map[string]string +// @Router /api/v1/transcription/{id}/diarize [post] +// @Security ApiKeyAuth +// @Security BearerAuth +func (h *Handler) StartDiarization(c *gin.Context) { + jobID := c.Param("id") + + job, err := h.getJobForDiarization(c, jobID) + if err != nil { + return + } + + requestParams, err := h.getValidatedTranscriptionParams(c, job, jobID) + if err != nil { + return + } + + requestParams.Diarize = true + requestParams.DiarizationOnly = true + requestParams.IsMultiTrackEnabled = false + if requestParams.DiarizeModel == "" { + requestParams.DiarizeModel = transcription.ModelPyannote + } + + job.Parameters = *requestParams + job.Diarization = true + job.Status = models.StatusPending + job.ErrorMessage = nil + + if err := h.jobRepo.Update(c.Request.Context(), job); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update job"}) + return + } + + if err := h.taskQueue.EnqueueJob(jobID); err != nil { + logger.Error("Failed to enqueue diarization job", "job_id", jobID, "error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to enqueue diarization job"}) + return + } + + logger.Info("Diarization job started", + "job_id", jobID, + "diarize_model", requestParams.DiarizeModel, + "min_speakers", requestParams.MinSpeakers, + "max_speakers", requestParams.MaxSpeakers) + + c.JSON(http.StatusOK, job) +} + func (h *Handler) getJobForTranscription(c *gin.Context, jobID string) (*models.TranscriptionJob, error) { job, err := h.jobRepo.FindByID(c.Request.Context(), jobID) if err != nil { @@ -1052,6 +1110,35 @@ func (h *Handler) getJobForTranscription(c *gin.Context, jobID string) (*models. return job, nil } +func (h *Handler) getJobForDiarization(c *gin.Context, jobID string) (*models.TranscriptionJob, error) { + job, err := h.jobRepo.FindByID(c.Request.Context(), jobID) + if err != nil { + if err == gorm.ErrRecordNotFound { + c.JSON(http.StatusNotFound, gin.H{"error": "Job not found"}) + return nil, err + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get job"}) + return nil, err + } + + if job.Status == models.StatusProcessing || job.Status == models.StatusPending { + c.JSON(http.StatusBadRequest, gin.H{"error": "Cannot start diarization: job is currently processing or pending"}) + return nil, fmt.Errorf("invalid job status") + } + + if job.IsMultiTrack { + c.JSON(http.StatusBadRequest, gin.H{"error": "Standalone diarization is not supported for multi-track recordings"}) + return nil, fmt.Errorf("multi-track diarization unsupported") + } + + if job.Transcript == nil || strings.TrimSpace(*job.Transcript) == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Cannot start diarization: job does not have an existing transcript"}) + return nil, fmt.Errorf("missing transcript") + } + + return job, nil +} + func (h *Handler) getValidatedTranscriptionParams(c *gin.Context, job *models.TranscriptionJob, jobID string) (*models.WhisperXParams, error) { // Set defaults requestParams := models.WhisperXParams{ diff --git a/internal/api/router.go b/internal/api/router.go index b5bec3224..33061fdf7 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -132,6 +132,7 @@ func SetupRoutes(handler *Handler, authService *auth.AuthService) *gin.Engine { transcription.POST("/youtube", handler.DownloadFromYouTube) transcription.POST("/submit", handler.SubmitJob) transcription.POST("/:id/start", handler.StartTranscription) + transcription.POST("/:id/diarize", handler.StartDiarization) transcription.POST("/:id/kill", handler.KillJob) transcription.GET("/:id/logs", handler.GetJobLogs) transcription.GET("/:id/status", handler.GetJobStatus) diff --git a/internal/models/transcription.go b/internal/models/transcription.go index 632c6c3b7..b325cb74f 100644 --- a/internal/models/transcription.go +++ b/internal/models/transcription.go @@ -85,6 +85,7 @@ type WhisperXParams struct { // Diarization settings Diarize bool `json:"diarize" gorm:"type:boolean;default:false"` + DiarizationOnly bool `json:"diarization_only" gorm:"type:boolean;default:false"` MinSpeakers *int `json:"min_speakers,omitempty" gorm:"type:int"` MaxSpeakers *int `json:"max_speakers,omitempty" gorm:"type:int"` DiarizeModel string `json:"diarize_model" gorm:"type:varchar(50);default:'pyannote'"` // Options: 'pyannote', 'nvidia_sortformer' @@ -127,7 +128,9 @@ type WhisperXParams struct { CallbackURL *string `json:"callback_url,omitempty" gorm:"type:text"` // OpenAI settings - APIKey *string `json:"api_key,omitempty" gorm:"type:text"` + APIKey *string `json:"api_key,omitempty" gorm:"type:text"` + APIURL *string `json:"api_url,omitempty" gorm:"type:text"` + TimeoutMinutes *int `json:"timeout_minutes,omitempty" gorm:"type:int"` // Voxtral settings MaxNewTokens *int `json:"max_new_tokens,omitempty" gorm:"type:int"` @@ -207,10 +210,10 @@ func (tp *TranscriptionProfile) BeforeSave(tx *gorm.DB) error { // LLMConfig represents LLM configuration settings type LLMConfig struct { ID uint `json:"id" gorm:"primaryKey"` - Provider string `json:"provider" gorm:"not null;type:varchar(50)"` // "ollama" or "openai" - BaseURL *string `json:"base_url,omitempty" gorm:"type:text"` // For Ollama + Provider string `json:"provider" gorm:"not null;type:varchar(50)"` // "ollama" or "openai" + BaseURL *string `json:"base_url,omitempty" gorm:"type:text"` // For Ollama OpenAIBaseURL *string `json:"openai_base_url,omitempty" gorm:"type:text"` // For OpenAI custom endpoint - APIKey *string `json:"api_key,omitempty" gorm:"type:text"` // For OpenAI (encrypted) + APIKey *string `json:"api_key,omitempty" gorm:"type:text"` // For OpenAI (encrypted) IsActive bool `json:"is_active" gorm:"type:boolean;default:false"` CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` diff --git a/internal/transcription/README.md b/internal/transcription/README.md index e915bc151..30c5b5de3 100644 --- a/internal/transcription/README.md +++ b/internal/transcription/README.md @@ -155,6 +155,7 @@ err := adapter.ValidateParameters(params) | `whisperx` | `whisper` | 90+ languages | Timestamps, Diarization, Translation | | `parakeet` | `nvidia_parakeet` | English only | Timestamps, Long-form, High Quality | | `canary` | `nvidia_canary` | 12 languages | Timestamps, Translation, Multilingual | +| `openai_whisper` | `openai` | 57 languages | Timestamps, Diarization, Translation, Custom Endpoint | ### Diarization Models @@ -221,9 +222,18 @@ params := map[string]interface{}{ // NVIDIA Canary with translation params := map[string]interface{}{ "source_lang": "es", - "target_lang": "en", + "target_lang": "en", "task": "translate", } + +// OpenAI with custom self-hosted endpoint +params := map[string]interface{}{ + "base_url": "http://localhost:8000/v1", + "model": "Systran/faster-whisper-large-v3", + "timeout_minutes": 30, + "diarize": true, + "diarize_model": "pyannote", +} ``` ## Testing diff --git a/internal/transcription/adapters/openai_adapter.go b/internal/transcription/adapters/openai_adapter.go index a9b008efd..55f2f77be 100644 --- a/internal/transcription/adapters/openai_adapter.go +++ b/internal/transcription/adapters/openai_adapter.go @@ -40,7 +40,7 @@ func NewOpenAIAdapter(apiKey string) *OpenAIAdapter { Features: map[string]bool{ "timestamps": true, // Verbose JSON response includes segments "word_level": false, // Not supported by standard API yet (unless using verbose_json with timestamp_granularities which is beta) - "diarization": false, // Not supported by OpenAI API + "diarization": true, // Post-processing via pyannote/sortformer pipeline "translation": true, "language_detection": true, "vad": true, // Implicit @@ -59,13 +59,19 @@ func NewOpenAIAdapter(apiKey string) *OpenAIAdapter { Description: "OpenAI API Key (overrides system default)", Group: "authentication", }, + { + Name: "base_url", + Type: "string", + Required: false, + Description: "Custom transcription API base URL (overrides server default)", + Group: "authentication", + }, { Name: "model", Type: "string", Required: false, Default: "whisper-1", - Options: []string{"whisper-1"}, - Description: "ID of the model to use", + Description: "Model name (e.g. whisper-1, or any model exposed by a custom endpoint)", Group: "basic", }, { @@ -92,6 +98,15 @@ func NewOpenAIAdapter(apiKey string) *OpenAIAdapter { Description: "Sampling temperature", Group: "quality", }, + { + Name: "timeout_minutes", + Type: "int", + Required: false, + Default: nil, + Min: &[]float64{1}[0], + Description: "HTTP request timeout in minutes (increase for large files on self-hosted endpoints)", + Group: "advanced", + }, } baseAdapter := NewBaseAdapter("openai_whisper", "", capabilities, schema) @@ -153,7 +168,14 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn apiKey = key } - if apiKey == "" { + const officialURL = "https://api.openai.com/v1/audio/transcriptions" + endpointURL := officialURL + if url := a.GetStringParameter(params, "base_url"); url != "" { + endpointURL = strings.TrimRight(url, "/") + "/audio/transcriptions" + } + isOfficialEndpoint := endpointURL == officialURL + + if apiKey == "" && isOfficialEndpoint { writeLog("Error: OpenAI API key is required but not provided") return nil, fmt.Errorf("OpenAI API key is required but not provided") } @@ -188,7 +210,7 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn writeLog("Model: %s", model) _ = writer.WriteField("model", model) - if strings.HasPrefix(model, "gpt-4o") { + if isOfficialEndpoint && strings.HasPrefix(model, "gpt-4o") { if strings.Contains(model, "diarize") { _ = writer.WriteField("response_format", "diarized_json") } else { @@ -197,7 +219,6 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn // gpt-4o models don't support timestamp_granularities with these formats } else { _ = writer.WriteField("response_format", "verbose_json") - // timestamp_granularities is only supported for whisper-1 if model == "whisper-1" { _ = writer.WriteField("timestamp_granularities[]", "word") // Request word timestamps _ = writer.WriteField("timestamp_granularities[]", "segment") // Request segment timestamps @@ -224,8 +245,8 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn } // Create request - writeLog("Sending request to OpenAI API...") - req, err := http.NewRequestWithContext(ctx, "POST", "https://api.openai.com/v1/audio/transcriptions", body) + writeLog("Sending request to %s...", endpointURL) + req, err := http.NewRequestWithContext(ctx, "POST", endpointURL, body) if err != nil { writeLog("Error: Failed to create request: %v", err) return nil, fmt.Errorf("failed to create request: %w", err) @@ -235,9 +256,14 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn req.Header.Set("Authorization", "Bearer "+apiKey) // Execute request - client := &http.Client{ - Timeout: 10 * time.Minute, // Generous timeout for large files + timeout := 10 * time.Minute + if !isOfficialEndpoint { + timeout = 30 * time.Minute // Default for self-hosted endpoints + } + if t := a.GetIntParameter(params, "timeout_minutes"); t > 0 { + timeout = time.Duration(t) * time.Minute } + client := &http.Client{Timeout: timeout} resp, err := client.Do(req) if err != nil { writeLog("Error: Request failed: %v", err) @@ -247,8 +273,8 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn if resp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(resp.Body) - writeLog("Error: OpenAI API error (status %d): %s", resp.StatusCode, string(respBody)) - return nil, fmt.Errorf("OpenAI API error (status %d): %s", resp.StatusCode, string(respBody)) + writeLog("Error: transcription API error (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("transcription API error (status %d): %s", resp.StatusCode, string(respBody)) } writeLog("Response received. Parsing...") diff --git a/internal/transcription/adapters/sortformer_adapter.go b/internal/transcription/adapters/sortformer_adapter.go index f7b133c0b..4d8f5b1fd 100644 --- a/internal/transcription/adapters/sortformer_adapter.go +++ b/internal/transcription/adapters/sortformer_adapter.go @@ -7,6 +7,7 @@ import ( "os" "os/exec" "path/filepath" + "sort" "strconv" "strings" "time" @@ -64,6 +65,16 @@ func NewSortformerAdapter(envPath string) *SortformerAdapter { Description: "Maximum number of speakers (optimized for 4)", Group: "basic", }, + { + Name: "min_speakers", + Type: "int", + Required: false, + Default: 1, + Min: &[]float64{1}[0], + Max: &[]float64{8}[0], + Description: "Minimum number of speakers", + Group: "basic", + }, { Name: "batch_size", Type: "int", @@ -423,10 +434,22 @@ func (s *SortformerAdapter) buildSortformerArgs(input interfaces.AudioInput, par func (s *SortformerAdapter) parseResult(tempDir string, input interfaces.AudioInput, params map[string]interface{}) (*interfaces.DiarizationResult, error) { outputFormat := s.GetStringParameter(params, "output_format") + var ( + result *interfaces.DiarizationResult + err error + ) + if outputFormat == OutputFormatJSON { - return s.parseJSONResult(tempDir) + result, err = s.parseJSONResult(tempDir) + } else { + result, err = s.parseRTTMResult(tempDir, input) + } + + if err != nil { + return nil, err } - return s.parseRTTMResult(tempDir, input) + + return s.enforceSpeakerLimit(result, params), nil } // parseJSONResult parses JSON format output @@ -538,6 +561,126 @@ func (s *SortformerAdapter) parseRTTMResult(tempDir string, input interfaces.Aud return result, nil } +type sortformerSpeakerDuration struct { + speaker string + duration float64 +} + +func (s *SortformerAdapter) enforceSpeakerLimit(result *interfaces.DiarizationResult, params map[string]interface{}) *interfaces.DiarizationResult { + maxSpeakers := s.GetIntParameter(params, "max_speakers") + if result == nil || maxSpeakers <= 0 { + return result + } + + speakerDurations := make(map[string]float64) + for _, segment := range result.Segments { + duration := segment.End - segment.Start + if duration < 0 { + duration = 0 + } + speakerDurations[segment.Speaker] += duration + } + + originalSpeakerCount := len(speakerDurations) + if originalSpeakerCount <= maxSpeakers { + result.SpeakerCount = originalSpeakerCount + result.Speakers = sortedSpeakerList(speakerDurations) + return result + } + + rankedSpeakers := make([]sortformerSpeakerDuration, 0, len(speakerDurations)) + for speaker, duration := range speakerDurations { + rankedSpeakers = append(rankedSpeakers, sortformerSpeakerDuration{ + speaker: speaker, + duration: duration, + }) + } + + sort.Slice(rankedSpeakers, func(i, j int) bool { + if rankedSpeakers[i].duration == rankedSpeakers[j].duration { + return rankedSpeakers[i].speaker < rankedSpeakers[j].speaker + } + return rankedSpeakers[i].duration > rankedSpeakers[j].duration + }) + + keptSpeakers := make(map[string]bool, maxSpeakers) + fallbackSpeaker := rankedSpeakers[0].speaker + for i := 0; i < maxSpeakers && i < len(rankedSpeakers); i++ { + keptSpeakers[rankedSpeakers[i].speaker] = true + } + + originalSegments := append([]interfaces.DiarizationSegment(nil), result.Segments...) + for i := range result.Segments { + if keptSpeakers[result.Segments[i].Speaker] { + continue + } + + result.Segments[i].Speaker = nearestKeptSpeaker(originalSegments, i, keptSpeakers, fallbackSpeaker) + } + + rebuildSpeakerSummary(result) + + logger.Warn("Sortformer returned more speakers than requested; remapped extra speaker labels", + "requested_max_speakers", maxSpeakers, + "original_speakers", originalSpeakerCount, + "final_speakers", result.SpeakerCount) + + return result +} + +func nearestKeptSpeaker(segments []interfaces.DiarizationSegment, targetIndex int, keptSpeakers map[string]bool, fallbackSpeaker string) string { + target := segments[targetIndex] + bestSpeaker := fallbackSpeaker + bestDistance := -1.0 + + for i, segment := range segments { + if i == targetIndex || !keptSpeakers[segment.Speaker] { + continue + } + + distance := segmentDistance(target, segment) + if bestDistance < 0 || distance < bestDistance || (distance == bestDistance && segment.Speaker < bestSpeaker) { + bestSpeaker = segment.Speaker + bestDistance = distance + } + } + + return bestSpeaker +} + +func segmentDistance(a, b interfaces.DiarizationSegment) float64 { + if b.End <= a.Start { + return a.Start - b.End + } + if a.End <= b.Start { + return b.Start - a.End + } + return 0 +} + +func rebuildSpeakerSummary(result *interfaces.DiarizationResult) { + speakers := make(map[string]float64) + for _, segment := range result.Segments { + duration := segment.End - segment.Start + if duration < 0 { + duration = 0 + } + speakers[segment.Speaker] += duration + } + + result.SpeakerCount = len(speakers) + result.Speakers = sortedSpeakerList(speakers) +} + +func sortedSpeakerList(speakers map[string]float64) []string { + speakerList := make([]string, 0, len(speakers)) + for speaker := range speakers { + speakerList = append(speakerList, speaker) + } + sort.Strings(speakerList) + return speakerList +} + // GetEstimatedProcessingTime provides Sortformer-specific time estimation func (s *SortformerAdapter) GetEstimatedProcessingTime(input interfaces.AudioInput) time.Duration { // Sortformer is typically very fast, often faster than real-time diff --git a/internal/transcription/adapters/sortformer_adapter_test.go b/internal/transcription/adapters/sortformer_adapter_test.go new file mode 100644 index 000000000..708bdd01b --- /dev/null +++ b/internal/transcription/adapters/sortformer_adapter_test.go @@ -0,0 +1,36 @@ +package adapters + +import ( + "testing" + + "scriberr/internal/transcription/interfaces" +) + +func TestSortformerEnforceSpeakerLimitRemapsExtraSpeakers(t *testing.T) { + adapter := NewSortformerAdapter("/tmp/sortformer") + result := &interfaces.DiarizationResult{ + Segments: []interfaces.DiarizationSegment{ + {Start: 0, End: 10, Speaker: "speaker_0", Confidence: 1.0}, + {Start: 12, End: 20, Speaker: "speaker_1", Confidence: 1.0}, + {Start: 21, End: 22, Speaker: "speaker_2", Confidence: 1.0}, + }, + SpeakerCount: 3, + Speakers: []string{"speaker_0", "speaker_1", "speaker_2"}, + } + + capped := adapter.enforceSpeakerLimit(result, map[string]interface{}{ + "max_speakers": 2, + }) + + if capped.SpeakerCount != 2 { + t.Fatalf("expected 2 speakers, got %d", capped.SpeakerCount) + } + if capped.Segments[2].Speaker != "speaker_1" { + t.Fatalf("expected speaker_2 segment to map to nearest retained speaker_1, got %s", capped.Segments[2].Speaker) + } + for _, speaker := range capped.Speakers { + if speaker == "speaker_2" { + t.Fatal("speaker_2 should have been removed from the speaker summary") + } + } +} diff --git a/internal/transcription/adapters_test.go b/internal/transcription/adapters_test.go index ca09ca877..4c3a35c0a 100644 --- a/internal/transcription/adapters_test.go +++ b/internal/transcription/adapters_test.go @@ -553,6 +553,49 @@ func BenchmarkModelRegistryLookup(b *testing.B) { } } +func TestOpenAIAdapter(t *testing.T) { + a := adapters.NewOpenAIAdapter("sk-test") + if a == nil { + t.Fatal("NewOpenAIAdapter returned nil") + } + + caps := a.GetCapabilities() + if caps.ModelID != "openai_whisper" { + t.Errorf("expected model ID 'openai_whisper', got %q", caps.ModelID) + } + if caps.ModelFamily != "openai" { + t.Errorf("expected model family 'openai', got %q", caps.ModelFamily) + } + if !caps.Features["diarization"] { + t.Error("diarization capability must be true") + } + + schema := a.GetParameterSchema() + hasBaseURL := false + for _, p := range schema { + if p.Name == "base_url" { + hasBaseURL = true + } + if p.Name == "model" && len(p.Options) > 0 { + t.Errorf("model parameter must not have a fixed Options list, got %v", p.Options) + } + } + if !hasBaseURL { + t.Error("schema must include base_url parameter") + } +} + +func TestOpenAIAdapterWithBaseURL(t *testing.T) { + a := adapters.NewOpenAIAdapter("") + if a == nil { + t.Fatal("NewOpenAIAdapter returned nil") + } + caps := a.GetCapabilities() + if !caps.Features["diarization"] { + t.Error("diarization capability must be true") + } +} + func BenchmarkParameterValidation(b *testing.B) { reg := registry.GetRegistry() adapter, err := reg.GetTranscriptionAdapter("whisperx") diff --git a/internal/transcription/sortformer_params_test.go b/internal/transcription/sortformer_params_test.go new file mode 100644 index 000000000..adfa2d2e0 --- /dev/null +++ b/internal/transcription/sortformer_params_test.go @@ -0,0 +1,25 @@ +package transcription + +import ( + "testing" + + "scriberr/internal/models" +) + +func TestConvertToSortformerParamsPreservesSpeakerConstraints(t *testing.T) { + minSpeakers := 2 + maxSpeakers := 2 + service := &UnifiedTranscriptionService{} + + params := service.convertToSortformerParams(models.WhisperXParams{ + MinSpeakers: &minSpeakers, + MaxSpeakers: &maxSpeakers, + }) + + if params["min_speakers"] != minSpeakers { + t.Fatalf("expected min_speakers=%d, got %v", minSpeakers, params["min_speakers"]) + } + if params["max_speakers"] != maxSpeakers { + t.Fatalf("expected max_speakers=%d, got %v", maxSpeakers, params["max_speakers"]) + } +} diff --git a/internal/transcription/unified_service.go b/internal/transcription/unified_service.go index e17ef8b4b..bbfe39bf8 100644 --- a/internal/transcription/unified_service.go +++ b/internal/transcription/unified_service.go @@ -287,6 +287,17 @@ func (u *UnifiedTranscriptionService) processSingleTrackJob(ctx context.Context, var transcriptResult *interfaces.TranscriptResult var diarizationResult *interfaces.DiarizationResult + if job.Parameters.DiarizationOnly { + if !job.Parameters.Diarize { + return fmt.Errorf("diarization-only job requires diarization to be enabled") + } + + transcriptResult, err = u.loadExistingTranscriptResult(job) + if err != nil { + return fmt.Errorf("failed to load existing transcript for diarization: %w", err) + } + } + // Perform transcription using the preprocessed audio if transcriptionModelID != "" { logger.Info("Running transcription", "model_id", transcriptionModelID) @@ -376,19 +387,21 @@ func (u *UnifiedTranscriptionService) IsMultiTrackJob(jobID string) bool { // selectModels determines which models to use based on job parameters func (u *UnifiedTranscriptionService) selectModels(params models.WhisperXParams) (transcriptionModelID, diarizationModelID string, err error) { // Determine transcription model - switch params.ModelFamily { - case FamilyNvidiaParakeet: - transcriptionModelID = ModelParakeet - case FamilyNvidiaCanary: - transcriptionModelID = ModelCanary - case FamilyWhisper: - transcriptionModelID = ModelWhisperX - case FamilyOpenAI: - transcriptionModelID = ModelOpenAI - case FamilyMistralVoxtral: - transcriptionModelID = ModelVoxtral - default: - transcriptionModelID = ModelWhisperX // Default fallback + if !params.DiarizationOnly { + switch params.ModelFamily { + case FamilyNvidiaParakeet: + transcriptionModelID = ModelParakeet + case FamilyNvidiaCanary: + transcriptionModelID = ModelCanary + case FamilyWhisper: + transcriptionModelID = ModelWhisperX + case FamilyOpenAI: + transcriptionModelID = ModelOpenAI + case FamilyMistralVoxtral: + transcriptionModelID = ModelVoxtral + default: + transcriptionModelID = ModelWhisperX // Default fallback + } } // Determine diarization model if needed @@ -589,6 +602,12 @@ func (u *UnifiedTranscriptionService) convertToOpenAIParams(params models.Whispe if params.APIKey != nil && *params.APIKey != "" { paramMap["api_key"] = *params.APIKey } + if params.APIURL != nil && *params.APIURL != "" { + paramMap["base_url"] = *params.APIURL + } + if params.TimeoutMinutes != nil && *params.TimeoutMinutes > 0 { + paramMap["timeout_minutes"] = *params.TimeoutMinutes + } return paramMap } @@ -739,11 +758,19 @@ func (u *UnifiedTranscriptionService) convertToPyannoteParams(params models.Whis // convertToSortformerParams converts to Sortformer-specific parameters func (u *UnifiedTranscriptionService) convertToSortformerParams(params models.WhisperXParams) map[string]interface{} { - return map[string]interface{}{ + paramMap := map[string]interface{}{ "output_format": OutputFormatJSON, "auto_convert_audio": true, - // Sortformer is optimized for 4 speakers, no additional config needed } + + if params.MinSpeakers != nil { + paramMap["min_speakers"] = *params.MinSpeakers + } + if params.MaxSpeakers != nil { + paramMap["max_speakers"] = *params.MaxSpeakers + } + + return paramMap } func (u *UnifiedTranscriptionService) parametersToMap(params models.WhisperXParams) map[string]interface{} { @@ -836,6 +863,7 @@ func (u *UnifiedTranscriptionService) mergeDiarizationWithTranscription(transcri // Assign speakers to transcript segments based on timing overlap for i := range mergedTranscript.Segments { segment := &mergedTranscript.Segments[i] + segment.Speaker = nil bestSpeaker := u.findBestSpeakerForSegment(segment.Start, segment.End, diarization.Segments) if bestSpeaker != "" { segment.Speaker = &bestSpeaker @@ -849,6 +877,7 @@ func (u *UnifiedTranscriptionService) mergeDiarizationWithTranscription(transcri for i := range mergedTranscript.WordSegments { word := &mergedTranscript.WordSegments[i] + word.Speaker = nil bestSpeaker := u.findBestSpeakerForSegment(word.Start, word.End, diarization.Segments) if bestSpeaker != "" { word.Speaker = &bestSpeaker @@ -879,6 +908,23 @@ func (u *UnifiedTranscriptionService) findBestSpeakerForSegment(start, end float return bestSpeaker } +func (u *UnifiedTranscriptionService) loadExistingTranscriptResult(job *models.TranscriptionJob) (*interfaces.TranscriptResult, error) { + if job.Transcript == nil || strings.TrimSpace(*job.Transcript) == "" { + return nil, fmt.Errorf("job has no transcript") + } + + var result interfaces.TranscriptResult + if err := json.Unmarshal([]byte(*job.Transcript), &result); err != nil { + return nil, fmt.Errorf("failed to parse transcript JSON: %w", err) + } + + if len(result.Segments) == 0 { + return nil, fmt.Errorf("existing transcript has no timed segments") + } + + return &result, nil +} + // saveTranscriptionResults saves the transcription results to the database func (u *UnifiedTranscriptionService) saveTranscriptionResults(jobID string, result *interfaces.TranscriptResult) error { // Convert result to JSON string for database storage diff --git a/tests/adapter_registration_test.go b/tests/adapter_registration_test.go index 3b9203c72..a43520906 100644 --- a/tests/adapter_registration_test.go +++ b/tests/adapter_registration_test.go @@ -47,6 +47,19 @@ func TestAdapterEnvPathInjection(t *testing.T) { } } +// TestOpenAIAdapterConstruction tests the OpenAI adapter constructor +func TestOpenAIAdapterConstruction(t *testing.T) { + a := adapters.NewOpenAIAdapter("") + if a == nil { + t.Fatal("NewOpenAIAdapter returned nil with empty key") + } + + a = adapters.NewOpenAIAdapter("sk-test") + if !a.GetCapabilities().Features["diarization"] { + t.Error("diarization capability must be true") + } +} + // TestRegisterAdapters tests that registerAdapters correctly registers all adapters func TestRegisterAdapters(t *testing.T) { // Clear registry before test @@ -67,6 +80,8 @@ func TestRegisterAdapters(t *testing.T) { adapters.NewParakeetAdapter(nvidiaEnvPath)) registry.RegisterTranscriptionAdapter("canary", adapters.NewCanaryAdapter(nvidiaEnvPath)) + registry.RegisterTranscriptionAdapter("openai_whisper", + adapters.NewOpenAIAdapter("")) registry.RegisterDiarizationAdapter("pyannote", adapters.NewPyAnnoteAdapter(nvidiaEnvPath)) @@ -75,8 +90,8 @@ func TestRegisterAdapters(t *testing.T) { // Verify registrations transcriptionAdapters := registry.GetTranscriptionAdapters() - if len(transcriptionAdapters) != 3 { - t.Errorf("Expected 3 transcription adapters, got %d", len(transcriptionAdapters)) + if len(transcriptionAdapters) != 4 { + t.Errorf("Expected 4 transcription adapters, got %d", len(transcriptionAdapters)) } // Check specific adapters are registered @@ -89,6 +104,9 @@ func TestRegisterAdapters(t *testing.T) { if _, exists := transcriptionAdapters["canary"]; !exists { t.Error("canary adapter not registered") } + if _, exists := transcriptionAdapters["openai_whisper"]; !exists { + t.Error("openai_whisper adapter not registered") + } diarizationAdapters := registry.GetDiarizationAdapters() if len(diarizationAdapters) != 2 { diff --git a/web/frontend/src/components/TranscribeDDialog.tsx b/web/frontend/src/components/TranscribeDDialog.tsx index e88482ca6..739dc3623 100644 --- a/web/frontend/src/components/TranscribeDDialog.tsx +++ b/web/frontend/src/components/TranscribeDDialog.tsx @@ -15,6 +15,7 @@ import { SelectValue, } from "@/components/ui/select"; import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { Loader2 } from "lucide-react"; import type { WhisperXParams } from "./TranscriptionConfigDialog"; @@ -36,20 +37,43 @@ interface TranscribeDDialogProps { onStartTranscription: (params: WhisperXParams, profileId?: string) => void; loading?: boolean; title?: string; + description?: string; + submitLabel?: string; + loadingLabel?: string; + forceDiarization?: boolean; } +type DiarizationModel = "pyannote" | "nvidia_sortformer"; + export function TranscribeDDialog({ open, onOpenChange, onStartTranscription, loading = false, title, + description, + submitLabel = "Start Transcription", + loadingLabel = "Starting...", + forceDiarization = false, }: TranscribeDDialogProps) { const { getAuthHeaders } = useAuth(); const [profiles, setProfiles] = useState([]); const [selectedProfileId, setSelectedProfileId] = useState(""); const [profilesLoading, setProfilesLoading] = useState(false); const [defaultProfile, setDefaultProfile] = useState(null); + const [minSpeakersInput, setMinSpeakersInput] = useState(""); + const [maxSpeakersInput, setMaxSpeakersInput] = useState(""); + const [diarizationModel, setDiarizationModel] = useState("pyannote"); + + const selectedProfile = profiles.find(p => p.id === selectedProfileId); + const showDiarizationSettings = forceDiarization || (selectedProfile?.parameters.diarize ?? false); + const minSpeakers = parseSpeakerLimit(minSpeakersInput); + const maxSpeakers = parseSpeakerLimit(maxSpeakersInput); + const hasInvalidSpeakerLimits = showDiarizationSettings && ( + (minSpeakersInput.trim() !== "" && minSpeakers === undefined) || + (maxSpeakersInput.trim() !== "" && maxSpeakers === undefined) + ); + const hasSpeakerRangeError = showDiarizationSettings && minSpeakers !== undefined && maxSpeakers !== undefined && minSpeakers > maxSpeakers; const fetchProfiles = useCallback(async () => { try { @@ -101,13 +125,31 @@ export function TranscribeDDialog({ } }, [open, fetchProfiles]); + useEffect(() => { + const profile = profiles.find(p => p.id === selectedProfileId); + + setMinSpeakersInput(formatSpeakerLimit(profile?.parameters.min_speakers)); + setMaxSpeakersInput(formatSpeakerLimit(profile?.parameters.max_speakers)); + setDiarizationModel(normalizeDiarizationModel(profile?.parameters.diarize_model)); + }, [profiles, selectedProfileId]); + const handleStartTranscription = () => { - if (!selectedProfileId) return; + if (!selectedProfile || hasInvalidSpeakerLimits || hasSpeakerRangeError) return; - const selectedProfile = profiles.find(p => p.id === selectedProfileId); - if (selectedProfile) { - onStartTranscription(selectedProfile.parameters, selectedProfile.id); + const params: WhisperXParams = { ...selectedProfile.parameters }; + if (showDiarizationSettings) { + params.diarize = true; + params.diarize_model = diarizationModel; + params.min_speakers = minSpeakers; + params.max_speakers = maxSpeakers; + + if (diarizationModel === "nvidia_sortformer") { + onStartTranscription(omitPyannoteOnlyParams(params), selectedProfile.id); + return; + } } + + onStartTranscription(params, selectedProfile.id); }; const handleProfileChange = (value: string) => { @@ -124,7 +166,7 @@ export function TranscribeDDialog({ {title || "Transcribe with Profile"} - Choose a saved profile to start transcription with your preferred settings. + {description || "Choose a saved profile to start transcription with your preferred settings."} @@ -183,6 +225,83 @@ export function TranscribeDDialog({ )} + {showDiarizationSettings && ( +
+
+ + +
+ +
+
+ + setMinSpeakersInput(event.target.value)} + aria-invalid={hasInvalidSpeakerLimits || hasSpeakerRangeError} + className="h-11 rounded-[var(--radius-btn)] bg-[var(--bg-main)] border border-[var(--border-subtle)] text-[var(--text-primary)] focus-visible:ring-[var(--brand-light)] focus-visible:border-[var(--brand-solid)] shadow-none" + /> +
+
+ + setMaxSpeakersInput(event.target.value)} + aria-invalid={hasInvalidSpeakerLimits || hasSpeakerRangeError} + className="h-11 rounded-[var(--radius-btn)] bg-[var(--bg-main)] border border-[var(--border-subtle)] text-[var(--text-primary)] focus-visible:ring-[var(--brand-light)] focus-visible:border-[var(--brand-solid)] shadow-none" + /> +
+
+ {hasInvalidSpeakerLimits && ( +

Use whole numbers from 1 to 20.

+ )} + {!hasInvalidSpeakerLimits && hasSpeakerRangeError && ( +

Min speakers cannot exceed max speakers.

+ )} +
+ )} @@ -196,16 +315,16 @@ export function TranscribeDDialog({ @@ -213,3 +332,33 @@ export function TranscribeDDialog({ ); } + +function formatSpeakerLimit(value?: number): string { + return value ? String(value) : ""; +} + +function parseSpeakerLimit(value: string): number | undefined { + const trimmedValue = value.trim(); + if (!trimmedValue) return undefined; + + const parsedValue = Number(trimmedValue); + if (!Number.isInteger(parsedValue) || parsedValue < 1 || parsedValue > 20) return undefined; + + return parsedValue; +} + +function normalizeDiarizationModel(value?: string): DiarizationModel { + return value === "nvidia_sortformer" ? "nvidia_sortformer" : "pyannote"; +} + +function omitPyannoteOnlyParams(params: WhisperXParams): WhisperXParams { + const sortformerParams = { + ...params, + } as Omit & Partial>; + + delete sortformerParams.hf_token; + delete sortformerParams.vad_onset; + delete sortformerParams.vad_offset; + + return sortformerParams as WhisperXParams; +} diff --git a/web/frontend/src/components/transcription/TranscriptionConfigDialog.tsx b/web/frontend/src/components/transcription/TranscriptionConfigDialog.tsx index 4f46aa333..cb42f7e41 100644 --- a/web/frontend/src/components/transcription/TranscriptionConfigDialog.tsx +++ b/web/frontend/src/components/transcription/TranscriptionConfigDialog.tsx @@ -44,6 +44,7 @@ export interface WhisperXParams { vad_offset: number; chunk_size: number; diarize: boolean; + diarization_only?: boolean; min_speakers?: number; max_speakers?: number; diarize_model: string; @@ -72,6 +73,8 @@ export interface WhisperXParams { attention_context_right: number; is_multi_track_enabled: boolean; api_key?: string; + api_url?: string; + timeout_minutes?: number; max_new_tokens?: number; } @@ -394,9 +397,13 @@ export const TranscriptionConfigDialog = memo(function TranscriptionConfigDialog )} {params.model_family === "openai" && ( )} @@ -632,14 +639,30 @@ interface OpenAIConfigProps extends ConfigProps { } function OpenAIConfig({ - params, updateParam, - isValidating, validationStatus, validationMessage, availableModels, onValidate + params, + updateParam, + isMultiTrack, + isValidating, + validationStatus, + validationMessage, + availableModels, + onValidate }: OpenAIConfigProps) { return (
- + + updateParam('api_url', e.target.value || undefined)} + className={inputClassName} + /> + + +
updateParam('api_key', e.target.value)} className={`${inputClassName} flex-1`} /> - + {!params.api_url && ( + + )}
- {validationStatus !== 'idle' && ( + {validationStatus !== 'idle' && !params.api_url && (
{validationStatus === 'valid' ? : } {validationMessage} @@ -662,12 +687,42 @@ function OpenAIConfig({ )} - updateParam('model', v)} options={availableModels} /> + {params.api_url ? ( + + updateParam('model', e.target.value)} + className={inputClassName} + /> + + ) : ( + updateParam('model', v)} options={availableModels} /> + )} + updateParam('language', v === "auto" ? undefined : v)} options={LANGUAGES} /> + + {params.api_url && ( + + updateParam('timeout_minutes', e.target.value ? parseInt(e.target.value) : undefined)} + className={inputClassName} + /> + + )}
- {params.model && params.model !== "whisper-1" && ( + {!isMultiTrack && ( + + )} + + {params.model && params.model !== "whisper-1" && !params.api_url && ( Word-level timestamps are only supported by whisper-1. Synchronized playback won't be available. diff --git a/web/frontend/src/components/ui/swipeable-item.tsx b/web/frontend/src/components/ui/swipeable-item.tsx index 37942669f..ace6a862c 100644 --- a/web/frontend/src/components/ui/swipeable-item.tsx +++ b/web/frontend/src/components/ui/swipeable-item.tsx @@ -1,6 +1,6 @@ import { useEffect, useRef, useState, type ReactNode } from "react"; import { motion, useAnimation, type PanInfo } from "framer-motion"; -import { Trash2, Wand2, StopCircle } from "lucide-react"; +import { Trash2, Wand2, StopCircle, Users } from "lucide-react"; import { WandAdvancedIcon } from "@/components/icons/WandAdvancedIcon"; import { useIsMobile } from "@/hooks/use-mobile"; @@ -8,8 +8,10 @@ interface SwipeableItemProps { children: ReactNode; onTranscribe: () => void; onTranscribeAdvanced: () => void; + onDiarize?: () => void; onDelete: () => void; onStop?: () => void; + canDiarize?: boolean; isProcessing?: boolean; isSelectionMode?: boolean; shouldShowHint?: boolean; @@ -33,8 +35,10 @@ export function SwipeableItem({ children, onTranscribe, onTranscribeAdvanced, + onDiarize, onDelete, onStop, + canDiarize = false, isProcessing = false, isSelectionMode = false, shouldShowHint = false, @@ -51,8 +55,9 @@ export function SwipeableItem({ const hasMovedRef = useRef(false); const suppressClickUntilRef = useRef(0); - // Width of the action buttons area (3 buttons × 44px + gaps + padding) - const OPEN_WIDTH = -160; + const actionCount = canDiarize && onDiarize && !isProcessing ? 4 : 3; + const actionAreaWidth = actionCount * 44 + (actionCount - 1) * 8 + 16; + const OPEN_WIDTH = -actionAreaWidth; const handleDragStart = (_event: MouseEvent | TouchEvent | PointerEvent, info: PanInfo) => { isDraggingRef.current = true; @@ -168,7 +173,10 @@ export function SwipeableItem({ data-has-moved={hasMoved()} > {/* --- LAYER 0: The Action Buttons (Hidden underneath, mobile only) --- */} -
+
{/* Transcribe (Primary) */} + {canDiarize && onDiarize && !isProcessing && ( + + )} + {/* Delete or Stop (Destructive - furthest right) */} {isProcessing && onStop ? ( + + Diarize + + )} )} @@ -1023,6 +1101,17 @@ export const AudioFilesTable = memo(function AudioFilesTable({ onStartTranscription={onStartTranscribeWithProfile} loading={transcriptionLoading} /> + {/* Stop Transcription Dialog */} diff --git a/web/project-site/public/api/swagger.json b/web/project-site/public/api/swagger.json index f5ddfdbb3..eb430c1b7 100644 --- a/web/project-site/public/api/swagger.json +++ b/web/project-site/public/api/swagger.json @@ -4778,6 +4778,10 @@ "description": "OpenAI settings", "type": "string" }, + "api_url": { + "description": "Custom transcription API base URL (OpenAI adapter only)", + "type": "string" + }, "attention_context_left": { "description": "NVIDIA Parakeet-specific parameters for long-form audio", "type": "integer" @@ -4930,6 +4934,10 @@ "threads": { "type": "integer" }, + "timeout_minutes": { + "description": "HTTP request timeout in minutes (OpenAI adapter with custom base URL)", + "type": "integer" + }, "vad_method": { "description": "VAD (Voice Activity Detection) settings", "type": "string"