• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Bert model."""
16
17import math
18import copy
19import numpy as np
20import mindspore.common.dtype as mstype
21import mindspore.nn as nn
22import mindspore.ops.functional as F
23from mindspore.common.initializer import TruncatedNormal, initializer
24from mindspore.ops import operations as P
25from mindspore.ops import composite as C
26from mindspore.common.tensor import Tensor
27from mindspore.common.parameter import Parameter
28
29
30class BertConfig:
31    """
32    Configuration for `BertModel`.
33
34    Args:
35        batch_size (int): Batch size of input dataset.
36        seq_length (int): Length of input sequence. Default: 128.
37        vocab_size (int): The shape of each embedding vector. Default: 32000.
38        hidden_size (int): Size of the bert encoder layers. Default: 768.
39        num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder
40                           cell. Default: 12.
41        num_attention_heads (int): Number of attention heads in the BertTransformer
42                             encoder cell. Default: 12.
43        intermediate_size (int): Size of intermediate layer in the BertTransformer
44                           encoder cell. Default: 3072.
45        hidden_act (str): Activation function used in the BertTransformer encoder
46                    cell. Default: "gelu".
47        hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
48        attention_probs_dropout_prob (float): The dropout probability for
49                                      BertAttention. Default: 0.1.
50        max_position_embeddings (int): Maximum length of sequences used in this
51                                 model. Default: 512.
52        type_vocab_size (int): Size of token type vocab. Default: 16.
53        initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
54        use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
55        input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
56                                 dataset. Default: True.
57        token_type_ids_from_dataset (bool): Specifies whether to use the token type ids that loaded
58                                     from dataset. Default: True.
59        dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
60        compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
61    """
62    def __init__(self,
63                 batch_size,
64                 seq_length=128,
65                 vocab_size=32000,
66                 hidden_size=768,
67                 num_hidden_layers=12,
68                 num_attention_heads=12,
69                 intermediate_size=3072,
70                 hidden_act="gelu",
71                 hidden_dropout_prob=0.1,
72                 attention_probs_dropout_prob=0.1,
73                 max_position_embeddings=512,
74                 type_vocab_size=16,
75                 initializer_range=0.02,
76                 use_relative_positions=False,
77                 input_mask_from_dataset=True,
78                 token_type_ids_from_dataset=True,
79                 dtype=mstype.float32,
80                 compute_type=mstype.float32,
81                 enable_fused_layernorm=False):
82        self.batch_size = batch_size
83        self.seq_length = seq_length
84        self.vocab_size = vocab_size
85        self.hidden_size = hidden_size
86        self.num_hidden_layers = num_hidden_layers
87        self.num_attention_heads = num_attention_heads
88        self.hidden_act = hidden_act
89        self.intermediate_size = intermediate_size
90        self.hidden_dropout_prob = hidden_dropout_prob
91        self.attention_probs_dropout_prob = attention_probs_dropout_prob
92        self.max_position_embeddings = max_position_embeddings
93        self.type_vocab_size = type_vocab_size
94        self.initializer_range = initializer_range
95        self.input_mask_from_dataset = input_mask_from_dataset
96        self.token_type_ids_from_dataset = token_type_ids_from_dataset
97        self.use_relative_positions = use_relative_positions
98        self.dtype = dtype
99        self.compute_type = compute_type
100        self.enable_fused_layernorm = enable_fused_layernorm
101
102
103class EmbeddingLookup(nn.Cell):
104    """
105    A embeddings lookup table with a fixed dictionary and size.
106
107    Args:
108        vocab_size (int): Size of the dictionary of embeddings.
109        embedding_size (int): The size of each embedding vector.
110        embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
111                         each embedding vector.
112        use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
113        initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
114    """
115    def __init__(self,
116                 vocab_size,
117                 embedding_size,
118                 embedding_shape,
119                 use_one_hot_embeddings=False,
120                 initializer_range=0.02):
121        super(EmbeddingLookup, self).__init__()
122        self.vocab_size = vocab_size
123        self.use_one_hot_embeddings = use_one_hot_embeddings
124        self.embedding_table = Parameter(initializer
125                                         (TruncatedNormal(initializer_range),
126                                          [vocab_size, embedding_size]),
127                                         name='embedding_table')
128        self.expand = P.ExpandDims()
129        self.shape_flat = (-1,)
130        self.gather = P.Gather()
131        self.one_hot = P.OneHot()
132        self.on_value = Tensor(1.0, mstype.float32)
133        self.off_value = Tensor(0.0, mstype.float32)
134        self.array_mul = P.MatMul()
135        self.reshape = P.Reshape()
136        self.shape = tuple(embedding_shape)
137
138    def construct(self, input_ids):
139        extended_ids = self.expand(input_ids, -1)
140        flat_ids = self.reshape(extended_ids, self.shape_flat)
141        if self.use_one_hot_embeddings:
142            one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
143            output_for_reshape = self.array_mul(
144                one_hot_ids, self.embedding_table)
145        else:
146            output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
147        output = self.reshape(output_for_reshape, self.shape)
148        return output, self.embedding_table
149
150
151class EmbeddingPostprocessor(nn.Cell):
152    """
153    Postprocessors apply positional and token type embeddings to word embeddings.
154
155    Args:
156        embedding_size (int): The size of each embedding vector.
157        embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
158                         each embedding vector.
159        use_token_type (bool): Specifies whether to use token type embeddings. Default: False.
160        token_type_vocab_size (int): Size of token type vocab. Default: 16.
161        use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
162        initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
163        max_position_embeddings (int): Maximum length of sequences used in this
164                                 model. Default: 512.
165        dropout_prob (float): The dropout probability. Default: 0.1.
166    """
167    def __init__(self,
168                 embedding_size,
169                 embedding_shape,
170                 use_relative_positions=False,
171                 use_token_type=False,
172                 token_type_vocab_size=16,
173                 use_one_hot_embeddings=False,
174                 initializer_range=0.02,
175                 max_position_embeddings=512,
176                 dropout_prob=0.1):
177        super(EmbeddingPostprocessor, self).__init__()
178        self.use_token_type = use_token_type
179        self.token_type_vocab_size = token_type_vocab_size
180        self.use_one_hot_embeddings = use_one_hot_embeddings
181        self.max_position_embeddings = max_position_embeddings
182        self.embedding_table = Parameter(initializer
183                                         (TruncatedNormal(initializer_range),
184                                          [token_type_vocab_size,
185                                           embedding_size]),
186                                         name='embedding_table')
187
188        self.shape_flat = (-1,)
189        self.one_hot = P.OneHot()
190        self.on_value = Tensor(1.0, mstype.float32)
191        self.off_value = Tensor(0.1, mstype.float32)
192        self.array_mul = P.MatMul()
193        self.reshape = P.Reshape()
194        self.shape = tuple(embedding_shape)
195        self.layernorm = nn.LayerNorm((embedding_size,))
196        self.dropout = nn.Dropout(1 - dropout_prob)
197        self.gather = P.Gather()
198        self.use_relative_positions = use_relative_positions
199        self.slice = P.StridedSlice()
200        self.full_position_embeddings = Parameter(initializer
201                                                  (TruncatedNormal(initializer_range),
202                                                   [max_position_embeddings,
203                                                    embedding_size]),
204                                                  name='full_position_embeddings')
205
206    def construct(self, token_type_ids, word_embeddings):
207        output = word_embeddings
208        if self.use_token_type:
209            flat_ids = self.reshape(token_type_ids, self.shape_flat)
210            if self.use_one_hot_embeddings:
211                one_hot_ids = self.one_hot(flat_ids,
212                                           self.token_type_vocab_size, self.on_value, self.off_value)
213                token_type_embeddings = self.array_mul(one_hot_ids,
214                                                       self.embedding_table)
215            else:
216                token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0)
217            token_type_embeddings = self.reshape(token_type_embeddings, self.shape)
218            output += token_type_embeddings
219        if not self.use_relative_positions:
220            _, seq, width = self.shape
221            position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1))
222            position_embeddings = self.reshape(position_embeddings, (1, seq, width))
223            output += position_embeddings
224        output = self.layernorm(output)
225        output = self.dropout(output)
226        return output
227
228
229class BertOutput(nn.Cell):
230    """
231    Apply a linear computation to hidden status and a residual computation to input.
232
233    Args:
234        in_channels (int): Input channels.
235        out_channels (int): Output channels.
236        initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
237        dropout_prob (float): The dropout probability. Default: 0.1.
238        compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
239    """
240    def __init__(self,
241                 in_channels,
242                 out_channels,
243                 initializer_range=0.02,
244                 dropout_prob=0.1,
245                 compute_type=mstype.float32,
246                 enable_fused_layernorm=False):
247        super(BertOutput, self).__init__()
248        self.dense = nn.Dense(in_channels, out_channels,
249                              weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
250        self.dropout = nn.Dropout(1 - dropout_prob)
251        self.dropout_prob = dropout_prob
252        self.add = P.Add()
253        self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
254        self.cast = P.Cast()
255
256    def construct(self, hidden_status, input_tensor):
257        output = self.dense(hidden_status)
258        output = self.dropout(output)
259        output = self.add(output, input_tensor)
260        output = self.layernorm(output)
261        return output
262
263
264class RelaPosMatrixGenerator(nn.Cell):
265    """
266    Generates matrix of relative positions between inputs.
267
268    Args:
269        length (int): Length of one dim for the matrix to be generated.
270        max_relative_position (int): Max value of relative position.
271    """
272    def __init__(self, length, max_relative_position):
273        super(RelaPosMatrixGenerator, self).__init__()
274        self._length = length
275        self._max_relative_position = max_relative_position
276        self._min_relative_position = -max_relative_position
277        self.range_length = -length + 1
278
279        self.tile = P.Tile()
280        self.range_mat = P.Reshape()
281        self.sub = P.Sub()
282        self.expanddims = P.ExpandDims()
283        self.cast = P.Cast()
284
285    def construct(self):
286        range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32)
287        range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1))
288        tile_row_out = self.tile(range_vec_row_out, (self._length,))
289        tile_col_out = self.tile(range_vec_col_out, (1, self._length))
290        range_mat_out = self.range_mat(tile_row_out, (self._length, self._length))
291        transpose_out = self.range_mat(tile_col_out, (self._length, self._length))
292        distance_mat = self.sub(range_mat_out, transpose_out)
293
294        distance_mat_clipped = C.clip_by_value(distance_mat,
295                                               self._min_relative_position,
296                                               self._max_relative_position)
297
298        # Shift values to be >=0. Each integer still uniquely identifies a
299        # relative position difference.
300        final_mat = distance_mat_clipped + self._max_relative_position
301        return final_mat
302
303
304class RelaPosEmbeddingsGenerator(nn.Cell):
305    """
306    Generates tensor of size [length, length, depth].
307
308    Args:
309        length (int): Length of one dim for the matrix to be generated.
310        depth (int): Size of each attention head.
311        max_relative_position (int): Maxmum value of relative position.
312        initializer_range (float): Initialization value of TruncatedNormal.
313        use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
314    """
315    def __init__(self,
316                 length,
317                 depth,
318                 max_relative_position,
319                 initializer_range,
320                 use_one_hot_embeddings=False):
321        super(RelaPosEmbeddingsGenerator, self).__init__()
322        self.depth = depth
323        self.vocab_size = max_relative_position * 2 + 1
324        self.use_one_hot_embeddings = use_one_hot_embeddings
325
326        self.embeddings_table = Parameter(
327            initializer(TruncatedNormal(initializer_range),
328                        [self.vocab_size, self.depth]),
329            name='embeddings_for_position')
330
331        self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
332                                                                max_relative_position=max_relative_position)
333        self.reshape = P.Reshape()
334        self.one_hot = nn.OneHot(depth=self.vocab_size)
335        self.shape = P.Shape()
336        self.gather = P.Gather()  # index_select
337        self.matmul = P.BatchMatMul()
338
339    def construct(self):
340        relative_positions_matrix_out = self.relative_positions_matrix()
341
342        # Generate embedding for each relative position of dimension depth.
343        if self.use_one_hot_embeddings:
344            flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
345            one_hot_relative_positions_matrix = self.one_hot(
346                flat_relative_positions_matrix)
347            embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
348            my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
349            embeddings = self.reshape(embeddings, my_shape)
350        else:
351            embeddings = self.gather(self.embeddings_table,
352                                     relative_positions_matrix_out, 0)
353        return embeddings
354
355
356class SaturateCast(nn.Cell):
357    """
358    Performs a safe saturating cast. This operation applies proper clamping before casting to prevent
359    the danger that the value will overflow or underflow.
360
361    Args:
362        src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32.
363        dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32.
364    """
365    def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
366        super(SaturateCast, self).__init__()
367        np_type = mstype.dtype_to_nptype(dst_type)
368        min_type = float(np.finfo(np_type).min)
369        max_type = float(np.finfo(np_type).max)
370
371        self.tensor_min_type = min_type
372        self.tensor_max_type = max_type
373
374        self.min_op = P.Minimum()
375        self.max_op = P.Maximum()
376        self.cast = P.Cast()
377        self.dst_type = dst_type
378
379    def construct(self, x):
380        out = self.max_op(x, self.tensor_min_type)
381        out = self.min_op(out, self.tensor_max_type)
382        return self.cast(out, self.dst_type)
383
384
385class BertAttention(nn.Cell):
386    """
387    Apply multi-headed attention from "from_tensor" to "to_tensor".
388
389    Args:
390        batch_size (int): Batch size of input datasets.
391        from_tensor_width (int): Size of last dim of from_tensor.
392        to_tensor_width (int): Size of last dim of to_tensor.
393        from_seq_length (int): Length of from_tensor sequence.
394        to_seq_length (int): Length of to_tensor sequence.
395        num_attention_heads (int): Number of attention heads. Default: 1.
396        size_per_head (int): Size of each attention head. Default: 512.
397        query_act (str): Activation function for the query transform. Default: None.
398        key_act (str): Activation function for the key transform. Default: None.
399        value_act (str): Activation function for the value transform. Default: None.
400        has_attention_mask (bool): Specifies whether to use attention mask. Default: False.
401        attention_probs_dropout_prob (float): The dropout probability for
402                                      BertAttention. Default: 0.0.
403        use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
404        initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
405        do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d
406                             tensor. Default: False.
407        use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
408        compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32.
409    """
410    def __init__(self,
411                 batch_size,
412                 from_tensor_width,
413                 to_tensor_width,
414                 from_seq_length,
415                 to_seq_length,
416                 num_attention_heads=1,
417                 size_per_head=512,
418                 query_act=None,
419                 key_act=None,
420                 value_act=None,
421                 has_attention_mask=False,
422                 attention_probs_dropout_prob=0.0,
423                 use_one_hot_embeddings=False,
424                 initializer_range=0.02,
425                 do_return_2d_tensor=False,
426                 use_relative_positions=False,
427                 compute_type=mstype.float32):
428
429        super(BertAttention, self).__init__()
430        self.batch_size = batch_size
431        self.from_seq_length = from_seq_length
432        self.to_seq_length = to_seq_length
433        self.num_attention_heads = num_attention_heads
434        self.size_per_head = size_per_head
435        self.has_attention_mask = has_attention_mask
436        self.use_relative_positions = use_relative_positions
437
438        self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head))
439        self.reshape = P.Reshape()
440        self.shape_from_2d = (-1, from_tensor_width)
441        self.shape_to_2d = (-1, to_tensor_width)
442        weight = TruncatedNormal(initializer_range)
443        units = num_attention_heads * size_per_head
444        self.query_layer = nn.Dense(from_tensor_width,
445                                    units,
446                                    activation=query_act,
447                                    weight_init=weight).to_float(compute_type)
448        self.key_layer = nn.Dense(to_tensor_width,
449                                  units,
450                                  activation=key_act,
451                                  weight_init=weight).to_float(compute_type)
452        self.value_layer = nn.Dense(to_tensor_width,
453                                    units,
454                                    activation=value_act,
455                                    weight_init=weight).to_float(compute_type)
456
457        self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head)
458        self.shape_to = (
459            batch_size, to_seq_length, num_attention_heads, size_per_head)
460
461        self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
462        self.multiply = P.Mul()
463        self.transpose = P.Transpose()
464        self.trans_shape = (0, 2, 1, 3)
465        self.trans_shape_relative = (2, 0, 1, 3)
466        self.trans_shape_position = (1, 2, 0, 3)
467        self.multiply_data = -10000.0
468        self.batch_num = batch_size * num_attention_heads
469        self.matmul = P.BatchMatMul()
470
471        self.softmax = nn.Softmax()
472        self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)
473
474        if self.has_attention_mask:
475            self.expand_dims = P.ExpandDims()
476            self.sub = P.Sub()
477            self.add = P.Add()
478            self.cast = P.Cast()
479            self.get_dtype = P.DType()
480        if do_return_2d_tensor:
481            self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head)
482        else:
483            self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head)
484
485        self.cast_compute_type = SaturateCast(dst_type=compute_type)
486        if self.use_relative_positions:
487            self._generate_relative_positions_embeddings = \
488                RelaPosEmbeddingsGenerator(length=to_seq_length,
489                                           depth=size_per_head,
490                                           max_relative_position=16,
491                                           initializer_range=initializer_range,
492                                           use_one_hot_embeddings=use_one_hot_embeddings)
493
494    def construct(self, from_tensor, to_tensor, attention_mask):
495        # reshape 2d/3d input tensors to 2d
496        from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
497        to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
498        query_out = self.query_layer(from_tensor_2d)
499        key_out = self.key_layer(to_tensor_2d)
500        value_out = self.value_layer(to_tensor_2d)
501
502        query_layer = self.reshape(query_out, self.shape_from)
503        query_layer = self.transpose(query_layer, self.trans_shape)
504        key_layer = self.reshape(key_out, self.shape_to)
505        key_layer = self.transpose(key_layer, self.trans_shape)
506
507        attention_scores = self.matmul_trans_b(query_layer, key_layer)
508
509        # use_relative_position, supplementary logic
510        if self.use_relative_positions:
511            # 'relations_keys' = [F|T, F|T, H]
512            relations_keys = self._generate_relative_positions_embeddings()
513            relations_keys = self.cast_compute_type(relations_keys)
514            # query_layer_t is [F, B, N, H]
515            query_layer_t = self.transpose(query_layer, self.trans_shape_relative)
516            # query_layer_r is [F, B * N, H]
517            query_layer_r = self.reshape(query_layer_t,
518                                         (self.from_seq_length,
519                                          self.batch_num,
520                                          self.size_per_head))
521            # key_position_scores is [F, B * N, F|T]
522            key_position_scores = self.matmul_trans_b(query_layer_r,
523                                                      relations_keys)
524            # key_position_scores_r is [F, B, N, F|T]
525            key_position_scores_r = self.reshape(key_position_scores,
526                                                 (self.from_seq_length,
527                                                  self.batch_size,
528                                                  self.num_attention_heads,
529                                                  self.from_seq_length))
530            # key_position_scores_r_t is [B, N, F, F|T]
531            key_position_scores_r_t = self.transpose(key_position_scores_r,
532                                                     self.trans_shape_position)
533            attention_scores = attention_scores + key_position_scores_r_t
534
535        attention_scores = self.multiply(self.scores_mul, attention_scores)
536
537        if self.has_attention_mask:
538            attention_mask = self.expand_dims(attention_mask, 1)
539            multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)),
540                                    self.cast(attention_mask, self.get_dtype(attention_scores)))
541
542            adder = self.multiply(multiply_out, self.multiply_data)
543            attention_scores = self.add(adder, attention_scores)
544
545        attention_probs = self.softmax(attention_scores)
546        attention_probs = self.dropout(attention_probs)
547
548        value_layer = self.reshape(value_out, self.shape_to)
549        value_layer = self.transpose(value_layer, self.trans_shape)
550        context_layer = self.matmul(attention_probs, value_layer)
551
552        # use_relative_position, supplementary logic
553        if self.use_relative_positions:
554            # 'relations_values' = [F|T, F|T, H]
555            relations_values = self._generate_relative_positions_embeddings()
556            relations_values = self.cast_compute_type(relations_values)
557            # attention_probs_t is [F, B, N, T]
558            attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative)
559            # attention_probs_r is [F, B * N, T]
560            attention_probs_r = self.reshape(
561                attention_probs_t,
562                (self.from_seq_length,
563                 self.batch_num,
564                 self.to_seq_length))
565            # value_position_scores is [F, B * N, H]
566            value_position_scores = self.matmul(attention_probs_r,
567                                                relations_values)
568            # value_position_scores_r is [F, B, N, H]
569            value_position_scores_r = self.reshape(value_position_scores,
570                                                   (self.from_seq_length,
571                                                    self.batch_size,
572                                                    self.num_attention_heads,
573                                                    self.size_per_head))
574            # value_position_scores_r_t is [B, N, F, H]
575            value_position_scores_r_t = self.transpose(value_position_scores_r,
576                                                       self.trans_shape_position)
577            context_layer = context_layer + value_position_scores_r_t
578
579        context_layer = self.transpose(context_layer, self.trans_shape)
580        context_layer = self.reshape(context_layer, self.shape_return)
581
582        return context_layer
583
584
585class BertSelfAttention(nn.Cell):
586    """
587    Apply self-attention.
588
589    Args:
590        batch_size (int): Batch size of input dataset.
591        seq_length (int): Length of input sequence.
592        hidden_size (int): Size of the bert encoder layers.
593        num_attention_heads (int): Number of attention heads. Default: 12.
594        attention_probs_dropout_prob (float): The dropout probability for
595                                      BertAttention. Default: 0.1.
596        use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False.
597        initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
598        hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
599        use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
600        compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32.
601    """
602    def __init__(self,
603                 batch_size,
604                 seq_length,
605                 hidden_size,
606                 num_attention_heads=12,
607                 attention_probs_dropout_prob=0.1,
608                 use_one_hot_embeddings=False,
609                 initializer_range=0.02,
610                 hidden_dropout_prob=0.1,
611                 use_relative_positions=False,
612                 compute_type=mstype.float32,
613                 enable_fused_layernorm=False):
614        super(BertSelfAttention, self).__init__()
615        if hidden_size % num_attention_heads != 0:
616            raise ValueError("The hidden size (%d) is not a multiple of the number "
617                             "of attention heads (%d)" % (hidden_size, num_attention_heads))
618
619        self.size_per_head = int(hidden_size / num_attention_heads)
620
621        self.attention = BertAttention(
622            batch_size=batch_size,
623            from_tensor_width=hidden_size,
624            to_tensor_width=hidden_size,
625            from_seq_length=seq_length,
626            to_seq_length=seq_length,
627            num_attention_heads=num_attention_heads,
628            size_per_head=self.size_per_head,
629            attention_probs_dropout_prob=attention_probs_dropout_prob,
630            use_one_hot_embeddings=use_one_hot_embeddings,
631            initializer_range=initializer_range,
632            use_relative_positions=use_relative_positions,
633            has_attention_mask=True,
634            do_return_2d_tensor=True,
635            compute_type=compute_type)
636
637        self.output = BertOutput(in_channels=hidden_size,
638                                 out_channels=hidden_size,
639                                 initializer_range=initializer_range,
640                                 dropout_prob=hidden_dropout_prob,
641                                 compute_type=compute_type,
642                                 enable_fused_layernorm=enable_fused_layernorm)
643        self.reshape = P.Reshape()
644        self.shape = (-1, hidden_size)
645
646    def construct(self, input_tensor, attention_mask):
647        input_tensor = self.reshape(input_tensor, self.shape)
648        attention_output = self.attention(input_tensor, input_tensor, attention_mask)
649        output = self.output(attention_output, input_tensor)
650        return output
651
652
653class BertEncoderCell(nn.Cell):
654    """
655    Encoder cells used in BertTransformer.
656
657    Args:
658        batch_size (int): Batch size of input dataset.
659        hidden_size (int): Size of the bert encoder layers. Default: 768.
660        seq_length (int): Length of input sequence. Default: 512.
661        num_attention_heads (int): Number of attention heads. Default: 12.
662        intermediate_size (int): Size of intermediate layer. Default: 3072.
663        attention_probs_dropout_prob (float): The dropout probability for
664                                      BertAttention. Default: 0.02.
665        use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
666        initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
667        hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
668        use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
669        hidden_act (str): Activation function. Default: "gelu".
670        compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32.
671    """
672    def __init__(self,
673                 batch_size,
674                 hidden_size=768,
675                 seq_length=512,
676                 num_attention_heads=12,
677                 intermediate_size=3072,
678                 attention_probs_dropout_prob=0.02,
679                 use_one_hot_embeddings=False,
680                 initializer_range=0.02,
681                 hidden_dropout_prob=0.1,
682                 use_relative_positions=False,
683                 hidden_act="gelu",
684                 compute_type=mstype.float32,
685                 enable_fused_layernorm=False):
686        super(BertEncoderCell, self).__init__()
687        self.attention = BertSelfAttention(
688            batch_size=batch_size,
689            hidden_size=hidden_size,
690            seq_length=seq_length,
691            num_attention_heads=num_attention_heads,
692            attention_probs_dropout_prob=attention_probs_dropout_prob,
693            use_one_hot_embeddings=use_one_hot_embeddings,
694            initializer_range=initializer_range,
695            hidden_dropout_prob=hidden_dropout_prob,
696            use_relative_positions=use_relative_positions,
697            compute_type=compute_type,
698            enable_fused_layernorm=enable_fused_layernorm)
699        self.intermediate = nn.Dense(in_channels=hidden_size,
700                                     out_channels=intermediate_size,
701                                     activation=hidden_act,
702                                     weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
703        self.output = BertOutput(in_channels=intermediate_size,
704                                 out_channels=hidden_size,
705                                 initializer_range=initializer_range,
706                                 dropout_prob=hidden_dropout_prob,
707                                 compute_type=compute_type,
708                                 enable_fused_layernorm=enable_fused_layernorm)
709
710    def construct(self, hidden_states, attention_mask):
711        # self-attention
712        attention_output = self.attention(hidden_states, attention_mask)
713        # feed construct
714        intermediate_output = self.intermediate(attention_output)
715        # add and normalize
716        output = self.output(intermediate_output, attention_output)
717        return output
718
719
720class BertTransformer(nn.Cell):
721    """
722    Multi-layer bert transformer.
723
724    Args:
725        batch_size (int): Batch size of input dataset.
726        hidden_size (int): Size of the encoder layers.
727        seq_length (int): Length of input sequence.
728        num_hidden_layers (int): Number of hidden layers in encoder cells.
729        num_attention_heads (int): Number of attention heads in encoder cells. Default: 12.
730        intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072.
731        attention_probs_dropout_prob (float): The dropout probability for
732                                      BertAttention. Default: 0.1.
733        use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
734        initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
735        hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
736        use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
737        hidden_act (str): Activation function used in the encoder cells. Default: "gelu".
738        compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
739        return_all_encoders (bool): Specifies whether to return all encoders. Default: False.
740    """
741    def __init__(self,
742                 batch_size,
743                 hidden_size,
744                 seq_length,
745                 num_hidden_layers,
746                 num_attention_heads=12,
747                 intermediate_size=3072,
748                 attention_probs_dropout_prob=0.1,
749                 use_one_hot_embeddings=False,
750                 initializer_range=0.02,
751                 hidden_dropout_prob=0.1,
752                 use_relative_positions=False,
753                 hidden_act="gelu",
754                 compute_type=mstype.float32,
755                 return_all_encoders=False,
756                 enable_fused_layernorm=False):
757        super(BertTransformer, self).__init__()
758        self.return_all_encoders = return_all_encoders
759
760        layers = []
761        for _ in range(num_hidden_layers):
762            layer = BertEncoderCell(batch_size=batch_size,
763                                    hidden_size=hidden_size,
764                                    seq_length=seq_length,
765                                    num_attention_heads=num_attention_heads,
766                                    intermediate_size=intermediate_size,
767                                    attention_probs_dropout_prob=attention_probs_dropout_prob,
768                                    use_one_hot_embeddings=use_one_hot_embeddings,
769                                    initializer_range=initializer_range,
770                                    hidden_dropout_prob=hidden_dropout_prob,
771                                    use_relative_positions=use_relative_positions,
772                                    hidden_act=hidden_act,
773                                    compute_type=compute_type,
774                                    enable_fused_layernorm=enable_fused_layernorm)
775            layers.append(layer)
776
777        self.layers = nn.CellList(layers)
778
779        self.reshape = P.Reshape()
780        self.shape = (-1, hidden_size)
781        self.out_shape = (batch_size, seq_length, hidden_size)
782
783    def construct(self, input_tensor, attention_mask):
784        prev_output = self.reshape(input_tensor, self.shape)
785
786        all_encoder_layers = ()
787        for layer_module in self.layers:
788            layer_output = layer_module(prev_output, attention_mask)
789            prev_output = layer_output
790
791            if self.return_all_encoders:
792                layer_output = self.reshape(layer_output, self.out_shape)
793                all_encoder_layers = all_encoder_layers + (layer_output,)
794
795        if not self.return_all_encoders:
796            prev_output = self.reshape(prev_output, self.out_shape)
797            all_encoder_layers = all_encoder_layers + (prev_output,)
798        return all_encoder_layers
799
800
801class CreateAttentionMaskFromInputMask(nn.Cell):
802    """
803    Create attention mask according to input mask.
804
805    Args:
806        config (Class): Configuration for BertModel.
807    """
808    def __init__(self, config):
809        super(CreateAttentionMaskFromInputMask, self).__init__()
810        self.input_mask_from_dataset = config.input_mask_from_dataset
811        self.input_mask = None
812
813        if not self.input_mask_from_dataset:
814            self.input_mask = initializer(
815                "ones", [config.batch_size, config.seq_length], mstype.int32).init_data()
816
817        self.cast = P.Cast()
818        self.reshape = P.Reshape()
819        self.shape = (config.batch_size, 1, config.seq_length)
820        self.broadcast_ones = initializer(
821            "ones", [config.batch_size, config.seq_length, 1], mstype.float32).init_data()
822        self.batch_matmul = P.BatchMatMul()
823
824    def construct(self, input_mask):
825        if not self.input_mask_from_dataset:
826            input_mask = self.input_mask
827
828        input_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
829        attention_mask = self.batch_matmul(self.broadcast_ones, input_mask)
830        return attention_mask
831
832
833class BertModel(nn.Cell):
834    """
835    Bidirectional Encoder Representations from Transformers.
836
837    Args:
838        config (Class): Configuration for BertModel.
839        is_training (bool): True for training mode. False for eval mode.
840        use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
841    """
842    def __init__(self,
843                 config,
844                 is_training,
845                 use_one_hot_embeddings=False):
846        super(BertModel, self).__init__()
847        config = copy.deepcopy(config)
848        if not is_training:
849            config.hidden_dropout_prob = 0.0
850            config.attention_probs_dropout_prob = 0.0
851
852        self.input_mask_from_dataset = config.input_mask_from_dataset
853        self.token_type_ids_from_dataset = config.token_type_ids_from_dataset
854        self.batch_size = config.batch_size
855        self.seq_length = config.seq_length
856        self.hidden_size = config.hidden_size
857        self.num_hidden_layers = config.num_hidden_layers
858        self.embedding_size = config.hidden_size
859        self.token_type_ids = None
860
861        self.last_idx = self.num_hidden_layers - 1
862        output_embedding_shape = [self.batch_size, self.seq_length,
863                                  self.embedding_size]
864
865        if not self.token_type_ids_from_dataset:
866            self.token_type_ids = initializer(
867                "zeros", [self.batch_size, self.seq_length], mstype.int32).init_data()
868
869        self.bert_embedding_lookup = EmbeddingLookup(
870            vocab_size=config.vocab_size,
871            embedding_size=self.embedding_size,
872            embedding_shape=output_embedding_shape,
873            use_one_hot_embeddings=use_one_hot_embeddings,
874            initializer_range=config.initializer_range)
875
876        self.bert_embedding_postprocessor = EmbeddingPostprocessor(
877            embedding_size=self.embedding_size,
878            embedding_shape=output_embedding_shape,
879            use_relative_positions=config.use_relative_positions,
880            use_token_type=True,
881            token_type_vocab_size=config.type_vocab_size,
882            use_one_hot_embeddings=use_one_hot_embeddings,
883            initializer_range=0.02,
884            max_position_embeddings=config.max_position_embeddings,
885            dropout_prob=config.hidden_dropout_prob)
886
887        self.bert_encoder = BertTransformer(
888            batch_size=self.batch_size,
889            hidden_size=self.hidden_size,
890            seq_length=self.seq_length,
891            num_attention_heads=config.num_attention_heads,
892            num_hidden_layers=self.num_hidden_layers,
893            intermediate_size=config.intermediate_size,
894            attention_probs_dropout_prob=config.attention_probs_dropout_prob,
895            use_one_hot_embeddings=use_one_hot_embeddings,
896            initializer_range=config.initializer_range,
897            hidden_dropout_prob=config.hidden_dropout_prob,
898            use_relative_positions=config.use_relative_positions,
899            hidden_act=config.hidden_act,
900            compute_type=config.compute_type,
901            return_all_encoders=True,
902            enable_fused_layernorm=config.enable_fused_layernorm)
903
904        self.cast = P.Cast()
905        self.dtype = config.dtype
906        self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
907        self.slice = P.StridedSlice()
908
909        self.squeeze_1 = P.Squeeze(axis=1)
910        self.dense = nn.Dense(self.hidden_size, self.hidden_size,
911                              activation="tanh",
912                              weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
913        self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)
914
915    def construct(self, input_ids, token_type_ids, input_mask):
916
917        # embedding
918        if not self.token_type_ids_from_dataset:
919            token_type_ids = self.token_type_ids
920        word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids)
921        embedding_output = self.bert_embedding_postprocessor(token_type_ids,
922                                                             word_embeddings)
923
924        # attention mask [batch_size, seq_length, seq_length]
925        attention_mask = self._create_attention_mask_from_input_mask(input_mask)
926
927        # bert encoder
928        encoder_output = self.bert_encoder(self.cast_compute_type(embedding_output),
929                                           attention_mask)
930
931        sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
932
933        # pooler
934        sequence_slice = self.slice(sequence_output,
935                                    (0, 0, 0),
936                                    (self.batch_size, 1, self.hidden_size),
937                                    (1, 1, 1))
938        first_token = self.squeeze_1(sequence_slice)
939        pooled_output = self.dense(first_token)
940        pooled_output = self.cast(pooled_output, self.dtype)
941
942        return sequence_output, pooled_output, embedding_tables
943