diff --git a/py/noxfile.py b/py/noxfile.py index c92395da..7626d106 100644 --- a/py/noxfile.py +++ b/py/noxfile.py @@ -320,7 +320,7 @@ def test_cli(session): _install_test_deps(session) session.install(".[cli]") session.install("httpx") # Required for starlette.testclient - _run_tests(session, "braintrust/devserver/test_server_integration.py") + _run_tests(session, DEVSERVER_DIR) @nox.session() diff --git a/py/src/braintrust/devserver/cors.py b/py/src/braintrust/devserver/cors.py index 9f920fbf..f1b3e340 100644 --- a/py/src/braintrust/devserver/cors.py +++ b/py/src/braintrust/devserver/cors.py @@ -58,7 +58,7 @@ def check_origin(origin: str) -> bool: for allowed in ALLOWED_ORIGINS: if isinstance(allowed, str) and origin == allowed: return True - elif isinstance(allowed, re.Pattern) and allowed.match(origin): + elif isinstance(allowed, re.Pattern) and allowed.fullmatch(origin): return True return False diff --git a/py/src/braintrust/devserver/test_cors.py b/py/src/braintrust/devserver/test_cors.py new file mode 100644 index 00000000..68e12ba5 --- /dev/null +++ b/py/src/braintrust/devserver/test_cors.py @@ -0,0 +1,45 @@ +import asyncio +import unittest + +from braintrust.devserver.cors import check_origin, create_cors_middleware + + +class TestCorsOriginValidation(unittest.TestCase): + def test_check_origin_allows_legitimate_preview_origin(self): + self.assertTrue(check_origin("https://legit.preview.braintrust.dev")) + + def test_check_origin_rejects_preview_suffix_bypass(self): + self.assertFalse(check_origin("https://evil.preview.braintrust.dev.attacker.com")) + + def test_options_response_does_not_reflect_disallowed_origin(self): + async def app(scope, receive, send): + raise AssertionError("OPTIONS requests should be handled by the CORS middleware") + + middleware = create_cors_middleware()(app) + messages = [] + + async def receive(): + return {"type": "http.request", "body": b"", "more_body": False} + + async def send(message): + messages.append(message) + + scope = { + "type": "http", + "method": "OPTIONS", + "headers": [ + (b"origin", b"https://evil.preview.braintrust.dev.attacker.com"), + (b"access-control-request-method", b"POST"), + ], + } + + asyncio.run(middleware(scope, receive, send)) + + response_start = next(message for message in messages if message["type"] == "http.response.start") + headers = dict(response_start["headers"]) + self.assertNotIn(b"access-control-allow-origin", headers) + self.assertNotIn(b"access-control-allow-credentials", headers) + + +if __name__ == "__main__": + unittest.main()