diff --git a/.jules/bolt.md b/.jules/bolt.md new file mode 100644 index 00000000..ff87cd3c --- /dev/null +++ b/.jules/bolt.md @@ -0,0 +1 @@ +## 2024-06-19 - AST Traversal Optimization\n**Learning:** The memory mentions that for hot-path AST traversal, using eager list-appending (`list.append()`) instead of `yield from` recursion is better for performance.\n**Action:** Replace `yield from` generator usage in `walk_node` and related core AST traversal functions with eager list population, maintaining `isinstance()` type checks for type safety. diff --git a/src/wardline/core/autofix.py b/src/wardline/core/autofix.py index f8660bd3..29734f1a 100644 --- a/src/wardline/core/autofix.py +++ b/src/wardline/core/autofix.py @@ -50,12 +50,16 @@ def has_comment_in_span( def _own_statements(node: ast.AST) -> Iterator[ast.stmt]: - for child in ast.iter_child_nodes(node): - if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + result: list[ast.stmt] = [] + stack: list[ast.AST] = list(reversed(list(ast.iter_child_nodes(node)))) + while stack: + current = stack.pop() + if isinstance(current, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): continue - if isinstance(child, ast.stmt): - yield child - yield from _own_statements(child) + if isinstance(current, ast.stmt): + result.append(current) + stack.extend(reversed(list(ast.iter_child_nodes(current)))) + return iter(result) def get_assert_nodes_for_function(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> list[ast.Assert]: diff --git a/src/wardline/scanner/analyzer.py b/src/wardline/scanner/analyzer.py index a8420454..c9d1e9a1 100644 --- a/src/wardline/scanner/analyzer.py +++ b/src/wardline/scanner/analyzer.py @@ -452,16 +452,16 @@ def _bind_call_site_arguments_to_parameters( return result def _iter_l2_body_nodes(node: ast.FunctionDef | ast.AsyncFunctionDef) -> Iterator[ast.AST]: - def walk(current: ast.AST) -> Iterator[ast.AST]: - for child in ast.iter_child_nodes(current): + result: list[ast.AST] = [] + stack: list[ast.AST] = list(reversed(node.body)) + while stack: + current = stack.pop() + result.append(current) + for child in reversed(list(ast.iter_child_nodes(current))): if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Lambda)): continue - yield child - yield from walk(child) - - for stmt in node.body: - yield stmt - yield from walk(stmt) + stack.append(child) + return iter(result) def _assignment_targets(node: ast.AST) -> list[ast.expr]: if isinstance(node, ast.Assign): diff --git a/src/wardline/scanner/ast_primitives.py b/src/wardline/scanner/ast_primitives.py index 70f565b3..b5b2e8e2 100644 --- a/src/wardline/scanner/ast_primitives.py +++ b/src/wardline/scanner/ast_primitives.py @@ -105,38 +105,33 @@ def iter_calls_in_function_body( values, base classes, metaclass keywords) are still attributed to ``node``. """ - def walk_node(current: ast.AST) -> Iterator[ast.Call]: + result: list[ast.Call] = [] + stack: list[ast.AST] = list(reversed(node.body)) + + while stack: + current = stack.pop() if isinstance(current, (ast.FunctionDef, ast.AsyncFunctionDef)): - for decorator in current.decorator_list: - yield from walk_node(decorator) - yield from _walk_argument_defaults(current.args) - return + args = current.args + stack.extend(reversed([kw for kw in args.kw_defaults if kw is not None])) + stack.extend(reversed(args.defaults)) + stack.extend(reversed(current.decorator_list)) + continue if isinstance(current, ast.ClassDef): - for decorator in current.decorator_list: - yield from walk_node(decorator) - for base in current.bases: - yield from walk_node(base) - for keyword in current.keywords: - yield from walk_node(keyword.value) - return + stack.extend(reversed([kw.value for kw in current.keywords])) + stack.extend(reversed(current.bases)) + stack.extend(reversed(current.decorator_list)) + continue if isinstance(current, ast.Lambda): - yield from _walk_argument_defaults(current.args) - return + args = current.args + stack.extend(reversed([kw for kw in args.kw_defaults if kw is not None])) + stack.extend(reversed(args.defaults)) + continue if isinstance(current, ast.Call): - yield current - for child in ast.iter_child_nodes(current): - yield from walk_node(child) - - def _walk_argument_defaults(args: ast.arguments) -> Iterator[ast.Call]: - for default in args.defaults: - yield from walk_node(default) - for kw_default in args.kw_defaults: - if kw_default is None: - continue - yield from walk_node(kw_default) + result.append(current) + + stack.extend(reversed(list(ast.iter_child_nodes(current)))) - for stmt in node.body: - yield from walk_node(stmt) + return iter(result) def resolve_self_method_fqn( diff --git a/src/wardline/scanner/rules/_ast_helpers.py b/src/wardline/scanner/rules/_ast_helpers.py index 7c3b52ff..ba0be13c 100644 --- a/src/wardline/scanner/rules/_ast_helpers.py +++ b/src/wardline/scanner/rules/_ast_helpers.py @@ -36,12 +36,16 @@ def _own_statements(node: ast.AST) -> Iterator[ast.stmt]: """Yield every statement in *node*'s own scope, not descending into nested def/class bodies. Includes the bodies of if/for/while/try/with at any depth.""" - for child in ast.iter_child_nodes(node): - if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + result: list[ast.stmt] = [] + stack: list[ast.AST] = list(reversed(list(ast.iter_child_nodes(node)))) + while stack: + current = stack.pop() + if isinstance(current, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): continue - if isinstance(child, ast.stmt): - yield child - yield from _own_statements(child) + if isinstance(current, ast.stmt): + result.append(current) + stack.extend(reversed(list(ast.iter_child_nodes(current)))) + return iter(result) def _own_reachable_statements( @@ -67,14 +71,18 @@ def _own_nodes_in_reachable_stmt(stmt: ast.stmt) -> Iterator[ast.AST]: def _walk_own_non_stmt_children(node: ast.AST) -> Iterator[ast.AST]: - for child in ast.iter_child_nodes(node): - if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Lambda)): - yield child - elif isinstance(child, ast.stmt): + result: list[ast.AST] = [] + stack: list[ast.AST] = list(reversed(list(ast.iter_child_nodes(node)))) + while stack: + current = stack.pop() + if isinstance(current, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Lambda)): + result.append(current) + elif isinstance(current, ast.stmt): continue else: - yield child - yield from _walk_own_non_stmt_children(child) + result.append(current) + stack.extend(reversed(list(ast.iter_child_nodes(current)))) + return iter(result) def _reachable_statements_in_block( @@ -639,9 +647,13 @@ def own_nodes(node: ast.AST) -> Iterator[ast.AST]: def _walk_own(node: ast.AST) -> Iterator[ast.AST]: - for child in ast.iter_child_nodes(node): - if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Lambda)): - yield child + result: list[ast.AST] = [] + stack: list[ast.AST] = list(reversed(list(ast.iter_child_nodes(node)))) + while stack: + current = stack.pop() + if isinstance(current, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Lambda)): + result.append(current) else: - yield child - yield from _walk_own(child) + result.append(current) + stack.extend(reversed(list(ast.iter_child_nodes(current)))) + return iter(result) diff --git a/src/wardline/scanner/rules/_sink_helpers.py b/src/wardline/scanner/rules/_sink_helpers.py index 5cb5d6bf..2015ff5d 100644 --- a/src/wardline/scanner/rules/_sink_helpers.py +++ b/src/wardline/scanner/rules/_sink_helpers.py @@ -126,12 +126,16 @@ def _own_calls(node: ast.AST) -> Iterator[ast.Call]: the entity index does not emit separate lambda entities; skipping them would hide dangerous calls from sink rules. """ - for child in ast.iter_child_nodes(node): - if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + result: list[ast.Call] = [] + stack: list[ast.AST] = list(reversed(list(ast.iter_child_nodes(node)))) + while stack: + current = stack.pop() + if isinstance(current, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): continue - if isinstance(child, ast.Call): - yield child - yield from _own_calls(child) + if isinstance(current, ast.Call): + result.append(current) + stack.extend(reversed(list(ast.iter_child_nodes(current)))) + return iter(result) def _direct_sink_fqn( diff --git a/src/wardline/scanner/taint/callgraph.py b/src/wardline/scanner/taint/callgraph.py index 38c2908f..f9ea4352 100644 --- a/src/wardline/scanner/taint/callgraph.py +++ b/src/wardline/scanner/taint/callgraph.py @@ -43,23 +43,32 @@ def _own_nodes_in(node: ast.AST) -> Iterator[ast.AST]: """Yield *node* and every descendant in its own scope (including *node* itself), not descending into nested def/class/lambda scopes.""" - yield node - for child in ast.iter_child_nodes(node): - if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Lambda)): - continue - yield from _own_nodes_in(child) + result: list[ast.AST] = [] + stack: list[ast.AST] = [node] + while stack: + current = stack.pop() + result.append(current) + for child in reversed(list(ast.iter_child_nodes(current))): + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Lambda)): + continue + stack.append(child) + return iter(result) def _target_names(target: ast.expr) -> Iterator[str]: """Yield the plain ``Name`` ids bound by an assignment/loop target (recursing into tuple/list/starred destructuring); attribute/subscript targets bind no local name.""" - if isinstance(target, ast.Name): - yield target.id - elif isinstance(target, ast.Starred): - yield from _target_names(target.value) - elif isinstance(target, (ast.Tuple, ast.List)): - for elt in target.elts: - yield from _target_names(elt) + result: list[str] = [] + stack: list[ast.expr] = [target] + while stack: + current = stack.pop() + if isinstance(current, ast.Name): + result.append(current.id) + elif isinstance(current, ast.Starred): + stack.append(current.value) + elif isinstance(current, (ast.Tuple, ast.List)): + stack.extend(reversed(current.elts)) + return iter(result) def _candidate_receiver_classes( diff --git a/src/wardline/scanner/taint/variable_level.py b/src/wardline/scanner/taint/variable_level.py index 056c083c..a0f61f1c 100644 --- a/src/wardline/scanner/taint/variable_level.py +++ b/src/wardline/scanner/taint/variable_level.py @@ -393,12 +393,16 @@ def _own_scope_lambdas(node: ast.AST) -> Iterator[ast.Lambda]: """Yield every ``ast.Lambda`` in *node*'s own scope (descends into lambdas, which are not separate entities, but NOT into nested ``def``/``class`` — those are analyzed as their own entities).""" - for child in ast.iter_child_nodes(node): - if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + result: list[ast.Lambda] = [] + stack: list[ast.AST] = list(reversed(list(ast.iter_child_nodes(node)))) + while stack: + current = stack.pop() + if isinstance(current, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): continue - if isinstance(child, ast.Lambda): - yield child - yield from _own_scope_lambdas(child) + if isinstance(current, ast.Lambda): + result.append(current) + stack.extend(reversed(list(ast.iter_child_nodes(current)))) + return iter(result) def _worst_ever_var_taints(