1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16''' 17CRF script. 18''' 19 20import numpy as np 21import mindspore.nn as nn 22from mindspore.ops import operations as P 23from mindspore.common.tensor import Tensor 24from mindspore.common.parameter import Parameter 25import mindspore.common.dtype as mstype 26 27class CRF(nn.Cell): 28 ''' 29 Conditional Random Field 30 Args: 31 tag_to_index: The dict for tag to index mapping with extra "<START>" and "<STOP>"sign. 32 batch_size: Batch size, i.e., the length of the first dimension. 33 seq_length: Sequence length, i.e., the length of the second dimension. 34 is_training: Specifies whether to use training mode. 35 Returns: 36 Training mode: Tensor, total loss. 37 Evaluation mode: Tuple, the index for each step with the highest score; Tuple, the index for the last 38 step with the highest score. 39 ''' 40 def __init__(self, tag_to_index, batch_size=1, seq_length=128, is_training=True): 41 42 super(CRF, self).__init__() 43 self.target_size = len(tag_to_index) 44 self.is_training = is_training 45 self.tag_to_index = tag_to_index 46 self.batch_size = batch_size 47 self.seq_length = seq_length 48 self.START_TAG = "<START>" 49 self.STOP_TAG = "<STOP>" 50 self.START_VALUE = Tensor(self.target_size-2, dtype=mstype.int32) 51 self.STOP_VALUE = Tensor(self.target_size-1, dtype=mstype.int32) 52 transitions = np.random.normal(size=(self.target_size, self.target_size)).astype(np.float32) 53 transitions[tag_to_index[self.START_TAG], :] = -10000 54 transitions[:, tag_to_index[self.STOP_TAG]] = -10000 55 self.transitions = Parameter(Tensor(transitions), name="transition_matrix") 56 self.cat = P.Concat(axis=-1) 57 self.argmax = P.ArgMaxWithValue(axis=-1) 58 self.log = P.Log() 59 self.exp = P.Exp() 60 self.sum = P.ReduceSum() 61 self.tile = P.Tile() 62 self.reduce_sum = P.ReduceSum(keep_dims=True) 63 self.reshape = P.Reshape() 64 self.expand = P.ExpandDims() 65 self.mean = P.ReduceMean() 66 init_alphas = np.ones(shape=(self.batch_size, self.target_size)) * -10000.0 67 init_alphas[:, self.tag_to_index[self.START_TAG]] = 0. 68 self.init_alphas = Tensor(init_alphas, dtype=mstype.float32) 69 self.cast = P.Cast() 70 self.reduce_max = P.ReduceMax(keep_dims=True) 71 self.on_value = Tensor(1.0, dtype=mstype.float32) 72 self.off_value = Tensor(0.0, dtype=mstype.float32) 73 self.onehot = P.OneHot() 74 75 def log_sum_exp(self, logits): 76 ''' 77 Compute the log_sum_exp score for Normalization factor. 78 ''' 79 max_score = self.reduce_max(logits, -1) #16 5 5 80 score = self.log(self.reduce_sum(self.exp(logits - max_score), -1)) 81 score = max_score + score 82 return score 83 84 def _realpath_score(self, features, label): 85 ''' 86 Compute the emission and transition score for the real path. 87 ''' 88 label = label * 1 89 concat_A = self.tile(self.reshape(self.START_VALUE, (1,)), (self.batch_size,)) 90 concat_A = self.reshape(concat_A, (self.batch_size, 1)) 91 labels = self.cat((concat_A, label)) 92 onehot_label = self.onehot(label, self.target_size, self.on_value, self.off_value) 93 emits = features * onehot_label 94 labels = self.onehot(labels, self.target_size, self.on_value, self.off_value) 95 label1 = labels[:, 1:, :] 96 label2 = labels[:, :self.seq_length, :] 97 label1 = self.expand(label1, 3) 98 label2 = self.expand(label2, 2) 99 label_trans = label1 * label2 100 transitions = self.expand(self.expand(self.transitions, 0), 0) 101 trans = transitions * label_trans 102 score = self.sum(emits, (1, 2)) + self.sum(trans, (1, 2, 3)) 103 stop_value_index = labels[:, (self.seq_length-1):self.seq_length, :] 104 stop_value = self.transitions[(self.target_size-1):self.target_size, :] 105 stop_score = stop_value * self.reshape(stop_value_index, (self.batch_size, self.target_size)) 106 score = score + self.sum(stop_score, 1) 107 score = self.reshape(score, (self.batch_size, -1)) 108 return score 109 110 def _normalization_factor(self, features): 111 ''' 112 Compute the total score for all the paths. 113 ''' 114 forward_var = self.init_alphas 115 forward_var = self.expand(forward_var, 1) 116 for idx in range(self.seq_length): 117 feat = features[:, idx:(idx+1), :] 118 emit_score = self.reshape(feat, (self.batch_size, self.target_size, 1)) 119 next_tag_var = emit_score + self.transitions + forward_var 120 forward_var = self.log_sum_exp(next_tag_var) 121 forward_var = self.reshape(forward_var, (self.batch_size, 1, self.target_size)) 122 terminal_var = forward_var + self.reshape(self.transitions[(self.target_size-1):self.target_size, :], (1, -1)) 123 alpha = self.log_sum_exp(terminal_var) 124 alpha = self.reshape(alpha, (self.batch_size, -1)) 125 return alpha 126 127 def _decoder(self, features): 128 ''' 129 Viterbi decode for evaluation. 130 ''' 131 backpointers = () 132 forward_var = self.init_alphas 133 for idx in range(self.seq_length): 134 feat = features[:, idx:(idx+1), :] 135 feat = self.reshape(feat, (self.batch_size, self.target_size)) 136 bptrs_t = () 137 138 next_tag_var = self.expand(forward_var, 1) + self.transitions 139 best_tag_id, best_tag_value = self.argmax(next_tag_var) 140 bptrs_t += (best_tag_id,) 141 forward_var = best_tag_value + feat 142 143 backpointers += (bptrs_t,) 144 terminal_var = forward_var + self.reshape(self.transitions[(self.target_size-1):self.target_size, :], (1, -1)) 145 best_tag_id, _ = self.argmax(terminal_var) 146 return backpointers, best_tag_id 147 148 def construct(self, features, label): 149 if self.is_training: 150 forward_score = self._normalization_factor(features) 151 gold_score = self._realpath_score(features, label) 152 return_value = self.mean(forward_score - gold_score) 153 else: 154 path_list, tag = self._decoder(features) 155 return_value = path_list, tag 156 return return_value 157 158def postprocess(backpointers, best_tag_id): 159 ''' 160 Do postprocess 161 ''' 162 best_tag_id = best_tag_id.asnumpy() 163 batch_size = len(best_tag_id) 164 best_path = [] 165 for i in range(batch_size): 166 best_path.append([]) 167 best_local_id = best_tag_id[i] 168 best_path[-1].append(best_local_id) 169 for bptrs_t in reversed(backpointers): 170 bptrs_t = bptrs_t[0].asnumpy() 171 local_idx = bptrs_t[i] 172 best_local_id = local_idx[best_local_id] 173 best_path[-1].append(best_local_id) 174 # Pop off the start tag (we dont want to return that to the caller) 175 best_path[-1].pop() 176 best_path[-1].reverse() 177 return best_path 178