• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""utils for test"""
16
17import os
18import re
19import string
20import collections
21import json
22import numpy as np
23
24from mindspore import log as logger
25
26
27def get_data(dir_name):
28    """
29    Return raw data of imagenet dataset.
30
31    Args:
32        dir_name (str): String of imagenet dataset's path.
33
34    Returns:
35        List
36    """
37    if not os.path.isdir(dir_name):
38        raise IOError("Directory {} not exists".format(dir_name))
39    img_dir = os.path.join(dir_name, "images")
40    ann_file = os.path.join(dir_name, "annotation.txt")
41    with open(ann_file, "r") as file_reader:
42        lines = file_reader.readlines()
43
44    data_list = []
45    for line in lines:
46        try:
47            filename, label = line.split(",")
48            label = label.strip("\n")
49            with open(os.path.join(img_dir, filename), "rb") as file_reader:
50                img = file_reader.read()
51            data_json = {"file_name": filename,
52                         "data": img,
53                         "label": int(label)}
54            data_list.append(data_json)
55        except FileNotFoundError:
56            continue
57    return data_list
58
59
60def get_two_bytes_data(file_name):
61    """
62    Return raw data of two-bytes dataset.
63
64    Args:
65        file_name (str): String of two-bytes dataset's path.
66
67    Returns:
68        List
69    """
70    if not os.path.exists(file_name):
71        raise IOError("map file {} not exists".format(file_name))
72    dir_name = os.path.dirname(file_name)
73    with open(file_name, "r") as file_reader:
74        lines = file_reader.readlines()
75    data_list = []
76    row_num = 0
77    for line in lines:
78        try:
79            img, label = line.strip('\n').split(" ")
80            with open(os.path.join(dir_name, img), "rb") as file_reader:
81                img_data = file_reader.read()
82            with open(os.path.join(dir_name, label), "rb") as file_reader:
83                label_data = file_reader.read()
84            data_json = {"file_name": img,
85                         "img_data": img_data,
86                         "label_name": label,
87                         "label_data": label_data,
88                         "id": row_num
89                         }
90            row_num += 1
91            data_list.append(data_json)
92        except FileNotFoundError:
93            continue
94    return data_list
95
96
97def get_multi_bytes_data(file_name, bytes_num=3):
98    """
99    Return raw data of multi-bytes dataset.
100
101    Args:
102        file_name (str): String of multi-bytes dataset's path.
103        bytes_num (int): Number of bytes fields.
104
105    Returns:
106       List
107    """
108    if not os.path.exists(file_name):
109        raise IOError("map file {} not exists".format(file_name))
110    dir_name = os.path.dirname(file_name)
111    with open(file_name, "r") as file_reader:
112        lines = file_reader.readlines()
113    data_list = []
114    row_num = 0
115    for line in lines:
116        try:
117            img10_path = line.strip('\n').split(" ")
118            img5 = []
119            for path in img10_path[:bytes_num]:
120                with open(os.path.join(dir_name, path), "rb") as file_reader:
121                    img5 += [file_reader.read()]
122            data_json = {"image_{}".format(i): img5[i]
123                         for i in range(len(img5))}
124            data_json.update({"id": row_num})
125            row_num += 1
126            data_list.append(data_json)
127        except FileNotFoundError:
128            continue
129    return data_list
130
131
132def get_mkv_data(dir_name):
133    """
134    Return raw data of Vehicle_and_Person dataset.
135
136    Args:
137        dir_name (str): String of Vehicle_and_Person dataset's path.
138
139    Returns:
140        List
141    """
142    if not os.path.isdir(dir_name):
143        raise IOError("Directory {} not exists".format(dir_name))
144    img_dir = os.path.join(dir_name, "Image")
145    label_dir = os.path.join(dir_name, "prelabel")
146
147    data_list = []
148    file_list = os.listdir(label_dir)
149
150    index = 1
151    for file in file_list:
152        if os.path.splitext(file)[1] == '.json':
153            file_path = os.path.join(label_dir, file)
154
155            image_name = ''.join([os.path.splitext(file)[0], ".jpg"])
156            image_path = os.path.join(img_dir, image_name)
157
158            with open(file_path, "r") as load_f:
159                load_dict = json.load(load_f)
160
161            if os.path.exists(image_path):
162                with open(image_path, "rb") as file_reader:
163                    img = file_reader.read()
164                data_json = {"file_name": image_name,
165                             "prelabel": str(load_dict),
166                             "data": img,
167                             "id": index}
168                data_list.append(data_json)
169            index += 1
170    logger.info('{} images are missing'.format(len(file_list) - len(data_list)))
171    return data_list
172
173
174def get_nlp_data(dir_name, vocab_file, num):
175    """
176    Return raw data of aclImdb dataset.
177
178    Args:
179        dir_name (str): String of aclImdb dataset's path.
180        vocab_file (str): String of dictionary's path.
181        num (int): Number of sample.
182
183    Returns:
184        List
185    """
186    if not os.path.isdir(dir_name):
187        raise IOError("Directory {} not exists".format(dir_name))
188    for root, _, files in os.walk(dir_name):
189        for index, file_name_extension in enumerate(files):
190            if index < num:
191                file_path = os.path.join(root, file_name_extension)
192                file_name, _ = file_name_extension.split('.', 1)
193                id_, rating = file_name.split('_', 1)
194                with open(file_path, 'r') as f:
195                    raw_content = f.read()
196
197                dictionary = load_vocab(vocab_file)
198                vectors = [dictionary.get('[CLS]')]
199                vectors += [dictionary.get(i) if i in dictionary
200                            else dictionary.get('[UNK]')
201                            for i in re.findall(r"[\w']+|[{}]"
202                                                .format(string.punctuation),
203                                                raw_content)]
204                vectors += [dictionary.get('[SEP]')]
205                input_, mask, segment = inputs(vectors)
206                input_ids = np.reshape(np.array(input_), [1, -1])
207                input_mask = np.reshape(np.array(mask), [1, -1])
208                segment_ids = np.reshape(np.array(segment), [1, -1])
209                data = {
210                    "label": 1,
211                    "id": id_,
212                    "rating": float(rating),
213                    "input_ids": input_ids,
214                    "input_mask": input_mask,
215                    "segment_ids": segment_ids
216                }
217                yield data
218
219
220def convert_to_uni(text):
221    if isinstance(text, str):
222        return text
223    if isinstance(text, bytes):
224        return text.decode('utf-8', 'ignore')
225    raise Exception("The type %s does not convert!" % type(text))
226
227
228def load_vocab(vocab_file):
229    """load vocabulary to translate statement."""
230    vocab = collections.OrderedDict()
231    vocab.setdefault('blank', 2)
232    index = 0
233    with open(vocab_file) as reader:
234        while True:
235            tmp = reader.readline()
236            if not tmp:
237                break
238            token = convert_to_uni(tmp)
239            token = token.strip()
240            vocab[token] = index
241            index += 1
242    return vocab
243
244
245def inputs(vectors, maxlen=50):
246    length = len(vectors)
247    if length > maxlen:
248        return vectors[0:maxlen], [1] * maxlen, [0] * maxlen
249    input_ = vectors + [0] * (maxlen - length)
250    mask = [1] * length + [0] * (maxlen - length)
251    segment = [0] * maxlen
252    return input_, mask, segment
253