diff --git a/api/src/app/schemas/host_schemas.py b/api/src/app/schemas/host_schemas.py index 2fb218b5..888a6721 100644 --- a/api/src/app/schemas/host_schemas.py +++ b/api/src/app/schemas/host_schemas.py @@ -1,11 +1,12 @@ from ipaddress import IPv4Address +from typing import Self from pydantic import ( BaseModel, ConfigDict, Field, - ValidationInfo, field_validator, + model_validator, ) from src.app.enums.operating_systems import OpenLabsOS @@ -82,20 +83,14 @@ def validate_hostname(cls, hostname: str) -> str: raise ValueError(msg) return hostname - @field_validator("size") - @classmethod - def validate_size(cls, size: int, info: ValidationInfo) -> int: + @model_validator(mode="after") + def validate_size(self) -> Self: """Check VM disk size is sufficient.""" - os: OpenLabsOS | None = info.data.get("os") - - if os is None: - msg = "OS field not set to OpenLabsOS type." + if not is_valid_disk_size(self.os, self.size): + msg = f"Disk size {self.size}GB too small for OS: {self.os.value}. Minimum disk size: {OS_SIZE_THRESHOLD[self.os]}GB." raise ValueError(msg) - if not is_valid_disk_size(os, size): - msg = f"Disk size {size}GB too small for OS: {os.value}. Minimum disk size: {OS_SIZE_THRESHOLD[os]}GB." - raise ValueError(msg) - return size + return self # ==================== Blueprints ===================== diff --git a/api/src/app/schemas/range_schemas.py b/api/src/app/schemas/range_schemas.py index c6172072..e00e94e5 100644 --- a/api/src/app/schemas/range_schemas.py +++ b/api/src/app/schemas/range_schemas.py @@ -1,8 +1,13 @@ from datetime import datetime, timezone from ipaddress import IPv4Address -from typing import Any +from typing import Any, Self, Sequence -from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + model_validator, +) from ..enums.providers import OpenLabsProvider from ..enums.range_states import RangeState @@ -14,6 +19,7 @@ BlueprintVPCSchema, DeployedVPCCreateSchema, DeployedVPCSchema, + VPCCommonSchema, ) @@ -32,6 +38,36 @@ class RangeCommonSchema(BaseModel): vpn: bool = Field(default=False, description="Automatic VPN configuration.") +class RangeCreateValidationMixin(BaseModel): + """Mixin class with common validation for all range creation schemas.""" + + vpcs: Sequence[VPCCommonSchema] + + @model_validator(mode="after") + def validate_unique_vpc_names(self) -> Self: + """Check VPC names are unique.""" + if not self.vpcs: + return self + + vpc_names = [vpc.name for vpc in self.vpcs] + if len(vpc_names) != len(set(vpc_names)): + msg = "All VPCs in the range must have unique names." + raise ValueError(msg) + return self + + @model_validator(mode="after") + def validate_mutually_exclusive_vpcs(self) -> Self: + """Check that VPCs do not overlap.""" + if not self.vpcs: + return self + + vpc_cidrs = [vpc.cidr for vpc in self.vpcs] + if not mutually_exclusive_networks_v4(vpc_cidrs): + msg = "All VPCs in the range must be mutually exclusive (not overlap)." + raise ValueError(msg) + return self + + # ==================== Blueprints ===================== @@ -48,42 +84,13 @@ class BlueprintRangeBaseSchema(RangeCommonSchema): pass -class BlueprintRangeCreateSchema(BlueprintRangeBaseSchema): +class BlueprintRangeCreateSchema(BlueprintRangeBaseSchema, RangeCreateValidationMixin): """Schema to create blueprint range objects.""" vpcs: list[BlueprintVPCCreateSchema] = Field( ..., description="All blueprint VPCs in range." ) - @field_validator("vpcs") - @classmethod - def validate_unique_vpc_names( - cls, vpcs: list[BlueprintVPCCreateSchema], info: ValidationInfo - ) -> list[BlueprintVPCCreateSchema]: - """Check VPC names are unique.""" - vpc_names = [vpc.name for vpc in vpcs] - - if len(vpc_names) != len(set(vpc_names)): - msg = "All VPCs in the range must have unique names." - raise ValueError(msg) - - return vpcs - - @field_validator("vpcs") - @classmethod - def validate_mutually_exclusive_vpcs( - cls, vpcs: list[BlueprintVPCCreateSchema], info: ValidationInfo - ) -> list[BlueprintVPCCreateSchema]: - """Check that VPCs do not overlap.""" - vpc_cidrs = [vpc.cidr for vpc in vpcs] - - if not mutually_exclusive_networks_v4(vpc_cidrs): - - msg = "All VPCs in range should be mutually exclusive (not overlap)." - raise ValueError(msg) - - return vpcs - class BlueprintRangeSchema(BlueprintRangeBaseSchema): """Blueprint range object.""" @@ -152,42 +159,13 @@ class DeployedRangeBaseSchema(RangeCommonSchema): ) -class DeployedRangeCreateSchema(DeployedRangeBaseSchema): +class DeployedRangeCreateSchema(DeployedRangeBaseSchema, RangeCreateValidationMixin): """Schema to create deployed range object.""" vpcs: list[DeployedVPCCreateSchema] = Field( ..., description="Deployed VPCs in the range." ) - @field_validator("vpcs") - @classmethod - def validate_unique_vpc_names( - cls, vpcs: list[DeployedVPCCreateSchema], info: ValidationInfo - ) -> list[DeployedVPCCreateSchema]: - """Check VPC names are unique.""" - vpc_names = [vpc.name for vpc in vpcs] - - if len(vpc_names) != len(set(vpc_names)): - msg = "All VPCs in the range must have unique names." - raise ValueError(msg) - - return vpcs - - @field_validator("vpcs") - @classmethod - def validate_mutually_exclusive_vpcs( - cls, vpcs: list[DeployedVPCCreateSchema], info: ValidationInfo - ) -> list[DeployedVPCCreateSchema]: - """Check that VPCs do not overlap.""" - vpc_cidrs = [vpc.cidr for vpc in vpcs] - - if not mutually_exclusive_networks_v4(vpc_cidrs): - - msg = "All VPCs in range should be mutually exclusive (not overlap)." - raise ValueError(msg) - - return vpcs - class DeployedRangeSchema(DeployedRangeBaseSchema): """Deployed range object.""" diff --git a/api/src/app/schemas/subnet_schemas.py b/api/src/app/schemas/subnet_schemas.py index db5b5366..0244f51f 100644 --- a/api/src/app/schemas/subnet_schemas.py +++ b/api/src/app/schemas/subnet_schemas.py @@ -1,6 +1,13 @@ from ipaddress import IPv4Network - -from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator +from typing import Self, Sequence + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_validator, +) from ..validators.names import OPENLABS_NAME_REGEX from ..validators.network import max_num_hosts_in_subnet @@ -9,6 +16,7 @@ BlueprintHostSchema, DeployedHostCreateSchema, DeployedHostSchema, + HostCommonSchema, ) @@ -27,21 +35,13 @@ class SubnetCommonSchema(BaseModel): ) -# ==================== Blueprints ===================== - - -class BlueprintSubnetBaseSchema(SubnetCommonSchema): - """Base pydantic class for all blueprint subnet objects.""" +class SubnetCreateValidationMixin(BaseModel): + """Mixin class with common validation for all subnet creation schemas.""" - pass - - -class BlueprintSubnetCreateSchema(BlueprintSubnetBaseSchema): - """Schema to create blueprint subnet objects.""" - - hosts: list[BlueprintHostCreateSchema] = Field( - ..., description="All blueprint hosts in the subnet." - ) + # Forward references + name: str + cidr: IPv4Network + hosts: Sequence[HostCommonSchema] @field_validator("cidr") @classmethod @@ -52,38 +52,48 @@ def validate_subnet_private_cidr_range(cls, cidr: IPv4Network) -> IPv4Network: raise ValueError(msg) return cidr - @field_validator("hosts") - @classmethod - def validate_unique_hostnames( - cls, hosts: list[BlueprintHostCreateSchema] - ) -> list[BlueprintHostCreateSchema]: + @model_validator(mode="after") + def validate_unique_hostnames(self) -> Self: """Check hostnames are unique.""" - hostnames = [host.hostname for host in hosts] + if not self.hosts: + return self + + hostnames = [host.hostname for host in self.hosts] if len(hostnames) != len(set(hostnames)): - msg = "All hostnames must be unique." + msg = f"All hostnames in subnet: {self.name} must be unique." raise ValueError(msg) - return hosts - @field_validator("hosts") - @classmethod - def validate_max_number_hosts( - cls, hosts: list[BlueprintHostCreateSchema], info: ValidationInfo - ) -> list[BlueprintHostCreateSchema]: + return self + + @model_validator(mode="after") + def validate_max_number_hosts(self) -> Self: """Check that the number of hosts does not exceed subnet CIDR.""" - subnet_cidr = info.data.get("cidr") + max_num_hosts = max_num_hosts_in_subnet(self.cidr) - if not subnet_cidr: - msg = "Subnet missing CIDR." + if len(self.hosts) > max_num_hosts: + msg = f"Too many hosts in subnet: {self.name}! Max: {max_num_hosts}, Requested: {len(self.hosts)}" raise ValueError(msg) - max_num_hosts = max_num_hosts_in_subnet(subnet_cidr) - num_requested_hosts = len(hosts) + return self - if num_requested_hosts > max_num_hosts: - msg = f"Too many hosts in subnet! Max: {max_num_hosts}, Requested: {num_requested_hosts}" - raise ValueError(msg) - return hosts +# ==================== Blueprints ===================== + + +class BlueprintSubnetBaseSchema(SubnetCommonSchema): + """Base pydantic class for all blueprint subnet objects.""" + + pass + + +class BlueprintSubnetCreateSchema( + BlueprintSubnetBaseSchema, SubnetCreateValidationMixin +): + """Schema to create blueprint subnet objects.""" + + hosts: list[BlueprintHostCreateSchema] = Field( + ..., description="All blueprint hosts in the subnet." + ) class BlueprintSubnetSchema(BlueprintSubnetBaseSchema): @@ -119,55 +129,13 @@ class DeployedSubnetBaseSchema(SubnetCommonSchema): ) -class DeployedSubnetCreateSchema(DeployedSubnetBaseSchema): +class DeployedSubnetCreateSchema(DeployedSubnetBaseSchema, SubnetCreateValidationMixin): """Schema to create deployed subnet objects.""" hosts: list[DeployedHostCreateSchema] = Field( ..., description="Deployed hosts within subnet." ) - @field_validator("cidr") - @classmethod - def validate_subnet_private_cidr_range(cls, cidr: IPv4Network) -> IPv4Network: - """Check subnet CIDR ranges are private.""" - if not cidr.is_private: - msg = "Subnets should only use private CIDR ranges." - raise ValueError(msg) - return cidr - - @field_validator("hosts") - @classmethod - def validate_unique_hostnames( - cls, hosts: list[DeployedHostCreateSchema] - ) -> list[DeployedHostCreateSchema]: - """Check hostnames are unique.""" - hostnames = [host.hostname for host in hosts] - if len(hostnames) != len(set(hostnames)): - msg = "All hostnames must be unique." - raise ValueError(msg) - return hosts - - @field_validator("hosts") - @classmethod - def validate_max_number_hosts( - cls, hosts: list[DeployedHostCreateSchema], info: ValidationInfo - ) -> list[DeployedHostCreateSchema]: - """Check that the number of hosts does not exceed subnet CIDR.""" - subnet_cidr = info.data.get("cidr") - - if not subnet_cidr: - msg = "Subnet missing CIDR." - raise ValueError(msg) - - max_num_hosts = max_num_hosts_in_subnet(subnet_cidr) - num_requested_hosts = len(hosts) - - if num_requested_hosts > max_num_hosts: - msg = f"Too many hosts in subnet! Max: {max_num_hosts}, Requested: {num_requested_hosts}" - raise ValueError(msg) - - return hosts - model_config = ConfigDict(from_attributes=True) diff --git a/api/src/app/schemas/vpc_schemas.py b/api/src/app/schemas/vpc_schemas.py index caa2530b..d1cf080d 100644 --- a/api/src/app/schemas/vpc_schemas.py +++ b/api/src/app/schemas/vpc_schemas.py @@ -1,6 +1,13 @@ from ipaddress import IPv4Network - -from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator +from typing import Self, Sequence + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_validator, +) from ..validators.names import OPENLABS_NAME_REGEX from ..validators.network import all_subnets_contained, mutually_exclusive_networks_v4 @@ -9,6 +16,7 @@ BlueprintSubnetSchema, DeployedSubnetCreateSchema, DeployedSubnetSchema, + SubnetCommonSchema, ) @@ -26,21 +34,13 @@ class VPCCommonSchema(BaseModel): ) -# ==================== Blueprints ===================== - - -class BlueprintVPCBaseSchema(VPCCommonSchema): - """Base pydantic class for all blueprint VPC objects.""" - - pass - +class VPCCreateValidationMixin(BaseModel): + """Mixin class with common validation for all VPC creation schemas.""" -class BlueprintVPCCreateSchema(BlueprintVPCBaseSchema): - """Schema to create blueprint VPC objects.""" - - subnets: list[BlueprintSubnetCreateSchema] = Field( - ..., description="All blueprint subnets in VPC." - ) + # Forward references + name: str + cidr: IPv4Network + subnets: Sequence[SubnetCommonSchema] @field_validator("cidr") @classmethod @@ -51,68 +51,55 @@ def validate_vpc_private_cidr_range(cls, cidr: IPv4Network) -> IPv4Network: raise ValueError(msg) return cidr - @field_validator("subnets") - @classmethod - def validate_unique_subnet_names( - cls, subnets: list[BlueprintSubnetCreateSchema], info: ValidationInfo - ) -> list[BlueprintSubnetCreateSchema]: + @model_validator(mode="after") + def validate_unique_subnet_names(self) -> Self: """Check subnet names are unique.""" - subnet_names = [subnet.name for subnet in subnets] + if not self.subnets: + return self + subnet_names = [subnet.name for subnet in self.subnets] if len(subnet_names) != len(set(subnet_names)): - vpc_name = info.data.get("name") - if not vpc_name: - msg = "VPC is missing a name." - raise ValueError(msg) - - msg = f"All subnet in VPC: {vpc_name} must have unique names." + msg = f"All subnet in VPC: {self.name} must have unique names." raise ValueError(msg) - return subnets + return self - @field_validator("subnets") - @classmethod - def validate_subnets_contained( - cls, subnets: list[BlueprintSubnetCreateSchema], info: ValidationInfo - ) -> list[BlueprintSubnetCreateSchema]: + @model_validator(mode="after") + def validate_subnets_contained(self) -> Self: """Check that the VPC CIDR contains all subnet CIDRs.""" - vpc_name = info.data.get("name") - if not vpc_name: - msg = "VPC is missing a name." + subnet_cidrs = [subnet.cidr for subnet in self.subnets] + if not all_subnets_contained(self.cidr, subnet_cidrs): + msg = f"All subnets in VPC: {self.name} should be contained within: {self.cidr}" raise ValueError(msg) - vpc_cidr = info.data.get("cidr") - if not vpc_cidr: - msg = f"VPC: {vpc_name} missing CIDR." - raise ValueError(msg) + return self - subnet_cidrs = [subnet.cidr for subnet in subnets] - if not all_subnets_contained(vpc_cidr, subnet_cidrs): - msg = ( - f"All subnets in VPC: {vpc_name} should be contained within: {vpc_cidr}" - ) + @model_validator(mode="after") + def validate_mutually_exclusive_subnets(self) -> Self: + """Check that subnets do not overlap.""" + subnet_cidrs = [subnet.cidr for subnet in self.subnets] + if not mutually_exclusive_networks_v4(subnet_cidrs): + msg = f"All subnets in VPC: {self.name} should be mutually exclusive (not overlap)." raise ValueError(msg) - return subnets + return self - @field_validator("subnets") - @classmethod - def validate_mutually_exclusive_subnets( - cls, subnets: list[BlueprintSubnetCreateSchema], info: ValidationInfo - ) -> list[BlueprintSubnetCreateSchema]: - """Check that subnets do not overlap.""" - subnet_cidrs = [subnet.cidr for subnet in subnets] - if not mutually_exclusive_networks_v4(subnet_cidrs): - vpc_name = info.data.get("name") - if not vpc_name: - msg = "VPC is missing a name." - raise ValueError(msg) +# ==================== Blueprints ===================== - msg = f"All subnets in VPC: {vpc_name} should be mutually exclusive (not overlap)." - raise ValueError(msg) - return subnets +class BlueprintVPCBaseSchema(VPCCommonSchema): + """Base pydantic class for all blueprint VPC objects.""" + + pass + + +class BlueprintVPCCreateSchema(BlueprintVPCBaseSchema, VPCCreateValidationMixin): + """Schema to create blueprint VPC objects.""" + + subnets: list[BlueprintSubnetCreateSchema] = Field( + ..., description="All blueprint subnets in VPC." + ) class BlueprintVPCSchema(BlueprintVPCBaseSchema): @@ -148,85 +135,13 @@ class DeployedVPCBaseSchema(VPCCommonSchema): ) -class DeployedVPCCreateSchema(DeployedVPCBaseSchema): +class DeployedVPCCreateSchema(DeployedVPCBaseSchema, VPCCreateValidationMixin): """Schema to create deployed VPC objects.""" subnets: list[DeployedSubnetCreateSchema] = Field( ..., description="Deployed subnets within VPC." ) - @field_validator("cidr") - @classmethod - def validate_vpc_private_cidr_range(cls, cidr: IPv4Network) -> IPv4Network: - """Check VPC CIDR ranges are private.""" - if not cidr.is_private: - msg = "VPCs should only use private CIDR ranges." - raise ValueError(msg) - return cidr - - @field_validator("subnets") - @classmethod - def validate_unique_subnet_names( - cls, subnets: list[DeployedSubnetCreateSchema], info: ValidationInfo - ) -> list[DeployedSubnetCreateSchema]: - """Check subnet names are unique.""" - subnet_names = [subnet.name for subnet in subnets] - - if len(subnet_names) != len(set(subnet_names)): - vpc_name = info.data.get("name") - if not vpc_name: - msg = "VPC is missing a name." - raise ValueError(msg) - - msg = f"All subnet in VPC: {vpc_name} must have unique names." - raise ValueError(msg) - - return subnets - - @field_validator("subnets") - @classmethod - def validate_subnets_contained( - cls, subnets: list[DeployedSubnetCreateSchema], info: ValidationInfo - ) -> list[DeployedSubnetCreateSchema]: - """Check that the VPC CIDR contains all subnet CIDRs.""" - vpc_name = info.data.get("name") - if not vpc_name: - msg = "VPC is missing a name." - raise ValueError(msg) - - vpc_cidr = info.data.get("cidr") - if not vpc_cidr: - msg = f"VPC: {vpc_name} missing CIDR." - raise ValueError(msg) - - subnet_cidrs = [subnet.cidr for subnet in subnets] - if not all_subnets_contained(vpc_cidr, subnet_cidrs): - msg = ( - f"All subnets in VPC: {vpc_name} should be contained within: {vpc_cidr}" - ) - raise ValueError(msg) - - return subnets - - @field_validator("subnets") - @classmethod - def validate_mutually_exclusive_subnets( - cls, subnets: list[DeployedSubnetCreateSchema], info: ValidationInfo - ) -> list[DeployedSubnetCreateSchema]: - """Check that subnets do not overlap.""" - subnet_cidrs = [subnet.cidr for subnet in subnets] - - if not mutually_exclusive_networks_v4(subnet_cidrs): - vpc_name = info.data.get("name") - if not vpc_name: - msg = "VPC is missing a name." - raise ValueError(msg) - - msg = f"All subnets in VPC: {vpc_name} should be mutually exclusive (not overlap)." - raise ValueError(msg) - - return subnets - class DeployedVPCSchema(DeployedVPCBaseSchema): """Deployed VPC object."""