diff --git a/cmd/granted/main.go b/cmd/granted/main.go index 542602a7..ceec6041 100644 --- a/cmd/granted/main.go +++ b/cmd/granted/main.go @@ -10,12 +10,17 @@ import ( "github.com/common-fate/clio" "github.com/common-fate/clio/clierr" + "github.com/common-fate/updatecheck" + "github.com/fwdcloudsec/granted/internal/build" "github.com/fwdcloudsec/granted/pkg/assume" "github.com/fwdcloudsec/granted/pkg/granted" "github.com/urfave/cli/v2" ) func main() { + updatecheck.Check(updatecheck.GrantedCLI, build.Version, !build.IsDev()) + defer updatecheck.Print() + c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) go func() { diff --git a/go.mod b/go.mod index 3c999865..69cc33be 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 + github.com/common-fate/updatecheck v0.3.5 github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 github.com/pkg/errors v0.9.1 github.com/segmentio/ksuid v1.0.4 @@ -17,10 +18,12 @@ require ( ) require ( + github.com/Masterminds/sprig/v3 v3.2.3 github.com/alessio/shellescape v1.4.2 + github.com/briandowns/spinner v1.23.0 github.com/common-fate/clio v1.2.3 - github.com/common-fate/grab v1.3.0 github.com/fatih/color v1.16.0 + github.com/google/uuid v1.6.0 github.com/hashicorp/go-version v1.7.0 github.com/schollz/progressbar/v3 v3.13.1 go.uber.org/zap v1.26.0 @@ -30,12 +33,10 @@ require ( require ( github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver/v3 v3.2.0 // indirect - github.com/Masterminds/sprig/v3 v3.2.3 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 // indirect github.com/benbjohnson/clock v1.3.5 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/google/go-cmp v0.6.0 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/huandu/xstrings v1.3.3 // indirect github.com/imdario/mergo v0.3.11 // indirect github.com/kr/pretty v0.3.1 // indirect @@ -68,7 +69,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/iam v1.28.7 github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 // indirect github.com/aws/smithy-go v1.24.1 - github.com/common-fate/awsconfigfile v0.10.0 + github.com/common-fate/useragent v0.1.0 github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/danieljoos/wincred v1.1.2 // indirect github.com/dvsekhvalnov/jose2go v1.8.0 // indirect @@ -79,11 +80,12 @@ require ( github.com/joho/godotenv v1.4.0 github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-isatty v0.0.20 github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect github.com/mtibben/percent v0.2.1 // indirect github.com/olekukonko/tablewriter v0.0.5 github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/sethvargo/go-retry v0.2.4 github.com/stretchr/testify v1.10.0 go.uber.org/ratelimit v0.3.0 golang.org/x/sync v0.19.0 @@ -92,3 +94,5 @@ require ( golang.org/x/text v0.34.0 gopkg.in/ini.v1 v1.67.0 ) + +replace github.com/aws/session-manager-plugin => github.com/common-fate/session-manager-plugin v0.0.0-20240723053832-3d311db99016 diff --git a/go.sum b/go.sum index 3751100a..6579687e 100644 --- a/go.sum +++ b/go.sum @@ -46,12 +46,14 @@ github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0= github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= -github.com/common-fate/awsconfigfile v0.10.0 h1:9W0JTeO0d3jNLw3Ps9U7IJwLYp4D9zcipq/sqNEWJOg= -github.com/common-fate/awsconfigfile v0.10.0/go.mod h1:znstvN26aO+KUwmdjwZ+WcmitZ7heEJb5iFdCPokAO8= +github.com/briandowns/spinner v1.23.0 h1:alDF2guRWqa/FOZZYWjlMIx2L6H0wyewPxo/CH4Pt2A= +github.com/briandowns/spinner v1.23.0/go.mod h1:rPG4gmXeN3wQV/TsAY4w8lPdIM6RX3yqeBQJSrbXjuE= github.com/common-fate/clio v1.2.3 h1:hHwUYZjn66qGYDpgANl0EB/92hyi/Jsnd07qB09rvn4= github.com/common-fate/clio v1.2.3/go.mod h1:NkozaS15SA+6Y9zb+82eIj1i41aWShorTqA01GKQ7A8= -github.com/common-fate/grab v1.3.0 h1:vGNBMfhAVAWtrLuH1stnhL4LsDb73drhegC/060q+Ok= -github.com/common-fate/grab v1.3.0/go.mod h1:6zH8GckZGFrOKfZzL4Y/2OTvxwFeL6cDtsztM0GGC2Y= +github.com/common-fate/updatecheck v0.3.5 h1:UGIKMnYwuHjbhhCaisLz1pNPg8Z1nXEoWcfqT+4LkAg= +github.com/common-fate/updatecheck v0.3.5/go.mod h1:fru9yoUXmM3QVAUdDDqKQeDoln20Pkji/7EH64gVHMs= +github.com/common-fate/useragent v0.1.0 h1:RLmkIiJXcOUJAUyXWc/zCaGbrGmlCbHBGMx99ztQ3ZU= +github.com/common-fate/useragent v0.1.0/go.mod h1:GjXGR6cDiMboDP04qlfDfA5HTbeoRSoNgQWDAyOdW9o= github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -146,6 +148,8 @@ github.com/schollz/progressbar/v3 v3.13.1 h1:o8rySDYiQ59Mwzy2FELeHY5ZARXZTVJC7iH github.com/schollz/progressbar/v3 v3.13.1/go.mod h1:xvrbki8kfT1fzWzBT/UZd9L6GA+jdL7HAgq2RFnO6fQ= github.com/segmentio/ksuid v1.0.4 h1:sBo2BdShXjmcugAMwjugoGUdUV0pcxY5mW4xKRn3v4c= github.com/segmentio/ksuid v1.0.4/go.mod h1:/XUiZBD3kVx5SmUOl55voK5yeAbBNNIed+2O73XgrPE= +github.com/sethvargo/go-retry v0.2.4 h1:T+jHEQy/zKJf5s95UkguisicE0zuF9y7+/vgz08Ocec= +github.com/sethvargo/go-retry v0.2.4/go.mod h1:1afjQuvh7s4gflMObvjLPaWgluLLyhA1wmVZ6KLpICw= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= diff --git a/pkg/accessrequest/role.go b/pkg/accessrequest/role.go new file mode 100644 index 00000000..fbaae170 --- /dev/null +++ b/pkg/accessrequest/role.go @@ -0,0 +1,121 @@ +// Package accessrequest handles +// making requests to roles that a +// user doesn't have access to. +package accessrequest + +import ( + "encoding/json" + "fmt" + "net/url" + "os" + "path/filepath" + + "github.com/common-fate/clio/clierr" + "github.com/fwdcloudsec/granted/pkg/config" +) + +type Role struct { + Account string `json:"account"` + Role string `json:"role"` +} + +func (r Role) URL(dashboardURL string) string { + u, err := url.Parse(dashboardURL) + if err != nil { + return fmt.Sprintf("error building access request URL: %s", err.Error()) + } + u.Path = "access" + q := u.Query() + q.Add("type", "aws-sso") + q.Add("permissionSetArn.label", r.Role) + q.Add("accountId", r.Account) + u.RawQuery = q.Encode() + + return u.String() +} + +func (r Role) Save() error { + roleBytes, err := json.Marshal(r) + if err != nil { + return err + } + + configFolder, err := config.GrantedConfigFolder() + if err != nil { + return err + } + + file := filepath.Join(configFolder, "latest-role") + return os.WriteFile(file, roleBytes, 0644) +} + +func LatestRole() (*Role, error) { + configFolder, err := config.GrantedConfigFolder() + if err != nil { + return nil, err + } + + file := filepath.Join(configFolder, "latest-role") + + if _, err := os.Stat(file); os.IsNotExist(err) { + return nil, clierr.New("no latest role saved", clierr.Info("You can run 'assume' to try and access a role. If the role is inaccessible it will be saved as the latest role.")) + } + + roleBytes, err := os.ReadFile(file) + if err != nil { + return nil, err + } + + var r Role + err = json.Unmarshal(roleBytes, &r) + if err != nil { + return nil, err + } + + return &r, nil +} + +type Profile struct { + Name string +} + +func (p Profile) Save() error { + profileBytes, err := json.Marshal(p) + if err != nil { + return err + } + + configFolder, err := config.GrantedConfigFolder() + if err != nil { + return err + } + + file := filepath.Join(configFolder, "latest-profile") + return os.WriteFile(file, profileBytes, 0644) +} + +func LatestProfile() (*Profile, error) { + configFolder, err := config.GrantedConfigFolder() + if err != nil { + return nil, err + } + + file := filepath.Join(configFolder, "latest-profile") + + if _, err := os.Stat(file); os.IsNotExist(err) { + return nil, clierr.New("no latest profile saved", clierr.Info("You can run 'assume' to try and access a profile. If the profile is inaccessible it will be saved as the latest profile.")) + } + + profileBytes, err := os.ReadFile(file) + if err != nil { + return nil, err + } + + var p Profile + err = json.Unmarshal(profileBytes, &p) + if err != nil { + return nil, err + } + + return &p, nil +} diff --git a/pkg/assume/assume.go b/pkg/assume/assume.go index 09717345..d04bddc7 100644 --- a/pkg/assume/assume.go +++ b/pkg/assume/assume.go @@ -1,6 +1,7 @@ package assume import ( + "context" "errors" "fmt" "net/url" @@ -17,7 +18,7 @@ import ( "github.com/alessio/shellescape" "github.com/aws/aws-sdk-go-v2/aws" awsconfig "github.com/aws/aws-sdk-go-v2/config" - "github.com/common-fate/awsconfigfile" + "github.com/fwdcloudsec/granted/pkg/awsconfigfile" "github.com/common-fate/clio" "github.com/common-fate/clio/ansi" "github.com/common-fate/clio/clierr" @@ -27,11 +28,15 @@ import ( "github.com/fwdcloudsec/granted/pkg/config" "github.com/fwdcloudsec/granted/pkg/console" "github.com/fwdcloudsec/granted/pkg/forkprocess" + "github.com/fwdcloudsec/granted/pkg/hook/accessrequesthook" + "github.com/fwdcloudsec/granted/pkg/hook/httpprovider" + "github.com/fwdcloudsec/granted/pkg/providercfg" "github.com/fwdcloudsec/granted/pkg/launcher" "github.com/fwdcloudsec/granted/pkg/testable" cfflags "github.com/fwdcloudsec/granted/pkg/urfav_overrides" "github.com/fatih/color" "github.com/hako/durafmt" + sethRetry "github.com/sethvargo/go-retry" "github.com/urfave/cli/v2" "gopkg.in/ini.v1" ) @@ -149,7 +154,7 @@ func AssumeCommand(c *cli.Context) error { AccountID: profile.AWSConfig.SSOAccountID, AccountName: profile.AWSConfig.SSOAccountID, RoleName: profile.AWSConfig.SSORoleName, - GeneratedFrom: "commonfate", + GeneratedFrom: "granted-provider", }, }, }) @@ -301,6 +306,8 @@ func AssumeCommand(c *cli.Context) error { configOpts.Duration = d } + reason := assumeFlags.String("reason") + attachments := assumeFlags.StringSlice("attach") cfg, err := config.Load() if err != nil { return err @@ -308,6 +315,12 @@ func AssumeCommand(c *cli.Context) error { configOpts.UseAuthorizationCode = assumeFlags.Bool("use-authorization-code") || cfg.UseAuthorizationCode + wait := assumeFlags.Bool("wait") + retryDuration := time.Minute * 1 + if wait { + retryDuration = time.Minute * 15 + } + // if getConsoleURL is true, we'll use the AWS federated login to retrieve a URL to access the console. // depending on how Granted is configured, this is then printed to the terminal or a browser is launched at the URL automatically. getConsoleURL := !assumeFlags.Bool("env") && ((assumeFlags.Bool("console") || assumeFlags.String("console-destination") != "") || assumeFlags.Bool("active-role") || assumeFlags.String("service") != "" || assumeFlags.Bool("url")) @@ -328,7 +341,68 @@ func AssumeCommand(c *cli.Context) error { creds, err := profile.AssumeConsole(c.Context, configOpts) if err != nil && strings.HasPrefix(err.Error(), "no access") { clio.Debugw("received a No Access error", "error", err) - // TODO: this is where we can add a hook in future to allow users to define a shell script to be executed to automatically request access, etc. + + hook, hookCreateErr := accessrequesthook.NewHookFromProfile(profile, newHTTPProvider) + if hookCreateErr != nil { + return hookCreateErr + } + if hook == nil { + return err + } + + var apiDuration *time.Duration + if duration != "" { + d, err := time.ParseDuration(duration) + if err != nil { + return err + } + apiDuration = &d + } + + noAccessInput := accessrequesthook.NoAccessInput{ + Profile: profile, + Reason: reason, + Attachments: attachments, + Duration: apiDuration, + Confirm: assumeFlags.Bool("confirm"), + Wait: wait, + StartTime: time.Now(), + } + retry, justActivated, hookErr := hook.NoAccess(c.Context, noAccessInput) + if hookErr != nil { + return hookErr + } + + if retry { + // reset the start time for the timer (otherwise it shows 2s, 7s, 12s etc) + noAccessInput.StartTime = time.Now() + + b := sethRetry.NewConstant(5 * time.Second) + b = sethRetry.WithMaxDuration(retryDuration, b) + err = sethRetry.Do(c.Context, b, func(ctx context.Context) (err error) { + + if !justActivated { + err = hook.RetryAccess(ctx, noAccessInput) + if err != nil { + return sethRetry.RetryableError(err) + } + } + + creds, err = profile.AssumeConsole(c.Context, configOpts) + if err != nil { + return sethRetry.RetryableError(err) + } + + // If we successfully got credentials, mark as just activated + justActivated = true + + return nil + }) + if err != nil { + return err + } + + } } if err != nil { @@ -439,7 +513,67 @@ func AssumeCommand(c *cli.Context) error { creds, err := profile.AssumeTerminal(c.Context, configOpts) if err != nil && strings.HasPrefix(err.Error(), "no access") { clio.Debugw("received a No Access error", "error", err) - // TODO: this is where we can add a hook in future to allow users to define a shell script to be executed to automatically request access, etc. + + hook, hookCreateErr := accessrequesthook.NewHookFromProfile(profile, newHTTPProvider) + if hookCreateErr != nil { + return hookCreateErr + } + if hook == nil { + return err + } + + var apiDuration *time.Duration + if duration != "" { + d, err := time.ParseDuration(duration) + if err != nil { + return err + } + apiDuration = &d + } + noAccessInput := accessrequesthook.NoAccessInput{ + Profile: profile, + Reason: reason, + Duration: apiDuration, + Confirm: assumeFlags.Bool("confirm"), + Wait: wait, + StartTime: time.Now(), + } + retry, justActivated, hookErr := hook.NoAccess(c.Context, noAccessInput) + if hookErr != nil { + return hookErr + } + + if retry { + // reset the start time for the timer (otherwise it shows 2s, 7s, 12s etc) + noAccessInput.StartTime = time.Now() + + b := sethRetry.NewConstant(time.Second * 5) + b = sethRetry.WithMaxDuration(retryDuration, b) + err = sethRetry.Do(c.Context, b, func(ctx context.Context) (err error) { + + if !justActivated { + err = hook.RetryAccess(ctx, noAccessInput) + if err != nil { + return sethRetry.RetryableError(err) + } + } + + creds, err = profile.AssumeTerminal(c.Context, configOpts) + if err != nil { + + return sethRetry.RetryableError(err) + } + + // If we successfully got credentials, mark as just activated + justActivated = true + + return nil + }) + if err != nil { + return err + } + + } } if err != nil { @@ -600,6 +734,14 @@ func EnvKeys(creds aws.Credentials, region string) []string { "AWS_REGION=" + region} } +func newHTTPProvider(providerURL string) (accessrequesthook.AccessProvider, error) { + cfg, err := providercfg.LoadFromURL(context.Background(), providerURL) + if err != nil { + return nil, err + } + return httpprovider.New(cfg, providerURL, ""), nil +} + func filterMultiToken(filterValue string, optValue string, optIndex int) bool { optValue = strings.ToLower(optValue) filters := strings.Split(strings.ToLower(filterValue), " ") diff --git a/pkg/assume/entrypoint.go b/pkg/assume/entrypoint.go index 92f939aa..b9829294 100644 --- a/pkg/assume/entrypoint.go +++ b/pkg/assume/entrypoint.go @@ -7,6 +7,7 @@ import ( "github.com/common-fate/clio" "github.com/common-fate/clio/cliolog" + "github.com/common-fate/useragent" "github.com/fwdcloudsec/granted/internal/build" "github.com/fwdcloudsec/granted/pkg/alias" "github.com/fwdcloudsec/granted/pkg/assumeprint" @@ -156,6 +157,9 @@ func GetCliApp() *cli.App { return alias.MustBeConfigured(c.Bool("auto-configure-shell")) } + // set the user agent + c.Context = useragent.NewContext(c.Context, "granted", build.Version) + return nil }, } diff --git a/pkg/autosync/registry_config.go b/pkg/autosync/registry_config.go index 76c3bac9..feb6269c 100644 --- a/pkg/autosync/registry_config.go +++ b/pkg/autosync/registry_config.go @@ -19,7 +19,7 @@ type RegistrySyncConfig struct { LastCheckForSync time.Weekday `json:"lastCheckForSync"` } -// return the absolute path of commonfate/registry-sync file. +// return the absolute path of the registry-sync file. func (rc RegistrySyncConfig) Path() string { return path.Join(rc.dir, FILENAME) } @@ -59,10 +59,10 @@ func loadRegistryConfig() (rc RegistrySyncConfig, ok bool) { return } - rc.dir = path.Join(cd, "commonfate") + rc.dir = path.Join(cd, "granted") err = os.MkdirAll(rc.dir, os.ModePerm) if err != nil { - clio.Debug("error creating commonfate config dir: %s", err.Error()) + clio.Debug("error creating granted config dir: %s", err.Error()) return } diff --git a/pkg/awsconfigfile/awscfg.go b/pkg/awsconfigfile/awscfg.go new file mode 100644 index 00000000..05bb38ad --- /dev/null +++ b/pkg/awsconfigfile/awscfg.go @@ -0,0 +1,163 @@ +// Package awsconfigfile contains logic to template ~/.aws/config files +// based on profile sources (AWS SSO, HTTP registries, etc). +// +// Vendored from github.com/common-fate/awsconfigfile v0.10.0 and updated +// to use provider-agnostic naming. +package awsconfigfile + +import ( + "bytes" + "sort" + "strings" + "text/template" + + "github.com/Masterminds/sprig/v3" + "gopkg.in/ini.v1" +) + +type SSOProfile struct { + SSOStartURL string + SSORegion string + + Region string + AccountID string + AccountName string + RoleName string + ProviderURL string + // GeneratedFrom is the source that the profile was created from, + // such as 'aws-sso' or a named HTTP registry. + GeneratedFrom string +} + +// ToIni converts a profile to a struct with `ini` tags ready to be +// written to an ini config file. +// +// If noCredentialProcess is true, the struct will contain sso_ parameters. +// Otherwise it will contain granted_sso parameters for use with the +// Granted credential process. +func (p SSOProfile) ToIni(profileName string, noCredentialProcess bool) any { + if noCredentialProcess { + return ®ularProfile{ + SSOStartURL: p.SSOStartURL, + SSORegion: p.SSORegion, + SSOAccountID: p.AccountID, + SSORoleName: p.RoleName, + GeneratedFrom: p.GeneratedFrom, + Region: p.Region, + } + } + + credProcess := "granted credential-process --profile " + profileName + + if p.ProviderURL != "" { + credProcess += " --url " + p.ProviderURL + } + + return &credentialProcessProfile{ + SSOStartURL: p.SSOStartURL, + SSORegion: p.SSORegion, + SSOAccountID: p.AccountID, + SSORoleName: p.RoleName, + CredProcess: credProcess, + GeneratedFrom: p.GeneratedFrom, + Region: p.Region, + } +} + +type MergeOpts struct { + Config *ini.File + Prefix string + Profiles []SSOProfile + SectionNameTemplate string + NoCredentialProcess bool + // PruneStartURLs is a slice of AWS SSO start URLs which profiles are being + // generated for. Existing profiles with these start URLs will be removed if + // they aren't found in the Profiles field. + PruneStartURLs []string +} + +func Merge(opts MergeOpts) error { + if opts.SectionNameTemplate == "" { + opts.SectionNameTemplate = "{{ .AccountName }}/{{ .RoleName }}" + } + + sort.SliceStable(opts.Profiles, func(i, j int) bool { + combinedNameI := opts.Profiles[i].AccountName + "/" + opts.Profiles[i].RoleName + combinedNameJ := opts.Profiles[j].AccountName + "/" + opts.Profiles[j].RoleName + return combinedNameI < combinedNameJ + }) + + funcMap := sprig.TxtFuncMap() + sectionNameTempl, err := template.New("").Funcs(funcMap).Parse(opts.SectionNameTemplate) + if err != nil { + return err + } + + // remove any config sections that have 'common_fate_generated_from' as a key + // (legacy) or 'granted_generated_from' (current) + for _, sec := range opts.Config.Sections() { + var startURL string + + if sec.HasKey("granted_sso_start_url") { + startURL = sec.Key("granted_sso_start_url").String() + } else if sec.HasKey("sso_start_url") { + startURL = sec.Key("sso_start_url").String() + } + + for _, pruneURL := range opts.PruneStartURLs { + isGenerated := sec.HasKey("granted_generated_from") || sec.HasKey("common_fate_generated_from") + + if isGenerated && startURL == pruneURL { + opts.Config.DeleteSection(sec.Name()) + } + } + } + + for _, ssoProfile := range opts.Profiles { + ssoProfile.AccountName = normalizeAccountName(ssoProfile.AccountName) + sectionNameBuffer := bytes.NewBufferString("") + err := sectionNameTempl.Execute(sectionNameBuffer, ssoProfile) + if err != nil { + return err + } + profileName := opts.Prefix + sectionNameBuffer.String() + sectionName := "profile " + profileName + + opts.Config.DeleteSection(sectionName) + section, err := opts.Config.NewSection(sectionName) + if err != nil { + return err + } + + entry := ssoProfile.ToIni(profileName, opts.NoCredentialProcess) + err = section.ReflectFrom(entry) + if err != nil { + return err + } + } + + return nil +} + +type credentialProcessProfile struct { + SSOStartURL string `ini:"granted_sso_start_url"` + SSORegion string `ini:"granted_sso_region"` + SSOAccountID string `ini:"granted_sso_account_id"` + SSORoleName string `ini:"granted_sso_role_name"` + GeneratedFrom string `ini:"granted_generated_from"` + CredProcess string `ini:"credential_process"` + Region string `ini:"region,omitempty"` +} + +type regularProfile struct { + SSOStartURL string `ini:"sso_start_url"` + SSORegion string `ini:"sso_region"` + SSOAccountID string `ini:"sso_account_id"` + GeneratedFrom string `ini:"granted_generated_from"` + SSORoleName string `ini:"sso_role_name"` + Region string `ini:"region,omitempty"` +} + +func normalizeAccountName(accountName string) string { + return strings.ReplaceAll(accountName, " ", "-") +} diff --git a/pkg/awsconfigfile/config_path.go b/pkg/awsconfigfile/config_path.go new file mode 100644 index 00000000..e0f091db --- /dev/null +++ b/pkg/awsconfigfile/config_path.go @@ -0,0 +1,17 @@ +package awsconfigfile + +import ( + "os" + "path/filepath" +) + +// DefaultSharedConfigFilename returns the SDK's default file path for +// the shared config file (~/.aws/config). +func DefaultSharedConfigFilename() string { + return filepath.Join(userHomeDir(), ".aws", "config") +} + +func userHomeDir() string { + homedir, _ := os.UserHomeDir() + return homedir +} diff --git a/pkg/awsconfigfile/generator.go b/pkg/awsconfigfile/generator.go new file mode 100644 index 00000000..b3df890d --- /dev/null +++ b/pkg/awsconfigfile/generator.go @@ -0,0 +1,97 @@ +package awsconfigfile + +import ( + "context" + "fmt" + "regexp" + "strings" + "sync" + + "golang.org/x/sync/errgroup" + "gopkg.in/ini.v1" +) + +// Source returns AWS profiles to be combined into an AWS config file. +type Source interface { + GetProfiles(ctx context.Context) ([]SSOProfile, error) +} + +// Generator generates AWS profiles for ~/.aws/config. +// It reads profiles from sources and merges them with +// an existing ini config file. +type Generator struct { + Sources []Source + Config *ini.File + NoCredentialProcess bool + ProfileNameTemplate string + Prefix string + // PruneStartURLs is a slice of AWS SSO start URLs which profiles are being + // generated for. Existing profiles with these start URLs will be removed if + // they aren't found in the Profiles field. + PruneStartURLs []string +} + +// AddSource adds a new source to load profiles from to the generator. +func (g *Generator) AddSource(source Source) { + g.Sources = append(g.Sources, source) +} + +const profileSectionIllegalChars = ` \][;'"` + +// regular expression that matches on the characters \][;'" including whitespace, +// but does not match anything between {{ }} so it does not check inside go templates +var profileSectionIllegalCharsRegex = regexp.MustCompile(`(?s)((?:^|[^\{])[\s\][;'"]|[\][;'"][\s]*(?:$|[^\}]))`) +var matchGoTemplateSection = regexp.MustCompile(`\{\{[\s\S]*?\}\}`) + +var DefaultProfileNameTemplate = "{{ .AccountName }}/{{ .RoleName }}" + +// Generate AWS profiles and merge them with the existing config. +func (g *Generator) Generate(ctx context.Context) error { + var eg errgroup.Group + var mu sync.Mutex + var profiles []SSOProfile + + if strings.ContainsAny(g.Prefix, profileSectionIllegalChars) { + return fmt.Errorf("profile prefix must not contain any of these illegal characters (%s)", profileSectionIllegalChars) + } + + if g.ProfileNameTemplate == "" { + g.ProfileNameTemplate = DefaultProfileNameTemplate + } + + if g.ProfileNameTemplate != DefaultProfileNameTemplate { + cleaned := matchGoTemplateSection.ReplaceAllString(g.ProfileNameTemplate, "") + if profileSectionIllegalCharsRegex.MatchString(cleaned) { + return fmt.Errorf("profile template must not contain any of these illegal characters (%s)", profileSectionIllegalChars) + } + } + + for _, s := range g.Sources { + scopy := s + eg.Go(func() error { + got, err := scopy.GetProfiles(ctx) + if err != nil { + return err + } + mu.Lock() + defer mu.Unlock() + profiles = append(profiles, got...) + return nil + }) + } + + err := eg.Wait() + if err != nil { + return err + } + + err = Merge(MergeOpts{ + Config: g.Config, + SectionNameTemplate: g.ProfileNameTemplate, + Profiles: profiles, + NoCredentialProcess: g.NoCredentialProcess, + Prefix: g.Prefix, + PruneStartURLs: g.PruneStartURLs, + }) + return err +} diff --git a/pkg/config/config.go b/pkg/config/config.go index a4464e52..f7dfbec9 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -120,6 +120,7 @@ type Registry struct { PrefixDuplicateProfiles bool `toml:"prefixDuplicateProfiles,omitempty"` PrefixAllProfiles bool `toml:"prefixAllProfiles,omitempty"` Type string `toml:"type,omitempty"` + TenantID string `toml:"tenantID,omitempty"` } type AWSSSOConfiguration struct { diff --git a/pkg/granted/auth/auth.go b/pkg/granted/auth/auth.go new file mode 100644 index 00000000..25e8bd90 --- /dev/null +++ b/pkg/granted/auth/auth.go @@ -0,0 +1,99 @@ +package auth + +import ( + "fmt" + "time" + + "github.com/common-fate/clio" + "github.com/fwdcloudsec/granted/pkg/idclogin" + "github.com/fwdcloudsec/granted/pkg/providercfg" + "github.com/fwdcloudsec/granted/pkg/securestorage" + "github.com/urfave/cli/v2" +) + +var Command = cli.Command{ + Name: "auth", + Usage: "Manage OIDC authentication for Granted", + Flags: []cli.Flag{}, + Subcommands: []*cli.Command{ + &loginCommand, + &logoutCommand, + }, +} + +var loginCommand = cli.Command{ + Name: "login", + Usage: "Authenticate to an access provider", + Flags: []cli.Flag{ + &cli.StringFlag{Name: "url", Usage: "The access provider URL to authenticate with"}, + }, + Action: func(c *cli.Context) error { + providerURL := c.String("url") + if providerURL == "" { + providerURL = c.Args().First() + } + if providerURL == "" { + return fmt.Errorf("please provide a provider URL, e.g. 'granted auth login https://provider.example.com'") + } + + cfg, err := providercfg.LoadFromURL(c.Context, providerURL) + if err != nil { + return fmt.Errorf("failed to load provider config from %s: %w", providerURL, err) + } + + if cfg.Auth.Type != "oidc" { + return fmt.Errorf("unsupported auth type '%s' for provider at %s (expected 'oidc')", cfg.Auth.Type, providerURL) + } + + output, err := idclogin.ProviderLogin(c.Context, idclogin.ProviderLoginInput{ + IssuerURL: cfg.Auth.Issuer, + ClientID: cfg.Auth.ClientID, + Scopes: cfg.Auth.Scopes, + }) + if err != nil { + return fmt.Errorf("authentication failed: %w", err) + } + + tokenStorage := securestorage.NewProviderTokenStorage() + err = tokenStorage.StoreToken(providerURL, securestorage.ProviderToken{ + AccessToken: output.AccessToken, + RefreshToken: output.RefreshToken, + IDToken: output.IDToken, + TokenType: output.TokenType, + Expiry: time.Now().Add(time.Duration(output.ExpiresIn) * time.Second), + ProviderURL: providerURL, + }) + if err != nil { + return fmt.Errorf("failed to store token: %w", err) + } + + clio.Successf("Successfully authenticated to %s (%s)", cfg.Provider, providerURL) + return nil + }, +} + +var logoutCommand = cli.Command{ + Name: "logout", + Usage: "Log out of an access provider", + Flags: []cli.Flag{ + &cli.StringFlag{Name: "url", Usage: "The access provider URL to log out from"}, + }, + Action: func(c *cli.Context) error { + providerURL := c.String("url") + if providerURL == "" { + providerURL = c.Args().First() + } + if providerURL == "" { + return fmt.Errorf("please provide a provider URL, e.g. 'granted auth logout https://provider.example.com'") + } + + tokenStorage := securestorage.NewProviderTokenStorage() + err := tokenStorage.ClearToken(providerURL) + if err != nil { + return fmt.Errorf("failed to clear token: %w", err) + } + + clio.Successf("Logged out from %s", providerURL) + return nil + }, +} diff --git a/pkg/granted/credential_process.go b/pkg/granted/credential_process.go index cf4f9ec9..dcc1e897 100644 --- a/pkg/granted/credential_process.go +++ b/pkg/granted/credential_process.go @@ -1,6 +1,7 @@ package granted import ( + "context" "encoding/json" "fmt" "time" @@ -9,9 +10,11 @@ import ( "github.com/pkg/errors" "github.com/common-fate/clio" + "github.com/fwdcloudsec/granted/pkg/accessrequest" "github.com/fwdcloudsec/granted/pkg/cfaws" "github.com/fwdcloudsec/granted/pkg/config" "github.com/fwdcloudsec/granted/pkg/securestorage" + sethRetry "github.com/sethvargo/go-retry" "github.com/urfave/cli/v2" ) @@ -50,7 +53,6 @@ var CredentialProcess = cli.Command{ useCache := !cfg.DisableCredentialProcessCache && !cliNoCache if useCache { - // try and look up session credentials from the secure storage cache. cachedCreds, err := secureSessionCredentialStorage.GetCredentials(profileName) if err != nil { clio.Debugw("error loading cached credentials", "error", err, "profile", profileName) @@ -59,7 +61,6 @@ var CredentialProcess = cli.Command{ } else if cachedCreds.CanExpire && cachedCreds.Expires.Add(-c.Duration("window")).Before(time.Now()) { clio.Debugw("refreshing credentials", "reason", "credentials are expired") } else { - // if we get here, the cached session credentials are valid clio.Debugw("credentials found in cache", "expires", cachedCreds.Expires.String(), "canExpire", cachedCreds.CanExpire, "timeNow", time.Now().String(), "refreshIfBeforeNow", cachedCreds.Expires.Add(-c.Duration("window")).String()) return printCredentials(*cachedCreds) } @@ -69,7 +70,6 @@ var CredentialProcess = cli.Command{ clio.Debugw("refreshing credentials", "reason", "credential process cache is disabled via config") } - // purge the credentials from the cache err = secureSessionCredentialStorage.SecureStorage.Clear(profileName) if err != nil { clio.Debugw("error clearing cached credentials", "error", err, "profile", profileName) @@ -92,7 +92,26 @@ var CredentialProcess = cli.Command{ credentials, err := profile.AssumeTerminal(c.Context, cfaws.ConfigOpts{Duration: duration, UsingCredentialProcess: true, CredentialProcessAutoLogin: autoLogin, UseAuthorizationCode: cfg.UseAuthorizationCode}) if err != nil { - return err + clio.Debugw("initial assume failed, attempting retry with backoff", "error", err) + + // Retry with exponential backoff in case of transient errors + b := sethRetry.NewFibonacci(time.Second) + b = sethRetry.WithMaxDuration(time.Second*30, b) + retryErr := sethRetry.Do(c.Context, b, func(ctx context.Context) (err error) { + credentials, err = profile.AssumeTerminal(c.Context, cfaws.ConfigOpts{Duration: duration, UsingCredentialProcess: true, CredentialProcessAutoLogin: autoLogin}) + if err != nil { + return sethRetry.RetryableError(err) + } + return nil + }) + if retryErr != nil { + clio.Debugw("could not assume role after retries, notifying user to try requesting access", "error", err) + saveErr := accessrequest.Profile{Name: profileName}.Save() + if saveErr != nil { + return saveErr + } + return errors.New("You don't have access but you can request it with 'granted request latest'") + } } if !cfg.DisableCredentialProcessCache { clio.Debugw("storing refreshed credentials in credential process cache", "expires", credentials.Expires.String(), "canExpire", credentials.CanExpire, "timeNow", time.Now().String()) diff --git a/pkg/granted/entrypoint.go b/pkg/granted/entrypoint.go index 7679f9f2..d629cd6a 100644 --- a/pkg/granted/entrypoint.go +++ b/pkg/granted/entrypoint.go @@ -8,12 +8,15 @@ import ( "github.com/common-fate/clio" "github.com/common-fate/clio/cliolog" + "github.com/common-fate/useragent" "github.com/fwdcloudsec/granted/internal/build" "github.com/fwdcloudsec/granted/pkg/chromemsg" "github.com/fwdcloudsec/granted/pkg/config" + "github.com/fwdcloudsec/granted/pkg/granted/auth" "github.com/fwdcloudsec/granted/pkg/granted/doctor" "github.com/fwdcloudsec/granted/pkg/granted/middleware" "github.com/fwdcloudsec/granted/pkg/granted/registry" + "github.com/fwdcloudsec/granted/pkg/granted/request" "github.com/fwdcloudsec/granted/pkg/granted/settings" "github.com/urfave/cli/v2" "go.uber.org/zap" @@ -48,7 +51,10 @@ func GetCliApp() *cli.App { middleware.WithBeforeFuncs(&CredentialProcess, middleware.WithAutosync()), ®istry.ProfileRegistryCommand, &ConsoleCommand, + &login, &CacheCommand, + &auth.Command, + &request.Command, &doctor.Command, }, // Granted may be invoked via our browser extension, which uses the Native Messaging @@ -90,6 +96,8 @@ func GetCliApp() *cli.App { if err := config.SetupConfigFolder(); err != nil { return err } + // set the user agent + c.Context = useragent.NewContext(c.Context, "granted", build.Version) err = chromemsg.ConfigureHost() if err != nil { @@ -102,3 +110,16 @@ func GetCliApp() *cli.App { return app } + +var login = cli.Command{ + Name: "login", + Usage: "Log in to an access provider [deprecated: use granted auth login]", + Flags: []cli.Flag{ + &cli.BoolFlag{Name: "lazy", Usage: "When the lazy flag is used, a login flow will only be started when the access token is expired"}, + }, + Action: func(c *cli.Context) error { + clio.Warn("this command is deprecated and will be removed in a future release") + clio.Warn("use 'granted auth login ' to authenticate with an access provider") + return nil + }, +} diff --git a/pkg/granted/registry/add.go b/pkg/granted/registry/add.go index e1d8778b..943d2b33 100644 --- a/pkg/granted/registry/add.go +++ b/pkg/granted/registry/add.go @@ -8,6 +8,7 @@ import ( "github.com/common-fate/clio" grantedConfig "github.com/fwdcloudsec/granted/pkg/config" "github.com/fwdcloudsec/granted/pkg/granted/awsmerge" + "github.com/fwdcloudsec/granted/pkg/httpregistry" "github.com/fwdcloudsec/granted/pkg/granted/registry/gitregistry" "github.com/fwdcloudsec/granted/pkg/testable" @@ -34,7 +35,8 @@ var AddCommand = cli.Command{ &cli.BoolFlag{Name: "prefix-duplicate-profiles", Aliases: []string{"pdp"}, Usage: "Provide this flag if you want to append registry name to duplicate profiles"}, &cli.BoolFlag{Name: "write-on-sync-failure", Aliases: []string{"wosf"}, Usage: "Always overwrite AWS config, even if sync fails (DEPRECATED)"}, &cli.StringSliceFlag{Name: "required-key", Aliases: []string{"r", "requiredKey"}, Usage: "Used to bypass the prompt or override user specific values"}, - &cli.StringFlag{Name: "type", Value: "git", Usage: "specify the type of granted registry source you want to set up. Default: git"}}, + &cli.StringFlag{Name: "type", Value: "git", Usage: "specify the type of granted registry source you want to set up. Default: git"}, + &cli.StringFlag{Name: "tenant-id", Usage: "For HTTP registries: the tenant ID for multi-tenant providers"}}, ArgsUsage: "--name --url --type ", Action: func(c *cli.Context) error { @@ -59,13 +61,10 @@ var AddCommand = cli.Command{ requiredKey := c.StringSlice("required-key") priority := c.Int("priority") registryType := c.String("type") + tenantID := c.String("tenant-id") - if registryType == "http" { - return fmt.Errorf("HTTP registries are not longer supported in this version of Granted: if you are impacted by this please raise an issue: https://github.com/fwdcloudsec/granted/issues/new") - } - - if registryType != "git" { - return fmt.Errorf("invalid registry type provided: %s. must be 'git'", c.String("type")) + if registryType != "git" && registryType != "http" { + return fmt.Errorf("invalid registry type provided: %s. must be 'git' or 'http'", c.String("type")) } for _, r := range gConf.ProfileRegistry.Registries { @@ -86,84 +85,164 @@ var AddCommand = cli.Command{ PrefixDuplicateProfiles: prefixDuplicateProfiles, PrefixAllProfiles: prefixAllProfiles, Type: registryType, + TenantID: tenantID, } - registry, err := gitregistry.New(gitregistry.Opts{ - Name: name, - URL: URL, - Path: pathFlag, - Filename: configFileName, - Ref: ref, - RequiredKeys: requiredKey, - Interactive: true, - }) - - if err != nil { - return err - } - src, err := registry.AWSProfiles(ctx, true) - if err != nil { - return err - } + if registryType == "git" { + registry, err := gitregistry.New(gitregistry.Opts{ + Name: name, + URL: URL, + Path: pathFlag, + Filename: configFileName, + Ref: ref, + RequiredKeys: requiredKey, + Interactive: true, + }) - dst, filepath, err := loadAWSConfigFile() - if err != nil { - return err - } + if err != nil { + return err + } + src, err := registry.AWSProfiles(ctx, true) + if err != nil { + return err + } - merged, err := awsmerge.WithRegistry(src, dst, awsmerge.RegistryOpts{ - Name: name, - PrefixAllProfiles: prefixAllProfiles, - PrefixDuplicateProfiles: prefixDuplicateProfiles, - }) - var dpe awsmerge.DuplicateProfileError - if errors.As(err, &dpe) { - clio.Warnf(err.Error()) + dst, filepath, err := loadAWSConfigFile() + if err != nil { + return err + } - const ( - DUPLICATE = "Add registry name as prefix to all duplicate profiles for this registry" - ABORT = "Abort, I will manually fix this" - ) + merged, err := awsmerge.WithRegistry(src, dst, awsmerge.RegistryOpts{ + Name: name, + PrefixAllProfiles: prefixAllProfiles, + PrefixDuplicateProfiles: prefixDuplicateProfiles, + }) + var dpe awsmerge.DuplicateProfileError + if errors.As(err, &dpe) { + clio.Warnf(err.Error()) + + const ( + DUPLICATE = "Add registry name as prefix to all duplicate profiles for this registry" + ABORT = "Abort, I will manually fix this" + ) + + options := []string{DUPLICATE, ABORT} + + in := survey.Select{Message: "Please select which option would you like to choose to resolve: ", Options: options} + var selected string + err = testable.AskOne(&in, &selected) + if err != nil { + return err + } + + if selected == ABORT { + return fmt.Errorf("aborting sync for registry %s", name) + } + + registryConfig.PrefixDuplicateProfiles = true + + // try and merge again + merged, err = awsmerge.WithRegistry(src, dst, awsmerge.RegistryOpts{ + Name: name, + PrefixAllProfiles: prefixAllProfiles, + PrefixDuplicateProfiles: true, + }) + if err != nil { + return fmt.Errorf("error after trying to merge profiles again: %w", err) + } + } - options := []string{DUPLICATE, ABORT} + // we have verified that this registry is a valid one and sync is completed. + // so save the new registry to config file. + gConf.ProfileRegistry.Registries = append(gConf.ProfileRegistry.Registries, registryConfig) + err = gConf.Save() + if err != nil { + return err + } - in := survey.Select{Message: "Please select which option would you like to choose to resolve: ", Options: options} - var selected string - err = testable.AskOne(&in, &selected) + err = merged.SaveTo(filepath) if err != nil { return err } - if selected == ABORT { - return fmt.Errorf("aborting sync for registry %s", name) + return nil + } else { + + registry := httpregistry.New(httpregistry.Opts{ + Name: name, + URL: URL, + TenantID: tenantID, + }) + + if err != nil { + return err + } + src, err := registry.AWSProfiles(ctx, true) + if err != nil { + return err } - registryConfig.PrefixDuplicateProfiles = true + dst, filepath, err := loadAWSConfigFile() + if err != nil { + return err + } - // try and merge again - merged, err = awsmerge.WithRegistry(src, dst, awsmerge.RegistryOpts{ + merged, err := awsmerge.WithRegistry(src, dst, awsmerge.RegistryOpts{ Name: name, PrefixAllProfiles: prefixAllProfiles, - PrefixDuplicateProfiles: true, + PrefixDuplicateProfiles: prefixDuplicateProfiles, }) + var dpe awsmerge.DuplicateProfileError + if errors.As(err, &dpe) { + clio.Warnf(err.Error()) + + const ( + DUPLICATE = "Add registry name as prefix to all duplicate profiles for this registry" + ABORT = "Abort, I will manually fix this" + ) + + options := []string{DUPLICATE, ABORT} + + in := survey.Select{Message: "Please select which option would you like to choose to resolve: ", Options: options} + var selected string + err = testable.AskOne(&in, &selected) + if err != nil { + return err + } + + if selected == ABORT { + return fmt.Errorf("aborting sync for registry %s", name) + } + + registryConfig.PrefixDuplicateProfiles = true + + // try and merge again + merged, err = awsmerge.WithRegistry(src, dst, awsmerge.RegistryOpts{ + Name: name, + PrefixAllProfiles: prefixAllProfiles, + PrefixDuplicateProfiles: true, + }) + if err != nil { + return fmt.Errorf("error after trying to merge profiles again: %w", err) + } + } + + // we have verified that this registry is a valid one and sync is completed. + // so save the new registry to config file. + gConf.ProfileRegistry.Registries = append(gConf.ProfileRegistry.Registries, registryConfig) + err = gConf.Save() if err != nil { - return fmt.Errorf("error after trying to merge profiles again: %w", err) + return err } - } - // we have verified that this registry is a valid one and sync is completed. - // so save the new registry to config file. - gConf.ProfileRegistry.Registries = append(gConf.ProfileRegistry.Registries, registryConfig) - err = gConf.Save() - if err != nil { - return err - } + err = merged.SaveTo(filepath) + if err != nil { + return err + } + + return nil - err = merged.SaveTo(filepath) - if err != nil { - return err } - return nil }, -} \ No newline at end of file +} diff --git a/pkg/granted/registry/registry.go b/pkg/granted/registry/registry.go index 7f431925..c47d3815 100644 --- a/pkg/granted/registry/registry.go +++ b/pkg/granted/registry/registry.go @@ -5,6 +5,7 @@ import ( "sort" grantedConfig "github.com/fwdcloudsec/granted/pkg/config" + "github.com/fwdcloudsec/granted/pkg/httpregistry" "github.com/fwdcloudsec/granted/pkg/granted/registry/gitregistry" "gopkg.in/ini.v1" ) @@ -48,7 +49,18 @@ func GetProfileRegistries(interactive bool) ([]loadedRegistry, error) { Config: r, Registry: reg, }) + } else { + reg := httpregistry.New(httpregistry.Opts{ + Name: r.Name, + URL: r.URL, + TenantID: r.TenantID, + }) + registries = append(registries, loadedRegistry{ + Config: r, + Registry: reg, + }) } + } // this will sort the registry based on priority. diff --git a/pkg/granted/request/request.go b/pkg/granted/request/request.go new file mode 100644 index 00000000..89dcb8f3 --- /dev/null +++ b/pkg/granted/request/request.go @@ -0,0 +1,88 @@ +package request + +import ( + "context" + "time" + + "github.com/common-fate/clio" + "github.com/fwdcloudsec/granted/pkg/accessrequest" + "github.com/fwdcloudsec/granted/pkg/cfaws" + "github.com/fwdcloudsec/granted/pkg/hook/accessrequesthook" + "github.com/fwdcloudsec/granted/pkg/hook/httpprovider" + "github.com/fwdcloudsec/granted/pkg/providercfg" + "github.com/urfave/cli/v2" +) + +var Command = cli.Command{ + Name: "request", + Usage: "Request access to a role", + Subcommands: []*cli.Command{ + &latestCommand, + // TODO: re-enable check and close commands with HTTP provider + }, +} + +func newHTTPProvider(providerURL string) (accessrequesthook.AccessProvider, error) { + cfg, err := providercfg.LoadFromURL(context.Background(), providerURL) + if err != nil { + return nil, err + } + return httpprovider.New(cfg, providerURL, ""), nil +} + +var latestCommand = cli.Command{ + Name: "latest", + Usage: "Request access to the latest AWS role you attempted to use", + Flags: []cli.Flag{ + &cli.StringFlag{Name: "reason", Usage: "A reason for access"}, + &cli.StringSliceFlag{Name: "attach", Usage: "Attach justifications to your request, such as a Jira ticket id or url `--attach=TP-123`"}, + &cli.DurationFlag{Name: "duration", Usage: "Duration of request, defaults to max duration of the access rule."}, + &cli.BoolFlag{Name: "confirm", Aliases: []string{"y"}, Usage: "Skip confirmation prompts for access requests"}, + }, + Action: func(c *cli.Context) error { + latest, err := accessrequest.LatestProfile() + if err != nil { + return err + } + + profiles, err := cfaws.LoadProfiles() + if err != nil { + return err + } + + profile, err := profiles.LoadInitialisedProfile(c.Context, latest.Name) + if err != nil { + return err + } + + hook, err := accessrequesthook.NewHookFromProfile(profile, newHTTPProvider) + if err != nil { + clio.Debugw("failed to create access hook", "error", err) + return err + } + if hook == nil { + clio.Info("No access provider configured for this profile") + return nil + } + + reason := c.String("reason") + duration := c.Duration("duration") + var apiDuration *time.Duration + if duration != 0 { + apiDuration = &duration + } + + _, _, err = hook.NoAccess(c.Context, accessrequesthook.NoAccessInput{ + Profile: profile, + Reason: reason, + Attachments: c.StringSlice("attach"), + Duration: apiDuration, + Confirm: c.Bool("confirm"), + }) + if err != nil { + return err + } + + return nil + }, +} diff --git a/pkg/granted/settings/set.go b/pkg/granted/settings/set.go index 1ddeb170..c5743bf1 100644 --- a/pkg/granted/settings/set.go +++ b/pkg/granted/settings/set.go @@ -7,7 +7,6 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/common-fate/clio" - "github.com/common-fate/grab" "github.com/fwdcloudsec/granted/pkg/config" "github.com/urfave/cli/v2" ) @@ -152,7 +151,10 @@ func (f keyringFields) Set(value any) error { return nil } func (f keyringFields) Value() any { - return grab.Value(grab.Value(f.field)) + if f.field == nil || *f.field == nil { + return "" + } + return **f.field } func (f keyringFields) Kind() reflect.Kind { return reflect.String diff --git a/pkg/granted/settings/set_test.go b/pkg/granted/settings/set_test.go index 9ca7cbb2..c2f58d9b 100644 --- a/pkg/granted/settings/set_test.go +++ b/pkg/granted/settings/set_test.go @@ -4,10 +4,11 @@ import ( "slices" "testing" - "github.com/common-fate/grab" "github.com/stretchr/testify/assert" ) +func ptrString(s string) *string { return &s } + func TestFieldOptions(t *testing.T) { type input struct { A string @@ -41,7 +42,7 @@ func TestFieldOptions(t *testing.T) { D *string }{ C: "C", - D: grab.Ptr("D"), + D: ptrString("D"), }, }, want: []string{"A", "B.C", "B.D"}, diff --git a/pkg/granted/sso.go b/pkg/granted/sso.go index d70ee66e..e8f4b562 100644 --- a/pkg/granted/sso.go +++ b/pkg/granted/sso.go @@ -17,11 +17,12 @@ import ( "github.com/aws/aws-sdk-go-v2/aws/retry" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/sso" - "github.com/common-fate/awsconfigfile" "github.com/common-fate/clio" "github.com/common-fate/clio/clierr" + "github.com/fwdcloudsec/granted/pkg/awsconfigfile" "github.com/fwdcloudsec/granted/pkg/cfaws" grantedconfig "github.com/fwdcloudsec/granted/pkg/config" + "github.com/fwdcloudsec/granted/pkg/httpregistry" "github.com/fwdcloudsec/granted/pkg/idclogin" "github.com/fwdcloudsec/granted/pkg/securestorage" "github.com/fwdcloudsec/granted/pkg/testable" @@ -53,7 +54,7 @@ var GenerateCommand = cli.Command{ &cli.StringFlag{Name: "config", Usage: "Specify the SSO config section in the Granted config file ([SSO.name])", Value: "default"}, &cli.StringFlag{Name: "prefix", Usage: "Specify a prefix for all generated profile names"}, &cli.StringFlag{Name: "sso-region", Usage: "Specify the SSO region"}, - &cli.StringSliceFlag{Name: "source", Usage: "The sources to load AWS profiles from (valid values are: 'aws-sso')", Value: cli.NewStringSlice("aws-sso")}, + &cli.StringSliceFlag{Name: "source", Usage: "The sources to load AWS profiles from ('aws-sso' or a named profile registry)", Value: cli.NewStringSlice("aws-sso")}, &cli.BoolFlag{Name: "no-credential-process", Usage: "Generate profiles without the Granted credential-process integration"}, &cli.StringFlag{Name: "profile-template", Usage: "Specify profile name template", Value: awsconfigfile.DefaultProfileNameTemplate}, &cli.StringFlag{Name: "sso-browser-profile", Usage: "Use a pre-existing profile in your browser for SSO login", EnvVars: []string{"GRANTED_SSO_BROWSER_PROFILE"}}, @@ -108,10 +109,12 @@ var GenerateCommand = cli.Command{ switch s { case "aws-sso": g.AddSource(AWSSSOSource{SSORegion: ssoRegion, StartURL: startURL, SSOBrowserProfile: ssoBrowserProfile, UseDeviceCode: c.Bool("use-device-code")}) - case "commonfate", "common-fate", "cf": - return fmt.Errorf("the common fate profile source is no longer supported: https://www.commonfate.io/blog/winding-down") default: - return fmt.Errorf("unknown profile source %s: allowed sources are aws-sso", s) + reg, err := registrySourceByName(cfg, s) + if err != nil { + return err + } + g.AddSource(reg) } } @@ -138,8 +141,8 @@ var PopulateCommand = cli.Command{ &cli.StringFlag{Name: "prefix", Usage: "Specify a prefix for all generated profile names"}, &cli.StringFlag{Name: "sso-region", Usage: "Specify the SSO region"}, &cli.StringSliceFlag{Name: "sso-scope", Usage: "Specify the SSO scopes"}, - &cli.StringSliceFlag{Name: "source", Usage: "The sources to load AWS profiles from", Value: cli.NewStringSlice("aws-sso")}, - &cli.BoolFlag{Name: "prune", Usage: "Remove any generated profiles with the 'common_fate_generated_from' key which no longer exist"}, + &cli.StringSliceFlag{Name: "source", Usage: "The sources to load AWS profiles from ('aws-sso' or a named profile registry)", Value: cli.NewStringSlice("aws-sso")}, + &cli.BoolFlag{Name: "prune", Usage: "Remove any generated profiles which no longer exist in the source"}, &cli.StringFlag{Name: "profile-template", Usage: "Specify profile name template", Value: awsconfigfile.DefaultProfileNameTemplate}, &cli.BoolFlag{Name: "no-credential-process", Usage: "Generate profiles without the Granted credential-process integration"}, &cli.StringFlag{Name: "sso-browser-profile", Usage: "Use a pre-existing profile in your browser for SSO login", EnvVars: []string{"GRANTED_SSO_BROWSER_PROFILE"}}, @@ -225,10 +228,12 @@ var PopulateCommand = cli.Command{ switch s { case "aws-sso": g.AddSource(AWSSSOSource{SSORegion: ssoRegion, StartURL: startURL, SSOScopes: c.StringSlice("sso-scope"), SSOBrowserProfile: ssoBrowserProfile, UseDeviceCode: c.Bool("use-device-code")}) - case "commonfate", "common-fate", "cf": - return fmt.Errorf("the common fate profile source is no longer supported: https://www.commonfate.io/blog/winding-down") default: - return fmt.Errorf("unknown profile source %s: allowed sources are aws-sso", s) + reg, err := registrySourceByName(cfg, s) + if err != nil { + return err + } + g.AddSource(reg) } } err = g.Generate(ctx) @@ -337,6 +342,30 @@ var LoginCommand = cli.Command{ }, } +// registrySourceByName looks up a named profile registry from the Granted +// config and returns it as an awsconfigfile.Source for use with sso generate/populate. +func registrySourceByName(cfg *grantedconfig.Config, name string) (awsconfigfile.Source, error) { + var registryNames []string + for _, r := range cfg.ProfileRegistry.Registries { + registryNames = append(registryNames, r.Name) + if r.Name == name { + if r.Type != "http" && r.Type != "" { + return nil, fmt.Errorf("profile registry %q is type %q, only 'http' registries can be used as an SSO profile source", name, r.Type) + } + return httpregistry.New(httpregistry.Opts{ + Name: r.Name, + URL: r.URL, + TenantID: r.TenantID, + }), nil + } + } + + if len(registryNames) == 0 { + return nil, fmt.Errorf("unknown profile source %q: no profile registries are configured.\nAdd one with: granted registry add --name --url --type http", name) + } + return nil, fmt.Errorf("unknown profile source %q: no profile registry found with that name.\nAvailable registries: %s", name, fmt.Sprintf("%v", registryNames)) +} + type AWSSSOSource struct { SSORegion string StartURL string diff --git a/pkg/hook/accessrequesthook/accessrequesthook.go b/pkg/hook/accessrequesthook/accessrequesthook.go new file mode 100644 index 00000000..058d7a0a --- /dev/null +++ b/pkg/hook/accessrequesthook/accessrequesthook.go @@ -0,0 +1,518 @@ +package accessrequesthook + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "strings" + "time" + + "github.com/AlecAivazis/survey/v2" + "github.com/briandowns/spinner" + "github.com/common-fate/clio" + "github.com/fatih/color" + "github.com/fwdcloudsec/granted/pkg/cfaws" + "github.com/mattn/go-isatty" +) + +type Hook struct { + Provider AccessProvider +} + +// NewHook creates a Hook with the given provider. +// Callers should construct the appropriate AccessProvider (e.g., httpprovider.New) +// and pass it here. Returns nil if no provider is given. +func NewHook(provider AccessProvider) *Hook { + if provider == nil { + return nil + } + return &Hook{Provider: provider} +} + +// NewHookFromProfile creates a Hook configured from the profile's provider URL. +// Returns (nil, nil) if no provider is configured on the profile. +// This is a convenience function that requires the caller to provide a +// factory function to avoid import cycles. +func NewHookFromProfile(profile *cfaws.Profile, factory func(providerURL string) (AccessProvider, error)) (*Hook, error) { + providerURL := getProviderURL(profile) + if providerURL == "" { + return nil, nil + } + provider, err := factory(providerURL) + if err != nil { + return nil, err + } + return &Hook{Provider: provider}, nil +} + +// getProviderURL reads the access provider URL from a profile's raw config. +func getProviderURL(profile *cfaws.Profile) string { + if profile == nil || profile.RawConfig == nil { + return "" + } + for _, key := range []string{"granted_access_provider_url", "common_fate_url"} { + if profile.RawConfig.HasKey(key) { + k, err := profile.RawConfig.GetKey(key) + if err != nil { + continue + } + if k.Value() != "" { + return k.Value() + } + } + } + return "" +} + +type NoAccessInput struct { + Profile *cfaws.Profile + Reason string + Attachments []string + Duration *time.Duration + Confirm bool + Wait bool + StartTime time.Time +} + +func (h Hook) NoAccess(ctx context.Context, input NoAccessInput) (retry bool, justActivated bool, err error) { + if h.Provider == nil { + clio.Debugw("no access provider configured, skipping access request hook") + return false, false, nil + } + + target := fmt.Sprintf("AWS::Account::%s", input.Profile.AWSConfig.SSOAccountID) + role := input.Profile.AWSConfig.SSORoleName + + clio.Infof("You don't currently have access to %s, checking if we can request access...\t[target=%s, role=%s]", input.Profile.Name, target, role) + + retry, _, justActivated, err = h.NoEntitlementAccess(ctx, NoEntitlementAccessInput{ + Target: target, + Role: role, + Reason: input.Reason, + Duration: input.Duration, + Confirm: input.Confirm, + Wait: input.Wait, + StartTime: input.StartTime, + Attachments: input.Attachments, + }) + + return retry, justActivated, err +} + +type NoEntitlementAccessInput struct { + Target string + Role string + Reason string + Attachments []string + Duration *time.Duration + Confirm bool + Wait bool + StartTime time.Time +} + +func (h Hook) NoEntitlementAccess(ctx context.Context, input NoEntitlementAccessInput) (retry bool, result *EnsureResponse, justActivated bool, err error) { + justActivated = false + + req := EnsureRequest{ + Entitlements: []EntitlementInput{ + { + Target: input.Target, + Role: input.Role, + Duration: input.Duration, + }, + }, + Justification: Justification{}, + } + + hasChanges, result, err := h.dryRun(ctx, &req, false, input.Confirm) + if isUnauthorized(err) { + clio.Debugw("prompting user login because token is expired", "error_details", err.Error()) + clio.Infof("You need to log in to your access provider") + + err = h.Provider.Login(ctx) + if err != nil { + return false, nil, justActivated, err + } + + hasChanges, result, err = h.dryRun(ctx, &req, false, input.Confirm) + } + + if err != nil { + return false, nil, justActivated, err + } + if !hasChanges { + if result != nil && len(result.Grants) == 1 && result.Grants[0].Status == GrantStatusActive { + return false, result, justActivated, nil + } + if input.Wait { + return true, result, justActivated, nil + } + return false, nil, justActivated, errors.New("no access changes") + } + + req.DryRun = false + + if input.Reason != "" { + req.Justification.Reason = input.Reason + } else { + if result.Validation != nil && result.Validation.HasReason { + if !IsTerminal(os.Stdin.Fd()) { + return false, nil, justActivated, errors.New("detected a noninteractive terminal: a reason is required to make this access request, to apply the planned changes please re-run with the --reason flag") + } + + var customReason string + msg := "Reason for access (Required)" + reasonPrompt := &survey.Input{ + Message: msg, + Help: "Will be stored in audit trails and associated with your request", + } + withStdio := survey.WithStdio(os.Stdin, os.Stderr, os.Stderr) + err = survey.AskOne(reasonPrompt, &customReason, withStdio, survey.WithValidator(survey.Required)) + if err != nil { + return false, nil, justActivated, err + } + + req.Justification.Reason = customReason + } + } + + if len(input.Attachments) > 0 { + req.Justification.Attachments = input.Attachments + } else { + if result.Validation != nil && result.Validation.HasJiraTicket { + if !IsTerminal(os.Stdin.Fd()) { + return false, nil, justActivated, errors.New("detected a noninteractive terminal: a jira ticket attachment is required to make this access request, to apply the planned changes please re-run with the --attach flag") + } + + var attachment string + msg := "Jira ticket attachment for access (Required)" + reasonPrompt := &survey.Input{ + Message: msg, + Help: "Will be stored in audit trails and associated with your request", + } + withStdio := survey.WithStdio(os.Stdin, os.Stderr, os.Stderr) + err = survey.AskOne(reasonPrompt, &attachment, withStdio, survey.WithValidator(survey.Required)) + if err != nil { + return false, nil, justActivated, err + } + + req.Justification.Attachments = append(req.Justification.Attachments, attachment) + } + } + + si := spinner.New(spinner.CharSets[14], 100*time.Millisecond) + si.Suffix = " ensuring access..." + si.Writer = os.Stderr + si.Start() + + res, err := h.Provider.Ensure(ctx, &req) + if err != nil { + si.Stop() + return false, nil, justActivated, err + } + si.Stop() + + printDiagnostics(res.Diagnostics) + + clio.Debugw("Ensure response", "response", debugJSON(res)) + + for _, g := range res.Grants { + exp := ShortDur(g.Duration) + + switch g.Change { + case GrantChangeActivated: + _, _ = color.New(color.BgHiGreen).Fprintf(os.Stderr, "[ACTIVATED]") + _, _ = color.New(color.FgGreen).Fprintf(os.Stderr, " %s was activated for %s: %s\n", g.Name, exp, h.Provider.RequestURL(g.AccessRequestID)) + retry = true + justActivated = true + continue + + case GrantChangeExtended: + extendedTime := "" + if g.Extension != nil { + extendedTime = ShortDur(g.Extension.ExtensionDuration) + } + _, _ = color.New(color.BgBlue).Fprintf(os.Stderr, "[EXTENDED]") + _, _ = color.New(color.FgBlue).Fprintf(os.Stderr, " %s was extended for another %s: %s\n", g.Name, extendedTime, h.Provider.RequestURL(g.AccessRequestID)) + _, _ = color.New(color.FgGreen).Printf(" %s will now expire in %s\n", g.Name, exp) + retry = true + continue + + case GrantChangeRequested: + _, _ = color.New(color.BgHiYellow, color.FgBlack).Fprintf(os.Stderr, "[REQUESTED]") + _, _ = color.New(color.FgYellow).Fprintf(os.Stderr, " %s requires approval: %s\n", g.Name, h.Provider.RequestURL(g.AccessRequestID)) + if input.Wait { + return true, res, justActivated, nil + } + return false, nil, justActivated, errors.New("applying access was attempted but the resources requested require approval before activation") + + case GrantChangeProvisioningFailed: + _, _ = color.New(color.FgRed).Fprintf(os.Stderr, "[ERROR] %s failed provisioning: %s\n", g.Name, h.Provider.RequestURL(g.AccessRequestID)) + return false, nil, justActivated, errors.New("access provisioning failed") + } + + switch g.Status { + case GrantStatusActive: + if g.ExpiresAt != nil { + exp = ShortDur(time.Until(*g.ExpiresAt)) + } + _, _ = color.New(color.FgGreen).Fprintf(os.Stderr, "[ACTIVE] %s is already active for the next %s: %s\n", g.Name, exp, h.Provider.RequestURL(g.AccessRequestID)) + retry = true + continue + + case GrantStatusPending: + _, _ = color.New(color.FgWhite).Fprintf(os.Stderr, "[PENDING] %s is already pending: %s\n", g.Name, h.Provider.RequestURL(g.AccessRequestID)) + if input.Wait { + return true, res, justActivated, nil + } + return false, nil, justActivated, errors.New("access is pending approval") + + case GrantStatusClosed: + _, _ = color.New(color.FgWhite).Fprintf(os.Stderr, "[CLOSED] %s is closed but was still returned: %s\n. This is most likely due to an error and should be reported.", g.Name, h.Provider.RequestURL(g.AccessRequestID)) + return false, nil, justActivated, errors.New("grant was closed") + + default: + _, _ = color.New(color.FgWhite).Fprintf(os.Stderr, "[UNSPECIFIED] %s is in an unspecified status: %s\n. This is most likely due to an error and should be reported.", g.Name, h.Provider.RequestURL(g.AccessRequestID)) + return false, nil, justActivated, errors.New("grant was in an unspecified state") + } + } + + printDiagnostics(res.Diagnostics) + + return retry, res, justActivated, nil +} + +func (h Hook) RetryAccess(ctx context.Context, input NoAccessInput) error { + if h.Provider == nil { + return nil + } + + target := fmt.Sprintf("AWS::Account::%s", input.Profile.AWSConfig.SSOAccountID) + role := input.Profile.AWSConfig.SSORoleName + _, err := h.RetryNoEntitlementAccess(ctx, NoEntitlementAccessInput{ + Target: target, + Role: role, + Reason: input.Reason, + Duration: input.Duration, + Confirm: input.Confirm, + Wait: input.Wait, + StartTime: input.StartTime, + Attachments: input.Attachments, + }) + return err +} + +func (h Hook) RetryNoEntitlementAccess(ctx context.Context, input NoEntitlementAccessInput) (result *EnsureResponse, err error) { + req := EnsureRequest{ + Entitlements: []EntitlementInput{ + { + Target: input.Target, + Role: input.Role, + Duration: input.Duration, + }, + }, + Justification: Justification{}, + } + + res, err := h.Provider.Ensure(ctx, &req) + if err != nil { + return nil, err + } + + clio.Debugw("ensure response", "res", debugJSON(res)) + + now := time.Now() + elapsed := now.Sub(input.StartTime).Round(time.Second * 10) + + allGrantsApproved := true + allGrantsActivated := true + for _, g := range res.Grants { + if g.Status == GrantStatusActive { + continue + } + if g.Approved && g.Change == GrantChangeUnspecified && g.ProvisioningStatus != "successful" { + clio.Infof("Request was approved but failed to activate, you might not have permission to activate. You can try and activate the access using the web console. [%s elapsed]", elapsed) + printDiagnostics(res.Diagnostics) + } + if !g.Approved { + clio.Infof("Waiting for request to be approved... [%s elapsed]", elapsed) + allGrantsApproved = false + } + if g.ActivatedAt == nil { + allGrantsActivated = false + } + } + + if !allGrantsApproved || !allGrantsActivated { + return res, errors.New("waiting on all grants to be approved and activated") + } + return res, nil +} + +func (h Hook) dryRun(ctx context.Context, req *EnsureRequest, jsonOutput bool, confirm bool) (bool, *EnsureResponse, error) { + req.DryRun = true + + si := spinner.New(spinner.CharSets[14], 100*time.Millisecond) + si.Suffix = " planning access changes..." + si.Writer = os.Stderr + si.Start() + + res, err := h.Provider.Ensure(ctx, req) + if err != nil { + si.Stop() + return false, nil, err + } + + si.Stop() + + clio.Debugw("Ensure response", "response", debugJSON(res)) + + if jsonOutput { + resJSON, err := json.Marshal(res) + if err != nil { + return false, nil, err + } + fmt.Println(string(resJSON)) + return false, nil, errors.New("exiting because --output=json was specified: use --output=text to show an interactive prompt, or use --confirm to proceed with the changes") + } + + var hasChanges bool + + for _, g := range res.Grants { + exp := ShortDur(g.Duration) + + if g.Change != GrantChangeNone && g.Change != GrantChangeUnspecified { + hasChanges = true + } + + switch g.Change { + case GrantChangeActivated: + _, _ = color.New(color.BgHiGreen).Fprintf(os.Stderr, "[WILL ACTIVATE]") + _, _ = color.New(color.FgGreen).Fprintf(os.Stderr, " %s will be activated for %s: %s\n", g.Name, exp, h.Provider.RequestURL(g.AccessRequestID)) + continue + + case GrantChangeExtended: + extendedTime := "" + if g.Extension != nil { + extendedTime = ShortDur(g.Extension.ExtensionDuration) + } + _, _ = color.New(color.BgBlue).Printf("[WILL EXTEND]") + _, _ = color.New(color.FgBlue).Printf(" %s will be extended for another %s: %s\n", g.Name, extendedTime, h.Provider.RequestURL(g.AccessRequestID)) + continue + + case GrantChangeRequested: + _, _ = color.New(color.BgHiYellow, color.FgBlack).Fprintf(os.Stderr, "[WILL REQUEST]") + _, _ = color.New(color.FgYellow).Fprintf(os.Stderr, " %s will require approval\n", g.Name) + continue + + case GrantChangeProvisioningFailed: + _, _ = color.New(color.FgRed).Fprintf(os.Stderr, "[ERROR] %s will fail provisioning\n", g.Name) + continue + } + + switch g.Status { + case GrantStatusActive: + if g.ExpiresAt != nil { + exp = ShortDur(time.Until(*g.ExpiresAt)) + } + _, _ = color.New(color.FgGreen).Fprintf(os.Stderr, "[ACTIVE] %s is already active for the next %s: %s\n", g.Name, exp, h.Provider.RequestURL(g.AccessRequestID)) + continue + case GrantStatusPending: + _, _ = color.New(color.FgWhite).Fprintf(os.Stderr, "[PENDING] %s is already pending: %s\n", g.Name, h.Provider.RequestURL(g.AccessRequestID)) + continue + case GrantStatusClosed: + _, _ = color.New(color.FgWhite).Fprintf(os.Stderr, "[CLOSED] %s is closed but was still returned: %s\n. This is most likely due to an error and should be reported.", g.Name, h.Provider.RequestURL(g.AccessRequestID)) + continue + } + + _, _ = color.New(color.FgWhite).Fprintf(os.Stderr, "[UNSPECIFIED] %s is in an unspecified status: %s\n. This is most likely due to an error and should be reported.", g.Name, h.Provider.RequestURL(g.AccessRequestID)) + } + + printDiagnostics(res.Diagnostics) + + if !hasChanges { + return false, res, nil + } + + if !confirm { + if !IsTerminal(os.Stdin.Fd()) { + return false, nil, errors.New("detected a noninteractive terminal: to apply the planned changes please re-run with the --confirm flag") + } + + withStdio := survey.WithStdio(os.Stdin, os.Stderr, os.Stderr) + confirmPrompt := survey.Confirm{ + Message: "Apply proposed access changes", + } + err = survey.AskOne(&confirmPrompt, &confirm, withStdio) + if err != nil { + return false, nil, err + } + } + + if !confirm { + return false, nil, errors.New("cancelled operation") + } + + clio.Info("Attempting to grant access...") + return confirm, res, nil +} + +func IsTerminal(fd uintptr) bool { + return isatty.IsTerminal(fd) || isatty.IsCygwinTerminal(fd) +} + +func ShortDur(d time.Duration) string { + if d > time.Minute { + d = d.Round(time.Minute) + } else { + d = d.Round(time.Second) + } + + s := d.String() + if strings.HasSuffix(s, "m0s") { + s = s[:len(s)-2] + } + if strings.HasSuffix(s, "h0m") { + s = s[:len(s)-2] + } + return s +} + +func printDiagnostics(diags []Diagnostic) { + for _, d := range diags { + switch d.Level { + case "error": + clio.Errorf("[diagnostic] %s", d.Message) + case "warning": + clio.Warnf("[diagnostic] %s", d.Message) + default: + clio.Infof("[diagnostic] %s", d.Message) + } + } +} + +func isUnauthorized(err error) bool { + if err == nil { + return false + } + var u Unauthorized + if errors.As(err, &u) { + return u.IsUnauthorized() + } + // Fallback: check for common OAuth2 error strings + msg := err.Error() + return strings.Contains(msg, "oauth2: token expired") || + strings.Contains(msg, "oauth2: invalid grant") || + strings.Contains(msg, `oauth2: "token_expired"`) || + strings.Contains(msg, `oauth2: "invalid_grant"`) +} + +func debugJSON(v interface{}) string { + b, err := json.Marshal(v) + if err != nil { + return fmt.Sprintf("(marshal error: %v)", err) + } + return string(b) +} diff --git a/pkg/hook/accessrequesthook/provider.go b/pkg/hook/accessrequesthook/provider.go new file mode 100644 index 00000000..32ea807b --- /dev/null +++ b/pkg/hook/accessrequesthook/provider.go @@ -0,0 +1,90 @@ +package accessrequesthook + +import ( + "context" + "time" +) + +// AccessProvider is the interface that JIT access platforms implement. +type AccessProvider interface { + Ensure(ctx context.Context, req *EnsureRequest) (*EnsureResponse, error) + Login(ctx context.Context) error + RequestURL(accessRequestID string) string +} + +type EnsureRequest struct { + Entitlements []EntitlementInput + Justification Justification + DryRun bool +} + +type EntitlementInput struct { + Target string + Role string + Duration *time.Duration +} + +type Justification struct { + Reason string + Attachments []string +} + +type EnsureResponse struct { + Grants []GrantResult + Validation *ValidationInfo + Diagnostics []Diagnostic +} + +type GrantResult struct { + ID string + Name string + Status GrantStatus + Change GrantChange + Approved bool + Duration time.Duration + ExpiresAt *time.Time + ActivatedAt *time.Time + AccessRequestID string + ProvisioningStatus string + Extension *Extension +} + +type Extension struct { + ExtensionDuration time.Duration +} + +type ValidationInfo struct { + HasReason bool + HasJiraTicket bool +} + +type Diagnostic struct { + Level string + Message string +} + +type GrantStatus string + +const ( + GrantStatusActive GrantStatus = "active" + GrantStatusPending GrantStatus = "pending" + GrantStatusClosed GrantStatus = "closed" + GrantStatusUnspecified GrantStatus = "unspecified" +) + +// Unauthorized is an interface that errors can implement to indicate +// that the user needs to re-authenticate. +type Unauthorized interface { + IsUnauthorized() bool +} + +type GrantChange string + +const ( + GrantChangeNone GrantChange = "none" + GrantChangeActivated GrantChange = "activated" + GrantChangeExtended GrantChange = "extended" + GrantChangeRequested GrantChange = "requested" + GrantChangeProvisioningFailed GrantChange = "provisioning_failed" + GrantChangeUnspecified GrantChange = "" +) diff --git a/pkg/hook/httpprovider/httpprovider.go b/pkg/hook/httpprovider/httpprovider.go new file mode 100644 index 00000000..491ea770 --- /dev/null +++ b/pkg/hook/httpprovider/httpprovider.go @@ -0,0 +1,330 @@ +package httpprovider + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/common-fate/clio" + "github.com/fwdcloudsec/granted/pkg/hook/accessrequesthook" + "github.com/fwdcloudsec/granted/pkg/idclogin" + "github.com/fwdcloudsec/granted/pkg/providercfg" + "github.com/fwdcloudsec/granted/pkg/securestorage" +) + +// HTTPProvider implements AccessProvider using REST/JSON calls. +type HTTPProvider struct { + cfg *providercfg.ProviderConfig + client *http.Client + tokenStorage securestorage.ProviderTokenStorage + providerURL string + tenantID string +} + +// New creates an HTTPProvider from a ProviderConfig. +// If tenantID is empty, falls back to the tenant_id from the provider config +// (auto-populated by single-tenant providers). +func New(cfg *providercfg.ProviderConfig, providerURL string, tenantID string) *HTTPProvider { + if tenantID == "" { + tenantID = cfg.TenantID + } + return &HTTPProvider{ + cfg: cfg, + client: &http.Client{Timeout: 30 * time.Second}, + tokenStorage: securestorage.NewProviderTokenStorage(), + providerURL: providerURL, + tenantID: tenantID, + } +} + +// getToken returns a valid Bearer token, triggering login if needed. +func (p *HTTPProvider) getToken(ctx context.Context, interactive bool) (string, error) { + token := p.tokenStorage.GetValidToken(p.providerURL) + if token != nil { + return token.AccessToken, nil + } + if !interactive { + return "", fmt.Errorf("no valid token for provider %s. Run 'granted auth login --url %s' to authenticate", p.providerURL, p.providerURL) + } + if err := p.Login(ctx); err != nil { + return "", err + } + token = p.tokenStorage.GetValidToken(p.providerURL) + if token == nil { + return "", fmt.Errorf("login succeeded but no token was stored") + } + return token.AccessToken, nil +} + +// setAuthHeaders adds Authorization and optional X-Tenant-ID to a request. +func (p *HTTPProvider) setAuthHeaders(req *http.Request, token string) { + req.Header.Set("Authorization", "Bearer "+token) + if p.tenantID != "" { + req.Header.Set("X-Tenant-ID", p.tenantID) + } +} + +// Ensure calls POST {apiURL}/v1/access/ensure. +func (p *HTTPProvider) Ensure(ctx context.Context, req *accessrequesthook.EnsureRequest) (*accessrequesthook.EnsureResponse, error) { + apiReq := toAPIRequest(req) + + body, err := json.Marshal(apiReq) + if err != nil { + return nil, fmt.Errorf("marshalling ensure request: %w", err) + } + + ensureURL := p.cfg.APIURL + "/v1/access/ensure" + clio.Debugw("calling ensure endpoint", "url", ensureURL, "dry_run", req.DryRun) + + token, err := p.getToken(ctx, true) + if err != nil { + return nil, fmt.Errorf("getting auth token: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, ensureURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + httpReq.Header.Set("Content-Type", "application/json") + p.setAuthHeaders(httpReq, token) + + resp, err := p.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("ensure request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusUnauthorized { + clio.Debug("received 401, attempting re-authentication") + if err := p.Login(ctx); err != nil { + return nil, fmt.Errorf("re-authentication failed: %w", err) + } + newToken, err := p.getToken(ctx, false) + if err != nil { + return nil, &UnauthorizedError{StatusCode: resp.StatusCode} + } + + httpReq, err = http.NewRequestWithContext(ctx, http.MethodPost, ensureURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + httpReq.Header.Set("Content-Type", "application/json") + p.setAuthHeaders(httpReq, newToken) + + resp, err = p.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("ensure retry request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusUnauthorized { + return nil, &UnauthorizedError{StatusCode: resp.StatusCode} + } + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("ensure endpoint returned HTTP %d", resp.StatusCode) + } + + var apiResp apiEnsureResponse + if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil { + return nil, fmt.Errorf("decoding ensure response: %w", err) + } + + return fromAPIResponse(&apiResp), nil +} + +// Login performs OIDC authentication against the provider's identity provider. +func (p *HTTPProvider) Login(ctx context.Context) error { + if p.cfg.Auth.Type != "oidc" { + return fmt.Errorf("unsupported auth type: %s (expected 'oidc')", p.cfg.Auth.Type) + } + + output, err := idclogin.ProviderLogin(ctx, idclogin.ProviderLoginInput{ + IssuerURL: p.cfg.Auth.Issuer, + ClientID: p.cfg.Auth.ClientID, + Scopes: p.cfg.Auth.Scopes, + }) + if err != nil { + return err + } + + token := securestorage.ProviderToken{ + AccessToken: output.AccessToken, + RefreshToken: output.RefreshToken, + IDToken: output.IDToken, + TokenType: output.TokenType, + Expiry: time.Now().Add(time.Duration(output.ExpiresIn) * time.Second), + ProviderURL: p.providerURL, + TenantID: p.tenantID, + } + + return p.tokenStorage.StoreToken(p.providerURL, token) +} + +// RequestURL builds the URL for viewing an access request. +func (p *HTTPProvider) RequestURL(accessRequestID string) string { + u, err := url.Parse(p.cfg.AccessURL) + if err != nil { + return fmt.Sprintf("%s/access/requests/%s", p.cfg.AccessURL, accessRequestID) + } + return u.JoinPath("access", "requests", accessRequestID).String() +} + +// UnauthorizedError indicates the provider returned a 401, meaning the token +// is expired or invalid and the user needs to re-authenticate. +type UnauthorizedError struct { + StatusCode int +} + +func (e *UnauthorizedError) Error() string { + return fmt.Sprintf("unauthorized (HTTP %d): token expired or invalid", e.StatusCode) +} + +func (e *UnauthorizedError) IsUnauthorized() bool { + return true +} + +// IsUnauthorized checks whether an error is an UnauthorizedError. +func IsUnauthorized(err error) bool { + _, ok := err.(*UnauthorizedError) + return ok +} + +// --- API wire types --- + +type apiEnsureRequest struct { + Entitlements []apiEntitlementInput `json:"entitlements"` + Justification apiJustification `json:"justification"` + DryRun bool `json:"dry_run"` +} + +type apiEntitlementInput struct { + Target string `json:"target"` + Role string `json:"role"` + Duration string `json:"duration,omitempty"` +} + +type apiJustification struct { + Reason string `json:"reason,omitempty"` + Attachments []string `json:"attachments,omitempty"` +} + +type apiEnsureResponse struct { + Grants []apiGrantResult `json:"grants"` + Validation *apiValidation `json:"validation,omitempty"` + Diagnostics []apiDiagnostic `json:"diagnostics,omitempty"` +} + +type apiGrantResult struct { + ID string `json:"id"` + Name string `json:"name"` + Status string `json:"status"` + Change string `json:"change"` + Approved bool `json:"approved"` + Duration string `json:"duration"` + ExpiresAt *string `json:"expires_at,omitempty"` + ActivatedAt *string `json:"activated_at,omitempty"` + AccessRequestID string `json:"access_request_id"` + ProvisioningStatus string `json:"provisioning_status,omitempty"` + Extension *apiExtension `json:"extension,omitempty"` +} + +type apiExtension struct { + ExtensionDuration string `json:"extension_duration"` +} + +type apiValidation struct { + HasReason bool `json:"has_reason"` + HasJiraTicket bool `json:"has_jira_ticket"` +} + +type apiDiagnostic struct { + Level string `json:"level"` + Message string `json:"message"` +} + +func toAPIRequest(req *accessrequesthook.EnsureRequest) *apiEnsureRequest { + apiReq := &apiEnsureRequest{ + DryRun: req.DryRun, + Justification: apiJustification{ + Reason: req.Justification.Reason, + Attachments: req.Justification.Attachments, + }, + } + + for _, e := range req.Entitlements { + ent := apiEntitlementInput{ + Target: e.Target, + Role: e.Role, + } + if e.Duration != nil { + ent.Duration = e.Duration.String() + } + apiReq.Entitlements = append(apiReq.Entitlements, ent) + } + + return apiReq +} + +func fromAPIResponse(resp *apiEnsureResponse) *accessrequesthook.EnsureResponse { + result := &accessrequesthook.EnsureResponse{} + + if resp.Validation != nil { + result.Validation = &accessrequesthook.ValidationInfo{ + HasReason: resp.Validation.HasReason, + HasJiraTicket: resp.Validation.HasJiraTicket, + } + } + + for _, d := range resp.Diagnostics { + result.Diagnostics = append(result.Diagnostics, accessrequesthook.Diagnostic{ + Level: d.Level, + Message: d.Message, + }) + } + + for _, g := range resp.Grants { + grant := accessrequesthook.GrantResult{ + ID: g.ID, + Name: g.Name, + Status: accessrequesthook.GrantStatus(g.Status), + Change: accessrequesthook.GrantChange(g.Change), + Approved: g.Approved, + AccessRequestID: g.AccessRequestID, + ProvisioningStatus: g.ProvisioningStatus, + } + + if d, err := time.ParseDuration(g.Duration); err == nil { + grant.Duration = d + } + + if g.ExpiresAt != nil { + if t, err := time.Parse(time.RFC3339, *g.ExpiresAt); err == nil { + grant.ExpiresAt = &t + } + } + + if g.ActivatedAt != nil { + if t, err := time.Parse(time.RFC3339, *g.ActivatedAt); err == nil { + grant.ActivatedAt = &t + } + } + + if g.Extension != nil { + if d, err := time.ParseDuration(g.Extension.ExtensionDuration); err == nil { + grant.Extension = &accessrequesthook.Extension{ + ExtensionDuration: d, + } + } + } + + result.Grants = append(result.Grants, grant) + } + + return result +} diff --git a/pkg/httpregistry/httpregistry.go b/pkg/httpregistry/httpregistry.go new file mode 100644 index 00000000..085e385d --- /dev/null +++ b/pkg/httpregistry/httpregistry.go @@ -0,0 +1,236 @@ +package httpregistry + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "sync" + "time" + + "github.com/common-fate/clio" + "github.com/fwdcloudsec/granted/pkg/awsconfigfile" + "github.com/fwdcloudsec/granted/pkg/idclogin" + "github.com/fwdcloudsec/granted/pkg/providercfg" + "github.com/fwdcloudsec/granted/pkg/securestorage" + "gopkg.in/ini.v1" +) + +type Registry struct { + opts Opts + mu sync.Mutex + cfg *providercfg.ProviderConfig + tokenStorage securestorage.ProviderTokenStorage +} + +type Opts struct { + Name string + URL string + TenantID string +} + +// getConfig lazily loads the provider configuration. +// This avoids slowing down Granted startup when the registry isn't needed. +func (r *Registry) getConfig(interactive bool) (*providercfg.ProviderConfig, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.cfg != nil { + return r.cfg, nil + } + + cfg, err := providercfg.LoadFromURL(context.Background(), r.opts.URL) + if err != nil { + if interactive { + clio.Warnf("Failed to load provider config from %s: %s", r.opts.URL, err) + } + return nil, err + } + + r.cfg = cfg + return r.cfg, nil +} + +// getToken returns a valid Bearer token for the provider, triggering login if interactive. +func (r *Registry) getToken(ctx context.Context, cfg *providercfg.ProviderConfig, interactive bool) (string, error) { + if cfg.Auth.Type != "oidc" { + return "", nil + } + + token := r.tokenStorage.GetValidToken(r.opts.URL) + if token != nil { + return token.AccessToken, nil + } + + if !interactive { + return "", fmt.Errorf("no valid token for provider %s. Run 'granted auth login --url %s' to authenticate", r.opts.URL, r.opts.URL) + } + + output, err := idclogin.ProviderLogin(ctx, idclogin.ProviderLoginInput{ + IssuerURL: cfg.Auth.Issuer, + ClientID: cfg.Auth.ClientID, + Scopes: cfg.Auth.Scopes, + }) + if err != nil { + return "", err + } + + providerToken := securestorage.ProviderToken{ + AccessToken: output.AccessToken, + RefreshToken: output.RefreshToken, + IDToken: output.IDToken, + TokenType: output.TokenType, + Expiry: time.Now().Add(time.Duration(output.ExpiresIn) * time.Second), + ProviderURL: r.opts.URL, + TenantID: r.opts.TenantID, + } + + if err := r.tokenStorage.StoreToken(r.opts.URL, providerToken); err != nil { + clio.Warnf("failed to store provider token: %s", err) + } + + return output.AccessToken, nil +} + +func New(opts Opts) *Registry { + return &Registry{ + opts: opts, + tokenStorage: securestorage.NewProviderTokenStorage(), + } +} + +type listProfilesResponse struct { + Profiles []profileEntry `json:"profiles"` + NextPageToken string `json:"next_page_token"` +} + +type profileEntry struct { + Name string `json:"name"` + Attributes []profileKeyVal `json:"attributes"` +} + +type profileKeyVal struct { + Key string `json:"key"` + Value string `json:"value"` +} + +// fetchProfiles retrieves raw profile entries from the HTTP registry API. +func (r *Registry) fetchProfiles(ctx context.Context, interactive bool) ([]profileEntry, error) { + cfg, err := r.getConfig(interactive) + if err != nil { + return nil, err + } + + client := &http.Client{Timeout: 30 * time.Second} + + accessToken, err := r.getToken(ctx, cfg, interactive) + if err != nil { + return nil, err + } + + var allProfiles []profileEntry + var pageToken string + + for { + listURL := fmt.Sprintf("%s/v1/registry/profiles", cfg.APIURL) + if pageToken != "" { + listURL += "?page_token=" + pageToken + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, listURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "application/json") + + if accessToken != "" { + req.Header.Set("Authorization", "Bearer "+accessToken) + } + tenantID := r.opts.TenantID + if tenantID == "" { + tenantID = cfg.TenantID + } + if tenantID != "" { + req.Header.Set("X-Tenant-ID", tenantID) + } + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("fetching profiles from %s: %w", listURL, err) + } + + if resp.StatusCode != http.StatusOK { + _ = resp.Body.Close() + return nil, fmt.Errorf("profile registry returned HTTP %d from %s", resp.StatusCode, listURL) + } + + var listResp listProfilesResponse + if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil { + _ = resp.Body.Close() + return nil, fmt.Errorf("decoding profile list from %s: %w", listURL, err) + } + _ = resp.Body.Close() + + allProfiles = append(allProfiles, listResp.Profiles...) + + if listResp.NextPageToken == "" { + break + } + pageToken = listResp.NextPageToken + } + + return allProfiles, nil +} + +func (r *Registry) AWSProfiles(ctx context.Context, interactive bool) (*ini.File, error) { + allProfiles, err := r.fetchProfiles(ctx, interactive) + if err != nil { + return nil, err + } + + result := ini.Empty() + + for _, profile := range allProfiles { + section, err := result.NewSection(profile.Name) + if err != nil { + return nil, err + } + + for _, attr := range profile.Attributes { + _, err := section.NewKey(attr.Key, attr.Value) + if err != nil { + return nil, err + } + } + } + + return result, nil +} + +// GetProfiles implements awsconfigfile.Source, allowing an HTTP registry +// to be used as a profile source in `granted sso generate` and `granted sso populate`. +func (r *Registry) GetProfiles(ctx context.Context) ([]awsconfigfile.SSOProfile, error) { + entries, err := r.fetchProfiles(ctx, true) + if err != nil { + return nil, err + } + + var profiles []awsconfigfile.SSOProfile + for _, entry := range entries { + attrs := make(map[string]string, len(entry.Attributes)) + for _, kv := range entry.Attributes { + attrs[kv.Key] = kv.Value + } + + profiles = append(profiles, awsconfigfile.SSOProfile{ + SSOStartURL: attrs["sso_start_url"], + SSORegion: attrs["sso_region"], + AccountID: attrs["sso_account_id"], + AccountName: attrs["account_name"], + RoleName: attrs["sso_role_name"], + GeneratedFrom: r.opts.Name, + }) + } + + return profiles, nil +} diff --git a/pkg/idclogin/provider_login.go b/pkg/idclogin/provider_login.go new file mode 100644 index 00000000..a5a78aac --- /dev/null +++ b/pkg/idclogin/provider_login.go @@ -0,0 +1,312 @@ +package idclogin + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/common-fate/clio" + "github.com/google/uuid" +) + +type ProviderLoginInput struct { + IssuerURL string + ClientID string + Scopes []string + BrowserProfile string +} + +type ProviderLoginOutput struct { + AccessToken string + RefreshToken string + IDToken string + TokenType string + ExpiresIn int +} + +// OIDCDiscovery holds endpoints from a standard OpenID Connect discovery document. +type OIDCDiscovery struct { + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"` +} + +// DiscoverOIDC fetches the OpenID Connect discovery document from the issuer's +// well-known configuration endpoint. +func DiscoverOIDC(ctx context.Context, issuerURL string) (*OIDCDiscovery, error) { + discoveryURL := strings.TrimRight(issuerURL, "/") + "/.well-known/openid-configuration" + clio.Debugw("fetching OIDC discovery document", "url", discoveryURL) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryURL, nil) + if err != nil { + return nil, fmt.Errorf("building OIDC discovery request: %w", err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("fetching OIDC discovery from %s: %w", discoveryURL, err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("OIDC discovery returned HTTP %d from %s", resp.StatusCode, discoveryURL) + } + + var discovery OIDCDiscovery + if err := json.NewDecoder(resp.Body).Decode(&discovery); err != nil { + return nil, fmt.Errorf("decoding OIDC discovery from %s: %w", discoveryURL, err) + } + + if discovery.AuthorizationEndpoint == "" { + return nil, fmt.Errorf("OIDC discovery from %s missing authorization_endpoint", discoveryURL) + } + if discovery.TokenEndpoint == "" { + return nil, fmt.Errorf("OIDC discovery from %s missing token_endpoint", discoveryURL) + } + + return &discovery, nil +} + +// ProviderLogin performs an OAuth Authorization Code + PKCE flow against a +// generic OIDC provider (not AWS SSO specific). It opens the user's browser, +// waits for the authorization callback, and exchanges the code for tokens. +func ProviderLogin(ctx context.Context, input ProviderLoginInput) (*ProviderLoginOutput, error) { + discovery, err := DiscoverOIDC(ctx, input.IssuerURL) + if err != nil { + return nil, err + } + + callbackResult := make(chan providerCallbackResult, 1) + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, fmt.Errorf("failed to start local OAuth callback server: %w", err) + } + + port := listener.Addr().(*net.TCPAddr).Port + callbackURL := fmt.Sprintf("http://127.0.0.1:%d/callback", port) + + codeVerifier, err := generateCodeVerifier() + if err != nil { + _ = listener.Close() + return nil, fmt.Errorf("failed to generate PKCE code verifier: %w", err) + } + codeChallenge := computeCodeChallenge(codeVerifier) + + state := uuid.New().String() + + srv := &http.Server{ + Handler: newProviderCallbackHandler(state, callbackResult), + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 30 * time.Second, + } + go func() { + if err := srv.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { + clio.Debugf("OAuth callback server error: %s", err) + } + }() + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = srv.Shutdown(shutdownCtx) + }() + + authorizeURL, err := buildProviderAuthorizeURL(discovery.AuthorizationEndpoint, input.ClientID, callbackURL, state, codeChallenge, input.Scopes) + if err != nil { + return nil, fmt.Errorf("failed to build authorize URL: %w", err) + } + + if err := OpenBrowserWithFallbackMessage(authorizeURL, input.BrowserProfile); err != nil { + return nil, err + } + + clio.Info("Awaiting authentication in the browser") + clio.Info("You will be prompted to authenticate and approve access") + + var result providerCallbackResult + select { + case result = <-callbackResult: + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(authorizationCallbackTimeout): + return nil, errors.New("timed out waiting for authorization callback") + } + + if result.err != nil { + return nil, fmt.Errorf("authorization failed: %w", result.err) + } + + output, err := exchangeCodeForToken(ctx, discovery.TokenEndpoint, tokenExchangeInput{ + Code: result.code, + ClientID: input.ClientID, + RedirectURI: callbackURL, + CodeVerifier: codeVerifier, + }) + if err != nil { + return nil, err + } + + return output, nil +} + +type providerCallbackResult struct { + code string + err error +} + +const providerCallbackSuccessHTML = ` + +Granted - Authentication Successful + +
+

Authentication Successful

+

You have successfully authenticated with your access provider.

+

You can close this window and return to your terminal.

+
+ +` + +func newProviderCallbackHandler(expectedState string, result chan<- providerCallbackResult) http.Handler { + var once sync.Once + mux := http.NewServeMux() + mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var handled bool + once.Do(func() { + handled = true + query := r.URL.Query() + + if errParam := query.Get("error"); errParam != "" { + errDesc := query.Get("error_description") + writeErrorPage(w, errParam, errDesc) + result <- providerCallbackResult{err: fmt.Errorf("%s: %s", errParam, errDesc)} + return + } + + code := query.Get("code") + st := query.Get("state") + + if st != expectedState { + writeErrorPage(w, "state_mismatch", "The state parameter did not match. This may indicate a CSRF attack.") + result <- providerCallbackResult{err: errors.New("OAuth state parameter mismatch")} + return + } + + if code == "" { + writeErrorPage(w, "missing_code", "No authorization code was received.") + result <- providerCallbackResult{err: errors.New("no authorization code received")} + return + } + + setSecurityHeaders(w) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(providerCallbackSuccessHTML)) + result <- providerCallbackResult{code: code} + }) + + if !handled { + http.Error(w, "Authorization already processed", http.StatusConflict) + } + }) + return mux +} + +func buildProviderAuthorizeURL(authorizationEndpoint, clientID, redirectURI, state, codeChallenge string, scopes []string) (string, error) { + u, err := url.Parse(authorizationEndpoint) + if err != nil { + return "", err + } + + q := u.Query() + q.Set("response_type", "code") + q.Set("client_id", clientID) + q.Set("redirect_uri", redirectURI) + q.Set("state", state) + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") + q.Set("scope", strings.Join(scopes, " ")) + u.RawQuery = q.Encode() + + return u.String(), nil +} + +type tokenExchangeInput struct { + Code string + ClientID string + RedirectURI string + CodeVerifier string +} + +type tokenExchangeResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + IDToken string `json:"id_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Error string `json:"error"` + ErrorDesc string `json:"error_description"` +} + +func exchangeCodeForToken(ctx context.Context, tokenEndpoint string, input tokenExchangeInput) (*ProviderLoginOutput, error) { + data := url.Values{ + "grant_type": {"authorization_code"}, + "code": {input.Code}, + "client_id": {input.ClientID}, + "redirect_uri": {input.RedirectURI}, + "code_verifier": {input.CodeVerifier}, + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenEndpoint, strings.NewReader(data.Encode())) + if err != nil { + return nil, fmt.Errorf("building token exchange request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("token exchange request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading token exchange response: %w", err) + } + + var tokenResp tokenExchangeResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("decoding token exchange response: %w", err) + } + + if tokenResp.Error != "" { + return nil, fmt.Errorf("token exchange error: %s: %s", tokenResp.Error, tokenResp.ErrorDesc) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token endpoint returned HTTP %d", resp.StatusCode) + } + + if tokenResp.AccessToken == "" { + return nil, errors.New("token exchange returned empty access_token") + } + + return &ProviderLoginOutput{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + IDToken: tokenResp.IDToken, + TokenType: tokenResp.TokenType, + ExpiresIn: tokenResp.ExpiresIn, + }, nil +} diff --git a/pkg/providercfg/providercfg.go b/pkg/providercfg/providercfg.go new file mode 100644 index 00000000..a93561bf --- /dev/null +++ b/pkg/providercfg/providercfg.go @@ -0,0 +1,95 @@ +package providercfg + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + + "github.com/common-fate/clio" + "github.com/fwdcloudsec/granted/pkg/cfaws" +) + +type ProviderConfig struct { + Provider string `json:"provider"` + Version string `json:"version"` + APIURL string `json:"api_url"` + AccessURL string `json:"access_url"` + TenantID string `json:"tenant_id,omitempty"` + Auth AuthConfig `json:"auth"` +} + +type AuthConfig struct { + Type string `json:"type"` + Issuer string `json:"issuer"` + ClientID string `json:"client_id"` + Scopes []string `json:"scopes"` +} + +// LoadFromURL fetches the provider configuration from {providerURL}/granted/config.json. +func LoadFromURL(ctx context.Context, providerURL string) (*ProviderConfig, error) { + u, err := url.Parse(providerURL) + if err != nil { + return nil, fmt.Errorf("invalid provider URL (%s): %w", providerURL, err) + } + + configURL := u.JoinPath("granted", "config.json").String() + clio.Debugw("loading provider config", "url", configURL) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, configURL, nil) + if err != nil { + return nil, err + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("fetching provider config from %s: %w", configURL, err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("provider config returned HTTP %d from %s", resp.StatusCode, configURL) + } + + var cfg ProviderConfig + if err := json.NewDecoder(resp.Body).Decode(&cfg); err != nil { + return nil, fmt.Errorf("decoding provider config from %s: %w", configURL, err) + } + + return &cfg, nil +} + +// GetProviderURL reads the access provider URL from a profile's raw config. +// It checks granted_access_provider_url first, then common_fate_url as a legacy alias. +// Returns an empty string if neither key is set. +func GetProviderURL(profile *cfaws.Profile) string { + if profile == nil || profile.RawConfig == nil { + return "" + } + + for _, key := range []string{"granted_access_provider_url", "common_fate_url"} { + if profile.RawConfig.HasKey(key) { + k, err := profile.RawConfig.GetKey(key) + if err != nil { + clio.Debugw("error reading profile key", "key", key, "error", err) + continue + } + if k.Value() != "" { + return k.Value() + } + } + } + + return "" +} + +// GenerateRequestURL builds a URL to view an access request in the provider UI. +func GenerateRequestURL(accessURL string, requestID string) (string, error) { + u, err := url.Parse(accessURL) + if err != nil { + return "", err + } + p := u.JoinPath("access", "requests", requestID) + return p.String(), nil +} diff --git a/pkg/securestorage/provider_token_storage.go b/pkg/securestorage/provider_token_storage.go new file mode 100644 index 00000000..ac4e0888 --- /dev/null +++ b/pkg/securestorage/provider_token_storage.go @@ -0,0 +1,47 @@ +package securestorage + +import ( + "time" +) + +type ProviderTokenStorage struct { + SecureStorage SecureStorage +} + +type ProviderToken struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + IDToken string `json:"id_token,omitempty"` + TokenType string `json:"token_type"` + Expiry time.Time `json:"expiry"` + ProviderURL string `json:"provider_url"` + TenantID string `json:"tenant_id,omitempty"` +} + +func NewProviderTokenStorage() ProviderTokenStorage { + return ProviderTokenStorage{ + SecureStorage: SecureStorage{StorageSuffix: "granted-provider-tokens"}, + } +} + +// GetValidToken returns a stored token if it exists and is not expired. +// Returns nil if no valid token exists. +func (s *ProviderTokenStorage) GetValidToken(providerURL string) *ProviderToken { + var token ProviderToken + err := s.SecureStorage.Retrieve(providerURL, &token) + if err != nil { + return nil + } + if time.Now().After(token.Expiry) { + return nil + } + return &token +} + +func (s *ProviderTokenStorage) StoreToken(providerURL string, token ProviderToken) error { + return s.SecureStorage.Store(providerURL, token) +} + +func (s *ProviderTokenStorage) ClearToken(providerURL string) error { + return s.SecureStorage.Clear(providerURL) +}