diff --git a/.github/workflows/app-ci.yml b/.github/workflows/app-ci.yml index 22ab7e8..929a376 100644 --- a/.github/workflows/app-ci.yml +++ b/.github/workflows/app-ci.yml @@ -53,7 +53,7 @@ jobs: ./release/generate_linux_binary.sh chmod +x ./dist/cloudgrep ./dist/cloudgrep -h # check it doesn't return non 0 exit status, i.e. crash - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: dist-linux path: ./dist/* @@ -76,7 +76,7 @@ jobs: cd release ./generate_windows_binary.bat ./dist/cloudgrep.exe -h - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: dist-windows path: ./release/dist/* @@ -99,7 +99,7 @@ jobs: ./release/generate_linux_binary.sh chmod +x ./dist/cloudgrep ./dist/cloudgrep -h # check it doesn't return non 0 exit status, i.e. crash - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: dist-osx path: ./dist/* diff --git a/cloudgrep/__main__.py b/cloudgrep/__main__.py index 8db34df..1a5a825 100644 --- a/cloudgrep/__main__.py +++ b/cloudgrep/__main__.py @@ -7,6 +7,7 @@ VERSION = "1.0.5" + # Define a custom argument type for a list of strings def list_of_strings(arg: str) -> List[str]: return arg.split(",") diff --git a/cloudgrep/cloud.py b/cloudgrep/cloud.py index dd9863b..b8f63ae 100644 --- a/cloudgrep/cloud.py +++ b/cloudgrep/cloud.py @@ -1,11 +1,12 @@ import boto3 +import os from azure.storage.blob import BlobServiceClient, BlobProperties from azure.identity import DefaultAzureCredential from azure.core.exceptions import ResourceNotFoundError from google.cloud import storage # type: ignore from datetime import datetime import botocore -import concurrent +import concurrent.futures import tempfile from typing import Iterator, Optional, List, Any import logging @@ -13,6 +14,9 @@ class Cloud: + def __init__(self) -> None: + self.search = Search() + def download_from_s3_multithread( self, bucket: str, @@ -26,34 +30,38 @@ def download_from_s3_multithread( ) -> int: """Use ThreadPoolExecutor and boto3 to download every file in the bucket from s3 Returns number of matched files""" - client_config = botocore.config.Config( - max_pool_connections=64, - ) - matched_count = 0 - s3 = boto3.client("s3", config=client_config) + if not log_properties: + log_properties = [] - # Create a function to download the files - def download_file(key: str) -> None: - # Get meta data of file in s3 using boto3 - with tempfile.NamedTemporaryFile() as tmp: - tmp.close() # fixes issue on windows - logging.info(f"Downloading {bucket} {key} to {tmp.name}") - s3.download_file(bucket, key, tmp.name) - matched = Search().search_file( - tmp.name, key, query, hide_filenames, yara_rules, log_format, log_properties, json_output - ) - if matched: - nonlocal matched_count - matched_count += 1 + s3 = boto3.client("s3", config=botocore.config.Config(max_pool_connections=64)) - # Use ThreadPoolExecutor to download the files - with concurrent.futures.ThreadPoolExecutor() as executor: # type: ignore - executor.map(download_file, files) - # For debugging, run in a single thread for clearer logging: - # for file in files: - # download_file(file) + def _download_search_s3(key: str) -> int: + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp_name = tmp.name + try: + logging.info(f"Downloading s3://{bucket}/{key} to {tmp_name}") + s3.download_file(bucket, key, tmp_name) + matched = self.search.search_file( + tmp_name, key, query, hide_filenames, yara_rules, log_format, log_properties, json_output + ) + return 1 if matched else 0 + except Exception as e: + logging.error(f"Error downloading or searching {key}: {e}") + return 0 + finally: + try: + # Cleanup + os.remove(tmp_name) + except OSError: + pass - return matched_count + total_matched = 0 + # Use ThreadPoolExecutor to download and search files + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(_download_search_s3, k) for k in files] + for fut in concurrent.futures.as_completed(futures): + total_matched += fut.result() + return total_matched def download_from_azure( self, @@ -64,50 +72,63 @@ def download_from_azure( hide_filenames: bool, yara_rules: Any, log_format: Optional[str] = None, - log_properties: List[str] = [], - json_output: Optional[bool] = False, + log_properties: Optional[List[str]] = None, + json_output: bool = False, ) -> int: """Download every file in the container from azure Returns number of matched files""" + if not log_properties: + log_properties = [] + default_credential = DefaultAzureCredential() - matched_count = 0 blob_service_client = BlobServiceClient.from_connection_string( f"DefaultEndpointsProtocol=https;AccountName={account_name};EndpointSuffix=core.windows.net", credential=default_credential, ) container_client = blob_service_client.get_container_client(container_name) - def download_file(key: str) -> None: - with tempfile.NamedTemporaryFile() as tmp: - tmp.close() # fixes issue on windows - logging.info(f"Downloading {account_name}/{container_name} {key} to {tmp.name}") + def _download_search_azure(key: str) -> int: + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp_name = tmp.name + try: + logging.info(f"Downloading azure://{account_name}/{container_name}/{key} to {tmp_name}") + blob_client = container_client.get_blob_client(key) + with open(tmp_name, "wb") as my_blob: + blob_data = blob_client.download_blob() + blob_data.readinto(my_blob) + + matched = self.search.search_file( + tmp_name, + key, + query, + hide_filenames, + yara_rules, + log_format, + log_properties, + json_output, + account_name, + ) + return 1 if matched else 0 + except ResourceNotFoundError: + logging.info(f"File {key} not found in {account_name}/{container_name}") + return 0 + except Exception as e: + logging.error(f"Error downloading or searching {key}: {e}") + return 0 + finally: try: - blob_client = container_client.get_blob_client(key) - with open(tmp.name, "wb") as my_blob: - blob_data = blob_client.download_blob() - blob_data.readinto(my_blob) - matched = Search().search_file( - tmp.name, - key, - query, - hide_filenames, - yara_rules, - log_format, - log_properties, - json_output, - account_name, - ) - if matched: - nonlocal matched_count - matched_count += 1 - except ResourceNotFoundError: - logging.info(f"File {key} not found in {account_name}/{container_name}") + import os - # Use ThreadPoolExecutor to download the files - with concurrent.futures.ThreadPoolExecutor() as executor: - executor.map(download_file, files) + os.remove(tmp_name) + except OSError: + pass - return matched_count + total_matched = 0 + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(_download_search_azure, k) for k in files] + for fut in concurrent.futures.as_completed(futures): + total_matched += fut.result() + return total_matched def download_from_google( self, @@ -117,34 +138,49 @@ def download_from_google( hide_filenames: bool, yara_rules: Any, log_format: Optional[str] = None, - log_properties: List[str] = [], - json_output: Optional[bool] = False, + log_properties: Optional[List[str]] = None, + json_output: bool = False, ) -> int: """Download every file in the bucket from google Returns number of matched files""" + if not log_properties: + log_properties = [] - matched_count = 0 client = storage.Client() bucket_gcp = client.get_bucket(bucket) - def download_file(key: str) -> None: - with tempfile.NamedTemporaryFile() as tmp: - tmp.close() # fixes issue on windows - logging.info(f"Downloading {bucket} {key} to {tmp.name}") + def _download_and_search_google(key: str) -> int: + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp_name = tmp.name + try: + logging.info(f"Downloading gs://{bucket}/{key} to {tmp_name}") blob = bucket_gcp.get_blob(key) - blob.download_to_filename(tmp.name) - matched = Search().search_file( - tmp.name, key, query, hide_filenames, yara_rules, log_format, log_properties, json_output + if blob is None: + logging.warning(f"Blob not found: {key}") + return 0 + blob.download_to_filename(tmp_name) + + matched = self.search.search_file( + tmp_name, key, query, hide_filenames, yara_rules, log_format, log_properties, json_output ) - if matched: - nonlocal matched_count - matched_count += 1 + return 1 if matched else 0 + except Exception as e: + logging.error(f"Error downloading or searching {key}: {e}") + return 0 + finally: + try: + import os - # Use ThreadPoolExecutor to download the files - with concurrent.futures.ThreadPoolExecutor() as executor: - executor.map(download_file, files) + os.remove(tmp_name) + except OSError: + pass - return matched_count + total_matched = 0 + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(_download_and_search_google, k) for k in files] + for fut in concurrent.futures.as_completed(futures): + total_matched += fut.result() + return total_matched def filter_object( self, @@ -154,12 +190,13 @@ def filter_object( to_date: Optional[datetime], file_size: int, ) -> bool: + """Filter S3 objects by date range, file size, and substring in key.""" last_modified = obj["LastModified"] if last_modified and from_date and from_date > last_modified: return False # Object was modified before the from_date if last_modified and to_date and last_modified > to_date: return False # Object was modified after the to_date - if obj["Size"] == 0 or obj["Size"] > file_size: + if obj["Size"] == 0 or int(obj["Size"]) > file_size: return False # Object is empty or too large if key_contains and key_contains not in obj["Key"]: return False # Object does not contain the key_contains string @@ -173,15 +210,16 @@ def filter_object_azure( to_date: Optional[datetime], file_size: int, ) -> bool: - last_modified = obj["last_modified"] + """Filter Azure blob objects similarly.""" + last_modified = obj["last_modified"] # type: ignore if last_modified and from_date and from_date > last_modified: - return False # Object was modified before the from_date + return False if last_modified and to_date and last_modified > to_date: - return False # Object was modified after the to_date + return False if obj["size"] == 0 or int(obj["size"]) > file_size: - return False # Object is empty or too large + return False if key_contains and key_contains not in obj["name"]: - return False # Object does not contain the key_contains string + return False return True def filter_object_google( @@ -191,6 +229,7 @@ def filter_object_google( from_date: Optional[datetime], to_date: Optional[datetime], ) -> bool: + """Filter objects in GCP""" last_modified = obj.updated if last_modified and from_date and from_date > last_modified: return False @@ -209,7 +248,7 @@ def get_objects( end_date: Optional[datetime], file_size: int, ) -> Iterator[str]: - """Get all objects in a bucket with a given prefix""" + """Get objects in S3""" s3 = boto3.client("s3") paginator = s3.get_paginator("list_objects_v2") page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix) @@ -229,8 +268,8 @@ def get_azure_objects( end_date: Optional[datetime], file_size: int, ) -> Iterator[str]: + """Get all objects in Azure storage container with a given prefix""" default_credential = DefaultAzureCredential() - """ Get all objects in Azure storage container with a given prefix """ blob_service_client = BlobServiceClient.from_connection_string( f"DefaultEndpointsProtocol=https;AccountName={account_name};EndpointSuffix=core.windows.net", credential=default_credential, @@ -239,14 +278,7 @@ def get_azure_objects( blobs = container_client.list_blobs(name_starts_with=prefix) for blob in blobs: - - if self.filter_object_azure( - blob, - key_contains, - from_date, - end_date, - file_size, - ): + if self.filter_object_azure(blob, key_contains, from_date, end_date, file_size): yield blob.name def get_google_objects( @@ -262,10 +294,5 @@ def get_google_objects( bucket_gcp = client.get_bucket(bucket) blobs = bucket_gcp.list_blobs(prefix=prefix) for blob in blobs: - if self.filter_object_google( - blob, - key_contains, - from_date, - end_date, - ): + if self.filter_object_google(blob, key_contains, from_date, end_date): yield blob.name diff --git a/cloudgrep/cloudgrep.py b/cloudgrep/cloudgrep.py index 4dc71b2..6250a34 100644 --- a/cloudgrep/cloudgrep.py +++ b/cloudgrep/cloudgrep.py @@ -1,19 +1,20 @@ import boto3 -from datetime import timezone, datetime -from dateutil.parser import parse -from typing import Optional +from datetime import datetime +from typing import Optional, List, Any import logging -from cloudgrep.cloud import Cloud -from typing import List - import yara # type: ignore +from cloudgrep.cloud import Cloud + class CloudGrep: + def __init__(self) -> None: + self.cloud = Cloud() + def load_queries(self, file: str) -> List[str]: """Load in a list of queries from a file""" with open(file, "r") as f: - return [line.strip() for line in f.readlines() if len(line.strip())] + return [line.strip() for line in f if line.strip()] def search( self, @@ -32,95 +33,177 @@ def search( hide_filenames: bool = False, log_type: Optional[str] = None, log_format: Optional[str] = None, - log_properties: List[str] = [], + log_properties: Optional[List[str]] = None, profile: Optional[str] = None, - json_output: Optional[bool] = False, + json_output: bool = False, ) -> None: - # load in a list of queries from a file + """Search query/queries across cloud storage""" + + # Load queries from a file if given if not query and file: logging.debug(f"Loading queries in from {file}") query = self.load_queries(file) - # Set log_format and log_properties values based on potential log_type input argument - if log_type != None: - match log_type: - case "cloudtrail": - log_format = "json" - log_properties = ["Records"] - case "azure": - log_format = "json" - log_properties = ["data"] - case _: - logging.error( - f"Invalid log_type value ('{log_type}') unhandled in switch statement in 'search' function." - ) - + # Compile optional Yara rules + yara_rules = None if yara_file: logging.debug(f"Loading yara rules from {yara_file}") yara_rules = yara.compile(filepath=yara_file) - else: - yara_rules = None if profile: # Set the AWS credentials profile to use boto3.setup_default_session(profile_name=profile) - # Parse dates - parsed_from_date = None - if from_date: - parsed_from_date = parse(from_date).astimezone(timezone.utc) # type: ignore - parsed_end_date = None - if end_date: - parsed_end_date = parse(end_date).astimezone(timezone.utc) # type: ignore + if log_type is not None: + if log_type == "cloudtrail": + log_format = "json" + log_properties = ["Records"] + elif log_type == "azure": + log_format = "json" + log_properties = ["data"] + else: + logging.error(f"Invalid log_type: '{log_type}'") + return + if log_properties is None: + log_properties = [] # default + # Search given cloud storage if bucket: - matching_keys = list( - Cloud().get_objects(bucket, prefix, key_contains, parsed_from_date, parsed_end_date, file_size) - ) - s3_client = boto3.client("s3") - region = s3_client.get_bucket_location(Bucket=bucket) - if log_format != None: - logging.warning( - f"Bucket is in region: {region['LocationConstraint']} : Search from the same region to avoid egress charges." - ) - logging.warning(f"Searching {len(matching_keys)} files in {bucket} for {query}...") - - else: - print( - f"Bucket is in region: {region['LocationConstraint']} : Search from the same region to avoid egress charges." - ) - print(f"Searching {len(matching_keys)} files in {bucket} for {query}...") - Cloud().download_from_s3_multithread( - bucket, matching_keys, query, hide_filenames, yara_rules, log_format, log_properties, json_output + self._search_s3( + bucket=bucket, + query=query, + yara_rules=yara_rules, + file_size=file_size, + prefix=prefix, + key_contains=key_contains, + from_date=from_date, + end_date=end_date, + hide_filenames=hide_filenames, + log_format=log_format, + log_properties=log_properties, + json_output=json_output, ) if account_name and container_name: - matching_keys = list( - Cloud().get_azure_objects( - account_name, container_name, prefix, key_contains, parsed_from_date, parsed_end_date, file_size - ) - ) - print(f"Searching {len(matching_keys)} files in {account_name}/{container_name} for {query}...") - - Cloud().download_from_azure( - account_name, - container_name, - matching_keys, - query, - hide_filenames, - yara_rules, - log_format, - log_properties, - json_output, + self._search_azure( + account_name=account_name, + container_name=container_name, + query=query, + yara_rules=yara_rules, + file_size=file_size, + prefix=prefix, + key_contains=key_contains, + from_date=from_date, + end_date=end_date, + hide_filenames=hide_filenames, + log_format=log_format, + log_properties=log_properties, + json_output=json_output, ) if google_bucket: - matching_keys = list( - Cloud().get_google_objects(google_bucket, prefix, key_contains, parsed_from_date, parsed_end_date) + self._search_gcs( + google_bucket=google_bucket, + query=query, + yara_rules=yara_rules, + file_size=file_size, + prefix=prefix, + key_contains=key_contains, + from_date=from_date, + end_date=end_date, + hide_filenames=hide_filenames, + log_format=log_format, + log_properties=log_properties, + json_output=json_output, ) - print(f"Searching {len(matching_keys)} files in {google_bucket} for {query}...") + def _search_s3( + self, + bucket: str, + query: List[str], + yara_rules: Any, + file_size: int, + prefix: Optional[str], + key_contains: Optional[str], + from_date: Optional[datetime], + end_date: Optional[datetime], + hide_filenames: bool, + log_format: Optional[str], + log_properties: List[str], + json_output: bool, + ) -> None: + """Search S3 bucket for query""" + matching_keys = list(self.cloud.get_objects(bucket, prefix, key_contains, from_date, end_date, file_size)) + s3_client = boto3.client("s3") + region = s3_client.get_bucket_location(Bucket=bucket) + logging.warning( + f"Bucket is in region: {region.get('LocationConstraint', 'unknown')} : " + "Search from the same region to avoid egress charges." + ) + logging.warning(f"Searching {len(matching_keys)} files in {bucket} for {query}...") + self.cloud.download_from_s3_multithread( + bucket, matching_keys, query, hide_filenames, yara_rules, log_format, log_properties, json_output + ) - Cloud().download_from_google( - google_bucket, matching_keys, query, hide_filenames, yara_rules, log_format, log_properties, json_output + def _search_azure( + self, + account_name: str, + container_name: str, + query: List[str], + yara_rules: Any, + file_size: int, + prefix: Optional[str], + key_contains: Optional[str], + from_date: Optional[datetime], + end_date: Optional[datetime], + hide_filenames: bool, + log_format: Optional[str], + log_properties: List[str], + json_output: bool, + ) -> None: + """Search Azure container for query""" + matching_keys = list( + self.cloud.get_azure_objects( + account_name, container_name, prefix, key_contains, from_date, end_date, file_size ) + ) + print(f"Searching {len(matching_keys)} files in {account_name}/{container_name} for {query}...") + self.cloud.download_from_azure( + account_name, + container_name, + matching_keys, + query, + hide_filenames, + yara_rules, + log_format, + log_properties, + json_output, + ) + + def _search_gcs( + self, + google_bucket: str, + query: List[str], + yara_rules: Any, + file_size: int, + prefix: Optional[str], + key_contains: Optional[str], + from_date: Optional[datetime], + end_date: Optional[datetime], + hide_filenames: bool, + log_format: Optional[str], + log_properties: List[str], + json_output: bool, + ) -> None: + matching_keys = list(self.cloud.get_google_objects(google_bucket, prefix, key_contains, from_date, end_date)) + print(f"Searching {len(matching_keys)} files in {google_bucket} for {query}...") + self.cloud.download_from_google( + google_bucket, + matching_keys, + query, + hide_filenames, + yara_rules, + log_format, + log_properties, + json_output, + ) diff --git a/cloudgrep/search.py b/cloudgrep/search.py index 014df17..4652700 100644 --- a/cloudgrep/search.py +++ b/cloudgrep/search.py @@ -16,34 +16,42 @@ def get_all_strings_line(self, file_path: str) -> List[str]: """ with open(file_path, "rb") as f: read_bytes = f.read() - b = read_bytes.decode("utf-8", "ignore") - b = b.replace("\n", "\r") - string_list = b.split("\r") - return string_list + return read_bytes.decode("utf-8", "ignore").replace("\n", "\r").split("\r") def print_match(self, matched_line_dict: dict, hide_filenames: bool, json_output: Optional[bool]) -> None: """Print matched line""" if json_output: - if hide_filenames: - matched_line_dict.pop("key_name") + matched_line_dict.pop("key_name", None) if hide_filenames else None try: - print(json.dumps(matched_line_dict)) except TypeError: - print(str(matched_line_dict)) else: - line = "" - if "line" in matched_line_dict: - line = matched_line_dict["line"] + line = matched_line_dict.get("line", "") if "match_rule" in matched_line_dict: line = f"{matched_line_dict['match_rule']}: {matched_line_dict['match_strings']}" - - if not hide_filenames: - print(f"{matched_line_dict['key_name']}: {line}") - else: - - print(line) + print(line if hide_filenames else f"{matched_line_dict['key_name']}: {line}") + + def parse_logs(self, line: str, log_format: Optional[str]) -> Any: + """Parse input log line based on format""" + try: + if log_format == "json": + return json.loads(line) + elif log_format == "csv": + return list(csv.DictReader([line])) + elif log_format: + logging.error(f"Invalid log format: {log_format}") + except (json.JSONDecodeError, csv.Error) as e: + logging.error(f"Invalid {log_format} format in line: {line} ({e})") + return None + + def extract_log_entries(self, line_parsed: Any, log_properties: List[str]) -> List[Any]: + """Extract properties in log entries""" + if log_properties: + for log_property in log_properties: + if isinstance(line_parsed, dict): + line_parsed = line_parsed.get(log_property, None) + return line_parsed if isinstance(line_parsed, list) else [line_parsed] def search_logs( self, @@ -55,42 +63,14 @@ def search_logs( log_properties: List[str] = [], json_output: Optional[bool] = False, ) -> None: - """Regex search of each log record in input line""" - # Parse input line based on defined format. - match log_format: - case "json": - try: - line_parsed = json.loads(line) - except json.JSONDecodeError: - logging.error(f"Invalid JSON in line: {line}") - return None - case "csv": - line_parsed = csv.DictReader(line) - case _: - logging.error( - f"Invalid log_format value ('{log_format}') in switch statement in 'search_logs' function, so defaulting to 'json'." - ) - # Default to JSON format. - log_format = "json" - line_parsed = json.loads(line) + """Search log records in parsed logs""" + line_parsed = self.parse_logs(line, log_format) + if not line_parsed: + return - # Step into property/properties to get to final list of lines for per-line searching. - if log_properties != None: - - for log_property in log_properties: - if line_parsed: - line_parsed = line_parsed.get(log_property, None) - - # Ensure line_parsed is iterable. - if type(line_parsed) != list: - line_parsed = [line_parsed] - - # Perform per-line searching. - for record in line_parsed: + for record in self.extract_log_entries(line_parsed, log_properties): if re.search(search, json.dumps(record)): - - matched_line_dict = {"key_name": key_name, "query": search, "line": record} - self.print_match(matched_line_dict, hide_filenames, json_output) + self.print_match({"key_name": key_name, "query": search, "line": record}, hide_filenames, json_output) def search_line( self, @@ -103,30 +83,29 @@ def search_line( json_output: Optional[bool] = False, ) -> bool: """Regex search of the line""" - matched = False - for cur_search in search: - if re.search(cur_search, line): - - if log_format != None: + matched = any(re.search(cur_search, line) for cur_search in search) + if matched: + if log_format: + for cur_search in search: self.search_logs( line, key_name, cur_search, hide_filenames, log_format, log_properties, json_output ) - else: - - matched_line_dict = {"key_name": key_name, "query": cur_search, "line": line} - self.print_match(matched_line_dict, hide_filenames, json_output) - matched = True + else: + self.print_match({"key_name": key_name, "query": search, "line": line}, hide_filenames, json_output) return matched - def yara_scan_file(self, file_name: str, key_name: str, hide_filenames: bool, yara_rules: Any, json_output: Optional[bool] = False) -> bool: # type: ignore - matched = False + def yara_scan_file( + self, file_name: str, key_name: str, hide_filenames: bool, yara_rules: Any, json_output: Optional[bool] = False + ) -> bool: + """Run Yara scan on a file""" matches = yara_rules.match(file_name) - if matches: - for match in matches: - matched_line_dict = {"key_name": key_name, "match_rule": match.rule, "match_strings": match.strings} - self.print_match(matched_line_dict, hide_filenames, json_output) - matched = True - return matched + for match in matches: + self.print_match( + {"key_name": key_name, "match_rule": match.rule, "match_strings": match.strings}, + hide_filenames, + json_output, + ) + return bool(matches) def search_file( self, @@ -141,82 +120,31 @@ def search_file( account_name: Optional[str] = None, ) -> bool: """Regex search of the file line by line""" - matched = False - logging.info(f"Searching {file_name} for {search}") if yara_rules: - matched = self.yara_scan_file(file_name, key_name, hide_filenames, yara_rules, json_output) - else: - if key_name.endswith(".gz"): - with gzip.open(file_name, "rt") as f: - if account_name: - try: - # Try to load the file as JSON - json_data = json.load(f) - for i in range(len(json_data)): - data = json_data[i] - line = json.dumps(data) - if self.search_line( - key_name, search, hide_filenames, line, log_format, log_properties, json_output - ): - matched = True - except json.JSONDecodeError: - logging.info(f"File {file_name} is not JSON") - else: - for line in f: - - if self.search_line( - key_name, search, hide_filenames, line, log_format, log_properties, json_output - ): - matched = True - elif key_name.endswith(".zip"): - with tempfile.TemporaryDirectory() as tempdir: - with zipfile.ZipFile(file_name, "r") as zf: - zf.extractall(tempdir) - logging.info(f"Extracted {file_name} to {tempdir}") - for filename in os.listdir(tempdir): - logging.info(f"Searching in zip {filename}") - if os.path.isfile(os.path.join(tempdir, filename)): - with open(os.path.join(tempdir, filename)) as f: - if account_name: - if account_name: - try: - json_data = json.load(f) - for i in range(len(json_data)): - data = json_data[i] - line = json.dumps(data) - - if self.search_line( - key_name, - search, - hide_filenames, - line, - log_format, - log_properties, - json_output, - ): - matched = True - except json.JSONDecodeError: - logging.info(f"File {file_name} is not JSON") - else: - for line in f: - if self.search_line( - f"{key_name}/{filename}", - search, - hide_filenames, - line, - log_format, - log_properties, - json_output, - ): - matched = True - else: - - for line in self.get_all_strings_line(file_name): - - if self.search_line( - key_name, search, hide_filenames, line, log_format, log_properties, json_output - ): - matched = True - - return matched + return self.yara_scan_file(file_name, key_name, hide_filenames, yara_rules, json_output) + + def process_lines(lines: Any) -> bool: + return any( + self.search_line(key_name, search, hide_filenames, line, log_format, log_properties, json_output) + for line in lines + ) + + if key_name.endswith(".gz"): + with gzip.open(file_name, "rt") as f: + return process_lines(json.load(f) if account_name else f) + elif key_name.endswith(".zip"): + with tempfile.TemporaryDirectory() as tempdir, zipfile.ZipFile(file_name, "r") as zf: + zf.extractall(tempdir) + return any( + # Process the extracted files + process_lines( + json.load(open(os.path.join(tempdir, filename))) + if account_name + else open(os.path.join(tempdir, filename)) + ) + # Search all files in the zip file + for filename in os.listdir(tempdir) + if os.path.isfile(os.path.join(tempdir, filename)) + ) + return process_lines(self.get_all_strings_line(file_name)) diff --git a/tests/test_unit.py b/tests/test_unit.py index be9cd21..d10bec3 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -2,6 +2,7 @@ Basic unit tests for Cloud Grep python3 -m unittest discover tests """ + import unittest import os import boto3 @@ -22,9 +23,9 @@ from cloudgrep.cloudgrep import CloudGrep - BASE_PATH = os.path.dirname(os.path.realpath(__file__)) + class CloudGrepTests(unittest.TestCase): """Tests for Cloud Grep""" @@ -36,15 +37,11 @@ def test_weird_files(self) -> None: self.assertIn("SomeLine", Search().get_all_strings_line(f"{BASE_PATH}/data/14_3.log")) def test_gzip(self) -> None: - found = Search().search_file( - f"{BASE_PATH}/data/000000.gz", "000000.gz", ["Running on machine"], False, None - ) + found = Search().search_file(f"{BASE_PATH}/data/000000.gz", "000000.gz", ["Running on machine"], False, None) self.assertTrue(found) def test_zip(self) -> None: - found = Search().search_file( - f"{BASE_PATH}/data/000000.zip", "000000.zip", ["Running on machine"], False, None - ) + found = Search().search_file(f"{BASE_PATH}/data/000000.zip", "000000.zip", ["Running on machine"], False, None) self.assertTrue(found) def test_print_match(self) -> None: @@ -182,33 +179,25 @@ def test_search_cloudtrail(self) -> None: self.assertTrue(json.loads(output)) def test_filter_object_s3_empty_file(self) -> None: - obj = { - "LastModified": datetime(2023, 1, 1), - "Size": 0, - "Key": "empty_file.log" - } + obj = {"LastModified": datetime(2023, 1, 1), "Size": 0, "Key": "empty_file.log"} key_contains = "empty" from_date = datetime(2022, 1, 1) to_date = datetime(2024, 1, 1) file_size = 10000 self.assertFalse( Cloud().filter_object(obj, key_contains, from_date, to_date, file_size), - "Empty file should have been filtered out" + "Empty file should have been filtered out", ) def test_filter_object_s3_out_of_date_range(self) -> None: - obj = { - "LastModified": datetime(2021, 1, 1), - "Size": 500, - "Key": "old_file.log" - } + obj = {"LastModified": datetime(2021, 1, 1), "Size": 500, "Key": "old_file.log"} key_contains = "old" from_date = datetime(2022, 1, 1) to_date = datetime(2024, 1, 1) file_size = 10000 self.assertFalse( Cloud().filter_object(obj, key_contains, from_date, to_date, file_size), - "Object older than from_date should not match" + "Object older than from_date should not match", ) def test_search_logs_csv_format(self) -> None: @@ -223,7 +212,7 @@ def test_search_logs_csv_format(self) -> None: hide_filenames=False, log_format="csv", log_properties=[], - json_output=False + json_output=False, ) self.assertIn("val1", fake_out.getvalue()) @@ -238,7 +227,7 @@ def test_search_logs_unknown_format(self) -> None: hide_filenames=False, log_format="not_a_real_format", log_properties=[], - json_output=False + json_output=False, ) mock_log.assert_called_once() @@ -271,7 +260,7 @@ def test_cloudgrep_search_no_query_file(self) -> None: log_format=None, log_properties=[], profile=None, - json_output=False + json_output=False, ) output = fake_out.getvalue().strip() self.assertIn("hello direct query", output) @@ -305,12 +294,13 @@ def test_cloudgrep_search_with_profile(self) -> None: log_format=None, log_properties=[], profile="my_aws_profile", - json_output=False + json_output=False, ) mock_setup_session.assert_called_with(profile_name="my_aws_profile") def test_main_no_args_shows_help(self) -> None: from cloudgrep.__main__ import main + with patch.object(sys, "argv", ["prog"]): # Argparse prints help to sys.stderr with patch("sys.stderr", new=StringIO()) as fake_err: @@ -336,7 +326,7 @@ def test_azure_search_mocked(self, mock_service_client: MagicMock) -> None: # Actually written to a local file fake_content = b"Some Azure log entry that mentions azure target" - + def fake_readinto_me(file_obj: BinaryIO) -> None: file_obj.write(fake_content)