Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/llm/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ func NewLLMClient(ep ResolvedEndpoint) LLMClient {
APIKey: ep.Token,
Model: ep.Model,
AuthHeader: ep.AuthHeader,
Timeout: ep.Timeout,
ExtraBody: ep.ExtraBody,
ExtraHeaders: ep.ExtraHeaders,
}
Expand Down
69 changes: 68 additions & 1 deletion internal/llm/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ package llm
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
)

// ResolvedEndpoint holds the resolved LLM endpoint configuration.
Expand All @@ -19,6 +22,7 @@ type ResolvedEndpoint struct {
Source string // human-readable config source label
ExtraBody map[string]any // vendor-specific request body fields
ExtraHeaders map[string]string // extra HTTP headers for the LLM request
Timeout time.Duration // per-request HTTP timeout; 0 means use default (5 min)
}

// Environment variable names for OCR-specific configuration.
Expand All @@ -28,6 +32,7 @@ const (
envOCRLLMModel = "OCR_LLM_MODEL"
envOCRLLMAuthHeader = "OCR_LLM_AUTH_HEADER"
envOCRLLMExtraHeaders = "OCR_LLM_EXTRA_HEADERS"
envOCRLLMTimeout = "OCR_LLM_TIMEOUT"
envOCRUseAnthropic = "OCR_USE_ANTHROPIC"
)

Expand Down Expand Up @@ -71,13 +76,59 @@ func ResolveEndpointWithModelOverride(configPath, modelOverride string) (Resolve
ep.Source = s.name
}
ep.Model = stripModelSuffix(ep.Model)
// OCR_LLM_TIMEOUT is a global override: applies regardless of
// which strategy resolved the endpoint, and takes precedence
// over config-file values when set.
if envTimeout, ok := parseTimeoutEnv(); ok {
ep.Timeout = envTimeout
}
return ep, nil
}
}

return ResolvedEndpoint{}, fmt.Errorf("no valid LLM endpoint configured; one of OCR_LLM_URL/OCR_LLM_TOKEN/OCR_LLM_MODEL, ~/.opencodereview/config.json, or ANTHROPIC_BASE_URL/ANTHROPIC_AUTH_TOKEN/ANTHROPIC_MODEL must be set")
}

// parseTimeoutEnv reads and validates the OCR_LLM_TIMEOUT environment variable.
// Returns the parsed duration and true if set, or 0 and false if unset/empty.
// Rejects negative values and values that would overflow time.Duration.
func parseTimeoutEnv() (time.Duration, bool) {
raw := strings.TrimSpace(os.Getenv(envOCRLLMTimeout))
if raw == "" {
return 0, false
}
sec, err := strconv.Atoi(raw)
if err != nil {
return 0, false
}
if sec < 0 {
return 0, false
}
// Guard against overflow: time.Duration is int64 nanoseconds.
maxSec := int64(math.MaxInt64 / int64(time.Second))
if int64(sec) > maxSec {
return 0, false
}
return time.Duration(sec) * time.Second, true
}

// validateTimeoutSec converts a config-file timeout (in seconds) to time.Duration.
// Returns 0 for zero input (use default). Rejects negative values and overflow.
func validateTimeoutSec(sec int) (time.Duration, error) {
if sec == 0 {
return 0, nil
}
if sec < 0 {
return 0, fmt.Errorf("timeout_sec must be non-negative, got %d", sec)
}
// Guard against overflow: time.Duration is int64 nanoseconds.
maxSec := int64(math.MaxInt64 / int64(time.Second))
if int64(sec) > maxSec {
return 0, fmt.Errorf("timeout_sec %d overflows time.Duration (max %d)", sec, maxSec)
}
return time.Duration(sec) * time.Second, nil
}

// tryOCREnv reads OCR-specific environment variables.
func tryOCREnv(modelOverride string) (ResolvedEndpoint, bool, error) {
url := os.Getenv(envOCRLLMURL)
Expand Down Expand Up @@ -122,6 +173,9 @@ func tryOCREnv(modelOverride string) (ResolvedEndpoint, bool, error) {
}
}

// Note: OCR_LLM_TIMEOUT is applied globally in ResolveEndpointWithModelOverride,
// not here, so it works regardless of which strategy resolves the endpoint.

return ResolvedEndpoint{URL: url, Token: token, Model: model, Protocol: protocol, AuthHeader: authHeader, Source: "OCR environment", ExtraHeaders: extraHeaders}, true, nil
}

Expand All @@ -132,6 +186,7 @@ type llmFileConfig struct {
AuthHeader string `json:"auth_header,omitempty"`
Model string `json:"model,omitempty"`
UseAnthropic *bool `json:"use_anthropic,omitempty"` // pointer to distinguish unset from false
TimeoutSec int `json:"timeout_sec,omitempty"` // per-request HTTP timeout in seconds
ExtraBody map[string]any `json:"extra_body,omitempty"`
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
}
Comment on lines +189 to 192
Expand All @@ -144,6 +199,7 @@ type providerEntryConfig struct {
Model string `json:"model,omitempty"`
Models []string `json:"models,omitempty"`
AuthHeader string `json:"auth_header,omitempty"`
TimeoutSec int `json:"timeout_sec,omitempty"` // per-request HTTP timeout in seconds
ExtraBody map[string]any `json:"extra_body,omitempty"`
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
}
Expand Down Expand Up @@ -288,6 +344,11 @@ func tryProviderConfig(cfg configFile, modelOverride string) (ResolvedEndpoint,
extraBody = entry.ExtraBody
extraHeaders := entry.ExtraHeaders

timeout, err := validateTimeoutSec(entry.TimeoutSec)
if err != nil {
return ResolvedEndpoint{}, false, fmt.Errorf("provider %q: %w", cfg.Provider, err)
}

if protocol == "anthropic" {
url = ensureMessagesSuffix(url)
}
Expand All @@ -301,6 +362,7 @@ func tryProviderConfig(cfg configFile, modelOverride string) (ResolvedEndpoint,
Source: "provider:" + cfg.Provider,
ExtraBody: extraBody,
ExtraHeaders: extraHeaders,
Timeout: timeout,
}, true, nil
}

Expand Down Expand Up @@ -336,7 +398,12 @@ func tryLegacyLlmConfig(cfg configFile, modelOverride string) (ResolvedEndpoint,
}
}

return ResolvedEndpoint{URL: cfg.Llm.URL, Token: cfg.Llm.AuthToken, Model: model, Protocol: protocol, AuthHeader: authHeader, Source: "OCR config file", ExtraBody: cfg.Llm.ExtraBody, ExtraHeaders: cfg.Llm.ExtraHeaders}, true, nil
timeout, err := validateTimeoutSec(cfg.Llm.TimeoutSec)
if err != nil {
return ResolvedEndpoint{}, false, fmt.Errorf("OCR config file: %w", err)
}

return ResolvedEndpoint{URL: cfg.Llm.URL, Token: cfg.Llm.AuthToken, Model: model, Protocol: protocol, AuthHeader: authHeader, Source: "OCR config file", ExtraBody: cfg.Llm.ExtraBody, ExtraHeaders: cfg.Llm.ExtraHeaders, Timeout: timeout}, true, nil
}

// tryCCEnv reads Claude Code environment variables.
Expand Down
165 changes: 165 additions & 0 deletions internal/llm/resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"path/filepath"
"strings"
"testing"
"time"
)

func TestStripModelSuffix(t *testing.T) {
Expand Down Expand Up @@ -1123,3 +1124,167 @@ func TestResolveEndpoint_LegacyLlmExtraHeaders(t *testing.T) {
t.Errorf("ExtraHeaders[\"X-Legacy\"] = %q, want %q", v, "yes")
}
}

func TestParseTimeoutEnv(t *testing.T) {
tests := []struct {
name string
value string
want time.Duration
wantOK bool
}{
{"empty", "", 0, false},
{"valid 120", "120", 120 * time.Second, true},
{"valid 60", "60", 60 * time.Second, true},
{"zero", "0", 0, true},
{"negative", "-5", 0, false},
{"non-integer", "abc", 0, false},
{"with spaces", " 90 ", 90 * time.Second, true},
{"overflow", "99999999999999", 0, false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("OCR_LLM_TIMEOUT", tt.value)
got, ok := parseTimeoutEnv()
if ok != tt.wantOK {
t.Errorf("parseTimeoutEnv() ok = %v, want %v", ok, tt.wantOK)
}
if got != tt.want {
t.Errorf("parseTimeoutEnv() = %v, want %v", got, tt.want)
}
})
}
}

func TestValidateTimeoutSec(t *testing.T) {
tests := []struct {
name string
input int
want time.Duration
wantErr bool
}{
{"zero", 0, 0, false},
{"positive 60", 60, 60 * time.Second, false},
{"positive 300", 300, 300 * time.Second, false},
{"negative -1", -1, 0, true},
{"negative -100", -100, 0, true},
{"max safe", 9223372036, 9223372036 * time.Second, false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := validateTimeoutSec(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("validateTimeoutSec(%d) error = %v, wantErr %v", tt.input, err, tt.wantErr)
}
if got != tt.want {
t.Errorf("validateTimeoutSec(%d) = %v, want %v", tt.input, got, tt.want)
}
})
}
}

func TestResolveEndpoint_EnvTimeoutGlobalOverride(t *testing.T) {
clearAllEnv(t)
t.Setenv("OCR_LLM_URL", "https://api.example.com/v1")
t.Setenv("OCR_LLM_TOKEN", "test-token")
t.Setenv("OCR_LLM_MODEL", "mimo-v2.5-pro")
t.Setenv("OCR_USE_ANTHROPIC", "false")
t.Setenv("OCR_LLM_TIMEOUT", "90")

ep, err := ResolveEndpoint(filepath.Join(t.TempDir(), "nonexistent.json"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ep.Timeout != 90*time.Second {
t.Errorf("Timeout = %v, want %v", ep.Timeout, 90*time.Second)
}
}

func TestResolveEndpoint_ConfigTimeoutSec(t *testing.T) {
clearAllEnv(t)

cfg := configFile{
Llm: llmFileConfig{
URL: "https://api.example.com/v1/messages",
AuthToken: "test-token",
Model: "claude-opus-4-6",
TimeoutSec: 120,
},
}
data, _ := json.Marshal(cfg)
cfgPath := filepath.Join(t.TempDir(), "config.json")
os.WriteFile(cfgPath, data, 0644)

ep, err := ResolveEndpoint(cfgPath)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ep.Timeout != 120*time.Second {
t.Errorf("Timeout = %v, want %v", ep.Timeout, 120*time.Second)
}
}

func TestResolveEndpoint_NegativeConfigTimeoutSec(t *testing.T) {
clearAllEnv(t)

cfg := configFile{
Llm: llmFileConfig{
URL: "https://api.example.com/v1/messages",
AuthToken: "test-token",
Model: "claude-opus-4-6",
TimeoutSec: -5,
},
}
data, _ := json.Marshal(cfg)
cfgPath := filepath.Join(t.TempDir(), "config.json")
os.WriteFile(cfgPath, data, 0644)

_, err := ResolveEndpoint(cfgPath)
if err == nil {
t.Fatal("expected error for negative timeout_sec, got nil")
}
if !strings.Contains(err.Error(), "non-negative") {
t.Errorf("error %q should mention non-negative", err.Error())
}
}

func TestNewLLMClient_TimeoutForwarded(t *testing.T) {
ep := ResolvedEndpoint{
URL: "https://api.example.com/v1",
Token: "test-token",
Model: "test-model",
Timeout: 2 * time.Minute,
}

client := NewLLMClient(ep)
if client == nil {
t.Fatal("NewLLMClient returned nil")
}

// Verify the client was created (we can't easily inspect the internal timeout,
// but we can verify the client is functional and was constructed without error).
if oc, ok := client.(*OpenAIClient); ok {
if oc.cfg.Timeout != 2*time.Minute {
t.Errorf("OpenAIClient cfg.Timeout = %v, want %v", oc.cfg.Timeout, 2*time.Minute)
}
} else {
t.Errorf("expected *OpenAIClient, got %T", client)
}
}

func TestNewLLMClient_DefaultTimeout(t *testing.T) {
ep := ResolvedEndpoint{
URL: "https://api.example.com/v1",
Token: "test-token",
Model: "test-model",
// Timeout not set — should default to 5 minutes
}

client := NewLLMClient(ep)
if oc, ok := client.(*OpenAIClient); ok {
if oc.cfg.Timeout != 5*time.Minute {
t.Errorf("OpenAIClient cfg.Timeout = %v, want default %v", oc.cfg.Timeout, 5*time.Minute)
}
}
}