diff --git a/lib/utils.py b/lib/utils.py index 02343b1..7189722 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -20,6 +20,7 @@ import functools import gzip import importlib +import importlib.util import inspect import io import lzma @@ -336,18 +337,16 @@ def get_module_members(module_name: str) -> Set[str] | None: return None return {member for member, _ in inspect.getmembers(imported_module)} - spec = importlib.util.find_spec(module_name) - if not spec: + if (spec := importlib.util.find_spec(module_name)) is None: return None - - if not os.path.exists(spec.origin): + if (origin := spec.origin) is None: return None - with open(spec.origin, "r") as f: - try: - tree = ast.parse(f.read(), filename=spec.origin) - except (ValueError, SyntaxError): - return None + try: + with open(origin, "r") as f: + tree = ast.parse(f.read(), filename=origin) + except (IOError, ValueError, SyntaxError, RecursionError): + return None members = set() # This generates the list of methods and classes in the module from diff --git a/saferpickle.py b/saferpickle.py index 27e0d58..f6f321d 100644 --- a/saferpickle.py +++ b/saferpickle.py @@ -22,7 +22,6 @@ import io import lzma import math -import multiprocessing from multiprocessing import shared_memory import os import pickle @@ -36,11 +35,13 @@ import zipfile from absl import logging +from third_party.corrupy import picklemagic + +import multiprocessing from lib import config from lib import constants from lib import exceptions from lib import utils -from third_party.corrupy import picklemagic IllegalArgumentCombinationError = exceptions.IllegalArgumentCombinationError @@ -641,27 +642,41 @@ def strict_security_scan(pickle_bytes: bytes) -> bool: def is_unsafe( - safe_score: int, - unsafe_score: int, - suspicious_score: int, + number_of_safe_results: int, + number_of_unsafe_results: int, + number_of_suspicious_results: int, ) -> bool: """Conditional check for safeness. Args: - safe_score: Safe score from the security scan. - unsafe_score: Unsafe score from the security scan. - suspicious_score: Suspicious score from the security scan. + number_of_safe_results: Number of safe results from the security scan. + number_of_unsafe_results: Number of unsafe results from the security scan. + number_of_suspicious_results: Number of suspicious results from the security + scan. Returns: True if the pickle file is dangerous, False otherwise. """ - if unsafe_score == 0 and suspicious_score == 0: + if number_of_unsafe_results == 0 and number_of_suspicious_results == 0: return False - # 0.5 was chosen as the weight for suspicious results - # to guard against unknown results filtering into suspicious results. - sum_of_unsafe_and_suspicious_scores = unsafe_score + 0.5 * suspicious_score - return sum_of_unsafe_and_suspicious_scores >= safe_score + # We halve the weight of suspicious results to lower false positives + # caused by greedy matches of unknown method-like strings (Ex. "google.com") + if ( + number_of_suspicious_results + number_of_unsafe_results + >= number_of_safe_results + ): + return True + + sum_of_unsafe_and_suspicious_results = ( + number_of_unsafe_results + 0.5 * number_of_suspicious_results + ) + + unsafe = (sum_of_unsafe_and_suspicious_results > number_of_safe_results) or ( + number_of_safe_results == 0 and sum_of_unsafe_and_suspicious_results >= 1 + ) + + return unsafe def picklemagic_scan( @@ -786,10 +801,10 @@ def apply_approach( logging.info(" Unknown results: %s\n", results.unknown_results) ( - safe_score, - unsafe_score, - suspicious_score, - unknown_score, + number_of_safe_results, + number_of_unsafe_results, + number_of_suspicious_results, + number_of_unknown_results, ) = score_results( results.safe_results, results.unsafe_results, @@ -797,14 +812,14 @@ def apply_approach( results.unknown_results, ) scores = { - "unsafe": unsafe_score, - "suspicious": suspicious_score, - "unknown": unknown_score, + "unsafe": number_of_unsafe_results, + "suspicious": number_of_suspicious_results, + "unknown": number_of_unknown_results, } if results.is_denylisted or is_unsafe( - safe_score, - unsafe_score, - suspicious_score, + number_of_safe_results, + number_of_unsafe_results, + number_of_suspicious_results, ): return scores