1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16import pytest 17 18import mindspore.common.dtype as mstype 19from mindspore.common.tensor import Tensor 20import mindspore.dataset as ds 21from mindspore.dataset.engine.iterators import ITERATORS_LIST, _cleanup 22 23DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] 24SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json" 25COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", 26 "col_sint16", "col_sint32", "col_sint64"] 27 28 29def check(project_columns): 30 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS, shuffle=False) 31 data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=project_columns, shuffle=False) 32 33 for data_actual, data_expected in zip(data1.create_tuple_iterator(project_columns, num_epochs=1, output_numpy=True), 34 data2.create_tuple_iterator(num_epochs=1, output_numpy=True)): 35 assert len(data_actual) == len(data_expected) 36 assert all([np.array_equal(d1, d2) for d1, d2 in zip(data_actual, data_expected)]) 37 38 39def test_iterator_create_tuple_numpy(): 40 """ 41 Test creating tuple iterator with output NumPy 42 """ 43 check(COLUMNS) 44 check(COLUMNS[0:1]) 45 check(COLUMNS[0:2]) 46 check(COLUMNS[0:7]) 47 check(COLUMNS[7:8]) 48 check(COLUMNS[0:2:8]) 49 50def test_iterator_create_dict_mstensor(): 51 """ 52 Test creating dict iterator with output MSTensor 53 """ 54 def generator(): 55 for i in range(64): 56 yield (np.array([i], dtype=np.float32),) 57 58 # apply dataset operations 59 data1 = ds.GeneratorDataset(generator, ["data"]) 60 61 i = 0 62 for item in data1.create_dict_iterator(num_epochs=1): 63 golden = np.array([i], dtype=np.float32) 64 np.testing.assert_array_equal(item["data"].asnumpy(), golden) 65 assert isinstance(item["data"], Tensor) 66 assert item["data"].dtype == mstype.float32 67 i += 1 68 assert i == 64 69 70def test_iterator_create_tuple_mstensor(): 71 """ 72 Test creating tuple iterator with output MSTensor 73 """ 74 def generator(): 75 for i in range(64): 76 yield (np.array([i], dtype=np.float32),) 77 78 # apply dataset operations 79 data1 = ds.GeneratorDataset(generator, ["data"]) 80 81 i = 0 82 for item in data1.create_tuple_iterator(num_epochs=1): 83 golden = np.array([i], dtype=np.float32) 84 np.testing.assert_array_equal(item[0].asnumpy(), golden) 85 assert isinstance(item[0], Tensor) 86 assert item[0].dtype == mstype.float32 87 i += 1 88 assert i == 64 89 90 91def test_iterator_weak_ref(): 92 ITERATORS_LIST.clear() 93 data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) 94 itr1 = data.create_tuple_iterator(num_epochs=1) 95 itr2 = data.create_tuple_iterator(num_epochs=1) 96 itr3 = data.create_tuple_iterator(num_epochs=1) 97 98 assert len(ITERATORS_LIST) == 3 99 assert sum(itr() is not None for itr in ITERATORS_LIST) == 3 100 101 del itr1 102 assert len(ITERATORS_LIST) == 2 103 assert sum(itr() is not None for itr in ITERATORS_LIST) == 2 104 105 del itr2 106 assert len(ITERATORS_LIST) == 1 107 assert sum(itr() is not None for itr in ITERATORS_LIST) == 1 108 109 del itr3 110 assert ITERATORS_LIST == [] 111 assert sum(itr() is not None for itr in ITERATORS_LIST) == 0 112 113 itr1 = data.create_tuple_iterator(num_epochs=1) 114 itr2 = data.create_tuple_iterator(num_epochs=1) 115 itr3 = data.create_tuple_iterator(num_epochs=1) 116 117 _cleanup() 118 with pytest.raises(AttributeError) as info: 119 itr2.__next__() 120 assert "object has no attribute '_runtime_context'" in str(info.value) 121 122 del itr1 123 assert ITERATORS_LIST == [] 124 125 _cleanup() 126 127def test_iterator_exception(): 128 data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) 129 try: 130 _ = data.create_dict_iterator(output_numpy="123") 131 assert False 132 except TypeError as e: 133 assert "Argument output_numpy with value 123 is not of type" in str(e) 134 135 try: 136 _ = data.create_dict_iterator(output_numpy=123) 137 assert False 138 except TypeError as e: 139 assert "Argument output_numpy with value 123 is not of type" in str(e) 140 141 try: 142 _ = data.create_tuple_iterator(output_numpy="123") 143 assert False 144 except TypeError as e: 145 assert "Argument output_numpy with value 123 is not of type" in str(e) 146 147 try: 148 _ = data.create_tuple_iterator(output_numpy=123) 149 assert False 150 except TypeError as e: 151 assert "Argument output_numpy with value 123 is not of type" in str(e) 152 153 154class MyDict(dict): 155 def __getattr__(self, key): 156 return self[key] 157 158 def __setattr__(self, key, value): 159 self[key] = value 160 161 def __call__(self, t): 162 return t 163 164 165def test_tree_copy(): 166 """ 167 Testing copying the tree with a pyfunc that cannot be pickled 168 """ 169 170 data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS) 171 data1 = data.map(operations=[MyDict()]) 172 173 itr = data1.create_tuple_iterator(num_epochs=1) 174 175 assert id(data1) != id(itr.dataset) 176 assert id(data) != id(itr.dataset.children[0]) 177 assert id(data1.operations[0]) == id(itr.dataset.operations[0]) 178 179 itr.release() 180 181 182if __name__ == '__main__': 183 test_iterator_create_tuple_numpy() 184 test_iterator_weak_ref() 185 test_iterator_exception() 186 test_tree_copy() 187