• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# coding:utf-8
2import os
3import pickle
4import collections
5import argparse
6import numpy as np
7import pandas as pd
8TRAIN_LINE_COUNT = 45840617
9TEST_LINE_COUNT = 6042135
10
11
12class DataStatsDict():
13    def __init__(self):
14        self.field_size = 39  # value_1-13;  cat_1-26;
15        self.val_cols = ["val_{}".format(i + 1) for i in range(13)]
16        self.cat_cols = ["cat_{}".format(i + 1) for i in range(26)]
17        #
18        self.val_min_dict = {col: 0 for col in self.val_cols}
19        self.val_max_dict = {col: 0 for col in self.val_cols}
20        self.cat_count_dict = {col: collections.defaultdict(int) for col in self.cat_cols}
21        #
22        self.oov_prefix = "OOV_"
23        self.cat2id_dict = {}
24        self.cat2id_dict.update({col: i for i, col in enumerate(self.val_cols)})
25        self.cat2id_dict.update({self.oov_prefix + col: i + len(self.val_cols) for i, col in enumerate(self.cat_cols)})
26        # { "val_1": , ..., "val_13": ,  "OOV_cat_1": , ..., "OOV_cat_26": }
27
28    def stats_vals(self, val_list):
29        assert len(val_list) == len(self.val_cols)
30        def map_max_min(i, val):
31            key = self.val_cols[i]
32            if val != "":
33                if float(val) > self.val_max_dict[key]:
34                    self.val_max_dict[key] = float(val)
35                if float(val) < self.val_min_dict[key]:
36                    self.val_min_dict[key] = float(val)
37        for i, val in enumerate(val_list):
38            map_max_min(i, val)
39
40    def stats_cats(self, cat_list):
41        assert len(cat_list) == len(self.cat_cols)
42        def map_cat_count(i, cat):
43            key = self.cat_cols[i]
44            self.cat_count_dict[key][cat] += 1
45        for i, cat in enumerate(cat_list):
46            map_cat_count(i, cat)
47    #
48    def save_dict(self, output_path, prefix=""):
49        with open(os.path.join(output_path, "{}val_max_dict.pkl".format(prefix)), "wb") as file_wrt:
50            pickle.dump(self.val_max_dict, file_wrt)
51        with open(os.path.join(output_path, "{}val_min_dict.pkl".format(prefix)), "wb") as file_wrt:
52            pickle.dump(self.val_min_dict, file_wrt)
53        with open(os.path.join(output_path, "{}cat_count_dict.pkl".format(prefix)), "wb") as file_wrt:
54            pickle.dump(self.cat_count_dict, file_wrt)
55
56    def load_dict(self, dict_path, prefix=""):
57        with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "rb") as file_wrt:
58            self.val_max_dict = pickle.load(file_wrt)
59        with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "rb") as file_wrt:
60            self.val_min_dict = pickle.load(file_wrt)
61        with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "rb") as file_wrt:
62            self.cat_count_dict = pickle.load(file_wrt)
63        print("val_max_dict.items()[:50]: {}".format(list(self.val_max_dict.items())))
64        print("val_min_dict.items()[:50]: {}".format(list(self.val_min_dict.items())))
65
66    def get_cat2id(self, threshold=100):
67        for key, cat_count_d in self.cat_count_dict.items():
68            new_cat_count_d = dict(filter(lambda x: x[1] > threshold, cat_count_d.items()))
69            for cat_str, _ in new_cat_count_d.items():
70                self.cat2id_dict[key + "_" + cat_str] = len(self.cat2id_dict)
71        # print("before_all_count: {}".format( before_all_count )) # before_all_count: 33762577
72        # print("after_all_count: {}".format( after_all_count )) # after_all_count: 184926
73        print("cat2id_dict.size: {}".format(len(self.cat2id_dict)))
74        print("cat2id_dict.items()[:50]: {}".format(list(self.cat2id_dict.items())[:50]))
75
76    def map_cat2id(self, values, cats):
77        def minmax_scale_value(i, val):
78            # min_v = float(self.val_min_dict[ "val_{}".format(i+1) ])
79            max_v = float(self.val_max_dict["val_{}".format(i + 1)])
80            # return ( float(val) - min_v ) * 1.0 / (max_v - min_v)
81            return float(val) * 1.0 / max_v
82        id_list = []
83        weight_list = []
84        for i, val in enumerate(values):
85            if val == "":
86                id_list.append(i)
87                weight_list.append(0)
88            else:
89                key = "val_{}".format(i + 1)
90                id_list.append(self.cat2id_dict[key])
91                weight_list.append(minmax_scale_value(i, float(val)))
92
93        for i, cat_str in enumerate(cats):
94            key = "cat_{}".format(i + 1) + "_" + cat_str
95            if key in self.cat2id_dict:
96                id_list.append(self.cat2id_dict[key])
97            else:
98                id_list.append(self.cat2id_dict[self.oov_prefix + "cat_{}".format(i + 1)])
99            weight_list.append(1.0)
100        return id_list, weight_list
101
102def mkdir_path(file_path):
103    if not os.path.exists(file_path):
104        os.makedirs(file_path)
105
106def statsdata(data_source_path, output_path, data_stats1):
107    with open(data_source_path, encoding="utf-8") as file_in:
108        errorline_list = []
109        count = 0
110        for line in file_in:
111            count += 1
112            line = line.strip("\n")
113            items = line.split("\t")
114            if len(items) != 40:
115                errorline_list.append(count)
116                print("line: {}".format(line))
117                continue
118            if count % 1000000 == 0:
119                print("Have handle {}w lines.".format(count // 10000))
120            values = items[1:14]
121            cats = items[14:]
122            assert len(values) == 13, "values.size: {}".format(len(values))
123            assert len(cats) == 26, "cats.size: {}".format(len(cats))
124            data_stats1.stats_vals(values)
125            data_stats1.stats_cats(cats)
126    data_stats1.save_dict(output_path)
127
128def add_write(file_path, wrt_str):
129    with open(file_path, 'a', encoding="utf-8") as file_out:
130        file_out.write(wrt_str + "\n")
131
132def random_split_trans2h5(input_file_path, output_path, data_stats2, part_rows=2000000, test_size=0.1, seed=2020):
133    test_size = int(TRAIN_LINE_COUNT * test_size)
134
135    all_indices = [i for i in range(TRAIN_LINE_COUNT)]
136    np.random.seed(seed)
137    np.random.shuffle(all_indices)
138    print("all_indices.size: {}".format(len(all_indices)))
139    test_indices_set = set(all_indices[: test_size])
140    print("test_indices_set.size: {}".format(len(test_indices_set)))
141    print("----------" * 10 + "\n" * 2)
142
143    train_feature_file_name = os.path.join(output_path, "train_input_part_{}.h5")
144    train_label_file_name = os.path.join(output_path, "train_output_part_{}.h5")
145    test_feature_file_name = os.path.join(output_path, "test_input_part_{}.h5")
146    test_label_file_name = os.path.join(output_path, "test_output_part_{}.h5")
147
148    train_feature_list = []
149    train_label_list = []
150    test_feature_list = []
151    test_label_list = []
152    with open(input_file_path, encoding="utf-8") as file_in:
153        count = 0
154        train_part_number = 0
155        test_part_number = 0
156        for i, line in enumerate(file_in):
157            count += 1
158            if count % 1000000 == 0:
159                print("Have handle {}w lines.".format(count // 10000))
160            line = line.strip("\n")
161            items = line.split("\t")
162            if len(items) != 40:
163                continue
164            label = float(items[0])
165            values = items[1:14]
166            cats = items[14:]
167            assert len(values) == 13, "values.size: {}".format(len(values))
168            assert len(cats) == 26, "cats.size: {}".format(len(cats))
169            ids, wts = data_stats2.map_cat2id(values, cats)
170            if i not in test_indices_set:
171                train_feature_list.append(ids + wts)
172                train_label_list.append(label)
173            else:
174                test_feature_list.append(ids + wts)
175                test_label_list.append(label)
176            if train_label_list and (len(train_label_list) % part_rows == 0):
177                pd.DataFrame(np.asarray(train_feature_list)).to_hdf(train_feature_file_name.format(train_part_number),
178                                                                    key="fixed")
179                pd.DataFrame(np.asarray(train_label_list)).to_hdf(train_label_file_name.format(train_part_number),
180                                                                  key="fixed")
181                train_feature_list = []
182                train_label_list = []
183                train_part_number += 1
184            if test_label_list and (len(test_label_list) % part_rows == 0):
185                pd.DataFrame(np.asarray(test_feature_list)).to_hdf(test_feature_file_name.format(test_part_number),
186                                                                   key="fixed")
187                pd.DataFrame(np.asarray(test_label_list)).to_hdf(test_label_file_name.format(test_part_number),
188                                                                 key="fixed")
189                test_feature_list = []
190                test_label_list = []
191                test_part_number += 1
192
193        if train_label_list:
194            pd.DataFrame(np.asarray(train_feature_list)).to_hdf(train_feature_file_name.format(train_part_number),
195                                                                key="fixed")
196            pd.DataFrame(np.asarray(train_label_list)).to_hdf(train_label_file_name.format(train_part_number),
197                                                              key="fixed")
198
199        if test_label_list:
200            pd.DataFrame(np.asarray(test_feature_list)).to_hdf(test_feature_file_name.format(test_part_number),
201                                                               key="fixed")
202            pd.DataFrame(np.asarray(test_label_list)).to_hdf(test_label_file_name.format(test_part_number), key="fixed")
203
204if __name__ == "__main__":
205    parser = argparse.ArgumentParser(description='Get and Process datasets')
206    parser.add_argument('--base_path', default="/home/wushuquan/tmp/", help='The path to save dataset')
207    parser.add_argument('--output_path', default="/home/wushuquan/tmp/h5dataset/",
208                        help='The path to save h5 dataset')
209
210    args, _ = parser.parse_known_args()
211    base_path = args.base_path
212    data_path = base_path + ""
213
214    os.system("tar -zxvf {}dac.tar.gz".format(data_path))
215    print("********tar end***********")
216    data_stats = DataStatsDict()
217
218    # step 1, stats the vocab and normalize value
219    data_file_path = "./train.txt"
220    stats_output_path = base_path + "stats_dict/"
221    mkdir_path(stats_output_path)
222    statsdata(data_file_path, stats_output_path, data_stats)
223    print("----------" * 10)
224    data_stats.load_dict(dict_path=stats_output_path, prefix="")
225    data_stats.get_cat2id(threshold=100)
226    # step 2, transform data trans2h5; version 2: np.random.shuffle
227    in_file_path = "./train.txt"
228    mkdir_path(args.output_path)
229    random_split_trans2h5(in_file_path, args.output_path, data_stats, part_rows=2000000, test_size=0.1, seed=2020)
230