-
Notifications
You must be signed in to change notification settings - Fork 85
Expand file tree
/
Copy pathconftest.py
More file actions
159 lines (119 loc) · 4 KB
/
conftest.py
File metadata and controls
159 lines (119 loc) · 4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import subprocess
import sys
import tempfile
import types
from pathlib import Path
from typing import Any, Dict, List
import pytest
import torch
tempfile.tempdir = "/dev/shm"
def save_data(file_name, module_name, func_name, ret, args, kwargs):
data = {
"module_name": module_name,
"func_name": func_name,
"ret": ret,
"args": args,
"kwargs": kwargs,
}
torch.save(data, file_name)
def dump_test_py(file_name, test_before_file, test_after_file, pypath):
text = """
import sys
import os
from pathlib import Path
sys.path.insert(0, "{}")
import torch
import hpc
din = torch.load("{}")
dout = torch.load("{}")
func = getattr(hpc, din['func_name'])
args = din['args']
kwargs = din['kwargs']
gt = dout['ret']
s = func(*args, **kwargs)
# test output
def assert_equal(my, gt):
if isinstance(my, torch.Tensor):
assert torch.equal(my.byte(), gt.byte())
elif isinstance(my, tuple):
for i, e in enumerate(my):
assert_equal(my[i], gt[i])
elif isinstance(my, dict):
for k in my.keys():
assert_equal(my[k], gt[k])
else:
assert my == gt
assert_equal(s, gt)
assert_equal(args, dout['args'])
assert_equal(kwargs, dout['kwargs'])
""".format(
pypath, test_before_file, test_after_file
)
with open(file_name, "w") as fp:
fp.write(text)
def sanitizer_check(file_name, check):
cmd = f'compute-sanitizer --tool={check} --require-cuda-init=no --kernel-name regex="hpc.+" python3 {file_name}'
print(cmd)
try:
output = subprocess.check_output(cmd, shell=True)
text = output.decode("utf-8")
print(text)
except subprocess.CalledProcessError as e:
raise e
class TraceHook(object):
def __init__(self, checks, module_name):
self.checks_ = checks
self.module_name = module_name
def _wrap_func(self, module, func_name):
if not hasattr(module, func_name):
return False
org_func = getattr(module, func_name)
def wrapped(*args, **kwargs):
fd, tmp_py_file = tempfile.mkstemp(prefix="tmp_hpc_" + func_name + "_", suffix=".py")
os.close(fd)
tmp_before_invoke_file = tmp_py_file.replace(".py", "_before_invoke.pth")
tmp_after_invoke_file = tmp_py_file.replace(".py", "_after_invoke.pth")
save_data(tmp_before_invoke_file, "hpc", func_name, None, args, kwargs)
ret = org_func(*args, **kwargs)
save_data(tmp_after_invoke_file, "hpc", func_name, ret, args, kwargs)
pypath = os.path.realpath(list(Path(__file__).parent.glob("./build/lib.*/"))[0])
dump_test_py(tmp_py_file, tmp_before_invoke_file, tmp_after_invoke_file, pypath)
print(tmp_py_file)
for check in self.checks_:
print(f"{check}...")
sanitizer_check(tmp_py_file, check)
os.unlink(tmp_before_invoke_file)
os.unlink(tmp_after_invoke_file)
os.unlink(tmp_py_file)
return ret
if len(self.checks_) > 0:
setattr(module, func_name, wrapped)
return True
def hook(self):
sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("./build/lib.*/"))[0]))
module = __import__(self.module_name)
dirs = dir(module)
for d in dirs:
if d.endswith("fake") or d.startswith("_"):
continue
e = getattr(module, d)
if not isinstance(e, types.FunctionType):
continue
if not callable(e):
continue
self._wrap_func(module, d)
def get_checks():
checks = os.getenv("SANITIZER_CHECK")
if not checks:
return []
checks = [e.strip() for e in checks.split(",")]
return checks
# enable compute-sanitizer by set environment
# eg.
# export SANITIZER_CHECK=memcheck,synccheck,racecheck
# to enable memcheck etc.
def pytest_configure(config):
checks = get_checks()
hooker = TraceHook(checks, "hpc")
hooker.hook()