• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16
17import mindspore as ms
18import mindspore.nn as nn
19from mindspore import Parameter, Tensor, context
20from mindspore.common.api import _cell_graph_executor
21from mindspore.ops import composite as C
22from mindspore.ops import operations as P
23from tests.ut.python.ops.test_math_ops import VirtualLoss
24
25
26grad_all = C.GradOperation(get_all=True)
27
28
29class NetWithLoss(nn.Cell):
30    def __init__(self, network):
31        super(NetWithLoss, self).__init__()
32        self.loss = VirtualLoss()
33        self.network = network
34
35    def construct(self, x, y, b):
36        predict = self.network(x, y, b)
37        return self.loss(predict)
38
39
40class GradWrap(nn.Cell):
41    def __init__(self, network):
42        super(GradWrap, self).__init__()
43        self.network = network
44
45    def construct(self, x, y, b):
46        return grad_all(self.network)(x, y, b)
47
48
49def compile_net(net, x, y, b):
50    net.set_auto_parallel()
51    net.set_train()
52    _cell_graph_executor.compile(net, x, y, b)
53
54
55def test_matmul_sub():
56    class Net(nn.Cell):
57        def __init__(self, strategy1, strategy2):
58            super().__init__()
59            self.matmul = P.MatMul().shard(strategy1)
60            self.sub = P.Sub().shard(strategy2)
61
62        def construct(self, x, y, b):
63            out = self.matmul(x, y)
64            out = self.sub(out, b)
65            return out
66
67    context.set_auto_parallel_context(device_num=8, global_rank=0)
68    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
69    strategy1 = ((2, 2), (2, 2))
70    strategy2 = ((4, 2), (4, 2))
71    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
72
73    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
74    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
75    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
76    compile_net(net, x, y, b)
77
78
79def test_matmul_add():
80    class Net(nn.Cell):
81        def __init__(self, strategy1, strategy2):
82            super().__init__()
83            self.matmul = P.MatMul().shard(strategy1)
84            self.add = P.Add().shard(strategy2)
85
86        def construct(self, x, y, b):
87            out = self.matmul(x, y)
88            out = self.add(out, b)
89            return out
90
91    context.set_auto_parallel_context(device_num=8, global_rank=0)
92    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
93    strategy1 = ((2, 2), (2, 2))
94    strategy2 = ((4, 2), (4, 2))
95    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
96
97    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
98    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
99    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
100    compile_net(net, x, y, b)
101
102
103def test_matmul_mul():
104    class Net(nn.Cell):
105        def __init__(self, strategy1, strategy2):
106            super().__init__()
107            self.matmul = P.MatMul().shard(strategy1)
108            self.mul = P.Mul().shard(strategy2)
109
110        def construct(self, x, y, b):
111            out = self.matmul(x, y)
112            out = self.mul(out, b)
113            return out
114
115    context.set_auto_parallel_context(device_num=8, global_rank=0)
116    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
117    strategy1 = ((2, 2), (2, 2))
118    strategy2 = ((4, 2), (4, 2))
119    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
120
121    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
122    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
123    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
124    compile_net(net, x, y, b)
125
126def test_matmul_mod():
127    class Net(nn.Cell):
128        def __init__(self, strategy1, strategy2):
129            super().__init__()
130            self.matmul = P.MatMul().shard(strategy1)
131            self.mod = P.Mod().shard(strategy2)
132
133        def construct(self, x, y, b):
134            out = self.matmul(x, y)
135            out = self.mod(out, b)
136            return out
137
138    context.set_auto_parallel_context(device_num=8, global_rank=0)
139    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
140    strategy1 = ((2, 2), (2, 2))
141    strategy2 = ((4, 2), (4, 2))
142    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
143
144    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
145    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
146    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
147    compile_net(net, x, y, b)
148
149def test_matmul_floormod():
150    class Net(nn.Cell):
151        def __init__(self, strategy1, strategy2):
152            super().__init__()
153            self.matmul = P.MatMul().shard(strategy1)
154            self.floormod = P.FloorMod().shard(strategy2)
155
156        def construct(self, x, y, b):
157            out = self.matmul(x, y)
158            out = self.floormod(out, b)
159            return out
160
161    context.set_auto_parallel_context(device_num=8, global_rank=0)
162    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
163    strategy1 = ((2, 2), (2, 2))
164    strategy2 = ((4, 2), (4, 2))
165    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
166
167    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
168    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
169    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
170    compile_net(net, x, y, b)
171
172
173def test_matmul_atan2():
174    class Net(nn.Cell):
175        def __init__(self, strategy1, strategy2):
176            super().__init__()
177            self.matmul = P.MatMul().shard(strategy1)
178            self.atan2 = P.Atan2().shard(strategy2)
179
180        def construct(self, x, y, b):
181            out = self.matmul(x, y)
182            out = self.atan2(out, b)
183            return out
184
185    context.set_auto_parallel_context(device_num=8, global_rank=0)
186    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
187    strategy1 = ((2, 2), (2, 2))
188    strategy2 = ((4, 2), (4, 2))
189    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
190
191    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
192    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
193    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
194    compile_net(net, x, y, b)
195
196
197def test_matmul_divNoNan():
198    class Net(nn.Cell):
199        def __init__(self, strategy1, strategy2):
200            super().__init__()
201            self.matmul = P.MatMul().shard(strategy1)
202            self.divNoNan = P.DivNoNan().shard(strategy2)
203
204        def construct(self, x, y, b):
205            out = self.matmul(x, y)
206            out = self.divNoNan(out, b)
207            return out
208
209    context.set_auto_parallel_context(device_num=8, global_rank=0)
210    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
211    strategy1 = ((2, 2), (2, 2))
212    strategy2 = ((4, 2), (4, 2))
213    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
214
215    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
216    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
217    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
218    compile_net(net, x, y, b)
219
220
221def test_matmul_logicaland():
222    class Net(nn.Cell):
223        def __init__(self, strategy1, strategy2):
224            super().__init__()
225            self.matmul = P.MatMul().shard(strategy1)
226            self.equal = P.Equal().shard(strategy2)
227            self.notequal = P.NotEqual().shard(strategy2)
228            self.logical = P.LogicalAnd().shard(strategy2)
229
230        def construct(self, x, y, b):
231            out = self.matmul(x, y)
232            out1 = self.equal(out, b)
233            out = self.matmul(x, y)
234            out2 = self.notequal(out, b)
235            out = self.logical(out1, out2)
236            return out
237
238    context.set_auto_parallel_context(device_num=8, global_rank=0)
239    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
240    strategy1 = ((2, 2), (2, 2))
241    strategy2 = ((4, 2), (4, 2))
242    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
243
244    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
245    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
246    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
247    compile_net(net, x, y, b)
248
249
250def test_matmul_logicalor():
251    class Net(nn.Cell):
252        def __init__(self, strategy1, strategy2):
253            super().__init__()
254            self.matmul = P.MatMul().shard(strategy1)
255            self.equal = P.Equal().shard(strategy2)
256            self.notequal = P.NotEqual().shard(strategy2)
257            self.logical = P.LogicalOr().shard(strategy2)
258
259        def construct(self, x, y, b):
260            out = self.matmul(x, y)
261            out1 = self.equal(out, b)
262            out = self.matmul(x, y)
263            out2 = self.notequal(out, b)
264            out = self.logical(out1, out2)
265            return out
266
267    context.set_auto_parallel_context(device_num=8, global_rank=0)
268    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
269    strategy1 = ((2, 2), (2, 2))
270    strategy2 = ((4, 2), (4, 2))
271    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
272
273    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
274    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
275    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
276    compile_net(net, x, y, b)
277
278
279def test_matmul_div():
280    class Net(nn.Cell):
281        def __init__(self, strategy1, strategy2):
282            super().__init__()
283            self.matmul = P.MatMul().shard(strategy1)
284            self.div = P.Div().shard(strategy2)
285
286        def construct(self, x, y, b):
287            out = self.matmul(x, y)
288            out = self.div(out, b)
289            return out
290
291    context.set_auto_parallel_context(device_num=8, global_rank=0)
292    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
293    strategy1 = ((2, 2), (2, 2))
294    strategy2 = ((4, 2), (4, 2))
295    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
296
297    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
298    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
299    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
300    compile_net(net, x, y, b)
301
302
303def test_matmul_add_broadcast():
304    class Net(nn.Cell):
305        def __init__(self, strategy1, strategy2):
306            super().__init__()
307            self.matmul = P.MatMul().shard(strategy1)
308            self.add = P.Add().shard(strategy2)
309
310        def construct(self, x, y, b):
311            out = self.matmul(x, y)
312            out = self.add(out, b)
313            return out
314
315    context.set_auto_parallel_context(device_num=8, global_rank=0)
316    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
317    strategy1 = ((2, 2), (2, 2))
318    strategy2 = ((4, 2), (2,))
319    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
320
321    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
322    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
323    b = Tensor(np.ones([64]), dtype=ms.float32)
324    compile_net(net, x, y, b)
325
326
327def test_matmul_add_broadcast2():
328    class Net(nn.Cell):
329        def __init__(self, strategy1, strategy2):
330            super().__init__()
331            self.matmul = P.MatMul().shard(strategy1)
332            self.add = P.Add().shard(strategy2)
333
334        def construct(self, x, y, b):
335            out = self.matmul(x, y)
336            out = self.add(out, b)
337            return out
338
339    context.set_auto_parallel_context(device_num=8, global_rank=0)
340    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
341    strategy1 = ((2, 4), (4, 1))
342    strategy2 = ((4, 1), (1, 2))
343    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
344
345    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
346    y = Tensor(np.ones([32, 1]), dtype=ms.float32)
347    b = Tensor(np.ones([1, 64]), dtype=ms.float32)
348    compile_net(net, x, y, b)
349
350
351def test_matmul_sub_broadcast():
352    class Net(nn.Cell):
353        def __init__(self, strategy1, strategy2):
354            super().__init__()
355            self.matmul = P.MatMul().shard(strategy1)
356            self.sub = P.Sub().shard(strategy2)
357
358        def construct(self, x, y, b):
359            out = self.matmul(x, y)
360            out = self.sub(out, b)
361            return out
362
363    context.set_auto_parallel_context(device_num=8, global_rank=0)
364    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
365    strategy1 = ((2, 2), (2, 2))
366    strategy2 = ((4, 2), (2,))
367    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
368
369    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
370    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
371    b = Tensor(np.ones([64]), dtype=ms.float32)
372    compile_net(net, x, y, b)
373
374
375def test_matmul_sub_broadcast2():
376    class Net(nn.Cell):
377        def __init__(self, strategy1, strategy2):
378            super().__init__()
379            self.matmul = P.MatMul().shard(strategy1)
380            self.sub = P.Sub().shard(strategy2)
381
382        def construct(self, x, y, b):
383            out = self.matmul(x, y)
384            out = self.sub(out, b)
385            return out
386
387    context.set_auto_parallel_context(device_num=8, global_rank=0)
388    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
389    strategy1 = ((2, 4), (4, 1))
390    strategy2 = ((4, 1), (1, 2))
391    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
392
393    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
394    y = Tensor(np.ones([32, 1]), dtype=ms.float32)
395    b = Tensor(np.ones([1, 64]), dtype=ms.float32)
396    compile_net(net, x, y, b)
397
398
399def test_matmul_mul_broadcast():
400    class Net(nn.Cell):
401        def __init__(self, strategy1, strategy2):
402            super().__init__()
403            self.matmul = P.MatMul().shard(strategy1)
404            self.mul = P.Mul().shard(strategy2)
405
406        def construct(self, x, y, b):
407            out = self.matmul(x, y)
408            out = self.mul(out, b)
409            return out
410
411    context.set_auto_parallel_context(device_num=8, global_rank=0)
412    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
413    strategy1 = ((2, 2), (2, 2))
414    strategy2 = ((4, 2), (2,))
415    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
416
417    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
418    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
419    b = Tensor(np.ones([64]), dtype=ms.float32)
420    compile_net(net, x, y, b)
421
422
423def test_matmul_mul_broadcast2():
424    class Net(nn.Cell):
425        def __init__(self, strategy1, strategy2):
426            super().__init__()
427            self.matmul = P.MatMul().shard(strategy1)
428            self.mul = P.Mul().shard(strategy2)
429
430        def construct(self, x, y, b):
431            out = self.matmul(x, y)
432            out = self.mul(out, b)
433            return out
434
435    context.set_auto_parallel_context(device_num=8, global_rank=0)
436    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
437    strategy1 = ((2, 4), (4, 1))
438    strategy2 = ((4, 1), (1, 2))
439    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
440
441    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
442    y = Tensor(np.ones([32, 1]), dtype=ms.float32)
443    b = Tensor(np.ones([1, 64]), dtype=ms.float32)
444    compile_net(net, x, y, b)
445
446
447def test_matmul_div_broadcast():
448    class Net(nn.Cell):
449        def __init__(self, strategy1, strategy2):
450            super().__init__()
451            self.matmul = P.MatMul().shard(strategy1)
452            self.div = P.Div().shard(strategy2)
453
454        def construct(self, x, y, b):
455            out = self.matmul(x, y)
456            out = self.div(out, b)
457            return out
458
459    context.set_auto_parallel_context(device_num=8, global_rank=0)
460    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
461    strategy1 = ((2, 2), (2, 2))
462    strategy2 = ((4, 2), (2,))
463    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
464
465    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
466    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
467    b = Tensor(np.ones([64]), dtype=ms.float32)
468    compile_net(net, x, y, b)
469
470
471def test_matmul_div_broadcast2():
472    class Net(nn.Cell):
473        def __init__(self, strategy1, strategy2):
474            super().__init__()
475            self.matmul = P.MatMul().shard(strategy1)
476            self.div = P.Div().shard(strategy2)
477
478        def construct(self, x, y, b):
479            out = self.matmul(x, y)
480            out = self.div(out, b)
481            return out
482
483    context.set_auto_parallel_context(device_num=8, global_rank=0)
484    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
485    strategy1 = ((2, 4), (4, 1))
486    strategy2 = ((4, 1), (1, 2))
487    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
488
489    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
490    y = Tensor(np.ones([32, 1]), dtype=ms.float32)
491    b = Tensor(np.ones([1, 64]), dtype=ms.float32)
492    compile_net(net, x, y, b)
493
494
495def test_matmul_greater_broadcast():
496    class Net(nn.Cell):
497        def __init__(self, strategy1, strategy2):
498            super().__init__()
499            self.matmul = P.MatMul().shard(strategy1)
500            self.greater = P.Greater().shard(strategy2)
501
502        def construct(self, x, y, b):
503            out = self.matmul(x, y)
504            out = self.greater(out, b)
505            return out
506
507    context.set_auto_parallel_context(device_num=8, global_rank=0)
508    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
509    strategy1 = ((2, 2), (2, 2))
510    strategy2 = ((4, 2), (2,))
511    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
512
513    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
514    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
515    b = Tensor(np.ones([64]), dtype=ms.float32)
516    compile_net(net, x, y, b)
517
518
519def test_matmul_greater_broadcast2():
520    class Net(nn.Cell):
521        def __init__(self, strategy1, strategy2):
522            super().__init__()
523            self.matmul = P.MatMul().shard(strategy1)
524            self.greater = P.Greater().shard(strategy2)
525
526        def construct(self, x, y, b):
527            out = self.matmul(x, y)
528            out = self.greater(out, b)
529            return out
530
531    context.set_auto_parallel_context(device_num=8, global_rank=0)
532    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
533    strategy1 = ((2, 4), (4, 1))
534    strategy2 = ((4, 1), (1, 2))
535    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
536
537    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
538    y = Tensor(np.ones([32, 1]), dtype=ms.float32)
539    b = Tensor(np.ones([1, 64]), dtype=ms.float32)
540    compile_net(net, x, y, b)
541
542
543def test_matmul_floordiv():
544    class Net(nn.Cell):
545        def __init__(self, strategy1, strategy2):
546            super().__init__()
547            self.matmul = P.MatMul().shard(strategy1)
548            self.floordiv = P.FloorDiv().shard(strategy2)
549
550        def construct(self, x, y, b):
551            out = self.matmul(x, y)
552            out = self.floordiv(out, b)
553            return out
554
555    context.set_auto_parallel_context(device_num=8, global_rank=0)
556    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
557    strategy1 = ((2, 2), (2, 2))
558    strategy2 = ((4, 2), (4, 2))
559    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
560
561    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
562    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
563    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
564    compile_net(net, x, y, b)
565
566
567def test_matmul_floordiv_broadcast():
568    class Net(nn.Cell):
569        def __init__(self, strategy1, strategy2):
570            super().__init__()
571            self.matmul = P.MatMul().shard(strategy1)
572            self.floordiv = P.FloorDiv().shard(strategy2)
573
574        def construct(self, x, y, b):
575            out = self.matmul(x, y)
576            out = self.floordiv(out, b)
577            return out
578
579    context.set_auto_parallel_context(device_num=8, global_rank=0)
580    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
581    strategy1 = ((2, 2), (2, 2))
582    strategy2 = ((4, 2), (2,))
583    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
584
585    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
586    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
587    b = Tensor(np.ones([64]), dtype=ms.float32)
588    compile_net(net, x, y, b)
589
590
591def test_matmul_floordiv_broadcast2():
592    class Net(nn.Cell):
593        def __init__(self, strategy1, strategy2):
594            super().__init__()
595            self.matmul = P.MatMul().shard(strategy1)
596            self.floordiv = P.FloorDiv().shard(strategy2)
597
598        def construct(self, x, y, b):
599            out = self.matmul(x, y)
600            out = self.floordiv(out, b)
601            return out
602
603    context.set_auto_parallel_context(device_num=8, global_rank=0)
604    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
605    strategy1 = ((2, 4), (4, 1))
606    strategy2 = ((4, 1), (1, 2))
607    net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
608
609    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
610    y = Tensor(np.ones([32, 1]), dtype=ms.float32)
611    b = Tensor(np.ones([1, 64]), dtype=ms.float32)
612    compile_net(net, x, y, b)
613
614
615def test_assign_sub():
616    class Net(nn.Cell):
617        def __init__(self):
618            super().__init__()
619            self.assign_sub = P.AssignSub()
620            self.mul = P.Mul()
621            self.mul_weight = Parameter(Tensor(np.full([128, 32],
622                                                       0.5, dtype=np.float32)),
623                                        name="mul_weight")
624            self.assignsub_weight = Parameter(Tensor(np.full([128, 32],
625                                                             1.1, dtype=np.float32)),
626                                              name="assignsub_weight")
627
628        def construct(self, x):
629            out = self.mul(x, self.mul_weight)
630            out = self.assign_sub(self.assignsub_weight, out)
631            return out
632
633    class SubNetWithLoss(nn.Cell):
634        def __init__(self, network):
635            super(SubNetWithLoss, self).__init__()
636            self.loss = VirtualLoss()
637            self.network = network
638
639        def construct(self, x):
640            predict = self.network(x,)
641            return self.loss(predict)
642
643    class SubGradWrap(nn.Cell):
644        def __init__(self, network):
645            super(SubGradWrap, self).__init__()
646            self.network = network
647
648        def construct(self, x):
649            return grad_all(self.network)(x)
650
651    def compile_sub_net(net, x):
652        net.set_auto_parallel()
653        net.set_train()
654        _cell_graph_executor.compile(net, x)
655
656    context.set_auto_parallel_context(device_num=64, global_rank=15)
657    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
658    net = SubGradWrap(SubNetWithLoss(Net()))
659    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
660    compile_sub_net(net, x)
661
662
663def test_assign_add():
664    class Net(nn.Cell):
665        def __init__(self):
666            super().__init__()
667            self.assign_sub = P.AssignAdd()
668            self.mul = P.Mul()
669            self.mul_weight = Parameter(Tensor(np.full([128, 32],
670                                                       0.5, dtype=np.float32)),
671                                        name="mul_weight")
672            self.assignsub_weight = Parameter(Tensor(np.full([128, 32],
673                                                             1.1, dtype=np.float32)),
674                                              name="assignsub_weight")
675
676        def construct(self, x):
677            out = self.mul(x, self.mul_weight)
678            out = self.assign_sub(self.assignsub_weight, out)
679            return out
680
681    class SubNetWithLoss(nn.Cell):
682        def __init__(self, network):
683            super(SubNetWithLoss, self).__init__()
684            self.loss = VirtualLoss()
685            self.network = network
686
687        def construct(self, x):
688            predict = self.network(x,)
689            return self.loss(predict)
690
691    class SubGradWrap(nn.Cell):
692        def __init__(self, network):
693            super(SubGradWrap, self).__init__()
694            self.network = network
695
696        def construct(self, x):
697            return grad_all(self.network)(x)
698
699    def compile_sub_net(net, x):
700        net.set_auto_parallel()
701        net.set_train()
702        _cell_graph_executor.compile(net, x)
703
704    context.set_auto_parallel_context(device_num=64, global_rank=15)
705    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
706    net = SubGradWrap(SubNetWithLoss(Net()))
707    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
708    compile_sub_net(net, x)
709
710
711def test_assign():
712    class Net(nn.Cell):
713        def __init__(self):
714            super().__init__()
715            self.assign_sub = P.Assign()
716            self.mul = P.Mul()
717            self.mul_weight = Parameter(Tensor(np.full([128, 32],
718                                                       0.5, dtype=np.float32)),
719                                        name="mul_weight")
720            self.assignsub_weight = Parameter(Tensor(np.full([128, 32],
721                                                             1.1, dtype=np.float32)),
722                                              name="assignsub_weight")
723
724        def construct(self, x):
725            out = self.mul(x, self.mul_weight)
726            out = self.assign_sub(self.assignsub_weight, out)
727            return out
728
729    class SubNetWithLoss(nn.Cell):
730        def __init__(self, network):
731            super(SubNetWithLoss, self).__init__()
732            self.loss = VirtualLoss()
733            self.network = network
734
735        def construct(self, x):
736            predict = self.network(x,)
737            return self.loss(predict)
738
739    class SubGradWrap(nn.Cell):
740        def __init__(self, network):
741            super(SubGradWrap, self).__init__()
742            self.network = network
743
744        def construct(self, x):
745            return grad_all(self.network)(x)
746
747    def compile_sub_net(net, x):
748        net.set_auto_parallel()
749        net.set_train()
750        _cell_graph_executor.compile(net, x)
751
752    context.set_auto_parallel_context(device_num=64, global_rank=15)
753    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
754    net = SubGradWrap(SubNetWithLoss(Net()))
755    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
756    compile_sub_net(net, x)
757