diff --git a/api/client.go b/api/client.go index f14c6c5..678f589 100644 --- a/api/client.go +++ b/api/client.go @@ -12,7 +12,9 @@ import ( "time" ) -const graphqlEndpoint = "https://developer.api.autodesk.com/mfg/graphql" +// graphqlEndpoint is a var (not const) so tests can point it at an +// httptest.Server. Production code never reassigns it. +var graphqlEndpoint = "https://developer.api.autodesk.com/mfg/graphql" // region is the X-Ads-Region header value sent with every request. // Empty means no header is sent (defaults to US on the server side). diff --git a/api/client_test.go b/api/client_test.go index a8e7b6f..0541818 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -1,6 +1,14 @@ package api -import "testing" +import ( + "context" + "encoding/json" + "strings" + "sync/atomic" + "testing" + + "github.com/schneik80/FusionDataCLI/internal/testutil" +) func TestSetRegion(t *testing.T) { orig := region @@ -26,3 +34,139 @@ func TestSetRegion(t *testing.T) { }) } } + +// swapEndpoint redirects the package-level graphqlEndpoint to url and +// schedules restoration via t.Cleanup. Tests use this to point the GraphQL +// client at an httptest.Server. +func swapEndpoint(t *testing.T, url string) { + t.Helper() + prev := graphqlEndpoint + t.Cleanup(func() { graphqlEndpoint = prev }) + graphqlEndpoint = url +} + +func TestGqlQuery_HappyPath(t *testing.T) { + var sawAuth, sawQuery, sawFoo bool + srv := testutil.GraphQLServer(t, func(req testutil.GraphQLRequest) testutil.GraphQLResponse { + if req.AuthHeader == "Bearer test-token" { + sawAuth = true + } else { + t.Errorf("AuthHeader = %q, want %q", req.AuthHeader, "Bearer test-token") + } + if strings.Contains(req.Query, "Marker") { + sawQuery = true + } else { + t.Errorf("Query missing marker: %q", req.Query) + } + if v, ok := req.Variables["foo"].(string); ok && v == "bar" { + sawFoo = true + } else { + t.Errorf("Variables[foo] = %v, want \"bar\"", req.Variables["foo"]) + } + return testutil.GraphQLResponse{Data: map[string]any{"x": 1}} + }) + swapEndpoint(t, srv.URL) + + ctx := context.Background() + raw, err := gqlQuery(ctx, "test-token", "query Marker { hubs { id } }", map[string]any{"foo": "bar"}) + if err != nil { + t.Fatalf("gqlQuery returned error: %v", err) + } + + var got map[string]int + if err := json.Unmarshal(raw, &got); err != nil { + t.Fatalf("decoding raw data: %v (raw=%s)", err, raw) + } + if got["x"] != 1 { + t.Errorf("decoded data = %v, want {x:1}", got) + } + if !sawAuth || !sawQuery || !sawFoo { + t.Errorf("handler missed an assertion: auth=%v query=%v foo=%v", sawAuth, sawQuery, sawFoo) + } +} + +func TestGqlQuery_401_Wraps(t *testing.T) { + srv := testutil.GraphQLServer(t, func(req testutil.GraphQLRequest) testutil.GraphQLResponse { + return testutil.GraphQLResponse{Status: 401} + }) + swapEndpoint(t, srv.URL) + + _, err := gqlQuery(context.Background(), "tok", "query Q {}", nil) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(strings.ToLower(err.Error()), "unauthorized") { + t.Errorf("error = %q, want substring \"unauthorized\"", err.Error()) + } +} + +func TestGqlQuery_GraphQLErrors_Joined(t *testing.T) { + srv := testutil.GraphQLServer(t, func(req testutil.GraphQLRequest) testutil.GraphQLResponse { + return testutil.GraphQLResponse{Errors: []string{"first failure", "second failure"}} + }) + swapEndpoint(t, srv.URL) + + _, err := gqlQuery(context.Background(), "tok", "query Q {}", nil) + if err == nil { + t.Fatal("expected error, got nil") + } + msg := err.Error() + if !strings.Contains(msg, "first failure; second failure") { + t.Errorf("error = %q, want both messages joined by \"; \"", msg) + } +} + +func TestGqlQuery_EmptyData_Errors(t *testing.T) { + // A response with no "data" field at all leaves gr.Data as a zero-length + // json.RawMessage, which trips the production code's len(gr.Data) == 0 + // guard. (Strings like `""` decode to a 2-byte RawMessage and would + // pass that check.) + srv := testutil.GraphQLServer(t, func(req testutil.GraphQLRequest) testutil.GraphQLResponse { + return testutil.GraphQLResponse{RawBody: `{}`} + }) + swapEndpoint(t, srv.URL) + + _, err := gqlQuery(context.Background(), "tok", "query Q {}", nil) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(strings.ToLower(err.Error()), "empty") { + t.Errorf("error = %q, want substring \"empty\"", err.Error()) + } +} + +func TestGqlQuery_RegionHeader(t *testing.T) { + // Region is shared global state — back up & restore around the whole test. + origRegion := region + t.Cleanup(func() { region = origRegion }) + + cases := []struct { + name string + setRegion string + wantRegion string + }{ + {name: "with_region", setRegion: "EMEA", wantRegion: "EMEA"}, + {name: "without_region", setRegion: "", wantRegion: ""}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var seen atomic.Value // string + seen.Store("") + srv := testutil.GraphQLServer(t, func(req testutil.GraphQLRequest) testutil.GraphQLResponse { + seen.Store(req.Region) + return testutil.GraphQLResponse{Data: map[string]any{"ok": true}} + }) + swapEndpoint(t, srv.URL) + + SetRegion(tc.setRegion) + + if _, err := gqlQuery(context.Background(), "tok", "query Q {}", nil); err != nil { + t.Fatalf("gqlQuery: %v", err) + } + if got := seen.Load().(string); got != tc.wantRegion { + t.Errorf("X-Ads-Region = %q, want %q", got, tc.wantRegion) + } + }) + } +} diff --git a/api/details_test.go b/api/details_test.go index 9f7c078..7e056ca 100644 --- a/api/details_test.go +++ b/api/details_test.go @@ -1,8 +1,13 @@ package api import ( + "context" + "encoding/json" + "os" "testing" "time" + + "github.com/schneik80/FusionDataCLI/internal/testutil" ) func TestParseTime(t *testing.T) { @@ -64,6 +69,162 @@ func TestParseTime(t *testing.T) { } } +func TestGetItemDetails_AllFields(t *testing.T) { + raw, err := os.ReadFile("testdata/details_design.json") + if err != nil { + t.Fatalf("reading fixture: %v", err) + } + var data map[string]any + if err := json.Unmarshal(raw, &data); err != nil { + t.Fatalf("unmarshaling fixture: %v", err) + } + + var sawHubID, sawItemID bool + srv := testutil.GraphQLServer(t, func(req testutil.GraphQLRequest) testutil.GraphQLResponse { + if v, ok := req.Variables["hubId"].(string); ok && v == "h1" { + sawHubID = true + } else { + t.Errorf("Variables[hubId] = %v, want \"h1\"", req.Variables["hubId"]) + } + if v, ok := req.Variables["itemId"].(string); ok && v == "item-1" { + sawItemID = true + } else { + t.Errorf("Variables[itemId] = %v, want \"item-1\"", req.Variables["itemId"]) + } + return testutil.GraphQLResponse{Data: data} + }) + swapEndpoint(t, srv.URL) + + got, err := GetItemDetails(context.Background(), "tok", "h1", "item-1") + if err != nil { + t.Fatalf("GetItemDetails: %v", err) + } + if !sawHubID || !sawItemID { + t.Errorf("handler missed an assertion: hubId=%v itemId=%v", sawHubID, sawItemID) + } + + if got.ID != "urn:item:abc" { + t.Errorf("ID = %q, want %q", got.ID, "urn:item:abc") + } + if got.Name != "Widget A" { + t.Errorf("Name = %q, want %q", got.Name, "Widget A") + } + if got.Typename != "DesignItem" { + t.Errorf("Typename = %q, want %q", got.Typename, "DesignItem") + } + if got.Size != "12345678" { + t.Errorf("Size = %q, want %q", got.Size, "12345678") + } + if got.MimeType != "application/vnd.autodesk.fusion360" { + t.Errorf("MimeType = %q, want %q", got.MimeType, "application/vnd.autodesk.fusion360") + } + if got.ExtensionType != "Fusion360" { + t.Errorf("ExtensionType = %q, want %q", got.ExtensionType, "Fusion360") + } + if got.FusionWebURL != "https://fusion.example/widget-a" { + t.Errorf("FusionWebURL = %q, want %q", got.FusionWebURL, "https://fusion.example/widget-a") + } + if got.VersionNumber != 3 { + t.Errorf("VersionNumber = %d, want 3", got.VersionNumber) + } + + wantCreated := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC) + if !got.CreatedOn.Equal(wantCreated) { + t.Errorf("CreatedOn = %v, want %v", got.CreatedOn, wantCreated) + } + wantModified := time.Date(2024, 2, 20, 14, 0, 0, 0, time.UTC) + if !got.ModifiedOn.Equal(wantModified) { + t.Errorf("ModifiedOn = %v, want %v", got.ModifiedOn, wantModified) + } + + if got.CreatedBy != "Ada Lovelace" { + t.Errorf("CreatedBy = %q, want %q", got.CreatedBy, "Ada Lovelace") + } + if got.ModifiedBy != "Grace Hopper" { + t.Errorf("ModifiedBy = %q, want %q", got.ModifiedBy, "Grace Hopper") + } + if got.PartNumber != "WGT-001" { + t.Errorf("PartNumber = %q, want %q", got.PartNumber, "WGT-001") + } + if got.PartDesc != "The widget A" { + t.Errorf("PartDesc = %q, want %q", got.PartDesc, "The widget A") + } + if got.Material != "Aluminum 6061" { + t.Errorf("Material = %q, want %q", got.Material, "Aluminum 6061") + } + if !got.IsMilestone { + t.Errorf("IsMilestone = false, want true") + } + if got.RootComponentVersionID != "urn:cv:xyz" { + t.Errorf("RootComponentVersionID = %q, want %q", got.RootComponentVersionID, "urn:cv:xyz") + } + + if len(got.Versions) != 3 { + t.Fatalf("len(Versions) = %d, want 3", len(got.Versions)) + } + // Reversed: most-recent first. + if got.Versions[0].Number != 3 { + t.Errorf("Versions[0].Number = %d, want 3", got.Versions[0].Number) + } + if got.Versions[1].Number != 2 { + t.Errorf("Versions[1].Number = %d, want 2", got.Versions[1].Number) + } + if got.Versions[2].Number != 1 { + t.Errorf("Versions[2].Number = %d, want 1", got.Versions[2].Number) + } + if got.Versions[0].Comment != "third edit" { + t.Errorf("Versions[0].Comment = %q, want %q", got.Versions[0].Comment, "third edit") + } + if got.Versions[0].CreatedBy != "Grace Hopper" { + t.Errorf("Versions[0].CreatedBy = %q, want %q", got.Versions[0].CreatedBy, "Grace Hopper") + } +} + +func TestGetItemDetails_DrawingItem_NoComponentVersion(t *testing.T) { + data := map[string]any{ + "item": map[string]any{ + "__typename": "DrawingItem", + "id": "urn:item:dwg", + "name": "Sheet 1", + "size": "0", + "mimeType": "application/dwg", + "extensionType": "DrawingItem", + "createdOn": "2024-03-01T09:00:00Z", + "createdBy": map[string]any{"firstName": "X", "lastName": "Y"}, + "lastModifiedOn": "2024-03-02T09:00:00Z", + "lastModifiedBy": map[string]any{"firstName": "X", "lastName": "Y"}, + "fusionWebUrl": "https://example/dwg", + "tipVersion": map[string]any{"versionNumber": 1}, + }, + "itemVersions": map[string]any{"results": []any{}}, + } + + srv := testutil.GraphQLServer(t, func(req testutil.GraphQLRequest) testutil.GraphQLResponse { + return testutil.GraphQLResponse{Data: data} + }) + swapEndpoint(t, srv.URL) + + got, err := GetItemDetails(context.Background(), "tok", "h1", "item-dwg") + if err != nil { + t.Fatalf("GetItemDetails: %v", err) + } + if got.Typename != "DrawingItem" { + t.Errorf("Typename = %q, want %q", got.Typename, "DrawingItem") + } + if got.RootComponentVersionID != "" { + t.Errorf("RootComponentVersionID = %q, want empty", got.RootComponentVersionID) + } + if got.PartNumber != "" { + t.Errorf("PartNumber = %q, want empty", got.PartNumber) + } + if got.Material != "" { + t.Errorf("Material = %q, want empty", got.Material) + } + if got.IsMilestone { + t.Errorf("IsMilestone = true, want false") + } +} + func TestApiUser_FullName(t *testing.T) { cases := []struct { name string diff --git a/api/download.go b/api/download.go index 64a5a77..6dcde11 100644 --- a/api/download.go +++ b/api/download.go @@ -99,6 +99,14 @@ func DownloadFile(ctx context.Context, url, destPath string) error { return nil } +// userHomeDir and nowFunc are package vars so tests can inject a temp +// directory and a fixed time for deterministic StepDownloadPath output. +// Production code uses the stdlib defaults. +var ( + userHomeDir = os.UserHomeDir + nowFunc = time.Now +) + // StepDownloadPath returns a sensible local destination for a STEP file // derived from name. Prefers ~/Downloads, falling back to the OS temp dir // if the home directory cannot be determined. A timestamp suffix avoids @@ -108,8 +116,8 @@ func StepDownloadPath(name string) string { if safe == "" { safe = "design" } - fname := fmt.Sprintf("%s-%s.stp", safe, time.Now().Format("20060102-150405")) - if home, err := os.UserHomeDir(); err == nil && home != "" { + fname := fmt.Sprintf("%s-%s.stp", safe, nowFunc().Format("20060102-150405")) + if home, err := userHomeDir(); err == nil && home != "" { return filepath.Join(home, "Downloads", fname) } return filepath.Join(os.TempDir(), fname) diff --git a/api/download_test.go b/api/download_test.go index 4b20b29..2888f2b 100644 --- a/api/download_test.go +++ b/api/download_test.go @@ -1,6 +1,19 @@ package api -import "testing" +import ( + "bytes" + "context" + "errors" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/schneik80/FusionDataCLI/internal/testutil" +) func TestSanitizeFilename(t *testing.T) { cases := []struct { @@ -59,3 +72,193 @@ func TestSanitizeFilename(t *testing.T) { }) } } + +func TestRequestSTEPDerivative_Statuses(t *testing.T) { + cases := []struct { + name string + status string + want string + }{ + {name: "pending", status: "PENDING", want: "PENDING"}, + {name: "success", status: "SUCCESS", want: "SUCCESS"}, + {name: "failed", status: "FAILED", want: "FAILED"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var sawCV bool + status := tc.status + srv := testutil.GraphQLServer(t, func(req testutil.GraphQLRequest) testutil.GraphQLResponse { + if v, ok := req.Variables["componentVersionId"].(string); ok && v == "cv-1" { + sawCV = true + } else { + t.Errorf("Variables[componentVersionId] = %v, want \"cv-1\"", req.Variables["componentVersionId"]) + } + return testutil.GraphQLResponse{Data: map[string]any{ + "componentVersion": map[string]any{ + "derivatives": []any{ + map[string]any{ + "status": status, + "signedUrl": "https://signed.example/file.stp", + "outputFormat": "STEP", + }, + }, + }, + }} + }) + swapEndpoint(t, srv.URL) + + gotStatus, gotURL, err := RequestSTEPDerivative(context.Background(), "tok", "cv-1") + if err != nil { + t.Fatalf("RequestSTEPDerivative: %v", err) + } + if !sawCV { + t.Errorf("handler did not see componentVersionId variable") + } + if gotStatus != tc.want { + t.Errorf("status = %q, want %q", gotStatus, tc.want) + } + if gotURL != "https://signed.example/file.stp" { + t.Errorf("signedURL = %q, want %q", gotURL, "https://signed.example/file.stp") + } + }) + } +} + +func TestRequestSTEPDerivative_NoDerivative(t *testing.T) { + srv := testutil.GraphQLServer(t, func(req testutil.GraphQLRequest) testutil.GraphQLResponse { + return testutil.GraphQLResponse{Data: map[string]any{ + "componentVersion": map[string]any{ + "derivatives": []any{}, + }, + }} + }) + swapEndpoint(t, srv.URL) + + _, _, err := RequestSTEPDerivative(context.Background(), "tok", "cv-1") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "no STEP derivative") { + t.Errorf("error = %q, want substring \"no STEP derivative\"", err.Error()) + } +} + +func TestDownloadFile_Streams(t *testing.T) { + body := bytes.Repeat([]byte("x"), 100*1024) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // H2 regression guardrail — APS signed URLs are self-authenticated; + // sending the bearer token would leak credentials to any host the + // signed URL points at. + if got := r.Header.Get("Authorization"); got != "" { + t.Errorf("Authorization header leaked to download endpoint: %q (must be empty — H2 regression!)", got) + } + w.Header().Set("Content-Type", "application/octet-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(body) + })) + t.Cleanup(srv.Close) + + dest := filepath.Join(t.TempDir(), "out.stp") + if err := DownloadFile(context.Background(), srv.URL+"/some/file.stp", dest); err != nil { + t.Fatalf("DownloadFile: %v", err) + } + + got, err := os.ReadFile(dest) + if err != nil { + t.Fatalf("reading destination: %v", err) + } + if !bytes.Equal(got, body) { + t.Errorf("downloaded bytes (%d) do not match served body (%d)", len(got), len(body)) + } +} + +func TestDownloadFile_Non2xx(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte("forbidden by policy")) + })) + t.Cleanup(srv.Close) + + dest := filepath.Join(t.TempDir(), "should-not-exist.stp") + err := DownloadFile(context.Background(), srv.URL+"/blocked", dest) + if err == nil { + t.Fatal("expected error, got nil") + } + msg := err.Error() + if !strings.Contains(msg, "403") { + t.Errorf("error = %q, want substring \"403\"", msg) + } + if !strings.Contains(msg, "forbidden by policy") { + t.Errorf("error = %q, want substring \"forbidden by policy\"", msg) + } + if _, statErr := os.Stat(dest); !os.IsNotExist(statErr) { + t.Errorf("destination file should not exist; stat err = %v", statErr) + } +} + +func TestStepDownloadPath_Sanitizes(t *testing.T) { + prevHome, prevNow := userHomeDir, nowFunc + t.Cleanup(func() { userHomeDir, nowFunc = prevHome, prevNow }) + + fixedNow := time.Date(2030, 1, 2, 15, 4, 5, 0, time.UTC) + nowFunc = func() time.Time { return fixedNow } + + cases := []struct { + name string + // homeFn is the userHomeDir override for this case. + homeFn func() (string, error) + input string + // wantPath, when non-empty, is the exact expected output. + wantPath string + // fallbackTemp marks the home-error case; we only assert prefix/suffix. + fallbackTemp bool + }{ + { + name: "plain name", + homeFn: func() (string, error) { return "/home/test", nil }, + input: "My Design", + wantPath: filepath.Join("/home/test", "Downloads", "My Design-20300102-150405.stp"), + }, + { + name: "slashes replaced", + homeFn: func() (string, error) { return "/home/test", nil }, + input: "design/with/slashes", + wantPath: filepath.Join("/home/test", "Downloads", "design_with_slashes-20300102-150405.stp"), + }, + { + name: "empty falls back to design", + homeFn: func() (string, error) { return "/home/test", nil }, + input: "", + wantPath: filepath.Join("/home/test", "Downloads", "design-20300102-150405.stp"), + }, + { + name: "home error falls back to TempDir", + homeFn: func() (string, error) { return "", errors.New("no home") }, + input: "Widget", + fallbackTemp: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + userHomeDir = tc.homeFn + got := StepDownloadPath(tc.input) + if tc.fallbackTemp { + wantPrefix := os.TempDir() + string(os.PathSeparator) + wantSuffix := "-20300102-150405.stp" + if !strings.HasPrefix(got, wantPrefix) { + t.Errorf("StepDownloadPath(%q) = %q, want prefix %q", tc.input, got, wantPrefix) + } + if !strings.HasSuffix(got, wantSuffix) { + t.Errorf("StepDownloadPath(%q) = %q, want suffix %q", tc.input, got, wantSuffix) + } + return + } + if got != tc.wantPath { + t.Errorf("StepDownloadPath(%q) = %q, want %q", tc.input, got, tc.wantPath) + } + }) + } +} diff --git a/api/queries_test.go b/api/queries_test.go index 7b3e330..155d1a8 100644 --- a/api/queries_test.go +++ b/api/queries_test.go @@ -1,6 +1,12 @@ package api -import "testing" +import ( + "context" + "sync/atomic" + "testing" + + "github.com/schneik80/FusionDataCLI/internal/testutil" +) func TestNavItemFromTypename(t *testing.T) { cases := []struct { @@ -40,3 +46,220 @@ func TestNavItemFromTypename(t *testing.T) { }) } } + +func TestGetHubs_Pagination(t *testing.T) { + var calls atomic.Int32 + srv := testutil.GraphQLServer(t, func(req testutil.GraphQLRequest) testutil.GraphQLResponse { + n := calls.Add(1) + switch n { + case 1: + if _, ok := req.Variables["cursor"]; ok { + t.Errorf("first call should not include cursor variable, got %v", req.Variables) + } + return testutil.GraphQLResponse{Data: map[string]any{ + "hubs": map[string]any{ + "pagination": map[string]any{"cursor": "PAGE2"}, + "results": []map[string]any{ + { + "id": "h1", + "name": "Hub1", + "fusionWebUrl": "https://example/h1", + "alternativeIdentifiers": map[string]any{ + "dataManagementAPIHubId": "ah1", + }, + }, + { + "id": "h2", + "name": "Hub2", + "fusionWebUrl": "https://example/h2", + "alternativeIdentifiers": map[string]any{ + "dataManagementAPIHubId": "ah2", + }, + }, + }, + }, + }} + case 2: + if got, _ := req.Variables["cursor"].(string); got != "PAGE2" { + t.Errorf("second call cursor = %v, want \"PAGE2\"", req.Variables["cursor"]) + } + return testutil.GraphQLResponse{Data: map[string]any{ + "hubs": map[string]any{ + "pagination": map[string]any{"cursor": ""}, + "results": []map[string]any{ + { + "id": "h3", + "name": "Hub3", + "fusionWebUrl": "https://example/h3", + "alternativeIdentifiers": map[string]any{ + "dataManagementAPIHubId": "ah3", + }, + }, + }, + }, + }} + default: + t.Errorf("unexpected extra call #%d", n) + return testutil.GraphQLResponse{Data: map[string]any{ + "hubs": map[string]any{ + "pagination": map[string]any{"cursor": ""}, + "results": []map[string]any{}, + }, + }} + } + }) + swapEndpoint(t, srv.URL) + + got, err := GetHubs(context.Background(), "tok") + if err != nil { + t.Fatalf("GetHubs: %v", err) + } + if got, want := calls.Load(), int32(2); got != want { + t.Errorf("call count = %d, want %d", got, want) + } + + wantIDs := []string{"h1", "h2", "h3"} + if len(got) != len(wantIDs) { + t.Fatalf("len = %d, want %d (items=%+v)", len(got), len(wantIDs), got) + } + for i, want := range wantIDs { + if got[i].ID != want { + t.Errorf("hubs[%d].ID = %q, want %q", i, got[i].ID, want) + } + if got[i].Kind != "hub" { + t.Errorf("hubs[%d].Kind = %q, want \"hub\"", i, got[i].Kind) + } + if !got[i].IsContainer { + t.Errorf("hubs[%d].IsContainer = false, want true", i) + } + } +} + +func TestGetProjects_FiltersInactive(t *testing.T) { + srv := testutil.GraphQLServer(t, func(req testutil.GraphQLRequest) testutil.GraphQLResponse { + if got, _ := req.Variables["hubId"].(string); got != "h1" { + t.Errorf("hubId = %v, want \"h1\"", req.Variables["hubId"]) + } + return testutil.GraphQLResponse{Data: map[string]any{ + "hub": map[string]any{ + "projects": map[string]any{ + "pagination": map[string]any{"cursor": ""}, + "results": []map[string]any{ + { + "id": "p-active-lower", + "name": "ActiveLower", + "projectStatus": "active", + "projectType": "FUSION", + "alternativeIdentifiers": map[string]any{ + "dataManagementAPIProjectId": "ap1", + }, + }, + { + "id": "p-inactive-upper", + "name": "InactiveUpper", + "projectStatus": "INACTIVE", + "projectType": "FUSION", + "alternativeIdentifiers": map[string]any{ + "dataManagementAPIProjectId": "ap2", + }, + }, + { + "id": "p-inactive-mixed", + "name": "InactiveMixed", + "projectStatus": "Inactive", + "projectType": "FUSION", + "alternativeIdentifiers": map[string]any{ + "dataManagementAPIProjectId": "ap3", + }, + }, + { + "id": "p-active-cap", + "name": "ActiveCap", + "projectStatus": "Active", + "projectType": "FUSION", + "alternativeIdentifiers": map[string]any{ + "dataManagementAPIProjectId": "ap4", + }, + }, + }, + }, + }, + }} + }) + swapEndpoint(t, srv.URL) + + got, err := GetProjects(context.Background(), "tok", "h1") + if err != nil { + t.Fatalf("GetProjects: %v", err) + } + wantNames := []string{"ActiveLower", "ActiveCap"} + if len(got) != len(wantNames) { + t.Fatalf("len = %d, want %d (items=%+v)", len(got), len(wantNames), got) + } + for i, want := range wantNames { + if got[i].Name != want { + t.Errorf("projects[%d].Name = %q, want %q", i, got[i].Name, want) + } + if got[i].Kind != "project" { + t.Errorf("projects[%d].Kind = %q, want \"project\"", i, got[i].Kind) + } + if !got[i].IsContainer { + t.Errorf("projects[%d].IsContainer = false, want true", i) + } + } +} + +func TestGetItems_TypenameMapping(t *testing.T) { + srv := testutil.GraphQLServer(t, func(req testutil.GraphQLRequest) testutil.GraphQLResponse { + if got, _ := req.Variables["hubId"].(string); got != "h1" { + t.Errorf("hubId = %v, want \"h1\"", req.Variables["hubId"]) + } + if got, _ := req.Variables["folderId"].(string); got != "f1" { + t.Errorf("folderId = %v, want \"f1\"", req.Variables["folderId"]) + } + return testutil.GraphQLResponse{Data: map[string]any{ + "itemsByFolder": map[string]any{ + "pagination": map[string]any{"cursor": ""}, + "results": []map[string]any{ + {"__typename": "DesignItem", "id": "i1", "name": "Design"}, + {"__typename": "ConfiguredDesignItem", "id": "i2", "name": "Configured"}, + {"__typename": "DrawingItem", "id": "i3", "name": "Drawing"}, + {"__typename": "Folder", "id": "i4", "name": "SubFolder"}, + {"__typename": "MysteryItem", "id": "i5", "name": "Mystery"}, + }, + }, + }} + }) + swapEndpoint(t, srv.URL) + + got, err := GetItems(context.Background(), "tok", "h1", "f1") + if err != nil { + t.Fatalf("GetItems: %v", err) + } + + want := []struct { + id string + kind string + isContainer bool + }{ + {"i1", "design", false}, + {"i2", "configured", false}, + {"i3", "drawing", false}, + {"i4", "folder", true}, + {"i5", "unknown", false}, + } + if len(got) != len(want) { + t.Fatalf("len = %d, want %d (items=%+v)", len(got), len(want), got) + } + for i, w := range want { + if got[i].ID != w.id { + t.Errorf("items[%d].ID = %q, want %q", i, got[i].ID, w.id) + } + if got[i].Kind != w.kind { + t.Errorf("items[%d].Kind = %q, want %q", i, got[i].Kind, w.kind) + } + if got[i].IsContainer != w.isContainer { + t.Errorf("items[%d].IsContainer = %v, want %v", i, got[i].IsContainer, w.isContainer) + } + } +} diff --git a/api/testdata/details_design.json b/api/testdata/details_design.json new file mode 100644 index 0000000..5c8e90d --- /dev/null +++ b/api/testdata/details_design.json @@ -0,0 +1,30 @@ +{ + "item": { + "__typename": "DesignItem", + "id": "urn:item:abc", + "name": "Widget A", + "size": "12345678", + "mimeType": "application/vnd.autodesk.fusion360", + "extensionType": "Fusion360", + "createdOn": "2024-01-15T10:30:45Z", + "createdBy": {"firstName": "Ada", "lastName": "Lovelace"}, + "lastModifiedOn": "2024-02-20T14:00:00Z", + "lastModifiedBy": {"firstName": "Grace", "lastName": "Hopper"}, + "fusionWebUrl": "https://fusion.example/widget-a", + "tipVersion": {"versionNumber": 3}, + "tipRootComponentVersion": { + "id": "urn:cv:xyz", + "partNumber": "WGT-001", + "partDescription": "The widget A", + "materialName": "Aluminum 6061", + "isMilestone": true + } + }, + "itemVersions": { + "results": [ + {"versionNumber": 1, "name": "first", "createdOn": "2024-01-15T10:30:45Z", "createdBy": {"firstName": "Ada", "lastName": "Lovelace"}}, + {"versionNumber": 2, "name": "second save", "createdOn": "2024-01-20T11:00:00Z", "createdBy": {"firstName": "Ada", "lastName": "Lovelace"}}, + {"versionNumber": 3, "name": "third edit", "createdOn": "2024-02-20T14:00:00Z", "createdBy": {"firstName": "Grace", "lastName": "Hopper"}} + ] + } +} diff --git a/auth/callback.go b/auth/callback.go index dc60e8f..16e84cf 100644 --- a/auth/callback.go +++ b/auth/callback.go @@ -9,7 +9,10 @@ import ( "time" ) -const ( +// callbackPort and CallbackURL are vars (not consts) so tests can bind +// the listener on an ephemeral port (`:0`) and rewrite CallbackURL to the +// resolved address. Production code never reassigns them. +var ( callbackPort = 7879 // CallbackURL is the redirect URI that must be registered in your APS app settings. CallbackURL = "http://localhost:7879/callback" diff --git a/auth/callback_test.go b/auth/callback_test.go new file mode 100644 index 0000000..fe38861 --- /dev/null +++ b/auth/callback_test.go @@ -0,0 +1,221 @@ +package auth + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "testing" + "time" +) + +// freePort binds an ephemeral listener on 127.0.0.1, reads the chosen port, +// and closes the listener. There's a brief race window between Close and the +// callback server's rebind, but for a single-process test it's fine. +func freePort(t *testing.T) int { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.Listen: %v", err) + } + port := ln.Addr().(*net.TCPAddr).Port + if err := ln.Close(); err != nil { + t.Fatalf("ln.Close: %v", err) + } + return port +} + +// waitForServer polls 127.0.0.1: with a short interval until a TCP +// connection succeeds, or fails the test after ~1s. +func waitForServer(t *testing.T, port int) { + t.Helper() + for i := 0; i < 100; i++ { + c, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if err == nil { + _ = c.Close() + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("callback server didn't come up within 1s") +} + +// useFreeCallbackPort swaps callbackPort and CallbackURL to a freshly chosen +// ephemeral port, restoring both on cleanup. Returns the chosen port. +func useFreeCallbackPort(t *testing.T) int { + t.Helper() + port := freePort(t) + + prevPort := callbackPort + prevURL := CallbackURL + t.Cleanup(func() { + callbackPort = prevPort + CallbackURL = prevURL + }) + callbackPort = port + CallbackURL = fmt.Sprintf("http://127.0.0.1:%d/callback", port) + return port +} + +func TestWaitForCallback_HappyPath(t *testing.T) { + port := useFreeCallbackPort(t) + + type result struct { + code string + err error + } + resCh := make(chan result, 1) + go func() { + code, err := WaitForCallback(context.Background()) + resCh <- result{code: code, err: err} + }() + + waitForServer(t, port) + + resp, err := http.Get(CallbackURL + "?code=abc") + if err != nil { + t.Fatalf("http.Get: %v", err) + } + body, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + t.Fatalf("read body: %v", err) + } + if !strings.Contains(string(body), "Authentication successful") { + t.Errorf("body = %q, want substring %q", string(body), "Authentication successful") + } + + select { + case r := <-resCh: + if r.err != nil { + t.Errorf("WaitForCallback returned error: %v", r.err) + } + if r.code != "abc" { + t.Errorf("code = %q, want %q", r.code, "abc") + } + case <-time.After(2 * time.Second): + t.Fatal("WaitForCallback did not return within 2s") + } +} + +func TestWaitForCallback_OAuthError(t *testing.T) { + port := useFreeCallbackPort(t) + + type result struct { + code string + err error + } + resCh := make(chan result, 1) + go func() { + code, err := WaitForCallback(context.Background()) + resCh <- result{code: code, err: err} + }() + + waitForServer(t, port) + + // url.QueryEscape("") — spelled out so the + // security intent is obvious in the source. + const encodedDesc = "%3Cscript%3Ealert%281%29%3C%2Fscript%3E" + resp, err := http.Get(CallbackURL + "?error=access_denied&error_description=" + encodedDesc) + if err != nil { + t.Fatalf("http.Get: %v", err) + } + body, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + t.Fatalf("read body: %v", err) + } + + bodyStr := string(body) + if !strings.Contains(bodyStr, "<script>") { + t.Errorf("body does not contain HTML-escaped %q; body=%q", "<script>", bodyStr) + } + if strings.Contains(bodyStr, "") { + t.Errorf("body contains unescaped script tag (XSS regression!); body=%q", bodyStr) + } + + select { + case r := <-resCh: + if r.err == nil { + t.Fatalf("WaitForCallback returned nil error; code=%q", r.code) + } + if !strings.Contains(r.err.Error(), "access_denied") { + t.Errorf("error %q does not contain %q", r.err.Error(), "access_denied") + } + if !strings.Contains(r.err.Error(), "") { + // The error message itself isn't HTML-rendered, so the raw + // description is fine — what matters is the HTTP body escaped it. + t.Errorf("error %q does not contain raw description %q", r.err.Error(), "") + } + case <-time.After(2 * time.Second): + t.Fatal("WaitForCallback did not return within 2s") + } +} + +func TestWaitForCallback_NoCode(t *testing.T) { + port := useFreeCallbackPort(t) + + type result struct { + code string + err error + } + resCh := make(chan result, 1) + go func() { + code, err := WaitForCallback(context.Background()) + resCh <- result{code: code, err: err} + }() + + waitForServer(t, port) + + resp, err := http.Get(CallbackURL + "?") + if err != nil { + t.Fatalf("http.Get: %v", err) + } + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + + select { + case r := <-resCh: + if r.err == nil { + t.Fatalf("WaitForCallback returned nil error; code=%q", r.code) + } + if !strings.Contains(r.err.Error(), "no authorization code") { + t.Errorf("error %q does not contain %q", r.err.Error(), "no authorization code") + } + case <-time.After(2 * time.Second): + t.Fatal("WaitForCallback did not return within 2s") + } +} + +func TestWaitForCallback_ContextCancel(t *testing.T) { + port := useFreeCallbackPort(t) + + ctx, cancel := context.WithCancel(context.Background()) + type result struct { + code string + err error + } + resCh := make(chan result, 1) + go func() { + code, err := WaitForCallback(ctx) + resCh <- result{code: code, err: err} + }() + + waitForServer(t, port) + cancel() + + select { + case r := <-resCh: + if r.err == nil { + t.Fatalf("WaitForCallback returned nil error after cancel; code=%q", r.code) + } + if !errors.Is(r.err, context.Canceled) { + t.Errorf("error = %v, want wraps context.Canceled", r.err) + } + case <-time.After(2 * time.Second): + t.Fatal("WaitForCallback did not return within 2s after cancel") + } +} diff --git a/auth/oauth.go b/auth/oauth.go index 836a101..1d4c58c 100644 --- a/auth/oauth.go +++ b/auth/oauth.go @@ -16,7 +16,9 @@ import ( "time" ) -const ( +// Endpoints are vars (not consts) so tests can swap them for an +// httptest.Server URL. Production code never reassigns them. +var ( authEndpoint = "https://developer.api.autodesk.com/authentication/v2/authorize" tokenEndpoint = "https://developer.api.autodesk.com/authentication/v2/token" authScope = "data:read user-profile:read" diff --git a/auth/oauth_test.go b/auth/oauth_test.go index 0409329..4c46425 100644 --- a/auth/oauth_test.go +++ b/auth/oauth_test.go @@ -1,8 +1,13 @@ package auth import ( + "context" + "net/http" + "net/http/httptest" "net/url" + "strings" "testing" + "time" ) func TestNewVerifier_Length(t *testing.T) { @@ -71,3 +76,218 @@ func TestBuildAuthURL_Shape(t *testing.T) { } } } + +// swapTokenEndpoint replaces the package-level tokenEndpoint var for the +// duration of the test, restoring it on cleanup. +func swapTokenEndpoint(t *testing.T, url string) { + t.Helper() + prev := tokenEndpoint + t.Cleanup(func() { tokenEndpoint = prev }) + tokenEndpoint = url +} + +func TestExchangeCode_PublicClient_PutsClientIDInBody(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("method = %q, want POST", r.Method) + } + if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want { + t.Errorf("Content-Type = %q, want %q", got, want) + } + if got := r.Header.Get("Authorization"); got != "" { + t.Errorf("Authorization header = %q, want empty (public client)", got) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("ParseForm: %v", err) + } + wantParams := map[string]string{ + "client_id": "pub-client", + "grant_type": "authorization_code", + "code": "the-code", + "redirect_uri": CallbackURL, + "code_verifier": "the-verifier", + } + for k, want := range wantParams { + if got := r.PostForm.Get(k); got != want { + t.Errorf("PostForm[%q] = %q, want %q", k, got, want) + } + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"access_token":"AT","refresh_token":"RT","expires_in":3600}`)) + })) + t.Cleanup(srv.Close) + swapTokenEndpoint(t, srv.URL) + + td, err := exchangeCode(context.Background(), "pub-client", "", "the-code", "the-verifier") + if err != nil { + t.Fatalf("exchangeCode returned error: %v", err) + } + if td.AccessToken != "AT" { + t.Errorf("AccessToken = %q, want %q", td.AccessToken, "AT") + } + if td.RefreshToken != "RT" { + t.Errorf("RefreshToken = %q, want %q", td.RefreshToken, "RT") + } + delta := time.Until(td.ExpiresAt) - time.Hour + if delta < -5*time.Second || delta > 5*time.Second { + t.Errorf("ExpiresAt offset from now = %v, want ~1h (±5s)", time.Until(td.ExpiresAt)) + } +} + +func TestExchangeCode_ConfidentialClient_UsesBasicAuth(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user, pass, ok := r.BasicAuth() + if !ok { + t.Errorf("expected Basic auth header to be present") + } + if user != "conf-client" || pass != "secret" { + t.Errorf("BasicAuth = (%q, %q), want (%q, %q)", user, pass, "conf-client", "secret") + } + if err := r.ParseForm(); err != nil { + t.Fatalf("ParseForm: %v", err) + } + if got := r.PostForm.Get("client_id"); got != "" { + t.Errorf("PostForm[client_id] = %q, want empty (confidential clients use Basic)", got) + } + if got, want := r.PostForm.Get("grant_type"), "authorization_code"; got != want { + t.Errorf("PostForm[grant_type] = %q, want %q", got, want) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"access_token":"AT","refresh_token":"RT","expires_in":3600}`)) + })) + t.Cleanup(srv.Close) + swapTokenEndpoint(t, srv.URL) + + td, err := exchangeCode(context.Background(), "conf-client", "secret", "the-code", "the-verifier") + if err != nil { + t.Fatalf("exchangeCode returned error: %v", err) + } + if td.AccessToken != "AT" { + t.Errorf("AccessToken = %q, want %q", td.AccessToken, "AT") + } + if td.RefreshToken != "RT" { + t.Errorf("RefreshToken = %q, want %q", td.RefreshToken, "RT") + } +} + +func TestRefresh_RefreshesAndSaves(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Fatalf("ParseForm: %v", err) + } + if got, want := r.PostForm.Get("grant_type"), "refresh_token"; got != want { + t.Errorf("PostForm[grant_type] = %q, want %q", got, want) + } + if got, want := r.PostForm.Get("refresh_token"), "old-rt"; got != want { + t.Errorf("PostForm[refresh_token] = %q, want %q", got, want) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"access_token":"new-AT","refresh_token":"new-RT","expires_in":600}`)) + })) + t.Cleanup(srv.Close) + swapTokenEndpoint(t, srv.URL) + + td, err := Refresh(context.Background(), "pub", "", "old-rt") + if err != nil { + t.Fatalf("Refresh returned error: %v", err) + } + if td.AccessToken != "new-AT" { + t.Errorf("AccessToken = %q, want %q", td.AccessToken, "new-AT") + } + if td.RefreshToken != "new-RT" { + t.Errorf("RefreshToken = %q, want %q", td.RefreshToken, "new-RT") + } + + loaded, err := LoadTokens() + if err != nil { + t.Fatalf("LoadTokens returned error: %v", err) + } + if loaded == nil { + t.Fatal("LoadTokens returned nil; expected SaveTokens to have persisted") + } + if loaded.AccessToken != td.AccessToken { + t.Errorf("loaded.AccessToken = %q, want %q", loaded.AccessToken, td.AccessToken) + } + if loaded.RefreshToken != td.RefreshToken { + t.Errorf("loaded.RefreshToken = %q, want %q", loaded.RefreshToken, td.RefreshToken) + } + if !loaded.ExpiresAt.Equal(td.ExpiresAt) { + t.Errorf("loaded.ExpiresAt = %v, want %v", loaded.ExpiresAt, td.ExpiresAt) + } +} + +func TestDoTokenRequest_ErrorPaths(t *testing.T) { + tests := []struct { + name string + handler http.HandlerFunc + useBadURL bool + wantSubstrs []string + }{ + { + name: "oauth_error_400", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"bad code"}`)) + }, + wantSubstrs: []string{"invalid_grant", "bad code"}, + }, + { + name: "non_json_500", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("Internal Server Error")) + }, + wantSubstrs: []string{"500", "Internal Server Error"}, + }, + { + name: "empty_body_200", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + // no body + }, + wantSubstrs: []string{"parsing"}, + }, + { + name: "network_error", + useBadURL: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.useBadURL { + prev := tokenEndpoint + t.Cleanup(func() { tokenEndpoint = prev }) + tokenEndpoint = "http://127.0.0.1:1" + } else { + srv := httptest.NewServer(tc.handler) + t.Cleanup(srv.Close) + swapTokenEndpoint(t, srv.URL) + } + + form := url.Values{} + form.Set("grant_type", "refresh_token") + form.Set("refresh_token", "x") + + td, err := doTokenRequest(context.Background(), "id", "", form) + if err == nil { + t.Fatalf("doTokenRequest returned nil error; td=%+v", td) + } + for _, sub := range tc.wantSubstrs { + if !strings.Contains(err.Error(), sub) { + t.Errorf("error %q does not contain %q", err.Error(), sub) + } + } + }) + } +} + diff --git a/fusion/mcp_test.go b/fusion/mcp_test.go index 2f13997..36d89d9 100644 --- a/fusion/mcp_test.go +++ b/fusion/mcp_test.go @@ -5,7 +5,11 @@ import ( "encoding/base64" "encoding/json" "strings" + "sync" "testing" + "time" + + "github.com/schneik80/FusionDataCLI/internal/testutil" ) func TestNormalizeProjectID(t *testing.T) { @@ -267,3 +271,284 @@ func TestInsertDocument_ValidatesInput(t *testing.T) { }) } } + +// ---------------------------------------------------------------------------- +// Phase 2 (L2) — httptest-mocked MCP server tests. +// +// These exercise the JSON-RPC + session-cache machinery in invoke / session / +// callTool against a faked MCP server (testutil.NewMCPServer). Each test +// constructs a fresh *Client so that cached session state from one test +// cannot leak into another. +// ---------------------------------------------------------------------------- + +// testCtx returns a 5-second-bounded context, freed via t.Cleanup. +func testCtx(t *testing.T) context.Context { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + return ctx +} + +func TestActiveHubProjects_Success(t *testing.T) { + var capturedArgs map[string]any + srv := testutil.NewMCPServer(t, testutil.MCPScenario{ + SessionID: "sid-success", + Tools: map[string]testutil.MCPHandler{ + "fusion_mcp_read": func(args map[string]any) testutil.MCPResponse { + capturedArgs = args + return testutil.MCPResponse{ + ContentText: `{"success": true, "projects": [{"id": "P1", "name": "Project One"}, {"id": "P2", "name": "Project Two"}]}`, + } + }, + }, + }) + + client := &Client{Endpoint: srv.URL, HTTP: srv.Client()} + projects, err := client.ActiveHubProjects(testCtx(t)) + if err != nil { + t.Fatalf("ActiveHubProjects: unexpected error: %v", err) + } + if len(projects) != 2 { + t.Fatalf("ActiveHubProjects: got %d projects, want 2 (%+v)", len(projects), projects) + } + if projects[0].Name != "Project One" || projects[1].Name != "Project Two" { + t.Errorf("ActiveHubProjects: names = [%q,%q], want [\"Project One\",\"Project Two\"]", + projects[0].Name, projects[1].Name) + } + if got := srv.InitCount(); got != 1 { + t.Errorf("InitCount = %d, want 1", got) + } + if got := srv.CallCount("fusion_mcp_read"); got != 1 { + t.Errorf("CallCount(fusion_mcp_read) = %d, want 1", got) + } + if capturedArgs == nil { + t.Fatal("handler was not invoked (capturedArgs is nil)") + } + if got := capturedArgs["queryType"]; got != "projects" { + t.Errorf("args[queryType] = %v, want \"projects\"", got) + } +} + +func TestActiveHubProjects_SuccessFalse(t *testing.T) { + srv := testutil.NewMCPServer(t, testutil.MCPScenario{ + SessionID: "sid-fail", + Tools: map[string]testutil.MCPHandler{ + "fusion_mcp_read": func(args map[string]any) testutil.MCPResponse { + return testutil.MCPResponse{ + ContentText: `{"success": false, "projects": []}`, + } + }, + }, + }) + + client := &Client{Endpoint: srv.URL, HTTP: srv.Client()} + _, err := client.ActiveHubProjects(testCtx(t)) + if err == nil { + t.Fatal("ActiveHubProjects: expected error for success:false, got nil") + } + // Subtlety: callTool's parseToolErrorText branch fires first on success:false + // payloads, surfacing them as "tool reported failure" before the per-method + // "projects query failed" branch in ActiveHubProjects ever runs. Either + // message indicates the failure was caught — assert on either to remain + // resilient to whichever guard short-circuits first. + msg := err.Error() + if !strings.Contains(msg, "projects query failed") && !strings.Contains(msg, "tool reported failure") { + t.Errorf("ActiveHubProjects: error %q does not contain \"projects query failed\" or \"tool reported failure\"", msg) + } +} + +func TestActiveHubProjects_SuccessFalse_WithErrorMsg(t *testing.T) { + srv := testutil.NewMCPServer(t, testutil.MCPScenario{ + SessionID: "sid-autherr", + Tools: map[string]testutil.MCPHandler{ + "fusion_mcp_read": func(args map[string]any) testutil.MCPResponse { + return testutil.MCPResponse{ + ContentText: `{"success": false, "error": "auth failed"}`, + } + }, + }, + }) + + client := &Client{Endpoint: srv.URL, HTTP: srv.Client()} + _, err := client.ActiveHubProjects(testCtx(t)) + if err == nil { + t.Fatal("ActiveHubProjects: expected error for success:false with error, got nil") + } + if !strings.Contains(err.Error(), "auth failed") { + t.Errorf("ActiveHubProjects: error %q does not contain \"auth failed\"", err.Error()) + } +} + +func TestOpenDocument_Roundtrip(t *testing.T) { + var capturedArgs map[string]any + srv := testutil.NewMCPServer(t, testutil.MCPScenario{ + SessionID: "sid-open", + Tools: map[string]testutil.MCPHandler{ + "fusion_mcp_execute": func(args map[string]any) testutil.MCPResponse { + capturedArgs = args + // Plain non-JSON text — parseToolErrorText returns "" because + // it doesn't start with '{' or '['. Production treats as success. + return testutil.MCPResponse{ContentText: "opened"} + }, + }, + }) + + client := &Client{Endpoint: srv.URL, HTTP: srv.Client()} + if err := client.OpenDocument(testCtx(t), "validfileid"); err != nil { + t.Fatalf("OpenDocument: unexpected error: %v", err) + } + if capturedArgs == nil { + t.Fatal("handler was not invoked (capturedArgs is nil)") + } + if got := capturedArgs["featureType"]; got != "document" { + t.Errorf("args[featureType] = %v, want \"document\"", got) + } + obj, ok := capturedArgs["object"].(map[string]any) + if !ok { + t.Fatalf("args[object] is %T, want map[string]any", capturedArgs["object"]) + } + if got := obj["operation"]; got != "open" { + t.Errorf("args.object[operation] = %v, want \"open\"", got) + } + if got := obj["fileId"]; got != "validfileid" { + t.Errorf("args.object[fileId] = %v, want \"validfileid\"", got) + } +} + +func TestInsertDocument_ScriptShape(t *testing.T) { + var capturedArgs map[string]any + srv := testutil.NewMCPServer(t, testutil.MCPScenario{ + SessionID: "sid-insert", + Tools: map[string]testutil.MCPHandler{ + "fusion_mcp_execute": func(args map[string]any) testutil.MCPResponse { + capturedArgs = args + return testutil.MCPResponse{ContentText: "inserted"} + }, + }, + }) + + client := &Client{Endpoint: srv.URL, HTTP: srv.Client()} + const fileID = "urn:adsk.wipprod:dm.lineage:abc" + if err := client.InsertDocument(testCtx(t), fileID); err != nil { + t.Fatalf("InsertDocument: unexpected error: %v", err) + } + if capturedArgs == nil { + t.Fatal("handler was not invoked (capturedArgs is nil)") + } + if got := capturedArgs["featureType"]; got != "script" { + t.Errorf("args[featureType] = %v, want \"script\"", got) + } + obj, ok := capturedArgs["object"].(map[string]any) + if !ok { + t.Fatalf("args[object] is %T, want map[string]any", capturedArgs["object"]) + } + script, ok := obj["script"].(string) + if !ok { + t.Fatalf("args.object[script] is %T, want string", obj["script"]) + } + if !strings.Contains(script, "file_id") { + t.Errorf("script does not contain \"file_id\" variable name:\n%s", script) + } + // JSON-encoded fileId is the original URN wrapped in double quotes. + wantQuoted := `"` + fileID + `"` + if !strings.Contains(script, wantQuoted) { + t.Errorf("script does not contain JSON-encoded fileId %s:\n%s", wantQuoted, script) + } +} + +func TestInvoke_SessionRetryOn404(t *testing.T) { + var ( + mu sync.Mutex + calls int + ) + handler := func(args map[string]any) testutil.MCPResponse { + mu.Lock() + defer mu.Unlock() + calls++ + if calls == 1 { + return testutil.MCPResponse{SessionExpired: true} + } + return testutil.MCPResponse{ContentText: `{"success":true,"projects":[]}`} + } + srv := testutil.NewMCPServer(t, testutil.MCPScenario{ + SessionID: "sid-1", + Tools: map[string]testutil.MCPHandler{ + "fusion_mcp_read": handler, + }, + }) + + client := &Client{Endpoint: srv.URL, HTTP: srv.Client()} + projects, err := client.ActiveHubProjects(testCtx(t)) + if err != nil { + t.Fatalf("ActiveHubProjects: unexpected error: %v", err) + } + if len(projects) != 0 { + t.Errorf("expected 0 projects after retry, got %d (%+v)", len(projects), projects) + } + if got := srv.InitCount(); got != 2 { + t.Errorf("InitCount = %d, want 2 (initial + re-handshake after 404)", got) + } + if got := srv.CallCount("fusion_mcp_read"); got != 2 { + t.Errorf("CallCount(fusion_mcp_read) = %d, want 2 (one 404 + one success)", got) + } + sids := srv.SessionIDsSeen() + if len(sids) != 2 { + t.Fatalf("SessionIDsSeen len = %d, want 2 (got %v)", len(sids), sids) + } + for i, s := range sids { + if s != "sid-1" { + t.Errorf("SessionIDsSeen[%d] = %q, want \"sid-1\"", i, s) + } + } +} + +func TestInvoke_SSEResponse(t *testing.T) { + srv := testutil.NewMCPServer(t, testutil.MCPScenario{ + SessionID: "sid-sse", + SSEMode: true, + Tools: map[string]testutil.MCPHandler{ + "fusion_mcp_read": func(args map[string]any) testutil.MCPResponse { + return testutil.MCPResponse{ + ContentText: `{"success":true,"projects":[{"id":"P1","name":"Solo"}]}`, + } + }, + }, + }) + + client := &Client{Endpoint: srv.URL, HTTP: srv.Client()} + projects, err := client.ActiveHubProjects(testCtx(t)) + if err != nil { + t.Fatalf("ActiveHubProjects (SSE): unexpected error: %v", err) + } + if len(projects) != 1 { + t.Fatalf("ActiveHubProjects (SSE): got %d projects, want 1 (%+v)", len(projects), projects) + } + if projects[0].Name != "Solo" { + t.Errorf("ActiveHubProjects (SSE): name = %q, want \"Solo\"", projects[0].Name) + } +} + +func TestSession_StatelessMode(t *testing.T) { + srv := testutil.NewMCPServer(t, testutil.MCPScenario{ + SessionID: "", // stateless: server emits no Mcp-Session-Id header on init + Tools: map[string]testutil.MCPHandler{ + "fusion_mcp_read": func(args map[string]any) testutil.MCPResponse { + return testutil.MCPResponse{ + ContentText: `{"success":true,"projects":[]}`, + } + }, + }, + }) + + client := &Client{Endpoint: srv.URL, HTTP: srv.Client()} + if _, err := client.ActiveHubProjects(testCtx(t)); err != nil { + t.Fatalf("ActiveHubProjects (stateless): unexpected error: %v", err) + } + sids := srv.SessionIDsSeen() + if len(sids) != 1 { + t.Fatalf("SessionIDsSeen len = %d, want 1 (got %v)", len(sids), sids) + } + if sids[0] != "" { + t.Errorf("SessionIDsSeen[0] = %q, want \"\" (stateless: client must omit header)", sids[0]) + } +} diff --git a/internal/testutil/graphql.go b/internal/testutil/graphql.go new file mode 100644 index 0000000..594f798 --- /dev/null +++ b/internal/testutil/graphql.go @@ -0,0 +1,105 @@ +// Package testutil provides shared test helpers for spinning up in-process +// httptest.Server instances that emulate the upstream services FusionDataCLI +// talks to: the APS GraphQL endpoint and the Fusion desktop MCP JSON-RPC +// server. Tests in `auth`, `api`, and `fusion` packages all need to issue +// real HTTP requests against fakes; this package keeps that boilerplate in +// one place so per-package test files can stay focused on the behaviour +// under test. +package testutil + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +// GraphQLRequest is the decoded view of an APS GraphQL POST that the +// handler sees. AuthHeader is the raw Authorization header value (typically +// "Bearer "); Region is the X-Ads-Region header (empty if unset). +type GraphQLRequest struct { + Query string + Variables map[string]any + AuthHeader string + Region string +} + +// GraphQLResponse is what a GraphQLServer handler returns for a given +// request. Either Data (marshaled into the response's "data" field) or +// Errors (each becomes {"message": ...} in the response's "errors" array) +// or both may be set. Status defaults to 200; RawBody, if non-empty, is +// sent verbatim and Data/Errors/Status are ignored. +type GraphQLResponse struct { + Data any + Errors []string + Status int + RawBody string +} + +// GraphQLServer starts an httptest.Server that decodes APS GraphQL +// requests and feeds them to handler, replying with whatever +// GraphQLResponse the handler returns. The server is closed via +// t.Cleanup, so callers don't have to defer Close(). +// +// The fake doesn't enforce method or Content-Type — those are the +// caller's job to assert on the captured request when relevant. +func GraphQLServer(t *testing.T, handler func(req GraphQLRequest) GraphQLResponse) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("testutil: reading request body: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + var req struct { + Query string `json:"query"` + Variables map[string]any `json:"variables"` + } + if err := json.Unmarshal(body, &req); err != nil { + t.Errorf("testutil: decoding GraphQL request body %q: %v", body, err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + resp := handler(GraphQLRequest{ + Query: req.Query, + Variables: req.Variables, + AuthHeader: r.Header.Get("Authorization"), + Region: r.Header.Get("X-Ads-Region"), + }) + + if resp.Status == 0 { + resp.Status = http.StatusOK + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(resp.Status) + + if resp.RawBody != "" { + _, _ = io.WriteString(w, resp.RawBody) + return + } + + envelope := map[string]any{} + if resp.Data != nil { + envelope["data"] = resp.Data + } + if len(resp.Errors) > 0 { + errs := make([]map[string]string, len(resp.Errors)) + for i, m := range resp.Errors { + errs[i] = map[string]string{"message": m} + } + envelope["errors"] = errs + } + // Always emit a "data" key even when nil, so the client's empty-data + // branch is reachable: the production gqlQuery returns an error if + // the JSON has zero bytes in `data`. Encoding `nil` produces "null" + // which still has length, so handlers that want to trigger the + // "empty data" path should set Data to json.RawMessage("") explicitly + // (or use RawBody for that). + _ = json.NewEncoder(w).Encode(envelope) + })) + t.Cleanup(srv.Close) + return srv +} diff --git a/internal/testutil/mcp.go b/internal/testutil/mcp.go new file mode 100644 index 0000000..3b6349c --- /dev/null +++ b/internal/testutil/mcp.go @@ -0,0 +1,206 @@ +package testutil + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" +) + +// MCPHandler is invoked for each tools/call request the fake server +// receives. args is the decoded "arguments" map from the JSON-RPC params. +// The returned MCPResponse controls what the fake replies with (or whether +// it rejects the call as session-expired). +type MCPHandler func(args map[string]any) MCPResponse + +// MCPResponse is one tools/call reply. +// +// - ContentText: payload returned as result.content[0].text (the JSON +// string the production client further decodes via parseToolErrorText). +// - IsError: sets result.isError on the JSON-RPC response. +// - SessionExpired: when true, the server returns 404 (production retries +// with a fresh handshake). +// - Status: overrides the default 200 (rare; use for non-404 error +// responses like 500). +type MCPResponse struct { + ContentText string + IsError bool + SessionExpired bool + Status int +} + +// MCPScenario configures the fake MCP server. +// +// - SessionID: returned in the Mcp-Session-Id response header on +// initialize. Empty string emulates stateless mode (no header). +// - Tools: handlers indexed by tool name. A request for an unmapped tool +// fails the test. +// - SSEMode: when true, tools/call responses are wrapped as +// "data: \n\n" SSE events instead of plain JSON. +type MCPScenario struct { + SessionID string + Tools map[string]MCPHandler + SSEMode bool +} + +// MCPServer wraps an httptest.Server with per-tool call counts so tests +// can assert on retry / session-cache behaviour. +type MCPServer struct { + *httptest.Server + mu sync.Mutex + calls map[string]int + inits int + sidSeen []string // session IDs received on tools/call requests +} + +// InitCount reports how many times the initialize handshake ran. +func (s *MCPServer) InitCount() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.inits +} + +// CallCount reports how many tools/call requests targeted toolName. +func (s *MCPServer) CallCount(toolName string) int { + s.mu.Lock() + defer s.mu.Unlock() + return s.calls[toolName] +} + +// SessionIDsSeen returns the Mcp-Session-Id header values received on +// tools/call requests in arrival order. Useful for asserting that the +// production client sent the cached SID after init, and re-handshake +// happened after a session-expired response. +func (s *MCPServer) SessionIDsSeen() []string { + s.mu.Lock() + defer s.mu.Unlock() + out := make([]string, len(s.sidSeen)) + copy(out, s.sidSeen) + return out +} + +// NewMCPServer starts a fake Fusion MCP JSON-RPC server scripted by +// scenario. The server handles `initialize`, the +// `notifications/initialized` notification (HTTP 204), and `tools/call` +// dispatched via scenario.Tools. Auto-closed via t.Cleanup. +func NewMCPServer(t *testing.T, scenario MCPScenario) *MCPServer { + t.Helper() + s := &MCPServer{calls: map[string]int{}} + s.Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("testutil: reading MCP request body: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + var req struct { + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` + Method string `json:"method"` + Params json.RawMessage `json:"params"` + } + if err := json.Unmarshal(body, &req); err != nil { + t.Errorf("testutil: decoding MCP request %q: %v", body, err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + switch req.Method { + case "initialize": + s.mu.Lock() + s.inits++ + s.mu.Unlock() + if scenario.SessionID != "" { + w.Header().Set("Mcp-Session-Id", scenario.SessionID) + } + writeJSONRPC(w, req.ID, map[string]any{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]any{}, + "serverInfo": map[string]any{"name": "fake-mcp", "version": "0"}, + }) + + case "notifications/initialized": + w.WriteHeader(http.StatusNoContent) + + case "tools/call": + var p struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + } + if err := json.Unmarshal(req.Params, &p); err != nil { + t.Errorf("testutil: decoding tools/call params: %v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + s.mu.Lock() + s.calls[p.Name]++ + s.sidSeen = append(s.sidSeen, r.Header.Get("Mcp-Session-Id")) + s.mu.Unlock() + + handler, ok := scenario.Tools[p.Name] + if !ok { + t.Errorf("testutil: no scenario handler for tool %q", p.Name) + http.Error(w, "no handler", http.StatusInternalServerError) + return + } + resp := handler(p.Arguments) + if resp.SessionExpired { + http.Error(w, "session expired", http.StatusNotFound) + return + } + if resp.Status != 0 && resp.Status != http.StatusOK { + http.Error(w, "scripted non-OK", resp.Status) + return + } + + payload := map[string]any{ + "content": []map[string]any{{"type": "text", "text": resp.ContentText}}, + "isError": resp.IsError, + } + if scenario.SSEMode { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + var buf bytes.Buffer + _ = json.NewEncoder(&buf).Encode(jsonRPCResponse{ + JSONRPC: "2.0", ID: req.ID, Result: mustMarshal(payload), + }) + _, _ = io.WriteString(w, "data: "+buf.String()+"\n") + return + } + writeJSONRPC(w, req.ID, payload) + + default: + t.Errorf("testutil: unexpected MCP method %q", req.Method) + http.Error(w, "unexpected method", http.StatusBadRequest) + } + })) + t.Cleanup(s.Close) + return s +} + +type jsonRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` + Result json.RawMessage `json:"result,omitempty"` +} + +func writeJSONRPC(w http.ResponseWriter, id int, result any) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(jsonRPCResponse{ + JSONRPC: "2.0", + ID: id, + Result: mustMarshal(result), + }) +} + +func mustMarshal(v any) json.RawMessage { + b, err := json.Marshal(v) + if err != nil { + panic(err) + } + return b +}