Skip to content

Commit 9a9ab83

Browse files
committed
fix: adding header and query
1 parent 5aa20a0 commit 9a9ab83

3 files changed

Lines changed: 168 additions & 0 deletions

File tree

.vscode/settings.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"python.testing.pytestArgs": [
3+
"py"
4+
],
5+
"python.testing.unittestEnabled": false,
6+
"python.testing.pytestEnabled": true
7+
}

py/autoevals/oai.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import json
23
import os
34
import sys
45
import textwrap
@@ -145,6 +146,51 @@ def prepare_openai(client: Optional[LLMClient] = None, is_async=False, api_key=N
145146
# This is the new v1 API
146147
is_v1 = True
147148

149+
default_headers = {}
150+
default_query = {}
151+
152+
# Get headers from environment variables
153+
if os.environ.get("OPENAI_DEFAULT_HEADERS"):
154+
try:
155+
default_headers = json.loads(os.environ.get("OPENAI_DEFAULT_HEADERS"))
156+
except json.JSONDecodeError as e:
157+
print(f"Error parsing OPENAI_DEFAULT_HEADERS: {e}")
158+
default_headers = {}
159+
160+
# Get query params from environment variables
161+
if os.environ.get("OPENAI_DEFAULT_QUERY"):
162+
try:
163+
default_query = json.loads(os.environ.get("OPENAI_DEFAULT_QUERY"))
164+
except json.JSONDecodeError as e:
165+
print(f"Error parsing OPENAI_DEFAULT_QUERY: {e}")
166+
default_query = {}
167+
168+
# Add request source tracking header
169+
default_headers["X-Request-Source"] = "autoevals"
170+
171+
print(f"default_headers: {default_headers}")
172+
print(f"default_query: {default_query}")
173+
174+
if is_async:
175+
openai_obj = openai.AsyncOpenAI(
176+
api_key=api_key,
177+
base_url=base_url,
178+
default_headers=default_headers,
179+
default_query=default_query
180+
)
181+
else:
182+
openai_obj = openai.OpenAI(
183+
api_key=api_key,
184+
base_url=base_url,
185+
default_headers=default_headers,
186+
default_query=default_query
187+
)
188+
else:
189+
if api_key:
190+
openai.api_key = api_key
191+
openai.api_base = base_url
192+
# For v0 API, headers and query params need to be set per-request
193+
148194
if client is None:
149195
# prepare the default openai sdk, if not provided
150196
if api_key is None:

py/autoevals/test_oai.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import pytest
2+
from unittest.mock import Mock, patch
3+
import json
4+
import os
5+
6+
from . import oai
7+
from .oai import LLMClient, prepare_openai, post_process_response, run_cached_request, arun_cached_request
8+
9+
class MockOpenAIResponse:
10+
def dict(self):
11+
return {"response": "test"}
12+
13+
class MockRateLimitError(Exception):
14+
pass
15+
16+
class MockCompletions:
17+
def create(self, **kwargs):
18+
return MockOpenAIResponse()
19+
20+
class MockChat:
21+
def __init__(self):
22+
self.completions = MockCompletions()
23+
24+
class MockEmbeddings:
25+
def create(self, **kwargs):
26+
return MockOpenAIResponse()
27+
28+
class MockModerations:
29+
def create(self, **kwargs):
30+
return MockOpenAIResponse()
31+
32+
class MockOpenAI:
33+
def __init__(self, **kwargs):
34+
self.default_headers = kwargs.get('default_headers', {})
35+
self.default_query = kwargs.get('default_query', {})
36+
self.chat = MockChat()
37+
self.embeddings = MockEmbeddings()
38+
self.moderations = MockModerations()
39+
self.RateLimitError = MockRateLimitError
40+
41+
def test_openai_sync():
42+
"""Test basic OpenAI client functionality with a simple completion request"""
43+
mock_openai = MockOpenAI()
44+
client = LLMClient(
45+
openai=mock_openai,
46+
complete=mock_openai.chat.completions.create,
47+
embed=mock_openai.embeddings.create,
48+
moderation=mock_openai.moderations.create,
49+
RateLimitError=MockRateLimitError
50+
)
51+
52+
response = run_cached_request(
53+
client=client,
54+
request_type="complete",
55+
messages=[
56+
{
57+
"role": "system",
58+
"content": "You are a helpful assistant."
59+
},
60+
{
61+
"role": "user",
62+
"content": "What is 2+2?"
63+
}
64+
],
65+
model="gpt-3.5-turbo",
66+
max_tokens=50
67+
)
68+
69+
assert response == {"response": "test"}
70+
71+
@patch('openai.OpenAI')
72+
@patch.dict(os.environ, {'OPENAI_API_KEY': 'test-key'})
73+
def test_openai_headers(mock_openai):
74+
"""Test OpenAI client with custom headers"""
75+
mock_instance = MockOpenAI(default_headers={"X-Custom-Header": "test", "X-Request-Source": "autoevals"})
76+
mock_openai.return_value = mock_instance
77+
with patch.dict(os.environ, {'OPENAI_DEFAULT_HEADERS': json.dumps({"X-Custom-Header": "test"})}):
78+
client, wrapped = prepare_openai()
79+
assert isinstance(client, LLMClient)
80+
assert mock_instance.default_headers["X-Custom-Header"] == "test"
81+
assert mock_instance.default_headers["X-Request-Source"] == "autoevals"
82+
83+
@patch('openai.OpenAI')
84+
@patch.dict(os.environ, {'OPENAI_API_KEY': 'test-key'})
85+
def test_openai_query_params(mock_openai):
86+
"""Test OpenAI client with custom query parameters"""
87+
mock_instance = MockOpenAI(default_query={"custom_param": "test"})
88+
mock_openai.return_value = mock_instance
89+
with patch.dict(os.environ, {'OPENAI_DEFAULT_QUERY': json.dumps({"custom_param": "test"})}):
90+
client, wrapped = prepare_openai()
91+
assert isinstance(client, LLMClient)
92+
assert mock_instance.default_query["custom_param"] == "test"
93+
94+
@patch('openai.OpenAI')
95+
@patch.dict(os.environ, {'OPENAI_API_KEY': 'test-key'})
96+
def test_invalid_header_json(mock_openai):
97+
"""Test handling of invalid header JSON"""
98+
mock_instance = MockOpenAI(default_headers={"X-Request-Source": "autoevals"})
99+
mock_openai.return_value = mock_instance
100+
with patch.dict(os.environ, {'OPENAI_DEFAULT_HEADERS': 'invalid json'}):
101+
client, wrapped = prepare_openai()
102+
assert isinstance(client, LLMClient)
103+
assert mock_instance.default_headers["X-Request-Source"] == "autoevals"
104+
assert len(mock_instance.default_headers) == 1
105+
106+
@patch('openai.OpenAI')
107+
@patch.dict(os.environ, {'OPENAI_API_KEY': 'test-key'})
108+
def test_invalid_query_json(mock_openai):
109+
"""Test handling of invalid query JSON"""
110+
mock_instance = MockOpenAI()
111+
mock_openai.return_value = mock_instance
112+
with patch.dict(os.environ, {'OPENAI_DEFAULT_QUERY': 'invalid json'}):
113+
client, wrapped = prepare_openai()
114+
assert isinstance(client, LLMClient)
115+
assert len(mock_instance.default_query) == 0

0 commit comments

Comments
 (0)