Skip to content

Commit 3b31ec5

Browse files
authored
Fix multiple logic fixes and prevent recursive hooking conflicts
- Added an early return for zip slip detection - Run _custom_genops_chunks on one chunk range at a time - Fixed a bug with the classification fallback - Replace thread local storage with a global lock to avoid hooking conflicts between threads - Added error handling for file seeking
2 parents 7b7b6b0 + c779f38 commit 3b31ec5

2 files changed

Lines changed: 57 additions & 63 deletions

File tree

lib/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@
8585
"cProfile",
8686
"cloudpickle.load",
8787
"cloudpickle.loads",
88+
"code.interact",
89+
"code.InteractiveConsole",
8890
"code.InteractiveInterpreter",
8991
"codecs.decode",
9092
"codeop.compile_command",
@@ -101,6 +103,7 @@
101103
"eval",
102104
"exec",
103105
"execfile",
106+
"fileinput",
104107
"get_type_hints",
105108
"gzip",
106109
"hashlib",
@@ -186,6 +189,7 @@
186189
"reconstruct",
187190
"scipy",
188191
"set",
192+
"shutil.disk_usage",
189193
"sklearn",
190194
"spacy",
191195
"str",

saferpickle.py

Lines changed: 53 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,15 @@ def _process_chunk_for_generate_ops(
247247
chunked_operands.add(str(operand))
248248
else:
249249
with open(pickle_data_source, "rb") as f:
250-
for _, operand in _custom_chunked_genops(f, chunk_range):
251-
if operand is None:
252-
continue
253-
chunked_operands.add(str(operand))
250+
f.seek(chunk_range[0])
251+
chunk_data = f.read(chunk_range[1] - chunk_range[0])
252+
with io.BytesIO(chunk_data) as memory_f:
253+
for _, operand in _custom_chunked_genops(
254+
memory_f, (0, len(chunk_data))
255+
):
256+
if operand is None:
257+
continue
258+
chunked_operands.add(str(operand))
254259
except StopIteration:
255260
pass
256261
return chunked_operands
@@ -565,10 +570,8 @@ def categorize_strings(
565570
elif classification == Classification.SUSPICIOUS:
566571
new_suspicious_results.add(result)
567572
elif classification == Classification.UNKNOWN:
568-
new_unknown_results.add(result)
569-
else:
570573
# Fallback: Check against original categories if
571-
# classify_class_name returns None.
574+
# classify_class_name returns UNKNOWN.
572575
if result in unsafe_results:
573576
new_unsafe_results.add(result)
574577
elif result in suspicious_results:
@@ -867,9 +870,12 @@ def _extract_and_scan_archive(
867870
for name in zf.namelist():
868871
if ".." in name or name.startswith("/"):
869872
# Zip slip detection
870-
all_scores["unsafe"] += 1337 # High score for zip slip
871873
logging.warning("Zip slip detected: %s", name)
872-
continue
874+
return {
875+
"unsafe": 1337,
876+
"suspicious": 0,
877+
"unknown": 0,
878+
} # Return early
873879

874880
with zf.open(name) as f:
875881
content = f.read()
@@ -894,9 +900,8 @@ def _extract_and_scan_archive(
894900
elif archive_type == "tar":
895901
for name, content in utils.extract_tar_contents(data):
896902
if ".." in name or name.startswith("/"):
897-
all_scores["unsafe"] += 1337
898903
logging.warning("Tar slip detected: %s", name)
899-
continue
904+
return {"unsafe": 1337, "suspicious": 0, "unknown": 0} # Return early
900905
scores = security_scan(content, recursion_depth=recursion_depth + 1)
901906
_merge_scores(all_scores, scores)
902907

@@ -996,7 +1001,8 @@ def _security_scan_internal(
9961001
os.remove(pickle_file_path)
9971002

9981003

999-
_thread_local_storage_for_hooking = threading.local()
1004+
_ORIG_METHODS_BEFORE_HOOKING = {}
1005+
_HOOKING_LOCK = threading.Lock()
10001006

10011007

10021008
def _report_or_raise(
@@ -1047,9 +1053,9 @@ def _scan_and_load(
10471053
raise TypeError("pickle_file_or_bytes must be IOBase when is_load=True")
10481054
pickle_file = pickle_file_or_bytes
10491055
data_bytes = pickle_file.read()
1050-
if hasattr(pickle_file, "seek"):
1056+
try:
10511057
pickle_file.seek(0)
1052-
else:
1058+
except (OSError, AttributeError, io.UnsupportedOperation):
10531059
pickle_file = io.BytesIO(data_bytes)
10541060
else:
10551061
if not isinstance(pickle_file_or_bytes, bytes):
@@ -1163,10 +1169,6 @@ def hook_pickle(
11631169
) -> None:
11641170
"""This implements the hooking of pickle-like libraries."""
11651171
config.set_config_path(config_path)
1166-
if not hasattr(
1167-
_thread_local_storage_for_hooking, "orig_methods_before_hooking"
1168-
):
1169-
_thread_local_storage_for_hooking.orig_methods_before_hooking = {}
11701172

11711173
def custom_loads(
11721174
pickle_bytes: bytes,
@@ -1295,32 +1297,27 @@ def custom_load(
12951297
logging.debug("Failed to import %s", hookable_mod)
12961298
continue
12971299

1298-
if (
1299-
hookable_mod
1300-
not in _thread_local_storage_for_hooking.orig_methods_before_hooking
1301-
):
1302-
_thread_local_storage_for_hooking.orig_methods_before_hooking[
1303-
hookable_mod
1304-
] = {}
1305-
1306-
methods_to_patch = {
1307-
"load": functools.partial(custom_load, hooked_mod_name=hookable_mod),
1308-
"_load": functools.partial(custom_load, hooked_mod_name=hookable_mod),
1309-
"loads": functools.partial(custom_loads, hooked_mod_name=hookable_mod),
1310-
"_loads": functools.partial(custom_loads, hooked_mod_name=hookable_mod),
1311-
}
1312-
for method_name, custom_func in methods_to_patch.items():
1313-
if hasattr(module, method_name):
1314-
if (
1315-
method_name
1316-
not in _thread_local_storage_for_hooking.orig_methods_before_hooking[
1317-
hookable_mod
1318-
]
1319-
):
1320-
_thread_local_storage_for_hooking.orig_methods_before_hooking[
1321-
hookable_mod
1322-
][method_name] = getattr(module, method_name)
1323-
setattr(module, method_name, custom_func)
1300+
with _HOOKING_LOCK:
1301+
if hookable_mod not in _ORIG_METHODS_BEFORE_HOOKING:
1302+
_ORIG_METHODS_BEFORE_HOOKING[hookable_mod] = {}
1303+
1304+
methods_to_patch = {
1305+
"load": functools.partial(custom_load, hooked_mod_name=hookable_mod),
1306+
"_load": functools.partial(custom_load, hooked_mod_name=hookable_mod),
1307+
"loads": functools.partial(
1308+
custom_loads, hooked_mod_name=hookable_mod
1309+
),
1310+
"_loads": functools.partial(
1311+
custom_loads, hooked_mod_name=hookable_mod
1312+
),
1313+
}
1314+
for method_name, custom_func in methods_to_patch.items():
1315+
if hasattr(module, method_name):
1316+
if method_name not in _ORIG_METHODS_BEFORE_HOOKING[hookable_mod]:
1317+
_ORIG_METHODS_BEFORE_HOOKING[hookable_mod][method_name] = getattr(
1318+
module, method_name
1319+
)
1320+
setattr(module, method_name, custom_func)
13241321

13251322

13261323
@contextlib.contextmanager
@@ -1348,25 +1345,18 @@ def hook_pickle_libs(
13481345

13491346
def unhook_pickle() -> None:
13501347
"""Unhooks the pickle-like libraries."""
1351-
if not hasattr(
1352-
_thread_local_storage_for_hooking, "orig_methods_before_hooking"
1353-
):
1354-
return
1355-
for (
1356-
module_name,
1357-
methods,
1358-
) in _thread_local_storage_for_hooking.orig_methods_before_hooking.items():
1359-
try:
1360-
module = importlib.import_module(module_name)
1361-
for method_name, original_method in methods.items():
1362-
if hasattr(module, method_name):
1363-
setattr(module, method_name, original_method)
1364-
except (ImportError, ModuleNotFoundError):
1365-
logging.debug("Failed to import %s for unhooking", module_name)
1366-
continue
1367-
# Empty stored methods to avoid re-unhooking on a second unhook call
1368-
# items() would be empty after clearing the dictionary
1369-
_thread_local_storage_for_hooking.orig_methods_before_hooking.clear()
1348+
with _HOOKING_LOCK:
1349+
for module_name, methods in _ORIG_METHODS_BEFORE_HOOKING.items():
1350+
try:
1351+
module = importlib.import_module(module_name)
1352+
for method_name, original_method in methods.items():
1353+
if hasattr(module, method_name):
1354+
setattr(module, method_name, original_method)
1355+
except (ImportError, ModuleNotFoundError):
1356+
logging.debug("Failed to import %s for unhooking", module_name)
1357+
continue
1358+
# Empty stored methods to avoid re-unhooking on a second unhook call
1359+
_ORIG_METHODS_BEFORE_HOOKING.clear()
13701360

13711361

13721362
# To avoid creating __pycache__ files

0 commit comments

Comments
 (0)