1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""LossMonitor Callback class.""" 16 17import numpy as np 18from mindspore.common.tensor import Tensor 19 20from ._callback import Callback 21 22 23class LossMonitor(Callback): 24 """ 25 Monitor the loss in training. 26 27 If the loss is NAN or INF, it will terminate training. 28 29 Note: 30 If per_print_times is 0, do not print loss. 31 32 Args: 33 per_print_times (int): Print the loss every seconds. Default: 1. 34 35 Raises: 36 ValueError: If per_print_times is not an integer or less than zero. 37 """ 38 39 def __init__(self, per_print_times=1): 40 super(LossMonitor, self).__init__() 41 if not isinstance(per_print_times, int) or per_print_times < 0: 42 raise ValueError("The argument 'per_print_times' must be int and >= 0, " 43 "but got {}".format(per_print_times)) 44 self._per_print_times = per_print_times 45 46 def step_end(self, run_context): 47 """ 48 Print training loss at the end of step. 49 50 Args: 51 run_context (RunContext): Context of the train running. 52 """ 53 cb_params = run_context.original_args() 54 loss = cb_params.net_outputs 55 56 if isinstance(loss, (tuple, list)): 57 if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): 58 loss = loss[0] 59 60 if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): 61 loss = np.mean(loss.asnumpy()) 62 63 cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 64 65 if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): 66 raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( 67 cb_params.cur_epoch_num, cur_step_in_epoch)) 68 if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: 69 print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True) 70