1# Copyright 2019-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import pytest 17 18import mindspore.context as context 19import mindspore.nn as nn 20from mindspore import Tensor 21from mindspore.ops import operations as P 22 23class SqueezeNet(nn.Cell): 24 def __init__(self): 25 super(SqueezeNet, self).__init__() 26 self.squeeze = P.Squeeze() 27 28 def construct(self, tensor): 29 return self.squeeze(tensor) 30 31 32def squeeze(nptype): 33 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 34 35 np.random.seed(0) 36 x = np.random.randn(1, 16, 1, 1).astype(nptype) 37 net = SqueezeNet() 38 output = net(Tensor(x)) 39 assert np.all(output.asnumpy() == x.squeeze()) 40 41@pytest.mark.level1 42@pytest.mark.platform_x86_gpu_training 43@pytest.mark.env_onecard 44def test_squeeze_bool(): 45 squeeze(np.bool) 46 47@pytest.mark.level1 48@pytest.mark.platform_x86_gpu_training 49@pytest.mark.env_onecard 50def test_squeeze_uint8(): 51 squeeze(np.uint8) 52 53@pytest.mark.level1 54@pytest.mark.platform_x86_gpu_training 55@pytest.mark.env_onecard 56def test_squeeze_uint16(): 57 squeeze(np.uint16) 58 59@pytest.mark.level1 60@pytest.mark.platform_x86_gpu_training 61@pytest.mark.env_onecard 62def test_squeeze_uint32(): 63 squeeze(np.uint32) 64 65@pytest.mark.level1 66@pytest.mark.platform_x86_gpu_training 67@pytest.mark.env_onecard 68def test_squeeze_int8(): 69 squeeze(np.int8) 70 71@pytest.mark.level1 72@pytest.mark.platform_x86_gpu_training 73@pytest.mark.env_onecard 74def test_squeeze_int16(): 75 squeeze(np.int16) 76 77@pytest.mark.level1 78@pytest.mark.platform_x86_gpu_training 79@pytest.mark.env_onecard 80def test_squeeze_int32(): 81 squeeze(np.int32) 82 83@pytest.mark.level1 84@pytest.mark.platform_x86_gpu_training 85@pytest.mark.env_onecard 86def test_squeeze_int64(): 87 squeeze(np.int64) 88 89@pytest.mark.level1 90@pytest.mark.platform_x86_gpu_training 91@pytest.mark.env_onecard 92def test_squeeze_float16(): 93 squeeze(np.float16) 94 95@pytest.mark.level0 96@pytest.mark.platform_x86_gpu_training 97@pytest.mark.env_onecard 98def test_squeeze_float32(): 99 squeeze(np.float32) 100 101@pytest.mark.level0 102@pytest.mark.platform_x86_gpu_training 103@pytest.mark.env_onecard 104def test_squeeze_float64(): 105 squeeze(np.float64) 106