11import abc
2+ import uuid
23import typing
34import fastapi
45import sqlmodel
6+ import sqlalchemy
57import importlib
68import os
79import tomllib
@@ -95,9 +97,11 @@ class ExtensionManager:
9597
9698 @classmethod
9799 def start_enabled (cls , app : fastapi .FastAPI ):
98- """Start all enabled entensions."""
100+ """Start all extensions enabled for the current client."""
101+ from app .business .client import ClientManager
102+
99103 cls .FASTAPI_APP = app
100- for extension in cls .get_installed ():
104+ for extension in cls .get_installed (enabled_only = True ):
101105 cls .start (extension = extension )
102106
103107 @classmethod
@@ -305,7 +309,7 @@ def download(cls, extid: str, version: Opt[str] = None) -> ExtensionModel:
305309 version = version or "0.0.0" ,
306310 nickname = nickname ,
307311 config = {},
308- disabled = True ,
312+ enabled = [] ,
309313 )
310314
311315 @classmethod
@@ -329,8 +333,15 @@ def install(cls, extid: ExtensionID, version: Opt[str] = None) -> ExtensionModel
329333 return extension
330334
331335 @classmethod
332- async def set_disabled (cls , extid : ExtensionID , disabled : bool ) -> ExtensionModel :
333- """Enable or disable an extension."""
336+ async def enable (cls , extid : ExtensionID ) -> ExtensionModel :
337+ """Enable an extension for the current client.
338+
339+ Adds the current client ID to the extension's enabled list and starts it.
340+ """
341+ from app .business .client import ClientManager
342+
343+ client_id = ClientManager .get_current_client_id ()
344+
334345 with SessionLocal () as db :
335346 extension = db .exec (
336347 sqlmodel .select (ExtensionModel ).where (ExtensionModel .id == extid )
@@ -339,32 +350,76 @@ async def set_disabled(cls, extid: ExtensionID, disabled: bool) -> ExtensionMode
339350 if not extension :
340351 raise ValueError (f"Extension with id { extid } not found." )
341352
342- extension .disabled = disabled
343- db .add (extension )
344- db .commit ()
345- db .refresh (extension )
353+ # Add client to enabled list if not already present
354+ current_enabled = set (extension .enabled )
355+ if client_id not in current_enabled :
356+ current_enabled .add (client_id )
357+ extension .enabled = list (current_enabled )
358+ db .add (extension )
359+ db .commit ()
360+ db .refresh (extension )
361+
362+ # Start extension if not already running
363+ if extid not in cls .RUNNING_EXTENSIONS :
364+ cls .start (extid )
365+
366+ return extension
367+
368+ @classmethod
369+ async def disable (cls , extid : ExtensionID ) -> ExtensionModel :
370+ """Disable an extension for the current client.
371+
372+ Removes the current client ID from the extension's enabled list and stops it.
373+ """
374+ from app .business .client import ClientManager
375+
376+ client_id = ClientManager .get_current_client_id ()
346377
347- if disabled :
378+ with SessionLocal () as db :
379+ extension = db .exec (
380+ sqlmodel .select (ExtensionModel ).where (ExtensionModel .id == extid )
381+ ).first ()
382+
383+ if not extension :
384+ raise ValueError (f"Extension with id { extid } not found." )
385+
386+ # Remove client from enabled list
387+ current_enabled = set (extension .enabled )
388+ if client_id in current_enabled :
389+ current_enabled .discard (client_id )
390+ extension .enabled = list (current_enabled )
391+ db .add (extension )
392+ db .commit ()
393+ db .refresh (extension )
394+
395+ # Close extension if running
396+ if extid in cls .RUNNING_EXTENSIONS :
348397 await cls .close (extid )
349- else :
350- cls .start (extid )
351398
352399 return extension
353400
354401 @classmethod
355- def get_installed (cls , disabled : Opt [bool ] = False ) -> tuple [ExtensionModel , ...]:
402+ def get_installed (
403+ cls ,
404+ enabled_only : bool = False ,
405+ ) -> tuple [ExtensionModel , ...]:
356406 """Get installed extensions.
357407
358- :param disabled : If True, include disabled extensions; otherwise, only enabled ones .
408+ :param enabled_only : If True, only return extensions enabled for the current client .
359409 """
410+ from app .business .client import ClientManager
411+
360412 with SessionLocal () as db :
361- return tuple (
362- db .exec (
363- sqlmodel .select (ExtensionModel ).where (
364- disabled is None or ExtensionModel .disabled == disabled
365- )
366- ).all ()
367- )
413+ query = sqlmodel .select (ExtensionModel )
414+
415+ if enabled_only :
416+ client_id = ClientManager .get_current_client_id ()
417+ # Filter: client_id must be in the enabled array
418+ query = query .where (
419+ ExtensionModel .enabled .any (client_id , operator = sqlalchemy .sql .operators .eq )
420+ )
421+
422+ return tuple (db .exec (query ).all ())
368423
369424 @classmethod
370425 def get (cls , extid : ExtensionID ) -> Opt [ExtensionModel ]:
@@ -460,7 +515,7 @@ def sync(cls):
460515 id = ext_id ,
461516 version = version ,
462517 nickname = nickname ,
463- disabled = True ,
518+ enabled = [] ,
464519 )
465520 db .add (new_ext )
466521 new_count += 1
0 commit comments