1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import pytest 17 18import mindspore as ms 19from mindspore import context, Tensor, Parameter 20from mindspore.nn import Cell, Momentum 21from mindspore.ops import operations as P 22from mindspore.train import Model 23from tests.dataset_mock import MindData 24 25 26class Dataset(MindData): 27 def __init__(self, predict, label, length=3): 28 super(Dataset, self).__init__(size=length) 29 self.predict = predict 30 self.label = label 31 self.index = 0 32 self.length = length 33 34 def __iter__(self): 35 return self 36 37 def __next__(self): 38 if self.index >= self.length: 39 raise StopIteration 40 self.index += 1 41 return self.predict, self.label 42 43 def reset(self): 44 self.index = 0 45 46 47class Net(Cell): 48 def __init__(self, w1, strategy1=None, strategy2=None): 49 super().__init__() 50 self.mul = P.Mul().shard(strategy1) 51 self.w1 = Parameter(w1, "w1") 52 self.indices = Tensor(np.ones([16, 2]), dtype=ms.int32) 53 self.gathernd = P.GatherNd().shard(strategy2) 54 55 def construct(self, x, b): 56 out = self.mul(x, self.w1) 57 out = self.gathernd(out, self.indices) 58 return out 59 60 61_x = Tensor(np.ones([16, 64]), dtype=ms.float32) 62_b = Tensor(np.ones([16, 64]), dtype=ms.float32) 63_w1 = Tensor(np.ones([128, 64]), dtype=ms.float32) 64 65 66def compile_net(net): 67 learning_rate = 0.1 68 momentum = 0.9 69 epoch_size = 2 70 dataset = Dataset(_x, _b) 71 opt = Momentum(net.trainable_params(), learning_rate, momentum) 72 model = Model(net, optimizer=opt) 73 model.train(epoch_size, dataset, dataset_sink_mode=False) 74 context.reset_auto_parallel_context() 75 76 77def test_gathernd_data_parallel(): 78 context.set_auto_parallel_context( 79 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 80 strategy1 = ((8, 1), (8, 1)) 81 strategy2 = ((1, 1), (8, 1)) 82 net = Net(_w1, strategy1, strategy2) 83 compile_net(net) 84 85 86def test_gathernd_model_parallel(): 87 context.set_auto_parallel_context( 88 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 89 strategy1 = ((2, 4), (2, 4)) 90 strategy2 = ((1, 1), (4, 1)) 91 net = Net(_w1, strategy1, strategy2) 92 compile_net(net) 93 94 95def test_gathernd_auto_parallel(): 96 context.set_auto_parallel_context( 97 parallel_mode="auto_parallel", device_num=8, global_rank=0) 98 net = Net(_w1) 99 compile_net(net) 100 101 102def test_gathernd_strategy_error(): 103 context.set_auto_parallel_context( 104 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 105 strategy1 = ((8, 1), (8, 1)) 106 strategy2 = ((1, 1), (2, 4)) 107 net = Net(_w1, strategy1, strategy2) 108 with pytest.raises(RuntimeError): 109 compile_net(net) 110