Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 64 additions & 20 deletions needs_config_writer/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ def write_needscfg_file(
srcdir: Optional source directory (defaults to app.srcdir)
"""

def get_safe_config(obj: Any, path: str = "", outpath: Path | None = None) -> Any:
def get_safe_config(
obj: Any,
path: str = "",
outpath: Path | None = None,
visited: set[int] | None = None,
) -> Any:
"""
Recursively walk needs config and make it TOML serialisable.

Expand All @@ -45,21 +50,50 @@ def get_safe_config(obj: Any, path: str = "", outpath: Path | None = None) -> An

Special handling:
- PosixPath objects are converted to strings (with optional relativization)
- Circular references are detected and filtered out to prevent infinite recursion

Args:
obj: The object to convert
path: The current path for debugging (e.g., "needs.types[0].directive")
outpath: The output file path for relativizing absolute paths
visited: Set of object IDs already visited (for circular reference detection)

Returns:
The converted object if serializable, or None if the value should be filtered out
"""
from datetime import date, datetime, time

# Initialize visited set on first call
if visited is None:
visited = set()

# Filter out None - TOML doesn't support null values
if obj is None:
return None

# Check for circular references (only for mutable objects that can contain references)
# Skip this check for immutable types and simple types
# We track objects in the current traversal path to detect true circular refs (A -> B -> A)
# but allow the same object to be referenced from different paths (A -> C, B -> C)
if isinstance(obj, (dict, list, tuple, set)):
obj_id = id(obj)
if obj_id in visited:
log_warning(
LOGGER,
f"Circular reference detected at '{path}' - filtering out to prevent infinite recursion",
"circular_reference",
location=None,
)
return None
# Add to visited set for this traversal path
visited.add(obj_id)
# We'll remove it after processing to allow the same object from different paths
should_remove_from_visited = True
visited_obj_id = obj_id
else:
should_remove_from_visited = False
visited_obj_id = None

# Check if this path should be relativized based on allowlist
should_relativize = False
path_prefix = None
Expand Down Expand Up @@ -177,27 +211,37 @@ def get_safe_config(obj: Any, path: str = "", outpath: Path | None = None) -> An
return obj

if isinstance(obj, dict):
result = {}
for key, value in obj.items():
item_path = f"{path}.{key}" if path else str(key)
safe_value = get_safe_config(value, item_path, outpath)
if safe_value is not None:
result[key] = safe_value
return result
try:
result = {}
for key, value in obj.items():
item_path = f"{path}.{key}" if path else str(key)
safe_value = get_safe_config(value, item_path, outpath, visited)
if safe_value is not None:
result[key] = safe_value
return result
finally:
# Remove from visited to allow same object from different paths
if should_remove_from_visited and visited_obj_id is not None:
visited.discard(visited_obj_id)

if isinstance(obj, (list, tuple, set)):
items = []
for idx, item in enumerate(obj):
item_path = f"{path}[{idx}]"
safe_value = get_safe_config(item, item_path, outpath)
if safe_value is not None:
items.append(safe_value)

if isinstance(obj, tuple):
return tuple(items)
if isinstance(obj, set):
return set(items)
return items
try:
items = []
for idx, item in enumerate(obj):
item_path = f"{path}[{idx}]"
safe_value = get_safe_config(item, item_path, outpath, visited)
if safe_value is not None:
items.append(safe_value)

if isinstance(obj, tuple):
return tuple(items)
if isinstance(obj, set):
return set(items)
return items
finally:
# Remove from visited to allow same object from different paths
if should_remove_from_visited and visited_obj_id is not None:
visited.discard(visited_obj_id)

# If it's not a TOML-serializable type, warn and filter it out
log_warning(
Expand Down
41 changes: 41 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,3 +1402,44 @@ def test_relative_path_with_prefix(

assert content == snapshot
app.cleanup()


def test_no_recursion_error_on_rebuild(
tmpdir: Path,
make_app: Callable[[], SphinxTestApp],
write_fixture_files: Callable[[Path, dict[str, Any]], None],
) -> None:
"""Test that rebuilding doesn't cause recursion errors (simulates hot reload)."""
conf_py = textwrap.dedent(
"""
extensions = [
"sphinx_needs",
"needs_config_writer",
]
needscfg_add_header = False
needscfg_overwrite = True
"""
)
index_rst = textwrap.dedent(
"""
Headline
========
"""
)
file_contents: dict[str, str] = {
"conf": conf_py,
"rst": index_rst,
}
write_fixture_files(tmpdir, file_contents)

# First build
app: SphinxTestApp = make_app(srcdir=Path(tmpdir), freshenv=True)
app.build()
assert app.statuscode == 0
app.cleanup()

# Second build (simulates hot reload) - should not cause recursion error
app2: SphinxTestApp = make_app(srcdir=Path(tmpdir), freshenv=False)
app2.build()
assert app2.statuscode == 0
app2.cleanup()
Loading