diff --git a/docs/handlers/proxy.md b/docs/handlers/proxy.md index 63e2e3c8..21d6b8fe 100644 --- a/docs/handlers/proxy.md +++ b/docs/handlers/proxy.md @@ -26,6 +26,17 @@ The handler has the following optional fields: - `upstreams` may contain a list of `l4proxy.Upstream` structures (valid for JSON). In a Caddyfile, multiple `upstream` options or blocks are unmarshalled into a list of such structures. +- `dynamic_upstreams` may contain an upstream-source module that discovers the upstreams at runtime instead of listing + them statically, so the backend set need not be restated in config when DNS already publishes it. In a Caddyfile it + is `dynamic { ... }`. Two DNS sources are provided: + - `srv` resolves SRV records. Options: `service`, `proto`, `name` (or just `name` for the full domain), `refresh` + (default `1m`), `grace_period` (serve stale results for this long on lookup failure), `dial_network`. + - `a` resolves A/AAAA records for a `name` on a configured `port`. Options: `name`, `port`, `refresh`, + `grace_period`, `dial_network`. + + When `dynamic_upstreams` is configured, the static `upstreams` list may be empty. Note: active health checks run on + statically-configured upstreams only. + **Active health checks** occur independently in a background goroutine. They run in the background on a timer. To minimally enable active health checks, set `active` field equal to an empty structure inside `health_checks` in a JSON configuration or include any active health check option into a Caddyfile. diff --git a/integration/caddyfile_adapt/gd_handler_proxy_dynamic_srv.caddytest b/integration/caddyfile_adapt/gd_handler_proxy_dynamic_srv.caddytest new file mode 100644 index 00000000..1fba7e7e --- /dev/null +++ b/integration/caddyfile_adapt/gd_handler_proxy_dynamic_srv.caddytest @@ -0,0 +1,46 @@ +{ + layer4 { + :5432 { + route { + proxy { + dynamic srv { + service postgres + proto tcp + name db.internal + refresh 30s + } + } + } + } + } +} +---------- +{ + "apps": { + "layer4": { + "servers": { + "srv0": { + "listen": [ + ":5432" + ], + "routes": [ + { + "handle": [ + { + "dynamic_upstreams": { + "name": "db.internal", + "proto": "tcp", + "refresh": 30000000000, + "service": "postgres", + "source": "srv" + }, + "handler": "proxy" + } + ] + } + ] + } + } + } + } +} diff --git a/modules/l4proxy/proxy.go b/modules/l4proxy/proxy.go index 8756f319..f47178bd 100644 --- a/modules/l4proxy/proxy.go +++ b/modules/l4proxy/proxy.go @@ -18,6 +18,7 @@ import ( "bytes" "context" "crypto/tls" + "encoding/json" "fmt" "io" "log" @@ -46,9 +47,14 @@ func init() { // Handler is a handler that can proxy connections. type Handler struct { - // Upstreams is the list of backends to proxy to. + // Upstreams is the static list of backends to proxy to. Upstreams UpstreamPool `json:"upstreams,omitempty"` + // DynamicUpstreamsRaw is a module that discovers upstreams dynamically (per + // connection) instead of listing them statically — e.g. from DNS SRV + // records, so the backend set need not be restated in config. + DynamicUpstreamsRaw json.RawMessage `json:"dynamic_upstreams,omitempty" caddy:"namespace=layer4.proxy.upstreams inline_key=source"` + // Health checks update the status of backends, whether they are // up or down. Down backends will not be proxied to. HealthChecks *HealthChecks `json:"health_checks,omitempty"` @@ -62,6 +68,8 @@ type Handler struct { proxyProtocolVersion uint8 + dynamicUpstreams UpstreamSource + ctx caddy.Context logger *zap.Logger } @@ -98,8 +106,17 @@ func (h *Handler) Provision(ctx caddy.Context) error { return fmt.Errorf("proxy_protocol: \"%s\" should be empty, or one of \"v1\" \"v2\"", proxyProtocol) } + // load the dynamic upstreams source module, if configured + if h.DynamicUpstreamsRaw != nil { + mod, err := ctx.LoadModule(h, "DynamicUpstreamsRaw") + if err != nil { + return fmt.Errorf("loading dynamic upstreams source module: %v", err) + } + h.dynamicUpstreams = mod.(UpstreamSource) + } + // prepare upstreams - if len(h.Upstreams) == 0 { + if len(h.Upstreams) == 0 && h.dynamicUpstreams == nil { return fmt.Errorf("no upstreams defined") } for i, ups := range h.Upstreams { @@ -160,9 +177,20 @@ func (h *Handler) Handle(down *layer4.Connection, _ layer4.Handler) error { var upConns []net.Conn var proxyErr error + // determine the pool: dynamically discovered (per connection) or static + pool := h.Upstreams + if h.dynamicUpstreams != nil { + dynUpstreams, err := h.dynamicUpstreams.GetUpstreams(repl) + if err != nil { + h.logger.Error("getting dynamic upstreams", zap.Error(err)) + } else { + pool = dynUpstreams + } + } + for { // choose an available upstream - upstream := h.LoadBalancing.SelectionPolicy.Select(h.Upstreams, down) + upstream := h.LoadBalancing.SelectionPolicy.Select(pool, down) if upstream == nil { if proxyErr == nil { proxyErr = fmt.Errorf("no upstreams available") @@ -502,6 +530,11 @@ func (h *Handler) Cleanup() error { // // proxy_protocol // +// # discover upstreams dynamically instead of listing them +// dynamic [] { +// ... +// } +// // # multiple upstream options are supported // upstream [] { // ... @@ -697,6 +730,28 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { return d.Errf("duplicate %s option '%s'", wrapper, optionName) } _, h.ProxyProtocol, hasProxyProtocol = d.NextArg(), d.Val(), true + case "dynamic": + if h.DynamicUpstreamsRaw != nil { + return d.Errf("duplicate %s option '%s'", wrapper, optionName) + } + if !d.NextArg() { + return d.ArgErr() + } + sourceName := d.Val() + unm, err := caddyfile.UnmarshalModule(d, "layer4.proxy.upstreams."+sourceName) + if err != nil { + return err + } + source, ok := unm.(UpstreamSource) + if !ok { + return d.Errf("module '%s' is not an upstream source", sourceName) + } + sourceRaw := caddyconfig.JSON(source, nil) + sourceRaw, err = layer4.SetModuleNameInline("source", sourceName, sourceRaw) + if err != nil { + return d.Errf("re-encoding module '%s' configuration: %v", sourceName, err) + } + h.DynamicUpstreamsRaw = sourceRaw case "upstream": u := &Upstream{} if err := u.UnmarshalCaddyfile(d.NewFromNextSegment()); err != nil { diff --git a/modules/l4proxy/upstreams.go b/modules/l4proxy/upstreams.go new file mode 100644 index 00000000..964c0f89 --- /dev/null +++ b/modules/l4proxy/upstreams.go @@ -0,0 +1,460 @@ +// Copyright 2020 Matthew Holt +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package l4proxy + +import ( + "context" + "fmt" + "net" + "strconv" + "sync" + "time" + + "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" + "go.uber.org/zap" +) + +func init() { + caddy.RegisterModule(&SRVUpstreams{}) + caddy.RegisterModule(&AUpstreams{}) +} + +// UpstreamSource gets the list of upstreams to proxy to dynamically, instead of +// from a static configuration, so the backend set can be discovered (e.g. from +// DNS) rather than hard-coded. It is given the connection's replacer for +// placeholder expansion (and nothing connection-specific), so the same source +// can also be polled by the active health checker, which has no connection. +type UpstreamSource interface { + GetUpstreams(*caddy.Replacer) (UpstreamPool, error) +} + +// SRVUpstreams discovers upstreams from DNS SRV records, so the upstream set +// does not have to be restated in config when DNS already publishes it. Results +// are cached and refreshed periodically. Note: active health checks only run on +// statically-configured upstreams; passive health checking and connection +// counting still apply to dynamically-discovered ones. +type SRVUpstreams struct { + // The service label of the SRV record (the "_service" part). + Service string `json:"service,omitempty"` + + // The protocol label of the SRV record, "tcp" or "udp" (the "_proto" part). + Proto string `json:"proto,omitempty"` + + // The name label; or, if service and proto are empty, the entire domain + // name to look up. + Name string `json:"name,omitempty"` + + // The interval at which to refresh the SRV lookup. Results are cached + // between lookups. Default: 1m. + Refresh caddy.Duration `json:"refresh,omitempty"` + + // If > 0 and a lookup fails, keep using the cached results for up to this + // long (even though they are stale) instead of returning an error. Default: 0. + GracePeriod caddy.Duration `json:"grace_period,omitempty"` + + // Specific network to dial the discovered upstreams on (e.g. "tcp4"); the + // SRV record only provides host and port. Defaults to "tcp". + DialNetwork string `json:"dial_network,omitempty"` + + logger *zap.Logger + lookupSRV func(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) +} + +// CaddyModule returns the Caddy module information. +func (*SRVUpstreams) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + ID: "layer4.proxy.upstreams.srv", + New: func() caddy.Module { return new(SRVUpstreams) }, + } +} + +// Provision sets up the SRV upstream source. +func (su *SRVUpstreams) Provision(ctx caddy.Context) error { + su.logger = ctx.Logger() + if su.Refresh == 0 { + su.Refresh = caddy.Duration(time.Minute) + } + if su.lookupSRV == nil { + su.lookupSRV = net.DefaultResolver.LookupSRV + } + return nil +} + +// GetUpstreams resolves the SRV record (using cached results when fresh) and +// returns one upstream per record. +func (su *SRVUpstreams) GetUpstreams(repl *caddy.Replacer) (UpstreamPool, error) { + addr, service, proto, name := su.expandedAddr(repl) + + // fast path: a fresh cached result under a read lock + srvCacheMu.RLock() + cached := srvCache[addr] + srvCacheMu.RUnlock() + if cached.isFresh() { + return cached.upstreams, nil + } + + srvCacheMu.Lock() + defer srvCacheMu.Unlock() + + // re-check under the write lock in case another goroutine refreshed it + cached = srvCache[addr] + if cached.isFresh() { + return cached.upstreams, nil + } + + _, records, err := su.lookupSRV(context.Background(), service, proto, name) + if err != nil && len(records) == 0 { + // LookupSRV may return some records plus an error for invalid names; + // only treat it as fatal when nothing usable came back. + if su.GracePeriod > 0 && cached.upstreams != nil { + if c := su.logger.Check(zap.ErrorLevel, "SRV lookup failed; using stale cache"); c != nil { + c.Write(zap.String("addr", addr), zap.Error(err)) + } + cached.freshness = time.Now().Add(time.Duration(su.GracePeriod) - time.Duration(su.Refresh)) + srvCache[addr] = cached + return cached.upstreams, nil + } + return nil, fmt.Errorf("looking up SRV %s: %v", addr, err) + } + + pool := make(UpstreamPool, 0, len(records)) + for _, rec := range records { + dialAddr := net.JoinHostPort(rec.Target, strconv.Itoa(int(rec.Port))) + if su.DialNetwork != "" { + dialAddr = su.DialNetwork + "/" + dialAddr + } + up, err := newDynamicUpstream(dialAddr) + if err != nil { + if c := su.logger.Check(zap.WarnLevel, "skipping invalid SRV target"); c != nil { + c.Write(zap.String("target", dialAddr), zap.Error(err)) + } + continue + } + pool = append(pool, up) + } + + // bound the cache when inserting a brand-new entry + if cached.freshness.IsZero() && len(srvCache) >= 100 { + for k := range srvCache { + delete(srvCache, k) + break + } + } + srvCache[addr] = dnsCacheEntry{refresh: time.Duration(su.Refresh), freshness: time.Now(), upstreams: pool} + return pool, nil +} + +// expandedAddr expands placeholders in the SRV labels and returns the RFC 2782 +// address plus the individual service/proto/name used for the lookup. When both +// Service and Proto are empty, Name is treated as the full domain to look up. +func (su *SRVUpstreams) expandedAddr(repl *caddy.Replacer) (addr, service, proto, name string) { + name = repl.ReplaceAll(su.Name, "") + if su.Service == "" && su.Proto == "" { + return name, "", "", name + } + service = repl.ReplaceAll(su.Service, "") + proto = repl.ReplaceAll(su.Proto, "") + return fmt.Sprintf("_%s._%s.%s", service, proto, name), service, proto, name +} + +// UnmarshalCaddyfile sets up the SRVUpstreams from Caddyfile tokens. Syntax: +// +// srv [] { +// service +// proto +// name +// refresh +// grace_period +// dial_network +// } +func (su *SRVUpstreams) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { + _, wrapper := d.Next(), d.Val() // consume wrapper name + + if d.NextArg() { + su.Name = d.Val() + } + if d.CountRemainingArgs() > 0 { + return d.ArgErr() + } + + for nesting := d.Nesting(); d.NextBlock(nesting); { + option := d.Val() + switch option { + case "service": + if !d.NextArg() { + return d.ArgErr() + } + su.Service = d.Val() + case "proto": + if !d.NextArg() { + return d.ArgErr() + } + su.Proto = d.Val() + case "name": + if !d.NextArg() { + return d.ArgErr() + } + su.Name = d.Val() + case "refresh": + if !d.NextArg() { + return d.ArgErr() + } + dur, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("parsing %s option '%s': %v", wrapper, option, err) + } + su.Refresh = caddy.Duration(dur) + case "grace_period": + if !d.NextArg() { + return d.ArgErr() + } + dur, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("parsing %s option '%s': %v", wrapper, option, err) + } + su.GracePeriod = caddy.Duration(dur) + case "dial_network": + if !d.NextArg() { + return d.ArgErr() + } + su.DialNetwork = d.Val() + default: + return d.Errf("unrecognized %s option '%s'", wrapper, option) + } + } + return nil +} + +// newDynamicUpstream builds an Upstream (with a peer drawn from the shared peer +// pool, so health and connection state persist across refreshes) for a single +// dynamically-discovered dial address. +func newDynamicUpstream(dialAddr string) (*Upstream, error) { + address, err := parseAddress(dialAddr) + if err != nil { + return nil, err + } + p := &peer{dialAddr: dialAddr, address: address} + existingPeer, loaded := peers.LoadOrStore(dialAddr, p) + if loaded { + p = existingPeer.(*peer) + } + return &Upstream{Dial: []string{dialAddr}, peers: []*peer{p}}, nil +} + +type dnsCacheEntry struct { + refresh time.Duration + freshness time.Time + upstreams UpstreamPool +} + +func (e dnsCacheEntry) isFresh() bool { + return !e.freshness.IsZero() && time.Since(e.freshness) < e.refresh +} + +var ( + srvCacheMu sync.RWMutex + srvCache = make(map[string]dnsCacheEntry) +) + +// AUpstreams discovers upstreams from a name's DNS A/AAAA records. Since plain +// address records carry no port, every discovered address uses the configured +// Port. This fits clusters where all members share a port (e.g. a Postgres +// cluster on 5432 published behind a single name). Results are cached and +// refreshed; see SRVUpstreams for the same active-health-check caveat. +type AUpstreams struct { + // The domain name to look up. + Name string `json:"name,omitempty"` + + // The port to use for every discovered address. + Port string `json:"port,omitempty"` + + // The interval at which to refresh the lookup. Results are cached between + // lookups. Default: 1m. + Refresh caddy.Duration `json:"refresh,omitempty"` + + // If > 0 and a lookup fails, keep using the cached results for up to this + // long (even though they are stale) instead of returning an error. Default: 0. + GracePeriod caddy.Duration `json:"grace_period,omitempty"` + + // Specific network to dial the discovered upstreams on (e.g. "tcp4"). + // Defaults to "tcp". + DialNetwork string `json:"dial_network,omitempty"` + + logger *zap.Logger + lookupHost func(ctx context.Context, host string) ([]string, error) +} + +// CaddyModule returns the Caddy module information. +func (*AUpstreams) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + ID: "layer4.proxy.upstreams.a", + New: func() caddy.Module { return new(AUpstreams) }, + } +} + +// Provision sets up the A upstream source. +func (au *AUpstreams) Provision(ctx caddy.Context) error { + au.logger = ctx.Logger() + if au.Refresh == 0 { + au.Refresh = caddy.Duration(time.Minute) + } + if au.Port == "" { + return fmt.Errorf("a upstreams: port is required") + } + if au.lookupHost == nil { + au.lookupHost = net.DefaultResolver.LookupHost + } + return nil +} + +// GetUpstreams resolves the name's addresses (using cached results when fresh) +// and returns one upstream per address, all on the configured port. +func (au *AUpstreams) GetUpstreams(repl *caddy.Replacer) (UpstreamPool, error) { + name := repl.ReplaceAll(au.Name, "") + port := repl.ReplaceAll(au.Port, "") + key := net.JoinHostPort(name, port) + + aCacheMu.RLock() + cached := aCache[key] + aCacheMu.RUnlock() + if cached.isFresh() { + return cached.upstreams, nil + } + + aCacheMu.Lock() + defer aCacheMu.Unlock() + + cached = aCache[key] + if cached.isFresh() { + return cached.upstreams, nil + } + + addrs, err := au.lookupHost(context.Background(), name) + if err != nil { + if au.GracePeriod > 0 && cached.upstreams != nil { + if c := au.logger.Check(zap.ErrorLevel, "A lookup failed; using stale cache"); c != nil { + c.Write(zap.String("name", name), zap.Error(err)) + } + cached.freshness = time.Now().Add(time.Duration(au.GracePeriod) - time.Duration(au.Refresh)) + aCache[key] = cached + return cached.upstreams, nil + } + return nil, fmt.Errorf("looking up A/AAAA %s: %v", name, err) + } + + pool := make(UpstreamPool, 0, len(addrs)) + for _, ip := range addrs { + dialAddr := net.JoinHostPort(ip, port) + if au.DialNetwork != "" { + dialAddr = au.DialNetwork + "/" + dialAddr + } + up, err := newDynamicUpstream(dialAddr) + if err != nil { + if c := au.logger.Check(zap.WarnLevel, "skipping invalid A/AAAA address"); c != nil { + c.Write(zap.String("target", dialAddr), zap.Error(err)) + } + continue + } + pool = append(pool, up) + } + + if cached.freshness.IsZero() && len(aCache) >= 100 { + for k := range aCache { + delete(aCache, k) + break + } + } + aCache[key] = dnsCacheEntry{refresh: time.Duration(au.Refresh), freshness: time.Now(), upstreams: pool} + return pool, nil +} + +// UnmarshalCaddyfile sets up the AUpstreams from Caddyfile tokens. Syntax: +// +// a [] { +// name +// port +// refresh +// grace_period +// dial_network +// } +func (au *AUpstreams) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { + _, wrapper := d.Next(), d.Val() // consume wrapper name + + if d.NextArg() { + au.Name = d.Val() + } + if d.CountRemainingArgs() > 0 { + return d.ArgErr() + } + + for nesting := d.Nesting(); d.NextBlock(nesting); { + option := d.Val() + switch option { + case "name": + if !d.NextArg() { + return d.ArgErr() + } + au.Name = d.Val() + case "port": + if !d.NextArg() { + return d.ArgErr() + } + au.Port = d.Val() + case "refresh": + if !d.NextArg() { + return d.ArgErr() + } + dur, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("parsing %s option '%s': %v", wrapper, option, err) + } + au.Refresh = caddy.Duration(dur) + case "grace_period": + if !d.NextArg() { + return d.ArgErr() + } + dur, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("parsing %s option '%s': %v", wrapper, option, err) + } + au.GracePeriod = caddy.Duration(dur) + case "dial_network": + if !d.NextArg() { + return d.ArgErr() + } + au.DialNetwork = d.Val() + default: + return d.Errf("unrecognized %s option '%s'", wrapper, option) + } + } + return nil +} + +var ( + aCacheMu sync.RWMutex + aCache = make(map[string]dnsCacheEntry) +) + +// Interface guards +var ( + _ UpstreamSource = (*SRVUpstreams)(nil) + _ caddy.Provisioner = (*SRVUpstreams)(nil) + _ caddyfile.Unmarshaler = (*SRVUpstreams)(nil) + + _ UpstreamSource = (*AUpstreams)(nil) + _ caddy.Provisioner = (*AUpstreams)(nil) + _ caddyfile.Unmarshaler = (*AUpstreams)(nil) +) diff --git a/modules/l4proxy/upstreams_test.go b/modules/l4proxy/upstreams_test.go new file mode 100644 index 00000000..e7c786e5 --- /dev/null +++ b/modules/l4proxy/upstreams_test.go @@ -0,0 +1,330 @@ +// Copyright 2020 Matthew Holt +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package l4proxy + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "testing" + "time" + + "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" + "go.uber.org/zap" +) + +func srvWith(name string, recs []*net.SRV, err error, calls *int) *SRVUpstreams { + return &SRVUpstreams{ + Name: name, + Refresh: caddy.Duration(time.Minute), + logger: zap.NewNop(), + lookupSRV: func(context.Context, string, string, string) (string, []*net.SRV, error) { + if calls != nil { + *calls++ + } + return "", recs, err + }, + } +} + +func TestSRVGetUpstreamsDiscoversRecords(t *testing.T) { + recs := []*net.SRV{ + {Target: "db1.example.", Port: 5432}, + {Target: "db2.example.", Port: 5433}, + } + calls := 0 + su := srvWith("srv-discover.test", recs, nil, &calls) + + pool, err := su.GetUpstreams(caddy.NewReplacer()) + if err != nil { + t.Fatalf("getUpstreams: %v", err) + } + if len(pool) != 2 { + t.Fatalf("pool length = %d, want 2", len(pool)) + } + want := []string{"db1.example.:5432", "db2.example.:5433"} + for i, w := range want { + if pool[i].Dial[0] != w { + t.Errorf("dial[%d] = %q, want %q", i, pool[i].Dial[0], w) + } + if len(pool[i].peers) != 1 { + t.Errorf("upstream %d has %d peers, want 1", i, len(pool[i].peers)) + } + } + if calls != 1 { + t.Errorf("lookup calls = %d, want 1", calls) + } +} + +func TestSRVGetUpstreamsCaches(t *testing.T) { + calls := 0 + su := srvWith("srv-cache.test", []*net.SRV{{Target: "x.", Port: 1}}, nil, &calls) + repl := caddy.NewReplacer() + + if _, err := su.GetUpstreams(repl); err != nil { + t.Fatal(err) + } + if _, err := su.GetUpstreams(repl); err != nil { + t.Fatal(err) + } + if calls != 1 { + t.Errorf("lookup calls = %d, want 1 (second call should hit cache)", calls) + } +} + +func TestSRVGetUpstreamsLookupError(t *testing.T) { + su := srvWith("srv-error.test", nil, errors.New("dns boom"), nil) + if _, err := su.GetUpstreams(caddy.NewReplacer()); err == nil { + t.Fatal("expected an error when lookup fails and nothing is cached") + } +} + +func TestSRVExpandedAddr(t *testing.T) { + repl := caddy.NewReplacer() + + su := &SRVUpstreams{Service: "postgres", Proto: "tcp", Name: "db.local"} + addr, service, proto, name := su.expandedAddr(repl) + if addr != "_postgres._tcp.db.local" { + t.Errorf("addr = %q, want _postgres._tcp.db.local", addr) + } + if service != "postgres" || proto != "tcp" || name != "db.local" { + t.Errorf("parts = %q/%q/%q", service, proto, name) + } + + // service+proto empty: Name is the full domain + suName := &SRVUpstreams{Name: "_custom._tcp.svc"} + addr2, _, _, name2 := suName.expandedAddr(repl) + if addr2 != "_custom._tcp.svc" || name2 != "_custom._tcp.svc" { + t.Errorf("name-only addr = %q, name = %q", addr2, name2) + } +} + +func TestUnmarshalCaddyfileDynamicSRV(t *testing.T) { + d := caddyfile.NewTestDispenser("proxy {\n" + + "\tdynamic srv {\n" + + "\t\tservice postgres\n" + + "\t\tproto tcp\n" + + "\t\tname db.local\n" + + "\t\trefresh 30s\n" + + "\t}\n" + + "}") + h := new(Handler) + if err := h.UnmarshalCaddyfile(d); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if len(h.DynamicUpstreamsRaw) == 0 { + t.Fatal("DynamicUpstreamsRaw was not set") + } + var m map[string]any + if err := json.Unmarshal(h.DynamicUpstreamsRaw, &m); err != nil { + t.Fatalf("decoding DynamicUpstreamsRaw: %v", err) + } + if m["source"] != "srv" { + t.Errorf("source = %v, want srv", m["source"]) + } + if m["service"] != "postgres" || m["name"] != "db.local" { + t.Errorf("parsed fields wrong: %v", m) + } +} + +func TestUnmarshalCaddyfileDynamicErrors(t *testing.T) { + cases := map[string]string{ + "missing source": "proxy {\n\tdynamic\n}", + "unknown source": "proxy {\n\tdynamic nope\n}", + "bad srv option": "proxy {\n\tdynamic srv {\n\t\tbogus x\n\t}\n}", + } + for name, input := range cases { + t.Run(name, func(t *testing.T) { + h := new(Handler) + if err := h.UnmarshalCaddyfile(caddyfile.NewTestDispenser(input)); err == nil { + t.Fatalf("expected an error for %q, got nil", name) + } + }) + } +} + +func aWith(name, port string, addrs []string, err error, calls *int) *AUpstreams { + return &AUpstreams{ + Name: name, + Port: port, + Refresh: caddy.Duration(time.Minute), + logger: zap.NewNop(), + lookupHost: func(context.Context, string) ([]string, error) { + if calls != nil { + *calls++ + } + return addrs, err + }, + } +} + +func TestAGetUpstreamsDiscoversAddresses(t *testing.T) { + calls := 0 + au := aWith("db.a-discover.test", "5432", []string{"10.0.0.1", "10.0.0.2"}, nil, &calls) + + pool, err := au.GetUpstreams(caddy.NewReplacer()) + if err != nil { + t.Fatalf("GetUpstreams: %v", err) + } + if len(pool) != 2 { + t.Fatalf("pool length = %d, want 2", len(pool)) + } + want := []string{"10.0.0.1:5432", "10.0.0.2:5432"} + for i, w := range want { + if pool[i].Dial[0] != w { + t.Errorf("dial[%d] = %q, want %q", i, pool[i].Dial[0], w) + } + } + if calls != 1 { + t.Errorf("lookup calls = %d, want 1", calls) + } +} + +func TestAGetUpstreamsCaches(t *testing.T) { + calls := 0 + au := aWith("db.a-cache.test", "5432", []string{"10.0.0.9"}, nil, &calls) + repl := caddy.NewReplacer() + + if _, err := au.GetUpstreams(repl); err != nil { + t.Fatal(err) + } + if _, err := au.GetUpstreams(repl); err != nil { + t.Fatal(err) + } + if calls != 1 { + t.Errorf("lookup calls = %d, want 1 (second call should hit cache)", calls) + } +} + +func TestUnmarshalCaddyfileDynamicA(t *testing.T) { + d := caddyfile.NewTestDispenser("proxy {\n" + + "\tdynamic a {\n" + + "\t\tname db.local\n" + + "\t\tport 5432\n" + + "\t\trefresh 15s\n" + + "\t}\n" + + "}") + h := new(Handler) + if err := h.UnmarshalCaddyfile(d); err != nil { + t.Fatalf("unmarshal: %v", err) + } + var m map[string]any + if err := json.Unmarshal(h.DynamicUpstreamsRaw, &m); err != nil { + t.Fatalf("decoding DynamicUpstreamsRaw: %v", err) + } + if m["source"] != "a" { + t.Errorf("source = %v, want a", m["source"]) + } + if m["name"] != "db.local" || m["port"] != "5432" { + t.Errorf("parsed fields wrong: %v", m) + } +} + +func TestSRVGracePeriodServesStale(t *testing.T) { + failing := false + su := &SRVUpstreams{ + Name: "srv-grace-cov.test", + Refresh: caddy.Duration(time.Nanosecond), + GracePeriod: caddy.Duration(time.Hour), + logger: zap.NewNop(), + lookupSRV: func(context.Context, string, string, string) (string, []*net.SRV, error) { + if failing { + return "", nil, errors.New("dns boom") + } + return "", []*net.SRV{{Target: "a.example.", Port: 1}}, nil + }, + } + repl := caddy.NewReplacer() + if _, err := su.GetUpstreams(repl); err != nil { + t.Fatalf("seeding: %v", err) + } + failing = true // entry is already stale (refresh 1ns); next lookup fails + pool, err := su.GetUpstreams(repl) + if err != nil { + t.Fatalf("grace period should suppress the error: %v", err) + } + if len(pool) != 1 { + t.Errorf("expected the stale cached pool to be served, got %d", len(pool)) + } +} + +func TestAGracePeriodServesStale(t *testing.T) { + failing := false + au := &AUpstreams{ + Name: "a-grace-cov.test", + Port: "5432", + Refresh: caddy.Duration(time.Nanosecond), + GracePeriod: caddy.Duration(time.Hour), + logger: zap.NewNop(), + lookupHost: func(context.Context, string) ([]string, error) { + if failing { + return nil, errors.New("dns boom") + } + return []string{"10.0.0.1"}, nil + }, + } + repl := caddy.NewReplacer() + if _, err := au.GetUpstreams(repl); err != nil { + t.Fatalf("seeding: %v", err) + } + failing = true + pool, err := au.GetUpstreams(repl) + if err != nil { + t.Fatalf("grace period should suppress the error: %v", err) + } + if len(pool) != 1 { + t.Errorf("expected the stale cached pool to be served, got %d", len(pool)) + } +} + +func TestNewDynamicUpstreamInvalid(t *testing.T) { + // a non-numeric port makes ParseNetworkAddress fail + if _, err := newDynamicUpstream("host:notaport"); err == nil { + t.Fatal("expected an error for an invalid dial address") + } +} + +func TestSRVCacheBound(t *testing.T) { + for i := 0; i < 101; i++ { + su := srvWith(fmt.Sprintf("srv-bound-%d.test", i), []*net.SRV{{Target: "a.example.", Port: 1}}, nil, nil) + if _, err := su.GetUpstreams(caddy.NewReplacer()); err != nil { + t.Fatalf("insert %d: %v", i, err) + } + } + srvCacheMu.RLock() + n := len(srvCache) + srvCacheMu.RUnlock() + if n > 100 { + t.Errorf("srv cache not bounded: %d entries", n) + } +} + +func TestACacheBound(t *testing.T) { + for i := 0; i < 101; i++ { + au := aWith(fmt.Sprintf("a-bound-%d.test", i), "5432", []string{"10.0.0.1"}, nil, nil) + if _, err := au.GetUpstreams(caddy.NewReplacer()); err != nil { + t.Fatalf("insert %d: %v", i, err) + } + } + aCacheMu.RLock() + n := len(aCache) + aCacheMu.RUnlock() + if n > 100 { + t.Errorf("a cache not bounded: %d entries", n) + } +}