1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16The basic layer of the Transformer Networks. This is an experimental interface that is subject to 17change and/or deletion. 18""" 19from functools import wraps, partial 20import inspect 21import math 22import numpy as np 23from mindspore.common.parameter import Parameter 24from mindspore.common.initializer import initializer, Tensor 25import mindspore.common.dtype as mstype 26from mindspore.ops import operations as P 27from mindspore._extends import cell_attr_register 28from mindspore.nn.cell import Cell 29from mindspore import nn 30from mindspore.nn.layer.activation import get_activation 31from mindspore.ops import functional as F 32from mindspore._checkparam import Validator 33from mindspore.ops.primitive import constexpr, Primitive 34from .op_parallel_config import default_dpmp_config, OpParallelConfig 35 36__all__ = [ 37 "FixedSparseAttention" 38] 39 40 41def _args_type_validator_check(*type_args, **type_kwargs): 42 """Check whether input data type is correct.""" 43 44 def type_check(func): 45 sig = inspect.signature(func) 46 bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments 47 48 @wraps(func) 49 def wrapper(*args, **kwargs): 50 nonlocal bound_types 51 bound_values = sig.bind(*args, **kwargs) 52 53 argument_dict = bound_values.arguments 54 if "kwargs" in bound_types: 55 bound_types = bound_types["kwargs"] 56 if "kwargs" in argument_dict: 57 argument_dict = argument_dict["kwargs"] 58 for name, value in argument_dict.items(): 59 if name in bound_types: 60 bound_types[name](value, name) 61 return func(*args, **kwargs) 62 63 return wrapper 64 65 return type_check 66 67 68def _valid_type_checks(types, class_name): 69 # types should be a list of types, this function check if the type is in the valid dtypes 70 def validator_check_func(value, name): 71 # The args of Validator.check_type_name is (arg_name, arg_type, valid_types, prim_name) 72 # as the input of _args_type_validator_check is fixed, so we need to manually change the input order 73 partial_check = partial(Validator.check_type_name, valid_types=types, prim_name=class_name) 74 return partial_check(name, type(value)) 75 76 return validator_check_func 77 78 79def _valid_value_checks(types, class_name): 80 # the value should be a list of types, this function check if the value is in the valid dtypes 81 def validator_check_func(value, name): 82 # The args of Validator.check_type_name is (arg_name, arg_type, valid_types, prim_name) 83 # as the input of _args_type_validator_check is fixed, so we need to manually change the input order 84 partial_check = partial(Validator.check_type_name, valid_types=types, prim_name=class_name) 85 return partial_check(name, value) 86 87 return validator_check_func 88 89 90class _LayerInputCheck: 91 """ 92 A input check class for the inputs of the transformer model. 93 """ 94 @staticmethod 95 def check_shape_length(input_shape, param_name, func_name, target_len): 96 """ 97 Check the input shape's length is equal to the expected shape 98 :param input_shape(list): a list of the tensor shapes. 99 :param param_name(str): the name of the checked parameter. 100 :param func_name(str): the name of the function. 101 :param target_len: the expected length of the shape. 102 :return: 103 """ 104 if not isinstance(target_len, list): 105 target_len = [target_len] 106 matched = False 107 for item in target_len: 108 if len(input_shape) == item: 109 matched = True 110 if not matched: 111 raise ValueError(f"{func_name} {param_name} shape length should be one of {target_len} dimension, " 112 f"but got shape {input_shape}") 113 return True 114 115 @staticmethod 116 def check_shape_equal(input_shape, param_name, func_name, target_shape): 117 """ 118 Check the input shape's is equal to the expected shape 119 :param input_shape(list): a list of the tensor shapes. 120 :param param_name(str): the name of the checked parameter. 121 :param func_name(str): the name of the function. 122 :param target_shape: the expected shape. 123 :return: 124 """ 125 if not isinstance(target_shape[0], list): 126 target_shape = [target_shape] 127 if isinstance(input_shape, tuple): 128 input_shape = list(input_shape) 129 _LayerInputCheck.check_shape_length(input_shape, param_name, func_name, 130 [len(item) for item in target_shape]) 131 matched = False 132 for item in target_shape: 133 if item == input_shape: 134 matched = True 135 break 136 137 if not matched: 138 raise ValueError(f"{func_name} {param_name} shape should be one of {target_shape}," 139 f"but got {input_shape}") 140 return True 141 142 @staticmethod 143 def check_shape_value_on_axis(input_shape, dim, param_name, cls_name, target_value): 144 if input_shape[dim] != target_value: 145 raise ValueError(f"{cls_name} {param_name} at {dim} shape should be {target_value}," 146 f"but got {input_shape[dim]}") 147 return True 148 149 150 151@constexpr 152def _check_past_none_input_none(use_past, param_name, func_name, default_value, is_tensor, is_default): 153 """ If the past is True, check whether the inputs is None""" 154 if not use_past: 155 if is_tensor: 156 raise TypeError(f"{func_name} {param_name} should be {default_value}, if use_pat is False, but found " 157 f"a tensor") 158 if not is_default: 159 raise TypeError(f"{func_name} {param_name} should be {default_value}, if use_pat is False.") 160 else: 161 if not is_tensor: 162 raise TypeError(f"{func_name} {param_name} should be tensor, if use_pat is True") 163 return True 164 165 166 167@constexpr 168def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name): 169 Validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) 170 171 172@constexpr 173def _check_input_shape(input_shape, param_name, func_name, target_len): 174 # check the input length 175 _LayerInputCheck.check_shape_length(input_shape, param_name, func_name, target_len) 176 177 178@constexpr 179def _check_shape_equal(input_shape, param_name, func_name, target_shape): 180 # check the input length 181 _LayerInputCheck.check_shape_equal(input_shape, param_name, func_name, target_shape) 182 183 184@constexpr 185def _check_input_shape_value(input_shape, dim, param_name, cls_name, target_value): 186 _LayerInputCheck.check_shape_value_on_axis(input_shape, dim, param_name, cls_name, target_value) 187 188 189class _LayerNorm(Cell): 190 r""" 191 A self-defined layer norm operation using reduce sum and reduce mean 192 193 Args: 194 normalized_shape (tuple): The shape of the input tensor 195 eps (float): The epsilon value of the denominator. Default 1e-5. 196 param_init_type: The param init type. 197 Inputs: 198 - **x** (Tensor) - Tensor of shape :math:`(batch, seq\_length, hidden\_size)`. 199 200 Outputs: 201 Tensor of shape :math:`(batch, seq_length, hidden_size)`. 202 """ 203 204 def __init__(self, normalized_shape, eps=1e-5, param_init_type=mstype.float32): 205 super(_LayerNorm, self).__init__() 206 if param_init_type not in [mstype.float32, mstype.float16]: 207 raise TypeError(f"param type should in [float32, float16], but found type {type(param_init_type)}") 208 if normalized_shape[0] <= 1024: 209 self.layer_norm = P.LayerNorm(begin_norm_axis=-1, 210 begin_params_axis=-1, 211 epsilon=eps) 212 self.is_self_defined = normalized_shape[0] > 1024 213 self.gamma = Parameter(initializer('ones', normalized_shape, param_init_type), name="gamma", 214 parallel_optimizer=False) 215 self.beta = Parameter(initializer('zeros', normalized_shape, param_init_type), name="beta", 216 parallel_optimizer=False) 217 self.mean = P.ReduceMean(keep_dims=True) 218 self.square = P.Square() 219 self.sqrt = P.Sqrt() 220 self.sub1 = P.Sub() 221 self.sub2 = P.Sub() 222 self.add = P.Add() 223 self.eps = eps 224 self.mul = P.Mul() 225 self.add2 = P.Add() 226 self.real_div = P.RealDiv() 227 228 def construct(self, x): 229 r""" 230 x : batch x seq_length x hidden_size 231 """ 232 if self.is_self_defined: 233 mean = self.mean(x, -1) 234 diff = self.sub1(x, mean) 235 variance = self.mean(self.square(diff), -1) 236 variance_eps = self.sqrt(self.add(variance, self.eps)) 237 output = self.real_div(diff, variance_eps) 238 output = self.add2(self.mul(output, self.gamma), self.beta) 239 else: 240 output, _, _ = self.layer_norm(x, self.gamma, self.beta) 241 return output 242 243 def shard(self, strategy): 244 r""" 245 Set the shard for the layer norm. the strategy size should be equal to the inputs. 246 247 Note: 248 It is valid only in semi auto parallel or auto parallel mode. 249 In other parallel modes, strategies set here will be ignored. 250 251 Args: 252 strategy (tuple): The strategy for the dropout. Should be the same shape as the inputs. 253 Examples: 254 >>> import mindspore 255 >>> net = mindspore.parallel.nn.transformer.LayerNorm(normalized_shape=(1024, 10)) 256 >>> net.shard(((10, 2, 1),)) 257 """ 258 if self.is_self_defined: 259 self.mean.shard(strategy) 260 self.square.shard(strategy) 261 self.sqrt.shard(strategy) 262 self.sub1.shard((strategy[0], strategy[0])) 263 self.sub2.shard((strategy[0], strategy[0])) 264 self.add.shard((strategy[0], ())) 265 self.mul.shard((strategy[0], (1,))) 266 self.add2.shard((strategy[0], (1,))) 267 self.real_div.shard((strategy[0], strategy[0])) 268 else: 269 self.layer_norm.shard((strategy[0], (1,), (1,))) 270 271 return self 272 273 274class _Linear(Cell): 275 r""" 276 The dense connected layer. Once the parallel mode is enabled, the input shape should be 277 3-D tensor. 278 279 Applies dense connected layer for the input. This layer implements the operation as: 280 281 .. math:: 282 \text{outputs} = \text{activation}(\text{X} * \text{kernel} + \text{bias}), 283 284 where :math:`X` is the input tensors, :math:`\text{activation}` is the activation function passed as the activation 285 argument (if passed in), :math:`\text{kernel}` is a weight matrix with the same 286 data type as the :math:`X` created by the layer, and :math:`\text{bias}` is a bias vector 287 with the same data type as the :math:`X` created by the layer (only if has_bias is True). 288 289 Args: 290 in_channels (int): The number of channels in the input space. 291 out_channels (int): The number of channels in the output space. 292 weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype 293 is same as `x`. The values of str refer to the function `initializer`. Default: 'normal'. 294 bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is 295 same as `x`. The values of str refer to the function `initializer`. Default: 'zeros'. 296 has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. 297 activation (str): activate function applied to the output of the fully connected layer, 298 eg. 'ReLU'.Default: None. 299 expert_num (int): The number of experts used in this Linear. Here, for the case expert_num > 1, BatchMatMul is 300 used and the first dimension in BatchMatMul indicate expert_num. Default: 1. 301 compute_dtype (dtype.Number): The computation type. Default: mstype.float16 302 Inputs: 303 - **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`. The `in_channels` in `Args` should be equal 304 to :math:`in\_channels` in `Inputs`. 305 306 Outputs: 307 Tensor of shape :math:`(*, out\_channels)`. 308 309 Raises: 310 TypeError: If `in_channels` or `out_channels` is not an int. 311 TypeError: If `has_bias` is not a bool. 312 TypeError: If `activation` is not one of str, Cell, Primitive, None. 313 ValueError: If length of shape of `weight_init` is not equal to 2 or shape[0] of `weight_init` 314 is not equal to `out_channels` or shape[1] of `weight_init` is not equal to `in_channels`. 315 ValueError: If length of shape of `bias_init` is not equal to 1 316 or shape[0] of `bias_init` is not equal to `out_channels`. 317 318 Supported Platforms: 319 ``Ascend`` ``GPU`` 320 """ 321 322 @cell_attr_register(attrs=['has_bias', 'in_channels', 'out_channels', 'shard_output', 'activation']) 323 def __init__(self, 324 in_channels, 325 out_channels, 326 weight_init='normal', 327 bias_init='zeros', 328 has_bias=True, 329 activation=None, 330 transpose_b=True, 331 expert_num=1, 332 param_init_type=mstype.float32, 333 compute_dtype=mstype.float16): 334 super(_Linear, self).__init__() 335 self.in_channels = Validator.check_positive_int(in_channels) 336 self.out_channels = Validator.check_positive_int(out_channels) 337 if param_init_type not in [mstype.float32, mstype.float16]: 338 raise TypeError(f"param type should in [float32, float16], but found type {type(param_init_type)}") 339 if activation and not isinstance(activation, str): 340 raise ValueError("Activation can only be str, but found type {}".format(activation)) 341 if isinstance(weight_init, Tensor): 342 if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \ 343 weight_init.shape[1] != in_channels: 344 raise ValueError("Weight init shape error.") 345 if transpose_b: 346 weight_shape = [out_channels, in_channels] 347 else: 348 weight_shape = [in_channels, out_channels] 349 self.expert_num = expert_num 350 if self.expert_num > 1: 351 self.expert_flag = True 352 self.weight = Parameter(initializer(weight_init, [self.expert_num] + weight_shape, param_init_type), 353 name="weight") 354 self.matmul = P.BatchMatMul(transpose_b=transpose_b) 355 else: 356 self.expert_flag = False 357 self.weight = Parameter(initializer(weight_init, weight_shape, param_init_type), name="weight") 358 self.matmul = P.MatMul(transpose_b=transpose_b) 359 self.bias = None 360 self.has_bias = has_bias 361 if self.has_bias: 362 if isinstance(bias_init, Tensor): 363 if bias_init.ndim != 1 or bias_init.shape[0] != out_channels: 364 raise ValueError("Bias init shape error.") 365 self.bias = Parameter(initializer(bias_init, [out_channels], param_init_type), name="bias") 366 self.bias_add = P.Add() 367 self.act_name = activation 368 self.activation = get_activation(activation) if isinstance(activation, str) else activation 369 if activation is not None and not isinstance(self.activation, (Cell, Primitive)): 370 raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) 371 self.activation_flag = self.activation is not None 372 self.dtype = compute_dtype 373 self.cast = P.Cast() 374 375 def construct(self, x): 376 out_shape = P.Shape()(x)[:-1] + (self.out_channels,) 377 x = P.Reshape()(x, (-1, self.in_channels)) 378 if self.expert_flag is True: 379 x = P.Reshape()(x, (self.expert_num, -1, self.in_channels)) 380 weight = self.cast(self.weight, self.dtype) 381 x = self.matmul(x, weight) 382 if self.has_bias: 383 x = self.bias_add(x, self.cast(self.bias, self.dtype)) 384 if self.activation_flag: 385 x = self.activation(x) 386 output = P.Reshape()(x, out_shape) 387 return output 388 389 def shard(self, strategy_matmul, strategy_bias=None, strategy_activation=None): 390 r""" 391 Set the shard for the linear. the strategy size should be equal to the inputs. 392 393 Note: 394 It is valid only in semi auto parallel or auto parallel mode. 395 In other parallel modes, strategies set here will be ignored. 396 397 Args: 398 strategy_matmul (tuple): The strategy for the matmul. Should be the same shape as the inputs. 399 strategy_bias (tuple): The strategy for the bias_add. Should be the same shape as the inputs. 400 strategy_activation (tuple): The strategy for the strategy_activation. Should be the same shape as 401 the inputs. 402 """ 403 self.matmul.shard(strategy_matmul) 404 if self.has_bias: 405 self.bias_add.shard(strategy_bias) 406 if self.activation_flag: 407 # some operations has many primitives, need to manually set the shard 408 if self.act_name.lower() == "leakyrelu": 409 self.activation.select_op.shard((strategy_activation[0], strategy_activation[0])) 410 elif self.act_name.lower() == "logsigmoid": 411 self.activation.mul.shard((strategy_activation[0], ())) 412 self.activation.exp.shard(strategy_activation) 413 self.activation.add.shard((strategy_activation[0], ())) 414 self.activation.rec.shard(strategy_activation) 415 self.activation.log.shard(strategy_activation) 416 elif self.act_name.lower() == "logsoftmax": 417 raise ValueError("logsoftmax is not supported.") 418 else: 419 getattr(self.activation, self.act_name).shard(strategy_activation) 420 421 return self 422 423 424class FixedSparseAttention(nn.Cell): 425 """ 426 Fixed Sparse Attention Layer 427 428 This function contains the sparse attention primitives used in Sparse Transformers (see paper). 429 https://arxiv.org/abs/1904.10509 430 Specifically, it includes the following: 431 1. A faster implementation of normal attention (the upper triangle is not computed, and many operations are fused). 432 2. An implementation of "strided" and "fixed" attention, as in the Sparse Transformers paper. 433 434 Args: 435 batch_size (int): Number of input batch size. 436 num_heads (int): Number of attention heads. 437 block_size (int): An integer determining the block size. Current implementation of sparse self-attention 438 is based on blocked sparse matrices. In which this parameter defines size of such blocks, 439 Block X Block. only supports 64 for now 440 seq_length (int): length of input sequence, only supports 1024 for now 441 num_different_global_patterns (int):An integer determining number of different global attentions layouts. 442 While global attention can be fixed by which block/s are representative of 443 any local window, since there are multi-heads, each head can use a 444 different global representative, only supports 4 for now 445 size_per_head (int): An integer determining embedding size of each attention head, 446 only supports 64, 128 for now 447 448 Inputs: 449 - **q** (Tensor) - Tensor query (:class:`mstype.fp16` [batch_size, seq_length, hidden_size]): Sequence of 450 queries to query the context. 451 - **k** (Tensor) - Tensor key (:class:`mstype.fp16` [batch_size, seq_length, hidden_size]): Sequence of 452 queries to query the context. 453 - **v** (Tensor) - Tensor value (:class:`mstype.fp16` [batch size, sequence length, Embedding Size]): 454 Sequence of queries to query the context. 455 - **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp32`, :class:`mstype.fp16` 456 [batch_size, seq_length, seq_length]): Lower triangular matrix to pass masked information. 457 458 Outputs: 459 A Tensor. The output of the attention with shape [batch_size, seq_length, hidden_size] 460 461 Supported Platforms: 462 ``Ascend`` 463 464 Examples: 465 >>> import numpy as np 466 >>> from mindspore import dtype as mstype 467 >>> from mindspore.parallel.nn import FixedSparseAttention 468 >>> from mindspore import Tensor 469 >>> model = FixedSparseAttention(batch_size=2, 470 ... num_heads=8, 471 ... size_per_head=64, 472 ... block_size=64) 473 >>> q = Tensor(np.ones((2, 1024, 8*64)), mstype.float16) 474 >>> k = Tensor(np.ones((2, 1024, 8*64)), mstype.float16) 475 >>> v = Tensor(np.ones((2, 1024, 8*64)), mstype.float16) 476 >>> attention_mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32) 477 >>> output = model(q, k, v, attention_mask) 478 >>> print(output.shape) 479 (2, 1024, 512) 480 """ 481 482 @_args_type_validator_check(batch_size=Validator.check_positive_int, 483 num_heads=Validator.check_positive_int, 484 size_per_head=Validator.check_positive_int, 485 block_size=Validator.check_positive_int, 486 seq_length=Validator.check_positive_int, 487 num_different_global_patterns=Validator.check_positive_int, 488 parallel_config=_valid_type_checks([OpParallelConfig], "FixedSparseAttention")) 489 def __init__(self, 490 batch_size, 491 num_heads, 492 size_per_head, 493 block_size, 494 seq_length=1024, 495 num_different_global_patterns=4, 496 parallel_config=default_dpmp_config): 497 super(FixedSparseAttention, self).__init__() 498 dp, mp = parallel_config.data_parallel, parallel_config.model_parallel 499 if num_heads % mp != 0: 500 raise ValueError(f"The number of heads {num_heads} must be a " 501 f"multiple of parallel_config.model_parallel {mp}.") 502 if batch_size % dp != 0: 503 raise ValueError(f"The batch_size {batch_size} must be a " 504 f"multiple of parallel_config.data_parallel {parallel_config.data_parallel}.") 505 self.seq_length = seq_length 506 self.batch_size = batch_size 507 self.hidden_size = size_per_head * num_heads 508 self.num_heads = num_heads 509 self.block_size = block_size 510 self.block_num = seq_length // block_size 511 self.size_per_head = size_per_head 512 self.global_size = seq_length // 4 513 self.reshape = P.Reshape() 514 self.transpose = P.Transpose().shard(((dp, 1, mp, 1),)) 515 self.batch_matmul = P.BatchMatMul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) 516 self.multiply = P.Mul().shard(((dp, 1, 1, 1), (1, 1, 1))) 517 self.multiply_data = Tensor([-10000.0], dtype=mstype.float32) 518 self.parallel_config = parallel_config 519 size_per_head_list = [64, 128] 520 if self.seq_length != 1024: 521 raise ValueError("seq_length only supports 1024 for now.") 522 if self.block_size != 64: 523 raise ValueError("block_size only supports 64 for now.") 524 if num_different_global_patterns != 4: 525 raise ValueError("num_different_global_patterns only supports 4 for now.") 526 if self.size_per_head not in size_per_head_list: 527 raise ValueError(f"size_per_head only supports {size_per_head_list} for now, " 528 f"but found {self.size_per_head}") 529 local_ones = np.ones((self.block_size, self.block_size), 530 dtype=np.float16) 531 global_mask_original = np.ones((self.seq_length, self.global_size), dtype=np.float16) 532 for i in range(self.seq_length): 533 for j in range(self.global_size): 534 if i // 16 >= (j // 16 + 1) * 4: 535 global_mask_original[i, j] = 0.0 536 537 global_mask_original = -10000 * global_mask_original 538 global_mask_fx = global_mask_original.reshape((self.seq_length // 16, 16, self.global_size // 16, 16)) 539 global_mask = np.transpose(global_mask_fx, (2, 0, 1, 3)) 540 global_mask = np.repeat(global_mask[np.newaxis, :, :, :, :], self.batch_size, axis=0) 541 global_mask = global_mask.reshape((self.batch_size * self.global_size // 16, self.seq_length // 16, 16, 16)) 542 self.global_mask = Tensor(global_mask, mstype.float32) 543 self.local_mask_triangle = Tensor(np.tril(local_ones), mstype.float32) 544 self.scale_factor = Tensor((math.sqrt(self.size_per_head))) 545 self.matmul_dds = P.MatmulDDS(self.batch_size, self.num_heads).shard(((mp, dp, 1, 1), 546 (mp, dp, 1, 1), 547 (1, dp, 1, 1), 548 (dp, 1, 1, 1))) 549 self.matmul_dsd = P.DSDMatmul().shard(((dp, mp, 1, 1, 1, 1, 1), (dp, mp, 1, 1, 1, 1, 1), (dp, mp, 1, 1))) 550 self.sub1 = P.Sub().shard(((1,), (dp, 1, 1, 1))) 551 self.mul1 = P.Mul().shard(((dp, 1, 1, 1), (1,))) 552 self.transpose1 = P.Transpose().shard(((dp, 1, 1, 1),)) 553 self.transpose2 = P.Transpose().shard(((dp, 1, 1, 1),)) 554 self.transpose3 = P.Transpose().shard(((dp, mp, 1, 1, 1, 1),)) 555 self.transpose4 = P.Transpose().shard(((dp, mp, 1, 1),)) 556 self.div = P.RealDiv().shard(((mp, dp, 1, 1), ())) 557 self.slice1 = P.StridedSlice().shard(((dp, 1, 1),)) 558 559 def _transpose_inputs(self, q, k, v): 560 """ 561 do reshape and transpose to inputs 562 """ 563 q = self.transpose( 564 self.reshape( 565 q, 566 (-1, 16, self.num_heads * self.size_per_head // 16, 16)), 567 (2, 0, 1, 3)) 568 k = self.transpose( 569 self.reshape( 570 k, (-1, 16, self.num_heads * self.size_per_head // 16, 16)), 571 (2, 0, 1, 3)) 572 v = self.transpose( 573 self.reshape( 574 v, 575 (-1, 16, self.num_heads * self.size_per_head // 16, 16)), 576 (0, 2, 3, 1)) 577 578 return q, k, v 579 580 def _generate_attention_mask(self, attention_mask): 581 """ 582 generate global attention mask and local attention mask from origin attention mask 583 """ 584 attention_mask = self.reshape(attention_mask, (-1, self.seq_length, self.seq_length)) 585 input_mask = self.slice1(attention_mask, (0, self.seq_length - 1, 0), 586 (self.batch_size, self.seq_length, self.seq_length), (1, 1, 1)) 587 input_mask = self.reshape(input_mask, (-1, self.seq_length)) 588 input_shape = P.Shape()(input_mask) # bs, seq_length 589 # bs, block_num, 1, block_size 590 local_shape_right = (input_shape[0], self.block_num, 1, self.block_size) 591 # bs, block_num, block_size, 1 592 local_shape_left = (input_shape[0], self.block_num, self.block_size, 1) 593 local_mask_left = self.reshape(input_mask, local_shape_left) 594 local_mask_right = self.reshape(input_mask, local_shape_right) 595 # bs, block_num, block_size, block_size 596 local_attention_mask = self.batch_matmul(local_mask_left, local_mask_right) 597 lower_triangle = P.ExpandDims()(self.local_mask_triangle, 0) 598 local_attention_mask = self.multiply(local_attention_mask, lower_triangle) 599 local_multiplied_out = self.sub1(P.Cast()(F.tuple_to_array((1.0,)), mstype.float32), 600 P.Cast()(local_attention_mask, mstype.float32)) 601 local_adder = self.mul1(local_multiplied_out, self.multiply_data) 602 local_mask_original = self.transpose1(local_adder, (0, 2, 1, 3)) 603 local_mask_original = self.reshape( 604 local_mask_original, 605 (self.batch_size * self.block_size, self.block_num * self.block_size)) 606 local_mask_fx = self.reshape( 607 local_mask_original, 608 (self.batch_size * self.block_size // 16, 16, 609 self.block_num * self.block_size // 16, 16)) 610 local_mask = self.transpose2(local_mask_fx, (2, 0, 1, 3)) 611 global_mask = self.global_mask 612 613 return local_mask, global_mask 614 615 def construct(self, q, k, v, attention_mask): 616 _check_shape_equal(F.shape(q), "q", self.cls_name, 617 [self.batch_size, self.seq_length, self.hidden_size]) 618 _check_input_dtype(F.dtype(q), "q", [mstype.float16], self.cls_name) 619 _check_shape_equal(F.shape(k), "k", self.cls_name, 620 [self.batch_size, self.seq_length, self.hidden_size]) 621 _check_input_dtype(F.dtype(k), "k", [mstype.float16], self.cls_name) 622 _check_shape_equal(F.shape(v), "v", self.cls_name, 623 [self.batch_size, self.seq_length, self.hidden_size]) 624 _check_input_dtype(F.dtype(v), "v", [mstype.float16], self.cls_name) 625 _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name, 626 [self.batch_size, self.seq_length, self.seq_length]) 627 _check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name) 628 629 q, k, v = self._transpose_inputs(q, k, v) 630 local_mask, global_mask = self._generate_attention_mask(attention_mask) 631 q = self.div(q, F.cast(self.scale_factor, F.dtype(q))) 632 k = self.div(k, F.cast(self.scale_factor, F.dtype(k))) 633 local_prob, global_prob = self.matmul_dds(q, k, local_mask, global_mask) 634 attention = self.matmul_dsd(local_prob, global_prob, v) 635 attention_merge = self.transpose3(attention, (0, 1, 3, 4, 2, 5)) 636 attention_merge = F.reshape( 637 attention_merge, 638 (-1, self.num_heads, self.seq_length, self.size_per_head)) 639 attention_merge = self.transpose4(attention_merge, (0, 2, 1, 3)) 640 attention_merge = F.reshape( 641 attention_merge, 642 (-1, self.seq_length, self.size_per_head * self.num_heads)) 643 644 return attention_merge 645