• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""image"""
16import numbers
17import numpy as np
18import mindspore.common.dtype as mstype
19from mindspore.common.tensor import Tensor
20from mindspore.ops import operations as P
21from mindspore.ops import functional as F
22from mindspore.ops.primitive import constexpr
23from mindspore._checkparam import Rel, Validator as validator
24from .conv import Conv2d
25from .container import CellList
26from .pooling import AvgPool2d
27from .activation import ReLU
28from ..cell import Cell
29
30__all__ = ['ImageGradients', 'SSIM', 'MSSSIM', 'PSNR', 'CentralCrop']
31
32
33class ImageGradients(Cell):
34    r"""
35    Returns two tensors, the first is along the height dimension and the second is along the width dimension.
36
37    Assume an image shape is :math:`h*w`. The gradients along the height and the width are :math:`dy` and :math:`dx`,
38    respectively.
39
40    .. math::
41        dy[i] = \begin{cases} image[i+1, :]-image[i, :], &if\ 0<=i<h-1 \cr
42        0, &if\ i==h-1\end{cases}
43
44        dx[i] = \begin{cases} image[:, i+1]-image[:, i], &if\ 0<=i<w-1 \cr
45        0, &if\ i==w-1\end{cases}
46
47    Inputs:
48        - **images** (Tensor) - The input image data, with format 'NCHW'.
49
50    Outputs:
51        - **dy** (Tensor) - vertical image gradients, the same type and shape as input.
52        - **dx** (Tensor) - horizontal image gradients, the same type and shape as input.
53
54    Raises:
55        ValueError: If length of shape of `images` is not equal to 4.
56
57    Supported Platforms:
58        ``Ascend`` ``GPU`` ``CPU``
59
60    Examples:
61        >>> net = nn.ImageGradients()
62        >>> image = Tensor(np.array([[[[1, 2], [3, 4]]]]), dtype=mindspore.int32)
63        >>> output = net(image)
64        >>> print(output)
65        (Tensor(shape=[1, 1, 2, 2], dtype=Int32, value=
66        [[[[2, 2],
67           [0, 0]]]]), Tensor(shape=[1, 1, 2, 2], dtype=Int32, value=
68        [[[[1, 0],
69           [1, 0]]]]))
70    """
71    def __init__(self):
72        super(ImageGradients, self).__init__()
73
74    def construct(self, images):
75        check = _check_input_4d(F.shape(images), "images", self.cls_name)
76        images = F.depend(images, check)
77        batch_size, depth, height, width = P.Shape()(images)
78        if height == 1:
79            dy = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
80        else:
81            dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
82            dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
83            dy = P.Concat(2)((dy, dy_last))
84
85        if width == 1:
86            dx = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
87        else:
88            dx = images[:, :, :, 1:] - images[:, :, :, :width - 1]
89            dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
90            dx = P.Concat(3)((dx, dx_last))
91        return dy, dx
92
93
94def _convert_img_dtype_to_float32(img, max_val):
95    """convert img dtype to float32"""
96    # Usually max_val is 1.0 or 255, we will do the scaling if max_val > 1.
97    # We will scale img pixel value if max_val > 1. and just cast otherwise.
98    ret = F.cast(img, mstype.float32)
99    max_val = F.scalar_cast(max_val, mstype.float32)
100    if max_val > 1.:
101        scale = 1. / max_val
102        ret = ret * scale
103    return ret
104
105
106@constexpr
107def _get_dtype_max(dtype):
108    """get max of the dtype"""
109    np_type = mstype.dtype_to_nptype(dtype)
110    if issubclass(np_type, numbers.Integral):
111        dtype_max = np.float64(np.iinfo(np_type).max)
112    else:
113        dtype_max = 1.0
114    return dtype_max
115
116
117@constexpr
118def _check_input_4d(input_shape, param_name, func_name):
119    if len(input_shape) != 4:
120        raise ValueError(f"For '{func_name}', the dimension of '{param_name}' should be 4d, "
121                         f"but got {len(input_shape)}.")
122    return True
123
124
125@constexpr
126def _check_input_filter_size(input_shape, param_name, filter_size, func_name):
127    _check_input_4d(input_shape, param_name, func_name)
128    validator.check(param_name + " shape[2]", input_shape[2], "filter_size", filter_size, Rel.GE, func_name)
129    validator.check(param_name + " shape[3]", input_shape[3], "filter_size", filter_size, Rel.GE, func_name)
130
131
132@constexpr
133def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
134    validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
135
136
137def _conv2d(in_channels, out_channels, kernel_size, weight, stride=1, padding=0):
138    return Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
139                  weight_init=weight, padding=padding, pad_mode="valid")
140
141
142def _create_window(size, sigma):
143    x_data, y_data = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1]
144    x_data = np.expand_dims(x_data, axis=-1).astype(np.float32)
145    x_data = np.expand_dims(x_data, axis=-1) ** 2
146    y_data = np.expand_dims(y_data, axis=-1).astype(np.float32)
147    y_data = np.expand_dims(y_data, axis=-1) ** 2
148    sigma = 2 * sigma ** 2
149    g = np.exp(-(x_data + y_data) / sigma)
150    return np.transpose(g / np.sum(g), (2, 3, 0, 1))
151
152
153def _split_img(x):
154    _, c, _, _ = F.shape(x)
155    img_split = P.Split(1, c)
156    output = img_split(x)
157    return output, c
158
159
160def _compute_per_channel_loss(c1, c2, img1, img2, conv):
161    """computes ssim index between img1 and img2 per single channel"""
162    dot_img = img1 * img2
163    mu1 = conv(img1)
164    mu2 = conv(img2)
165    mu1_sq = mu1 * mu1
166    mu2_sq = mu2 * mu2
167    mu1_mu2 = mu1 * mu2
168    sigma1_tmp = conv(img1 * img1)
169    sigma1_sq = sigma1_tmp - mu1_sq
170    sigma2_tmp = conv(img2 * img2)
171    sigma2_sq = sigma2_tmp - mu2_sq
172    sigma12_tmp = conv(dot_img)
173    sigma12 = sigma12_tmp - mu1_mu2
174    a = (2 * mu1_mu2 + c1)
175    b = (mu1_sq + mu2_sq + c1)
176    v1 = 2 * sigma12 + c2
177    v2 = sigma1_sq + sigma2_sq + c2
178    ssim = (a * v1) / (b * v2)
179    cs = v1 / v2
180    return ssim, cs
181
182
183def _compute_multi_channel_loss(c1, c2, img1, img2, conv, concat, mean):
184    """computes ssim index between img1 and img2 per color channel"""
185    split_img1, c = _split_img(img1)
186    split_img2, _ = _split_img(img2)
187    multi_ssim = ()
188    multi_cs = ()
189    for i in range(c):
190        ssim_per_channel, cs_per_channel = _compute_per_channel_loss(c1, c2, split_img1[i], split_img2[i], conv)
191        multi_ssim += (ssim_per_channel,)
192        multi_cs += (cs_per_channel,)
193
194    multi_ssim = concat(multi_ssim)
195    multi_cs = concat(multi_cs)
196
197    ssim = mean(multi_ssim, (2, 3))
198    cs = mean(multi_cs, (2, 3))
199    return ssim, cs
200
201
202class SSIM(Cell):
203    r"""
204    Returns SSIM index between two images.
205
206    Its implementation is based on Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). `Image quality
207    assessment: from error visibility to structural similarity <https://ieeexplore.ieee.org/document/1284395>`_.
208    IEEE transactions on image processing.
209
210    SSIM is a measure of the similarity of two pictures.
211    Like PSNR, SSIM is often used as an evaluation of image quality. SSIM is a number between 0 and 1.The larger it is,
212    the smaller the gap between the output image and the undistorted image, that is, the better the image quality.
213    When the two images are exactly the same, SSIM=1.
214
215    .. math::
216
217        l(x,y)&=\frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1}, C_1=(K_1L)^2.\\
218        c(x,y)&=\frac{2\sigma_x\sigma_y+C_2}{\sigma_x^2+\sigma_y^2+C_2}, C_2=(K_2L)^2.\\
219        s(x,y)&=\frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}, C_3=C_2/2.\\
220        SSIM(x,y)&=l*c*s\\&=\frac{(2\mu_x\mu_y+C_1)(2\sigma_{xy}+C_2}{(\mu_x^2+\mu_y^2+C_1)(\sigma_x^2+\sigma_y^2+C_2)}.
221
222    Args:
223        max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images).
224          Default: 1.0.
225        filter_size (int): The size of the Gaussian filter. Default: 11. The value must be greater than or equal to 1.
226        filter_sigma (float): The standard deviation of Gaussian kernel. Default: 1.5.
227          The value must be greater than 0.
228        k1 (float): The constant used to generate c1 in the luminance comparison function. Default: 0.01.
229        k2 (float): The constant used to generate c2 in the contrast comparison function. Default: 0.03.
230
231    Inputs:
232        - **img1** (Tensor) - The first image batch with format 'NCHW'. It must be the same shape and dtype as img2.
233        - **img2** (Tensor) - The second image batch with format 'NCHW'. It must be the same shape and dtype as img1.
234
235    Outputs:
236        Tensor, has the same dtype as img1. It is a 1-D tensor with shape N, where N is the batch num of img1.
237
238    Raises:
239        TypeError: If `max_val` is neither int nor float.
240        TypeError: If `k1`, `k2` or `filter_sigma` is not a float.
241        TypeError: If `filter_size` is not an int.
242        ValueError: If `max_val` or `filter_sigma` is less than or equal to 0.
243        ValueError: If `filter_size` is less than 0.
244
245    Supported Platforms:
246        ``Ascend`` ``GPU`` ``CPU``
247
248    Examples:
249        >>> import numpy as np
250        >>> import mindspore.nn as nn
251        >>> from mindspore import Tensor
252        >>> net = nn.SSIM()
253        >>> img1 = Tensor(np.ones([1, 3, 16, 16]).astype(np.float32))
254        >>> img2 = Tensor(np.ones([1, 3, 16, 16]).astype(np.float32))
255        >>> output = net(img1, img2)
256        >>> print(output)
257        [1.]
258    """
259    def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
260        super(SSIM, self).__init__()
261        validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
262        validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
263        self.max_val = max_val
264        self.filter_size = validator.check_int(filter_size, 1, Rel.GE, 'filter_size', self.cls_name)
265        self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name)
266        self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
267        self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
268        window = _create_window(filter_size, filter_sigma)
269        self.conv = _conv2d(1, 1, filter_size, Tensor(window))
270        self.conv.weight.requires_grad = False
271        self.reduce_mean = P.ReduceMean()
272        self.concat = P.Concat(axis=1)
273
274    def construct(self, img1, img2):
275        _check_input_dtype(F.dtype(img1), "img1", [mstype.float32, mstype.float16], self.cls_name)
276        _check_input_filter_size(F.shape(img1), "img1", self.filter_size, self.cls_name)
277        P.SameTypeShape()(img1, img2)
278        dtype_max_val = _get_dtype_max(F.dtype(img1))
279        max_val = F.scalar_cast(self.max_val, F.dtype(img1))
280        max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val)
281        img1 = _convert_img_dtype_to_float32(img1, dtype_max_val)
282        img2 = _convert_img_dtype_to_float32(img2, dtype_max_val)
283
284        c1 = (self.k1 * max_val) ** 2
285        c2 = (self.k2 * max_val) ** 2
286
287        ssim_ave_channel, _ = _compute_multi_channel_loss(c1, c2, img1, img2, self.conv, self.concat, self.reduce_mean)
288        loss = self.reduce_mean(ssim_ave_channel, -1)
289
290        return loss
291
292
293def _downsample(img1, img2, op):
294    a = op(img1)
295    b = op(img2)
296    return a, b
297
298
299class MSSSIM(Cell):
300    r"""
301    Returns MS-SSIM index between two images.
302
303    Its implementation is based on Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. `Multiscale structural similarity
304    for image quality assessment <https://ieeexplore.ieee.org/document/1292216>`_.
305    Signals, Systems and Computers, 2004.
306
307    .. math::
308
309        l(x,y)&=\frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1}, C_1=(K_1L)^2.\\
310        c(x,y)&=\frac{2\sigma_x\sigma_y+C_2}{\sigma_x^2+\sigma_y^2+C_2}, C_2=(K_2L)^2.\\
311        s(x,y)&=\frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}, C_3=C_2/2.\\
312        MSSSIM(x,y)&=l^\alpha_M*{\prod_{1\leq j\leq M} (c^\beta_j*s^\gamma_j)}.
313
314    Args:
315        max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images).
316          Default: 1.0.
317        power_factors (Union[tuple, list]): Iterable of weights for each scal e.
318          Default: (0.0448, 0.2856, 0.3001, 0.2363, 0.1333). Default values obtained by Wang et al.
319        filter_size (int): The size of the Gaussian filter. Default: 11.
320        filter_sigma (float): The standard deviation of Gaussian kernel. Default: 1.5.
321        k1 (float): The constant used to generate c1 in the luminance comparison function. Default: 0.01.
322        k2 (float): The constant used to generate c2 in the contrast comparison function. Default: 0.03.
323
324    Inputs:
325        - **img1** (Tensor) - The first image batch with format 'NCHW'. It must be the same shape and dtype as img2.
326        - **img2** (Tensor) - The second image batch with format 'NCHW'. It must be the same shape and dtype as img1.
327
328    Outputs:
329        Tensor, the value is in range [0, 1]. It is a 1-D tensor with shape N, where N is the batch num of img1.
330
331    Raises:
332        TypeError: If `max_val` is neither int nor float.
333        TypeError: If `power_factors` is neither tuple nor list.
334        TypeError: If `k1`, `k2` or `filter_sigma` is not a float.
335        TypeError: If `filter_size` is not an int.
336        ValueError: If `max_val` or `filter_sigma` is less than or equal to 0.
337        ValueError: If `filter_size` is less than 0.
338        ValueError: If length of shape of `img1` or `img2` is not equal to 4.
339
340    Supported Platforms:
341        ``Ascend`` ``GPU``
342
343    Examples:
344        >>> import numpy as np
345        >>> import mindspore.nn as nn
346        >>> from mindspore import Tensor
347        >>> net = nn.MSSSIM(power_factors=(0.033, 0.033, 0.033))
348        >>> img1 = Tensor(np.ones((1, 3, 128, 128)).astype(np.float32))
349        >>> img2 = Tensor(np.ones((1, 3, 128, 128)).astype(np.float32))
350        >>> output = net(img1, img2)
351        >>> print(output)
352        [1.]
353    """
354    def __init__(self, max_val=1.0, power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), filter_size=11,
355                 filter_sigma=1.5, k1=0.01, k2=0.03):
356        super(MSSSIM, self).__init__()
357        validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
358        validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
359        self.max_val = max_val
360        validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name)
361        self.filter_size = validator.check_int(filter_size, 1, Rel.GE, 'filter_size', self.cls_name)
362        self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name)
363        self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
364        self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
365        window = _create_window(filter_size, filter_sigma)
366        self.level = len(power_factors)
367        self.conv = []
368        for i in range(self.level):
369            self.conv.append(_conv2d(1, 1, filter_size, Tensor(window)))
370            self.conv[i].weight.requires_grad = False
371        self.multi_convs_list = CellList(self.conv)
372        self.weight_tensor = Tensor(power_factors, mstype.float32)
373        self.avg_pool = AvgPool2d(kernel_size=2, stride=2, pad_mode='valid')
374        self.relu = ReLU()
375        self.reduce_mean = P.ReduceMean()
376        self.prod = P.ReduceProd()
377        self.pow = P.Pow()
378        self.stack = P.Stack(axis=-1)
379        self.concat = P.Concat(axis=1)
380
381    def construct(self, img1, img2):
382        _check_input_4d(F.shape(img1), "img1", self.cls_name)
383        _check_input_4d(F.shape(img2), "img2", self.cls_name)
384        valid_type = [mstype.float64, mstype.float32, mstype.float16, mstype.uint8]
385        _check_input_dtype(F.dtype(img1), 'img1', valid_type, self.cls_name)
386        P.SameTypeShape()(img1, img2)
387        dtype_max_val = _get_dtype_max(F.dtype(img1))
388        max_val = F.scalar_cast(self.max_val, F.dtype(img1))
389        max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val)
390        img1 = _convert_img_dtype_to_float32(img1, dtype_max_val)
391        img2 = _convert_img_dtype_to_float32(img2, dtype_max_val)
392
393        c1 = (self.k1 * max_val) ** 2
394        c2 = (self.k2 * max_val) ** 2
395
396        sim = ()
397        mcs = ()
398
399        for i in range(self.level):
400            sim, cs = _compute_multi_channel_loss(c1, c2, img1, img2,
401                                                  self.multi_convs_list[i], self.concat, self.reduce_mean)
402            mcs += (self.relu(cs),)
403            img1, img2 = _downsample(img1, img2, self.avg_pool)
404
405        mcs = mcs[0:-1:1]
406        mcs_and_ssim = self.stack(mcs + (self.relu(sim),))
407        mcs_and_ssim = self.pow(mcs_and_ssim, self.weight_tensor)
408        ms_ssim = self.prod(mcs_and_ssim, -1)
409        loss = self.reduce_mean(ms_ssim, -1)
410
411        return loss
412
413
414class PSNR(Cell):
415    r"""
416    Returns Peak Signal-to-Noise Ratio of two image batches.
417
418    It produces a PSNR value for each image in batch.
419    Assume inputs are :math:`I` and :math:`K`, both with shape :math:`h*w`.
420    :math:`MAX` represents the dynamic range of pixel values.
421
422    .. math::
423
424        MSE&=\frac{1}{hw}\sum\limits_{i=0}^{h-1}\sum\limits_{j=0}^{w-1}[I(i,j)-K(i,j)]^2\\
425        PSNR&=10*log_{10}(\frac{MAX^2}{MSE})
426
427    Args:
428        max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images).
429          The value must be greater than 0. Default: 1.0.
430
431    Inputs:
432        - **img1** (Tensor) - The first image batch with format 'NCHW'. It must be the same shape and dtype as img2.
433        - **img2** (Tensor) - The second image batch with format 'NCHW'. It must be the same shape and dtype as img1.
434
435    Outputs:
436        Tensor, with dtype mindspore.float32. It is a 1-D tensor with shape N, where N is the batch num of img1.
437
438    Raises:
439        TypeError: If `max_val` is neither int nor float.
440        ValueError: If `max_val` is less than or equal to 0.
441        ValueError: If length of shape of `img1` or `img2` is not equal to 4.
442
443    Supported Platforms:
444        ``Ascend`` ``GPU`` ``CPU``
445
446    Examples:
447        >>> net = nn.PSNR()
448        >>> img1 = Tensor([[[[1, 2, 3, 4], [1, 2, 3, 4]]]])
449        >>> img2 = Tensor([[[[3, 4, 5, 6], [3, 4, 5, 6]]]])
450        >>> output = net(img1, img2)
451        >>> print(output)
452        [-6.0206]
453    """
454    def __init__(self, max_val=1.0):
455        super(PSNR, self).__init__()
456        validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
457        validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
458        self.max_val = max_val
459
460    def construct(self, img1, img2):
461        _check_input_4d(F.shape(img1), "img1", self.cls_name)
462        _check_input_4d(F.shape(img2), "img2", self.cls_name)
463        P.SameTypeShape()(img1, img2)
464        dtype_max_val = _get_dtype_max(F.dtype(img1))
465        max_val = F.scalar_cast(self.max_val, F.dtype(img1))
466        max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val)
467        img1 = _convert_img_dtype_to_float32(img1, dtype_max_val)
468        img2 = _convert_img_dtype_to_float32(img2, dtype_max_val)
469
470        mse = P.ReduceMean()(F.square(img1 - img2), (-3, -2, -1))
471        psnr = 10 * P.Log()(F.square(max_val) / mse) / F.scalar_log(10.0)
472
473        return psnr
474
475
476@constexpr
477def _raise_dims_rank_error(input_shape, param_name, func_name):
478    """raise error if input is not 3d or 4d"""
479    raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}")
480
481
482@constexpr
483def _get_bbox(rank, shape, central_fraction):
484    """get bbox start and size for slice"""
485    if rank == 3:
486        c, h, w = shape
487    else:
488        n, c, h, w = shape
489
490    bbox_h_start = int((float(h) - np.float32(h * central_fraction)) / 2)
491    bbox_w_start = int((float(w) - np.float32(w * central_fraction)) / 2)
492    bbox_h_size = h - bbox_h_start * 2
493    bbox_w_size = w - bbox_w_start * 2
494
495    if rank == 3:
496        bbox_begin = (0, bbox_h_start, bbox_w_start)
497        bbox_size = (c, bbox_h_size, bbox_w_size)
498    else:
499        bbox_begin = (0, 0, bbox_h_start, bbox_w_start)
500        bbox_size = (n, c, bbox_h_size, bbox_w_size)
501
502    return bbox_begin, bbox_size
503
504
505class CentralCrop(Cell):
506    """
507    Crops the central region of the images with the central_fraction.
508
509    Args:
510        central_fraction (float): Fraction of size to crop. It must be float and in range (0.0, 1.0].
511
512    Inputs:
513        - **image** (Tensor) - A 3-D tensor of shape [C, H, W], or a 4-D tensor of shape [N, C, H, W].
514
515    Outputs:
516        Tensor, 3-D or 4-D float tensor, according to the input.
517
518    Raises:
519        TypeError: If `central_fraction` is not a float.
520        ValueError: If `central_fraction` is not in range (0, 1.0].
521
522    Supported Platforms:
523        ``Ascend`` ``GPU`` ``CPU``
524
525    Examples:
526        >>> net = nn.CentralCrop(central_fraction=0.5)
527        >>> image = Tensor(np.random.random((4, 3, 4, 4)), mindspore.float32)
528        >>> output = net(image)
529        >>> print(output.shape)
530        (4, 3, 2, 2)
531    """
532
533    def __init__(self, central_fraction):
534        super(CentralCrop, self).__init__()
535        validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name)
536        validator.check_float_range(central_fraction, 0.0, 1.0, Rel.INC_RIGHT, 'central_fraction', self.cls_name)
537        self.central_fraction = central_fraction
538        self.slice = P.Slice()
539
540    def construct(self, image):
541        image_shape = F.shape(image)
542        rank = len(image_shape)
543        if not rank in (3, 4):
544            return _raise_dims_rank_error(image_shape, "image", self.cls_name)
545        if self.central_fraction == 1.0:
546            return image
547
548        bbox_begin, bbox_size = _get_bbox(rank, image_shape, self.central_fraction)
549        image = self.slice(image, bbox_begin, bbox_size)
550
551        return image
552