--- title: Model Interconversion keywords: fastai sidebar: home_sidebar summary: "API details." description: "API details." nb_path: "nbs/08_model_converter.ipynb" ---
pytorch_to_onnx
[source]
pytorch_to_onnx
(model
,tensor
,export_path
='temp.onnx'
)
onnx_to_pytorch
[source]
onnx_to_pytorch
(onnx_model
)
tf2_to_onnx
[source]
tf2_to_onnx
(model
,opset
=None
,output_path
=None
, **kwargs
)
tf2_to_pytorch
[source]
tf2_to_pytorch
(model
,opset
=None
, **kwargs
)
import numpy as np
import timm
model1 = timm.create_model("resnet18")
model1.eval()
model_inter_path = pytorch_to_onnx(model1, torch.randn(1, 3, 224, 224))
model2 = onnx_to_pytorch(model_inter_path)
x = torch.randn(1, 3, 224, 224)
np.allclose(model1(x).detach().numpy(), model2(x).detach().numpy(), 1e-4)
True
import tensorflow as tf
import torch
tf.__version__
'2.3.0'
# model_test = tf2_to_pytorch(tf_model, inputs_as_nchw=None, opset=13).eval()
import numpy as np
from chitra.image import Chitra
image = Chitra("https://c.files.bbci.co.uk/957C/production/_111686283_pic1.png")
image.image = image.image.resize((224, 224)).convert("RGB")
image.imshow()
x1 = tf.cast(image.to_tensor("tf"), tf.float32) / 127.5 - 1.0
x1 = tf.expand_dims(x1, 0)
x2 = image.numpy()[:].astype(np.float32) / 255
x2 = np.expand_dims(x2, 0)
x2 = torch.from_numpy(x2)
x2 = x2.permute(0, 3, 1, 2)
x2.shape
torch.Size([1, 3, 224, 224])
Chitra(((x1[0] + 1) * 127.5).numpy().astype("uint8")).imshow()
from chitra.core import IMAGENET_LABELS
res1 = tf.math.softmax(tf_model.predict(x1), 1)
IMAGENET_LABELS[tf.argmax(res1, 1).numpy()[0]]
'pinwheel'
res2 = my_model(x2)
# IMAGENET_LABELS[torch.argmax(res2).item()]
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-252-d9aab2a98c5d> in <module> ----> 1 res2 = my_model(x2) 2 # IMAGENET_LABELS[torch.argmax(res2).item()] ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 887 result = self._slow_forward(*input, **kwargs) 888 else: --> 889 result = self.forward(*input, **kwargs) 890 for hook in itertools.chain( 891 _global_forward_hooks.values(), ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input) 117 def forward(self, input): 118 for module in self: --> 119 input = module(input) 120 return input 121 ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 887 result = self._slow_forward(*input, **kwargs) 888 else: --> 889 result = self.forward(*input, **kwargs) 890 for hook in itertools.chain( 891 _global_forward_hooks.values(), ~/miniconda3/envs/torch/lib/python3.8/site-packages/onnx2pytorch/convert/model.py in forward(self, *input) 132 activations[out_op_id] = op(in_activations[0]) 133 else: --> 134 activations[out_op_id] = op(*in_activations) 135 136 if self.debug: ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 887 result = self._slow_forward(*input, **kwargs) 888 else: --> 889 result = self.forward(*input, **kwargs) 890 for hook in itertools.chain( 891 _global_forward_hooks.values(), ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input) 117 def forward(self, input): 118 for module in self: --> 119 input = module(input) 120 return input 121 ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 887 result = self._slow_forward(*input, **kwargs) 888 else: --> 889 result = self.forward(*input, **kwargs) 890 for hook in itertools.chain( 891 _global_forward_hooks.values(), ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/conv.py in forward(self, input) 397 398 def forward(self, input: Tensor) -> Tensor: --> 399 return self._conv_forward(input, self.weight, self.bias) 400 401 class Conv3d(_ConvNd): ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias) 393 weight, bias, self.stride, 394 _pair(0), self.dilation, self.groups) --> 395 return F.conv2d(input, weight, bias, self.stride, 396 self.padding, self.dilation, self.groups) 397 RuntimeError: Given groups=1, weight of size [32, 3, 3, 3], expected input[1, 224, 4, 225] to have 3 channels, but got 224 channels instead
my_model
Sequential( (0): ConvertModel( (Conv_mobilenetv2_1.00_224/bn_Conv1/FusedBatchNormV3:0): Sequential( (0): ConstantPad2d(padding=[0, 1, 0, 1], value=0) (1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2)) ) (Clip_mobilenetv2_1.00_224/Conv1_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/expanded_conv_depthwise_BN/FusedBatchNormV3:0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32) (Clip_mobilenetv2_1.00_224/expanded_conv_depthwise_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/expanded_conv_project_BN/FusedBatchNormV3:0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1)) (Conv_mobilenetv2_1.00_224/block_1_expand/Conv2D:0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (BatchNormalization_mobilenetv2_1.00_224/block_1_expand_BN/FusedBatchNormV3:0): BatchNormUnsafe(96, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True) (Clip_mobilenetv2_1.00_224/block_1_expand_relu/Relu6:0): clamp() (Split_Split__8143:0): Split() (Pad_mobilenetv2_1.00_224/block_1_pad/Pad:0): Pad() (Conv_mobilenetv2_1.00_224/block_1_depthwise_BN/FusedBatchNormV3:0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), groups=96) (Clip_mobilenetv2_1.00_224/block_1_depthwise_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_1_project_BN/FusedBatchNormV3:0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1)) (Conv_mobilenetv2_1.00_224/block_2_expand/Conv2D:0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False) (BatchNormalization_mobilenetv2_1.00_224/block_2_expand_BN/FusedBatchNormV3:0): BatchNormUnsafe(144, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True) (Clip_mobilenetv2_1.00_224/block_2_expand_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_2_depthwise_BN/FusedBatchNormV3:0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144) (Clip_mobilenetv2_1.00_224/block_2_depthwise_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_2_project_BN/FusedBatchNormV3:0): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1)) (Add_mobilenetv2_1.00_224/block_2_add/add:0): Add() (Conv_mobilenetv2_1.00_224/block_3_expand_BN/FusedBatchNormV3:0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1)) (Clip_mobilenetv2_1.00_224/block_3_expand_relu/Relu6:0): clamp() (Pad_mobilenetv2_1.00_224/block_3_pad/Pad:0): Pad() (Conv_mobilenetv2_1.00_224/block_3_depthwise_BN/FusedBatchNormV3:0): Conv2d(144, 144, kernel_size=(3, 3), stride=(2, 2), groups=144) (Clip_mobilenetv2_1.00_224/block_3_depthwise_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_3_project_BN/FusedBatchNormV3:0): Conv2d(144, 32, kernel_size=(1, 1), stride=(1, 1)) (Conv_mobilenetv2_1.00_224/block_4_expand/Conv2D:0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False) (BatchNormalization_mobilenetv2_1.00_224/block_4_expand_BN/FusedBatchNormV3:0): BatchNormUnsafe(192, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True) (Clip_mobilenetv2_1.00_224/block_4_expand_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_4_depthwise_BN/FusedBatchNormV3:0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192) (Clip_mobilenetv2_1.00_224/block_4_depthwise_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_4_project_BN/FusedBatchNormV3:0): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1)) (Add_mobilenetv2_1.00_224/block_4_add/add:0): Add() (Conv_mobilenetv2_1.00_224/block_5_expand_BN/FusedBatchNormV3:0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1)) (Clip_mobilenetv2_1.00_224/block_5_expand_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_5_depthwise_BN/FusedBatchNormV3:0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192) (Clip_mobilenetv2_1.00_224/block_5_depthwise_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_5_project_BN/FusedBatchNormV3:0): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1)) (Add_mobilenetv2_1.00_224/block_5_add/add:0): Add() (Conv_mobilenetv2_1.00_224/block_6_expand_BN/FusedBatchNormV3:0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1)) (Clip_mobilenetv2_1.00_224/block_6_expand_relu/Relu6:0): clamp() (Pad_mobilenetv2_1.00_224/block_6_pad/Pad:0): Pad() (Conv_mobilenetv2_1.00_224/block_6_depthwise_BN/FusedBatchNormV3:0): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), groups=192) (Clip_mobilenetv2_1.00_224/block_6_depthwise_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_6_project_BN/FusedBatchNormV3:0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) (Conv_mobilenetv2_1.00_224/block_7_expand/Conv2D:0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False) (BatchNormalization_mobilenetv2_1.00_224/block_7_expand_BN/FusedBatchNormV3:0): BatchNormUnsafe(384, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True) (Clip_mobilenetv2_1.00_224/block_7_expand_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_7_depthwise_BN/FusedBatchNormV3:0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384) (Clip_mobilenetv2_1.00_224/block_7_depthwise_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_7_project_BN/FusedBatchNormV3:0): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1)) (Add_mobilenetv2_1.00_224/block_7_add/add:0): Add() (Conv_mobilenetv2_1.00_224/block_8_expand_BN/FusedBatchNormV3:0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1)) (Clip_mobilenetv2_1.00_224/block_8_expand_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_8_depthwise_BN/FusedBatchNormV3:0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384) (Clip_mobilenetv2_1.00_224/block_8_depthwise_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_8_project_BN/FusedBatchNormV3:0): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1)) (Add_mobilenetv2_1.00_224/block_8_add/add:0): Add() (Conv_mobilenetv2_1.00_224/block_9_expand_BN/FusedBatchNormV3:0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1)) (Clip_mobilenetv2_1.00_224/block_9_expand_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_9_depthwise_BN/FusedBatchNormV3:0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384) (Clip_mobilenetv2_1.00_224/block_9_depthwise_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_9_project_BN/FusedBatchNormV3:0): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1)) (Add_mobilenetv2_1.00_224/block_9_add/add:0): Add() (Conv_mobilenetv2_1.00_224/block_10_expand_BN/FusedBatchNormV3:0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1)) (Clip_mobilenetv2_1.00_224/block_10_expand_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_10_depthwise_BN/FusedBatchNormV3:0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384) (Clip_mobilenetv2_1.00_224/block_10_depthwise_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_10_project_BN/FusedBatchNormV3:0): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1)) (Conv_mobilenetv2_1.00_224/block_11_expand/Conv2D:0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False) (BatchNormalization_mobilenetv2_1.00_224/block_11_expand_BN/FusedBatchNormV3:0): BatchNormUnsafe(576, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True) (Clip_mobilenetv2_1.00_224/block_11_expand_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_11_depthwise_BN/FusedBatchNormV3:0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576) (Clip_mobilenetv2_1.00_224/block_11_depthwise_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_11_project_BN/FusedBatchNormV3:0): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) (Add_mobilenetv2_1.00_224/block_11_add/add:0): Add() (Conv_mobilenetv2_1.00_224/block_12_expand_BN/FusedBatchNormV3:0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1)) (Clip_mobilenetv2_1.00_224/block_12_expand_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_12_depthwise_BN/FusedBatchNormV3:0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576) (Clip_mobilenetv2_1.00_224/block_12_depthwise_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_12_project_BN/FusedBatchNormV3:0): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) (Add_mobilenetv2_1.00_224/block_12_add/add:0): Add() (Conv_mobilenetv2_1.00_224/block_13_expand_BN/FusedBatchNormV3:0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1)) (Clip_mobilenetv2_1.00_224/block_13_expand_relu/Relu6:0): clamp() (Pad_mobilenetv2_1.00_224/block_13_pad/Pad:0): Pad() (Conv_mobilenetv2_1.00_224/block_13_depthwise_BN/FusedBatchNormV3:0): Conv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), groups=576) (Clip_mobilenetv2_1.00_224/block_13_depthwise_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_13_project_BN/FusedBatchNormV3:0): Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1)) (Conv_mobilenetv2_1.00_224/block_14_expand/Conv2D:0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False) (BatchNormalization_mobilenetv2_1.00_224/block_14_expand_BN/FusedBatchNormV3:0): BatchNormUnsafe(960, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True) (Clip_mobilenetv2_1.00_224/block_14_expand_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_14_depthwise_BN/FusedBatchNormV3:0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960) (Clip_mobilenetv2_1.00_224/block_14_depthwise_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_14_project_BN/FusedBatchNormV3:0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1)) (Add_mobilenetv2_1.00_224/block_14_add/add:0): Add() (Conv_mobilenetv2_1.00_224/block_15_expand_BN/FusedBatchNormV3:0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1)) (Clip_mobilenetv2_1.00_224/block_15_expand_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_15_depthwise_BN/FusedBatchNormV3:0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960) (Clip_mobilenetv2_1.00_224/block_15_depthwise_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_15_project_BN/FusedBatchNormV3:0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1)) (Add_mobilenetv2_1.00_224/block_15_add/add:0): Add() (Conv_mobilenetv2_1.00_224/block_16_expand_BN/FusedBatchNormV3:0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1)) (Clip_mobilenetv2_1.00_224/block_16_expand_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_16_depthwise_BN/FusedBatchNormV3:0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960) (Clip_mobilenetv2_1.00_224/block_16_depthwise_relu/Relu6:0): clamp() (Conv_mobilenetv2_1.00_224/block_16_project_BN/FusedBatchNormV3:0): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1)) (Conv_mobilenetv2_1.00_224/Conv_1/Conv2D:0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False) (BatchNormalization_mobilenetv2_1.00_224/Conv_1_bn/FusedBatchNormV3:0): BatchNormUnsafe(1280, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True) (Clip_mobilenetv2_1.00_224/out_relu/Relu6:0): clamp() (GlobalAveragePool_mobilenetv2_1.00_224/global_average_pooling2d_9/Mean:0): GlobalAveragePool() (Squeeze_mobilenetv2_1.00_224/global_average_pooling2d_9/Mean_Squeeze__8183:0): Squeeze() (MatMul_mobilenetv2_1.00_224/predictions/BiasAdd:0): Linear(in_features=1280, out_features=1000, bias=True) (Softmax_predictions): Softmax(dim=None) ) (1): Sequential( (0): ConstantPad2d(padding=[0, 1, 0, 1], value=0) (1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2)) ) (2): ConstantPad2d(padding=[0, 1, 0, 1], value=0) (3): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2)) (4): clamp() (5): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32) (6): clamp() (7): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1)) (8): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (9): BatchNormUnsafe(96, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True) (10): clamp() (11): Split() (12): Pad() (13): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), groups=96) (14): clamp() (15): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1)) (16): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False) (17): BatchNormUnsafe(144, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True) (18): clamp() (19): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144) (20): clamp() (21): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1)) (22): Add() (23): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1)) (24): clamp() (25): Pad() (26): Conv2d(144, 144, kernel_size=(3, 3), stride=(2, 2), groups=144) (27): clamp() (28): Conv2d(144, 32, kernel_size=(1, 1), stride=(1, 1)) (29): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False) (30): BatchNormUnsafe(192, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True) (31): clamp() (32): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192) (33): clamp() (34): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1)) (35): Add() (36): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1)) (37): clamp() (38): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192) (39): clamp() (40): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1)) (41): Add() (42): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1)) (43): clamp() (44): Pad() (45): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), groups=192) (46): clamp() (47): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) (48): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False) (49): BatchNormUnsafe(384, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True) (50): clamp() (51): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384) (52): clamp() (53): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1)) (54): Add() (55): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1)) (56): clamp() (57): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384) (58): clamp() (59): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1)) (60): Add() (61): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1)) (62): clamp() (63): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384) (64): clamp() (65): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1)) (66): Add() (67): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1)) (68): clamp() (69): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384) (70): clamp() (71): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1)) (72): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False) (73): BatchNormUnsafe(576, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True) (74): clamp() (75): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576) (76): clamp() (77): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) (78): Add() (79): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1)) (80): clamp() (81): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576) (82): clamp() (83): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) (84): Add() (85): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1)) (86): clamp() (87): Pad() (88): Conv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), groups=576) (89): clamp() (90): Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1)) (91): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False) (92): BatchNormUnsafe(960, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True) (93): clamp() (94): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960) (95): clamp() (96): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1)) (97): Add() (98): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1)) (99): clamp() (100): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960) (101): clamp() (102): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1)) (103): Add() (104): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1)) (105): clamp() (106): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960) (107): clamp() (108): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1)) (109): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False) (110): BatchNormUnsafe(1280, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True) (111): clamp() (112): GlobalAveragePool() (113): Squeeze() (114): Linear(in_features=1280, out_features=1000, bias=True) )
x2.shape, res2.shape
(torch.Size([1, 224, 224, 3]), torch.Size([9, 1000]))