Skip to content

Commit 6a90316

Browse files
committed
fix: tool invocation callback host
1 parent 77fc3ab commit 6a90316

File tree

4 files changed

+544
-16
lines changed

4 files changed

+544
-16
lines changed

src/lightrace/client.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
_current_trace_id,
1919
_get_tool_registry,
2020
_set_client_defaults,
21+
_set_on_tool_registered,
2122
_set_otel_exporter,
2223
_tool_registry,
2324
)
@@ -58,6 +59,7 @@ def __init__(
5859
session_id: str | None = None,
5960
dev_server: bool = True,
6061
dev_server_port: int = 0,
62+
dev_server_host: str | None = None,
6163
):
6264
self._public_key = public_key or os.environ.get("LIGHTRACE_PUBLIC_KEY", "")
6365
self._secret_key = secret_key or os.environ.get("LIGHTRACE_SECRET_KEY", "")
@@ -69,6 +71,9 @@ def __init__(
6971
self._dev_server: DevServer | None = None
7072
self._dev_server_enabled = dev_server
7173
self._dev_server_port = dev_server_port
74+
self._dev_server_host = dev_server_host or os.environ.get(
75+
"LIGHTRACE_DEV_SERVER_HOST", "127.0.0.1"
76+
)
7277

7378
if not enabled:
7479
logger.info("Lightrace disabled — no events will be sent")
@@ -120,16 +125,30 @@ def _start_dev_server(self) -> None:
120125
self._dev_server = DevServer(
121126
port=self._dev_server_port,
122127
public_key=self._public_key,
128+
callback_host=self._dev_server_host,
123129
)
130+
self._pending_registration: threading.Timer | None = None
124131
try:
125132
port = self._dev_server.start()
126133
logger.info("Dev server listening on http://127.0.0.1:%d", port)
127134
self._register_tools_http()
135+
136+
# Re-register when new tools are added after init (debounced)
137+
def _on_new_tool(name: str) -> None:
138+
logger.debug("New tool registered: %s — scheduling re-registration", name)
139+
if self._pending_registration is not None:
140+
self._pending_registration.cancel()
141+
timer = threading.Timer(0.2, self._register_tools_http)
142+
timer.daemon = True
143+
timer.start()
144+
self._pending_registration = timer
145+
146+
_set_on_tool_registered(_on_new_tool)
128147
except Exception as e:
129148
logger.error("Failed to start dev server: %s", e)
130149

131150
def _register_tools_http(self) -> None:
132-
"""Register tool definitions with the Lightrace backend via HTTP (fire-and-forget)."""
151+
"""Register tool definitions with the Lightrace backend via HTTP with retry."""
133152
registry = _get_tool_registry()
134153
if not registry:
135154
return
@@ -149,19 +168,42 @@ def _register_tools_http(self) -> None:
149168

150169
auth = base64.b64encode(f"{self._public_key}:{self._secret_key}".encode()).decode()
151170
host = self._host
171+
max_retries = 3
152172

153173
def _do_register() -> None:
154-
try:
155-
resp = httpx.post(
156-
f"{host}/api/public/tools/register",
157-
json={"callbackUrl": callback_url, "tools": tools},
158-
headers={"Authorization": f"Basic {auth}"},
159-
timeout=5.0,
160-
)
161-
if resp.status_code >= 400:
162-
logger.warning("Tool registration returned %d", resp.status_code)
163-
except Exception as e:
164-
logger.error("Failed to register tools: %s", e)
174+
import time
175+
176+
for attempt in range(max_retries):
177+
try:
178+
resp = httpx.post(
179+
f"{host}/api/public/tools/register",
180+
json={"callbackUrl": callback_url, "tools": tools},
181+
headers={"Authorization": f"Basic {auth}"},
182+
timeout=5.0,
183+
)
184+
if resp.status_code < 400:
185+
logger.info(
186+
"Registered %d tool(s): %s",
187+
len(tools),
188+
", ".join(t["name"] for t in tools),
189+
)
190+
return
191+
logger.warning(
192+
"Tool registration returned %d (attempt %d/%d)",
193+
resp.status_code,
194+
attempt + 1,
195+
max_retries,
196+
)
197+
except Exception as e:
198+
logger.warning(
199+
"Tool registration failed (attempt %d/%d): %s",
200+
attempt + 1,
201+
max_retries,
202+
e,
203+
)
204+
if attempt < max_retries - 1:
205+
time.sleep(2**attempt)
206+
logger.error("Tool registration failed after %d attempts", max_retries)
165207

166208
threading.Thread(target=_do_register, daemon=True, name="lightrace-register").start()
167209

@@ -225,6 +267,10 @@ def flush(self) -> None:
225267

226268
def shutdown(self) -> None:
227269
"""Flush and shut down the client."""
270+
_set_on_tool_registered(None)
271+
if hasattr(self, "_pending_registration") and self._pending_registration is not None:
272+
self._pending_registration.cancel()
273+
self._pending_registration = None
228274
if self._dev_server is not None:
229275
self._dev_server.stop()
230276
self._dev_server = None

src/lightrace/dev_server.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from __future__ import annotations
1010

1111
import asyncio
12+
import inspect
1213
import logging
1314
import threading
1415
import time
@@ -85,16 +86,32 @@ async def invoke(req: InvokeRequest, request: Request) -> JSONResponse:
8586
input_data = req.input
8687
start = time.monotonic()
8788

89+
# Smart dispatch: spread kwargs when input keys match function parameter names
90+
spread = False
91+
if isinstance(input_data, dict):
92+
try:
93+
sig = inspect.signature(func)
94+
param_names = {
95+
p.name
96+
for p in sig.parameters.values()
97+
if p.kind
98+
not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
99+
}
100+
# Spread if the input dict keys are a subset of the function's param names
101+
spread = len(param_names) > 0 and set(input_data.keys()).issubset(param_names)
102+
except (ValueError, TypeError):
103+
spread = False
104+
88105
try:
89106
if asyncio.iscoroutinefunction(func):
90-
if isinstance(input_data, dict):
107+
if spread:
91108
result = await func(**input_data)
92109
elif input_data is not None:
93110
result = await func(input_data)
94111
else:
95112
result = await func()
96113
else:
97-
if isinstance(input_data, dict):
114+
if spread:
98115
result = await asyncio.to_thread(func, **input_data)
99116
elif input_data is not None:
100117
result = await asyncio.to_thread(func, input_data)
@@ -128,9 +145,10 @@ async def invoke(req: InvokeRequest, request: Request) -> JSONResponse:
128145
class DevServer:
129146
"""Lightweight HTTP server that accepts tool invocation from the dashboard."""
130147

131-
def __init__(self, port: int = 0, public_key: str = ""):
148+
def __init__(self, port: int = 0, public_key: str = "", callback_host: str = "127.0.0.1"):
132149
self._port = port
133150
self._public_key = public_key
151+
self._callback_host = callback_host
134152
self._thread: threading.Thread | None = None
135153
self._assigned_port: int | None = None
136154
self._server: Any = None
@@ -205,4 +223,4 @@ def port(self) -> int | None:
205223
def callback_url(self) -> str | None:
206224
if self._assigned_port is None:
207225
return None
208-
return f"http://127.0.0.1:{self._assigned_port}"
226+
return f"http://{self._callback_host}:{self._assigned_port}"

src/lightrace/trace.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
# Global references (set by Client on init)
2929
_otel_exporter: Any = None # LightraceOtelExporter instance
3030
_tool_registry: dict[str, dict[str, Any]] = {}
31+
_on_tool_registered: Callable[[str], None] | None = None
3132

3233
# Client defaults
3334
_client_defaults: dict[str, str | None] = {"user_id": None, "session_id": None}
@@ -47,6 +48,11 @@ def _get_tool_registry() -> dict[str, dict[str, Any]]:
4748
return _tool_registry
4849

4950

51+
def _set_on_tool_registered(callback: Callable[[str], None] | None) -> None:
52+
global _on_tool_registered
53+
_on_tool_registered = callback
54+
55+
5056
VALID_TYPES = {None, "span", "generation", "event", "tool", "chain"}
5157

5258

@@ -86,6 +92,8 @@ def decorator(func: F) -> F:
8692
"input_schema": build_json_schema(func),
8793
"description": None,
8894
}
95+
if _on_tool_registered is not None:
96+
_on_tool_registered(obs_name)
8997

9098
if asyncio.iscoroutinefunction(func):
9199

0 commit comments

Comments
 (0)