1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the License); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# httpwww.apache.orglicensesLICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an AS IS BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16Defined callback for DeepFM. 17""" 18import time 19from mindspore.train.callback import Callback 20 21 22def add_write(file_path, out_str): 23 with open(file_path, 'a+', encoding='utf-8') as file_out: 24 file_out.write(out_str + '\n') 25 26 27class EvalCallBack(Callback): 28 """ 29 Monitor the loss in training. 30 If the loss is NAN or INF terminating training. 31 Note 32 If per_print_times is 0 do not print loss. 33 """ 34 def __init__(self, model, eval_dataset, auc_metric, eval_file_path): 35 super(EvalCallBack, self).__init__() 36 self.model = model 37 self.eval_dataset = eval_dataset 38 self.aucMetric = auc_metric 39 self.aucMetric.clear() 40 self.eval_file_path = eval_file_path 41 42 def epoch_end(self, run_context): 43 start_time = time.time() 44 out = self.model.eval(self.eval_dataset) 45 eval_time = int(time.time() - start_time) 46 time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 47 out_str = "{} EvalCallBack metric{}; eval_time{}s".format( 48 time_str, out.values(), eval_time) 49 print(out_str) 50 add_write(self.eval_file_path, out_str) 51 52 53class LossCallBack(Callback): 54 """ 55 Monitor the loss in training. 56 If the loss is NAN or INF terminating training. 57 Note 58 If per_print_times is 0 do not print loss. 59 Args 60 loss_file_path (str) The file absolute path, to save as loss_file; 61 per_print_times (int) Print loss every times. Default 1. 62 """ 63 def __init__(self, loss_file_path, per_print_times=1): 64 super(LossCallBack, self).__init__() 65 if not isinstance(per_print_times, int) or per_print_times < 0: 66 raise ValueError("print_step must be int and >= 0.") 67 self.loss_file_path = loss_file_path 68 self._per_print_times = per_print_times 69 self.loss = 0 70 71 def step_end(self, run_context): 72 cb_params = run_context.original_args() 73 loss = cb_params.net_outputs.asnumpy() 74 cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 75 cur_num = cb_params.cur_step_num 76 if self._per_print_times != 0 and cur_num % self._per_print_times == 0: 77 with open(self.loss_file_path, "a+") as loss_file: 78 time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 79 loss_file.write("{} epoch: {} step: {}, loss is {}\n".format( 80 time_str, cb_params.cur_epoch_num, cur_step_in_epoch, loss)) 81 print("epoch: {} step: {}, loss is {}\n".format( 82 cb_params.cur_epoch_num, cur_step_in_epoch, loss)) 83 self.loss = loss 84 85class TimeMonitor(Callback): 86 """ 87 Time monitor for calculating cost of each epoch. 88 Args 89 data_size (int) step size of an epoch. 90 """ 91 def __init__(self, data_size): 92 super(TimeMonitor, self).__init__() 93 self.data_size = data_size 94 self.per_step_time = 0 95 96 def epoch_begin(self, run_context): 97 self.epoch_time = time.time() 98 99 def epoch_end(self, run_context): 100 epoch_mseconds = (time.time() - self.epoch_time) * 1000 101 per_step_mseconds = epoch_mseconds / self.data_size 102 print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True) 103 self.per_step_time = per_step_mseconds 104 105 def step_begin(self, run_context): 106 self.step_time = time.time() 107 108 def step_end(self, run_context): 109 step_mseconds = (time.time() - self.step_time) * 1000 110 print(f"step time {step_mseconds}", flush=True) 111