1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Bert model.""" 16 17import math 18import copy 19import numpy as np 20import mindspore.common.dtype as mstype 21import mindspore.nn as nn 22import mindspore.ops.functional as F 23from mindspore.common.initializer import TruncatedNormal, initializer 24from mindspore.ops import operations as P 25from mindspore.ops import composite as C 26from mindspore.common.tensor import Tensor 27from mindspore.common.parameter import Parameter 28 29 30class BertConfig: 31 """ 32 Configuration for `BertModel`. 33 34 Args: 35 batch_size (int): Batch size of input dataset. 36 seq_length (int): Length of input sequence. Default: 128. 37 vocab_size (int): The shape of each embedding vector. Default: 32000. 38 hidden_size (int): Size of the bert encoder layers. Default: 768. 39 num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder 40 cell. Default: 12. 41 num_attention_heads (int): Number of attention heads in the BertTransformer 42 encoder cell. Default: 12. 43 intermediate_size (int): Size of intermediate layer in the BertTransformer 44 encoder cell. Default: 3072. 45 hidden_act (str): Activation function used in the BertTransformer encoder 46 cell. Default: "gelu". 47 hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. 48 attention_probs_dropout_prob (float): The dropout probability for 49 BertAttention. Default: 0.1. 50 max_position_embeddings (int): Maximum length of sequences used in this 51 model. Default: 512. 52 type_vocab_size (int): Size of token type vocab. Default: 16. 53 initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. 54 use_relative_positions (bool): Specifies whether to use relative positions. Default: False. 55 input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from 56 dataset. Default: True. 57 token_type_ids_from_dataset (bool): Specifies whether to use the token type ids that loaded 58 from dataset. Default: True. 59 dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32. 60 compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. 61 """ 62 def __init__(self, 63 batch_size, 64 seq_length=128, 65 vocab_size=32000, 66 hidden_size=768, 67 num_hidden_layers=12, 68 num_attention_heads=12, 69 intermediate_size=3072, 70 hidden_act="gelu", 71 hidden_dropout_prob=0.1, 72 attention_probs_dropout_prob=0.1, 73 max_position_embeddings=512, 74 type_vocab_size=16, 75 initializer_range=0.02, 76 use_relative_positions=False, 77 input_mask_from_dataset=True, 78 token_type_ids_from_dataset=True, 79 dtype=mstype.float32, 80 compute_type=mstype.float32, 81 enable_fused_layernorm=False): 82 self.batch_size = batch_size 83 self.seq_length = seq_length 84 self.vocab_size = vocab_size 85 self.hidden_size = hidden_size 86 self.num_hidden_layers = num_hidden_layers 87 self.num_attention_heads = num_attention_heads 88 self.hidden_act = hidden_act 89 self.intermediate_size = intermediate_size 90 self.hidden_dropout_prob = hidden_dropout_prob 91 self.attention_probs_dropout_prob = attention_probs_dropout_prob 92 self.max_position_embeddings = max_position_embeddings 93 self.type_vocab_size = type_vocab_size 94 self.initializer_range = initializer_range 95 self.input_mask_from_dataset = input_mask_from_dataset 96 self.token_type_ids_from_dataset = token_type_ids_from_dataset 97 self.use_relative_positions = use_relative_positions 98 self.dtype = dtype 99 self.compute_type = compute_type 100 self.enable_fused_layernorm = enable_fused_layernorm 101 102 103class EmbeddingLookup(nn.Cell): 104 """ 105 A embeddings lookup table with a fixed dictionary and size. 106 107 Args: 108 vocab_size (int): Size of the dictionary of embeddings. 109 embedding_size (int): The size of each embedding vector. 110 embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of 111 each embedding vector. 112 use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. 113 initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. 114 """ 115 def __init__(self, 116 vocab_size, 117 embedding_size, 118 embedding_shape, 119 use_one_hot_embeddings=False, 120 initializer_range=0.02): 121 super(EmbeddingLookup, self).__init__() 122 self.vocab_size = vocab_size 123 self.use_one_hot_embeddings = use_one_hot_embeddings 124 self.embedding_table = Parameter(initializer 125 (TruncatedNormal(initializer_range), 126 [vocab_size, embedding_size]), 127 name='embedding_table') 128 self.expand = P.ExpandDims() 129 self.shape_flat = (-1,) 130 self.gather = P.Gather() 131 self.one_hot = P.OneHot() 132 self.on_value = Tensor(1.0, mstype.float32) 133 self.off_value = Tensor(0.0, mstype.float32) 134 self.array_mul = P.MatMul() 135 self.reshape = P.Reshape() 136 self.shape = tuple(embedding_shape) 137 138 def construct(self, input_ids): 139 extended_ids = self.expand(input_ids, -1) 140 flat_ids = self.reshape(extended_ids, self.shape_flat) 141 if self.use_one_hot_embeddings: 142 one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) 143 output_for_reshape = self.array_mul( 144 one_hot_ids, self.embedding_table) 145 else: 146 output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) 147 output = self.reshape(output_for_reshape, self.shape) 148 return output, self.embedding_table 149 150 151class EmbeddingPostprocessor(nn.Cell): 152 """ 153 Postprocessors apply positional and token type embeddings to word embeddings. 154 155 Args: 156 embedding_size (int): The size of each embedding vector. 157 embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of 158 each embedding vector. 159 use_token_type (bool): Specifies whether to use token type embeddings. Default: False. 160 token_type_vocab_size (int): Size of token type vocab. Default: 16. 161 use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. 162 initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. 163 max_position_embeddings (int): Maximum length of sequences used in this 164 model. Default: 512. 165 dropout_prob (float): The dropout probability. Default: 0.1. 166 """ 167 def __init__(self, 168 embedding_size, 169 embedding_shape, 170 use_relative_positions=False, 171 use_token_type=False, 172 token_type_vocab_size=16, 173 use_one_hot_embeddings=False, 174 initializer_range=0.02, 175 max_position_embeddings=512, 176 dropout_prob=0.1): 177 super(EmbeddingPostprocessor, self).__init__() 178 self.use_token_type = use_token_type 179 self.token_type_vocab_size = token_type_vocab_size 180 self.use_one_hot_embeddings = use_one_hot_embeddings 181 self.max_position_embeddings = max_position_embeddings 182 self.embedding_table = Parameter(initializer 183 (TruncatedNormal(initializer_range), 184 [token_type_vocab_size, 185 embedding_size]), 186 name='embedding_table') 187 188 self.shape_flat = (-1,) 189 self.one_hot = P.OneHot() 190 self.on_value = Tensor(1.0, mstype.float32) 191 self.off_value = Tensor(0.1, mstype.float32) 192 self.array_mul = P.MatMul() 193 self.reshape = P.Reshape() 194 self.shape = tuple(embedding_shape) 195 self.layernorm = nn.LayerNorm((embedding_size,)) 196 self.dropout = nn.Dropout(1 - dropout_prob) 197 self.gather = P.Gather() 198 self.use_relative_positions = use_relative_positions 199 self.slice = P.StridedSlice() 200 self.full_position_embeddings = Parameter(initializer 201 (TruncatedNormal(initializer_range), 202 [max_position_embeddings, 203 embedding_size]), 204 name='full_position_embeddings') 205 206 def construct(self, token_type_ids, word_embeddings): 207 output = word_embeddings 208 if self.use_token_type: 209 flat_ids = self.reshape(token_type_ids, self.shape_flat) 210 if self.use_one_hot_embeddings: 211 one_hot_ids = self.one_hot(flat_ids, 212 self.token_type_vocab_size, self.on_value, self.off_value) 213 token_type_embeddings = self.array_mul(one_hot_ids, 214 self.embedding_table) 215 else: 216 token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0) 217 token_type_embeddings = self.reshape(token_type_embeddings, self.shape) 218 output += token_type_embeddings 219 if not self.use_relative_positions: 220 _, seq, width = self.shape 221 position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1)) 222 position_embeddings = self.reshape(position_embeddings, (1, seq, width)) 223 output += position_embeddings 224 output = self.layernorm(output) 225 output = self.dropout(output) 226 return output 227 228 229class BertOutput(nn.Cell): 230 """ 231 Apply a linear computation to hidden status and a residual computation to input. 232 233 Args: 234 in_channels (int): Input channels. 235 out_channels (int): Output channels. 236 initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. 237 dropout_prob (float): The dropout probability. Default: 0.1. 238 compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. 239 """ 240 def __init__(self, 241 in_channels, 242 out_channels, 243 initializer_range=0.02, 244 dropout_prob=0.1, 245 compute_type=mstype.float32, 246 enable_fused_layernorm=False): 247 super(BertOutput, self).__init__() 248 self.dense = nn.Dense(in_channels, out_channels, 249 weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) 250 self.dropout = nn.Dropout(1 - dropout_prob) 251 self.dropout_prob = dropout_prob 252 self.add = P.Add() 253 self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) 254 self.cast = P.Cast() 255 256 def construct(self, hidden_status, input_tensor): 257 output = self.dense(hidden_status) 258 output = self.dropout(output) 259 output = self.add(output, input_tensor) 260 output = self.layernorm(output) 261 return output 262 263 264class RelaPosMatrixGenerator(nn.Cell): 265 """ 266 Generates matrix of relative positions between inputs. 267 268 Args: 269 length (int): Length of one dim for the matrix to be generated. 270 max_relative_position (int): Max value of relative position. 271 """ 272 def __init__(self, length, max_relative_position): 273 super(RelaPosMatrixGenerator, self).__init__() 274 self._length = length 275 self._max_relative_position = max_relative_position 276 self._min_relative_position = -max_relative_position 277 self.range_length = -length + 1 278 279 self.tile = P.Tile() 280 self.range_mat = P.Reshape() 281 self.sub = P.Sub() 282 self.expanddims = P.ExpandDims() 283 self.cast = P.Cast() 284 285 def construct(self): 286 range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32) 287 range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1)) 288 tile_row_out = self.tile(range_vec_row_out, (self._length,)) 289 tile_col_out = self.tile(range_vec_col_out, (1, self._length)) 290 range_mat_out = self.range_mat(tile_row_out, (self._length, self._length)) 291 transpose_out = self.range_mat(tile_col_out, (self._length, self._length)) 292 distance_mat = self.sub(range_mat_out, transpose_out) 293 294 distance_mat_clipped = C.clip_by_value(distance_mat, 295 self._min_relative_position, 296 self._max_relative_position) 297 298 # Shift values to be >=0. Each integer still uniquely identifies a 299 # relative position difference. 300 final_mat = distance_mat_clipped + self._max_relative_position 301 return final_mat 302 303 304class RelaPosEmbeddingsGenerator(nn.Cell): 305 """ 306 Generates tensor of size [length, length, depth]. 307 308 Args: 309 length (int): Length of one dim for the matrix to be generated. 310 depth (int): Size of each attention head. 311 max_relative_position (int): Maxmum value of relative position. 312 initializer_range (float): Initialization value of TruncatedNormal. 313 use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. 314 """ 315 def __init__(self, 316 length, 317 depth, 318 max_relative_position, 319 initializer_range, 320 use_one_hot_embeddings=False): 321 super(RelaPosEmbeddingsGenerator, self).__init__() 322 self.depth = depth 323 self.vocab_size = max_relative_position * 2 + 1 324 self.use_one_hot_embeddings = use_one_hot_embeddings 325 326 self.embeddings_table = Parameter( 327 initializer(TruncatedNormal(initializer_range), 328 [self.vocab_size, self.depth]), 329 name='embeddings_for_position') 330 331 self.relative_positions_matrix = RelaPosMatrixGenerator(length=length, 332 max_relative_position=max_relative_position) 333 self.reshape = P.Reshape() 334 self.one_hot = nn.OneHot(depth=self.vocab_size) 335 self.shape = P.Shape() 336 self.gather = P.Gather() # index_select 337 self.matmul = P.BatchMatMul() 338 339 def construct(self): 340 relative_positions_matrix_out = self.relative_positions_matrix() 341 342 # Generate embedding for each relative position of dimension depth. 343 if self.use_one_hot_embeddings: 344 flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,)) 345 one_hot_relative_positions_matrix = self.one_hot( 346 flat_relative_positions_matrix) 347 embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table) 348 my_shape = self.shape(relative_positions_matrix_out) + (self.depth,) 349 embeddings = self.reshape(embeddings, my_shape) 350 else: 351 embeddings = self.gather(self.embeddings_table, 352 relative_positions_matrix_out, 0) 353 return embeddings 354 355 356class SaturateCast(nn.Cell): 357 """ 358 Performs a safe saturating cast. This operation applies proper clamping before casting to prevent 359 the danger that the value will overflow or underflow. 360 361 Args: 362 src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32. 363 dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32. 364 """ 365 def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): 366 super(SaturateCast, self).__init__() 367 np_type = mstype.dtype_to_nptype(dst_type) 368 min_type = float(np.finfo(np_type).min) 369 max_type = float(np.finfo(np_type).max) 370 371 self.tensor_min_type = min_type 372 self.tensor_max_type = max_type 373 374 self.min_op = P.Minimum() 375 self.max_op = P.Maximum() 376 self.cast = P.Cast() 377 self.dst_type = dst_type 378 379 def construct(self, x): 380 out = self.max_op(x, self.tensor_min_type) 381 out = self.min_op(out, self.tensor_max_type) 382 return self.cast(out, self.dst_type) 383 384 385class BertAttention(nn.Cell): 386 """ 387 Apply multi-headed attention from "from_tensor" to "to_tensor". 388 389 Args: 390 batch_size (int): Batch size of input datasets. 391 from_tensor_width (int): Size of last dim of from_tensor. 392 to_tensor_width (int): Size of last dim of to_tensor. 393 from_seq_length (int): Length of from_tensor sequence. 394 to_seq_length (int): Length of to_tensor sequence. 395 num_attention_heads (int): Number of attention heads. Default: 1. 396 size_per_head (int): Size of each attention head. Default: 512. 397 query_act (str): Activation function for the query transform. Default: None. 398 key_act (str): Activation function for the key transform. Default: None. 399 value_act (str): Activation function for the value transform. Default: None. 400 has_attention_mask (bool): Specifies whether to use attention mask. Default: False. 401 attention_probs_dropout_prob (float): The dropout probability for 402 BertAttention. Default: 0.0. 403 use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. 404 initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. 405 do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d 406 tensor. Default: False. 407 use_relative_positions (bool): Specifies whether to use relative positions. Default: False. 408 compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32. 409 """ 410 def __init__(self, 411 batch_size, 412 from_tensor_width, 413 to_tensor_width, 414 from_seq_length, 415 to_seq_length, 416 num_attention_heads=1, 417 size_per_head=512, 418 query_act=None, 419 key_act=None, 420 value_act=None, 421 has_attention_mask=False, 422 attention_probs_dropout_prob=0.0, 423 use_one_hot_embeddings=False, 424 initializer_range=0.02, 425 do_return_2d_tensor=False, 426 use_relative_positions=False, 427 compute_type=mstype.float32): 428 429 super(BertAttention, self).__init__() 430 self.batch_size = batch_size 431 self.from_seq_length = from_seq_length 432 self.to_seq_length = to_seq_length 433 self.num_attention_heads = num_attention_heads 434 self.size_per_head = size_per_head 435 self.has_attention_mask = has_attention_mask 436 self.use_relative_positions = use_relative_positions 437 438 self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head)) 439 self.reshape = P.Reshape() 440 self.shape_from_2d = (-1, from_tensor_width) 441 self.shape_to_2d = (-1, to_tensor_width) 442 weight = TruncatedNormal(initializer_range) 443 units = num_attention_heads * size_per_head 444 self.query_layer = nn.Dense(from_tensor_width, 445 units, 446 activation=query_act, 447 weight_init=weight).to_float(compute_type) 448 self.key_layer = nn.Dense(to_tensor_width, 449 units, 450 activation=key_act, 451 weight_init=weight).to_float(compute_type) 452 self.value_layer = nn.Dense(to_tensor_width, 453 units, 454 activation=value_act, 455 weight_init=weight).to_float(compute_type) 456 457 self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) 458 self.shape_to = ( 459 batch_size, to_seq_length, num_attention_heads, size_per_head) 460 461 self.matmul_trans_b = P.BatchMatMul(transpose_b=True) 462 self.multiply = P.Mul() 463 self.transpose = P.Transpose() 464 self.trans_shape = (0, 2, 1, 3) 465 self.trans_shape_relative = (2, 0, 1, 3) 466 self.trans_shape_position = (1, 2, 0, 3) 467 self.multiply_data = -10000.0 468 self.batch_num = batch_size * num_attention_heads 469 self.matmul = P.BatchMatMul() 470 471 self.softmax = nn.Softmax() 472 self.dropout = nn.Dropout(1 - attention_probs_dropout_prob) 473 474 if self.has_attention_mask: 475 self.expand_dims = P.ExpandDims() 476 self.sub = P.Sub() 477 self.add = P.Add() 478 self.cast = P.Cast() 479 self.get_dtype = P.DType() 480 if do_return_2d_tensor: 481 self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head) 482 else: 483 self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head) 484 485 self.cast_compute_type = SaturateCast(dst_type=compute_type) 486 if self.use_relative_positions: 487 self._generate_relative_positions_embeddings = \ 488 RelaPosEmbeddingsGenerator(length=to_seq_length, 489 depth=size_per_head, 490 max_relative_position=16, 491 initializer_range=initializer_range, 492 use_one_hot_embeddings=use_one_hot_embeddings) 493 494 def construct(self, from_tensor, to_tensor, attention_mask): 495 # reshape 2d/3d input tensors to 2d 496 from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d) 497 to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d) 498 query_out = self.query_layer(from_tensor_2d) 499 key_out = self.key_layer(to_tensor_2d) 500 value_out = self.value_layer(to_tensor_2d) 501 502 query_layer = self.reshape(query_out, self.shape_from) 503 query_layer = self.transpose(query_layer, self.trans_shape) 504 key_layer = self.reshape(key_out, self.shape_to) 505 key_layer = self.transpose(key_layer, self.trans_shape) 506 507 attention_scores = self.matmul_trans_b(query_layer, key_layer) 508 509 # use_relative_position, supplementary logic 510 if self.use_relative_positions: 511 # 'relations_keys' = [F|T, F|T, H] 512 relations_keys = self._generate_relative_positions_embeddings() 513 relations_keys = self.cast_compute_type(relations_keys) 514 # query_layer_t is [F, B, N, H] 515 query_layer_t = self.transpose(query_layer, self.trans_shape_relative) 516 # query_layer_r is [F, B * N, H] 517 query_layer_r = self.reshape(query_layer_t, 518 (self.from_seq_length, 519 self.batch_num, 520 self.size_per_head)) 521 # key_position_scores is [F, B * N, F|T] 522 key_position_scores = self.matmul_trans_b(query_layer_r, 523 relations_keys) 524 # key_position_scores_r is [F, B, N, F|T] 525 key_position_scores_r = self.reshape(key_position_scores, 526 (self.from_seq_length, 527 self.batch_size, 528 self.num_attention_heads, 529 self.from_seq_length)) 530 # key_position_scores_r_t is [B, N, F, F|T] 531 key_position_scores_r_t = self.transpose(key_position_scores_r, 532 self.trans_shape_position) 533 attention_scores = attention_scores + key_position_scores_r_t 534 535 attention_scores = self.multiply(self.scores_mul, attention_scores) 536 537 if self.has_attention_mask: 538 attention_mask = self.expand_dims(attention_mask, 1) 539 multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)), 540 self.cast(attention_mask, self.get_dtype(attention_scores))) 541 542 adder = self.multiply(multiply_out, self.multiply_data) 543 attention_scores = self.add(adder, attention_scores) 544 545 attention_probs = self.softmax(attention_scores) 546 attention_probs = self.dropout(attention_probs) 547 548 value_layer = self.reshape(value_out, self.shape_to) 549 value_layer = self.transpose(value_layer, self.trans_shape) 550 context_layer = self.matmul(attention_probs, value_layer) 551 552 # use_relative_position, supplementary logic 553 if self.use_relative_positions: 554 # 'relations_values' = [F|T, F|T, H] 555 relations_values = self._generate_relative_positions_embeddings() 556 relations_values = self.cast_compute_type(relations_values) 557 # attention_probs_t is [F, B, N, T] 558 attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative) 559 # attention_probs_r is [F, B * N, T] 560 attention_probs_r = self.reshape( 561 attention_probs_t, 562 (self.from_seq_length, 563 self.batch_num, 564 self.to_seq_length)) 565 # value_position_scores is [F, B * N, H] 566 value_position_scores = self.matmul(attention_probs_r, 567 relations_values) 568 # value_position_scores_r is [F, B, N, H] 569 value_position_scores_r = self.reshape(value_position_scores, 570 (self.from_seq_length, 571 self.batch_size, 572 self.num_attention_heads, 573 self.size_per_head)) 574 # value_position_scores_r_t is [B, N, F, H] 575 value_position_scores_r_t = self.transpose(value_position_scores_r, 576 self.trans_shape_position) 577 context_layer = context_layer + value_position_scores_r_t 578 579 context_layer = self.transpose(context_layer, self.trans_shape) 580 context_layer = self.reshape(context_layer, self.shape_return) 581 582 return context_layer 583 584 585class BertSelfAttention(nn.Cell): 586 """ 587 Apply self-attention. 588 589 Args: 590 batch_size (int): Batch size of input dataset. 591 seq_length (int): Length of input sequence. 592 hidden_size (int): Size of the bert encoder layers. 593 num_attention_heads (int): Number of attention heads. Default: 12. 594 attention_probs_dropout_prob (float): The dropout probability for 595 BertAttention. Default: 0.1. 596 use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False. 597 initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. 598 hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. 599 use_relative_positions (bool): Specifies whether to use relative positions. Default: False. 600 compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32. 601 """ 602 def __init__(self, 603 batch_size, 604 seq_length, 605 hidden_size, 606 num_attention_heads=12, 607 attention_probs_dropout_prob=0.1, 608 use_one_hot_embeddings=False, 609 initializer_range=0.02, 610 hidden_dropout_prob=0.1, 611 use_relative_positions=False, 612 compute_type=mstype.float32, 613 enable_fused_layernorm=False): 614 super(BertSelfAttention, self).__init__() 615 if hidden_size % num_attention_heads != 0: 616 raise ValueError("The hidden size (%d) is not a multiple of the number " 617 "of attention heads (%d)" % (hidden_size, num_attention_heads)) 618 619 self.size_per_head = int(hidden_size / num_attention_heads) 620 621 self.attention = BertAttention( 622 batch_size=batch_size, 623 from_tensor_width=hidden_size, 624 to_tensor_width=hidden_size, 625 from_seq_length=seq_length, 626 to_seq_length=seq_length, 627 num_attention_heads=num_attention_heads, 628 size_per_head=self.size_per_head, 629 attention_probs_dropout_prob=attention_probs_dropout_prob, 630 use_one_hot_embeddings=use_one_hot_embeddings, 631 initializer_range=initializer_range, 632 use_relative_positions=use_relative_positions, 633 has_attention_mask=True, 634 do_return_2d_tensor=True, 635 compute_type=compute_type) 636 637 self.output = BertOutput(in_channels=hidden_size, 638 out_channels=hidden_size, 639 initializer_range=initializer_range, 640 dropout_prob=hidden_dropout_prob, 641 compute_type=compute_type, 642 enable_fused_layernorm=enable_fused_layernorm) 643 self.reshape = P.Reshape() 644 self.shape = (-1, hidden_size) 645 646 def construct(self, input_tensor, attention_mask): 647 input_tensor = self.reshape(input_tensor, self.shape) 648 attention_output = self.attention(input_tensor, input_tensor, attention_mask) 649 output = self.output(attention_output, input_tensor) 650 return output 651 652 653class BertEncoderCell(nn.Cell): 654 """ 655 Encoder cells used in BertTransformer. 656 657 Args: 658 batch_size (int): Batch size of input dataset. 659 hidden_size (int): Size of the bert encoder layers. Default: 768. 660 seq_length (int): Length of input sequence. Default: 512. 661 num_attention_heads (int): Number of attention heads. Default: 12. 662 intermediate_size (int): Size of intermediate layer. Default: 3072. 663 attention_probs_dropout_prob (float): The dropout probability for 664 BertAttention. Default: 0.02. 665 use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. 666 initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. 667 hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. 668 use_relative_positions (bool): Specifies whether to use relative positions. Default: False. 669 hidden_act (str): Activation function. Default: "gelu". 670 compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. 671 """ 672 def __init__(self, 673 batch_size, 674 hidden_size=768, 675 seq_length=512, 676 num_attention_heads=12, 677 intermediate_size=3072, 678 attention_probs_dropout_prob=0.02, 679 use_one_hot_embeddings=False, 680 initializer_range=0.02, 681 hidden_dropout_prob=0.1, 682 use_relative_positions=False, 683 hidden_act="gelu", 684 compute_type=mstype.float32, 685 enable_fused_layernorm=False): 686 super(BertEncoderCell, self).__init__() 687 self.attention = BertSelfAttention( 688 batch_size=batch_size, 689 hidden_size=hidden_size, 690 seq_length=seq_length, 691 num_attention_heads=num_attention_heads, 692 attention_probs_dropout_prob=attention_probs_dropout_prob, 693 use_one_hot_embeddings=use_one_hot_embeddings, 694 initializer_range=initializer_range, 695 hidden_dropout_prob=hidden_dropout_prob, 696 use_relative_positions=use_relative_positions, 697 compute_type=compute_type, 698 enable_fused_layernorm=enable_fused_layernorm) 699 self.intermediate = nn.Dense(in_channels=hidden_size, 700 out_channels=intermediate_size, 701 activation=hidden_act, 702 weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) 703 self.output = BertOutput(in_channels=intermediate_size, 704 out_channels=hidden_size, 705 initializer_range=initializer_range, 706 dropout_prob=hidden_dropout_prob, 707 compute_type=compute_type, 708 enable_fused_layernorm=enable_fused_layernorm) 709 710 def construct(self, hidden_states, attention_mask): 711 # self-attention 712 attention_output = self.attention(hidden_states, attention_mask) 713 # feed construct 714 intermediate_output = self.intermediate(attention_output) 715 # add and normalize 716 output = self.output(intermediate_output, attention_output) 717 return output 718 719 720class BertTransformer(nn.Cell): 721 """ 722 Multi-layer bert transformer. 723 724 Args: 725 batch_size (int): Batch size of input dataset. 726 hidden_size (int): Size of the encoder layers. 727 seq_length (int): Length of input sequence. 728 num_hidden_layers (int): Number of hidden layers in encoder cells. 729 num_attention_heads (int): Number of attention heads in encoder cells. Default: 12. 730 intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072. 731 attention_probs_dropout_prob (float): The dropout probability for 732 BertAttention. Default: 0.1. 733 use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. 734 initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. 735 hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. 736 use_relative_positions (bool): Specifies whether to use relative positions. Default: False. 737 hidden_act (str): Activation function used in the encoder cells. Default: "gelu". 738 compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. 739 return_all_encoders (bool): Specifies whether to return all encoders. Default: False. 740 """ 741 def __init__(self, 742 batch_size, 743 hidden_size, 744 seq_length, 745 num_hidden_layers, 746 num_attention_heads=12, 747 intermediate_size=3072, 748 attention_probs_dropout_prob=0.1, 749 use_one_hot_embeddings=False, 750 initializer_range=0.02, 751 hidden_dropout_prob=0.1, 752 use_relative_positions=False, 753 hidden_act="gelu", 754 compute_type=mstype.float32, 755 return_all_encoders=False, 756 enable_fused_layernorm=False): 757 super(BertTransformer, self).__init__() 758 self.return_all_encoders = return_all_encoders 759 760 layers = [] 761 for _ in range(num_hidden_layers): 762 layer = BertEncoderCell(batch_size=batch_size, 763 hidden_size=hidden_size, 764 seq_length=seq_length, 765 num_attention_heads=num_attention_heads, 766 intermediate_size=intermediate_size, 767 attention_probs_dropout_prob=attention_probs_dropout_prob, 768 use_one_hot_embeddings=use_one_hot_embeddings, 769 initializer_range=initializer_range, 770 hidden_dropout_prob=hidden_dropout_prob, 771 use_relative_positions=use_relative_positions, 772 hidden_act=hidden_act, 773 compute_type=compute_type, 774 enable_fused_layernorm=enable_fused_layernorm) 775 layers.append(layer) 776 777 self.layers = nn.CellList(layers) 778 779 self.reshape = P.Reshape() 780 self.shape = (-1, hidden_size) 781 self.out_shape = (batch_size, seq_length, hidden_size) 782 783 def construct(self, input_tensor, attention_mask): 784 prev_output = self.reshape(input_tensor, self.shape) 785 786 all_encoder_layers = () 787 for layer_module in self.layers: 788 layer_output = layer_module(prev_output, attention_mask) 789 prev_output = layer_output 790 791 if self.return_all_encoders: 792 layer_output = self.reshape(layer_output, self.out_shape) 793 all_encoder_layers = all_encoder_layers + (layer_output,) 794 795 if not self.return_all_encoders: 796 prev_output = self.reshape(prev_output, self.out_shape) 797 all_encoder_layers = all_encoder_layers + (prev_output,) 798 return all_encoder_layers 799 800 801class CreateAttentionMaskFromInputMask(nn.Cell): 802 """ 803 Create attention mask according to input mask. 804 805 Args: 806 config (Class): Configuration for BertModel. 807 """ 808 def __init__(self, config): 809 super(CreateAttentionMaskFromInputMask, self).__init__() 810 self.input_mask_from_dataset = config.input_mask_from_dataset 811 self.input_mask = None 812 813 if not self.input_mask_from_dataset: 814 self.input_mask = initializer( 815 "ones", [config.batch_size, config.seq_length], mstype.int32).init_data() 816 817 self.cast = P.Cast() 818 self.reshape = P.Reshape() 819 self.shape = (config.batch_size, 1, config.seq_length) 820 self.broadcast_ones = initializer( 821 "ones", [config.batch_size, config.seq_length, 1], mstype.float32).init_data() 822 self.batch_matmul = P.BatchMatMul() 823 824 def construct(self, input_mask): 825 if not self.input_mask_from_dataset: 826 input_mask = self.input_mask 827 828 input_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32) 829 attention_mask = self.batch_matmul(self.broadcast_ones, input_mask) 830 return attention_mask 831 832 833class BertModel(nn.Cell): 834 """ 835 Bidirectional Encoder Representations from Transformers. 836 837 Args: 838 config (Class): Configuration for BertModel. 839 is_training (bool): True for training mode. False for eval mode. 840 use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. 841 """ 842 def __init__(self, 843 config, 844 is_training, 845 use_one_hot_embeddings=False): 846 super(BertModel, self).__init__() 847 config = copy.deepcopy(config) 848 if not is_training: 849 config.hidden_dropout_prob = 0.0 850 config.attention_probs_dropout_prob = 0.0 851 852 self.input_mask_from_dataset = config.input_mask_from_dataset 853 self.token_type_ids_from_dataset = config.token_type_ids_from_dataset 854 self.batch_size = config.batch_size 855 self.seq_length = config.seq_length 856 self.hidden_size = config.hidden_size 857 self.num_hidden_layers = config.num_hidden_layers 858 self.embedding_size = config.hidden_size 859 self.token_type_ids = None 860 861 self.last_idx = self.num_hidden_layers - 1 862 output_embedding_shape = [self.batch_size, self.seq_length, 863 self.embedding_size] 864 865 if not self.token_type_ids_from_dataset: 866 self.token_type_ids = initializer( 867 "zeros", [self.batch_size, self.seq_length], mstype.int32).init_data() 868 869 self.bert_embedding_lookup = EmbeddingLookup( 870 vocab_size=config.vocab_size, 871 embedding_size=self.embedding_size, 872 embedding_shape=output_embedding_shape, 873 use_one_hot_embeddings=use_one_hot_embeddings, 874 initializer_range=config.initializer_range) 875 876 self.bert_embedding_postprocessor = EmbeddingPostprocessor( 877 embedding_size=self.embedding_size, 878 embedding_shape=output_embedding_shape, 879 use_relative_positions=config.use_relative_positions, 880 use_token_type=True, 881 token_type_vocab_size=config.type_vocab_size, 882 use_one_hot_embeddings=use_one_hot_embeddings, 883 initializer_range=0.02, 884 max_position_embeddings=config.max_position_embeddings, 885 dropout_prob=config.hidden_dropout_prob) 886 887 self.bert_encoder = BertTransformer( 888 batch_size=self.batch_size, 889 hidden_size=self.hidden_size, 890 seq_length=self.seq_length, 891 num_attention_heads=config.num_attention_heads, 892 num_hidden_layers=self.num_hidden_layers, 893 intermediate_size=config.intermediate_size, 894 attention_probs_dropout_prob=config.attention_probs_dropout_prob, 895 use_one_hot_embeddings=use_one_hot_embeddings, 896 initializer_range=config.initializer_range, 897 hidden_dropout_prob=config.hidden_dropout_prob, 898 use_relative_positions=config.use_relative_positions, 899 hidden_act=config.hidden_act, 900 compute_type=config.compute_type, 901 return_all_encoders=True, 902 enable_fused_layernorm=config.enable_fused_layernorm) 903 904 self.cast = P.Cast() 905 self.dtype = config.dtype 906 self.cast_compute_type = SaturateCast(dst_type=config.compute_type) 907 self.slice = P.StridedSlice() 908 909 self.squeeze_1 = P.Squeeze(axis=1) 910 self.dense = nn.Dense(self.hidden_size, self.hidden_size, 911 activation="tanh", 912 weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type) 913 self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) 914 915 def construct(self, input_ids, token_type_ids, input_mask): 916 917 # embedding 918 if not self.token_type_ids_from_dataset: 919 token_type_ids = self.token_type_ids 920 word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids) 921 embedding_output = self.bert_embedding_postprocessor(token_type_ids, 922 word_embeddings) 923 924 # attention mask [batch_size, seq_length, seq_length] 925 attention_mask = self._create_attention_mask_from_input_mask(input_mask) 926 927 # bert encoder 928 encoder_output = self.bert_encoder(self.cast_compute_type(embedding_output), 929 attention_mask) 930 931 sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) 932 933 # pooler 934 sequence_slice = self.slice(sequence_output, 935 (0, 0, 0), 936 (self.batch_size, 1, self.hidden_size), 937 (1, 1, 1)) 938 first_token = self.squeeze_1(sequence_slice) 939 pooled_output = self.dense(first_token) 940 pooled_output = self.cast(pooled_output, self.dtype) 941 942 return sequence_output, pooled_output, embedding_tables 943