From ef931afb04ea04342f9af54bf0c4218c2d4109f9 Mon Sep 17 00:00:00 2001 From: zhy Date: Wed, 20 May 2026 11:15:32 +0800 Subject: [PATCH 1/2] feat: add token usage tracking and Prometheus metrics endpoint Parse token usage from upstream API responses (OpenAI, Anthropic, Gemini) and store in request_logs for cost analytics. Expose aggregated metrics at GET /metrics in Prometheus exposition format. Co-Authored-By: Claude Sonnet 4.6 --- internal/db/migrations/migration.go | 7 +- .../db/migrations/v1_2_0_AddTokenColumns.go | 37 +++++++ internal/handler/metrics_handler.go | 95 +++++++++++++++++ internal/models/types.go | 44 ++++---- internal/proxy/response_handlers.go | 31 +++++- internal/proxy/server.go | 23 ++-- internal/proxy/usage.go | 100 ++++++++++++++++++ internal/router/router.go | 1 + 8 files changed, 308 insertions(+), 30 deletions(-) create mode 100644 internal/db/migrations/v1_2_0_AddTokenColumns.go create mode 100644 internal/handler/metrics_handler.go create mode 100644 internal/proxy/usage.go diff --git a/internal/db/migrations/migration.go b/internal/db/migrations/migration.go index d8886a48f..3253cc6ed 100644 --- a/internal/db/migrations/migration.go +++ b/internal/db/migrations/migration.go @@ -11,7 +11,12 @@ func MigrateDatabase(db *gorm.DB) error { } // Run v1.1.0 migration - return V1_1_0_AddKeyHashColumn(db) + if err := V1_1_0_AddKeyHashColumn(db); err != nil { + return err + } + + // Run v1.2.0 migration + return V1_2_0_AddTokenColumns(db) } // HandleLegacyIndexes removes old indexes from previous versions to prevent migration errors diff --git a/internal/db/migrations/v1_2_0_AddTokenColumns.go b/internal/db/migrations/v1_2_0_AddTokenColumns.go new file mode 100644 index 000000000..43461288f --- /dev/null +++ b/internal/db/migrations/v1_2_0_AddTokenColumns.go @@ -0,0 +1,37 @@ +package db + +import ( + "gpt-load/internal/models" + + "github.com/sirupsen/logrus" + "gorm.io/gorm" +) + +// V1_2_0_AddTokenColumns adds token usage columns to request_logs table +func V1_2_0_AddTokenColumns(db *gorm.DB) error { + if !db.Migrator().HasColumn(&models.RequestLog{}, "prompt_tokens") { + if err := db.Migrator().AddColumn(&models.RequestLog{}, "prompt_tokens"); err != nil { + return err + } + logrus.Info("Added column prompt_tokens to request_logs") + } + if !db.Migrator().HasColumn(&models.RequestLog{}, "completion_tokens") { + if err := db.Migrator().AddColumn(&models.RequestLog{}, "completion_tokens"); err != nil { + return err + } + logrus.Info("Added column completion_tokens to request_logs") + } + if !db.Migrator().HasColumn(&models.RequestLog{}, "total_tokens") { + if err := db.Migrator().AddColumn(&models.RequestLog{}, "total_tokens"); err != nil { + return err + } + logrus.Info("Added column total_tokens to request_logs") + } + if !db.Migrator().HasColumn(&models.RequestLog{}, "token_cost_usd") { + if err := db.Migrator().AddColumn(&models.RequestLog{}, "token_cost_usd"); err != nil { + return err + } + logrus.Info("Added column token_cost_usd to request_logs") + } + return nil +} diff --git a/internal/handler/metrics_handler.go b/internal/handler/metrics_handler.go new file mode 100644 index 000000000..4134ad655 --- /dev/null +++ b/internal/handler/metrics_handler.go @@ -0,0 +1,95 @@ +package handler + +import ( + "fmt" + "strings" + + "gpt-load/internal/models" + + "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" +) + +// Metrics returns a Prometheus-text /metrics endpoint exposing token usage +// and request counts aggregated from request_logs. +// +// This is deliberately kept minimal — a full Prometheus client library is not +// introduced. The output format follows the Prometheus exposition format so +// operators can scrape it with any standard Prometheus server and build +// dashboards (e.g. Grafana) on top. +func (s *Server) Metrics(c *gin.Context) { + var results []struct { + GroupName string + Model string + TotalRequests int64 + TotalTokens int64 + TotalCost float64 + TotalPrompt int64 + TotalCompletion int64 + } + + // Aggregate token usage and request count from successful non-streaming + // requests (those are the ones where we can extract usage data). + if err := s.DB.Model(&models.RequestLog{}). + Select(`COALESCE(group_name, '') as group_name, + COALESCE(model, 'unknown') as model, + COUNT(*) as total_requests, + COALESCE(SUM(total_tokens), 0) as total_tokens, + COALESCE(SUM(token_cost_usd), 0) as total_cost, + COALESCE(SUM(prompt_tokens), 0) as total_prompt, + COALESCE(SUM(completion_tokens), 0) as total_completion`). + Where("is_success = ?", true). + Group("group_name, model"). + Scan(&results).Error; err != nil { + logrus.WithError(err).Error("Failed to query metrics") + c.String(500, "internal error\n") + return + } + + var sb strings.Builder + sb.WriteString("# HELP gpt_load_requests_total Total number of successful proxy requests by group and model\n") + sb.WriteString("# TYPE gpt_load_requests_total counter\n") + for _, r := range results { + sb.WriteString(fmt.Sprintf( + `gpt_load_requests_total{group=%q,model=%q} %d`+"\n", + r.GroupName, r.Model, r.TotalRequests, + )) + } + + sb.WriteString("\n# HELP gpt_load_tokens_total Total token count by type, group, and model\n") + sb.WriteString("# TYPE gpt_load_tokens_total counter\n") + for _, r := range results { + if r.TotalPrompt > 0 { + sb.WriteString(fmt.Sprintf( + `gpt_load_tokens_total{type="prompt",group=%q,model=%q} %d`+"\n", + r.GroupName, r.Model, r.TotalPrompt, + )) + } + if r.TotalCompletion > 0 { + sb.WriteString(fmt.Sprintf( + `gpt_load_tokens_total{type="completion",group=%q,model=%q} %d`+"\n", + r.GroupName, r.Model, r.TotalCompletion, + )) + } + if r.TotalTokens > 0 { + sb.WriteString(fmt.Sprintf( + `gpt_load_tokens_total{type="total",group=%q,model=%q} %d`+"\n", + r.GroupName, r.Model, r.TotalTokens, + )) + } + } + + sb.WriteString("\n# HELP gpt_load_cost_total Total cost in USD by group and model\n") + sb.WriteString("# TYPE gpt_load_cost_total counter\n") + for _, r := range results { + if r.TotalCost > 0 { + sb.WriteString(fmt.Sprintf( + `gpt_load_cost_total{group=%q,model=%q} %.6f`+"\n", + r.GroupName, r.Model, r.TotalCost, + )) + } + } + + c.Header("Content-Type", "text/plain; charset=utf-8") + c.String(200, sb.String()) +} diff --git a/internal/models/types.go b/internal/models/types.go index 27089f856..9d7cbe34f 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -135,26 +135,30 @@ const ( // RequestLog 对应 request_logs 表 type RequestLog struct { - ID string `gorm:"type:varchar(36);primaryKey" json:"id"` - Timestamp time.Time `gorm:"not null;index" json:"timestamp"` - GroupID uint `gorm:"not null;index" json:"group_id"` - GroupName string `gorm:"type:varchar(255);index" json:"group_name"` - ParentGroupID uint `gorm:"index" json:"parent_group_id"` - ParentGroupName string `gorm:"type:varchar(255);index" json:"parent_group_name"` - KeyValue string `gorm:"type:text" json:"key_value"` - KeyHash string `gorm:"type:varchar(128);index" json:"key_hash"` - Model string `gorm:"type:varchar(255);index" json:"model"` - IsSuccess bool `gorm:"not null" json:"is_success"` - SourceIP string `gorm:"type:varchar(64)" json:"source_ip"` - StatusCode int `gorm:"not null" json:"status_code"` - RequestPath string `gorm:"type:varchar(500)" json:"request_path"` - Duration int64 `gorm:"not null" json:"duration_ms"` - ErrorMessage string `gorm:"type:text" json:"error_message"` - UserAgent string `gorm:"type:varchar(512)" json:"user_agent"` - RequestType string `gorm:"type:varchar(20);not null;default:'final';index" json:"request_type"` - UpstreamAddr string `gorm:"type:varchar(500)" json:"upstream_addr"` - IsStream bool `gorm:"not null" json:"is_stream"` - RequestBody string `gorm:"type:text" json:"request_body"` + ID string `gorm:"type:varchar(36);primaryKey" json:"id"` + Timestamp time.Time `gorm:"not null;index" json:"timestamp"` + GroupID uint `gorm:"not null;index" json:"group_id"` + GroupName string `gorm:"type:varchar(255);index" json:"group_name"` + ParentGroupID uint `gorm:"index" json:"parent_group_id"` + ParentGroupName string `gorm:"type:varchar(255);index" json:"parent_group_name"` + KeyValue string `gorm:"type:text" json:"key_value"` + KeyHash string `gorm:"type:varchar(128);index" json:"key_hash"` + Model string `gorm:"type:varchar(255);index" json:"model"` + IsSuccess bool `gorm:"not null" json:"is_success"` + SourceIP string `gorm:"type:varchar(64)" json:"source_ip"` + StatusCode int `gorm:"not null" json:"status_code"` + RequestPath string `gorm:"type:varchar(500)" json:"request_path"` + Duration int64 `gorm:"not null" json:"duration_ms"` + ErrorMessage string `gorm:"type:text" json:"error_message"` + UserAgent string `gorm:"type:varchar(512)" json:"user_agent"` + RequestType string `gorm:"type:varchar(20);not null;default:'final';index" json:"request_type"` + UpstreamAddr string `gorm:"type:varchar(500)" json:"upstream_addr"` + IsStream bool `gorm:"not null" json:"is_stream"` + RequestBody string `gorm:"type:text" json:"request_body"` + PromptTokens int64 `gorm:"not null;default:0" json:"prompt_tokens"` + CompletionTokens int64 `gorm:"not null;default:0" json:"completion_tokens"` + TotalTokens int64 `gorm:"not null;default:0" json:"total_tokens"` + TokenCostUSD float64 `gorm:"not null;default:0" json:"token_cost_usd"` } // StatCard 用于仪表盘的单个统计卡片数据 diff --git a/internal/proxy/response_handlers.go b/internal/proxy/response_handlers.go index dc146cabf..8767ccbb6 100644 --- a/internal/proxy/response_handlers.go +++ b/internal/proxy/response_handlers.go @@ -41,8 +41,33 @@ func (ps *ProxyServer) handleStreamingResponse(c *gin.Context, resp *http.Respon } } -func (ps *ProxyServer) handleNormalResponse(c *gin.Context, resp *http.Response) { - if _, err := io.Copy(c.Writer, resp.Body); err != nil { - logUpstreamError("copying response body", err) +// handleNormalResponse buffers the upstream response body, attempts to extract +// token usage metadata, writes the original body to the client, and returns +// the parsed usage (or nil when the body is not a JSON chat-completion response). +func (ps *ProxyServer) handleNormalResponse(c *gin.Context, resp *http.Response) *TokenUsage { + body, err := io.ReadAll(resp.Body) + if err != nil { + logUpstreamError("reading response body", err) + if _, writeErr := c.Writer.Write(body); writeErr != nil { + logUpstreamError("writing buffered body to client", writeErr) + } + return nil + } + + // Check for gzip encoding + body = handleGzipCompression(resp, body) + + if _, writeErr := c.Writer.Write(body); writeErr != nil { + logUpstreamError("writing buffered body to client", writeErr) + return nil } + + // Only parse usage for successful chat-completion-like responses. + if resp.StatusCode < 400 && isChatCompletionPath(c.Request.URL.Path) { + if usage := extractTokenUsage(body); usage != nil { + return usage + } + } + + return nil } diff --git a/internal/proxy/server.go b/internal/proxy/server.go index d443f1d92..52a38ad9a 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -130,7 +130,7 @@ func (ps *ProxyServer) executeRequestWithRetry( if err != nil { logrus.Errorf("Failed to select a key for group %s on attempt %d: %v", group.Name, retryCount+1, err) response.Error(c, app_errors.NewAPIError(app_errors.ErrNoKeysAvailable, err.Error())) - ps.logRequest(c, originalGroup, group, nil, startTime, http.StatusServiceUnavailable, err, isStream, "", channelHandler, bodyBytes, models.RequestTypeFinal) + ps.logRequest(c, originalGroup, group, nil, startTime, http.StatusServiceUnavailable, err, isStream, "", channelHandler, bodyBytes, models.RequestTypeFinal, nil) return } @@ -169,7 +169,7 @@ func (ps *ProxyServer) executeRequestWithRetry( finalBodyBytes, err := channelHandler.ApplyModelRedirect(req, bodyBytes, group) if err != nil { response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, err.Error())) - ps.logRequest(c, originalGroup, group, apiKey, startTime, http.StatusBadRequest, err, isStream, upstreamURL, channelHandler, bodyBytes, models.RequestTypeFinal) + ps.logRequest(c, originalGroup, group, apiKey, startTime, http.StatusBadRequest, err, isStream, upstreamURL, channelHandler, bodyBytes, models.RequestTypeFinal, nil) return } @@ -206,7 +206,7 @@ func (ps *ProxyServer) executeRequestWithRetry( if err != nil || shouldRetryByStatus { if err != nil && app_errors.IsIgnorableError(err) { logrus.Debugf("Client-side ignorable error for key %s, aborting retries: %v", utils.MaskAPIKey(apiKey.KeyValue), err) - ps.logRequest(c, originalGroup, group, apiKey, startTime, 499, err, isStream, upstreamURL, channelHandler, bodyBytes, models.RequestTypeFinal) + ps.logRequest(c, originalGroup, group, apiKey, startTime, 499, err, isStream, upstreamURL, channelHandler, bodyBytes, models.RequestTypeFinal, nil) return } @@ -244,7 +244,7 @@ func (ps *ProxyServer) executeRequestWithRetry( requestType = models.RequestTypeFinal } - ps.logRequest(c, originalGroup, group, apiKey, startTime, statusCode, errors.New(parsedError), isStream, upstreamURL, channelHandler, bodyBytes, requestType) + ps.logRequest(c, originalGroup, group, apiKey, startTime, statusCode, errors.New(parsedError), isStream, upstreamURL, channelHandler, bodyBytes, requestType, nil) // 如果是最后一次尝试,直接返回错误,不再递归 if isLastAttempt { @@ -264,6 +264,8 @@ func (ps *ProxyServer) executeRequestWithRetry( // ps.keyProvider.UpdateStatus(apiKey, group, true) // 请求成功不再重置成功次数,减少IO消耗 logrus.Debugf("Request for group %s succeeded on attempt %d with key %s", group.Name, retryCount+1, utils.MaskAPIKey(apiKey.KeyValue)) + var usage *TokenUsage + // Check if this is a model list request (needs special handling) if shouldInterceptModelList(c.Request.URL.Path, c.Request.Method) { ps.handleModelListResponse(c, resp, group, channelHandler) @@ -278,11 +280,11 @@ func (ps *ProxyServer) executeRequestWithRetry( if isStream { ps.handleStreamingResponse(c, resp) } else { - ps.handleNormalResponse(c, resp) + usage = ps.handleNormalResponse(c, resp) } } - ps.logRequest(c, originalGroup, group, apiKey, startTime, resp.StatusCode, nil, isStream, upstreamURL, channelHandler, bodyBytes, models.RequestTypeFinal) + ps.logRequest(c, originalGroup, group, apiKey, startTime, resp.StatusCode, nil, isStream, upstreamURL, channelHandler, bodyBytes, models.RequestTypeFinal, usage) } func shouldFailoverOnStatusCode(statusCode int, group *models.Group) bool { @@ -293,6 +295,7 @@ func shouldFailoverOnStatusCode(statusCode int, group *models.Group) bool { } // logRequest is a helper function to create and record a request log. +// usage may be nil for streaming requests or failed requests. func (ps *ProxyServer) logRequest( c *gin.Context, originalGroup *models.Group, @@ -306,6 +309,7 @@ func (ps *ProxyServer) logRequest( channelHandler channel.ChannelProxy, bodyBytes []byte, requestType string, + usage *TokenUsage, ) { if ps.requestLogService == nil { return @@ -335,6 +339,13 @@ func (ps *ProxyServer) logRequest( RequestBody: requestBodyToLog, } + // Set token usage data extracted from the response body. + if usage != nil { + logEntry.PromptTokens = usage.PromptTokens + logEntry.CompletionTokens = usage.CompletionTokens + logEntry.TotalTokens = usage.TotalTokens + } + // Set parent group if originalGroup != nil && originalGroup.GroupType == "aggregate" && originalGroup.ID != group.ID { logEntry.ParentGroupID = originalGroup.ID diff --git a/internal/proxy/usage.go b/internal/proxy/usage.go new file mode 100644 index 000000000..ef7586c3f --- /dev/null +++ b/internal/proxy/usage.go @@ -0,0 +1,100 @@ +package proxy + +import ( + "encoding/json" + "strings" +) + +// TokenUsage holds token counts parsed from an API response. +type TokenUsage struct { + PromptTokens int64 + CompletionTokens int64 + TotalTokens int64 +} + +// openaiUsage maps the "usage" object in OpenAI /v1/chat/completions responses. +type openaiUsage struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` +} + +// anthropicUsage maps the "usage" object in Anthropic /v1/messages responses. +type anthropicUsage struct { + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` +} + +// geminiUsageMetadata maps the "usageMetadata" object in Gemini responses. +type geminiUsageMetadata struct { + PromptTokenCount int64 `json:"promptTokenCount"` + CandidatesTokenCount int64 `json:"candidatesTokenCount"` + TotalTokenCount int64 `json:"totalTokenCount"` +} + +// extractTokenUsage attempts to parse token usage from a provider response body. +// Returns nil when no usage information can be extracted (e.g. streaming, errors). +func extractTokenUsage(body []byte) *TokenUsage { + if len(body) == 0 { + return nil + } + + // Try a top-level "usage" object (OpenAI, Anthropic, OpenRouter). + var topLevel struct { + Usage json.RawMessage `json:"usage"` + } + if err := json.Unmarshal(body, &topLevel); err == nil && len(topLevel.Usage) > 0 { + if u := tryOpenAIUsage(topLevel.Usage); u != nil { + return u + } + } + + // Try "usageMetadata" (Gemini). + var gemini struct { + UsageMetadata *geminiUsageMetadata `json:"usageMetadata"` + } + if err := json.Unmarshal(body, &gemini); err == nil && gemini.UsageMetadata != nil { + return &TokenUsage{ + PromptTokens: gemini.UsageMetadata.PromptTokenCount, + CompletionTokens: gemini.UsageMetadata.CandidatesTokenCount, + TotalTokens: gemini.UsageMetadata.TotalTokenCount, + } + } + + return nil +} + +func tryOpenAIUsage(raw json.RawMessage) *TokenUsage { + // Standard OpenAI format: {"prompt_tokens":N,"completion_tokens":N,"total_tokens":N} + var oai openaiUsage + if err := json.Unmarshal(raw, &oai); err == nil && oai.TotalTokens > 0 { + return &TokenUsage{ + PromptTokens: oai.PromptTokens, + CompletionTokens: oai.CompletionTokens, + TotalTokens: oai.TotalTokens, + } + } + + // Anthropic format: {"input_tokens":N,"output_tokens":N} + // (Anthropic also nests under "usage") + var anthro anthropicUsage + if err := json.Unmarshal(raw, &anthro); err == nil && anthro.InputTokens > 0 { + return &TokenUsage{ + PromptTokens: anthro.InputTokens, + CompletionTokens: anthro.OutputTokens, + TotalTokens: anthro.InputTokens + anthro.OutputTokens, + } + } + + return nil +} + +// isChatCompletionPath returns true when the request path looks like a chat +// or text-generation endpoint where usage information is expected. +func isChatCompletionPath(path string) bool { + p := strings.ToLower(path) + return strings.Contains(p, "/chat/completions") || + strings.Contains(p, "/messages") || + strings.Contains(p, "/completions") || + strings.Contains(p, "/generate") +} diff --git a/internal/router/router.go b/internal/router/router.go index 22edd3dc6..bed9b194b 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -75,6 +75,7 @@ func NewRouter( // registerSystemRoutes 注册系统级路由 func registerSystemRoutes(router *gin.Engine, serverHandler *handler.Server) { router.GET("/health", serverHandler.Health) + router.GET("/metrics", serverHandler.Metrics) } // registerAPIRoutes 注册API路由 From 77f11ffc8d8e9e3a5f8bd6ff7929f74d88b3996a Mon Sep 17 00:00:00 2001 From: zhy Date: Wed, 20 May 2026 12:37:23 +0800 Subject: [PATCH 2/2] fix: address CodeRabbit review comments for PR #416 - Filter out streaming requests in /metrics query (is_stream = false) - Stream non-usage responses via io.Copy instead of buffering - Add CostUSD field to TokenUsage and assign in logRequest - Accept usage when any token field is non-zero, not just total - Add optional METRICS_TOKEN env var for /metrics endpoint auth Co-Authored-By: Claude Opus 4.7 --- internal/handler/metrics_handler.go | 2 +- internal/proxy/response_handlers.go | 32 +++++++++++++++-------------- internal/proxy/server.go | 1 + internal/proxy/usage.go | 15 ++++++++++---- internal/router/router.go | 31 +++++++++++++++++++++++++++- 5 files changed, 60 insertions(+), 21 deletions(-) diff --git a/internal/handler/metrics_handler.go b/internal/handler/metrics_handler.go index 4134ad655..00634ab51 100644 --- a/internal/handler/metrics_handler.go +++ b/internal/handler/metrics_handler.go @@ -38,7 +38,7 @@ func (s *Server) Metrics(c *gin.Context) { COALESCE(SUM(token_cost_usd), 0) as total_cost, COALESCE(SUM(prompt_tokens), 0) as total_prompt, COALESCE(SUM(completion_tokens), 0) as total_completion`). - Where("is_success = ?", true). + Where("is_success = ? AND is_stream = ?", true, false). Group("group_name, model"). Scan(&results).Error; err != nil { logrus.WithError(err).Error("Failed to query metrics") diff --git a/internal/proxy/response_handlers.go b/internal/proxy/response_handlers.go index 8767ccbb6..6391b673c 100644 --- a/internal/proxy/response_handlers.go +++ b/internal/proxy/response_handlers.go @@ -41,20 +41,29 @@ func (ps *ProxyServer) handleStreamingResponse(c *gin.Context, resp *http.Respon } } -// handleNormalResponse buffers the upstream response body, attempts to extract -// token usage metadata, writes the original body to the client, and returns -// the parsed usage (or nil when the body is not a JSON chat-completion response). +// handleNormalResponse buffers the upstream response body only when token usage +// extraction is needed (successful chat-completion responses). For all other +// non-stream responses it streams directly to the client via io.Copy to avoid +// buffering large payloads into memory. +// Note: response headers and status code are already set by the caller (HandleProxy). func (ps *ProxyServer) handleNormalResponse(c *gin.Context, resp *http.Response) *TokenUsage { + needsUsage := resp.StatusCode < 400 && isChatCompletionPath(c.Request.URL.Path) + + if !needsUsage { + // Stream directly to client — no buffering. + if _, err := io.Copy(c.Writer, resp.Body); err != nil { + logUpstreamError("streaming response to client", err) + } + return nil + } + + // Buffer the body for usage extraction. body, err := io.ReadAll(resp.Body) if err != nil { logUpstreamError("reading response body", err) - if _, writeErr := c.Writer.Write(body); writeErr != nil { - logUpstreamError("writing buffered body to client", writeErr) - } return nil } - // Check for gzip encoding body = handleGzipCompression(resp, body) if _, writeErr := c.Writer.Write(body); writeErr != nil { @@ -62,12 +71,5 @@ func (ps *ProxyServer) handleNormalResponse(c *gin.Context, resp *http.Response) return nil } - // Only parse usage for successful chat-completion-like responses. - if resp.StatusCode < 400 && isChatCompletionPath(c.Request.URL.Path) { - if usage := extractTokenUsage(body); usage != nil { - return usage - } - } - - return nil + return extractTokenUsage(body) } diff --git a/internal/proxy/server.go b/internal/proxy/server.go index 52a38ad9a..45239daf2 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -344,6 +344,7 @@ func (ps *ProxyServer) logRequest( logEntry.PromptTokens = usage.PromptTokens logEntry.CompletionTokens = usage.CompletionTokens logEntry.TotalTokens = usage.TotalTokens + logEntry.TokenCostUSD = usage.CostUSD } // Set parent group diff --git a/internal/proxy/usage.go b/internal/proxy/usage.go index ef7586c3f..b2a1be335 100644 --- a/internal/proxy/usage.go +++ b/internal/proxy/usage.go @@ -5,11 +5,12 @@ import ( "strings" ) -// TokenUsage holds token counts parsed from an API response. +// TokenUsage holds token counts and estimated cost parsed from an API response. type TokenUsage struct { PromptTokens int64 CompletionTokens int64 TotalTokens int64 + CostUSD float64 } // openaiUsage maps the "usage" object in OpenAI /v1/chat/completions responses. @@ -67,18 +68,24 @@ func extractTokenUsage(body []byte) *TokenUsage { func tryOpenAIUsage(raw json.RawMessage) *TokenUsage { // Standard OpenAI format: {"prompt_tokens":N,"completion_tokens":N,"total_tokens":N} var oai openaiUsage - if err := json.Unmarshal(raw, &oai); err == nil && oai.TotalTokens > 0 { + if err := json.Unmarshal(raw, &oai); err == nil && + (oai.TotalTokens > 0 || oai.PromptTokens > 0 || oai.CompletionTokens > 0) { + total := oai.TotalTokens + if total == 0 { + total = oai.PromptTokens + oai.CompletionTokens + } return &TokenUsage{ PromptTokens: oai.PromptTokens, CompletionTokens: oai.CompletionTokens, - TotalTokens: oai.TotalTokens, + TotalTokens: total, } } // Anthropic format: {"input_tokens":N,"output_tokens":N} // (Anthropic also nests under "usage") var anthro anthropicUsage - if err := json.Unmarshal(raw, &anthro); err == nil && anthro.InputTokens > 0 { + if err := json.Unmarshal(raw, &anthro); err == nil && + (anthro.InputTokens > 0 || anthro.OutputTokens > 0) { return &TokenUsage{ PromptTokens: anthro.InputTokens, CompletionTokens: anthro.OutputTokens, diff --git a/internal/router/router.go b/internal/router/router.go index bed9b194b..f05a7eeb3 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -1,6 +1,7 @@ package router import ( + "crypto/subtle" "embed" "gpt-load/internal/handler" "gpt-load/internal/i18n" @@ -10,6 +11,7 @@ import ( "gpt-load/internal/types" "io/fs" "net/http" + "os" "strings" "time" @@ -75,7 +77,34 @@ func NewRouter( // registerSystemRoutes 注册系统级路由 func registerSystemRoutes(router *gin.Engine, serverHandler *handler.Server) { router.GET("/health", serverHandler.Health) - router.GET("/metrics", serverHandler.Metrics) + + // /metrics is optionally protected by METRICS_TOKEN env var. + // When set, callers must provide Authorization: Bearer . + metricsHandler := serverHandler.Metrics + if token := os.Getenv("METRICS_TOKEN"); token != "" { + metricsHandler = metricsAuthMiddleware(token, metricsHandler) + } + router.GET("/metrics", metricsHandler) +} + +// metricsAuthMiddleware wraps a handler to require a Bearer token matching +// the configured METRICS_TOKEN. Returns 401 on missing or mismatched token. +func metricsAuthMiddleware(expectedToken string, next gin.HandlerFunc) gin.HandlerFunc { + return func(c *gin.Context) { + auth := c.GetHeader("Authorization") + if !strings.HasPrefix(auth, "Bearer ") { + c.String(http.StatusUnauthorized, "missing or invalid authorization header\n") + c.Abort() + return + } + provided := strings.TrimPrefix(auth, "Bearer ") + if subtle.ConstantTimeCompare([]byte(provided), []byte(expectedToken)) != 1 { + c.String(http.StatusUnauthorized, "invalid metrics token\n") + c.Abort() + return + } + next(c) + } } // registerAPIRoutes 注册API路由