1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16 17import mindspore as ms 18from mindspore import context, Tensor, Parameter 19from mindspore.common.api import _cell_graph_executor 20from mindspore.nn import Cell, TrainOneStepCell, Momentum 21from mindspore.ops import operations as P 22 23class Net(Cell): 24 def __init__(self, weight, weight2, strategy1=None, strategy2=None, is_parameter=True): 25 super().__init__() 26 self.concat = P.Concat(axis=0).shard(strategy1) 27 if is_parameter: 28 self.weight = Parameter(weight, "w1") 29 else: 30 self.weight = weight 31 self.mul = P.Mul().shard(strategy2) 32 self.weight2 = Parameter(weight2, "w2") 33 34 def construct(self, x, b): 35 out = self.concat((self.weight, self.weight2)) 36 out = self.mul(x, out) 37 return out 38 39 40class Net2(Cell): 41 def __init__(self, weight, strategy1=None, strategy2=None, axis=0): 42 super().__init__() 43 self.mul = P.Mul().shard(strategy1) 44 self.concat = P.Concat(axis=axis).shard(strategy2) 45 self.weight = Parameter(weight, "w") 46 47 def construct(self, x, b): 48 out = self.mul(x, b) 49 out = self.concat((out, self.weight)) 50 return out 51 52 53class Net3(Cell): 54 def __init__(self, weight, weight2, weight3, strategy1=None, strategy2=None, is_parameter=True): 55 super().__init__() 56 self.concat = P.Concat(axis=0).shard(strategy1) 57 if is_parameter: 58 self.weight = Parameter(weight, "w1") 59 else: 60 self.weight = weight 61 self.mul = P.Mul().shard(strategy2) 62 self.weight2 = Parameter(weight2, "w2") 63 self.weight3 = Parameter(weight3, "w3") 64 65 def construct(self, x, b): 66 out = self.concat((self.weight, self.weight2, self.weight3)) 67 out = self.mul(x, out) 68 return out 69 70 71_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) 72_w1 = Tensor(np.ones([96, 64, 32]), dtype=ms.float32) 73_w2 = Tensor(np.ones([32, 64, 32]), dtype=ms.float32) 74_w3 = Tensor(np.ones([128, 16, 32]), dtype=ms.float32) 75_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) 76 77w1 = Tensor(np.ones([48, 64, 32]), dtype=ms.float32) 78w2 = Tensor(np.ones([16, 64, 32]), dtype=ms.float32) 79w3 = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) 80 81 82def compile_net(net): 83 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 84 train_net = TrainOneStepCell(net, optimizer) 85 train_net.set_auto_parallel() 86 train_net.set_train() 87 _cell_graph_executor.compile(train_net, _x, _b) 88 context.reset_auto_parallel_context() 89 90 91def test_concat_parameter(): 92 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 93 strategy1 = ((1, 4, 2), (1, 4, 2)) 94 strategy2 = ((1, 4, 2), (1, 4, 2)) 95 net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True) 96 compile_net(net) 97 98 99def test_concat_parameter_no_full_split(): 100 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 101 strategy1 = ((1, 2, 2), (1, 2, 2)) 102 strategy2 = ((1, 4, 2), (1, 4, 2)) 103 net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True) 104 compile_net(net) 105 106 107def test_concat_tensor_and_parameter(): 108 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 109 strategy1 = ((1, 2, 2), (1, 2, 2)) 110 strategy2 = ((1, 4, 2), (1, 4, 2)) 111 net = Net(_w1, _w2, strategy1, strategy2, is_parameter=False) 112 compile_net(net) 113 114 115def test_concat_output(): 116 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 117 strategy1 = ((2, 2, 2), (2, 2, 2)) 118 strategy2 = ((1, 4, 2), (1, 4, 2)) 119 net = Net2(_w1, strategy1, strategy2) 120 compile_net(net) 121 122 123def test_concat_output_no_full_split(): 124 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 125 strategy1 = ((2, 2, 2), (2, 2, 2)) 126 strategy2 = ((1, 2, 2), (1, 2, 2)) 127 net = Net2(_w1, strategy1, strategy2) 128 compile_net(net) 129 130 131def test_concat_no_strategy(): 132 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 133 strategy1 = ((2, 2, 2), (2, 2, 2)) 134 strategy2 = None 135 net = Net2(_w3, strategy1, strategy2, axis=1) 136 compile_net(net) 137 138 139def test_concat_auto_parallel(): 140 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) 141 net = Net2(_w2) 142 compile_net(net) 143 144 145def test_concat_auto_parallel2(): 146 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) 147 strategy1 = None 148 strategy2 = None 149 net = Net2(_w3, strategy1, strategy2, axis=1) 150 compile_net(net) 151 152 153def test_concat_auto_parallel_3_tensor(): 154 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) 155 net = Net3(w1, w2, w3) 156 compile_net(net) 157