1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Callback related classes and functions.""" 16 17from contextlib import ExitStack 18 19from mindspore import log as logger 20from mindspore.train.serialization import _fill_param_into_net 21from mindspore.train.summary.summary_record import _cache_summary_tensor_data 22 23_cur_net = None 24 25 26def set_cur_net(net): 27 28 """ 29 Set current net for which we are using to save checkpoint. 30 31 Args: 32 net (Cell): train network 33 """ 34 global _cur_net 35 _cur_net = net 36 37 38def checkpoint_cb_for_save_op(parameter_list): 39 """ 40 The checkpoint callback function for MindSpore. 41 42 Will be executed by checkpoint save op. 43 44 Args: 45 parameter_list (list): Format is like [{"name",name},{"data",value}] and value type is Tensor. 46 47 Returns: 48 bool, true: means save checkpoint success. 49 """ 50 if _cur_net is None: 51 logger.warning("_cur_net is None. parameters are not updated.") 52 return False 53 54 logger.info("update parameters in the net.") 55 _fill_param_into_net(_cur_net, parameter_list) 56 set_cur_net(None) 57 return True 58 59 60def summary_cb_for_save_op(summary_list): 61 """ 62 The summary callback function for MindSpore. 63 64 Will be executed by summary op. 65 66 Args: 67 summary_list (list): Format is like [{"name": tag_name, "data": tensor},...] and value is Scalar/Tensor. 68 69 Returns: 70 bool, true: means save summary success. 71 """ 72 ret = _cache_summary_tensor_data(summary_list) 73 return ret 74 75 76class Callback: 77 """ 78 Abstract base class used to build a callback class. Callbacks are context managers 79 which will be entered and exited when passing into the Model. 80 You can use this mechanism to initialize and release resources automatically. 81 82 Callback function will execute some operations in the current step or epoch. 83 84 It holds the information of the model. Such as `network`, `train_network`, `epoch_num`, `batch_num`, 85 `loss_fn`, `optimizer`, `parallel_mode`, `device_number`, `list_callback`, `cur_epoch_num`, 86 `cur_step_num`, `dataset_sink_mode`, `net_outputs` and so on. 87 88 Examples: 89 >>> from mindspore import Model, nn 90 >>> from mindspore.train.callback import Callback 91 >>> class Print_info(Callback): 92 ... def step_end(self, run_context): 93 ... cb_params = run_context.original_args() 94 ... print("step_num: ", cb_params.cur_step_num) 95 >>> 96 >>> print_cb = Print_info() 97 >>> dataset = create_custom_dataset() 98 >>> net = Net() 99 >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 100 >>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9) 101 >>> model = Model(net, loss_fn=loss, optimizer=optim) 102 >>> model.train(1, dataset, callbacks=print_cb) 103 step_num: 1 104 """ 105 106 def __enter__(self): 107 """Return the enter target.""" 108 return self 109 110 def __exit__(self, *err): 111 """Release resources here if have any.""" 112 113 def begin(self, run_context): 114 """ 115 Called once before the network executing. 116 117 Args: 118 run_context (RunContext): Include some information of the model. 119 """ 120 121 def epoch_begin(self, run_context): 122 """ 123 Called before each epoch beginning. 124 125 Args: 126 run_context (RunContext): Include some information of the model. 127 """ 128 129 def epoch_end(self, run_context): 130 """ 131 Called after each epoch finished. 132 133 Args: 134 run_context (RunContext): Include some information of the model. 135 """ 136 137 def step_begin(self, run_context): 138 """ 139 Called before each step beginning. 140 141 Args: 142 run_context (RunContext): Include some information of the model. 143 """ 144 145 def step_end(self, run_context): 146 """ 147 Called after each step finished. 148 149 Args: 150 run_context (RunContext): Include some information of the model. 151 """ 152 153 def end(self, run_context): 154 """ 155 Called once after network training. 156 157 Args: 158 run_context (RunContext): Include some information of the model. 159 """ 160 161 162class CallbackManager(Callback): 163 """ 164 Sequential execution of callback functions. 165 166 Execute Callback functions at certain points. 167 168 Args: 169 callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list. 170 """ 171 172 def __init__(self, callbacks): 173 self._callbacks, self._stack = [], None 174 if isinstance(callbacks, Callback): 175 self._callbacks.append(callbacks) 176 elif isinstance(callbacks, list): 177 for cb in callbacks: 178 if not isinstance(cb, Callback): 179 raise TypeError("When the 'callbacks' is a list, the elements in " 180 "'callbacks' must be Callback functions.") 181 self._callbacks.append(cb) 182 elif callbacks is not None: 183 raise TypeError("The 'callbacks' is not a Callback or a list of Callback.") 184 185 def __enter__(self): 186 if self._stack is None: 187 callbacks, self._stack = [], ExitStack().__enter__() 188 for callback in self._callbacks: 189 target = self._stack.enter_context(callback) 190 if not isinstance(target, Callback): 191 logger.warning("Please return 'self' or a Callback as the enter target.") 192 callbacks.append(callback) 193 else: 194 callbacks.append(target) 195 self._callbacks = callbacks 196 return self 197 198 def __exit__(self, *err): 199 return self._stack.__exit__(*err) 200 201 def begin(self, run_context): 202 """Called once before network training.""" 203 for cb in self._callbacks: 204 cb.begin(run_context) 205 206 def epoch_begin(self, run_context): 207 """Called before each epoch begin.""" 208 for cb in self._callbacks: 209 cb.epoch_begin(run_context) 210 211 def epoch_end(self, run_context): 212 """Called after each epoch finished.""" 213 for cb in self._callbacks: 214 cb.epoch_end(run_context) 215 216 def step_begin(self, run_context): 217 """Called before each epoch begin.""" 218 for cb in self._callbacks: 219 cb.step_begin(run_context) 220 221 def step_end(self, run_context): 222 """Called after each step finished.""" 223 for cb in self._callbacks: 224 cb.step_end(run_context) 225 226 def end(self, run_context): 227 """Called once after network training.""" 228 for cb in self._callbacks: 229 cb.end(run_context) 230 231 232class InternalCallbackParam(dict): 233 """Internal callback object's parameters.""" 234 235 def __getattr__(self, key): 236 return self[key] 237 238 def __setattr__(self, key, value): 239 self[key] = value 240 241 242class RunContext: 243 """ 244 Provide information about the model. 245 246 Provide information about original request to model function. 247 Callback objects can stop the loop by calling request_stop() of run_context. 248 249 Args: 250 original_args (dict): Holding the related information of model. 251 """ 252 def __init__(self, original_args): 253 if not isinstance(original_args, dict): 254 raise TypeError("The argument 'original_args' of RunContext should be dict type, " 255 "but got {}.".format(type(original_args))) 256 self._original_args = original_args 257 self._stop_requested = False 258 259 def original_args(self): 260 """ 261 Get the _original_args object. 262 263 Returns: 264 Dict, an object that holds the original arguments of model. 265 """ 266 return self._original_args 267 268 def request_stop(self): 269 """ 270 Set stop requirement during training. 271 272 Callbacks can use this function to request stop of iterations. 273 model.train() checks whether this is called or not. 274 """ 275 self._stop_requested = True 276 277 def get_stop_requested(self): 278 """ 279 Return whether a stop is requested or not. 280 281 Returns: 282 bool, if true, model.train() stops iterations. 283 """ 284 return self._stop_requested 285