1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import pytest 17 18import mindspore as ms 19from mindspore import context, Tensor, Parameter 20from mindspore.nn import Cell, Momentum 21from mindspore.ops import operations as P 22from mindspore.train import Model 23from tests.dataset_mock import MindData 24 25 26class Dataset(MindData): 27 def __init__(self, predict, label, length=3): 28 super(Dataset, self).__init__(size=length) 29 self.predict = predict 30 self.label = label 31 self.index = 0 32 self.length = length 33 34 def __iter__(self): 35 return self 36 37 def __next__(self): 38 if self.index >= self.length: 39 raise StopIteration 40 self.index += 1 41 return self.predict, self.label 42 43 def reset(self): 44 self.index = 0 45 46 47class Net(Cell): 48 def __init__(self, w1, strategy1=None, strategy2=None): 49 super().__init__() 50 self.mul = P.Mul().shard(strategy1) 51 self.w1 = Parameter(w1, "w1") 52 self.topk = P.TopK().shard(strategy2) 53 54 def construct(self, x, b): 55 out = self.mul(x, self.w1) 56 out, _ = self.topk(out, 8) 57 return out 58 59 60_x = Tensor(np.ones([16, 64]), dtype=ms.float32) 61_b = Tensor(np.ones([16, 64]), dtype=ms.float32) 62_w1 = Tensor(np.ones([128, 64]), dtype=ms.float32) 63 64 65def compile_net(net): 66 learning_rate = 0.1 67 momentum = 0.9 68 epoch_size = 2 69 dataset = Dataset(_x, _b) 70 opt = Momentum(net.trainable_params(), learning_rate, momentum) 71 model = Model(net, optimizer=opt) 72 model.train(epoch_size, dataset, dataset_sink_mode=False) 73 context.reset_auto_parallel_context() 74 75 76def test_topk_data_parallel(): 77 context.set_auto_parallel_context( 78 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 79 strategy1 = ((8, 1), (8, 1)) 80 strategy2 = ((8, 1),) 81 net = Net(_w1, strategy1, strategy2) 82 compile_net(net) 83 84 85def test_topk_model_parallel(): 86 context.set_auto_parallel_context( 87 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 88 strategy1 = ((2, 4), (2, 4)) 89 strategy2 = ((2, 1),) 90 net = Net(_w1, strategy1, strategy2) 91 compile_net(net) 92 93 94def test_topk_auto_parallel(): 95 context.set_auto_parallel_context( 96 parallel_mode="auto_parallel", device_num=8, global_rank=0) 97 net = Net(_w1) 98 compile_net(net) 99 100 101def test_topk_strategy_error(): 102 context.set_auto_parallel_context( 103 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 104 strategy1 = ((8, 1), (8, 1)) 105 strategy2 = ((1, 8),) 106 net = Net(_w1, strategy1, strategy2) 107 with pytest.raises(RuntimeError): 108 compile_net(net) 109