• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/train/accuracy_monitor.h"
18 #include <sys/stat.h>
19 #include <algorithm>
20 #include <utility>
21 #include <vector>
22 #include <iostream>
23 #include <fstream>
24 #include <memory>
25 #include "include/errorcode.h"
26 #include "src/train/train_loop.h"
27 #include "src/common/utils.h"
28 #include "src/tensor.h"
29 
30 namespace mindspore {
31 namespace lite {
Begin(const lite::TrainLoopCallBackData & cb_data)32 void AccuracyMonitor::Begin(const lite::TrainLoopCallBackData &cb_data) {
33   if (cb_data.epoch_ == 0) accuracies_.clear();
34 }
35 
EpochEnd(const lite::TrainLoopCallBackData & cb_data)36 int AccuracyMonitor::EpochEnd(const lite::TrainLoopCallBackData &cb_data) {
37   if ((static_cast<int>(cb_data.epoch_) + 1) % check_every_n_ == 0) {
38     auto ret = cb_data.loop_->Eval(ds_, {}, nullptr, max_steps_);
39     if (ret != RET_OK) {
40       MS_LOG(ERROR) << "Eval failed.";
41       return RET_ERROR;
42     }
43   }
44   accuracies_.push_back(std::make_pair(cb_data.epoch_, 0.0));
45   return RET_CONTINUE;
46 }
47 }  // namespace lite
48 }  // namespace mindspore
49