• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16import pytest
17import mindspore as ms
18import mindspore.nn as nn
19from mindspore import Tensor
20from mindspore import context
21from mindspore.common.api import _cell_graph_executor
22from mindspore.ops import composite as C
23from mindspore.ops import operations as P
24from tests.ut.python.ops.test_math_ops import VirtualLoss
25
26
27grad_all = C.GradOperation(get_all=True)
28
29
30class NetWithLossNoBias(nn.Cell):
31    def __init__(self, network):
32        super(NetWithLossNoBias, self).__init__()
33        self.loss = VirtualLoss()
34        self.network = network
35
36    def construct(self, x, y):
37        predict = self.network(x, y)
38        return self.loss(predict)
39
40
41class NetWithLoss(nn.Cell):
42    def __init__(self, network):
43        super(NetWithLoss, self).__init__()
44        self.loss = VirtualLoss()
45        self.network = network
46
47    def construct(self, x, y, b):
48        predict = self.network(x, y, b)
49        return self.loss(predict)
50
51
52class GradWrapNoBias(nn.Cell):
53    def __init__(self, network):
54        super(GradWrapNoBias, self).__init__()
55        self.network = network
56
57    def construct(self, x, y):
58        return grad_all(self.network)(x, y)
59
60
61class GradWrap(nn.Cell):
62    def __init__(self, network):
63        super(GradWrap, self).__init__()
64        self.network = network
65
66    def construct(self, x, y, b):
67        return grad_all(self.network)(x, y, b)
68
69
70def compile_net_no_bias(net, x, y):
71    net.set_auto_parallel()
72    net.set_train()
73    _cell_graph_executor.compile(net, x, y)
74
75
76def compile_net(net, x, y, b):
77    net.set_auto_parallel()
78    net.set_train()
79    _cell_graph_executor.compile(net, x, y, b)
80
81
82# model_parallel test
83def test_sum_mul():
84    class Net(nn.Cell):
85        def __init__(self, strategy1, strategy2, strategy3):
86            super().__init__()
87            self.mul1 = P.Mul().shard(strategy1)
88            self.reduce_sum = P.ReduceSum(keep_dims=False).shard(strategy2)
89            self.mul2 = P.Mul().shard(strategy3)
90
91        def construct(self, x, y, b):
92            out = self.mul1(x, y)
93            out = self.reduce_sum(out, (1,))
94            out = self.mul2(out, b)
95            return out
96
97    context.set_auto_parallel_context(device_num=8, global_rank=0)
98    strategy1 = ((1, 1, 8), (1, 1, 8))
99    strategy2 = ((4, 1, 2),)
100    strategy3 = ((2, 4), (2, 4))
101    net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
102    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
103
104    x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
105    y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
106    b = Tensor(np.ones([128, 64]), dtype=ms.float32)
107    compile_net(net, x, y, b)
108
109
110def test_sum_mul2():
111    class Net(nn.Cell):
112        def __init__(self, strategy1, strategy2, strategy3):
113            super().__init__()
114            self.mul1 = P.Mul().shard(strategy1)
115            self.reduce_sum = P.ReduceSum(keep_dims=False).shard(strategy2)
116            self.mul2 = P.Mul().shard(strategy3)
117
118        def construct(self, x, y, b):
119            out = self.mul1(x, y)
120            out = self.reduce_sum(out, (0, 1))
121            out = self.mul2(out, b)
122            return out
123
124    context.set_auto_parallel_context(device_num=8, global_rank=0)
125    strategy1 = ((1, 1, 4, 2), (1, 1, 4, 2))
126    strategy2 = ((2, 4, 1, 1),)
127    strategy3 = ((2, 4), (2, 4))
128    net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
129    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
130
131    x = Tensor(np.ones([128, 128, 64, 64]), dtype=ms.float32)
132    y = Tensor(np.ones([128, 128, 64, 64]), dtype=ms.float32)
133    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
134    compile_net(net, x, y, b)
135
136
137def test_sum_mul3():
138    class Net(nn.Cell):
139        def __init__(self, strategy1, strategy2, strategy3):
140            super().__init__()
141            self.mul1 = P.Mul().shard(strategy1)
142            self.reduce_sum = P.ReduceSum(keep_dims=False).shard(strategy2)
143            self.mul2 = P.Mul().shard(strategy3)
144
145        def construct(self, x, y, b):
146            out = self.mul1(x, y)
147            out = self.reduce_sum(out, -1)
148            out = self.mul2(out, b)
149            return out
150
151    context.set_auto_parallel_context(device_num=8, global_rank=0)
152    strategy1 = ((1, 4, 2), (1, 4, 2))
153    strategy2 = ((4, 2, 1),)
154    strategy3 = ((2, 4), (2, 4))
155    net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
156    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
157
158    x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
159    y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
160    b = Tensor(np.ones([128, 32]), dtype=ms.float32)
161    compile_net(net, x, y, b)
162
163
164def test_sum_mul4():
165    class Net(nn.Cell):
166        def __init__(self, strategy1, strategy2, strategy3):
167            super().__init__()
168            self.mul1 = P.Mul().shard(strategy1)
169            self.reduce_sum = P.ReduceSum(keep_dims=True).shard(strategy2)
170            self.mul2 = P.Mul().shard(strategy3)
171
172        def construct(self, x, y, b):
173            out = self.mul1(x, y)
174            out = self.reduce_sum(out, -1)
175            out = self.mul2(out, b)
176            return out
177
178    context.set_auto_parallel_context(device_num=8, global_rank=0)
179    strategy1 = ((1, 4, 2), (1, 4, 2))
180    strategy2 = ((2, 2, 2),)
181    strategy3 = ((4, 2, 1), (4, 2, 1))
182    net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
183    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
184
185    x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
186    y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
187    b = Tensor(np.ones([128, 32, 1]), dtype=ms.float32)
188    compile_net(net, x, y, b)
189
190
191def test_sum_mul5():
192    class Net(nn.Cell):
193        def __init__(self, strategy1, strategy2):
194            super().__init__()
195            self.mul1 = P.Mul().shard(strategy1)
196            self.reduce_sum = P.ReduceSum(keep_dims=True).shard(strategy2)
197
198        def construct(self, x, y):
199            out = self.mul1(x, y)
200            out = self.reduce_sum(out, 0)
201            return out
202
203    context.set_auto_parallel_context(device_num=64, global_rank=0)
204    strategy1 = ((1, 8, 8), (1, 8, 8))
205    strategy2 = ((2, 4, 1),)
206    net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
207    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
208
209    x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
210    y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
211    compile_net_no_bias(net, x, y)
212
213
214def test_sum_mul6():
215    class Net(nn.Cell):
216        def __init__(self, strategy1, strategy2):
217            super().__init__()
218            self.mul1 = P.Mul().shard(strategy1)
219            self.reduce_sum = P.ReduceSum(keep_dims=True).shard(strategy2)
220
221        def construct(self, x, y):
222            out = self.mul1(x, y)
223            out = self.reduce_sum(out, 1)
224            return out
225
226    context.set_auto_parallel_context(device_num=64, global_rank=0)
227    strategy1 = ((1, 8, 8), (1, 8, 8))
228    strategy2 = ((2, 1, 4),)
229    net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
230    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
231
232    x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
233    y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
234    compile_net_no_bias(net, x, y)
235
236
237def test_sum_mul7():
238    class Net(nn.Cell):
239        def __init__(self, strategy1, strategy2):
240            super().__init__()
241            self.mul1 = P.Mul().shard(strategy1)
242            self.reduce_sum = P.ReduceSum(keep_dims=True).shard(strategy2)
243
244        def construct(self, x, y):
245            out = self.mul1(x, y)
246            out = self.reduce_sum(out, (0, 1))
247            return out
248
249    context.set_auto_parallel_context(device_num=64, global_rank=0)
250    strategy1 = ((1, 8, 8), (1, 8, 8))
251    strategy2 = ((2, 4, 1),)
252    net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
253    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
254
255    x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
256    y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
257    compile_net_no_bias(net, x, y)
258
259
260def test_max_mul():
261    class Net(nn.Cell):
262        def __init__(self, strategy1, strategy2, strategy3):
263            super().__init__()
264            self.mul1 = P.Mul().shard(strategy1)
265            self.reduce_max = P.ReduceMax(keep_dims=False).shard(strategy2)
266            self.mul2 = P.Mul().shard(strategy3)
267
268        def construct(self, x, y, b):
269            out = self.mul1(x, y)
270            out = self.reduce_max(out, -1)
271            out = self.mul2(out, b)
272            return out
273
274    context.set_auto_parallel_context(device_num=8, global_rank=0)
275    strategy1 = ((1, 4, 2), (1, 4, 2))
276    strategy2 = ((4, 1, 2),)
277    strategy3 = ((2, 4), (2, 4))
278    net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
279    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
280
281    x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
282    y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
283    b = Tensor(np.ones([128, 32]), dtype=ms.float32)
284    compile_net(net, x, y, b)
285
286
287def test_min_mul():
288    class Net(nn.Cell):
289        def __init__(self, strategy1, strategy2, strategy3):
290            super().__init__()
291            self.mul1 = P.Mul().shard(strategy1)
292            self.reduce_min = P.ReduceMin(keep_dims=False).shard(strategy2)
293            self.mul2 = P.Mul().shard(strategy3)
294
295        def construct(self, x, y, b):
296            out = self.mul1(x, y)
297            out = self.reduce_min(out, 0)
298            out = self.mul2(out, b)
299            return out
300
301    context.set_auto_parallel_context(device_num=8, global_rank=0)
302    strategy1 = ((1, 4, 2), (1, 4, 2))
303    strategy2 = ((4, 1, 2),)
304    strategy3 = ((2, 4), (2, 4))
305    net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
306    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
307
308    x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
309    y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
310    b = Tensor(np.ones([32, 64]), dtype=ms.float32)
311    compile_net(net, x, y, b)
312
313
314def test_reduce_mean_mul_float32():
315    class Net(nn.Cell):
316        def __init__(self, strategy1, strategy2, strategy3):
317            super().__init__()
318            self.mul1 = P.Mul().shard(strategy1)
319            self.reduce_mean = P.ReduceMean(keep_dims=False).shard(strategy2)
320            self.mul2 = P.Mul().shard(strategy3)
321
322        def construct(self, x, y, b):
323            out = self.mul1(x, y)
324            out = self.reduce_mean(out, 0)
325            out = self.mul2(out, b)
326            return out
327
328    context.set_auto_parallel_context(device_num=8, global_rank=0)
329    strategy1 = ((1, 4, 2), (1, 4, 2))
330    strategy2 = ((4, 1, 2),)
331    strategy3 = ((2, 4), (2, 4))
332    net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
333    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
334
335    x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
336    y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
337    b = Tensor(np.ones([32, 64]), dtype=ms.float32)
338
339    compile_net(net, x, y, b)
340
341
342class ArgMaxWithValueNet(nn.Cell):
343    def __init__(self, strategy1, strategy2, strategy3):
344        super().__init__()
345        self.mul1 = P.Mul().shard(strategy1)
346        self.arg_max_with_value = P.ArgMaxWithValue(keep_dims=False, axis=-1).shard(strategy2)
347        self.mul2 = P.Mul().shard(strategy3)
348
349    def construct(self, x, y, b):
350        out = self.mul1(x, y)
351        _, out = self.arg_max_with_value(out)
352        out = self.mul2(out, b)
353        return out
354
355
356class ArgMinWithValueNet(nn.Cell):
357    def __init__(self, strategy1, strategy2, strategy3):
358        super().__init__()
359        self.mul1 = P.Mul().shard(strategy1)
360        self.arg_min_with_value = P.ArgMinWithValue(keep_dims=False, axis=-1).shard(strategy2)
361        self.mul2 = P.Mul().shard(strategy3)
362
363    def construct(self, x, y, b):
364        out = self.mul1(x, y)
365        _, out = self.arg_min_with_value(out)
366        out = self.mul2(out, b)
367        return out
368
369
370def gen_inputs_and_compile_net(net):
371    x = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
372    y = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
373    b = Tensor(np.ones([128, 64]), dtype=ms.float32)
374    compile_net(net, x, y, b)
375
376
377def gen_inputs_and_compile_net_no_bias(net):
378    x = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
379    y = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
380    compile_net_no_bias(net, x, y)
381
382
383def tobefixed_test_arg_max_with_value_mul_semi_axis_parallel():
384    context.set_auto_parallel_context(device_num=8, global_rank=0)
385    strategy1 = ((1, 4, 2), (1, 4, 2))
386    strategy2 = ((4, 1, 2),)
387    strategy3 = ((2, 4), (2, 4))
388    net = GradWrap(NetWithLoss(ArgMaxWithValueNet(strategy1, strategy2, strategy3)))
389    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
390    gen_inputs_and_compile_net(net)
391
392
393def test_arg_max_with_value_mul_semi():
394    context.set_auto_parallel_context(device_num=8, global_rank=0)
395    strategy1 = ((1, 4, 2), (1, 4, 2))
396    strategy2 = ((4, 1, 1),)
397    strategy3 = ((2, 4), (2, 4))
398    net = GradWrap(NetWithLoss(ArgMaxWithValueNet(strategy1, strategy2, strategy3)))
399    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
400    gen_inputs_and_compile_net(net)
401
402
403def test_arg_max_with_value_mul_auto():
404    context.set_auto_parallel_context(device_num=8, global_rank=0)
405    strategy1 = None
406    strategy2 = None
407    strategy3 = None
408    net = GradWrap(NetWithLoss(ArgMaxWithValueNet(strategy1, strategy2, strategy3)))
409    context.set_auto_parallel_context(parallel_mode="auto_parallel")
410    gen_inputs_and_compile_net(net)
411
412
413def test_arg_min_with_value_mul_semi_axis_parallel():
414    context.set_auto_parallel_context(device_num=8, global_rank=0)
415    strategy1 = ((1, 4, 2), (1, 4, 2))
416    strategy2 = ((4, 1, 2),)
417    strategy3 = ((2, 4), (2, 4))
418    net = GradWrap(NetWithLoss(ArgMinWithValueNet(strategy1, strategy2, strategy3)))
419    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
420    gen_inputs_and_compile_net(net)
421
422
423def test_arg_min_with_value_mul_semi():
424    context.set_auto_parallel_context(device_num=8, global_rank=0)
425    strategy1 = ((1, 4, 2), (1, 4, 2))
426    strategy2 = ((4, 1, 1),)
427    strategy3 = ((2, 4), (2, 4))
428    net = GradWrap(NetWithLoss(ArgMinWithValueNet(strategy1, strategy2, strategy3)))
429    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
430    gen_inputs_and_compile_net(net)
431
432
433def test_arg_min_with_value_mul_auto():
434    context.set_auto_parallel_context(device_num=8, global_rank=0)
435    strategy1 = None
436    strategy2 = None
437    strategy3 = None
438    net = GradWrap(NetWithLoss(ArgMinWithValueNet(strategy1, strategy2, strategy3)))
439    context.set_auto_parallel_context(parallel_mode="auto_parallel")
440    gen_inputs_and_compile_net(net)
441
442
443class ArgMinWithValueNet2(nn.Cell):
444    def __init__(self, strategy1, strategy2, strategy3):
445        super().__init__()
446        self.mul1 = P.Mul().shard(strategy1)
447        self.arg_min_with_value = P.ArgMinWithValue(keep_dims=True, axis=-1).shard(strategy2)
448        self.relu = P.ReLU().shard(strategy3)
449
450    def construct(self, x, y):
451        out = self.mul1(x, y)
452        _, out = self.arg_min_with_value(out)
453        out = self.relu(out)
454        return out
455
456
457def tobefixed_test_arg_min_with_value_mul_semi_axis_parallel2():
458    context.set_auto_parallel_context(device_num=8, global_rank=0)
459    strategy1 = ((1, 4, 2), (1, 4, 2))
460    strategy2 = ((4, 1, 2),)
461    strategy3 = ((2, 4, 1),)
462    net = GradWrapNoBias(NetWithLossNoBias(ArgMinWithValueNet2(strategy1, strategy2, strategy3)))
463    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
464    gen_inputs_and_compile_net_no_bias(net)
465
466
467def test_arg_min_with_value_mul_semi2():
468    context.set_auto_parallel_context(device_num=8, global_rank=0)
469    strategy1 = ((1, 4, 2), (1, 4, 2))
470    strategy2 = ((4, 1, 1),)
471    strategy3 = ((2, 4, 1),)
472    net = GradWrapNoBias(NetWithLossNoBias(ArgMinWithValueNet2(strategy1, strategy2, strategy3)))
473    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
474    gen_inputs_and_compile_net_no_bias(net)
475
476
477def test_arg_min_with_value_mul_auto2():
478    context.set_auto_parallel_context(device_num=8, global_rank=0)
479    strategy1 = None
480    strategy2 = None
481    strategy3 = None
482    net = GradWrapNoBias(NetWithLossNoBias(ArgMinWithValueNet2(strategy1, strategy2, strategy3)))
483    context.set_auto_parallel_context(parallel_mode="auto_parallel")
484    gen_inputs_and_compile_net_no_bias(net)
485
486
487def test_cross_batch():
488    class Net(nn.Cell):
489        def __init__(self, strategy1, strategy2, strategy3):
490            super().__init__()
491            self.mul1 = P.Mul().shard(strategy1)
492            self.reduce_sum = P.ReduceSum(keep_dims=False).shard(strategy2)
493            self.reduce_mean = P.ReduceMean(keep_dims=False).shard(strategy3).add_prim_attr("cross_batch", True)
494
495        def construct(self, x, y):
496            out = self.mul1(x, y)
497            out = self.reduce_sum(out, -1)
498            out = self.reduce_mean(out, 0)
499            return out
500
501    context.set_auto_parallel_context(device_num=8, global_rank=0)
502    strategy1 = ((4, 2), (4, 2))
503    strategy2 = ((2, 1),)
504    strategy3 = ((8,),)
505    net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2, strategy3)))
506    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
507
508    x = Tensor(np.ones([32, 64]), dtype=ms.float32)
509    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
510    compile_net_no_bias(net, x, y)
511
512
513def test_cross_batch2():
514    class Net(nn.Cell):
515        def __init__(self, strategy1, strategy2, strategy3):
516            super().__init__()
517            self.mul1 = P.Mul().shard(strategy1)
518            self.reduce_mean = P.ReduceMean(keep_dims=False).shard(strategy2)
519            self.reduce_sum = P.ReduceSum(keep_dims=False).shard(strategy3).add_prim_attr("cross_batch", True)
520
521        def construct(self, x, y):
522            out = self.mul1(x, y)
523            out = self.reduce_mean(out, -1)
524            out = self.reduce_sum(out, 0)
525            return out
526
527    context.set_auto_parallel_context(device_num=8, global_rank=0)
528    strategy1 = ((4, 2), (4, 2))
529    strategy2 = ((2, 1),)
530    strategy3 = ((8,),)
531    net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2, strategy3)))
532    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
533
534    x = Tensor(np.ones([32, 64]), dtype=ms.float32)
535    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
536    compile_net_no_bias(net, x, y)
537
538
539def test_cross_batch_auto():
540    class Net(nn.Cell):
541        def __init__(self):
542            super().__init__()
543            self.mul1 = P.Mul()
544            self.reduce_mean = P.ReduceMean(keep_dims=False)
545            self.reduce_sum = P.ReduceSum(keep_dims=False).add_prim_attr("cross_batch", True)
546
547        def construct(self, x, y):
548            out = self.mul1(x, y)
549            out = self.reduce_mean(out, -1)
550            out = self.reduce_sum(out, 0)
551            return out
552
553    context.set_auto_parallel_context(device_num=8, global_rank=0)
554    net = GradWrapNoBias(NetWithLossNoBias(Net()))
555    context.set_auto_parallel_context(parallel_mode="auto_parallel")
556
557    x = Tensor(np.ones([32, 64]), dtype=ms.float32)
558    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
559    compile_net_no_bias(net, x, y)
560
561
562def test_max_empty_tuple():
563    class Net(nn.Cell):
564        def __init__(self, strategy1, strategy2, strategy3):
565            super().__init__()
566            self.mul = P.Mul().shard(strategy1)
567            self.reduce_max = P.ReduceMax(keep_dims=False).shard(strategy2)
568            self.add = P.Add().shard(strategy3)
569
570        def construct(self, x, y, b):
571            out = self.mul(x, y)
572            out = self.reduce_max(out)
573            out = self.add(out, b)
574            return out
575
576    context.set_auto_parallel_context(device_num=8, global_rank=0)
577    strategy1 = ((1, 4, 2), (1, 4, 2))
578    strategy2 = ((4, 1, 2),)
579    strategy3 = ((), (1, 1))
580    net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
581    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
582
583    x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
584    y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
585    b = Tensor(np.ones([128, 32]), dtype=ms.float32)
586
587    compile_net(net, x, y, b)
588
589
590def test_any_mul():
591    class Net(nn.Cell):
592        def __init__(self, strategy1, strategy2):
593            super().__init__()
594            self.mul1 = P.Mul().shard(strategy1)
595            self.reduce_any = P.ReduceAny(keep_dims=False).shard(strategy2)
596            self.cast = P.Cast()
597
598        def construct(self, x, y):
599            out = self.mul1(x, y)
600            out = self.cast(out, ms.bool_)
601            out = self.reduce_any(out, 1)
602            return out
603
604    context.set_auto_parallel_context(device_num=64, global_rank=0)
605    strategy1 = ((1, 8, 1), (1, 8, 1))
606    strategy2 = ((1, 8, 1),)
607    net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
608    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
609
610    x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
611    y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
612    with pytest.raises(RuntimeError):
613        compile_net_no_bias(net, x, y)
614
615
616def test_any_mul2():
617    class Net(nn.Cell):
618        def __init__(self, strategy1, strategy2):
619            super().__init__()
620            self.mul1 = P.Mul().shard(strategy1)
621            self.reduce_any = P.ReduceAny(keep_dims=False).shard(strategy2)
622            self.cast = P.Cast()
623
624        def construct(self, x, y):
625            out = self.mul1(x, y)
626            out = self.cast(out, ms.bool_)
627            out = self.reduce_any(out, -1)
628            return out
629
630    context.set_auto_parallel_context(device_num=64, global_rank=0)
631    strategy1 = ((8, 1, 1), (8, 1, 1))
632    strategy2 = ((8, 1, 1),)
633    net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
634    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
635
636    x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
637    y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
638    compile_net_no_bias(net, x, y)
639