1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Python callback class 17""" 18import threading 19from mindspore._c_dataengine import PyDSCallback 20from mindspore.train.callback import Callback 21import mindspore.dataset as ds 22from .validators import check_callback 23 24 25class DSCallback: 26 """ 27 Abstract base class used to build a dataset callback class. 28 29 Args: 30 step_size (int, optional): The number of steps between the step_begin and step_end are called (Default=1). 31 32 Examples: 33 >>> from mindspore.dataset import DSCallback 34 >>> 35 >>> class PrintInfo(DSCallback): 36 ... def ds_epoch_end(self, ds_run_context): 37 ... print(cb_params.cur_epoch_num) 38 ... print(cb_params.cur_step_num) 39 >>> 40 >>> # dataset is an instance of Dataset object 41 >>> dataset = dataset.map(operations=op, callbacks=PrintInfo()) 42 """ 43 44 @check_callback 45 def __init__(self, step_size=1): 46 self.step_size = step_size 47 48 def ds_begin(self, ds_run_context): 49 """ 50 Called before the data pipeline is started. 51 52 Args: 53 ds_run_context (RunContext): Include some information of the pipeline. 54 """ 55 56 def ds_epoch_begin(self, ds_run_context): 57 """ 58 Called before a new epoch is started. 59 60 Args: 61 ds_run_context (RunContext): Include some information of the pipeline. 62 """ 63 64 def ds_epoch_end(self, ds_run_context): 65 """ 66 Called after an epoch is finished. 67 68 Args: 69 ds_run_context (RunContext): Include some information of the pipeline. 70 """ 71 72 def ds_step_begin(self, ds_run_context): 73 """ 74 Called before each step start. 75 76 Args: 77 ds_run_context (RunContext): Include some information of the pipeline. 78 """ 79 80 def ds_step_end(self, ds_run_context): 81 """ 82 Called after each step finished. 83 84 Args: 85 ds_run_context (RunContext): Include some information of the pipeline. 86 """ 87 88 def create_runtime_obj(self): 89 """ 90 Creates a runtime (C++) object from the callback methods defined by the user. 91 92 Returns: 93 _c_dataengine.PyDSCallback. 94 """ 95 c_cb = PyDSCallback(self.step_size) 96 at_least_one = False 97 98 if self.__class__.ds_begin != DSCallback.ds_begin: 99 c_cb.set_begin(self.ds_begin) 100 at_least_one = True 101 102 if self.__class__.ds_epoch_begin != DSCallback.ds_epoch_begin: 103 c_cb.set_epoch_begin(self.ds_epoch_begin) 104 at_least_one = True 105 if self.__class__.ds_epoch_end != DSCallback.ds_epoch_end: 106 c_cb.set_epoch_end(self.ds_epoch_end) 107 at_least_one = True 108 109 if self.__class__.ds_step_begin != DSCallback.ds_step_begin: 110 c_cb.set_step_begin(self.ds_step_begin) 111 at_least_one = True 112 if self.__class__.ds_step_end != DSCallback.ds_step_end: 113 c_cb.set_step_end(self.ds_step_end) 114 at_least_one = True 115 116 if not at_least_one: 117 raise AttributeError("Provided Callback class did not override any of the 6 callback methods.") 118 119 return c_cb 120 121 122class WaitedDSCallback(Callback, DSCallback): 123 """ 124 Abstract base class used to build a dataset callback class that is synchronized with the training callback. 125 126 This class can be used to execute a user defined logic right after the previous step or epoch. 127 For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters. 128 129 Args: 130 step_size (int, optional): The number of rows in each step. Usually the step size 131 will be equal to the batch size (Default=1). 132 133 Examples: 134 >>> from mindspore.dataset import WaitedDSCallback 135 >>> 136 >>> my_cb = WaitedDSCallback(32) 137 >>> # dataset is an instance of Dataset object 138 >>> dataset = dataset.map(operations=AugOp(), callbacks=my_cb) 139 >>> dataset = dataset.batch(32) 140 >>> # define the model 141 >>> model.train(epochs, data, callbacks=[my_cb]) 142 """ 143 144 def __init__(self, step_size=1): 145 super().__init__() 146 self.step_size = step_size 147 self.step_event = threading.Event() 148 self.step_run_context = None 149 150 self.epoch_event = threading.Event() 151 self.epoch_run_context = None 152 153 self.training_ended = False 154 155 def sync_epoch_begin(self, train_run_context, ds_run_context): 156 """ 157 Called before a new dataset epoch is started and after the previous training epoch is ended. 158 159 Args: 160 train_run_context: Include some information of the model with feedback from the previous epoch. 161 ds_run_context: Include some information of the dataset pipeline. 162 """ 163 164 def sync_step_begin(self, train_run_context, ds_run_context): 165 """ 166 Called before a new dataset step is started and after the previous training step is ended. 167 168 Args: 169 train_run_context: Include some information of the model with feedback from the previous step. 170 ds_run_context: Include some information of the dataset pipeline. 171 """ 172 173 def epoch_end(self, run_context): 174 """ 175 Internal method, do not call/override. Defines epoch_end of Callback to release the wait in ds_epoch_begin. 176 177 Args: 178 run_context: Include some information of the model. 179 """ 180 self.epoch_run_context = run_context 181 self.epoch_event.set() 182 183 def ds_epoch_begin(self, ds_run_context): 184 """ 185 Internal method, do not call/override. Defines ds_epoch_begin of DSCallback to wait for MS epoch_end callback. 186 187 Args: 188 ds_run_context: Include some information of the pipeline. 189 """ 190 if ds_run_context.cur_epoch_num > 1: 191 if not self.training_ended: 192 success = self.epoch_event.wait(timeout=ds.config.get_callback_timeout()) 193 self.epoch_event.clear() 194 if not success: 195 raise RuntimeError(f"ds_epoch_begin timed out after {ds.config.get_callback_timeout()} second(s).") 196 # by the time this thread wakes up, self.epoch_run_context is already available 197 self.sync_epoch_begin(self.epoch_run_context, ds_run_context) 198 199 def step_end(self, run_context): 200 """ 201 Internal method, do not call/override. Defines step_end of Callback to release the wait in ds_step_begin. 202 203 Args: 204 run_context: Include some information of the model. 205 """ 206 self.step_run_context = run_context 207 self.step_event.set() 208 209 def ds_step_begin(self, ds_run_context): 210 """ 211 Internal method, do not call/override. Defines ds_step_begin of DSCallback to wait for MS step_end callback. 212 213 Args: 214 ds_run_context: Include some information of the pipeline. 215 """ 216 if ds_run_context.cur_step_num > self.step_size: 217 if not self.training_ended: 218 success = self.step_event.wait(timeout=ds.config.get_callback_timeout()) 219 self.step_event.clear() 220 if not success: 221 raise RuntimeError(f"ds_step_begin timed out after {ds.config.get_callback_timeout()} second(s).") 222 # by the time this thread wakes up, self.epoch_run_context is already available 223 self.sync_step_begin(self.step_run_context, ds_run_context) 224 225 def create_runtime_obj(self): 226 """ 227 Creates a runtime (C++) object from the callback methods defined by the user. This method is internal. 228 229 Returns: 230 _c_dataengine.PyDSCallback. 231 """ 232 c_cb = PyDSCallback(self.step_size) 233 at_least_one = False 234 235 if self.__class__.sync_step_begin != WaitedDSCallback.sync_step_begin: 236 c_cb.set_step_begin(self.ds_step_begin) 237 at_least_one = True 238 239 if self.__class__.sync_epoch_begin != WaitedDSCallback.sync_epoch_begin: 240 c_cb.set_epoch_begin(self.ds_epoch_begin) 241 at_least_one = True 242 243 if not at_least_one: 244 raise AttributeError("Provided Callback class did not override any of the 2 callback methods.") 245 246 return c_cb 247 248 def end(self, run_context): 249 """ 250 Internal method, release the wait if training is ended. 251 252 Args: 253 run_context: Include some information of the model. 254 """ 255 self.epoch_end(run_context) 256 self.step_end(run_context) 257 self.training_ended = True 258