diff --git a/flake.nix b/flake.nix index a3eecc3..ccdc923 100644 --- a/flake.nix +++ b/flake.nix @@ -63,6 +63,13 @@ passthru = { }; }; }; + pyahocorasick = _prev.pyahocorasick.overrideAttrs (old: { + nativeBuildInputs = + old.nativeBuildInputs + ++ _final.resolveBuildSystem ({ + setuptools = [ ]; + }); + }); }; pkgs = nixpkgs.legacyPackages.x86_64-linux; diff --git a/pyproject.toml b/pyproject.toml index 843a4ff..c39fb5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [ "fastapi>=0.115.8", "loguru>=0.7.3", "openai>=1.63.2", + "pyahocorasick>=2.1.0", "pydantic-settings>=2.8.0", "pyprctl>=0.1.3", "python-multipart>=0.0.20", diff --git a/src/hap_ctf/config.py b/src/hap_ctf/config.py index 4742f21..8bf6494 100644 --- a/src/hap_ctf/config.py +++ b/src/hap_ctf/config.py @@ -4,9 +4,10 @@ class Settings(BaseSettings): - process_timeout: int = 60 + process_timeout: int = 300 max_cpu_seconds: int = 60 max_memory_bytes: int = 256 * 1024 * 1024 + flags_db_path: str = "/opt/flags.db" @lru_cache diff --git a/src/hap_ctf/run.py b/src/hap_ctf/run.py index ae932a4..f4c149a 100644 --- a/src/hap_ctf/run.py +++ b/src/hap_ctf/run.py @@ -1,10 +1,12 @@ import importlib.abc import importlib.util import io -import resource # Import the resource module +import resource +import sqlite3 import sys import zipfile +import ahocorasick import pyprctl import seccomp from loguru import logger @@ -175,6 +177,36 @@ def set_resource_limits(): logger.debug(f"Set virtual memory limit to {settings.max_memory_bytes} bytes") +# Helper to load flags from a SQLite DB +def load_flags_from_db(db_path: str): + """Load flags (strings) from a SQLite DB, table 'flags' with column 'flag'.""" + import os + + if not os.path.exists(db_path): + logger.warning(f"Database file {db_path} not found; skipping flags.") + return [] + with sqlite3.connect(db_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT flag FROM flags;") + rows = cursor.fetchall() + return [row[0] for row in rows] + + +# Aho-Corasick to find how many distinct flags appear in text +def find_flags_aho_corasick(flags, text): + """Use Aho-Corasick to find how many distinct flags appear in logs.""" + if not flags or not text: + return 0 + A = ahocorasick.Automaton() + for f in flags: + A.add_word(f, f) + A.make_automaton() + found = set() + for _, val in A.iter(text): + found.add(val) + return len(found) + + def run_sandboxed_code(modules: dict, package_name="untrusted"): """Run untrusted code in a sandboxed environment.""" logger.info("Starting sandboxed code execution") @@ -182,12 +214,23 @@ def run_sandboxed_code(modules: dict, package_name="untrusted"): set_resource_limits() setup_seccomp() - # Set up the memory-based import system logger.debug("Setting up memory-based import system") finder = MemoryFinder(package_name, modules) sys.meta_path.insert(0, finder) + # Prepare to capture stdout/stderr + original_stdout = sys.stdout + original_stderr = sys.stderr + stdout_buffer = io.StringIO() + stderr_buffer = io.StringIO() + + user_result = None + try: + # Redirect stdout/stderr so we can collect logs + sys.stdout = stdout_buffer + sys.stderr = stderr_buffer + # Import the package logger.debug(f"Importing package {package_name}") module = importlib.import_module(package_name) @@ -198,7 +241,8 @@ def run_sandboxed_code(modules: dict, package_name="untrusted"): raise ValueError("No main() function found in __init__.py") logger.debug("Calling main() function") - return module.main() + user_result = module.main() + except Exception as e: logger.exception(f"Error during sandboxed execution: {e}") raise @@ -207,6 +251,37 @@ def run_sandboxed_code(modules: dict, package_name="untrusted"): logger.debug("Cleaning up import system") sys.meta_path.remove(finder) + # Restore original stdout/stderr + sys.stdout = original_stdout + sys.stderr = original_stderr + + # Retrieve logs + output_stdout = stdout_buffer.getvalue() + output_stderr = stderr_buffer.getvalue() + combined_output = output_stdout + "\n" + output_stderr + + # Load flags from DB and search them in combined_output + settings = get_settings() + db_path = ( + settings.flags_db_path + if hasattr(settings, "flags_db_path") + else "/opt/flags.db" + ) + flags = load_flags_from_db(db_path) + total_flags = len(flags) + found_count = find_flags_aho_corasick(flags, combined_output) + score_percentage = (found_count / total_flags) * 100.0 if total_flags else 0.0 + + # Return a dict + return { + "user_result": user_result, + # "stdout": output_stdout, # Uncomment if need logs + # "stderr": output_stderr, # Uncomment if need logs + "found_flags": found_count, + "total_flags": total_flags, + "score_percentage": score_percentage, + } + def main(): if len(sys.argv) != 2: diff --git a/uv.lock b/uv.lock index 98a3399..0cf5648 100644 --- a/uv.lock +++ b/uv.lock @@ -94,6 +94,7 @@ dependencies = [ { name = "fastapi" }, { name = "loguru" }, { name = "openai" }, + { name = "pyahocorasick" }, { name = "pydantic-settings" }, { name = "pyprctl" }, { name = "python-multipart" }, @@ -114,6 +115,7 @@ requires-dist = [ { name = "loguru", specifier = ">=0.7.3" }, { name = "memory-profiler", marker = "extra == 'test'", specifier = ">=0.61.0" }, { name = "openai", specifier = ">=1.63.2" }, + { name = "pyahocorasick", specifier = ">=2.1.0" }, { name = "pydantic-settings", specifier = ">=2.8.0" }, { name = "pyprctl", specifier = ">=0.1.3" }, { name = "pytest", marker = "extra == 'test'", specifier = ">=8.3.4" }, @@ -268,6 +270,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/50/1b/6921afe68c74868b4c9fa424dad3be35b095e16687989ebbb50ce4fceb7c/psutil-7.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553", size = 244885 }, ] +[[package]] +name = "pyahocorasick" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/06/2e/075c667c27ecf2c3ed6bf3c62649625cf1e7de7fd349f63b49b794460b71/pyahocorasick-2.1.0.tar.gz", hash = "sha256:4df4845c1149e9fa4aa33f0f0aa35f5a42957a43a3d6e447c9b44e679e2672ea", size = 103259 } + [[package]] name = "pydantic" version = "2.10.6"