1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the License); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# httpwww.apache.orglicensesLICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an AS IS BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""DeepLabv3.""" 16 17import numpy as np 18import mindspore.nn as nn 19from mindspore.ops import operations as P 20from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \ 21 DepthwiseConv2dNative, SpaceToBatch, BatchToSpace 22 23 24class ASPPSampleBlock(nn.Cell): 25 """ASPP sample block.""" 26 def __init__(self, feature_shape, scale_size, output_stride): 27 super(ASPPSampleBlock, self).__init__() 28 sample_h = (feature_shape[0] * scale_size + 1) / output_stride + 1 29 sample_w = (feature_shape[1] * scale_size + 1) / output_stride + 1 30 self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True) 31 32 def construct(self, x): 33 return self.sample(x) 34 35 36class ASPP(nn.Cell): 37 """ 38 ASPP model for DeepLabv3. 39 40 Args: 41 channel (int): Input channel. 42 depth (int): Output channel. 43 feature_shape (list): The shape of feature,[h,w]. 44 scale_sizes (list): Input scales for multi-scale feature extraction. 45 atrous_rates (list): Atrous rates for atrous spatial pyramid pooling. 46 output_stride (int): 'The ratio of input to output spatial resolution.' 47 fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' 48 49 Returns: 50 Tensor, output tensor. 51 52 Examples: 53 >>> ASPP(channel=2048,256,[14,14],[1],[6],16) 54 """ 55 def __init__(self, channel, depth, feature_shape, scale_sizes, 56 atrous_rates, output_stride, fine_tune_batch_norm=False): 57 super(ASPP, self).__init__() 58 self.aspp0 = _conv_bn_relu(channel, 59 depth, 60 ksize=1, 61 stride=1, 62 use_batch_statistics=fine_tune_batch_norm) 63 self.atrous_rates = [] 64 if atrous_rates is not None: 65 self.atrous_rates = atrous_rates 66 self.aspp_pointwise = _conv_bn_relu(channel, 67 depth, 68 ksize=1, 69 stride=1, 70 use_batch_statistics=fine_tune_batch_norm) 71 self.aspp_depth_depthwiseconv = DepthwiseConv2dNative(channel, 72 channel_multiplier=1, 73 kernel_size=3, 74 stride=1, 75 dilation=1, 76 pad_mode="valid") 77 self.aspp_depth_bn = nn.BatchNorm2d(1 * channel, use_batch_statistics=fine_tune_batch_norm) 78 self.aspp_depth_relu = nn.ReLU() 79 self.aspp_depths = [] 80 self.aspp_depth_spacetobatchs = [] 81 self.aspp_depth_batchtospaces = [] 82 83 for scale_size in scale_sizes: 84 aspp_scale_depth_size = np.ceil((feature_shape[0]*scale_size)/16) 85 if atrous_rates is None: 86 break 87 for rate in atrous_rates: 88 padding = 0 89 for j in range(100): 90 padded_size = rate * j 91 if padded_size >= aspp_scale_depth_size + 2 * rate: 92 padding = padded_size - aspp_scale_depth_size - 2 * rate 93 break 94 paddings = [[rate, rate + int(padding)], 95 [rate, rate + int(padding)]] 96 self.aspp_depth_spacetobatch = SpaceToBatch(rate, paddings) 97 self.aspp_depth_spacetobatchs.append(self.aspp_depth_spacetobatch) 98 crops = [[0, int(padding)], [0, int(padding)]] 99 self.aspp_depth_batchtospace = BatchToSpace(rate, crops) 100 self.aspp_depth_batchtospaces.append(self.aspp_depth_batchtospace) 101 self.aspp_depths = nn.CellList(self.aspp_depths) 102 self.aspp_depth_spacetobatchs = nn.CellList(self.aspp_depth_spacetobatchs) 103 self.aspp_depth_batchtospaces = nn.CellList(self.aspp_depth_batchtospaces) 104 105 self.global_pooling = nn.AvgPool2d(kernel_size=(int(feature_shape[0]), int(feature_shape[1]))) 106 self.global_poolings = [] 107 for scale_size in scale_sizes: 108 pooling_h = np.ceil((feature_shape[0]*scale_size)/output_stride) 109 pooling_w = np.ceil((feature_shape[0]*scale_size)/output_stride) 110 self.global_poolings.append(nn.AvgPool2d(kernel_size=(int(pooling_h), int(pooling_w)))) 111 self.global_poolings = nn.CellList(self.global_poolings) 112 self.conv_bn = _conv_bn_relu(channel, 113 depth, 114 ksize=1, 115 stride=1, 116 use_batch_statistics=fine_tune_batch_norm) 117 self.samples = [] 118 for scale_size in scale_sizes: 119 self.samples.append(ASPPSampleBlock(feature_shape, scale_size, output_stride)) 120 self.samples = nn.CellList(self.samples) 121 self.feature_shape = feature_shape 122 self.concat = P.Concat(axis=1) 123 124 def construct(self, x, scale_index=0): 125 aspp0 = self.aspp0(x) 126 aspp1 = self.global_poolings[scale_index](x) 127 aspp1 = self.conv_bn(aspp1) 128 aspp1 = self.samples[scale_index](aspp1) 129 output = self.concat((aspp1, aspp0)) 130 131 for i in range(len(self.atrous_rates)): 132 aspp_i = self.aspp_depth_spacetobatchs[i + scale_index * len(self.atrous_rates)](x) 133 aspp_i = self.aspp_depth_depthwiseconv(aspp_i) 134 aspp_i = self.aspp_depth_batchtospaces[i + scale_index * len(self.atrous_rates)](aspp_i) 135 aspp_i = self.aspp_depth_bn(aspp_i) 136 aspp_i = self.aspp_depth_relu(aspp_i) 137 aspp_i = self.aspp_pointwise(aspp_i) 138 output = self.concat((output, aspp_i)) 139 return output 140 141 142class DecoderSampleBlock(nn.Cell): 143 """Decoder sample block.""" 144 def __init__(self, feature_shape, scale_size=1.0, decoder_output_stride=4): 145 super(DecoderSampleBlock, self).__init__() 146 sample_h = (feature_shape[0] * scale_size + 1) / decoder_output_stride + 1 147 sample_w = (feature_shape[1] * scale_size + 1) / decoder_output_stride + 1 148 self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True) 149 150 def construct(self, x): 151 return self.sample(x) 152 153 154class Decoder(nn.Cell): 155 """ 156 Decode module for DeepLabv3. 157 Args: 158 low_level_channel (int): Low level input channel 159 channel (int): Input channel. 160 depth (int): Output channel. 161 feature_shape (list): 'Input image shape, [N,C,H,W].' 162 scale_sizes (list): 'Input scales for multi-scale feature extraction.' 163 decoder_output_stride (int): 'The ratio of input to output spatial resolution' 164 fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' 165 Returns: 166 Tensor, output tensor. 167 Examples: 168 >>> Decoder(256, 100, [56,56]) 169 """ 170 def __init__(self, 171 low_level_channel, 172 channel, 173 depth, 174 feature_shape, 175 scale_sizes, 176 decoder_output_stride, 177 fine_tune_batch_norm): 178 super(Decoder, self).__init__() 179 self.feature_projection = _conv_bn_relu(low_level_channel, 48, ksize=1, stride=1, 180 pad_mode="same", use_batch_statistics=fine_tune_batch_norm) 181 self.decoder_depth0 = _deep_conv_bn_relu(channel + 48, 182 channel_multiplier=1, 183 ksize=3, 184 stride=1, 185 pad_mode="same", 186 dilation=1, 187 use_batch_statistics=fine_tune_batch_norm) 188 self.decoder_pointwise0 = _conv_bn_relu(channel + 48, 189 depth, 190 ksize=1, 191 stride=1, 192 use_batch_statistics=fine_tune_batch_norm) 193 self.decoder_depth1 = _deep_conv_bn_relu(depth, 194 channel_multiplier=1, 195 ksize=3, 196 stride=1, 197 pad_mode="same", 198 dilation=1, 199 use_batch_statistics=fine_tune_batch_norm) 200 self.decoder_pointwise1 = _conv_bn_relu(depth, 201 depth, 202 ksize=1, 203 stride=1, 204 use_batch_statistics=fine_tune_batch_norm) 205 self.depth = depth 206 self.concat = P.Concat(axis=1) 207 self.samples = [] 208 for scale_size in scale_sizes: 209 self.samples.append(DecoderSampleBlock(feature_shape, scale_size, decoder_output_stride)) 210 self.samples = nn.CellList(self.samples) 211 212 def construct(self, x, low_level_feature, scale_index): 213 low_level_feature = self.feature_projection(low_level_feature) 214 low_level_feature = self.samples[scale_index](low_level_feature) 215 x = self.samples[scale_index](x) 216 output = self.concat((x, low_level_feature)) 217 output = self.decoder_depth0(output) 218 output = self.decoder_pointwise0(output) 219 output = self.decoder_depth1(output) 220 output = self.decoder_pointwise1(output) 221 return output 222 223 224class SingleDeepLabV3(nn.Cell): 225 """ 226 DeepLabv3 Network. 227 Args: 228 num_classes (int): Class number. 229 feature_shape (list): Input image shape, [N,C,H,W]. 230 backbone (Cell): Backbone Network. 231 channel (int): Resnet output channel. 232 depth (int): ASPP block depth. 233 scale_sizes (list): Input scales for multi-scale feature extraction. 234 atrous_rates (list): Atrous rates for atrous spatial pyramid pooling. 235 decoder_output_stride (int): 'The ratio of input to output spatial resolution' 236 output_stride (int): 'The ratio of input to output spatial resolution.' 237 fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' 238 Returns: 239 Tensor, output tensor. 240 Examples: 241 >>> SingleDeepLabV3(num_classes=10, 242 >>> feature_shape=[1,3,224,224], 243 >>> backbone=resnet50_dl(), 244 >>> channel=2048, 245 >>> depth=256) 246 >>> scale_sizes=[1.0]) 247 >>> atrous_rates=[6]) 248 >>> decoder_output_stride=4) 249 >>> output_stride=16) 250 """ 251 252 def __init__(self, 253 num_classes, 254 feature_shape, 255 backbone, 256 channel, 257 depth, 258 scale_sizes, 259 atrous_rates, 260 decoder_output_stride, 261 output_stride, 262 fine_tune_batch_norm=False): 263 super(SingleDeepLabV3, self).__init__() 264 self.num_classes = num_classes 265 self.channel = channel 266 self.depth = depth 267 self.scale_sizes = [] 268 for scale_size in np.sort(scale_sizes): 269 self.scale_sizes.append(scale_size) 270 self.net = backbone 271 self.aspp = ASPP(channel=self.channel, 272 depth=self.depth, 273 feature_shape=[feature_shape[2], 274 feature_shape[3]], 275 scale_sizes=self.scale_sizes, 276 atrous_rates=atrous_rates, 277 output_stride=output_stride, 278 fine_tune_batch_norm=fine_tune_batch_norm) 279 280 atrous_rates_len = 0 281 if atrous_rates is not None: 282 atrous_rates_len = len(atrous_rates) 283 self.fc1 = _conv_bn_relu(depth * (2 + atrous_rates_len), depth, 284 ksize=1, 285 stride=1, 286 use_batch_statistics=fine_tune_batch_norm) 287 self.fc2 = nn.Conv2d(depth, 288 num_classes, 289 kernel_size=1, 290 stride=1, 291 has_bias=True) 292 self.upsample = P.ResizeBilinear((int(feature_shape[2]), 293 int(feature_shape[3])), 294 align_corners=True) 295 self.samples = [] 296 for scale_size in self.scale_sizes: 297 self.samples.append(SampleBlock(feature_shape, scale_size)) 298 self.samples = nn.CellList(self.samples) 299 self.feature_shape = [float(feature_shape[0]), float(feature_shape[1]), float(feature_shape[2]), 300 float(feature_shape[3])] 301 302 self.pad = P.Pad(((0, 0), (0, 0), (1, 1), (1, 1))) 303 self.dropout = nn.Dropout(keep_prob=0.9) 304 self.shape = P.Shape() 305 self.decoder_output_stride = decoder_output_stride 306 if decoder_output_stride is not None: 307 self.decoder = Decoder(low_level_channel=depth, 308 channel=depth, 309 depth=depth, 310 feature_shape=[feature_shape[2], 311 feature_shape[3]], 312 scale_sizes=self.scale_sizes, 313 decoder_output_stride=decoder_output_stride, 314 fine_tune_batch_norm=fine_tune_batch_norm) 315 316 def construct(self, x, scale_index=0): 317 x = (2.0 / 255.0) * x - 1.0 318 x = self.pad(x) 319 low_level_feature, feature_map = self.net(x) 320 for scale_size in self.scale_sizes: 321 if scale_size * self.feature_shape[2] + 1.0 >= self.shape(x)[2] - 2: 322 output = self.aspp(feature_map, scale_index) 323 output = self.fc1(output) 324 if self.decoder_output_stride is not None: 325 output = self.decoder(output, low_level_feature, scale_index) 326 output = self.fc2(output) 327 output = self.samples[scale_index](output) 328 return output 329 scale_index += 1 330 return feature_map 331 332 333class SampleBlock(nn.Cell): 334 """Sample block.""" 335 def __init__(self, 336 feature_shape, 337 scale_size=1.0): 338 super(SampleBlock, self).__init__() 339 sample_h = np.ceil(float(feature_shape[2]) * scale_size) 340 sample_w = np.ceil(float(feature_shape[3]) * scale_size) 341 self.sample = P.ResizeBilinear((int(sample_h), int(sample_w)), align_corners=True) 342 343 def construct(self, x): 344 return self.sample(x) 345 346 347class DeepLabV3(nn.Cell): 348 """DeepLabV3 model.""" 349 def __init__(self, num_classes, feature_shape, backbone, channel, depth, infer_scale_sizes, atrous_rates, 350 decoder_output_stride, output_stride, fine_tune_batch_norm, image_pyramid): 351 super(DeepLabV3, self).__init__() 352 self.infer_scale_sizes = [] 353 if infer_scale_sizes is not None: 354 self.infer_scale_sizes = infer_scale_sizes 355 356 self.infer_scale_sizes = infer_scale_sizes 357 if image_pyramid is None: 358 image_pyramid = [1.0] 359 360 self.image_pyramid = image_pyramid 361 scale_sizes = [] 362 for pyramid in image_pyramid: 363 scale_sizes.append(pyramid) 364 for scale in infer_scale_sizes: 365 scale_sizes.append(scale) 366 self.samples = [] 367 for scale_size in scale_sizes: 368 self.samples.append(SampleBlock(feature_shape, scale_size)) 369 self.samples = nn.CellList(self.samples) 370 self.deeplabv3 = SingleDeepLabV3(num_classes=num_classes, 371 feature_shape=feature_shape, 372 backbone=resnet50_dl(fine_tune_batch_norm), 373 channel=channel, 374 depth=depth, 375 scale_sizes=scale_sizes, 376 atrous_rates=atrous_rates, 377 decoder_output_stride=decoder_output_stride, 378 output_stride=output_stride, 379 fine_tune_batch_norm=fine_tune_batch_norm) 380 self.softmax = P.Softmax(axis=1) 381 self.concat = P.Concat(axis=2) 382 self.expand_dims = P.ExpandDims() 383 self.reduce_mean = P.ReduceMean() 384 self.sample_common = P.ResizeBilinear((int(feature_shape[2]), 385 int(feature_shape[3])), 386 align_corners=True) 387 388 def construct(self, x): 389 logits = () 390 if self.training: 391 if len(self.image_pyramid) >= 1: 392 if self.image_pyramid[0] == 1: 393 logits = self.deeplabv3(x) 394 else: 395 x1 = self.samples[0](x) 396 logits = self.deeplabv3(x1) 397 logits = self.sample_common(logits) 398 logits = self.expand_dims(logits, 2) 399 for i in range(len(self.image_pyramid) - 1): 400 x_i = self.samples[i + 1](x) 401 logits_i = self.deeplabv3(x_i) 402 logits_i = self.sample_common(logits_i) 403 logits_i = self.expand_dims(logits_i, 2) 404 logits = self.concat((logits, logits_i)) 405 logits = self.reduce_mean(logits, 2) 406 return logits 407 if len(self.infer_scale_sizes) >= 1: 408 infer_index = len(self.image_pyramid) 409 x1 = self.samples[infer_index](x) 410 logits = self.deeplabv3(x1) 411 logits = self.sample_common(logits) 412 logits = self.softmax(logits) 413 logits = self.expand_dims(logits, 2) 414 for i in range(len(self.infer_scale_sizes) - 1): 415 x_i = self.samples[i + 1 + infer_index](x) 416 logits_i = self.deeplabv3(x_i) 417 logits_i = self.sample_common(logits_i) 418 logits_i = self.softmax(logits_i) 419 logits_i = self.expand_dims(logits_i, 2) 420 logits = self.concat((logits, logits_i)) 421 logits = self.reduce_mean(logits, 2) 422 return logits 423 424 425def deeplabv3_resnet50(num_classes, feature_shape, image_pyramid, 426 infer_scale_sizes, atrous_rates=None, decoder_output_stride=None, 427 output_stride=16, fine_tune_batch_norm=False): 428 """ 429 ResNet50 based DeepLabv3 network. 430 431 Args: 432 num_classes (int): Class number. 433 feature_shape (list): Input image shape, [N,C,H,W]. 434 image_pyramid (list): Input scales for multi-scale feature extraction. 435 atrous_rates (list): Atrous rates for atrous spatial pyramid pooling. 436 infer_scale_sizes (list): 'The scales to resize images for inference. 437 decoder_output_stride (int): 'The ratio of input to output spatial resolution' 438 output_stride (int): 'The ratio of input to output spatial resolution.' 439 fine_tune_batch_norm (bool): 'Fine tune the batch norm parameters or not' 440 441 Returns: 442 Cell, cell instance of ResNet50 based DeepLabv3 neural network. 443 444 Examples: 445 >>> deeplabv3_resnet50(100, [1,3,224,224],[1.0],[1.0]) 446 """ 447 return DeepLabV3(num_classes=num_classes, 448 feature_shape=feature_shape, 449 backbone=resnet50_dl(fine_tune_batch_norm), 450 channel=2048, 451 depth=256, 452 infer_scale_sizes=infer_scale_sizes, 453 atrous_rates=atrous_rates, 454 decoder_output_stride=decoder_output_stride, 455 output_stride=output_stride, 456 fine_tune_batch_norm=fine_tune_batch_norm, 457 image_pyramid=image_pyramid) 458