1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16 17from util import save_and_check_tuple 18 19import mindspore.dataset as ds 20import mindspore.dataset.transforms.c_transforms as C 21from mindspore.common import dtype as mstype 22 23DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"] 24SCHEMA_DIR_TF = "../data/dataset/testTFTestAllTypes/datasetSchema.json" 25GENERATE_GOLDEN = False 26 27 28def test_case_project_single_column(): 29 columns = ["col_sint32"] 30 parameters = {"params": {'columns': columns}} 31 32 data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False) 33 data1 = data1.project(columns=columns) 34 35 filename = "project_single_column_result.npz" 36 save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) 37 38 39def test_case_project_multiple_columns_in_order(): 40 columns = ["col_sint16", "col_float", "col_2d"] 41 parameters = {"params": {'columns': columns}} 42 43 data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False) 44 data1 = data1.project(columns=columns) 45 46 filename = "project_multiple_columns_in_order_result.npz" 47 save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) 48 49 50def test_case_project_multiple_columns_out_of_order(): 51 columns = ["col_3d", "col_sint64", "col_2d"] 52 parameters = {"params": {'columns': columns}} 53 54 data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False) 55 data1 = data1.project(columns=columns) 56 57 filename = "project_multiple_columns_out_of_order_result.npz" 58 save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) 59 60 61def test_case_project_map(): 62 columns = ["col_3d", "col_sint64", "col_2d"] 63 parameters = {"params": {'columns': columns}} 64 65 data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False) 66 data1 = data1.project(columns=columns) 67 68 type_cast_op = C.TypeCast(mstype.int64) 69 data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"]) 70 71 filename = "project_map_after_result.npz" 72 save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) 73 74 75def test_case_map_project(): 76 columns = ["col_3d", "col_sint64", "col_2d"] 77 parameters = {"params": {'columns': columns}} 78 79 data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False) 80 81 type_cast_op = C.TypeCast(mstype.int64) 82 data1 = data1.map(operations=type_cast_op, input_columns=["col_sint64"]) 83 84 data1 = data1.project(columns=columns) 85 86 filename = "project_map_before_result.npz" 87 save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) 88 89 90def test_case_project_between_maps(): 91 columns = ["col_3d", "col_sint64", "col_2d"] 92 parameters = {"params": {'columns': columns}} 93 94 data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False) 95 96 type_cast_op = C.TypeCast(mstype.int64) 97 data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"]) 98 data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"]) 99 data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"]) 100 data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"]) 101 102 data1 = data1.project(columns=columns) 103 104 data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"]) 105 data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"]) 106 data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"]) 107 data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"]) 108 data1 = data1.map(operations=type_cast_op, input_columns=["col_3d"]) 109 110 filename = "project_between_maps_result.npz" 111 save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) 112 113 114def test_case_project_repeat(): 115 columns = ["col_3d", "col_sint64", "col_2d"] 116 parameters = {"params": {'columns': columns}} 117 118 data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False) 119 data1 = data1.project(columns=columns) 120 121 repeat_count = 3 122 data1 = data1.repeat(repeat_count) 123 124 filename = "project_before_repeat_result.npz" 125 save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) 126 127 128def test_case_repeat_project(): 129 columns = ["col_3d", "col_sint64", "col_2d"] 130 parameters = {"params": {'columns': columns}} 131 132 data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False) 133 134 repeat_count = 3 135 data1 = data1.repeat(repeat_count) 136 137 data1 = data1.project(columns=columns) 138 139 filename = "project_after_repeat_result.npz" 140 save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) 141 142 143def test_case_map_project_map_project(): 144 columns = ["col_3d", "col_sint64", "col_2d"] 145 parameters = {"params": {'columns': columns}} 146 147 data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False) 148 149 type_cast_op = C.TypeCast(mstype.int64) 150 data1 = data1.map(operations=type_cast_op, input_columns=["col_sint64"]) 151 152 data1 = data1.project(columns=columns) 153 154 data1 = data1.map(operations=type_cast_op, input_columns=["col_2d"]) 155 156 data1 = data1.project(columns=columns) 157 158 filename = "project_alternate_parallel_inline_result.npz" 159 save_and_check_tuple(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) 160 161 162def test_column_order(): 163 """test the output dict has maintained an insertion order.""" 164 165 def gen_3_cols(num): 166 for i in range(num): 167 yield (np.array([i * 3]), np.array([i * 3 + 1]), np.array([i * 3 + 2])) 168 169 def test_config(num, col_order): 170 dst = ds.GeneratorDataset((lambda: gen_3_cols(num)), ["col1", "col2", "col3"]).batch(batch_size=num) 171 dst = dst.project(col_order) 172 res = dict() 173 for item in dst.create_dict_iterator(num_epochs=1): 174 res = item 175 return res 176 177 assert list(test_config(1, ["col3", "col2", "col1"]).keys()) == ["col3", "col2", "col1"] 178 assert list(test_config(2, ["col1", "col2", "col3"]).keys()) == ["col1", "col2", "col3"] 179 assert list(test_config(3, ["col2", "col3", "col1"]).keys()) == ["col2", "col3", "col1"] 180 181 182if __name__ == '__main__': 183 test_column_order() 184