1# coding:utf-8 2import os 3import pickle 4import collections 5import argparse 6import numpy as np 7import pandas as pd 8TRAIN_LINE_COUNT = 45840617 9TEST_LINE_COUNT = 6042135 10 11 12class DataStatsDict(): 13 def __init__(self): 14 self.field_size = 39 # value_1-13; cat_1-26; 15 self.val_cols = ["val_{}".format(i + 1) for i in range(13)] 16 self.cat_cols = ["cat_{}".format(i + 1) for i in range(26)] 17 # 18 self.val_min_dict = {col: 0 for col in self.val_cols} 19 self.val_max_dict = {col: 0 for col in self.val_cols} 20 self.cat_count_dict = {col: collections.defaultdict(int) for col in self.cat_cols} 21 # 22 self.oov_prefix = "OOV_" 23 self.cat2id_dict = {} 24 self.cat2id_dict.update({col: i for i, col in enumerate(self.val_cols)}) 25 self.cat2id_dict.update({self.oov_prefix + col: i + len(self.val_cols) for i, col in enumerate(self.cat_cols)}) 26 # { "val_1": , ..., "val_13": , "OOV_cat_1": , ..., "OOV_cat_26": } 27 28 def stats_vals(self, val_list): 29 assert len(val_list) == len(self.val_cols) 30 def map_max_min(i, val): 31 key = self.val_cols[i] 32 if val != "": 33 if float(val) > self.val_max_dict[key]: 34 self.val_max_dict[key] = float(val) 35 if float(val) < self.val_min_dict[key]: 36 self.val_min_dict[key] = float(val) 37 for i, val in enumerate(val_list): 38 map_max_min(i, val) 39 40 def stats_cats(self, cat_list): 41 assert len(cat_list) == len(self.cat_cols) 42 def map_cat_count(i, cat): 43 key = self.cat_cols[i] 44 self.cat_count_dict[key][cat] += 1 45 for i, cat in enumerate(cat_list): 46 map_cat_count(i, cat) 47 # 48 def save_dict(self, output_path, prefix=""): 49 with open(os.path.join(output_path, "{}val_max_dict.pkl".format(prefix)), "wb") as file_wrt: 50 pickle.dump(self.val_max_dict, file_wrt) 51 with open(os.path.join(output_path, "{}val_min_dict.pkl".format(prefix)), "wb") as file_wrt: 52 pickle.dump(self.val_min_dict, file_wrt) 53 with open(os.path.join(output_path, "{}cat_count_dict.pkl".format(prefix)), "wb") as file_wrt: 54 pickle.dump(self.cat_count_dict, file_wrt) 55 56 def load_dict(self, dict_path, prefix=""): 57 with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "rb") as file_wrt: 58 self.val_max_dict = pickle.load(file_wrt) 59 with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "rb") as file_wrt: 60 self.val_min_dict = pickle.load(file_wrt) 61 with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "rb") as file_wrt: 62 self.cat_count_dict = pickle.load(file_wrt) 63 print("val_max_dict.items()[:50]: {}".format(list(self.val_max_dict.items()))) 64 print("val_min_dict.items()[:50]: {}".format(list(self.val_min_dict.items()))) 65 66 def get_cat2id(self, threshold=100): 67 for key, cat_count_d in self.cat_count_dict.items(): 68 new_cat_count_d = dict(filter(lambda x: x[1] > threshold, cat_count_d.items())) 69 for cat_str, _ in new_cat_count_d.items(): 70 self.cat2id_dict[key + "_" + cat_str] = len(self.cat2id_dict) 71 # print("before_all_count: {}".format( before_all_count )) # before_all_count: 33762577 72 # print("after_all_count: {}".format( after_all_count )) # after_all_count: 184926 73 print("cat2id_dict.size: {}".format(len(self.cat2id_dict))) 74 print("cat2id_dict.items()[:50]: {}".format(list(self.cat2id_dict.items())[:50])) 75 76 def map_cat2id(self, values, cats): 77 def minmax_scale_value(i, val): 78 # min_v = float(self.val_min_dict[ "val_{}".format(i+1) ]) 79 max_v = float(self.val_max_dict["val_{}".format(i + 1)]) 80 # return ( float(val) - min_v ) * 1.0 / (max_v - min_v) 81 return float(val) * 1.0 / max_v 82 id_list = [] 83 weight_list = [] 84 for i, val in enumerate(values): 85 if val == "": 86 id_list.append(i) 87 weight_list.append(0) 88 else: 89 key = "val_{}".format(i + 1) 90 id_list.append(self.cat2id_dict[key]) 91 weight_list.append(minmax_scale_value(i, float(val))) 92 93 for i, cat_str in enumerate(cats): 94 key = "cat_{}".format(i + 1) + "_" + cat_str 95 if key in self.cat2id_dict: 96 id_list.append(self.cat2id_dict[key]) 97 else: 98 id_list.append(self.cat2id_dict[self.oov_prefix + "cat_{}".format(i + 1)]) 99 weight_list.append(1.0) 100 return id_list, weight_list 101 102def mkdir_path(file_path): 103 if not os.path.exists(file_path): 104 os.makedirs(file_path) 105 106def statsdata(data_source_path, output_path, data_stats1): 107 with open(data_source_path, encoding="utf-8") as file_in: 108 errorline_list = [] 109 count = 0 110 for line in file_in: 111 count += 1 112 line = line.strip("\n") 113 items = line.split("\t") 114 if len(items) != 40: 115 errorline_list.append(count) 116 print("line: {}".format(line)) 117 continue 118 if count % 1000000 == 0: 119 print("Have handle {}w lines.".format(count // 10000)) 120 values = items[1:14] 121 cats = items[14:] 122 assert len(values) == 13, "values.size: {}".format(len(values)) 123 assert len(cats) == 26, "cats.size: {}".format(len(cats)) 124 data_stats1.stats_vals(values) 125 data_stats1.stats_cats(cats) 126 data_stats1.save_dict(output_path) 127 128def add_write(file_path, wrt_str): 129 with open(file_path, 'a', encoding="utf-8") as file_out: 130 file_out.write(wrt_str + "\n") 131 132def random_split_trans2h5(input_file_path, output_path, data_stats2, part_rows=2000000, test_size=0.1, seed=2020): 133 test_size = int(TRAIN_LINE_COUNT * test_size) 134 135 all_indices = [i for i in range(TRAIN_LINE_COUNT)] 136 np.random.seed(seed) 137 np.random.shuffle(all_indices) 138 print("all_indices.size: {}".format(len(all_indices))) 139 test_indices_set = set(all_indices[: test_size]) 140 print("test_indices_set.size: {}".format(len(test_indices_set))) 141 print("----------" * 10 + "\n" * 2) 142 143 train_feature_file_name = os.path.join(output_path, "train_input_part_{}.h5") 144 train_label_file_name = os.path.join(output_path, "train_output_part_{}.h5") 145 test_feature_file_name = os.path.join(output_path, "test_input_part_{}.h5") 146 test_label_file_name = os.path.join(output_path, "test_output_part_{}.h5") 147 148 train_feature_list = [] 149 train_label_list = [] 150 test_feature_list = [] 151 test_label_list = [] 152 with open(input_file_path, encoding="utf-8") as file_in: 153 count = 0 154 train_part_number = 0 155 test_part_number = 0 156 for i, line in enumerate(file_in): 157 count += 1 158 if count % 1000000 == 0: 159 print("Have handle {}w lines.".format(count // 10000)) 160 line = line.strip("\n") 161 items = line.split("\t") 162 if len(items) != 40: 163 continue 164 label = float(items[0]) 165 values = items[1:14] 166 cats = items[14:] 167 assert len(values) == 13, "values.size: {}".format(len(values)) 168 assert len(cats) == 26, "cats.size: {}".format(len(cats)) 169 ids, wts = data_stats2.map_cat2id(values, cats) 170 if i not in test_indices_set: 171 train_feature_list.append(ids + wts) 172 train_label_list.append(label) 173 else: 174 test_feature_list.append(ids + wts) 175 test_label_list.append(label) 176 if train_label_list and (len(train_label_list) % part_rows == 0): 177 pd.DataFrame(np.asarray(train_feature_list)).to_hdf(train_feature_file_name.format(train_part_number), 178 key="fixed") 179 pd.DataFrame(np.asarray(train_label_list)).to_hdf(train_label_file_name.format(train_part_number), 180 key="fixed") 181 train_feature_list = [] 182 train_label_list = [] 183 train_part_number += 1 184 if test_label_list and (len(test_label_list) % part_rows == 0): 185 pd.DataFrame(np.asarray(test_feature_list)).to_hdf(test_feature_file_name.format(test_part_number), 186 key="fixed") 187 pd.DataFrame(np.asarray(test_label_list)).to_hdf(test_label_file_name.format(test_part_number), 188 key="fixed") 189 test_feature_list = [] 190 test_label_list = [] 191 test_part_number += 1 192 193 if train_label_list: 194 pd.DataFrame(np.asarray(train_feature_list)).to_hdf(train_feature_file_name.format(train_part_number), 195 key="fixed") 196 pd.DataFrame(np.asarray(train_label_list)).to_hdf(train_label_file_name.format(train_part_number), 197 key="fixed") 198 199 if test_label_list: 200 pd.DataFrame(np.asarray(test_feature_list)).to_hdf(test_feature_file_name.format(test_part_number), 201 key="fixed") 202 pd.DataFrame(np.asarray(test_label_list)).to_hdf(test_label_file_name.format(test_part_number), key="fixed") 203 204if __name__ == "__main__": 205 parser = argparse.ArgumentParser(description='Get and Process datasets') 206 parser.add_argument('--base_path', default="/home/wushuquan/tmp/", help='The path to save dataset') 207 parser.add_argument('--output_path', default="/home/wushuquan/tmp/h5dataset/", 208 help='The path to save h5 dataset') 209 210 args, _ = parser.parse_known_args() 211 base_path = args.base_path 212 data_path = base_path + "" 213 214 os.system("tar -zxvf {}dac.tar.gz".format(data_path)) 215 print("********tar end***********") 216 data_stats = DataStatsDict() 217 218 # step 1, stats the vocab and normalize value 219 data_file_path = "./train.txt" 220 stats_output_path = base_path + "stats_dict/" 221 mkdir_path(stats_output_path) 222 statsdata(data_file_path, stats_output_path, data_stats) 223 print("----------" * 10) 224 data_stats.load_dict(dict_path=stats_output_path, prefix="") 225 data_stats.get_cat2id(threshold=100) 226 # step 2, transform data trans2h5; version 2: np.random.shuffle 227 in_file_path = "./train.txt" 228 mkdir_path(args.output_path) 229 random_split_trans2h5(in_file_path, args.output_path, data_stats, part_rows=2000000, test_size=0.1, seed=2020) 230