• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Data operations, will be used in run_pretrain.py
17"""
18import os
19import mindspore.common.dtype as mstype
20import mindspore.dataset as ds
21import mindspore.dataset.transforms.c_transforms as C
22from mindspore import log as logger
23from .config import bert_net_cfg
24
25
26def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", data_dir=None, schema_dir=None):
27    """create train dataset"""
28    # apply repeat operations
29    repeat_count = epoch_size
30    files = os.listdir(data_dir)
31    data_files = []
32    for file_name in files:
33        if "tfrecord" in file_name:
34            data_files.append(os.path.join(data_dir, file_name))
35    data_set = ds.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None,
36                                  columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
37                                                "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
38                                  shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
39                                  shard_equal_rows=True)
40    ori_dataset_size = data_set.get_dataset_size()
41    print('origin dataset size: ', ori_dataset_size)
42    new_repeat_count = int(repeat_count * ori_dataset_size // data_set.get_dataset_size())
43    type_cast_op = C.TypeCast(mstype.int32)
44    data_set = data_set.map(operations=type_cast_op, input_columns="masked_lm_ids")
45    data_set = data_set.map(operations=type_cast_op, input_columns="masked_lm_positions")
46    data_set = data_set.map(operations=type_cast_op, input_columns="next_sentence_labels")
47    data_set = data_set.map(operations=type_cast_op, input_columns="segment_ids")
48    data_set = data_set.map(operations=type_cast_op, input_columns="input_mask")
49    data_set = data_set.map(operations=type_cast_op, input_columns="input_ids")
50    # apply batch operations
51    data_set = data_set.batch(bert_net_cfg.batch_size, drop_remainder=True)
52    data_set = data_set.repeat(max(new_repeat_count, repeat_count))
53    logger.info("data size: {}".format(data_set.get_dataset_size()))
54    logger.info("repeatcount: {}".format(data_set.get_repeat_count()))
55    return data_set, new_repeat_count
56