1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Dataset help for minddata dataset""" 16from __future__ import absolute_import 17 18import math 19import copy 20 21from mindspore import _checkparam as Validator 22from mindspore import log as logger 23from mindspore.common._auto_dynamic import is_auto_dynamic, convert_new_shapes 24from mindspore.common.dtype import pytype_to_dtype 25from mindspore.common.api import _cell_graph_executor, _is_args_fullmode, ARG_SPECIFIED 26from mindspore.common._utils import is_shape_unknown 27from mindspore.dataset.engine import offload 28from mindspore import context, nn 29from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list 30from mindspore.parallel._utils import _get_device_num, _get_global_rank, _need_to_full, \ 31 _to_full_shapes, _get_pipeline_stages, _change_symbols_for_parallel, _is_in_auto_parallel_mode, \ 32 _origin_shapes, _dynamic_shape_for_dataset 33from mindspore.parallel._ps_context import _is_role_sched 34from mindspore.ops import operations as P 35from mindspore.common.auto_dynamic_shape import _auto_dynamic_shape 36 37 38def _send_data(dataset, epoch_num): 39 """Engine dataset to write data to tdt queue.""" 40 if not hasattr(dataset, '__has_sent__'): 41 exec_dataset = dataset.__transfer_dataset__ 42 exec_dataset.send(epoch_num) 43 dataset.__has_sent__ = True 44 45 46def _send_data_no_flag(dataset, epoch_num): 47 """Engine dataset to write data to tdt queue directly.""" 48 exec_dataset = dataset.__transfer_dataset__ 49 exec_dataset.send(epoch_num) 50 51 52def _dynamic_sink_data(dataset, dataset_iter): 53 """Special scenario for dataset with sink_size=1.""" 54 if hasattr(dataset_iter, "sink_size") and \ 55 dataset_iter.sink_size == 1 and \ 56 dataset.get_dataset_size() != 1 and \ 57 not hasattr(dataset, "__no_send__") and \ 58 hasattr(dataset_iter, "sink_count") and \ 59 dataset_iter.sink_count == 1: 60 return True 61 return False 62 63 64def _dynamic_sink_exception_scenario(dataset_iter, is_dynamic): 65 """The exception scenario for dynamic data is not applicable.""" 66 if context.get_context("mode") != context.GRAPH_MODE or is_dynamic: 67 return True 68 return False 69 70 71def _dynamic_sink_scenario(dataset, dataset_iter, is_dynamic): 72 """Special scenario with dynamic shape and sink_size=1.""" 73 flag = False 74 75 # This is used only for test 76 if is_auto_dynamic(): 77 return False 78 79 if _dynamic_sink_data(dataset, dataset_iter) and not _dynamic_sink_exception_scenario(dataset_iter, is_dynamic): 80 flag = True 81 82 return flag 83 84 85class _DataWrapper(nn.Cell): 86 """ 87 Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the 88 dataset channel 'queue_name' and performs the forward computation. 89 """ 90 91 def __init__(self, network, dataset_types, dataset_shapes, queue_name): 92 super(_DataWrapper, self).__init__( 93 auto_prefix=False, flags=network.get_flags()) 94 # Also copy the flag in `network` construct 95 flags = getattr(network.__class__.construct, "_func_graph_flags", {}) 96 self.info = (dataset_types, dataset_shapes) 97 self.add_flags(**flags) 98 self.get_next = P.GetNext( 99 dataset_types, dataset_shapes, len(dataset_types), queue_name) 100 if network.get_inputs() is not None: 101 network_inputs = network.get_inputs() 102 is_fullmode = _is_args_fullmode(network_inputs, False) 103 if is_fullmode: 104 symbol_inputs = [getattr(inp, "symbolic_shape", None) for inp in network.get_inputs()] 105 else: 106 symbol_inputs = [None for _ in dataset_shapes] 107 arg_specified = network_inputs.get(ARG_SPECIFIED, []) 108 for idx, inp in arg_specified: 109 symbol_inputs[idx] = getattr(inp, "symbolic_shape", None) 110 symbols_for_parallel = _change_symbols_for_parallel(dataset_shapes, copy.deepcopy(symbol_inputs)) 111 if any((s is not None for s in symbols_for_parallel)): 112 self.get_next.add_prim_attr("symbols", symbol_inputs) 113 self.get_next.add_prim_attr("symbols_for_parallel", symbols_for_parallel) 114 self.network = network 115 self._get_attr_from_cell(network) 116 117 def construct(self): 118 outputs = self.get_next() 119 return self.network(*outputs) 120 121 122def _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name): 123 if not isinstance(network, _DataWrapper): 124 network = _DataWrapper( 125 network, dataset_types, dataset_shapes, queue_name) 126 return network 127 128 129def _has_dynamic_shape(dataset_shapes): 130 for shape in dataset_shapes: 131 if is_shape_unknown(shape): 132 return True 133 return False 134 135 136def _generate_network_with_dataset(network, dataset_helper, queue_name): 137 """ 138 Generate new network with network and dataset info. 139 """ 140 dataset_types, dataset_shapes = dataset_helper.types_shapes() 141 142 # This is used only for test 143 if is_auto_dynamic(): 144 new_shapes = convert_new_shapes(dataset_shapes) 145 return _generate_dataset_sink_mode_net(network, new_shapes, dataset_types, queue_name) 146 147 if network.get_inputs() and None not in network.get_inputs(): 148 if _is_in_auto_parallel_mode(): 149 # here, the dataset shapes has been processed by full_shape(), so need to resume it to original shape 150 # the _check_inputs() will change static origin_shape to dynamic shape 151 # after _check_inputs(), convert dataset_shapes to dynamic shape 152 origin_shape = _origin_shapes(dataset_shapes) 153 _check_inputs(network.get_inputs(), origin_shape, dataset_types) 154 dataset_shapes = _dynamic_shape_for_dataset(dataset_shapes, origin_shape) 155 else: 156 _check_inputs(network.get_inputs(), dataset_shapes, dataset_types) 157 elif context.get_context("mode") == context.PYNATIVE_MODE: 158 dataset_shapes = tuple([(-2,)] * len(dataset_shapes)) 159 network = _generate_dataset_sink_mode_net( 160 network, dataset_shapes, dataset_types, queue_name) 161 return network 162 163 164def _check_inputs(network_shapes, dataset_shapes, dataset_types): 165 """ 166 Check if set inputs are correct. 167 """ 168 if not _is_args_fullmode(network_shapes, False): 169 temp_network_shapes = [None for _ in dataset_shapes] 170 arg_specified = network_shapes.get(ARG_SPECIFIED, []) 171 for idx, inp in arg_specified: 172 temp_network_shapes[idx] = inp 173 network_shapes = temp_network_shapes 174 175 for tensor_index, ele_dataset_shape in enumerate(dataset_shapes): 176 if network_shapes[tensor_index] is None: 177 continue 178 set_inputs_shape = list(network_shapes[tensor_index].shape) 179 inputs_shape = list(ele_dataset_shape) 180 if dataset_types[tensor_index] != network_shapes[tensor_index].dtype: 181 raise TypeError( 182 f"The {tensor_index+1}th input type of 'set_inputs' must be the same as network's input, " 183 f"but got 'set_inputs': {network_shapes[tensor_index].dtype} and network's " 184 f"input: {dataset_types[tensor_index]}." 185 ) 186 if len(inputs_shape) != len(set_inputs_shape): 187 raise ValueError( 188 f"The {tensor_index + 1}th input dims of 'set_inputs' must be the same as network's input, " 189 f"but got 'set_inputs': {len(set_inputs_shape)} and network's input: {len(inputs_shape)}.") 190 for index, ele_shape in enumerate(ele_dataset_shape): 191 if network_shapes[tensor_index].shape[index] != -1: 192 if set_inputs_shape[index] != ele_shape: 193 raise ValueError( 194 f"The {index + 1}th input shape of 'set_inputs' must be the same as network's input, " 195 f"but got 'set_inputs': {set_inputs_shape[index]} and network's input: " 196 f"{dataset_shapes[tensor_index][index]}.") 197 else: 198 dataset_shapes[tensor_index][index] = -1 199 200 201class _DatasetAux: 202 @staticmethod 203 def __deepcopy__(memodict): 204 return 205 206 207def _get_dataset_aux(dataset): 208 if not hasattr(dataset, '__network_aux__'): 209 dataset.__network_aux__ = _DatasetAux() 210 return dataset.__network_aux__ 211 212 213def connect_network_with_dataset(network, dataset_helper): 214 """ 215 Connect the `network` with dataset in `dataset_helper`. Only supported in `sink mode 216 <https://mindspore.cn/tutorials/experts/en/master/optimize/execution_opt.html>`_, (dataset_sink_mode=True). 217 218 Args: 219 network (Cell): The training network for dataset. 220 dataset_helper (DatasetHelper): A class to process the MindData dataset, it provides the type, shape and queue 221 name of the dataset. 222 223 Returns: 224 Cell, a new network containing the type, shape and queue name of the dataset info. 225 226 Raises: 227 RuntimeError: If the API was not called in dataset sink mode. 228 229 Supported Platforms: 230 ``Ascend`` ``GPU`` 231 232 Examples: 233 >>> import numpy as np 234 >>> import mindspore as ms 235 >>> from mindspore import nn 236 >>> from mindspore import dataset as ds 237 >>> 238 >>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))} 239 >>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32) 240 >>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True) 241 >>> net = nn.Dense(10, 5) 242 >>> net_with_dataset = ms.connect_network_with_dataset(net, dataset_helper) 243 """ 244 dataset_iter = dataset_helper.iter 245 dataset = dataset_iter.dataset 246 aux = _get_dataset_aux(dataset) 247 248 if isinstance(dataset_iter, _DatasetIterNormal): 249 raise RuntimeError( 250 "The API 'connect_network_with_dataset' should be called in dataset sink mode.") 251 252 if _is_role_sched(): 253 network.add_flags(sink_mode=True) 254 return network 255 256 if not hasattr(aux, '__network__'): 257 aux.__network__ = network 258 259 if aux.__network__ is not network: 260 raise ValueError( 261 "The dataset has been connected to other network, please check the code.") 262 is_dynamic = bool(network.get_inputs()) 263 queue_name = dataset.__transfer_dataset__.queue_name 264 if _dynamic_sink_scenario(dataset, dataset_iter, is_dynamic): 265 dataset_types, dataset_shapes = dataset_helper.get_data_info() 266 # Need to do full_batch for shapes which also do in the _DatasetIterMSLoopSink 267 if _need_to_full(): 268 dataset_shapes = _to_full_shapes(dataset_shapes, _get_device_num() // _get_pipeline_stages()) 269 dataset_types = [pytype_to_dtype(x) for x in dataset_types] 270 if not is_dynamic: 271 dataset_shapes = _auto_dynamic_shape.auto_dynamic_generate_compile_args(dataset_shapes, True) 272 key = str(dataset_types) + str(dataset_shapes) 273 274 if hasattr(aux, "__shape_type__") and aux.__shape_type__ != key: 275 _auto_dynamic_shape.update_phase_and_compile_args(dataset_shapes, key, True, aux) 276 if hasattr(aux, '__network_manage__') and key in aux.__network_manage__: 277 network = aux.__network_manage__[key] 278 else: 279 if _need_to_full(): 280 device_num = _get_device_num() // _get_pipeline_stages() 281 dataset_shapes = _to_full_shapes(dataset_shapes, device_num) 282 283 network = _generate_dataset_sink_mode_net( 284 network, dataset_shapes, dataset_types, queue_name) 285 if hasattr(aux, '__network_manage__'): 286 aux.__network_manage__ = aux.__network_manage__ 287 else: 288 aux.__network_manage__ = dict() 289 aux.__network_manage__[key] = network 290 network.add_flags(sink_mode=True) 291 return network 292 293 if hasattr(aux, '__sink_network__'): 294 network = aux.__sink_network__ 295 else: 296 if context.get_context("device_target") in ("Ascend", "GPU"): 297 network = offload.check_add_offload_sink_mode( 298 dataset, dataset_helper, network) 299 network = _generate_network_with_dataset( 300 network, dataset_helper, queue_name) 301 aux.__sink_network__ = network 302 dataset_types, dataset_shapes = dataset_helper.types_shapes() 303 aux.__shape_type__ = str(dataset_types) + str(dataset_shapes) 304 305 if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter, is_dynamic): 306 dataset_helper.get_data_info() 307 network.add_flags(sink_mode=True) 308 return network 309 310 311class DatasetHelper: 312 """ 313 DatasetHelper is a class to process the MindData dataset and provides the information of dataset. 314 315 According to different contexts, change the iterations of dataset and use the same iteration for loop in different 316 contexts. 317 318 Note: 319 The iteration of DatasetHelper will provide one epoch data. 320 321 Args: 322 dataset (Dataset): The dataset iterator. The dataset can be generated by dataset generator API in 323 `mindspore.dataset` module, such as :class:`mindspore.dataset.ImageFolderDataset`. 324 dataset_sink_mode (bool): If the value is True, GetNext is employed to fetch the data at device through the 325 dataset pipeline, otherwise fetch the data at host by iterating through the dataset. 326 Default: ``True``. 327 sink_size (int): Control the amount of data in each sink. 328 If sink_size=-1, sink the complete dataset for each epoch. 329 If sink_size>0, sink sink_size data for each epoch. 330 Default: -1. 331 epoch_num (int): The number of passes of the entire dataset to be sent. Default: 1. 332 333 Examples: 334 >>> import numpy as np 335 >>> import mindspore as ms 336 >>> from mindspore import nn 337 >>> from mindspore import dataset as ds 338 >>> 339 >>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))} 340 >>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32) 341 >>> set_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=False) 342 >>> 343 >>> net = nn.Dense(10, 5) 344 >>> # Object of DatasetHelper is iterable 345 >>> for next_element in set_helper: 346 ... # `next_element` includes data and label, using data to run the net 347 ... data = next_element[0] 348 ... result = net(data) 349 """ 350 351 def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1): 352 dataset_sink_mode = Validator.check_bool(dataset_sink_mode) 353 Validator.check_is_int(sink_size) 354 if sink_size < -1 or sink_size == 0: 355 raise ValueError( 356 "The 'sink_size' must be -1 or positive, but got sink_size {}.".format(sink_size)) 357 if sink_size == -1: 358 sink_size = dataset.get_dataset_size() 359 360 if dataset_sink_mode: 361 if context.get_context("mode") == context.GRAPH_MODE: 362 if _is_role_sched(): 363 iterclass = _DatasetIterPSServer 364 elif (context.get_context("device_target") == "Ascend") or \ 365 (context.get_context("device_target") == "GPU"): 366 iterclass = _DatasetIterMSLoopSink 367 else: 368 target = context.get_context("device_target") 369 raise RuntimeError("Currently dataset sink mode is not supported when the device " 370 "target is {}, please set dataset_sink_mode to False " 371 "in Model.train()".format(target)) 372 else: 373 iterclass = _DatasetIterPyNative 374 self.iter = iterclass(dataset, sink_size, epoch_num) 375 else: 376 iterclass = _DatasetIterNormal 377 self.iter = iterclass(dataset, epoch_num=epoch_num) 378 379 def __iter__(self): 380 return self.iter.__iter__() 381 382 # A temp solution for loop sink. Delete later 383 def types_shapes(self): 384 """ 385 Get the types and shapes from dataset on the current configuration. 386 387 Examples: 388 >>> import mindspore as ms 389 >>> import numpy as np 390 >>> 391 >>> # Define a dataset pipeline 392 >>> def generator(): 393 ... for i in range(5): 394 ... yield (np.ones((32, 10)),) 395 >>> 396 >>> train_dataset = ms.dataset.GeneratorDataset(generator, ["data"]) 397 >>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True) 398 >>> 399 >>> types, shapes = dataset_helper.types_shapes() 400 """ 401 return self.iter.types_shapes() 402 403 def sink_size(self): 404 """ 405 Get sink_size for each iteration. 406 407 Examples: 408 >>> import mindspore as ms 409 >>> import numpy as np 410 >>> 411 >>> # Define a dataset pipeline 412 >>> def generator(): 413 ... for i in range(5): 414 ... yield (np.ones((32, 10)),) 415 >>> 416 >>> train_dataset = ms.dataset.GeneratorDataset(generator, ["data"]) 417 >>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True, sink_size=-1) 418 >>> 419 >>> # if sink_size==-1, then will return the full size of source dataset. 420 >>> sink_size = dataset_helper.sink_size() 421 """ 422 return self.iter.get_sink_size() 423 424 def stop_send(self): 425 """ 426 Stop send data about data sink. 427 428 Examples: 429 >>> import mindspore as ms 430 >>> import numpy as np 431 >>> # Define a dataset pipeline 432 >>> def generator(): 433 ... for i in range(5): 434 ... yield (np.ones((32, 10)),) 435 >>> train_dataset = ms.dataset.GeneratorDataset(generator, ["data"]) 436 >>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True, sink_size=-1) 437 >>> dataset_helper.stop_send() 438 """ 439 self.iter.stop_send() 440 441 def release(self): 442 """ 443 Free up resources about data sink. 444 445 Examples: 446 >>> import numpy as np 447 >>> import mindspore as ms 448 >>> from mindspore import nn 449 >>> from mindspore import dataset as ds 450 >>> 451 >>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))} 452 >>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32) 453 >>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True) 454 >>> dataset_helper.release() 455 """ 456 self.iter.release() 457 458 def continue_send(self): 459 """ 460 Continue to send data to device at the beginning of epoch. 461 462 Examples: 463 >>> import numpy as np 464 >>> import mindspore as ms 465 >>> from mindspore import nn 466 >>> from mindspore import dataset as ds 467 >>> 468 >>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))} 469 >>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32) 470 >>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True) 471 >>> dataset_helper.continue_send() 472 """ 473 self.iter.continue_send() 474 475 def _reset(self, step, dataset_size): 476 """Reset the dataset to the provided step and epoch.""" 477 self.iter._reset(step, dataset_size) # pylint: disable=protected-access 478 479 # pylint: disable=missing-docstring 480 def get_data_info(self): 481 # In sink mode, it returns the types and shapes of the current data. 482 # Generally, it works in dynamic shape scenarios. 483 return self.iter.get_data_info() 484 485 # pylint: disable=missing-docstring 486 def get_mbuf_queue_size(self): 487 # In sink mode, it returns the element numbers inside mbuf channel. 488 return self.iter.get_mbuf_queue_size() 489 490 # pylint: disable=missing-docstring 491 def get_send_info(self, run_context): 492 # In sink mode, it returns the send information of dataset at this moment. 493 # Send information includes number of send batches, time summary of fetching data on host 494 # and time summary of sending data. 495 class InfoViewer: 496 ''' 497 Inner class for parsing send info. 498 ''' 499 500 def __init__(self, send_info, run_context): 501 self.info_ = {} 502 self.sink_size = run_context.original_args()["batch_num"] 503 if run_context.original_args().get("train_dataset", None) is not None: 504 self.dataset_size = run_context.original_args()["train_dataset"].get_dataset_size() 505 elif run_context.original_args().get("valid_dataset", None) is not None: 506 self.dataset_size = run_context.original_args()["valid_dataset"].get_dataset_size() 507 else: 508 raise RuntimeError("Could not find a proper dataset to estimate dataset size.") 509 if not send_info: 510 epoch = 1 511 self.info_[epoch] = {'fetch_data_num': 0, 'fetch_data_time': 0, 'first_data_time': 0} 512 else: 513 for info_per_epoch in send_info: 514 epoch, fetch_data_num, first_data_time, fetch_data_time = info_per_epoch 515 if fetch_data_num > 1: 516 fetch_data_time = (fetch_data_time - first_data_time) / (fetch_data_num - 1) * 1000. 517 self.info_[epoch] = {'fetch_data_num': fetch_data_num, 518 'fetch_data_time': fetch_data_time, 519 'first_data_time': first_data_time} 520 521 def epoch(self, epoch): 522 if self.sink_size == self.dataset_size: 523 return self.info_[epoch] 524 global_step = epoch * self.sink_size 525 data_epoch = math.ceil(global_step / self.dataset_size) 526 return self.info_[data_epoch] 527 528 # send info struct:[epoch, data_num_per_epoch, first_data_time, accumulate_data_time] 529 # for example [1, 1875, 0.421, 0.362] 530 send_info = self.iter.get_send_info() 531 return InfoViewer(send_info, run_context) 532 533 534class _DatasetIter: 535 """Base iter for dataset helper""" 536 537 def __init__(self, dataset, sink_size, epoch_num): 538 self.dataset = dataset 539 self.sink_size = sink_size 540 self.sink_count = self.get_sink_count(dataset) 541 self.dataset_types, self.dataset_shapes = _get_types_and_shapes( 542 dataset) 543 544 if dataset.get_init_step() % sink_size != 0: 545 init_epoch = dataset.get_init_step() // sink_size 546 init_step = init_epoch * sink_size 547 logger.warning("Init global step must be the end of the epoch in sink mode, " 548 "but got: {0}. Reset it to the end of epoch {1} at step {2}." 549 .format(dataset.get_init_step(), init_epoch, init_step)) 550 dataset.set_init_step(init_step) 551 552 if not hasattr(dataset, '__transfer_dataset__'): 553 if hasattr(dataset, '__loop_size__'): 554 self.sink_size = dataset.__loop_size__ 555 create_data_info_queue = ( 556 sink_size == 1 and self.sink_count == 1 and dataset.get_dataset_size() != 1) 557 dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size, 558 create_data_info_queue=create_data_info_queue) 559 560 if not hasattr(dataset, '__no_send__'): 561 _send_data(dataset, epoch_num) 562 else: 563 # if using an existed __transfer_dataset__, set the queue_name directly 564 if not dataset.__transfer_dataset__.queue_name: 565 _cell_graph_executor.set_queue_name( 566 dataset.__transfer_dataset__.queue_name) 567 _send_data_no_flag(dataset, epoch_num) 568 569 self.stop_send = dataset.__transfer_dataset__.stop_send 570 self.release = dataset.__transfer_dataset__.release 571 self.continue_send = dataset.__transfer_dataset__.continue_send 572 self.get_data_info = dataset.__transfer_dataset__.get_data_info 573 self.get_mbuf_queue_size = dataset.__transfer_dataset__.get_mbuf_queue_size 574 self.get_send_info = dataset.__transfer_dataset__.get_send_info 575 if hasattr(dataset.__transfer_dataset__, "_reset"): 576 self._reset = dataset.__transfer_dataset__._reset # pylint: disable=protected-access 577 578 def __iter__(self): 579 self.index = 0 580 return self 581 582 def __next__(self): 583 if self.index >= self.sink_count: 584 raise StopIteration() 585 self.index += 1 586 return self.op() 587 588 def types_shapes(self): 589 """ 590 Return the types and shapes of the dataset. The type and shape of each data in the dataset 591 should be consistent. 592 """ 593 return self.dataset_types, self.dataset_shapes 594 595 def get_sink_count(self, dataset): 596 sink_count = 1 597 if hasattr(dataset, '__loop_size__'): 598 loop_size = dataset.__loop_size__ 599 if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0: 600 raise ValueError(f"Dataset size {dataset.get_dataset_size()} and 'sink_size' {loop_size} " 601 f"are not matched, dataset size should be divisible by 'sink_size'.") 602 sink_count = math.ceil(dataset.get_dataset_size() / loop_size) 603 return sink_count 604 605 def get_sink_size(self): 606 """get sink_size to device""" 607 sink_size = 1 608 if hasattr(self.dataset, '__loop_size__'): 609 sink_size = self.dataset.__loop_size__ 610 else: 611 if context.get_context("device_target") == "Ascend" or context.get_context("device_target") == "GPU": 612 if self.sink_size > 0: 613 sink_size = self.sink_size 614 else: 615 sink_size = self.dataset.get_dataset_size() 616 return sink_size 617 618 619class _DatasetIterPyNative(_DatasetIter): 620 """Iter for context (mode=PYNATIVE_MODE).""" 621 622 def __init__(self, dataset, sink_size, epoch_num): 623 super().__init__(dataset, sink_size, epoch_num) 624 if sink_size > 0: 625 self.sink_count = sink_size 626 else: 627 self.sink_count = dataset.get_dataset_size() 628 629 def op(): 630 return tuple() 631 632 self.op = op 633 634 635class _DatasetIterMSLoopSink(_DatasetIter): 636 """Iter for context (device_target=Ascend)""" 637 638 def __init__(self, dataset, sink_size, epoch_num): 639 super().__init__(dataset, sink_size, epoch_num) 640 self.sink_count = self.get_sink_count(dataset) 641 # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch, 642 # use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for 643 # compile is device_number times the batch dimension of tensors for run. Now only support LoopSink. 644 if _need_to_full(): 645 device_num = _get_device_num() // _get_pipeline_stages() 646 self.dataset_shapes = _to_full_shapes( 647 self.dataset_shapes, device_num) 648 649 def op(): 650 return tuple() 651 652 self.op = op 653 654 655class _DatasetIterPSServer(_DatasetIter): 656 """Iter for context on MS_PSERVER or MS_SCHED""" 657 658 def __init__(self, dataset, sink_size, epoch_num): 659 super().__init__(dataset, sink_size, epoch_num) 660 self.sink_count = 1 661 self.sink_size = 1 662 self.op = None 663 664 def op(): 665 return _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num=1) 666 667 self.op = op 668 669 670class _DatasetIterNormal: 671 """Iter for normal(non sink) mode, feed the data from host.""" 672 673 def __init__(self, dataset, epoch_num=-1): 674 self.dataset = dataset 675 self.device_num = _get_device_num() 676 self.global_rank = _get_global_rank() 677 self.iter = self.dataset.create_tuple_iterator( 678 num_epochs=epoch_num, do_copy=True) 679 680 def __iter__(self): 681 return self 682 683 def __next__(self): 684 data = self.iter.__next__() 685 return data 686 687 688__all__ = ["DatasetHelper", "connect_network_with_dataset"] 689