• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16The basic layer of the Transformer Networks. This is an experimental interface that is subject to
17change and/or deletion.
18"""
19from functools import wraps, partial
20import inspect
21import math
22import numpy as np
23from mindspore.common.parameter import Parameter
24from mindspore.common.initializer import initializer, Tensor
25import mindspore.common.dtype as mstype
26from mindspore.ops import operations as P
27from mindspore._extends import cell_attr_register
28from mindspore.nn.cell import Cell
29from mindspore import nn
30from mindspore.nn.layer.activation import get_activation
31from mindspore.ops import functional as F
32from mindspore._checkparam import Validator
33from mindspore.ops.primitive import constexpr, Primitive
34from .op_parallel_config import default_dpmp_config, OpParallelConfig
35
36__all__ = [
37    "FixedSparseAttention"
38]
39
40
41def _args_type_validator_check(*type_args, **type_kwargs):
42    """Check whether input data type is correct."""
43
44    def type_check(func):
45        sig = inspect.signature(func)
46        bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments
47
48        @wraps(func)
49        def wrapper(*args, **kwargs):
50            nonlocal bound_types
51            bound_values = sig.bind(*args, **kwargs)
52
53            argument_dict = bound_values.arguments
54            if "kwargs" in bound_types:
55                bound_types = bound_types["kwargs"]
56            if "kwargs" in argument_dict:
57                argument_dict = argument_dict["kwargs"]
58            for name, value in argument_dict.items():
59                if name in bound_types:
60                    bound_types[name](value, name)
61            return func(*args, **kwargs)
62
63        return wrapper
64
65    return type_check
66
67
68def _valid_type_checks(types, class_name):
69    # types should be a list of types, this function check if the type is in the valid dtypes
70    def validator_check_func(value, name):
71        # The args of Validator.check_type_name is (arg_name, arg_type, valid_types, prim_name)
72        # as the input of _args_type_validator_check is fixed, so we need to manually change the input order
73        partial_check = partial(Validator.check_type_name, valid_types=types, prim_name=class_name)
74        return partial_check(name, type(value))
75
76    return validator_check_func
77
78
79def _valid_value_checks(types, class_name):
80    # the value should be a list of types, this function check if the value is in the valid dtypes
81    def validator_check_func(value, name):
82        # The args of Validator.check_type_name is (arg_name, arg_type, valid_types, prim_name)
83        # as the input of _args_type_validator_check is fixed, so we need to manually change the input order
84        partial_check = partial(Validator.check_type_name, valid_types=types, prim_name=class_name)
85        return partial_check(name, value)
86
87    return validator_check_func
88
89
90class _LayerInputCheck:
91    """
92       A input check class for the inputs of the transformer model.
93    """
94    @staticmethod
95    def check_shape_length(input_shape, param_name, func_name, target_len):
96        """
97        Check the input shape's length is equal to the expected shape
98        :param input_shape(list): a list of the tensor shapes.
99        :param param_name(str): the name of the checked parameter.
100        :param func_name(str): the name of the function.
101        :param target_len: the expected length of the shape.
102        :return:
103        """
104        if not isinstance(target_len, list):
105            target_len = [target_len]
106        matched = False
107        for item in target_len:
108            if len(input_shape) == item:
109                matched = True
110        if not matched:
111            raise ValueError(f"{func_name} {param_name} shape length should be one of {target_len} dimension, "
112                             f"but got shape {input_shape}")
113        return True
114
115    @staticmethod
116    def check_shape_equal(input_shape, param_name, func_name, target_shape):
117        """
118        Check the input shape's is equal to the expected shape
119        :param input_shape(list): a list of the tensor shapes.
120        :param param_name(str): the name of the checked parameter.
121        :param func_name(str): the name of the function.
122        :param target_shape: the expected shape.
123        :return:
124        """
125        if not isinstance(target_shape[0], list):
126            target_shape = [target_shape]
127        if isinstance(input_shape, tuple):
128            input_shape = list(input_shape)
129        _LayerInputCheck.check_shape_length(input_shape, param_name, func_name,
130                                            [len(item) for item in target_shape])
131        matched = False
132        for item in target_shape:
133            if item == input_shape:
134                matched = True
135                break
136
137        if not matched:
138            raise ValueError(f"{func_name} {param_name} shape should be one of {target_shape},"
139                             f"but got {input_shape}")
140        return True
141
142    @staticmethod
143    def check_shape_value_on_axis(input_shape, dim, param_name, cls_name, target_value):
144        if input_shape[dim] != target_value:
145            raise ValueError(f"{cls_name} {param_name} at {dim} shape should be {target_value},"
146                             f"but got {input_shape[dim]}")
147        return True
148
149
150
151@constexpr
152def _check_past_none_input_none(use_past, param_name, func_name, default_value, is_tensor, is_default):
153    """ If the past is True, check whether the inputs is None"""
154    if not use_past:
155        if is_tensor:
156            raise TypeError(f"{func_name} {param_name} should be {default_value}, if use_pat is False, but found "
157                            f"a tensor")
158        if not is_default:
159            raise TypeError(f"{func_name} {param_name} should be {default_value}, if use_pat is False.")
160    else:
161        if not is_tensor:
162            raise TypeError(f"{func_name} {param_name} should be tensor, if use_pat is True")
163    return True
164
165
166
167@constexpr
168def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
169    Validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
170
171
172@constexpr
173def _check_input_shape(input_shape, param_name, func_name, target_len):
174    # check the input length
175    _LayerInputCheck.check_shape_length(input_shape, param_name, func_name, target_len)
176
177
178@constexpr
179def _check_shape_equal(input_shape, param_name, func_name, target_shape):
180    # check the input length
181    _LayerInputCheck.check_shape_equal(input_shape, param_name, func_name, target_shape)
182
183
184@constexpr
185def _check_input_shape_value(input_shape, dim, param_name, cls_name, target_value):
186    _LayerInputCheck.check_shape_value_on_axis(input_shape, dim, param_name, cls_name, target_value)
187
188
189class _LayerNorm(Cell):
190    r"""
191        A self-defined layer norm operation using reduce sum and reduce mean
192
193        Args:
194            normalized_shape (tuple): The shape of the input tensor
195            eps (float): The epsilon value of the denominator. Default 1e-5.
196            param_init_type: The param init type.
197        Inputs:
198            - **x** (Tensor) - Tensor of shape :math:`(batch, seq\_length, hidden\_size)`.
199
200        Outputs:
201            Tensor of shape :math:`(batch, seq_length, hidden_size)`.
202    """
203
204    def __init__(self, normalized_shape, eps=1e-5, param_init_type=mstype.float32):
205        super(_LayerNorm, self).__init__()
206        if param_init_type not in [mstype.float32, mstype.float16]:
207            raise TypeError(f"param type should in [float32, float16], but found type {type(param_init_type)}")
208        if normalized_shape[0] <= 1024:
209            self.layer_norm = P.LayerNorm(begin_norm_axis=-1,
210                                          begin_params_axis=-1,
211                                          epsilon=eps)
212        self.is_self_defined = normalized_shape[0] > 1024
213        self.gamma = Parameter(initializer('ones', normalized_shape, param_init_type), name="gamma",
214                               parallel_optimizer=False)
215        self.beta = Parameter(initializer('zeros', normalized_shape, param_init_type), name="beta",
216                              parallel_optimizer=False)
217        self.mean = P.ReduceMean(keep_dims=True)
218        self.square = P.Square()
219        self.sqrt = P.Sqrt()
220        self.sub1 = P.Sub()
221        self.sub2 = P.Sub()
222        self.add = P.Add()
223        self.eps = eps
224        self.mul = P.Mul()
225        self.add2 = P.Add()
226        self.real_div = P.RealDiv()
227
228    def construct(self, x):
229        r"""
230          x : batch x seq_length x hidden_size
231        """
232        if self.is_self_defined:
233            mean = self.mean(x, -1)
234            diff = self.sub1(x, mean)
235            variance = self.mean(self.square(diff), -1)
236            variance_eps = self.sqrt(self.add(variance, self.eps))
237            output = self.real_div(diff, variance_eps)
238            output = self.add2(self.mul(output, self.gamma), self.beta)
239        else:
240            output, _, _ = self.layer_norm(x, self.gamma, self.beta)
241        return output
242
243    def shard(self, strategy):
244        r"""
245        Set the shard for the layer norm. the strategy size should be equal to the inputs.
246
247        Note:
248            It is valid only in semi auto parallel or auto parallel mode.
249            In other parallel modes, strategies set here will be ignored.
250
251        Args:
252            strategy (tuple): The strategy for the dropout. Should be the same shape as the inputs.
253        Examples:
254            >>> import mindspore
255            >>> net = mindspore.parallel.nn.transformer.LayerNorm(normalized_shape=(1024, 10))
256            >>> net.shard(((10, 2, 1),))
257        """
258        if self.is_self_defined:
259            self.mean.shard(strategy)
260            self.square.shard(strategy)
261            self.sqrt.shard(strategy)
262            self.sub1.shard((strategy[0], strategy[0]))
263            self.sub2.shard((strategy[0], strategy[0]))
264            self.add.shard((strategy[0], ()))
265            self.mul.shard((strategy[0], (1,)))
266            self.add2.shard((strategy[0], (1,)))
267            self.real_div.shard((strategy[0], strategy[0]))
268        else:
269            self.layer_norm.shard((strategy[0], (1,), (1,)))
270
271        return self
272
273
274class _Linear(Cell):
275    r"""
276    The dense connected layer. Once the parallel mode is enabled, the input shape should be
277    3-D tensor.
278
279    Applies dense connected layer for the input. This layer implements the operation as:
280
281    .. math::
282        \text{outputs} = \text{activation}(\text{X} * \text{kernel} + \text{bias}),
283
284    where :math:`X` is the input tensors, :math:`\text{activation}` is the activation function passed as the activation
285    argument (if passed in), :math:`\text{kernel}` is a weight matrix with the same
286    data type as the :math:`X` created by the layer, and :math:`\text{bias}` is a bias vector
287    with the same data type as the :math:`X` created by the layer (only if has_bias is True).
288
289    Args:
290        in_channels (int): The number of channels in the input space.
291        out_channels (int): The number of channels in the output space.
292        weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
293            is same as `x`. The values of str refer to the function `initializer`. Default: 'normal'.
294        bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
295            same as `x`. The values of str refer to the function `initializer`. Default: 'zeros'.
296        has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
297        activation (str): activate function applied to the output of the fully connected layer,
298            eg. 'ReLU'.Default: None.
299        expert_num (int): The number of experts used in this Linear. Here, for the case expert_num > 1, BatchMatMul is
300            used and the first dimension in BatchMatMul indicate expert_num. Default: 1.
301        compute_dtype (dtype.Number): The computation type. Default: mstype.float16
302    Inputs:
303        - **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`. The `in_channels` in `Args` should be equal
304          to :math:`in\_channels` in `Inputs`.
305
306    Outputs:
307        Tensor of shape :math:`(*, out\_channels)`.
308
309    Raises:
310        TypeError: If `in_channels` or `out_channels` is not an int.
311        TypeError: If `has_bias` is not a bool.
312        TypeError: If `activation` is not one of str, Cell, Primitive, None.
313        ValueError: If length of shape of `weight_init` is not equal to 2 or shape[0] of `weight_init`
314                    is not equal to `out_channels` or shape[1] of `weight_init` is not equal to `in_channels`.
315        ValueError: If length of shape of `bias_init` is not equal to 1
316                    or shape[0] of `bias_init` is not equal to `out_channels`.
317
318    Supported Platforms:
319        ``Ascend`` ``GPU``
320    """
321
322    @cell_attr_register(attrs=['has_bias', 'in_channels', 'out_channels', 'shard_output', 'activation'])
323    def __init__(self,
324                 in_channels,
325                 out_channels,
326                 weight_init='normal',
327                 bias_init='zeros',
328                 has_bias=True,
329                 activation=None,
330                 transpose_b=True,
331                 expert_num=1,
332                 param_init_type=mstype.float32,
333                 compute_dtype=mstype.float16):
334        super(_Linear, self).__init__()
335        self.in_channels = Validator.check_positive_int(in_channels)
336        self.out_channels = Validator.check_positive_int(out_channels)
337        if param_init_type not in [mstype.float32, mstype.float16]:
338            raise TypeError(f"param type should in [float32, float16], but found type {type(param_init_type)}")
339        if activation and not isinstance(activation, str):
340            raise ValueError("Activation can only be str, but found type {}".format(activation))
341        if isinstance(weight_init, Tensor):
342            if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
343                    weight_init.shape[1] != in_channels:
344                raise ValueError("Weight init shape error.")
345        if transpose_b:
346            weight_shape = [out_channels, in_channels]
347        else:
348            weight_shape = [in_channels, out_channels]
349        self.expert_num = expert_num
350        if self.expert_num > 1:
351            self.expert_flag = True
352            self.weight = Parameter(initializer(weight_init, [self.expert_num] + weight_shape, param_init_type),
353                                    name="weight")
354            self.matmul = P.BatchMatMul(transpose_b=transpose_b)
355        else:
356            self.expert_flag = False
357            self.weight = Parameter(initializer(weight_init, weight_shape, param_init_type), name="weight")
358            self.matmul = P.MatMul(transpose_b=transpose_b)
359        self.bias = None
360        self.has_bias = has_bias
361        if self.has_bias:
362            if isinstance(bias_init, Tensor):
363                if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
364                    raise ValueError("Bias init shape error.")
365            self.bias = Parameter(initializer(bias_init, [out_channels], param_init_type), name="bias")
366            self.bias_add = P.Add()
367        self.act_name = activation
368        self.activation = get_activation(activation) if isinstance(activation, str) else activation
369        if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
370            raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
371        self.activation_flag = self.activation is not None
372        self.dtype = compute_dtype
373        self.cast = P.Cast()
374
375    def construct(self, x):
376        out_shape = P.Shape()(x)[:-1] + (self.out_channels,)
377        x = P.Reshape()(x, (-1, self.in_channels))
378        if self.expert_flag is True:
379            x = P.Reshape()(x, (self.expert_num, -1, self.in_channels))
380        weight = self.cast(self.weight, self.dtype)
381        x = self.matmul(x, weight)
382        if self.has_bias:
383            x = self.bias_add(x, self.cast(self.bias, self.dtype))
384        if self.activation_flag:
385            x = self.activation(x)
386        output = P.Reshape()(x, out_shape)
387        return output
388
389    def shard(self, strategy_matmul, strategy_bias=None, strategy_activation=None):
390        r"""
391         Set the shard for the linear. the strategy size should be equal to the inputs.
392
393         Note:
394            It is valid only in semi auto parallel or auto parallel mode.
395            In other parallel modes, strategies set here will be ignored.
396
397         Args:
398             strategy_matmul (tuple): The strategy for the matmul. Should be the same shape as the inputs.
399             strategy_bias (tuple): The strategy for the bias_add. Should be the same shape as the inputs.
400             strategy_activation (tuple): The strategy for the strategy_activation. Should be the same shape as
401                the inputs.
402         """
403        self.matmul.shard(strategy_matmul)
404        if self.has_bias:
405            self.bias_add.shard(strategy_bias)
406        if self.activation_flag:
407            # some operations has many primitives, need to manually set the shard
408            if self.act_name.lower() == "leakyrelu":
409                self.activation.select_op.shard((strategy_activation[0], strategy_activation[0]))
410            elif self.act_name.lower() == "logsigmoid":
411                self.activation.mul.shard((strategy_activation[0], ()))
412                self.activation.exp.shard(strategy_activation)
413                self.activation.add.shard((strategy_activation[0], ()))
414                self.activation.rec.shard(strategy_activation)
415                self.activation.log.shard(strategy_activation)
416            elif self.act_name.lower() == "logsoftmax":
417                raise ValueError("logsoftmax is not supported.")
418            else:
419                getattr(self.activation, self.act_name).shard(strategy_activation)
420
421        return self
422
423
424class FixedSparseAttention(nn.Cell):
425    """
426    Fixed Sparse Attention Layer
427
428    This function contains the sparse attention primitives used in Sparse Transformers (see paper).
429    https://arxiv.org/abs/1904.10509
430    Specifically, it includes the following:
431    1. A faster implementation of normal attention (the upper triangle is not computed, and many operations are fused).
432    2. An implementation of "strided" and "fixed" attention, as in the Sparse Transformers paper.
433
434    Args:
435        batch_size (int): Number of input batch size.
436        num_heads (int): Number of attention heads.
437        block_size (int): An integer determining the block size. Current implementation of sparse self-attention
438                          is based on blocked sparse matrices. In which this parameter defines size of such blocks,
439                          Block X Block. only supports 64 for now
440        seq_length (int): length of input sequence, only supports 1024 for now
441        num_different_global_patterns (int):An integer determining number of different global attentions layouts.
442                                            While global attention can be fixed by which block/s are representative of
443                                            any local window, since there are multi-heads, each head can use a
444                                            different global representative, only supports 4 for now
445        size_per_head (int): An integer determining embedding size of each attention head,
446                             only supports 64, 128 for now
447
448    Inputs:
449        - **q** (Tensor) - Tensor query (:class:`mstype.fp16` [batch_size, seq_length, hidden_size]): Sequence of
450          queries to query the context.
451        - **k** (Tensor) - Tensor key (:class:`mstype.fp16` [batch_size, seq_length, hidden_size]): Sequence of
452          queries to query the context.
453        - **v** (Tensor) - Tensor value (:class:`mstype.fp16` [batch size, sequence length, Embedding Size]):
454          Sequence of queries to query the context.
455        - **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp32`, :class:`mstype.fp16`
456          [batch_size, seq_length, seq_length]): Lower triangular matrix to pass masked information.
457
458    Outputs:
459        A Tensor. The output of the attention with shape [batch_size, seq_length, hidden_size]
460
461    Supported Platforms:
462        ``Ascend``
463
464    Examples:
465        >>> import numpy as np
466        >>> from mindspore import dtype as mstype
467        >>> from mindspore.parallel.nn import FixedSparseAttention
468        >>> from mindspore import Tensor
469        >>> model = FixedSparseAttention(batch_size=2,
470        ...                              num_heads=8,
471        ...                              size_per_head=64,
472        ...                              block_size=64)
473        >>> q = Tensor(np.ones((2, 1024, 8*64)), mstype.float16)
474        >>> k = Tensor(np.ones((2, 1024, 8*64)), mstype.float16)
475        >>> v = Tensor(np.ones((2, 1024, 8*64)), mstype.float16)
476        >>> attention_mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
477        >>> output = model(q, k, v, attention_mask)
478        >>> print(output.shape)
479        (2, 1024, 512)
480    """
481
482    @_args_type_validator_check(batch_size=Validator.check_positive_int,
483                                num_heads=Validator.check_positive_int,
484                                size_per_head=Validator.check_positive_int,
485                                block_size=Validator.check_positive_int,
486                                seq_length=Validator.check_positive_int,
487                                num_different_global_patterns=Validator.check_positive_int,
488                                parallel_config=_valid_type_checks([OpParallelConfig], "FixedSparseAttention"))
489    def __init__(self,
490                 batch_size,
491                 num_heads,
492                 size_per_head,
493                 block_size,
494                 seq_length=1024,
495                 num_different_global_patterns=4,
496                 parallel_config=default_dpmp_config):
497        super(FixedSparseAttention, self).__init__()
498        dp, mp = parallel_config.data_parallel, parallel_config.model_parallel
499        if num_heads % mp != 0:
500            raise ValueError(f"The number of heads {num_heads} must be a "
501                             f"multiple of parallel_config.model_parallel {mp}.")
502        if batch_size % dp != 0:
503            raise ValueError(f"The batch_size {batch_size} must be a "
504                             f"multiple of parallel_config.data_parallel {parallel_config.data_parallel}.")
505        self.seq_length = seq_length
506        self.batch_size = batch_size
507        self.hidden_size = size_per_head * num_heads
508        self.num_heads = num_heads
509        self.block_size = block_size
510        self.block_num = seq_length // block_size
511        self.size_per_head = size_per_head
512        self.global_size = seq_length // 4
513        self.reshape = P.Reshape()
514        self.transpose = P.Transpose().shard(((dp, 1, mp, 1),))
515        self.batch_matmul = P.BatchMatMul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
516        self.multiply = P.Mul().shard(((dp, 1, 1, 1), (1, 1, 1)))
517        self.multiply_data = Tensor([-10000.0], dtype=mstype.float32)
518        self.parallel_config = parallel_config
519        size_per_head_list = [64, 128]
520        if self.seq_length != 1024:
521            raise ValueError("seq_length only supports 1024 for now.")
522        if self.block_size != 64:
523            raise ValueError("block_size only supports 64 for now.")
524        if num_different_global_patterns != 4:
525            raise ValueError("num_different_global_patterns only supports 4 for now.")
526        if self.size_per_head not in size_per_head_list:
527            raise ValueError(f"size_per_head only supports {size_per_head_list} for now, "
528                             f"but found {self.size_per_head}")
529        local_ones = np.ones((self.block_size, self.block_size),
530                             dtype=np.float16)
531        global_mask_original = np.ones((self.seq_length, self.global_size), dtype=np.float16)
532        for i in range(self.seq_length):
533            for j in range(self.global_size):
534                if i // 16 >= (j // 16 + 1) * 4:
535                    global_mask_original[i, j] = 0.0
536
537        global_mask_original = -10000 * global_mask_original
538        global_mask_fx = global_mask_original.reshape((self.seq_length // 16, 16, self.global_size // 16, 16))
539        global_mask = np.transpose(global_mask_fx, (2, 0, 1, 3))
540        global_mask = np.repeat(global_mask[np.newaxis, :, :, :, :], self.batch_size, axis=0)
541        global_mask = global_mask.reshape((self.batch_size * self.global_size // 16, self.seq_length // 16, 16, 16))
542        self.global_mask = Tensor(global_mask, mstype.float32)
543        self.local_mask_triangle = Tensor(np.tril(local_ones), mstype.float32)
544        self.scale_factor = Tensor((math.sqrt(self.size_per_head)))
545        self.matmul_dds = P.MatmulDDS(self.batch_size, self.num_heads).shard(((mp, dp, 1, 1),
546                                                                              (mp, dp, 1, 1),
547                                                                              (1, dp, 1, 1),
548                                                                              (dp, 1, 1, 1)))
549        self.matmul_dsd = P.DSDMatmul().shard(((dp, mp, 1, 1, 1, 1, 1), (dp, mp, 1, 1, 1, 1, 1), (dp, mp, 1, 1)))
550        self.sub1 = P.Sub().shard(((1,), (dp, 1, 1, 1)))
551        self.mul1 = P.Mul().shard(((dp, 1, 1, 1), (1,)))
552        self.transpose1 = P.Transpose().shard(((dp, 1, 1, 1),))
553        self.transpose2 = P.Transpose().shard(((dp, 1, 1, 1),))
554        self.transpose3 = P.Transpose().shard(((dp, mp, 1, 1, 1, 1),))
555        self.transpose4 = P.Transpose().shard(((dp, mp, 1, 1),))
556        self.div = P.RealDiv().shard(((mp, dp, 1, 1), ()))
557        self.slice1 = P.StridedSlice().shard(((dp, 1, 1),))
558
559    def _transpose_inputs(self, q, k, v):
560        """
561        do reshape and transpose to inputs
562        """
563        q = self.transpose(
564            self.reshape(
565                q,
566                (-1, 16, self.num_heads * self.size_per_head // 16, 16)),
567            (2, 0, 1, 3))
568        k = self.transpose(
569            self.reshape(
570                k, (-1, 16, self.num_heads * self.size_per_head // 16, 16)),
571            (2, 0, 1, 3))
572        v = self.transpose(
573            self.reshape(
574                v,
575                (-1, 16, self.num_heads * self.size_per_head // 16, 16)),
576            (0, 2, 3, 1))
577
578        return q, k, v
579
580    def _generate_attention_mask(self, attention_mask):
581        """
582        generate global attention mask and local attention mask from origin attention mask
583        """
584        attention_mask = self.reshape(attention_mask, (-1, self.seq_length, self.seq_length))
585        input_mask = self.slice1(attention_mask, (0, self.seq_length - 1, 0),
586                                 (self.batch_size, self.seq_length, self.seq_length), (1, 1, 1))
587        input_mask = self.reshape(input_mask, (-1, self.seq_length))
588        input_shape = P.Shape()(input_mask)  # bs, seq_length
589        # bs, block_num, 1, block_size
590        local_shape_right = (input_shape[0], self.block_num, 1, self.block_size)
591        # bs, block_num, block_size, 1
592        local_shape_left = (input_shape[0], self.block_num, self.block_size, 1)
593        local_mask_left = self.reshape(input_mask, local_shape_left)
594        local_mask_right = self.reshape(input_mask, local_shape_right)
595        # bs, block_num, block_size, block_size
596        local_attention_mask = self.batch_matmul(local_mask_left, local_mask_right)
597        lower_triangle = P.ExpandDims()(self.local_mask_triangle, 0)
598        local_attention_mask = self.multiply(local_attention_mask, lower_triangle)
599        local_multiplied_out = self.sub1(P.Cast()(F.tuple_to_array((1.0,)), mstype.float32),
600                                         P.Cast()(local_attention_mask, mstype.float32))
601        local_adder = self.mul1(local_multiplied_out, self.multiply_data)
602        local_mask_original = self.transpose1(local_adder, (0, 2, 1, 3))
603        local_mask_original = self.reshape(
604            local_mask_original,
605            (self.batch_size * self.block_size, self.block_num * self.block_size))
606        local_mask_fx = self.reshape(
607            local_mask_original,
608            (self.batch_size * self.block_size // 16, 16,
609             self.block_num * self.block_size // 16, 16))
610        local_mask = self.transpose2(local_mask_fx, (2, 0, 1, 3))
611        global_mask = self.global_mask
612
613        return local_mask, global_mask
614
615    def construct(self, q, k, v, attention_mask):
616        _check_shape_equal(F.shape(q), "q", self.cls_name,
617                           [self.batch_size, self.seq_length, self.hidden_size])
618        _check_input_dtype(F.dtype(q), "q", [mstype.float16], self.cls_name)
619        _check_shape_equal(F.shape(k), "k", self.cls_name,
620                           [self.batch_size, self.seq_length, self.hidden_size])
621        _check_input_dtype(F.dtype(k), "k", [mstype.float16], self.cls_name)
622        _check_shape_equal(F.shape(v), "v", self.cls_name,
623                           [self.batch_size, self.seq_length, self.hidden_size])
624        _check_input_dtype(F.dtype(v), "v", [mstype.float16], self.cls_name)
625        _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
626                           [self.batch_size, self.seq_length, self.seq_length])
627        _check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
628
629        q, k, v = self._transpose_inputs(q, k, v)
630        local_mask, global_mask = self._generate_attention_mask(attention_mask)
631        q = self.div(q, F.cast(self.scale_factor, F.dtype(q)))
632        k = self.div(k, F.cast(self.scale_factor, F.dtype(k)))
633        local_prob, global_prob = self.matmul_dds(q, k, local_mask, global_mask)
634        attention = self.matmul_dsd(local_prob, global_prob, v)
635        attention_merge = self.transpose3(attention, (0, 1, 3, 4, 2, 5))
636        attention_merge = F.reshape(
637            attention_merge,
638            (-1, self.num_heads, self.seq_length, self.size_per_head))
639        attention_merge = self.transpose4(attention_merge, (0, 2, 1, 3))
640        attention_merge = F.reshape(
641            attention_merge,
642            (-1, self.seq_length, self.size_per_head * self.num_heads))
643
644        return attention_merge
645