• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import math
17import copy
18import numpy as np
19from mindspore import nn
20from mindspore import context
21from mindspore.common import dtype as mstype
22from mindspore.ops import functional as F
23from mindspore.ops import operations as P
24from mindspore.ops import composite as C
25from mindspore.common.initializer import TruncatedNormal, initializer
26from mindspore.common.tensor import Tensor
27from mindspore.common.parameter import Parameter
28
29
30class AlbertConfig:
31    """
32    Configuration for `AlbertModel`.
33
34    Args:
35        seq_length (int): Length of input sequence. Default: 128.
36        vocab_size (int): The shape of each embedding vector. Default: 32000.
37        hidden_size (int): Size of the bert encoder layers. Default: 768.
38        num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder
39                           cell. Default: 12.
40        num_attention_heads (int): Number of attention heads in the BertTransformer
41                             encoder cell. Default: 12.
42        intermediate_size (int): Size of intermediate layer in the BertTransformer
43                           encoder cell. Default: 3072.
44        hidden_act (str): Activation function used in the BertTransformer encoder
45                    cell. Default: "gelu".
46        hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
47        attention_probs_dropout_prob (float): The dropout probability for
48                                      BertAttention. Default: 0.1.
49        max_position_embeddings (int): Maximum length of sequences used in this
50                                 model. Default: 512.
51        type_vocab_size (int): Size of token type vocab. Default: 16.
52        initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
53        use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
54        dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
55        compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
56    """
57
58    def __init__(self,
59                 seq_length=256,
60                 vocab_size=21128,
61                 hidden_size=312,
62                 num_hidden_groups=1,
63                 num_hidden_layers=4,
64                 inner_group_num=1,
65                 num_attention_heads=12,
66                 intermediate_size=1248,
67                 hidden_act="gelu",
68                 query_act=None,
69                 key_act=None,
70                 value_act=None,
71                 hidden_dropout_prob=0.0,
72                 attention_probs_dropout_prob=0.0,
73                 max_position_embeddings=512,
74                 type_vocab_size=2,
75                 initializer_range=0.02,
76                 use_relative_positions=False,
77                 classifier_dropout_prob=0.1,
78                 embedding_size=128,
79                 layer_norm_eps=1e-12,
80                 has_attention_mask=True,
81                 do_return_2d_tensor=True,
82                 use_one_hot_embeddings=False,
83                 use_token_type=True,
84                 return_all_encoders=False,
85                 output_attentions=False,
86                 output_hidden_states=False,
87                 dtype=mstype.float32,
88                 compute_type=mstype.float32,
89                 is_training=True,
90                 num_labels=5,
91                 use_word_embeddings=True):
92        self.seq_length = seq_length
93        self.vocab_size = vocab_size
94        self.hidden_size = hidden_size
95        self.num_hidden_layers = num_hidden_layers
96        self.inner_group_num = inner_group_num
97        self.num_attention_heads = num_attention_heads
98        self.hidden_act = hidden_act
99        self.query_act = query_act
100        self.key_act = key_act
101        self.value_act = value_act
102        self.intermediate_size = intermediate_size
103        self.hidden_dropout_prob = hidden_dropout_prob
104        self.attention_probs_dropout_prob = attention_probs_dropout_prob
105        self.max_position_embeddings = max_position_embeddings
106        self.type_vocab_size = type_vocab_size
107        self.initializer_range = initializer_range
108        self.use_relative_positions = use_relative_positions
109        self.classifier_dropout_prob = classifier_dropout_prob
110        self.embedding_size = embedding_size
111        self.layer_norm_eps = layer_norm_eps
112        self.num_hidden_groups = num_hidden_groups
113        self.has_attention_mask = has_attention_mask
114        self.do_return_2d_tensor = do_return_2d_tensor
115        self.use_one_hot_embeddings = use_one_hot_embeddings
116        self.use_token_type = use_token_type
117        self.return_all_encoders = return_all_encoders
118        self.output_attentions = output_attentions
119        self.output_hidden_states = output_hidden_states
120        self.dtype = dtype
121        self.compute_type = compute_type
122        self.is_training = is_training
123        self.num_labels = num_labels
124        self.use_word_embeddings = use_word_embeddings
125
126
127class EmbeddingLookup(nn.Cell):
128    """
129    A embeddings lookup table with a fixed dictionary and size.
130
131    Args:
132        config (AlbertConfig): Albert Config.
133    """
134
135    def __init__(self, config):
136        super(EmbeddingLookup, self).__init__()
137        self.vocab_size = config.vocab_size
138        self.use_one_hot_embeddings = config.use_one_hot_embeddings
139        self.embedding_table = Parameter(initializer
140                                         (TruncatedNormal(config.initializer_range),
141                                          [config.vocab_size, config.embedding_size]),
142                                         name='embedding_table')
143        self.expand = P.ExpandDims()
144        self.shape_flat = (-1,)
145        self.gather = P.Gather()
146        self.one_hot = P.OneHot()
147        self.on_value = Tensor(1.0, mstype.float32)
148        self.off_value = Tensor(0.0, mstype.float32)
149        self.array_mul = P.MatMul()
150        self.reshape = P.Reshape()
151        self.shape = (-1, config.seq_length, config.embedding_size)
152
153    def construct(self, input_ids):
154        """embedding lookup"""
155        flat_ids = self.reshape(input_ids, self.shape_flat)
156        if self.use_one_hot_embeddings:
157            one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
158            output_for_reshape = self.array_mul(
159                one_hot_ids, self.embedding_table)
160        else:
161            output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
162        output = self.reshape(output_for_reshape, self.shape)
163        return output, self.embedding_table
164
165
166class EmbeddingPostprocessor(nn.Cell):
167    """
168    Postprocessors apply positional and token type embeddings to word embeddings.
169
170    Args:
171        config (AlbertConfig): Albert Config.
172    """
173
174    def __init__(self, config):
175        super(EmbeddingPostprocessor, self).__init__()
176        self.use_token_type = config.use_token_type
177        self.token_type_vocab_size = config.type_vocab_size
178        self.use_one_hot_embeddings = config.use_one_hot_embeddings
179        self.max_position_embeddings = config.max_position_embeddings
180        self.embedding_table = Parameter(initializer
181                                         (TruncatedNormal(config.initializer_range),
182                                          [config.type_vocab_size,
183                                           config.embedding_size]))
184        self.shape_flat = (-1,)
185        self.one_hot = P.OneHot()
186        self.on_value = Tensor(1.0, mstype.float32)
187        self.off_value = Tensor(0.1, mstype.float32)
188        self.array_mul = P.MatMul()
189        self.reshape = P.Reshape()
190        self.shape = (-1, config.seq_length, config.embedding_size)
191        self.layernorm = nn.LayerNorm((config.embedding_size,))
192        self.dropout = nn.Dropout(1 - config.hidden_dropout_prob)
193        self.gather = P.Gather()
194        self.use_relative_positions = config.use_relative_positions
195        self.slice = P.StridedSlice()
196        self.full_position_embeddings = Parameter(initializer
197                                                  (TruncatedNormal(config.initializer_range),
198                                                   [config.max_position_embeddings,
199                                                    config.embedding_size]))
200
201    def construct(self, token_type_ids, word_embeddings):
202        """embedding postprocessor"""
203        output = word_embeddings
204        if self.use_token_type:
205            flat_ids = self.reshape(token_type_ids, self.shape_flat)
206            if self.use_one_hot_embeddings:
207                one_hot_ids = self.one_hot(flat_ids,
208                                           self.token_type_vocab_size, self.on_value, self.off_value)
209                token_type_embeddings = self.array_mul(one_hot_ids,
210                                                       self.embedding_table)
211            else:
212                token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0)
213            token_type_embeddings = self.reshape(token_type_embeddings, self.shape)
214            output += token_type_embeddings
215        if not self.use_relative_positions:
216            _, seq, width = self.shape
217            position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1))
218            position_embeddings = self.reshape(position_embeddings, (1, seq, width))
219            output += position_embeddings
220        output = self.layernorm(output)
221        output = self.dropout(output)
222        return output
223
224
225class AlbertOutput(nn.Cell):
226    """
227    Apply a linear computation to hidden status and a residual computation to input.
228
229    Args:
230        config (AlbertConfig): Albert Config.
231    """
232
233    def __init__(self, config):
234        super(AlbertOutput, self).__init__()
235        self.dense = nn.Dense(config.hidden_size, config.hidden_size,
236                              weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
237        self.dropout = nn.Dropout(1 - config.hidden_dropout_prob)
238        self.add = P.Add()
239        self.is_gpu = context.get_context('device_target') == "GPU"
240        if self.is_gpu:
241            self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(mstype.float32)
242            self.compute_type = config.compute_type
243        else:
244            self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type)
245
246        self.cast = P.Cast()
247
248    def construct(self, hidden_status, input_tensor):
249        """bert output"""
250        output = self.dense(hidden_status)
251        output = self.dropout(output)
252        output = self.add(input_tensor, output)
253        output = self.layernorm(output)
254        if self.is_gpu:
255            output = self.cast(output, self.compute_type)
256        return output
257
258
259class RelaPosMatrixGenerator(nn.Cell):
260    """
261    Generates matrix of relative positions between inputs.
262
263    Args:
264        length (int): Length of one dim for the matrix to be generated.
265        max_relative_position (int): Max value of relative position.
266    """
267
268    def __init__(self, length, max_relative_position):
269        super(RelaPosMatrixGenerator, self).__init__()
270        self._length = length
271        self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32)
272        self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32)
273        self.range_length = -length + 1
274        self.tile = P.Tile()
275        self.range_mat = P.Reshape()
276        self.sub = P.Sub()
277        self.expanddims = P.ExpandDims()
278        self.cast = P.Cast()
279
280    def construct(self):
281        """position matrix generator"""
282        range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32)
283        range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1))
284        tile_row_out = self.tile(range_vec_row_out, (self._length,))
285        tile_col_out = self.tile(range_vec_col_out, (1, self._length))
286        range_mat_out = self.range_mat(tile_row_out, (self._length, self._length))
287        transpose_out = self.range_mat(tile_col_out, (self._length, self._length))
288        distance_mat = self.sub(range_mat_out, transpose_out)
289        distance_mat_clipped = C.clip_by_value(distance_mat,
290                                               self._min_relative_position,
291                                               self._max_relative_position)
292        # Shift values to be >=0. Each integer still uniquely identifies a
293        # relative position difference.
294        final_mat = distance_mat_clipped + self._max_relative_position
295        return final_mat
296
297
298class RelaPosEmbeddingsGenerator(nn.Cell):
299    """
300    Generates tensor of size [length, length, depth].
301
302    Args:
303        length (int): Length of one dim for the matrix to be generated.
304        depth (int): Size of each attention head.
305        max_relative_position (int): Maxmum value of relative position.
306        initializer_range (float): Initialization value of TruncatedNormal.
307        use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
308    """
309
310    def __init__(self,
311                 length,
312                 depth,
313                 max_relative_position,
314                 initializer_range,
315                 use_one_hot_embeddings=False):
316        super(RelaPosEmbeddingsGenerator, self).__init__()
317        self.depth = depth
318        self.vocab_size = max_relative_position * 2 + 1
319        self.use_one_hot_embeddings = use_one_hot_embeddings
320        self.embeddings_table = Parameter(
321            initializer(TruncatedNormal(initializer_range),
322                        [self.vocab_size, self.depth]),
323            name='embeddings_for_position')
324        self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
325                                                                max_relative_position=max_relative_position)
326        self.reshape = P.Reshape()
327        self.one_hot = P.OneHot()
328        self.on_value = Tensor(1.0, mstype.float32)
329        self.off_value = Tensor(0.0, mstype.float32)
330        self.shape = P.Shape()
331        self.gather = P.Gather()  # index_select
332        self.matmul = P.BatchMatMul()
333
334    def construct(self):
335        """position embedding generation"""
336        relative_positions_matrix_out = self.relative_positions_matrix()
337        # Generate embedding for each relative position of dimension depth.
338        if self.use_one_hot_embeddings:
339            flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
340            one_hot_relative_positions_matrix = self.one_hot(
341                flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value)
342            embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
343            my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
344            embeddings = self.reshape(embeddings, my_shape)
345        else:
346            embeddings = self.gather(self.embeddings_table,
347                                     relative_positions_matrix_out, 0)
348        return embeddings
349
350
351class SaturateCast(nn.Cell):
352    """
353    Performs a safe saturating cast. This operation applies proper clamping before casting to prevent
354    the danger that the value will overflow or underflow.
355
356    Args:
357        src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32.
358        dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32.
359    """
360
361    def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
362        super(SaturateCast, self).__init__()
363        np_type = mstype.dtype_to_nptype(dst_type)
364        min_type = np.finfo(np_type).min
365        max_type = np.finfo(np_type).max
366        self.tensor_min_type = Tensor([min_type], dtype=src_type)
367        self.tensor_max_type = Tensor([max_type], dtype=src_type)
368        self.min_op = P.Minimum()
369        self.max_op = P.Maximum()
370        self.cast = P.Cast()
371        self.dst_type = dst_type
372
373    def construct(self, x):
374        """saturate cast"""
375        out = self.max_op(x, self.tensor_min_type)
376        out = self.min_op(out, self.tensor_max_type)
377        return self.cast(out, self.dst_type)
378
379
380class AlbertAttention(nn.Cell):
381    """
382    Apply multi-headed attention from "from_tensor" to "to_tensor".
383
384    Args:
385        config (AlbertConfig): Albert Config.
386    """
387
388    def __init__(self, config):
389        super(AlbertAttention, self).__init__()
390        self.from_seq_length = config.seq_length
391        self.to_seq_length = config.seq_length
392        self.num_attention_heads = config.num_attention_heads
393        self.size_per_head = int(config.hidden_size / config.num_attention_heads)
394        self.has_attention_mask = config.has_attention_mask
395        self.use_relative_positions = config.use_relative_positions
396        self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=config.compute_type)
397        self.reshape = P.Reshape()
398        self.shape_from_2d = (-1, config.hidden_size)
399        self.shape_to_2d = (-1, config.hidden_size)
400        weight = TruncatedNormal(config.initializer_range)
401
402        self.query = nn.Dense(config.hidden_size,
403                              config.hidden_size,
404                              activation=config.query_act,
405                              weight_init=weight).to_float(config.compute_type)
406        self.key = nn.Dense(config.hidden_size,
407                            config.hidden_size,
408                            activation=config.key_act,
409                            weight_init=weight).to_float(config.compute_type)
410        self.value = nn.Dense(config.hidden_size,
411                              config.hidden_size,
412                              activation=config.value_act,
413                              weight_init=weight).to_float(config.compute_type)
414        self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
415        self.matmul = P.BatchMatMul()
416        self.shape_from = (-1, config.seq_length, config.num_attention_heads, self.size_per_head)
417        self.shape_to = (-1, config.seq_length, config.num_attention_heads, self.size_per_head)
418        self.multiply = P.Mul()
419        self.transpose = P.Transpose()
420        self.trans_shape = (0, 2, 1, 3)
421        self.trans_shape_relative = (2, 0, 1, 3)
422        self.trans_shape_position = (1, 2, 0, 3)
423        self.multiply_data = Tensor([-10000.0], dtype=config.compute_type)
424        self.softmax = nn.Softmax()
425        self.dropout = nn.Dropout(1 - config.attention_probs_dropout_prob)
426        if self.has_attention_mask:
427            self.expand_dims = P.ExpandDims()
428            self.sub = P.Sub()
429            self.add = P.Add()
430            self.cast = P.Cast()
431            self.get_dtype = P.DType()
432        if config.do_return_2d_tensor:
433            self.shape_return = (-1, config.hidden_size)
434        else:
435            self.shape_return = (-1, config.seq_length, config.hidden_size)
436        self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
437        if self.use_relative_positions:
438            self._generate_relative_positions_embeddings = \
439                RelaPosEmbeddingsGenerator(length=config.seq_length,
440                                           depth=self.size_per_head,
441                                           max_relative_position=16,
442                                           initializer_range=config.initializer_range,
443                                           use_one_hot_embeddings=config.use_one_hot_embeddings)
444
445    def construct(self, from_tensor, to_tensor, attention_mask):
446        """bert attention"""
447        # reshape 2d/3d input tensors to 2d
448        from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
449        to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
450        query_out = self.query(from_tensor_2d)
451        key_out = self.key(to_tensor_2d)
452        value_out = self.value(to_tensor_2d)
453        query_layer = self.reshape(query_out, self.shape_from)
454        query_layer = self.transpose(query_layer, self.trans_shape)
455        key_layer = self.reshape(key_out, self.shape_to)
456        key_layer = self.transpose(key_layer, self.trans_shape)
457        attention_scores = self.matmul_trans_b(query_layer, key_layer)
458        # use_relative_position, supplementary logic
459        if self.use_relative_positions:
460            # relations_keys is [F|T, F|T, H]
461            relations_keys = self._generate_relative_positions_embeddings()
462            relations_keys = self.cast_compute_type(relations_keys)
463            # query_layer_t is [F, B, N, H]
464            query_layer_t = self.transpose(query_layer, self.trans_shape_relative)
465            # query_layer_r is [F, B * N, H]
466            query_layer_r = self.reshape(query_layer_t,
467                                         (self.from_seq_length,
468                                          -1,
469                                          self.size_per_head))
470            # key_position_scores is [F, B * N, F|T]
471            key_position_scores = self.matmul_trans_b(query_layer_r,
472                                                      relations_keys)
473            # key_position_scores_r is [F, B, N, F|T]
474            key_position_scores_r = self.reshape(key_position_scores,
475                                                 (self.from_seq_length,
476                                                  -1,
477                                                  self.num_attention_heads,
478                                                  self.from_seq_length))
479            # key_position_scores_r_t is [B, N, F, F|T]
480            key_position_scores_r_t = self.transpose(key_position_scores_r,
481                                                     self.trans_shape_position)
482            attention_scores = attention_scores + key_position_scores_r_t
483        attention_scores = self.multiply(self.scores_mul, attention_scores)
484        if self.has_attention_mask:
485            attention_mask = self.expand_dims(attention_mask, 1)
486            multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)),
487                                    self.cast(attention_mask, self.get_dtype(attention_scores)))
488            adder = self.multiply(multiply_out, self.multiply_data)
489            attention_scores = self.add(adder, attention_scores)
490        attention_probs = self.softmax(attention_scores)
491        attention_probs = self.dropout(attention_probs)
492        value_layer = self.reshape(value_out, self.shape_to)
493        value_layer = self.transpose(value_layer, self.trans_shape)
494        context_layer = self.matmul(attention_probs, value_layer)
495        # use_relative_position, supplementary logic
496        if self.use_relative_positions:
497            # relations_values is [F|T, F|T, H]
498            relations_values = self._generate_relative_positions_embeddings()
499            relations_values = self.cast_compute_type(relations_values)
500            # attention_probs_t is [F, B, N, T]
501            attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative)
502            # attention_probs_r is [F, B * N, T]
503            attention_probs_r = self.reshape(
504                attention_probs_t,
505                (self.from_seq_length,
506                 -1,
507                 self.to_seq_length))
508            # value_position_scores is [F, B * N, H]
509            value_position_scores = self.matmul(attention_probs_r,
510                                                relations_values)
511            # value_position_scores_r is [F, B, N, H]
512            value_position_scores_r = self.reshape(value_position_scores,
513                                                   (self.from_seq_length,
514                                                    -1,
515                                                    self.num_attention_heads,
516                                                    self.size_per_head))
517            # value_position_scores_r_t is [B, N, F, H]
518            value_position_scores_r_t = self.transpose(value_position_scores_r,
519                                                       self.trans_shape_position)
520            context_layer = context_layer + value_position_scores_r_t
521        context_layer = self.transpose(context_layer, self.trans_shape)
522        context_layer = self.reshape(context_layer, self.shape_return)
523        return context_layer, attention_scores
524
525
526class AlbertSelfAttention(nn.Cell):
527    """
528    Apply self-attention.
529
530    Args:
531        config (AlbertConfig): Albert Config.
532    """
533
534    def __init__(self, config):
535        super(AlbertSelfAttention, self).__init__()
536        if config.hidden_size % config.num_attention_heads != 0:
537            raise ValueError("The hidden size (%d) is not a multiple of the number "
538                             "of attention heads (%d)" % (config.hidden_size, config.num_attention_heads))
539        self.attention = AlbertAttention(config)
540        self.output = AlbertOutput(config)
541        self.reshape = P.Reshape()
542        self.shape = (-1, config.hidden_size)
543
544    def construct(self, input_tensor, attention_mask):
545        """bert self attention"""
546        input_tensor = self.reshape(input_tensor, self.shape)
547        attention_output, attention_scores = self.attention(input_tensor, input_tensor, attention_mask)
548        output = self.output(attention_output, input_tensor)
549        return output, attention_scores
550
551
552class AlbertEncoderCell(nn.Cell):
553    """
554    Encoder cells used in BertTransformer.
555
556    Args:
557        config (AlbertConfig): Albert Config.
558    """
559
560    def __init__(self, config):
561        super(AlbertEncoderCell, self).__init__()
562        self.attention = AlbertSelfAttention(config)
563        self.intermediate = nn.Dense(in_channels=config.hidden_size,
564                                     out_channels=config.intermediate_size,
565                                     activation=config.hidden_act,
566                                     weight_init=TruncatedNormal(config.initializer_range)
567                                     ).to_float(config.compute_type)
568        self.output = AlbertOutput(config)
569
570    def construct(self, hidden_states, attention_mask):
571        """bert encoder cell"""
572        # self-attention
573        attention_output, attention_scores = self.attention(hidden_states, attention_mask)
574        # feed construct
575        intermediate_output = self.intermediate(attention_output)
576        # add and normalize
577        output = self.output(intermediate_output, attention_output)
578        return output, attention_scores
579
580
581class AlbertLayer(nn.Cell):
582    """
583    Args:
584        config (AlbertConfig): Albert Config.
585    """
586    def __init__(self, config):
587        super(AlbertLayer, self).__init__()
588
589        self.output_attentions = config.output_attentions
590        self.attention = AlbertSelfAttention(config)
591        self.ffn = nn.Dense(config.hidden_size,
592                            config.intermediate_size,
593                            activation=config.hidden_act).to_float(config.compute_type)
594        self.ffn_output = nn.Dense(config.intermediate_size, config.hidden_size)
595        self.full_layer_layer_norm = nn.LayerNorm((config.hidden_size,))
596        self.shape = (-1, config.seq_length, config.hidden_size)
597        self.reshape = P.Reshape()
598
599    def construct(self, hidden_states, attention_mask):
600        attention_output, attention_scores = self.attention(hidden_states, attention_mask)
601
602        ffn_output = self.ffn(attention_output)
603        ffn_output = self.ffn_output(ffn_output)
604        ffn_output = self.reshape(ffn_output + attention_output, self.shape)
605        hidden_states = self.full_layer_layer_norm(ffn_output)
606
607        return hidden_states, attention_scores
608
609
610class AlbertLayerGroup(nn.Cell):
611    """
612    Args:
613        config (AlbertConfig): Albert Config.
614    """
615
616    def __init__(self, config):
617        super(AlbertLayerGroup, self).__init__()
618
619        self.output_attentions = config.output_attentions
620        self.output_hidden_states = config.output_hidden_states
621
622        self.albert_layers = nn.CellList([AlbertLayer(config) for _ in range(config.inner_group_num)])
623
624    def construct(self, hidden_states, attention_mask):
625        layer_hidden_states = ()
626        layer_attentions = ()
627
628        for _, albert_layer in enumerate(self.albert_layers):
629            layer_output = albert_layer(hidden_states, attention_mask)
630            hidden_states = layer_output[0]
631            if self.output_attentions:
632                layer_attentions = layer_attentions + (layer_output[1],)
633            if self.output_hidden_states:
634                layer_hidden_states = layer_hidden_states + (hidden_states,)
635
636        outputs = (hidden_states,)
637        if self.output_attentions:
638            outputs = outputs + (layer_attentions,)
639        if self.output_hidden_states:
640            outputs = outputs + (layer_hidden_states,)
641        return outputs
642
643
644class AlbertTransformer(nn.Cell):
645    """
646    Multi-layer bert transformer.
647
648    Args:
649        config (AlbertConfig): Albert Config.
650    """
651
652    def __init__(self, config):
653        super(AlbertTransformer, self).__init__()
654        self.num_hidden_layers = config.num_hidden_layers
655        self.num_hidden_groups = config.num_hidden_groups
656        self.group_idx_list = [int(_ / (config.num_hidden_layers / config.num_hidden_groups))
657                               for _ in range(config.num_hidden_layers)]
658
659        self.embedding_hidden_mapping_in = nn.Dense(config.embedding_size, config.hidden_size)
660        self.return_all_encoders = config.return_all_encoders
661        layers = []
662        for _ in range(config.num_hidden_groups):
663            layer = AlbertLayerGroup(config)
664            layers.append(layer)
665        self.albert_layer_groups = nn.CellList(layers)
666        self.reshape = P.Reshape()
667        self.shape = (-1, config.embedding_size)
668        self.out_shape = (-1, config.seq_length, config.hidden_size)
669
670    def construct(self, input_tensor, attention_mask):
671        """bert transformer"""
672        prev_output = self.reshape(input_tensor, self.shape)
673        prev_output = self.embedding_hidden_mapping_in(prev_output)
674        all_encoder_layers = ()
675        all_encoder_atts = ()
676        all_encoder_outputs = (prev_output,)
677        # for layer_module in self.layers:
678        for i in range(self.num_hidden_layers):
679            # Index of the hidden group
680            group_idx = self.group_idx_list[i]
681
682            layer_output, encoder_att = self.albert_layer_groups[group_idx](prev_output, attention_mask)
683            prev_output = layer_output
684            if self.return_all_encoders:
685                all_encoder_outputs += (layer_output,)
686                layer_output = self.reshape(layer_output, self.out_shape)
687                all_encoder_layers += (layer_output,)
688                all_encoder_atts += (encoder_att,)
689        if not self.return_all_encoders:
690            prev_output = self.reshape(prev_output, self.out_shape)
691            all_encoder_layers += (prev_output,)
692        return prev_output
693
694
695class CreateAttentionMaskFromInputMask(nn.Cell):
696    """
697    Create attention mask according to input mask.
698
699    Args:
700        config (Class): Configuration for BertModel.
701    """
702
703    def __init__(self, config):
704        super(CreateAttentionMaskFromInputMask, self).__init__()
705        self.cast = P.Cast()
706        self.reshape = P.Reshape()
707        self.shape = (-1, 1, config.seq_length)
708
709    def construct(self, input_mask):
710        attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
711        return attention_mask
712
713
714class AlbertModel(nn.Cell):
715    """
716    Bidirectional Encoder Representations from Transformers.
717
718    Args:
719        config (Class): Configuration for BertModel.
720    """
721
722    def __init__(self, config):
723        super(AlbertModel, self).__init__()
724        config = copy.deepcopy(config)
725        if not config.is_training:
726            config.hidden_dropout_prob = 0.0
727            config.attention_probs_dropout_prob = 0.0
728        self.seq_length = config.seq_length
729        self.hidden_size = config.hidden_size
730        self.num_hidden_layers = config.num_hidden_layers
731        self.embedding_size = config.hidden_size
732        self.token_type_ids = None
733        self.last_idx = self.num_hidden_layers - 1
734        self.use_word_embeddings = config.use_word_embeddings
735        if self.use_word_embeddings:
736            self.word_embeddings = EmbeddingLookup(config)
737        self.embedding_postprocessor = EmbeddingPostprocessor(config)
738        self.encoder = AlbertTransformer(config)
739        self.cast = P.Cast()
740        self.dtype = config.dtype
741        self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
742        self.slice = P.StridedSlice()
743        self.squeeze_1 = P.Squeeze(axis=1)
744        self.pooler = nn.Dense(self.hidden_size, self.hidden_size,
745                               activation="tanh",
746                               weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
747        self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)
748
749    def construct(self, input_ids, token_type_ids, input_mask):
750        """bert model"""
751        # embedding
752        if self.use_word_embeddings:
753            word_embeddings, _ = self.word_embeddings(input_ids)
754        else:
755            word_embeddings = input_ids
756        embedding_output = self.embedding_postprocessor(token_type_ids, word_embeddings)
757        # attention mask [batch_size, seq_length, seq_length]
758        attention_mask = self._create_attention_mask_from_input_mask(input_mask)
759        # bert encoder
760        encoder_output = self.encoder(self.cast_compute_type(embedding_output), attention_mask)
761        sequence_output = self.cast(encoder_output, self.dtype)
762        # pooler
763        batch_size = P.Shape()(input_ids)[0]
764        sequence_slice = self.slice(sequence_output,
765                                    (0, 0, 0),
766                                    (batch_size, 1, self.hidden_size),
767                                    (1, 1, 1))
768        first_token = self.squeeze_1(sequence_slice)
769        pooled_output = self.pooler(first_token)
770        pooled_output = self.cast(pooled_output, self.dtype)
771        return sequence_output, pooled_output
772
773
774class AlbertMLMHead(nn.Cell):
775    """
776    Get masked lm output.
777
778    Args:
779        config (AlbertConfig): The config of BertModel.
780
781    Returns:
782        Tensor, masked lm output.
783    """
784    def __init__(self, config):
785        super(AlbertMLMHead, self).__init__()
786
787        self.layernorm = nn.LayerNorm((config.embedding_size,)).to_float(config.compute_type)
788        self.dense = nn.Dense(
789            config.hidden_size,
790            config.embedding_size,
791            weight_init=TruncatedNormal(config.initializer_range),
792            activation=config.hidden_act
793        ).to_float(config.compute_type)
794        self.decoder = nn.Dense(
795            config.embedding_size,
796            config.vocab_size,
797            weight_init=TruncatedNormal(config.initializer_range),
798        ).to_float(config.compute_type)
799
800    def construct(self, hidden_states):
801        hidden_states = self.dense(hidden_states)
802        hidden_states = self.layernorm(hidden_states)
803        hidden_states = self.decoder(hidden_states)
804        return hidden_states
805
806
807class AlbertModelCLS(nn.Cell):
808    """
809    This class is responsible for classification task evaluation,
810    i.e. mnli(num_labels=3), qnli(num_labels=2), qqp(num_labels=2).
811    The returned output represents the final logits as the results of log_softmax is proportional to that of softmax.
812    """
813
814    def __init__(self, config):
815        super(AlbertModelCLS, self).__init__()
816        self.albert = AlbertModel(config)
817        self.cast = P.Cast()
818        self.weight_init = TruncatedNormal(config.initializer_range)
819        self.log_softmax = P.LogSoftmax(axis=-1)
820        self.dtype = config.dtype
821        self.classifier = nn.Dense(config.hidden_size, config.num_labels, weight_init=self.weight_init,
822                                   has_bias=True).to_float(config.compute_type)
823        self.relu = nn.ReLU()
824        self.is_training = config.is_training
825        if self.is_training:
826            self.dropout = nn.Dropout(1 - config.classifier_dropout_prob)
827
828    def construct(self, input_ids, input_mask, token_type_id):
829        """classification albert model"""
830        _, pooled_output = self.albert(input_ids, token_type_id, input_mask)
831        # pooled_output = self.relu(pooled_output)
832        if self.is_training:
833            pooled_output = self.dropout(pooled_output)
834        logits = self.classifier(pooled_output)
835        logits = self.cast(logits, self.dtype)
836        return logits
837
838
839class AlbertModelForAD(nn.Cell):
840    """albert model for ad"""
841
842    def __init__(self, config):
843        super(AlbertModelForAD, self).__init__()
844
845        # main model
846        self.albert = AlbertModel(config)
847
848        # classifier head
849        self.cast = P.Cast()
850        self.dtype = config.dtype
851        self.classifier = nn.Dense(config.hidden_size, config.num_labels,
852                                   weight_init=TruncatedNormal(config.initializer_range),
853                                   has_bias=True).to_float(config.compute_type)
854        self.is_training = config.is_training
855        if self.is_training:
856            self.dropout = nn.Dropout(1 - config.classifier_dropout_prob)
857
858        # masked language model head
859        self.predictions = AlbertMLMHead(config)
860
861    def construct(self, input_ids, input_mask, token_type_id):
862        """albert model for ad"""
863        sequence_output, pooled_output = self.albert(input_ids, token_type_id, input_mask)
864        if self.is_training:
865            pooled_output = self.dropout(pooled_output)
866        logits = self.classifier(pooled_output)
867        logits = self.cast(logits, self.dtype)
868        prediction_scores = self.predictions(sequence_output)
869        prediction_scores = self.cast(prediction_scores, self.dtype)
870        return prediction_scores, logits
871
872
873class AlbertModelMLM(nn.Cell):
874    """albert model for mlm"""
875
876    def __init__(self, config):
877        super(AlbertModelMLM, self).__init__()
878        self.cast = P.Cast()
879        self.dtype = config.dtype
880
881        # main model
882        self.albert = AlbertModel(config)
883
884        # masked language model head
885        self.predictions = AlbertMLMHead(config)
886
887    def construct(self, input_ids, input_mask, token_type_id):
888        """albert model for mlm"""
889        sequence_output, _ = self.albert(input_ids, token_type_id, input_mask)
890        prediction_scores = self.predictions(sequence_output)
891        prediction_scores = self.cast(prediction_scores, self.dtype)
892        return prediction_scores
893