• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16'''
17CRF script.
18'''
19
20import numpy as np
21import mindspore.nn as nn
22from mindspore.ops import operations as P
23from mindspore.common.tensor import Tensor
24from mindspore.common.parameter import Parameter
25import mindspore.common.dtype as mstype
26
27class CRF(nn.Cell):
28    '''
29    Conditional Random Field
30    Args:
31        tag_to_index: The dict for tag to index mapping with extra "<START>" and "<STOP>"sign.
32        batch_size: Batch size, i.e., the length of the first dimension.
33        seq_length: Sequence length, i.e., the length of the second dimension.
34        is_training: Specifies whether to use training mode.
35    Returns:
36        Training mode: Tensor, total loss.
37        Evaluation mode: Tuple, the index for each step with the highest score; Tuple, the index for the last
38        step with the highest score.
39    '''
40    def __init__(self, tag_to_index, batch_size=1, seq_length=128, is_training=True):
41
42        super(CRF, self).__init__()
43        self.target_size = len(tag_to_index)
44        self.is_training = is_training
45        self.tag_to_index = tag_to_index
46        self.batch_size = batch_size
47        self.seq_length = seq_length
48        self.START_TAG = "<START>"
49        self.STOP_TAG = "<STOP>"
50        self.START_VALUE = Tensor(self.target_size-2, dtype=mstype.int32)
51        self.STOP_VALUE = Tensor(self.target_size-1, dtype=mstype.int32)
52        transitions = np.random.normal(size=(self.target_size, self.target_size)).astype(np.float32)
53        transitions[tag_to_index[self.START_TAG], :] = -10000
54        transitions[:, tag_to_index[self.STOP_TAG]] = -10000
55        self.transitions = Parameter(Tensor(transitions), name="transition_matrix")
56        self.cat = P.Concat(axis=-1)
57        self.argmax = P.ArgMaxWithValue(axis=-1)
58        self.log = P.Log()
59        self.exp = P.Exp()
60        self.sum = P.ReduceSum()
61        self.tile = P.Tile()
62        self.reduce_sum = P.ReduceSum(keep_dims=True)
63        self.reshape = P.Reshape()
64        self.expand = P.ExpandDims()
65        self.mean = P.ReduceMean()
66        init_alphas = np.ones(shape=(self.batch_size, self.target_size)) * -10000.0
67        init_alphas[:, self.tag_to_index[self.START_TAG]] = 0.
68        self.init_alphas = Tensor(init_alphas, dtype=mstype.float32)
69        self.cast = P.Cast()
70        self.reduce_max = P.ReduceMax(keep_dims=True)
71        self.on_value = Tensor(1.0, dtype=mstype.float32)
72        self.off_value = Tensor(0.0, dtype=mstype.float32)
73        self.onehot = P.OneHot()
74
75    def log_sum_exp(self, logits):
76        '''
77        Compute the log_sum_exp score for Normalization factor.
78        '''
79        max_score = self.reduce_max(logits, -1)  #16 5 5
80        score = self.log(self.reduce_sum(self.exp(logits - max_score), -1))
81        score = max_score + score
82        return score
83
84    def _realpath_score(self, features, label):
85        '''
86        Compute the emission and transition score for the real path.
87        '''
88        label = label * 1
89        concat_A = self.tile(self.reshape(self.START_VALUE, (1,)), (self.batch_size,))
90        concat_A = self.reshape(concat_A, (self.batch_size, 1))
91        labels = self.cat((concat_A, label))
92        onehot_label = self.onehot(label, self.target_size, self.on_value, self.off_value)
93        emits = features * onehot_label
94        labels = self.onehot(labels, self.target_size, self.on_value, self.off_value)
95        label1 = labels[:, 1:, :]
96        label2 = labels[:, :self.seq_length, :]
97        label1 = self.expand(label1, 3)
98        label2 = self.expand(label2, 2)
99        label_trans = label1 * label2
100        transitions = self.expand(self.expand(self.transitions, 0), 0)
101        trans = transitions * label_trans
102        score = self.sum(emits, (1, 2)) + self.sum(trans, (1, 2, 3))
103        stop_value_index = labels[:, (self.seq_length-1):self.seq_length, :]
104        stop_value = self.transitions[(self.target_size-1):self.target_size, :]
105        stop_score = stop_value * self.reshape(stop_value_index, (self.batch_size, self.target_size))
106        score = score + self.sum(stop_score, 1)
107        score = self.reshape(score, (self.batch_size, -1))
108        return score
109
110    def _normalization_factor(self, features):
111        '''
112        Compute the total score for all the paths.
113        '''
114        forward_var = self.init_alphas
115        forward_var = self.expand(forward_var, 1)
116        for idx in range(self.seq_length):
117            feat = features[:, idx:(idx+1), :]
118            emit_score = self.reshape(feat, (self.batch_size, self.target_size, 1))
119            next_tag_var = emit_score + self.transitions + forward_var
120            forward_var = self.log_sum_exp(next_tag_var)
121            forward_var = self.reshape(forward_var, (self.batch_size, 1, self.target_size))
122        terminal_var = forward_var + self.reshape(self.transitions[(self.target_size-1):self.target_size, :], (1, -1))
123        alpha = self.log_sum_exp(terminal_var)
124        alpha = self.reshape(alpha, (self.batch_size, -1))
125        return alpha
126
127    def _decoder(self, features):
128        '''
129        Viterbi decode for evaluation.
130        '''
131        backpointers = ()
132        forward_var = self.init_alphas
133        for idx in range(self.seq_length):
134            feat = features[:, idx:(idx+1), :]
135            feat = self.reshape(feat, (self.batch_size, self.target_size))
136            bptrs_t = ()
137
138            next_tag_var = self.expand(forward_var, 1) + self.transitions
139            best_tag_id, best_tag_value = self.argmax(next_tag_var)
140            bptrs_t += (best_tag_id,)
141            forward_var = best_tag_value + feat
142
143            backpointers += (bptrs_t,)
144        terminal_var = forward_var + self.reshape(self.transitions[(self.target_size-1):self.target_size, :], (1, -1))
145        best_tag_id, _ = self.argmax(terminal_var)
146        return backpointers, best_tag_id
147
148    def construct(self, features, label):
149        if self.is_training:
150            forward_score = self._normalization_factor(features)
151            gold_score = self._realpath_score(features, label)
152            return_value = self.mean(forward_score - gold_score)
153        else:
154            path_list, tag = self._decoder(features)
155            return_value = path_list, tag
156        return return_value
157
158def postprocess(backpointers, best_tag_id):
159    '''
160    Do postprocess
161    '''
162    best_tag_id = best_tag_id.asnumpy()
163    batch_size = len(best_tag_id)
164    best_path = []
165    for i in range(batch_size):
166        best_path.append([])
167        best_local_id = best_tag_id[i]
168        best_path[-1].append(best_local_id)
169        for bptrs_t in reversed(backpointers):
170            bptrs_t = bptrs_t[0].asnumpy()
171            local_idx = bptrs_t[i]
172            best_local_id = local_idx[best_local_id]
173            best_path[-1].append(best_local_id)
174        # Pop off the start tag (we dont want to return that to the caller)
175        best_path[-1].pop()
176        best_path[-1].reverse()
177    return best_path
178