1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""TimeMonitor Callback class.""" 16 17import time 18 19from ._callback import Callback 20 21 22class TimeMonitor(Callback): 23 """ 24 Monitor the time in training. 25 26 Args: 27 data_size (int): How many steps are the intervals between print information each time. 28 if the program get `batch_num` during training, `data_size` will be set to `batch_num`, 29 otherwise `data_size` will be used. Default: None. 30 31 Raises: 32 ValueError: If data_size is not positive int. 33 """ 34 35 def __init__(self, data_size=None): 36 super(TimeMonitor, self).__init__() 37 self.data_size = data_size 38 self.epoch_time = time.time() 39 40 def epoch_begin(self, run_context): 41 """ 42 Record time at the begin of epoch. 43 44 Args: 45 run_context (RunContext): Context of the process running. 46 """ 47 self.epoch_time = time.time() 48 49 def epoch_end(self, run_context): 50 """ 51 Print process cost time at the end of epoch. 52 53 Args: 54 run_context (RunContext): Context of the process running. 55 """ 56 epoch_seconds = (time.time() - self.epoch_time) * 1000 57 step_size = self.data_size 58 cb_params = run_context.original_args() 59 if hasattr(cb_params, "batch_num"): 60 batch_num = cb_params.batch_num 61 if isinstance(batch_num, int) and batch_num > 0: 62 step_size = cb_params.batch_num 63 64 if not isinstance(step_size, int) or step_size < 1: 65 raise ValueError("data_size must be positive int.") 66 67 step_seconds = epoch_seconds / step_size 68 print("epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format(epoch_seconds, step_seconds), flush=True) 69