• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16import pytest
17
18import mindspore.common.dtype as mstype
19from mindspore.common.tensor import Tensor
20import mindspore.dataset as ds
21from mindspore.dataset.engine.iterators import ITERATORS_LIST, _cleanup
22
23DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
24SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
25COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
26           "col_sint16", "col_sint32", "col_sint64"]
27
28
29def check(project_columns):
30    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS, shuffle=False)
31    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=project_columns, shuffle=False)
32
33    for data_actual, data_expected in zip(data1.create_tuple_iterator(project_columns, num_epochs=1, output_numpy=True),
34                                          data2.create_tuple_iterator(num_epochs=1, output_numpy=True)):
35        assert len(data_actual) == len(data_expected)
36        assert all([np.array_equal(d1, d2) for d1, d2 in zip(data_actual, data_expected)])
37
38
39def test_iterator_create_tuple_numpy():
40    """
41    Test creating tuple iterator with output NumPy
42    """
43    check(COLUMNS)
44    check(COLUMNS[0:1])
45    check(COLUMNS[0:2])
46    check(COLUMNS[0:7])
47    check(COLUMNS[7:8])
48    check(COLUMNS[0:2:8])
49
50def test_iterator_create_dict_mstensor():
51    """
52    Test creating dict iterator with output MSTensor
53    """
54    def generator():
55        for i in range(64):
56            yield (np.array([i], dtype=np.float32),)
57
58    # apply dataset operations
59    data1 = ds.GeneratorDataset(generator, ["data"])
60
61    i = 0
62    for item in data1.create_dict_iterator(num_epochs=1):
63        golden = np.array([i], dtype=np.float32)
64        np.testing.assert_array_equal(item["data"].asnumpy(), golden)
65        assert isinstance(item["data"], Tensor)
66        assert item["data"].dtype == mstype.float32
67        i += 1
68    assert i == 64
69
70def test_iterator_create_tuple_mstensor():
71    """
72    Test creating tuple iterator with output MSTensor
73    """
74    def generator():
75        for i in range(64):
76            yield (np.array([i], dtype=np.float32),)
77
78    # apply dataset operations
79    data1 = ds.GeneratorDataset(generator, ["data"])
80
81    i = 0
82    for item in data1.create_tuple_iterator(num_epochs=1):
83        golden = np.array([i], dtype=np.float32)
84        np.testing.assert_array_equal(item[0].asnumpy(), golden)
85        assert isinstance(item[0], Tensor)
86        assert item[0].dtype == mstype.float32
87        i += 1
88    assert i == 64
89
90
91def test_iterator_weak_ref():
92    ITERATORS_LIST.clear()
93    data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
94    itr1 = data.create_tuple_iterator(num_epochs=1)
95    itr2 = data.create_tuple_iterator(num_epochs=1)
96    itr3 = data.create_tuple_iterator(num_epochs=1)
97
98    assert len(ITERATORS_LIST) == 3
99    assert sum(itr() is not None for itr in ITERATORS_LIST) == 3
100
101    del itr1
102    assert len(ITERATORS_LIST) == 2
103    assert sum(itr() is not None for itr in ITERATORS_LIST) == 2
104
105    del itr2
106    assert len(ITERATORS_LIST) == 1
107    assert sum(itr() is not None for itr in ITERATORS_LIST) == 1
108
109    del itr3
110    assert ITERATORS_LIST == []
111    assert sum(itr() is not None for itr in ITERATORS_LIST) == 0
112
113    itr1 = data.create_tuple_iterator(num_epochs=1)
114    itr2 = data.create_tuple_iterator(num_epochs=1)
115    itr3 = data.create_tuple_iterator(num_epochs=1)
116
117    _cleanup()
118    with pytest.raises(AttributeError) as info:
119        itr2.__next__()
120    assert "object has no attribute '_runtime_context'" in str(info.value)
121
122    del itr1
123    assert ITERATORS_LIST == []
124
125    _cleanup()
126
127def test_iterator_exception():
128    data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
129    try:
130        _ = data.create_dict_iterator(output_numpy="123")
131        assert False
132    except TypeError as e:
133        assert "Argument output_numpy with value 123 is not of type" in str(e)
134
135    try:
136        _ = data.create_dict_iterator(output_numpy=123)
137        assert False
138    except TypeError as e:
139        assert "Argument output_numpy with value 123 is not of type" in str(e)
140
141    try:
142        _ = data.create_tuple_iterator(output_numpy="123")
143        assert False
144    except TypeError as e:
145        assert "Argument output_numpy with value 123 is not of type" in str(e)
146
147    try:
148        _ = data.create_tuple_iterator(output_numpy=123)
149        assert False
150    except TypeError as e:
151        assert "Argument output_numpy with value 123 is not of type" in str(e)
152
153
154class MyDict(dict):
155    def __getattr__(self, key):
156        return self[key]
157
158    def __setattr__(self, key, value):
159        self[key] = value
160
161    def __call__(self, t):
162        return t
163
164
165def test_tree_copy():
166    """
167    Testing copying the tree with a pyfunc that cannot be pickled
168    """
169
170    data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS)
171    data1 = data.map(operations=[MyDict()])
172
173    itr = data1.create_tuple_iterator(num_epochs=1)
174
175    assert id(data1) != id(itr.dataset)
176    assert id(data) != id(itr.dataset.children[0])
177    assert id(data1.operations[0]) == id(itr.dataset.operations[0])
178
179    itr.release()
180
181
182if __name__ == '__main__':
183    test_iterator_create_tuple_numpy()
184    test_iterator_weak_ref()
185    test_iterator_exception()
186    test_tree_copy()
187