• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/train/loss_monitor.h"
18 #include <sys/stat.h>
19 #include <algorithm>
20 #include <utility>
21 #include <vector>
22 #include <iostream>
23 #include "src/litert/lite_session.h"
24 #include "src/common/utils.h"
25 #include "src/tensor.h"
26 
27 namespace mindspore {
28 namespace lite {
Begin(const TrainLoopCallBackData & cb_data)29 void LossMonitor::Begin(const TrainLoopCallBackData &cb_data) {
30   if (cb_data.epoch_ == 0) losses_.clear();
31 }
32 
EpochBegin(const TrainLoopCallBackData & cb_data)33 void LossMonitor::EpochBegin(const TrainLoopCallBackData &cb_data) {
34   if (losses_.size() != cb_data.epoch_) {
35     MS_LOG(WARNING) << "losses array does not match epoch number";
36   } else {
37     losses_.push_back(std::make_pair(cb_data.epoch_, 0.0));
38   }
39 }
40 
EpochEnd(const TrainLoopCallBackData & cb_data)41 int LossMonitor::EpochEnd(const TrainLoopCallBackData &cb_data) {
42   if (cb_data.step_ > 0) losses_.at(cb_data.epoch_).second /= static_cast<float>(cb_data.step_ + 1);
43   if (print_every_n_ > 0) {
44     std::cout << "Epoch (" << (cb_data.epoch_ + 1) << "):\tLoss is " << losses_.at(cb_data.epoch_).second << std::endl;
45   }
46   return RET_CONTINUE;
47 }
48 
StepEnd(const TrainLoopCallBackData & cb_data)49 void LossMonitor::StepEnd(const TrainLoopCallBackData &cb_data) {
50   auto outputs = cb_data.session_->GetOutputs();
51   for (auto it = outputs.begin(); it != outputs.end(); ++it) {
52     if (it->second->ElementsNum() == 1) {
53       auto loss = reinterpret_cast<float *>(it->second->MutableData());
54       losses_.at(cb_data.epoch_).second += loss[0];
55       if ((static_cast<int>(cb_data.step_) + 1) % print_every_n_ == 0)
56         std::cout << (cb_data.epoch_ + 1) << "." << (cb_data.step_ + 1) << ":\tLoss is " << loss[0] << std::endl;
57       return;
58     }
59   }
60   MS_LOG(WARNING) << "Model does not have a loss output tensor of size 1";
61 }
62 }  // namespace lite
63 }  // namespace mindspore
64