• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Train utility."""
16from __future__ import absolute_import
17
18import os
19import json
20from collections.abc import Iterable
21
22import numpy as np
23
24from mindspore.common.tensor import Tensor
25from mindspore._c_expression import Tensor as Tensor_
26from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
27from mindspore.common import dtype as mstype
28from mindspore import log as logger
29from mindspore import _checkparam as Validator
30from mindspore.common.api import _cell_graph_executor
31from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
32from mindspore.train.checkpoint_pb2 import Checkpoint
33from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy
34from mindspore.train.lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo
35from mindspore.parallel._parallel_serialization import _make_dir
36from mindspore.ops.operations import debug_ops
37
38
39def _convert_type(types):
40    """
41    Convert from numpy type to tensor type.
42
43    Args:
44        types (list): Numpy type list of element in dataset.
45
46    Returns:
47        list, list of element in dataset.
48    """
49    ms_types = []
50    for np_type in types:
51        ms_type = pytype_to_dtype(np_type)
52        ms_types.append(ms_type)
53    return ms_types
54
55
56def _get_types_and_shapes(dataset):
57    """Get dataset types and shapes."""
58    dataset_types = _convert_type(dataset.output_types())
59    dataset_shapes = dataset.output_shapes()
60    return dataset_types, dataset_shapes
61
62
63def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_info_queue=False):
64    """Initialize and execute the dataset graph."""
65    batch_size = exec_dataset.get_batch_size()
66    input_indexs = exec_dataset.input_indexs
67
68    # transform data format
69    dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
70    send_epoch_end = bool(dataset_size == -1)
71    queue_name = _cell_graph_executor.get_queue_name(phase)
72    if queue_name is None:
73        queue_name = str("")
74    exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end,
75                                           create_data_info_queue=create_data_info_queue, queue_name=queue_name)
76    _cell_graph_executor.init_dataset(exec_dataset.queue_name,
77                                      dataset_size,
78                                      batch_size,
79                                      dataset_types,
80                                      dataset_shapes,
81                                      input_indexs,
82                                      phase=phase)
83    return exec_dataset
84
85
86def _make_directory(path, arg_name='path'):
87    """Make directory."""
88    return _make_dir(path, arg_name)
89
90
91def _construct_tensor_list(types, shapes, batch_expand_num=1):
92    """
93    Construct list of tensors with types and shapes, used to initialize the network.
94
95    Args:
96        types: List or Tuple. The output types of element in dataset.
97        shapes: List or Tuple. The output shapes of element in dataset.
98        batch_expand_num (int): Batch expand number.
99
100    Returns:
101        List, list of Tensors.
102    """
103    if len(types) != len(shapes):
104        raise ValueError("The length of dataset types must be equal to dataset shapes, "
105                         "but got dataset types={} and dataset shapes={}".format(types, shapes))
106    tensor_list = []
107    for type_, shape in zip(types, shapes):
108        new_shape = ()
109        for i, item in enumerate(shape):
110            if i == 0:
111                new_shape += (item * batch_expand_num,)
112            else:
113                new_shape += (item,)
114        tensor = Tensor(np.zeros(new_shape, dtype_to_nptype(type_)), dtype=type_)
115        tensor.virtual_flag = True
116        tensor_list.append(tensor)
117    return tensor_list
118
119
120def _to_tensor(elem, scaling_sens=None):
121    """Convert numpy to tensor, adapt to feed the data from host solution."""
122    lst = []
123    if not isinstance(elem, (tuple, list)):
124        elem = [elem]
125    for data in elem:
126        if not isinstance(data, np.ndarray):
127            if scaling_sens:
128                elem_tuple = tuple(elem) + (Tensor(scaling_sens, mstype.float32),)
129            else:
130                elem_tuple = tuple(elem)
131            return elem_tuple
132        lst.append(Tensor(data))
133    if scaling_sens:
134        lst.append(Tensor(scaling_sens, mstype.float32))
135
136    return lst[0] if len(lst) == 1 else tuple(lst)
137
138
139def _construct_input_tensors(dataset_types, dataset_shapes, device_number=1):
140    """Construct tensor list to initialize the network which implemented in dataset sink."""
141    tensor_list_run = _construct_tensor_list(dataset_types, dataset_shapes, batch_expand_num=1)
142    tensor_list_compile = _construct_tensor_list(dataset_types, dataset_shapes, batch_expand_num=device_number)
143    return tensor_list_run, tensor_list_compile
144
145
146def _check_to_numpy(plugin, tensor, prim=None):
147    """Check the tensor and return a numpy.ndarray."""
148    np_value = tensor.asnumpy()
149    np_value = np_value.copy()
150    summary_name = plugin.capitalize() + "Summary" if prim else "SummaryRecord"
151    if plugin == 'scalar':
152        if np_value.size == 1:
153            return np_value
154        raise ValueError(
155            f'For "{summary_name}", the v rank must be less than or equal to 1, but got {len(np_value)}.')
156    if plugin == 'image':
157        if np_value.ndim == 4:
158            return np_value
159        raise ValueError(f'For "{summary_name}", the tensor seems not to hold a valid image.')
160    if plugin in ('tensor', 'histogram'):
161        if np_value.ndim > 0:
162            return np_value
163        raise ValueError(f'For "{summary_name}", the value should not be empty.')
164    return np_value
165
166
167def check_summary_param(summary_name, tag, tensor):
168    """Checks the tag is valid for summary."""
169    plugin = summary_name.split('Summary')[0].lower()
170    try:
171        if not isinstance(tag, str) or not tag:
172            raise TypeError(f'For "{summary_name}", the name must be valid string, but got "{tag}".')
173        if not isinstance(tensor, (Tensor, Tensor_)):
174            raise TypeError(f'For "{summary_name}", the parameter "value" expect to be Tensor, '
175                            f'but got {type(tensor).__name__}')
176        _check_to_numpy(plugin, tensor, prim=True)
177    except TypeError as err:
178        raise TypeError(err) from err
179    except ValueError as err:
180        raise ValueError(err) from err
181    finally:
182        debug_ops.SUMMARY_TENSOR_CACHE = []
183
184
185def _check_lineage_value(plugin, value):
186    """Check the lineage value."""
187
188    def raises(plugin, prototype):
189        raise TypeError(f'Plugin {repr(plugin)} expects a {prototype.__name__} value.')
190
191    if plugin == 'dataset_graph' and not isinstance(value, DatasetGraph):
192        raises(plugin, DatasetGraph)
193
194    if plugin == 'eval_lineage' and not isinstance(value, EvaluationLineage):
195        raises(plugin, EvaluationLineage)
196
197    if plugin == 'train_lineage' and not isinstance(value, TrainLineage):
198        raises(plugin, TrainLineage)
199
200    if plugin == 'custom_lineage_data' and not isinstance(value, UserDefinedInfo):
201        raises(plugin, UserDefinedInfo)
202
203
204def check_value_type(arg_name, arg_value, valid_types):
205    """Checks whether a value is instance of some types."""
206    valid_types = tuple(valid_types) if isinstance(valid_types, Iterable) else (valid_types,)
207    is_valid = True
208
209    # bool is subclass of int, so for a bool value, we need to extra check
210    if isinstance(arg_value, int) and isinstance(arg_value, bool) and bool not in valid_types:
211        is_valid = False
212
213    if not isinstance(arg_value, valid_types):
214        is_valid = False
215
216    if not is_valid:
217        raise TypeError(f'For `{arg_name}` the type should be a valid type of {[t.__name__ for t in valid_types]}, '
218                        f'but got {type(arg_value).__name__}.')
219
220
221def read_proto(file_name, proto_format="MINDIR", display_data=False):
222    """
223    Read protobuf file.
224
225    Args:
226        file_name (str): File name.
227        proto_format (str): Proto format {MINDIR, CKPT, CKPT_STRATEGY}.  Default: MINDIR.
228        display_data (bool): Whether display data. Default: ``False``.
229
230    Returns:
231        Object, proto object.
232    """
233    Validator.check_file_name_by_regular(file_name)
234    file_name = os.path.realpath(file_name)
235    if proto_format == "MINDIR":
236        model = mindir_model()
237    elif proto_format == "CKPT":
238        model = Checkpoint()
239    elif proto_format == "CKPT_STRATEGY":
240        model = ckpt_strategy()
241    else:
242        raise ValueError("Unsupported proto format.")
243
244    try:
245        with open(file_name, "rb") as f:
246            pb_content = f.read()
247            model.ParseFromString(pb_content)
248    except BaseException as e:
249        logger.critical(f"Failed to phase the file: {file_name} as format: {proto_format},"
250                        f" please check the correct file and format.")
251        raise ValueError(e.__str__()) from e
252    finally:
253        pass
254
255    if proto_format == "MINDIR" and not display_data:
256        for param_proto in model.graph.parameter:
257            param_proto.raw_data = b'\0'
258
259    if proto_format == "CKPT" and not display_data:
260        for element in model.value:
261            if element.tensor.ByteSize() != 0:
262                element.tensor.tensor_content = b'\0'
263            else:
264                for ele in element.maptensor.tensor:
265                    ele.tensor_content = b'\0'
266
267    return model
268
269
270def parse_strategy_ckpt(file_name):
271    """
272    Parses a strategy ckpt layout file and returns the rank location dict.
273
274    Args:
275        file_name (str):Strategy ckpt file name.
276
277    Returns:
278        Dict, layout dict. Key is parameter name, value is (dev_matrix, tensor_map).
279
280    Examples:
281        >>> from mindspore.train.utils import parse_strategy_ckpt
282        >>> layout_dict = parse_strategy_ckpt("/path/to/strategy.ckpt")
283        {"param1": [[4, 4], [0, -1]], "param2": [[4, 4], [-1, 0]],,,,}
284    """
285    model = ckpt_strategy()
286    with open(file_name, "rb") as f:
287        pb_content = f.read()
288        model.ParseFromString(pb_content)
289    layout_dict = {}
290    for param in model.parallel_layout_item:
291        dev_matrix = []
292        tensor_map = []
293        for ele in param.parallel_layouts.dev_matrix[0].ListFields()[0][1]:
294            dev_matrix.append(ele)
295
296        for ele in param.parallel_layouts.tensor_map[0].ListFields()[0][1]:
297            tensor_map.append(ele)
298        layout_dict[param.param_name] = [dev_matrix, tensor_map]
299    return layout_dict
300
301
302def get_parameter_redundancy(layout_obj, initial_rank=0):
303    """
304    Get parameter redundancy map.
305
306    Args:
307        layout_obj (Union[str, layout): File name of `strategy.ckpt` or net.parameter_layout_dict.
308        initial_rank (int): Start rank id for each pipeline. Default: 0.
309
310    Returns:
311        Dict, dict of parameter redundancy info.
312
313    Examples:
314        >>> from mindspore.train.utils import get_parameter_redundancy
315        >>> param_redundancy_dict = get_parameter_redundancy("/path/to/strategy.ckpt")
316        {'param1': ((0, 1, 2, 3, 4, 5, 6, 7),),
317         'param2': ((0, 4, 8, 12), (1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15)),
318         'param3': ((0, 4, 8, 12), (1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15)),
319         'param4': ((0, 4, 8, 12), (1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15))}
320    """
321    if isinstance(layout_obj, str):
322        parameter_layout = parse_strategy_ckpt(layout_obj)
323    else:
324        parameter_layout = {}
325        for k, v in layout_obj.items():
326            parameter_layout[k] = v[:2]
327
328    param_redundancy_dict = {}
329    for key, (slices, deploy_loc, *_) in parameter_layout.items():
330        redundancy_matrix = np.zeros(shape=slices + [len(slices)], dtype=np.int8)
331        for i in deploy_loc:
332            internal_slice = tuple(slice(None) for _ in range(i))
333            for j in range(slices[-i - 1]):
334                if i == -1:
335                    continue
336                else:
337                    redundancy_matrix[(..., j) + internal_slice + (i,)] = j
338        locate_list = redundancy_matrix.reshape((-1, len(slices))).tolist()
339        redundancy_dict = {}
340        for index, locate in enumerate(locate_list):
341            redundancy_dict.setdefault(tuple(locate), []).append(index+initial_rank)
342        redundancy_list = []
343        for _, indices in sorted(redundancy_dict.items()):
344            redundancy_list.append(tuple(indices))
345
346        param_redundancy_dict[key] = tuple(redundancy_list)
347    return param_redundancy_dict
348
349
350def _collect_settings_by_rank(redundancy_map):
351    """
352    Collect parameter redundancy map by rank id.
353
354    {"param1":((1,3,5,7),(2,4,6,8)),"param2":((1,3,5,7),(2,4,6,8))}
355    ->{(1,3,5,7):{"param1", "param2"},(2,4,6,8):{"param1", "param2"}}
356    """
357    redundancy_map_reversed = {}
358    for key, redundancy in redundancy_map.items():
359        for index, item in enumerate(redundancy):
360            redundancy_map_reversed.setdefault(item, []).append(
361                (key, index))
362    return redundancy_map_reversed
363
364
365def _restructure(input_dict):
366    """
367    Flatten and reorganize the nested dictionary structure."""
368    if all(not isinstance(item, tuple) for item in input_dict):
369        return input_dict
370    res_dict = {}
371    for key, values in input_dict.items():
372        for index, value in enumerate(values):
373            res_dict.setdefault(key[index % len(key)], []).append(value)
374    return _restructure(res_dict)
375
376
377def _rotate_list_elements(i, input_list):
378    """Rotate element list."""
379    rotated_list = [input_list[(i + j) % len(input_list)] for j in
380                    range(len(input_list))]
381    return rotated_list
382
383
384def remove_param_redundancy(param_redundancy_dict, keep_redundancy=1):
385    """
386    Remove parameter redundancy, get the single parameter for each rank id.
387    Args:
388        param_redundancy_dict (Dict): Parameter redundancy dict.
389        keep_redundancy (Int): Keep redundancy number.
390
391    Returns:
392        Dict, single parameter for each rank id. Key is rank_id, value is set(params).
393
394    Examples:
395        >>> from mindspore.train.utils import get_parameter_redundancy, remove_param_redundancy
396        >>> param_redundancy_dict = get_parameter_redundancy("/path/to/strategy.ckpt")
397        >>> single_parameter = remove_param_redundancy(param_redundancy_dict)
398        {0: {param1, param3}, 1: {param2, param4},,,}}
399    """
400    redundancy_dict_reversed = _collect_settings_by_rank(param_redundancy_dict)
401    sorted_layouts = {}
402    for device_layout, layer_names_list in redundancy_dict_reversed.items():
403        sorted_layer_names = [item[0] for item in layer_names_list]
404        sorted_layouts[device_layout] = sorted_layer_names
405    result = {}
406    for i in range(keep_redundancy):
407        rotated_layouts = {tuple(_rotate_list_elements(i, key)): value for
408                           key, value in sorted_layouts.items()}
409        restructured_layouts = _restructure(rotated_layouts)
410        for key, value in restructured_layouts.items():
411            result.setdefault(key, set()).update(set(value))
412    return result
413
414
415def parse_hccl_file(hccl_file_path):
416    """
417    Parses an HCCL configuration JSON file, return a dict key is rank_id, value is device_ip.
418
419    Args:
420        hccl_file_path (str): The path to the HCCL configuration JSON file.
421
422    Returns:
423        Dict: A Dict, key is rank_id, value is device_ip.
424
425    Examples:
426        >>> from mindspore.train.utils import parse_hccl_file
427        >>> rankid_dict = parse_hccl_file("/path/to/hccl.json")
428        {0: "10.11.10.163", 1: "10.11.10.164", 2: "10.11.10.165", 3: "10.11.10.166",,,,}
429    """
430    with open(hccl_file_path) as f:
431        hccl_dict = json.load(f)
432    server_list = hccl_dict["server_list"]
433    rankid_dict = {}
434    for server in server_list:
435        device_list = server["device"]
436        for device in device_list:
437            rankid_dict[int(device["rank_id"])] = device["device_ip"]
438
439    return rankid_dict
440