• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""image"""
16from __future__ import absolute_import
17from __future__ import division
18
19import numbers
20import numpy as np
21
22import mindspore.common.dtype as mstype
23import mindspore.ops as ops
24from mindspore.common.tensor import Tensor
25from mindspore.ops import operations as P
26from mindspore.ops.operations import _inner_ops as inner
27from mindspore.ops import functional as F
28from mindspore.ops.primitive import constexpr, _primexpr
29from mindspore import _checkparam as validator
30from mindspore.nn.layer.conv import Conv2d
31from mindspore.nn.layer.container import CellList
32from mindspore.nn.layer.pooling import AvgPool2d
33from mindspore.nn.layer.activation import ReLU
34from mindspore.nn.cell import Cell
35
36__all__ = ['ImageGradients', 'SSIM', 'MSSSIM', 'PSNR', 'CentralCrop', 'PixelShuffle', 'PixelUnshuffle']
37
38
39class ImageGradients(Cell):
40    r"""
41    Returns two tensors, the first is along the height dimension and the second is along the width dimension.
42
43    Assume an image shape is :math:`h*w`, the gradients along the height and the width are :math:`dy` and :math:`dx`,
44    respectively.
45
46    .. math::
47        dy[i] = \begin{cases} image[i+1, :]-image[i, :], &if\ 0<=i<h-1 \cr
48        0, &if\ i==h-1\end{cases}
49
50        dx[i] = \begin{cases} image[:, i+1]-image[:, i], &if\ 0<=i<w-1 \cr
51        0, &if\ i==w-1\end{cases}
52
53    Inputs:
54        - **images** (Tensor) - The input image data, with format 'NCHW'.
55
56    Outputs:
57        - **dy** (Tensor) - vertical image gradients, the same type and shape as input.
58        - **dx** (Tensor) - horizontal image gradients, the same type and shape as input.
59
60    Raises:
61        ValueError: If length of shape of `images` is not equal to 4.
62
63    Supported Platforms:
64        ``Ascend`` ``GPU`` ``CPU``
65
66    Examples:
67        >>> import mindspore as ms
68        >>> import numpy as np
69        >>> net = ms.nn.ImageGradients()
70        >>> image = ms.Tensor(np.array([[[[1, 2], [3, 4]]]]), dtype=ms.int32)
71        >>> output = net(image)
72        >>> print(output)
73        (Tensor(shape=[1, 1, 2, 2], dtype=Int32, value=
74        [[[[2, 2],
75           [0, 0]]]]), Tensor(shape=[1, 1, 2, 2], dtype=Int32, value=
76        [[[[1, 0],
77           [1, 0]]]]))
78    """
79    def __init__(self):
80        super(ImageGradients, self).__init__()
81
82    def construct(self, images):
83        _check_input_4d(F.shape(images), "images", self.cls_name)
84        batch_size, depth, height, width = P.Shape()(images)
85        if height == 1:
86            dy = F.fill(P.DType()(images), (batch_size, depth, 1, width), 0)
87        else:
88            dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
89            dy_last = F.fill(P.DType()(images), (batch_size, depth, 1, width), 0)
90            dy = P.Concat(2)((dy, dy_last))
91
92        if width == 1:
93            dx = F.fill(P.DType()(images), (batch_size, depth, height, 1), 0)
94        else:
95            dx = images[:, :, :, 1:] - images[:, :, :, :width - 1]
96            dx_last = F.fill(P.DType()(images), (batch_size, depth, height, 1), 0)
97            dx = P.Concat(3)((dx, dx_last))
98        return dy, dx
99
100
101def _convert_img_dtype_to_float32(img, max_val):
102    """convert img dtype to float32"""
103    # Usually max_val is 1.0 or 255, we will do the scaling if max_val > 1.
104    # We will scale img pixel value if max_val > 1. and just cast otherwise.
105    ret = F.cast(img, mstype.float32)
106    max_val = F.scalar_cast(max_val, mstype.float32)
107    if max_val > 1.:
108        scale = 1. / max_val
109        ret = ret * scale
110    return ret
111
112
113@constexpr
114def _get_dtype_max(dtype):
115    """get max of the dtype"""
116    np_type = mstype.dtype_to_nptype(dtype)
117    if issubclass(np_type, numbers.Integral):
118        dtype_max = np.float64(np.iinfo(np_type).max).item()
119    else:
120        dtype_max = 1.0
121    return dtype_max
122
123
124@_primexpr
125def _check_input_4d(input_shape, param_name, func_name):
126    if len(input_shape) != 4:
127        raise ValueError(f"For '{func_name}', the dimension of '{param_name}' must be 4d, "
128                         f"but got {len(input_shape)}.")
129
130
131@_primexpr
132def _check_input_filter_size(input_shape, param_name, filter_size, func_name):
133    _check_input_4d(input_shape, param_name, func_name)
134    validator.check(param_name + " shape[2]", input_shape[2], "filter_size", filter_size, validator.GE, func_name)
135    validator.check(param_name + " shape[3]", input_shape[3], "filter_size", filter_size, validator.GE, func_name)
136
137
138@constexpr
139def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
140    validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
141
142
143def _conv2d(in_channels, out_channels, kernel_size, weight, stride=1, padding=0):
144    return Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
145                  weight_init=weight, padding=padding, pad_mode="valid")
146
147
148def _create_window(size, sigma):
149    x_data, y_data = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1]
150    x_data = np.expand_dims(x_data, axis=-1).astype(np.float32)
151    x_data = np.expand_dims(x_data, axis=-1) ** 2
152    y_data = np.expand_dims(y_data, axis=-1).astype(np.float32)
153    y_data = np.expand_dims(y_data, axis=-1) ** 2
154    sigma = 2 * sigma ** 2
155    g = np.exp(-(x_data + y_data) / sigma)
156    return np.transpose(g / np.sum(g), (2, 3, 0, 1))
157
158
159def _split_img(x):
160    _, c, _, _ = F.shape(x)
161    img_split = P.Split(1, c)
162    output = img_split(x)
163    return output, c
164
165
166def _compute_per_channel_loss(c1, c2, img1, img2, conv):
167    """computes ssim index between img1 and img2 per single channel"""
168    dot_img = img1 * img2
169    mu1 = conv(img1)
170    mu2 = conv(img2)
171    mu1_sq = mu1 * mu1
172    mu2_sq = mu2 * mu2
173    mu1_mu2 = mu1 * mu2
174    sigma1_tmp = conv(img1 * img1)
175    sigma1_sq = sigma1_tmp - mu1_sq
176    sigma2_tmp = conv(img2 * img2)
177    sigma2_sq = sigma2_tmp - mu2_sq
178    sigma12_tmp = conv(dot_img)
179    sigma12 = sigma12_tmp - mu1_mu2
180    a = (2 * mu1_mu2 + c1)
181    b = (mu1_sq + mu2_sq + c1)
182    v1 = 2 * sigma12 + c2
183    v2 = sigma1_sq + sigma2_sq + c2
184    ssim = (a * v1) / (b * v2)
185    cs = v1 / v2
186    return ssim, cs
187
188
189def _compute_multi_channel_loss(c1, c2, img1, img2, conv, concat, mean):
190    """computes ssim index between img1 and img2 per color channel"""
191    split_img1, c = _split_img(img1)
192    split_img2, _ = _split_img(img2)
193    multi_ssim = ()
194    multi_cs = ()
195    for i in range(c):
196        ssim_per_channel, cs_per_channel = _compute_per_channel_loss(c1, c2, split_img1[i], split_img2[i], conv)
197        multi_ssim += (ssim_per_channel,)
198        multi_cs += (cs_per_channel,)
199
200    multi_ssim = concat(multi_ssim)
201    multi_cs = concat(multi_cs)
202
203    ssim = mean(multi_ssim, (2, 3))
204    cs = mean(multi_cs, (2, 3))
205    return ssim, cs
206
207
208class SSIM(Cell):
209    r"""
210    Returns SSIM index between two images.
211
212    Its implementation is based on Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004) `Image quality
213    assessment: from error visibility to structural similarity <https://ieeexplore.ieee.org/document/1284395>`_ .
214
215    SSIM is a measure of the similarity of two pictures.
216    Like PSNR,
217    SSIM is often used as an evaluation of image quality.
218    SSIM is a number between 0 and 1, and the larger it is,
219    the smaller the gap between the output image and the undistorted image, that is, the better the image quality.
220    When the two images are exactly the same, SSIM=1.
221
222    .. math::
223
224        l(x,y)&=\frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1}, C_1=(K_1L)^2.\\
225        c(x,y)&=\frac{2\sigma_x\sigma_y+C_2}{\sigma_x^2+\sigma_y^2+C_2}, C_2=(K_2L)^2.\\
226        s(x,y)&=\frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}, C_3=C_2/2.\\
227        SSIM(x,y)&=l*c*s\\&=\frac{(2\mu_x\mu_y+C_1)(2\sigma_{xy}+C_2)}
228        {(\mu_x^2+\mu_y^2+C_1)(\sigma_x^2+\sigma_y^2+C_2)}.
229
230    Args:
231        max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images).
232          Default: ``1.0`` .
233        filter_size (int): The size of the Gaussian filter. Default: 11. The value must be greater than or equal to 1.
234        filter_sigma (float): The standard deviation of Gaussian kernel. Default: ``1.5`` .
235          The value must be greater than 0.
236        k1 (float): The constant used to generate c1 in the luminance comparison function. Default: ``0.01`` .
237        k2 (float): The constant used to generate c2 in the contrast comparison function. Default: ``0.03`` .
238
239    Inputs:
240        - **img1** (Tensor) - The first image batch with format 'NCHW'. It must be the same shape and dtype as img2.
241        - **img2** (Tensor) - The second image batch with format 'NCHW'. It must be the same shape and dtype as img1.
242
243    Outputs:
244        Tensor, has the same dtype as img1. It is a 1-D tensor with shape N, where N is the batch num of img1.
245
246    Raises:
247        TypeError: If `max_val` is neither int nor float.
248        TypeError: If `k1`, `k2` or `filter_sigma` is not a float.
249        TypeError: If `filter_size` is not an int.
250        ValueError: If `max_val` or `filter_sigma` is less than or equal to 0.
251        ValueError: If `filter_size` is less than 0.
252
253    Supported Platforms:
254        ``Ascend`` ``GPU`` ``CPU``
255
256    Examples:
257        >>> import mindspore as ms
258        >>> import numpy as np
259        >>> net = ms.nn.SSIM()
260        >>> img1 = ms.Tensor(np.ones([1, 3, 16, 16]).astype(np.float32))
261        >>> img2 = ms.Tensor(np.ones([1, 3, 16, 16]).astype(np.float32))
262        >>> output = net(img1, img2)
263        >>> print(output)
264        [1.]
265    """
266    def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
267        super(SSIM, self).__init__()
268        validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
269        validator.check_number('max_val', max_val, 0.0, validator.GT, self.cls_name)
270        self.max_val = max_val
271        self.filter_size = validator.check_int(filter_size, 1, validator.GE, 'filter_size', self.cls_name)
272        self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name)
273        self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
274        self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
275        window = _create_window(filter_size, filter_sigma)
276        self.conv = _conv2d(1, 1, filter_size, Tensor(window))
277        self.conv.weight.requires_grad = False
278        self.reduce_mean = P.ReduceMean()
279        self.concat = P.Concat(axis=1)
280
281    def construct(self, img1, img2):
282        _check_input_dtype(F.dtype(img1), "img1", [mstype.float32, mstype.float16], self.cls_name)
283        _check_input_filter_size(F.shape(img1), "img1", self.filter_size, self.cls_name)
284        inner.SameTypeShape()(img1, img2)
285        dtype_max_val = _get_dtype_max(F.dtype(img1))
286        max_val = _convert_img_dtype_to_float32(self.max_val, dtype_max_val)
287        img1 = _convert_img_dtype_to_float32(img1, dtype_max_val)
288        img2 = _convert_img_dtype_to_float32(img2, dtype_max_val)
289
290        c1 = (self.k1 * max_val) ** 2
291        c2 = (self.k2 * max_val) ** 2
292
293        ssim_ave_channel, _ = _compute_multi_channel_loss(c1, c2, img1, img2, self.conv, self.concat, self.reduce_mean)
294        loss = self.reduce_mean(ssim_ave_channel, -1)
295
296        return loss
297
298
299def _downsample(img1, img2, op):
300    a = op(img1)
301    b = op(img2)
302    return a, b
303
304
305class MSSSIM(Cell):
306    r"""
307    Returns MS-SSIM index between two images.
308
309    Its implementation is based on `Multiscale structural similarity
310    for image quality assessment <https://ieeexplore.ieee.org/document/1292216>`_
311    by Zhou Wang, Eero P. Simoncelli, and Alan C. Bovik, published on Signals, Systems and Computers in 2004.
312
313    .. math::
314
315        l(x,y)&=\frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1}, C_1=(K_1L)^2.\\
316        c(x,y)&=\frac{2\sigma_x\sigma_y+C_2}{\sigma_x^2+\sigma_y^2+C_2}, C_2=(K_2L)^2.\\
317        s(x,y)&=\frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}, C_3=C_2/2.\\
318        MSSSIM(x,y)&=l^\alpha_M*{\prod_{1\leq j\leq M} (c^\beta_j*s^\gamma_j)}.
319
320    Args:
321        max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images).
322          Default: ``1.0`` .
323        power_factors (Union[tuple, list]): Iterable of weights for each scale.
324          Default: (0.0448, 0.2856, 0.3001, 0.2363, 0.1333). Default values obtained by Wang et al.
325        filter_size (int): The size of the Gaussian filter. Default: ``11`` .
326        filter_sigma (float): The standard deviation of Gaussian kernel. Default: ``1.5`` .
327        k1 (float): The constant used to generate c1 in the luminance comparison function. Default: ``0.01`` .
328        k2 (float): The constant used to generate c2 in the contrast comparison function. Default: ``0.03`` .
329
330    Inputs:
331        - **img1** (Tensor) - The first image batch with format 'NCHW'. It must be the same shape and dtype as img2.
332        - **img2** (Tensor) - The second image batch with format 'NCHW'. It must be the same shape and dtype as img1.
333
334    Outputs:
335        Tensor, the value is in range [0, 1]. It is a 1-D tensor with shape N, where N is the batch num of img1.
336
337    Raises:
338        TypeError: If `max_val` is neither int nor float.
339        TypeError: If `power_factors` is neither tuple nor list.
340        TypeError: If `k1`, `k2` or `filter_sigma` is not a float.
341        TypeError: If `filter_size` is not an int.
342        ValueError: If `max_val` or `filter_sigma` is less than or equal to 0.
343        ValueError: If `filter_size` is less than 0.
344        ValueError: If length of shape of `img1` or `img2` is not equal to 4.
345
346    Supported Platforms:
347        ``Ascend`` ``GPU``
348
349    Examples:
350        >>> import numpy as np
351        >>> import mindspore as ms
352        >>> net = ms.nn.MSSSIM(power_factors=(0.033, 0.033, 0.033))
353        >>> img1 = ms.Tensor(np.ones((1, 3, 128, 128)).astype(np.float32))
354        >>> img2 = ms.Tensor(np.ones((1, 3, 128, 128)).astype(np.float32))
355        >>> output = net(img1, img2)
356        >>> print(output)
357        [1.]
358    """
359    def __init__(self, max_val=1.0, power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), filter_size=11,
360                 filter_sigma=1.5, k1=0.01, k2=0.03):
361        super(MSSSIM, self).__init__()
362        validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
363        validator.check_number('max_val', max_val, 0.0, validator.GT, self.cls_name)
364        self.max_val = max_val
365        validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name)
366        self.filter_size = validator.check_int(filter_size, 1, validator.GE, 'filter_size', self.cls_name)
367        self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name)
368        self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
369        self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
370        window = _create_window(filter_size, filter_sigma)
371        self.level = len(power_factors)
372        self.conv = []
373        for i in range(self.level):
374            self.conv.append(_conv2d(1, 1, filter_size, Tensor(window)))
375            self.conv[i].weight.requires_grad = False
376        self.multi_convs_list = CellList(self.conv)
377        self.weight_tensor = Tensor(power_factors, mstype.float32)
378        self.avg_pool = AvgPool2d(kernel_size=2, stride=2, pad_mode='valid')
379        self.relu = ReLU()
380        self.reduce_mean = P.ReduceMean()
381        self.prod = P.ReduceProd()
382        self.pow = P.Pow()
383        self.stack = P.Stack(axis=-1)
384        self.concat = P.Concat(axis=1)
385
386    def construct(self, img1, img2):
387        _check_input_4d(F.shape(img1), "img1", self.cls_name)
388        _check_input_4d(F.shape(img2), "img2", self.cls_name)
389        valid_type = [mstype.float64, mstype.float32, mstype.float16, mstype.uint8]
390        _check_input_dtype(F.dtype(img1), 'img1', valid_type, self.cls_name)
391        inner.SameTypeShape()(img1, img2)
392        dtype_max_val = _get_dtype_max(F.dtype(img1))
393        max_val = _convert_img_dtype_to_float32(self.max_val, dtype_max_val)
394        img1 = _convert_img_dtype_to_float32(img1, dtype_max_val)
395        img2 = _convert_img_dtype_to_float32(img2, dtype_max_val)
396
397        c1 = (self.k1 * max_val) ** 2
398        c2 = (self.k2 * max_val) ** 2
399
400        sim = ()
401        mcs = ()
402
403        for i in range(self.level):
404            sim, cs = _compute_multi_channel_loss(c1, c2, img1, img2,
405                                                  self.multi_convs_list[i], self.concat, self.reduce_mean)
406            mcs += (self.relu(cs),)
407            img1, img2 = _downsample(img1, img2, self.avg_pool)
408
409        mcs = mcs[0:-1:1]
410        mcs_and_ssim = self.stack(mcs + (self.relu(sim),))
411        mcs_and_ssim = self.pow(mcs_and_ssim, self.weight_tensor)
412        ms_ssim = self.prod(mcs_and_ssim, -1)
413        loss = self.reduce_mean(ms_ssim, -1)
414
415        return loss
416
417
418class PSNR(Cell):
419    r"""
420    Returns Peak Signal-to-Noise Ratio of two image batches.
421
422    It produces a PSNR value for each image in batch.
423    Assume inputs are :math:`I` and :math:`K`, both with shape :math:`h*w`.
424    :math:`MAX` represents the dynamic range of pixel values.
425
426    .. math::
427
428        MSE&=\frac{1}{hw}\sum\limits_{i=0}^{h-1}\sum\limits_{j=0}^{w-1}[I(i,j)-K(i,j)]^2\\
429        PSNR&=10*log_{10}(\frac{MAX^2}{MSE})
430
431    Args:
432        max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images).
433          The value must be greater than 0. Default: ``1.0`` .
434
435    Inputs:
436        - **img1** (Tensor) - The first image batch with format 'NCHW'. It must be the same shape and dtype as img2.
437        - **img2** (Tensor) - The second image batch with format 'NCHW'. It must be the same shape and dtype as img1.
438
439    Outputs:
440        Tensor, with dtype mindspore.float32. It is a 1-D tensor with shape N, where N is the batch num of img1.
441
442    Raises:
443        TypeError: If `max_val` is neither int nor float.
444        ValueError: If `max_val` is less than or equal to 0.
445        ValueError: If length of shape of `img1` or `img2` is not equal to 4.
446
447    Supported Platforms:
448        ``Ascend`` ``GPU`` ``CPU``
449
450    Examples:
451        >>> import mindspore as ms
452        >>> net = ms.nn.PSNR()
453        >>> img1 = ms.Tensor([[[[1, 2, 3, 4], [1, 2, 3, 4]]]])
454        >>> img2 = ms.Tensor([[[[3, 4, 5, 6], [3, 4, 5, 6]]]])
455        >>> output = net(img1, img2)
456        >>> print(output)
457        [-6.0206]
458    """
459    def __init__(self, max_val=1.0):
460        super(PSNR, self).__init__()
461        validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
462        validator.check_number('max_val', max_val, 0.0, validator.GT, self.cls_name)
463        self.max_val = max_val
464
465    def construct(self, img1, img2):
466        _check_input_4d(F.shape(img1), "img1", self.cls_name)
467        _check_input_4d(F.shape(img2), "img2", self.cls_name)
468        inner.SameTypeShape()(img1, img2)
469        dtype_max_val = _get_dtype_max(F.dtype(img1))
470        max_val = _convert_img_dtype_to_float32(self.max_val, dtype_max_val)
471        img1 = _convert_img_dtype_to_float32(img1, dtype_max_val)
472        img2 = _convert_img_dtype_to_float32(img2, dtype_max_val)
473
474        mse = P.ReduceMean()(F.square(img1 - img2), (-3, -2, -1))
475        psnr = 10 * P.Log()(F.square(max_val) / mse) / F.scalar_log(10.0)
476
477        return psnr
478
479
480@_primexpr
481def _check_rank(rank, input_shape, param_name, func_name):
482    """raise error if input is not 3d or 4d"""
483    def _check():
484        if rank not in (3, 4):
485            raise ValueError(f"{func_name} {param_name} must be 3d or 4d, but got shape {input_shape}")
486    _check()
487
488
489@_primexpr
490def _get_bbox(rank, shape, central_fraction):
491    """get bbox start and size for slice"""
492    n, c, h, w = -1, -1, -1, -1
493    if rank == 3:
494        c, h, w = shape
495    else:
496        n, c, h, w = shape
497
498    bbox_h_start = int((float(h) - float(h * central_fraction)) / 2)
499    bbox_w_start = int((float(w) - float(w * central_fraction)) / 2)
500    bbox_h_size = h - bbox_h_start * 2
501    bbox_w_size = w - bbox_w_start * 2
502
503    if rank == 3:
504        bbox_begin = (0, bbox_h_start, bbox_w_start)
505        bbox_size = (c, bbox_h_size, bbox_w_size)
506    else:
507        bbox_begin = (0, 0, bbox_h_start, bbox_w_start)
508        bbox_size = (n, c, bbox_h_size, bbox_w_size)
509
510    return bbox_begin, bbox_size
511
512
513class CentralCrop(Cell):
514    """
515    Crops the central region of the images with the central_fraction.
516
517    Args:
518        central_fraction (float): Fraction of size to crop. It must be float and in range (0.0, 1.0].
519
520    Inputs:
521        - **image** (Tensor) - A 3-D tensor of shape [C, H, W], or a 4-D tensor of shape [N, C, H, W].
522
523    Outputs:
524        Tensor, 3-D or 4-D float tensor, according to the input.
525
526    Raises:
527        TypeError: If `central_fraction` is not a float.
528        ValueError: If `central_fraction` is not in range (0.0, 1.0].
529
530    Supported Platforms:
531        ``Ascend`` ``GPU`` ``CPU``
532
533    Examples:
534        >>> import mindspore as ms
535        >>> import numpy as np
536        >>> net = ms.nn.CentralCrop(central_fraction=0.5)
537        >>> image = ms.Tensor(np.random.random((4, 3, 4, 4)), ms.float32)
538        >>> output = net(image)
539        >>> print(output.shape)
540        (4, 3, 2, 2)
541    """
542
543    def __init__(self, central_fraction):
544        super(CentralCrop, self).__init__()
545        validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name)
546        validator.check_float_range(central_fraction, 0.0, 1.0, validator.INC_RIGHT, 'central_fraction', self.cls_name)
547        self.central_fraction = central_fraction
548        self.slice = P.Slice()
549
550    def construct(self, image):
551        image_shape = F.shape(image)
552        rank = len(image_shape)
553        _check_rank(rank, image_shape, "image", self.cls_name)
554        if self.central_fraction == 1.0:
555            return image
556
557        bbox_begin, bbox_size = _get_bbox(rank, image_shape, self.central_fraction)
558        image = self.slice(image, bbox_begin, bbox_size)
559
560        return image
561
562
563class PixelShuffle(Cell):
564    r"""
565    Applies the PixelShuffle operation over input which implements sub-pixel convolutions
566    with stride :math:`1/r` . For more details, refer to
567    `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network
568    <https://arxiv.org/abs/1609.05158>`_ .
569
570    Typically, the input is of shape :math:`(*, C \times r^2, H, W)` , and the output is of shape
571    :math:`(*, C, H \times r, W \times r)`,
572    where :math:`r` is an upscale factor and :math:`*` is zero or more batch dimensions.
573
574    Note:
575        The dimension of input Tensor on Ascend should be less than 7.
576
577    Args:
578        upscale_factor (int): factor to shuffle the input, and is a positive integer.
579            `upscale_factor` is the above-mentioned :math:`r`.
580
581    Inputs:
582        - **input** (Tensor) - Tensor of shape :math:`(*, C \times r^2, H, W)` . The dimension of `x` is larger than 2,
583          and the length of third to last dimension can be divisible by `upscale_factor` squared.
584
585    Outputs:
586        - **output** (Tensor) - Tensor of shape :math:`(*, C, H \times r, W \times r)` .
587
588    Raises:
589        ValueError: If `upscale_factor` is not a positive integer.
590        ValueError: If the length of third to last dimension of `input` is not divisible by `upscale_factor` squared.
591        ValueError: If the dimension of `input` is less than 3.
592        TypeError: If `input` is not a Tensor.
593
594    Supported Platforms:
595        ``Ascend`` ``GPU`` ``CPU``
596
597    Examples:
598        >>> import mindspore as ms
599        >>> import numpy as np
600        >>> input_x = np.arange(3 * 2 * 8 * 4 * 4).reshape((3, 2, 8, 4, 4))
601        >>> input_x = ms.Tensor(input_x, ms.dtype.int32)
602        >>> pixel_shuffle = ms.nn.PixelShuffle(2)
603        >>> output = pixel_shuffle(input_x)
604        >>> print(output.shape)
605        (3, 2, 2, 8, 8)
606    """
607    def __init__(self, upscale_factor):
608        super(PixelShuffle, self).__init__()
609        self.upscale_factor = upscale_factor
610
611    def construct(self, input):
612        return ops.pixel_shuffle(input, self.upscale_factor)
613
614
615class PixelUnshuffle(Cell):
616    r"""
617    Applies the PixelUnshuffle operation over input which is the inverse of PixelShuffle. For more details, refer
618    to `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network
619    <https://arxiv.org/abs/1609.05158>`_ .
620
621    Typically, the input is of shape :math:`(*, C, H \times r, W \times r)` , and the output is of shape
622    :math:`(*, C \times r^2, H, W)` ,
623    where :math:`r` is a downscale factor and :math:`*` is zero or more batch dimensions.
624
625    Args:
626        downscale_factor (int): factor to unshuffle the input, and is a positive integer.
627            `downscale_factor` is the above-mentioned :math:`r`.
628
629    Inputs:
630        - **input** (Tensor) - Tensor of shape :math:`(*, C, H \times r, W \times r)` . The dimension of `input` is
631          larger than 2, and the length of second to last dimension or last dimension can be divisible by
632          `downscale_factor` .
633
634    Outputs:
635        - **output** (Tensor) - Tensor of shape :math:`(*, C \times r^2, H, W)` .
636
637    Raises:
638        ValueError: If `downscale_factor` is not a positive integer.
639        ValueError: If the length of second to last dimension or last dimension is not divisible by `downscale_factor` .
640        ValueError: If the dimension of `input` is less than 3.
641        TypeError: If `input` is not a Tensor.
642
643    Supported Platforms:
644        ``Ascend`` ``GPU`` ``CPU``
645
646    Examples:
647        >>> import mindspore as ms
648        >>> import numpy as np
649        >>> pixel_unshuffle = ms.nn.PixelUnshuffle(2)
650        >>> input_x = np.arange(8 * 8).reshape((1, 1, 8, 8))
651        >>> input_x = ms.Tensor(input_x, ms.dtype.int32)
652        >>> output = pixel_unshuffle(input_x)
653        >>> print(output.shape)
654        (1, 4, 4, 4)
655    """
656    def __init__(self, downscale_factor):
657        super(PixelUnshuffle, self).__init__()
658        self.downscale_factor = downscale_factor
659
660    def construct(self, input):
661        return ops.pixel_unshuffle(input, self.downscale_factor)
662