-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathneedle_test.py
More file actions
55 lines (42 loc) · 2.07 KB
/
needle_test.py
File metadata and controls
55 lines (42 loc) · 2.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""
Proprietary / All Rights Reserved - Non-Commercial Use Only
Source-available for portfolio viewing only. Commercial use, unauthorized modification, reproduction, or distribution is strictly prohibited. All rights reserved.
"""
import random
import torch
from hspmn_v3_0 import HSPMNBlock
from utils_v3_0 import HSPMNConfig, setup_env, get_device, setup_logging, seed_everything
logger = setup_logging(__name__)
def run_needle_in_haystack(model, device, dtype, seq_len=16384, dim=256, num_tests=10):
"""Embeds a distinct 'needle' token at random depths and checks router retrieval."""
model.eval()
success_count = 0
for i in range(num_tests):
x = torch.randn(1, seq_len, dim, device=device, dtype=dtype)
needle_pos = random.randint(0, seq_len - 1)
x[:, needle_pos:needle_pos + 1, :] = 10.0
with torch.inference_mode():
_ = model(x)
router_out = model.router(x)
selected_indices = router_out.indices[0].tolist()
if needle_pos in selected_indices:
success_count += 1
logger.info(f"Test {i+1}/{num_tests}: Needle found at depth {needle_pos}")
else:
logger.warning(f"Test {i+1}/{num_tests}: Needle missed at depth {needle_pos}")
accuracy = success_count / num_tests
logger.info(f"Needle-in-a-Haystack accuracy: {accuracy * 100:.2f}%")
return accuracy
if __name__ == "__main__":
setup_env()
device = get_device()
seed_everything()
config = HSPMNConfig(dim=256, num_heads=4, num_kv_heads=4, max_seq_len=16384, sparsity_k=0.1)
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
model = HSPMNBlock(config).to(device, dtype=dtype)
model = torch.compile(model, mode="max-autotune", fullgraph=True)
with torch.no_grad():
model.router.gate.weight.fill_(0.1)
model.router.route_bias.fill_(0.0) # Buffer init, no_grad is correct here
logger.info("Running Needle-in-a-Haystack test...")
run_needle_in_haystack(model, device, dtype, seq_len=16384, dim=256, num_tests=20)