Skip to content

Commit a76d357

Browse files
committed
fix: allow for async and sync handlers/middleware
1 parent 069471d commit a76d357

2 files changed

Lines changed: 37 additions & 5 deletions

File tree

nitric/faas.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -217,16 +217,33 @@ def compose_middleware(*middlewares: Union[Middleware, List[Middleware]]) -> Mid
217217
The resulting middleware will effectively be a chain of the provided middleware,
218218
where each calls the next in the chain when they're successful.
219219
"""
220+
middlewares = list(middlewares)
220221
if len(middlewares) == 1 and not isinstance(middlewares[0], list):
221222
return middlewares[0]
222223

223224
middlewares = [compose_middleware(m) if isinstance(m, list) else m for m in middlewares]
224225

225226
async def handler(ctx, next_middleware=lambda ctx: ctx):
226-
middleware_chain = functools.reduce(
227-
lambda acc_next, cur: lambda context: cur(context, acc_next), reversed(middlewares + (next_middleware,))
228-
)
229-
return middleware_chain(ctx)
227+
def reduceChain(acc_next, cur):
228+
async def chainedMiddleware(context):
229+
# Count the positional arguments to determine if the function is a handler or middleware.
230+
all_args = cur.__code__.co_argcount
231+
kwargs = len(cur.__defaults__) if cur.__defaults__ is not None else 0
232+
pos_args = all_args - kwargs
233+
if pos_args == 2:
234+
# Call the middleware with next and return the result
235+
return (
236+
(await cur(context, acc_next)) if asyncio.iscoroutinefunction(cur) else cur(context, acc_next)
237+
)
238+
else:
239+
# Call the handler with ctx only, then call the remainder of the middleware chain
240+
result = (await cur(context)) if asyncio.iscoroutinefunction(cur) else cur(context)
241+
return (await acc_next(result)) if asyncio.iscoroutinefunction(acc_next) else acc_next(result)
242+
243+
return chainedMiddleware
244+
245+
middleware_chain = functools.reduce(reduceChain, reversed(middlewares + [next_middleware]))
246+
return await middleware_chain(ctx)
230247

231248
return handler
232249

tests/test_faas.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import pytest
2323

24-
from nitric.faas import start, FunctionServer, EventContext, HttpContext
24+
from nitric.faas import start, FunctionServer, HttpContext, compose_middleware, HttpResponse
2525

2626
from nitricapi.nitric.faas.v1 import (
2727
ServerMessage,
@@ -46,6 +46,21 @@ def __init__(self):
4646

4747

4848
class EventClientTest(IsolatedAsyncioTestCase):
49+
async def test_compose_middleware(self):
50+
async def middleware(ctx: HttpContext, next) -> HttpContext:
51+
ctx.res.status = 401
52+
return await next(ctx)
53+
54+
async def handler(ctx: HttpContext) -> HttpContext:
55+
ctx.res.body = "some text"
56+
return ctx
57+
58+
composed = compose_middleware(middleware, handler)
59+
60+
ctx = HttpContext(response=HttpResponse(), request=None)
61+
result = await composed(ctx)
62+
assert result.res.status == 401
63+
4964
def test_start_with_one_handler(self):
5065
mock_server_constructor = Mock()
5166
mock_server = Object()

0 commit comments

Comments
 (0)