Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions providers/google/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,17 @@ func (g languageModel) prepareParams(call fantasy.Call) (*genai.GenerateContentC
if providerOptions.CachedContent != "" {
config.CachedContent = providerOptions.CachedContent
}
if providerOptions.ServiceTier != "" {
if providerOptions.ServiceTier == ServiceTierFlex && !supportsFlexProcessing(g.modelID) {
warnings = append(warnings, fantasy.CallWarning{
Type: fantasy.CallWarningTypeUnsupportedSetting,
Setting: "ServiceTier",
Details: "flex service tier is only available for Gemini 2.5 and Gemini 3 family models",
})
} else {
config.ServiceTier = genai.ServiceTier(providerOptions.ServiceTier)
}
}

if len(call.Tools) > 0 {
tools, toolChoice, toolWarnings := toGoogleTools(call.Tools, call.ToolChoice)
Expand Down Expand Up @@ -574,6 +585,13 @@ func (g *languageModel) Model() string {
return g.modelID
}

// supportsFlexProcessing reports whether the model supports the Flex service
// tier. Per Google's docs, Flex is available for the Gemini 2.5 and Gemini 3
// model families.
func supportsFlexProcessing(modelID string) bool {
return strings.Contains(modelID, "gemini-2.5") || strings.Contains(modelID, "gemini-3")
}

// Provider implements fantasy.LanguageModel.
func (g *languageModel) Provider() string {
return g.provider
Expand Down
20 changes: 20 additions & 0 deletions providers/google/provider_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,21 @@ const (
ThinkingLevelMinimal ThinkingLevel = "MINIMAL"
)

// ServiceTier selects the pricing/performance tier for a request.
type ServiceTier = string

// Predefined service tiers for the Google provider. An empty value defaults to
// the standard tier.
const (
// ServiceTierFlex selects Flex processing: lower cost in exchange for
// variable latency and best-effort, sheddable availability.
ServiceTierFlex ServiceTier = "flex"
// ServiceTierStandard selects the default standard tier.
ServiceTierStandard ServiceTier = "standard"
// ServiceTierPriority selects the priority tier.
ServiceTierPriority ServiceTier = "priority"
)

// ThinkingConfig represents thinking configuration for the Google provider.
type ThinkingConfig struct {
ThinkingBudget *int64 `json:"thinking_budget,omitempty"`
Expand Down Expand Up @@ -107,6 +122,11 @@ type ProviderOptions struct {

// Optional. A list of unique safety settings for blocking unsafe content.
SafetySettings []SafetySetting `json:"safety_settings"`

// Optional. The service tier to use for the request, e.g. ServiceTierFlex.
// An empty value defaults to the standard tier.
ServiceTier string `json:"service_tier,omitempty"`

// 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
// 'BLOCK_LOW_AND_ABOVE',
// 'BLOCK_MEDIUM_AND_ABOVE',
Expand Down
95 changes: 95 additions & 0 deletions providers/google/service_tier_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package google

import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"

"charm.land/fantasy"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestServiceTier(t *testing.T) {
t.Parallel()

prompt := fantasy.Prompt{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "Hi"}},
},
}

// call runs a single Generate against a stub server and returns the
// decoded request body plus any warnings surfaced on the response.
call := func(t *testing.T, model string, opts *ProviderOptions) (map[string]any, []fantasy.CallWarning) {
t.Helper()

var body map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
raw, _ := io.ReadAll(r.Body)
_ = json.Unmarshal(raw, &body)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"candidates": []map[string]any{
{
"content": map[string]any{
"role": "model",
"parts": []map[string]any{{"text": "Hello"}},
},
"finishReason": "STOP",
},
},
"usageMetadata": map[string]any{
"promptTokenCount": 5,
"candidatesTokenCount": 2,
"totalTokenCount": 7,
},
})
}))
t.Cleanup(server.Close)

p, err := New(
WithVertex("test-project", "us-central1"),
WithBaseURL(server.URL),
WithSkipAuth(true),
)
require.NoError(t, err)
lm, err := p.LanguageModel(t.Context(), model)
require.NoError(t, err)

c := fantasy.Call{Prompt: prompt}
if opts != nil {
c.ProviderOptions = fantasy.ProviderOptions{Name: opts}
}
resp, err := lm.Generate(t.Context(), c)
require.NoError(t, err)
return body, resp.Warnings
}

t.Run("flex sent for supported model", func(t *testing.T) {
t.Parallel()
body, warnings := call(t, "gemini-2.5-flash", &ProviderOptions{ServiceTier: ServiceTierFlex})
assert.Equal(t, "flex", body["serviceTier"])
assert.Empty(t, warnings)
})

t.Run("omitted by default", func(t *testing.T) {
t.Parallel()
body, _ := call(t, "gemini-2.5-flash", nil)
_, ok := body["serviceTier"]
assert.False(t, ok)
})

t.Run("flex warns and is dropped for unsupported model", func(t *testing.T) {
t.Parallel()
body, warnings := call(t, "gemini-1.5-flash", &ProviderOptions{ServiceTier: ServiceTierFlex})
_, ok := body["serviceTier"]
assert.False(t, ok)
require.NotEmpty(t, warnings)
assert.Equal(t, fantasy.CallWarningTypeUnsupportedSetting, warnings[0].Type)
assert.Equal(t, "ServiceTier", warnings[0].Setting)
})
}
Loading