1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16import pytest 17 18import mindspore.dataset as ds 19import mindspore.dataset.vision.c_transforms as vision 20 21 22DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] 23SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json" 24 25 26def test_tf_skip(): 27 """ 28 a simple skip operation. 29 """ 30 data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) 31 32 resize_height, resize_width = 32, 32 33 decode_op = vision.Decode() 34 resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR) 35 data1 = data1.map(operations=decode_op, input_columns=["image"]) 36 data1 = data1.map(operations=resize_op, input_columns=["image"]) 37 data1 = data1.skip(2) 38 39 num_iter = 0 40 for _ in data1.create_dict_iterator(num_epochs=1): 41 num_iter += 1 42 assert num_iter == 1 43 44 45def generator_md(): 46 """ 47 create a dataset with [0, 1, 2, 3, 4] 48 """ 49 for i in range(5): 50 yield (np.array([i]),) 51 52 53def test_generator_skip(): 54 ds1 = ds.GeneratorDataset(generator_md, ["data"], num_parallel_workers=4) 55 56 # Here ds1 should be [3, 4] 57 ds1 = ds1.skip(3) 58 59 buf = [] 60 for data in ds1.create_tuple_iterator(output_numpy=True): 61 buf.append(data[0][0]) 62 assert len(buf) == 2 63 assert buf == [3, 4] 64 65 66def test_skip_1(): 67 ds1 = ds.GeneratorDataset(generator_md, ["data"]) 68 69 # Here ds1 should be [] 70 ds1 = ds1.skip(7) 71 72 buf = [] 73 for data in ds1.create_tuple_iterator(output_numpy=True): 74 buf.append(data[0][0]) 75 assert buf == [] 76 77 78def test_skip_2(): 79 ds1 = ds.GeneratorDataset(generator_md, ["data"]) 80 81 # Here ds1 should be [0, 1, 2, 3, 4] 82 ds1 = ds1.skip(0) 83 84 buf = [] 85 for data in ds1.create_tuple_iterator(output_numpy=True): 86 buf.append(data[0][0]) 87 assert len(buf) == 5 88 assert buf == [0, 1, 2, 3, 4] 89 90 91def test_skip_repeat_1(): 92 ds1 = ds.GeneratorDataset(generator_md, ["data"]) 93 94 # Here ds1 should be [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] 95 ds1 = ds1.repeat(2) 96 97 # Here ds1 should be [3, 4, 0, 1, 2, 3, 4] 98 ds1 = ds1.skip(3) 99 100 buf = [] 101 for data in ds1.create_tuple_iterator(output_numpy=True): 102 buf.append(data[0][0]) 103 assert len(buf) == 7 104 assert buf == [3, 4, 0, 1, 2, 3, 4] 105 106 107def test_skip_repeat_2(): 108 ds1 = ds.GeneratorDataset(generator_md, ["data"]) 109 110 # Here ds1 should be [3, 4] 111 ds1 = ds1.skip(3) 112 113 # Here ds1 should be [3, 4, 3, 4] 114 ds1 = ds1.repeat(2) 115 116 buf = [] 117 for data in ds1.create_tuple_iterator(output_numpy=True): 118 buf.append(data[0][0]) 119 assert len(buf) == 4 120 assert buf == [3, 4, 3, 4] 121 122 123def test_skip_repeat_3(): 124 ds1 = ds.GeneratorDataset(generator_md, ["data"]) 125 126 # Here ds1 should be [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] 127 ds1 = ds1.repeat(2) 128 129 # Here ds1 should be [3, 4] 130 ds1 = ds1.skip(8) 131 132 # Here ds1 should be [3, 4, 3, 4, 3, 4] 133 ds1 = ds1.repeat(3) 134 135 buf = [] 136 for data in ds1.create_tuple_iterator(output_numpy=True): 137 buf.append(data[0][0]) 138 assert len(buf) == 6 139 assert buf == [3, 4, 3, 4, 3, 4] 140 141 142def test_skip_take_1(): 143 ds1 = ds.GeneratorDataset(generator_md, ["data"]) 144 145 # Here ds1 should be [0, 1, 2, 3] 146 ds1 = ds1.take(4) 147 148 # Here ds1 should be [2, 3] 149 ds1 = ds1.skip(2) 150 151 buf = [] 152 for data in ds1.create_tuple_iterator(output_numpy=True): 153 buf.append(data[0][0]) 154 assert len(buf) == 2 155 assert buf == [2, 3] 156 157 158def test_skip_take_2(): 159 ds1 = ds.GeneratorDataset(generator_md, ["data"]) 160 161 # Here ds1 should be [2, 3, 4] 162 ds1 = ds1.skip(2) 163 164 # Here ds1 should be [2, 3] 165 ds1 = ds1.take(2) 166 167 buf = [] 168 for data in ds1.create_tuple_iterator(output_numpy=True): 169 buf.append(data[0][0]) 170 assert len(buf) == 2 171 assert buf == [2, 3] 172 173 174def generator_1d(): 175 for i in range(64): 176 yield (np.array([i]),) 177 178 179def test_skip_filter_1(): 180 dataset = ds.GeneratorDataset(generator_1d, ['data']) 181 dataset = dataset.skip(5) 182 dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4) 183 184 buf = [] 185 for item in dataset.create_tuple_iterator(output_numpy=True): 186 buf.append(item[0][0]) 187 assert buf == [5, 6, 7, 8, 9, 10] 188 189 190def test_skip_filter_2(): 191 dataset = ds.GeneratorDataset(generator_1d, ['data']) 192 dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4) 193 dataset = dataset.skip(5) 194 195 buf = [] 196 for item in dataset.create_tuple_iterator(output_numpy=True): 197 buf.append(item[0][0]) 198 assert buf == [5, 6, 7, 8, 9, 10] 199 200 201def test_skip_exception_1(): 202 data1 = ds.GeneratorDataset(generator_md, ["data"]) 203 204 try: 205 data1 = data1.skip(count=-1) 206 num_iter = 0 207 for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 208 num_iter += 1 209 210 except ValueError as e: 211 assert "Input count is not within the required interval" in str(e) 212 213 214def test_skip_exception_2(): 215 ds1 = ds.GeneratorDataset(generator_md, ["data"]) 216 217 with pytest.raises(ValueError) as e: 218 ds1 = ds1.skip(-2) 219 assert "Input count is not within the required interval" in str(e.value) 220 221 222 223if __name__ == "__main__": 224 test_tf_skip() 225 test_generator_skip() 226 test_skip_1() 227 test_skip_2() 228 test_skip_repeat_1() 229 test_skip_repeat_2() 230 test_skip_repeat_3() 231 test_skip_take_1() 232 test_skip_take_2() 233 test_skip_filter_1() 234 test_skip_filter_2() 235 test_skip_exception_1() 236 test_skip_exception_2() 237