• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16test_structure_output
17"""
18import numpy as np
19
20import mindspore.ops.operations as P
21from mindspore import Tensor, context
22from mindspore.nn import Cell
23from mindspore.ops.functional import depend
24
25context.set_context(mode=context.GRAPH_MODE)
26
27
28def test_output_const_tuple_0():
29    class Net(Cell):
30        def __init__(self):
31            super(Net, self).__init__()
32            self.x = (1, 2, 3)
33
34        def construct(self):
35            return self.x
36
37    x = (1, 2, 3)
38    net = Net()
39    assert net() == x
40
41
42def test_output_const_tuple_1():
43    class Net(Cell):
44        def __init__(self):
45            super(Net, self).__init__()
46            self.tuple_1 = (1, 2, 3)
47            self.tuple_2 = (4, 5, 6)
48
49        def construct(self):
50            ret = self.tuple_1 + self.tuple_2
51            return ret
52
53    net = Net()
54    assert net() == (1, 2, 3, 4, 5, 6)
55
56
57def test_output_const_list():
58    class Net(Cell):
59        def __init__(self):
60            super(Net, self).__init__()
61            self.tuple_1 = [1, 2, 3]
62
63        def construct(self):
64            ret = self.tuple_1
65            return ret
66
67    net = Net()
68    assert net() == (1, 2, 3)
69
70
71def test_output_const_int():
72    class Net(Cell):
73        def __init__(self):
74            super(Net, self).__init__()
75            self.number_1 = 2
76            self.number_2 = 3
77
78        def construct(self):
79            ret = self.number_1 + self.number_2
80            return ret
81
82    net = Net()
83    assert net() == 5
84
85
86def test_output_const_str():
87    class Net(Cell):
88        def __init__(self):
89            super(Net, self).__init__()
90            self.str = "hello world"
91
92        def construct(self):
93            ret = self.str
94            return ret
95
96    net = Net()
97    assert net() == "hello world"
98
99
100def test_output_parameter_int():
101    class Net(Cell):
102        def __init__(self):
103            super(Net, self).__init__()
104
105        def construct(self, x):
106            return x
107
108    x = Tensor(np.array(88).astype(np.int32))
109    net = Net()
110    assert net(x) == x
111
112
113def test_output_parameter_str():
114    class Net(Cell):
115        def __init__(self):
116            super(Net, self).__init__()
117            self.x = "hello world"
118
119        def construct(self):
120            return self.x
121
122    x = "hello world"
123    net = Net()
124    assert net() == x
125
126
127def test_tuple_tuple_0():
128    class Net(Cell):
129        def __init__(self):
130            super(Net, self).__init__()
131            self.add = P.Add()
132            self.sub = P.Sub()
133
134        def construct(self, x, y):
135            xx = self.add(x, x)
136            yy = self.add(y, y)
137            xxx = self.sub(x, x)
138            yyy = self.sub(y, y)
139            ret = ((xx, yy), (xxx, yyy))
140            ret = (ret, ret)
141            return ret
142
143    net = Net()
144    x = Tensor(np.ones([2], np.int32))
145    y = Tensor(np.zeros([3], np.int32))
146    net(x, y)
147
148
149def test_tuple_tuple_1():
150    class Net(Cell):
151        def __init__(self):
152            super(Net, self).__init__()
153            self.add = P.Add()
154            self.sub = P.Sub()
155
156        def construct(self, x, y):
157            xx = self.add(x, x)
158            yy = self.add(y, y)
159            ret = ((xx, yy), x)
160            ret = (ret, ret)
161            return ret
162
163    net = Net()
164    x = Tensor(np.ones([2], np.int32))
165    y = Tensor(np.zeros([3], np.int32))
166    net(x, y)
167
168
169def test_tuple_tuple_2():
170    class Net(Cell):
171        def __init__(self):
172            super(Net, self).__init__()
173            self.add = P.Add()
174            self.sub = P.Sub()
175            self.relu = P.ReLU()
176            self.depend = depend
177
178        def construct(self, x, y):
179            xx = self.add(x, x)
180            yy = self.add(y, y)
181            xxx = self.sub(x, x)
182            yyy = self.sub(y, y)
183            z = self.relu(x)
184            ret = ((xx, yy), (xxx, yyy))
185            ret = (ret, ret)
186            ret = self.depend(ret, z)
187            return ret
188
189    net = Net()
190    x = Tensor(np.ones([2], np.int32))
191    y = Tensor(np.zeros([3], np.int32))
192    net(x, y)
193
194
195def test_tuple_tuple_3():
196    class Net(Cell):
197        def __init__(self):
198            super(Net, self).__init__()
199            self.add = P.Add()
200            self.sub = P.Sub()
201            self.relu = P.ReLU()
202            self.depend = depend
203
204        def construct(self, x, y):
205            xx = self.add(x, x)
206            yy = self.add(y, y)
207            z = self.relu(x)
208            ret = ((xx, yy), x)
209            ret = (ret, ret)
210            ret = self.depend(ret, z)
211            return ret
212
213    net = Net()
214    x = Tensor(np.ones([2], np.int32))
215    y = Tensor(np.zeros([3], np.int32))
216    net(x, y)
217
218
219def test_soft():
220    class SoftmaxCrossEntropyWithLogitsNet(Cell):
221        def __init__(self):
222            super(SoftmaxCrossEntropyWithLogitsNet, self).__init__()
223            self.soft = P.SoftmaxCrossEntropyWithLogits()
224            self.value = (Tensor(np.zeros((2, 2)).astype(np.float32)), Tensor(np.ones((2, 2)).astype(np.float32)))
225
226        def construct(self, x, y, z):
227            xx = x + y
228            yy = x - y
229            ret = self.soft(xx, yy)
230            ret = (ret, z)
231            ret = (ret, self.value)
232            return ret
233
234    input1 = Tensor(np.zeros((2, 2)).astype(np.float32))
235    input2 = Tensor(np.ones((2, 2)).astype(np.float32))
236    input3 = Tensor((np.ones((2, 2)) + np.ones((2, 2))).astype(np.float32))
237    net = SoftmaxCrossEntropyWithLogitsNet()
238    net(input1, input2, input3)
239
240
241def test_const_depend():
242    class ConstDepend(Cell):
243        def __init__(self):
244            super(ConstDepend, self).__init__()
245            self.value = (Tensor(np.zeros((2, 3)).astype(np.float32)), Tensor(np.ones((2, 3)).astype(np.float32)))
246            self.soft = P.SoftmaxCrossEntropyWithLogits()
247            self.depend = depend
248
249        def construct(self, x, y, z):
250            ret = x + y
251            ret = ret * z
252            ret = self.depend(self.value, ret)
253            ret = (ret, self.soft(x, y))
254            return ret
255
256    input1 = Tensor(np.zeros((2, 2)).astype(np.float32))
257    input2 = Tensor(np.ones((2, 2)).astype(np.float32))
258    input3 = Tensor((np.ones((2, 2)) + np.ones((2, 2))).astype(np.float32))
259    net = ConstDepend()
260    net(input1, input2, input3)
261