diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..4711e8e --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,34 @@ +name: Code Quality + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff mypy + pip install -e . + + - name: Run Ruff linter + run: ruff check src/ tests/ + + - name: Run Ruff formatter check + run: ruff format --check src/ tests/ + + - name: Run mypy type checker + run: mypy src/ + continue-on-error: true # Don't fail build on type errors initially diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..df9408b --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,39 @@ +name: Tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11", "3.12", "3.13", "3.14"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[test]" + + - name: Run tests with pytest + run: | + pytest -v --cov=openbatch --cov-report=term-missing --cov-report=xml + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + file: ./coverage.xml + flags: unittests + name: codecov-umbrella + fail_ci_if_error: false diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..709d5c1 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,33 @@ +# Pre-commit hooks for code quality +# See https://pre-commit.com for more information + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + args: ['--maxkb=1000'] + - id: check-json + - id: check-toml + - id: check-merge-conflict + - id: debug-statements + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.15.0 + hooks: + # Run the linter + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + # Run the formatter + - id: ruff-format + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.13.0 + hooks: + - id: mypy + additional_dependencies: [pydantic>=2.11.9] + args: [--ignore-missing-imports] + files: ^src/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..cfe1927 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,433 @@ +# Contributing to OpenBatch + +Thank you for your interest in contributing to openbatch! This document provides guidelines and instructions for contributing to the project. + +## Table of Contents + +- [Code of Conduct](#code-of-conduct) +- [Getting Started](#getting-started) +- [Development Setup](#development-setup) +- [Development Workflow](#development-workflow) +- [Branch Naming Convention](#branch-naming-convention) +- [Commit Message Convention](#commit-message-convention) +- [Running Tests](#running-tests) +- [Code Quality](#code-quality) +- [Continuous Integration](#continuous-integration) +- [Opening Issues](#opening-issues) +- [Submitting Pull Requests](#submitting-pull-requests) +- [Documentation](#documentation) + +## Code of Conduct + +Please be respectful and constructive in all interactions. + +## Getting Started + +1. **Fork the repository** on GitHub +2. **Clone your fork** locally: + ```bash + git clone https://github.com/YOUR_USERNAME/openbatch.git + cd openbatch + ``` +3. **Add the upstream repository**: + ```bash + git remote add upstream https://github.com/daniel-gomm/openbatch.git + ``` + +## Development Setup + +### Prerequisites + +- Python 3.11 or higher +- pip +- git + +### Installation + +1. **Create a virtual environment** (recommended): + ```bash + python -m venv venv + source venv/bin/activate # On Windows: venv\Scripts\activate + ``` + +2. **Install the package in editable mode with development dependencies**: + ```bash + pip install -e ".[dev]" + ``` + +3. **Install pre-commit hooks**: + ```bash + pre-commit install + ``` + + This ensures code quality checks run automatically before each commit. + +### Verify Installation + +```bash +# Run tests +pytest + +# Check linting +ruff check src/ tests/ + +# Check formatting +ruff format --check src/ tests/ + +# Run type checking +mypy src/ +``` + +## Development Workflow + +1. **Sync your fork** with the upstream repository: + ```bash + git checkout main + git fetch upstream + git merge upstream/main + git push origin main + ``` + +2. **Create a new branch** for your work (see [Branch Naming Convention](#branch-naming-convention)): + ```bash + git checkout -b feature/your-feature-name + ``` + +3. **Make your changes** and commit them (see [Commit Message Convention](#commit-message-convention)) + +4. **Push your changes** to your fork: + ```bash + git push origin feature/your-feature-name + ``` + +5. **Open a Pull Request** on GitHub + +## Branch Naming Convention + +We use the following prefixes for branch names: + +- `feature/` - New features or enhancements + - Example: `feature/add-retry-mechanism` +- `fix/` - Bug fixes + - Example: `fix/validation-error-handling` +- `documentation/` - Documentation updates + - Example: `documentation/improve-api-examples` + +**Format**: `/` + +## Commit Message Convention + +We use [Gitmoji](https://gitmoji.dev/) to prefix commit messages with relevant emojis that indicate the nature of the change. + +### Common Gitmojis + +| Emoji | Code | Description | +|-------|------|-------------| +| ✨ | `:sparkles:` | Introduce new features | +| πŸ› | `:bug:` | Fix a bug | +| πŸ“ | `:memo:` | Add or update documentation | +| βœ… | `:white_check_mark:` | Add, update, or pass tests | +| ♻️ | `:recycle:` | Refactor code | +| 🎨 | `:art:` | Improve structure/format of the code | +| ⚑️ | `:zap:` | Improve performance | +| πŸ”’οΈ | `:lock:` | Fix security issues | +| ⬆️ | `:arrow_up:` | Upgrade dependencies | +| ⬇️ | `:arrow_down:` | Downgrade dependencies | +| πŸ”§ | `:wrench:` | Add or update configuration files | +| πŸš€ | `:rocket:` | Deploy stuff | +| πŸ’š | `:green_heart:` | Fix CI build | + +See the full list at [gitmoji.dev](https://gitmoji.dev/). + +### Commit Message Format + +``` + + +[Optional detailed description] + +[Optional footer with issue references] +``` + +### Examples + +```bash +# Adding a new feature +git commit -m "✨ Add batch file validation module" + +# Fixing a bug +git commit -m "πŸ› Fix custom_id uniqueness check in validator" + +# Updating documentation +git commit -m "πŸ“ Add validation usage examples to README" + +# Adding tests +git commit -m "βœ… Add integration tests for validation module" + +# Refactoring +git commit -m "♻️ Refactor BatchJobManager to use strategy pattern" +``` + +## Running Tests + +### Run All Tests + +```bash +pytest +``` + +### Run Tests with Coverage + +```bash +pytest --cov=src/openbatch --cov-report=html +``` + +View the coverage report by opening `htmlcov/index.html` in your browser. + +### Run Specific Tests + +```bash +# Run a specific test file +pytest tests/test_validation.py + +# Run a specific test class +pytest tests/test_validation.py::TestBatchFileValidator + +# Run a specific test method +pytest tests/test_validation.py::TestBatchFileValidator::test_valid_batch_file + +# Run tests matching a pattern +pytest -k "validation" +``` + +### Run Tests in Verbose Mode + +```bash +pytest -v +``` + +## Code Quality + +### Linting and Formatting + +We use **Ruff** for both linting and formatting: + +```bash +# Check for linting issues +ruff check src/ tests/ + +# Automatically fix linting issues +ruff check --fix src/ tests/ + +# Check code formatting +ruff format --check src/ tests/ + +# Format code +ruff format src/ tests/ +``` + +### Type Checking + +We use **mypy** for static type checking, configured for gradual adoption with reasonable strictness: + +```bash +mypy src/ +``` + +**Note**: Mypy is configured to allow `**kwargs` patterns and flexible union types that are common in this codebase. It focuses on catching real type errors while not being overly strict about edge cases. + +### Pre-commit Hooks + +Pre-commit hooks automatically run quality checks before each commit: + +- Trailing whitespace removal +- End-of-file fixer +- YAML syntax check +- Ruff linting and formatting +- Mypy type checking + +To manually run all pre-commit hooks: + +```bash +pre-commit run --all-files +``` + +## Continuous Integration + +We use GitHub Actions for CI/CD. All pull requests must pass: + +### Test Workflow (`.github/workflows/test.yml`) + +- Runs on Python 3.11, 3.12, 3.13, 3.14 +- Executes all tests with pytest +- Generates coverage report +- Uploads coverage to Codecov + +### Lint Workflow (`.github/workflows/lint.yml`) + +- Runs Ruff linting checks +- Runs Ruff formatting checks +- Runs mypy type checking (informational) + +**All checks must pass before a PR can be merged.** + +## Opening Issues + +### Before Opening an Issue + +1. **Search existing issues** to avoid duplicates +2. **Check the documentation** to see if your question is already answered + +### Issue Types + +- **Bug Report**: Report a problem with the library +- **Feature Request**: Suggest a new feature or enhancement +- **Documentation**: Report issues with documentation +- **Question**: Ask questions about usage + +### Bug Report Template + +When reporting a bug, please include: + +1. **Description**: Clear description of the issue +2. **Steps to Reproduce**: + ```python + # Minimal code example + ``` +3. **Expected Behavior**: What you expected to happen +4. **Actual Behavior**: What actually happened +5. **Environment**: + - OS: (e.g., Ubuntu 22.04, Windows 11, macOS 13) + - Python version: (e.g., 3.11.5) + - OpenBatch version: (e.g., 0.0.4) +6. **Traceback** (if applicable) + +### Feature Request Template + +When requesting a feature: + +1. **Problem**: Describe the problem your feature would solve +2. **Proposed Solution**: Describe your proposed solution +3. **Alternatives**: Alternative solutions you've considered +4. **Additional Context**: Any other context or examples + +## Submitting Pull Requests + +### Before Submitting + +1. **Ensure all tests pass**: `pytest` +2. **Run code quality checks**: `ruff check src/ tests/` +3. **Format your code**: `ruff format src/ tests/` +4. **Update documentation** if you changed the API +5. **Add tests** for new features or bug fixes +6. **Update CHANGELOG.md** (if applicable) + +### Pull Request Guidelines + +1. **Title**: Use a clear, descriptive title + - Good: "✨ Add validation for batch file size limits" + - Bad: "Update validation.py" + +2. **Description**: Include: + - What changes you made and why + - Related issue number (e.g., "Closes #123") + - Screenshots/examples if applicable + - Breaking changes (if any) + +3. **Keep PRs focused**: One feature/fix per PR + +4. **Write clear commit messages**: Follow the [Commit Message Convention](#commit-message-convention) + +5. **Respond to feedback**: Be open to suggestions and discussions + +### Pull Request Template + +```markdown +## Description + + +## Related Issue + + +## Type of Change +- [ ] Bug fix (non-breaking change which fixes an issue) +- [ ] New feature (non-breaking change which adds functionality) +- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) +- [ ] Documentation update + +## Testing + + +## Checklist +- [ ] My code follows the style guidelines of this project +- [ ] I have performed a self-review of my code +- [ ] I have commented my code, particularly in hard-to-understand areas +- [ ] I have made corresponding changes to the documentation +- [ ] My changes generate no new warnings +- [ ] I have added tests that prove my fix is effective or that my feature works +- [ ] New and existing unit tests pass locally with my changes +- [ ] I have followed the guidelines on commit messages +``` + +## Documentation + +### Docstring Style + +We use Google-style docstrings: + +```python +def validate_batch_file( + file_path: str | Path, + strict: bool = True, +) -> ValidationResult: + """ + Validate a batch file (convenience function). + + Args: + file_path: Path to the JSONL batch file + strict: Enable all checks (file size, request count) + + Returns: + ValidationResult with errors, warnings, and statistics + + Raises: + FileNotFoundError: If the file does not exist + + Example: + >>> result = validate_batch_file("my_batch.jsonl") + >>> if result.is_valid: + ... print("File is valid!") + """ +``` + +### Building Documentation + +Documentation is built with MkDocs: + +```bash +# Install documentation dependencies (if not already installed) +pip install mkdocs mkdocs-material + +# Serve documentation locally +mkdocs serve + +# Build documentation +mkdocs build +``` + +Visit `http://127.0.0.1:8000` to view the documentation. + +### Adding Documentation Pages + +1. Create a new markdown file in `docs/` +2. Add it to `mkdocs.yml` under `nav:` + +## Questions? + +If you have questions about contributing: + +1. Check the [documentation](https://tiepnguyen2003.github.io/OpenAIBatchJobBuilder/) +2. Search [existing issues](https://github.com/TiepNguyen2003/OpenAIBatchJobBuilder/issues) +3. Open a new issue with the "Question" label + +Thank you for contributing to OpenBatch! πŸš€ diff --git a/README.md b/README.md index 607a59b..526dda3 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,13 @@ # OpenBatch: Simplify OpenAI Batch Job Creation -[](https://www.google.com/search?q=https://badge.fury.io/py/openbatch) **OpenBatch** is a lightweight Python utility designed to streamline the creation of JSONL files for the [OpenAI Batch API](https://platform.openai.com/docs/guides/batch). It provides a type-safe and intuitive interface using Pydantic models to construct requests for the `/v1/responses`, `/v1/chat/completions`, and `/v1/embeddings` endpoints. +[![PyPI version](https://badge.fury.io/py/openbatch.svg)](https://badge.fury.io/py/openbatch) +[![Python versions](https://img.shields.io/pypi/pyversions/openbatch.svg)](https://pypi.org/project/openbatch/) +[![Tests](https://github.com/daniel-gomm/openbatch/actions/workflows/test.yml/badge.svg)](https://github.com/daniel-gomm/openbatch/actions/workflows/test.yml) +[![codecov](https://codecov.io/gh/daniel-gomm/openbatch/branch/main/graph/badge.svg)](https://codecov.io/gh/daniel-gomm/openbatch) +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) +[![GitHub stars](https://img.shields.io/github/stars/daniel-gomm/openbatch.svg?style=social&label=Star)](https://github.com/daniel-gomm/openbatch) + +**OpenBatch** is a lightweight Python utility designed to streamline the creation of JSONL files for the [OpenAI Batch API](https://platform.openai.com/docs/guides/batch). It provides a type-safe and intuitive interface using Pydantic models to construct requests for the `/v1/responses`, `/v1/chat/completions`, and `/v1/embeddings` endpoints. For a detailed guide on using OpenBatch, please refer to the **[OpenBatch Documentation](https://openbatch.daniel-gomm.com/)**. @@ -189,3 +196,53 @@ You can also override any common setting on a per-instance basis by using the `i 3. **Retrieve Results**: Monitor the job's status and, once completed, download the output file with the results. For detailed instructions on these steps, please refer to the **[Official OpenAI Batch API Documentation](https://platform.openai.com/docs/api-reference/batch)**. + +----- + +## Validation + +OpenBatch includes built-in [validation](https://openbatch.daniel-gomm.com/validation/) to catch errors before uploading to OpenAI. + +```python +from openbatch import validate_batch_file + +result = validate_batch_file("my_batch.jsonl") +if result.is_valid: + print(f"Valid! {result.stats['total_requests']} requests") +``` + +----- + +## Testing + +OpenBatch includes a comprehensive test suite. + +```bash +# Install with test dependencies +pip install -e ".[test]" + +# Run tests +pytest + +# Run with coverage report +pytest --cov=openbatch +``` + +The test suite includes: +- Unit tests for all core functionality +- Integration tests for end-to-end workflows +- Tests for structured outputs, reasoning models, and unicode handling + +----- + +## Contributing + +Contributions are welcome! Please see our [Contributing Guide](CONTRIBUTING.md) for details on: + +- Development setup +- Code quality standards +- Branch naming conventions (`feature/`, `fix/`, `documentation/`) +- Commit message conventions (using [gitmoji](https://gitmoji.dev/)) +- Opening issues and pull requests + +For bugs and feature requests, please [open an issue](https://github.com/TiepNguyen2003/OpenAIBatchJobBuilder/issues). diff --git a/docs/index.md b/docs/index.md index daa63b3..15b7708 100644 --- a/docs/index.md +++ b/docs/index.md @@ -2,7 +2,7 @@ **OpenBatch** is a lightweight Python utility designed to streamline the creation of JSONL files for the [OpenAI Batch API](https://platform.openai.com/docs/guides/batch). It provides a type-safe and intuitive interface using Pydantic models to construct requests for various endpoints, including `/v1/chat/completions`, `/v1/embeddings`, and the new `/v1/responses` endpoint. -## Key Features ✨ +## Key Features - **Type-Safe & Modern**: Built with Pydantic for robust, self-documenting, and editor-friendly request models. - **Dual APIs for Flexibility**: diff --git a/docs/validation.md b/docs/validation.md new file mode 100644 index 0000000..b7642a9 --- /dev/null +++ b/docs/validation.md @@ -0,0 +1,135 @@ +# Batch File Validation + +OpenBatch includes comprehensive validation to catch errors before uploading batch files to OpenAI, saving time and preventing failed uploads. + +## Overview + +The validation module checks your batch files for: + +- βœ“ Valid JSONL format +- βœ“ Unique `custom_id` values +- βœ“ Required fields present +- βœ“ Valid HTTP methods (POST) +- βœ“ Valid endpoint URLs +- βœ“ Correct request body structure for each API type +- βœ“ File size limits (200 MB) +- βœ“ Request count limits (50,000) +- ⚠ Mixed endpoint types warning + +## Simple Validation + +```python +from openbatch import validate_batch_file + +result = validate_batch_file("my_batch.jsonl") + +if result.is_valid: + print(f"βœ“ Batch file is valid!") + print(f"Total requests: {result.stats['total_requests']}") + # Proceed with upload to OpenAI +else: + print(f"βœ— Validation failed:") + print(result) # Shows detailed errors and warnings +``` + +### Quick Boolean Check + +For a fast True/False check: + +```python +from openbatch.validation import quick_validate + +if quick_validate("my_batch.jsonl"): + # File is valid, proceed + upload_to_openai("my_batch.jsonl") +``` + +## Validation Result + +The `ValidationResult` object provides detailed information: + +```python +result = validate_batch_file("batch.jsonl") + +# Check validity +if result.is_valid: + # Access statistics + print(f"Requests: {result.stats['total_requests']}") + print(f"File size: {result.stats['file_size_mb']} MB") + print(f"Endpoints: {result.stats['endpoints_used']}") + print(f"Unique IDs: {result.stats['unique_custom_ids']}") +else: + # Handle errors + for error in result.errors: + print(f"ERROR: {error}") + + # Review warnings + for warning in result.warnings: + print(f"WARNING: {warning}") +``` + +### Human-Readable Output + +The `ValidationResult` has a nice string representation: + +```python +result = validate_batch_file("batch.jsonl") +print(result) +``` + +Output: +``` +Validation: PASSED + +Statistics: + total_requests: 100 + unique_custom_ids: 100 + endpoints_used: ['/v1/responses'] + file_size_mb: 0.05 + +Warnings (1): + β€’ File extension is '.json', expected '.jsonl' +``` + +## Advanced Usage + +### Configurable Validation + +Use `BatchFileValidator` for custom validation options: + +```python +from openbatch.validation import BatchFileValidator + +validator = BatchFileValidator( + check_custom_id_uniqueness=True, # Check for duplicate IDs + check_file_size=True, # Check 200 MB limit + check_request_count=True, # Check 50K request limit + allow_mixed_endpoints=False # Warn about mixed endpoints +) + +result = validator.validate_file("batch.jsonl") +``` + +### Disable Specific Checks + +```python +from openbatch import validate_batch_file + +# Skip custom_id uniqueness check (if you want duplicates) +result = validate_batch_file( + "batch.jsonl", + check_custom_id_uniqueness=False +) + +# Allow mixed endpoints without warning +result = validate_batch_file( + "batch.jsonl", + allow_mixed_endpoints=True +) + +# Disable strict checks (file size, request count) +result = validate_batch_file( + "batch.jsonl", + strict=False +) +``` diff --git a/mkdocs.yml b/mkdocs.yml index c201aa9..637afde 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -28,6 +28,7 @@ plugins: nav: - index.md - Getting Started: getting_started.md + - Validation: validation.md - Reference: - BatchJobManager: reference/batch_job_manager.md - BatchJobCollector: reference/batch_job_collector.md diff --git a/pyproject.toml b/pyproject.toml index d771905..bd65d98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,13 +4,13 @@ build-backend = "setuptools.build_meta" [project] name = "openbatch" -version = "0.0.3" +version = "0.0.4" authors = [ { name="Daniel Gomm", email="daniel.gomm@cwi.nl" }, ] description = "Create batch jobs for the OpenAI API with ease." readme = "README.md" -requires-python = ">=3.5" +requires-python = ">=3.11" classifiers = [ "Programming Language :: Python :: 3", "Operating System :: OS Independent", @@ -23,3 +23,103 @@ dependencies = [ Homepage = "https://github.com/daniel-gomm/openbatch" Issues = "https://github.com/daniel-gomm/openbatch/issues" Documentation = "https://daniel-gomm.github.io/openbatch/" + +[project.optional-dependencies] +test = [ + "pytest>=8.0.0", + "pytest-cov>=4.1.0", +] +dev = [ + "ruff>=0.8.0", + "mypy>=1.13.0", + "pre-commit>=4.0.0", +] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = [ + "--strict-markers", + "--strict-config", + "--cov=openbatch", + "--cov-report=term-missing", + "--cov-report=html", + "--cov-report=xml", +] + +[tool.coverage.run] +source = ["src/openbatch"] +omit = ["*/tests/*", "*/__pycache__/*"] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise AssertionError", + "raise NotImplementedError", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", + "@abstractmethod", +] + +# Ruff configuration +[tool.ruff] +line-length = 100 +target-version = "py311" + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "N", # pep8-naming + "UP", # pyupgrade + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "SIM", # flake8-simplify + "RUF", # ruff-specific rules +] +ignore = [ + "E501", # line too long (handled by formatter) + "B008", # do not perform function calls in argument defaults + "N805", # first argument should be named self (pydantic validators) +] + +[tool.ruff.lint.per-file-ignores] +"tests/**/*.py" = [ + "N802", # function name should be lowercase + "N803", # argument name should be lowercase +] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +line-ending = "lf" + +# Mypy configuration - Static type checking +# Configured for gradual adoption with reasonable strictness +[tool.mypy] +python_version = "3.11" +warn_return_any = false # Too strict for **kwargs patterns +warn_unused_configs = true +disallow_untyped_defs = false # Allow untyped defs (gradual adoption) +disallow_incomplete_defs = false # Allow incomplete defs (gradual adoption) +check_untyped_defs = true +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = false # Pre-commit adds ignores, this would conflict +warn_no_return = true +strict_equality = true +# Allow flexible type checking for union types and **kwargs +disable_error_code = ["assignment", "no-untyped-def"] + +[[tool.mypy.overrides]] +module = "tests.*" +disallow_untyped_defs = false + +[[tool.mypy.overrides]] +module = "pydantic.*" +ignore_missing_imports = true diff --git a/src/openbatch/__init__.py b/src/openbatch/__init__.py index 6e210bc..b2c17e5 100644 --- a/src/openbatch/__init__.py +++ b/src/openbatch/__init__.py @@ -1,27 +1,29 @@ from openbatch.collector import BatchCollector from openbatch.manager import BatchJobManager from openbatch.model import ( + ChatCompletionsRequest, + EmbeddingInputInstance, + EmbeddingsRequest, Message, + MessagesInputInstance, PromptTemplate, PromptTemplateInputInstance, - MessagesInputInstance, - EmbeddingInputInstance, - ResponsesRequest, - ChatCompletionsRequest, - EmbeddingsRequest, ReasoningConfig, + ResponsesRequest, ) +from openbatch.validation import validate_batch_file __all__ = [ "BatchCollector", "BatchJobManager", + "ChatCompletionsRequest", + "EmbeddingInputInstance", + "EmbeddingsRequest", "Message", + "MessagesInputInstance", "PromptTemplate", "PromptTemplateInputInstance", - "MessagesInputInstance", - "EmbeddingInputInstance", - "ResponsesRequest", - "ChatCompletionsRequest", - "EmbeddingsRequest", "ReasoningConfig", + "ResponsesRequest", + "validate_batch_file", ] diff --git a/src/openbatch/_utils.py b/src/openbatch/_utils.py index 6c01376..2263cb8 100644 --- a/src/openbatch/_utils.py +++ b/src/openbatch/_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, TypeVar +from typing import Any, TypeVar from pydantic import BaseModel @@ -6,6 +6,7 @@ # Copied and adapted from the OpenAI library to avoid adding a dependency https://github.com/openai/openai-python/blob/main/src/openai/lib/_pydantic.py + def _ensure_strict_json_schema( json_schema: object, *, @@ -26,7 +27,9 @@ def _ensure_strict_json_schema( definitions = json_schema.get("definitions") if isinstance(definitions, dict): for definition_name, definition_schema in definitions.items(): - _ensure_strict_json_schema(definition_schema, path=(*path, "definitions", definition_name), root=root) + _ensure_strict_json_schema( + definition_schema, path=(*path, "definitions", definition_name), root=root + ) typ = json_schema.get("type") if typ == "object" and "additionalProperties" not in json_schema: @@ -36,7 +39,7 @@ def _ensure_strict_json_schema( # { 'type': 'object', 'properties': { 'a': {...} } } properties = json_schema.get("properties") if isinstance(properties, dict): - json_schema["required"] = [prop for prop in properties.keys()] + json_schema["required"] = list(properties) json_schema["properties"] = { key: _ensure_strict_json_schema(prop_schema, path=(*path, "properties", key), root=root) for key, prop_schema in properties.items() @@ -60,7 +63,9 @@ def _ensure_strict_json_schema( all_of = json_schema.get("allOf") if isinstance(all_of, dict): if len(all_of) == 1: - json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0"), root=root)) + json_schema.update( + _ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0"), root=root) + ) json_schema.pop("allOf") else: json_schema["allOf"] = [ @@ -79,7 +84,9 @@ def _ensure_strict_json_schema( resolved = resolve_ref(root=root, ref=ref) if not isinstance(resolved, dict): - raise ValueError(f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}") + raise ValueError( + f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}" + ) # properties from the json schema take priority over the ones on the `$ref` json_schema.update({**resolved, **json_schema}) @@ -90,13 +97,10 @@ def _ensure_strict_json_schema( return json_schema + def has_more_than_n_keys(obj: dict[str, object], n: int) -> bool: - i = 0 - for _ in obj.keys(): - i += 1 - if i > n: - return True - return False + return any(i > n for i, _ in enumerate(obj, 1)) + def resolve_ref(*, root: dict[str, object], ref: str) -> object: if not ref.startswith("#/"): @@ -106,12 +110,15 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object: resolved = root for key in path: value = resolved[key] - assert isinstance(value, dict), f"encountered non-dictionary entry while resolving {ref} - {resolved}" + assert isinstance(value, dict), ( + f"encountered non-dictionary entry while resolving {ref} - {resolved}" + ) resolved = value return resolved -def type_to_json_schema(output_type: type[T]) -> Dict[str, Any]: + +def type_to_json_schema(output_type: type[T]) -> dict[str, Any]: json_schema = output_type.model_json_schema() schema = _ensure_strict_json_schema(json_schema, path=(), root=json_schema) return schema diff --git a/src/openbatch/collector.py b/src/openbatch/collector.py index 1e82da6..f1dc56a 100644 --- a/src/openbatch/collector.py +++ b/src/openbatch/collector.py @@ -1,22 +1,22 @@ from os import PathLike +from pathlib import Path from types import SimpleNamespace -from typing import Union, Optional from pydantic import BaseModel from openbatch.manager import BatchJobManager -from openbatch.model import ResponsesRequest, ChatCompletionsRequest, EmbeddingsRequest +from openbatch.model import ChatCompletionsRequest, EmbeddingsRequest, ResponsesRequest class Responses: """ - A utility class for easily constructing and adding individual - Responses API requests to a batch job file. + A utility class for easily constructing and adding individual + Responses API requests to a batch job file. - It acts as a high-level interface for the '/v1/responses' endpoint. - """ + It acts as a high-level interface for the '/v1/responses' endpoint. + """ - def __init__(self, batch_file_path: Union[str, PathLike]): + def __init__(self, batch_file_path: str | PathLike): """ Initializes the Responses collector. @@ -24,9 +24,12 @@ def __init__(self, batch_file_path: Union[str, PathLike]): batch_file_path (Union[str, PathLike]): The path to the JSONL file where the batch requests will be written. """ - self.batch_file_path = batch_file_path + self.batch_file_path = Path(batch_file_path) + self._manager = BatchJobManager() - def parse(self, custom_id: str, model: str, text_format: Optional[type[BaseModel]] = None, **kwargs) -> None: + def parse( + self, custom_id: str, model: str, text_format: type[BaseModel] | None = None, **kwargs + ) -> None: """ Creates a ResponsesRequest, optionally enforcing a JSON output structure, and adds it to the batch file. Use it like the `OpenAI().responses.parse()` method. @@ -56,16 +59,18 @@ def create(self, custom_id: str, model: str, **kwargs) -> None: self._add_request(custom_id, request) def _add_request(self, custom_id: str, request: ResponsesRequest) -> None: - BatchJobManager.add(custom_id, request, self.batch_file_path) + self._manager.add(custom_id, request, self.batch_file_path) + class ChatCompletions: """ - A utility class for easily constructing and adding individual - Chat Completions API requests to a batch job file. + A utility class for easily constructing and adding individual + Chat Completions API requests to a batch job file. - It acts as a high-level interface for the '/v1/chat/completions' endpoint. - """ - def __init__(self, batch_file_path: Union[str, PathLike]): + It acts as a high-level interface for the '/v1/chat/completions' endpoint. + """ + + def __init__(self, batch_file_path: str | PathLike): """ Initializes the ChatCompletions collector. @@ -73,9 +78,12 @@ def __init__(self, batch_file_path: Union[str, PathLike]): batch_file_path (Union[str, PathLike]): The path to the JSONL file where the batch requests will be written. """ - self.batch_file_path = batch_file_path + self.batch_file_path = Path(batch_file_path) + self._manager = BatchJobManager() - def parse(self, custom_id: str, model: str, response_format: Optional[type[BaseModel]] = None, **kwargs) -> None: + def parse( + self, custom_id: str, model: str, response_format: type[BaseModel] | None = None, **kwargs + ) -> None: """ Creates a ChatCompletionsRequest, optionally enforcing a JSON output structure, and adds it to the batch file. Use it like the `OpenAI().chat.completions.parse()` method. @@ -105,7 +113,8 @@ def create(self, custom_id: str, model: str, **kwargs) -> None: self._add_request(custom_id, request) def _add_request(self, custom_id: str, request: ChatCompletionsRequest) -> None: - BatchJobManager.add(custom_id, request, self.batch_file_path) + self._manager.add(custom_id, request, self.batch_file_path) + class Embeddings: """ @@ -114,7 +123,8 @@ class Embeddings: It acts as a high-level interface for the '/v1/embeddings' endpoint. """ - def __init__(self, batch_file_path: Union[str, PathLike]): + + def __init__(self, batch_file_path: str | PathLike): """ Initializes the Embeddings collector. @@ -122,9 +132,10 @@ def __init__(self, batch_file_path: Union[str, PathLike]): batch_file_path (Union[str, PathLike]): The path to the JSONL file where the batch requests will be written. """ - self.batch_file_path = batch_file_path + self.batch_file_path = Path(batch_file_path) + self._manager = BatchJobManager() - def create(self, custom_id: str, model: str, inp: Union[str, list[str]], **kwargs) -> None: + def create(self, custom_id: str, model: str, inp: str | list[str], **kwargs) -> None: """ Creates an EmbeddingsRequest and adds it to the batch file. Use it like the `OpenAI().embeddings.create()` method. @@ -135,7 +146,7 @@ def create(self, custom_id: str, model: str, inp: Union[str, list[str]], **kwarg **kwargs: Additional parameters for the EmbeddingsRequest. """ request = EmbeddingsRequest.model_validate({"model": model, "input": inp, **kwargs}) - BatchJobManager.add(custom_id, request, self.batch_file_path) + self._manager.add(custom_id, request, self.batch_file_path) class BatchCollector: @@ -157,7 +168,8 @@ class BatchCollector: where the batch requests will be written. The file will be created if it doesn't exist and appended to if it does. """ - def __init__(self, batch_file_path: Union[str, PathLike]): + + def __init__(self, batch_file_path: str | PathLike): self.responses = Responses(batch_file_path) self.chat = SimpleNamespace() self.chat.completions = ChatCompletions(batch_file_path) diff --git a/src/openbatch/manager.py b/src/openbatch/manager.py index 76a1a10..e6172a4 100644 --- a/src/openbatch/manager.py +++ b/src/openbatch/manager.py @@ -1,25 +1,26 @@ import json import warnings +from collections.abc import Iterable from copy import deepcopy from pathlib import Path -from typing import TypeVar, Iterable, Union +from typing import TypeVar from openbatch.model import ( + BaseRequest, + ChatCompletionsAPIStrategy, + ChatCompletionsRequest, + EmbeddingInputInstance, + EmbeddingsAPIStrategy, + EmbeddingsRequest, PromptTemplate, - ReusablePrompt, PromptTemplateInputInstance, - ResponsesRequest, ResponsesAPIStrategy, - EmbeddingsAPIStrategy, - ChatCompletionsAPIStrategy, - BaseRequest, - EmbeddingsRequest, - EmbeddingInputInstance, - ChatCompletionsRequest, + ResponsesRequest, + ReusablePrompt, ) B = TypeVar("B", bound=BaseRequest) -R = TypeVar("R", bound=Union[ResponsesRequest, ChatCompletionsRequest]) +R = TypeVar("R", bound=ResponsesRequest | ChatCompletionsRequest) class BatchJobManager: @@ -41,7 +42,7 @@ def __init__(self, ensure_ascii: bool = True) -> None: def add_templated_instances( self, - prompt: Union[PromptTemplate, ReusablePrompt], + prompt: PromptTemplate | ReusablePrompt, common_request: R, input_instances: Iterable[PromptTemplateInputInstance], save_file_path: str | Path, @@ -69,9 +70,7 @@ def add_templated_instances( unsupported request type. """ if isinstance(common_request, EmbeddingsRequest): - raise ValueError( - "Embeddings API is not supported with templated instances." - ) + raise ValueError("Embeddings API is not supported with templated instances.") elif not isinstance(common_request, ResponsesRequest) and not isinstance( common_request, ChatCompletionsRequest ): @@ -84,6 +83,7 @@ def add_templated_instances( warnings.warn( f"File {save_file_path} already exists. New contents are appended to the file. Make sure that this is intended behavior.", category=RuntimeWarning, + stacklevel=2, ) for instance in input_instances: @@ -92,15 +92,13 @@ def add_templated_instances( request = request.model_copy(update=instance.instance_request_options) request = self._handle_prompt(prompt, request, instance) - self.add( - custom_id=instance.id, request=request, save_file_path=save_file_path - ) + self.add(custom_id=instance.id, request=request, save_file_path=save_file_path) def add_embedding_requests( self, inputs: Iterable[EmbeddingInputInstance], common_request: EmbeddingsRequest, - save_file_path: Union[str, Path], + save_file_path: str | Path, ) -> None: """ Adds multiple embedding request instances to a batch request file. @@ -123,15 +121,13 @@ def add_embedding_requests( request = request.model_copy(update=instance.instance_request_options) request.set_input(instance.input) - self.add( - custom_id=instance.id, request=request, save_file_path=save_file_path - ) + self.add(custom_id=instance.id, request=request, save_file_path=save_file_path) def add( self, custom_id: str, request: B, - save_file_path: Union[str, Path], + save_file_path: str | Path, ) -> None: """ Creates a single batch request object and appends it to the specified file. @@ -153,9 +149,7 @@ def add( if isinstance(request, ResponsesRequest): strategy = ResponsesAPIStrategy() if request.input is None and request.prompt is None: - raise ValueError( - "Responses request must define either an input or a prompt." - ) + raise ValueError("Responses request must define either an input or a prompt.") elif isinstance(request, ChatCompletionsRequest): strategy = ChatCompletionsAPIStrategy() if request.messages is None: @@ -170,14 +164,10 @@ def add( save_file_path = Path(save_file_path) save_file_path.parent.mkdir(parents=True, exist_ok=True) - batch_request = strategy.create_request( - custom_id=custom_id, body=request.to_dict() - ) + batch_request = strategy.create_request(custom_id=custom_id, body=request.to_dict()) with open(save_file_path, "a+") as outfile: - outfile.write( - json.dumps(batch_request, ensure_ascii=self.ensure_ascii) + "\n" - ) + outfile.write(json.dumps(batch_request, ensure_ascii=self.ensure_ascii) + "\n") @staticmethod def _handle_prompt( @@ -187,9 +177,7 @@ def _handle_prompt( ) -> R: if isinstance(prompt, ReusablePrompt): if not isinstance(request, ResponsesRequest): - raise ValueError( - "Reusable prompts can only be used with ResponsesOptions." - ) + raise ValueError("Reusable prompts can only be used with ResponsesOptions.") request.prompt = ReusablePrompt( id=prompt.id, version=prompt.version, diff --git a/src/openbatch/model.py b/src/openbatch/model.py index d69ec1e..6a8f632 100644 --- a/src/openbatch/model.py +++ b/src/openbatch/model.py @@ -1,15 +1,15 @@ -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod from os import PathLike from pathlib import Path -from typing import List, Dict, Any, Optional, Literal, TypeVar, Union, Self +from typing import Any, Literal, TypeVar from pydantic import BaseModel, Field from openbatch._utils import type_to_json_schema - T = TypeVar("T", bound=BaseModel) + class Message(BaseModel): """ Represents a single message in a conversation or prompt. @@ -18,6 +18,7 @@ class Message(BaseModel): role (str): The role of the message sender (e.g., "user", "assistant", "system"). content (str): The text content of the message. """ + role: str content: str @@ -30,6 +31,7 @@ def serialize(self): """ return {"role": self.role, "content": self.content} + class PromptTemplate(BaseModel): """ A template containing a sequence of messages, where the content can contain @@ -38,9 +40,10 @@ class PromptTemplate(BaseModel): Attributes: messages (List[Message]): A list of Message objects that form the template. """ - messages: List[Message] - def format(self, **kwargs) -> List[Message]: + messages: list[Message] + + def format(self, **kwargs) -> list[Message]: """ Formats the content of each message in the template using the provided keyword arguments. @@ -56,6 +59,7 @@ def format(self, **kwargs) -> List[Message]: formatted_messages.append(Message(role=message.role, content=formatted_content)) return formatted_messages + class ReusablePrompt(BaseModel): """ References a reusable prompt template and its associated variables. @@ -66,9 +70,11 @@ class ReusablePrompt(BaseModel): variables (Dict[str, Any]): A dictionary of variable names and their values to be used when formatting the prompt. """ + id: str version: str - variables: Dict[str, Any] + variables: dict[str, Any] + class ReasoningConfig(BaseModel): """ @@ -80,8 +86,14 @@ class ReasoningConfig(BaseModel): summary (Optional[Literal["auto", "concise", "detailed"]]): A summary of the reasoning performed by the model. Optional. """ - effort: Literal["minimal", "low", "medium", "high"] = Field(default="medium", description="Constrains effort on reasoning for reasoning models.") - summary: Optional[Literal["auto", "concise", "detailed"]] = Field(None, description="A summary of the reasoning performed by the model.") + + effort: Literal["minimal", "low", "medium", "high"] = Field( + default="medium", description="Constrains effort on reasoning for reasoning models." + ) + summary: Literal["auto", "concise", "detailed"] | None = Field( + None, description="A summary of the reasoning performed by the model." + ) + class InputInstance(BaseModel): """ @@ -92,8 +104,12 @@ class InputInstance(BaseModel): instance_request_options (Optional[Dict[str, Any]]): Options specific to the input instance that can be set in the API request. Optional. """ + id: str = Field(description="Unique identifier of the input instance.") - instance_request_options: Optional[Dict[str, Any]] = Field(None, description="Options specific to the input instance that to set in the request.") + instance_request_options: dict[str, Any] | None = Field( + None, description="Options specific to the input instance that to set in the request." + ) + class MessagesInputInstance(InputInstance): """ @@ -105,7 +121,9 @@ class MessagesInputInstance(InputInstance): instance_request_options (Optional[Dict[str, Any]]): Options specific to the input instance that can be set in the API request. Optional. """ - messages: List[Message] = Field(description="List of messages to be sent to the model.") + + messages: list[Message] = Field(description="List of messages to be sent to the model.") + class PromptTemplateInputInstance(InputInstance): """ @@ -118,7 +136,11 @@ class PromptTemplateInputInstance(InputInstance): instance_request_options (Optional[Dict[str, Any]]): Options specific to the input instance that can be set in the API request. Optional. """ - prompt_value_mapping: Dict[str, str] = Field(description="Mapping of prompt variable names to their values.") + + prompt_value_mapping: dict[str, str] = Field( + description="Mapping of prompt variable names to their values." + ) + class EmbeddingInputInstance(InputInstance): """ @@ -130,13 +152,16 @@ class EmbeddingInputInstance(InputInstance): instance_request_options (Optional[Dict[str, Any]]): Options specific to the input instance that can be set in the API request. Optional. """ - input: Union[str, List[str]] = Field(description="Text(s) to be embedded.") + + input: str | list[str] = Field(description="Text(s) to be embedded.") + class RequestStrategy(ABC): """ Abstract base class defining the strategy for creating a request for a specific API endpoint. """ + @property @abstractmethod def url(self) -> str: @@ -148,7 +173,7 @@ def url(self) -> str: """ pass - def create_request(self, custom_id: str, body: Dict[str, Any]) -> Dict[str, Any]: + def create_request(self, custom_id: str, body: dict[str, Any]) -> dict[str, Any]: """ Creates a structured request dictionary for a batch job. @@ -159,31 +184,33 @@ def create_request(self, custom_id: str, body: Dict[str, Any]) -> Dict[str, Any] Returns: Dict[str, Any]: A dictionary representing the complete request structure. """ - return { - "custom_id": custom_id, - "method": "POST", - "url": self.url, - "body": body - } + return {"custom_id": custom_id, "method": "POST", "url": self.url, "body": body} + class ResponsesAPIStrategy(RequestStrategy): """Strategy for creating requests to the /v1/responses endpoint.""" + @property def url(self) -> str: return "/v1/responses" + class ChatCompletionsAPIStrategy(RequestStrategy): """Strategy for creating requests to the /v1/chat/completions endpoint.""" + @property def url(self) -> str: return "/v1/chat/completions" + class EmbeddingsAPIStrategy(RequestStrategy): """Strategy for creating requests to the /v1/embeddings endpoint.""" + @property def url(self) -> str: return "/v1/embeddings" + class BaseRequest(BaseModel, ABC): """ Abstract base class for API-specific request configurations (job configurations). @@ -191,9 +218,12 @@ class BaseRequest(BaseModel, ABC): Attributes: model (str): Model ID used to generate the response, like "gpt-4.1". Defaults to "gpt-4.1". """ - model: str = Field("gpt-4.1", description="Model ID used to generate the response, like gpt-4o or o3.") - def to_dict(self) -> Dict[str, Any]: + model: str = Field( + "gpt-4.1", description="Model ID used to generate the response, like gpt-4o or o3." + ) + + def to_dict(self) -> dict[str, Any]: """ Converts the request configuration object to a dictionary, excluding fields that are None. @@ -202,6 +232,7 @@ def to_dict(self) -> Dict[str, Any]: """ return self.model_dump(exclude_none=True) + class TextGenerationRequest(BaseRequest, ABC): """ Abstract base class for text generation requests, including common parameters @@ -220,68 +251,136 @@ class TextGenerationRequest(BaseRequest, ABC): tool_choice (Optional[str | object]): How the model should select which tool to use. top_logprobs (Optional[int]): Number of most likely tokens to return at each position (0 to 20). """ - tools: Optional[List[object]] = Field(None, description="An array of tools the model may call while generating a response.") - top_p: Optional[float] = Field(None, ge=0, le=1, description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass.") - parallel_tool_calls: Optional[bool] = Field(None, description="Whether to allow the model to run tool calls in parallel.") - prompt_cache_key: Optional[str] = Field(None, description="Used by OpenAI to cache responses for similar requests to optimize your cache hit rates.") - safety_identifier: Optional[str] = Field(None, description="A stable identifier used to help detect users of your application that may be violating OpenAI's usage policies.") - service_tier: Optional[Literal["auto", "default", "flex", "priority"]] = Field(None, description="Specifies the processing type used for serving the request.") - store: Optional[bool] = Field(None, description="Whether to store the generated model response for later retrieval via API.") - temperature: Optional[float] = Field(None, ge=0, le=2, description="What sampling temperature to use, between 0 and 2.") - tool_choice: Optional[str | object] = Field(None, description="How the model should select which tool (or tools) to use when generating a response.") - top_logprobs: Optional[int] = Field(None, ge=0, le=20, description="An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability.") + tools: list[object] | None = Field( + None, description="An array of tools the model may call while generating a response." + ) + top_p: float | None = Field( + None, + ge=0, + le=1, + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass.", + ) + parallel_tool_calls: bool | None = Field( + None, description="Whether to allow the model to run tool calls in parallel." + ) + prompt_cache_key: str | None = Field( + None, + description="Used by OpenAI to cache responses for similar requests to optimize your cache hit rates.", + ) + safety_identifier: str | None = Field( + None, + description="A stable identifier used to help detect users of your application that may be violating OpenAI's usage policies.", + ) + service_tier: Literal["auto", "default", "flex", "priority"] | None = Field( + None, description="Specifies the processing type used for serving the request." + ) + store: bool | None = Field( + None, + description="Whether to store the generated model response for later retrieval via API.", + ) + temperature: float | None = Field( + None, ge=0, le=2, description="What sampling temperature to use, between 0 and 2." + ) + tool_choice: str | object | None = Field( + None, + description="How the model should select which tool (or tools) to use when generating a response.", + ) + top_logprobs: int | None = Field( + None, + ge=0, + le=20, + description="An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability.", + ) @abstractmethod def set_output_structure(self, output_type: type[T]) -> None: pass @abstractmethod - def set_input_messages(self, messages: List[Message]) -> None: + def set_input_messages(self, messages: list[Message]) -> None: pass class ResponsesRequest(TextGenerationRequest): """ - Configuration for a /v1/responses API request. - - Attributes: - model (str): Model ID used to generate the response, like "gpt-4.1". Defaults to "gpt-4.1". - conversation (Optional[str]): The conversation this response belongs to. - include (Optional[List[Literal[...]]]): Specify additional output data to include. - input (Optional[str | List[Dict[str, str]]]): Text, image, or file inputs to the model. - instructions (Optional[str]): A system or developer message. - max_output_tokens (Optional[int]): Upper bound for generated tokens. - max_tool_calls (Optional[int]): Maximum number of tool calls allowed. - previous_response_id (Optional[str]): ID of the previous response for multi-turn. - prompt (Optional[ReusablePrompt]): Reference to a prompt template and its variables. - reasoning (Optional[ReasoningConfig]): Configuration for reasoning models. - text (Optional[object]): Configuration options for a text response from the model (e.g., JSON schema). - truncation (Optional[Literal["auto", "disabled"]]): The truncation strategy to use. - tools (Optional[List[object]]): An array of tools the model may call. - top_p (Optional[float]): An alternative to sampling with temperature (nucleus sampling). - parallel_tool_calls (Optional[bool]): Whether to allow parallel tool calls. - prompt_cache_key (Optional[str]): Used by OpenAI to cache responses. - safety_identifier (Optional[str]): A stable identifier for policy monitoring. - service_tier (Optional[Literal["auto", "default", "flex", "priority"]]): Specifies the processing type. - store (Optional[bool]): Whether to store the generated model response. - temperature (Optional[float]): Sampling temperature to use (0 to 2). - tool_choice (Optional[str | object]): How the model should select which tool to use. - top_logprobs (Optional[int]): Number of most likely tokens to return at each position (0 to 20). - """ - conversation: Optional[str] = Field(None, description="The conversation that this response belongs to.") - include: Optional[List[Literal["code_interpreter_call.outputs", "computer_call_output.output.image_url", "file_search_call.results", "message.input_image.image_url", "message.output_text.logprobs", "reasoning.encrypted_content"]]] = Field(None, description="Specify additional output data to include in the model response.") - input: Optional[str | List[Dict[str, str]]] = Field(None, description="Text, image, or file inputs to the model, used to generate a response.") - instructions: Optional[str] = Field(None, description="A system (or developer) message inserted into the model's context.") - max_output_tokens: Optional[int] = Field(None, gt=0, description="An upper bound for the number of tokens that can be generated for a response, including visible output tokens and reasoning tokens.") - max_tool_calls: Optional[int] = Field(None, gt=0, description="The maximum number of total calls to built-in tools that can be processed in a response.") - previous_response_id: Optional[str] = Field(None, description="The unique ID of the previous response to the model. Use this to create multi-turn conversations.") - prompt: Optional[ReusablePrompt] = Field(None, description="Reference to a prompt template and its variables.") - reasoning: Optional[ReasoningConfig] = Field(None, description="Configuration options for reasoning models.") - text: Optional[object] = Field(None, description="Configuration options for a text response from the model.") - truncation: Optional[Literal["auto", "disabled"]] = Field(None, description="The truncation strategy to use for the model response.") - - def set_input_messages(self, messages: List[Message]) -> None: + Configuration for a /v1/responses API request. + + Attributes: + model (str): Model ID used to generate the response, like "gpt-4.1". Defaults to "gpt-4.1". + conversation (Optional[str]): The conversation this response belongs to. + include (Optional[List[Literal[...]]]): Specify additional output data to include. + input (Optional[str | List[Dict[str, str]]]): Text, image, or file inputs to the model. + instructions (Optional[str]): A system or developer message. + max_output_tokens (Optional[int]): Upper bound for generated tokens. + max_tool_calls (Optional[int]): Maximum number of tool calls allowed. + previous_response_id (Optional[str]): ID of the previous response for multi-turn. + prompt (Optional[ReusablePrompt]): Reference to a prompt template and its variables. + reasoning (Optional[ReasoningConfig]): Configuration for reasoning models. + text (Optional[object]): Configuration options for a text response from the model (e.g., JSON schema). + truncation (Optional[Literal["auto", "disabled"]]): The truncation strategy to use. + tools (Optional[List[object]]): An array of tools the model may call. + top_p (Optional[float]): An alternative to sampling with temperature (nucleus sampling). + parallel_tool_calls (Optional[bool]): Whether to allow parallel tool calls. + prompt_cache_key (Optional[str]): Used by OpenAI to cache responses. + safety_identifier (Optional[str]): A stable identifier for policy monitoring. + service_tier (Optional[Literal["auto", "default", "flex", "priority"]]): Specifies the processing type. + store (Optional[bool]): Whether to store the generated model response. + temperature (Optional[float]): Sampling temperature to use (0 to 2). + tool_choice (Optional[str | object]): How the model should select which tool to use. + top_logprobs (Optional[int]): Number of most likely tokens to return at each position (0 to 20). + """ + + conversation: str | None = Field( + None, description="The conversation that this response belongs to." + ) + include: ( + list[ + Literal[ + "code_interpreter_call.outputs", + "computer_call_output.output.image_url", + "file_search_call.results", + "message.input_image.image_url", + "message.output_text.logprobs", + "reasoning.encrypted_content", + ] + ] + | None + ) = Field(None, description="Specify additional output data to include in the model response.") + input: str | list[dict[str, str]] | None = Field( + None, description="Text, image, or file inputs to the model, used to generate a response." + ) + instructions: str | None = Field( + None, description="A system (or developer) message inserted into the model's context." + ) + max_output_tokens: int | None = Field( + None, + gt=0, + description="An upper bound for the number of tokens that can be generated for a response, including visible output tokens and reasoning tokens.", + ) + max_tool_calls: int | None = Field( + None, + gt=0, + description="The maximum number of total calls to built-in tools that can be processed in a response.", + ) + previous_response_id: str | None = Field( + None, + description="The unique ID of the previous response to the model. Use this to create multi-turn conversations.", + ) + prompt: ReusablePrompt | None = Field( + None, description="Reference to a prompt template and its variables." + ) + reasoning: ReasoningConfig | None = Field( + None, description="Configuration options for reasoning models." + ) + text: object | None = Field( + None, description="Configuration options for a text response from the model." + ) + truncation: Literal["auto", "disabled"] | None = Field( + None, description="The truncation strategy to use for the model response." + ) + + def set_input_messages(self, messages: list[Message]) -> None: self.input = [m.serialize() for m in messages] def set_output_structure(self, output_type: type[T]) -> None: @@ -291,10 +390,11 @@ def set_output_structure(self, output_type: type[T]) -> None: "type": "json_schema", "name": output_type.__name__, "schema": schema, - "strict": True + "strict": True, } } + class ChatCompletionsRequest(TextGenerationRequest): """ Configuration for a /v1/chat/completions API request. @@ -325,21 +425,57 @@ class ChatCompletionsRequest(TextGenerationRequest): tool_choice (Optional[str | object]): How the model should select which tool to use. top_logprobs (Optional[int]): Number of most likely tokens to return at each position (0 to 20). """ - messages: List[Dict[str, str]] = Field(None, description="A list of messages comprising the conversation so far.") - frequency_penalty: Optional[float] = Field(None, ge=-2, le=2, description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.") - logit_bias: Optional[Dict] = Field(None, description="Modify the likelihood of specified tokens appearing in the completion.") - logprobs: Optional[bool] = Field(None, description="Whether to return log probabilities of the output tokens or not.") - max_completion_tokens: Optional[int] = Field(None, gt=0, description="An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens.") - modalities: Optional[List[str]] = Field(None, description="Output types that you would like the model to generate.") - n: Optional[int] = Field(None, description="How many chat completion choices to generate for each input message.") - prediction: Optional[object] = Field(None, description="Configuration for a Predicted Output, which can greatly improve response times when large parts of the model response are known ahead of time.") - presence_penalty: Optional[float] = Field(None, ge=-2, le=2, description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.") - reasoning_effort: Optional[Literal["minimal", "low", "medium", "high"]] = Field(None, description="Constrains effort on reasoning for reasoning models.") - response_format: Optional[Dict] = Field(None, description="An object specifying the format that the model must output.") - verbosity: Optional[Literal["low", "medium", "high"]] = Field(None, description="Constrains the verbosity of the model's response.") - web_search_options: Optional[object] = Field(None, description="This tool searches the web for relevant results to use in a response.") - - def set_input_messages(self, messages: List[Message]) -> None: + + messages: list[dict[str, str]] = Field( + [], description="A list of messages comprising the conversation so far." + ) + frequency_penalty: float | None = Field( + None, + ge=-2, + le=2, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.", + ) + logit_bias: dict | None = Field( + None, description="Modify the likelihood of specified tokens appearing in the completion." + ) + logprobs: bool | None = Field( + None, description="Whether to return log probabilities of the output tokens or not." + ) + max_completion_tokens: int | None = Field( + None, + gt=0, + description="An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens.", + ) + modalities: list[str] | None = Field( + None, description="Output types that you would like the model to generate." + ) + n: int | None = Field( + None, description="How many chat completion choices to generate for each input message." + ) + prediction: object | None = Field( + None, + description="Configuration for a Predicted Output, which can greatly improve response times when large parts of the model response are known ahead of time.", + ) + presence_penalty: float | None = Field( + None, + ge=-2, + le=2, + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.", + ) + reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = Field( + None, description="Constrains effort on reasoning for reasoning models." + ) + response_format: dict | None = Field( + None, description="An object specifying the format that the model must output." + ) + verbosity: Literal["low", "medium", "high"] | None = Field( + None, description="Constrains the verbosity of the model's response." + ) + web_search_options: object | None = Field( + None, description="This tool searches the web for relevant results to use in a response." + ) + + def set_input_messages(self, messages: list[Message]) -> None: self.messages = [m.serialize() for m in messages] def set_output_structure(self, output_type: type[T]) -> None: @@ -349,10 +485,11 @@ def set_output_structure(self, output_type: type[T]) -> None: "type": "json_schema", "name": output_type.__name__, "schema": schema, - "strict": True + "strict": True, } } + class EmbeddingsRequest(BaseRequest): """ Configuration for a /v1/embeddings API request. @@ -364,14 +501,27 @@ class EmbeddingsRequest(BaseRequest): encoding_format (Optional[Literal["base64", "float"]]): The format to return the embeddings in. user (Optional[str]): A unique identifier representing the end-user. """ - input: Union[str | List[str]] = Field(None, description="Input text to embed, encoded as a string or array of tokens.") - dimensions: Optional[int] = Field(None, ge=1, description="The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.") - encoding_format: Optional[Literal["base64", "float"]] = Field(None, description="The format to return the embeddings in. Can be either float or base64.") - user: Optional[str] = Field(None, description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. ") - def set_input(self, inp: Union[str | List[str]]) -> None: + input: str | list[str] = Field( + "", description="Input text to embed, encoded as a string or array of tokens." + ) + dimensions: int | None = Field( + None, + ge=1, + description="The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.", + ) + encoding_format: Literal["base64", "float"] | None = Field( + None, description="The format to return the embeddings in. Can be either float or base64." + ) + user: str | None = Field( + None, + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. ", + ) + + def set_input(self, inp: str | list[str]) -> None: self.input = inp + class RequestTemplate(BaseModel): """ A template defining a batch job, including its name, description, @@ -385,21 +535,22 @@ class RequestTemplate(BaseModel): request (BaseRequest): The API-specific request configuration (e.g., ResponsesRequest). metadata (Optional[Dict[Any, Any]]): Optional metadata associated with the request template. """ + name: str description: str - prompt: Union[PromptTemplate, ReusablePrompt] + prompt: PromptTemplate | ReusablePrompt request: BaseRequest - metadata: Optional[Dict[Any, Any]] = None + metadata: dict[Any, Any] | None = None def save(self, path: PathLike) -> None: if not str(path).endswith(".json"): - raise ValueError("RequestTemplate has to be saves as \".json\" file.") + raise ValueError('RequestTemplate has to be saves as ".json" file.') Path(path).parent.mkdir(parents=True, exist_ok=True) - with open(path, 'w+') as f: + with open(path, "w+") as f: f.write(self.model_dump_json(indent=4)) @classmethod - def load(cls, path: PathLike) -> Self: - with open(path, 'r') as f: + def load(cls, path: PathLike) -> "RequestTemplate": + with open(path) as f: return RequestTemplate.model_validate_json(f.read()) diff --git a/src/openbatch/validation.py b/src/openbatch/validation.py new file mode 100644 index 0000000..0176d9e --- /dev/null +++ b/src/openbatch/validation.py @@ -0,0 +1,326 @@ +""" +Validation utilities for OpenAI batch job files. + +This module provides functions to validate JSONL batch files before uploading them +to the OpenAI Batch API, helping catch errors early and ensure compliance with +API requirements. +""" + +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, ClassVar, TextIO + + +@dataclass +class ValidationResult: + """ + Result of batch file validation. + + Attributes: + is_valid (bool): Whether the batch file is valid + errors (List[str]): List of validation errors found + warnings (List[str]): List of non-critical warnings + stats (Dict[str, Any]): Statistics about the batch file + """ + + is_valid: bool + errors: list[str] = field(default_factory=list) + warnings: list[str] = field(default_factory=list) + stats: dict[str, Any] = field(default_factory=dict) + + def __str__(self) -> str: + """Human-readable summary of validation results.""" + lines = [] + lines.append(f"Validation: {'PASSED' if self.is_valid else 'FAILED'}") + + if self.stats: + lines.append("\nStatistics:") + for key, value in self.stats.items(): + lines.append(f" {key}: {value}") + + if self.errors: + lines.append(f"\nErrors ({len(self.errors)}):") + for error in self.errors: + lines.append(f" β€’ {error}") + + if self.warnings: + lines.append(f"\nWarnings ({len(self.warnings)}):") + for warning in self.warnings: + lines.append(f" β€’ {warning}") + + return "\n".join(lines) + + +class BatchFileValidator: + """ + Validator for OpenAI batch job JSONL files. + + Validates batch files against OpenAI Batch API requirements including: + - Valid JSONL format + - Unique custom_ids + - Required fields present + - Valid endpoint URLs + - File size limits + """ + + # OpenAI Batch API constraints (as of 2026) + MAX_FILE_SIZE_MB: int = 200 + MAX_REQUESTS: int = 50000 + VALID_ENDPOINTS: ClassVar[set[str]] = { + "/v1/responses", + "/v1/chat/completions", + "/v1/embeddings", + } + REQUIRED_FIELDS: ClassVar[set[str]] = {"custom_id", "method", "url", "body"} + + def __init__( + self, + check_custom_id_uniqueness: bool = True, + check_file_size: bool = True, + check_request_count: bool = True, + allow_mixed_endpoints: bool = False, + ): + """ + Initialize the validator with configuration options. + + Args: + check_custom_id_uniqueness: Check for duplicate custom_ids + check_file_size: Check file size against limits + check_request_count: Check number of requests against limits + allow_mixed_endpoints: Allow multiple endpoint types in one file (not recommended) + """ + self.check_custom_id_uniqueness = check_custom_id_uniqueness + self.check_file_size = check_file_size + self.check_request_count = check_request_count + self.allow_mixed_endpoints = allow_mixed_endpoints + + def validate_file(self, file_path: str | Path) -> ValidationResult: + """ + Validate a batch file. + + Args: + file_path: Path to the JSONL batch file + + Returns: + ValidationResult with errors, warnings, and statistics + """ + file_path = Path(file_path) + result = ValidationResult(is_valid=True) + + # Check file exists + if not file_path.exists(): + result.errors.append(f"File not found: {file_path}") + result.is_valid = False + return result + + # Check file extension + if file_path.suffix != ".jsonl": + result.warnings.append(f"File extension is '{file_path.suffix}', expected '.jsonl'") + + # Check file size + if self.check_file_size: + file_size_mb = file_path.stat().st_size / (1024 * 1024) + result.stats["file_size_mb"] = round(file_size_mb, 2) + + if file_size_mb > self.MAX_FILE_SIZE_MB: + result.errors.append( + f"File size ({file_size_mb:.2f} MB) exceeds limit ({self.MAX_FILE_SIZE_MB} MB)" + ) + result.is_valid = False + + # Validate content + try: + with open(file_path, encoding="utf-8") as f: + self._validate_content(f, result) + except Exception as e: + result.errors.append(f"Error reading file: {e!s}") + result.is_valid = False + + return result + + def _validate_content(self, file_handle: TextIO, result: ValidationResult) -> None: + """Validate the content of the batch file.""" + custom_ids: set[str] = set() + endpoints: set[str] = set() + line_number = 0 + + for line in file_handle: + line_number += 1 + line = line.strip() + + # Skip empty lines + if not line: + result.warnings.append(f"Line {line_number}: Empty line (will be ignored)") + continue + + # Parse JSON + try: + request = json.loads(line) + except json.JSONDecodeError as e: + result.errors.append(f"Line {line_number}: Invalid JSON - {e!s}") + result.is_valid = False + continue + + # Validate request structure + self._validate_request(request, line_number, custom_ids, endpoints, result) + + # Update statistics + result.stats["total_requests"] = line_number + result.stats["unique_custom_ids"] = len(custom_ids) + result.stats["endpoints_used"] = sorted(endpoints) + + # Check request count + if self.check_request_count and line_number > self.MAX_REQUESTS: + result.errors.append( + f"Request count ({line_number}) exceeds limit ({self.MAX_REQUESTS})" + ) + result.is_valid = False + + # Check for mixed endpoints + if not self.allow_mixed_endpoints and len(endpoints) > 1: + result.warnings.append( + f"Multiple endpoint types detected: {sorted(endpoints)}. " + "OpenAI recommends one request type per file." + ) + + def _validate_request( + self, + request: dict[str, Any], + line_number: int, + custom_ids: set[str], + endpoints: set[str], + result: ValidationResult, + ) -> None: + """Validate a single request object.""" + + # Check required fields + missing_fields = self.REQUIRED_FIELDS - set(request.keys()) + if missing_fields: + result.errors.append(f"Line {line_number}: Missing required fields: {missing_fields}") + result.is_valid = False + return + + # Validate custom_id + custom_id = request.get("custom_id") + if not custom_id or not isinstance(custom_id, str): + result.errors.append( + f"Line {line_number}: Invalid custom_id (must be a non-empty string)" + ) + result.is_valid = False + elif self.check_custom_id_uniqueness: + if custom_id in custom_ids: + result.errors.append(f"Line {line_number}: Duplicate custom_id '{custom_id}'") + result.is_valid = False + else: + custom_ids.add(custom_id) + + # Validate method + method = request.get("method") + if method != "POST": + result.errors.append(f"Line {line_number}: Invalid method '{method}' (must be 'POST')") + result.is_valid = False + + # Validate URL + url = request.get("url") + if url not in self.VALID_ENDPOINTS: + result.errors.append( + f"Line {line_number}: Invalid endpoint '{url}'. " + f"Valid endpoints: {self.VALID_ENDPOINTS}" + ) + result.is_valid = False + else: + endpoints.add(url) + + # Validate body + body = request.get("body") + if not isinstance(body, dict): + result.errors.append(f"Line {line_number}: 'body' must be a JSON object") + result.is_valid = False + else: + self._validate_body(body, str(url), line_number, result) + + def _validate_body( + self, body: dict[str, Any], endpoint: str, line_number: int, result: ValidationResult + ) -> None: + """Validate the request body based on endpoint type.""" + + # Check for model field (required for all endpoints) + if "model" not in body: + result.errors.append(f"Line {line_number}: Missing required field 'model' in body") + result.is_valid = False + + # Endpoint-specific validation + if endpoint == "/v1/responses": + if "input" not in body and "prompt" not in body: + result.errors.append( + f"Line {line_number}: Responses API requires either 'input' or 'prompt' in body" + ) + result.is_valid = False + + elif endpoint == "/v1/chat/completions": + if "messages" not in body: + result.errors.append( + f"Line {line_number}: Chat Completions API requires 'messages' in body" + ) + result.is_valid = False + elif not isinstance(body["messages"], list): + result.errors.append(f"Line {line_number}: 'messages' must be an array") + result.is_valid = False + + elif endpoint == "/v1/embeddings" and "input" not in body: + result.errors.append(f"Line {line_number}: Embeddings API requires 'input' in body") + result.is_valid = False + + +def validate_batch_file( + file_path: str | Path, + strict: bool = True, + check_custom_id_uniqueness: bool = True, + allow_mixed_endpoints: bool = False, +) -> ValidationResult: + """ + Validate a batch file (convenience function). + + Args: + file_path: Path to the JSONL batch file + strict: Enable all checks (file size, request count) + check_custom_id_uniqueness: Check for duplicate custom_ids + allow_mixed_endpoints: Allow multiple endpoint types in one file + + Returns: + ValidationResult with errors, warnings, and statistics + + Example: + >>> result = validate_batch_file("my_batch.jsonl") + >>> if result.is_valid: + ... print("File is valid!") + ... else: + ... print(result) + """ + validator = BatchFileValidator( + check_custom_id_uniqueness=check_custom_id_uniqueness, + check_file_size=strict, + check_request_count=strict, + allow_mixed_endpoints=allow_mixed_endpoints, + ) + return validator.validate_file(file_path) + + +def quick_validate(file_path: str | Path) -> bool: + """ + Quick validation check (returns True/False). + + Args: + file_path: Path to the JSONL batch file + + Returns: + True if valid, False otherwise + + Example: + >>> if quick_validate("my_batch.jsonl"): + ... # Proceed with upload + ... pass + """ + result = validate_batch_file(file_path) + return result.is_valid diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_collector.py b/tests/test_collector.py new file mode 100644 index 0000000..ec5ae66 --- /dev/null +++ b/tests/test_collector.py @@ -0,0 +1,368 @@ +import json + +import pytest +from pydantic import BaseModel, Field + +from openbatch.collector import BatchCollector, ChatCompletions, Embeddings, Responses +from openbatch.model import ReasoningConfig + + +@pytest.fixture +def temp_batch_file(tmp_path): + """Provides a temporary file path for batch files.""" + return tmp_path / "test_batch.jsonl" + + +class TestResponses: + def test_responses_create(self, temp_batch_file): + responses = Responses(temp_batch_file) + responses.create( + custom_id="req_1", + model="gpt-4", + input="What is Python?", + instructions="You are a helpful assistant", + max_output_tokens=100, + ) + + assert temp_batch_file.exists() + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert data["custom_id"] == "req_1" + assert data["url"] == "/v1/responses" + assert data["body"]["model"] == "gpt-4" + assert data["body"]["input"] == "What is Python?" + assert data["body"]["instructions"] == "You are a helpful assistant" + assert data["body"]["max_output_tokens"] == 100 + + def test_responses_parse_without_format(self, temp_batch_file): + responses = Responses(temp_batch_file) + responses.parse( + custom_id="req_2", + model="gpt-4", + input="Analyze this text", + instructions="Be concise", + ) + + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert data["custom_id"] == "req_2" + assert data["body"]["model"] == "gpt-4" + assert "text" not in data["body"] # No text format specified + + def test_responses_parse_with_format(self, temp_batch_file): + class Analysis(BaseModel): + summary: str = Field(description="Brief summary") + sentiment: str = Field(description="Sentiment analysis") + + responses = Responses(temp_batch_file) + responses.parse( + custom_id="req_3", + model="gpt-4", + text_format=Analysis, + input="Great product!", + instructions="Analyze sentiment", + ) + + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert data["custom_id"] == "req_3" + assert "text" in data["body"] + assert data["body"]["text"]["format"]["type"] == "json_schema" + assert data["body"]["text"]["format"]["name"] == "Analysis" + assert data["body"]["text"]["format"]["strict"] is True + + def test_responses_with_reasoning_config(self, temp_batch_file): + responses = Responses(temp_batch_file) + responses.create( + custom_id="req_4", + model="gpt-5-mini", + input="Complex problem", + reasoning=ReasoningConfig(effort="high", summary="detailed"), + ) + + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert data["body"]["reasoning"]["effort"] == "high" + assert data["body"]["reasoning"]["summary"] == "detailed" + + def test_responses_multiple_requests(self, temp_batch_file): + responses = Responses(temp_batch_file) + responses.create(custom_id="req_1", model="gpt-4", input="First") + responses.create(custom_id="req_2", model="gpt-4", input="Second") + + with open(temp_batch_file) as f: + lines = f.readlines() + + assert len(lines) == 2 + data1 = json.loads(lines[0]) + data2 = json.loads(lines[1]) + assert data1["custom_id"] == "req_1" + assert data2["custom_id"] == "req_2" + + +class TestChatCompletions: + def test_chat_completions_create(self, temp_batch_file): + chat = ChatCompletions(temp_batch_file) + chat.create( + custom_id="chat_1", + model="gpt-4", + messages=[ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + ], + temperature=0.7, + ) + + assert temp_batch_file.exists() + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert data["custom_id"] == "chat_1" + assert data["url"] == "/v1/chat/completions" + assert data["body"]["model"] == "gpt-4" + assert len(data["body"]["messages"]) == 2 + assert data["body"]["temperature"] == 0.7 + + def test_chat_completions_parse_without_format(self, temp_batch_file): + chat = ChatCompletions(temp_batch_file) + chat.parse( + custom_id="chat_2", + model="gpt-4", + messages=[{"role": "user", "content": "Hi"}], + ) + + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert data["custom_id"] == "chat_2" + assert "response_format" not in data["body"] + + def test_chat_completions_parse_with_format(self, temp_batch_file): + class Response(BaseModel): + answer: str + confidence: float + + chat = ChatCompletions(temp_batch_file) + chat.parse( + custom_id="chat_3", + model="gpt-4", + response_format=Response, + messages=[{"role": "user", "content": "What is 2+2?"}], + ) + + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert data["custom_id"] == "chat_3" + assert "response_format" in data["body"] + assert data["body"]["response_format"]["format"]["name"] == "Response" + assert data["body"]["response_format"]["format"]["strict"] is True + + def test_chat_completions_with_reasoning_effort(self, temp_batch_file): + chat = ChatCompletions(temp_batch_file) + chat.create( + custom_id="chat_4", + model="o1-mini", + messages=[{"role": "user", "content": "Complex question"}], + reasoning_effort="high", + ) + + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert data["body"]["reasoning_effort"] == "high" + + def test_chat_completions_multiple_requests(self, temp_batch_file): + chat = ChatCompletions(temp_batch_file) + chat.create(custom_id="chat_1", model="gpt-4", messages=[{"role": "user", "content": "Hi"}]) + chat.create( + custom_id="chat_2", + model="gpt-4", + messages=[{"role": "user", "content": "Bye"}], + ) + + with open(temp_batch_file) as f: + lines = f.readlines() + + assert len(lines) == 2 + + +class TestEmbeddings: + def test_embeddings_create_single_input(self, temp_batch_file): + embeddings = Embeddings(temp_batch_file) + embeddings.create( + custom_id="emb_1", + model="text-embedding-3-small", + inp="Text to embed", + ) + + assert temp_batch_file.exists() + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert data["custom_id"] == "emb_1" + assert data["url"] == "/v1/embeddings" + assert data["body"]["model"] == "text-embedding-3-small" + assert data["body"]["input"] == "Text to embed" + + def test_embeddings_create_list_input(self, temp_batch_file): + embeddings = Embeddings(temp_batch_file) + embeddings.create( + custom_id="emb_2", + model="text-embedding-3-small", + inp=["Text 1", "Text 2", "Text 3"], + ) + + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert isinstance(data["body"]["input"], list) + assert len(data["body"]["input"]) == 3 + + def test_embeddings_with_dimensions(self, temp_batch_file): + embeddings = Embeddings(temp_batch_file) + embeddings.create( + custom_id="emb_3", + model="text-embedding-3-small", + inp="Test", + dimensions=512, + ) + + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert data["body"]["dimensions"] == 512 + + def test_embeddings_multiple_requests(self, temp_batch_file): + embeddings = Embeddings(temp_batch_file) + embeddings.create(custom_id="emb_1", model="text-embedding-3-small", inp="First") + embeddings.create(custom_id="emb_2", model="text-embedding-3-small", inp="Second") + + with open(temp_batch_file) as f: + lines = f.readlines() + + assert len(lines) == 2 + + +class TestBatchCollector: + def test_batch_collector_initialization(self, temp_batch_file): + collector = BatchCollector(temp_batch_file) + assert isinstance(collector.responses, Responses) + assert isinstance(collector.chat.completions, ChatCompletions) + assert isinstance(collector.embeddings, Embeddings) + + def test_batch_collector_responses_api(self, temp_batch_file): + collector = BatchCollector(temp_batch_file) + collector.responses.create( + custom_id="req_1", + model="gpt-4", + input="Hello", + ) + + assert temp_batch_file.exists() + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert data["url"] == "/v1/responses" + + def test_batch_collector_chat_completions_api(self, temp_batch_file): + collector = BatchCollector(temp_batch_file) + collector.chat.completions.create( + custom_id="chat_1", + model="gpt-4", + messages=[{"role": "user", "content": "Hi"}], + ) + + assert temp_batch_file.exists() + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert data["url"] == "/v1/chat/completions" + + def test_batch_collector_embeddings_api(self, temp_batch_file): + collector = BatchCollector(temp_batch_file) + collector.embeddings.create( + custom_id="emb_1", + model="text-embedding-3-small", + inp="Text", + ) + + assert temp_batch_file.exists() + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert data["url"] == "/v1/embeddings" + + def test_batch_collector_mixed_apis_in_sequence(self, tmp_path): + # Demonstrate that different endpoints need different files + responses_file = tmp_path / "responses.jsonl" + chat_file = tmp_path / "chat.jsonl" + embeddings_file = tmp_path / "embeddings.jsonl" + + responses_collector = BatchCollector(responses_file) + responses_collector.responses.create(custom_id="req_1", model="gpt-4", input="Test") + + chat_collector = BatchCollector(chat_file) + chat_collector.chat.completions.create( + custom_id="chat_1", + model="gpt-4", + messages=[{"role": "user", "content": "Test"}], + ) + + embeddings_collector = BatchCollector(embeddings_file) + embeddings_collector.embeddings.create( + custom_id="emb_1", model="text-embedding-3-small", inp="Test" + ) + + assert responses_file.exists() + assert chat_file.exists() + assert embeddings_file.exists() + + def test_batch_collector_responses_parse_structured_output(self, temp_batch_file): + class TaskAnalysis(BaseModel): + task_type: str + complexity: str + estimated_time: int + + collector = BatchCollector(temp_batch_file) + collector.responses.parse( + custom_id="analysis_1", + model="gpt-4", + text_format=TaskAnalysis, + input="Analyze this task: Build a web scraper", + instructions="Provide structured analysis", + ) + + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert "text" in data["body"] + assert data["body"]["text"]["format"]["name"] == "TaskAnalysis" + assert "task_type" in str(data["body"]["text"]["format"]["schema"]) + + def test_batch_collector_chat_parse_structured_output(self, temp_batch_file): + class CodeReview(BaseModel): + issues: list[str] + suggestions: list[str] + rating: int + + collector = BatchCollector(temp_batch_file) + collector.chat.completions.parse( + custom_id="review_1", + model="gpt-4", + response_format=CodeReview, + messages=[ + {"role": "system", "content": "You are a code reviewer"}, + {"role": "user", "content": "Review this code: def foo(): pass"}, + ], + ) + + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert "response_format" in data["body"] + assert data["body"]["response_format"]["format"]["name"] == "CodeReview" diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..817fa94 --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,394 @@ +"""Integration tests that verify end-to-end workflows.""" + +import json + +import pytest +from pydantic import BaseModel, Field + +from openbatch import ( + BatchCollector, + BatchJobManager, + EmbeddingInputInstance, + EmbeddingsRequest, + Message, + PromptTemplate, + PromptTemplateInputInstance, + ReasoningConfig, + ResponsesRequest, +) + + +@pytest.fixture +def temp_dir(tmp_path): + """Provides a temporary directory for test files.""" + return tmp_path + + +class TestEndToEndBatchCreation: + """Test complete workflows from creation to file output.""" + + def test_batch_collector_complete_workflow(self, temp_dir): + """Test BatchCollector API for all three endpoint types.""" + # Separate files for each API type (as required by OpenAI) + responses_file = temp_dir / "responses.jsonl" + chat_file = temp_dir / "chat.jsonl" + embeddings_file = temp_dir / "embeddings.jsonl" + + # Responses API + responses_collector = BatchCollector(responses_file) + responses_collector.responses.create( + custom_id="resp_1", + model="gpt-4", + input="What is machine learning?", + instructions="Be concise", + max_output_tokens=100, + ) + + # Chat Completions API + chat_collector = BatchCollector(chat_file) + chat_collector.chat.completions.create( + custom_id="chat_1", + model="gpt-4", + messages=[ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Explain quantum computing"}, + ], + temperature=0.7, + ) + + # Embeddings API + embeddings_collector = BatchCollector(embeddings_file) + embeddings_collector.embeddings.create( + custom_id="emb_1", + model="text-embedding-3-small", + inp="Machine learning is a subset of artificial intelligence", + ) + + # Verify all files exist and contain valid JSON + assert responses_file.exists() + assert chat_file.exists() + assert embeddings_file.exists() + + with open(responses_file) as f: + resp_data = json.loads(f.readline()) + assert resp_data["url"] == "/v1/responses" + + with open(chat_file) as f: + chat_data = json.loads(f.readline()) + assert chat_data["url"] == "/v1/chat/completions" + + with open(embeddings_file) as f: + emb_data = json.loads(f.readline()) + assert emb_data["url"] == "/v1/embeddings" + + def test_batch_job_manager_templated_workflow(self, temp_dir): + """Test BatchJobManager with prompt templates for bulk generation.""" + batch_file = temp_dir / "marketing_batch.jsonl" + + # Setup + template = PromptTemplate( + messages=[ + Message( + role="system", + content="You are a marketing copywriter. Generate a catchy, two-sentence description.", + ), + Message( + role="user", + content="Product: {product_name}, Features: {features}", + ), + ] + ) + + common_config = ResponsesRequest(model="gpt-4-mini", temperature=0.8, max_output_tokens=100) + + products = [ + PromptTemplateInputInstance( + id="prod_001", + prompt_value_mapping={ + "product_name": "AeroGlide Drone", + "features": "4K camera, 30-min flight", + }, + ), + PromptTemplateInputInstance( + id="prod_002", + prompt_value_mapping={ + "product_name": "HydroPure Bottle", + "features": "Self-cleaning, insulated steel", + }, + ), + PromptTemplateInputInstance( + id="prod_003", + prompt_value_mapping={ + "product_name": "SmartDesk Pro", + "features": "Height adjustable, USB charging", + }, + instance_request_options={"temperature": 0.5}, # Override for this instance + ), + ] + + # Generate batch + manager = BatchJobManager() + manager.add_templated_instances( + prompt=template, + common_request=common_config, + input_instances=products, + save_file_path=batch_file, + ) + + # Verify + assert batch_file.exists() + with open(batch_file) as f: + lines = f.readlines() + + assert len(lines) == 3 + + # Check first product + data1 = json.loads(lines[0]) + assert data1["custom_id"] == "prod_001" + assert "AeroGlide Drone" in str(data1["body"]["input"]) + assert data1["body"]["temperature"] == 0.8 + + # Check third product with overridden temperature + data3 = json.loads(lines[2]) + assert data3["custom_id"] == "prod_003" + assert data3["body"]["temperature"] == 0.5 + + def test_batch_job_manager_embeddings_workflow(self, temp_dir): + """Test BatchJobManager for bulk embeddings generation.""" + batch_file = temp_dir / "embeddings_batch.jsonl" + + common_config = EmbeddingsRequest(model="text-embedding-3-small", dimensions=512) + + documents = [ + EmbeddingInputInstance(id="doc_1", input="The sky is blue."), + EmbeddingInputInstance(id="doc_2", input="Grass is green."), + EmbeddingInputInstance(id="doc_3", input="Water is wet."), + EmbeddingInputInstance( + id="doc_4", + input="Fire is hot.", + instance_request_options={"dimensions": 256}, + ), + ] + + manager = BatchJobManager() + manager.add_embedding_requests( + inputs=documents, common_request=common_config, save_file_path=batch_file + ) + + assert batch_file.exists() + with open(batch_file) as f: + lines = f.readlines() + + assert len(lines) == 4 + + data1 = json.loads(lines[0]) + assert data1["body"]["dimensions"] == 512 + + data4 = json.loads(lines[3]) + assert data4["body"]["dimensions"] == 256 # Overridden + + +class TestStructuredOutputWorkflows: + """Test workflows with structured JSON output.""" + + def test_responses_api_with_structured_output(self, temp_dir): + """Test Responses API with Pydantic model for structured output.""" + + class SentimentAnalysis(BaseModel): + sentiment: str = Field(description="positive, negative, or neutral") + confidence: float = Field(description="Confidence score 0-1") + key_phrases: list[str] = Field(description="Important phrases") + + batch_file = temp_dir / "sentiment_batch.jsonl" + collector = BatchCollector(batch_file) + + texts_to_analyze = [ + "This product exceeded my expectations! Absolutely love it.", + "Terrible experience. Would not recommend to anyone.", + "It's okay, nothing special but does the job.", + ] + + for idx, text in enumerate(texts_to_analyze): + collector.responses.parse( + custom_id=f"sentiment_{idx}", + model="gpt-4", + text_format=SentimentAnalysis, + input=text, + instructions="Analyze the sentiment of the given text", + ) + + assert batch_file.exists() + with open(batch_file) as f: + lines = f.readlines() + + assert len(lines) == 3 + + # Verify structured output configuration + data = json.loads(lines[0]) + assert "text" in data["body"] + assert data["body"]["text"]["format"]["name"] == "SentimentAnalysis" + assert data["body"]["text"]["format"]["strict"] is True + schema = data["body"]["text"]["format"]["schema"] + assert "sentiment" in schema["properties"] + assert "confidence" in schema["properties"] + assert "key_phrases" in schema["properties"] + + def test_chat_completions_with_structured_output(self, temp_dir): + """Test Chat Completions API with structured output.""" + + class RecipeExtraction(BaseModel): + recipe_name: str + ingredients: list[str] + steps: list[str] + prep_time_minutes: int + difficulty: str + + batch_file = temp_dir / "recipes_batch.jsonl" + collector = BatchCollector(batch_file) + + recipe_texts = [ + "How to make scrambled eggs: Beat 2 eggs, heat pan, cook for 2 minutes. Takes 5 minutes. Easy.", + "Chocolate cake recipe: Mix flour, sugar, cocoa. Bake at 350F for 30 minutes. Medium difficulty. Takes 45 minutes.", + ] + + for idx, text in enumerate(recipe_texts): + collector.chat.completions.parse( + custom_id=f"recipe_{idx}", + model="gpt-4", + response_format=RecipeExtraction, + messages=[ + { + "role": "system", + "content": "Extract structured recipe information", + }, + {"role": "user", "content": text}, + ], + ) + + with open(batch_file) as f: + lines = f.readlines() + + assert len(lines) == 2 + data = json.loads(lines[0]) + assert "response_format" in data["body"] + assert data["body"]["response_format"]["format"]["name"] == "RecipeExtraction" + + +class TestReasoningModelsWorkflow: + """Test workflows with reasoning models.""" + + def test_responses_api_with_reasoning_config(self, temp_dir): + """Test Responses API with reasoning configuration.""" + batch_file = temp_dir / "reasoning_batch.jsonl" + collector = BatchCollector(batch_file) + + complex_problems = [ + { + "id": "logic_1", + "problem": "If all A are B, and all B are C, are all A necessarily C?", + "effort": "high", + }, + { + "id": "logic_2", + "problem": "What is the flaw in this argument: All birds can fly. Penguins are birds. Therefore penguins can fly.", + "effort": "medium", + }, + ] + + for item in complex_problems: + collector.responses.create( + custom_id=item["id"], + model="o1-mini", + input=item["problem"], + instructions="Analyze the logical structure carefully", + reasoning=ReasoningConfig(effort=item["effort"], summary="detailed"), + ) + + with open(batch_file) as f: + lines = f.readlines() + + assert len(lines) == 2 + data1 = json.loads(lines[0]) + assert data1["body"]["reasoning"]["effort"] == "high" + assert data1["body"]["reasoning"]["summary"] == "detailed" + + +class TestUnicodeAndSpecialCharacters: + """Test handling of non-ASCII characters.""" + + def test_ensure_ascii_true(self, temp_dir): + """Test that ensure_ascii=True escapes non-ASCII characters.""" + batch_file = temp_dir / "unicode_escaped.jsonl" + manager = BatchJobManager(ensure_ascii=True) + + request = ResponsesRequest(model="gpt-4", input="Hello δΈ–η•Œ! ΠŸΡ€ΠΈΠ²Π΅Ρ‚! Ω…Ψ±Ψ­Ψ¨Ψ§") + manager.add("unicode_test", request, batch_file) + + with open(batch_file, encoding="utf-8") as f: + content = f.read() + + # Should contain escaped unicode + assert "\\u" in content + # Raw characters should not be present + assert "δΈ–η•Œ" not in content + + def test_ensure_ascii_false(self, temp_dir): + """Test that ensure_ascii=False preserves non-ASCII characters.""" + batch_file = temp_dir / "unicode_raw.jsonl" + manager = BatchJobManager(ensure_ascii=False) + + request = ResponsesRequest(model="gpt-4", input="Hello δΈ–η•Œ! ΠŸΡ€ΠΈΠ²Π΅Ρ‚! Ω…Ψ±Ψ­Ψ¨Ψ§ Emoji: πŸš€") + manager.add("unicode_test", request, batch_file) + + with open(batch_file, encoding="utf-8") as f: + content = f.read() + + # Raw characters should be present + assert "δΈ–η•Œ" in content + assert "ΠŸΡ€ΠΈΠ²Π΅Ρ‚" in content + assert "Ω…Ψ±Ψ­Ψ¨Ψ§" in content + assert "πŸš€" in content + + +class TestLargeScaleBatchGeneration: + """Test generation of large batch files.""" + + def test_generate_1000_requests(self, temp_dir): + """Test generating a batch file with 1000 requests.""" + batch_file = temp_dir / "large_batch.jsonl" + + template = PromptTemplate(messages=[Message(role="user", content="Classify: {text}")]) + + common_request = ResponsesRequest(model="gpt-4-mini", max_output_tokens=10) + + # Generate 1000 instances + instances = [ + PromptTemplateInputInstance( + id=f"classify_{i:04d}", prompt_value_mapping={"text": f"Sample text {i}"} + ) + for i in range(1000) + ] + + manager = BatchJobManager() + manager.add_templated_instances( + prompt=template, + common_request=common_request, + input_instances=instances, + save_file_path=batch_file, + ) + + # Verify + assert batch_file.exists() + with open(batch_file) as f: + lines = f.readlines() + + assert len(lines) == 1000 + + # Spot check first and last + first = json.loads(lines[0]) + last = json.loads(lines[999]) + + assert first["custom_id"] == "classify_0000" + assert last["custom_id"] == "classify_0999" + assert "Sample text 0" in str(first["body"]["input"]) + assert "Sample text 999" in str(last["body"]["input"]) diff --git a/tests/test_manager.py b/tests/test_manager.py new file mode 100644 index 0000000..cbb0b3e --- /dev/null +++ b/tests/test_manager.py @@ -0,0 +1,374 @@ +import json +import warnings + +import pytest + +from openbatch.manager import BatchJobManager +from openbatch.model import ( + ChatCompletionsRequest, + EmbeddingInputInstance, + EmbeddingsRequest, + Message, + PromptTemplate, + PromptTemplateInputInstance, + ResponsesRequest, + ReusablePrompt, +) + + +@pytest.fixture +def temp_batch_file(tmp_path): + """Provides a temporary file path for batch files.""" + return tmp_path / "test_batch.jsonl" + + +@pytest.fixture +def manager(): + """Provides a BatchJobManager instance.""" + return BatchJobManager() + + +@pytest.fixture +def manager_no_ascii(): + """Provides a BatchJobManager instance with ensure_ascii=False.""" + return BatchJobManager(ensure_ascii=False) + + +class TestBatchJobManagerAdd: + def test_add_responses_request(self, manager, temp_batch_file): + request = ResponsesRequest(model="gpt-4", input="Hello world") + manager.add("test_id", request, temp_batch_file) + + assert temp_batch_file.exists() + with open(temp_batch_file) as f: + line = f.readline() + data = json.loads(line) + + assert data["custom_id"] == "test_id" + assert data["method"] == "POST" + assert data["url"] == "/v1/responses" + assert data["body"]["model"] == "gpt-4" + assert data["body"]["input"] == "Hello world" + + def test_add_chat_completions_request(self, manager, temp_batch_file): + request = ChatCompletionsRequest( + model="gpt-4", messages=[{"role": "user", "content": "Hi"}] + ) + manager.add("chat_id", request, temp_batch_file) + + assert temp_batch_file.exists() + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert data["custom_id"] == "chat_id" + assert data["url"] == "/v1/chat/completions" + assert data["body"]["messages"][0]["role"] == "user" + + def test_add_embeddings_request(self, manager, temp_batch_file): + request = EmbeddingsRequest(model="text-embedding-3-small", input="Text to embed") + manager.add("emb_id", request, temp_batch_file) + + assert temp_batch_file.exists() + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert data["custom_id"] == "emb_id" + assert data["url"] == "/v1/embeddings" + assert data["body"]["input"] == "Text to embed" + + def test_add_multiple_requests(self, manager, temp_batch_file): + request1 = ResponsesRequest(model="gpt-4", input="First") + request2 = ResponsesRequest(model="gpt-4", input="Second") + + manager.add("id1", request1, temp_batch_file) + manager.add("id2", request2, temp_batch_file) + + with open(temp_batch_file) as f: + lines = f.readlines() + + assert len(lines) == 2 + data1 = json.loads(lines[0]) + data2 = json.loads(lines[1]) + assert data1["custom_id"] == "id1" + assert data2["custom_id"] == "id2" + + def test_add_responses_request_without_input_or_prompt_raises(self, manager, temp_batch_file): + request = ResponsesRequest(model="gpt-4") + with pytest.raises(ValueError, match="must define either an input or a prompt"): + manager.add("test_id", request, temp_batch_file) + + def test_add_chat_completions_request_without_messages_raises(self, manager, temp_batch_file): + # messages is required in the Pydantic model, so we create a request + # and then set messages to None manually + request = ChatCompletionsRequest(model="gpt-4", messages=[]) + request.messages = None + with pytest.raises(ValueError, match="must define messages"): + manager.add("test_id", request, temp_batch_file) + + def test_add_embeddings_request_without_input_raises(self, manager, temp_batch_file): + # input is required in the Pydantic model, so we create a request + # and then set input to None manually + request = EmbeddingsRequest(model="text-embedding-3-small", input="dummy") + request.input = None + with pytest.raises(ValueError, match="must define an input"): + manager.add("test_id", request, temp_batch_file) + + def test_add_creates_parent_directory(self, manager, tmp_path): + nested_path = tmp_path / "subdir" / "batch.jsonl" + request = ResponsesRequest(model="gpt-4", input="Test") + manager.add("test_id", request, nested_path) + + assert nested_path.exists() + assert nested_path.parent.exists() + + def test_add_with_ensure_ascii_false(self, manager_no_ascii, temp_batch_file): + request = ResponsesRequest(model="gpt-4", input="Hello δΈ–η•Œ") + manager_no_ascii.add("test_id", request, temp_batch_file) + + with open(temp_batch_file, encoding="utf-8") as f: + content = f.read() + data = json.loads(content) + + assert "δΈ–η•Œ" in content # Non-ASCII characters preserved + assert data["body"]["input"] == "Hello δΈ–η•Œ" + + def test_add_with_ensure_ascii_true(self, manager, temp_batch_file): + request = ResponsesRequest(model="gpt-4", input="Hello δΈ–η•Œ") + manager.add("test_id", request, temp_batch_file) + + with open(temp_batch_file, encoding="utf-8") as f: + raw_content = f.read() + + # ASCII escaped version should not contain the raw unicode characters + assert "\\u" in raw_content + + +class TestBatchJobManagerTemplatedInstances: + def test_add_templated_instances_responses_api(self, manager, temp_batch_file): + template = PromptTemplate( + messages=[Message(role="user", content="Product: {product}, Price: {price}")] + ) + common_request = ResponsesRequest(model="gpt-4", temperature=0.7) + instances = [ + PromptTemplateInputInstance( + id="prod_1", prompt_value_mapping={"product": "Laptop", "price": "$1000"} + ), + PromptTemplateInputInstance( + id="prod_2", prompt_value_mapping={"product": "Mouse", "price": "$20"} + ), + ] + + manager.add_templated_instances(template, common_request, instances, temp_batch_file) + + with open(temp_batch_file) as f: + lines = f.readlines() + + assert len(lines) == 2 + data1 = json.loads(lines[0]) + data2 = json.loads(lines[1]) + + assert data1["custom_id"] == "prod_1" + assert "Laptop" in str(data1["body"]["input"]) + assert data2["custom_id"] == "prod_2" + assert "Mouse" in str(data2["body"]["input"]) + + def test_add_templated_instances_chat_completions_api(self, manager, temp_batch_file): + template = PromptTemplate( + messages=[ + Message(role="system", content="You are a {role}"), + Message(role="user", content="{question}"), + ] + ) + common_request = ChatCompletionsRequest(model="gpt-4", temperature=0.5) + instances = [ + PromptTemplateInputInstance( + id="q1", + prompt_value_mapping={"role": "teacher", "question": "What is math?"}, + ), + PromptTemplateInputInstance( + id="q2", prompt_value_mapping={"role": "chef", "question": "How to cook?"} + ), + ] + + manager.add_templated_instances(template, common_request, instances, temp_batch_file) + + with open(temp_batch_file) as f: + lines = f.readlines() + + assert len(lines) == 2 + data1 = json.loads(lines[0]) + + assert data1["body"]["messages"][0]["content"] == "You are a teacher" + assert data1["body"]["messages"][1]["content"] == "What is math?" + + def test_add_templated_instances_with_reusable_prompt(self, manager, temp_batch_file): + reusable_prompt = ReusablePrompt(id="prompt_123", version="v1", variables={}) + common_request = ResponsesRequest(model="gpt-4") + instances = [ + PromptTemplateInputInstance(id="inst_1", prompt_value_mapping={"var": "value"}) + ] + + manager.add_templated_instances(reusable_prompt, common_request, instances, temp_batch_file) + + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert data["body"]["prompt"]["id"] == "prompt_123" + assert data["body"]["prompt"]["version"] == "v1" + assert data["body"]["prompt"]["variables"]["var"] == "value" + + def test_add_templated_instances_with_instance_options(self, manager, temp_batch_file): + template = PromptTemplate(messages=[Message(role="user", content="{text}")]) + common_request = ResponsesRequest(model="gpt-4", temperature=0.7) + instances = [ + PromptTemplateInputInstance( + id="inst_1", + prompt_value_mapping={"text": "Hello"}, + instance_request_options={"temperature": 0.9}, + ), + PromptTemplateInputInstance(id="inst_2", prompt_value_mapping={"text": "World"}), + ] + + manager.add_templated_instances(template, common_request, instances, temp_batch_file) + + with open(temp_batch_file) as f: + lines = f.readlines() + + data1 = json.loads(lines[0]) + data2 = json.loads(lines[1]) + + # First instance should override temperature + assert data1["body"]["temperature"] == 0.9 + # Second instance should use common temperature + assert data2["body"]["temperature"] == 0.7 + + def test_add_templated_instances_with_embeddings_raises(self, manager, temp_batch_file): + template = PromptTemplate(messages=[Message(role="user", content="Test")]) + common_request = EmbeddingsRequest(model="text-embedding-3-small", input="dummy") + instances = [ + PromptTemplateInputInstance(id="inst_1", prompt_value_mapping={"text": "Test"}) + ] + + with pytest.raises(ValueError, match="Embeddings API is not supported"): + manager.add_templated_instances(template, common_request, instances, temp_batch_file) + + def test_add_templated_instances_reusable_prompt_with_chat_raises( + self, manager, temp_batch_file + ): + reusable_prompt = ReusablePrompt(id="prompt_123", version="v1", variables={}) + common_request = ChatCompletionsRequest(model="gpt-4", messages=[]) + instances = [ + PromptTemplateInputInstance(id="inst_1", prompt_value_mapping={"var": "value"}) + ] + + with pytest.raises(ValueError, match="Reusable prompts can only be used"): + manager.add_templated_instances( + reusable_prompt, common_request, instances, temp_batch_file + ) + + def test_add_templated_instances_appending_warning(self, manager, temp_batch_file): + # Create the file first + temp_batch_file.write_text("existing content\n") + + template = PromptTemplate(messages=[Message(role="user", content="Test")]) + common_request = ResponsesRequest(model="gpt-4") + instances = [PromptTemplateInputInstance(id="inst_1", prompt_value_mapping={})] + + # Should warn when appending to existing file + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + manager.add_templated_instances(template, common_request, instances, temp_batch_file) + assert len(w) == 1 + assert "already exists" in str(w[0].message) + + def test_add_templated_instances_suppress_warnings(self, manager, temp_batch_file): + temp_batch_file.write_text("existing content\n") + + template = PromptTemplate(messages=[Message(role="user", content="Test")]) + common_request = ResponsesRequest(model="gpt-4") + instances = [PromptTemplateInputInstance(id="inst_1", prompt_value_mapping={})] + + # Should not warn when suppress_warnings=True + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + manager.add_templated_instances( + template, + common_request, + instances, + temp_batch_file, + suppress_warnings=True, + ) + assert len(w) == 0 + + +class TestBatchJobManagerEmbeddingRequests: + def test_add_embedding_requests(self, manager, temp_batch_file): + common_request = EmbeddingsRequest(model="text-embedding-3-small", dimensions=512) + inputs = [ + EmbeddingInputInstance(id="emb_1", input="First text"), + EmbeddingInputInstance(id="emb_2", input="Second text"), + ] + + manager.add_embedding_requests(inputs, common_request, temp_batch_file) + + with open(temp_batch_file) as f: + lines = f.readlines() + + assert len(lines) == 2 + data1 = json.loads(lines[0]) + data2 = json.loads(lines[1]) + + assert data1["custom_id"] == "emb_1" + assert data1["body"]["input"] == "First text" + assert data1["body"]["dimensions"] == 512 + + assert data2["custom_id"] == "emb_2" + assert data2["body"]["input"] == "Second text" + + def test_add_embedding_requests_with_list_input(self, manager, temp_batch_file): + common_request = EmbeddingsRequest(model="text-embedding-3-small") + inputs = [ + EmbeddingInputInstance(id="emb_1", input=["Text 1", "Text 2"]), + ] + + manager.add_embedding_requests(inputs, common_request, temp_batch_file) + + with open(temp_batch_file) as f: + data = json.loads(f.readline()) + + assert isinstance(data["body"]["input"], list) + assert len(data["body"]["input"]) == 2 + + def test_add_embedding_requests_with_instance_options(self, manager, temp_batch_file): + common_request = EmbeddingsRequest(model="text-embedding-3-small", dimensions=512) + inputs = [ + EmbeddingInputInstance( + id="emb_1", + input="First", + instance_request_options={"dimensions": 256}, + ), + EmbeddingInputInstance(id="emb_2", input="Second"), + ] + + manager.add_embedding_requests(inputs, common_request, temp_batch_file) + + with open(temp_batch_file) as f: + lines = f.readlines() + + data1 = json.loads(lines[0]) + data2 = json.loads(lines[1]) + + # First instance should override dimensions + assert data1["body"]["dimensions"] == 256 + # Second instance should use common dimensions + assert data2["body"]["dimensions"] == 512 + + def test_add_embedding_requests_creates_parent_directory(self, manager, tmp_path): + nested_path = tmp_path / "subdir" / "embeddings.jsonl" + common_request = EmbeddingsRequest(model="text-embedding-3-small") + inputs = [EmbeddingInputInstance(id="emb_1", input="Test")] + + manager.add_embedding_requests(inputs, common_request, nested_path) + + assert nested_path.exists() + assert nested_path.parent.exists() diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..16c8a98 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,252 @@ +from pydantic import BaseModel, Field + +from openbatch.model import ( + ChatCompletionsAPIStrategy, + ChatCompletionsRequest, + EmbeddingInputInstance, + EmbeddingsAPIStrategy, + EmbeddingsRequest, + Message, + MessagesInputInstance, + PromptTemplate, + PromptTemplateInputInstance, + ReasoningConfig, + ResponsesAPIStrategy, + ResponsesRequest, + ReusablePrompt, +) + + +class TestMessage: + def test_message_creation(self): + msg = Message(role="user", content="Hello") + assert msg.role == "user" + assert msg.content == "Hello" + + def test_message_serialize(self): + msg = Message(role="system", content="You are helpful") + serialized = msg.serialize() + assert serialized == {"role": "system", "content": "You are helpful"} + + +class TestPromptTemplate: + def test_prompt_template_creation(self): + template = PromptTemplate( + messages=[ + Message(role="system", content="You are a {role}"), + Message(role="user", content="Help me with {task}"), + ] + ) + assert len(template.messages) == 2 + + def test_prompt_template_format(self): + template = PromptTemplate( + messages=[ + Message(role="system", content="You are a {role}"), + Message(role="user", content="Help me with {task}"), + ] + ) + formatted = template.format(role="assistant", task="coding") + assert len(formatted) == 2 + assert formatted[0].content == "You are a assistant" + assert formatted[1].content == "Help me with coding" + + def test_prompt_template_format_multiple_placeholders(self): + template = PromptTemplate( + messages=[ + Message( + role="user", + content="Product: {product}, Price: {price}, Category: {category}", + ) + ] + ) + formatted = template.format(product="Laptop", price="$1000", category="Electronics") + assert formatted[0].content == "Product: Laptop, Price: $1000, Category: Electronics" + + +class TestReusablePrompt: + def test_reusable_prompt_creation(self): + prompt = ReusablePrompt(id="prompt_123", version="v1", variables={"name": "John"}) + assert prompt.id == "prompt_123" + assert prompt.version == "v1" + assert prompt.variables == {"name": "John"} + + +class TestReasoningConfig: + def test_reasoning_config_default(self): + config = ReasoningConfig() + assert config.effort == "medium" + assert config.summary is None + + def test_reasoning_config_custom(self): + config = ReasoningConfig(effort="high", summary="detailed") + assert config.effort == "high" + assert config.summary == "detailed" + + +class TestInputInstances: + def test_prompt_template_input_instance(self): + instance = PromptTemplateInputInstance( + id="inst_1", + prompt_value_mapping={"name": "Alice", "age": "30"}, + instance_request_options={"temperature": 0.5}, + ) + assert instance.id == "inst_1" + assert instance.prompt_value_mapping == {"name": "Alice", "age": "30"} + assert instance.instance_request_options == {"temperature": 0.5} + + def test_messages_input_instance(self): + messages = [Message(role="user", content="Hello")] + instance = MessagesInputInstance(id="inst_2", messages=messages) + assert instance.id == "inst_2" + assert len(instance.messages) == 1 + + def test_embedding_input_instance(self): + instance = EmbeddingInputInstance(id="emb_1", input="Text to embed") + assert instance.id == "emb_1" + assert instance.input == "Text to embed" + + def test_embedding_input_instance_list(self): + instance = EmbeddingInputInstance(id="emb_2", input=["Text 1", "Text 2"]) + assert instance.id == "emb_2" + assert isinstance(instance.input, list) + assert len(instance.input) == 2 + + +class TestAPIStrategies: + def test_responses_api_strategy(self): + strategy = ResponsesAPIStrategy() + assert strategy.url == "/v1/responses" + request = strategy.create_request("test_id", {"model": "gpt-4"}) + assert request["custom_id"] == "test_id" + assert request["method"] == "POST" + assert request["url"] == "/v1/responses" + assert request["body"] == {"model": "gpt-4"} + + def test_chat_completions_api_strategy(self): + strategy = ChatCompletionsAPIStrategy() + assert strategy.url == "/v1/chat/completions" + + def test_embeddings_api_strategy(self): + strategy = EmbeddingsAPIStrategy() + assert strategy.url == "/v1/embeddings" + + +class TestResponsesRequest: + def test_responses_request_minimal(self): + request = ResponsesRequest(model="gpt-4") + assert request.model == "gpt-4" + assert request.input is None + + def test_responses_request_with_input(self): + request = ResponsesRequest(model="gpt-4", input="Hello world") + assert request.input == "Hello world" + + def test_responses_request_to_dict(self): + request = ResponsesRequest(model="gpt-4", input="Hello", temperature=0.7) + result = request.to_dict() + assert result["model"] == "gpt-4" + assert result["input"] == "Hello" + assert result["temperature"] == 0.7 + + def test_responses_request_exclude_none(self): + request = ResponsesRequest(model="gpt-4", input="Hello") + result = request.to_dict() + assert "temperature" not in result + assert "max_output_tokens" not in result + + def test_responses_request_set_input_messages(self): + request = ResponsesRequest(model="gpt-4") + messages = [Message(role="user", content="Hello")] + request.set_input_messages(messages) + assert request.input == [{"role": "user", "content": "Hello"}] + + def test_responses_request_set_output_structure(self): + class TestOutput(BaseModel): + name: str + age: int + + request = ResponsesRequest(model="gpt-4") + request.set_output_structure(TestOutput) + assert request.text is not None + assert "format" in request.text + assert request.text["format"]["type"] == "json_schema" + assert request.text["format"]["name"] == "TestOutput" + assert request.text["format"]["strict"] is True + + def test_responses_request_with_reasoning(self): + request = ResponsesRequest( + model="gpt-4", reasoning=ReasoningConfig(effort="high", summary="detailed") + ) + assert request.reasoning.effort == "high" + assert request.reasoning.summary == "detailed" + + +class TestChatCompletionsRequest: + def test_chat_completions_request_minimal(self): + messages = [{"role": "user", "content": "Hello"}] + request = ChatCompletionsRequest(model="gpt-4", messages=messages) + assert request.model == "gpt-4" + assert len(request.messages) == 1 + + def test_chat_completions_request_set_input_messages(self): + request = ChatCompletionsRequest(model="gpt-4", messages=[]) + messages = [Message(role="user", content="Hi")] + request.set_input_messages(messages) + assert request.messages == [{"role": "user", "content": "Hi"}] + + def test_chat_completions_request_set_output_structure(self): + class TestResponse(BaseModel): + answer: str = Field(description="The answer") + + request = ChatCompletionsRequest(model="gpt-4", messages=[]) + request.set_output_structure(TestResponse) + assert request.response_format is not None + assert "format" in request.response_format + assert request.response_format["format"]["name"] == "TestResponse" + + def test_chat_completions_request_with_temperature(self): + request = ChatCompletionsRequest( + model="gpt-4", messages=[{"role": "user", "content": "Hi"}], temperature=0.9 + ) + assert request.temperature == 0.9 + + def test_chat_completions_request_to_dict(self): + request = ChatCompletionsRequest( + model="gpt-4", + messages=[{"role": "user", "content": "Hi"}], + temperature=0.5, + max_completion_tokens=100, + ) + result = request.to_dict() + assert result["model"] == "gpt-4" + assert result["temperature"] == 0.5 + assert result["max_completion_tokens"] == 100 + + +class TestEmbeddingsRequest: + def test_embeddings_request_with_string(self): + request = EmbeddingsRequest(model="text-embedding-3-small", input="Hello") + assert request.model == "text-embedding-3-small" + assert request.input == "Hello" + + def test_embeddings_request_with_list(self): + request = EmbeddingsRequest(model="text-embedding-3-small", input=["Hello", "World"]) + assert isinstance(request.input, list) + assert len(request.input) == 2 + + def test_embeddings_request_set_input(self): + request = EmbeddingsRequest(model="text-embedding-3-small", input="test") + request.set_input("New text") + assert request.input == "New text" + + def test_embeddings_request_with_dimensions(self): + request = EmbeddingsRequest(model="text-embedding-3-small", input="test", dimensions=512) + assert request.dimensions == 512 + + def test_embeddings_request_to_dict(self): + request = EmbeddingsRequest(model="text-embedding-3-small", input="test", dimensions=256) + result = request.to_dict() + assert result["model"] == "text-embedding-3-small" + assert result["input"] == "test" + assert result["dimensions"] == 256 diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..4f89994 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,261 @@ +import pytest +from pydantic import BaseModel, Field + +from openbatch._utils import ( + _ensure_strict_json_schema, + has_more_than_n_keys, + resolve_ref, + type_to_json_schema, +) + + +class TestHasMoreThanNKeys: + def test_empty_dict(self): + assert has_more_than_n_keys({}, 0) is False + assert has_more_than_n_keys({}, 1) is False + + def test_single_key(self): + assert has_more_than_n_keys({"a": 1}, 0) is True + assert has_more_than_n_keys({"a": 1}, 1) is False + assert has_more_than_n_keys({"a": 1}, 2) is False + + def test_multiple_keys(self): + obj = {"a": 1, "b": 2, "c": 3} + assert has_more_than_n_keys(obj, 0) is True + assert has_more_than_n_keys(obj, 1) is True + assert has_more_than_n_keys(obj, 2) is True + assert has_more_than_n_keys(obj, 3) is False + + +class TestResolveRef: + def test_resolve_simple_ref(self): + root = {"definitions": {"Person": {"type": "object"}}} + result = resolve_ref(root=root, ref="#/definitions/Person") + assert result == {"type": "object"} + + def test_resolve_nested_ref(self): + root = {"definitions": {"Nested": {"properties": {"field": {"type": "string"}}}}} + result = resolve_ref(root=root, ref="#/definitions/Nested/properties/field") + assert result == {"type": "string"} + + def test_resolve_ref_invalid_format(self): + root = {"definitions": {}} + with pytest.raises(ValueError, match="Does not start with #/"): + resolve_ref(root=root, ref="definitions/Person") + + +class TestEnsureStrictJsonSchema: + def test_object_adds_additional_properties_false(self): + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + result = _ensure_strict_json_schema(schema, path=(), root=schema) + assert result["additionalProperties"] is False + + def test_object_preserves_existing_additional_properties(self): + schema = { + "type": "object", + "properties": {"name": {"type": "string"}}, + "additionalProperties": True, + } + result = _ensure_strict_json_schema(schema, path=(), root=schema) + assert result["additionalProperties"] is True + + def test_object_makes_all_properties_required(self): + schema = { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + } + result = _ensure_strict_json_schema(schema, path=(), root=schema) + assert "required" in result + assert set(result["required"]) == {"name", "age"} + + def test_nested_object_properties(self): + schema = { + "type": "object", + "properties": { + "person": { + "type": "object", + "properties": {"name": {"type": "string"}}, + } + }, + } + result = _ensure_strict_json_schema(schema, path=(), root=schema) + assert result["properties"]["person"]["additionalProperties"] is False + assert result["properties"]["person"]["required"] == ["name"] + + def test_array_items(self): + schema = {"type": "array", "items": {"type": "string"}} + result = _ensure_strict_json_schema(schema, path=(), root=schema) + assert result["items"]["type"] == "string" + + def test_array_with_object_items(self): + schema = { + "type": "array", + "items": {"type": "object", "properties": {"id": {"type": "integer"}}}, + } + result = _ensure_strict_json_schema(schema, path=(), root=schema) + assert result["items"]["additionalProperties"] is False + assert result["items"]["required"] == ["id"] + + def test_any_of_union(self): + schema = { + "anyOf": [ + {"type": "string"}, + {"type": "object", "properties": {"a": {"type": "string"}}}, + ] + } + result = _ensure_strict_json_schema(schema, path=(), root=schema) + assert len(result["anyOf"]) == 2 + assert result["anyOf"][1]["additionalProperties"] is False + + def test_all_of_single_element_processed(self): + # Test that allOf with single element is processed + # Note: The implementation handles multi-element allOf, but single-element + # allOf isn't necessarily unwrapped by the current implementation + schema = { + "allOf": [ + {"type": "object", "properties": {"name": {"type": "string"}}}, + {"type": "object", "properties": {"age": {"type": "integer"}}}, + ] + } + result = _ensure_strict_json_schema(schema, path=(), root=schema) + # Verify allOf entries are processed + assert "allOf" in result + assert len(result["allOf"]) == 2 + # Each entry should have properties from original schema + assert result["allOf"][0]["type"] == "object" + assert result["allOf"][1]["type"] == "object" + + def test_definitions_processed(self): + schema = { + "type": "object", + "properties": {"user": {"$ref": "#/definitions/User"}}, + "definitions": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}}, + } + result = _ensure_strict_json_schema(schema, path=(), root=schema) + assert result["definitions"]["User"]["additionalProperties"] is False + assert result["definitions"]["User"]["required"] == ["name"] + + def test_defs_processed(self): + schema = { + "type": "object", + "properties": {"user": {"$ref": "#/$defs/User"}}, + "$defs": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}}, + } + result = _ensure_strict_json_schema(schema, path=(), root=schema) + assert result["$defs"]["User"]["additionalProperties"] is False + assert result["$defs"]["User"]["required"] == ["name"] + + def test_ref_with_additional_properties_unrolled(self): + schema = { + "type": "object", + "properties": { + "user": { + "$ref": "#/definitions/User", + "description": "The user object", + } + }, + "definitions": {"User": {"type": "object", "properties": {"name": {"type": "string"}}}}, + } + result = _ensure_strict_json_schema(schema, path=(), root=schema) + # The $ref should be unrolled when there are additional properties + user_prop = result["properties"]["user"] + assert "$ref" not in user_prop + assert user_prop["type"] == "object" + assert user_prop["description"] == "The user object" + assert user_prop["additionalProperties"] is False + + +class TestTypeToJsonSchema: + def test_simple_model(self): + class SimpleModel(BaseModel): + name: str + age: int + + schema = type_to_json_schema(SimpleModel) + assert schema["type"] == "object" + assert "name" in schema["properties"] + assert "age" in schema["properties"] + assert schema["additionalProperties"] is False + assert set(schema["required"]) == {"name", "age"} + + def test_model_with_optional_field(self): + class ModelWithOptional(BaseModel): + name: str + nickname: str | None = None + + schema = type_to_json_schema(ModelWithOptional) + # All properties should be required in strict mode + assert set(schema["required"]) == {"name", "nickname"} + + def test_model_with_nested_object(self): + class Address(BaseModel): + street: str + city: str + + class Person(BaseModel): + name: str + address: Address + + schema = type_to_json_schema(Person) + assert schema["additionalProperties"] is False + assert "address" in schema["properties"] + + # Check nested object is also strict + if "$defs" in schema: + address_schema = schema["$defs"]["Address"] + else: + address_schema = schema["definitions"]["Address"] + + assert address_schema["additionalProperties"] is False + assert set(address_schema["required"]) == {"street", "city"} + + def test_model_with_list_field(self): + class TodoList(BaseModel): + title: str + items: list[str] + + schema = type_to_json_schema(TodoList) + assert schema["properties"]["items"]["type"] == "array" + assert schema["properties"]["items"]["items"]["type"] == "string" + + def test_model_with_field_descriptions(self): + class DescribedModel(BaseModel): + name: str = Field(description="The person's name") + age: int = Field(description="The person's age") + + schema = type_to_json_schema(DescribedModel) + assert schema["properties"]["name"]["description"] == "The person's name" + assert schema["properties"]["age"]["description"] == "The person's age" + + def test_complex_nested_model(self): + class Item(BaseModel): + id: int + name: str + + class Order(BaseModel): + order_id: str + items: list[Item] + total: float + + class Customer(BaseModel): + name: str + orders: list[Order] + + schema = type_to_json_schema(Customer) + assert schema["additionalProperties"] is False + assert "orders" in schema["properties"] + + # Verify all nested models have strict schema + defs_key = "$defs" if "$defs" in schema else "definitions" + assert schema[defs_key]["Order"]["additionalProperties"] is False + assert schema[defs_key]["Item"]["additionalProperties"] is False + + def test_model_preserves_constraints(self): + class ConstrainedModel(BaseModel): + age: int = Field(ge=0, le=120) + email: str = Field(pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$") + + schema = type_to_json_schema(ConstrainedModel) + assert schema["properties"]["age"]["minimum"] == 0 + assert schema["properties"]["age"]["maximum"] == 120 + assert "pattern" in schema["properties"]["email"] diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 0000000..27593d3 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,433 @@ +"""Tests for batch file validation.""" + +import json + +import pytest + +from openbatch.validation import ( + ValidationResult, + quick_validate, + validate_batch_file, +) + + +@pytest.fixture +def temp_batch_file(tmp_path): + """Provides a temporary file path for batch files.""" + return tmp_path / "test_batch.jsonl" + + +class TestValidationResult: + def test_validation_result_str(self): + result = ValidationResult( + is_valid=False, + errors=["Error 1", "Error 2"], + warnings=["Warning 1"], + stats={"total_requests": 10, "file_size_mb": 0.5}, + ) + output = str(result) + assert "FAILED" in output + assert "Error 1" in output + assert "Warning 1" in output + assert "total_requests: 10" in output + + def test_validation_result_success(self): + result = ValidationResult(is_valid=True, stats={"total_requests": 5}) + output = str(result) + assert "PASSED" in output + + +class TestBatchFileValidator: + def test_valid_batch_file(self, temp_batch_file): + """Test validation of a valid batch file.""" + requests = [ + { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"}, + }, + { + "custom_id": "req_2", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "World"}, + }, + ] + + with open(temp_batch_file, "w") as f: + for req in requests: + f.write(json.dumps(req) + "\n") + + result = validate_batch_file(temp_batch_file) + assert result.is_valid + assert len(result.errors) == 0 + assert result.stats["total_requests"] == 2 + assert result.stats["unique_custom_ids"] == 2 + + def test_file_not_found(self): + """Test validation of non-existent file.""" + result = validate_batch_file("nonexistent.jsonl") + assert not result.is_valid + assert any("not found" in err.lower() for err in result.errors) + + def test_invalid_json(self, temp_batch_file): + """Test validation of file with invalid JSON.""" + with open(temp_batch_file, "w") as f: + f.write('{"custom_id": "req_1", "invalid json}\n') + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("invalid json" in err.lower() for err in result.errors) + + def test_duplicate_custom_ids(self, temp_batch_file): + """Test detection of duplicate custom_ids.""" + requests = [ + { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"}, + }, + { + "custom_id": "req_1", # Duplicate + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "World"}, + }, + ] + + with open(temp_batch_file, "w") as f: + for req in requests: + f.write(json.dumps(req) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("duplicate" in err.lower() for err in result.errors) + + def test_missing_required_fields(self, temp_batch_file): + """Test detection of missing required fields.""" + request = { + "custom_id": "req_1", + # Missing method, url, body + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("missing required fields" in err.lower() for err in result.errors) + + def test_invalid_method(self, temp_batch_file): + """Test detection of invalid HTTP method.""" + request = { + "custom_id": "req_1", + "method": "GET", # Should be POST + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"}, + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("invalid method" in err.lower() for err in result.errors) + + def test_invalid_endpoint(self, temp_batch_file): + """Test detection of invalid endpoint URL.""" + request = { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/invalid", # Invalid endpoint + "body": {"model": "gpt-4"}, + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("invalid endpoint" in err.lower() for err in result.errors) + + def test_responses_api_missing_input(self, temp_batch_file): + """Test Responses API validation - missing input/prompt.""" + request = { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4"}, # Missing input or prompt + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("input" in err.lower() or "prompt" in err.lower() for err in result.errors) + + def test_chat_completions_missing_messages(self, temp_batch_file): + """Test Chat Completions API validation - missing messages.""" + request = { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/chat/completions", + "body": {"model": "gpt-4"}, # Missing messages + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("messages" in err.lower() for err in result.errors) + + def test_chat_completions_invalid_messages(self, temp_batch_file): + """Test Chat Completions API validation - messages not an array.""" + request = { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/chat/completions", + "body": {"model": "gpt-4", "messages": "not an array"}, + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("messages" in err.lower() and "array" in err.lower() for err in result.errors) + + def test_embeddings_missing_input(self, temp_batch_file): + """Test Embeddings API validation - missing input.""" + request = { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/embeddings", + "body": {"model": "text-embedding-3-small"}, # Missing input + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("input" in err.lower() for err in result.errors) + + def test_mixed_endpoints_warning(self, temp_batch_file): + """Test warning for mixed endpoint types.""" + requests = [ + { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"}, + }, + { + "custom_id": "req_2", + "method": "POST", + "url": "/v1/embeddings", + "body": {"model": "text-embedding-3-small", "input": "World"}, + }, + ] + + with open(temp_batch_file, "w") as f: + for req in requests: + f.write(json.dumps(req) + "\n") + + result = validate_batch_file(temp_batch_file, allow_mixed_endpoints=False) + assert result.is_valid # Valid but with warning + assert any("multiple endpoint" in warn.lower() for warn in result.warnings) + + def test_empty_lines_warning(self, temp_batch_file): + """Test warning for empty lines.""" + with open(temp_batch_file, "w") as f: + f.write( + '{"custom_id": "req_1", "method": "POST", "url": "/v1/responses", "body": {"model": "gpt-4", "input": "Hi"}}\n' + ) + f.write("\n") # Empty line + f.write( + '{"custom_id": "req_2", "method": "POST", "url": "/v1/responses", "body": {"model": "gpt-4", "input": "Bye"}}\n' + ) + + result = validate_batch_file(temp_batch_file) + assert any("empty line" in warn.lower() for warn in result.warnings) + + def test_wrong_file_extension_warning(self, tmp_path): + """Test warning for wrong file extension.""" + json_file = tmp_path / "batch.json" # Should be .jsonl + request = { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"}, + } + + with open(json_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(json_file) + assert any(".jsonl" in warn.lower() for warn in result.warnings) + + def test_missing_model_in_body(self, temp_batch_file): + """Test detection of missing model field in body.""" + request = { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": {"input": "Hello"}, # Missing model + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("model" in err.lower() for err in result.errors) + + def test_invalid_custom_id_type(self, temp_batch_file): + """Test detection of invalid custom_id type.""" + request = { + "custom_id": 123, # Should be string + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"}, + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("custom_id" in err.lower() for err in result.errors) + + def test_empty_custom_id(self, temp_batch_file): + """Test detection of empty custom_id.""" + request = { + "custom_id": "", # Empty string + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"}, + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("custom_id" in err.lower() for err in result.errors) + + def test_body_not_object(self, temp_batch_file): + """Test detection of non-object body.""" + request = { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": "not an object", + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert not result.is_valid + assert any("body" in err.lower() and "object" in err.lower() for err in result.errors) + + def test_skip_custom_id_check(self, temp_batch_file): + """Test disabling custom_id uniqueness check.""" + requests = [ + { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"}, + }, + { + "custom_id": "req_1", # Duplicate + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "World"}, + }, + ] + + with open(temp_batch_file, "w") as f: + for req in requests: + f.write(json.dumps(req) + "\n") + + result = validate_batch_file(temp_batch_file, check_custom_id_uniqueness=False) + # Should be valid when uniqueness check is disabled + assert result.is_valid + + +class TestConvenienceFunctions: + def test_quick_validate_true(self, temp_batch_file): + """Test quick_validate with valid file.""" + request = { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"}, + } + + with open(temp_batch_file, "w") as f: + f.write(json.dumps(request) + "\n") + + assert quick_validate(temp_batch_file) is True + + def test_quick_validate_false(self, temp_batch_file): + """Test quick_validate with invalid file.""" + with open(temp_batch_file, "w") as f: + f.write("invalid json\n") + + assert quick_validate(temp_batch_file) is False + + +class TestComplexScenarios: + def test_large_valid_file(self, temp_batch_file): + """Test validation of file with many requests.""" + with open(temp_batch_file, "w") as f: + for i in range(1000): + request = { + "custom_id": f"req_{i}", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": f"Request {i}"}, + } + f.write(json.dumps(request) + "\n") + + result = validate_batch_file(temp_batch_file) + assert result.is_valid + assert result.stats["total_requests"] == 1000 + assert result.stats["unique_custom_ids"] == 1000 + + def test_all_three_endpoints(self, temp_batch_file): + """Test file with all three valid endpoints.""" + requests = [ + { + "custom_id": "req_1", + "method": "POST", + "url": "/v1/responses", + "body": {"model": "gpt-4", "input": "Hello"}, + }, + { + "custom_id": "req_2", + "method": "POST", + "url": "/v1/chat/completions", + "body": {"model": "gpt-4", "messages": [{"role": "user", "content": "Hi"}]}, + }, + { + "custom_id": "req_3", + "method": "POST", + "url": "/v1/embeddings", + "body": {"model": "text-embedding-3-small", "input": "Text"}, + }, + ] + + with open(temp_batch_file, "w") as f: + for req in requests: + f.write(json.dumps(req) + "\n") + + result = validate_batch_file(temp_batch_file, allow_mixed_endpoints=True) + assert result.is_valid + assert len(result.stats["endpoints_used"]) == 3