From d532dc66b6a358dc3526a886a8cbff514b424488 Mon Sep 17 00:00:00 2001 From: Chris Doman Date: Fri, 31 Jan 2025 17:50:12 +0000 Subject: [PATCH 1/6] Reuse threatexecutor --- cloudgrep/cloud.py | 37 +++++++++++++------------------------ 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/cloudgrep/cloud.py b/cloudgrep/cloud.py index b8f63ae..da7caa2 100644 --- a/cloudgrep/cloud.py +++ b/cloudgrep/cloud.py @@ -17,6 +17,16 @@ class Cloud: def __init__(self) -> None: self.search = Search() + def _download_and_search_in_parallel(self, files: List[str], worker_func) -> int: + """ Use ThreadPoolExecutorto download every file + Returns number of matched files """ + total_matched = 0 + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(worker_func, key) for key in files] + for fut in concurrent.futures.as_completed(futures): + total_matched += fut.result() + return total_matched + def download_from_s3_multithread( self, bucket: str, @@ -50,18 +60,11 @@ def _download_search_s3(key: str) -> int: return 0 finally: try: - # Cleanup os.remove(tmp_name) except OSError: pass - total_matched = 0 - # Use ThreadPoolExecutor to download and search files - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [executor.submit(_download_search_s3, k) for k in files] - for fut in concurrent.futures.as_completed(futures): - total_matched += fut.result() - return total_matched + return self._download_and_search_in_parallel(files, _download_search_s3) def download_from_azure( self, @@ -117,18 +120,11 @@ def _download_search_azure(key: str) -> int: return 0 finally: try: - import os - os.remove(tmp_name) except OSError: pass - total_matched = 0 - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [executor.submit(_download_search_azure, k) for k in files] - for fut in concurrent.futures.as_completed(futures): - total_matched += fut.result() - return total_matched + return self._download_and_search_in_parallel(files, _download_search_azure) def download_from_google( self, @@ -169,18 +165,11 @@ def _download_and_search_google(key: str) -> int: return 0 finally: try: - import os - os.remove(tmp_name) except OSError: pass - total_matched = 0 - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [executor.submit(_download_and_search_google, k) for k in files] - for fut in concurrent.futures.as_completed(futures): - total_matched += fut.result() - return total_matched + return self._download_and_search_in_parallel(files, _download_and_search_google) def filter_object( self, From 5338b99ee1747bd1c7d8891ae5ad81cd6cde828e Mon Sep 17 00:00:00 2001 From: Chris Doman Date: Fri, 31 Jan 2025 17:56:16 +0000 Subject: [PATCH 2/6] close zip files --- cloudgrep/search.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/cloudgrep/search.py b/cloudgrep/search.py index 4652700..6516adb 100644 --- a/cloudgrep/search.py +++ b/cloudgrep/search.py @@ -136,15 +136,23 @@ def process_lines(lines: Any) -> bool: elif key_name.endswith(".zip"): with tempfile.TemporaryDirectory() as tempdir, zipfile.ZipFile(file_name, "r") as zf: zf.extractall(tempdir) - return any( - # Process the extracted files - process_lines( - json.load(open(os.path.join(tempdir, filename))) - if account_name - else open(os.path.join(tempdir, filename)) - ) - # Search all files in the zip file - for filename in os.listdir(tempdir) - if os.path.isfile(os.path.join(tempdir, filename)) - ) + + matched_any = False + for extracted_filename in os.listdir(tempdir): + extracted_path = os.path.join(tempdir, extracted_filename) + if not os.path.isfile(extracted_path): + continue + + if account_name: + with open(extracted_path, "r", encoding="utf-8") as f: + data = json.load(f) + if process_lines(data): + matched_any = True + else: + with open(extracted_path, "r", encoding="utf-8") as f: + if process_lines(f): + matched_any = True + + return matched_any + return process_lines(self.get_all_strings_line(file_name)) From 023edeb6e7ff080ee2b862f18c28f4185b771fb6 Mon Sep 17 00:00:00 2001 From: Chris Doman Date: Sat, 1 Feb 2025 22:15:57 +0000 Subject: [PATCH 3/6] Split out filters --- cloudgrep/__main__.py | 158 +++++++++----------------- cloudgrep/cloud.py | 246 ++++++++++++++++++++--------------------- cloudgrep/cloudgrep.py | 132 ++++++++++------------ cloudgrep/search.py | 165 +++++++++++++++------------ 4 files changed, 326 insertions(+), 375 deletions(-) diff --git a/cloudgrep/__main__.py b/cloudgrep/__main__.py index 1a5a825..e8b7ba2 100644 --- a/cloudgrep/__main__.py +++ b/cloudgrep/__main__.py @@ -1,137 +1,87 @@ import argparse import logging import sys -from typing import List +from typing import List, Optional +import dateutil.parser +import datetime from cloudgrep.cloudgrep import CloudGrep VERSION = "1.0.5" -# Define a custom argument type for a list of strings def list_of_strings(arg: str) -> List[str]: - return arg.split(",") + """Parse a comma‐separated string into a list of nonempty strings.""" + return [s.strip() for s in arg.split(",") if s.strip()] def main() -> None: parser = argparse.ArgumentParser( - description=f"CloudGrep searches is grep for cloud storage like S3 and Azure Storage. Version: {VERSION}" - ) - parser.add_argument("-b", "--bucket", help="AWS S3 Bucket to search. E.g. my-bucket", required=False) - parser.add_argument("-an", "--account-name", help="Azure Account Name to Search", required=False) - parser.add_argument("-cn", "--container-name", help="Azure Container Name to Search", required=False) - parser.add_argument("-gb", "--google-bucket", help="Google Cloud Bucket to Search", required=False) - parser.add_argument( - "-q", - "--query", - type=list_of_strings, - help="Text to search for. Will be parsed as a Regex. E.g. example.com", - required=False, - ) - parser.add_argument( - "-v", - "--file", - help="File containing a list of words or regular expressions to search for. One per line.", - required=False, - ) - parser.add_argument( - "-y", - "--yara", - help="File containing Yara rules to scan files.", - required=False, - ) - parser.add_argument( - "-p", - "--prefix", - help="Optionally filter on the start of the Object name. E.g. logs/", - required=False, - default="", - ) - parser.add_argument( - "-f", "--filename", help="Optionally filter on Objects that match a keyword. E.g. .log.gz ", required=False - ) - parser.add_argument( - "-s", - "--start_date", - help="Optionally filter on Objects modified after a Date or Time. E.g. 2022-01-01 ", - required=False, - ) - parser.add_argument( - "-e", - "--end_date", - help="Optionally filter on Objects modified before a Date or Time. E.g. 2022-01-01 ", - required=False, + description=f"CloudGrep: grep for cloud storage (S3, Azure, Google Cloud). Version: {VERSION}" ) + parser.add_argument("-b", "--bucket", help="AWS S3 Bucket to search (e.g. my-bucket)") + parser.add_argument("-an", "--account-name", help="Azure Account Name to search") + parser.add_argument("-cn", "--container-name", help="Azure Container Name to search") + parser.add_argument("-gb", "--google-bucket", help="Google Cloud Bucket to search") + parser.add_argument("-q", "--query", type=list_of_strings, help="Comma-separated list of regex patterns to search") + parser.add_argument("-v", "--file", help="File containing queries (one per line)") + parser.add_argument("-y", "--yara", help="File containing Yara rules") + parser.add_argument("-p", "--prefix", default="", help="Filter objects by prefix (e.g. logs/)") + parser.add_argument("-f", "--filename", help="Filter objects whose names contain a keyword (e.g. .log.gz)") + parser.add_argument("-s", "--start_date", help="Filter objects modified after this date (YYYY-MM-DD)") + parser.add_argument("-e", "--end_date", help="Filter objects modified before this date (YYYY-MM-DD)") parser.add_argument( "-fs", "--file_size", - help="Optionally filter on Objects smaller than a file size, in bytes. Defaults to 100 Mb. ", - default=100000000, - required=False, - ) - parser.add_argument( - "-pr", - "--profile", - help="Set an AWS profile to use. E.g. default, dev, prod.", - required=False, - ) - parser.add_argument("-d", "--debug", help="Enable Debug logging. ", action="store_true", required=False) - parser.add_argument( - "-hf", "--hide_filenames", help="Dont show matching filenames. ", action="store_true", required=False + type=int, + default=100_000_000, + help="Max file size in bytes (default: 100MB)", ) + parser.add_argument("-pr", "--profile", help="AWS profile to use (e.g. default, dev, prod)") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug logging") + parser.add_argument("-hf", "--hide_filenames", action="store_true", help="Hide filenames in output") + parser.add_argument("-lt", "--log_type", help="Pre-defined log type (e.g. cloudtrail, azure)") + parser.add_argument("-lf", "--log_format", help="Custom log format (e.g. json, csv)") parser.add_argument( - "-lt", - "--log_type", - help="Return individual matching log entries based on pre-defined log types, otherwise custom log_format and log_properties can be used. E.g. cloudtrail. ", - required=False, + "-lp", "--log_properties", type=list_of_strings, help="Comma-separated list of log properties to extract" ) - parser.add_argument( - "-lf", - "--log_format", - help="Define custom log format of raw file to parse before applying search logic. Used if --log_type is not defined. E.g. json. ", - required=False, - ) - parser.add_argument( - "-lp", - "--log_properties", - type=list_of_strings, - help="Define custom list of properties to traverse to dynamically extract final list of log records. Used if --log_type is not defined. E.g. [" - "Records" - "]. ", - required=False, - ) - parser.add_argument("-jo", "--json_output", help="Output as JSON.", action="store_true") - args = vars(parser.parse_args()) + parser.add_argument("-jo", "--json_output", action="store_true", help="Output results in JSON format") + args = parser.parse_args() if len(sys.argv) == 1: parser.print_help(sys.stderr) sys.exit(1) - if args["debug"]: - logging.basicConfig(format="[%(asctime)s]:[%(levelname)s] - %(message)s", level=logging.INFO) + # Parse dates (if provided) into datetime objects + start_date: Optional["datetime.datetime"] = dateutil.parser.parse(args.start_date) if args.start_date else None + end_date: Optional["datetime.datetime"] = dateutil.parser.parse(args.end_date) if args.end_date else None + + # Configure logging + if args.debug: + logging.basicConfig(format="[%(asctime)s] [%(levelname)s] %(message)s", level=logging.DEBUG) else: - logging.basicConfig(format="[%(asctime)s] - %(message)s", level=logging.WARNING) + logging.basicConfig(format="[%(asctime)s] %(message)s", level=logging.WARNING) logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR) CloudGrep().search( - args["bucket"], - args["account_name"], - args["container_name"], - args["google_bucket"], - args["query"], - args["file"], - args["yara"], - int(args["file_size"]), - args["prefix"], - args["filename"], - args["start_date"], - args["end_date"], - args["hide_filenames"], - args["log_type"], - args["log_format"], - args["log_properties"], - args["profile"], - args["json_output"], + bucket=args.bucket, + account_name=args.account_name, + container_name=args.container_name, + google_bucket=args.google_bucket, + query=args.query, + file=args.file, + yara_file=args.yara, + file_size=args.file_size, + prefix=args.prefix, + key_contains=args.filename, + from_date=start_date, + end_date=end_date, + hide_filenames=args.hide_filenames, + log_type=args.log_type, + log_format=args.log_format, + log_properties=args.log_properties, + profile=args.profile, + json_output=args.json_output, ) diff --git a/cloudgrep/cloud.py b/cloudgrep/cloud.py index da7caa2..1fe3393 100644 --- a/cloudgrep/cloud.py +++ b/cloudgrep/cloud.py @@ -1,6 +1,6 @@ import boto3 import os -from azure.storage.blob import BlobServiceClient, BlobProperties +from azure.storage.blob import BlobServiceClient from azure.identity import DefaultAzureCredential from azure.core.exceptions import ResourceNotFoundError from google.cloud import storage # type: ignore @@ -8,25 +8,33 @@ import botocore import concurrent.futures import tempfile -from typing import Iterator, Optional, List, Any +from typing import Iterator, Optional, List, Any, Tuple import logging from cloudgrep.search import Search - class Cloud: def __init__(self) -> None: self.search = Search() - def _download_and_search_in_parallel(self, files: List[str], worker_func) -> int: - """ Use ThreadPoolExecutorto download every file + def _download_and_search_in_parallel(self, files: List[Any], worker_func) -> int: + """ Use ThreadPoolExecutor to download every file Returns number of matched files """ total_matched = 0 with concurrent.futures.ThreadPoolExecutor() as executor: futures = [executor.submit(worker_func, key) for key in files] for fut in concurrent.futures.as_completed(futures): - total_matched += fut.result() + try: + total_matched += fut.result() + except Exception: + logging.exception("Error in worker thread") return total_matched + def _download_to_temp(self) -> str: + """ Return a temporary filename """ + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp.close() + return tmp.name + def download_from_s3_multithread( self, bucket: str, @@ -38,16 +46,13 @@ def download_from_s3_multithread( log_properties: List[str] = [], json_output: Optional[bool] = False, ) -> int: - """Use ThreadPoolExecutor and boto3 to download every file in the bucket from s3 - Returns number of matched files""" - if not log_properties: + """ Download and search files from AWS S3""" + if log_properties is None: log_properties = [] - s3 = boto3.client("s3", config=botocore.config.Config(max_pool_connections=64)) def _download_search_s3(key: str) -> int: - with tempfile.NamedTemporaryFile(delete=False) as tmp: - tmp_name = tmp.name + tmp_name = self._download_to_temp() try: logging.info(f"Downloading s3://{bucket}/{key} to {tmp_name}") s3.download_file(bucket, key, tmp_name) @@ -55,8 +60,8 @@ def _download_search_s3(key: str) -> int: tmp_name, key, query, hide_filenames, yara_rules, log_format, log_properties, json_output ) return 1 if matched else 0 - except Exception as e: - logging.error(f"Error downloading or searching {key}: {e}") + except Exception: + logging.exception(f"Error processing {key}") return 0 finally: try: @@ -78,28 +83,22 @@ def download_from_azure( log_properties: Optional[List[str]] = None, json_output: bool = False, ) -> int: - """Download every file in the container from azure - Returns number of matched files""" - if not log_properties: + """Download and search files from Azure Storage """ + if log_properties is None: log_properties = [] - default_credential = DefaultAzureCredential() - blob_service_client = BlobServiceClient.from_connection_string( - f"DefaultEndpointsProtocol=https;AccountName={account_name};EndpointSuffix=core.windows.net", - credential=default_credential, - ) + connection_str = f"DefaultEndpointsProtocol=https;AccountName={account_name};EndpointSuffix=core.windows.net" + blob_service_client = BlobServiceClient.from_connection_string(connection_str, credential=default_credential) container_client = blob_service_client.get_container_client(container_name) def _download_search_azure(key: str) -> int: - with tempfile.NamedTemporaryFile(delete=False) as tmp: - tmp_name = tmp.name + tmp_name = self._download_to_temp() try: logging.info(f"Downloading azure://{account_name}/{container_name}/{key} to {tmp_name}") blob_client = container_client.get_blob_client(key) - with open(tmp_name, "wb") as my_blob: - blob_data = blob_client.download_blob() - blob_data.readinto(my_blob) - + with open(tmp_name, "wb") as out_file: + stream = blob_client.download_blob() + stream.readinto(out_file) matched = self.search.search_file( tmp_name, key, @@ -115,8 +114,8 @@ def _download_search_azure(key: str) -> int: except ResourceNotFoundError: logging.info(f"File {key} not found in {account_name}/{container_name}") return 0 - except Exception as e: - logging.error(f"Error downloading or searching {key}: {e}") + except Exception: + logging.exception(f"Error processing {key}") return 0 finally: try: @@ -129,7 +128,7 @@ def _download_search_azure(key: str) -> int: def download_from_google( self, bucket: str, - files: List[str], + blobs: List[Tuple[str, Any]], query: List[str], hide_filenames: bool, yara_rules: Any, @@ -137,31 +136,22 @@ def download_from_google( log_properties: Optional[List[str]] = None, json_output: bool = False, ) -> int: - """Download every file in the bucket from google - Returns number of matched files""" - if not log_properties: + """Download and search files from Google Cloud Storage""" + if log_properties is None: log_properties = [] - client = storage.Client() - bucket_gcp = client.get_bucket(bucket) - - def _download_and_search_google(key: str) -> int: - with tempfile.NamedTemporaryFile(delete=False) as tmp: - tmp_name = tmp.name + def _download_and_search_google(item: Tuple[str, Any]) -> int: + key, blob = item + tmp_name = self._download_to_temp() try: logging.info(f"Downloading gs://{bucket}/{key} to {tmp_name}") - blob = bucket_gcp.get_blob(key) - if blob is None: - logging.warning(f"Blob not found: {key}") - return 0 blob.download_to_filename(tmp_name) - matched = self.search.search_file( tmp_name, key, query, hide_filenames, yara_rules, log_format, log_properties, json_output ) return 1 if matched else 0 - except Exception as e: - logging.error(f"Error downloading or searching {key}: {e}") + except Exception: + logging.exception(f"Error processing {key}") return 0 finally: try: @@ -169,64 +159,7 @@ def _download_and_search_google(key: str) -> int: except OSError: pass - return self._download_and_search_in_parallel(files, _download_and_search_google) - - def filter_object( - self, - obj: dict, - key_contains: Optional[str], - from_date: Optional[datetime], - to_date: Optional[datetime], - file_size: int, - ) -> bool: - """Filter S3 objects by date range, file size, and substring in key.""" - last_modified = obj["LastModified"] - if last_modified and from_date and from_date > last_modified: - return False # Object was modified before the from_date - if last_modified and to_date and last_modified > to_date: - return False # Object was modified after the to_date - if obj["Size"] == 0 or int(obj["Size"]) > file_size: - return False # Object is empty or too large - if key_contains and key_contains not in obj["Key"]: - return False # Object does not contain the key_contains string - return True - - def filter_object_azure( - self, - obj: BlobProperties, - key_contains: Optional[str], - from_date: Optional[datetime], - to_date: Optional[datetime], - file_size: int, - ) -> bool: - """Filter Azure blob objects similarly.""" - last_modified = obj["last_modified"] # type: ignore - if last_modified and from_date and from_date > last_modified: - return False - if last_modified and to_date and last_modified > to_date: - return False - if obj["size"] == 0 or int(obj["size"]) > file_size: - return False - if key_contains and key_contains not in obj["name"]: - return False - return True - - def filter_object_google( - self, - obj: storage.blob.Blob, - key_contains: Optional[str], - from_date: Optional[datetime], - to_date: Optional[datetime], - ) -> bool: - """Filter objects in GCP""" - last_modified = obj.updated - if last_modified and from_date and from_date > last_modified: - return False - if last_modified and to_date and last_modified > to_date: - return False - if key_contains and key_contains not in obj.name: - return False - return True + return self._download_and_search_in_parallel(blobs, _download_and_search_google) def get_objects( self, @@ -237,15 +170,13 @@ def get_objects( end_date: Optional[datetime], file_size: int, ) -> Iterator[str]: - """Get objects in S3""" + """Yield objects that match filter""" s3 = boto3.client("s3") paginator = s3.get_paginator("list_objects_v2") - page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix) - for page in page_iterator: - if "Contents" in page: - for obj in page["Contents"]: - if self.filter_object(obj, key_contains, from_date, end_date, file_size): - yield obj["Key"] + for page in paginator.paginate(Bucket=bucket, Prefix=prefix): + for obj in page.get("Contents", []): + if self.filter_object(obj, key_contains, from_date, end_date, file_size): + yield obj.get("Key") def get_azure_objects( self, @@ -257,16 +188,12 @@ def get_azure_objects( end_date: Optional[datetime], file_size: int, ) -> Iterator[str]: - """Get all objects in Azure storage container with a given prefix""" + """Yield Azure blob names that match the filter""" default_credential = DefaultAzureCredential() - blob_service_client = BlobServiceClient.from_connection_string( - f"DefaultEndpointsProtocol=https;AccountName={account_name};EndpointSuffix=core.windows.net", - credential=default_credential, - ) + connection_str = f"DefaultEndpointsProtocol=https;AccountName={account_name};EndpointSuffix=core.windows.net" + blob_service_client = BlobServiceClient.from_connection_string(connection_str, credential=default_credential) container_client = blob_service_client.get_container_client(container_name) - blobs = container_client.list_blobs(name_starts_with=prefix) - - for blob in blobs: + for blob in container_client.list_blobs(name_starts_with=prefix): if self.filter_object_azure(blob, key_contains, from_date, end_date, file_size): yield blob.name @@ -277,11 +204,80 @@ def get_google_objects( key_contains: Optional[str], from_date: Optional[datetime], end_date: Optional[datetime], - ) -> Iterator[str]: - """Get all objects in a GCP bucket with a given prefix""" + ) -> Iterator[Tuple[str, Any]]: + """Yield (blob_name, blob) for blobs in GCP that match filter""" client = storage.Client() bucket_gcp = client.get_bucket(bucket) - blobs = bucket_gcp.list_blobs(prefix=prefix) - for blob in blobs: + for blob in bucket_gcp.list_blobs(prefix=prefix): if self.filter_object_google(blob, key_contains, from_date, end_date): - yield blob.name + yield blob.name, blob + + def filter_object( + self, + obj: dict, + key_contains: Optional[str], + from_date: Optional[datetime], + to_date: Optional[datetime], + file_size: int, + ) -> bool: + """Filter an S3 object based on modification date, size, and key substring""" + last_modified = obj.get("LastModified") + if last_modified: + if from_date and last_modified < from_date: + return False + if to_date and last_modified > to_date: + return False + # If size is 0 or greater than file_size, skip + if int(obj.get("Size", 0)) == 0 or int(obj.get("Size", 0)) > file_size: + return False + if key_contains and key_contains not in obj.get("Key", ""): + return False + return True + + def filter_object_azure( + self, + obj: Any, + key_contains: Optional[str], + from_date: Optional[datetime], + to_date: Optional[datetime], + file_size: int, + ) -> bool: + """ + Filter an Azure blob object (or dict) based on modification date, size, and name substring. + """ + if isinstance(obj, dict): + last_modified = obj.get("last_modified") + size = int(obj.get("size", 0)) + name = obj.get("name", "") + else: + last_modified = getattr(obj, "last_modified", None) + size = int(getattr(obj, "size", 0)) + name = getattr(obj, "name", "") + if last_modified: + if from_date and last_modified < from_date: + return False + if to_date and last_modified > to_date: + return False + if size == 0 or size > file_size: + return False + if key_contains and key_contains not in name: + return False + return True + + def filter_object_google( + self, + obj: storage.blob.Blob, + key_contains: Optional[str], + from_date: Optional[datetime], + to_date: Optional[datetime], + ) -> bool: + """Filter a GCP blob based on update time and name substring""" + last_modified = getattr(obj, "updated", None) + if last_modified: + if from_date and last_modified < from_date: + return False + if to_date and last_modified > to_date: + return False + if key_contains and key_contains not in getattr(obj, "name", ""): + return False + return True diff --git a/cloudgrep/cloudgrep.py b/cloudgrep/cloudgrep.py index 6250a34..a04b483 100644 --- a/cloudgrep/cloudgrep.py +++ b/cloudgrep/cloudgrep.py @@ -11,9 +11,8 @@ class CloudGrep: def __init__(self) -> None: self.cloud = Cloud() - def load_queries(self, file: str) -> List[str]: - """Load in a list of queries from a file""" - with open(file, "r") as f: + def load_queries(self, file_path: str) -> List[str]: + with open(file_path, "r", encoding="utf-8") as f: return [line.strip() for line in f if line.strip()] def search( @@ -22,7 +21,7 @@ def search( account_name: Optional[str], container_name: Optional[str], google_bucket: Optional[str], - query: List[str], + query: Optional[List[str]], file: Optional[str], yara_file: Optional[str], file_size: int, @@ -37,84 +36,79 @@ def search( profile: Optional[str] = None, json_output: bool = False, ) -> None: - """Search query/queries across cloud storage""" - - # Load queries from a file if given if not query and file: - logging.debug(f"Loading queries in from {file}") + logging.debug(f"Loading queries from {file}") query = self.load_queries(file) + if not query: + logging.error("No query provided. Exiting.") + return - # Compile optional Yara rules yara_rules = None if yara_file: - logging.debug(f"Loading yara rules from {yara_file}") + logging.debug(f"Compiling yara rules from {yara_file}") yara_rules = yara.compile(filepath=yara_file) if profile: - # Set the AWS credentials profile to use boto3.setup_default_session(profile_name=profile) - if log_type is not None: - if log_type == "cloudtrail": + if log_type: + if log_type.lower() == "cloudtrail": log_format = "json" log_properties = ["Records"] - elif log_type == "azure": + elif log_type.lower() == "azure": log_format = "json" log_properties = ["data"] else: - logging.error(f"Invalid log_type: '{log_type}'") + logging.error(f"Invalid log_type: {log_type}") return if log_properties is None: - log_properties = [] # default + log_properties = [] - # Search given cloud storage if bucket: self._search_s3( - bucket=bucket, - query=query, - yara_rules=yara_rules, - file_size=file_size, - prefix=prefix, - key_contains=key_contains, - from_date=from_date, - end_date=end_date, - hide_filenames=hide_filenames, - log_format=log_format, - log_properties=log_properties, - json_output=json_output, + bucket, + query, + yara_rules, + file_size, + prefix, + key_contains, + from_date, + end_date, + hide_filenames, + log_format, + log_properties, + json_output, ) - if account_name and container_name: self._search_azure( - account_name=account_name, - container_name=container_name, - query=query, - yara_rules=yara_rules, - file_size=file_size, - prefix=prefix, - key_contains=key_contains, - from_date=from_date, - end_date=end_date, - hide_filenames=hide_filenames, - log_format=log_format, - log_properties=log_properties, - json_output=json_output, + account_name, + container_name, + query, + yara_rules, + file_size, + prefix, + key_contains, + from_date, + end_date, + hide_filenames, + log_format, + log_properties, + json_output, ) - if google_bucket: self._search_gcs( - google_bucket=google_bucket, - query=query, - yara_rules=yara_rules, - file_size=file_size, - prefix=prefix, - key_contains=key_contains, - from_date=from_date, - end_date=end_date, - hide_filenames=hide_filenames, - log_format=log_format, - log_properties=log_properties, - json_output=json_output, + google_bucket, + query, + yara_rules, + file_size, + prefix, + key_contains, + from_date, + end_date, + hide_filenames, + log_format, + log_properties, + json_output, ) def _search_s3( @@ -132,14 +126,11 @@ def _search_s3( log_properties: List[str], json_output: bool, ) -> None: - """Search S3 bucket for query""" + """ Search S3 bucket for query """ matching_keys = list(self.cloud.get_objects(bucket, prefix, key_contains, from_date, end_date, file_size)) s3_client = boto3.client("s3") - region = s3_client.get_bucket_location(Bucket=bucket) - logging.warning( - f"Bucket is in region: {region.get('LocationConstraint', 'unknown')} : " - "Search from the same region to avoid egress charges." - ) + region = s3_client.get_bucket_location(Bucket=bucket).get("LocationConstraint", "unknown") + logging.warning(f"Bucket region: {region}. (Search from the same region to avoid egress charges.)") logging.warning(f"Searching {len(matching_keys)} files in {bucket} for {query}...") self.cloud.download_from_s3_multithread( bucket, matching_keys, query, hide_filenames, yara_rules, log_format, log_properties, json_output @@ -161,13 +152,12 @@ def _search_azure( log_properties: List[str], json_output: bool, ) -> None: - """Search Azure container for query""" matching_keys = list( self.cloud.get_azure_objects( account_name, container_name, prefix, key_contains, from_date, end_date, file_size ) ) - print(f"Searching {len(matching_keys)} files in {account_name}/{container_name} for {query}...") + logging.info(f"Searching {len(matching_keys)} files in {account_name}/{container_name} for {query}...") self.cloud.download_from_azure( account_name, container_name, @@ -195,15 +185,9 @@ def _search_gcs( log_properties: List[str], json_output: bool, ) -> None: - matching_keys = list(self.cloud.get_google_objects(google_bucket, prefix, key_contains, from_date, end_date)) - print(f"Searching {len(matching_keys)} files in {google_bucket} for {query}...") + # Get (blob_name, blob) + matching_blobs = list(self.cloud.get_google_objects(google_bucket, prefix, key_contains, from_date, end_date)) + logging.info(f"Searching {len(matching_blobs)} files in {google_bucket} for {query}...") self.cloud.download_from_google( - google_bucket, - matching_keys, - query, - hide_filenames, - yara_rules, - log_format, - log_properties, - json_output, + google_bucket, matching_blobs, query, hide_filenames, yara_rules, log_format, log_properties, json_output ) diff --git a/cloudgrep/search.py b/cloudgrep/search.py index 6516adb..b7a26c6 100644 --- a/cloudgrep/search.py +++ b/cloudgrep/search.py @@ -11,47 +11,51 @@ class Search: def get_all_strings_line(self, file_path: str) -> List[str]: - """Get all the strings from a file line by line - We do this instead of f.readlines() as this supports binary files too - """ with open(file_path, "rb") as f: - read_bytes = f.read() - return read_bytes.decode("utf-8", "ignore").replace("\n", "\r").split("\r") + content = f.read().decode("utf-8", errors="ignore") + return content.splitlines() - def print_match(self, matched_line_dict: dict, hide_filenames: bool, json_output: Optional[bool]) -> None: - """Print matched line""" + def print_match(self, match_info: dict, hide_filenames: bool, json_output: Optional[bool]) -> None: + output = match_info.copy() + if hide_filenames: + output.pop("key_name", None) if json_output: - matched_line_dict.pop("key_name", None) if hide_filenames else None try: - print(json.dumps(matched_line_dict)) + print(json.dumps(output)) except TypeError: - print(str(matched_line_dict)) + print(str(output)) else: - line = matched_line_dict.get("line", "") - if "match_rule" in matched_line_dict: - line = f"{matched_line_dict['match_rule']}: {matched_line_dict['match_strings']}" - print(line if hide_filenames else f"{matched_line_dict['key_name']}: {line}") + line = output.get("line", "") + if "match_rule" in output: + line = f"{output['match_rule']}: {output.get('match_strings', '')}" + print(f"{output.get('key_name', '')}: {line}" if not hide_filenames else line) def parse_logs(self, line: str, log_format: Optional[str]) -> Any: - """Parse input log line based on format""" - try: - if log_format == "json": + if log_format == "json": + try: return json.loads(line) - elif log_format == "csv": + except json.JSONDecodeError as e: + logging.error(f"JSON decode error in line: {line} ({e})") + elif log_format == "csv": + try: return list(csv.DictReader([line])) - elif log_format: - logging.error(f"Invalid log format: {log_format}") - except (json.JSONDecodeError, csv.Error) as e: - logging.error(f"Invalid {log_format} format in line: {line} ({e})") + except csv.Error as e: + logging.error(f"CSV parse error in line: {line} ({e})") + elif log_format: + logging.error(f"Unsupported log format: {log_format}") return None - def extract_log_entries(self, line_parsed: Any, log_properties: List[str]) -> List[Any]: - """Extract properties in log entries""" - if log_properties: - for log_property in log_properties: - if isinstance(line_parsed, dict): - line_parsed = line_parsed.get(log_property, None) - return line_parsed if isinstance(line_parsed, list) else [line_parsed] + def extract_log_entries(self, parsed: Any, log_properties: List[str]) -> List[Any]: + if log_properties and isinstance(parsed, dict): + for prop in log_properties: + parsed = parsed.get(prop, None) + if parsed is None: + break + if isinstance(parsed, list): + return parsed + elif parsed is not None: + return [parsed] + return [] def search_logs( self, @@ -64,18 +68,18 @@ def search_logs( json_output: Optional[bool] = False, ) -> None: """Search log records in parsed logs""" - line_parsed = self.parse_logs(line, log_format) - if not line_parsed: + parsed = self.parse_logs(line, log_format) + if not parsed: return - - for record in self.extract_log_entries(line_parsed, log_properties): - if re.search(search, json.dumps(record)): - self.print_match({"key_name": key_name, "query": search, "line": record}, hide_filenames, json_output) + for entry in self.extract_log_entries(parsed, log_properties): + entry_str = json.dumps(entry) + if re.search(search, entry_str): + self.print_match({"key_name": key_name, "query": search, "line": entry}, hide_filenames, json_output) def search_line( self, key_name: str, - search: List[str], + patterns: List[str], hide_filenames: bool, line: str, log_format: Optional[str], @@ -83,16 +87,17 @@ def search_line( json_output: Optional[bool] = False, ) -> bool: """Regex search of the line""" - matched = any(re.search(cur_search, line) for cur_search in search) - if matched: - if log_format: - for cur_search in search: - self.search_logs( - line, key_name, cur_search, hide_filenames, log_format, log_properties, json_output + found = False + for pattern in patterns: + if re.search(pattern, line): + if log_format: + self.search_logs(line, key_name, pattern, hide_filenames, log_format, log_properties, json_output) + else: + self.print_match( + {"key_name": key_name, "query": pattern, "line": line}, hide_filenames, json_output ) - else: - self.print_match({"key_name": key_name, "query": search, "line": line}, hide_filenames, json_output) - return matched + found = True + return found def yara_scan_file( self, file_name: str, key_name: str, hide_filenames: bool, yara_rules: Any, json_output: Optional[bool] = False @@ -111,7 +116,7 @@ def search_file( self, file_name: str, key_name: str, - search: List[str], + patterns: List[str], hide_filenames: bool, yara_rules: Any, log_format: Optional[str] = None, @@ -120,39 +125,55 @@ def search_file( account_name: Optional[str] = None, ) -> bool: """Regex search of the file line by line""" - logging.info(f"Searching {file_name} for {search}") + logging.info(f"Searching {file_name} for patterns: {patterns}") if yara_rules: return self.yara_scan_file(file_name, key_name, hide_filenames, yara_rules, json_output) - def process_lines(lines: Any) -> bool: + def process_lines(lines) -> bool: return any( - self.search_line(key_name, search, hide_filenames, line, log_format, log_properties, json_output) + self.search_line(key_name, patterns, hide_filenames, line, log_format, log_properties, json_output) for line in lines ) - if key_name.endswith(".gz"): - with gzip.open(file_name, "rt") as f: - return process_lines(json.load(f) if account_name else f) - elif key_name.endswith(".zip"): - with tempfile.TemporaryDirectory() as tempdir, zipfile.ZipFile(file_name, "r") as zf: - zf.extractall(tempdir) - - matched_any = False - for extracted_filename in os.listdir(tempdir): - extracted_path = os.path.join(tempdir, extracted_filename) - if not os.path.isfile(extracted_path): - continue - + if file_name.endswith(".gz"): + try: + with gzip.open(file_name, "rt", encoding="utf-8", errors="ignore") as f: if account_name: - with open(extracted_path, "r", encoding="utf-8") as f: - data = json.load(f) - if process_lines(data): - matched_any = True + data = json.load(f) + return process_lines(data) else: - with open(extracted_path, "r", encoding="utf-8") as f: - if process_lines(f): - matched_any = True - + return process_lines(f) + except Exception: + logging.exception(f"Error processing gzip file: {file_name}") + return False + elif file_name.endswith(".zip"): + matched_any = False + try: + with tempfile.TemporaryDirectory() as tempdir: + with zipfile.ZipFile(file_name, "r") as zf: + zf.extractall(tempdir) + for extracted in os.listdir(tempdir): + extracted_path = os.path.join(tempdir, extracted) + if not os.path.isfile(extracted_path): + continue + try: + with open(extracted_path, "r", encoding="utf-8", errors="ignore") as f: + if account_name: + data = json.load(f) + if process_lines(data): + matched_any = True + else: + if process_lines(f): + matched_any = True + except Exception: + logging.exception(f"Error processing extracted file: {extracted_path}") return matched_any - - return process_lines(self.get_all_strings_line(file_name)) + except Exception: + logging.exception(f"Error processing zip file: {file_name}") + return False + else: + try: + return process_lines(self.get_all_strings_line(file_name)) + except Exception: + logging.exception(f"Error processing file: {file_name}") + return False From 66aad8a9ca1c787f6cbd972f48d2b7c847359393 Mon Sep 17 00:00:00 2001 From: Chris Doman Date: Sat, 1 Feb 2025 22:40:23 +0000 Subject: [PATCH 4/6] Split out list files and add test --- cloudgrep/cloud.py | 13 +-- cloudgrep/cloudgrep.py | 188 +++++++++++++++++------------------------ cloudgrep/search.py | 2 +- tests/test_unit.py | 67 +++++++++++++++ 4 files changed, 153 insertions(+), 117 deletions(-) diff --git a/cloudgrep/cloud.py b/cloudgrep/cloud.py index 1fe3393..1202858 100644 --- a/cloudgrep/cloud.py +++ b/cloudgrep/cloud.py @@ -12,13 +12,14 @@ import logging from cloudgrep.search import Search + class Cloud: def __init__(self) -> None: self.search = Search() - def _download_and_search_in_parallel(self, files: List[Any], worker_func) -> int: - """ Use ThreadPoolExecutor to download every file - Returns number of matched files """ + def _download_and_search_in_parallel(self, files: List[Any], worker_func: Any) -> int: + """Use ThreadPoolExecutor to download every file + Returns number of matched files""" total_matched = 0 with concurrent.futures.ThreadPoolExecutor() as executor: futures = [executor.submit(worker_func, key) for key in files] @@ -30,7 +31,7 @@ def _download_and_search_in_parallel(self, files: List[Any], worker_func) -> int return total_matched def _download_to_temp(self) -> str: - """ Return a temporary filename """ + """Return a temporary filename""" with tempfile.NamedTemporaryFile(delete=False) as tmp: tmp.close() return tmp.name @@ -46,7 +47,7 @@ def download_from_s3_multithread( log_properties: List[str] = [], json_output: Optional[bool] = False, ) -> int: - """ Download and search files from AWS S3""" + """Download and search files from AWS S3""" if log_properties is None: log_properties = [] s3 = boto3.client("s3", config=botocore.config.Config(max_pool_connections=64)) @@ -83,7 +84,7 @@ def download_from_azure( log_properties: Optional[List[str]] = None, json_output: bool = False, ) -> int: - """Download and search files from Azure Storage """ + """Download and search files from Azure Storage""" if log_properties is None: log_properties = [] default_credential = DefaultAzureCredential() diff --git a/cloudgrep/cloudgrep.py b/cloudgrep/cloudgrep.py index a04b483..276c11f 100644 --- a/cloudgrep/cloudgrep.py +++ b/cloudgrep/cloudgrep.py @@ -1,6 +1,6 @@ import boto3 from datetime import datetime -from typing import Optional, List, Any +from typing import Optional, List, Any, Dict import logging import yara # type: ignore @@ -15,6 +15,39 @@ def load_queries(self, file_path: str) -> List[str]: with open(file_path, "r", encoding="utf-8") as f: return [line.strip() for line in f if line.strip()] + def list_files( + self, + bucket: Optional[str], + account_name: Optional[str], + container_name: Optional[str], + google_bucket: Optional[str], + prefix: Optional[str] = "", + key_contains: Optional[str] = None, + from_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + file_size: int = 100_000_000, # 100MB + ) -> Dict[str, List[Any]]: + """ + Returns a dictionary of matching files for each cloud provider. + + The returned dict has the following keys: + - "s3": a list of S3 object keys that match filters + - "azure": a list of Azure blob names that match filters + - "gcs": a list of tuples (blob name, blob) for Google Cloud Storage that match filters + """ + files = {} + if bucket: + files["s3"] = list(self.cloud.get_objects(bucket, prefix, key_contains, from_date, end_date, file_size)) + if account_name and container_name: + files["azure"] = list( + self.cloud.get_azure_objects( + account_name, container_name, prefix, key_contains, from_date, end_date, file_size + ) + ) + if google_bucket: + files["gcs"] = list(self.cloud.get_google_objects(google_bucket, prefix, key_contains, from_date, end_date)) + return files + def search( self, bucket: Optional[str], @@ -25,7 +58,7 @@ def search( file: Optional[str], yara_file: Optional[str], file_size: int, - prefix: Optional[str] = None, + prefix: Optional[str] = "", key_contains: Optional[str] = None, from_date: Optional[datetime] = None, end_date: Optional[datetime] = None, @@ -35,7 +68,14 @@ def search( log_properties: Optional[List[str]] = None, profile: Optional[str] = None, json_output: bool = False, + files: Optional[Dict[str, List[Any]]] = None, ) -> None: + """ + Searches the contents of files matching the given queries. + + If the optional `files` parameter is provided (a dict with keys such as "s3", "azure", or "gcs") + then the search will use those file lists instead of applying the filters again. + """ if not query and file: logging.debug(f"Loading queries from {file}") query = self.load_queries(file) @@ -65,129 +105,57 @@ def search( log_properties = [] if bucket: - self._search_s3( - bucket, - query, - yara_rules, - file_size, - prefix, - key_contains, - from_date, - end_date, - hide_filenames, - log_format, - log_properties, - json_output, + if files and "s3" in files: + matching_keys = files["s3"] + else: + matching_keys = list( + self.cloud.get_objects(bucket, prefix, key_contains, from_date, end_date, file_size) + ) + s3_client = boto3.client("s3") + region = s3_client.get_bucket_location(Bucket=bucket).get("LocationConstraint", "unknown") + logging.warning(f"Bucket region: {region}. (Search from the same region to avoid egress charges.)") + logging.warning(f"Searching {len(matching_keys)} files in {bucket} for {query}...") + self.cloud.download_from_s3_multithread( + bucket, matching_keys, query, hide_filenames, yara_rules, log_format, log_properties, json_output ) + if account_name and container_name: - self._search_azure( + if files and "azure" in files: + matching_keys = files["azure"] + else: + matching_keys = list( + self.cloud.get_azure_objects( + account_name, container_name, prefix, key_contains, from_date, end_date, file_size + ) + ) + logging.info(f"Searching {len(matching_keys)} files in {account_name}/{container_name} for {query}...") + self.cloud.download_from_azure( account_name, container_name, + matching_keys, query, - yara_rules, - file_size, - prefix, - key_contains, - from_date, - end_date, hide_filenames, + yara_rules, log_format, log_properties, json_output, ) + if google_bucket: - self._search_gcs( + if files and "gcs" in files: + matching_blobs = files["gcs"] + else: + matching_blobs = list( + self.cloud.get_google_objects(google_bucket, prefix, key_contains, from_date, end_date) + ) + logging.info(f"Searching {len(matching_blobs)} files in {google_bucket} for {query}...") + self.cloud.download_from_google( google_bucket, + matching_blobs, query, - yara_rules, - file_size, - prefix, - key_contains, - from_date, - end_date, hide_filenames, + yara_rules, log_format, log_properties, json_output, ) - - def _search_s3( - self, - bucket: str, - query: List[str], - yara_rules: Any, - file_size: int, - prefix: Optional[str], - key_contains: Optional[str], - from_date: Optional[datetime], - end_date: Optional[datetime], - hide_filenames: bool, - log_format: Optional[str], - log_properties: List[str], - json_output: bool, - ) -> None: - """ Search S3 bucket for query """ - matching_keys = list(self.cloud.get_objects(bucket, prefix, key_contains, from_date, end_date, file_size)) - s3_client = boto3.client("s3") - region = s3_client.get_bucket_location(Bucket=bucket).get("LocationConstraint", "unknown") - logging.warning(f"Bucket region: {region}. (Search from the same region to avoid egress charges.)") - logging.warning(f"Searching {len(matching_keys)} files in {bucket} for {query}...") - self.cloud.download_from_s3_multithread( - bucket, matching_keys, query, hide_filenames, yara_rules, log_format, log_properties, json_output - ) - - def _search_azure( - self, - account_name: str, - container_name: str, - query: List[str], - yara_rules: Any, - file_size: int, - prefix: Optional[str], - key_contains: Optional[str], - from_date: Optional[datetime], - end_date: Optional[datetime], - hide_filenames: bool, - log_format: Optional[str], - log_properties: List[str], - json_output: bool, - ) -> None: - matching_keys = list( - self.cloud.get_azure_objects( - account_name, container_name, prefix, key_contains, from_date, end_date, file_size - ) - ) - logging.info(f"Searching {len(matching_keys)} files in {account_name}/{container_name} for {query}...") - self.cloud.download_from_azure( - account_name, - container_name, - matching_keys, - query, - hide_filenames, - yara_rules, - log_format, - log_properties, - json_output, - ) - - def _search_gcs( - self, - google_bucket: str, - query: List[str], - yara_rules: Any, - file_size: int, - prefix: Optional[str], - key_contains: Optional[str], - from_date: Optional[datetime], - end_date: Optional[datetime], - hide_filenames: bool, - log_format: Optional[str], - log_properties: List[str], - json_output: bool, - ) -> None: - # Get (blob_name, blob) - matching_blobs = list(self.cloud.get_google_objects(google_bucket, prefix, key_contains, from_date, end_date)) - logging.info(f"Searching {len(matching_blobs)} files in {google_bucket} for {query}...") - self.cloud.download_from_google( - google_bucket, matching_blobs, query, hide_filenames, yara_rules, log_format, log_properties, json_output - ) diff --git a/cloudgrep/search.py b/cloudgrep/search.py index b7a26c6..92a74b0 100644 --- a/cloudgrep/search.py +++ b/cloudgrep/search.py @@ -129,7 +129,7 @@ def search_file( if yara_rules: return self.yara_scan_file(file_name, key_name, hide_filenames, yara_rules, json_output) - def process_lines(lines) -> bool: + def process_lines(lines: Any) -> bool: return any( self.search_line(key_name, patterns, hide_filenames, line, log_format, log_properties, json_output) for line in lines diff --git a/tests/test_unit.py b/tests/test_unit.py index d10bec3..ae840d8 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -401,3 +401,70 @@ def fake_download_to_filename(local_path: str) -> None: output = fake_out.getvalue().strip() self.assertIn("google target", output, "Should match the google target text in the downloaded content") + + @mock_aws + def test_list_files_returns_pre_filtered_files(self): + """ + Test that list_files() returns only the S3 objects that match + the specified filters (e.g. key substring and non‑empty content). + """ + bucket_name = "list-files-test-bucket" + # Create a fake S3 bucket + s3_resource = boto3.resource("s3", region_name="us-east-1") + s3_resource.create_bucket(Bucket=bucket_name) + s3_client = boto3.client("s3", region_name="us-east-1") + + # Upload several objects: + # - Two objects that match + s3_client.put_object(Bucket=bucket_name, Key="log_file1.txt", Body=b"dummy content") + s3_client.put_object(Bucket=bucket_name, Key="log_file2.txt", Body=b"dummy content") + # Onne that doesnt match the key_contains filter + s3_client.put_object(Bucket=bucket_name, Key="not_a_thing.txt", Body=b"dummy content") + # One that doesnt match the file_size filter + s3_client.put_object(Bucket=bucket_name, Key="log_empty.txt", Body=b"") + + # Call list files + cg = CloudGrep() + result = cg.list_files( + bucket=bucket_name, + account_name=None, + container_name=None, + google_bucket=None, + prefix="", + key_contains="log", + from_date=None, + end_date=None, + file_size=1000000 # 1 MB + ) + + # Assert only the matching files are returned + self.assertIn("s3", result) + expected_keys = {"log_file1.txt", "log_file2.txt"} + self.assertEqual(set(result["s3"]), expected_keys) + + # Now search the contents of the files and assert they hit + for key in expected_keys: + with patch("sys.stdout", new=StringIO()) as fake_out: + cg.search( + bucket=bucket_name, + account_name=None, + container_name=None, + google_bucket=None, + query=["dummy content"], + file=None, + yara_file=None, + file_size=1000000, + prefix="", + key_contains=key, + from_date=None, + end_date=None, + hide_filenames=False, + log_type=None, + log_format=None, + log_properties=[], + profile=None, + json_output=False, + files=result, # Pass the pre-filtered files from list_files + ) + output = fake_out.getvalue().strip() + self.assertIn("log_file1.txt", output) \ No newline at end of file From 66dc67026ac21f6f0cb899fe53c3fe63107b84a1 Mon Sep 17 00:00:00 2001 From: Chris Doman Date: Sat, 1 Feb 2025 22:52:49 +0000 Subject: [PATCH 5/6] Cap workers, limit memory usage --- cloudgrep/cloud.py | 12 ++++------ cloudgrep/search.py | 56 ++++++++++++++++++++++----------------------- tests/test_unit.py | 2 +- 3 files changed, 33 insertions(+), 37 deletions(-) diff --git a/cloudgrep/cloud.py b/cloudgrep/cloud.py index 1202858..07c11d8 100644 --- a/cloudgrep/cloud.py +++ b/cloudgrep/cloud.py @@ -12,7 +12,6 @@ import logging from cloudgrep.search import Search - class Cloud: def __init__(self) -> None: self.search = Search() @@ -21,13 +20,10 @@ def _download_and_search_in_parallel(self, files: List[Any], worker_func: Any) - """Use ThreadPoolExecutor to download every file Returns number of matched files""" total_matched = 0 - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [executor.submit(worker_func, key) for key in files] - for fut in concurrent.futures.as_completed(futures): - try: - total_matched += fut.result() - except Exception: - logging.exception("Error in worker thread") + max_workers = 10 # limit cpu/memory pressure + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + for result in executor.map(worker_func, files): + total_matched += result return total_matched def _download_to_temp(self) -> str: diff --git a/cloudgrep/search.py b/cloudgrep/search.py index 92a74b0..9034100 100644 --- a/cloudgrep/search.py +++ b/cloudgrep/search.py @@ -1,19 +1,18 @@ -import tempfile import re -from typing import Optional, List, Any +from typing import Optional, List, Any, Iterator, Iterable import logging import gzip import zipfile -import os import json import csv - +import io class Search: - def get_all_strings_line(self, file_path: str) -> List[str]: - with open(file_path, "rb") as f: - content = f.read().decode("utf-8", errors="ignore") - return content.splitlines() + def get_all_strings_line(self, file_path: str) -> Iterator[str]: + """Yield lines from a file without loading into memory""" + with open(file_path, "r", encoding="utf-8", errors="ignore") as f: + for line in f: + yield line def print_match(self, match_info: dict, hide_filenames: bool, json_output: Optional[bool]) -> None: output = match_info.copy() @@ -79,7 +78,7 @@ def search_logs( def search_line( self, key_name: str, - patterns: List[str], + compiled_patterns: List[re.Pattern], hide_filenames: bool, line: str, log_format: Optional[str], @@ -88,13 +87,13 @@ def search_line( ) -> bool: """Regex search of the line""" found = False - for pattern in patterns: - if re.search(pattern, line): + for regex in compiled_patterns: + if regex.search(line): if log_format: - self.search_logs(line, key_name, pattern, hide_filenames, log_format, log_properties, json_output) + self.search_logs(line, key_name, regex.pattern, hide_filenames, log_format, log_properties, json_output) else: self.print_match( - {"key_name": key_name, "query": pattern, "line": line}, hide_filenames, json_output + {"key_name": key_name, "query": regex.pattern, "line": line}, hide_filenames, json_output ) found = True return found @@ -128,10 +127,12 @@ def search_file( logging.info(f"Searching {file_name} for patterns: {patterns}") if yara_rules: return self.yara_scan_file(file_name, key_name, hide_filenames, yara_rules, json_output) + + compiled_patterns = [re.compile(p) for p in patterns] - def process_lines(lines: Any) -> bool: + def process_lines(lines: Iterable[str]) -> bool: return any( - self.search_line(key_name, patterns, hide_filenames, line, log_format, log_properties, json_output) + self.search_line(key_name, compiled_patterns, hide_filenames, line, log_format, log_properties, json_output) for line in lines ) @@ -149,24 +150,23 @@ def process_lines(lines: Any) -> bool: elif file_name.endswith(".zip"): matched_any = False try: - with tempfile.TemporaryDirectory() as tempdir: - with zipfile.ZipFile(file_name, "r") as zf: - zf.extractall(tempdir) - for extracted in os.listdir(tempdir): - extracted_path = os.path.join(tempdir, extracted) - if not os.path.isfile(extracted_path): + with zipfile.ZipFile(file_name, "r") as zf: + for zip_info in zf.infolist(): + if zip_info.is_dir(): continue - try: - with open(extracted_path, "r", encoding="utf-8", errors="ignore") as f: + with zf.open(zip_info) as file_obj: + # Wrap the binary stream as text + with io.TextIOWrapper(file_obj, encoding="utf-8", errors="ignore") as f: if account_name: - data = json.load(f) - if process_lines(data): - matched_any = True + try: + data = json.load(f) + if process_lines(data): + matched_any = True + except Exception: + logging.exception(f"Error processing json in zip member: {zip_info.filename}") else: if process_lines(f): matched_any = True - except Exception: - logging.exception(f"Error processing extracted file: {extracted_path}") return matched_any except Exception: logging.exception(f"Error processing zip file: {file_name}") diff --git a/tests/test_unit.py b/tests/test_unit.py index ae840d8..c8a6574 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -403,7 +403,7 @@ def fake_download_to_filename(local_path: str) -> None: self.assertIn("google target", output, "Should match the google target text in the downloaded content") @mock_aws - def test_list_files_returns_pre_filtered_files(self): + def test_list_files_returns_pre_filtered_files(self) -> None: """ Test that list_files() returns only the S3 objects that match the specified filters (e.g. key substring and non‑empty content). From dc9dbf4bd701f683d07907354df41af77e1eda77 Mon Sep 17 00:00:00 2001 From: Chris Doman Date: Sat, 1 Feb 2025 22:58:12 +0000 Subject: [PATCH 6/6] fix type --- cloudgrep/cloudgrep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cloudgrep/cloudgrep.py b/cloudgrep/cloudgrep.py index 276c11f..cf47a18 100644 --- a/cloudgrep/cloudgrep.py +++ b/cloudgrep/cloudgrep.py @@ -45,7 +45,7 @@ def list_files( ) ) if google_bucket: - files["gcs"] = list(self.cloud.get_google_objects(google_bucket, prefix, key_contains, from_date, end_date)) + files["gcs"] = [blob[0] for blob in self.cloud.get_google_objects(google_bucket, prefix, key_contains, from_date, end_date)] return files def search(