1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18 19import mindspore.context as context 20import mindspore.nn as nn 21import mindspore.ops.operations.array_ops as P 22from mindspore import Tensor 23from mindspore.common.api import ms_function 24from mindspore.common.initializer import initializer 25from mindspore.common.parameter import Parameter 26 27 28class PackNet(nn.Cell): 29 def __init__(self, nptype): 30 super(PackNet, self).__init__() 31 self.stack = P.Stack(axis=2) 32 self.data_np = np.array([0] * 16).astype(nptype) 33 self.data_np = np.reshape(self.data_np, (2, 2, 2, 2)) 34 self.x1 = Parameter(initializer( 35 Tensor(self.data_np), [2, 2, 2, 2]), name='x1') 36 self.x2 = Parameter(initializer( 37 Tensor(np.arange(16).reshape(2, 2, 2, 2).astype(nptype)), [2, 2, 2, 2]), name='x2') 38 39 @ms_function 40 def construct(self): 41 return self.stack((self.x1, self.x2)) 42 43 44def pack(nptype): 45 context.set_context(mode=context.GRAPH_MODE, device_target='CPU') 46 pack_ = PackNet(nptype) 47 output = pack_() 48 expect = np.array([[[[[0, 0], 49 [0, 0]], 50 [[0, 1], 51 [2, 3]]], 52 [[[0, 0], 53 [0, 0]], 54 [[4, 5], 55 [6, 7]]]], 56 [[[[0, 0], 57 [0, 0]], 58 [[8, 9], 59 [10, 11]]], 60 [[[0, 0], 61 [0, 0]], 62 [[12, 13], 63 [14, 15]]]]]).astype(nptype) 64 assert (output.asnumpy() == expect).all() 65 66@pytest.mark.level0 67@pytest.mark.platform_x86_cpu 68@pytest.mark.env_onecard 69def test_pack_graph_float32(): 70 pack(np.float32) 71 72@pytest.mark.level0 73@pytest.mark.platform_x86_cpu 74@pytest.mark.env_onecard 75def test_pack_graph_float16(): 76 pack(np.float16) 77 78@pytest.mark.level0 79@pytest.mark.platform_x86_cpu 80@pytest.mark.env_onecard 81def test_pack_graph_int32(): 82 pack(np.int32) 83 84@pytest.mark.level0 85@pytest.mark.platform_x86_cpu 86@pytest.mark.env_onecard 87def test_pack_graph_int16(): 88 pack(np.int16) 89 90@pytest.mark.level0 91@pytest.mark.platform_x86_cpu 92@pytest.mark.env_onecard 93def test_pack_graph_uint8(): 94 pack(np.uint8) 95 96@pytest.mark.level0 97@pytest.mark.platform_x86_cpu 98@pytest.mark.env_onecard 99def test_pack_graph_bool(): 100 pack(np.bool) 101