1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16test_structure_output 17""" 18import numpy as np 19 20import mindspore.ops.operations as P 21from mindspore import Tensor, context 22from mindspore.nn import Cell 23from mindspore.ops.functional import depend 24 25context.set_context(mode=context.GRAPH_MODE) 26 27 28def test_output_const_tuple_0(): 29 class Net(Cell): 30 def __init__(self): 31 super(Net, self).__init__() 32 self.x = (1, 2, 3) 33 34 def construct(self): 35 return self.x 36 37 x = (1, 2, 3) 38 net = Net() 39 assert net() == x 40 41 42def test_output_const_tuple_1(): 43 class Net(Cell): 44 def __init__(self): 45 super(Net, self).__init__() 46 self.tuple_1 = (1, 2, 3) 47 self.tuple_2 = (4, 5, 6) 48 49 def construct(self): 50 ret = self.tuple_1 + self.tuple_2 51 return ret 52 53 net = Net() 54 assert net() == (1, 2, 3, 4, 5, 6) 55 56 57def test_output_const_list(): 58 class Net(Cell): 59 def __init__(self): 60 super(Net, self).__init__() 61 self.tuple_1 = [1, 2, 3] 62 63 def construct(self): 64 ret = self.tuple_1 65 return ret 66 67 net = Net() 68 assert net() == (1, 2, 3) 69 70 71def test_output_const_int(): 72 class Net(Cell): 73 def __init__(self): 74 super(Net, self).__init__() 75 self.number_1 = 2 76 self.number_2 = 3 77 78 def construct(self): 79 ret = self.number_1 + self.number_2 80 return ret 81 82 net = Net() 83 assert net() == 5 84 85 86def test_output_const_str(): 87 class Net(Cell): 88 def __init__(self): 89 super(Net, self).__init__() 90 self.str = "hello world" 91 92 def construct(self): 93 ret = self.str 94 return ret 95 96 net = Net() 97 assert net() == "hello world" 98 99 100def test_output_parameter_int(): 101 class Net(Cell): 102 def __init__(self): 103 super(Net, self).__init__() 104 105 def construct(self, x): 106 return x 107 108 x = Tensor(np.array(88).astype(np.int32)) 109 net = Net() 110 assert net(x) == x 111 112 113def test_output_parameter_str(): 114 class Net(Cell): 115 def __init__(self): 116 super(Net, self).__init__() 117 self.x = "hello world" 118 119 def construct(self): 120 return self.x 121 122 x = "hello world" 123 net = Net() 124 assert net() == x 125 126 127def test_tuple_tuple_0(): 128 class Net(Cell): 129 def __init__(self): 130 super(Net, self).__init__() 131 self.add = P.Add() 132 self.sub = P.Sub() 133 134 def construct(self, x, y): 135 xx = self.add(x, x) 136 yy = self.add(y, y) 137 xxx = self.sub(x, x) 138 yyy = self.sub(y, y) 139 ret = ((xx, yy), (xxx, yyy)) 140 ret = (ret, ret) 141 return ret 142 143 net = Net() 144 x = Tensor(np.ones([2], np.int32)) 145 y = Tensor(np.zeros([3], np.int32)) 146 net(x, y) 147 148 149def test_tuple_tuple_1(): 150 class Net(Cell): 151 def __init__(self): 152 super(Net, self).__init__() 153 self.add = P.Add() 154 self.sub = P.Sub() 155 156 def construct(self, x, y): 157 xx = self.add(x, x) 158 yy = self.add(y, y) 159 ret = ((xx, yy), x) 160 ret = (ret, ret) 161 return ret 162 163 net = Net() 164 x = Tensor(np.ones([2], np.int32)) 165 y = Tensor(np.zeros([3], np.int32)) 166 net(x, y) 167 168 169def test_tuple_tuple_2(): 170 class Net(Cell): 171 def __init__(self): 172 super(Net, self).__init__() 173 self.add = P.Add() 174 self.sub = P.Sub() 175 self.relu = P.ReLU() 176 self.depend = depend 177 178 def construct(self, x, y): 179 xx = self.add(x, x) 180 yy = self.add(y, y) 181 xxx = self.sub(x, x) 182 yyy = self.sub(y, y) 183 z = self.relu(x) 184 ret = ((xx, yy), (xxx, yyy)) 185 ret = (ret, ret) 186 ret = self.depend(ret, z) 187 return ret 188 189 net = Net() 190 x = Tensor(np.ones([2], np.int32)) 191 y = Tensor(np.zeros([3], np.int32)) 192 net(x, y) 193 194 195def test_tuple_tuple_3(): 196 class Net(Cell): 197 def __init__(self): 198 super(Net, self).__init__() 199 self.add = P.Add() 200 self.sub = P.Sub() 201 self.relu = P.ReLU() 202 self.depend = depend 203 204 def construct(self, x, y): 205 xx = self.add(x, x) 206 yy = self.add(y, y) 207 z = self.relu(x) 208 ret = ((xx, yy), x) 209 ret = (ret, ret) 210 ret = self.depend(ret, z) 211 return ret 212 213 net = Net() 214 x = Tensor(np.ones([2], np.int32)) 215 y = Tensor(np.zeros([3], np.int32)) 216 net(x, y) 217 218 219def test_soft(): 220 class SoftmaxCrossEntropyWithLogitsNet(Cell): 221 def __init__(self): 222 super(SoftmaxCrossEntropyWithLogitsNet, self).__init__() 223 self.soft = P.SoftmaxCrossEntropyWithLogits() 224 self.value = (Tensor(np.zeros((2, 2)).astype(np.float32)), Tensor(np.ones((2, 2)).astype(np.float32))) 225 226 def construct(self, x, y, z): 227 xx = x + y 228 yy = x - y 229 ret = self.soft(xx, yy) 230 ret = (ret, z) 231 ret = (ret, self.value) 232 return ret 233 234 input1 = Tensor(np.zeros((2, 2)).astype(np.float32)) 235 input2 = Tensor(np.ones((2, 2)).astype(np.float32)) 236 input3 = Tensor((np.ones((2, 2)) + np.ones((2, 2))).astype(np.float32)) 237 net = SoftmaxCrossEntropyWithLogitsNet() 238 net(input1, input2, input3) 239 240 241def test_const_depend(): 242 class ConstDepend(Cell): 243 def __init__(self): 244 super(ConstDepend, self).__init__() 245 self.value = (Tensor(np.zeros((2, 3)).astype(np.float32)), Tensor(np.ones((2, 3)).astype(np.float32))) 246 self.soft = P.SoftmaxCrossEntropyWithLogits() 247 self.depend = depend 248 249 def construct(self, x, y, z): 250 ret = x + y 251 ret = ret * z 252 ret = self.depend(self.value, ret) 253 ret = (ret, self.soft(x, y)) 254 return ret 255 256 input1 = Tensor(np.zeros((2, 2)).astype(np.float32)) 257 input2 = Tensor(np.ones((2, 2)).astype(np.float32)) 258 input3 = Tensor((np.ones((2, 2)) + np.ones((2, 2))).astype(np.float32)) 259 net = ConstDepend() 260 net(input1, input2, input3) 261