1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16import pytest 17import mindspore.dataset as ds 18from mindspore import log as logger 19 20 21# In generator dataset: Number of rows is 3, its value is 0, 1, 2 22def generator(): 23 for i in range(3): 24 yield (np.array([i]),) 25 26 27# In generator dataset: Number of rows is 10, its value is 0, 1, 2 ... 10 28def generator_10(): 29 for i in range(10): 30 yield (np.array([i]),) 31 32 33def filter_func_ge(data): 34 if data > 3: 35 return False 36 return True 37 38 39def test_take_01(): 40 """ 41 Test take: origin there are 3 row, and take 1 row, in this case: will not meet eoe and eof 42 """ 43 logger.info("test_take_01") 44 data1 = ds.GeneratorDataset(generator, ["data"]) 45 46 data1 = data1.take(1) 47 data1 = data1.repeat(2) 48 49 # Here i refers to index, d refers to data element 50 for _, d in enumerate(data1): 51 assert d[0].asnumpy()[0] == 0 52 53 assert sum([1 for _ in data1]) == 2 54 55 56def test_take_02(): 57 """ 58 Test take: origin there are 3 row, and take 2 row, in this case: will meet eoe 59 """ 60 logger.info("test_take_02") 61 data1 = ds.GeneratorDataset(generator, ["data"]) 62 63 data1 = data1.take(2) 64 data1 = data1.repeat(2) 65 66 # Here i refers to index, d refers to data element 67 for i, d in enumerate(data1): 68 assert i % 2 == d[0].asnumpy()[0] 69 70 assert sum([1 for _ in data1]) == 4 71 72 73def test_take_03(): 74 """ 75 Test take: origin there are 3 row, and take 3 row, in this case: will meet eoe and eof 76 """ 77 logger.info("test_take_03") 78 data1 = ds.GeneratorDataset(generator, ["data"]) 79 80 data1 = data1.take(3) 81 data1 = data1.repeat(2) 82 83 # Here i refers to index, d refers to data elements 84 for i, d in enumerate(data1): 85 assert i % 3 == d[0].asnumpy()[0] 86 87 assert sum([1 for _ in data1]) == 6 88 89 90def test_take_04(): 91 """ 92 Test take: origin there are 3 row, and take 4 row, this is more than the total rows 93 """ 94 logger.info("test_take_04") 95 data1 = ds.GeneratorDataset(generator, ["data"]) 96 97 data1 = data1.take(4) 98 data1 = data1.repeat(2) 99 100 # Here i refers to index, d refers to data element 101 for i, d in enumerate(data1): 102 assert i % 3 == d[0].asnumpy()[0] 103 104 assert sum([1 for _ in data1]) == 6 105 106 107def test_take_05(): 108 """ 109 Test take: there is no repeat op 110 """ 111 logger.info("test_take_05") 112 data1 = ds.GeneratorDataset(generator, ["data"]) 113 114 data1 = data1.take(2) 115 116 # Here i refers to index, d refers to data element 117 for i, d in enumerate(data1): 118 assert i == d[0].asnumpy()[0] 119 120 assert sum([1 for _ in data1]) == 2 121 122 123def test_take_06(): 124 """ 125 Test take: repeat is before take 126 """ 127 logger.info("test_take_06") 128 data1 = ds.GeneratorDataset(generator, ["data"]) 129 130 data1 = data1.repeat(2) 131 data1 = data1.take(4) 132 133 # Here i refers to index, d refers to data element 134 for i, d in enumerate(data1): 135 assert i % 3 == d[0].asnumpy()[0] 136 137 assert sum([1 for _ in data1]) == 4 138 139 140def test_take_07(): 141 """ 142 Test take: take is before batch, that mean take(N), N refer to rows num 143 """ 144 logger.info("test_take_07") 145 data1 = ds.GeneratorDataset(generator, ["data"]) 146 147 data1 = data1.take(2) 148 data1 = data1.batch(2) 149 assert sum([1 for _ in data1]) == 1 150 151 152def test_take_08(): 153 """ 154 Test take: take is after batch, that mean take(N), N refer to batches num 155 """ 156 logger.info("test_take_08") 157 data1 = ds.GeneratorDataset(generator, ["data"]) 158 159 data1 = data1.batch(2) 160 data1 = data1.take(2) 161 assert sum([1 for _ in data1]) == 2 162 163 164def test_take_09(): 165 """ 166 Test take: take count is -1, and read the whole dataset, take after repeat 167 """ 168 logger.info("test_take_09") 169 data1 = ds.GeneratorDataset(generator, ["data"]) 170 171 data1 = data1.repeat(2) 172 data1 = data1.take(-1) 173 174 # Here i refers to index, d refers to data element 175 for i, d in enumerate(data1): 176 assert i % 3 == d[0].asnumpy()[0] 177 178 assert sum([1 for _ in data1]) == 6 179 180 181def test_take_10(): 182 """ 183 Test take: take count is -1, and read the whole dataset, take before repeat 184 """ 185 logger.info("test_take_10") 186 data1 = ds.GeneratorDataset(generator, ["data"]) 187 188 data1 = data1.take(-1) 189 data1 = data1.repeat(2) 190 191 # Here i refers to index, d refers to data element 192 for i, d in enumerate(data1): 193 assert i % 3 == d[0].asnumpy()[0] 194 195 assert sum([1 for _ in data1]) == 6 196 197 198def test_take_11(): 199 """ 200 Test take: batch first, then do repeat and take operation 201 """ 202 logger.info("test_take_11") 203 data1 = ds.GeneratorDataset(generator, ["data"]) 204 205 data1 = data1.batch(2) 206 data1 = data1.repeat(2) 207 data1 = data1.take(-1) 208 209 # Here i refers to index, d refers to data element 210 for i, d in enumerate(data1): 211 assert 2 * (i % 2) == d[0].asnumpy()[0] 212 213 assert sum([1 for _ in data1]) == 4 214 215 216def test_take_12(): 217 """ 218 Test take: take first, then do batch and repeat operation 219 """ 220 logger.info("test_take_12") 221 data1 = ds.GeneratorDataset(generator, ["data"]) 222 223 data1 = data1.take(2) 224 data1 = data1.batch(2) 225 data1 = data1.repeat(2) 226 227 # Here i refers to index, d refers to data element 228 for _, d in enumerate(data1): 229 assert d[0].asnumpy()[0] == 0 230 231 assert sum([1 for _ in data1]) == 2 232 233 234def test_take_13(): 235 """ 236 Test take: skip first, then do take, batch and repeat operation 237 """ 238 logger.info("test_take_13") 239 data1 = ds.GeneratorDataset(generator, ["data"]) 240 241 data1 = data1.skip(2) 242 data1 = data1.take(-1) 243 data1 = data1.batch(2) 244 data1 = data1.repeat(2) 245 246 # Here i refers to index, d refers to data element 247 for _, d in enumerate(data1): 248 assert d[0].asnumpy()[0] == 2 249 250 assert sum([1 for _ in data1]) == 2 251 252 253def test_take_14(): 254 """ 255 Test take: take first, then do batch, skip and repeat operation 256 """ 257 logger.info("test_take_14") 258 data1 = ds.GeneratorDataset(generator, ["data"]) 259 260 data1 = data1.take(-1) 261 data1 = data1.batch(2) 262 data1 = data1.skip(1) 263 data1 = data1.repeat(2) 264 265 # Here i refers to index, d refers to data element 266 for _, d in enumerate(data1): 267 assert d[0].asnumpy()[0] == 2 268 269 assert sum([1 for _ in data1]) == 2 270 271 272def test_take_15(): 273 """ 274 Test take: large amount data, take a part, then do skip operation 275 """ 276 logger.info("test_take_15") 277 data1 = ds.GeneratorDataset(generator_10, ["data"]) 278 279 data1 = data1.take(6) 280 data1 = data1.skip(2) 281 282 # Here i refers to index, d refers to data element 283 for i, d in enumerate(data1): 284 assert (i + 2) == d[0].asnumpy()[0] 285 286 assert sum([1 for _ in data1]) == 4 287 288 289def test_take_16(): 290 """ 291 Test take: large amount data, skip a part, then do take operation 292 """ 293 logger.info("test_take_16") 294 data1 = ds.GeneratorDataset(generator_10, ["data"]) 295 296 data1 = data1.skip(3) 297 data1 = data1.take(5) 298 299 # Here i refers to index, d refers to data element 300 for i, d in enumerate(data1): 301 assert (i + 3) == d[0].asnumpy()[0] 302 303 assert sum([1 for _ in data1]) == 5 304 305 306def test_take_17(): 307 """ 308 Test take: take first, then do filter operation 309 """ 310 logger.info("test_take_17") 311 data1 = ds.GeneratorDataset(generator_10, ["data"]) 312 313 data1 = data1.take(8) 314 data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4) 315 316 # Here i refers to index, d refers to data element 317 for i, d in enumerate(data1): 318 assert i == d[0].asnumpy()[0] 319 320 assert sum([1 for _ in data1]) == 4 321 322 323def test_take_18(): 324 """ 325 Test take: take first, then do filter, skip, batch and repeat operation 326 """ 327 logger.info("test_take_18") 328 data1 = ds.GeneratorDataset(generator_10, ["data"]) 329 330 data1 = data1.take(8) 331 data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4) 332 data1 = data1.skip(2) 333 334 data1 = data1.batch(2) 335 data1 = data1.repeat(2) 336 337 # Here i refers to index, d refers to data element 338 for _, d in enumerate(data1): 339 assert d[0].asnumpy()[0] == 2 340 341 assert sum([1 for _ in data1]) == 2 342 343 344def test_take_19(): 345 """ 346 Test take: take is after batch, that mean take(N), N refer to batches num 347 """ 348 logger.info("test_take_19") 349 with pytest.raises(ValueError) as info: 350 data1 = ds.GeneratorDataset(generator, ["data"]) 351 352 data1 = data1.batch(2) 353 data1 = data1.take(0) 354 assert "within the required interval" in str(info.value) 355 356if __name__ == '__main__': 357 test_take_01() 358 test_take_02() 359 test_take_03() 360 test_take_04() 361 test_take_05() 362 test_take_06() 363 test_take_07() 364 test_take_08() 365 test_take_09() 366 test_take_10() 367 test_take_11() 368 test_take_12() 369 test_take_13() 370 test_take_14() 371 test_take_15() 372 test_take_16() 373 test_take_17() 374 test_take_18() 375 test_take_19() 376 logger.info('== test take operation finished ==') 377