diff --git a/internal/llm/client.go b/internal/llm/client.go index afbec68..685c717 100644 --- a/internal/llm/client.go +++ b/internal/llm/client.go @@ -198,6 +198,7 @@ func NewLLMClient(ep ResolvedEndpoint) LLMClient { APIKey: ep.Token, Model: ep.Model, AuthHeader: ep.AuthHeader, + Timeout: ep.Timeout, ExtraBody: ep.ExtraBody, ExtraHeaders: ep.ExtraHeaders, } diff --git a/internal/llm/resolver.go b/internal/llm/resolver.go index 8ed011d..a1cfe16 100644 --- a/internal/llm/resolver.go +++ b/internal/llm/resolver.go @@ -3,10 +3,13 @@ package llm import ( "encoding/json" "fmt" + "math" "os" "path/filepath" "regexp" + "strconv" "strings" + "time" ) // ResolvedEndpoint holds the resolved LLM endpoint configuration. @@ -19,6 +22,7 @@ type ResolvedEndpoint struct { Source string // human-readable config source label ExtraBody map[string]any // vendor-specific request body fields ExtraHeaders map[string]string // extra HTTP headers for the LLM request + Timeout time.Duration // per-request HTTP timeout; 0 means use default (5 min) } // Environment variable names for OCR-specific configuration. @@ -28,6 +32,7 @@ const ( envOCRLLMModel = "OCR_LLM_MODEL" envOCRLLMAuthHeader = "OCR_LLM_AUTH_HEADER" envOCRLLMExtraHeaders = "OCR_LLM_EXTRA_HEADERS" + envOCRLLMTimeout = "OCR_LLM_TIMEOUT" envOCRUseAnthropic = "OCR_USE_ANTHROPIC" ) @@ -71,6 +76,12 @@ func ResolveEndpointWithModelOverride(configPath, modelOverride string) (Resolve ep.Source = s.name } ep.Model = stripModelSuffix(ep.Model) + // OCR_LLM_TIMEOUT is a global override: applies regardless of + // which strategy resolved the endpoint, and takes precedence + // over config-file values when set. + if envTimeout, ok := parseTimeoutEnv(); ok { + ep.Timeout = envTimeout + } return ep, nil } } @@ -78,6 +89,46 @@ func ResolveEndpointWithModelOverride(configPath, modelOverride string) (Resolve return ResolvedEndpoint{}, fmt.Errorf("no valid LLM endpoint configured; one of OCR_LLM_URL/OCR_LLM_TOKEN/OCR_LLM_MODEL, ~/.opencodereview/config.json, or ANTHROPIC_BASE_URL/ANTHROPIC_AUTH_TOKEN/ANTHROPIC_MODEL must be set") } +// parseTimeoutEnv reads and validates the OCR_LLM_TIMEOUT environment variable. +// Returns the parsed duration and true if set, or 0 and false if unset/empty. +// Rejects negative values and values that would overflow time.Duration. +func parseTimeoutEnv() (time.Duration, bool) { + raw := strings.TrimSpace(os.Getenv(envOCRLLMTimeout)) + if raw == "" { + return 0, false + } + sec, err := strconv.Atoi(raw) + if err != nil { + return 0, false + } + if sec < 0 { + return 0, false + } + // Guard against overflow: time.Duration is int64 nanoseconds. + maxSec := int64(math.MaxInt64 / int64(time.Second)) + if int64(sec) > maxSec { + return 0, false + } + return time.Duration(sec) * time.Second, true +} + +// validateTimeoutSec converts a config-file timeout (in seconds) to time.Duration. +// Returns 0 for zero input (use default). Rejects negative values and overflow. +func validateTimeoutSec(sec int) (time.Duration, error) { + if sec == 0 { + return 0, nil + } + if sec < 0 { + return 0, fmt.Errorf("timeout_sec must be non-negative, got %d", sec) + } + // Guard against overflow: time.Duration is int64 nanoseconds. + maxSec := int64(math.MaxInt64 / int64(time.Second)) + if int64(sec) > maxSec { + return 0, fmt.Errorf("timeout_sec %d overflows time.Duration (max %d)", sec, maxSec) + } + return time.Duration(sec) * time.Second, nil +} + // tryOCREnv reads OCR-specific environment variables. func tryOCREnv(modelOverride string) (ResolvedEndpoint, bool, error) { url := os.Getenv(envOCRLLMURL) @@ -122,6 +173,9 @@ func tryOCREnv(modelOverride string) (ResolvedEndpoint, bool, error) { } } + // Note: OCR_LLM_TIMEOUT is applied globally in ResolveEndpointWithModelOverride, + // not here, so it works regardless of which strategy resolves the endpoint. + return ResolvedEndpoint{URL: url, Token: token, Model: model, Protocol: protocol, AuthHeader: authHeader, Source: "OCR environment", ExtraHeaders: extraHeaders}, true, nil } @@ -132,6 +186,7 @@ type llmFileConfig struct { AuthHeader string `json:"auth_header,omitempty"` Model string `json:"model,omitempty"` UseAnthropic *bool `json:"use_anthropic,omitempty"` // pointer to distinguish unset from false + TimeoutSec int `json:"timeout_sec,omitempty"` // per-request HTTP timeout in seconds ExtraBody map[string]any `json:"extra_body,omitempty"` ExtraHeaders map[string]string `json:"extra_headers,omitempty"` } @@ -144,6 +199,7 @@ type providerEntryConfig struct { Model string `json:"model,omitempty"` Models []string `json:"models,omitempty"` AuthHeader string `json:"auth_header,omitempty"` + TimeoutSec int `json:"timeout_sec,omitempty"` // per-request HTTP timeout in seconds ExtraBody map[string]any `json:"extra_body,omitempty"` ExtraHeaders map[string]string `json:"extra_headers,omitempty"` } @@ -288,6 +344,11 @@ func tryProviderConfig(cfg configFile, modelOverride string) (ResolvedEndpoint, extraBody = entry.ExtraBody extraHeaders := entry.ExtraHeaders + timeout, err := validateTimeoutSec(entry.TimeoutSec) + if err != nil { + return ResolvedEndpoint{}, false, fmt.Errorf("provider %q: %w", cfg.Provider, err) + } + if protocol == "anthropic" { url = ensureMessagesSuffix(url) } @@ -301,6 +362,7 @@ func tryProviderConfig(cfg configFile, modelOverride string) (ResolvedEndpoint, Source: "provider:" + cfg.Provider, ExtraBody: extraBody, ExtraHeaders: extraHeaders, + Timeout: timeout, }, true, nil } @@ -336,7 +398,12 @@ func tryLegacyLlmConfig(cfg configFile, modelOverride string) (ResolvedEndpoint, } } - return ResolvedEndpoint{URL: cfg.Llm.URL, Token: cfg.Llm.AuthToken, Model: model, Protocol: protocol, AuthHeader: authHeader, Source: "OCR config file", ExtraBody: cfg.Llm.ExtraBody, ExtraHeaders: cfg.Llm.ExtraHeaders}, true, nil + timeout, err := validateTimeoutSec(cfg.Llm.TimeoutSec) + if err != nil { + return ResolvedEndpoint{}, false, fmt.Errorf("OCR config file: %w", err) + } + + return ResolvedEndpoint{URL: cfg.Llm.URL, Token: cfg.Llm.AuthToken, Model: model, Protocol: protocol, AuthHeader: authHeader, Source: "OCR config file", ExtraBody: cfg.Llm.ExtraBody, ExtraHeaders: cfg.Llm.ExtraHeaders, Timeout: timeout}, true, nil } // tryCCEnv reads Claude Code environment variables. diff --git a/internal/llm/resolver_test.go b/internal/llm/resolver_test.go index 7e5d6af..48841c8 100644 --- a/internal/llm/resolver_test.go +++ b/internal/llm/resolver_test.go @@ -6,6 +6,7 @@ import ( "path/filepath" "strings" "testing" + "time" ) func TestStripModelSuffix(t *testing.T) { @@ -1123,3 +1124,167 @@ func TestResolveEndpoint_LegacyLlmExtraHeaders(t *testing.T) { t.Errorf("ExtraHeaders[\"X-Legacy\"] = %q, want %q", v, "yes") } } + +func TestParseTimeoutEnv(t *testing.T) { + tests := []struct { + name string + value string + want time.Duration + wantOK bool + }{ + {"empty", "", 0, false}, + {"valid 120", "120", 120 * time.Second, true}, + {"valid 60", "60", 60 * time.Second, true}, + {"zero", "0", 0, true}, + {"negative", "-5", 0, false}, + {"non-integer", "abc", 0, false}, + {"with spaces", " 90 ", 90 * time.Second, true}, + {"overflow", "99999999999999", 0, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv("OCR_LLM_TIMEOUT", tt.value) + got, ok := parseTimeoutEnv() + if ok != tt.wantOK { + t.Errorf("parseTimeoutEnv() ok = %v, want %v", ok, tt.wantOK) + } + if got != tt.want { + t.Errorf("parseTimeoutEnv() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestValidateTimeoutSec(t *testing.T) { + tests := []struct { + name string + input int + want time.Duration + wantErr bool + }{ + {"zero", 0, 0, false}, + {"positive 60", 60, 60 * time.Second, false}, + {"positive 300", 300, 300 * time.Second, false}, + {"negative -1", -1, 0, true}, + {"negative -100", -100, 0, true}, + {"max safe", 9223372036, 9223372036 * time.Second, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := validateTimeoutSec(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("validateTimeoutSec(%d) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + if got != tt.want { + t.Errorf("validateTimeoutSec(%d) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestResolveEndpoint_EnvTimeoutGlobalOverride(t *testing.T) { + clearAllEnv(t) + t.Setenv("OCR_LLM_URL", "https://api.example.com/v1") + t.Setenv("OCR_LLM_TOKEN", "test-token") + t.Setenv("OCR_LLM_MODEL", "mimo-v2.5-pro") + t.Setenv("OCR_USE_ANTHROPIC", "false") + t.Setenv("OCR_LLM_TIMEOUT", "90") + + ep, err := ResolveEndpoint(filepath.Join(t.TempDir(), "nonexistent.json")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ep.Timeout != 90*time.Second { + t.Errorf("Timeout = %v, want %v", ep.Timeout, 90*time.Second) + } +} + +func TestResolveEndpoint_ConfigTimeoutSec(t *testing.T) { + clearAllEnv(t) + + cfg := configFile{ + Llm: llmFileConfig{ + URL: "https://api.example.com/v1/messages", + AuthToken: "test-token", + Model: "claude-opus-4-6", + TimeoutSec: 120, + }, + } + data, _ := json.Marshal(cfg) + cfgPath := filepath.Join(t.TempDir(), "config.json") + os.WriteFile(cfgPath, data, 0644) + + ep, err := ResolveEndpoint(cfgPath) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ep.Timeout != 120*time.Second { + t.Errorf("Timeout = %v, want %v", ep.Timeout, 120*time.Second) + } +} + +func TestResolveEndpoint_NegativeConfigTimeoutSec(t *testing.T) { + clearAllEnv(t) + + cfg := configFile{ + Llm: llmFileConfig{ + URL: "https://api.example.com/v1/messages", + AuthToken: "test-token", + Model: "claude-opus-4-6", + TimeoutSec: -5, + }, + } + data, _ := json.Marshal(cfg) + cfgPath := filepath.Join(t.TempDir(), "config.json") + os.WriteFile(cfgPath, data, 0644) + + _, err := ResolveEndpoint(cfgPath) + if err == nil { + t.Fatal("expected error for negative timeout_sec, got nil") + } + if !strings.Contains(err.Error(), "non-negative") { + t.Errorf("error %q should mention non-negative", err.Error()) + } +} + +func TestNewLLMClient_TimeoutForwarded(t *testing.T) { + ep := ResolvedEndpoint{ + URL: "https://api.example.com/v1", + Token: "test-token", + Model: "test-model", + Timeout: 2 * time.Minute, + } + + client := NewLLMClient(ep) + if client == nil { + t.Fatal("NewLLMClient returned nil") + } + + // Verify the client was created (we can't easily inspect the internal timeout, + // but we can verify the client is functional and was constructed without error). + if oc, ok := client.(*OpenAIClient); ok { + if oc.cfg.Timeout != 2*time.Minute { + t.Errorf("OpenAIClient cfg.Timeout = %v, want %v", oc.cfg.Timeout, 2*time.Minute) + } + } else { + t.Errorf("expected *OpenAIClient, got %T", client) + } +} + +func TestNewLLMClient_DefaultTimeout(t *testing.T) { + ep := ResolvedEndpoint{ + URL: "https://api.example.com/v1", + Token: "test-token", + Model: "test-model", + // Timeout not set — should default to 5 minutes + } + + client := NewLLMClient(ep) + if oc, ok := client.(*OpenAIClient); ok { + if oc.cfg.Timeout != 5*time.Minute { + t.Errorf("OpenAIClient cfg.Timeout = %v, want default %v", oc.cfg.Timeout, 5*time.Minute) + } + } +}