I'm hoping to try out your model with my custom data, but I need to get it converted to ONNX eventually, so I thought I'd try converting the simple examples first, as a test.
I'm just running, as a quick training pass:
python main.py pretrain --train_path data/example/train.txt --val_path data/example/val.txt
Then I try to load/convert the checkpoint with:
import torch
from collections import OrderedDict
from bert.train.model.bert import build_model
from bert.preprocess.preprocess import add_preprocess_parser
from bert.train.train import add_pretrain_parser, add_finetune_parser
trained_pth = 'checkpoints/BERT-BERT-{phase}-layers_count={layers_count}-hidden_size={hidden_size}-heads_count={heads_count}-{timestamp}-layers_count=1-hidden_size=128-heads_count=2-2019_07_14_17_44_38/epoch=010-val_loss=6.14-val_metrics=0.0-0.331.pth'
state_dict = torch.load(trained_pth, map_location='cpu')['state_dict'] # NOTE: This is an OrderedDict()
ordered_dict = OrderedDict()
for k, v in state_dict.items():
name = k[13:]
ordered_dict[name] = v
model = build_model(1, 128, 1, 128, 0.1, 512, 151)
model.load_state_dict(ordered_dict)
dummy_input = (torch.randn(1, 128).long(), torch.randn(1, 128).long())
input_names = ["input_sequence", "segment"]
output_names = ["predictions"]
torch.onnx.export(model, dummy_input,"bert.onnx", verbose=True, input_names=input_names, output_names=output_names)
Obviously something's wrong, because I'm hitting the following error:
File "/home/james/src/BERT-pytorch/basic_test_to_onnx.py", line 24, in <module>
torch.onnx.export(model, dummy_input,"bert.onnx", verbose=True, input_names=input_names, output_names=output_names)
File "/home/james/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/onnx/__init__.py", line 25, in export
return utils.export(*args, **kwargs)
File "/home/james/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/onnx/utils.py", line 131, in export
strip_doc_string=strip_doc_string)
File "/home/james/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/onnx/utils.py", line 363, in _export
_retain_param_name, do_constant_folding)
File "/home/james/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/onnx/utils.py", line 266, in _model_to_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
File "/home/james/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/onnx/utils.py", line 225, in _trace_and_get_graph_from_model
trace, torch_out = torch.jit.get_trace_graph(model, args, _force_outplace=True)
File "/home/james/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/jit/__init__.py", line 231, in get_trace_graph
return LegacyTracedModule(f, _force_outplace, return_inputs)(*args, **kwargs)
File "/home/james/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/home/james/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/jit/__init__.py", line 294, in forward
out = self.inner(*trace_inputs)
File "/home/james/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self._slow_forward(*input, **kwargs)
File "/home/james/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/nn/modules/module.py", line 481, in _slow_forward
result = self.forward(*input, **kwargs)
builtins.TypeError: forward() takes 2 positional arguments but 3 were given
But it doesn't look like I have direct access to the caller of this, so I'm really not sure where the extra argument is coming from, or how I might fix it. Do you know, offhand, whether this model can be converted successfully to ONNX?
I'm hoping to try out your model with my custom data, but I need to get it converted to ONNX eventually, so I thought I'd try converting the simple examples first, as a test.
I'm just running, as a quick training pass:
python main.py pretrain --train_path data/example/train.txt --val_path data/example/val.txtThen I try to load/convert the checkpoint with:
Obviously something's wrong, because I'm hitting the following error:
But it doesn't look like I have direct access to the caller of this, so I'm really not sure where the extra argument is coming from, or how I might fix it. Do you know, offhand, whether this model can be converted successfully to ONNX?