From 2b7e13c5e05efd40e295a24e97d2a3d761d86d7d Mon Sep 17 00:00:00 2001 From: lb1192176991-lab Date: Fri, 5 Jun 2026 00:39:45 +0800 Subject: [PATCH] feat: concurrent wallet scanning with progress bar and ThreadPoolExecutor (Fixes #29) --- scripts/rugguard.py | 144 ++++++++++++++++++++++++++++++++++++++++++- tests/test_checks.py | 28 ++++++++- 2 files changed, 170 insertions(+), 2 deletions(-) diff --git a/scripts/rugguard.py b/scripts/rugguard.py index dcd51cc..6821f6a 100644 --- a/scripts/rugguard.py +++ b/scripts/rugguard.py @@ -1282,6 +1282,141 @@ def rug_check_wallet(address: str) -> dict: # ── Watch / History / Webhook Support ────────────────────────────────────── +# Concurrent wallet scanning +WALLET_SCAN_WORKERS = int(os.environ.get("WALLET_SCAN_WORKERS", "4")) + + +def _rug_check_token_safe(mint: str) -> tuple[str, float, str, list[str]]: + """Thread-safe wrapper for rug_check_token. Returns (mint, score, level, warnings).""" + try: + report = rug_check_token(mint) + return (mint, report.safety_score, report.risk_level, report.warnings) + except Exception as e: + return (mint, 0, "ERROR", [str(e)[:80]]) + + +def rug_check_wallet_concurrent(address: str) -> dict: + """Scan a wallet using concurrent workers with a progress bar. + + Falls back to synchronous mode when WALLET_SCAN_WORKERS=1. + """ + from concurrent.futures import ThreadPoolExecutor, as_completed + import time as _time + + result = _rpc_call("getTokenAccountsByOwner", [ + address, + {"programId": "TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA"}, + {"encoding": "jsonParsed"}, + ]) + if not result or "value" not in result: + return { + "address": address, + "error": "Could not fetch wallet token accounts", + "total_tokens": 0, + "risky_tokens": [], + "summary": "Wallet scan failed — RPC error.", + } + + tokens = result["value"] + total = len(tokens) + is_tty = sys.stdout.isatty() + workers = max(1, min(WALLET_SCAN_WORKERS, 10)) + + # Collect mints that have meaningful balance + mints_to_scan = [] + for token_acct in tokens: + acct_data = token_acct.get("account", {}).get("data", {}).get("parsed", {}).get("info", {}) + mint = acct_data.get("mint", "") + amount = int(acct_data.get("tokenAmount", {}).get("amount", "0")) + decimals = acct_data.get("tokenAmount", {}).get("decimals", 0) + if amount == 0: + continue + if decimals > 0 and amount > 10 ** (decimals - 4): + mints_to_scan.append(mint) + + scan_total = len(mints_to_scan) + risky_tokens = [] + completed = 0 + start_time = _time.time() + + def _progress(n: int, risks: int, done: bool = False) -> None: + if not is_tty: + return + elapsed = _time.time() - start_time + pct = n / scan_total * 100 if scan_total > 0 else 100 + bar_w = 20 + fill = int(bar_w * n / scan_total) if scan_total > 0 else bar_w + bar = "\u2588" * fill + "\u2591" * (bar_w - fill) + if n > 0 and elapsed > 0: + eta_secs = (elapsed / n) * (scan_total - n) + eta = f"{int(eta_secs // 60)}m{int(eta_secs % 60)}s" if eta_secs > 60 else f"{int(eta_secs)}s" + else: + eta = "?" + if done: + end = "\n" + else: + end = "" + print(f"\r [{bar}] {n}/{scan_total} ({pct:.0f}%) | {risks} risks | {int(elapsed)}s elapsed ETA:{eta}", end=end, flush=True) + + _progress(0, 0) + + if workers <= 1 or scan_total <= 2: + # Sequential mode + for mint in mints_to_scan: + score, level, warnings = 0, "UNKNOWN", [] + try: + report = rug_check_token(mint.strip()) + score = report.safety_score + level = report.risk_level + warnings = report.warnings + except Exception: + pass + completed += 1 + if score < 60: + risky_tokens.append({ + "mint": mint, + "safety_score": score, + "risk_level": level, + "top_warnings": warnings[:3], + }) + _progress(completed, len(risky_tokens)) + else: + # Parallel mode + with ThreadPoolExecutor(max_workers=workers) as executor: + futures = {executor.submit(_rug_check_token_safe, m): m for m in mints_to_scan} + for future in as_completed(futures): + mint, score, level, warnings = future.result() + completed += 1 + if score < 60: + risky_tokens.append({ + "mint": mint, + "safety_score": score, + "risk_level": level, + "top_warnings": warnings[:3], + }) + _progress(completed, len(risky_tokens)) + + _progress(completed, len(risky_tokens), done=True) + + risky_tokens.sort(key=lambda t: t["safety_score"]) + total_tok = len(tokens) + risky_count = len(risky_tokens) + + if risky_count == 0: + summary = f"Concurrent scan: no high-risk tokens detected among {total_tok} accounts." + else: + summary = f"Concurrent scan: {risky_count} risky / {total_tok} total accounts." + + return { + "address": address, + "total_tokens": total_tok, + "risky_count": risky_count, + "risky_tokens": risky_tokens, + "summary": summary, + "scan_time_seconds": round(_time.time() - start_time, 1), + } + + DEFAULT_HISTORY_DB = os.environ.get("SOLANA_RUG_HISTORY_DB", os.path.expanduser("~/.solana-rug/history.sqlite3")) HISTORY_RETENTION_DAYS = int(os.environ.get("SOLANA_RUG_HISTORY_RETENTION_DAYS", "90")) WEBHOOK_COOLDOWN_SECONDS = 3600 # at most 1 alert per token per hour @@ -1655,7 +1790,11 @@ def cli_wallet(args: list[str]) -> None: print('Usage: python rugguard.py wallet
', file=sys.stderr) sys.exit(1) - result = rug_check_wallet(address.strip()) + use_concurrent = "--sequential" not in args + if use_concurrent: + result = rug_check_wallet_concurrent(address.strip()) + else: + result = rug_check_wallet(address.strip()) print(json.dumps(result, indent=2, default=str)) if result.get("risky_count", 0) > 0: @@ -1673,6 +1812,7 @@ def cli_help() -> None: --json Output as JSON (default for token) --markdown Output as Markdown report --md Alias for --markdown + --sequential Disable concurrent wallet scanning (uses original sequential mode) WATCH OPTIONS: --interval Seconds between checks (default: 60) @@ -1689,6 +1829,8 @@ def cli_help() -> None: ENVIRONMENT: SOLANA_RPC_URL Override RPC endpoint (default: api.mainnet-beta.solana.com) + WALLET_SCAN_WORKERS + Concurrent wallet scan workers (default: 4, max: 10) EXIT CODES: 0 No critical risks detected diff --git a/tests/test_checks.py b/tests/test_checks.py index 27ac3ce..0e31799 100644 --- a/tests/test_checks.py +++ b/tests/test_checks.py @@ -450,4 +450,30 @@ def test_wallet_scan() -> None: assert "address" in result assert result["address"] == TEST_WALLET assert "total_tokens" in result - assert isinstance(result["total_tokens"], int) + assert isinstance(result['total_tokens'], int) + + +# ── Concurrent Scan Tests ───────────────────────────────────────────────── + +class TestConcurrentWallet: + def test_rug_check_token_safe(self): + from rugguard import _rug_check_token_safe + import sys + sys.path.insert(0, str(Path(__file__).resolve().parent.parent / 'scripts')) + # Test with invalid mint — should not crash + mint, score, level, warnings = _rug_check_token_safe("InvalidMint123") + assert isinstance(mint, str) + assert isinstance(score, (int, float)) + assert isinstance(level, str) + assert isinstance(warnings, list) + + def test_sequential_fallback(self): + """WALLET_SCAN_WORKERS=1 should fall back to sequential mode.""" + import os + os.environ['WALLET_SCAN_WORKERS'] = '1' + # Import after setting env + from rugguard import rug_check_wallet_concurrent + # Should complete without error (RPC call may fail gracefully) + result = rug_check_wallet_concurrent("9WzDXwBbmkg8ZTbNMqUxvQRAyrZzDsGYdLVL9zYtAWWM") + assert 'address' in result + assert isinstance(result.get('total_tokens', 0), int)