• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16
17import mindspore.dataset as ds
18from util import config_get_set_seed, config_get_set_num_parallel_workers
19
20
21# Generate 1d int numpy array from 0 - 63
22def generator_1d():
23    for i in range(4):
24        yield (np.array([i]),)
25
26
27def test_case_0():
28    """
29    Test 1D Generator.
30    Test without explicit kwargs for input args.
31    """
32    original_seed = config_get_set_seed(55)
33    original_num_parallel_workers = config_get_set_num_parallel_workers(1)
34
35    # apply dataset qoperations
36    data1 = ds.GeneratorDataset(generator_1d, ["data"])
37    data1 = data1.shuffle(2)
38    data1 = data1.map((lambda x: x), ["data"])
39    data1 = data1.batch(2)
40
41    expected_data = np.array([[[1], [2]], [[3], [0]]])
42    for i, data_row in enumerate(data1.create_tuple_iterator(output_numpy=True)):
43        np.testing.assert_array_equal(data_row[0], expected_data[i])
44
45    # Restore configuration
46    ds.config.set_seed(original_seed)
47    ds.config.set_num_parallel_workers((original_num_parallel_workers))
48
49
50if __name__ == "__main__":
51    test_case_0()
52