forked from bytedance/byteir
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathadd2.py
More file actions
33 lines (26 loc) · 1.14 KB
/
add2.py
File metadata and controls
33 lines (26 loc) · 1.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import brt
import torch
from torch.cuda.memory import caching_allocator_alloc, caching_allocator_delete
import numpy as np
import os
def main():
session = brt.Session(alloc_func=caching_allocator_alloc, free_func=caching_allocator_delete)
model_path = os.path.join(os.path.dirname(__file__), "add2.mlir")
session.load(model_path)
req = session.new_request_context(torch.cuda.current_stream()._as_parameter_.value)
inputs = []
outputs = []
for offset in session.get_input_arg_offsets():
data = np.random.random(size=session.get_static_shape(offset))
inputs.append(torch.tensor(data, dtype=torch.float32, device="cuda"))
req.bind_arg(offset, inputs[-1].data_ptr())
for offset in session.get_output_arg_offsets():
outputs.append(torch.empty(session.get_static_shape(offset), dtype=torch.float32, device="cuda"))
req.bind_arg(offset, outputs[-1].data_ptr())
req.finish_io_binding()
req.run()
req.sync()
torch.testing.assert_close(inputs[0] + inputs[1], outputs[0])
torch.testing.assert_close(inputs[0] + inputs[1] + inputs[1], outputs[1])
if __name__ == "__main__":
main()