1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16 17import mindspore as ms 18from mindspore import context, Tensor, Parameter 19from mindspore.nn import Cell, Momentum 20from mindspore.ops import operations as P 21from mindspore.train import Model 22from tests.dataset_mock import MindData 23 24 25class Dataset(MindData): 26 def __init__(self, predict, label, length=3): 27 super(Dataset, self).__init__(size=length) 28 self.predict = predict 29 self.label = label 30 self.index = 0 31 self.length = length 32 33 def __iter__(self): 34 return self 35 36 def __next__(self): 37 if self.index >= self.length: 38 raise StopIteration 39 self.index += 1 40 return self.predict, self.label 41 42 def reset(self): 43 self.index = 0 44 45 46class Net(Cell): 47 def __init__(self, weight, weight2, strategy1=None, strategy2=None, is_parameter=True): 48 super().__init__() 49 self.concat = P.Concat(axis=0).shard(strategy1) 50 if is_parameter: 51 self.weight = Parameter(weight, "w1") 52 else: 53 self.weight = weight 54 self.mul = P.Mul().shard(strategy2) 55 self.weight2 = Parameter(weight2, "w2") 56 57 def construct(self, x, b): 58 out = self.concat((self.weight, self.weight2)) 59 out = self.mul(x, out) 60 return out 61 62 63class Net2(Cell): 64 def __init__(self, weight, strategy1=None, strategy2=None, axis=0): 65 super().__init__() 66 self.mul = P.Mul().shard(strategy1) 67 self.concat = P.Concat(axis=axis).shard(strategy2) 68 self.weight = Parameter(weight, "w") 69 70 def construct(self, x, b): 71 out = self.mul(x, x) 72 out = self.concat((out, self.weight)) 73 return out 74 75 76class Net3(Cell): 77 def __init__(self, weight, weight2, weight3, strategy1=None, strategy2=None, is_parameter=True): 78 super().__init__() 79 self.concat = P.Concat(axis=0).shard(strategy1) 80 if is_parameter: 81 self.weight = Parameter(weight, "w1") 82 else: 83 self.weight = weight 84 self.mul = P.Mul().shard(strategy2) 85 self.weight2 = Parameter(weight2, "w2") 86 self.weight3 = Parameter(weight3, "w3") 87 88 def construct(self, x, b): 89 out = self.concat((self.weight, self.weight2, self.weight3)) 90 out = self.mul(x, out) 91 return out 92 93 94_x = Tensor(np.ones([16, 64, 32]), dtype=ms.float32) 95_b = Tensor(np.ones([16, 64, 32, 32]), dtype=ms.int32) 96_w1 = Tensor(np.ones([96, 64, 32]), dtype=ms.float32) 97_w2 = Tensor(np.ones([32, 64, 32]), dtype=ms.float32) 98_w3 = Tensor(np.ones([128, 16, 32]), dtype=ms.float32) 99 100w1 = Tensor(np.ones([48, 64, 32]), dtype=ms.float32) 101w2 = Tensor(np.ones([16, 64, 32]), dtype=ms.float32) 102w3 = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) 103 104 105def compile_net(net): 106 learning_rate = 0.1 107 momentum = 0.9 108 epoch_size = 2 109 dataset = Dataset(_x, _b) 110 opt = Momentum(net.trainable_params(), learning_rate, momentum) 111 model = Model(net, optimizer=opt, amp_level="O2") 112 model.train(epoch_size, dataset, dataset_sink_mode=False) 113 context.reset_auto_parallel_context() 114 115 116def test_concat_parameter(): 117 context.set_auto_parallel_context( 118 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 119 strategy1 = ((1, 4, 2), (1, 4, 2)) 120 strategy2 = ((1, 4, 2), (1, 4, 2)) 121 net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True) 122 compile_net(net) 123 124 125def test_concat_parameter_no_full_split(): 126 context.set_auto_parallel_context( 127 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 128 strategy1 = ((1, 2, 2), (1, 2, 2)) 129 strategy2 = ((1, 4, 2), (1, 4, 2)) 130 net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True) 131 compile_net(net) 132 133 134def test_concat_tensor_and_parameter(): 135 context.set_auto_parallel_context( 136 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 137 strategy1 = ((1, 2, 2), (1, 2, 2)) 138 strategy2 = ((1, 4, 2), (1, 4, 2)) 139 net = Net(_w1, _w2, strategy1, strategy2, is_parameter=False) 140 compile_net(net) 141 142 143def test_concat_output(): 144 context.set_auto_parallel_context( 145 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 146 strategy1 = ((2, 2, 2), (2, 2, 2)) 147 strategy2 = ((1, 4, 2), (1, 4, 2)) 148 net = Net2(_w1, strategy1, strategy2) 149 compile_net(net) 150 151 152def test_concat_output_no_full_split(): 153 context.set_auto_parallel_context( 154 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 155 strategy1 = ((2, 2, 2), (2, 2, 2)) 156 strategy2 = ((1, 2, 2), (1, 2, 2)) 157 net = Net2(_w1, strategy1, strategy2) 158 compile_net(net) 159 160 161def test_concat_no_strategy(): 162 context.set_auto_parallel_context( 163 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 164 strategy1 = ((2, 2, 2), (2, 2, 2)) 165 strategy2 = None 166 net = Net2(_w3, strategy1, strategy2, axis=1) 167 compile_net(net) 168 169 170def test_concat_auto_parallel(): 171 context.set_auto_parallel_context( 172 parallel_mode="auto_parallel", device_num=8, global_rank=0) 173 net = Net2(_w2) 174 compile_net(net) 175 176 177def test_concat_auto_parallel2(): 178 context.set_auto_parallel_context( 179 parallel_mode="auto_parallel", device_num=8, global_rank=0) 180 strategy1 = None 181 strategy2 = None 182 net = Net2(_w3, strategy1, strategy2, axis=1) 183 compile_net(net) 184 185 186def test_concat_auto_parallel_3_tensor(): 187 context.set_auto_parallel_context( 188 parallel_mode="auto_parallel", device_num=8, global_rank=0) 189 net = Net3(w1, w2, w3) 190 compile_net(net) 191