From 01f626905d0d643f71cbcff8f55148bb74fe07b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=B6hmwalder?= Date: Thu, 21 Dec 2023 13:08:36 +0100 Subject: [PATCH] add context to public functions Almost all public functions of this library make one or more HTTP requests behind the scenes. Since this is an operation that can potentially block indefinitely, it is good practice to pass a context into these functions. That way, the operation can be cancelled "from above" if necessary. Add ...Context() variants of all public functions that end up making HTTP requests. Maintain external API compatibility by keeping the old functions as wrappers, calling the new ...Context() functions with context.Background(). An exception to this is the Client interface in nexus.go and its implementation. There, the context parameter is added directly. This is technically a breaking change, but I suspect not many people will be using this low-level interface anyway. --- iq/application.go | 53 +++-- iq/application_test.go | 19 +- iq/componentDetails.go | 45 +++-- iq/componentDetails_test.go | 21 +- iq/componentLabels.go | 85 +++++--- iq/componentLabels_test.go | 7 +- iq/componentVersions.go | 11 +- iq/componentVersions_test.go | 3 +- iq/componentsRemediation.go | 47 +++-- iq/componentsRemediation_test.go | 7 +- iq/dataRetentionPolicies.go | 25 ++- iq/dataRetentionPolicies_test.go | 7 +- iq/evaluation.go | 13 +- iq/evaluation_test.go | 3 +- iq/organization.go | 31 ++- iq/organization_test.go | 11 +- iq/policies.go | 21 +- iq/policies_test.go | 5 +- iq/policyViolations.go | 25 ++- iq/policyViolations_test.go | 5 +- iq/reportMetrics.go | 19 +- iq/reportMetrics_test.go | 7 +- iq/reports.go | 145 +++++++++----- iq/reports_test.go | 15 +- iq/roleMemberships.go | 319 ++++++++++++++++++++----------- iq/roleMemberships_test.go | 83 ++++---- iq/roles.go | 33 +++- iq/roles_test.go | 5 +- iq/search.go | 11 +- iq/search_test.go | 5 +- iq/sourceControl.go | 83 +++++--- iq/sourceControl_test.go | 27 +-- iq/users.go | 35 ++-- iq/users_test.go | 9 +- nexus.go | 35 ++-- rm/anonymous.go | 17 +- rm/assets.go | 31 ++- rm/assets_test.go | 11 +- rm/components.go | 43 +++-- rm/components_test.go | 17 +- rm/email.go | 25 ++- rm/groovyBlobStore.go | 23 ++- rm/groovyBlobStore_test.go | 11 +- rm/groovyRepository.go | 31 ++- rm/groovyRepository_test.go | 6 +- rm/maintenance.go | 21 +- rm/maintenance_test.go | 5 +- rm/readOnly.go | 31 ++- rm/repositories.go | 45 +++-- rm/repositories_test.go | 5 +- rm/roles.go | 17 +- rm/scripts.go | 75 +++++--- rm/scripts_test.go | 29 +-- rm/search.go | 25 ++- rm/search_test.go | 7 +- rm/staging.go | 25 ++- rm/status.go | 15 +- rm/support.go | 11 +- rm/tagging.go | 51 +++-- rm/tagging_test.go | 9 +- 60 files changed, 1214 insertions(+), 647 deletions(-) diff --git a/iq/application.go b/iq/application.go index 21a0dce..769ad97 100644 --- a/iq/application.go +++ b/iq/application.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -44,13 +45,12 @@ type Application struct { } `json:"applicationTags,omitempty"` } -// GetApplicationByPublicID returns details on the named IQ application -func GetApplicationByPublicID(iq IQ, applicationPublicID string) (*Application, error) { +func GetApplicationByPublicIDContext(ctx context.Context, iq IQ, applicationPublicID string) (*Application, error) { doError := func(err error) error { return fmt.Errorf("application '%s' not found: %v", applicationPublicID, err) } endpoint := fmt.Sprintf(restApplicationByPublic, applicationPublicID) - body, _, err := iq.Get(endpoint) + body, _, err := iq.Get(ctx, endpoint) if err != nil { return nil, doError(err) } @@ -67,8 +67,12 @@ func GetApplicationByPublicID(iq IQ, applicationPublicID string) (*Application, return &resp.Applications[0], nil } -// CreateApplication creates an application in IQ with the given name and identifier -func CreateApplication(iq IQ, name, id, organizationID string) (string, error) { +// GetApplicationByPublicID returns details on the named IQ application +func GetApplicationByPublicID(iq IQ, applicationPublicID string) (*Application, error) { + return GetApplicationByPublicIDContext(context.Background(), iq, applicationPublicID) +} + +func CreateApplicationContext(ctx context.Context, iq IQ, name, id, organizationID string) (string, error) { if name == "" || id == "" || organizationID == "" { return "", fmt.Errorf("cannot create application with empty values") } @@ -82,7 +86,7 @@ func CreateApplication(iq IQ, name, id, organizationID string) (string, error) { return doError(err) } - body, _, err := iq.Post(restApplication, bytes.NewBuffer(request)) + body, _, err := iq.Post(ctx, restApplication, bytes.NewBuffer(request)) if err != nil { return doError(err) } @@ -95,17 +99,25 @@ func CreateApplication(iq IQ, name, id, organizationID string) (string, error) { return resp.ID, nil } -// DeleteApplication deletes an application in IQ with the given id -func DeleteApplication(iq IQ, applicationID string) error { - if resp, err := iq.Del(fmt.Sprintf("%s/%s", restApplication, applicationID)); err != nil && resp.StatusCode != http.StatusNoContent { +// CreateApplication creates an application in IQ with the given name and identifier +func CreateApplication(iq IQ, name, id, organizationID string) (string, error) { + return CreateApplicationContext(context.Background(), iq, name, id, organizationID) +} + +func DeleteApplicationContext(ctx context.Context, iq IQ, applicationID string) error { + if resp, err := iq.Del(ctx, fmt.Sprintf("%s/%s", restApplication, applicationID)); err != nil && resp.StatusCode != http.StatusNoContent { return fmt.Errorf("application '%s' not deleted: %v", applicationID, err) } return nil } -// GetAllApplications returns a slice of all of the applications in an IQ instance -func GetAllApplications(iq IQ) ([]Application, error) { - body, _, err := iq.Get(restApplication) +// DeleteApplication deletes an application in IQ with the given id +func DeleteApplication(iq IQ, applicationID string) error { + return DeleteApplicationContext(context.Background(), iq, applicationID) +} + +func GetAllApplicationsContext(ctx context.Context, iq IQ) ([]Application, error) { + body, _, err := iq.Get(ctx, restApplication) if err != nil { return nil, fmt.Errorf("applications not found: %v", err) } @@ -118,14 +130,18 @@ func GetAllApplications(iq IQ) ([]Application, error) { return resp.Applications, nil } -// GetApplicationsByOrganization returns all applications under a given organization -func GetApplicationsByOrganization(iq IQ, organizationName string) ([]Application, error) { - org, err := GetOrganizationByName(iq, organizationName) +// GetAllApplications returns a slice of all of the applications in an IQ instance +func GetAllApplications(iq IQ) ([]Application, error) { + return GetAllApplicationsContext(context.Background(), iq) +} + +func GetApplicationsByOrganizationContext(ctx context.Context, iq IQ, organizationName string) ([]Application, error) { + org, err := GetOrganizationByNameContext(ctx, iq, organizationName) if err != nil { return nil, fmt.Errorf("organization not found: %v", err) } - apps, err := GetAllApplications(iq) + apps, err := GetAllApplicationsContext(ctx, iq) if err != nil { return nil, fmt.Errorf("could not get applications list: %v", err) } @@ -139,3 +155,8 @@ func GetApplicationsByOrganization(iq IQ, organizationName string) ([]Applicatio return orgApps, nil } + +// GetApplicationsByOrganization returns all applications under a given organization +func GetApplicationsByOrganization(iq IQ, organizationName string) ([]Application, error) { + return GetApplicationsByOrganizationContext(context.Background(), iq, organizationName) +} diff --git a/iq/application_test.go b/iq/application_test.go index 3f7ee71..b89c774 100644 --- a/iq/application_test.go +++ b/iq/application_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -109,7 +110,7 @@ func TestGetAllApplications(t *testing.T) { iq, mock := applicationTestIQ(t) defer mock.Close() - applications, err := GetAllApplications(iq) + applications, err := GetAllApplicationsContext(context.Background(), iq) if err != nil { t.Error(err) } @@ -133,7 +134,7 @@ func TestGetApplicationByPublicID(t *testing.T) { dummyAppsIdx := 2 - got, err := GetApplicationByPublicID(iq, dummyApps[dummyAppsIdx].PublicID) + got, err := GetApplicationByPublicIDContext(context.Background(), iq, dummyApps[dummyAppsIdx].PublicID) if err != nil { t.Error(err) } @@ -197,7 +198,7 @@ func TestCreateApplication(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := CreateApplication(tt.args.iq, tt.args.name, tt.args.id, tt.args.organizationID) + got, err := CreateApplicationContext(context.Background(), tt.args.iq, tt.args.name, tt.args.id, tt.args.organizationID) if (err != nil) != tt.wantErr { t.Errorf("CreateApplication() error = %v, wantErr %v", err, tt.wantErr) return @@ -216,16 +217,16 @@ func TestDeleteApplication(t *testing.T) { deleteMeApp := Application{PublicID: "deleteMeApp", Name: "deleteMeApp", OrganizationID: "deleteMeAppOrgId"} var err error - deleteMeApp.ID, err = CreateApplication(iq, deleteMeApp.Name, deleteMeApp.PublicID, deleteMeApp.OrganizationID) + deleteMeApp.ID, err = CreateApplicationContext(context.Background(), iq, deleteMeApp.Name, deleteMeApp.PublicID, deleteMeApp.OrganizationID) if err != nil { t.Fatal(err) } - if err := DeleteApplication(iq, deleteMeApp.PublicID); err != nil { + if err := DeleteApplicationContext(context.Background(), iq, deleteMeApp.PublicID); err != nil { t.Fatal(err) } - if _, err := GetApplicationByPublicID(iq, deleteMeApp.PublicID); err == nil { + if _, err := GetApplicationByPublicIDContext(context.Background(), iq, deleteMeApp.PublicID); err == nil { t.Fatal("App was not deleted") } } @@ -254,7 +255,7 @@ func TestGetApplicationsByOrganization(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := GetApplicationsByOrganization(tt.args.iq, tt.args.organizationName) + got, err := GetApplicationsByOrganizationContext(context.Background(), tt.args.iq, tt.args.organizationName) if (err != nil) != tt.wantErr { t.Errorf("GetApplicationsByOrganization() error = %v, wantErr %v", err, tt.wantErr) return @@ -272,7 +273,7 @@ func ExampleGetAllApplications() { panic(err) } - applications, err := GetAllApplications(iq) + applications, err := GetAllApplicationsContext(context.Background(), iq) if err != nil { panic(err) } @@ -286,7 +287,7 @@ func ExampleCreateApplication() { panic(err) } - appID, err := CreateApplication(iq, "name", "id", "organization") + appID, err := CreateApplicationContext(context.Background(), iq, "name", "id", "organization") if err != nil { panic(err) } diff --git a/iq/componentDetails.go b/iq/componentDetails.go index 0885ed1..907947b 100644 --- a/iq/componentDetails.go +++ b/iq/componentDetails.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" ) @@ -42,17 +43,20 @@ type ComponentDetail struct { } `json:"securityData"` } -// GetComponent returns information on a named component -func GetComponent(iq IQ, component Component) (ComponentDetail, error) { - deets, err := GetComponents(iq, []Component{component}) +func GetComponentContext(ctx context.Context, iq IQ, component Component) (ComponentDetail, error) { + deets, err := GetComponentsContext(ctx, iq, []Component{component}) if deets == nil || len(deets) == 0 { return ComponentDetail{}, err } return deets[0], err } -// GetComponents returns information on the named components -func GetComponents(iq IQ, components []Component) ([]ComponentDetail, error) { +// GetComponent returns information on a named component +func GetComponent(iq IQ, component Component) (ComponentDetail, error) { + return GetComponentContext(context.Background(), iq, component) +} + +func GetComponentsContext(ctx context.Context, iq IQ, components []Component) ([]ComponentDetail, error) { reqComponents := detailsRequest{Components: make([]componentRequested, len(components))} for i, c := range components { reqComponents.Components[i] = componentRequestedFromComponent(c) @@ -63,7 +67,7 @@ func GetComponents(iq IQ, components []Component) ([]ComponentDetail, error) { return nil, fmt.Errorf("could not generate request: %v", err) } - body, _, err := iq.Post(restComponentDetails, bytes.NewBuffer(req)) + body, _, err := iq.Post(ctx, restComponentDetails, bytes.NewBuffer(req)) if err != nil { return nil, fmt.Errorf("could not find component details: %v", err) } @@ -76,13 +80,17 @@ func GetComponents(iq IQ, components []Component) ([]ComponentDetail, error) { return resp.ComponentDetails, nil } -// GetComponentsByApplication returns an array with all components along with their -func GetComponentsByApplication(iq IQ, appPublicID string) ([]ComponentDetail, error) { +// GetComponents returns information on the named components +func GetComponents(iq IQ, components []Component) ([]ComponentDetail, error) { + return GetComponentsContext(context.Background(), iq, components) +} + +func GetComponentsByApplicationContext(ctx context.Context, iq IQ, appPublicID string) ([]ComponentDetail, error) { componentHashes := make(map[string]struct{}) components := make([]Component, 0) stages := []Stage{StageBuild, StageStageRelease, StageRelease, StageOperate} for _, stage := range stages { - if report, err := GetRawReportByAppID(iq, appPublicID, string(stage)); err == nil { + if report, err := GetRawReportByAppIDContext(ctx, iq, appPublicID, string(stage)); err == nil { for _, c := range report.Components { if _, ok := componentHashes[c.Hash]; !ok { componentHashes[c.Hash] = struct{}{} @@ -92,12 +100,16 @@ func GetComponentsByApplication(iq IQ, appPublicID string) ([]ComponentDetail, e } } - return GetComponents(iq, components) + return GetComponentsContext(ctx, iq, components) } -// GetAllComponents returns an array with all components along with their -func GetAllComponents(iq IQ) ([]ComponentDetail, error) { - apps, err := GetAllApplications(iq) +// GetComponentsByApplication returns an array with all components along with their +func GetComponentsByApplication(iq IQ, appPublicID string) ([]ComponentDetail, error) { + return GetComponentsByApplicationContext(context.Background(), iq, appPublicID) +} + +func GetAllComponentsContext(ctx context.Context, iq IQ) ([]ComponentDetail, error) { + apps, err := GetAllApplicationsContext(ctx, iq) if err != nil { return nil, err } @@ -106,7 +118,7 @@ func GetAllComponents(iq IQ) ([]ComponentDetail, error) { components := make([]ComponentDetail, 0) for _, app := range apps { - appComponents, err := GetComponentsByApplication(iq, app.PublicID) + appComponents, err := GetComponentsByApplicationContext(ctx, iq, app.PublicID) // TODO: catcher if err != nil { return nil, err @@ -122,3 +134,8 @@ func GetAllComponents(iq IQ) ([]ComponentDetail, error) { return components, nil } + +// GetAllComponents returns an array with all components +func GetAllComponents(iq IQ) ([]ComponentDetail, error) { + return GetAllComponentsContext(context.Background(), iq) +} diff --git a/iq/componentDetails_test.go b/iq/componentDetails_test.go index e316fc6..8701509 100644 --- a/iq/componentDetails_test.go +++ b/iq/componentDetails_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -88,13 +89,13 @@ func TestGetComponent(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := GetComponent(tt.args.iq, tt.args.component) + got, err := GetComponentContext(context.Background(), tt.args.iq, tt.args.component) if (err != nil) != tt.wantErr { - t.Errorf("GetComponent() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("GetComponentContext() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("GetComponent() = %v, want %v", got, tt.want) + t.Errorf("GetComponentContext() = %v, want %v", got, tt.want) } }) } @@ -106,7 +107,7 @@ func TestGetComponents(t *testing.T) { expected := dummyComponentDetails[0] - details, err := GetComponents(iq, []Component{expected.Component}) + details, err := GetComponentsContext(context.Background(), iq, []Component{expected.Component}) if err != nil { t.Error(err) } @@ -145,13 +146,13 @@ func TestGetComponentsByApplication(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := GetComponentsByApplication(tt.args.iq, tt.args.appPublicID) + got, err := GetComponentsByApplicationContext(context.Background(), tt.args.iq, tt.args.appPublicID) if (err != nil) != tt.wantErr { - t.Errorf("GetComponentsByApplication() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("GetComponentsByApplicationContext() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("GetComponentsByApplication() = %v, want %v", got, tt.want) + t.Errorf("GetComponentsByApplicationContext() = %v, want %v", got, tt.want) } }) } @@ -180,13 +181,13 @@ func TestGetAllComponents(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := GetAllComponents(tt.args.iq) + got, err := GetAllComponentsContext(context.Background(), tt.args.iq) if (err != nil) != tt.wantErr { - t.Errorf("GetAllComponents() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("GetAllComponentsContext() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("GetAllComponents() = %v, want %v", got, tt.want) + t.Errorf("GetAllComponentsContext() = %v, want %v", got, tt.want) } }) } diff --git a/iq/componentLabels.go b/iq/componentLabels.go index bc9315e..8c3d7d3 100644 --- a/iq/componentLabels.go +++ b/iq/componentLabels.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -26,15 +27,14 @@ type IqComponentLabel struct { Color string `json:"color"` } -// ComponentLabelApply adds an existing label to a component for a given application -func ComponentLabelApply(iq IQ, comp Component, appID, label string) error { - app, err := GetApplicationByPublicID(iq, appID) +func ComponentLabelApplyContext(ctx context.Context, iq IQ, comp Component, appID, label string) error { + app, err := GetApplicationByPublicIDContext(ctx, iq, appID) if err != nil { return fmt.Errorf("could not retrieve application with ID %s: %v", appID, err) } endpoint := fmt.Sprintf(restLabelComponent, comp.Hash, url.PathEscape(label), app.ID) - _, resp, err := iq.Post(endpoint, nil) + _, resp, err := iq.Post(ctx, endpoint, nil) if err != nil { if resp == nil || resp.StatusCode != http.StatusNoContent { return fmt.Errorf("could not apply label: %v", err) @@ -44,15 +44,19 @@ func ComponentLabelApply(iq IQ, comp Component, appID, label string) error { return nil } -// ComponentLabelUnapply removes an existing association between a label and a component -func ComponentLabelUnapply(iq IQ, comp Component, appID, label string) error { - app, err := GetApplicationByPublicID(iq, appID) +// ComponentLabelApply adds an existing label to a component for a given application +func ComponentLabelApply(iq IQ, comp Component, appID, label string) error { + return ComponentLabelApplyContext(context.Background(), iq, comp, appID, label) +} + +func ComponentLabelUnapplyContext(ctx context.Context, iq IQ, comp Component, appID, label string) error { + app, err := GetApplicationByPublicIDContext(ctx, iq, appID) if err != nil { return fmt.Errorf("could not retrieve application with ID %s: %v", appID, err) } endpoint := fmt.Sprintf(restLabelComponent, comp.Hash, url.PathEscape(label), app.ID) - resp, err := iq.Del(endpoint) + resp, err := iq.Del(ctx, endpoint) if err != nil { if resp == nil || resp.StatusCode != http.StatusNoContent { return fmt.Errorf("could not unapply label: %v", err) @@ -62,8 +66,13 @@ func ComponentLabelUnapply(iq IQ, comp Component, appID, label string) error { return nil } -func getComponentLabels(iq IQ, endpoint string) ([]IqComponentLabel, error) { - body, _, err := iq.Get(endpoint) +// ComponentLabelUnapply removes an existing association between a label and a component +func ComponentLabelUnapply(iq IQ, comp Component, appID, label string) error { + return ComponentLabelUnapplyContext(context.Background(), iq, comp, appID, label) +} + +func getComponentLabels(ctx context.Context, iq IQ, endpoint string) ([]IqComponentLabel, error) { + body, _, err := iq.Get(ctx, endpoint) if err != nil { return nil, err } @@ -77,26 +86,34 @@ func getComponentLabels(iq IQ, endpoint string) ([]IqComponentLabel, error) { return labels, nil } +func GetComponentLabelsByOrganizationContext(ctx context.Context, iq IQ, organization string) ([]IqComponentLabel, error) { + endpoint := fmt.Sprintf(restLabelComponentByOrg, organization) + return getComponentLabels(ctx, iq, endpoint) +} + // GetComponentLabelsByOrganization retrieves an array of an organization's component label func GetComponentLabelsByOrganization(iq IQ, organization string) ([]IqComponentLabel, error) { - endpoint := fmt.Sprintf(restLabelComponentByOrg, organization) - return getComponentLabels(iq, endpoint) + return GetComponentLabelsByOrganizationContext(context.Background(), iq, organization) +} + +func GetComponentLabelsByAppIDContext(ctx context.Context, iq IQ, appID string) ([]IqComponentLabel, error) { + endpoint := fmt.Sprintf(restLabelComponentByApp, appID) + return getComponentLabels(ctx, iq, endpoint) } // GetComponentLabelsByAppID retrieves an array of an organization's component label func GetComponentLabelsByAppID(iq IQ, appID string) ([]IqComponentLabel, error) { - endpoint := fmt.Sprintf(restLabelComponentByApp, appID) - return getComponentLabels(iq, endpoint) + return GetComponentLabelsByAppIDContext(context.Background(), iq, appID) } -func createLabel(iq IQ, endpoint, label, description, color string) (IqComponentLabel, error) { +func createLabel(ctx context.Context, iq IQ, endpoint, label, description, color string) (IqComponentLabel, error) { var labelResponse IqComponentLabel request, err := json.Marshal(IqComponentLabel{Label: label, Description: description, Color: color}) if err != nil { return labelResponse, fmt.Errorf("could not marshal label: %v", err) } - body, resp, err := iq.Post(endpoint, bytes.NewBuffer(request)) + body, resp, err := iq.Post(ctx, endpoint, bytes.NewBuffer(request)) if resp.StatusCode != http.StatusOK { return labelResponse, fmt.Errorf("did not succeeed in creating label: %v", err) } @@ -109,22 +126,29 @@ func createLabel(iq IQ, endpoint, label, description, color string) (IqComponent return labelResponse, nil } +func CreateComponentLabelForOrganizationContext(ctx context.Context, iq IQ, organization, label, description, color string) (IqComponentLabel, error) { + endpoint := fmt.Sprintf(restLabelComponentByOrg, organization) + return createLabel(ctx, iq, endpoint, label, description, color) +} + // CreateComponentLabelForOrganization creates a label for an organization func CreateComponentLabelForOrganization(iq IQ, organization, label, description, color string) (IqComponentLabel, error) { - endpoint := fmt.Sprintf(restLabelComponentByOrg, organization) - return createLabel(iq, endpoint, label, description, color) + return CreateComponentLabelForOrganizationContext(context.Background(), iq, organization, label, description, color) +} + +func CreateComponentLabelForApplicationContext(ctx context.Context, iq IQ, appID, label, description, color string) (IqComponentLabel, error) { + endpoint := fmt.Sprintf(restLabelComponentByApp, appID) + return createLabel(ctx, iq, endpoint, label, description, color) } // CreateComponentLabelForApplication creates a label for an application func CreateComponentLabelForApplication(iq IQ, appID, label, description, color string) (IqComponentLabel, error) { - endpoint := fmt.Sprintf(restLabelComponentByApp, appID) - return createLabel(iq, endpoint, label, description, color) + return CreateComponentLabelForApplicationContext(context.Background(), iq, appID, label, description, color) } -// DeleteComponentLabelForOrganization deletes a label from an organization -func DeleteComponentLabelForOrganization(iq IQ, organization, label string) error { +func DeleteComponentLabelForOrganizationContext(ctx context.Context, iq IQ, organization, label string) error { endpoint := fmt.Sprintf(restLabelComponentByOrgDel, organization, label) - resp, err := iq.Del(endpoint) + resp, err := iq.Del(ctx, endpoint) if resp.StatusCode != http.StatusOK { return fmt.Errorf("did not succeeed in deleting label: %v", err) } @@ -133,10 +157,14 @@ func DeleteComponentLabelForOrganization(iq IQ, organization, label string) erro return nil } -// DeleteComponentLabelForApplication deletes a label from an application -func DeleteComponentLabelForApplication(iq IQ, appID, label string) error { +// DeleteComponentLabelForOrganization deletes a label from an organization +func DeleteComponentLabelForOrganization(iq IQ, organization, label string) error { + return DeleteComponentLabelForOrganizationContext(context.Background(), iq, organization, label) +} + +func DeleteComponentLabelForApplicationContext(ctx context.Context, iq IQ, appID, label string) error { endpoint := fmt.Sprintf(restLabelComponentByAppDel, appID, label) - resp, err := iq.Del(endpoint) + resp, err := iq.Del(ctx, endpoint) if resp.StatusCode != http.StatusOK { return fmt.Errorf("did not succeeed in deleting label: %v", err) } @@ -144,3 +172,8 @@ func DeleteComponentLabelForApplication(iq IQ, appID, label string) error { return nil } + +// DeleteComponentLabelForApplication deletes a label from an application +func DeleteComponentLabelForApplication(iq IQ, appID, label string) error { + return DeleteComponentLabelForApplicationContext(context.Background(), iq, appID, label) +} diff --git a/iq/componentLabels_test.go b/iq/componentLabels_test.go index 14e7e2d..416de29 100644 --- a/iq/componentLabels_test.go +++ b/iq/componentLabels_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -55,7 +56,7 @@ func TestComponentLabelApply(t *testing.T) { label, component, appID := dummyLabels[0], dummyComponent, dummyApps[0].PublicID - if err := ComponentLabelApply(iq, component, appID, label); err != nil { + if err := ComponentLabelApplyContext(context.Background(), iq, component, appID, label); err != nil { t.Error(err) } } @@ -66,11 +67,11 @@ func TestComponentLabelUnapply(t *testing.T) { label, component, appID := dummyLabels[0], dummyComponent, dummyApps[0].PublicID - if err := ComponentLabelApply(iq, component, appID, label); err != nil { + if err := ComponentLabelApplyContext(context.Background(), iq, component, appID, label); err != nil { t.Fatal(err) } - if err := ComponentLabelUnapply(iq, component, appID, label); err != nil { + if err := ComponentLabelUnapplyContext(context.Background(), iq, component, appID, label); err != nil { t.Error(err) } } diff --git a/iq/componentVersions.go b/iq/componentVersions.go index 0021d01..08bc44e 100644 --- a/iq/componentVersions.go +++ b/iq/componentVersions.go @@ -2,20 +2,20 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" ) const restComponentVersions = "api/v2/components/versions" -// ComponentVersions returns all known versions of a given component -func ComponentVersions(iq IQ, comp Component) (versions []string, err error) { +func ComponentVersionsContext(ctx context.Context, iq IQ, comp Component) (versions []string, err error) { str, err := json.Marshal(comp) if err != nil { return nil, fmt.Errorf("could not process component: %v", err) } - body, _, err := iq.Post(restComponentVersions, bytes.NewBuffer(str)) + body, _, err := iq.Post(ctx, restComponentVersions, bytes.NewBuffer(str)) if err != nil { return nil, fmt.Errorf("could not request component: %v", err) } @@ -26,3 +26,8 @@ func ComponentVersions(iq IQ, comp Component) (versions []string, err error) { return } + +// ComponentVersions returns all known versions of a given component +func ComponentVersions(iq IQ, comp Component) (versions []string, err error) { + return ComponentVersionsContext(context.Background(), iq, comp) +} diff --git a/iq/componentVersions_test.go b/iq/componentVersions_test.go index b543fb6..1f86b71 100644 --- a/iq/componentVersions_test.go +++ b/iq/componentVersions_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -62,7 +63,7 @@ func TestComponentVersions(t *testing.T) { iq, mock := componentVersionsTestIQ(t) defer mock.Close() - versions, err := ComponentVersions(iq, dummyComponent) + versions, err := ComponentVersionsContext(context.Background(), iq, dummyComponent) if err != nil { t.Error(err) } diff --git a/iq/componentsRemediation.go b/iq/componentsRemediation.go index b43e4e7..ace12b7 100644 --- a/iq/componentsRemediation.go +++ b/iq/componentsRemediation.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" "sync" @@ -60,13 +61,13 @@ func createRemediationEndpoint(base, id, stage string) string { return buf.String() } -func getRemediation(iq IQ, component Component, endpoint string) (Remediation, error) { +func getRemediation(ctx context.Context, iq IQ, component Component, endpoint string) (Remediation, error) { request, err := json.Marshal(component) if err != nil { return Remediation{}, fmt.Errorf("could not build the request: %v", err) } - body, _, err := iq.Post(endpoint, bytes.NewBuffer(request)) + body, _, err := iq.Post(ctx, endpoint, bytes.NewBuffer(request)) if err != nil { return Remediation{}, fmt.Errorf("could not get remediation: %v", err) } @@ -80,40 +81,47 @@ func getRemediation(iq IQ, component Component, endpoint string) (Remediation, e return results.Remediation, nil } -func getRemediationByAppInternalID(iq IQ, component Component, stage, appInternalID string) (Remediation, error) { - return getRemediation(iq, component, createRemediationEndpoint(restRemediationByApp, appInternalID, stage)) +func getRemediationByAppInternalID(ctx context.Context, iq IQ, component Component, stage, appInternalID string) (Remediation, error) { + return getRemediation(ctx, iq, component, createRemediationEndpoint(restRemediationByApp, appInternalID, stage)) } -// GetRemediationByApp retrieves the remediation information on a component based on an application's policies -func GetRemediationByApp(iq IQ, component Component, stage, applicationID string) (Remediation, error) { - app, err := GetApplicationByPublicID(iq, applicationID) +func GetRemediationByAppContext(ctx context.Context, iq IQ, component Component, stage, applicationID string) (Remediation, error) { + app, err := GetApplicationByPublicIDContext(ctx, iq, applicationID) if err != nil { return Remediation{}, fmt.Errorf("could not get application: %v", err) } - return getRemediationByAppInternalID(iq, component, stage, app.ID) + return getRemediationByAppInternalID(ctx, iq, component, stage, app.ID) } -// GetRemediationByOrg retrieves the remediation information on a component based on an organization's policies -func GetRemediationByOrg(iq IQ, component Component, stage, organizationName string) (Remediation, error) { - org, err := GetOrganizationByName(iq, organizationName) +// GetRemediationByApp retrieves the remediation information on a component based on an application's policies +func GetRemediationByApp(iq IQ, component Component, stage, applicationID string) (Remediation, error) { + return GetRemediationByAppContext(context.Background(), iq, component, stage, applicationID) +} + +func GetRemediationByOrgContext(ctx context.Context, iq IQ, component Component, stage, organizationName string) (Remediation, error) { + org, err := GetOrganizationByNameContext(ctx, iq, organizationName) if err != nil { return Remediation{}, fmt.Errorf("could not get organization: %v", err) } endpoint := createRemediationEndpoint(restRemediationByOrg, org.ID, stage) - return getRemediation(iq, component, endpoint) + return getRemediation(ctx, iq, component, endpoint) } -// GetRemediationsByAppReport retrieves the remediation information on each component of a report -func GetRemediationsByAppReport(iq IQ, applicationID, reportID string) (remediations []Remediation, err error) { - report, err := getRawReportByAppReportID(iq, applicationID, reportID) +// GetRemediationByOrg retrieves the remediation information on a component based on an organization's policies +func GetRemediationByOrg(iq IQ, component Component, stage, organizationName string) (Remediation, error) { + return GetRemediationByOrgContext(context.Background(), iq, component, stage, organizationName) +} + +func GetRemediationsByAppReportContext(ctx context.Context, iq IQ, applicationID, reportID string) (remediations []Remediation, err error) { + report, err := getRawReportByAppReportID(ctx, iq, applicationID, reportID) if err != nil { return nil, fmt.Errorf("could not get report %s for app %s: %v", reportID, applicationID, err) } - app, err := GetApplicationByPublicID(iq, applicationID) + app, err := GetApplicationByPublicIDContext(ctx, iq, applicationID) if err != nil { return nil, fmt.Errorf("could not get application: %v", err) } @@ -131,7 +139,7 @@ func GetRemediationsByAppReport(iq IQ, applicationID, reportID string) (remediat PackageURL: c.PackageURL, } var remediation Remediation - remediation, err = getRemediationByAppInternalID(iq, purl, report.ReportInfo.Stage, app.ID) + remediation, err = getRemediationByAppInternalID(ctx, iq, purl, report.ReportInfo.Stage, app.ID) if err != nil { err = fmt.Errorf("did not find remediation for '%v': %v", c, err) break @@ -157,3 +165,8 @@ func GetRemediationsByAppReport(iq IQ, applicationID, reportID string) (remediat return } + +// GetRemediationsByAppReport retrieves the remediation information on each component of a report +func GetRemediationsByAppReport(iq IQ, applicationID, reportID string) (remediations []Remediation, err error) { + return GetRemediationsByAppReportContext(context.Background(), iq, applicationID, reportID) +} diff --git a/iq/componentsRemediation_test.go b/iq/componentsRemediation_test.go index 6c7061e..83f36cb 100644 --- a/iq/componentsRemediation_test.go +++ b/iq/componentsRemediation_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "net/http" @@ -103,7 +104,7 @@ func TestRemediationByApp(t *testing.T) { id, stage := dummyApps[0].PublicID, "build" - remediation, err := GetRemediationByApp(iq, dummyComponent, stage, id) + remediation, err := GetRemediationByAppContext(context.Background(), iq, dummyComponent, stage, id) if err != nil { t.Error(err) } @@ -120,7 +121,7 @@ func TestRemediationByOrg(t *testing.T) { id, stage := dummyOrgs[0].Name, "build" - remediation, err := GetRemediationByOrg(iq, dummyComponent, stage, id) + remediation, err := GetRemediationByOrgContext(context.Background(), iq, dummyComponent, stage, id) if err != nil { t.Error(err) } @@ -138,7 +139,7 @@ func TestRemediationByAppReport(t *testing.T) { appIdx, reportID := 0, "0" - got, err := GetRemediationsByAppReport(iq, dummyApps[appIdx].PublicID, reportID) + got, err := GetRemediationsByAppReportContext(context.Background(), iq, dummyApps[appIdx].PublicID, reportID) if err != nil { t.Error(err) } diff --git a/iq/dataRetentionPolicies.go b/iq/dataRetentionPolicies.go index a7f38c6..fe6fa1c 100644 --- a/iq/dataRetentionPolicies.go +++ b/iq/dataRetentionPolicies.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" ) @@ -26,16 +27,15 @@ type DataRetentionPolicy struct { MaxAge string `json:"maxAge"` } -// GetRetentionPolicies returns the current retention policies -func GetRetentionPolicies(iq IQ, orgName string) (policies DataRetentionPolicies, err error) { - org, err := GetOrganizationByName(iq, orgName) +func GetRetentionPoliciesContext(ctx context.Context, iq IQ, orgName string) (policies DataRetentionPolicies, err error) { + org, err := GetOrganizationByNameContext(ctx, iq, orgName) if err != nil { return policies, fmt.Errorf("could not retrieve organization named %s: %v", orgName, err) } endpoint := fmt.Sprintf(restDataRetentionPolicies, org.ID) - body, _, err := iq.Get(endpoint) + body, _, err := iq.Get(ctx, endpoint) if err != nil { return policies, fmt.Errorf("did not retrieve retention policies for organization %s: %v", orgName, err) } @@ -45,9 +45,13 @@ func GetRetentionPolicies(iq IQ, orgName string) (policies DataRetentionPolicies return } -// SetRetentionPolicies updates the retention policies -func SetRetentionPolicies(iq IQ, orgName string, policies DataRetentionPolicies) error { - org, err := GetOrganizationByName(iq, orgName) +// GetRetentionPolicies returns the current retention policies +func GetRetentionPolicies(iq IQ, orgName string) (policies DataRetentionPolicies, err error) { + return GetRetentionPoliciesContext(context.Background(), iq, orgName) +} + +func SetRetentionPoliciesContext(ctx context.Context, iq IQ, orgName string, policies DataRetentionPolicies) error { + org, err := GetOrganizationByNameContext(ctx, iq, orgName) if err != nil { return fmt.Errorf("could not retrieve organization named %s: %v", orgName, err) } @@ -59,10 +63,15 @@ func SetRetentionPolicies(iq IQ, orgName string, policies DataRetentionPolicies) endpoint := fmt.Sprintf(restDataRetentionPolicies, org.ID) - _, _, err = iq.Put(endpoint, bytes.NewBuffer(request)) + _, _, err = iq.Put(ctx, endpoint, bytes.NewBuffer(request)) if err != nil { return fmt.Errorf("did not set retention policies for organization %s: %v", orgName, err) } return nil } + +// SetRetentionPolicies updates the retention policies +func SetRetentionPolicies(iq IQ, orgName string, policies DataRetentionPolicies) error { + return SetRetentionPoliciesContext(context.Background(), iq, orgName, policies) +} diff --git a/iq/dataRetentionPolicies_test.go b/iq/dataRetentionPolicies_test.go index 60b15c0..b993119 100644 --- a/iq/dataRetentionPolicies_test.go +++ b/iq/dataRetentionPolicies_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -98,7 +99,7 @@ func TestGetRetentionPolicies(t *testing.T) { iq, mock := dataRetentionPoliciesTestIQ(t) defer mock.Close() - policies, err := GetRetentionPolicies(iq, dummyOrgs[0].Name) + policies, err := GetRetentionPoliciesContext(context.Background(), iq, dummyOrgs[0].Name) if err != nil { t.Error(err) } @@ -135,12 +136,12 @@ func TestSetRetentionPolicies(t *testing.T) { SuccessMetrics: expected.SuccessMetrics, } - err := SetRetentionPolicies(iq, dummyOrgs[0].Name, retentionRequest) + err := SetRetentionPoliciesContext(context.Background(), iq, dummyOrgs[0].Name, retentionRequest) if err != nil { t.Error(err) } - got, err := GetRetentionPolicies(iq, dummyOrgs[0].Name) + got, err := GetRetentionPoliciesContext(context.Background(), iq, dummyOrgs[0].Name) if err != nil { t.Error(err) } diff --git a/iq/evaluation.go b/iq/evaluation.go index 93a543c..f2df081 100644 --- a/iq/evaluation.go +++ b/iq/evaluation.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -230,15 +231,14 @@ type iqEvaluationRequest struct { Components []Component `json:"components"` } -// EvaluateComponents evaluates the list of components -func EvaluateComponents(iq IQ, components []Component, applicationID string) (*Evaluation, error) { +func EvaluateComponentsContext(ctx context.Context, iq IQ, components []Component, applicationID string) (*Evaluation, error) { request, err := json.Marshal(iqEvaluationRequest{Components: components}) if err != nil { return nil, fmt.Errorf("could not build the request: %v", err) } requestEndpoint := fmt.Sprintf(restEvaluation, applicationID) - body, _, err := iq.Post(requestEndpoint, bytes.NewBuffer(request)) + body, _, err := iq.Post(ctx, requestEndpoint, bytes.NewBuffer(request)) if err != nil { return nil, fmt.Errorf("components not evaluated: %v", err) } @@ -249,7 +249,7 @@ func EvaluateComponents(iq IQ, components []Component, applicationID string) (*E } getEvaluationResults := func() (*Evaluation, error) { - body, resp, e := iq.Get(results.ResultsURL) + body, resp, e := iq.Get(ctx, results.ResultsURL) if e != nil { if resp.StatusCode != http.StatusNotFound { return nil, fmt.Errorf("could not retrieve evaluation results: %v", err) @@ -280,3 +280,8 @@ func EvaluateComponents(iq IQ, components []Component, applicationID string) (*E } } } + +// EvaluateComponents evaluates the list of components +func EvaluateComponents(iq IQ, components []Component, applicationID string) (*Evaluation, error) { + return EvaluateComponentsContext(context.Background(), iq, components, applicationID) +} diff --git a/iq/evaluation_test.go b/iq/evaluation_test.go index 1bd65d2..d6ebbcb 100644 --- a/iq/evaluation_test.go +++ b/iq/evaluation_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -98,7 +99,7 @@ func TestEvaluateComponents(t *testing.T) { appID := "dummyAppId" - report, err := EvaluateComponents(iq, []Component{dummyComponent}, appID) + report, err := EvaluateComponentsContext(context.Background(), iq, []Component{dummyComponent}, appID) if err != nil { t.Error(err) } diff --git a/iq/organization.go b/iq/organization.go index 16de0f5..96c4abb 100644 --- a/iq/organization.go +++ b/iq/organization.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" ) @@ -35,9 +36,8 @@ type Organization struct { Tags []IQCategory `json:"tags,omitempty"` } -// GetOrganizationByName returns details on the named IQ organization -func GetOrganizationByName(iq IQ, organizationName string) (*Organization, error) { - orgs, err := GetAllOrganizations(iq) +func GetOrganizationByNameContext(ctx context.Context, iq IQ, organizationName string) (*Organization, error) { + orgs, err := GetAllOrganizationsContext(ctx, iq) if err != nil { return nil, fmt.Errorf("organization '%s' not found: %v", organizationName, err) } @@ -50,8 +50,12 @@ func GetOrganizationByName(iq IQ, organizationName string) (*Organization, error return nil, fmt.Errorf("organization '%s' not found", organizationName) } -// CreateOrganization creates an organization in IQ with the given name -func CreateOrganization(iq IQ, name string) (string, error) { +// GetOrganizationByName returns details on the named IQ organization +func GetOrganizationByName(iq IQ, organizationName string) (*Organization, error) { + return GetOrganizationByNameContext(context.Background(), iq, organizationName) +} + +func CreateOrganizationContext(ctx context.Context, iq IQ, name string) (string, error) { doError := func(err error) error { return fmt.Errorf("organization '%s' not created: %v", name, err) } @@ -61,7 +65,7 @@ func CreateOrganization(iq IQ, name string) (string, error) { return "", doError(err) } - body, _, err := iq.Post(restOrganization, bytes.NewBuffer(request)) + body, _, err := iq.Post(ctx, restOrganization, bytes.NewBuffer(request)) if err != nil { return "", doError(err) } @@ -74,13 +78,17 @@ func CreateOrganization(iq IQ, name string) (string, error) { return org.ID, nil } -// GetAllOrganizations returns a slice of all of the organizations in an IQ instance -func GetAllOrganizations(iq IQ) ([]Organization, error) { +// CreateOrganization creates an organization in IQ with the given name +func CreateOrganization(iq IQ, name string) (string, error) { + return CreateOrganizationContext(context.Background(), iq, name) +} + +func GetAllOrganizationsContext(ctx context.Context, iq IQ) ([]Organization, error) { doError := func(err error) error { return fmt.Errorf("organizations not found: %v", err) } - body, _, err := iq.Get(restOrganization) + body, _, err := iq.Get(ctx, restOrganization) if err != nil { return nil, doError(err) } @@ -92,3 +100,8 @@ func GetAllOrganizations(iq IQ) ([]Organization, error) { return resp.Organizations, nil } + +// GetAllOrganizations returns a slice of all of the organizations in an IQ instance +func GetAllOrganizations(iq IQ) ([]Organization, error) { + return GetAllOrganizationsContext(context.Background(), iq) +} diff --git a/iq/organization_test.go b/iq/organization_test.go index 1c5169d..2ce3ab8 100644 --- a/iq/organization_test.go +++ b/iq/organization_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -68,7 +69,7 @@ func TestGetOranizationByName(t *testing.T) { dummyOrgsIdx := 2 - org, err := GetOrganizationByName(iq, dummyOrgs[dummyOrgsIdx].Name) + org, err := GetOrganizationByNameContext(context.Background(), iq, dummyOrgs[dummyOrgsIdx].Name) if err != nil { t.Error(err) } @@ -90,12 +91,12 @@ func TestCreateOrganization(t *testing.T) { createdOrg := Organization{Name: "createdOrg"} var err error - createdOrg.ID, err = CreateOrganization(iq, createdOrg.Name) + createdOrg.ID, err = CreateOrganizationContext(context.Background(), iq, createdOrg.Name) if err != nil { t.Fatal(err) } - org, err := GetOrganizationByName(iq, createdOrg.Name) + org, err := GetOrganizationByNameContext(context.Background(), iq, createdOrg.Name) if err != nil { t.Fatal(err) } @@ -111,7 +112,7 @@ func TestGetAllOrganizations(t *testing.T) { iq, mock := organizationTestIQ(t) defer mock.Close() - organizations, err := GetAllOrganizations(iq) + organizations, err := GetAllOrganizationsContext(context.Background(), iq) if err != nil { panic(err) } @@ -125,7 +126,7 @@ func ExampleCreateOrganization() { panic(err) } - orgID, err := CreateOrganization(iq, "DatabaseTeam") + orgID, err := CreateOrganizationContext(context.Background(), iq, "DatabaseTeam") if err != nil { panic(err) } diff --git a/iq/policies.go b/iq/policies.go index 91821a1..364a1b7 100644 --- a/iq/policies.go +++ b/iq/policies.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" ) @@ -21,9 +22,8 @@ type policiesList struct { Policies []PolicyInfo `json:"policies"` } -// GetPolicies returns a list of all of the policies in IQ -func GetPolicies(iq IQ) ([]PolicyInfo, error) { - body, _, err := iq.Get(restPolicies) +func GetPoliciesContext(ctx context.Context, iq IQ) ([]PolicyInfo, error) { + body, _, err := iq.Get(ctx, restPolicies) if err != nil { return nil, fmt.Errorf("could not get list of policies: %v", err) } @@ -36,9 +36,13 @@ func GetPolicies(iq IQ) ([]PolicyInfo, error) { return resp.Policies, nil } -// GetPolicyInfoByName returns an information object for the named policy -func GetPolicyInfoByName(iq IQ, policyName string) (PolicyInfo, error) { - policies, err := GetPolicies(iq) +// GetPolicies returns a list of all of the policies in IQ +func GetPolicies(iq IQ) ([]PolicyInfo, error) { + return GetPoliciesContext(context.Background(), iq) +} + +func GetPolicyInfoByNameContext(ctx context.Context, iq IQ, policyName string) (PolicyInfo, error) { + policies, err := GetPoliciesContext(ctx, iq) if err != nil { return PolicyInfo{}, fmt.Errorf("did not find policy with name %s: %v", policyName, err) } @@ -51,3 +55,8 @@ func GetPolicyInfoByName(iq IQ, policyName string) (PolicyInfo, error) { return PolicyInfo{}, fmt.Errorf("did not find policy with name %s", policyName) } + +// GetPolicyInfoByName returns an information object for the named policy +func GetPolicyInfoByName(iq IQ, policyName string) (PolicyInfo, error) { + return GetPolicyInfoByNameContext(context.Background(), iq, policyName) +} diff --git a/iq/policies_test.go b/iq/policies_test.go index 7232eeb..1f7cece 100644 --- a/iq/policies_test.go +++ b/iq/policies_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "net/http" @@ -50,7 +51,7 @@ func TestGetPolicies(t *testing.T) { iq, mock := policiesTestIQ(t) defer mock.Close() - infos, err := GetPolicies(iq) + infos, err := GetPoliciesContext(context.Background(), iq) if err != nil { t.Error(err) } @@ -72,7 +73,7 @@ func TestGetPolicyInfoByName(t *testing.T) { expected := dummyPolicyInfos[0] - info, err := GetPolicyInfoByName(iq, expected.Name) + info, err := GetPolicyInfoByNameContext(context.Background(), iq, expected.Name) if err != nil { t.Error(err) } diff --git a/iq/policyViolations.go b/iq/policyViolations.go index e8f68e3..86e3ae5 100644 --- a/iq/policyViolations.go +++ b/iq/policyViolations.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" ) @@ -18,9 +19,8 @@ type violationResponse struct { ApplicationViolations []ApplicationViolation `json:"applicationViolations"` } -// GetAllPolicyViolations returns all policy violations -func GetAllPolicyViolations(iq IQ) ([]ApplicationViolation, error) { - policyInfos, err := GetPolicies(iq) +func GetAllPolicyViolationsContext(ctx context.Context, iq IQ) ([]ApplicationViolation, error) { + policyInfos, err := GetPoliciesContext(ctx, iq) if err != nil { return nil, fmt.Errorf("could not get policies: %v", err) } @@ -33,7 +33,7 @@ func GetAllPolicyViolations(iq IQ) ([]ApplicationViolation, error) { endpoint.WriteString(i.ID) } - body, _, err := iq.Get(endpoint.String()) + body, _, err := iq.Get(ctx, endpoint.String()) if err != nil { return nil, fmt.Errorf("could not get policy violations: %v", err) } @@ -47,9 +47,13 @@ func GetAllPolicyViolations(iq IQ) ([]ApplicationViolation, error) { return resp.ApplicationViolations, nil } -// GetPolicyViolationsByName returns the policy violations by policy name -func GetPolicyViolationsByName(iq IQ, policyNames ...string) ([]ApplicationViolation, error) { - policies, err := GetPolicies(iq) +// GetAllPolicyViolations returns all policy violations +func GetAllPolicyViolations(iq IQ) ([]ApplicationViolation, error) { + return GetAllPolicyViolationsContext(context.Background(), iq) +} + +func GetPolicyViolationsByNameContext(ctx context.Context, iq IQ, policyNames ...string) ([]ApplicationViolation, error) { + policies, err := GetPoliciesContext(ctx, iq) if err != nil { return nil, fmt.Errorf("did not find policy: %v", err) } @@ -67,7 +71,7 @@ func GetPolicyViolationsByName(iq IQ, policyNames ...string) ([]ApplicationViola } } - body, _, err := iq.Get(endpoint.String()) + body, _, err := iq.Get(ctx, endpoint.String()) if err != nil { return nil, fmt.Errorf("could not get policy violations: %v", err) } @@ -80,3 +84,8 @@ func GetPolicyViolationsByName(iq IQ, policyNames ...string) ([]ApplicationViola return resp.ApplicationViolations, nil } + +// GetPolicyViolationsByName returns the policy violations by policy name +func GetPolicyViolationsByName(iq IQ, policyNames ...string) ([]ApplicationViolation, error) { + return GetPolicyViolationsByNameContext(context.Background(), iq, policyNames...) +} diff --git a/iq/policyViolations_test.go b/iq/policyViolations_test.go index 9fe65a8..4cc53b2 100644 --- a/iq/policyViolations_test.go +++ b/iq/policyViolations_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "net/http" @@ -80,7 +81,7 @@ func TestGetAllPolicyViolations(t *testing.T) { iq, mock := policyViolationsTestIQ(t) defer mock.Close() - violations, err := GetAllPolicyViolations(iq) + violations, err := GetAllPolicyViolationsContext(context.Background(), iq) if err != nil { t.Error(err) } @@ -102,7 +103,7 @@ func TestGetPolicyViolationsByName(t *testing.T) { expected := dummyPolicyViolations[0] - violations, err := GetPolicyViolationsByName(iq, expected.PolicyViolations[0].PolicyName) + violations, err := GetPolicyViolationsByNameContext(context.Background(), iq, expected.PolicyViolations[0].PolicyName) if err != nil { t.Error(err) } diff --git a/iq/reportMetrics.go b/iq/reportMetrics.go index b671d29..009b3ec 100644 --- a/iq/reportMetrics.go +++ b/iq/reportMetrics.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -126,7 +127,7 @@ func (b *MetricsRequestBuilder) WithOrganization(v string) *MetricsRequestBuilde return b } -func (b *MetricsRequestBuilder) build(iq IQ) (req metricRequest, err error) { +func (b *MetricsRequestBuilder) build(ctx context.Context, iq IQ) (req metricRequest, err error) { // If timePeriod is MONTH - an ISO 8601 year and month without timezone. // If timePeriod is WEEK - an ISO 8601 week year and week (e.g. week of 29 December 2008 is "2009-W01") formatTime := func(t time.Time) string { @@ -163,7 +164,7 @@ func (b *MetricsRequestBuilder) build(iq IQ) (req metricRequest, err error) { if b.apps != nil { req.ApplicationIDS = make([]string, len(b.apps)) for i, a := range b.apps { - app, er := GetApplicationByPublicID(iq, a) + app, er := GetApplicationByPublicIDContext(ctx, iq, a) if er != nil { return req, fmt.Errorf("could not find application with public id %s: %v", a, er) } @@ -174,7 +175,7 @@ func (b *MetricsRequestBuilder) build(iq IQ) (req metricRequest, err error) { if b.orgs != nil { req.OrganizationIDS = make([]string, len(b.orgs)) for i, o := range b.orgs { - org, er := GetOrganizationByName(iq, o) + org, er := GetOrganizationByNameContext(ctx, iq, o) if er != nil { return req, fmt.Errorf("could not find organization with name %s: %v", o, er) } @@ -190,11 +191,10 @@ func NewMetricsRequestBuilder() *MetricsRequestBuilder { return new(MetricsRequestBuilder) } -// GenerateMetrics creates metrics from the given qualifiers -func GenerateMetrics(iq IQ, builder *MetricsRequestBuilder) ([]Metrics, error) { +func GenerateMetricsContext(ctx context.Context, iq IQ, builder *MetricsRequestBuilder) ([]Metrics, error) { // TODO: Accept header: application/json or text/csv - req, err := builder.build(iq) + req, err := builder.build(ctx, iq) if err != nil { return nil, fmt.Errorf("could not build request: %v", err) } @@ -204,7 +204,7 @@ func GenerateMetrics(iq IQ, builder *MetricsRequestBuilder) ([]Metrics, error) { return nil, fmt.Errorf("could not marshal request: %v", err) } - body, _, err := iq.Post(restMetrics, bytes.NewBuffer(buf)) + body, _, err := iq.Post(ctx, restMetrics, bytes.NewBuffer(buf)) if err != nil { return nil, fmt.Errorf("could not issue request to IQ: %v", err) } @@ -217,3 +217,8 @@ func GenerateMetrics(iq IQ, builder *MetricsRequestBuilder) ([]Metrics, error) { return metrics, nil } + +// GenerateMetrics creates metrics from the given qualifiers +func GenerateMetrics(iq IQ, builder *MetricsRequestBuilder) ([]Metrics, error) { + return GenerateMetricsContext(context.Background(), iq, builder) +} diff --git a/iq/reportMetrics_test.go b/iq/reportMetrics_test.go index c932304..61b8a1e 100644 --- a/iq/reportMetrics_test.go +++ b/iq/reportMetrics_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -197,7 +198,7 @@ func TestMetricsRequestBuilder(t *testing.T) { defer mock.Close() for _, test := range tests { - got, err := test.input.build(iq) + got, err := test.input.build(context.Background(), iq) if err != nil { t.Errorf("Unexpected error building metrics request: %v", err) t.Error("input", test.input) @@ -230,7 +231,7 @@ func TestGenerateMetrics(t *testing.T) { } for _, test := range tests { - got, err := GenerateMetrics(iq, test.input) + got, err := GenerateMetricsContext(context.Background(), iq, test.input) if err != nil { t.Error(err) } @@ -251,7 +252,7 @@ func ExampleGenerateMetrics() { reqLastYear := NewMetricsRequestBuilder().Monthly().StartingOn(time.Now().Add(-(24 * time.Hour) * 365)).WithApplication("WebGoat") - metrics, err := GenerateMetrics(iq, reqLastYear) + metrics, err := GenerateMetricsContext(context.Background(), iq, reqLastYear) if err != nil { panic(err) } diff --git a/iq/reports.go b/iq/reports.go index 277db76..f9dac15 100644 --- a/iq/reports.go +++ b/iq/reports.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "log" @@ -124,9 +125,8 @@ type Report struct { Raw ReportRaw `json:"rawReport"` } -// GetAllReportInfos returns all report infos -func GetAllReportInfos(iq IQ) ([]ReportInfo, error) { - body, _, err := iq.Get(restReports) +func GetAllReportInfosContext(ctx context.Context, iq IQ) ([]ReportInfo, error) { + body, _, err := iq.Get(ctx, restReports) if err != nil { return nil, fmt.Errorf("could not get report info: %v", err) } @@ -137,9 +137,13 @@ func GetAllReportInfos(iq IQ) ([]ReportInfo, error) { return infos, err } -// GetAllReports returns all policy and raw reports -func GetAllReports(iq IQ) ([]Report, error) { - infos, err := GetAllReportInfos(iq) +// GetAllReportInfos returns all report infos +func GetAllReportInfos(iq IQ) ([]ReportInfo, error) { + return GetAllReportInfosContext(context.Background(), iq) +} + +func GetAllReportsContext(ctx context.Context, iq IQ) ([]Report, error) { + infos, err := GetAllReportInfosContext(ctx, iq) if err != nil { return nil, fmt.Errorf("could not get report infos: %v", err) } @@ -147,8 +151,8 @@ func GetAllReports(iq IQ) ([]Report, error) { reports := make([]Report, 0) for _, info := range infos { - raw, _ := getRawReportByURL(iq, info.ReportDataURL) - policy, _ := getPolicyReportByURL(iq, strings.Replace(info.ReportDataURL, "/raw", "/policy", 1)) + raw, _ := getRawReportByURL(ctx, iq, info.ReportDataURL) + policy, _ := getPolicyReportByURL(ctx, iq, strings.Replace(info.ReportDataURL, "/raw", "/policy", 1)) raw.ReportInfo = info policy.ReportInfo = info @@ -163,15 +167,19 @@ func GetAllReports(iq IQ) ([]Report, error) { return reports, err } -// GetReportInfosByAppID returns report information by application public ID -func GetReportInfosByAppID(iq IQ, appID string) (infos []ReportInfo, err error) { - app, err := GetApplicationByPublicID(iq, appID) +// GetAllReports returns all policy and raw reports +func GetAllReports(iq IQ) ([]Report, error) { + return GetAllReportsContext(context.Background(), iq) +} + +func GetReportInfosByAppIDContext(ctx context.Context, iq IQ, appID string) (infos []ReportInfo, err error) { + app, err := GetApplicationByPublicIDContext(ctx, iq, appID) if err != nil { return nil, fmt.Errorf("could not get info for application: %v", err) } endpoint := fmt.Sprintf("%s/%s", restReports, app.ID) - body, _, err := iq.Get(endpoint) + body, _, err := iq.Get(ctx, endpoint) if err != nil { return nil, fmt.Errorf("could not get report infos: %v", err) } @@ -184,9 +192,13 @@ func GetReportInfosByAppID(iq IQ, appID string) (infos []ReportInfo, err error) return } -// GetReportInfoByAppIDStage returns report information by application public ID and stage -func GetReportInfoByAppIDStage(iq IQ, appID, stage string) (ReportInfo, error) { - if infos, err := GetReportInfosByAppID(iq, appID); err == nil { +// GetReportInfosByAppID returns report information by application public ID +func GetReportInfosByAppID(iq IQ, appID string) (infos []ReportInfo, err error) { + return GetReportInfosByAppIDContext(context.Background(), iq, appID) +} + +func GetReportInfoByAppIDStageContext(ctx context.Context, iq IQ, appID, stage string) (ReportInfo, error) { + if infos, err := GetReportInfosByAppIDContext(ctx, iq, appID); err == nil { for _, info := range infos { if info.Stage == stage { return info, nil @@ -197,8 +209,13 @@ func GetReportInfoByAppIDStage(iq IQ, appID, stage string) (ReportInfo, error) { return ReportInfo{}, fmt.Errorf("did not find report for '%s'", appID) } -func getRawReportByURL(iq IQ, URL string) (ReportRaw, error) { - body, resp, err := iq.Get(URL) +// GetReportInfoByAppIDStage returns report information by application public ID and stage +func GetReportInfoByAppIDStage(iq IQ, appID, stage string) (ReportInfo, error) { + return GetReportInfoByAppIDStageContext(context.Background(), iq, appID, stage) +} + +func getRawReportByURL(ctx context.Context, iq IQ, URL string) (ReportRaw, error) { + body, resp, err := iq.Get(ctx, URL) if err != nil { log.Printf("error: could not retrieve raw report: %v\n", err) dump, _ := httputil.DumpRequest(resp.Request, true) @@ -213,20 +230,19 @@ func getRawReportByURL(iq IQ, URL string) (ReportRaw, error) { return report, nil } -func getRawReportByAppReportID(iq IQ, appID, reportID string) (ReportRaw, error) { - return getRawReportByURL(iq, fmt.Sprintf(restReportsRaw, appID, reportID)) +func getRawReportByAppReportID(ctx context.Context, iq IQ, appID, reportID string) (ReportRaw, error) { + return getRawReportByURL(ctx, iq, fmt.Sprintf(restReportsRaw, appID, reportID)) } -// GetRawReportByAppID returns report information by application public ID -func GetRawReportByAppID(iq IQ, appID, stage string) (ReportRaw, error) { - infos, err := GetReportInfosByAppID(iq, appID) +func GetRawReportByAppIDContext(ctx context.Context, iq IQ, appID, stage string) (ReportRaw, error) { + infos, err := GetReportInfosByAppIDContext(ctx, iq, appID) if err != nil { return ReportRaw{}, fmt.Errorf("could not get report info for app '%s': %v", appID, err) } for _, info := range infos { if info.Stage == stage { - report, err := getRawReportByURL(iq, info.ReportDataURL) + report, err := getRawReportByURL(ctx, iq, info.ReportDataURL) report.ReportInfo = info return report, err } @@ -235,8 +251,13 @@ func GetRawReportByAppID(iq IQ, appID, stage string) (ReportRaw, error) { return ReportRaw{}, fmt.Errorf("could not find raw report for stage %s", stage) } -func getPolicyReportByURL(iq IQ, URL string) (ReportPolicy, error) { - body, _, err := iq.Get(URL) +// GetRawReportByAppID returns report information by application public ID +func GetRawReportByAppID(iq IQ, appID, stage string) (ReportRaw, error) { + return GetRawReportByAppIDContext(context.Background(), iq, appID, stage) +} + +func getPolicyReportByURL(ctx context.Context, iq IQ, URL string) (ReportPolicy, error) { + body, _, err := iq.Get(ctx, URL) if err != nil { return ReportPolicy{}, fmt.Errorf("could not get policy report at URL %s: %v", URL, err) } @@ -248,16 +269,15 @@ func getPolicyReportByURL(iq IQ, URL string) (ReportPolicy, error) { return report, nil } -// GetPolicyReportByAppID returns report information by application public ID -func GetPolicyReportByAppID(iq IQ, appID, stage string) (ReportPolicy, error) { - infos, err := GetReportInfosByAppID(iq, appID) +func GetPolicyReportByAppIDContext(ctx context.Context, iq IQ, appID, stage string) (ReportPolicy, error) { + infos, err := GetReportInfosByAppIDContext(ctx, iq, appID) if err != nil { return ReportPolicy{}, fmt.Errorf("could not get report info for app '%s': %v", appID, err) } for _, info := range infos { if info.Stage == stage { - report, err := getPolicyReportByURL(iq, strings.Replace(infos[0].ReportDataURL, "/raw", "/policy", 1)) + report, err := getPolicyReportByURL(ctx, iq, strings.Replace(infos[0].ReportDataURL, "/raw", "/policy", 1)) report.ReportInfo = info return report, err } @@ -266,14 +286,18 @@ func GetPolicyReportByAppID(iq IQ, appID, stage string) (ReportPolicy, error) { return ReportPolicy{}, fmt.Errorf("could not find policy report for stage %s", stage) } -// GetReportByAppID returns report information by application public ID -func GetReportByAppID(iq IQ, appID, stage string) (report Report, err error) { - report.Policy, err = GetPolicyReportByAppID(iq, appID, stage) +// GetPolicyReportByAppID returns report information by application public ID +func GetPolicyReportByAppID(iq IQ, appID, stage string) (ReportPolicy, error) { + return GetPolicyReportByAppIDContext(context.Background(), iq, appID, stage) +} + +func GetReportByAppIDContext(ctx context.Context, iq IQ, appID, stage string) (report Report, err error) { + report.Policy, err = GetPolicyReportByAppIDContext(ctx, iq, appID, stage) if err != nil { return report, fmt.Errorf("could not retrieve policy report: %v", err) } - report.Raw, err = GetRawReportByAppID(iq, appID, stage) + report.Raw, err = GetRawReportByAppIDContext(ctx, iq, appID, stage) if err != nil { return report, fmt.Errorf("could not retrieve raw report: %v", err) } @@ -281,19 +305,23 @@ func GetReportByAppID(iq IQ, appID, stage string) (report Report, err error) { return report, nil } -// GetReportByAppReportID returns raw and policy report information for a given report ID -func GetReportByAppReportID(iq IQ, appID, reportID string) (report Report, err error) { - report.Policy, err = getPolicyReportByURL(iq, fmt.Sprintf(restReportsPolicy, appID, reportID)) +// GetReportByAppID returns report information by application public ID +func GetReportByAppID(iq IQ, appID, stage string) (report Report, err error) { + return GetReportByAppIDContext(context.Background(), iq, appID, stage) +} + +func GetReportByAppReportIDContext(ctx context.Context, iq IQ, appID, reportID string) (report Report, err error) { + report.Policy, err = getPolicyReportByURL(ctx, iq, fmt.Sprintf(restReportsPolicy, appID, reportID)) if err != nil { return report, fmt.Errorf("could not retrieve policy report: %v", err) } - report.Raw, err = getRawReportByURL(iq, fmt.Sprintf(restReportsRaw, appID, reportID)) + report.Raw, err = getRawReportByURL(ctx, iq, fmt.Sprintf(restReportsRaw, appID, reportID)) if err != nil { return report, fmt.Errorf("could not retrieve raw report: %v", err) } - infos, err := GetReportInfosByAppID(iq, appID) + infos, err := GetReportInfosByAppIDContext(ctx, iq, appID) if err != nil { return report, fmt.Errorf("could not retrieve report infos: %v", err) } @@ -307,16 +335,20 @@ func GetReportByAppReportID(iq IQ, appID, reportID string) (report Report, err e return report, nil } -// GetReportInfosByOrganization returns report information by organization name -func GetReportInfosByOrganization(iq IQ, organizationName string) (infos []ReportInfo, err error) { - apps, err := GetApplicationsByOrganization(iq, organizationName) +// GetReportByAppReportID returns raw and policy report information for a given report ID +func GetReportByAppReportID(iq IQ, appID, reportID string) (report Report, err error) { + return GetReportByAppReportIDContext(context.Background(), iq, appID, reportID) +} + +func GetReportInfosByOrganizationContext(ctx context.Context, iq IQ, organizationName string) (infos []ReportInfo, err error) { + apps, err := GetApplicationsByOrganizationContext(ctx, iq, organizationName) if err != nil { return nil, fmt.Errorf("could not get applications for organization '%s': %v", organizationName, err) } infos = make([]ReportInfo, 0) for _, app := range apps { - if appInfos, err := GetReportInfosByAppID(iq, app.PublicID); err == nil { + if appInfos, err := GetReportInfosByAppIDContext(ctx, iq, app.PublicID); err == nil { infos = append(infos, appInfos...) } } @@ -324,9 +356,13 @@ func GetReportInfosByOrganization(iq IQ, organizationName string) (infos []Repor return infos, nil } -// GetReportsByOrganization returns all reports for an given organization -func GetReportsByOrganization(iq IQ, organizationName string) (reports []Report, err error) { - apps, err := GetApplicationsByOrganization(iq, organizationName) +// GetReportInfosByOrganization returns report information by organization name +func GetReportInfosByOrganization(iq IQ, organizationName string) (infos []ReportInfo, err error) { + return GetReportInfosByOrganizationContext(context.Background(), iq, organizationName) +} + +func GetReportsByOrganizationContext(ctx context.Context, iq IQ, organizationName string) (reports []Report, err error) { + apps, err := GetApplicationsByOrganizationContext(ctx, iq, organizationName) if err != nil { return nil, fmt.Errorf("could not get applications for organization '%s': %v", organizationName, err) } @@ -336,7 +372,7 @@ func GetReportsByOrganization(iq IQ, organizationName string) (reports []Report, reports = make([]Report, 0) for _, app := range apps { for _, s := range stages { - if appReport, err := GetReportByAppID(iq, app.PublicID, string(s)); err == nil { + if appReport, err := GetReportByAppIDContext(ctx, iq, app.PublicID, string(s)); err == nil { reports = append(reports, appReport) } } @@ -345,6 +381,11 @@ func GetReportsByOrganization(iq IQ, organizationName string) (reports []Report, return reports, nil } +// GetReportsByOrganization returns all reports for an given organization +func GetReportsByOrganization(iq IQ, organizationName string) (reports []Report, err error) { + return GetReportsByOrganizationContext(context.Background(), iq, organizationName) +} + // ReportDiff encapsulates the differences between reports type ReportDiff struct { Reports []Report `json:"reports"` @@ -352,16 +393,15 @@ type ReportDiff struct { Fixed []PolicyReportComponent `json:"fixed,omitempty"` } -// ReportsDiff returns a structure describing various differences between two reports -func ReportsDiff(iq IQ, appID, report1ID, report2ID string) (ReportDiff, error) { +func ReportsDiffContext(ctx context.Context, iq IQ, appID, report1ID, report2ID string) (ReportDiff, error) { var ( report1, report2 Report err error ) - report1, err = GetReportByAppReportID(iq, appID, report1ID) + report1, err = GetReportByAppReportIDContext(ctx, iq, appID, report1ID) if err == nil { - report2, err = GetReportByAppReportID(iq, appID, report2ID) + report2, err = GetReportByAppReportIDContext(ctx, iq, appID, report2ID) } if err != nil { return ReportDiff{}, fmt.Errorf("could not retrieve raw reports: %v", err) @@ -414,3 +454,8 @@ func ReportsDiff(iq IQ, appID, report1ID, report2ID string) (ReportDiff, error) return diff(iq, report2, report1) } + +// ReportsDiff returns a structure describing various differences between two reports +func ReportsDiff(iq IQ, appID, report1ID, report2ID string) (ReportDiff, error) { + return ReportsDiffContext(context.Background(), iq, appID, report1ID, report2ID) +} diff --git a/iq/reports_test.go b/iq/reports_test.go index be1841f..00b81c0 100644 --- a/iq/reports_test.go +++ b/iq/reports_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "net/http" @@ -162,7 +163,7 @@ func TestGetAllReportInfos(t *testing.T) { iq, mock := reportsTestIQ(t) defer mock.Close() - infos, err := GetAllReportInfos(iq) + infos, err := GetAllReportInfosContext(context.Background(), iq) if err != nil { t.Error(err) } @@ -184,7 +185,7 @@ func TestGetReportInfosByAppID(t *testing.T) { testIdx := 0 - infos, err := GetReportInfosByAppID(iq, dummyApps[testIdx].PublicID) + infos, err := GetReportInfosByAppIDContext(context.Background(), iq, dummyApps[testIdx].PublicID) if err != nil { t.Error(err) } @@ -204,7 +205,7 @@ func Test_getRawReportByAppReportID(t *testing.T) { testIdx := 0 - report, err := getRawReportByAppReportID(iq, dummyApps[testIdx].PublicID, fmt.Sprintf("%d", testIdx)) + report, err := getRawReportByAppReportID(context.Background(), iq, dummyApps[testIdx].PublicID, fmt.Sprintf("%d", testIdx)) if err != nil { t.Fatal(err) } @@ -221,7 +222,7 @@ func TestGetRawReportByAppID(t *testing.T) { testIdx := 0 - report, err := GetRawReportByAppID(iq, dummyApps[testIdx].PublicID, dummyReportInfos[testIdx].Stage) + report, err := GetRawReportByAppIDContext(context.Background(), iq, dummyApps[testIdx].PublicID, dummyReportInfos[testIdx].Stage) if err != nil { t.Fatal(err) } @@ -238,7 +239,7 @@ func TestGetPolicyReportByAppID(t *testing.T) { testIdx := 0 - report, err := GetPolicyReportByAppID(iq, dummyApps[testIdx].PublicID, dummyReportInfos[testIdx].Stage) + report, err := GetPolicyReportByAppIDContext(context.Background(), iq, dummyApps[testIdx].PublicID, dummyReportInfos[testIdx].Stage) if err != nil { t.Fatal(err) } @@ -255,7 +256,7 @@ func TestGetReportByAppID(t *testing.T) { testIdx := 0 - report, err := GetReportByAppID(iq, dummyApps[testIdx].PublicID, dummyReportInfos[testIdx].Stage) + report, err := GetReportByAppIDContext(context.Background(), iq, dummyApps[testIdx].PublicID, dummyReportInfos[testIdx].Stage) if err != nil { t.Fatal(err) } @@ -294,7 +295,7 @@ func TestGetReportInfosByOrganization(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotInfos, err := GetReportInfosByOrganization(tt.args.iq, tt.args.organizationName) + gotInfos, err := GetReportInfosByOrganizationContext(context.Background(), tt.args.iq, tt.args.organizationName) if (err != nil) != tt.wantErr { t.Errorf("GetReportInfosByOrganization() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/iq/roleMemberships.go b/iq/roleMemberships.go index c6c76d1..d2c485c 100644 --- a/iq/roleMemberships.go +++ b/iq/roleMemberships.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -53,9 +54,9 @@ type Member struct { UserOrGroupName string `json:"userOrGroupName"` } -func hasRev70API(iq IQ) bool { +func hasRev70API(ctx context.Context, iq IQ) bool { api := fmt.Sprintf(restRoleMembersOrgGet, RootOrganization) - request, _ := iq.NewRequest("HEAD", api, nil) + request, _ := iq.NewRequest(ctx, "HEAD", api, nil) _, resp, _ := iq.Do(request) return resp.StatusCode != http.StatusNotFound } @@ -78,15 +79,15 @@ func newMappings(roleID, memberType, memberName string) memberMappings { } } -func organizationAuthorizationsByID(iq IQ, orgID string) ([]MemberMapping, error) { +func organizationAuthorizationsByID(ctx context.Context, iq IQ, orgID string) ([]MemberMapping, error) { var endpoint string - if hasRev70API(iq) { + if hasRev70API(ctx, iq) { endpoint = fmt.Sprintf(restRoleMembersOrgGet, orgID) } else { endpoint = fmt.Sprintf(restRoleMembersOrgDeprecated, orgID) } - body, _, err := iq.Get(endpoint) + body, _, err := iq.Get(ctx, endpoint) if err != nil { return nil, fmt.Errorf("could not retrieve role mapping for organization %s: %v", orgID, err) } @@ -97,15 +98,15 @@ func organizationAuthorizationsByID(iq IQ, orgID string) ([]MemberMapping, error return mappings.MemberMappings, err } -func organizationAuthorizationsByRoleID(iq IQ, roleID string) ([]MemberMapping, error) { - orgs, err := GetAllOrganizations(iq) +func organizationAuthorizationsByRoleID(ctx context.Context, iq IQ, roleID string) ([]MemberMapping, error) { + orgs, err := GetAllOrganizationsContext(ctx, iq) if err != nil { return nil, fmt.Errorf("could not find organizations: %v", err) } mappings := make([]MemberMapping, 0) for _, org := range orgs { - orgMaps, _ := organizationAuthorizationsByID(iq, org.ID) + orgMaps, _ := organizationAuthorizationsByID(ctx, iq, org.ID) for _, m := range orgMaps { if m.RoleID == roleID { mappings = append(mappings, m) @@ -116,40 +117,48 @@ func organizationAuthorizationsByRoleID(iq IQ, roleID string) ([]MemberMapping, return mappings, nil } -// OrganizationAuthorizations returns the member mappings of an organization -func OrganizationAuthorizations(iq IQ, name string) ([]MemberMapping, error) { - org, err := GetOrganizationByName(iq, name) +func OrganizationAuthorizationsContext(ctx context.Context, iq IQ, name string) ([]MemberMapping, error) { + org, err := GetOrganizationByNameContext(ctx, iq, name) if err != nil { return nil, fmt.Errorf("could not find organization with name %s: %v", name, err) } - return organizationAuthorizationsByID(iq, org.ID) + return organizationAuthorizationsByID(ctx, iq, org.ID) } -// OrganizationAuthorizationsByRole returns the member mappings of all organizations which match the given role -func OrganizationAuthorizationsByRole(iq IQ, roleName string) ([]MemberMapping, error) { - role, err := RoleByName(iq, roleName) +// OrganizationAuthorizations returns the member mappings of an organization +func OrganizationAuthorizations(iq IQ, name string) ([]MemberMapping, error) { + return OrganizationAuthorizationsContext(context.Background(), iq, name) +} + +func OrganizationAuthorizationsByRoleContext(ctx context.Context, iq IQ, roleName string) ([]MemberMapping, error) { + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return nil, fmt.Errorf("could not find role with name %s: %v", roleName, err) } - return organizationAuthorizationsByRoleID(iq, role.ID) + return organizationAuthorizationsByRoleID(ctx, iq, role.ID) } -func setOrganizationAuth(iq IQ, name, roleName, member, memberType string) error { - org, err := GetOrganizationByName(iq, name) +// OrganizationAuthorizationsByRole returns the member mappings of all organizations which match the given role +func OrganizationAuthorizationsByRole(iq IQ, roleName string) ([]MemberMapping, error) { + return OrganizationAuthorizationsByRoleContext(context.Background(), iq, roleName) +} + +func setOrganizationAuth(ctx context.Context, iq IQ, name, roleName, member, memberType string) error { + org, err := GetOrganizationByNameContext(ctx, iq, name) if err != nil { return fmt.Errorf("could not find organization with name %s: %v", name, err) } - role, err := RoleByName(iq, roleName) + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return fmt.Errorf("could not find role with name %s: %v", roleName, err) } var endpoint string var payload io.Reader - if hasRev70API(iq) { + if hasRev70API(ctx, iq) { switch memberType { case MemberTypeUser: endpoint = fmt.Sprintf(restRoleMembersOrgUser, org.ID, role.ID, member) @@ -158,7 +167,7 @@ func setOrganizationAuth(iq IQ, name, roleName, member, memberType string) error } } else { endpoint = fmt.Sprintf(restRoleMembersOrgDeprecated, org.ID) - current, err := OrganizationAuthorizations(iq, name) + current, err := OrganizationAuthorizationsContext(ctx, iq, name) if err != nil && current == nil { current = make([]MemberMapping, 0) } @@ -171,7 +180,7 @@ func setOrganizationAuth(iq IQ, name, roleName, member, memberType string) error payload = bytes.NewBuffer(buf) } - _, _, err = iq.Put(endpoint, payload) + _, _, err = iq.Put(ctx, endpoint, payload) if err != nil { return fmt.Errorf("could not update organization role mapping: %v", err) } @@ -179,25 +188,33 @@ func setOrganizationAuth(iq IQ, name, roleName, member, memberType string) error return nil } +func SetOrganizationUserContext(ctx context.Context, iq IQ, name, roleName, user string) error { + return setOrganizationAuth(ctx, iq, name, roleName, user, MemberTypeUser) +} + // SetOrganizationUser sets the role and user that can have access to an organization func SetOrganizationUser(iq IQ, name, roleName, user string) error { - return setOrganizationAuth(iq, name, roleName, user, MemberTypeUser) + return SetOrganizationUserContext(context.Background(), iq, name, roleName, user) +} + +func SetOrganizationGroupContext(ctx context.Context, iq IQ, name, roleName, group string) error { + return setOrganizationAuth(ctx, iq, name, roleName, group, MemberTypeGroup) } // SetOrganizationGroup sets the role and group that can have access to an organization func SetOrganizationGroup(iq IQ, name, roleName, group string) error { - return setOrganizationAuth(iq, name, roleName, group, MemberTypeGroup) + return SetOrganizationGroupContext(context.Background(), iq, name, roleName, group) } -func applicationAuthorizationsByID(iq IQ, appID string) ([]MemberMapping, error) { +func applicationAuthorizationsByID(ctx context.Context, iq IQ, appID string) ([]MemberMapping, error) { var endpoint string - if hasRev70API(iq) { + if hasRev70API(ctx, iq) { endpoint = fmt.Sprintf(restRoleMembersAppGet, appID) } else { endpoint = fmt.Sprintf(restRoleMembersAppDeprecated, appID) } - body, _, err := iq.Get(endpoint) + body, _, err := iq.Get(ctx, endpoint) if err != nil { return nil, fmt.Errorf("could not retrieve role mapping for application %s: %v", appID, err) } @@ -208,15 +225,15 @@ func applicationAuthorizationsByID(iq IQ, appID string) ([]MemberMapping, error) return mappings.MemberMappings, err } -func applicationAuthorizationsByRoleID(iq IQ, roleID string) ([]MemberMapping, error) { - apps, err := GetAllApplications(iq) +func applicationAuthorizationsByRoleID(ctx context.Context, iq IQ, roleID string) ([]MemberMapping, error) { + apps, err := GetAllApplicationsContext(ctx, iq) if err != nil { return nil, fmt.Errorf("could not find applications: %v", err) } mappings := make([]MemberMapping, 0) for _, app := range apps { - appMaps, _ := applicationAuthorizationsByID(iq, app.ID) + appMaps, _ := applicationAuthorizationsByID(ctx, iq, app.ID) for _, m := range appMaps { if m.RoleID == roleID { mappings = append(mappings, m) @@ -227,40 +244,48 @@ func applicationAuthorizationsByRoleID(iq IQ, roleID string) ([]MemberMapping, e return mappings, nil } -// ApplicationAuthorizations returns the member mappings of an application -func ApplicationAuthorizations(iq IQ, name string) ([]MemberMapping, error) { - app, err := GetApplicationByPublicID(iq, name) +func ApplicationAuthorizationsContext(ctx context.Context, iq IQ, name string) ([]MemberMapping, error) { + app, err := GetApplicationByPublicIDContext(ctx, iq, name) if err != nil { return nil, fmt.Errorf("could not find application with name %s: %v", name, err) } - return applicationAuthorizationsByID(iq, app.ID) + return applicationAuthorizationsByID(ctx, iq, app.ID) } -// ApplicationAuthorizationsByRole returns the member mappings of all applications which match the given role -func ApplicationAuthorizationsByRole(iq IQ, roleName string) ([]MemberMapping, error) { - role, err := RoleByName(iq, roleName) +// ApplicationAuthorizations returns the member mappings of an application +func ApplicationAuthorizations(iq IQ, name string) ([]MemberMapping, error) { + return ApplicationAuthorizationsContext(context.Background(), iq, name) +} + +func ApplicationAuthorizationsByRoleContext(ctx context.Context, iq IQ, roleName string) ([]MemberMapping, error) { + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return nil, fmt.Errorf("could not find role with name %s: %v", roleName, err) } - return applicationAuthorizationsByRoleID(iq, role.ID) + return applicationAuthorizationsByRoleID(ctx, iq, role.ID) } -func setApplicationAuth(iq IQ, name, roleName, member, memberType string) error { - app, err := GetApplicationByPublicID(iq, name) +// ApplicationAuthorizationsByRole returns the member mappings of all applications which match the given role +func ApplicationAuthorizationsByRole(iq IQ, roleName string) ([]MemberMapping, error) { + return ApplicationAuthorizationsByRoleContext(context.Background(), iq, roleName) +} + +func setApplicationAuth(ctx context.Context, iq IQ, name, roleName, member, memberType string) error { + app, err := GetApplicationByPublicIDContext(ctx, iq, name) if err != nil { return fmt.Errorf("could not find application with name %s: %v", name, err) } - role, err := RoleByName(iq, roleName) + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return fmt.Errorf("could not find role with name %s: %v", roleName, err) } var endpoint string var payload io.Reader - if hasRev70API(iq) { + if hasRev70API(ctx, iq) { switch memberType { case MemberTypeUser: endpoint = fmt.Sprintf(restRoleMembersAppUser, app.ID, role.ID, member) @@ -269,7 +294,7 @@ func setApplicationAuth(iq IQ, name, roleName, member, memberType string) error } } else { endpoint = fmt.Sprintf(restRoleMembersAppDeprecated, app.ID) - current, err := ApplicationAuthorizations(iq, name) + current, err := ApplicationAuthorizationsContext(ctx, iq, name) if err != nil && current == nil { current = make([]MemberMapping, 0) } @@ -282,7 +307,7 @@ func setApplicationAuth(iq IQ, name, roleName, member, memberType string) error payload = bytes.NewBuffer(buf) } - _, _, err = iq.Put(endpoint, payload) + _, _, err = iq.Put(ctx, endpoint, payload) if err != nil { return fmt.Errorf("could not update organization role mapping: %v", err) } @@ -290,19 +315,27 @@ func setApplicationAuth(iq IQ, name, roleName, member, memberType string) error return nil } +func SetApplicationUserContext(ctx context.Context, iq IQ, name, roleName, user string) error { + return setApplicationAuth(ctx, iq, name, roleName, user, MemberTypeUser) +} + // SetApplicationUser sets the role and user that can have access to an application func SetApplicationUser(iq IQ, name, roleName, user string) error { - return setApplicationAuth(iq, name, roleName, user, MemberTypeUser) + return SetApplicationUserContext(context.Background(), iq, name, roleName, user) +} + +func SetApplicationGroupContext(ctx context.Context, iq IQ, name, roleName, group string) error { + return setApplicationAuth(ctx, iq, name, roleName, group, MemberTypeGroup) } // SetApplicationGroup sets the role and group that can have access to an application func SetApplicationGroup(iq IQ, name, roleName, group string) error { - return setApplicationAuth(iq, name, roleName, group, MemberTypeGroup) + return SetApplicationGroupContext(context.Background(), iq, name, roleName, group) } -func revokeLT70(iq IQ, authType, authName, roleName, memberType, memberName string) error { +func revokeLT70(ctx context.Context, iq IQ, authType, authName, roleName, memberType, memberName string) error { var err error - role, err := RoleByName(iq, roleName) + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return fmt.Errorf("could not find role with name %s: %v", roleName, err) } @@ -313,18 +346,18 @@ func revokeLT70(iq IQ, authType, authName, roleName, memberType, memberName stri ) switch authType { case "organization": - org, err := GetOrganizationByName(iq, authName) + org, err := GetOrganizationByNameContext(ctx, iq, authName) if err == nil { authID = org.ID baseEndpoint = restRoleMembersOrgDeprecated - mapping, err = OrganizationAuthorizations(iq, authName) + mapping, err = OrganizationAuthorizationsContext(ctx, iq, authName) } case "application": - app, err := GetApplicationByPublicID(iq, authName) + app, err := GetApplicationByPublicIDContext(ctx, iq, authName) if err == nil { authID = app.ID baseEndpoint = restRoleMembersAppDeprecated - mapping, err = ApplicationAuthorizations(iq, authName) + mapping, err = ApplicationAuthorizationsContext(ctx, iq, authName) } } if err != nil && mapping != nil { @@ -349,7 +382,7 @@ func revokeLT70(iq IQ, authType, authName, roleName, memberType, memberName stri } endpoint := fmt.Sprintf(baseEndpoint, authID) - _, _, err = iq.Put(endpoint, bytes.NewBuffer(buf)) + _, _, err = iq.Put(ctx, endpoint, bytes.NewBuffer(buf)) if err != nil { return fmt.Errorf("could not remove role mapping: %v", err) } @@ -357,8 +390,8 @@ func revokeLT70(iq IQ, authType, authName, roleName, memberType, memberName stri return nil } -func revoke(iq IQ, authType, authName, roleName, memberType, memberName string) error { - role, err := RoleByName(iq, roleName) +func revoke(ctx context.Context, iq IQ, authType, authName, roleName, memberType, memberName string) error { + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return fmt.Errorf("could not find role with name %s: %v", roleName, err) } @@ -368,7 +401,7 @@ func revoke(iq IQ, authType, authName, roleName, memberType, memberName string) ) switch authType { case "organization": - org, err := GetOrganizationByName(iq, authName) + org, err := GetOrganizationByNameContext(ctx, iq, authName) if err == nil { authID = org.ID switch memberType { @@ -379,7 +412,7 @@ func revoke(iq IQ, authType, authName, roleName, memberType, memberName string) } } case "application": - app, err := GetApplicationByPublicID(iq, authName) + app, err := GetApplicationByPublicIDContext(ctx, iq, authName) if err == nil { authID = app.ID switch memberType { @@ -392,48 +425,64 @@ func revoke(iq IQ, authType, authName, roleName, memberType, memberName string) } endpoint := fmt.Sprintf(baseEndpoint, authID, role.ID, memberName) - _, err = iq.Del(endpoint) + _, err = iq.Del(ctx, endpoint) return err } -// RevokeOrganizationUser removes a user and role from the named organization +func RevokeOrganizationUserContext(ctx context.Context, iq IQ, name, roleName, user string) error { + if !hasRev70API(ctx, iq) { + return revokeLT70(ctx, iq, "organization", name, roleName, MemberTypeUser, user) + } + return revoke(ctx, iq, "organization", name, roleName, MemberTypeUser, user) +} + func RevokeOrganizationUser(iq IQ, name, roleName, user string) error { - if !hasRev70API(iq) { - return revokeLT70(iq, "organization", name, roleName, MemberTypeUser, user) + return RevokeOrganizationUserContext(context.Background(), iq, name, roleName, user) +} + +func RevokeOrganizationGroupContext(ctx context.Context, iq IQ, name, roleName, group string) error { + if !hasRev70API(ctx, iq) { + return revokeLT70(ctx, iq, "organization", name, roleName, MemberTypeGroup, group) } - return revoke(iq, "organization", name, roleName, MemberTypeUser, user) + return revoke(ctx, iq, "organization", name, roleName, MemberTypeGroup, group) } // RevokeOrganizationGroup removes a group and role from the named organization func RevokeOrganizationGroup(iq IQ, name, roleName, group string) error { - if !hasRev70API(iq) { - return revokeLT70(iq, "organization", name, roleName, MemberTypeGroup, group) + return RevokeOrganizationGroupContext(context.Background(), iq, name, roleName, group) +} + +func RevokeApplicationUserContext(ctx context.Context, iq IQ, name, roleName, user string) error { + if !hasRev70API(ctx, iq) { + return revokeLT70(ctx, iq, "application", name, roleName, MemberTypeUser, user) } - return revoke(iq, "organization", name, roleName, MemberTypeGroup, group) + return revoke(ctx, iq, "application", name, roleName, MemberTypeUser, user) } // RevokeApplicationUser removes a user and role from the named application func RevokeApplicationUser(iq IQ, name, roleName, user string) error { - if !hasRev70API(iq) { - return revokeLT70(iq, "application", name, roleName, MemberTypeUser, user) + return RevokeApplicationUserContext(context.Background(), iq, name, roleName, user) + +} + +func RevokeApplicationGroupContext(ctx context.Context, iq IQ, name, roleName, group string) error { + if !hasRev70API(ctx, iq) { + return revokeLT70(ctx, iq, "application", name, roleName, MemberTypeGroup, group) } - return revoke(iq, "application", name, roleName, MemberTypeUser, user) + return revoke(ctx, iq, "application", name, roleName, MemberTypeGroup, group) } // RevokeApplicationGroup removes a group and role from the named application func RevokeApplicationGroup(iq IQ, name, roleName, group string) error { - if !hasRev70API(iq) { - return revokeLT70(iq, "application", name, roleName, MemberTypeGroup, group) - } - return revoke(iq, "application", name, roleName, MemberTypeGroup, group) + return RevokeApplicationGroupContext(context.Background(), iq, name, roleName, group) } -func repositoriesAuth(iq IQ, method, roleName, memberType, member string) error { - if !hasRev70API(iq) { +func repositoriesAuth(ctx context.Context, iq IQ, method, roleName, memberType, member string) error { + if !hasRev70API(ctx, iq) { return fmt.Errorf("did not find revision 70 API") } - role, err := RoleByName(iq, roleName) + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return fmt.Errorf("could not find role with name %s: %v", roleName, err) } @@ -448,9 +497,9 @@ func repositoriesAuth(iq IQ, method, roleName, memberType, member string) error switch method { case http.MethodPut: - _, _, err = iq.Put(endpoint, nil) + _, _, err = iq.Put(ctx, endpoint, nil) case http.MethodDelete: - _, err = iq.Del(endpoint) + _, err = iq.Del(ctx, endpoint) } if err != nil { return fmt.Errorf("could not affect repositories role mapping: %v", err) @@ -459,8 +508,8 @@ func repositoriesAuth(iq IQ, method, roleName, memberType, member string) error return nil } -func repositoriesAuthorizationsByRoleID(iq IQ, roleID string) ([]MemberMapping, error) { - auths, err := RepositoriesAuthorizations(iq) +func repositoriesAuthorizationsByRoleID(ctx context.Context, iq IQ, roleID string) ([]MemberMapping, error) { + auths, err := RepositoriesAuthorizationsContext(ctx, iq) if err != nil { return nil, fmt.Errorf("could not find authorization mappings for repositories: %v", err) } @@ -475,9 +524,8 @@ func repositoriesAuthorizationsByRoleID(iq IQ, roleID string) ([]MemberMapping, return mappings, nil } -// RepositoriesAuthorizations returns the member mappings of all repositories -func RepositoriesAuthorizations(iq IQ) ([]MemberMapping, error) { - body, _, err := iq.Get(restRoleMembersReposGet) +func RepositoriesAuthorizationsContext(ctx context.Context, iq IQ) ([]MemberMapping, error) { + body, _, err := iq.Get(ctx, restRoleMembersReposGet) if err != nil { return nil, fmt.Errorf("could not get repositories mappings: %v", err) } @@ -491,49 +539,74 @@ func RepositoriesAuthorizations(iq IQ) ([]MemberMapping, error) { return mappings.MemberMappings, nil } -// RepositoriesAuthorizationsByRole returns the member mappings of all repositories which match the given role -func RepositoriesAuthorizationsByRole(iq IQ, roleName string) ([]MemberMapping, error) { - role, err := RoleByName(iq, roleName) +// RepositoriesAuthorizations returns the member mappings of all repositories +func RepositoriesAuthorizations(iq IQ) ([]MemberMapping, error) { + return RepositoriesAuthorizationsContext(context.Background(), iq) +} + +func RepositoriesAuthorizationsByRoleContext(ctx context.Context, iq IQ, roleName string) ([]MemberMapping, error) { + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return nil, fmt.Errorf("could not find role with name %s: %v", roleName, err) } - return repositoriesAuthorizationsByRoleID(iq, role.ID) + return repositoriesAuthorizationsByRoleID(ctx, iq, role.ID) +} + +// RepositoriesAuthorizationsByRole returns the member mappings of all repositories which match the given role +func RepositoriesAuthorizationsByRole(iq IQ, roleName string) ([]MemberMapping, error) { + return RepositoriesAuthorizationsByRoleContext(context.Background(), iq, roleName) +} + +func SetRepositoriesUserContext(ctx context.Context, iq IQ, roleName, user string) error { + return repositoriesAuth(ctx, iq, http.MethodPut, roleName, MemberTypeUser, user) } // SetRepositoriesUser sets the role and user that can have access to the repositories func SetRepositoriesUser(iq IQ, roleName, user string) error { - return repositoriesAuth(iq, http.MethodPut, roleName, MemberTypeUser, user) + return SetRepositoriesUserContext(context.Background(), iq, roleName, user) +} + +func SetRepositoriesGroupContext(ctx context.Context, iq IQ, roleName, group string) error { + return repositoriesAuth(ctx, iq, http.MethodPut, roleName, MemberTypeGroup, group) } // SetRepositoriesGroup sets the role and group that can have access to the repositories func SetRepositoriesGroup(iq IQ, roleName, group string) error { - return repositoriesAuth(iq, http.MethodPut, roleName, MemberTypeGroup, group) + return SetRepositoriesGroupContext(context.Background(), iq, roleName, group) +} + +func RevokeRepositoriesUserContext(ctx context.Context, iq IQ, roleName, user string) error { + return repositoriesAuth(ctx, iq, http.MethodDelete, roleName, MemberTypeUser, user) } // RevokeRepositoriesUser revoke the role and user that can have access to the repositories func RevokeRepositoriesUser(iq IQ, roleName, user string) error { - return repositoriesAuth(iq, http.MethodDelete, roleName, MemberTypeUser, user) + return RevokeRepositoriesUserContext(context.Background(), iq, roleName, user) +} + +func RevokeRepositoriesGroupContext(ctx context.Context, iq IQ, roleName, group string) error { + return repositoriesAuth(ctx, iq, http.MethodDelete, roleName, MemberTypeGroup, group) } // RevokeRepositoriesGroup revoke the role and group that can have access to the repositories func RevokeRepositoriesGroup(iq IQ, roleName, group string) error { - return repositoriesAuth(iq, http.MethodDelete, roleName, MemberTypeGroup, group) + return RevokeRepositoriesGroupContext(context.Background(), iq, roleName, group) } -func membersByRoleID(iq IQ, roleID string) ([]MemberMapping, error) { +func membersByRoleID(ctx context.Context, iq IQ, roleID string) ([]MemberMapping, error) { members := make([]MemberMapping, 0) - if m, err := organizationAuthorizationsByRoleID(iq, roleID); err == nil && len(m) > 0 { + if m, err := organizationAuthorizationsByRoleID(ctx, iq, roleID); err == nil && len(m) > 0 { members = append(members, m...) } - if m, err := applicationAuthorizationsByRoleID(iq, roleID); err == nil && len(m) > 0 { + if m, err := applicationAuthorizationsByRoleID(ctx, iq, roleID); err == nil && len(m) > 0 { members = append(members, m...) } - if hasRev70API(iq) { - if m, err := repositoriesAuthorizationsByRoleID(iq, roleID); err == nil && len(m) > 0 { + if hasRev70API(ctx, iq) { + if m, err := repositoriesAuthorizationsByRoleID(ctx, iq, roleID); err == nil && len(m) > 0 { members = append(members, m...) } } @@ -541,18 +614,21 @@ func membersByRoleID(iq IQ, roleID string) ([]MemberMapping, error) { return members, nil } -// MembersByRole returns all users and groups by role name -func MembersByRole(iq IQ, roleName string) ([]MemberMapping, error) { - role, err := RoleByName(iq, roleName) +func MembersByRoleContext(ctx context.Context, iq IQ, roleName string) ([]MemberMapping, error) { + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return nil, fmt.Errorf("could not find role with name %s: %v", roleName, err) } - return membersByRoleID(iq, role.ID) + return membersByRoleID(ctx, iq, role.ID) } -// GlobalAuthorizations returns all of the users and roles who have the administrator role across all of IQ -func GlobalAuthorizations(iq IQ) ([]MemberMapping, error) { - body, _, err := iq.Get(restRoleMembersGlobalGet) +// MembersByRole returns all users and groups by role name +func MembersByRole(iq IQ, roleName string) ([]MemberMapping, error) { + return MembersByRoleContext(context.Background(), iq, roleName) +} + +func GlobalAuthorizationsContext(ctx context.Context, iq IQ) ([]MemberMapping, error) { + body, _, err := iq.Get(ctx, restRoleMembersGlobalGet) if err != nil { return nil, fmt.Errorf("could not get global members: %v", err) } @@ -566,12 +642,17 @@ func GlobalAuthorizations(iq IQ) ([]MemberMapping, error) { return mappings.MemberMappings, nil } -func globalAuth(iq IQ, method, roleName, memberType, member string) error { - if !hasRev70API(iq) { +// GlobalAuthorizations returns all of the users and roles who have the administrator role across all of IQ +func GlobalAuthorizations(iq IQ) ([]MemberMapping, error) { + return GlobalAuthorizationsContext(context.Background(), iq) +} + +func globalAuth(ctx context.Context, iq IQ, method, roleName, memberType, member string) error { + if !hasRev70API(ctx, iq) { return fmt.Errorf("did not find revision 70 API") } - role, err := RoleByName(iq, roleName) + role, err := RoleByNameContext(ctx, iq, roleName) if err != nil { return fmt.Errorf("could not find role with name %s: %v", roleName, err) } @@ -586,9 +667,9 @@ func globalAuth(iq IQ, method, roleName, memberType, member string) error { switch method { case http.MethodPut: - _, _, err = iq.Put(endpoint, nil) + _, _, err = iq.Put(ctx, endpoint, nil) case http.MethodDelete: - _, err = iq.Del(endpoint) + _, err = iq.Del(ctx, endpoint) } if err != nil { return fmt.Errorf("could not affect global role mapping: %v", err) @@ -597,22 +678,38 @@ func globalAuth(iq IQ, method, roleName, memberType, member string) error { return nil } +func SetGlobalUserContext(ctx context.Context, iq IQ, roleName, user string) error { + return globalAuth(ctx, iq, http.MethodPut, roleName, MemberTypeUser, user) +} + // SetGlobalUser sets the role and user that can have access to the repositories func SetGlobalUser(iq IQ, roleName, user string) error { - return globalAuth(iq, http.MethodPut, roleName, MemberTypeUser, user) + return SetGlobalUserContext(context.Background(), iq, roleName, user) +} + +func SetGlobalGroupContext(ctx context.Context, iq IQ, roleName, group string) error { + return globalAuth(ctx, iq, http.MethodPut, roleName, MemberTypeGroup, group) } // SetGlobalGroup sets the role and group that can have access to the global func SetGlobalGroup(iq IQ, roleName, group string) error { - return globalAuth(iq, http.MethodPut, roleName, MemberTypeGroup, group) + return SetGlobalGroupContext(context.Background(), iq, roleName, group) +} + +func RevokeGlobalUserContext(ctx context.Context, iq IQ, roleName, user string) error { + return globalAuth(ctx, iq, http.MethodDelete, roleName, MemberTypeUser, user) } // RevokeGlobalUser revoke the role and user that can have access to the global func RevokeGlobalUser(iq IQ, roleName, user string) error { - return globalAuth(iq, http.MethodDelete, roleName, MemberTypeUser, user) + return RevokeGlobalUserContext(context.Background(), iq, roleName, user) +} + +func RevokeGlobalGroupContext(ctx context.Context, iq IQ, roleName, group string) error { + return globalAuth(ctx, iq, http.MethodDelete, roleName, MemberTypeGroup, group) } // RevokeGlobalGroup revoke the role and group that can have access to the global func RevokeGlobalGroup(iq IQ, roleName, group string) error { - return globalAuth(iq, http.MethodDelete, roleName, MemberTypeGroup, group) + return RevokeGlobalGroupContext(context.Background(), iq, roleName, group) } diff --git a/iq/roleMemberships_test.go b/iq/roleMemberships_test.go index 7bf8a36..89a0814 100644 --- a/iq/roleMemberships_test.go +++ b/iq/roleMemberships_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -397,7 +398,7 @@ func testGetOrganizationAuthorizations(t *testing.T, iq IQ) { t.Helper() dummyIdx := 0 - got, err := OrganizationAuthorizations(iq, dummyOrgs[dummyIdx].Name) + got, err := OrganizationAuthorizationsContext(context.Background(), iq, dummyOrgs[dummyIdx].Name) if err != nil { t.Error(err) } @@ -424,7 +425,7 @@ func testGetOrganizationAuthorizationsByRole(t *testing.T, iq IQ) { } } - got, err := OrganizationAuthorizationsByRole(iq, role.Name) + got, err := OrganizationAuthorizationsByRoleContext(context.Background(), iq, role.Name) if err != nil { t.Error(err) } @@ -448,7 +449,7 @@ func testGetApplicationAuthorizations(t *testing.T, iq IQ) { t.Helper() dummyIdx := 0 - got, err := ApplicationAuthorizations(iq, dummyApps[dummyIdx].PublicID) + got, err := ApplicationAuthorizationsContext(context.Background(), iq, dummyApps[dummyIdx].PublicID) if err != nil { t.Error(err) } @@ -474,7 +475,7 @@ func testGetApplicationAuthorizationsByRole(t *testing.T, iq IQ) { } } - got, err := ApplicationAuthorizationsByRole(iq, role.Name) + got, err := ApplicationAuthorizationsByRoleContext(context.Background(), iq, role.Name) if err != nil { t.Error(err) } @@ -517,22 +518,22 @@ func testSetAuth(t *testing.T, iq IQ, authTarget string, memberType string) { case "organization": switch memberType { case MemberTypeUser: - err = SetOrganizationUser(iq, dummyOrgs[dummyIdx].Name, dummyRoles[role].Name, memberName) + err = SetOrganizationUserContext(context.Background(), iq, dummyOrgs[dummyIdx].Name, dummyRoles[role].Name, memberName) case MemberTypeGroup: - err = SetOrganizationGroup(iq, dummyOrgs[dummyIdx].Name, dummyRoles[role].Name, memberName) + err = SetOrganizationGroupContext(context.Background(), iq, dummyOrgs[dummyIdx].Name, dummyRoles[role].Name, memberName) } if err == nil { - got, err = OrganizationAuthorizations(iq, dummyOrgs[dummyIdx].Name) + got, err = OrganizationAuthorizationsContext(context.Background(), iq, dummyOrgs[dummyIdx].Name) } case "application": switch memberType { case MemberTypeUser: - err = SetApplicationUser(iq, dummyApps[dummyIdx].PublicID, dummyRoles[role].Name, memberName) + err = SetApplicationUserContext(context.Background(), iq, dummyApps[dummyIdx].PublicID, dummyRoles[role].Name, memberName) case MemberTypeGroup: - err = SetApplicationGroup(iq, dummyApps[dummyIdx].PublicID, dummyRoles[role].Name, memberName) + err = SetApplicationGroupContext(context.Background(), iq, dummyApps[dummyIdx].PublicID, dummyRoles[role].Name, memberName) } if err == nil { - got, err = ApplicationAuthorizations(iq, dummyApps[dummyIdx].PublicID) + got, err = ApplicationAuthorizationsContext(context.Background(), iq, dummyApps[dummyIdx].PublicID) } } if err != nil { @@ -613,68 +614,68 @@ func testRevoke(t *testing.T, iq IQ, authType, memberType string) { dummyOrgName := dummyOrgs[0].Name switch memberType { case MemberTypeUser: - err = SetOrganizationUser(iq, dummyOrgName, role.Name, name) + err = SetOrganizationUserContext(context.Background(), iq, dummyOrgName, role.Name, name) if err == nil { - err = RevokeOrganizationUser(iq, dummyOrgName, role.Name, name) + err = RevokeOrganizationUserContext(context.Background(), iq, dummyOrgName, role.Name, name) } case MemberTypeGroup: - err = SetOrganizationGroup(iq, dummyOrgName, role.Name, name) + err = SetOrganizationGroupContext(context.Background(), iq, dummyOrgName, role.Name, name) if err == nil { t.Log("HERE1") - err = RevokeOrganizationGroup(iq, dummyOrgName, role.Name, name) + err = RevokeOrganizationGroupContext(context.Background(), iq, dummyOrgName, role.Name, name) } } if err == nil { - mappings, err = OrganizationAuthorizations(iq, dummyOrgName) + mappings, err = OrganizationAuthorizationsContext(context.Background(), iq, dummyOrgName) } case "application": dummyAppName := dummyApps[0].PublicID switch memberType { case MemberTypeUser: - err = SetApplicationUser(iq, dummyAppName, role.Name, name) + err = SetApplicationUserContext(context.Background(), iq, dummyAppName, role.Name, name) if err == nil { - err = RevokeApplicationUser(iq, dummyAppName, role.Name, name) + err = RevokeApplicationUserContext(context.Background(), iq, dummyAppName, role.Name, name) } case MemberTypeGroup: - err = SetApplicationGroup(iq, dummyAppName, role.Name, name) + err = SetApplicationGroupContext(context.Background(), iq, dummyAppName, role.Name, name) if err == nil { - err = RevokeApplicationGroup(iq, dummyAppName, role.Name, name) + err = RevokeApplicationGroupContext(context.Background(), iq, dummyAppName, role.Name, name) } } if err == nil { - mappings, err = ApplicationAuthorizations(iq, dummyAppName) + mappings, err = ApplicationAuthorizationsContext(context.Background(), iq, dummyAppName) } case "repository_container": switch memberType { case MemberTypeUser: - err = SetRepositoriesUser(iq, role.Name, name) + err = SetRepositoriesUserContext(context.Background(), iq, role.Name, name) if err == nil { - err = RevokeRepositoriesUser(iq, role.Name, name) + err = RevokeRepositoriesUserContext(context.Background(), iq, role.Name, name) } case MemberTypeGroup: - err = SetRepositoriesGroup(iq, role.Name, name) + err = SetRepositoriesGroupContext(context.Background(), iq, role.Name, name) if err == nil { - err = RevokeRepositoriesGroup(iq, role.Name, name) + err = RevokeRepositoriesGroupContext(context.Background(), iq, role.Name, name) } } if err == nil { - mappings, err = RepositoriesAuthorizations(iq) + mappings, err = RepositoriesAuthorizationsContext(context.Background(), iq) } case "global": switch memberType { case MemberTypeUser: - err = SetGlobalUser(iq, role.Name, name) + err = SetGlobalUserContext(context.Background(), iq, role.Name, name) if err == nil { - err = RevokeGlobalUser(iq, role.Name, name) + err = RevokeGlobalUserContext(context.Background(), iq, role.Name, name) } case MemberTypeGroup: - err = SetGlobalGroup(iq, role.Name, name) + err = SetGlobalGroupContext(context.Background(), iq, role.Name, name) if err == nil { - err = RevokeGlobalGroup(iq, role.Name, name) + err = RevokeGlobalGroupContext(context.Background(), iq, role.Name, name) } } if err == nil { - mappings, err = GlobalAuthorizations(iq) + mappings, err = GlobalAuthorizationsContext(context.Background(), iq) } } if err != nil { @@ -732,7 +733,7 @@ func TestRepositoriesAuthorizations(t *testing.T) { iq, mock := roleMembershipsTestIQ(t, false) defer mock.Close() - got, err := RepositoriesAuthorizations(iq) + got, err := RepositoriesAuthorizationsContext(context.Background(), iq) if err != nil { t.Error(err) } @@ -758,7 +759,7 @@ func TestGetApplicationAuthorizationsByRole(t *testing.T) { } } - got, err := RepositoriesAuthorizationsByRole(iq, role.Name) + got, err := RepositoriesAuthorizationsByRoleContext(context.Background(), iq, role.Name) if err != nil { t.Error(err) } @@ -791,15 +792,15 @@ func testSetRepositories(t *testing.T, memberType string) { var err error switch memberType { case MemberTypeUser: - err = SetRepositoriesUser(iq, role.Name, memberName) + err = SetRepositoriesUserContext(context.Background(), iq, role.Name, memberName) case MemberTypeGroup: - err = SetRepositoriesGroup(iq, role.Name, memberName) + err = SetRepositoriesGroupContext(context.Background(), iq, role.Name, memberName) } if err != nil { t.Error(err) } - got, err := RepositoriesAuthorizations(iq) + got, err := RepositoriesAuthorizationsContext(context.Background(), iq) if err != nil { t.Error(err) } @@ -860,7 +861,7 @@ func testMembersByRole(t *testing.T, iq IQ) { } } } - if hasRev70API(iq) { + if hasRev70API(context.Background(), iq) { for _, m := range dummyRoleMappingsRepos { if m.RoleID == role.ID { want = append(want, m) @@ -868,7 +869,7 @@ func testMembersByRole(t *testing.T, iq IQ) { } } - got, err := MembersByRole(iq, role.Name) + got, err := MembersByRoleContext(context.Background(), iq, role.Name) if err != nil { t.Error(err) } @@ -888,7 +889,7 @@ func TestGlobalAuthorizations(t *testing.T) { iq, mock := roleMembershipsTestIQ(t, false) defer mock.Close() - got, err := GlobalAuthorizations(iq) + got, err := GlobalAuthorizationsContext(context.Background(), iq) if err != nil { t.Error(err) } @@ -924,15 +925,15 @@ func testSetGlobal(t *testing.T, memberType string) { var err error switch memberType { case MemberTypeUser: - err = SetGlobalUser(iq, role.Name, memberName) + err = SetGlobalUserContext(context.Background(), iq, role.Name, memberName) case MemberTypeGroup: - err = SetGlobalGroup(iq, role.Name, memberName) + err = SetGlobalGroupContext(context.Background(), iq, role.Name, memberName) } if err != nil { t.Error(err) } - got, err := GlobalAuthorizations(iq) + got, err := GlobalAuthorizationsContext(context.Background(), iq) if err != nil { t.Error(err) } diff --git a/iq/roles.go b/iq/roles.go index 830bddd..2fc9ff2 100644 --- a/iq/roles.go +++ b/iq/roles.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "net/http" @@ -22,11 +23,10 @@ type Role struct { Description string `json:"description"` } -// Roles returns a slice of all the roles in the IQ instance -func Roles(iq IQ) ([]Role, error) { - body, resp, err := iq.Get(restRoles) +func RolesContext(ctx context.Context, iq IQ) ([]Role, error) { + body, resp, err := iq.Get(ctx, restRoles) if resp.StatusCode == http.StatusNotFound { - body, _, err = iq.Get(restRolesDeprecated) + body, _, err = iq.Get(ctx, restRolesDeprecated) } if err != nil { return nil, fmt.Errorf("could not retrieve roles: %v", err) @@ -40,9 +40,13 @@ func Roles(iq IQ) ([]Role, error) { return list.Roles, nil } -// RoleByName returns the named role -func RoleByName(iq IQ, name string) (Role, error) { - roles, err := Roles(iq) +// Roles returns a slice of all the roles in the IQ instance +func Roles(iq IQ) ([]Role, error) { + return RolesContext(context.Background(), iq) +} + +func RoleByNameContext(ctx context.Context, iq IQ, name string) (Role, error) { + roles, err := RolesContext(ctx, iq) if err != nil { return Role{}, fmt.Errorf("did not find role with name %s: %v", name, err) } @@ -56,12 +60,21 @@ func RoleByName(iq IQ, name string) (Role, error) { return Role{}, fmt.Errorf("did not find role with name %s", name) } -// GetSystemAdminID returns the identifier of the System Administrator role -func GetSystemAdminID(iq IQ) (string, error) { - role, err := RoleByName(iq, "System Administrator") +// RoleByName returns the named role +func RoleByName(iq IQ, name string) (Role, error) { + return RoleByNameContext(context.Background(), iq, name) +} + +func GetSystemAdminIDContext(ctx context.Context, iq IQ) (string, error) { + role, err := RoleByNameContext(ctx, iq, "System Administrator") if err != nil { return "", fmt.Errorf("did not get admin role: %v", err) } return role.ID, nil } + +// GetSystemAdminID returns the identifier of the System Administrator role +func GetSystemAdminID(iq IQ) (string, error) { + return GetSystemAdminIDContext(context.Background(), iq) +} diff --git a/iq/roles_test.go b/iq/roles_test.go index 73bdaa2..0d524d2 100644 --- a/iq/roles_test.go +++ b/iq/roles_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "net/http" @@ -80,7 +81,7 @@ func TestRoles(t *testing.T) { iq, mock := rolesTestIQ(t) defer mock.Close() - got, err := Roles(iq) + got, err := RolesContext(context.Background(), iq) if err != nil { t.Error(err) } @@ -96,7 +97,7 @@ func TestRoleByName(t *testing.T) { want := dummyRoles[0] - got, err := RoleByName(iq, want.Name) + got, err := RoleByNameContext(context.Background(), iq, want.Name) if err != nil { t.Error(err) } diff --git a/iq/search.go b/iq/search.go index df65041..2757857 100644 --- a/iq/search.go +++ b/iq/search.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -117,10 +118,9 @@ func NewSearchQueryBuilder() *SearchQueryBuilder { return b } -// SearchComponents allows searching the indicated IQ instance for specific components -func SearchComponents(iq IQ, query nexus.SearchQueryBuilder) ([]SearchResult, error) { +func SearchComponentsContext(ctx context.Context, iq IQ, query nexus.SearchQueryBuilder) ([]SearchResult, error) { endpoint := restSearchComponent + "?" + query.Build() - body, resp, err := iq.Get(endpoint) + body, resp, err := iq.Get(ctx, endpoint) if err != nil || resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("could not find component: %v", err) } @@ -132,3 +132,8 @@ func SearchComponents(iq IQ, query nexus.SearchQueryBuilder) ([]SearchResult, er return searchResp.Results, nil } + +// SearchComponents allows searching the indicated IQ instance for specific components +func SearchComponents(iq IQ, query nexus.SearchQueryBuilder) ([]SearchResult, error) { + return SearchComponentsContext(context.Background(), iq, query) +} diff --git a/iq/search_test.go b/iq/search_test.go index 9747b00..7c5b339 100644 --- a/iq/search_test.go +++ b/iq/search_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "net/http" @@ -57,7 +58,7 @@ func TestSearchComponent(t *testing.T) { defer mock.Close() query := NewSearchQueryBuilder().Coordinates(dummyComponent.ComponentID.Coordinates) - components, err := SearchComponents(iq, query) + components, err := SearchComponentsContext(context.Background(), iq, query) if err != nil { t.Fatalf("Did not complete search: %v", err) } @@ -88,7 +89,7 @@ func ExampleSearchComponents() { query = query.Stage(StageBuild) query = query.PackageURL("pkg:maven/commons-collections/commons-collections@3.2") - components, err := SearchComponents(iq, query) + components, err := SearchComponentsContext(context.Background(), iq, query) if err != nil { panic(fmt.Sprintf("Did not complete search: %v", err)) } diff --git a/iq/sourceControl.go b/iq/sourceControl.go index 4514246..5c01bd4 100644 --- a/iq/sourceControl.go +++ b/iq/sourceControl.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -20,10 +21,10 @@ type SourceControlEntry struct { Token string `json:"token"` } -func getSourceControlEntryByInternalID(iq IQ, applicationID string) (entry SourceControlEntry, err error) { +func getSourceControlEntryByInternalID(ctx context.Context, iq IQ, applicationID string) (entry SourceControlEntry, err error) { endpoint := fmt.Sprintf(restSourceControl, applicationID) - body, _, err := iq.Get(endpoint) + body, _, err := iq.Get(ctx, endpoint) if err != nil { return } @@ -33,26 +34,29 @@ func getSourceControlEntryByInternalID(iq IQ, applicationID string) (entry Sourc return } -// GetSourceControlEntry lists of all of the Source Control entries for the given application -func GetSourceControlEntry(iq IQ, applicationID string) (SourceControlEntry, error) { - appInfo, err := GetApplicationByPublicID(iq, applicationID) +func GetSourceControlEntryContext(ctx context.Context, iq IQ, applicationID string) (SourceControlEntry, error) { + appInfo, err := GetApplicationByPublicIDContext(ctx, iq, applicationID) if err != nil { return SourceControlEntry{}, fmt.Errorf("no source control entry for '%s': %v", applicationID, err) } - return getSourceControlEntryByInternalID(iq, appInfo.ID) + return getSourceControlEntryByInternalID(ctx, iq, appInfo.ID) } -// GetAllSourceControlEntries lists of all of the Source Control entries in the IQ instance -func GetAllSourceControlEntries(iq IQ) ([]SourceControlEntry, error) { - apps, err := GetAllApplications(iq) +// GetSourceControlEntry lists of all of the Source Control entries for the given application +func GetSourceControlEntry(iq IQ, applicationID string) (SourceControlEntry, error) { + return GetSourceControlEntryContext(context.Background(), iq, applicationID) +} + +func GetAllSourceControlEntriesContext(ctx context.Context, iq IQ) ([]SourceControlEntry, error) { + apps, err := GetAllApplicationsContext(ctx, iq) if err != nil { return nil, fmt.Errorf("no source control entries: %v", err) } entries := make([]SourceControlEntry, 0) for _, app := range apps { - if entry, err := getSourceControlEntryByInternalID(iq, app.ID); err == nil { + if entry, err := getSourceControlEntryByInternalID(ctx, iq, app.ID); err == nil { entries = append(entries, entry) } } @@ -60,13 +64,17 @@ func GetAllSourceControlEntries(iq IQ) ([]SourceControlEntry, error) { return entries, nil } -// CreateSourceControlEntry creates a source control entry in IQ -func CreateSourceControlEntry(iq IQ, applicationID, repositoryURL, token string) error { +// GetAllSourceControlEntries lists of all of the Source Control entries in the IQ instance +func GetAllSourceControlEntries(iq IQ) ([]SourceControlEntry, error) { + return GetAllSourceControlEntriesContext(context.Background(), iq) +} + +func CreateSourceControlEntryContext(ctx context.Context, iq IQ, applicationID, repositoryURL, token string) error { doError := func(err error) error { return fmt.Errorf("source control entry not created for '%s': %v", applicationID, err) } - appInfo, err := GetApplicationByPublicID(iq, applicationID) + appInfo, err := GetApplicationByPublicIDContext(ctx, iq, applicationID) if err != nil { return doError(err) } @@ -77,20 +85,24 @@ func CreateSourceControlEntry(iq IQ, applicationID, repositoryURL, token string) } endpoint := fmt.Sprintf(restSourceControl, appInfo.ID) - if _, _, err = iq.Post(endpoint, bytes.NewBuffer(request)); err != nil { + if _, _, err = iq.Post(ctx, endpoint, bytes.NewBuffer(request)); err != nil { return doError(err) } return nil } -// UpdateSourceControlEntry updates a source control entry in IQ -func UpdateSourceControlEntry(iq IQ, applicationID, repositoryURL, token string) error { +// CreateSourceControlEntry creates a source control entry in IQ +func CreateSourceControlEntry(iq IQ, applicationID, repositoryURL, token string) error { + return CreateSourceControlEntryContext(context.Background(), iq, applicationID, repositoryURL, token) +} + +func UpdateSourceControlEntryContext(ctx context.Context, iq IQ, applicationID, repositoryURL, token string) error { doError := func(err error) error { return fmt.Errorf("source control entry not updated for '%s': %v", applicationID, err) } - appInfo, err := GetApplicationByPublicID(iq, applicationID) + appInfo, err := GetApplicationByPublicIDContext(ctx, iq, applicationID) if err != nil { return doError(err) } @@ -101,17 +113,22 @@ func UpdateSourceControlEntry(iq IQ, applicationID, repositoryURL, token string) } endpoint := fmt.Sprintf(restSourceControl, appInfo.ID) - if _, _, err = iq.Put(endpoint, bytes.NewBuffer(request)); err != nil { + if _, _, err = iq.Put(ctx, endpoint, bytes.NewBuffer(request)); err != nil { return doError(err) } return nil } -func deleteSourceControlEntry(iq IQ, appInternalID, sourceControlID string) error { +// UpdateSourceControlEntry updates a source control entry in IQ +func UpdateSourceControlEntry(iq IQ, applicationID, repositoryURL, token string) error { + return UpdateSourceControlEntryContext(context.Background(), iq, applicationID, repositoryURL, token) +} + +func deleteSourceControlEntry(ctx context.Context, iq IQ, appInternalID, sourceControlID string) error { endpoint := fmt.Sprintf(restSourceControlDelete, appInternalID, sourceControlID) - resp, err := iq.Del(endpoint) + resp, err := iq.Del(ctx, endpoint) if err != nil && resp.StatusCode != http.StatusNoContent { return err } @@ -119,33 +136,41 @@ func deleteSourceControlEntry(iq IQ, appInternalID, sourceControlID string) erro return nil } -// DeleteSourceControlEntry deletes a source control entry in IQ -func DeleteSourceControlEntry(iq IQ, applicationID, sourceControlID string) error { - appInfo, err := GetApplicationByPublicID(iq, applicationID) +func DeleteSourceControlEntryContext(ctx context.Context, iq IQ, applicationID, sourceControlID string) error { + appInfo, err := GetApplicationByPublicIDContext(ctx, iq, applicationID) if err != nil { return fmt.Errorf("source control entry not deleted from '%s': %v", applicationID, err) } - return deleteSourceControlEntry(iq, appInfo.ID, sourceControlID) + return deleteSourceControlEntry(ctx, iq, appInfo.ID, sourceControlID) } -// DeleteSourceControlEntryByApp deletes a source control entry in IQ for the given application -func DeleteSourceControlEntryByApp(iq IQ, applicationID string) error { +// DeleteSourceControlEntry deletes a source control entry in IQ +func DeleteSourceControlEntry(iq IQ, applicationID, sourceControlID string) error { + return DeleteSourceControlEntryContext(context.Background(), iq, applicationID, sourceControlID) +} + +func DeleteSourceControlEntryByAppContext(ctx context.Context, iq IQ, applicationID string) error { doError := func(err error) error { return fmt.Errorf("source control entry not deleted from '%s': %v", applicationID, err) } - appInfo, err := GetApplicationByPublicID(iq, applicationID) + appInfo, err := GetApplicationByPublicIDContext(ctx, iq, applicationID) if err != nil { return doError(err) } - entry, err := getSourceControlEntryByInternalID(iq, appInfo.ID) + entry, err := getSourceControlEntryByInternalID(ctx, iq, appInfo.ID) if err != nil { return doError(err) } - return deleteSourceControlEntry(iq, appInfo.ID, entry.ID) + return deleteSourceControlEntry(ctx, iq, appInfo.ID, entry.ID) +} + +// DeleteSourceControlEntryByApp deletes a source control entry in IQ for the given application +func DeleteSourceControlEntryByApp(iq IQ, applicationID string) error { + return DeleteSourceControlEntryByAppContext(context.Background(), iq, applicationID) } // DeleteSourceControlEntryByEntry deletes a source control entry in IQ for the given entry ID diff --git a/iq/sourceControl_test.go b/iq/sourceControl_test.go index 7af2c93..a96c6e5 100644 --- a/iq/sourceControl_test.go +++ b/iq/sourceControl_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -115,7 +116,7 @@ func TestGetSourceControlEntryByInternalID(t *testing.T) { dummyEntryIdx := 2 - entry, err := getSourceControlEntryByInternalID(iq, dummyEntries[dummyEntryIdx].ApplicationID) + entry, err := getSourceControlEntryByInternalID(context.Background(), iq, dummyEntries[dummyEntryIdx].ApplicationID) if err != nil { t.Error(err) } @@ -131,7 +132,7 @@ func TestGetAllSourceControlEntries(t *testing.T) { iq, mock := sourceControlTestIQ(t) defer mock.Close() - entries, err := GetAllSourceControlEntries(iq) + entries, err := GetAllSourceControlEntriesContext(context.Background(), iq) if err != nil { t.Error(err) } @@ -161,7 +162,7 @@ func TestGetSourceControlEntry(t *testing.T) { dummyEntryIdx := 0 - entry, err := GetSourceControlEntry(iq, dummyApps[dummyEntryIdx].PublicID) + entry, err := GetSourceControlEntryContext(context.Background(), iq, dummyApps[dummyEntryIdx].PublicID) if err != nil { t.Error(err) } @@ -179,12 +180,12 @@ func TestCreateSourceControlEntry(t *testing.T) { createdEntry := SourceControlEntry{newEntryID, dummyApps[len(dummyApps)-1].ID, "createdEntryURL", "createEntryToken"} - err := CreateSourceControlEntry(iq, dummyApps[len(dummyApps)-1].PublicID, createdEntry.RepositoryURL, createdEntry.Token) + err := CreateSourceControlEntryContext(context.Background(), iq, dummyApps[len(dummyApps)-1].PublicID, createdEntry.RepositoryURL, createdEntry.Token) if err != nil { t.Error(err) } - entry, err := GetSourceControlEntry(iq, dummyApps[len(dummyApps)-1].PublicID) + entry, err := GetSourceControlEntryContext(context.Background(), iq, dummyApps[len(dummyApps)-1].PublicID) if err != nil { t.Error(err) } @@ -202,12 +203,12 @@ func TestUpdateSourceControlEntry(t *testing.T) { updatedEntryRepositoryURL := "updatedRepoURL" updatedEntryToken := "updatedToken" - err := UpdateSourceControlEntry(iq, dummyApps[len(dummyApps)-2].PublicID, updatedEntryRepositoryURL, updatedEntryToken) + err := UpdateSourceControlEntryContext(context.Background(), iq, dummyApps[len(dummyApps)-2].PublicID, updatedEntryRepositoryURL, updatedEntryToken) if err != nil { t.Error(err) } - entry, err := GetSourceControlEntry(iq, dummyApps[len(dummyApps)-2].PublicID) + entry, err := GetSourceControlEntryContext(context.Background(), iq, dummyApps[len(dummyApps)-2].PublicID) if err != nil { t.Error(err) } @@ -229,15 +230,15 @@ func TestDeleteSourceControlEntry(t *testing.T) { app := dummyApps[len(dummyApps)-1] deleteMe := SourceControlEntry{newEntryID, app.ID, "deleteMeURL", "deleteMeToken"} - if err := CreateSourceControlEntry(iq, app.PublicID, deleteMe.RepositoryURL, deleteMe.Token); err != nil { + if err := CreateSourceControlEntryContext(context.Background(), iq, app.PublicID, deleteMe.RepositoryURL, deleteMe.Token); err != nil { t.Error(err) } - if err := DeleteSourceControlEntry(iq, app.PublicID, newEntryID); err != nil { + if err := DeleteSourceControlEntryContext(context.Background(), iq, app.PublicID, newEntryID); err != nil { t.Error(err) } - if _, err := GetSourceControlEntry(iq, app.PublicID); err == nil { + if _, err := GetSourceControlEntryContext(context.Background(), iq, app.PublicID); err == nil { t.Error("Unexpectedly found entry which should have been deleted") } } @@ -249,15 +250,15 @@ func TestDeleteSourceControlEntryByApp(t *testing.T) { app := dummyApps[len(dummyApps)-1] deleteMe := SourceControlEntry{newEntryID, app.ID, "deleteMeURL", "deleteMeToken"} - if err := CreateSourceControlEntry(iq, app.PublicID, deleteMe.RepositoryURL, deleteMe.Token); err != nil { + if err := CreateSourceControlEntryContext(context.Background(), iq, app.PublicID, deleteMe.RepositoryURL, deleteMe.Token); err != nil { t.Error(err) } - if err := DeleteSourceControlEntryByApp(iq, app.PublicID); err != nil { + if err := DeleteSourceControlEntryByAppContext(context.Background(), iq, app.PublicID); err != nil { t.Error(err) } - if _, err := GetSourceControlEntry(iq, app.PublicID); err == nil { + if _, err := GetSourceControlEntryContext(context.Background(), iq, app.PublicID); err == nil { t.Error("Unexpectedly found entry which should have been deleted") } } diff --git a/iq/users.go b/iq/users.go index 332d813..2550710 100644 --- a/iq/users.go +++ b/iq/users.go @@ -2,6 +2,7 @@ package nexusiq import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -21,10 +22,9 @@ type User struct { Password string `json:"password,omitempty"` } -// GetUser returns user details for the given name -func GetUser(iq IQ, username string) (user User, err error) { +func GetUserContext(ctx context.Context, iq IQ, username string) (user User, err error) { endpoint := fmt.Sprintf(restUsers, username) - body, _, err := iq.Get(endpoint) + body, _, err := iq.Get(ctx, endpoint) if err != nil { return user, fmt.Errorf("could not retrieve details on username %s: %v", username, err) } @@ -34,32 +34,45 @@ func GetUser(iq IQ, username string) (user User, err error) { return user, err } -// SetUser creates a new user -func SetUser(iq IQ, user User) (err error) { +// GetUser returns user details for the given name +func GetUser(iq IQ, username string) (user User, err error) { + return GetUserContext(context.Background(), iq, username) +} + +func SetUserContext(ctx context.Context, iq IQ, user User) (err error) { buf, err := json.Marshal(user) if err != nil { return fmt.Errorf("could not read user details: %v", err) } str := bytes.NewBuffer(buf) - if _, er := GetUser(iq, user.Username); er != nil { - _, resp, er := iq.Post(restUsersPost, str) + if _, er := GetUserContext(ctx, iq, user.Username); er != nil { + _, resp, er := iq.Post(ctx, restUsersPost, str) if er != nil && resp.StatusCode != http.StatusNoContent { return er } } else { endpoint := fmt.Sprintf(restUsers, user.Username) - _, _, err = iq.Put(endpoint, str) + _, _, err = iq.Put(ctx, endpoint, str) } return err } -// DeleteUser removes the named user -func DeleteUser(iq IQ, username string) error { +// SetUser creates a new user +func SetUser(iq IQ, user User) (err error) { + return SetUserContext(context.Background(), iq, user) +} + +func DeleteUserContext(ctx context.Context, iq IQ, username string) error { endpoint := fmt.Sprintf(restUsers, username) - if resp, err := iq.Del(endpoint); err != nil && resp.StatusCode != http.StatusNoContent { + if resp, err := iq.Del(ctx, endpoint); err != nil && resp.StatusCode != http.StatusNoContent { return err } return nil } + +// DeleteUser removes the named user +func DeleteUser(iq IQ, username string) error { + return DeleteUserContext(context.Background(), iq, username) +} diff --git a/iq/users_test.go b/iq/users_test.go index a644797..35c9f8f 100644 --- a/iq/users_test.go +++ b/iq/users_test.go @@ -1,6 +1,7 @@ package nexusiq import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -145,7 +146,7 @@ func usersTestIQ(t *testing.T, useDeprecated bool) (IQ, *httptest.Server) { func checkExists(t *testing.T, iq IQ, want User) { t.Helper() - got, err := GetUser(iq, want.Username) + got, err := GetUserContext(context.Background(), iq, want.Username) if err != nil { t.Error(err) } @@ -165,7 +166,7 @@ func TestGetUser(t *testing.T) { } func setUser(t *testing.T, iq IQ, want User) { - err := SetUser(iq, want) + err := SetUserContext(context.Background(), iq, want) if err != nil { t.Error(err) } @@ -234,12 +235,12 @@ func TestDeleteUser(t *testing.T) { // Create new dummy user setUser(t, iq, want) - err := DeleteUser(iq, want.Username) + err := DeleteUserContext(context.Background(), iq, want.Username) if err != nil { t.Error(err) } - if _, err := GetUser(iq, want.Username); err == nil { + if _, err := GetUserContext(context.Background(), iq, want.Username); err == nil { t.Error("Found user which I tried to delete") } } diff --git a/nexus.go b/nexus.go index 7ab00f9..9059d2a 100644 --- a/nexus.go +++ b/nexus.go @@ -1,6 +1,7 @@ package nexus import ( + "context" "crypto/tls" "crypto/x509" "errors" @@ -20,12 +21,12 @@ type ServerInfo struct { // Client is the interface which allows interacting with an IQ server type Client interface { - NewRequest(method, endpoint string, payload io.Reader) (*http.Request, error) + NewRequest(ctx context.Context, method, endpoint string, payload io.Reader) (*http.Request, error) Do(request *http.Request) ([]byte, *http.Response, error) - Get(endpoint string) ([]byte, *http.Response, error) - Post(endpoint string, payload io.Reader) ([]byte, *http.Response, error) - Put(endpoint string, payload io.Reader) ([]byte, *http.Response, error) - Del(endpoint string) (*http.Response, error) + Get(ctx context.Context, endpoint string) ([]byte, *http.Response, error) + Post(ctx context.Context, endpoint string, payload io.Reader) ([]byte, *http.Response, error) + Put(ctx context.Context, endpoint string, payload io.Reader) ([]byte, *http.Response, error) + Del(ctx context.Context, endpoint string) (*http.Response, error) Info() ServerInfo SetDebug(enable bool) SetCertFile(certFile string) @@ -38,9 +39,9 @@ type DefaultClient struct { } // NewRequest created an http.Request object based on an endpoint and fills in basic auth -func (s *DefaultClient) NewRequest(method, endpoint string, payload io.Reader) (request *http.Request, err error) { +func (s *DefaultClient) NewRequest(ctx context.Context, method, endpoint string, payload io.Reader) (request *http.Request, err error) { url := fmt.Sprintf("%s/%s", s.Host, endpoint) - request, err = http.NewRequest(method, url, payload) + request, err = http.NewRequestWithContext(ctx, method, url, payload) if err != nil { return } @@ -104,8 +105,8 @@ func (s *DefaultClient) Do(request *http.Request) (body []byte, resp *http.Respo return } -func (s *DefaultClient) http(method, endpoint string, payload io.Reader) ([]byte, *http.Response, error) { - request, err := s.NewRequest(method, endpoint, payload) +func (s *DefaultClient) http(ctx context.Context, method, endpoint string, payload io.Reader) ([]byte, *http.Response, error) { + request, err := s.NewRequest(ctx, method, endpoint, payload) if err != nil { return nil, nil, err } @@ -114,23 +115,23 @@ func (s *DefaultClient) http(method, endpoint string, payload io.Reader) ([]byte } // Get performs an HTTP GET against the indicated endpoint -func (s *DefaultClient) Get(endpoint string) ([]byte, *http.Response, error) { - return s.http(http.MethodGet, endpoint, nil) +func (s *DefaultClient) Get(ctx context.Context, endpoint string) ([]byte, *http.Response, error) { + return s.http(ctx, http.MethodGet, endpoint, nil) } // Post performs an HTTP POST against the indicated endpoint -func (s *DefaultClient) Post(endpoint string, payload io.Reader) ([]byte, *http.Response, error) { - return s.http(http.MethodPost, endpoint, payload) +func (s *DefaultClient) Post(ctx context.Context, endpoint string, payload io.Reader) ([]byte, *http.Response, error) { + return s.http(ctx, http.MethodPost, endpoint, payload) } // Put performs an HTTP PUT against the indicated endpoint -func (s *DefaultClient) Put(endpoint string, payload io.Reader) ([]byte, *http.Response, error) { - return s.http(http.MethodPut, endpoint, payload) +func (s *DefaultClient) Put(ctx context.Context, endpoint string, payload io.Reader) ([]byte, *http.Response, error) { + return s.http(ctx, http.MethodPut, endpoint, payload) } // Del performs an HTTP DELETE against the indicated endpoint -func (s *DefaultClient) Del(endpoint string) (resp *http.Response, err error) { - _, resp, err = s.http(http.MethodDelete, endpoint, nil) +func (s *DefaultClient) Del(ctx context.Context, endpoint string) (resp *http.Response, err error) { + _, resp, err = s.http(ctx, http.MethodDelete, endpoint, nil) return } diff --git a/rm/anonymous.go b/rm/anonymous.go index 36e6a0d..29bd9e9 100644 --- a/rm/anonymous.go +++ b/rm/anonymous.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -15,10 +16,10 @@ type SettingsAnonAccess struct { RealmName string `json:"realmName"` } -func GetAnonAccess(rm RM) (SettingsAnonAccess, error) { +func GetAnonAccessContext(ctx context.Context, rm RM) (SettingsAnonAccess, error) { var settings SettingsAnonAccess - body, resp, err := rm.Get(restAnonymous) + body, resp, err := rm.Get(ctx, restAnonymous) if err != nil && resp.StatusCode != http.StatusNoContent { return SettingsAnonAccess{}, fmt.Errorf("anonymous access settings can't getting: %v", err) } @@ -30,15 +31,23 @@ func GetAnonAccess(rm RM) (SettingsAnonAccess, error) { return settings, nil } -func SetAnonAccess(rm RM, settings SettingsAnonAccess) error { +func GetAnonAccess(rm RM) (SettingsAnonAccess, error) { + return GetAnonAccessContext(context.Background(), rm) +} + +func SetAnonAccessContext(ctx context.Context, rm RM, settings SettingsAnonAccess) error { json, err := json.Marshal(settings) if err != nil { return err } - if _, resp, err := rm.Put(restAnonymous, bytes.NewBuffer(json)); err != nil && resp.StatusCode != http.StatusNoContent { + if _, resp, err := rm.Put(ctx, restAnonymous, bytes.NewBuffer(json)); err != nil && resp.StatusCode != http.StatusNoContent { return fmt.Errorf("email config not set: %v", err) } return nil } + +func SetAnonAccess(rm RM, settings SettingsAnonAccess) error { + return SetAnonAccessContext(context.Background(), rm, settings) +} diff --git a/rm/assets.go b/rm/assets.go index 25b3cdb..fd9c16f 100644 --- a/rm/assets.go +++ b/rm/assets.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "net/http" @@ -31,8 +32,7 @@ type listAssetsResponse struct { ContinuationToken string `json:"continuationToken"` } -// GetAssets returns a list of assets in the indicated repository -func GetAssets(rm RM, repo string) (items []RepositoryItemAsset, err error) { +func GetAssetsContext(ctx context.Context, rm RM, repo string) (items []RepositoryItemAsset, err error) { continuation := "" get := func() (listResp listAssetsResponse, err error) { @@ -42,7 +42,7 @@ func GetAssets(rm RM, repo string) (items []RepositoryItemAsset, err error) { url += "&continuationToken=" + continuation } - body, resp, err := rm.Get(url) + body, resp, err := rm.Get(ctx, url) if err != nil || resp.StatusCode != http.StatusOK { return } @@ -71,8 +71,12 @@ func GetAssets(rm RM, repo string) (items []RepositoryItemAsset, err error) { return items, nil } -// GetAssetByID returns an asset by ID -func GetAssetByID(rm RM, id string) (items RepositoryItemAsset, err error) { +// GetAssets returns a list of assets in the indicated repository +func GetAssets(rm RM, repo string) (items []RepositoryItemAsset, err error) { + return GetAssetsContext(context.Background(), rm, repo) +} + +func GetAssetByIDContext(ctx context.Context, rm RM, id string) (items RepositoryItemAsset, err error) { doError := func(err error) error { return fmt.Errorf("no asset with id '%s': %v", id, err) } @@ -80,7 +84,7 @@ func GetAssetByID(rm RM, id string) (items RepositoryItemAsset, err error) { var item RepositoryItemAsset url := fmt.Sprintf("%s/%s", restAssets, id) - body, resp, err := rm.Get(url) + body, resp, err := rm.Get(ctx, url) if err != nil || resp.StatusCode != http.StatusOK { return item, doError(err) } @@ -92,13 +96,22 @@ func GetAssetByID(rm RM, id string) (items RepositoryItemAsset, err error) { return item, nil } -// DeleteAssetByID deletes the asset indicated by ID -func DeleteAssetByID(rm RM, id string) error { +// GetAssetByID returns an asset by ID +func GetAssetByID(rm RM, id string) (items RepositoryItemAsset, err error) { + return GetAssetByIDContext(context.Background(), rm, id) +} + +func DeleteAssetByIDContext(ctx context.Context, rm RM, id string) error { url := fmt.Sprintf("%s/%s", restAssets, id) - if resp, err := rm.Del(url); err != nil && resp.StatusCode != http.StatusNoContent { + if resp, err := rm.Del(ctx, url); err != nil && resp.StatusCode != http.StatusNoContent { return fmt.Errorf("asset not deleted '%s': %v", id, err) } return nil } + +// DeleteAssetByID deletes the asset indicated by ID +func DeleteAssetByID(rm RM, id string) error { + return DeleteAssetByIDContext(context.Background(), rm, id) +} diff --git a/rm/assets_test.go b/rm/assets_test.go index cedcfc9..1c26ec2 100644 --- a/rm/assets_test.go +++ b/rm/assets_test.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "net/http" @@ -130,7 +131,7 @@ func getAssetsTester(t *testing.T, repo string) { rm, mock := assetsTestRM(t) defer mock.Close() - assets, err := GetAssets(rm, repo) + assets, err := GetAssetsContext(context.Background(), rm, repo) if err != nil { panic(err) } @@ -162,7 +163,7 @@ func TestGetAssetByID(t *testing.T) { expectedAsset := dummyAssets["repo-maven"][0] - asset, err := GetAssetByID(rm, expectedAsset.ID) + asset, err := GetAssetByIDContext(context.Background(), rm, expectedAsset.ID) if err != nil { t.Error(err) } @@ -189,15 +190,15 @@ func TestDeleteAssetByID(t *testing.T) { dummyAssets[deleteMe.Repository] = append(dummyAssets[deleteMe.Repository], deleteMe) - if _, err := GetAssetByID(rm, deleteMe.ID); err != nil { + if _, err := GetAssetByIDContext(context.Background(), rm, deleteMe.ID); err != nil { t.Errorf("Error getting component: %v\n", err) } - if err := DeleteAssetByID(rm, deleteMe.ID); err != nil { + if err := DeleteAssetByIDContext(context.Background(), rm, deleteMe.ID); err != nil { t.Fatal(err) } - if _, err := GetAssetByID(rm, deleteMe.ID); err == nil { + if _, err := GetAssetByIDContext(context.Background(), rm, deleteMe.ID); err == nil { t.Errorf("Asset not deleted: %v\n", err) } } diff --git a/rm/components.go b/rm/components.go index 65f41b1..5d5e9eb 100644 --- a/rm/components.go +++ b/rm/components.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "io" @@ -286,8 +287,7 @@ func (a UploadComponentApt) write(w *multipart.Writer) error { return nil } -// GetComponents returns a list of components in the indicated repository -func GetComponents(rm RM, repo string) ([]RepositoryItem, error) { +func GetComponentsContext(ctx context.Context, rm RM, repo string) ([]RepositoryItem, error) { continuation := "" getComponents := func() (listResp listComponentsResponse, err error) { @@ -297,7 +297,7 @@ func GetComponents(rm RM, repo string) ([]RepositoryItem, error) { url += "&continuationToken=" + continuation } - body, resp, err := rm.Get(url) + body, resp, err := rm.Get(ctx, url) if err != nil || resp.StatusCode != http.StatusOK { return } @@ -326,8 +326,12 @@ func GetComponents(rm RM, repo string) ([]RepositoryItem, error) { return items, nil } -// GetComponentByID returns a component by ID -func GetComponentByID(rm RM, id string) (RepositoryItem, error) { +// GetComponents returns a list of components in the indicated repository +func GetComponents(rm RM, repo string) ([]RepositoryItem, error) { + return GetComponentsContext(context.Background(), rm, repo) +} + +func GetComponentByIDContext(ctx context.Context, rm RM, id string) (RepositoryItem, error) { doError := func(err error) error { return fmt.Errorf("no component with id '%s': %v", id, err) } @@ -335,7 +339,7 @@ func GetComponentByID(rm RM, id string) (RepositoryItem, error) { var item RepositoryItem url := fmt.Sprintf("%s/%s", restComponents, id) - body, resp, err := rm.Get(url) + body, resp, err := rm.Get(ctx, url) if err != nil || resp.StatusCode != http.StatusOK { return item, doError(err) } @@ -347,20 +351,28 @@ func GetComponentByID(rm RM, id string) (RepositoryItem, error) { return item, nil } -// DeleteComponentByID deletes the indicated component -func DeleteComponentByID(rm RM, id string) error { +// GetComponentByID returns a component by ID +func GetComponentByID(rm RM, id string) (RepositoryItem, error) { + return GetComponentByIDContext(context.Background(), rm, id) +} + +func DeleteComponentByIDContext(ctx context.Context, rm RM, id string) error { url := fmt.Sprintf("%s/%s", restComponents, id) - if resp, err := rm.Del(url); err != nil && resp.StatusCode != http.StatusNoContent { + if resp, err := rm.Del(ctx, url); err != nil && resp.StatusCode != http.StatusNoContent { return fmt.Errorf("component not deleted '%s': %v", id, err) } return nil } -// UploadComponent uploads a component to repository manager -func UploadComponent(rm RM, repo string, component UploadComponentWriter) error { - if _, err := GetRepositoryByName(rm, repo); err != nil { +// DeleteComponentByID deletes the indicated component +func DeleteComponentByID(rm RM, id string) error { + return DeleteComponentByIDContext(context.Background(), rm, id) +} + +func UploadComponentContext(ctx context.Context, rm RM, repo string, component UploadComponentWriter) error { + if _, err := GetRepositoryByNameContext(ctx, rm, repo); err != nil { return fmt.Errorf("could not find repository: %v", err) } @@ -379,7 +391,7 @@ func UploadComponent(rm RM, repo string, component UploadComponentWriter) error }() url := fmt.Sprintf(restListComponentsByRepo, repo) - req, err := rm.NewRequest("POST", url, b) + req, err := rm.NewRequest(ctx, "POST", url, b) req.Header.Set("Content-Type", m.FormDataContentType()) if err != nil { return doError(err) @@ -391,3 +403,8 @@ func UploadComponent(rm RM, repo string, component UploadComponentWriter) error return nil } + +// UploadComponent uploads a component to repository manager +func UploadComponent(rm RM, repo string, component UploadComponentWriter) error { + return UploadComponentContext(context.Background(), rm, repo, component) +} diff --git a/rm/components_test.go b/rm/components_test.go index 7fca5c8..58b045c 100644 --- a/rm/components_test.go +++ b/rm/components_test.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "net/http" @@ -168,7 +169,7 @@ func getComponentsTester(t *testing.T, repo string) { rm, mock := componentsTestRM(t) defer mock.Close() - components, err := GetComponents(rm, repo) + components, err := GetComponentsContext(context.Background(), rm, repo) if err != nil { panic(err) } @@ -200,7 +201,7 @@ func TestGetComponentByID(t *testing.T) { expectedComponent := dummyComponents["repo-maven"][0] - component, err := GetComponentByID(rm, expectedComponent.ID) + component, err := GetComponentByIDContext(context.Background(), rm, expectedComponent.ID) if err != nil { t.Error(err) } @@ -218,13 +219,13 @@ func componentUploader(t *testing.T, expected RepositoryItem, upload UploadCompo defer mock.Close() // if err := UploadComponent(rm, expected.Repository, coordinate, file); err != nil { - if err := UploadComponent(rm, expected.Repository, upload); err != nil { + if err := UploadComponentContext(context.Background(), rm, expected.Repository, upload); err != nil { t.Error(err) } expected.ID = dummyNewComponentID - component, err := GetComponentByID(rm, expected.ID) + component, err := GetComponentByIDContext(context.Background(), rm, expected.ID) if err != nil { t.Error(err) } @@ -314,17 +315,17 @@ func TestDeleteComponentByID(t *testing.T) { } // if err = UploadComponent(rm, deleteMe.Repository, coord, nil); err != nil { - if err = UploadComponent(rm, deleteMe.Repository, upload); err != nil { + if err = UploadComponentContext(context.Background(), rm, deleteMe.Repository, upload); err != nil { t.Error(err) } deleteMe.ID = dummyNewComponentID - if err = DeleteComponentByID(rm, deleteMe.ID); err != nil { + if err = DeleteComponentByIDContext(context.Background(), rm, deleteMe.ID); err != nil { t.Fatal(err) } - if _, err := GetComponentByID(rm, deleteMe.ID); err == nil { + if _, err := GetComponentByIDContext(context.Background(), rm, deleteMe.ID); err == nil { t.Errorf("Component not deleted: %v\n", err) } } @@ -335,7 +336,7 @@ func ExampleGetComponents() { panic(err) } - items, err := GetComponents(rm, "maven-central") + items, err := GetComponentsContext(context.Background(), rm, "maven-central") if err != nil { panic(err) } diff --git a/rm/email.go b/rm/email.go index fc7622b..0992bbb 100644 --- a/rm/email.go +++ b/rm/email.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -24,24 +25,27 @@ type EmailConfig struct { NexusTrustStoreEnabled bool `json:"nexusTrustStoreEnabled"` } -func SetEmailConfig(rm RM, config EmailConfig) error { - +func SetEmailConfigContext(ctx context.Context, rm RM, config EmailConfig) error { json, err := json.Marshal(config) if err != nil { return err } - if _, resp, err := rm.Put(restEmail, bytes.NewBuffer(json)); err != nil && resp.StatusCode != http.StatusNoContent { + if _, resp, err := rm.Put(ctx, restEmail, bytes.NewBuffer(json)); err != nil && resp.StatusCode != http.StatusNoContent { return fmt.Errorf("email config not set: %v", err) } return nil } -func GetEmailConfig(rm RM) (EmailConfig, error) { +func SetEmailConfig(rm RM, config EmailConfig) error { + return SetEmailConfigContext(context.Background(), rm, config) +} + +func GetEmailConfigContext(ctx context.Context, rm RM) (EmailConfig, error) { var config EmailConfig - body, resp, err := rm.Get(restEmail) + body, resp, err := rm.Get(ctx, restEmail) if err != nil && resp.StatusCode != http.StatusNoContent { return EmailConfig{}, fmt.Errorf("email config can't getting: %v", err) } @@ -53,11 +57,18 @@ func GetEmailConfig(rm RM) (EmailConfig, error) { return config, nil } -func DeleteEmailConfig(rm RM) error { +func GetEmailConfig(rm RM) (EmailConfig, error) { + return GetEmailConfigContext(context.Background(), rm) +} - if resp, err := rm.Del(restEmail); err != nil && resp.StatusCode != http.StatusNoContent { +func DeleteEmailConfigContext(ctx context.Context, rm RM) error { + if resp, err := rm.Del(ctx, restEmail); err != nil && resp.StatusCode != http.StatusNoContent { return fmt.Errorf("email config not deleted: %v", err) } return nil } + +func DeleteEmailConfig(rm RM) error { + return DeleteEmailConfigContext(context.Background(), rm) +} diff --git a/rm/groovyBlobStore.go b/rm/groovyBlobStore.go index 3ddc1ef..a41649e 100644 --- a/rm/groovyBlobStore.go +++ b/rm/groovyBlobStore.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "fmt" "text/template" ) @@ -54,13 +55,12 @@ func DeleteBlobStore(rm RM, name string) error { return err } - _, err = ScriptRunOnce(rm, newAnonGroovyScript(buf.String()), nil) + _, err = ScriptRunOnceContext(rm, newAnonGroovyScript(buf.String()), nil) return err } */ -// CreateFileBlobStore creates a blobstore -func CreateFileBlobStore(rm RM, name, path string) error { +func CreateFileBlobStoreContext(ctx context.Context, rm RM, name, path string) error { tmpl, err := template.New("fbs").Parse(groovyCreateFileBlobStore) if err != nil { return fmt.Errorf("could not parse template: %v", err) @@ -72,12 +72,16 @@ func CreateFileBlobStore(rm RM, name, path string) error { return fmt.Errorf("could not create file blobstore from template: %v", err) } - _, err = ScriptRunOnce(rm, newAnonGroovyScript(buf.String()), nil) + _, err = ScriptRunOnceContext(ctx, rm, newAnonGroovyScript(buf.String()), nil) return fmt.Errorf("could not create file blobstore: %v", err) } -// CreateBlobStoreGroup creates a blobstore -func CreateBlobStoreGroup(rm RM, name string, blobStores []string) error { +// CreateFileBlobStore creates a blobstore +func CreateFileBlobStore(rm RM, name, path string) error { + return CreateFileBlobStoreContext(context.Background(), rm, name, path) +} + +func CreateBlobStoreGroupContext(ctx context.Context, rm RM, name string, blobStores []string) error { tmpl, err := template.New("group").Parse(groovyCreateBlobStoreGroup) if err != nil { return fmt.Errorf("could not parse template: %v", err) @@ -89,6 +93,11 @@ func CreateBlobStoreGroup(rm RM, name string, blobStores []string) error { return fmt.Errorf("could not create group blobstore from template: %v", err) } - _, err = ScriptRunOnce(rm, newAnonGroovyScript(buf.String()), nil) + _, err = ScriptRunOnceContext(ctx, rm, newAnonGroovyScript(buf.String()), nil) return fmt.Errorf("could not create group blobstore: %v", err) } + +// CreateBlobStoreGroup creates a blobstore group +func CreateBlobStoreGroup(rm RM, name string, blobStores []string) error { + return CreateBlobStoreGroupContext(context.Background(), rm, name, blobStores) +} diff --git a/rm/groovyBlobStore_test.go b/rm/groovyBlobStore_test.go index 22415df..da959e3 100644 --- a/rm/groovyBlobStore_test.go +++ b/rm/groovyBlobStore_test.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "testing" ) @@ -9,7 +10,7 @@ func TestCreateFileBlobStore(t *testing.T) { rm, mock := repositoriesTestRM(t) defer mock.Close() - err := CreateFileBlobStore(rm, "testname", "testpath") + err := CreateFileBlobStoreContext(context.Background(), rm, "testname", "testpath") if err != nil { t.Error(err) } @@ -22,11 +23,11 @@ func TestCreateBlobStoreGroup(t *testing.T) { rm, mock := repositoriesTestRM(t) defer mock.Close() - CreateFileBlobStore(rm, "f1", "pathf1") - CreateFileBlobStore(rm, "f2", "pathf2") - CreateFileBlobStore(rm, "f3", "pathf3") + CreateFileBlobStoreContext(context.Background(), rm, "f1", "pathf1") + CreateFileBlobStoreContext(context.Background(), rm, "f2", "pathf2") + CreateFileBlobStoreContext(context.Background(), rm, "f3", "pathf3") - err := CreateBlobStoreGroup(rm, "grpname", []string{"f1", "f2", "f3"}) + err := CreateBlobStoreGroupContext(context.Background(), rm, "grpname", []string{"f1", "f2", "f3"}) if err != nil { t.Error(err) } diff --git a/rm/groovyRepository.go b/rm/groovyRepository.go index 50b6d94..af08934 100644 --- a/rm/groovyRepository.go +++ b/rm/groovyRepository.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "fmt" "text/template" ) @@ -71,8 +72,7 @@ type repositoryGroup struct { Members []string } -// CreateHostedRepository creates a hosted repository of the indicated format -func CreateHostedRepository(rm RM, format repositoryFormat, config repositoryHosted) error { +func CreateHostedRepositoryContext(ctx context.Context, rm RM, format repositoryFormat, config repositoryHosted) error { var groovyTmpl string switch format { case Maven: @@ -112,12 +112,16 @@ func CreateHostedRepository(rm RM, format repositoryFormat, config repositoryHos return fmt.Errorf("could not create hosted repository from template: %v", err) } - _, err = ScriptRunOnce(rm, newAnonGroovyScript(buf.String()), nil) + _, err = ScriptRunOnceContext(ctx, rm, newAnonGroovyScript(buf.String()), nil) return fmt.Errorf("could not create hosted repository: %v", err) } -// CreateProxyRepository creates a proxy repository of the indicated format -func CreateProxyRepository(rm RM, format repositoryFormat, config repositoryProxy) error { +// CreateHostedRepository creates a hosted repository of the indicated format +func CreateHostedRepository(rm RM, format repositoryFormat, config repositoryHosted) error { + return CreateHostedRepositoryContext(context.Background(), rm, format, config) +} + +func CreateProxyRepositoryContext(ctx context.Context, rm RM, format repositoryFormat, config repositoryProxy) error { var groovyTmpl string switch format { case Maven: @@ -157,12 +161,16 @@ func CreateProxyRepository(rm RM, format repositoryFormat, config repositoryProx return fmt.Errorf("could not create proxy repository from template: %v", err) } - _, err = ScriptRunOnce(rm, newAnonGroovyScript(buf.String()), nil) + _, err = ScriptRunOnceContext(ctx, rm, newAnonGroovyScript(buf.String()), nil) return fmt.Errorf("could not create proxy repository: %v", err) } -// CreateGroupRepository creates a group repository of the indicated format -func CreateGroupRepository(rm RM, format repositoryFormat, config repositoryGroup) error { +// CreateProxyRepository creates a proxy repository of the indicated format +func CreateProxyRepository(rm RM, format repositoryFormat, config repositoryProxy) error { + return CreateProxyRepositoryContext(context.Background(), rm, format, config) +} + +func CreateGroupRepositoryContext(ctx context.Context, rm RM, format repositoryFormat, config repositoryGroup) error { var groovyTmpl string switch format { case Maven: @@ -202,6 +210,11 @@ func CreateGroupRepository(rm RM, format repositoryFormat, config repositoryGrou return fmt.Errorf("could not create group repository from template: %v", err) } - _, err = ScriptRunOnce(rm, newAnonGroovyScript(buf.String()), nil) + _, err = ScriptRunOnceContext(ctx, rm, newAnonGroovyScript(buf.String()), nil) return fmt.Errorf("could not create group repository: %v", err) } + +// CreateGroupRepository creates a group repository of the indicated format +func CreateGroupRepository(rm RM, format repositoryFormat, config repositoryGroup) error { + return CreateGroupRepositoryContext(context.Background(), rm, format, config) +} diff --git a/rm/groovyRepository_test.go b/rm/groovyRepository_test.go index c817be5..74233b8 100644 --- a/rm/groovyRepository_test.go +++ b/rm/groovyRepository_test.go @@ -1,16 +1,12 @@ package nexusrm -import ( -// "testing" -) - /* func TestCreateFileBlobStore(t *testing.T) { t.Skip("Needs new framework") rm, mock := repositoriesTestRM(t) defer mock.Close() - err := CreateFileBlobStore(rm, "testname", "testpath") + err := CreateFileBlobStoreContext(rm, "testname", "testpath") if err != nil { t.Error(err) } diff --git a/rm/maintenance.go b/rm/maintenance.go index 105ee94..0f751c7 100644 --- a/rm/maintenance.go +++ b/rm/maintenance.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "net/http" @@ -22,8 +23,7 @@ type DatabaseState struct { IndexErrors int `json:"indexErrors"` } -// CheckDatabase returns the state of the named database -func CheckDatabase(rm RM, dbName string) (DatabaseState, error) { +func CheckDatabaseContext(ctx context.Context, rm RM, dbName string) (DatabaseState, error) { doError := func(err error) error { return fmt.Errorf("error checking status of database '%s': %v", dbName, err) } @@ -31,7 +31,7 @@ func CheckDatabase(rm RM, dbName string) (DatabaseState, error) { var state DatabaseState url := fmt.Sprintf(restMaintenanceDBCheck, dbName) - body, resp, err := rm.Put(url, nil) + body, resp, err := rm.Put(ctx, url, nil) if err != nil || resp.StatusCode != http.StatusOK { return state, doError(err) } @@ -43,8 +43,12 @@ func CheckDatabase(rm RM, dbName string) (DatabaseState, error) { return state, nil } -// CheckAllDatabases returns state on all of the databases -func CheckAllDatabases(rm RM) (states map[string]DatabaseState, err error) { +// CheckDatabase returns the state of the named database +func CheckDatabase(rm RM, dbName string) (DatabaseState, error) { + return CheckDatabaseContext(context.Background(), rm, dbName) +} + +func CheckAllDatabasesContext(ctx context.Context, rm RM) (states map[string]DatabaseState, err error) { states = make(map[string]DatabaseState) check := func(dbName string) { @@ -52,7 +56,7 @@ func CheckAllDatabases(rm RM) (states map[string]DatabaseState, err error) { return } - if state, er := CheckDatabase(rm, dbName); er != nil { + if state, er := CheckDatabaseContext(ctx, rm, dbName); er != nil { err = fmt.Errorf("error with '%s' database when all states: %v", dbName, er) } else { states[dbName] = state @@ -66,3 +70,8 @@ func CheckAllDatabases(rm RM) (states map[string]DatabaseState, err error) { return } + +// CheckAllDatabases returns state on all of the databases +func CheckAllDatabases(rm RM) (states map[string]DatabaseState, err error) { + return CheckAllDatabasesContext(context.Background(), rm) +} diff --git a/rm/maintenance_test.go b/rm/maintenance_test.go index c125193..9420de9 100644 --- a/rm/maintenance_test.go +++ b/rm/maintenance_test.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "net/http" @@ -50,7 +51,7 @@ func TestCheckDatabase(t *testing.T) { db := ComponentDB - state, err := CheckDatabase(rm, db) + state, err := CheckDatabaseContext(context.Background(), rm, db) if err != nil { panic(err) } @@ -68,7 +69,7 @@ func TestCheckAllDatabases(t *testing.T) { rm, mock := maintenanceTestRM(t) defer mock.Close() - states, err := CheckAllDatabases(rm) + states, err := CheckAllDatabasesContext(context.Background(), rm) if err != nil { panic(err) } diff --git a/rm/readOnly.go b/rm/readOnly.go index 103d057..49035c3 100644 --- a/rm/readOnly.go +++ b/rm/readOnly.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -39,9 +40,8 @@ func (s ReadOnlyState) String() string { return buf.String() } -// GetReadOnlyState returns the read-only state of the RM instance -func GetReadOnlyState(rm RM) (state ReadOnlyState, err error) { - body, resp, err := rm.Get(restReadOnly) +func GetReadOnlyStateContext(ctx context.Context, rm RM) (state ReadOnlyState, err error) { + body, resp, err := rm.Get(ctx, restReadOnly) if err != nil { return state, fmt.Errorf("could not get read-only state: %v", err) } @@ -55,9 +55,13 @@ func GetReadOnlyState(rm RM) (state ReadOnlyState, err error) { return } -// ReadOnlyEnable enables read-only mode for the RM instance -func ReadOnlyEnable(rm RM) (state ReadOnlyState, err error) { - body, resp, err := rm.Post(restReadOnlyFreeze, nil) +// GetReadOnlyState returns the read-only state of the RM instance +func GetReadOnlyState(rm RM) (state ReadOnlyState, err error) { + return GetReadOnlyStateContext(context.Background(), rm) +} + +func ReadOnlyEnableContext(ctx context.Context, rm RM) (state ReadOnlyState, err error) { + body, resp, err := rm.Post(ctx, restReadOnlyFreeze, nil) if err != nil && resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusNotFound { return } @@ -67,14 +71,18 @@ func ReadOnlyEnable(rm RM) (state ReadOnlyState, err error) { return } -// ReadOnlyRelease disables read-only mode for the RM instance -func ReadOnlyRelease(rm RM, force bool) (state ReadOnlyState, err error) { +// ReadOnlyEnable enables read-only mode for the RM instance +func ReadOnlyEnable(rm RM) (state ReadOnlyState, err error) { + return ReadOnlyEnableContext(context.Background(), rm) +} + +func ReadOnlyReleaseContext(ctx context.Context, rm RM, force bool) (state ReadOnlyState, err error) { endpoint := restReadOnlyRelease if force { endpoint = restReadOnlyForceRelease } - body, resp, err := rm.Post(endpoint, nil) + body, resp, err := rm.Post(ctx, endpoint, nil) if err != nil && resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusNotFound { return } @@ -83,3 +91,8 @@ func ReadOnlyRelease(rm RM, force bool) (state ReadOnlyState, err error) { return } + +// ReadOnlyRelease disables read-only mode for the RM instance +func ReadOnlyRelease(rm RM, force bool) (state ReadOnlyState, err error) { + return ReadOnlyReleaseContext(context.Background(), rm, force) +} diff --git a/rm/repositories.go b/rm/repositories.go index ff0678f..353ff45 100644 --- a/rm/repositories.go +++ b/rm/repositories.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -161,7 +162,7 @@ type RepositoryRawHosted struct { Raw AttributesRaw `json:"raw"` } -func CreateRepositoryHosted(rm RM, format repositoryFormat, r interface{}) error { +func CreateRepositoryHostedContext(ctx context.Context, rm RM, format repositoryFormat, r interface{}) error { buf, err := json.Marshal(r) if err != nil { return fmt.Errorf("could not marshal: %v", err) @@ -196,7 +197,7 @@ func CreateRepositoryHosted(rm RM, format repositoryFormat, r interface{}) error case Yum: restEndpointRepository = restRepositoriesHostedYum } - _, resp, err := rm.Post(restEndpointRepository, bytes.NewBuffer(buf)) + _, resp, err := rm.Post(ctx, restEndpointRepository, bytes.NewBuffer(buf)) if err != nil && resp == nil { return fmt.Errorf("could not create repository: %v", err) } @@ -204,7 +205,11 @@ func CreateRepositoryHosted(rm RM, format repositoryFormat, r interface{}) error return nil } -func CreateRepositoryProxy(rm RM, format repositoryFormat, r interface{}) error { +func CreateRepositoryHosted(rm RM, format repositoryFormat, r interface{}) error { + return CreateRepositoryHostedContext(context.Background(), rm, format, r) +} + +func CreateRepositoryProxyContext(ctx context.Context, rm RM, format repositoryFormat, r interface{}) error { buf, err := json.Marshal(r) if err != nil { return fmt.Errorf("could not marshal: %v", err) @@ -247,7 +252,7 @@ func CreateRepositoryProxy(rm RM, format repositoryFormat, r interface{}) error case Yum: restEndpointRepository = restRepositoriesProxyYum } - _, resp, err := rm.Post(restEndpointRepository, bytes.NewBuffer(buf)) + _, resp, err := rm.Post(ctx, restEndpointRepository, bytes.NewBuffer(buf)) if err != nil && resp == nil { return fmt.Errorf("could not create repository: %v", err) } @@ -255,23 +260,30 @@ func CreateRepositoryProxy(rm RM, format repositoryFormat, r interface{}) error return nil } -func DeleteRepositoryByName(rm RM, name string) error { +func CreateRepositoryProxy(rm RM, format repositoryFormat, r interface{}) error { + return CreateRepositoryProxyContext(context.Background(), rm, format, r) +} + +func DeleteRepositoryByNameContext(ctx context.Context, rm RM, name string) error { url := fmt.Sprintf("%s/%s", restRepositories, name) - if resp, err := rm.Del(url); err != nil && resp.StatusCode != http.StatusNoContent { + if resp, err := rm.Del(ctx, url); err != nil && resp.StatusCode != http.StatusNoContent { return fmt.Errorf("repository not deleted '%s': %v", name, err) } return nil } -// GetRepositories returns a list of components in the indicated repository -func GetRepositories(rm RM) ([]Repository, error) { +func DeleteRepositoryByName(rm RM, name string) error { + return DeleteRepositoryByNameContext(context.Background(), rm, name) +} + +func GetRepositoriesContext(ctx context.Context, rm RM) ([]Repository, error) { doError := func(err error) error { return fmt.Errorf("could not find repositories: %v", err) } - body, resp, err := rm.Get(restRepositories) + body, resp, err := rm.Get(ctx, restRepositories) if err != nil || resp.StatusCode != http.StatusOK { return nil, doError(err) } @@ -284,9 +296,13 @@ func GetRepositories(rm RM) ([]Repository, error) { return repos, nil } -// GetRepositoryByName returns information on a named repository -func GetRepositoryByName(rm RM, name string) (repo Repository, err error) { - repos, err := GetRepositories(rm) +// GetRepositories returns a list of components in the indicated repository +func GetRepositories(rm RM) ([]Repository, error) { + return GetRepositoriesContext(context.Background(), rm) +} + +func GetRepositoryByNameContext(ctx context.Context, rm RM, name string) (repo Repository, err error) { + repos, err := GetRepositoriesContext(ctx, rm) if err != nil { return repo, fmt.Errorf("could not get list of repositories: %v", err) } @@ -299,3 +315,8 @@ func GetRepositoryByName(rm RM, name string) (repo Repository, err error) { return repo, fmt.Errorf("did not find repository '%s': %v", name, err) } + +// GetRepositoryByName returns information on a named repository +func GetRepositoryByName(rm RM, name string) (repo Repository, err error) { + return GetRepositoryByNameContext(context.Background(), rm, name) +} diff --git a/rm/repositories_test.go b/rm/repositories_test.go index ad97184..0f1dea0 100644 --- a/rm/repositories_test.go +++ b/rm/repositories_test.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "net/http" @@ -38,7 +39,7 @@ func TestGetRepositories(t *testing.T) { rm, mock := repositoriesTestRM(t) defer mock.Close() - repos, err := GetRepositories(rm) + repos, err := GetRepositoriesContext(context.Background(), rm) if err != nil { t.Error(err) } @@ -57,7 +58,7 @@ func TestGetRepositoryByName(t *testing.T) { dummyRepoIdx := 0 - repo, err := GetRepositoryByName(rm, dummyRepos[dummyRepoIdx].Name) + repo, err := GetRepositoryByNameContext(context.Background(), rm, dummyRepos[dummyRepoIdx].Name) if err != nil { t.Error(err) } diff --git a/rm/roles.go b/rm/roles.go index 47b8dd6..8ecc650 100644 --- a/rm/roles.go +++ b/rm/roles.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -17,13 +18,13 @@ type Role struct { Roles []string `json:"roles"` } -func CreateRole(rm RM, role Role) error { +func CreateRoleContext(ctx context.Context, rm RM, role Role) error { json, err := json.Marshal(role) if err != nil { return err } - _, resp, err := rm.Post(restRole, bytes.NewBuffer(json)) + _, resp, err := rm.Post(ctx, restRole, bytes.NewBuffer(json)) if err != nil && resp.StatusCode != http.StatusNoContent { return err } @@ -31,12 +32,20 @@ func CreateRole(rm RM, role Role) error { return nil } -func DeleteRoleById(rm RM, id string) error { +func CreateRole(rm RM, role Role) error { + return CreateRoleContext(context.Background(), rm, role) +} + +func DeleteRoleByIdContext(ctx context.Context, rm RM, id string) error { url := fmt.Sprintf("%s/%s", restRole, id) - if resp, err := rm.Del(url); err != nil && resp.StatusCode != http.StatusNoContent { + if resp, err := rm.Del(ctx, url); err != nil && resp.StatusCode != http.StatusNoContent { return fmt.Errorf("role not deleted '%s': %v", id, err) } return nil } + +func DeleteRoleById(rm RM, id string) error { + return DeleteRoleByIdContext(context.Background(), rm, id) +} diff --git a/rm/scripts.go b/rm/scripts.go index b685176..a041c95 100644 --- a/rm/scripts.go +++ b/rm/scripts.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -24,13 +25,12 @@ type runResponse struct { Result string `json:"result"` } -// ScriptList lists all of the uploaded scripts in Repository Manager -func ScriptList(rm RM) ([]Script, error) { +func ScriptListContext(ctx context.Context, rm RM) ([]Script, error) { doError := func(err error) error { return fmt.Errorf("could not list scripts: %v", err) } - body, _, err := rm.Get(restScript) + body, _, err := rm.Get(ctx, restScript) if err != nil { return nil, doError(err) } @@ -43,8 +43,12 @@ func ScriptList(rm RM) ([]Script, error) { return scripts, nil } -// ScriptGet returns the named script -func ScriptGet(rm RM, name string) (Script, error) { +// ScriptList lists all of the uploaded scripts in Repository Manager +func ScriptList(rm RM) ([]Script, error) { + return ScriptListContext(context.Background(), rm) +} + +func ScriptGetContext(ctx context.Context, rm RM, name string) (Script, error) { doError := func(err error) error { return fmt.Errorf("could not find script '%s': %v", name, err) } @@ -52,7 +56,7 @@ func ScriptGet(rm RM, name string) (Script, error) { var script Script endpoint := fmt.Sprintf("%s/%s", restScript, name) - body, _, err := rm.Get(endpoint) + body, _, err := rm.Get(ctx, endpoint) if err != nil { return script, doError(err) } @@ -64,8 +68,12 @@ func ScriptGet(rm RM, name string) (Script, error) { return script, nil } -// ScriptUpload uploads the given Script to Repository Manager -func ScriptUpload(rm RM, script Script) error { +// ScriptGet returns the named script +func ScriptGet(rm RM, name string) (Script, error) { + return ScriptGetContext(context.Background(), rm, name) +} + +func ScriptUploadContext(ctx context.Context, rm RM, script Script) error { doError := func(err error) error { return fmt.Errorf("could not upload script '%s': %v", script.Name, err) } @@ -75,7 +83,7 @@ func ScriptUpload(rm RM, script Script) error { return doError(err) } - _, resp, err := rm.Post(restScript, bytes.NewBuffer(json)) + _, resp, err := rm.Post(ctx, restScript, bytes.NewBuffer(json)) if err != nil && resp.StatusCode != http.StatusNoContent { return doError(err) } @@ -83,8 +91,12 @@ func ScriptUpload(rm RM, script Script) error { return nil } -// ScriptUpdate update the contents of the given script -func ScriptUpdate(rm RM, script Script) error { +// ScriptUpload uploads the given Script to Repository Manager +func ScriptUpload(rm RM, script Script) error { + return ScriptUploadContext(context.Background(), rm, script) +} + +func ScriptUpdateContext(ctx context.Context, rm RM, script Script) error { doError := func(err error) error { return fmt.Errorf("could not update script '%s': %v", script.Name, err) } @@ -95,7 +107,7 @@ func ScriptUpdate(rm RM, script Script) error { } endpoint := fmt.Sprintf("%s/%s", restScript, script.Name) - _, resp, err := rm.Put(endpoint, bytes.NewBuffer(json)) + _, resp, err := rm.Put(ctx, endpoint, bytes.NewBuffer(json)) if err != nil && resp.StatusCode != http.StatusNoContent { return doError(err) } @@ -103,14 +115,18 @@ func ScriptUpdate(rm RM, script Script) error { return nil } -// ScriptRun executes the named Script -func ScriptRun(rm RM, name string, arguments []byte) (string, error) { +// ScriptUpdate update the contents of the given script +func ScriptUpdate(rm RM, script Script) error { + return ScriptUpdateContext(context.Background(), rm, script) +} + +func ScriptRunContext(ctx context.Context, rm RM, name string, arguments []byte) (string, error) { doError := func(err error) error { return fmt.Errorf("could not run script '%s': %v", name, err) } endpoint := fmt.Sprintf(restScriptRun, name) - body, _, err := rm.Post(endpoint, bytes.NewBuffer(arguments)) // TODO: Better response handling + body, _, err := rm.Post(ctx, endpoint, bytes.NewBuffer(arguments)) // TODO: Better response handling if err != nil { return "", doError(err) } @@ -123,22 +139,35 @@ func ScriptRun(rm RM, name string, arguments []byte) (string, error) { return resp.Result, nil } -// ScriptRunOnce takes the given Script, uploads it, executes it, and deletes it -func ScriptRunOnce(rm RM, script Script, arguments []byte) (string, error) { - if err := ScriptUpload(rm, script); err != nil { +// ScriptRun executes the named Script +func ScriptRun(rm RM, name string, arguments []byte) (string, error) { + return ScriptRunContext(context.Background(), rm, name, arguments) +} + +func ScriptRunOnceContext(ctx context.Context, rm RM, script Script, arguments []byte) (string, error) { + if err := ScriptUploadContext(ctx, rm, script); err != nil { return "", err } - defer ScriptDelete(rm, script.Name) + defer ScriptDeleteContext(ctx, rm, script.Name) - return ScriptRun(rm, script.Name, arguments) + return ScriptRunContext(ctx, rm, script.Name, arguments) } -// ScriptDelete removes the name, uploaded script -func ScriptDelete(rm RM, name string) error { +// ScriptRunOnce takes the given Script, uploads it, executes it, and deletes it +func ScriptRunOnce(rm RM, script Script, arguments []byte) (string, error) { + return ScriptRunOnceContext(context.Background(), rm, script, arguments) +} + +func ScriptDeleteContext(ctx context.Context, rm RM, name string) error { endpoint := fmt.Sprintf("%s/%s", restScript, name) - resp, err := rm.Del(endpoint) + resp, err := rm.Del(ctx, endpoint) if err != nil && resp.StatusCode != http.StatusNoContent { return fmt.Errorf("could not delete '%s': %v", name, err) } return nil } + +// ScriptDelete removes the name, uploaded script +func ScriptDelete(rm RM, name string) error { + return ScriptDeleteContext(context.Background(), rm, name) +} diff --git a/rm/scripts_test.go b/rm/scripts_test.go index f02ffea..896ba89 100644 --- a/rm/scripts_test.go +++ b/rm/scripts_test.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -128,7 +129,7 @@ func TestScriptList(t *testing.T) { rm, mock := scriptsTestRM(t) defer mock.Close() - scripts, err := ScriptList(rm) + scripts, err := ScriptListContext(context.Background(), rm) if err != nil { t.Error(err) } @@ -151,7 +152,7 @@ func TestScriptGet(t *testing.T) { dummyScriptsIdx := 1 - script, err := ScriptGet(rm, dummyScripts[dummyScriptsIdx].Name) + script, err := ScriptGetContext(context.Background(), rm, dummyScripts[dummyScriptsIdx].Name) if err != nil { t.Error(err) } @@ -168,11 +169,11 @@ func TestScriptUpload(t *testing.T) { newScript := Script{Name: "newScript", Content: "log.info('I am new!')", Type: "groovy"} - if err := ScriptUpload(rm, newScript); err != nil { + if err := ScriptUploadContext(context.Background(), rm, newScript); err != nil { t.Error(err) } - script, err := ScriptGet(rm, newScript.Name) + script, err := ScriptGetContext(context.Background(), rm, newScript.Name) if err != nil { t.Error(err) } @@ -197,11 +198,11 @@ func TestScriptUpdate(t *testing.T) { t.Fatal("I am an idiot") } - if err := ScriptUpdate(rm, updatedScript); err != nil { + if err := ScriptUpdateContext(context.Background(), rm, updatedScript); err != nil { t.Error(err) } - script, err := ScriptGet(rm, updatedScript.Name) + script, err := ScriptGetContext(context.Background(), rm, updatedScript.Name) if err != nil { t.Error(err) } @@ -218,15 +219,15 @@ func TestScriptDelete(t *testing.T) { deleteMe := Script{Name: "deleteMe", Content: "log.info('Existence is pain!')", Type: "groovy"} - if err := ScriptUpload(rm, deleteMe); err != nil { + if err := ScriptUploadContext(context.Background(), rm, deleteMe); err != nil { t.Error(err) } - if err := ScriptDelete(rm, deleteMe.Name); err != nil { + if err := ScriptDeleteContext(context.Background(), rm, deleteMe.Name); err != nil { t.Error(err) } - if _, err := ScriptGet(rm, deleteMe.Name); err == nil { + if _, err := ScriptGetContext(context.Background(), rm, deleteMe.Name); err == nil { t.Error("Found script which should have been deleted") } } @@ -238,11 +239,11 @@ func TestScriptRun(t *testing.T) { script := Script{Name: "scriptArgsTest", Content: "return args", Type: "groovy"} input := "this is a test" - if err := ScriptUpload(rm, script); err != nil { + if err := ScriptUploadContext(context.Background(), rm, script); err != nil { t.Error(err) } - ret, err := ScriptRun(rm, script.Name, []byte(input)) + ret, err := ScriptRunContext(context.Background(), rm, script.Name, []byte(input)) if err != nil { t.Error(err) } @@ -251,7 +252,7 @@ func TestScriptRun(t *testing.T) { t.Errorf("Did not get expected script output: %s\n", ret) } - if err = ScriptDelete(rm, script.Name); err != nil { + if err = ScriptDeleteContext(context.Background(), rm, script.Name); err != nil { t.Error(err) } } @@ -263,7 +264,7 @@ func TestScriptRunOnce(t *testing.T) { script := Script{Name: "scriptArgsTest", Content: "return args", Type: "groovy"} input := "this is a test" - ret, err := ScriptRunOnce(rm, script, []byte(input)) + ret, err := ScriptRunOnceContext(context.Background(), rm, script, []byte(input)) if err != nil { t.Error(err) } @@ -272,7 +273,7 @@ func TestScriptRunOnce(t *testing.T) { t.Errorf("Did not get expected script output: %s\n", ret) } - if _, err = ScriptGet(rm, script.Name); err == nil { + if _, err = ScriptGetContext(context.Background(), rm, script.Name); err == nil { t.Error("Found script which should have been deleted") } } diff --git a/rm/search.go b/rm/search.go index c8b7c94..da67166 100644 --- a/rm/search.go +++ b/rm/search.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -104,7 +105,7 @@ func NewSearchQueryBuilder() *SearchQueryBuilder { return b } -func search(rm RM, endpoint string, queryBuilder nexus.SearchQueryBuilder, responseHandler func([]byte) (string, error)) error { +func search(ctx context.Context, rm RM, endpoint string, queryBuilder nexus.SearchQueryBuilder, responseHandler func([]byte) (string, error)) error { continuation := "" queryEndpoint := fmt.Sprintf("%s?%s", endpoint, queryBuilder.Build()) @@ -115,7 +116,7 @@ func search(rm RM, endpoint string, queryBuilder nexus.SearchQueryBuilder, respo url += "&continuationToken=" + continuation } - body, resp, err := rm.Get(url) + body, resp, err := rm.Get(ctx, url) if err != nil || resp.StatusCode != http.StatusOK { return } @@ -142,11 +143,10 @@ func search(rm RM, endpoint string, queryBuilder nexus.SearchQueryBuilder, respo return nil } -// SearchComponents allows searching the indicated RM instance for specific components -func SearchComponents(rm RM, query nexus.SearchQueryBuilder) ([]RepositoryItem, error) { +func SearchComponentsContext(ctx context.Context, rm RM, query nexus.SearchQueryBuilder) ([]RepositoryItem, error) { items := make([]RepositoryItem, 0) - err := search(rm, restSearchComponents, query, func(body []byte) (string, error) { + err := search(ctx, rm, restSearchComponents, query, func(body []byte) (string, error) { var resp searchComponentsResponse if er := json.Unmarshal(body, &resp); er != nil { return "", er @@ -160,11 +160,15 @@ func SearchComponents(rm RM, query nexus.SearchQueryBuilder) ([]RepositoryItem, return items, err } -// SearchAssets allows searching the indicated RM instance for specific assets -func SearchAssets(rm RM, query nexus.SearchQueryBuilder) ([]RepositoryItemAsset, error) { +// SearchComponents allows searching the indicated RM instance for specific components +func SearchComponents(rm RM, query nexus.SearchQueryBuilder) ([]RepositoryItem, error) { + return SearchComponentsContext(context.Background(), rm, query) +} + +func SearchAssetsContext(ctx context.Context, rm RM, query nexus.SearchQueryBuilder) ([]RepositoryItemAsset, error) { items := make([]RepositoryItemAsset, 0) - err := search(rm, restSearchAssets, query, func(body []byte) (string, error) { + err := search(ctx, rm, restSearchAssets, query, func(body []byte) (string, error) { var resp searchAssetsResponse if er := json.Unmarshal(body, &resp); er != nil { return "", er @@ -177,3 +181,8 @@ func SearchAssets(rm RM, query nexus.SearchQueryBuilder) ([]RepositoryItemAsset, return items, err } + +// SearchAssets allows searching the indicated RM instance for specific assets +func SearchAssets(rm RM, query nexus.SearchQueryBuilder) ([]RepositoryItemAsset, error) { + return SearchAssetsContext(context.Background(), rm, query) +} diff --git a/rm/search_test.go b/rm/search_test.go index a2cd3de..9613827 100644 --- a/rm/search_test.go +++ b/rm/search_test.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "net/http" @@ -78,7 +79,7 @@ func TestSearchComponents(t *testing.T) { repo := "repo-maven" query := NewSearchQueryBuilder().Repository(repo) - components, err := SearchComponents(rm, query) + components, err := SearchComponentsContext(context.Background(), rm, query) if err != nil { t.Fatalf("Did not complete search: %v", err) } @@ -103,7 +104,7 @@ func TestSearchAssets(t *testing.T) { repo := "repo-maven" query := NewSearchQueryBuilder().Repository(repo) - assets, err := SearchAssets(rm, query) + assets, err := SearchAssetsContext(context.Background(), rm, query) if err != nil { t.Error(err) } @@ -128,7 +129,7 @@ func ExampleSearchComponents() { } query := NewSearchQueryBuilder().Repository("maven-releases") - components, err := SearchComponents(rm, query) + components, err := SearchComponentsContext(context.Background(), rm, query) if err != nil { panic(err) } diff --git a/rm/staging.go b/rm/staging.go index 50eef0c..a1c237c 100644 --- a/rm/staging.go +++ b/rm/staging.go @@ -1,6 +1,9 @@ package nexusrm -import "fmt" +import ( + "context" + "fmt" +) // service/rest/v1/staging/move/{repository} const ( @@ -38,19 +41,27 @@ type componentsDeleted struct { Version string `json:"version"` } -// StagingMove promotes components which match a set of criteria -func StagingMove(rm RM, query QueryBuilder) error { +func StagingMoveContext(ctx context.Context, rm RM, query QueryBuilder) error { endpoint := fmt.Sprintf("%s?%s", restStaging, query.Build()) // TODO: handle response - _, _, err := rm.Post(endpoint, nil) + _, _, err := rm.Post(ctx, endpoint, nil) return err } -// StagingDelete removes components which have been staged -func StagingDelete(rm RM, query QueryBuilder) error { +// StagingMove promotes components which match a set of criteria +func StagingMove(rm RM, query QueryBuilder) error { + return StagingMoveContext(context.Background(), rm, query) +} + +func StagingDeleteContext(ctx context.Context, rm RM, query QueryBuilder) error { endpoint := fmt.Sprintf("%s?%s", restStaging, query.Build()) - _, err := rm.Del(endpoint) + _, err := rm.Del(ctx, endpoint) return err } + +// StagingDelete removes components which have been staged +func StagingDelete(rm RM, query QueryBuilder) error { + return StagingDeleteContext(context.Background(), rm, query) +} diff --git a/rm/status.go b/rm/status.go index 6db09c3..c8a7306 100644 --- a/rm/status.go +++ b/rm/status.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "net/http" ) @@ -9,14 +10,22 @@ const ( restStatusWritable = "service/rest/v1/status/writable" ) +func StatusReadableContext(ctx context.Context, rm RM) (_ bool) { + _, resp, err := rm.Get(ctx, restStatusReadable) + return err == nil && resp.StatusCode == http.StatusOK +} + // StatusReadable returns true if the RM instance can serve read requests func StatusReadable(rm RM) (_ bool) { - _, resp, err := rm.Get(restStatusReadable) + return StatusReadableContext(context.Background(), rm) +} + +func StatusWritableContext(ctx context.Context, rm RM) (_ bool) { + _, resp, err := rm.Get(ctx, restStatusWritable) return err == nil && resp.StatusCode == http.StatusOK } // StatusWritable returns true if the RM instance can serve read requests func StatusWritable(rm RM) (_ bool) { - _, resp, err := rm.Get(restStatusWritable) - return err == nil && resp.StatusCode == http.StatusOK + return StatusWritableContext(context.Background(), rm) } diff --git a/rm/support.go b/rm/support.go index 88e6b3e..355a977 100644 --- a/rm/support.go +++ b/rm/support.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "encoding/json" "fmt" "mime" @@ -41,14 +42,13 @@ func NewSupportZipOptions() (o SupportZipOptions) { return } -// GetSupportZip generates a support zip with the given options -func GetSupportZip(rm RM, options SupportZipOptions) ([]byte, string, error) { +func GetSupportZipContext(ctx context.Context, rm RM, options SupportZipOptions) ([]byte, string, error) { request, err := json.Marshal(options) if err != nil { return nil, "", fmt.Errorf("error retrieving support zip: %v", err) } - body, resp, err := rm.Post(restSupportZip, bytes.NewBuffer(request)) + body, resp, err := rm.Post(ctx, restSupportZip, bytes.NewBuffer(request)) if err != nil { return nil, "", fmt.Errorf("error retrieving support zip: %v", err) } @@ -63,3 +63,8 @@ func GetSupportZip(rm RM, options SupportZipOptions) ([]byte, string, error) { return body, params["filename"], nil } + +// GetSupportZip generates a support zip with the given options +func GetSupportZip(rm RM, options SupportZipOptions) ([]byte, string, error) { + return GetSupportZipContext(context.Background(), rm, options) +} diff --git a/rm/tagging.go b/rm/tagging.go index dc53353..65f4945 100644 --- a/rm/tagging.go +++ b/rm/tagging.go @@ -2,6 +2,7 @@ package nexusrm import ( "bytes" + "context" "encoding/json" "fmt" ) @@ -35,8 +36,7 @@ type componentsAssociated struct { Version string `json:"version"` } -// TagsList returns a list of tags in the given RM instance -func TagsList(rm RM) ([]Tag, error) { +func TagsListContext(ctx context.Context, rm RM) ([]Tag, error) { continuation := "" tags := make([]Tag, 0) @@ -47,7 +47,7 @@ func TagsList(rm RM) ([]Tag, error) { url += "&continuationToken=" + continuation } - body, _, err := rm.Get(url) + body, _, err := rm.Get(ctx, url) if err != nil { return fmt.Errorf("could not get list of tags: %v", err) } @@ -76,8 +76,12 @@ func TagsList(rm RM) ([]Tag, error) { return tags, nil } -// AddTag adds a tag to the given instance -func AddTag(rm RM, tagName string, attributes map[string]string) (Tag, error) { +// TagsList returns a list of tags in the given RM instance +func TagsList(rm RM) ([]Tag, error) { + return TagsListContext(context.Background(), rm) +} + +func AddTagContext(ctx context.Context, rm RM, tagName string, attributes map[string]string) (Tag, error) { tag := Tag{Name: tagName} //TODO: attributes @@ -86,7 +90,7 @@ func AddTag(rm RM, tagName string, attributes map[string]string) (Tag, error) { return Tag{}, fmt.Errorf("could not marshal tag: %v", err) } - body, _, err := rm.Post(restTagging, bytes.NewBuffer(buf)) + body, _, err := rm.Post(ctx, restTagging, bytes.NewBuffer(buf)) if err != nil { return Tag{}, fmt.Errorf("could not create tag %s: %v", tagName, err) } @@ -99,11 +103,15 @@ func AddTag(rm RM, tagName string, attributes map[string]string) (Tag, error) { return createdTag, nil } -// GetTag retrieve the named tag -func GetTag(rm RM, tagName string) (Tag, error) { +// AddTag adds a tag to the given instance +func AddTag(rm RM, tagName string, attributes map[string]string) (Tag, error) { + return AddTagContext(context.Background(), rm, tagName, attributes) +} + +func GetTagContext(ctx context.Context, rm RM, tagName string) (Tag, error) { endpoint := fmt.Sprintf("%s/%s", restTagging, tagName) - body, _, err := rm.Get(endpoint) + body, _, err := rm.Get(ctx, endpoint) if err != nil { return Tag{}, fmt.Errorf("could not find tag %s: %v", tagName, err) } @@ -116,19 +124,32 @@ func GetTag(rm RM, tagName string) (Tag, error) { return tag, nil } -// AssociateTag associates a tag to any component which matches the search criteria -func AssociateTag(rm RM, query QueryBuilder) error { +// GetTag retrieve the named tag +func GetTag(rm RM, tagName string) (Tag, error) { + return GetTagContext(context.Background(), rm, tagName) +} + +func AssociateTagContext(ctx context.Context, rm RM, query QueryBuilder) error { endpoint := fmt.Sprintf("%s?%s", restTagging, query.Build()) // TODO: handle response - _, _, err := rm.Post(endpoint, nil) + _, _, err := rm.Post(ctx, endpoint, nil) return err } -// DisassociateTag associates a tag to any component which matches the search criteria -func DisassociateTag(rm RM, query QueryBuilder) error { +// AssociateTag associates a tag to any component which matches the search criteria +func AssociateTag(rm RM, query QueryBuilder) error { + return AssociateTagContext(context.Background(), rm, query) +} + +func DisassociateTagContext(ctx context.Context, rm RM, query QueryBuilder) error { endpoint := fmt.Sprintf("%s?%s", restTagging, query.Build()) - _, err := rm.Del(endpoint) + _, err := rm.Del(ctx, endpoint) return err } + +// DisassociateTag associates a tag to any component which matches the search criteria +func DisassociateTag(rm RM, query QueryBuilder) error { + return DisassociateTagContext(context.Background(), rm, query) +} diff --git a/rm/tagging_test.go b/rm/tagging_test.go index 6c3fd47..40d0ce1 100644 --- a/rm/tagging_test.go +++ b/rm/tagging_test.go @@ -1,6 +1,7 @@ package nexusrm import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -99,7 +100,7 @@ func TestTagsList(t *testing.T) { rm, mock := taggingTestRM(t) defer mock.Close() - tags, err := TagsList(rm) + tags, err := TagsListContext(context.Background(), rm) if err != nil { t.Error(err) } @@ -121,7 +122,7 @@ func TestGetTag(t *testing.T) { want := dummyTags[0] - got, err := GetTag(rm, want.Name) + got, err := GetTagContext(context.Background(), rm, want.Name) if err != nil { t.Error(err) } @@ -139,7 +140,7 @@ func TestAddTag(t *testing.T) { newName := "newTestTag" - got, err := AddTag(rm, newName, nil) + got, err := AddTagContext(context.Background(), rm, newName, nil) if err != nil { t.Error(err) } @@ -148,7 +149,7 @@ func TestAddTag(t *testing.T) { t.Error("Did not get tag with expected name") } - gotAgain, err := GetTag(rm, newName) + gotAgain, err := GetTagContext(context.Background(), rm, newName) if err != nil { t.Error(err) }