1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_operator """ 16import numpy as np 17 18from mindspore import Tensor, Model, context 19from mindspore.nn import Cell 20from mindspore.nn import ReLU 21from mindspore.ops import operations as P 22from ...ut_filter import non_graph_engine 23 24 25class arithmetic_Net(Cell): 26 """ arithmetic_Net definition """ 27 28 def __init__(self, symbol, loop_count=(1, 3)): 29 super().__init__() 30 self.symbol = symbol 31 self.loop_count = loop_count 32 self.relu = ReLU() 33 34 def construct(self, x): 35 a, b = self.loop_count 36 y = self.symbol 37 if y == 1: 38 a += b 39 for _ in (b, a): 40 x = self.relu(x) 41 elif y == 2: 42 b -= a 43 for _ in (a, b): 44 x = self.relu(x) 45 elif y == 3: 46 z = a + b 47 for _ in (b, z): 48 x = self.relu(x) 49 elif y == 4: 50 z = b - a 51 for _ in (z, b): 52 x = self.relu(x) 53 elif y == 5: 54 z = a * b 55 for _ in (a, z): 56 x = self.relu(x) 57 elif y == 6: 58 z = b / a 59 for _ in (a, z): 60 x = self.relu(x) 61 elif y == 7: 62 z = b % a + 1 63 for _ in (a, z): 64 x = self.relu(x) 65 else: 66 if not a: 67 x = self.relu(x) 68 return x 69 70 71class logical_Net(Cell): 72 """ logical_Net definition """ 73 74 def __init__(self, symbol, loop_count=(1, 3)): 75 super().__init__() 76 self.symbol = symbol 77 self.loop_count = loop_count 78 self.fla = P.Flatten() 79 self.relu = ReLU() 80 81 def construct(self, x): 82 a, b = self.loop_count 83 y = self.symbol 84 if y == 1: 85 if b and a: 86 x = self.relu(x) 87 else: 88 x = self.fla(x) 89 else: 90 if b or a: 91 x = self.relu(x) 92 else: 93 x = self.fla(x) 94 return x 95 96 97def arithmetic_operator_base(symbol): 98 """ arithmetic_operator_base """ 99 input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) 100 input_me = Tensor(input_np) 101 logical_operator = {"++": 1, "--": 2, "+": 3, "-": 4, "*": 5, "/": 6, "%": 7, "not": 8} 102 x = logical_operator[symbol] 103 net = arithmetic_Net(x) 104 context.set_context(mode=context.GRAPH_MODE) 105 model = Model(net) 106 model.predict(input_me) 107 108 109def logical_operator_base(symbol): 110 """ logical_operator_base """ 111 input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) 112 input_me = Tensor(input_np) 113 logical_operator = {"and": 1, "or": 2} 114 x = logical_operator[symbol] 115 net = logical_Net(x) 116 context.set_context(mode=context.GRAPH_MODE) 117 model = Model(net) 118 model.predict(input_me) 119 120 121@non_graph_engine 122def test_ME_arithmetic_operator_0080(): 123 """ test_ME_arithmetic_operator_0080 """ 124 arithmetic_operator_base('not') 125 126 127@non_graph_engine 128def test_ME_arithmetic_operator_0070(): 129 """ test_ME_arithmetic_operator_0070 """ 130 logical_operator_base('and') 131 132 133@non_graph_engine 134def test_ME_logical_operator_0020(): 135 """ test_ME_logical_operator_0020 """ 136 logical_operator_base('or') 137 138 139def test_ops(): 140 class OpsNet(Cell): 141 """ OpsNet definition """ 142 143 def __init__(self, x, y): 144 super(OpsNet, self).__init__() 145 self.x = x 146 self.y = y 147 self.int = 4 148 self.float = 3.2 149 self.str_a = "hello" 150 self.str_b = "world" 151 152 def construct(self, x, y): 153 h = x // y 154 m = x ** y 155 n = x % y 156 r = self.x // self.y 157 s = self.x ** self.y 158 t = self.x % self.y 159 p = h + m + n 160 q = r + s + t 161 ret_pow = p ** q + q ** p 162 ret_mod = p % q + q % p 163 ret_floor = p // q + q // p 164 ret = ret_pow + ret_mod + ret_floor 165 if self.int > self.float: 166 if [1, 2, 3] is not None: 167 if self.str_a + self.str_b == "helloworld": 168 if q == 86: 169 return ret 170 return x 171 172 net = OpsNet(9, 2) 173 x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32)) 174 y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32)) 175 context.set_context(mode=context.GRAPH_MODE) 176 net(x, y) 177 178 179def test_in_dict(): 180 class InDictNet(Cell): 181 """ InDictNet definition """ 182 183 def __init__(self, key_in, key_not_in): 184 super(InDictNet, self).__init__() 185 self.key_in = key_in 186 self.key_not_in = key_not_in 187 188 def construct(self, x, y, z): 189 d = {"a": x, "b": y} 190 ret_in = 1 191 ret_not_in = 2 192 if self.key_in in d: 193 ret_in = d[self.key_in] 194 if self.key_not_in not in d: 195 ret_not_in = z 196 ret = ret_in + ret_not_in 197 return ret 198 199 net = InDictNet("a", "c") 200 x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32)) 201 y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32)) 202 z = Tensor(np.random.randint(low=20, high=30, size=(2, 3, 4), dtype=np.int32)) 203 context.set_context(mode=context.GRAPH_MODE) 204 net(x, y, z) 205