1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""image""" 16import numbers 17import numpy as np 18import mindspore.common.dtype as mstype 19from mindspore.common.tensor import Tensor 20from mindspore.ops import operations as P 21from mindspore.ops import functional as F 22from mindspore.ops.primitive import constexpr 23from mindspore._checkparam import Rel, Validator as validator 24from .conv import Conv2d 25from .container import CellList 26from .pooling import AvgPool2d 27from .activation import ReLU 28from ..cell import Cell 29 30__all__ = ['ImageGradients', 'SSIM', 'MSSSIM', 'PSNR', 'CentralCrop'] 31 32 33class ImageGradients(Cell): 34 r""" 35 Returns two tensors, the first is along the height dimension and the second is along the width dimension. 36 37 Assume an image shape is :math:`h*w`. The gradients along the height and the width are :math:`dy` and :math:`dx`, 38 respectively. 39 40 .. math:: 41 dy[i] = \begin{cases} image[i+1, :]-image[i, :], &if\ 0<=i<h-1 \cr 42 0, &if\ i==h-1\end{cases} 43 44 dx[i] = \begin{cases} image[:, i+1]-image[:, i], &if\ 0<=i<w-1 \cr 45 0, &if\ i==w-1\end{cases} 46 47 Inputs: 48 - **images** (Tensor) - The input image data, with format 'NCHW'. 49 50 Outputs: 51 - **dy** (Tensor) - vertical image gradients, the same type and shape as input. 52 - **dx** (Tensor) - horizontal image gradients, the same type and shape as input. 53 54 Raises: 55 ValueError: If length of shape of `images` is not equal to 4. 56 57 Supported Platforms: 58 ``Ascend`` ``GPU`` ``CPU`` 59 60 Examples: 61 >>> net = nn.ImageGradients() 62 >>> image = Tensor(np.array([[[[1, 2], [3, 4]]]]), dtype=mindspore.int32) 63 >>> output = net(image) 64 >>> print(output) 65 (Tensor(shape=[1, 1, 2, 2], dtype=Int32, value= 66 [[[[2, 2], 67 [0, 0]]]]), Tensor(shape=[1, 1, 2, 2], dtype=Int32, value= 68 [[[[1, 0], 69 [1, 0]]]])) 70 """ 71 def __init__(self): 72 super(ImageGradients, self).__init__() 73 74 def construct(self, images): 75 check = _check_input_4d(F.shape(images), "images", self.cls_name) 76 images = F.depend(images, check) 77 batch_size, depth, height, width = P.Shape()(images) 78 if height == 1: 79 dy = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0) 80 else: 81 dy = images[:, :, 1:, :] - images[:, :, :height - 1, :] 82 dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0) 83 dy = P.Concat(2)((dy, dy_last)) 84 85 if width == 1: 86 dx = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0) 87 else: 88 dx = images[:, :, :, 1:] - images[:, :, :, :width - 1] 89 dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0) 90 dx = P.Concat(3)((dx, dx_last)) 91 return dy, dx 92 93 94def _convert_img_dtype_to_float32(img, max_val): 95 """convert img dtype to float32""" 96 # Usually max_val is 1.0 or 255, we will do the scaling if max_val > 1. 97 # We will scale img pixel value if max_val > 1. and just cast otherwise. 98 ret = F.cast(img, mstype.float32) 99 max_val = F.scalar_cast(max_val, mstype.float32) 100 if max_val > 1.: 101 scale = 1. / max_val 102 ret = ret * scale 103 return ret 104 105 106@constexpr 107def _get_dtype_max(dtype): 108 """get max of the dtype""" 109 np_type = mstype.dtype_to_nptype(dtype) 110 if issubclass(np_type, numbers.Integral): 111 dtype_max = np.float64(np.iinfo(np_type).max) 112 else: 113 dtype_max = 1.0 114 return dtype_max 115 116 117@constexpr 118def _check_input_4d(input_shape, param_name, func_name): 119 if len(input_shape) != 4: 120 raise ValueError(f"For '{func_name}', the dimension of '{param_name}' should be 4d, " 121 f"but got {len(input_shape)}.") 122 return True 123 124 125@constexpr 126def _check_input_filter_size(input_shape, param_name, filter_size, func_name): 127 _check_input_4d(input_shape, param_name, func_name) 128 validator.check(param_name + " shape[2]", input_shape[2], "filter_size", filter_size, Rel.GE, func_name) 129 validator.check(param_name + " shape[3]", input_shape[3], "filter_size", filter_size, Rel.GE, func_name) 130 131 132@constexpr 133def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name): 134 validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) 135 136 137def _conv2d(in_channels, out_channels, kernel_size, weight, stride=1, padding=0): 138 return Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, 139 weight_init=weight, padding=padding, pad_mode="valid") 140 141 142def _create_window(size, sigma): 143 x_data, y_data = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1] 144 x_data = np.expand_dims(x_data, axis=-1).astype(np.float32) 145 x_data = np.expand_dims(x_data, axis=-1) ** 2 146 y_data = np.expand_dims(y_data, axis=-1).astype(np.float32) 147 y_data = np.expand_dims(y_data, axis=-1) ** 2 148 sigma = 2 * sigma ** 2 149 g = np.exp(-(x_data + y_data) / sigma) 150 return np.transpose(g / np.sum(g), (2, 3, 0, 1)) 151 152 153def _split_img(x): 154 _, c, _, _ = F.shape(x) 155 img_split = P.Split(1, c) 156 output = img_split(x) 157 return output, c 158 159 160def _compute_per_channel_loss(c1, c2, img1, img2, conv): 161 """computes ssim index between img1 and img2 per single channel""" 162 dot_img = img1 * img2 163 mu1 = conv(img1) 164 mu2 = conv(img2) 165 mu1_sq = mu1 * mu1 166 mu2_sq = mu2 * mu2 167 mu1_mu2 = mu1 * mu2 168 sigma1_tmp = conv(img1 * img1) 169 sigma1_sq = sigma1_tmp - mu1_sq 170 sigma2_tmp = conv(img2 * img2) 171 sigma2_sq = sigma2_tmp - mu2_sq 172 sigma12_tmp = conv(dot_img) 173 sigma12 = sigma12_tmp - mu1_mu2 174 a = (2 * mu1_mu2 + c1) 175 b = (mu1_sq + mu2_sq + c1) 176 v1 = 2 * sigma12 + c2 177 v2 = sigma1_sq + sigma2_sq + c2 178 ssim = (a * v1) / (b * v2) 179 cs = v1 / v2 180 return ssim, cs 181 182 183def _compute_multi_channel_loss(c1, c2, img1, img2, conv, concat, mean): 184 """computes ssim index between img1 and img2 per color channel""" 185 split_img1, c = _split_img(img1) 186 split_img2, _ = _split_img(img2) 187 multi_ssim = () 188 multi_cs = () 189 for i in range(c): 190 ssim_per_channel, cs_per_channel = _compute_per_channel_loss(c1, c2, split_img1[i], split_img2[i], conv) 191 multi_ssim += (ssim_per_channel,) 192 multi_cs += (cs_per_channel,) 193 194 multi_ssim = concat(multi_ssim) 195 multi_cs = concat(multi_cs) 196 197 ssim = mean(multi_ssim, (2, 3)) 198 cs = mean(multi_cs, (2, 3)) 199 return ssim, cs 200 201 202class SSIM(Cell): 203 r""" 204 Returns SSIM index between two images. 205 206 Its implementation is based on Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). `Image quality 207 assessment: from error visibility to structural similarity <https://ieeexplore.ieee.org/document/1284395>`_. 208 IEEE transactions on image processing. 209 210 SSIM is a measure of the similarity of two pictures. 211 Like PSNR, SSIM is often used as an evaluation of image quality. SSIM is a number between 0 and 1.The larger it is, 212 the smaller the gap between the output image and the undistorted image, that is, the better the image quality. 213 When the two images are exactly the same, SSIM=1. 214 215 .. math:: 216 217 l(x,y)&=\frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1}, C_1=(K_1L)^2.\\ 218 c(x,y)&=\frac{2\sigma_x\sigma_y+C_2}{\sigma_x^2+\sigma_y^2+C_2}, C_2=(K_2L)^2.\\ 219 s(x,y)&=\frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}, C_3=C_2/2.\\ 220 SSIM(x,y)&=l*c*s\\&=\frac{(2\mu_x\mu_y+C_1)(2\sigma_{xy}+C_2}{(\mu_x^2+\mu_y^2+C_1)(\sigma_x^2+\sigma_y^2+C_2)}. 221 222 Args: 223 max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images). 224 Default: 1.0. 225 filter_size (int): The size of the Gaussian filter. Default: 11. The value must be greater than or equal to 1. 226 filter_sigma (float): The standard deviation of Gaussian kernel. Default: 1.5. 227 The value must be greater than 0. 228 k1 (float): The constant used to generate c1 in the luminance comparison function. Default: 0.01. 229 k2 (float): The constant used to generate c2 in the contrast comparison function. Default: 0.03. 230 231 Inputs: 232 - **img1** (Tensor) - The first image batch with format 'NCHW'. It must be the same shape and dtype as img2. 233 - **img2** (Tensor) - The second image batch with format 'NCHW'. It must be the same shape and dtype as img1. 234 235 Outputs: 236 Tensor, has the same dtype as img1. It is a 1-D tensor with shape N, where N is the batch num of img1. 237 238 Raises: 239 TypeError: If `max_val` is neither int nor float. 240 TypeError: If `k1`, `k2` or `filter_sigma` is not a float. 241 TypeError: If `filter_size` is not an int. 242 ValueError: If `max_val` or `filter_sigma` is less than or equal to 0. 243 ValueError: If `filter_size` is less than 0. 244 245 Supported Platforms: 246 ``Ascend`` ``GPU`` ``CPU`` 247 248 Examples: 249 >>> import numpy as np 250 >>> import mindspore.nn as nn 251 >>> from mindspore import Tensor 252 >>> net = nn.SSIM() 253 >>> img1 = Tensor(np.ones([1, 3, 16, 16]).astype(np.float32)) 254 >>> img2 = Tensor(np.ones([1, 3, 16, 16]).astype(np.float32)) 255 >>> output = net(img1, img2) 256 >>> print(output) 257 [1.] 258 """ 259 def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): 260 super(SSIM, self).__init__() 261 validator.check_value_type('max_val', max_val, [int, float], self.cls_name) 262 validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) 263 self.max_val = max_val 264 self.filter_size = validator.check_int(filter_size, 1, Rel.GE, 'filter_size', self.cls_name) 265 self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name) 266 self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) 267 self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) 268 window = _create_window(filter_size, filter_sigma) 269 self.conv = _conv2d(1, 1, filter_size, Tensor(window)) 270 self.conv.weight.requires_grad = False 271 self.reduce_mean = P.ReduceMean() 272 self.concat = P.Concat(axis=1) 273 274 def construct(self, img1, img2): 275 _check_input_dtype(F.dtype(img1), "img1", [mstype.float32, mstype.float16], self.cls_name) 276 _check_input_filter_size(F.shape(img1), "img1", self.filter_size, self.cls_name) 277 P.SameTypeShape()(img1, img2) 278 dtype_max_val = _get_dtype_max(F.dtype(img1)) 279 max_val = F.scalar_cast(self.max_val, F.dtype(img1)) 280 max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val) 281 img1 = _convert_img_dtype_to_float32(img1, dtype_max_val) 282 img2 = _convert_img_dtype_to_float32(img2, dtype_max_val) 283 284 c1 = (self.k1 * max_val) ** 2 285 c2 = (self.k2 * max_val) ** 2 286 287 ssim_ave_channel, _ = _compute_multi_channel_loss(c1, c2, img1, img2, self.conv, self.concat, self.reduce_mean) 288 loss = self.reduce_mean(ssim_ave_channel, -1) 289 290 return loss 291 292 293def _downsample(img1, img2, op): 294 a = op(img1) 295 b = op(img2) 296 return a, b 297 298 299class MSSSIM(Cell): 300 r""" 301 Returns MS-SSIM index between two images. 302 303 Its implementation is based on Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. `Multiscale structural similarity 304 for image quality assessment <https://ieeexplore.ieee.org/document/1292216>`_. 305 Signals, Systems and Computers, 2004. 306 307 .. math:: 308 309 l(x,y)&=\frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1}, C_1=(K_1L)^2.\\ 310 c(x,y)&=\frac{2\sigma_x\sigma_y+C_2}{\sigma_x^2+\sigma_y^2+C_2}, C_2=(K_2L)^2.\\ 311 s(x,y)&=\frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}, C_3=C_2/2.\\ 312 MSSSIM(x,y)&=l^\alpha_M*{\prod_{1\leq j\leq M} (c^\beta_j*s^\gamma_j)}. 313 314 Args: 315 max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images). 316 Default: 1.0. 317 power_factors (Union[tuple, list]): Iterable of weights for each scal e. 318 Default: (0.0448, 0.2856, 0.3001, 0.2363, 0.1333). Default values obtained by Wang et al. 319 filter_size (int): The size of the Gaussian filter. Default: 11. 320 filter_sigma (float): The standard deviation of Gaussian kernel. Default: 1.5. 321 k1 (float): The constant used to generate c1 in the luminance comparison function. Default: 0.01. 322 k2 (float): The constant used to generate c2 in the contrast comparison function. Default: 0.03. 323 324 Inputs: 325 - **img1** (Tensor) - The first image batch with format 'NCHW'. It must be the same shape and dtype as img2. 326 - **img2** (Tensor) - The second image batch with format 'NCHW'. It must be the same shape and dtype as img1. 327 328 Outputs: 329 Tensor, the value is in range [0, 1]. It is a 1-D tensor with shape N, where N is the batch num of img1. 330 331 Raises: 332 TypeError: If `max_val` is neither int nor float. 333 TypeError: If `power_factors` is neither tuple nor list. 334 TypeError: If `k1`, `k2` or `filter_sigma` is not a float. 335 TypeError: If `filter_size` is not an int. 336 ValueError: If `max_val` or `filter_sigma` is less than or equal to 0. 337 ValueError: If `filter_size` is less than 0. 338 ValueError: If length of shape of `img1` or `img2` is not equal to 4. 339 340 Supported Platforms: 341 ``Ascend`` ``GPU`` 342 343 Examples: 344 >>> import numpy as np 345 >>> import mindspore.nn as nn 346 >>> from mindspore import Tensor 347 >>> net = nn.MSSSIM(power_factors=(0.033, 0.033, 0.033)) 348 >>> img1 = Tensor(np.ones((1, 3, 128, 128)).astype(np.float32)) 349 >>> img2 = Tensor(np.ones((1, 3, 128, 128)).astype(np.float32)) 350 >>> output = net(img1, img2) 351 >>> print(output) 352 [1.] 353 """ 354 def __init__(self, max_val=1.0, power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), filter_size=11, 355 filter_sigma=1.5, k1=0.01, k2=0.03): 356 super(MSSSIM, self).__init__() 357 validator.check_value_type('max_val', max_val, [int, float], self.cls_name) 358 validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) 359 self.max_val = max_val 360 validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name) 361 self.filter_size = validator.check_int(filter_size, 1, Rel.GE, 'filter_size', self.cls_name) 362 self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name) 363 self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) 364 self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) 365 window = _create_window(filter_size, filter_sigma) 366 self.level = len(power_factors) 367 self.conv = [] 368 for i in range(self.level): 369 self.conv.append(_conv2d(1, 1, filter_size, Tensor(window))) 370 self.conv[i].weight.requires_grad = False 371 self.multi_convs_list = CellList(self.conv) 372 self.weight_tensor = Tensor(power_factors, mstype.float32) 373 self.avg_pool = AvgPool2d(kernel_size=2, stride=2, pad_mode='valid') 374 self.relu = ReLU() 375 self.reduce_mean = P.ReduceMean() 376 self.prod = P.ReduceProd() 377 self.pow = P.Pow() 378 self.stack = P.Stack(axis=-1) 379 self.concat = P.Concat(axis=1) 380 381 def construct(self, img1, img2): 382 _check_input_4d(F.shape(img1), "img1", self.cls_name) 383 _check_input_4d(F.shape(img2), "img2", self.cls_name) 384 valid_type = [mstype.float64, mstype.float32, mstype.float16, mstype.uint8] 385 _check_input_dtype(F.dtype(img1), 'img1', valid_type, self.cls_name) 386 P.SameTypeShape()(img1, img2) 387 dtype_max_val = _get_dtype_max(F.dtype(img1)) 388 max_val = F.scalar_cast(self.max_val, F.dtype(img1)) 389 max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val) 390 img1 = _convert_img_dtype_to_float32(img1, dtype_max_val) 391 img2 = _convert_img_dtype_to_float32(img2, dtype_max_val) 392 393 c1 = (self.k1 * max_val) ** 2 394 c2 = (self.k2 * max_val) ** 2 395 396 sim = () 397 mcs = () 398 399 for i in range(self.level): 400 sim, cs = _compute_multi_channel_loss(c1, c2, img1, img2, 401 self.multi_convs_list[i], self.concat, self.reduce_mean) 402 mcs += (self.relu(cs),) 403 img1, img2 = _downsample(img1, img2, self.avg_pool) 404 405 mcs = mcs[0:-1:1] 406 mcs_and_ssim = self.stack(mcs + (self.relu(sim),)) 407 mcs_and_ssim = self.pow(mcs_and_ssim, self.weight_tensor) 408 ms_ssim = self.prod(mcs_and_ssim, -1) 409 loss = self.reduce_mean(ms_ssim, -1) 410 411 return loss 412 413 414class PSNR(Cell): 415 r""" 416 Returns Peak Signal-to-Noise Ratio of two image batches. 417 418 It produces a PSNR value for each image in batch. 419 Assume inputs are :math:`I` and :math:`K`, both with shape :math:`h*w`. 420 :math:`MAX` represents the dynamic range of pixel values. 421 422 .. math:: 423 424 MSE&=\frac{1}{hw}\sum\limits_{i=0}^{h-1}\sum\limits_{j=0}^{w-1}[I(i,j)-K(i,j)]^2\\ 425 PSNR&=10*log_{10}(\frac{MAX^2}{MSE}) 426 427 Args: 428 max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images). 429 The value must be greater than 0. Default: 1.0. 430 431 Inputs: 432 - **img1** (Tensor) - The first image batch with format 'NCHW'. It must be the same shape and dtype as img2. 433 - **img2** (Tensor) - The second image batch with format 'NCHW'. It must be the same shape and dtype as img1. 434 435 Outputs: 436 Tensor, with dtype mindspore.float32. It is a 1-D tensor with shape N, where N is the batch num of img1. 437 438 Raises: 439 TypeError: If `max_val` is neither int nor float. 440 ValueError: If `max_val` is less than or equal to 0. 441 ValueError: If length of shape of `img1` or `img2` is not equal to 4. 442 443 Supported Platforms: 444 ``Ascend`` ``GPU`` ``CPU`` 445 446 Examples: 447 >>> net = nn.PSNR() 448 >>> img1 = Tensor([[[[1, 2, 3, 4], [1, 2, 3, 4]]]]) 449 >>> img2 = Tensor([[[[3, 4, 5, 6], [3, 4, 5, 6]]]]) 450 >>> output = net(img1, img2) 451 >>> print(output) 452 [-6.0206] 453 """ 454 def __init__(self, max_val=1.0): 455 super(PSNR, self).__init__() 456 validator.check_value_type('max_val', max_val, [int, float], self.cls_name) 457 validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) 458 self.max_val = max_val 459 460 def construct(self, img1, img2): 461 _check_input_4d(F.shape(img1), "img1", self.cls_name) 462 _check_input_4d(F.shape(img2), "img2", self.cls_name) 463 P.SameTypeShape()(img1, img2) 464 dtype_max_val = _get_dtype_max(F.dtype(img1)) 465 max_val = F.scalar_cast(self.max_val, F.dtype(img1)) 466 max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val) 467 img1 = _convert_img_dtype_to_float32(img1, dtype_max_val) 468 img2 = _convert_img_dtype_to_float32(img2, dtype_max_val) 469 470 mse = P.ReduceMean()(F.square(img1 - img2), (-3, -2, -1)) 471 psnr = 10 * P.Log()(F.square(max_val) / mse) / F.scalar_log(10.0) 472 473 return psnr 474 475 476@constexpr 477def _raise_dims_rank_error(input_shape, param_name, func_name): 478 """raise error if input is not 3d or 4d""" 479 raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}") 480 481 482@constexpr 483def _get_bbox(rank, shape, central_fraction): 484 """get bbox start and size for slice""" 485 if rank == 3: 486 c, h, w = shape 487 else: 488 n, c, h, w = shape 489 490 bbox_h_start = int((float(h) - np.float32(h * central_fraction)) / 2) 491 bbox_w_start = int((float(w) - np.float32(w * central_fraction)) / 2) 492 bbox_h_size = h - bbox_h_start * 2 493 bbox_w_size = w - bbox_w_start * 2 494 495 if rank == 3: 496 bbox_begin = (0, bbox_h_start, bbox_w_start) 497 bbox_size = (c, bbox_h_size, bbox_w_size) 498 else: 499 bbox_begin = (0, 0, bbox_h_start, bbox_w_start) 500 bbox_size = (n, c, bbox_h_size, bbox_w_size) 501 502 return bbox_begin, bbox_size 503 504 505class CentralCrop(Cell): 506 """ 507 Crops the central region of the images with the central_fraction. 508 509 Args: 510 central_fraction (float): Fraction of size to crop. It must be float and in range (0.0, 1.0]. 511 512 Inputs: 513 - **image** (Tensor) - A 3-D tensor of shape [C, H, W], or a 4-D tensor of shape [N, C, H, W]. 514 515 Outputs: 516 Tensor, 3-D or 4-D float tensor, according to the input. 517 518 Raises: 519 TypeError: If `central_fraction` is not a float. 520 ValueError: If `central_fraction` is not in range (0, 1.0]. 521 522 Supported Platforms: 523 ``Ascend`` ``GPU`` ``CPU`` 524 525 Examples: 526 >>> net = nn.CentralCrop(central_fraction=0.5) 527 >>> image = Tensor(np.random.random((4, 3, 4, 4)), mindspore.float32) 528 >>> output = net(image) 529 >>> print(output.shape) 530 (4, 3, 2, 2) 531 """ 532 533 def __init__(self, central_fraction): 534 super(CentralCrop, self).__init__() 535 validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name) 536 validator.check_float_range(central_fraction, 0.0, 1.0, Rel.INC_RIGHT, 'central_fraction', self.cls_name) 537 self.central_fraction = central_fraction 538 self.slice = P.Slice() 539 540 def construct(self, image): 541 image_shape = F.shape(image) 542 rank = len(image_shape) 543 if not rank in (3, 4): 544 return _raise_dims_rank_error(image_shape, "image", self.cls_name) 545 if self.central_fraction == 1.0: 546 return image 547 548 bbox_begin, bbox_size = _get_bbox(rank, image_shape, self.central_fraction) 549 image = self.slice(image, bbox_begin, bbox_size) 550 551 return image 552