• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Callback related classes and functions."""
16
17from contextlib import ExitStack
18
19from mindspore import log as logger
20from mindspore.train.serialization import _fill_param_into_net
21from mindspore.train.summary.summary_record import _cache_summary_tensor_data
22
23_cur_net = None
24
25
26def set_cur_net(net):
27
28    """
29    Set current net for which we are using to save checkpoint.
30
31    Args:
32        net (Cell): train network
33    """
34    global _cur_net
35    _cur_net = net
36
37
38def checkpoint_cb_for_save_op(parameter_list):
39    """
40    The checkpoint callback function for MindSpore.
41
42    Will be executed by checkpoint save op.
43
44    Args:
45        parameter_list (list): Format is like [{"name",name},{"data",value}] and value type is Tensor.
46
47    Returns:
48        bool, true: means save checkpoint success.
49    """
50    if _cur_net is None:
51        logger.warning("_cur_net is None. parameters are not updated.")
52        return False
53
54    logger.info("update parameters in the net.")
55    _fill_param_into_net(_cur_net, parameter_list)
56    set_cur_net(None)
57    return True
58
59
60def summary_cb_for_save_op(summary_list):
61    """
62    The summary callback function for MindSpore.
63
64    Will be executed by summary op.
65
66    Args:
67        summary_list (list): Format is like [{"name": tag_name, "data": tensor},...] and value is Scalar/Tensor.
68
69    Returns:
70        bool, true: means save summary success.
71    """
72    ret = _cache_summary_tensor_data(summary_list)
73    return ret
74
75
76class Callback:
77    """
78    Abstract base class used to build a callback class. Callbacks are context managers
79    which will be entered and exited when passing into the Model.
80    You can use this mechanism to initialize and release resources automatically.
81
82    Callback function will execute some operations in the current step or epoch.
83
84    It holds the information of the model. Such as `network`, `train_network`, `epoch_num`, `batch_num`,
85        `loss_fn`, `optimizer`, `parallel_mode`, `device_number`, `list_callback`, `cur_epoch_num`,
86        `cur_step_num`, `dataset_sink_mode`, `net_outputs` and so on.
87
88    Examples:
89        >>> from mindspore import Model, nn
90        >>> from mindspore.train.callback import Callback
91        >>> class Print_info(Callback):
92        ...     def step_end(self, run_context):
93        ...         cb_params = run_context.original_args()
94        ...         print("step_num: ", cb_params.cur_step_num)
95        >>>
96        >>> print_cb = Print_info()
97        >>> dataset = create_custom_dataset()
98        >>> net = Net()
99        >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
100        >>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
101        >>> model = Model(net, loss_fn=loss, optimizer=optim)
102        >>> model.train(1, dataset, callbacks=print_cb)
103        step_num: 1
104    """
105
106    def __enter__(self):
107        """Return the enter target."""
108        return self
109
110    def __exit__(self, *err):
111        """Release resources here if have any."""
112
113    def begin(self, run_context):
114        """
115        Called once before the network executing.
116
117        Args:
118            run_context (RunContext): Include some information of the model.
119        """
120
121    def epoch_begin(self, run_context):
122        """
123        Called before each epoch beginning.
124
125        Args:
126            run_context (RunContext): Include some information of the model.
127        """
128
129    def epoch_end(self, run_context):
130        """
131        Called after each epoch finished.
132
133        Args:
134            run_context (RunContext): Include some information of the model.
135        """
136
137    def step_begin(self, run_context):
138        """
139        Called before each step beginning.
140
141        Args:
142            run_context (RunContext): Include some information of the model.
143        """
144
145    def step_end(self, run_context):
146        """
147        Called after each step finished.
148
149        Args:
150            run_context (RunContext): Include some information of the model.
151        """
152
153    def end(self, run_context):
154        """
155        Called once after network training.
156
157        Args:
158            run_context (RunContext): Include some information of the model.
159        """
160
161
162class CallbackManager(Callback):
163    """
164    Sequential execution of callback functions.
165
166    Execute Callback functions at certain points.
167
168    Args:
169        callbacks (Optional[list[Callback], Callback]): None, callback, or callbacks list.
170    """
171
172    def __init__(self, callbacks):
173        self._callbacks, self._stack = [], None
174        if isinstance(callbacks, Callback):
175            self._callbacks.append(callbacks)
176        elif isinstance(callbacks, list):
177            for cb in callbacks:
178                if not isinstance(cb, Callback):
179                    raise TypeError("When the 'callbacks' is a list, the elements in "
180                                    "'callbacks' must be Callback functions.")
181                self._callbacks.append(cb)
182        elif callbacks is not None:
183            raise TypeError("The 'callbacks' is not a Callback or a list of Callback.")
184
185    def __enter__(self):
186        if self._stack is None:
187            callbacks, self._stack = [], ExitStack().__enter__()
188            for callback in self._callbacks:
189                target = self._stack.enter_context(callback)
190                if not isinstance(target, Callback):
191                    logger.warning("Please return 'self' or a Callback as the enter target.")
192                    callbacks.append(callback)
193                else:
194                    callbacks.append(target)
195            self._callbacks = callbacks
196        return self
197
198    def __exit__(self, *err):
199        return self._stack.__exit__(*err)
200
201    def begin(self, run_context):
202        """Called once before network training."""
203        for cb in self._callbacks:
204            cb.begin(run_context)
205
206    def epoch_begin(self, run_context):
207        """Called before each epoch begin."""
208        for cb in self._callbacks:
209            cb.epoch_begin(run_context)
210
211    def epoch_end(self, run_context):
212        """Called after each epoch finished."""
213        for cb in self._callbacks:
214            cb.epoch_end(run_context)
215
216    def step_begin(self, run_context):
217        """Called before each epoch begin."""
218        for cb in self._callbacks:
219            cb.step_begin(run_context)
220
221    def step_end(self, run_context):
222        """Called after each step finished."""
223        for cb in self._callbacks:
224            cb.step_end(run_context)
225
226    def end(self, run_context):
227        """Called once after network training."""
228        for cb in self._callbacks:
229            cb.end(run_context)
230
231
232class InternalCallbackParam(dict):
233    """Internal callback object's parameters."""
234
235    def __getattr__(self, key):
236        return self[key]
237
238    def __setattr__(self, key, value):
239        self[key] = value
240
241
242class RunContext:
243    """
244    Provide information about the model.
245
246    Provide information about original request to model function.
247    Callback objects can stop the loop by calling request_stop() of run_context.
248
249    Args:
250        original_args (dict): Holding the related information of model.
251    """
252    def __init__(self, original_args):
253        if not isinstance(original_args, dict):
254            raise TypeError("The argument 'original_args' of RunContext should be dict type, "
255                            "but got {}.".format(type(original_args)))
256        self._original_args = original_args
257        self._stop_requested = False
258
259    def original_args(self):
260        """
261        Get the _original_args object.
262
263        Returns:
264           Dict, an object that holds the original arguments of model.
265        """
266        return self._original_args
267
268    def request_stop(self):
269        """
270        Set stop requirement during training.
271
272        Callbacks can use this function to request stop of iterations.
273        model.train() checks whether this is called or not.
274        """
275        self._stop_requested = True
276
277    def get_stop_requested(self):
278        """
279        Return whether a stop is requested or not.
280
281        Returns:
282            bool, if true, model.train() stops iterations.
283        """
284        return self._stop_requested
285