1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import pytest 17 18import mindspore as ms 19from mindspore import context, Tensor, Parameter 20from mindspore.nn import Cell, Momentum 21from mindspore.ops import operations as P 22from mindspore.train import Model 23from tests.dataset_mock import MindData 24 25 26class Dataset(MindData): 27 def __init__(self, predict, label, length=3): 28 super(Dataset, self).__init__(size=length) 29 self.predict = predict 30 self.label = label 31 self.index = 0 32 self.length = length 33 34 def __iter__(self): 35 return self 36 37 def __next__(self): 38 if self.index >= self.length: 39 raise StopIteration 40 self.index += 1 41 return self.predict, self.label 42 43 def reset(self): 44 self.index = 0 45 46 47class Net(Cell): 48 def __init__(self, w1, w2, strategy1=None, strategy2=None): 49 super().__init__() 50 self.less = P.Less().shard(strategy1) 51 self.w1 = Parameter(w1, "w1") 52 self.w2 = Parameter(w2, "w2") 53 self.select = P.Select().shard(strategy2) 54 55 def construct(self, x, b): 56 out = self.less(x, b) 57 out = self.select(out, self.w1, self.w2) 58 return out 59 60 61_x = Tensor(np.ones([16, 64, 32]), dtype=ms.float32) 62_b = Tensor(np.ones([16, 64, 32]), dtype=ms.float32) 63_w1 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) 64_w2 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) 65 66 67def compile_net(net): 68 learning_rate = 0.1 69 momentum = 0.9 70 epoch_size = 2 71 dataset = Dataset(_x, _b) 72 opt = Momentum(net.trainable_params(), learning_rate, momentum) 73 model = Model(net, optimizer=opt) 74 model.train(epoch_size, dataset, dataset_sink_mode=False) 75 context.reset_auto_parallel_context() 76 77 78def test_select_data_parallel(): 79 context.set_auto_parallel_context( 80 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 81 strategy1 = ((8, 1, 1), (8, 1, 1)) 82 strategy2 = ((8, 1, 1), (8, 1, 1), (8, 1, 1)) 83 net = Net(_w1, _w2, strategy1, strategy2) 84 compile_net(net) 85 86 87def test_select_model_parallel(): 88 context.set_auto_parallel_context( 89 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 90 strategy1 = ((2, 2, 2), (2, 2, 2)) 91 strategy2 = ((2, 2, 2), (2, 2, 2), (2, 2, 2)) 92 net = Net(_w1, _w2, strategy1, strategy2) 93 compile_net(net) 94 95 96def test_select_mirror(): 97 context.set_auto_parallel_context( 98 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 99 strategy1 = ((1, 2, 2), (1, 2, 2)) 100 strategy2 = ((1, 2, 2), (1, 2, 2), (1, 2, 2)) 101 net = Net(_w1, _w2, strategy1, strategy2) 102 compile_net(net) 103 104 105def test_select_auto_parallel(): 106 context.set_auto_parallel_context( 107 parallel_mode="auto_parallel", device_num=8, global_rank=0) 108 net = Net(_w1, _w2) 109 compile_net(net) 110 111 112def test_select_strategy_error(): 113 context.set_auto_parallel_context( 114 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 115 strategy1 = ((2, 2, 2), (2, 2, 2)) 116 strategy2 = ((8, 1, 1), (2, 2, 2), (2, 2, 2)) 117 net = Net(_w1, _w2, strategy1, strategy2) 118 with pytest.raises(RuntimeError): 119 compile_net(net) 120