1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16This module is to write data into mindrecord. 17""" 18import os 19import sys 20import threading 21import traceback 22 23from inspect import signature 24from functools import wraps 25 26import numpy as np 27import mindspore._c_mindrecord as ms 28from .common.exceptions import ParamValueError, MRMUnsupportedSchemaError 29 30SUCCESS = ms.MSRStatus.SUCCESS 31FAILED = ms.MSRStatus.FAILED 32DATASET_NLP = ms.ShardType.NLP 33DATASET_CV = ms.ShardType.CV 34 35MIN_HEADER_SIZE = ms.MIN_HEADER_SIZE 36MAX_HEADER_SIZE = ms.MAX_HEADER_SIZE 37MIN_PAGE_SIZE = ms.MIN_PAGE_SIZE 38MAX_PAGE_SIZE = ms.MAX_PAGE_SIZE 39MIN_SHARD_COUNT = ms.MIN_SHARD_COUNT 40MAX_SHARD_COUNT = ms.MAX_SHARD_COUNT 41MIN_CONSUMER_COUNT = ms.MIN_CONSUMER_COUNT 42MAX_CONSUMER_COUNT = ms.get_max_thread_num 43MIN_FILE_SIZE = ms.MIN_FILE_SIZE 44 45VALUE_TYPE_MAP = {"int": ["int32", "int64"], "float": ["float32", "float64"], "str": "string", "bytes": "bytes", 46 "int32": "int32", "int64": "int64", "float32": "float32", "float64": "float64", 47 "ndarray": ["int32", "int64", "float32", "float64"]} 48 49VALID_ATTRIBUTES = ["int32", "int64", "float32", "float64", "string", "bytes"] 50VALID_ARRAY_ATTRIBUTES = ["int32", "int64", "float32", "float64"] 51 52 53class ExceptionThread(threading.Thread): 54 """ class to pass exception""" 55 def __init__(self, *args, **kwargs): 56 threading.Thread.__init__(self, *args, **kwargs) 57 self.res = SUCCESS 58 self.exitcode = 0 59 self.exception = None 60 self.exc_traceback = '' 61 62 def run(self): 63 try: 64 if self._target: 65 self.res = self._target(*self._args, **self._kwargs) 66 except Exception as e: # pylint: disable=W0703 67 self.exitcode = 1 68 self.exception = e 69 self.exc_traceback = ''.join(traceback.format_exception(*sys.exc_info())) 70 71 72def check_filename(path, arg_name=""): 73 """ 74 check the filename in the path. 75 76 Args: 77 path (str): the path. 78 79 Raises: 80 ParamValueError: If path is not string. 81 FileNameError: If path contains invalid character. 82 83 Returns: 84 Bool, whether filename is valid. 85 """ 86 if arg_name == "": 87 arg_name = "File path" 88 else: 89 arg_name = "'{}'".format(arg_name) 90 if not path: 91 raise ParamValueError('{} is not allowed None or empty!'.format(arg_name)) 92 if not isinstance(path, str): 93 raise ParamValueError("File path: {} is not string.".format(path)) 94 if path.endswith("/"): 95 raise ParamValueError("File path can not end with '/'") 96 file_name = os.path.basename(path) 97 98 # '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`', 99 # '&', '.', '/', '@', "'", '^', ',', '_', '<', ';', '~', '>', 100 # '*', '(', '%', ')', '-', '=', '{', '?', '$' 101 forbidden_symbols = set(r'\/:*?"<>|`&\';') 102 103 if set(file_name) & forbidden_symbols: 104 raise ParamValueError(r"File name should not contains \/:*?\"<>|`&;\'") 105 106 if file_name.startswith(' ') or file_name.endswith(' '): 107 raise ParamValueError("File name should not start/end with space.") 108 109 return True 110 111 112def check_parameter(func): 113 """ 114 decorator for parameter check 115 """ 116 sig = signature(func) 117 118 @wraps(func) 119 def wrapper(*args, **kw): 120 bound = sig.bind(*args, **kw) 121 for name, value in bound.arguments.items(): 122 if name == 'file_name': 123 if isinstance(value, list): 124 for f in value: 125 check_filename(f) 126 else: 127 check_filename(value) 128 if name == 'num_consumer': 129 if value is None: 130 raise ParamValueError("Parameter num_consumer is None.") 131 if isinstance(value, int): 132 if value < MIN_CONSUMER_COUNT or value > MAX_CONSUMER_COUNT(): 133 raise ParamValueError("Parameter num_consumer: {} should between {} and {}." 134 .format(value, MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT())) 135 else: 136 raise ParamValueError("Parameter num_consumer is not int.") 137 return func(*args, **kw) 138 139 return wrapper 140 141 142def populate_data(raw, blob, columns, blob_fields, schema): 143 """ 144 Reconstruct data form raw and blob data. 145 146 Args: 147 raw (Dict): Data contain primitive data like "int32", "int64", "float32", "float64", "string", "bytes". 148 blob (Bytes): Data contain bytes and ndarray data. 149 columns(List): List of column name which will be populated. 150 blob_fields (List): Refer to the field which data stored in blob. 151 schema(Dict): Dict of Schema 152 153 Raises: 154 MRMUnsupportedSchemaError: If schema is invalid. 155 """ 156 if raw: 157 # remove dummy fields 158 raw = {k: v for k, v in raw.items() if k in schema} 159 else: 160 raw = {} 161 if not blob_fields: 162 return raw 163 164 loaded_columns = [] 165 if columns: 166 for column in columns: 167 if column in blob_fields: 168 loaded_columns.append(column) 169 else: 170 loaded_columns = blob_fields 171 172 def _render_raw(field, blob_data): 173 data_type = schema[field]['type'] 174 data_shape = schema[field]['shape'] if 'shape' in schema[field] else [] 175 if data_shape: 176 try: 177 raw[field] = np.reshape(np.frombuffer(blob_data, dtype=data_type), data_shape) 178 except ValueError: 179 raise MRMUnsupportedSchemaError('Shape in schema is illegal.') 180 else: 181 raw[field] = blob_data 182 183 for blob_field in loaded_columns: 184 _render_raw(blob_field, bytes(blob[blob_field])) 185 return raw 186