• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import pytest
16import numpy as np
17
18import mindspore.context as context
19from mindspore import Tensor
20from mindspore import nn
21import tests.st.utils.test_utils as test_utils
22
23
24@test_utils.run_with_cell
25def channel_shuffle(x):
26    return nn.ChannelShuffle(2)(x)
27
28
29@pytest.mark.level0
30@pytest.mark.env_onecard
31@pytest.mark.platform_arm_ascend_training
32@pytest.mark.platform_x86_ascend_training
33@pytest.mark.parametrize("context_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE])
34def test_net_channelshuffle_float32(context_mode):
35    """
36    Feature: channelshuffle
37    Description: test channelshuffle
38    Expectation: expect correct result.
39    """
40    context.set_context(mode=context_mode, device_target="Ascend")
41    x = Tensor(np.arange(16).astype(np.int32).reshape(1, 4, 2, 2))
42    output = channel_shuffle(x)
43    expected = np.array([[[[0, 1],
44                           [2, 3]],
45                          [[8, 9],
46                           [10, 11]],
47                          [[4, 5],
48                           [6, 7]],
49                          [[12, 13],
50                           [14, 15]]]], np.int32)
51    assert np.all(output.asnumpy() == expected)
52