• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""High-Level API for Second Order Training or Testing.
16Second-order optimizer THOR reduces the computation workload and improves the computation speed by reducing the
17frequency of updating the second-order matrix. In order to optimize the overall performance, the ModelThor class
18is redefined to inherit the Model class provided by MindSpore. The parameter of THOR for controlling the frequency
19of updating the second-order matrix can be obtained by the ModelThor class. """
20import math
21from mindspore.train.callback import RunContext
22from mindspore import context
23from mindspore import nn
24from mindspore.context import ParallelMode
25from mindspore.train.model import Model
26from mindspore.train.dataset_helper import connect_network_with_dataset
27from mindspore.parallel._utils import _need_to_full, _to_full_tensor
28from mindspore.common.dtype import pytype_to_dtype
29from mindspore._c_expression import init_exec_dataset
30from .dataset_helper import DatasetHelper
31
32
33def _convert_to_ms_type(types):
34    """
35    Convert from numpy type to mindspore tensor type.
36
37    Args:
38        types (list): Numpy type list of element in dataset.
39
40    Returns:
41        list, list of element in dataset.
42    """
43    ms_types = []
44    for numpy_type in types:
45        ms_type = pytype_to_dtype(numpy_type)
46        ms_types.append(ms_type)
47    return ms_types
48
49
50def _get_types_and_shapes(dataset):
51    """Get dataset types and shapes."""
52    dataset_types = _convert_to_ms_type(dataset.output_types())
53    dataset_shapes = dataset.output_shapes()
54    return dataset_types, dataset_shapes
55
56
57def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
58    """Initialize and execute the dataset graph."""
59    batch_size = exec_dataset.get_batch_size()
60    input_indexs = exec_dataset.input_indexs
61
62    # transform data format
63    dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
64    init_exec_dataset(exec_dataset.__transfer_dataset__.queue_name,
65                      dataset_size,
66                      batch_size,
67                      dataset_types,
68                      dataset_shapes,
69                      input_indexs,
70                      phase=phase,
71                      need_run=False)
72
73
74class ModelThor(Model):
75    """
76    High-Level API for Training or Testing.
77
78    `Model` groups layers into an object with training and inference features.
79
80    Args:
81        network (Cell): A training or testing network.
82        loss_fn (Cell): Objective function, if loss_fn is None, the
83                             network should contain the logic of loss and grads calculation, and the logic
84                             of parallel if needed. Default: None.
85        optimizer (Cell): Optimizer for updating the weights. Default: None.
86        metrics (Union[dict, set]): A Dictionary or a set of metrics to be evaluated by the model during
87                        training and testing. eg: {'accuracy', 'recall'}. Default: None.
88        eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as
89                             `eval_network`. Default: None.
90        eval_indexes (list): When defining the `eval_network`, if `eval_indexes` is None, all outputs of the
91                             `eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three
92                             elements, including the positions of loss value, predicted value and label. The loss
93                             value would be passed to the `Loss` metric, the predicted value and label would be passed
94                             to other metric. Default: None.
95        amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed
96            precision training. Supports [O0, O2, O3]. Default: "O0".
97
98            - O0: Do not change.
99            - O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
100            - O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
101
102            O2 is recommended on GPU, O3 is recommended on Ascend.
103
104        loss_scale_manager (Union[None, LossScaleManager]): If it is None, the loss would not be scaled. Otherwise,
105            scale the loss by LossScaleManager. It is a key argument.
106            e.g. Use `loss_scale_manager=None` to set the value.
107        keep_batchnorm_fp32 (bool): Keep Batchnorm running in `float32`. If it is set to true, the level setting before
108            will be overwritten. Default: True.
109    """
110
111    def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None,
112                 eval_indexes=None, amp_level="O0", **kwargs):
113        super(ModelThor, self).__init__(network, loss_fn, optimizer, metrics, eval_network,
114                                        eval_indexes, amp_level, **kwargs)
115        if isinstance(network, nn.TrainOneStepCell):
116            self._frequency = network.optimizer.get_frequency()
117        else:
118            self._frequency = optimizer.get_frequency()
119        # used to stop training for early stop, such as stopAtTIme or stopATStep
120        self.should_stop = False
121        self.switch_branch_one = True
122        self.index_first_order = 0
123        self.train_network_init_flag = True
124        self.has_do_dataset_init = False
125        self._train_network = self._build_train_network()
126
127    def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1,
128                         epoch_num=1, iter_first_order=1):
129        """Initializes dataset."""
130        if dataset_sink_mode and not is_train:
131            dataset.__loop_size__ = 1
132        dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order)
133
134        if dataset_sink_mode and context.get_context("device_target") != "GPU":
135            network = connect_network_with_dataset(network, dataset_helper)
136        network.set_train(is_train)
137        network.phase = phase
138
139        if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
140            network.set_auto_parallel()
141
142        return dataset_helper, network
143
144    def _train_gpu_sink_step(self, cb_params, inputs, list_callback, iter_first_order, run_context):
145        """train gpu sink step"""
146        if self.switch_branch_one:
147            cb_params.cur_step_num += 1
148            if self.train_network_init_flag:
149                self._train_network.add_flags_recursive(thor=True)
150            self._train_network.phase = 'train0'
151            self.switch_branch_one = not self.switch_branch_one
152            outputs = self._train_network(*inputs)
153            cb_params.net_outputs = outputs
154            list_callback.step_end(run_context)
155        else:
156            cb_params.cur_step_num += 1
157            if self.train_network_init_flag:
158                self._train_network.add_flags_recursive(thor=False)
159                self.train_network_init_flag = False
160            self._train_network.phase = 'train1'
161            outputs = self._train_network(*inputs)
162            cb_params.net_outputs = outputs
163            self.index_first_order += 1
164            if self.index_first_order == iter_first_order:
165                self.index_first_order = 0
166                self.switch_branch_one = not self.switch_branch_one
167                list_callback.step_end(run_context)
168
169    def _train_ascend_sink_step(self, cb_params, train_dataset, iter_first_order, inputs, list_callback, run_context):
170        """train ascend sink step"""
171        if self.switch_branch_one:
172            cb_params.cur_step_num += 1
173            if self.train_network_init_flag:
174                self._train_network.add_flags_recursive(thor=True)
175            self._train_network.phase = 'train0'
176        else:
177            cb_params.cur_step_num += iter_first_order
178            if self.train_network_init_flag:
179                self._train_network.add_flags_recursive(thor=False)
180                self.train_network_init_flag = False
181            self._train_network.phase = 'train1'
182            if not self.has_do_dataset_init:
183                _exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset')
184                self.has_do_dataset_init = True
185        self.switch_branch_one = not self.switch_branch_one
186        outputs = self._train_network(*inputs)
187        cb_params.net_outputs = outputs
188        list_callback.step_end(run_context)
189
190    def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1):
191        """
192        Training process. The data would be passed to network through dataset channel.
193
194        Args:
195            epoch (int): Total number of iterations on the data.
196            train_dataset (Dataset): A training dataset iterator. If there is no
197                                     loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
198                                     returned and passed to the network. Otherwise, a tuple (data, label) should
199                                     be returned. The data and label would be passed to the network and loss
200                                     function respectively.
201            list_callback (Callback): Executor of callback list. Default: None.
202            cb_params (_InternalCallbackParam): Callback parameters. Default: None.
203            sink_size (int): Control the amount of data in each sink. Default: -1.
204        """
205        if sink_size == -1:
206            epoch_num = epoch
207        else:
208            epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
209
210        iter_first_order = self._frequency - 1
211        iter_second_order = 1
212        train_dataset.__loop_size__ = iter_second_order
213        dataset_helper, train_network = self._exec_preprocess(self._train_network,
214                                                              is_train=True,
215                                                              phase='train',
216                                                              dataset=train_dataset,
217                                                              dataset_sink_mode=True,
218                                                              sink_size=sink_size,
219                                                              epoch_num=epoch_num,
220                                                              iter_first_order=iter_first_order)
221
222        self._train_network = train_network
223        cb_params.train_network = self._train_network
224        cb_params.cur_step_num = 0
225
226        run_context = RunContext(cb_params)
227        list_callback.begin(run_context)
228
229        for i in range(epoch):
230            cb_params.cur_epoch_num = i + 1
231            list_callback.epoch_begin(run_context)
232            # for data sink dataset_helper only iter once, other wise iter epoch_size times.
233            for inputs in dataset_helper:
234                if _need_to_full() and context.get_context("device_target") == "GPU":
235                    inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
236                list_callback.step_begin(run_context)
237                if context.get_context("device_target") == "GPU":
238                    self._train_gpu_sink_step(cb_params, inputs, list_callback, iter_first_order, run_context)
239                else:
240                    self._train_ascend_sink_step(cb_params, train_dataset, iter_first_order, inputs, list_callback,
241                                                 run_context)
242            list_callback.epoch_end(run_context)
243            self.should_stop = self.should_stop or run_context.get_stop_requested()
244            if self.should_stop:
245                break
246        dataset_helper.stop_send()
247
248        list_callback.end(run_context)
249
250
251__all__ = ["ModelThor"]
252