• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16
17import mindspore.dataset as ds
18
19
20def test_tensor_empty():
21    def gen():
22        for _ in range(4):
23            (yield np.array([], dtype=np.int64), np.array([], dtype='S').reshape([0, 4]), np.array([1],
24                                                                                                   dtype=np.float64))
25
26    data = ds.GeneratorDataset(gen, column_names=["col1", "col2", "col3"])
27
28    for d in data.create_tuple_iterator(output_numpy=True):
29        np.testing.assert_array_equal(np.array([], dtype=np.int64), d[0])
30        np.testing.assert_array_equal(np.array([], dtype='S').reshape([0, 4]), d[1])
31        np.testing.assert_array_equal(np.array([1], dtype=np.float64), d[2])
32
33
34def test_tensor_empty_map():
35    def gen():
36        for _ in range(4):
37            (yield np.array([], dtype=np.int64), np.array([], dtype='S'), np.array([1], dtype=np.float64))
38
39    data = ds.GeneratorDataset(gen, column_names=["col1", "col2", "col3"])
40
41    def func(x, y, z):
42        x = np.array([1], dtype=np.int64)
43        y = np.array(["Hi"], dtype='S')
44        z = np.array([], dtype=np.float64)
45        return x, y, z
46
47    data = data.map(operations=func, input_columns=["col1", "col2", "col3"])
48
49    for d in data.create_tuple_iterator(output_numpy=True):
50        np.testing.assert_array_equal(np.array([1], dtype=np.int64), d[0])
51        np.testing.assert_array_equal(np.array(["Hi"], dtype='S'), d[1])
52        np.testing.assert_array_equal(np.array([], dtype=np.float64), d[2])
53
54
55def test_tensor_empty_batch():
56    def gen():
57        for _ in range(4):
58            (yield np.array([], dtype=np.int64), np.array([], dtype='S').reshape([0, 4]), np.array([1],
59                                                                                                   dtype=np.float64))
60
61    data = ds.GeneratorDataset(gen, column_names=["col1", "col2", "col3"]).batch(2)
62
63    for d in data.create_tuple_iterator(output_numpy=True):
64        np.testing.assert_array_equal(np.array([], dtype=np.int64).reshape([2, 0]), d[0])
65        np.testing.assert_array_equal(np.array([], dtype='S').reshape([2, 0, 4]), d[1])
66        np.testing.assert_array_equal(np.array([[1], [1]], dtype=np.float64), d[2])
67
68
69if __name__ == '__main__':
70    test_tensor_empty()
71    test_tensor_empty_map()
72    test_tensor_empty_batch()
73