2424import functools
2525import json
2626import traceback
27- from typing import Dict , Generic , Protocol , Union , List , TypeVar , Any , Optional , Sequence
27+ from typing import Dict , Generic , Literal , Protocol , Union , List , TypeVar , Any , Optional , Sequence
2828from opentelemetry import context , propagate
2929
3030import betterproto
4848 BucketNotificationConfig ,
4949 BucketNotificationType ,
5050 NotificationResponseContext ,
51+ WebsocketResponseContext ,
52+ WebsocketEvent ,
53+ WebsocketWorker ,
5154)
5255import grpclib
5356import asyncio
@@ -101,6 +104,10 @@ def bucket_notification(self) -> Union[BucketNotificationContext, None]:
101104 """Return this context as a BucketNotificationContext if it is one, otherwise returns None."""
102105 return None
103106
107+ def websocket (self ) -> Union [WebsocketContext , None ]:
108+ """Return this context as a WebsocketContext if it is one, otherwise returns None."""
109+ return None
110+
104111
105112def _ctx_from_grpc_trigger_request (trigger_request : TriggerRequest , options : Optional [FaasClientOptions ] = None ):
106113 """Return a TriggerContext from a TriggerRequest."""
@@ -114,6 +121,8 @@ def _ctx_from_grpc_trigger_request(trigger_request: TriggerRequest, options: Opt
114121 return FileNotificationContext .from_grpc_trigger_request_and_options (trigger_request , options )
115122 else :
116123 return BucketNotificationContext .from_grpc_trigger_request (trigger_request )
124+ elif context_type == "websocket" :
125+ return WebsocketContext .from_grpc_trigger_request (trigger_request )
117126 else :
118127 print (f"Trigger with unknown context received, context type: { context_type } " )
119128 raise Exception (f"Unknown trigger context, type: { context_type } " )
@@ -154,6 +163,10 @@ def _grpc_response_from_ctx(ctx: TriggerContext) -> TriggerResponse:
154163 if bucket_context is not None :
155164 return TriggerResponse (notification = NotificationResponseContext (success = bucket_context .res .success ))
156165
166+ websocket_context = ctx .websocket ()
167+ if websocket_context is not None :
168+ return TriggerResponse (websocket = WebsocketResponseContext (success = websocket_context .res .success ))
169+
157170 raise Exception ("Unknown Trigger Context type, unable to return valid response" )
158171
159172
@@ -302,6 +315,54 @@ def from_grpc_trigger_request(trigger_request: TriggerRequest):
302315 )
303316
304317
318+ class WebsocketRequest (Request ):
319+ """Represents an incoming websocket event."""
320+
321+ def __init__ (
322+ self , connection_id : str , data : bytes , query : Dict [str , str | List [str ]], trace_context : Dict [str , str ]
323+ ):
324+ """Construct a new WebsocketRequest."""
325+ super ().__init__ (data , trace_context )
326+
327+ self .connection_id = connection_id
328+ self .query = query
329+
330+
331+ class WebsocketResponse (Response ):
332+ """Represents a response to a websocket event."""
333+
334+ def __init__ (self , success : bool = True ):
335+ """Construct a new WebsocketResponse."""
336+ self .success = success
337+
338+
339+ class WebsocketContext (TriggerContext ):
340+ """Represents the full request/response context for a Websocket based trigger."""
341+
342+ def __init__ (self , request : WebsocketRequest , response : Optional [WebsocketResponse ] = None ):
343+ """Construct a new WebsocketContext."""
344+ super ().__init__ ()
345+ self .req = request
346+ self .res = response if response else WebsocketResponse ()
347+
348+ def websocket (self ) -> WebsocketContext :
349+ """Return this WebsocketContext, used when determining the context type of a trigger."""
350+ return self
351+
352+ @staticmethod
353+ def from_grpc_trigger_request (trigger_request : TriggerRequest ) -> WebsocketContext :
354+ """Construct a new WebsocketContext from a Websocket trigger from the Nitric Membrane."""
355+ query : Record = {k : v .value for (k , v ) in trigger_request .websocket .query_params .items ()}
356+ return WebsocketContext (
357+ request = WebsocketRequest (
358+ data = trigger_request .data ,
359+ connection_id = trigger_request .websocket .connection_id ,
360+ query = query ,
361+ trace_context = trigger_request .trace_context .values ,
362+ )
363+ )
364+
365+
305366class BucketNotificationRequest (Request ):
306367 """Represents a translated Event, from a subscribed bucket notification, forwarded from the Nitric Membrane."""
307368
@@ -424,6 +485,26 @@ def _to_grpc_event_type(event_type: str) -> BucketNotificationType:
424485 raise ValueError (f"Event type { event_type } is unsupported" )
425486
426487
488+ class WebsocketWorkerOptions :
489+ """Options for websocket workers."""
490+
491+ def __init__ (self , socket_name : str , event_type : Literal ["connect" , "disconnect" , "message" ]):
492+ """Construct new websocket worker options."""
493+ self .socket_name = socket_name
494+ self .event_type = WebsocketWorkerOptions ._to_grpc_event_type (event_type )
495+
496+ @staticmethod
497+ def _to_grpc_event_type (event_type : Literal ["connect" , "disconnect" , "message" ]) -> WebsocketEvent :
498+ if event_type == "connect" :
499+ return WebsocketEvent .Connect
500+ elif event_type == "disconnect" :
501+ return WebsocketEvent .Disconnect
502+ elif event_type == "message" :
503+ return WebsocketEvent .Message
504+ else :
505+ raise ValueError (f"Event type { event_type } is unsupported" )
506+
507+
427508class FileNotificationWorkerOptions (BucketNotificationWorkerOptions ):
428509 """Options for bucket notification workers with file references."""
429510
@@ -499,13 +580,16 @@ class FaasWorkerOptions:
499580 SubscriptionWorkerOptions ,
500581 BucketNotificationWorkerOptions ,
501582 FileNotificationWorkerOptions ,
583+ WebsocketWorkerOptions ,
502584 FaasWorkerOptions ,
503585]
504586
505587# class Context(Protocol):
506588# ...
507589
508- C = TypeVar ("C" , TriggerContext , HttpContext , EventContext , FileNotificationContext , BucketNotificationContext )
590+ C = TypeVar (
591+ "C" , TriggerContext , HttpContext , EventContext , FileNotificationContext , BucketNotificationContext , WebsocketContext
592+ )
509593
510594
511595class Middleware (Protocol , Generic [C ]):
@@ -528,11 +612,13 @@ async def __call__(self, ctx: C) -> C | None:
528612EventMiddleware = Middleware [EventContext ]
529613BucketNotificationMiddleware = Middleware [BucketNotificationContext ]
530614FileNotificationMiddleware = Middleware [FileNotificationContext ]
615+ WebsocketMiddleware = Middleware [WebsocketContext ]
531616
532617HttpHandler = Handler [HttpContext ]
533618EventHandler = Handler [EventContext ]
534619BucketNotificationHandler = Handler [BucketNotificationContext ]
535620FileNotificationHandler = Handler [FileNotificationContext ]
621+ WebsocketHandler = Handler [WebsocketContext ]
536622
537623
538624def _convert_to_middleware (handler : Handler [C ] | Middleware [C ]) -> Middleware [C ]:
@@ -615,6 +701,7 @@ def __init__(self, opts: FaasClientOptions):
615701 self .__bucket_notification_handler : Optional [
616702 Union [BucketNotificationMiddleware , FileNotificationMiddleware ]
617703 ] = None
704+ self .__websocket_handler : Optional [WebsocketMiddleware ] = None
618705 self ._opts = opts
619706
620707 def http (self , * handlers : HttpMiddleware | HttpHandler ) -> FunctionServer :
@@ -646,9 +733,23 @@ def bucket_notification(
646733 self .__bucket_notification_handler = compose_middleware (* handlers )
647734 return self
648735
736+ def websocket (self , * handlers : WebsocketMiddleware ) -> FunctionServer :
737+ """
738+ Register one or more Websocket Trigger Handlers or Middleware.
739+
740+ When multiple handlers are provided, they will be called in order.
741+ """
742+ self .__websocket_handler = compose_middleware (* handlers )
743+ return self
744+
649745 async def start (self ):
650746 """Start the function server using the previously provided middleware."""
651- if not self ._http_handler and not self ._event_handler and not self .__bucket_notification_handler :
747+ if (
748+ not self ._http_handler
749+ and not self ._event_handler
750+ and not self .__bucket_notification_handler
751+ and not self .__websocket_handler
752+ ):
652753 raise Exception ("At least one handler function must be provided." )
653754
654755 await self ._run ()
@@ -665,6 +766,10 @@ def _event_handler(self):
665766 def _bucket_notification_handler (self ):
666767 return self .__bucket_notification_handler
667768
769+ @property
770+ def _websocket_handler (self ):
771+ return self ._websocket_handler
772+
668773 async def _run (self ):
669774 """Register a new FaaS worker with the Membrane, using the provided function as the handler."""
670775 channel = new_default_channel ()
@@ -697,6 +802,10 @@ async def _run(self):
697802 init_request = InitRequest (
698803 bucket_notification = BucketNotificationWorker (bucket = self ._opts .bucket_name , config = config )
699804 )
805+ elif isinstance (self ._opts , WebsocketWorkerOptions ):
806+ init_request = InitRequest (
807+ websocket = WebsocketWorker (socket = self ._opts .socket_name , event = self ._opts .event_type )
808+ )
700809
701810 # let the membrane server know we're ready to start
702811 await request_channel .send (ClientMessage (init_request = init_request ))
@@ -722,6 +831,8 @@ async def _run(self):
722831 func = self ._event_handler
723832 elif ctx .bucket_notification ():
724833 func = self ._bucket_notification_handler
834+ elif ctx .websocket ():
835+ func = self ._websocket_handler
725836
726837 assert func is not None
727838
0 commit comments