1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16Note: 17 Transformer Networks. This is an experimental interface that is subject to change and/or deletion. 18""" 19import math 20import numpy as np 21from mindspore.common.tensor import Tensor 22from mindspore.common.parameter import Parameter 23from mindspore.common.initializer import initializer 24from mindspore import nn 25from mindspore import context 26import mindspore.common.dtype as mstype 27from mindspore.ops import operations as P 28from mindspore.ops import functional as F 29from mindspore.nn.cell import Cell 30from mindspore._checkparam import Validator 31from mindspore import log as logger 32from mindspore.parallel._utils import _get_parallel_mode 33from mindspore.context import ParallelMode 34from .layers import _LayerNorm, _Linear, _check_input_shape, \ 35 _args_type_validator_check, _valid_type_checks, _valid_value_checks, \ 36 _check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value 37from .op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, _Config, _check_config 38from .moe import default_moe_config, MoE 39 40__all__ = [ 41 "AttentionMask", 42 "VocabEmbedding", 43 "MultiHeadAttention", 44 "FeedForward", 45 "TransformerEncoder", 46 "TransformerDecoder", 47 "TransformerEncoderLayer", 48 "TransformerDecoderLayer", 49 "Transformer", 50 "TransformerOpParallelConfig", 51 "EmbeddingOpParallelConfig"] 52 53 54class EmbeddingOpParallelConfig(_Config): 55 r""" 56 EmbeddingOpParallelConfig for the setting data parallel or row slice for the embedding table. 57 58 Args: 59 data_parallel (int): The data parallel way. Default: 1 60 model_parallel (int): The model parallel way. Default: 1 61 vocab_emb_dp (bool): Shard embedding in model parallel or data parallel. Default: True 62 63 Supported Platforms: 64 ``Ascend`` ``GPU`` 65 66 Examples: 67 >>> config=EmbeddingOpParallelConfig(data_parallel=1, model_parallel=1, vocab_emb_dp=True) 68 """ 69 70 def __init__(self, data_parallel=1, model_parallel=1, vocab_emb_dp=True): 71 self._dp_mp_config = OpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel) 72 Validator.check_bool(vocab_emb_dp, "vocab_emb_dp") 73 self.vocab_emb_dp = vocab_emb_dp 74 75 @property 76 def data_parallel(self): 77 return self._dp_mp_config.data_parallel 78 79 @data_parallel.setter 80 def data_parallel(self, value): 81 self._dp_mp_config.data_parallel = value 82 83 @property 84 def model_parallel(self): 85 return self._dp_mp_config.model_parallel 86 87 @model_parallel.setter 88 def model_parallel(self, value): 89 self._dp_mp_config.model_parallel = value 90 91 @property 92 def vocab_emb_dp(self): 93 return self._vocab_emb_dp 94 95 @vocab_emb_dp.setter 96 def vocab_emb_dp(self, value): 97 Validator.check_bool(value, "vocab_emb_dp") 98 self._vocab_emb_dp = value 99 100 @property 101 def dp_mp_config(self): 102 r""" 103 To obtain the DPMPlConfig for the setting data parallel, model parallel 104 105 Supported Platforms: 106 ``Ascend`` ``GPU`` 107 108 Examples: 109 >>> config=EmbeddingOpParallelConfig(data_parallel=1, model_parallel=1, vocab_emb_dp=True) 110 >>> parallel_config = config.dp_mp_config 111 """ 112 return self._dp_mp_config 113 114 115class TransformerOpParallelConfig(_Config): 116 r""" 117 TransformerOpParallelConfig for the setting global data parallel, model parallel and fusion group. 118 The parallel configure setting. 119 120 Note: 121 Except the recompute argument, other arguments will not be effective when the user doesn't set 122 auto_parallel_context to `SEMI_AUTO_PARALLEL` or `AUTO_PARALLEL`. 123 The micro_batch_num must be greater than or equal to pipeline_stage. The data_parallel\*model_parallel 124 \*pipeline_stage must be equal or less equal to the device. When setting the pipeline stage and 125 optimizer_shard, the config will overwrite the auto_parallel_context. 126 127 Args: 128 data_parallel (int): The data parallel way. Default: 1. 129 model_parallel (int): The model parallel way. Default: 1. 130 pipeline_stage (int): The number of the pipeline stage. Should be a positive value. Default: 1. 131 micro_batch_num (int): The microe size of the batches for the pipeline training. Default: 1. 132 optimizer_shard (bool): Whether to enable optimizer shard. Default False. 133 gradient_aggregation_group (int): The fusion group size of the optimizer state sharding. Default: 4. 134 recompute (bool): Enable recomputation of the transformer block or not. Default: False. 135 vocab_emb_dp (bool): Shard embedding in model parallel or data parallel. Default: True. 136 137 Supported Platforms: 138 ``Ascend`` ``GPU`` 139 140 Examples: 141 >>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1) 142 """ 143 144 def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1, recompute=False, 145 optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True): 146 self.recompute = recompute 147 self.optimizer_shard = optimizer_shard 148 self.gradient_aggregation_group = gradient_aggregation_group 149 self._embed_dp_mp_config = EmbeddingOpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel, 150 vocab_emb_dp=vocab_emb_dp) 151 self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num) 152 153 @property 154 def recompute(self): 155 return self._recompute 156 157 @recompute.setter 158 def recompute(self, value): 159 Validator.check_bool(value, "recompute") 160 self._recompute = value 161 162 @property 163 def vocab_emb_dp(self): 164 return self._embed_dp_mp_config.vocab_emb_dp 165 166 @vocab_emb_dp.setter 167 def vocab_emb_dp(self, value): 168 self._embed_dp_mp_config.vocab_emb_dp = value 169 170 @property 171 def gradient_aggregation_group(self): 172 return self._gradient_aggregation_group 173 174 @gradient_aggregation_group.setter 175 def gradient_aggregation_group(self, value): 176 Validator.check_positive_int(value, "gradient_aggregation_group") 177 self._gradient_aggregation_group = value 178 179 @property 180 def micro_batch_num(self): 181 return self._pp_config.micro_batch_num 182 183 @micro_batch_num.setter 184 def micro_batch_num(self, value): 185 self._pp_config.micro_batch_num = value 186 187 @property 188 def model_parallel(self): 189 return self._embed_dp_mp_config.model_parallel 190 191 @model_parallel.setter 192 def model_parallel(self, value): 193 self._embed_dp_mp_config.model_parallel = value 194 195 @property 196 def data_parallel(self): 197 return self._embed_dp_mp_config.data_parallel 198 199 @data_parallel.setter 200 def data_parallel(self, value): 201 self._embed_dp_mp_config.data_parallel = value 202 203 @property 204 def pipeline_stage(self): 205 return self._pp_config.pipeline_stage 206 207 @pipeline_stage.setter 208 def pipeline_stage(self, value): 209 self._pp_config.pipeline_stage = value 210 211 @property 212 def optimizer_shard(self): 213 return self._optimizer_shard 214 215 @optimizer_shard.setter 216 def optimizer_shard(self, value): 217 Validator.check_bool(value, "optimizer_shard") 218 self._optimizer_shard = value 219 context.set_auto_parallel_context(enable_parallel_optimizer=value) 220 221 @property 222 def embedding_dp_mp_config(self): 223 r""" 224 To obtain the EmbeddingParallelConfig for the setting data parallel, model parallel and embedding 225 parallel. 226 227 Supported Platforms: 228 ``Ascend`` ``GPU`` 229 230 Examples: 231 >>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1, vocab_emb_dp=True) 232 >>> parallel_config = config.embedding_dp_mp_config 233 """ 234 return self._embed_dp_mp_config 235 236 @property 237 def dp_mp_config(self): 238 r""" 239 To obtain the EmbeddingParallelConfig for the setting data parallel, model parallel and embedding 240 parallel. 241 242 Supported Platforms: 243 ``Ascend`` ``GPU`` 244 245 Examples: 246 >>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1, vocab_emb_dp=True) 247 >>> parallel_config = config.dp_mp_config 248 """ 249 return self._embed_dp_mp_config.dp_mp_config 250 251 252default_transformer_config = TransformerOpParallelConfig() 253default_embedding_parallel_config = EmbeddingOpParallelConfig() 254 255 256class FeedForward(Cell): 257 r""" 258 The multilayer perceptron with two linear layers with dropout applied at final output. The first linear 259 will project the input dimension from hidden_size to ffn_hidden_size, the second linear will project the 260 dimension from ffn_hidden_size to hidden_size. The first linear is sharded on the relative dimension, 261 the second linear is sharded on the output dimension. The overview process can be 262 263 .. math:: 264 Dropout((xW_1+b_1)W_2 + b_2)) 265 266 where the :math:`W_1, W_2, b_1` and :math:`b_2` are trainable parameters. 267 268 Args: 269 hidden_size (int): The dimension of the inputs. 270 ffn_hidden_size (int): The intermediate hidden size. 271 dropout_rate (float): The dropout rate for the second linear's output. 272 hidden_act (str): The activation of the internal feedforward layer. Supports 'relu', 273 'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish', 274 'hsigmoid', 'logsigmoid' and so on. Default: gelu. 275 expert_num (int): The number of experts used in Linear. For the case expert_num > 1, BatchMatMul is used 276 and the first dimension in BatchMatMul indicate expert_num. Default: 1. 277 param_init_type (dtype.Number): The parameter initialization type. Should be dtype.float32 or dtype.float16. 278 Default: dtype.float32. 279 parallel_config(OpParallelConfig): The config of parallel setting, see `OpParallelConfig`. 280 Default `default_dpmp_config`, an instance of `OpParallelConfig` with 281 default args. 282 283 Inputs: 284 - **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`. 285 Float tensor. 286 287 Outputs: 288 Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size] 289 or [batch * seq_length, hidden_size]`. 290 291 Raises: 292 ValueError: `hidden_act` is not a string. 293 TypeError: `parallel_config` is not a subclass of OpParallelConfig. 294 ValueError: `ffn_hidden_size` is not a multiple of the model parallel way. 295 ValueError: `hidden_size` is not a multiple of the model parallel way. 296 297 Supported Platforms: 298 ``Ascend`` ``GPU`` 299 300 Examples: 301 >>> import numpy as np 302 >>> from mindspore.parallel.nn import FeedForward 303 >>> from mindspore import dtype as mstype 304 >>> from mindspore import Tensor 305 >>> model = FeedForward(hidden_size=15, ffn_hidden_size=30, dropout_rate=0.1) 306 >>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32) 307 >>> output = model(tensor) 308 >>> print(output.shape) 309 (2, 20, 15) 310 """ 311 312 @_args_type_validator_check(hidden_size=Validator.check_positive_int, 313 ffn_hidden_size=Validator.check_positive_int, 314 dropout_rate=Validator.check_non_negative_float, 315 hidden_act=_valid_type_checks([str], "FeedForward"), 316 param_init_type=_valid_value_checks([mstype.float32, mstype.float16], 317 "FeedForward"), 318 parallel_config=_valid_type_checks([OpParallelConfig], 319 "FeedForward")) 320 def __init__(self, hidden_size, 321 ffn_hidden_size, 322 dropout_rate, 323 hidden_act='gelu', 324 expert_num=1, 325 param_init_type=mstype.float32, 326 parallel_config=default_dpmp_config): 327 super(FeedForward, self).__init__() 328 _check_config(parallel_config) 329 dp = parallel_config.data_parallel 330 mp = parallel_config.model_parallel 331 if ffn_hidden_size % mp != 0: 332 raise ValueError(f"ffn_hidden_size {ffn_hidden_size} should be a multiple of the model parallel way {mp}") 333 if hidden_size % mp != 0: 334 raise ValueError(f"hidden_size {hidden_size} should be a multiple of the model parallel way {mp}") 335 if dropout_rate < 0 or dropout_rate >= 1: 336 raise ValueError(f"dropout_rate probability should be a number in range [0, 1.0), " 337 f"but got {dropout_rate}") 338 input_size = hidden_size 339 output_size = ffn_hidden_size 340 # Here, 'ep' stands for expert parallel number, which is equal to data parallel number. 341 ep = dp 342 # Project to ffn_hidden_size 343 self.mapping = _Linear(in_channels=input_size, 344 out_channels=output_size, 345 activation=hidden_act, 346 transpose_b=False, 347 expert_num=expert_num, 348 param_init_type=param_init_type) 349 350 if expert_num > 1: 351 self.mapping.shard(strategy_matmul=((ep, 1, 1), (ep, 1, mp)), 352 strategy_bias=((ep, 1, mp), (mp,)), 353 strategy_activation=((ep, 1, mp),)) 354 else: 355 self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)), 356 strategy_bias=((dp, mp), (mp,)), 357 strategy_activation=((dp, mp),)) 358 # Project back to hidden_size 359 self.projection = _Linear(in_channels=output_size, 360 out_channels=input_size, 361 transpose_b=False, 362 expert_num=expert_num, 363 param_init_type=param_init_type) 364 if expert_num > 1: 365 self.projection.shard(strategy_matmul=((ep, 1, mp), (ep, mp, 1)), 366 strategy_bias=((ep, 1, 1), (1,))) 367 else: 368 self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)), 369 strategy_bias=((dp, 1), (1,))) 370 self.projection.bias.parallel_optimizer = False 371 self.dropout = nn.Dropout(1 - dropout_rate) 372 self.dropout.dropout.shard(((dp, 1),)) 373 self.dropout_3d = nn.Dropout(1 - dropout_rate) 374 self.dropout_3d.dropout.shard(((dp, 1, 1),)) 375 self.cast = P.Cast() 376 377 def construct(self, x): 378 _check_input_shape(F.shape(x), "x", self.cls_name, [2, 3]) 379 _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name) 380 x = self.cast(x, mstype.float16) 381 # returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size] 382 hidden = self.mapping(x) 383 output = self.projection(hidden) 384 # returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size] 385 if len(F.shape(output)) == 3: 386 output = self.dropout_3d(output) 387 else: 388 output = self.dropout(output) 389 return output 390 391 392class AttentionMask(Cell): 393 r""" 394 Get the Lower triangular matrix from the input mask. The input mask is a 2D tensor (batch_size, seq_length) 395 with 1 and 0. 1 indicates the current position is a valid token, otherwise not. 396 397 Args: 398 seq_length(int): The sequence length of the input tensor. 399 parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`, 400 an instance of `OpParallelConfig` with default args. 401 402 Inputs: 403 - **input_mask** (Tensor) - The mask indicating whether each position is a valid input with 404 (batch_size, seq_length). 405 406 Outputs: 407 Tensor. The attention mask matrix with shape (batch_size, seq_length, seq_length). 408 409 Raises: 410 TypeError: `seq_length` is not an integer. 411 ValueError: `seq_length` is not a positive value. 412 TypeError: `parallel_config` is not a subclass of OpParallelConfig. 413 414 Supported Platforms: 415 ``Ascend`` ``GPU`` 416 417 Examples: 418 >>> import numpy as np 419 >>> from mindspore.parallel.nn import AttentionMask 420 >>> from mindspore import Tensor 421 >>> mask = AttentionMask(seq_length=4) 422 >>> mask_array = np.array([[1, 1, 1, 0]], np.float32) 423 >>> inputs = Tensor(mask_array) 424 >>> res = mask(inputs) 425 >>> print(res) 426 [[[1. 0. 0. 0], 427 [1. 1. 0. 0], 428 [1. 1. 1. 0], 429 [0. 0. 0. 0]]] 430 """ 431 432 @_args_type_validator_check(seq_length=Validator.check_positive_int, 433 parallel_config=_valid_type_checks([OpParallelConfig], "AttentionMask")) 434 def __init__(self, seq_length, parallel_config=default_dpmp_config): 435 super(AttentionMask, self).__init__() 436 self.seq_length = seq_length 437 self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1), ())) 438 self.reshape = P.Reshape() 439 self.mul = P.BatchMatMul().shard( 440 ((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1))) 441 self.expand_dim = P.ExpandDims().shard(((1, 1),)) 442 ones = np.ones(shape=(seq_length, seq_length)) 443 # Default lower triangle mask matrix 444 self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32) 445 self.multiply = P.Mul().shard(((parallel_config.data_parallel, 1, 1), (1, 1, 1))) 446 447 def construct(self, input_mask): 448 _check_input_shape(F.shape(input_mask), "input_mask", self.cls_name, 2) 449 _check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name) 450 _check_input_shape_value(F.shape(input_mask), 1, "input_mask", self.cls_name, self.seq_length) 451 input_mask = P.Cast()(self.not_equal(input_mask, 0), mstype.float16) 452 input_shape = P.Shape()(input_mask) 453 shape_right = (input_shape[0], 1, input_shape[1]) 454 shape_left = input_shape + (1,) 455 # Mask the padded inputs 456 mask_left = self.reshape(input_mask, shape_left) 457 mask_right = self.reshape(input_mask, shape_right) 458 attention_mask = self.mul(mask_left, mask_right) 459 lower_traiangle = self.expand_dim(self.lower_triangle_mask, 0) 460 # the returned shape is [bs, seq_length, seq_length] 461 attention_mask = self.multiply( 462 attention_mask, lower_traiangle) 463 return attention_mask 464 465 466class VocabEmbedding(Cell): 467 """ 468 The embedding lookup table from the 0-th dim of the parameter table. When the parallel_config.vocab_emb_dp is 469 True and in the `AUTO_PARALLEL_MODE`, the embedding lookup will be a `parallel_config.data_parallel` 470 data parallel way, or will shard the parameter at the 0-th dimension in `parallel_config.model_parallel`, so-called 471 row slice of the embedding table. 472 473 Args: 474 vocab_size (int): Size of the dictionary of embeddings. 475 embedding_size (int): The size of each embedding vector. 476 param_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table. 477 Refer to class `initializer` for the values of string when a string 478 is specified. Default: 'normal'. 479 parallel_config(EmbeddingOpParallelConfig): The parallel config of network. Default 480 `default_embedding_parallel_config`, an instance of `EmbeddingOpParallelConfig` with default args. 481 482 Inputs: 483 **input_ids** (Tensor) - The tokenized inputs with datatype int32 with shape (batch_size, seq_length) 484 485 Outputs: 486 Tuple, a tuple contains (`output`, `embedding_table`) 487 488 - **output** (Tensor) - The embedding vector for the input with shape (batch_size, 489 seq_length, embedding_size). 490 - **weight** (Tensor) - The embedding table with shape (vocab_size, embedding_size). 491 492 Raises: 493 ValueError: If the parallel_config.vocab_emb_dp is True, the vocab size is not a multiple of 494 parallel_config.model_parallel 495 ValueError: `vocab_size` is not a positive value. 496 ValueError: `embedding_size` is not a positive value. 497 TypeError: `parallel_config` is not a subclass of OpParallelConfig. 498 499 Supported Platforms: 500 ``Ascend`` ``GPU`` 501 502 Examples: 503 >>> import numpy as np 504 >>> from mindspore.parallel.nn import VocabEmbedding 505 >>> from mindspore import Tensor 506 >>> from mindspore import dtype as mstype 507 >>> model = VocabEmbedding(vocab_size=30, embedding_size=30) 508 >>> tensor = Tensor(np.ones((20, 15)), mstype.int32) 509 >>> output, table = model(tensor) 510 >>> print(output.shape) 511 (20, 15, 30) 512 >>> print(table.shape) 513 (30, 30) 514 """ 515 516 @_args_type_validator_check(vocab_size=Validator.check_positive_int, 517 embedding_size=Validator.check_positive_int, 518 parallel_config=_valid_type_checks([EmbeddingOpParallelConfig], "VocabEmbedding")) 519 def __init__(self, vocab_size, embedding_size, parallel_config=default_embedding_parallel_config, 520 param_init='normal'): 521 super(VocabEmbedding, self).__init__() 522 _check_config(parallel_config) 523 self.vocab_size = vocab_size 524 self.embedding_size = embedding_size 525 self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), 526 name='embedding_table', parallel_optimizer=False) 527 if parallel_config.vocab_emb_dp: 528 self.gather = P.GatherV2().shard(((1, 1), (parallel_config.data_parallel, 1))) 529 logger.info(f"Using {parallel_config.data_parallel} data parallel for the embedding lookup.") 530 else: 531 if self.vocab_size % parallel_config.model_parallel != 0: 532 raise ValueError(f"The vocab size of the embedding {self.vocab_size} must be a " 533 f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.") 534 self.gather = P.GatherV2().shard(((parallel_config.model_parallel, 1), (1, 1))) 535 logger.info(f"Using {parallel_config.data_parallel} model parallel for the embedding lookup.") 536 537 def construct(self, input_ids): 538 _check_input_shape(F.shape(input_ids), "input_ids", self.cls_name, 2) 539 _check_input_dtype(F.dtype(input_ids), "input_ids", [mstype.int32], self.cls_name) 540 output = self.gather(self.embedding_table, input_ids, 0) 541 return output, self.embedding_table 542 543 544class MultiHeadAttention(Cell): 545 r""" 546 This is an implementation of multihead attention in the paper `Attention is all you need 547 <https://arxiv.org/pdf/1706.03762v5.pdf>`_. Given the query vector with source length, and the 548 key and value vector with target length, the attention will be performed as the following 549 550 .. math:: 551 MultiHeadAttention(query, key, vector) = Concat(head_1, \dots, head_h)W^O 552 553 where :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)`. The default is with a bias. 554 555 if query, key and value tensor is same, then it will be self attention. 556 557 Args: 558 batch_size(int): The batch size of the input tensor. 559 src_seq_length(int): The sequence length of the query vector. 560 tgt_seq_length(int): The sequence length of the key and value vector. 561 hidden_size(int): The hidden size of the input. 562 num_heads(int): The number of the heads. 563 hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1 564 attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1 565 compute_dtype(dtype.Number): The computation type of dense. Default dtype.float16. 566 Should be dtype.float32 or dtype.float16. 567 param_init_type(dtype.Number): The parameter initialization type of the module. Default dtype.float32. 568 Should be dtype.float32 or dtype.float16. 569 softmax_compute_type(dtype.Number): The type of softmax computation module. Default dtype.float32. 570 Should be dtype.float32 or dtype.float16. 571 use_past(bool): Use the past state to compute, used for incremental prediction. For example, if we have two 572 words and want to generate the ten more words. We just need to compute the two words's state only once, 573 and generate the next word one by one. When use_past is True, there are two steps to run the prediction. 574 The first step, set the is_first_iteration to be True by 575 `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the 576 is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. At this moment, 577 pass the single step's input tensor, and loop it. Default False. 578 parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`, 579 an instance of `OpParallelConfig` with default args. 580 581 Inputs: 582 - **query_tensor** (Tensor) - the query vector with shape (batch_size, src_seq_length, hidden_size) or 583 (batch_size * src_seq_length, hidden_size), if the use_past is False or is_first_iteration=True. Otherwise, 584 must be (batch_size, 1, hidden_size) 585 - **key_tensor** (Tensor) - the key vector with shape (batch_size, tgt_seq_length, hidden_size) or 586 (batch_size * tgt_seq_length, hidden_size), if the use_past is False or is_first_iteration=True. Otherwise, 587 must be (batch_size, 1, hidden_size) 588 - **value_tensor** (Tensor) - the value vector with shape (batch_size, tgt_seq_length, hidden_size) or 589 (batch_size * tgt_seq_length, hidden_size), if the use_past is False or is_first_iteration=True. Otherwise, 590 must be (batch_size, 1, hidden_size) 591 - **attention_mask** (Tensor) - the attention mask matrix with shape (batch_size, src_seq_length, 592 tgt_seq_length), if the use_past is False or is_first_iteration=True. Otherwise, 593 must be (batch_size, 1, tgt_seq_length) 594 - **key_past** (Tensor) - Float16 tensor with shape (batch_size, num_heads, size_per_head, tgt_seq_length). 595 The past calculated key vector. Used for incremental prediction when the use_past is True. 596 Default None. 597 - **value_past** (Tensor) - Float16 tensor with shape (batch_size, num_heads, tgt_seq_length, size_per_head). 598 The past calculated value vector. Used for incremental prediction when the use_past is True. 599 Default None. 600 - **batch_valid_length** (Tensor) - Int32 tensor with shape (batch_size,) the past calculated the index. 601 Used for incremental prediction when the use_past is True. Default None. 602 603 Outputs: 604 Tuple, a tuple contains(`output`, `layer_present`) 605 606 - **output** (Tensor) - Tensor, the float tensor of the output of the layer with 607 shape (batch_size, src_seq_length, hidden_size) or (batch_size * src_seq_length, hidden_size), 608 if the use_past is False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size). 609 610 - **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with 611 ((batch_size, num_heads, size_per_head, tgt_seq_length), 612 (batch_size, num_heads, tgt_seq_length, size_per_head)). 613 614 Supported Platforms: 615 ``Ascend`` ``GPU`` 616 617 Examples: 618 >>> import numpy as np 619 >>> from mindspore.parallel.nn import MultiHeadAttention 620 >>> from mindspore import dtype as mstype 621 >>> from mindspore import Tensor 622 >>> model = MultiHeadAttention(batch_size=2, hidden_size=15, src_seq_length=20, tgt_seq_length=20, 623 ... num_heads=3) 624 >>> from_tensor = Tensor(np.ones((2, 20, 15)), mstype.float32) 625 >>> to_tensor = Tensor(np.ones((2, 20, 15)), mstype.float16) 626 >>> attention_mask = Tensor(np.ones((2, 20, 20)), mstype.float16) 627 >>> attn_out, past = model(from_tensor, to_tensor, to_tensor, attention_mask) 628 >>> print(attn_out.shape) 629 (2, 20, 15) 630 >>> print(past[0].shape) 631 (2, 3, 5, 20) 632 >>> print(past[1].shape) 633 (2, 3, 20, 5) 634 # When use use_past=True, it includes two steps to implement the incremental prediction. 635 # Step 1: set is_first_iteration=True, and input the full sequence length's state. 636 # We need to prepare the memory parameters for saving key and value states firstly. 637 >>> model = MultiHeadAttention(batch_size=2, hidden_size=15, src_seq_length=20, tgt_seq_length=20, 638 ... num_heads=3, use_past=True) 639 >>> key_past = Tensor(np.zeros(shape=(2, 3, 5, 20)), mstype.float16) 640 >>> value_past = Tensor(np.zeros(shape=(2, 3, 20, 5)), mstype.float16) 641 >>> batch_valid_length = Tensor(np.ones((2,)), mstype.int32) 642 # Set is_first_iteration=True to generate the full memory states 643 >>> model.add_flags_recursive(is_first_iteration=True) 644 >>> attn_out, past = model(from_tensor, to_tensor, to_tensor, attention_mask, key_past, value_past, 645 ... batch_valid_length) 646 >>> print(attn_out.shape) 647 (2, 20, 15) 648 >>> print(past[0].shape) 649 (2, 3, 5, 20) 650 >>> print(past[1].shape) 651 (2, 3, 20, 5) 652 >>> from_tensor = Tensor(np.ones((2, 1, 15)), mstype.float32) 653 >>> to_tensor = Tensor(np.ones((2, 1, 15)), mstype.float16) 654 >>> attention_mask = Tensor(np.ones((2, 1, 20)), mstype.float16) 655 # Step 2: set is_first_iteration=False, and pass the single word to run the prediction rather than the full 656 # sequence. 657 >>> model.add_flags_recursive(is_first_iteration=False) 658 >>> attn_out, past = model(from_tensor, to_tensor, to_tensor, attention_mask, key_past, value_past, 659 ... batch_valid_length) 660 >>> print(attn_out.shape) 661 (2, 1, 15) 662 >>> print(past[0].shape) 663 (2, 3, 5, 20) 664 >>> print(past[1].shape) 665 (2, 3, 20, 5) 666 """ 667 668 @_args_type_validator_check(batch_size=Validator.check_positive_int, 669 hidden_size=Validator.check_positive_int, 670 num_heads=Validator.check_positive_int, 671 src_seq_length=Validator.check_positive_int, 672 tgt_seq_length=Validator.check_positive_int, 673 attention_dropout_rate=Validator.check_non_negative_float, 674 hidden_dropout_rate=Validator.check_non_negative_float, 675 compute_dtype=_valid_value_checks([mstype.float32, mstype.float16], 676 "MultiHeadAttention"), 677 softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16], 678 "MultiHeadAttention"), 679 param_init_type=_valid_value_checks([mstype.float32, mstype.float16], 680 "MultiHeadAttention"), 681 parallel_config=_valid_type_checks([OpParallelConfig], 682 "MultiHeadAttention"), 683 use_past=Validator.check_bool) 684 def __init__(self, batch_size, 685 src_seq_length, 686 tgt_seq_length, 687 hidden_size, 688 num_heads, 689 hidden_dropout_rate=0.1, 690 attention_dropout_rate=0.1, 691 compute_dtype=mstype.float16, 692 softmax_compute_type=mstype.float32, 693 param_init_type=mstype.float32, 694 use_past=False, 695 parallel_config=default_dpmp_config): 696 super(MultiHeadAttention, self).__init__() 697 _check_config(parallel_config) 698 self.is_parallel_mode = _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) 699 self.src_seq_length = src_seq_length 700 self.tgt_seq_length = tgt_seq_length 701 self.hidden_size = hidden_size 702 self.batch_size = batch_size 703 if hidden_dropout_rate < 0 or hidden_dropout_rate >= 1: 704 raise ValueError(f"hidden_dropout_rate probability should be a number in range [0, 1.0), " 705 f"but got {hidden_dropout_rate}") 706 if attention_dropout_rate < 0 or attention_dropout_rate >= 1: 707 raise ValueError(f"attention_dropout_rate probability should be a number in range [0, 1.0), " 708 f"but got {attention_dropout_rate}") 709 if hidden_size % num_heads != 0: 710 raise ValueError(f"The hidden size {hidden_size} should be a multiple of num_heads {num_heads}") 711 if num_heads % parallel_config.model_parallel != 0: 712 raise ValueError(f"The number of heads {num_heads} must be a " 713 f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.") 714 if self.is_parallel_mode and batch_size % parallel_config.data_parallel != 0: 715 raise ValueError(f"The batch size {batch_size} must be a " 716 f"multiple of parallel_config.data_parallel {parallel_config.data_parallel}.") 717 self.is_first_iteration = True 718 # Output layer 719 self.projection = _Linear(in_channels=hidden_size, 720 out_channels=hidden_size, 721 transpose_b=False, 722 param_init_type=param_init_type).to_float(compute_dtype) 723 self.projection.shard(strategy_bias=((parallel_config.data_parallel, 1), (1,)), 724 strategy_matmul=((parallel_config.data_parallel, parallel_config.model_parallel), 725 (parallel_config.model_parallel, 1))) 726 self.projection.bias.parallel_optimizer = False 727 self.transpose = P.Transpose().shard(((parallel_config.data_parallel, 1, parallel_config.model_parallel, 1),)) 728 self.merger_head_transpose = P.Transpose().shard( 729 ((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),)) 730 self.reshape = P.Reshape() 731 self.n_head = num_heads 732 # embedding size per head 733 self.size_per_head = hidden_size // self.n_head 734 self.concat_k = P.Concat(axis=3) 735 self.concat_v = P.Concat(axis=2) 736 self.multiply_data = Tensor([ 737 -10000.0, 738 ], dtype=softmax_compute_type) 739 self.batch_matmul = P.BatchMatMul().shard( 740 ((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1), 741 (parallel_config.data_parallel, parallel_config.model_parallel, 1, 1))) 742 self.real_div = P.RealDiv().shard(((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1), ())) 743 self.sub = P.Sub().shard( 744 ((1,), (parallel_config.data_parallel, 1, 1, 1))) 745 self.mul = P.Mul().shard( 746 ((parallel_config.data_parallel, 1, 1, 1), (1,))) 747 self.add = P.Add().shard( 748 ((parallel_config.data_parallel, 1, 1, 1), 749 (parallel_config.data_parallel, parallel_config.model_parallel, 1, 1))) 750 # Normalize factor for attention, sqrt(dk) as widely used 751 self.scale_factor = Tensor(math.sqrt(self.size_per_head)) 752 self.use_past = use_past 753 self.dropout = nn.Dropout(1 - hidden_dropout_rate) 754 self.dropout.dropout.shard(((parallel_config.data_parallel, 1),)) 755 self.prob_dropout = nn.Dropout(1 - attention_dropout_rate) 756 self.prob_dropout.dropout.shard( 757 ((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),)) 758 self.softmax = nn.Softmax().to_float(softmax_compute_type) 759 self.softmax.softmax.shard(((parallel_config.data_parallel, parallel_config.model_parallel, 1),)) 760 self.expand_dims = P.ExpandDims().shard(((parallel_config.data_parallel, 1, 1),)) 761 762 # Query 763 self.dense1 = _Linear(hidden_size, 764 hidden_size, 765 param_init_type=param_init_type).to_float(compute_dtype) 766 self.dense1.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)), 767 strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel), 768 (parallel_config.model_parallel,))) 769 # Key 770 self.dense2 = _Linear(hidden_size, 771 hidden_size, 772 param_init_type=param_init_type).to_float(compute_dtype) 773 self.dense2.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)), 774 strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel), 775 (parallel_config.model_parallel,))) 776 777 # Value 778 self.dense3 = _Linear(hidden_size, 779 hidden_size, 780 param_init_type=param_init_type).to_float(compute_dtype) 781 self.dense3.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)), 782 strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel), 783 (parallel_config.model_parallel,))) 784 self.dtype = compute_dtype 785 self.softmax_dtype = softmax_compute_type 786 if self.use_past: 787 # operators used for state reuse 788 seq_range = np.arange(src_seq_length).reshape(1, 1, -1) 789 self.range = Tensor(np.tile(seq_range, (batch_size, 1, 1)), mstype.int32) 790 self.seq_length = src_seq_length 791 self.attention_mask = Tensor(np.tril(np.ones(shape=(self.seq_length, self.seq_length))), mstype.int32) 792 self.slice = P.StridedSlice().shard(((1, 1, 1, 1),)) 793 self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ())) 794 self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),)) 795 self.expand_dims = P.ExpandDims().shard(((1, 1, 1),)) 796 self.tensor_le = P.LessEqual().shard(((1, 1, 1), (1, 1, 1))) 797 self.add = P.Add().shard(((1, 1, 1, 1), (1, 1, 1, 1))) 798 self.equal = P.Equal().shard(((1, 1, 1), (1, 1, 1))) 799 self.sub1 = P.Sub().shard(((1,), ())) 800 self.tile = P.Tile().shard(((1, 1, 1, 1),)) 801 self.less = P.Less().shard(((1, 1, 1), (1, 1, 1))) 802 self.mul1 = P.Mul().shard(((1, 1, 1, 1), (1, 1, 1, 1))) 803 804 def construct(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None, 805 value_past=None, batch_valid_length=None): 806 self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past, 807 value_past, batch_valid_length) 808 query_tensor, key_tensor, value_tensor, batch_size, ori_shape = self._convert_to_2d_tensor(query_tensor, 809 key_tensor, 810 value_tensor, 811 attention_mask) 812 813 # multi head attention: query, key, value are derived from the same inputs 814 query = self.dense1(query_tensor) 815 key = self.dense2(key_tensor) 816 value = self.dense3(value_tensor) 817 # the returned shape is [bs, num_heads, seq_length, size_per_head] 818 query = self.transpose( 819 F.reshape( 820 query, 821 (batch_size, -1, self.n_head, self.size_per_head)), 822 (0, 2, 1, 3)) 823 # the returned shape is [bs, size_per_head, seq_length, num_heads] 824 key = self.transpose( 825 F.reshape( 826 key, (batch_size, -1, self.n_head, self.size_per_head)), 827 (0, 2, 3, 1)) 828 # the returned shape is [bs, num_heads, seq_length, size_per_head] 829 value = self.transpose( 830 F.reshape( 831 value, 832 (batch_size, -1, self.n_head, self.size_per_head)), 833 (0, 2, 1, 3)) 834 # support input shape is [bs, seq, seq] or [bs, heads, seq, seq] 835 if len(F.shape(attention_mask)) == 3: 836 # expand attention mask from [bs, seq, seq] -> [bs, 1, seq, seq] 837 attention_mask = self.expand_dims(attention_mask, 1) 838 # key and value for current token(s) 839 key_present = key 840 value_present = value 841 if self.use_past: 842 # The first graph with the input size of (bs, seq_length) 843 if self.is_first_iteration: 844 # Get the valid input length without padding 845 valid_length_vector = F.cast(self.less(self.range, batch_valid_length.view(-1, 1, 1)), self.dtype) 846 # Cover the key and value numbers corresponding to the padding position 847 key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2)) 848 value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3)) 849 # The second graph with the inpus size of (bs, 1) 850 # the shape of query is (bs, num_heads, 1, size_per_head) 851 # the shape of key is (bs, num_heads, size_per_head, 1) 852 # the shape of value is (bs, num_heads, 1, size_per_head) 853 else: 854 # Get the current token position index 855 valid_length = self.reducesum(F.cast(self.not_equal(self.slice(key_past, (0, 0, 0, 0), 856 (F.shape(key_tensor)[0], 1, 1, 857 self.src_seq_length), 858 (1, 1, 1, 1)), 859 0), mstype.float32), (1, 2, 3)) 860 valid_length = F.reshape(valid_length, (-1, 1, 1)) 861 valid_length_vector = F.cast(self.equal(valid_length, self.range), self.dtype) 862 # Pad the key and value to seq_length with only the position index not zero 863 current_key = self.mul1(self.tile(key, (1, 1, 1, self.seq_length)), 864 self.expand_dims(valid_length_vector, 2)) 865 current_value = self.mul1(self.tile(value, (1, 1, self.seq_length, 1)), 866 self.expand_dims(valid_length_vector, 3)) 867 # Concat the previous saved state and current state 868 key = self.add(key_past, current_key) 869 value = self.add(value_past, current_value) 870 # Update key_present and value_present for state update 871 key_present = key 872 value_present = value 873 attention_mask = F.reshape(self.attention_mask, (self.seq_length, self.seq_length, 1, 1)) 874 875 layer_present = (key_present, value_present) 876 # multi head attention considering attention mask 877 # the return shape is [bs * seq_length, hidden_size] 878 attention = self._attn(query, key, value, attention_mask) 879 # Output 880 output = self.projection(attention) 881 output = self.dropout(output) 882 output = F.reshape(output, ori_shape) 883 return output, layer_present 884 885 def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None, 886 value_past=None, batch_valid_length=None): 887 r"""Check inputs""" 888 if not self.use_past or (self.use_past and self.is_first_iteration): 889 _check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name, 890 [[self.batch_size, self.src_seq_length, self.hidden_size], 891 [self.batch_size * self.src_seq_length, self.hidden_size]]) 892 _check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name, 893 [[self.batch_size, self.tgt_seq_length, self.hidden_size], 894 [self.batch_size * self.tgt_seq_length, self.hidden_size]]) 895 _check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name, 896 [[self.batch_size, self.tgt_seq_length, self.hidden_size], 897 [self.batch_size * self.tgt_seq_length, self.hidden_size]]) 898 _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name, 899 [self.batch_size, self.src_seq_length, self.tgt_seq_length]) 900 else: 901 _check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name, 902 [[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]]) 903 _check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name, 904 [[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]]) 905 _check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name, 906 [[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]]) 907 _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name, 908 [[self.batch_size, 1, self.tgt_seq_length], [self.batch_size, self.hidden_size]]) 909 910 _check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name) 911 _check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name) 912 _check_input_dtype(F.dtype(value_tensor), "value_tensor", [mstype.float32, mstype.float16], self.cls_name) 913 _check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name) 914 915 key_is_tensor = isinstance(key_past, Tensor) 916 value_is_tensor = isinstance(value_past, Tensor) 917 batch_valid_length_is_tensor = isinstance(batch_valid_length, Tensor) 918 key_is_default = key_past is None 919 value_is_default = value_past is None 920 batch_is_default = batch_valid_length is None 921 _check_past_none_input_none(self.use_past, "key_past", self.cls_name, None, key_is_tensor, 922 key_is_default) 923 _check_past_none_input_none(self.use_past, "value_past", self.cls_name, None, value_is_tensor, 924 value_is_default) 925 _check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None, 926 batch_valid_length_is_tensor, batch_is_default) 927 if self.use_past: 928 _check_shape_equal(F.shape(key_past), "key_past", self.cls_name, 929 [self.batch_size, self.n_head, self.size_per_head, self.tgt_seq_length]) 930 _check_input_dtype(F.dtype(key_past), "key_past", [mstype.float16], self.cls_name) 931 _check_shape_equal(F.shape(value_past), "value_past", self.cls_name, 932 [self.batch_size, self.n_head, self.tgt_seq_length, self.size_per_head]) 933 _check_input_dtype(F.dtype(value_past), "value_past", [mstype.float16], self.cls_name) 934 _check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size]) 935 _check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name) 936 return True 937 938 def _convert_to_2d_tensor(self, query_tensor, key_tensor, value_tensor, attention_mask): 939 """convert a nd tensor to a 2d tensor""" 940 query_shape = F.shape(query_tensor) 941 query_tensor = F.reshape(query_tensor, (-1, query_shape[-1])) 942 key_shape = F.shape(key_tensor) 943 key_tensor = F.reshape(key_tensor, (-1, key_shape[-1])) 944 value_shape = F.shape(value_tensor) 945 value_tensor = F.reshape(value_tensor, (-1, value_shape[-1])) 946 return query_tensor, key_tensor, value_tensor, F.shape(attention_mask)[0], query_shape 947 948 def _merge_heads(self, x): 949 """ 950 convert a 4d input to a 2d output 951 952 Inputs: 953 x: input tensor 954 955 Output: 956 x_merge: the 2d output 957 """ 958 x = self.merger_head_transpose( 959 x, (0, 2, 1, 3)) # bs, seq_length, head, size_per_head 960 x_shape = P.Shape()(x) 961 new_shape = (-1, x_shape[-2] * x_shape[-1]) 962 x_merge = self.reshape(x, new_shape) 963 return x_merge 964 965 def _attn(self, query, key, value, attention_mask): 966 """ 967 Get the weighted score along the seq_length 968 969 Inputs: 970 query: the query matrix 971 key: the key matrix 972 value: the value matrix 973 attention_mask: the attention mask matrix with shape (batch_size, 974 1, seq_length, seq_length) 975 Outputs: 976 weighted_values: Tensor, the weighted sum scores 977 """ 978 # Normalize query and key before MatMul, default off 979 # Attention score [bs, num_heads, seq_length, seq_length] 980 score = self.batch_matmul(query, key) 981 # Normalize after query and key MatMul 982 score = self.real_div( 983 score, 984 P.Cast()(self.scale_factor, P.DType()(score))) 985 986 ori_dtype = P.DType()(score) 987 score = P.Cast()(score, self.softmax_dtype) 988 989 # for input size of (bs, 1) namely the second graph, 990 # the shape of attention_mask matrix should be (bs, 1, 1, seq_length) 991 if self.use_past and not self.is_first_iteration: 992 # Calculate the current total token 993 current_index = self.reducesum(F.cast(self.not_equal(self.slice(key, (0, 0, 0, 0), 994 (F.shape(query)[0], 1, 1, self.seq_length), 995 (1, 1, 1, 1)), 996 0), mstype.float32), (1, 2, 3)) 997 # Get the precise position index 998 index = self.sub1(F.cast(current_index, mstype.int32), 1) 999 index = F.reshape(index, (-1, 1, 1)) 1000 # Calculate the attention_mask matrix via the position index 1001 attention_mask = F.cast(self.tensor_le(self.range, index), mstype.int32) 1002 attention_mask = self.expand_dims(attention_mask, 2) 1003 1004 # Minus 10000 for the position where masked to exclude them from softmax 1005 multiplu_out = self.sub( 1006 P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)), 1007 P.Cast()(attention_mask, P.DType()(score))) 1008 1009 adder = self.mul(multiplu_out, self.multiply_data) 1010 attention_scores = self.add(adder, score) 1011 1012 shape = F.shape(attention_scores) 1013 # attention probs 1014 attention_probs = self.softmax( 1015 F.reshape(attention_scores, 1016 (shape[0], -1, shape[-1]))) 1017 attention_probs = P.Cast()(attention_probs, ori_dtype) 1018 attention_probs = F.reshape(attention_probs, shape) 1019 1020 attention_probs = self.prob_dropout(attention_probs) 1021 # Weighted sum output [bs, num_heads, seq_length, size_per_head] 1022 weighted_values = self.batch_matmul(attention_probs, value) 1023 attention_merge = self._merge_heads(weighted_values) 1024 return attention_merge 1025 1026 1027class TransformerEncoderLayer(Cell): 1028 r""" 1029 Transformer Encoder Layer. This is an implementation of the single layer of the transformer 1030 encoder layer, including multihead attention and feedward layer. 1031 1032 Args: 1033 batch_size(int): The batch size of the input tensor. 1034 hidden_size(int): The hidden size of the input. 1035 seq_length(int): The input sequence length. 1036 ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer. 1037 num_heads(int): The number of the heads. 1038 hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1 1039 attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1 1040 post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False. 1041 hidden_act(str): The activation of the internal feedforward layer. Supports 'relu', 1042 'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish', 1043 'hsigmoid', 'logsigmoid' and so on. Default: gelu. 1044 layernorm_compute_type(dtype.Number): The computation type of the layernorm. 1045 Should be dtype.float32 or dtype.float16. Default dtype.float32. 1046 softmax_compute_type(dtype.Number): The computation type of the softmax in the attention. 1047 Should be dtype.float32 or dtype.float16. Default mstype.float32. 1048 param_init_type(dtype.Number): The parameter initialization type of the module. 1049 Should be dtype.float32 or dtype.float16. Default dtype.float32. 1050 use_past(bool): Use the past state to compute, used for incremental prediction. For example, if we have two 1051 words and want to generate the ten more words. We just need to compute the two words's state only once, 1052 and generate the next word one by one. When use_past is True, there are two steps to run the prediction. 1053 The first step, set the is_first_iteration to be True by 1054 `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the 1055 is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. At this moment, 1056 pass the single step's input tensor, and loop it. Default False. 1057 moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). 1058 parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`, 1059 an instance of `OpParallelConfig` with default args. 1060 1061 Inputs: 1062 - **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or 1063 [batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise, 1064 should be [batch_size, 1, hidden_size] 1065 - **input_mask** (Tensor) - Float Tensor, attention mask with shape [batch_size, seq_length, seq_length], 1066 if the use_past is False or is_first_iteration=True. Otherwise, should be [batch_size, 1, hidden_size] 1067 - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and 1068 past value parameter used in the incremental prediction. Only valid when use_past is True. Default True. 1069 - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used 1070 for incremental prediction when the use_past is True. Default None. 1071 1072 Outputs: 1073 Tuple, a tuple contains(`output`, `layer_present`). 1074 1075 - **output** (Tensor) - The float tensor of the output of the layer with 1076 shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size), if the use_past is 1077 False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size) 1078 1079 - **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with 1080 ((batch_size, num_heads, size_per_head, seq_length), 1081 (batch_size, num_heads, seq_length, size_per_head)). 1082 1083 Supported Platforms: 1084 ``Ascend`` ``GPU`` 1085 1086 Examples: 1087 >>> import numpy as np 1088 >>> from mindspore import dtype as mstype 1089 >>> from mindspore.parallel.nn import TransformerEncoderLayer 1090 >>> from mindspore import Tensor 1091 >>> model = TransformerEncoderLayer(batch_size=2, hidden_size=8, ffn_hidden_size=64, seq_length=16, 1092 ... num_heads=2) 1093 >>> encoder_input_value = Tensor(np.ones((2, 16, 8)), mstype.float32) 1094 >>> encoder_input_mask = Tensor(np.ones((2, 16, 16)), mstype.float16) 1095 >>> output, past = model(encoder_input_value, encoder_input_mask) 1096 >>> print(output.shape) 1097 (2, 16, 8) 1098 >>> print(past[0].shape) 1099 (2, 2, 4, 16) 1100 >>> print(past[1].shape) 1101 (2, 2, 16, 4) 1102 # When use use_past=True, it includes two steps to implement the incremental prediction. 1103 # Step 1: set is_first_iteration=True, and input the full sequence length's state. 1104 >>> batch_valid_length = Tensor(np.ones((2,)), mstype.int32) 1105 >>> init_reset = Tensor([True], mstype.bool_) 1106 # Set is_first_iteration=True to generate the full memory states 1107 >>> model = TransformerEncoderLayer(batch_size=2, hidden_size=8, ffn_hidden_size=64, seq_length=16, 1108 ... num_heads=2, use_past=True) 1109 >>> model.add_flags_recursive(is_first_iteration=True) 1110 >>> hidden, past = model(encoder_input_value, encoder_input_mask, init_reset, batch_valid_length) 1111 >>> print(hidden.shape) 1112 (2, 16, 8) 1113 >>> print(past[0].shape) 1114 (2, 2, 4, 16) 1115 >>> print(past[1].shape) 1116 (2, 2, 16, 4) 1117 >>> encoder_input_value = Tensor(np.ones((2, 1, 8)), mstype.float32) 1118 >>> encoder_input_mask = Tensor(np.ones((2, 1, 16)), mstype.float16) 1119 >>> init_reset = Tensor([False], mstype.bool_) 1120 # Step 2: set is_first_iteration=False, and pass the single word to run the prediction rather than the full 1121 # sequence. 1122 >>> model.add_flags_recursive(is_first_iteration=False) 1123 >>> hidden, past = model(encoder_input_value, encoder_input_mask, init_reset, batch_valid_length) 1124 >>> print(hidden.shape) 1125 (2, 1, 8) 1126 >>> print(past[0].shape) 1127 (2, 2, 4, 16) 1128 >>> print(past[1].shape) 1129 (2, 2, 16, 4) 1130 """ 1131 1132 @_args_type_validator_check(batch_size=Validator.check_positive_int, 1133 hidden_size=Validator.check_positive_int, 1134 num_heads=Validator.check_positive_int, 1135 ffn_hidden_size=Validator.check_positive_int, 1136 seq_length=Validator.check_positive_int, 1137 attention_dropout_rate=Validator.check_non_negative_float, 1138 hidden_dropout_rate=Validator.check_non_negative_float, 1139 hidden_act=_valid_type_checks([str], "TransformerEncoderLayer"), 1140 post_layernorm_residual=Validator.check_bool, 1141 layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16], 1142 "TransformerEncoderLayer"), 1143 softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16], 1144 "TransformerEncoderLayer"), 1145 param_init_type=_valid_value_checks([mstype.float32, mstype.float16], 1146 "TransformerEncoderLayer"), 1147 parallel_config=_valid_type_checks([OpParallelConfig], 1148 "TransformerEncoderLayer"), 1149 use_past=Validator.check_bool) 1150 def __init__(self, 1151 batch_size, 1152 hidden_size, 1153 ffn_hidden_size, 1154 num_heads, 1155 seq_length, 1156 attention_dropout_rate=0.1, 1157 hidden_dropout_rate=0.1, 1158 post_layernorm_residual=False, 1159 layernorm_compute_type=mstype.float32, 1160 softmax_compute_type=mstype.float32, 1161 param_init_type=mstype.float32, 1162 hidden_act='gelu', 1163 use_past=False, 1164 moe_config=default_moe_config, 1165 parallel_config=default_dpmp_config): 1166 super(TransformerEncoderLayer, self).__init__() 1167 _check_config(parallel_config) 1168 if num_heads % parallel_config.model_parallel != 0: 1169 raise ValueError( 1170 f"num heads must be divisibled by the model parallel way {parallel_config.model_parallel}, " 1171 f"but found {num_heads}") 1172 if hidden_size % parallel_config.model_parallel != 0: 1173 raise ValueError( 1174 f"hidden_size must be divisibled by the model parallel way {parallel_config.model_parallel}, " 1175 f"but found {hidden_size}") 1176 if ffn_hidden_size % parallel_config.model_parallel != 0: 1177 raise ValueError( 1178 f"ffn_hidden_size must be divisibled by the model parallel way {parallel_config.model_parallel}, " 1179 f"but found {ffn_hidden_size}") 1180 self.use_past = use_past 1181 self.seq_length = seq_length 1182 self.hidden_size = hidden_size 1183 self.batch_size = batch_size 1184 self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type) 1185 self.layernorm1.shard(((parallel_config.data_parallel, 1),)) 1186 self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type) 1187 self.layernorm2.shard(((parallel_config.data_parallel, 1),)) 1188 1189 self.attention = MultiHeadAttention(batch_size=batch_size, 1190 src_seq_length=seq_length, 1191 tgt_seq_length=seq_length, 1192 hidden_size=hidden_size, 1193 num_heads=num_heads, 1194 hidden_dropout_rate=hidden_dropout_rate, 1195 attention_dropout_rate=attention_dropout_rate, 1196 softmax_compute_type=softmax_compute_type, 1197 param_init_type=param_init_type, 1198 use_past=use_past, 1199 parallel_config=parallel_config) 1200 self.use_moe = (moe_config.expert_num > 1) 1201 if self.use_moe is True: 1202 self.output = MoE(hidden_size=hidden_size, 1203 dropout_rate=hidden_dropout_rate, 1204 ffn_hidden_size=ffn_hidden_size, 1205 param_init_type=param_init_type, 1206 hidden_act=hidden_act, 1207 moe_config=moe_config, 1208 parallel_config=parallel_config) 1209 else: 1210 # Feed Forward Network, FFN 1211 self.output = FeedForward(hidden_size=hidden_size, 1212 dropout_rate=hidden_dropout_rate, 1213 ffn_hidden_size=ffn_hidden_size, 1214 param_init_type=param_init_type, 1215 hidden_act=hidden_act, 1216 parallel_config=parallel_config) 1217 self.post_layernorm_residual = post_layernorm_residual 1218 self.add = P.Add().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1))) 1219 self.add_3d = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1))) 1220 self.dtype = mstype.float16 1221 self.key_past = None 1222 self.value_past = None 1223 1224 if self.use_past: 1225 # operator used for state reuse 1226 self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),)) 1227 self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ())) 1228 self.slice = P.StridedSlice().shard(((1, 1, 1, 1),)) 1229 size_per_head = int(hidden_size / num_heads) 1230 self.key_shape = (batch_size, num_heads, size_per_head, seq_length) 1231 self.value_shape = (batch_size, num_heads, seq_length, size_per_head) 1232 # parameters saving key and value states 1233 self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past") 1234 self.value_past = Parameter(Tensor(np.zeros(shape=self.value_shape), self.dtype), name="value_past") 1235 self.tile = P.Tile().shard(((1, 1),)) 1236 self.mul = P.Mul().shard(((1, 1, 1, 1), (1,))) 1237 self.assign = P.Assign().shard(((1, 1, 1, 1), (1, 1, 1, 1))) 1238 1239 def construct(self, x, input_mask, init_reset=True, batch_valid_length=None): 1240 self._check_input(x, input_mask, init_reset, batch_valid_length) 1241 x_shape = F.shape(x) 1242 x = F.reshape(x, (-1, x_shape[-1])) 1243 input_x = self.layernorm1(x) 1244 input_x = F.cast(input_x, self.dtype) 1245 1246 # indicate whether reset saved states 1247 key_reset = None 1248 value_reset = None 1249 1250 if self.use_past: 1251 # reset states, init_reset True for reuse and False for reset 1252 key_reset = self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype))) 1253 value_reset = self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype))) 1254 # add dependency for desired execution order 1255 input_x = F.depend(input_x, key_reset) 1256 input_x = F.depend(input_x, value_reset) 1257 1258 attention, layer_present = self.attention(input_x, input_x, input_x, input_mask, 1259 self.key_past, self.value_past, batch_valid_length) 1260 # For post-layernorm the inputs for residual path are output of self-attention and output of layernorm 1261 if self.post_layernorm_residual: 1262 x = self.add(input_x, attention) 1263 # For pre-layernorm the inputs for residual path are output of self-attention and input of this layer 1264 else: 1265 x = self.add(x, attention) 1266 1267 output_x = self.layernorm2(x) 1268 output_x = F.cast(output_x, self.dtype) 1269 aux_loss = None 1270 if self.use_moe is True: 1271 mlp_logit, aux_loss = self.output(output_x) 1272 else: 1273 mlp_logit = self.output(output_x) 1274 1275 value_update = None 1276 key_update = None 1277 if self.use_past: 1278 # current key and value 1279 key_present, value_present = layer_present 1280 # update key and value calculated this step 1281 key_update = self.assign(self.key_past, key_present) 1282 value_update = self.assign(self.value_past, value_present) 1283 # add dependency for desired execution order 1284 key_update = F.depend(key_update, key_reset) 1285 value_update = F.depend(value_update, value_reset) 1286 1287 # add dependency for desired execution order 1288 mlp_logit = F.depend(mlp_logit, value_update) 1289 mlp_logit = F.depend(mlp_logit, key_update) 1290 1291 # if shape is 3d, we reshape the inputs of the add 1292 if len(x_shape) == 3: 1293 output_x = P.Reshape()(output_x, x_shape) 1294 mlp_logit = P.Reshape()(mlp_logit, x_shape) 1295 x = P.Reshape()(x, x_shape) 1296 1297 if self.post_layernorm_residual: 1298 output = self.add_3d(output_x, mlp_logit) 1299 else: 1300 output = self.add_3d(x, mlp_logit) 1301 else: 1302 if self.post_layernorm_residual: 1303 output = self.add(output_x, mlp_logit) 1304 else: 1305 output = self.add(x, mlp_logit) 1306 output = F.reshape(output, x_shape) 1307 1308 if self.use_moe is True: 1309 return output, layer_present, aux_loss 1310 return output, layer_present 1311 1312 def _check_input(self, x, input_mask, init_reset, batch_valid_length): 1313 r"""Check inputs""" 1314 if not self.use_past or (self.use_past and self.is_first_iteration): 1315 _check_shape_equal(F.shape(x), "x", self.cls_name, 1316 [[self.batch_size, self.seq_length, self.hidden_size], 1317 [self.batch_size * self.seq_length, self.hidden_size]]) 1318 _check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name, 1319 [self.batch_size, self.seq_length, self.seq_length]) 1320 else: 1321 _check_shape_equal(F.shape(x), "x", self.cls_name, [self.batch_size, 1, self.hidden_size]) 1322 _check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name, 1323 [self.batch_size, 1, self.seq_length]) 1324 _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name) 1325 _check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name) 1326 1327 init_reset_is_tensor = isinstance(init_reset, Tensor) 1328 init_reset_is_default = init_reset is True 1329 batch_valid_length_is_tensor = isinstance(batch_valid_length, Tensor) 1330 batch_is_default = batch_valid_length is None 1331 _check_past_none_input_none(self.use_past, "init_reset", self.cls_name, True, init_reset_is_tensor, 1332 init_reset_is_default) 1333 _check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None, 1334 batch_valid_length_is_tensor, batch_is_default) 1335 1336 if self.use_past: 1337 _check_shape_equal(F.shape(init_reset), "init_reset", self.cls_name, [1]) 1338 _check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name) 1339 _check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size]) 1340 _check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name) 1341 return True 1342 1343 1344class TransformerDecoderLayer(Cell): 1345 r""" 1346 Transformer Decoder Layer. This is an implementation of the single layer of the transformer 1347 decoder layer, including self-attention, cross attention and feedward layer. When the encoder_output is None, 1348 the cross attention will not be effective. 1349 1350 Args: 1351 batch_size(int): The batch size of the input tensor. 1352 hidden_size(int): The hidden size of the input. 1353 src_seq_length(int): The input source sequence length. 1354 tgt_seq_length(int): The input target sequence length. 1355 ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer. 1356 num_heads(int): The number of the heads. 1357 hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1. 1358 attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1. 1359 post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False. 1360 hidden_act(str): The activation of the internal feedforward layer. Supports 'relu', 1361 'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish', 1362 'hsigmoid', 'logsigmoid' and so on. Default: gelu. 1363 layernorm_compute_type(dtype.Number): The computation type of the layernorm. 1364 Should be dtype.float32 or dtype.float16. Default dtype.float32. 1365 softmax_compute_type(dtype.Number): The computation type of the softmax in the attention. 1366 Should be dtype.float32 or dtype.float16. Default mstype.float32. 1367 param_init_type(dtype.Number): The parameter initialization type of the module. 1368 Should be dtype.float32 or dtype.float16. Default dtype.float32. 1369 use_past(bool): Use the past state to compute, used for incremental prediction. Default False. 1370 moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). 1371 parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`, 1372 an instance of `OpParallelConfig` with default args. 1373 1374 Inputs: 1375 - **hidden_stats** (Tensor) - the input tensor with shape [batch_size, tgt_seq_length, hidden_size] or 1376 [batch_size * tgt_seq_length, hidden_size]. 1377 - **decoder_mask** (Tensor) - the attention mask for decoder with shape [batch_size, src_seq_length, 1378 seq_length]. 1379 - **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size] or 1380 [batch_size * seq_length, hidden_size]. Note this args can not be passed by None when the net is in outermost 1381 layer. Default None. 1382 - **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length, 1383 src_seq_length] where tgt_seq_length is the length of the decoder. Note this args can not be passed by 1384 None when the net is in outermost layer. Default None. 1385 - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and 1386 past value parameter used in the incremental prediction. Only valid when use_past is True. Default True. 1387 - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used 1388 for incremental prediction when the use_past is True. Default None. 1389 1390 Outputs: 1391 Tuple, a tuple contains(`output`, `layer_present`) 1392 1393 - **output** (Tensor) - the output logit of this layer. The shape is [batch, seq_length, hidden_size] or 1394 [batch * seq_length, hidden_size]. 1395 - **layer_present** (Tensor) - A tuple, where each tuple is the tensor of the projected key and value 1396 vector in self attention with shape ((batch_size, num_heads, size_per_head, tgt_seq_length), 1397 (batch_size, num_heads, tgt_seq_length, size_per_head), and of the projected key and value vector 1398 in cross attention with shape (batch_size, num_heads, size_per_head, src_seq_length), 1399 (batch_size, num_heads, src_seq_length, size_per_head)). 1400 1401 Supported Platforms: 1402 ``Ascend`` ``GPU`` 1403 1404 Examples: 1405 >>> import numpy as np 1406 >>> from mindspore import dtype as mstype 1407 >>> from mindspore.parallel.nn import TransformerDecoderLayer 1408 >>> from mindspore import Tensor 1409 >>> model = TransformerDecoderLayer(batch_size=2, hidden_size=64, ffn_hidden_size=64, num_heads=2, 1410 ... src_seq_length=20, tgt_seq_length=10) 1411 >>> encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32) 1412 >>> decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32) 1413 >>> decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16) 1414 >>> memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16) 1415 >>> output, past = model(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask) 1416 >>> print(output.shape) 1417 (2, 10, 64) 1418 >>> print(past[0].shape) 1419 (2, 2, 32, 10) 1420 >>> print(past[1].shape) 1421 (2, 2, 10, 32) 1422 >>> print(past[2].shape) 1423 (2, 2, 32, 20) 1424 >>> print(past[3].shape) 1425 (2, 2, 20, 32) 1426 """ 1427 1428 @_args_type_validator_check(batch_size=Validator.check_positive_int, 1429 hidden_size=Validator.check_positive_int, 1430 num_heads=Validator.check_positive_int, 1431 ffn_hidden_size=Validator.check_positive_int, 1432 src_seq_length=Validator.check_positive_int, 1433 tgt_seq_length=Validator.check_positive_int, 1434 attention_dropout_rate=Validator.check_non_negative_float, 1435 hidden_dropout_rate=Validator.check_non_negative_float, 1436 hidden_act=_valid_type_checks([str], "TransformerDecoderLayer"), 1437 post_layernorm_residual=Validator.check_bool, 1438 layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16], 1439 "TransformerDecoderLayer"), 1440 softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16], 1441 "TransformerDecoderLayer"), 1442 param_init_type=_valid_value_checks([mstype.float32, mstype.float16], 1443 "TransformerDecoderLayer"), 1444 parallel_config=_valid_type_checks([OpParallelConfig], 1445 "TransformerDecoderLayer"), 1446 use_past=Validator.check_bool) 1447 def __init__(self, hidden_size, 1448 ffn_hidden_size, 1449 num_heads, 1450 batch_size, 1451 src_seq_length, 1452 tgt_seq_length, 1453 attention_dropout_rate=0.1, 1454 hidden_dropout_rate=0.1, 1455 post_layernorm_residual=False, 1456 use_past=False, 1457 layernorm_compute_type=mstype.float32, 1458 softmax_compute_type=mstype.float32, 1459 param_init_type=mstype.float32, 1460 hidden_act='gelu', 1461 moe_config=default_moe_config, 1462 parallel_config=default_dpmp_config): 1463 super(TransformerDecoderLayer, self).__init__() 1464 _check_config(parallel_config) 1465 if num_heads % parallel_config.model_parallel != 0: 1466 raise ValueError( 1467 f"num heads must be divisibled by the model parallel way {parallel_config.model_parallel}, " 1468 f"but found {num_heads}") 1469 if hidden_size % parallel_config.model_parallel != 0: 1470 raise ValueError( 1471 f"hidden_size must be divisibled by the model parallel way {parallel_config.model_parallel}, " 1472 f"but found {hidden_size}") 1473 if ffn_hidden_size % parallel_config.model_parallel != 0: 1474 raise ValueError( 1475 f"ffn_hidden_size must be divisibled by the model parallel way {parallel_config.model_parallel}, " 1476 f"but found {ffn_hidden_size}") 1477 if use_past is True: 1478 raise ValueError(f"The {self.cls_name} does not support use_past=True.") 1479 self.batch_size = batch_size 1480 self.use_past = use_past 1481 self.softmax_compute_type = softmax_compute_type 1482 1483 self.src_seq_length = src_seq_length 1484 self.tgt_seq_length = tgt_seq_length 1485 self.use_past = use_past 1486 self.hidden_size = hidden_size 1487 1488 self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type) 1489 self.layernorm1.shard(((parallel_config.data_parallel, 1),)) 1490 self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type) 1491 self.layernorm2.shard(((parallel_config.data_parallel, 1),)) 1492 1493 self.attention = MultiHeadAttention(hidden_size=hidden_size, 1494 num_heads=num_heads, 1495 batch_size=batch_size, 1496 src_seq_length=tgt_seq_length, 1497 tgt_seq_length=tgt_seq_length, 1498 hidden_dropout_rate=hidden_dropout_rate, 1499 attention_dropout_rate=attention_dropout_rate, 1500 use_past=use_past, 1501 softmax_compute_type=softmax_compute_type, 1502 param_init_type=param_init_type, 1503 parallel_config=parallel_config) 1504 # Cross attention with the output of encoder as memory tensor 1505 self.cross_attention = MultiHeadAttention(hidden_size=hidden_size, 1506 num_heads=num_heads, 1507 batch_size=batch_size, 1508 src_seq_length=tgt_seq_length, 1509 tgt_seq_length=src_seq_length, 1510 hidden_dropout_rate=hidden_dropout_rate, 1511 attention_dropout_rate=attention_dropout_rate, 1512 softmax_compute_type=softmax_compute_type, 1513 use_past=use_past, 1514 param_init_type=param_init_type, 1515 parallel_config=parallel_config) 1516 self.cross_attention_layernorm = _LayerNorm((hidden_size,)).to_float( 1517 layernorm_compute_type) 1518 self.cross_attention_layernorm.shard(((parallel_config.data_parallel, 1),)) 1519 self.use_moe = (moe_config.expert_num > 1) 1520 if self.use_moe is True: 1521 self.output = MoE(hidden_size=hidden_size, 1522 dropout_rate=hidden_dropout_rate, 1523 ffn_hidden_size=ffn_hidden_size, 1524 param_init_type=param_init_type, 1525 hidden_act=hidden_act, 1526 moe_config=moe_config, 1527 parallel_config=parallel_config) 1528 else: 1529 # Feed Forward Network, FFN 1530 self.output = FeedForward(hidden_size=hidden_size, 1531 dropout_rate=hidden_dropout_rate, 1532 ffn_hidden_size=ffn_hidden_size, 1533 hidden_act=hidden_act, 1534 param_init_type=param_init_type, 1535 parallel_config=parallel_config) 1536 self.post_layernorm_residual = post_layernorm_residual 1537 self.add = P.Add().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1))) 1538 self.add_3d = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1))) 1539 self.dtype = mstype.float16 1540 self.key_past = None 1541 self.value_past = None 1542 if self.use_past: 1543 # operator used for state reuse 1544 self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),)) 1545 self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ())) 1546 self.slice = P.StridedSlice().shard(((1, 1, 1, 1),)) 1547 size_per_head = int(hidden_size / num_heads) 1548 self.key_shape = (batch_size, num_heads, size_per_head, tgt_seq_length) 1549 self.value_shape = (batch_size, num_heads, tgt_seq_length, size_per_head) 1550 # parameters saving key and value states 1551 self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past") 1552 self.value_past = Parameter(Tensor(np.zeros(shape=self.value_shape), self.dtype), name="value_past") 1553 self.tile = P.Tile().shard(((1, 1),)) 1554 self.mul = P.Mul().shard(((1, 1, 1, 1), (1,))) 1555 self.assign = P.Assign().shard(((1, 1, 1, 1), (1, 1, 1, 1))) 1556 1557 def construct(self, hidden_stats, 1558 decoder_mask, 1559 encoder_output=None, 1560 memory_mask=None, 1561 init_reset=True, batch_valid_length=None): 1562 self._check_input(hidden_stats, decoder_mask, encoder_output, memory_mask, init_reset, batch_valid_length) 1563 # the returned shape is [bs, seq_length, embedding_size] or [bs * seq_length, embedding_size] 1564 hidden_shape = F.shape(hidden_stats) 1565 hidden_stats = F.reshape(hidden_stats, (-1, hidden_shape[-1])) 1566 input_x = self.layernorm1(hidden_stats) 1567 input_x = F.cast(input_x, self.dtype) 1568 1569 # indicate whether reset saved states 1570 key_reset = None 1571 value_reset = None 1572 if self.use_past: 1573 # reset states, init_reset True for reuse and False for reset 1574 key_reset = self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype))) 1575 value_reset = self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype))) 1576 # add dependency for desired execution order 1577 input_x = F.depend(input_x, key_reset) 1578 input_x = F.depend(input_x, value_reset) 1579 1580 attention, layer_present = self.attention(input_x, input_x, input_x, decoder_mask, self.key_past, 1581 self.value_past, batch_valid_length) 1582 # For post-layernorm the inputs for residual path are output of self-attention and output of layernorm 1583 if self.post_layernorm_residual: 1584 x = self.add(input_x, attention) 1585 # For pre-layernorm the inputs for residual path are output of self-attention and input of this layer 1586 else: 1587 x = self.add(hidden_stats, attention) 1588 1589 middle_output = None 1590 if encoder_output is not None: 1591 middle_output = self.cross_attention_layernorm(x) 1592 middle_output = F.cast(middle_output, self.dtype) 1593 cross_attn_output, cross_layer_present = self.cross_attention(middle_output, encoder_output, 1594 encoder_output, 1595 memory_mask, self.key_past, 1596 self.value_past, batch_valid_length) 1597 layer_present += cross_layer_present 1598 if self.post_layernorm_residual: 1599 x = self.add(middle_output, cross_attn_output) 1600 else: 1601 x = self.add(x, cross_attn_output) 1602 1603 output_x = self.layernorm2(x) 1604 output_x = F.cast(output_x, self.dtype) 1605 aux_loss = None 1606 if self.use_moe is True: 1607 mlp_logit, aux_loss = self.output(output_x) 1608 else: 1609 mlp_logit = self.output(output_x) 1610 1611 value_update = None 1612 key_update = None 1613 if self.use_past: 1614 # current key and value 1615 key_present, value_present = layer_present 1616 # update key and value calculated this step 1617 key_update = self.assign(self.key_past, key_present) 1618 value_update = self.assign(self.value_past, value_present) 1619 # add dependency for desired execution order 1620 key_update = F.depend(key_update, key_reset) 1621 value_update = F.depend(value_update, value_reset) 1622 1623 # add dependency for desired execution order 1624 mlp_logit = F.depend(mlp_logit, value_update) 1625 mlp_logit = F.depend(mlp_logit, key_update) 1626 1627 # if shape is 3d, we reshape the inputs of the add 1628 if len(hidden_shape) == 3: 1629 output_x = P.Reshape()(output_x, hidden_shape) 1630 mlp_logit = P.Reshape()(mlp_logit, hidden_shape) 1631 x = P.Reshape()(x, hidden_shape) 1632 1633 if self.post_layernorm_residual: 1634 output = self.add_3d(output_x, mlp_logit) 1635 else: 1636 output = self.add_3d(x, mlp_logit) 1637 else: 1638 if self.post_layernorm_residual: 1639 output = self.add(output_x, mlp_logit) 1640 else: 1641 output = self.add(x, mlp_logit) 1642 output = F.reshape(output, hidden_shape) 1643 1644 if self.use_moe is True: 1645 return output, layer_present, aux_loss 1646 return output, layer_present 1647 1648 def _check_input(self, hidden_states, attention_mask, encoder_output, memory_mask, init_reset, batch_valid_length): 1649 r"""Check inputs""" 1650 if not self.use_past or (self.use_past and self.is_first_iteration): 1651 _check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name, 1652 [[self.batch_size, self.tgt_seq_length, self.hidden_size], 1653 [self.batch_size * self.tgt_seq_length, self.hidden_size]]) 1654 _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name, 1655 [self.batch_size, self.tgt_seq_length, self.tgt_seq_length]) 1656 1657 else: 1658 _check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name, 1659 [self.batch_size, 1, self.hidden_size]) 1660 _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name, 1661 [self.batch_size, 1, self.tgt_seq_length]) 1662 _check_input_dtype(F.dtype(hidden_states), "hidden_states", [mstype.float32, mstype.float16], self.cls_name) 1663 _check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name) 1664 if encoder_output is not None: 1665 _check_shape_equal(F.shape(encoder_output), "encoder_output", self.cls_name, 1666 [[self.batch_size, self.src_seq_length, self.hidden_size], 1667 [self.batch_size * self.src_seq_length, self.hidden_size]]) 1668 _check_input_dtype(F.dtype(encoder_output), "encoder_output", 1669 [mstype.float32, mstype.float16], self.cls_name) 1670 if memory_mask is not None: 1671 _check_shape_equal(F.shape(memory_mask), "memory_mask", self.cls_name, 1672 [self.batch_size, self.tgt_seq_length, self.src_seq_length]) 1673 _check_input_dtype(F.dtype(memory_mask), "memory_mask", 1674 [mstype.float32, mstype.float16], self.cls_name) 1675 1676 init_reset_is_tensor = isinstance(init_reset, Tensor) 1677 init_reset_is_default = init_reset is True 1678 batch_valid_length_is_tensor = isinstance(batch_valid_length, Tensor) 1679 batch_is_default = batch_valid_length is None 1680 _check_past_none_input_none(self.use_past, "init_reset", self.cls_name, True, init_reset_is_tensor, 1681 init_reset_is_default) 1682 _check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None, 1683 batch_valid_length_is_tensor, batch_is_default) 1684 1685 if self.use_past: 1686 _check_shape_equal(F.shape(init_reset), "init_reset", self.cls_name, [1]) 1687 _check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name) 1688 _check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size]) 1689 _check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name) 1690 return True 1691 1692 1693def _get_lambda_func(total_layer=None): 1694 r""" 1695 A wrapper function of specifying pipeline stage and gradient aggregation fusion. If the total layer 1696 is not None, for example, set in the transformer model, the pipeline stage setting function will be 1697 `(layer_id + 0) // (total_layers / parallel_config.pipeline_stage)` for the encoder and, 1698 `(layer_id + offset) // 1699 (total_layers / parallel_config.pipeline_stage)` for the decoder, where `offset` is the layers in the encoder. 1700 """ 1701 1702 def _set_parallel_configure_for_layer(network, layer_id, offset, parallel_config, layers): 1703 r""" 1704 Default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`. 1705 1706 Args: 1707 network(Cell) - Represents the transformer block 1708 layer_id(int) - Means the layer index for the current module, counts from zero. 1709 offset(int) - Means the layer_index needs an offset, if there are other modules in the net. 1710 layers(int) - The total layers used for the model. 1711 """ 1712 # override the layers 1713 if total_layer: 1714 layers = total_layer 1715 # Used for the pipeline's stages setting 1716 if layers < parallel_config.pipeline_stage: 1717 raise ValueError(f"layers {layers} must be larger than pipeline stage {parallel_config.pipeline_stage}") 1718 1719 pp_dis = max(int(layers / parallel_config.pipeline_stage), 1) 1720 # the pipeline stage must be in [0, parallel_config.pipeline_stage - 1] 1721 pp_id = min((layer_id + offset) // pp_dis, parallel_config.pipeline_stage - 1) 1722 network.pipeline_stage = pp_id 1723 logger.info(f"pipeline stage id is {pp_id}") 1724 1725 # Used for optimizer's fusion tag 1726 dis = max(int(layers / parallel_config.gradient_aggregation_group), 1) 1727 network.set_comm_fusion(int((layer_id + offset) / dis) + 1) 1728 # Used for enabling recomputation of the block 1729 if parallel_config.recompute: 1730 network.recompute() 1731 1732 return _set_parallel_configure_for_layer 1733 1734 1735class TransformerEncoder(Cell): 1736 r""" 1737 Transformer Encoder module with multi-layer stacked of `TransformerEncoderLayer`, including multihead self 1738 attention and feedforward layer. 1739 1740 Args: 1741 batch_size(int): The batch size of the input tensor. 1742 num_layers(int): The layers of the `TransformerEncoderLayer` 1743 hidden_size(int): The hidden size of the input. 1744 ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer. 1745 seq_length(int): The seq_length of the input tensor. 1746 num_heads(int): The number of the heads. 1747 hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1 1748 attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1 1749 post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False. 1750 hidden_act(str): The activation of the internal feedforward layer. Supports 'relu', 1751 'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish', 1752 'hsigmoid', 'logsigmoid' and so on. Default: gelu. 1753 layernorm_compute_type(dtype.Number): The computation type of the layernorm. 1754 Should be dtype.float32 or dtype.float16. Default dtype.float32. 1755 softmax_compute_type(dtype.Number): The computation type of the softmax in the attention. 1756 Should be dtype.float32 or dtype.float16. Default mstype.float32. 1757 param_init_type(dtype.Number): The parameter initialization type of the module. 1758 Should be dtype.float32 or dtype.float16. Default dtype.float32. 1759 use_past(bool): Use the past state to compute, used for incremental prediction. For example, if we have two 1760 words and want to generate the ten more words. We just need to compute the two words's state only once, 1761 and generate the next word one by one. When use_past is True, there are two steps to run the prediction. 1762 The first step, set the is_first_iteration to be True by 1763 `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the 1764 is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. At this moment, 1765 pass the single step's input tensor, and loop it. Default False. 1766 lambda_func: A function can determine the fusion index, pipeline stages and recompute attribute. If the user 1767 wants to determine the pipeline stage and gradient aggregation fusion, the user can pass a function 1768 that accepts `network`, `layer_id`, `offset`, `parallel_config`, `layers`. The `network(Cell)` 1769 represents the transformer block, `layer_id(int)` means the layer index for the current module, counts from 1770 zero, `offset(int)` means the layer_index needs an offset, if there are other modules in the net. The 1771 default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`. 1772 offset(int): The initial layer index for the `decoder`. Used for setting the fusion id and stage id, to not 1773 overlap with the encoder layer. 1774 moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). 1775 parallel_config(TransformerOpParallelConfig): The parallel configure. Default `default_transformer_config`, 1776 an instance of `TransformerOpParallelConfig` with default args. 1777 1778 Inputs: 1779 - **hidden_states** (Tensor) - Tensor, shape should be [batch_size, seq_length, hidden_size] or 1780 [batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise, 1781 should be [batch_size, 1, hidden_size]. 1782 - **attention_mask** (Tensor) - Tensor, attention mask with shape [batch_size, seq_length, seq_length] 1783 - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and 1784 past value parameter used in the incremental prediction. Only valid when use_past is True. Default True 1785 - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used 1786 for incremental prediction when the use_past is True. Default None. 1787 1788 Outputs: 1789 Tuple, a tuple contains(`output`, `layer_present`) 1790 1791 - **output** (Tensor) - The float tensor of the output of the layer with 1792 shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size), if the use_past is 1793 False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size). 1794 - **layer_present** (Tuple) - A tuple with size of num_layers, where each tuple contains the Tensor the 1795 projected key and value vector with shape ((batch_size, num_heads, size_per_head, seq_length), 1796 and (batch_size, num_heads, seq_length, size_per_head)). 1797 1798 Supported Platforms: 1799 ``Ascend`` ``GPU`` 1800 1801 Examples: 1802 >>> import numpy as np 1803 >>> from mindspore import dtype as mstype 1804 >>> from mindspore.parallel.nn import TransformerEncoder 1805 >>> from mindspore import Tensor 1806 >>> model = TransformerEncoder(batch_size=2, num_layers=2, hidden_size=8, ffn_hidden_size=64, seq_length=16, 1807 ... num_heads=2) 1808 >>> encoder_input_value = Tensor(np.ones((2, 16, 8)), mstype.float32) 1809 >>> encoder_input_mask = Tensor(np.ones((2, 16, 16)), mstype.float16) 1810 >>> output, past = model(encoder_input_value, encoder_input_mask) 1811 >>> print(output.shape) 1812 (2, 16, 8) 1813 >>> print(len(past)) 1814 2 1815 >>> print(past[0][0].shape) 1816 (2, 2, 4, 16) 1817 >>> print(past[0][1].shape) 1818 (2, 2, 16, 4) 1819 # When use use_past=True, it includes two steps to implement the incremental prediction. 1820 # Step 1: set is_first_iteration=True, and input the full sequence length's state. 1821 >>> batch_valid_length = Tensor(np.ones((2,)), mstype.int32) 1822 >>> init_reset = Tensor([True], mstype.bool_) 1823 # Set is_first_iteration=True to generate the full memory states 1824 >>> model = TransformerEncoder(batch_size=2, hidden_size=8, ffn_hidden_size=64, seq_length=16, 1825 ... num_heads=2, num_layers=2, use_past=True) 1826 >>> model.add_flags_recursive(is_first_iteration=True) 1827 >>> hidden, past = model(encoder_input_value, encoder_input_mask, init_reset, batch_valid_length) 1828 >>> print(hidden.shape) 1829 (2, 16, 8) 1830 >>> print(past[0].shape) 1831 (2, 2, 4, 16) 1832 >>> print(past[1].shape) 1833 (2, 2, 16, 4) 1834 >>> encoder_input_value = Tensor(np.ones((2, 1, 8)), mstype.float32) 1835 >>> encoder_input_mask = Tensor(np.ones((2, 1, 16)), mstype.float16) 1836 >>> init_reset = Tensor([False], mstype.bool_) 1837 # Step 2: set is_first_iteration=False, and pass the single word to run the prediction rather than the full 1838 # sequence. 1839 >>> model.add_flags_recursive(is_first_iteration=False) 1840 >>> hidden, past = model(encoder_input_value, encoder_input_mask, init_reset, batch_valid_length) 1841 >>> print(hidden.shape) 1842 (2, 1, 8) 1843 >>> print(past[0].shape) 1844 (2, 2, 4, 16) 1845 >>> print(past[1].shape) 1846 (2, 2, 16, 4) 1847 """ 1848 1849 @_args_type_validator_check(batch_size=Validator.check_positive_int, 1850 hidden_size=Validator.check_positive_int, 1851 num_heads=Validator.check_positive_int, 1852 ffn_hidden_size=Validator.check_positive_int, 1853 seq_length=Validator.check_positive_int, 1854 num_layers=Validator.check_positive_int, 1855 offset=Validator.check_non_negative_int, 1856 attention_dropout_rate=Validator.check_non_negative_float, 1857 hidden_dropout_rate=Validator.check_non_negative_float, 1858 hidden_act=_valid_type_checks([str], "TransformerEncoder"), 1859 post_layernorm_residual=Validator.check_bool, 1860 layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16], 1861 "TransformerEncoder"), 1862 softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16], 1863 "TransformerEncoder"), 1864 param_init_type=_valid_value_checks([mstype.float32, mstype.float16], 1865 "TransformerEncoder"), 1866 parallel_config=_valid_type_checks([TransformerOpParallelConfig], 1867 "TransformerEncoder"), 1868 use_past=Validator.check_bool) 1869 def __init__(self, 1870 batch_size, 1871 num_layers, 1872 hidden_size, 1873 ffn_hidden_size, 1874 seq_length, 1875 num_heads, 1876 attention_dropout_rate=0.1, 1877 hidden_dropout_rate=0.1, 1878 hidden_act='gelu', 1879 post_layernorm_residual=False, 1880 layernorm_compute_type=mstype.float32, 1881 softmax_compute_type=mstype.float32, 1882 param_init_type=mstype.float32, 1883 lambda_func=None, 1884 offset=0, 1885 use_past=False, 1886 moe_config=default_moe_config, 1887 parallel_config=default_transformer_config): 1888 super(TransformerEncoder, self).__init__() 1889 _check_config(parallel_config) 1890 1891 self.use_moe = (moe_config.expert_num > 1) 1892 self.add = P.Add().shard(((), ())) 1893 self.aux_loss = Tensor(0.0, mstype.float32) 1894 if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,): 1895 raise RuntimeError(f"The {self.cls_name} does not support auto parallel mode now.") 1896 self.num_layers = num_layers 1897 self.blocks = nn.CellList() 1898 for i in range(num_layers): 1899 block = TransformerEncoderLayer(hidden_size=hidden_size, 1900 batch_size=batch_size, 1901 ffn_hidden_size=ffn_hidden_size, 1902 seq_length=seq_length, 1903 attention_dropout_rate=attention_dropout_rate, 1904 hidden_dropout_rate=hidden_dropout_rate, 1905 layernorm_compute_type=layernorm_compute_type, 1906 softmax_compute_type=softmax_compute_type, 1907 num_heads=num_heads, 1908 hidden_act=hidden_act, 1909 post_layernorm_residual=post_layernorm_residual, 1910 param_init_type=param_init_type, 1911 use_past=use_past, 1912 moe_config=moe_config, 1913 parallel_config=parallel_config.dp_mp_config) 1914 # If the user doesn't pass the fusion function, use the default one 1915 if not lambda_func: 1916 lambda_func = _get_lambda_func() 1917 1918 lambda_func(block, layer_id=i, layers=num_layers, 1919 offset=offset, parallel_config=parallel_config) 1920 self.blocks.append(block) 1921 1922 def construct(self, hidden_states, attention_mask, init_reset=True, batch_valid_length=None): 1923 present_layer = () 1924 if self.use_moe is True: 1925 accum_loss = self.aux_loss 1926 for i in range(self.num_layers): 1927 hidden_states, present, aux_loss = self.blocks[i](hidden_states, 1928 attention_mask, 1929 init_reset, 1930 batch_valid_length) 1931 present_layer = present_layer + (present,) 1932 accum_loss = self.add(accum_loss, aux_loss) 1933 return hidden_states, present_layer, accum_loss 1934 1935 for i in range(self.num_layers): 1936 hidden_states, present = self.blocks[i](hidden_states, 1937 attention_mask, 1938 init_reset, 1939 batch_valid_length) 1940 present_layer = present_layer + (present,) 1941 1942 return hidden_states, present_layer 1943 1944 1945class TransformerDecoder(Cell): 1946 r""" 1947 Transformer Decoder module with multi-layer stacked of `TransformerDecoderLayer`, including multihead self 1948 attention, cross attention and feedforward layer. 1949 1950 Args: 1951 batch_size(int): The batch size of the input tensor. 1952 num_layers(int): The layers of the `TransformerDecoderLayer`. 1953 hidden_size(int): The hidden size of the input. 1954 ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer. 1955 src_seq_length(int): The input source sequence length. 1956 tgt_seq_length(int): The input target sequence length. 1957 num_heads(int): The number of the heads. 1958 hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1. 1959 attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1. 1960 post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False. 1961 hidden_act(str): The activation of the internal feedforward layer. Supports 'relu', 1962 'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish', 1963 'hsigmoid', 'logsigmoid' and so on. Default: gelu. 1964 layernorm_compute_type(dtype.Number): The computation type of the layernorm. 1965 Should be dtype.float32 or dtype.float16. Default dtype.float32. 1966 softmax_compute_type(dtype.Number): The computation type of the softmax in the attention. 1967 Should be dtype.float32 or dtype.float16. Default mstype.float32. 1968 param_init_type(dtype.Number): The parameter initialization type of the module. 1969 Should be dtype.float32 or dtype.float16. Default dtype.float32. 1970 offset(int): The initial layer index for the `decoder`. Used for setting the fusion id and stage id, to not 1971 overlap with the encoder layer. 1972 lambda_func: A function can determine the fusion index, pipeline stages and recompute attribute. If the user 1973 wants to determine the pipeline stage and gradient aggregation fusion, the user can pass a function 1974 that accepts `network`, `layer_id`, `offset`, `parallel_config`, `layers`. The `network(Cell)` 1975 represents the transformer block, `layer_id(int)` means the layer index for the current module, counts from 1976 zero, `offset(int)` means the layer_index needs an offset, if there are other modules in the net. The 1977 default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`. 1978 Default: None 1979 moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). 1980 parallel_config(TransformerOpParallelConfig): The parallel configure. Default `default_transformer_config`, 1981 an instance of `TransformerOpParallelConfig` with default args. 1982 1983 Inputs: 1984 - **hidden_stats** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size] or 1985 [batch_size * seq_length, hidden_size] 1986 - **attention_mask** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length] 1987 - **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size] or 1988 [batch_size * seq_length, hidden_size]. Note this args can not be passed by None when the net is in outermost 1989 layer. Default None. 1990 - **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length, 1991 src_seq_length] where tgt_seq_length is the length of the decoder. Note this args can not be passed by 1992 None when the net is in outermost layer. Default None. 1993 - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and 1994 past value parameter used in the incremental prediction. Only valid when use_past is True. Default True 1995 - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. 1996 Used for incremental prediction when the use_past is True. Default None. 1997 1998 Outputs: 1999 Tuple, a tuple contains(`output`, `layer_present`) 2000 2001 - **output** (Tensor) - The output logit of this layer. The shape is [batch, tgt_seq_length, hidden_size] or 2002 [batch * tgt_seq_length, hidden_size] 2003 - **layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor of the projected 2004 key and value vector in self attention with shape ((batch_size, num_heads, size_per_head, tgt_seq_length), 2005 (batch_size, num_heads, tgt_seq_length, size_per_head), and of the projected key and value vector 2006 in cross attention with shape (batch_size, num_heads, size_per_head, src_seq_length), 2007 (batch_size, num_heads, src_seq_length, size_per_head)). 2008 2009 Supported Platforms: 2010 ``Ascend`` ``GPU`` 2011 2012 Examples: 2013 >>> import numpy as np 2014 >>> from mindspore import dtype as mstype 2015 >>> from mindspore.parallel.nn import TransformerDecoder 2016 >>> from mindspore import Tensor 2017 >>> model = TransformerDecoder(batch_size=2, num_layers=1, hidden_size=64, ffn_hidden_size=64, 2018 ... num_heads=2, src_seq_length=20, tgt_seq_length=10) 2019 >>> encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32) 2020 >>> decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32) 2021 >>> decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16) 2022 >>> memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16) 2023 >>> output, past = model(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask) 2024 >>> print(output.shape) 2025 (2, 10, 64) 2026 >>> print(len(past)) 2027 1 2028 >>> print(past[0][0].shape) 2029 (2, 2, 32, 10) 2030 >>> print(past[0][1].shape) 2031 (2, 2, 10, 32) 2032 >>> print(past[0][2].shape) 2033 (2, 2, 32, 20) 2034 >>> print(past[0][3].shape) 2035 (2, 2, 20, 32) 2036 2037 """ 2038 2039 @_args_type_validator_check(batch_size=Validator.check_positive_int, 2040 hidden_size=Validator.check_positive_int, 2041 num_heads=Validator.check_positive_int, 2042 ffn_hidden_size=Validator.check_positive_int, 2043 src_seq_length=Validator.check_positive_int, 2044 num_layers=Validator.check_positive_int, 2045 tgt_seq_length=Validator.check_positive_int, 2046 offset=Validator.check_non_negative_int, 2047 attention_dropout_rate=Validator.check_non_negative_float, 2048 hidden_dropout_rate=Validator.check_non_negative_float, 2049 hidden_act=_valid_type_checks([str], "TransformerDecoder"), 2050 post_layernorm_residual=Validator.check_bool, 2051 layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16], 2052 "TransformerDecoder"), 2053 softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16], 2054 "TransformerDecoder"), 2055 param_init_type=_valid_value_checks([mstype.float32, mstype.float16], 2056 "TransformerDecoder"), 2057 parallel_config=_valid_type_checks([TransformerOpParallelConfig], 2058 "TransformerDecoder"), 2059 use_past=Validator.check_bool) 2060 def __init__(self, 2061 num_layers, 2062 batch_size, 2063 hidden_size, 2064 ffn_hidden_size, 2065 src_seq_length, 2066 tgt_seq_length, 2067 num_heads, 2068 attention_dropout_rate=0.1, 2069 hidden_dropout_rate=0.1, 2070 post_layernorm_residual=False, 2071 layernorm_compute_type=mstype.float32, 2072 softmax_compute_type=mstype.float32, 2073 param_init_type=mstype.float32, 2074 hidden_act='gelu', 2075 lambda_func=None, 2076 use_past=False, 2077 offset=0, 2078 moe_config=default_moe_config, 2079 parallel_config=default_transformer_config): 2080 super(TransformerDecoder, self).__init__() 2081 _check_config(parallel_config) 2082 2083 self.add = P.Add().shard(((), ())) 2084 self.aux_loss = Tensor(0.0, mstype.float32) 2085 if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,): 2086 raise RuntimeError(f"The {self.cls_name} does not support auto parallel mode now.") 2087 self.num_layers = num_layers 2088 self.blocks = nn.CellList() 2089 self.use_moe = (moe_config.expert_num > 1) 2090 for i in range(num_layers): 2091 block = TransformerDecoderLayer(hidden_size=hidden_size, 2092 batch_size=batch_size, 2093 ffn_hidden_size=ffn_hidden_size, 2094 src_seq_length=src_seq_length, 2095 tgt_seq_length=tgt_seq_length, 2096 attention_dropout_rate=attention_dropout_rate, 2097 hidden_dropout_rate=hidden_dropout_rate, 2098 num_heads=num_heads, 2099 layernorm_compute_type=layernorm_compute_type, 2100 softmax_compute_type=softmax_compute_type, 2101 hidden_act=hidden_act, 2102 use_past=use_past, 2103 param_init_type=param_init_type, 2104 post_layernorm_residual=post_layernorm_residual, 2105 moe_config=moe_config, 2106 parallel_config=parallel_config.dp_mp_config) 2107 # If the user doesn't pass the fusion function, use the default one 2108 if not lambda_func: 2109 lambda_func = _get_lambda_func() 2110 2111 lambda_func(block, layer_id=i, layers=num_layers, 2112 offset=offset, parallel_config=parallel_config) 2113 2114 self.blocks.append(block) 2115 2116 def construct(self, hidden_states, attention_mask, encoder_output=None, memory_mask=None, 2117 init_reset=True, batch_valid_length=None): 2118 present_layer = () 2119 if self.use_moe is True: 2120 accum_loss = self.aux_loss 2121 for i in range(self.num_layers): 2122 hidden_states, present, aux_loss = self.blocks[i](hidden_states, 2123 attention_mask, 2124 encoder_output, 2125 memory_mask, 2126 init_reset, 2127 batch_valid_length) 2128 present_layer = present_layer + (present,) 2129 accum_loss = self.add(accum_loss, aux_loss) 2130 return hidden_states, present_layer, accum_loss 2131 2132 # Loop through each self-attention layer 2133 for i in range(self.num_layers): 2134 hidden_states, present = self.blocks[i](hidden_states, 2135 attention_mask, 2136 encoder_output, 2137 memory_mask, 2138 init_reset, 2139 batch_valid_length) 2140 present_layer = present_layer + (present,) 2141 2142 return hidden_states, present_layer 2143 2144 2145class Transformer(Cell): 2146 r""" 2147 Transformer module including encoder and decoder. The difference with the original implements is the module use 2148 the residual addition before the layer normalization. And the default hidden act is `gelu`. 2149 The details can be found in `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_. 2150 2151 Note: 2152 This is an experimental interface that is subject to change and/or deletion. 2153 2154 Args: 2155 batch_size(int): The batch size of the input tensor. 2156 encoder_layers(int): The layers of the `TransformerEncoderLayer`. 2157 decoder_layers(int): The layers of the `TransformerDecoderLayer`. 2158 hidden_size(int): The hidden size of the input. 2159 ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer. 2160 src_seq_length(int): The seq_length of the encoder's input tensor. 2161 tgt_seq_length(int): The seq_length of the decoder's input tensor. 2162 num_heads(int): The number of the heads. Default: 2. 2163 hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1 2164 attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1 2165 post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False. 2166 layernorm_compute_type(dtype.Number): The computation type of the layernorm. 2167 Should be dtype.float32 or dtype.float16. Default dtype.float32. 2168 softmax_compute_type(dtype.Number): The computation type of the softmax in the attention. 2169 Should be dtype.float32 or dtype.float16. Default mstype.float32. 2170 param_init_type(dtype.Number): The parameter initialization type of the module. 2171 Should be dtype.float32 or dtype.float16. Default dtype.float32. 2172 hidden_act(str): The activation of the internal feedforward layer. Supports 'relu', 2173 'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish', 2174 'hsigmoid', 'logsigmoid' and so on. Default: gelu. 2175 moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). 2176 lambda_func: A function can determine the fusion index, pipeline stages and recompute attribute. If the user 2177 wants to determine the pipeline stage and gradient aggregation fusion, the user can pass a function 2178 that accepts `network`, `layer_id`, `offset`, `parallel_config`, `layers`. The `network(Cell)` 2179 represents the transformer block, `layer_id(int)` means the layer index for the current module, counts from 2180 zero, `offset(int)` means the layer_index needs an offset, if there are other modules in the net. The 2181 default setting for the pipeline is: `(layer_id + offset) // ((encoder_layers + decoder_length) 2182 / pipeline_stage)`. 2183 parallel_config(TransformerOpParallelConfig): The parallel configure. Default `default_transformer_config`, 2184 an instance of `TransformerOpParallelConfig` with default args. 2185 2186 Inputs: 2187 - **encoder_inputs** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size] or 2188 [batch_size * seq_length, hidden_size]. 2189 - **encoder_masks** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length]. 2190 - **decoder_inputs** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size] or 2191 [batch_size * seq_length, hidden_size], 2192 this should be none if the decoder layer is 0. 2193 - **decoder_masks** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length] 2194 - **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length, 2195 src_seq_length] 2196 where tgt_seq_length is the length of the decoder. the output of the encoder with shape [batch_size, 2197 seq_length, hidden_size], this should be none if the decoder layer is 0. 2198 - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and 2199 past value parameter used in the incremental prediction. Only valid when use_past is True. Default True 2200 - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used 2201 for incremental prediction when the use_past is True. Default None. 2202 2203 Outputs: 2204 Tuple, a tuple contains(`output`, `encoder_layer_present`, `encoder_layer_present`) 2205 2206 - **output** (Tensor) - If there is only encoder, the output logit of the encoder layer. The shape is 2207 [batch, src_seq_length, hidden_size] or [batch * src_seq_length, hidden_size], if there are encoder and 2208 decoders, the output is from the decoder layer. The shape is [batch, tgt_seq_length, hidden_size] or 2209 [batch * tgt_seq_length, hidden_size]. 2210 - **encoder_layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor the 2211 projected key and value vector in self attention with shape ((batch_size, num_heads, size_per_head, 2212 src_seq_length), (batch_size, num_heads, src_seq_length, size_per_head)). 2213 - **decoder_layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor 2214 of the projected key and value vector in self attention with shape ((batch_size, num_heads, size_per_head, 2215 tgt_seq_length), (batch_size, num_heads, tgt_seq_length, size_per_head)), and the 2216 projected key and value vector in cross attention with shape 2217 (batch_size, num_heads, size_per_head, src_seq_length), 2218 (batch_size, num_heads, src_seq_length, size_per_head)). If the decoder is not set, the 2219 returned value will be None. 2220 2221 Supported Platforms: 2222 ``Ascend`` ``GPU`` 2223 2224 Examples: 2225 >>> import numpy as np 2226 >>> from mindspore import dtype as mstype 2227 >>> from mindspore.parallel.nn import Transformer 2228 >>> from mindspore import Tensor 2229 >>> model = Transformer(batch_size=2, encoder_layers=1, decoder_layers=2, hidden_size=64, ffn_hidden_size=64, 2230 ... src_seq_length=20, tgt_seq_length=10) 2231 >>> encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32) 2232 >>> encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16) 2233 >>> decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32) 2234 >>> decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16) 2235 >>> memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16) 2236 >>> output, en_past, de_past = model(encoder_input_value, encoder_input_mask, decoder_input_value, 2237 ... decoder_input_mask, memory_mask) 2238 >>> print(output.shape) 2239 (2, 10, 64) 2240 >>> print(len(en_past)) 2241 1 2242 >>> print(len(de_past)) 2243 2 2244 >>> print(en_past[0][0].shape) 2245 (2, 2, 32, 20) 2246 >>> print(en_past[0][1].shape) 2247 (2, 2, 20, 32) 2248 >>> print(de_past[0][0].shape) 2249 (2, 2, 32, 10) 2250 >>> print(de_past[0][1].shape) 2251 (2, 2, 10, 32) 2252 >>> print(de_past[0][2].shape) 2253 (2, 2, 32, 20) 2254 >>> print(de_past[0][3].shape) 2255 (2, 2, 20, 32) 2256 2257 """ 2258 2259 @_args_type_validator_check(batch_size=Validator.check_positive_int, 2260 hidden_size=Validator.check_positive_int, 2261 num_heads=Validator.check_positive_int, 2262 ffn_hidden_size=Validator.check_positive_int, 2263 src_seq_length=Validator.check_positive_int, 2264 encoder_layers=Validator.check_positive_int, 2265 decoder_layers=Validator.check_non_negative_int, 2266 tgt_seq_length=Validator.check_positive_int, 2267 attention_dropout_rate=Validator.check_non_negative_float, 2268 hidden_dropout_rate=Validator.check_non_negative_float, 2269 hidden_act=_valid_type_checks([str], "Transformer"), 2270 post_layernorm_residual=Validator.check_bool, 2271 layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16], 2272 "Transformer"), 2273 softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16], 2274 "Transformer"), 2275 param_init_type=_valid_value_checks([mstype.float32, mstype.float16], "Transformer"), 2276 parallel_config=_valid_type_checks([TransformerOpParallelConfig], "Transformer"), 2277 use_past=Validator.check_bool) 2278 def __init__(self, 2279 hidden_size, 2280 batch_size, 2281 ffn_hidden_size, 2282 src_seq_length, 2283 tgt_seq_length, 2284 encoder_layers=3, 2285 decoder_layers=3, 2286 num_heads=2, 2287 attention_dropout_rate=0.1, 2288 hidden_dropout_rate=0.1, 2289 hidden_act='gelu', 2290 post_layernorm_residual=False, 2291 layernorm_compute_type=mstype.float32, 2292 softmax_compute_type=mstype.float32, 2293 param_init_type=mstype.float32, 2294 lambda_func=None, 2295 use_past=False, 2296 moe_config=default_moe_config, 2297 parallel_config=default_transformer_config): 2298 super(Transformer, self).__init__() 2299 _check_config(parallel_config) 2300 self.batch_size = batch_size 2301 self.hidden_size = hidden_size 2302 self.src_seq_length = src_seq_length 2303 self.tgt_seq_length = tgt_seq_length 2304 self.use_past = use_past 2305 if encoder_layers <= 0 < decoder_layers: 2306 raise ValueError(f"Transformer doest support encoder layer {encoder_layers} and decoder" 2307 f"layer {decoder_layers}, please use TransformerDecoder") 2308 if encoder_layers > 0 and decoder_layers > 0 and use_past is True: 2309 raise ValueError(f"The {self.cls_name} with encoder and decoder does not support use_past=True.") 2310 if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,): 2311 raise RuntimeError(f"The {self.cls_name} does not support auto parallel mode now.") 2312 # The shard setting of Transformer is set within the TransformerEncoderLayer 2313 if not lambda_func: 2314 lambda_func = _get_lambda_func(total_layer=encoder_layers + decoder_layers) 2315 2316 self.use_moe = (moe_config.expert_num > 1) 2317 self.add = P.Add().shard(((), ())) 2318 self.aux_loss = Tensor(0.0, mstype.float32) 2319 if encoder_layers > 0: 2320 self.encoder = TransformerEncoder(num_layers=encoder_layers, 2321 batch_size=batch_size, 2322 hidden_size=hidden_size, 2323 ffn_hidden_size=ffn_hidden_size, 2324 num_heads=num_heads, 2325 seq_length=src_seq_length, 2326 attention_dropout_rate=attention_dropout_rate, 2327 hidden_dropout_rate=hidden_dropout_rate, 2328 hidden_act=hidden_act, 2329 layernorm_compute_type=layernorm_compute_type, 2330 softmax_compute_type=softmax_compute_type, 2331 post_layernorm_residual=post_layernorm_residual, 2332 param_init_type=param_init_type, 2333 lambda_func=lambda_func, 2334 use_past=use_past, 2335 moe_config=moe_config, 2336 parallel_config=parallel_config) 2337 else: 2338 self.encoder = None 2339 2340 # Offset is needed as the encoder has consumed some flags. 2341 # so the decoder need to increase the flags based on the encoder layer 2342 self.decoder = None 2343 if decoder_layers > 0: 2344 self.decoder = TransformerDecoder(num_layers=decoder_layers, 2345 batch_size=batch_size, 2346 hidden_size=hidden_size, 2347 ffn_hidden_size=ffn_hidden_size, 2348 num_heads=num_heads, 2349 src_seq_length=src_seq_length, 2350 tgt_seq_length=tgt_seq_length, 2351 attention_dropout_rate=attention_dropout_rate, 2352 hidden_dropout_rate=hidden_dropout_rate, 2353 hidden_act=hidden_act, 2354 post_layernorm_residual=post_layernorm_residual, 2355 layernorm_compute_type=layernorm_compute_type, 2356 softmax_compute_type=softmax_compute_type, 2357 lambda_func=lambda_func, 2358 use_past=use_past, 2359 param_init_type=param_init_type, 2360 offset=encoder_layers, 2361 moe_config=moe_config, 2362 parallel_config=parallel_config) 2363 2364 def construct(self, encoder_inputs, 2365 encoder_masks, 2366 decoder_inputs=None, 2367 decoder_masks=None, 2368 memory_mask=None, 2369 init_reset=True, 2370 batch_valid_length=None): 2371 2372 encoder_output = None 2373 output = None 2374 encoder_layer_present = None 2375 decoder_layer_present = None 2376 accum_loss = self.aux_loss 2377 if self.encoder is not None: 2378 if self.use_moe is True: 2379 encoder_output, encoder_layer_present, encoder_aux_loss = self.encoder(encoder_inputs, encoder_masks, 2380 init_reset, batch_valid_length) 2381 accum_loss = self.add(accum_loss, encoder_aux_loss) 2382 else: 2383 encoder_output, encoder_layer_present = self.encoder(encoder_inputs, encoder_masks, init_reset, 2384 batch_valid_length) 2385 output = encoder_output 2386 2387 if self.decoder is not None: 2388 # decoder mask should be created outside of the model 2389 if self.use_moe is True: 2390 decoder_output, decoder_layer_present, decoder_aux_loss = self.decoder(decoder_inputs, decoder_masks, 2391 encoder_output, memory_mask, 2392 init_reset, batch_valid_length) 2393 accum_loss = self.add(accum_loss, decoder_aux_loss) 2394 else: 2395 decoder_output, decoder_layer_present = self.decoder(decoder_inputs, 2396 decoder_masks, 2397 encoder_output, 2398 memory_mask, init_reset, 2399 batch_valid_length) 2400 output = decoder_output 2401 if self.use_moe is True: 2402 return output, encoder_layer_present, decoder_layer_present, accum_loss 2403 return output, encoder_layer_present, decoder_layer_present 2404