Skip to content

Commit 0e181ba

Browse files
committed
feat(client): add client management and registration functionality
- Introduced ClientManager for handling client registration and lookup. - Added ClientModel schema for client representation. - Updated settings to include client-specific configurations. - Modified extension management to enable/disable extensions per client. - Created migration for clients table and updated extensions schema. - Enhanced API endpoints for enabling/disabling extensions for clients.
1 parent 949849d commit 0e181ba

13 files changed

Lines changed: 355 additions & 28 deletions

File tree

.env.example

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Server Configuration
22
HOST=0.0.0.0
33
PORT=8000
4+
CLIENT_ID= # optional
5+
CLIENT_NAME= # optional
6+
CLIENT_BASE_URL= # required
47

58
# JWT Authentication (Required)
69
JWT_SECRET=your_jwt_secret_here

app/business/client/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .main import ClientManager
2+
3+
__all__ = [
4+
"ClientManager",
5+
]

app/business/client/main.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
__all__ = [
2+
"ClientManager",
3+
]
4+
5+
import sqlalchemy.dialects.postgresql
6+
import sqlmodel
7+
from app.engine import SessionLocal
8+
from app.settings import settings
9+
from app.schemas.client.main import ClientModel, ClientID
10+
from libs.obsrv.main import get_logger
11+
12+
13+
LOGGER = get_logger().getChild(__name__)
14+
15+
16+
class ClientManager:
17+
"""Manages client registration and lookup."""
18+
19+
@classmethod
20+
def register_self(cls) -> ClientModel:
21+
"""Register or update the current client in the database.
22+
23+
Uses upsert to handle both initial registration and updates.
24+
Called during application startup.
25+
"""
26+
with SessionLocal() as db:
27+
stmt = sqlalchemy.dialects.postgresql.insert(ClientModel).values(
28+
id=settings.client_id,
29+
name=settings.client_name,
30+
rest_api_url=settings.client_base_url,
31+
)
32+
stmt = stmt.on_conflict_do_update(
33+
index_elements=["id"],
34+
set_=dict(
35+
name=stmt.excluded.name,
36+
rest_api_url=stmt.excluded.rest_api_url,
37+
),
38+
)
39+
db.exec(stmt) # type: ignore
40+
db.commit()
41+
42+
# Fetch and return the registered client
43+
client = db.exec(
44+
sqlmodel.select(ClientModel).where(ClientModel.id == settings.client_id)
45+
).one()
46+
47+
LOGGER.info(f"Client registered: {client.name} ({client.id})")
48+
return client
49+
50+
@classmethod
51+
def get_current_client_id(cls) -> ClientID:
52+
"""Get the current client's ID."""
53+
return settings.client_id
54+
55+
@classmethod
56+
def get(cls, client_id: ClientID) -> ClientModel | None:
57+
"""Get a client by ID."""
58+
with SessionLocal() as db:
59+
return db.exec(
60+
sqlmodel.select(ClientModel).where(ClientModel.id == client_id)
61+
).first()
62+
63+
@classmethod
64+
def get_all(cls) -> tuple[ClientModel, ...]:
65+
"""Get all registered clients."""
66+
with SessionLocal() as db:
67+
return tuple(db.exec(sqlmodel.select(ClientModel)).all())

app/business/extension/main.py

Lines changed: 77 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import abc
2+
import uuid
23
import typing
34
import fastapi
45
import sqlmodel
6+
import sqlalchemy
57
import importlib
68
import os
79
import tomllib
@@ -95,9 +97,11 @@ class ExtensionManager:
9597

9698
@classmethod
9799
def start_enabled(cls, app: fastapi.FastAPI):
98-
"""Start all enabled entensions."""
100+
"""Start all extensions enabled for the current client."""
101+
from app.business.client import ClientManager
102+
99103
cls.FASTAPI_APP = app
100-
for extension in cls.get_installed():
104+
for extension in cls.get_installed(enabled_only=True):
101105
cls.start(extension=extension)
102106

103107
@classmethod
@@ -305,7 +309,7 @@ def download(cls, extid: str, version: Opt[str] = None) -> ExtensionModel:
305309
version=version or "0.0.0",
306310
nickname=nickname,
307311
config={},
308-
disabled=True,
312+
enabled=[],
309313
)
310314

311315
@classmethod
@@ -329,8 +333,15 @@ def install(cls, extid: ExtensionID, version: Opt[str] = None) -> ExtensionModel
329333
return extension
330334

331335
@classmethod
332-
async def set_disabled(cls, extid: ExtensionID, disabled: bool) -> ExtensionModel:
333-
"""Enable or disable an extension."""
336+
async def enable(cls, extid: ExtensionID) -> ExtensionModel:
337+
"""Enable an extension for the current client.
338+
339+
Adds the current client ID to the extension's enabled list and starts it.
340+
"""
341+
from app.business.client import ClientManager
342+
343+
client_id = ClientManager.get_current_client_id()
344+
334345
with SessionLocal() as db:
335346
extension = db.exec(
336347
sqlmodel.select(ExtensionModel).where(ExtensionModel.id == extid)
@@ -339,32 +350,76 @@ async def set_disabled(cls, extid: ExtensionID, disabled: bool) -> ExtensionMode
339350
if not extension:
340351
raise ValueError(f"Extension with id {extid} not found.")
341352

342-
extension.disabled = disabled
343-
db.add(extension)
344-
db.commit()
345-
db.refresh(extension)
353+
# Add client to enabled list if not already present
354+
current_enabled = set(extension.enabled)
355+
if client_id not in current_enabled:
356+
current_enabled.add(client_id)
357+
extension.enabled = list(current_enabled)
358+
db.add(extension)
359+
db.commit()
360+
db.refresh(extension)
361+
362+
# Start extension if not already running
363+
if extid not in cls.RUNNING_EXTENSIONS:
364+
cls.start(extid)
365+
366+
return extension
367+
368+
@classmethod
369+
async def disable(cls, extid: ExtensionID) -> ExtensionModel:
370+
"""Disable an extension for the current client.
371+
372+
Removes the current client ID from the extension's enabled list and stops it.
373+
"""
374+
from app.business.client import ClientManager
375+
376+
client_id = ClientManager.get_current_client_id()
346377

347-
if disabled:
378+
with SessionLocal() as db:
379+
extension = db.exec(
380+
sqlmodel.select(ExtensionModel).where(ExtensionModel.id == extid)
381+
).first()
382+
383+
if not extension:
384+
raise ValueError(f"Extension with id {extid} not found.")
385+
386+
# Remove client from enabled list
387+
current_enabled = set(extension.enabled)
388+
if client_id in current_enabled:
389+
current_enabled.discard(client_id)
390+
extension.enabled = list(current_enabled)
391+
db.add(extension)
392+
db.commit()
393+
db.refresh(extension)
394+
395+
# Close extension if running
396+
if extid in cls.RUNNING_EXTENSIONS:
348397
await cls.close(extid)
349-
else:
350-
cls.start(extid)
351398

352399
return extension
353400

354401
@classmethod
355-
def get_installed(cls, disabled: Opt[bool] = False) -> tuple[ExtensionModel, ...]:
402+
def get_installed(
403+
cls,
404+
enabled_only: bool = False,
405+
) -> tuple[ExtensionModel, ...]:
356406
"""Get installed extensions.
357407
358-
:param disabled: If True, include disabled extensions; otherwise, only enabled ones.
408+
:param enabled_only: If True, only return extensions enabled for the current client.
359409
"""
410+
from app.business.client import ClientManager
411+
360412
with SessionLocal() as db:
361-
return tuple(
362-
db.exec(
363-
sqlmodel.select(ExtensionModel).where(
364-
disabled is None or ExtensionModel.disabled == disabled
365-
)
366-
).all()
367-
)
413+
query = sqlmodel.select(ExtensionModel)
414+
415+
if enabled_only:
416+
client_id = ClientManager.get_current_client_id()
417+
# Filter: client_id must be in the enabled array
418+
query = query.where(
419+
ExtensionModel.enabled.any(client_id, operator=sqlalchemy.sql.operators.eq)
420+
)
421+
422+
return tuple(db.exec(query).all())
368423

369424
@classmethod
370425
def get(cls, extid: ExtensionID) -> Opt[ExtensionModel]:
@@ -460,7 +515,7 @@ def sync(cls):
460515
id=ext_id,
461516
version=version,
462517
nickname=nickname,
463-
disabled=True,
518+
enabled=[],
464519
)
465520
db.add(new_ext)
466521
new_count += 1

app/routes/extension.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
@ROUTER.get("")
1717
def get_extensions() -> tuple[ExtensionModel, ...]:
1818
"""List all installed extensions"""
19-
return ExtensionManager.get_installed(disabled=None)
19+
return ExtensionManager.get_installed()
2020

2121

2222
@ROUTER.post("/{extid}")
@@ -29,11 +29,23 @@ def install_extension(
2929
return ExtensionManager.install(extid, version=version)
3030

3131

32-
@ROUTER.put("/{extid}/disabled/{disabled}")
33-
async def toggle_extension(extid: ExtensionID, disabled: bool) -> ExtensionModel:
34-
"""启用/禁用插件 (Enable/Disable extension)"""
32+
@ROUTER.post("/{extid}/enable")
33+
async def enable_extension(extid: ExtensionID) -> ExtensionModel:
34+
"""启用插件 (Enable extension for current client)"""
3535
try:
36-
return await ExtensionManager.set_disabled(extid, disabled)
36+
return await ExtensionManager.enable(extid)
37+
except ValueError:
38+
raise fastapi.HTTPException(
39+
status_code=fastapi.status.HTTP_404_NOT_FOUND,
40+
detail=f"Extension with id {extid} not found.",
41+
)
42+
43+
44+
@ROUTER.post("/{extid}/disable")
45+
async def disable_extension(extid: ExtensionID) -> ExtensionModel:
46+
"""禁用插件 (Disable extension for current client)"""
47+
try:
48+
return await ExtensionManager.disable(extid)
3749
except ValueError:
3850
raise fastapi.HTTPException(
3951
status_code=fastapi.status.HTTP_404_NOT_FOUND,

app/schemas/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"ExtensionModel",
1010
"RelationEmbeddingModel",
1111
"BlockEmbeddingModel",
12+
"ClientModel",
1213
]
1314

1415
import sqlalchemy.orm
@@ -23,3 +24,4 @@
2324
from .source import SourceModel, SourceCollectJobModel
2425
from .extension.main import ExtensionModel
2526
from .sink import RelationEmbeddingModel, BlockEmbeddingModel
27+
from .client.main import ClientModel

app/schemas/client/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .main import ClientModel, ClientID
2+
3+
__all__ = [
4+
"ClientModel",
5+
"ClientID",
6+
]

app/schemas/client/main.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import datetime
2+
import uuid
3+
import sqlalchemy
4+
import sqlalchemy.dialects.postgresql
5+
import sqlmodel
6+
import typing
7+
from typing import Optional as Opt
8+
9+
10+
ClientID: typing.TypeAlias = uuid.UUID
11+
12+
13+
class ClientModel(sqlmodel.SQLModel, table=True):
14+
"""Client registration model.
15+
16+
Represents a client instance that connects to this InKCre deployment.
17+
All clients are equal peers in the network.
18+
"""
19+
20+
__tablename__: str = "clients" # type: ignore
21+
22+
id: ClientID = sqlmodel.Field(
23+
sa_column=sqlalchemy.Column(
24+
sqlalchemy.dialects.postgresql.UUID(as_uuid=True),
25+
primary_key=True,
26+
default=uuid.uuid4,
27+
)
28+
)
29+
name: str = sqlmodel.Field(
30+
sa_column=sqlalchemy.Column(sqlalchemy.Text, nullable=False)
31+
)
32+
labels: list[str] = sqlmodel.Field(
33+
default_factory=list,
34+
sa_column=sqlalchemy.Column(
35+
sqlalchemy.dialects.postgresql.ARRAY(sqlalchemy.Text),
36+
server_default=sqlalchemy.text("'{}'::text[]"),
37+
),
38+
)
39+
rest_api_url: Opt[str] = sqlmodel.Field(
40+
default=None,
41+
sa_column=sqlalchemy.Column(sqlalchemy.Text, nullable=True),
42+
)
43+
"""REST API base URL. Nullable since not all clients are reachable (e.g., client-web)."""
44+
created_at: datetime.datetime = sqlmodel.Field(
45+
default_factory=datetime.datetime.now,
46+
sa_column=sqlalchemy.Column(
47+
sqlalchemy.TIMESTAMP(timezone=True),
48+
server_default=sqlalchemy.text("CURRENT_TIMESTAMP"),
49+
),
50+
)

app/schemas/extension/main.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import uuid
12
import sqlalchemy
23
import sqlalchemy.dialects.postgresql
34
import sqlmodel
@@ -24,7 +25,20 @@ class ExtensionModel(sqlmodel.SQLModel, table=True):
2425
2526
format: `major.minor.patch`.
2627
"""
27-
disabled: bool = sqlmodel.Field(default=False)
28+
enabled: list[uuid.UUID] = sqlmodel.Field(
29+
default_factory=list,
30+
sa_column=sqlmodel.Column(
31+
sqlalchemy.dialects.postgresql.ARRAY(
32+
sqlalchemy.dialects.postgresql.UUID(as_uuid=True)
33+
),
34+
server_default=sqlalchemy.text("'{}'::uuid[]"),
35+
nullable=False,
36+
),
37+
)
38+
"""List of client IDs for which this extension is enabled.
39+
40+
Empty array means disabled for all clients.
41+
"""
2842
nickname: Opt[str] = sqlmodel.Field(default=None)
2943
config: dict = sqlmodel.Field(
3044
default_factory=dict,

0 commit comments

Comments
 (0)