• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16This module is to write data into mindrecord.
17"""
18import os
19import sys
20import threading
21import traceback
22
23from inspect import signature
24from functools import wraps
25
26import numpy as np
27import mindspore._c_mindrecord as ms
28from .common.exceptions import ParamValueError, MRMUnsupportedSchemaError
29
30SUCCESS = ms.MSRStatus.SUCCESS
31FAILED = ms.MSRStatus.FAILED
32DATASET_NLP = ms.ShardType.NLP
33DATASET_CV = ms.ShardType.CV
34
35MIN_HEADER_SIZE = ms.MIN_HEADER_SIZE
36MAX_HEADER_SIZE = ms.MAX_HEADER_SIZE
37MIN_PAGE_SIZE = ms.MIN_PAGE_SIZE
38MAX_PAGE_SIZE = ms.MAX_PAGE_SIZE
39MIN_SHARD_COUNT = ms.MIN_SHARD_COUNT
40MAX_SHARD_COUNT = ms.MAX_SHARD_COUNT
41MIN_CONSUMER_COUNT = ms.MIN_CONSUMER_COUNT
42MAX_CONSUMER_COUNT = ms.get_max_thread_num
43MIN_FILE_SIZE = ms.MIN_FILE_SIZE
44
45VALUE_TYPE_MAP = {"int": ["int32", "int64"], "float": ["float32", "float64"], "str": "string", "bytes": "bytes",
46                  "int32": "int32", "int64": "int64", "float32": "float32", "float64": "float64",
47                  "ndarray": ["int32", "int64", "float32", "float64"]}
48
49VALID_ATTRIBUTES = ["int32", "int64", "float32", "float64", "string", "bytes"]
50VALID_ARRAY_ATTRIBUTES = ["int32", "int64", "float32", "float64"]
51
52
53class ExceptionThread(threading.Thread):
54    """ class to pass exception"""
55    def __init__(self, *args, **kwargs):
56        threading.Thread.__init__(self, *args, **kwargs)
57        self.res = SUCCESS
58        self.exitcode = 0
59        self.exception = None
60        self.exc_traceback = ''
61
62    def run(self):
63        try:
64            if self._target:
65                self.res = self._target(*self._args, **self._kwargs)
66        except Exception as e:  # pylint: disable=W0703
67            self.exitcode = 1
68            self.exception = e
69            self.exc_traceback = ''.join(traceback.format_exception(*sys.exc_info()))
70
71
72def check_filename(path, arg_name=""):
73    """
74    check the filename in the path.
75
76    Args:
77        path (str): the path.
78
79    Raises:
80        ParamValueError: If path is not string.
81        FileNameError: If path contains invalid character.
82
83    Returns:
84        Bool, whether filename is valid.
85    """
86    if arg_name == "":
87        arg_name = "File path"
88    else:
89        arg_name = "'{}'".format(arg_name)
90    if not path:
91        raise ParamValueError('{} is not allowed None or empty!'.format(arg_name))
92    if not isinstance(path, str):
93        raise ParamValueError("File path: {} is not string.".format(path))
94    if path.endswith("/"):
95        raise ParamValueError("File path can not end with '/'")
96    file_name = os.path.basename(path)
97
98    # '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`',
99    # '&', '.', '/', '@', "'", '^', ',', '_', '<', ';', '~', '>',
100    # '*', '(', '%', ')', '-', '=', '{', '?', '$'
101    forbidden_symbols = set(r'\/:*?"<>|`&\';')
102
103    if set(file_name) & forbidden_symbols:
104        raise ParamValueError(r"File name should not contains \/:*?\"<>|`&;\'")
105
106    if file_name.startswith(' ') or file_name.endswith(' '):
107        raise ParamValueError("File name should not start/end with space.")
108
109    return True
110
111
112def check_parameter(func):
113    """
114    decorator for parameter check
115    """
116    sig = signature(func)
117
118    @wraps(func)
119    def wrapper(*args, **kw):
120        bound = sig.bind(*args, **kw)
121        for name, value in bound.arguments.items():
122            if name == 'file_name':
123                if isinstance(value, list):
124                    for f in value:
125                        check_filename(f)
126                else:
127                    check_filename(value)
128            if name == 'num_consumer':
129                if value is None:
130                    raise ParamValueError("Parameter num_consumer is None.")
131                if isinstance(value, int):
132                    if value < MIN_CONSUMER_COUNT or value > MAX_CONSUMER_COUNT():
133                        raise ParamValueError("Parameter num_consumer: {} should between {} and {}."
134                                              .format(value, MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT()))
135                else:
136                    raise ParamValueError("Parameter num_consumer is not int.")
137        return func(*args, **kw)
138
139    return wrapper
140
141
142def populate_data(raw, blob, columns, blob_fields, schema):
143    """
144    Reconstruct data form raw and blob data.
145
146    Args:
147        raw (Dict): Data contain primitive data like "int32", "int64", "float32", "float64", "string", "bytes".
148        blob (Bytes): Data contain bytes and ndarray data.
149        columns(List): List of column name which will be populated.
150        blob_fields (List): Refer to the field which data stored in blob.
151        schema(Dict): Dict of Schema
152
153    Raises:
154        MRMUnsupportedSchemaError: If schema is invalid.
155    """
156    if raw:
157        # remove dummy fields
158        raw = {k: v for k, v in raw.items() if k in schema}
159    else:
160        raw = {}
161    if not blob_fields:
162        return raw
163
164    loaded_columns = []
165    if columns:
166        for column in columns:
167            if column in blob_fields:
168                loaded_columns.append(column)
169    else:
170        loaded_columns = blob_fields
171
172    def _render_raw(field, blob_data):
173        data_type = schema[field]['type']
174        data_shape = schema[field]['shape'] if 'shape' in schema[field] else []
175        if data_shape:
176            try:
177                raw[field] = np.reshape(np.frombuffer(blob_data, dtype=data_type), data_shape)
178            except ValueError:
179                raise MRMUnsupportedSchemaError('Shape in schema is illegal.')
180        else:
181            raw[field] = blob_data
182
183    for blob_field in loaded_columns:
184        _render_raw(blob_field, bytes(blob[blob_field]))
185    return raw
186