1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Resnet examples.""" 17 18# pylint: disable=missing-docstring, arguments-differ 19 20import mindspore.nn as nn 21from mindspore.ops import operations as P 22 23 24def conv3x3(in_channels, out_channels, stride=1, padding=1, pad_mode='pad'): 25 """3x3 convolution """ 26 return nn.Conv2d(in_channels, out_channels, 27 kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode) 28 29 30def conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='pad'): 31 """1x1 convolution""" 32 return nn.Conv2d(in_channels, out_channels, 33 kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode) 34 35 36class ResidualBlock(nn.Cell): 37 """ 38 residual Block 39 """ 40 expansion = 4 41 42 def __init__(self, 43 in_channels, 44 out_channels, 45 stride=1, 46 down_sample=False): 47 super(ResidualBlock, self).__init__() 48 49 out_chls = out_channels // self.expansion 50 self.conv1 = conv1x1(in_channels, out_chls, stride=1, padding=0) 51 self.bn1 = nn.BatchNorm2d(out_chls) 52 53 self.conv2 = conv3x3(out_chls, out_chls, stride=stride, padding=1) 54 self.bn2 = nn.BatchNorm2d(out_chls) 55 56 self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0) 57 self.bn3 = nn.BatchNorm2d(out_channels) 58 59 self.relu = nn.ReLU() 60 self.downsample = down_sample 61 62 self.conv_down_sample = conv1x1(in_channels, out_channels, 63 stride=stride, padding=0) 64 self.bn_down_sample = nn.BatchNorm2d(out_channels) 65 self.add = P.Add() 66 67 def construct(self, x): 68 """ 69 :param x: 70 :return: 71 """ 72 identity = x 73 74 out = self.conv1(x) 75 out = self.bn1(out) 76 out = self.relu(out) 77 78 out = self.conv2(out) 79 out = self.bn2(out) 80 out = self.relu(out) 81 82 out = self.conv3(out) 83 out = self.bn3(out) 84 85 if self.downsample: 86 identity = self.conv_down_sample(identity) 87 identity = self.bn_down_sample(identity) 88 89 out = self.add(out, identity) 90 out = self.relu(out) 91 92 return out 93 94 95class ResNet50(nn.Cell): 96 """ 97 resnet nn.Cell 98 """ 99 100 def __init__(self, block, num_classes=100): 101 super(ResNet50, self).__init__() 102 103 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad') 104 self.bn1 = nn.BatchNorm2d(64) 105 self.relu = nn.ReLU() 106 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid') 107 108 self.layer1 = self.MakeLayer( 109 block, 3, in_channels=64, out_channels=256, stride=1) 110 self.layer2 = self.MakeLayer( 111 block, 4, in_channels=256, out_channels=512, stride=2) 112 self.layer3 = self.MakeLayer( 113 block, 6, in_channels=512, out_channels=1024, stride=2) 114 self.layer4 = self.MakeLayer( 115 block, 3, in_channels=1024, out_channels=2048, stride=2) 116 117 self.avgpool = nn.AvgPool2d(7, 1) 118 self.flatten = P.Flatten() 119 self.fc = nn.Dense(512 * block.expansion, num_classes) 120 121 def MakeLayer(self, block, layer_num, in_channels, out_channels, stride): 122 """ 123 make block layer 124 :param block: 125 :param layer_num: 126 :param in_channels: 127 :param out_channels: 128 :param stride: 129 :return: 130 """ 131 layers = [] 132 resblk = block(in_channels, out_channels, 133 stride=stride, down_sample=True) 134 layers.append(resblk) 135 136 for _ in range(1, layer_num): 137 resblk = block(out_channels, out_channels, stride=1) 138 layers.append(resblk) 139 140 return nn.SequentialCell(layers) 141 142 def construct(self, x): 143 """ 144 :param x: 145 :return: 146 """ 147 x = self.conv1(x) 148 x = self.bn1(x) 149 x = self.relu(x) 150 x = self.maxpool(x) 151 152 x = self.layer1(x) 153 x = self.layer2(x) 154 x = self.layer3(x) 155 x = self.layer4(x) 156 157 x = self.avgpool(x) 158 x = self.flatten(x) 159 x = self.fc(x) 160 161 return x 162 163 164def resnet50(): 165 return ResNet50(ResidualBlock, 10) 166