1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18import mindspore as ms 19from mindspore import context, Tensor, Parameter 20from mindspore.common.api import _cell_graph_executor 21from mindspore.nn import Cell, TrainOneStepCell, LazyAdam 22from mindspore.ops import operations as P 23from mindspore.common.initializer import initializer 24 25@pytest.fixture(scope="module", autouse=True) 26def setup_teardown(): 27 context.set_context(enable_sparse=True) 28 yield 29 context.set_context(enable_sparse=False) 30 31 32class Net(Cell): 33 def __init__(self, 34 strategy1=None, 35 strategy2=None, 36 strategy3=None, 37 axis=0, 38 init_flag=True, 39 split_tuple=(4, 4), 40 split_string="manual_split", 41 param_shape=(8, 8)): 42 super().__init__() 43 self.gatherv2 = P.EmbeddingLookup().shard(strategy1) 44 self.gatherv2.add_prim_attr(split_string, split_tuple) 45 self.gatherv2.add_prim_attr("primitive_target", "CPU") 46 self.mul = P.Mul().shard(strategy2) 47 self.reshape = P.Reshape() 48 self.matmul = P.MatMul().shard(strategy3) 49 self.matmul.add_prim_attr("forward_reduce_scatter", True) 50 if init_flag: 51 self.param = Parameter(initializer("ones", param_shape, ms.float32), name="gatherv2_param") 52 else: 53 self.param = Parameter(Tensor(np.ones(param_shape), dtype=ms.float32), name="gatherv2_param") 54 self.mul_weight = Parameter(initializer("ones", (8, 8, 8), ms.float32), name="mul_weight") 55 self.matmul_weight = Parameter(initializer("ones", (64, 16), ms.float32), name="matmul_weight") 56 self.axis = axis 57 58 def construct(self, x, b): 59 out = self.gatherv2(self.param, x, self.axis) 60 out = self.mul(out, b) 61 return out 62 63 64_x = Tensor(np.ones([8, 8]), dtype=ms.int32) 65_b = Tensor(np.ones([8, 8, 8]), dtype=ms.float32) 66 67 68def compile_net(net): 69 optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1) 70 optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU") 71 train_net = TrainOneStepCell(net, optimizer) 72 train_net.set_auto_parallel() 73 train_net.set_train() 74 _cell_graph_executor.compile(train_net, _x, _b, auto_parallel_mode=True) 75 context.reset_auto_parallel_context() 76 77 78def test_normal_split(): 79 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) 80 strategy1 = ((2, 1), (1, 2)) 81 strategy2 = ((1, 2, 1), (1, 2, 1)) 82 strategy3 = ((1, 2), (2, 1)) 83 net = Net(strategy1, strategy2, strategy3) 84 compile_net(net) 85 86 87def test_normal_split2(): 88 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0) 89 strategy1 = ((4, 1), (1, 4)) 90 strategy2 = ((1, 4, 1), (1, 4, 1)) 91 strategy3 = ((1, 4), (4, 1)) 92 net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8)) 93 compile_net(net) 94 95 96def test_normal_split3(): 97 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=17) 98 strategy1 = ((4, 8), (1, 4)) 99 strategy2 = ((1, 4, 8), (1, 4, 8)) 100 strategy3 = ((1, 32), (32, 1)) 101 net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8)) 102 compile_net(net) 103 104 105def test_normal_split_with_offset(): 106 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) 107 strategy1 = ((2, 1), (1, 2)) 108 strategy2 = ((1, 2, 1), (1, 2, 1)) 109 strategy3 = ((1, 2), (2, 1)) 110 net = Net(strategy1, strategy2, strategy3, split_string="manual_split_with_offset", split_tuple=((4, 0), (4, 4))) 111 compile_net(net) 112 113 114def test_auto_parallel_error(): 115 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2, global_rank=0) 116 net = Net() 117 with pytest.raises(RuntimeError): 118 compile_net(net) 119 120 121def test_auto_parallel(): 122 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2, global_rank=0) 123 net = Net(split_string="fake") 124 compile_net(net) 125 126 127def test_axis_error(): 128 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) 129 strategy1 = ((2, 1), (1, 2)) 130 strategy2 = ((1, 2, 1), (1, 2, 1)) 131 strategy3 = ((1, 2), (2, 1)) 132 net = Net(strategy1, strategy2, strategy3, axis=1) 133 with pytest.raises(RuntimeError): 134 compile_net(net) 135 136 137def test_strategy_error(): 138 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 139 strategy1 = ((4, 1), (8, 1)) 140 strategy2 = ((1, 2, 1), (1, 2, 1)) 141 strategy3 = ((1, 2), (2, 1)) 142 net = Net(strategy1, strategy2, strategy3) 143 with pytest.raises(RuntimeError): 144 compile_net(net) 145 146 147def test_strategy_error2(): 148 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 149 strategy1 = ((4, 1), (1, 8)) 150 strategy2 = ((1, 2, 1), (1, 2, 1)) 151 strategy3 = ((1, 2), (2, 1)) 152 net = Net(strategy1, strategy2, strategy3) 153 with pytest.raises(RuntimeError): 154 compile_net(net) 155 156 157def test_strategy_error3(): 158 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 159 strategy1 = ((2, 1), (1, 2)) 160 strategy2 = ((1, 2, 1), (1, 2, 1)) 161 strategy3 = ((1, 2), (2, 1)) 162 net = Net(strategy1, strategy2, strategy3) 163 with pytest.raises(RuntimeError): 164 compile_net(net) 165 166 167def test_strategy_error4(): 168 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) 169 strategy1 = ((2, 8), (1, 2)) 170 strategy2 = ((1, 2, 1), (1, 2, 1)) 171 strategy3 = ((1, 2), (2, 1)) 172 net = Net(strategy1, strategy2, strategy3) 173 with pytest.raises(RuntimeError): 174 compile_net(net) 175 176 177def test_strategy_error5(): 178 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0) 179 strategy1 = ((4, 1), (1, 4)) 180 strategy2 = ((1, 2, 1), (1, 2, 1)) 181 strategy3 = ((1, 2), (2, 1)) 182 net = Net(strategy1, strategy2, strategy3) 183 with pytest.raises(RuntimeError): 184 compile_net(net) 185 186 187def test_split_tuple_error(): 188 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) 189 strategy1 = ((2, 1), (1, 2)) 190 strategy2 = ((1, 2, 1), (1, 2, 1)) 191 strategy3 = ((1, 2), (2, 1)) 192 net = Net(strategy1, strategy2, strategy3, split_tuple=((5, 0), (5, 5))) 193 with pytest.raises(RuntimeError): 194 compile_net(net) 195 196 197def test_parameter_use_tensor_error(): 198 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) 199 strategy1 = ((2, 1), (1, 2)) 200 strategy2 = ((1, 2, 1), (1, 2, 1)) 201 strategy3 = ((1, 2), (2, 1)) 202 net = Net(strategy1, strategy2, strategy3, init_flag=False) 203 with pytest.raises(RuntimeError): 204 compile_net(net) 205