--- title: import_utils keywords: fastai sidebar: home_sidebar summary: "API details." description: "API details." nb_path: "nbs/07_import_utils.ipynb" ---
_ONNX = is_installed("onnx")
_ONNXMLTOOLS = is_installed("onnxmltools")
import torch.onnx
import timm
torch_model = timm.create_model("resnet18")
torch_model.eval()
print()
x = torch.randn(1, 3, 224, 224, requires_grad=True)
torch_out = torch_model(x)
# Export the model
torch.onnx.export(torch_model, # model being run
x, # model input (or a tuple for multiple inputs)
"temp.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}})
import onnx
from onnx2pytorch import ConvertModel
onnx_model = onnx.load('temp.onnx')
onnx.checker.check_model(onnx_model)
pytorch_model = ConvertModel(onnx_model)
import numpy as np
np.allclose(pytorch_model(x).detach().numpy(), torch_out.detach().numpy(), 1e-4)