diff --git a/src/dcc-cli.ts b/src/dcc-cli.ts index 53d677f..b5ca306 100644 --- a/src/dcc-cli.ts +++ b/src/dcc-cli.ts @@ -176,8 +176,11 @@ async function submit() { return } - const checks = await githubOps.getChecks(pr?.number) - const isRequired = (c: Check) => pr.requiredChecks.includes(c.name) + const [checks, required] = await Promise.all([ + githubOps.getChecks(pr.number), + gitOps.mainBranch().then(b => githubOps.getAllRequiredChecks(b, pr)), + ]) + const isRequired = (c: Check) => required.has(c.name) const printChecksWithRequired = (list: Check[]) => { for (const c of list) { printCheck(c, isRequired(c)) @@ -261,7 +264,10 @@ async function status() { if (!pr) { print('No PR was created for this branch') } else { - const checks: Check[] = await githubOps.getChecks(pr?.number) + const [checks, required] = await Promise.all([ + githubOps.getChecks(pr.number), + gitOps.mainBranch().then(b => githubOps.getAllRequiredChecks(b, pr)), + ]) print(`PR #${pr.number}: ${pr.title}`) print(pr.url) if (pr.lastCommit) { @@ -280,9 +286,9 @@ async function status() { const orderOfCheck = (c: Check) => c.tag === 'PASSING' ? 0 : c.tag === 'PENDING' ? 1 : c.tag === 'FAILING' ? 2 : shouldNeverHappen(c) for (const c of checks.sort((a, b) => orderOfCheck(a) - orderOfCheck(b))) { - printCheck(c, pr.requiredChecks.includes(c.name)) + printCheck(c, required.has(c.name)) } - const numRequired = checks.filter(c => pr.requiredChecks.includes(c.name)).length + const numRequired = checks.filter(c => required.has(c.name)).length print(` (${numRequired}/${checks.length} are required 🔒)`) print() } diff --git a/src/github-ops.ts b/src/github-ops.ts index 0a1cec9..9788a4c 100644 --- a/src/github-ops.ts +++ b/src/github-ops.ts @@ -1,7 +1,22 @@ import { GitOps } from './git-ops.js' import { Octokit } from '@octokit/rest' +import { z } from 'zod' import { logger } from './logger.js' +const BranchRulesSchema = z.array( + z + .object({ + type: z.string(), + parameters: z + .object({ + required_status_checks: z.array(z.object({ context: z.string() }).passthrough()).optional(), + }) + .passthrough() + .optional(), + }) + .passthrough(), +) + export type Check = | { tag: 'FAILING' @@ -44,6 +59,33 @@ export class GithubOps { return [...pending, ...passing, ...failing] } + async getAllRequiredChecks(branch: string, pr: { protectionRequiredChecks: string[] }): Promise> { + const fromRulesets = await this.getRulesetRequiredChecks(branch) + return new Set([...pr.protectionRequiredChecks, ...fromRulesets]) + } + + async getRulesetRequiredChecks(branch: string): Promise { + const r = await this.gitOps.getRepo() + try { + const resp = await this.kit.request('GET /repos/{owner}/{repo}/rules/branches/{branch}', { + owner: r.owner, + repo: r.name, + branch, + }) + const rules = BranchRulesSchema.parse(resp.data) + return rules + .filter(rule => rule.type === 'required_status_checks') + .flatMap(rule => rule.parameters?.required_status_checks ?? []) + .map(c => c.context) + } catch (err) { + const status = (err as { status?: number })?.status + if (status !== 404) { + logger.warn(`Failed to fetch ruleset required checks (status=${status ?? 'unknown'}): ${err}`) + } + return [] + } + } + async merge(prNumber: number): Promise { const r = await this.gitOps.getRepo() await this.kit.pulls.merge({ owner: r.owner, repo: r.name, pull_number: prNumber, merge_method: 'squash' }) diff --git a/src/gql.ts b/src/gql.ts index 0dd3110..fbdb866 100644 --- a/src/gql.ts +++ b/src/gql.ts @@ -12,7 +12,7 @@ export interface CurrentPrInfo { mergeabilityStatus: MergeabilityStatus url: string openUrl: string - requiredChecks: string[] + protectionRequiredChecks: string[] lastCommit?: { message: string abbreviatedOid?: string @@ -201,7 +201,7 @@ export class GraphqlOps { const mainBranch = await this.gitOps.mainBranch() const protectionRules = repository?.branchProtectionRules?.nodes ?? [] - const requiredChecks = protectionRules + const protectionRequiredChecks = protectionRules .filter(rule => rule.requiresStatusChecks && rule.matchingRefs.nodes.some(ref => ref.name === mainBranch)) .flatMap(rule => rule.requiredStatusCheckContexts) @@ -222,7 +222,7 @@ export class GraphqlOps { mergeabilityStatus, url, openUrl, - requiredChecks, + protectionRequiredChecks, lastCommit: commit && { message: commit?.message, abbreviatedOid: commit?.abbreviatedOid,