• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16
17import mindspore.context as context
18import mindspore.ops.composite as C
19from mindspore import Tensor, Parameter
20from mindspore import nn
21from mindspore.nn import Cell
22from mindspore.ops import operations as P
23
24context.set_context(mode=context.GRAPH_MODE)
25
26grad_all = C.GradOperation(get_all=True)
27grad_all_with_sens = C.GradOperation(sens_param=True)
28
29
30def test_parser_three_default_mixed_args_subnet():
31    class SubNetDefaultMixedArgs(Cell):
32        def __init__(self):
33            super().__init__()
34
35        def construct(self, y, x=3, x1=None, x2=(1, 2)):
36            if x == 3:
37                if x1 == None:
38                    return y
39            return -y
40
41    class NetOut(Cell):
42        def __init__(self):
43            super(NetOut, self).__init__()
44            self.net_inside = SubNetDefaultMixedArgs()
45
46        def construct(self, x, y=3):
47            z = self.net_inside(x)
48
49            return z
50
51    tensor1 = Tensor(np.full((2, 3), 2).astype(np.float32))
52    tensor2 = Tensor(np.full((3, 2), 4).astype(np.float32))
53    net = NetOut()
54    assert np.all(net(tensor1, tensor2).asnumpy() == tensor1.asnumpy())
55
56
57# pylint: disable=keyword-arg-before-vararg
58def test_net_vararg_kwonlyarg_kwarg():
59    class FirstNet(Cell):
60        def __init__(self):
61            super(FirstNet, self).__init__()
62            self.net = SecondNet()
63
64        def construct(self, x=1, z=2 + 2 + 4, y=3):
65            c = self.net(22, 33, x, y, z, 2, 3, 4, 5, key1=10, key2=20, key3=30, key4=40)
66            return c
67
68    class SecondNet(Cell):
69        def __init__(self):
70            super(SecondNet, self).__init__()
71
72        def construct(self, x, y=2, p=5, q=40, *var, key1=1, key2=3, **kwargs):
73            a = x - y
74            b = p * q
75            c = a / b
76            d = var[0] * var[1] * var[2] * var[3]
77            e = key1 - key2 - kwargs["key3"] + kwargs["key4"]
78            return a + b + c + d + e
79
80    net = FirstNet()
81    net()
82
83
84# pylint: disable=keyword-arg-before-vararg
85def test_net_vararg_normal_input():
86    class FirstNet(Cell):
87        def __init__(self):
88            super(FirstNet, self).__init__()
89            self.net = SecondNet()
90
91        def construct(self, x=1, z=2 + 2 + 4, y=3):
92            c = self.net(22, 33, x, y, z, 2, 3, 4, 5, key1=10, key2=20, key3=30, key4=40)
93            return c
94
95    class SecondNet(Cell):
96        def __init__(self):
97            super(SecondNet, self).__init__()
98
99        def construct(self, x, y=2, p=5, q=40, *var, key1=1, key2=3, **kwargs):
100            a = x - y
101            b = p * q
102            c = a / b
103            d = var[0] * var[1] * var[2] * var[3]
104            e = key1 - key2 - kwargs["key3"] + kwargs["key4"]
105            return a + b + c + d + e
106
107    x = Tensor(np.ones((2, 3, 4), np.int32))
108    net = FirstNet()
109    net(x, x, x)
110
111
112def test_prim_vararg_kwonlyarg():
113    class FirstNet(Cell):
114        def __init__(self):
115            super(FirstNet, self).__init__()
116            self.max = P.Maximum()
117            self.min = P.Minimum()
118            self.net = SecondNet()
119            self.x = Tensor(np.ones((2, 3, 4), np.float32))
120            self.y = Tensor(np.ones((2, 3, 4), np.float32))
121
122        def construct(self):
123            a = self.max(self.x, self.y)
124            b = self.min(self.x, self.y)
125            t = {"x": a, "y": b}
126            c = self.net(t["x"], t["y"], a, b, z=a, r=b)
127            return c
128
129    class SecondNet(Cell):
130        def __init__(self):
131            super(SecondNet, self).__init__()
132            self.addN = P.AddN()
133            self.max = P.Maximum()
134            self.add = P.Add()
135
136        def construct(self, x, y, *args, z=0, r=1):
137            c = self.max(args[0], args[1])
138            d = self.addN(args)
139            e = self.max(*args)
140            ret = x + y + c + d + e + z + r
141            return ret
142
143    net = FirstNet()
144    net()
145
146
147def test_no_vararg():
148    class FirstNet(Cell):
149        def __init__(self):
150            super(FirstNet, self).__init__()
151            self.max = P.Maximum()
152            self.min = P.Minimum()
153            self.net = SecondNet()
154            self.x = Tensor(np.ones((2, 3, 4), np.float32))
155            self.y = Tensor(np.ones((2, 3, 4), np.float32))
156
157        def construct(self):
158            t = {"x": self.x, "y": self.y}
159            a = self.max(self.x, self.y)
160            b = self.min(self.x, self.y)
161            c = self.net(a, b, z=a, r=b)
162            return c
163
164    class SecondNet(Cell):
165        def __init__(self):
166            super(SecondNet, self).__init__()
167
168        def construct(self, x, y, *, z=0, r=1):
169            ret = x + y + z + r
170            return ret
171
172    net = FirstNet()
173    net()
174
175
176def test_net_variable_and_weights():
177    class FirstNet(Cell):
178        def __init__(self):
179            super(FirstNet, self).__init__()
180            self.max = P.Maximum()
181            self.min = P.Minimum()
182            self.net = SecondNet()
183            self.x = Tensor(np.ones((3, 4), np.float32))
184            self.y = Tensor(np.ones((3, 4), np.float32))
185            self.weight = Parameter(Tensor(np.ones((2, 3, 4)).astype(np.float32)), "w1", requires_grad=True)
186
187        def construct(self, *args):
188            t = (self.x, self.y)
189            a = self.max(self.x, self.weight)
190            b = self.min(self.weight, args[0])
191            c = self.net(a, b, *t)
192            return c
193
194    class SecondNet(Cell):
195        def __init__(self):
196            super(SecondNet, self).__init__()
197            self.addN = P.AddN()
198            self.max = P.Maximum()
199            self.add = P.Add()
200            self.weight = Parameter(Tensor(np.ones((2, 3, 4), np.float32)), "w2", requires_grad=True)
201
202        def construct(self, a, b, *args):
203            c = self.max(args[0], a)
204            d = self.addN(args)
205            ret = a + b + c + d + self.weight
206            return ret
207
208    net = FirstNet()
209    x = Tensor(np.ones((4,), np.float32))
210    y = Tensor(np.ones((4,), np.float32))
211    z = Tensor(np.ones((4,), np.float32))
212    net(x, y, z)
213
214
215def test_net_vargs_expand():
216    class InputBackward(Cell):
217        """ InputBackward definition """
218
219        def __init__(self, network, c1=None, c2=None):
220            super(InputBackward, self).__init__()
221            self.network = network
222            self.network.set_train()
223            self.grad = grad_all_with_sens
224            self.c1 = c1
225            self.c2 = c2
226
227        def construct(self, *inputs):
228            return self.grad(self.network)(*inputs)
229
230    class AddNet(Cell):
231        def __init__(self):
232            super(AddNet, self).__init__()
233
234        def construct(self, x, y):
235            return x + y
236
237    net = InputBackward(AddNet())
238    x = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
239    y = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
240    sens = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
241
242    net.set_train()
243    net(x, y, sens)
244
245
246def test_mixed_precision_const_parameter():
247    class NetLoss(Cell):
248        def __init__(self):
249            super(NetLoss, self).__init__()
250            self.shape = P.Shape()
251            self.up_sample1 = P.ResizeBilinear((14, 14))
252            self.up_sample2 = P.ResizeBilinear((28, 28))
253            self.up_sample3 = P.ResizeBilinear((36, 36))
254
255        def construct(self, x, y, z, *args):
256            ret = 0
257            if args[0] == self.shape(z)[2]:
258                if args[0] == 14:
259                    ret = self.up_sample1(y) + x
260                elif args[0] == 28:
261                    ret = self.up_sample2(y) - x
262                else:
263                    ret = x / y
264            else:
265                ret = x * y
266            ret = ret * z
267            return ret
268
269    class NetMain(Cell):
270        def __init__(self, loss_fn):
271            super(NetMain, self).__init__()
272            self.loss_fn = loss_fn
273            self.shape = P.Shape()
274
275        def construct(self, x, y, z):
276            size_x = self.shape(x)[2]
277            size_y = self.shape(y)[2]
278            ret = self.loss_fn(x, y, z, size_x, size_y)
279            return ret
280
281    loss_fn = NetLoss()
282    net = NetMain(loss_fn)
283    net.add_flags_recursive(fp32=True)
284    x = Tensor(np.ones((1, 3, 28, 28), np.float32))
285    y = Tensor(np.ones((1, 3, 14, 14), np.float32))
286    z = Tensor(np.ones((1, 3, 28, 28), np.float32))
287    _ = net(x, y, z)
288
289
290def test_pass_args_by_key_ward_way():
291    class KeyWardNet(Cell):
292        def __init__(self):
293            super(KeyWardNet, self).__init__()
294
295        def construct(self, x, y, z):
296            return x + y - z
297
298    class GradNet(Cell):
299        def __init__(self, net):
300            super(GradNet, self).__init__()
301            self.grad = C.GradOperation(get_all=True, sens_param=True)
302            self.net = net
303            self.sens = Tensor(np.ones((3, 3, 4), np.float32))
304
305        def construct(self, x, y, z, sens):
306            return self.grad(self.net)(x, y, z, sens)
307
308    x = Tensor(np.ones((1, 3, 4), np.float32))
309    y = Tensor(np.ones((1, 3, 4), np.float32))
310    z = Tensor(np.ones((3, 3, 4), np.float32))
311    net = KeyWardNet()
312    net(x, z=z, y=y)
313    grad_net = GradNet(net)
314    sens = Tensor(np.ones((3, 3, 4), np.float32))
315    grad_net(x, y=y, z=z, sens=sens)
316
317
318def test_none_input():
319    """
320    Feature: Net's inputs
321    Description: Support none input for the outermost net
322    Expectation: no error
323    """
324
325    class Net(Cell):
326        def __init__(self):
327            super(Net, self).__init__()
328            self.op = nn.ResizeBilinear()
329
330        def construct(self, a, b, c, d):
331            return self.op(a, b, c, d)
332
333    x = Tensor(np.array([1, 2, 3, 4]).astype(np.float32).reshape((1, 1, 2, 2,)))
334    net = Net()
335    net(x, (4, 4), None, True)
336