1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Train utility.""" 16import os 17from collections.abc import Iterable 18 19import numpy as np 20 21from mindspore.common.tensor import Tensor 22from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype 23from mindspore.common import dtype as mstype 24from mindspore import log as logger 25from mindspore.common.api import _cell_graph_executor 26from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model 27from mindspore.train.checkpoint_pb2 import Checkpoint 28from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy 29 30from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo 31 32MAX_PATH_LENGTH = 1024 33 34 35def _convert_type(types): 36 """ 37 Convert from numpy type to tensor type. 38 39 Args: 40 types (list): Numpy type list of element in dataset. 41 42 Returns: 43 list, list of element in dataset. 44 """ 45 ms_types = [] 46 for np_type in types: 47 ms_type = pytype_to_dtype(np_type) 48 ms_types.append(ms_type) 49 return ms_types 50 51 52def _get_types_and_shapes(dataset): 53 """Get dataset types and shapes.""" 54 dataset_types = _convert_type(dataset.output_types()) 55 dataset_shapes = dataset.output_shapes() 56 return dataset_types, dataset_shapes 57 58 59def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_info_queue=False): 60 """Initialize and execute the dataset graph.""" 61 batch_size = exec_dataset.get_batch_size() 62 input_indexs = exec_dataset.input_indexs 63 64 # transform data format 65 dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) 66 send_epoch_end = bool(dataset_size == -1) 67 exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue) 68 69 _cell_graph_executor.init_dataset(exec_dataset.queue_name, 70 dataset_size, 71 batch_size, 72 dataset_types, 73 dataset_shapes, 74 input_indexs, 75 phase=phase) 76 77 return exec_dataset 78 79 80def _make_directory(path, arg_name='path'): 81 """Make directory.""" 82 if not isinstance(path, str): 83 logger.error("The %s is invalid, the type should be string.", arg_name) 84 raise TypeError("The {} is invalid, the type should be string.".format(arg_name)) 85 if path.strip() == "": 86 logger.error("The %s is invalid, it should be non-blank.", arg_name) 87 raise ValueError("The {} is invalid, it should be non-blank.".format(arg_name)) 88 89 path = os.path.realpath(path) 90 91 if len(path) > MAX_PATH_LENGTH: 92 logger.error("The %s length is too long, it should be limited in %s.", arg_name, MAX_PATH_LENGTH) 93 raise ValueError("The {} length is too long, it should be limited in {}.".format(arg_name, MAX_PATH_LENGTH)) 94 95 logger.debug("The abs path is %r", path) 96 97 if os.path.exists(path): 98 if not os.path.isdir(path): 99 logger.error("The path(%r) is a file path, it should be a directory path.", path) 100 raise NotADirectoryError("The path({}) is a file path, it should be a directory path.".format(path)) 101 real_path = path 102 else: 103 logger.debug("The directory(%s) doesn't exist, will create it", path) 104 try: 105 permissions = os.R_OK | os.W_OK | os.X_OK 106 os.umask(permissions << 3 | permissions) 107 mode = permissions << 6 108 os.makedirs(path, mode=mode, exist_ok=True) 109 real_path = path 110 except PermissionError as e: 111 logger.error("No write permission on the directory(%r), error = %r", path, e) 112 raise TypeError("No write permission on the directory.") 113 finally: 114 pass 115 return real_path 116 117 118def _construct_tensor_list(types, shapes, batch_expand_num=1): 119 """ 120 Construct list of tensors with types and shapes, used to initialize the network. 121 122 Args: 123 types: List or Tuple. The output types of element in dataset. 124 shapes: List or Tuple. The output shapes of element in dataset. 125 batch_expand_num (int): Batch expand number. 126 127 Returns: 128 List, list of Tensors. 129 """ 130 if len(types) != len(shapes): 131 raise ValueError("The length of dataset types must equal to dataset shapes, " 132 "but got dataset types={} and dataset shapes={}".format(types, shapes)) 133 tensor_list = [] 134 for type_, shape in zip(types, shapes): 135 new_shape = () 136 for i, item in enumerate(shape): 137 if i == 0: 138 new_shape += (item * batch_expand_num,) 139 else: 140 new_shape += (item,) 141 tensor = Tensor(np.zeros(new_shape, dtype_to_nptype(type_))) 142 tensor.virtual_flag = True 143 tensor_list.append(tensor) 144 return tensor_list 145 146 147def _to_tensor(elem, scaling_sens=None): 148 """Convert numpy to tensor, adapt to feed the data from host solution.""" 149 lst = [] 150 if not isinstance(elem, (tuple, list)): 151 elem = [elem] 152 for data in elem: 153 if not isinstance(data, np.ndarray): 154 if scaling_sens: 155 elem_tuple = tuple(elem) + (Tensor(scaling_sens, mstype.float32),) 156 else: 157 elem_tuple = tuple(elem) 158 return elem_tuple 159 lst.append(Tensor(data)) 160 if scaling_sens: 161 lst.append(Tensor(scaling_sens, mstype.float32)) 162 163 return lst[0] if len(lst) == 1 else tuple(lst) 164 165 166def _construct_input_tensors(dataset_types, dataset_shapes, device_number=1): 167 """Construct tensor list to initialize the network which implemented in dataset sink.""" 168 tensor_list_run = _construct_tensor_list(dataset_types, dataset_shapes, batch_expand_num=1) 169 tensor_list_compile = _construct_tensor_list(dataset_types, dataset_shapes, batch_expand_num=device_number) 170 return tensor_list_run, tensor_list_compile 171 172 173def _check_to_numpy(plugin, tensor): 174 """Check the tensor and return a numpy.ndarray.""" 175 np_value = tensor.asnumpy() 176 np_value = np_value.copy() 177 if plugin == 'scalar': 178 if np_value.size == 1: 179 return np_value 180 raise ValueError('The tensor holds more than one value, but the scalar plugin expects on value.') 181 if plugin == 'image': 182 if np_value.ndim == 4: 183 return np_value 184 raise ValueError('The tensor seems not to hold a valid image.') 185 if plugin in ('tensor', 'histogram'): 186 if np_value.ndim > 0: 187 return np_value 188 raise ValueError('The tensor should not be empty.') 189 return np_value 190 191 192def _check_lineage_value(plugin, value): 193 """Check the lineage value.""" 194 def raises(plugin, prototype): 195 raise TypeError(f'Plugin {repr(plugin)} expects a {prototype.__name__} value.') 196 197 if plugin == 'dataset_graph' and not isinstance(value, DatasetGraph): 198 raises(plugin, DatasetGraph) 199 200 if plugin == 'eval_lineage' and not isinstance(value, EvaluationLineage): 201 raises(plugin, EvaluationLineage) 202 203 if plugin == 'train_lineage' and not isinstance(value, TrainLineage): 204 raises(plugin, TrainLineage) 205 206 if plugin == 'custom_lineage_data' and not isinstance(value, UserDefinedInfo): 207 raises(plugin, UserDefinedInfo) 208 209 210def check_value_type(arg_name, arg_value, valid_types): 211 """Checks whether a value is instance of some types.""" 212 valid_types = tuple(valid_types) if isinstance(valid_types, Iterable) else (valid_types,) 213 is_valid = True 214 215 # bool is subclass of int, so for a bool value, we need to extra check 216 if isinstance(arg_value, int) and isinstance(arg_value, bool) and bool not in valid_types: 217 is_valid = False 218 219 if not isinstance(arg_value, valid_types): 220 is_valid = False 221 222 if not is_valid: 223 raise TypeError(f'For `{arg_name}` the type should be a valid type of {[t.__name__ for t in valid_types]}, ' 224 f'but got {type(arg_value).__name__}.') 225 226 227def read_proto(file_name, proto_format="MINDIR", display_data=False): 228 """ 229 Read protobuf file. 230 231 Args: 232 file_name (str): File name. 233 proto_format (str): Proto format {MINDIR, CKPT, CKPT_STRATEGY}. Default: MINDIR. 234 display_data (bool): Whether display data. Default: False. 235 236 Returns: 237 Object, proto object. 238 """ 239 240 if proto_format == "MINDIR": 241 model = mindir_model() 242 elif proto_format == "CKPT": 243 model = Checkpoint() 244 elif proto_format == "CKPT_STRATEGY": 245 model = ckpt_strategy() 246 else: 247 raise ValueError("Unsupported proto format.") 248 249 try: 250 with open(file_name, "rb") as f: 251 pb_content = f.read() 252 model.ParseFromString(pb_content) 253 except BaseException as e: 254 logger.error("Failed to read the file `%s`, please check the correct of the file.", file_name) 255 raise ValueError(e.__str__()) 256 finally: 257 pass 258 259 if proto_format == "MINDIR" and not display_data: 260 for param_proto in model.graph.parameter: 261 param_proto.raw_data = b'\0' 262 263 if proto_format == "CKPT" and not display_data: 264 for element in model.value: 265 element.tensor.tensor_content = b'\0' 266 267 return model 268