Skip to content

Commit 4916c1d

Browse files
committed
sync aload_prompt with load_prompt
1 parent 9e39793 commit 4916c1d

2 files changed

Lines changed: 87 additions & 43 deletions

File tree

py/setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
long_description = f.read()
1414

1515
install_requires = [
16-
"aiohttp",
1716
"GitPython",
1817
"requests",
1918
"chevron",
@@ -34,7 +33,7 @@
3433
"openai-agents": ["openai-agents"],
3534
"otel": ["opentelemetry-api", "opentelemetry-sdk", "opentelemetry-exporter-otlp-proto-http"],
3635
# orjson is not compatible with PyPy, so we exclude it for that platform
37-
"performance": ["orjson; platform_python_implementation != 'PyPy'"],
36+
"performance": ["orjson; platform_python_implementation != 'PyPy'", "aiohttp"],
3837
"temporal": ["temporalio>=1.19.0; python_version>='3.10'"],
3938
}
4039

py/src/braintrust/logger.py

Lines changed: 86 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -817,11 +817,14 @@ def patch_json(self, object_type: str, args: Mapping[str, Any] | None = None) ->
817817

818818
async def aget_json(
819819
self, object_type: str, args: Optional[Mapping[str, Any]] = None, retries: int = 0
820-
) -> Mapping[str, Any]:
820+
) -> Mapping[str, Any] | None:
821821
"""
822822
Async version of get_json. Makes a true async HTTP GET request and returns JSON response.
823823
"""
824+
from importlib.util import find_spec
825+
824826
tries = retries + 1
827+
use_urllib = find_spec("aiohttp") is None
825828

826829
for i in range(tries):
827830
try:
@@ -831,8 +834,7 @@ async def aget_json(
831834
url += "?" + urlencode(_strip_nones(args))
832835

833836
# check if aiohttp is available, otherwise fall back to asyncio approach
834-
from importlib.util import find_spec
835-
if find_spec("aiohttp") is None:
837+
if use_urllib:
836838
# Fall back to asyncio + urllib approach
837839
return await self._make_asyncio_request(url)
838840
return await self._make_aiohttp_request(url)
@@ -864,22 +866,21 @@ async def _make_aiohttp_request(self, url: str) -> Mapping[str, Any]:
864866

865867
async def _make_asyncio_request(self, url: str) -> Mapping[str, Any]:
866868
"""Make async HTTP request using asyncio and urllib (fallback)"""
867-
loop = asyncio.get_event_loop()
869+
loop = asyncio.get_running_loop()
870+
timeout_secs = parse_env_var_float("BRAINTRUST_HTTP_TIMEOUT", 60.0)
868871

869872
def sync_request():
870873
request = Request(url)
871874
if self.token:
872875
request.add_header("Authorization", f"Bearer {self.token}")
873876

874877
try:
875-
response_obj = urlopen(request)
878+
response_obj = urlopen(request, timeout=timeout_secs)
876879
response_data = response_obj.read()
877880
return json.loads(response_data.decode("utf-8"))
878881
except HTTPError as e:
879-
if e.code >= 400:
880-
error_body = e.read().decode("utf-8") if hasattr(e, "read") else str(e)
881-
raise Exception(f"HTTP {e.code}: {error_body}")
882-
raise
882+
error_body = e.read().decode("utf-8") if hasattr(e, "read") else str(e)
883+
raise Exception(f"HTTP {e.code}: {error_body}")
883884
except URLError as e:
884885
raise Exception(f"URL Error: {e}")
885886

@@ -2244,8 +2245,10 @@ async def aload_prompt(
22442245
slug: Optional[str] = None,
22452246
version: Optional[Union[str, int]] = None,
22462247
project_id: Optional[str] = None,
2248+
prompt_id: str | None = None,
22472249
defaults: Optional[Mapping[str, Any]] = None,
22482250
no_trace: bool = False,
2251+
environment: str | None = None,
22492252
app_url: Optional[str] = None,
22502253
api_key: Optional[str] = None,
22512254
org_name: Optional[str] = None,
@@ -2257,81 +2260,123 @@ async def aload_prompt(
22572260
:param slug: The slug of the prompt to load.
22582261
:param version: An optional version of the prompt (to read). If not specified, the latest version will be used.
22592262
:param project_id: The id of the project to load the prompt from. This takes precedence over `project` if specified.
2263+
:param prompt_id: The id of a specific prompt to load. If specified, this takes precedence over all other parameters (project, slug, version).
22602264
:param defaults: (Optional) A dictionary of default values to use when rendering the prompt. Prompt values will override these defaults.
22612265
:param no_trace: If true, do not include logging metadata for this prompt when build() is called.
2266+
:param environment: The environment to load the prompt from. If both `version` and `environment` are provided, `version` takes precedence.
22622267
:param app_url: The URL of the Braintrust App. Defaults to https://www.braintrust.dev.
22632268
:param api_key: The API key to use. If the parameter is not specified, will try to use the `BRAINTRUST_API_KEY` environment variable. If no API
22642269
key is specified, will prompt the user to login.
22652270
:param org_name: (Optional) The name of a specific organization to connect to. This is useful if you belong to multiple.
22662271
:returns: The prompt object.
22672272
"""
2273+
effective_environment = None
2274+
if version is None:
2275+
effective_environment = environment
22682276

2269-
if not project and not project_id:
2277+
if prompt_id:
2278+
pass
2279+
elif not project and not project_id:
22702280
raise ValueError("Must specify at least one of project or project_id")
2271-
if not slug:
2281+
elif not slug:
22722282
raise ValueError("Must specify slug")
22732283

2274-
loop = asyncio.get_event_loop()
2284+
loop = asyncio.get_running_loop()
2285+
response = None
22752286

22762287
try:
22772288
# Run login in thread pool since it's synchronous
22782289
await loop.run_in_executor(HTTP_REQUEST_THREAD_POOL, login, app_url, api_key, org_name)
2290+
if prompt_id:
2291+
args = _populate_args({}, version=version, environment=effective_environment)
22792292

2280-
# Make async HTTP request
2281-
args = _populate_args(
2282-
{
2283-
"project_name": project,
2284-
"project_id": project_id,
2285-
"slug": slug,
2286-
"version": version,
2287-
},
2288-
)
2293+
response = await _state.api_conn().aget_json(f"/v1/prompt/{prompt_id}", args)
22892294

2290-
response = await _state.api_conn().aget_json("/v1/prompt", args)
2295+
if response:
2296+
response = {"objects": [response]}
2297+
2298+
else:
2299+
args = _populate_args(
2300+
{},
2301+
project_name=project,
2302+
project_id=project_id,
2303+
slug=slug,
2304+
version=version,
2305+
environment=effective_environment,
2306+
)
2307+
2308+
response = await _state.api_conn().aget_json("/v1/prompt", args)
22912309

22922310
except Exception as server_error:
2311+
# If environment or version was specified, don't fall back to cache
2312+
if effective_environment is not None or version is not None:
2313+
raise ValueError(f"Prompt not found with specified parameters") from server_error
2314+
22932315
eprint(f"Failed to load prompt, attempting to fall back to cache: {server_error}")
22942316
try:
2295-
cache_result = await loop.run_in_executor(
2296-
HTTP_REQUEST_THREAD_POOL,
2297-
lambda: _state._prompt_cache.get(
2298-
slug,
2299-
version=str(version) if version else "latest",
2300-
project_id=project_id,
2301-
project_name=project,
2302-
),
2303-
)
2317+
if prompt_id:
2318+
cache_result = await loop.run_in_executor(
2319+
HTTP_REQUEST_THREAD_POOL,
2320+
lambda: _state._prompt_cache.get(id=prompt_id),
2321+
)
2322+
else:
2323+
cache_result = await loop.run_in_executor(
2324+
HTTP_REQUEST_THREAD_POOL,
2325+
lambda: _state._prompt_cache.get(
2326+
slug,
2327+
version=str(version) if version else "latest",
2328+
project_id=project_id,
2329+
project_name=project,
2330+
),
2331+
)
23042332
# Return Prompt with pre-computed metadata from cache
23052333
return Prompt(
23062334
lazy_metadata=LazyValue(lambda: cache_result, use_mutex=True),
23072335
defaults=defaults or {},
23082336
no_trace=no_trace,
23092337
)
23102338
except Exception as cache_error:
2339+
if prompt_id:
2340+
raise ValueError(
2341+
f"Prompt with id {prompt_id} not found (not found on server or in local cache): {cache_error}"
2342+
) from server_error
23112343
raise ValueError(
23122344
f"Prompt {slug} (version {version or 'latest'}) not found in {project or project_id} (not found on server or in local cache): {cache_error}"
23132345
) from server_error
23142346

23152347
if response is None or "objects" not in response or len(response["objects"]) == 0:
2348+
if prompt_id:
2349+
raise ValueError(f"Prompt with id {prompt_id} not found.")
2350+
23162351
raise ValueError(f"Prompt {slug} not found in project {project or project_id}.")
23172352
elif len(response["objects"]) > 1:
2353+
if prompt_id:
2354+
raise ValueError(f"Multiple prompts found with id {prompt_id}. This should never happen.")
2355+
23182356
raise ValueError(
23192357
f"Multiple prompts found with slug {slug} in project {project or project_id}. This should never happen."
23202358
)
23212359

23222360
resp_prompt = response["objects"][0]
23232361
prompt_metadata = PromptSchema.from_dict_deep(resp_prompt)
23242362
try:
2325-
await loop.run_in_executor(
2326-
HTTP_REQUEST_THREAD_POOL,
2327-
lambda: _state._prompt_cache.set(
2328-
slug,
2329-
str(version) if version else "latest",
2330-
prompt_metadata,
2331-
project_id=project_id,
2332-
project_name=project,
2333-
),
2334-
)
2363+
# save prompt to cache
2364+
if prompt_id:
2365+
await loop.run_in_executor(
2366+
HTTP_REQUEST_THREAD_POOL,
2367+
lambda: _state._prompt_cache.set(prompt_metadata, id=prompt_id),
2368+
)
2369+
else:
2370+
await loop.run_in_executor(
2371+
HTTP_REQUEST_THREAD_POOL,
2372+
lambda: _state._prompt_cache.set(
2373+
prompt_metadata,
2374+
slug=slug,
2375+
version=str(version) if version else "latest",
2376+
project_id=project_id,
2377+
project_name=project,
2378+
),
2379+
)
23352380
except Exception as e:
23362381
eprint(f"Failed to store prompt in cache: {e}")
23372382

0 commit comments

Comments
 (0)