@@ -217,16 +217,33 @@ def compose_middleware(*middlewares: Union[Middleware, List[Middleware]]) -> Mid
217217 The resulting middleware will effectively be a chain of the provided middleware,
218218 where each calls the next in the chain when they're successful.
219219 """
220+ middlewares = list (middlewares )
220221 if len (middlewares ) == 1 and not isinstance (middlewares [0 ], list ):
221222 return middlewares [0 ]
222223
223224 middlewares = [compose_middleware (m ) if isinstance (m , list ) else m for m in middlewares ]
224225
225226 async def handler (ctx , next_middleware = lambda ctx : ctx ):
226- middleware_chain = functools .reduce (
227- lambda acc_next , cur : lambda context : cur (context , acc_next ), reversed (middlewares + (next_middleware ,))
228- )
229- return middleware_chain (ctx )
227+ def reduceChain (acc_next , cur ):
228+ async def chainedMiddleware (context ):
229+ # Count the positional arguments to determine if the function is a handler or middleware.
230+ all_args = cur .__code__ .co_argcount
231+ kwargs = len (cur .__defaults__ ) if cur .__defaults__ is not None else 0
232+ pos_args = all_args - kwargs
233+ if pos_args == 2 :
234+ # Call the middleware with next and return the result
235+ return (
236+ (await cur (context , acc_next )) if asyncio .iscoroutinefunction (cur ) else cur (context , acc_next )
237+ )
238+ else :
239+ # Call the handler with ctx only, then call the remainder of the middleware chain
240+ result = (await cur (context )) if asyncio .iscoroutinefunction (cur ) else cur (context )
241+ return (await acc_next (result )) if asyncio .iscoroutinefunction (acc_next ) else acc_next (result )
242+
243+ return chainedMiddleware
244+
245+ middleware_chain = functools .reduce (reduceChain , reversed (middlewares + [next_middleware ]))
246+ return await middleware_chain (ctx )
230247
231248 return handler
232249
0 commit comments