Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion packages/mcp-hmr/mcp_hmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,52 @@
__all__ = "mcp_server", "run_with_hmr"


def _resolve_watch_path(module_or_path: str) -> str:
"""Resolve a stable watch directory for hot reload.

Why this exists:
`AsyncReloader` expects a non-empty, valid path to watch. In mcp-hmr 0.0.3.2
the watcher was initialized with an empty string, which can result in no
files being watched (or errors) depending on the platform / watcher backend.

Strategy:
- If the target is a file path, watch its parent directory.
- If the target is an importable module, watch the directory containing the
module's file (or package __init__.py).
- Fall back to the current working directory.
"""
# path:attr target
p = Path(module_or_path)
if p.is_file():
return str(p.resolve().parent)

# Some callers may provide a directory path directly; watch it as-is.
if p.is_dir():
return str(p.resolve())

# module:attr target
try:
spec = find_spec(module_or_path)
except (ImportError, ModuleNotFoundError, TypeError, ValueError):
spec = None

if spec is not None:
# For built-in / frozen modules, spec.origin can be strings like "built-in".
# Use has_location to ensure origin is a real filesystem location.
if getattr(spec, "has_location", False) and spec.origin:
origin_path = Path(spec.origin)
if origin_path.is_file():
# For packages, origin points at __init__.py; for modules, origin is the .py file.
return str(origin_path.resolve().parent)

# Namespace packages may have no origin; use their search location.
locations = list(spec.submodule_search_locations or ())
if locations:
return str(Path(locations[0]).resolve())

return str(Path.cwd())


def mcp_server(target: str):
module, attr = target.rsplit(":", 1)

Expand Down Expand Up @@ -80,7 +126,7 @@ async def main():

class Reloader(AsyncReloader):
def __init__(self):
super().__init__("")
super().__init__(_resolve_watch_path(module))
self.error_filter.exclude_filenames.add(__file__)

async def __aenter__(self):
Expand Down