Skip to content

Commit 0e048e5

Browse files
authored
feat: Websocket support (#120)
Adds websocket support to the python SDK. Example: ```python from nitric.resources import websocket my_socket = websocket("my-socket") @my_socket.on("connect") async def handle_connect(ctx): print("handling connection") ```
2 parents 2d27d67 + 1fb6363 commit 0e048e5

15 files changed

Lines changed: 600 additions & 4 deletions

File tree

makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ test:
2121
@echo Running Tox tests
2222
@tox -e py
2323

24-
NITRIC_VERSION="v0.27.0"
24+
NITRIC_VERSION="v0.32.0"
2525

2626
download:
2727
@curl -L https://github.com/nitrictech/nitric/releases/download/${NITRIC_VERSION}/contracts.tgz -o contracts.tgz

nitric/api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from nitric.api.storage import Storage
2323
from nitric.api.documents import Documents
2424
from nitric.api.secrets import Secrets
25+
from nitric.api.websocket import Websocket
2526

2627
__all__ = [
2728
"Events",
@@ -33,4 +34,5 @@
3334
"FailedTask",
3435
"TopicRef",
3536
"Secrets",
37+
"Websocket",
3638
]

nitric/api/websocket.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Union
2+
from grpclib.client import Channel
3+
from grpclib import GRPCError
4+
from nitric.exception import exception_from_grpc_error
5+
from nitric.utils import new_default_channel
6+
from nitric.proto.nitric.websocket.v1 import (
7+
WebsocketServiceStub,
8+
WebsocketSendRequest,
9+
)
10+
11+
12+
class Websocket(object):
13+
"""Nitric generic Websocket client."""
14+
15+
def __init__(self):
16+
"""Construct a Nitric Websocket Client."""
17+
self._channel: Union[Channel, None] = new_default_channel()
18+
# Had to make unprotected (publically accessible in order to use as part of bucket reference)
19+
self.websocket_stub = WebsocketServiceStub(channel=self._channel)
20+
21+
async def send(self, socket: str, connection_id: str, data: bytes):
22+
"""Send data to a connection on a socket."""
23+
try:
24+
await self.websocket_stub.send(
25+
websocket_send_request=WebsocketSendRequest(socket=socket, connection_id=connection_id, data=data)
26+
)
27+
except GRPCError as grpc_err:
28+
raise exception_from_grpc_error(grpc_err)

nitric/application.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class Nitric:
4545
"secret": {},
4646
"queue": {},
4747
"collection": {},
48+
"websocket": {},
4849
}
4950

5051
@classmethod

nitric/faas.py

Lines changed: 114 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import functools
2525
import json
2626
import traceback
27-
from typing import Dict, Generic, Protocol, Union, List, TypeVar, Any, Optional, Sequence
27+
from typing import Dict, Generic, Literal, Protocol, Union, List, TypeVar, Any, Optional, Sequence
2828
from opentelemetry import context, propagate
2929

3030
import betterproto
@@ -48,6 +48,9 @@
4848
BucketNotificationConfig,
4949
BucketNotificationType,
5050
NotificationResponseContext,
51+
WebsocketResponseContext,
52+
WebsocketEvent,
53+
WebsocketWorker,
5154
)
5255
import grpclib
5356
import asyncio
@@ -101,6 +104,10 @@ def bucket_notification(self) -> Union[BucketNotificationContext, None]:
101104
"""Return this context as a BucketNotificationContext if it is one, otherwise returns None."""
102105
return None
103106

107+
def websocket(self) -> Union[WebsocketContext, None]:
108+
"""Return this context as a WebsocketContext if it is one, otherwise returns None."""
109+
return None
110+
104111

105112
def _ctx_from_grpc_trigger_request(trigger_request: TriggerRequest, options: Optional[FaasClientOptions] = None):
106113
"""Return a TriggerContext from a TriggerRequest."""
@@ -114,6 +121,8 @@ def _ctx_from_grpc_trigger_request(trigger_request: TriggerRequest, options: Opt
114121
return FileNotificationContext.from_grpc_trigger_request_and_options(trigger_request, options)
115122
else:
116123
return BucketNotificationContext.from_grpc_trigger_request(trigger_request)
124+
elif context_type == "websocket":
125+
return WebsocketContext.from_grpc_trigger_request(trigger_request)
117126
else:
118127
print(f"Trigger with unknown context received, context type: {context_type}")
119128
raise Exception(f"Unknown trigger context, type: {context_type}")
@@ -154,6 +163,10 @@ def _grpc_response_from_ctx(ctx: TriggerContext) -> TriggerResponse:
154163
if bucket_context is not None:
155164
return TriggerResponse(notification=NotificationResponseContext(success=bucket_context.res.success))
156165

166+
websocket_context = ctx.websocket()
167+
if websocket_context is not None:
168+
return TriggerResponse(websocket=WebsocketResponseContext(success=websocket_context.res.success))
169+
157170
raise Exception("Unknown Trigger Context type, unable to return valid response")
158171

159172

@@ -302,6 +315,54 @@ def from_grpc_trigger_request(trigger_request: TriggerRequest):
302315
)
303316

304317

318+
class WebsocketRequest(Request):
319+
"""Represents an incoming websocket event."""
320+
321+
def __init__(
322+
self, connection_id: str, data: bytes, query: Dict[str, str | List[str]], trace_context: Dict[str, str]
323+
):
324+
"""Construct a new WebsocketRequest."""
325+
super().__init__(data, trace_context)
326+
327+
self.connection_id = connection_id
328+
self.query = query
329+
330+
331+
class WebsocketResponse(Response):
332+
"""Represents a response to a websocket event."""
333+
334+
def __init__(self, success: bool = True):
335+
"""Construct a new WebsocketResponse."""
336+
self.success = success
337+
338+
339+
class WebsocketContext(TriggerContext):
340+
"""Represents the full request/response context for a Websocket based trigger."""
341+
342+
def __init__(self, request: WebsocketRequest, response: Optional[WebsocketResponse] = None):
343+
"""Construct a new WebsocketContext."""
344+
super().__init__()
345+
self.req = request
346+
self.res = response if response else WebsocketResponse()
347+
348+
def websocket(self) -> WebsocketContext:
349+
"""Return this WebsocketContext, used when determining the context type of a trigger."""
350+
return self
351+
352+
@staticmethod
353+
def from_grpc_trigger_request(trigger_request: TriggerRequest) -> WebsocketContext:
354+
"""Construct a new WebsocketContext from a Websocket trigger from the Nitric Membrane."""
355+
query: Record = {k: v.value for (k, v) in trigger_request.websocket.query_params.items()}
356+
return WebsocketContext(
357+
request=WebsocketRequest(
358+
data=trigger_request.data,
359+
connection_id=trigger_request.websocket.connection_id,
360+
query=query,
361+
trace_context=trigger_request.trace_context.values,
362+
)
363+
)
364+
365+
305366
class BucketNotificationRequest(Request):
306367
"""Represents a translated Event, from a subscribed bucket notification, forwarded from the Nitric Membrane."""
307368

@@ -424,6 +485,26 @@ def _to_grpc_event_type(event_type: str) -> BucketNotificationType:
424485
raise ValueError(f"Event type {event_type} is unsupported")
425486

426487

488+
class WebsocketWorkerOptions:
489+
"""Options for websocket workers."""
490+
491+
def __init__(self, socket_name: str, event_type: Literal["connect", "disconnect", "message"]):
492+
"""Construct new websocket worker options."""
493+
self.socket_name = socket_name
494+
self.event_type = WebsocketWorkerOptions._to_grpc_event_type(event_type)
495+
496+
@staticmethod
497+
def _to_grpc_event_type(event_type: Literal["connect", "disconnect", "message"]) -> WebsocketEvent:
498+
if event_type == "connect":
499+
return WebsocketEvent.Connect
500+
elif event_type == "disconnect":
501+
return WebsocketEvent.Disconnect
502+
elif event_type == "message":
503+
return WebsocketEvent.Message
504+
else:
505+
raise ValueError(f"Event type {event_type} is unsupported")
506+
507+
427508
class FileNotificationWorkerOptions(BucketNotificationWorkerOptions):
428509
"""Options for bucket notification workers with file references."""
429510

@@ -499,13 +580,16 @@ class FaasWorkerOptions:
499580
SubscriptionWorkerOptions,
500581
BucketNotificationWorkerOptions,
501582
FileNotificationWorkerOptions,
583+
WebsocketWorkerOptions,
502584
FaasWorkerOptions,
503585
]
504586

505587
# class Context(Protocol):
506588
# ...
507589

508-
C = TypeVar("C", TriggerContext, HttpContext, EventContext, FileNotificationContext, BucketNotificationContext)
590+
C = TypeVar(
591+
"C", TriggerContext, HttpContext, EventContext, FileNotificationContext, BucketNotificationContext, WebsocketContext
592+
)
509593

510594

511595
class Middleware(Protocol, Generic[C]):
@@ -528,11 +612,13 @@ async def __call__(self, ctx: C) -> C | None:
528612
EventMiddleware = Middleware[EventContext]
529613
BucketNotificationMiddleware = Middleware[BucketNotificationContext]
530614
FileNotificationMiddleware = Middleware[FileNotificationContext]
615+
WebsocketMiddleware = Middleware[WebsocketContext]
531616

532617
HttpHandler = Handler[HttpContext]
533618
EventHandler = Handler[EventContext]
534619
BucketNotificationHandler = Handler[BucketNotificationContext]
535620
FileNotificationHandler = Handler[FileNotificationContext]
621+
WebsocketHandler = Handler[WebsocketContext]
536622

537623

538624
def _convert_to_middleware(handler: Handler[C] | Middleware[C]) -> Middleware[C]:
@@ -615,6 +701,7 @@ def __init__(self, opts: FaasClientOptions):
615701
self.__bucket_notification_handler: Optional[
616702
Union[BucketNotificationMiddleware, FileNotificationMiddleware]
617703
] = None
704+
self.__websocket_handler: Optional[WebsocketMiddleware] = None
618705
self._opts = opts
619706

620707
def http(self, *handlers: HttpMiddleware | HttpHandler) -> FunctionServer:
@@ -646,9 +733,23 @@ def bucket_notification(
646733
self.__bucket_notification_handler = compose_middleware(*handlers)
647734
return self
648735

736+
def websocket(self, *handlers: WebsocketMiddleware) -> FunctionServer:
737+
"""
738+
Register one or more Websocket Trigger Handlers or Middleware.
739+
740+
When multiple handlers are provided, they will be called in order.
741+
"""
742+
self.__websocket_handler = compose_middleware(*handlers)
743+
return self
744+
649745
async def start(self):
650746
"""Start the function server using the previously provided middleware."""
651-
if not self._http_handler and not self._event_handler and not self.__bucket_notification_handler:
747+
if (
748+
not self._http_handler
749+
and not self._event_handler
750+
and not self.__bucket_notification_handler
751+
and not self.__websocket_handler
752+
):
652753
raise Exception("At least one handler function must be provided.")
653754

654755
await self._run()
@@ -665,6 +766,10 @@ def _event_handler(self):
665766
def _bucket_notification_handler(self):
666767
return self.__bucket_notification_handler
667768

769+
@property
770+
def _websocket_handler(self):
771+
return self._websocket_handler
772+
668773
async def _run(self):
669774
"""Register a new FaaS worker with the Membrane, using the provided function as the handler."""
670775
channel = new_default_channel()
@@ -697,6 +802,10 @@ async def _run(self):
697802
init_request = InitRequest(
698803
bucket_notification=BucketNotificationWorker(bucket=self._opts.bucket_name, config=config)
699804
)
805+
elif isinstance(self._opts, WebsocketWorkerOptions):
806+
init_request = InitRequest(
807+
websocket=WebsocketWorker(socket=self._opts.socket_name, event=self._opts.event_type)
808+
)
700809

701810
# let the membrane server know we're ready to start
702811
await request_channel.send(ClientMessage(init_request=init_request))
@@ -722,6 +831,8 @@ async def _run(self):
722831
func = self._event_handler
723832
elif ctx.bucket_notification():
724833
func = self._bucket_notification_handler
834+
elif ctx.websocket():
835+
func = self._websocket_handler
725836

726837
assert func is not None
727838

nitric/proto/nitric/deploy/v1/__init__.py

Lines changed: 35 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)