1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Adapter module that convert different input data objects into tf.dataset.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22import contextlib 23import functools 24import itertools 25import math 26import random 27 28import numpy as np 29import six 30 31from tensorflow.python.data.experimental.ops import cardinality 32from tensorflow.python.data.experimental.ops import distribute_options 33from tensorflow.python.data.ops import dataset_ops 34from tensorflow.python.data.ops import iterator_ops 35from tensorflow.python.distribute import distribution_strategy_context as ds_context 36from tensorflow.python.distribute import input_lib 37from tensorflow.python.eager import context 38from tensorflow.python.eager import monitoring 39from tensorflow.python.framework import dtypes 40from tensorflow.python.framework import errors 41from tensorflow.python.framework import ops 42from tensorflow.python.framework import smart_cond 43from tensorflow.python.framework import sparse_tensor 44from tensorflow.python.framework import tensor_shape 45from tensorflow.python.keras import backend 46from tensorflow.python.keras.engine import training_utils 47from tensorflow.python.keras.utils import data_utils 48from tensorflow.python.keras.utils import dataset_creator 49from tensorflow.python.keras.utils import tf_utils 50from tensorflow.python.ops import array_ops 51from tensorflow.python.ops import math_ops 52from tensorflow.python.ops import random_ops 53from tensorflow.python.ops import script_ops 54from tensorflow.python.platform import tf_logging as logging 55from tensorflow.python.util import nest 56from tensorflow.python.util.tf_export import keras_export 57 58keras_data_adapter_gauge = monitoring.BoolGauge( 59 "/tensorflow/api/keras/data_adapters", "keras data adapter usage", "method") 60 61try: 62 from scipy import sparse as scipy_sparse # pylint: disable=g-import-not-at-top 63except ImportError: 64 scipy_sparse = None 65try: 66 import pandas as pd # pylint: disable=g-import-not-at-top 67except ImportError: 68 pd = None 69 70 71@six.add_metaclass(abc.ABCMeta) 72class DataAdapter(object): 73 """Base class for input data adapter. 74 75 In TF 2.0, tf.data is the preferred API for user to feed in data. In order 76 to simplify the training code path, all the input data object will be 77 converted to `tf.data.Dataset` if possible. 78 79 Note that since this class is mainly targeted for TF 2.0, it might have a lot 80 of assumptions under the hood, eg eager context by default, distribution 81 strategy, etc. In the meantime, some legacy feature support might be dropped, 82 eg, Iterator from dataset API in v1, etc. 83 84 The sample usage of this class is like: 85 86 ``` 87 x = tf.data.Dataset.range(100) 88 adapter_cls = [NumpyArrayDataAdapter, ..., DatasetAdapter] 89 applicable_adapters = [cls for cls in adapter_cls if cls.can_handle(x)] 90 if len(applicable_adapters) != 1: 91 raise ValueError("Expect only one adapter class to handle the input") 92 93 dataset = applicable_adapters[0](x).get_dataset() 94 for data in dataset: 95 # training 96 ``` 97 """ 98 99 @staticmethod 100 def can_handle(x, y=None): 101 """Whether the current DataAdapter could handle the input x and y. 102 103 Structure wise, x and y can be single object, or list of objects if there 104 multiple input/output, or dictionary of objects when the intput/output are 105 named. 106 107 Args: 108 x: input features. 109 y: target labels. Note that y could be None in the case of prediction. 110 111 Returns: 112 boolean 113 """ 114 raise NotImplementedError 115 116 @abc.abstractmethod 117 def __init__(self, x, y=None, **kwargs): 118 """Create a DataAdapter based on data inputs. 119 120 The caller must make sure to call `can_handle()` first before invoking this 121 method. Provide unsupported data type will result into unexpected behavior. 122 123 Args: 124 x: input features. 125 y: target labels. Note that y could be None in the case of prediction. 126 **kwargs: Other keyword arguments for DataAdapter during the construction 127 of the tf.dataset.Dataset. For example: 128 - Numpy data might have `sample_weights` which will be used for 129 weighting the loss function during training. 130 - Numpy data might need to have `batch_size` parameter when constructing 131 the dataset and iterator. 132 - Certain input might need to be distribution strategy aware. When 133 `distribution_strategy` is passed, the created dataset need to respect 134 the strategy. 135 DataAdapter might choose to ignore any keyword argument if it doesn't 136 use it, or raise exception if any required argument is not provide. 137 """ 138 if not self.can_handle(x, y): 139 raise ValueError("{} Cannot handle input {}, {}".format( 140 self.__class__, x, y)) 141 142 @abc.abstractmethod 143 def get_dataset(self): 144 """Get a dataset instance for the current DataAdapter. 145 146 Note that the dataset returned does not repeat for epoch, so caller might 147 need to create new iterator for the same dataset at the beginning of the 148 epoch. This behavior might change in future. 149 150 Returns: 151 An tf.dataset.Dataset. Caller might use the dataset in different 152 context, eg iter(dataset) in eager to get the value directly, or in graph 153 mode, provide the iterator tensor to Keras model function. 154 """ 155 raise NotImplementedError 156 157 @abc.abstractmethod 158 def get_size(self): 159 """Return the size (number of batches) for the dataset created. 160 161 For certain type of the data input, the number of batches is known, eg for 162 Numpy data, the size is same as (number_of_element / batch_size). Whereas 163 for dataset or python generator, the size is unknown since it may or may not 164 have a end state. 165 166 Returns: 167 int, the number of batches for the dataset, or None if it is unknown. The 168 caller could use this to control the loop of training, show progress bar, 169 or handle unexpected StopIteration error. 170 """ 171 raise NotImplementedError 172 173 @abc.abstractmethod 174 def batch_size(self): 175 """Return the batch size of the dataset created. 176 177 For certain type of the data input, the batch size is known, and even 178 required, like numpy array. Where as for dataset, the batch is unknown 179 unless we take a peek. 180 181 Returns: 182 int, the batch size of the dataset, or None if it is unknown. 183 """ 184 raise NotImplementedError 185 186 def representative_batch_size(self): 187 """Return a representative size for batches in the dataset. 188 189 This is not guaranteed to be the batch size for all batches in the 190 dataset. It just needs to be a rough approximation for batch sizes in 191 the dataset. 192 193 Returns: 194 int, a representative size for batches found in the dataset, 195 or None if it is unknown. 196 """ 197 return self.batch_size() 198 199 @abc.abstractmethod 200 def has_partial_batch(self): 201 """Whether the dataset has partial batch at the end.""" 202 raise NotImplementedError 203 204 @abc.abstractmethod 205 def partial_batch_size(self): 206 """The size of the final partial batch for dataset. 207 208 Will return None if has_partial_batch is False or batch_size is None. 209 """ 210 raise NotImplementedError 211 212 @abc.abstractmethod 213 def should_recreate_iterator(self): 214 """Returns whether a new iterator should be created every epoch.""" 215 raise NotImplementedError 216 217 def get_samples(self): 218 """Returns number of samples in the data, or `None`.""" 219 if not self.get_size() or not self.batch_size(): 220 return None 221 total_sample = self.get_size() * self.batch_size() 222 if self.has_partial_batch(): 223 total_sample -= (self.batch_size() - self.partial_batch_size()) 224 return total_sample 225 226 def on_epoch_end(self): 227 """A hook called after each epoch.""" 228 pass 229 230 231class TensorLikeDataAdapter(DataAdapter): 232 """Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy.""" 233 234 @staticmethod 235 def can_handle(x, y=None): 236 # TODO(kaftan): Check performance implications of using a flatten 237 # here for other types of inputs. 238 flat_inputs = nest.flatten(x) 239 if y is not None: 240 flat_inputs += nest.flatten(y) 241 242 tensor_types = (ops.Tensor, np.ndarray) 243 if pd: 244 tensor_types = (ops.Tensor, np.ndarray, pd.Series, pd.DataFrame) 245 246 def _is_tensor(v): 247 if isinstance(v, tensor_types): 248 return True 249 return False 250 251 return all(_is_tensor(v) for v in flat_inputs) 252 253 def __init__(self, 254 x, 255 y=None, 256 sample_weights=None, 257 sample_weight_modes=None, 258 batch_size=None, 259 epochs=1, 260 steps=None, 261 shuffle=False, 262 **kwargs): 263 super(TensorLikeDataAdapter, self).__init__(x, y, **kwargs) 264 x, y, sample_weights = _process_tensorlike((x, y, sample_weights)) 265 sample_weight_modes = broadcast_sample_weight_modes( 266 sample_weights, sample_weight_modes) 267 268 # If sample_weights are not specified for an output use 1.0 as weights. 269 (sample_weights, _, _) = training_utils.handle_partial_sample_weights( 270 y, sample_weights, sample_weight_modes, check_all_flat=True) 271 272 inputs = pack_x_y_sample_weight(x, y, sample_weights) 273 274 num_samples = set(int(i.shape[0]) for i in nest.flatten(inputs)).pop() 275 _check_data_cardinality(inputs) 276 277 # If batch_size is not passed but steps is, calculate from the input data. 278 # Default to 32 for backwards compat. 279 if not batch_size: 280 batch_size = int(math.ceil(num_samples / steps)) if steps else 32 281 282 self._size = int(math.ceil(num_samples / batch_size)) 283 self._batch_size = batch_size 284 285 num_full_batches = int(num_samples // batch_size) 286 self._partial_batch_size = num_samples % batch_size 287 288 if isinstance(shuffle, str): 289 shuffle = shuffle.lower() 290 291 self._shuffle = shuffle 292 # Vectorized version of shuffle. 293 # This is a performance improvement over using `from_tensor_slices`. 294 # The indices of the data are shuffled and batched, and these indices 295 # are then zipped with the data and used to extract a batch of the data 296 # at each step. The performance improvements here come from: 297 # 1. vectorized batch using gather 298 # 2. parallelized map 299 # 3. pipelined permutation generation 300 # 4. optimized permutation batching 301 # 5. disabled static optimizations 302 303 indices_dataset = dataset_ops.DatasetV2.range(1) 304 if shuffle != "batch": 305 indices_dataset = indices_dataset.repeat(epochs) 306 307 def permutation(_): 308 # It turns out to be more performant to make a new set of indices rather 309 # than reusing the same range Tensor. (presumably because of buffer 310 # forwarding.) 311 indices = math_ops.range(num_samples, dtype=dtypes.int64) 312 if shuffle and shuffle != "batch": 313 indices = random_ops.random_shuffle(indices) 314 return indices 315 316 # We prefetch a single element. Computing large permutations can take quite 317 # a while so we don't want to wait for prefetching over an epoch boundary to 318 # trigger the next permutation. On the other hand, too many simultaneous 319 # shuffles can contend on a hardware level and degrade all performance. 320 indices_dataset = indices_dataset.map(permutation).prefetch(1) 321 322 def slice_batch_indices(indices): 323 """Convert a Tensor of indices into a dataset of batched indices. 324 325 This step can be accomplished in several ways. The most natural is to 326 slice the Tensor in a Dataset map. (With a condition on the upper index to 327 handle the partial batch.) However it turns out that coercing the Tensor 328 into a shape which is divisible by the batch size (and handling the last 329 partial batch separately) allows for a much more favorable memory access 330 pattern and improved performance. 331 332 Args: 333 indices: Tensor which determines the data order for an entire epoch. 334 335 Returns: 336 A Dataset of batched indices. 337 """ 338 num_in_full_batch = num_full_batches * batch_size 339 first_k_indices = array_ops.slice(indices, [0], [num_in_full_batch]) 340 first_k_indices = array_ops.reshape( 341 first_k_indices, [num_full_batches, batch_size]) 342 343 flat_dataset = dataset_ops.DatasetV2.from_tensor_slices(first_k_indices) 344 if self._partial_batch_size: 345 index_remainder = dataset_ops.DatasetV2.from_tensors(array_ops.slice( 346 indices, [num_in_full_batch], [self._partial_batch_size])) 347 flat_dataset = flat_dataset.concatenate(index_remainder) 348 349 if shuffle == "batch": 350 # 1024 is a magic constant that has not been properly evaluated 351 flat_dataset = flat_dataset.shuffle(1024).repeat(epochs) 352 return flat_dataset 353 354 indices_dataset = indices_dataset.flat_map(slice_batch_indices) 355 356 dataset = self.slice_inputs(indices_dataset, inputs) 357 358 if shuffle == "batch": 359 def shuffle_batch(*batch): 360 return nest.map_structure(random_ops.random_shuffle, batch) 361 dataset = dataset.map(shuffle_batch) 362 363 self._dataset = dataset 364 365 def slice_inputs(self, indices_dataset, inputs): 366 """Slice inputs into a Dataset of batches. 367 368 Given a Dataset of batch indices and the unsliced inputs, 369 this step slices the inputs in a parallelized fashion 370 and produces a dataset of input batches. 371 372 Args: 373 indices_dataset: A Dataset of batched indices 374 inputs: A python data structure that contains the inputs, targets, 375 and possibly sample weights. 376 377 Returns: 378 A Dataset of input batches matching the batch indices. 379 """ 380 dataset = dataset_ops.DatasetV2.zip(( 381 indices_dataset, 382 dataset_ops.DatasetV2.from_tensors(inputs).repeat() 383 )) 384 385 def grab_batch(i, data): 386 return nest.map_structure(lambda d: array_ops.gather(d, i, axis=0), data) 387 388 dataset = dataset.map( 389 grab_batch, num_parallel_calls=dataset_ops.AUTOTUNE) 390 391 # Default optimizations are disabled to avoid the overhead of (unnecessary) 392 # input pipeline graph serialization and deserialization 393 options = dataset_ops.Options() 394 options.experimental_optimization.apply_default_optimizations = False 395 if self._shuffle: 396 # See b/141490660 for more details. 397 options.experimental_external_state_policy = ( 398 distribute_options.ExternalStatePolicy.IGNORE) 399 dataset = dataset.with_options(options) 400 return dataset 401 402 def get_dataset(self): 403 return self._dataset 404 405 def get_size(self): 406 return self._size 407 408 def batch_size(self): 409 return self._batch_size 410 411 def has_partial_batch(self): 412 return self._partial_batch_size > 0 413 414 def partial_batch_size(self): 415 return self._partial_batch_size or None 416 417 def should_recreate_iterator(self): 418 # An infinite dataset is always created here. 419 return False 420 421 422class GenericArrayLikeDataAdapter(TensorLikeDataAdapter): 423 """Adapter that handles array-like data without forcing it into memory. 424 425 This adapter handles array-like datasets that may be too big to fully 426 fit into memory. 427 428 Specifically, this adapter handles any Python class which implements: 429 `__get_item__`, `__len__`, `shape`, and `dtype` with the same meanings 430 as Numpy, but it ignores any case where all the inputs are Tensors or Numpy 431 arrays (because that case is handled by the base TensorLikeDataAdapter). 432 433 It ignores scipy sparse matrices and Composite Tensors because those are 434 handled by the CompositeTensorDataAdapter. 435 436 It also does not handle lists/tuples of scalars, because those are handled 437 by the ListsOfScalarsDataAdapter. 438 """ 439 440 @staticmethod 441 def can_handle(x, y=None): 442 flat_inputs = nest.flatten(x) 443 if y is not None: 444 flat_inputs += nest.flatten(y) 445 446 def _is_array_like(v): 447 """Return True if v is a Tensor, array, or is array-like.""" 448 return ( 449 hasattr(v, "__getitem__") and 450 hasattr(v, "shape") and 451 hasattr(v, "dtype") and 452 hasattr(v, "__len__") 453 ) 454 455 if (not TensorLikeDataAdapter.can_handle(x, y) and 456 not CompositeTensorDataAdapter.can_handle(x, y)): 457 return all(_is_array_like(v) for v in flat_inputs) 458 else: 459 return False 460 461 def __init__(self, *args, **kwargs): 462 logging.warn( 463 "Keras is training/fitting/evaluating on array-like data. Keras may " 464 "not be optimized for this format, so if your input data format is " 465 "supported by TensorFlow I/O (https://github.com/tensorflow/io) we " 466 "recommend using that to load a Dataset instead.") 467 468 super(GenericArrayLikeDataAdapter, self).__init__(*args, **kwargs) 469 470 def slice_inputs(self, indices_dataset, inputs): 471 """Slice inputs into a Dataset of batches. 472 473 Given a Dataset of batch indices and the unsliced inputs, 474 this step slices the inputs in a parallelized fashion 475 and produces a dataset of input batches. 476 477 Args: 478 indices_dataset: A Dataset of batched indices 479 inputs: A python data structure that contains the inputs, targets, 480 and possibly sample weights. 481 482 Returns: 483 A Dataset of input batches matching the batch indices. 484 """ 485 flat_inputs = nest.flatten(inputs) 486 def dynamic_shape_like(t): 487 shape = list(t.shape) 488 shape[0] = None 489 return tuple(shape) 490 491 flat_dtypes = [inp.dtype for inp in flat_inputs] 492 contiguous = True 493 if self._shuffle and self._shuffle != "batch": 494 contiguous = False 495 496 def grab_batch(indices): 497 """Grab a batch of data from the inputs.""" 498 # This uses a py_function to avoid converting the array-like 499 # into a Tensor before slicing it, because converting the array-like 500 # to a Tensor may force it into memory.. 501 def py_method(ind): 502 def slice_array(data): 503 return training_utils.slice_arrays(data, ind.numpy(), 504 contiguous=contiguous) 505 return [slice_array(inp) for inp in flat_inputs] 506 507 flat_out = script_ops.eager_py_func(py_method, [indices], flat_dtypes) 508 for v, original_inp in zip(flat_out, flat_inputs): 509 v.set_shape(dynamic_shape_like(original_inp)) 510 return nest.pack_sequence_as(inputs, flat_out) 511 512 dataset = indices_dataset.map( 513 grab_batch, num_parallel_calls=dataset_ops.AUTOTUNE) 514 515 return dataset 516 517 518class DatasetCreatorAdapter(DataAdapter): 519 """Adapter that handles dataset functions.""" 520 521 def __init__(self, *args, **kwargs): 522 super(DatasetCreatorAdapter, self).__init__(*args, **kwargs) 523 524 @staticmethod 525 def can_handle(x, y=None): 526 if isinstance(x, dataset_creator.DatasetCreator): 527 assert y is None 528 return True 529 530 def should_recreate_iterator(self): 531 # We expect users to shuffle the dataset in their `dataset_fn` supplied to 532 # `DatasetCreator`. Since that is a buffered shuffle, we intend to not reset 533 # the dataset so the batches that are not shuffled can still be pulled. 534 return False 535 536 def get_size(self): 537 raise NotImplementedError() 538 539 def get_dataset(self): 540 raise NotImplementedError() 541 542 def batch_size(self): 543 raise NotImplementedError() 544 545 def has_partial_batch(self): 546 raise NotImplementedError() 547 548 def partial_batch_size(self): 549 raise NotImplementedError() 550 551 552class CompositeTensorDataAdapter(DataAdapter): 553 """Adapter that handles composite tensor.""" 554 555 @staticmethod 556 def can_handle(x, y=None): 557 flat_inputs = nest.flatten(x) 558 if y is not None: 559 flat_inputs += nest.flatten(y) 560 561 def _is_composite(v): 562 # Dataset/iterator inherits from CompositeTensor but should be handled 563 # by DatasetAdapter and GeneratorAdapter. 564 if (tf_utils.is_extension_type(v) and 565 not isinstance(v, (dataset_ops.DatasetV2, 566 iterator_ops.IteratorBase))): 567 return True 568 # Support Scipy sparse tensors if scipy is installed 569 if scipy_sparse is not None and scipy_sparse.issparse(v): 570 return True 571 return False 572 573 def _is_tensor_or_composite(v): 574 if isinstance(v, (ops.Tensor, np.ndarray)): 575 return True 576 return _is_composite(v) 577 578 return (any(_is_composite(v) for v in flat_inputs) and 579 all(_is_tensor_or_composite(v) for v in flat_inputs)) 580 581 def __init__(self, 582 x, 583 y=None, 584 sample_weights=None, 585 sample_weight_modes=None, 586 batch_size=None, 587 steps=None, 588 shuffle=False, 589 **kwargs): 590 super(CompositeTensorDataAdapter, self).__init__(x, y, **kwargs) 591 x, y, sample_weights = _process_tensorlike((x, y, sample_weights)) 592 sample_weight_modes = broadcast_sample_weight_modes( 593 sample_weights, sample_weight_modes) 594 595 # If sample_weights are not specified for an output use 1.0 as weights. 596 (sample_weights, _, _) = training_utils.handle_partial_sample_weights( 597 y, sample_weights, sample_weight_modes, check_all_flat=True) 598 599 inputs = pack_x_y_sample_weight(x, y, sample_weights) 600 601 dataset = dataset_ops.DatasetV2.from_tensor_slices(inputs) 602 num_samples = int(nest.flatten(x)[0].shape[0]) 603 if shuffle: 604 dataset = dataset.shuffle(num_samples) 605 606 # If batch_size is not passed but steps is, calculate from the input data. 607 # Default to 32 for backwards compat. 608 if not batch_size: 609 batch_size = int(math.ceil(num_samples / steps)) if steps else 32 610 611 dataset = dataset.batch(batch_size) 612 self._size = int(math.ceil(num_samples / batch_size)) 613 self._batch_size = batch_size 614 self._has_partial_batch = (self._size != (num_samples // batch_size)) 615 616 self._partial_batch_size = None 617 if self._has_partial_batch: 618 self._partial_batch_size = ( 619 num_samples - (self._size - 1) * self._batch_size) 620 621 self._dataset = dataset 622 623 def get_dataset(self): 624 return self._dataset 625 626 def get_size(self): 627 return self._size 628 629 def batch_size(self): 630 return self._batch_size 631 632 def has_partial_batch(self): 633 return self._has_partial_batch 634 635 def partial_batch_size(self): 636 return self._partial_batch_size 637 638 def should_recreate_iterator(self): 639 return True 640 641 642class ListsOfScalarsDataAdapter(DataAdapter): 643 """Adapter that handles lists of scalars and lists of lists of scalars.""" 644 645 @staticmethod 646 def can_handle(x, y=None): 647 handles_x = ListsOfScalarsDataAdapter._is_list_of_scalars(x) 648 handles_y = True 649 if y is not None: 650 handles_y = ListsOfScalarsDataAdapter._is_list_of_scalars(y) 651 return handles_x and handles_y 652 653 @staticmethod 654 def _is_list_of_scalars(inp): 655 if isinstance(inp, (float, int, str, bytes, bytearray)): 656 return True 657 if isinstance(inp, (list, tuple)) and inp: 658 return ListsOfScalarsDataAdapter._is_list_of_scalars(inp[0]) 659 return False 660 661 def __init__(self, 662 x, 663 y=None, 664 sample_weights=None, 665 sample_weight_modes=None, 666 batch_size=None, 667 shuffle=False, 668 **kwargs): 669 super(ListsOfScalarsDataAdapter, self).__init__(x, y, **kwargs) 670 x = np.asarray(x) 671 if y is not None: 672 y = np.asarray(y) 673 if sample_weights is not None: 674 sample_weights = np.asarray(sample_weights) 675 sample_weight_modes = broadcast_sample_weight_modes( 676 sample_weights, sample_weight_modes) 677 678 self._internal_adapter = TensorLikeDataAdapter( 679 x, 680 y=y, 681 sample_weights=sample_weights, 682 sample_weight_modes=sample_weight_modes, 683 batch_size=batch_size, 684 shuffle=shuffle, 685 **kwargs) 686 687 def get_dataset(self): 688 return self._internal_adapter.get_dataset() 689 690 def get_size(self): 691 return self._internal_adapter.get_size() 692 693 def batch_size(self): 694 return self._internal_adapter.batch_size() 695 696 def has_partial_batch(self): 697 return self._internal_adapter.has_partial_batch() 698 699 def partial_batch_size(self): 700 return self._internal_adapter.partial_batch_size() 701 702 def should_recreate_iterator(self): 703 return True 704 705 706class DatasetAdapter(DataAdapter): 707 """Adapter that handles `tf.data.Dataset`.""" 708 709 @staticmethod 710 def can_handle(x, y=None): 711 return (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)) or 712 _is_distributed_dataset(x)) 713 714 def __init__(self, 715 x, 716 y=None, 717 sample_weights=None, 718 steps=None, 719 **kwargs): 720 super(DatasetAdapter, self).__init__(x, y, **kwargs) 721 # Note that the dataset instance is immutable, its fine to reuse the user 722 # provided dataset. 723 self._dataset = x 724 725 # The user-provided steps. 726 self._user_steps = steps 727 728 self._validate_args(y, sample_weights, steps) 729 730 def get_dataset(self): 731 return self._dataset 732 733 def get_size(self): 734 return # Inferred in `DataHandler`. 735 736 def batch_size(self): 737 return None 738 739 def has_partial_batch(self): 740 return False 741 742 def partial_batch_size(self): 743 return None 744 745 def should_recreate_iterator(self): 746 # Since DistributedDatasets have no cardinality, the user must provide 747 # all steps that need to be run, calling `.repeat()` as needed. 748 if _is_distributed_dataset(self._dataset): 749 return False 750 751 # If user doesn't supply `steps`, or if they supply `steps` that 752 # exactly equals the size of the `Dataset`, create a new iterator 753 # each epoch. 754 return (self._user_steps is None or 755 cardinality.cardinality(self._dataset).numpy() == self._user_steps) 756 757 def _validate_args(self, y, sample_weights, steps): 758 """Validates `__init__` arguments.""" 759 # Arguments that shouldn't be passed. 760 if not is_none_or_empty(y): 761 raise ValueError("`y` argument is not supported when using " 762 "dataset as input.") 763 if not is_none_or_empty(sample_weights): 764 raise ValueError("`sample_weight` argument is not supported when using " 765 "dataset as input.") 766 767 if steps is None: 768 if _is_distributed_dataset(self._dataset): 769 raise ValueError("When providing a distributed dataset, you must " 770 "specify the number of steps to run.") 771 772 size = cardinality.cardinality(self._dataset).numpy() 773 if size == cardinality.INFINITE and steps is None: 774 raise ValueError( 775 "When providing an infinite dataset, you must specify " 776 "the number of steps to run (if you did not intend to " 777 "create an infinite dataset, make sure to not call " 778 "`repeat()` on the dataset).") 779 780 781class GeneratorDataAdapter(DataAdapter): 782 """Adapter that handles python generators and iterators.""" 783 784 @staticmethod 785 def can_handle(x, y=None): 786 return ((hasattr(x, "__next__") or hasattr(x, "next")) 787 and hasattr(x, "__iter__") 788 and not isinstance(x, data_utils.Sequence)) 789 790 def __init__(self, 791 x, 792 y=None, 793 sample_weights=None, 794 workers=1, 795 use_multiprocessing=False, 796 max_queue_size=10, 797 model=None, 798 **kwargs): 799 # Generators should never shuffle as exhausting the generator in order to 800 # shuffle the batches is inefficient. 801 kwargs.pop("shuffle", None) 802 803 if not is_none_or_empty(y): 804 raise ValueError("`y` argument is not supported when using " 805 "python generator as input.") 806 if not is_none_or_empty(sample_weights): 807 raise ValueError("`sample_weight` argument is not supported when using " 808 "python generator as input.") 809 810 super(GeneratorDataAdapter, self).__init__(x, y, **kwargs) 811 812 # Since we have to know the dtype of the python generator when we build the 813 # dataset, we have to look at a batch to infer the structure. 814 peek, x = self._peek_and_restore(x) 815 peek = self._standardize_batch(peek) 816 peek = _process_tensorlike(peek) 817 818 # Need to build the Model on concrete input shapes. 819 if model is not None and not model.built: 820 concrete_x, _, _ = unpack_x_y_sample_weight(peek) 821 model.distribute_strategy.run( 822 lambda x: model(x, training=False), args=(concrete_x,)) 823 824 self._first_batch_size = int(nest.flatten(peek)[0].shape[0]) 825 826 def _get_dynamic_shape(t): 827 shape = t.shape 828 # Unknown number of dimensions, `as_list` cannot be called. 829 if shape.rank is None: 830 return shape 831 return tensor_shape.TensorShape([None for _ in shape.as_list()]) 832 833 output_shapes = nest.map_structure(_get_dynamic_shape, peek) 834 output_types = nest.map_structure(lambda t: t.dtype, peek) 835 836 # Note that dataset API takes a callable that creates a generator object, 837 # rather than generator itself, which is why we define a function here. 838 generator_fn = self._handle_multiprocessing(x, workers, use_multiprocessing, 839 max_queue_size) 840 841 def wrapped_generator(): 842 for data in generator_fn(): 843 yield self._standardize_batch(data) 844 845 dataset = dataset_ops.DatasetV2.from_generator( 846 wrapped_generator, output_types, output_shapes=output_shapes) 847 848 if workers == 1 and not use_multiprocessing: 849 dataset = dataset.prefetch(1) 850 851 self._dataset = dataset 852 853 def _standardize_batch(self, data): 854 """Standardizes a batch output by a generator.""" 855 # Removes `None`s. 856 x, y, sample_weight = unpack_x_y_sample_weight(data) 857 data = pack_x_y_sample_weight(x, y, sample_weight) 858 859 data = nest.list_to_tuple(data) 860 861 def _convert_dtype(t): 862 if (isinstance(t, np.ndarray) and issubclass(t.dtype.type, np.floating)): 863 return np.array(t, dtype=backend.floatx()) 864 return t 865 866 data = nest.map_structure(_convert_dtype, data) 867 return data 868 869 @staticmethod 870 def _peek_and_restore(x): 871 peek = next(x) 872 return peek, itertools.chain([peek], x) 873 874 def _handle_multiprocessing(self, x, workers, use_multiprocessing, 875 max_queue_size): 876 """Create a callable, possibly including an Enqueuer.""" 877 if workers > 1 or (workers > 0 and use_multiprocessing): 878 def generator_fn(): 879 enqueuer = data_utils.GeneratorEnqueuer( 880 x, use_multiprocessing=use_multiprocessing) 881 enqueuer.start(workers=workers, max_queue_size=max_queue_size) 882 return enqueuer.get() 883 else: 884 generator_fn = lambda: x 885 return generator_fn 886 887 def get_dataset(self): 888 return self._dataset 889 890 def get_size(self): 891 return None 892 893 def batch_size(self): 894 return None 895 896 def representative_batch_size(self): 897 return self._first_batch_size 898 899 def has_partial_batch(self): 900 return False 901 902 def partial_batch_size(self): 903 return 904 905 def should_recreate_iterator(self): 906 return False 907 908 909class KerasSequenceAdapter(GeneratorDataAdapter): 910 """Adapter that handles `keras.utils.Sequence`.""" 911 912 @staticmethod 913 def can_handle(x, y=None): 914 return isinstance(x, data_utils.Sequence) 915 916 def __init__(self, 917 x, 918 y=None, 919 sample_weights=None, 920 shuffle=False, 921 workers=1, 922 use_multiprocessing=False, 923 max_queue_size=10, 924 model=None, 925 **kwargs): 926 if not is_none_or_empty(y): 927 raise ValueError("`y` argument is not supported when using " 928 "`keras.utils.Sequence` as input.") 929 if not is_none_or_empty(sample_weights): 930 raise ValueError("`sample_weight` argument is not supported when using " 931 "`keras.utils.Sequence` as input.") 932 933 self._size = len(x) 934 self._shuffle_sequence = shuffle 935 self._keras_sequence = x 936 self._enqueuer = None 937 super(KerasSequenceAdapter, self).__init__( 938 x, 939 shuffle=False, # Shuffle is handed in the _make_callable override. 940 workers=workers, 941 use_multiprocessing=use_multiprocessing, 942 max_queue_size=max_queue_size, 943 model=model, 944 **kwargs) 945 946 @staticmethod 947 def _peek_and_restore(x): 948 return x[0], x 949 950 def _handle_multiprocessing(self, x, workers, use_multiprocessing, 951 max_queue_size): 952 if workers > 1 or (workers > 0 and use_multiprocessing): 953 def generator_fn(): 954 self._enqueuer = data_utils.OrderedEnqueuer( 955 x, use_multiprocessing=use_multiprocessing, 956 shuffle=self._shuffle_sequence) 957 self._enqueuer.start(workers=workers, max_queue_size=max_queue_size) 958 return self._enqueuer.get() 959 else: 960 def generator_fn(): 961 order = range(len(x)) 962 if self._shuffle_sequence: 963 # Match the shuffle convention in OrderedEnqueuer. 964 order = list(order) 965 random.shuffle(order) 966 967 for i in order: 968 yield x[i] 969 970 return generator_fn 971 972 def get_size(self): 973 return self._size 974 975 def should_recreate_iterator(self): 976 return True 977 978 def on_epoch_end(self): 979 if self._enqueuer: 980 self._enqueuer.stop() 981 self._keras_sequence.on_epoch_end() 982 983 984ALL_ADAPTER_CLS = [ 985 ListsOfScalarsDataAdapter, TensorLikeDataAdapter, 986 GenericArrayLikeDataAdapter, DatasetAdapter, GeneratorDataAdapter, 987 KerasSequenceAdapter, CompositeTensorDataAdapter, DatasetCreatorAdapter 988] 989 990 991def select_data_adapter(x, y): 992 """Selects a data adapter than can handle a given x and y.""" 993 adapter_cls = [cls for cls in ALL_ADAPTER_CLS if cls.can_handle(x, y)] 994 if not adapter_cls: 995 # TODO(scottzhu): This should be a less implementation-specific error. 996 raise ValueError( 997 "Failed to find data adapter that can handle " 998 "input: {}, {}".format( 999 _type_name(x), _type_name(y))) 1000 elif len(adapter_cls) > 1: 1001 raise RuntimeError( 1002 "Data adapters should be mutually exclusive for " 1003 "handling inputs. Found multiple adapters {} to handle " 1004 "input: {}, {}".format( 1005 adapter_cls, _type_name(x), _type_name(y))) 1006 # Instrument the data adapter usage before returning it 1007 keras_data_adapter_gauge.get_cell(adapter_cls[0].__name__).set(True) 1008 return adapter_cls[0] 1009 1010 1011def _type_name(x): 1012 """Generates a description of the type of an object.""" 1013 if isinstance(x, dict): 1014 key_types = set(_type_name(key) for key in x.keys()) 1015 val_types = set(_type_name(key) for key in x.values()) 1016 return "({} containing {} keys and {} values)".format( 1017 type(x), key_types, val_types) 1018 if isinstance(x, (list, tuple)): 1019 types = set(_type_name(val) for val in x) 1020 return "({} containing values of types {})".format( 1021 type(x), types) 1022 return str(type(x)) 1023 1024 1025def _process_tensorlike(inputs): 1026 """Process tensor-like inputs. 1027 1028 This function: 1029 1030 (1) Converts `Numpy` arrays to `Tensor`s. 1031 (2) Converts `Scipy` sparse matrices to `SparseTensor`s. 1032 (2) Converts `list`s to `tuple`s (for `tf.data` support). 1033 1034 Args: 1035 inputs: Structure of `Tensor`s, `NumPy` arrays, or tensor-like. 1036 1037 Returns: 1038 Structure of `Tensor`s or tensor-like. 1039 """ 1040 1041 def _convert_numpy_and_scipy(x): 1042 if isinstance(x, np.ndarray): 1043 dtype = None 1044 if issubclass(x.dtype.type, np.floating): 1045 dtype = backend.floatx() 1046 return ops.convert_to_tensor_v2_with_dispatch(x, dtype=dtype) 1047 elif scipy_sparse and scipy_sparse.issparse(x): 1048 return _scipy_sparse_to_sparse_tensor(x) 1049 return x 1050 1051 inputs = nest.map_structure(_convert_numpy_and_scipy, inputs) 1052 return nest.list_to_tuple(inputs) 1053 1054 1055def is_none_or_empty(inputs): 1056 # util method to check if the input is a None or a empty list. 1057 # the python "not" check will raise an error like below if the input is a 1058 # numpy array 1059 # "The truth value of an array with more than one element is ambiguous. 1060 # Use a.any() or a.all()" 1061 return inputs is None or not nest.flatten(inputs) 1062 1063 1064def broadcast_sample_weight_modes(target_structure, sample_weight_modes): 1065 """Match sample_weight_modes structure with output structure.""" 1066 if target_structure is None or not nest.flatten(target_structure): 1067 return sample_weight_modes 1068 1069 if isinstance(sample_weight_modes, str): 1070 if isinstance(target_structure, dict): 1071 return {key: sample_weight_modes for key in target_structure.keys()} 1072 return [sample_weight_modes for _ in target_structure] 1073 1074 if sample_weight_modes: 1075 try: 1076 nest.assert_same_structure( 1077 training_utils.list_to_tuple(target_structure), 1078 training_utils.list_to_tuple(sample_weight_modes)) 1079 except (ValueError, TypeError): 1080 target_str = str(nest.map_structure(lambda _: "...", target_structure)) 1081 mode_str = str(nest.map_structure(lambda _: "...", sample_weight_modes)) 1082 1083 # Attempt to coerce sample_weight_modes to the target structure. This 1084 # implicitly depends on the fact that Model flattens outputs for its 1085 # internal representation. 1086 try: 1087 sample_weight_modes = nest.pack_sequence_as( 1088 target_structure, nest.flatten(sample_weight_modes)) 1089 logging.warning( 1090 "sample_weight modes were coerced from\n {}\n to \n {}" 1091 .format(target_str, mode_str)) 1092 except (ValueError, TypeError): 1093 raise ValueError( 1094 "Unable to match target structure and sample_weight_modes " 1095 "structure:\n {}\n to \n {}".format(target_str, mode_str)) 1096 1097 return sample_weight_modes 1098 1099 1100class DataHandler(object): 1101 """Handles iterating over epoch-level `tf.data.Iterator` objects.""" 1102 1103 def __init__(self, 1104 x, 1105 y=None, 1106 sample_weight=None, 1107 batch_size=None, 1108 steps_per_epoch=None, 1109 initial_epoch=0, 1110 epochs=1, 1111 shuffle=False, 1112 class_weight=None, 1113 max_queue_size=10, 1114 workers=1, 1115 use_multiprocessing=False, 1116 model=None, 1117 steps_per_execution=None, 1118 distribute=True): 1119 """Initializes a `DataHandler`. 1120 1121 Arguments: 1122 x: See `Model.fit`. 1123 y: See `Model.fit`. 1124 sample_weight: See `Model.fit`. 1125 batch_size: See `Model.fit`. 1126 steps_per_epoch: See `Model.fit`. 1127 initial_epoch: See `Model.fit`. 1128 epochs: See `Model.fit`. 1129 shuffle: See `Model.fit`. 1130 class_weight: See `Model.fit`. 1131 max_queue_size: See `Model.fit`. 1132 workers: See `Model.fit`. 1133 use_multiprocessing: See `Model.fit`. 1134 model: The `Model` instance. Needed in order to correctly `build` the 1135 `Model` using generator-like inputs (see `GeneratorDataAdapter`). 1136 steps_per_execution: See `Model.compile`. 1137 distribute: Whether to distribute the `tf.dataset`. 1138 `PreprocessingLayer.adapt` does not support distributed datasets, 1139 `Model` should always set this to `True`. 1140 """ 1141 1142 self._initial_epoch = initial_epoch 1143 self._epochs = epochs 1144 self._insufficient_data = False 1145 self._model = model 1146 1147 # `steps_per_execution_value` is the cached initial value. 1148 # `steps_per_execution` is mutable and may be changed by the DataAdapter 1149 # to handle partial executions. 1150 if steps_per_execution is None: 1151 self._steps_per_execution = 1 1152 self._steps_per_execution_value = 1 1153 else: 1154 self._steps_per_execution = steps_per_execution 1155 self._steps_per_execution_value = steps_per_execution.numpy().item() 1156 1157 adapter_cls = select_data_adapter(x, y) 1158 self._verify_data_adapter_compatibility(adapter_cls) 1159 self._adapter = adapter_cls( 1160 x, 1161 y, 1162 batch_size=batch_size, 1163 steps=steps_per_epoch, 1164 epochs=epochs - initial_epoch, 1165 sample_weights=sample_weight, 1166 shuffle=shuffle, 1167 max_queue_size=max_queue_size, 1168 workers=workers, 1169 use_multiprocessing=use_multiprocessing, 1170 distribution_strategy=ds_context.get_strategy(), 1171 model=model) 1172 1173 strategy = ds_context.get_strategy() 1174 1175 self._current_step = 0 1176 self._step_increment = self._steps_per_execution_value - 1 1177 self._insufficient_data = False 1178 1179 self._configure_dataset_and_inferred_steps(strategy, x, steps_per_epoch, 1180 class_weight, distribute) 1181 1182 def _verify_data_adapter_compatibility(self, adapter_cls): 1183 if adapter_cls == DatasetCreatorAdapter: 1184 raise NotImplementedError("`DatasetCreator` input is only supported in " 1185 "`ParameterServerStrategy` at this time.") 1186 1187 def _configure_dataset_and_inferred_steps(self, strategy, x, steps_per_epoch, 1188 class_weight, distribute): 1189 """Configure the `_dataset` and `_inferred_steps` attributes.""" 1190 del x 1191 dataset = self._adapter.get_dataset() 1192 if class_weight: 1193 dataset = dataset.map(_make_class_weight_map_fn(class_weight)) 1194 self._inferred_steps = self._infer_steps(steps_per_epoch, dataset) 1195 1196 # `PreprocessingLayer.adapt` does not currently support distributed 1197 # datasets, so we pass `distribute=False` there. 1198 if distribute and not _is_distributed_dataset(dataset): 1199 dataset = strategy.experimental_distribute_dataset(dataset) 1200 self._dataset = dataset 1201 self._validate_data_handler() 1202 1203 def enumerate_epochs(self): 1204 """Yields `(epoch, tf.data.Iterator)`.""" 1205 with self._truncate_execution_to_epoch(): 1206 data_iterator = iter(self._dataset) 1207 for epoch in range(self._initial_epoch, self._epochs): 1208 if self._insufficient_data: # Set by `catch_stop_iteration`. 1209 break 1210 if self._adapter.should_recreate_iterator(): 1211 data_iterator = iter(self._dataset) 1212 yield epoch, data_iterator 1213 self._adapter.on_epoch_end() 1214 1215 @contextlib.contextmanager 1216 def _truncate_execution_to_epoch(self): 1217 """Truncates steps per execution to at most one epoch.""" 1218 should_truncate = ( 1219 self._inferred_steps is not None and 1220 self._steps_per_execution_value > self._inferred_steps) 1221 original_value = self._steps_per_execution_value 1222 try: 1223 if should_truncate: 1224 self._steps_per_execution.assign(self._inferred_steps) 1225 self._steps_per_execution_value = self._inferred_steps 1226 yield 1227 finally: 1228 if should_truncate: 1229 self._steps_per_execution.assign(original_value) 1230 self._steps_per_execution_value = original_value 1231 1232 def sync(self): 1233 context.async_wait() 1234 1235 @contextlib.contextmanager 1236 def catch_stop_iteration(self): 1237 """Catches errors when an iterator runs out of data.""" 1238 try: 1239 yield 1240 self.sync() 1241 except (StopIteration, errors.OutOfRangeError): 1242 if self._inferred_steps is None: 1243 self._inferred_steps = self._current_step 1244 else: 1245 self._insufficient_data = True 1246 total_epochs = self._epochs - self._initial_epoch 1247 logging.warning( 1248 "Your input ran out of data; interrupting training. " 1249 "Make sure that your dataset or generator can generate at " 1250 "least `steps_per_epoch * epochs` batches (in this case, " 1251 "{} batches). You may need to use the repeat() function " 1252 "when building your dataset.".format(total_epochs * 1253 self._inferred_steps)) 1254 1255 def steps(self): 1256 """Yields steps for the current epoch.""" 1257 self._current_step = 0 1258 # `self._inferred_steps` can be changed by `catch_stop_iteration`. 1259 while (self._inferred_steps is None or 1260 self._current_step < self._inferred_steps): 1261 if self._insufficient_data: # Set by `catch_stop_iteration`. 1262 break 1263 1264 can_run_full_execution = ( 1265 self._steps_per_execution_value == 1 or 1266 self._inferred_steps is None or 1267 self._inferred_steps - self._current_step >= 1268 self._steps_per_execution_value) 1269 1270 if can_run_full_execution: 1271 self._step_increment = self._steps_per_execution_value - 1 1272 yield self._current_step 1273 self._current_step += self._steps_per_execution_value 1274 else: 1275 # Last partial execution. 1276 steps_remaining = self._inferred_steps - self._current_step 1277 self._steps_per_execution.assign(steps_remaining) 1278 self._step_increment = steps_remaining - 1 1279 yield self._current_step 1280 self._current_step += steps_remaining 1281 self._steps_per_execution.assign(self._steps_per_execution_value) 1282 1283 @property 1284 def step_increment(self): 1285 """The number to increment the step for `on_batch_end` methods.""" 1286 return self._step_increment 1287 1288 @property 1289 def inferred_steps(self): 1290 """The inferred steps per epoch of the created `Dataset`. 1291 1292 This will be `None` in the case where: 1293 1294 (1) A `Dataset` of unknown cardinality was passed to the `DataHandler`, and 1295 (2) `steps_per_epoch` was not provided, and 1296 (3) The first epoch of iteration has not yet completed. 1297 1298 Returns: 1299 The inferred steps per epoch of the created `Dataset`. 1300 """ 1301 return self._inferred_steps 1302 1303 @property 1304 def should_sync(self): 1305 # Catch OutOfRangeError for Datasets of unknown size. 1306 # This blocks until the batch has finished executing. 1307 # TODO(b/150292341): Allow multiple async steps here. 1308 return self._inferred_steps is None 1309 1310 def _infer_steps(self, steps, dataset): 1311 """Infers steps_per_epoch needed to loop through a dataset.""" 1312 if steps is not None: 1313 return steps 1314 1315 adapter_steps = self._adapter.get_size() 1316 if adapter_steps is not None: 1317 return adapter_steps 1318 1319 size = cardinality.cardinality(dataset) 1320 if size == cardinality.INFINITE and steps is None: 1321 raise ValueError("When passing an infinitely repeating dataset, you " 1322 "must specify how many steps to draw.") 1323 if size >= 0: 1324 return size.numpy().item() 1325 return None 1326 1327 @property 1328 def _samples(self): 1329 return self._adapter.get_samples() 1330 1331 def _validate_data_handler(self): 1332 # TODO(b/152094471): Support this with DistIter.get_next_as_optional. 1333 if self._steps_per_execution_value > 1 and self._inferred_steps is None: 1334 raise ValueError( 1335 "Could not infer the size of the data. With " 1336 "`steps_per_execution > 1`, you must specify the number of steps " 1337 "to run.") 1338 1339 def resolve_logs(self, logs): 1340 return logs 1341 1342 1343class _ClusterCoordinatorDataHandler(DataHandler): 1344 """A `DataHandler` that is compatible with `ClusterCoordinator`.""" 1345 1346 def _verify_data_adapter_compatibility(self, adapter_cls): 1347 if adapter_cls != DatasetCreatorAdapter: 1348 raise NotImplementedError("Only `DatasetCreator` input is supported in " 1349 "`ParameterServerStrategy` at this time.") 1350 1351 def _configure_dataset_and_inferred_steps(self, strategy, x, steps_per_epoch, 1352 class_weight, distribute): 1353 if not isinstance(x, dataset_creator.DatasetCreator): 1354 raise TypeError("When using `ParameterServerStrategy`, `x` must be a " 1355 "`DatasetCreator`.") 1356 1357 def per_worker_dataset_fn(): 1358 return strategy.distribute_datasets_from_function(x) 1359 1360 self._dataset = self._model._cluster_coordinator.create_per_worker_dataset( # pylint: disable=protected-access 1361 per_worker_dataset_fn) 1362 if steps_per_epoch is None: 1363 raise ValueError( 1364 "`steps_per_epoch` must be specified with `ParameterServerStrategy`.") 1365 self._inferred_steps = steps_per_epoch 1366 1367 def sync(self): 1368 self._model._cluster_coordinator.join() # pylint: disable=protected-access 1369 1370 def resolve_logs(self, logs): 1371 return logs.fetch() 1372 1373 1374def get_data_handler(*args, **kwargs): 1375 if getattr(kwargs["model"], "_cluster_coordinator", None): 1376 return _ClusterCoordinatorDataHandler(*args, **kwargs) 1377 return DataHandler(*args, **kwargs) 1378 1379 1380def _make_class_weight_map_fn(class_weight): 1381 """Applies class weighting to a `Dataset`. 1382 1383 The `Dataset` is assumed to be in format `(x, y)` or `(x, y, sw)`, where 1384 `y` must be a single `Tensor`. 1385 1386 Args: 1387 class_weight: A map where the keys are integer class ids and values are 1388 the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}` 1389 1390 Returns: 1391 A function that can be used with `tf.data.Dataset.map` to apply class 1392 weighting. 1393 """ 1394 class_ids = list(sorted(class_weight.keys())) 1395 expected_class_ids = list(range(len(class_ids))) 1396 if class_ids != expected_class_ids: 1397 error_msg = ( 1398 "Expected `class_weight` to be a dict with keys from 0 to one less " 1399 "than the number of classes, found {}").format(class_weight) 1400 raise ValueError(error_msg) 1401 1402 class_weight_tensor = ops.convert_to_tensor_v2_with_dispatch( 1403 [class_weight[int(c)] for c in class_ids]) 1404 1405 def _class_weights_map_fn(*data): 1406 """Convert `class_weight` to `sample_weight`.""" 1407 x, y, sw = unpack_x_y_sample_weight(data) 1408 1409 if nest.is_nested(y): 1410 raise ValueError( 1411 "`class_weight` is only supported for Models with a single output.") 1412 1413 if y.shape.rank > 2: 1414 raise ValueError("`class_weight` not supported for " 1415 "3+ dimensional targets.") 1416 1417 y_classes = smart_cond.smart_cond( 1418 y.shape.rank == 2 and backend.shape(y)[1] > 1, 1419 lambda: backend.argmax(y, axis=1), 1420 lambda: math_ops.cast(backend.reshape(y, (-1,)), dtypes.int64)) 1421 1422 cw = array_ops.gather_v2(class_weight_tensor, y_classes) 1423 if sw is not None: 1424 cw = math_ops.cast(cw, sw.dtype) 1425 sw, cw = expand_1d((sw, cw)) 1426 # `class_weight` and `sample_weight` are multiplicative. 1427 sw = sw * cw 1428 else: 1429 sw = cw 1430 1431 return x, y, sw 1432 1433 return _class_weights_map_fn 1434 1435 1436def expand_1d(data): 1437 """Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s.""" 1438 1439 def _expand_single_1d_tensor(t): 1440 # Leaves `CompositeTensor`s as-is. 1441 if (isinstance(t, ops.Tensor) and 1442 isinstance(t.shape, tensor_shape.TensorShape) and t.shape.rank == 1): 1443 return array_ops.expand_dims_v2(t, axis=-1) 1444 return t 1445 1446 return nest.map_structure(_expand_single_1d_tensor, data) 1447 1448 1449def train_validation_split(arrays, validation_split): 1450 """Split arrays into train and validation subsets in deterministic order. 1451 1452 The last part of data will become validation data. 1453 1454 Args: 1455 arrays: Tensors to split. Allowed inputs are arbitrarily nested structures 1456 of Tensors and NumPy arrays. 1457 validation_split: Float between 0 and 1. The proportion of the dataset to 1458 include in the validation split. The rest of the dataset will be included 1459 in the training split. 1460 Returns: 1461 `(train_arrays, validation_arrays)` 1462 """ 1463 1464 def _can_split(t): 1465 tensor_types = (ops.Tensor, np.ndarray) 1466 if pd: 1467 tensor_types = (ops.Tensor, np.ndarray, pd.Series, pd.DataFrame) 1468 return isinstance(t, tensor_types) or t is None 1469 1470 flat_arrays = nest.flatten(arrays) 1471 unsplitable = [type(t) for t in flat_arrays if not _can_split(t)] 1472 if unsplitable: 1473 raise ValueError( 1474 "`validation_split` is only supported for Tensors or NumPy " 1475 "arrays, found following types in the input: {}".format(unsplitable)) 1476 1477 if all(t is None for t in flat_arrays): 1478 return arrays, arrays 1479 1480 first_non_none = None 1481 for t in flat_arrays: 1482 if t is not None: 1483 first_non_none = t 1484 break 1485 1486 # Assumes all arrays have the same batch shape or are `None`. 1487 batch_dim = int(first_non_none.shape[0]) 1488 split_at = int(math.floor(batch_dim * (1. - validation_split))) 1489 1490 if split_at == 0 or split_at == batch_dim: 1491 raise ValueError( 1492 "Training data contains {batch_dim} samples, which is not sufficient " 1493 "to split it into a validation and training set as specified by " 1494 "`validation_split={validation_split}`. Either provide more data, or a " 1495 "different value for the `validation_split` argument." .format( 1496 batch_dim=batch_dim, validation_split=validation_split)) 1497 1498 def _split(t, start, end): 1499 if t is None: 1500 return t 1501 return t[start:end] 1502 1503 train_arrays = nest.map_structure( 1504 functools.partial(_split, start=0, end=split_at), arrays) 1505 val_arrays = nest.map_structure( 1506 functools.partial(_split, start=split_at, end=batch_dim), arrays) 1507 1508 return train_arrays, val_arrays 1509 1510 1511@keras_export("keras.utils.unpack_x_y_sample_weight", v1=[]) 1512def unpack_x_y_sample_weight(data): 1513 """Unpacks user-provided data tuple. 1514 1515 This is a convenience utility to be used when overriding 1516 `Model.train_step`, `Model.test_step`, or `Model.predict_step`. 1517 This utility makes it easy to support data of the form `(x,)`, 1518 `(x, y)`, or `(x, y, sample_weight)`. 1519 1520 Standalone usage: 1521 1522 >>> features_batch = tf.ones((10, 5)) 1523 >>> labels_batch = tf.zeros((10, 5)) 1524 >>> data = (features_batch, labels_batch) 1525 >>> # `y` and `sample_weight` will default to `None` if not provided. 1526 >>> x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data) 1527 >>> sample_weight is None 1528 True 1529 1530 Example in overridden `Model.train_step`: 1531 1532 ```python 1533 class MyModel(tf.keras.Model): 1534 1535 def train_step(self, data): 1536 # If `sample_weight` is not provided, all samples will be weighted 1537 # equally. 1538 x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data) 1539 1540 with tf.GradientTape() as tape: 1541 y_pred = self(x, training=True) 1542 loss = self.compiled_loss( 1543 y, y_pred, sample_weight, regularization_losses=self.losses) 1544 trainable_variables = self.trainable_variables 1545 gradients = tape.gradient(loss, trainable_variables) 1546 self.optimizer.apply_gradients(zip(gradients, trainable_variables)) 1547 1548 self.compiled_metrics.update_state(y, y_pred, sample_weight) 1549 return {m.name: m.result() for m in self.metrics} 1550 ``` 1551 1552 Args: 1553 data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`. 1554 1555 Returns: 1556 The unpacked tuple, with `None`s for `y` and `sample_weight` if they are not 1557 provided. 1558 """ 1559 if not isinstance(data, tuple): 1560 return (data, None, None) 1561 elif len(data) == 1: 1562 return (data[0], None, None) 1563 elif len(data) == 2: 1564 return (data[0], data[1], None) 1565 elif len(data) == 3: 1566 return (data[0], data[1], data[2]) 1567 else: 1568 error_msg = ("Data is expected to be in format `x`, `(x,)`, `(x, y)`, " 1569 "or `(x, y, sample_weight)`, found: {}").format(data) 1570 raise ValueError(error_msg) 1571 1572 1573@keras_export("keras.utils.pack_x_y_sample_weight", v1=[]) 1574def pack_x_y_sample_weight(x, y=None, sample_weight=None): 1575 """Packs user-provided data into a tuple. 1576 1577 This is a convenience utility for packing data into the tuple formats 1578 that `Model.fit` uses. 1579 1580 Standalone usage: 1581 1582 >>> x = tf.ones((10, 1)) 1583 >>> data = tf.keras.utils.pack_x_y_sample_weight(x) 1584 >>> isinstance(data, tf.Tensor) 1585 True 1586 >>> y = tf.ones((10, 1)) 1587 >>> data = tf.keras.utils.pack_x_y_sample_weight(x, y) 1588 >>> isinstance(data, tuple) 1589 True 1590 >>> x, y = data 1591 1592 Args: 1593 x: Features to pass to `Model`. 1594 y: Ground-truth targets to pass to `Model`. 1595 sample_weight: Sample weight for each element. 1596 1597 Returns: 1598 Tuple in the format used in `Model.fit`. 1599 """ 1600 if y is None: 1601 # For single x-input, we do no tuple wrapping since in this case 1602 # there is no ambiguity. This also makes NumPy and Dataset 1603 # consistent in that the user does not have to wrap their Dataset 1604 # data in an unecessary tuple 1605 if not nest.is_nested(x): 1606 return x 1607 else: 1608 return (x,) 1609 elif sample_weight is None: 1610 return (x, y) 1611 else: 1612 return (x, y, sample_weight) 1613 1614 1615def single_batch_iterator(strategy, 1616 x, 1617 y=None, 1618 sample_weight=None, 1619 class_weight=None): 1620 """Creates a single-batch dataset.""" 1621 x, y, sample_weight = _process_tensorlike((x, y, sample_weight)) 1622 if y is None: 1623 data = (x,) 1624 elif sample_weight is None: 1625 data = (x, y) 1626 else: 1627 data = (x, y, sample_weight) 1628 1629 _check_data_cardinality(data) 1630 dataset = dataset_ops.DatasetV2.from_tensors(data) 1631 if class_weight: 1632 dataset = dataset.map(_make_class_weight_map_fn(class_weight)) 1633 dataset = strategy.experimental_distribute_dataset(dataset) 1634 return iter(dataset) 1635 1636 1637def _check_data_cardinality(data): 1638 num_samples = set(int(i.shape[0]) for i in nest.flatten(data)) 1639 if len(num_samples) > 1: 1640 msg = "Data cardinality is ambiguous:\n" 1641 for label, single_data in zip(["x", "y", "sample_weight"], data): 1642 msg += " {} sizes: {}\n".format( 1643 label, ", ".join(str(i.shape[0]) for i in nest.flatten(single_data))) 1644 msg += "Make sure all arrays contain the same number of samples." 1645 raise ValueError(msg) 1646 1647 1648def _scipy_sparse_to_sparse_tensor(t): 1649 """Converts a SciPy sparse matrix to a SparseTensor.""" 1650 sparse_coo = t.tocoo() 1651 row, col = sparse_coo.row, sparse_coo.col 1652 data, shape = sparse_coo.data, sparse_coo.shape 1653 if issubclass(data.dtype.type, np.floating): 1654 data = data.astype(backend.floatx()) 1655 indices = np.concatenate( 1656 (np.expand_dims(row, axis=1), np.expand_dims(col, axis=1)), axis=1) 1657 return sparse_tensor.SparseTensor(indices, data, shape) 1658 1659 1660def _is_distributed_dataset(ds): 1661 return isinstance(ds, input_lib.DistributedDatasetInterface) 1662