diff --git a/.gitignore b/.gitignore index 1bfea5db..e522da89 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,5 @@ testing-out/ # Misc .vscode/ +# Tests +.coverage \ No newline at end of file diff --git a/api/pyproject.toml b/api/pyproject.toml index 6d6c73e9..5b329793 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -76,7 +76,8 @@ addopts = [ asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "session" filterwarnings = [ - "ignore:cannot collect test class 'Testing':pytest.PytestCollectionWarning" + "ignore:cannot collect test class 'Testing':pytest.PytestCollectionWarning", + "ignore::DeprecationWarning:botocore.auth", ] # Register markers to easily select unit or integration tests. diff --git a/api/requirements.txt b/api/requirements.txt index 3d8b83b0..f62fe2e7 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -21,4 +21,7 @@ argon2-cffi>=23.1.0 # Async Jobs arq>=0.26.3 aiofiles>=24.1.0 -tenacity>=9.1.2 \ No newline at end of file +tenacity>=9.1.2 + +# Provider SDKs +boto3>=1.38.0 \ No newline at end of file diff --git a/api/src/app/api/v1/users.py b/api/src/app/api/v1/users.py index f86d5c34..222b4821 100644 --- a/api/src/app/api/v1/users.py +++ b/api/src/app/api/v1/users.py @@ -1,25 +1,25 @@ import base64 -from datetime import UTC, datetime from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import JSONResponse from sqlalchemy import select from sqlalchemy.ext.asyncio.session import AsyncSession +from src.app.cloud.creds_factory import CredsFactory + from ...core.auth.auth import get_current_user from ...core.config import settings from ...core.db.database import async_get_db +from ...crud.crud_secrets import get_user_secrets, upsert_user_secrets from ...crud.crud_users import get_user_by_id, update_user_password from ...models.secret_model import SecretModel from ...models.user_model import UserModel from ...schemas.message_schema import ( - AWSUpdateSecretMessageSchema, - AzureUpdateSecretMessageSchema, + MessageSchema, UpdatePasswordMessageSchema, ) from ...schemas.secret_schema import ( - AWSSecrets, - AzureSecrets, + AnySecrets, CloudSecretStatusSchema, UserSecretResponseSchema, ) @@ -116,7 +116,7 @@ async def update_password( @router.get("/me/secrets") -async def get_user_secrets( +async def fetch_user_secrets( current_user: UserModel = Depends(get_current_user), # noqa: B008 db: AsyncSession = Depends(async_get_db), # noqa: B008 ) -> UserSecretResponseSchema: @@ -170,115 +170,67 @@ async def get_user_secrets( ) -@router.post("/me/secrets/aws") -async def update_aws_secrets( - aws_secrets: AWSSecrets, +@router.post("/me/secrets") +async def update_user_secrets( + creds: AnySecrets, current_user: UserModel = Depends(get_current_user), # noqa: B008 db: AsyncSession = Depends(async_get_db), # noqa: B008 -) -> AWSUpdateSecretMessageSchema: - """Update the current user's AWS secrets. +) -> MessageSchema: + """Update the current user's secrets. Args: ---- - aws_secrets (AWSSecrets): The AWS credentials to store. + creds (AnySecrets): The provider credentials to store. current_user (UserModel): The authenticated user. db (Session): Database connection. Returns: ------- - AWSUpdateSecretMessageSchema: Status message. + MessageSchema: Status message of updating user secrets. """ - # Fetch secrets explicitly from the database - stmt = select(SecretModel).where(SecretModel.user_id == current_user.id) - result = await db.execute(stmt) - secrets = result.scalars().first() + # Fetch secrets from the database + secrets = await get_user_secrets(db, current_user.id) if not secrets: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="User secrets record not found", - ) - - # Encrypt the AWS credentials using the user's public key - if not current_user.public_key: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="User encryption keys not set up. Please register a new account.", + detail="User secrets record not found!", ) - # Convert to dictionary for encryption - aws_data = { - "aws_access_key": aws_secrets.aws_access_key, - "aws_secret_key": aws_secrets.aws_secret_key, - } - - # Encrypt with the user's public key - encrypted_data = encrypt_with_public_key(aws_data, current_user.public_key) - - # Update the secrets with encrypted values - secrets.aws_access_key = encrypted_data["aws_access_key"] - secrets.aws_secret_key = encrypted_data["aws_secret_key"] - secrets.aws_created_at = datetime.now(UTC) - await db.commit() - - return AWSUpdateSecretMessageSchema(message="AWS credentials updated successfully") - - -@router.post("/me/secrets/azure") -async def update_azure_secrets( - azure_secrets: AzureSecrets, - current_user: UserModel = Depends(get_current_user), # noqa: B008 - db: AsyncSession = Depends(async_get_db), # noqa: B008 -) -> AzureUpdateSecretMessageSchema: - """Update the current user's Azure secrets. - - Args: - ---- - azure_secrets (AzureSecrets): The Azure credentials to store. - current_user (UserModel): The authenticated user. - db (Session): Database connection. + # Verify credentials are valid before storing + creds_obj = CredsFactory.create_creds_verification(credentials=creds) - Returns: - ------- - AzureUpdateSecretMessageSchema: Success message. + verified, msg = creds_obj.verify_creds() - """ - # Fetch secrets explicitly from the database - stmt = select(SecretModel).where(SecretModel.user_id == current_user.id) - result = await db.execute(stmt) - secrets = result.scalars().first() - if not secrets: + if not verified: raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="User secrets record not found", + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=msg.message, ) - # Encrypt the Azure credentials using the user's public key + # Encrypt the credentials using the user's public key if not current_user.public_key: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="User encryption keys not set up. Please register a new account.", ) - # Convert to dictionary for encryption - azure_data = { - "azure_client_id": azure_secrets.azure_client_id, - "azure_client_secret": azure_secrets.azure_client_secret, - "azure_tenant_id": azure_secrets.azure_tenant_id, - "azure_subscription_id": azure_secrets.azure_subscription_id, - } + # Convert provided credentials to dictionary for encryption + user_creds = creds_obj.get_user_creds() # Encrypt with the user's public key - encrypted_data = encrypt_with_public_key(azure_data, current_user.public_key) + encrypted_data = encrypt_with_public_key( + data=user_creds, public_key_b64=current_user.public_key + ) # Update the secrets with encrypted values - secrets.azure_client_id = encrypted_data["azure_client_id"] - secrets.azure_client_secret = encrypted_data["azure_client_secret"] - secrets.azure_tenant_id = encrypted_data["azure_tenant_id"] - secrets.azure_subscription_id = encrypted_data["azure_subscription_id"] - secrets.azure_created_at = datetime.now(UTC) - await db.commit() - - return AzureUpdateSecretMessageSchema( - message="Azure credentials updated successfully" + secrets = creds_obj.update_secret_schema( + secrets=secrets, encrypted_data=encrypted_data + ) + + # Add new secrets to database + await upsert_user_secrets(db, secrets, current_user.id) + + return MessageSchema( + message=f"{creds.provider.value.upper()} credentials successfully verified and updated." ) diff --git a/api/src/app/cloud/__init__.py b/api/src/app/cloud/__init__.py new file mode 100644 index 00000000..7036f4db --- /dev/null +++ b/api/src/app/cloud/__init__.py @@ -0,0 +1 @@ +"""Cloud controls for the OpenLabs API.""" diff --git a/api/src/app/cloud/aws_creds.py b/api/src/app/cloud/aws_creds.py new file mode 100644 index 00000000..38d6c8e6 --- /dev/null +++ b/api/src/app/cloud/aws_creds.py @@ -0,0 +1,163 @@ +import logging +from datetime import UTC, datetime +from typing import List, Tuple + +import boto3 +from botocore.exceptions import ClientError + +from src.app.schemas.message_schema import MessageSchema +from src.app.schemas.secret_schema import AWSSecrets, SecretSchema + +from .base_creds import AbstractBaseCreds + +# Configure logging +logger = logging.getLogger(__name__) + + +class AWSCreds(AbstractBaseCreds): + """Credential verification for AWS.""" + + credentials: AWSSecrets + + def __init__(self, credentials: AWSSecrets) -> None: + """Initialize AWS credentials verification object.""" + self.credentials = credentials + + def get_user_creds(self) -> dict[str, str]: + """Convert user AWS secrets to dictionary for encryption.""" + return { + "aws_access_key": self.credentials.aws_access_key, + "aws_secret_key": self.credentials.aws_secret_key, + } + + def update_secret_schema( + self, secrets: SecretSchema, encrypted_data: dict[str, str] + ) -> SecretSchema: + """Update user secrets schema with newly encrypted secrets.""" + secrets.aws_access_key = encrypted_data["aws_access_key"] + secrets.aws_secret_key = encrypted_data["aws_secret_key"] + secrets.aws_created_at = datetime.now(UTC) + return secrets + + def verify_creds(self) -> Tuple[bool, MessageSchema]: + """Verify credentials authenticate to an AWS account.""" + try: + # --- Step 1: Basic Authentication with STS --- + # Created shared session for authentication and IAM permission check + session = boto3.Session( + aws_access_key_id=self.credentials.aws_access_key, + aws_secret_access_key=self.credentials.aws_secret_key, + ) + client = session.client("sts") + caller_identity = ( + client.get_caller_identity() + ) # will raise an error if not valid + caller_arn = caller_identity["Arn"] + logger.info( + "AWS credentials successfully authenticated for ARN: %s", caller_arn + ) + + if caller_arn.endswith( + ":root" + ): # If root access key credentials are used, skip permissions check as root user has all permissions + return ( + True, + MessageSchema( + message="AWS credentials authenticated and all required permissions are present." + ), + ) + + # --- Step 2: Simulate permissions for a sample of minimum critical actions --- + iam_client = session.client("iam") + + actions_to_test = [ + # For Instance + "ec2:RunInstances", + "ec2:TerminateInstances", + "ec2:DescribeInstances", + # For Vpc + "ec2:CreateVpc", + "ec2:DeleteVpc", + "ec2:DescribeVpcs", + # For Subnet + "ec2:CreateSubnet", + "ec2:DeleteSubnet", + "ec2:DescribeSubnets", + # For InternetGateway + "ec2:CreateInternetGateway", + "ec2:DeleteInternetGateway", + "ec2:AttachInternetGateway", + "ec2:DetachInternetGateway", + # For Eip and NatGateway + "ec2:AllocateAddress", # Create EIP + "ec2:ReleaseAddress", # Delete EIP + "ec2:AssociateAddress", + "ec2:CreateNatGateway", + "ec2:DeleteNatGateway", + # For KeyPair + "ec2:CreateKeyPair", + "ec2:DeleteKeyPair", + # For SecurityGroup and SecurityGroupRule + "ec2:CreateSecurityGroup", + "ec2:DeleteSecurityGroup", + "ec2:AuthorizeSecurityGroupIngress", + "ec2:RevokeSecurityGroupIngress", + "ec2:AuthorizeSecurityGroupEgress", + "ec2:RevokeSecurityGroupEgress", + # For RouteTable, Route, and RouteTableAssociation + "ec2:CreateRouteTable", + "ec2:DeleteRouteTable", + "ec2:CreateRoute", + "ec2:DeleteRoute", + "ec2:AssociateRouteTable", + "ec2:DisassociateRouteTable", + # For Transit Gateway --- + "ec2:CreateTransitGateway", + "ec2:DeleteTransitGateway", + "ec2:CreateTransitGatewayVpcAttachment", + "ec2:DeleteTransitGatewayVpcAttachment", + "ec2:CreateTransitGatewayRoute", + "ec2:DeleteTransitGatewayRoute", + "ec2:DescribeTransitGateways", + "ec2:DescribeTransitGatewayVpcAttachments", + ] + + simulation_results = iam_client.simulate_principal_policy( + PolicySourceArn=caller_arn, ActionNames=actions_to_test + ) + + # --- Step 3: Evaluate the simulation results --- + denied_actions: List[str] = [] + for result in simulation_results["EvaluationResults"]: + if result["EvalDecision"] != "allowed": + denied_actions.append(result["EvalActionName"]) + + if denied_actions: + error_message = f"Authentication succeeded, but the user/group is missing required permissions. The following actions were denied: {', '.join(denied_actions)}" + logger.error(error_message) + return ( + False, + MessageSchema( + message=f"Insufficient permissions for your AWS account user/group. Please ensure the following permissions are added: {', '.join(denied_actions)}" + ), + ) + logger.info("All simulated actions were allowed for ARN: %s", caller_arn) + return ( + True, + MessageSchema( + message="AWS credentials authenticated and all required permissions are present." + ), + ) + except ClientError as e: + error_code = e.response["Error"]["Code"] + if error_code in ("InvalidClientTokenId", "SignatureDoesNotMatch"): + message = "AWS credentials could not be authenticated. Please ensure you are providing credentials that are linked to a valid AWS account." + elif error_code == "AccessDenied": + message = "AWS credentials are valid, but lack permissions to perform the permissions verification. Please ensure you give your AWS account user/group has the iam:SimulatePrincipalPolicy permission attached to a policy." + else: + message = e.response["Error"]["Message"] + logger.error("AWS verification failed: %s", message) + return ( + False, + MessageSchema(message=message), + ) diff --git a/api/src/app/cloud/base_creds.py b/api/src/app/cloud/base_creds.py new file mode 100644 index 00000000..d09ffc58 --- /dev/null +++ b/api/src/app/cloud/base_creds.py @@ -0,0 +1,31 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +from src.app.schemas.message_schema import MessageSchema +from src.app.schemas.secret_schema import AnySecrets, SecretSchema + + +class AbstractBaseCreds(ABC): + """Abstract class to enforce common credential verification functionality across range cloud providers.""" + + @abstractmethod + def __init__(self, credentials: AnySecrets) -> None: + """Expected constructor for all credential subclasses.""" # noqa: D401 + pass + + @abstractmethod + def get_user_creds(self) -> dict[str, str]: + """Convert user secrets to dictionary for encryption.""" + pass + + @abstractmethod + def update_secret_schema( + self, secrets: SecretSchema, encrypted_data: dict[str, str] + ) -> SecretSchema: + """Update user secrets schema with newly encrypted secrets.""" + pass + + @abstractmethod + def verify_creds(self) -> Tuple[bool, MessageSchema]: + """Verify that user provided credentials properly authenticate to a provider account.""" + pass diff --git a/api/src/app/cloud/creds_factory.py b/api/src/app/cloud/creds_factory.py new file mode 100644 index 00000000..0efdf92d --- /dev/null +++ b/api/src/app/cloud/creds_factory.py @@ -0,0 +1,45 @@ +import logging +from typing import ClassVar, Type + +from src.app.cloud.aws_creds import AWSCreds +from src.app.cloud.base_creds import AbstractBaseCreds +from src.app.enums.providers import OpenLabsProvider +from src.app.schemas.secret_schema import AnySecrets + +# Configure logging +logger = logging.getLogger(__name__) + + +class CredsFactory: + """Create creds objects.""" + + _registry: ClassVar[dict[OpenLabsProvider, Type[AbstractBaseCreds]]] = { + OpenLabsProvider.AWS: AWSCreds, + } + + @classmethod + def create_creds_verification(cls, credentials: AnySecrets) -> AbstractBaseCreds: + """Create creds object. + + **Note:** This function accepts a creation schema as the OpenLabs resource ID is not required + for terraform. + + Args: + ---- + cls (CredsFactory): The CredsFactory class. + provider (OpenLabsProvider): Cloud provider the credentials to verify are for + credentials (dict[str, Any]): User cloud credentials to verify + + Returns: + ------- + AbstractBaseCreds: Creds object that will be used to verify the user cloud credentials provided + + """ + creds_class = cls._registry.get(credentials.provider) + + if creds_class is None: + msg = f"Failed to build creds object. Non-existent provider given: {credentials.provider}" + logger.error(msg) + raise ValueError(msg) + + return creds_class(credentials=credentials) diff --git a/api/src/app/crud/crud_secrets.py b/api/src/app/crud/crud_secrets.py new file mode 100644 index 00000000..4f3ef855 --- /dev/null +++ b/api/src/app/crud/crud_secrets.py @@ -0,0 +1,61 @@ +import logging + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.app.models.secret_model import SecretModel + +from ..schemas.secret_schema import SecretSchema + +logger = logging.getLogger(__name__) + + +async def get_user_secrets(db: AsyncSession, user_id: int) -> SecretSchema | None: + """Get user provider secrets. + + Args: + ---- + db (Session): Database connection. + user_id (int): ID of the user requesting data. + + Returns: + ------- + Optional[SecretSchema]: User provider secrets if it exists in the database. + + """ + stmt = select(SecretModel).where(SecretModel.user_id == user_id) + result = await db.execute(stmt) + + if not result: + logger.info( + "Failed to fetch secrets for user: %s. Not found in database!", user_id + ) + return None + + secrets = result.scalars().first() + + return SecretSchema.model_validate(secrets) + + +async def upsert_user_secrets( + db: AsyncSession, secrets: SecretSchema, user_id: int +) -> None: + """Update user provider secrets. + + Args: + ---- + db (Session): Database connection. + user_id (int): ID of the user requesting data. + secrets (SecretSchema): User secrets record containing new secrets to add to database + + Returns: + ------- + None + + """ + db_object_to_merge = SecretModel(user_id=user_id, **secrets.model_dump()) + # 2. Merge the instance into the session. + # SQLAlchemy checks if a record with this primary key exists. + # - If yes, it copies the new data onto the existing record. + # - If no, it stages a new record for insertion. + await db.merge(db_object_to_merge) diff --git a/api/src/app/enums/providers.py b/api/src/app/enums/providers.py index 23ae894b..012032a3 100644 --- a/api/src/app/enums/providers.py +++ b/api/src/app/enums/providers.py @@ -1,7 +1,7 @@ from enum import Enum -class OpenLabsProvider(Enum): +class OpenLabsProvider(str, Enum): """OpenLabs supported cloud providers.""" AWS = "aws" diff --git a/api/src/app/schemas/secret_schema.py b/api/src/app/schemas/secret_schema.py index b6fbdccd..37a4f6ae 100644 --- a/api/src/app/schemas/secret_schema.py +++ b/api/src/app/schemas/secret_schema.py @@ -1,6 +1,9 @@ from datetime import datetime, timezone +from typing import Annotated, Literal, Union -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from src.app.enums.providers import OpenLabsProvider class SecretBaseSchema(BaseModel): @@ -54,9 +57,16 @@ class SecretSchema(SecretBaseSchema): model_config = ConfigDict(from_attributes=True) -class AWSSecrets(BaseModel): +class BaseSecrets(BaseModel): + """Base secret object for setting secrets on OpenLabs.""" + + provider: OpenLabsProvider + + +class AWSSecrets(BaseSecrets): """AWS secret object for setting secrets on OpenLabs.""" + provider: Literal[OpenLabsProvider.AWS] = OpenLabsProvider.AWS aws_access_key: str = Field( ..., description="Access key for AWS account", @@ -66,10 +76,60 @@ class AWSSecrets(BaseModel): description="Secret key for AWS account", ) - -class AzureSecrets(BaseModel): + @field_validator("aws_access_key") + @classmethod + def validate_access_key(cls, aws_access_key: str) -> str: + """Check AWS access key is correct length. + + Args: + ---- + cls: AWSSecrets object. + aws_access_key (str): AWS access key. + + Returns: + ------- + str: AWS access key. + + """ + access_key_length = 20 + if len(aws_access_key.strip()) == 0: + msg = "No AWS access key provided." + raise ValueError(msg) + if len(aws_access_key.strip()) != access_key_length: + msg = "Invalid AWS access key format." + raise ValueError(msg) + return aws_access_key + + @field_validator("aws_secret_key") + @classmethod + def validate_secret_key(cls, aws_secret_key: str) -> str: + """Check AWS secret key is correct length. + + Args: + ---- + cls: AWSSecrets object. + aws_secret_key (str): AWS secret key. + + Returns: + ------- + str: AWS access key. + + """ + secret_key_length = 40 + if len(aws_secret_key.strip()) == 0: + msg = "No AWS secret key provided." + raise ValueError(msg) + if len(aws_secret_key.strip()) != secret_key_length: + msg = "Invalid AWS secret key format." + raise ValueError(msg) + return aws_secret_key + + +class AzureSecrets(BaseSecrets): """Azure secret object for setting secrets on OpenLabs.""" + provider: Literal[OpenLabsProvider.AZURE] = OpenLabsProvider.AZURE + azure_client_id: str = Field( ..., description="Client ID for Azure", @@ -88,6 +148,9 @@ class AzureSecrets(BaseModel): ) +AnySecrets = Annotated[Union[AWSSecrets, AzureSecrets], Field(discriminator="provider")] + + class CloudSecretStatusSchema(BaseModel): """General response schema for a single cloud provider.""" diff --git a/api/tests/api_test_utils.py b/api/tests/api_test_utils.py index 0fef733e..797417f5 100644 --- a/api/tests/api_test_utils.py +++ b/api/tests/api_test_utils.py @@ -225,7 +225,7 @@ async def is_logged_in(client: AsyncClient) -> bool: async def add_cloud_credentials( auth_client: AsyncClient, provider: OpenLabsProvider, - credentials_payload: dict[str, Any], + credentials: dict[str, Any], ) -> bool: """Add cloud credentials to the authenticated client's account. @@ -233,7 +233,7 @@ async def add_cloud_credentials( ---- auth_client (AsyncClient): Any authenticated httpx client. NOT THE `auth_client` FIXTURE! provider (OpenLabsProvider): A valid OpenLabs cloud provider to configure credentials for. - credentials_payload (dict[str, Any]): Dictionary representation of corresponding cloud provider's credential schema. + credentials (dict[str, Any]): Dictionary representation of corresponding cloud provider's credential schema. Returns: ------- @@ -242,10 +242,13 @@ async def add_cloud_credentials( """ base_route = get_api_base_route(version=1) - if not credentials_payload: + if not credentials: logger.error("Failed to add cloud credentials. Payload empty!") return False + credentials_payload = credentials + credentials_payload["provider"] = provider + # Verify we are logged in logged_in = await is_logged_in(auth_client) if not logged_in: @@ -255,9 +258,8 @@ async def add_cloud_credentials( return False # Submit credentials - provider_url = provider.value.lower() response = await auth_client.post( - f"{base_route}/users/me/secrets/{provider_url}", json=credentials_payload + f"{base_route}/users/me/secrets", json=credentials_payload ) if response.status_code != status.HTTP_200_OK: logger.error( diff --git a/api/tests/aws_test_utils.py b/api/tests/aws_test_utils.py new file mode 100644 index 00000000..c2b8c017 --- /dev/null +++ b/api/tests/aws_test_utils.py @@ -0,0 +1,26 @@ +import logging +import os + +from src.app.enums.providers import OpenLabsProvider +from tests.deploy_test_utils import get_provider_test_creds + +# Configure logging +logger = logging.getLogger(__name__) + + +def set_test_boto_creds() -> bool: + """Set credentials for the AWS boto client.""" + aws_provider = OpenLabsProvider.AWS + creds = get_provider_test_creds(aws_provider) + if not creds: + logger.error( + "Failed to set test boto creds. Test credentials for %s not set.", + aws_provider.value.upper(), + ) + return False + + # Set environment variables + os.environ["AWS_ACCESS_KEY_ID"] = creds["aws_access_key"] + os.environ["AWS_SECRET_ACCESS_KEY"] = creds["aws_secret_key"] + + return True diff --git a/api/tests/common/api/v1/config.py b/api/tests/common/api/v1/config.py index e6e9821f..64bed83f 100644 --- a/api/tests/common/api/v1/config.py +++ b/api/tests/common/api/v1/config.py @@ -336,13 +336,15 @@ } # Test data for AWS secrets -aws_secrets_payload = { +aws_secrets_payload: dict[str, Any] = { + "provider": "aws", "aws_access_key": "AKIAIOSFODNN7EXAMPLE", "aws_secret_key": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", } # Test data for Azure secrets azure_secrets_payload = { + "provider": "azure", "azure_client_id": "00000000-0000-0000-0000-000000000000", "azure_client_secret": "example-client-secret-value", "azure_tenant_id": "00000000-0000-0000-0000-000000000000", diff --git a/api/tests/common/api/v1/test_users.py b/api/tests/common/api/v1/test_users.py index 577dc43a..bd00fe6b 100644 --- a/api/tests/common/api/v1/test_users.py +++ b/api/tests/common/api/v1/test_users.py @@ -15,7 +15,6 @@ AUTH_API_CLIENT_PARAMS, BASE_ROUTE, aws_secrets_payload, - azure_secrets_payload, password_update_payload, ) @@ -47,74 +46,39 @@ async def test_update_password_with_incorrect_current( assert update_response.status_code == status.HTTP_400_BAD_REQUEST assert update_response.json()["detail"] == "Current password is incorrect" - async def test_update_aws_credentials(self, auth_api_client: AsyncClient) -> None: - """Test updating AWS credentials.""" - # Add AWS credentials - aws_response = await auth_api_client.post( - f"{BASE_ROUTE}/users/me/secrets/aws", json=aws_secrets_payload - ) - assert aws_response.status_code == status.HTTP_200_OK - assert aws_response.json()["message"] == "AWS credentials updated successfully" - - # Check updated status - updated_status_response = await auth_api_client.get( - f"{BASE_ROUTE}/users/me/secrets" - ) - assert updated_status_response.status_code == status.HTTP_200_OK - - aws_status = updated_status_response.json() - assert aws_status["aws"]["has_credentials"] is True - assert "created_at" in aws_status["aws"] + async def test_update_secrets_with_invalid_payload( + self, auth_api_client: AsyncClient + ) -> None: + """Test updating user secrets with invalid credentials payload.""" + # Try update with invalid secrets format - Use AWS secrets specifically for this test + invalid_payload = copy.deepcopy(aws_secrets_payload) + # Using incorrect credentials to test validation - submit payload without required fields + invalid_payload["aws_secret_key"] = "" - async def test_update_azure_credentials(self, auth_api_client: AsyncClient) -> None: - """Test updating Azure credentials.""" - # Add Azure credentials - azure_response = await auth_api_client.post( - f"{BASE_ROUTE}/users/me/secrets/azure", json=azure_secrets_payload - ) - assert azure_response.status_code == status.HTTP_200_OK - assert ( - azure_response.json()["message"] == "Azure credentials updated successfully" + update_response = await auth_api_client.post( + f"{BASE_ROUTE}/users/me/secrets", json=invalid_payload ) + assert update_response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert "No AWS secret key provided." in str(update_response.json()["detail"]) - # Check updated status - status_response = await auth_api_client.get(f"{BASE_ROUTE}/users/me/secrets") - assert status_response.status_code == status.HTTP_200_OK - - azure_status = status_response.json() - assert azure_status["azure"]["has_credentials"] is True - assert "created_at" in azure_status["azure"] - - async def test_both_provider_credentials_status( + async def test_update_secrets_with_invalid_credentials( self, auth_api_client: AsyncClient ) -> None: - """Test that status shows both provider credentials when set.""" - # Add AWS credentials - aws_response = await auth_api_client.post( - f"{BASE_ROUTE}/users/me/secrets/aws", json=aws_secrets_payload - ) - assert aws_response.status_code == status.HTTP_200_OK - assert aws_response.json()["message"] == "AWS credentials updated successfully" + """Test updating user secrets with invalid credentials that do not authenticate.""" + # Try update with invalid secrets - Use AWS secrets specifically for this test + invalid_payload = copy.deepcopy( + aws_secrets_payload + ) # Example secrets correct format but do not authenticate - # Add Azure credentials - azure_response = await auth_api_client.post( - f"{BASE_ROUTE}/users/me/secrets/azure", json=azure_secrets_payload + update_response = await auth_api_client.post( + f"{BASE_ROUTE}/users/me/secrets", json=invalid_payload ) - assert azure_response.status_code == status.HTTP_200_OK + assert update_response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY assert ( - azure_response.json()["message"] == "Azure credentials updated successfully" + update_response.json()["detail"] + == "AWS credentials could not be authenticated. Please ensure you are providing credentials that are linked to a valid AWS account." ) - # Check final status with both credentials - status_response = await auth_api_client.get(f"{BASE_ROUTE}/users/me/secrets") - assert status_response.status_code == status.HTTP_200_OK - - provider_status = status_response.json() - assert provider_status["aws"]["has_credentials"] is True - assert provider_status["azure"]["has_credentials"] is True - assert "created_at" in provider_status["aws"] - assert "created_at" in provider_status["azure"] - @pytest.mark.asyncio(loop_scope="session") @pytest.mark.parametrize( @@ -195,13 +159,7 @@ async def test_unauthenticated_access(self, api_client: AsyncClient) -> None: # Try to update AWS secrets without authentication response = await api_client.post( - f"{BASE_ROUTE}/users/me/secrets/aws", json=aws_secrets_payload - ) - assert response.status_code == status.HTTP_401_UNAUTHORIZED - - # Try to update Azure secrets without authentication - response = await api_client.post( - f"{BASE_ROUTE}/users/me/secrets/azure", json=azure_secrets_payload + f"{BASE_ROUTE}/users/me/secrets", json=aws_secrets_payload ) assert response.status_code == status.HTTP_401_UNAUTHORIZED diff --git a/api/tests/integration/cloud/__init__.py b/api/tests/integration/cloud/__init__.py new file mode 100644 index 00000000..26394a91 --- /dev/null +++ b/api/tests/integration/cloud/__init__.py @@ -0,0 +1 @@ +"""Tests for cloud controls of the OpenLabs API.""" diff --git a/api/tests/integration/cloud/aws_config.py b/api/tests/integration/cloud/aws_config.py new file mode 100644 index 00000000..d7337a11 --- /dev/null +++ b/api/tests/integration/cloud/aws_config.py @@ -0,0 +1,55 @@ +import pytest + +# --- IAM --- +SUCCESS_POLICY = { + "Version": "2012-10-17", + "Statement": [ + {"Effect": "Allow", "Action": "iam:SimulatePrincipalPolicy", "Resource": "*"}, + {"Effect": "Allow", "Action": "ec2:*", "Resource": "*"}, + ], +} + +NO_SIMULATE_POLICY = { + "Version": "2012-10-17", + "Statement": [{"Effect": "Allow", "Action": "ec2:*", "Resource": "*"}], +} + +INSUFFICIENT_EC2_POLICY = { + "Version": "2012-10-17", + "Statement": [ + {"Effect": "Allow", "Action": "iam:SimulatePrincipalPolicy", "Resource": "*"}, + { + "Effect": "Allow", + "Action": [ + "ec2:DescribeInstances", + "ec2:DescribeVpcs", + "ec2:DescribeSubnets", + ], + "Resource": "*", + }, + ], +} + +VERIFY_CREDS_TEST_CASES = [ + pytest.param( + SUCCESS_POLICY, + True, + "permissions are present", + id="success-all-permissions", + marks=pytest.mark.aws, + ), + pytest.param( + NO_SIMULATE_POLICY, + False, + "iam:simulateprincipalpolicy", + id="failure-no-simulate-permission", + marks=pytest.mark.aws, + ), + pytest.param( + INSUFFICIENT_EC2_POLICY, + False, + "insufficient permissions", + id="failure-insufficient-ec2-permissions", + marks=pytest.mark.aws, + ), +] diff --git a/api/tests/integration/cloud/test_aws_creds.py b/api/tests/integration/cloud/test_aws_creds.py new file mode 100644 index 00000000..ff6741eb --- /dev/null +++ b/api/tests/integration/cloud/test_aws_creds.py @@ -0,0 +1,129 @@ +import json +import logging +import time +import uuid +from typing import Generator + +import boto3 +import pytest +from botocore.exceptions import ClientError + +from src.app.cloud.aws_creds import AWSCreds +from src.app.schemas.secret_schema import AWSSecrets +from tests.aws_test_utils import set_test_boto_creds + +from .aws_config import VERIFY_CREDS_TEST_CASES + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="session") +def iam_client(load_test_env_file: bool) -> boto3.client: + """Create IAM client, skipping if credentials are not set.""" + if not load_test_env_file or not set_test_boto_creds(): + pytest.skip("Credentials for AWS not set.") + return boto3.client("iam") + + +@pytest.fixture(scope="session") +def all_test_credentials( + iam_client: boto3.client, +) -> Generator[dict[str, dict[str, str]], None, None]: + """Create all necessary IAM users/policies for the entire test session at once.""" + session_id = uuid.uuid4().hex[:8] + created_resources = {} + credentials_by_id = {} + + logger.info("Starting batch creation of IAM test resources...") + + try: + # Batch Create all IAM Resources + for param in VERIFY_CREDS_TEST_CASES: + test_id = str(param.id) + policy_document = param.values[0] + + user_name = f"openlabs-test-{test_id}-{session_id}" + policy_name = f"OpenLabsTest-{test_id}-{session_id}" + + # Create user, policy, attach, and create key + iam_client.create_user(UserName=user_name) + policy_response = iam_client.create_policy( + PolicyName=policy_name, PolicyDocument=json.dumps(policy_document) + ) + policy_arn = policy_response["Policy"]["Arn"] + iam_client.attach_user_policy(UserName=user_name, PolicyArn=policy_arn) + key_response = iam_client.create_access_key(UserName=user_name) + access_key = key_response["AccessKey"] + + # Store credentials for the test to use + credentials_by_id[test_id] = { + "provider": "aws", + "aws_access_key": access_key["AccessKeyId"], + "aws_secret_key": access_key["SecretAccessKey"], + } + # Store all identifiers for later cleanup + created_resources[test_id] = { + "user_name": user_name, + "policy_arn": policy_arn, + "access_key_id": access_key["AccessKeyId"], + } + logger.info("Created resources for test ID: %s", test_id) + + prop_wait = 10 + logger.info("Waiting %d seconds for all IAM changes to propagate...", prop_wait) + time.sleep(prop_wait) + + yield credentials_by_id + + finally: + # 3. --- Batch Cleanup --- + logger.info("Starting cleanup of all IAM test resources.") + for test_id, resources in created_resources.items(): + user_name = resources["user_name"] + logger.info( + "Cleaning up resources for %s (user: %s)...", test_id, user_name + ) + try: + iam_client.detach_user_policy( + UserName=user_name, PolicyArn=resources["policy_arn"] + ) + except ClientError as e: + logger.warning("Failed to detach policy for %s: %s", user_name, e) + try: + iam_client.delete_policy(PolicyArn=resources["policy_arn"]) + except ClientError as e: + logger.warning("Failed to delete policy for %s: %s", user_name, e) + try: + iam_client.delete_access_key( + UserName=user_name, AccessKeyId=resources["access_key_id"] + ) + except ClientError as e: + logger.warning("Failed to delete access key for %s: %s", user_name, e) + try: + iam_client.delete_user(UserName=user_name) + except ClientError as e: + logger.warning("Failed to delete user %s: %s", user_name, e) + logger.info("IAM resource cleanup complete.") + + +@pytest.mark.parametrize( + "test_id, expected_result, expected_message_part", + [(p.id, p.values[1], p.values[2]) for p in VERIFY_CREDS_TEST_CASES], + ids=[str(p.id) for p in VERIFY_CREDS_TEST_CASES], +) +def test_verify_aws_creds( + all_test_credentials: dict[str, dict[str, str]], + test_id: str, + expected_result: bool, + expected_message_part: str, +) -> None: + """Test the AWS verify_creds method with various scenarios.""" + # The test now receives exactly the arguments it needs. No unpacking! + credentials = all_test_credentials[test_id] + aws_creds = AWSSecrets.model_validate(credentials) + + aws_verifier = AWSCreds(credentials=aws_creds) + is_valid, message_schema = aws_verifier.verify_creds() + + assert is_valid is expected_result + assert expected_message_part in message_schema.message.lower() diff --git a/api/tests/unit/api/v1/config.py b/api/tests/unit/api/v1/config.py index ec933477..39768d40 100644 --- a/api/tests/unit/api/v1/config.py +++ b/api/tests/unit/api/v1/config.py @@ -181,19 +181,6 @@ "new_password": "newpassword123", } -# Test data for AWS secrets -aws_secrets_payload = { - "aws_access_key": "AKIAIOSFODNN7EXAMPLE", - "aws_secret_key": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", -} - -# Test data for Azure secrets -azure_secrets_payload = { - "azure_client_id": "00000000-0000-0000-0000-000000000000", - "azure_client_secret": "example-client-secret-value", - "azure_tenant_id": "00000000-0000-0000-0000-000000000000", - "azure_subscription_id": "00000000-0000-0000-0000-000000000000", -} # ============================== # Job Payloads # ============================== diff --git a/api/tests/unit/api/v1/test_users.py b/api/tests/unit/api/v1/test_users.py new file mode 100644 index 00000000..27ab1c77 --- /dev/null +++ b/api/tests/unit/api/v1/test_users.py @@ -0,0 +1,112 @@ +from datetime import UTC, datetime +from unittest.mock import MagicMock + +import pytest +from fastapi import status +from httpx import AsyncClient +from pytest_mock import MockerFixture + +from src.app.core.auth.auth import get_current_user +from src.app.main import app +from src.app.models.user_model import UserModel +from src.app.schemas.message_schema import MessageSchema +from src.app.schemas.secret_schema import SecretSchema +from tests.common.api.v1.config import BASE_ROUTE, aws_secrets_payload + + +@pytest.fixture +def users_api_v1_endpoints_path() -> str: + """Get the dot path to the v1 API endpoints for users.""" + return "src.app.api.v1.users" + + +@pytest.fixture +def mock_update_secrets_success( + mocker: MockerFixture, users_api_v1_endpoints_path: str +) -> None: + """Bypass provider credentials verification and updating user secrets record to succeed.""" + mock_creds_class = MagicMock() + mock_creds_class.verify_creds.return_value = [True, MessageSchema(message="true")] + mock_creds_class.update_secret_schema.return_value = SecretSchema() + # Patch the functions + mocker.patch( + f"{users_api_v1_endpoints_path}.CredsFactory.create_creds_verification", + return_value=mock_creds_class, + ) + + +@pytest.fixture +def mock_get_secrets_failure( + mocker: MockerFixture, users_api_v1_endpoints_path: str +) -> None: + """Bypass fetching users secrets to fail.""" + # Patch the function + mocker.patch(f"{users_api_v1_endpoints_path}.get_user_secrets", return_value=None) + + +@pytest.fixture +def mock_get_secrets(mocker: MockerFixture, users_api_v1_endpoints_path: str) -> None: + """Bypass fetching users secrets to pass for a fake user.""" + + def override_get_current_user_no_key() -> UserModel: + return UserModel( + name="FakeUser", + email="fakeuser@gmail.com", + hashed_password="faskpasswordhash", # noqa: S106 + created_at=datetime.now(UTC), + last_active=datetime.now(UTC), + is_admin=False, + public_key=None, + ) + + # Temporarily override the dependency + app.dependency_overrides[get_current_user] = override_get_current_user_no_key + # Patch the function + mocker.patch( + f"{users_api_v1_endpoints_path}.get_user_secrets", + return_value=SecretSchema(), + ) + + +async def test_update_aws_secrets_success( + auth_client: AsyncClient, mock_update_secrets_success: None +) -> None: + """Test that attempting to update user AWS provider credentials succeeds.""" + response = await auth_client.post( + f"{BASE_ROUTE}/users/me/secrets", + json=aws_secrets_payload, + ) + assert response.status_code == status.HTTP_200_OK + assert ( + response.json()["message"] + == "AWS credentials successfully verified and updated." + ) + + +async def test_update_user_secrets_database_fetch_failure( + auth_client: AsyncClient, mock_get_secrets_failure: None +) -> None: + """Test that attempting to update user provider credentials fails when user record is not found in database.""" + response = await auth_client.post( + f"{BASE_ROUTE}/users/me/secrets", + json=aws_secrets_payload, + ) + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert response.json()["detail"] == "User secrets record not found!" + + +async def test_update_user_secrets_encryption_failure( + auth_client: AsyncClient, + mock_get_secrets: None, + mock_update_secrets_success: None, +) -> None: + """Test that attempting to update user provider credentials fails when user public key does not exist.""" + response = await auth_client.post( + f"{BASE_ROUTE}/users/me/secrets", + json=aws_secrets_payload, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert ( + response.json()["detail"] + == "User encryption keys not set up. Please register a new account." + ) diff --git a/api/tests/unit/cloud/test_aws_creds.py b/api/tests/unit/cloud/test_aws_creds.py new file mode 100644 index 00000000..256b6907 --- /dev/null +++ b/api/tests/unit/cloud/test_aws_creds.py @@ -0,0 +1,178 @@ +import pytest +from botocore.exceptions import ClientError +from pytest_mock import MockerFixture + +from src.app.cloud.aws_creds import AWSCreds +from src.app.schemas.secret_schema import AWSSecrets, SecretSchema +from src.app.utils.crypto import encrypt_with_public_key, generate_rsa_key_pair +from tests.common.api.v1.config import aws_secrets_payload + + +@pytest.fixture(scope="module") +def aws_creds_class() -> AWSCreds: + """Create a AWS creds class object.""" + return AWSCreds(AWSSecrets.model_validate(aws_secrets_payload)) + + +def test_init(aws_creds_class: AWSCreds) -> None: + """Test that initializing the aws creds object creates and stores and AWS Secrets Schema object.""" + assert isinstance(aws_creds_class.credentials, AWSSecrets) + + +def test_get_user_creds(aws_creds_class: AWSCreds) -> None: + """Test that getting user credentials is returned as a proper dictionary.""" + user_creds = aws_creds_class.get_user_creds() + + # Keys + access_key = "aws_access_key" + secret_key = "aws_secret_key" # noqa: S105 + + assert user_creds[access_key] == aws_secrets_payload[access_key] + assert user_creds[secret_key] == aws_secrets_payload[secret_key] + + +def test_update_secret_schema(aws_creds_class: AWSCreds) -> None: + """Test that Secret Schema is updated with encrypted user credentials.""" + user_creds = aws_creds_class.get_user_creds() + + # Encrypt with the user's public key + _, public_key = generate_rsa_key_pair() + encrypted_data = encrypt_with_public_key(data=user_creds, public_key_b64=public_key) + + # Update the secrets with encrypted values + secrets = SecretSchema() + secrets = aws_creds_class.update_secret_schema( + secrets=secrets, encrypted_data=encrypted_data + ) + + assert secrets.aws_access_key == encrypted_data["aws_access_key"] + assert secrets.aws_secret_key == encrypted_data["aws_secret_key"] + + +def test_verify_creds_success(aws_creds_class: AWSCreds, mocker: MockerFixture) -> None: + """Test successful credential verification with sufficient permissions.""" + # Mock the boto3 session and its clients + mock_session = mocker.patch("boto3.Session").return_value + mock_sts = mock_session.client.return_value + mock_iam = mock_session.client.return_value + + # Configure mock responses + mock_sts.get_caller_identity.return_value = { + "Arn": "arn:aws:iam::123456789012:user/test-user" + } + mock_iam.simulate_principal_policy.return_value = { + "EvaluationResults": [ + {"EvalDecision": "allowed", "EvalActionName": "ec2:RunInstances"} + ] + } + + # Execute the function + verified, msg = aws_creds_class.verify_creds() + + # Assert the results + assert verified is True + assert "authenticated and all required permissions are present" in msg.message + mock_sts.get_caller_identity.assert_called_once() + mock_iam.simulate_principal_policy.assert_called_once() + + +def test_verify_creds_root_user( + aws_creds_class: AWSCreds, mocker: MockerFixture +) -> None: + """Test that the permissions check is skipped for the root user.""" + mock_session = mocker.patch("boto3.Session").return_value + mock_sts = mock_session.client.return_value + mock_iam = mock_session.client.return_value + + mock_sts.get_caller_identity.return_value = { + "Arn": "arn:aws:iam::123456789012:root" + } + + verified, msg = aws_creds_class.verify_creds() + + assert verified is True + assert "authenticated and all required permissions are present" in msg.message + mock_iam.simulate_principal_policy.assert_not_called() # Ensure IAM check was skipped + + +def test_verify_creds_insufficient_permissions( + aws_creds_class: AWSCreds, mocker: MockerFixture +) -> None: + """Test verification failure due to denied actions in the simulation.""" + mock_session = mocker.patch("boto3.Session").return_value + mock_sts = mock_session.client.return_value + mock_iam = mock_session.client.return_value + + mock_sts.get_caller_identity.return_value = { + "Arn": "arn:aws:iam::123456789012:user/limited-user" + } + mock_iam.simulate_principal_policy.return_value = { + "EvaluationResults": [ + {"EvalDecision": "allowed", "EvalActionName": "ec2:DescribeInstances"}, + {"EvalDecision": "implicitDeny", "EvalActionName": "ec2:RunInstances"}, + {"EvalDecision": "implicitDeny", "EvalActionName": "ec2:CreateVpc"}, + ] + } + + verified, msg = aws_creds_class.verify_creds() + + assert verified is False + assert "Insufficient permissions" in msg.message + assert "ec2:RunInstances" in msg.message + assert "ec2:CreateVpc" in msg.message + assert ( + "ec2:DescribeInstances" not in msg.message + ) # Ensure only denied actions are listed + + +def test_verify_creds_invalid_token( + aws_creds_class: AWSCreds, mocker: MockerFixture +) -> None: + """Test verification failure due to invalid credentials.""" + mock_session = mocker.patch("boto3.Session").return_value + mock_sts = mock_session.client.return_value + + # Simulate a ClientError for an invalid token + error_response = { + "Error": { + "Code": "InvalidTokenId", + "Message": "The security token included in the request is invalid.", + } + } + mock_sts.get_caller_identity.side_effect = ClientError( + error_response, "GetCallerIdentity" + ) + + verified, msg = aws_creds_class.verify_creds() + + assert verified is False + assert msg.message == "The security token included in the request is invalid." + + +def test_verify_creds_iam_access_denied( + aws_creds_class: AWSCreds, mocker: MockerFixture +) -> None: + """Test failure when credentials are valid but cannot perform the IAM simulation.""" + mock_session = mocker.patch("boto3.Session").return_value + mock_sts = mock_session.client.return_value + mock_iam = mock_session.client.return_value + + mock_sts.get_caller_identity.return_value = { + "Arn": "arn:aws:iam::123456789012:user/test-user" + } + + # Simulate a ClientError for lack of iam:SimulatePrincipalPolicy permission + error_response = { + "Error": { + "Code": "AccessDenied", + "Message": "User is not authorized to perform iam:SimulatePrincipalPolicy", + } + } + mock_iam.simulate_principal_policy.side_effect = ClientError( + error_response, "SimulatePrincipalPolicy" + ) + + verified, msg = aws_creds_class.verify_creds() + + assert verified is False + assert "iam:SimulatePrincipalPolicy permission" in msg.message diff --git a/api/tests/unit/cloud/test_creds_factory.py b/api/tests/unit/cloud/test_creds_factory.py new file mode 100644 index 00000000..f0366836 --- /dev/null +++ b/api/tests/unit/cloud/test_creds_factory.py @@ -0,0 +1,30 @@ +import copy + +import pytest + +from src.app.cloud.aws_creds import AWSCreds +from src.app.cloud.creds_factory import CredsFactory +from src.app.schemas.secret_schema import AWSSecrets +from tests.common.api.v1.config import aws_secrets_payload + + +def test_creds_factory_non_existent_provider_type() -> None: + """Test that CredsFactory.create_creds_verification() raises a ValueError when invalid provider is provided.""" + # Set provider to non-existent provider + bad_provider_creds_payload = copy.deepcopy(aws_secrets_payload) + aws_secrets = AWSSecrets.model_validate(bad_provider_creds_payload) + aws_secrets.provider = "Fake Provider" # type: ignore + + with pytest.raises(ValueError): + _ = CredsFactory.create_creds_verification( + credentials=aws_secrets, + ) + + +def test_creds_factory_build_aws_creds() -> None: + """Test that CredsFactory can build an AWSCreds.""" + # Use AWS secrets payload with provider already set to AWS + aws_secrets = AWSSecrets.model_validate(aws_secrets_payload) + creds_object = CredsFactory.create_creds_verification(credentials=aws_secrets) + + assert type(creds_object) is AWSCreds diff --git a/api/tests/unit/crud/test_crud_secrets.py b/api/tests/unit/crud/test_crud_secrets.py new file mode 100644 index 00000000..42fb38ed --- /dev/null +++ b/api/tests/unit/crud/test_crud_secrets.py @@ -0,0 +1,75 @@ +import re +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from src.app.crud.crud_secrets import get_user_secrets, upsert_user_secrets +from src.app.models.secret_model import SecretModel +from src.app.schemas.secret_schema import SecretSchema + +from .crud_mocks import DummyDB + + +@pytest.fixture +def crud_secrets_path() -> str: + """Return the dot path of the tested crud secrets file.""" + return "src.app.crud.crud_secrets" + + +async def test_upsert_user_secrets() -> None: + """Test that updating user secrets works correctly.""" + mock_db = DummyDB() + fake_secrets = SecretSchema() + mock_user_id = 1 + + await upsert_user_secrets(db=mock_db, secrets=fake_secrets, user_id=mock_user_id) + + # Check we upsert correctly + mock_db.merge.assert_awaited_once() + + # Check we are upserting the creds for the correct user + called_with_object = mock_db.merge.call_args[0][0] + + # 5. Assert that the object has the correct attributes + assert isinstance(called_with_object, SecretModel) + assert called_with_object.user_id == mock_user_id + + +async def test_get_user_secrets_success(mocker: MockerFixture) -> None: + """Tests successfully retrieving user secrets when they exist.""" + dummy_db = DummyDB() + user_id = 123 + mock_secret_model = MagicMock() + + # Mock the full database call to return mock model + mock_result = dummy_db.execute.return_value + mock_result.scalars = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_secret_model + + # Patch validation of secrets + mock_validate = mocker.patch.object( + SecretSchema, "model_validate", return_value=SecretSchema() + ) + + result = await get_user_secrets(db=dummy_db, user_id=user_id) + fake_secrets = SecretSchema() + assert result == fake_secrets + mock_validate.assert_called_once_with(mock_secret_model) + + # Verify the fetching secrets for correct user + stmt = dummy_db.execute.call_args[0][0] + assert str(SecretModel.user_id == user_id) in str(stmt.whereclause) + + +async def test_get_user_secrets_not_found(caplog: pytest.LogCaptureFixture) -> None: + """Tests that None is returned and a message is logged when secrets are not found.""" + dummy_db = DummyDB() + user_id = 404 + + # Mock the database to return no results + dummy_db.execute.return_value = None + + result = await get_user_secrets(db=dummy_db, user_id=user_id) + assert result is None + assert re.search(f"failed to fetch secrets.*{user_id}", caplog.text, re.IGNORECASE) diff --git a/api/tests/unit/schemas/test_secret_schema.py b/api/tests/unit/schemas/test_secret_schema.py new file mode 100644 index 00000000..7fe133ab --- /dev/null +++ b/api/tests/unit/schemas/test_secret_schema.py @@ -0,0 +1,48 @@ +import copy +import re + +import pytest +from pydantic import ValidationError + +from src.app.schemas.secret_schema import AWSSecrets +from tests.common.api.v1.config import aws_secrets_payload + + +def test_aws_secrets_schema_invalid_access_key_length() -> None: + """Test that the AWS Secrets schema fails when aws_access_key is invalid length.""" + invalid_creds = copy.deepcopy(aws_secrets_payload) + invalid_creds["aws_access_key"] = "string" + + expected_msg = re.compile(r"Invalid AWS access key format.", re.IGNORECASE) + with pytest.raises(ValidationError, match=expected_msg): + AWSSecrets.model_validate(invalid_creds) + + +def test_aws_secrets_schema_missing_access_key() -> None: + """Test that the AWS Secrets schema fails when aws_access_key is missing.""" + invalid_creds = copy.deepcopy(aws_secrets_payload) + invalid_creds["aws_access_key"] = "" # Empty string = missing access key + + expected_msg = re.compile(r"No AWS access key provided.", re.IGNORECASE) + with pytest.raises(ValidationError, match=expected_msg): + AWSSecrets.model_validate(invalid_creds) + + +def test_aws_secrets_schema_invalid_secret_key_length() -> None: + """Test that the AWS Secrets schema fails when aws_secret_key is invalid length.""" + invalid_creds = copy.deepcopy(aws_secrets_payload) + invalid_creds["aws_secret_key"] = "string" # noqa: S105 + + expected_msg = re.compile(r"Invalid AWS secret key format.", re.IGNORECASE) + with pytest.raises(ValidationError, match=expected_msg): + AWSSecrets.model_validate(invalid_creds) + + +def test_aws_secrets_schema_missing_secret_key() -> None: + """Test that the AWS Secrets schema fails when aws_secret_key is missing.""" + invalid_creds = copy.deepcopy(aws_secrets_payload) + invalid_creds["aws_secret_key"] = "" # Empty string = missing secret key + + expected_msg = re.compile(r"No AWS secret key provided.", re.IGNORECASE) + with pytest.raises(ValidationError, match=expected_msg): + AWSSecrets.model_validate(invalid_creds) diff --git a/cli/internal/client/auth.go b/cli/internal/client/auth.go index 7b903809..843f5832 100644 --- a/cli/internal/client/auth.go +++ b/cli/internal/client/auth.go @@ -4,6 +4,8 @@ import ( "fmt" "net/http" "net/url" + "regexp" + "strings" "github.com/OpenLabsHQ/OpenLabs/cli/internal/logger" ) @@ -114,22 +116,64 @@ func (c *Client) GetUserSecrets() (*UserSecretResponse, error) { return &secrets, nil } +type AWSSecretsPayload struct { + Provider string `json:"provider"` + AccessKey string `json:"aws_access_key"` + SecretKey string `json:"aws_secret_key"` +} + +type AzureSecretsPayload struct { + Provider string `json:"provider"` + ClientID string `json:"azure_client_id"` + ClientSecret string `json:"azure_client_secret"` + TenantID string `json:"azure_tenant_id"` + SubscriptionID string `json:"azure_subscription_id"` +} + +// ParseValidationErrors extracts and cleans user-facing messages from a validation error string. +func ParseValidationErrors(errString string) string { + re := regexp.MustCompile(`msg:(.*?) type:`) + matches := re.FindAllStringSubmatch(errString, -1) + + if len(matches) == 0 { + return errString + } + + var errorMessages []string + for _, match := range matches { + if len(match) > 1 { + message := strings.TrimSpace(match[1]) + message = strings.TrimSuffix(message, ",") + + // **This is the new line to remove the prefix** + message = strings.TrimPrefix(message, "Value error, ") + + errorMessages = append(errorMessages, message) + } + } + + return strings.Join(errorMessages, "; ") +} + func (c *Client) UpdateAWSSecrets(accessKey, secretKey string) error { - secrets := AWSSecrets{ + payload := AWSSecretsPayload{ + Provider: "aws", AccessKey: accessKey, SecretKey: secretKey, } var response Message - if err := c.makeRequest("POST", "/api/v1/users/me/secrets/aws", secrets, &response); err != nil { - return fmt.Errorf("failed to update AWS secrets: %w", err) + if err := c.makeRequest("POST", "/api/v1/users/me/secrets", payload, &response); err != nil { + message := ParseValidationErrors(err.Error()) + return fmt.Errorf("failed to update AWS secrets: %s", message) } return nil } func (c *Client) UpdateAzureSecrets(clientID, clientSecret, tenantID, subscriptionID string) error { - secrets := AzureSecrets{ + payload := AzureSecretsPayload{ + Provider: "azure", ClientID: clientID, ClientSecret: clientSecret, TenantID: tenantID, @@ -137,7 +181,7 @@ func (c *Client) UpdateAzureSecrets(clientID, clientSecret, tenantID, subscripti } var response Message - if err := c.makeRequest("POST", "/api/v1/users/me/secrets/azure", secrets, &response); err != nil { + if err := c.makeRequest("POST", "/api/v1/users/me/secrets", payload, &response); err != nil { return fmt.Errorf("failed to update Azure secrets: %w", err) } diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index c4c29f4d..44f86ff9 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -14,10 +14,7 @@ import type { BlueprintRange, PasswordUpdateRequest, PasswordUpdateResponse, - AWSSecretsRequest, - AWSSecretsResponse, - AzureSecretsRequest, - AzureSecretsResponse, + SecretsResponse, DeployRangeRequest } from '$lib/types/api' @@ -162,36 +159,11 @@ export const userApi = { ) }, - // Set AWS secrets - setAwsSecrets: async (accessKey: string, secretKey: string): Promise> => { - const request: AWSSecretsRequest = { - aws_access_key: accessKey, - aws_secret_key: secretKey, - } - return await apiRequest( - '/api/v1/users/me/secrets/aws', - 'POST', - request - ) - }, - - // Set Azure secrets - setAzureSecrets: async ( - clientId: string, - clientSecret: string, - tenantId: string, - subscriptionId: string - ): Promise> => { - const request: AzureSecretsRequest = { - azure_client_id: clientId, - azure_client_secret: clientSecret, - azure_tenant_id: tenantId, - azure_subscription_id: subscriptionId, - } - return await apiRequest( - '/api/v1/users/me/secrets/azure', + updateSecrets: async (payload: any): Promise> => { + return await apiRequest( + '/api/v1/users/me/secrets', 'POST', - request + payload ) }, } diff --git a/frontend/src/lib/types/api.ts b/frontend/src/lib/types/api.ts index b8530f98..9fe4619e 100644 --- a/frontend/src/lib/types/api.ts +++ b/frontend/src/lib/types/api.ts @@ -191,25 +191,8 @@ export interface PasswordUpdateResponse { message: string; } -// AWS secrets types -export interface AWSSecretsRequest { - aws_access_key: string; - aws_secret_key: string; -} - -export interface AWSSecretsResponse { - message: string; -} - -// Azure secrets types -export interface AzureSecretsRequest { - azure_client_id: string; - azure_client_secret: string; - azure_tenant_id: string; - azure_subscription_id: string; -} - -export interface AzureSecretsResponse { +// Secrets verification response +export interface SecretsResponse { message: string; } diff --git a/frontend/src/lib/utils/error.ts b/frontend/src/lib/utils/error.ts index 12389dd1..c09bfd1b 100644 --- a/frontend/src/lib/utils/error.ts +++ b/frontend/src/lib/utils/error.ts @@ -49,6 +49,18 @@ export function formatErrorMessage(error: unknown, fallbackMessage: string = 'An return fallbackMessage } +// Used specifically for pydantic errors thrown back to the user at the endpoint before the endpoint function is executed +export function extractPydanticErrors(error: any): string { + let errorObj = error; + + // 2. Map over the array to get just the 'msg' string from each object. + const messages = errorObj.map((err: any) => err.msg.replace(/^Value error,\s*/i, '')) + + // 3. Join the array of messages into a single string, separated by newlines. + return messages.join('\n'); +} + + /** * Creates a safe error handler function that always returns a string * @param fallbackMessage - Default message for unhandled errors diff --git a/frontend/src/routes/settings/+page.svelte b/frontend/src/routes/settings/+page.svelte index 9880f002..0dbed19a 100644 --- a/frontend/src/routes/settings/+page.svelte +++ b/frontend/src/routes/settings/+page.svelte @@ -6,6 +6,10 @@ import LoadingSpinner from '$lib/components/LoadingSpinner.svelte' import { fade } from 'svelte/transition' import logger from '$lib/utils/logger' + import { extractPydanticErrors } from '$lib/utils/error' + + // Active tab state + let activeTab: 'aws' | 'azure' = 'aws' // Password form let currentPassword = '' @@ -41,18 +45,18 @@ let secretsStatus = { aws: { configured: false, - createdAt: null, + createdAt: null as string | null, }, azure: { configured: false, - createdAt: null, + createdAt: null as string | null, }, } let loadingSecrets = true let loadingUserData = true // Format date for tooltip display - function formatDateForTooltip(dateString) { + function formatDateForTooltip(dateString: string | null) { if (!dateString) return 'Date unavailable' try { return `Configured on ${new Date(dateString).toLocaleString()}` @@ -62,54 +66,49 @@ } } - // Custom tooltip management - let showAwsTooltip = false - let showAzureTooltip = false - - // Position tracking for tooltips - let awsTooltipPosition = { x: 0, y: 0 } - let azureTooltipPosition = { x: 0, y: 0 } - - function handleMouseEnter(event, tooltipType) { - // Calculate position for tooltip - const rect = event.target.getBoundingClientRect() - const position = { - x: rect.left + window.scrollX + rect.width / 2, // Center horizontally - y: rect.top + window.scrollY - 40, // Position higher above the element + // Helper for tab styling + function getTabClass(tabName: 'aws' | 'azure') { + const baseClasses = 'whitespace-nowrap border-b-2 px-1 py-4 text-sm font-medium' + const inactiveClasses = 'border-transparent text-gray-400 hover:border-gray-500 hover:text-gray-300' + + if (activeTab === tabName) { + const activeClasses = tabName === 'aws' + ? 'border-yellow-500 text-yellow-500' + : 'border-blue-500 text-blue-500' + return `${baseClasses} ${activeClasses}` } + + return `${baseClasses} ${inactiveClasses}` + } - // Set position and show appropriate tooltip - if (tooltipType === 'aws') { - awsTooltipPosition = position - showAwsTooltip = true - } else if (tooltipType === 'azure') { - azureTooltipPosition = position - showAzureTooltip = true + // Custom tooltip management + let showTooltip = false + let tooltipPosition = { x: 0, y: 0 } + + function handleMouseEnter(event: MouseEvent) { + const rect = (event.target as HTMLElement).getBoundingClientRect() + tooltipPosition = { + x: rect.left + window.scrollX + rect.width / 2, + y: rect.top + window.scrollY - 10, } + showTooltip = true } - function handleMouseLeave(tooltipType) { - if (tooltipType === 'aws') { - showAwsTooltip = false - } else if (tooltipType === 'azure') { - showAzureTooltip = false - } + function handleMouseLeave() { + showTooltip = false } // Load user data and secrets status onMount(async () => { + loadingUserData = true try { - // Load user data first const { authApi } = await import('$lib/api') const userResponse = await authApi.getCurrentUser() - if (userResponse.data?.user) { userData = { name: userResponse.data.user.name || '', email: userResponse.data.user.email || '', } - - // Update auth store auth.updateUser(userResponse.data.user) } } catch (error) { @@ -118,22 +117,18 @@ loadingUserData = false } + loadingSecrets = true try { - // Then load secrets status const result = await userApi.getUserSecrets() - if (result.data) { - const awsDate = result.data.aws?.created_at - const azureDate = result.data.azure?.created_at - secretsStatus = { aws: { configured: result.data.aws?.has_credentials || false, - createdAt: awsDate, + createdAt: result.data.aws?.created_at || null, }, azure: { configured: result.data.azure?.has_credentials || false, - createdAt: azureDate, + createdAt: result.data.azure?.created_at || null, }, } } @@ -146,42 +141,27 @@ // Handle password update async function handlePasswordUpdate() { - // Reset messages passwordError = '' passwordSuccess = '' - - // Validate input - if (!currentPassword) { - passwordError = 'Current password is required' + if (!currentPassword || !newPassword) { + passwordError = 'All password fields are required' return } - - if (!newPassword) { - passwordError = 'New password is required' - return - } - if (newPassword !== confirmPassword) { passwordError = 'New passwords do not match' return } - if (newPassword.length < 8) { passwordError = 'Password must be at least 8 characters long' return } - isPasswordLoading = true - try { const result = await userApi.updatePassword(currentPassword, newPassword) - if (result.error) { passwordError = result.error return } - - // Success passwordSuccess = 'Password updated successfully' currentPassword = '' newPassword = '' @@ -194,124 +174,112 @@ } } - // Handle AWS secrets update - async function handleAwsSecretsUpdate() { - // Reset messages - awsError = '' - awsSuccess = '' - - // Validate input - if (!awsAccessKey) { - awsError = 'AWS Access Key is required' - return - } - - if (!awsSecretKey) { - awsError = 'AWS Secret Key is required' - return + // Shared function to call the single endpoint + async function updateSecrets(provider: 'aws' | 'azure', payload: any) { + // Set loading state based on provider + if (provider === 'aws') { + isAwsLoading = true + } else { + isAzureLoading = true } - isAwsLoading = true - try { - const result = await userApi.setAwsSecrets(awsAccessKey, awsSecretKey) + // Call single, common endpoint + const result = await userApi.updateSecrets(payload) + let errorMsg = "" if (result.error) { - awsError = result.error + if (typeof result.error === 'object'){ // Pydantic validation error + const formattedError = extractPydanticErrors(result.error) + errorMsg = formattedError + errorMsg = `${errorMsg}\nPlease ensure you are providing proper AWS credentials.` + } + else errorMsg = result.error + if (provider === 'aws') awsError = errorMsg + else azureError = errorMsg return } - // Success - awsSuccess = 'AWS credentials updated successfully' - secretsStatus.aws.configured = true // Update local status - secretsStatus.aws.createdAt = new Date().toISOString() - awsAccessKey = '' - awsSecretKey = '' + // Handle success based on provider + if (provider === 'aws') { + awsSuccess = 'AWS credentials updated successfully' + secretsStatus.aws = { configured: true, createdAt: new Date().toISOString() } + awsAccessKey = '' + awsSecretKey = '' + } else { + azureSuccess = 'Azure credentials updated successfully' + secretsStatus.azure = { configured: true, createdAt: new Date().toISOString() } + azureClientId = '' + azureClientSecret = '' + azureTenantId = '' + azureSubscriptionId = '' + } } catch (error) { - awsError = - error instanceof Error - ? error.message - : 'Failed to update AWS credentials' + const errorMessage = error instanceof Error ? error.message : 'Failed to update credentials' + if (provider === 'aws') awsError = errorMessage + else azureError = errorMessage } finally { - isAwsLoading = false + // Reset loading state based on provider + if (provider === 'aws') isAwsLoading = false + else isAzureLoading = false } } - // Handle Azure secrets update - async function handleAzureSecretsUpdate() { - // Reset messages - azureError = '' - azureSuccess = '' - - // Validate input - if (!azureClientId) { - azureError = 'Azure Client ID is required' + // Handle AWS secrets update + async function handleAwsSecretsUpdate() { + awsError = '' + awsSuccess = '' + if (!awsAccessKey || !awsSecretKey) { + awsError = 'Both AWS Access Key and Secret Key are required' return } - - if (!azureClientSecret) { - azureError = 'Azure Client Secret is required' - return + + // Construct the AWS-specific payload + const payload = { + provider: "aws", + aws_access_key: awsAccessKey, + aws_secret_key: awsSecretKey, } - if (!azureTenantId) { - azureError = 'Azure Tenant ID is required' - return - } + // Call the shared update function + await updateSecrets('aws', payload) + } - if (!azureSubscriptionId) { - azureError = 'Azure Subscription ID is required' + // Handle Azure secrets update + async function handleAzureSecretsUpdate() { + azureError = '' + azureSuccess = '' + if ( + !azureClientId || + !azureClientSecret || + !azureTenantId || + !azureSubscriptionId + ) { + azureError = 'All Azure fields are required' return } - - isAzureLoading = true - - try { - const result = await userApi.setAzureSecrets( - azureClientId, - azureClientSecret, - azureTenantId, - azureSubscriptionId - ) - - if (result.error) { - azureError = result.error - return - } - - // Success - azureSuccess = 'Azure credentials updated successfully' - secretsStatus.azure.configured = true // Update local status - secretsStatus.azure.createdAt = new Date().toISOString() - azureClientId = '' - azureClientSecret = '' - azureTenantId = '' - azureSubscriptionId = '' - } catch (error) { - azureError = - error instanceof Error - ? error.message - : 'Failed to update Azure credentials' - } finally { - isAzureLoading = false + + // Construct the Azure-specific payload + const payload = { + provider: "azure", + azure_client_id: azureClientId, + azure_client_secret: azureClientSecret, + azure_tenant_id: azureTenantId, + azure_subscription_id: azureSubscriptionId, } + + // Call the shared update function + await updateSecrets('azure', payload) } -
+
{:else} -
+
- {userData.name?.[0] || 'U'} + {userData.name?.[0]?.toUpperCase() || 'U'}

{userData.name || 'User'}

@@ -361,7 +329,6 @@

Change Password

-
-
-
- {#if passwordError} -
- {passwordError} -
+
{passwordError}
{/if} - {#if passwordSuccess} -
- {passwordSuccess} -
+
{passwordSuccess}
{/if} -
- -
+ +
-

Cloud Provider Credentials

- - - -
-
- -
-
- - - - End-to-End Encrypted +
+ +
+
+ + End-to-End Encrypted +
+

Your credentials are encrypted before being stored and are only decrypted when needed for a range. We cannot access your cloud provider credentials.

+
-

- Your credentials are encrypted before entering the database and - are only decrypted when needed for a range. Even the person - hosting OpenLabs cannot access your cloud provider credentials. -

-
-
{#if loadingSecrets} @@ -518,255 +422,118 @@
{:else} -
- -
-
-

AWS Credentials

- handleMouseEnter(e, 'aws')} - on:mouseleave={() => handleMouseLeave('aws')} +
+ +
+ +
- {#if showAwsTooltip && secretsStatus.aws.configured} -
+
+ {#if activeTab === 'aws'} +
+

AWS Credentials

+ - {formatDateForTooltip(secretsStatus.aws.createdAt)} -
-
- {/if} -
- -
-
+ {secretsStatus.aws.configured ? 'Configured' : 'Not Configured'} + +
+
- - + +
-
- - + +
- - {#if awsError} -
- {awsError} -
- {/if} - - {#if awsSuccess} -
- {awsSuccess} -
- {/if} -
- -
-
{/if} + {#if awsSuccess}
{awsSuccess}
{/if} +
+ +
+ + {:else if activeTab === 'azure'} +
+

Azure Credentials

+ - {#if isAwsLoading} - - - - Updating... - {:else} - {secretsStatus.aws.configured - ? 'Update AWS Credentials' - : 'Set AWS Credentials'} - {/if} - + {secretsStatus.azure.configured ? 'Configured' : 'Not Configured'} +
- -
- - -
-
-

Azure Credentials

- handleMouseEnter(e, 'azure')} - on:mouseleave={() => handleMouseLeave('azure')} - > - {secretsStatus.azure.configured - ? 'Configured' - : 'Not Configured'} - - - {#if showAzureTooltip && secretsStatus.azure.configured} -
- {formatDateForTooltip(secretsStatus.azure.createdAt)} -
+
+
+ +
- {/if} -
- - -
- - + +
-
- - + +
-
- - + +
- -
- - + {#if azureError}
{azureError}
{/if} + {#if azureSuccess}
{azureSuccess}
{/if} +
+
- - {#if azureError} -
- {azureError} -
- {/if} - - {#if azureSuccess} -
- {azureSuccess} -
- {/if} -
- -
- -
- + + {/if}
{/if}
+ + + {#if showTooltip} + {@const status = activeTab === 'aws' ? secretsStatus.aws : secretsStatus.azure} + {#if status.configured} +
+ {formatDateForTooltip(status.createdAt)} +
+
+ {/if} + {/if}