diff --git a/cloudgrep/__main__.py b/cloudgrep/__main__.py index 1a5a825..e8b7ba2 100644 --- a/cloudgrep/__main__.py +++ b/cloudgrep/__main__.py @@ -1,137 +1,87 @@ import argparse import logging import sys -from typing import List +from typing import List, Optional +import dateutil.parser +import datetime from cloudgrep.cloudgrep import CloudGrep VERSION = "1.0.5" -# Define a custom argument type for a list of strings def list_of_strings(arg: str) -> List[str]: - return arg.split(",") + """Parse a comma‐separated string into a list of nonempty strings.""" + return [s.strip() for s in arg.split(",") if s.strip()] def main() -> None: parser = argparse.ArgumentParser( - description=f"CloudGrep searches is grep for cloud storage like S3 and Azure Storage. Version: {VERSION}" - ) - parser.add_argument("-b", "--bucket", help="AWS S3 Bucket to search. E.g. my-bucket", required=False) - parser.add_argument("-an", "--account-name", help="Azure Account Name to Search", required=False) - parser.add_argument("-cn", "--container-name", help="Azure Container Name to Search", required=False) - parser.add_argument("-gb", "--google-bucket", help="Google Cloud Bucket to Search", required=False) - parser.add_argument( - "-q", - "--query", - type=list_of_strings, - help="Text to search for. Will be parsed as a Regex. E.g. example.com", - required=False, - ) - parser.add_argument( - "-v", - "--file", - help="File containing a list of words or regular expressions to search for. One per line.", - required=False, - ) - parser.add_argument( - "-y", - "--yara", - help="File containing Yara rules to scan files.", - required=False, - ) - parser.add_argument( - "-p", - "--prefix", - help="Optionally filter on the start of the Object name. E.g. logs/", - required=False, - default="", - ) - parser.add_argument( - "-f", "--filename", help="Optionally filter on Objects that match a keyword. E.g. .log.gz ", required=False - ) - parser.add_argument( - "-s", - "--start_date", - help="Optionally filter on Objects modified after a Date or Time. E.g. 2022-01-01 ", - required=False, - ) - parser.add_argument( - "-e", - "--end_date", - help="Optionally filter on Objects modified before a Date or Time. E.g. 2022-01-01 ", - required=False, + description=f"CloudGrep: grep for cloud storage (S3, Azure, Google Cloud). Version: {VERSION}" ) + parser.add_argument("-b", "--bucket", help="AWS S3 Bucket to search (e.g. my-bucket)") + parser.add_argument("-an", "--account-name", help="Azure Account Name to search") + parser.add_argument("-cn", "--container-name", help="Azure Container Name to search") + parser.add_argument("-gb", "--google-bucket", help="Google Cloud Bucket to search") + parser.add_argument("-q", "--query", type=list_of_strings, help="Comma-separated list of regex patterns to search") + parser.add_argument("-v", "--file", help="File containing queries (one per line)") + parser.add_argument("-y", "--yara", help="File containing Yara rules") + parser.add_argument("-p", "--prefix", default="", help="Filter objects by prefix (e.g. logs/)") + parser.add_argument("-f", "--filename", help="Filter objects whose names contain a keyword (e.g. .log.gz)") + parser.add_argument("-s", "--start_date", help="Filter objects modified after this date (YYYY-MM-DD)") + parser.add_argument("-e", "--end_date", help="Filter objects modified before this date (YYYY-MM-DD)") parser.add_argument( "-fs", "--file_size", - help="Optionally filter on Objects smaller than a file size, in bytes. Defaults to 100 Mb. ", - default=100000000, - required=False, - ) - parser.add_argument( - "-pr", - "--profile", - help="Set an AWS profile to use. E.g. default, dev, prod.", - required=False, - ) - parser.add_argument("-d", "--debug", help="Enable Debug logging. ", action="store_true", required=False) - parser.add_argument( - "-hf", "--hide_filenames", help="Dont show matching filenames. ", action="store_true", required=False + type=int, + default=100_000_000, + help="Max file size in bytes (default: 100MB)", ) + parser.add_argument("-pr", "--profile", help="AWS profile to use (e.g. default, dev, prod)") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug logging") + parser.add_argument("-hf", "--hide_filenames", action="store_true", help="Hide filenames in output") + parser.add_argument("-lt", "--log_type", help="Pre-defined log type (e.g. cloudtrail, azure)") + parser.add_argument("-lf", "--log_format", help="Custom log format (e.g. json, csv)") parser.add_argument( - "-lt", - "--log_type", - help="Return individual matching log entries based on pre-defined log types, otherwise custom log_format and log_properties can be used. E.g. cloudtrail. ", - required=False, + "-lp", "--log_properties", type=list_of_strings, help="Comma-separated list of log properties to extract" ) - parser.add_argument( - "-lf", - "--log_format", - help="Define custom log format of raw file to parse before applying search logic. Used if --log_type is not defined. E.g. json. ", - required=False, - ) - parser.add_argument( - "-lp", - "--log_properties", - type=list_of_strings, - help="Define custom list of properties to traverse to dynamically extract final list of log records. Used if --log_type is not defined. E.g. [" - "Records" - "]. ", - required=False, - ) - parser.add_argument("-jo", "--json_output", help="Output as JSON.", action="store_true") - args = vars(parser.parse_args()) + parser.add_argument("-jo", "--json_output", action="store_true", help="Output results in JSON format") + args = parser.parse_args() if len(sys.argv) == 1: parser.print_help(sys.stderr) sys.exit(1) - if args["debug"]: - logging.basicConfig(format="[%(asctime)s]:[%(levelname)s] - %(message)s", level=logging.INFO) + # Parse dates (if provided) into datetime objects + start_date: Optional["datetime.datetime"] = dateutil.parser.parse(args.start_date) if args.start_date else None + end_date: Optional["datetime.datetime"] = dateutil.parser.parse(args.end_date) if args.end_date else None + + # Configure logging + if args.debug: + logging.basicConfig(format="[%(asctime)s] [%(levelname)s] %(message)s", level=logging.DEBUG) else: - logging.basicConfig(format="[%(asctime)s] - %(message)s", level=logging.WARNING) + logging.basicConfig(format="[%(asctime)s] %(message)s", level=logging.WARNING) logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR) CloudGrep().search( - args["bucket"], - args["account_name"], - args["container_name"], - args["google_bucket"], - args["query"], - args["file"], - args["yara"], - int(args["file_size"]), - args["prefix"], - args["filename"], - args["start_date"], - args["end_date"], - args["hide_filenames"], - args["log_type"], - args["log_format"], - args["log_properties"], - args["profile"], - args["json_output"], + bucket=args.bucket, + account_name=args.account_name, + container_name=args.container_name, + google_bucket=args.google_bucket, + query=args.query, + file=args.file, + yara_file=args.yara, + file_size=args.file_size, + prefix=args.prefix, + key_contains=args.filename, + from_date=start_date, + end_date=end_date, + hide_filenames=args.hide_filenames, + log_type=args.log_type, + log_format=args.log_format, + log_properties=args.log_properties, + profile=args.profile, + json_output=args.json_output, ) diff --git a/cloudgrep/cloud.py b/cloudgrep/cloud.py index b8f63ae..07c11d8 100644 --- a/cloudgrep/cloud.py +++ b/cloudgrep/cloud.py @@ -1,6 +1,6 @@ import boto3 import os -from azure.storage.blob import BlobServiceClient, BlobProperties +from azure.storage.blob import BlobServiceClient from azure.identity import DefaultAzureCredential from azure.core.exceptions import ResourceNotFoundError from google.cloud import storage # type: ignore @@ -8,15 +8,30 @@ import botocore import concurrent.futures import tempfile -from typing import Iterator, Optional, List, Any +from typing import Iterator, Optional, List, Any, Tuple import logging from cloudgrep.search import Search - class Cloud: def __init__(self) -> None: self.search = Search() + def _download_and_search_in_parallel(self, files: List[Any], worker_func: Any) -> int: + """Use ThreadPoolExecutor to download every file + Returns number of matched files""" + total_matched = 0 + max_workers = 10 # limit cpu/memory pressure + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + for result in executor.map(worker_func, files): + total_matched += result + return total_matched + + def _download_to_temp(self) -> str: + """Return a temporary filename""" + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp.close() + return tmp.name + def download_from_s3_multithread( self, bucket: str, @@ -28,16 +43,13 @@ def download_from_s3_multithread( log_properties: List[str] = [], json_output: Optional[bool] = False, ) -> int: - """Use ThreadPoolExecutor and boto3 to download every file in the bucket from s3 - Returns number of matched files""" - if not log_properties: + """Download and search files from AWS S3""" + if log_properties is None: log_properties = [] - s3 = boto3.client("s3", config=botocore.config.Config(max_pool_connections=64)) def _download_search_s3(key: str) -> int: - with tempfile.NamedTemporaryFile(delete=False) as tmp: - tmp_name = tmp.name + tmp_name = self._download_to_temp() try: logging.info(f"Downloading s3://{bucket}/{key} to {tmp_name}") s3.download_file(bucket, key, tmp_name) @@ -45,23 +57,16 @@ def _download_search_s3(key: str) -> int: tmp_name, key, query, hide_filenames, yara_rules, log_format, log_properties, json_output ) return 1 if matched else 0 - except Exception as e: - logging.error(f"Error downloading or searching {key}: {e}") + except Exception: + logging.exception(f"Error processing {key}") return 0 finally: try: - # Cleanup os.remove(tmp_name) except OSError: pass - total_matched = 0 - # Use ThreadPoolExecutor to download and search files - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [executor.submit(_download_search_s3, k) for k in files] - for fut in concurrent.futures.as_completed(futures): - total_matched += fut.result() - return total_matched + return self._download_and_search_in_parallel(files, _download_search_s3) def download_from_azure( self, @@ -75,28 +80,22 @@ def download_from_azure( log_properties: Optional[List[str]] = None, json_output: bool = False, ) -> int: - """Download every file in the container from azure - Returns number of matched files""" - if not log_properties: + """Download and search files from Azure Storage""" + if log_properties is None: log_properties = [] - default_credential = DefaultAzureCredential() - blob_service_client = BlobServiceClient.from_connection_string( - f"DefaultEndpointsProtocol=https;AccountName={account_name};EndpointSuffix=core.windows.net", - credential=default_credential, - ) + connection_str = f"DefaultEndpointsProtocol=https;AccountName={account_name};EndpointSuffix=core.windows.net" + blob_service_client = BlobServiceClient.from_connection_string(connection_str, credential=default_credential) container_client = blob_service_client.get_container_client(container_name) def _download_search_azure(key: str) -> int: - with tempfile.NamedTemporaryFile(delete=False) as tmp: - tmp_name = tmp.name + tmp_name = self._download_to_temp() try: logging.info(f"Downloading azure://{account_name}/{container_name}/{key} to {tmp_name}") blob_client = container_client.get_blob_client(key) - with open(tmp_name, "wb") as my_blob: - blob_data = blob_client.download_blob() - blob_data.readinto(my_blob) - + with open(tmp_name, "wb") as out_file: + stream = blob_client.download_blob() + stream.readinto(out_file) matched = self.search.search_file( tmp_name, key, @@ -112,28 +111,21 @@ def _download_search_azure(key: str) -> int: except ResourceNotFoundError: logging.info(f"File {key} not found in {account_name}/{container_name}") return 0 - except Exception as e: - logging.error(f"Error downloading or searching {key}: {e}") + except Exception: + logging.exception(f"Error processing {key}") return 0 finally: try: - import os - os.remove(tmp_name) except OSError: pass - total_matched = 0 - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [executor.submit(_download_search_azure, k) for k in files] - for fut in concurrent.futures.as_completed(futures): - total_matched += fut.result() - return total_matched + return self._download_and_search_in_parallel(files, _download_search_azure) def download_from_google( self, bucket: str, - files: List[str], + blobs: List[Tuple[str, Any]], query: List[str], hide_filenames: bool, yara_rules: Any, @@ -141,103 +133,30 @@ def download_from_google( log_properties: Optional[List[str]] = None, json_output: bool = False, ) -> int: - """Download every file in the bucket from google - Returns number of matched files""" - if not log_properties: + """Download and search files from Google Cloud Storage""" + if log_properties is None: log_properties = [] - client = storage.Client() - bucket_gcp = client.get_bucket(bucket) - - def _download_and_search_google(key: str) -> int: - with tempfile.NamedTemporaryFile(delete=False) as tmp: - tmp_name = tmp.name + def _download_and_search_google(item: Tuple[str, Any]) -> int: + key, blob = item + tmp_name = self._download_to_temp() try: logging.info(f"Downloading gs://{bucket}/{key} to {tmp_name}") - blob = bucket_gcp.get_blob(key) - if blob is None: - logging.warning(f"Blob not found: {key}") - return 0 blob.download_to_filename(tmp_name) - matched = self.search.search_file( tmp_name, key, query, hide_filenames, yara_rules, log_format, log_properties, json_output ) return 1 if matched else 0 - except Exception as e: - logging.error(f"Error downloading or searching {key}: {e}") + except Exception: + logging.exception(f"Error processing {key}") return 0 finally: try: - import os - os.remove(tmp_name) except OSError: pass - total_matched = 0 - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [executor.submit(_download_and_search_google, k) for k in files] - for fut in concurrent.futures.as_completed(futures): - total_matched += fut.result() - return total_matched - - def filter_object( - self, - obj: dict, - key_contains: Optional[str], - from_date: Optional[datetime], - to_date: Optional[datetime], - file_size: int, - ) -> bool: - """Filter S3 objects by date range, file size, and substring in key.""" - last_modified = obj["LastModified"] - if last_modified and from_date and from_date > last_modified: - return False # Object was modified before the from_date - if last_modified and to_date and last_modified > to_date: - return False # Object was modified after the to_date - if obj["Size"] == 0 or int(obj["Size"]) > file_size: - return False # Object is empty or too large - if key_contains and key_contains not in obj["Key"]: - return False # Object does not contain the key_contains string - return True - - def filter_object_azure( - self, - obj: BlobProperties, - key_contains: Optional[str], - from_date: Optional[datetime], - to_date: Optional[datetime], - file_size: int, - ) -> bool: - """Filter Azure blob objects similarly.""" - last_modified = obj["last_modified"] # type: ignore - if last_modified and from_date and from_date > last_modified: - return False - if last_modified and to_date and last_modified > to_date: - return False - if obj["size"] == 0 or int(obj["size"]) > file_size: - return False - if key_contains and key_contains not in obj["name"]: - return False - return True - - def filter_object_google( - self, - obj: storage.blob.Blob, - key_contains: Optional[str], - from_date: Optional[datetime], - to_date: Optional[datetime], - ) -> bool: - """Filter objects in GCP""" - last_modified = obj.updated - if last_modified and from_date and from_date > last_modified: - return False - if last_modified and to_date and last_modified > to_date: - return False - if key_contains and key_contains not in obj.name: - return False - return True + return self._download_and_search_in_parallel(blobs, _download_and_search_google) def get_objects( self, @@ -248,15 +167,13 @@ def get_objects( end_date: Optional[datetime], file_size: int, ) -> Iterator[str]: - """Get objects in S3""" + """Yield objects that match filter""" s3 = boto3.client("s3") paginator = s3.get_paginator("list_objects_v2") - page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix) - for page in page_iterator: - if "Contents" in page: - for obj in page["Contents"]: - if self.filter_object(obj, key_contains, from_date, end_date, file_size): - yield obj["Key"] + for page in paginator.paginate(Bucket=bucket, Prefix=prefix): + for obj in page.get("Contents", []): + if self.filter_object(obj, key_contains, from_date, end_date, file_size): + yield obj.get("Key") def get_azure_objects( self, @@ -268,16 +185,12 @@ def get_azure_objects( end_date: Optional[datetime], file_size: int, ) -> Iterator[str]: - """Get all objects in Azure storage container with a given prefix""" + """Yield Azure blob names that match the filter""" default_credential = DefaultAzureCredential() - blob_service_client = BlobServiceClient.from_connection_string( - f"DefaultEndpointsProtocol=https;AccountName={account_name};EndpointSuffix=core.windows.net", - credential=default_credential, - ) + connection_str = f"DefaultEndpointsProtocol=https;AccountName={account_name};EndpointSuffix=core.windows.net" + blob_service_client = BlobServiceClient.from_connection_string(connection_str, credential=default_credential) container_client = blob_service_client.get_container_client(container_name) - blobs = container_client.list_blobs(name_starts_with=prefix) - - for blob in blobs: + for blob in container_client.list_blobs(name_starts_with=prefix): if self.filter_object_azure(blob, key_contains, from_date, end_date, file_size): yield blob.name @@ -288,11 +201,80 @@ def get_google_objects( key_contains: Optional[str], from_date: Optional[datetime], end_date: Optional[datetime], - ) -> Iterator[str]: - """Get all objects in a GCP bucket with a given prefix""" + ) -> Iterator[Tuple[str, Any]]: + """Yield (blob_name, blob) for blobs in GCP that match filter""" client = storage.Client() bucket_gcp = client.get_bucket(bucket) - blobs = bucket_gcp.list_blobs(prefix=prefix) - for blob in blobs: + for blob in bucket_gcp.list_blobs(prefix=prefix): if self.filter_object_google(blob, key_contains, from_date, end_date): - yield blob.name + yield blob.name, blob + + def filter_object( + self, + obj: dict, + key_contains: Optional[str], + from_date: Optional[datetime], + to_date: Optional[datetime], + file_size: int, + ) -> bool: + """Filter an S3 object based on modification date, size, and key substring""" + last_modified = obj.get("LastModified") + if last_modified: + if from_date and last_modified < from_date: + return False + if to_date and last_modified > to_date: + return False + # If size is 0 or greater than file_size, skip + if int(obj.get("Size", 0)) == 0 or int(obj.get("Size", 0)) > file_size: + return False + if key_contains and key_contains not in obj.get("Key", ""): + return False + return True + + def filter_object_azure( + self, + obj: Any, + key_contains: Optional[str], + from_date: Optional[datetime], + to_date: Optional[datetime], + file_size: int, + ) -> bool: + """ + Filter an Azure blob object (or dict) based on modification date, size, and name substring. + """ + if isinstance(obj, dict): + last_modified = obj.get("last_modified") + size = int(obj.get("size", 0)) + name = obj.get("name", "") + else: + last_modified = getattr(obj, "last_modified", None) + size = int(getattr(obj, "size", 0)) + name = getattr(obj, "name", "") + if last_modified: + if from_date and last_modified < from_date: + return False + if to_date and last_modified > to_date: + return False + if size == 0 or size > file_size: + return False + if key_contains and key_contains not in name: + return False + return True + + def filter_object_google( + self, + obj: storage.blob.Blob, + key_contains: Optional[str], + from_date: Optional[datetime], + to_date: Optional[datetime], + ) -> bool: + """Filter a GCP blob based on update time and name substring""" + last_modified = getattr(obj, "updated", None) + if last_modified: + if from_date and last_modified < from_date: + return False + if to_date and last_modified > to_date: + return False + if key_contains and key_contains not in getattr(obj, "name", ""): + return False + return True diff --git a/cloudgrep/cloudgrep.py b/cloudgrep/cloudgrep.py index 6250a34..cf47a18 100644 --- a/cloudgrep/cloudgrep.py +++ b/cloudgrep/cloudgrep.py @@ -1,6 +1,6 @@ import boto3 from datetime import datetime -from typing import Optional, List, Any +from typing import Optional, List, Any, Dict import logging import yara # type: ignore @@ -11,22 +11,54 @@ class CloudGrep: def __init__(self) -> None: self.cloud = Cloud() - def load_queries(self, file: str) -> List[str]: - """Load in a list of queries from a file""" - with open(file, "r") as f: + def load_queries(self, file_path: str) -> List[str]: + with open(file_path, "r", encoding="utf-8") as f: return [line.strip() for line in f if line.strip()] + def list_files( + self, + bucket: Optional[str], + account_name: Optional[str], + container_name: Optional[str], + google_bucket: Optional[str], + prefix: Optional[str] = "", + key_contains: Optional[str] = None, + from_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + file_size: int = 100_000_000, # 100MB + ) -> Dict[str, List[Any]]: + """ + Returns a dictionary of matching files for each cloud provider. + + The returned dict has the following keys: + - "s3": a list of S3 object keys that match filters + - "azure": a list of Azure blob names that match filters + - "gcs": a list of tuples (blob name, blob) for Google Cloud Storage that match filters + """ + files = {} + if bucket: + files["s3"] = list(self.cloud.get_objects(bucket, prefix, key_contains, from_date, end_date, file_size)) + if account_name and container_name: + files["azure"] = list( + self.cloud.get_azure_objects( + account_name, container_name, prefix, key_contains, from_date, end_date, file_size + ) + ) + if google_bucket: + files["gcs"] = [blob[0] for blob in self.cloud.get_google_objects(google_bucket, prefix, key_contains, from_date, end_date)] + return files + def search( self, bucket: Optional[str], account_name: Optional[str], container_name: Optional[str], google_bucket: Optional[str], - query: List[str], + query: Optional[List[str]], file: Optional[str], yara_file: Optional[str], file_size: int, - prefix: Optional[str] = None, + prefix: Optional[str] = "", key_contains: Optional[str] = None, from_date: Optional[datetime] = None, end_date: Optional[datetime] = None, @@ -36,174 +68,94 @@ def search( log_properties: Optional[List[str]] = None, profile: Optional[str] = None, json_output: bool = False, + files: Optional[Dict[str, List[Any]]] = None, ) -> None: - """Search query/queries across cloud storage""" + """ + Searches the contents of files matching the given queries. - # Load queries from a file if given + If the optional `files` parameter is provided (a dict with keys such as "s3", "azure", or "gcs") + then the search will use those file lists instead of applying the filters again. + """ if not query and file: - logging.debug(f"Loading queries in from {file}") + logging.debug(f"Loading queries from {file}") query = self.load_queries(file) + if not query: + logging.error("No query provided. Exiting.") + return - # Compile optional Yara rules yara_rules = None if yara_file: - logging.debug(f"Loading yara rules from {yara_file}") + logging.debug(f"Compiling yara rules from {yara_file}") yara_rules = yara.compile(filepath=yara_file) if profile: - # Set the AWS credentials profile to use boto3.setup_default_session(profile_name=profile) - if log_type is not None: - if log_type == "cloudtrail": + if log_type: + if log_type.lower() == "cloudtrail": log_format = "json" log_properties = ["Records"] - elif log_type == "azure": + elif log_type.lower() == "azure": log_format = "json" log_properties = ["data"] else: - logging.error(f"Invalid log_type: '{log_type}'") + logging.error(f"Invalid log_type: {log_type}") return if log_properties is None: - log_properties = [] # default + log_properties = [] - # Search given cloud storage if bucket: - self._search_s3( - bucket=bucket, - query=query, - yara_rules=yara_rules, - file_size=file_size, - prefix=prefix, - key_contains=key_contains, - from_date=from_date, - end_date=end_date, - hide_filenames=hide_filenames, - log_format=log_format, - log_properties=log_properties, - json_output=json_output, + if files and "s3" in files: + matching_keys = files["s3"] + else: + matching_keys = list( + self.cloud.get_objects(bucket, prefix, key_contains, from_date, end_date, file_size) + ) + s3_client = boto3.client("s3") + region = s3_client.get_bucket_location(Bucket=bucket).get("LocationConstraint", "unknown") + logging.warning(f"Bucket region: {region}. (Search from the same region to avoid egress charges.)") + logging.warning(f"Searching {len(matching_keys)} files in {bucket} for {query}...") + self.cloud.download_from_s3_multithread( + bucket, matching_keys, query, hide_filenames, yara_rules, log_format, log_properties, json_output ) if account_name and container_name: - self._search_azure( - account_name=account_name, - container_name=container_name, - query=query, - yara_rules=yara_rules, - file_size=file_size, - prefix=prefix, - key_contains=key_contains, - from_date=from_date, - end_date=end_date, - hide_filenames=hide_filenames, - log_format=log_format, - log_properties=log_properties, - json_output=json_output, + if files and "azure" in files: + matching_keys = files["azure"] + else: + matching_keys = list( + self.cloud.get_azure_objects( + account_name, container_name, prefix, key_contains, from_date, end_date, file_size + ) + ) + logging.info(f"Searching {len(matching_keys)} files in {account_name}/{container_name} for {query}...") + self.cloud.download_from_azure( + account_name, + container_name, + matching_keys, + query, + hide_filenames, + yara_rules, + log_format, + log_properties, + json_output, ) if google_bucket: - self._search_gcs( - google_bucket=google_bucket, - query=query, - yara_rules=yara_rules, - file_size=file_size, - prefix=prefix, - key_contains=key_contains, - from_date=from_date, - end_date=end_date, - hide_filenames=hide_filenames, - log_format=log_format, - log_properties=log_properties, - json_output=json_output, - ) - - def _search_s3( - self, - bucket: str, - query: List[str], - yara_rules: Any, - file_size: int, - prefix: Optional[str], - key_contains: Optional[str], - from_date: Optional[datetime], - end_date: Optional[datetime], - hide_filenames: bool, - log_format: Optional[str], - log_properties: List[str], - json_output: bool, - ) -> None: - """Search S3 bucket for query""" - matching_keys = list(self.cloud.get_objects(bucket, prefix, key_contains, from_date, end_date, file_size)) - s3_client = boto3.client("s3") - region = s3_client.get_bucket_location(Bucket=bucket) - logging.warning( - f"Bucket is in region: {region.get('LocationConstraint', 'unknown')} : " - "Search from the same region to avoid egress charges." - ) - logging.warning(f"Searching {len(matching_keys)} files in {bucket} for {query}...") - self.cloud.download_from_s3_multithread( - bucket, matching_keys, query, hide_filenames, yara_rules, log_format, log_properties, json_output - ) - - def _search_azure( - self, - account_name: str, - container_name: str, - query: List[str], - yara_rules: Any, - file_size: int, - prefix: Optional[str], - key_contains: Optional[str], - from_date: Optional[datetime], - end_date: Optional[datetime], - hide_filenames: bool, - log_format: Optional[str], - log_properties: List[str], - json_output: bool, - ) -> None: - """Search Azure container for query""" - matching_keys = list( - self.cloud.get_azure_objects( - account_name, container_name, prefix, key_contains, from_date, end_date, file_size + if files and "gcs" in files: + matching_blobs = files["gcs"] + else: + matching_blobs = list( + self.cloud.get_google_objects(google_bucket, prefix, key_contains, from_date, end_date) + ) + logging.info(f"Searching {len(matching_blobs)} files in {google_bucket} for {query}...") + self.cloud.download_from_google( + google_bucket, + matching_blobs, + query, + hide_filenames, + yara_rules, + log_format, + log_properties, + json_output, ) - ) - print(f"Searching {len(matching_keys)} files in {account_name}/{container_name} for {query}...") - self.cloud.download_from_azure( - account_name, - container_name, - matching_keys, - query, - hide_filenames, - yara_rules, - log_format, - log_properties, - json_output, - ) - - def _search_gcs( - self, - google_bucket: str, - query: List[str], - yara_rules: Any, - file_size: int, - prefix: Optional[str], - key_contains: Optional[str], - from_date: Optional[datetime], - end_date: Optional[datetime], - hide_filenames: bool, - log_format: Optional[str], - log_properties: List[str], - json_output: bool, - ) -> None: - matching_keys = list(self.cloud.get_google_objects(google_bucket, prefix, key_contains, from_date, end_date)) - print(f"Searching {len(matching_keys)} files in {google_bucket} for {query}...") - self.cloud.download_from_google( - google_bucket, - matching_keys, - query, - hide_filenames, - yara_rules, - log_format, - log_properties, - json_output, - ) diff --git a/cloudgrep/search.py b/cloudgrep/search.py index 4652700..9034100 100644 --- a/cloudgrep/search.py +++ b/cloudgrep/search.py @@ -1,57 +1,60 @@ -import tempfile import re -from typing import Optional, List, Any +from typing import Optional, List, Any, Iterator, Iterable import logging import gzip import zipfile -import os import json import csv - +import io class Search: - def get_all_strings_line(self, file_path: str) -> List[str]: - """Get all the strings from a file line by line - We do this instead of f.readlines() as this supports binary files too - """ - with open(file_path, "rb") as f: - read_bytes = f.read() - return read_bytes.decode("utf-8", "ignore").replace("\n", "\r").split("\r") + def get_all_strings_line(self, file_path: str) -> Iterator[str]: + """Yield lines from a file without loading into memory""" + with open(file_path, "r", encoding="utf-8", errors="ignore") as f: + for line in f: + yield line - def print_match(self, matched_line_dict: dict, hide_filenames: bool, json_output: Optional[bool]) -> None: - """Print matched line""" + def print_match(self, match_info: dict, hide_filenames: bool, json_output: Optional[bool]) -> None: + output = match_info.copy() + if hide_filenames: + output.pop("key_name", None) if json_output: - matched_line_dict.pop("key_name", None) if hide_filenames else None try: - print(json.dumps(matched_line_dict)) + print(json.dumps(output)) except TypeError: - print(str(matched_line_dict)) + print(str(output)) else: - line = matched_line_dict.get("line", "") - if "match_rule" in matched_line_dict: - line = f"{matched_line_dict['match_rule']}: {matched_line_dict['match_strings']}" - print(line if hide_filenames else f"{matched_line_dict['key_name']}: {line}") + line = output.get("line", "") + if "match_rule" in output: + line = f"{output['match_rule']}: {output.get('match_strings', '')}" + print(f"{output.get('key_name', '')}: {line}" if not hide_filenames else line) def parse_logs(self, line: str, log_format: Optional[str]) -> Any: - """Parse input log line based on format""" - try: - if log_format == "json": + if log_format == "json": + try: return json.loads(line) - elif log_format == "csv": + except json.JSONDecodeError as e: + logging.error(f"JSON decode error in line: {line} ({e})") + elif log_format == "csv": + try: return list(csv.DictReader([line])) - elif log_format: - logging.error(f"Invalid log format: {log_format}") - except (json.JSONDecodeError, csv.Error) as e: - logging.error(f"Invalid {log_format} format in line: {line} ({e})") + except csv.Error as e: + logging.error(f"CSV parse error in line: {line} ({e})") + elif log_format: + logging.error(f"Unsupported log format: {log_format}") return None - def extract_log_entries(self, line_parsed: Any, log_properties: List[str]) -> List[Any]: - """Extract properties in log entries""" - if log_properties: - for log_property in log_properties: - if isinstance(line_parsed, dict): - line_parsed = line_parsed.get(log_property, None) - return line_parsed if isinstance(line_parsed, list) else [line_parsed] + def extract_log_entries(self, parsed: Any, log_properties: List[str]) -> List[Any]: + if log_properties and isinstance(parsed, dict): + for prop in log_properties: + parsed = parsed.get(prop, None) + if parsed is None: + break + if isinstance(parsed, list): + return parsed + elif parsed is not None: + return [parsed] + return [] def search_logs( self, @@ -64,18 +67,18 @@ def search_logs( json_output: Optional[bool] = False, ) -> None: """Search log records in parsed logs""" - line_parsed = self.parse_logs(line, log_format) - if not line_parsed: + parsed = self.parse_logs(line, log_format) + if not parsed: return - - for record in self.extract_log_entries(line_parsed, log_properties): - if re.search(search, json.dumps(record)): - self.print_match({"key_name": key_name, "query": search, "line": record}, hide_filenames, json_output) + for entry in self.extract_log_entries(parsed, log_properties): + entry_str = json.dumps(entry) + if re.search(search, entry_str): + self.print_match({"key_name": key_name, "query": search, "line": entry}, hide_filenames, json_output) def search_line( self, key_name: str, - search: List[str], + compiled_patterns: List[re.Pattern], hide_filenames: bool, line: str, log_format: Optional[str], @@ -83,16 +86,17 @@ def search_line( json_output: Optional[bool] = False, ) -> bool: """Regex search of the line""" - matched = any(re.search(cur_search, line) for cur_search in search) - if matched: - if log_format: - for cur_search in search: - self.search_logs( - line, key_name, cur_search, hide_filenames, log_format, log_properties, json_output + found = False + for regex in compiled_patterns: + if regex.search(line): + if log_format: + self.search_logs(line, key_name, regex.pattern, hide_filenames, log_format, log_properties, json_output) + else: + self.print_match( + {"key_name": key_name, "query": regex.pattern, "line": line}, hide_filenames, json_output ) - else: - self.print_match({"key_name": key_name, "query": search, "line": line}, hide_filenames, json_output) - return matched + found = True + return found def yara_scan_file( self, file_name: str, key_name: str, hide_filenames: bool, yara_rules: Any, json_output: Optional[bool] = False @@ -111,7 +115,7 @@ def search_file( self, file_name: str, key_name: str, - search: List[str], + patterns: List[str], hide_filenames: bool, yara_rules: Any, log_format: Optional[str] = None, @@ -120,31 +124,56 @@ def search_file( account_name: Optional[str] = None, ) -> bool: """Regex search of the file line by line""" - logging.info(f"Searching {file_name} for {search}") + logging.info(f"Searching {file_name} for patterns: {patterns}") if yara_rules: return self.yara_scan_file(file_name, key_name, hide_filenames, yara_rules, json_output) + + compiled_patterns = [re.compile(p) for p in patterns] - def process_lines(lines: Any) -> bool: + def process_lines(lines: Iterable[str]) -> bool: return any( - self.search_line(key_name, search, hide_filenames, line, log_format, log_properties, json_output) + self.search_line(key_name, compiled_patterns, hide_filenames, line, log_format, log_properties, json_output) for line in lines ) - if key_name.endswith(".gz"): - with gzip.open(file_name, "rt") as f: - return process_lines(json.load(f) if account_name else f) - elif key_name.endswith(".zip"): - with tempfile.TemporaryDirectory() as tempdir, zipfile.ZipFile(file_name, "r") as zf: - zf.extractall(tempdir) - return any( - # Process the extracted files - process_lines( - json.load(open(os.path.join(tempdir, filename))) - if account_name - else open(os.path.join(tempdir, filename)) - ) - # Search all files in the zip file - for filename in os.listdir(tempdir) - if os.path.isfile(os.path.join(tempdir, filename)) - ) - return process_lines(self.get_all_strings_line(file_name)) + if file_name.endswith(".gz"): + try: + with gzip.open(file_name, "rt", encoding="utf-8", errors="ignore") as f: + if account_name: + data = json.load(f) + return process_lines(data) + else: + return process_lines(f) + except Exception: + logging.exception(f"Error processing gzip file: {file_name}") + return False + elif file_name.endswith(".zip"): + matched_any = False + try: + with zipfile.ZipFile(file_name, "r") as zf: + for zip_info in zf.infolist(): + if zip_info.is_dir(): + continue + with zf.open(zip_info) as file_obj: + # Wrap the binary stream as text + with io.TextIOWrapper(file_obj, encoding="utf-8", errors="ignore") as f: + if account_name: + try: + data = json.load(f) + if process_lines(data): + matched_any = True + except Exception: + logging.exception(f"Error processing json in zip member: {zip_info.filename}") + else: + if process_lines(f): + matched_any = True + return matched_any + except Exception: + logging.exception(f"Error processing zip file: {file_name}") + return False + else: + try: + return process_lines(self.get_all_strings_line(file_name)) + except Exception: + logging.exception(f"Error processing file: {file_name}") + return False diff --git a/tests/test_unit.py b/tests/test_unit.py index d10bec3..c8a6574 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -401,3 +401,70 @@ def fake_download_to_filename(local_path: str) -> None: output = fake_out.getvalue().strip() self.assertIn("google target", output, "Should match the google target text in the downloaded content") + + @mock_aws + def test_list_files_returns_pre_filtered_files(self) -> None: + """ + Test that list_files() returns only the S3 objects that match + the specified filters (e.g. key substring and non‑empty content). + """ + bucket_name = "list-files-test-bucket" + # Create a fake S3 bucket + s3_resource = boto3.resource("s3", region_name="us-east-1") + s3_resource.create_bucket(Bucket=bucket_name) + s3_client = boto3.client("s3", region_name="us-east-1") + + # Upload several objects: + # - Two objects that match + s3_client.put_object(Bucket=bucket_name, Key="log_file1.txt", Body=b"dummy content") + s3_client.put_object(Bucket=bucket_name, Key="log_file2.txt", Body=b"dummy content") + # Onne that doesnt match the key_contains filter + s3_client.put_object(Bucket=bucket_name, Key="not_a_thing.txt", Body=b"dummy content") + # One that doesnt match the file_size filter + s3_client.put_object(Bucket=bucket_name, Key="log_empty.txt", Body=b"") + + # Call list files + cg = CloudGrep() + result = cg.list_files( + bucket=bucket_name, + account_name=None, + container_name=None, + google_bucket=None, + prefix="", + key_contains="log", + from_date=None, + end_date=None, + file_size=1000000 # 1 MB + ) + + # Assert only the matching files are returned + self.assertIn("s3", result) + expected_keys = {"log_file1.txt", "log_file2.txt"} + self.assertEqual(set(result["s3"]), expected_keys) + + # Now search the contents of the files and assert they hit + for key in expected_keys: + with patch("sys.stdout", new=StringIO()) as fake_out: + cg.search( + bucket=bucket_name, + account_name=None, + container_name=None, + google_bucket=None, + query=["dummy content"], + file=None, + yara_file=None, + file_size=1000000, + prefix="", + key_contains=key, + from_date=None, + end_date=None, + hide_filenames=False, + log_type=None, + log_format=None, + log_properties=[], + profile=None, + json_output=False, + files=result, # Pass the pre-filtered files from list_files + ) + output = fake_out.getvalue().strip() + self.assertIn("log_file1.txt", output) \ No newline at end of file