-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathseeding.py
More file actions
122 lines (87 loc) · 3.91 KB
/
seeding.py
File metadata and controls
122 lines (87 loc) · 3.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Deterministic seeding utilities.
This module hashes a master seed together with component, run and stream
identifiers to derive 64-bit sub-seeds. The message is encoded as
``"{master_seed}|{component_id}|{run_id}|{stream_id}"`` and hashed with
SHA-256, taking the first eight bytes as a big-endian integer. The resulting
sub-seed can be fed to independent random number generators.
Philox, a counter-based RNG available in NumPy, is chosen because it allows
reproducible, stateless streams that can be advanced independently across
parallel processes.
"""
from __future__ import annotations
import argparse
import hashlib
import random
import numpy as np
try: # Optional PyTorch integration
import torch
except Exception: # pragma: no cover - environment without torch
torch = None # type: ignore
def make_subseed(
master_seed: int | str, component_id: str, run_id: str, stream_id: int | str = 0
) -> int:
"""Derive a deterministic 64-bit sub-seed.
Args:
master_seed: Global seed as an integer or string.
component_id: Identifier for the component (e.g., "dataloader").
run_id: Identifier for the current experiment/run.
stream_id: Optional sub-stream identifier.
Returns:
The first eight bytes of the SHA-256 digest interpreted as a
big-endian integer.
"""
message = f"{master_seed}|{component_id}|{run_id}|{stream_id}"
digest = hashlib.sha256(message.encode("utf-8")).digest()
return int.from_bytes(digest[:8], "big")
def philox_rng(subseed: int) -> np.random.Generator:
"""Create a NumPy Philox generator seeded with ``subseed``."""
return np.random.Generator(np.random.Philox(subseed))
def python_rng(subseed: int) -> random.Random:
"""Return a ``random.Random`` instance seeded with ``subseed``."""
rng = random.Random()
rng.seed(subseed)
return rng
def torch_rng(subseed: int, device: str | torch.device = "cpu"):
"""Return a torch ``Generator`` seeded with ``subseed``.
Args:
subseed: Seed value for the generator.
device: Torch device string or ``torch.device``. Defaults to ``"cpu"``.
Raises:
ImportError: If PyTorch is not installed.
"""
if torch is None: # pragma: no cover - only hit when torch missing
raise ImportError("PyTorch is not installed")
gen = torch.Generator(device)
gen.manual_seed(int(subseed))
return gen
def set_torch_deterministic(enabled: bool = True) -> None:
"""Toggle deterministic algorithms and cuDNN flags in PyTorch.
Args:
enabled: Whether to enable deterministic behaviour.
Raises:
ImportError: If PyTorch is not installed.
"""
if torch is None: # pragma: no cover - only hit when torch missing
raise ImportError("PyTorch is not installed")
torch.backends.cudnn.deterministic = enabled
torch.backends.cudnn.benchmark = not enabled
torch.use_deterministic_algorithms(enabled)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Deterministic seeding utility")
parser.add_argument("--master-seed", required=True, help="Master seed (int or str)")
parser.add_argument("--component", required=True, help="Component identifier")
parser.add_argument("--run-id", required=True, help="Run identifier")
parser.add_argument("--stream-id", type=int, default=0, help="Stream identifier")
parser.add_argument("--n", type=int, default=5, help="How many numbers to draw")
args = parser.parse_args()
subseed = make_subseed(
args.master_seed, args.component, args.run_id, args.stream_id
)
print(f"Subseed: {subseed}")
np_rng = philox_rng(subseed)
print("NumPy Philox:", np_rng.random(args.n))
py_rng = python_rng(subseed)
print("Python random:", [py_rng.random() for _ in range(args.n)])
if torch is not None:
tgen = torch_rng(subseed)
print("PyTorch:", torch.rand(args.n, generator=tgen).tolist())