diff --git a/werk24/cli/commands/health_check.py b/werk24/cli/commands/health_check.py index 0fa8b81..3b5837a 100644 --- a/werk24/cli/commands/health_check.py +++ b/werk24/cli/commands/health_check.py @@ -3,6 +3,7 @@ import sys import typer +from typing import Optional from packaging.version import Version from rich.console import Console from rich.panel import Panel @@ -22,13 +23,17 @@ @app.command() -def health_check(): +def health_check( + custom_cafile: str = typer.Option( + None, help="Path to an additional CA bundle for TLS verification" + ) +): """Run a comprehensive health check for the CLI.""" console.print(Panel(f"[blue]Werk24 CLI Health Check v{__version__}[/blue]")) system_information() license_information() - asyncio.run(network_information()) - asyncio.run(status_information()) + asyncio.run(network_information(custom_cafile)) + asyncio.run(status_information(custom_cafile)) def system_information(): @@ -73,7 +78,7 @@ def license_information(): print_panel("License Information", license_info) -async def network_information(): +async def network_information(custom_cafile: Optional[str]): """ Display network information and test WebSocket connections. """ @@ -81,7 +86,7 @@ async def network_information(): network_info = [] try: - async with Werk24Client() as _: + async with Werk24Client(custom_cafile=custom_cafile) as _: network_info.append( (f"WebSocket Connection ({server_uri})", "[green]Successful[/green]") ) @@ -129,11 +134,14 @@ async def network_information(): print_panel("Network Information", network_info) -async def status_information(): +async def status_information(custom_cafile: Optional[str]): """Display the system status information.""" status_info = [] try: - status = await Werk24Client.get_system_status() + if custom_cafile: + status = await Werk24Client.get_system_status(custom_cafile=custom_cafile) + else: + status = await Werk24Client.get_system_status() status_info.append(("Indicator", status.status_indicator)) if status.status_description: status_info.append(("Description", status.status_description)) diff --git a/werk24/cli/commands/status.py b/werk24/cli/commands/status.py index 2dc9314..0aeb096 100644 --- a/werk24/cli/commands/status.py +++ b/werk24/cli/commands/status.py @@ -10,8 +10,17 @@ @app.command() -def status(): +def status( + custom_cafile: str = typer.Option( + None, help="Path to an additional CA bundle for TLS verification" + ) +): """Fetch and display the Werk24 system status.""" - system_status = asyncio.run(Werk24Client.get_system_status()) + if custom_cafile: + system_status = asyncio.run( + Werk24Client.get_system_status(custom_cafile=custom_cafile) + ) + else: + system_status = asyncio.run(Werk24Client.get_system_status()) console.print_json(data=system_status.model_dump(mode="json")) diff --git a/werk24/cli/commands/techread.py b/werk24/cli/commands/techread.py index 5346393..59eb34e 100644 --- a/werk24/cli/commands/techread.py +++ b/werk24/cli/commands/techread.py @@ -45,6 +45,9 @@ def techread( ), ask_sheet_images: bool = typer.Option(False, help="Ask for sheet images"), ask_view_images: bool = typer.Option(False, help="Ask for view image"), + custom_cafile: str = typer.Option( + None, help="Path to an additional CA bundle for TLS verification" + ), ): """Read a drawing file and extract information.""" @@ -76,11 +79,13 @@ def techread( raise UserInputError("No hooks selected. At least one hook must be enabled.") with open(file_path, "rb") as fid: - asyncio.run(run(server, fid, hooks, max_pages)) + asyncio.run(run(server, fid, hooks, max_pages, custom_cafile)) -async def run(server: str, fh: str, hooks: list[Hook], max_pages: int): - async with Werk24Client(server) as client: +async def run( + server: str, fh: str, hooks: list[Hook], max_pages: int, custom_cafile: Optional[str] +): + async with Werk24Client(server, custom_cafile=custom_cafile) as client: await client.read_drawing_with_hooks(fh, hooks, max_pages) diff --git a/werk24/techread.py b/werk24/techread.py index 82d7f1a..81d3ee5 100644 --- a/werk24/techread.py +++ b/werk24/techread.py @@ -88,15 +88,20 @@ def __init__( https_server=settings.http_server, token: Optional[str] = None, region: Optional[str] = None, + custom_cafile: Optional[str] = None, ): self.license = find_license(token, region) self._wss_server = str(wss_server) self._https_server = str(https_server) self._wss_session = None + self._custom_cafile = custom_cafile # Reuse a single SSL context configured with the certifi CA bundle # to avoid recreating it for each connection and to ensure that the - # certificate chain is properly verified. + # certificate chain is properly verified. Load an additional custom + # CA file if provided. self._ssl_context = ssl.create_default_context(cafile=certifi.where()) + if self._custom_cafile: + self._ssl_context.load_verify_locations(cafile=self._custom_cafile) def _get_auth_headers(self): """ @@ -529,7 +534,7 @@ async def _upload_associated_file( try: logger.debug("Uploading file to the server: %s", str(presigned_post.url)) - async with self._make_https_session() as session: + async with self._make_https_session(cafile=self._custom_cafile) as session: response = await session.post(str(presigned_post.url), data=form) self._raise_for_status(str(presigned_post.url), response.status) logger.info("File uploaded successfully.") @@ -754,7 +759,7 @@ async def read_drawing_with_callback( # send the request headers = self._get_auth_headers() url = self._make_https_url("/techread/read-with-callback") - async with self._make_https_session() as session: + async with self._make_https_session(cafile=self._custom_cafile) as session: response = await session.post(url, data=data, headers=headers) self._raise_for_status(url, response.status) response_json = await response.json(content_type=None) @@ -765,11 +770,13 @@ async def read_drawing_with_callback( raise BadRequestException(f"Request failed: {response_json}") from e @staticmethod - async def get_system_status() -> SystemStatus: + async def get_system_status(custom_cafile: str | None = None) -> SystemStatus: """Fetch the current system status from the API.""" url = urljoin(str(settings.http_server), "/status") ssl_context = ssl.create_default_context(cafile=certifi.where()) + if custom_cafile: + ssl_context.load_verify_locations(cafile=custom_cafile) connector = aiohttp.TCPConnector(ssl=ssl_context) timeout = aiohttp.ClientTimeout(total=None, sock_connect=30, sock_read=30) async with aiohttp.ClientSession( @@ -814,9 +821,12 @@ def _make_https_session( - aiohttp.ClientSession: A configured HTTP client session. """ try: - # Use the provided CA file or the default certifi CA bundle - cafile = cafile or certifi.where() - ssl_context = ssl.create_default_context(cafile=cafile) + # Always start with the certifi CA bundle and load any custom + # certificate file on top so the default trust store remains + # available. + ssl_context = ssl.create_default_context(cafile=certifi.where()) + if cafile: + ssl_context.load_verify_locations(cafile=cafile) connector = aiohttp.TCPConnector(ssl=ssl_context) # Configure timeouts @@ -1006,7 +1016,7 @@ async def download_payload( # Attempt to download the payload try: - async with self._make_https_session() as session: + async with self._make_https_session(cafile=self._custom_cafile) as session: logger.debug("Sending GET request to %s", payload_url) response = await session.get(str(payload_url))