diff --git a/README.md b/README.md index 044386655c3..54936686431 100644 --- a/README.md +++ b/README.md @@ -230,7 +230,7 @@ If you press `shift+w` on a commit (or branch/ref) a menu will open that allows ### Show GitHub pull requests -In the branches panel, lazygit can show which of your branches have an associated GitHub pull request by showing a GitHub icon next to the branch name; its color shows the state of the PR (open, merged, etc.). For those that have one, you can press `shift-G` to open the PR in the browser. There is no configuration needed to enable this, but it requires the [`gh`](https://cli.github.com/) tool to be installed, and you need to do `gh auth login` once to allow lazygit to access GitHub. +In the branches panel, lazygit can show which of your branches have an associated GitHub pull request by showing a GitHub icon next to the branch name; its color shows the state of the PR (open, merged, etc.). For those that have one, you can press `shift-G` to open the PR in the browser. There is no configuration needed to enable this for github.com, but it requires the [`gh`](https://cli.github.com/) tool to be installed, and you need to do `gh auth login` once to allow lazygit to access GitHub. For GitHub Enterprise, also run `gh auth login --hostname ` and add a [`services` entry](docs/Config.md#custom-pull-request-urls) for the host with the `github` provider. ## Tutorials diff --git a/docs-master/Config.md b/docs-master/Config.md index 42eff4666a1..05d3b6d20f5 100644 --- a/docs-master/Config.md +++ b/docs-master/Config.md @@ -1117,6 +1117,8 @@ Where: - `provider` is one of `github`, `bitbucket`, `bitbucketServer`, `azuredevops`, `gitlab`, `gitea` or `codeberg` - `webDomain` is the URL where your git service exposes a web interface and APIs, e.g. `gitservice.work.com` +For the `github` provider, configuring an entry here also enables the pull-request icons in the branches panel for that host (e.g. a GitHub Enterprise Server instance). Lazygit picks up the auth token via the same mechanisms as the `gh` CLI: the `GH_ENTERPRISE_TOKEN` / `GITHUB_ENTERPRISE_TOKEN` environment variables, or `gh auth login --hostname `. + ## Predefined commit message prefix In situations where certain naming pattern is used for branches and commits, pattern can be used to populate commit message with prefix that is parsed from the branch name. diff --git a/pkg/commands/git_commands/github.go b/pkg/commands/git_commands/github.go index 85893615dcb..e05472ef10f 100644 --- a/pkg/commands/git_commands/github.go +++ b/pkg/commands/git_commands/github.go @@ -138,19 +138,16 @@ func fetchPullRequestsQuery(branches []string, owner string, repo string) (strin return queryString, variables } -func (self *GitHubCommands) GetAuthToken() string { - defaultHost, _ := auth.DefaultHost() - token, _ := auth.TokenForHost(defaultHost) +func (self *GitHubCommands) GetAuthToken(host string) string { + token, _ := auth.TokenForHost(host) return token } -// FetchRecentPRs fetches recent pull requests using GraphQL. -func (self *GitHubCommands) FetchRecentPRs(branches []string, baseRemote *models.Remote, token string) ([]*models.GithubPullRequest, error) { - repoOwner, repoName, err := self.GetBaseRepoOwnerAndName(baseRemote) - if err != nil { - return nil, err - } - +// FetchRecentPRs fetches recent pull requests using GraphQL. serviceInfo +// identifies the GitHub instance (github.com or a GitHub Enterprise Server) +// and the owner/repo to query against. +func (self *GitHubCommands) FetchRecentPRs(branches []string, serviceInfo *hosting_service.ServiceInfo, token string) ([]*models.GithubPullRequest, error) { + endpoint := graphQLEndpoint(serviceInfo.WebDomain) t := time.Now() var g errgroup.Group @@ -171,7 +168,7 @@ func (self *GitHubCommands) FetchRecentPRs(branches []string, baseRemote *models // Launch a goroutine for each chunk of branches g.Go(func() error { - prs, err := self.fetchRecentPRsAux(repoOwner, repoName, branchChunk, token) + prs, err := self.fetchRecentPRsAux(endpoint, serviceInfo.Owner, serviceInfo.Repository, branchChunk, token) if err != nil { return err } @@ -181,7 +178,7 @@ func (self *GitHubCommands) FetchRecentPRs(branches []string, baseRemote *models } // Wait for all goroutines, then close the channel so the range loop exits - err = g.Wait() + err := g.Wait() close(results) if err != nil { return nil, err @@ -198,14 +195,14 @@ func (self *GitHubCommands) FetchRecentPRs(branches []string, baseRemote *models return allPRs, nil } -func (self *GitHubCommands) fetchRecentPRsAux(repoOwner string, repoName string, branches []string, token string) ([]*models.GithubPullRequest, error) { +func (self *GitHubCommands) fetchRecentPRsAux(endpoint string, repoOwner string, repoName string, branches []string, token string) ([]*models.GithubPullRequest, error) { queryString, variables := fetchPullRequestsQuery(branches, repoOwner, repoName) bodyBytes, err := json.Marshal(graphQLRequest{Query: queryString, Variables: variables}) if err != nil { return nil, err } - req, err := http.NewRequest("POST", "https://api.github.com/graphql", bytes.NewBuffer(bodyBytes)) + req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(bodyBytes)) if err != nil { return nil, err } @@ -336,45 +333,12 @@ func getRemotesToOwnersMap(remotes []*models.Remote) map[string]string { return res } -func (self *GitHubCommands) InGithubRepo(remotes []*models.Remote) bool { - if len(remotes) == 0 { - return false - } - - remote := getMainRemote(remotes) - - if len(remote.Urls) == 0 { - return false - } - - url := remote.Urls[0] - return strings.Contains(strings.ToLower(url), "github.com") -} - -func getMainRemote(remotes []*models.Remote) *models.Remote { - for _, remote := range remotes { - if remote.Name == "origin" { - return remote - } - } - - // need to sort remotes by name so that this is deterministic - return lo.MinBy(remotes, func(a, b *models.Remote) bool { - return a.Name < b.Name - }) -} - -func (self *GitHubCommands) GetBaseRepoOwnerAndName(baseRemote *models.Remote) (string, string, error) { - if len(baseRemote.Urls) == 0 { - return "", "", fmt.Errorf("No URLs found for remote") +// graphQLEndpoint returns the GraphQL API URL for a GitHub host. github.com +// uses a dedicated api. subdomain; GitHub Enterprise Server hangs the API off +// the web host under /api/graphql. +func graphQLEndpoint(host string) string { + if auth.NormalizeHostname(host) == "github.com" { + return "https://api.github.com/graphql" } - - url := baseRemote.Urls[0] - - repoInfo, err := hosting_service.GetRepoInfoFromURL(url) - if err != nil { - return "", "", err - } - - return repoInfo.Owner, repoInfo.Repository, nil + return "https://" + host + "/api/graphql" } diff --git a/pkg/commands/git_commands/github_test.go b/pkg/commands/git_commands/github_test.go index d9d55ffd1ae..b332ba12a21 100644 --- a/pkg/commands/git_commands/github_test.go +++ b/pkg/commands/git_commands/github_test.go @@ -57,6 +57,25 @@ func TestGetRepoInfoFromURL(t *testing.T) { } } +func TestGraphQLEndpoint(t *testing.T) { + cases := []struct { + host string + expected string + }{ + {"github.com", "https://api.github.com/graphql"}, + {"www.github.com", "https://api.github.com/graphql"}, + {"GITHUB.com", "https://api.github.com/graphql"}, + {"ghe.example.com", "https://ghe.example.com/api/graphql"}, + {"ghe.example.com:8443", "https://ghe.example.com:8443/api/graphql"}, + } + + for _, c := range cases { + t.Run(c.host, func(t *testing.T) { + assert.Equal(t, c.expected, graphQLEndpoint(c.host)) + }) + } +} + func TestGenerateGithubPullRequestMap(t *testing.T) { cases := []struct { name string diff --git a/pkg/commands/git_commands/hosting_service.go b/pkg/commands/git_commands/hosting_service.go index 7d977212703..f43b93e9037 100644 --- a/pkg/commands/git_commands/hosting_service.go +++ b/pkg/commands/git_commands/hosting_service.go @@ -21,8 +21,8 @@ func (self *HostingService) GetCommitURL(commitSha string) (string, error) { return self.getHostingServiceMgr(self.config.GetRemoteURL()).GetCommitURL(commitSha) } -func (self *HostingService) GetRepoNameFromRemoteURL(remoteURL string) (string, error) { - return self.getHostingServiceMgr(remoteURL).GetRepoName() +func (self *HostingService) GetServiceInfo(remoteURL string) (hosting_service.ServiceInfo, error) { + return self.getHostingServiceMgr(remoteURL).GetServiceInfo() } // getting this on every request rather than storing it in state in case our remoteURL changes diff --git a/pkg/commands/hosting_service/definitions.go b/pkg/commands/hosting_service/definitions.go index 130bf04811f..09fa191c832 100644 --- a/pkg/commands/hosting_service/definitions.go +++ b/pkg/commands/hosting_service/definitions.go @@ -1,10 +1,12 @@ package hosting_service +import "regexp" + // if you want to make a custom regex for a given service feel free to test it out // at https://regex101.com using the flavor Golang -var defaultUrlRegexStrings = []string{ - `^(?:https?|ssh)://[^/]+/(?P.*)/(?P.*?)(?:\.git)?$`, - `^(.*?@)?.*:/*(?P.*)/(?P.*?)(?:\.git)?$`, +var defaultUrlRegexps = []*regexp.Regexp{ + regexp.MustCompile(`^(?:https?|ssh)://[^/]+/(?P.*)/(?P.*?)(?:\.git)?$`), + regexp.MustCompile(`^(.*?@)?.*:/*(?P.*)/(?P.*?)(?:\.git)?$`), } var ( @@ -19,7 +21,7 @@ var githubServiceDef = ServiceDefinition{ pullRequestURLIntoDefaultBranch: "/compare/{{.From}}?expand=1", pullRequestURLIntoTargetBranch: "/compare/{{.To}}...{{.From}}?expand=1", commitURL: "/commit/{{.CommitHash}}", - regexStrings: defaultUrlRegexStrings, + urlRegexps: defaultUrlRegexps, repoURLTemplate: defaultRepoURLTemplate, repoNameTemplate: defaultRepoNameTemplate, } @@ -29,9 +31,9 @@ var bitbucketServiceDef = ServiceDefinition{ pullRequestURLIntoDefaultBranch: "/pull-requests/new?source={{.From}}&t=1", pullRequestURLIntoTargetBranch: "/pull-requests/new?source={{.From}}&dest={{.To}}&t=1", commitURL: "/commits/{{.CommitHash}}", - regexStrings: []string{ - `^(?:https?|ssh)://.*/(?P.*)/(?P.*?)(?:\.git)?$`, - `^.*@.*:/*(?P.*)/(?P.*?)(?:\.git)?$`, + urlRegexps: []*regexp.Regexp{ + regexp.MustCompile(`^(?:https?|ssh)://.*/(?P.*)/(?P.*?)(?:\.git)?$`), + regexp.MustCompile(`^.*@.*:/*(?P.*)/(?P.*?)(?:\.git)?$`), }, repoURLTemplate: defaultRepoURLTemplate, repoNameTemplate: defaultRepoNameTemplate, @@ -42,7 +44,7 @@ var gitLabServiceDef = ServiceDefinition{ pullRequestURLIntoDefaultBranch: "/-/merge_requests/new?merge_request%5Bsource_branch%5D={{.From}}", pullRequestURLIntoTargetBranch: "/-/merge_requests/new?merge_request%5Bsource_branch%5D={{.From}}&merge_request%5Btarget_branch%5D={{.To}}", commitURL: "/-/commit/{{.CommitHash}}", - regexStrings: defaultUrlRegexStrings, + urlRegexps: defaultUrlRegexps, repoURLTemplate: defaultRepoURLTemplate, repoNameTemplate: defaultRepoNameTemplate, } @@ -52,11 +54,11 @@ var azdoServiceDef = ServiceDefinition{ pullRequestURLIntoDefaultBranch: "/pullrequestcreate?sourceRef={{.From}}", pullRequestURLIntoTargetBranch: "/pullrequestcreate?sourceRef={{.From}}&targetRef={{.To}}", commitURL: "/commit/{{.CommitHash}}", - regexStrings: []string{ - `^.+@vs-ssh\.visualstudio\.com[:/](?:v3/)?(?P[^/]+)/(?P[^/]+)/(?P[^/]+?)(?:\.git)?$`, - `^git@ssh.dev.azure.com.*/(?P.*)/(?P.*)/(?P.*?)(?:\.git)?$`, - `^https://.*@dev.azure.com/(?P.*?)/(?P.*?)/_git/(?P.*?)(?:\.git)?$`, - `^https://.*/(?P.*?)/(?P.*?)/_git/(?P.*?)(?:\.git)?$`, + urlRegexps: []*regexp.Regexp{ + regexp.MustCompile(`^.+@vs-ssh\.visualstudio\.com[:/](?:v3/)?(?P[^/]+)/(?P[^/]+)/(?P[^/]+?)(?:\.git)?$`), + regexp.MustCompile(`^git@ssh.dev.azure.com.*/(?P.*)/(?P.*)/(?P.*?)(?:\.git)?$`), + regexp.MustCompile(`^https://.*@dev.azure.com/(?P.*?)/(?P.*?)/_git/(?P.*?)(?:\.git)?$`), + regexp.MustCompile(`^https://.*/(?P.*?)/(?P.*?)/_git/(?P.*?)(?:\.git)?$`), }, repoURLTemplate: "https://{{.webDomain}}/{{.org}}/{{.project}}/_git/{{.repo}}", repoNameTemplate: "{{.org}}/{{.project}}/{{.repo}}", @@ -67,9 +69,9 @@ var bitbucketServerServiceDef = ServiceDefinition{ pullRequestURLIntoDefaultBranch: "/pull-requests?create&sourceBranch={{.From}}", pullRequestURLIntoTargetBranch: "/pull-requests?create&targetBranch={{.To}}&sourceBranch={{.From}}", commitURL: "/commits/{{.CommitHash}}", - regexStrings: []string{ - `^ssh://git@.*/(?P.*)/(?P.*?)(?:\.git)?$`, - `^https://.*/scm/(?P.*)/(?P.*?)(?:\.git)?$`, + urlRegexps: []*regexp.Regexp{ + regexp.MustCompile(`^ssh://git@.*/(?P.*)/(?P.*?)(?:\.git)?$`), + regexp.MustCompile(`^https://.*/scm/(?P.*)/(?P.*?)(?:\.git)?$`), }, repoURLTemplate: "https://{{.webDomain}}/projects/{{.project}}/repos/{{.repo}}", repoNameTemplate: "{{.project}}/{{.repo}}", @@ -80,7 +82,7 @@ var giteaServiceDef = ServiceDefinition{ pullRequestURLIntoDefaultBranch: "/compare/{{.From}}", pullRequestURLIntoTargetBranch: "/compare/{{.To}}...{{.From}}", commitURL: "/commit/{{.CommitHash}}", - regexStrings: defaultUrlRegexStrings, + urlRegexps: defaultUrlRegexps, repoURLTemplate: defaultRepoURLTemplate, } @@ -89,7 +91,7 @@ var codebergServiceDef = ServiceDefinition{ pullRequestURLIntoDefaultBranch: "/compare/{{.From}}", pullRequestURLIntoTargetBranch: "/compare/{{.To}}...{{.From}}", commitURL: "/commit/{{.CommitHash}}", - regexStrings: defaultUrlRegexStrings, + urlRegexps: defaultUrlRegexps, repoURLTemplate: defaultRepoURLTemplate, } diff --git a/pkg/commands/hosting_service/hosting_service.go b/pkg/commands/hosting_service/hosting_service.go index 620d0d0a788..ff2641441e9 100644 --- a/pkg/commands/hosting_service/hosting_service.go +++ b/pkg/commands/hosting_service/hosting_service.go @@ -73,6 +73,42 @@ func (self *HostingServiceMgr) GetRepoName() (string, error) { return repoName, nil } +// ServiceInfo holds the resolved hosting service for a remote URL. Owner +// comes from the "owner" named regex capture, which only exists for +// owner/repo-shaped providers (github, gitlab, bitbucket, gitea, codeberg); +// it's empty for azuredevops and bitbucketServer, whose URLs are organised +// differently. Repository is populated for every provider, but RepoName may +// have more than two segments (e.g. "org/project/repo" for azuredevops). +type ServiceInfo struct { + Provider string // e.g. "github" + WebDomain string // e.g. "github.com", or "git.acme.com" for an on-prem instance + Owner string // e.g. "jesseduffield" + Repository string // e.g. "lazygit" + RepoName string // e.g. "jesseduffield/lazygit" +} + +// GetServiceInfo identifies which hosting service the configured remote URL +// belongs to and returns enough information to talk to its web/API host. +func (self *HostingServiceMgr) GetServiceInfo() (ServiceInfo, error) { + serviceDomain, err := self.getServiceDomain(self.remoteURL) + if err != nil { + return ServiceInfo{}, err + } + + matches, err := serviceDomain.serviceDefinition.parseRemoteUrl(self.remoteURL) + if err != nil { + return ServiceInfo{}, err + } + + return ServiceInfo{ + Provider: serviceDomain.serviceDefinition.provider, + WebDomain: serviceDomain.webDomain, + Owner: matches["owner"], + Repository: matches["repo"], + RepoName: utils.ResolvePlaceholderString(serviceDomain.serviceDefinition.repoNameTemplate, matches), + }, nil +} + func (self *HostingServiceMgr) getService() (*Service, error) { serviceDomain, err := self.getServiceDomain(self.remoteURL) if err != nil { @@ -159,7 +195,7 @@ type ServiceDefinition struct { pullRequestURLIntoDefaultBranch string pullRequestURLIntoTargetBranch string commitURL string - regexStrings []string + urlRegexps []*regexp.Regexp // can expect 'webdomain' to be passed in. Otherwise, you get to pick what we match in the regex repoURLTemplate string @@ -186,8 +222,7 @@ func (self ServiceDefinition) getRepoNameFromRemoteURL(url string) (string, erro } func (self ServiceDefinition) parseRemoteUrl(url string) (map[string]string, error) { - for _, regexStr := range self.regexStrings { - re := regexp.MustCompile(regexStr) + for _, re := range self.urlRegexps { matches := utils.FindNamedMatches(re, url) if matches != nil { return matches, nil @@ -206,8 +241,7 @@ type RepoInformation struct { // GetRepoInfoFromURL parses a remote URL (SSH or HTTPS) and extracts the // owner and repository name using the default URL regex patterns. func GetRepoInfoFromURL(url string) (RepoInformation, error) { - for _, regexStr := range defaultUrlRegexStrings { - re := regexp.MustCompile(regexStr) + for _, re := range defaultUrlRegexps { matches := utils.FindNamedMatches(re, url) if matches != nil { return RepoInformation{ diff --git a/pkg/commands/hosting_service/hosting_service_test.go b/pkg/commands/hosting_service/hosting_service_test.go index c2fabcd0de6..f150f22eb4d 100644 --- a/pkg/commands/hosting_service/hosting_service_test.go +++ b/pkg/commands/hosting_service/hosting_service_test.go @@ -577,3 +577,107 @@ func TestGetPullRequestURL(t *testing.T) { }) } } + +func TestGetServiceInfo(t *testing.T) { + scenarios := []struct { + name string + remoteURL string + configServiceDomains map[string]string + expected ServiceInfo + }{ + { + name: "github.com SSH", + remoteURL: "git@github.com:jesseduffield/lazygit.git", + expected: ServiceInfo{ + Provider: "github", + WebDomain: "github.com", + Owner: "jesseduffield", + Repository: "lazygit", + RepoName: "jesseduffield/lazygit", + }, + }, + { + name: "github enterprise with same git and web host", + remoteURL: "git@github.example.com:my-org/my-repo.git", + configServiceDomains: map[string]string{ + "github.example.com": "github:github.example.com", + }, + expected: ServiceInfo{ + Provider: "github", + WebDomain: "github.example.com", + Owner: "my-org", + Repository: "my-repo", + RepoName: "my-org/my-repo", + }, + }, + { + name: "github enterprise with distinct git and web hosts", + remoteURL: "git@git.example.com:my-org/my-repo.git", + configServiceDomains: map[string]string{ + "git.example.com": "github:ghe.example.com", + }, + expected: ServiceInfo{ + Provider: "github", + WebDomain: "ghe.example.com", + Owner: "my-org", + Repository: "my-repo", + RepoName: "my-org/my-repo", + }, + }, + { + name: "github enterprise with web host port", + remoteURL: "git@git.example.com:my-org/my-repo.git", + configServiceDomains: map[string]string{ + "git.example.com": "github:ghe.example.com:8443", + }, + expected: ServiceInfo{ + Provider: "github", + WebDomain: "ghe.example.com:8443", + Owner: "my-org", + Repository: "my-repo", + RepoName: "my-org/my-repo", + }, + }, + { + // azuredevops uses org/project/repo named captures rather than + // owner/repo, so Owner is unpopulated and RepoName has three + // segments rather than the usual two. + name: "azuredevops", + remoteURL: "https://myorg@dev.azure.com/myorg/myproject/_git/myrepo", + expected: ServiceInfo{ + Provider: "azuredevops", + WebDomain: "dev.azure.com", + Repository: "myrepo", + RepoName: "myorg/myproject/myrepo", + }, + }, + { + // bitbucketServer uses project/repo named captures, so Owner is + // unpopulated and RepoName is project/repo rather than owner/repo. + name: "bitbucketServer", + remoteURL: "https://mycompany.bitbucket.com/scm/myproject/myrepo.git", + configServiceDomains: map[string]string{ + "mycompany.bitbucket.com": "bitbucketServer:mycompany.bitbucket.com", + }, + expected: ServiceInfo{ + Provider: "bitbucketServer", + WebDomain: "mycompany.bitbucket.com", + Repository: "myrepo", + RepoName: "myproject/myrepo", + }, + }, + } + + for _, s := range scenarios { + t.Run(s.name, func(t *testing.T) { + tr := i18n.EnglishTranslationSet() + log := &fakes.FakeFieldLogger{} + mgr := NewHostingServiceMgr(log, tr, s.remoteURL, s.configServiceDomains) + + info, err := mgr.GetServiceInfo() + + assert.NoError(t, err) + assert.Equal(t, s.expected, info) + }) + } +} diff --git a/pkg/gui/controllers/helpers/refresh_helper.go b/pkg/gui/controllers/helpers/refresh_helper.go index 77de9ca4a0f..6c554deff70 100644 --- a/pkg/gui/controllers/helpers/refresh_helper.go +++ b/pkg/gui/controllers/helpers/refresh_helper.go @@ -8,6 +8,7 @@ import ( "github.com/jesseduffield/generics/set" "github.com/jesseduffield/lazygit/pkg/commands/git_commands" + "github.com/jesseduffield/lazygit/pkg/commands/hosting_service" "github.com/jesseduffield/lazygit/pkg/commands/models" "github.com/jesseduffield/lazygit/pkg/config" "github.com/jesseduffield/lazygit/pkg/gocui" @@ -812,39 +813,33 @@ func (self *RefreshHelper) refreshGithubPullRequests() { self.c.Mutexes().RefreshingPullRequestsMutex.Lock() defer self.c.Mutexes().RefreshingPullRequestsMutex.Unlock() - if !self.c.Git().GitHub.InGithubRepo(self.c.Model().Remotes) { + githubRemotes := getAuthenticatedGithubRemotes(self.getGithubRemotes(), self.c.Git().GitHub.GetAuthToken) + if len(githubRemotes) == 0 { self.c.Model().PullRequests = nil self.c.Model().PullRequestsMap = nil return } - authToken := self.c.Git().GitHub.GetAuthToken() - if authToken == "" { - self.c.Model().PullRequests = nil - self.c.Model().PullRequestsMap = nil - return - } - - githubRemotes := self.getGithubRemotes() - baseRemote := getGithubBaseRemote(githubRemotes, self.c.Git().GitHub.ConfiguredBaseRemoteName()) - if baseRemote == nil { + baseInfo := getGithubBaseRemote(githubRemotes, self.c.Git().GitHub.ConfiguredBaseRemoteName()) + if baseInfo == nil { self.c.Model().PullRequests = nil self.c.Model().PullRequestsMap = nil - if len(githubRemotes) > 0 && !self.githubBaseRemotePromptDismissed[self.c.Git().RepoPaths.RepoPath()] { - self.promptForBaseGithubRepo(authToken, githubRemotes) + if !self.githubBaseRemotePromptDismissed[self.c.Git().RepoPaths.RepoPath()] { + self.promptForBaseGithubRepo(githubRemotes) } return } - if err := self.setGithubPullRequests(authToken, baseRemote); err != nil { + if err := self.setGithubPullRequests(baseInfo); err != nil { self.c.LogAction(fmt.Sprintf("Error fetching pull requests from GitHub: %s", err.Error())) } } type githubRemoteInfo struct { - remote *models.Remote - repoName string + remote *models.Remote + serviceInfo hosting_service.ServiceInfo + authToken string } func (self *RefreshHelper) getGithubRemotes() []githubRemoteInfo { @@ -852,23 +847,44 @@ func (self *RefreshHelper) getGithubRemotes() []githubRemoteInfo { if len(remote.Urls) == 0 { return githubRemoteInfo{}, false } - repoName, err := self.c.Git().HostingService.GetRepoNameFromRemoteURL(remote.Urls[0]) - if err != nil { + serviceInfo, err := self.c.Git().HostingService.GetServiceInfo(remote.Urls[0]) + if err != nil || serviceInfo.Provider != "github" { + return githubRemoteInfo{}, false + } + return githubRemoteInfo{remote: remote, serviceInfo: serviceInfo}, true + }) +} + +// getAuthenticatedGithubRemotes drops remotes for which no auth token is +// available and attaches the resolved token to the rest. Token lookups are +// cached by host so that multiple remotes pointing at the same instance +// (e.g. origin + a fork on github.com) only trigger one lookup. +func getAuthenticatedGithubRemotes(githubRemotes []githubRemoteInfo, getAuthToken func(host string) string) []githubRemoteInfo { + tokensByHost := map[string]string{} + return lo.FilterMap(githubRemotes, func(info githubRemoteInfo, _ int) (githubRemoteInfo, bool) { + host := info.serviceInfo.WebDomain + token, cached := tokensByHost[host] + if !cached { + token = getAuthToken(host) + tokensByHost[host] = token + } + if token == "" { return githubRemoteInfo{}, false } - return githubRemoteInfo{remote: remote, repoName: repoName}, true + info.authToken = token + return info, true }) } -func getGithubBaseRemote(githubRemotes []githubRemoteInfo, configuredRemoteName string) *models.Remote { - findRemoteByName := func(name string) *models.Remote { +func getGithubBaseRemote(githubRemotes []githubRemoteInfo, configuredRemoteName string) *githubRemoteInfo { + findRemoteByName := func(name string) *githubRemoteInfo { info, ok := lo.Find(githubRemotes, func(info githubRemoteInfo) bool { return info.remote.Name == name }) if !ok { return nil } - return info.remote + return &info } if configuredRemoteName != "" { @@ -876,29 +892,29 @@ func getGithubBaseRemote(githubRemotes []githubRemoteInfo, configuredRemoteName } if len(githubRemotes) == 1 { - return githubRemotes[0].remote + return &githubRemotes[0] } // Not sure if "upstream" is really a common convention for the name of the remote that PRs are // made against, but if it exists it's pretty likely to be the one we want. - if remote := findRemoteByName("upstream"); remote != nil { - return remote + if info := findRemoteByName("upstream"); info != nil { + return info } return nil } -func (self *RefreshHelper) promptForBaseGithubRepo(authToken string, githubRemotes []githubRemoteInfo) { +func (self *RefreshHelper) promptForBaseGithubRepo(githubRemotes []githubRemoteInfo) { menuItems := lo.Map(githubRemotes, func(info githubRemoteInfo, _ int) *types.MenuItem { return &types.MenuItem{ - LabelColumns: []string{info.remote.Name, style.FgCyan.Sprint(info.repoName)}, + LabelColumns: []string{info.remote.Name, style.FgCyan.Sprint(info.serviceInfo.RepoName)}, OnPress: func() error { return self.c.WithWaitingStatus(self.c.Tr.FetchingPullRequests, func(gocui.Task) error { if err := self.c.Git().GitHub.SetConfiguredBaseRemoteName(info.remote.Name); err != nil { self.c.Log.Error(err) } - if err := self.setGithubPullRequests(authToken, info.remote); err != nil { + if err := self.setGithubPullRequests(&info); err != nil { self.c.LogAction(fmt.Sprintf("Error fetching pull requests from GitHub: %s", err.Error())) } return nil @@ -928,7 +944,7 @@ func (self *RefreshHelper) rebuildPullRequestsMap() { ) } -func (self *RefreshHelper) setGithubPullRequests(authToken string, baseRemote *models.Remote) error { +func (self *RefreshHelper) setGithubPullRequests(baseInfo *githubRemoteInfo) error { if len(self.c.Model().Branches) == 0 { return nil } @@ -940,7 +956,7 @@ func (self *RefreshHelper) setGithubPullRequests(authToken string, baseRemote *m return branch.UpstreamBranch }) - prs, err := self.c.Git().GitHub.FetchRecentPRs(branchNames, baseRemote, authToken) + prs, err := self.c.Git().GitHub.FetchRecentPRs(branchNames, &baseInfo.serviceInfo, baseInfo.authToken) if err != nil { return err } diff --git a/pkg/gui/controllers/helpers/refresh_helper_test.go b/pkg/gui/controllers/helpers/refresh_helper_test.go index dea3b8b81d2..cebd044c4ed 100644 --- a/pkg/gui/controllers/helpers/refresh_helper_test.go +++ b/pkg/gui/controllers/helpers/refresh_helper_test.go @@ -3,6 +3,7 @@ package helpers import ( "testing" + "github.com/jesseduffield/lazygit/pkg/commands/hosting_service" "github.com/jesseduffield/lazygit/pkg/commands/models" "github.com/samber/lo" "github.com/stretchr/testify/assert" @@ -60,14 +61,64 @@ func TestGetGithubBaseRemote(t *testing.T) { assert.Nil(t, result) } else { assert.NotNil(t, result) - assert.Equal(t, c.expected, result.Name) + assert.Equal(t, c.expected, result.remote.Name) } }) } } +func TestGetAuthenticatedGithubRemotes(t *testing.T) { + githubRemotes := []githubRemoteInfo{ + makeGithubRemoteInfo("origin", "github.com"), + makeGithubRemoteInfo("fork", "github.com"), + makeGithubRemoteInfo("enterprise", "ghe.example.com"), + makeGithubRemoteInfo("missing-auth", "no-token.example.com"), + } + + callsByHost := map[string]int{} + result := getAuthenticatedGithubRemotes(githubRemotes, func(host string) string { + callsByHost[host]++ + switch host { + case "github.com": + return "github-token" + case "ghe.example.com": + return "ghe-token" + default: + return "" + } + }) + + assert.Equal(t, []githubRemoteInfo{ + makeAuthenticatedGithubRemoteInfo("origin", "github.com", "github-token"), + makeAuthenticatedGithubRemoteInfo("fork", "github.com", "github-token"), + makeAuthenticatedGithubRemoteInfo("enterprise", "ghe.example.com", "ghe-token"), + }, result) + // Two remotes share github.com; the lookup runs only once. + assert.Equal(t, map[string]int{ + "github.com": 1, + "ghe.example.com": 1, + "no-token.example.com": 1, + }, callsByHost) +} + func makeGithubRemoteInfoList(names ...string) []githubRemoteInfo { return lo.Map(names, func(name string, _ int) githubRemoteInfo { - return githubRemoteInfo{remote: &models.Remote{Name: name}, repoName: name} + return makeGithubRemoteInfo(name, name) }) } + +func makeGithubRemoteInfo(name string, webDomain string) githubRemoteInfo { + return githubRemoteInfo{ + remote: &models.Remote{Name: name}, + serviceInfo: hosting_service.ServiceInfo{ + RepoName: name, + WebDomain: webDomain, + }, + } +} + +func makeAuthenticatedGithubRemoteInfo(name string, webDomain string, authToken string) githubRemoteInfo { + info := makeGithubRemoteInfo(name, webDomain) + info.authToken = authToken + return info +}