• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16import pytest
17
18import mindspore as ms
19from mindspore import context, Tensor, Parameter
20from mindspore.nn import Cell, Momentum
21from mindspore.ops import operations as P
22from mindspore.train import Model
23from tests.dataset_mock import MindData
24
25
26class Dataset(MindData):
27    def __init__(self, predict, label, length=3):
28        super(Dataset, self).__init__(size=length)
29        self.predict = predict
30        self.label = label
31        self.index = 0
32        self.length = length
33
34    def __iter__(self):
35        return self
36
37    def __next__(self):
38        if self.index >= self.length:
39            raise StopIteration
40        self.index += 1
41        return self.predict, self.label
42
43    def reset(self):
44        self.index = 0
45
46
47class Net(Cell):
48    def __init__(self, w1_shape, indices_shape, strategy1=None, strategy2=None, strategy3=None):
49        super().__init__()
50        self.mul = P.Mul().shard(strategy1)
51        self.w1 = Parameter(Tensor(np.ones(w1_shape), dtype=ms.float32), "w1")
52        self.indices = Tensor(np.ones(indices_shape), dtype=ms.int32)
53        self.gathernd = P.GatherNd().shard(strategy2)
54        self.relu = P.ReLU().shard(strategy3)
55
56    def construct(self, x, b):
57        out = self.mul(x, self.w1)
58        out = self.gathernd(out, self.indices)
59        out = self.relu(out)
60        return out
61
62
63class Net2(Cell):
64    def __init__(self, w1_shape, indices_shape, strategy1=None, strategy2=None, strategy3=None):
65        super().__init__()
66        self.mul = P.Mul().shard(strategy1)
67        self.w1 = Parameter(Tensor(np.ones(w1_shape), dtype=ms.float32), "w1")
68        self.indices = Tensor(np.ones(indices_shape), dtype=ms.int32)
69        self.gathernd = P.GatherNd().shard(strategy2)
70        self.relu = P.ReLU().shard(strategy3)
71
72    def construct(self, x, b):
73        out = self.mul(x, self.w1)
74        out = self.gathernd(out, self.indices)
75        return out
76
77
78class Net3(Cell):
79    def __init__(self, w1_shape, indices_shape, strategy1=None, strategy2=None, strategy3=None):
80        super().__init__()
81        self.mul = P.Mul().shard(strategy1)
82        self.w1 = Parameter(Tensor(np.ones(w1_shape), dtype=ms.float32), "w1")
83        self.indices = Tensor(np.ones(indices_shape), dtype=ms.int32)
84        self.gathernd = P.GatherNd().shard(strategy2)
85        self.relu = P.ReLU().shard(strategy3)
86
87    def construct(self, x, b):
88        out = self.gathernd(x, self.indices)
89        out = self.relu(out)
90        out = self.mul(out, self.w1)
91        return out
92
93
94# full_batch = false
95_x = Tensor(np.ones([1, 16, 32]), dtype=ms.float32)
96_b = Tensor(np.ones([1, 16, 32]), dtype=ms.float32)
97
98
99def compile_net(net):
100    learning_rate = 0.1
101    momentum = 0.9
102    epoch_size = 2
103    dataset = Dataset(_x, _b)
104    opt = Momentum(net.trainable_params(), learning_rate, momentum)
105    model = Model(net, optimizer=opt)
106    model.train(epoch_size, dataset, dataset_sink_mode=False)
107    context.reset_auto_parallel_context()
108
109
110def test_gathernd_data_parallel():
111    context.set_auto_parallel_context(
112        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
113    w1_shape = [8, 16, 32]
114    indices_shape = [8, 4, 2, 1]
115    strategy1 = ((8, 1, 1), (8, 1, 1))
116    strategy2 = ((1, 1, 1), (8, 1, 1, 1))
117    strategy3 = ((8, 1, 1, 1, 1),)
118    net = Net(w1_shape, indices_shape, strategy1, strategy2, strategy3)
119    compile_net(net)
120
121
122def test_gathernd_data_parallel2():
123    context.set_auto_parallel_context(
124        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
125    w1_shape = [8, 16, 32]
126    indices_shape = [8, 4, 2, 2]
127    strategy1 = ((8, 1, 1), (8, 1, 1))
128    strategy2 = ((1, 1, 1), (8, 1, 1, 1))
129    strategy3 = ((8, 1, 1, 1),)
130    net = Net(w1_shape, indices_shape, strategy1, strategy2, strategy3)
131    compile_net(net)
132
133
134def test_gathernd_data_parallel3():
135    context.set_auto_parallel_context(
136        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
137    w1_shape = [8, 16, 32]
138    indices_shape = [8, 4, 2, 3]
139    strategy1 = ((8, 1, 1), (8, 1, 1))
140    strategy2 = ((1, 1, 1), (8, 1, 1, 1))
141    strategy3 = ((8, 1, 1),)
142    net = Net(w1_shape, indices_shape, strategy1, strategy2, strategy3)
143    compile_net(net)
144
145
146def test_gathernd_data_parallel4():
147    context.set_auto_parallel_context(
148        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
149    w1_shape = [8, 16, 32]
150    indices_shape = [8, 4, 2, 1]
151    strategy1 = ((8, 1, 1), (8, 1, 1))
152    strategy2 = ((1, 1, 1), (8, 1, 1, 1))
153    strategy3 = ((8, 1, 1, 1, 1),)
154    net = Net2(w1_shape, indices_shape, strategy1, strategy2, strategy3)
155    compile_net(net)
156
157
158def test_gathernd_data_parallel5():
159    context.set_auto_parallel_context(
160        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
161    w1_shape = [8, 16, 32]
162    indices_shape = [8, 4, 2, 2]
163    strategy1 = ((8, 1, 1), (8, 1, 1))
164    strategy2 = ((1, 1, 1), (8, 1, 1, 1))
165    strategy3 = ((8, 1, 1, 1),)
166    net = Net2(w1_shape, indices_shape, strategy1, strategy2, strategy3)
167    compile_net(net)
168
169
170def test_gathernd_data_parallel6():
171    context.set_auto_parallel_context(
172        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
173    w1_shape = [8, 16, 32]
174    indices_shape = [8, 4, 2, 3]
175    strategy1 = ((8, 1, 1), (8, 1, 1))
176    strategy2 = ((1, 1, 1), (8, 1, 1, 1))
177    strategy3 = ((8, 1, 1),)
178    net = Net2(w1_shape, indices_shape, strategy1, strategy2, strategy3)
179    compile_net(net)
180
181
182def test_gathernd_data_parallel7():
183    context.set_auto_parallel_context(
184        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
185    w1_shape = [8, 4, 2, 16, 32]
186    indices_shape = [8, 4, 2, 1]
187    strategy1 = ((8, 1, 1, 1, 1), (8, 1, 1, 1, 1))
188    strategy2 = ((1, 1, 1), (8, 1, 1, 1))
189    strategy3 = ((8, 1, 1, 1, 1),)
190    net = Net3(w1_shape, indices_shape, strategy1, strategy2, strategy3)
191    compile_net(net)
192
193
194def test_gathernd_data_parallel8():
195    context.set_auto_parallel_context(
196        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
197    w1_shape = [8, 4, 2, 32]
198    indices_shape = [8, 4, 2, 2]
199    strategy1 = ((8, 1, 1, 1), (8, 1, 1, 1))
200    strategy2 = ((1, 1, 1), (8, 1, 1, 1))
201    strategy3 = ((8, 1, 1, 1),)
202    net = Net3(w1_shape, indices_shape, strategy1, strategy2, strategy3)
203    compile_net(net)
204
205
206def test_gathernd_data_parallel9():
207    context.set_auto_parallel_context(
208        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
209    w1_shape = [8, 4, 2]
210    indices_shape = [8, 4, 2, 3]
211    strategy1 = ((8, 1, 1), (8, 1, 1))
212    strategy2 = ((1, 1, 1), (8, 1, 1, 1))
213    strategy3 = ((8, 1, 1),)
214    net = Net3(w1_shape, indices_shape, strategy1, strategy2, strategy3)
215    compile_net(net)
216
217
218def test_gathernd_model_parallel():
219    context.set_auto_parallel_context(
220        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
221    w1_shape = [8, 16, 32]
222    indices_shape = [8, 4, 2, 1]
223    strategy1 = ((8, 1, 1), (8, 1, 1))
224    strategy2 = ((1, 1, 1), (2, 2, 2, 1))
225    strategy3 = ((8, 1, 1, 1, 1),)
226    net = Net(w1_shape, indices_shape, strategy1, strategy2, strategy3)
227    compile_net(net)
228
229
230def test_gathernd_model_parallel2():
231    context.set_auto_parallel_context(
232        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
233    w1_shape = [8, 16, 32]
234    indices_shape = [8, 4, 2, 2]
235    strategy1 = ((8, 1, 1), (8, 1, 1))
236    strategy2 = ((1, 1, 1), (2, 2, 2, 1))
237    strategy3 = ((8, 1, 1, 1),)
238    net = Net(w1_shape, indices_shape, strategy1, strategy2, strategy3)
239    compile_net(net)
240
241
242def test_gathernd_model_parallel3():
243    context.set_auto_parallel_context(
244        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
245    w1_shape = [8, 16, 32]
246    indices_shape = [8, 4, 2, 3]
247    strategy1 = ((8, 1, 1), (8, 1, 1))
248    strategy2 = ((1, 1, 1), (2, 2, 2, 1))
249    strategy3 = ((8, 1, 1),)
250    net = Net(w1_shape, indices_shape, strategy1, strategy2, strategy3)
251    compile_net(net)
252
253
254def test_gathernd_model_parallel4():
255    context.set_auto_parallel_context(
256        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
257    w1_shape = [8, 16, 32]
258    indices_shape = [8, 4, 2, 1]
259    strategy1 = ((8, 1, 1), (8, 1, 1))
260    strategy2 = ((1, 1, 1), (2, 2, 2, 1))
261    strategy3 = ((8, 1, 1, 1, 1),)
262    net = Net2(w1_shape, indices_shape, strategy1, strategy2, strategy3)
263    compile_net(net)
264
265
266def test_gathernd_model_parallel5():
267    context.set_auto_parallel_context(
268        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
269    w1_shape = [8, 16, 32]
270    indices_shape = [8, 4, 2, 2]
271    strategy1 = ((8, 1, 1), (8, 1, 1))
272    strategy2 = ((1, 1, 1), (2, 2, 2, 1))
273    strategy3 = ((8, 1, 1, 1),)
274    net = Net2(w1_shape, indices_shape, strategy1, strategy2, strategy3)
275    compile_net(net)
276
277
278def test_gathernd_model_parallel6():
279    context.set_auto_parallel_context(
280        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
281    w1_shape = [8, 16, 32]
282    indices_shape = [8, 4, 2, 3]
283    strategy1 = ((8, 1, 1), (8, 1, 1))
284    strategy2 = ((1, 1, 1), (2, 2, 2, 1))
285    strategy3 = ((8, 1, 1),)
286    net = Net2(w1_shape, indices_shape, strategy1, strategy2, strategy3)
287    compile_net(net)
288
289
290def test_gathernd_model_parallel7():
291    context.set_auto_parallel_context(
292        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
293    w1_shape = [8, 4, 2, 16, 32]
294    indices_shape = [8, 4, 2, 1]
295    strategy1 = ((8, 1, 1, 1, 1), (8, 1, 1, 1, 1))
296    strategy2 = ((1, 1, 1), (2, 2, 2, 1))
297    strategy3 = ((8, 1, 1, 1, 1),)
298    net = Net3(w1_shape, indices_shape, strategy1, strategy2, strategy3)
299    compile_net(net)
300
301
302def test_gathernd_model_parallel8():
303    context.set_auto_parallel_context(
304        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
305    w1_shape = [8, 4, 2, 32]
306    indices_shape = [8, 4, 2, 2]
307    strategy1 = ((8, 1, 1, 1), (8, 1, 1, 1))
308    strategy2 = ((1, 1, 1), (2, 2, 2, 1))
309    strategy3 = ((8, 1, 1, 1),)
310    net = Net3(w1_shape, indices_shape, strategy1, strategy2, strategy3)
311    compile_net(net)
312
313
314def test_gathernd_model_parallel9():
315    context.set_auto_parallel_context(
316        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
317    w1_shape = [8, 4, 2]
318    indices_shape = [8, 4, 2, 3]
319    strategy1 = ((8, 1, 1), (8, 1, 1))
320    strategy2 = ((1, 1, 1), (2, 2, 2, 1))
321    strategy3 = ((8, 1, 1),)
322    net = Net3(w1_shape, indices_shape, strategy1, strategy2, strategy3)
323    compile_net(net)
324
325def test_gathernd_auto_parallel():
326    context.set_auto_parallel_context(
327        parallel_mode="auto_parallel", device_num=8, global_rank=0)
328    w1_shape = [8, 16, 32]
329    indices_shape = [8, 4, 2, 1]
330    net = Net(w1_shape, indices_shape)
331    compile_net(net)
332
333
334def test_gathernd_auto_parallel2():
335    context.set_auto_parallel_context(
336        parallel_mode="auto_parallel", device_num=8, global_rank=0)
337    w1_shape = [8, 16, 32]
338    indices_shape = [8, 4, 2, 2]
339    net = Net(w1_shape, indices_shape)
340    compile_net(net)
341
342
343def test_gathernd_auto_parallel3():
344    context.set_auto_parallel_context(
345        parallel_mode="auto_parallel", device_num=8, global_rank=0)
346    w1_shape = [8, 16, 32]
347    indices_shape = [8, 4, 2, 3]
348    net = Net(w1_shape, indices_shape)
349    compile_net(net)
350
351
352def test_gathernd_strategy_error():
353    context.set_auto_parallel_context(
354        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
355    w1_shape = [8, 16, 32]
356    indices_shape = [8, 4, 2, 3]
357    strategy1 = ((8, 1, 1), (8, 1, 1))
358    strategy2 = ((2, 1, 1), (1, 2, 2, 1))
359    strategy3 = ((8, 1, 1),)
360    net = Net(w1_shape, indices_shape, strategy1, strategy2, strategy3)
361    with pytest.raises(RuntimeError):
362        compile_net(net)
363
364
365def test_gathernd_strategy_error2():
366    context.set_auto_parallel_context(
367        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
368    w1_shape = [8, 16, 32]
369    indices_shape = [8, 4, 2, 3]
370    strategy1 = ((8, 1, 1), (8, 1, 1))
371    strategy2 = ((1, 1, 1), (1, 2, 2, 2))
372    strategy3 = ((8, 1, 1),)
373    net = Net(w1_shape, indices_shape, strategy1, strategy2, strategy3)
374    with pytest.raises(RuntimeError):
375        compile_net(net)
376