1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18 19import mindspore.context as context 20import mindspore.nn as nn 21import mindspore.ops.operations.array_ops as P 22from mindspore import Tensor 23from mindspore.common.api import ms_function 24from mindspore.common.initializer import initializer 25from mindspore.common.parameter import Parameter 26 27 28class Net(nn.Cell): 29 def __init__(self, nptype): 30 super(Net, self).__init__() 31 32 self.unstack = P.Unstack(axis=3) 33 self.data_np = np.array([[[[[0, 0], 34 [-2, -1]], 35 [[0, 0], 36 [0, 1]]], 37 [[[0, 0], 38 [2, 3]], 39 [[0, 0], 40 [4, 5]]], 41 [[[0, 0], 42 [6, 7]], 43 [[0, 0], 44 [8, 9]]]], 45 [[[[0, 0], 46 [10, 11]], 47 [[0, 0], 48 [12, 13]]], 49 [[[0, 0], 50 [14, 15]], 51 [[0, 0], 52 [16, 17]]], 53 [[[0, 0], 54 [18, 19]], 55 [[0, 0], 56 [20, 21]]]], 57 [[[[0, 0], 58 [22, 23]], 59 [[0, 0], 60 [24, 25]]], 61 [[[0, 0], 62 [26, 27]], 63 [[0, 0], 64 [28, 29]]], 65 [[[0, 0], 66 [30, 31]], 67 [[0, 0], 68 [32, 33]]]]]).astype(nptype) 69 self.x1 = Parameter(initializer(Tensor(self.data_np), [3, 3, 2, 2, 2]), name='x1') 70 71 @ms_function 72 def construct(self): 73 return self.unstack(self.x1) 74 75 76def unpack(nptype): 77 context.set_context(mode=context.GRAPH_MODE, device_target='CPU') 78 unpack_ = Net(nptype) 79 output = unpack_() 80 expect = (np.reshape(np.array([0] * 36).astype(nptype), (3, 3, 2, 2)), 81 np.arange(-2, 34, 1).reshape(3, 3, 2, 2).astype(nptype)) 82 83 for i, exp in enumerate(expect): 84 assert (output[i].asnumpy() == exp).all() 85 86 87def unpack_pynative(nptype): 88 context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') 89 x1 = np.array([[[[[0, 0], 90 [-2, -1]], 91 [[0, 0], 92 [0, 1]]], 93 [[[0, 0], 94 [2, 3]], 95 [[0, 0], 96 [4, 5]]], 97 [[[0, 0], 98 [6, 7]], 99 [[0, 0], 100 [8, 9]]]], 101 [[[[0, 0], 102 [10, 11]], 103 [[0, 0], 104 [12, 13]]], 105 [[[0, 0], 106 [14, 15]], 107 [[0, 0], 108 [16, 17]]], 109 [[[0, 0], 110 [18, 19]], 111 [[0, 0], 112 [20, 21]]]], 113 [[[[0, 0], 114 [22, 23]], 115 [[0, 0], 116 [24, 25]]], 117 [[[0, 0], 118 [26, 27]], 119 [[0, 0], 120 [28, 29]]], 121 [[[0, 0], 122 [30, 31]], 123 [[0, 0], 124 [32, 33]]]]]).astype(nptype) 125 x1 = Tensor(x1) 126 expect = (np.reshape(np.array([0] * 36).astype(nptype), (3, 3, 2, 2)), 127 np.arange(-2, 34, 1).reshape(3, 3, 2, 2).astype(nptype)) 128 output = P.Unstack(axis=3)(x1) 129 130 for i, exp in enumerate(expect): 131 assert (output[i].asnumpy() == exp).all() 132 133 134@pytest.mark.level0 135@pytest.mark.platform_x86_cpu 136@pytest.mark.env_onecard 137def test_unpack_graph_float32(): 138 unpack(np.float32) 139 140 141@pytest.mark.level0 142@pytest.mark.platform_x86_cpu 143@pytest.mark.env_onecard 144def test_unpack_graph_float16(): 145 unpack(np.float16) 146 147 148@pytest.mark.level0 149@pytest.mark.platform_x86_cpu 150@pytest.mark.env_onecard 151def test_unpack_graph_int32(): 152 unpack(np.int32) 153 154 155@pytest.mark.level0 156@pytest.mark.platform_x86_cpu 157@pytest.mark.env_onecard 158def test_unpack_graph_int16(): 159 unpack(np.int16) 160 161 162@pytest.mark.level0 163@pytest.mark.platform_x86_cpu 164@pytest.mark.env_onecard 165def test_unpack_graph_uint8(): 166 unpack(np.uint8) 167 168 169@pytest.mark.level0 170@pytest.mark.platform_x86_cpu 171@pytest.mark.env_onecard 172def test_unpack_graph_bool(): 173 unpack(np.bool) 174 175 176@pytest.mark.level0 177@pytest.mark.platform_x86_cpu 178@pytest.mark.env_onecard 179def test_unpack_pynative_float32(): 180 unpack_pynative(np.float32) 181 182 183@pytest.mark.level0 184@pytest.mark.platform_x86_cpu 185@pytest.mark.env_onecard 186def test_unpack_pynative_float16(): 187 unpack_pynative(np.float16) 188 189 190@pytest.mark.level0 191@pytest.mark.platform_x86_cpu 192@pytest.mark.env_onecard 193def test_unpack_pynative_int32(): 194 unpack_pynative(np.int32) 195 196 197@pytest.mark.level0 198@pytest.mark.platform_x86_cpu 199@pytest.mark.env_onecard 200def test_unpack_pynative_int16(): 201 unpack_pynative(np.int16) 202 203 204@pytest.mark.level0 205@pytest.mark.platform_x86_cpu 206@pytest.mark.env_onecard 207def test_unpack_pynative_uint8(): 208 unpack_pynative(np.uint8) 209 210 211@pytest.mark.level0 212@pytest.mark.platform_x86_cpu 213@pytest.mark.env_onecard 214def test_unpack_pynative_bool(): 215 unpack_pynative(np.bool) 216