diff --git a/src/App.hs b/src/App.hs index 3e60ee8..94029d5 100644 --- a/src/App.hs +++ b/src/App.hs @@ -63,6 +63,7 @@ getSettings = do primeCacheMode <- (==Just "1") <$> lookupEnv "TASKRUNNER_PRIME_CACHE_MODE" mainBranch <- map toText <$> lookupEnv "TASKRUNNER_MAIN_BRANCH" quietMode <- (==Just "1") <$> lookupEnv "TASKRUNNER_QUIET" + githubTokenRefreshThresholdSeconds <- maybe 300 read <$> lookupEnv "TASKRUNNER_GITHUB_TOKEN_REFRESH_THRESHOLD_SECONDS" pure Settings { stateDirectory , rootDirectory @@ -78,6 +79,7 @@ getSettings = do , mainBranch , force = False , quietMode + , githubTokenRefreshThresholdSeconds } main :: IO () diff --git a/src/CommitStatus.hs b/src/CommitStatus.hs index da43fa6..2ce6806 100644 --- a/src/CommitStatus.hs +++ b/src/CommitStatus.hs @@ -4,19 +4,24 @@ module CommitStatus where import Universum -import Data.Aeson (FromJSON(..), ToJSON(..), encode) +import Data.Aeson (FromJSON(..), ToJSON(..), encode, eitherDecodeFileStrict) import Data.Time.Clock.POSIX (getPOSIXTime) +import Data.Time.Clock (UTCTime, getCurrentTime, diffUTCTime) +import Data.Time.Format.ISO8601 (iso8601ParseM, iso8601Show) import Web.JWT (Algorithm(RS256), JWTClaimsSet(..), encodeSigned, numericDate, stringOrURI, EncodeSigner (..), readRsaSecret, JOSEHeader (..)) import qualified Data.Text as T import qualified Data.Text.Encoding as TE import qualified Network.HTTP.Client as HTTP import Network.HTTP.Client.TLS (tlsManagerSettings) -import System.Environment (getEnv, lookupEnv, setEnv) +import System.Environment (getEnv, lookupEnv) import Network.HTTP.Types.Status (Status(..)) import Data.Aeson.Decoding (eitherDecode) import qualified Data.Text as Text +import qualified Data.ByteString.Lazy as BL +import System.FileLock (withFileLock, SharedExclusive(..)) +import System.Directory (doesFileExist) import Utils (getCurrentCommit, logError, logDebug) -import Types (AppState(..), GithubClient(..)) +import Types (AppState(..), GithubClient(..), Settings(..)) -- Define the data types for the status update data StatusRequest = StatusRequest @@ -35,23 +40,107 @@ data StatusResponse = StatusResponse deriving anyclass (FromJSON) -- Define the data type for the installation token response -newtype InstallationTokenResponse = InstallationTokenResponse +data InstallationTokenResponse = InstallationTokenResponse { token :: T.Text + , expires_at :: T.Text } deriving (Show, Generic) deriving anyclass (FromJSON) +-- Cache file for storing credentials across processes +data CredentialsCache = CredentialsCache + { cachedToken :: T.Text + , cachedExpiresAt :: T.Text + } deriving (Show, Generic) + deriving anyclass (FromJSON, ToJSON) + +credentialsCacheFile :: Settings -> FilePath +credentialsCacheFile settings = settings.stateDirectory <> "/.github-token-cache.json" + +-- Try to read cache file (no locking - caller should hold lock) +tryReadCache :: FilePath -> IO (Maybe (T.Text, UTCTime)) +tryReadCache cacheFile = do + exists <- doesFileExist cacheFile + if exists then do + result <- eitherDecodeFileStrict @CredentialsCache cacheFile + case result of + Left _ -> pure Nothing + Right cache -> + case iso8601ParseM (toString cache.cachedExpiresAt) of + Just expiresAt -> pure $ Just (cache.cachedToken, expiresAt) + Nothing -> pure Nothing + else + pure Nothing + getClient :: AppState -> IO GithubClient getClient appState = do mClient <- readIORef appState.githubClient case mClient of - Just client -> pure client + Just client -> do + -- Fast path: check if cached token is still valid + now <- getCurrentTime + let threshold = fromIntegral appState.settings.githubTokenRefreshThresholdSeconds + if diffUTCTime client.expiresAt now >= threshold + then pure client + else do + -- Token expiring, need to refresh + logDebug appState $ "GitHub token expired or expiring soon (in " <> show (floor (diffUTCTime client.expiresAt now) :: Int) <> "s), refreshing..." + writeIORef appState.githubClient Nothing + loadOrRefreshClient appState + Nothing -> + loadOrRefreshClient appState + +loadOrRefreshClient :: AppState -> IO GithubClient +loadOrRefreshClient appState = do + let cacheFile = credentialsCacheFile appState.settings + let lockFile = cacheFile <> ".lock" + let threshold = fromIntegral appState.settings.githubTokenRefreshThresholdSeconds + + client <- withFileLock lockFile Exclusive \_ -> do + -- Under EXCLUSIVE lock: read, check, refresh if needed + mCached <- tryReadCache cacheFile + + now <- getCurrentTime + case mCached of + Just (cachedToken, expiresAt) + | diffUTCTime expiresAt now >= threshold -> do + -- Valid cached token + logDebug appState "Using cached GitHub token from file" + buildClientWithToken appState cachedToken expiresAt + | otherwise -> do + -- Expired token, refresh + logDebug appState "Cached token expired, refreshing" + refreshToken appState cacheFile + Nothing -> do + -- No cache, create new token + logDebug appState "No cached token, creating new one" + refreshToken appState cacheFile + + writeIORef appState.githubClient (Just client) + pure client + +-- Create new token and write to cache (caller should hold EXCLUSIVE lock) +refreshToken :: AppState -> FilePath -> IO GithubClient +refreshToken appState cacheFile = do + tokenResponse <- createTokenFromGitHub appState + + expiresAt <- case iso8601ParseM (toString tokenResponse.expires_at) of + Just t -> pure t Nothing -> do - client <- initClient appState - writeIORef appState.githubClient $ Just client - pure client + logError appState $ "CommitStatus: Failed to parse expires_at: " <> tokenResponse.expires_at + exitFailure + + -- Write to cache (already under EXCLUSIVE lock, no additional locking needed) + let cache = CredentialsCache + { cachedToken = tokenResponse.token + , cachedExpiresAt = T.pack $ iso8601Show expiresAt + } + BL.writeFile cacheFile (encode cache) + + -- Build and return client + buildClientWithToken appState tokenResponse.token expiresAt -initClient :: AppState -> IO GithubClient -initClient appState = do +buildClientWithToken :: AppState -> T.Text -> UTCTime -> IO GithubClient +buildClientWithToken _appState accessToken expiresAt = do -- Load environment variables apiUrl <- fromMaybe "https://api.github.com" <$> lookupEnv "GITHUB_API_URL" appId <- getEnv "GITHUB_APP_ID" @@ -59,62 +148,64 @@ initClient appState = do privateKeyStr <- getEnv "GITHUB_APP_PRIVATE_KEY" owner <- getEnv "GITHUB_REPOSITORY_OWNER" repo <- getEnv "GITHUB_REPOSITORY" + manager <- HTTP.newManager tlsManagerSettings + + pure $ GithubClient + { apiUrl = T.pack apiUrl + , appId = T.pack appId + , installationId = T.pack installationId + , privateKey = T.pack privateKeyStr + , owner = T.pack owner + , repo = T.pack repo + , manager = manager + , accessToken = accessToken + , expiresAt = expiresAt + } + +-- Create a new GitHub App installation token from GitHub API +createTokenFromGitHub :: AppState -> IO InstallationTokenResponse +createTokenFromGitHub appState = do + -- Load environment variables + apiUrl <- fromMaybe "https://api.github.com" <$> lookupEnv "GITHUB_API_URL" + appId <- getEnv "GITHUB_APP_ID" + installationId <- getEnv "GITHUB_INSTALLATION_ID" + privateKeyStr <- getEnv "GITHUB_APP_PRIVATE_KEY" + -- Prepare the HTTP manager manager <- HTTP.newManager tlsManagerSettings - let createToken = do - let privateKeyBytes = encodeUtf8 $ Text.replace "|" "\n" $ toText privateKeyStr - let privateKey = fromMaybe (error "Invalid github key") $ readRsaSecret privateKeyBytes - - -- Create the JWT token - now <- getPOSIXTime - let claims = mempty { iss = stringOrURI $ T.pack appId - , iat = numericDate now - , exp = numericDate (now + 5 * 60) - } - let jwt = encodeSigned (EncodeRSAPrivateKey privateKey) (mempty { alg = Just RS256 }) claims - - -- Get the installation access token - let installUrl = apiUrl <> "/app/installations/" ++ installationId ++ "/access_tokens" - initRequest <- HTTP.parseRequest installUrl - let request = initRequest - { HTTP.method = "POST" - , HTTP.requestHeaders = - [ ("Authorization", "Bearer " <> TE.encodeUtf8 jwt) - , ("Accept", "application/vnd.github.v3+json") - , ("User-Agent", "restaumatic-bot") - ] - } - response <- HTTP.httpLbs request manager - let mTokenResponse = eitherDecode @InstallationTokenResponse (HTTP.responseBody response) - case mTokenResponse of - Left err -> do - logError appState $ "CommitStatus: Failed to parse installation token response: " <> show err - logError appState $ "CommitStatus: Response: " <> decodeUtf8 response.responseBody - - -- FIXME: handle the error better - exitFailure - Right tokenResponse -> - pure tokenResponse.token - - -- Try to read token from environment variable - -- Otherwise generate a new one, and set env for future uses (also in child processes) - accessToken <- lookupEnv "_taskrunner_github_access_token" >>= \case - Just token -> pure $ T.pack token - Nothing -> do - token <- createToken - setEnv "_taskrunner_github_access_token" $ T.unpack token - pure token - - pure $ GithubClient { apiUrl = T.pack apiUrl - , appId = T.pack appId - , installationId = T.pack installationId - , privateKey = T.pack privateKeyStr - , owner = T.pack owner - , repo = T.pack repo - , manager = manager - , accessToken = accessToken - } + let privateKeyBytes = encodeUtf8 $ Text.replace "|" "\n" $ toText privateKeyStr + let privateKey = fromMaybe (error "Invalid github key") $ readRsaSecret privateKeyBytes + + -- Create the JWT token + now <- getPOSIXTime + let claims = mempty { iss = stringOrURI $ T.pack appId + , iat = numericDate now + , exp = numericDate (now + 5 * 60) + } + let jwt = encodeSigned (EncodeRSAPrivateKey privateKey) (mempty { alg = Just RS256 }) claims + + -- Get the installation access token + let installUrl = apiUrl <> "/app/installations/" ++ installationId ++ "/access_tokens" + initRequest <- HTTP.parseRequest installUrl + let request = initRequest + { HTTP.method = "POST" + , HTTP.requestHeaders = + [ ("Authorization", "Bearer " <> TE.encodeUtf8 jwt) + , ("Accept", "application/vnd.github.v3+json") + , ("User-Agent", "restaumatic-bot") + ] + } + response <- HTTP.httpLbs request manager + let mTokenResponse = eitherDecode @InstallationTokenResponse (HTTP.responseBody response) + case mTokenResponse of + Left err -> do + logError appState $ "CommitStatus: Failed to parse installation token response: " <> show err + logError appState $ "CommitStatus: Response: " <> decodeUtf8 response.responseBody + -- FIXME: handle the error better + exitFailure + Right tokenResponse -> + pure tokenResponse updateCommitStatus :: MonadIO m => AppState -> StatusRequest -> m () updateCommitStatus appState statusRequest = liftIO do diff --git a/src/Types.hs b/src/Types.hs index 8ee2230..8b6a70d 100644 --- a/src/Types.hs +++ b/src/Types.hs @@ -4,6 +4,7 @@ import Universum import SnapshotCliArgs (SnapshotCliArgs) import Data.Aeson (FromJSON, ToJSON) import qualified Network.HTTP.Client as HTTP +import Data.Time.Clock (UTCTime) data Settings = Settings { stateDirectory :: FilePath @@ -20,6 +21,7 @@ data Settings = Settings , mainBranch :: Maybe Text , force :: Bool , quietMode :: Bool + , githubTokenRefreshThresholdSeconds :: Int } deriving (Show) type JobName = String @@ -66,4 +68,5 @@ data GithubClient = GithubClient , repo :: Text , manager :: HTTP.Manager , accessToken :: Text + , expiresAt :: UTCTime } diff --git a/test/FakeGithubApi.hs b/test/FakeGithubApi.hs index 929cf43..c1645b4 100644 --- a/test/FakeGithubApi.hs +++ b/test/FakeGithubApi.hs @@ -1,7 +1,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecursiveDo #-} -module FakeGithubApi (Server, start, stop, clearOutput, getOutput) where +module FakeGithubApi (Server, start, stop, clearOutput, getOutput, setTokenLifetime) where import Universum @@ -11,6 +11,8 @@ import Network.HTTP.Types (status200, status201, status400, status404, methodPos import Data.Aeson (encode, object, (.=), Value) import qualified Data.Aeson as Aeson import qualified Data.Map.Strict as Map +import Data.Time.Clock (getCurrentTime, addUTCTime) +import Data.Time.Format.ISO8601 (iso8601Show) import Control.Concurrent (forkIO, ThreadId, killThread) @@ -31,9 +33,17 @@ handleAccessTokenRequest :: Server -> Text -> Request -> (Response -> IO Respons handleAccessTokenRequest server instId req respond = if requestMethod req == methodPost then do + -- Read token lifetime from server state + lifetimeSeconds <- readIORef server.tokenLifetimeSeconds + now <- getCurrentTime + let expiresAt = addUTCTime (fromIntegral lifetimeSeconds) now addOutput server $ "Requested access token for installation " <> instId respond $ responseLBS status200 [("Content-Type", "application/json")] - (encode $ object ["token" .= ("mock-access-token" :: Text), "installation_id" .= instId]) + (encode $ object + [ "token" .= ("mock-access-token" :: Text) + , "expires_at" .= iso8601Show expiresAt + , "installation_id" .= instId + ]) else respond $ responseLBS status400 [] "Bad Request" handleCommitStatusRequest :: Server -> Text -> Text -> Text -> Request -> (Response -> IO ResponseReceived) -> IO ResponseReceived @@ -61,6 +71,7 @@ data Server = Server { tid :: ThreadId , output :: IORef [Text] , statuses :: IORef (Map Text [Value]) -- Map from commit SHA to list of status objects + , tokenLifetimeSeconds :: IORef Int } start :: Int -> IO Server @@ -68,9 +79,10 @@ start port = do started <- newEmptyMVar output <- newIORef [] statuses <- newIORef Map.empty + tokenLifetimeSeconds <- newIORef 3600 -- Default: 1 hour let settings = Warp.setPort port $ Warp.setBeforeMainLoop (putMVar started ()) Warp.defaultSettings rec - let server = Server {tid, output, statuses} + let server = Server {tid, output, statuses, tokenLifetimeSeconds} tid <- forkIO $ Warp.runSettings settings $ app server takeMVar started pure server @@ -82,9 +94,10 @@ addOutput :: Server -> Text -> IO () addOutput (Server {output}) msg = modifyIORef output (msg :) clearOutput :: Server -> IO () -clearOutput (Server {output, statuses}) = do +clearOutput (Server {output, statuses, tokenLifetimeSeconds}) = do writeIORef output [] writeIORef statuses Map.empty + writeIORef tokenLifetimeSeconds 3600 -- Reset to default getOutput :: Server -> IO [Text] getOutput (Server {output}) = reverse <$> readIORef output @@ -100,3 +113,6 @@ getStatuses :: Server -> Text -> IO [Value] getStatuses (Server {statuses}) commitSha = do statusMap <- readIORef statuses pure $ fromMaybe [] $ Map.lookup commitSha statusMap + +setTokenLifetime :: Server -> Int -> IO () +setTokenLifetime server seconds = writeIORef server.tokenLifetimeSeconds seconds diff --git a/test/Spec.hs b/test/Spec.hs index 357c402..2f8d1d6 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -87,6 +87,10 @@ runTest fakeGithubServer source = do withSystemTempDirectory "testrunner-test" \dir -> do let options = getOptions (toText source) + -- Set token lifetime if specified in test + whenJust options.githubTokenLifetime $ \lifetime -> + FakeGithubApi.setTokenLifetime fakeGithubServer lifetime + (pipeRead, pipeWrite) <- createPipe path <- getEnv "PATH" @@ -170,6 +174,7 @@ data Options = Options -- If github status is disabled, taskrunner should work without them. , githubKeys :: Bool , quiet :: Bool + , githubTokenLifetime :: Maybe Int } instance Default Options where @@ -179,6 +184,7 @@ instance Default Options where , s3 = False , githubKeys = False , quiet = False + , githubTokenLifetime = Nothing } getOptions :: Text -> Options @@ -198,6 +204,9 @@ getOptions source = flip execState def $ go (lines source) ["#", "github", "keys"] -> do modify (\s -> s { githubKeys = True }) go rest + ["#", "github", "token", "lifetime", n] -> do + modify (\s -> s { githubTokenLifetime = readMaybe (toString n) }) + go rest ["#", "quiet"] -> do modify (\s -> (s :: Options) { quiet = True }) go rest diff --git a/test/t/github-commit-status-failure-then-success.out b/test/t/github-commit-status-failure-then-success.out index 2a4af21..e2c044d 100644 --- a/test/t/github-commit-status-failure-then-success.out +++ b/test/t/github-commit-status-failure-then-success.out @@ -4,5 +4,4 @@ -- github: Requested access token for installation 123 Updated commit status for fakeowner/fakerepo to {"context":"mytask","description":null,"state":"failure","target_url":null} -Requested access token for installation 123 Updated commit status for fakeowner/fakerepo to {"context":"mytask","description":null,"state":"success","target_url":null} diff --git a/test/t/slow/github-token-refresh.out b/test/t/slow/github-token-refresh.out new file mode 100644 index 0000000..7edeed5 --- /dev/null +++ b/test/t/slow/github-token-refresh.out @@ -0,0 +1,8 @@ +-- output: +[mytask] stdout | Task started, pending status posted +[mytask] stdout | Task finishing (token will be refreshed for final status) +-- github: +Requested access token for installation 123 +Updated commit status for fakeowner/fakerepo to {"context":"mytask","description":"not cached","state":"pending","target_url":null} +Requested access token for installation 123 +Updated commit status for fakeowner/fakerepo to {"context":"mytask","description":null,"state":"success","target_url":null} diff --git a/test/t/slow/github-token-refresh.txt b/test/t/slow/github-token-refresh.txt new file mode 100644 index 0000000..42fb61e --- /dev/null +++ b/test/t/slow/github-token-refresh.txt @@ -0,0 +1,17 @@ +# check output github +# no toplevel +# github keys +# github token lifetime 2 + +export TASKRUNNER_ENABLE_COMMIT_STATUS=1 +export TASKRUNNER_GITHUB_TOKEN_REFRESH_THRESHOLD_SECONDS=1 + +git init -q +git commit --allow-empty -q -m "Initial commit" + +taskrunner -n mytask bash -e -c ' + snapshot -n --commit-status + echo "Task started, pending status posted" + sleep 3 + echo "Task finishing (token will be refreshed for final status)" +'