1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Training-related utilities.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21import atexit 22import collections 23import functools 24import multiprocessing.pool 25import threading 26import time 27 28import numpy as np 29import six 30from six.moves import zip # pylint: disable=redefined-builtin 31 32from tensorflow.core.framework import graph_pb2 33from tensorflow.python import tf2 34from tensorflow.python.data.experimental.ops import cardinality 35from tensorflow.python.data.experimental.ops import distribute_options 36from tensorflow.python.data.ops import dataset_ops 37from tensorflow.python.data.ops import iterator_ops 38from tensorflow.python.eager import context 39from tensorflow.python.framework import composite_tensor 40from tensorflow.python.framework import dtypes 41from tensorflow.python.framework import errors 42from tensorflow.python.framework import ops 43from tensorflow.python.framework import smart_cond 44from tensorflow.python.framework import sparse_tensor 45from tensorflow.python.framework import tensor_spec 46from tensorflow.python.framework import tensor_util 47from tensorflow.python.keras import backend as K 48from tensorflow.python.keras import callbacks as cbks 49from tensorflow.python.keras import losses 50from tensorflow.python.keras import metrics as metrics_module 51from tensorflow.python.keras.utils import data_utils 52from tensorflow.python.keras.utils import generic_utils 53from tensorflow.python.keras.utils import losses_utils 54from tensorflow.python.keras.utils import tf_inspect 55from tensorflow.python.ops import array_ops 56from tensorflow.python.ops import gen_array_ops 57from tensorflow.python.ops import math_ops 58from tensorflow.python.ops import sparse_ops 59from tensorflow.python.ops.ragged import ragged_tensor 60from tensorflow.python.ops.ragged import ragged_tensor_value 61from tensorflow.python.platform import tf_logging as logging 62from tensorflow.python.util import nest 63 64 65def is_composite_or_composite_value(tensor): 66 """Returns true if 'tensor' is a CompositeTensor or a CT Value object.""" 67 # TODO(b/125094323): This should be isinstance(CompositeTensor) or 68 # isinstance(CompositeTensorValue) once we support that. 69 return isinstance( 70 tensor, 71 (composite_tensor.CompositeTensor, sparse_tensor.SparseTensorValue, 72 ragged_tensor_value.RaggedTensorValue)) 73 74 75@six.add_metaclass(abc.ABCMeta) 76class Aggregator(object): 77 """Abstract base class used to aggregate batch-level outputs of a loop. 78 79 Attributes: 80 use_steps: Whether the loop is using `step` or `batch_size`. 81 num_samples: Total number of samples: `batch_size * num_batches`. 82 steps: Total number of steps. 83 batch_size: Batch size. It is used for validation checks between inputs and 84 outputs. 85 results: What to return at the end of the aggregation loop. 86 """ 87 88 def __init__(self, use_steps, num_samples=None, steps=None, batch_size=None): 89 self.use_steps = use_steps 90 self.num_samples = num_samples 91 self.steps = steps 92 self.batch_size = batch_size 93 self.results = [] 94 95 @abc.abstractmethod 96 def create(self, batch_outs): 97 """Creates the initial results from the first batch outputs. 98 99 Args: 100 batch_outs: A list of batch-level outputs. 101 """ 102 raise NotImplementedError('Must be implemented in subclasses.') 103 104 @abc.abstractmethod 105 def aggregate(self, batch_outs, batch_start=None, batch_end=None): 106 """Aggregates batch-level results into total results. 107 108 Args: 109 batch_outs: A list of batch-level outputs. 110 batch_start: The start index of this batch. Always `None` if `use_steps` 111 is `True`. 112 batch_end: The end index of this batch. Always `None` if `use_steps` is 113 `True`. 114 """ 115 raise NotImplementedError('Must be implemented in subclasses.') 116 117 @abc.abstractmethod 118 def finalize(self): 119 """Prepares the total results to be returned.""" 120 raise NotImplementedError('Must be implemented in subclasses.') 121 122 123class MetricsAggregator(Aggregator): 124 """Aggregator that calculates loss and metrics info. 125 126 Attributes: 127 use_steps: Whether the loop is using `step` or `batch_size`. 128 num_samples: Total number of samples: `batch_size*num_batches`. 129 steps: Total number of steps, ie number of times to iterate over a dataset 130 to cover all samples. 131 """ 132 133 def __init__(self, use_steps, num_samples=None, steps=None): 134 super(MetricsAggregator, self).__init__( 135 use_steps=use_steps, 136 num_samples=num_samples, 137 steps=steps, 138 batch_size=None) 139 140 def create(self, batch_outs): 141 self.results = [0.] * len(batch_outs) 142 143 def aggregate(self, batch_outs, batch_start=None, batch_end=None): 144 # Loss. 145 if self.use_steps: 146 self.results[0] += batch_outs[0] 147 else: 148 self.results[0] += batch_outs[0] * (batch_end - batch_start) 149 # Metrics (always stateful, just grab current values.) 150 self.results[1:] = batch_outs[1:] 151 152 def finalize(self): 153 if not self.results: 154 raise ValueError('Empty training data.') 155 self.results[0] /= (self.num_samples or self.steps) 156 157 158def _append_sparse_tensor_value(target, to_append): 159 """Append sparse tensor value objects.""" 160 # Make sure the sparse tensors are of the same size (except for the 0th dim). 161 if len(target.dense_shape) != len(to_append.dense_shape): 162 raise RuntimeError( 163 'Unable to concatenate %s and %s. The inner dense shapes do not ' 164 'have the same number of dimensions (%s vs %s)' % 165 (target, to_append, target.dense_shape, to_append.dense_shape)) 166 167 if target.dense_shape[1:] != to_append.dense_shape[1:]: 168 raise RuntimeError( 169 'Unable to concatenate %s and %s. The inner dense shapes do not ' 170 'match inner dimensions (%s vs %s)' % 171 (target, to_append, target.dense_shape[1:], to_append.dense_shape[1:])) 172 173 # Add the to_append indices to target, updating the 0th value, and keeping 174 # track of the maximum so we know the final dense_shape of this tensor. 175 base_dim0_value = target.dense_shape[0] 176 max_dim0_value = target.dense_shape[0] 177 new_indices = target.indices 178 for index in to_append.indices: 179 # Here, we iterate through the sparse indices of the tensor to append. For 180 # each index, we update its zeroth value (the batch index) by adding the 181 # number of batch items in the tensor we are appending to (so an index 182 # of [0, 0, 1] for a value that is being appended to a tensor with 0th dim 183 # size 3 would become [3, 0, 1].) 184 index[0] += base_dim0_value 185 max_dim0_value = max(max_dim0_value, index[0]) 186 new_indices = np.append(new_indices, [index], axis=0) 187 188 # Extend the values array to contain all of the appended values. These will 189 # be in the same order as the indices added above. 190 new_values = np.concatenate((target.values, to_append.values), axis=0) 191 192 # Create a new dense shape by replacing the value for the 0th dimension 193 # with the new max dim0 value. 194 new_dense_shape = list(target.dense_shape) 195 new_dense_shape[0] = max_dim0_value + 1 196 new_dense_shape = tuple(new_dense_shape) 197 198 return sparse_tensor.SparseTensorValue( 199 indices=new_indices, values=new_values, dense_shape=new_dense_shape) 200 201 202def _append_ragged_tensor_value(target, to_append): 203 """Append ragged tensor value objects.""" 204 # Make sure the ragged tensors are of the same size (save for the 0th dim). 205 if len(target.shape) != len(to_append.shape): 206 raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append)) 207 208 if target.shape[1:] != to_append.shape[1:]: 209 raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append)) 210 211 adjusted_row_splits = to_append.row_splits[1:] + target.row_splits[-1] 212 new_row_splits = np.append(target.row_splits, adjusted_row_splits) 213 if isinstance(target.values, ragged_tensor_value.RaggedTensorValue): 214 new_values = _append_ragged_tensor_value(target.values, to_append.values) 215 else: 216 new_values = np.concatenate((target.values, to_append.values), axis=0) 217 218 return ragged_tensor_value.RaggedTensorValue(new_values, new_row_splits) 219 220 221def _append_composite_tensor(target, to_append): 222 """Helper function to append composite tensors to each other in the 0 axis. 223 224 In order to support batching within a fit/evaluate/predict call, we need 225 to be able to aggregate within a CompositeTensor. Unfortunately, the CT 226 API currently does not make this easy - especially in V1 mode, where we're 227 working with CompositeTensor Value objects that have no connection with the 228 CompositeTensors that created them. 229 230 Args: 231 target: CompositeTensor or CompositeTensor value object that will be 232 appended to. 233 to_append: CompositeTensor or CompositeTensor value object to append to. 234 'target'. 235 236 Returns: 237 A CompositeTensor or CompositeTensor value object. 238 239 Raises: 240 RuntimeError: if concatenation is not possible. 241 """ 242 if type(target) is not type(to_append): 243 raise RuntimeError('Unable to concatenate %s and %s' % 244 (type(target), type(to_append))) 245 246 # Perform type-specific concatenation. 247 # TODO(b/125094323): This should be replaced by a simple call to 248 # target.append() that should work on all of the below classes. 249 250 # If we're seeing a CompositeTensor here, we know it's because we're in 251 # Eager mode (or else we'd have evaluated the CT to a CT Value object 252 # already). Therefore, it's safe to call concat() on it without evaluating 253 # the result any further. If not - that is, if we're seeing a 254 # SparseTensorValue or a RaggedTensorValue - we need to hand-update it 255 # since we're outside of the graph anyways. 256 if isinstance(target, sparse_tensor.SparseTensor): 257 # We need to invoke the sparse version of concatenate here - tf.concat 258 # won't work. 259 return sparse_ops.sparse_concat(sp_inputs=[target, to_append], axis=0) 260 elif isinstance(target, ragged_tensor.RaggedTensor): 261 return array_ops.concat([target, to_append], axis=0) 262 elif isinstance(target, sparse_tensor.SparseTensorValue): 263 return _append_sparse_tensor_value(target, to_append) 264 elif isinstance(target, ragged_tensor_value.RaggedTensorValue): 265 return _append_ragged_tensor_value(target, to_append) 266 else: 267 raise RuntimeError('Attempted to concatenate unsupported object %s.' % 268 type(target)) 269 270 271class ConcatAggregator(Aggregator): 272 """Combine tensor-likes which cannot be merged on the fly. 273 274 This class expects to aggregate a single tensor-like rather than a nested 275 structure of tensor-likes. 276 """ 277 278 def __init__(self, batch_size): 279 self.composite = None 280 super(ConcatAggregator, self).__init__( 281 use_steps=True, num_samples=None, steps=None, batch_size=batch_size) 282 283 def create(self, batch_element): 284 self.composite = is_composite_or_composite_value(batch_element) 285 286 def aggregate(self, batch_element, batch_start=None, batch_end=None): 287 288 # TODO(psv): Add num_samples check here to detect when output batch 289 # #samples is < batch size and != input batch #samples. 290 if self.batch_size and self.batch_size < batch_element.shape[0]: 291 raise ValueError( 292 'Mismatch between expected batch size and model output batch size. ' 293 'Output shape = {}, expected output shape = shape {}'.format( 294 batch_element.shape, 295 (self.batch_size,) + batch_element.shape[1:])) 296 self.results.append(batch_element) 297 298 def finalize(self): 299 # Special case of single batch inference which skips a copy. 300 if len(self.results) == 1: 301 self.results = self.results[0] 302 303 elif self.composite: 304 # TODO(taylorrobie): efficiently concatenate. 305 results = self.results[0] 306 for r in self.results[1:]: 307 results = _append_composite_tensor(results, r) 308 self.results = results 309 310 else: 311 self.results = np.concatenate(self.results, axis=0) 312 313 314_COPY_THREADS = 4 315_COPY_POOL = None 316 317 318def get_copy_pool(): 319 """Shared threadpool for copying arrays. 320 321 Pool instantiation takes ~ 2ms, so a singleton pool is used rather than 322 creating a pool per SliceAggregator. 323 324 Returns: 325 The global copy threadpool. 326 """ 327 global _COPY_POOL 328 if _COPY_POOL is None: 329 _COPY_POOL = multiprocessing.pool.ThreadPool(_COPY_THREADS) 330 atexit.register(_COPY_POOL.close) 331 return _COPY_POOL 332 333 334class SliceAggregator(Aggregator): 335 """Combine arrays where the final size is known. 336 337 This class expects to aggregate a single tensor-like rather than a nested 338 structure of tensor-likes. 339 340 NumPy copies are an operation that threads handle quite well because all of 341 the heavy lifting is in c and does not need the GIL. Moreover, we can perform 342 lock-free writes to the same buffer in multiple threads because the nature of 343 result aggregation guarantees that either the indices are disjoint or the 344 aggregator will throw an exception in finalize. Moreover, because aggregation 345 is performed on the slowest varying dimension, assignments for a given batch 346 will write to contiguous blocks of memory, further minimizing contention. 347 348 There is, however, some scheduling and context switching overhead which will 349 offset the gains from pipelining the slice assignment. Below a given threshold 350 it is faster to simply assign in the main thread rather than enqueue the 351 assignment in a side thread. The exact threshold will vary from system to 352 system, but the time is not very sensitive to the exact transition so a value 353 of 2 ** 14 was chosen which should be reasonable on most systems. 354 """ 355 356 _BINARY_SIZE_THRESHOLD = 2 ** 14 357 _MAX_COPY_SECONDS = 300 358 359 def __init__(self, num_samples, batch_size): 360 self._async_copies = [] 361 self._pool = get_copy_pool() 362 self._errors = [] 363 super(SliceAggregator, self).__init__( 364 use_steps=False, 365 num_samples=num_samples, 366 steps=None, 367 batch_size=batch_size) 368 369 def create(self, batch_element): 370 # This step does not need to be pipelined because NumPy empty array 371 # initialization is effectively instantaneous. 372 shape = (self.num_samples,) + batch_element.shape[1:] 373 dtype = batch_element.dtype 374 375 self.results = np.empty(shape=shape, dtype=dtype) 376 377 def aggregate(self, batch_element, batch_start, batch_end): 378 # Fail early. 379 if self._errors: 380 six.reraise(type(self._errors[0]), self._errors[0]) 381 382 # In the special case of single batch inference, no copy is needed. 383 if batch_end - batch_start == self.num_samples: 384 if self.num_samples != batch_element.shape[0]: 385 raise ValueError( 386 'Mismatch between expected batch size and model output batch size. ' 387 'Output shape = {}, expected output shape = shape {}'.format( 388 batch_element.shape, self.results.shape)) 389 390 self.results = batch_element 391 return 392 393 # This is an approximate threshold, so we don't need to consider the number 394 # of bytes per element. 395 num_elements = np.prod(batch_element.shape) 396 if num_elements < self._BINARY_SIZE_THRESHOLD: 397 self.results[batch_start:batch_end] = batch_element 398 else: 399 is_finished = threading.Event() 400 self._pool.apply_async( 401 self._slice_assign, 402 args=(batch_element, batch_start, batch_end, is_finished)) 403 self._async_copies.append(is_finished) 404 405 def _slice_assign(self, batch_element, batch_start, batch_end, is_finished): 406 """Legacy utility method to slice input arrays.""" 407 try: 408 self.results[batch_start:batch_end] = batch_element 409 410 except Exception as e: # pylint: disable=broad-except 411 # `_slice_assign` should only be called in threads and exceptions raised 412 # in threads do not carry over to the main thread. So instead we perform a 413 # a broad catch in the thread and then store the exception to be re-raised 414 # in the main thread. 415 self._errors.append(e) 416 417 finally: 418 is_finished.set() 419 420 def finalize(self): 421 start_time = time.time() 422 for is_finished in self._async_copies: 423 timeout = max([0., self._MAX_COPY_SECONDS - (time.time() - start_time)]) 424 if not is_finished.wait(timeout): 425 raise ValueError('Timed out waiting for copy to complete.') 426 427 if self._errors: 428 six.reraise(self._errors[0].__class__, self._errors[0]) 429 430 431class OutputsAggregator(Aggregator): 432 """Aggregator that concatenates outputs.""" 433 434 _structure = None 435 436 def create(self, batch_outs): 437 # SparseTensorValue is a named tuple which nest will flatten, so we need 438 # to guard it to properly handle the structure. 439 self._structure = nest.get_traverse_shallow_structure( 440 lambda x: not is_composite_or_composite_value(x), batch_outs) 441 batch_outs = nest.flatten_up_to(self._structure, batch_outs) 442 443 for batch_element in batch_outs: 444 if is_composite_or_composite_value(batch_element): 445 # If the output is not a ndarray, it will be either a composite tensor 446 # or a composite tensor's Value object. In either case, we can't 447 # allocate an array to hold the object - we'll handle it later. 448 self.results.append(ConcatAggregator(self.batch_size)) 449 elif isinstance(batch_element, np.ndarray): 450 self.results.append( 451 (ConcatAggregator(self.batch_size) if self.use_steps else 452 SliceAggregator(self.num_samples, self.batch_size))) 453 else: 454 # This is not a ndarray, a CompositeTensor, or a CompositeTensorValue. 455 # Fail fast rather than trying to concatenate it. 456 raise RuntimeError('Attempted to aggregate unsupported object {}.' 457 .format(batch_element)) 458 459 self.results[-1].create(batch_element) 460 461 def aggregate(self, batch_outs, batch_start=None, batch_end=None): 462 batch_outs = nest.flatten_up_to(self._structure, batch_outs) 463 for batch_element, result in zip(batch_outs, self.results): 464 result.aggregate(batch_element, batch_start, batch_end) 465 466 def finalize(self): 467 for result in self.results: 468 result.finalize() 469 self.results = [i.results for i in self.results] 470 self.results = nest.pack_sequence_as(self._structure, self.results) 471 472 473def get_progbar(model, count_mode, include_metrics=True): 474 """Get Progbar.""" 475 if include_metrics: 476 stateful_metric_names = getattr(model, 'metrics_names', None) 477 if stateful_metric_names: 478 stateful_metric_names = stateful_metric_names[1:] # Exclude `loss` 479 else: 480 stateful_metric_names = None 481 return cbks.ProgbarLogger(count_mode, stateful_metrics=stateful_metric_names) 482 483 484def check_num_samples(ins, batch_size=None, steps=None, steps_name='steps'): 485 """Determine the number of samples provided for training and evaluation. 486 487 The number of samples is not defined when running with `steps`, 488 in which case the number of samples is set to `None`. 489 490 Args: 491 ins: List of tensors to be fed to the Keras function. 492 batch_size: Integer batch size or `None` if not defined. 493 steps: Total number of steps (batches of samples) before declaring 494 `_predict_loop` finished. Ignored with the default value of `None`. 495 steps_name: The public API's parameter name for `steps`. 496 497 Raises: 498 ValueError: when `steps` is `None` and the attribute `ins.shape` 499 does not exist. Also raises ValueError when `steps` is not `None` 500 and `batch_size` is not `None` because they are mutually 501 exclusive. 502 503 Returns: 504 When steps is `None`, returns the number of samples to be 505 processed based on the size of the first dimension of the 506 first input numpy array. When steps is not `None` and 507 `batch_size` is `None`, returns `None`. 508 """ 509 if steps is not None and batch_size is not None: 510 raise ValueError('If ' + steps_name + 511 ' is set, the `batch_size` must be None.') 512 if check_steps_argument(ins, steps, steps_name): 513 return None 514 515 if hasattr(ins[0], 'shape'): 516 return int(ins[0].shape[0]) 517 return None # Edge case where ins == [static_learning_phase] 518 519 520def standardize_single_array(x, expected_shape=None): 521 """Expand data of shape (x,) to (x, 1), unless len(expected_shape)==1.""" 522 if x is None: 523 return None 524 525 if is_composite_or_composite_value(x): 526 return x 527 528 if isinstance(x, int): 529 raise ValueError( 530 'Expected an array data type but received an integer: {}'.format(x)) 531 532 if (x.shape is not None and len(x.shape) == 1 and 533 (expected_shape is None or len(expected_shape) != 1)): 534 if tensor_util.is_tf_type(x): 535 x = array_ops.expand_dims(x, axis=1) 536 else: 537 x = np.expand_dims(x, 1) 538 return x 539 540 541def get_composite_shape(tensor): 542 """Returns the shape of the passed composite tensor.""" 543 if isinstance(tensor, sparse_tensor.SparseTensorValue): 544 # SparseTensorValues use a 'dense_shape' attribute 545 return tensor.dense_shape 546 else: 547 return tensor.shape 548 549 550def standardize_input_data(data, 551 names, 552 shapes=None, 553 check_batch_axis=True, 554 exception_prefix=''): 555 """Normalizes inputs and targets provided by users. 556 557 Users may pass data as a list of arrays, dictionary of arrays, 558 or as a single array. We normalize this to an ordered list of 559 arrays (same order as `names`), while checking that the provided 560 arrays have shapes that match the network's expectations. 561 562 Args: 563 data: User-provided input data (polymorphic). 564 names: List of expected array names. 565 shapes: Optional list of expected array shapes. 566 check_batch_axis: Boolean; whether to check that the batch axis of the 567 arrays matches the expected value found in `shapes`. 568 exception_prefix: String prefix used for exception formatting. 569 570 Returns: 571 List of standardized input arrays (one array per model input). 572 573 Raises: 574 ValueError: in case of improperly formatted user-provided data. 575 """ 576 try: 577 data_len = len(data) 578 except TypeError: 579 # For instance if data is `None` or a symbolic Tensor. 580 data_len = None 581 582 if not names: 583 if data_len and not isinstance(data, dict): 584 raise ValueError( 585 'Error when checking model ' + exception_prefix + ': ' 586 'expected no data, but got:', data) 587 return [] 588 if data is None: 589 return [None for _ in range(len(names))] 590 591 if isinstance(data, dict): 592 try: 593 data = [ 594 data[x].values 595 if data[x].__class__.__name__ == 'DataFrame' else data[x] 596 for x in names 597 ] 598 except KeyError as e: 599 raise ValueError('No data provided for "' + e.args[0] + '". Need data ' 600 'for each key in: ' + str(names)) 601 elif isinstance(data, (list, tuple)): 602 if isinstance(data[0], (list, tuple)): 603 data = [np.asarray(d) for d in data] 604 elif len(names) == 1 and isinstance(data[0], (float, int)): 605 data = [np.asarray(data)] 606 else: 607 data = [ 608 x.values if x.__class__.__name__ == 'DataFrame' else x for x in data 609 ] 610 else: 611 data = data.values if data.__class__.__name__ == 'DataFrame' else data 612 data = [data] 613 614 if shapes is not None: 615 data = [ 616 standardize_single_array(x, shape) for (x, shape) in zip(data, shapes) 617 ] 618 else: 619 data = [standardize_single_array(x) for x in data] 620 621 if len(data) != len(names): 622 if data and hasattr(data[0], 'shape'): 623 raise ValueError('Error when checking model ' + exception_prefix + 624 ': the list of Numpy arrays that you are passing to ' 625 'your model is not the size the model expected. ' 626 'Expected to see ' + str(len(names)) + ' array(s), ' + 627 'for inputs ' + str(names) + ' but instead got the ' 628 'following list of ' + str(len(data)) + ' arrays: ' + 629 str(data)[:200] + '...') 630 elif len(names) > 1: 631 raise ValueError('Error when checking model ' + exception_prefix + 632 ': you are passing a list as input to your model, ' 633 'but the model expects a list of ' + str(len(names)) + 634 ' Numpy arrays instead. The list you passed was: ' + 635 str(data)[:200]) 636 elif len(data) == 1 and not hasattr(data[0], 'shape'): 637 raise TypeError('Error when checking model ' + exception_prefix + 638 ': data should be a Numpy array, or list/dict of ' 639 'Numpy arrays. Found: ' + str(data)[:200] + '...') 640 elif len(names) == 1: 641 data = [np.asarray(data)] 642 643 # Check shapes compatibility. 644 if shapes: 645 for i in range(len(names)): 646 if shapes[i] is not None: 647 if tensor_util.is_tf_type(data[i]): 648 tensorshape = data[i].shape 649 if not tensorshape: 650 continue 651 data_shape = tuple(tensorshape.as_list()) 652 elif is_composite_or_composite_value(data[i]): 653 tensorshape = get_composite_shape(data[i]) 654 data_shape = tuple(tensorshape.as_list()) 655 else: 656 data_shape = data[i].shape 657 658 shape = shapes[i] 659 if len(data_shape) != len(shape): 660 raise ValueError('Error when checking ' + exception_prefix + 661 ': expected ' + names[i] + ' to have ' + 662 str(len(shape)) + ' dimensions, but got array ' 663 'with shape ' + str(data_shape)) 664 if not check_batch_axis: 665 data_shape = data_shape[1:] 666 shape = shape[1:] 667 for dim, ref_dim in zip(data_shape, shape): 668 if ref_dim != dim and ref_dim is not None and dim is not None: 669 raise ValueError('Error when checking ' + exception_prefix + 670 ': expected ' + names[i] + ' to have shape ' + 671 str(shape) + ' but got array with shape ' + 672 str(data_shape)) 673 return data 674 675 676def standardize_sample_or_class_weights(x_weight, output_names, weight_type): 677 """Maps `sample_weight` or `class_weight` to model outputs. 678 679 Args: 680 x_weight: User-provided `sample_weight` or `class_weight` argument. 681 output_names: List of output names (strings) in the model. 682 weight_type: A string used purely for exception printing. 683 684 Returns: 685 A list of `sample_weight` or `class_weight` where there are exactly 686 one element per model output. 687 688 Raises: 689 ValueError: In case of invalid user-provided argument. 690 """ 691 if x_weight is None or (isinstance(x_weight, (list, tuple)) and 692 len(x_weight) == 0): # pylint: disable=g-explicit-length-test 693 return [None for _ in output_names] 694 if len(output_names) == 1: 695 if isinstance(x_weight, (list, tuple)) and len(x_weight) == 1: 696 return x_weight 697 if isinstance(x_weight, dict) and output_names[0] in x_weight: 698 return [x_weight[output_names[0]]] 699 else: 700 return [x_weight] 701 if isinstance(x_weight, (list, tuple)): 702 if len(x_weight) != len(output_names): 703 raise ValueError('Provided `' + weight_type + '` was a list of ' + 704 str(len(x_weight)) + ' elements, but the model has ' + 705 str(len(output_names)) + ' outputs. ' 706 'You should provide one `' + weight_type + '`' 707 'array per model output.') 708 return x_weight 709 if isinstance(x_weight, collections.abc.Mapping): 710 generic_utils.check_for_unexpected_keys(weight_type, x_weight, output_names) 711 x_weights = [] 712 for name in output_names: 713 x_weights.append(x_weight.get(name)) 714 return x_weights 715 else: 716 raise TypeError('The model has multiple outputs, so `' + weight_type + '` ' 717 'should be either a list or a dict. ' 718 'Provided `' + weight_type + '` type not understood: ' + 719 str(x_weight)) 720 721 722def standardize_class_weights(class_weight, output_names): 723 return standardize_sample_or_class_weights(class_weight, output_names, 724 'class_weight') 725 726 727def standardize_sample_weights(sample_weight, output_names): 728 return standardize_sample_or_class_weights(sample_weight, output_names, 729 'sample_weight') 730 731 732def check_array_lengths(inputs, targets, weights=None): 733 """Does user input validation for numpy arrays. 734 735 Args: 736 inputs: list of Numpy arrays of inputs. 737 targets: list of Numpy arrays of targets. 738 weights: list of Numpy arrays of sample weights. 739 740 Raises: 741 ValueError: in case of incorrectly formatted data. 742 """ 743 744 def is_tensor_or_composite_tensor(x): 745 return tensor_util.is_tf_type(x) or is_composite_or_composite_value(x) 746 747 def set_of_lengths(x): 748 # Returns a set with the variation between 749 # different shapes, with None => 0 750 if x is None: 751 return {} 752 else: 753 return set([ 754 y.shape[0] 755 for y in x 756 if y is not None and not is_tensor_or_composite_tensor(y) 757 ]) 758 759 set_x = set_of_lengths(inputs) 760 set_y = set_of_lengths(targets) 761 set_w = set_of_lengths(weights) 762 if len(set_x) > 1: 763 raise ValueError('All input arrays (x) should have ' 764 'the same number of samples. Got array shapes: ' + 765 str([x.shape for x in inputs])) 766 if len(set_y) > 1: 767 raise ValueError('All target arrays (y) should have ' 768 'the same number of samples. Got array shapes: ' + 769 str([y.shape for y in targets])) 770 if set_x and set_y and list(set_x)[0] != list(set_y)[0]: 771 raise ValueError('Input arrays should have ' 772 'the same number of samples as target arrays. ' 773 'Found ' + str(list(set_x)[0]) + ' input samples ' 774 'and ' + str(list(set_y)[0]) + ' target samples.') 775 if len(set_w) > 1: 776 raise ValueError('All sample_weight arrays should have ' 777 'the same number of samples. Got array shapes: ' + 778 str([w.shape for w in weights])) 779 if set_y and set_w and list(set_y)[0] != list(set_w)[0]: 780 raise ValueError('Sample_weight arrays should have ' 781 'the same number of samples as target arrays. Got ' + 782 str(list(set_y)[0]) + ' input samples and ' + 783 str(list(set_w)[0]) + ' target samples.') 784 785 786def check_loss_and_target_compatibility(targets, loss_fns, output_shapes): 787 """Does validation on the compatibility of targets and loss functions. 788 789 This helps prevent users from using loss functions incorrectly. This check 790 is purely for UX purposes. 791 792 Args: 793 targets: list of Numpy arrays of targets. 794 loss_fns: list of loss functions. 795 output_shapes: list of shapes of model outputs. 796 797 Raises: 798 ValueError: if a loss function or target array 799 is incompatible with an output. 800 """ 801 key_loss_fns = { 802 losses.mean_squared_error, losses.binary_crossentropy, 803 losses.categorical_crossentropy 804 } 805 key_loss_classes = (losses.MeanSquaredError, losses.BinaryCrossentropy, 806 losses.CategoricalCrossentropy) 807 for y, loss, shape in zip(targets, loss_fns, output_shapes): 808 if y is None or loss is None or tensor_util.is_tf_type(y): 809 continue 810 if losses.is_categorical_crossentropy(loss): 811 if y.shape[-1] == 1: 812 raise ValueError('You are passing a target array of shape ' + 813 str(y.shape) + 814 ' while using as loss `categorical_crossentropy`. ' 815 '`categorical_crossentropy` expects ' 816 'targets to be binary matrices (1s and 0s) ' 817 'of shape (samples, classes). ' 818 'If your targets are integer classes, ' 819 'you can convert them to the expected format via:\n' 820 '```\n' 821 'from keras.utils import to_categorical\n' 822 'y_binary = to_categorical(y_int)\n' 823 '```\n' 824 '\n' 825 'Alternatively, you can use the loss function ' 826 '`sparse_categorical_crossentropy` instead, ' 827 'which does expect integer targets.') 828 829 is_loss_wrapper = isinstance(loss, losses.LossFunctionWrapper) 830 if (isinstance(loss, key_loss_classes) or (is_loss_wrapper and 831 (loss.fn in key_loss_fns))): 832 for target_dim, out_dim in zip(y.shape[1:], shape[1:]): 833 if out_dim is not None and target_dim != out_dim: 834 loss_name = loss.name 835 if loss_name is None: 836 loss_type = loss.fn if is_loss_wrapper else type(loss) 837 loss_name = loss_type.__name__ 838 raise ValueError('A target array with shape ' + str(y.shape) + 839 ' was passed for an output of shape ' + str(shape) + 840 ' while using as loss `' + loss_name + '`. ' 841 'This loss expects targets to have the same shape ' 842 'as the output.') 843 844 845def collect_per_output_metric_info(metrics, 846 output_names, 847 output_shapes, 848 loss_fns, 849 is_weighted=False): 850 """Maps metric names and functions to model outputs. 851 852 Args: 853 metrics: a list or a list of lists or a dict of metric functions. 854 output_names: a list of the names (strings) of model outputs. 855 output_shapes: a list of the shapes (strings) of model outputs. 856 loss_fns: a list of the loss functions corresponding to the model outputs. 857 is_weighted: Boolean indicating whether the given metrics are weighted. 858 859 Returns: 860 A list (one entry per model output) of dicts. 861 For instance, if the model has 2 outputs, and for the first output 862 we want to compute "binary_accuracy" and "binary_crossentropy", 863 and just "binary_accuracy" for the second output, 864 the list would look like: `[{ 865 'acc': binary_accuracy(), 866 'ce': binary_crossentropy(), 867 }, { 868 'acc': binary_accuracy(), 869 }]` 870 871 Raises: 872 TypeError: if an incorrect type is passed for the `metrics` argument. 873 """ 874 if not metrics: 875 return [{} for _ in output_names] 876 877 if isinstance(metrics, list): 878 any_sub_list = any(isinstance(m, list) for m in metrics) 879 if any_sub_list: 880 if len(metrics) != len(output_names): 881 raise ValueError('When passing a list of lists as `metrics`, ' 882 'it should have one entry per model output. ' 883 'The model has ' + str(len(output_names)) + 884 ' outputs, but you passed metrics=' + str(metrics)) 885 # User has provided a list of len = len(outputs). 886 nested_metrics = [generic_utils.to_list(m) for m in metrics] 887 else: 888 # If it is a single list we then apply all metrics to all outputs. 889 if len(output_names) > 1: 890 nested_metrics = [] 891 for _ in output_names: 892 nested_metrics.append( 893 [metrics_module.clone_metric(m) for m in metrics]) 894 else: 895 nested_metrics = [metrics] 896 elif isinstance(metrics, collections.abc.Mapping): 897 generic_utils.check_for_unexpected_keys('metrics', metrics, output_names) 898 nested_metrics = [] 899 for name in output_names: 900 output_metrics = generic_utils.to_list(metrics.get(name, [])) 901 nested_metrics.append(output_metrics) 902 else: 903 raise TypeError('Type of `metrics` argument not understood. ' 904 'Expected a list or dictionary, found: ' + str(metrics)) 905 906 per_output_metrics = [] 907 for i, metrics in enumerate(nested_metrics): 908 metrics_dict = collections.OrderedDict() 909 for metric in metrics: 910 metric_name = get_metric_name(metric, is_weighted) 911 metric_fn = get_metric_function( 912 metric, output_shape=output_shapes[i], loss_fn=loss_fns[i]) 913 914 # If the metric function is not stateful, we create a stateful version. 915 if not isinstance(metric_fn, metrics_module.Metric): 916 metric_fn = metrics_module.MeanMetricWrapper( 917 metric_fn, name=metric_name) 918 metrics_dict[metric_name] = metric_fn 919 per_output_metrics.append(metrics_dict) 920 921 return per_output_metrics 922 923 924def batch_shuffle(index_array, batch_size): 925 """Shuffles an array in a batch-wise fashion. 926 927 Useful for shuffling HDF5 arrays 928 (where one cannot access arbitrary indices). 929 930 Args: 931 index_array: array of indices to be shuffled. 932 batch_size: integer. 933 934 Returns: 935 The `index_array` array, shuffled in a batch-wise fashion. 936 """ 937 batch_count = int(len(index_array) / batch_size) 938 # to reshape we need to be cleanly divisible by batch size 939 # we stash extra items and reappend them after shuffling 940 last_batch = index_array[batch_count * batch_size:] 941 index_array = index_array[:batch_count * batch_size] 942 index_array = index_array.reshape((batch_count, batch_size)) 943 np.random.shuffle(index_array) 944 index_array = index_array.flatten() 945 return np.append(index_array, last_batch) 946 947 948def standardize_weights(y, 949 sample_weight=None, 950 class_weight=None, 951 sample_weight_mode=None): 952 """Performs sample weight validation and standardization. 953 954 Everything gets normalized to a single sample-wise (or timestep-wise) 955 weight array. If both `sample_weight` and `class_weight` are provided, 956 the weights are multiplied. 957 958 Args: 959 y: Numpy array or Tensor of model targets to be weighted. 960 sample_weight: User-provided `sample_weight` argument. 961 class_weight: User-provided `class_weight` argument. 962 sample_weight_mode: One of `None` or `"temporal"`. `"temporal"` indicated 963 that we expect 2D weight data that will be applied to the last 2 964 dimensions of the targets (i.e. we are weighting timesteps, not 965 samples). 966 967 Returns: 968 A numpy array of target weights, one entry per sample to weight. 969 970 Raises: 971 ValueError: In case of invalid user-provided arguments. 972 """ 973 # Iterator may return sample_weight as 1-tuple 974 if isinstance(sample_weight, tuple): 975 sample_weight = sample_weight[0] 976 if sample_weight_mode is not None and sample_weight_mode != 'samplewise': 977 if sample_weight_mode != 'temporal': 978 raise ValueError('"sample_weight_mode ' 979 'should be None or "temporal". ' 980 'Found: ' + str(sample_weight_mode)) 981 if len(y.shape) < 3: 982 raise ValueError('Found a sample_weight array for ' 983 'an input with shape ' + str(y.shape) + '. ' 984 'Timestep-wise sample weighting (use of ' 985 'sample_weight_mode="temporal") is restricted to ' 986 'outputs that are at least 3D, i.e. that have ' 987 'a time dimension.') 988 if sample_weight is not None and len(sample_weight.shape) != 2: 989 raise ValueError('Found a sample_weight array with shape ' + 990 str(sample_weight.shape) + '. ' 991 'In order to use timestep-wise sample weighting, ' 992 'you should pass a 2D sample_weight array.') 993 else: 994 if sample_weight is not None and len(sample_weight.shape) != 1: 995 raise ValueError( 996 'Found a sample_weight array with shape {}. In order to ' 997 'use timestep-wise sample weights, you should specify ' 998 'sample_weight_mode="temporal" in compile(); founssd "{}" ' 999 'instead. If you just mean to use sample-wise weights, ' 1000 'make sure your sample_weight array is 1D.'.format( 1001 sample_weight.shape, sample_weight_mode)) 1002 1003 if sample_weight is not None: 1004 if len(sample_weight.shape) > len(y.shape): 1005 raise ValueError('Found a sample_weight with shape' + 1006 str(sample_weight.shape) + '.' 1007 'Expected sample_weight with rank ' 1008 'less than or equal to ' + str(len(y.shape))) 1009 1010 if (not tensor_util.is_tf_type(sample_weight) and 1011 y.shape[:sample_weight.ndim] != sample_weight.shape): 1012 raise ValueError('Found a sample_weight array with shape ' + 1013 str(sample_weight.shape) + ' for an input with shape ' + 1014 str(y.shape) + '. ' 1015 'sample_weight cannot be broadcast.') 1016 1017 # Class weights applied per-sample. 1018 class_sample_weight = None 1019 if isinstance(class_weight, dict): 1020 if len(y.shape) > 2: 1021 raise ValueError('`class_weight` not supported for ' 1022 '3+ dimensional targets.') 1023 1024 if tensor_util.is_tf_type(y): 1025 # Few classes are expected, so densifying is reasonable. 1026 keys = np.array(sorted(class_weight.keys())) 1027 values = np.array([class_weight[i] for i in keys]) 1028 weight_vector = np.zeros(np.max(keys) + 1) 1029 weight_vector[:] = np.nan 1030 weight_vector[keys] = values 1031 1032 y_classes = smart_cond.smart_cond( 1033 len(y.shape.as_list()) == 2 and K.shape(y)[1] > 1, 1034 lambda: K.argmax(y, axis=1), 1035 lambda: math_ops.cast(K.reshape(y, (-1,)), dtypes.int64)) 1036 class_sample_weight = array_ops.gather(weight_vector, y_classes) 1037 gen_array_ops.check_numerics( 1038 class_sample_weight, 1039 'Invalid classes or class weights detected. NaN values indicate that ' 1040 'an appropriate class weight could not be determined.') 1041 class_sample_weight = math_ops.cast(class_sample_weight, K.floatx()) 1042 if sample_weight is not None: 1043 sample_weight = math_ops.cast( 1044 ops.convert_to_tensor_v2_with_dispatch(sample_weight), K.floatx()) 1045 else: 1046 y_classes = y 1047 if len(y.shape) == 2: 1048 if y.shape[1] > 1: 1049 y_classes = np.argmax(y, axis=1) 1050 elif y.shape[1] == 1: 1051 y_classes = np.reshape(y, y.shape[0]) 1052 1053 class_sample_weight = np.asarray( 1054 [class_weight[cls] for cls in y_classes if cls in class_weight]) 1055 1056 if len(class_sample_weight) != len(y_classes): 1057 # subtract the sets to pick all missing classes 1058 existing_classes = set(y_classes) 1059 existing_class_weight = set(class_weight.keys()) 1060 raise ValueError( 1061 '`class_weight` must contain all classes in the data.' 1062 ' The classes %s exist in the data but not in ' 1063 '`class_weight`.' % (existing_classes - existing_class_weight)) 1064 1065 if class_sample_weight is not None and sample_weight is not None: 1066 # Multiply weights if both are provided. 1067 return class_sample_weight * sample_weight 1068 if sample_weight is not None: 1069 return sample_weight 1070 if class_sample_weight is not None: 1071 return class_sample_weight 1072 return None 1073 1074 1075def has_symbolic_tensors(ls): 1076 if context.executing_eagerly(): 1077 return False 1078 return has_tensors(ls) 1079 1080 1081def has_tensors(ls): 1082 """Returns true if `ls` contains tensors.""" 1083 # Note: at some point in time ragged tensors didn't count as tensors, so this 1084 # returned false for ragged tensors. Making this return true fails some tests 1085 # which would then require a steps_per_epoch argument. 1086 if isinstance(ls, (list, tuple)): 1087 return any( 1088 tensor_util.is_tf_type(v) and 1089 not isinstance(v, ragged_tensor.RaggedTensor) for v in ls) 1090 if isinstance(ls, dict): 1091 return any( 1092 tensor_util.is_tf_type(v) and 1093 not isinstance(v, ragged_tensor.RaggedTensor) 1094 for _, v in six.iteritems(ls)) 1095 return tensor_util.is_tf_type(ls) and not isinstance( 1096 ls, ragged_tensor.RaggedTensor) 1097 1098 1099def get_metric_name(metric, weighted=False): 1100 """Returns the name corresponding to the given metric input. 1101 1102 Args: 1103 metric: Metric function name or reference. 1104 weighted: Boolean indicating if the given metric is weighted. 1105 1106 Returns: 1107 The metric name. 1108 """ 1109 if tf2.enabled(): 1110 # We keep the string that the user has set in compile as the metric name. 1111 if isinstance(metric, six.string_types): 1112 return metric 1113 1114 metric = metrics_module.get(metric) 1115 return metric.name if hasattr(metric, 'name') else metric.__name__ 1116 else: 1117 metric_name_prefix = 'weighted_' if weighted else '' 1118 if metric in ('accuracy', 'acc', 'crossentropy', 'ce'): 1119 if metric in ('accuracy', 'acc'): 1120 suffix = 'acc' 1121 elif metric in ('crossentropy', 'ce'): 1122 suffix = 'ce' 1123 else: 1124 metric_fn = metrics_module.get(metric) 1125 # Get metric name as string 1126 if hasattr(metric_fn, 'name'): 1127 suffix = metric_fn.name 1128 else: 1129 suffix = metric_fn.__name__ 1130 metric_name = metric_name_prefix + suffix 1131 return metric_name 1132 1133 1134def get_metric_function(metric, output_shape=None, loss_fn=None): 1135 """Returns the metric function corresponding to the given metric input. 1136 1137 Args: 1138 metric: Metric function name or reference. 1139 output_shape: The shape of the output that this metric will be calculated 1140 for. 1141 loss_fn: The loss function used. 1142 1143 Returns: 1144 The metric function. 1145 """ 1146 if metric not in ['accuracy', 'acc', 'crossentropy', 'ce']: 1147 return metrics_module.get(metric) 1148 1149 is_sparse_categorical_crossentropy = ( 1150 isinstance(loss_fn, losses.SparseCategoricalCrossentropy) or 1151 (isinstance(loss_fn, losses.LossFunctionWrapper) and 1152 loss_fn.fn == losses.sparse_categorical_crossentropy)) 1153 1154 is_binary_crossentropy = ( 1155 isinstance(loss_fn, losses.BinaryCrossentropy) or 1156 (isinstance(loss_fn, losses.LossFunctionWrapper) and 1157 loss_fn.fn == losses.binary_crossentropy)) 1158 1159 if metric in ['accuracy', 'acc']: 1160 if output_shape[-1] == 1 or is_binary_crossentropy: 1161 return metrics_module.binary_accuracy 1162 elif is_sparse_categorical_crossentropy: 1163 return metrics_module.sparse_categorical_accuracy 1164 # If the output_shape[-1] is not 1, then we know output is `categorical`. 1165 # We assume it is sparse categorical only if loss is explicitly given 1166 # as sparse categorical crossentropy loss. 1167 return metrics_module.categorical_accuracy 1168 else: 1169 if output_shape[-1] == 1 or is_binary_crossentropy: 1170 return metrics_module.binary_crossentropy 1171 elif is_sparse_categorical_crossentropy: 1172 return metrics_module.sparse_categorical_crossentropy 1173 return metrics_module.categorical_crossentropy 1174 1175 1176def call_metric_function(metric_fn, 1177 y_true, 1178 y_pred=None, 1179 weights=None, 1180 mask=None): 1181 """Invokes metric function and returns the metric result tensor.""" 1182 if mask is not None: 1183 mask = math_ops.cast(mask, y_pred.dtype) 1184 if weights is None: 1185 # Use mask as sample weight. 1186 weights = mask 1187 else: 1188 # Update dimensions of weights to match with mask. 1189 weights = math_ops.cast(weights, dtype=y_pred.dtype) 1190 mask, _, weights = losses_utils.squeeze_or_expand_dimensions( 1191 mask, sample_weight=weights) 1192 weights *= mask 1193 1194 if y_pred is not None: 1195 return metric_fn(y_true, y_pred, sample_weight=weights) 1196 # `Mean` metric only takes a single value. 1197 return metric_fn(y_true, sample_weight=weights) 1198 1199 1200def get_loss_function(loss): 1201 """Returns the loss corresponding to the loss input in `compile` API.""" 1202 if loss is None or isinstance(loss, losses.Loss): 1203 return loss 1204 1205 if tf_inspect.isclass(loss) and issubclass(loss, losses.Loss): 1206 # It is not safe to assume that the loss takes no constructor arguments. 1207 raise ValueError( 1208 'Received uninstantiated Loss class: {}\nPlease call loss ""classes ' 1209 'before passing them to Model.compile.'.format(loss)) 1210 1211 # Deserialize loss configuration, if needed. 1212 if isinstance(loss, collections.abc.Mapping): 1213 loss = losses.get(loss) 1214 1215 # Custom callable class. 1216 if callable(loss) and not hasattr(loss, '__name__'): 1217 return loss 1218 1219 # Wrap loss function with signature `(y_true, y_pred, **kwargs)` 1220 # in `LossFunctionWrapper` class. 1221 loss_fn = losses.get(loss) 1222 1223 # For losses which are given as strings/functions in the compile API, 1224 # we always set the loss reduction type to be `SUM_OVER_BATCH_SIZE` 1225 # (both in distribution strategy context and otherwise). 1226 return losses.LossFunctionWrapper( 1227 loss_fn, 1228 name=loss_fn.__name__, 1229 reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE) 1230 1231 1232def validate_dataset_input(x, y, sample_weight, validation_split=None): 1233 """Validates user input arguments when a dataset iterator is passed. 1234 1235 Args: 1236 x: Input data. A `tf.data` dataset or iterator. 1237 y: Target data. It could be either Numpy array(s) or TensorFlow tensor(s). 1238 Expected to be `None` when `x` is a dataset iterator. 1239 sample_weight: An optional sample-weight array passed by the user to weight 1240 the importance of each sample in `x`. Expected to be `None` when `x` is a 1241 dataset iterator 1242 validation_split: Float between 0 and 1. Fraction of the training data to be 1243 used as validation data. Expected to be `None` when `x` is a dataset 1244 iterator. 1245 1246 Raises: 1247 ValueError: if argument `y` or `sample_weight` or `validation_split` are 1248 provided by user. 1249 """ 1250 if y is not None: 1251 raise ValueError('You passed a dataset or dataset iterator (%s) as ' 1252 'input `x` to your model. In that case, you should ' 1253 'not specify a target (`y`) argument, since the dataset ' 1254 'or dataset iterator generates both input data and ' 1255 'target data. ' 1256 'Received: %s' % (x, y)) 1257 if sample_weight is not None: 1258 raise ValueError('`sample_weight` argument is not supported when input ' 1259 '`x` is a dataset or a dataset iterator. Instead, you' 1260 'can provide sample_weight as the third element of your' 1261 'dataset, i.e. (inputs, targets, sample_weight). ' 1262 'Received: x=%s, sample_weight=%s' % (x, sample_weight)) 1263 if validation_split is not None and validation_split != 0.0: 1264 raise ValueError( 1265 '`validation_split` argument is not supported when ' 1266 'input `x` is a dataset or a dataset iterator. ' 1267 'Received: x=%s, validation_split=%f' % (x, validation_split)) 1268 1269 1270def validate_input_types(inp, orig_inp, allow_dict=True, field_name='inputs'): 1271 """Helper function to validate either inputs or targets.""" 1272 if isinstance(inp, (list, tuple)): 1273 if not all(isinstance(v, np.ndarray) or 1274 tensor_util.is_tf_type(v) for v in inp): 1275 raise ValueError( 1276 'Please provide as model inputs either a single array or a list of ' 1277 'arrays. You passed: {}={}'.format(field_name, str(orig_inp))) 1278 elif isinstance(inp, dict): 1279 if not allow_dict: 1280 raise ValueError( 1281 'You cannot pass a dictionary as model {}.'.format(field_name)) 1282 elif not isinstance(inp, np.ndarray) and not tensor_util.is_tf_type(inp): 1283 raise ValueError( 1284 'Please provide as model inputs either a single array or a list of ' 1285 'arrays. You passed: {}={}'.format(field_name, orig_inp)) 1286 1287 1288def check_generator_arguments(y=None, sample_weight=None, 1289 validation_split=None): 1290 """Validates arguments passed when using a generator.""" 1291 if y is not None: 1292 raise ValueError('`y` argument is not supported when data is' 1293 'a generator or Sequence instance. Instead pass targets' 1294 ' as the second element of the generator.') 1295 if sample_weight is not None: 1296 raise ValueError('`sample_weight` argument is not supported when data is' 1297 'a generator or Sequence instance. Instead pass sample' 1298 ' weights as the third element of the generator.') 1299 if validation_split: 1300 raise ValueError('If your data is in the form of a Python generator, ' 1301 'you cannot use `validation_split`.') 1302 1303 1304def check_steps_argument(input_data, steps, steps_name): 1305 """Validates `steps` argument based on input data's type. 1306 1307 The cases when `steps` value must be provided are when 1308 1. input data passed is an iterator. 1309 2. model was built on top of symbolic tensors, input data is not 1310 required and is `None`. 1311 3. input data passed is a symbolic tensor. 1312 1313 Args: 1314 input_data: Input data. Can be Numpy array(s) or TensorFlow tensor(s) or 1315 tf.data.Dataset iterator or `None`. 1316 steps: Integer or `None`. Total number of steps (batches of samples) to 1317 execute. 1318 steps_name: The public API's parameter name for `steps`. 1319 1320 Returns: 1321 boolean, True if `steps` argument is required, else False. 1322 1323 Raises: 1324 ValueError: if `steps` argument is required for given input data type 1325 but not provided. 1326 """ 1327 is_x_iterator = isinstance( 1328 input_data, (iterator_ops.Iterator, iterator_ops.IteratorBase)) 1329 if (input_data is None or is_x_iterator or has_symbolic_tensors(input_data) or 1330 (isinstance(input_data, list) and not input_data)): 1331 if steps is None: 1332 input_type_str = 'a Dataset iterator' if is_x_iterator else 'data tensors' 1333 raise ValueError('When using {input_type} as input to a model, you should' 1334 ' specify the `{steps_name}` argument.'.format( 1335 input_type=input_type_str, steps_name=steps_name)) 1336 return True 1337 1338 if isinstance(input_data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): 1339 return True 1340 1341 if steps is not None: 1342 list_types = (np.ndarray, list, tuple) 1343 if (isinstance(input_data, list_types) or 1344 (isinstance(input_data, dict) and 1345 any(isinstance(v, list_types) for v in input_data.values()))): 1346 logging.warning('When passing input data as arrays, do not specify ' 1347 '`steps_per_epoch`/`steps` argument. ' 1348 'Please use `batch_size` instead.') 1349 return False 1350 1351 1352def cast_single_tensor(x, dtype=None): 1353 if isinstance(x, np.ndarray): 1354 x = ops.convert_to_tensor_v2_with_dispatch(x) 1355 dtype = dtype or K.floatx() 1356 if x.dtype.is_floating: 1357 return math_ops.cast(x, dtype=dtype) 1358 return x 1359 1360 1361def cast_if_floating_dtype_and_mismatch(targets, outputs): 1362 """Returns target data tensors using correct datatype. 1363 1364 Checks that each target and output pair are the same datatype. If not, casts 1365 the target to the output's datatype. 1366 1367 Args: 1368 targets: tensor or list of targets. 1369 outputs: tensor or list of outputs. 1370 1371 Returns: 1372 Targets in appropriate datatype. 1373 """ 1374 if tensor_util.is_tf_type(targets): 1375 # There is one target, so output[0] should be the only output. 1376 return cast_single_tensor(targets, dtype=outputs[0].dtype) 1377 new_targets = [] 1378 for target, out in zip(targets, outputs): 1379 if isinstance(target, np.ndarray): 1380 target = ops.convert_to_tensor_v2_with_dispatch(target) 1381 if target.dtype != out.dtype: 1382 new_targets.append(cast_single_tensor(target, dtype=out.dtype)) 1383 else: 1384 new_targets.append(target) 1385 return new_targets 1386 1387 1388def cast_if_floating_dtype(x, dtype=None): 1389 """Casts the given data tensors to the default floating point type. 1390 1391 Casts only if the input is already a floating point type. 1392 Args: 1393 x: tensor or list/tuple of tensors. 1394 dtype: The dtype to which Tensors should be cast. 1395 1396 Returns: 1397 Converted input. 1398 """ 1399 return nest.map_structure(functools.partial(cast_single_tensor, dtype=dtype), 1400 x) 1401 1402 1403def cast_to_model_input_dtypes(x, model): 1404 """Casts the given data tensors to the dtypes of the model inputs. 1405 1406 Args: 1407 x: tensor or list/tuple of tensors. 1408 model: The model. 1409 1410 Returns: 1411 Converted input. Each tensor is casted to the corresponding input in 1412 `model.inputs`. 1413 """ 1414 input_dtypes = nest.map_structure(lambda t: t.dtype, model.inputs) 1415 return nest.map_structure(math_ops.cast, x, input_dtypes) 1416 1417 1418def prepare_sample_weight_modes(training_endpoints, sample_weight_mode): 1419 """Prepares sample weight modes for the model. 1420 1421 Args: 1422 training_endpoints: List of model _TrainingEndpoints. 1423 sample_weight_mode: sample weight mode user input passed from compile API. 1424 1425 Raises: 1426 ValueError: In case of invalid `sample_weight_mode` input. 1427 """ 1428 1429 if isinstance(sample_weight_mode, collections.abc.Mapping): 1430 generic_utils.check_for_unexpected_keys( 1431 'sample_weight_mode', sample_weight_mode, 1432 [e.output_name for e in training_endpoints]) 1433 1434 for end_point in training_endpoints: 1435 if not end_point.should_skip_target_weights(): 1436 if end_point.output_name not in sample_weight_mode: 1437 raise ValueError('Output ' + end_point.output_name + 1438 'missing from `_sample_weight_modes` dictionary') 1439 else: 1440 end_point.sample_weight_mode = sample_weight_mode.get( 1441 end_point.output_name) 1442 elif isinstance(sample_weight_mode, (list, tuple)): 1443 if len(sample_weight_mode) != len(training_endpoints): 1444 raise ValueError('When passing a list as sample_weight_mode, ' 1445 'it should have one entry per model output. ' 1446 'The model has ' + str(len(training_endpoints)) + 1447 ' outputs, but you passed ' + 1448 str(len(sample_weight_mode)) + '_sample_weight_modes.') 1449 for mode, endpoint in zip(sample_weight_mode, training_endpoints): 1450 if not endpoint.should_skip_target_weights(): 1451 endpoint.sample_weight_mode = mode 1452 else: 1453 for endpoint in training_endpoints: 1454 if not endpoint.should_skip_target_weights(): 1455 endpoint.sample_weight_mode = sample_weight_mode 1456 1457 1458def prepare_loss_functions(loss, output_names): 1459 """Converts loss to a list of loss functions. 1460 1461 Args: 1462 loss: String (name of objective function), objective function or 1463 `tf.losses.Loss` instance. See `tf.losses`. If the model has multiple 1464 outputs, you can use a different loss on each output by passing a 1465 dictionary or a list of losses. The loss value that will be minimized by 1466 the model will then be the sum of all individual losses. 1467 output_names: List of model output names. 1468 1469 Returns: 1470 A list of loss objective functions. 1471 1472 Raises: 1473 ValueError: If loss is a dict with keys not in model output names, 1474 or if loss is a list with len not equal to model outputs. 1475 """ 1476 if isinstance(loss, collections.abc.Mapping): 1477 generic_utils.check_for_unexpected_keys('loss', loss, output_names) 1478 loss_functions = [] 1479 for name in output_names: 1480 if name not in loss: 1481 logging.warning( 1482 'Output {0} missing from loss dictionary. We assume ' 1483 'this was done on purpose. The fit and evaluate APIs will not be ' 1484 'expecting any data to be passed to {0}.'.format(name)) 1485 loss_functions.append(get_loss_function(loss.get(name, None))) 1486 elif isinstance(loss, six.string_types): 1487 loss_functions = [get_loss_function(loss) for _ in output_names] 1488 elif isinstance(loss, collections.abc.Sequence): 1489 if len(loss) != len(output_names): 1490 raise ValueError('When passing a list as loss, it should have one entry ' 1491 'per model outputs. The model has {} outputs, but you ' 1492 'passed loss={}'.format(len(output_names), loss)) 1493 loss_functions = nest.map_structure(get_loss_function, loss) 1494 else: 1495 loss_functions = [get_loss_function(loss) for _ in range(len(output_names))] 1496 1497 return loss_functions 1498 1499 1500def prepare_loss_weights(training_endpoints, loss_weights=None): 1501 """Converts loss weights to a list of loss weights. 1502 1503 The result loss weights will be populated on the training endpoint. 1504 1505 Args: 1506 training_endpoints: List of model training endpoints. 1507 loss_weights: Optional list or dictionary specifying scalar coefficients 1508 (Python floats) to weight the loss contributions of different model 1509 outputs. The loss value that will be minimized by the model will then be 1510 the *weighted sum* of all individual losses, weighted by the 1511 `loss_weights` coefficients. If a list, it is expected to have a 1:1 1512 mapping to the model's outputs. If a dict, it is expected to map 1513 output names (strings) to scalar coefficients. 1514 1515 Raises: 1516 ValueError: If loss weight is a dict with key not in model output names, 1517 or if loss is a list with len not equal to model outputs. 1518 """ 1519 if loss_weights is None: 1520 for e in training_endpoints: 1521 e.loss_weight = 1. 1522 elif isinstance(loss_weights, collections.abc.Mapping): 1523 generic_utils.check_for_unexpected_keys( 1524 'loss_weights', loss_weights, 1525 [e.output_name for e in training_endpoints]) 1526 for e in training_endpoints: 1527 e.loss_weight = loss_weights.get(e.output_name, 1.) 1528 elif isinstance(loss_weights, list): 1529 if len(loss_weights) != len(training_endpoints): 1530 raise ValueError('When passing a list as loss_weights, ' 1531 'it should have one entry per model output. ' 1532 'The model has ' + str(len(training_endpoints)) + 1533 ' outputs, but you passed loss_weights=' + 1534 str(loss_weights)) 1535 for w, e in zip(loss_weights, training_endpoints): 1536 e.loss_weight = w 1537 else: 1538 raise TypeError('Could not interpret loss_weights argument: ' + 1539 str(loss_weights) + ' - expected a list of dicts.') 1540 1541 1542# TODO(rohanj): This is a hack to get around not depending on feature_column and 1543# create a cyclical dependency. Figure out a cleaner solution 1544def is_feature_layer(layer): 1545 """Returns whether `layer` is a FeatureLayer or not.""" 1546 return getattr(layer, '_is_feature_layer', False) 1547 1548 1549def is_eager_dataset_or_iterator(data): 1550 return context.executing_eagerly() and isinstance( 1551 data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2, 1552 iterator_ops.IteratorBase)) 1553 1554 1555# pylint: disable=protected-access 1556def get_dataset_graph_def(dataset): 1557 if context.executing_eagerly(): 1558 graph_def_str = dataset._as_serialized_graph().numpy() 1559 else: 1560 graph_def_str = K.get_value(dataset._as_serialized_graph()) 1561 return graph_pb2.GraphDef().FromString(graph_def_str) 1562 1563 1564def verify_dataset_shuffled(x): 1565 """Verifies that the dataset is shuffled. 1566 1567 Args: 1568 x: Dataset passed as an input to the model. 1569 1570 Returns: 1571 boolean, whether the input dataset is shuffled or not. 1572 """ 1573 assert isinstance(x, dataset_ops.DatasetV2) 1574 graph_def = get_dataset_graph_def(x) 1575 for node in graph_def.node: 1576 if node.op.startswith('ShuffleDataset'): 1577 return True 1578 # Also check graph_def.library.function for ds.interleave or ds.flat_map 1579 for function in graph_def.library.function: 1580 for node in function.node_def: 1581 if node.op.startswith('ShuffleDataset'): 1582 return True 1583 logging.warning('Expected a shuffled dataset but input dataset `x` is ' 1584 'not shuffled. Please invoke `shuffle()` on input dataset.') 1585 return False 1586 1587 1588def is_dataset_or_iterator(data): 1589 return isinstance(data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2, 1590 iterator_ops.Iterator, iterator_ops.IteratorBase)) 1591 1592 1593def get_iterator(dataset): 1594 """Create and initialize an iterator from a dataset.""" 1595 if context.executing_eagerly(): 1596 iterator = dataset_ops.make_one_shot_iterator(dataset) 1597 else: 1598 iterator = dataset_ops.make_initializable_iterator(dataset) 1599 initialize_iterator(iterator) 1600 return iterator 1601 1602 1603def initialize_iterator(iterator): 1604 if not context.executing_eagerly(): 1605 init_op = iterator.initializer 1606 K.get_session((init_op,)).run(init_op) 1607 1608 1609def extract_tensors_from_dataset(dataset): 1610 """Extract a tuple of tensors `inputs, targets, sample_weight` from a dataset. 1611 1612 Args: 1613 dataset: Dataset instance. 1614 1615 Returns: 1616 Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None. 1617 """ 1618 iterator = get_iterator(dataset) 1619 inputs, targets, sample_weight = unpack_iterator_input(iterator) 1620 return inputs, targets, sample_weight 1621 1622 1623def unpack_iterator_input(iterator): 1624 """Convert a dataset iterator to a tuple of tensors `x, y, sample_weights`. 1625 1626 Args: 1627 iterator: Instance of a dataset iterator. 1628 1629 Returns: 1630 Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None. 1631 """ 1632 try: 1633 next_element = iterator.get_next() 1634 except errors.OutOfRangeError: 1635 raise RuntimeError('Your dataset iterator ran out of data; ' 1636 'Make sure that your dataset can generate ' 1637 'required number of samples.') 1638 1639 if isinstance(next_element, (list, tuple)): 1640 if len(next_element) not in [2, 3]: 1641 raise ValueError( 1642 'Please provide model inputs as a list or tuple of 2 or 3 ' 1643 'elements: (input, target) or (input, target, sample_weights) ' 1644 'Received %s' % next_element) 1645 if len(next_element) == 2: 1646 x, y = next_element 1647 weights = None 1648 else: 1649 x, y, weights = next_element 1650 else: 1651 x = next_element 1652 y = None 1653 weights = None 1654 return x, y, weights 1655 1656 1657def infer_steps_for_dataset(model, 1658 dataset, 1659 steps, 1660 epochs=1, 1661 steps_name='steps'): 1662 """Infers steps_per_epoch needed to loop through a dataset. 1663 1664 Args: 1665 model: Keras model instance. 1666 dataset: Input data of type tf.data.Dataset. 1667 steps: Number of steps to draw from the dataset (may be None if unknown). 1668 epochs: Number of times to iterate over the dataset. 1669 steps_name: The string name of the steps argument, either `steps`, 1670 `validation_steps`, or `steps_per_epoch`. Only used for error message 1671 formatting. 1672 1673 Returns: 1674 Integer or `None`. Inferred number of steps to loop through the dataset. 1675 `None` is returned if 1) the size of the dataset is unknown and `steps` was 1676 not specified, or 2) this is multi-worker training and auto sharding is 1677 enabled. 1678 1679 Raises: 1680 ValueError: In case of invalid argument values. 1681 """ 1682 assert isinstance(dataset, dataset_ops.DatasetV2) 1683 if (model._in_multi_worker_mode() and 1684 (dataset.options().experimental_distribute.auto_shard_policy != 1685 distribute_options.AutoShardPolicy.OFF)): 1686 # If the dataset would be auto-sharded, we should not infer a local 1687 # steps_per_epoch due to the possible inbalanced sharding between workers. 1688 return None 1689 1690 size = K.get_value(cardinality.cardinality(dataset)) 1691 if size == cardinality.INFINITE and steps is None: 1692 raise ValueError('When passing an infinitely repeating dataset, you ' 1693 'must specify the `%s` argument.' % (steps_name,)) 1694 if size >= 0: 1695 if steps is not None and steps * epochs > size: 1696 if epochs > 1: 1697 raise ValueError('The dataset you passed contains %s batches, but you ' 1698 'passed `epochs=%s` and `%s=%s`, which is a total of ' 1699 '%s steps. We cannot draw that many steps from this ' 1700 'dataset. We suggest to set `%s=%s`.' % 1701 (size, epochs, steps_name, steps, steps * epochs, 1702 steps_name, size // epochs)) 1703 else: 1704 raise ValueError('The dataset you passed contains %s batches, but you ' 1705 'passed `%s=%s`. We cannot draw that many steps from ' 1706 'this dataset. We suggest to set `%s=%s`.' % 1707 (size, steps_name, steps, steps_name, size)) 1708 if steps is None: 1709 if size >= 0: 1710 return size 1711 return None 1712 return steps 1713 1714 1715class ModelInputs(object): 1716 """Encapsulates model inputs. 1717 1718 Allows for transforming model inputs while keeping the same structure. 1719 """ 1720 1721 def __init__(self, inputs): 1722 self._inputs = inputs 1723 self._is_dict = isinstance(self._inputs, dict) 1724 self._is_single_input = not isinstance(self._inputs, (list, tuple, dict)) 1725 1726 self._flattened_inputs = [] 1727 self._input_names = [] 1728 1729 if self._is_dict: 1730 for k in sorted(self._inputs.keys()): 1731 self._flattened_inputs.append(self._inputs[k]) 1732 self._input_names.append(k) 1733 else: 1734 self._flattened_inputs = nest.flatten(self._inputs) 1735 self._input_names = [ 1736 'input_%d' % (i + 1) for i in range(len(self._flattened_inputs)) 1737 ] 1738 1739 def get_input_names(self): 1740 """Returns keys to name inputs by. 1741 1742 In case inputs provided were a list, tuple or single entry, we make up a 1743 key 'input_%d'. For dictionary case, we return a sorted list of keys. 1744 """ 1745 return self._input_names 1746 1747 def get_symbolic_inputs(self, return_single_as_list=False): 1748 """Returns inputs to be set as self.inputs for a model.""" 1749 # TODO(karmel): There is a side-effect here where what you get 1750 # with as_list and as_dict depends on whether you have called this 1751 # method first, since it modifies in place. 1752 for i, (k, v) in enumerate(zip(self._input_names, self._flattened_inputs)): 1753 if isinstance(v, (list, float, int)): 1754 v = np.asarray(v) 1755 if v.ndim == 1: 1756 v = np.expand_dims(v, 1) 1757 1758 if isinstance(v, np.ndarray): 1759 # We fix the placeholder shape except the batch size. 1760 # This is suboptimal, but it is the best we can do with the info 1761 # we have. The user should call `model._set_inputs(placeholders)` 1762 # to specify custom placeholders if the need arises. 1763 shape = (None,) + tuple(v.shape[1:]) 1764 if shape == (None,): 1765 shape = (None, 1) 1766 dtype = dtypes.as_dtype(v.dtype) 1767 if dtype.is_floating: 1768 dtype = K.floatx() 1769 v = K.placeholder(shape=shape, name=k, dtype=dtype) 1770 elif isinstance(v, tensor_spec.TensorSpec): 1771 shape = (None,) + tuple(v.shape.as_list()[1:]) 1772 if shape == (None,): 1773 shape = (None, 1) 1774 v = K.placeholder(shape=shape, name=k, dtype=v.dtype) 1775 1776 self._flattened_inputs[i] = v 1777 1778 if self._is_dict: 1779 return dict(zip(self._input_names, self._flattened_inputs)) 1780 if self._is_single_input and not return_single_as_list: 1781 return self._flattened_inputs[0] 1782 return self._flattened_inputs 1783 1784 def as_dict(self): 1785 """An iterable over a dictionary version of inputs.""" 1786 for k, v in zip(self._input_names, self._flattened_inputs): 1787 yield k, v 1788 1789 def as_list(self): 1790 """Returning the inputs as a list.""" 1791 return self._flattened_inputs 1792 1793 1794# Allow use of methods not exposed to the user. 1795# pylint: disable=protected-access 1796 1797 1798# pylint: enable=protected-access 1799 1800 1801def generic_output_names(outputs_list): 1802 return ['output_%d' % (i + 1) for i in range(len(outputs_list))] 1803 1804 1805def should_run_validation(validation_freq, epoch): 1806 """Checks if validation should be run this epoch. 1807 1808 Args: 1809 validation_freq: Integer or list. If an integer, specifies how many training 1810 epochs to run before a new validation run is performed. If a list, 1811 specifies the epochs on which to run validation. 1812 epoch: Integer, the number of the training epoch just completed. 1813 1814 Returns: 1815 Bool, True if validation should be run. 1816 1817 Raises: 1818 ValueError: if `validation_freq` is an Integer and less than 1, or if 1819 it is neither an Integer nor a Sequence. 1820 """ 1821 # `epoch` is 0-indexed internally but 1-indexed in the public API. 1822 one_indexed_epoch = epoch + 1 1823 1824 if isinstance(validation_freq, int): 1825 if validation_freq < 1: 1826 raise ValueError('`validation_freq` can not be less than 1.') 1827 return one_indexed_epoch % validation_freq == 0 1828 1829 if not isinstance(validation_freq, collections.abc.Container): 1830 raise ValueError('`validation_freq` must be an Integer or ' 1831 '`collections.abc.Container` (e.g. list, tuple, etc.)') 1832 return one_indexed_epoch in validation_freq 1833 1834 1835def split_training_and_validation_data(x, y, sample_weights, validation_split): 1836 """Split input data into train/eval section based on validation_split.""" 1837 if has_symbolic_tensors(x): 1838 raise ValueError('If your data is in the form of symbolic tensors, ' 1839 'you cannot use `validation_split`.') 1840 if hasattr(x[0], 'shape'): 1841 split_at = int(x[0].shape[0] * (1. - validation_split)) 1842 else: 1843 split_at = int(len(x[0]) * (1. - validation_split)) 1844 x, val_x = (generic_utils.slice_arrays(x, 0, split_at), 1845 generic_utils.slice_arrays(x, split_at)) 1846 y, val_y = (generic_utils.slice_arrays(y, 0, split_at), 1847 generic_utils.slice_arrays(y, split_at)) 1848 if sample_weights: 1849 sample_weights, val_sample_weights = ( 1850 generic_utils.slice_arrays(sample_weights, 0, split_at), 1851 generic_utils.slice_arrays(sample_weights, split_at), 1852 ) 1853 else: 1854 val_sample_weights = None 1855 return x, y, sample_weights, val_x, val_y, val_sample_weights 1856 1857 1858def unpack_validation_data(validation_data, raise_if_ambiguous=True): 1859 """Unpack validation data based input type. 1860 1861 The validation data is not touched if its dataset or dataset iterator. 1862 For other type of input (Numpy or tensor), it will be unpacked into tuple of 1863 3 which is x, y and sample weights. 1864 1865 Args: 1866 validation_data: dataset, dataset iterator, or numpy, tensor tuple. 1867 raise_if_ambiguous: boolean on whether to fail if validation_data cannot be 1868 parsed. Otherwise simply return validation_data, None, None and defer the 1869 decision to the caller. 1870 1871 Returns: 1872 tuple of 3, (x, y, sample_weights) for numpy and tensor input. 1873 """ 1874 if (isinstance(validation_data, (iterator_ops.Iterator, 1875 iterator_ops.IteratorBase, 1876 dataset_ops.DatasetV2, 1877 data_utils.Sequence)) 1878 or not hasattr(validation_data, '__len__')): 1879 val_x = validation_data 1880 val_y = None 1881 val_sample_weight = None 1882 elif len(validation_data) == 2: 1883 try: 1884 val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence 1885 val_sample_weight = None 1886 except ValueError: 1887 val_x, val_y, val_sample_weight = validation_data, None, None 1888 elif len(validation_data) == 3: 1889 try: 1890 val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence 1891 except ValueError: 1892 val_x, val_y, val_sample_weight = validation_data, None, None 1893 else: 1894 if raise_if_ambiguous: 1895 raise ValueError( 1896 'When passing a `validation_data` argument, ' 1897 'it must contain either 2 items (x_val, y_val), ' 1898 'or 3 items (x_val, y_val, val_sample_weights), ' 1899 'or alternatively it could be a dataset or a ' 1900 'dataset or a dataset iterator. ' 1901 'However we received `validation_data=%s`' % validation_data) 1902 val_x, val_y, val_sample_weight = validation_data, None, None 1903 return val_x, val_y, val_sample_weight 1904 1905 1906class TrainingLoop(object): 1907 """TrainingLoop is a wrapper class around the training logic. 1908 1909 This class is trying to encapsulate the different logic of fit/eval/predict 1910 with regard to different data input and model condition. 1911 1912 Note that TrainingLoop is stateless, which means it doesn't contain any 1913 internal field and can be reused with different model and inputs. 1914 """ 1915 1916 def fit(self, 1917 model, 1918 x=None, 1919 y=None, 1920 batch_size=None, 1921 epochs=1, 1922 verbose=1, 1923 callbacks=None, 1924 validation_split=0., 1925 validation_data=None, 1926 shuffle=True, 1927 class_weight=None, 1928 sample_weight=None, 1929 initial_epoch=0, 1930 steps_per_epoch=None, 1931 validation_steps=None, 1932 validation_freq=1, 1933 **kwargs): 1934 """Train the model with the inputs and targets.""" 1935 raise NotImplementedError() 1936 1937 def evaluate(self, 1938 model, 1939 x=None, 1940 y=None, 1941 batch_size=None, 1942 verbose=1, 1943 sample_weight=None, 1944 steps=None, 1945 callbacks=None, 1946 **kwargs): 1947 """Returns the loss value & metrics values for the model in test mode.""" 1948 raise NotImplementedError() 1949 1950 def predict(self, 1951 model, 1952 x, 1953 batch_size=None, 1954 verbose=0, 1955 steps=None, 1956 callbacks=None, 1957 **kwargs): 1958 raise NotImplementedError() 1959