• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16import mindspore as ms
17import mindspore.nn as nn
18from mindspore import Tensor
19from mindspore import context
20from mindspore.common.api import _cell_graph_executor
21from mindspore.ops import composite as C
22from mindspore.ops import operations as P
23from tests.ut.python.ops.test_math_ops import VirtualLoss
24
25
26grad_all = C.GradOperation(get_all=True)
27
28
29class NetWithLoss(nn.Cell):
30    def __init__(self, network):
31        super(NetWithLoss, self).__init__()
32        self.loss = VirtualLoss()
33        self.network = network
34
35    def construct(self, x, y):
36        predict = self.network(x, y)
37        return self.loss(predict)
38
39
40class GradWrap(nn.Cell):
41    def __init__(self, network):
42        super(GradWrap, self).__init__()
43        self.network = network
44
45    def construct(self, x, y):
46        return grad_all(self.network)(x, y)
47
48
49class Net(nn.Cell):
50    def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None, target=""):
51        super().__init__()
52        if shape is None:
53            shape = [64, 64]
54        self.gatherv2 = P.Gather().shard(strategy1).add_prim_attr("primitive_target", target)
55        self.mul = P.Mul().shard(strategy2)
56        self.index = Tensor(np.ones(shape), dtype=ms.int32)
57        self.axis = axis
58
59    def construct(self, x, y):
60        out = self.gatherv2(x, self.index, self.axis)
61        out = self.mul(out, y)
62        return out
63
64
65def test_gatherv2_semi_auto0():
66    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
67    strategy1 = ((1, 8), (1, 1))
68    strategy2 = ((4, 2, 1), (4, 2, 1))
69    net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
70    net.set_auto_parallel()
71
72    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
73    y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
74    net.set_train()
75    _cell_graph_executor.compile(net, x, y)
76
77
78def test_gatherv2_semi_auto1():
79    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
80    strategy1 = ((8, 1), (1, 1))
81    strategy2 = ((4, 2, 1), (4, 2, 1))
82    net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
83    net.set_auto_parallel()
84
85    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
86    y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
87    net.set_train()
88    _cell_graph_executor.compile(net, x, y)
89
90
91def test_gatherv2_semi_auto2():
92    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
93    strategy1 = ((2, 4), (1, 1))
94    strategy2 = ((4, 2, 1), (4, 2, 1))
95    net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
96    net.set_auto_parallel()
97
98    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
99    y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
100    net.set_train()
101    _cell_graph_executor.compile(net, x, y)
102
103
104def test_gatherv2_semi_auto3():
105    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
106    strategy1 = ((1, 8), (1, 1))
107    strategy2 = ((4, 2, 1), (4, 2, 1))
108    net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2)))
109    net.set_auto_parallel()
110
111    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
112    y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
113    net.set_train()
114    _cell_graph_executor.compile(net, x, y)
115
116
117def test_gatherv2_semi_auto4():
118    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
119    strategy1 = ((8, 1), (1, 1))
120    strategy2 = ((4, 2, 1), (4, 2, 1))
121    net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2)))
122    net.set_auto_parallel()
123
124    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
125    y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
126    net.set_train()
127    _cell_graph_executor.compile(net, x, y)
128
129
130def test_gatherv2_semi_auto5():
131    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
132    strategy1 = ((2, 4), (1, 1))
133    strategy2 = ((4, 2, 1), (4, 2, 1))
134    net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2)))
135    net.set_auto_parallel()
136
137    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
138    y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
139    net.set_train()
140    _cell_graph_executor.compile(net, x, y)
141
142
143def test_gatherv2_semi_auto6():
144    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
145    strategy2 = ((4, 2, 1), (4, 2, 1))
146    net = GradWrap(NetWithLoss(Net(0, None, strategy2)))
147    net.set_auto_parallel()
148
149    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
150    y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32)
151    net.set_train()
152    _cell_graph_executor.compile(net, x, y)
153
154
155def test_gatherv2_semi_auto7():
156    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
157    strategy2 = ((4, 2, 1), (4, 2, 1))
158    net = GradWrap(NetWithLoss(Net(1, None, strategy2)))
159    net.set_auto_parallel()
160
161    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
162    y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
163    net.set_train()
164    _cell_graph_executor.compile(net, x, y)
165
166
167def test_gatherv2_semi_auto8():
168    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
169    strategy1 = ((8,), (1, 1))
170    strategy2 = ((4, 2), (4, 2))
171    net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
172    net.set_auto_parallel()
173
174    x = Tensor(np.ones([64]), dtype=ms.float32)
175    y = Tensor(np.ones([64, 64]), dtype=ms.float32)
176    net.set_train()
177    _cell_graph_executor.compile(net, x, y)
178
179
180def test_gatherv2_forward_all_reduce():
181    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
182    strategy1 = ((8, 1), (1, 1))
183    strategy2 = ((2, 4, 1), (2, 4, 1))
184    net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
185    net.set_auto_parallel()
186
187    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
188    y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)
189    net.set_train()
190    _cell_graph_executor.compile(net, x, y)
191
192
193def test_gatherv2_shard_batch_and_axis():
194    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
195    strategy1 = ((4, 1), (2, 1))
196    strategy2 = ((2, 4, 1), (2, 4, 1))
197    net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
198    net.set_auto_parallel()
199
200    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
201    y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)
202    net.set_train()
203    _cell_graph_executor.compile(net, x, y)
204
205
206def test_gatherv2_split_axis_0_repeat_calc():
207    context.set_auto_parallel_context(device_num=8, global_rank=7, parallel_mode="semi_auto_parallel")
208    strategy1 = ((4, 1), (1, 1))
209    strategy2 = ((2, 4, 1), (2, 4, 1))
210    net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
211    net.set_auto_parallel()
212
213    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
214    y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)
215    net.set_train()
216    _cell_graph_executor.compile(net, x, y)
217
218
219def test_gatherv2_auto0():
220    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
221    net = GradWrap(NetWithLoss(Net(0)))
222    net.set_auto_parallel()
223    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
224    y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32)
225    net.set_train()
226    _cell_graph_executor.compile(net, x, y)
227
228
229def test_gatherv2_auto1():
230    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
231    net = GradWrap(NetWithLoss(Net(1)))
232    net.set_auto_parallel()
233    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
234    y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
235    net.set_train()
236    _cell_graph_executor.compile(net, x, y)
237