• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Python callback class
17"""
18import threading
19from mindspore._c_dataengine import PyDSCallback
20from mindspore.train.callback import Callback
21import mindspore.dataset as ds
22from .validators import check_callback
23
24
25class DSCallback:
26    """
27    Abstract base class used to build a dataset callback class.
28
29    Args:
30        step_size (int, optional): The number of steps between the step_begin and step_end are called (Default=1).
31
32    Examples:
33        >>> from mindspore.dataset import DSCallback
34        >>>
35        >>> class PrintInfo(DSCallback):
36        ...     def ds_epoch_end(self, ds_run_context):
37        ...         print(cb_params.cur_epoch_num)
38        ...         print(cb_params.cur_step_num)
39        >>>
40        >>> # dataset is an instance of Dataset object
41        >>> dataset = dataset.map(operations=op, callbacks=PrintInfo())
42    """
43
44    @check_callback
45    def __init__(self, step_size=1):
46        self.step_size = step_size
47
48    def ds_begin(self, ds_run_context):
49        """
50        Called before the data pipeline is started.
51
52        Args:
53            ds_run_context (RunContext): Include some information of the pipeline.
54        """
55
56    def ds_epoch_begin(self, ds_run_context):
57        """
58        Called before a new epoch is started.
59
60        Args:
61            ds_run_context (RunContext): Include some information of the pipeline.
62        """
63
64    def ds_epoch_end(self, ds_run_context):
65        """
66        Called after an epoch is finished.
67
68        Args:
69            ds_run_context (RunContext): Include some information of the pipeline.
70        """
71
72    def ds_step_begin(self, ds_run_context):
73        """
74        Called before each step start.
75
76        Args:
77            ds_run_context (RunContext): Include some information of the pipeline.
78        """
79
80    def ds_step_end(self, ds_run_context):
81        """
82        Called after each step finished.
83
84        Args:
85            ds_run_context (RunContext): Include some information of the pipeline.
86        """
87
88    def create_runtime_obj(self):
89        """
90        Creates a runtime (C++) object from the callback methods defined by the user.
91
92        Returns:
93            _c_dataengine.PyDSCallback.
94        """
95        c_cb = PyDSCallback(self.step_size)
96        at_least_one = False
97
98        if self.__class__.ds_begin != DSCallback.ds_begin:
99            c_cb.set_begin(self.ds_begin)
100            at_least_one = True
101
102        if self.__class__.ds_epoch_begin != DSCallback.ds_epoch_begin:
103            c_cb.set_epoch_begin(self.ds_epoch_begin)
104            at_least_one = True
105        if self.__class__.ds_epoch_end != DSCallback.ds_epoch_end:
106            c_cb.set_epoch_end(self.ds_epoch_end)
107            at_least_one = True
108
109        if self.__class__.ds_step_begin != DSCallback.ds_step_begin:
110            c_cb.set_step_begin(self.ds_step_begin)
111            at_least_one = True
112        if self.__class__.ds_step_end != DSCallback.ds_step_end:
113            c_cb.set_step_end(self.ds_step_end)
114            at_least_one = True
115
116        if not at_least_one:
117            raise AttributeError("Provided Callback class did not override any of the 6 callback methods.")
118
119        return c_cb
120
121
122class WaitedDSCallback(Callback, DSCallback):
123    """
124    Abstract base class used to build a dataset callback class that is synchronized with the training callback.
125
126    This class can be used to execute a user defined logic right after the previous step or epoch.
127    For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters.
128
129    Args:
130       step_size (int, optional): The number of rows in each step. Usually the step size
131           will be equal to the batch size (Default=1).
132
133    Examples:
134        >>> from mindspore.dataset import WaitedDSCallback
135        >>>
136        >>> my_cb = WaitedDSCallback(32)
137        >>> # dataset is an instance of Dataset object
138        >>> dataset = dataset.map(operations=AugOp(), callbacks=my_cb)
139        >>> dataset = dataset.batch(32)
140        >>> # define the model
141        >>> model.train(epochs, data, callbacks=[my_cb])
142    """
143
144    def __init__(self, step_size=1):
145        super().__init__()
146        self.step_size = step_size
147        self.step_event = threading.Event()
148        self.step_run_context = None
149
150        self.epoch_event = threading.Event()
151        self.epoch_run_context = None
152
153        self.training_ended = False
154
155    def sync_epoch_begin(self, train_run_context, ds_run_context):
156        """
157        Called before a new dataset epoch is started and after the previous training epoch is ended.
158
159        Args:
160            train_run_context: Include some information of the model with feedback from the previous epoch.
161            ds_run_context: Include some information of the dataset pipeline.
162        """
163
164    def sync_step_begin(self, train_run_context, ds_run_context):
165        """
166        Called before a new dataset step is started and after the previous training step is ended.
167
168        Args:
169            train_run_context: Include some information of the model with feedback from the previous step.
170            ds_run_context: Include some information of the dataset pipeline.
171        """
172
173    def epoch_end(self, run_context):
174        """
175        Internal method, do not call/override. Defines epoch_end of Callback to release the wait in ds_epoch_begin.
176
177        Args:
178          run_context: Include some information of the model.
179        """
180        self.epoch_run_context = run_context
181        self.epoch_event.set()
182
183    def ds_epoch_begin(self, ds_run_context):
184        """
185        Internal method, do not call/override. Defines ds_epoch_begin of DSCallback to wait for MS epoch_end callback.
186
187        Args:
188          ds_run_context: Include some information of the pipeline.
189        """
190        if ds_run_context.cur_epoch_num > 1:
191            if not self.training_ended:
192                success = self.epoch_event.wait(timeout=ds.config.get_callback_timeout())
193                self.epoch_event.clear()
194                if not success:
195                    raise RuntimeError(f"ds_epoch_begin timed out after {ds.config.get_callback_timeout()} second(s).")
196            # by the time this thread wakes up, self.epoch_run_context is already available
197            self.sync_epoch_begin(self.epoch_run_context, ds_run_context)
198
199    def step_end(self, run_context):
200        """
201        Internal method, do not call/override. Defines step_end of Callback to release the wait in ds_step_begin.
202
203        Args:
204          run_context: Include some information of the model.
205        """
206        self.step_run_context = run_context
207        self.step_event.set()
208
209    def ds_step_begin(self, ds_run_context):
210        """
211        Internal method, do not call/override. Defines ds_step_begin of DSCallback to wait for MS step_end callback.
212
213        Args:
214            ds_run_context: Include some information of the pipeline.
215        """
216        if ds_run_context.cur_step_num > self.step_size:
217            if not self.training_ended:
218                success = self.step_event.wait(timeout=ds.config.get_callback_timeout())
219                self.step_event.clear()
220                if not success:
221                    raise RuntimeError(f"ds_step_begin timed out after {ds.config.get_callback_timeout()} second(s).")
222                # by the time this thread wakes up, self.epoch_run_context is already available
223            self.sync_step_begin(self.step_run_context, ds_run_context)
224
225    def create_runtime_obj(self):
226        """
227        Creates a runtime (C++) object from the callback methods defined by the user. This method is internal.
228
229        Returns:
230            _c_dataengine.PyDSCallback.
231        """
232        c_cb = PyDSCallback(self.step_size)
233        at_least_one = False
234
235        if self.__class__.sync_step_begin != WaitedDSCallback.sync_step_begin:
236            c_cb.set_step_begin(self.ds_step_begin)
237            at_least_one = True
238
239        if self.__class__.sync_epoch_begin != WaitedDSCallback.sync_epoch_begin:
240            c_cb.set_epoch_begin(self.ds_epoch_begin)
241            at_least_one = True
242
243        if not at_least_one:
244            raise AttributeError("Provided Callback class did not override any of the 2 callback methods.")
245
246        return c_cb
247
248    def end(self, run_context):
249        """
250        Internal method, release the wait if training is ended.
251
252        Args:
253          run_context: Include some information of the model.
254        """
255        self.epoch_end(run_context)
256        self.step_end(run_context)
257        self.training_ended = True
258