1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Training-related utilities.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21import collections 22from collections import OrderedDict 23 24import numpy as np 25import six 26 27from tensorflow.python import tf2 28from tensorflow.python.data.experimental.ops import cardinality 29from tensorflow.python.data.ops import dataset_ops 30from tensorflow.python.data.ops import iterator_ops 31from tensorflow.python.data.ops import readers 32from tensorflow.python.eager import context 33from tensorflow.python.framework import constant_op 34from tensorflow.python.framework import dtypes 35from tensorflow.python.framework import errors 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import tensor_shape 38from tensorflow.python.framework import tensor_util 39from tensorflow.python.keras import backend as K 40from tensorflow.python.keras import callbacks as cbks 41from tensorflow.python.keras import losses 42from tensorflow.python.keras import metrics as metrics_module 43from tensorflow.python.keras.utils import generic_utils 44from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions 45from tensorflow.python.ops import array_ops 46from tensorflow.python.ops import math_ops 47from tensorflow.python.platform import tf_logging as logging 48from tensorflow.python.util import nest 49 50 51@six.add_metaclass(abc.ABCMeta) 52class Aggregator(object): 53 """Abstract base class used to aggregate batch-level outputs of a loop. 54 55 Attributes: 56 use_steps: Whether the loop is using `step` or `batch_size`. 57 num_samples_or_steps: Either `batch_size*num_batches` or `steps`. 58 results: What to return at the end of the aggregation loop. 59 """ 60 61 def __init__(self, use_steps, num_samples_or_steps): 62 self.use_steps = use_steps 63 self.num_samples_or_steps = num_samples_or_steps 64 self.results = [] 65 66 @abc.abstractmethod 67 def create(self, batch_outs): 68 """Creates the initial results from the first batch outputs. 69 70 Arguments: 71 batch_outs: A list of batch-level outputs. 72 """ 73 raise NotImplementedError('Must be implemented in subclasses.') 74 75 @abc.abstractmethod 76 def aggregate(self, batch_outs, batch_start=None, batch_end=None): 77 """Aggregates batch-level results into total results. 78 79 Arguments: 80 batch_outs: A list of batch-level outputs. 81 batch_start: The start index of this batch. Always `None` if `use_steps` 82 is `True`. 83 batch_end: The end index of this batch. Always `None` if `use_steps` is 84 `True`. 85 """ 86 raise NotImplementedError('Must be implemented in subclasses.') 87 88 @abc.abstractmethod 89 def finalize(self): 90 """Prepares the total results to be returned.""" 91 raise NotImplementedError('Must be implemented in subclasses.') 92 93 94class MetricsAggregator(Aggregator): 95 """Aggregator that calculates loss and metrics info.""" 96 97 def create(self, batch_outs): 98 self.results = [0.] * len(batch_outs) 99 100 def aggregate(self, batch_outs, batch_start=None, batch_end=None): 101 # Loss. 102 if self.use_steps: 103 self.results[0] += batch_outs[0] 104 else: 105 self.results[0] += batch_outs[0] * (batch_end - batch_start) 106 # Metrics (always stateful, just grab current values.) 107 self.results[1:] = batch_outs[1:] 108 109 def finalize(self): 110 if not self.results: 111 raise ValueError('Empty training data.') 112 self.results[0] /= self.num_samples_or_steps 113 114 115class OutputsAggregator(Aggregator): 116 """Aggregator that concatenates outputs.""" 117 118 def create(self, batch_outs): 119 if self.use_steps: 120 # Cannot pre-allocate the returned NumPy arrays bc 121 # batch sizes are unknown. Concatenate batches at the end. 122 for _ in batch_outs: 123 self.results.append([]) 124 else: 125 # Pre-allocate NumPy arrays. 126 for batch_out in batch_outs: 127 shape = (self.num_samples_or_steps,) + batch_out.shape[1:] 128 self.results.append(np.zeros(shape, dtype=batch_out.dtype)) 129 130 def aggregate(self, batch_outs, batch_start=None, batch_end=None): 131 if self.use_steps: 132 for i, batch_out in enumerate(batch_outs): 133 self.results[i].append(batch_out) 134 else: 135 for i, batch_out in enumerate(batch_outs): 136 self.results[i][batch_start:batch_end] = batch_out 137 138 def finalize(self): 139 if self.use_steps: 140 self.results = [np.concatenate(result, axis=0) for result in self.results] 141 142 143def get_progbar(model, count_mode): 144 """Get Progbar.""" 145 stateful_metric_names = None 146 if hasattr(model, 'metrics_names'): 147 stateful_metric_names = model.metrics_names[1:] # Exclude `loss` 148 return cbks.ProgbarLogger(count_mode, stateful_metrics=stateful_metric_names) 149 150 151def slice_arrays(arrays, indices, contiguous=True): 152 """Slices batches out of provided arrays (workaround for eager tensors). 153 154 Unfortunately eager tensors don't have the same slicing behavior as 155 Numpy arrays (they follow the same slicing behavior as symbolic TF tensors), 156 hence we cannot use `generic_utils.slice_arrays` directly 157 and we have to implement this workaround based on `concat`. This has a 158 performance cost. 159 160 Arguments: 161 arrays: Single array or list of arrays. 162 indices: List of indices in the array that should be included in the output 163 batch. 164 contiguous: Boolean flag indicating whether the indices are contiguous. 165 166 Returns: 167 Slice of data (either single array or list of arrays). 168 """ 169 converted_to_list = False 170 if not isinstance(arrays, list): 171 converted_to_list = True 172 arrays = [arrays] 173 if any(tensor_util.is_tensor(x) for x in arrays): 174 if not contiguous: 175 entries = [[x[i:i + 1] for i in indices] for x in arrays] 176 slices = [array_ops.concat(x, axis=0) for x in entries] 177 else: 178 slices = [x[indices[0]:indices[-1] + 1] for x in arrays] 179 else: 180 slices = generic_utils.slice_arrays(arrays, indices) 181 182 if converted_to_list: 183 slices = slices[0] 184 return slices 185 186 187def check_num_samples(ins, batch_size=None, steps=None, steps_name='steps'): 188 """Determine the number of samples provided for training and evaluation. 189 190 The number of samples is not defined when running with `steps`, 191 in which case the number of samples is set to `None`. 192 193 Arguments: 194 ins: List of tensors to be fed to the Keras function. 195 batch_size: Integer batch size or `None` if not defined. 196 steps: Total number of steps (batches of samples) before declaring 197 `_predict_loop` finished. Ignored with the default value of `None`. 198 steps_name: The public API's parameter name for `steps`. 199 200 Raises: 201 ValueError: when `steps` is `None` and the attribute `ins.shape` 202 does not exist. Also raises ValueError when `steps` is not `None` 203 and `batch_size` is not `None` because they are mutually 204 exclusive. 205 206 Returns: 207 When steps is `None`, returns the number of samples to be 208 processed based on the size of the first dimension of the 209 first input numpy array. When steps is not `None` and 210 `batch_size` is `None`, returns `None`. 211 """ 212 if steps is not None and batch_size is not None: 213 raise ValueError('If ' + steps_name + 214 ' is set, the `batch_size` must be None.') 215 if check_steps_argument(ins, steps, steps_name): 216 return None 217 if hasattr(ins[0], 'shape'): 218 return int(ins[0].shape[0]) 219 return None # Edge case where ins == [static_learning_phase] 220 221 222def standardize_single_array(x, expected_shape=None): 223 """Expand data of shape (x,) to (x, 1), unless len(expected_shape)==1.""" 224 if x is None: 225 return None 226 227 if (x.shape is not None and len(x.shape) == 1 and 228 (expected_shape is None or len(expected_shape) != 1)): 229 if tensor_util.is_tensor(x): 230 x = array_ops.expand_dims(x, axis=1) 231 else: 232 x = np.expand_dims(x, 1) 233 return x 234 235 236def standardize_input_data(data, 237 names, 238 shapes=None, 239 check_batch_axis=True, 240 exception_prefix=''): 241 """Normalizes inputs and targets provided by users. 242 243 Users may pass data as a list of arrays, dictionary of arrays, 244 or as a single array. We normalize this to an ordered list of 245 arrays (same order as `names`), while checking that the provided 246 arrays have shapes that match the network's expectations. 247 248 Arguments: 249 data: User-provided input data (polymorphic). 250 names: List of expected array names. 251 shapes: Optional list of expected array shapes. 252 check_batch_axis: Boolean; whether to check that the batch axis of the 253 arrays matches the expected value found in `shapes`. 254 exception_prefix: String prefix used for exception formatting. 255 256 Returns: 257 List of standardized input arrays (one array per model input). 258 259 Raises: 260 ValueError: in case of improperly formatted user-provided data. 261 """ 262 if not names: 263 if (data is not None and hasattr(data, '__len__') and len(data) and 264 not isinstance(data, dict)): 265 raise ValueError( 266 'Error when checking model ' + exception_prefix + ': ' 267 'expected no data, but got:', data) 268 return [] 269 if data is None: 270 return [None for _ in range(len(names))] 271 272 if isinstance(data, dict): 273 try: 274 data = [ 275 data[x].values 276 if data[x].__class__.__name__ == 'DataFrame' else data[x] 277 for x in names 278 ] 279 except KeyError as e: 280 raise ValueError('No data provided for "' + e.args[0] + '". Need data ' 281 'for each key in: ' + str(names)) 282 elif isinstance(data, (list, tuple)): 283 if isinstance(data[0], (list, tuple)): 284 data = [np.asarray(d) for d in data] 285 elif len(names) == 1 and isinstance(data[0], (float, int)): 286 data = [np.asarray(data)] 287 else: 288 data = [ 289 x.values if x.__class__.__name__ == 'DataFrame' else x for x in data 290 ] 291 else: 292 data = data.values if data.__class__.__name__ == 'DataFrame' else data 293 data = [data] 294 if shapes is not None: 295 data = [ 296 standardize_single_array(x, shape) for (x, shape) in zip(data, shapes) 297 ] 298 else: 299 data = [standardize_single_array(x) for x in data] 300 301 if len(data) != len(names): 302 if data and hasattr(data[0], 'shape'): 303 raise ValueError('Error when checking model ' + exception_prefix + 304 ': the list of Numpy arrays that you are passing to ' 305 'your model is not the size the model expected. ' 306 'Expected to see ' + str(len(names)) + ' array(s), ' 307 'but instead got the following list of ' + 308 str(len(data)) + ' arrays: ' + str(data)[:200] + '...') 309 elif len(names) > 1: 310 raise ValueError('Error when checking model ' + exception_prefix + 311 ': you are passing a list as input to your model, ' 312 'but the model expects a list of ' + str(len(names)) + 313 ' Numpy arrays instead. The list you passed was: ' + 314 str(data)[:200]) 315 elif len(data) == 1 and not hasattr(data[0], 'shape'): 316 raise TypeError('Error when checking model ' + exception_prefix + 317 ': data should be a Numpy array, or list/dict of ' 318 'Numpy arrays. Found: ' + str(data)[:200] + '...') 319 elif len(names) == 1: 320 data = [np.asarray(data)] 321 322 # Check shapes compatibility. 323 if shapes: 324 for i in range(len(names)): 325 if shapes[i] is not None: 326 if tensor_util.is_tensor(data[i]): 327 tensorshape = data[i].get_shape() 328 if not tensorshape: 329 continue 330 data_shape = tuple(tensorshape.as_list()) 331 else: 332 data_shape = data[i].shape 333 shape = shapes[i] 334 if len(data_shape) != len(shape): 335 raise ValueError('Error when checking ' + exception_prefix + 336 ': expected ' + names[i] + ' to have ' + 337 str(len(shape)) + ' dimensions, but got array ' 338 'with shape ' + str(data_shape)) 339 if not check_batch_axis: 340 data_shape = data_shape[1:] 341 shape = shape[1:] 342 for dim, ref_dim in zip(data_shape, shape): 343 if ref_dim != dim and ref_dim is not None and dim is not None: 344 raise ValueError('Error when checking ' + exception_prefix + 345 ': expected ' + names[i] + ' to have shape ' + 346 str(shape) + ' but got array with shape ' + 347 str(data_shape)) 348 return data 349 350 351def standardize_sample_or_class_weights(x_weight, output_names, weight_type): 352 """Maps `sample_weight` or `class_weight` to model outputs. 353 354 Arguments: 355 x_weight: User-provided `sample_weight` or `class_weight` argument. 356 output_names: List of output names (strings) in the model. 357 weight_type: A string used purely for exception printing. 358 359 Returns: 360 A list of `sample_weight` or `class_weight` where there are exactly 361 one element per model output. 362 363 Raises: 364 ValueError: In case of invalid user-provided argument. 365 """ 366 if x_weight is None or (isinstance(x_weight, list) and len(x_weight) == 0): # pylint: disable=g-explicit-length-test 367 return [None for _ in output_names] 368 if len(output_names) == 1: 369 if isinstance(x_weight, list) and len(x_weight) == 1: 370 return x_weight 371 if isinstance(x_weight, dict) and output_names[0] in x_weight: 372 return [x_weight[output_names[0]]] 373 else: 374 return [x_weight] 375 if isinstance(x_weight, list): 376 if len(x_weight) != len(output_names): 377 raise ValueError('Provided `' + weight_type + '` was a list of ' + 378 str(len(x_weight)) + ' elements, but the model has ' + 379 str(len(output_names)) + ' outputs. ' 380 'You should provide one `' + weight_type + '`' 381 'array per model output.') 382 return x_weight 383 if isinstance(x_weight, dict): 384 x_weights = [] 385 for name in output_names: 386 x_weights.append(x_weight.get(name)) 387 return x_weights 388 else: 389 raise TypeError('The model has multiple outputs, so `' + weight_type + '` ' 390 'should be either a list or a dict. ' 391 'Provided `' + weight_type + '` type not understood: ' + 392 str(x_weight)) 393 394 395def standardize_class_weights(class_weight, output_names): 396 return standardize_sample_or_class_weights(class_weight, output_names, 397 'class_weight') 398 399 400def standardize_sample_weights(sample_weight, output_names): 401 return standardize_sample_or_class_weights(sample_weight, output_names, 402 'sample_weight') 403 404 405def check_array_lengths(inputs, targets, weights=None): 406 """Does user input validation for numpy arrays. 407 408 Arguments: 409 inputs: list of Numpy arrays of inputs. 410 targets: list of Numpy arrays of targets. 411 weights: list of Numpy arrays of sample weights. 412 413 Raises: 414 ValueError: in case of incorrectly formatted data. 415 """ 416 417 def set_of_lengths(x): 418 # Returns a set with the variation between 419 # different shapes, with None => 0 420 if x is None: 421 return {} 422 else: 423 return set([ 424 y.shape[0] 425 for y in x 426 if y is not None and not tensor_util.is_tensor(y) 427 ]) 428 429 set_x = set_of_lengths(inputs) 430 set_y = set_of_lengths(targets) 431 set_w = set_of_lengths(weights) 432 if len(set_x) > 1: 433 raise ValueError('All input arrays (x) should have ' 434 'the same number of samples. Got array shapes: ' + 435 str([x.shape for x in inputs])) 436 if len(set_y) > 1: 437 raise ValueError('All target arrays (y) should have ' 438 'the same number of samples. Got array shapes: ' + 439 str([y.shape for y in targets])) 440 if set_x and set_y and list(set_x)[0] != list(set_y)[0]: 441 raise ValueError('Input arrays should have ' 442 'the same number of samples as target arrays. ' 443 'Found ' + str(list(set_x)[0]) + ' input samples ' 444 'and ' + str(list(set_y)[0]) + ' target samples.') 445 if len(set_w) > 1: 446 raise ValueError('All sample_weight arrays should have ' 447 'the same number of samples. Got array shapes: ' + 448 str([w.shape for w in weights])) 449 if set_y and set_w and list(set_y)[0] != list(set_w)[0]: 450 raise ValueError('Sample_weight arrays should have ' 451 'the same number of samples as target arrays. Got ' + 452 str(list(set_y)[0]) + ' input samples and ' + 453 str(list(set_w)[0]) + ' target samples.') 454 455 456def check_loss_and_target_compatibility(targets, loss_fns, output_shapes): 457 """Does validation on the compatibility of targets and loss functions. 458 459 This helps prevent users from using loss functions incorrectly. This check 460 is purely for UX purposes. 461 462 Arguments: 463 targets: list of Numpy arrays of targets. 464 loss_fns: list of loss functions. 465 output_shapes: list of shapes of model outputs. 466 467 Raises: 468 ValueError: if a loss function or target array 469 is incompatible with an output. 470 """ 471 key_loss_fns = { 472 losses.mean_squared_error, losses.binary_crossentropy, 473 losses.categorical_crossentropy 474 } 475 key_loss_classes = (losses.MeanSquaredError, losses.BinaryCrossentropy, 476 losses.CategoricalCrossentropy) 477 for y, loss, shape in zip(targets, loss_fns, output_shapes): 478 if y is None or loss is None or tensor_util.is_tensor(y): 479 continue 480 if losses.is_categorical_crossentropy(loss): 481 if y.shape[-1] == 1: 482 raise ValueError('You are passing a target array of shape ' + 483 str(y.shape) + 484 ' while using as loss `categorical_crossentropy`. ' 485 '`categorical_crossentropy` expects ' 486 'targets to be binary matrices (1s and 0s) ' 487 'of shape (samples, classes). ' 488 'If your targets are integer classes, ' 489 'you can convert them to the expected format via:\n' 490 '```\n' 491 'from keras.utils import to_categorical\n' 492 'y_binary = to_categorical(y_int)\n' 493 '```\n' 494 '\n' 495 'Alternatively, you can use the loss function ' 496 '`sparse_categorical_crossentropy` instead, ' 497 'which does expect integer targets.') 498 499 is_loss_wrapper = isinstance(loss, losses.LossFunctionWrapper) 500 if (isinstance(loss, key_loss_classes) or (is_loss_wrapper and 501 (loss.fn in key_loss_fns))): 502 for target_dim, out_dim in zip(y.shape[1:], shape[1:]): 503 if out_dim is not None and target_dim != out_dim: 504 loss_name = loss.name 505 if loss_name is None: 506 loss_type = loss.fn if is_loss_wrapper else type(loss) 507 loss_name = loss_type.__name__ 508 raise ValueError('A target array with shape ' + str(y.shape) + 509 ' was passed for an output of shape ' + str(shape) + 510 ' while using as loss `' + loss_name + '`. ' 511 'This loss expects targets to have the same shape ' 512 'as the output.') 513 514 515def collect_per_output_metric_info(metrics, 516 output_names, 517 output_shapes, 518 loss_fns, 519 is_weighted=False): 520 """Maps metric names and functions to model outputs. 521 522 Arguments: 523 metrics: a list or a list of lists or a dict of metric functions. 524 output_names: a list of the names (strings) of model outputs. 525 output_shapes: a list of the shapes (strings) of model outputs. 526 loss_fns: a list of the loss functions corresponding to the model outputs. 527 is_weighted: Boolean indicating whether the given metrics are weighted. 528 529 Returns: 530 A list (one entry per model output) of dicts. 531 For instance, if the model has 2 outputs, and for the first output 532 we want to compute "binary_accuracy" and "binary_crossentropy", 533 and just "binary_accuracy" for the second output, 534 the list would look like: `[{ 535 'acc': binary_accuracy(), 536 'ce': binary_crossentropy(), 537 }, { 538 'acc': binary_accuracy(), 539 }]` 540 541 Raises: 542 TypeError: if an incorrect type is passed for the `metrics` argument. 543 """ 544 if not metrics: 545 return [{} for _ in output_names] 546 547 if isinstance(metrics, list): 548 any_sub_list = any(isinstance(m, list) for m in metrics) 549 if any_sub_list: 550 if len(metrics) != len(output_names): 551 raise ValueError('When passing a list of lists as `metrics`, ' 552 'it should have one entry per model output. ' 553 'The model has ' + str(len(output_names)) + 554 ' outputs, but you passed metrics=' + str(metrics)) 555 # User has provided a list of len = len(outputs). 556 nested_metrics = [generic_utils.to_list(m) for m in metrics] 557 else: 558 # If it is a single list we then apply all metrics to all outputs. 559 if len(output_names) > 1: 560 nested_metrics = [] 561 for _ in output_names: 562 nested_metrics.append( 563 [metrics_module.clone_metric(m) for m in metrics]) 564 else: 565 nested_metrics = [metrics] 566 elif isinstance(metrics, dict): 567 nested_metrics = [] 568 for name in output_names: 569 output_metrics = generic_utils.to_list(metrics.get(name, [])) 570 nested_metrics.append(output_metrics) 571 else: 572 raise TypeError('Type of `metrics` argument not understood. ' 573 'Expected a list or dictionary, found: ' + str(metrics)) 574 575 per_output_metrics = [] 576 for i, metrics in enumerate(nested_metrics): 577 metrics_dict = OrderedDict() 578 for metric in metrics: 579 metric_name = get_metric_name(metric, is_weighted) 580 metric_fn = get_metric_function( 581 metric, output_shape=output_shapes[i], loss_fn=loss_fns[i]) 582 583 # If the metric function is not stateful, we create a stateful version. 584 if not isinstance(metric_fn, metrics_module.Metric): 585 metric_fn = metrics_module.MeanMetricWrapper( 586 metric_fn, name=metric_name) 587 metrics_dict[metric_name] = metric_fn 588 per_output_metrics.append(metrics_dict) 589 590 return per_output_metrics 591 592 593def batch_shuffle(index_array, batch_size): 594 """Shuffles an array in a batch-wise fashion. 595 596 Useful for shuffling HDF5 arrays 597 (where one cannot access arbitrary indices). 598 599 Arguments: 600 index_array: array of indices to be shuffled. 601 batch_size: integer. 602 603 Returns: 604 The `index_array` array, shuffled in a batch-wise fashion. 605 """ 606 batch_count = int(len(index_array) / batch_size) 607 # to reshape we need to be cleanly divisible by batch size 608 # we stash extra items and reappend them after shuffling 609 last_batch = index_array[batch_count * batch_size:] 610 index_array = index_array[:batch_count * batch_size] 611 index_array = index_array.reshape((batch_count, batch_size)) 612 np.random.shuffle(index_array) 613 index_array = index_array.flatten() 614 return np.append(index_array, last_batch) 615 616 617def standardize_weights(y, 618 sample_weight=None, 619 class_weight=None, 620 sample_weight_mode=None): 621 """Performs sample weight validation and standardization. 622 623 Everything gets normalized to a single sample-wise (or timestep-wise) 624 weight array. If both `sample_weight` and `class_weight` are provided, 625 the weights are multiplied. 626 627 Arguments: 628 y: Numpy array of model targets to be weighted. 629 sample_weight: User-provided `sample_weight` argument. 630 class_weight: User-provided `class_weight` argument. 631 sample_weight_mode: One of `None` or `"temporal"`. `"temporal"` indicated 632 that we expect 2D weight data that will be applied to the last 2 633 dimensions of the targets (i.e. we are weighting timesteps, not 634 samples). 635 636 Returns: 637 A numpy array of target weights, one entry per sample to weight. 638 639 Raises: 640 ValueError: In case of invalid user-provided arguments. 641 """ 642 # Iterator may return sample_weight as 1-tuple 643 if isinstance(sample_weight, tuple): 644 sample_weight = sample_weight[0] 645 if sample_weight_mode is not None: 646 if sample_weight_mode != 'temporal': 647 raise ValueError('"sample_weight_mode ' 648 'should be None or "temporal". ' 649 'Found: ' + str(sample_weight_mode)) 650 if len(y.shape) < 3: 651 raise ValueError('Found a sample_weight array for ' 652 'an input with shape ' + str(y.shape) + '. ' 653 'Timestep-wise sample weighting (use of ' 654 'sample_weight_mode="temporal") is restricted to ' 655 'outputs that are at least 3D, i.e. that have ' 656 'a time dimension.') 657 if sample_weight is not None and len(sample_weight.shape) != 2: 658 raise ValueError('Found a sample_weight array with shape ' + 659 str(sample_weight.shape) + '. ' 660 'In order to use timestep-wise sample weighting, ' 661 'you should pass a 2D sample_weight array.') 662 else: 663 if sample_weight is not None and len(sample_weight.shape) != 1: 664 raise ValueError('Found a sample_weight array with shape ' + 665 str(sample_weight.shape) + '. ' 666 'In order to use timestep-wise sample weights, ' 667 'you should specify ' 668 'sample_weight_mode="temporal" ' 669 'in compile(). If you just mean to use ' 670 'sample-wise weights, make sure your ' 671 'sample_weight array is 1D.') 672 673 if sample_weight is not None: 674 if len(sample_weight.shape) > len(y.shape): 675 raise ValueError('Found a sample_weight with shape' + 676 str(sample_weight.shape) + '.' 677 'Expected sample_weight with rank ' 678 'less than or equal to ' + str(len(y.shape))) 679 680 if (not tensor_util.is_tensor(sample_weight) and 681 y.shape[:sample_weight.ndim] != sample_weight.shape): 682 raise ValueError('Found a sample_weight array with shape ' + 683 str(sample_weight.shape) + ' for an input with shape ' + 684 str(y.shape) + '. ' 685 'sample_weight cannot be broadcast.') 686 687 # Class weights applied per-sample. 688 class_sample_weight = None 689 if isinstance(class_weight, dict): 690 if len(y.shape) > 2: 691 raise ValueError('`class_weight` not supported for ' 692 '3+ dimensional targets.') 693 694 if len(y.shape) == 2: 695 if y.shape[1] > 1: 696 y_classes = np.argmax(y, axis=1) 697 elif y.shape[1] == 1: 698 y_classes = np.reshape(y, y.shape[0]) 699 else: 700 y_classes = y 701 702 class_sample_weight = np.asarray( 703 [class_weight[cls] for cls in y_classes if cls in class_weight]) 704 705 if len(class_sample_weight) != len(y_classes): 706 # subtract the sets to pick all missing classes 707 existing_classes = set(y_classes) 708 existing_class_weight = set(class_weight.keys()) 709 raise ValueError( 710 '`class_weight` must contain all classes in the data.' 711 ' The classes %s exist in the data but not in ' 712 '`class_weight`.' % (existing_classes - existing_class_weight)) 713 714 if class_sample_weight is not None and sample_weight is not None: 715 # Multiply weights if both are provided. 716 return class_sample_weight * sample_weight 717 if sample_weight is not None: 718 return sample_weight 719 if class_sample_weight is not None: 720 return class_sample_weight 721 return None 722 723 724def has_symbolic_tensors(ls): 725 if context.executing_eagerly(): 726 return False 727 return has_tensors(ls) 728 729 730def has_tensors(ls): 731 if isinstance(ls, (list, tuple)): 732 return any(tensor_util.is_tensor(v) for v in ls) 733 if isinstance(ls, dict): 734 return any(tensor_util.is_tensor(v) for _, v in six.iteritems(ls)) 735 return tensor_util.is_tensor(ls) 736 737 738def get_metric_name(metric, weighted=False): 739 """Returns the name corresponding to the given metric input. 740 741 Arguments: 742 metric: Metric function name or reference. 743 weighted: Boolean indicating if the given metric is weighted. 744 745 Returns: 746 The metric name. 747 """ 748 if tf2.enabled(): 749 # We keep the string that the user has set in compile as the metric name. 750 if isinstance(metric, six.string_types): 751 return metric 752 753 metric = metrics_module.get(metric) 754 return metric.name if hasattr(metric, 'name') else metric.__name__ 755 else: 756 metric_name_prefix = 'weighted_' if weighted else '' 757 if metric in ('accuracy', 'acc', 'crossentropy', 'ce'): 758 if metric in ('accuracy', 'acc'): 759 suffix = 'acc' 760 elif metric in ('crossentropy', 'ce'): 761 suffix = 'ce' 762 else: 763 metric_fn = metrics_module.get(metric) 764 # Get metric name as string 765 if hasattr(metric_fn, 'name'): 766 suffix = metric_fn.name 767 else: 768 suffix = metric_fn.__name__ 769 metric_name = metric_name_prefix + suffix 770 return metric_name 771 772 773def get_metric_function(metric, output_shape=None, loss_fn=None): 774 """Returns the metric function corresponding to the given metric input. 775 776 Arguments: 777 metric: Metric function name or reference. 778 output_shape: The shape of the output that this metric will be calculated 779 for. 780 loss_fn: The loss function used. 781 782 Returns: 783 The metric function. 784 """ 785 if metric not in ['accuracy', 'acc', 'crossentropy', 'ce']: 786 return metrics_module.get(metric) 787 788 is_sparse_categorical_crossentropy = ( 789 isinstance(loss_fn, losses.SparseCategoricalCrossentropy) or 790 (isinstance(loss_fn, losses.LossFunctionWrapper) and 791 loss_fn.fn == losses.sparse_categorical_crossentropy)) 792 793 is_binary_crossentropy = ( 794 isinstance(loss_fn, losses.BinaryCrossentropy) or 795 (isinstance(loss_fn, losses.LossFunctionWrapper) and 796 loss_fn.fn == losses.binary_crossentropy)) 797 798 if metric in ['accuracy', 'acc']: 799 if output_shape[-1] == 1 or is_binary_crossentropy: 800 return metrics_module.binary_accuracy 801 elif is_sparse_categorical_crossentropy: 802 return metrics_module.sparse_categorical_accuracy 803 # If the output_shape[-1] is not 1, then we know output is `categorical`. 804 # We assume it is sparse categorical only if loss is explicitly given 805 # as sparse categorical crossentropy loss. 806 return metrics_module.categorical_accuracy 807 else: 808 if output_shape[-1] == 1 or is_binary_crossentropy: 809 return metrics_module.binary_crossentropy 810 elif is_sparse_categorical_crossentropy: 811 return metrics_module.sparse_categorical_crossentropy 812 return metrics_module.categorical_crossentropy 813 814 815def call_metric_function(metric_fn, y_true, y_pred, weights=None, mask=None): 816 """Invokes metric function and returns the metric result tensor.""" 817 if mask is None: 818 return metric_fn(y_true, y_pred, sample_weight=weights) 819 820 mask = math_ops.cast(mask, y_pred.dtype) 821 if weights is None: 822 # Use mask as sample weight. 823 return metric_fn(y_true, y_pred, sample_weight=mask) 824 825 # Update dimensions of weights to match with mask. 826 mask, _, weights = squeeze_or_expand_dimensions(mask, None, weights) 827 weights *= mask 828 return metric_fn(y_true, y_pred, sample_weight=weights) 829 830 831def get_loss_function(loss): 832 """Returns the loss function corresponding to the given loss input.""" 833 if loss is None or isinstance(loss, losses.Loss): 834 return loss 835 836 # Deserialize loss configuration, if needed. 837 if isinstance(loss, collections.Mapping): 838 loss = losses.get(loss) 839 840 # Custom callable class. 841 if callable(loss) and not hasattr(loss, '__name__'): 842 return loss 843 844 # Wrap loss function with signature `(y_true, y_pred, **kwargs)` 845 # in `LossFunctionWrapper` class. 846 loss_fn = losses.get(loss) 847 return losses.LossFunctionWrapper(loss_fn, name=loss_fn.__name__) 848 849 850def validate_dataset_input(x, y, sample_weight, validation_split=None): 851 """Validates user input arguments when a dataset iterator is passed. 852 853 Arguments: 854 x: Input data. A `tf.data` dataset or iterator. 855 y: Target data. It could be either Numpy array(s) or TensorFlow tensor(s). 856 Expected to be `None` when `x` is a dataset iterator. 857 sample_weight: An optional sample-weight array passed by the user to weight 858 the importance of each sample in `x`. Expected to be `None` when `x` is a 859 dataset iterator 860 validation_split: Float between 0 and 1. Fraction of the training data to be 861 used as validation data. Expected to be `None` when `x` is a dataset 862 iterator. 863 864 Raises: 865 ValueError: if argument `y` or `sample_weight` or `validation_split` are 866 provided by user. 867 """ 868 if y is not None: 869 raise ValueError('You passed a dataset or dataset iterator (%s) as ' 870 'input `x` to your model. In that case, you should ' 871 'not specify a target (`y`) argument, since the dataset ' 872 'or dataset iterator generates both input data and ' 873 'target data. ' 874 'Received: %s' % (x, y)) 875 if sample_weight is not None: 876 raise ValueError('`sample_weight` argument is not supported when input ' 877 '`x` is a dataset or a dataset iterator. Instead, you' 878 'can provide sample_weight as the third element of your' 879 'dataset, i.e. (inputs, targets, sample_weight). ' 880 'Received: x=%s, sample_weight=%s' % (x, sample_weight)) 881 if validation_split is not None and validation_split != 0.0: 882 raise ValueError( 883 '`validation_split` argument is not supported when ' 884 'input `x` is a dataset or a dataset iterator. ' 885 'Received: x=%s, validation_split=%f' % (x, validation_split)) 886 887 888def check_generator_arguments(y=None, sample_weight=None, 889 validation_split=None): 890 """Validates arguments passed when using a generator.""" 891 if y is not None: 892 raise ValueError('`y` argument is not supported when data is' 893 'a generator or Sequence instance. Instead pass targets' 894 ' as the second element of the generator.') 895 if sample_weight is not None: 896 raise ValueError('`sample_weight` argument is not supported when data is' 897 'a generator or Sequence instance. Instead pass sample' 898 ' weights as the third element of the generator.') 899 if validation_split: 900 raise ValueError('If your data is in the form of a Python generator, ' 901 'you cannot use `validation_split`.') 902 903 904def check_steps_argument(input_data, steps, steps_name): 905 """Validates `steps` argument based on input data's type. 906 907 The cases when `steps` value must be provided are when 908 1. input data passed is an iterator. 909 2. model was built on top of symbolic tensors, input data is not 910 required and is `None`. 911 3. input data passed is a symbolic tensor. 912 913 Arguments: 914 input_data: Input data. Can be Numpy array(s) or TensorFlow tensor(s) or 915 tf.data.Dataset iterator or `None`. 916 steps: Integer or `None`. Total number of steps (batches of samples) to 917 execute. 918 steps_name: The public API's parameter name for `steps`. 919 920 Returns: 921 boolean, True if `steps` argument is required, else False. 922 923 Raises: 924 ValueError: if `steps` argument is required for given input data type 925 but not provided. 926 """ 927 # TODO(fchollet): allow datasets with steps=None if cardinality is known. 928 is_x_iterator = isinstance( 929 input_data, (iterator_ops.Iterator, iterator_ops.EagerIterator)) 930 if (input_data is None or is_x_iterator or has_symbolic_tensors(input_data) or 931 (isinstance(input_data, list) and not input_data)): 932 if steps is None: 933 input_type_str = 'a Dataset iterator' if is_x_iterator else 'data tensors' 934 raise ValueError('When using {input_type} as input to a model, you should' 935 ' specify the `{steps_name}` argument.'.format( 936 input_type=input_type_str, steps_name=steps_name)) 937 return True 938 return False 939 940 941def cast_single_tensor(x): 942 if tensor_util.is_tensor(x) and x.dtype.is_floating: 943 return math_ops.cast(x, dtype=K.floatx()) 944 return x 945 946 947def cast_if_floating_dtype(x): 948 """Casts the given data tensors to the default floating point type. 949 950 Casts only if the input is already a floating point type. 951 Args: 952 x: tensor or list/tuple of tensors. 953 954 Returns: 955 Converted input. 956 957 Raises: 958 RuntimeError: if data isn't tensors. 959 """ 960 if not has_tensors(x): 961 raise RuntimeError( 962 'Please provide tensors for casting, got: {x}'.format(x=x)) 963 964 return nest.map_structure(cast_single_tensor, x) 965 966 967def get_output_sample_weight_and_mode(skip_target_weighing_indices, 968 sample_weight_mode, output_name, 969 output_index): 970 """Returns the sample weight and weight mode for a single output.""" 971 if output_index in skip_target_weighing_indices: 972 return None, None 973 974 if sample_weight_mode == 'temporal': 975 default_value = [[1.]] 976 shape = [None, None] 977 mode = 'temporal' 978 else: 979 default_value = [1.] 980 shape = [None] 981 mode = None 982 if context.executing_eagerly(): 983 weight = None 984 else: 985 weight = array_ops.placeholder_with_default( 986 constant_op.constant(default_value, dtype=K.floatx()), 987 shape=shape, 988 name=output_name + '_sample_weights') 989 return weight, mode 990 991 992def prepare_sample_weights(output_names, sample_weight_mode, 993 skip_target_weighing_indices): 994 """Prepares sample weights for the model. 995 996 Args: 997 output_names: List of model output names. 998 sample_weight_mode: sample weight mode user input passed from compile API. 999 skip_target_weighing_indices: Indices of output for which sample weights 1000 should be skipped. 1001 1002 Returns: 1003 A pair of list of sample weights and sample weight modes 1004 (one for each output). 1005 1006 Raises: 1007 ValueError: In case of invalid `sample_weight_mode` input. 1008 """ 1009 sample_weights = [] 1010 sample_weight_modes = [] 1011 if isinstance(sample_weight_mode, dict): 1012 unknown_output = set(sample_weight_mode.keys()) - set(output_names) 1013 if unknown_output: 1014 raise ValueError('Unknown entry in ' 1015 'sample_weight_mode dictionary: "' + unknown_output + 1016 '". Only expected the following keys: ' + 1017 str(output_names)) 1018 for i, name in enumerate(output_names): 1019 if (i not in skip_target_weighing_indices and 1020 name not in sample_weight_mode): 1021 raise ValueError('Output missing from sample_weight_modes dictionary') 1022 weight, mode = get_output_sample_weight_and_mode( 1023 skip_target_weighing_indices, sample_weight_mode.get(name), name, i) 1024 sample_weights.append(weight) 1025 sample_weight_modes.append(mode) 1026 elif isinstance(sample_weight_mode, list): 1027 if len(sample_weight_mode) != len(output_names): 1028 raise ValueError('When passing a list as sample_weight_mode, ' 1029 'it should have one entry per model output. ' 1030 'The model has ' + str(len(output_names)) + 1031 ' outputs, but you passed ' + 1032 str(len(sample_weight_mode)) + 'sample_weight_modes') 1033 for i, name in enumerate(output_names): 1034 weight, mode = get_output_sample_weight_and_mode( 1035 skip_target_weighing_indices, sample_weight_mode[i], name, i) 1036 sample_weights.append(weight) 1037 sample_weight_modes.append(mode) 1038 else: 1039 for i, name in enumerate(output_names): 1040 weight, mode = get_output_sample_weight_and_mode( 1041 skip_target_weighing_indices, sample_weight_mode, name, i) 1042 sample_weights.append(weight) 1043 sample_weight_modes.append(mode) 1044 return sample_weights, sample_weight_modes 1045 1046 1047def prepare_loss_functions(loss, output_names): 1048 """Converts loss to a list of loss functions. 1049 1050 Arguments: 1051 loss: String (name of objective function), objective function or 1052 `tf.losses.Loss` instance. See `tf.losses`. If the model has multiple 1053 outputs, you can use a different loss on each output by passing a 1054 dictionary or a list of losses. The loss value that will be minimized by 1055 the model will then be the sum of all individual losses. 1056 output_names: List of model output names. 1057 1058 Returns: 1059 A list of loss objective functions. 1060 1061 Raises: 1062 ValueError: If loss is a dict with keys not in model output names, 1063 or if loss is a list with len not equal to model outputs. 1064 """ 1065 if isinstance(loss, collections.Mapping): 1066 for name in loss: 1067 if name not in output_names: 1068 raise ValueError('Unknown entry in loss dictionary: {}. Only expected ' 1069 'following keys: {}'.format(name, output_names)) 1070 loss_functions = [] 1071 for name in output_names: 1072 if name not in loss: 1073 logging.warning( 1074 'Output {0} missing from loss dictionary. We assume ' 1075 'this was done on purpose. The fit and evaluate APIs will not be ' 1076 'expecting any data to be passed to {0}.'.format(name)) 1077 loss_functions.append(get_loss_function(loss.get(name, None))) 1078 elif isinstance(loss, six.string_types): 1079 loss_functions = [get_loss_function(loss) for _ in output_names] 1080 elif isinstance(loss, collections.Sequence): 1081 if len(loss) != len(output_names): 1082 raise ValueError('When passing a list as loss, it should have one entry ' 1083 'per model outputs. The model has {} outputs, but you ' 1084 'passed loss={}'.format(len(output_names), loss)) 1085 loss_functions = nest.map_structure(get_loss_function, loss) 1086 else: 1087 loss_functions = [get_loss_function(loss) for _ in range(len(output_names))] 1088 1089 return loss_functions 1090 1091 1092def prepare_loss_weights(output_names, loss_weights=None): 1093 """Converts loss weights to a list of loss weights. 1094 1095 Arguments: 1096 output_names: List of model output names. 1097 loss_weights: Optional list or dictionary specifying scalar coefficients 1098 (Python floats) to weight the loss contributions of different model 1099 outputs. The loss value that will be minimized by the model will then be 1100 the *weighted sum* of all individual losses, weighted by the 1101 `loss_weights` coefficients. If a list, it is expected to have a 1:1 1102 mapping to the model's outputs. If a dict, it is expected to map 1103 output names (strings) to scalar coefficients. 1104 1105 Returns: 1106 A list of loss weights of python floats. 1107 1108 Raises: 1109 ValueError: If loss weight is a dict with key not in model output names, 1110 or if loss is a list with len not equal to model outputs. 1111 """ 1112 if loss_weights is None: 1113 weights_list = [1.] * len(output_names) 1114 elif isinstance(loss_weights, dict): 1115 for name in loss_weights: 1116 if name not in output_names: 1117 raise ValueError('Unknown entry in loss_weights dictionary: {}. ' 1118 'Only expected the following keys: {}'.format( 1119 name, output_names)) 1120 weights_list = [loss_weights.get(name, 1.) for name in output_names] 1121 elif isinstance(loss_weights, list): 1122 if len(loss_weights) != len(output_names): 1123 raise ValueError('When passing a list as loss_weights, ' 1124 'it should have one entry per model output. ' 1125 'The model has ' + str(len(output_names)) + 1126 ' outputs, but you passed loss_weights=' + 1127 str(loss_weights)) 1128 weights_list = loss_weights 1129 else: 1130 raise TypeError('Could not interpret loss_weights argument: ' + 1131 str(loss_weights) + ' - expected a list of dicts.') 1132 1133 return weights_list 1134 1135 1136# TODO(rohanj): This is a hack to get around not depending on feature_column and 1137# create a cyclical dependency. Figure out a cleaner solution 1138def is_feature_layer(layer): 1139 """Returns whether `layer` is a FeatureLayer or not.""" 1140 return getattr(layer, '_is_feature_layer', False) 1141 1142 1143def is_eager_dataset_or_iterator(data): 1144 return context.executing_eagerly() and isinstance( 1145 data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2, 1146 iterator_ops.EagerIterator)) 1147 1148 1149# pylint: disable=protected-access 1150def assert_not_batched(dataset): 1151 """Asserts that `dataset` is not batched. 1152 1153 The algorithm used by this method is sound but not complete. In other words, 1154 if the method fails to establish the assertion, it does not mean the dataset 1155 is batched. 1156 1157 Example usage: 1158 ```python 1159 try: 1160 assert_not_batched(dataset) 1161 # safe to assume `dataset` it not batched here 1162 expect ValueError: 1163 # make no assumptions about `dataset` 1164 ``` 1165 1166 Args: 1167 dataset: The dataset to analyze. 1168 1169 Raises: 1170 ValueError: If the method cannot establish the assertion. 1171 """ 1172 if isinstance(dataset, dataset_ops.DatasetV1Adapter): 1173 return assert_not_batched(dataset._dataset) 1174 else: 1175 whitelisted_types = [ 1176 dataset_ops._OptionsDataset, 1177 dataset_ops.ConcatenateDataset, 1178 dataset_ops.CacheDataset, 1179 dataset_ops.FilterDataset, 1180 dataset_ops.MapDataset, 1181 dataset_ops.ParallelMapDataset, 1182 dataset_ops.PrefetchDataset, 1183 dataset_ops.RangeDataset, 1184 dataset_ops.RepeatDataset, 1185 dataset_ops.ShuffleDataset, 1186 dataset_ops.SkipDataset, 1187 dataset_ops.SparseTensorSliceDataset, 1188 dataset_ops.TakeDataset, 1189 dataset_ops.TensorDataset, 1190 dataset_ops.TensorSliceDataset, 1191 dataset_ops.ZipDataset, 1192 readers.FixedLengthRecordDatasetV2, 1193 readers.TextLineDatasetV2, 1194 readers.TFRecordDatasetV2, 1195 ] 1196 for ty in whitelisted_types: 1197 if isinstance(dataset, ty): 1198 for input_dataset in dataset._inputs(): 1199 assert_not_batched(input_dataset) 1200 return 1201 raise ValueError('Could not assert that dataset is not batched.') 1202 1203 1204# pylint: disable=protected-access 1205def assert_not_shuffled(dataset): 1206 """Asserts that `dataset` is not shuffled. 1207 1208 The algorithm used by this method is sound but not complete. In other words, 1209 if the method fails to establish the assertion, it does not mean the dataset 1210 is shuffled. 1211 1212 Example usage: 1213 ```python 1214 try: 1215 assert_not_shuffled(dataset) 1216 # safe to assume `dataset` it not shuffled here 1217 expect ValueError: 1218 # make no assumptions about `dataset` 1219 ``` 1220 1221 Args: 1222 dataset: The dataset to analyze. 1223 1224 Raises: 1225 ValueError: If the method cannot establish the assertion. 1226 """ 1227 if isinstance(dataset, dataset_ops.DatasetV1Adapter): 1228 return assert_not_shuffled(dataset._dataset) 1229 else: 1230 whitelisted_types = [ 1231 dataset_ops._OptionsDataset, 1232 dataset_ops.BatchDataset, 1233 dataset_ops.ConcatenateDataset, 1234 dataset_ops.CacheDataset, 1235 dataset_ops.FilterDataset, 1236 dataset_ops.MapDataset, 1237 dataset_ops.PaddedBatchDataset, 1238 dataset_ops.ParallelMapDataset, 1239 dataset_ops.PrefetchDataset, 1240 dataset_ops.RangeDataset, 1241 dataset_ops.RepeatDataset, 1242 dataset_ops.SkipDataset, 1243 dataset_ops.SparseTensorSliceDataset, 1244 dataset_ops.TakeDataset, 1245 dataset_ops.TensorDataset, 1246 dataset_ops.TensorSliceDataset, 1247 dataset_ops.WindowDataset, 1248 dataset_ops.ZipDataset, 1249 readers.FixedLengthRecordDatasetV2, 1250 readers.TextLineDatasetV2, 1251 readers.TFRecordDatasetV2, 1252 ] 1253 for ty in whitelisted_types: 1254 if isinstance(dataset, ty): 1255 for input_dataset in dataset._inputs(): 1256 assert_not_shuffled(input_dataset) 1257 return 1258 raise ValueError('Could not assert that dataset is not shuffled.') 1259 1260 1261def verify_dataset_shuffled(x): 1262 """Verifies that the dataset is shuffled. 1263 1264 Args: 1265 x: Dataset passed as an input to the model. 1266 1267 Raises: 1268 ValueError: if the dataset is not already shuffled. 1269 """ 1270 assert isinstance(x, dataset_ops.DatasetV2) 1271 try: 1272 assert_not_shuffled(x) 1273 except ValueError: 1274 # Dataset may or may not be shuffled. 1275 return 1276 else: 1277 logging.warning('Expected a shuffled dataset but input dataset `x` is ' 1278 'not shuffled. Please invoke `shuffle()` on input dataset.') 1279 1280 1281def is_dataset_or_iterator(data): 1282 return isinstance(data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2, 1283 iterator_ops.EagerIterator, iterator_ops.Iterator)) 1284 1285 1286def get_iterator(dataset): 1287 """Create and initialize an iterator from a dataset.""" 1288 iterator = dataset_ops.make_initializable_iterator(dataset) 1289 initialize_iterator(iterator) 1290 return iterator 1291 1292 1293def initialize_iterator(iterator): 1294 init_op = iterator.initializer 1295 if not context.executing_eagerly(): 1296 K.get_session((init_op,)).run(init_op) 1297 1298 1299def extract_tensors_from_dataset(dataset): 1300 """Extract a tuple of tensors `inputs, targets, sample_weight` from a dataset. 1301 1302 Arguments: 1303 dataset: Dataset instance. 1304 1305 Returns: 1306 Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None. 1307 """ 1308 iterator = get_iterator(dataset) 1309 inputs, targets, sample_weight = unpack_iterator_input(iterator) 1310 return inputs, targets, sample_weight 1311 1312 1313def unpack_iterator_input(iterator): 1314 """Convert a dataset iterator to a tuple of tensors `x, y, sample_weights`. 1315 1316 Arguments: 1317 iterator: Instance of a dataset iterator. 1318 1319 Returns: 1320 Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None. 1321 """ 1322 try: 1323 next_element = iterator.get_next() 1324 except errors.OutOfRangeError: 1325 raise RuntimeError('Your dataset iterator ran out of data; ' 1326 'Make sure that your dataset can generate ' 1327 'required number of samples.') 1328 1329 if isinstance(next_element, (list, tuple)): 1330 if len(next_element) not in [2, 3]: 1331 raise ValueError( 1332 'Please provide model inputs as a list or tuple of 2 or 3 ' 1333 'elements: (input, target) or (input, target, sample_weights) ' 1334 'Received %s' % next_element) 1335 if len(next_element) == 2: 1336 x, y = next_element 1337 weights = None 1338 else: 1339 x, y, weights = next_element 1340 else: 1341 x = next_element 1342 y = None 1343 weights = None 1344 return x, y, weights 1345 1346 1347def infer_steps_for_dataset(dataset, steps, epochs=1, steps_name='steps'): 1348 """Infers steps_per_epoch needed to loop through a dataset. 1349 1350 Arguments: 1351 dataset: Input data of type tf.data.Dataset. 1352 steps: Number of steps to draw from the dataset (may be None if unknown). 1353 epochs: Number of times to iterate over the dataset. 1354 steps_name: The string name of the steps argument, either `steps`, 1355 `validation_steps`, or `steps_per_epoch`. Only used for error message 1356 formatting. 1357 1358 Returns: 1359 Integer or `None`. Inferred number of steps to loop through the dataset. 1360 `None` is returned if the size of the dataset is unknown and `steps` was 1361 not specified. 1362 1363 Raises: 1364 ValueError: In case of invalid argument values. 1365 """ 1366 assert isinstance(dataset, dataset_ops.DatasetV2) 1367 size = K.get_value(cardinality.cardinality(dataset)) 1368 if size == cardinality.INFINITE and steps is None: 1369 raise ValueError('When passing an infinitely repeating dataset, you ' 1370 'must specify the `%s` argument.' % (steps_name,)) 1371 if size >= 0: 1372 if steps is not None and steps * epochs > size: 1373 if epochs > 1: 1374 raise ValueError('The dataset you passed contains %s batches, but you ' 1375 'passed `epochs=%s` and `%s=%s`, which is a total of ' 1376 '%s steps. We cannot draw that many steps from this ' 1377 'dataset. We suggest to set `%s=%s`.' % 1378 (size, epochs, steps_name, steps, steps * epochs, 1379 steps_name, size // epochs)) 1380 else: 1381 raise ValueError('The dataset you passed contains %s batches, but you ' 1382 'passed `%s=%s`. We cannot draw that many steps from ' 1383 'this dataset. We suggest to set `%s=%s`.' % 1384 (size, steps_name, steps, steps_name, size)) 1385 if steps is None: 1386 if size >= 0: 1387 return size 1388 return None 1389 return steps 1390 1391 1392class ModelInputs(object): 1393 """Encapsulates model inputs. 1394 1395 Allows for transforming model inputs while keeping the same structure. 1396 """ 1397 1398 def __init__(self, inputs): 1399 self._inputs = inputs 1400 self._is_dict = isinstance(self._inputs, dict) 1401 self._is_single_input = not isinstance(self._inputs, (list, tuple, dict)) 1402 1403 self._flattened_inputs = [] 1404 self._input_names = [] 1405 1406 if self._is_dict: 1407 for k in sorted(self._inputs.keys()): 1408 self._flattened_inputs.append(self._inputs[k]) 1409 self._input_names.append(k) 1410 else: 1411 self._flattened_inputs = nest.flatten(self._inputs) 1412 self._input_names = [ 1413 'input_%d' % (i + 1) for i in range(len(self._flattened_inputs)) 1414 ] 1415 1416 def get_input_names(self): 1417 """Returns keys to name inputs by. 1418 1419 In case inputs provided were a list, tuple or single entry, we make up a 1420 key 'input_%d'. For dictionary case, we return a sorted list of keys. 1421 """ 1422 return self._input_names 1423 1424 def get_symbolic_inputs(self, return_single_as_list=False): 1425 """Returns inputs to be set as self.inputs for a model.""" 1426 # TODO(karmel): There is a side-effect here where what you get 1427 # with as_list and as_dict depends on whether you have called this 1428 # method first, since it modifies in place. 1429 for i in range(len(self._flattened_inputs)): 1430 k = self._input_names[i] 1431 v = self._flattened_inputs[i] 1432 if isinstance(v, (list, float, int)): 1433 v = np.asarray(v) 1434 if v.ndim == 1: 1435 v = np.expand_dims(v, 1) 1436 1437 if isinstance(v, (np.ndarray, ops.EagerTensor)): 1438 # We fix the placeholder shape except the batch size. 1439 # This is suboptimal, but it is the best we can do with the info 1440 # we have. The user should call `model._set_inputs(placeholders)` 1441 # to specify custom placeholders if the need arises. 1442 shape = (None,) + tuple(v.shape[1:]) 1443 dtype = dtypes.as_dtype(v.dtype) 1444 if dtype.is_floating: 1445 dtype = K.floatx() 1446 v = K.placeholder(shape=shape, name=k, dtype=dtype) 1447 elif isinstance(v, tensor_shape.TensorShape): 1448 shape = (None,) + tuple(v.as_list()[1:]) 1449 v = K.placeholder(shape=shape, name=k) 1450 1451 self._flattened_inputs[i] = v 1452 1453 if self._is_dict: 1454 return dict(zip(self._input_names, self._flattened_inputs)) 1455 if self._is_single_input and not return_single_as_list: 1456 return self._flattened_inputs[0] 1457 return self._flattened_inputs 1458 1459 def as_dict(self): 1460 """An iterable over a dictionary version of inputs.""" 1461 for i in range(len(self._flattened_inputs)): 1462 yield self._input_names[i], self._flattened_inputs[i] 1463 1464 def as_list(self): 1465 """Returning the inputs as a list.""" 1466 return self._flattened_inputs 1467 1468 1469# Allow use of methods not exposed to the user. 1470# pylint: disable=protected-access 1471def get_input_shape_and_dtype(layer): 1472 """Retrieves input shape and input dtype of layer if applicable. 1473 1474 Args: 1475 layer: Layer (or model) instance. 1476 1477 Returns: 1478 Tuple (input_shape, input_dtype). Both could be None if the layer 1479 does not have a defined input shape. 1480 1481 Raises: 1482 ValueError: in case an empty Sequential or Functional model is passed. 1483 """ 1484 1485 def _is_graph_model(layer): 1486 return ((hasattr(layer, '_is_graph_network') and layer._is_graph_network) or 1487 layer.__class__.__name__ == 'Sequential') 1488 1489 # In case of nested models: recover the first layer 1490 # of the deepest model to infer input shape and dtype. 1491 # Subclassed Models may not have been built so can't be checked. 1492 while _is_graph_model(layer): 1493 if not layer.layers: 1494 raise ValueError('An empty Model cannot be used as a Layer.') 1495 layer = layer.layers[0] 1496 1497 if hasattr(layer, '_batch_input_shape'): 1498 return layer._batch_input_shape, layer.dtype 1499 return None, None 1500 1501 1502# pylint: enable=protected-access 1503 1504 1505def get_static_batch_size(layer): 1506 """Gets the static batch size of a Layer. 1507 1508 Arguments: 1509 layer: a `Layer` instance. 1510 1511 Returns: 1512 The static batch size of a Layer. 1513 """ 1514 batch_input_shape, _ = get_input_shape_and_dtype(layer) 1515 if batch_input_shape is not None: 1516 return tensor_shape.as_dimension(batch_input_shape[0]).value 1517 return None 1518 1519 1520def generic_output_names(outputs_list): 1521 return ['output_%d' % (i + 1) for i in range(len(outputs_list))] 1522 1523 1524def convert_eager_tensors_to_numpy(structure): 1525 """Convert every EagerTensor in `structure` to NumPy. 1526 1527 Arguments: 1528 structure: An arbitrary structure of elements to be converted to NumPy 1529 arrays. 1530 1531 Returns: 1532 An identical structure with EagerTensors converted to NumPy arrays. 1533 """ 1534 1535 def _convert(element): 1536 if isinstance(element, ops.EagerTensor): 1537 return element.numpy() 1538 return element 1539 1540 return nest.map_structure(_convert, structure) 1541 1542 1543def should_run_validation(validation_freq, epoch): 1544 """Checks if validation should be run this epoch. 1545 1546 Arguments: 1547 validation_freq: Integer or list. If an integer, specifies how many training 1548 epochs to run before a new validation run is performed. If a list, 1549 specifies the epochs on which to run validation. 1550 epoch: Integer, the number of the training epoch just completed. 1551 1552 Returns: 1553 Bool, True if validation should be run. 1554 1555 Raises: 1556 ValueError: if `validation_freq` is an Integer and less than 1, or if 1557 it is neither an Integer nor a Sequence. 1558 """ 1559 # `epoch` is 0-indexed internally but 1-indexed in the public API. 1560 one_indexed_epoch = epoch + 1 1561 1562 if isinstance(validation_freq, int): 1563 if validation_freq < 1: 1564 raise ValueError('`validation_freq` can not be less than 1.') 1565 return one_indexed_epoch % validation_freq == 0 1566 1567 if not isinstance(validation_freq, collections.Container): 1568 raise ValueError('`validation_freq` must be an Integer or ' 1569 '`collections.Container` (e.g. list, tuple, etc.)') 1570 return one_indexed_epoch in validation_freq 1571