1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""image""" 16from __future__ import absolute_import 17from __future__ import division 18 19import numbers 20import numpy as np 21 22import mindspore.common.dtype as mstype 23import mindspore.ops as ops 24from mindspore.common.tensor import Tensor 25from mindspore.ops import operations as P 26from mindspore.ops.operations import _inner_ops as inner 27from mindspore.ops import functional as F 28from mindspore.ops.primitive import constexpr, _primexpr 29from mindspore import _checkparam as validator 30from mindspore.nn.layer.conv import Conv2d 31from mindspore.nn.layer.container import CellList 32from mindspore.nn.layer.pooling import AvgPool2d 33from mindspore.nn.layer.activation import ReLU 34from mindspore.nn.cell import Cell 35 36__all__ = ['ImageGradients', 'SSIM', 'MSSSIM', 'PSNR', 'CentralCrop', 'PixelShuffle', 'PixelUnshuffle'] 37 38 39class ImageGradients(Cell): 40 r""" 41 Returns two tensors, the first is along the height dimension and the second is along the width dimension. 42 43 Assume an image shape is :math:`h*w`, the gradients along the height and the width are :math:`dy` and :math:`dx`, 44 respectively. 45 46 .. math:: 47 dy[i] = \begin{cases} image[i+1, :]-image[i, :], &if\ 0<=i<h-1 \cr 48 0, &if\ i==h-1\end{cases} 49 50 dx[i] = \begin{cases} image[:, i+1]-image[:, i], &if\ 0<=i<w-1 \cr 51 0, &if\ i==w-1\end{cases} 52 53 Inputs: 54 - **images** (Tensor) - The input image data, with format 'NCHW'. 55 56 Outputs: 57 - **dy** (Tensor) - vertical image gradients, the same type and shape as input. 58 - **dx** (Tensor) - horizontal image gradients, the same type and shape as input. 59 60 Raises: 61 ValueError: If length of shape of `images` is not equal to 4. 62 63 Supported Platforms: 64 ``Ascend`` ``GPU`` ``CPU`` 65 66 Examples: 67 >>> import mindspore as ms 68 >>> import numpy as np 69 >>> net = ms.nn.ImageGradients() 70 >>> image = ms.Tensor(np.array([[[[1, 2], [3, 4]]]]), dtype=ms.int32) 71 >>> output = net(image) 72 >>> print(output) 73 (Tensor(shape=[1, 1, 2, 2], dtype=Int32, value= 74 [[[[2, 2], 75 [0, 0]]]]), Tensor(shape=[1, 1, 2, 2], dtype=Int32, value= 76 [[[[1, 0], 77 [1, 0]]]])) 78 """ 79 def __init__(self): 80 super(ImageGradients, self).__init__() 81 82 def construct(self, images): 83 _check_input_4d(F.shape(images), "images", self.cls_name) 84 batch_size, depth, height, width = P.Shape()(images) 85 if height == 1: 86 dy = F.fill(P.DType()(images), (batch_size, depth, 1, width), 0) 87 else: 88 dy = images[:, :, 1:, :] - images[:, :, :height - 1, :] 89 dy_last = F.fill(P.DType()(images), (batch_size, depth, 1, width), 0) 90 dy = P.Concat(2)((dy, dy_last)) 91 92 if width == 1: 93 dx = F.fill(P.DType()(images), (batch_size, depth, height, 1), 0) 94 else: 95 dx = images[:, :, :, 1:] - images[:, :, :, :width - 1] 96 dx_last = F.fill(P.DType()(images), (batch_size, depth, height, 1), 0) 97 dx = P.Concat(3)((dx, dx_last)) 98 return dy, dx 99 100 101def _convert_img_dtype_to_float32(img, max_val): 102 """convert img dtype to float32""" 103 # Usually max_val is 1.0 or 255, we will do the scaling if max_val > 1. 104 # We will scale img pixel value if max_val > 1. and just cast otherwise. 105 ret = F.cast(img, mstype.float32) 106 max_val = F.scalar_cast(max_val, mstype.float32) 107 if max_val > 1.: 108 scale = 1. / max_val 109 ret = ret * scale 110 return ret 111 112 113@constexpr 114def _get_dtype_max(dtype): 115 """get max of the dtype""" 116 np_type = mstype.dtype_to_nptype(dtype) 117 if issubclass(np_type, numbers.Integral): 118 dtype_max = np.float64(np.iinfo(np_type).max).item() 119 else: 120 dtype_max = 1.0 121 return dtype_max 122 123 124@_primexpr 125def _check_input_4d(input_shape, param_name, func_name): 126 if len(input_shape) != 4: 127 raise ValueError(f"For '{func_name}', the dimension of '{param_name}' must be 4d, " 128 f"but got {len(input_shape)}.") 129 130 131@_primexpr 132def _check_input_filter_size(input_shape, param_name, filter_size, func_name): 133 _check_input_4d(input_shape, param_name, func_name) 134 validator.check(param_name + " shape[2]", input_shape[2], "filter_size", filter_size, validator.GE, func_name) 135 validator.check(param_name + " shape[3]", input_shape[3], "filter_size", filter_size, validator.GE, func_name) 136 137 138@constexpr 139def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name): 140 validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) 141 142 143def _conv2d(in_channels, out_channels, kernel_size, weight, stride=1, padding=0): 144 return Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, 145 weight_init=weight, padding=padding, pad_mode="valid") 146 147 148def _create_window(size, sigma): 149 x_data, y_data = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1] 150 x_data = np.expand_dims(x_data, axis=-1).astype(np.float32) 151 x_data = np.expand_dims(x_data, axis=-1) ** 2 152 y_data = np.expand_dims(y_data, axis=-1).astype(np.float32) 153 y_data = np.expand_dims(y_data, axis=-1) ** 2 154 sigma = 2 * sigma ** 2 155 g = np.exp(-(x_data + y_data) / sigma) 156 return np.transpose(g / np.sum(g), (2, 3, 0, 1)) 157 158 159def _split_img(x): 160 _, c, _, _ = F.shape(x) 161 img_split = P.Split(1, c) 162 output = img_split(x) 163 return output, c 164 165 166def _compute_per_channel_loss(c1, c2, img1, img2, conv): 167 """computes ssim index between img1 and img2 per single channel""" 168 dot_img = img1 * img2 169 mu1 = conv(img1) 170 mu2 = conv(img2) 171 mu1_sq = mu1 * mu1 172 mu2_sq = mu2 * mu2 173 mu1_mu2 = mu1 * mu2 174 sigma1_tmp = conv(img1 * img1) 175 sigma1_sq = sigma1_tmp - mu1_sq 176 sigma2_tmp = conv(img2 * img2) 177 sigma2_sq = sigma2_tmp - mu2_sq 178 sigma12_tmp = conv(dot_img) 179 sigma12 = sigma12_tmp - mu1_mu2 180 a = (2 * mu1_mu2 + c1) 181 b = (mu1_sq + mu2_sq + c1) 182 v1 = 2 * sigma12 + c2 183 v2 = sigma1_sq + sigma2_sq + c2 184 ssim = (a * v1) / (b * v2) 185 cs = v1 / v2 186 return ssim, cs 187 188 189def _compute_multi_channel_loss(c1, c2, img1, img2, conv, concat, mean): 190 """computes ssim index between img1 and img2 per color channel""" 191 split_img1, c = _split_img(img1) 192 split_img2, _ = _split_img(img2) 193 multi_ssim = () 194 multi_cs = () 195 for i in range(c): 196 ssim_per_channel, cs_per_channel = _compute_per_channel_loss(c1, c2, split_img1[i], split_img2[i], conv) 197 multi_ssim += (ssim_per_channel,) 198 multi_cs += (cs_per_channel,) 199 200 multi_ssim = concat(multi_ssim) 201 multi_cs = concat(multi_cs) 202 203 ssim = mean(multi_ssim, (2, 3)) 204 cs = mean(multi_cs, (2, 3)) 205 return ssim, cs 206 207 208class SSIM(Cell): 209 r""" 210 Returns SSIM index between two images. 211 212 Its implementation is based on Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004) `Image quality 213 assessment: from error visibility to structural similarity <https://ieeexplore.ieee.org/document/1284395>`_ . 214 215 SSIM is a measure of the similarity of two pictures. 216 Like PSNR, 217 SSIM is often used as an evaluation of image quality. 218 SSIM is a number between 0 and 1, and the larger it is, 219 the smaller the gap between the output image and the undistorted image, that is, the better the image quality. 220 When the two images are exactly the same, SSIM=1. 221 222 .. math:: 223 224 l(x,y)&=\frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1}, C_1=(K_1L)^2.\\ 225 c(x,y)&=\frac{2\sigma_x\sigma_y+C_2}{\sigma_x^2+\sigma_y^2+C_2}, C_2=(K_2L)^2.\\ 226 s(x,y)&=\frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}, C_3=C_2/2.\\ 227 SSIM(x,y)&=l*c*s\\&=\frac{(2\mu_x\mu_y+C_1)(2\sigma_{xy}+C_2)} 228 {(\mu_x^2+\mu_y^2+C_1)(\sigma_x^2+\sigma_y^2+C_2)}. 229 230 Args: 231 max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images). 232 Default: ``1.0`` . 233 filter_size (int): The size of the Gaussian filter. Default: 11. The value must be greater than or equal to 1. 234 filter_sigma (float): The standard deviation of Gaussian kernel. Default: ``1.5`` . 235 The value must be greater than 0. 236 k1 (float): The constant used to generate c1 in the luminance comparison function. Default: ``0.01`` . 237 k2 (float): The constant used to generate c2 in the contrast comparison function. Default: ``0.03`` . 238 239 Inputs: 240 - **img1** (Tensor) - The first image batch with format 'NCHW'. It must be the same shape and dtype as img2. 241 - **img2** (Tensor) - The second image batch with format 'NCHW'. It must be the same shape and dtype as img1. 242 243 Outputs: 244 Tensor, has the same dtype as img1. It is a 1-D tensor with shape N, where N is the batch num of img1. 245 246 Raises: 247 TypeError: If `max_val` is neither int nor float. 248 TypeError: If `k1`, `k2` or `filter_sigma` is not a float. 249 TypeError: If `filter_size` is not an int. 250 ValueError: If `max_val` or `filter_sigma` is less than or equal to 0. 251 ValueError: If `filter_size` is less than 0. 252 253 Supported Platforms: 254 ``Ascend`` ``GPU`` ``CPU`` 255 256 Examples: 257 >>> import mindspore as ms 258 >>> import numpy as np 259 >>> net = ms.nn.SSIM() 260 >>> img1 = ms.Tensor(np.ones([1, 3, 16, 16]).astype(np.float32)) 261 >>> img2 = ms.Tensor(np.ones([1, 3, 16, 16]).astype(np.float32)) 262 >>> output = net(img1, img2) 263 >>> print(output) 264 [1.] 265 """ 266 def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): 267 super(SSIM, self).__init__() 268 validator.check_value_type('max_val', max_val, [int, float], self.cls_name) 269 validator.check_number('max_val', max_val, 0.0, validator.GT, self.cls_name) 270 self.max_val = max_val 271 self.filter_size = validator.check_int(filter_size, 1, validator.GE, 'filter_size', self.cls_name) 272 self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name) 273 self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) 274 self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) 275 window = _create_window(filter_size, filter_sigma) 276 self.conv = _conv2d(1, 1, filter_size, Tensor(window)) 277 self.conv.weight.requires_grad = False 278 self.reduce_mean = P.ReduceMean() 279 self.concat = P.Concat(axis=1) 280 281 def construct(self, img1, img2): 282 _check_input_dtype(F.dtype(img1), "img1", [mstype.float32, mstype.float16], self.cls_name) 283 _check_input_filter_size(F.shape(img1), "img1", self.filter_size, self.cls_name) 284 inner.SameTypeShape()(img1, img2) 285 dtype_max_val = _get_dtype_max(F.dtype(img1)) 286 max_val = _convert_img_dtype_to_float32(self.max_val, dtype_max_val) 287 img1 = _convert_img_dtype_to_float32(img1, dtype_max_val) 288 img2 = _convert_img_dtype_to_float32(img2, dtype_max_val) 289 290 c1 = (self.k1 * max_val) ** 2 291 c2 = (self.k2 * max_val) ** 2 292 293 ssim_ave_channel, _ = _compute_multi_channel_loss(c1, c2, img1, img2, self.conv, self.concat, self.reduce_mean) 294 loss = self.reduce_mean(ssim_ave_channel, -1) 295 296 return loss 297 298 299def _downsample(img1, img2, op): 300 a = op(img1) 301 b = op(img2) 302 return a, b 303 304 305class MSSSIM(Cell): 306 r""" 307 Returns MS-SSIM index between two images. 308 309 Its implementation is based on `Multiscale structural similarity 310 for image quality assessment <https://ieeexplore.ieee.org/document/1292216>`_ 311 by Zhou Wang, Eero P. Simoncelli, and Alan C. Bovik, published on Signals, Systems and Computers in 2004. 312 313 .. math:: 314 315 l(x,y)&=\frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1}, C_1=(K_1L)^2.\\ 316 c(x,y)&=\frac{2\sigma_x\sigma_y+C_2}{\sigma_x^2+\sigma_y^2+C_2}, C_2=(K_2L)^2.\\ 317 s(x,y)&=\frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}, C_3=C_2/2.\\ 318 MSSSIM(x,y)&=l^\alpha_M*{\prod_{1\leq j\leq M} (c^\beta_j*s^\gamma_j)}. 319 320 Args: 321 max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images). 322 Default: ``1.0`` . 323 power_factors (Union[tuple, list]): Iterable of weights for each scale. 324 Default: (0.0448, 0.2856, 0.3001, 0.2363, 0.1333). Default values obtained by Wang et al. 325 filter_size (int): The size of the Gaussian filter. Default: ``11`` . 326 filter_sigma (float): The standard deviation of Gaussian kernel. Default: ``1.5`` . 327 k1 (float): The constant used to generate c1 in the luminance comparison function. Default: ``0.01`` . 328 k2 (float): The constant used to generate c2 in the contrast comparison function. Default: ``0.03`` . 329 330 Inputs: 331 - **img1** (Tensor) - The first image batch with format 'NCHW'. It must be the same shape and dtype as img2. 332 - **img2** (Tensor) - The second image batch with format 'NCHW'. It must be the same shape and dtype as img1. 333 334 Outputs: 335 Tensor, the value is in range [0, 1]. It is a 1-D tensor with shape N, where N is the batch num of img1. 336 337 Raises: 338 TypeError: If `max_val` is neither int nor float. 339 TypeError: If `power_factors` is neither tuple nor list. 340 TypeError: If `k1`, `k2` or `filter_sigma` is not a float. 341 TypeError: If `filter_size` is not an int. 342 ValueError: If `max_val` or `filter_sigma` is less than or equal to 0. 343 ValueError: If `filter_size` is less than 0. 344 ValueError: If length of shape of `img1` or `img2` is not equal to 4. 345 346 Supported Platforms: 347 ``Ascend`` ``GPU`` 348 349 Examples: 350 >>> import numpy as np 351 >>> import mindspore as ms 352 >>> net = ms.nn.MSSSIM(power_factors=(0.033, 0.033, 0.033)) 353 >>> img1 = ms.Tensor(np.ones((1, 3, 128, 128)).astype(np.float32)) 354 >>> img2 = ms.Tensor(np.ones((1, 3, 128, 128)).astype(np.float32)) 355 >>> output = net(img1, img2) 356 >>> print(output) 357 [1.] 358 """ 359 def __init__(self, max_val=1.0, power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), filter_size=11, 360 filter_sigma=1.5, k1=0.01, k2=0.03): 361 super(MSSSIM, self).__init__() 362 validator.check_value_type('max_val', max_val, [int, float], self.cls_name) 363 validator.check_number('max_val', max_val, 0.0, validator.GT, self.cls_name) 364 self.max_val = max_val 365 validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name) 366 self.filter_size = validator.check_int(filter_size, 1, validator.GE, 'filter_size', self.cls_name) 367 self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name) 368 self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) 369 self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) 370 window = _create_window(filter_size, filter_sigma) 371 self.level = len(power_factors) 372 self.conv = [] 373 for i in range(self.level): 374 self.conv.append(_conv2d(1, 1, filter_size, Tensor(window))) 375 self.conv[i].weight.requires_grad = False 376 self.multi_convs_list = CellList(self.conv) 377 self.weight_tensor = Tensor(power_factors, mstype.float32) 378 self.avg_pool = AvgPool2d(kernel_size=2, stride=2, pad_mode='valid') 379 self.relu = ReLU() 380 self.reduce_mean = P.ReduceMean() 381 self.prod = P.ReduceProd() 382 self.pow = P.Pow() 383 self.stack = P.Stack(axis=-1) 384 self.concat = P.Concat(axis=1) 385 386 def construct(self, img1, img2): 387 _check_input_4d(F.shape(img1), "img1", self.cls_name) 388 _check_input_4d(F.shape(img2), "img2", self.cls_name) 389 valid_type = [mstype.float64, mstype.float32, mstype.float16, mstype.uint8] 390 _check_input_dtype(F.dtype(img1), 'img1', valid_type, self.cls_name) 391 inner.SameTypeShape()(img1, img2) 392 dtype_max_val = _get_dtype_max(F.dtype(img1)) 393 max_val = _convert_img_dtype_to_float32(self.max_val, dtype_max_val) 394 img1 = _convert_img_dtype_to_float32(img1, dtype_max_val) 395 img2 = _convert_img_dtype_to_float32(img2, dtype_max_val) 396 397 c1 = (self.k1 * max_val) ** 2 398 c2 = (self.k2 * max_val) ** 2 399 400 sim = () 401 mcs = () 402 403 for i in range(self.level): 404 sim, cs = _compute_multi_channel_loss(c1, c2, img1, img2, 405 self.multi_convs_list[i], self.concat, self.reduce_mean) 406 mcs += (self.relu(cs),) 407 img1, img2 = _downsample(img1, img2, self.avg_pool) 408 409 mcs = mcs[0:-1:1] 410 mcs_and_ssim = self.stack(mcs + (self.relu(sim),)) 411 mcs_and_ssim = self.pow(mcs_and_ssim, self.weight_tensor) 412 ms_ssim = self.prod(mcs_and_ssim, -1) 413 loss = self.reduce_mean(ms_ssim, -1) 414 415 return loss 416 417 418class PSNR(Cell): 419 r""" 420 Returns Peak Signal-to-Noise Ratio of two image batches. 421 422 It produces a PSNR value for each image in batch. 423 Assume inputs are :math:`I` and :math:`K`, both with shape :math:`h*w`. 424 :math:`MAX` represents the dynamic range of pixel values. 425 426 .. math:: 427 428 MSE&=\frac{1}{hw}\sum\limits_{i=0}^{h-1}\sum\limits_{j=0}^{w-1}[I(i,j)-K(i,j)]^2\\ 429 PSNR&=10*log_{10}(\frac{MAX^2}{MSE}) 430 431 Args: 432 max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images). 433 The value must be greater than 0. Default: ``1.0`` . 434 435 Inputs: 436 - **img1** (Tensor) - The first image batch with format 'NCHW'. It must be the same shape and dtype as img2. 437 - **img2** (Tensor) - The second image batch with format 'NCHW'. It must be the same shape and dtype as img1. 438 439 Outputs: 440 Tensor, with dtype mindspore.float32. It is a 1-D tensor with shape N, where N is the batch num of img1. 441 442 Raises: 443 TypeError: If `max_val` is neither int nor float. 444 ValueError: If `max_val` is less than or equal to 0. 445 ValueError: If length of shape of `img1` or `img2` is not equal to 4. 446 447 Supported Platforms: 448 ``Ascend`` ``GPU`` ``CPU`` 449 450 Examples: 451 >>> import mindspore as ms 452 >>> net = ms.nn.PSNR() 453 >>> img1 = ms.Tensor([[[[1, 2, 3, 4], [1, 2, 3, 4]]]]) 454 >>> img2 = ms.Tensor([[[[3, 4, 5, 6], [3, 4, 5, 6]]]]) 455 >>> output = net(img1, img2) 456 >>> print(output) 457 [-6.0206] 458 """ 459 def __init__(self, max_val=1.0): 460 super(PSNR, self).__init__() 461 validator.check_value_type('max_val', max_val, [int, float], self.cls_name) 462 validator.check_number('max_val', max_val, 0.0, validator.GT, self.cls_name) 463 self.max_val = max_val 464 465 def construct(self, img1, img2): 466 _check_input_4d(F.shape(img1), "img1", self.cls_name) 467 _check_input_4d(F.shape(img2), "img2", self.cls_name) 468 inner.SameTypeShape()(img1, img2) 469 dtype_max_val = _get_dtype_max(F.dtype(img1)) 470 max_val = _convert_img_dtype_to_float32(self.max_val, dtype_max_val) 471 img1 = _convert_img_dtype_to_float32(img1, dtype_max_val) 472 img2 = _convert_img_dtype_to_float32(img2, dtype_max_val) 473 474 mse = P.ReduceMean()(F.square(img1 - img2), (-3, -2, -1)) 475 psnr = 10 * P.Log()(F.square(max_val) / mse) / F.scalar_log(10.0) 476 477 return psnr 478 479 480@_primexpr 481def _check_rank(rank, input_shape, param_name, func_name): 482 """raise error if input is not 3d or 4d""" 483 def _check(): 484 if rank not in (3, 4): 485 raise ValueError(f"{func_name} {param_name} must be 3d or 4d, but got shape {input_shape}") 486 _check() 487 488 489@_primexpr 490def _get_bbox(rank, shape, central_fraction): 491 """get bbox start and size for slice""" 492 n, c, h, w = -1, -1, -1, -1 493 if rank == 3: 494 c, h, w = shape 495 else: 496 n, c, h, w = shape 497 498 bbox_h_start = int((float(h) - float(h * central_fraction)) / 2) 499 bbox_w_start = int((float(w) - float(w * central_fraction)) / 2) 500 bbox_h_size = h - bbox_h_start * 2 501 bbox_w_size = w - bbox_w_start * 2 502 503 if rank == 3: 504 bbox_begin = (0, bbox_h_start, bbox_w_start) 505 bbox_size = (c, bbox_h_size, bbox_w_size) 506 else: 507 bbox_begin = (0, 0, bbox_h_start, bbox_w_start) 508 bbox_size = (n, c, bbox_h_size, bbox_w_size) 509 510 return bbox_begin, bbox_size 511 512 513class CentralCrop(Cell): 514 """ 515 Crops the central region of the images with the central_fraction. 516 517 Args: 518 central_fraction (float): Fraction of size to crop. It must be float and in range (0.0, 1.0]. 519 520 Inputs: 521 - **image** (Tensor) - A 3-D tensor of shape [C, H, W], or a 4-D tensor of shape [N, C, H, W]. 522 523 Outputs: 524 Tensor, 3-D or 4-D float tensor, according to the input. 525 526 Raises: 527 TypeError: If `central_fraction` is not a float. 528 ValueError: If `central_fraction` is not in range (0.0, 1.0]. 529 530 Supported Platforms: 531 ``Ascend`` ``GPU`` ``CPU`` 532 533 Examples: 534 >>> import mindspore as ms 535 >>> import numpy as np 536 >>> net = ms.nn.CentralCrop(central_fraction=0.5) 537 >>> image = ms.Tensor(np.random.random((4, 3, 4, 4)), ms.float32) 538 >>> output = net(image) 539 >>> print(output.shape) 540 (4, 3, 2, 2) 541 """ 542 543 def __init__(self, central_fraction): 544 super(CentralCrop, self).__init__() 545 validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name) 546 validator.check_float_range(central_fraction, 0.0, 1.0, validator.INC_RIGHT, 'central_fraction', self.cls_name) 547 self.central_fraction = central_fraction 548 self.slice = P.Slice() 549 550 def construct(self, image): 551 image_shape = F.shape(image) 552 rank = len(image_shape) 553 _check_rank(rank, image_shape, "image", self.cls_name) 554 if self.central_fraction == 1.0: 555 return image 556 557 bbox_begin, bbox_size = _get_bbox(rank, image_shape, self.central_fraction) 558 image = self.slice(image, bbox_begin, bbox_size) 559 560 return image 561 562 563class PixelShuffle(Cell): 564 r""" 565 Applies the PixelShuffle operation over input which implements sub-pixel convolutions 566 with stride :math:`1/r` . For more details, refer to 567 `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network 568 <https://arxiv.org/abs/1609.05158>`_ . 569 570 Typically, the input is of shape :math:`(*, C \times r^2, H, W)` , and the output is of shape 571 :math:`(*, C, H \times r, W \times r)`, 572 where :math:`r` is an upscale factor and :math:`*` is zero or more batch dimensions. 573 574 Note: 575 The dimension of input Tensor on Ascend should be less than 7. 576 577 Args: 578 upscale_factor (int): factor to shuffle the input, and is a positive integer. 579 `upscale_factor` is the above-mentioned :math:`r`. 580 581 Inputs: 582 - **input** (Tensor) - Tensor of shape :math:`(*, C \times r^2, H, W)` . The dimension of `x` is larger than 2, 583 and the length of third to last dimension can be divisible by `upscale_factor` squared. 584 585 Outputs: 586 - **output** (Tensor) - Tensor of shape :math:`(*, C, H \times r, W \times r)` . 587 588 Raises: 589 ValueError: If `upscale_factor` is not a positive integer. 590 ValueError: If the length of third to last dimension of `input` is not divisible by `upscale_factor` squared. 591 ValueError: If the dimension of `input` is less than 3. 592 TypeError: If `input` is not a Tensor. 593 594 Supported Platforms: 595 ``Ascend`` ``GPU`` ``CPU`` 596 597 Examples: 598 >>> import mindspore as ms 599 >>> import numpy as np 600 >>> input_x = np.arange(3 * 2 * 8 * 4 * 4).reshape((3, 2, 8, 4, 4)) 601 >>> input_x = ms.Tensor(input_x, ms.dtype.int32) 602 >>> pixel_shuffle = ms.nn.PixelShuffle(2) 603 >>> output = pixel_shuffle(input_x) 604 >>> print(output.shape) 605 (3, 2, 2, 8, 8) 606 """ 607 def __init__(self, upscale_factor): 608 super(PixelShuffle, self).__init__() 609 self.upscale_factor = upscale_factor 610 611 def construct(self, input): 612 return ops.pixel_shuffle(input, self.upscale_factor) 613 614 615class PixelUnshuffle(Cell): 616 r""" 617 Applies the PixelUnshuffle operation over input which is the inverse of PixelShuffle. For more details, refer 618 to `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network 619 <https://arxiv.org/abs/1609.05158>`_ . 620 621 Typically, the input is of shape :math:`(*, C, H \times r, W \times r)` , and the output is of shape 622 :math:`(*, C \times r^2, H, W)` , 623 where :math:`r` is a downscale factor and :math:`*` is zero or more batch dimensions. 624 625 Args: 626 downscale_factor (int): factor to unshuffle the input, and is a positive integer. 627 `downscale_factor` is the above-mentioned :math:`r`. 628 629 Inputs: 630 - **input** (Tensor) - Tensor of shape :math:`(*, C, H \times r, W \times r)` . The dimension of `input` is 631 larger than 2, and the length of second to last dimension or last dimension can be divisible by 632 `downscale_factor` . 633 634 Outputs: 635 - **output** (Tensor) - Tensor of shape :math:`(*, C \times r^2, H, W)` . 636 637 Raises: 638 ValueError: If `downscale_factor` is not a positive integer. 639 ValueError: If the length of second to last dimension or last dimension is not divisible by `downscale_factor` . 640 ValueError: If the dimension of `input` is less than 3. 641 TypeError: If `input` is not a Tensor. 642 643 Supported Platforms: 644 ``Ascend`` ``GPU`` ``CPU`` 645 646 Examples: 647 >>> import mindspore as ms 648 >>> import numpy as np 649 >>> pixel_unshuffle = ms.nn.PixelUnshuffle(2) 650 >>> input_x = np.arange(8 * 8).reshape((1, 1, 8, 8)) 651 >>> input_x = ms.Tensor(input_x, ms.dtype.int32) 652 >>> output = pixel_unshuffle(input_x) 653 >>> print(output.shape) 654 (1, 4, 4, 4) 655 """ 656 def __init__(self, downscale_factor): 657 super(PixelUnshuffle, self).__init__() 658 self.downscale_factor = downscale_factor 659 660 def construct(self, input): 661 return ops.pixel_unshuffle(input, self.downscale_factor) 662