• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Note:
17    Transformer Networks. This is an experimental interface that is subject to change and/or deletion.
18"""
19import math
20import numpy as np
21from mindspore.common.tensor import Tensor
22from mindspore.common.parameter import Parameter
23from mindspore.common.initializer import initializer
24from mindspore import nn
25from mindspore import context
26import mindspore.common.dtype as mstype
27from mindspore.ops import operations as P
28from mindspore.ops import functional as F
29from mindspore.nn.cell import Cell
30from mindspore._checkparam import Validator
31from mindspore import log as logger
32from mindspore.parallel._utils import _get_parallel_mode
33from mindspore.context import ParallelMode
34from .layers import _LayerNorm, _Linear, _check_input_shape, \
35    _args_type_validator_check, _valid_type_checks, _valid_value_checks, \
36    _check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value
37from .op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, _Config, _check_config
38from .moe import default_moe_config, MoE
39
40__all__ = [
41    "AttentionMask",
42    "VocabEmbedding",
43    "MultiHeadAttention",
44    "FeedForward",
45    "TransformerEncoder",
46    "TransformerDecoder",
47    "TransformerEncoderLayer",
48    "TransformerDecoderLayer",
49    "Transformer",
50    "TransformerOpParallelConfig",
51    "EmbeddingOpParallelConfig"]
52
53
54class EmbeddingOpParallelConfig(_Config):
55    r"""
56        EmbeddingOpParallelConfig for the setting data parallel or row slice for the embedding table.
57
58        Args:
59            data_parallel (int): The data parallel way. Default: 1
60            model_parallel (int): The model parallel way. Default: 1
61            vocab_emb_dp (bool): Shard embedding in model parallel or data parallel. Default: True
62
63        Supported Platforms:
64            ``Ascend`` ``GPU``
65
66        Examples:
67            >>> config=EmbeddingOpParallelConfig(data_parallel=1, model_parallel=1, vocab_emb_dp=True)
68    """
69
70    def __init__(self, data_parallel=1, model_parallel=1, vocab_emb_dp=True):
71        self._dp_mp_config = OpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel)
72        Validator.check_bool(vocab_emb_dp, "vocab_emb_dp")
73        self.vocab_emb_dp = vocab_emb_dp
74
75    @property
76    def data_parallel(self):
77        return self._dp_mp_config.data_parallel
78
79    @data_parallel.setter
80    def data_parallel(self, value):
81        self._dp_mp_config.data_parallel = value
82
83    @property
84    def model_parallel(self):
85        return self._dp_mp_config.model_parallel
86
87    @model_parallel.setter
88    def model_parallel(self, value):
89        self._dp_mp_config.model_parallel = value
90
91    @property
92    def vocab_emb_dp(self):
93        return self._vocab_emb_dp
94
95    @vocab_emb_dp.setter
96    def vocab_emb_dp(self, value):
97        Validator.check_bool(value, "vocab_emb_dp")
98        self._vocab_emb_dp = value
99
100    @property
101    def dp_mp_config(self):
102        r"""
103            To obtain the DPMPlConfig for the setting data parallel, model parallel
104
105            Supported Platforms:
106                ``Ascend`` ``GPU``
107
108            Examples:
109                >>> config=EmbeddingOpParallelConfig(data_parallel=1, model_parallel=1, vocab_emb_dp=True)
110                >>> parallel_config = config.dp_mp_config
111        """
112        return self._dp_mp_config
113
114
115class TransformerOpParallelConfig(_Config):
116    r"""
117        TransformerOpParallelConfig for the setting global data parallel, model parallel and fusion group.
118        The parallel configure setting.
119
120        Note:
121            Except the recompute argument, other arguments will not be effective when the user doesn't set
122            auto_parallel_context to `SEMI_AUTO_PARALLEL` or `AUTO_PARALLEL`.
123            The micro_batch_num must be greater than or equal to pipeline_stage. The data_parallel\*model_parallel
124            \*pipeline_stage must be equal or less equal to the device. When setting the pipeline stage and
125            optimizer_shard, the config will overwrite the auto_parallel_context.
126
127        Args:
128            data_parallel (int): The data parallel way. Default: 1.
129            model_parallel (int): The model parallel way. Default: 1.
130            pipeline_stage (int): The number of the pipeline stage. Should be a positive value. Default: 1.
131            micro_batch_num (int): The microe size of the batches for the pipeline training. Default: 1.
132            optimizer_shard (bool): Whether to enable optimizer shard. Default False.
133            gradient_aggregation_group (int): The fusion group size of the optimizer state sharding. Default: 4.
134            recompute (bool): Enable recomputation of the transformer block or not. Default: False.
135            vocab_emb_dp (bool): Shard embedding in model parallel or data parallel. Default: True.
136
137        Supported Platforms:
138            ``Ascend`` ``GPU``
139
140        Examples:
141            >>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1)
142    """
143
144    def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1, recompute=False,
145                 optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True):
146        self.recompute = recompute
147        self.optimizer_shard = optimizer_shard
148        self.gradient_aggregation_group = gradient_aggregation_group
149        self._embed_dp_mp_config = EmbeddingOpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
150                                                             vocab_emb_dp=vocab_emb_dp)
151        self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num)
152
153    @property
154    def recompute(self):
155        return self._recompute
156
157    @recompute.setter
158    def recompute(self, value):
159        Validator.check_bool(value, "recompute")
160        self._recompute = value
161
162    @property
163    def vocab_emb_dp(self):
164        return self._embed_dp_mp_config.vocab_emb_dp
165
166    @vocab_emb_dp.setter
167    def vocab_emb_dp(self, value):
168        self._embed_dp_mp_config.vocab_emb_dp = value
169
170    @property
171    def gradient_aggregation_group(self):
172        return self._gradient_aggregation_group
173
174    @gradient_aggregation_group.setter
175    def gradient_aggregation_group(self, value):
176        Validator.check_positive_int(value, "gradient_aggregation_group")
177        self._gradient_aggregation_group = value
178
179    @property
180    def micro_batch_num(self):
181        return self._pp_config.micro_batch_num
182
183    @micro_batch_num.setter
184    def micro_batch_num(self, value):
185        self._pp_config.micro_batch_num = value
186
187    @property
188    def model_parallel(self):
189        return self._embed_dp_mp_config.model_parallel
190
191    @model_parallel.setter
192    def model_parallel(self, value):
193        self._embed_dp_mp_config.model_parallel = value
194
195    @property
196    def data_parallel(self):
197        return self._embed_dp_mp_config.data_parallel
198
199    @data_parallel.setter
200    def data_parallel(self, value):
201        self._embed_dp_mp_config.data_parallel = value
202
203    @property
204    def pipeline_stage(self):
205        return self._pp_config.pipeline_stage
206
207    @pipeline_stage.setter
208    def pipeline_stage(self, value):
209        self._pp_config.pipeline_stage = value
210
211    @property
212    def optimizer_shard(self):
213        return self._optimizer_shard
214
215    @optimizer_shard.setter
216    def optimizer_shard(self, value):
217        Validator.check_bool(value, "optimizer_shard")
218        self._optimizer_shard = value
219        context.set_auto_parallel_context(enable_parallel_optimizer=value)
220
221    @property
222    def embedding_dp_mp_config(self):
223        r"""
224            To obtain the EmbeddingParallelConfig for the setting data parallel, model parallel and embedding
225            parallel.
226
227            Supported Platforms:
228                ``Ascend`` ``GPU``
229
230            Examples:
231                >>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1, vocab_emb_dp=True)
232                >>> parallel_config = config.embedding_dp_mp_config
233        """
234        return self._embed_dp_mp_config
235
236    @property
237    def dp_mp_config(self):
238        r"""
239            To obtain the EmbeddingParallelConfig for the setting data parallel, model parallel and embedding
240            parallel.
241
242            Supported Platforms:
243                ``Ascend`` ``GPU``
244
245            Examples:
246                >>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1, vocab_emb_dp=True)
247                >>> parallel_config = config.dp_mp_config
248        """
249        return self._embed_dp_mp_config.dp_mp_config
250
251
252default_transformer_config = TransformerOpParallelConfig()
253default_embedding_parallel_config = EmbeddingOpParallelConfig()
254
255
256class FeedForward(Cell):
257    r"""
258    The multilayer perceptron with two linear layers with dropout applied at final output. The first linear
259    will project the input dimension from hidden_size to ffn_hidden_size, the second linear will project the
260    dimension from ffn_hidden_size to hidden_size. The first linear is sharded on the relative dimension,
261    the second linear is sharded on the output dimension. The overview process can be
262
263    .. math::
264        Dropout((xW_1+b_1)W_2 + b_2))
265
266    where the :math:`W_1, W_2, b_1` and :math:`b_2` are trainable parameters.
267
268    Args:
269        hidden_size (int): The dimension of the inputs.
270        ffn_hidden_size (int): The intermediate hidden size.
271        dropout_rate (float): The dropout rate for the second linear's output.
272        hidden_act (str): The activation of the internal feedforward layer. Supports 'relu',
273                         'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
274                         'hsigmoid', 'logsigmoid' and so on. Default: gelu.
275        expert_num (int): The number of experts used in Linear. For the case expert_num > 1, BatchMatMul is used
276            and the first dimension in BatchMatMul indicate expert_num. Default: 1.
277        param_init_type (dtype.Number): The parameter initialization type. Should be dtype.float32 or dtype.float16.
278                                        Default: dtype.float32.
279        parallel_config(OpParallelConfig): The config of parallel setting, see `OpParallelConfig`.
280                                           Default `default_dpmp_config`, an instance of `OpParallelConfig` with
281                                           default args.
282
283    Inputs:
284        - **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`.
285          Float tensor.
286
287    Outputs:
288        Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size]
289        or [batch * seq_length, hidden_size]`.
290
291    Raises:
292        ValueError: `hidden_act` is not a string.
293        TypeError: `parallel_config` is not a subclass of OpParallelConfig.
294        ValueError: `ffn_hidden_size` is not a multiple of the model parallel way.
295        ValueError: `hidden_size` is not a multiple of the model parallel way.
296
297    Supported Platforms:
298        ``Ascend`` ``GPU``
299
300    Examples:
301        >>> import numpy as np
302        >>> from mindspore.parallel.nn import FeedForward
303        >>> from mindspore import dtype as mstype
304        >>> from mindspore import Tensor
305        >>> model = FeedForward(hidden_size=15, ffn_hidden_size=30, dropout_rate=0.1)
306        >>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
307        >>> output = model(tensor)
308        >>> print(output.shape)
309        (2, 20, 15)
310    """
311
312    @_args_type_validator_check(hidden_size=Validator.check_positive_int,
313                                ffn_hidden_size=Validator.check_positive_int,
314                                dropout_rate=Validator.check_non_negative_float,
315                                hidden_act=_valid_type_checks([str], "FeedForward"),
316                                param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
317                                                                    "FeedForward"),
318                                parallel_config=_valid_type_checks([OpParallelConfig],
319                                                                   "FeedForward"))
320    def __init__(self, hidden_size,
321                 ffn_hidden_size,
322                 dropout_rate,
323                 hidden_act='gelu',
324                 expert_num=1,
325                 param_init_type=mstype.float32,
326                 parallel_config=default_dpmp_config):
327        super(FeedForward, self).__init__()
328        _check_config(parallel_config)
329        dp = parallel_config.data_parallel
330        mp = parallel_config.model_parallel
331        if ffn_hidden_size % mp != 0:
332            raise ValueError(f"ffn_hidden_size {ffn_hidden_size} should be a multiple of the model parallel way {mp}")
333        if hidden_size % mp != 0:
334            raise ValueError(f"hidden_size {hidden_size} should be a multiple of the model parallel way {mp}")
335        if dropout_rate < 0 or dropout_rate >= 1:
336            raise ValueError(f"dropout_rate probability should be a number in range [0, 1.0), "
337                             f"but got {dropout_rate}")
338        input_size = hidden_size
339        output_size = ffn_hidden_size
340        # Here, 'ep' stands for expert parallel number, which is equal to data parallel number.
341        ep = dp
342        # Project to ffn_hidden_size
343        self.mapping = _Linear(in_channels=input_size,
344                               out_channels=output_size,
345                               activation=hidden_act,
346                               transpose_b=False,
347                               expert_num=expert_num,
348                               param_init_type=param_init_type)
349
350        if expert_num > 1:
351            self.mapping.shard(strategy_matmul=((ep, 1, 1), (ep, 1, mp)),
352                               strategy_bias=((ep, 1, mp), (mp,)),
353                               strategy_activation=((ep, 1, mp),))
354        else:
355            self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)),
356                               strategy_bias=((dp, mp), (mp,)),
357                               strategy_activation=((dp, mp),))
358        # Project back to hidden_size
359        self.projection = _Linear(in_channels=output_size,
360                                  out_channels=input_size,
361                                  transpose_b=False,
362                                  expert_num=expert_num,
363                                  param_init_type=param_init_type)
364        if expert_num > 1:
365            self.projection.shard(strategy_matmul=((ep, 1, mp), (ep, mp, 1)),
366                                  strategy_bias=((ep, 1, 1), (1,)))
367        else:
368            self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)),
369                                  strategy_bias=((dp, 1), (1,)))
370        self.projection.bias.parallel_optimizer = False
371        self.dropout = nn.Dropout(1 - dropout_rate)
372        self.dropout.dropout.shard(((dp, 1),))
373        self.dropout_3d = nn.Dropout(1 - dropout_rate)
374        self.dropout_3d.dropout.shard(((dp, 1, 1),))
375        self.cast = P.Cast()
376
377    def construct(self, x):
378        _check_input_shape(F.shape(x), "x", self.cls_name, [2, 3])
379        _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
380        x = self.cast(x, mstype.float16)
381        # returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
382        hidden = self.mapping(x)
383        output = self.projection(hidden)
384        # returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
385        if len(F.shape(output)) == 3:
386            output = self.dropout_3d(output)
387        else:
388            output = self.dropout(output)
389        return output
390
391
392class AttentionMask(Cell):
393    r"""
394    Get the Lower triangular matrix from the input mask. The input mask is a 2D tensor (batch_size, seq_length)
395    with 1 and 0. 1 indicates the current position is a valid token, otherwise not.
396
397    Args:
398        seq_length(int): The sequence length of the input tensor.
399        parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
400                                           an instance of `OpParallelConfig` with default args.
401
402    Inputs:
403        - **input_mask** (Tensor) - The mask indicating whether each position is a valid input with
404          (batch_size, seq_length).
405
406    Outputs:
407        Tensor. The attention mask matrix with shape (batch_size, seq_length, seq_length).
408
409    Raises:
410        TypeError: `seq_length` is not an integer.
411        ValueError: `seq_length` is not a positive value.
412        TypeError: `parallel_config` is not a subclass of OpParallelConfig.
413
414    Supported Platforms:
415        ``Ascend`` ``GPU``
416
417    Examples:
418        >>> import numpy as np
419        >>> from mindspore.parallel.nn import AttentionMask
420        >>> from mindspore import Tensor
421        >>> mask = AttentionMask(seq_length=4)
422        >>> mask_array = np.array([[1, 1, 1, 0]], np.float32)
423        >>> inputs = Tensor(mask_array)
424        >>> res = mask(inputs)
425        >>> print(res)
426        [[[1. 0. 0. 0],
427          [1. 1. 0. 0],
428          [1. 1. 1. 0],
429          [0. 0. 0. 0]]]
430    """
431
432    @_args_type_validator_check(seq_length=Validator.check_positive_int,
433                                parallel_config=_valid_type_checks([OpParallelConfig], "AttentionMask"))
434    def __init__(self, seq_length, parallel_config=default_dpmp_config):
435        super(AttentionMask, self).__init__()
436        self.seq_length = seq_length
437        self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1), ()))
438        self.reshape = P.Reshape()
439        self.mul = P.BatchMatMul().shard(
440            ((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
441        self.expand_dim = P.ExpandDims().shard(((1, 1),))
442        ones = np.ones(shape=(seq_length, seq_length))
443        # Default lower triangle mask matrix
444        self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32)
445        self.multiply = P.Mul().shard(((parallel_config.data_parallel, 1, 1), (1, 1, 1)))
446
447    def construct(self, input_mask):
448        _check_input_shape(F.shape(input_mask), "input_mask", self.cls_name, 2)
449        _check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name)
450        _check_input_shape_value(F.shape(input_mask), 1, "input_mask", self.cls_name, self.seq_length)
451        input_mask = P.Cast()(self.not_equal(input_mask, 0), mstype.float16)
452        input_shape = P.Shape()(input_mask)
453        shape_right = (input_shape[0], 1, input_shape[1])
454        shape_left = input_shape + (1,)
455        # Mask the padded inputs
456        mask_left = self.reshape(input_mask, shape_left)
457        mask_right = self.reshape(input_mask, shape_right)
458        attention_mask = self.mul(mask_left, mask_right)
459        lower_traiangle = self.expand_dim(self.lower_triangle_mask, 0)
460        # the returned shape is [bs, seq_length, seq_length]
461        attention_mask = self.multiply(
462            attention_mask, lower_traiangle)
463        return attention_mask
464
465
466class VocabEmbedding(Cell):
467    """
468    The embedding lookup table from the 0-th dim of the parameter table. When the parallel_config.vocab_emb_dp is
469    True and in the `AUTO_PARALLEL_MODE`, the embedding lookup will be a `parallel_config.data_parallel`
470    data parallel way, or will shard the parameter at the 0-th dimension in `parallel_config.model_parallel`, so-called
471    row slice of the embedding table.
472
473    Args:
474        vocab_size (int): Size of the dictionary of embeddings.
475        embedding_size (int): The size of each embedding vector.
476        param_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
477            Refer to class `initializer` for the values of string when a string
478            is specified. Default: 'normal'.
479        parallel_config(EmbeddingOpParallelConfig): The parallel config of network. Default
480            `default_embedding_parallel_config`, an instance of `EmbeddingOpParallelConfig` with default args.
481
482    Inputs:
483        **input_ids** (Tensor) - The tokenized inputs with datatype int32 with shape (batch_size, seq_length)
484
485    Outputs:
486        Tuple, a tuple contains (`output`, `embedding_table`)
487
488        - **output** (Tensor) - The embedding vector for the input with shape (batch_size,
489          seq_length, embedding_size).
490        - **weight** (Tensor) - The embedding table with shape (vocab_size, embedding_size).
491
492    Raises:
493        ValueError: If the parallel_config.vocab_emb_dp is True, the vocab size is not a multiple of
494            parallel_config.model_parallel
495        ValueError: `vocab_size` is not a positive value.
496        ValueError: `embedding_size` is not a positive value.
497        TypeError: `parallel_config` is not a subclass of OpParallelConfig.
498
499    Supported Platforms:
500        ``Ascend`` ``GPU``
501
502    Examples:
503        >>> import numpy as np
504        >>> from mindspore.parallel.nn import VocabEmbedding
505        >>> from mindspore import Tensor
506        >>> from mindspore import dtype as mstype
507        >>> model = VocabEmbedding(vocab_size=30, embedding_size=30)
508        >>> tensor = Tensor(np.ones((20, 15)), mstype.int32)
509        >>> output, table = model(tensor)
510        >>> print(output.shape)
511        (20, 15, 30)
512        >>> print(table.shape)
513        (30, 30)
514    """
515
516    @_args_type_validator_check(vocab_size=Validator.check_positive_int,
517                                embedding_size=Validator.check_positive_int,
518                                parallel_config=_valid_type_checks([EmbeddingOpParallelConfig], "VocabEmbedding"))
519    def __init__(self, vocab_size, embedding_size, parallel_config=default_embedding_parallel_config,
520                 param_init='normal'):
521        super(VocabEmbedding, self).__init__()
522        _check_config(parallel_config)
523        self.vocab_size = vocab_size
524        self.embedding_size = embedding_size
525        self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
526                                         name='embedding_table', parallel_optimizer=False)
527        if parallel_config.vocab_emb_dp:
528            self.gather = P.GatherV2().shard(((1, 1), (parallel_config.data_parallel, 1)))
529            logger.info(f"Using {parallel_config.data_parallel} data parallel for the embedding lookup.")
530        else:
531            if self.vocab_size % parallel_config.model_parallel != 0:
532                raise ValueError(f"The vocab size of the embedding {self.vocab_size} must be a "
533                                 f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.")
534            self.gather = P.GatherV2().shard(((parallel_config.model_parallel, 1), (1, 1)))
535            logger.info(f"Using {parallel_config.data_parallel} model parallel for the embedding lookup.")
536
537    def construct(self, input_ids):
538        _check_input_shape(F.shape(input_ids), "input_ids", self.cls_name, 2)
539        _check_input_dtype(F.dtype(input_ids), "input_ids", [mstype.int32], self.cls_name)
540        output = self.gather(self.embedding_table, input_ids, 0)
541        return output, self.embedding_table
542
543
544class MultiHeadAttention(Cell):
545    r"""
546    This is an implementation of multihead attention in the paper `Attention is all you need
547    <https://arxiv.org/pdf/1706.03762v5.pdf>`_. Given the query vector with source length, and the
548    key and value vector with target length, the attention will be performed as the following
549
550    .. math::
551           MultiHeadAttention(query, key, vector) = Concat(head_1, \dots, head_h)W^O
552
553    where :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)`. The default is with a bias.
554
555    if query, key and value tensor is same, then it will be self attention.
556
557    Args:
558        batch_size(int): The batch size of the input tensor.
559        src_seq_length(int): The sequence length of the query vector.
560        tgt_seq_length(int): The sequence length of the key and value vector.
561        hidden_size(int): The hidden size of the input.
562        num_heads(int): The number of the heads.
563        hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1
564        attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1
565        compute_dtype(dtype.Number): The computation type of dense. Default dtype.float16.
566            Should be dtype.float32 or dtype.float16.
567        param_init_type(dtype.Number): The parameter initialization type of the module. Default dtype.float32.
568            Should be dtype.float32 or dtype.float16.
569        softmax_compute_type(dtype.Number): The type of softmax computation module. Default dtype.float32.
570            Should be dtype.float32 or dtype.float16.
571        use_past(bool): Use the past state to compute, used for incremental prediction. For example, if we have two
572            words and want to generate the ten more words. We just need to compute the two words's state only once,
573            and generate the next word one by one. When use_past is True, there are two steps to run the prediction.
574            The first step, set the is_first_iteration to be True by
575            `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the
576            is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. At this moment,
577            pass the single step's input tensor, and loop it. Default False.
578        parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
579                                           an instance of `OpParallelConfig` with default args.
580
581    Inputs:
582        - **query_tensor** (Tensor) - the query vector with shape (batch_size, src_seq_length, hidden_size) or
583          (batch_size * src_seq_length, hidden_size), if the use_past is False or is_first_iteration=True. Otherwise,
584          must be (batch_size, 1, hidden_size)
585        - **key_tensor** (Tensor) - the key vector with shape (batch_size, tgt_seq_length, hidden_size) or
586          (batch_size * tgt_seq_length, hidden_size), if the use_past is False or is_first_iteration=True. Otherwise,
587          must be (batch_size, 1, hidden_size)
588        - **value_tensor** (Tensor) - the value vector with shape (batch_size, tgt_seq_length, hidden_size) or
589          (batch_size * tgt_seq_length, hidden_size), if the use_past is False or is_first_iteration=True. Otherwise,
590          must be (batch_size, 1, hidden_size)
591        - **attention_mask** (Tensor) - the attention mask matrix with shape (batch_size, src_seq_length,
592          tgt_seq_length), if the use_past is False or is_first_iteration=True. Otherwise,
593          must be (batch_size, 1, tgt_seq_length)
594        - **key_past** (Tensor) - Float16 tensor with shape (batch_size, num_heads, size_per_head, tgt_seq_length).
595          The past calculated key vector. Used for incremental prediction when the use_past is True.
596          Default None.
597        - **value_past** (Tensor) - Float16 tensor with shape (batch_size, num_heads, tgt_seq_length, size_per_head).
598          The past calculated value vector. Used for incremental prediction when the use_past is True.
599          Default None.
600        - **batch_valid_length** (Tensor) - Int32 tensor with shape (batch_size,) the past calculated the index.
601          Used for incremental prediction when the use_past is True. Default None.
602
603    Outputs:
604        Tuple, a tuple contains(`output`, `layer_present`)
605
606        - **output** (Tensor) - Tensor, the float tensor of the output of the layer with
607          shape (batch_size, src_seq_length, hidden_size) or (batch_size * src_seq_length, hidden_size),
608          if the use_past is False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size).
609
610        - **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
611          ((batch_size, num_heads, size_per_head, tgt_seq_length),
612          (batch_size, num_heads, tgt_seq_length, size_per_head)).
613
614    Supported Platforms:
615        ``Ascend`` ``GPU``
616
617    Examples:
618        >>> import numpy as np
619        >>> from mindspore.parallel.nn import MultiHeadAttention
620        >>> from mindspore import dtype as mstype
621        >>> from mindspore import Tensor
622        >>> model = MultiHeadAttention(batch_size=2, hidden_size=15, src_seq_length=20, tgt_seq_length=20,
623        ...                            num_heads=3)
624        >>> from_tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
625        >>> to_tensor = Tensor(np.ones((2, 20, 15)), mstype.float16)
626        >>> attention_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
627        >>> attn_out, past = model(from_tensor, to_tensor, to_tensor, attention_mask)
628        >>> print(attn_out.shape)
629        (2, 20, 15)
630        >>> print(past[0].shape)
631        (2, 3, 5, 20)
632        >>> print(past[1].shape)
633        (2, 3, 20, 5)
634        # When use use_past=True, it includes two steps to implement the incremental prediction.
635        # Step 1: set is_first_iteration=True, and input the full sequence length's state.
636        # We need to prepare the memory parameters for saving key and value states firstly.
637        >>> model = MultiHeadAttention(batch_size=2, hidden_size=15, src_seq_length=20, tgt_seq_length=20,
638        ...                            num_heads=3, use_past=True)
639        >>> key_past = Tensor(np.zeros(shape=(2, 3, 5, 20)), mstype.float16)
640        >>> value_past = Tensor(np.zeros(shape=(2, 3, 20, 5)), mstype.float16)
641        >>> batch_valid_length = Tensor(np.ones((2,)), mstype.int32)
642        # Set is_first_iteration=True to generate the full memory states
643        >>> model.add_flags_recursive(is_first_iteration=True)
644        >>> attn_out, past = model(from_tensor, to_tensor, to_tensor, attention_mask, key_past, value_past,
645        ...                        batch_valid_length)
646        >>> print(attn_out.shape)
647        (2, 20, 15)
648        >>> print(past[0].shape)
649        (2, 3, 5, 20)
650        >>> print(past[1].shape)
651        (2, 3, 20, 5)
652        >>> from_tensor = Tensor(np.ones((2, 1, 15)), mstype.float32)
653        >>> to_tensor = Tensor(np.ones((2, 1, 15)), mstype.float16)
654        >>> attention_mask = Tensor(np.ones((2, 1, 20)), mstype.float16)
655        # Step 2: set is_first_iteration=False, and pass the single word to run the prediction rather than the full
656        # sequence.
657        >>> model.add_flags_recursive(is_first_iteration=False)
658        >>> attn_out, past = model(from_tensor, to_tensor, to_tensor, attention_mask, key_past, value_past,
659        ...                        batch_valid_length)
660        >>> print(attn_out.shape)
661        (2, 1, 15)
662        >>> print(past[0].shape)
663        (2, 3, 5, 20)
664        >>> print(past[1].shape)
665        (2, 3, 20, 5)
666    """
667
668    @_args_type_validator_check(batch_size=Validator.check_positive_int,
669                                hidden_size=Validator.check_positive_int,
670                                num_heads=Validator.check_positive_int,
671                                src_seq_length=Validator.check_positive_int,
672                                tgt_seq_length=Validator.check_positive_int,
673                                attention_dropout_rate=Validator.check_non_negative_float,
674                                hidden_dropout_rate=Validator.check_non_negative_float,
675                                compute_dtype=_valid_value_checks([mstype.float32, mstype.float16],
676                                                                  "MultiHeadAttention"),
677                                softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
678                                                                         "MultiHeadAttention"),
679                                param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
680                                                                    "MultiHeadAttention"),
681                                parallel_config=_valid_type_checks([OpParallelConfig],
682                                                                   "MultiHeadAttention"),
683                                use_past=Validator.check_bool)
684    def __init__(self, batch_size,
685                 src_seq_length,
686                 tgt_seq_length,
687                 hidden_size,
688                 num_heads,
689                 hidden_dropout_rate=0.1,
690                 attention_dropout_rate=0.1,
691                 compute_dtype=mstype.float16,
692                 softmax_compute_type=mstype.float32,
693                 param_init_type=mstype.float32,
694                 use_past=False,
695                 parallel_config=default_dpmp_config):
696        super(MultiHeadAttention, self).__init__()
697        _check_config(parallel_config)
698        self.is_parallel_mode = _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
699        self.src_seq_length = src_seq_length
700        self.tgt_seq_length = tgt_seq_length
701        self.hidden_size = hidden_size
702        self.batch_size = batch_size
703        if hidden_dropout_rate < 0 or hidden_dropout_rate >= 1:
704            raise ValueError(f"hidden_dropout_rate probability should be a number in range [0, 1.0), "
705                             f"but got {hidden_dropout_rate}")
706        if attention_dropout_rate < 0 or attention_dropout_rate >= 1:
707            raise ValueError(f"attention_dropout_rate probability should be a number in range [0, 1.0), "
708                             f"but got {attention_dropout_rate}")
709        if hidden_size % num_heads != 0:
710            raise ValueError(f"The hidden size {hidden_size} should be a multiple of num_heads {num_heads}")
711        if num_heads % parallel_config.model_parallel != 0:
712            raise ValueError(f"The number of heads {num_heads} must be a "
713                             f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.")
714        if self.is_parallel_mode and batch_size % parallel_config.data_parallel != 0:
715            raise ValueError(f"The batch size {batch_size} must be a "
716                             f"multiple of parallel_config.data_parallel {parallel_config.data_parallel}.")
717        self.is_first_iteration = True
718        # Output layer
719        self.projection = _Linear(in_channels=hidden_size,
720                                  out_channels=hidden_size,
721                                  transpose_b=False,
722                                  param_init_type=param_init_type).to_float(compute_dtype)
723        self.projection.shard(strategy_bias=((parallel_config.data_parallel, 1), (1,)),
724                              strategy_matmul=((parallel_config.data_parallel, parallel_config.model_parallel),
725                                               (parallel_config.model_parallel, 1)))
726        self.projection.bias.parallel_optimizer = False
727        self.transpose = P.Transpose().shard(((parallel_config.data_parallel, 1, parallel_config.model_parallel, 1),))
728        self.merger_head_transpose = P.Transpose().shard(
729            ((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),))
730        self.reshape = P.Reshape()
731        self.n_head = num_heads
732        # embedding size per head
733        self.size_per_head = hidden_size // self.n_head
734        self.concat_k = P.Concat(axis=3)
735        self.concat_v = P.Concat(axis=2)
736        self.multiply_data = Tensor([
737            -10000.0,
738        ], dtype=softmax_compute_type)
739        self.batch_matmul = P.BatchMatMul().shard(
740            ((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),
741             (parallel_config.data_parallel, parallel_config.model_parallel, 1, 1)))
742        self.real_div = P.RealDiv().shard(((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1), ()))
743        self.sub = P.Sub().shard(
744            ((1,), (parallel_config.data_parallel, 1, 1, 1)))
745        self.mul = P.Mul().shard(
746            ((parallel_config.data_parallel, 1, 1, 1), (1,)))
747        self.add = P.Add().shard(
748            ((parallel_config.data_parallel, 1, 1, 1),
749             (parallel_config.data_parallel, parallel_config.model_parallel, 1, 1)))
750        # Normalize factor for attention, sqrt(dk) as widely used
751        self.scale_factor = Tensor(math.sqrt(self.size_per_head))
752        self.use_past = use_past
753        self.dropout = nn.Dropout(1 - hidden_dropout_rate)
754        self.dropout.dropout.shard(((parallel_config.data_parallel, 1),))
755        self.prob_dropout = nn.Dropout(1 - attention_dropout_rate)
756        self.prob_dropout.dropout.shard(
757            ((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),))
758        self.softmax = nn.Softmax().to_float(softmax_compute_type)
759        self.softmax.softmax.shard(((parallel_config.data_parallel, parallel_config.model_parallel, 1),))
760        self.expand_dims = P.ExpandDims().shard(((parallel_config.data_parallel, 1, 1),))
761
762        # Query
763        self.dense1 = _Linear(hidden_size,
764                              hidden_size,
765                              param_init_type=param_init_type).to_float(compute_dtype)
766        self.dense1.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)),
767                          strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel),
768                                         (parallel_config.model_parallel,)))
769        # Key
770        self.dense2 = _Linear(hidden_size,
771                              hidden_size,
772                              param_init_type=param_init_type).to_float(compute_dtype)
773        self.dense2.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)),
774                          strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel),
775                                         (parallel_config.model_parallel,)))
776
777        # Value
778        self.dense3 = _Linear(hidden_size,
779                              hidden_size,
780                              param_init_type=param_init_type).to_float(compute_dtype)
781        self.dense3.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)),
782                          strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel),
783                                         (parallel_config.model_parallel,)))
784        self.dtype = compute_dtype
785        self.softmax_dtype = softmax_compute_type
786        if self.use_past:
787            # operators used for state reuse
788            seq_range = np.arange(src_seq_length).reshape(1, 1, -1)
789            self.range = Tensor(np.tile(seq_range, (batch_size, 1, 1)), mstype.int32)
790            self.seq_length = src_seq_length
791            self.attention_mask = Tensor(np.tril(np.ones(shape=(self.seq_length, self.seq_length))), mstype.int32)
792            self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
793            self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
794            self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),))
795            self.expand_dims = P.ExpandDims().shard(((1, 1, 1),))
796            self.tensor_le = P.LessEqual().shard(((1, 1, 1), (1, 1, 1)))
797            self.add = P.Add().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
798            self.equal = P.Equal().shard(((1, 1, 1), (1, 1, 1)))
799            self.sub1 = P.Sub().shard(((1,), ()))
800            self.tile = P.Tile().shard(((1, 1, 1, 1),))
801            self.less = P.Less().shard(((1, 1, 1), (1, 1, 1)))
802            self.mul1 = P.Mul().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
803
804    def construct(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
805                  value_past=None, batch_valid_length=None):
806        self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past,
807                           value_past, batch_valid_length)
808        query_tensor, key_tensor, value_tensor, batch_size, ori_shape = self._convert_to_2d_tensor(query_tensor,
809                                                                                                   key_tensor,
810                                                                                                   value_tensor,
811                                                                                                   attention_mask)
812
813        # multi head attention: query, key, value are derived from the same inputs
814        query = self.dense1(query_tensor)
815        key = self.dense2(key_tensor)
816        value = self.dense3(value_tensor)
817        # the returned shape is [bs, num_heads, seq_length, size_per_head]
818        query = self.transpose(
819            F.reshape(
820                query,
821                (batch_size, -1, self.n_head, self.size_per_head)),
822            (0, 2, 1, 3))
823        # the returned shape is [bs, size_per_head, seq_length, num_heads]
824        key = self.transpose(
825            F.reshape(
826                key, (batch_size, -1, self.n_head, self.size_per_head)),
827            (0, 2, 3, 1))
828        # the returned shape is [bs, num_heads, seq_length, size_per_head]
829        value = self.transpose(
830            F.reshape(
831                value,
832                (batch_size, -1, self.n_head, self.size_per_head)),
833            (0, 2, 1, 3))
834        # support input shape is [bs, seq, seq] or [bs, heads, seq, seq]
835        if len(F.shape(attention_mask)) == 3:
836            # expand attention mask from [bs, seq, seq] -> [bs, 1, seq, seq]
837            attention_mask = self.expand_dims(attention_mask, 1)
838        # key and value for current token(s)
839        key_present = key
840        value_present = value
841        if self.use_past:
842            # The first graph with the input size of (bs, seq_length)
843            if self.is_first_iteration:
844                # Get the valid input length without padding
845                valid_length_vector = F.cast(self.less(self.range, batch_valid_length.view(-1, 1, 1)), self.dtype)
846                # Cover the key and value numbers corresponding to the padding position
847                key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2))
848                value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3))
849            # The second graph with the inpus size of (bs, 1)
850            # the shape of query is (bs, num_heads, 1, size_per_head)
851            # the shape of key is   (bs, num_heads, size_per_head, 1)
852            # the shape of value is (bs, num_heads, 1, size_per_head)
853            else:
854                # Get the current token position index
855                valid_length = self.reducesum(F.cast(self.not_equal(self.slice(key_past, (0, 0, 0, 0),
856                                                                               (F.shape(key_tensor)[0], 1, 1,
857                                                                                self.src_seq_length),
858                                                                               (1, 1, 1, 1)),
859                                                                    0), mstype.float32), (1, 2, 3))
860                valid_length = F.reshape(valid_length, (-1, 1, 1))
861                valid_length_vector = F.cast(self.equal(valid_length, self.range), self.dtype)
862                # Pad the key and value to seq_length with only the position index not zero
863                current_key = self.mul1(self.tile(key, (1, 1, 1, self.seq_length)),
864                                        self.expand_dims(valid_length_vector, 2))
865                current_value = self.mul1(self.tile(value, (1, 1, self.seq_length, 1)),
866                                          self.expand_dims(valid_length_vector, 3))
867                # Concat the previous saved state and current state
868                key = self.add(key_past, current_key)
869                value = self.add(value_past, current_value)
870                # Update key_present and value_present for state update
871                key_present = key
872                value_present = value
873                attention_mask = F.reshape(self.attention_mask, (self.seq_length, self.seq_length, 1, 1))
874
875        layer_present = (key_present, value_present)
876        # multi head attention considering attention mask
877        # the return shape is [bs * seq_length, hidden_size]
878        attention = self._attn(query, key, value, attention_mask)
879        # Output
880        output = self.projection(attention)
881        output = self.dropout(output)
882        output = F.reshape(output, ori_shape)
883        return output, layer_present
884
885    def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
886                      value_past=None, batch_valid_length=None):
887        r"""Check inputs"""
888        if not self.use_past or (self.use_past and self.is_first_iteration):
889            _check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
890                               [[self.batch_size, self.src_seq_length, self.hidden_size],
891                                [self.batch_size * self.src_seq_length, self.hidden_size]])
892            _check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
893                               [[self.batch_size, self.tgt_seq_length, self.hidden_size],
894                                [self.batch_size * self.tgt_seq_length, self.hidden_size]])
895            _check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
896                               [[self.batch_size, self.tgt_seq_length, self.hidden_size],
897                                [self.batch_size * self.tgt_seq_length, self.hidden_size]])
898            _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
899                               [self.batch_size, self.src_seq_length, self.tgt_seq_length])
900        else:
901            _check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
902                               [[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]])
903            _check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
904                               [[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]])
905            _check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
906                               [[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]])
907            _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
908                               [[self.batch_size, 1, self.tgt_seq_length], [self.batch_size, self.hidden_size]])
909
910        _check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name)
911        _check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name)
912        _check_input_dtype(F.dtype(value_tensor), "value_tensor", [mstype.float32, mstype.float16], self.cls_name)
913        _check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
914
915        key_is_tensor = isinstance(key_past, Tensor)
916        value_is_tensor = isinstance(value_past, Tensor)
917        batch_valid_length_is_tensor = isinstance(batch_valid_length, Tensor)
918        key_is_default = key_past is None
919        value_is_default = value_past is None
920        batch_is_default = batch_valid_length is None
921        _check_past_none_input_none(self.use_past, "key_past", self.cls_name, None, key_is_tensor,
922                                    key_is_default)
923        _check_past_none_input_none(self.use_past, "value_past", self.cls_name, None, value_is_tensor,
924                                    value_is_default)
925        _check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None,
926                                    batch_valid_length_is_tensor, batch_is_default)
927        if self.use_past:
928            _check_shape_equal(F.shape(key_past), "key_past", self.cls_name,
929                               [self.batch_size, self.n_head, self.size_per_head, self.tgt_seq_length])
930            _check_input_dtype(F.dtype(key_past), "key_past", [mstype.float16], self.cls_name)
931            _check_shape_equal(F.shape(value_past), "value_past", self.cls_name,
932                               [self.batch_size, self.n_head, self.tgt_seq_length, self.size_per_head])
933            _check_input_dtype(F.dtype(value_past), "value_past", [mstype.float16], self.cls_name)
934            _check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size])
935            _check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
936        return True
937
938    def _convert_to_2d_tensor(self, query_tensor, key_tensor, value_tensor, attention_mask):
939        """convert a nd tensor to a 2d tensor"""
940        query_shape = F.shape(query_tensor)
941        query_tensor = F.reshape(query_tensor, (-1, query_shape[-1]))
942        key_shape = F.shape(key_tensor)
943        key_tensor = F.reshape(key_tensor, (-1, key_shape[-1]))
944        value_shape = F.shape(value_tensor)
945        value_tensor = F.reshape(value_tensor, (-1, value_shape[-1]))
946        return query_tensor, key_tensor, value_tensor, F.shape(attention_mask)[0], query_shape
947
948    def _merge_heads(self, x):
949        """
950        convert a 4d input to a 2d output
951
952        Inputs:
953            x: input tensor
954
955        Output:
956            x_merge: the 2d output
957        """
958        x = self.merger_head_transpose(
959            x, (0, 2, 1, 3))  # bs, seq_length, head, size_per_head
960        x_shape = P.Shape()(x)
961        new_shape = (-1, x_shape[-2] * x_shape[-1])
962        x_merge = self.reshape(x, new_shape)
963        return x_merge
964
965    def _attn(self, query, key, value, attention_mask):
966        """
967        Get the weighted score along the seq_length
968
969        Inputs:
970            query: the query matrix
971            key: the key matrix
972            value: the value matrix
973            attention_mask: the attention mask matrix with shape (batch_size,
974            1, seq_length, seq_length)
975        Outputs:
976            weighted_values: Tensor, the weighted sum scores
977        """
978        # Normalize query and key before MatMul, default off
979        # Attention score [bs, num_heads, seq_length, seq_length]
980        score = self.batch_matmul(query, key)
981        # Normalize after query and key MatMul
982        score = self.real_div(
983            score,
984            P.Cast()(self.scale_factor, P.DType()(score)))
985
986        ori_dtype = P.DType()(score)
987        score = P.Cast()(score, self.softmax_dtype)
988
989        # for input size of (bs, 1) namely the second graph,
990        # the shape of attention_mask matrix should be (bs, 1, 1, seq_length)
991        if self.use_past and not self.is_first_iteration:
992            # Calculate the current total token
993            current_index = self.reducesum(F.cast(self.not_equal(self.slice(key, (0, 0, 0, 0),
994                                                                            (F.shape(query)[0], 1, 1, self.seq_length),
995                                                                            (1, 1, 1, 1)),
996                                                                 0), mstype.float32), (1, 2, 3))
997            # Get the precise position index
998            index = self.sub1(F.cast(current_index, mstype.int32), 1)
999            index = F.reshape(index, (-1, 1, 1))
1000            # Calculate the attention_mask matrix via the position index
1001            attention_mask = F.cast(self.tensor_le(self.range, index), mstype.int32)
1002            attention_mask = self.expand_dims(attention_mask, 2)
1003
1004        # Minus 10000 for the position where masked to exclude them from softmax
1005        multiplu_out = self.sub(
1006            P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)),
1007            P.Cast()(attention_mask, P.DType()(score)))
1008
1009        adder = self.mul(multiplu_out, self.multiply_data)
1010        attention_scores = self.add(adder, score)
1011
1012        shape = F.shape(attention_scores)
1013        # attention probs
1014        attention_probs = self.softmax(
1015            F.reshape(attention_scores,
1016                      (shape[0], -1, shape[-1])))
1017        attention_probs = P.Cast()(attention_probs, ori_dtype)
1018        attention_probs = F.reshape(attention_probs, shape)
1019
1020        attention_probs = self.prob_dropout(attention_probs)
1021        # Weighted sum output [bs, num_heads, seq_length, size_per_head]
1022        weighted_values = self.batch_matmul(attention_probs, value)
1023        attention_merge = self._merge_heads(weighted_values)
1024        return attention_merge
1025
1026
1027class TransformerEncoderLayer(Cell):
1028    r"""
1029    Transformer Encoder Layer. This is an implementation of the single layer of the transformer
1030    encoder layer, including multihead attention and feedward layer.
1031
1032    Args:
1033        batch_size(int): The batch size of the input tensor.
1034        hidden_size(int): The hidden size of the input.
1035        seq_length(int): The input sequence length.
1036        ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
1037        num_heads(int): The number of the heads.
1038        hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1
1039        attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1
1040        post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False.
1041        hidden_act(str): The activation of the internal feedforward layer. Supports 'relu',
1042                         'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
1043                         'hsigmoid', 'logsigmoid' and so on. Default: gelu.
1044        layernorm_compute_type(dtype.Number): The computation type of the layernorm.
1045            Should be dtype.float32 or dtype.float16. Default dtype.float32.
1046        softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
1047            Should be dtype.float32 or dtype.float16. Default mstype.float32.
1048        param_init_type(dtype.Number): The parameter initialization type of the module.
1049            Should be dtype.float32 or dtype.float16. Default dtype.float32.
1050        use_past(bool): Use the past state to compute, used for incremental prediction. For example, if we have two
1051            words and want to generate the ten more words. We just need to compute the two words's state only once,
1052            and generate the next word one by one. When use_past is True, there are two steps to run the prediction.
1053            The first step, set the is_first_iteration to be True by
1054            `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the
1055            is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. At this moment,
1056            pass the single step's input tensor, and loop it. Default False.
1057        moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
1058        parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
1059                                           an instance of `OpParallelConfig` with default args.
1060
1061    Inputs:
1062        - **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or
1063          [batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise,
1064          should be [batch_size, 1, hidden_size]
1065        - **input_mask** (Tensor) - Float Tensor, attention mask with shape [batch_size, seq_length, seq_length],
1066          if the use_past is False or is_first_iteration=True. Otherwise, should be [batch_size, 1, hidden_size]
1067        - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
1068          past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
1069        - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used
1070          for incremental prediction when the use_past is True. Default None.
1071
1072    Outputs:
1073        Tuple, a tuple contains(`output`, `layer_present`).
1074
1075        - **output** (Tensor) - The float tensor of the output of the layer with
1076          shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size), if the use_past is
1077          False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size)
1078
1079        - **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
1080          ((batch_size, num_heads, size_per_head, seq_length),
1081          (batch_size, num_heads, seq_length, size_per_head)).
1082
1083    Supported Platforms:
1084        ``Ascend`` ``GPU``
1085
1086    Examples:
1087        >>> import numpy as np
1088        >>> from mindspore import dtype as mstype
1089        >>> from mindspore.parallel.nn import TransformerEncoderLayer
1090        >>> from mindspore import Tensor
1091        >>> model = TransformerEncoderLayer(batch_size=2, hidden_size=8, ffn_hidden_size=64, seq_length=16,
1092        ...                                 num_heads=2)
1093        >>> encoder_input_value = Tensor(np.ones((2, 16, 8)), mstype.float32)
1094        >>> encoder_input_mask = Tensor(np.ones((2, 16, 16)), mstype.float16)
1095        >>> output, past = model(encoder_input_value, encoder_input_mask)
1096        >>> print(output.shape)
1097        (2, 16, 8)
1098        >>> print(past[0].shape)
1099        (2, 2, 4, 16)
1100        >>> print(past[1].shape)
1101        (2, 2, 16, 4)
1102        # When use use_past=True, it includes two steps to implement the incremental prediction.
1103        # Step 1: set is_first_iteration=True, and input the full sequence length's state.
1104        >>> batch_valid_length = Tensor(np.ones((2,)), mstype.int32)
1105        >>> init_reset = Tensor([True], mstype.bool_)
1106        # Set is_first_iteration=True to generate the full memory states
1107        >>> model = TransformerEncoderLayer(batch_size=2, hidden_size=8, ffn_hidden_size=64, seq_length=16,
1108        ...                                 num_heads=2, use_past=True)
1109        >>> model.add_flags_recursive(is_first_iteration=True)
1110        >>> hidden, past = model(encoder_input_value, encoder_input_mask, init_reset, batch_valid_length)
1111        >>> print(hidden.shape)
1112        (2, 16, 8)
1113        >>> print(past[0].shape)
1114        (2, 2, 4, 16)
1115        >>> print(past[1].shape)
1116        (2, 2, 16, 4)
1117        >>> encoder_input_value = Tensor(np.ones((2, 1, 8)), mstype.float32)
1118        >>> encoder_input_mask = Tensor(np.ones((2, 1, 16)), mstype.float16)
1119        >>> init_reset = Tensor([False], mstype.bool_)
1120        # Step 2: set is_first_iteration=False, and pass the single word to run the prediction rather than the full
1121        # sequence.
1122        >>> model.add_flags_recursive(is_first_iteration=False)
1123        >>> hidden, past = model(encoder_input_value, encoder_input_mask, init_reset, batch_valid_length)
1124        >>> print(hidden.shape)
1125        (2, 1, 8)
1126        >>> print(past[0].shape)
1127        (2, 2, 4, 16)
1128        >>> print(past[1].shape)
1129        (2, 2, 16, 4)
1130    """
1131
1132    @_args_type_validator_check(batch_size=Validator.check_positive_int,
1133                                hidden_size=Validator.check_positive_int,
1134                                num_heads=Validator.check_positive_int,
1135                                ffn_hidden_size=Validator.check_positive_int,
1136                                seq_length=Validator.check_positive_int,
1137                                attention_dropout_rate=Validator.check_non_negative_float,
1138                                hidden_dropout_rate=Validator.check_non_negative_float,
1139                                hidden_act=_valid_type_checks([str], "TransformerEncoderLayer"),
1140                                post_layernorm_residual=Validator.check_bool,
1141                                layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
1142                                                                           "TransformerEncoderLayer"),
1143                                softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
1144                                                                         "TransformerEncoderLayer"),
1145                                param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
1146                                                                    "TransformerEncoderLayer"),
1147                                parallel_config=_valid_type_checks([OpParallelConfig],
1148                                                                   "TransformerEncoderLayer"),
1149                                use_past=Validator.check_bool)
1150    def __init__(self,
1151                 batch_size,
1152                 hidden_size,
1153                 ffn_hidden_size,
1154                 num_heads,
1155                 seq_length,
1156                 attention_dropout_rate=0.1,
1157                 hidden_dropout_rate=0.1,
1158                 post_layernorm_residual=False,
1159                 layernorm_compute_type=mstype.float32,
1160                 softmax_compute_type=mstype.float32,
1161                 param_init_type=mstype.float32,
1162                 hidden_act='gelu',
1163                 use_past=False,
1164                 moe_config=default_moe_config,
1165                 parallel_config=default_dpmp_config):
1166        super(TransformerEncoderLayer, self).__init__()
1167        _check_config(parallel_config)
1168        if num_heads % parallel_config.model_parallel != 0:
1169            raise ValueError(
1170                f"num heads must be divisibled by the model parallel way {parallel_config.model_parallel}, "
1171                f"but found {num_heads}")
1172        if hidden_size % parallel_config.model_parallel != 0:
1173            raise ValueError(
1174                f"hidden_size must be divisibled by the model parallel way {parallel_config.model_parallel}, "
1175                f"but found {hidden_size}")
1176        if ffn_hidden_size % parallel_config.model_parallel != 0:
1177            raise ValueError(
1178                f"ffn_hidden_size must be divisibled by the model parallel way {parallel_config.model_parallel}, "
1179                f"but found {ffn_hidden_size}")
1180        self.use_past = use_past
1181        self.seq_length = seq_length
1182        self.hidden_size = hidden_size
1183        self.batch_size = batch_size
1184        self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
1185        self.layernorm1.shard(((parallel_config.data_parallel, 1),))
1186        self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
1187        self.layernorm2.shard(((parallel_config.data_parallel, 1),))
1188
1189        self.attention = MultiHeadAttention(batch_size=batch_size,
1190                                            src_seq_length=seq_length,
1191                                            tgt_seq_length=seq_length,
1192                                            hidden_size=hidden_size,
1193                                            num_heads=num_heads,
1194                                            hidden_dropout_rate=hidden_dropout_rate,
1195                                            attention_dropout_rate=attention_dropout_rate,
1196                                            softmax_compute_type=softmax_compute_type,
1197                                            param_init_type=param_init_type,
1198                                            use_past=use_past,
1199                                            parallel_config=parallel_config)
1200        self.use_moe = (moe_config.expert_num > 1)
1201        if self.use_moe is True:
1202            self.output = MoE(hidden_size=hidden_size,
1203                              dropout_rate=hidden_dropout_rate,
1204                              ffn_hidden_size=ffn_hidden_size,
1205                              param_init_type=param_init_type,
1206                              hidden_act=hidden_act,
1207                              moe_config=moe_config,
1208                              parallel_config=parallel_config)
1209        else:
1210            # Feed Forward Network, FFN
1211            self.output = FeedForward(hidden_size=hidden_size,
1212                                      dropout_rate=hidden_dropout_rate,
1213                                      ffn_hidden_size=ffn_hidden_size,
1214                                      param_init_type=param_init_type,
1215                                      hidden_act=hidden_act,
1216                                      parallel_config=parallel_config)
1217        self.post_layernorm_residual = post_layernorm_residual
1218        self.add = P.Add().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1)))
1219        self.add_3d = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
1220        self.dtype = mstype.float16
1221        self.key_past = None
1222        self.value_past = None
1223
1224        if self.use_past:
1225            # operator used for state reuse
1226            self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),))
1227            self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
1228            self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
1229            size_per_head = int(hidden_size / num_heads)
1230            self.key_shape = (batch_size, num_heads, size_per_head, seq_length)
1231            self.value_shape = (batch_size, num_heads, seq_length, size_per_head)
1232            # parameters saving key and value states
1233            self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past")
1234            self.value_past = Parameter(Tensor(np.zeros(shape=self.value_shape), self.dtype), name="value_past")
1235            self.tile = P.Tile().shard(((1, 1),))
1236            self.mul = P.Mul().shard(((1, 1, 1, 1), (1,)))
1237            self.assign = P.Assign().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
1238
1239    def construct(self, x, input_mask, init_reset=True, batch_valid_length=None):
1240        self._check_input(x, input_mask, init_reset, batch_valid_length)
1241        x_shape = F.shape(x)
1242        x = F.reshape(x, (-1, x_shape[-1]))
1243        input_x = self.layernorm1(x)
1244        input_x = F.cast(input_x, self.dtype)
1245
1246        # indicate whether reset saved states
1247        key_reset = None
1248        value_reset = None
1249
1250        if self.use_past:
1251            # reset states, init_reset True for reuse and False for reset
1252            key_reset = self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype)))
1253            value_reset = self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype)))
1254            # add dependency for desired execution order
1255            input_x = F.depend(input_x, key_reset)
1256            input_x = F.depend(input_x, value_reset)
1257
1258        attention, layer_present = self.attention(input_x, input_x, input_x, input_mask,
1259                                                  self.key_past, self.value_past, batch_valid_length)
1260        # For post-layernorm the inputs for residual path are output of self-attention and output of layernorm
1261        if self.post_layernorm_residual:
1262            x = self.add(input_x, attention)
1263        # For pre-layernorm the inputs for residual path are output of self-attention and input of this layer
1264        else:
1265            x = self.add(x, attention)
1266
1267        output_x = self.layernorm2(x)
1268        output_x = F.cast(output_x, self.dtype)
1269        aux_loss = None
1270        if self.use_moe is True:
1271            mlp_logit, aux_loss = self.output(output_x)
1272        else:
1273            mlp_logit = self.output(output_x)
1274
1275        value_update = None
1276        key_update = None
1277        if self.use_past:
1278            # current key and value
1279            key_present, value_present = layer_present
1280            # update key and value calculated this step
1281            key_update = self.assign(self.key_past, key_present)
1282            value_update = self.assign(self.value_past, value_present)
1283            # add dependency for desired execution order
1284            key_update = F.depend(key_update, key_reset)
1285            value_update = F.depend(value_update, value_reset)
1286
1287        # add dependency for desired execution order
1288        mlp_logit = F.depend(mlp_logit, value_update)
1289        mlp_logit = F.depend(mlp_logit, key_update)
1290
1291        # if shape is 3d, we reshape the inputs of the add
1292        if len(x_shape) == 3:
1293            output_x = P.Reshape()(output_x, x_shape)
1294            mlp_logit = P.Reshape()(mlp_logit, x_shape)
1295            x = P.Reshape()(x, x_shape)
1296
1297            if self.post_layernorm_residual:
1298                output = self.add_3d(output_x, mlp_logit)
1299            else:
1300                output = self.add_3d(x, mlp_logit)
1301        else:
1302            if self.post_layernorm_residual:
1303                output = self.add(output_x, mlp_logit)
1304            else:
1305                output = self.add(x, mlp_logit)
1306            output = F.reshape(output, x_shape)
1307
1308        if self.use_moe is True:
1309            return output, layer_present, aux_loss
1310        return output, layer_present
1311
1312    def _check_input(self, x, input_mask, init_reset, batch_valid_length):
1313        r"""Check inputs"""
1314        if not self.use_past or (self.use_past and self.is_first_iteration):
1315            _check_shape_equal(F.shape(x), "x", self.cls_name,
1316                               [[self.batch_size, self.seq_length, self.hidden_size],
1317                                [self.batch_size * self.seq_length, self.hidden_size]])
1318            _check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
1319                               [self.batch_size, self.seq_length, self.seq_length])
1320        else:
1321            _check_shape_equal(F.shape(x), "x", self.cls_name, [self.batch_size, 1, self.hidden_size])
1322            _check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
1323                               [self.batch_size, 1, self.seq_length])
1324        _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
1325        _check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name)
1326
1327        init_reset_is_tensor = isinstance(init_reset, Tensor)
1328        init_reset_is_default = init_reset is True
1329        batch_valid_length_is_tensor = isinstance(batch_valid_length, Tensor)
1330        batch_is_default = batch_valid_length is None
1331        _check_past_none_input_none(self.use_past, "init_reset", self.cls_name, True, init_reset_is_tensor,
1332                                    init_reset_is_default)
1333        _check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None,
1334                                    batch_valid_length_is_tensor, batch_is_default)
1335
1336        if self.use_past:
1337            _check_shape_equal(F.shape(init_reset), "init_reset", self.cls_name, [1])
1338            _check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name)
1339            _check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size])
1340            _check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
1341        return True
1342
1343
1344class TransformerDecoderLayer(Cell):
1345    r"""
1346    Transformer Decoder Layer. This is an implementation of the single layer of the transformer
1347    decoder layer, including self-attention, cross attention and feedward layer. When the encoder_output is None,
1348    the cross attention will not be effective.
1349
1350    Args:
1351        batch_size(int): The batch size of the input tensor.
1352        hidden_size(int): The hidden size of the input.
1353        src_seq_length(int): The input source sequence length.
1354        tgt_seq_length(int): The input target sequence length.
1355        ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
1356        num_heads(int): The number of the heads.
1357        hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1.
1358        attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1.
1359        post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False.
1360        hidden_act(str): The activation of the internal feedforward layer. Supports 'relu',
1361                         'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
1362                         'hsigmoid', 'logsigmoid' and so on. Default: gelu.
1363        layernorm_compute_type(dtype.Number): The computation type of the layernorm.
1364            Should be dtype.float32 or dtype.float16. Default dtype.float32.
1365        softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
1366            Should be dtype.float32 or dtype.float16. Default mstype.float32.
1367        param_init_type(dtype.Number): The parameter initialization type of the module.
1368            Should be dtype.float32 or dtype.float16. Default dtype.float32.
1369        use_past(bool): Use the past state to compute, used for incremental prediction. Default False.
1370        moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
1371        parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
1372                                           an instance of `OpParallelConfig` with default args.
1373
1374    Inputs:
1375        - **hidden_stats** (Tensor) - the input tensor with shape [batch_size, tgt_seq_length, hidden_size] or
1376          [batch_size * tgt_seq_length, hidden_size].
1377        - **decoder_mask** (Tensor) - the attention mask for decoder with shape [batch_size, src_seq_length,
1378          seq_length].
1379        - **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size] or
1380          [batch_size * seq_length, hidden_size]. Note this args can not be passed by None when the net is in outermost
1381          layer. Default None.
1382        - **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
1383          src_seq_length] where tgt_seq_length is the length of the decoder. Note this args can not be passed by
1384          None when the net is in outermost layer. Default None.
1385        - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
1386          past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
1387        - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used
1388          for incremental prediction when the use_past is True. Default None.
1389
1390    Outputs:
1391        Tuple, a tuple contains(`output`, `layer_present`)
1392
1393        - **output** (Tensor) - the output logit of this layer. The shape is [batch, seq_length, hidden_size] or
1394          [batch * seq_length, hidden_size].
1395        - **layer_present** (Tensor) - A tuple, where each tuple is the tensor of the projected key and value
1396          vector in self attention with shape ((batch_size, num_heads, size_per_head, tgt_seq_length),
1397          (batch_size, num_heads, tgt_seq_length, size_per_head), and of the projected key and value vector
1398          in cross attention with shape  (batch_size, num_heads, size_per_head, src_seq_length),
1399          (batch_size, num_heads, src_seq_length, size_per_head)).
1400
1401    Supported Platforms:
1402        ``Ascend`` ``GPU``
1403
1404    Examples:
1405        >>> import numpy as np
1406        >>> from mindspore import dtype as mstype
1407        >>> from mindspore.parallel.nn import TransformerDecoderLayer
1408        >>> from mindspore import Tensor
1409        >>> model = TransformerDecoderLayer(batch_size=2, hidden_size=64, ffn_hidden_size=64, num_heads=2,
1410        ...                                 src_seq_length=20, tgt_seq_length=10)
1411        >>> encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
1412        >>> decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
1413        >>> decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
1414        >>> memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
1415        >>> output, past = model(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
1416        >>> print(output.shape)
1417        (2, 10, 64)
1418        >>> print(past[0].shape)
1419        (2, 2, 32, 10)
1420        >>> print(past[1].shape)
1421        (2, 2, 10, 32)
1422        >>> print(past[2].shape)
1423        (2, 2, 32, 20)
1424        >>> print(past[3].shape)
1425        (2, 2, 20, 32)
1426    """
1427
1428    @_args_type_validator_check(batch_size=Validator.check_positive_int,
1429                                hidden_size=Validator.check_positive_int,
1430                                num_heads=Validator.check_positive_int,
1431                                ffn_hidden_size=Validator.check_positive_int,
1432                                src_seq_length=Validator.check_positive_int,
1433                                tgt_seq_length=Validator.check_positive_int,
1434                                attention_dropout_rate=Validator.check_non_negative_float,
1435                                hidden_dropout_rate=Validator.check_non_negative_float,
1436                                hidden_act=_valid_type_checks([str], "TransformerDecoderLayer"),
1437                                post_layernorm_residual=Validator.check_bool,
1438                                layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
1439                                                                           "TransformerDecoderLayer"),
1440                                softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
1441                                                                         "TransformerDecoderLayer"),
1442                                param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
1443                                                                    "TransformerDecoderLayer"),
1444                                parallel_config=_valid_type_checks([OpParallelConfig],
1445                                                                   "TransformerDecoderLayer"),
1446                                use_past=Validator.check_bool)
1447    def __init__(self, hidden_size,
1448                 ffn_hidden_size,
1449                 num_heads,
1450                 batch_size,
1451                 src_seq_length,
1452                 tgt_seq_length,
1453                 attention_dropout_rate=0.1,
1454                 hidden_dropout_rate=0.1,
1455                 post_layernorm_residual=False,
1456                 use_past=False,
1457                 layernorm_compute_type=mstype.float32,
1458                 softmax_compute_type=mstype.float32,
1459                 param_init_type=mstype.float32,
1460                 hidden_act='gelu',
1461                 moe_config=default_moe_config,
1462                 parallel_config=default_dpmp_config):
1463        super(TransformerDecoderLayer, self).__init__()
1464        _check_config(parallel_config)
1465        if num_heads % parallel_config.model_parallel != 0:
1466            raise ValueError(
1467                f"num heads must be divisibled by the model parallel way {parallel_config.model_parallel}, "
1468                f"but found {num_heads}")
1469        if hidden_size % parallel_config.model_parallel != 0:
1470            raise ValueError(
1471                f"hidden_size must be divisibled by the model parallel way {parallel_config.model_parallel}, "
1472                f"but found {hidden_size}")
1473        if ffn_hidden_size % parallel_config.model_parallel != 0:
1474            raise ValueError(
1475                f"ffn_hidden_size must be divisibled by the model parallel way {parallel_config.model_parallel}, "
1476                f"but found {ffn_hidden_size}")
1477        if use_past is True:
1478            raise ValueError(f"The {self.cls_name} does not support use_past=True.")
1479        self.batch_size = batch_size
1480        self.use_past = use_past
1481        self.softmax_compute_type = softmax_compute_type
1482
1483        self.src_seq_length = src_seq_length
1484        self.tgt_seq_length = tgt_seq_length
1485        self.use_past = use_past
1486        self.hidden_size = hidden_size
1487
1488        self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
1489        self.layernorm1.shard(((parallel_config.data_parallel, 1),))
1490        self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
1491        self.layernorm2.shard(((parallel_config.data_parallel, 1),))
1492
1493        self.attention = MultiHeadAttention(hidden_size=hidden_size,
1494                                            num_heads=num_heads,
1495                                            batch_size=batch_size,
1496                                            src_seq_length=tgt_seq_length,
1497                                            tgt_seq_length=tgt_seq_length,
1498                                            hidden_dropout_rate=hidden_dropout_rate,
1499                                            attention_dropout_rate=attention_dropout_rate,
1500                                            use_past=use_past,
1501                                            softmax_compute_type=softmax_compute_type,
1502                                            param_init_type=param_init_type,
1503                                            parallel_config=parallel_config)
1504        # Cross attention with the output of encoder as memory tensor
1505        self.cross_attention = MultiHeadAttention(hidden_size=hidden_size,
1506                                                  num_heads=num_heads,
1507                                                  batch_size=batch_size,
1508                                                  src_seq_length=tgt_seq_length,
1509                                                  tgt_seq_length=src_seq_length,
1510                                                  hidden_dropout_rate=hidden_dropout_rate,
1511                                                  attention_dropout_rate=attention_dropout_rate,
1512                                                  softmax_compute_type=softmax_compute_type,
1513                                                  use_past=use_past,
1514                                                  param_init_type=param_init_type,
1515                                                  parallel_config=parallel_config)
1516        self.cross_attention_layernorm = _LayerNorm((hidden_size,)).to_float(
1517            layernorm_compute_type)
1518        self.cross_attention_layernorm.shard(((parallel_config.data_parallel, 1),))
1519        self.use_moe = (moe_config.expert_num > 1)
1520        if self.use_moe is True:
1521            self.output = MoE(hidden_size=hidden_size,
1522                              dropout_rate=hidden_dropout_rate,
1523                              ffn_hidden_size=ffn_hidden_size,
1524                              param_init_type=param_init_type,
1525                              hidden_act=hidden_act,
1526                              moe_config=moe_config,
1527                              parallel_config=parallel_config)
1528        else:
1529            # Feed Forward Network, FFN
1530            self.output = FeedForward(hidden_size=hidden_size,
1531                                      dropout_rate=hidden_dropout_rate,
1532                                      ffn_hidden_size=ffn_hidden_size,
1533                                      hidden_act=hidden_act,
1534                                      param_init_type=param_init_type,
1535                                      parallel_config=parallel_config)
1536        self.post_layernorm_residual = post_layernorm_residual
1537        self.add = P.Add().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1)))
1538        self.add_3d = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
1539        self.dtype = mstype.float16
1540        self.key_past = None
1541        self.value_past = None
1542        if self.use_past:
1543            # operator used for state reuse
1544            self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),))
1545            self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
1546            self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
1547            size_per_head = int(hidden_size / num_heads)
1548            self.key_shape = (batch_size, num_heads, size_per_head, tgt_seq_length)
1549            self.value_shape = (batch_size, num_heads, tgt_seq_length, size_per_head)
1550            # parameters saving key and value states
1551            self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past")
1552            self.value_past = Parameter(Tensor(np.zeros(shape=self.value_shape), self.dtype), name="value_past")
1553            self.tile = P.Tile().shard(((1, 1),))
1554            self.mul = P.Mul().shard(((1, 1, 1, 1), (1,)))
1555            self.assign = P.Assign().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
1556
1557    def construct(self, hidden_stats,
1558                  decoder_mask,
1559                  encoder_output=None,
1560                  memory_mask=None,
1561                  init_reset=True, batch_valid_length=None):
1562        self._check_input(hidden_stats, decoder_mask, encoder_output, memory_mask, init_reset, batch_valid_length)
1563        # the returned shape is [bs, seq_length, embedding_size] or [bs * seq_length, embedding_size]
1564        hidden_shape = F.shape(hidden_stats)
1565        hidden_stats = F.reshape(hidden_stats, (-1, hidden_shape[-1]))
1566        input_x = self.layernorm1(hidden_stats)
1567        input_x = F.cast(input_x, self.dtype)
1568
1569        # indicate whether reset saved states
1570        key_reset = None
1571        value_reset = None
1572        if self.use_past:
1573            # reset states, init_reset True for reuse and False for reset
1574            key_reset = self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype)))
1575            value_reset = self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype)))
1576            # add dependency for desired execution order
1577            input_x = F.depend(input_x, key_reset)
1578            input_x = F.depend(input_x, value_reset)
1579
1580        attention, layer_present = self.attention(input_x, input_x, input_x, decoder_mask, self.key_past,
1581                                                  self.value_past, batch_valid_length)
1582        # For post-layernorm the inputs for residual path are output of self-attention and output of layernorm
1583        if self.post_layernorm_residual:
1584            x = self.add(input_x, attention)
1585        # For pre-layernorm the inputs for residual path are output of self-attention and input of this layer
1586        else:
1587            x = self.add(hidden_stats, attention)
1588
1589        middle_output = None
1590        if encoder_output is not None:
1591            middle_output = self.cross_attention_layernorm(x)
1592            middle_output = F.cast(middle_output, self.dtype)
1593            cross_attn_output, cross_layer_present = self.cross_attention(middle_output, encoder_output,
1594                                                                          encoder_output,
1595                                                                          memory_mask, self.key_past,
1596                                                                          self.value_past, batch_valid_length)
1597            layer_present += cross_layer_present
1598            if self.post_layernorm_residual:
1599                x = self.add(middle_output, cross_attn_output)
1600            else:
1601                x = self.add(x, cross_attn_output)
1602
1603        output_x = self.layernorm2(x)
1604        output_x = F.cast(output_x, self.dtype)
1605        aux_loss = None
1606        if self.use_moe is True:
1607            mlp_logit, aux_loss = self.output(output_x)
1608        else:
1609            mlp_logit = self.output(output_x)
1610
1611        value_update = None
1612        key_update = None
1613        if self.use_past:
1614            # current key and value
1615            key_present, value_present = layer_present
1616            # update key and value calculated this step
1617            key_update = self.assign(self.key_past, key_present)
1618            value_update = self.assign(self.value_past, value_present)
1619            # add dependency for desired execution order
1620            key_update = F.depend(key_update, key_reset)
1621            value_update = F.depend(value_update, value_reset)
1622
1623        # add dependency for desired execution order
1624        mlp_logit = F.depend(mlp_logit, value_update)
1625        mlp_logit = F.depend(mlp_logit, key_update)
1626
1627        # if shape is 3d, we reshape the inputs of the add
1628        if len(hidden_shape) == 3:
1629            output_x = P.Reshape()(output_x, hidden_shape)
1630            mlp_logit = P.Reshape()(mlp_logit, hidden_shape)
1631            x = P.Reshape()(x, hidden_shape)
1632
1633            if self.post_layernorm_residual:
1634                output = self.add_3d(output_x, mlp_logit)
1635            else:
1636                output = self.add_3d(x, mlp_logit)
1637        else:
1638            if self.post_layernorm_residual:
1639                output = self.add(output_x, mlp_logit)
1640            else:
1641                output = self.add(x, mlp_logit)
1642            output = F.reshape(output, hidden_shape)
1643
1644        if self.use_moe is True:
1645            return output, layer_present, aux_loss
1646        return output, layer_present
1647
1648    def _check_input(self, hidden_states, attention_mask, encoder_output, memory_mask, init_reset, batch_valid_length):
1649        r"""Check inputs"""
1650        if not self.use_past or (self.use_past and self.is_first_iteration):
1651            _check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
1652                               [[self.batch_size, self.tgt_seq_length, self.hidden_size],
1653                                [self.batch_size * self.tgt_seq_length, self.hidden_size]])
1654            _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
1655                               [self.batch_size, self.tgt_seq_length, self.tgt_seq_length])
1656
1657        else:
1658            _check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
1659                               [self.batch_size, 1, self.hidden_size])
1660            _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
1661                               [self.batch_size, 1, self.tgt_seq_length])
1662        _check_input_dtype(F.dtype(hidden_states), "hidden_states", [mstype.float32, mstype.float16], self.cls_name)
1663        _check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
1664        if encoder_output is not None:
1665            _check_shape_equal(F.shape(encoder_output), "encoder_output", self.cls_name,
1666                               [[self.batch_size, self.src_seq_length, self.hidden_size],
1667                                [self.batch_size * self.src_seq_length, self.hidden_size]])
1668            _check_input_dtype(F.dtype(encoder_output), "encoder_output",
1669                               [mstype.float32, mstype.float16], self.cls_name)
1670        if memory_mask is not None:
1671            _check_shape_equal(F.shape(memory_mask), "memory_mask", self.cls_name,
1672                               [self.batch_size, self.tgt_seq_length, self.src_seq_length])
1673            _check_input_dtype(F.dtype(memory_mask), "memory_mask",
1674                               [mstype.float32, mstype.float16], self.cls_name)
1675
1676        init_reset_is_tensor = isinstance(init_reset, Tensor)
1677        init_reset_is_default = init_reset is True
1678        batch_valid_length_is_tensor = isinstance(batch_valid_length, Tensor)
1679        batch_is_default = batch_valid_length is None
1680        _check_past_none_input_none(self.use_past, "init_reset", self.cls_name, True, init_reset_is_tensor,
1681                                    init_reset_is_default)
1682        _check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None,
1683                                    batch_valid_length_is_tensor, batch_is_default)
1684
1685        if self.use_past:
1686            _check_shape_equal(F.shape(init_reset), "init_reset", self.cls_name, [1])
1687            _check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name)
1688            _check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size])
1689            _check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
1690        return True
1691
1692
1693def _get_lambda_func(total_layer=None):
1694    r"""
1695        A wrapper function of specifying pipeline stage and gradient aggregation fusion. If the total layer
1696        is not None, for example, set in the transformer model, the pipeline stage setting function will be
1697        `(layer_id + 0) // (total_layers / parallel_config.pipeline_stage)` for the encoder and,
1698        `(layer_id + offset) //
1699        (total_layers / parallel_config.pipeline_stage)` for the decoder, where `offset` is the layers in the encoder.
1700    """
1701
1702    def _set_parallel_configure_for_layer(network, layer_id, offset, parallel_config, layers):
1703        r"""
1704            Default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.
1705
1706            Args:
1707                network(Cell) - Represents the transformer block
1708                layer_id(int) - Means the layer index for the current module, counts from zero.
1709                offset(int) - Means the layer_index needs an offset, if there are other modules in the net.
1710                layers(int) - The total layers used for the model.
1711        """
1712        # override the layers
1713        if total_layer:
1714            layers = total_layer
1715        # Used for the pipeline's stages setting
1716        if layers < parallel_config.pipeline_stage:
1717            raise ValueError(f"layers {layers} must be larger than pipeline stage {parallel_config.pipeline_stage}")
1718
1719        pp_dis = max(int(layers / parallel_config.pipeline_stage), 1)
1720        # the pipeline stage must be in [0, parallel_config.pipeline_stage - 1]
1721        pp_id = min((layer_id + offset) // pp_dis, parallel_config.pipeline_stage - 1)
1722        network.pipeline_stage = pp_id
1723        logger.info(f"pipeline stage id is {pp_id}")
1724
1725        # Used for optimizer's fusion tag
1726        dis = max(int(layers / parallel_config.gradient_aggregation_group), 1)
1727        network.set_comm_fusion(int((layer_id + offset) / dis) + 1)
1728        # Used for enabling recomputation of the block
1729        if parallel_config.recompute:
1730            network.recompute()
1731
1732    return _set_parallel_configure_for_layer
1733
1734
1735class TransformerEncoder(Cell):
1736    r"""
1737    Transformer Encoder module with multi-layer stacked of `TransformerEncoderLayer`, including multihead self
1738    attention and feedforward layer.
1739
1740    Args:
1741        batch_size(int): The batch size of the input tensor.
1742        num_layers(int): The layers of the `TransformerEncoderLayer`
1743        hidden_size(int): The hidden size of the input.
1744        ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
1745        seq_length(int): The seq_length of the input tensor.
1746        num_heads(int): The number of the heads.
1747        hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1
1748        attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1
1749        post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False.
1750        hidden_act(str): The activation of the internal feedforward layer. Supports 'relu',
1751                         'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
1752                         'hsigmoid', 'logsigmoid' and so on. Default: gelu.
1753        layernorm_compute_type(dtype.Number): The computation type of the layernorm.
1754            Should be dtype.float32 or dtype.float16. Default dtype.float32.
1755        softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
1756            Should be dtype.float32 or dtype.float16. Default mstype.float32.
1757        param_init_type(dtype.Number): The parameter initialization type of the module.
1758            Should be dtype.float32 or dtype.float16. Default dtype.float32.
1759        use_past(bool): Use the past state to compute, used for incremental prediction. For example, if we have two
1760            words and want to generate the ten more words. We just need to compute the two words's state only once,
1761            and generate the next word one by one. When use_past is True, there are two steps to run the prediction.
1762            The first step, set the is_first_iteration to be True by
1763            `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the
1764            is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. At this moment,
1765            pass the single step's input tensor, and loop it. Default False.
1766        lambda_func: A function can determine the fusion index, pipeline stages and recompute attribute. If the user
1767            wants to determine the pipeline stage and gradient aggregation fusion, the user can pass a function
1768            that accepts `network`, `layer_id`, `offset`, `parallel_config`, `layers`. The `network(Cell)`
1769            represents the transformer block, `layer_id(int)` means the layer index for the current module, counts from
1770            zero, `offset(int)` means the layer_index needs an offset, if there are other modules in the net. The
1771            default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.
1772        offset(int): The initial layer index for the `decoder`. Used for setting the fusion id and stage id, to not
1773            overlap with the encoder layer.
1774        moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
1775        parallel_config(TransformerOpParallelConfig): The parallel configure. Default `default_transformer_config`,
1776                                           an instance of `TransformerOpParallelConfig` with default args.
1777
1778    Inputs:
1779        - **hidden_states** (Tensor) - Tensor, shape should be [batch_size, seq_length, hidden_size] or
1780          [batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise,
1781          should be [batch_size, 1, hidden_size].
1782        - **attention_mask** (Tensor) - Tensor, attention mask with shape [batch_size, seq_length, seq_length]
1783        - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
1784          past value parameter used in the incremental prediction. Only valid when use_past is True. Default True
1785        - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used
1786          for incremental prediction when the use_past is True. Default None.
1787
1788    Outputs:
1789        Tuple, a tuple contains(`output`, `layer_present`)
1790
1791        - **output** (Tensor) - The float tensor of the output of the layer with
1792          shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size), if the use_past is
1793          False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size).
1794        - **layer_present** (Tuple) - A tuple with size of num_layers, where each tuple contains the Tensor the
1795          projected key and value vector with shape ((batch_size, num_heads, size_per_head, seq_length),
1796          and (batch_size, num_heads, seq_length, size_per_head)).
1797
1798    Supported Platforms:
1799        ``Ascend`` ``GPU``
1800
1801    Examples:
1802        >>> import numpy as np
1803        >>> from mindspore import dtype as mstype
1804        >>> from mindspore.parallel.nn import TransformerEncoder
1805        >>> from mindspore import Tensor
1806        >>> model = TransformerEncoder(batch_size=2, num_layers=2, hidden_size=8, ffn_hidden_size=64, seq_length=16,
1807        ...                            num_heads=2)
1808        >>> encoder_input_value = Tensor(np.ones((2, 16, 8)), mstype.float32)
1809        >>> encoder_input_mask = Tensor(np.ones((2, 16, 16)), mstype.float16)
1810        >>> output, past = model(encoder_input_value, encoder_input_mask)
1811        >>> print(output.shape)
1812        (2, 16, 8)
1813        >>> print(len(past))
1814        2
1815        >>> print(past[0][0].shape)
1816        (2, 2, 4, 16)
1817        >>> print(past[0][1].shape)
1818        (2, 2, 16, 4)
1819        # When use use_past=True, it includes two steps to implement the incremental prediction.
1820        # Step 1: set is_first_iteration=True, and input the full sequence length's state.
1821        >>> batch_valid_length = Tensor(np.ones((2,)), mstype.int32)
1822        >>> init_reset = Tensor([True], mstype.bool_)
1823        # Set is_first_iteration=True to generate the full memory states
1824        >>> model = TransformerEncoder(batch_size=2, hidden_size=8, ffn_hidden_size=64, seq_length=16,
1825        ...                            num_heads=2, num_layers=2, use_past=True)
1826        >>> model.add_flags_recursive(is_first_iteration=True)
1827        >>> hidden, past = model(encoder_input_value, encoder_input_mask, init_reset, batch_valid_length)
1828        >>> print(hidden.shape)
1829        (2, 16, 8)
1830        >>> print(past[0].shape)
1831        (2, 2, 4, 16)
1832        >>> print(past[1].shape)
1833        (2, 2, 16, 4)
1834        >>> encoder_input_value = Tensor(np.ones((2, 1, 8)), mstype.float32)
1835        >>> encoder_input_mask = Tensor(np.ones((2, 1, 16)), mstype.float16)
1836        >>> init_reset = Tensor([False], mstype.bool_)
1837        # Step 2: set is_first_iteration=False, and pass the single word to run the prediction rather than the full
1838        # sequence.
1839        >>> model.add_flags_recursive(is_first_iteration=False)
1840        >>> hidden, past = model(encoder_input_value, encoder_input_mask, init_reset, batch_valid_length)
1841        >>> print(hidden.shape)
1842        (2, 1, 8)
1843        >>> print(past[0].shape)
1844        (2, 2, 4, 16)
1845        >>> print(past[1].shape)
1846        (2, 2, 16, 4)
1847    """
1848
1849    @_args_type_validator_check(batch_size=Validator.check_positive_int,
1850                                hidden_size=Validator.check_positive_int,
1851                                num_heads=Validator.check_positive_int,
1852                                ffn_hidden_size=Validator.check_positive_int,
1853                                seq_length=Validator.check_positive_int,
1854                                num_layers=Validator.check_positive_int,
1855                                offset=Validator.check_non_negative_int,
1856                                attention_dropout_rate=Validator.check_non_negative_float,
1857                                hidden_dropout_rate=Validator.check_non_negative_float,
1858                                hidden_act=_valid_type_checks([str], "TransformerEncoder"),
1859                                post_layernorm_residual=Validator.check_bool,
1860                                layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
1861                                                                           "TransformerEncoder"),
1862                                softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
1863                                                                         "TransformerEncoder"),
1864                                param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
1865                                                                    "TransformerEncoder"),
1866                                parallel_config=_valid_type_checks([TransformerOpParallelConfig],
1867                                                                   "TransformerEncoder"),
1868                                use_past=Validator.check_bool)
1869    def __init__(self,
1870                 batch_size,
1871                 num_layers,
1872                 hidden_size,
1873                 ffn_hidden_size,
1874                 seq_length,
1875                 num_heads,
1876                 attention_dropout_rate=0.1,
1877                 hidden_dropout_rate=0.1,
1878                 hidden_act='gelu',
1879                 post_layernorm_residual=False,
1880                 layernorm_compute_type=mstype.float32,
1881                 softmax_compute_type=mstype.float32,
1882                 param_init_type=mstype.float32,
1883                 lambda_func=None,
1884                 offset=0,
1885                 use_past=False,
1886                 moe_config=default_moe_config,
1887                 parallel_config=default_transformer_config):
1888        super(TransformerEncoder, self).__init__()
1889        _check_config(parallel_config)
1890
1891        self.use_moe = (moe_config.expert_num > 1)
1892        self.add = P.Add().shard(((), ()))
1893        self.aux_loss = Tensor(0.0, mstype.float32)
1894        if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
1895            raise RuntimeError(f"The {self.cls_name} does not support auto parallel mode now.")
1896        self.num_layers = num_layers
1897        self.blocks = nn.CellList()
1898        for i in range(num_layers):
1899            block = TransformerEncoderLayer(hidden_size=hidden_size,
1900                                            batch_size=batch_size,
1901                                            ffn_hidden_size=ffn_hidden_size,
1902                                            seq_length=seq_length,
1903                                            attention_dropout_rate=attention_dropout_rate,
1904                                            hidden_dropout_rate=hidden_dropout_rate,
1905                                            layernorm_compute_type=layernorm_compute_type,
1906                                            softmax_compute_type=softmax_compute_type,
1907                                            num_heads=num_heads,
1908                                            hidden_act=hidden_act,
1909                                            post_layernorm_residual=post_layernorm_residual,
1910                                            param_init_type=param_init_type,
1911                                            use_past=use_past,
1912                                            moe_config=moe_config,
1913                                            parallel_config=parallel_config.dp_mp_config)
1914            # If the user doesn't pass the fusion function, use the default one
1915            if not lambda_func:
1916                lambda_func = _get_lambda_func()
1917
1918            lambda_func(block, layer_id=i, layers=num_layers,
1919                        offset=offset, parallel_config=parallel_config)
1920            self.blocks.append(block)
1921
1922    def construct(self, hidden_states, attention_mask, init_reset=True, batch_valid_length=None):
1923        present_layer = ()
1924        if self.use_moe is True:
1925            accum_loss = self.aux_loss
1926            for i in range(self.num_layers):
1927                hidden_states, present, aux_loss = self.blocks[i](hidden_states,
1928                                                                  attention_mask,
1929                                                                  init_reset,
1930                                                                  batch_valid_length)
1931                present_layer = present_layer + (present,)
1932                accum_loss = self.add(accum_loss, aux_loss)
1933            return hidden_states, present_layer, accum_loss
1934
1935        for i in range(self.num_layers):
1936            hidden_states, present = self.blocks[i](hidden_states,
1937                                                    attention_mask,
1938                                                    init_reset,
1939                                                    batch_valid_length)
1940            present_layer = present_layer + (present,)
1941
1942        return hidden_states, present_layer
1943
1944
1945class TransformerDecoder(Cell):
1946    r"""
1947    Transformer Decoder module with multi-layer stacked of `TransformerDecoderLayer`, including multihead self
1948    attention, cross attention and feedforward layer.
1949
1950    Args:
1951        batch_size(int): The batch size of the input tensor.
1952        num_layers(int): The layers of the `TransformerDecoderLayer`.
1953        hidden_size(int): The hidden size of the input.
1954        ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
1955        src_seq_length(int): The input source sequence length.
1956        tgt_seq_length(int): The input target sequence length.
1957        num_heads(int): The number of the heads.
1958        hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1.
1959        attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1.
1960        post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False.
1961        hidden_act(str): The activation of the internal feedforward layer. Supports 'relu',
1962                         'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
1963                         'hsigmoid', 'logsigmoid' and so on. Default: gelu.
1964        layernorm_compute_type(dtype.Number): The computation type of the layernorm.
1965            Should be dtype.float32 or dtype.float16. Default dtype.float32.
1966        softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
1967            Should be dtype.float32 or dtype.float16. Default mstype.float32.
1968        param_init_type(dtype.Number): The parameter initialization type of the module.
1969            Should be dtype.float32 or dtype.float16. Default dtype.float32.
1970        offset(int): The initial layer index for the `decoder`. Used for setting the fusion id and stage id, to not
1971            overlap with the encoder layer.
1972        lambda_func: A function can determine the fusion index, pipeline stages and recompute attribute. If the user
1973            wants to determine the pipeline stage and gradient aggregation fusion, the user can pass a function
1974            that accepts `network`, `layer_id`, `offset`, `parallel_config`, `layers`. The `network(Cell)`
1975            represents the transformer block, `layer_id(int)` means the layer index for the current module, counts from
1976            zero, `offset(int)` means the layer_index needs an offset, if there are other modules in the net. The
1977            default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.
1978            Default: None
1979        moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
1980        parallel_config(TransformerOpParallelConfig): The parallel configure. Default `default_transformer_config`,
1981                                           an instance of `TransformerOpParallelConfig` with default args.
1982
1983    Inputs:
1984        - **hidden_stats** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size] or
1985          [batch_size * seq_length, hidden_size]
1986        - **attention_mask** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length]
1987        - **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size] or
1988          [batch_size * seq_length, hidden_size]. Note this args can not be passed by None when the net is in outermost
1989          layer. Default None.
1990        - **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
1991          src_seq_length] where tgt_seq_length is the length of the decoder. Note this args can not be passed by
1992          None when the net is in outermost layer. Default None.
1993        - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
1994          past value parameter used in the incremental prediction. Only valid when use_past is True. Default True
1995        - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
1996          Used for incremental prediction when the use_past is True. Default None.
1997
1998    Outputs:
1999        Tuple, a tuple contains(`output`, `layer_present`)
2000
2001        - **output** (Tensor) - The output logit of this layer. The shape is [batch, tgt_seq_length, hidden_size] or
2002          [batch * tgt_seq_length, hidden_size]
2003        - **layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor of the projected
2004          key and value vector in self attention with shape ((batch_size, num_heads, size_per_head, tgt_seq_length),
2005          (batch_size, num_heads, tgt_seq_length, size_per_head), and of the projected key and value vector
2006          in cross attention with shape  (batch_size, num_heads, size_per_head, src_seq_length),
2007          (batch_size, num_heads, src_seq_length, size_per_head)).
2008
2009    Supported Platforms:
2010        ``Ascend`` ``GPU``
2011
2012    Examples:
2013        >>> import numpy as np
2014        >>> from mindspore import dtype as mstype
2015        >>> from mindspore.parallel.nn import TransformerDecoder
2016        >>> from mindspore import Tensor
2017        >>> model = TransformerDecoder(batch_size=2, num_layers=1, hidden_size=64, ffn_hidden_size=64,
2018        ...                            num_heads=2, src_seq_length=20, tgt_seq_length=10)
2019        >>> encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
2020        >>> decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
2021        >>> decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
2022        >>> memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
2023        >>> output, past = model(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
2024        >>> print(output.shape)
2025        (2, 10, 64)
2026        >>> print(len(past))
2027        1
2028        >>> print(past[0][0].shape)
2029        (2, 2, 32, 10)
2030        >>> print(past[0][1].shape)
2031        (2, 2, 10, 32)
2032        >>> print(past[0][2].shape)
2033        (2, 2, 32, 20)
2034        >>> print(past[0][3].shape)
2035        (2, 2, 20, 32)
2036
2037    """
2038
2039    @_args_type_validator_check(batch_size=Validator.check_positive_int,
2040                                hidden_size=Validator.check_positive_int,
2041                                num_heads=Validator.check_positive_int,
2042                                ffn_hidden_size=Validator.check_positive_int,
2043                                src_seq_length=Validator.check_positive_int,
2044                                num_layers=Validator.check_positive_int,
2045                                tgt_seq_length=Validator.check_positive_int,
2046                                offset=Validator.check_non_negative_int,
2047                                attention_dropout_rate=Validator.check_non_negative_float,
2048                                hidden_dropout_rate=Validator.check_non_negative_float,
2049                                hidden_act=_valid_type_checks([str], "TransformerDecoder"),
2050                                post_layernorm_residual=Validator.check_bool,
2051                                layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
2052                                                                           "TransformerDecoder"),
2053                                softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
2054                                                                         "TransformerDecoder"),
2055                                param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
2056                                                                    "TransformerDecoder"),
2057                                parallel_config=_valid_type_checks([TransformerOpParallelConfig],
2058                                                                   "TransformerDecoder"),
2059                                use_past=Validator.check_bool)
2060    def __init__(self,
2061                 num_layers,
2062                 batch_size,
2063                 hidden_size,
2064                 ffn_hidden_size,
2065                 src_seq_length,
2066                 tgt_seq_length,
2067                 num_heads,
2068                 attention_dropout_rate=0.1,
2069                 hidden_dropout_rate=0.1,
2070                 post_layernorm_residual=False,
2071                 layernorm_compute_type=mstype.float32,
2072                 softmax_compute_type=mstype.float32,
2073                 param_init_type=mstype.float32,
2074                 hidden_act='gelu',
2075                 lambda_func=None,
2076                 use_past=False,
2077                 offset=0,
2078                 moe_config=default_moe_config,
2079                 parallel_config=default_transformer_config):
2080        super(TransformerDecoder, self).__init__()
2081        _check_config(parallel_config)
2082
2083        self.add = P.Add().shard(((), ()))
2084        self.aux_loss = Tensor(0.0, mstype.float32)
2085        if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
2086            raise RuntimeError(f"The {self.cls_name} does not support auto parallel mode now.")
2087        self.num_layers = num_layers
2088        self.blocks = nn.CellList()
2089        self.use_moe = (moe_config.expert_num > 1)
2090        for i in range(num_layers):
2091            block = TransformerDecoderLayer(hidden_size=hidden_size,
2092                                            batch_size=batch_size,
2093                                            ffn_hidden_size=ffn_hidden_size,
2094                                            src_seq_length=src_seq_length,
2095                                            tgt_seq_length=tgt_seq_length,
2096                                            attention_dropout_rate=attention_dropout_rate,
2097                                            hidden_dropout_rate=hidden_dropout_rate,
2098                                            num_heads=num_heads,
2099                                            layernorm_compute_type=layernorm_compute_type,
2100                                            softmax_compute_type=softmax_compute_type,
2101                                            hidden_act=hidden_act,
2102                                            use_past=use_past,
2103                                            param_init_type=param_init_type,
2104                                            post_layernorm_residual=post_layernorm_residual,
2105                                            moe_config=moe_config,
2106                                            parallel_config=parallel_config.dp_mp_config)
2107            # If the user doesn't pass the fusion function, use the default one
2108            if not lambda_func:
2109                lambda_func = _get_lambda_func()
2110
2111            lambda_func(block, layer_id=i, layers=num_layers,
2112                        offset=offset, parallel_config=parallel_config)
2113
2114            self.blocks.append(block)
2115
2116    def construct(self, hidden_states, attention_mask, encoder_output=None, memory_mask=None,
2117                  init_reset=True, batch_valid_length=None):
2118        present_layer = ()
2119        if self.use_moe is True:
2120            accum_loss = self.aux_loss
2121            for i in range(self.num_layers):
2122                hidden_states, present, aux_loss = self.blocks[i](hidden_states,
2123                                                                  attention_mask,
2124                                                                  encoder_output,
2125                                                                  memory_mask,
2126                                                                  init_reset,
2127                                                                  batch_valid_length)
2128                present_layer = present_layer + (present,)
2129                accum_loss = self.add(accum_loss, aux_loss)
2130            return hidden_states, present_layer, accum_loss
2131
2132        # Loop through each self-attention layer
2133        for i in range(self.num_layers):
2134            hidden_states, present = self.blocks[i](hidden_states,
2135                                                    attention_mask,
2136                                                    encoder_output,
2137                                                    memory_mask,
2138                                                    init_reset,
2139                                                    batch_valid_length)
2140            present_layer = present_layer + (present,)
2141
2142        return hidden_states, present_layer
2143
2144
2145class Transformer(Cell):
2146    r"""
2147    Transformer module including encoder and decoder. The difference with the original implements is the module use
2148    the residual addition before the layer normalization. And the default hidden act is `gelu`.
2149    The details can be found in `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_.
2150
2151    Note:
2152        This is an experimental interface that is subject to change and/or deletion.
2153
2154    Args:
2155        batch_size(int): The batch size of the input tensor.
2156        encoder_layers(int): The layers of the `TransformerEncoderLayer`.
2157        decoder_layers(int): The layers of the `TransformerDecoderLayer`.
2158        hidden_size(int): The hidden size of the input.
2159        ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
2160        src_seq_length(int): The seq_length of the encoder's input tensor.
2161        tgt_seq_length(int): The seq_length of the decoder's input tensor.
2162        num_heads(int): The number of the heads. Default: 2.
2163        hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1
2164        attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1
2165        post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False.
2166        layernorm_compute_type(dtype.Number): The computation type of the layernorm.
2167            Should be dtype.float32 or dtype.float16. Default dtype.float32.
2168        softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
2169            Should be dtype.float32 or dtype.float16. Default mstype.float32.
2170        param_init_type(dtype.Number): The parameter initialization type of the module.
2171            Should be dtype.float32 or dtype.float16. Default dtype.float32.
2172        hidden_act(str): The activation of the internal feedforward layer. Supports 'relu',
2173                         'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
2174                         'hsigmoid', 'logsigmoid' and so on. Default: gelu.
2175        moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
2176        lambda_func: A function can determine the fusion index, pipeline stages and recompute attribute. If the user
2177            wants to determine the pipeline stage and gradient aggregation fusion, the user can pass a function
2178            that accepts `network`, `layer_id`, `offset`, `parallel_config`, `layers`. The `network(Cell)`
2179            represents the transformer block, `layer_id(int)` means the layer index for the current module, counts from
2180            zero, `offset(int)` means the layer_index needs an offset, if there are other modules in the net. The
2181            default setting for the pipeline is: `(layer_id + offset) // ((encoder_layers + decoder_length)
2182            / pipeline_stage)`.
2183        parallel_config(TransformerOpParallelConfig): The parallel configure. Default `default_transformer_config`,
2184                                           an instance of `TransformerOpParallelConfig` with default args.
2185
2186    Inputs:
2187        - **encoder_inputs** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size] or
2188          [batch_size * seq_length, hidden_size].
2189        - **encoder_masks** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length].
2190        - **decoder_inputs** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size] or
2191          [batch_size * seq_length, hidden_size],
2192          this should be none if the decoder layer is 0.
2193        - **decoder_masks** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length]
2194        - **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
2195          src_seq_length]
2196          where tgt_seq_length is the length of the decoder. the output of the encoder with shape [batch_size,
2197          seq_length, hidden_size], this should be none if the decoder layer is 0.
2198        - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
2199          past value parameter used in the incremental prediction. Only valid when use_past is True. Default True
2200        - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used
2201          for incremental prediction when the use_past is True. Default None.
2202
2203    Outputs:
2204        Tuple, a tuple contains(`output`, `encoder_layer_present`, `encoder_layer_present`)
2205
2206        - **output** (Tensor) - If there is only encoder, the output logit of the encoder layer. The shape is
2207          [batch, src_seq_length, hidden_size] or [batch * src_seq_length, hidden_size], if there are encoder and
2208          decoders, the output is from the decoder layer. The shape is [batch, tgt_seq_length, hidden_size] or
2209          [batch * tgt_seq_length, hidden_size].
2210        - **encoder_layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor the
2211          projected key and value vector in self attention with shape ((batch_size, num_heads, size_per_head,
2212          src_seq_length), (batch_size, num_heads, src_seq_length, size_per_head)).
2213        - **decoder_layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor
2214          of the projected key and value vector in self attention with shape ((batch_size, num_heads, size_per_head,
2215          tgt_seq_length), (batch_size, num_heads, tgt_seq_length, size_per_head)), and the
2216          projected key and value vector in cross attention with shape
2217          (batch_size, num_heads, size_per_head, src_seq_length),
2218          (batch_size, num_heads, src_seq_length, size_per_head)). If the decoder is not set, the
2219          returned value will be None.
2220
2221    Supported Platforms:
2222        ``Ascend`` ``GPU``
2223
2224    Examples:
2225        >>> import numpy as np
2226        >>> from mindspore import dtype as mstype
2227        >>> from mindspore.parallel.nn import Transformer
2228        >>> from mindspore import Tensor
2229        >>> model = Transformer(batch_size=2, encoder_layers=1, decoder_layers=2, hidden_size=64, ffn_hidden_size=64,
2230        ...         src_seq_length=20, tgt_seq_length=10)
2231        >>> encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
2232        >>> encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
2233        >>> decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
2234        >>> decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
2235        >>> memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
2236        >>> output, en_past, de_past = model(encoder_input_value, encoder_input_mask, decoder_input_value,
2237        ...                                  decoder_input_mask, memory_mask)
2238        >>> print(output.shape)
2239        (2, 10, 64)
2240        >>> print(len(en_past))
2241        1
2242        >>> print(len(de_past))
2243        2
2244        >>> print(en_past[0][0].shape)
2245        (2, 2, 32, 20)
2246        >>> print(en_past[0][1].shape)
2247        (2, 2, 20, 32)
2248        >>> print(de_past[0][0].shape)
2249        (2, 2, 32, 10)
2250        >>> print(de_past[0][1].shape)
2251        (2, 2, 10, 32)
2252        >>> print(de_past[0][2].shape)
2253        (2, 2, 32, 20)
2254        >>> print(de_past[0][3].shape)
2255        (2, 2, 20, 32)
2256
2257    """
2258
2259    @_args_type_validator_check(batch_size=Validator.check_positive_int,
2260                                hidden_size=Validator.check_positive_int,
2261                                num_heads=Validator.check_positive_int,
2262                                ffn_hidden_size=Validator.check_positive_int,
2263                                src_seq_length=Validator.check_positive_int,
2264                                encoder_layers=Validator.check_positive_int,
2265                                decoder_layers=Validator.check_non_negative_int,
2266                                tgt_seq_length=Validator.check_positive_int,
2267                                attention_dropout_rate=Validator.check_non_negative_float,
2268                                hidden_dropout_rate=Validator.check_non_negative_float,
2269                                hidden_act=_valid_type_checks([str], "Transformer"),
2270                                post_layernorm_residual=Validator.check_bool,
2271                                layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
2272                                                                           "Transformer"),
2273                                softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
2274                                                                         "Transformer"),
2275                                param_init_type=_valid_value_checks([mstype.float32, mstype.float16], "Transformer"),
2276                                parallel_config=_valid_type_checks([TransformerOpParallelConfig], "Transformer"),
2277                                use_past=Validator.check_bool)
2278    def __init__(self,
2279                 hidden_size,
2280                 batch_size,
2281                 ffn_hidden_size,
2282                 src_seq_length,
2283                 tgt_seq_length,
2284                 encoder_layers=3,
2285                 decoder_layers=3,
2286                 num_heads=2,
2287                 attention_dropout_rate=0.1,
2288                 hidden_dropout_rate=0.1,
2289                 hidden_act='gelu',
2290                 post_layernorm_residual=False,
2291                 layernorm_compute_type=mstype.float32,
2292                 softmax_compute_type=mstype.float32,
2293                 param_init_type=mstype.float32,
2294                 lambda_func=None,
2295                 use_past=False,
2296                 moe_config=default_moe_config,
2297                 parallel_config=default_transformer_config):
2298        super(Transformer, self).__init__()
2299        _check_config(parallel_config)
2300        self.batch_size = batch_size
2301        self.hidden_size = hidden_size
2302        self.src_seq_length = src_seq_length
2303        self.tgt_seq_length = tgt_seq_length
2304        self.use_past = use_past
2305        if encoder_layers <= 0 < decoder_layers:
2306            raise ValueError(f"Transformer doest support encoder layer {encoder_layers} and decoder"
2307                             f"layer {decoder_layers}, please use TransformerDecoder")
2308        if encoder_layers > 0 and decoder_layers > 0 and use_past is True:
2309            raise ValueError(f"The {self.cls_name} with encoder and decoder does not support use_past=True.")
2310        if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
2311            raise RuntimeError(f"The {self.cls_name} does not support auto parallel mode now.")
2312        # The shard setting of Transformer is set within the TransformerEncoderLayer
2313        if not lambda_func:
2314            lambda_func = _get_lambda_func(total_layer=encoder_layers + decoder_layers)
2315
2316        self.use_moe = (moe_config.expert_num > 1)
2317        self.add = P.Add().shard(((), ()))
2318        self.aux_loss = Tensor(0.0, mstype.float32)
2319        if encoder_layers > 0:
2320            self.encoder = TransformerEncoder(num_layers=encoder_layers,
2321                                              batch_size=batch_size,
2322                                              hidden_size=hidden_size,
2323                                              ffn_hidden_size=ffn_hidden_size,
2324                                              num_heads=num_heads,
2325                                              seq_length=src_seq_length,
2326                                              attention_dropout_rate=attention_dropout_rate,
2327                                              hidden_dropout_rate=hidden_dropout_rate,
2328                                              hidden_act=hidden_act,
2329                                              layernorm_compute_type=layernorm_compute_type,
2330                                              softmax_compute_type=softmax_compute_type,
2331                                              post_layernorm_residual=post_layernorm_residual,
2332                                              param_init_type=param_init_type,
2333                                              lambda_func=lambda_func,
2334                                              use_past=use_past,
2335                                              moe_config=moe_config,
2336                                              parallel_config=parallel_config)
2337        else:
2338            self.encoder = None
2339
2340        # Offset is needed as the encoder has consumed some flags.
2341        # so the decoder need to increase the flags based on the encoder layer
2342        self.decoder = None
2343        if decoder_layers > 0:
2344            self.decoder = TransformerDecoder(num_layers=decoder_layers,
2345                                              batch_size=batch_size,
2346                                              hidden_size=hidden_size,
2347                                              ffn_hidden_size=ffn_hidden_size,
2348                                              num_heads=num_heads,
2349                                              src_seq_length=src_seq_length,
2350                                              tgt_seq_length=tgt_seq_length,
2351                                              attention_dropout_rate=attention_dropout_rate,
2352                                              hidden_dropout_rate=hidden_dropout_rate,
2353                                              hidden_act=hidden_act,
2354                                              post_layernorm_residual=post_layernorm_residual,
2355                                              layernorm_compute_type=layernorm_compute_type,
2356                                              softmax_compute_type=softmax_compute_type,
2357                                              lambda_func=lambda_func,
2358                                              use_past=use_past,
2359                                              param_init_type=param_init_type,
2360                                              offset=encoder_layers,
2361                                              moe_config=moe_config,
2362                                              parallel_config=parallel_config)
2363
2364    def construct(self, encoder_inputs,
2365                  encoder_masks,
2366                  decoder_inputs=None,
2367                  decoder_masks=None,
2368                  memory_mask=None,
2369                  init_reset=True,
2370                  batch_valid_length=None):
2371
2372        encoder_output = None
2373        output = None
2374        encoder_layer_present = None
2375        decoder_layer_present = None
2376        accum_loss = self.aux_loss
2377        if self.encoder is not None:
2378            if self.use_moe is True:
2379                encoder_output, encoder_layer_present, encoder_aux_loss = self.encoder(encoder_inputs, encoder_masks,
2380                                                                                       init_reset, batch_valid_length)
2381                accum_loss = self.add(accum_loss, encoder_aux_loss)
2382            else:
2383                encoder_output, encoder_layer_present = self.encoder(encoder_inputs, encoder_masks, init_reset,
2384                                                                     batch_valid_length)
2385            output = encoder_output
2386
2387        if self.decoder is not None:
2388            # decoder mask should be created outside of the model
2389            if self.use_moe is True:
2390                decoder_output, decoder_layer_present, decoder_aux_loss = self.decoder(decoder_inputs, decoder_masks,
2391                                                                                       encoder_output, memory_mask,
2392                                                                                       init_reset, batch_valid_length)
2393                accum_loss = self.add(accum_loss, decoder_aux_loss)
2394            else:
2395                decoder_output, decoder_layer_present = self.decoder(decoder_inputs,
2396                                                                     decoder_masks,
2397                                                                     encoder_output,
2398                                                                     memory_mask, init_reset,
2399                                                                     batch_valid_length)
2400            output = decoder_output
2401        if self.use_moe is True:
2402            return output, encoder_layer_present, decoder_layer_present, accum_loss
2403        return output, encoder_layer_present, decoder_layer_present
2404