Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions cosmos_framework/utils/checkpoint_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ def _hf_download(cmd_args: list[str]) -> str:

Uses a newer Hugging Face CLI version to download checkpoint. The dependency
version is very old and not robust.

--format=json is intentionally omitted: it silences tqdm progress bars.
Without it, hf download streams tqdm to stderr (visible on rank0) and
prints the local snapshot path as the last stdout line.

stdout is captured via Popen so we can guarantee the child is killed on
KeyboardInterrupt, preventing stale file-lock deadlocks on the next run.
"""
is_rank0 = os.environ.get("RANK", "0") == "0"
cmd = [
Expand All @@ -152,18 +159,33 @@ def _hf_download(cmd_args: list[str]) -> str:
"click",
f"hf@{HF_VERSION}",
"download",
"--format=json",
*cmd_args,
]
log.info(f"{shlex.join(cmd)}")
output = subprocess.run(
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=None if is_rank0 else subprocess.PIPE,
text=True,
check=True,
)
return json.loads(output.stdout)["path"]
stdout_lines: list[str] = []
try:
assert proc.stdout is not None
for line in proc.stdout:
stdout_lines.append(line)
proc.wait()
except BaseException:
# Ensure the child is always reaped (covers KeyboardInterrupt too),
# so it cannot keep holding HuggingFace file-system locks.
proc.kill()
proc.wait()
raise
if proc.returncode != 0:
raise subprocess.CalledProcessError(proc.returncode, cmd)
# Without --format=json, hf download prints the local path as the last
# (and typically only) stdout line.
last_line = next((ln for ln in reversed(stdout_lines) if ln.strip()), "")
return last_line.strip()


class RepositoryType(StrEnum):
Expand Down
Loading