• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16
17from util import save_and_check_tuple
18
19import mindspore.dataset as ds
20import mindspore.dataset.transforms.c_transforms as C
21from mindspore.common import dtype as mstype
22
23DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"]
24SCHEMA_DIR_TF = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
25GENERATE_GOLDEN = False
26
27
28def test_case_project_single_column():
29    columns = ["col_sint32"]
30    parameters = {"params": {'columns': columns}}
31
32    data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False)
33    data1 = data1.project(columns=columns)
34
35    filename = "project_single_column_result.npz"
36    save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
37
38
39def test_case_project_multiple_columns_in_order():
40    columns = ["col_sint16", "col_float", "col_2d"]
41    parameters = {"params": {'columns': columns}}
42
43    data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False)
44    data1 = data1.project(columns=columns)
45
46    filename = "project_multiple_columns_in_order_result.npz"
47    save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
48
49
50def test_case_project_multiple_columns_out_of_order():
51    columns = ["col_3d", "col_sint64", "col_2d"]
52    parameters = {"params": {'columns': columns}}
53
54    data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False)
55    data1 = data1.project(columns=columns)
56
57    filename = "project_multiple_columns_out_of_order_result.npz"
58    save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
59
60
61def test_case_project_map():
62    columns = ["col_3d", "col_sint64", "col_2d"]
63    parameters = {"params": {'columns': columns}}
64
65    data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False)
66    data1 = data1.project(columns=columns)
67
68    type_cast_op = C.TypeCast(mstype.int64)
69    data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"])
70
71    filename = "project_map_after_result.npz"
72    save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
73
74
75def test_case_map_project():
76    columns = ["col_3d", "col_sint64", "col_2d"]
77    parameters = {"params": {'columns': columns}}
78
79    data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False)
80
81    type_cast_op = C.TypeCast(mstype.int64)
82    data1 = data1.map(operations=type_cast_op, input_columns=["col_sint64"])
83
84    data1 = data1.project(columns=columns)
85
86    filename = "project_map_before_result.npz"
87    save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
88
89
90def test_case_project_between_maps():
91    columns = ["col_3d", "col_sint64", "col_2d"]
92    parameters = {"params": {'columns': columns}}
93
94    data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False)
95
96    type_cast_op = C.TypeCast(mstype.int64)
97    data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"])
98    data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"])
99    data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"])
100    data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"])
101
102    data1 = data1.project(columns=columns)
103
104    data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"])
105    data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"])
106    data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"])
107    data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"])
108    data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"])
109
110    filename = "project_between_maps_result.npz"
111    save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
112
113
114def test_case_project_repeat():
115    columns = ["col_3d", "col_sint64", "col_2d"]
116    parameters = {"params": {'columns': columns}}
117
118    data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False)
119    data1 = data1.project(columns=columns)
120
121    repeat_count = 3
122    data1 = data1.repeat(repeat_count)
123
124    filename = "project_before_repeat_result.npz"
125    save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
126
127
128def test_case_repeat_project():
129    columns = ["col_3d", "col_sint64", "col_2d"]
130    parameters = {"params": {'columns': columns}}
131
132    data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False)
133
134    repeat_count = 3
135    data1 = data1.repeat(repeat_count)
136
137    data1 = data1.project(columns=columns)
138
139    filename = "project_after_repeat_result.npz"
140    save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
141
142
143def test_case_map_project_map_project():
144    columns = ["col_3d", "col_sint64", "col_2d"]
145    parameters = {"params": {'columns': columns}}
146
147    data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False)
148
149    type_cast_op = C.TypeCast(mstype.int64)
150    data1 = data1.map(operations=type_cast_op, input_columns=["col_sint64"])
151
152    data1 = data1.project(columns=columns)
153
154    data1 = data1.map(operations=type_cast_op, input_columns=["col_2d"])
155
156    data1 = data1.project(columns=columns)
157
158    filename = "project_alternate_parallel_inline_result.npz"
159    save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
160
161
162def test_column_order():
163    """test the output dict has maintained an insertion order."""
164
165    def gen_3_cols(num):
166        for i in range(num):
167            yield (np.array([i * 3]), np.array([i * 3 + 1]), np.array([i * 3 + 2]))
168
169    def test_config(num, col_order):
170        dst = ds.GeneratorDataset((lambda: gen_3_cols(num)), ["col1", "col2", "col3"]).batch(batch_size=num)
171        dst = dst.project(col_order)
172        res = dict()
173        for item in dst.create_dict_iterator(num_epochs=1):
174            res = item
175        return res
176
177    assert list(test_config(1, ["col3", "col2", "col1"]).keys()) == ["col3", "col2", "col1"]
178    assert list(test_config(2, ["col1", "col2", "col3"]).keys()) == ["col1", "col2", "col3"]
179    assert list(test_config(3, ["col2", "col3", "col1"]).keys()) == ["col2", "col3", "col1"]
180
181
182if __name__ == '__main__':
183    test_column_order()
184