1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import numpy as np 17import pytest 18import mindspore.dataset as ds 19from mindspore import log as logger 20 21 22def gen(): 23 for i in range(100): 24 yield (np.array(i),) 25 26 27class Augment: 28 def __init__(self, loss): 29 self.loss = loss 30 31 def preprocess(self, input_): 32 return input_ 33 34 def update(self, data): 35 self.loss = data["loss"] 36 37 38def test_simple_sync_wait(): 39 """ 40 Test simple sync wait: test sync in dataset pipeline 41 """ 42 logger.info("test_simple_sync_wait") 43 batch_size = 4 44 dataset = ds.GeneratorDataset(gen, column_names=["input"]) 45 46 aug = Augment(0) 47 dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) 48 dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"]) 49 dataset = dataset.batch(batch_size) 50 count = 0 51 for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 52 assert data["input"][0] == count 53 count += batch_size 54 data = {"loss": count} 55 dataset.sync_update(condition_name="policy", data=data) 56 57 58def test_simple_shuffle_sync(): 59 """ 60 Test simple shuffle sync: test shuffle before sync 61 """ 62 logger.info("test_simple_shuffle_sync") 63 shuffle_size = 4 64 batch_size = 10 65 66 dataset = ds.GeneratorDataset(gen, column_names=["input"]) 67 68 aug = Augment(0) 69 dataset = dataset.shuffle(shuffle_size) 70 dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) 71 dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"]) 72 dataset = dataset.batch(batch_size) 73 74 count = 0 75 for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 76 count += 1 77 data = {"loss": count} 78 dataset.sync_update(condition_name="policy", data=data) 79 80 81def test_two_sync(): 82 """ 83 Test two sync: dataset pipeline with with two sync_operators 84 """ 85 logger.info("test_two_sync") 86 batch_size = 6 87 88 dataset = ds.GeneratorDataset(gen, column_names=["input"]) 89 90 aug = Augment(0) 91 # notice that with our design, we need to have step_size = shuffle size 92 dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) 93 94 dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"]) 95 96 dataset = dataset.sync_wait(num_batch=2, condition_name="every 2 batches") 97 98 dataset = dataset.batch(batch_size) 99 100 count = 0 101 for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 102 count += 1 103 data = {"loss": count} 104 dataset.sync_update(condition_name="every batch", data=data) 105 if count % 2 == 0: 106 dataset.sync_update(condition_name="every 2 batches") 107 108 109def test_sync_epoch(): 110 """ 111 Test sync wait with epochs: test sync with epochs in dataset pipeline 112 """ 113 logger.info("test_sync_epoch") 114 batch_size = 30 115 dataset = ds.GeneratorDataset(gen, column_names=["input"]) 116 117 aug = Augment(0) 118 dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) 119 dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"]) 120 dataset = dataset.batch(batch_size, drop_remainder=True) 121 122 for _ in range(3): 123 aug.update({"loss": 0}) 124 count = 0 125 for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 126 assert data["input"][0] == count 127 count += batch_size 128 data = {"loss": count} 129 dataset.sync_update(condition_name="policy", data=data) 130 131 132def test_multiple_iterators(): 133 """ 134 Test sync wait with multiple iterators: will start multiple 135 """ 136 logger.info("test_sync_epoch") 137 batch_size = 30 138 dataset = ds.GeneratorDataset(gen, column_names=["input"]) 139 140 aug = Augment(0) 141 dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) 142 dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"]) 143 dataset = dataset.batch(batch_size, drop_remainder=True) 144 # 2nd dataset 145 dataset2 = ds.GeneratorDataset(gen, column_names=["input"]) 146 147 aug = Augment(0) 148 dataset2 = dataset2.sync_wait(condition_name="policy", callback=aug.update) 149 dataset2 = dataset2.map(operations=[aug.preprocess], input_columns=["input"]) 150 dataset2 = dataset2.batch(batch_size, drop_remainder=True) 151 152 for item1, item2 in zip(dataset.create_dict_iterator(num_epochs=1, output_numpy=True), 153 dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)): 154 assert item1["input"][0] == item2["input"][0] 155 data1 = {"loss": item1["input"][0]} 156 data2 = {"loss": item2["input"][0]} 157 dataset.sync_update(condition_name="policy", data=data1) 158 dataset2.sync_update(condition_name="policy", data=data2) 159 160 161def test_sync_exception_01(): 162 """ 163 Test sync: with shuffle in sync mode 164 """ 165 logger.info("test_sync_exception_01") 166 shuffle_size = 4 167 168 dataset = ds.GeneratorDataset(gen, column_names=["input"]) 169 170 aug = Augment(0) 171 dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) 172 dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"]) 173 174 with pytest.raises(RuntimeError) as e: 175 dataset.shuffle(shuffle_size) 176 assert "No shuffle after sync operators" in str(e.value) 177 178 179def test_sync_exception_02(): 180 """ 181 Test sync: with duplicated condition name 182 """ 183 logger.info("test_sync_exception_02") 184 185 dataset = ds.GeneratorDataset(gen, column_names=["input"]) 186 187 aug = Augment(0) 188 dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) 189 190 dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"]) 191 192 with pytest.raises(RuntimeError) as e: 193 dataset.sync_wait(num_batch=2, condition_name="every batch") 194 assert "Condition name is already in use" in str(e.value) 195 196 197def test_sync_exception_03(): 198 """ 199 Test sync: with wrong batch size 200 """ 201 logger.info("test_sync_exception_03") 202 203 dataset = ds.GeneratorDataset(gen, column_names=["input"]) 204 205 aug = Augment(0) 206 # try to create dataset with batch_size < 0 207 with pytest.raises(ValueError) as e: 208 dataset.sync_wait(condition_name="every batch", num_batch=-1, callback=aug.update) 209 assert "num_batch need to be greater than 0." in str(e.value) 210 211 212def test_sync_exception_04(): 213 """ 214 Test sync: with negative batch size in update 215 """ 216 logger.info("test_sync_exception_04") 217 218 dataset = ds.GeneratorDataset(gen, column_names=["input"]) 219 220 aug = Augment(0) 221 # try to create dataset with batch_size < 0 222 dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) 223 dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"]) 224 count = 0 225 with pytest.raises(RuntimeError) as e: 226 for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 227 count += 1 228 data = {"loss": count} 229 dataset.sync_update(condition_name="every batch", num_batch=-1, data=data) 230 assert "Sync_update batch size can only be positive" in str(e.value) 231 232 233def test_sync_exception_05(): 234 """ 235 Test sync: with wrong batch size in update 236 """ 237 logger.info("test_sync_exception_05") 238 239 dataset = ds.GeneratorDataset(gen, column_names=["input"]) 240 count = 0 241 aug = Augment(0) 242 # try to create dataset with batch_size < 0 243 dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) 244 dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"]) 245 with pytest.raises(RuntimeError) as e: 246 for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 247 dataset.disable_sync() 248 count += 1 249 data = {"loss": count} 250 dataset.disable_sync() 251 dataset.sync_update(condition_name="every", data=data) 252 assert "Condition name not found" in str(e.value) 253 254 255def test_simple_sync_wait_empty_condition_name(): 256 """ callback is none, sync_wait and sync_update's condition_name is empty string ('') """ 257 logger.info("test_simple_sync_wait_empty_condition_name") 258 batch_size = 10 259 dataset = ds.GeneratorDataset(gen, column_names=["input"]) 260 261 aug = Augment(0) 262 dataset = dataset.sync_wait(condition_name='', num_batch=1) 263 dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) 264 dataset = dataset.batch(batch_size) 265 266 count = 0 267 for data in dataset.create_dict_iterator(output_numpy=True): 268 count += 1 269 data = {"loss": count} 270 dataset.sync_update(condition_name="", data=data) 271 272 273def test_sync_exception_06(): 274 """ 275 Test sync: with string batch size 276 """ 277 logger.info("test_sync_exception_03") 278 279 dataset = ds.GeneratorDataset(gen, column_names=["input"]) 280 281 aug = Augment(0) 282 # try to create dataset with batch_size < 0 283 with pytest.raises(TypeError) as e: 284 dataset.sync_wait(condition_name="every batch", num_batch="123", callback=aug.update) 285 assert "is not of type [<class 'int'>]" in str(e.value) 286 287 288if __name__ == "__main__": 289 test_simple_sync_wait() 290 test_simple_shuffle_sync() 291 test_two_sync() 292 test_sync_exception_01() 293 test_sync_exception_02() 294 test_sync_exception_03() 295 test_sync_exception_04() 296 test_sync_exception_05() 297 test_sync_exception_06() 298 test_sync_epoch() 299 test_multiple_iterators() 300 test_simple_sync_wait_empty_condition_name() 301