• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/train/accuracy_metrics.h"
18 #include "include/errorcode.h"
19 #include "src/common/utils.h"
20 #include "src/tensor.h"
21 #include "src/train/train_utils.h"
22 
23 namespace mindspore {
24 namespace lite {
AccuracyMetrics(int accuracy_metrics,const std::vector<int> & input_indexes,const std::vector<int> & output_indexes)25 AccuracyMetrics::AccuracyMetrics(int accuracy_metrics, const std::vector<int> &input_indexes,
26                                  const std::vector<int> &output_indexes)
27     : Metrics() {
28   if (input_indexes.size() == output_indexes.size()) {
29     input_indexes_ = input_indexes;
30     output_indexes_ = output_indexes;
31   } else {
32     MS_LOG(WARNING) << "input to output mapping vectors sizes do not match";
33   }
34   if (accuracy_metrics != METRICS_CLASSIFICATION) {
35     MS_LOG(WARNING) << "Only classification metrics is supported";
36   } else {
37     accuracy_metrics_ = accuracy_metrics;
38   }
39 }
40 
Update(std::vector<lite::Tensor * > inputs,std::vector<lite::Tensor * > outputs)41 void AccuracyMetrics::Update(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
42   for (unsigned int i = 0; i < input_indexes_.size(); i++) {
43     if ((inputs.size() <= static_cast<unsigned int>(input_indexes_[i])) ||
44         (outputs.size() <= static_cast<unsigned int>(output_indexes_[i]))) {
45       MS_LOG(WARNING) << "indices " << input_indexes_[i] << "/" << output_indexes_[i]
46                       << " is outside of input/output range";
47       return;
48     }
49     float accuracy = 0.0;
50     if (inputs.at(input_indexes_[i])->data_type() == kNumberTypeInt32) {
51       accuracy = CalculateSparseClassification(inputs.at(input_indexes_[i]), outputs.at(output_indexes_[i]));
52     } else {
53       accuracy = CalculateOneHotClassification(inputs.at(input_indexes_[i]), outputs.at(output_indexes_[i]));
54     }
55     total_accuracy_ += accuracy;
56     total_steps_ += 1.0;
57   }
58 }
59 
Eval()60 float AccuracyMetrics::Eval() {
61   if (total_steps_ == 0.0) {
62     MS_LOG(WARNING) << "Accuary can not be calculated, because the number of samples is 0.";
63     return 0.0;
64   }
65 
66   return (total_accuracy_ / total_steps_);
67 }
68 }  // namespace lite
69 }  // namespace mindspore
70