From 3409ab40585890d6842c67b80edc0f25d88fbc3a Mon Sep 17 00:00:00 2001 From: Daniel Gomm Date: Mon, 9 Feb 2026 14:35:12 +0100 Subject: [PATCH 1/9] :bug: fix calls to static builder methods (#4) --- src/openbatch/collector.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/openbatch/collector.py b/src/openbatch/collector.py index 1e82da6..31bc482 100644 --- a/src/openbatch/collector.py +++ b/src/openbatch/collector.py @@ -25,6 +25,7 @@ def __init__(self, batch_file_path: Union[str, PathLike]): where the batch requests will be written. """ self.batch_file_path = batch_file_path + self._manager = BatchJobManager() def parse(self, custom_id: str, model: str, text_format: Optional[type[BaseModel]] = None, **kwargs) -> None: """ @@ -56,7 +57,7 @@ def create(self, custom_id: str, model: str, **kwargs) -> None: self._add_request(custom_id, request) def _add_request(self, custom_id: str, request: ResponsesRequest) -> None: - BatchJobManager.add(custom_id, request, self.batch_file_path) + self._manager.add(custom_id, request, self.batch_file_path) class ChatCompletions: """ @@ -74,6 +75,7 @@ def __init__(self, batch_file_path: Union[str, PathLike]): where the batch requests will be written. """ self.batch_file_path = batch_file_path + self._manager = BatchJobManager() def parse(self, custom_id: str, model: str, response_format: Optional[type[BaseModel]] = None, **kwargs) -> None: """ @@ -105,7 +107,7 @@ def create(self, custom_id: str, model: str, **kwargs) -> None: self._add_request(custom_id, request) def _add_request(self, custom_id: str, request: ChatCompletionsRequest) -> None: - BatchJobManager.add(custom_id, request, self.batch_file_path) + self._manager.add(custom_id, request, self.batch_file_path) class Embeddings: """ @@ -123,6 +125,7 @@ def __init__(self, batch_file_path: Union[str, PathLike]): where the batch requests will be written. """ self.batch_file_path = batch_file_path + self._manager = BatchJobManager() def create(self, custom_id: str, model: str, inp: Union[str, list[str]], **kwargs) -> None: """ @@ -135,7 +138,7 @@ def create(self, custom_id: str, model: str, inp: Union[str, list[str]], **kwarg **kwargs: Additional parameters for the EmbeddingsRequest. """ request = EmbeddingsRequest.model_validate({"model": model, "input": inp, **kwargs}) - BatchJobManager.add(custom_id, request, self.batch_file_path) + self._manager.add(custom_id, request, self.batch_file_path) class BatchCollector: From 7f11c76769999298925e7f429679aeb643ac5341 Mon Sep 17 00:00:00 2001 From: Daniel Gomm Date: Mon, 9 Feb 2026 14:42:40 +0100 Subject: [PATCH 2/9] :test_tube: introduce tests (#5) --- .github/workflows/test.yml | 39 ++++ README.md | 31 ++- pyproject.toml | 35 ++++ tests/__init__.py | 0 tests/test_collector.py | 371 ++++++++++++++++++++++++++++++++++ tests/test_integration.py | 401 +++++++++++++++++++++++++++++++++++++ tests/test_manager.py | 395 ++++++++++++++++++++++++++++++++++++ tests/test_model.py | 258 ++++++++++++++++++++++++ tests/test_utils.py | 266 ++++++++++++++++++++++++ 9 files changed, 1795 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/test.yml create mode 100644 tests/__init__.py create mode 100644 tests/test_collector.py create mode 100644 tests/test_integration.py create mode 100644 tests/test_manager.py create mode 100644 tests/test_model.py create mode 100644 tests/test_utils.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..292fc91 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,39 @@ +name: Tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[test]" + + - name: Run tests with pytest + run: | + pytest -v --cov=openbatch --cov-report=term-missing --cov-report=xml + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + file: ./coverage.xml + flags: unittests + name: codecov-umbrella + fail_ci_if_error: false diff --git a/README.md b/README.md index 607a59b..6bae5fb 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,13 @@ # OpenBatch: Simplify OpenAI Batch Job Creation -[](https://www.google.com/search?q=https://badge.fury.io/py/openbatch) **OpenBatch** is a lightweight Python utility designed to streamline the creation of JSONL files for the [OpenAI Batch API](https://platform.openai.com/docs/guides/batch). It provides a type-safe and intuitive interface using Pydantic models to construct requests for the `/v1/responses`, `/v1/chat/completions`, and `/v1/embeddings` endpoints. +[![PyPI version](https://badge.fury.io/py/openbatch.svg)](https://badge.fury.io/py/openbatch) +[![Python versions](https://img.shields.io/pypi/pyversions/openbatch.svg)](https://pypi.org/project/openbatch/) +[![Tests](https://github.com/daniel-gomm/openbatch/actions/workflows/test.yml/badge.svg)](https://github.com/daniel-gomm/openbatch/actions/workflows/test.yml) +[![codecov](https://codecov.io/gh/daniel-gomm/openbatch/branch/main/graph/badge.svg)](https://codecov.io/gh/daniel-gomm/openbatch) +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) +[![GitHub stars](https://img.shields.io/github/stars/daniel-gomm/openbatch.svg?style=social&label=Star)](https://github.com/daniel-gomm/openbatch) + +**OpenBatch** is a lightweight Python utility designed to streamline the creation of JSONL files for the [OpenAI Batch API](https://platform.openai.com/docs/guides/batch). It provides a type-safe and intuitive interface using Pydantic models to construct requests for the `/v1/responses`, `/v1/chat/completions`, and `/v1/embeddings` endpoints. For a detailed guide on using OpenBatch, please refer to the **[OpenBatch Documentation](https://openbatch.daniel-gomm.com/)**. @@ -189,3 +196,25 @@ You can also override any common setting on a per-instance basis by using the `i 3. **Retrieve Results**: Monitor the job's status and, once completed, download the output file with the results. For detailed instructions on these steps, please refer to the **[Official OpenAI Batch API Documentation](https://platform.openai.com/docs/api-reference/batch)**. + +----- + +## Testing + +OpenBatch includes a comprehensive test suite. + +```bash +# Install with test dependencies +pip install -e ".[test]" + +# Run tests +pytest + +# Run with coverage report +pytest --cov=openbatch +``` + +The test suite includes: +- Unit tests for all core functionality +- Integration tests for end-to-end workflows +- Tests for structured outputs, reasoning models, and unicode handling \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d771905..0adbf8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,3 +23,38 @@ dependencies = [ Homepage = "https://github.com/daniel-gomm/openbatch" Issues = "https://github.com/daniel-gomm/openbatch/issues" Documentation = "https://daniel-gomm.github.io/openbatch/" + +[project.optional-dependencies] +test = [ + "pytest>=8.0.0", + "pytest-cov>=4.1.0", +] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = [ + "--strict-markers", + "--strict-config", + "--cov=openbatch", + "--cov-report=term-missing", + "--cov-report=html", + "--cov-report=xml", +] + +[tool.coverage.run] +source = ["src/openbatch"] +omit = ["*/tests/*", "*/__pycache__/*"] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise AssertionError", + "raise NotImplementedError", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", + "@abstractmethod", +] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_collector.py b/tests/test_collector.py new file mode 100644 index 0000000..090ab14 --- /dev/null +++ b/tests/test_collector.py @@ -0,0 +1,371 @@ +import json +import pytest +from pathlib import Path +from pydantic import BaseModel, Field +from openbatch.collector import BatchCollector, Responses, ChatCompletions, Embeddings +from openbatch.model import ReasoningConfig + + +@pytest.fixture +def temp_batch_file(tmp_path): + """Provides a temporary file path for batch files.""" + return tmp_path / "test_batch.jsonl" + + +class TestResponses: + def test_responses_create(self, temp_batch_file): + responses = Responses(temp_batch_file) + responses.create( + custom_id="req_1", + model="gpt-4", + input="What is Python?", + instructions="You are a helpful assistant", + max_output_tokens=100, + ) + + assert temp_batch_file.exists() + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert data["custom_id"] == "req_1" + assert data["url"] == "/v1/responses" + assert data["body"]["model"] == "gpt-4" + assert data["body"]["input"] == "What is Python?" + assert data["body"]["instructions"] == "You are a helpful assistant" + assert data["body"]["max_output_tokens"] == 100 + + def test_responses_parse_without_format(self, temp_batch_file): + responses = Responses(temp_batch_file) + responses.parse( + custom_id="req_2", + model="gpt-4", + input="Analyze this text", + instructions="Be concise", + ) + + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert data["custom_id"] == "req_2" + assert data["body"]["model"] == "gpt-4" + assert "text" not in data["body"] # No text format specified + + def test_responses_parse_with_format(self, temp_batch_file): + class Analysis(BaseModel): + summary: str = Field(description="Brief summary") + sentiment: str = Field(description="Sentiment analysis") + + responses = Responses(temp_batch_file) + responses.parse( + custom_id="req_3", + model="gpt-4", + text_format=Analysis, + input="Great product!", + instructions="Analyze sentiment", + ) + + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert data["custom_id"] == "req_3" + assert "text" in data["body"] + assert data["body"]["text"]["format"]["type"] == "json_schema" + assert data["body"]["text"]["format"]["name"] == "Analysis" + assert data["body"]["text"]["format"]["strict"] is True + + def test_responses_with_reasoning_config(self, temp_batch_file): + responses = Responses(temp_batch_file) + responses.create( + custom_id="req_4", + model="gpt-5-mini", + input="Complex problem", + reasoning=ReasoningConfig(effort="high", summary="detailed"), + ) + + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert data["body"]["reasoning"]["effort"] == "high" + assert data["body"]["reasoning"]["summary"] == "detailed" + + def test_responses_multiple_requests(self, temp_batch_file): + responses = Responses(temp_batch_file) + responses.create(custom_id="req_1", model="gpt-4", input="First") + responses.create(custom_id="req_2", model="gpt-4", input="Second") + + with open(temp_batch_file, "r") as f: + lines = f.readlines() + + assert len(lines) == 2 + data1 = json.loads(lines[0]) + data2 = json.loads(lines[1]) + assert data1["custom_id"] == "req_1" + assert data2["custom_id"] == "req_2" + + +class TestChatCompletions: + def test_chat_completions_create(self, temp_batch_file): + chat = ChatCompletions(temp_batch_file) + chat.create( + custom_id="chat_1", + model="gpt-4", + messages=[ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + ], + temperature=0.7, + ) + + assert temp_batch_file.exists() + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert data["custom_id"] == "chat_1" + assert data["url"] == "/v1/chat/completions" + assert data["body"]["model"] == "gpt-4" + assert len(data["body"]["messages"]) == 2 + assert data["body"]["temperature"] == 0.7 + + def test_chat_completions_parse_without_format(self, temp_batch_file): + chat = ChatCompletions(temp_batch_file) + chat.parse( + custom_id="chat_2", + model="gpt-4", + messages=[{"role": "user", "content": "Hi"}], + ) + + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert data["custom_id"] == "chat_2" + assert "response_format" not in data["body"] + + def test_chat_completions_parse_with_format(self, temp_batch_file): + class Response(BaseModel): + answer: str + confidence: float + + chat = ChatCompletions(temp_batch_file) + chat.parse( + custom_id="chat_3", + model="gpt-4", + response_format=Response, + messages=[{"role": "user", "content": "What is 2+2?"}], + ) + + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert data["custom_id"] == "chat_3" + assert "response_format" in data["body"] + assert data["body"]["response_format"]["format"]["name"] == "Response" + assert data["body"]["response_format"]["format"]["strict"] is True + + def test_chat_completions_with_reasoning_effort(self, temp_batch_file): + chat = ChatCompletions(temp_batch_file) + chat.create( + custom_id="chat_4", + model="o1-mini", + messages=[{"role": "user", "content": "Complex question"}], + reasoning_effort="high", + ) + + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert data["body"]["reasoning_effort"] == "high" + + def test_chat_completions_multiple_requests(self, temp_batch_file): + chat = ChatCompletions(temp_batch_file) + chat.create( + custom_id="chat_1", model="gpt-4", messages=[{"role": "user", "content": "Hi"}] + ) + chat.create( + custom_id="chat_2", + model="gpt-4", + messages=[{"role": "user", "content": "Bye"}], + ) + + with open(temp_batch_file, "r") as f: + lines = f.readlines() + + assert len(lines) == 2 + + +class TestEmbeddings: + def test_embeddings_create_single_input(self, temp_batch_file): + embeddings = Embeddings(temp_batch_file) + embeddings.create( + custom_id="emb_1", + model="text-embedding-3-small", + inp="Text to embed", + ) + + assert temp_batch_file.exists() + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert data["custom_id"] == "emb_1" + assert data["url"] == "/v1/embeddings" + assert data["body"]["model"] == "text-embedding-3-small" + assert data["body"]["input"] == "Text to embed" + + def test_embeddings_create_list_input(self, temp_batch_file): + embeddings = Embeddings(temp_batch_file) + embeddings.create( + custom_id="emb_2", + model="text-embedding-3-small", + inp=["Text 1", "Text 2", "Text 3"], + ) + + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert isinstance(data["body"]["input"], list) + assert len(data["body"]["input"]) == 3 + + def test_embeddings_with_dimensions(self, temp_batch_file): + embeddings = Embeddings(temp_batch_file) + embeddings.create( + custom_id="emb_3", + model="text-embedding-3-small", + inp="Test", + dimensions=512, + ) + + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert data["body"]["dimensions"] == 512 + + def test_embeddings_multiple_requests(self, temp_batch_file): + embeddings = Embeddings(temp_batch_file) + embeddings.create(custom_id="emb_1", model="text-embedding-3-small", inp="First") + embeddings.create(custom_id="emb_2", model="text-embedding-3-small", inp="Second") + + with open(temp_batch_file, "r") as f: + lines = f.readlines() + + assert len(lines) == 2 + + +class TestBatchCollector: + def test_batch_collector_initialization(self, temp_batch_file): + collector = BatchCollector(temp_batch_file) + assert isinstance(collector.responses, Responses) + assert isinstance(collector.chat.completions, ChatCompletions) + assert isinstance(collector.embeddings, Embeddings) + + def test_batch_collector_responses_api(self, temp_batch_file): + collector = BatchCollector(temp_batch_file) + collector.responses.create( + custom_id="req_1", + model="gpt-4", + input="Hello", + ) + + assert temp_batch_file.exists() + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert data["url"] == "/v1/responses" + + def test_batch_collector_chat_completions_api(self, temp_batch_file): + collector = BatchCollector(temp_batch_file) + collector.chat.completions.create( + custom_id="chat_1", + model="gpt-4", + messages=[{"role": "user", "content": "Hi"}], + ) + + assert temp_batch_file.exists() + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert data["url"] == "/v1/chat/completions" + + def test_batch_collector_embeddings_api(self, temp_batch_file): + collector = BatchCollector(temp_batch_file) + collector.embeddings.create( + custom_id="emb_1", + model="text-embedding-3-small", + inp="Text", + ) + + assert temp_batch_file.exists() + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert data["url"] == "/v1/embeddings" + + def test_batch_collector_mixed_apis_in_sequence(self, tmp_path): + # Demonstrate that different endpoints need different files + responses_file = tmp_path / "responses.jsonl" + chat_file = tmp_path / "chat.jsonl" + embeddings_file = tmp_path / "embeddings.jsonl" + + responses_collector = BatchCollector(responses_file) + responses_collector.responses.create( + custom_id="req_1", model="gpt-4", input="Test" + ) + + chat_collector = BatchCollector(chat_file) + chat_collector.chat.completions.create( + custom_id="chat_1", + model="gpt-4", + messages=[{"role": "user", "content": "Test"}], + ) + + embeddings_collector = BatchCollector(embeddings_file) + embeddings_collector.embeddings.create( + custom_id="emb_1", model="text-embedding-3-small", inp="Test" + ) + + assert responses_file.exists() + assert chat_file.exists() + assert embeddings_file.exists() + + def test_batch_collector_responses_parse_structured_output(self, temp_batch_file): + class TaskAnalysis(BaseModel): + task_type: str + complexity: str + estimated_time: int + + collector = BatchCollector(temp_batch_file) + collector.responses.parse( + custom_id="analysis_1", + model="gpt-4", + text_format=TaskAnalysis, + input="Analyze this task: Build a web scraper", + instructions="Provide structured analysis", + ) + + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert "text" in data["body"] + assert data["body"]["text"]["format"]["name"] == "TaskAnalysis" + assert "task_type" in str(data["body"]["text"]["format"]["schema"]) + + def test_batch_collector_chat_parse_structured_output(self, temp_batch_file): + class CodeReview(BaseModel): + issues: list[str] + suggestions: list[str] + rating: int + + collector = BatchCollector(temp_batch_file) + collector.chat.completions.parse( + custom_id="review_1", + model="gpt-4", + response_format=CodeReview, + messages=[ + {"role": "system", "content": "You are a code reviewer"}, + {"role": "user", "content": "Review this code: def foo(): pass"}, + ], + ) + + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert "response_format" in data["body"] + assert data["body"]["response_format"]["format"]["name"] == "CodeReview" diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..761d762 --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,401 @@ +"""Integration tests that verify end-to-end workflows.""" +import json +import pytest +from pathlib import Path +from pydantic import BaseModel, Field +from openbatch import ( + BatchCollector, + BatchJobManager, + Message, + PromptTemplate, + PromptTemplateInputInstance, + EmbeddingInputInstance, + ResponsesRequest, + ChatCompletionsRequest, + EmbeddingsRequest, + ReasoningConfig, +) + + +@pytest.fixture +def temp_dir(tmp_path): + """Provides a temporary directory for test files.""" + return tmp_path + + +class TestEndToEndBatchCreation: + """Test complete workflows from creation to file output.""" + + def test_batch_collector_complete_workflow(self, temp_dir): + """Test BatchCollector API for all three endpoint types.""" + # Separate files for each API type (as required by OpenAI) + responses_file = temp_dir / "responses.jsonl" + chat_file = temp_dir / "chat.jsonl" + embeddings_file = temp_dir / "embeddings.jsonl" + + # Responses API + responses_collector = BatchCollector(responses_file) + responses_collector.responses.create( + custom_id="resp_1", + model="gpt-4", + input="What is machine learning?", + instructions="Be concise", + max_output_tokens=100, + ) + + # Chat Completions API + chat_collector = BatchCollector(chat_file) + chat_collector.chat.completions.create( + custom_id="chat_1", + model="gpt-4", + messages=[ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Explain quantum computing"}, + ], + temperature=0.7, + ) + + # Embeddings API + embeddings_collector = BatchCollector(embeddings_file) + embeddings_collector.embeddings.create( + custom_id="emb_1", + model="text-embedding-3-small", + inp="Machine learning is a subset of artificial intelligence", + ) + + # Verify all files exist and contain valid JSON + assert responses_file.exists() + assert chat_file.exists() + assert embeddings_file.exists() + + with open(responses_file, "r") as f: + resp_data = json.loads(f.readline()) + assert resp_data["url"] == "/v1/responses" + + with open(chat_file, "r") as f: + chat_data = json.loads(f.readline()) + assert chat_data["url"] == "/v1/chat/completions" + + with open(embeddings_file, "r") as f: + emb_data = json.loads(f.readline()) + assert emb_data["url"] == "/v1/embeddings" + + def test_batch_job_manager_templated_workflow(self, temp_dir): + """Test BatchJobManager with prompt templates for bulk generation.""" + batch_file = temp_dir / "marketing_batch.jsonl" + + # Setup + template = PromptTemplate( + messages=[ + Message( + role="system", + content="You are a marketing copywriter. Generate a catchy, two-sentence description.", + ), + Message( + role="user", + content="Product: {product_name}, Features: {features}", + ), + ] + ) + + common_config = ResponsesRequest( + model="gpt-4-mini", temperature=0.8, max_output_tokens=100 + ) + + products = [ + PromptTemplateInputInstance( + id="prod_001", + prompt_value_mapping={ + "product_name": "AeroGlide Drone", + "features": "4K camera, 30-min flight", + }, + ), + PromptTemplateInputInstance( + id="prod_002", + prompt_value_mapping={ + "product_name": "HydroPure Bottle", + "features": "Self-cleaning, insulated steel", + }, + ), + PromptTemplateInputInstance( + id="prod_003", + prompt_value_mapping={ + "product_name": "SmartDesk Pro", + "features": "Height adjustable, USB charging", + }, + instance_request_options={"temperature": 0.5}, # Override for this instance + ), + ] + + # Generate batch + manager = BatchJobManager() + manager.add_templated_instances( + prompt=template, + common_request=common_config, + input_instances=products, + save_file_path=batch_file, + ) + + # Verify + assert batch_file.exists() + with open(batch_file, "r") as f: + lines = f.readlines() + + assert len(lines) == 3 + + # Check first product + data1 = json.loads(lines[0]) + assert data1["custom_id"] == "prod_001" + assert "AeroGlide Drone" in str(data1["body"]["input"]) + assert data1["body"]["temperature"] == 0.8 + + # Check third product with overridden temperature + data3 = json.loads(lines[2]) + assert data3["custom_id"] == "prod_003" + assert data3["body"]["temperature"] == 0.5 + + def test_batch_job_manager_embeddings_workflow(self, temp_dir): + """Test BatchJobManager for bulk embeddings generation.""" + batch_file = temp_dir / "embeddings_batch.jsonl" + + common_config = EmbeddingsRequest( + model="text-embedding-3-small", dimensions=512 + ) + + documents = [ + EmbeddingInputInstance(id="doc_1", input="The sky is blue."), + EmbeddingInputInstance(id="doc_2", input="Grass is green."), + EmbeddingInputInstance(id="doc_3", input="Water is wet."), + EmbeddingInputInstance( + id="doc_4", + input="Fire is hot.", + instance_request_options={"dimensions": 256}, + ), + ] + + manager = BatchJobManager() + manager.add_embedding_requests( + inputs=documents, common_request=common_config, save_file_path=batch_file + ) + + assert batch_file.exists() + with open(batch_file, "r") as f: + lines = f.readlines() + + assert len(lines) == 4 + + data1 = json.loads(lines[0]) + assert data1["body"]["dimensions"] == 512 + + data4 = json.loads(lines[3]) + assert data4["body"]["dimensions"] == 256 # Overridden + + +class TestStructuredOutputWorkflows: + """Test workflows with structured JSON output.""" + + def test_responses_api_with_structured_output(self, temp_dir): + """Test Responses API with Pydantic model for structured output.""" + + class SentimentAnalysis(BaseModel): + sentiment: str = Field(description="positive, negative, or neutral") + confidence: float = Field(description="Confidence score 0-1") + key_phrases: list[str] = Field(description="Important phrases") + + batch_file = temp_dir / "sentiment_batch.jsonl" + collector = BatchCollector(batch_file) + + texts_to_analyze = [ + "This product exceeded my expectations! Absolutely love it.", + "Terrible experience. Would not recommend to anyone.", + "It's okay, nothing special but does the job.", + ] + + for idx, text in enumerate(texts_to_analyze): + collector.responses.parse( + custom_id=f"sentiment_{idx}", + model="gpt-4", + text_format=SentimentAnalysis, + input=text, + instructions="Analyze the sentiment of the given text", + ) + + assert batch_file.exists() + with open(batch_file, "r") as f: + lines = f.readlines() + + assert len(lines) == 3 + + # Verify structured output configuration + data = json.loads(lines[0]) + assert "text" in data["body"] + assert data["body"]["text"]["format"]["name"] == "SentimentAnalysis" + assert data["body"]["text"]["format"]["strict"] is True + schema = data["body"]["text"]["format"]["schema"] + assert "sentiment" in schema["properties"] + assert "confidence" in schema["properties"] + assert "key_phrases" in schema["properties"] + + def test_chat_completions_with_structured_output(self, temp_dir): + """Test Chat Completions API with structured output.""" + + class RecipeExtraction(BaseModel): + recipe_name: str + ingredients: list[str] + steps: list[str] + prep_time_minutes: int + difficulty: str + + batch_file = temp_dir / "recipes_batch.jsonl" + collector = BatchCollector(batch_file) + + recipe_texts = [ + "How to make scrambled eggs: Beat 2 eggs, heat pan, cook for 2 minutes. Takes 5 minutes. Easy.", + "Chocolate cake recipe: Mix flour, sugar, cocoa. Bake at 350F for 30 minutes. Medium difficulty. Takes 45 minutes.", + ] + + for idx, text in enumerate(recipe_texts): + collector.chat.completions.parse( + custom_id=f"recipe_{idx}", + model="gpt-4", + response_format=RecipeExtraction, + messages=[ + { + "role": "system", + "content": "Extract structured recipe information", + }, + {"role": "user", "content": text}, + ], + ) + + with open(batch_file, "r") as f: + lines = f.readlines() + + assert len(lines) == 2 + data = json.loads(lines[0]) + assert "response_format" in data["body"] + assert data["body"]["response_format"]["format"]["name"] == "RecipeExtraction" + + +class TestReasoningModelsWorkflow: + """Test workflows with reasoning models.""" + + def test_responses_api_with_reasoning_config(self, temp_dir): + """Test Responses API with reasoning configuration.""" + batch_file = temp_dir / "reasoning_batch.jsonl" + collector = BatchCollector(batch_file) + + complex_problems = [ + { + "id": "logic_1", + "problem": "If all A are B, and all B are C, are all A necessarily C?", + "effort": "high", + }, + { + "id": "logic_2", + "problem": "What is the flaw in this argument: All birds can fly. Penguins are birds. Therefore penguins can fly.", + "effort": "medium", + }, + ] + + for item in complex_problems: + collector.responses.create( + custom_id=item["id"], + model="o1-mini", + input=item["problem"], + instructions="Analyze the logical structure carefully", + reasoning=ReasoningConfig(effort=item["effort"], summary="detailed"), + ) + + with open(batch_file, "r") as f: + lines = f.readlines() + + assert len(lines) == 2 + data1 = json.loads(lines[0]) + assert data1["body"]["reasoning"]["effort"] == "high" + assert data1["body"]["reasoning"]["summary"] == "detailed" + + +class TestUnicodeAndSpecialCharacters: + """Test handling of non-ASCII characters.""" + + def test_ensure_ascii_true(self, temp_dir): + """Test that ensure_ascii=True escapes non-ASCII characters.""" + batch_file = temp_dir / "unicode_escaped.jsonl" + manager = BatchJobManager(ensure_ascii=True) + + request = ResponsesRequest(model="gpt-4", input="Hello 世界! Привет! مرحبا") + manager.add("unicode_test", request, batch_file) + + with open(batch_file, "r", encoding="utf-8") as f: + content = f.read() + + # Should contain escaped unicode + assert "\\u" in content + # Raw characters should not be present + assert "世界" not in content + + def test_ensure_ascii_false(self, temp_dir): + """Test that ensure_ascii=False preserves non-ASCII characters.""" + batch_file = temp_dir / "unicode_raw.jsonl" + manager = BatchJobManager(ensure_ascii=False) + + request = ResponsesRequest( + model="gpt-4", input="Hello 世界! Привет! مرحبا Emoji: 🚀" + ) + manager.add("unicode_test", request, batch_file) + + with open(batch_file, "r", encoding="utf-8") as f: + content = f.read() + + # Raw characters should be present + assert "世界" in content + assert "Привет" in content + assert "مرحبا" in content + assert "🚀" in content + + +class TestLargeScaleBatchGeneration: + """Test generation of large batch files.""" + + def test_generate_1000_requests(self, temp_dir): + """Test generating a batch file with 1000 requests.""" + batch_file = temp_dir / "large_batch.jsonl" + + template = PromptTemplate( + messages=[Message(role="user", content="Classify: {text}")] + ) + + common_request = ResponsesRequest(model="gpt-4-mini", max_output_tokens=10) + + # Generate 1000 instances + instances = [ + PromptTemplateInputInstance( + id=f"classify_{i:04d}", prompt_value_mapping={"text": f"Sample text {i}"} + ) + for i in range(1000) + ] + + manager = BatchJobManager() + manager.add_templated_instances( + prompt=template, + common_request=common_request, + input_instances=instances, + save_file_path=batch_file, + ) + + # Verify + assert batch_file.exists() + with open(batch_file, "r") as f: + lines = f.readlines() + + assert len(lines) == 1000 + + # Spot check first and last + first = json.loads(lines[0]) + last = json.loads(lines[999]) + + assert first["custom_id"] == "classify_0000" + assert last["custom_id"] == "classify_0999" + assert "Sample text 0" in str(first["body"]["input"]) + assert "Sample text 999" in str(last["body"]["input"]) diff --git a/tests/test_manager.py b/tests/test_manager.py new file mode 100644 index 0000000..4c335a4 --- /dev/null +++ b/tests/test_manager.py @@ -0,0 +1,395 @@ +import json +import pytest +import warnings +from pathlib import Path +from openbatch.manager import BatchJobManager +from openbatch.model import ( + Message, + PromptTemplate, + ReusablePrompt, + PromptTemplateInputInstance, + EmbeddingInputInstance, + ResponsesRequest, + ChatCompletionsRequest, + EmbeddingsRequest, +) + + +@pytest.fixture +def temp_batch_file(tmp_path): + """Provides a temporary file path for batch files.""" + return tmp_path / "test_batch.jsonl" + + +@pytest.fixture +def manager(): + """Provides a BatchJobManager instance.""" + return BatchJobManager() + + +@pytest.fixture +def manager_no_ascii(): + """Provides a BatchJobManager instance with ensure_ascii=False.""" + return BatchJobManager(ensure_ascii=False) + + +class TestBatchJobManagerAdd: + def test_add_responses_request(self, manager, temp_batch_file): + request = ResponsesRequest(model="gpt-4", input="Hello world") + manager.add("test_id", request, temp_batch_file) + + assert temp_batch_file.exists() + with open(temp_batch_file, "r") as f: + line = f.readline() + data = json.loads(line) + + assert data["custom_id"] == "test_id" + assert data["method"] == "POST" + assert data["url"] == "/v1/responses" + assert data["body"]["model"] == "gpt-4" + assert data["body"]["input"] == "Hello world" + + def test_add_chat_completions_request(self, manager, temp_batch_file): + request = ChatCompletionsRequest( + model="gpt-4", messages=[{"role": "user", "content": "Hi"}] + ) + manager.add("chat_id", request, temp_batch_file) + + assert temp_batch_file.exists() + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert data["custom_id"] == "chat_id" + assert data["url"] == "/v1/chat/completions" + assert data["body"]["messages"][0]["role"] == "user" + + def test_add_embeddings_request(self, manager, temp_batch_file): + request = EmbeddingsRequest(model="text-embedding-3-small", input="Text to embed") + manager.add("emb_id", request, temp_batch_file) + + assert temp_batch_file.exists() + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert data["custom_id"] == "emb_id" + assert data["url"] == "/v1/embeddings" + assert data["body"]["input"] == "Text to embed" + + def test_add_multiple_requests(self, manager, temp_batch_file): + request1 = ResponsesRequest(model="gpt-4", input="First") + request2 = ResponsesRequest(model="gpt-4", input="Second") + + manager.add("id1", request1, temp_batch_file) + manager.add("id2", request2, temp_batch_file) + + with open(temp_batch_file, "r") as f: + lines = f.readlines() + + assert len(lines) == 2 + data1 = json.loads(lines[0]) + data2 = json.loads(lines[1]) + assert data1["custom_id"] == "id1" + assert data2["custom_id"] == "id2" + + def test_add_responses_request_without_input_or_prompt_raises( + self, manager, temp_batch_file + ): + request = ResponsesRequest(model="gpt-4") + with pytest.raises(ValueError, match="must define either an input or a prompt"): + manager.add("test_id", request, temp_batch_file) + + def test_add_chat_completions_request_without_messages_raises( + self, manager, temp_batch_file + ): + # messages is required in the Pydantic model, so we create a request + # and then set messages to None manually + request = ChatCompletionsRequest(model="gpt-4", messages=[]) + request.messages = None + with pytest.raises(ValueError, match="must define messages"): + manager.add("test_id", request, temp_batch_file) + + def test_add_embeddings_request_without_input_raises(self, manager, temp_batch_file): + # input is required in the Pydantic model, so we create a request + # and then set input to None manually + request = EmbeddingsRequest(model="text-embedding-3-small", input="dummy") + request.input = None + with pytest.raises(ValueError, match="must define an input"): + manager.add("test_id", request, temp_batch_file) + + def test_add_creates_parent_directory(self, manager, tmp_path): + nested_path = tmp_path / "subdir" / "batch.jsonl" + request = ResponsesRequest(model="gpt-4", input="Test") + manager.add("test_id", request, nested_path) + + assert nested_path.exists() + assert nested_path.parent.exists() + + def test_add_with_ensure_ascii_false(self, manager_no_ascii, temp_batch_file): + request = ResponsesRequest(model="gpt-4", input="Hello 世界") + manager_no_ascii.add("test_id", request, temp_batch_file) + + with open(temp_batch_file, "r", encoding="utf-8") as f: + content = f.read() + data = json.loads(content) + + assert "世界" in content # Non-ASCII characters preserved + assert data["body"]["input"] == "Hello 世界" + + def test_add_with_ensure_ascii_true(self, manager, temp_batch_file): + request = ResponsesRequest(model="gpt-4", input="Hello 世界") + manager.add("test_id", request, temp_batch_file) + + with open(temp_batch_file, "r", encoding="utf-8") as f: + raw_content = f.read() + + # ASCII escaped version should not contain the raw unicode characters + assert "\\u" in raw_content + + +class TestBatchJobManagerTemplatedInstances: + def test_add_templated_instances_responses_api(self, manager, temp_batch_file): + template = PromptTemplate( + messages=[Message(role="user", content="Product: {product}, Price: {price}")] + ) + common_request = ResponsesRequest(model="gpt-4", temperature=0.7) + instances = [ + PromptTemplateInputInstance( + id="prod_1", prompt_value_mapping={"product": "Laptop", "price": "$1000"} + ), + PromptTemplateInputInstance( + id="prod_2", prompt_value_mapping={"product": "Mouse", "price": "$20"} + ), + ] + + manager.add_templated_instances(template, common_request, instances, temp_batch_file) + + with open(temp_batch_file, "r") as f: + lines = f.readlines() + + assert len(lines) == 2 + data1 = json.loads(lines[0]) + data2 = json.loads(lines[1]) + + assert data1["custom_id"] == "prod_1" + assert "Laptop" in str(data1["body"]["input"]) + assert data2["custom_id"] == "prod_2" + assert "Mouse" in str(data2["body"]["input"]) + + def test_add_templated_instances_chat_completions_api(self, manager, temp_batch_file): + template = PromptTemplate( + messages=[ + Message(role="system", content="You are a {role}"), + Message(role="user", content="{question}"), + ] + ) + common_request = ChatCompletionsRequest(model="gpt-4", temperature=0.5) + instances = [ + PromptTemplateInputInstance( + id="q1", + prompt_value_mapping={"role": "teacher", "question": "What is math?"}, + ), + PromptTemplateInputInstance( + id="q2", prompt_value_mapping={"role": "chef", "question": "How to cook?"} + ), + ] + + manager.add_templated_instances(template, common_request, instances, temp_batch_file) + + with open(temp_batch_file, "r") as f: + lines = f.readlines() + + assert len(lines) == 2 + data1 = json.loads(lines[0]) + + assert data1["body"]["messages"][0]["content"] == "You are a teacher" + assert data1["body"]["messages"][1]["content"] == "What is math?" + + def test_add_templated_instances_with_reusable_prompt(self, manager, temp_batch_file): + reusable_prompt = ReusablePrompt(id="prompt_123", version="v1", variables={}) + common_request = ResponsesRequest(model="gpt-4") + instances = [ + PromptTemplateInputInstance(id="inst_1", prompt_value_mapping={"var": "value"}) + ] + + manager.add_templated_instances( + reusable_prompt, common_request, instances, temp_batch_file + ) + + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert data["body"]["prompt"]["id"] == "prompt_123" + assert data["body"]["prompt"]["version"] == "v1" + assert data["body"]["prompt"]["variables"]["var"] == "value" + + def test_add_templated_instances_with_instance_options(self, manager, temp_batch_file): + template = PromptTemplate(messages=[Message(role="user", content="{text}")]) + common_request = ResponsesRequest(model="gpt-4", temperature=0.7) + instances = [ + PromptTemplateInputInstance( + id="inst_1", + prompt_value_mapping={"text": "Hello"}, + instance_request_options={"temperature": 0.9}, + ), + PromptTemplateInputInstance( + id="inst_2", prompt_value_mapping={"text": "World"} + ), + ] + + manager.add_templated_instances(template, common_request, instances, temp_batch_file) + + with open(temp_batch_file, "r") as f: + lines = f.readlines() + + data1 = json.loads(lines[0]) + data2 = json.loads(lines[1]) + + # First instance should override temperature + assert data1["body"]["temperature"] == 0.9 + # Second instance should use common temperature + assert data2["body"]["temperature"] == 0.7 + + def test_add_templated_instances_with_embeddings_raises(self, manager, temp_batch_file): + template = PromptTemplate(messages=[Message(role="user", content="Test")]) + common_request = EmbeddingsRequest( + model="text-embedding-3-small", input="dummy" + ) + instances = [ + PromptTemplateInputInstance(id="inst_1", prompt_value_mapping={"text": "Test"}) + ] + + with pytest.raises(ValueError, match="Embeddings API is not supported"): + manager.add_templated_instances( + template, common_request, instances, temp_batch_file + ) + + def test_add_templated_instances_reusable_prompt_with_chat_raises( + self, manager, temp_batch_file + ): + reusable_prompt = ReusablePrompt(id="prompt_123", version="v1", variables={}) + common_request = ChatCompletionsRequest(model="gpt-4", messages=[]) + instances = [ + PromptTemplateInputInstance(id="inst_1", prompt_value_mapping={"var": "value"}) + ] + + with pytest.raises(ValueError, match="Reusable prompts can only be used"): + manager.add_templated_instances( + reusable_prompt, common_request, instances, temp_batch_file + ) + + def test_add_templated_instances_appending_warning(self, manager, temp_batch_file): + # Create the file first + temp_batch_file.write_text("existing content\n") + + template = PromptTemplate(messages=[Message(role="user", content="Test")]) + common_request = ResponsesRequest(model="gpt-4") + instances = [ + PromptTemplateInputInstance(id="inst_1", prompt_value_mapping={}) + ] + + # Should warn when appending to existing file + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + manager.add_templated_instances( + template, common_request, instances, temp_batch_file + ) + assert len(w) == 1 + assert "already exists" in str(w[0].message) + + def test_add_templated_instances_suppress_warnings(self, manager, temp_batch_file): + temp_batch_file.write_text("existing content\n") + + template = PromptTemplate(messages=[Message(role="user", content="Test")]) + common_request = ResponsesRequest(model="gpt-4") + instances = [ + PromptTemplateInputInstance(id="inst_1", prompt_value_mapping={}) + ] + + # Should not warn when suppress_warnings=True + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + manager.add_templated_instances( + template, + common_request, + instances, + temp_batch_file, + suppress_warnings=True, + ) + assert len(w) == 0 + + +class TestBatchJobManagerEmbeddingRequests: + def test_add_embedding_requests(self, manager, temp_batch_file): + common_request = EmbeddingsRequest( + model="text-embedding-3-small", dimensions=512 + ) + inputs = [ + EmbeddingInputInstance(id="emb_1", input="First text"), + EmbeddingInputInstance(id="emb_2", input="Second text"), + ] + + manager.add_embedding_requests(inputs, common_request, temp_batch_file) + + with open(temp_batch_file, "r") as f: + lines = f.readlines() + + assert len(lines) == 2 + data1 = json.loads(lines[0]) + data2 = json.loads(lines[1]) + + assert data1["custom_id"] == "emb_1" + assert data1["body"]["input"] == "First text" + assert data1["body"]["dimensions"] == 512 + + assert data2["custom_id"] == "emb_2" + assert data2["body"]["input"] == "Second text" + + def test_add_embedding_requests_with_list_input(self, manager, temp_batch_file): + common_request = EmbeddingsRequest(model="text-embedding-3-small") + inputs = [ + EmbeddingInputInstance(id="emb_1", input=["Text 1", "Text 2"]), + ] + + manager.add_embedding_requests(inputs, common_request, temp_batch_file) + + with open(temp_batch_file, "r") as f: + data = json.loads(f.readline()) + + assert isinstance(data["body"]["input"], list) + assert len(data["body"]["input"]) == 2 + + def test_add_embedding_requests_with_instance_options(self, manager, temp_batch_file): + common_request = EmbeddingsRequest( + model="text-embedding-3-small", dimensions=512 + ) + inputs = [ + EmbeddingInputInstance( + id="emb_1", + input="First", + instance_request_options={"dimensions": 256}, + ), + EmbeddingInputInstance(id="emb_2", input="Second"), + ] + + manager.add_embedding_requests(inputs, common_request, temp_batch_file) + + with open(temp_batch_file, "r") as f: + lines = f.readlines() + + data1 = json.loads(lines[0]) + data2 = json.loads(lines[1]) + + # First instance should override dimensions + assert data1["body"]["dimensions"] == 256 + # Second instance should use common dimensions + assert data2["body"]["dimensions"] == 512 + + def test_add_embedding_requests_creates_parent_directory(self, manager, tmp_path): + nested_path = tmp_path / "subdir" / "embeddings.jsonl" + common_request = EmbeddingsRequest(model="text-embedding-3-small") + inputs = [EmbeddingInputInstance(id="emb_1", input="Test")] + + manager.add_embedding_requests(inputs, common_request, nested_path) + + assert nested_path.exists() + assert nested_path.parent.exists() diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..d0f4962 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,258 @@ +import pytest +from pydantic import BaseModel, Field +from openbatch.model import ( + Message, + PromptTemplate, + ReusablePrompt, + ReasoningConfig, + PromptTemplateInputInstance, + MessagesInputInstance, + EmbeddingInputInstance, + ResponsesRequest, + ChatCompletionsRequest, + EmbeddingsRequest, + ResponsesAPIStrategy, + ChatCompletionsAPIStrategy, + EmbeddingsAPIStrategy, +) + + +class TestMessage: + def test_message_creation(self): + msg = Message(role="user", content="Hello") + assert msg.role == "user" + assert msg.content == "Hello" + + def test_message_serialize(self): + msg = Message(role="system", content="You are helpful") + serialized = msg.serialize() + assert serialized == {"role": "system", "content": "You are helpful"} + + +class TestPromptTemplate: + def test_prompt_template_creation(self): + template = PromptTemplate( + messages=[ + Message(role="system", content="You are a {role}"), + Message(role="user", content="Help me with {task}"), + ] + ) + assert len(template.messages) == 2 + + def test_prompt_template_format(self): + template = PromptTemplate( + messages=[ + Message(role="system", content="You are a {role}"), + Message(role="user", content="Help me with {task}"), + ] + ) + formatted = template.format(role="assistant", task="coding") + assert len(formatted) == 2 + assert formatted[0].content == "You are a assistant" + assert formatted[1].content == "Help me with coding" + + def test_prompt_template_format_multiple_placeholders(self): + template = PromptTemplate( + messages=[ + Message( + role="user", + content="Product: {product}, Price: {price}, Category: {category}", + ) + ] + ) + formatted = template.format(product="Laptop", price="$1000", category="Electronics") + assert formatted[0].content == "Product: Laptop, Price: $1000, Category: Electronics" + + +class TestReusablePrompt: + def test_reusable_prompt_creation(self): + prompt = ReusablePrompt(id="prompt_123", version="v1", variables={"name": "John"}) + assert prompt.id == "prompt_123" + assert prompt.version == "v1" + assert prompt.variables == {"name": "John"} + + +class TestReasoningConfig: + def test_reasoning_config_default(self): + config = ReasoningConfig() + assert config.effort == "medium" + assert config.summary is None + + def test_reasoning_config_custom(self): + config = ReasoningConfig(effort="high", summary="detailed") + assert config.effort == "high" + assert config.summary == "detailed" + + +class TestInputInstances: + def test_prompt_template_input_instance(self): + instance = PromptTemplateInputInstance( + id="inst_1", + prompt_value_mapping={"name": "Alice", "age": "30"}, + instance_request_options={"temperature": 0.5}, + ) + assert instance.id == "inst_1" + assert instance.prompt_value_mapping == {"name": "Alice", "age": "30"} + assert instance.instance_request_options == {"temperature": 0.5} + + def test_messages_input_instance(self): + messages = [Message(role="user", content="Hello")] + instance = MessagesInputInstance(id="inst_2", messages=messages) + assert instance.id == "inst_2" + assert len(instance.messages) == 1 + + def test_embedding_input_instance(self): + instance = EmbeddingInputInstance(id="emb_1", input="Text to embed") + assert instance.id == "emb_1" + assert instance.input == "Text to embed" + + def test_embedding_input_instance_list(self): + instance = EmbeddingInputInstance(id="emb_2", input=["Text 1", "Text 2"]) + assert instance.id == "emb_2" + assert isinstance(instance.input, list) + assert len(instance.input) == 2 + + +class TestAPIStrategies: + def test_responses_api_strategy(self): + strategy = ResponsesAPIStrategy() + assert strategy.url == "/v1/responses" + request = strategy.create_request("test_id", {"model": "gpt-4"}) + assert request["custom_id"] == "test_id" + assert request["method"] == "POST" + assert request["url"] == "/v1/responses" + assert request["body"] == {"model": "gpt-4"} + + def test_chat_completions_api_strategy(self): + strategy = ChatCompletionsAPIStrategy() + assert strategy.url == "/v1/chat/completions" + + def test_embeddings_api_strategy(self): + strategy = EmbeddingsAPIStrategy() + assert strategy.url == "/v1/embeddings" + + +class TestResponsesRequest: + def test_responses_request_minimal(self): + request = ResponsesRequest(model="gpt-4") + assert request.model == "gpt-4" + assert request.input is None + + def test_responses_request_with_input(self): + request = ResponsesRequest(model="gpt-4", input="Hello world") + assert request.input == "Hello world" + + def test_responses_request_to_dict(self): + request = ResponsesRequest(model="gpt-4", input="Hello", temperature=0.7) + result = request.to_dict() + assert result["model"] == "gpt-4" + assert result["input"] == "Hello" + assert result["temperature"] == 0.7 + + def test_responses_request_exclude_none(self): + request = ResponsesRequest(model="gpt-4", input="Hello") + result = request.to_dict() + assert "temperature" not in result + assert "max_output_tokens" not in result + + def test_responses_request_set_input_messages(self): + request = ResponsesRequest(model="gpt-4") + messages = [Message(role="user", content="Hello")] + request.set_input_messages(messages) + assert request.input == [{"role": "user", "content": "Hello"}] + + def test_responses_request_set_output_structure(self): + class TestOutput(BaseModel): + name: str + age: int + + request = ResponsesRequest(model="gpt-4") + request.set_output_structure(TestOutput) + assert request.text is not None + assert "format" in request.text + assert request.text["format"]["type"] == "json_schema" + assert request.text["format"]["name"] == "TestOutput" + assert request.text["format"]["strict"] is True + + def test_responses_request_with_reasoning(self): + request = ResponsesRequest( + model="gpt-4", reasoning=ReasoningConfig(effort="high", summary="detailed") + ) + assert request.reasoning.effort == "high" + assert request.reasoning.summary == "detailed" + + +class TestChatCompletionsRequest: + def test_chat_completions_request_minimal(self): + messages = [{"role": "user", "content": "Hello"}] + request = ChatCompletionsRequest(model="gpt-4", messages=messages) + assert request.model == "gpt-4" + assert len(request.messages) == 1 + + def test_chat_completions_request_set_input_messages(self): + request = ChatCompletionsRequest(model="gpt-4", messages=[]) + messages = [Message(role="user", content="Hi")] + request.set_input_messages(messages) + assert request.messages == [{"role": "user", "content": "Hi"}] + + def test_chat_completions_request_set_output_structure(self): + class TestResponse(BaseModel): + answer: str = Field(description="The answer") + + request = ChatCompletionsRequest(model="gpt-4", messages=[]) + request.set_output_structure(TestResponse) + assert request.response_format is not None + assert "format" in request.response_format + assert request.response_format["format"]["name"] == "TestResponse" + + def test_chat_completions_request_with_temperature(self): + request = ChatCompletionsRequest( + model="gpt-4", messages=[{"role": "user", "content": "Hi"}], temperature=0.9 + ) + assert request.temperature == 0.9 + + def test_chat_completions_request_to_dict(self): + request = ChatCompletionsRequest( + model="gpt-4", + messages=[{"role": "user", "content": "Hi"}], + temperature=0.5, + max_completion_tokens=100, + ) + result = request.to_dict() + assert result["model"] == "gpt-4" + assert result["temperature"] == 0.5 + assert result["max_completion_tokens"] == 100 + + +class TestEmbeddingsRequest: + def test_embeddings_request_with_string(self): + request = EmbeddingsRequest(model="text-embedding-3-small", input="Hello") + assert request.model == "text-embedding-3-small" + assert request.input == "Hello" + + def test_embeddings_request_with_list(self): + request = EmbeddingsRequest( + model="text-embedding-3-small", input=["Hello", "World"] + ) + assert isinstance(request.input, list) + assert len(request.input) == 2 + + def test_embeddings_request_set_input(self): + request = EmbeddingsRequest(model="text-embedding-3-small", input="test") + request.set_input("New text") + assert request.input == "New text" + + def test_embeddings_request_with_dimensions(self): + request = EmbeddingsRequest( + model="text-embedding-3-small", input="test", dimensions=512 + ) + assert request.dimensions == 512 + + def test_embeddings_request_to_dict(self): + request = EmbeddingsRequest( + model="text-embedding-3-small", input="test", dimensions=256 + ) + result = request.to_dict() + assert result["model"] == "text-embedding-3-small" + assert result["input"] == "test" + assert result["dimensions"] == 256 diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..2766e6d --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,266 @@ +import pytest +from pydantic import BaseModel, Field +from typing import Optional, List +from openbatch._utils import ( + type_to_json_schema, + _ensure_strict_json_schema, + has_more_than_n_keys, + resolve_ref, +) + + +class TestHasMoreThanNKeys: + def test_empty_dict(self): + assert has_more_than_n_keys({}, 0) is False + assert has_more_than_n_keys({}, 1) is False + + def test_single_key(self): + assert has_more_than_n_keys({"a": 1}, 0) is True + assert has_more_than_n_keys({"a": 1}, 1) is False + assert has_more_than_n_keys({"a": 1}, 2) is False + + def test_multiple_keys(self): + obj = {"a": 1, "b": 2, "c": 3} + assert has_more_than_n_keys(obj, 0) is True + assert has_more_than_n_keys(obj, 1) is True + assert has_more_than_n_keys(obj, 2) is True + assert has_more_than_n_keys(obj, 3) is False + + +class TestResolveRef: + def test_resolve_simple_ref(self): + root = {"definitions": {"Person": {"type": "object"}}} + result = resolve_ref(root=root, ref="#/definitions/Person") + assert result == {"type": "object"} + + def test_resolve_nested_ref(self): + root = { + "definitions": {"Nested": {"properties": {"field": {"type": "string"}}}} + } + result = resolve_ref(root=root, ref="#/definitions/Nested/properties/field") + assert result == {"type": "string"} + + def test_resolve_ref_invalid_format(self): + root = {"definitions": {}} + with pytest.raises(ValueError, match="Does not start with #/"): + resolve_ref(root=root, ref="definitions/Person") + + +class TestEnsureStrictJsonSchema: + def test_object_adds_additional_properties_false(self): + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + result = _ensure_strict_json_schema(schema, path=(), root=schema) + assert result["additionalProperties"] is False + + def test_object_preserves_existing_additional_properties(self): + schema = { + "type": "object", + "properties": {"name": {"type": "string"}}, + "additionalProperties": True, + } + result = _ensure_strict_json_schema(schema, path=(), root=schema) + assert result["additionalProperties"] is True + + def test_object_makes_all_properties_required(self): + schema = { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + } + result = _ensure_strict_json_schema(schema, path=(), root=schema) + assert "required" in result + assert set(result["required"]) == {"name", "age"} + + def test_nested_object_properties(self): + schema = { + "type": "object", + "properties": { + "person": { + "type": "object", + "properties": {"name": {"type": "string"}}, + } + }, + } + result = _ensure_strict_json_schema(schema, path=(), root=schema) + assert result["properties"]["person"]["additionalProperties"] is False + assert result["properties"]["person"]["required"] == ["name"] + + def test_array_items(self): + schema = {"type": "array", "items": {"type": "string"}} + result = _ensure_strict_json_schema(schema, path=(), root=schema) + assert result["items"]["type"] == "string" + + def test_array_with_object_items(self): + schema = { + "type": "array", + "items": {"type": "object", "properties": {"id": {"type": "integer"}}}, + } + result = _ensure_strict_json_schema(schema, path=(), root=schema) + assert result["items"]["additionalProperties"] is False + assert result["items"]["required"] == ["id"] + + def test_any_of_union(self): + schema = { + "anyOf": [{"type": "string"}, {"type": "object", "properties": {"a": {"type": "string"}}}] + } + result = _ensure_strict_json_schema(schema, path=(), root=schema) + assert len(result["anyOf"]) == 2 + assert result["anyOf"][1]["additionalProperties"] is False + + def test_all_of_single_element_processed(self): + # Test that allOf with single element is processed + # Note: The implementation handles multi-element allOf, but single-element + # allOf isn't necessarily unwrapped by the current implementation + schema = { + "allOf": [ + {"type": "object", "properties": {"name": {"type": "string"}}}, + {"type": "object", "properties": {"age": {"type": "integer"}}}, + ] + } + result = _ensure_strict_json_schema(schema, path=(), root=schema) + # Verify allOf entries are processed + assert "allOf" in result + assert len(result["allOf"]) == 2 + # Each entry should have properties from original schema + assert result["allOf"][0]["type"] == "object" + assert result["allOf"][1]["type"] == "object" + + def test_definitions_processed(self): + schema = { + "type": "object", + "properties": {"user": {"$ref": "#/definitions/User"}}, + "definitions": { + "User": {"type": "object", "properties": {"name": {"type": "string"}}} + }, + } + result = _ensure_strict_json_schema(schema, path=(), root=schema) + assert result["definitions"]["User"]["additionalProperties"] is False + assert result["definitions"]["User"]["required"] == ["name"] + + def test_defs_processed(self): + schema = { + "type": "object", + "properties": {"user": {"$ref": "#/$defs/User"}}, + "$defs": { + "User": {"type": "object", "properties": {"name": {"type": "string"}}} + }, + } + result = _ensure_strict_json_schema(schema, path=(), root=schema) + assert result["$defs"]["User"]["additionalProperties"] is False + assert result["$defs"]["User"]["required"] == ["name"] + + def test_ref_with_additional_properties_unrolled(self): + schema = { + "type": "object", + "properties": { + "user": { + "$ref": "#/definitions/User", + "description": "The user object", + } + }, + "definitions": { + "User": {"type": "object", "properties": {"name": {"type": "string"}}} + }, + } + result = _ensure_strict_json_schema(schema, path=(), root=schema) + # The $ref should be unrolled when there are additional properties + user_prop = result["properties"]["user"] + assert "$ref" not in user_prop + assert user_prop["type"] == "object" + assert user_prop["description"] == "The user object" + assert user_prop["additionalProperties"] is False + + +class TestTypeToJsonSchema: + def test_simple_model(self): + class SimpleModel(BaseModel): + name: str + age: int + + schema = type_to_json_schema(SimpleModel) + assert schema["type"] == "object" + assert "name" in schema["properties"] + assert "age" in schema["properties"] + assert schema["additionalProperties"] is False + assert set(schema["required"]) == {"name", "age"} + + def test_model_with_optional_field(self): + class ModelWithOptional(BaseModel): + name: str + nickname: Optional[str] = None + + schema = type_to_json_schema(ModelWithOptional) + # All properties should be required in strict mode + assert set(schema["required"]) == {"name", "nickname"} + + def test_model_with_nested_object(self): + class Address(BaseModel): + street: str + city: str + + class Person(BaseModel): + name: str + address: Address + + schema = type_to_json_schema(Person) + assert schema["additionalProperties"] is False + assert "address" in schema["properties"] + + # Check nested object is also strict + if "$defs" in schema: + address_schema = schema["$defs"]["Address"] + else: + address_schema = schema["definitions"]["Address"] + + assert address_schema["additionalProperties"] is False + assert set(address_schema["required"]) == {"street", "city"} + + def test_model_with_list_field(self): + class TodoList(BaseModel): + title: str + items: List[str] + + schema = type_to_json_schema(TodoList) + assert schema["properties"]["items"]["type"] == "array" + assert schema["properties"]["items"]["items"]["type"] == "string" + + def test_model_with_field_descriptions(self): + class DescribedModel(BaseModel): + name: str = Field(description="The person's name") + age: int = Field(description="The person's age") + + schema = type_to_json_schema(DescribedModel) + assert schema["properties"]["name"]["description"] == "The person's name" + assert schema["properties"]["age"]["description"] == "The person's age" + + def test_complex_nested_model(self): + class Item(BaseModel): + id: int + name: str + + class Order(BaseModel): + order_id: str + items: List[Item] + total: float + + class Customer(BaseModel): + name: str + orders: List[Order] + + schema = type_to_json_schema(Customer) + assert schema["additionalProperties"] is False + assert "orders" in schema["properties"] + + # Verify all nested models have strict schema + defs_key = "$defs" if "$defs" in schema else "definitions" + assert schema[defs_key]["Order"]["additionalProperties"] is False + assert schema[defs_key]["Item"]["additionalProperties"] is False + + def test_model_preserves_constraints(self): + class ConstrainedModel(BaseModel): + age: int = Field(ge=0, le=120) + email: str = Field(pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$") + + schema = type_to_json_schema(ConstrainedModel) + assert schema["properties"]["age"]["minimum"] == 0 + assert schema["properties"]["age"]["maximum"] == 120 + assert "pattern" in schema["properties"]["email"] From e465a427e78b470aa7c0ef78d19ab33c3246d77a Mon Sep 17 00:00:00 2001 From: Daniel Gomm Date: Mon, 9 Feb 2026 15:11:43 +0100 Subject: [PATCH 3/9] :sparkles: add possibility to validate batch-job files --- README.md | 14 ++ docs/index.md | 2 +- docs/validation.md | 135 ++++++++++++ mkdocs.yml | 1 + src/openbatch/__init__.py | 2 + src/openbatch/validation.py | 349 +++++++++++++++++++++++++++++ tests/test_collector.py | 1 - tests/test_integration.py | 1 - tests/test_manager.py | 1 - tests/test_validation.py | 427 ++++++++++++++++++++++++++++++++++++ 10 files changed, 929 insertions(+), 4 deletions(-) create mode 100644 docs/validation.md create mode 100644 src/openbatch/validation.py create mode 100644 tests/test_validation.py diff --git a/README.md b/README.md index 6bae5fb..cf2127a 100644 --- a/README.md +++ b/README.md @@ -199,6 +199,20 @@ For detailed instructions on these steps, please refer to the **[Official OpenAI ----- +## Validation + +OpenBatch includes built-in [validation](https://openbatch.daniel-gomm.com/validation/) to catch errors before uploading to OpenAI. + +```python +from openbatch import validate_batch_file + +result = validate_batch_file("my_batch.jsonl") +if result.is_valid: + print(f"Valid! {result.stats['total_requests']} requests") +``` + +----- + ## Testing OpenBatch includes a comprehensive test suite. diff --git a/docs/index.md b/docs/index.md index daa63b3..15b7708 100644 --- a/docs/index.md +++ b/docs/index.md @@ -2,7 +2,7 @@ **OpenBatch** is a lightweight Python utility designed to streamline the creation of JSONL files for the [OpenAI Batch API](https://platform.openai.com/docs/guides/batch). It provides a type-safe and intuitive interface using Pydantic models to construct requests for various endpoints, including `/v1/chat/completions`, `/v1/embeddings`, and the new `/v1/responses` endpoint. -## Key Features ✨ +## Key Features - **Type-Safe & Modern**: Built with Pydantic for robust, self-documenting, and editor-friendly request models. - **Dual APIs for Flexibility**: diff --git a/docs/validation.md b/docs/validation.md new file mode 100644 index 0000000..b7642a9 --- /dev/null +++ b/docs/validation.md @@ -0,0 +1,135 @@ +# Batch File Validation + +OpenBatch includes comprehensive validation to catch errors before uploading batch files to OpenAI, saving time and preventing failed uploads. + +## Overview + +The validation module checks your batch files for: + +- ✓ Valid JSONL format +- ✓ Unique `custom_id` values +- ✓ Required fields present +- ✓ Valid HTTP methods (POST) +- ✓ Valid endpoint URLs +- ✓ Correct request body structure for each API type +- ✓ File size limits (200 MB) +- ✓ Request count limits (50,000) +- ⚠ Mixed endpoint types warning + +## Simple Validation + +```python +from openbatch import validate_batch_file + +result = validate_batch_file("my_batch.jsonl") + +if result.is_valid: + print(f"✓ Batch file is valid!") + print(f"Total requests: {result.stats['total_requests']}") + # Proceed with upload to OpenAI +else: + print(f"✗ Validation failed:") + print(result) # Shows detailed errors and warnings +``` + +### Quick Boolean Check + +For a fast True/False check: + +```python +from openbatch.validation import quick_validate + +if quick_validate("my_batch.jsonl"): + # File is valid, proceed + upload_to_openai("my_batch.jsonl") +``` + +## Validation Result + +The `ValidationResult` object provides detailed information: + +```python +result = validate_batch_file("batch.jsonl") + +# Check validity +if result.is_valid: + # Access statistics + print(f"Requests: {result.stats['total_requests']}") + print(f"File size: {result.stats['file_size_mb']} MB") + print(f"Endpoints: {result.stats['endpoints_used']}") + print(f"Unique IDs: {result.stats['unique_custom_ids']}") +else: + # Handle errors + for error in result.errors: + print(f"ERROR: {error}") + + # Review warnings + for warning in result.warnings: + print(f"WARNING: {warning}") +``` + +### Human-Readable Output + +The `ValidationResult` has a nice string representation: + +```python +result = validate_batch_file("batch.jsonl") +print(result) +``` + +Output: +``` +Validation: PASSED + +Statistics: + total_requests: 100 + unique_custom_ids: 100 + endpoints_used: ['/v1/responses'] + file_size_mb: 0.05 + +Warnings (1): + • File extension is '.json', expected '.jsonl' +``` + +## Advanced Usage + +### Configurable Validation + +Use `BatchFileValidator` for custom validation options: + +```python +from openbatch.validation import BatchFileValidator + +validator = BatchFileValidator( + check_custom_id_uniqueness=True, # Check for duplicate IDs + check_file_size=True, # Check 200 MB limit + check_request_count=True, # Check 50K request limit + allow_mixed_endpoints=False # Warn about mixed endpoints +) + +result = validator.validate_file("batch.jsonl") +``` + +### Disable Specific Checks + +```python +from openbatch import validate_batch_file + +# Skip custom_id uniqueness check (if you want duplicates) +result = validate_batch_file( + "batch.jsonl", + check_custom_id_uniqueness=False +) + +# Allow mixed endpoints without warning +result = validate_batch_file( + "batch.jsonl", + allow_mixed_endpoints=True +) + +# Disable strict checks (file size, request count) +result = validate_batch_file( + "batch.jsonl", + strict=False +) +``` diff --git a/mkdocs.yml b/mkdocs.yml index c201aa9..637afde 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -28,6 +28,7 @@ plugins: nav: - index.md - Getting Started: getting_started.md + - Validation: validation.md - Reference: - BatchJobManager: reference/batch_job_manager.md - BatchJobCollector: reference/batch_job_collector.md diff --git a/src/openbatch/__init__.py b/src/openbatch/__init__.py index 6e210bc..36f0341 100644 --- a/src/openbatch/__init__.py +++ b/src/openbatch/__init__.py @@ -11,6 +11,7 @@ EmbeddingsRequest, ReasoningConfig, ) +from openbatch.validation import validate_batch_file __all__ = [ "BatchCollector", @@ -24,4 +25,5 @@ "ChatCompletionsRequest", "EmbeddingsRequest", "ReasoningConfig", + "validate_batch_file", ] diff --git a/src/openbatch/validation.py b/src/openbatch/validation.py new file mode 100644 index 0000000..5aecbb8 --- /dev/null +++ b/src/openbatch/validation.py @@ -0,0 +1,349 @@ +""" +Validation utilities for OpenAI batch job files. + +This module provides functions to validate JSONL batch files before uploading them +to the OpenAI Batch API, helping catch errors early and ensure compliance with +API requirements. +""" + +import json +from pathlib import Path +from typing import Union, List, Dict, Any, Set +from dataclasses import dataclass, field + + +@dataclass +class ValidationResult: + """ + Result of batch file validation. + + Attributes: + is_valid (bool): Whether the batch file is valid + errors (List[str]): List of validation errors found + warnings (List[str]): List of non-critical warnings + stats (Dict[str, Any]): Statistics about the batch file + """ + is_valid: bool + errors: List[str] = field(default_factory=list) + warnings: List[str] = field(default_factory=list) + stats: Dict[str, Any] = field(default_factory=dict) + + def __str__(self) -> str: + """Human-readable summary of validation results.""" + lines = [] + lines.append(f"Validation: {'PASSED' if self.is_valid else 'FAILED'}") + + if self.stats: + lines.append("\nStatistics:") + for key, value in self.stats.items(): + lines.append(f" {key}: {value}") + + if self.errors: + lines.append(f"\nErrors ({len(self.errors)}):") + for error in self.errors: + lines.append(f" • {error}") + + if self.warnings: + lines.append(f"\nWarnings ({len(self.warnings)}):") + for warning in self.warnings: + lines.append(f" • {warning}") + + return "\n".join(lines) + + +class BatchFileValidator: + """ + Validator for OpenAI batch job JSONL files. + + Validates batch files against OpenAI Batch API requirements including: + - Valid JSONL format + - Unique custom_ids + - Required fields present + - Valid endpoint URLs + - File size limits + """ + + # OpenAI Batch API constraints (as of 2026) + MAX_FILE_SIZE_MB = 200 + MAX_REQUESTS = 50000 + VALID_ENDPOINTS = { + "/v1/responses", + "/v1/chat/completions", + "/v1/embeddings" + } + REQUIRED_FIELDS = {"custom_id", "method", "url", "body"} + + def __init__( + self, + check_custom_id_uniqueness: bool = True, + check_file_size: bool = True, + check_request_count: bool = True, + allow_mixed_endpoints: bool = False + ): + """ + Initialize the validator with configuration options. + + Args: + check_custom_id_uniqueness: Check for duplicate custom_ids + check_file_size: Check file size against limits + check_request_count: Check number of requests against limits + allow_mixed_endpoints: Allow multiple endpoint types in one file (not recommended) + """ + self.check_custom_id_uniqueness = check_custom_id_uniqueness + self.check_file_size = check_file_size + self.check_request_count = check_request_count + self.allow_mixed_endpoints = allow_mixed_endpoints + + def validate_file(self, file_path: Union[str, Path]) -> ValidationResult: + """ + Validate a batch file. + + Args: + file_path: Path to the JSONL batch file + + Returns: + ValidationResult with errors, warnings, and statistics + """ + file_path = Path(file_path) + result = ValidationResult(is_valid=True) + + # Check file exists + if not file_path.exists(): + result.errors.append(f"File not found: {file_path}") + result.is_valid = False + return result + + # Check file extension + if file_path.suffix != ".jsonl": + result.warnings.append( + f"File extension is '{file_path.suffix}', expected '.jsonl'" + ) + + # Check file size + if self.check_file_size: + file_size_mb = file_path.stat().st_size / (1024 * 1024) + result.stats["file_size_mb"] = round(file_size_mb, 2) + + if file_size_mb > self.MAX_FILE_SIZE_MB: + result.errors.append( + f"File size ({file_size_mb:.2f} MB) exceeds limit " + f"({self.MAX_FILE_SIZE_MB} MB)" + ) + result.is_valid = False + + # Validate content + try: + with open(file_path, "r", encoding="utf-8") as f: + self._validate_content(f, result) + except Exception as e: + result.errors.append(f"Error reading file: {str(e)}") + result.is_valid = False + + return result + + def _validate_content(self, file_handle, result: ValidationResult) -> None: + """Validate the content of the batch file.""" + custom_ids: Set[str] = set() + endpoints: Set[str] = set() + line_number = 0 + + for line in file_handle: + line_number += 1 + line = line.strip() + + # Skip empty lines + if not line: + result.warnings.append(f"Line {line_number}: Empty line (will be ignored)") + continue + + # Parse JSON + try: + request = json.loads(line) + except json.JSONDecodeError as e: + result.errors.append( + f"Line {line_number}: Invalid JSON - {str(e)}" + ) + result.is_valid = False + continue + + # Validate request structure + self._validate_request(request, line_number, custom_ids, endpoints, result) + + # Update statistics + result.stats["total_requests"] = line_number + result.stats["unique_custom_ids"] = len(custom_ids) + result.stats["endpoints_used"] = list(endpoints) + + # Check request count + if self.check_request_count and line_number > self.MAX_REQUESTS: + result.errors.append( + f"Request count ({line_number}) exceeds limit ({self.MAX_REQUESTS})" + ) + result.is_valid = False + + # Check for mixed endpoints + if not self.allow_mixed_endpoints and len(endpoints) > 1: + result.warnings.append( + f"Multiple endpoint types detected: {list(endpoints)}. " + "OpenAI recommends one request type per file." + ) + + def _validate_request( + self, + request: Dict[str, Any], + line_number: int, + custom_ids: Set[str], + endpoints: Set[str], + result: ValidationResult + ) -> None: + """Validate a single request object.""" + + # Check required fields + missing_fields = self.REQUIRED_FIELDS - set(request.keys()) + if missing_fields: + result.errors.append( + f"Line {line_number}: Missing required fields: {missing_fields}" + ) + result.is_valid = False + return + + # Validate custom_id + custom_id = request.get("custom_id") + if not custom_id or not isinstance(custom_id, str): + result.errors.append( + f"Line {line_number}: Invalid custom_id (must be a non-empty string)" + ) + result.is_valid = False + elif self.check_custom_id_uniqueness: + if custom_id in custom_ids: + result.errors.append( + f"Line {line_number}: Duplicate custom_id '{custom_id}'" + ) + result.is_valid = False + else: + custom_ids.add(custom_id) + + # Validate method + method = request.get("method") + if method != "POST": + result.errors.append( + f"Line {line_number}: Invalid method '{method}' (must be 'POST')" + ) + result.is_valid = False + + # Validate URL + url = request.get("url") + if url not in self.VALID_ENDPOINTS: + result.errors.append( + f"Line {line_number}: Invalid endpoint '{url}'. " + f"Valid endpoints: {self.VALID_ENDPOINTS}" + ) + result.is_valid = False + else: + endpoints.add(url) + + # Validate body + body = request.get("body") + if not isinstance(body, dict): + result.errors.append( + f"Line {line_number}: 'body' must be a JSON object" + ) + result.is_valid = False + else: + self._validate_body(body, url, line_number, result) + + def _validate_body( + self, + body: Dict[str, Any], + endpoint: str, + line_number: int, + result: ValidationResult + ) -> None: + """Validate the request body based on endpoint type.""" + + # Check for model field (required for all endpoints) + if "model" not in body: + result.errors.append( + f"Line {line_number}: Missing required field 'model' in body" + ) + result.is_valid = False + + # Endpoint-specific validation + if endpoint == "/v1/responses": + if "input" not in body and "prompt" not in body: + result.errors.append( + f"Line {line_number}: Responses API requires either 'input' or 'prompt' in body" + ) + result.is_valid = False + + elif endpoint == "/v1/chat/completions": + if "messages" not in body: + result.errors.append( + f"Line {line_number}: Chat Completions API requires 'messages' in body" + ) + result.is_valid = False + elif not isinstance(body["messages"], list): + result.errors.append( + f"Line {line_number}: 'messages' must be an array" + ) + result.is_valid = False + + elif endpoint == "/v1/embeddings": + if "input" not in body: + result.errors.append( + f"Line {line_number}: Embeddings API requires 'input' in body" + ) + result.is_valid = False + + +def validate_batch_file( + file_path: Union[str, Path], + strict: bool = True, + check_custom_id_uniqueness: bool = True, + allow_mixed_endpoints: bool = False +) -> ValidationResult: + """ + Validate a batch file (convenience function). + + Args: + file_path: Path to the JSONL batch file + strict: Enable all checks (file size, request count) + check_custom_id_uniqueness: Check for duplicate custom_ids + allow_mixed_endpoints: Allow multiple endpoint types in one file + + Returns: + ValidationResult with errors, warnings, and statistics + + Example: + >>> result = validate_batch_file("my_batch.jsonl") + >>> if result.is_valid: + ... print("File is valid!") + ... else: + ... print(result) + """ + validator = BatchFileValidator( + check_custom_id_uniqueness=check_custom_id_uniqueness, + check_file_size=strict, + check_request_count=strict, + allow_mixed_endpoints=allow_mixed_endpoints + ) + return validator.validate_file(file_path) + + +def quick_validate(file_path: Union[str, Path]) -> bool: + """ + Quick validation check (returns True/False). + + Args: + file_path: Path to the JSONL batch file + + Returns: + True if valid, False otherwise + + Example: + >>> if quick_validate("my_batch.jsonl"): + ... # Proceed with upload + ... pass + """ + result = validate_batch_file(file_path) + return result.is_valid diff --git a/tests/test_collector.py b/tests/test_collector.py index 090ab14..e4d8588 100644 --- a/tests/test_collector.py +++ b/tests/test_collector.py @@ -1,6 +1,5 @@ import json import pytest -from pathlib import Path from pydantic import BaseModel, Field from openbatch.collector import BatchCollector, Responses, ChatCompletions, Embeddings from openbatch.model import ReasoningConfig diff --git a/tests/test_integration.py b/tests/test_integration.py index 761d762..d342ed2 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,7 +1,6 @@ """Integration tests that verify end-to-end workflows.""" import json import pytest -from pathlib import Path from pydantic import BaseModel, Field from openbatch import ( BatchCollector, diff --git a/tests/test_manager.py b/tests/test_manager.py index 4c335a4..b979366 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -1,7 +1,6 @@ import json import pytest import warnings -from pathlib import Path from openbatch.manager import BatchJobManager from openbatch.model import ( Message, diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 0000000..d7b59c2 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,427 @@ +"""Tests for batch file validation.""" + +import json +import pytest +from openbatch.validation import ( + validate_batch_file, + quick_validate, + ValidationResult, +) + + +@pytest.fixture +def temp_batch_file(tmp_path): + """Provides a temporary file path for batch files.""" + return tmp_path / "test_batch.jsonl" + + +class TestValidationResult: + def test_validation_result_str(self): + result = ValidationResult( + is_valid=False, + errors=["Error 1", "Error 2"], + warnings=["Warning 1"], + stats={"total_requests": 10, "file_size_mb": 0.5} + ) + output = str(result) + assert "FAILED" in output + assert "Error 1" in output + assert "Warning 1" in output + assert "total_requests: 10" in output + + def test_validation_result_success(self): + result = ValidationResult(is_valid=True, stats={"total_requests": 5}) + output = str(result) + assert "PASSED" in output + + +class TestBatchFileValidator: + def test_valid_batch_file(self, temp_batch_file): + """Test validation of a valid batch file.""" + requests = [ + { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"} + }, + { + "custom_id": "req_2", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "World"} + } + ] + + with open(temp_batch_file, "w") as f: + for req in requests: + f.write(json.dumps(req) + "\n") + + result = validate_batch_file(temp_batch_file) + assert result.is_valid + assert len(result.errors) == 0 + assert result.stats["total_requests"] == 2 + assert result.stats["unique_custom_ids"] == 2 + + def test_file_not_found(self): + """Test validation of non-existent file.""" + result = validate_batch_file("nonexistent.jsonl") + assert not result.is_valid + assert any("not found" in err.lower() for err in result.errors) + + def test_invalid_json(self, temp_batch_file): + """Test validation of file with invalid JSON.""" + with open(temp_batch_file, "w") as f: + f.write('{"custom_id": "req_1", "invalid json}\n') + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("invalid json" in err.lower() for err in result.errors) + + def test_duplicate_custom_ids(self, temp_batch_file): + """Test detection of duplicate custom_ids.""" + requests = [ + { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"} + }, + { + "custom_id": "req_1", # Duplicate + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "World"} + } + ] + + with open(temp_batch_file, "w") as f: + for req in requests: + f.write(json.dumps(req) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("duplicate" in err.lower() for err in result.errors) + + def test_missing_required_fields(self, temp_batch_file): + """Test detection of missing required fields.""" + request = { + "custom_id": "req_1", + # Missing method, url, body + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("missing required fields" in err.lower() for err in result.errors) + + def test_invalid_method(self, temp_batch_file): + """Test detection of invalid HTTP method.""" + request = { + "custom_id": "req_1", + "method": "GET", # Should be POST + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"} + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("invalid method" in err.lower() for err in result.errors) + + def test_invalid_endpoint(self, temp_batch_file): + """Test detection of invalid endpoint URL.""" + request = { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/invalid", # Invalid endpoint + "body": {"model": "gpt-4"} + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("invalid endpoint" in err.lower() for err in result.errors) + + def test_responses_api_missing_input(self, temp_batch_file): + """Test Responses API validation - missing input/prompt.""" + request = { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4"} # Missing input or prompt + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("input" in err.lower() or "prompt" in err.lower() for err in result.errors) + + def test_chat_completions_missing_messages(self, temp_batch_file): + """Test Chat Completions API validation - missing messages.""" + request = { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/chat/completions", + "body": {"model": "gpt-4"} # Missing messages + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("messages" in err.lower() for err in result.errors) + + def test_chat_completions_invalid_messages(self, temp_batch_file): + """Test Chat Completions API validation - messages not an array.""" + request = { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/chat/completions", + "body": {"model": "gpt-4", "messages": "not an array"} + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("messages" in err.lower() and "array" in err.lower() for err in result.errors) + + def test_embeddings_missing_input(self, temp_batch_file): + """Test Embeddings API validation - missing input.""" + request = { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/embeddings", + "body": {"model": "text-embedding-3-small"} # Missing input + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("input" in err.lower() for err in result.errors) + + def test_mixed_endpoints_warning(self, temp_batch_file): + """Test warning for mixed endpoint types.""" + requests = [ + { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"} + }, + { + "custom_id": "req_2", + "method": "POST", + "url": "/v1/embeddings", + "body": {"model": "text-embedding-3-small", "input": "World"} + } + ] + + with open(temp_batch_file, "w") as f: + for req in requests: + f.write(json.dumps(req) + "\n") + + result = validate_batch_file(temp_batch_file, allow_mixed_endpoints=False) + assert result.is_valid # Valid but with warning + assert any("multiple endpoint" in warn.lower() for warn in result.warnings) + + def test_empty_lines_warning(self, temp_batch_file): + """Test warning for empty lines.""" + with open(temp_batch_file, "w") as f: + f.write('{"custom_id": "req_1", "method": "POST", "url": "/v1/responses", "body": {"model": "gpt-4", "input": "Hi"}}\n') + f.write("\n") # Empty line + f.write('{"custom_id": "req_2", "method": "POST", "url": "/v1/responses", "body": {"model": "gpt-4", "input": "Bye"}}\n') + + result = validate_batch_file(temp_batch_file) + assert any("empty line" in warn.lower() for warn in result.warnings) + + def test_wrong_file_extension_warning(self, tmp_path): + """Test warning for wrong file extension.""" + json_file = tmp_path / "batch.json" # Should be .jsonl + request = { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"} + } + + with open(json_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(json_file) + assert any(".jsonl" in warn.lower() for warn in result.warnings) + + def test_missing_model_in_body(self, temp_batch_file): + """Test detection of missing model field in body.""" + request = { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": {"input": "Hello"} # Missing model + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("model" in err.lower() for err in result.errors) + + def test_invalid_custom_id_type(self, temp_batch_file): + """Test detection of invalid custom_id type.""" + request = { + "custom_id": 123, # Should be string + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"} + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("custom_id" in err.lower() for err in result.errors) + + def test_empty_custom_id(self, temp_batch_file): + """Test detection of empty custom_id.""" + request = { + "custom_id": "", # Empty string + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"} + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("custom_id" in err.lower() for err in result.errors) + + def test_body_not_object(self, temp_batch_file): + """Test detection of non-object body.""" + request = { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": "not an object" + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("body" in err.lower() and "object" in err.lower() for err in result.errors) + + def test_skip_custom_id_check(self, temp_batch_file): + """Test disabling custom_id uniqueness check.""" + requests = [ + { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"} + }, + { + "custom_id": "req_1", # Duplicate + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "World"} + } + ] + + with open(temp_batch_file, "w") as f: + for req in requests: + f.write(json.dumps(req) + "\n") + + result = validate_batch_file(temp_batch_file, check_custom_id_uniqueness=False) + # Should be valid when uniqueness check is disabled + assert result.is_valid + + +class TestConvenienceFunctions: + def test_quick_validate_true(self, temp_batch_file): + """Test quick_validate with valid file.""" + request = { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"} + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + assert quick_validate(temp_batch_file) is True + + def test_quick_validate_false(self, temp_batch_file): + """Test quick_validate with invalid file.""" + with open(temp_batch_file, "w") as f: + f.write("invalid json\n") + + assert quick_validate(temp_batch_file) is False + + +class TestComplexScenarios: + def test_large_valid_file(self, temp_batch_file): + """Test validation of file with many requests.""" + with open(temp_batch_file, "w") as f: + for i in range(1000): + request = { + "custom_id": f"req_{i}", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": f"Request {i}"} + } + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert result.is_valid + assert result.stats["total_requests"] == 1000 + assert result.stats["unique_custom_ids"] == 1000 + + def test_all_three_endpoints(self, temp_batch_file): + """Test file with all three valid endpoints.""" + requests = [ + { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"} + }, + { + "custom_id": "req_2", + "method": "POST", + "url": "/v1/chat/completions", + "body": {"model": "gpt-4", "messages": [{"role": "user", "content": "Hi"}]} + }, + { + "custom_id": "req_3", + "method": "POST", + "url": "/v1/embeddings", + "body": {"model": "text-embedding-3-small", "input": "Text"} + } + ] + + with open(temp_batch_file, "w") as f: + for req in requests: + f.write(json.dumps(req) + "\n") + + result = validate_batch_file(temp_batch_file, allow_mixed_endpoints=True) + assert result.is_valid + assert len(result.stats["endpoints_used"]) == 3 From 643fb22023f19fa226bdf39c5d9d892d75f113d5 Mon Sep 17 00:00:00 2001 From: Daniel Gomm Date: Mon, 9 Feb 2026 15:14:57 +0100 Subject: [PATCH 4/9] :wrench: update to version 0.0.4 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0adbf8c..83e5887 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "openbatch" -version = "0.0.3" +version = "0.0.4" authors = [ { name="Daniel Gomm", email="daniel.gomm@cwi.nl" }, ] From 15388e577cc5c6b23bac9903d27a34f7e379aed1 Mon Sep 17 00:00:00 2001 From: Daniel Gomm Date: Mon, 9 Feb 2026 15:23:50 +0100 Subject: [PATCH 5/9] :wrench: update python requirements to 3.11 --- .github/workflows/test.yml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 292fc91..df9408b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.11", "3.12", "3.13", "3.14"] steps: - uses: actions/checkout@v4 diff --git a/pyproject.toml b/pyproject.toml index 83e5887..0fbab8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ authors = [ ] description = "Create batch jobs for the OpenAI API with ease." readme = "README.md" -requires-python = ">=3.5" +requires-python = ">=3.11" classifiers = [ "Programming Language :: Python :: 3", "Operating System :: OS Independent", From 0f58f8f03ca15cffb4ab94dffc310d5e411f0254 Mon Sep 17 00:00:00 2001 From: Daniel Gomm Date: Tue, 10 Feb 2026 10:15:05 +0100 Subject: [PATCH 6/9] :construction_worker: add and apply linting and static type checking --- .github/workflows/lint.yml | 34 ++++ .pre-commit-config.yaml | 33 ++++ README.md | 16 +- pyproject.toml | 65 +++++++ src/openbatch/__init__.py | 20 +-- src/openbatch/_utils.py | 33 ++-- src/openbatch/collector.py | 49 ++--- src/openbatch/manager.py | 56 +++--- src/openbatch/model.py | 349 ++++++++++++++++++++++++++---------- src/openbatch/validation.py | 103 +++++------ tests/test_collector.py | 50 +++--- tests/test_integration.py | 48 +++-- tests/test_manager.py | 82 ++++----- tests/test_model.py | 32 ++-- tests/test_utils.py | 33 ++-- tests/test_validation.py | 74 ++++---- 16 files changed, 661 insertions(+), 416 deletions(-) create mode 100644 .github/workflows/lint.yml create mode 100644 .pre-commit-config.yaml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..4711e8e --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,34 @@ +name: Code Quality + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff mypy + pip install -e . + + - name: Run Ruff linter + run: ruff check src/ tests/ + + - name: Run Ruff formatter check + run: ruff format --check src/ tests/ + + - name: Run mypy type checker + run: mypy src/ + continue-on-error: true # Don't fail build on type errors initially diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..f48b756 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,33 @@ +# Pre-commit hooks for code quality +# See https://pre-commit.com for more information + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + args: ['--maxkb=1000'] + - id: check-json + - id: check-toml + - id: check-merge-conflict + - id: debug-statements + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.8.5 + hooks: + # Run the linter + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + # Run the formatter + - id: ruff-format + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.13.0 + hooks: + - id: mypy + additional_dependencies: [pydantic>=2.11.9] + args: [--ignore-missing-imports] + files: ^src/ diff --git a/README.md b/README.md index cf2127a..526dda3 100644 --- a/README.md +++ b/README.md @@ -231,4 +231,18 @@ pytest --cov=openbatch The test suite includes: - Unit tests for all core functionality - Integration tests for end-to-end workflows -- Tests for structured outputs, reasoning models, and unicode handling \ No newline at end of file +- Tests for structured outputs, reasoning models, and unicode handling + +----- + +## Contributing + +Contributions are welcome! Please see our [Contributing Guide](CONTRIBUTING.md) for details on: + +- Development setup +- Code quality standards +- Branch naming conventions (`feature/`, `fix/`, `documentation/`) +- Commit message conventions (using [gitmoji](https://gitmoji.dev/)) +- Opening issues and pull requests + +For bugs and feature requests, please [open an issue](https://github.com/TiepNguyen2003/OpenAIBatchJobBuilder/issues). diff --git a/pyproject.toml b/pyproject.toml index 0fbab8e..bd65d98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,11 @@ test = [ "pytest>=8.0.0", "pytest-cov>=4.1.0", ] +dev = [ + "ruff>=0.8.0", + "mypy>=1.13.0", + "pre-commit>=4.0.0", +] [tool.pytest.ini_options] testpaths = ["tests"] @@ -58,3 +63,63 @@ exclude_lines = [ "if TYPE_CHECKING:", "@abstractmethod", ] + +# Ruff configuration +[tool.ruff] +line-length = 100 +target-version = "py311" + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "N", # pep8-naming + "UP", # pyupgrade + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "SIM", # flake8-simplify + "RUF", # ruff-specific rules +] +ignore = [ + "E501", # line too long (handled by formatter) + "B008", # do not perform function calls in argument defaults + "N805", # first argument should be named self (pydantic validators) +] + +[tool.ruff.lint.per-file-ignores] +"tests/**/*.py" = [ + "N802", # function name should be lowercase + "N803", # argument name should be lowercase +] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +line-ending = "lf" + +# Mypy configuration - Static type checking +# Configured for gradual adoption with reasonable strictness +[tool.mypy] +python_version = "3.11" +warn_return_any = false # Too strict for **kwargs patterns +warn_unused_configs = true +disallow_untyped_defs = false # Allow untyped defs (gradual adoption) +disallow_incomplete_defs = false # Allow incomplete defs (gradual adoption) +check_untyped_defs = true +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = false # Pre-commit adds ignores, this would conflict +warn_no_return = true +strict_equality = true +# Allow flexible type checking for union types and **kwargs +disable_error_code = ["assignment", "no-untyped-def"] + +[[tool.mypy.overrides]] +module = "tests.*" +disallow_untyped_defs = false + +[[tool.mypy.overrides]] +module = "pydantic.*" +ignore_missing_imports = true diff --git a/src/openbatch/__init__.py b/src/openbatch/__init__.py index 36f0341..b2c17e5 100644 --- a/src/openbatch/__init__.py +++ b/src/openbatch/__init__.py @@ -1,29 +1,29 @@ from openbatch.collector import BatchCollector from openbatch.manager import BatchJobManager from openbatch.model import ( + ChatCompletionsRequest, + EmbeddingInputInstance, + EmbeddingsRequest, Message, + MessagesInputInstance, PromptTemplate, PromptTemplateInputInstance, - MessagesInputInstance, - EmbeddingInputInstance, - ResponsesRequest, - ChatCompletionsRequest, - EmbeddingsRequest, ReasoningConfig, + ResponsesRequest, ) from openbatch.validation import validate_batch_file __all__ = [ "BatchCollector", "BatchJobManager", + "ChatCompletionsRequest", + "EmbeddingInputInstance", + "EmbeddingsRequest", "Message", + "MessagesInputInstance", "PromptTemplate", "PromptTemplateInputInstance", - "MessagesInputInstance", - "EmbeddingInputInstance", - "ResponsesRequest", - "ChatCompletionsRequest", - "EmbeddingsRequest", "ReasoningConfig", + "ResponsesRequest", "validate_batch_file", ] diff --git a/src/openbatch/_utils.py b/src/openbatch/_utils.py index 6c01376..e750051 100644 --- a/src/openbatch/_utils.py +++ b/src/openbatch/_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, TypeVar +from typing import Any, TypeVar from pydantic import BaseModel @@ -6,6 +6,7 @@ # Copied and adapted from the OpenAI library to avoid adding a dependency https://github.com/openai/openai-python/blob/main/src/openai/lib/_pydantic.py + def _ensure_strict_json_schema( json_schema: object, *, @@ -26,7 +27,9 @@ def _ensure_strict_json_schema( definitions = json_schema.get("definitions") if isinstance(definitions, dict): for definition_name, definition_schema in definitions.items(): - _ensure_strict_json_schema(definition_schema, path=(*path, "definitions", definition_name), root=root) + _ensure_strict_json_schema( + definition_schema, path=(*path, "definitions", definition_name), root=root + ) typ = json_schema.get("type") if typ == "object" and "additionalProperties" not in json_schema: @@ -36,7 +39,7 @@ def _ensure_strict_json_schema( # { 'type': 'object', 'properties': { 'a': {...} } } properties = json_schema.get("properties") if isinstance(properties, dict): - json_schema["required"] = [prop for prop in properties.keys()] + json_schema["required"] = list(properties) json_schema["properties"] = { key: _ensure_strict_json_schema(prop_schema, path=(*path, "properties", key), root=root) for key, prop_schema in properties.items() @@ -60,7 +63,9 @@ def _ensure_strict_json_schema( all_of = json_schema.get("allOf") if isinstance(all_of, dict): if len(all_of) == 1: - json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0"), root=root)) + json_schema.update( + _ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0"), root=root) + ) json_schema.pop("allOf") else: json_schema["allOf"] = [ @@ -79,7 +84,9 @@ def _ensure_strict_json_schema( resolved = resolve_ref(root=root, ref=ref) if not isinstance(resolved, dict): - raise ValueError(f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}") + raise ValueError( + f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}" + ) # properties from the json schema take priority over the ones on the `$ref` json_schema.update({**resolved, **json_schema}) @@ -90,13 +97,10 @@ def _ensure_strict_json_schema( return json_schema + def has_more_than_n_keys(obj: dict[str, object], n: int) -> bool: - i = 0 - for _ in obj.keys(): - i += 1 - if i > n: - return True - return False + return any(i > n for i, _ in enumerate(obj, 1)) + def resolve_ref(*, root: dict[str, object], ref: str) -> object: if not ref.startswith("#/"): @@ -106,12 +110,15 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object: resolved = root for key in path: value = resolved[key] - assert isinstance(value, dict), f"encountered non-dictionary entry while resolving {ref} - {resolved}" + assert isinstance( + value, dict + ), f"encountered non-dictionary entry while resolving {ref} - {resolved}" resolved = value return resolved -def type_to_json_schema(output_type: type[T]) -> Dict[str, Any]: + +def type_to_json_schema(output_type: type[T]) -> dict[str, Any]: json_schema = output_type.model_json_schema() schema = _ensure_strict_json_schema(json_schema, path=(), root=json_schema) return schema diff --git a/src/openbatch/collector.py b/src/openbatch/collector.py index 31bc482..f1dc56a 100644 --- a/src/openbatch/collector.py +++ b/src/openbatch/collector.py @@ -1,22 +1,22 @@ from os import PathLike +from pathlib import Path from types import SimpleNamespace -from typing import Union, Optional from pydantic import BaseModel from openbatch.manager import BatchJobManager -from openbatch.model import ResponsesRequest, ChatCompletionsRequest, EmbeddingsRequest +from openbatch.model import ChatCompletionsRequest, EmbeddingsRequest, ResponsesRequest class Responses: """ - A utility class for easily constructing and adding individual - Responses API requests to a batch job file. + A utility class for easily constructing and adding individual + Responses API requests to a batch job file. - It acts as a high-level interface for the '/v1/responses' endpoint. - """ + It acts as a high-level interface for the '/v1/responses' endpoint. + """ - def __init__(self, batch_file_path: Union[str, PathLike]): + def __init__(self, batch_file_path: str | PathLike): """ Initializes the Responses collector. @@ -24,10 +24,12 @@ def __init__(self, batch_file_path: Union[str, PathLike]): batch_file_path (Union[str, PathLike]): The path to the JSONL file where the batch requests will be written. """ - self.batch_file_path = batch_file_path + self.batch_file_path = Path(batch_file_path) self._manager = BatchJobManager() - def parse(self, custom_id: str, model: str, text_format: Optional[type[BaseModel]] = None, **kwargs) -> None: + def parse( + self, custom_id: str, model: str, text_format: type[BaseModel] | None = None, **kwargs + ) -> None: """ Creates a ResponsesRequest, optionally enforcing a JSON output structure, and adds it to the batch file. Use it like the `OpenAI().responses.parse()` method. @@ -59,14 +61,16 @@ def create(self, custom_id: str, model: str, **kwargs) -> None: def _add_request(self, custom_id: str, request: ResponsesRequest) -> None: self._manager.add(custom_id, request, self.batch_file_path) + class ChatCompletions: """ - A utility class for easily constructing and adding individual - Chat Completions API requests to a batch job file. + A utility class for easily constructing and adding individual + Chat Completions API requests to a batch job file. - It acts as a high-level interface for the '/v1/chat/completions' endpoint. - """ - def __init__(self, batch_file_path: Union[str, PathLike]): + It acts as a high-level interface for the '/v1/chat/completions' endpoint. + """ + + def __init__(self, batch_file_path: str | PathLike): """ Initializes the ChatCompletions collector. @@ -74,10 +78,12 @@ def __init__(self, batch_file_path: Union[str, PathLike]): batch_file_path (Union[str, PathLike]): The path to the JSONL file where the batch requests will be written. """ - self.batch_file_path = batch_file_path + self.batch_file_path = Path(batch_file_path) self._manager = BatchJobManager() - def parse(self, custom_id: str, model: str, response_format: Optional[type[BaseModel]] = None, **kwargs) -> None: + def parse( + self, custom_id: str, model: str, response_format: type[BaseModel] | None = None, **kwargs + ) -> None: """ Creates a ChatCompletionsRequest, optionally enforcing a JSON output structure, and adds it to the batch file. Use it like the `OpenAI().chat.completions.parse()` method. @@ -109,6 +115,7 @@ def create(self, custom_id: str, model: str, **kwargs) -> None: def _add_request(self, custom_id: str, request: ChatCompletionsRequest) -> None: self._manager.add(custom_id, request, self.batch_file_path) + class Embeddings: """ A utility class for easily constructing and adding individual @@ -116,7 +123,8 @@ class Embeddings: It acts as a high-level interface for the '/v1/embeddings' endpoint. """ - def __init__(self, batch_file_path: Union[str, PathLike]): + + def __init__(self, batch_file_path: str | PathLike): """ Initializes the Embeddings collector. @@ -124,10 +132,10 @@ def __init__(self, batch_file_path: Union[str, PathLike]): batch_file_path (Union[str, PathLike]): The path to the JSONL file where the batch requests will be written. """ - self.batch_file_path = batch_file_path + self.batch_file_path = Path(batch_file_path) self._manager = BatchJobManager() - def create(self, custom_id: str, model: str, inp: Union[str, list[str]], **kwargs) -> None: + def create(self, custom_id: str, model: str, inp: str | list[str], **kwargs) -> None: """ Creates an EmbeddingsRequest and adds it to the batch file. Use it like the `OpenAI().embeddings.create()` method. @@ -160,7 +168,8 @@ class BatchCollector: where the batch requests will be written. The file will be created if it doesn't exist and appended to if it does. """ - def __init__(self, batch_file_path: Union[str, PathLike]): + + def __init__(self, batch_file_path: str | PathLike): self.responses = Responses(batch_file_path) self.chat = SimpleNamespace() self.chat.completions = ChatCompletions(batch_file_path) diff --git a/src/openbatch/manager.py b/src/openbatch/manager.py index 76a1a10..e6172a4 100644 --- a/src/openbatch/manager.py +++ b/src/openbatch/manager.py @@ -1,25 +1,26 @@ import json import warnings +from collections.abc import Iterable from copy import deepcopy from pathlib import Path -from typing import TypeVar, Iterable, Union +from typing import TypeVar from openbatch.model import ( + BaseRequest, + ChatCompletionsAPIStrategy, + ChatCompletionsRequest, + EmbeddingInputInstance, + EmbeddingsAPIStrategy, + EmbeddingsRequest, PromptTemplate, - ReusablePrompt, PromptTemplateInputInstance, - ResponsesRequest, ResponsesAPIStrategy, - EmbeddingsAPIStrategy, - ChatCompletionsAPIStrategy, - BaseRequest, - EmbeddingsRequest, - EmbeddingInputInstance, - ChatCompletionsRequest, + ResponsesRequest, + ReusablePrompt, ) B = TypeVar("B", bound=BaseRequest) -R = TypeVar("R", bound=Union[ResponsesRequest, ChatCompletionsRequest]) +R = TypeVar("R", bound=ResponsesRequest | ChatCompletionsRequest) class BatchJobManager: @@ -41,7 +42,7 @@ def __init__(self, ensure_ascii: bool = True) -> None: def add_templated_instances( self, - prompt: Union[PromptTemplate, ReusablePrompt], + prompt: PromptTemplate | ReusablePrompt, common_request: R, input_instances: Iterable[PromptTemplateInputInstance], save_file_path: str | Path, @@ -69,9 +70,7 @@ def add_templated_instances( unsupported request type. """ if isinstance(common_request, EmbeddingsRequest): - raise ValueError( - "Embeddings API is not supported with templated instances." - ) + raise ValueError("Embeddings API is not supported with templated instances.") elif not isinstance(common_request, ResponsesRequest) and not isinstance( common_request, ChatCompletionsRequest ): @@ -84,6 +83,7 @@ def add_templated_instances( warnings.warn( f"File {save_file_path} already exists. New contents are appended to the file. Make sure that this is intended behavior.", category=RuntimeWarning, + stacklevel=2, ) for instance in input_instances: @@ -92,15 +92,13 @@ def add_templated_instances( request = request.model_copy(update=instance.instance_request_options) request = self._handle_prompt(prompt, request, instance) - self.add( - custom_id=instance.id, request=request, save_file_path=save_file_path - ) + self.add(custom_id=instance.id, request=request, save_file_path=save_file_path) def add_embedding_requests( self, inputs: Iterable[EmbeddingInputInstance], common_request: EmbeddingsRequest, - save_file_path: Union[str, Path], + save_file_path: str | Path, ) -> None: """ Adds multiple embedding request instances to a batch request file. @@ -123,15 +121,13 @@ def add_embedding_requests( request = request.model_copy(update=instance.instance_request_options) request.set_input(instance.input) - self.add( - custom_id=instance.id, request=request, save_file_path=save_file_path - ) + self.add(custom_id=instance.id, request=request, save_file_path=save_file_path) def add( self, custom_id: str, request: B, - save_file_path: Union[str, Path], + save_file_path: str | Path, ) -> None: """ Creates a single batch request object and appends it to the specified file. @@ -153,9 +149,7 @@ def add( if isinstance(request, ResponsesRequest): strategy = ResponsesAPIStrategy() if request.input is None and request.prompt is None: - raise ValueError( - "Responses request must define either an input or a prompt." - ) + raise ValueError("Responses request must define either an input or a prompt.") elif isinstance(request, ChatCompletionsRequest): strategy = ChatCompletionsAPIStrategy() if request.messages is None: @@ -170,14 +164,10 @@ def add( save_file_path = Path(save_file_path) save_file_path.parent.mkdir(parents=True, exist_ok=True) - batch_request = strategy.create_request( - custom_id=custom_id, body=request.to_dict() - ) + batch_request = strategy.create_request(custom_id=custom_id, body=request.to_dict()) with open(save_file_path, "a+") as outfile: - outfile.write( - json.dumps(batch_request, ensure_ascii=self.ensure_ascii) + "\n" - ) + outfile.write(json.dumps(batch_request, ensure_ascii=self.ensure_ascii) + "\n") @staticmethod def _handle_prompt( @@ -187,9 +177,7 @@ def _handle_prompt( ) -> R: if isinstance(prompt, ReusablePrompt): if not isinstance(request, ResponsesRequest): - raise ValueError( - "Reusable prompts can only be used with ResponsesOptions." - ) + raise ValueError("Reusable prompts can only be used with ResponsesOptions.") request.prompt = ReusablePrompt( id=prompt.id, version=prompt.version, diff --git a/src/openbatch/model.py b/src/openbatch/model.py index d69ec1e..6a8f632 100644 --- a/src/openbatch/model.py +++ b/src/openbatch/model.py @@ -1,15 +1,15 @@ -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod from os import PathLike from pathlib import Path -from typing import List, Dict, Any, Optional, Literal, TypeVar, Union, Self +from typing import Any, Literal, TypeVar from pydantic import BaseModel, Field from openbatch._utils import type_to_json_schema - T = TypeVar("T", bound=BaseModel) + class Message(BaseModel): """ Represents a single message in a conversation or prompt. @@ -18,6 +18,7 @@ class Message(BaseModel): role (str): The role of the message sender (e.g., "user", "assistant", "system"). content (str): The text content of the message. """ + role: str content: str @@ -30,6 +31,7 @@ def serialize(self): """ return {"role": self.role, "content": self.content} + class PromptTemplate(BaseModel): """ A template containing a sequence of messages, where the content can contain @@ -38,9 +40,10 @@ class PromptTemplate(BaseModel): Attributes: messages (List[Message]): A list of Message objects that form the template. """ - messages: List[Message] - def format(self, **kwargs) -> List[Message]: + messages: list[Message] + + def format(self, **kwargs) -> list[Message]: """ Formats the content of each message in the template using the provided keyword arguments. @@ -56,6 +59,7 @@ def format(self, **kwargs) -> List[Message]: formatted_messages.append(Message(role=message.role, content=formatted_content)) return formatted_messages + class ReusablePrompt(BaseModel): """ References a reusable prompt template and its associated variables. @@ -66,9 +70,11 @@ class ReusablePrompt(BaseModel): variables (Dict[str, Any]): A dictionary of variable names and their values to be used when formatting the prompt. """ + id: str version: str - variables: Dict[str, Any] + variables: dict[str, Any] + class ReasoningConfig(BaseModel): """ @@ -80,8 +86,14 @@ class ReasoningConfig(BaseModel): summary (Optional[Literal["auto", "concise", "detailed"]]): A summary of the reasoning performed by the model. Optional. """ - effort: Literal["minimal", "low", "medium", "high"] = Field(default="medium", description="Constrains effort on reasoning for reasoning models.") - summary: Optional[Literal["auto", "concise", "detailed"]] = Field(None, description="A summary of the reasoning performed by the model.") + + effort: Literal["minimal", "low", "medium", "high"] = Field( + default="medium", description="Constrains effort on reasoning for reasoning models." + ) + summary: Literal["auto", "concise", "detailed"] | None = Field( + None, description="A summary of the reasoning performed by the model." + ) + class InputInstance(BaseModel): """ @@ -92,8 +104,12 @@ class InputInstance(BaseModel): instance_request_options (Optional[Dict[str, Any]]): Options specific to the input instance that can be set in the API request. Optional. """ + id: str = Field(description="Unique identifier of the input instance.") - instance_request_options: Optional[Dict[str, Any]] = Field(None, description="Options specific to the input instance that to set in the request.") + instance_request_options: dict[str, Any] | None = Field( + None, description="Options specific to the input instance that to set in the request." + ) + class MessagesInputInstance(InputInstance): """ @@ -105,7 +121,9 @@ class MessagesInputInstance(InputInstance): instance_request_options (Optional[Dict[str, Any]]): Options specific to the input instance that can be set in the API request. Optional. """ - messages: List[Message] = Field(description="List of messages to be sent to the model.") + + messages: list[Message] = Field(description="List of messages to be sent to the model.") + class PromptTemplateInputInstance(InputInstance): """ @@ -118,7 +136,11 @@ class PromptTemplateInputInstance(InputInstance): instance_request_options (Optional[Dict[str, Any]]): Options specific to the input instance that can be set in the API request. Optional. """ - prompt_value_mapping: Dict[str, str] = Field(description="Mapping of prompt variable names to their values.") + + prompt_value_mapping: dict[str, str] = Field( + description="Mapping of prompt variable names to their values." + ) + class EmbeddingInputInstance(InputInstance): """ @@ -130,13 +152,16 @@ class EmbeddingInputInstance(InputInstance): instance_request_options (Optional[Dict[str, Any]]): Options specific to the input instance that can be set in the API request. Optional. """ - input: Union[str, List[str]] = Field(description="Text(s) to be embedded.") + + input: str | list[str] = Field(description="Text(s) to be embedded.") + class RequestStrategy(ABC): """ Abstract base class defining the strategy for creating a request for a specific API endpoint. """ + @property @abstractmethod def url(self) -> str: @@ -148,7 +173,7 @@ def url(self) -> str: """ pass - def create_request(self, custom_id: str, body: Dict[str, Any]) -> Dict[str, Any]: + def create_request(self, custom_id: str, body: dict[str, Any]) -> dict[str, Any]: """ Creates a structured request dictionary for a batch job. @@ -159,31 +184,33 @@ def create_request(self, custom_id: str, body: Dict[str, Any]) -> Dict[str, Any] Returns: Dict[str, Any]: A dictionary representing the complete request structure. """ - return { - "custom_id": custom_id, - "method": "POST", - "url": self.url, - "body": body - } + return {"custom_id": custom_id, "method": "POST", "url": self.url, "body": body} + class ResponsesAPIStrategy(RequestStrategy): """Strategy for creating requests to the /v1/responses endpoint.""" + @property def url(self) -> str: return "/v1/responses" + class ChatCompletionsAPIStrategy(RequestStrategy): """Strategy for creating requests to the /v1/chat/completions endpoint.""" + @property def url(self) -> str: return "/v1/chat/completions" + class EmbeddingsAPIStrategy(RequestStrategy): """Strategy for creating requests to the /v1/embeddings endpoint.""" + @property def url(self) -> str: return "/v1/embeddings" + class BaseRequest(BaseModel, ABC): """ Abstract base class for API-specific request configurations (job configurations). @@ -191,9 +218,12 @@ class BaseRequest(BaseModel, ABC): Attributes: model (str): Model ID used to generate the response, like "gpt-4.1". Defaults to "gpt-4.1". """ - model: str = Field("gpt-4.1", description="Model ID used to generate the response, like gpt-4o or o3.") - def to_dict(self) -> Dict[str, Any]: + model: str = Field( + "gpt-4.1", description="Model ID used to generate the response, like gpt-4o or o3." + ) + + def to_dict(self) -> dict[str, Any]: """ Converts the request configuration object to a dictionary, excluding fields that are None. @@ -202,6 +232,7 @@ def to_dict(self) -> Dict[str, Any]: """ return self.model_dump(exclude_none=True) + class TextGenerationRequest(BaseRequest, ABC): """ Abstract base class for text generation requests, including common parameters @@ -220,68 +251,136 @@ class TextGenerationRequest(BaseRequest, ABC): tool_choice (Optional[str | object]): How the model should select which tool to use. top_logprobs (Optional[int]): Number of most likely tokens to return at each position (0 to 20). """ - tools: Optional[List[object]] = Field(None, description="An array of tools the model may call while generating a response.") - top_p: Optional[float] = Field(None, ge=0, le=1, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass.") - parallel_tool_calls: Optional[bool] = Field(None, description="Whether to allow the model to run tool calls in parallel.") - prompt_cache_key: Optional[str] = Field(None, description="Used by OpenAI to cache responses for similar requests to optimize your cache hit rates.") - safety_identifier: Optional[str] = Field(None, description="A stable identifier used to help detect users of your application that may be violating OpenAI's usage policies.") - service_tier: Optional[Literal["auto", "default", "flex", "priority"]] = Field(None, description="Specifies the processing type used for serving the request.") - store: Optional[bool] = Field(None, description="Whether to store the generated model response for later retrieval via API.") - temperature: Optional[float] = Field(None, ge=0, le=2, description="What sampling temperature to use, between 0 and 2.") - tool_choice: Optional[str | object] = Field(None, description="How the model should select which tool (or tools) to use when generating a response.") - top_logprobs: Optional[int] = Field(None, ge=0, le=20, description="An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability.") + tools: list[object] | None = Field( + None, description="An array of tools the model may call while generating a response." + ) + top_p: float | None = Field( + None, + ge=0, + le=1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass.", + ) + parallel_tool_calls: bool | None = Field( + None, description="Whether to allow the model to run tool calls in parallel." + ) + prompt_cache_key: str | None = Field( + None, + description="Used by OpenAI to cache responses for similar requests to optimize your cache hit rates.", + ) + safety_identifier: str | None = Field( + None, + description="A stable identifier used to help detect users of your application that may be violating OpenAI's usage policies.", + ) + service_tier: Literal["auto", "default", "flex", "priority"] | None = Field( + None, description="Specifies the processing type used for serving the request." + ) + store: bool | None = Field( + None, + description="Whether to store the generated model response for later retrieval via API.", + ) + temperature: float | None = Field( + None, ge=0, le=2, description="What sampling temperature to use, between 0 and 2." + ) + tool_choice: str | object | None = Field( + None, + description="How the model should select which tool (or tools) to use when generating a response.", + ) + top_logprobs: int | None = Field( + None, + ge=0, + le=20, + description="An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability.", + ) @abstractmethod def set_output_structure(self, output_type: type[T]) -> None: pass @abstractmethod - def set_input_messages(self, messages: List[Message]) -> None: + def set_input_messages(self, messages: list[Message]) -> None: pass class ResponsesRequest(TextGenerationRequest): """ - Configuration for a /v1/responses API request. - - Attributes: - model (str): Model ID used to generate the response, like "gpt-4.1". Defaults to "gpt-4.1". - conversation (Optional[str]): The conversation this response belongs to. - include (Optional[List[Literal[...]]]): Specify additional output data to include. - input (Optional[str | List[Dict[str, str]]]): Text, image, or file inputs to the model. - instructions (Optional[str]): A system or developer message. - max_output_tokens (Optional[int]): Upper bound for generated tokens. - max_tool_calls (Optional[int]): Maximum number of tool calls allowed. - previous_response_id (Optional[str]): ID of the previous response for multi-turn. - prompt (Optional[ReusablePrompt]): Reference to a prompt template and its variables. - reasoning (Optional[ReasoningConfig]): Configuration for reasoning models. - text (Optional[object]): Configuration options for a text response from the model (e.g., JSON schema). - truncation (Optional[Literal["auto", "disabled"]]): The truncation strategy to use. - tools (Optional[List[object]]): An array of tools the model may call. - top_p (Optional[float]): An alternative to sampling with temperature (nucleus sampling). - parallel_tool_calls (Optional[bool]): Whether to allow parallel tool calls. - prompt_cache_key (Optional[str]): Used by OpenAI to cache responses. - safety_identifier (Optional[str]): A stable identifier for policy monitoring. - service_tier (Optional[Literal["auto", "default", "flex", "priority"]]): Specifies the processing type. - store (Optional[bool]): Whether to store the generated model response. - temperature (Optional[float]): Sampling temperature to use (0 to 2). - tool_choice (Optional[str | object]): How the model should select which tool to use. - top_logprobs (Optional[int]): Number of most likely tokens to return at each position (0 to 20). - """ - conversation: Optional[str] = Field(None, description="The conversation that this response belongs to.") - include: Optional[List[Literal["code_interpreter_call.outputs", "computer_call_output.output.image_url", "file_search_call.results", "message.input_image.image_url", "message.output_text.logprobs", "reasoning.encrypted_content"]]] = Field(None, description="Specify additional output data to include in the model response.") - input: Optional[str | List[Dict[str, str]]] = Field(None, description="Text, image, or file inputs to the model, used to generate a response.") - instructions: Optional[str] = Field(None, description="A system (or developer) message inserted into the model's context.") - max_output_tokens: Optional[int] = Field(None, gt=0, description="An upper bound for the number of tokens that can be generated for a response, including visible output tokens and reasoning tokens.") - max_tool_calls: Optional[int] = Field(None, gt=0, description="The maximum number of total calls to built-in tools that can be processed in a response.") - previous_response_id: Optional[str] = Field(None, description="The unique ID of the previous response to the model. Use this to create multi-turn conversations.") - prompt: Optional[ReusablePrompt] = Field(None, description="Reference to a prompt template and its variables.") - reasoning: Optional[ReasoningConfig] = Field(None, description="Configuration options for reasoning models.") - text: Optional[object] = Field(None, description="Configuration options for a text response from the model.") - truncation: Optional[Literal["auto", "disabled"]] = Field(None, description="The truncation strategy to use for the model response.") - - def set_input_messages(self, messages: List[Message]) -> None: + Configuration for a /v1/responses API request. + + Attributes: + model (str): Model ID used to generate the response, like "gpt-4.1". Defaults to "gpt-4.1". + conversation (Optional[str]): The conversation this response belongs to. + include (Optional[List[Literal[...]]]): Specify additional output data to include. + input (Optional[str | List[Dict[str, str]]]): Text, image, or file inputs to the model. + instructions (Optional[str]): A system or developer message. + max_output_tokens (Optional[int]): Upper bound for generated tokens. + max_tool_calls (Optional[int]): Maximum number of tool calls allowed. + previous_response_id (Optional[str]): ID of the previous response for multi-turn. + prompt (Optional[ReusablePrompt]): Reference to a prompt template and its variables. + reasoning (Optional[ReasoningConfig]): Configuration for reasoning models. + text (Optional[object]): Configuration options for a text response from the model (e.g., JSON schema). + truncation (Optional[Literal["auto", "disabled"]]): The truncation strategy to use. + tools (Optional[List[object]]): An array of tools the model may call. + top_p (Optional[float]): An alternative to sampling with temperature (nucleus sampling). + parallel_tool_calls (Optional[bool]): Whether to allow parallel tool calls. + prompt_cache_key (Optional[str]): Used by OpenAI to cache responses. + safety_identifier (Optional[str]): A stable identifier for policy monitoring. + service_tier (Optional[Literal["auto", "default", "flex", "priority"]]): Specifies the processing type. + store (Optional[bool]): Whether to store the generated model response. + temperature (Optional[float]): Sampling temperature to use (0 to 2). + tool_choice (Optional[str | object]): How the model should select which tool to use. + top_logprobs (Optional[int]): Number of most likely tokens to return at each position (0 to 20). + """ + + conversation: str | None = Field( + None, description="The conversation that this response belongs to." + ) + include: ( + list[ + Literal[ + "code_interpreter_call.outputs", + "computer_call_output.output.image_url", + "file_search_call.results", + "message.input_image.image_url", + "message.output_text.logprobs", + "reasoning.encrypted_content", + ] + ] + | None + ) = Field(None, description="Specify additional output data to include in the model response.") + input: str | list[dict[str, str]] | None = Field( + None, description="Text, image, or file inputs to the model, used to generate a response." + ) + instructions: str | None = Field( + None, description="A system (or developer) message inserted into the model's context." + ) + max_output_tokens: int | None = Field( + None, + gt=0, + description="An upper bound for the number of tokens that can be generated for a response, including visible output tokens and reasoning tokens.", + ) + max_tool_calls: int | None = Field( + None, + gt=0, + description="The maximum number of total calls to built-in tools that can be processed in a response.", + ) + previous_response_id: str | None = Field( + None, + description="The unique ID of the previous response to the model. Use this to create multi-turn conversations.", + ) + prompt: ReusablePrompt | None = Field( + None, description="Reference to a prompt template and its variables." + ) + reasoning: ReasoningConfig | None = Field( + None, description="Configuration options for reasoning models." + ) + text: object | None = Field( + None, description="Configuration options for a text response from the model." + ) + truncation: Literal["auto", "disabled"] | None = Field( + None, description="The truncation strategy to use for the model response." + ) + + def set_input_messages(self, messages: list[Message]) -> None: self.input = [m.serialize() for m in messages] def set_output_structure(self, output_type: type[T]) -> None: @@ -291,10 +390,11 @@ def set_output_structure(self, output_type: type[T]) -> None: "type": "json_schema", "name": output_type.__name__, "schema": schema, - "strict": True + "strict": True, } } + class ChatCompletionsRequest(TextGenerationRequest): """ Configuration for a /v1/chat/completions API request. @@ -325,21 +425,57 @@ class ChatCompletionsRequest(TextGenerationRequest): tool_choice (Optional[str | object]): How the model should select which tool to use. top_logprobs (Optional[int]): Number of most likely tokens to return at each position (0 to 20). """ - messages: List[Dict[str, str]] = Field(None, description="A list of messages comprising the conversation so far.") - frequency_penalty: Optional[float] = Field(None, ge=-2, le=2, description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.") - logit_bias: Optional[Dict] = Field(None, description="Modify the likelihood of specified tokens appearing in the completion.") - logprobs: Optional[bool] = Field(None, description="Whether to return log probabilities of the output tokens or not.") - max_completion_tokens: Optional[int] = Field(None, gt=0, description="An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens.") - modalities: Optional[List[str]] = Field(None, description="Output types that you would like the model to generate.") - n: Optional[int] = Field(None, description="How many chat completion choices to generate for each input message.") - prediction: Optional[object] = Field(None, description="Configuration for a Predicted Output, which can greatly improve response times when large parts of the model response are known ahead of time.") - presence_penalty: Optional[float] = Field(None, ge=-2, le=2, description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.") - reasoning_effort: Optional[Literal["minimal", "low", "medium", "high"]] = Field(None, description="Constrains effort on reasoning for reasoning models.") - response_format: Optional[Dict] = Field(None, description="An object specifying the format that the model must output.") - verbosity: Optional[Literal["low", "medium", "high"]] = Field(None, description="Constrains the verbosity of the model's response.") - web_search_options: Optional[object] = Field(None, description="This tool searches the web for relevant results to use in a response.") - - def set_input_messages(self, messages: List[Message]) -> None: + + messages: list[dict[str, str]] = Field( + [], description="A list of messages comprising the conversation so far." + ) + frequency_penalty: float | None = Field( + None, + ge=-2, + le=2, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.", + ) + logit_bias: dict | None = Field( + None, description="Modify the likelihood of specified tokens appearing in the completion." + ) + logprobs: bool | None = Field( + None, description="Whether to return log probabilities of the output tokens or not." + ) + max_completion_tokens: int | None = Field( + None, + gt=0, + description="An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens.", + ) + modalities: list[str] | None = Field( + None, description="Output types that you would like the model to generate." + ) + n: int | None = Field( + None, description="How many chat completion choices to generate for each input message." + ) + prediction: object | None = Field( + None, + description="Configuration for a Predicted Output, which can greatly improve response times when large parts of the model response are known ahead of time.", + ) + presence_penalty: float | None = Field( + None, + ge=-2, + le=2, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.", + ) + reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = Field( + None, description="Constrains effort on reasoning for reasoning models." + ) + response_format: dict | None = Field( + None, description="An object specifying the format that the model must output." + ) + verbosity: Literal["low", "medium", "high"] | None = Field( + None, description="Constrains the verbosity of the model's response." + ) + web_search_options: object | None = Field( + None, description="This tool searches the web for relevant results to use in a response." + ) + + def set_input_messages(self, messages: list[Message]) -> None: self.messages = [m.serialize() for m in messages] def set_output_structure(self, output_type: type[T]) -> None: @@ -349,10 +485,11 @@ def set_output_structure(self, output_type: type[T]) -> None: "type": "json_schema", "name": output_type.__name__, "schema": schema, - "strict": True + "strict": True, } } + class EmbeddingsRequest(BaseRequest): """ Configuration for a /v1/embeddings API request. @@ -364,14 +501,27 @@ class EmbeddingsRequest(BaseRequest): encoding_format (Optional[Literal["base64", "float"]]): The format to return the embeddings in. user (Optional[str]): A unique identifier representing the end-user. """ - input: Union[str | List[str]] = Field(None, description="Input text to embed, encoded as a string or array of tokens.") - dimensions: Optional[int] = Field(None, ge=1, description="The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.") - encoding_format: Optional[Literal["base64", "float"]] = Field(None, description="The format to return the embeddings in. Can be either float or base64.") - user: Optional[str] = Field(None, description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. ") - def set_input(self, inp: Union[str | List[str]]) -> None: + input: str | list[str] = Field( + "", description="Input text to embed, encoded as a string or array of tokens." + ) + dimensions: int | None = Field( + None, + ge=1, + description="The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.", + ) + encoding_format: Literal["base64", "float"] | None = Field( + None, description="The format to return the embeddings in. Can be either float or base64." + ) + user: str | None = Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. ", + ) + + def set_input(self, inp: str | list[str]) -> None: self.input = inp + class RequestTemplate(BaseModel): """ A template defining a batch job, including its name, description, @@ -385,21 +535,22 @@ class RequestTemplate(BaseModel): request (BaseRequest): The API-specific request configuration (e.g., ResponsesRequest). metadata (Optional[Dict[Any, Any]]): Optional metadata associated with the request template. """ + name: str description: str - prompt: Union[PromptTemplate, ReusablePrompt] + prompt: PromptTemplate | ReusablePrompt request: BaseRequest - metadata: Optional[Dict[Any, Any]] = None + metadata: dict[Any, Any] | None = None def save(self, path: PathLike) -> None: if not str(path).endswith(".json"): - raise ValueError("RequestTemplate has to be saves as \".json\" file.") + raise ValueError('RequestTemplate has to be saves as ".json" file.') Path(path).parent.mkdir(parents=True, exist_ok=True) - with open(path, 'w+') as f: + with open(path, "w+") as f: f.write(self.model_dump_json(indent=4)) @classmethod - def load(cls, path: PathLike) -> Self: - with open(path, 'r') as f: + def load(cls, path: PathLike) -> "RequestTemplate": + with open(path) as f: return RequestTemplate.model_validate_json(f.read()) diff --git a/src/openbatch/validation.py b/src/openbatch/validation.py index 5aecbb8..382c3be 100644 --- a/src/openbatch/validation.py +++ b/src/openbatch/validation.py @@ -7,9 +7,9 @@ """ import json -from pathlib import Path -from typing import Union, List, Dict, Any, Set from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, ClassVar, TextIO @dataclass @@ -23,10 +23,11 @@ class ValidationResult: warnings (List[str]): List of non-critical warnings stats (Dict[str, Any]): Statistics about the batch file """ + is_valid: bool - errors: List[str] = field(default_factory=list) - warnings: List[str] = field(default_factory=list) - stats: Dict[str, Any] = field(default_factory=dict) + errors: list[str] = field(default_factory=list) + warnings: list[str] = field(default_factory=list) + stats: dict[str, Any] = field(default_factory=dict) def __str__(self) -> str: """Human-readable summary of validation results.""" @@ -64,21 +65,21 @@ class BatchFileValidator: """ # OpenAI Batch API constraints (as of 2026) - MAX_FILE_SIZE_MB = 200 - MAX_REQUESTS = 50000 - VALID_ENDPOINTS = { + MAX_FILE_SIZE_MB: int = 200 + MAX_REQUESTS: int = 50000 + VALID_ENDPOINTS: ClassVar[set[str]] = { "/v1/responses", "/v1/chat/completions", - "/v1/embeddings" + "/v1/embeddings", } - REQUIRED_FIELDS = {"custom_id", "method", "url", "body"} + REQUIRED_FIELDS: ClassVar[set[str]] = {"custom_id", "method", "url", "body"} def __init__( self, check_custom_id_uniqueness: bool = True, check_file_size: bool = True, check_request_count: bool = True, - allow_mixed_endpoints: bool = False + allow_mixed_endpoints: bool = False, ): """ Initialize the validator with configuration options. @@ -94,7 +95,7 @@ def __init__( self.check_request_count = check_request_count self.allow_mixed_endpoints = allow_mixed_endpoints - def validate_file(self, file_path: Union[str, Path]) -> ValidationResult: + def validate_file(self, file_path: str | Path) -> ValidationResult: """ Validate a batch file. @@ -115,9 +116,7 @@ def validate_file(self, file_path: Union[str, Path]) -> ValidationResult: # Check file extension if file_path.suffix != ".jsonl": - result.warnings.append( - f"File extension is '{file_path.suffix}', expected '.jsonl'" - ) + result.warnings.append(f"File extension is '{file_path.suffix}', expected '.jsonl'") # Check file size if self.check_file_size: @@ -126,25 +125,24 @@ def validate_file(self, file_path: Union[str, Path]) -> ValidationResult: if file_size_mb > self.MAX_FILE_SIZE_MB: result.errors.append( - f"File size ({file_size_mb:.2f} MB) exceeds limit " - f"({self.MAX_FILE_SIZE_MB} MB)" + f"File size ({file_size_mb:.2f} MB) exceeds limit ({self.MAX_FILE_SIZE_MB} MB)" ) result.is_valid = False # Validate content try: - with open(file_path, "r", encoding="utf-8") as f: + with open(file_path, encoding="utf-8") as f: self._validate_content(f, result) except Exception as e: - result.errors.append(f"Error reading file: {str(e)}") + result.errors.append(f"Error reading file: {e!s}") result.is_valid = False return result - def _validate_content(self, file_handle, result: ValidationResult) -> None: + def _validate_content(self, file_handle: TextIO, result: ValidationResult) -> None: """Validate the content of the batch file.""" - custom_ids: Set[str] = set() - endpoints: Set[str] = set() + custom_ids: set[str] = set() + endpoints: set[str] = set() line_number = 0 for line in file_handle: @@ -160,9 +158,7 @@ def _validate_content(self, file_handle, result: ValidationResult) -> None: try: request = json.loads(line) except json.JSONDecodeError as e: - result.errors.append( - f"Line {line_number}: Invalid JSON - {str(e)}" - ) + result.errors.append(f"Line {line_number}: Invalid JSON - {e!s}") result.is_valid = False continue @@ -190,20 +186,18 @@ def _validate_content(self, file_handle, result: ValidationResult) -> None: def _validate_request( self, - request: Dict[str, Any], + request: dict[str, Any], line_number: int, - custom_ids: Set[str], - endpoints: Set[str], - result: ValidationResult + custom_ids: set[str], + endpoints: set[str], + result: ValidationResult, ) -> None: """Validate a single request object.""" # Check required fields missing_fields = self.REQUIRED_FIELDS - set(request.keys()) if missing_fields: - result.errors.append( - f"Line {line_number}: Missing required fields: {missing_fields}" - ) + result.errors.append(f"Line {line_number}: Missing required fields: {missing_fields}") result.is_valid = False return @@ -216,9 +210,7 @@ def _validate_request( result.is_valid = False elif self.check_custom_id_uniqueness: if custom_id in custom_ids: - result.errors.append( - f"Line {line_number}: Duplicate custom_id '{custom_id}'" - ) + result.errors.append(f"Line {line_number}: Duplicate custom_id '{custom_id}'") result.is_valid = False else: custom_ids.add(custom_id) @@ -226,9 +218,7 @@ def _validate_request( # Validate method method = request.get("method") if method != "POST": - result.errors.append( - f"Line {line_number}: Invalid method '{method}' (must be 'POST')" - ) + result.errors.append(f"Line {line_number}: Invalid method '{method}' (must be 'POST')") result.is_valid = False # Validate URL @@ -245,27 +235,19 @@ def _validate_request( # Validate body body = request.get("body") if not isinstance(body, dict): - result.errors.append( - f"Line {line_number}: 'body' must be a JSON object" - ) + result.errors.append(f"Line {line_number}: 'body' must be a JSON object") result.is_valid = False else: - self._validate_body(body, url, line_number, result) + self._validate_body(body, str(url), line_number, result) def _validate_body( - self, - body: Dict[str, Any], - endpoint: str, - line_number: int, - result: ValidationResult + self, body: dict[str, Any], endpoint: str, line_number: int, result: ValidationResult ) -> None: """Validate the request body based on endpoint type.""" # Check for model field (required for all endpoints) if "model" not in body: - result.errors.append( - f"Line {line_number}: Missing required field 'model' in body" - ) + result.errors.append(f"Line {line_number}: Missing required field 'model' in body") result.is_valid = False # Endpoint-specific validation @@ -283,24 +265,19 @@ def _validate_body( ) result.is_valid = False elif not isinstance(body["messages"], list): - result.errors.append( - f"Line {line_number}: 'messages' must be an array" - ) + result.errors.append(f"Line {line_number}: 'messages' must be an array") result.is_valid = False - elif endpoint == "/v1/embeddings": - if "input" not in body: - result.errors.append( - f"Line {line_number}: Embeddings API requires 'input' in body" - ) - result.is_valid = False + elif endpoint == "/v1/embeddings" and "input" not in body: + result.errors.append(f"Line {line_number}: Embeddings API requires 'input' in body") + result.is_valid = False def validate_batch_file( - file_path: Union[str, Path], + file_path: str | Path, strict: bool = True, check_custom_id_uniqueness: bool = True, - allow_mixed_endpoints: bool = False + allow_mixed_endpoints: bool = False, ) -> ValidationResult: """ Validate a batch file (convenience function). @@ -325,12 +302,12 @@ def validate_batch_file( check_custom_id_uniqueness=check_custom_id_uniqueness, check_file_size=strict, check_request_count=strict, - allow_mixed_endpoints=allow_mixed_endpoints + allow_mixed_endpoints=allow_mixed_endpoints, ) return validator.validate_file(file_path) -def quick_validate(file_path: Union[str, Path]) -> bool: +def quick_validate(file_path: str | Path) -> bool: """ Quick validation check (returns True/False). diff --git a/tests/test_collector.py b/tests/test_collector.py index e4d8588..ec5ae66 100644 --- a/tests/test_collector.py +++ b/tests/test_collector.py @@ -1,7 +1,9 @@ import json + import pytest from pydantic import BaseModel, Field -from openbatch.collector import BatchCollector, Responses, ChatCompletions, Embeddings + +from openbatch.collector import BatchCollector, ChatCompletions, Embeddings, Responses from openbatch.model import ReasoningConfig @@ -23,7 +25,7 @@ def test_responses_create(self, temp_batch_file): ) assert temp_batch_file.exists() - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert data["custom_id"] == "req_1" @@ -42,7 +44,7 @@ def test_responses_parse_without_format(self, temp_batch_file): instructions="Be concise", ) - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert data["custom_id"] == "req_2" @@ -63,7 +65,7 @@ class Analysis(BaseModel): instructions="Analyze sentiment", ) - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert data["custom_id"] == "req_3" @@ -81,7 +83,7 @@ def test_responses_with_reasoning_config(self, temp_batch_file): reasoning=ReasoningConfig(effort="high", summary="detailed"), ) - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert data["body"]["reasoning"]["effort"] == "high" @@ -92,7 +94,7 @@ def test_responses_multiple_requests(self, temp_batch_file): responses.create(custom_id="req_1", model="gpt-4", input="First") responses.create(custom_id="req_2", model="gpt-4", input="Second") - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: lines = f.readlines() assert len(lines) == 2 @@ -116,7 +118,7 @@ def test_chat_completions_create(self, temp_batch_file): ) assert temp_batch_file.exists() - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert data["custom_id"] == "chat_1" @@ -133,7 +135,7 @@ def test_chat_completions_parse_without_format(self, temp_batch_file): messages=[{"role": "user", "content": "Hi"}], ) - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert data["custom_id"] == "chat_2" @@ -152,7 +154,7 @@ class Response(BaseModel): messages=[{"role": "user", "content": "What is 2+2?"}], ) - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert data["custom_id"] == "chat_3" @@ -169,23 +171,21 @@ def test_chat_completions_with_reasoning_effort(self, temp_batch_file): reasoning_effort="high", ) - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert data["body"]["reasoning_effort"] == "high" def test_chat_completions_multiple_requests(self, temp_batch_file): chat = ChatCompletions(temp_batch_file) - chat.create( - custom_id="chat_1", model="gpt-4", messages=[{"role": "user", "content": "Hi"}] - ) + chat.create(custom_id="chat_1", model="gpt-4", messages=[{"role": "user", "content": "Hi"}]) chat.create( custom_id="chat_2", model="gpt-4", messages=[{"role": "user", "content": "Bye"}], ) - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: lines = f.readlines() assert len(lines) == 2 @@ -201,7 +201,7 @@ def test_embeddings_create_single_input(self, temp_batch_file): ) assert temp_batch_file.exists() - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert data["custom_id"] == "emb_1" @@ -217,7 +217,7 @@ def test_embeddings_create_list_input(self, temp_batch_file): inp=["Text 1", "Text 2", "Text 3"], ) - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert isinstance(data["body"]["input"], list) @@ -232,7 +232,7 @@ def test_embeddings_with_dimensions(self, temp_batch_file): dimensions=512, ) - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert data["body"]["dimensions"] == 512 @@ -242,7 +242,7 @@ def test_embeddings_multiple_requests(self, temp_batch_file): embeddings.create(custom_id="emb_1", model="text-embedding-3-small", inp="First") embeddings.create(custom_id="emb_2", model="text-embedding-3-small", inp="Second") - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: lines = f.readlines() assert len(lines) == 2 @@ -264,7 +264,7 @@ def test_batch_collector_responses_api(self, temp_batch_file): ) assert temp_batch_file.exists() - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert data["url"] == "/v1/responses" @@ -278,7 +278,7 @@ def test_batch_collector_chat_completions_api(self, temp_batch_file): ) assert temp_batch_file.exists() - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert data["url"] == "/v1/chat/completions" @@ -292,7 +292,7 @@ def test_batch_collector_embeddings_api(self, temp_batch_file): ) assert temp_batch_file.exists() - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert data["url"] == "/v1/embeddings" @@ -304,9 +304,7 @@ def test_batch_collector_mixed_apis_in_sequence(self, tmp_path): embeddings_file = tmp_path / "embeddings.jsonl" responses_collector = BatchCollector(responses_file) - responses_collector.responses.create( - custom_id="req_1", model="gpt-4", input="Test" - ) + responses_collector.responses.create(custom_id="req_1", model="gpt-4", input="Test") chat_collector = BatchCollector(chat_file) chat_collector.chat.completions.create( @@ -339,7 +337,7 @@ class TaskAnalysis(BaseModel): instructions="Provide structured analysis", ) - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert "text" in data["body"] @@ -363,7 +361,7 @@ class CodeReview(BaseModel): ], ) - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert "response_format" in data["body"] diff --git a/tests/test_integration.py b/tests/test_integration.py index d342ed2..817fa94 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,18 +1,20 @@ """Integration tests that verify end-to-end workflows.""" + import json + import pytest from pydantic import BaseModel, Field + from openbatch import ( BatchCollector, BatchJobManager, + EmbeddingInputInstance, + EmbeddingsRequest, Message, PromptTemplate, PromptTemplateInputInstance, - EmbeddingInputInstance, - ResponsesRequest, - ChatCompletionsRequest, - EmbeddingsRequest, ReasoningConfig, + ResponsesRequest, ) @@ -67,15 +69,15 @@ def test_batch_collector_complete_workflow(self, temp_dir): assert chat_file.exists() assert embeddings_file.exists() - with open(responses_file, "r") as f: + with open(responses_file) as f: resp_data = json.loads(f.readline()) assert resp_data["url"] == "/v1/responses" - with open(chat_file, "r") as f: + with open(chat_file) as f: chat_data = json.loads(f.readline()) assert chat_data["url"] == "/v1/chat/completions" - with open(embeddings_file, "r") as f: + with open(embeddings_file) as f: emb_data = json.loads(f.readline()) assert emb_data["url"] == "/v1/embeddings" @@ -97,9 +99,7 @@ def test_batch_job_manager_templated_workflow(self, temp_dir): ] ) - common_config = ResponsesRequest( - model="gpt-4-mini", temperature=0.8, max_output_tokens=100 - ) + common_config = ResponsesRequest(model="gpt-4-mini", temperature=0.8, max_output_tokens=100) products = [ PromptTemplateInputInstance( @@ -137,7 +137,7 @@ def test_batch_job_manager_templated_workflow(self, temp_dir): # Verify assert batch_file.exists() - with open(batch_file, "r") as f: + with open(batch_file) as f: lines = f.readlines() assert len(lines) == 3 @@ -157,9 +157,7 @@ def test_batch_job_manager_embeddings_workflow(self, temp_dir): """Test BatchJobManager for bulk embeddings generation.""" batch_file = temp_dir / "embeddings_batch.jsonl" - common_config = EmbeddingsRequest( - model="text-embedding-3-small", dimensions=512 - ) + common_config = EmbeddingsRequest(model="text-embedding-3-small", dimensions=512) documents = [ EmbeddingInputInstance(id="doc_1", input="The sky is blue."), @@ -178,7 +176,7 @@ def test_batch_job_manager_embeddings_workflow(self, temp_dir): ) assert batch_file.exists() - with open(batch_file, "r") as f: + with open(batch_file) as f: lines = f.readlines() assert len(lines) == 4 @@ -220,7 +218,7 @@ class SentimentAnalysis(BaseModel): ) assert batch_file.exists() - with open(batch_file, "r") as f: + with open(batch_file) as f: lines = f.readlines() assert len(lines) == 3 @@ -267,7 +265,7 @@ class RecipeExtraction(BaseModel): ], ) - with open(batch_file, "r") as f: + with open(batch_file) as f: lines = f.readlines() assert len(lines) == 2 @@ -306,7 +304,7 @@ def test_responses_api_with_reasoning_config(self, temp_dir): reasoning=ReasoningConfig(effort=item["effort"], summary="detailed"), ) - with open(batch_file, "r") as f: + with open(batch_file) as f: lines = f.readlines() assert len(lines) == 2 @@ -326,7 +324,7 @@ def test_ensure_ascii_true(self, temp_dir): request = ResponsesRequest(model="gpt-4", input="Hello 世界! Привет! مرحبا") manager.add("unicode_test", request, batch_file) - with open(batch_file, "r", encoding="utf-8") as f: + with open(batch_file, encoding="utf-8") as f: content = f.read() # Should contain escaped unicode @@ -339,12 +337,10 @@ def test_ensure_ascii_false(self, temp_dir): batch_file = temp_dir / "unicode_raw.jsonl" manager = BatchJobManager(ensure_ascii=False) - request = ResponsesRequest( - model="gpt-4", input="Hello 世界! Привет! مرحبا Emoji: 🚀" - ) + request = ResponsesRequest(model="gpt-4", input="Hello 世界! Привет! مرحبا Emoji: 🚀") manager.add("unicode_test", request, batch_file) - with open(batch_file, "r", encoding="utf-8") as f: + with open(batch_file, encoding="utf-8") as f: content = f.read() # Raw characters should be present @@ -361,9 +357,7 @@ def test_generate_1000_requests(self, temp_dir): """Test generating a batch file with 1000 requests.""" batch_file = temp_dir / "large_batch.jsonl" - template = PromptTemplate( - messages=[Message(role="user", content="Classify: {text}")] - ) + template = PromptTemplate(messages=[Message(role="user", content="Classify: {text}")]) common_request = ResponsesRequest(model="gpt-4-mini", max_output_tokens=10) @@ -385,7 +379,7 @@ def test_generate_1000_requests(self, temp_dir): # Verify assert batch_file.exists() - with open(batch_file, "r") as f: + with open(batch_file) as f: lines = f.readlines() assert len(lines) == 1000 diff --git a/tests/test_manager.py b/tests/test_manager.py index b979366..cbb0b3e 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -1,16 +1,18 @@ import json -import pytest import warnings + +import pytest + from openbatch.manager import BatchJobManager from openbatch.model import ( + ChatCompletionsRequest, + EmbeddingInputInstance, + EmbeddingsRequest, Message, PromptTemplate, - ReusablePrompt, PromptTemplateInputInstance, - EmbeddingInputInstance, ResponsesRequest, - ChatCompletionsRequest, - EmbeddingsRequest, + ReusablePrompt, ) @@ -38,7 +40,7 @@ def test_add_responses_request(self, manager, temp_batch_file): manager.add("test_id", request, temp_batch_file) assert temp_batch_file.exists() - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: line = f.readline() data = json.loads(line) @@ -55,7 +57,7 @@ def test_add_chat_completions_request(self, manager, temp_batch_file): manager.add("chat_id", request, temp_batch_file) assert temp_batch_file.exists() - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert data["custom_id"] == "chat_id" @@ -67,7 +69,7 @@ def test_add_embeddings_request(self, manager, temp_batch_file): manager.add("emb_id", request, temp_batch_file) assert temp_batch_file.exists() - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert data["custom_id"] == "emb_id" @@ -81,7 +83,7 @@ def test_add_multiple_requests(self, manager, temp_batch_file): manager.add("id1", request1, temp_batch_file) manager.add("id2", request2, temp_batch_file) - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: lines = f.readlines() assert len(lines) == 2 @@ -90,16 +92,12 @@ def test_add_multiple_requests(self, manager, temp_batch_file): assert data1["custom_id"] == "id1" assert data2["custom_id"] == "id2" - def test_add_responses_request_without_input_or_prompt_raises( - self, manager, temp_batch_file - ): + def test_add_responses_request_without_input_or_prompt_raises(self, manager, temp_batch_file): request = ResponsesRequest(model="gpt-4") with pytest.raises(ValueError, match="must define either an input or a prompt"): manager.add("test_id", request, temp_batch_file) - def test_add_chat_completions_request_without_messages_raises( - self, manager, temp_batch_file - ): + def test_add_chat_completions_request_without_messages_raises(self, manager, temp_batch_file): # messages is required in the Pydantic model, so we create a request # and then set messages to None manually request = ChatCompletionsRequest(model="gpt-4", messages=[]) @@ -127,7 +125,7 @@ def test_add_with_ensure_ascii_false(self, manager_no_ascii, temp_batch_file): request = ResponsesRequest(model="gpt-4", input="Hello 世界") manager_no_ascii.add("test_id", request, temp_batch_file) - with open(temp_batch_file, "r", encoding="utf-8") as f: + with open(temp_batch_file, encoding="utf-8") as f: content = f.read() data = json.loads(content) @@ -138,7 +136,7 @@ def test_add_with_ensure_ascii_true(self, manager, temp_batch_file): request = ResponsesRequest(model="gpt-4", input="Hello 世界") manager.add("test_id", request, temp_batch_file) - with open(temp_batch_file, "r", encoding="utf-8") as f: + with open(temp_batch_file, encoding="utf-8") as f: raw_content = f.read() # ASCII escaped version should not contain the raw unicode characters @@ -162,7 +160,7 @@ def test_add_templated_instances_responses_api(self, manager, temp_batch_file): manager.add_templated_instances(template, common_request, instances, temp_batch_file) - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: lines = f.readlines() assert len(lines) == 2 @@ -194,7 +192,7 @@ def test_add_templated_instances_chat_completions_api(self, manager, temp_batch_ manager.add_templated_instances(template, common_request, instances, temp_batch_file) - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: lines = f.readlines() assert len(lines) == 2 @@ -210,11 +208,9 @@ def test_add_templated_instances_with_reusable_prompt(self, manager, temp_batch_ PromptTemplateInputInstance(id="inst_1", prompt_value_mapping={"var": "value"}) ] - manager.add_templated_instances( - reusable_prompt, common_request, instances, temp_batch_file - ) + manager.add_templated_instances(reusable_prompt, common_request, instances, temp_batch_file) - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert data["body"]["prompt"]["id"] == "prompt_123" @@ -230,14 +226,12 @@ def test_add_templated_instances_with_instance_options(self, manager, temp_batch prompt_value_mapping={"text": "Hello"}, instance_request_options={"temperature": 0.9}, ), - PromptTemplateInputInstance( - id="inst_2", prompt_value_mapping={"text": "World"} - ), + PromptTemplateInputInstance(id="inst_2", prompt_value_mapping={"text": "World"}), ] manager.add_templated_instances(template, common_request, instances, temp_batch_file) - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: lines = f.readlines() data1 = json.loads(lines[0]) @@ -250,17 +244,13 @@ def test_add_templated_instances_with_instance_options(self, manager, temp_batch def test_add_templated_instances_with_embeddings_raises(self, manager, temp_batch_file): template = PromptTemplate(messages=[Message(role="user", content="Test")]) - common_request = EmbeddingsRequest( - model="text-embedding-3-small", input="dummy" - ) + common_request = EmbeddingsRequest(model="text-embedding-3-small", input="dummy") instances = [ PromptTemplateInputInstance(id="inst_1", prompt_value_mapping={"text": "Test"}) ] with pytest.raises(ValueError, match="Embeddings API is not supported"): - manager.add_templated_instances( - template, common_request, instances, temp_batch_file - ) + manager.add_templated_instances(template, common_request, instances, temp_batch_file) def test_add_templated_instances_reusable_prompt_with_chat_raises( self, manager, temp_batch_file @@ -282,16 +272,12 @@ def test_add_templated_instances_appending_warning(self, manager, temp_batch_fil template = PromptTemplate(messages=[Message(role="user", content="Test")]) common_request = ResponsesRequest(model="gpt-4") - instances = [ - PromptTemplateInputInstance(id="inst_1", prompt_value_mapping={}) - ] + instances = [PromptTemplateInputInstance(id="inst_1", prompt_value_mapping={})] # Should warn when appending to existing file with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - manager.add_templated_instances( - template, common_request, instances, temp_batch_file - ) + manager.add_templated_instances(template, common_request, instances, temp_batch_file) assert len(w) == 1 assert "already exists" in str(w[0].message) @@ -300,9 +286,7 @@ def test_add_templated_instances_suppress_warnings(self, manager, temp_batch_fil template = PromptTemplate(messages=[Message(role="user", content="Test")]) common_request = ResponsesRequest(model="gpt-4") - instances = [ - PromptTemplateInputInstance(id="inst_1", prompt_value_mapping={}) - ] + instances = [PromptTemplateInputInstance(id="inst_1", prompt_value_mapping={})] # Should not warn when suppress_warnings=True with warnings.catch_warnings(record=True) as w: @@ -319,9 +303,7 @@ def test_add_templated_instances_suppress_warnings(self, manager, temp_batch_fil class TestBatchJobManagerEmbeddingRequests: def test_add_embedding_requests(self, manager, temp_batch_file): - common_request = EmbeddingsRequest( - model="text-embedding-3-small", dimensions=512 - ) + common_request = EmbeddingsRequest(model="text-embedding-3-small", dimensions=512) inputs = [ EmbeddingInputInstance(id="emb_1", input="First text"), EmbeddingInputInstance(id="emb_2", input="Second text"), @@ -329,7 +311,7 @@ def test_add_embedding_requests(self, manager, temp_batch_file): manager.add_embedding_requests(inputs, common_request, temp_batch_file) - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: lines = f.readlines() assert len(lines) == 2 @@ -351,16 +333,14 @@ def test_add_embedding_requests_with_list_input(self, manager, temp_batch_file): manager.add_embedding_requests(inputs, common_request, temp_batch_file) - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: data = json.loads(f.readline()) assert isinstance(data["body"]["input"], list) assert len(data["body"]["input"]) == 2 def test_add_embedding_requests_with_instance_options(self, manager, temp_batch_file): - common_request = EmbeddingsRequest( - model="text-embedding-3-small", dimensions=512 - ) + common_request = EmbeddingsRequest(model="text-embedding-3-small", dimensions=512) inputs = [ EmbeddingInputInstance( id="emb_1", @@ -372,7 +352,7 @@ def test_add_embedding_requests_with_instance_options(self, manager, temp_batch_ manager.add_embedding_requests(inputs, common_request, temp_batch_file) - with open(temp_batch_file, "r") as f: + with open(temp_batch_file) as f: lines = f.readlines() data1 = json.loads(lines[0]) diff --git a/tests/test_model.py b/tests/test_model.py index d0f4962..16c8a98 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,19 +1,19 @@ -import pytest from pydantic import BaseModel, Field + from openbatch.model import ( + ChatCompletionsAPIStrategy, + ChatCompletionsRequest, + EmbeddingInputInstance, + EmbeddingsAPIStrategy, + EmbeddingsRequest, Message, + MessagesInputInstance, PromptTemplate, - ReusablePrompt, - ReasoningConfig, PromptTemplateInputInstance, - MessagesInputInstance, - EmbeddingInputInstance, - ResponsesRequest, - ChatCompletionsRequest, - EmbeddingsRequest, + ReasoningConfig, ResponsesAPIStrategy, - ChatCompletionsAPIStrategy, - EmbeddingsAPIStrategy, + ResponsesRequest, + ReusablePrompt, ) @@ -231,9 +231,7 @@ def test_embeddings_request_with_string(self): assert request.input == "Hello" def test_embeddings_request_with_list(self): - request = EmbeddingsRequest( - model="text-embedding-3-small", input=["Hello", "World"] - ) + request = EmbeddingsRequest(model="text-embedding-3-small", input=["Hello", "World"]) assert isinstance(request.input, list) assert len(request.input) == 2 @@ -243,15 +241,11 @@ def test_embeddings_request_set_input(self): assert request.input == "New text" def test_embeddings_request_with_dimensions(self): - request = EmbeddingsRequest( - model="text-embedding-3-small", input="test", dimensions=512 - ) + request = EmbeddingsRequest(model="text-embedding-3-small", input="test", dimensions=512) assert request.dimensions == 512 def test_embeddings_request_to_dict(self): - request = EmbeddingsRequest( - model="text-embedding-3-small", input="test", dimensions=256 - ) + request = EmbeddingsRequest(model="text-embedding-3-small", input="test", dimensions=256) result = request.to_dict() assert result["model"] == "text-embedding-3-small" assert result["input"] == "test" diff --git a/tests/test_utils.py b/tests/test_utils.py index 2766e6d..4f89994 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,11 +1,11 @@ import pytest from pydantic import BaseModel, Field -from typing import Optional, List + from openbatch._utils import ( - type_to_json_schema, _ensure_strict_json_schema, has_more_than_n_keys, resolve_ref, + type_to_json_schema, ) @@ -34,9 +34,7 @@ def test_resolve_simple_ref(self): assert result == {"type": "object"} def test_resolve_nested_ref(self): - root = { - "definitions": {"Nested": {"properties": {"field": {"type": "string"}}}} - } + root = {"definitions": {"Nested": {"properties": {"field": {"type": "string"}}}}} result = resolve_ref(root=root, ref="#/definitions/Nested/properties/field") assert result == {"type": "string"} @@ -100,7 +98,10 @@ def test_array_with_object_items(self): def test_any_of_union(self): schema = { - "anyOf": [{"type": "string"}, {"type": "object", "properties": {"a": {"type": "string"}}}] + "anyOf": [ + {"type": "string"}, + {"type": "object", "properties": {"a": {"type": "string"}}}, + ] } result = _ensure_strict_json_schema(schema, path=(), root=schema) assert len(result["anyOf"]) == 2 @@ -128,9 +129,7 @@ def test_definitions_processed(self): schema = { "type": "object", "properties": {"user": {"$ref": "#/definitions/User"}}, - "definitions": { - "User": {"type": "object", "properties": {"name": {"type": "string"}}} - }, + "definitions": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}}, } result = _ensure_strict_json_schema(schema, path=(), root=schema) assert result["definitions"]["User"]["additionalProperties"] is False @@ -140,9 +139,7 @@ def test_defs_processed(self): schema = { "type": "object", "properties": {"user": {"$ref": "#/$defs/User"}}, - "$defs": { - "User": {"type": "object", "properties": {"name": {"type": "string"}}} - }, + "$defs": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}}, } result = _ensure_strict_json_schema(schema, path=(), root=schema) assert result["$defs"]["User"]["additionalProperties"] is False @@ -157,9 +154,7 @@ def test_ref_with_additional_properties_unrolled(self): "description": "The user object", } }, - "definitions": { - "User": {"type": "object", "properties": {"name": {"type": "string"}}} - }, + "definitions": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}}, } result = _ensure_strict_json_schema(schema, path=(), root=schema) # The $ref should be unrolled when there are additional properties @@ -186,7 +181,7 @@ class SimpleModel(BaseModel): def test_model_with_optional_field(self): class ModelWithOptional(BaseModel): name: str - nickname: Optional[str] = None + nickname: str | None = None schema = type_to_json_schema(ModelWithOptional) # All properties should be required in strict mode @@ -217,7 +212,7 @@ class Person(BaseModel): def test_model_with_list_field(self): class TodoList(BaseModel): title: str - items: List[str] + items: list[str] schema = type_to_json_schema(TodoList) assert schema["properties"]["items"]["type"] == "array" @@ -239,12 +234,12 @@ class Item(BaseModel): class Order(BaseModel): order_id: str - items: List[Item] + items: list[Item] total: float class Customer(BaseModel): name: str - orders: List[Order] + orders: list[Order] schema = type_to_json_schema(Customer) assert schema["additionalProperties"] is False diff --git a/tests/test_validation.py b/tests/test_validation.py index d7b59c2..27593d3 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,11 +1,13 @@ """Tests for batch file validation.""" import json + import pytest + from openbatch.validation import ( - validate_batch_file, - quick_validate, ValidationResult, + quick_validate, + validate_batch_file, ) @@ -21,7 +23,7 @@ def test_validation_result_str(self): is_valid=False, errors=["Error 1", "Error 2"], warnings=["Warning 1"], - stats={"total_requests": 10, "file_size_mb": 0.5} + stats={"total_requests": 10, "file_size_mb": 0.5}, ) output = str(result) assert "FAILED" in output @@ -43,14 +45,14 @@ def test_valid_batch_file(self, temp_batch_file): "custom_id": "req_1", "method": "POST", "url": "/v1/responses", - "body": {"model": "gpt-4", "input": "Hello"} + "body": {"model": "gpt-4", "input": "Hello"}, }, { "custom_id": "req_2", "method": "POST", "url": "/v1/responses", - "body": {"model": "gpt-4", "input": "World"} - } + "body": {"model": "gpt-4", "input": "World"}, + }, ] with open(temp_batch_file, "w") as f: @@ -85,14 +87,14 @@ def test_duplicate_custom_ids(self, temp_batch_file): "custom_id": "req_1", "method": "POST", "url": "/v1/responses", - "body": {"model": "gpt-4", "input": "Hello"} + "body": {"model": "gpt-4", "input": "Hello"}, }, { "custom_id": "req_1", # Duplicate "method": "POST", "url": "/v1/responses", - "body": {"model": "gpt-4", "input": "World"} - } + "body": {"model": "gpt-4", "input": "World"}, + }, ] with open(temp_batch_file, "w") as f: @@ -123,7 +125,7 @@ def test_invalid_method(self, temp_batch_file): "custom_id": "req_1", "method": "GET", # Should be POST "url": "/v1/responses", - "body": {"model": "gpt-4", "input": "Hello"} + "body": {"model": "gpt-4", "input": "Hello"}, } with open(temp_batch_file, "w") as f: @@ -139,7 +141,7 @@ def test_invalid_endpoint(self, temp_batch_file): "custom_id": "req_1", "method": "POST", "url": "/v1/invalid", # Invalid endpoint - "body": {"model": "gpt-4"} + "body": {"model": "gpt-4"}, } with open(temp_batch_file, "w") as f: @@ -155,7 +157,7 @@ def test_responses_api_missing_input(self, temp_batch_file): "custom_id": "req_1", "method": "POST", "url": "/v1/responses", - "body": {"model": "gpt-4"} # Missing input or prompt + "body": {"model": "gpt-4"}, # Missing input or prompt } with open(temp_batch_file, "w") as f: @@ -171,7 +173,7 @@ def test_chat_completions_missing_messages(self, temp_batch_file): "custom_id": "req_1", "method": "POST", "url": "/v1/chat/completions", - "body": {"model": "gpt-4"} # Missing messages + "body": {"model": "gpt-4"}, # Missing messages } with open(temp_batch_file, "w") as f: @@ -187,7 +189,7 @@ def test_chat_completions_invalid_messages(self, temp_batch_file): "custom_id": "req_1", "method": "POST", "url": "/v1/chat/completions", - "body": {"model": "gpt-4", "messages": "not an array"} + "body": {"model": "gpt-4", "messages": "not an array"}, } with open(temp_batch_file, "w") as f: @@ -203,7 +205,7 @@ def test_embeddings_missing_input(self, temp_batch_file): "custom_id": "req_1", "method": "POST", "url": "/v1/embeddings", - "body": {"model": "text-embedding-3-small"} # Missing input + "body": {"model": "text-embedding-3-small"}, # Missing input } with open(temp_batch_file, "w") as f: @@ -220,14 +222,14 @@ def test_mixed_endpoints_warning(self, temp_batch_file): "custom_id": "req_1", "method": "POST", "url": "/v1/responses", - "body": {"model": "gpt-4", "input": "Hello"} + "body": {"model": "gpt-4", "input": "Hello"}, }, { "custom_id": "req_2", "method": "POST", "url": "/v1/embeddings", - "body": {"model": "text-embedding-3-small", "input": "World"} - } + "body": {"model": "text-embedding-3-small", "input": "World"}, + }, ] with open(temp_batch_file, "w") as f: @@ -241,9 +243,13 @@ def test_mixed_endpoints_warning(self, temp_batch_file): def test_empty_lines_warning(self, temp_batch_file): """Test warning for empty lines.""" with open(temp_batch_file, "w") as f: - f.write('{"custom_id": "req_1", "method": "POST", "url": "/v1/responses", "body": {"model": "gpt-4", "input": "Hi"}}\n') + f.write( + '{"custom_id": "req_1", "method": "POST", "url": "/v1/responses", "body": {"model": "gpt-4", "input": "Hi"}}\n' + ) f.write("\n") # Empty line - f.write('{"custom_id": "req_2", "method": "POST", "url": "/v1/responses", "body": {"model": "gpt-4", "input": "Bye"}}\n') + f.write( + '{"custom_id": "req_2", "method": "POST", "url": "/v1/responses", "body": {"model": "gpt-4", "input": "Bye"}}\n' + ) result = validate_batch_file(temp_batch_file) assert any("empty line" in warn.lower() for warn in result.warnings) @@ -255,7 +261,7 @@ def test_wrong_file_extension_warning(self, tmp_path): "custom_id": "req_1", "method": "POST", "url": "/v1/responses", - "body": {"model": "gpt-4", "input": "Hello"} + "body": {"model": "gpt-4", "input": "Hello"}, } with open(json_file, "w") as f: @@ -270,7 +276,7 @@ def test_missing_model_in_body(self, temp_batch_file): "custom_id": "req_1", "method": "POST", "url": "/v1/responses", - "body": {"input": "Hello"} # Missing model + "body": {"input": "Hello"}, # Missing model } with open(temp_batch_file, "w") as f: @@ -286,7 +292,7 @@ def test_invalid_custom_id_type(self, temp_batch_file): "custom_id": 123, # Should be string "method": "POST", "url": "/v1/responses", - "body": {"model": "gpt-4", "input": "Hello"} + "body": {"model": "gpt-4", "input": "Hello"}, } with open(temp_batch_file, "w") as f: @@ -302,7 +308,7 @@ def test_empty_custom_id(self, temp_batch_file): "custom_id": "", # Empty string "method": "POST", "url": "/v1/responses", - "body": {"model": "gpt-4", "input": "Hello"} + "body": {"model": "gpt-4", "input": "Hello"}, } with open(temp_batch_file, "w") as f: @@ -318,7 +324,7 @@ def test_body_not_object(self, temp_batch_file): "custom_id": "req_1", "method": "POST", "url": "/v1/responses", - "body": "not an object" + "body": "not an object", } with open(temp_batch_file, "w") as f: @@ -335,14 +341,14 @@ def test_skip_custom_id_check(self, temp_batch_file): "custom_id": "req_1", "method": "POST", "url": "/v1/responses", - "body": {"model": "gpt-4", "input": "Hello"} + "body": {"model": "gpt-4", "input": "Hello"}, }, { "custom_id": "req_1", # Duplicate "method": "POST", "url": "/v1/responses", - "body": {"model": "gpt-4", "input": "World"} - } + "body": {"model": "gpt-4", "input": "World"}, + }, ] with open(temp_batch_file, "w") as f: @@ -361,7 +367,7 @@ def test_quick_validate_true(self, temp_batch_file): "custom_id": "req_1", "method": "POST", "url": "/v1/responses", - "body": {"model": "gpt-4", "input": "Hello"} + "body": {"model": "gpt-4", "input": "Hello"}, } with open(temp_batch_file, "w") as f: @@ -386,7 +392,7 @@ def test_large_valid_file(self, temp_batch_file): "custom_id": f"req_{i}", "method": "POST", "url": "/v1/responses", - "body": {"model": "gpt-4", "input": f"Request {i}"} + "body": {"model": "gpt-4", "input": f"Request {i}"}, } f.write(json.dumps(request) + "\n") @@ -402,20 +408,20 @@ def test_all_three_endpoints(self, temp_batch_file): "custom_id": "req_1", "method": "POST", "url": "/v1/responses", - "body": {"model": "gpt-4", "input": "Hello"} + "body": {"model": "gpt-4", "input": "Hello"}, }, { "custom_id": "req_2", "method": "POST", "url": "/v1/chat/completions", - "body": {"model": "gpt-4", "messages": [{"role": "user", "content": "Hi"}]} + "body": {"model": "gpt-4", "messages": [{"role": "user", "content": "Hi"}]}, }, { "custom_id": "req_3", "method": "POST", "url": "/v1/embeddings", - "body": {"model": "text-embedding-3-small", "input": "Text"} - } + "body": {"model": "text-embedding-3-small", "input": "Text"}, + }, ] with open(temp_batch_file, "w") as f: From ee4a5ffdd638b346359042587a85497b074f6433 Mon Sep 17 00:00:00 2001 From: Daniel Gomm Date: Tue, 10 Feb 2026 10:19:23 +0100 Subject: [PATCH 7/9] :memo: add contributing guide --- CONTRIBUTING.md | 433 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 433 insertions(+) create mode 100644 CONTRIBUTING.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..cfe1927 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,433 @@ +# Contributing to OpenBatch + +Thank you for your interest in contributing to openbatch! This document provides guidelines and instructions for contributing to the project. + +## Table of Contents + +- [Code of Conduct](#code-of-conduct) +- [Getting Started](#getting-started) +- [Development Setup](#development-setup) +- [Development Workflow](#development-workflow) +- [Branch Naming Convention](#branch-naming-convention) +- [Commit Message Convention](#commit-message-convention) +- [Running Tests](#running-tests) +- [Code Quality](#code-quality) +- [Continuous Integration](#continuous-integration) +- [Opening Issues](#opening-issues) +- [Submitting Pull Requests](#submitting-pull-requests) +- [Documentation](#documentation) + +## Code of Conduct + +Please be respectful and constructive in all interactions. + +## Getting Started + +1. **Fork the repository** on GitHub +2. **Clone your fork** locally: + ```bash + git clone https://github.com/YOUR_USERNAME/openbatch.git + cd openbatch + ``` +3. **Add the upstream repository**: + ```bash + git remote add upstream https://github.com/daniel-gomm/openbatch.git + ``` + +## Development Setup + +### Prerequisites + +- Python 3.11 or higher +- pip +- git + +### Installation + +1. **Create a virtual environment** (recommended): + ```bash + python -m venv venv + source venv/bin/activate # On Windows: venv\Scripts\activate + ``` + +2. **Install the package in editable mode with development dependencies**: + ```bash + pip install -e ".[dev]" + ``` + +3. **Install pre-commit hooks**: + ```bash + pre-commit install + ``` + + This ensures code quality checks run automatically before each commit. + +### Verify Installation + +```bash +# Run tests +pytest + +# Check linting +ruff check src/ tests/ + +# Check formatting +ruff format --check src/ tests/ + +# Run type checking +mypy src/ +``` + +## Development Workflow + +1. **Sync your fork** with the upstream repository: + ```bash + git checkout main + git fetch upstream + git merge upstream/main + git push origin main + ``` + +2. **Create a new branch** for your work (see [Branch Naming Convention](#branch-naming-convention)): + ```bash + git checkout -b feature/your-feature-name + ``` + +3. **Make your changes** and commit them (see [Commit Message Convention](#commit-message-convention)) + +4. **Push your changes** to your fork: + ```bash + git push origin feature/your-feature-name + ``` + +5. **Open a Pull Request** on GitHub + +## Branch Naming Convention + +We use the following prefixes for branch names: + +- `feature/` - New features or enhancements + - Example: `feature/add-retry-mechanism` +- `fix/` - Bug fixes + - Example: `fix/validation-error-handling` +- `documentation/` - Documentation updates + - Example: `documentation/improve-api-examples` + +**Format**: `/` + +## Commit Message Convention + +We use [Gitmoji](https://gitmoji.dev/) to prefix commit messages with relevant emojis that indicate the nature of the change. + +### Common Gitmojis + +| Emoji | Code | Description | +|-------|------|-------------| +| ✨ | `:sparkles:` | Introduce new features | +| 🐛 | `:bug:` | Fix a bug | +| 📝 | `:memo:` | Add or update documentation | +| ✅ | `:white_check_mark:` | Add, update, or pass tests | +| ♻️ | `:recycle:` | Refactor code | +| 🎨 | `:art:` | Improve structure/format of the code | +| ⚡️ | `:zap:` | Improve performance | +| 🔒️ | `:lock:` | Fix security issues | +| ⬆️ | `:arrow_up:` | Upgrade dependencies | +| ⬇️ | `:arrow_down:` | Downgrade dependencies | +| 🔧 | `:wrench:` | Add or update configuration files | +| 🚀 | `:rocket:` | Deploy stuff | +| 💚 | `:green_heart:` | Fix CI build | + +See the full list at [gitmoji.dev](https://gitmoji.dev/). + +### Commit Message Format + +``` + + +[Optional detailed description] + +[Optional footer with issue references] +``` + +### Examples + +```bash +# Adding a new feature +git commit -m "✨ Add batch file validation module" + +# Fixing a bug +git commit -m "🐛 Fix custom_id uniqueness check in validator" + +# Updating documentation +git commit -m "📝 Add validation usage examples to README" + +# Adding tests +git commit -m "✅ Add integration tests for validation module" + +# Refactoring +git commit -m "♻️ Refactor BatchJobManager to use strategy pattern" +``` + +## Running Tests + +### Run All Tests + +```bash +pytest +``` + +### Run Tests with Coverage + +```bash +pytest --cov=src/openbatch --cov-report=html +``` + +View the coverage report by opening `htmlcov/index.html` in your browser. + +### Run Specific Tests + +```bash +# Run a specific test file +pytest tests/test_validation.py + +# Run a specific test class +pytest tests/test_validation.py::TestBatchFileValidator + +# Run a specific test method +pytest tests/test_validation.py::TestBatchFileValidator::test_valid_batch_file + +# Run tests matching a pattern +pytest -k "validation" +``` + +### Run Tests in Verbose Mode + +```bash +pytest -v +``` + +## Code Quality + +### Linting and Formatting + +We use **Ruff** for both linting and formatting: + +```bash +# Check for linting issues +ruff check src/ tests/ + +# Automatically fix linting issues +ruff check --fix src/ tests/ + +# Check code formatting +ruff format --check src/ tests/ + +# Format code +ruff format src/ tests/ +``` + +### Type Checking + +We use **mypy** for static type checking, configured for gradual adoption with reasonable strictness: + +```bash +mypy src/ +``` + +**Note**: Mypy is configured to allow `**kwargs` patterns and flexible union types that are common in this codebase. It focuses on catching real type errors while not being overly strict about edge cases. + +### Pre-commit Hooks + +Pre-commit hooks automatically run quality checks before each commit: + +- Trailing whitespace removal +- End-of-file fixer +- YAML syntax check +- Ruff linting and formatting +- Mypy type checking + +To manually run all pre-commit hooks: + +```bash +pre-commit run --all-files +``` + +## Continuous Integration + +We use GitHub Actions for CI/CD. All pull requests must pass: + +### Test Workflow (`.github/workflows/test.yml`) + +- Runs on Python 3.11, 3.12, 3.13, 3.14 +- Executes all tests with pytest +- Generates coverage report +- Uploads coverage to Codecov + +### Lint Workflow (`.github/workflows/lint.yml`) + +- Runs Ruff linting checks +- Runs Ruff formatting checks +- Runs mypy type checking (informational) + +**All checks must pass before a PR can be merged.** + +## Opening Issues + +### Before Opening an Issue + +1. **Search existing issues** to avoid duplicates +2. **Check the documentation** to see if your question is already answered + +### Issue Types + +- **Bug Report**: Report a problem with the library +- **Feature Request**: Suggest a new feature or enhancement +- **Documentation**: Report issues with documentation +- **Question**: Ask questions about usage + +### Bug Report Template + +When reporting a bug, please include: + +1. **Description**: Clear description of the issue +2. **Steps to Reproduce**: + ```python + # Minimal code example + ``` +3. **Expected Behavior**: What you expected to happen +4. **Actual Behavior**: What actually happened +5. **Environment**: + - OS: (e.g., Ubuntu 22.04, Windows 11, macOS 13) + - Python version: (e.g., 3.11.5) + - OpenBatch version: (e.g., 0.0.4) +6. **Traceback** (if applicable) + +### Feature Request Template + +When requesting a feature: + +1. **Problem**: Describe the problem your feature would solve +2. **Proposed Solution**: Describe your proposed solution +3. **Alternatives**: Alternative solutions you've considered +4. **Additional Context**: Any other context or examples + +## Submitting Pull Requests + +### Before Submitting + +1. **Ensure all tests pass**: `pytest` +2. **Run code quality checks**: `ruff check src/ tests/` +3. **Format your code**: `ruff format src/ tests/` +4. **Update documentation** if you changed the API +5. **Add tests** for new features or bug fixes +6. **Update CHANGELOG.md** (if applicable) + +### Pull Request Guidelines + +1. **Title**: Use a clear, descriptive title + - Good: "✨ Add validation for batch file size limits" + - Bad: "Update validation.py" + +2. **Description**: Include: + - What changes you made and why + - Related issue number (e.g., "Closes #123") + - Screenshots/examples if applicable + - Breaking changes (if any) + +3. **Keep PRs focused**: One feature/fix per PR + +4. **Write clear commit messages**: Follow the [Commit Message Convention](#commit-message-convention) + +5. **Respond to feedback**: Be open to suggestions and discussions + +### Pull Request Template + +```markdown +## Description + + +## Related Issue + + +## Type of Change +- [ ] Bug fix (non-breaking change which fixes an issue) +- [ ] New feature (non-breaking change which adds functionality) +- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) +- [ ] Documentation update + +## Testing + + +## Checklist +- [ ] My code follows the style guidelines of this project +- [ ] I have performed a self-review of my code +- [ ] I have commented my code, particularly in hard-to-understand areas +- [ ] I have made corresponding changes to the documentation +- [ ] My changes generate no new warnings +- [ ] I have added tests that prove my fix is effective or that my feature works +- [ ] New and existing unit tests pass locally with my changes +- [ ] I have followed the guidelines on commit messages +``` + +## Documentation + +### Docstring Style + +We use Google-style docstrings: + +```python +def validate_batch_file( + file_path: str | Path, + strict: bool = True, +) -> ValidationResult: + """ + Validate a batch file (convenience function). + + Args: + file_path: Path to the JSONL batch file + strict: Enable all checks (file size, request count) + + Returns: + ValidationResult with errors, warnings, and statistics + + Raises: + FileNotFoundError: If the file does not exist + + Example: + >>> result = validate_batch_file("my_batch.jsonl") + >>> if result.is_valid: + ... print("File is valid!") + """ +``` + +### Building Documentation + +Documentation is built with MkDocs: + +```bash +# Install documentation dependencies (if not already installed) +pip install mkdocs mkdocs-material + +# Serve documentation locally +mkdocs serve + +# Build documentation +mkdocs build +``` + +Visit `http://127.0.0.1:8000` to view the documentation. + +### Adding Documentation Pages + +1. Create a new markdown file in `docs/` +2. Add it to `mkdocs.yml` under `nav:` + +## Questions? + +If you have questions about contributing: + +1. Check the [documentation](https://tiepnguyen2003.github.io/OpenAIBatchJobBuilder/) +2. Search [existing issues](https://github.com/TiepNguyen2003/OpenAIBatchJobBuilder/issues) +3. Open a new issue with the "Question" label + +Thank you for contributing to OpenBatch! 🚀 From 845e1b3110a615ce4fa5173fb3bf9b5455ab1de7 Mon Sep 17 00:00:00 2001 From: Daniel Gomm Date: Tue, 10 Feb 2026 10:28:25 +0100 Subject: [PATCH 8/9] :rotating_light: fix linting error and update pre-commit-config --- .pre-commit-config.yaml | 2 +- src/openbatch/_utils.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f48b756..709d5c1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: - id: debug-statements - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.5 + rev: v0.15.0 hooks: # Run the linter - id: ruff diff --git a/src/openbatch/_utils.py b/src/openbatch/_utils.py index e750051..2263cb8 100644 --- a/src/openbatch/_utils.py +++ b/src/openbatch/_utils.py @@ -110,9 +110,9 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object: resolved = root for key in path: value = resolved[key] - assert isinstance( - value, dict - ), f"encountered non-dictionary entry while resolving {ref} - {resolved}" + assert isinstance(value, dict), ( + f"encountered non-dictionary entry while resolving {ref} - {resolved}" + ) resolved = value return resolved From 3f70abc2848e11d2561f0cd0ab830b8004b4859a Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Feb 2026 10:44:32 +0100 Subject: [PATCH 9/9] :bug: sort endpoints in validation stats for deterministic output * Initial plan * Sort endpoints to make stats output deterministic Co-authored-by: daniel-gomm <63717948+daniel-gomm@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: daniel-gomm <63717948+daniel-gomm@users.noreply.github.com> --- src/openbatch/validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/openbatch/validation.py b/src/openbatch/validation.py index 382c3be..0176d9e 100644 --- a/src/openbatch/validation.py +++ b/src/openbatch/validation.py @@ -168,7 +168,7 @@ def _validate_content(self, file_handle: TextIO, result: ValidationResult) -> No # Update statistics result.stats["total_requests"] = line_number result.stats["unique_custom_ids"] = len(custom_ids) - result.stats["endpoints_used"] = list(endpoints) + result.stats["endpoints_used"] = sorted(endpoints) # Check request count if self.check_request_count and line_number > self.MAX_REQUESTS: @@ -180,7 +180,7 @@ def _validate_content(self, file_handle: TextIO, result: ValidationResult) -> No # Check for mixed endpoints if not self.allow_mixed_endpoints and len(endpoints) > 1: result.warnings.append( - f"Multiple endpoint types detected: {list(endpoints)}. " + f"Multiple endpoint types detected: {sorted(endpoints)}. " "OpenAI recommends one request type per file." )