• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16High-Level API for Second Order Training or Testing.
17Second-order optimizer THOR reduces the computation workload and improves the computation speed by reducing the
18frequency of updating the second-order matrix. In order to optimize the overall performance, the ModelThor class
19is redefined to inherit the Model class provided by MindSpore. The parameter of THOR for controlling the frequency
20of updating the second-order matrix can be obtained by the ModelThor class.
21"""
22from __future__ import absolute_import
23
24import math
25
26from mindspore.train.callback import RunContext
27from mindspore import context
28from mindspore import nn
29from mindspore.train.model import Model
30from mindspore.train.dataset_helper import connect_network_with_dataset
31from mindspore.parallel._utils import _need_to_full, _to_full_tensor
32from mindspore.common.dtype import pytype_to_dtype
33from mindspore._c_expression import init_exec_dataset
34from mindspore.train.train_thor.dataset_helper import DatasetHelper
35
36
37def _convert_to_ms_type(types):
38    """
39    Convert from numpy type to mindspore tensor type.
40
41    Args:
42        types (list): Numpy type list of element in dataset.
43
44    Returns:
45        list, list of element in dataset.
46    """
47    ms_types = []
48    for numpy_type in types:
49        ms_type = pytype_to_dtype(numpy_type)
50        ms_types.append(ms_type)
51    return ms_types
52
53
54def _get_types_and_shapes(dataset):
55    """Get dataset types and shapes."""
56    dataset_types = _convert_to_ms_type(dataset.output_types())
57    dataset_shapes = dataset.output_shapes()
58    return dataset_types, dataset_shapes
59
60
61def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
62    """Initialize and execute the dataset graph."""
63    batch_size = exec_dataset.get_batch_size()
64    input_indexs = exec_dataset.input_indexs
65
66    # transform data format
67    dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
68    init_exec_dataset(exec_dataset.__transfer_dataset__.queue_name,
69                      dataset_size,
70                      batch_size,
71                      dataset_types,
72                      dataset_shapes,
73                      input_indexs,
74                      phase=phase,
75                      need_run=False)
76
77
78class ModelThor(Model):
79    """
80    High-Level API for Training or Testing.
81
82    `Model` groups layers into an object with training and inference features.
83
84    Args:
85        network (Cell): A training or testing network.
86        loss_fn (Cell): Objective function, if loss_fn is None, the
87                             network should contain the logic of loss and grads calculation, and the logic
88                             of parallel if needed. Default: ``None``.
89        optimizer (Cell): Optimizer for updating the weights. Default: ``None``.
90        metrics (Union[dict, set]): A Dictionary or a set of metrics to be evaluated by the model during
91                        training and testing. eg: {'accuracy', 'recall'}. Default: ``None``.
92        eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as
93                             `eval_network`. Default: ``None``.
94        eval_indexes (list): When defining the `eval_network`, if `eval_indexes` is None, all outputs of the
95                             `eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three
96                             elements, including the positions of loss value, predicted value and label. The loss
97                             value would be passed to the `Loss` metric, the predicted value and label would be passed
98                             to other metric. Default: ``None``.
99        amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed
100            precision training. Supports [O0, O2, O3]. Default: "O0".
101
102            - O0: Do not change.
103            - O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
104            - O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
105
106            O2 is recommended on GPU, O3 is recommended on Ascend.
107
108        loss_scale_manager (Union[None, LossScaleManager]): If it is None, the loss would not be scaled. Otherwise,
109            scale the loss by LossScaleManager. It is a key argument.
110            e.g. Use `loss_scale_manager=None` to set the value.
111        keep_batchnorm_fp32 (bool): Keep Batchnorm running in `float32`. If it is set to true, the level setting before
112            will be overwritten. Default: ``True``.
113    """
114
115    def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None,
116                 eval_indexes=None, amp_level="O0", **kwargs):
117        super(ModelThor, self).__init__(network, loss_fn, optimizer, metrics, eval_network,
118                                        eval_indexes, amp_level, **kwargs)
119        if isinstance(network, nn.TrainOneStepCell):
120            self._frequency = network.optimizer.get_frequency()
121        else:
122            self._frequency = optimizer.get_frequency()
123        # used to stop training for early stop, such as stopAtTIme or stopATStep
124        self.switch_branch_one = True
125        self.index_first_order = 0
126        self.train_network_init_flag = True
127        self.has_do_dataset_init = False
128        self._train_network = self._build_train_network()
129
130    def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1,
131                         epoch_num=1, iter_first_order=1):
132        """Initializes dataset."""
133        if dataset_sink_mode and not is_train:
134            dataset.__loop_size__ = 1
135        dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order)
136
137        if dataset_sink_mode and context.get_context("device_target") != "GPU":
138            network = connect_network_with_dataset(network, dataset_helper)
139        network.set_train(is_train)
140        network.phase = phase
141
142        return dataset_helper, network
143
144    def _train_gpu_sink_step(self, cb_params, inputs, list_callback, iter_first_order, run_context):
145        """train gpu sink step"""
146        if self.switch_branch_one:
147            cb_params.cur_step_num += 1
148            if self.train_network_init_flag:
149                self._train_network.add_flags_recursive(thor=True)
150            self._train_network.phase = 'train0'
151            self.switch_branch_one = not self.switch_branch_one
152            outputs = self._train_network(*inputs)
153            cb_params.net_outputs = outputs
154            list_callback.on_train_step_end(run_context)
155        else:
156            cb_params.cur_step_num += 1
157            if self.train_network_init_flag:
158                self._train_network.add_flags_recursive(thor=False)
159                self.train_network_init_flag = False
160            self._train_network.phase = 'train1'
161            outputs = self._train_network(*inputs)
162            cb_params.net_outputs = outputs
163            self.index_first_order += 1
164            if self.index_first_order == iter_first_order:
165                self.index_first_order = 0
166                self.switch_branch_one = not self.switch_branch_one
167                list_callback.on_train_step_end(run_context)
168
169    def _train_ascend_sink_step(self, cb_params, train_dataset, iter_first_order, inputs, list_callback, run_context):
170        """train ascend sink step"""
171        if self.switch_branch_one:
172            cb_params.cur_step_num += 1
173            if self.train_network_init_flag:
174                self._train_network.add_flags_recursive(thor=True)
175            self._train_network.phase = 'train0'
176        else:
177            cb_params.cur_step_num += iter_first_order
178            if self.train_network_init_flag:
179                self._train_network.add_flags_recursive(thor=False)
180                self.train_network_init_flag = False
181            self._train_network.phase = 'train1'
182            if not self.has_do_dataset_init:
183                _exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset')
184                self.has_do_dataset_init = True
185        self.switch_branch_one = not self.switch_branch_one
186        outputs = self._train_network(*inputs)
187        cb_params.net_outputs = outputs
188        list_callback.on_train_step_end(run_context)
189
190    def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None,
191                                    sink_size=-1, initial_epoch=0, valid_infos=None):
192        """
193        Training process. The data would be passed to network through dataset channel.
194
195        Args:
196            epoch (int): Total number of iterations on the data.
197            train_dataset (Dataset): A training dataset iterator. If there is no
198                                     loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
199                                     returned and passed to the network. Otherwise, a tuple (data, label) should
200                                     be returned. The data and label would be passed to the network and loss
201                                     function respectively.
202            list_callback (Callback): Executor of callback list. Default: ``None``.
203            cb_params (_InternalCallbackParam): Callback parameters. Default: ``None``.
204            sink_size (int): Control the amount of data in each sink. Default: -1.
205            initial_epoch (int): Epoch at which to start train, it useful for resuming a previous training run.
206                                 Default: 0.
207        """
208        valid_dataset, _, _ = valid_infos
209        if valid_dataset:
210            raise ValueError("Evaluation in training is currently not supported in the second-order scenario of thor.")
211        if sink_size == -1:
212            epoch_num = epoch - initial_epoch
213        else:
214            epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size()) - initial_epoch
215
216        iter_first_order = self._frequency - 1
217        iter_second_order = 1
218        train_dataset.__loop_size__ = iter_second_order
219        dataset_helper, train_network = self._exec_preprocess(self._train_network,
220                                                              is_train=True,
221                                                              phase='train',
222                                                              dataset=train_dataset,
223                                                              dataset_sink_mode=True,
224                                                              sink_size=sink_size,
225                                                              epoch_num=epoch_num,
226                                                              iter_first_order=iter_first_order)
227
228        self._train_network = train_network
229        cb_params.train_network = self._train_network
230        cb_params.cur_step_num = 0
231
232        run_context = RunContext(cb_params)
233        list_callback.on_train_begin(run_context)
234
235        for i in range(initial_epoch, epoch):
236            cb_params.cur_epoch_num = i + 1
237            list_callback.on_train_epoch_begin(run_context)
238            # for data sink dataset_helper only iter once, other wise iter epoch_size times.
239            for inputs in dataset_helper:
240                if _need_to_full() and context.get_context("device_target") == "GPU":
241                    inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
242                list_callback.on_train_step_begin(run_context)
243                if context.get_context("device_target") == "GPU":
244                    self._train_gpu_sink_step(cb_params, inputs, list_callback, iter_first_order, run_context)
245                else:
246                    self._train_ascend_sink_step(cb_params, train_dataset, iter_first_order, inputs, list_callback,
247                                                 run_context)
248            list_callback.on_train_epoch_end(run_context)
249            should_stop = False or run_context.get_stop_requested()
250            if should_stop:
251                break
252        dataset_helper.stop_send()
253
254        list_callback.on_train_end(run_context)
255
256
257__all__ = ["ModelThor"]
258