-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapi.py
More file actions
executable file
·186 lines (151 loc) · 5.2 KB
/
api.py
File metadata and controls
executable file
·186 lines (151 loc) · 5.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import asyncio
import json
import logging
import os
import secrets
from contextlib import asynccontextmanager
from typing import Dict
import uvicorn
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.security import OAuth2PasswordBearer
from ataraxai import __version__
from ataraxai.gateway.gateway_task_manager import GatewayTaskManager
from ataraxai.gateway.request_manager import RequestManager
from ataraxai.praxis.ataraxai_orchestrator import (
AtaraxAIOrchestrator,
AtaraxAIOrchestratorFactory,
)
from ataraxai.praxis.katalepsis import Katalepsis
from ataraxai.praxis.utils.ataraxai_logger import AtaraxAILogger
from ataraxai.routes.benchmark_route.benchmarker import router_benchmark
from ataraxai.routes.chain_runner_route.chain_runner import router_chain_runner
from ataraxai.routes.chat_route.chat import router_chat
from ataraxai.routes.configs_routes.llama_cpp_config_route.llama_cpp_config import (
router_llama_cpp,
)
from ataraxai.routes.configs_routes.rag_config_route.rag_config_route import (
router_rag_config,
)
from ataraxai.routes.configs_routes.user_preferences_route.user_preferences import (
router_user_preferences,
)
from ataraxai.routes.core_ai_service.core_ai_service import (
router_core_ai_service_config,
)
from ataraxai.routes.benchmark_route.benchmarker import router_benchmark
from ataraxai.routes.dependency_api import (
get_orchestrator,
verify_token,
)
from ataraxai.routes.models_manager_route.models_manager import router_models_manager
from ataraxai.routes.rag_route.rag import router_rag
from ataraxai.routes.status import Status, StatusResponse
from ataraxai.routes.vault_route.vault import router_vault
os.environ.setdefault("ENVIRONMENT", "development")
ENVIRONMENT = os.getenv("ENVIRONMENT", "development")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.secret_token = secrets.token_hex(16)
app.state.orchestrator = await AtaraxAIOrchestratorFactory.create_orchestrator()
app.state.logger = app.state.orchestrator.logger
app.state.katalepsis_monitor = Katalepsis()
app.state.request_manager = RequestManager(logger=app.state.logger)
app.state.gateway_task_manager = GatewayTaskManager()
await app.state.request_manager.start()
yield
app.state.logger.info("API is shutting down. Closing orchestrator resources.")
await app.state.orchestrator.shutdown()
await app.state.request_manager.stop()
app.state.secret_token = None
app = FastAPI(
title="AtaraxAI API",
description="API for the AtaraxAI Local Assistant Engine",
version=__version__,
lifespan=lifespan,
)
if ENVIRONMENT == "development":
allowed_hosts = ["localhost", "127.0.0.1", "test"]
allow_origins = ["*"]
else:
app.docs_url = None
app.redoc_url = None
allowed_hosts = ["localhost", "127.0.0.1"]
allow_origins = ["*"]
app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed_hosts)
app.add_middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get(
"/v1/status", response_model=StatusResponse, dependencies=[Depends(verify_token)]
)
async def get_state(
orch: AtaraxAIOrchestrator = Depends(get_orchestrator),
) -> StatusResponse:
state = await orch.get_state()
return StatusResponse(
status=Status.SUCCESS,
message=f"AtaraxAI is currently in state: {state.name}",
)
@app.get(
"/v1/health", response_model=StatusResponse, dependencies=[Depends(verify_token)]
)
async def get_health(
orch: AtaraxAIOrchestrator = Depends(get_orchestrator),
) -> StatusResponse:
return StatusResponse(
status=Status.SUCCESS,
message="AtaraxAI is healthy.",
)
all_routers = [
router_vault,
router_chat,
router_rag,
router_models_manager,
router_user_preferences,
router_llama_cpp,
router_rag_config,
router_core_ai_service_config,
router_chain_runner,
router_benchmark,
]
for router in all_routers:
app.include_router(router, dependencies=[Depends(verify_token)])
def print_connection_info(port: int, token: str | None):
connection_info: Dict[str, int | str | None] = {
"port": port,
"token": token,
"status": "ready",
}
print(json.dumps(connection_info), flush=True)
def find_free_port() -> int:
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
async def main():
port = find_free_port()
config = uvicorn.Config(
app,
host="127.0.0.1",
port=port,
log_level="info",
access_log=False,
)
server = uvicorn.Server(config)
original_startup = server.startup
async def custom_startup(**kwargs): # type: ignore
await original_startup()
token = app.state.secret_token
print_connection_info(port, token)
server.startup = custom_startup # type: ignore
await server.serve()
if __name__ == "__main__":
asyncio.run(main())