1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/train/accuracy_metrics.h"
18 #include "include/errorcode.h"
19 #include "src/common/utils.h"
20 #include "src/tensor.h"
21 #include "src/train/train_utils.h"
22
23 namespace mindspore {
24 namespace lite {
AccuracyMetrics(int accuracy_metrics,const std::vector<int> & input_indexes,const std::vector<int> & output_indexes)25 AccuracyMetrics::AccuracyMetrics(int accuracy_metrics, const std::vector<int> &input_indexes,
26 const std::vector<int> &output_indexes)
27 : Metrics() {
28 if (input_indexes.size() == output_indexes.size()) {
29 input_indexes_ = input_indexes;
30 output_indexes_ = output_indexes;
31 } else {
32 MS_LOG(WARNING) << "input to output mapping vectors sizes do not match";
33 }
34 if (accuracy_metrics != METRICS_CLASSIFICATION) {
35 MS_LOG(WARNING) << "Only classification metrics is supported";
36 } else {
37 accuracy_metrics_ = accuracy_metrics;
38 }
39 }
40
Update(std::vector<lite::Tensor * > inputs,std::vector<lite::Tensor * > outputs)41 void AccuracyMetrics::Update(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
42 for (unsigned int i = 0; i < input_indexes_.size(); i++) {
43 if ((inputs.size() <= static_cast<unsigned int>(input_indexes_[i])) ||
44 (outputs.size() <= static_cast<unsigned int>(output_indexes_[i]))) {
45 MS_LOG(WARNING) << "indices " << input_indexes_[i] << "/" << output_indexes_[i]
46 << " is outside of input/output range";
47 return;
48 }
49 float accuracy = 0.0;
50 if (inputs.at(input_indexes_[i])->data_type() == kNumberTypeInt32) {
51 accuracy = CalculateSparseClassification(inputs.at(input_indexes_[i]), outputs.at(output_indexes_[i]));
52 } else {
53 accuracy = CalculateOneHotClassification(inputs.at(input_indexes_[i]), outputs.at(output_indexes_[i]));
54 }
55 total_accuracy_ += accuracy;
56 total_steps_ += 1.0;
57 }
58 }
59
Eval()60 float AccuracyMetrics::Eval() {
61 if (total_steps_ == 0.0) {
62 MS_LOG(WARNING) << "Accuary can not be calculated, because the number of samples is 0.";
63 return 0.0;
64 }
65
66 return (total_accuracy_ / total_steps_);
67 }
68 } // namespace lite
69 } // namespace mindspore
70