1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import pytest 17import mindspore.context as context 18import mindspore.nn as nn 19from mindspore import Tensor 20from mindspore.ops import operations as P 21 22 23class LessNet(nn.Cell): 24 def __init__(self): 25 super(LessNet, self).__init__() 26 self.ops = P.Less() 27 28 def construct(self, x, y): 29 return self.ops(x, y) 30 31 32class GreaterNet(nn.Cell): 33 def __init__(self): 34 super(GreaterNet, self).__init__() 35 self.ops = P.Greater() 36 37 def construct(self, x, y): 38 return self.ops(x, y) 39 40 41class LessEqualNet(nn.Cell): 42 def __init__(self): 43 super(LessEqualNet, self).__init__() 44 self.ops = P.LessEqual() 45 46 def construct(self, x, y): 47 return self.ops(x, y) 48 49 50class GreaterEqualNet(nn.Cell): 51 def __init__(self): 52 super(GreaterEqualNet, self).__init__() 53 self.ops = P.GreaterEqual() 54 55 def construct(self, x, y): 56 return self.ops(x, y) 57 58 59def gen_data(): 60 # Generate data which contains broadcast scene and two inputs are expr. 61 np.random.seed(0) 62 x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) 63 y0_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.float32) 64 x1_np = np.random.randint(1, 5, (2, 1, 1, 4)).astype(np.float16) 65 y1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16) 66 x2_np = np.random.randint(1, 5, 1).astype(np.int32) 67 y2_np = np.random.randint(1, 5, 1).astype(np.int32) 68 x3_np = np.array([768]).astype(np.float32) 69 y3_np = np.array([3072.5]).astype(np.float32) 70 71 x0 = Tensor(x0_np) 72 y0 = Tensor(y0_np) 73 x1 = Tensor(x1_np) 74 y1 = Tensor(y1_np) 75 x2 = Tensor(x2_np) 76 y2 = Tensor(y2_np) 77 x3 = Tensor(x3_np) 78 y3 = Tensor(y3_np) 79 return x0, y0, x1, y1, x2, y2, x3, y3 80 81 82def get_less_net_output(x0, y0, x1, y1, x2, y2, x3, y3, enable_graph_kernel=False): 83 context.set_context(enable_graph_kernel=enable_graph_kernel) 84 net_less = LessNet() 85 less_output_0 = net_less(x0, y0).asnumpy() 86 less_output_1 = net_less(x1, y1).asnumpy() 87 less_output_2 = net_less(x2, y2).asnumpy() 88 less_output_3 = net_less(x3, y3).asnumpy() 89 return less_output_0, less_output_1, less_output_2, less_output_3 90 91 92def get_greater_net_output(x0, y0, x1, y1, x2, y2, x3, y3, enable_graph_kernel=False): 93 context.set_context(enable_graph_kernel=enable_graph_kernel) 94 net_greater = GreaterNet() 95 greater_output_0 = net_greater(x0, y0).asnumpy() 96 greater_output_1 = net_greater(x1, y1).asnumpy() 97 greater_output_2 = net_greater(x2, y2).asnumpy() 98 greater_output_3 = net_greater(x3, y3).asnumpy() 99 return greater_output_0, greater_output_1, greater_output_2, greater_output_3 100 101 102def get_less_equal_net_output(x0, y0, x1, y1, x2, y2, x3, y3, enable_graph_kernel=False): 103 context.set_context(enable_graph_kernel=enable_graph_kernel) 104 net_less_equal = LessEqualNet() 105 less_equal_output_0 = net_less_equal(x0, y0).asnumpy() 106 less_equal_output_1 = net_less_equal(x1, y1).asnumpy() 107 less_equal_output_2 = net_less_equal(x2, y2).asnumpy() 108 less_equal_output_3 = net_less_equal(x3, y3).asnumpy() 109 return less_equal_output_0, less_equal_output_1, less_equal_output_2, less_equal_output_3 110 111 112def get_greater_equal_net_output(x0, y0, x1, y1, x2, y2, x3, y3, enable_graph_kernel=False): 113 context.set_context(enable_graph_kernel=enable_graph_kernel) 114 net_greater_equal = GreaterEqualNet() 115 greter_equal_output_0 = net_greater_equal(x0, y0).asnumpy() 116 greter_equal_output_1 = net_greater_equal(x1, y1).asnumpy() 117 greter_equal_output_2 = net_greater_equal(x2, y2).asnumpy() 118 greter_equal_output_3 = net_greater_equal(x3, y3).asnumpy() 119 return greter_equal_output_0, greter_equal_output_1, greter_equal_output_2, greter_equal_output_3 120 121 122def test_less_net(): 123 x0, y0, x1, y1, x2, y2, x3, y3 = gen_data() 124 out_gk_on_0, out_gk_on_1, out_gk_on_2, out_gk_on_3 = get_less_net_output(x0, y0, x1, y1, x2, y2, x3, y3, True) 125 out_gk_off_0, out_gk_off_1, out_gk_off_2, out_gk_off_3 = get_less_net_output( 126 x0, y0, x1, y1, x2, y2, x3, y3, False) 127 128 assert np.all(out_gk_on_0 == out_gk_off_0) 129 assert out_gk_on_0.shape == out_gk_off_0.shape 130 assert np.all(out_gk_on_1 == out_gk_off_1) 131 assert out_gk_on_1.shape == out_gk_off_1.shape 132 assert np.all(out_gk_on_2 == out_gk_off_2) 133 assert out_gk_on_2.shape == out_gk_off_2.shape 134 assert np.all(out_gk_on_3 == out_gk_off_3) 135 assert out_gk_on_3.shape == out_gk_off_3.shape 136 137 138def test_greater_net(): 139 x0, y0, x1, y1, x2, y2, x3, y3 = gen_data() 140 out_gk_on_0, out_gk_on_1, out_gk_on_2, out_gk_on_3 = get_greater_net_output(x0, y0, x1, y1, x2, y2, x3, y3, True) 141 out_gk_off_0, out_gk_off_1, out_gk_off_2, out_gk_off_3 = get_greater_net_output( 142 x0, y0, x1, y1, x2, y2, x3, y3, False) 143 144 assert np.all(out_gk_on_0 == out_gk_off_0) 145 assert out_gk_on_0.shape == out_gk_off_0.shape 146 assert np.all(out_gk_on_1 == out_gk_off_1) 147 assert out_gk_on_1.shape == out_gk_off_1.shape 148 assert np.all(out_gk_on_2 == out_gk_off_2) 149 assert out_gk_on_2.shape == out_gk_off_2.shape 150 assert np.all(out_gk_on_3 == out_gk_off_3) 151 assert out_gk_on_3.shape == out_gk_off_3.shape 152 153 154def test_less_equal_net(): 155 x0, y0, x1, y1, x2, y2, x3, y3 = gen_data() 156 out_gk_on_0, out_gk_on_1, out_gk_on_2, out_gk_on_3 = get_less_equal_net_output( 157 x0, y0, x1, y1, x2, y2, x3, y3, True) 158 out_gk_off_0, out_gk_off_1, out_gk_off_2, out_gk_off_3 = get_less_equal_net_output( 159 x0, y0, x1, y1, x2, y2, x3, y3, False) 160 161 assert np.all(out_gk_on_0 == out_gk_off_0) 162 assert out_gk_on_0.shape == out_gk_off_0.shape 163 assert np.all(out_gk_on_1 == out_gk_off_1) 164 assert out_gk_on_1.shape == out_gk_off_1.shape 165 assert np.all(out_gk_on_2 == out_gk_off_2) 166 assert out_gk_on_2.shape == out_gk_off_2.shape 167 assert np.all(out_gk_on_3 == out_gk_off_3) 168 assert out_gk_on_3.shape == out_gk_off_3.shape 169 170 171def test_greater_equal_net(): 172 x0, y0, x1, y1, x2, y2, x3, y3 = gen_data() 173 out_gk_on_0, out_gk_on_1, out_gk_on_2, out_gk_on_3 = get_greater_equal_net_output( 174 x0, y0, x1, y1, x2, y2, x3, y3, True) 175 out_gk_off_0, out_gk_off_1, out_gk_off_2, out_gk_off_3 = get_greater_equal_net_output( 176 x0, y0, x1, y1, x2, y2, x3, y3, False) 177 178 assert np.all(out_gk_on_0 == out_gk_off_0) 179 assert out_gk_on_0.shape == out_gk_off_0.shape 180 assert np.all(out_gk_on_1 == out_gk_off_1) 181 assert out_gk_on_1.shape == out_gk_off_1.shape 182 assert np.all(out_gk_on_2 == out_gk_off_2) 183 assert out_gk_on_2.shape == out_gk_off_2.shape 184 assert np.all(out_gk_on_3 == out_gk_off_3) 185 assert out_gk_on_3.shape == out_gk_off_3.shape 186 187 188@pytest.mark.level0 189@pytest.mark.platform_x86_gpu_training 190@pytest.mark.env_onecard 191def test_less_gpu(): 192 context.set_context(mode=context.GRAPH_MODE, device_target='GPU') 193 test_less_net() 194 195 196@pytest.mark.level0 197@pytest.mark.platform_x86_gpu_training 198@pytest.mark.env_onecard 199def test_greater_gpu(): 200 context.set_context(mode=context.GRAPH_MODE, device_target='GPU') 201 test_greater_net() 202 203 204@pytest.mark.level0 205@pytest.mark.platform_x86_gpu_training 206@pytest.mark.env_onecard 207def test_less_equal_gpu(): 208 context.set_context(mode=context.GRAPH_MODE, device_target='GPU') 209 test_less_equal_net() 210 211 212@pytest.mark.level0 213@pytest.mark.platform_x86_gpu_training 214@pytest.mark.env_onecard 215def test_greater_equal_gpu(): 216 context.set_context(mode=context.GRAPH_MODE, device_target='GPU') 217 test_greater_equal_net() 218