From 1b8b782fec08e33fc524cbaac58cf4c42ed4356d Mon Sep 17 00:00:00 2001 From: Arun Babu Neelicattu Date: Sun, 29 Mar 2026 04:33:22 +0200 Subject: [PATCH 1/5] feat: extend OpenAI adapter for self-hosted Whisper endpoints The OpenAI transcription profile now works with any OpenAI-compatible Whisper server (faster-whisper, whisper.cpp, etc.), not just api.openai.com. Users can set a custom base URL per job in the config dialog. When a custom URL is provided: - The model field becomes a free-text input so any server-specific model name can be used - API key validation is hidden (most self-hosted servers don't require auth) - A configurable timeout is exposed, defaulting to 30 minutes to accommodate slower hardware - Speaker diarization can be enabled, handled as a post-processing step the same way it works for local models The official OpenAI endpoint behaviour is unchanged, including the special response format handling for gpt-4o audio models. --- api-docs/swagger.json | 8 ++ api-docs/swagger.yaml | 6 ++ internal/models/transcription.go | 4 +- internal/transcription/README.md | 12 ++- .../transcription/adapters/openai_adapter.go | 50 ++++++++--- internal/transcription/adapters_test.go | 43 ++++++++++ internal/transcription/unified_service.go | 6 ++ tests/adapter_registration_test.go | 22 ++++- .../TranscriptionConfigDialog.tsx | 84 +++++++++++++++---- web/project-site/public/api/swagger.json | 8 ++ 10 files changed, 212 insertions(+), 31 deletions(-) diff --git a/api-docs/swagger.json b/api-docs/swagger.json index f5ddfdbb3..eb430c1b7 100644 --- a/api-docs/swagger.json +++ b/api-docs/swagger.json @@ -4778,6 +4778,10 @@ "description": "OpenAI settings", "type": "string" }, + "api_url": { + "description": "Custom transcription API base URL (OpenAI adapter only)", + "type": "string" + }, "attention_context_left": { "description": "NVIDIA Parakeet-specific parameters for long-form audio", "type": "integer" @@ -4930,6 +4934,10 @@ "threads": { "type": "integer" }, + "timeout_minutes": { + "description": "HTTP request timeout in minutes (OpenAI adapter with custom base URL)", + "type": "integer" + }, "vad_method": { "description": "VAD (Voice Activity Detection) settings", "type": "string" diff --git a/api-docs/swagger.yaml b/api-docs/swagger.yaml index 4b43a4ca1..39b1a8ecc 100644 --- a/api-docs/swagger.yaml +++ b/api-docs/swagger.yaml @@ -641,6 +641,9 @@ definitions: api_key: description: OpenAI settings type: string + api_url: + description: Custom transcription API base URL (OpenAI adapter only) + type: string attention_context_left: description: NVIDIA Parakeet-specific parameters for long-form audio type: integer @@ -747,6 +750,9 @@ definitions: type: number threads: type: integer + timeout_minutes: + description: HTTP request timeout in minutes (OpenAI adapter with custom base URL) + type: integer vad_method: description: VAD (Voice Activity Detection) settings type: string diff --git a/internal/models/transcription.go b/internal/models/transcription.go index 632c6c3b7..04285b08e 100644 --- a/internal/models/transcription.go +++ b/internal/models/transcription.go @@ -127,7 +127,9 @@ type WhisperXParams struct { CallbackURL *string `json:"callback_url,omitempty" gorm:"type:text"` // OpenAI settings - APIKey *string `json:"api_key,omitempty" gorm:"type:text"` + APIKey *string `json:"api_key,omitempty" gorm:"type:text"` + APIURL *string `json:"api_url,omitempty" gorm:"type:text"` + TimeoutMinutes *int `json:"timeout_minutes,omitempty" gorm:"type:int"` // Voxtral settings MaxNewTokens *int `json:"max_new_tokens,omitempty" gorm:"type:int"` diff --git a/internal/transcription/README.md b/internal/transcription/README.md index e915bc151..30c5b5de3 100644 --- a/internal/transcription/README.md +++ b/internal/transcription/README.md @@ -155,6 +155,7 @@ err := adapter.ValidateParameters(params) | `whisperx` | `whisper` | 90+ languages | Timestamps, Diarization, Translation | | `parakeet` | `nvidia_parakeet` | English only | Timestamps, Long-form, High Quality | | `canary` | `nvidia_canary` | 12 languages | Timestamps, Translation, Multilingual | +| `openai_whisper` | `openai` | 57 languages | Timestamps, Diarization, Translation, Custom Endpoint | ### Diarization Models @@ -221,9 +222,18 @@ params := map[string]interface{}{ // NVIDIA Canary with translation params := map[string]interface{}{ "source_lang": "es", - "target_lang": "en", + "target_lang": "en", "task": "translate", } + +// OpenAI with custom self-hosted endpoint +params := map[string]interface{}{ + "base_url": "http://localhost:8000/v1", + "model": "Systran/faster-whisper-large-v3", + "timeout_minutes": 30, + "diarize": true, + "diarize_model": "pyannote", +} ``` ## Testing diff --git a/internal/transcription/adapters/openai_adapter.go b/internal/transcription/adapters/openai_adapter.go index a9b008efd..4f2824e01 100644 --- a/internal/transcription/adapters/openai_adapter.go +++ b/internal/transcription/adapters/openai_adapter.go @@ -40,7 +40,7 @@ func NewOpenAIAdapter(apiKey string) *OpenAIAdapter { Features: map[string]bool{ "timestamps": true, // Verbose JSON response includes segments "word_level": false, // Not supported by standard API yet (unless using verbose_json with timestamp_granularities which is beta) - "diarization": false, // Not supported by OpenAI API + "diarization": true, // Post-processing via pyannote/sortformer pipeline "translation": true, "language_detection": true, "vad": true, // Implicit @@ -59,13 +59,19 @@ func NewOpenAIAdapter(apiKey string) *OpenAIAdapter { Description: "OpenAI API Key (overrides system default)", Group: "authentication", }, + { + Name: "base_url", + Type: "string", + Required: false, + Description: "Custom transcription API base URL (overrides server default)", + Group: "authentication", + }, { Name: "model", Type: "string", Required: false, Default: "whisper-1", - Options: []string{"whisper-1"}, - Description: "ID of the model to use", + Description: "Model name (e.g. whisper-1, or any model exposed by a custom endpoint)", Group: "basic", }, { @@ -92,6 +98,15 @@ func NewOpenAIAdapter(apiKey string) *OpenAIAdapter { Description: "Sampling temperature", Group: "quality", }, + { + Name: "timeout_minutes", + Type: "int", + Required: false, + Default: 10, + Min: &[]float64{1}[0], + Description: "HTTP request timeout in minutes (increase for large files on self-hosted endpoints)", + Group: "advanced", + }, } baseAdapter := NewBaseAdapter("openai_whisper", "", capabilities, schema) @@ -153,7 +168,14 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn apiKey = key } - if apiKey == "" { + const officialURL = "https://api.openai.com/v1/audio/transcriptions" + endpointURL := officialURL + if url := a.GetStringParameter(params, "base_url"); url != "" { + endpointURL = strings.TrimRight(url, "/") + "/audio/transcriptions" + } + isOfficialEndpoint := endpointURL == officialURL + + if apiKey == "" && isOfficialEndpoint { writeLog("Error: OpenAI API key is required but not provided") return nil, fmt.Errorf("OpenAI API key is required but not provided") } @@ -188,7 +210,7 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn writeLog("Model: %s", model) _ = writer.WriteField("model", model) - if strings.HasPrefix(model, "gpt-4o") { + if isOfficialEndpoint && strings.HasPrefix(model, "gpt-4o") { if strings.Contains(model, "diarize") { _ = writer.WriteField("response_format", "diarized_json") } else { @@ -197,7 +219,6 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn // gpt-4o models don't support timestamp_granularities with these formats } else { _ = writer.WriteField("response_format", "verbose_json") - // timestamp_granularities is only supported for whisper-1 if model == "whisper-1" { _ = writer.WriteField("timestamp_granularities[]", "word") // Request word timestamps _ = writer.WriteField("timestamp_granularities[]", "segment") // Request segment timestamps @@ -224,8 +245,8 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn } // Create request - writeLog("Sending request to OpenAI API...") - req, err := http.NewRequestWithContext(ctx, "POST", "https://api.openai.com/v1/audio/transcriptions", body) + writeLog("Sending request to %s...", endpointURL) + req, err := http.NewRequestWithContext(ctx, "POST", endpointURL, body) if err != nil { writeLog("Error: Failed to create request: %v", err) return nil, fmt.Errorf("failed to create request: %w", err) @@ -235,9 +256,14 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn req.Header.Set("Authorization", "Bearer "+apiKey) // Execute request - client := &http.Client{ - Timeout: 10 * time.Minute, // Generous timeout for large files + timeout := 10 * time.Minute + if !isOfficialEndpoint { + timeout = 30 * time.Minute // Default for self-hosted endpoints + } + if t := a.GetIntParameter(params, "timeout_minutes"); t > 0 { + timeout = time.Duration(t) * time.Minute } + client := &http.Client{Timeout: timeout} resp, err := client.Do(req) if err != nil { writeLog("Error: Request failed: %v", err) @@ -247,8 +273,8 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn if resp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(resp.Body) - writeLog("Error: OpenAI API error (status %d): %s", resp.StatusCode, string(respBody)) - return nil, fmt.Errorf("OpenAI API error (status %d): %s", resp.StatusCode, string(respBody)) + writeLog("Error: transcription API error (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("transcription API error (status %d): %s", resp.StatusCode, string(respBody)) } writeLog("Response received. Parsing...") diff --git a/internal/transcription/adapters_test.go b/internal/transcription/adapters_test.go index ca09ca877..4c3a35c0a 100644 --- a/internal/transcription/adapters_test.go +++ b/internal/transcription/adapters_test.go @@ -553,6 +553,49 @@ func BenchmarkModelRegistryLookup(b *testing.B) { } } +func TestOpenAIAdapter(t *testing.T) { + a := adapters.NewOpenAIAdapter("sk-test") + if a == nil { + t.Fatal("NewOpenAIAdapter returned nil") + } + + caps := a.GetCapabilities() + if caps.ModelID != "openai_whisper" { + t.Errorf("expected model ID 'openai_whisper', got %q", caps.ModelID) + } + if caps.ModelFamily != "openai" { + t.Errorf("expected model family 'openai', got %q", caps.ModelFamily) + } + if !caps.Features["diarization"] { + t.Error("diarization capability must be true") + } + + schema := a.GetParameterSchema() + hasBaseURL := false + for _, p := range schema { + if p.Name == "base_url" { + hasBaseURL = true + } + if p.Name == "model" && len(p.Options) > 0 { + t.Errorf("model parameter must not have a fixed Options list, got %v", p.Options) + } + } + if !hasBaseURL { + t.Error("schema must include base_url parameter") + } +} + +func TestOpenAIAdapterWithBaseURL(t *testing.T) { + a := adapters.NewOpenAIAdapter("") + if a == nil { + t.Fatal("NewOpenAIAdapter returned nil") + } + caps := a.GetCapabilities() + if !caps.Features["diarization"] { + t.Error("diarization capability must be true") + } +} + func BenchmarkParameterValidation(b *testing.B) { reg := registry.GetRegistry() adapter, err := reg.GetTranscriptionAdapter("whisperx") diff --git a/internal/transcription/unified_service.go b/internal/transcription/unified_service.go index e17ef8b4b..f38d33062 100644 --- a/internal/transcription/unified_service.go +++ b/internal/transcription/unified_service.go @@ -589,6 +589,12 @@ func (u *UnifiedTranscriptionService) convertToOpenAIParams(params models.Whispe if params.APIKey != nil && *params.APIKey != "" { paramMap["api_key"] = *params.APIKey } + if params.APIURL != nil && *params.APIURL != "" { + paramMap["base_url"] = *params.APIURL + } + if params.TimeoutMinutes != nil && *params.TimeoutMinutes > 0 { + paramMap["timeout_minutes"] = *params.TimeoutMinutes + } return paramMap } diff --git a/tests/adapter_registration_test.go b/tests/adapter_registration_test.go index 3b9203c72..a43520906 100644 --- a/tests/adapter_registration_test.go +++ b/tests/adapter_registration_test.go @@ -47,6 +47,19 @@ func TestAdapterEnvPathInjection(t *testing.T) { } } +// TestOpenAIAdapterConstruction tests the OpenAI adapter constructor +func TestOpenAIAdapterConstruction(t *testing.T) { + a := adapters.NewOpenAIAdapter("") + if a == nil { + t.Fatal("NewOpenAIAdapter returned nil with empty key") + } + + a = adapters.NewOpenAIAdapter("sk-test") + if !a.GetCapabilities().Features["diarization"] { + t.Error("diarization capability must be true") + } +} + // TestRegisterAdapters tests that registerAdapters correctly registers all adapters func TestRegisterAdapters(t *testing.T) { // Clear registry before test @@ -67,6 +80,8 @@ func TestRegisterAdapters(t *testing.T) { adapters.NewParakeetAdapter(nvidiaEnvPath)) registry.RegisterTranscriptionAdapter("canary", adapters.NewCanaryAdapter(nvidiaEnvPath)) + registry.RegisterTranscriptionAdapter("openai_whisper", + adapters.NewOpenAIAdapter("")) registry.RegisterDiarizationAdapter("pyannote", adapters.NewPyAnnoteAdapter(nvidiaEnvPath)) @@ -75,8 +90,8 @@ func TestRegisterAdapters(t *testing.T) { // Verify registrations transcriptionAdapters := registry.GetTranscriptionAdapters() - if len(transcriptionAdapters) != 3 { - t.Errorf("Expected 3 transcription adapters, got %d", len(transcriptionAdapters)) + if len(transcriptionAdapters) != 4 { + t.Errorf("Expected 4 transcription adapters, got %d", len(transcriptionAdapters)) } // Check specific adapters are registered @@ -89,6 +104,9 @@ func TestRegisterAdapters(t *testing.T) { if _, exists := transcriptionAdapters["canary"]; !exists { t.Error("canary adapter not registered") } + if _, exists := transcriptionAdapters["openai_whisper"]; !exists { + t.Error("openai_whisper adapter not registered") + } diarizationAdapters := registry.GetDiarizationAdapters() if len(diarizationAdapters) != 2 { diff --git a/web/frontend/src/components/transcription/TranscriptionConfigDialog.tsx b/web/frontend/src/components/transcription/TranscriptionConfigDialog.tsx index 4f46aa333..728d5458f 100644 --- a/web/frontend/src/components/transcription/TranscriptionConfigDialog.tsx +++ b/web/frontend/src/components/transcription/TranscriptionConfigDialog.tsx @@ -72,6 +72,8 @@ export interface WhisperXParams { attention_context_right: number; is_multi_track_enabled: boolean; api_key?: string; + api_url?: string; + timeout_minutes?: number; max_new_tokens?: number; } @@ -394,9 +396,13 @@ export const TranscriptionConfigDialog = memo(function TranscriptionConfigDialog )} {params.model_family === "openai" && ( )} @@ -632,14 +638,30 @@ interface OpenAIConfigProps extends ConfigProps { } function OpenAIConfig({ - params, updateParam, - isValidating, validationStatus, validationMessage, availableModels, onValidate + params, + updateParam, + isMultiTrack, + isValidating, + validationStatus, + validationMessage, + availableModels, + onValidate }: OpenAIConfigProps) { return (
- + + updateParam('api_url', e.target.value || undefined)} + className={inputClassName} + /> + + +
updateParam('api_key', e.target.value)} className={`${inputClassName} flex-1`} /> - + {!params.api_url && ( + + )}
- {validationStatus !== 'idle' && ( + {validationStatus !== 'idle' && !params.api_url && (
{validationStatus === 'valid' ? : } {validationMessage} @@ -662,12 +686,42 @@ function OpenAIConfig({ )} - updateParam('model', v)} options={availableModels} /> + {params.api_url ? ( + + updateParam('model', e.target.value)} + className={inputClassName} + /> + + ) : ( + updateParam('model', v)} options={availableModels} /> + )} + updateParam('language', v === "auto" ? undefined : v)} options={LANGUAGES} /> + + {params.api_url && ( + + updateParam('timeout_minutes', e.target.value ? parseInt(e.target.value) : undefined)} + className={inputClassName} + /> + + )}
- {params.model && params.model !== "whisper-1" && ( + {!isMultiTrack && ( + + )} + + {params.model && params.model !== "whisper-1" && !params.api_url && ( Word-level timestamps are only supported by whisper-1. Synchronized playback won't be available. diff --git a/web/project-site/public/api/swagger.json b/web/project-site/public/api/swagger.json index f5ddfdbb3..eb430c1b7 100644 --- a/web/project-site/public/api/swagger.json +++ b/web/project-site/public/api/swagger.json @@ -4778,6 +4778,10 @@ "description": "OpenAI settings", "type": "string" }, + "api_url": { + "description": "Custom transcription API base URL (OpenAI adapter only)", + "type": "string" + }, "attention_context_left": { "description": "NVIDIA Parakeet-specific parameters for long-form audio", "type": "integer" @@ -4930,6 +4934,10 @@ "threads": { "type": "integer" }, + "timeout_minutes": { + "description": "HTTP request timeout in minutes (OpenAI adapter with custom base URL)", + "type": "integer" + }, "vad_method": { "description": "VAD (Voice Activity Detection) settings", "type": "string" From 8518efee51ed25d2043457fa8cd6188119161296 Mon Sep 17 00:00:00 2001 From: Codex Date: Thu, 4 Jun 2026 12:11:09 -0400 Subject: [PATCH 2/5] fix: honor sortformer speaker limits --- .../adapters/sortformer_adapter.go | 143 +++++++++++++++++- .../adapters/sortformer_adapter_test.go | 36 +++++ .../transcription/sortformer_params_test.go | 25 +++ internal/transcription/unified_service.go | 12 +- 4 files changed, 212 insertions(+), 4 deletions(-) create mode 100644 internal/transcription/adapters/sortformer_adapter_test.go create mode 100644 internal/transcription/sortformer_params_test.go diff --git a/internal/transcription/adapters/sortformer_adapter.go b/internal/transcription/adapters/sortformer_adapter.go index f7b133c0b..47065cb12 100644 --- a/internal/transcription/adapters/sortformer_adapter.go +++ b/internal/transcription/adapters/sortformer_adapter.go @@ -7,6 +7,7 @@ import ( "os" "os/exec" "path/filepath" + "sort" "strconv" "strings" "time" @@ -64,6 +65,16 @@ func NewSortformerAdapter(envPath string) *SortformerAdapter { Description: "Maximum number of speakers (optimized for 4)", Group: "basic", }, + { + Name: "min_speakers", + Type: "int", + Required: false, + Default: 1, + Min: &[]float64{1}[0], + Max: &[]float64{8}[0], + Description: "Minimum number of speakers", + Group: "basic", + }, { Name: "batch_size", Type: "int", @@ -423,10 +434,22 @@ func (s *SortformerAdapter) buildSortformerArgs(input interfaces.AudioInput, par func (s *SortformerAdapter) parseResult(tempDir string, input interfaces.AudioInput, params map[string]interface{}) (*interfaces.DiarizationResult, error) { outputFormat := s.GetStringParameter(params, "output_format") + var ( + result *interfaces.DiarizationResult + err error + ) + if outputFormat == OutputFormatJSON { - return s.parseJSONResult(tempDir) + result, err = s.parseJSONResult(tempDir) + } else { + result, err = s.parseRTTMResult(tempDir, input) + } + + if err != nil { + return nil, err } - return s.parseRTTMResult(tempDir, input) + + return s.enforceSpeakerLimit(result, params), nil } // parseJSONResult parses JSON format output @@ -538,6 +561,122 @@ func (s *SortformerAdapter) parseRTTMResult(tempDir string, input interfaces.Aud return result, nil } +type sortformerSpeakerDuration struct { + speaker string + duration float64 +} + +func (s *SortformerAdapter) enforceSpeakerLimit(result *interfaces.DiarizationResult, params map[string]interface{}) *interfaces.DiarizationResult { + maxSpeakers := s.GetIntParameter(params, "max_speakers") + if result == nil || maxSpeakers <= 0 { + return result + } + + speakerDurations := make(map[string]float64) + for _, segment := range result.Segments { + duration := segment.End - segment.Start + if duration < 0 { + duration = 0 + } + speakerDurations[segment.Speaker] += duration + } + + originalSpeakerCount := len(speakerDurations) + if originalSpeakerCount <= maxSpeakers { + result.SpeakerCount = originalSpeakerCount + result.Speakers = sortedSpeakerList(speakerDurations) + return result + } + + rankedSpeakers := make([]sortformerSpeakerDuration, 0, len(speakerDurations)) + for speaker, duration := range speakerDurations { + rankedSpeakers = append(rankedSpeakers, sortformerSpeakerDuration{ + speaker: speaker, + duration: duration, + }) + } + + sort.Slice(rankedSpeakers, func(i, j int) bool { + if rankedSpeakers[i].duration == rankedSpeakers[j].duration { + return rankedSpeakers[i].speaker < rankedSpeakers[j].speaker + } + return rankedSpeakers[i].duration > rankedSpeakers[j].duration + }) + + keptSpeakers := make(map[string]bool, maxSpeakers) + fallbackSpeaker := rankedSpeakers[0].speaker + for i := 0; i < maxSpeakers && i < len(rankedSpeakers); i++ { + keptSpeakers[rankedSpeakers[i].speaker] = true + } + + originalSegments := append([]interfaces.DiarizationSegment(nil), result.Segments...) + for i := range result.Segments { + if keptSpeakers[result.Segments[i].Speaker] { + continue + } + + result.Segments[i].Speaker = nearestKeptSpeaker(originalSegments, i, keptSpeakers, fallbackSpeaker) + } + + rebuildSpeakerSummary(result) + + logger.Warn("Sortformer returned more speakers than requested; remapped extra speaker labels", + "requested_max_speakers", maxSpeakers, + "original_speakers", originalSpeakerCount, + "final_speakers", result.SpeakerCount) + + return result +} + +func nearestKeptSpeaker(segments []interfaces.DiarizationSegment, targetIndex int, keptSpeakers map[string]bool, fallbackSpeaker string) string { + target := segments[targetIndex] + bestSpeaker := fallbackSpeaker + bestDistance := -1.0 + + for i, segment := range segments { + if i == targetIndex || !keptSpeakers[segment.Speaker] { + continue + } + + distance := segmentDistance(target, segment) + if bestDistance < 0 || distance < bestDistance || (distance == bestDistance && segment.Speaker < bestSpeaker) { + bestSpeaker = segment.Speaker + bestDistance = distance + } + } + + return bestSpeaker +} + +func segmentDistance(a, b interfaces.DiarizationSegment) float64 { + if b.End <= a.Start { + return a.Start - b.End + } + if a.End <= b.Start { + return b.Start - a.End + } + return 0 +} + +func rebuildSpeakerSummary(result *interfaces.DiarizationResult) { + speakers := make(map[string]float64) + for _, segment := range result.Segments { + speakers[segment.Speaker] += segment.End - segment.Start + } + + result.SpeakerCount = len(speakers) + result.Speakers = sortedSpeakerList(speakers) +} + +func sortedSpeakerList(speakers map[string]float64) []string { + speakerList := make([]string, 0, len(speakers)) + for speaker := range speakers { + speakerList = append(speakerList, speaker) + } + sort.Strings(speakerList) + return speakerList +} + // GetEstimatedProcessingTime provides Sortformer-specific time estimation func (s *SortformerAdapter) GetEstimatedProcessingTime(input interfaces.AudioInput) time.Duration { // Sortformer is typically very fast, often faster than real-time diff --git a/internal/transcription/adapters/sortformer_adapter_test.go b/internal/transcription/adapters/sortformer_adapter_test.go new file mode 100644 index 000000000..708bdd01b --- /dev/null +++ b/internal/transcription/adapters/sortformer_adapter_test.go @@ -0,0 +1,36 @@ +package adapters + +import ( + "testing" + + "scriberr/internal/transcription/interfaces" +) + +func TestSortformerEnforceSpeakerLimitRemapsExtraSpeakers(t *testing.T) { + adapter := NewSortformerAdapter("/tmp/sortformer") + result := &interfaces.DiarizationResult{ + Segments: []interfaces.DiarizationSegment{ + {Start: 0, End: 10, Speaker: "speaker_0", Confidence: 1.0}, + {Start: 12, End: 20, Speaker: "speaker_1", Confidence: 1.0}, + {Start: 21, End: 22, Speaker: "speaker_2", Confidence: 1.0}, + }, + SpeakerCount: 3, + Speakers: []string{"speaker_0", "speaker_1", "speaker_2"}, + } + + capped := adapter.enforceSpeakerLimit(result, map[string]interface{}{ + "max_speakers": 2, + }) + + if capped.SpeakerCount != 2 { + t.Fatalf("expected 2 speakers, got %d", capped.SpeakerCount) + } + if capped.Segments[2].Speaker != "speaker_1" { + t.Fatalf("expected speaker_2 segment to map to nearest retained speaker_1, got %s", capped.Segments[2].Speaker) + } + for _, speaker := range capped.Speakers { + if speaker == "speaker_2" { + t.Fatal("speaker_2 should have been removed from the speaker summary") + } + } +} diff --git a/internal/transcription/sortformer_params_test.go b/internal/transcription/sortformer_params_test.go new file mode 100644 index 000000000..adfa2d2e0 --- /dev/null +++ b/internal/transcription/sortformer_params_test.go @@ -0,0 +1,25 @@ +package transcription + +import ( + "testing" + + "scriberr/internal/models" +) + +func TestConvertToSortformerParamsPreservesSpeakerConstraints(t *testing.T) { + minSpeakers := 2 + maxSpeakers := 2 + service := &UnifiedTranscriptionService{} + + params := service.convertToSortformerParams(models.WhisperXParams{ + MinSpeakers: &minSpeakers, + MaxSpeakers: &maxSpeakers, + }) + + if params["min_speakers"] != minSpeakers { + t.Fatalf("expected min_speakers=%d, got %v", minSpeakers, params["min_speakers"]) + } + if params["max_speakers"] != maxSpeakers { + t.Fatalf("expected max_speakers=%d, got %v", maxSpeakers, params["max_speakers"]) + } +} diff --git a/internal/transcription/unified_service.go b/internal/transcription/unified_service.go index f38d33062..1129353a0 100644 --- a/internal/transcription/unified_service.go +++ b/internal/transcription/unified_service.go @@ -745,11 +745,19 @@ func (u *UnifiedTranscriptionService) convertToPyannoteParams(params models.Whis // convertToSortformerParams converts to Sortformer-specific parameters func (u *UnifiedTranscriptionService) convertToSortformerParams(params models.WhisperXParams) map[string]interface{} { - return map[string]interface{}{ + paramMap := map[string]interface{}{ "output_format": OutputFormatJSON, "auto_convert_audio": true, - // Sortformer is optimized for 4 speakers, no additional config needed } + + if params.MinSpeakers != nil { + paramMap["min_speakers"] = *params.MinSpeakers + } + if params.MaxSpeakers != nil { + paramMap["max_speakers"] = *params.MaxSpeakers + } + + return paramMap } func (u *UnifiedTranscriptionService) parametersToMap(params models.WhisperXParams) map[string]interface{} { From 11c179ccae5dca2a6d0a32b55a79e3ecc0c88f67 Mon Sep 17 00:00:00 2001 From: Codex Date: Thu, 4 Jun 2026 12:29:17 -0400 Subject: [PATCH 3/5] feat: allow per-job speaker limits from profiles --- .../src/components/TranscribeDDialog.tsx | 93 ++++++++++++++++++- 1 file changed, 88 insertions(+), 5 deletions(-) diff --git a/web/frontend/src/components/TranscribeDDialog.tsx b/web/frontend/src/components/TranscribeDDialog.tsx index e88482ca6..49cc9a75d 100644 --- a/web/frontend/src/components/TranscribeDDialog.tsx +++ b/web/frontend/src/components/TranscribeDDialog.tsx @@ -15,6 +15,7 @@ import { SelectValue, } from "@/components/ui/select"; import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { Loader2 } from "lucide-react"; import type { WhisperXParams } from "./TranscriptionConfigDialog"; @@ -50,6 +51,18 @@ export function TranscribeDDialog({ const [selectedProfileId, setSelectedProfileId] = useState(""); const [profilesLoading, setProfilesLoading] = useState(false); const [defaultProfile, setDefaultProfile] = useState(null); + const [minSpeakersInput, setMinSpeakersInput] = useState(""); + const [maxSpeakersInput, setMaxSpeakersInput] = useState(""); + + const selectedProfile = profiles.find(p => p.id === selectedProfileId); + const showSpeakerLimits = selectedProfile?.parameters.diarize ?? false; + const minSpeakers = parseSpeakerLimit(minSpeakersInput); + const maxSpeakers = parseSpeakerLimit(maxSpeakersInput); + const hasInvalidSpeakerLimits = showSpeakerLimits && ( + (minSpeakersInput.trim() !== "" && minSpeakers === undefined) || + (maxSpeakersInput.trim() !== "" && maxSpeakers === undefined) + ); + const hasSpeakerRangeError = showSpeakerLimits && minSpeakers !== undefined && maxSpeakers !== undefined && minSpeakers > maxSpeakers; const fetchProfiles = useCallback(async () => { try { @@ -101,13 +114,23 @@ export function TranscribeDDialog({ } }, [open, fetchProfiles]); + useEffect(() => { + const profile = profiles.find(p => p.id === selectedProfileId); + + setMinSpeakersInput(formatSpeakerLimit(profile?.parameters.min_speakers)); + setMaxSpeakersInput(formatSpeakerLimit(profile?.parameters.max_speakers)); + }, [profiles, selectedProfileId]); + const handleStartTranscription = () => { - if (!selectedProfileId) return; + if (!selectedProfile || hasInvalidSpeakerLimits || hasSpeakerRangeError) return; - const selectedProfile = profiles.find(p => p.id === selectedProfileId); - if (selectedProfile) { - onStartTranscription(selectedProfile.parameters, selectedProfile.id); + const params: WhisperXParams = { ...selectedProfile.parameters }; + if (showSpeakerLimits) { + params.min_speakers = minSpeakers; + params.max_speakers = maxSpeakers; } + + onStartTranscription(params, selectedProfile.id); }; const handleProfileChange = (value: string) => { @@ -183,6 +206,52 @@ export function TranscribeDDialog({ )}
+ {showSpeakerLimits && ( +
+
+
+ + setMinSpeakersInput(event.target.value)} + aria-invalid={hasInvalidSpeakerLimits || hasSpeakerRangeError} + className="h-11 rounded-[var(--radius-btn)] bg-[var(--bg-main)] border border-[var(--border-subtle)] text-[var(--text-primary)] focus-visible:ring-[var(--brand-light)] focus-visible:border-[var(--brand-solid)] shadow-none" + /> +
+
+ + setMaxSpeakersInput(event.target.value)} + aria-invalid={hasInvalidSpeakerLimits || hasSpeakerRangeError} + className="h-11 rounded-[var(--radius-btn)] bg-[var(--bg-main)] border border-[var(--border-subtle)] text-[var(--text-primary)] focus-visible:ring-[var(--brand-light)] focus-visible:border-[var(--brand-solid)] shadow-none" + /> +
+
+ {hasInvalidSpeakerLimits && ( +

Use whole numbers from 1 to 20.

+ )} + {!hasInvalidSpeakerLimits && hasSpeakerRangeError && ( +

Min speakers cannot exceed max speakers.

+ )} +
+ )} @@ -196,7 +265,7 @@ export function TranscribeDDialog({