From 8a52963d9d8eed011f6ac2409dc7a0123c5499bf Mon Sep 17 00:00:00 2001 From: Kevin Schneider Date: Sat, 2 May 2026 08:34:31 -0700 Subject: [PATCH] Tests Phase 2: httptest-mocked HTTP integration + H1/H2 regression guards Second of three planned phases. Phase 2 covers the L2 layer per the test plan: in-process httptest.Server fakes for the upstream services FusionDataCLI talks to (APS GraphQL, APS OAuth, Fusion MCP JSON-RPC), exercising every network-bound code path without ever leaving the process. Coverage jumps: auth 27.6% -> 73.9% (+46pp; OAuth exchange/refresh, callback) api 13.2% -> 70.9% (+58pp; gqlQuery, pagination, details, STEP derivative, file download) fusion 27.9% -> 83.8% (+56pp; MCP init+tools/call, session retry, SSE response parsing, stateless mode) config 90.6% -> 90.6% (unchanged; already covered in Phase 1) ui 24.5% -> 24.5% (unchanged; UI flow tests are Phase 3) total 23.9% -> 38.5% New shared test scaffolding: internal/testutil/graphql.go GraphQLServer helper internal/testutil/mcp.go NewMCPServer helper with InitCount, CallCount, SessionIDsSeen probes Production refactors (minimal, all const->var so tests can swap): api/client.go graphqlEndpoint auth/oauth.go authEndpoint, tokenEndpoint, authScope auth/callback.go callbackPort, CallbackURL api/download.go userHomeDir, nowFunc (for StepDownloadPath) Security regression coverage added: H1 TestWaitForCallback_OAuthError asserts the response body contains `<script>` and not the literal `") — spelled out so the + // security intent is obvious in the source. + const encodedDesc = "%3Cscript%3Ealert%281%29%3C%2Fscript%3E" + resp, err := http.Get(CallbackURL + "?error=access_denied&error_description=" + encodedDesc) + if err != nil { + t.Fatalf("http.Get: %v", err) + } + body, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + t.Fatalf("read body: %v", err) + } + + bodyStr := string(body) + if !strings.Contains(bodyStr, "<script>") { + t.Errorf("body does not contain HTML-escaped %q; body=%q", "<script>", bodyStr) + } + if strings.Contains(bodyStr, "") { + t.Errorf("body contains unescaped script tag (XSS regression!); body=%q", bodyStr) + } + + select { + case r := <-resCh: + if r.err == nil { + t.Fatalf("WaitForCallback returned nil error; code=%q", r.code) + } + if !strings.Contains(r.err.Error(), "access_denied") { + t.Errorf("error %q does not contain %q", r.err.Error(), "access_denied") + } + if !strings.Contains(r.err.Error(), "") { + // The error message itself isn't HTML-rendered, so the raw + // description is fine — what matters is the HTTP body escaped it. + t.Errorf("error %q does not contain raw description %q", r.err.Error(), "") + } + case <-time.After(2 * time.Second): + t.Fatal("WaitForCallback did not return within 2s") + } +} + +func TestWaitForCallback_NoCode(t *testing.T) { + port := useFreeCallbackPort(t) + + type result struct { + code string + err error + } + resCh := make(chan result, 1) + go func() { + code, err := WaitForCallback(context.Background()) + resCh <- result{code: code, err: err} + }() + + waitForServer(t, port) + + resp, err := http.Get(CallbackURL + "?") + if err != nil { + t.Fatalf("http.Get: %v", err) + } + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + + select { + case r := <-resCh: + if r.err == nil { + t.Fatalf("WaitForCallback returned nil error; code=%q", r.code) + } + if !strings.Contains(r.err.Error(), "no authorization code") { + t.Errorf("error %q does not contain %q", r.err.Error(), "no authorization code") + } + case <-time.After(2 * time.Second): + t.Fatal("WaitForCallback did not return within 2s") + } +} + +func TestWaitForCallback_ContextCancel(t *testing.T) { + port := useFreeCallbackPort(t) + + ctx, cancel := context.WithCancel(context.Background()) + type result struct { + code string + err error + } + resCh := make(chan result, 1) + go func() { + code, err := WaitForCallback(ctx) + resCh <- result{code: code, err: err} + }() + + waitForServer(t, port) + cancel() + + select { + case r := <-resCh: + if r.err == nil { + t.Fatalf("WaitForCallback returned nil error after cancel; code=%q", r.code) + } + if !errors.Is(r.err, context.Canceled) { + t.Errorf("error = %v, want wraps context.Canceled", r.err) + } + case <-time.After(2 * time.Second): + t.Fatal("WaitForCallback did not return within 2s after cancel") + } +} diff --git a/auth/oauth.go b/auth/oauth.go index 836a101..1d4c58c 100644 --- a/auth/oauth.go +++ b/auth/oauth.go @@ -16,7 +16,9 @@ import ( "time" ) -const ( +// Endpoints are vars (not consts) so tests can swap them for an +// httptest.Server URL. Production code never reassigns them. +var ( authEndpoint = "https://developer.api.autodesk.com/authentication/v2/authorize" tokenEndpoint = "https://developer.api.autodesk.com/authentication/v2/token" authScope = "data:read user-profile:read" diff --git a/auth/oauth_test.go b/auth/oauth_test.go index 0409329..4c46425 100644 --- a/auth/oauth_test.go +++ b/auth/oauth_test.go @@ -1,8 +1,13 @@ package auth import ( + "context" + "net/http" + "net/http/httptest" "net/url" + "strings" "testing" + "time" ) func TestNewVerifier_Length(t *testing.T) { @@ -71,3 +76,218 @@ func TestBuildAuthURL_Shape(t *testing.T) { } } } + +// swapTokenEndpoint replaces the package-level tokenEndpoint var for the +// duration of the test, restoring it on cleanup. +func swapTokenEndpoint(t *testing.T, url string) { + t.Helper() + prev := tokenEndpoint + t.Cleanup(func() { tokenEndpoint = prev }) + tokenEndpoint = url +} + +func TestExchangeCode_PublicClient_PutsClientIDInBody(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("method = %q, want POST", r.Method) + } + if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want { + t.Errorf("Content-Type = %q, want %q", got, want) + } + if got := r.Header.Get("Authorization"); got != "" { + t.Errorf("Authorization header = %q, want empty (public client)", got) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("ParseForm: %v", err) + } + wantParams := map[string]string{ + "client_id": "pub-client", + "grant_type": "authorization_code", + "code": "the-code", + "redirect_uri": CallbackURL, + "code_verifier": "the-verifier", + } + for k, want := range wantParams { + if got := r.PostForm.Get(k); got != want { + t.Errorf("PostForm[%q] = %q, want %q", k, got, want) + } + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"access_token":"AT","refresh_token":"RT","expires_in":3600}`)) + })) + t.Cleanup(srv.Close) + swapTokenEndpoint(t, srv.URL) + + td, err := exchangeCode(context.Background(), "pub-client", "", "the-code", "the-verifier") + if err != nil { + t.Fatalf("exchangeCode returned error: %v", err) + } + if td.AccessToken != "AT" { + t.Errorf("AccessToken = %q, want %q", td.AccessToken, "AT") + } + if td.RefreshToken != "RT" { + t.Errorf("RefreshToken = %q, want %q", td.RefreshToken, "RT") + } + delta := time.Until(td.ExpiresAt) - time.Hour + if delta < -5*time.Second || delta > 5*time.Second { + t.Errorf("ExpiresAt offset from now = %v, want ~1h (±5s)", time.Until(td.ExpiresAt)) + } +} + +func TestExchangeCode_ConfidentialClient_UsesBasicAuth(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user, pass, ok := r.BasicAuth() + if !ok { + t.Errorf("expected Basic auth header to be present") + } + if user != "conf-client" || pass != "secret" { + t.Errorf("BasicAuth = (%q, %q), want (%q, %q)", user, pass, "conf-client", "secret") + } + if err := r.ParseForm(); err != nil { + t.Fatalf("ParseForm: %v", err) + } + if got := r.PostForm.Get("client_id"); got != "" { + t.Errorf("PostForm[client_id] = %q, want empty (confidential clients use Basic)", got) + } + if got, want := r.PostForm.Get("grant_type"), "authorization_code"; got != want { + t.Errorf("PostForm[grant_type] = %q, want %q", got, want) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"access_token":"AT","refresh_token":"RT","expires_in":3600}`)) + })) + t.Cleanup(srv.Close) + swapTokenEndpoint(t, srv.URL) + + td, err := exchangeCode(context.Background(), "conf-client", "secret", "the-code", "the-verifier") + if err != nil { + t.Fatalf("exchangeCode returned error: %v", err) + } + if td.AccessToken != "AT" { + t.Errorf("AccessToken = %q, want %q", td.AccessToken, "AT") + } + if td.RefreshToken != "RT" { + t.Errorf("RefreshToken = %q, want %q", td.RefreshToken, "RT") + } +} + +func TestRefresh_RefreshesAndSaves(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Fatalf("ParseForm: %v", err) + } + if got, want := r.PostForm.Get("grant_type"), "refresh_token"; got != want { + t.Errorf("PostForm[grant_type] = %q, want %q", got, want) + } + if got, want := r.PostForm.Get("refresh_token"), "old-rt"; got != want { + t.Errorf("PostForm[refresh_token] = %q, want %q", got, want) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"access_token":"new-AT","refresh_token":"new-RT","expires_in":600}`)) + })) + t.Cleanup(srv.Close) + swapTokenEndpoint(t, srv.URL) + + td, err := Refresh(context.Background(), "pub", "", "old-rt") + if err != nil { + t.Fatalf("Refresh returned error: %v", err) + } + if td.AccessToken != "new-AT" { + t.Errorf("AccessToken = %q, want %q", td.AccessToken, "new-AT") + } + if td.RefreshToken != "new-RT" { + t.Errorf("RefreshToken = %q, want %q", td.RefreshToken, "new-RT") + } + + loaded, err := LoadTokens() + if err != nil { + t.Fatalf("LoadTokens returned error: %v", err) + } + if loaded == nil { + t.Fatal("LoadTokens returned nil; expected SaveTokens to have persisted") + } + if loaded.AccessToken != td.AccessToken { + t.Errorf("loaded.AccessToken = %q, want %q", loaded.AccessToken, td.AccessToken) + } + if loaded.RefreshToken != td.RefreshToken { + t.Errorf("loaded.RefreshToken = %q, want %q", loaded.RefreshToken, td.RefreshToken) + } + if !loaded.ExpiresAt.Equal(td.ExpiresAt) { + t.Errorf("loaded.ExpiresAt = %v, want %v", loaded.ExpiresAt, td.ExpiresAt) + } +} + +func TestDoTokenRequest_ErrorPaths(t *testing.T) { + tests := []struct { + name string + handler http.HandlerFunc + useBadURL bool + wantSubstrs []string + }{ + { + name: "oauth_error_400", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"bad code"}`)) + }, + wantSubstrs: []string{"invalid_grant", "bad code"}, + }, + { + name: "non_json_500", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("Internal Server Error")) + }, + wantSubstrs: []string{"500", "Internal Server Error"}, + }, + { + name: "empty_body_200", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + // no body + }, + wantSubstrs: []string{"parsing"}, + }, + { + name: "network_error", + useBadURL: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.useBadURL { + prev := tokenEndpoint + t.Cleanup(func() { tokenEndpoint = prev }) + tokenEndpoint = "http://127.0.0.1:1" + } else { + srv := httptest.NewServer(tc.handler) + t.Cleanup(srv.Close) + swapTokenEndpoint(t, srv.URL) + } + + form := url.Values{} + form.Set("grant_type", "refresh_token") + form.Set("refresh_token", "x") + + td, err := doTokenRequest(context.Background(), "id", "", form) + if err == nil { + t.Fatalf("doTokenRequest returned nil error; td=%+v", td) + } + for _, sub := range tc.wantSubstrs { + if !strings.Contains(err.Error(), sub) { + t.Errorf("error %q does not contain %q", err.Error(), sub) + } + } + }) + } +} + diff --git a/fusion/mcp_test.go b/fusion/mcp_test.go index 2f13997..36d89d9 100644 --- a/fusion/mcp_test.go +++ b/fusion/mcp_test.go @@ -5,7 +5,11 @@ import ( "encoding/base64" "encoding/json" "strings" + "sync" "testing" + "time" + + "github.com/schneik80/FusionDataCLI/internal/testutil" ) func TestNormalizeProjectID(t *testing.T) { @@ -267,3 +271,284 @@ func TestInsertDocument_ValidatesInput(t *testing.T) { }) } } + +// ---------------------------------------------------------------------------- +// Phase 2 (L2) — httptest-mocked MCP server tests. +// +// These exercise the JSON-RPC + session-cache machinery in invoke / session / +// callTool against a faked MCP server (testutil.NewMCPServer). Each test +// constructs a fresh *Client so that cached session state from one test +// cannot leak into another. +// ---------------------------------------------------------------------------- + +// testCtx returns a 5-second-bounded context, freed via t.Cleanup. +func testCtx(t *testing.T) context.Context { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + return ctx +} + +func TestActiveHubProjects_Success(t *testing.T) { + var capturedArgs map[string]any + srv := testutil.NewMCPServer(t, testutil.MCPScenario{ + SessionID: "sid-success", + Tools: map[string]testutil.MCPHandler{ + "fusion_mcp_read": func(args map[string]any) testutil.MCPResponse { + capturedArgs = args + return testutil.MCPResponse{ + ContentText: `{"success": true, "projects": [{"id": "P1", "name": "Project One"}, {"id": "P2", "name": "Project Two"}]}`, + } + }, + }, + }) + + client := &Client{Endpoint: srv.URL, HTTP: srv.Client()} + projects, err := client.ActiveHubProjects(testCtx(t)) + if err != nil { + t.Fatalf("ActiveHubProjects: unexpected error: %v", err) + } + if len(projects) != 2 { + t.Fatalf("ActiveHubProjects: got %d projects, want 2 (%+v)", len(projects), projects) + } + if projects[0].Name != "Project One" || projects[1].Name != "Project Two" { + t.Errorf("ActiveHubProjects: names = [%q,%q], want [\"Project One\",\"Project Two\"]", + projects[0].Name, projects[1].Name) + } + if got := srv.InitCount(); got != 1 { + t.Errorf("InitCount = %d, want 1", got) + } + if got := srv.CallCount("fusion_mcp_read"); got != 1 { + t.Errorf("CallCount(fusion_mcp_read) = %d, want 1", got) + } + if capturedArgs == nil { + t.Fatal("handler was not invoked (capturedArgs is nil)") + } + if got := capturedArgs["queryType"]; got != "projects" { + t.Errorf("args[queryType] = %v, want \"projects\"", got) + } +} + +func TestActiveHubProjects_SuccessFalse(t *testing.T) { + srv := testutil.NewMCPServer(t, testutil.MCPScenario{ + SessionID: "sid-fail", + Tools: map[string]testutil.MCPHandler{ + "fusion_mcp_read": func(args map[string]any) testutil.MCPResponse { + return testutil.MCPResponse{ + ContentText: `{"success": false, "projects": []}`, + } + }, + }, + }) + + client := &Client{Endpoint: srv.URL, HTTP: srv.Client()} + _, err := client.ActiveHubProjects(testCtx(t)) + if err == nil { + t.Fatal("ActiveHubProjects: expected error for success:false, got nil") + } + // Subtlety: callTool's parseToolErrorText branch fires first on success:false + // payloads, surfacing them as "tool reported failure" before the per-method + // "projects query failed" branch in ActiveHubProjects ever runs. Either + // message indicates the failure was caught — assert on either to remain + // resilient to whichever guard short-circuits first. + msg := err.Error() + if !strings.Contains(msg, "projects query failed") && !strings.Contains(msg, "tool reported failure") { + t.Errorf("ActiveHubProjects: error %q does not contain \"projects query failed\" or \"tool reported failure\"", msg) + } +} + +func TestActiveHubProjects_SuccessFalse_WithErrorMsg(t *testing.T) { + srv := testutil.NewMCPServer(t, testutil.MCPScenario{ + SessionID: "sid-autherr", + Tools: map[string]testutil.MCPHandler{ + "fusion_mcp_read": func(args map[string]any) testutil.MCPResponse { + return testutil.MCPResponse{ + ContentText: `{"success": false, "error": "auth failed"}`, + } + }, + }, + }) + + client := &Client{Endpoint: srv.URL, HTTP: srv.Client()} + _, err := client.ActiveHubProjects(testCtx(t)) + if err == nil { + t.Fatal("ActiveHubProjects: expected error for success:false with error, got nil") + } + if !strings.Contains(err.Error(), "auth failed") { + t.Errorf("ActiveHubProjects: error %q does not contain \"auth failed\"", err.Error()) + } +} + +func TestOpenDocument_Roundtrip(t *testing.T) { + var capturedArgs map[string]any + srv := testutil.NewMCPServer(t, testutil.MCPScenario{ + SessionID: "sid-open", + Tools: map[string]testutil.MCPHandler{ + "fusion_mcp_execute": func(args map[string]any) testutil.MCPResponse { + capturedArgs = args + // Plain non-JSON text — parseToolErrorText returns "" because + // it doesn't start with '{' or '['. Production treats as success. + return testutil.MCPResponse{ContentText: "opened"} + }, + }, + }) + + client := &Client{Endpoint: srv.URL, HTTP: srv.Client()} + if err := client.OpenDocument(testCtx(t), "validfileid"); err != nil { + t.Fatalf("OpenDocument: unexpected error: %v", err) + } + if capturedArgs == nil { + t.Fatal("handler was not invoked (capturedArgs is nil)") + } + if got := capturedArgs["featureType"]; got != "document" { + t.Errorf("args[featureType] = %v, want \"document\"", got) + } + obj, ok := capturedArgs["object"].(map[string]any) + if !ok { + t.Fatalf("args[object] is %T, want map[string]any", capturedArgs["object"]) + } + if got := obj["operation"]; got != "open" { + t.Errorf("args.object[operation] = %v, want \"open\"", got) + } + if got := obj["fileId"]; got != "validfileid" { + t.Errorf("args.object[fileId] = %v, want \"validfileid\"", got) + } +} + +func TestInsertDocument_ScriptShape(t *testing.T) { + var capturedArgs map[string]any + srv := testutil.NewMCPServer(t, testutil.MCPScenario{ + SessionID: "sid-insert", + Tools: map[string]testutil.MCPHandler{ + "fusion_mcp_execute": func(args map[string]any) testutil.MCPResponse { + capturedArgs = args + return testutil.MCPResponse{ContentText: "inserted"} + }, + }, + }) + + client := &Client{Endpoint: srv.URL, HTTP: srv.Client()} + const fileID = "urn:adsk.wipprod:dm.lineage:abc" + if err := client.InsertDocument(testCtx(t), fileID); err != nil { + t.Fatalf("InsertDocument: unexpected error: %v", err) + } + if capturedArgs == nil { + t.Fatal("handler was not invoked (capturedArgs is nil)") + } + if got := capturedArgs["featureType"]; got != "script" { + t.Errorf("args[featureType] = %v, want \"script\"", got) + } + obj, ok := capturedArgs["object"].(map[string]any) + if !ok { + t.Fatalf("args[object] is %T, want map[string]any", capturedArgs["object"]) + } + script, ok := obj["script"].(string) + if !ok { + t.Fatalf("args.object[script] is %T, want string", obj["script"]) + } + if !strings.Contains(script, "file_id") { + t.Errorf("script does not contain \"file_id\" variable name:\n%s", script) + } + // JSON-encoded fileId is the original URN wrapped in double quotes. + wantQuoted := `"` + fileID + `"` + if !strings.Contains(script, wantQuoted) { + t.Errorf("script does not contain JSON-encoded fileId %s:\n%s", wantQuoted, script) + } +} + +func TestInvoke_SessionRetryOn404(t *testing.T) { + var ( + mu sync.Mutex + calls int + ) + handler := func(args map[string]any) testutil.MCPResponse { + mu.Lock() + defer mu.Unlock() + calls++ + if calls == 1 { + return testutil.MCPResponse{SessionExpired: true} + } + return testutil.MCPResponse{ContentText: `{"success":true,"projects":[]}`} + } + srv := testutil.NewMCPServer(t, testutil.MCPScenario{ + SessionID: "sid-1", + Tools: map[string]testutil.MCPHandler{ + "fusion_mcp_read": handler, + }, + }) + + client := &Client{Endpoint: srv.URL, HTTP: srv.Client()} + projects, err := client.ActiveHubProjects(testCtx(t)) + if err != nil { + t.Fatalf("ActiveHubProjects: unexpected error: %v", err) + } + if len(projects) != 0 { + t.Errorf("expected 0 projects after retry, got %d (%+v)", len(projects), projects) + } + if got := srv.InitCount(); got != 2 { + t.Errorf("InitCount = %d, want 2 (initial + re-handshake after 404)", got) + } + if got := srv.CallCount("fusion_mcp_read"); got != 2 { + t.Errorf("CallCount(fusion_mcp_read) = %d, want 2 (one 404 + one success)", got) + } + sids := srv.SessionIDsSeen() + if len(sids) != 2 { + t.Fatalf("SessionIDsSeen len = %d, want 2 (got %v)", len(sids), sids) + } + for i, s := range sids { + if s != "sid-1" { + t.Errorf("SessionIDsSeen[%d] = %q, want \"sid-1\"", i, s) + } + } +} + +func TestInvoke_SSEResponse(t *testing.T) { + srv := testutil.NewMCPServer(t, testutil.MCPScenario{ + SessionID: "sid-sse", + SSEMode: true, + Tools: map[string]testutil.MCPHandler{ + "fusion_mcp_read": func(args map[string]any) testutil.MCPResponse { + return testutil.MCPResponse{ + ContentText: `{"success":true,"projects":[{"id":"P1","name":"Solo"}]}`, + } + }, + }, + }) + + client := &Client{Endpoint: srv.URL, HTTP: srv.Client()} + projects, err := client.ActiveHubProjects(testCtx(t)) + if err != nil { + t.Fatalf("ActiveHubProjects (SSE): unexpected error: %v", err) + } + if len(projects) != 1 { + t.Fatalf("ActiveHubProjects (SSE): got %d projects, want 1 (%+v)", len(projects), projects) + } + if projects[0].Name != "Solo" { + t.Errorf("ActiveHubProjects (SSE): name = %q, want \"Solo\"", projects[0].Name) + } +} + +func TestSession_StatelessMode(t *testing.T) { + srv := testutil.NewMCPServer(t, testutil.MCPScenario{ + SessionID: "", // stateless: server emits no Mcp-Session-Id header on init + Tools: map[string]testutil.MCPHandler{ + "fusion_mcp_read": func(args map[string]any) testutil.MCPResponse { + return testutil.MCPResponse{ + ContentText: `{"success":true,"projects":[]}`, + } + }, + }, + }) + + client := &Client{Endpoint: srv.URL, HTTP: srv.Client()} + if _, err := client.ActiveHubProjects(testCtx(t)); err != nil { + t.Fatalf("ActiveHubProjects (stateless): unexpected error: %v", err) + } + sids := srv.SessionIDsSeen() + if len(sids) != 1 { + t.Fatalf("SessionIDsSeen len = %d, want 1 (got %v)", len(sids), sids) + } + if sids[0] != "" { + t.Errorf("SessionIDsSeen[0] = %q, want \"\" (stateless: client must omit header)", sids[0]) + } +} diff --git a/internal/testutil/graphql.go b/internal/testutil/graphql.go new file mode 100644 index 0000000..594f798 --- /dev/null +++ b/internal/testutil/graphql.go @@ -0,0 +1,105 @@ +// Package testutil provides shared test helpers for spinning up in-process +// httptest.Server instances that emulate the upstream services FusionDataCLI +// talks to: the APS GraphQL endpoint and the Fusion desktop MCP JSON-RPC +// server. Tests in `auth`, `api`, and `fusion` packages all need to issue +// real HTTP requests against fakes; this package keeps that boilerplate in +// one place so per-package test files can stay focused on the behaviour +// under test. +package testutil + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +// GraphQLRequest is the decoded view of an APS GraphQL POST that the +// handler sees. AuthHeader is the raw Authorization header value (typically +// "Bearer "); Region is the X-Ads-Region header (empty if unset). +type GraphQLRequest struct { + Query string + Variables map[string]any + AuthHeader string + Region string +} + +// GraphQLResponse is what a GraphQLServer handler returns for a given +// request. Either Data (marshaled into the response's "data" field) or +// Errors (each becomes {"message": ...} in the response's "errors" array) +// or both may be set. Status defaults to 200; RawBody, if non-empty, is +// sent verbatim and Data/Errors/Status are ignored. +type GraphQLResponse struct { + Data any + Errors []string + Status int + RawBody string +} + +// GraphQLServer starts an httptest.Server that decodes APS GraphQL +// requests and feeds them to handler, replying with whatever +// GraphQLResponse the handler returns. The server is closed via +// t.Cleanup, so callers don't have to defer Close(). +// +// The fake doesn't enforce method or Content-Type — those are the +// caller's job to assert on the captured request when relevant. +func GraphQLServer(t *testing.T, handler func(req GraphQLRequest) GraphQLResponse) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("testutil: reading request body: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + var req struct { + Query string `json:"query"` + Variables map[string]any `json:"variables"` + } + if err := json.Unmarshal(body, &req); err != nil { + t.Errorf("testutil: decoding GraphQL request body %q: %v", body, err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + resp := handler(GraphQLRequest{ + Query: req.Query, + Variables: req.Variables, + AuthHeader: r.Header.Get("Authorization"), + Region: r.Header.Get("X-Ads-Region"), + }) + + if resp.Status == 0 { + resp.Status = http.StatusOK + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(resp.Status) + + if resp.RawBody != "" { + _, _ = io.WriteString(w, resp.RawBody) + return + } + + envelope := map[string]any{} + if resp.Data != nil { + envelope["data"] = resp.Data + } + if len(resp.Errors) > 0 { + errs := make([]map[string]string, len(resp.Errors)) + for i, m := range resp.Errors { + errs[i] = map[string]string{"message": m} + } + envelope["errors"] = errs + } + // Always emit a "data" key even when nil, so the client's empty-data + // branch is reachable: the production gqlQuery returns an error if + // the JSON has zero bytes in `data`. Encoding `nil` produces "null" + // which still has length, so handlers that want to trigger the + // "empty data" path should set Data to json.RawMessage("") explicitly + // (or use RawBody for that). + _ = json.NewEncoder(w).Encode(envelope) + })) + t.Cleanup(srv.Close) + return srv +} diff --git a/internal/testutil/mcp.go b/internal/testutil/mcp.go new file mode 100644 index 0000000..3b6349c --- /dev/null +++ b/internal/testutil/mcp.go @@ -0,0 +1,206 @@ +package testutil + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" +) + +// MCPHandler is invoked for each tools/call request the fake server +// receives. args is the decoded "arguments" map from the JSON-RPC params. +// The returned MCPResponse controls what the fake replies with (or whether +// it rejects the call as session-expired). +type MCPHandler func(args map[string]any) MCPResponse + +// MCPResponse is one tools/call reply. +// +// - ContentText: payload returned as result.content[0].text (the JSON +// string the production client further decodes via parseToolErrorText). +// - IsError: sets result.isError on the JSON-RPC response. +// - SessionExpired: when true, the server returns 404 (production retries +// with a fresh handshake). +// - Status: overrides the default 200 (rare; use for non-404 error +// responses like 500). +type MCPResponse struct { + ContentText string + IsError bool + SessionExpired bool + Status int +} + +// MCPScenario configures the fake MCP server. +// +// - SessionID: returned in the Mcp-Session-Id response header on +// initialize. Empty string emulates stateless mode (no header). +// - Tools: handlers indexed by tool name. A request for an unmapped tool +// fails the test. +// - SSEMode: when true, tools/call responses are wrapped as +// "data: \n\n" SSE events instead of plain JSON. +type MCPScenario struct { + SessionID string + Tools map[string]MCPHandler + SSEMode bool +} + +// MCPServer wraps an httptest.Server with per-tool call counts so tests +// can assert on retry / session-cache behaviour. +type MCPServer struct { + *httptest.Server + mu sync.Mutex + calls map[string]int + inits int + sidSeen []string // session IDs received on tools/call requests +} + +// InitCount reports how many times the initialize handshake ran. +func (s *MCPServer) InitCount() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.inits +} + +// CallCount reports how many tools/call requests targeted toolName. +func (s *MCPServer) CallCount(toolName string) int { + s.mu.Lock() + defer s.mu.Unlock() + return s.calls[toolName] +} + +// SessionIDsSeen returns the Mcp-Session-Id header values received on +// tools/call requests in arrival order. Useful for asserting that the +// production client sent the cached SID after init, and re-handshake +// happened after a session-expired response. +func (s *MCPServer) SessionIDsSeen() []string { + s.mu.Lock() + defer s.mu.Unlock() + out := make([]string, len(s.sidSeen)) + copy(out, s.sidSeen) + return out +} + +// NewMCPServer starts a fake Fusion MCP JSON-RPC server scripted by +// scenario. The server handles `initialize`, the +// `notifications/initialized` notification (HTTP 204), and `tools/call` +// dispatched via scenario.Tools. Auto-closed via t.Cleanup. +func NewMCPServer(t *testing.T, scenario MCPScenario) *MCPServer { + t.Helper() + s := &MCPServer{calls: map[string]int{}} + s.Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("testutil: reading MCP request body: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + var req struct { + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` + Method string `json:"method"` + Params json.RawMessage `json:"params"` + } + if err := json.Unmarshal(body, &req); err != nil { + t.Errorf("testutil: decoding MCP request %q: %v", body, err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + switch req.Method { + case "initialize": + s.mu.Lock() + s.inits++ + s.mu.Unlock() + if scenario.SessionID != "" { + w.Header().Set("Mcp-Session-Id", scenario.SessionID) + } + writeJSONRPC(w, req.ID, map[string]any{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]any{}, + "serverInfo": map[string]any{"name": "fake-mcp", "version": "0"}, + }) + + case "notifications/initialized": + w.WriteHeader(http.StatusNoContent) + + case "tools/call": + var p struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + } + if err := json.Unmarshal(req.Params, &p); err != nil { + t.Errorf("testutil: decoding tools/call params: %v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + s.mu.Lock() + s.calls[p.Name]++ + s.sidSeen = append(s.sidSeen, r.Header.Get("Mcp-Session-Id")) + s.mu.Unlock() + + handler, ok := scenario.Tools[p.Name] + if !ok { + t.Errorf("testutil: no scenario handler for tool %q", p.Name) + http.Error(w, "no handler", http.StatusInternalServerError) + return + } + resp := handler(p.Arguments) + if resp.SessionExpired { + http.Error(w, "session expired", http.StatusNotFound) + return + } + if resp.Status != 0 && resp.Status != http.StatusOK { + http.Error(w, "scripted non-OK", resp.Status) + return + } + + payload := map[string]any{ + "content": []map[string]any{{"type": "text", "text": resp.ContentText}}, + "isError": resp.IsError, + } + if scenario.SSEMode { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + var buf bytes.Buffer + _ = json.NewEncoder(&buf).Encode(jsonRPCResponse{ + JSONRPC: "2.0", ID: req.ID, Result: mustMarshal(payload), + }) + _, _ = io.WriteString(w, "data: "+buf.String()+"\n") + return + } + writeJSONRPC(w, req.ID, payload) + + default: + t.Errorf("testutil: unexpected MCP method %q", req.Method) + http.Error(w, "unexpected method", http.StatusBadRequest) + } + })) + t.Cleanup(s.Close) + return s +} + +type jsonRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` + Result json.RawMessage `json:"result,omitempty"` +} + +func writeJSONRPC(w http.ResponseWriter, id int, result any) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(jsonRPCResponse{ + JSONRPC: "2.0", + ID: id, + Result: mustMarshal(result), + }) +} + +func mustMarshal(v any) json.RawMessage { + b, err := json.Marshal(v) + if err != nil { + panic(err) + } + return b +}