• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16
17import mindspore.dataset as ds
18from mindspore import log as logger
19
20DATA_DIR = ["../data/dataset/testTFBert5Rows1/5TFDatas.data"]
21DATA_DIR_2 = ["../data/dataset/testTFBert5Rows2/5TFDatas.data"]
22SCHEMA_DIR = "../data/dataset/testTFBert5Rows1/datasetSchema.json"
23SCHEMA_DIR_2 = "../data/dataset/testTFBert5Rows2/datasetSchema.json"
24
25
26def test_rename():
27    data1 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2, shuffle=False)
28    data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2, shuffle=False)
29
30    data2 = data2.rename(input_columns=["input_ids", "segment_ids"], output_columns=["masks", "seg_ids"])
31
32    data = ds.zip((data1, data2))
33    data = data.repeat(3)
34
35    num_iter = 0
36
37    for _, item in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)):
38        logger.info("item[mask] is {}".format(item["masks"]))
39        np.testing.assert_equal(item["masks"], item["input_ids"])
40        logger.info("item[seg_ids] is {}".format(item["seg_ids"]))
41        np.testing.assert_equal(item["segment_ids"], item["seg_ids"])
42        # need to consume the data in the buffer
43        num_iter += 1
44    logger.info("Number of data in data: {}".format(num_iter))
45    assert num_iter == 15
46
47
48if __name__ == '__main__':
49    logger.info('===========test Rename Repeat===========')
50    test_rename()
51    logger.info('\n')
52