Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ getSettings = do
primeCacheMode <- (==Just "1") <$> lookupEnv "TASKRUNNER_PRIME_CACHE_MODE"
mainBranch <- map toText <$> lookupEnv "TASKRUNNER_MAIN_BRANCH"
quietMode <- (==Just "1") <$> lookupEnv "TASKRUNNER_QUIET"
githubTokenRefreshThresholdSeconds <- maybe 300 read <$> lookupEnv "TASKRUNNER_GITHUB_TOKEN_REFRESH_THRESHOLD_SECONDS"
pure Settings
{ stateDirectory
, rootDirectory
Expand All @@ -78,6 +79,7 @@ getSettings = do
, mainBranch
, force = False
, quietMode
, githubTokenRefreshThresholdSeconds
}

main :: IO ()
Expand Down
217 changes: 154 additions & 63 deletions src/CommitStatus.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,24 @@ module CommitStatus where

import Universum

import Data.Aeson (FromJSON(..), ToJSON(..), encode)
import Data.Aeson (FromJSON(..), ToJSON(..), encode, eitherDecodeFileStrict)
import Data.Time.Clock.POSIX (getPOSIXTime)
import Data.Time.Clock (UTCTime, getCurrentTime, diffUTCTime)
import Data.Time.Format.ISO8601 (iso8601ParseM, iso8601Show)
import Web.JWT (Algorithm(RS256), JWTClaimsSet(..), encodeSigned, numericDate, stringOrURI, EncodeSigner (..), readRsaSecret, JOSEHeader (..))
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Network.HTTP.Client as HTTP
import Network.HTTP.Client.TLS (tlsManagerSettings)
import System.Environment (getEnv, lookupEnv, setEnv)
import System.Environment (getEnv, lookupEnv)
import Network.HTTP.Types.Status (Status(..))
import Data.Aeson.Decoding (eitherDecode)
import qualified Data.Text as Text
import qualified Data.ByteString.Lazy as BL
import System.FileLock (withFileLock, SharedExclusive(..))
import System.Directory (doesFileExist)
import Utils (getCurrentCommit, logError, logDebug)
import Types (AppState(..), GithubClient(..))
import Types (AppState(..), GithubClient(..), Settings(..))

-- Define the data types for the status update
data StatusRequest = StatusRequest
Expand All @@ -35,86 +40,172 @@ data StatusResponse = StatusResponse
deriving anyclass (FromJSON)

-- Define the data type for the installation token response
newtype InstallationTokenResponse = InstallationTokenResponse
data InstallationTokenResponse = InstallationTokenResponse
{ token :: T.Text
, expires_at :: T.Text
} deriving (Show, Generic)
deriving anyclass (FromJSON)

-- Cache file for storing credentials across processes
data CredentialsCache = CredentialsCache
{ cachedToken :: T.Text
, cachedExpiresAt :: T.Text
} deriving (Show, Generic)
deriving anyclass (FromJSON, ToJSON)

credentialsCacheFile :: Settings -> FilePath
credentialsCacheFile settings = settings.stateDirectory <> "/.github-token-cache.json"

-- Try to read cache file (no locking - caller should hold lock)
tryReadCache :: FilePath -> IO (Maybe (T.Text, UTCTime))
tryReadCache cacheFile = do
exists <- doesFileExist cacheFile
if exists then do
result <- eitherDecodeFileStrict @CredentialsCache cacheFile
case result of
Left _ -> pure Nothing
Right cache ->
case iso8601ParseM (toString cache.cachedExpiresAt) of
Just expiresAt -> pure $ Just (cache.cachedToken, expiresAt)
Nothing -> pure Nothing
else
pure Nothing

getClient :: AppState -> IO GithubClient
getClient appState = do
mClient <- readIORef appState.githubClient
case mClient of
Just client -> pure client
Just client -> do
-- Fast path: check if cached token is still valid
now <- getCurrentTime
let threshold = fromIntegral appState.settings.githubTokenRefreshThresholdSeconds
if diffUTCTime client.expiresAt now >= threshold
then pure client
else do
-- Token expiring, need to refresh
logDebug appState $ "GitHub token expired or expiring soon (in " <> show (floor (diffUTCTime client.expiresAt now) :: Int) <> "s), refreshing..."
writeIORef appState.githubClient Nothing
loadOrRefreshClient appState
Nothing ->
loadOrRefreshClient appState

loadOrRefreshClient :: AppState -> IO GithubClient
loadOrRefreshClient appState = do
let cacheFile = credentialsCacheFile appState.settings
let lockFile = cacheFile <> ".lock"
let threshold = fromIntegral appState.settings.githubTokenRefreshThresholdSeconds

client <- withFileLock lockFile Exclusive \_ -> do
-- Under EXCLUSIVE lock: read, check, refresh if needed
mCached <- tryReadCache cacheFile

now <- getCurrentTime
case mCached of
Just (cachedToken, expiresAt)
| diffUTCTime expiresAt now >= threshold -> do
-- Valid cached token
logDebug appState "Using cached GitHub token from file"
buildClientWithToken appState cachedToken expiresAt
| otherwise -> do
-- Expired token, refresh
logDebug appState "Cached token expired, refreshing"
refreshToken appState cacheFile
Nothing -> do
-- No cache, create new token
logDebug appState "No cached token, creating new one"
refreshToken appState cacheFile

writeIORef appState.githubClient (Just client)
pure client

-- Create new token and write to cache (caller should hold EXCLUSIVE lock)
refreshToken :: AppState -> FilePath -> IO GithubClient
refreshToken appState cacheFile = do
tokenResponse <- createTokenFromGitHub appState

expiresAt <- case iso8601ParseM (toString tokenResponse.expires_at) of
Just t -> pure t
Nothing -> do
client <- initClient appState
writeIORef appState.githubClient $ Just client
pure client
logError appState $ "CommitStatus: Failed to parse expires_at: " <> tokenResponse.expires_at
exitFailure

-- Write to cache (already under EXCLUSIVE lock, no additional locking needed)
let cache = CredentialsCache
{ cachedToken = tokenResponse.token
, cachedExpiresAt = T.pack $ iso8601Show expiresAt
}
BL.writeFile cacheFile (encode cache)

-- Build and return client
buildClientWithToken appState tokenResponse.token expiresAt

initClient :: AppState -> IO GithubClient
initClient appState = do
buildClientWithToken :: AppState -> T.Text -> UTCTime -> IO GithubClient
buildClientWithToken _appState accessToken expiresAt = do
-- Load environment variables
apiUrl <- fromMaybe "https://api.github.com" <$> lookupEnv "GITHUB_API_URL"
appId <- getEnv "GITHUB_APP_ID"
installationId <- getEnv "GITHUB_INSTALLATION_ID"
privateKeyStr <- getEnv "GITHUB_APP_PRIVATE_KEY"
owner <- getEnv "GITHUB_REPOSITORY_OWNER"
repo <- getEnv "GITHUB_REPOSITORY"
manager <- HTTP.newManager tlsManagerSettings

pure $ GithubClient
{ apiUrl = T.pack apiUrl
, appId = T.pack appId
, installationId = T.pack installationId
, privateKey = T.pack privateKeyStr
, owner = T.pack owner
, repo = T.pack repo
, manager = manager
, accessToken = accessToken
, expiresAt = expiresAt
}

-- Create a new GitHub App installation token from GitHub API
createTokenFromGitHub :: AppState -> IO InstallationTokenResponse
createTokenFromGitHub appState = do
-- Load environment variables
apiUrl <- fromMaybe "https://api.github.com" <$> lookupEnv "GITHUB_API_URL"
appId <- getEnv "GITHUB_APP_ID"
installationId <- getEnv "GITHUB_INSTALLATION_ID"
privateKeyStr <- getEnv "GITHUB_APP_PRIVATE_KEY"

-- Prepare the HTTP manager
manager <- HTTP.newManager tlsManagerSettings

let createToken = do
let privateKeyBytes = encodeUtf8 $ Text.replace "|" "\n" $ toText privateKeyStr
let privateKey = fromMaybe (error "Invalid github key") $ readRsaSecret privateKeyBytes

-- Create the JWT token
now <- getPOSIXTime
let claims = mempty { iss = stringOrURI $ T.pack appId
, iat = numericDate now
, exp = numericDate (now + 5 * 60)
}
let jwt = encodeSigned (EncodeRSAPrivateKey privateKey) (mempty { alg = Just RS256 }) claims

-- Get the installation access token
let installUrl = apiUrl <> "/app/installations/" ++ installationId ++ "/access_tokens"
initRequest <- HTTP.parseRequest installUrl
let request = initRequest
{ HTTP.method = "POST"
, HTTP.requestHeaders =
[ ("Authorization", "Bearer " <> TE.encodeUtf8 jwt)
, ("Accept", "application/vnd.github.v3+json")
, ("User-Agent", "restaumatic-bot")
]
}
response <- HTTP.httpLbs request manager
let mTokenResponse = eitherDecode @InstallationTokenResponse (HTTP.responseBody response)
case mTokenResponse of
Left err -> do
logError appState $ "CommitStatus: Failed to parse installation token response: " <> show err
logError appState $ "CommitStatus: Response: " <> decodeUtf8 response.responseBody

-- FIXME: handle the error better
exitFailure
Right tokenResponse ->
pure tokenResponse.token

-- Try to read token from environment variable
-- Otherwise generate a new one, and set env for future uses (also in child processes)
accessToken <- lookupEnv "_taskrunner_github_access_token" >>= \case
Just token -> pure $ T.pack token
Nothing -> do
token <- createToken
setEnv "_taskrunner_github_access_token" $ T.unpack token
pure token

pure $ GithubClient { apiUrl = T.pack apiUrl
, appId = T.pack appId
, installationId = T.pack installationId
, privateKey = T.pack privateKeyStr
, owner = T.pack owner
, repo = T.pack repo
, manager = manager
, accessToken = accessToken
}
let privateKeyBytes = encodeUtf8 $ Text.replace "|" "\n" $ toText privateKeyStr
let privateKey = fromMaybe (error "Invalid github key") $ readRsaSecret privateKeyBytes

-- Create the JWT token
now <- getPOSIXTime
let claims = mempty { iss = stringOrURI $ T.pack appId
, iat = numericDate now
, exp = numericDate (now + 5 * 60)
}
let jwt = encodeSigned (EncodeRSAPrivateKey privateKey) (mempty { alg = Just RS256 }) claims

-- Get the installation access token
let installUrl = apiUrl <> "/app/installations/" ++ installationId ++ "/access_tokens"
initRequest <- HTTP.parseRequest installUrl
let request = initRequest
{ HTTP.method = "POST"
, HTTP.requestHeaders =
[ ("Authorization", "Bearer " <> TE.encodeUtf8 jwt)
, ("Accept", "application/vnd.github.v3+json")
, ("User-Agent", "restaumatic-bot")
]
}
response <- HTTP.httpLbs request manager
let mTokenResponse = eitherDecode @InstallationTokenResponse (HTTP.responseBody response)
case mTokenResponse of
Left err -> do
logError appState $ "CommitStatus: Failed to parse installation token response: " <> show err
logError appState $ "CommitStatus: Response: " <> decodeUtf8 response.responseBody
-- FIXME: handle the error better
exitFailure
Right tokenResponse ->
pure tokenResponse

updateCommitStatus :: MonadIO m => AppState -> StatusRequest -> m ()
updateCommitStatus appState statusRequest = liftIO do
Expand Down
3 changes: 3 additions & 0 deletions src/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Universum
import SnapshotCliArgs (SnapshotCliArgs)
import Data.Aeson (FromJSON, ToJSON)
import qualified Network.HTTP.Client as HTTP
import Data.Time.Clock (UTCTime)

data Settings = Settings
{ stateDirectory :: FilePath
Expand All @@ -20,6 +21,7 @@ data Settings = Settings
, mainBranch :: Maybe Text
, force :: Bool
, quietMode :: Bool
, githubTokenRefreshThresholdSeconds :: Int
} deriving (Show)

type JobName = String
Expand Down Expand Up @@ -66,4 +68,5 @@ data GithubClient = GithubClient
, repo :: Text
, manager :: HTTP.Manager
, accessToken :: Text
, expiresAt :: UTCTime
}
24 changes: 20 additions & 4 deletions test/FakeGithubApi.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecursiveDo #-}

module FakeGithubApi (Server, start, stop, clearOutput, getOutput) where
module FakeGithubApi (Server, start, stop, clearOutput, getOutput, setTokenLifetime) where

import Universum

Expand All @@ -11,6 +11,8 @@ import Network.HTTP.Types (status200, status201, status400, status404, methodPos
import Data.Aeson (encode, object, (.=), Value)
import qualified Data.Aeson as Aeson
import qualified Data.Map.Strict as Map
import Data.Time.Clock (getCurrentTime, addUTCTime)
import Data.Time.Format.ISO8601 (iso8601Show)

import Control.Concurrent (forkIO, ThreadId, killThread)

Expand All @@ -31,9 +33,17 @@ handleAccessTokenRequest :: Server -> Text -> Request -> (Response -> IO Respons
handleAccessTokenRequest server instId req respond =
if requestMethod req == methodPost
then do
-- Read token lifetime from server state
lifetimeSeconds <- readIORef server.tokenLifetimeSeconds
now <- getCurrentTime
let expiresAt = addUTCTime (fromIntegral lifetimeSeconds) now
addOutput server $ "Requested access token for installation " <> instId
respond $ responseLBS status200 [("Content-Type", "application/json")]
(encode $ object ["token" .= ("mock-access-token" :: Text), "installation_id" .= instId])
(encode $ object
[ "token" .= ("mock-access-token" :: Text)
, "expires_at" .= iso8601Show expiresAt
, "installation_id" .= instId
])
else respond $ responseLBS status400 [] "Bad Request"

handleCommitStatusRequest :: Server -> Text -> Text -> Text -> Request -> (Response -> IO ResponseReceived) -> IO ResponseReceived
Expand Down Expand Up @@ -61,16 +71,18 @@ data Server = Server
{ tid :: ThreadId
, output :: IORef [Text]
, statuses :: IORef (Map Text [Value]) -- Map from commit SHA to list of status objects
, tokenLifetimeSeconds :: IORef Int
}

start :: Int -> IO Server
start port = do
started <- newEmptyMVar
output <- newIORef []
statuses <- newIORef Map.empty
tokenLifetimeSeconds <- newIORef 3600 -- Default: 1 hour
let settings = Warp.setPort port $ Warp.setBeforeMainLoop (putMVar started ()) Warp.defaultSettings
rec
let server = Server {tid, output, statuses}
let server = Server {tid, output, statuses, tokenLifetimeSeconds}
tid <- forkIO $ Warp.runSettings settings $ app server
takeMVar started
pure server
Expand All @@ -82,9 +94,10 @@ addOutput :: Server -> Text -> IO ()
addOutput (Server {output}) msg = modifyIORef output (msg :)

clearOutput :: Server -> IO ()
clearOutput (Server {output, statuses}) = do
clearOutput (Server {output, statuses, tokenLifetimeSeconds}) = do
writeIORef output []
writeIORef statuses Map.empty
writeIORef tokenLifetimeSeconds 3600 -- Reset to default

getOutput :: Server -> IO [Text]
getOutput (Server {output}) = reverse <$> readIORef output
Expand All @@ -100,3 +113,6 @@ getStatuses :: Server -> Text -> IO [Value]
getStatuses (Server {statuses}) commitSha = do
statusMap <- readIORef statuses
pure $ fromMaybe [] $ Map.lookup commitSha statusMap

setTokenLifetime :: Server -> Int -> IO ()
setTokenLifetime server seconds = writeIORef server.tokenLifetimeSeconds seconds
Loading