diff --git a/src/main/oauth2/oauth-state.ts b/src/main/oauth2/oauth-state.ts new file mode 100644 index 000000000..ea46bbc61 --- /dev/null +++ b/src/main/oauth2/oauth-state.ts @@ -0,0 +1,23 @@ +import { randomBytes } from "crypto"; + +export const OAUTH_STATE_SESSION_KEY = "oauthState"; + +export const generateOAuthState = () => randomBytes(32).toString("hex"); + +export const setOAuthState = (req, state = generateOAuthState()) => { + if (!req.session) { + req.session = {}; + } + + req.session[OAUTH_STATE_SESSION_KEY] = state; + + return state; +}; + +export const getOAuthState = (req) => req.session && req.session[OAUTH_STATE_SESSION_KEY]; + +export const clearOAuthState = (req) => { + if (req.session) { + delete req.session[OAUTH_STATE_SESSION_KEY]; + } +}; diff --git a/src/main/routes/oauth2redirect.ts b/src/main/routes/oauth2redirect.ts index 7de8888a4..0e601359d 100644 --- a/src/main/routes/oauth2redirect.ts +++ b/src/main/routes/oauth2redirect.ts @@ -1,11 +1,25 @@ import * as express from "express"; import * as config from "config"; import { accessTokenRequest } from "../oauth2/access-token-request"; +import { clearOAuthState, getOAuthState } from "../oauth2/oauth-state"; export const COOKIE_ACCESS_TOKEN = "accessToken"; const router = express.Router(); export const oauth2redirect = (req, res, next) => { + const oauthState = getOAuthState(req); + + if (!req.query.state || req.query.state !== oauthState) { + clearOAuthState(req); + + const error: any = new Error("Invalid OAuth2 state parameter"); + error.status = 400; + + return next(error); + } + + clearOAuthState(req); + if (req.query.code) { // On successfully obtaining a token, the redirect should go back to ourselves. // Note: This *must not* include any query string. diff --git a/src/main/user/auth-checker-user-only-filter.ts b/src/main/user/auth-checker-user-only-filter.ts index bea0cd7bd..f61d56157 100644 --- a/src/main/user/auth-checker-user-only-filter.ts +++ b/src/main/user/auth-checker-user-only-filter.ts @@ -1,6 +1,7 @@ import { authorize } from "./user-request-authorizer"; import { get } from "config"; import { Logger } from "@hmcts/nodejs-logging"; +import { setOAuthState } from "../oauth2/oauth-state"; export const authCheckerUserOnlyFilter = (req, res, next) => { @@ -19,8 +20,9 @@ export const authCheckerUserOnlyFilter = (req, res, next) => { if (error.status === 403) { next(error); } else { + const state = setOAuthState(req); res.redirect(302, `${get("adminWeb.login_url")}?response_type=code&client_id=` + - `${get("idam.oauth2.client_id")}&redirect_uri=${REDIRECT_URI}`); + `${get("idam.oauth2.client_id")}&redirect_uri=${REDIRECT_URI}&state=${encodeURIComponent(state)}`); } }); }; diff --git a/src/test/routes/oauth2redirect.spec.ts b/src/test/routes/oauth2redirect.spec.ts index 41609d9bb..60ef2fae6 100644 --- a/src/test/routes/oauth2redirect.spec.ts +++ b/src/test/routes/oauth2redirect.spec.ts @@ -15,13 +15,28 @@ chai.use(sinonChai); describe("oauth2redirect", () => { const token = "ey123.ey456"; + const initialiseOAuthSession = () => { + const agent = request.agent(app); + + return agent + .get("/") + .then((res) => { + const state = new URL(res.headers.location).searchParams.get("state"); + + expect(state).to.be.a("string"); + + return { agent, state }; + }); + }; describe("when OAuth2 code is present", () => { it("should set an accessToken cookie and redirect to /", () => { - idamServiceMock.resolveExchangeCode(token); + return initialiseOAuthSession() + .then(({ agent, state }) => { + idamServiceMock.resolveExchangeCode(token); - return request(app) - .get("/oauth2redirect?code=abc123") + return agent.get(`/oauth2redirect?code=abc123&state=${state}`); + }) .then((res) => { const cookies = res.get("Set-Cookie").map((_) => cookie.parse(_)); expect(cookies.some((c) => c[`${COOKIE_ACCESS_TOKEN}`] === token)).to.be.true; @@ -32,10 +47,8 @@ describe("oauth2redirect", () => { describe("when OAuth2 code is not present", () => { it("should not set an accessToken cookie", () => { - idamServiceMock.resolveExchangeCode(token); - - return request(app) - .get("/oauth2redirect") + return initialiseOAuthSession() + .then(({ agent, state }) => agent.get(`/oauth2redirect?state=${state}`)) .then((res) => { const cookies = res.get("Set-Cookie").map((_) => cookie.parse(_)); expect(cookies.some((c) => c[`${COOKIE_ACCESS_TOKEN}`] === token)).to.be.false; @@ -45,6 +58,18 @@ describe("oauth2redirect", () => { }); }); + describe("when OAuth2 state is invalid", () => { + it("should reject the request before exchanging the code", () => { + return initialiseOAuthSession() + .then(({ agent }) => agent.get("/oauth2redirect?code=abc123&state=invalid-state")) + .then((res) => { + const cookies = res.get("Set-Cookie").map((_) => cookie.parse(_)); + expect(cookies.some((c) => c[`${COOKIE_ACCESS_TOKEN}`] === token)).to.be.false; + expect(res.status).to.equal(400); + }); + }); + }); + describe("OAuth2 redirect with secure flag", () => { const TOKEN = { @@ -57,6 +82,7 @@ describe("oauth2redirect", () => { let next; let config; let accessTokenRequest; + let oauthState; let oauth2redirect; beforeEach(() => { @@ -65,14 +91,19 @@ describe("oauth2redirect", () => { }; req = sinonExpressMock.mockReq(); - req.query = {code: "code", redirect_uri: "https://localhost:5000"}; + req.query = {code: "code", redirect_uri: "https://localhost:5000", state: "valid-state"}; res = sinonExpressMock.mockRes(); next = sinon.stub(); accessTokenRequest = sinon.stub(); accessTokenRequest.withArgs(req).returns(Promise.resolve(TOKEN)); + oauthState = { + clearOAuthState: sinon.stub(), + getOAuthState: sinon.stub().withArgs(req).returns("valid-state"), + }; oauth2redirect = proxyquire("../../main/routes/oauth2redirect", { "../oauth2/access-token-request": accessTokenRequest, + "../oauth2/oauth-state": oauthState, "config": config, }).oauth2redirect; }); @@ -95,6 +126,20 @@ describe("oauth2redirect", () => { oauth2redirect(req, res, next); }); + it("should reject requests with a mismatched state", () => { + req.query.state = "unexpected-state"; + + oauth2redirect(req, res, next); + + expect(oauthState.clearOAuthState).to.have.been.calledWith(req); + expect(accessTokenRequest).not.to.have.been.called; + expect(next).to.have.been.called; + + const error = next.firstCall.args[0]; + expect(error.message).to.equal("Invalid OAuth2 state parameter"); + expect(error.status).to.equal(400); + }); + }); }); diff --git a/src/test/user/auth-checker-user-only-filter.spec.ts b/src/test/user/auth-checker-user-only-filter.spec.ts index 0eca941a5..6d33ecb96 100644 --- a/src/test/user/auth-checker-user-only-filter.spec.ts +++ b/src/test/user/auth-checker-user-only-filter.spec.ts @@ -15,10 +15,13 @@ describe("authCheckerUserOnlyFilter", () => { const loginUrl = "http://idam.login"; const clientId = "ccd_admin"; const redirectUri = encodeURIComponent("http://localhost/oauth2redirect"); - const completeUrl = `${loginUrl}?response_type=code&client_id=${clientId}&redirect_uri=${redirectUri}`; + const state = "generated-state"; + const completeUrl = `${loginUrl}?response_type=code&client_id=${clientId}` + + `&redirect_uri=${redirectUri}&state=${encodeURIComponent(state)}`; let req; let res; + let oauthState; let userRequestAuthorizer; let filter; @@ -26,6 +29,7 @@ describe("authCheckerUserOnlyFilter", () => { req = { get: sinon.stub(), protocol: "http", + session: {}, }; req.get.withArgs("host").returns("localhost"); res = {}; @@ -34,6 +38,14 @@ describe("authCheckerUserOnlyFilter", () => { authorize: sinon.stub(), }; + oauthState = { + setOAuthState: sinon.stub().callsFake((request) => { + request.session.oauthState = state; + + return state; + }), + }; + const config = { get: sinon.stub(), }; @@ -41,6 +53,7 @@ describe("authCheckerUserOnlyFilter", () => { config.get.withArgs("idam.oauth2.client_id").returns(clientId); filter = proxyquire("../../main/user/auth-checker-user-only-filter", { + "../oauth2/oauth-state": oauthState, "./user-request-authorizer": userRequestAuthorizer, config, }).authCheckerUserOnlyFilter; @@ -88,20 +101,20 @@ describe("authCheckerUserOnlyFilter", () => { it("should redirect to the IdAM login URL", (done) => { res = { redirect: (code, url) => { - assert.equal(code, 302); - assert.equal(url, completeUrl); - done(); + try { + assert.equal(code, 302); + assert.equal(url, completeUrl); + expect(req.session.oauthState).to.equal(state); + expect(oauthState.setOAuthState).to.have.been.calledWith(req); + done(); + } catch (e) { + done(e); + } }, }; filter(req, res, (err) => { - try { - expect(err).to.equal(error); - expect(res.redirect).to.be.calledWith(302, completeUrl); - done(); - } catch (e) { - done(e); - } + expect(err).to.equal(error); }); }); });