Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions api-docs/swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -4778,6 +4778,10 @@
"description": "OpenAI settings",
"type": "string"
},
"api_url": {
"description": "Custom transcription API base URL (OpenAI adapter only)",
"type": "string"
},
"attention_context_left": {
"description": "NVIDIA Parakeet-specific parameters for long-form audio",
"type": "integer"
Expand Down Expand Up @@ -4930,6 +4934,10 @@
"threads": {
"type": "integer"
},
"timeout_minutes": {
"description": "HTTP request timeout in minutes (OpenAI adapter with custom base URL)",
"type": "integer"
},
"vad_method": {
"description": "VAD (Voice Activity Detection) settings",
"type": "string"
Expand Down
6 changes: 6 additions & 0 deletions api-docs/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,9 @@ definitions:
api_key:
description: OpenAI settings
type: string
api_url:
description: Custom transcription API base URL (OpenAI adapter only)
type: string
attention_context_left:
description: NVIDIA Parakeet-specific parameters for long-form audio
type: integer
Expand Down Expand Up @@ -747,6 +750,9 @@ definitions:
type: number
threads:
type: integer
timeout_minutes:
description: HTTP request timeout in minutes (OpenAI adapter with custom base URL)
type: integer
vad_method:
description: VAD (Voice Activity Detection) settings
type: string
Expand Down
87 changes: 87 additions & 0 deletions internal/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,64 @@ func (h *Handler) StartTranscription(c *gin.Context) {
c.JSON(http.StatusOK, job)
}

// @Summary Start diarization for a completed transcription
// @Description Run speaker diarization against an existing transcript without rerunning transcription
// @Tags transcription
// @Accept json
// @Produce json
// @Param id path string true "Job ID"
// @Param parameters body models.WhisperXParams true "Diarization parameters"
// @Success 200 {object} models.TranscriptionJob
// @Failure 400 {object} map[string]string
// @Failure 404 {object} map[string]string
// @Router /api/v1/transcription/{id}/diarize [post]
// @Security ApiKeyAuth
// @Security BearerAuth
func (h *Handler) StartDiarization(c *gin.Context) {
jobID := c.Param("id")

job, err := h.getJobForDiarization(c, jobID)
if err != nil {
return
}

requestParams, err := h.getValidatedTranscriptionParams(c, job, jobID)
if err != nil {
return
}

requestParams.Diarize = true
requestParams.DiarizationOnly = true
requestParams.IsMultiTrackEnabled = false
if requestParams.DiarizeModel == "" {
requestParams.DiarizeModel = transcription.ModelPyannote
}

job.Parameters = *requestParams
job.Diarization = true
job.Status = models.StatusPending
job.ErrorMessage = nil

if err := h.jobRepo.Update(c.Request.Context(), job); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update job"})
return
}

if err := h.taskQueue.EnqueueJob(jobID); err != nil {
logger.Error("Failed to enqueue diarization job", "job_id", jobID, "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to enqueue diarization job"})
return
}

logger.Info("Diarization job started",
"job_id", jobID,
"diarize_model", requestParams.DiarizeModel,
"min_speakers", requestParams.MinSpeakers,
"max_speakers", requestParams.MaxSpeakers)

c.JSON(http.StatusOK, job)
}

func (h *Handler) getJobForTranscription(c *gin.Context, jobID string) (*models.TranscriptionJob, error) {
job, err := h.jobRepo.FindByID(c.Request.Context(), jobID)
if err != nil {
Expand All @@ -1052,6 +1110,35 @@ func (h *Handler) getJobForTranscription(c *gin.Context, jobID string) (*models.
return job, nil
}

func (h *Handler) getJobForDiarization(c *gin.Context, jobID string) (*models.TranscriptionJob, error) {
job, err := h.jobRepo.FindByID(c.Request.Context(), jobID)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{"error": "Job not found"})
return nil, err
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get job"})
return nil, err
}

if job.Status == models.StatusProcessing || job.Status == models.StatusPending {
c.JSON(http.StatusBadRequest, gin.H{"error": "Cannot start diarization: job is currently processing or pending"})
return nil, fmt.Errorf("invalid job status")
}

if job.IsMultiTrack {
c.JSON(http.StatusBadRequest, gin.H{"error": "Standalone diarization is not supported for multi-track recordings"})
return nil, fmt.Errorf("multi-track diarization unsupported")
}

if job.Transcript == nil || strings.TrimSpace(*job.Transcript) == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Cannot start diarization: job does not have an existing transcript"})
return nil, fmt.Errorf("missing transcript")
}

return job, nil
}

func (h *Handler) getValidatedTranscriptionParams(c *gin.Context, job *models.TranscriptionJob, jobID string) (*models.WhisperXParams, error) {
// Set defaults
requestParams := models.WhisperXParams{
Expand Down
1 change: 1 addition & 0 deletions internal/api/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ func SetupRoutes(handler *Handler, authService *auth.AuthService) *gin.Engine {
transcription.POST("/youtube", handler.DownloadFromYouTube)
transcription.POST("/submit", handler.SubmitJob)
transcription.POST("/:id/start", handler.StartTranscription)
transcription.POST("/:id/diarize", handler.StartDiarization)
transcription.POST("/:id/kill", handler.KillJob)
transcription.GET("/:id/logs", handler.GetJobLogs)
transcription.GET("/:id/status", handler.GetJobStatus)
Expand Down
11 changes: 7 additions & 4 deletions internal/models/transcription.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ type WhisperXParams struct {

// Diarization settings
Diarize bool `json:"diarize" gorm:"type:boolean;default:false"`
DiarizationOnly bool `json:"diarization_only" gorm:"type:boolean;default:false"`
MinSpeakers *int `json:"min_speakers,omitempty" gorm:"type:int"`
MaxSpeakers *int `json:"max_speakers,omitempty" gorm:"type:int"`
DiarizeModel string `json:"diarize_model" gorm:"type:varchar(50);default:'pyannote'"` // Options: 'pyannote', 'nvidia_sortformer'
Expand Down Expand Up @@ -127,7 +128,9 @@ type WhisperXParams struct {
CallbackURL *string `json:"callback_url,omitempty" gorm:"type:text"`

// OpenAI settings
APIKey *string `json:"api_key,omitempty" gorm:"type:text"`
APIKey *string `json:"api_key,omitempty" gorm:"type:text"`
APIURL *string `json:"api_url,omitempty" gorm:"type:text"`
TimeoutMinutes *int `json:"timeout_minutes,omitempty" gorm:"type:int"`

// Voxtral settings
MaxNewTokens *int `json:"max_new_tokens,omitempty" gorm:"type:int"`
Expand Down Expand Up @@ -207,10 +210,10 @@ func (tp *TranscriptionProfile) BeforeSave(tx *gorm.DB) error {
// LLMConfig represents LLM configuration settings
type LLMConfig struct {
ID uint `json:"id" gorm:"primaryKey"`
Provider string `json:"provider" gorm:"not null;type:varchar(50)"` // "ollama" or "openai"
BaseURL *string `json:"base_url,omitempty" gorm:"type:text"` // For Ollama
Provider string `json:"provider" gorm:"not null;type:varchar(50)"` // "ollama" or "openai"
BaseURL *string `json:"base_url,omitempty" gorm:"type:text"` // For Ollama
OpenAIBaseURL *string `json:"openai_base_url,omitempty" gorm:"type:text"` // For OpenAI custom endpoint
APIKey *string `json:"api_key,omitempty" gorm:"type:text"` // For OpenAI (encrypted)
APIKey *string `json:"api_key,omitempty" gorm:"type:text"` // For OpenAI (encrypted)
IsActive bool `json:"is_active" gorm:"type:boolean;default:false"`
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
Expand Down
12 changes: 11 additions & 1 deletion internal/transcription/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ err := adapter.ValidateParameters(params)
| `whisperx` | `whisper` | 90+ languages | Timestamps, Diarization, Translation |
| `parakeet` | `nvidia_parakeet` | English only | Timestamps, Long-form, High Quality |
| `canary` | `nvidia_canary` | 12 languages | Timestamps, Translation, Multilingual |
| `openai_whisper` | `openai` | 57 languages | Timestamps, Diarization, Translation, Custom Endpoint |

### Diarization Models

Expand Down Expand Up @@ -221,9 +222,18 @@ params := map[string]interface{}{
// NVIDIA Canary with translation
params := map[string]interface{}{
"source_lang": "es",
"target_lang": "en",
"target_lang": "en",
"task": "translate",
}

// OpenAI with custom self-hosted endpoint
params := map[string]interface{}{
"base_url": "http://localhost:8000/v1",
"model": "Systran/faster-whisper-large-v3",
"timeout_minutes": 30,
"diarize": true,
"diarize_model": "pyannote",
}
```

## Testing
Expand Down
50 changes: 38 additions & 12 deletions internal/transcription/adapters/openai_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func NewOpenAIAdapter(apiKey string) *OpenAIAdapter {
Features: map[string]bool{
"timestamps": true, // Verbose JSON response includes segments
"word_level": false, // Not supported by standard API yet (unless using verbose_json with timestamp_granularities which is beta)
"diarization": false, // Not supported by OpenAI API
"diarization": true, // Post-processing via pyannote/sortformer pipeline
"translation": true,
"language_detection": true,
"vad": true, // Implicit
Expand All @@ -59,13 +59,19 @@ func NewOpenAIAdapter(apiKey string) *OpenAIAdapter {
Description: "OpenAI API Key (overrides system default)",
Group: "authentication",
},
{
Name: "base_url",
Type: "string",
Required: false,
Description: "Custom transcription API base URL (overrides server default)",
Group: "authentication",
},
{
Name: "model",
Type: "string",
Required: false,
Default: "whisper-1",
Options: []string{"whisper-1"},
Description: "ID of the model to use",
Description: "Model name (e.g. whisper-1, or any model exposed by a custom endpoint)",
Group: "basic",
},
{
Expand All @@ -92,6 +98,15 @@ func NewOpenAIAdapter(apiKey string) *OpenAIAdapter {
Description: "Sampling temperature",
Group: "quality",
},
{
Name: "timeout_minutes",
Type: "int",
Required: false,
Default: nil,
Min: &[]float64{1}[0],
Description: "HTTP request timeout in minutes (increase for large files on self-hosted endpoints)",
Group: "advanced",
},
}

baseAdapter := NewBaseAdapter("openai_whisper", "", capabilities, schema)
Expand Down Expand Up @@ -153,7 +168,14 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn
apiKey = key
}

if apiKey == "" {
const officialURL = "https://api.openai.com/v1/audio/transcriptions"
endpointURL := officialURL
if url := a.GetStringParameter(params, "base_url"); url != "" {
endpointURL = strings.TrimRight(url, "/") + "/audio/transcriptions"
}
isOfficialEndpoint := endpointURL == officialURL

if apiKey == "" && isOfficialEndpoint {
writeLog("Error: OpenAI API key is required but not provided")
return nil, fmt.Errorf("OpenAI API key is required but not provided")
}
Expand Down Expand Up @@ -188,7 +210,7 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn
writeLog("Model: %s", model)
_ = writer.WriteField("model", model)

if strings.HasPrefix(model, "gpt-4o") {
if isOfficialEndpoint && strings.HasPrefix(model, "gpt-4o") {
if strings.Contains(model, "diarize") {
_ = writer.WriteField("response_format", "diarized_json")
} else {
Expand All @@ -197,7 +219,6 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn
// gpt-4o models don't support timestamp_granularities with these formats
} else {
_ = writer.WriteField("response_format", "verbose_json")
// timestamp_granularities is only supported for whisper-1
if model == "whisper-1" {
_ = writer.WriteField("timestamp_granularities[]", "word") // Request word timestamps
_ = writer.WriteField("timestamp_granularities[]", "segment") // Request segment timestamps
Expand All @@ -224,8 +245,8 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn
}

// Create request
writeLog("Sending request to OpenAI API...")
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.openai.com/v1/audio/transcriptions", body)
writeLog("Sending request to %s...", endpointURL)
req, err := http.NewRequestWithContext(ctx, "POST", endpointURL, body)
if err != nil {
writeLog("Error: Failed to create request: %v", err)
return nil, fmt.Errorf("failed to create request: %w", err)
Expand All @@ -235,9 +256,14 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn
req.Header.Set("Authorization", "Bearer "+apiKey)

// Execute request
client := &http.Client{
Timeout: 10 * time.Minute, // Generous timeout for large files
timeout := 10 * time.Minute
if !isOfficialEndpoint {
timeout = 30 * time.Minute // Default for self-hosted endpoints
}
if t := a.GetIntParameter(params, "timeout_minutes"); t > 0 {
timeout = time.Duration(t) * time.Minute
}
client := &http.Client{Timeout: timeout}
resp, err := client.Do(req)
if err != nil {
writeLog("Error: Request failed: %v", err)
Expand All @@ -247,8 +273,8 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn

if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
writeLog("Error: OpenAI API error (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("OpenAI API error (status %d): %s", resp.StatusCode, string(respBody))
writeLog("Error: transcription API error (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("transcription API error (status %d): %s", resp.StatusCode, string(respBody))
}

writeLog("Response received. Parsing...")
Expand Down
Loading