• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Model."""
16from collections.abc import Iterable
17
18import os
19import math
20import numpy as np
21
22from mindspore import log as logger
23from ..common.tensor import Tensor
24from ..nn.metrics import get_metrics
25from .._checkparam import check_input_data, check_output_data, Validator
26from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback
27from .. import context
28from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
29    _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check, \
30    _check_task_sink_envs
31from ..parallel._ps_context import _is_role_pserver, _is_role_sched
32from ..nn.metrics import Loss
33from .. import nn
34from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
35from ..boost import AutoBoost
36from ..context import ParallelMode
37from ..parallel._cost_model_context import _set_multi_subgraphs
38from .dataset_helper import DatasetHelper, connect_network_with_dataset
39from . import amp
40from ..common.api import _pynative_executor
41
42
43def _transfer_tensor_to_tuple(inputs):
44    """
45    If the input is a tensor, convert it to a tuple. If not, the output is unchanged.
46    """
47    if isinstance(inputs, Tensor):
48        return (inputs,)
49
50    return inputs
51
52
53class _StepSync(Callback):
54    @staticmethod
55    def step_end(run_context):
56        _pynative_executor.sync()
57
58
59class Model:
60    """
61    High-Level API for training or inference.
62
63    `Model` groups layers into an object with training and inference features.
64
65    Args:
66        network (Cell): A training or testing network.
67        loss_fn (Cell): Objective function, if loss_fn is None, the
68                             network should contain the logic of loss and grads calculation,
69                             and parallel if needed. Default: None.
70        optimizer (Cell): Optimizer for updating the weights. Default: None.
71        metrics (Union[dict, set]): A Dictionary or a set of metrics to be evaluated by the model during
72                        training and inference. eg: {'accuracy', 'recall'}. Default: None.
73        eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as
74                             `eval_network` . Default: None.
75        eval_indexes (list): When defining the `eval_network`, if `eval_indexes` is None, all outputs of the
76                             `eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three
77                             elements, including the positions of loss value, predicted value and label. The loss
78                             value would be passed to the `Loss` metric, the predicted value and label would be passed
79                             to other metric. Default: None.
80        amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network` , level for mixed
81            precision training. Supports ["O0", "O2", "O3", "auto"]. Default: "O0".
82
83            - O0: Do not change.
84            - O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
85            - O3: Cast network to float16, with additional property `keep_batchnorm_fp32=False` .
86            - auto: Set to level to recommended level in different devices. Set level to O2 on GPU, Set
87              level to O3 Ascend. The recommended level is choose by the export experience, cannot
88              always general. User should specify the level for special network.
89
90            O2 is recommended on GPU, O3 is recommended on Ascend.The more detailed explanation of `amp_level` setting
91            can be found at `mindspore.amp.build_train_network` .
92        boost_level (str): Option for argument `level` in `mindspore.boost` , level for boost mode
93            training. Supports ["O0", "O1", "O2"]. Default: "O0".
94
95            - O0: Do not change.
96            - O1: Enable the boost mode, the performance is improved by about 20%, and
97              the accuracy is the same as the original accuracy.
98            - O2: Enable the boost mode, the performance is improved by about 30%, and
99              the accuracy is reduced by less than 3%.
100    Examples:
101        >>> from mindspore import Model, nn
102        >>>
103        >>> class Net(nn.Cell):
104        ...     def __init__(self, num_class=10, num_channel=1):
105        ...         super(Net, self).__init__()
106        ...         self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
107        ...         self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
108        ...         self.fc1 = nn.Dense(16*5*5, 120, weight_init='ones')
109        ...         self.fc2 = nn.Dense(120, 84, weight_init='ones')
110        ...         self.fc3 = nn.Dense(84, num_class, weight_init='ones')
111        ...         self.relu = nn.ReLU()
112        ...         self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
113        ...         self.flatten = nn.Flatten()
114        ...
115        ...     def construct(self, x):
116        ...         x = self.max_pool2d(self.relu(self.conv1(x)))
117        ...         x = self.max_pool2d(self.relu(self.conv2(x)))
118        ...         x = self.flatten(x)
119        ...         x = self.relu(self.fc1(x))
120        ...         x = self.relu(self.fc2(x))
121        ...         x = self.fc3(x)
122        ...         return x
123        >>>
124        >>> net = Net()
125        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
126        >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
127        >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
128        >>> # For details about how to build the dataset, please refer to the tutorial
129        >>> # document on the official website.
130        >>> dataset = create_custom_dataset()
131        >>> model.train(2, dataset)
132    """
133
134    def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None,
135                 eval_indexes=None, amp_level="O0", boost_level="O0", **kwargs):
136        self._network = network
137        self._loss_fn = loss_fn
138        self._optimizer = optimizer
139        self._loss_scale_manager = None
140        self._loss_scale_manager_set = False
141        self._keep_bn_fp32 = True
142        self._check_kwargs(kwargs)
143        self._amp_level = amp_level
144        self._boost_level = boost_level
145        self._eval_network = eval_network
146        self._process_amp_args(kwargs)
147        self._parallel_mode = _get_parallel_mode()
148        self._device_number = _get_device_num()
149        self._global_rank = _get_global_rank()
150        self._parameter_broadcast = _get_parameter_broadcast()
151        self._metrics = metrics
152
153        self._check_amp_level_arg(optimizer, amp_level)
154        self._check_for_graph_cell(kwargs)
155        self._build_boost_network(kwargs)
156        self._train_network = self._build_train_network()
157        self._build_eval_network(metrics, self._eval_network, eval_indexes)
158        self._build_predict_network()
159
160    def _check_for_graph_cell(self, kwargs):
161        """Check for graph cell"""
162        if not isinstance(self._network, nn.GraphCell):
163            return
164        if self._amp_level != "O0":
165            logger.warning("amp_level will not work when network is a GraphCell.")
166
167        if self._loss_fn is not None or self._optimizer is not None:
168            raise ValueError("For 'Model', 'loss_fn' and 'optimizer' should be None when network is a GraphCell, "
169                             "but got 'loss_fn': {}, 'optimizer': {}.".format(self._loss_fn, self._optimizer))
170        if kwargs:
171            raise ValueError("For 'Model', the '**kwargs' argument should be empty when network is a GraphCell.")
172
173    def _process_amp_args(self, kwargs):
174        if self._amp_level in ["O0", "O3"]:
175            self._keep_bn_fp32 = False
176        if 'keep_batchnorm_fp32' in kwargs:
177            self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
178        if 'loss_scale_manager' in kwargs:
179            self._loss_scale_manager = kwargs['loss_scale_manager']
180            self._loss_scale_manager_set = True
181
182    def _check_amp_level_arg(self, optimizer, amp_level):
183        if optimizer is None and amp_level != "O0":
184            raise ValueError(
185                "Auto mixed precision will not work because 'optimizer' is None.Please set amp_level='O0' "
186                "to disable auto mixed precision or set 'optimizer' not be None to use auto mixed precision.")
187
188    def _check_kwargs(self, kwargs):
189        for arg in kwargs:
190            if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
191                raise ValueError(f"The argument in 'kwargs' should be 'loss_scale_manager' or "
192                                 f"'keep_batchnorm_fp32', but got '{arg}'.")
193
194    def _check_reuse_dataset(self, dataset):
195        if not hasattr(dataset, '__model_hash__'):
196            dataset.__model_hash__ = hash(self)
197        if hasattr(dataset, '__model_hash__') and dataset.__model_hash__ != hash(self):
198            raise RuntimeError('The Dataset cannot be bound to different models, please create a new dataset.')
199
200    def _build_boost_network(self, kwargs):
201        """Build the boost network."""
202        processor = AutoBoost(self._boost_level, kwargs)
203        if processor.level not in ["O1", "O2"]:
204            return
205        if self._optimizer is None:
206            logger.warning("In boost mode, the optimizer must be defined.")
207            return
208        if self._eval_network is None and self._metrics is None:
209            logger.warning("In boost mode, the eval_network and metrics cannot be undefined at the same time.")
210            return
211
212        self._network, self._optimizer = processor.network_auto_process_train(self._network, self._optimizer)
213        if self._eval_network is not None:
214            self._eval_network = processor.network_auto_process_eval(self._eval_network)
215
216    def _build_train_network(self):
217        """Build train network"""
218        network = self._network
219        if self._loss_scale_manager is not None and self._optimizer is None:
220            raise ValueError("The argument 'optimizer' can not be None when set 'loss_scale_manager'.")
221
222        if self._optimizer:
223            if self._loss_scale_manager_set:
224                network = amp.build_train_network(network,
225                                                  self._optimizer,
226                                                  self._loss_fn,
227                                                  level=self._amp_level,
228                                                  boost_level=self._boost_level,
229                                                  loss_scale_manager=self._loss_scale_manager,
230                                                  keep_batchnorm_fp32=self._keep_bn_fp32)
231            else:
232                network = amp.build_train_network(network,
233                                                  self._optimizer,
234                                                  self._loss_fn,
235                                                  level=self._amp_level,
236                                                  boost_level=self._boost_level,
237                                                  keep_batchnorm_fp32=self._keep_bn_fp32)
238        elif self._loss_fn:
239            if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
240                network = _VirtualDatasetCell(network)
241            network = nn.WithLossCell(network, self._loss_fn)
242        # If need to check if loss_fn is not None, but optimizer is None
243
244        if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
245            network.set_auto_parallel()
246            if self._optimizer is None:
247                # In this case, multiple optimizer(s) is supposed to be included in 'self._network'
248                _set_multi_subgraphs()
249        return network
250
251    def _build_eval_network(self, metrics, eval_network, eval_indexes):
252        """Build the network for evaluation."""
253        self._metric_fns = get_metrics(metrics)
254        if not self._metric_fns:
255            return
256
257        if eval_network is not None:
258            if eval_indexes is not None and not (isinstance(eval_indexes, list) and len(eval_indexes) == 3):
259                raise ValueError("The argument 'eval_indexes' must be a list or None. If 'eval_indexes' is a list, "
260                                 "length of it must be three. But got 'eval_indexes' {}".format(eval_indexes))
261
262            self._eval_network = eval_network
263            self._eval_indexes = eval_indexes
264        else:
265            if self._loss_fn is None:
266                raise ValueError(f"If `metrics` is set, `eval_network` must not be None. Do not set `metrics` if you"
267                                 f" don't want an evaluation.\n"
268                                 f"If evaluation is required, you need to specify `eval_network`, which will be used in"
269                                 f" the framework to evaluate the model.\n"
270                                 f"For the simple scenarios with one data, one label and one logits, `eval_network` is"
271                                 f" optional, and then you can set `eval_network` or `loss_fn`. For the latter case,"
272                                 f" framework will automatically build an evaluation network with `network` and"
273                                 f" `loss_fn`.")
274
275            self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O2", "O3", "auto"])
276            self._eval_indexes = [0, 1, 2]
277
278        if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
279            if self._optimizer:
280                self._eval_network = _VirtualDatasetCell(self._eval_network)
281            if self._optimizer is None:
282                # In this case, multiple optimizer(s) is supposed to be included in 'self._network'
283                _set_multi_subgraphs()
284            self._eval_network.set_auto_parallel()
285
286    def _build_predict_network(self):
287        """Build the network for prediction."""
288        self._predict_network = self._network
289        if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
290            self._predict_network = _VirtualDatasetCell(self._network)
291            # Unlike the cases in build_train_network() and build_eval_network(), 'multi_subgraphs' is not set
292            self._predict_network.set_auto_parallel()
293
294    def _clear_metrics(self):
295        """Clear metrics local values."""
296        for metric in self._metric_fns.values():
297            metric.clear()
298
299    def _update_metrics(self, outputs):
300        """Update metrics local values."""
301        if isinstance(outputs, Tensor):
302            outputs = (outputs,)
303        if not isinstance(outputs, tuple):
304            raise ValueError(f"The argument 'outputs' should be tuple, but got {type(outputs)}.")
305
306        if self._eval_indexes is not None and len(outputs) < 3:
307            raise ValueError("The length of 'outputs' must be >= 3, but got {}".format(len(outputs)))
308
309        for metric in self._metric_fns.values():
310            if self._eval_indexes is None:
311                metric.update(*outputs)
312            else:
313                if isinstance(metric, Loss):
314                    metric.update(outputs[self._eval_indexes[0]])
315                else:
316                    metric.update(outputs[self._eval_indexes[1]], outputs[self._eval_indexes[2]])
317
318    def _get_metrics(self):
319        """Get metrics local values."""
320        metrics = dict()
321        for key, value in self._metric_fns.items():
322            metrics[key] = value.eval()
323        return metrics
324
325    def _get_scaling_sens(self):
326        """get the scaling sens"""
327        scaling_sens = 1
328        if self._loss_scale_manager is not None:
329            scaling_sens = self._loss_scale_manager.get_loss_scale()
330        if self._parallel_mode == ParallelMode.DATA_PARALLEL:
331            scaling_sens /= self._device_number
332        return scaling_sens
333
334    def _exec_preprocess(self, is_train, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1, dataset_helper=None):
335        """Initializes dataset."""
336        if is_train:
337            network = self._train_network
338            phase = 'train'
339        else:
340            network = self._eval_network
341            phase = 'eval'
342
343        if dataset_sink_mode and not is_train:
344            dataset.__loop_size__ = 1
345
346        if dataset_helper is None:
347            dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num)
348
349        if dataset_sink_mode:
350            network = connect_network_with_dataset(network, dataset_helper)
351
352        network.set_train(is_train)
353        network.phase = phase
354
355        if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
356            network.set_auto_parallel()
357
358        return dataset_helper, network
359
360    def _warmup_dataset(self, epoch, train_dataset, sink_size=-1):
361        """
362        Trigger dataset pipeline running before graph compiling.
363
364        Args:
365            epoch (int): Total number of iterations on the data.
366            train_dataset (Dataset): A training dataset iterator. If `train_dataset` is defined, training graphs will be
367                                     initialized. Default: None.
368            sink_size (int): Control the amount of data in each sink. Default: -1.
369        """
370        if sink_size == -1:
371            epoch_num = epoch
372        else:
373            epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
374            train_dataset.__total_batch__ = epoch * sink_size
375        dataset_helper = None
376        dataset_helper, _ = self._exec_preprocess(is_train=True,
377                                                  dataset=train_dataset,
378                                                  dataset_sink_mode=True,
379                                                  sink_size=sink_size,
380                                                  epoch_num=epoch_num,
381                                                  dataset_helper=dataset_helper)
382        train_dataset._dataset_helper = dataset_helper
383        train_dataset._warmup_epoch = epoch
384
385    def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
386        """
387        Initialize compute graphs and data graphs with the sink mode.
388
389        Note:
390            Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently.
391
392        Args:
393            train_dataset (Dataset): A training dataset iterator. If `train_dataset` is defined, training graphs will be
394                                     initialized. Default: None.
395            valid_dataset (Dataset): A evaluating dataset iterator. If `valid_dataset` is defined, evaluation graphs
396                                     will be initialized, and `metrics` in `Model` can not be None. Default: None.
397            sink_size (int): Control the amount of data in each sink. Default: -1.
398            epoch (int): Total number of iterations on the data. Default: 1.
399        """
400        if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend":
401            raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.')
402
403        if not train_dataset and not valid_dataset:
404            raise ValueError("The argument 'train_dataset' and 'valid_dataset' can not both be None or empty.")
405
406        _device_number_check(self._parallel_mode, self._device_number)
407
408        if train_dataset:
409            _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
410            if self._parameter_broadcast:
411                self._train_network.set_broadcast_flag()
412
413            train_dataset.__no_send__ = True
414            train_dataset_helper, train_network = self._exec_preprocess(is_train=True,
415                                                                        dataset=train_dataset,
416                                                                        dataset_sink_mode=True,
417                                                                        sink_size=sink_size)
418            self._warmup_dataset(epoch, train_dataset, sink_size)
419            self._train_network = train_network
420            if context.get_auto_parallel_context("pipeline_stages") > 1 and valid_dataset:
421                self._train_network.add_flags_recursive(is_first_iteration=True)
422            for inputs in train_dataset_helper:
423                self._train_network.compile(*inputs)
424                break
425
426        if valid_dataset:
427            if not self._metric_fns:
428                raise RuntimeError("If define `valid_dataset`, metric fn can not be None or empty, "
429                                   "you should set the argument 'metrics' for model.")
430
431            valid_dataset.__no_send__ = True
432            valid_dataset_helper, eval_network = self._exec_preprocess(is_train=False,
433                                                                       dataset=valid_dataset,
434                                                                       dataset_sink_mode=True)
435            self._eval_network = eval_network
436            if context.get_auto_parallel_context("pipeline_stages") > 1:
437                self._eval_network.add_flags_recursive(is_first_iteration=False)
438            for inputs in valid_dataset_helper:
439                self._eval_network.compile(*inputs)
440                break
441
442    def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1):
443        """
444        Training.
445
446        Args:
447            epoch (int): Total number of iterations on the data.
448            train_dataset (Dataset): A training dataset iterator. If there is no
449                                     loss_fn, a tuple with multiple data (data1, data2, data3, ...) will be
450                                     returned and passed to the network. Otherwise, a tuple (data, label) will
451                                     be returned. The data and label would be passed to the network and loss
452                                     function respectively.
453            callbacks (list): List of callback objects which should be executed while training. Default: None.
454            dataset_sink_mode (bool): Determine whether the data should be passed through the dataset channel.
455                                      Default: True.
456                                      Configure pynative mode or CPU, the training process will be performed with
457                                      dataset not sink.
458            sink_size (int): Control the amount of data in each sink. Default: -1.
459        """
460        epoch = Validator.check_positive_int(epoch)
461        if context.get_context("device_target") == "Ascend" and \
462           context.get_context("mode") == context.GRAPH_MODE and not \
463           _check_task_sink_envs() and \
464           dataset_sink_mode:
465            dataset_sink_mode = False
466            logger.warning("The Ascend cannot support dataset sink when performed with nontask sink mode."
467                           "So the training process will be performed with dataset not sink.")
468
469        if self._parameter_broadcast:
470            self._train_network.set_broadcast_flag()
471
472        cb_params = _InternalCallbackParam()
473        cb_params.train_network = self._train_network
474        cb_params.epoch_num = epoch
475        if dataset_sink_mode and sink_size > 0:
476            cb_params.batch_num = sink_size
477        else:
478            cb_params.batch_num = train_dataset.get_dataset_size()
479        cb_params.mode = "train"
480        cb_params.loss_fn = self._loss_fn
481        cb_params.optimizer = self._optimizer
482        cb_params.parallel_mode = self._parallel_mode
483        cb_params.device_number = self._device_number
484        cb_params.train_dataset = train_dataset
485        cb_params.list_callback = self._transform_callbacks(callbacks)
486        if context.get_context("mode") == context.PYNATIVE_MODE:
487            cb_params.list_callback.insert(0, _StepSync())
488            callbacks = cb_params.list_callback
489        cb_params.train_dataset_element = None
490        cb_params.network = self._network
491        if _is_role_pserver() or _is_role_sched():
492            epoch = 1
493
494        # build callback list
495        with _CallbackManager(callbacks) as list_callback:
496            self._check_reuse_dataset(train_dataset)
497            if not dataset_sink_mode:
498                self._train_process(epoch, train_dataset, list_callback, cb_params)
499            elif context.get_context("device_target") == "CPU":
500                logger.warning("The CPU cannot support dataset sink mode currently."
501                               "So the training process will be performed with dataset not sink.")
502                self._train_process(epoch, train_dataset, list_callback, cb_params)
503            else:
504                self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size)
505
506    @staticmethod
507    def _transform_callbacks(callbacks):
508        """Transform callback to a list."""
509        if callbacks is None:
510            return []
511
512        if isinstance(callbacks, Iterable):
513            return list(callbacks)
514
515        return [callbacks]
516
517    def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1):
518        """
519        Training process. The data would be passed to network through dataset channel.
520
521        Args:
522            epoch (int): Total number of iterations on the data.
523            train_dataset (Dataset): A training dataset iterator. If there is no
524                                     loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
525                                     returned and passed to the network. Otherwise, a tuple (data, label) should
526                                     be returned. The data and label would be passed to the network and loss
527                                     function respectively.
528            list_callback (Callback): Executor of callback list. Default: None.
529            cb_params (_InternalCallbackParam): Callback parameters. Default: None.
530            sink_size (int): Control the amount of data in each sink. Default: -1.
531        """
532        is_graph = (context.get_context("mode") == context.GRAPH_MODE)
533        if sink_size == -1:
534            epoch_num = epoch
535        else:
536            epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
537            train_dataset.__total_batch__ = epoch * sink_size
538
539        cb_params.cur_step_num = 0
540        cb_params.dataset_sink_mode = True
541
542        run_context = RunContext(cb_params)
543        list_callback.begin(run_context)
544        # used to stop training for early stop, such as stopAtTIme or stopATStep
545        should_stop = False
546        dataset_helper = None
547        if hasattr(train_dataset, '_dataset_helper'):
548            dataset_helper = train_dataset._dataset_helper
549        for i in range(epoch):
550            cb_params.cur_epoch_num = i + 1
551            list_callback.epoch_begin(run_context)
552            dataset_helper, train_network = self._exec_preprocess(is_train=True,
553                                                                  dataset=train_dataset,
554                                                                  dataset_sink_mode=True,
555                                                                  sink_size=sink_size,
556                                                                  epoch_num=epoch_num,
557                                                                  dataset_helper=dataset_helper)
558
559            self._train_network = train_network
560            cb_params.train_network = self._train_network
561
562            # for data sink dataset_helper only iter once, other wise iter epoch_size times.
563            for inputs in dataset_helper:
564                cb_params.train_dataset_element = inputs
565                list_callback.step_begin(run_context)
566                outputs = self._train_network(*inputs)
567                if is_graph:
568                    cb_params.cur_step_num += dataset_helper.sink_size()
569                else:
570                    cb_params.cur_step_num += 1
571                cb_params.net_outputs = outputs
572                list_callback.step_end(run_context)
573                if _is_role_pserver():
574                    os._exit(0)
575
576            dataset_helper.continue_send()
577            list_callback.epoch_end(run_context)
578            should_stop = should_stop or run_context.get_stop_requested()
579            if should_stop:
580                break
581        dataset_helper.stop_send()
582        dataset_helper.release()
583
584        list_callback.end(run_context)
585
586    def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
587        """
588        Training process. The data would be passed to network directly.
589
590        Args:
591            epoch (int): Total number of iterations on the data.
592            train_dataset (Dataset): A training dataset iterator. If there is no
593                                     loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
594                                     returned and passed to the network. Otherwise, a tuple (data, label) should
595                                     be returned. The data and label would be passed to the network and loss
596                                     function respectively.
597            list_callback (Callback): Executor of callback list. Default: None.
598            cb_params (_InternalCallbackParam): Callback parameters. Default: None.
599        """
600        dataset_helper, _ = self._exec_preprocess(is_train=True,
601                                                  dataset=train_dataset,
602                                                  dataset_sink_mode=False,
603                                                  epoch_num=epoch)
604        cb_params.cur_step_num = 0
605        cb_params.dataset_sink_mode = False
606        run_context = RunContext(cb_params)
607        list_callback.begin(run_context)
608        # used to stop training for early stop, such as stopAtTIme or stopATStep
609        should_stop = False
610        for i in range(epoch):
611            cb_params.cur_epoch_num = i + 1
612
613            list_callback.epoch_begin(run_context)
614
615            for next_element in dataset_helper:
616                len_element = len(next_element)
617                next_element = _transfer_tensor_to_tuple(next_element)
618                if self._loss_fn and len_element != 2:
619                    raise ValueError("When 'loss_fn' is not None, 'train_dataset' should return "
620                                     "two elements, but got {}, please check the number of elements "
621                                     "returned by 'train_dataset'".format(len_element))
622                cb_params.cur_step_num += 1
623
624                cb_params.train_dataset_element = next_element
625                list_callback.step_begin(run_context)
626                outputs = self._train_network(*next_element)
627                cb_params.net_outputs = outputs
628                if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
629                    _, overflow, _ = outputs
630                    overflow = np.all(overflow.asnumpy())
631                    self._loss_scale_manager.update_loss_scale(overflow)
632
633                list_callback.step_end(run_context)
634                if _is_role_pserver():
635                    os._exit(0)
636                should_stop = should_stop or run_context.get_stop_requested()
637                if should_stop:
638                    break
639
640            train_dataset.reset()
641
642            # if param is cache enable, flush data from cache to host before epoch end
643            self._flush_from_cache(cb_params)
644
645            list_callback.epoch_end(run_context)
646            should_stop = should_stop or run_context.get_stop_requested()
647            if should_stop:
648                break
649
650        list_callback.end(run_context)
651
652    def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1):
653        """
654        Training API where the iteration is controlled by python front-end.
655
656        When setting pynative mode or CPU, the training process will be performed with dataset not sink.
657
658        Note:
659            If dataset_sink_mode is True, data will be sent to device. If the device is Ascend, features
660            of data will be transferred one by one. The limitation of data transmission per time is 256M.
661            When dataset_sink_mode is True, the step_end method of the Callback class will be executed when
662            the epoch_end method is called.
663            If sink_size > 0, each epoch of the dataset can be traversed unlimited times until you get sink_size
664            elements of the dataset. The next epoch continues to traverse from the end position of the previous
665            traversal. The interface builds the computational graphs and then executes the computational graphs.
666            However, when the 'model.build' is executed first, it only performs the graphs execution.
667
668        Args:
669            epoch (int): Generally, total number of iterations on the data per epoch.
670                         When dataset_sink_mode is set to true and sink_size>0, each epoch sink sink_size
671                         steps on the data instead of total number of iterations.
672            train_dataset (Dataset): A training dataset iterator. If there is no
673                                     loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
674                                     returned and passed to the network. Otherwise, a tuple (data, label) should
675                                     be returned. The data and label would be passed to the network and loss
676                                     function respectively.
677            callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object,
678                                                            which should be executed while training.
679                                                            Default: None.
680            dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
681                                      Configure pynative mode or CPU, the training process will be performed with
682                                      dataset not sink. Default: True.
683            sink_size (int): Control the amount of data in each sink.
684                             If sink_size = -1, sink the complete dataset for each epoch.
685                             If sink_size > 0, sink sink_size data for each epoch.
686                             If dataset_sink_mode is False, set sink_size as invalid.
687                             Default: -1.
688
689        Examples:
690            >>> from mindspore import Model, nn, FixedLossScaleManager
691            >>>
692            >>> # For details about how to build the dataset, please refer to the tutorial
693            >>> # document on the official website.
694            >>> dataset = create_custom_dataset()
695            >>> net = Net()
696            >>> loss = nn.SoftmaxCrossEntropyWithLogits()
697            >>> loss_scale_manager = FixedLossScaleManager()
698            >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
699            >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
700            >>> model.train(2, dataset)
701        """
702        dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
703        if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode is True:
704            raise ValueError("Dataset sink mode is currently not supported when training with a GraphCell.")
705
706        if hasattr(train_dataset, '_warmup_epoch') and train_dataset._warmup_epoch != epoch:
707            raise ValueError("Use Model.build to initialize model, but the value of parameter `epoch` in Model.build "
708                             "is not equal to value in Model.train, got {} and {} separately."
709                             .format(train_dataset._warmup_epoch, epoch))
710
711        Validator.check_is_int(sink_size)
712        dataset_size = train_dataset.get_dataset_size()
713        if dataset_size == 0:
714            raise ValueError("There is no valid data in dataset, please check dataset file firstly.")
715        if sink_size == -1:
716            sink_size = dataset_size
717        if sink_size < -1 or sink_size == 0:
718            raise ValueError("The argument 'sink_size' must be -1 or positive, but got {}.".format(sink_size))
719
720        _device_number_check(self._parallel_mode, self._device_number)
721
722        self._train(epoch,
723                    train_dataset,
724                    callbacks=callbacks,
725                    dataset_sink_mode=dataset_sink_mode,
726                    sink_size=sink_size)
727
728    def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
729        """
730        Build computational graphs and data graphs with the sink mode.
731
732        .. warning::
733            This is an experimental prototype that is subject to change and/or deletion.
734
735        Note:
736            Pre-build process only supports `GRAPH_MODE` and `Ascend` target currently.
737            The interface builds the computational graphs, when the interface is executed first,
738            'model.train' only performs the graphs execution.
739            It only support dataset sink mode.
740
741        Args:
742            train_dataset (Dataset): A training dataset iterator. If `train_dataset` is defined, training graphs will be
743                                     initialized. Default: None.
744            valid_dataset (Dataset): An evaluating dataset iterator. If `valid_dataset` is defined, evaluation graphs
745                                     will be initialized, and `metrics` in `Model` can not be None. Default: None.
746            sink_size (int): Control the amount of data in each sink. Default: -1.
747            epoch (int): Control the training epochs. Default: 1.
748
749        Examples:
750            >>> from mindspore import Model, nn, FixedLossScaleManager
751            >>>
752            >>> # For details about how to build the dataset, please refer to the tutorial
753            >>> # document on the official website.
754            >>> dataset = create_custom_dataset()
755            >>> net = Net()
756            >>> loss = nn.SoftmaxCrossEntropyWithLogits()
757            >>> loss_scale_manager = FixedLossScaleManager()
758            >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
759            >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
760            >>> model.build(dataset, epoch=2)
761            >>> model.train(2, dataset)
762        """
763        self._init(train_dataset, valid_dataset, sink_size, epoch)
764
765    def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
766        """
767        Evaluation. The data would be passed to network through dataset channel.
768
769        Args:
770            valid_dataset (Dataset): Dataset to evaluate the model.
771            list_callback (Callback): Executor of callback list. Default: None.
772            cb_params (_InternalCallbackParam): Callback parameters. Default: None.
773
774        Returns:
775            Dict, which returns the loss value and metrics values for the model in the test mode.
776        """
777        run_context = RunContext(cb_params)
778
779        dataset_helper, eval_network = self._exec_preprocess(is_train=False,
780                                                             dataset=valid_dataset,
781                                                             dataset_sink_mode=True)
782        self._eval_network = eval_network
783        cb_params.eval_network = self._eval_network
784        cb_params.dataset_sink_mode = True
785        list_callback.begin(run_context)
786        list_callback.epoch_begin(run_context)
787        for inputs in dataset_helper:
788            cb_params.cur_step_num += 1
789            list_callback.step_begin(run_context)
790            outputs = self._eval_network(*inputs)
791            cb_params.net_outputs = outputs
792            list_callback.step_end(run_context)
793            self._update_metrics(outputs)
794
795        list_callback.epoch_end(run_context)
796        metrics = self._get_metrics()
797        cb_params.metrics = metrics
798        list_callback.end(run_context)
799
800        return metrics
801
802    def _eval_process(self, valid_dataset, list_callback=None, cb_params=None):
803        """
804        Evaluation. The data would be passed to network directly.
805
806        Args:
807            valid_dataset (Dataset): Dataset to evaluate the model.
808            list_callback (Callback): Executor of callback list. Default: None.
809            cb_params (_InternalCallbackParam): Callback parameters. Default: None.
810
811        Returns:
812            Dict, which returns the loss value and metrics values for the model in the test mode.
813        """
814        run_context = RunContext(cb_params)
815        cb_params.dataset_sink_mode = False
816        list_callback.begin(run_context)
817        dataset_helper, _ = self._exec_preprocess(is_train=False,
818                                                  dataset=valid_dataset,
819                                                  dataset_sink_mode=False)
820        list_callback.epoch_begin(run_context)
821        for next_element in dataset_helper:
822            cb_params.cur_step_num += 1
823            list_callback.step_begin(run_context)
824            next_element = _transfer_tensor_to_tuple(next_element)
825            outputs = self._eval_network(*next_element)
826            cb_params.net_outputs = outputs
827            list_callback.step_end(run_context)
828            self._update_metrics(outputs)
829
830        list_callback.epoch_end(run_context)
831        valid_dataset.reset()
832        metrics = self._get_metrics()
833        cb_params.metrics = metrics
834        list_callback.end(run_context)
835        return metrics
836
837    def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True):
838        """
839        Evaluation API where the iteration is controlled by python front-end.
840
841        Configure to pynative mode or CPU, the evaluating process will be performed with dataset non-sink mode.
842
843        Note:
844            If dataset_sink_mode is True, data will be sent to device. If the device is Ascend, features
845            of data will be transferred one by one. The limitation of data transmission per time is 256M.
846            When dataset_sink_mode is True, the step_end method of the Callback class will be executed when
847            the epoch_end method is called.
848
849        Args:
850            valid_dataset (Dataset): Dataset to evaluate the model.
851            callbacks (Optional[list(Callback)]): List of callback objects which should be executed
852                while training. Default: None.
853            dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
854                Default: True.
855
856        Returns:
857            Dict, the key is the metric name defined by users and the value is the metrics value for
858            the model in the test mode.
859
860        Examples:
861            >>> from mindspore import Model, nn
862            >>>
863            >>> # For details about how to build the dataset, please refer to the tutorial
864            >>> # document on the official website.
865            >>> dataset = create_custom_dataset()
866            >>> net = Net()
867            >>> loss = nn.SoftmaxCrossEntropyWithLogits()
868            >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
869            >>> acc = model.eval(dataset, dataset_sink_mode=False)
870        """
871        dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
872
873        _device_number_check(self._parallel_mode, self._device_number)
874        if not self._metric_fns:
875            raise ValueError("The model argument 'metrics' can not be None or empty, "
876                             "you should set the argument 'metrics' for model.")
877        if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode is True:
878            raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.")
879
880        cb_params = _InternalCallbackParam()
881        cb_params.eval_network = self._eval_network
882        cb_params.valid_dataset = valid_dataset
883        cb_params.batch_num = valid_dataset.get_dataset_size()
884        cb_params.mode = "eval"
885        cb_params.cur_step_num = 0
886        cb_params.list_callback = self._transform_callbacks(callbacks)
887        cb_params.network = self._network
888
889        self._clear_metrics()
890
891        if context.get_context("device_target") == "CPU" and dataset_sink_mode:
892            dataset_sink_mode = False
893            logger.warning("CPU cannot support dataset sink mode currently."
894                           "So the evaluating process will be performed with dataset non-sink mode.")
895        if context.get_context("device_target") == "Ascend" and \
896           context.get_context("mode") == context.GRAPH_MODE and not \
897           _check_task_sink_envs() and \
898           dataset_sink_mode:
899            dataset_sink_mode = False
900            logger.warning("The Ascend cannot support dataset sink when performed with nontask sink mode."
901                           "So the training process will be performed with dataset not sink.")
902
903        with _CallbackManager(callbacks) as list_callback:
904            if dataset_sink_mode:
905                return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
906            return self._eval_process(valid_dataset, list_callback, cb_params)
907
908    def predict(self, *predict_data):
909        """
910        Generate output predictions for the input samples.
911
912        Data could be a single tensor, a list of tensor, or a tuple of tensor.
913
914        Note:
915            This is a pre-compile function. The arguments should be the same with model.predict() function.
916
917        Args:
918            predict_data (Optional[Tensor, list[Tensor], tuple[Tensor]]): The predict data, can be a single tensor,
919                a list of tensor, or a tuple of tensor.
920
921        Returns:
922            Tensor, array(s) of predictions.
923
924        Examples:
925            >>> import mindspore as ms
926            >>> from mindspore import Model, Tensor
927            >>>
928            >>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32)
929            >>> model = Model(Net())
930            >>> result = model.predict(input_data)
931        """
932        self._predict_network.set_train(False)
933        check_input_data(*predict_data, data_class=(int, float, str, None, Tensor))
934        _parallel_predict_check()
935        result = self._predict_network(*predict_data)
936
937        check_output_data(result)
938        return result
939
940    def _infer_train_check(self, train_dataset, dataset_sink_mode, sink_size):
941        """
942        Check arguments of training.
943
944        Args:
945            train_dataset (Dataset): A training dataset iterator.
946            dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
947            sink_size (int): Control the amount of data in each sink.
948        """
949        if context.get_context("mode") != context.GRAPH_MODE:
950            raise RuntimeError("Pre-compile process that generate parameter layout for the train network "
951                               "only supports GRAPH MODE and Ascend target currently.")
952        if _get_parallel_mode() not in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
953            raise RuntimeError("'infer_train_layout' only supports 'semi_auto_parallel' and 'auto_parallel' "
954                               "mode, but got {}.".format(_get_parallel_mode()))
955        dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
956        if not dataset_sink_mode:
957            raise ValueError("Only dataset sink mode is supported for now.")
958        if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode is True:
959            raise ValueError("Dataset sink mode is currently not supported when training with a GraphCell.")
960        Validator.check_is_int(sink_size)
961        dataset_size = train_dataset.get_dataset_size()
962        if dataset_size == 0:
963            raise ValueError("There is no valid data in dataset, please check dataset file firstly.")
964        if sink_size == -1:
965            sink_size = dataset_size
966        if sink_size < -1 or sink_size == 0:
967            raise ValueError("The 'sink_size' must be -1 or positive, but got sink_size {}.".format(sink_size))
968
969    def infer_train_layout(self, train_dataset, dataset_sink_mode=True, sink_size=-1):
970        """
971        Generate parameter layout for the train network in auto or semi auto parallel mode.
972        Only dataset sink mode is supported for now.
973
974        .. warning::
975            This is an experimental prototype that is subject to change and/or deletion.
976
977        Note:
978            This is a pre-compile function. The arguments should be the same with model.train() function.
979
980        Args:
981            train_dataset (Dataset): A training dataset iterator. If there is no
982                         loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
983                         returned and passed to the network. Otherwise, a tuple (data, label) should
984                         be returned. The data and label would be passed to the network and loss
985                         function respectively.
986            dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
987                                      Configure pynative mode or CPU, the training process will be performed with
988                                      dataset not sink. Default: True.
989            sink_size (int): Control the amount of data in each sink.
990                             If sink_size = -1, sink the complete dataset for each epoch.
991                             If sink_size > 0, sink sink_size data for each epoch.
992                             If dataset_sink_mode is False, set sink_size as invalid.
993                             Default: -1.
994
995        Returns:
996            Dict, Parameter layout dictionary used for load distributed checkpoint
997
998        Examples:
999            >>> # This example should be run with multiple devices. Refer to the tutorial > Distributed Training on
1000            >>> # mindspore.cn.
1001            >>> import numpy as np
1002            >>> import mindspore as ms
1003            >>> from mindspore import Model, context, Tensor, nn, FixedLossScaleManager
1004            >>> from mindspore.context import ParallelMode
1005            >>> from mindspore.communication import init
1006            >>>
1007            >>> context.set_context(mode=context.GRAPH_MODE)
1008            >>> init()
1009            >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
1010            >>>
1011            >>> # For details about how to build the dataset, please refer to the tutorial
1012            >>> # document on the official website.
1013            >>> dataset = create_custom_dataset()
1014            >>> net = Net()
1015            >>> loss = nn.SoftmaxCrossEntropyWithLogits()
1016            >>> loss_scale_manager = FixedLossScaleManager()
1017            >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
1018            >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
1019            >>> layout_dict = model.infer_train_layout(dataset)
1020        """
1021        self._infer_train_check(train_dataset, dataset_sink_mode, sink_size)
1022
1023        train_dataset.__no_send__ = True
1024        train_dataset_helper, train_network = self._exec_preprocess(is_train=True,
1025                                                                    dataset=train_dataset,
1026                                                                    dataset_sink_mode=dataset_sink_mode,
1027                                                                    sink_size=sink_size)
1028        self._train_network = train_network
1029        for inputs in train_dataset_helper:
1030            self._train_network.compile(*inputs)
1031            break
1032        train_dataset.__model_hash__ = hash(self)
1033        return self._train_network.parameter_layout_dict
1034
1035    def infer_predict_layout(self, *predict_data):
1036        """
1037        Generate parameter layout for the predict network in auto or semi auto parallel mode.
1038
1039        Data could be a single tensor or multiple tensors.
1040
1041        Note:
1042            Batch data should be put together in one tensor.
1043
1044        Args:
1045            predict_data (Tensor): One tensor or multiple tensors of predict data.
1046
1047        Returns:
1048            Dict, Parameter layout dictionary used for load distributed checkpoint.
1049
1050        Raises:
1051            RuntimeError: If get_context is not GRAPH_MODE.
1052
1053        Examples:
1054            >>> # This example should be run with multiple devices. Refer to the tutorial > Distributed Training on
1055            >>> # mindspore.cn.
1056            >>> import numpy as np
1057            >>> import mindspore as ms
1058            >>> from mindspore import Model, context, Tensor
1059            >>> from mindspore.context import ParallelMode
1060            >>> from mindspore.communication import init
1061            >>>
1062            >>> context.set_context(mode=context.GRAPH_MODE)
1063            >>> init()
1064            >>> context.set_auto_parallel_context(full_batch=True, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
1065            >>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32)
1066            >>> model = Model(Net())
1067            >>> predict_map = model.infer_predict_layout(input_data)
1068        """
1069        if context.get_context("mode") != context.GRAPH_MODE:
1070            raise RuntimeError("Pre-compile process that generate parameter layout for the predict network "
1071                               "only supports GRAPH MODE and Ascend target currently.")
1072        if _get_parallel_mode() not in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
1073            raise RuntimeError('Infer predict layout only supports semi auto parallel and auto parallel mode.')
1074        _parallel_predict_check()
1075        check_input_data(*predict_data, data_class=Tensor)
1076
1077        predict_net = self._predict_network
1078        # Unlike the cases in build_train_network() and build_eval_network(), 'multi_subgraphs' is not set
1079        predict_net.set_auto_parallel()
1080        predict_net.set_train(False)
1081        predict_net.compile(*predict_data)
1082        return predict_net.parameter_layout_dict
1083
1084    def _flush_from_cache(self, cb_params):
1085        """Flush cache data to host if tensor is cache enable."""
1086        params = cb_params.train_network.get_parameters()
1087        for param in params:
1088            if param.cache_enable:
1089                Tensor(param).flush_from_cache()
1090
1091    @property
1092    def train_network(self):
1093        """Get the model's train_network."""
1094        return self._train_network
1095
1096    @property
1097    def predict_network(self):
1098        """Get the model's predict_network."""
1099        return self._predict_network
1100
1101    @property
1102    def eval_network(self):
1103        """Get the model's eval_network."""
1104        return self._eval_network
1105
1106
1107__all__ = ["Model"]
1108