• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16import mindspore.common.dtype as mstype
17import mindspore.nn as nn
18from mindspore import Tensor
19from mindspore.context import set_auto_parallel_context, ParallelMode
20from mindspore.ops import composite as C
21from mindspore.parallel.nn import Transformer, TransformerOpParallelConfig, MoEConfig
22from mindspore.nn.optim import AdamWeightDecay
23from mindspore.nn.wrap.cell_wrapper import TrainOneStepCell, _VirtualDatasetCell
24from mindspore.train import Model
25from tests.dataset_mock import MindData
26from tests.ut.python.ops.test_math_ops import VirtualLoss
27
28grad_all = C.GradOperation(get_all=True)
29
30
31class Dataset(MindData):
32    def __init__(self, *inputs, length=3):
33        super(Dataset, self).__init__(size=length)
34        self.inputs = inputs
35        self.index = 0
36        self.length = length
37
38    def __iter__(self):
39        return self
40
41    def __next__(self):
42        if self.index >= self.length:
43            raise StopIteration
44        self.index += 1
45        return self.inputs
46
47    def reset(self):
48        self.index = 0
49
50
51config = TransformerOpParallelConfig(data_parallel=2, model_parallel=8, vocab_emb_dp=False)
52moe_config = MoEConfig(expert_num=4)
53
54
55class NetWithLossFiveInputs(nn.Cell):
56    def __init__(self, network):
57        super(NetWithLossFiveInputs, self).__init__()
58        self.loss = VirtualLoss()
59        self.network = network
60
61    def construct(self, x1, x2, x3, x4, x5):
62        predict, _, _, _ = self.network(x1, x2, x3, x4, x5)
63        return self.loss(predict)
64
65
66def test_transformer_model():
67    set_auto_parallel_context(device_num=16, global_rank=0,
68                              full_batch=True, enable_alltoall=True,
69                              parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
70    net = Transformer(encoder_layers=1,
71                      decoder_layers=1,
72                      batch_size=2,
73                      src_seq_length=20,
74                      tgt_seq_length=10,
75                      hidden_size=64,
76                      num_heads=8,
77                      ffn_hidden_size=64,
78                      moe_config=moe_config,
79                      parallel_config=config)
80    net = _VirtualDatasetCell(net)
81    encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
82    encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
83    decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
84    decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
85    memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
86    net = NetWithLossFiveInputs(net)
87    params = net.trainable_params()
88    optimizer = AdamWeightDecay(params)
89    dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
90                      memory_mask)
91    net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
92    model = Model(net_with_grad)
93
94    model.train(1, dataset, dataset_sink_mode=False)
95
96
97def test_transformer_model_2d():
98    set_auto_parallel_context(device_num=16, global_rank=0,
99                              full_batch=True, enable_alltoall=True,
100                              parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
101    net = Transformer(encoder_layers=1,
102                      decoder_layers=1,
103                      batch_size=2,
104                      src_seq_length=20,
105                      tgt_seq_length=10,
106                      hidden_size=64,
107                      num_heads=8,
108                      ffn_hidden_size=64,
109                      moe_config=moe_config,
110                      parallel_config=config)
111    net = _VirtualDatasetCell(net)
112
113    encoder_input_value = Tensor(np.ones((40, 64)), mstype.float32)
114    encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
115    decoder_input_value = Tensor(np.ones((20, 64)), mstype.float32)
116    decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
117    memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
118    net = NetWithLossFiveInputs(net)
119    params = net.trainable_params()
120    optimizer = AdamWeightDecay(params)
121    dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
122                      memory_mask)
123    net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
124    model = Model(net_with_grad)
125
126    model.train(1, dataset, dataset_sink_mode=False)
127