1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16 17import mindspore as ms 18import mindspore.nn as nn 19from mindspore import Tensor 20from mindspore import context 21from mindspore.common import dtype as mstype 22from mindspore.common.parameter import ParameterTuple 23from mindspore.communication.management import init 24from mindspore.nn import Dense, Cell 25from mindspore.nn.loss.loss import LossBase 26from mindspore.nn.optim import Momentum 27from mindspore.ops import composite as C 28from mindspore.ops import operations as P 29from mindspore.train import Model 30from mindspore.context import ParallelMode 31from mindspore.communication._comm_helper import GlobalComm 32 33context.set_context(mode=context.GRAPH_MODE) 34device_number = 32 35batch_size_per_device = 128 36 37 38class Dataset(): 39 def __init__(self, predict, length=3): 40 self.predict = predict 41 self.index = 0 42 self.length = length 43 44 def __iter__(self): 45 return self 46 47 def __next__(self): 48 if self.index >= self.length: 49 raise StopIteration 50 self.index += 1 51 return (self.predict,) 52 53 def reset(self): 54 self.index = 0 55 56 def get_dataset_size(self): 57 return 128 58 59 def get_repeat_count(self): 60 return 1 61 62 def create_tuple_iterator(self, num_epochs=-1, do_copy=True): 63 return self 64 65 66class GatherV2(LossBase): 67 def __init__(self, index_dim, strategy, index_size=16): 68 super(GatherV2, self).__init__() 69 self.pow = P.Pow() 70 emb1_list = 21 71 emb2_list = 2 72 if index_dim == 1: 73 emb_list = list(range(index_size)) 74 emb1_list = emb_list[0::2] 75 emb2_list = emb_list[1::2] 76 if index_dim == 2: 77 emb_list = np.arange(index_size * 16) 78 emb1_list = np.reshape(emb_list[0::2], (int(index_size / 2), 16)) 79 emb2_list = np.reshape(emb_list[1::2], (int(index_size / 2), 16)) 80 self.emb1_param = Tensor(emb1_list, dtype=mstype.int32) 81 self.emb2_param = Tensor(emb2_list, dtype=mstype.int32) 82 self.gatherv2 = P.Gather().shard(strategy).add_prim_attr("data_parallel", True) 83 84 def construct(self, nembeddings): 85 emb1 = self.gatherv2(nembeddings, self.emb1_param, 0) 86 emb2 = self.gatherv2(nembeddings, self.emb2_param, 0) 87 return self.pow((emb1 - emb2), 2.0) 88 89 90def fc_with_initialize(input_channels, out_channels): 91 return Dense(input_channels, out_channels) 92 93 94class BuildTrainNetwork(nn.Cell): 95 def __init__(self, network, criterion): 96 super(BuildTrainNetwork, self).__init__() 97 self.network = network 98 self.criterion = criterion 99 100 def construct(self, input_data): 101 embeddings = self.network(input_data) 102 loss = self.criterion(embeddings) 103 return loss 104 105 106class TrainOneStepCell(Cell): 107 def __init__(self, network, optimizer, sens=1.0): 108 super(TrainOneStepCell, self).__init__(auto_prefix=False) 109 self.network = network 110 self.network.add_flags(defer_inline=True) 111 self.weights = ParameterTuple(network.trainable_params()) 112 self.optimizer = optimizer 113 self.grad = C.GradOperation(get_by_list=True, 114 sens_param=True) 115 self.sens = sens 116 117 def construct(self, data): 118 weights = self.weights 119 loss = self.network(data) 120 sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) 121 grads = self.grad(self.network, weights)(data, sens) 122 123 self.optimizer(grads) 124 return loss 125 126 127def net_trains(criterion, rank): 128 GlobalComm.CHECK_ENVS = False 129 init() 130 GlobalComm.CHECK_ENVS = True 131 lr = 0.1 132 momentum = 0.9 133 max_epoch = 20 134 input_channels = 256 135 out_channels = 512 136 context.set_context(mode=context.GRAPH_MODE) 137 context.reset_auto_parallel_context() 138 context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_number, 139 global_rank=rank) 140 predict = Tensor(np.ones([batch_size_per_device, input_channels]), dtype=ms.float32) 141 dataset = Dataset(predict, 4) 142 143 network = fc_with_initialize(input_channels, out_channels) 144 network.set_train() 145 146 train_network = BuildTrainNetwork(network, criterion) 147 train_network.set_train() 148 opt = Momentum(train_network.trainable_params(), lr, momentum) 149 train_net = TrainOneStepCell(train_network, opt).set_train() 150 151 model = Model(train_net) 152 model.train(max_epoch, dataset, dataset_sink_mode=False) 153 context.reset_auto_parallel_context() 154 155 156def test_auto_batch_parallel(): 157 gather_v2_strategy = None 158 criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) 159 rank = 2 160 net_trains(criterion, rank) 161 162 163def test_2d_index_auto_batch_parallel(): 164 gather_v2_strategy = None 165 criterion = GatherV2(2, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) 166 rank = 2 167 net_trains(criterion, rank) 168 169 170def test_batch_parallel(): 171 gather_v2_strategy = ((device_number, 1),) 172 criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) 173 rank = 2 174 net_trains(criterion, rank) 175 176 177def test_strategy1(): 178 gather_v2_strategy = ((16, 2),) 179 rank = 2 180 criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) 181 net_trains(criterion, rank) 182 183 184def test_strategy2(): 185 gather_v2_strategy = ((1, device_number),) 186 rank = 2 187 criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) 188 net_trains(criterion, rank) 189 190 191def test_strategy3(): 192 gather_v2_strategy = ((8, 1),) 193 rank = 2 194 criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) 195 net_trains(criterion, rank) 196 197 198class GatherV2Axis1(LossBase): 199 def __init__(self, index_dim, strategy, index_size=16): 200 super(GatherV2Axis1, self).__init__() 201 self.pow = P.Pow() 202 emb1_list = 21 203 emb2_list = 2 204 if index_dim == 1: 205 emb_list = list(range(index_size)) 206 emb1_list = emb_list[0::2] 207 emb2_list = emb_list[1::2] 208 if index_dim == 2: 209 emb_list = np.arange(index_size * index_size) 210 emb1_list = np.reshape(emb_list[0::2], (int(index_size / 2), index_size)) 211 emb2_list = np.reshape(emb_list[1::2], (int(index_size / 2), index_size)) 212 self.emb1_param = Tensor(emb1_list, dtype=mstype.int32) 213 self.emb2_param = Tensor(emb2_list, dtype=mstype.int32) 214 self.gatherv2 = P.Gather().shard(strategy) 215 216 def construct(self, nembeddings): 217 emb1 = self.gatherv2(nembeddings, self.emb1_param, 1) 218 emb2 = self.gatherv2(nembeddings, self.emb2_param, 1) 219 return self.pow((emb1 - emb2), 2.0) 220 221 222def test_axis1_auto_batch_parallel(): 223 gather_v2_strategy = None 224 criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512) 225 rank = 2 226 net_trains(criterion, rank) 227 228 229def test_axis1_batch_parallel(): 230 gather_v2_strategy = ((device_number, 1), (1,)) 231 criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512) 232 rank = 2 233 net_trains(criterion, rank) 234 235 236def test_axis1_strategy1(): 237 gather_v2_strategy = ((16, 2), (1,)) 238 rank = 17 239 criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512) 240 net_trains(criterion, rank) 241