• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Train utility."""
16import os
17from collections.abc import Iterable
18
19import numpy as np
20
21from mindspore.common.tensor import Tensor
22from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
23from mindspore.common import dtype as mstype
24from mindspore import log as logger
25from mindspore.common.api import _cell_graph_executor
26from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
27from mindspore.train.checkpoint_pb2 import Checkpoint
28from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy
29
30from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo
31
32MAX_PATH_LENGTH = 1024
33
34
35def _convert_type(types):
36    """
37    Convert from numpy type to tensor type.
38
39    Args:
40        types (list): Numpy type list of element in dataset.
41
42    Returns:
43        list, list of element in dataset.
44    """
45    ms_types = []
46    for np_type in types:
47        ms_type = pytype_to_dtype(np_type)
48        ms_types.append(ms_type)
49    return ms_types
50
51
52def _get_types_and_shapes(dataset):
53    """Get dataset types and shapes."""
54    dataset_types = _convert_type(dataset.output_types())
55    dataset_shapes = dataset.output_shapes()
56    return dataset_types, dataset_shapes
57
58
59def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_info_queue=False):
60    """Initialize and execute the dataset graph."""
61    batch_size = exec_dataset.get_batch_size()
62    input_indexs = exec_dataset.input_indexs
63
64    # transform data format
65    dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
66    send_epoch_end = bool(dataset_size == -1)
67    exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue)
68
69    _cell_graph_executor.init_dataset(exec_dataset.queue_name,
70                                      dataset_size,
71                                      batch_size,
72                                      dataset_types,
73                                      dataset_shapes,
74                                      input_indexs,
75                                      phase=phase)
76
77    return exec_dataset
78
79
80def _make_directory(path, arg_name='path'):
81    """Make directory."""
82    if not isinstance(path, str):
83        logger.error("The %s is invalid, the type should be string.", arg_name)
84        raise TypeError("The {} is invalid, the type should be string.".format(arg_name))
85    if path.strip() == "":
86        logger.error("The %s is invalid, it should be non-blank.", arg_name)
87        raise ValueError("The {} is invalid, it should be non-blank.".format(arg_name))
88
89    path = os.path.realpath(path)
90
91    if len(path) > MAX_PATH_LENGTH:
92        logger.error("The %s length is too long, it should be limited in %s.", arg_name, MAX_PATH_LENGTH)
93        raise ValueError("The {} length is too long, it should be limited in {}.".format(arg_name, MAX_PATH_LENGTH))
94
95    logger.debug("The abs path is %r", path)
96
97    if os.path.exists(path):
98        if not os.path.isdir(path):
99            logger.error("The path(%r) is a file path, it should be a directory path.", path)
100            raise NotADirectoryError("The path({}) is a file path, it should be a directory path.".format(path))
101        real_path = path
102    else:
103        logger.debug("The directory(%s) doesn't exist, will create it", path)
104        try:
105            permissions = os.R_OK | os.W_OK | os.X_OK
106            os.umask(permissions << 3 | permissions)
107            mode = permissions << 6
108            os.makedirs(path, mode=mode, exist_ok=True)
109            real_path = path
110        except PermissionError as e:
111            logger.error("No write permission on the directory(%r), error = %r", path, e)
112            raise TypeError("No write permission on the directory.")
113        finally:
114            pass
115    return real_path
116
117
118def _construct_tensor_list(types, shapes, batch_expand_num=1):
119    """
120    Construct list of tensors with types and shapes, used to initialize the network.
121
122    Args:
123        types: List or Tuple. The output types of element in dataset.
124        shapes: List or Tuple. The output shapes of element in dataset.
125        batch_expand_num (int): Batch expand number.
126
127    Returns:
128        List, list of Tensors.
129    """
130    if len(types) != len(shapes):
131        raise ValueError("The length of dataset types must equal to dataset shapes, "
132                         "but got dataset types={} and dataset shapes={}".format(types, shapes))
133    tensor_list = []
134    for type_, shape in zip(types, shapes):
135        new_shape = ()
136        for i, item in enumerate(shape):
137            if i == 0:
138                new_shape += (item * batch_expand_num,)
139            else:
140                new_shape += (item,)
141        tensor = Tensor(np.zeros(new_shape, dtype_to_nptype(type_)))
142        tensor.virtual_flag = True
143        tensor_list.append(tensor)
144    return tensor_list
145
146
147def _to_tensor(elem, scaling_sens=None):
148    """Convert numpy to tensor, adapt to feed the data from host solution."""
149    lst = []
150    if not isinstance(elem, (tuple, list)):
151        elem = [elem]
152    for data in elem:
153        if not isinstance(data, np.ndarray):
154            if scaling_sens:
155                elem_tuple = tuple(elem) + (Tensor(scaling_sens, mstype.float32),)
156            else:
157                elem_tuple = tuple(elem)
158            return elem_tuple
159        lst.append(Tensor(data))
160    if scaling_sens:
161        lst.append(Tensor(scaling_sens, mstype.float32))
162
163    return lst[0] if len(lst) == 1 else tuple(lst)
164
165
166def _construct_input_tensors(dataset_types, dataset_shapes, device_number=1):
167    """Construct tensor list to initialize the network which implemented in dataset sink."""
168    tensor_list_run = _construct_tensor_list(dataset_types, dataset_shapes, batch_expand_num=1)
169    tensor_list_compile = _construct_tensor_list(dataset_types, dataset_shapes, batch_expand_num=device_number)
170    return tensor_list_run, tensor_list_compile
171
172
173def _check_to_numpy(plugin, tensor):
174    """Check the tensor and return a numpy.ndarray."""
175    np_value = tensor.asnumpy()
176    np_value = np_value.copy()
177    if plugin == 'scalar':
178        if np_value.size == 1:
179            return np_value
180        raise ValueError('The tensor holds more than one value, but the scalar plugin expects on value.')
181    if plugin == 'image':
182        if np_value.ndim == 4:
183            return np_value
184        raise ValueError('The tensor seems not to hold a valid image.')
185    if plugin in ('tensor', 'histogram'):
186        if np_value.ndim > 0:
187            return np_value
188        raise ValueError('The tensor should not be empty.')
189    return np_value
190
191
192def _check_lineage_value(plugin, value):
193    """Check the lineage value."""
194    def raises(plugin, prototype):
195        raise TypeError(f'Plugin {repr(plugin)} expects a {prototype.__name__} value.')
196
197    if plugin == 'dataset_graph' and not isinstance(value, DatasetGraph):
198        raises(plugin, DatasetGraph)
199
200    if plugin == 'eval_lineage' and not isinstance(value, EvaluationLineage):
201        raises(plugin, EvaluationLineage)
202
203    if plugin == 'train_lineage' and not isinstance(value, TrainLineage):
204        raises(plugin, TrainLineage)
205
206    if plugin == 'custom_lineage_data' and not isinstance(value, UserDefinedInfo):
207        raises(plugin, UserDefinedInfo)
208
209
210def check_value_type(arg_name, arg_value, valid_types):
211    """Checks whether a value is instance of some types."""
212    valid_types = tuple(valid_types) if isinstance(valid_types, Iterable) else (valid_types,)
213    is_valid = True
214
215    # bool is subclass of int, so for a bool value, we need to extra check
216    if isinstance(arg_value, int) and isinstance(arg_value, bool) and bool not in valid_types:
217        is_valid = False
218
219    if not isinstance(arg_value, valid_types):
220        is_valid = False
221
222    if not is_valid:
223        raise TypeError(f'For `{arg_name}` the type should be a valid type of {[t.__name__ for t in valid_types]}, '
224                        f'but got {type(arg_value).__name__}.')
225
226
227def read_proto(file_name, proto_format="MINDIR", display_data=False):
228    """
229    Read protobuf file.
230
231    Args:
232        file_name (str): File name.
233        proto_format (str): Proto format {MINDIR, CKPT, CKPT_STRATEGY}.  Default: MINDIR.
234        display_data (bool): Whether display data. Default: False.
235
236    Returns:
237        Object, proto object.
238    """
239
240    if proto_format == "MINDIR":
241        model = mindir_model()
242    elif proto_format == "CKPT":
243        model = Checkpoint()
244    elif proto_format == "CKPT_STRATEGY":
245        model = ckpt_strategy()
246    else:
247        raise ValueError("Unsupported proto format.")
248
249    try:
250        with open(file_name, "rb") as f:
251            pb_content = f.read()
252            model.ParseFromString(pb_content)
253    except BaseException as e:
254        logger.error("Failed to read the file `%s`, please check the correct of the file.", file_name)
255        raise ValueError(e.__str__())
256    finally:
257        pass
258
259    if proto_format == "MINDIR" and not display_data:
260        for param_proto in model.graph.parameter:
261            param_proto.raw_data = b'\0'
262
263    if proto_format == "CKPT" and not display_data:
264        for element in model.value:
265            element.tensor.tensor_content = b'\0'
266
267    return model
268