Skip to content

Commit 3b3e8b3

Browse files
authored
Merge pull request #52 from nitrictech/fix/middleware-composition
Fix/middleware composition
2 parents 069471d + f46dd9d commit 3b3e8b3

2 files changed

Lines changed: 49 additions & 17 deletions

File tree

nitric/faas.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,16 +217,33 @@ def compose_middleware(*middlewares: Union[Middleware, List[Middleware]]) -> Mid
217217
The resulting middleware will effectively be a chain of the provided middleware,
218218
where each calls the next in the chain when they're successful.
219219
"""
220+
middlewares = list(middlewares)
220221
if len(middlewares) == 1 and not isinstance(middlewares[0], list):
221222
return middlewares[0]
222223

223224
middlewares = [compose_middleware(m) if isinstance(m, list) else m for m in middlewares]
224225

225226
async def handler(ctx, next_middleware=lambda ctx: ctx):
226-
middleware_chain = functools.reduce(
227-
lambda acc_next, cur: lambda context: cur(context, acc_next), reversed(middlewares + (next_middleware,))
228-
)
229-
return middleware_chain(ctx)
227+
def reduce_chain(acc_next, cur):
228+
async def chained_middleware(context):
229+
# Count the positional arguments to determine if the function is a handler or middleware.
230+
all_args = cur.__code__.co_argcount
231+
kwargs = len(cur.__defaults__) if cur.__defaults__ is not None else 0
232+
pos_args = all_args - kwargs
233+
if pos_args == 2:
234+
# Call the middleware with next and return the result
235+
return (
236+
(await cur(context, acc_next)) if asyncio.iscoroutinefunction(cur) else cur(context, acc_next)
237+
)
238+
else:
239+
# Call the handler with ctx only, then call the remainder of the middleware chain
240+
result = (await cur(context)) if asyncio.iscoroutinefunction(cur) else cur(context)
241+
return (await acc_next(result)) if asyncio.iscoroutinefunction(acc_next) else acc_next(result)
242+
243+
return chained_middleware
244+
245+
middleware_chain = functools.reduce(reduce_chain, reversed(middlewares + [next_middleware]))
246+
return await middleware_chain(ctx)
230247

231248
return handler
232249

@@ -279,7 +296,7 @@ def start(self, *handlers: Union[Middleware, List[Middleware]]):
279296
if not self._any_handler and not self._http_handler and not self._event_handler:
280297
raise Exception("At least one handler function must be provided.")
281298

282-
asyncio.run(self.run())
299+
asyncio.run(self._run())
283300

284301
@property
285302
def _http_handler(self):
@@ -289,7 +306,7 @@ def _http_handler(self):
289306
def _event_handler(self):
290307
return self.__event_handler if self.__event_handler else self._any_handler
291308

292-
async def run(self):
309+
async def _run(self):
293310
"""Register a new FaaS worker with the Membrane, using the provided function as the handler."""
294311
channel = new_default_channel()
295312
client = FaasServiceStub(channel)

tests/test_faas.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import pytest
2323

24-
from nitric.faas import start, FunctionServer, EventContext, HttpContext
24+
from nitric.faas import start, FunctionServer, HttpContext, compose_middleware, HttpResponse
2525

2626
from nitricapi.nitric.faas.v1 import (
2727
ServerMessage,
@@ -46,6 +46,21 @@ def __init__(self):
4646

4747

4848
class EventClientTest(IsolatedAsyncioTestCase):
49+
async def test_compose_middleware(self):
50+
async def middleware(ctx: HttpContext, next) -> HttpContext:
51+
ctx.res.status = 401
52+
return await next(ctx)
53+
54+
async def handler(ctx: HttpContext) -> HttpContext:
55+
ctx.res.body = "some text"
56+
return ctx
57+
58+
composed = compose_middleware(middleware, handler)
59+
60+
ctx = HttpContext(response=HttpResponse(), request=None)
61+
result = await composed(ctx)
62+
assert result.res.status == 401
63+
4964
def test_start_with_one_handler(self):
5065
mock_server_constructor = Mock()
5166
mock_server = Object()
@@ -94,7 +109,7 @@ def test_start_starts_event_loop(self):
94109
mock_run.return_value = mock_run_coroutine
95110

96111
with patch("nitric.faas.compose_middleware", mock_compose):
97-
with patch("nitric.faas.FunctionServer.run", mock_run):
112+
with patch("nitric.faas.FunctionServer._run", mock_run):
98113
with patch("asyncio.run", mock_asyncio_run):
99114
FunctionServer().start(mock_handler)
100115

@@ -124,7 +139,7 @@ async def mock_stream(self, request_iterator):
124139
with patch("nitric.faas.AsyncChannel", mock_async_channel_init), patch(
125140
"nitricapi.nitric.faas.v1.FaasServiceStub.trigger_stream", mock_stream
126141
), patch("nitric.faas.new_default_channel", mock_grpc_channel):
127-
await FunctionServer().http(mock_handler).run()
142+
await FunctionServer().http(mock_handler)._run()
128143

129144
# gRPC channel created
130145
mock_grpc_channel.assert_called_once()
@@ -165,7 +180,7 @@ async def mock_stream(self, request_iterator):
165180
with patch("nitric.faas.AsyncChannel", mock_async_channel_init), patch(
166181
"nitricapi.nitric.faas.v1.FaasServiceStub.trigger_stream", mock_stream
167182
), patch("nitric.faas.new_default_channel", mock_grpc_channel):
168-
await FunctionServer().http(mock_http_handler).event(mock_event_handler).run()
183+
await FunctionServer().http(mock_http_handler).event(mock_event_handler)._run()
169184

170185
# accept the init response from server
171186
assert 1 == stream_calls
@@ -200,7 +215,7 @@ async def mock_stream(self, request_iterator):
200215
with patch("nitric.faas.AsyncChannel", mock_async_channel_init), patch(
201216
"nitricapi.nitric.faas.v1.FaasServiceStub.trigger_stream", mock_stream
202217
), patch("nitric.faas.new_default_channel", mock_grpc_channel):
203-
await FunctionServer().http(mock_http_handler).event(mock_event_handler).run()
218+
await FunctionServer().http(mock_http_handler).event(mock_event_handler)._run()
204219

205220
# accept the init response from server
206221
assert 1 == stream_calls
@@ -235,7 +250,7 @@ async def mock_stream(self, request_iterator):
235250
with patch("nitric.faas.AsyncChannel", mock_async_channel_init), patch(
236251
"nitricapi.nitric.faas.v1.FaasServiceStub.trigger_stream", mock_stream
237252
), patch("nitric.faas.new_default_channel", mock_grpc_channel):
238-
await FunctionServer().http(mock_http_handler).event(mock_event_handler).run()
253+
await FunctionServer().http(mock_http_handler).event(mock_event_handler)._run()
239254

240255
# accept the init response from server
241256
assert 1 == stream_calls
@@ -270,7 +285,7 @@ async def mock_stream(self, request_iterator):
270285
with patch("nitric.faas.AsyncChannel", mock_async_channel_init), patch(
271286
"nitricapi.nitric.faas.v1.FaasServiceStub.trigger_stream", mock_stream
272287
), patch("nitric.faas.new_default_channel", mock_grpc_channel):
273-
await FunctionServer().http(mock_http_handler).event(mock_event_handler).run()
288+
await FunctionServer().http(mock_http_handler).event(mock_event_handler)._run()
274289

275290
# accept the init response from server
276291
assert 1 == stream_calls
@@ -305,7 +320,7 @@ async def mock_stream(self, request_iterator):
305320
with patch("nitric.faas.AsyncChannel", mock_async_channel_init), patch(
306321
"nitricapi.nitric.faas.v1.FaasServiceStub.trigger_stream", mock_stream
307322
), patch("nitric.faas.new_default_channel", mock_grpc_channel):
308-
await FunctionServer().http(mock_http_handler).event(mock_event_handler).run()
323+
await FunctionServer().http(mock_http_handler).event(mock_event_handler)._run()
309324

310325
# accept the init response from server
311326
assert 1 == stream_calls
@@ -334,7 +349,7 @@ async def mock_stream(self, request_iterator):
334349
with patch("nitric.faas.AsyncChannel", mock_async_channel_init), patch(
335350
"nitricapi.nitric.faas.v1.FaasServiceStub.trigger_stream", mock_stream
336351
), patch("nitric.faas.new_default_channel", mock_grpc_channel):
337-
await FunctionServer().event(mock_handler).run()
352+
await FunctionServer().event(mock_handler)._run()
338353

339354
# accept the trigger response from server
340355
assert 1 == stream_calls
@@ -373,7 +388,7 @@ async def mock_stream(self, request_iterator):
373388
with patch("nitric.faas.AsyncChannel", mock_async_channel_init), patch(
374389
"nitricapi.nitric.faas.v1.FaasServiceStub.trigger_stream", mock_stream
375390
), patch("nitric.faas.new_default_channel", mock_grpc_channel):
376-
await FunctionServer().http(mock_handler).run()
391+
await FunctionServer().http(mock_handler)._run()
377392

378393
# accept the trigger response from server
379394
assert 1 == stream_calls
@@ -414,7 +429,7 @@ async def mock_stream(self, request_iterator):
414429
with patch("nitric.faas.AsyncChannel", mock_async_channel_init), patch(
415430
"nitricapi.nitric.faas.v1.FaasServiceStub.trigger_stream", mock_stream
416431
), patch("nitric.faas.new_default_channel", mock_grpc_channel):
417-
await FunctionServer().http(mock_handler).run()
432+
await FunctionServer().http(mock_handler)._run()
418433

419434
# accept the trigger response from server
420435
assert 1 == stream_calls

0 commit comments

Comments
 (0)