1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""utils for test""" 16 17import os 18import re 19import string 20import collections 21import json 22import numpy as np 23 24from mindspore import log as logger 25 26 27def get_data(dir_name): 28 """ 29 Return raw data of imagenet dataset. 30 31 Args: 32 dir_name (str): String of imagenet dataset's path. 33 34 Returns: 35 List 36 """ 37 if not os.path.isdir(dir_name): 38 raise IOError("Directory {} not exists".format(dir_name)) 39 img_dir = os.path.join(dir_name, "images") 40 ann_file = os.path.join(dir_name, "annotation.txt") 41 with open(ann_file, "r") as file_reader: 42 lines = file_reader.readlines() 43 44 data_list = [] 45 for line in lines: 46 try: 47 filename, label = line.split(",") 48 label = label.strip("\n") 49 with open(os.path.join(img_dir, filename), "rb") as file_reader: 50 img = file_reader.read() 51 data_json = {"file_name": filename, 52 "data": img, 53 "label": int(label)} 54 data_list.append(data_json) 55 except FileNotFoundError: 56 continue 57 return data_list 58 59 60def get_two_bytes_data(file_name): 61 """ 62 Return raw data of two-bytes dataset. 63 64 Args: 65 file_name (str): String of two-bytes dataset's path. 66 67 Returns: 68 List 69 """ 70 if not os.path.exists(file_name): 71 raise IOError("map file {} not exists".format(file_name)) 72 dir_name = os.path.dirname(file_name) 73 with open(file_name, "r") as file_reader: 74 lines = file_reader.readlines() 75 data_list = [] 76 row_num = 0 77 for line in lines: 78 try: 79 img, label = line.strip('\n').split(" ") 80 with open(os.path.join(dir_name, img), "rb") as file_reader: 81 img_data = file_reader.read() 82 with open(os.path.join(dir_name, label), "rb") as file_reader: 83 label_data = file_reader.read() 84 data_json = {"file_name": img, 85 "img_data": img_data, 86 "label_name": label, 87 "label_data": label_data, 88 "id": row_num 89 } 90 row_num += 1 91 data_list.append(data_json) 92 except FileNotFoundError: 93 continue 94 return data_list 95 96 97def get_multi_bytes_data(file_name, bytes_num=3): 98 """ 99 Return raw data of multi-bytes dataset. 100 101 Args: 102 file_name (str): String of multi-bytes dataset's path. 103 bytes_num (int): Number of bytes fields. 104 105 Returns: 106 List 107 """ 108 if not os.path.exists(file_name): 109 raise IOError("map file {} not exists".format(file_name)) 110 dir_name = os.path.dirname(file_name) 111 with open(file_name, "r") as file_reader: 112 lines = file_reader.readlines() 113 data_list = [] 114 row_num = 0 115 for line in lines: 116 try: 117 img10_path = line.strip('\n').split(" ") 118 img5 = [] 119 for path in img10_path[:bytes_num]: 120 with open(os.path.join(dir_name, path), "rb") as file_reader: 121 img5 += [file_reader.read()] 122 data_json = {"image_{}".format(i): img5[i] 123 for i in range(len(img5))} 124 data_json.update({"id": row_num}) 125 row_num += 1 126 data_list.append(data_json) 127 except FileNotFoundError: 128 continue 129 return data_list 130 131 132def get_mkv_data(dir_name): 133 """ 134 Return raw data of Vehicle_and_Person dataset. 135 136 Args: 137 dir_name (str): String of Vehicle_and_Person dataset's path. 138 139 Returns: 140 List 141 """ 142 if not os.path.isdir(dir_name): 143 raise IOError("Directory {} not exists".format(dir_name)) 144 img_dir = os.path.join(dir_name, "Image") 145 label_dir = os.path.join(dir_name, "prelabel") 146 147 data_list = [] 148 file_list = os.listdir(label_dir) 149 150 index = 1 151 for file in file_list: 152 if os.path.splitext(file)[1] == '.json': 153 file_path = os.path.join(label_dir, file) 154 155 image_name = ''.join([os.path.splitext(file)[0], ".jpg"]) 156 image_path = os.path.join(img_dir, image_name) 157 158 with open(file_path, "r") as load_f: 159 load_dict = json.load(load_f) 160 161 if os.path.exists(image_path): 162 with open(image_path, "rb") as file_reader: 163 img = file_reader.read() 164 data_json = {"file_name": image_name, 165 "prelabel": str(load_dict), 166 "data": img, 167 "id": index} 168 data_list.append(data_json) 169 index += 1 170 logger.info('{} images are missing'.format(len(file_list) - len(data_list))) 171 return data_list 172 173 174def get_nlp_data(dir_name, vocab_file, num): 175 """ 176 Return raw data of aclImdb dataset. 177 178 Args: 179 dir_name (str): String of aclImdb dataset's path. 180 vocab_file (str): String of dictionary's path. 181 num (int): Number of sample. 182 183 Returns: 184 List 185 """ 186 if not os.path.isdir(dir_name): 187 raise IOError("Directory {} not exists".format(dir_name)) 188 for root, _, files in os.walk(dir_name): 189 for index, file_name_extension in enumerate(files): 190 if index < num: 191 file_path = os.path.join(root, file_name_extension) 192 file_name, _ = file_name_extension.split('.', 1) 193 id_, rating = file_name.split('_', 1) 194 with open(file_path, 'r') as f: 195 raw_content = f.read() 196 197 dictionary = load_vocab(vocab_file) 198 vectors = [dictionary.get('[CLS]')] 199 vectors += [dictionary.get(i) if i in dictionary 200 else dictionary.get('[UNK]') 201 for i in re.findall(r"[\w']+|[{}]" 202 .format(string.punctuation), 203 raw_content)] 204 vectors += [dictionary.get('[SEP]')] 205 input_, mask, segment = inputs(vectors) 206 input_ids = np.reshape(np.array(input_), [1, -1]) 207 input_mask = np.reshape(np.array(mask), [1, -1]) 208 segment_ids = np.reshape(np.array(segment), [1, -1]) 209 data = { 210 "label": 1, 211 "id": id_, 212 "rating": float(rating), 213 "input_ids": input_ids, 214 "input_mask": input_mask, 215 "segment_ids": segment_ids 216 } 217 yield data 218 219 220def convert_to_uni(text): 221 if isinstance(text, str): 222 return text 223 if isinstance(text, bytes): 224 return text.decode('utf-8', 'ignore') 225 raise Exception("The type %s does not convert!" % type(text)) 226 227 228def load_vocab(vocab_file): 229 """load vocabulary to translate statement.""" 230 vocab = collections.OrderedDict() 231 vocab.setdefault('blank', 2) 232 index = 0 233 with open(vocab_file) as reader: 234 while True: 235 tmp = reader.readline() 236 if not tmp: 237 break 238 token = convert_to_uni(tmp) 239 token = token.strip() 240 vocab[token] = index 241 index += 1 242 return vocab 243 244 245def inputs(vectors, maxlen=50): 246 length = len(vectors) 247 if length > maxlen: 248 return vectors[0:maxlen], [1] * maxlen, [0] * maxlen 249 input_ = vectors + [0] * (maxlen - length) 250 mask = [1] * length + [0] * (maxlen - length) 251 segment = [0] * maxlen 252 return input_, mask, segment 253