• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""assessment methods"""
17
18import numpy as np
19
20
21class Accuracy:
22    """Accuracy"""
23    def __init__(self):
24        self.acc_num = 0
25        self.total_num = 0
26        self.name = 'Accuracy'
27
28    def update(self, logits, labels):
29        labels = labels.asnumpy()
30        labels = np.reshape(labels, -1)
31        logits = logits.asnumpy()
32        logit_id = np.argmax(logits, axis=-1)
33        self.acc_num += np.sum(labels == logit_id)
34        self.total_num += len(labels)
35
36    def get_metrics(self):
37        return self.acc_num / self.total_num * 100.0
38
39
40class TopK:
41    """F1"""
42    def __init__(self, k=5):
43        self.acc_num = 0
44        self.total_num = 0
45        self.k = k
46        self.name = 'Top' + str(k)
47
48    def update(self, logits, labels):
49        labels = labels.asnumpy()
50        logits = logits.asnumpy()
51        sorted_index = logits.argsort()
52        for i, label in enumerate(labels):
53            for j in range(self.k):
54                if sorted_index[i, -j-1] == label:
55                    self.acc_num += 1
56                    break
57        self.total_num += len(labels)
58
59    def get_metrics(self):
60        return self.acc_num / self.total_num * 100.0
61
62
63class F1:
64    """F1"""
65    def __init__(self):
66        self.logits_array = np.array([])
67        self.labels_array = np.array([])
68        self.name = 'F1'
69
70    def update(self, logits, labels):
71        labels = labels.asnumpy()
72        labels = np.reshape(labels, -1)
73        logits = logits.asnumpy()
74        logits = np.argmax(logits, axis=1)
75        self.labels_array = np.concatenate([self.labels_array, labels]).astype(np.bool)
76        self.logits_array = np.concatenate([self.logits_array, logits]).astype(np.bool)
77
78    def get_metrics(self):
79        if len(self.labels_array) < 2:
80            return 0.0
81        tp = np.sum(self.labels_array & self.logits_array)
82        fp = np.sum(self.labels_array & (~self.logits_array))
83        fn = np.sum((~self.labels_array) & self.logits_array)
84        p = tp / (tp + fp)
85        r = tp / (tp + fn)
86        return 2.0 * p * r / (p + r) * 100.0
87
88
89class Pearsonr:
90    """Pearsonr"""
91    def __init__(self):
92        self.logits_array = np.array([])
93        self.labels_array = np.array([])
94        self.name = 'Pearsonr'
95
96    def update(self, logits, labels):
97        labels = labels.asnumpy()
98        labels = np.reshape(labels, -1)
99        logits = logits.asnumpy()
100        logits = np.reshape(logits, -1)
101        self.labels_array = np.concatenate([self.labels_array, labels])
102        self.logits_array = np.concatenate([self.logits_array, logits])
103
104    def get_metrics(self):
105        if len(self.labels_array) < 2:
106            return 0.0
107        x_mean = self.logits_array.mean()
108        y_mean = self.labels_array.mean()
109        xm = self.logits_array - x_mean
110        ym = self.labels_array - y_mean
111        norm_xm = np.linalg.norm(xm)
112        norm_ym = np.linalg.norm(ym)
113        return np.dot(xm / norm_xm, ym / norm_ym) * 100.0
114
115
116class Matthews:
117    """Matthews"""
118    def __init__(self):
119        self.logits_array = np.array([])
120        self.labels_array = np.array([])
121        self.name = 'Matthews'
122
123    def update(self, logits, labels):
124        labels = labels.asnumpy()
125        labels = np.reshape(labels, -1)
126        logits = logits.asnumpy()
127        logits = np.argmax(logits, axis=1)
128        self.labels_array = np.concatenate([self.labels_array, labels]).astype(np.bool)
129        self.logits_array = np.concatenate([self.logits_array, logits]).astype(np.bool)
130
131    def get_metrics(self):
132        if len(self.labels_array) < 2:
133            return 0.0
134        tp = np.sum(self.labels_array & self.logits_array)
135        fp = np.sum(self.labels_array & (~self.logits_array))
136        fn = np.sum((~self.labels_array) & self.logits_array)
137        tn = np.sum((~self.labels_array) & (~self.logits_array))
138        return (tp * tn - fp * fn) / np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) * 100.0
139