diff --git a/cmd/api/config/config.go b/cmd/api/config/config.go index 6abfe952..4e0bdadf 100644 --- a/cmd/api/config/config.go +++ b/cmd/api/config/config.go @@ -106,12 +106,13 @@ type MetricsConfig struct { // OtelConfig holds OpenTelemetry settings. type OtelConfig struct { - Enabled bool `koanf:"enabled"` - Endpoint string `koanf:"endpoint"` - ServiceName string `koanf:"service_name"` - ServiceInstanceID string `koanf:"service_instance_id"` - Insecure bool `koanf:"insecure"` - MetricExportInterval string `koanf:"metric_export_interval"` + Enabled bool `koanf:"enabled"` + Endpoint string `koanf:"endpoint"` + ServiceName string `koanf:"service_name"` + ServiceInstanceID string `koanf:"service_instance_id"` + Insecure bool `koanf:"insecure"` + MetricExportInterval string `koanf:"metric_export_interval"` + SuccessfulGetSampleRatio float64 `koanf:"successful_get_sample_ratio"` } // LoggingConfig holds log rotation and level settings. @@ -302,12 +303,13 @@ func defaultConfig() *Config { }, Otel: OtelConfig{ - Enabled: false, - Endpoint: "127.0.0.1:4317", - ServiceName: "hypeman", - ServiceInstanceID: getHostname(), - Insecure: true, - MetricExportInterval: "60s", + Enabled: false, + Endpoint: "127.0.0.1:4317", + ServiceName: "hypeman", + ServiceInstanceID: getHostname(), + Insecure: true, + MetricExportInterval: "60s", + SuccessfulGetSampleRatio: 0.1, }, Logging: LoggingConfig{ @@ -479,6 +481,9 @@ func (c *Config) Validate() error { return fmt.Errorf("otel.metric_export_interval must be a valid duration, got %q: %w", c.Otel.MetricExportInterval, err) } } + if c.Otel.SuccessfulGetSampleRatio < 0 || c.Otel.SuccessfulGetSampleRatio > 1 { + return fmt.Errorf("otel.successful_get_sample_ratio must be between 0 and 1, got %v", c.Otel.SuccessfulGetSampleRatio) + } if c.Oversubscription.CPU <= 0 { return fmt.Errorf("oversubscription.cpu must be positive, got %v", c.Oversubscription.CPU) } diff --git a/cmd/api/config/config_test.go b/cmd/api/config/config_test.go index 2923c973..6ae361a1 100644 --- a/cmd/api/config/config_test.go +++ b/cmd/api/config/config_test.go @@ -28,6 +28,9 @@ func TestDefaultConfigIncludesMetricsSettings(t *testing.T) { if cfg.Otel.MetricExportInterval != "60s" { t.Fatalf("expected default otel.metric_export_interval to be 60s, got %q", cfg.Otel.MetricExportInterval) } + if cfg.Otel.SuccessfulGetSampleRatio != 0.1 { + t.Fatalf("expected default otel.successful_get_sample_ratio to be 0.1, got %v", cfg.Otel.SuccessfulGetSampleRatio) + } } func TestLoadEnvOverridesMetricsAndOtelInterval(t *testing.T) { @@ -36,6 +39,7 @@ func TestLoadEnvOverridesMetricsAndOtelInterval(t *testing.T) { t.Setenv("METRICS__VM_LABEL_BUDGET", "350") t.Setenv("METRICS__RESOURCE_REFRESH_INTERVAL", "30s") t.Setenv("OTEL__METRIC_EXPORT_INTERVAL", "15s") + t.Setenv("OTEL__SUCCESSFUL_GET_SAMPLE_RATIO", "0.25") tmp := t.TempDir() cfgPath := filepath.Join(tmp, "config.yaml") @@ -63,6 +67,9 @@ func TestLoadEnvOverridesMetricsAndOtelInterval(t *testing.T) { if cfg.Otel.MetricExportInterval != "15s" { t.Fatalf("expected otel.metric_export_interval override, got %q", cfg.Otel.MetricExportInterval) } + if cfg.Otel.SuccessfulGetSampleRatio != 0.25 { + t.Fatalf("expected otel.successful_get_sample_ratio override, got %v", cfg.Otel.SuccessfulGetSampleRatio) + } } func TestValidateRejectsInvalidMetricsPort(t *testing.T) { @@ -85,6 +92,16 @@ func TestValidateRejectsInvalidMetricExportInterval(t *testing.T) { } } +func TestValidateRejectsInvalidSuccessfulGetSampleRatio(t *testing.T) { + cfg := defaultConfig() + cfg.Otel.SuccessfulGetSampleRatio = 1.1 + + err := cfg.Validate() + if err == nil { + t.Fatalf("expected validation error for invalid successful get sample ratio") + } +} + func TestValidateRejectsInvalidVMLabelBudget(t *testing.T) { cfg := defaultConfig() cfg.Metrics.VMLabelBudget = 0 diff --git a/cmd/api/main.go b/cmd/api/main.go index 21e6514d..01b17f57 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -78,14 +78,15 @@ func run() error { // Initialize OpenTelemetry (before wire initialization) otelCfg := otel.Config{ - Enabled: cfg.Otel.Enabled, - Endpoint: cfg.Otel.Endpoint, - ServiceName: cfg.Otel.ServiceName, - ServiceInstanceID: cfg.Otel.ServiceInstanceID, - Insecure: cfg.Otel.Insecure, - MetricExportInterval: cfg.Otel.MetricExportInterval, - Version: cfg.Version, - Env: cfg.Env, + Enabled: cfg.Otel.Enabled, + Endpoint: cfg.Otel.Endpoint, + ServiceName: cfg.Otel.ServiceName, + ServiceInstanceID: cfg.Otel.ServiceInstanceID, + Insecure: cfg.Otel.Insecure, + MetricExportInterval: cfg.Otel.MetricExportInterval, + SuccessfulGetSampleRatio: cfg.Otel.SuccessfulGetSampleRatio, + Version: cfg.Version, + Env: cfg.Env, } otelProvider, otelShutdown, err := otel.Init(context.Background(), otelCfg) @@ -149,7 +150,7 @@ func run() error { // Log OTel status if cfg.Otel.Enabled { - logger.Info("OpenTelemetry push enabled", "endpoint", cfg.Otel.Endpoint, "service", cfg.Otel.ServiceName, "metric_export_interval", cfg.Otel.MetricExportInterval) + logger.Info("OpenTelemetry push enabled", "endpoint", cfg.Otel.Endpoint, "service", cfg.Otel.ServiceName, "metric_export_interval", cfg.Otel.MetricExportInterval, "successful_get_sample_ratio", cfg.Otel.SuccessfulGetSampleRatio) } else { logger.Info("OpenTelemetry push disabled; Prometheus pull metrics remain available") } diff --git a/lib/hypervisor/firecracker/firecracker.go b/lib/hypervisor/firecracker/firecracker.go index 85e2a40c..e22e42f3 100644 --- a/lib/hypervisor/firecracker/firecracker.go +++ b/lib/hypervisor/firecracker/firecracker.go @@ -283,15 +283,30 @@ func (f *Firecracker) do(ctx context.Context, method, path string, reqBody any, attribute.String("http.method", method), attribute.String("http.route", path), ) - ctx, span := otel.Tracer("hypeman/hypervisor/firecracker").Start(ctx, "hypervisor.http "+method+" "+path, trace.WithAttributes(attrs...)) - defer span.End() + tracer := otel.Tracer("hypeman/hypervisor/firecracker") + spanName := "hypervisor.http " + method + " " + path + shouldTrace := hypervisor.ShouldTraceHypervisorHTTPSpan(method, path) + + var span trace.Span + if shouldTrace { + var spanCtx context.Context + spanCtx, span = tracer.Start(ctx, spanName, trace.WithAttributes(attrs...)) + ctx = spanCtx + defer span.End() + } + + recordError := func(err error) { + if shouldTrace { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + } + } var bodyReader io.Reader if reqBody != nil { data, err := json.Marshal(reqBody) if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) + recordError(err) return nil, fmt.Errorf("marshal request body: %w", err) } bodyReader = bytes.NewReader(data) @@ -299,8 +314,7 @@ func (f *Firecracker) do(ctx context.Context, method, path string, reqBody any, req, err := http.NewRequestWithContext(ctx, method, "http://localhost"+path, bodyReader) if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) + recordError(err) return nil, fmt.Errorf("create request: %w", err) } req.Header.Set("Accept", "application/json") @@ -310,23 +324,25 @@ func (f *Firecracker) do(ctx context.Context, method, path string, reqBody any, resp, err := f.client.Do(req) if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) + recordError(err) return nil, fmt.Errorf("request %s %s: %w", method, path, err) } defer resp.Body.Close() - span.SetAttributes(attribute.Int("http.status_code", resp.StatusCode)) + if shouldTrace { + span.SetAttributes(attribute.Int("http.status_code", resp.StatusCode)) + } data, err := io.ReadAll(resp.Body) if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) + recordError(err) return nil, fmt.Errorf("read response body: %w", err) } for _, status := range expectedStatus { if resp.StatusCode == status { - span.SetStatus(codes.Ok, "") + if shouldTrace { + span.SetStatus(codes.Ok, "") + } return data, nil } } @@ -334,11 +350,17 @@ func (f *Firecracker) do(ctx context.Context, method, path string, reqBody any, if len(data) > 0 { var apiErr apiError if err := json.Unmarshal(data, &apiErr); err == nil && apiErr.FaultMessage != "" { - span.SetStatus(codes.Error, apiErr.FaultMessage) + if shouldTrace { + span.SetAttributes(attribute.Int("http.status_code", resp.StatusCode)) + span.SetStatus(codes.Error, apiErr.FaultMessage) + } return nil, fmt.Errorf("status %d: %s", resp.StatusCode, apiErr.FaultMessage) } } - span.SetStatus(codes.Error, resp.Status) + if shouldTrace { + span.SetAttributes(attribute.Int("http.status_code", resp.StatusCode)) + span.SetStatus(codes.Error, resp.Status) + } return nil, fmt.Errorf("status %d: %s", resp.StatusCode, string(data)) } diff --git a/lib/hypervisor/tracing.go b/lib/hypervisor/tracing.go index 63908822..79888302 100644 --- a/lib/hypervisor/tracing.go +++ b/lib/hypervisor/tracing.go @@ -2,6 +2,7 @@ package hypervisor import ( "context" + "net/http" "time" "github.com/kernel/hypeman/lib/paths" @@ -52,6 +53,18 @@ func TraceAttributesFromContext(ctx context.Context) []attribute.KeyValue { return out } +func ShouldTraceHypervisorHTTPSpan(method, path string) bool { + if method != http.MethodGet { + return true + } + switch path { + case "/", "/api/v1/vm.info": + return false + default: + return true + } +} + func WrapHypervisor(hvType Type, hv Hypervisor) Hypervisor { if hv == nil { return nil @@ -115,6 +128,21 @@ func FinishTraceSpan(span trace.Span, err error) { finishTraceSpan(span, err) } +func StartDetachedTraceSpan(ctx context.Context, tracer trace.Tracer, name string, attrs ...attribute.KeyValue) (context.Context, trace.Span) { + allAttrs := TraceAttributesFromContext(ctx) + if len(attrs) > 0 { + allAttrs = append(allAttrs, attrs...) + } + + spanOpts := []trace.SpanStartOption{ + trace.WithNewRoot(), + } + if len(allAttrs) > 0 { + spanOpts = append(spanOpts, trace.WithAttributes(allAttrs...)) + } + return tracer.Start(context.Background(), name, spanOpts...) +} + func startTraceSpan(ctx context.Context, tracer trace.Tracer, name string, attrs ...attribute.KeyValue) (context.Context, trace.Span) { allAttrs := TraceAttributesFromContext(ctx) if len(attrs) > 0 { @@ -164,9 +192,17 @@ func (h *tracingHypervisor) Shutdown(ctx context.Context) (err error) { } func (h *tracingHypervisor) GetVMInfo(ctx context.Context) (_ *VMInfo, err error) { - ctx, span := startTraceSpan(ctx, h.tracer, "hypervisor.get_vm_info", h.spanAttrs(attribute.String("operation", "get_vm_info"))...) - defer func() { finishTraceSpan(span, err) }() - return h.next.GetVMInfo(ctx) + info, err := h.next.GetVMInfo(ctx) + if err != nil { + _, span := StartDetachedTraceSpan(ctx, h.tracer, "hypervisor.get_vm_info", + h.spanAttrs( + attribute.String("operation", "get_vm_info"), + attribute.String("sampled_from", "error_only"), + )..., + ) + finishTraceSpan(span, err) + } + return info, err } func (h *tracingHypervisor) Pause(ctx context.Context) (err error) { diff --git a/lib/hypervisor/tracing_test.go b/lib/hypervisor/tracing_test.go index d203af4a..8d5147c6 100644 --- a/lib/hypervisor/tracing_test.go +++ b/lib/hypervisor/tracing_test.go @@ -2,6 +2,7 @@ package hypervisor import ( "context" + "errors" "testing" "time" @@ -16,6 +17,7 @@ import ( ) type fakeHypervisor struct{} +type fakeHypervisorGetVMInfoError struct{} func (fakeHypervisor) DeleteVM(context.Context) error { return nil } func (fakeHypervisor) Shutdown(context.Context) error { return nil } @@ -33,7 +35,28 @@ func (fakeHypervisor) SetTargetGuestMemoryBytes(context.Context, int64) error { func (fakeHypervisor) GetTargetGuestMemoryBytes(context.Context) (int64, error) { return 0, nil } -func (fakeHypervisor) Capabilities() Capabilities { return Capabilities{} } +func (fakeHypervisor) Capabilities() Capabilities { return Capabilities{} } +func (fakeHypervisorGetVMInfoError) DeleteVM(context.Context) error { return nil } +func (fakeHypervisorGetVMInfoError) Shutdown(context.Context) error { return nil } +func (fakeHypervisorGetVMInfoError) GetVMInfo(context.Context) (*VMInfo, error) { + return nil, errors.New("vm info failed") +} +func (fakeHypervisorGetVMInfoError) Pause(context.Context) error { return nil } +func (fakeHypervisorGetVMInfoError) Resume(context.Context) error { return nil } +func (fakeHypervisorGetVMInfoError) Snapshot(context.Context, string) error { return nil } +func (fakeHypervisorGetVMInfoError) ResizeMemory(context.Context, int64) error { + return nil +} +func (fakeHypervisorGetVMInfoError) ResizeMemoryAndWait(context.Context, int64, time.Duration) error { + return nil +} +func (fakeHypervisorGetVMInfoError) SetTargetGuestMemoryBytes(context.Context, int64) error { + return nil +} +func (fakeHypervisorGetVMInfoError) GetTargetGuestMemoryBytes(context.Context) (int64, error) { + return 0, nil +} +func (fakeHypervisorGetVMInfoError) Capabilities() Capabilities { return Capabilities{} } type fakeStarter struct { returned Hypervisor @@ -100,6 +123,37 @@ func TestWrapVMStarterWrapsReturnedHypervisor(t *testing.T) { assert.Equal(t, string(TypeCloudHypervisor), attrs["hypervisor"]) } +func TestWrapHypervisorSkipsGetVMInfoTraceByDefault(t *testing.T) { + recorder, _ := newTestTracerProvider(t) + + hv := WrapHypervisor(TypeQEMU, fakeHypervisor{}) + _, err := hv.GetVMInfo(context.Background()) + require.NoError(t, err) + + for _, span := range recorder.Ended() { + if span.Name() == "hypervisor.get_vm_info" { + t.Fatalf("expected get vm info to be skipped by default") + } + } +} + +func TestWrapHypervisorCreatesDetachedErrorSpanForGetVMInfoFailures(t *testing.T) { + recorder, _ := newTestTracerProvider(t) + + ctx := WithTraceAttributes(context.Background(), attribute.String("instance_id", "inst_999")) + hv := WrapHypervisor(TypeQEMU, fakeHypervisorGetVMInfoError{}) + _, err := hv.GetVMInfo(ctx) + require.Error(t, err) + + span := findSpanByName(t, recorder.Ended(), "hypervisor.get_vm_info") + require.False(t, span.Parent().IsValid()) + + attrs := attrsToMap(span.Attributes()) + assert.Equal(t, "inst_999", attrs["instance_id"]) + assert.Equal(t, string(TypeQEMU), attrs["hypervisor"]) + assert.Equal(t, "error_only", attrs["sampled_from"]) +} + func newTestTracerProvider(t *testing.T) (*tracetest.SpanRecorder, *sdktrace.TracerProvider) { t.Helper() diff --git a/lib/hypervisor/vz/client.go b/lib/hypervisor/vz/client.go index d1331c56..c3132c1b 100644 --- a/lib/hypervisor/vz/client.go +++ b/lib/hypervisor/vz/client.go @@ -145,33 +145,52 @@ func (c *Client) doGet(ctx context.Context, path string) ([]byte, error) { attribute.String("http.method", http.MethodGet), attribute.String("http.route", path), ) - ctx, span := otel.Tracer("hypeman/hypervisor/vz").Start(ctx, "hypervisor.http GET "+path, trace.WithAttributes(attrs...)) - defer span.End() + tracer := otel.Tracer("hypeman/hypervisor/vz") + spanName := "hypervisor.http GET " + path + shouldTrace := hypervisor.ShouldTraceHypervisorHTTPSpan(http.MethodGet, path) + + var span trace.Span + if shouldTrace { + var spanCtx context.Context + spanCtx, span = tracer.Start(ctx, spanName, trace.WithAttributes(attrs...)) + ctx = spanCtx + defer span.End() + } + + recordError := func(err error) { + if shouldTrace { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + } + } req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://vz-shim"+path, nil) if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) + recordError(err) return nil, err } resp, err := c.httpClient.Do(req) if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) + recordError(err) return nil, err } defer resp.Body.Close() - span.SetAttributes(attribute.Int("http.status_code", resp.StatusCode)) + if shouldTrace { + span.SetAttributes(attribute.Int("http.status_code", resp.StatusCode)) + } body, err := io.ReadAll(resp.Body) if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) + recordError(err) return nil, err } if resp.StatusCode >= http.StatusBadRequest { - span.SetStatus(codes.Error, resp.Status) + if shouldTrace { + span.SetStatus(codes.Error, resp.Status) + } } else { - span.SetStatus(codes.Ok, "") + if shouldTrace { + span.SetStatus(codes.Ok, "") + } } return body, nil } diff --git a/lib/otel/README.md b/lib/otel/README.md index afff2054..05c8af6b 100644 --- a/lib/otel/README.md +++ b/lib/otel/README.md @@ -29,6 +29,7 @@ This keeps pull and push views aligned because both are sourced from the same OT | `OTEL_SERVICE_INSTANCE_ID` | Instance ID (`service.instance.id` attribute) | hostname | | `OTEL_INSECURE` | Disable TLS for OTLP | `true` | | `OTEL__METRIC_EXPORT_INTERVAL` | OTLP metric push interval (when enabled) | `60s` | +| `OTEL__SUCCESSFUL_GET_SAMPLE_RATIO` | Sample rate for successful root HTTP `GET` traces | `0.1` | | `METRICS__LISTEN_ADDRESS` | Bind address for `/metrics` listener | `127.0.0.1` | | `METRICS__PORT` | Port for `/metrics` listener | `9464` | | `METRICS__VM_LABEL_BUDGET` | Warning threshold for observed per-VM labeled VM metrics | `200` | diff --git a/lib/otel/http_sampling.go b/lib/otel/http_sampling.go new file mode 100644 index 00000000..179bc51f --- /dev/null +++ b/lib/otel/http_sampling.go @@ -0,0 +1,43 @@ +package otel + +import ( + "fmt" + "net/http" + + "go.opentelemetry.io/otel/attribute" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/trace" +) + +type successfulGETSampler struct { + getSampler sdktrace.Sampler + ratio float64 +} + +func newSuccessfulGETSampler(ratio float64) sdktrace.Sampler { + return sdktrace.ParentBased(&successfulGETSampler{ + getSampler: sdktrace.TraceIDRatioBased(ratio), + ratio: ratio, + }) +} + +func (s *successfulGETSampler) ShouldSample(params sdktrace.SamplingParameters) sdktrace.SamplingResult { + if params.Kind == trace.SpanKindServer && httpMethodFromAttrs(params.Attributes) == http.MethodGet { + return s.getSampler.ShouldSample(params) + } + return sdktrace.SamplingResult{Decision: sdktrace.RecordAndSample} +} + +func (s *successfulGETSampler) Description() string { + return fmt.Sprintf("successfulGETSampler{ratio=%.4f}", s.ratio) +} + +func httpMethodFromAttrs(attrs []attribute.KeyValue) string { + for _, attr := range attrs { + switch string(attr.Key) { + case "http.method", "http.request.method": + return attr.Value.AsString() + } + } + return "" +} diff --git a/lib/otel/http_sampling_test.go b/lib/otel/http_sampling_test.go new file mode 100644 index 00000000..9f7a710d --- /dev/null +++ b/lib/otel/http_sampling_test.go @@ -0,0 +1,88 @@ +package otel + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/riandyrn/otelchi" + otelapi "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" +) + +func TestSuccessfulGETSamplerDropsSuccessfulGETRequests(t *testing.T) { + recorder, router, shutdown := newHTTPTraceTestHarness(t, 0) + defer shutdown() + + router.Get("/instances", func(w http.ResponseWriter, r *http.Request) {}) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/instances", nil) + router.ServeHTTP(rr, req) + + if got := len(recorder.Ended()); got != 0 { + t.Fatalf("expected no spans for sampled-out successful GET, got %d", got) + } +} + +func TestSuccessfulGETSamplerKeepsSuccessfulPOSTRequests(t *testing.T) { + recorder, router, shutdown := newHTTPTraceTestHarness(t, 0) + defer shutdown() + + router.Post("/instances", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusCreated) + }) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/instances", nil) + router.ServeHTTP(rr, req) + + span := findEndedSpanByName(t, recorder.Ended(), "/instances") + if got := attrValue(span.Attributes(), "http.method"); got != http.MethodPost { + t.Fatalf("expected POST span, got attrs %v", span.Attributes()) + } +} + +func newHTTPTraceTestHarness(t *testing.T, getRatio float64) (*tracetest.SpanRecorder, *chi.Mux, func()) { + t.Helper() + + recorder := tracetest.NewSpanRecorder() + provider := sdktrace.NewTracerProvider( + sdktrace.WithSampler(newSuccessfulGETSampler(getRatio)), + sdktrace.WithSpanProcessor(recorder), + ) + previous := otelapi.GetTracerProvider() + otelapi.SetTracerProvider(provider) + + router := chi.NewRouter() + router.Use(otelchi.Middleware("hypeman-test", otelchi.WithChiRoutes(router))) + + return recorder, router, func() { + otelapi.SetTracerProvider(previous) + _ = provider.Shutdown(context.Background()) + } +} + +func findEndedSpanByName(t *testing.T, spans []sdktrace.ReadOnlySpan, name string) sdktrace.ReadOnlySpan { + t.Helper() + for _, span := range spans { + if span.Name() == name { + return span + } + } + t.Fatalf("span %q not found", name) + return nil +} + +func attrValue(attrs []attribute.KeyValue, key string) string { + for _, attr := range attrs { + if string(attr.Key) == key { + return attr.Value.Emit() + } + } + return "" +} diff --git a/lib/otel/otel.go b/lib/otel/otel.go index 74911d5a..3784b7cf 100644 --- a/lib/otel/otel.go +++ b/lib/otel/otel.go @@ -31,14 +31,15 @@ import ( // Config holds OpenTelemetry configuration. type Config struct { - Enabled bool - Endpoint string - ServiceName string - ServiceInstanceID string - Insecure bool - MetricExportInterval string - Version string - Env string + Enabled bool + Endpoint string + ServiceName string + ServiceInstanceID string + Insecure bool + MetricExportInterval string + SuccessfulGetSampleRatio float64 + Version string + Env string } // Provider holds initialized OTel providers. @@ -145,11 +146,13 @@ func Init(ctx context.Context, cfg Config) (*Provider, func(context.Context) err if traceErr != nil { slog.Warn("failed to initialize OTLP trace exporter; continuing without trace export", "error", traceErr) tracerProvider = sdktrace.NewTracerProvider( + sdktrace.WithSampler(newSuccessfulGETSampler(cfg.SuccessfulGetSampleRatio)), sdktrace.WithResource(res), ) } else { tracerProvider = sdktrace.NewTracerProvider( sdktrace.WithBatcher(traceExporter), + sdktrace.WithSampler(newSuccessfulGETSampler(cfg.SuccessfulGetSampleRatio)), sdktrace.WithResource(res), ) } diff --git a/lib/vmm/client.go b/lib/vmm/client.go index 70c1e849..c2d0495c 100644 --- a/lib/vmm/client.go +++ b/lib/vmm/client.go @@ -42,21 +42,31 @@ func (m *metricsRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro attribute.String("http.method", req.Method), attribute.String("http.route", req.URL.Path), ) - ctx, span := m.tracer.Start(req.Context(), "hypervisor.http "+req.Method+" "+req.URL.Path, trace.WithAttributes(attrs...)) - req = req.WithContext(ctx) + spanName := "hypervisor.http " + req.Method + " " + req.URL.Path + shouldTrace := hypervisor.ShouldTraceHypervisorHTTPSpan(req.Method, req.URL.Path) + + var span trace.Span + if shouldTrace { + ctx, startedSpan := m.tracer.Start(req.Context(), spanName, trace.WithAttributes(attrs...)) + span = startedSpan + req = req.WithContext(ctx) + } resp, err := m.base.RoundTrip(req) - if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) - } else { - span.SetAttributes(attribute.Int("http.status_code", resp.StatusCode)) - if resp.StatusCode >= 400 { - span.SetStatus(codes.Error, resp.Status) + switch { + case shouldTrace: + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) } else { - span.SetStatus(codes.Ok, "") + span.SetAttributes(attribute.Int("http.status_code", resp.StatusCode)) + if resp.StatusCode >= 400 { + span.SetStatus(codes.Error, resp.Status) + } else { + span.SetStatus(codes.Ok, "") + } } + span.End() } - span.End() // Record metrics using global VMMMetrics if VMMMetrics != nil { diff --git a/lib/vmm/client_tracing_test.go b/lib/vmm/client_tracing_test.go index 35f8bb11..6982ec58 100644 --- a/lib/vmm/client_tracing_test.go +++ b/lib/vmm/client_tracing_test.go @@ -72,3 +72,71 @@ func TestMetricsRoundTripperCreatesTraceSpan(t *testing.T) { assert.Equal(t, "PUT /api/v1/vm.resume", attrs["operation"]) assert.Equal(t, "204", attrs["http.status_code"]) } + +func TestMetricsRoundTripperSkipsVMInfoTraceByDefault(t *testing.T) { + recorder := tracetest.NewSpanRecorder() + provider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder)) + previous := otel.GetTracerProvider() + otel.SetTracerProvider(provider) + t.Cleanup(func() { + otel.SetTracerProvider(previous) + _ = provider.Shutdown(context.Background()) + }) + + rt := &metricsRoundTripper{ + base: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Body: io.NopCloser(strings.NewReader(`{}`)), + }, nil + }), + tracer: otel.Tracer("hypeman/vmm"), + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://localhost/api/v1/vm.info", nil) + require.NoError(t, err) + + _, err = rt.RoundTrip(req) + require.NoError(t, err) + + for _, span := range recorder.Ended() { + if span.Name() == "hypervisor.http GET /api/v1/vm.info" { + t.Fatalf("expected vm.info trace span to be suppressed by default") + } + } +} + +func TestMetricsRoundTripperSkipsVMInfoTraceOnErrors(t *testing.T) { + recorder := tracetest.NewSpanRecorder() + provider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder)) + previous := otel.GetTracerProvider() + otel.SetTracerProvider(provider) + t.Cleanup(func() { + otel.SetTracerProvider(previous) + _ = provider.Shutdown(context.Background()) + }) + + rt := &metricsRoundTripper{ + base: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Status: "500 Internal Server Error", + Body: io.NopCloser(strings.NewReader("boom")), + }, nil + }), + tracer: otel.Tracer("hypeman/vmm"), + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://localhost/api/v1/vm.info", nil) + require.NoError(t, err) + + _, err = rt.RoundTrip(req) + require.NoError(t, err) + + for _, span := range recorder.Ended() { + if span.Name() == "hypervisor.http GET /api/v1/vm.info" { + t.Fatalf("expected vm.info trace span to stay suppressed even on errors") + } + } +}