• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Bert submodules."""
17
18# pylint: disable=missing-docstring, arguments-differ
19
20import math
21import numpy as np
22
23import mindspore.common.dtype as mstype
24import mindspore.ops.functional as F
25from mindspore import nn
26from mindspore.common.initializer import TruncatedNormal
27from mindspore.common.tensor import Tensor
28from mindspore.tests.models.Bert_NEZHA.bert_model import SaturateCast, RelaPosEmbeddingsGenerator
29from mindspore.ops import operations as P
30
31
32class BertAttentionQueryKeyMul(nn.Cell):
33    def __init__(self,
34                 batch_size,
35                 from_tensor_width,
36                 to_tensor_width,
37                 from_seq_length,
38                 to_seq_length,
39                 num_attention_heads=1,
40                 size_per_head=512,
41                 query_act=None,
42                 key_act=None,
43                 initializer_range=0.02):
44        super(BertAttentionQueryKeyMul, self).__init__()
45        self.from_tensor_width = from_tensor_width
46        self.to_tensor_width = to_tensor_width
47        self.units = num_attention_heads * size_per_head
48        self.weight = TruncatedNormal(initializer_range)
49
50        self.trans_shape = (0, 2, 1, 3)
51        self.transpose = P.Transpose()
52        self.reshape = P.Reshape()
53        self.shp_from_2d = (-1, self.from_tensor_width)
54        self.shp_to_2d = (-1, self.to_tensor_width)
55        self.query_layer = nn.Dense(self.from_tensor_width,
56                                    self.units,
57                                    activation=query_act,
58                                    weight_init=self.weight)
59        self.key_layer = nn.Dense(self.to_tensor_width,
60                                  self.units,
61                                  activation=key_act,
62                                  weight_init=self.weight)
63
64        self.shp_from = (batch_size, from_seq_length, num_attention_heads, size_per_head)
65        self.shp_to = (
66            batch_size, to_seq_length, num_attention_heads, size_per_head)
67
68        self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
69        self.cast = P.Cast()
70
71    def construct(self, from_tensor, to_tensor):
72        from_tensor_2d = self.reshape(from_tensor, self.shp_from_2d)
73        to_tensor_2d = self.reshape(to_tensor, self.shp_to_2d)
74        from_tensor_2d = self.cast(from_tensor_2d, mstype.float32)
75        to_tensor_2d = self.cast(to_tensor_2d, mstype.float32)
76        query_out = self.query_layer(from_tensor_2d)
77        key_out = self.key_layer(to_tensor_2d)
78
79        query_layer = self.reshape(query_out, self.shp_from)
80        query_layer = self.transpose(query_layer, self.trans_shape)
81        key_layer = self.reshape(key_out, self.shp_to)
82        key_layer = self.transpose(key_layer, self.trans_shape)
83
84        attention_scores = self.matmul_trans_b(query_layer, key_layer)
85
86        return query_layer, key_layer, attention_scores
87
88
89class BertAttentionRelativePositionKeys(nn.Cell):
90    def __init__(self,
91                 batch_size,
92                 from_seq_length,
93                 to_seq_length,
94                 num_attention_heads=1,
95                 size_per_head=512,
96                 use_one_hot_embeddings=False,
97                 initializer_range=0.02,
98                 use_relative_positions=False,
99                 dtype=mstype.float32,
100                 compute_type=mstype.float32):
101        super(BertAttentionRelativePositionKeys, self).__init__()
102        self.batch_size = batch_size
103        self.from_seq_length = from_seq_length
104        self.to_seq_length = to_seq_length
105        self.use_relative_positions = use_relative_positions
106        self.size_per_head = size_per_head
107        self.num_attention_heads = num_attention_heads
108        self.trans_shape_position = (1, 2, 0, 3)
109        self.trans_shape_relative = (2, 0, 1, 3)
110
111        self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head))
112
113        self.reshape = P.Reshape()
114        self.multiply = P.Mul()
115        self.transpose = P.Transpose()
116        self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
117        self.batch_num = batch_size * num_attention_heads
118        self.cast = P.Cast()
119
120        self.cast_compute_type = SaturateCast(dst_type=compute_type)
121        self._generate_relative_positions_embeddings = \
122            RelaPosEmbeddingsGenerator(length=self.to_seq_length,
123                                       depth=self.size_per_head,
124                                       max_relative_position=16,
125                                       initializer_range=initializer_range,
126                                       use_one_hot_embeddings=use_one_hot_embeddings)
127
128    def construct(self, input_tensor, query_layer):
129        # use_relative_position, supplementary logic
130        relations_keys_embeddings = self._generate_relative_positions_embeddings()
131        if self.use_relative_positions:
132            # 'relations_keys' = [F|T, F|T, H]
133            relations_keys = self.cast_compute_type(relations_keys_embeddings)
134            # query_layer_t is [F, B, N, H]
135            query_layer_t = self.transpose(query_layer, self.trans_shape_relative)
136            # query_layer_r is [F, B * N, H]
137            query_layer_r = self.reshape(query_layer_t,
138                                         (self.from_seq_length,
139                                          self.batch_num,
140                                          self.size_per_head))
141            # key_position_scores is [F, B * N, F|T]
142            query_layer_r = self.cast(query_layer_r, mstype.float32)
143            key_position_scores = self.matmul_trans_b(query_layer_r,
144                                                      relations_keys)
145            # key_position_scores_r is [F, B, N, F|T]
146            key_position_scores_r = self.reshape(key_position_scores,
147                                                 (self.from_seq_length,
148                                                  self.batch_size,
149                                                  self.num_attention_heads,
150                                                  self.from_seq_length))
151            # key_position_scores_r_t is [B, N, F, F|T]
152            key_position_scores_r_t = self.transpose(key_position_scores_r,
153                                                     self.trans_shape_position)
154            input_tensor = self.cast(input_tensor, mstype.float32)
155
156            input_tensor = input_tensor + key_position_scores_r_t
157
158        attention_scores = self.multiply(input_tensor, self.scores_mul)
159
160        return relations_keys_embeddings, attention_scores
161
162
163class BertAttentionMask(nn.Cell):
164    def __init__(self,
165                 has_attention_mask=False,
166                 dtype=mstype.float32):
167
168        super(BertAttentionMask, self).__init__()
169        self.has_attention_mask = has_attention_mask
170        self.multiply_data = Tensor([-1000.0,], dtype=dtype)
171        self.multiply = P.Mul()
172
173        if self.has_attention_mask:
174            self.expand_dims = P.ExpandDims()
175            self.sub = P.Sub()
176            self.add = P.Add()
177            self.cast = P.Cast()
178            self.get_dtype = P.DType()
179
180    def construct(self, input_tensor, attention_mask):
181        attention_scores = input_tensor
182        attention_scores = self.cast(attention_scores, mstype.float32)
183        if self.has_attention_mask:
184            attention_mask = self.expand_dims(attention_mask, 1)
185            multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), mstype.float32),
186                                    self.cast(attention_mask, self.get_dtype(attention_scores)))
187
188            adder = self.multiply(multiply_out, self.multiply_data)
189            attention_scores = self.add(adder, attention_scores)
190
191        return attention_scores
192
193
194class BertAttentionMaskBackward(nn.Cell):
195    def __init__(self,
196                 attention_mask_shape,
197                 has_attention_mask=False,
198                 dtype=mstype.float32):
199        super(BertAttentionMaskBackward, self).__init__()
200        self.has_attention_mask = has_attention_mask
201        self.multiply_data = Tensor([-1000.0,], dtype=dtype)
202        self.multiply = P.Mul()
203        self.attention_mask = Tensor(np.ones(shape=attention_mask_shape).astype(np.float32))
204        if self.has_attention_mask:
205            self.expand_dims = P.ExpandDims()
206            self.sub = P.Sub()
207            self.add = P.Add()
208            self.cast = P.Cast()
209            self.get_dtype = P.DType()
210
211    def construct(self, input_tensor):
212        attention_scores = input_tensor
213        attention_scores = self.cast(attention_scores, mstype.float32)
214        if self.has_attention_mask:
215            attention_mask = self.expand_dims(self.attention_mask, 1)
216            multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), mstype.float32),
217                                    self.cast(attention_mask, self.get_dtype(attention_scores)))
218
219            adder = self.multiply(multiply_out, self.multiply_data)
220            attention_scores = self.add(adder, attention_scores)
221        return attention_scores
222
223
224class BertAttentionSoftmax(nn.Cell):
225    def __init__(self,
226                 batch_size,
227                 to_tensor_width,
228                 from_seq_length,
229                 to_seq_length,
230                 num_attention_heads=1,
231                 size_per_head=512,
232                 value_act=None,
233                 attention_probs_dropout_prob=0.0,
234                 initializer_range=0.02):
235        super(BertAttentionSoftmax, self).__init__()
236        self.to_tensor_width = to_tensor_width
237        self.value_act = value_act
238
239        self.reshape = P.Reshape()
240
241        self.shp_to_2d = (-1, self.to_tensor_width)
242        self.shp_from = (batch_size, from_seq_length, num_attention_heads, size_per_head)
243        self.shp_to = (
244            batch_size, to_seq_length, num_attention_heads, size_per_head)
245
246        self.trans_shape = (0, 2, 1, 3)
247        self.trans_shape_start = (0, 1)
248        self.matmul = P.BatchMatMul()
249
250        self.units = num_attention_heads * size_per_head
251        self.weight = TruncatedNormal(initializer_range)
252
253        self.softmax = nn.Softmax()
254        self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)
255        self.transpose = P.Transpose()
256
257        self.value_layer = nn.Dense(self.to_tensor_width,
258                                    self.units,
259                                    activation=value_act,
260                                    weight_init=self.weight)
261        self.cast = P.Cast()
262
263    def construct(self, to_tensor, attention_scores):
264        to_tensor = self.transpose(to_tensor, self.trans_shape_start)
265        to_tensor_2d = self.reshape(to_tensor, self.shp_to_2d)
266        to_tensor_2d = self.cast(to_tensor_2d, mstype.float32)
267        value_out = self.value_layer(to_tensor_2d)
268
269        attention_probs = self.softmax(attention_scores)
270        attention_probs = self.cast(attention_probs, mstype.float32)
271
272        value_layer = self.reshape(value_out, self.shp_to)
273        value_layer = self.transpose(value_layer, self.trans_shape)
274
275        context_layer = self.matmul(attention_probs, value_layer)
276
277        return value_layer, context_layer
278
279
280class BertAttentionRelativePositionValues(nn.Cell):
281    def __init__(self,
282                 batch_size,
283                 from_seq_length,
284                 to_seq_length,
285                 num_attention_heads=1,
286                 size_per_head=512,
287                 use_one_hot_embeddings=False,
288                 initializer_range=0.02,
289                 do_return_2d_tensor=False,
290                 use_relative_positions=False,
291                 dtype=mstype.float32,
292                 compute_type=mstype.float32):
293
294        super(BertAttentionRelativePositionValues, self).__init__()
295        self.batch_size = batch_size
296        self.from_seq_length = from_seq_length
297        self.to_seq_length = to_seq_length
298        self.use_relative_positions = use_relative_positions
299        self.size_per_head = size_per_head
300        self.num_attention_heads = num_attention_heads
301        self.trans_shape_position = (1, 2, 0, 3)
302        self.trans_shape_relative = (2, 0, 1, 3)
303
304        self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head))
305        self.trans_shape = (0, 2, 1, 3)
306
307        self.reshape = P.Reshape()
308        self.multiply = P.Mul()
309        self.transpose = P.Transpose()
310        self.batch_num = batch_size * num_attention_heads
311        self.matmul = P.BatchMatMul()
312        self.do_return_2d_tensor = do_return_2d_tensor
313        if self.do_return_2d_tensor:
314            self.shp_return = (batch_size * from_seq_length, num_attention_heads * size_per_head)
315        else:
316            self.shp_return = (batch_size, from_seq_length, num_attention_heads * size_per_head)
317
318        self.cast_compute_type = SaturateCast(dst_type=compute_type)
319        self._generate_relative_positions_embeddings = \
320            RelaPosEmbeddingsGenerator(length=self.to_seq_length,
321                                       depth=self.size_per_head,
322                                       max_relative_position=16,
323                                       initializer_range=initializer_range,
324                                       use_one_hot_embeddings=use_one_hot_embeddings)
325        self.fill = P.Fill()
326        self.multiply = P.Mul()
327        self.type = P.DType()
328        self.cast = P.Cast()
329
330    def construct(self, input_tensor, attention_probs):
331        # use_relative_position, supplementary logic
332        relations_values_embedding = self._generate_relative_positions_embeddings()  # (128, 128, 64)
333        if self.use_relative_positions:
334            # 'relations_values' = [F|T, F|T, H]
335            relations_values = self.cast_compute_type(relations_values_embedding)
336            # attention_probs_t is [F, B, N, T]
337            attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative)
338            # attention_probs_r is [F, B * N, T]
339            attention_probs_r = self.reshape(
340                attention_probs_t,
341                (self.from_seq_length,
342                 self.batch_num,
343                 self.to_seq_length))  # (128,768,128)
344            # value_position_scores is [F, B * N, H]
345            value_position_scores = self.matmul(attention_probs_r,
346                                                relations_values)
347            # value_position_scores_r is [F, B, N, H]
348            value_position_scores_r = self.reshape(value_position_scores,
349                                                   (self.from_seq_length,
350                                                    self.batch_size,
351                                                    self.num_attention_heads,
352                                                    self.size_per_head))
353            # value_position_scores_r_t is [B, N, F, H]
354            value_position_scores_r_t = self.transpose(value_position_scores_r,
355                                                       self.trans_shape_position)
356            input_tensor = input_tensor + value_position_scores_r_t
357
358        context_layer = self.transpose(input_tensor, self.trans_shape)
359        context_layer = self.reshape(context_layer, self.shp_return)
360        # ge reshape should not return, need an operator here
361        ones = self.cast(self.fill((1, 1), 1), self.type(context_layer))
362        context_layer = self.multiply(context_layer, ones)
363        return relations_values_embedding, context_layer
364
365
366class BertDense(nn.Cell):
367    def __init__(self,
368                 hidden_size=768,
369                 intermediate_size=3072,
370                 initializer_range=0.02):
371        super(BertDense, self).__init__()
372        self.intermediate = nn.Dense(in_channels=hidden_size,
373                                     out_channels=intermediate_size,
374                                     activation=None,
375                                     weight_init=TruncatedNormal(
376                                         initializer_range)
377                                     )
378        self.cast = P.Cast()
379
380    def construct(self, attention_output):
381        attention_output = self.cast(attention_output, mstype.float32)
382        intermediate_output = self.intermediate(attention_output)
383        return intermediate_output
384