1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18 19import mindspore.context as context 20import mindspore.nn as nn 21from mindspore import Tensor 22from mindspore.ops import operations as P 23 24 25class NetFlatten(nn.Cell): 26 def __init__(self): 27 super(NetFlatten, self).__init__() 28 self.flatten = P.Flatten() 29 30 def construct(self, x): 31 return self.flatten(x) 32 33 34class NetAllFlatten(nn.Cell): 35 def __init__(self): 36 super(NetAllFlatten, self).__init__() 37 self.flatten = P.Flatten() 38 39 def construct(self, x): 40 loop_count = 4 41 while loop_count > 0: 42 x = self.flatten(x) 43 loop_count = loop_count - 1 44 return x 45 46 47class NetFirstFlatten(nn.Cell): 48 def __init__(self): 49 super(NetFirstFlatten, self).__init__() 50 self.flatten = P.Flatten() 51 self.relu = P.ReLU() 52 53 def construct(self, x): 54 loop_count = 4 55 while loop_count > 0: 56 x = self.flatten(x) 57 loop_count = loop_count - 1 58 x = self.relu(x) 59 return x 60 61 62class NetLastFlatten(nn.Cell): 63 def __init__(self): 64 super(NetLastFlatten, self).__init__() 65 self.flatten = P.Flatten() 66 self.relu = P.ReLU() 67 68 def construct(self, x): 69 loop_count = 4 70 x = self.relu(x) 71 while loop_count > 0: 72 x = self.flatten(x) 73 loop_count = loop_count - 1 74 return x 75 76 77@pytest.mark.level0 78@pytest.mark.platform_x86_gpu_training 79@pytest.mark.env_onecard 80def test_flatten(): 81 x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32)) 82 expect = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32) 83 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 84 flatten = NetFlatten() 85 output = flatten(x) 86 assert (output.asnumpy() == expect).all() 87 88 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 89 flatten = NetFlatten() 90 output = flatten(x) 91 assert (output.asnumpy() == expect).all() 92 93 94@pytest.mark.level0 95@pytest.mark.platform_x86_gpu_training 96@pytest.mark.env_onecard 97def test_all_flatten(): 98 x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32)) 99 expect = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32) 100 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 101 flatten = NetAllFlatten() 102 output = flatten(x) 103 assert (output.asnumpy() == expect).all() 104 105 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 106 flatten = NetAllFlatten() 107 output = flatten(x) 108 assert (output.asnumpy() == expect).all() 109 110 111@pytest.mark.level0 112@pytest.mark.platform_x86_gpu_training 113@pytest.mark.env_onecard 114def test_first_flatten(): 115 x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32)) 116 expect = np.array([[0, 0.3, 3.6], [0.4, 0.5, 0]]).astype(np.float32) 117 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 118 flatten = NetFirstFlatten() 119 output = flatten(x) 120 assert (output.asnumpy() == expect).all() 121 122 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 123 flatten = NetFirstFlatten() 124 output = flatten(x) 125 assert (output.asnumpy() == expect).all() 126 127 128@pytest.mark.level0 129@pytest.mark.platform_x86_gpu_training 130@pytest.mark.env_onecard 131def test_last_flatten(): 132 x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32)) 133 expect = np.array([[0, 0.3, 3.6], [0.4, 0.5, 0]]).astype(np.float32) 134 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 135 flatten = NetLastFlatten() 136 output = flatten(x) 137 assert (output.asnumpy() == expect).all() 138 139 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 140 flatten = NetLastFlatten() 141 output = flatten(x) 142 assert (output.asnumpy() == expect).all() 143