1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16import mindspore.common.dtype as mstype 17import mindspore.nn as nn 18from mindspore import Tensor 19from mindspore.context import set_auto_parallel_context, ParallelMode 20from mindspore.ops import composite as C 21from mindspore.parallel.nn import Transformer, TransformerOpParallelConfig, MoEConfig 22from mindspore.nn.optim import AdamWeightDecay 23from mindspore.nn.wrap.cell_wrapper import TrainOneStepCell, _VirtualDatasetCell 24from mindspore.train import Model 25from tests.dataset_mock import MindData 26from tests.ut.python.ops.test_math_ops import VirtualLoss 27 28grad_all = C.GradOperation(get_all=True) 29 30 31class Dataset(MindData): 32 def __init__(self, *inputs, length=3): 33 super(Dataset, self).__init__(size=length) 34 self.inputs = inputs 35 self.index = 0 36 self.length = length 37 38 def __iter__(self): 39 return self 40 41 def __next__(self): 42 if self.index >= self.length: 43 raise StopIteration 44 self.index += 1 45 return self.inputs 46 47 def reset(self): 48 self.index = 0 49 50 51config = TransformerOpParallelConfig(data_parallel=2, model_parallel=8, vocab_emb_dp=False) 52moe_config = MoEConfig(expert_num=4) 53 54 55class NetWithLossFiveInputs(nn.Cell): 56 def __init__(self, network): 57 super(NetWithLossFiveInputs, self).__init__() 58 self.loss = VirtualLoss() 59 self.network = network 60 61 def construct(self, x1, x2, x3, x4, x5): 62 predict, _, _, _ = self.network(x1, x2, x3, x4, x5) 63 return self.loss(predict) 64 65 66def test_transformer_model(): 67 set_auto_parallel_context(device_num=16, global_rank=0, 68 full_batch=True, enable_alltoall=True, 69 parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) 70 net = Transformer(encoder_layers=1, 71 decoder_layers=1, 72 batch_size=2, 73 src_seq_length=20, 74 tgt_seq_length=10, 75 hidden_size=64, 76 num_heads=8, 77 ffn_hidden_size=64, 78 moe_config=moe_config, 79 parallel_config=config) 80 net = _VirtualDatasetCell(net) 81 encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32) 82 encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16) 83 decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32) 84 decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16) 85 memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16) 86 net = NetWithLossFiveInputs(net) 87 params = net.trainable_params() 88 optimizer = AdamWeightDecay(params) 89 dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask, 90 memory_mask) 91 net_with_grad = TrainOneStepCell(net, optimizer=optimizer) 92 model = Model(net_with_grad) 93 94 model.train(1, dataset, dataset_sink_mode=False) 95 96 97def test_transformer_model_2d(): 98 set_auto_parallel_context(device_num=16, global_rank=0, 99 full_batch=True, enable_alltoall=True, 100 parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) 101 net = Transformer(encoder_layers=1, 102 decoder_layers=1, 103 batch_size=2, 104 src_seq_length=20, 105 tgt_seq_length=10, 106 hidden_size=64, 107 num_heads=8, 108 ffn_hidden_size=64, 109 moe_config=moe_config, 110 parallel_config=config) 111 net = _VirtualDatasetCell(net) 112 113 encoder_input_value = Tensor(np.ones((40, 64)), mstype.float32) 114 encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16) 115 decoder_input_value = Tensor(np.ones((20, 64)), mstype.float32) 116 decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16) 117 memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16) 118 net = NetWithLossFiveInputs(net) 119 params = net.trainable_params() 120 optimizer = AdamWeightDecay(params) 121 dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask, 122 memory_mask) 123 net_with_grad = TrainOneStepCell(net, optimizer=optimizer) 124 model = Model(net_with_grad) 125 126 model.train(1, dataset, dataset_sink_mode=False) 127