-
Notifications
You must be signed in to change notification settings - Fork 0
Add Pydantic validation model for config.json #8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| """Config schema — Pydantic models for hydra/config.json validation. | ||
|
|
||
| Validates the full framework configuration at startup so misconfigurations | ||
| are caught immediately rather than at runtime when a .get() path is first hit. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Annotated, Literal, Union | ||
|
|
||
| from pydantic import BaseModel, Field, field_validator | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Sub-section models | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| class FrameworkConfig(BaseModel): | ||
| name: str = "Hydra" | ||
| version: str = "0.1.0" | ||
| debug: bool = False | ||
| log_level: str = "INFO" | ||
|
|
||
| @field_validator("log_level") | ||
| @classmethod | ||
| def validate_log_level(cls, v: str) -> str: | ||
| valid = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"} | ||
| upper = v.upper() | ||
| if upper not in valid: | ||
| raise ValueError(f"log_level must be one of {valid}, got '{v}'") | ||
| return upper | ||
|
|
||
|
|
||
| class DatabaseConfig(BaseModel): | ||
| url: str = "sqlite+aiosqlite:///hydra.db" | ||
|
|
||
|
|
||
| class DashboardConfig(BaseModel): | ||
| host: str = "0.0.0.0" | ||
| port: int = Field(default=8080, ge=1, le=65535) | ||
| enabled: bool = True | ||
|
|
||
|
|
||
| class ContextBusConfig(BaseModel): | ||
| backend: str = "memory" | ||
| max_queue_size: int = Field(default=10000, ge=1) | ||
| overflow_policy: str = "drop_oldest" | ||
| default_ttl: float | None = None | ||
| redis_url: str = "redis://localhost:6379/0" | ||
|
|
||
|
|
||
| class ProfileConfig(BaseModel): | ||
| id: str | ||
| name: str | ||
| avatar_url: str = "" | ||
| banner_url: str = "" | ||
| status: str = "" | ||
| theme: str = "default" | ||
| adapters: list[str] = [] | ||
| plugins: list[str] = [] | ||
| workers: list[str] = [] | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Adapter config models (discriminated union on "type") | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| class MockAdapterConfig(BaseModel): | ||
| type: Literal["mock"] | ||
| interval: float = 15 | ||
|
|
||
|
|
||
| class DiscordAdapterConfig(BaseModel): | ||
| type: Literal["discord"] | ||
| token: str = "" | ||
| intents: int = 513 | ||
|
|
||
|
|
||
| class OpenAIAdapterConfig(BaseModel): | ||
| type: Literal["openai"] | ||
| base_url: str = "https://api.openai.com/v1" | ||
| api_key: str = "" | ||
| model: str = "gpt-4" | ||
| temperature: float = Field(default=0.7, ge=0.0, le=2.0) # OpenAI API accepted range | ||
|
|
||
|
|
||
| class MCPAdapterConfig(BaseModel): | ||
| type: Literal["mcp"] | ||
| command: str = "" | ||
| args: list[str] = [] | ||
| env: dict[str, str] = {} | ||
|
|
||
|
|
||
| AdapterConfig = Annotated[ | ||
| Union[ | ||
| MockAdapterConfig, | ||
| DiscordAdapterConfig, | ||
| OpenAIAdapterConfig, | ||
| MCPAdapterConfig, | ||
| ], | ||
| Field(discriminator="type"), | ||
| ] | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Top-level config model | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| class HydraConfig(BaseModel): | ||
| framework: FrameworkConfig = FrameworkConfig() | ||
| database: DatabaseConfig = DatabaseConfig() | ||
| dashboard: DashboardConfig = DashboardConfig() | ||
| context_bus: ContextBusConfig = ContextBusConfig() | ||
| profiles: list[ProfileConfig] = [] | ||
| adapters: dict[str, AdapterConfig] = {} | ||
|
Comment on lines
+109
to
+115
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,10 @@ | |
| if _project_root not in sys.path: | ||
| sys.path.insert(0, _project_root) | ||
|
|
||
| from pydantic import ValidationError | ||
|
|
||
| from hydra.config_schema import HydraConfig | ||
|
|
||
|
|
||
| def setup_logging(level: str = "INFO") -> None: | ||
| """Configure structured logging for the framework.""" | ||
|
|
@@ -28,17 +32,22 @@ def setup_logging(level: str = "INFO") -> None: | |
| logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING) | ||
|
|
||
|
|
||
| def load_config(config_path: str) -> dict: | ||
| """Load framework configuration from JSON file.""" | ||
| def load_config(config_path: str) -> HydraConfig: | ||
| """Load and validate framework configuration from JSON file.""" | ||
| path = Path(config_path) | ||
| if not path.exists(): | ||
| logging.error("Config file not found: %s", config_path) | ||
| sys.exit(1) | ||
| with open(path) as f: | ||
| return json.load(f) | ||
| raw = json.load(f) | ||
| try: | ||
| return HydraConfig.model_validate(raw) | ||
| except ValidationError as exc: | ||
| logging.error("Invalid configuration: %s", exc) | ||
| sys.exit(1) | ||
|
Comment on lines
41
to
+47
|
||
|
|
||
|
|
||
| async def run(config: dict, dashboard_host: str, dashboard_port: int) -> None: | ||
| async def run(config: HydraConfig, dashboard_host: str, dashboard_port: int) -> None: | ||
| """Main async entry point — boots runtime and dashboard.""" | ||
| from hydra.core.runtime import HydraRuntime | ||
| from hydra.web_dashboard.app import create_app | ||
|
|
@@ -55,7 +64,7 @@ async def run(config: dict, dashboard_host: str, dashboard_port: int) -> None: | |
|
|
||
| # Create and start dashboard | ||
| dashboard_task = None | ||
| if config.get("dashboard", {}).get("enabled", True): | ||
| if config.dashboard.enabled: | ||
| app = create_app(runtime) | ||
| import uvicorn | ||
| uvi_config = uvicorn.Config( | ||
|
|
@@ -161,13 +170,12 @@ def main() -> None: | |
| config = load_config(args.config) | ||
|
|
||
| # Setup logging | ||
| log_level = "DEBUG" if args.debug else config.get("framework", {}).get("log_level", "INFO") | ||
| log_level = "DEBUG" if args.debug else config.framework.log_level | ||
| setup_logging(log_level) | ||
|
|
||
| # Resolve dashboard host/port | ||
| dashboard_config = config.get("dashboard", {}) | ||
| host = args.host or dashboard_config.get("host", "0.0.0.0") | ||
| port = args.port or dashboard_config.get("port", 8080) | ||
| host = args.host or config.dashboard.host | ||
| port = args.port or config.dashboard.port | ||
|
|
||
| # Run | ||
| try: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FrameworkConfig.validate_log_level()hard-codes a limited set of level names. This will reject validlogginglevel aliases likeWARN/FATAL(andNOTSET), which were previously accepted bysetup_logging()and are part of Python’s logging level name mapping. Consider validating againstlogging._nameToLevel(or a curated list derived from it) and make the error message deterministic (avoid embedding aset, whose ordering is not stable).