1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18import mindspore as ms 19from mindspore import context, Tensor, Parameter 20from mindspore.common.api import _cell_graph_executor 21from mindspore.nn import Cell, TrainOneStepCell, Momentum 22from mindspore.ops import operations as P 23from mindspore.common.initializer import initializer 24 25class Net(Cell): 26 def __init__(self, 27 strategy1=None, 28 strategy2=None, 29 strategy3=None, 30 axis=0, 31 init_flag=True, 32 split_tuple=(4, 4), 33 split_string="manual_split", 34 param_shape=(8, 8)): 35 super().__init__() 36 self.gatherv2 = P.Gather().shard(strategy1) 37 self.gatherv2.add_prim_attr(split_string, split_tuple) 38 self.mul = P.Mul().shard(strategy2) 39 self.reshape = P.Reshape() 40 self.matmul = P.MatMul().shard(strategy3) 41 self.matmul.add_prim_attr("forward_reduce_scatter", True) 42 if init_flag: 43 self.param = Parameter(initializer("ones", param_shape, ms.float32), name="gatherv2_param") 44 else: 45 self.param = Parameter(Tensor(np.ones(param_shape), dtype=ms.float32), name="gatherv2_param") 46 self.mul_weight = Parameter(initializer("ones", (8, 8, 8), ms.float32), name="mul_weight") 47 self.matmul_weight = Parameter(initializer("ones", (64, 16), ms.float32), name="matmul_weight") 48 self.axis = axis 49 50 def construct(self, x, b): 51 out = self.gatherv2(self.param, x, self.axis) 52 out = self.mul(out, self.mul_weight) 53 out = self.reshape(out, (8, 64)) 54 out = self.matmul(out, self.matmul_weight) 55 return out 56 57 58_x = Tensor(np.ones([8, 8]), dtype=ms.int32) 59_b = Tensor(np.ones([64, 8]), dtype=ms.float32) 60 61 62def compile_net(net): 63 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 64 train_net = TrainOneStepCell(net, optimizer) 65 train_net.set_auto_parallel() 66 train_net.set_train() 67 _cell_graph_executor.compile(train_net, _x, _b, auto_parallel_mode=True) 68 context.reset_auto_parallel_context() 69 70 71def test_normal_split(): 72 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) 73 strategy1 = ((2, 1), (1, 2)) 74 strategy2 = ((1, 2, 1), (1, 2, 1)) 75 strategy3 = ((1, 2), (2, 1)) 76 net = Net(strategy1, strategy2, strategy3) 77 compile_net(net) 78 79 80def test_normal_split2(): 81 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0) 82 strategy1 = ((4, 1), (1, 4)) 83 strategy2 = ((1, 4, 1), (1, 4, 1)) 84 strategy3 = ((1, 4), (4, 1)) 85 net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8)) 86 compile_net(net) 87 88 89def test_normal_split3(): 90 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=17) 91 strategy1 = ((4, 8), (1, 4)) 92 strategy2 = ((1, 4, 8), (1, 4, 8)) 93 strategy3 = ((1, 32), (32, 1)) 94 net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8)) 95 compile_net(net) 96 97 98def test_normal_split_with_offset(): 99 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) 100 strategy1 = ((2, 1), (1, 2)) 101 strategy2 = ((1, 2, 1), (1, 2, 1)) 102 strategy3 = ((1, 2), (2, 1)) 103 net = Net(strategy1, strategy2, strategy3, split_string="manual_split_with_offset", split_tuple=((4, 0), (4, 4))) 104 compile_net(net) 105 106 107def test_auto_parallel_error(): 108 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2, global_rank=0) 109 net = Net() 110 with pytest.raises(RuntimeError): 111 compile_net(net) 112 113 114def test_axis_error(): 115 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) 116 strategy1 = ((2, 1), (1, 2)) 117 strategy2 = ((1, 2, 1), (1, 2, 1)) 118 strategy3 = ((1, 2), (2, 1)) 119 net = Net(strategy1, strategy2, strategy3, axis=1) 120 with pytest.raises(RuntimeError): 121 compile_net(net) 122 123 124def test_strategy_error(): 125 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 126 strategy1 = ((4, 1), (8, 1)) 127 strategy2 = ((1, 2, 1), (1, 2, 1)) 128 strategy3 = ((1, 2), (2, 1)) 129 net = Net(strategy1, strategy2, strategy3) 130 with pytest.raises(RuntimeError): 131 compile_net(net) 132 133 134def test_strategy_error2(): 135 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 136 strategy1 = ((4, 1), (1, 8)) 137 strategy2 = ((1, 2, 1), (1, 2, 1)) 138 strategy3 = ((1, 2), (2, 1)) 139 net = Net(strategy1, strategy2, strategy3) 140 with pytest.raises(RuntimeError): 141 compile_net(net) 142 143 144def test_strategy_error3(): 145 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 146 strategy1 = ((2, 1), (1, 2)) 147 strategy2 = ((1, 2, 1), (1, 2, 1)) 148 strategy3 = ((1, 2), (2, 1)) 149 net = Net(strategy1, strategy2, strategy3) 150 with pytest.raises(RuntimeError): 151 compile_net(net) 152 153 154def test_strategy_error4(): 155 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) 156 strategy1 = ((2, 8), (1, 2)) 157 strategy2 = ((1, 2, 1), (1, 2, 1)) 158 strategy3 = ((1, 2), (2, 1)) 159 net = Net(strategy1, strategy2, strategy3) 160 with pytest.raises(RuntimeError): 161 compile_net(net) 162 163 164def test_strategy_error5(): 165 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0) 166 strategy1 = ((4, 1), (1, 4)) 167 strategy2 = ((1, 2, 1), (1, 2, 1)) 168 strategy3 = ((1, 2), (2, 1)) 169 net = Net(strategy1, strategy2, strategy3) 170 with pytest.raises(RuntimeError): 171 compile_net(net) 172 173 174def test_split_tuple_error(): 175 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) 176 strategy1 = ((2, 1), (1, 2)) 177 strategy2 = ((1, 2, 1), (1, 2, 1)) 178 strategy3 = ((1, 2), (2, 1)) 179 net = Net(strategy1, strategy2, strategy3, split_tuple=((5, 0), (5, 5))) 180 with pytest.raises(RuntimeError): 181 compile_net(net) 182 183 184def test_parameter_use_tensor_error(): 185 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) 186 strategy1 = ((2, 1), (1, 2)) 187 strategy2 = ((1, 2, 1), (1, 2, 1)) 188 strategy3 = ((1, 2), (2, 1)) 189 net = Net(strategy1, strategy2, strategy3, init_flag=False) 190 with pytest.raises(RuntimeError): 191 compile_net(net) 192