Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Werk24Client,
get_test_drawing,
)
from werk24.utils.exceptions import InvalidLicenseException, UnauthorizedException

FILE_PATH = files("werk24") / "assets/DRAWING_SUCCESS.png"

Expand Down Expand Up @@ -71,3 +72,23 @@ async def test_read_drawing_with_callback(
drawing_bytes, [AskMetaData()], callback_url
)
assert request_id is not None


@pytest.mark.asyncio
async def test_invalid_token():
"""
Test that an empty token raises an UnauthorizedException.
"""
with pytest.raises(UnauthorizedException):
async with Werk24Client(token="", region="eu-central-1") as client:
...


@pytest.mark.asyncio
async def test_invalid_region():
"""
Test that an empty region raises an InvalidLicenseException.
"""
with pytest.raises(InvalidLicenseException):
async with Werk24Client(token="", region=None) as client:
...
8 changes: 4 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import pytest

from werk24.utils.exceptions import InvalidLicenseException
from werk24.utils.license import (
License,
LicenseInvalid,
find_license,
find_license_in_envs,
find_license_in_paths,
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_parse_license_text(valid_license):

def test_parse_license_text_invalid():
"""Test parsing invalid license text raises an exception."""
with pytest.raises(LicenseInvalid):
with pytest.raises(InvalidLicenseException):
parse_license_text(INVALID_LICENSE_TEXT)


Expand All @@ -70,7 +70,7 @@ def test_parse_license_file(valid_license, mock_search_paths):
def test_parse_license_file_invalid(mock_search_paths):
"""Test parsing an invalid license file raises an exception."""
with patch("builtins.open", mock_open(read_data=INVALID_LICENSE_TEXT)):
with pytest.raises(LicenseInvalid):
with pytest.raises(InvalidLicenseException):
parse_license_file("./mock_license.txt")


Expand Down Expand Up @@ -125,5 +125,5 @@ def test_find_license_no_valid_license(mock_search_paths):
with patch("builtins.open", side_effect=FileNotFoundError), patch.dict(
os.environ, {}, clear=True
):
with pytest.raises(LicenseInvalid):
with pytest.raises(InvalidLicenseException):
find_license()
5 changes: 3 additions & 2 deletions werk24/cli/commands/health_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from werk24._version import __version__
from werk24.techread import Werk24Client
from werk24.utils.defaults import Settings
from werk24.utils.license import LicenseInvalid, find_license
from werk24.utils.exceptions import InvalidLicenseException
from werk24.utils.license import find_license

# Initialize Typer app and Rich console
app = typer.Typer()
Expand Down Expand Up @@ -62,7 +63,7 @@ def license_information():
try:
find_license()
license_status = "[green]Found[/green]"
except LicenseInvalid:
except InvalidLicenseException:
license_status = (
"[red]Not Found[/red] - Run [bold]werk24 init[/bold] to configure."
)
Expand Down
6 changes: 3 additions & 3 deletions werk24/cli/commands/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from rich.text import Text

from werk24.utils.defaults import Settings
from werk24.utils.exceptions import InvalidLicenseException
from werk24.utils.license import (
LicenseInvalid,
find_license,
parse_license_text,
save_license_file,
Expand All @@ -29,7 +29,7 @@ def init():
)
)
return
except LicenseInvalid:
except InvalidLicenseException:
pass # Continue to ask user to create a license file
ask_user_to_create_license()

Expand Down Expand Up @@ -83,7 +83,7 @@ def accept_license_from_terminal():
license = parse_license_text(license_text)
save_license_file(license)
console.print(Panel("[bold green]License successfully saved![/bold green]"))
except LicenseInvalid:
except InvalidLicenseException:
console.print("[red]Invalid license text. Please try again.[/red]")
accept_license_from_terminal() # Retry on failure

Expand Down
19 changes: 19 additions & 0 deletions werk24/techread.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,25 @@ def _create_websocket_session(
async def __aenter__(self):
try:
self._wss_session = await self._create_websocket_session()

# -----------------------------------------------------------
# Handle the error codes
# -----------------------------------------------------------
except websockets.exceptions.InvalidStatus as exc:
match exc.response.status_code:
case 403:
raise UnauthorizedException(
"Invalid status when connecting to the server"
) from exc

case _:
raise ServerException(
f"Invalid status when connecting to the server: {exc.response.status_code}"
) from exc

# -----------------------------------------------------------
# Handle remaining exceptions
# -----------------------------------------------------------
except Exception as exc:
logger.error("Failed to establish a connection with the server: %s", exc)
raise ServerException(details=str(exc)) from exc
Expand Down
10 changes: 10 additions & 0 deletions werk24/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,13 @@ class UserInputError(TechreadException):
cli_message_body: str = (
"The input provided is invalid. Please verify your input and try again."
)


class InvalidLicenseException(TechreadException):
"""Exception raised when the provided license is invalid."""

cli_message_header: str = "Invalid License"
cli_message_body: str = (
"The provided license is invalid or has expired.\n\n"
"Please ensure that you provide a token AND a region."
)
58 changes: 32 additions & 26 deletions werk24/utils/license.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import dotenv
from pydantic import BaseModel

from werk24.utils.exceptions import InvalidLicenseException

from .logger import get_logger

# Define constants
Expand All @@ -25,11 +27,6 @@ class License(BaseModel):
region: str


# Custom exception for invalid licenses
class LicenseInvalid(Exception):
pass


def find_license(token: Optional[str] = None, region: Optional[str] = None) -> License:
"""
Find a valid license by searching predefined paths or environment variables.
Expand All @@ -44,24 +41,33 @@ def find_license(token: Optional[str] = None, region: Optional[str] = None) -> L

Raises:
------
- LicenseInvalid: If no valid license is found.
- InvalidLicenseException: If no valid license is found.
"""
if token is not None and region is not None:
logger.info("Using provided license token.")
return License(token=token, region=region)

# -----------------------------------------------------------
# Check if token and region are provided
# -----------------------------------------------------------
if token is not None:
try:
return License(token=token, region=region)
except ValueError as e:
raise InvalidLicenseException(
"The license requires a token and a region"
) from e

# -----------------------------------------------------------
# If not provided, search for a valid license
# -----------------------------------------------------------
logger.info("Searching for a valid license...")
license = find_license_in_paths() or find_license_in_envs()
if not license:
logger.error("No valid license found.")
raise LicenseInvalid("No valid license could be found.")
if license:
return license

# Overwriting the region if provided
logger.info("Valid license found.")
if region is not None:
logger.warning("Overwriting region with provided value.")
license.region = region
return license
# -----------------------------------------------------------
# If no valid license is found, raise an exception
# -----------------------------------------------------------
logger.error("No valid license found.")
raise InvalidLicenseException("No valid license could be found.")


def find_license_in_paths() -> Optional[License]:
Expand All @@ -79,7 +85,7 @@ def find_license_in_paths() -> Optional[License]:
if os.path.exists(expanded_path):
try:
return parse_license_file(expanded_path)
except LicenseInvalid:
except InvalidLicenseException:
logger.debug(f"Invalid license at {expanded_path}")
else:
logger.debug(f"No license file found at {expanded_path}")
Expand Down Expand Up @@ -119,7 +125,7 @@ def parse_license_file(path: str) -> License:

Raises:
------
- LicenseInvalid: If the license file is invalid or cannot be read.
- InvalidLicenseException: If the license file is invalid or cannot be read.
"""
logger.debug(f"Attempting to parse license file at {path}")
try:
Expand All @@ -128,10 +134,10 @@ def parse_license_file(path: str) -> License:
return parse_license_text(content)
except FileNotFoundError as e:
logger.error(f"License file not found at {path}")
raise LicenseInvalid("License file not found.") from e
raise InvalidLicenseException("License file not found.") from e
except Exception as e:
logger.error(f"Error parsing license file at {path}: {e}")
raise LicenseInvalid("Invalid license file.") from e
raise InvalidLicenseException("Invalid license file.") from e


def parse_license_text(text: str) -> License:
Expand All @@ -148,7 +154,7 @@ def parse_license_text(text: str) -> License:

Raises:
------
- LicenseInvalid: If the license text is invalid.
- InvalidLicenseException: If the license text is invalid.
"""
logger.debug("Parsing license text...")
try:
Expand All @@ -159,9 +165,9 @@ def parse_license_text(text: str) -> License:
region = vars["W24TECHREAD_AUTH_REGION"]
logger.debug("License text parsed successfully.")
return License(token=token, region=region)
except KeyError as e:
except (ValueError, KeyError) as e:
logger.error(f"Missing key in license text: {e}")
raise LicenseInvalid("Invalid license text format.") from e
raise InvalidLicenseException("Invalid license text format.") from e


def save_license_file(license: License):
Expand All @@ -180,4 +186,4 @@ def save_license_file(license: License):
logger.info(f"License saved successfully at {license_path}")
except Exception as e:
logger.error(f"Error saving license file: {e}")
raise LicenseInvalid("Could not save the license file.") from e
raise InvalidLicenseException("Could not save the license file.") from e
Loading