1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import math 17import copy 18import numpy as np 19from mindspore import nn 20from mindspore import context 21from mindspore.common import dtype as mstype 22from mindspore.ops import functional as F 23from mindspore.ops import operations as P 24from mindspore.ops import composite as C 25from mindspore.common.initializer import TruncatedNormal, initializer 26from mindspore.common.tensor import Tensor 27from mindspore.common.parameter import Parameter 28 29 30class AlbertConfig: 31 """ 32 Configuration for `AlbertModel`. 33 34 Args: 35 seq_length (int): Length of input sequence. Default: 128. 36 vocab_size (int): The shape of each embedding vector. Default: 32000. 37 hidden_size (int): Size of the bert encoder layers. Default: 768. 38 num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder 39 cell. Default: 12. 40 num_attention_heads (int): Number of attention heads in the BertTransformer 41 encoder cell. Default: 12. 42 intermediate_size (int): Size of intermediate layer in the BertTransformer 43 encoder cell. Default: 3072. 44 hidden_act (str): Activation function used in the BertTransformer encoder 45 cell. Default: "gelu". 46 hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. 47 attention_probs_dropout_prob (float): The dropout probability for 48 BertAttention. Default: 0.1. 49 max_position_embeddings (int): Maximum length of sequences used in this 50 model. Default: 512. 51 type_vocab_size (int): Size of token type vocab. Default: 16. 52 initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. 53 use_relative_positions (bool): Specifies whether to use relative positions. Default: False. 54 dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32. 55 compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. 56 """ 57 58 def __init__(self, 59 seq_length=256, 60 vocab_size=21128, 61 hidden_size=312, 62 num_hidden_groups=1, 63 num_hidden_layers=4, 64 inner_group_num=1, 65 num_attention_heads=12, 66 intermediate_size=1248, 67 hidden_act="gelu", 68 query_act=None, 69 key_act=None, 70 value_act=None, 71 hidden_dropout_prob=0.0, 72 attention_probs_dropout_prob=0.0, 73 max_position_embeddings=512, 74 type_vocab_size=2, 75 initializer_range=0.02, 76 use_relative_positions=False, 77 classifier_dropout_prob=0.1, 78 embedding_size=128, 79 layer_norm_eps=1e-12, 80 has_attention_mask=True, 81 do_return_2d_tensor=True, 82 use_one_hot_embeddings=False, 83 use_token_type=True, 84 return_all_encoders=False, 85 output_attentions=False, 86 output_hidden_states=False, 87 dtype=mstype.float32, 88 compute_type=mstype.float32, 89 is_training=True, 90 num_labels=5, 91 use_word_embeddings=True): 92 self.seq_length = seq_length 93 self.vocab_size = vocab_size 94 self.hidden_size = hidden_size 95 self.num_hidden_layers = num_hidden_layers 96 self.inner_group_num = inner_group_num 97 self.num_attention_heads = num_attention_heads 98 self.hidden_act = hidden_act 99 self.query_act = query_act 100 self.key_act = key_act 101 self.value_act = value_act 102 self.intermediate_size = intermediate_size 103 self.hidden_dropout_prob = hidden_dropout_prob 104 self.attention_probs_dropout_prob = attention_probs_dropout_prob 105 self.max_position_embeddings = max_position_embeddings 106 self.type_vocab_size = type_vocab_size 107 self.initializer_range = initializer_range 108 self.use_relative_positions = use_relative_positions 109 self.classifier_dropout_prob = classifier_dropout_prob 110 self.embedding_size = embedding_size 111 self.layer_norm_eps = layer_norm_eps 112 self.num_hidden_groups = num_hidden_groups 113 self.has_attention_mask = has_attention_mask 114 self.do_return_2d_tensor = do_return_2d_tensor 115 self.use_one_hot_embeddings = use_one_hot_embeddings 116 self.use_token_type = use_token_type 117 self.return_all_encoders = return_all_encoders 118 self.output_attentions = output_attentions 119 self.output_hidden_states = output_hidden_states 120 self.dtype = dtype 121 self.compute_type = compute_type 122 self.is_training = is_training 123 self.num_labels = num_labels 124 self.use_word_embeddings = use_word_embeddings 125 126 127class EmbeddingLookup(nn.Cell): 128 """ 129 A embeddings lookup table with a fixed dictionary and size. 130 131 Args: 132 config (AlbertConfig): Albert Config. 133 """ 134 135 def __init__(self, config): 136 super(EmbeddingLookup, self).__init__() 137 self.vocab_size = config.vocab_size 138 self.use_one_hot_embeddings = config.use_one_hot_embeddings 139 self.embedding_table = Parameter(initializer 140 (TruncatedNormal(config.initializer_range), 141 [config.vocab_size, config.embedding_size]), 142 name='embedding_table') 143 self.expand = P.ExpandDims() 144 self.shape_flat = (-1,) 145 self.gather = P.Gather() 146 self.one_hot = P.OneHot() 147 self.on_value = Tensor(1.0, mstype.float32) 148 self.off_value = Tensor(0.0, mstype.float32) 149 self.array_mul = P.MatMul() 150 self.reshape = P.Reshape() 151 self.shape = (-1, config.seq_length, config.embedding_size) 152 153 def construct(self, input_ids): 154 """embedding lookup""" 155 flat_ids = self.reshape(input_ids, self.shape_flat) 156 if self.use_one_hot_embeddings: 157 one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) 158 output_for_reshape = self.array_mul( 159 one_hot_ids, self.embedding_table) 160 else: 161 output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) 162 output = self.reshape(output_for_reshape, self.shape) 163 return output, self.embedding_table 164 165 166class EmbeddingPostprocessor(nn.Cell): 167 """ 168 Postprocessors apply positional and token type embeddings to word embeddings. 169 170 Args: 171 config (AlbertConfig): Albert Config. 172 """ 173 174 def __init__(self, config): 175 super(EmbeddingPostprocessor, self).__init__() 176 self.use_token_type = config.use_token_type 177 self.token_type_vocab_size = config.type_vocab_size 178 self.use_one_hot_embeddings = config.use_one_hot_embeddings 179 self.max_position_embeddings = config.max_position_embeddings 180 self.embedding_table = Parameter(initializer 181 (TruncatedNormal(config.initializer_range), 182 [config.type_vocab_size, 183 config.embedding_size])) 184 self.shape_flat = (-1,) 185 self.one_hot = P.OneHot() 186 self.on_value = Tensor(1.0, mstype.float32) 187 self.off_value = Tensor(0.1, mstype.float32) 188 self.array_mul = P.MatMul() 189 self.reshape = P.Reshape() 190 self.shape = (-1, config.seq_length, config.embedding_size) 191 self.layernorm = nn.LayerNorm((config.embedding_size,)) 192 self.dropout = nn.Dropout(1 - config.hidden_dropout_prob) 193 self.gather = P.Gather() 194 self.use_relative_positions = config.use_relative_positions 195 self.slice = P.StridedSlice() 196 self.full_position_embeddings = Parameter(initializer 197 (TruncatedNormal(config.initializer_range), 198 [config.max_position_embeddings, 199 config.embedding_size])) 200 201 def construct(self, token_type_ids, word_embeddings): 202 """embedding postprocessor""" 203 output = word_embeddings 204 if self.use_token_type: 205 flat_ids = self.reshape(token_type_ids, self.shape_flat) 206 if self.use_one_hot_embeddings: 207 one_hot_ids = self.one_hot(flat_ids, 208 self.token_type_vocab_size, self.on_value, self.off_value) 209 token_type_embeddings = self.array_mul(one_hot_ids, 210 self.embedding_table) 211 else: 212 token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0) 213 token_type_embeddings = self.reshape(token_type_embeddings, self.shape) 214 output += token_type_embeddings 215 if not self.use_relative_positions: 216 _, seq, width = self.shape 217 position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1)) 218 position_embeddings = self.reshape(position_embeddings, (1, seq, width)) 219 output += position_embeddings 220 output = self.layernorm(output) 221 output = self.dropout(output) 222 return output 223 224 225class AlbertOutput(nn.Cell): 226 """ 227 Apply a linear computation to hidden status and a residual computation to input. 228 229 Args: 230 config (AlbertConfig): Albert Config. 231 """ 232 233 def __init__(self, config): 234 super(AlbertOutput, self).__init__() 235 self.dense = nn.Dense(config.hidden_size, config.hidden_size, 236 weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type) 237 self.dropout = nn.Dropout(1 - config.hidden_dropout_prob) 238 self.add = P.Add() 239 self.is_gpu = context.get_context('device_target') == "GPU" 240 if self.is_gpu: 241 self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(mstype.float32) 242 self.compute_type = config.compute_type 243 else: 244 self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type) 245 246 self.cast = P.Cast() 247 248 def construct(self, hidden_status, input_tensor): 249 """bert output""" 250 output = self.dense(hidden_status) 251 output = self.dropout(output) 252 output = self.add(input_tensor, output) 253 output = self.layernorm(output) 254 if self.is_gpu: 255 output = self.cast(output, self.compute_type) 256 return output 257 258 259class RelaPosMatrixGenerator(nn.Cell): 260 """ 261 Generates matrix of relative positions between inputs. 262 263 Args: 264 length (int): Length of one dim for the matrix to be generated. 265 max_relative_position (int): Max value of relative position. 266 """ 267 268 def __init__(self, length, max_relative_position): 269 super(RelaPosMatrixGenerator, self).__init__() 270 self._length = length 271 self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32) 272 self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32) 273 self.range_length = -length + 1 274 self.tile = P.Tile() 275 self.range_mat = P.Reshape() 276 self.sub = P.Sub() 277 self.expanddims = P.ExpandDims() 278 self.cast = P.Cast() 279 280 def construct(self): 281 """position matrix generator""" 282 range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32) 283 range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1)) 284 tile_row_out = self.tile(range_vec_row_out, (self._length,)) 285 tile_col_out = self.tile(range_vec_col_out, (1, self._length)) 286 range_mat_out = self.range_mat(tile_row_out, (self._length, self._length)) 287 transpose_out = self.range_mat(tile_col_out, (self._length, self._length)) 288 distance_mat = self.sub(range_mat_out, transpose_out) 289 distance_mat_clipped = C.clip_by_value(distance_mat, 290 self._min_relative_position, 291 self._max_relative_position) 292 # Shift values to be >=0. Each integer still uniquely identifies a 293 # relative position difference. 294 final_mat = distance_mat_clipped + self._max_relative_position 295 return final_mat 296 297 298class RelaPosEmbeddingsGenerator(nn.Cell): 299 """ 300 Generates tensor of size [length, length, depth]. 301 302 Args: 303 length (int): Length of one dim for the matrix to be generated. 304 depth (int): Size of each attention head. 305 max_relative_position (int): Maxmum value of relative position. 306 initializer_range (float): Initialization value of TruncatedNormal. 307 use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. 308 """ 309 310 def __init__(self, 311 length, 312 depth, 313 max_relative_position, 314 initializer_range, 315 use_one_hot_embeddings=False): 316 super(RelaPosEmbeddingsGenerator, self).__init__() 317 self.depth = depth 318 self.vocab_size = max_relative_position * 2 + 1 319 self.use_one_hot_embeddings = use_one_hot_embeddings 320 self.embeddings_table = Parameter( 321 initializer(TruncatedNormal(initializer_range), 322 [self.vocab_size, self.depth]), 323 name='embeddings_for_position') 324 self.relative_positions_matrix = RelaPosMatrixGenerator(length=length, 325 max_relative_position=max_relative_position) 326 self.reshape = P.Reshape() 327 self.one_hot = P.OneHot() 328 self.on_value = Tensor(1.0, mstype.float32) 329 self.off_value = Tensor(0.0, mstype.float32) 330 self.shape = P.Shape() 331 self.gather = P.Gather() # index_select 332 self.matmul = P.BatchMatMul() 333 334 def construct(self): 335 """position embedding generation""" 336 relative_positions_matrix_out = self.relative_positions_matrix() 337 # Generate embedding for each relative position of dimension depth. 338 if self.use_one_hot_embeddings: 339 flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,)) 340 one_hot_relative_positions_matrix = self.one_hot( 341 flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value) 342 embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table) 343 my_shape = self.shape(relative_positions_matrix_out) + (self.depth,) 344 embeddings = self.reshape(embeddings, my_shape) 345 else: 346 embeddings = self.gather(self.embeddings_table, 347 relative_positions_matrix_out, 0) 348 return embeddings 349 350 351class SaturateCast(nn.Cell): 352 """ 353 Performs a safe saturating cast. This operation applies proper clamping before casting to prevent 354 the danger that the value will overflow or underflow. 355 356 Args: 357 src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32. 358 dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32. 359 """ 360 361 def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): 362 super(SaturateCast, self).__init__() 363 np_type = mstype.dtype_to_nptype(dst_type) 364 min_type = np.finfo(np_type).min 365 max_type = np.finfo(np_type).max 366 self.tensor_min_type = Tensor([min_type], dtype=src_type) 367 self.tensor_max_type = Tensor([max_type], dtype=src_type) 368 self.min_op = P.Minimum() 369 self.max_op = P.Maximum() 370 self.cast = P.Cast() 371 self.dst_type = dst_type 372 373 def construct(self, x): 374 """saturate cast""" 375 out = self.max_op(x, self.tensor_min_type) 376 out = self.min_op(out, self.tensor_max_type) 377 return self.cast(out, self.dst_type) 378 379 380class AlbertAttention(nn.Cell): 381 """ 382 Apply multi-headed attention from "from_tensor" to "to_tensor". 383 384 Args: 385 config (AlbertConfig): Albert Config. 386 """ 387 388 def __init__(self, config): 389 super(AlbertAttention, self).__init__() 390 self.from_seq_length = config.seq_length 391 self.to_seq_length = config.seq_length 392 self.num_attention_heads = config.num_attention_heads 393 self.size_per_head = int(config.hidden_size / config.num_attention_heads) 394 self.has_attention_mask = config.has_attention_mask 395 self.use_relative_positions = config.use_relative_positions 396 self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=config.compute_type) 397 self.reshape = P.Reshape() 398 self.shape_from_2d = (-1, config.hidden_size) 399 self.shape_to_2d = (-1, config.hidden_size) 400 weight = TruncatedNormal(config.initializer_range) 401 402 self.query = nn.Dense(config.hidden_size, 403 config.hidden_size, 404 activation=config.query_act, 405 weight_init=weight).to_float(config.compute_type) 406 self.key = nn.Dense(config.hidden_size, 407 config.hidden_size, 408 activation=config.key_act, 409 weight_init=weight).to_float(config.compute_type) 410 self.value = nn.Dense(config.hidden_size, 411 config.hidden_size, 412 activation=config.value_act, 413 weight_init=weight).to_float(config.compute_type) 414 self.matmul_trans_b = P.BatchMatMul(transpose_b=True) 415 self.matmul = P.BatchMatMul() 416 self.shape_from = (-1, config.seq_length, config.num_attention_heads, self.size_per_head) 417 self.shape_to = (-1, config.seq_length, config.num_attention_heads, self.size_per_head) 418 self.multiply = P.Mul() 419 self.transpose = P.Transpose() 420 self.trans_shape = (0, 2, 1, 3) 421 self.trans_shape_relative = (2, 0, 1, 3) 422 self.trans_shape_position = (1, 2, 0, 3) 423 self.multiply_data = Tensor([-10000.0], dtype=config.compute_type) 424 self.softmax = nn.Softmax() 425 self.dropout = nn.Dropout(1 - config.attention_probs_dropout_prob) 426 if self.has_attention_mask: 427 self.expand_dims = P.ExpandDims() 428 self.sub = P.Sub() 429 self.add = P.Add() 430 self.cast = P.Cast() 431 self.get_dtype = P.DType() 432 if config.do_return_2d_tensor: 433 self.shape_return = (-1, config.hidden_size) 434 else: 435 self.shape_return = (-1, config.seq_length, config.hidden_size) 436 self.cast_compute_type = SaturateCast(dst_type=config.compute_type) 437 if self.use_relative_positions: 438 self._generate_relative_positions_embeddings = \ 439 RelaPosEmbeddingsGenerator(length=config.seq_length, 440 depth=self.size_per_head, 441 max_relative_position=16, 442 initializer_range=config.initializer_range, 443 use_one_hot_embeddings=config.use_one_hot_embeddings) 444 445 def construct(self, from_tensor, to_tensor, attention_mask): 446 """bert attention""" 447 # reshape 2d/3d input tensors to 2d 448 from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d) 449 to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d) 450 query_out = self.query(from_tensor_2d) 451 key_out = self.key(to_tensor_2d) 452 value_out = self.value(to_tensor_2d) 453 query_layer = self.reshape(query_out, self.shape_from) 454 query_layer = self.transpose(query_layer, self.trans_shape) 455 key_layer = self.reshape(key_out, self.shape_to) 456 key_layer = self.transpose(key_layer, self.trans_shape) 457 attention_scores = self.matmul_trans_b(query_layer, key_layer) 458 # use_relative_position, supplementary logic 459 if self.use_relative_positions: 460 # relations_keys is [F|T, F|T, H] 461 relations_keys = self._generate_relative_positions_embeddings() 462 relations_keys = self.cast_compute_type(relations_keys) 463 # query_layer_t is [F, B, N, H] 464 query_layer_t = self.transpose(query_layer, self.trans_shape_relative) 465 # query_layer_r is [F, B * N, H] 466 query_layer_r = self.reshape(query_layer_t, 467 (self.from_seq_length, 468 -1, 469 self.size_per_head)) 470 # key_position_scores is [F, B * N, F|T] 471 key_position_scores = self.matmul_trans_b(query_layer_r, 472 relations_keys) 473 # key_position_scores_r is [F, B, N, F|T] 474 key_position_scores_r = self.reshape(key_position_scores, 475 (self.from_seq_length, 476 -1, 477 self.num_attention_heads, 478 self.from_seq_length)) 479 # key_position_scores_r_t is [B, N, F, F|T] 480 key_position_scores_r_t = self.transpose(key_position_scores_r, 481 self.trans_shape_position) 482 attention_scores = attention_scores + key_position_scores_r_t 483 attention_scores = self.multiply(self.scores_mul, attention_scores) 484 if self.has_attention_mask: 485 attention_mask = self.expand_dims(attention_mask, 1) 486 multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)), 487 self.cast(attention_mask, self.get_dtype(attention_scores))) 488 adder = self.multiply(multiply_out, self.multiply_data) 489 attention_scores = self.add(adder, attention_scores) 490 attention_probs = self.softmax(attention_scores) 491 attention_probs = self.dropout(attention_probs) 492 value_layer = self.reshape(value_out, self.shape_to) 493 value_layer = self.transpose(value_layer, self.trans_shape) 494 context_layer = self.matmul(attention_probs, value_layer) 495 # use_relative_position, supplementary logic 496 if self.use_relative_positions: 497 # relations_values is [F|T, F|T, H] 498 relations_values = self._generate_relative_positions_embeddings() 499 relations_values = self.cast_compute_type(relations_values) 500 # attention_probs_t is [F, B, N, T] 501 attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative) 502 # attention_probs_r is [F, B * N, T] 503 attention_probs_r = self.reshape( 504 attention_probs_t, 505 (self.from_seq_length, 506 -1, 507 self.to_seq_length)) 508 # value_position_scores is [F, B * N, H] 509 value_position_scores = self.matmul(attention_probs_r, 510 relations_values) 511 # value_position_scores_r is [F, B, N, H] 512 value_position_scores_r = self.reshape(value_position_scores, 513 (self.from_seq_length, 514 -1, 515 self.num_attention_heads, 516 self.size_per_head)) 517 # value_position_scores_r_t is [B, N, F, H] 518 value_position_scores_r_t = self.transpose(value_position_scores_r, 519 self.trans_shape_position) 520 context_layer = context_layer + value_position_scores_r_t 521 context_layer = self.transpose(context_layer, self.trans_shape) 522 context_layer = self.reshape(context_layer, self.shape_return) 523 return context_layer, attention_scores 524 525 526class AlbertSelfAttention(nn.Cell): 527 """ 528 Apply self-attention. 529 530 Args: 531 config (AlbertConfig): Albert Config. 532 """ 533 534 def __init__(self, config): 535 super(AlbertSelfAttention, self).__init__() 536 if config.hidden_size % config.num_attention_heads != 0: 537 raise ValueError("The hidden size (%d) is not a multiple of the number " 538 "of attention heads (%d)" % (config.hidden_size, config.num_attention_heads)) 539 self.attention = AlbertAttention(config) 540 self.output = AlbertOutput(config) 541 self.reshape = P.Reshape() 542 self.shape = (-1, config.hidden_size) 543 544 def construct(self, input_tensor, attention_mask): 545 """bert self attention""" 546 input_tensor = self.reshape(input_tensor, self.shape) 547 attention_output, attention_scores = self.attention(input_tensor, input_tensor, attention_mask) 548 output = self.output(attention_output, input_tensor) 549 return output, attention_scores 550 551 552class AlbertEncoderCell(nn.Cell): 553 """ 554 Encoder cells used in BertTransformer. 555 556 Args: 557 config (AlbertConfig): Albert Config. 558 """ 559 560 def __init__(self, config): 561 super(AlbertEncoderCell, self).__init__() 562 self.attention = AlbertSelfAttention(config) 563 self.intermediate = nn.Dense(in_channels=config.hidden_size, 564 out_channels=config.intermediate_size, 565 activation=config.hidden_act, 566 weight_init=TruncatedNormal(config.initializer_range) 567 ).to_float(config.compute_type) 568 self.output = AlbertOutput(config) 569 570 def construct(self, hidden_states, attention_mask): 571 """bert encoder cell""" 572 # self-attention 573 attention_output, attention_scores = self.attention(hidden_states, attention_mask) 574 # feed construct 575 intermediate_output = self.intermediate(attention_output) 576 # add and normalize 577 output = self.output(intermediate_output, attention_output) 578 return output, attention_scores 579 580 581class AlbertLayer(nn.Cell): 582 """ 583 Args: 584 config (AlbertConfig): Albert Config. 585 """ 586 def __init__(self, config): 587 super(AlbertLayer, self).__init__() 588 589 self.output_attentions = config.output_attentions 590 self.attention = AlbertSelfAttention(config) 591 self.ffn = nn.Dense(config.hidden_size, 592 config.intermediate_size, 593 activation=config.hidden_act).to_float(config.compute_type) 594 self.ffn_output = nn.Dense(config.intermediate_size, config.hidden_size) 595 self.full_layer_layer_norm = nn.LayerNorm((config.hidden_size,)) 596 self.shape = (-1, config.seq_length, config.hidden_size) 597 self.reshape = P.Reshape() 598 599 def construct(self, hidden_states, attention_mask): 600 attention_output, attention_scores = self.attention(hidden_states, attention_mask) 601 602 ffn_output = self.ffn(attention_output) 603 ffn_output = self.ffn_output(ffn_output) 604 ffn_output = self.reshape(ffn_output + attention_output, self.shape) 605 hidden_states = self.full_layer_layer_norm(ffn_output) 606 607 return hidden_states, attention_scores 608 609 610class AlbertLayerGroup(nn.Cell): 611 """ 612 Args: 613 config (AlbertConfig): Albert Config. 614 """ 615 616 def __init__(self, config): 617 super(AlbertLayerGroup, self).__init__() 618 619 self.output_attentions = config.output_attentions 620 self.output_hidden_states = config.output_hidden_states 621 622 self.albert_layers = nn.CellList([AlbertLayer(config) for _ in range(config.inner_group_num)]) 623 624 def construct(self, hidden_states, attention_mask): 625 layer_hidden_states = () 626 layer_attentions = () 627 628 for _, albert_layer in enumerate(self.albert_layers): 629 layer_output = albert_layer(hidden_states, attention_mask) 630 hidden_states = layer_output[0] 631 if self.output_attentions: 632 layer_attentions = layer_attentions + (layer_output[1],) 633 if self.output_hidden_states: 634 layer_hidden_states = layer_hidden_states + (hidden_states,) 635 636 outputs = (hidden_states,) 637 if self.output_attentions: 638 outputs = outputs + (layer_attentions,) 639 if self.output_hidden_states: 640 outputs = outputs + (layer_hidden_states,) 641 return outputs 642 643 644class AlbertTransformer(nn.Cell): 645 """ 646 Multi-layer bert transformer. 647 648 Args: 649 config (AlbertConfig): Albert Config. 650 """ 651 652 def __init__(self, config): 653 super(AlbertTransformer, self).__init__() 654 self.num_hidden_layers = config.num_hidden_layers 655 self.num_hidden_groups = config.num_hidden_groups 656 self.group_idx_list = [int(_ / (config.num_hidden_layers / config.num_hidden_groups)) 657 for _ in range(config.num_hidden_layers)] 658 659 self.embedding_hidden_mapping_in = nn.Dense(config.embedding_size, config.hidden_size) 660 self.return_all_encoders = config.return_all_encoders 661 layers = [] 662 for _ in range(config.num_hidden_groups): 663 layer = AlbertLayerGroup(config) 664 layers.append(layer) 665 self.albert_layer_groups = nn.CellList(layers) 666 self.reshape = P.Reshape() 667 self.shape = (-1, config.embedding_size) 668 self.out_shape = (-1, config.seq_length, config.hidden_size) 669 670 def construct(self, input_tensor, attention_mask): 671 """bert transformer""" 672 prev_output = self.reshape(input_tensor, self.shape) 673 prev_output = self.embedding_hidden_mapping_in(prev_output) 674 all_encoder_layers = () 675 all_encoder_atts = () 676 all_encoder_outputs = (prev_output,) 677 # for layer_module in self.layers: 678 for i in range(self.num_hidden_layers): 679 # Index of the hidden group 680 group_idx = self.group_idx_list[i] 681 682 layer_output, encoder_att = self.albert_layer_groups[group_idx](prev_output, attention_mask) 683 prev_output = layer_output 684 if self.return_all_encoders: 685 all_encoder_outputs += (layer_output,) 686 layer_output = self.reshape(layer_output, self.out_shape) 687 all_encoder_layers += (layer_output,) 688 all_encoder_atts += (encoder_att,) 689 if not self.return_all_encoders: 690 prev_output = self.reshape(prev_output, self.out_shape) 691 all_encoder_layers += (prev_output,) 692 return prev_output 693 694 695class CreateAttentionMaskFromInputMask(nn.Cell): 696 """ 697 Create attention mask according to input mask. 698 699 Args: 700 config (Class): Configuration for BertModel. 701 """ 702 703 def __init__(self, config): 704 super(CreateAttentionMaskFromInputMask, self).__init__() 705 self.cast = P.Cast() 706 self.reshape = P.Reshape() 707 self.shape = (-1, 1, config.seq_length) 708 709 def construct(self, input_mask): 710 attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32) 711 return attention_mask 712 713 714class AlbertModel(nn.Cell): 715 """ 716 Bidirectional Encoder Representations from Transformers. 717 718 Args: 719 config (Class): Configuration for BertModel. 720 """ 721 722 def __init__(self, config): 723 super(AlbertModel, self).__init__() 724 config = copy.deepcopy(config) 725 if not config.is_training: 726 config.hidden_dropout_prob = 0.0 727 config.attention_probs_dropout_prob = 0.0 728 self.seq_length = config.seq_length 729 self.hidden_size = config.hidden_size 730 self.num_hidden_layers = config.num_hidden_layers 731 self.embedding_size = config.hidden_size 732 self.token_type_ids = None 733 self.last_idx = self.num_hidden_layers - 1 734 self.use_word_embeddings = config.use_word_embeddings 735 if self.use_word_embeddings: 736 self.word_embeddings = EmbeddingLookup(config) 737 self.embedding_postprocessor = EmbeddingPostprocessor(config) 738 self.encoder = AlbertTransformer(config) 739 self.cast = P.Cast() 740 self.dtype = config.dtype 741 self.cast_compute_type = SaturateCast(dst_type=config.compute_type) 742 self.slice = P.StridedSlice() 743 self.squeeze_1 = P.Squeeze(axis=1) 744 self.pooler = nn.Dense(self.hidden_size, self.hidden_size, 745 activation="tanh", 746 weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type) 747 self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) 748 749 def construct(self, input_ids, token_type_ids, input_mask): 750 """bert model""" 751 # embedding 752 if self.use_word_embeddings: 753 word_embeddings, _ = self.word_embeddings(input_ids) 754 else: 755 word_embeddings = input_ids 756 embedding_output = self.embedding_postprocessor(token_type_ids, word_embeddings) 757 # attention mask [batch_size, seq_length, seq_length] 758 attention_mask = self._create_attention_mask_from_input_mask(input_mask) 759 # bert encoder 760 encoder_output = self.encoder(self.cast_compute_type(embedding_output), attention_mask) 761 sequence_output = self.cast(encoder_output, self.dtype) 762 # pooler 763 batch_size = P.Shape()(input_ids)[0] 764 sequence_slice = self.slice(sequence_output, 765 (0, 0, 0), 766 (batch_size, 1, self.hidden_size), 767 (1, 1, 1)) 768 first_token = self.squeeze_1(sequence_slice) 769 pooled_output = self.pooler(first_token) 770 pooled_output = self.cast(pooled_output, self.dtype) 771 return sequence_output, pooled_output 772 773 774class AlbertMLMHead(nn.Cell): 775 """ 776 Get masked lm output. 777 778 Args: 779 config (AlbertConfig): The config of BertModel. 780 781 Returns: 782 Tensor, masked lm output. 783 """ 784 def __init__(self, config): 785 super(AlbertMLMHead, self).__init__() 786 787 self.layernorm = nn.LayerNorm((config.embedding_size,)).to_float(config.compute_type) 788 self.dense = nn.Dense( 789 config.hidden_size, 790 config.embedding_size, 791 weight_init=TruncatedNormal(config.initializer_range), 792 activation=config.hidden_act 793 ).to_float(config.compute_type) 794 self.decoder = nn.Dense( 795 config.embedding_size, 796 config.vocab_size, 797 weight_init=TruncatedNormal(config.initializer_range), 798 ).to_float(config.compute_type) 799 800 def construct(self, hidden_states): 801 hidden_states = self.dense(hidden_states) 802 hidden_states = self.layernorm(hidden_states) 803 hidden_states = self.decoder(hidden_states) 804 return hidden_states 805 806 807class AlbertModelCLS(nn.Cell): 808 """ 809 This class is responsible for classification task evaluation, 810 i.e. mnli(num_labels=3), qnli(num_labels=2), qqp(num_labels=2). 811 The returned output represents the final logits as the results of log_softmax is proportional to that of softmax. 812 """ 813 814 def __init__(self, config): 815 super(AlbertModelCLS, self).__init__() 816 self.albert = AlbertModel(config) 817 self.cast = P.Cast() 818 self.weight_init = TruncatedNormal(config.initializer_range) 819 self.log_softmax = P.LogSoftmax(axis=-1) 820 self.dtype = config.dtype 821 self.classifier = nn.Dense(config.hidden_size, config.num_labels, weight_init=self.weight_init, 822 has_bias=True).to_float(config.compute_type) 823 self.relu = nn.ReLU() 824 self.is_training = config.is_training 825 if self.is_training: 826 self.dropout = nn.Dropout(1 - config.classifier_dropout_prob) 827 828 def construct(self, input_ids, input_mask, token_type_id): 829 """classification albert model""" 830 _, pooled_output = self.albert(input_ids, token_type_id, input_mask) 831 # pooled_output = self.relu(pooled_output) 832 if self.is_training: 833 pooled_output = self.dropout(pooled_output) 834 logits = self.classifier(pooled_output) 835 logits = self.cast(logits, self.dtype) 836 return logits 837 838 839class AlbertModelForAD(nn.Cell): 840 """albert model for ad""" 841 842 def __init__(self, config): 843 super(AlbertModelForAD, self).__init__() 844 845 # main model 846 self.albert = AlbertModel(config) 847 848 # classifier head 849 self.cast = P.Cast() 850 self.dtype = config.dtype 851 self.classifier = nn.Dense(config.hidden_size, config.num_labels, 852 weight_init=TruncatedNormal(config.initializer_range), 853 has_bias=True).to_float(config.compute_type) 854 self.is_training = config.is_training 855 if self.is_training: 856 self.dropout = nn.Dropout(1 - config.classifier_dropout_prob) 857 858 # masked language model head 859 self.predictions = AlbertMLMHead(config) 860 861 def construct(self, input_ids, input_mask, token_type_id): 862 """albert model for ad""" 863 sequence_output, pooled_output = self.albert(input_ids, token_type_id, input_mask) 864 if self.is_training: 865 pooled_output = self.dropout(pooled_output) 866 logits = self.classifier(pooled_output) 867 logits = self.cast(logits, self.dtype) 868 prediction_scores = self.predictions(sequence_output) 869 prediction_scores = self.cast(prediction_scores, self.dtype) 870 return prediction_scores, logits 871 872 873class AlbertModelMLM(nn.Cell): 874 """albert model for mlm""" 875 876 def __init__(self, config): 877 super(AlbertModelMLM, self).__init__() 878 self.cast = P.Cast() 879 self.dtype = config.dtype 880 881 # main model 882 self.albert = AlbertModel(config) 883 884 # masked language model head 885 self.predictions = AlbertMLMHead(config) 886 887 def construct(self, input_ids, input_mask, token_type_id): 888 """albert model for mlm""" 889 sequence_output, _ = self.albert(input_ids, token_type_id, input_mask) 890 prediction_scores = self.predictions(sequence_output) 891 prediction_scores = self.cast(prediction_scores, self.dtype) 892 return prediction_scores 893