• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import os
17import pickle
18import numpy as np
19from mindspore import dataset as ds
20from mindspore.dataset.transforms import c_transforms as C
21from mindspore.common.tensor import Tensor
22from mindspore.common import dtype as mstype
23
24
25class InputFeatures:
26    """A single set of features of data."""
27
28    def __init__(self, input_ids, input_mask, segment_ids, label_id, seq_length=None):
29        self.input_ids = input_ids
30        self.input_mask = input_mask
31        self.segment_ids = segment_ids
32        self.label_id = label_id
33        self.seq_length = seq_length
34
35
36def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, cyclic_trunc=False):
37    """Loads a data file into a list of `InputBatch`s."""
38
39    label_map = {label: _ for _, label in enumerate(label_list)}
40
41    features = []
42    for example in examples:
43        tokens = tokenizer.tokenize(example[0])
44        seq_length = len(tokens)
45        if seq_length > max_seq_length - 2:
46            if cyclic_trunc:
47                rand_index = np.random.randint(0, seq_length)
48                tokens = [tokens[_] if _ < seq_length else tokens[_ - seq_length]
49                          for _ in range(rand_index, rand_index + max_seq_length - 2)]
50            else:
51                tokens = tokens[: (max_seq_length - 2)]
52
53        tokens = ["[CLS]"] + tokens + ["[SEP]"]
54        segment_ids = [0] * len(tokens)
55
56        input_ids = tokenizer.convert_tokens_to_ids(tokens)
57        input_mask = [1] * len(input_ids)
58        seq_length = len(input_ids)
59
60        padding = [0] * (max_seq_length - len(input_ids))
61        input_ids += padding
62        input_mask += padding
63        segment_ids += padding
64
65        assert len(input_ids) == max_seq_length
66        assert len(input_mask) == max_seq_length
67        assert len(segment_ids) == max_seq_length
68        label_id = label_map[example[1]]
69
70        features.append(InputFeatures(input_ids=input_ids,
71                                      input_mask=input_mask,
72                                      segment_ids=segment_ids,
73                                      label_id=label_id,
74                                      seq_length=seq_length))
75    return features
76
77
78def load_dataset(data_path, max_seq_length, tokenizer, batch_size, label_list=None, do_shuffle=True,
79                 drop_remainder=True, output_dir=None, i=0, cyclic_trunc=False):
80    if label_list is None:
81        label_list = ['good', 'leimu', 'xiaoku', 'xin']
82    with open(data_path, 'r', encoding='utf-8') as f:
83        data = f.read()
84    data_list = data.split('\n<<<')
85    input_list = []
86    for key in data_list[1:]:
87        key = key.split('>>>')
88        input_list.append([key[1], key[0]])
89    datasets = create_ms_dataset(input_list, label_list, max_seq_length, tokenizer, batch_size,
90                                 do_shuffle=do_shuffle, drop_remainder=drop_remainder, cyclic_trunc=cyclic_trunc)
91    if output_dir is not None:
92        output_path = os.path.join(output_dir, str(i) + '.dat')
93        print(output_path)
94        with open(output_path, "wb") as f:
95            pickle.dump(tuple(datasets), f)
96    del data, data_list, input_list
97    return datasets, len(label_list)
98
99
100def load_datasets(data_dir, max_seq_length, tokenizer, batch_size, label_list=None, do_shuffle=True,
101                  drop_remainder=True, output_dir=None, cyclic_trunc=False):
102    if label_list is None:
103        label_list = ['good', 'leimu', 'xiaoku', 'xin']
104    data_path_list = os.listdir(data_dir)
105    datasets_list = []
106    for i, relative_path in enumerate(data_path_list):
107        data_path = os.path.join(data_dir, relative_path)
108        with open(data_path, 'r', encoding='utf-8') as f:
109            data = f.read()
110        data_list = data.split('\n<<<')
111        input_list = []
112        for key in data_list[1:]:
113            key = key.split('>>>')
114            input_list.append([key[1], key[0]])
115        datasets = create_ms_dataset(input_list, label_list, max_seq_length, tokenizer, batch_size,
116                                     do_shuffle=do_shuffle, drop_remainder=drop_remainder, cyclic_trunc=cyclic_trunc)
117        if output_dir is not None:
118            output_path = os.path.join(output_dir, str(i) + '.dat')
119            print(output_path)
120            with open(output_path, "wb") as f:
121                pickle.dump(tuple(datasets.create_tuple_iterator()), f)
122        datasets_list.append(datasets)
123    return datasets_list, len(label_list)
124
125
126def create_ms_dataset(data_list, label_list, max_seq_length, tokenizer, batch_size, do_shuffle=True,
127                      drop_remainder=True, cyclic_trunc=False):
128    features = convert_examples_to_features(data_list, label_list, max_seq_length, tokenizer,
129                                            cyclic_trunc=cyclic_trunc)
130
131    def generator_func():
132        for feature in features:
133            yield (np.array(feature.input_ids),
134                   np.array(feature.input_mask),
135                   np.array(feature.segment_ids),
136                   np.array(feature.label_id),
137                   np.array(feature.seq_length))
138
139    dataset = ds.GeneratorDataset(generator_func,
140                                  ['input_ids', 'input_mask', 'token_type_id', 'label_ids', 'seq_length'])
141    if do_shuffle:
142        dataset = dataset.shuffle(buffer_size=10000)
143
144    type_cast_op = C.TypeCast(mstype.int32)
145    dataset = dataset.map(operations=[type_cast_op])
146    dataset = dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder)
147    return dataset
148
149
150class ConstructMaskAndReplaceTensor:
151    def __init__(self, batch_size, max_seq_length, vocab_size, keep_first_unchange=True, keep_last_unchange=True):
152        self.batch_size = batch_size
153        self.max_seq_length = max_seq_length
154        self.vocab_size = vocab_size
155        self.keep_first_unchange = keep_first_unchange
156        self.keep_last_unchange = keep_last_unchange
157        self.mask_tensor = np.ones((self.batch_size, self.max_seq_length))
158        self.replace_tensor = np.zeros((self.batch_size, self.max_seq_length))
159
160    def construct(self, seq_lengths):
161        for i in range(self.batch_size):
162            for j in range(seq_lengths[i]):
163                rand1 = np.random.random()
164                if rand1 < 0.15:
165                    self.mask_tensor[i, j] = 0
166                    rand2 = np.random.random()
167                    if rand2 < 0.8:
168                        self.replace_tensor[i, j] = 103
169                    elif rand2 < 0.9:
170                        self.mask_tensor[i, j] = 1
171                    else:
172                        self.replace_tensor[i, j] = np.random.randint(0, self.vocab_size)
173                else:
174                    self.mask_tensor[i, j] = 1
175                    self.replace_tensor[i, j] = 0
176            for j in range(seq_lengths[i], self.max_seq_length):
177                self.mask_tensor[i, j] = 1
178                self.replace_tensor[i, j] = 0
179            if self.keep_first_unchange:
180                self.mask_tensor[i, 0] = 1
181                self.replace_tensor[i, 0] = 0
182            if self.keep_last_unchange:
183                self.mask_tensor[i, seq_lengths[i] - 1] = 1
184                self.replace_tensor[i, seq_lengths[i] - 1] = 0
185        mask_tensor = Tensor(self.mask_tensor, dtype=mstype.int32)
186        replace_tensor = Tensor(self.replace_tensor, dtype=mstype.int32)
187        return mask_tensor, replace_tensor
188