From cc0e058a23726176d033a88b23e05da8cdced268 Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Tue, 19 May 2026 18:03:19 +0530 Subject: [PATCH 1/2] feat(spanner): add dynamic channel pool --- spanner/batch.go | 2 +- spanner/client.go | 85 +- spanner/dynamic_channel_pool.go | 1326 ++++++++++++++++++++++++++ spanner/dynamic_channel_pool_test.go | 594 ++++++++++++ spanner/location_aware_client.go | 16 + spanner/read.go | 20 +- spanner/session.go | 23 +- spanner/sessionclient.go | 8 + spanner/transaction.go | 4 +- 9 files changed, 2034 insertions(+), 44 deletions(-) create mode 100644 spanner/dynamic_channel_pool.go create mode 100644 spanner/dynamic_channel_pool_test.go diff --git a/spanner/batch.go b/spanner/batch.go index ecf002f14004..d2d6f2d66561 100644 --- a/spanner/batch.go +++ b/spanner/batch.go @@ -370,7 +370,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R nil, t.setTimestamp, t.release, - asGRPCSpannerClient(client), + requestIDHeaderProviderFromSpannerClient(client), true, false, ) diff --git a/spanner/client.go b/spanner/client.go index 4d0ae2065848..4e391e406e6d 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -285,6 +285,9 @@ type ClientConfig struct { // SessionPoolConfig is the configuration for session pool. SessionPoolConfig + // DynamicChannelPoolConfig is the opt-in configuration for dynamic gRPC channel pooling. + DynamicChannelPoolConfig DynamicChannelPoolConfig + // SessionLabels for the sessions created by this client. // See https://cloud.google.com/spanner/docs/reference/rpc/google.spanner.v1#session // for more info. @@ -527,6 +530,7 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf var pool gtransport.ConnPool var endpointClientOpts []option.ClientOption + var sc *sessionClient isFallbackEnabled := true if val, ok := os.LookupEnv("GOOGLE_SPANNER_ENABLE_GCP_FALLBACK"); ok { @@ -534,7 +538,55 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf isFallbackEnabled = b } } - if gme != nil { + + // TODO(loite): Remove as the original map cannot be changed by the user + // anyways, and the client library is also not changing it. + // Make a copy of labels. + sessionLabels := make(map[string]string) + for k, v := range config.SessionLabels { + sessionLabels[k] = v + } + + md := metadata.Pairs(resourcePrefixHeader, database) + if config.Compression == gzip.Name { + md.Append(requestsCompressionHeader, gzip.Name) + } + // Append end to end tracing header if SPANNER_ENABLE_END_TO_END_TRACING + // environment variable has been set or client has passed the opt-in + // option in ClientConfig. + endToEndTracingEnvironmentVariable := os.Getenv("SPANNER_ENABLE_END_TO_END_TRACING") + if config.EnableEndToEndTracing || endToEndTracingEnvironmentVariable == "true" { + md.Append(endToEndTracingHeader, "true") + } + + if isAFEBuiltInMetricEnabled { + md.Append(afeMetricHeader, "true") + } + if config.BatchTimeout == 0 { + config.BatchTimeout = time.Minute + } + + dcpEnabled := config.DynamicChannelPoolConfig.DCPEnabled && gme == nil && !isExperimentalLocationAPIEnabledForConfig(config) && os.Getenv("SPANNER_EMULATOR_HOST") == "" + if dcpEnabled { + reqIDInjector := new(requestIDHeaderInjector) + dcpOpts := append([]option.ClientOption{}, opts...) + dcpOpts = append(dcpOpts, + option.WithGRPCDialOption(grpc.WithChainStreamInterceptor(reqIDInjector.interceptStream)), + option.WithGRPCDialOption(grpc.WithChainUnaryInterceptor(reqIDInjector.interceptUnary)), + ) + sc = newSessionClient(nil, database, config.UserAgent, sessionLabels, config.DatabaseRole, config.DisableRouteToLeader, md, config.BatchTimeout, config.Logger, config.CallOptions) + sc.metricsTracerFactory = metricsTracerFactory + dial := func(dialCtx context.Context) (gtransport.ConnPool, error) { + return gtransport.DialPool(dialCtx, allClientOpts(1, config.Compression, config.EnableDirectAccess, dcpOpts...)...) + } + dcp, err := newDynamicChannelPool(ctx, sc, config.DynamicChannelPoolConfig, 0, dial) + if err != nil { + return nil, err + } + pool = dcp + sc.connPool = pool + sc.dynamicPool = dcp + } else if gme != nil { // Use GCPMultiEndpoint if provided. pool = &gmeWrapper{gme} endpointClientOpts = append(endpointClientOpts, opts...) @@ -608,14 +660,6 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf } } - // TODO(loite): Remove as the original map cannot be changed by the user - // anyways, and the client library is also not changing it. - // Make a copy of labels. - sessionLabels := make(map[string]string) - for k, v := range config.SessionLabels { - sessionLabels[k] = v - } - // Default configs for session pool. if config.MaxOpened == 0 { config.MaxOpened = uint64(pool.Num() * 100) @@ -623,30 +667,13 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf if config.MaxBurst == 0 { config.MaxBurst = DefaultSessionPoolConfig.MaxBurst } - if config.BatchTimeout == 0 { - config.BatchTimeout = time.Minute - } - - md := metadata.Pairs(resourcePrefixHeader, database) - if config.Compression == gzip.Name { - md.Append(requestsCompressionHeader, gzip.Name) - } - // Append end to end tracing header if SPANNER_ENABLE_END_TO_END_TRACING - // environment variable has been set or client has passed the opt-in - // option in ClientConfig. - endToEndTracingEnvironmentVariable := os.Getenv("SPANNER_ENABLE_END_TO_END_TRACING") - if config.EnableEndToEndTracing || endToEndTracingEnvironmentVariable == "true" { - md.Append(endToEndTracingHeader, "true") - } - - if isAFEBuiltInMetricEnabled { - md.Append(afeMetricHeader, "true") - } // Multiplexed sessions are always enabled as the session pool has been removed. // Create a session client. - sc := newSessionClient(pool, database, config.UserAgent, sessionLabels, config.DatabaseRole, config.DisableRouteToLeader, md, config.BatchTimeout, config.Logger, config.CallOptions) + if sc == nil { + sc = newSessionClient(pool, database, config.UserAgent, sessionLabels, config.DatabaseRole, config.DisableRouteToLeader, md, config.BatchTimeout, config.Logger, config.CallOptions) + } // Create an OpenTelemetry configuration otConfig, err := createOpenTelemetryConfig(ctx, config.OpenTelemetryMeterProvider, config.Logger, sc.id, database) diff --git a/spanner/dynamic_channel_pool.go b/spanner/dynamic_channel_pool.go new file mode 100644 index 000000000000..db481b9ec3e6 --- /dev/null +++ b/spanner/dynamic_channel_pool.go @@ -0,0 +1,1326 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spanner + +import ( + "context" + "errors" + "fmt" + "io" + "math" + "math/rand/v2" + "sort" + "sync" + "sync/atomic" + "time" + + vkit "cloud.google.com/go/spanner/apiv1" + "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/gax-go/v2" + gtransport "google.golang.org/api/transport/grpc" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +const ( + // dcpStateActive means the entry is eligible for new picks. + dcpStateActive int32 = iota + // dcpStateDraining means the entry was removed from the active slice and is + // only serving operations that already hold a reference to it. + dcpStateDraining + // dcpStateClosed means the entry has been closed and its metric slot returned. + dcpStateClosed +) + +// DynamicChannelSelectionStrategy controls how DCP chooses an active channel. +type DynamicChannelSelectionStrategy int + +const ( + // DCPPowerOfTwoLeastBusy compares two random active channels and returns the + // lower weighted-load channel. It falls back to a full scan if random picks + // only find draining entries. + DCPPowerOfTwoLeastBusy DynamicChannelSelectionStrategy = iota + // DCPRoundRobin cycles through active channels and skips draining entries. + DCPRoundRobin +) + +// DynamicChannelPoolConfig holds the knobs for Spanner dynamic channel pool. +// Zero values use DefaultDynamicChannelPoolConfig unless noted otherwise. +type DynamicChannelPoolConfig struct { + DCPEnabled bool // DCPEnabled opts the client into dynamic channel pool. + DCPInitialChannels int // DCPInitialChannels is the number of channels created at client startup. + DCPMinChannels int // DCPMinChannels is the lower bound retained during scale-down. + DCPMaxChannels int // DCPMaxChannels is the upper bound created during scale-up. + + // DCPMaxRPCPerChannel triggers event-driven scale-up when per-channel load or + // average load exceeds this value. + DCPMaxRPCPerChannel float64 + // DCPMinRPCPerChannel is the low-load threshold used by scale-down checks. + DCPMinRPCPerChannel float64 + + DCPScaleDownCheckInterval time.Duration // DCPScaleDownCheckInterval controls periodic downscale evaluation. + DCPScaleUpCooldown time.Duration // DCPScaleUpCooldown prevents repeated scale-up bursts. + DCPDownscaleConsecutiveLowLoadChecks int // DCPDownscaleConsecutiveLowLoadChecks debounces scale-down. + DCPMaxScaleUpPercent int // DCPMaxScaleUpPercent caps channels added per scale-up event. + DCPMaxRemoveChannels int // DCPMaxRemoveChannels caps channels marked draining per scale-down. + DCPDrainIdleGrace time.Duration // DCPDrainIdleGrace keeps an idle drained entry briefly before close. + DCPPrimeTimeout time.Duration // DCPPrimeTimeout bounds the SELECT 1 priming attempt for scaled-up channels. + DCPPrimeMaxAttempts int // DCPPrimeMaxAttempts bounds scaled-up channel priming retries. + DCPSelectionStrategy DynamicChannelSelectionStrategy +} + +// DefaultDynamicChannelPoolConfig returns the default DCP settings. +func DefaultDynamicChannelPoolConfig() DynamicChannelPoolConfig { + return DynamicChannelPoolConfig{ + DCPInitialChannels: 4, + DCPMinChannels: 4, + DCPMaxChannels: 256, + DCPMaxRPCPerChannel: 50, + DCPMinRPCPerChannel: 5, + DCPScaleDownCheckInterval: 30 * time.Second, + DCPScaleUpCooldown: 10 * time.Second, + DCPDownscaleConsecutiveLowLoadChecks: 3, + DCPMaxScaleUpPercent: 30, + DCPMaxRemoveChannels: 2, + DCPDrainIdleGrace: time.Minute, + DCPPrimeTimeout: 10 * time.Second, + DCPPrimeMaxAttempts: 3, + DCPSelectionStrategy: DCPPowerOfTwoLeastBusy, + } +} + +// normalizeDCPConfig fills zero-value knobs and validates internal consistency. +func normalizeDCPConfig(cfg DynamicChannelPoolConfig) (DynamicChannelPoolConfig, error) { + def := DefaultDynamicChannelPoolConfig() + initialChannelsSet := cfg.DCPInitialChannels != 0 + if cfg.DCPMinChannels == 0 { + cfg.DCPMinChannels = def.DCPMinChannels + } + if cfg.DCPInitialChannels == 0 { + cfg.DCPInitialChannels = def.DCPInitialChannels + if cfg.DCPInitialChannels < cfg.DCPMinChannels { + cfg.DCPInitialChannels = cfg.DCPMinChannels + } + } + if cfg.DCPMaxChannels == 0 { + cfg.DCPMaxChannels = def.DCPMaxChannels + } + if cfg.DCPMaxRPCPerChannel == 0 { + cfg.DCPMaxRPCPerChannel = def.DCPMaxRPCPerChannel + } + if cfg.DCPMinRPCPerChannel == 0 { + cfg.DCPMinRPCPerChannel = def.DCPMinRPCPerChannel + } + if cfg.DCPScaleDownCheckInterval == 0 { + cfg.DCPScaleDownCheckInterval = def.DCPScaleDownCheckInterval + } + if cfg.DCPScaleUpCooldown == 0 { + cfg.DCPScaleUpCooldown = def.DCPScaleUpCooldown + } + if cfg.DCPDownscaleConsecutiveLowLoadChecks == 0 { + cfg.DCPDownscaleConsecutiveLowLoadChecks = def.DCPDownscaleConsecutiveLowLoadChecks + } + if cfg.DCPMaxScaleUpPercent == 0 { + cfg.DCPMaxScaleUpPercent = def.DCPMaxScaleUpPercent + } + if cfg.DCPMaxRemoveChannels == 0 { + cfg.DCPMaxRemoveChannels = def.DCPMaxRemoveChannels + } + if cfg.DCPDrainIdleGrace == 0 { + cfg.DCPDrainIdleGrace = def.DCPDrainIdleGrace + } + if cfg.DCPPrimeTimeout == 0 { + cfg.DCPPrimeTimeout = def.DCPPrimeTimeout + } + if cfg.DCPPrimeMaxAttempts == 0 { + cfg.DCPPrimeMaxAttempts = def.DCPPrimeMaxAttempts + } + switch { + case cfg.DCPInitialChannels <= 0: + return cfg, fmt.Errorf("DCPInitialChannels must be positive") + case cfg.DCPMinChannels <= 0: + return cfg, fmt.Errorf("DCPMinChannels must be positive") + case cfg.DCPMaxChannels < cfg.DCPMinChannels: + return cfg, fmt.Errorf("DCPMaxChannels must be >= DCPMinChannels") + case initialChannelsSet && cfg.DCPInitialChannels < cfg.DCPMinChannels: + return cfg, fmt.Errorf("DCPInitialChannels must be >= DCPMinChannels when explicitly set") + case cfg.DCPInitialChannels > cfg.DCPMaxChannels: + return cfg, fmt.Errorf("DCPInitialChannels must be <= DCPMaxChannels") + case cfg.DCPMinRPCPerChannel >= cfg.DCPMaxRPCPerChannel: + return cfg, fmt.Errorf("DCPMinRPCPerChannel must be less than DCPMaxRPCPerChannel") + case cfg.DCPMaxScaleUpPercent <= 0 || cfg.DCPMaxScaleUpPercent > 100: + return cfg, fmt.Errorf("DCPMaxScaleUpPercent must be in (0,100]") + case cfg.DCPMaxRemoveChannels <= 0: + return cfg, fmt.Errorf("DCPMaxRemoveChannels must be positive") + case cfg.DCPSelectionStrategy != DCPPowerOfTwoLeastBusy && cfg.DCPSelectionStrategy != DCPRoundRobin: + return cfg, fmt.Errorf("DCPSelectionStrategy must be DCPPowerOfTwoLeastBusy or DCPRoundRobin") + } + return cfg, nil +} + +// dynamicChannelPool owns the copy-on-write slice of DCP entries and the +// background scaling/draining loops. +type dynamicChannelPool struct { + entries atomic.Pointer[[]*dcpEntry] + cfg DynamicChannelPoolConfig + targetRPCPerChannel float64 + + ctx context.Context + cancel context.CancelFunc + sc *sessionClient + database string + disableRouteToLeader bool + + dial func(context.Context) (gtransport.ConnPool, error) + rrIndex atomic.Uint64 + nextID atomic.Uint64 + totalRPCLoad atomic.Int32 + dialMu sync.Mutex + metricSlotMu sync.Mutex + freeMetricSlot []int64 + nextMetricSlot int64 + lastScaleUp atomic.Int64 + scaleUpSignal chan struct{} + done chan struct{} + stopOnce sync.Once + lowLoadRuns int + monitorMu sync.Mutex + primeSession atomic.Value // string + + drainingCount atomic.Int64 +} + +// dcpEntry represents one logical DCP slot. +type dcpEntry struct { + id uint64 + metricSlot int64 // bounded slot id used for metric cardinality + pool gtransport.ConnPool + delegate spannerClient + client spannerClient + parent *dynamicChannelPool + unaryLoad atomic.Int32 + streamLoad atomic.Int32 + errorCount atomic.Int64 // errors since process start; used for debug/diagnostics + state atomic.Int32 // dcpState* + createdAt atomic.Int64 // UnixNano creation time + lastActivity atomic.Int64 // UnixNano last pick/RPC/release time +} + +// newDynamicChannelPool creates the initial channel set and starts scale workers. +func newDynamicChannelPool(ctx context.Context, sc *sessionClient, cfg DynamicChannelPoolConfig, initial int, dial func(context.Context) (gtransport.ConnPool, error)) (*dynamicChannelPool, error) { + cfg, err := normalizeDCPConfig(cfg) + if err != nil { + return nil, err + } + if initial > 0 { + cfg.DCPInitialChannels = initial + } + poolCtx, cancel := context.WithCancel(ctx) + p := &dynamicChannelPool{ + cfg: cfg, + targetRPCPerChannel: math.Max(1, math.Floor((cfg.DCPMinRPCPerChannel+cfg.DCPMaxRPCPerChannel)/2)), + ctx: poolCtx, + cancel: cancel, + sc: sc, + database: sc.database, + disableRouteToLeader: sc.disableRouteToLeader, + dial: dial, + scaleUpSignal: make(chan struct{}, 1), + done: make(chan struct{}), + } + entries := make([]*dcpEntry, 0, cfg.DCPInitialChannels) + for i := 0; i < cfg.DCPInitialChannels; i++ { + e, err := p.newEntry(ctx, false) + if err != nil { + for _, entry := range entries { + entry.close() + } + cancel() + return nil, err + } + entries = append(entries, e) + } + p.entries.Store(&entries) + go p.scaleUpWorker() + go p.scaleDownMonitor() + return p, nil +} + +func (p *dynamicChannelPool) Num() int { return len(p.getEntries()) } +func (p *dynamicChannelPool) Conn() *grpc.ClientConn { + entries := p.getEntries() + if len(entries) == 0 { + return nil + } + return entries[0].pool.Conn() +} + +func (p *dynamicChannelPool) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error { + e, err := p.pick(ctx) + if err != nil { + return err + } + e.unaryLoad.Add(1) + p.totalRPCLoad.Add(1) + p.maybeSignalScaleUp(e) + e.lastActivity.Store(time.Now().UnixNano()) + defer func() { + e.unaryLoad.Add(-1) + p.totalRPCLoad.Add(-1) + e.lastActivity.Store(time.Now().UnixNano()) + }() + err = e.pool.Invoke(ctx, method, args, reply, opts...) + if err != nil { + e.errorCount.Add(1) + } + return err +} + +func (p *dynamicChannelPool) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + e, err := p.pick(ctx) + if err != nil { + return nil, err + } + e.streamLoad.Add(1) + p.totalRPCLoad.Add(1) + p.maybeSignalScaleUp(e) + e.lastActivity.Store(time.Now().UnixNano()) + stream, err := e.pool.NewStream(ctx, desc, method, opts...) + if err != nil { + e.streamLoad.Add(-1) + p.totalRPCLoad.Add(-1) + e.errorCount.Add(1) + return nil, err + } + return &dcpConnPoolTrackedStream{ClientStream: stream, entry: e}, nil +} + +func (p *dynamicChannelPool) Close() error { + p.stopOnce.Do(func() { p.cancel(); close(p.done) }) + p.dialMu.Lock() + defer p.dialMu.Unlock() + entries := p.getEntries() + p.entries.Store(&[]*dcpEntry{}) + var errs []error + for _, e := range entries { + if err := e.close(); err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) +} + +func (p *dynamicChannelPool) getEntries() []*dcpEntry { + ptr := p.entries.Load() + if ptr == nil { + return nil + } + return *ptr +} + +// setPrimeSession records the multiplexed session used for scaled-up channel +// priming. Initial channels are created during client startup and are not +// primed through this path. +func (p *dynamicChannelPool) setPrimeSession(id string) { + if id != "" { + p.primeSession.Store(id) + select { + case p.scaleUpSignal <- struct{}{}: + default: + } + } +} + +// hasPrimeSession reports whether a scaled-up channel can be primed. +func (p *dynamicChannelPool) hasPrimeSession() bool { + v := p.primeSession.Load() + if v == nil { + return false + } + sid, _ := v.(string) + return sid != "" +} + +// allocateMetricSlot returns a bounded per-entry metric slot. Slots are reused +// after channel close to avoid unbounded OTel attribute cardinality. +func (p *dynamicChannelPool) allocateMetricSlot() (int64, error) { + p.metricSlotMu.Lock() + defer p.metricSlotMu.Unlock() + n := len(p.freeMetricSlot) + if n > 0 { + slot := p.freeMetricSlot[n-1] + p.freeMetricSlot = p.freeMetricSlot[:n-1] + return slot, nil + } + if p.nextMetricSlot < int64(p.cfg.DCPMaxChannels) { + p.nextMetricSlot++ + return p.nextMetricSlot, nil + } + return 0, spannerErrorf(codes.ResourceExhausted, "spanner_dcp: no metric slots available") +} + +// releaseMetricSlot returns a bounded metric slot for later reuse. +func (p *dynamicChannelPool) releaseMetricSlot(slot int64) { + if slot <= 0 { + return + } + p.metricSlotMu.Lock() + p.freeMetricSlot = append(p.freeMetricSlot, slot) + p.metricSlotMu.Unlock() +} + +// newEntry dials one DCP entry. +func (p *dynamicChannelPool) newEntry(ctx context.Context, prime bool) (*dcpEntry, error) { + id := p.nextID.Add(1) + metricSlot, err := p.allocateMetricSlot() + if err != nil { + return nil, err + } + entryPool, err := p.dial(ctx) + if err != nil { + p.releaseMetricSlot(metricSlot) + return nil, err + } + e := &dcpEntry{id: id, metricSlot: metricSlot, pool: entryPool, parent: p} + now := time.Now().UnixNano() + e.createdAt.Store(now) + e.lastActivity.Store(now) + client, err := newGRPCSpannerClient(ctx, p.sc, id, gtransport.WithConnPool(e)) + if err != nil { + entryPool.Close() + p.releaseMetricSlot(metricSlot) + return nil, err + } + e.delegate = client + e.client = &dcpSpannerClient{entry: e, delegate: client} + if prime { + if err := p.prime(ctx, e); err != nil { + e.close() + return nil, err + } + } + return e, nil +} + +// prime verifies a scaled-up channel before publishing it to the active slice. +// It uses SELECT 1 through the new entry's delegate so failed channels are never +// visible to normal request picking. +func (p *dynamicChannelPool) prime(ctx context.Context, e *dcpEntry) error { + v := p.primeSession.Load() + if v == nil { + return spannerErrorf(codes.FailedPrecondition, "spanner_dcp: cannot prime channel before multiplexed session is available") + } + sid, _ := v.(string) + if sid == "" { + return spannerErrorf(codes.FailedPrecondition, "spanner_dcp: cannot prime channel before multiplexed session is available") + } + stmt := &spannerpb.ExecuteSqlRequest{Session: sid, Sql: "SELECT 1"} + var last error + for i := 0; i < p.cfg.DCPPrimeMaxAttempts; i++ { + primeCtx, cancel := context.WithTimeout(ctx, p.cfg.DCPPrimeTimeout) + _, last = e.delegate.ExecuteSql(contextWithOutgoingMetadata(primeCtx, p.sc.md, p.disableRouteToLeader), stmt) + cancel() + if last == nil { + p.recordPrimeSuccess() + return nil + } + if i < p.cfg.DCPPrimeMaxAttempts-1 { + timer := time.NewTimer(time.Duration(100*(1< 0 { + avg = float64(p.totalRPCLoad.Load()) / float64(active) + } + if float64(e.rpcLoad()) <= p.cfg.DCPMaxRPCPerChannel && avg <= p.cfg.DCPMaxRPCPerChannel { + return + } + select { + case p.scaleUpSignal <- struct{}{}: + default: + } +} + +// scaleUpWorker serializes event-driven scale-up requests. +func (p *dynamicChannelPool) scaleUpWorker() { + for { + select { + case <-p.done: + return + case <-p.scaleUpSignal: + p.scaleUp() + } + } +} + +// scaleUp adds and primes channels based on current total load. The new entries +// are published only after successful dial and SELECT 1 priming. +func (p *dynamicChannelPool) scaleUp() { + select { + case <-p.done: + return + default: + } + now := time.Now() + last := time.Unix(0, p.lastScaleUp.Load()) + if !last.IsZero() && now.Sub(last) < p.cfg.DCPScaleUpCooldown { + return + } + p.dialMu.Lock() + defer p.dialMu.Unlock() + if p.ctx.Err() != nil { + return + } + if !p.hasPrimeSession() { + return + } + entries := p.getEntries() + active := 0 + var load int32 + for _, e := range entries { + if !e.isDraining() { + active++ + load += e.rpcLoad() + } + } + if active == 0 { + return + } + desired := int(math.Ceil(float64(load) / p.targetRPCPerChannel)) + add := desired - active + capPct := int(math.Ceil(float64(active) * float64(p.cfg.DCPMaxScaleUpPercent) / 100)) + if add > capPct { + add = capPct + } + if maxAdd := p.cfg.DCPMaxChannels - len(entries); add > maxAdd { + add = maxAdd + } + if add <= 0 { + return + } + newEntries := make([]*dcpEntry, 0, add) + for i := 0; i < add; i++ { + e, err := p.newEntry(p.ctx, true) + if err == nil { + newEntries = append(newEntries, e) + } else { + logf(p.sc.logger, "spanner_dcp: failed to create or prime scaled-up channel: %v", err) + } + } + if len(newEntries) == 0 { + return + } + combined := make([]*dcpEntry, 0, len(entries)+len(newEntries)) + combined = append(combined, entries...) + combined = append(combined, newEntries...) + p.entries.Store(&combined) + p.lastScaleUp.Store(now.UnixNano()) + p.recordScaleUp(len(newEntries)) +} + +// scaleDownMonitor periodically evaluates whether sustained low load can drain +// channels. +func (p *dynamicChannelPool) scaleDownMonitor() { + t := time.NewTicker(p.cfg.DCPScaleDownCheckInterval) + defer t.Stop() + for { + select { + case <-p.done: + return + case <-t.C: + p.evaluateScaleDown() + } + } +} + +// evaluateScaleDown debounces low-load observations before removing channels. +func (p *dynamicChannelPool) evaluateScaleDown() { + p.monitorMu.Lock() + defer p.monitorMu.Unlock() + entries := p.getEntries() + active := 0 + var load int32 + for _, e := range entries { + if !e.isDraining() { + active++ + load += e.rpcLoad() + } + } + if active == 0 { + return + } + avg := float64(load) / float64(active) + if avg > p.cfg.DCPMinRPCPerChannel { + p.lowLoadRuns = 0 + return + } + p.lowLoadRuns++ + if p.lowLoadRuns < p.cfg.DCPDownscaleConsecutiveLowLoadChecks { + return + } + p.lowLoadRuns = 0 + desired := int(math.Ceil(float64(load) / p.targetRPCPerChannel)) + if desired < p.cfg.DCPMinChannels { + desired = p.cfg.DCPMinChannels + } + remove := active - desired + if remove <= 0 { + return + } + if remove > p.cfg.DCPMaxRemoveChannels { + remove = p.cfg.DCPMaxRemoveChannels + } + p.removeEntries(remove) +} + +// removeEntries revalidates low load under dialMu, removes selected entries from +// the active slice, and starts graceful drain goroutines. +func (p *dynamicChannelPool) removeEntries(count int) { + p.dialMu.Lock() + entries := p.getEntries() + active := 0 + var load int32 + type candidate struct { + e *dcpEntry + created int64 + load int32 + } + candidates := make([]candidate, 0, len(entries)) + for _, e := range entries { + if !e.isDraining() { + active++ + load += e.rpcLoad() + candidates = append(candidates, candidate{e, e.createdAt.Load(), e.weightedLoad()}) + } + } + if active == 0 { + p.dialMu.Unlock() + return + } + avg := float64(load) / float64(active) + if avg > p.cfg.DCPMinRPCPerChannel { + p.dialMu.Unlock() + return + } + desired := int(math.Ceil(float64(load) / p.targetRPCPerChannel)) + if desired < p.cfg.DCPMinChannels { + desired = p.cfg.DCPMinChannels + } + recomputed := active - desired + if recomputed <= 0 { + p.dialMu.Unlock() + return + } + if count > recomputed { + count = recomputed + } + if count > active-p.cfg.DCPMinChannels { + count = active - p.cfg.DCPMinChannels + } + if count <= 0 { + p.dialMu.Unlock() + return + } + sort.Slice(candidates, func(i, j int) bool { + if candidates[i].load != candidates[j].load { + return candidates[i].load < candidates[j].load + } + return candidates[i].created < candidates[j].created + }) + toDrain := make(map[*dcpEntry]bool) + for i := 0; i < count && i < len(candidates); i++ { + candidates[i].e.state.Store(dcpStateDraining) + toDrain[candidates[i].e] = true + } + keep := make([]*dcpEntry, 0, len(entries)-len(toDrain)) + for _, e := range entries { + if !toDrain[e] { + keep = append(keep, e) + } + } + p.entries.Store(&keep) + p.dialMu.Unlock() + p.drainingCount.Add(int64(len(toDrain))) + p.recordScaleDown(len(toDrain)) + for e := range toDrain { + go p.waitForDrainAndClose(e) + } +} + +// waitForDrainAndClose waits until a draining entry has no RPC load and has +// been idle for DCPDrainIdleGrace. +func (p *dynamicChannelPool) waitForDrainAndClose(e *dcpEntry) { + start := time.Now() + t := time.NewTicker(250 * time.Millisecond) + defer t.Stop() + for { + select { + case <-t.C: + if e.rpcLoad() == 0 && time.Since(time.Unix(0, e.lastActivity.Load())) >= p.cfg.DCPDrainIdleGrace { + e.close() + p.drainingCount.Add(-1) + p.recordDrainWait(time.Since(start)) + return + } + case <-p.ctx.Done(): + if e.client != nil { + e.close() + } else if e.pool != nil { + e.pool.Close() + } + p.drainingCount.Add(-1) + p.recordDrainWait(time.Since(start)) + return + } + } +} + +func (e *dcpEntry) Conn() *grpc.ClientConn { return e.pool.Conn() } +func (e *dcpEntry) Num() int { return 1 } +func (e *dcpEntry) Close() error { return e.close() } + +func (e *dcpEntry) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error { + return e.pool.Invoke(ctx, method, args, reply, opts...) +} + +func (e *dcpEntry) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return e.pool.NewStream(ctx, desc, method, opts...) +} + +func (e *dcpEntry) close() error { + if !e.state.CompareAndSwap(dcpStateActive, dcpStateClosed) && !e.state.CompareAndSwap(dcpStateDraining, dcpStateClosed) { + return nil + } + var errs []error + if e.client != nil { + errs = append(errs, e.client.Close()) + } + if e.pool != nil { + errs = append(errs, e.pool.Close()) + } + e.parent.releaseMetricSlot(e.metricSlot) + return errors.Join(errs...) +} + +// isDraining atomically checks whether the entry has been removed from normal +// selection and is waiting for in-flight operations to finish. +func (e *dcpEntry) isDraining() bool { return e.state.Load() == dcpStateDraining } + +// rpcLoad returns the current in-flight RPC load for this entry. +func (e *dcpEntry) rpcLoad() int32 { return e.unaryLoad.Load() + e.streamLoad.Load() } + +// weightedLoad returns the current in-flight RPC load for this entry. +func (e *dcpEntry) weightedLoad() int32 { return e.rpcLoad() } + +func (e *dcpEntry) applyPenalty(ctx context.Context, err error) {} + +func (p *dynamicChannelPool) recordScaleUp(added int) {} + +func (p *dynamicChannelPool) recordScaleDown(draining int) {} + +func (p *dynamicChannelPool) recordDrainWait(d time.Duration) {} + +func (p *dynamicChannelPool) recordSelection(ctx context.Context, e *dcpEntry) {} + +func (p *dynamicChannelPool) recordErrorPenalty(ctx context.Context) {} + +func (p *dynamicChannelPool) recordPrimeSuccess() {} + +func (p *dynamicChannelPool) recordPrimeFailure() {} + +type dcpSpannerClient struct { + entry *dcpEntry + delegate spannerClient +} + +func (c *dcpSpannerClient) CallOptions() *vkit.CallOptions { return c.delegate.CallOptions() } +func (c *dcpSpannerClient) Close() error { return c.delegate.Close() } +func (c *dcpSpannerClient) Connection() *grpc.ClientConn { return c.delegate.Connection() } + +func (c *dcpSpannerClient) startUnary(ctx context.Context) func(error) { + c.entry.unaryLoad.Add(1) + c.entry.parent.totalRPCLoad.Add(1) + c.entry.parent.maybeSignalScaleUp(c.entry) + c.entry.lastActivity.Store(time.Now().UnixNano()) + return func(err error) { + c.entry.unaryLoad.Add(-1) + c.entry.parent.totalRPCLoad.Add(-1) + c.entry.lastActivity.Store(time.Now().UnixNano()) + if err != nil { + c.entry.errorCount.Add(1) + } + } +} + +type dcpStreamRef struct { + once sync.Once + finish func(error) + closed chan struct{} +} + +func (r *dcpStreamRef) done(err error) { + r.once.Do(func() { + r.finish(err) + close(r.closed) + }) +} + +func (c *dcpSpannerClient) startStream(ctx context.Context) *dcpStreamRef { + c.entry.streamLoad.Add(1) + c.entry.parent.totalRPCLoad.Add(1) + c.entry.parent.maybeSignalScaleUp(c.entry) + c.entry.lastActivity.Store(time.Now().UnixNano()) + ref := &dcpStreamRef{closed: make(chan struct{}), finish: func(err error) { + c.entry.streamLoad.Add(-1) + c.entry.parent.totalRPCLoad.Add(-1) + c.entry.lastActivity.Store(time.Now().UnixNano()) + if err != nil && !errors.Is(err, io.EOF) { + c.entry.errorCount.Add(1) + } + }} + if ctx != nil && ctx.Done() != nil { + go func() { + select { + case <-ctx.Done(): + ref.done(ctx.Err()) + case <-ref.closed: + } + }() + } + return ref +} + +func (c *dcpSpannerClient) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest, opts ...gax.CallOption) (*spannerpb.Session, error) { + done := c.startUnary(ctx) + resp, err := c.delegate.CreateSession(ctx, req, opts...) + done(err) + return resp, err +} + +func (c *dcpSpannerClient) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest, opts ...gax.CallOption) (*spannerpb.BatchCreateSessionsResponse, error) { + done := c.startUnary(ctx) + resp, err := c.delegate.BatchCreateSessions(ctx, req, opts...) + done(err) + return resp, err +} + +func (c *dcpSpannerClient) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest, opts ...gax.CallOption) (*spannerpb.Session, error) { + done := c.startUnary(ctx) + resp, err := c.delegate.GetSession(ctx, req, opts...) + done(err) + return resp, err +} + +func (c *dcpSpannerClient) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest, opts ...gax.CallOption) *vkit.SessionIterator { + iter := c.delegate.ListSessions(ctx, req, opts...) + if iter != nil && iter.InternalFetch != nil { + fetch := iter.InternalFetch + iter.InternalFetch = func(pageSize int, pageToken string) ([]*spannerpb.Session, string, error) { + done := c.startUnary(ctx) + results, nextPageToken, err := fetch(pageSize, pageToken) + done(err) + return results, nextPageToken, err + } + } + return iter +} + +func (c *dcpSpannerClient) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest, opts ...gax.CallOption) error { + done := c.startUnary(ctx) + err := c.delegate.DeleteSession(ctx, req, opts...) + done(err) + return err +} + +func (c *dcpSpannerClient) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest, opts ...gax.CallOption) (*spannerpb.ResultSet, error) { + done := c.startUnary(ctx) + resp, err := c.delegate.ExecuteSql(ctx, req, opts...) + done(err) + return resp, err +} + +func (c *dcpSpannerClient) ExecuteStreamingSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest, opts ...gax.CallOption) (spannerpb.Spanner_ExecuteStreamingSqlClient, error) { + ref := c.startStream(ctx) + stream, err := c.delegate.ExecuteStreamingSql(ctx, req, opts...) + if err != nil { + ref.done(err) + return nil, err + } + return &dcpExecuteStreamingSqlClient{Spanner_ExecuteStreamingSqlClient: stream, ref: ref}, nil +} + +func (c *dcpSpannerClient) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest, opts ...gax.CallOption) (*spannerpb.ExecuteBatchDmlResponse, error) { + done := c.startUnary(ctx) + resp, err := c.delegate.ExecuteBatchDml(ctx, req, opts...) + done(err) + return resp, err +} + +func (c *dcpSpannerClient) Read(ctx context.Context, req *spannerpb.ReadRequest, opts ...gax.CallOption) (*spannerpb.ResultSet, error) { + done := c.startUnary(ctx) + resp, err := c.delegate.Read(ctx, req, opts...) + done(err) + return resp, err +} + +func (c *dcpSpannerClient) StreamingRead(ctx context.Context, req *spannerpb.ReadRequest, opts ...gax.CallOption) (spannerpb.Spanner_StreamingReadClient, error) { + ref := c.startStream(ctx) + stream, err := c.delegate.StreamingRead(ctx, req, opts...) + if err != nil { + ref.done(err) + return nil, err + } + return &dcpStreamingReadClient{Spanner_StreamingReadClient: stream, ref: ref}, nil +} + +func (c *dcpSpannerClient) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest, opts ...gax.CallOption) (*spannerpb.Transaction, error) { + done := c.startUnary(ctx) + resp, err := c.delegate.BeginTransaction(ctx, req, opts...) + done(err) + return resp, err +} + +func (c *dcpSpannerClient) Commit(ctx context.Context, req *spannerpb.CommitRequest, opts ...gax.CallOption) (*spannerpb.CommitResponse, error) { + done := c.startUnary(ctx) + resp, err := c.delegate.Commit(ctx, req, opts...) + done(err) + return resp, err +} + +func (c *dcpSpannerClient) Rollback(ctx context.Context, req *spannerpb.RollbackRequest, opts ...gax.CallOption) error { + done := c.startUnary(ctx) + err := c.delegate.Rollback(ctx, req, opts...) + done(err) + return err +} + +func (c *dcpSpannerClient) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest, opts ...gax.CallOption) (*spannerpb.PartitionResponse, error) { + done := c.startUnary(ctx) + resp, err := c.delegate.PartitionQuery(ctx, req, opts...) + done(err) + return resp, err +} + +func (c *dcpSpannerClient) PartitionRead(ctx context.Context, req *spannerpb.PartitionReadRequest, opts ...gax.CallOption) (*spannerpb.PartitionResponse, error) { + done := c.startUnary(ctx) + resp, err := c.delegate.PartitionRead(ctx, req, opts...) + done(err) + return resp, err +} + +func (c *dcpSpannerClient) BatchWrite(ctx context.Context, req *spannerpb.BatchWriteRequest, opts ...gax.CallOption) (spannerpb.Spanner_BatchWriteClient, error) { + ref := c.startStream(ctx) + stream, err := c.delegate.BatchWrite(ctx, req, opts...) + if err != nil { + ref.done(err) + return nil, err + } + return &dcpBatchWriteClient{Spanner_BatchWriteClient: stream, ref: ref}, nil +} + +type dcpExecuteStreamingSqlClient struct { + spannerpb.Spanner_ExecuteStreamingSqlClient + ref *dcpStreamRef +} + +func (c *dcpExecuteStreamingSqlClient) Recv() (*spannerpb.PartialResultSet, error) { + resp, err := c.Spanner_ExecuteStreamingSqlClient.Recv() + if err != nil { + c.ref.done(err) + } + return resp, err +} + +func (c *dcpExecuteStreamingSqlClient) CloseSend() error { + err := c.Spanner_ExecuteStreamingSqlClient.CloseSend() + if err != nil { + c.ref.done(err) + } + return err +} + +type dcpStreamingReadClient struct { + spannerpb.Spanner_StreamingReadClient + ref *dcpStreamRef +} + +func (c *dcpStreamingReadClient) Recv() (*spannerpb.PartialResultSet, error) { + resp, err := c.Spanner_StreamingReadClient.Recv() + if err != nil { + c.ref.done(err) + } + return resp, err +} + +func (c *dcpStreamingReadClient) CloseSend() error { + err := c.Spanner_StreamingReadClient.CloseSend() + if err != nil { + c.ref.done(err) + } + return err +} + +type dcpBatchWriteClient struct { + spannerpb.Spanner_BatchWriteClient + ref *dcpStreamRef +} + +func (c *dcpBatchWriteClient) Recv() (*spannerpb.BatchWriteResponse, error) { + resp, err := c.Spanner_BatchWriteClient.Recv() + if err != nil { + c.ref.done(err) + } + return resp, err +} + +func (c *dcpBatchWriteClient) CloseSend() error { + err := c.Spanner_BatchWriteClient.CloseSend() + if err != nil { + c.ref.done(err) + } + return err +} + +type dcpResolvingSpannerClient struct { + pool *dynamicChannelPool + entryID atomic.Uint64 +} + +func newDCPResolvingSpannerClient(pool *dynamicChannelPool, entryID uint64) *dcpResolvingSpannerClient { + c := &dcpResolvingSpannerClient{pool: pool} + c.entryID.Store(entryID) + return c +} + +func (c *dcpResolvingSpannerClient) resolve(ctx context.Context) (spannerClient, error) { + if c == nil || c.pool == nil { + return nil, errDCPNoEntries + } + if e := c.pool.lookupActive(c.entryID.Load()); e != nil { + c.pool.recordSelection(ctx, e) + e.lastActivity.Store(time.Now().UnixNano()) + return e.client, nil + } + e, err := c.pool.pick(ctx) + if err != nil { + return nil, err + } + c.entryID.Store(e.id) + return e.client, nil +} + +func (c *dcpResolvingSpannerClient) CallOptions() *vkit.CallOptions { + client, err := c.resolve(context.Background()) + if err != nil || client == nil { + return &vkit.CallOptions{} + } + return client.CallOptions() +} + +func (c *dcpResolvingSpannerClient) Close() error { return nil } + +func (c *dcpResolvingSpannerClient) Connection() *grpc.ClientConn { + client, err := c.resolve(context.Background()) + if err != nil || client == nil { + return nil + } + return client.Connection() +} + +func (c *dcpResolvingSpannerClient) generateRequestIDHeaderInjector() *requestIDWrap { + client, err := c.resolve(context.Background()) + if err != nil || client == nil { + return nil + } + gsc := asGRPCSpannerClient(client) + if gsc == nil { + return nil + } + return gsc.generateRequestIDHeaderInjector() +} + +func (c *dcpResolvingSpannerClient) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest, opts ...gax.CallOption) (*spannerpb.Session, error) { + client, err := c.resolve(ctx) + if err != nil { + return nil, err + } + return client.CreateSession(ctx, req, opts...) +} + +func (c *dcpResolvingSpannerClient) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest, opts ...gax.CallOption) (*spannerpb.BatchCreateSessionsResponse, error) { + client, err := c.resolve(ctx) + if err != nil { + return nil, err + } + return client.BatchCreateSessions(ctx, req, opts...) +} + +func (c *dcpResolvingSpannerClient) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest, opts ...gax.CallOption) (*spannerpb.Session, error) { + client, err := c.resolve(ctx) + if err != nil { + return nil, err + } + return client.GetSession(ctx, req, opts...) +} + +func (c *dcpResolvingSpannerClient) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest, opts ...gax.CallOption) *vkit.SessionIterator { + client, err := c.resolve(ctx) + if err != nil { + return &vkit.SessionIterator{} + } + return client.ListSessions(ctx, req, opts...) +} + +func (c *dcpResolvingSpannerClient) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest, opts ...gax.CallOption) error { + client, err := c.resolve(ctx) + if err != nil { + return err + } + return client.DeleteSession(ctx, req, opts...) +} + +func (c *dcpResolvingSpannerClient) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest, opts ...gax.CallOption) (*spannerpb.ResultSet, error) { + client, err := c.resolve(ctx) + if err != nil { + return nil, err + } + return client.ExecuteSql(ctx, req, opts...) +} + +func (c *dcpResolvingSpannerClient) ExecuteStreamingSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest, opts ...gax.CallOption) (spannerpb.Spanner_ExecuteStreamingSqlClient, error) { + client, err := c.resolve(ctx) + if err != nil { + return nil, err + } + return client.ExecuteStreamingSql(ctx, req, opts...) +} + +func (c *dcpResolvingSpannerClient) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest, opts ...gax.CallOption) (*spannerpb.ExecuteBatchDmlResponse, error) { + client, err := c.resolve(ctx) + if err != nil { + return nil, err + } + return client.ExecuteBatchDml(ctx, req, opts...) +} + +func (c *dcpResolvingSpannerClient) Read(ctx context.Context, req *spannerpb.ReadRequest, opts ...gax.CallOption) (*spannerpb.ResultSet, error) { + client, err := c.resolve(ctx) + if err != nil { + return nil, err + } + return client.Read(ctx, req, opts...) +} + +func (c *dcpResolvingSpannerClient) StreamingRead(ctx context.Context, req *spannerpb.ReadRequest, opts ...gax.CallOption) (spannerpb.Spanner_StreamingReadClient, error) { + client, err := c.resolve(ctx) + if err != nil { + return nil, err + } + return client.StreamingRead(ctx, req, opts...) +} + +func (c *dcpResolvingSpannerClient) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest, opts ...gax.CallOption) (*spannerpb.Transaction, error) { + client, err := c.resolve(ctx) + if err != nil { + return nil, err + } + return client.BeginTransaction(ctx, req, opts...) +} + +func (c *dcpResolvingSpannerClient) Commit(ctx context.Context, req *spannerpb.CommitRequest, opts ...gax.CallOption) (*spannerpb.CommitResponse, error) { + client, err := c.resolve(ctx) + if err != nil { + return nil, err + } + return client.Commit(ctx, req, opts...) +} + +func (c *dcpResolvingSpannerClient) Rollback(ctx context.Context, req *spannerpb.RollbackRequest, opts ...gax.CallOption) error { + client, err := c.resolve(ctx) + if err != nil { + return err + } + return client.Rollback(ctx, req, opts...) +} + +func (c *dcpResolvingSpannerClient) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest, opts ...gax.CallOption) (*spannerpb.PartitionResponse, error) { + client, err := c.resolve(ctx) + if err != nil { + return nil, err + } + return client.PartitionQuery(ctx, req, opts...) +} + +func (c *dcpResolvingSpannerClient) PartitionRead(ctx context.Context, req *spannerpb.PartitionReadRequest, opts ...gax.CallOption) (*spannerpb.PartitionResponse, error) { + client, err := c.resolve(ctx) + if err != nil { + return nil, err + } + return client.PartitionRead(ctx, req, opts...) +} + +func (c *dcpResolvingSpannerClient) BatchWrite(ctx context.Context, req *spannerpb.BatchWriteRequest, opts ...gax.CallOption) (spannerpb.Spanner_BatchWriteClient, error) { + client, err := c.resolve(ctx) + if err != nil { + return nil, err + } + return client.BatchWrite(ctx, req, opts...) +} diff --git a/spanner/dynamic_channel_pool_test.go b/spanner/dynamic_channel_pool_test.go new file mode 100644 index 000000000000..d3c6ce0affdc --- /dev/null +++ b/spanner/dynamic_channel_pool_test.go @@ -0,0 +1,594 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spanner + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "golang.org/x/sync/errgroup" + "google.golang.org/api/iterator" + "google.golang.org/api/option" + gtransport "google.golang.org/api/transport/grpc" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + . "cloud.google.com/go/spanner/internal/testutil" +) + +func testDCPConfig(initial, min, max int) DynamicChannelPoolConfig { + return DynamicChannelPoolConfig{ + DCPEnabled: true, + DCPInitialChannels: initial, + DCPMinChannels: min, + DCPMaxChannels: max, + DCPMaxRPCPerChannel: 1, + DCPMinRPCPerChannel: 0.5, + DCPScaleDownCheckInterval: 20 * time.Millisecond, + DCPScaleUpCooldown: time.Millisecond, + DCPDownscaleConsecutiveLowLoadChecks: 1, + DCPMaxScaleUpPercent: 100, + DCPMaxRemoveChannels: max, + DCPDrainIdleGrace: 10 * time.Millisecond, + DCPPrimeTimeout: time.Second, + DCPPrimeMaxAttempts: 3, + } +} + +func setupDCPMockedTestServer(t *testing.T, dcp DynamicChannelPoolConfig) (*MockedSpannerInMemTestServer, *Client, func()) { + t.Helper() + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + DisableNativeMetrics: true, + DynamicChannelPoolConfig: dcp, + }) + addSelect1Result(server) + if client.sc.dynamicPool == nil { + teardown() + t.Fatal("dynamic channel pool not enabled") + } + return server, client, teardown +} + +func drainDCPQuery(ctx context.Context, client *Client) error { + iter := client.Single().Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + return nil + } + if err != nil { + return err + } + } +} + +func TestDynamicChannelPoolOptInCreatesInitialChannels(t *testing.T) { + _, client, teardown := setupDCPMockedTestServer(t, testDCPConfig(2, 1, 4)) + defer teardown() + + if got, want := client.sc.dynamicPool.Num(), 2; got != want { + t.Fatalf("DCP initial channel count = %d, want %d", got, want) + } +} + +func TestDynamicChannelPoolScaleUpPrimesNewChannels(t *testing.T) { + server, client, teardown := setupDCPMockedTestServer(t, testDCPConfig(1, 1, 4)) + defer teardown() + server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql, SimulatedExecutionTime{MinimumExecutionTime: 2 * time.Second}) + if got := len(server.TestSpanner.DumpPings()); got != 0 { + t.Fatalf("initial DCP channel priming count = %d, want 0", got) + } + + ctx := context.Background() + var g errgroup.Group + for i := 0; i < 3; i++ { + g.Go(func() error { return drainDCPQuery(ctx, client) }) + } + + waitFor(t, func() error { + if got := client.sc.dynamicPool.Num(); got <= 1 { + return fmt.Errorf("DCP channel count = %d, want > 1", got) + } + if got := len(server.TestSpanner.DumpPings()); got == 0 { + return fmt.Errorf("DCP scale-up priming SELECT 1 count = %d, want > 0", got) + } + return nil + }) + if err := g.Wait(); err != nil { + t.Fatalf("query workload failed: %v", err) + } +} + +func TestDynamicChannelPoolScaleDownRemovesIdleChannelsToMin(t *testing.T) { + cfg := testDCPConfig(3, 1, 3) + cfg.DCPDrainIdleGrace = 200 * time.Millisecond + _, client, teardown := setupDCPMockedTestServer(t, cfg) + defer teardown() + + if err := drainDCPQuery(context.Background(), client); err != nil { + t.Fatalf("query failed: %v", err) + } + waitFor(t, func() error { + if got, want := client.sc.dynamicPool.Num(), 1; got != want { + return fmt.Errorf("DCP channel count after scale-down = %d, want %d", got, want) + } + if got := client.sc.dynamicPool.drainingCount.Load(); got == 0 { + return fmt.Errorf("DCP draining channel count = %d, want > 0 during drain grace", got) + } + return nil + }) + waitFor(t, func() error { + if got := client.sc.dynamicPool.drainingCount.Load(); got != 0 { + return fmt.Errorf("DCP draining channel count after grace = %d, want 0", got) + } + return nil + }) +} + +func TestDynamicChannelPoolScaleDownRequiresRepeatedLowLoad(t *testing.T) { + _, client, teardown := setupDCPMockedTestServer(t, testDCPConfig(3, 1, 3)) + defer teardown() + p := client.sc.dynamicPool + p.cfg.DCPDownscaleConsecutiveLowLoadChecks = 2 + + p.evaluateScaleDown() + if got, want := p.Num(), 3; got != want { + t.Fatalf("DCP channel count after first low-load check = %d, want %d", got, want) + } + p.evaluateScaleDown() + waitFor(t, func() error { + if got, want := p.Num(), 1; got != want { + return fmt.Errorf("DCP channel count after repeated low-load checks = %d, want %d", got, want) + } + return nil + }) +} + +func TestDynamicChannelPoolPickerSkipsDrainingEntries(t *testing.T) { + _, client, teardown := setupDCPMockedTestServer(t, testDCPConfig(3, 3, 3)) + defer teardown() + p := client.sc.dynamicPool + entries := p.getEntries() + for _, e := range entries[:2] { + e.state.Store(dcpStateDraining) + } + for i := 0; i < 20; i++ { + e, err := p.pick(context.Background()) + if err != nil { + t.Fatalf("pick failed: %v", err) + } + if e != entries[2] { + t.Fatalf("picker returned draining entry %d, want active entry %d", e.id, entries[2].id) + } + } +} + +func TestDynamicChannelPoolRoundRobinSkipsDrainingEntries(t *testing.T) { + _, client, teardown := setupDCPMockedTestServer(t, testDCPConfig(3, 3, 3)) + defer teardown() + p := client.sc.dynamicPool + p.cfg.DCPSelectionStrategy = DCPRoundRobin + entries := p.getEntries() + entries[1].state.Store(dcpStateDraining) + + var got []uint64 + for i := 0; i < 4; i++ { + e, err := p.pick(context.Background()) + if err != nil { + t.Fatalf("pick failed: %v", err) + } + got = append(got, e.id) + } + want := []uint64{entries[0].id, entries[2].id, entries[0].id, entries[2].id} + for i := range want { + if got[i] != want[i] { + t.Fatalf("round-robin sequence = %v, want %v", got, want) + } + } +} + +func TestDynamicChannelPoolMaxChannelsCapsScaleUp(t *testing.T) { + server, client, teardown := setupDCPMockedTestServer(t, testDCPConfig(1, 1, 2)) + defer teardown() + server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql, SimulatedExecutionTime{MinimumExecutionTime: 300 * time.Millisecond}) + + var g errgroup.Group + for i := 0; i < 8; i++ { + g.Go(func() error { return drainDCPQuery(context.Background(), client) }) + } + waitFor(t, func() error { + if got, want := client.sc.dynamicPool.Num(), 2; got != want { + return fmt.Errorf("DCP channel count under load = %d, want %d", got, want) + } + return nil + }) + if err := g.Wait(); err != nil { + t.Fatalf("query workload failed: %v", err) + } + if got, max := client.sc.dynamicPool.Num(), 2; got > max { + t.Fatalf("DCP channel count = %d, want <= %d", got, max) + } +} + +func TestDynamicChannelPoolLocationAwareDisablesDCP(t *testing.T) { + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + DisableNativeMetrics: true, + IsExperimentalHost: true, + DynamicChannelPoolConfig: testDCPConfig(1, 1, 2), + }) + defer teardown() + if client.sc.dynamicPool != nil { + t.Fatal("DCP enabled with location-aware routing, want disabled") + } +} + +type fakeDCPConnPool struct { + invokeErr error + invokeCount int + closed bool +} + +func (f *fakeDCPConnPool) Conn() *grpc.ClientConn { return nil } +func (f *fakeDCPConnPool) Num() int { return 1 } +func (f *fakeDCPConnPool) Close() error { + f.closed = true + return nil +} +func (f *fakeDCPConnPool) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error { + f.invokeCount++ + return f.invokeErr +} +func (f *fakeDCPConnPool) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return nil, f.invokeErr +} + +func TestDynamicChannelPoolScaleUpPrimeFailureDoesNotPublishEntry(t *testing.T) { + server, client, teardown := setupDCPMockedTestServer(t, testDCPConfig(1, 1, 2)) + defer teardown() + server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql, SimulatedExecutionTime{MinimumExecutionTime: 300 * time.Millisecond}) + server.TestSpanner.PutExecutionTime(MethodExecuteSql, SimulatedExecutionTime{ + Errors: []error{status.Error(codes.Internal, "prime failed")}, + KeepError: true, + }) + + var g errgroup.Group + for i := 0; i < 3; i++ { + g.Go(func() error { return drainDCPQuery(context.Background(), client) }) + } + waitFor(t, func() error { + if got := client.sc.dynamicPool.totalRPCLoad.Load(); got == 0 { + return fmt.Errorf("DCP total RPC load = %d, want in-flight workload", got) + } + return nil + }) + client.sc.dynamicPool.scaleUp() + if got, want := client.sc.dynamicPool.Num(), 1; got != want { + t.Fatalf("DCP channel count after failed prime = %d, want %d", got, want) + } + for _, e := range client.sc.dynamicPool.getEntries() { + if e.state.Load() != dcpStateActive { + t.Fatalf("active slice contains non-active entry state=%d", e.state.Load()) + } + } + if _, err := client.sc.dynamicPool.pick(context.Background()); err != nil { + t.Fatalf("pick after failed scale-up failed: %v", err) + } + if err := g.Wait(); err != nil { + t.Fatalf("query workload failed: %v", err) + } +} + +func TestDynamicChannelPoolPowerOfTwoPrefersLeastLoadedEntry(t *testing.T) { + _, client, teardown := setupDCPMockedTestServer(t, testDCPConfig(3, 3, 3)) + defer teardown() + p := client.sc.dynamicPool + entries := p.getEntries() + entries[1].unaryLoad.Store(100) + entries[2].streamLoad.Store(100) + + counts := map[uint64]int{} + for i := 0; i < 2000; i++ { + e, err := p.pick(context.Background()) + if err != nil { + t.Fatalf("pick failed: %v", err) + } + counts[e.id]++ + } + low := counts[entries[0].id] + high := counts[entries[1].id] + counts[entries[2].id] + if low <= high { + t.Fatalf("least-loaded entry picked %d times, higher-load entries picked %d times; want least-loaded preference", low, high) + } +} + +func TestDynamicChannelPoolCloseClosesActiveAndDrainingEntries(t *testing.T) { + _, client, teardown := setupDCPMockedTestServer(t, testDCPConfig(3, 3, 3)) + defer teardown() + p := client.sc.dynamicPool + entries := append([]*dcpEntry(nil), p.getEntries()...) + entries[1].state.Store(dcpStateDraining) + p.drainingCount.Add(1) + + client.Close() + if got := p.Num(); got != 0 { + t.Fatalf("DCP pool entries after close = %d, want 0", got) + } + for _, e := range entries { + if got := e.state.Load(); got != dcpStateClosed { + t.Fatalf("entry %d state after close = %d, want closed", e.id, got) + } + } +} + +func TestDynamicChannelPoolRequestIDUsesEntryChannelID(t *testing.T) { + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + dcpConfig := testDCPConfig(1, 1, 3) + dcpConfig.DCPSelectionStrategy = DCPRoundRobin + server, client, teardown := setupMockedTestServerWithConfigAndClientOptions(t, ClientConfig{ + DisableNativeMetrics: true, + DynamicChannelPoolConfig: dcpConfig, + }, clientOpts) + defer teardown() + addSelect1Result(server) + server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql, SimulatedExecutionTime{MinimumExecutionTime: 300 * time.Millisecond}) + + var g errgroup.Group + for i := 0; i < 4; i++ { + g.Go(func() error { return drainDCPQuery(context.Background(), client) }) + } + waitFor(t, func() error { + if got := client.sc.dynamicPool.Num(); got <= 1 { + return fmt.Errorf("DCP channel count = %d, want scale-up", got) + } + return nil + }) + // Run enough post-scale-up public queries to cycle through the active entries + // and observe the newly added DCP channel id. + for i := 0; i < client.sc.dynamicPool.Num(); i++ { + if err := drainDCPQuery(context.Background(), client); err != nil { + t.Fatalf("post-scale-up query failed: %v", err) + } + } + if err := g.Wait(); err != nil { + t.Fatalf("query workload failed: %v", err) + } + + observedChannelIDs := map[uint32]bool{} + for _, segments := range interceptorTracker.streamClientRequestIDSegments { + if segments.ChannelID == 0 { + t.Fatal("request id channel id is zero") + } + observedChannelIDs[segments.ChannelID] = true + } + if len(observedChannelIDs) <= 1 { + t.Fatalf("distinct DCP request-id channel ids = %v, want cardinality growth after scale-up", observedChannelIDs) + } + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +func TestDynamicChannelPoolFullScanFallbackFindsOnlyActiveEntry(t *testing.T) { + _, client, teardown := setupDCPMockedTestServer(t, testDCPConfig(4, 4, 4)) + defer teardown() + p := client.sc.dynamicPool + entries := p.getEntries() + for _, e := range entries[:3] { + e.state.Store(dcpStateDraining) + } + entries[3].unaryLoad.Store(7) + + e, err := p.pickLeastLoaded() + if err != nil { + t.Fatalf("pickLeastLoaded failed: %v", err) + } + if e != entries[3] { + t.Fatalf("full-scan fallback returned entry %d, want only active entry %d", e.id, entries[3].id) + } + picked, err := p.pick(context.Background()) + if err != nil { + t.Fatalf("pick fallback failed: %v", err) + } + if picked != entries[3] { + t.Fatalf("power-of-two fallback returned entry %d, want only active entry %d", picked.id, entries[3].id) + } +} + +func TestDCPResolvingClientRebindsDrainingEntry(t *testing.T) { + p := &dynamicChannelPool{cfg: testDCPConfig(2, 1, 2)} + entry1 := &dcpEntry{id: 1, client: &mockSpannerClient{}, parent: p} + entry2 := &dcpEntry{id: 2, client: &mockSpannerClient{}, parent: p} + entry1.state.Store(dcpStateActive) + entry2.state.Store(dcpStateActive) + entries := []*dcpEntry{entry1, entry2} + p.entries.Store(&entries) + + resolver := newDCPResolvingSpannerClient(p, entry1.id) + entry1.state.Store(dcpStateDraining) + + client, err := resolver.resolve(context.Background()) + if err != nil { + t.Fatalf("resolve failed: %v", err) + } + if client != entry2.client { + t.Fatalf("resolved client = %p, want entry2 client %p", client, entry2.client) + } + if got, want := resolver.entryID.Load(), entry2.id; got != want { + t.Fatalf("resolver entry id = %d, want %d", got, want) + } +} + +func TestDynamicChannelPoolDrainWaitsForActiveStreamLoad(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + p := &dynamicChannelPool{cfg: testDCPConfig(1, 1, 1), ctx: ctx} + p.cfg.DCPDrainIdleGrace = 10 * time.Millisecond + entry := &dcpEntry{id: 1, pool: &fakeDCPConnPool{}, client: &mockSpannerClient{}, parent: p} + entry.state.Store(dcpStateDraining) + entry.streamLoad.Store(1) + entry.lastActivity.Store(time.Now().Add(-time.Second).UnixNano()) + p.drainingCount.Store(1) + + done := make(chan struct{}) + go func() { + p.waitForDrainAndClose(entry) + close(done) + }() + + select { + case <-done: + t.Fatal("drain closed entry with active stream load") + case <-time.After(50 * time.Millisecond): + } + entry.streamLoad.Store(0) + entry.lastActivity.Store(time.Now().Add(-time.Second).UnixNano()) + waitFor(t, func() error { + select { + case <-done: + return nil + default: + return fmt.Errorf("drain did not close after stream load reached zero") + } + }) + if got := entry.state.Load(); got != dcpStateClosed { + t.Fatalf("entry state = %d, want closed", got) + } +} + +func TestDynamicChannelPoolPowerOfTwoSpreadDoesNotHerd(t *testing.T) { + _, client, teardown := setupDCPMockedTestServer(t, testDCPConfig(4, 4, 4)) + defer teardown() + p := client.sc.dynamicPool + entries := p.getEntries() + overloaded := entries[0] + overloaded.unaryLoad.Store(200) + + const workers = 400 + start := make(chan struct{}) + picked := make(chan *dcpEntry, workers) + var wg sync.WaitGroup + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + e, err := p.pick(context.Background()) + if err != nil { + picked <- nil + return + } + picked <- e + }() + } + close(start) + + counts := map[uint64]int{} + for i := 0; i < workers; i++ { + e := <-picked + if e == nil { + t.Fatalf("worker pick failed") + } + counts[e.id]++ + } + if got := counts[overloaded.id]; got > 60 { + t.Fatalf("overloaded entry picked %d times, want <= 60; counts=%v", got, counts) + } + for _, e := range entries[1:] { + if got := counts[e.id]; got < 70 { + t.Fatalf("entry %d picked %d times, want spread across low-load entries; counts=%v", e.id, got, counts) + } + } + var maxLow int + for _, e := range entries[1:] { + if got := counts[e.id]; got > maxLow { + maxLow = got + } + } + if maxLow > 190 { + t.Fatalf("parallel power-of-two picks herded onto one low-load entry: maxLow=%d counts=%v", maxLow, counts) + } + wg.Wait() +} + +func TestDynamicChannelPoolScaleUpHonorsMaxScaleUpPercent(t *testing.T) { + cfg := testDCPConfig(4, 1, 10) + cfg.DCPMaxScaleUpPercent = 25 + _, client, teardown := setupDCPMockedTestServer(t, cfg) + defer teardown() + p := client.sc.dynamicPool + p.setPrimeSession(client.sm.multiplexedSession.id) + for _, e := range p.getEntries() { + e.unaryLoad.Store(10) + } + + p.scaleUp() + if got, want := p.Num(), 5; got != want { + t.Fatalf("DCP channel count after percent-capped scale-up = %d, want %d", got, want) + } +} + +func TestDynamicChannelPoolScaleUpDialFailureDoesNotPublishEntry(t *testing.T) { + _, client, teardown := setupDCPMockedTestServer(t, testDCPConfig(1, 1, 2)) + defer teardown() + p := client.sc.dynamicPool + p.setPrimeSession(client.sm.multiplexedSession.id) + p.dial = func(context.Context) (gtransport.ConnPool, error) { + return nil, status.Error(codes.Unavailable, "dial failed") + } + initialEntries := append([]*dcpEntry(nil), p.getEntries()...) + p.getEntries()[0].unaryLoad.Store(10) + + p.scaleUp() + if got, want := p.Num(), 1; got != want { + t.Fatalf("DCP channel count after failed dial = %d, want %d", got, want) + } + if got := p.getEntries()[0]; got != initialEntries[0] { + t.Fatalf("active entry pointer changed after failed dial") + } + if got := p.lastScaleUp.Load(); got != 0 { + t.Fatalf("lastScaleUp after failed dial = %d, want 0", got) + } + for _, e := range p.getEntries() { + if e.state.Load() != dcpStateActive { + t.Fatalf("active slice contains non-active entry state=%d", e.state.Load()) + } + } +} + +func TestDynamicChannelPoolConfigDefaultsInitialChannelsToMinWhenInitialUnset(t *testing.T) { + cfg, err := normalizeDCPConfig(DynamicChannelPoolConfig{DCPEnabled: true, DCPMinChannels: 8, DCPMaxChannels: 10}) + if err != nil { + t.Fatalf("normalizeDCPConfig failed: %v", err) + } + if got, want := cfg.DCPInitialChannels, 8; got != want { + t.Fatalf("DCPInitialChannels = %d, want min channels %d", got, want) + } +} + +func TestDynamicChannelPoolConfigRejectsExplicitInitialBelowMin(t *testing.T) { + _, err := normalizeDCPConfig(DynamicChannelPoolConfig{DCPEnabled: true, DCPInitialChannels: 4, DCPMinChannels: 8, DCPMaxChannels: 10}) + if err == nil { + t.Fatal("normalizeDCPConfig succeeded, want error") + } +} diff --git a/spanner/location_aware_client.go b/spanner/location_aware_client.go index 29305e7844e5..17fe8fe0b974 100644 --- a/spanner/location_aware_client.go +++ b/spanner/location_aware_client.go @@ -96,9 +96,25 @@ func asGRPCSpannerClient(c spannerClient) *grpcSpannerClient { if lac, ok := c.(*locationAwareSpannerClient); ok { return asGRPCSpannerClient(lac.defaultClient) } + if dcp, ok := c.(*dcpSpannerClient); ok { + return asGRPCSpannerClient(dcp.delegate) + } + if dcp, ok := c.(*dcpResolvingSpannerClient); ok { + client, err := dcp.resolve(context.Background()) + if err == nil { + return asGRPCSpannerClient(client) + } + } return nil } +func requestIDHeaderProviderFromSpannerClient(c spannerClient) requestIDHeaderProvider { + if dcp, ok := c.(*dcpResolvingSpannerClient); ok { + return dcp + } + return asGRPCSpannerClient(c) +} + func newLocationAwareSpannerClient(defaultClient spannerClient, router *locationRouter, endpointCache channelEndpointCache) *locationAwareSpannerClient { return newIndexedLocationAwareSpannerClient( newLocationAwareState([]spannerClient{defaultClient}, router, endpointCache, nil), diff --git a/spanner/read.go b/spanner/read.go index 80a991a8af7f..64f1c5a70a48 100644 --- a/spanner/read.go +++ b/spanner/read.go @@ -48,6 +48,10 @@ type streamingFinalizer interface { finish() } +type requestIDHeaderProvider interface { + generateRequestIDHeaderInjector() *requestIDWrap +} + func shouldRetryResourceExhaustedInStreaming(_ spannerClient) bool { return true } @@ -72,7 +76,7 @@ func stream( rpc func(ct context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error), setTimestamp func(time.Time), release func(error), - gsc *grpcSpannerClient, + reqIDProvider requestIDHeaderProvider, ) *RowIterator { return streamWithTransactionCallbacks( ctx, @@ -86,7 +90,7 @@ func stream( nil, setTimestamp, release, - gsc, + reqIDProvider, true, false, ) @@ -104,7 +108,7 @@ func streamWithTransactionCallbacks( updatePrecommitToken func(token *sppb.MultiplexedSessionPrecommitToken), setTimestamp func(time.Time), release func(error), - gsc *grpcSpannerClient, + reqIDProvider requestIDHeaderProvider, retryResourceExhausted bool, allowRetryResourceExhaustedWithoutDelay bool, ) *RowIterator { @@ -112,7 +116,7 @@ func streamWithTransactionCallbacks( ctx, _ = startSpan(ctx, "RowIterator") return &RowIterator{ meterTracerFactory: meterTracerFactory, - streamd: newResumableStreamDecoder(ctx, cancel, logger, rpc, gsc, retryResourceExhausted, allowRetryResourceExhaustedWithoutDelay), + streamd: newResumableStreamDecoder(ctx, cancel, logger, rpc, reqIDProvider, retryResourceExhausted, allowRetryResourceExhaustedWithoutDelay), rowd: &partialResultSetDecoder{}, setTransactionID: setTransactionID, updatePrecommitToken: updatePrecommitToken, @@ -467,7 +471,7 @@ type resumableStreamDecoder struct { // backoff is used for the retry settings backoff gax.Backoff - gsc *grpcSpannerClient + reqIDProvider requestIDHeaderProvider retryResourceExhausted bool allowRetryResourceExhaustedWithoutDelay bool @@ -483,7 +487,7 @@ type resumableStreamDecoder struct { // newResumableStreamDecoder creates a new resumeableStreamDecoder instance. // Parameter rpc should be a function that creates a new stream beginning at the // restartToken if non-nil. -func newResumableStreamDecoder(ctx context.Context, cancel func(), logger *log.Logger, rpc func(ct context.Context, restartToken []byte, opts ...gax.CallOption) (streamingReceiver, error), gsc *grpcSpannerClient, retryResourceExhausted bool, allowRetryResourceExhaustedWithoutDelay bool) *resumableStreamDecoder { +func newResumableStreamDecoder(ctx context.Context, cancel func(), logger *log.Logger, rpc func(ct context.Context, restartToken []byte, opts ...gax.CallOption) (streamingReceiver, error), reqIDProvider requestIDHeaderProvider, retryResourceExhausted bool, allowRetryResourceExhaustedWithoutDelay bool) *resumableStreamDecoder { return &resumableStreamDecoder{ ctx: ctx, cancel: cancel, @@ -491,7 +495,7 @@ func newResumableStreamDecoder(ctx context.Context, cancel func(), logger *log.L rpc: rpc, maxBytesBetweenResumeTokens: atomic.LoadInt32(&maxBytesBetweenResumeTokens), backoff: DefaultRetryBackoff, - gsc: gsc, + reqIDProvider: reqIDProvider, retryResourceExhausted: retryResourceExhausted, allowRetryResourceExhaustedWithoutDelay: allowRetryResourceExhaustedWithoutDelay, } @@ -499,7 +503,7 @@ func newResumableStreamDecoder(ctx context.Context, cancel func(), logger *log.L func (d *resumableStreamDecoder) reqIDInjectorOrNew() *requestIDWrap { if d.reqIDInjector == nil { - d.reqIDInjector = d.gsc.generateRequestIDHeaderInjector() + d.reqIDInjector = d.reqIDProvider.generateRequestIDHeaderInjector() d.retryAttempt = 0 } return d.reqIDInjector diff --git a/spanner/session.go b/spanner/session.go index c74fda2249ae..872a21aa8194 100644 --- a/spanner/session.go +++ b/spanner/session.go @@ -371,6 +371,9 @@ func (p *sessionManager) finishMultiplexedSessionCreation(creation *multiplexedS if p.valid && s != nil { s.sm = p p.multiplexedSession = s + if p.sc.dynamicPool != nil { + p.sc.dynamicPool.setPrimeSession(s.id) + } p.recordStat(context.Background(), OpenSessionCount, int64(1), tag.Tag{Key: tagKeyIsMultiplexed, Value: "true"}) p.recordStat(context.Background(), SessionsCount, 1, tagNumSessions, tag.Tag{Key: tagKeyIsMultiplexed, Value: "true"}) } @@ -437,14 +440,22 @@ var errInvalidSession = spannerErrorf(codes.InvalidArgument, "invalid session") // newSessionHandleLocked creates a new session handle for the given session. // The caller must hold p.mu. -func (p *sessionManager) newSessionHandleLocked(s *session) (sh *sessionHandle) { - sh = &sessionHandle{session: s} +func (p *sessionManager) newSessionHandleLocked(ctx context.Context, s *session) (*sessionHandle, error) { + sh := &sessionHandle{session: s} + if p.sc.dynamicPool != nil && p.locationAwareState == nil { + entry, err := p.sc.dynamicPool.pick(ctx) + if err != nil { + return nil, err + } + sh.client = newDCPResolvingSpannerClient(p.sc.dynamicPool, entry.id) + return sh, nil + } client, idx := p.getRoundRobinClientLocked() if p.locationAwareState != nil && p.locationAwareState.endpointCache != nil { client = newIndexedLocationAwareSpannerClient(p.locationAwareState, idx) } sh.client = client - return sh + return sh, nil } func (p *sessionManager) getRoundRobinClientLocked() (spannerClient, int) { @@ -510,7 +521,11 @@ func (p *sessionManager) takeMultiplexed(ctx context.Context) (*sessionHandle, e s = p.multiplexedSession trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()}, "Acquired multiplexed session") - sh := p.newSessionHandleLocked(s) + sh, err := p.newSessionHandleLocked(ctx, s) + if err != nil { + p.mu.Unlock() + return nil, err + } p.mu.Unlock() p.incNumMultiplexedInUse(ctx) return sh, nil diff --git a/spanner/sessionclient.go b/spanner/sessionclient.go index 632164abef41..ef52009ab4a8 100644 --- a/spanner/sessionclient.go +++ b/spanner/sessionclient.go @@ -99,6 +99,7 @@ type sessionClient struct { disableRouteToLeader bool connPool gtransport.ConnPool + dynamicPool *dynamicChannelPool database string id string userAgent string @@ -280,6 +281,13 @@ func (sc *sessionClient) sessionWithID(id string) (*session, error) { // session. Using the same channel for all gRPC calls for a session ensures the // optimal usage of server side caches. func (sc *sessionClient) nextClient() (spannerClient, error) { + if sc.dynamicPool != nil { + entry, err := sc.dynamicPool.pick(context.Background()) + if err != nil { + return nil, err + } + return entry.client, nil + } var clientOpt option.ClientOption var channelID uint64 if _, ok := sc.connPool.(*gmeWrapper); ok { diff --git a/spanner/transaction.go b/spanner/transaction.go index 030c4b185e0a..f35f343a20dd 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -411,7 +411,7 @@ func (t *txReadOnly) ReadWithOptions(ctx context.Context, table string, keys Key t.updatePrecommitToken, t.setTimestamp, t.release, - asGRPCSpannerClient(client), + requestIDHeaderProviderFromSpannerClient(client), retryResourceExhausted, allowRetryResourceExhaustedWithoutDelay, ) @@ -766,7 +766,7 @@ func (t *txReadOnly) query(ctx context.Context, statement Statement, options Que t.updatePrecommitToken, t.setTimestamp, t.release, - asGRPCSpannerClient(client), + requestIDHeaderProviderFromSpannerClient(client), retryResourceExhausted, allowRetryResourceExhaustedWithoutDelay) } From 63cac5ed9fd639143718c16afb585ef8a472837a Mon Sep 17 00:00:00 2001 From: Rahul Yadav Date: Tue, 19 May 2026 18:05:04 +0530 Subject: [PATCH 2/2] feat(spanner): add DCP DirectPath fallback --- spanner/client.go | 20 +++- spanner/dynamic_channel_pool.go | 171 ++++++++++++++++++++++++++- spanner/dynamic_channel_pool_test.go | 71 +++++++++++ 3 files changed, 253 insertions(+), 9 deletions(-) diff --git a/spanner/client.go b/spanner/client.go index 4e391e406e6d..59fb9f5abee7 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -82,6 +82,12 @@ const ( // MinSessions for Experimental Host connection experimentalHostMinSessions = 0 + + // DirectPath fallback policy used by both non-DCP grpc-gcp fallback and the + // DCP DirectPath/CloudPath wrapper. + directPathFallbackErrorRateThreshold = float32(1) + directPathFallbackMinFailedCalls = 1 + directPathFallbackPeriod = time.Minute * 3 ) const ( @@ -579,7 +585,13 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf dial := func(dialCtx context.Context) (gtransport.ConnPool, error) { return gtransport.DialPool(dialCtx, allClientOpts(1, config.Compression, config.EnableDirectAccess, dcpOpts...)...) } - dcp, err := newDynamicChannelPool(ctx, sc, config.DynamicChannelPoolConfig, 0, dial) + var fallbackDial func(context.Context) (gtransport.ConnPool, error) + if isFallbackEnabled && isDirectPathEnabled { + fallbackDial = func(dialCtx context.Context) (gtransport.ConnPool, error) { + return gtransport.DialPool(dialCtx, append(allClientOpts(1, config.Compression, config.EnableDirectAccess, dcpOpts...), internaloption.EnableDirectPath(false))...) + } + } + dcp, err := newDynamicChannelPool(ctx, sc, config.DynamicChannelPoolConfig, 0, dial, fallbackDial) if err != nil { return nil, err } @@ -620,9 +632,9 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf fbOpts := grpcgcp.NewGCPFallbackOptions() fbOpts.EnableFallback = true - fbOpts.ErrorRateThreshold = 1 - fbOpts.MinFailedCalls = 1 - fbOpts.Period = time.Minute * 3 + fbOpts.ErrorRateThreshold = directPathFallbackErrorRateThreshold + fbOpts.MinFailedCalls = directPathFallbackMinFailedCalls + fbOpts.Period = directPathFallbackPeriod if metricsTracerFactory != nil && metricsTracerFactory.meterProvider != nil { fbOpts.MeterProvider = metricsTracerFactory.meterProvider diff --git a/spanner/dynamic_channel_pool.go b/spanner/dynamic_channel_pool.go index db481b9ec3e6..05e8cee2ed39 100644 --- a/spanner/dynamic_channel_pool.go +++ b/spanner/dynamic_channel_pool.go @@ -32,6 +32,7 @@ import ( gtransport "google.golang.org/api/transport/grpc" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) const ( @@ -183,7 +184,10 @@ type dynamicChannelPool struct { database string disableRouteToLeader bool - dial func(context.Context) (gtransport.ConnPool, error) + dial func(context.Context) (gtransport.ConnPool, error) + fallbackDial func(context.Context) (gtransport.ConnPool, error) + fallbackState *dcpFallbackState + rrIndex atomic.Uint64 nextID atomic.Uint64 totalRPCLoad atomic.Int32 @@ -202,7 +206,9 @@ type dynamicChannelPool struct { drainingCount atomic.Int64 } -// dcpEntry represents one logical DCP slot. +// dcpEntry represents one logical DCP slot. In DirectPath fallback mode the +// entry pool is a wrapper containing one DirectPath channel and one CloudPath +// fallback channel. type dcpEntry struct { id uint64 metricSlot int64 // bounded slot id used for metric cardinality @@ -219,7 +225,7 @@ type dcpEntry struct { } // newDynamicChannelPool creates the initial channel set and starts scale workers. -func newDynamicChannelPool(ctx context.Context, sc *sessionClient, cfg DynamicChannelPoolConfig, initial int, dial func(context.Context) (gtransport.ConnPool, error)) (*dynamicChannelPool, error) { +func newDynamicChannelPool(ctx context.Context, sc *sessionClient, cfg DynamicChannelPoolConfig, initial int, dial func(context.Context) (gtransport.ConnPool, error), fallbackDial func(context.Context) (gtransport.ConnPool, error)) (*dynamicChannelPool, error) { cfg, err := normalizeDCPConfig(cfg) if err != nil { return nil, err @@ -237,6 +243,8 @@ func newDynamicChannelPool(ctx context.Context, sc *sessionClient, cfg DynamicCh database: sc.database, disableRouteToLeader: sc.disableRouteToLeader, dial: dial, + fallbackDial: fallbackDial, + fallbackState: &dcpFallbackState{}, scaleUpSignal: make(chan struct{}, 1), done: make(chan struct{}), } @@ -381,18 +389,29 @@ func (p *dynamicChannelPool) releaseMetricSlot(slot int64) { p.metricSlotMu.Unlock() } -// newEntry dials one DCP entry. +// newEntry dials one DCP entry. When fallbackDial is set, the entry uses a +// DirectPath/CloudPath wrapper but still appears as one logical DCP slot. func (p *dynamicChannelPool) newEntry(ctx context.Context, prime bool) (*dcpEntry, error) { id := p.nextID.Add(1) metricSlot, err := p.allocateMetricSlot() if err != nil { return nil, err } - entryPool, err := p.dial(ctx) + primary, err := p.dial(ctx) if err != nil { p.releaseMetricSlot(metricSlot) return nil, err } + var entryPool gtransport.ConnPool = primary + if p.fallbackDial != nil { + fallback, err := p.fallbackDial(ctx) + if err != nil { + primary.Close() + p.releaseMetricSlot(metricSlot) + return nil, err + } + entryPool = &dcpFallbackSlot{id: id, direct: primary, cloud: fallback, state: p.fallbackState} + } e := &dcpEntry{id: id, metricSlot: metricSlot, pool: entryPool, parent: p} now := time.Now().UnixNano() e.createdAt.Store(now) @@ -872,6 +891,148 @@ func (e *dcpEntry) weightedLoad() int32 { return e.rpcLoad() } func (e *dcpEntry) applyPenalty(ctx context.Context, err error) {} +// dcpFallbackState is shared by all DirectPath fallback slots in the pool so a +// primary DirectPath outage can move the whole DCP wrapper pool to CloudPath. +type dcpFallbackState struct { + fallbackActive atomic.Bool + primarySuccesses atomic.Uint64 + primaryFailures atomic.Uint64 + lastPrimaryReset atomic.Int64 +} + +// dcpFallbackSlot is one logical DCP slot backed by two physical channels: +// DirectPath for the primary path and CloudPath for fallback. +type dcpFallbackSlot struct { + id uint64 + direct gtransport.ConnPool + cloud gtransport.ConnPool + state *dcpFallbackState +} + +func (s *dcpFallbackSlot) Conn() *grpc.ClientConn { + if s.state.fallbackActive.Load() { + return s.cloud.Conn() + } + return s.direct.Conn() +} + +func (s *dcpFallbackSlot) Num() int { return 1 } + +func (s *dcpFallbackSlot) Close() error { + e1 := s.direct.Close() + e2 := s.cloud.Close() + return errors.Join(e1, e2) +} + +func (s *dcpFallbackSlot) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error { + if s.state.fallbackActive.Load() { + err := s.cloud.Invoke(ctx, method, args, reply, opts...) + s.recordFallback(err) + return err + } + err := s.direct.Invoke(ctx, method, args, reply, opts...) + s.recordPrimary(err) + return err +} + +func (s *dcpFallbackSlot) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if s.state.fallbackActive.Load() { + st, err := s.cloud.NewStream(ctx, desc, method, opts...) + if err != nil { + s.recordFallback(err) + return st, err + } + return &dcpFallbackMonitoredStream{ClientStream: st, record: s.recordFallback}, nil + } + st, err := s.direct.NewStream(ctx, desc, method, opts...) + if err != nil { + s.recordPrimary(err) + return st, err + } + return &dcpFallbackMonitoredStream{ClientStream: st, record: s.recordPrimary}, nil +} + +// recordPrimary updates shared DirectPath health counters. +func (s *dcpFallbackSlot) recordPrimary(err error) { + s.resetPrimaryFallbackWindowIfNeeded() + if isDCPFallbackFailure(err) { + s.state.primaryFailures.Add(1) + s.maybeActivateFallback() + } else { + s.state.primarySuccesses.Add(1) + } +} + +// resetPrimaryFallbackWindowIfNeeded keeps DCP fallback activation counters on +// the same time window as the non-DCP grpc-gcp fallback Period. +func (s *dcpFallbackSlot) resetPrimaryFallbackWindowIfNeeded() { + now := time.Now().UnixNano() + last := s.state.lastPrimaryReset.Load() + if last == 0 { + s.state.lastPrimaryReset.CompareAndSwap(0, now) + return + } + if time.Duration(now-last) < directPathFallbackPeriod { + return + } + if s.state.lastPrimaryReset.CompareAndSwap(last, now) { + s.state.primaryFailures.Store(0) + s.state.primarySuccesses.Store(0) + } +} + +// recordFallback is intentionally a no-op. DCP DirectPath fallback is sticky +// once activated, matching non-DCP grpc-gcp fallback behavior. +func (s *dcpFallbackSlot) recordFallback(err error) { +} + +// maybeActivateFallback enables CloudPath after enough DirectPath samples show a +// sustained Unavailable rate. The activation threshold, minimum failed calls, +// and counter window intentionally mirror the non-DCP grpc-gcp fallback config. +func (s *dcpFallbackSlot) maybeActivateFallback() { + failures := s.state.primaryFailures.Load() + successes := s.state.primarySuccesses.Load() + total := failures + successes + if total == 0 || failures < uint64(directPathFallbackMinFailedCalls) { + return + } + if float32(failures)/float32(total) < directPathFallbackErrorRateThreshold { + return + } + s.state.fallbackActive.Store(true) +} + +// isDCPFallbackFailure returns true for errors that should move DirectPath +// traffic to CloudPath fallback. +func isDCPFallbackFailure(err error) bool { + c := status.Code(err) + return c == codes.Unavailable +} + +type dcpFallbackMonitoredStream struct { + grpc.ClientStream + once sync.Once + record func(error) +} + +func (s *dcpFallbackMonitoredStream) RecvMsg(m interface{}) error { + err := s.ClientStream.RecvMsg(m) + if err != nil { + s.once.Do(func() { + if errors.Is(err, io.EOF) { + s.record(nil) + } else { + s.record(err) + } + }) + } + return err +} + +func (s *dcpFallbackMonitoredStream) CloseSend() error { + return s.ClientStream.CloseSend() +} + func (p *dynamicChannelPool) recordScaleUp(added int) {} func (p *dynamicChannelPool) recordScaleDown(draining int) {} diff --git a/spanner/dynamic_channel_pool_test.go b/spanner/dynamic_channel_pool_test.go index d3c6ce0affdc..6418e837420d 100644 --- a/spanner/dynamic_channel_pool_test.go +++ b/spanner/dynamic_channel_pool_test.go @@ -259,6 +259,30 @@ func (f *fakeDCPConnPool) NewStream(ctx context.Context, desc *grpc.StreamDesc, return nil, f.invokeErr } +func TestDynamicChannelPoolDirectPathFallbackUsesSharedState(t *testing.T) { + state := &dcpFallbackState{} + primary1 := &fakeDCPConnPool{invokeErr: status.Error(codes.Unavailable, "directpath unavailable")} + cloud1 := &fakeDCPConnPool{} + primary2 := &fakeDCPConnPool{} + cloud2 := &fakeDCPConnPool{} + slot1 := &dcpFallbackSlot{id: 1, direct: primary1, cloud: cloud1, state: state} + slot2 := &dcpFallbackSlot{id: 2, direct: primary2, cloud: cloud2, state: state} + + _ = slot1.Invoke(context.Background(), "/test", nil, nil) + if !state.fallbackActive.Load() { + t.Fatal("shared fallback state inactive after DirectPath failure threshold") + } + if err := slot2.Invoke(context.Background(), "/test", nil, nil); err != nil { + t.Fatalf("fallback slot invoke failed: %v", err) + } + if got := primary2.invokeCount; got != 0 { + t.Fatalf("slot2 primary invoke count = %d, want 0 after shared fallback", got) + } + if got := cloud2.invokeCount; got != 1 { + t.Fatalf("slot2 cloud invoke count = %d, want 1 after shared fallback", got) + } +} + func TestDynamicChannelPoolScaleUpPrimeFailureDoesNotPublishEntry(t *testing.T) { server, client, teardown := setupDCPMockedTestServer(t, testDCPConfig(1, 1, 2)) defer teardown() @@ -576,6 +600,53 @@ func TestDynamicChannelPoolScaleUpDialFailureDoesNotPublishEntry(t *testing.T) { } } +func TestDynamicChannelPoolDirectPathFallbackSlotStaysPinnedAcrossFallback(t *testing.T) { + state := &dcpFallbackState{} + direct1 := &fakeDCPConnPool{} + cloud1 := &fakeDCPConnPool{} + direct2 := &fakeDCPConnPool{} + cloud2 := &fakeDCPConnPool{} + slot1 := &dcpFallbackSlot{id: 7, direct: direct1, cloud: cloud1, state: state} + slot2 := &dcpFallbackSlot{id: 8, direct: direct2, cloud: cloud2, state: state} + p := &dynamicChannelPool{cfg: testDCPConfig(2, 1, 2)} + entry1 := &dcpEntry{id: slot1.id, pool: slot1, parent: p} + entry2 := &dcpEntry{id: slot2.id, pool: slot2, parent: p} + entry1.state.Store(dcpStateActive) + entry2.state.Store(dcpStateActive) + entries := []*dcpEntry{entry1, entry2} + p.entries.Store(&entries) + + picked, err := p.pick(context.Background()) + if err != nil { + t.Fatalf("pick failed: %v", err) + } + if err := picked.pool.Invoke(context.Background(), "/test", nil, nil); err != nil { + t.Fatalf("direct invoke failed: %v", err) + } + state.fallbackActive.Store(true) + if err := picked.pool.Invoke(context.Background(), "/test", nil, nil); err != nil { + t.Fatalf("fallback invoke failed: %v", err) + } + + var pickedDirect, pickedCloud, otherDirect, otherCloud *fakeDCPConnPool + if picked.id == slot1.id { + pickedDirect, pickedCloud, otherDirect, otherCloud = direct1, cloud1, direct2, cloud2 + } else if picked.id == slot2.id { + pickedDirect, pickedCloud, otherDirect, otherCloud = direct2, cloud2, direct1, cloud1 + } else { + t.Fatalf("picked unexpected slot id %d", picked.id) + } + if got, want := pickedDirect.invokeCount, 1; got != want { + t.Fatalf("picked direct invoke count = %d, want %d", got, want) + } + if got, want := pickedCloud.invokeCount, 1; got != want { + t.Fatalf("picked cloud invoke count = %d, want %d", got, want) + } + if got := otherDirect.invokeCount + otherCloud.invokeCount; got != 0 { + t.Fatalf("other slot invoke count = %d, want 0", got) + } +} + func TestDynamicChannelPoolConfigDefaultsInitialChannelsToMinWhenInitialUnset(t *testing.T) { cfg, err := normalizeDCPConfig(DynamicChannelPoolConfig{DCPEnabled: true, DCPMinChannels: 8, DCPMaxChannels: 10}) if err != nil {