diff --git a/tests/test_utils.py b/tests/test_utils.py index 1e3a5f27..0087d1d2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,6 +3,7 @@ import pytest +from werk24.utils.defaults import Settings from werk24.utils.exceptions import InvalidLicenseException from werk24.utils.license import ( License, @@ -127,3 +128,9 @@ def test_find_license_no_valid_license(mock_search_paths): ): with pytest.raises(InvalidLicenseException): find_license() + + +def test_settings_invalid_log_level(): + """Ensure invalid log levels raise an error.""" + with pytest.raises(ValueError): + Settings(log_level="BOGUS") diff --git a/werk24/utils/defaults.py b/werk24/utils/defaults.py index adc570ea..ab506ba3 100644 --- a/werk24/utils/defaults.py +++ b/werk24/utils/defaults.py @@ -1,7 +1,7 @@ from typing import Set from packaging.version import Version -from pydantic import AnyUrl, Field, HttpUrl +from pydantic import AnyUrl, Field, HttpUrl, field_validator from pydantic_settings import BaseSettings @@ -61,24 +61,15 @@ class Settings(BaseSettings): max_https_retries: int = Field(3, ge=0) """Maximum retries for HTTPS requests. Must be greater than or equal to 0.""" - @staticmethod - def validate_log_level(cls, values): - """ - Validates the `log_level` attribute. - - Ensures that the log level is one of the accepted values. - - Raises: - ------ - ValueError: - If the `log_level` is not in the accepted set. + @field_validator("log_level") + @classmethod + def validate_log_level(cls, v: str) -> str: + """Validate the ``log_level`` attribute. - Returns: - ------- - dict: - The validated values. + Ensures that the provided log level is one of the accepted values and + returns the validated value. """ valid_log_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"} - if values.get("log_level") not in valid_log_levels: + if v not in valid_log_levels: raise ValueError(f"log_level must be one of {valid_log_levels}") - return values + return v