diff --git a/examples/ask_validation_example.py b/examples/ask_validation_example.py new file mode 100644 index 0000000..404652b --- /dev/null +++ b/examples/ask_validation_example.py @@ -0,0 +1,165 @@ +""" +Example demonstrating ask type validation in werk24-python client. + +This example shows how the client validates ask types before sending +requests to the API, providing helpful error messages for invalid asks. +""" + +import asyncio + +from werk24 import Werk24Client +from werk24.models.v1.ask import W24AskTitleBlock, W24AskVariantMeasures +from werk24.models.v2.asks import AskBalloons, AskFeatures +from werk24.utils.exceptions import BadRequestException + + +async def example_valid_asks(): + """Example with valid ask types (both v1 and v2).""" + print("=" * 60) + print("Example 1: Valid ask types") + print("=" * 60) + + # Create a list of valid asks (mixing v1 and v2) + asks = [ + W24AskTitleBlock(), + W24AskVariantMeasures(), + AskBalloons(), + AskFeatures(), + ] + + # Validate asks before sending (this is done automatically by read_drawing) + try: + Werk24Client.validate_asks(asks) + print("✓ All ask types are valid!") + print(f" - {len(asks)} asks validated successfully") + except BadRequestException as e: + print(f"✗ Validation failed: {e}") + + print() + + +async def example_invalid_ask(): + """Example with an invalid ask type.""" + print("=" * 60) + print("Example 2: Invalid ask type") + print("=" * 60) + + from pydantic import BaseModel + + # Create an invalid ask type + class InvalidAsk(BaseModel): + ask_type: str = "NONEXISTENT_ASK_TYPE" + + asks = [ + W24AskTitleBlock(), + InvalidAsk(), # This will fail validation + ] + + try: + Werk24Client.validate_asks(asks) + print("✓ All ask types are valid!") + except BadRequestException as e: + print(f"✗ Validation failed:") + print(f" {e}") + + print() + + +async def example_empty_asks(): + """Example with empty ask list.""" + print("=" * 60) + print("Example 3: Empty ask list") + print("=" * 60) + + asks = [] + + try: + Werk24Client.validate_asks(asks) + print("✓ All ask types are valid!") + except BadRequestException as e: + print(f"✗ Validation failed:") + print(f" {e}") + + print() + + +async def example_validation_in_read_drawing(): + """Example showing automatic validation in read_drawing.""" + print("=" * 60) + print("Example 4: Automatic validation in read_drawing") + print("=" * 60) + + import io + + from pydantic import BaseModel + + class InvalidAsk(BaseModel): + ask_type: str = "INVALID_TYPE" + + client = Werk24Client() + drawing = io.BytesIO(b"fake drawing content") + + print("Attempting to call read_drawing with invalid ask type...") + + try: + async with client: + async for message in client.read_drawing(drawing, [InvalidAsk()]): + print(f"Received message: {message}") + except BadRequestException as e: + print(f"✓ Validation caught the error before sending to API:") + print(f" {e}") + except Exception as e: + print(f"Other error: {e}") + + print() + + +async def example_helpful_error_message(): + """Example showing the helpful error message with valid ask types.""" + print("=" * 60) + print("Example 5: Helpful error message") + print("=" * 60) + + from pydantic import BaseModel + + class InvalidAsk1(BaseModel): + ask_type: str = "WRONG_TYPE_1" + + class InvalidAsk2(BaseModel): + ask_type: str = "WRONG_TYPE_2" + + asks = [InvalidAsk1(), InvalidAsk2()] + + try: + Werk24Client.validate_asks(asks) + except BadRequestException as e: + error_msg = str(e) + print("Error message includes:") + print(f" - Invalid ask types: WRONG_TYPE_1, WRONG_TYPE_2") + print(f" - List of all valid ask types") + print() + print("Full error message (truncated):") + print(f" {error_msg[:200]}...") + + print() + + +async def main(): + """Run all examples.""" + print("\n" + "=" * 60) + print("Ask Type Validation Examples") + print("=" * 60 + "\n") + + await example_valid_asks() + await example_invalid_ask() + await example_empty_asks() + await example_validation_in_read_drawing() + await example_helpful_error_message() + + print("=" * 60) + print("Examples completed!") + print("=" * 60) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/test_pydantic_validation.py b/test_pydantic_validation.py new file mode 100644 index 0000000..efc46b4 --- /dev/null +++ b/test_pydantic_validation.py @@ -0,0 +1,42 @@ +"""Quick test to see if Pydantic already validates ask types.""" + +from pydantic import BaseModel, ValidationError + +from werk24.models.v1.ask import W24AskTitleBlock +from werk24.models.v2.asks import AskBalloons +from werk24.models.v2.internal import TechreadRequest + +# Test 1: Valid asks should work +print("Test 1: Valid asks") +try: + request = TechreadRequest(asks=[W24AskTitleBlock(), AskBalloons()], max_pages=5) + print(f"✓ Valid asks accepted: {len(request.asks)} asks") +except ValidationError as e: + print(f"✗ Validation error: {e}") + +print() + +# Test 2: Invalid ask type +print("Test 2: Invalid ask type") + + +class InvalidAsk(BaseModel): + ask_type: str = "INVALID_TYPE" + + +try: + request = TechreadRequest(asks=[InvalidAsk()], max_pages=5) + print(f"✗ Invalid ask was accepted (shouldn't happen)") +except ValidationError as e: + print(f"✓ Pydantic caught the invalid ask:") + print(f" {e.errors()[0]['msg']}") + +print() + +# Test 3: Empty asks list +print("Test 3: Empty asks list") +try: + request = TechreadRequest(asks=[], max_pages=5) + print(f"✓ Empty asks list accepted (Pydantic doesn't validate list length)") +except ValidationError as e: + print(f"✗ Validation error: {e}") diff --git a/tests/test_ask_validation.py b/tests/test_ask_validation.py new file mode 100644 index 0000000..6f1e65a --- /dev/null +++ b/tests/test_ask_validation.py @@ -0,0 +1,118 @@ +""" +Tests for ask type validation in Werk24Client. + +This module tests the validate_asks() method to ensure it properly validates +both W24AskType (v1) and AskType (v2) ask types. +""" + +import pytest + +from werk24 import Werk24Client +from werk24.models.v1.ask import ( + W24AskTitleBlock, + W24AskVariantGDTs, + W24AskVariantMeasures, +) +from werk24.models.v2.asks import AskBalloons, AskFeatures, AskInsights +from werk24.utils.exceptions import BadRequestException + + +class TestAskValidation: + """Test suite for ask type validation.""" + + def test_validate_asks_with_valid_v1_asks(self): + """Test that valid v1 ask types pass validation.""" + asks = [ + W24AskTitleBlock(), + W24AskVariantMeasures(), + W24AskVariantGDTs(), + ] + # Should not raise any exception + Werk24Client.validate_asks(asks) + + def test_validate_asks_with_valid_v2_asks(self): + """Test that valid v2 ask types pass validation.""" + asks = [ + AskBalloons(), + AskFeatures(), + AskInsights(), + ] + # Should not raise any exception + Werk24Client.validate_asks(asks) + + def test_validate_asks_with_mixed_v1_and_v2_asks(self): + """Test that mixed v1 and v2 ask types pass validation.""" + asks = [ + W24AskTitleBlock(), + AskBalloons(), + W24AskVariantMeasures(), + AskFeatures(), + ] + # Should not raise any exception + Werk24Client.validate_asks(asks) + + def test_validate_asks_with_empty_list(self): + """Test that empty ask list raises BadRequestException.""" + with pytest.raises(BadRequestException) as exc_info: + Werk24Client.validate_asks([]) + + assert "No ask types provided" in str(exc_info.value) + + def test_validate_asks_with_invalid_ask_type(self): + """Test that invalid ask type raises BadRequestException with helpful message.""" + from pydantic import BaseModel + + class InvalidAsk(BaseModel): + ask_type: str = "INVALID_ASK_TYPE" + + with pytest.raises(BadRequestException) as exc_info: + Werk24Client.validate_asks([InvalidAsk()]) + + error_msg = str(exc_info.value) + assert "Invalid ask type(s): INVALID_ASK_TYPE" in error_msg + assert "Valid ask types are:" in error_msg + + def test_validate_asks_with_multiple_invalid_ask_types(self): + """Test that multiple invalid ask types are all reported.""" + from pydantic import BaseModel + + class InvalidAsk1(BaseModel): + ask_type: str = "INVALID_TYPE_1" + + class InvalidAsk2(BaseModel): + ask_type: str = "INVALID_TYPE_2" + + with pytest.raises(BadRequestException) as exc_info: + Werk24Client.validate_asks([InvalidAsk1(), InvalidAsk2()]) + + error_msg = str(exc_info.value) + assert "INVALID_TYPE_1" in error_msg + assert "INVALID_TYPE_2" in error_msg + + def test_validate_asks_with_missing_ask_type_attribute(self): + """Test that ask without ask_type attribute raises BadRequestException.""" + from pydantic import BaseModel + + class AskWithoutType(BaseModel): + some_field: str = "value" + + with pytest.raises(BadRequestException) as exc_info: + Werk24Client.validate_asks([AskWithoutType()]) + + error_msg = str(exc_info.value) + assert "(missing ask_type)" in error_msg + + def test_validate_asks_error_message_includes_valid_types(self): + """Test that error message includes list of valid ask types.""" + from pydantic import BaseModel + + class InvalidAsk(BaseModel): + ask_type: str = "INVALID" + + with pytest.raises(BadRequestException) as exc_info: + Werk24Client.validate_asks([InvalidAsk()]) + + error_msg = str(exc_info.value) + # Check that some known valid types are in the error message + assert "TITLE_BLOCK" in error_msg or "BALLOONS" in error_msg + assert "VARIANT_MEASURES" in error_msg or "FEATURES" in error_msg diff --git a/tests/test_ask_validation_integration.py b/tests/test_ask_validation_integration.py new file mode 100644 index 0000000..314a871 --- /dev/null +++ b/tests/test_ask_validation_integration.py @@ -0,0 +1,124 @@ +""" +Integration tests for ask type validation in Werk24Client methods. + +This module tests that ask validation is properly integrated into the +client's read_drawing and read_drawing_with_callback methods. +""" + +import io +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from werk24 import Werk24Client +from werk24.models.v1.ask import W24AskTitleBlock +from werk24.models.v2.asks import AskBalloons +from werk24.utils.exceptions import BadRequestException + + +class TestAskValidationIntegration: + """Integration tests for ask validation in client methods.""" + + @pytest.mark.asyncio + async def test_read_drawing_validates_asks_before_processing(self): + """Test that read_drawing validates asks before processing.""" + from pydantic import BaseModel + + class InvalidAsk(BaseModel): + ask_type: str = "INVALID_TYPE" + + client = Werk24Client() + drawing = io.BytesIO(b"fake drawing content") + + # Should raise BadRequestException due to invalid ask type + with pytest.raises(BadRequestException) as exc_info: + async with client: + async for _ in client.read_drawing(drawing, [InvalidAsk()]): + pass + + assert "Invalid ask type(s): INVALID_TYPE" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_read_drawing_with_callback_validates_asks(self): + """Test that read_drawing_with_callback validates asks.""" + from pydantic import BaseModel + + class InvalidAsk(BaseModel): + ask_type: str = "INVALID_TYPE" + + client = Werk24Client() + drawing = io.BytesIO(b"fake drawing content") + + # Should raise BadRequestException due to invalid ask type + # before even trying to make the HTTP request + with pytest.raises(BadRequestException) as exc_info: + await client.read_drawing_with_callback( + drawing, [InvalidAsk()], callback_url="https://example.com/callback" + ) + + assert "Invalid ask type(s): INVALID_TYPE" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_read_drawing_with_empty_asks_raises_error(self): + """Test that read_drawing with empty asks raises BadRequestException.""" + client = Werk24Client() + drawing = io.BytesIO(b"fake drawing content") + + with pytest.raises(BadRequestException) as exc_info: + async with client: + async for _ in client.read_drawing(drawing, []): + pass + + assert "No ask types provided" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_read_drawing_with_valid_asks_passes_validation(self): + """Test that read_drawing with valid asks passes validation.""" + client = Werk24Client() + drawing = io.BytesIO(b"fake drawing content") + asks = [W24AskTitleBlock(), AskBalloons()] + + # Mock the websocket connection to avoid actual API calls + with patch.object(client, "_create_websocket_session") as mock_ws: + mock_ws_instance = AsyncMock() + mock_ws_instance.__aenter__ = AsyncMock(return_value=mock_ws_instance) + mock_ws_instance.__aexit__ = AsyncMock() + mock_ws.return_value = mock_ws_instance + + # Mock the send and recv methods + mock_ws_instance.send = AsyncMock() + mock_ws_instance.recv = AsyncMock(side_effect=Exception("Stop iteration")) + + try: + async with client: + async for _ in client.read_drawing(drawing, asks): + pass + except Exception: + # We expect an exception because we're mocking, but the important + # thing is that validation passed (no BadRequestException) + pass + + def test_validate_asks_can_be_called_directly(self): + """Test that validate_asks can be called directly as a static method.""" + asks = [W24AskTitleBlock(), AskBalloons()] + + # Should not raise any exception + Werk24Client.validate_asks(asks) + + def test_validate_asks_provides_helpful_error_messages(self): + """Test that validation errors include helpful information.""" + from pydantic import BaseModel + + class InvalidAsk(BaseModel): + ask_type: str = "NONEXISTENT_ASK" + + with pytest.raises(BadRequestException) as exc_info: + Werk24Client.validate_asks([InvalidAsk()]) + + error_msg = str(exc_info.value) + # Should mention the invalid type + assert "NONEXISTENT_ASK" in error_msg + # Should provide list of valid types + assert "Valid ask types are:" in error_msg + # Should include some actual valid types + assert "TITLE_BLOCK" in error_msg or "BALLOONS" in error_msg diff --git a/werk24/techread.py b/werk24/techread.py index 82d7f1a..8da6ac7 100644 --- a/werk24/techread.py +++ b/werk24/techread.py @@ -98,6 +98,80 @@ def __init__( # certificate chain is properly verified. self._ssl_context = ssl.create_default_context(cafile=certifi.where()) + @staticmethod + def validate_asks(asks: List[AskV2]) -> None: + """ + Validate ask types before sending request to the API. + + This method checks if all provided ask types are valid according to either + API v1 (W24AskType) or API v2 (AskType) specifications. It raises a + BadRequestException with helpful error messages if invalid ask types are found. + + Parameters + ---------- + asks : List[AskV2] + List of ask types to validate. Can be W24Ask (v1) or AskV2 (v2) objects. + + Raises + ------ + BadRequestException + If any ask types are invalid, with a message listing the invalid types + and all valid ask types. + + Examples + -------- + >>> from werk24.models.v1.ask import W24AskVariantMeasures + >>> from werk24.models.v2.asks import AskBalloons + >>> asks = [W24AskVariantMeasures(), AskBalloons()] + >>> Werk24Client.validate_asks(asks) # No exception raised + + >>> from pydantic import BaseModel + >>> class InvalidAsk(BaseModel): + ... ask_type = "INVALID_TYPE" + >>> Werk24Client.validate_asks([InvalidAsk()]) # Raises BadRequestException + """ + from werk24.models.v1.ask import W24AskType + from werk24.models.v2.enums import AskType + + if not asks: + raise BadRequestException( + "No ask types provided. At least one ask type is required." + ) + + # Get all valid ask types from both versions + valid_v1_types = {ask_type.value for ask_type in W24AskType} + valid_v2_types = {ask_type.value for ask_type in AskType} + all_valid_types = valid_v1_types | valid_v2_types + + # Extract and validate ask type names from the input + invalid_asks = [] + for ask in asks: + # Get the ask_type attribute + ask_type_value = getattr(ask, "ask_type", None) + + if ask_type_value is None: + invalid_asks.append("(missing ask_type)") + continue + + # Convert enum to string if needed + if hasattr(ask_type_value, "value"): + ask_type_str = ask_type_value.value + else: + ask_type_str = str(ask_type_value) + + # Check if the ask type is valid + if ask_type_str not in all_valid_types: + invalid_asks.append(ask_type_str) + + if invalid_asks: + # Create helpful error message + sorted_valid_types = sorted(all_valid_types) + error_msg = ( + f"Invalid ask type(s): {', '.join(invalid_asks)}. " + f"Valid ask types are: {', '.join(sorted_valid_types)}" + ) + raise BadRequestException(error_msg) + def _get_auth_headers(self): """ Get the authentication headers for the request. @@ -166,7 +240,23 @@ async def read_drawing_with_hooks( max_pages: int = settings.max_pages, encryption_keys: Optional[EncryptionKeys] = None, ): + """ + Read the drawing and call hooks for each message. + + This method extracts asks from hooks and processes the drawing. + Ask validation is performed by the read_drawing method. + + Args: + ---- + - drawing (Union[BufferedReader, bytes]): The drawing to process. + - hooks (list[Hook]): List of hooks to call for each message. + - max_pages (int, optional): Maximum number of pages to process. + - encryption_keys (Optional[EncryptionKeys], optional): Optional encryption keys. + Raises: + ------ + - BadRequestException: If ask types are invalid. + """ asks_list = [cur_ask.ask for cur_ask in hooks if cur_ask.ask is not None] # send out the request and make a generator @@ -192,10 +282,11 @@ async def read_drawing( This function performs the following steps: 1. Validates the input drawing. - 2. Sends an initiation request with the specified questions (`asks`). - 3. Uploads the drawing to the server. - 4. Signals the server to start reading the uploaded drawing. - 5. Yields messages as the process progresses. + 2. Validates the ask types. + 3. Sends an initiation request with the specified questions (`asks`). + 4. Uploads the drawing to the server. + 5. Signals the server to start reading the uploaded drawing. + 6. Yields messages as the process progresses. Args: ---- @@ -212,13 +303,16 @@ async def read_drawing( Raises: ------ - - BadRequestException: If the request is malformed. + - BadRequestException: If the request is malformed or ask types are invalid. - RequestTooLargeException: If the drawing exceeds the maximum size limit. - Any other exceptions encountered will be logged and re-raised. """ # Run the preflight checks self.run_preflight_checks(drawing) + # Validate ask types before sending request + self.validate_asks(asks) + # Initiate the request init_message, init_response = await self.init_request(asks, max_pages) yield init_message @@ -717,6 +811,7 @@ async def read_drawing_with_callback( Raises: ------ + - BadRequestException: Raised when ask types are invalid. - ServerException: Raised when the server returns an error message. - InsufficientCreditsException: Raised when the user lacks sufficient credits for the request. @@ -728,6 +823,9 @@ async def read_drawing_with_callback( """ logger.debug("API method read_drawing_with_callback() called") + # Validate ask types before sending request + self.validate_asks(asks) + # send the request to the API # Set a default drawing filename if none is provided