--- title: Model Interconversion keywords: fastai sidebar: home_sidebar summary: "API details." description: "API details." nb_path: "nbs/08_model_converter.ipynb" ---
pytorch_to_onnx[source]
pytorch_to_onnx(model,tensor,export_path='temp.onnx')
onnx_to_pytorch[source]
onnx_to_pytorch(onnx_model)
tf2_to_onnx[source]
tf2_to_onnx(model,opset=None,output_path=None, **kwargs)
tf2_to_pytorch[source]
tf2_to_pytorch(model,opset=None, **kwargs)
import numpy as np
import timm
model1 = timm.create_model("resnet18")
model1.eval()
model_inter_path = pytorch_to_onnx(model1, torch.randn(1, 3, 224, 224))
model2 = onnx_to_pytorch(model_inter_path)
x = torch.randn(1, 3, 224, 224)
np.allclose(model1(x).detach().numpy(), model2(x).detach().numpy(), 1e-4)
True
import tensorflow as tf
import torch
tf.__version__
'2.3.0'
# model_test = tf2_to_pytorch(tf_model, inputs_as_nchw=None, opset=13).eval()
import numpy as np
from chitra.image import Chitra
image = Chitra("https://c.files.bbci.co.uk/957C/production/_111686283_pic1.png")
image.image = image.image.resize((224, 224)).convert("RGB")
image.imshow()
x1 = tf.cast(image.to_tensor("tf"), tf.float32) / 127.5 - 1.0
x1 = tf.expand_dims(x1, 0)
x2 = image.numpy()[:].astype(np.float32) / 255
x2 = np.expand_dims(x2, 0)
x2 = torch.from_numpy(x2)
x2 = x2.permute(0, 3, 1, 2)
x2.shape
torch.Size([1, 3, 224, 224])
Chitra(((x1[0] + 1) * 127.5).numpy().astype("uint8")).imshow()
from chitra.core import IMAGENET_LABELS
res1 = tf.math.softmax(tf_model.predict(x1), 1)
IMAGENET_LABELS[tf.argmax(res1, 1).numpy()[0]]
'pinwheel'
res2 = my_model(x2)
# IMAGENET_LABELS[torch.argmax(res2).item()]
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-252-d9aab2a98c5d> in <module> ----> 1 res2 = my_model(x2) 2 # IMAGENET_LABELS[torch.argmax(res2).item()] ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 887 result = self._slow_forward(*input, **kwargs) 888 else: --> 889 result = self.forward(*input, **kwargs) 890 for hook in itertools.chain( 891 _global_forward_hooks.values(), ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input) 117 def forward(self, input): 118 for module in self: --> 119 input = module(input) 120 return input 121 ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 887 result = self._slow_forward(*input, **kwargs) 888 else: --> 889 result = self.forward(*input, **kwargs) 890 for hook in itertools.chain( 891 _global_forward_hooks.values(), ~/miniconda3/envs/torch/lib/python3.8/site-packages/onnx2pytorch/convert/model.py in forward(self, *input) 132 activations[out_op_id] = op(in_activations[0]) 133 else: --> 134 activations[out_op_id] = op(*in_activations) 135 136 if self.debug: ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 887 result = self._slow_forward(*input, **kwargs) 888 else: --> 889 result = self.forward(*input, **kwargs) 890 for hook in itertools.chain( 891 _global_forward_hooks.values(), ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input) 117 def forward(self, input): 118 for module in self: --> 119 input = module(input) 120 return input 121 ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 887 result = self._slow_forward(*input, **kwargs) 888 else: --> 889 result = self.forward(*input, **kwargs) 890 for hook in itertools.chain( 891 _global_forward_hooks.values(), ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/conv.py in forward(self, input) 397 398 def forward(self, input: Tensor) -> Tensor: --> 399 return self._conv_forward(input, self.weight, self.bias) 400 401 class Conv3d(_ConvNd): ~/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias) 393 weight, bias, self.stride, 394 _pair(0), self.dilation, self.groups) --> 395 return F.conv2d(input, weight, bias, self.stride, 396 self.padding, self.dilation, self.groups) 397 RuntimeError: Given groups=1, weight of size [32, 3, 3, 3], expected input[1, 224, 4, 225] to have 3 channels, but got 224 channels instead
my_model
Sequential(
(0): ConvertModel(
(Conv_mobilenetv2_1.00_224/bn_Conv1/FusedBatchNormV3:0): Sequential(
(0): ConstantPad2d(padding=[0, 1, 0, 1], value=0)
(1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2))
)
(Clip_mobilenetv2_1.00_224/Conv1_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/expanded_conv_depthwise_BN/FusedBatchNormV3:0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
(Clip_mobilenetv2_1.00_224/expanded_conv_depthwise_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/expanded_conv_project_BN/FusedBatchNormV3:0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
(Conv_mobilenetv2_1.00_224/block_1_expand/Conv2D:0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
(BatchNormalization_mobilenetv2_1.00_224/block_1_expand_BN/FusedBatchNormV3:0): BatchNormUnsafe(96, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
(Clip_mobilenetv2_1.00_224/block_1_expand_relu/Relu6:0): clamp()
(Split_Split__8143:0): Split()
(Pad_mobilenetv2_1.00_224/block_1_pad/Pad:0): Pad()
(Conv_mobilenetv2_1.00_224/block_1_depthwise_BN/FusedBatchNormV3:0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), groups=96)
(Clip_mobilenetv2_1.00_224/block_1_depthwise_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_1_project_BN/FusedBatchNormV3:0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1))
(Conv_mobilenetv2_1.00_224/block_2_expand/Conv2D:0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
(BatchNormalization_mobilenetv2_1.00_224/block_2_expand_BN/FusedBatchNormV3:0): BatchNormUnsafe(144, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
(Clip_mobilenetv2_1.00_224/block_2_expand_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_2_depthwise_BN/FusedBatchNormV3:0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144)
(Clip_mobilenetv2_1.00_224/block_2_depthwise_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_2_project_BN/FusedBatchNormV3:0): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1))
(Add_mobilenetv2_1.00_224/block_2_add/add:0): Add()
(Conv_mobilenetv2_1.00_224/block_3_expand_BN/FusedBatchNormV3:0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1))
(Clip_mobilenetv2_1.00_224/block_3_expand_relu/Relu6:0): clamp()
(Pad_mobilenetv2_1.00_224/block_3_pad/Pad:0): Pad()
(Conv_mobilenetv2_1.00_224/block_3_depthwise_BN/FusedBatchNormV3:0): Conv2d(144, 144, kernel_size=(3, 3), stride=(2, 2), groups=144)
(Clip_mobilenetv2_1.00_224/block_3_depthwise_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_3_project_BN/FusedBatchNormV3:0): Conv2d(144, 32, kernel_size=(1, 1), stride=(1, 1))
(Conv_mobilenetv2_1.00_224/block_4_expand/Conv2D:0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
(BatchNormalization_mobilenetv2_1.00_224/block_4_expand_BN/FusedBatchNormV3:0): BatchNormUnsafe(192, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
(Clip_mobilenetv2_1.00_224/block_4_expand_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_4_depthwise_BN/FusedBatchNormV3:0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192)
(Clip_mobilenetv2_1.00_224/block_4_depthwise_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_4_project_BN/FusedBatchNormV3:0): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
(Add_mobilenetv2_1.00_224/block_4_add/add:0): Add()
(Conv_mobilenetv2_1.00_224/block_5_expand_BN/FusedBatchNormV3:0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1))
(Clip_mobilenetv2_1.00_224/block_5_expand_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_5_depthwise_BN/FusedBatchNormV3:0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192)
(Clip_mobilenetv2_1.00_224/block_5_depthwise_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_5_project_BN/FusedBatchNormV3:0): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
(Add_mobilenetv2_1.00_224/block_5_add/add:0): Add()
(Conv_mobilenetv2_1.00_224/block_6_expand_BN/FusedBatchNormV3:0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1))
(Clip_mobilenetv2_1.00_224/block_6_expand_relu/Relu6:0): clamp()
(Pad_mobilenetv2_1.00_224/block_6_pad/Pad:0): Pad()
(Conv_mobilenetv2_1.00_224/block_6_depthwise_BN/FusedBatchNormV3:0): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), groups=192)
(Clip_mobilenetv2_1.00_224/block_6_depthwise_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_6_project_BN/FusedBatchNormV3:0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
(Conv_mobilenetv2_1.00_224/block_7_expand/Conv2D:0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(BatchNormalization_mobilenetv2_1.00_224/block_7_expand_BN/FusedBatchNormV3:0): BatchNormUnsafe(384, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
(Clip_mobilenetv2_1.00_224/block_7_expand_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_7_depthwise_BN/FusedBatchNormV3:0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)
(Clip_mobilenetv2_1.00_224/block_7_depthwise_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_7_project_BN/FusedBatchNormV3:0): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1))
(Add_mobilenetv2_1.00_224/block_7_add/add:0): Add()
(Conv_mobilenetv2_1.00_224/block_8_expand_BN/FusedBatchNormV3:0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1))
(Clip_mobilenetv2_1.00_224/block_8_expand_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_8_depthwise_BN/FusedBatchNormV3:0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)
(Clip_mobilenetv2_1.00_224/block_8_depthwise_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_8_project_BN/FusedBatchNormV3:0): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1))
(Add_mobilenetv2_1.00_224/block_8_add/add:0): Add()
(Conv_mobilenetv2_1.00_224/block_9_expand_BN/FusedBatchNormV3:0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1))
(Clip_mobilenetv2_1.00_224/block_9_expand_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_9_depthwise_BN/FusedBatchNormV3:0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)
(Clip_mobilenetv2_1.00_224/block_9_depthwise_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_9_project_BN/FusedBatchNormV3:0): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1))
(Add_mobilenetv2_1.00_224/block_9_add/add:0): Add()
(Conv_mobilenetv2_1.00_224/block_10_expand_BN/FusedBatchNormV3:0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1))
(Clip_mobilenetv2_1.00_224/block_10_expand_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_10_depthwise_BN/FusedBatchNormV3:0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)
(Clip_mobilenetv2_1.00_224/block_10_depthwise_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_10_project_BN/FusedBatchNormV3:0): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1))
(Conv_mobilenetv2_1.00_224/block_11_expand/Conv2D:0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
(BatchNormalization_mobilenetv2_1.00_224/block_11_expand_BN/FusedBatchNormV3:0): BatchNormUnsafe(576, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
(Clip_mobilenetv2_1.00_224/block_11_expand_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_11_depthwise_BN/FusedBatchNormV3:0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576)
(Clip_mobilenetv2_1.00_224/block_11_depthwise_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_11_project_BN/FusedBatchNormV3:0): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
(Add_mobilenetv2_1.00_224/block_11_add/add:0): Add()
(Conv_mobilenetv2_1.00_224/block_12_expand_BN/FusedBatchNormV3:0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1))
(Clip_mobilenetv2_1.00_224/block_12_expand_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_12_depthwise_BN/FusedBatchNormV3:0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576)
(Clip_mobilenetv2_1.00_224/block_12_depthwise_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_12_project_BN/FusedBatchNormV3:0): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
(Add_mobilenetv2_1.00_224/block_12_add/add:0): Add()
(Conv_mobilenetv2_1.00_224/block_13_expand_BN/FusedBatchNormV3:0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1))
(Clip_mobilenetv2_1.00_224/block_13_expand_relu/Relu6:0): clamp()
(Pad_mobilenetv2_1.00_224/block_13_pad/Pad:0): Pad()
(Conv_mobilenetv2_1.00_224/block_13_depthwise_BN/FusedBatchNormV3:0): Conv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), groups=576)
(Clip_mobilenetv2_1.00_224/block_13_depthwise_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_13_project_BN/FusedBatchNormV3:0): Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1))
(Conv_mobilenetv2_1.00_224/block_14_expand/Conv2D:0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
(BatchNormalization_mobilenetv2_1.00_224/block_14_expand_BN/FusedBatchNormV3:0): BatchNormUnsafe(960, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
(Clip_mobilenetv2_1.00_224/block_14_expand_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_14_depthwise_BN/FusedBatchNormV3:0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960)
(Clip_mobilenetv2_1.00_224/block_14_depthwise_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_14_project_BN/FusedBatchNormV3:0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1))
(Add_mobilenetv2_1.00_224/block_14_add/add:0): Add()
(Conv_mobilenetv2_1.00_224/block_15_expand_BN/FusedBatchNormV3:0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1))
(Clip_mobilenetv2_1.00_224/block_15_expand_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_15_depthwise_BN/FusedBatchNormV3:0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960)
(Clip_mobilenetv2_1.00_224/block_15_depthwise_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_15_project_BN/FusedBatchNormV3:0): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1))
(Add_mobilenetv2_1.00_224/block_15_add/add:0): Add()
(Conv_mobilenetv2_1.00_224/block_16_expand_BN/FusedBatchNormV3:0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1))
(Clip_mobilenetv2_1.00_224/block_16_expand_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_16_depthwise_BN/FusedBatchNormV3:0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960)
(Clip_mobilenetv2_1.00_224/block_16_depthwise_relu/Relu6:0): clamp()
(Conv_mobilenetv2_1.00_224/block_16_project_BN/FusedBatchNormV3:0): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1))
(Conv_mobilenetv2_1.00_224/Conv_1/Conv2D:0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
(BatchNormalization_mobilenetv2_1.00_224/Conv_1_bn/FusedBatchNormV3:0): BatchNormUnsafe(1280, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
(Clip_mobilenetv2_1.00_224/out_relu/Relu6:0): clamp()
(GlobalAveragePool_mobilenetv2_1.00_224/global_average_pooling2d_9/Mean:0): GlobalAveragePool()
(Squeeze_mobilenetv2_1.00_224/global_average_pooling2d_9/Mean_Squeeze__8183:0): Squeeze()
(MatMul_mobilenetv2_1.00_224/predictions/BiasAdd:0): Linear(in_features=1280, out_features=1000, bias=True)
(Softmax_predictions): Softmax(dim=None)
)
(1): Sequential(
(0): ConstantPad2d(padding=[0, 1, 0, 1], value=0)
(1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2))
)
(2): ConstantPad2d(padding=[0, 1, 0, 1], value=0)
(3): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2))
(4): clamp()
(5): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
(6): clamp()
(7): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
(8): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
(9): BatchNormUnsafe(96, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
(10): clamp()
(11): Split()
(12): Pad()
(13): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), groups=96)
(14): clamp()
(15): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1))
(16): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
(17): BatchNormUnsafe(144, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
(18): clamp()
(19): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144)
(20): clamp()
(21): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1))
(22): Add()
(23): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1))
(24): clamp()
(25): Pad()
(26): Conv2d(144, 144, kernel_size=(3, 3), stride=(2, 2), groups=144)
(27): clamp()
(28): Conv2d(144, 32, kernel_size=(1, 1), stride=(1, 1))
(29): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
(30): BatchNormUnsafe(192, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
(31): clamp()
(32): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192)
(33): clamp()
(34): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
(35): Add()
(36): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1))
(37): clamp()
(38): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192)
(39): clamp()
(40): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
(41): Add()
(42): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1))
(43): clamp()
(44): Pad()
(45): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), groups=192)
(46): clamp()
(47): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
(48): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(49): BatchNormUnsafe(384, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
(50): clamp()
(51): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)
(52): clamp()
(53): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1))
(54): Add()
(55): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1))
(56): clamp()
(57): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)
(58): clamp()
(59): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1))
(60): Add()
(61): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1))
(62): clamp()
(63): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)
(64): clamp()
(65): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1))
(66): Add()
(67): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1))
(68): clamp()
(69): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)
(70): clamp()
(71): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1))
(72): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
(73): BatchNormUnsafe(576, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
(74): clamp()
(75): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576)
(76): clamp()
(77): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
(78): Add()
(79): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1))
(80): clamp()
(81): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576)
(82): clamp()
(83): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
(84): Add()
(85): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1))
(86): clamp()
(87): Pad()
(88): Conv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), groups=576)
(89): clamp()
(90): Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1))
(91): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
(92): BatchNormUnsafe(960, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
(93): clamp()
(94): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960)
(95): clamp()
(96): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1))
(97): Add()
(98): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1))
(99): clamp()
(100): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960)
(101): clamp()
(102): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1))
(103): Add()
(104): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1))
(105): clamp()
(106): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960)
(107): clamp()
(108): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1))
(109): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
(110): BatchNormUnsafe(1280, eps=0.0010000000474974513, momentum=0.1, affine=True, track_running_stats=True)
(111): clamp()
(112): GlobalAveragePool()
(113): Squeeze()
(114): Linear(in_features=1280, out_features=1000, bias=True)
)
x2.shape, res2.shape
(torch.Size([1, 224, 224, 3]), torch.Size([9, 1000]))