1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16 17import mindspore.dataset as ds 18 19 20def test_tensor_empty(): 21 def gen(): 22 for _ in range(4): 23 (yield np.array([], dtype=np.int64), np.array([], dtype='S').reshape([0, 4]), np.array([1], 24 dtype=np.float64)) 25 26 data = ds.GeneratorDataset(gen, column_names=["col1", "col2", "col3"]) 27 28 for d in data.create_tuple_iterator(output_numpy=True): 29 np.testing.assert_array_equal(np.array([], dtype=np.int64), d[0]) 30 np.testing.assert_array_equal(np.array([], dtype='S').reshape([0, 4]), d[1]) 31 np.testing.assert_array_equal(np.array([1], dtype=np.float64), d[2]) 32 33 34def test_tensor_empty_map(): 35 def gen(): 36 for _ in range(4): 37 (yield np.array([], dtype=np.int64), np.array([], dtype='S'), np.array([1], dtype=np.float64)) 38 39 data = ds.GeneratorDataset(gen, column_names=["col1", "col2", "col3"]) 40 41 def func(x, y, z): 42 x = np.array([1], dtype=np.int64) 43 y = np.array(["Hi"], dtype='S') 44 z = np.array([], dtype=np.float64) 45 return x, y, z 46 47 data = data.map(operations=func, input_columns=["col1", "col2", "col3"]) 48 49 for d in data.create_tuple_iterator(output_numpy=True): 50 np.testing.assert_array_equal(np.array([1], dtype=np.int64), d[0]) 51 np.testing.assert_array_equal(np.array(["Hi"], dtype='S'), d[1]) 52 np.testing.assert_array_equal(np.array([], dtype=np.float64), d[2]) 53 54 55def test_tensor_empty_batch(): 56 def gen(): 57 for _ in range(4): 58 (yield np.array([], dtype=np.int64), np.array([], dtype='S').reshape([0, 4]), np.array([1], 59 dtype=np.float64)) 60 61 data = ds.GeneratorDataset(gen, column_names=["col1", "col2", "col3"]).batch(2) 62 63 for d in data.create_tuple_iterator(output_numpy=True): 64 np.testing.assert_array_equal(np.array([], dtype=np.int64).reshape([2, 0]), d[0]) 65 np.testing.assert_array_equal(np.array([], dtype='S').reshape([2, 0, 4]), d[1]) 66 np.testing.assert_array_equal(np.array([[1], [1]], dtype=np.float64), d[2]) 67 68 69if __name__ == '__main__': 70 test_tensor_empty() 71 test_tensor_empty_map() 72 test_tensor_empty_batch() 73