1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore as ms 18from mindspore import context, Tensor, Parameter 19from mindspore.common.api import _cell_graph_executor 20from mindspore.nn import Cell, TrainOneStepCell, Momentum 21from mindspore.ops import operations as P 22 23 24class Net(Cell): 25 def __init__(self, mul_weight, strategy1=None, strategy2=None, strategy3=None): 26 super().__init__() 27 self.mul = P.Mul().shard(strategy1) 28 self.expand_dims = P.ExpandDims().shard(strategy2) 29 self.mul2 = P.Mul().shard(strategy3) 30 self.mul_weight = Parameter(mul_weight, "w1") 31 32 def construct(self, x, b): 33 out = self.mul(x, self.mul_weight) 34 out = self.expand_dims(out, -1) 35 out = self.mul2(out, b) 36 return out 37 38 39class Net2(Cell): 40 def __init__(self, mul_weight, strategy1=None, strategy2=None): 41 super().__init__() 42 self.expand_dims = P.ExpandDims().shard(strategy1) 43 self.mul = P.Mul().shard(strategy2) 44 self.mul_weight = Parameter(mul_weight, "w1") 45 46 def construct(self, x, b): 47 out = self.expand_dims(self.mul_weight, -1) 48 out = self.mul(out, b) 49 return out 50 51 52_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) 53_w1 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) 54_b = Tensor(np.ones([128, 64, 32, 1]), dtype=ms.float32) 55 56 57def compile_net(net): 58 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 59 train_net = TrainOneStepCell(net, optimizer) 60 train_net.set_auto_parallel() 61 train_net.set_train() 62 _cell_graph_executor.compile(train_net, _x, _b) 63 context.reset_auto_parallel_context() 64 65 66def test_expand_dims_data_parallel(): 67 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) 68 strategy1 = ((16, 1, 1), (16, 1, 1)) 69 strategy2 = ((16, 1, 1),) 70 strategy3 = ((16, 1, 1, 1), (16, 1, 1, 1)) 71 net = Net(_w1, strategy1, strategy2, strategy3) 72 compile_net(net) 73 74 75def test_expand_dims_model_parallel(): 76 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) 77 strategy1 = ((1, 1, 16), (1, 1, 16)) 78 strategy2 = ((1, 1, 16),) 79 strategy3 = ((1, 1, 16, 1), (1, 1, 16, 1)) 80 net = Net(_w1, strategy1, strategy2, strategy3) 81 compile_net(net) 82 83 84def test_expand_dims_hybrid_parallel(): 85 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) 86 strategy1 = ((2, 2, 4), (2, 2, 4)) 87 strategy2 = ((2, 2, 4),) 88 strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1)) 89 net = Net(_w1, strategy1, strategy2, strategy3) 90 compile_net(net) 91 92 93def test_expand_dims_auto_parallel(): 94 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0) 95 net = Net(_w1) 96 compile_net(net) 97 98 99def test_expand_dims_repeat_calc(): 100 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) 101 strategy1 = ((2, 2, 4), (2, 2, 4)) 102 strategy2 = ((1, 2, 2),) 103 strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1)) 104 net = Net(_w1, strategy1, strategy2, strategy3) 105 compile_net(net) 106 107 108def test_expand_dims_parameter(): 109 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) 110 strategy1 = ((1, 2, 2),) 111 strategy2 = ((2, 2, 4, 1), (2, 2, 4, 1)) 112 net = Net2(_w1, strategy1, strategy2) 113 compile_net(net) 114