• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the License);
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# httpwww.apache.orglicensesLICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an AS IS BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Defined callback for DeepFM.
17"""
18import time
19from mindspore.train.callback import Callback
20
21
22def add_write(file_path, out_str):
23    with open(file_path, 'a+', encoding='utf-8') as file_out:
24        file_out.write(out_str + '\n')
25
26
27class EvalCallBack(Callback):
28    """
29    Monitor the loss in training.
30    If the loss is NAN or INF terminating training.
31    Note
32        If per_print_times is 0 do not print loss.
33    """
34    def __init__(self, model, eval_dataset, auc_metric, eval_file_path):
35        super(EvalCallBack, self).__init__()
36        self.model = model
37        self.eval_dataset = eval_dataset
38        self.aucMetric = auc_metric
39        self.aucMetric.clear()
40        self.eval_file_path = eval_file_path
41
42    def epoch_end(self, run_context):
43        start_time = time.time()
44        out = self.model.eval(self.eval_dataset)
45        eval_time = int(time.time() - start_time)
46        time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
47        out_str = "{} EvalCallBack metric{}; eval_time{}s".format(
48            time_str, out.values(), eval_time)
49        print(out_str)
50        add_write(self.eval_file_path, out_str)
51
52
53class LossCallBack(Callback):
54    """
55    Monitor the loss in training.
56    If the loss is NAN or INF terminating training.
57    Note
58        If per_print_times is 0 do not print loss.
59    Args
60        loss_file_path (str) The file absolute path, to save as loss_file;
61        per_print_times (int) Print loss every times. Default 1.
62    """
63    def __init__(self, loss_file_path, per_print_times=1):
64        super(LossCallBack, self).__init__()
65        if not isinstance(per_print_times, int) or per_print_times < 0:
66            raise ValueError("print_step must be int and >= 0.")
67        self.loss_file_path = loss_file_path
68        self._per_print_times = per_print_times
69        self.loss = 0
70
71    def step_end(self, run_context):
72        cb_params = run_context.original_args()
73        loss = cb_params.net_outputs.asnumpy()
74        cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
75        cur_num = cb_params.cur_step_num
76        if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
77            with open(self.loss_file_path, "a+") as loss_file:
78                time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
79                loss_file.write("{} epoch: {} step: {}, loss is {}\n".format(
80                    time_str, cb_params.cur_epoch_num, cur_step_in_epoch, loss))
81            print("epoch: {} step: {}, loss is {}\n".format(
82                cb_params.cur_epoch_num, cur_step_in_epoch, loss))
83            self.loss = loss
84
85class TimeMonitor(Callback):
86    """
87    Time monitor for calculating cost of each epoch.
88    Args
89        data_size (int) step size of an epoch.
90    """
91    def __init__(self, data_size):
92        super(TimeMonitor, self).__init__()
93        self.data_size = data_size
94        self.per_step_time = 0
95
96    def epoch_begin(self, run_context):
97        self.epoch_time = time.time()
98
99    def epoch_end(self, run_context):
100        epoch_mseconds = (time.time() - self.epoch_time) * 1000
101        per_step_mseconds = epoch_mseconds / self.data_size
102        print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
103        self.per_step_time = per_step_mseconds
104
105    def step_begin(self, run_context):
106        self.step_time = time.time()
107
108    def step_end(self, run_context):
109        step_mseconds = (time.time() - self.step_time) * 1000
110        print(f"step time {step_mseconds}", flush=True)
111