1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Bert submodules.""" 17 18# pylint: disable=missing-docstring, arguments-differ 19 20import math 21import numpy as np 22 23import mindspore.common.dtype as mstype 24import mindspore.ops.functional as F 25from mindspore import nn 26from mindspore.common.initializer import TruncatedNormal 27from mindspore.common.tensor import Tensor 28from mindspore.tests.models.Bert_NEZHA.bert_model import SaturateCast, RelaPosEmbeddingsGenerator 29from mindspore.ops import operations as P 30 31 32class BertAttentionQueryKeyMul(nn.Cell): 33 def __init__(self, 34 batch_size, 35 from_tensor_width, 36 to_tensor_width, 37 from_seq_length, 38 to_seq_length, 39 num_attention_heads=1, 40 size_per_head=512, 41 query_act=None, 42 key_act=None, 43 initializer_range=0.02): 44 super(BertAttentionQueryKeyMul, self).__init__() 45 self.from_tensor_width = from_tensor_width 46 self.to_tensor_width = to_tensor_width 47 self.units = num_attention_heads * size_per_head 48 self.weight = TruncatedNormal(initializer_range) 49 50 self.trans_shape = (0, 2, 1, 3) 51 self.transpose = P.Transpose() 52 self.reshape = P.Reshape() 53 self.shp_from_2d = (-1, self.from_tensor_width) 54 self.shp_to_2d = (-1, self.to_tensor_width) 55 self.query_layer = nn.Dense(self.from_tensor_width, 56 self.units, 57 activation=query_act, 58 weight_init=self.weight) 59 self.key_layer = nn.Dense(self.to_tensor_width, 60 self.units, 61 activation=key_act, 62 weight_init=self.weight) 63 64 self.shp_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) 65 self.shp_to = ( 66 batch_size, to_seq_length, num_attention_heads, size_per_head) 67 68 self.matmul_trans_b = P.BatchMatMul(transpose_b=True) 69 self.cast = P.Cast() 70 71 def construct(self, from_tensor, to_tensor): 72 from_tensor_2d = self.reshape(from_tensor, self.shp_from_2d) 73 to_tensor_2d = self.reshape(to_tensor, self.shp_to_2d) 74 from_tensor_2d = self.cast(from_tensor_2d, mstype.float32) 75 to_tensor_2d = self.cast(to_tensor_2d, mstype.float32) 76 query_out = self.query_layer(from_tensor_2d) 77 key_out = self.key_layer(to_tensor_2d) 78 79 query_layer = self.reshape(query_out, self.shp_from) 80 query_layer = self.transpose(query_layer, self.trans_shape) 81 key_layer = self.reshape(key_out, self.shp_to) 82 key_layer = self.transpose(key_layer, self.trans_shape) 83 84 attention_scores = self.matmul_trans_b(query_layer, key_layer) 85 86 return query_layer, key_layer, attention_scores 87 88 89class BertAttentionRelativePositionKeys(nn.Cell): 90 def __init__(self, 91 batch_size, 92 from_seq_length, 93 to_seq_length, 94 num_attention_heads=1, 95 size_per_head=512, 96 use_one_hot_embeddings=False, 97 initializer_range=0.02, 98 use_relative_positions=False, 99 dtype=mstype.float32, 100 compute_type=mstype.float32): 101 super(BertAttentionRelativePositionKeys, self).__init__() 102 self.batch_size = batch_size 103 self.from_seq_length = from_seq_length 104 self.to_seq_length = to_seq_length 105 self.use_relative_positions = use_relative_positions 106 self.size_per_head = size_per_head 107 self.num_attention_heads = num_attention_heads 108 self.trans_shape_position = (1, 2, 0, 3) 109 self.trans_shape_relative = (2, 0, 1, 3) 110 111 self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head)) 112 113 self.reshape = P.Reshape() 114 self.multiply = P.Mul() 115 self.transpose = P.Transpose() 116 self.matmul_trans_b = P.BatchMatMul(transpose_b=True) 117 self.batch_num = batch_size * num_attention_heads 118 self.cast = P.Cast() 119 120 self.cast_compute_type = SaturateCast(dst_type=compute_type) 121 self._generate_relative_positions_embeddings = \ 122 RelaPosEmbeddingsGenerator(length=self.to_seq_length, 123 depth=self.size_per_head, 124 max_relative_position=16, 125 initializer_range=initializer_range, 126 use_one_hot_embeddings=use_one_hot_embeddings) 127 128 def construct(self, input_tensor, query_layer): 129 # use_relative_position, supplementary logic 130 relations_keys_embeddings = self._generate_relative_positions_embeddings() 131 if self.use_relative_positions: 132 # 'relations_keys' = [F|T, F|T, H] 133 relations_keys = self.cast_compute_type(relations_keys_embeddings) 134 # query_layer_t is [F, B, N, H] 135 query_layer_t = self.transpose(query_layer, self.trans_shape_relative) 136 # query_layer_r is [F, B * N, H] 137 query_layer_r = self.reshape(query_layer_t, 138 (self.from_seq_length, 139 self.batch_num, 140 self.size_per_head)) 141 # key_position_scores is [F, B * N, F|T] 142 query_layer_r = self.cast(query_layer_r, mstype.float32) 143 key_position_scores = self.matmul_trans_b(query_layer_r, 144 relations_keys) 145 # key_position_scores_r is [F, B, N, F|T] 146 key_position_scores_r = self.reshape(key_position_scores, 147 (self.from_seq_length, 148 self.batch_size, 149 self.num_attention_heads, 150 self.from_seq_length)) 151 # key_position_scores_r_t is [B, N, F, F|T] 152 key_position_scores_r_t = self.transpose(key_position_scores_r, 153 self.trans_shape_position) 154 input_tensor = self.cast(input_tensor, mstype.float32) 155 156 input_tensor = input_tensor + key_position_scores_r_t 157 158 attention_scores = self.multiply(input_tensor, self.scores_mul) 159 160 return relations_keys_embeddings, attention_scores 161 162 163class BertAttentionMask(nn.Cell): 164 def __init__(self, 165 has_attention_mask=False, 166 dtype=mstype.float32): 167 168 super(BertAttentionMask, self).__init__() 169 self.has_attention_mask = has_attention_mask 170 self.multiply_data = Tensor([-1000.0,], dtype=dtype) 171 self.multiply = P.Mul() 172 173 if self.has_attention_mask: 174 self.expand_dims = P.ExpandDims() 175 self.sub = P.Sub() 176 self.add = P.Add() 177 self.cast = P.Cast() 178 self.get_dtype = P.DType() 179 180 def construct(self, input_tensor, attention_mask): 181 attention_scores = input_tensor 182 attention_scores = self.cast(attention_scores, mstype.float32) 183 if self.has_attention_mask: 184 attention_mask = self.expand_dims(attention_mask, 1) 185 multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), mstype.float32), 186 self.cast(attention_mask, self.get_dtype(attention_scores))) 187 188 adder = self.multiply(multiply_out, self.multiply_data) 189 attention_scores = self.add(adder, attention_scores) 190 191 return attention_scores 192 193 194class BertAttentionMaskBackward(nn.Cell): 195 def __init__(self, 196 attention_mask_shape, 197 has_attention_mask=False, 198 dtype=mstype.float32): 199 super(BertAttentionMaskBackward, self).__init__() 200 self.has_attention_mask = has_attention_mask 201 self.multiply_data = Tensor([-1000.0,], dtype=dtype) 202 self.multiply = P.Mul() 203 self.attention_mask = Tensor(np.ones(shape=attention_mask_shape).astype(np.float32)) 204 if self.has_attention_mask: 205 self.expand_dims = P.ExpandDims() 206 self.sub = P.Sub() 207 self.add = P.Add() 208 self.cast = P.Cast() 209 self.get_dtype = P.DType() 210 211 def construct(self, input_tensor): 212 attention_scores = input_tensor 213 attention_scores = self.cast(attention_scores, mstype.float32) 214 if self.has_attention_mask: 215 attention_mask = self.expand_dims(self.attention_mask, 1) 216 multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), mstype.float32), 217 self.cast(attention_mask, self.get_dtype(attention_scores))) 218 219 adder = self.multiply(multiply_out, self.multiply_data) 220 attention_scores = self.add(adder, attention_scores) 221 return attention_scores 222 223 224class BertAttentionSoftmax(nn.Cell): 225 def __init__(self, 226 batch_size, 227 to_tensor_width, 228 from_seq_length, 229 to_seq_length, 230 num_attention_heads=1, 231 size_per_head=512, 232 value_act=None, 233 attention_probs_dropout_prob=0.0, 234 initializer_range=0.02): 235 super(BertAttentionSoftmax, self).__init__() 236 self.to_tensor_width = to_tensor_width 237 self.value_act = value_act 238 239 self.reshape = P.Reshape() 240 241 self.shp_to_2d = (-1, self.to_tensor_width) 242 self.shp_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) 243 self.shp_to = ( 244 batch_size, to_seq_length, num_attention_heads, size_per_head) 245 246 self.trans_shape = (0, 2, 1, 3) 247 self.trans_shape_start = (0, 1) 248 self.matmul = P.BatchMatMul() 249 250 self.units = num_attention_heads * size_per_head 251 self.weight = TruncatedNormal(initializer_range) 252 253 self.softmax = nn.Softmax() 254 self.dropout = nn.Dropout(1 - attention_probs_dropout_prob) 255 self.transpose = P.Transpose() 256 257 self.value_layer = nn.Dense(self.to_tensor_width, 258 self.units, 259 activation=value_act, 260 weight_init=self.weight) 261 self.cast = P.Cast() 262 263 def construct(self, to_tensor, attention_scores): 264 to_tensor = self.transpose(to_tensor, self.trans_shape_start) 265 to_tensor_2d = self.reshape(to_tensor, self.shp_to_2d) 266 to_tensor_2d = self.cast(to_tensor_2d, mstype.float32) 267 value_out = self.value_layer(to_tensor_2d) 268 269 attention_probs = self.softmax(attention_scores) 270 attention_probs = self.cast(attention_probs, mstype.float32) 271 272 value_layer = self.reshape(value_out, self.shp_to) 273 value_layer = self.transpose(value_layer, self.trans_shape) 274 275 context_layer = self.matmul(attention_probs, value_layer) 276 277 return value_layer, context_layer 278 279 280class BertAttentionRelativePositionValues(nn.Cell): 281 def __init__(self, 282 batch_size, 283 from_seq_length, 284 to_seq_length, 285 num_attention_heads=1, 286 size_per_head=512, 287 use_one_hot_embeddings=False, 288 initializer_range=0.02, 289 do_return_2d_tensor=False, 290 use_relative_positions=False, 291 dtype=mstype.float32, 292 compute_type=mstype.float32): 293 294 super(BertAttentionRelativePositionValues, self).__init__() 295 self.batch_size = batch_size 296 self.from_seq_length = from_seq_length 297 self.to_seq_length = to_seq_length 298 self.use_relative_positions = use_relative_positions 299 self.size_per_head = size_per_head 300 self.num_attention_heads = num_attention_heads 301 self.trans_shape_position = (1, 2, 0, 3) 302 self.trans_shape_relative = (2, 0, 1, 3) 303 304 self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head)) 305 self.trans_shape = (0, 2, 1, 3) 306 307 self.reshape = P.Reshape() 308 self.multiply = P.Mul() 309 self.transpose = P.Transpose() 310 self.batch_num = batch_size * num_attention_heads 311 self.matmul = P.BatchMatMul() 312 self.do_return_2d_tensor = do_return_2d_tensor 313 if self.do_return_2d_tensor: 314 self.shp_return = (batch_size * from_seq_length, num_attention_heads * size_per_head) 315 else: 316 self.shp_return = (batch_size, from_seq_length, num_attention_heads * size_per_head) 317 318 self.cast_compute_type = SaturateCast(dst_type=compute_type) 319 self._generate_relative_positions_embeddings = \ 320 RelaPosEmbeddingsGenerator(length=self.to_seq_length, 321 depth=self.size_per_head, 322 max_relative_position=16, 323 initializer_range=initializer_range, 324 use_one_hot_embeddings=use_one_hot_embeddings) 325 self.fill = P.Fill() 326 self.multiply = P.Mul() 327 self.type = P.DType() 328 self.cast = P.Cast() 329 330 def construct(self, input_tensor, attention_probs): 331 # use_relative_position, supplementary logic 332 relations_values_embedding = self._generate_relative_positions_embeddings() # (128, 128, 64) 333 if self.use_relative_positions: 334 # 'relations_values' = [F|T, F|T, H] 335 relations_values = self.cast_compute_type(relations_values_embedding) 336 # attention_probs_t is [F, B, N, T] 337 attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative) 338 # attention_probs_r is [F, B * N, T] 339 attention_probs_r = self.reshape( 340 attention_probs_t, 341 (self.from_seq_length, 342 self.batch_num, 343 self.to_seq_length)) # (128,768,128) 344 # value_position_scores is [F, B * N, H] 345 value_position_scores = self.matmul(attention_probs_r, 346 relations_values) 347 # value_position_scores_r is [F, B, N, H] 348 value_position_scores_r = self.reshape(value_position_scores, 349 (self.from_seq_length, 350 self.batch_size, 351 self.num_attention_heads, 352 self.size_per_head)) 353 # value_position_scores_r_t is [B, N, F, H] 354 value_position_scores_r_t = self.transpose(value_position_scores_r, 355 self.trans_shape_position) 356 input_tensor = input_tensor + value_position_scores_r_t 357 358 context_layer = self.transpose(input_tensor, self.trans_shape) 359 context_layer = self.reshape(context_layer, self.shp_return) 360 # ge reshape should not return, need an operator here 361 ones = self.cast(self.fill((1, 1), 1), self.type(context_layer)) 362 context_layer = self.multiply(context_layer, ones) 363 return relations_values_embedding, context_layer 364 365 366class BertDense(nn.Cell): 367 def __init__(self, 368 hidden_size=768, 369 intermediate_size=3072, 370 initializer_range=0.02): 371 super(BertDense, self).__init__() 372 self.intermediate = nn.Dense(in_channels=hidden_size, 373 out_channels=intermediate_size, 374 activation=None, 375 weight_init=TruncatedNormal( 376 initializer_range) 377 ) 378 self.cast = P.Cast() 379 380 def construct(self, attention_output): 381 attention_output = self.cast(attention_output, mstype.float32) 382 intermediate_output = self.intermediate(attention_output) 383 return intermediate_output 384