1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""assessment methods""" 17 18import numpy as np 19 20 21class Accuracy: 22 """Accuracy""" 23 def __init__(self): 24 self.acc_num = 0 25 self.total_num = 0 26 self.name = 'Accuracy' 27 28 def update(self, logits, labels): 29 labels = labels.asnumpy() 30 labels = np.reshape(labels, -1) 31 logits = logits.asnumpy() 32 logit_id = np.argmax(logits, axis=-1) 33 self.acc_num += np.sum(labels == logit_id) 34 self.total_num += len(labels) 35 36 def get_metrics(self): 37 return self.acc_num / self.total_num * 100.0 38 39 40class TopK: 41 """F1""" 42 def __init__(self, k=5): 43 self.acc_num = 0 44 self.total_num = 0 45 self.k = k 46 self.name = 'Top' + str(k) 47 48 def update(self, logits, labels): 49 labels = labels.asnumpy() 50 logits = logits.asnumpy() 51 sorted_index = logits.argsort() 52 for i, label in enumerate(labels): 53 for j in range(self.k): 54 if sorted_index[i, -j-1] == label: 55 self.acc_num += 1 56 break 57 self.total_num += len(labels) 58 59 def get_metrics(self): 60 return self.acc_num / self.total_num * 100.0 61 62 63class F1: 64 """F1""" 65 def __init__(self): 66 self.logits_array = np.array([]) 67 self.labels_array = np.array([]) 68 self.name = 'F1' 69 70 def update(self, logits, labels): 71 labels = labels.asnumpy() 72 labels = np.reshape(labels, -1) 73 logits = logits.asnumpy() 74 logits = np.argmax(logits, axis=1) 75 self.labels_array = np.concatenate([self.labels_array, labels]).astype(np.bool) 76 self.logits_array = np.concatenate([self.logits_array, logits]).astype(np.bool) 77 78 def get_metrics(self): 79 if len(self.labels_array) < 2: 80 return 0.0 81 tp = np.sum(self.labels_array & self.logits_array) 82 fp = np.sum(self.labels_array & (~self.logits_array)) 83 fn = np.sum((~self.labels_array) & self.logits_array) 84 p = tp / (tp + fp) 85 r = tp / (tp + fn) 86 return 2.0 * p * r / (p + r) * 100.0 87 88 89class Pearsonr: 90 """Pearsonr""" 91 def __init__(self): 92 self.logits_array = np.array([]) 93 self.labels_array = np.array([]) 94 self.name = 'Pearsonr' 95 96 def update(self, logits, labels): 97 labels = labels.asnumpy() 98 labels = np.reshape(labels, -1) 99 logits = logits.asnumpy() 100 logits = np.reshape(logits, -1) 101 self.labels_array = np.concatenate([self.labels_array, labels]) 102 self.logits_array = np.concatenate([self.logits_array, logits]) 103 104 def get_metrics(self): 105 if len(self.labels_array) < 2: 106 return 0.0 107 x_mean = self.logits_array.mean() 108 y_mean = self.labels_array.mean() 109 xm = self.logits_array - x_mean 110 ym = self.labels_array - y_mean 111 norm_xm = np.linalg.norm(xm) 112 norm_ym = np.linalg.norm(ym) 113 return np.dot(xm / norm_xm, ym / norm_ym) * 100.0 114 115 116class Matthews: 117 """Matthews""" 118 def __init__(self): 119 self.logits_array = np.array([]) 120 self.labels_array = np.array([]) 121 self.name = 'Matthews' 122 123 def update(self, logits, labels): 124 labels = labels.asnumpy() 125 labels = np.reshape(labels, -1) 126 logits = logits.asnumpy() 127 logits = np.argmax(logits, axis=1) 128 self.labels_array = np.concatenate([self.labels_array, labels]).astype(np.bool) 129 self.logits_array = np.concatenate([self.logits_array, logits]).astype(np.bool) 130 131 def get_metrics(self): 132 if len(self.labels_array) < 2: 133 return 0.0 134 tp = np.sum(self.labels_array & self.logits_array) 135 fp = np.sum(self.labels_array & (~self.logits_array)) 136 fn = np.sum((~self.labels_array) & self.logits_array) 137 tn = np.sum((~self.labels_array) & (~self.logits_array)) 138 return (tp * tn - fp * fn) / np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) * 100.0 139