--- title: Model Interconversion keywords: fastai sidebar: home_sidebar summary: "API details." description: "API details." nb_path: "nbs/08_model_converter.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %} {% raw %}

pytorch_to_onnx[source]

pytorch_to_onnx(model, tensor, export_path='temp.onnx')

{% endraw %} {% raw %}
{% endraw %} {% raw %}

onnx_to_pytorch[source]

onnx_to_pytorch(onnx_model)

{% endraw %} {% raw %}

tf2_to_onnx[source]

tf2_to_onnx(model, opset=None, output_path=None, **kwargs)

{% endraw %} {% raw %}

tf2_to_pytorch[source]

tf2_to_pytorch(model, opset=None, **kwargs)

{% endraw %} {% raw %}
{% endraw %}

example

{% raw %}
import numpy as np
import timm

model1 = timm.create_model("resnet18")
model1.eval()

model_inter_path = pytorch_to_onnx(model1, torch.randn(1, 3, 224, 224))
model2 = onnx_to_pytorch(model_inter_path)

x = torch.randn(1, 3, 224, 224)
np.allclose(model1(x).detach().numpy(), model2(x).detach().numpy(), 1e-4)
True
{% endraw %} {% raw %}
import tensorflow as tf
import torch
{% endraw %} {% raw %}
tf.__version__
'2.3.0'
{% endraw %} {% raw %}
# model_test = tf2_to_pytorch(tf_model, inputs_as_nchw=None, opset=13).eval()
{% endraw %} {% raw %}
import numpy as np
from chitra.image import Chitra

image = Chitra("https://c.files.bbci.co.uk/957C/production/_111686283_pic1.png")
image.image = image.image.resize((224, 224)).convert("RGB")
image.imshow()
{% endraw %} {% raw %}
x1 = tf.cast(image.to_tensor("tf"), tf.float32) / 127.5 - 1.0
x1 = tf.expand_dims(x1, 0)

x2 = image.numpy()[:].astype(np.float32) / 255
x2 = np.expand_dims(x2, 0)
x2 = torch.from_numpy(x2)
x2 = x2.permute(0, 3, 1, 2)
{% endraw %} {% raw %}
x2.shape
torch.Size([1, 3, 224, 224])
{% endraw %} {% raw %}
Chitra(((x1[0] + 1) * 127.5).numpy().astype("uint8")).imshow()
{% endraw %} {% raw %}
from chitra.core import IMAGENET_LABELS

res1 = tf.math.softmax(tf_model.predict(x1), 1)
IMAGENET_LABELS[tf.argmax(res1, 1).numpy()[0]]
'pinwheel'
{% endraw %} {% raw %}
res2 = my_model(x2)
# IMAGENET_LABELS[torch.argmax(res2).item()]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-252-d9aab2a98c5d> in <module>
----> 1 res2 = my_model(x2)
      2 # IMAGENET_LABELS[torch.argmax(res2).item()]

~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
    117     def forward(self, input):
    118         for module in self:
--> 119             input = module(input)
    120         return input
    121 

~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/miniconda3/envs/torch/lib/python3.8/site-packages/onnx2pytorch/convert/model.py in forward(self, *input)
    132                 activations[out_op_id] = op(in_activations[0])
    133             else:
--> 134                 activations[out_op_id] = op(*in_activations)
    135 
    136             if self.debug:

~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
    117     def forward(self, input):
    118         for module in self:
--> 119             input = module(input)
    120         return input
    121 

~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/conv.py in forward(self, input)
    397 
    398     def forward(self, input: Tensor) -> Tensor:
--> 399         return self._conv_forward(input, self.weight, self.bias)
    400 
    401 class Conv3d(_ConvNd):

~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    393                             weight, bias, self.stride,
    394                             _pair(0), self.dilation, self.groups)
--> 395         return F.conv2d(input, weight, bias, self.stride,
    396                         self.padding, self.dilation, self.groups)
    397 

RuntimeError: Given groups=1, weight of size [32, 3, 3, 3], expected input[1, 224, 4, 225] to have 3 channels, but got 224 channels instead
{% endraw %} {% raw %}
my_model
Sequential(
  (0): ConvertModel(
    (Conv_mobilenetv2_1.00_224/bn_Conv1/FusedBatchNormV3:0): Sequential(
      (0): ConstantPad2d(padding=[0, 1, 0, 1], value=0)
      (1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2))
    )
    (Clip_mobilenetv2_1.00_224/Conv1_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/expanded_conv_depthwise_BN/FusedBatchNormV3:0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
    (Clip_mobilenetv2_1.00_224/expanded_conv_depthwise_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/expanded_conv_project_BN/FusedBatchNormV3:0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
    (Conv_mobilenetv2_1.00_224/block_1_expand/Conv2D:0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (BatchNormalization_mobilenetv2_1.00_224/block_1_expand_BN/FusedBatchNormV3:0): BatchNormUnsafe(96, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
    (Clip_mobilenetv2_1.00_224/block_1_expand_relu/Relu6:0): clamp()
    (Split_Split__8143:0): Split()
    (Pad_mobilenetv2_1.00_224/block_1_pad/Pad:0): Pad()
    (Conv_mobilenetv2_1.00_224/block_1_depthwise_BN/FusedBatchNormV3:0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), groups=96)
    (Clip_mobilenetv2_1.00_224/block_1_depthwise_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_1_project_BN/FusedBatchNormV3:0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1))
    (Conv_mobilenetv2_1.00_224/block_2_expand/Conv2D:0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (BatchNormalization_mobilenetv2_1.00_224/block_2_expand_BN/FusedBatchNormV3:0): BatchNormUnsafe(144, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
    (Clip_mobilenetv2_1.00_224/block_2_expand_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_2_depthwise_BN/FusedBatchNormV3:0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144)
    (Clip_mobilenetv2_1.00_224/block_2_depthwise_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_2_project_BN/FusedBatchNormV3:0): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1))
    (Add_mobilenetv2_1.00_224/block_2_add/add:0): Add()
    (Conv_mobilenetv2_1.00_224/block_3_expand_BN/FusedBatchNormV3:0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1))
    (Clip_mobilenetv2_1.00_224/block_3_expand_relu/Relu6:0): clamp()
    (Pad_mobilenetv2_1.00_224/block_3_pad/Pad:0): Pad()
    (Conv_mobilenetv2_1.00_224/block_3_depthwise_BN/FusedBatchNormV3:0): Conv2d(144, 144, kernel_size=(3, 3), stride=(2, 2), groups=144)
    (Clip_mobilenetv2_1.00_224/block_3_depthwise_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_3_project_BN/FusedBatchNormV3:0): Conv2d(144, 32, kernel_size=(1, 1), stride=(1, 1))
    (Conv_mobilenetv2_1.00_224/block_4_expand/Conv2D:0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (BatchNormalization_mobilenetv2_1.00_224/block_4_expand_BN/FusedBatchNormV3:0): BatchNormUnsafe(192, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
    (Clip_mobilenetv2_1.00_224/block_4_expand_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_4_depthwise_BN/FusedBatchNormV3:0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192)
    (Clip_mobilenetv2_1.00_224/block_4_depthwise_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_4_project_BN/FusedBatchNormV3:0): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
    (Add_mobilenetv2_1.00_224/block_4_add/add:0): Add()
    (Conv_mobilenetv2_1.00_224/block_5_expand_BN/FusedBatchNormV3:0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1))
    (Clip_mobilenetv2_1.00_224/block_5_expand_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_5_depthwise_BN/FusedBatchNormV3:0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192)
    (Clip_mobilenetv2_1.00_224/block_5_depthwise_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_5_project_BN/FusedBatchNormV3:0): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
    (Add_mobilenetv2_1.00_224/block_5_add/add:0): Add()
    (Conv_mobilenetv2_1.00_224/block_6_expand_BN/FusedBatchNormV3:0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1))
    (Clip_mobilenetv2_1.00_224/block_6_expand_relu/Relu6:0): clamp()
    (Pad_mobilenetv2_1.00_224/block_6_pad/Pad:0): Pad()
    (Conv_mobilenetv2_1.00_224/block_6_depthwise_BN/FusedBatchNormV3:0): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), groups=192)
    (Clip_mobilenetv2_1.00_224/block_6_depthwise_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_6_project_BN/FusedBatchNormV3:0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
    (Conv_mobilenetv2_1.00_224/block_7_expand/Conv2D:0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (BatchNormalization_mobilenetv2_1.00_224/block_7_expand_BN/FusedBatchNormV3:0): BatchNormUnsafe(384, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
    (Clip_mobilenetv2_1.00_224/block_7_expand_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_7_depthwise_BN/FusedBatchNormV3:0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)
    (Clip_mobilenetv2_1.00_224/block_7_depthwise_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_7_project_BN/FusedBatchNormV3:0): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1))
    (Add_mobilenetv2_1.00_224/block_7_add/add:0): Add()
    (Conv_mobilenetv2_1.00_224/block_8_expand_BN/FusedBatchNormV3:0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1))
    (Clip_mobilenetv2_1.00_224/block_8_expand_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_8_depthwise_BN/FusedBatchNormV3:0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)
    (Clip_mobilenetv2_1.00_224/block_8_depthwise_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_8_project_BN/FusedBatchNormV3:0): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1))
    (Add_mobilenetv2_1.00_224/block_8_add/add:0): Add()
    (Conv_mobilenetv2_1.00_224/block_9_expand_BN/FusedBatchNormV3:0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1))
    (Clip_mobilenetv2_1.00_224/block_9_expand_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_9_depthwise_BN/FusedBatchNormV3:0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)
    (Clip_mobilenetv2_1.00_224/block_9_depthwise_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_9_project_BN/FusedBatchNormV3:0): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1))
    (Add_mobilenetv2_1.00_224/block_9_add/add:0): Add()
    (Conv_mobilenetv2_1.00_224/block_10_expand_BN/FusedBatchNormV3:0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1))
    (Clip_mobilenetv2_1.00_224/block_10_expand_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_10_depthwise_BN/FusedBatchNormV3:0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)
    (Clip_mobilenetv2_1.00_224/block_10_depthwise_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_10_project_BN/FusedBatchNormV3:0): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1))
    (Conv_mobilenetv2_1.00_224/block_11_expand/Conv2D:0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (BatchNormalization_mobilenetv2_1.00_224/block_11_expand_BN/FusedBatchNormV3:0): BatchNormUnsafe(576, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
    (Clip_mobilenetv2_1.00_224/block_11_expand_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_11_depthwise_BN/FusedBatchNormV3:0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576)
    (Clip_mobilenetv2_1.00_224/block_11_depthwise_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_11_project_BN/FusedBatchNormV3:0): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
    (Add_mobilenetv2_1.00_224/block_11_add/add:0): Add()
    (Conv_mobilenetv2_1.00_224/block_12_expand_BN/FusedBatchNormV3:0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1))
    (Clip_mobilenetv2_1.00_224/block_12_expand_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_12_depthwise_BN/FusedBatchNormV3:0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576)
    (Clip_mobilenetv2_1.00_224/block_12_depthwise_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_12_project_BN/FusedBatchNormV3:0): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
    (Add_mobilenetv2_1.00_224/block_12_add/add:0): Add()
    (Conv_mobilenetv2_1.00_224/block_13_expand_BN/FusedBatchNormV3:0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1))
    (Clip_mobilenetv2_1.00_224/block_13_expand_relu/Relu6:0): clamp()
    (Pad_mobilenetv2_1.00_224/block_13_pad/Pad:0): Pad()
    (Conv_mobilenetv2_1.00_224/block_13_depthwise_BN/FusedBatchNormV3:0): Conv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), groups=576)
    (Clip_mobilenetv2_1.00_224/block_13_depthwise_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_13_project_BN/FusedBatchNormV3:0): Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1))
    (Conv_mobilenetv2_1.00_224/block_14_expand/Conv2D:0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (BatchNormalization_mobilenetv2_1.00_224/block_14_expand_BN/FusedBatchNormV3:0): BatchNormUnsafe(960, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
    (Clip_mobilenetv2_1.00_224/block_14_expand_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_14_depthwise_BN/FusedBatchNormV3:0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960)
    (Clip_mobilenetv2_1.00_224/block_14_depthwise_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_14_project_BN/FusedBatchNormV3:0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1))
    (Add_mobilenetv2_1.00_224/block_14_add/add:0): Add()
    (Conv_mobilenetv2_1.00_224/block_15_expand_BN/FusedBatchNormV3:0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1))
    (Clip_mobilenetv2_1.00_224/block_15_expand_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_15_depthwise_BN/FusedBatchNormV3:0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960)
    (Clip_mobilenetv2_1.00_224/block_15_depthwise_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_15_project_BN/FusedBatchNormV3:0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1))
    (Add_mobilenetv2_1.00_224/block_15_add/add:0): Add()
    (Conv_mobilenetv2_1.00_224/block_16_expand_BN/FusedBatchNormV3:0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1))
    (Clip_mobilenetv2_1.00_224/block_16_expand_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_16_depthwise_BN/FusedBatchNormV3:0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960)
    (Clip_mobilenetv2_1.00_224/block_16_depthwise_relu/Relu6:0): clamp()
    (Conv_mobilenetv2_1.00_224/block_16_project_BN/FusedBatchNormV3:0): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1))
    (Conv_mobilenetv2_1.00_224/Conv_1/Conv2D:0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (BatchNormalization_mobilenetv2_1.00_224/Conv_1_bn/FusedBatchNormV3:0): BatchNormUnsafe(1280, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
    (Clip_mobilenetv2_1.00_224/out_relu/Relu6:0): clamp()
    (GlobalAveragePool_mobilenetv2_1.00_224/global_average_pooling2d_9/Mean:0): GlobalAveragePool()
    (Squeeze_mobilenetv2_1.00_224/global_average_pooling2d_9/Mean_Squeeze__8183:0): Squeeze()
    (MatMul_mobilenetv2_1.00_224/predictions/BiasAdd:0): Linear(in_features=1280, out_features=1000, bias=True)
    (Softmax_predictions): Softmax(dim=None)
  )
  (1): Sequential(
    (0): ConstantPad2d(padding=[0, 1, 0, 1], value=0)
    (1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2))
  )
  (2): ConstantPad2d(padding=[0, 1, 0, 1], value=0)
  (3): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2))
  (4): clamp()
  (5): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
  (6): clamp()
  (7): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
  (8): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (9): BatchNormUnsafe(96, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
  (10): clamp()
  (11): Split()
  (12): Pad()
  (13): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), groups=96)
  (14): clamp()
  (15): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1))
  (16): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (17): BatchNormUnsafe(144, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
  (18): clamp()
  (19): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144)
  (20): clamp()
  (21): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1))
  (22): Add()
  (23): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1))
  (24): clamp()
  (25): Pad()
  (26): Conv2d(144, 144, kernel_size=(3, 3), stride=(2, 2), groups=144)
  (27): clamp()
  (28): Conv2d(144, 32, kernel_size=(1, 1), stride=(1, 1))
  (29): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (30): BatchNormUnsafe(192, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
  (31): clamp()
  (32): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192)
  (33): clamp()
  (34): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
  (35): Add()
  (36): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1))
  (37): clamp()
  (38): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192)
  (39): clamp()
  (40): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
  (41): Add()
  (42): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1))
  (43): clamp()
  (44): Pad()
  (45): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), groups=192)
  (46): clamp()
  (47): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
  (48): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (49): BatchNormUnsafe(384, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
  (50): clamp()
  (51): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)
  (52): clamp()
  (53): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1))
  (54): Add()
  (55): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1))
  (56): clamp()
  (57): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)
  (58): clamp()
  (59): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1))
  (60): Add()
  (61): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1))
  (62): clamp()
  (63): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)
  (64): clamp()
  (65): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1))
  (66): Add()
  (67): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1))
  (68): clamp()
  (69): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)
  (70): clamp()
  (71): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1))
  (72): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (73): BatchNormUnsafe(576, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
  (74): clamp()
  (75): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576)
  (76): clamp()
  (77): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
  (78): Add()
  (79): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1))
  (80): clamp()
  (81): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576)
  (82): clamp()
  (83): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
  (84): Add()
  (85): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1))
  (86): clamp()
  (87): Pad()
  (88): Conv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), groups=576)
  (89): clamp()
  (90): Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1))
  (91): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (92): BatchNormUnsafe(960, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
  (93): clamp()
  (94): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960)
  (95): clamp()
  (96): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1))
  (97): Add()
  (98): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1))
  (99): clamp()
  (100): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960)
  (101): clamp()
  (102): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1))
  (103): Add()
  (104): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1))
  (105): clamp()
  (106): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960)
  (107): clamp()
  (108): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1))
  (109): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (110): BatchNormUnsafe(1280, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
  (111): clamp()
  (112): GlobalAveragePool()
  (113): Squeeze()
  (114): Linear(in_features=1280, out_features=1000, bias=True)
)
{% endraw %} {% raw %}
x2.shape, res2.shape
(torch.Size([1, 224, 224, 3]), torch.Size([9, 1000]))
{% endraw %}