1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for unit-testing Keras.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import contextlib 23import functools 24import itertools 25import threading 26 27import numpy as np 28 29from tensorflow.python import tf2 30from tensorflow.python.eager import context 31from tensorflow.python.framework import config 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import tensor_shape 35from tensorflow.python.framework import tensor_spec 36from tensorflow.python.framework import test_util 37from tensorflow.python.keras import backend 38from tensorflow.python.keras import layers 39from tensorflow.python.keras import models 40from tensorflow.python.keras.engine import base_layer_utils 41from tensorflow.python.keras.engine import keras_tensor 42from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2 43from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2 44from tensorflow.python.keras.optimizer_v2 import adam as adam_v2 45from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2 46from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 47from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2 48from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2 49from tensorflow.python.keras.utils import tf_contextlib 50from tensorflow.python.keras.utils import tf_inspect 51from tensorflow.python.util import tf_decorator 52 53 54def string_test(actual, expected): 55 np.testing.assert_array_equal(actual, expected) 56 57 58def numeric_test(actual, expected): 59 np.testing.assert_allclose(actual, expected, rtol=1e-3, atol=1e-6) 60 61 62def get_test_data(train_samples, 63 test_samples, 64 input_shape, 65 num_classes, 66 random_seed=None): 67 """Generates test data to train a model on. 68 69 Args: 70 train_samples: Integer, how many training samples to generate. 71 test_samples: Integer, how many test samples to generate. 72 input_shape: Tuple of integers, shape of the inputs. 73 num_classes: Integer, number of classes for the data and targets. 74 random_seed: Integer, random seed used by numpy to generate data. 75 76 Returns: 77 A tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. 78 """ 79 if random_seed is not None: 80 np.random.seed(random_seed) 81 num_sample = train_samples + test_samples 82 templates = 2 * num_classes * np.random.random((num_classes,) + input_shape) 83 y = np.random.randint(0, num_classes, size=(num_sample,)) 84 x = np.zeros((num_sample,) + input_shape, dtype=np.float32) 85 for i in range(num_sample): 86 x[i] = templates[y[i]] + np.random.normal(loc=0, scale=1., size=input_shape) 87 return ((x[:train_samples], y[:train_samples]), 88 (x[train_samples:], y[train_samples:])) 89 90 91@test_util.disable_cudnn_autotune 92def layer_test(layer_cls, 93 kwargs=None, 94 input_shape=None, 95 input_dtype=None, 96 input_data=None, 97 expected_output=None, 98 expected_output_dtype=None, 99 expected_output_shape=None, 100 validate_training=True, 101 adapt_data=None, 102 custom_objects=None, 103 test_harness=None, 104 supports_masking=None): 105 """Test routine for a layer with a single input and single output. 106 107 Args: 108 layer_cls: Layer class object. 109 kwargs: Optional dictionary of keyword arguments for instantiating the 110 layer. 111 input_shape: Input shape tuple. 112 input_dtype: Data type of the input data. 113 input_data: Numpy array of input data. 114 expected_output: Numpy array of the expected output. 115 expected_output_dtype: Data type expected for the output. 116 expected_output_shape: Shape tuple for the expected shape of the output. 117 validate_training: Whether to attempt to validate training on this layer. 118 This might be set to False for non-differentiable layers that output 119 string or integer values. 120 adapt_data: Optional data for an 'adapt' call. If None, adapt() will not 121 be tested for this layer. This is only relevant for PreprocessingLayers. 122 custom_objects: Optional dictionary mapping name strings to custom objects 123 in the layer class. This is helpful for testing custom layers. 124 test_harness: The Tensorflow test, if any, that this function is being 125 called in. 126 supports_masking: Optional boolean to check the `supports_masking` property 127 of the layer. If None, the check will not be performed. 128 129 Returns: 130 The output data (Numpy array) returned by the layer, for additional 131 checks to be done by the calling code. 132 133 Raises: 134 ValueError: if `input_shape is None`. 135 """ 136 if input_data is None: 137 if input_shape is None: 138 raise ValueError('input_shape is None') 139 if not input_dtype: 140 input_dtype = 'float32' 141 input_data_shape = list(input_shape) 142 for i, e in enumerate(input_data_shape): 143 if e is None: 144 input_data_shape[i] = np.random.randint(1, 4) 145 input_data = 10 * np.random.random(input_data_shape) 146 if input_dtype[:5] == 'float': 147 input_data -= 0.5 148 input_data = input_data.astype(input_dtype) 149 elif input_shape is None: 150 input_shape = input_data.shape 151 if input_dtype is None: 152 input_dtype = input_data.dtype 153 if expected_output_dtype is None: 154 expected_output_dtype = input_dtype 155 156 if dtypes.as_dtype(expected_output_dtype) == dtypes.string: 157 if test_harness: 158 assert_equal = test_harness.assertAllEqual 159 else: 160 assert_equal = string_test 161 else: 162 if test_harness: 163 assert_equal = test_harness.assertAllClose 164 else: 165 assert_equal = numeric_test 166 167 # instantiation 168 kwargs = kwargs or {} 169 layer = layer_cls(**kwargs) 170 171 if (supports_masking is not None 172 and layer.supports_masking != supports_masking): 173 raise AssertionError( 174 'When testing layer %s, the `supports_masking` property is %r' 175 'but expected to be %r.\nFull kwargs: %s' % 176 (layer_cls.__name__, layer.supports_masking, supports_masking, kwargs)) 177 178 # Test adapt, if data was passed. 179 if adapt_data is not None: 180 layer.adapt(adapt_data) 181 182 # test get_weights , set_weights at layer level 183 weights = layer.get_weights() 184 layer.set_weights(weights) 185 186 # test and instantiation from weights 187 if 'weights' in tf_inspect.getargspec(layer_cls.__init__): 188 kwargs['weights'] = weights 189 layer = layer_cls(**kwargs) 190 191 # test in functional API 192 x = layers.Input(shape=input_shape[1:], dtype=input_dtype) 193 y = layer(x) 194 if backend.dtype(y) != expected_output_dtype: 195 raise AssertionError('When testing layer %s, for input %s, found output ' 196 'dtype=%s but expected to find %s.\nFull kwargs: %s' % 197 (layer_cls.__name__, x, backend.dtype(y), 198 expected_output_dtype, kwargs)) 199 200 def assert_shapes_equal(expected, actual): 201 """Asserts that the output shape from the layer matches the actual shape.""" 202 if len(expected) != len(actual): 203 raise AssertionError( 204 'When testing layer %s, for input %s, found output_shape=' 205 '%s but expected to find %s.\nFull kwargs: %s' % 206 (layer_cls.__name__, x, actual, expected, kwargs)) 207 208 for expected_dim, actual_dim in zip(expected, actual): 209 if isinstance(expected_dim, tensor_shape.Dimension): 210 expected_dim = expected_dim.value 211 if isinstance(actual_dim, tensor_shape.Dimension): 212 actual_dim = actual_dim.value 213 if expected_dim is not None and expected_dim != actual_dim: 214 raise AssertionError( 215 'When testing layer %s, for input %s, found output_shape=' 216 '%s but expected to find %s.\nFull kwargs: %s' % 217 (layer_cls.__name__, x, actual, expected, kwargs)) 218 219 if expected_output_shape is not None: 220 assert_shapes_equal(tensor_shape.TensorShape(expected_output_shape), 221 y.shape) 222 223 # check shape inference 224 model = models.Model(x, y) 225 computed_output_shape = tuple( 226 layer.compute_output_shape( 227 tensor_shape.TensorShape(input_shape)).as_list()) 228 computed_output_signature = layer.compute_output_signature( 229 tensor_spec.TensorSpec(shape=input_shape, dtype=input_dtype)) 230 actual_output = model.predict(input_data) 231 actual_output_shape = actual_output.shape 232 assert_shapes_equal(computed_output_shape, actual_output_shape) 233 assert_shapes_equal(computed_output_signature.shape, actual_output_shape) 234 if computed_output_signature.dtype != actual_output.dtype: 235 raise AssertionError( 236 'When testing layer %s, for input %s, found output_dtype=' 237 '%s but expected to find %s.\nFull kwargs: %s' % 238 (layer_cls.__name__, x, actual_output.dtype, 239 computed_output_signature.dtype, kwargs)) 240 if expected_output is not None: 241 assert_equal(actual_output, expected_output) 242 243 # test serialization, weight setting at model level 244 model_config = model.get_config() 245 recovered_model = models.Model.from_config(model_config, custom_objects) 246 if model.weights: 247 weights = model.get_weights() 248 recovered_model.set_weights(weights) 249 output = recovered_model.predict(input_data) 250 assert_equal(output, actual_output) 251 252 # test training mode (e.g. useful for dropout tests) 253 # Rebuild the model to avoid the graph being reused between predict() and 254 # See b/120160788 for more details. This should be mitigated after 2.0. 255 layer_weights = layer.get_weights() # Get the layer weights BEFORE training. 256 if validate_training: 257 model = models.Model(x, layer(x)) 258 if _thread_local_data.run_eagerly is not None: 259 model.compile( 260 'rmsprop', 261 'mse', 262 weighted_metrics=['acc'], 263 run_eagerly=should_run_eagerly()) 264 else: 265 model.compile('rmsprop', 'mse', weighted_metrics=['acc']) 266 model.train_on_batch(input_data, actual_output) 267 268 # test as first layer in Sequential API 269 layer_config = layer.get_config() 270 layer_config['batch_input_shape'] = input_shape 271 layer = layer.__class__.from_config(layer_config) 272 273 # Test adapt, if data was passed. 274 if adapt_data is not None: 275 layer.adapt(adapt_data) 276 277 model = models.Sequential() 278 model.add(layers.Input(shape=input_shape[1:], dtype=input_dtype)) 279 model.add(layer) 280 281 layer.set_weights(layer_weights) 282 actual_output = model.predict(input_data) 283 actual_output_shape = actual_output.shape 284 for expected_dim, actual_dim in zip(computed_output_shape, 285 actual_output_shape): 286 if expected_dim is not None: 287 if expected_dim != actual_dim: 288 raise AssertionError( 289 'When testing layer %s **after deserialization**, ' 290 'for input %s, found output_shape=' 291 '%s but expected to find inferred shape %s.\nFull kwargs: %s' % 292 (layer_cls.__name__, 293 x, 294 actual_output_shape, 295 computed_output_shape, 296 kwargs)) 297 if expected_output is not None: 298 assert_equal(actual_output, expected_output) 299 300 # test serialization, weight setting at model level 301 model_config = model.get_config() 302 recovered_model = models.Sequential.from_config(model_config, custom_objects) 303 if model.weights: 304 weights = model.get_weights() 305 recovered_model.set_weights(weights) 306 output = recovered_model.predict(input_data) 307 assert_equal(output, actual_output) 308 309 # for further checks in the caller function 310 return actual_output 311 312 313_thread_local_data = threading.local() 314_thread_local_data.model_type = None 315_thread_local_data.run_eagerly = None 316_thread_local_data.saved_model_format = None 317_thread_local_data.save_kwargs = None 318 319 320@tf_contextlib.contextmanager 321def model_type_scope(value): 322 """Provides a scope within which the model type to test is equal to `value`. 323 324 The model type gets restored to its original value upon exiting the scope. 325 326 Args: 327 value: model type value 328 329 Yields: 330 The provided value. 331 """ 332 previous_value = _thread_local_data.model_type 333 try: 334 _thread_local_data.model_type = value 335 yield value 336 finally: 337 # Restore model type to initial value. 338 _thread_local_data.model_type = previous_value 339 340 341@tf_contextlib.contextmanager 342def run_eagerly_scope(value): 343 """Provides a scope within which we compile models to run eagerly or not. 344 345 The boolean gets restored to its original value upon exiting the scope. 346 347 Args: 348 value: Bool specifying if we should run models eagerly in the active test. 349 Should be True or False. 350 351 Yields: 352 The provided value. 353 """ 354 previous_value = _thread_local_data.run_eagerly 355 try: 356 _thread_local_data.run_eagerly = value 357 yield value 358 finally: 359 # Restore model type to initial value. 360 _thread_local_data.run_eagerly = previous_value 361 362 363@tf_contextlib.contextmanager 364def use_keras_tensors_scope(value): 365 """Provides a scope within which we use KerasTensors in the func. API or not. 366 367 The boolean gets restored to its original value upon exiting the scope. 368 369 Args: 370 value: Bool specifying if we should build functional models 371 using KerasTensors in the active test. 372 Should be True or False. 373 374 Yields: 375 The provided value. 376 """ 377 previous_value = keras_tensor._KERAS_TENSORS_ENABLED # pylint: disable=protected-access 378 try: 379 keras_tensor._KERAS_TENSORS_ENABLED = value # pylint: disable=protected-access 380 yield value 381 finally: 382 # Restore KerasTensor usage to initial value. 383 keras_tensor._KERAS_TENSORS_ENABLED = previous_value # pylint: disable=protected-access 384 385 386def should_run_eagerly(): 387 """Returns whether the models we are testing should be run eagerly.""" 388 if _thread_local_data.run_eagerly is None: 389 raise ValueError('Cannot call `should_run_eagerly()` outside of a ' 390 '`run_eagerly_scope()` or `run_all_keras_modes` ' 391 'decorator.') 392 393 return _thread_local_data.run_eagerly and context.executing_eagerly() 394 395 396@tf_contextlib.contextmanager 397def saved_model_format_scope(value, **kwargs): 398 """Provides a scope within which the savde model format to test is `value`. 399 400 The saved model format gets restored to its original value upon exiting the 401 scope. 402 403 Args: 404 value: saved model format value 405 **kwargs: optional kwargs to pass to the save function. 406 407 Yields: 408 The provided value. 409 """ 410 previous_format = _thread_local_data.saved_model_format 411 previous_kwargs = _thread_local_data.save_kwargs 412 try: 413 _thread_local_data.saved_model_format = value 414 _thread_local_data.save_kwargs = kwargs 415 yield 416 finally: 417 # Restore saved model format to initial value. 418 _thread_local_data.saved_model_format = previous_format 419 _thread_local_data.save_kwargs = previous_kwargs 420 421 422def get_save_format(): 423 if _thread_local_data.saved_model_format is None: 424 raise ValueError( 425 'Cannot call `get_save_format()` outside of a ' 426 '`saved_model_format_scope()` or `run_with_all_saved_model_formats` ' 427 'decorator.') 428 return _thread_local_data.saved_model_format 429 430 431def get_save_kwargs(): 432 if _thread_local_data.save_kwargs is None: 433 raise ValueError( 434 'Cannot call `get_save_kwargs()` outside of a ' 435 '`saved_model_format_scope()` or `run_with_all_saved_model_formats` ' 436 'decorator.') 437 return _thread_local_data.save_kwargs or {} 438 439 440def get_model_type(): 441 """Gets the model type that should be tested.""" 442 if _thread_local_data.model_type is None: 443 raise ValueError('Cannot call `get_model_type()` outside of a ' 444 '`model_type_scope()` or `run_with_all_model_types` ' 445 'decorator.') 446 447 return _thread_local_data.model_type 448 449 450def get_small_sequential_mlp(num_hidden, num_classes, input_dim=None): 451 model = models.Sequential() 452 if input_dim: 453 model.add(layers.Dense(num_hidden, activation='relu', input_dim=input_dim)) 454 else: 455 model.add(layers.Dense(num_hidden, activation='relu')) 456 activation = 'sigmoid' if num_classes == 1 else 'softmax' 457 model.add(layers.Dense(num_classes, activation=activation)) 458 return model 459 460 461def get_small_functional_mlp(num_hidden, num_classes, input_dim): 462 inputs = layers.Input(shape=(input_dim,)) 463 outputs = layers.Dense(num_hidden, activation='relu')(inputs) 464 activation = 'sigmoid' if num_classes == 1 else 'softmax' 465 outputs = layers.Dense(num_classes, activation=activation)(outputs) 466 return models.Model(inputs, outputs) 467 468 469class SmallSubclassMLP(models.Model): 470 """A subclass model based small MLP.""" 471 472 def __init__(self, num_hidden, num_classes, use_bn=False, use_dp=False): 473 super(SmallSubclassMLP, self).__init__(name='test_model') 474 self.use_bn = use_bn 475 self.use_dp = use_dp 476 477 self.layer_a = layers.Dense(num_hidden, activation='relu') 478 activation = 'sigmoid' if num_classes == 1 else 'softmax' 479 self.layer_b = layers.Dense(num_classes, activation=activation) 480 if self.use_dp: 481 self.dp = layers.Dropout(0.5) 482 if self.use_bn: 483 self.bn = layers.BatchNormalization(axis=-1) 484 485 def call(self, inputs, **kwargs): 486 x = self.layer_a(inputs) 487 if self.use_dp: 488 x = self.dp(x) 489 if self.use_bn: 490 x = self.bn(x) 491 return self.layer_b(x) 492 493 494class _SmallSubclassMLPCustomBuild(models.Model): 495 """A subclass model small MLP that uses a custom build method.""" 496 497 def __init__(self, num_hidden, num_classes): 498 super(_SmallSubclassMLPCustomBuild, self).__init__() 499 self.layer_a = None 500 self.layer_b = None 501 self.num_hidden = num_hidden 502 self.num_classes = num_classes 503 504 def build(self, input_shape): 505 self.layer_a = layers.Dense(self.num_hidden, activation='relu') 506 activation = 'sigmoid' if self.num_classes == 1 else 'softmax' 507 self.layer_b = layers.Dense(self.num_classes, activation=activation) 508 509 def call(self, inputs, **kwargs): 510 x = self.layer_a(inputs) 511 return self.layer_b(x) 512 513 514def get_small_subclass_mlp(num_hidden, num_classes): 515 return SmallSubclassMLP(num_hidden, num_classes) 516 517 518def get_small_subclass_mlp_with_custom_build(num_hidden, num_classes): 519 return _SmallSubclassMLPCustomBuild(num_hidden, num_classes) 520 521 522def get_small_mlp(num_hidden, num_classes, input_dim): 523 """Get a small mlp of the model type specified by `get_model_type`.""" 524 model_type = get_model_type() 525 if model_type == 'subclass': 526 return get_small_subclass_mlp(num_hidden, num_classes) 527 if model_type == 'subclass_custom_build': 528 return get_small_subclass_mlp_with_custom_build(num_hidden, num_classes) 529 if model_type == 'sequential': 530 return get_small_sequential_mlp(num_hidden, num_classes, input_dim) 531 if model_type == 'functional': 532 return get_small_functional_mlp(num_hidden, num_classes, input_dim) 533 raise ValueError('Unknown model type {}'.format(model_type)) 534 535 536class _SubclassModel(models.Model): 537 """A Keras subclass model.""" 538 539 def __init__(self, model_layers, *args, **kwargs): 540 """Instantiate a model. 541 542 Args: 543 model_layers: a list of layers to be added to the model. 544 *args: Model's args 545 **kwargs: Model's keyword args, at most one of input_tensor -> the input 546 tensor required for ragged/sparse input. 547 """ 548 549 inputs = kwargs.pop('input_tensor', None) 550 super(_SubclassModel, self).__init__(*args, **kwargs) 551 # Note that clone and build doesn't support lists of layers in subclassed 552 # models. Adding each layer directly here. 553 for i, layer in enumerate(model_layers): 554 setattr(self, self._layer_name_for_i(i), layer) 555 556 self.num_layers = len(model_layers) 557 558 if inputs is not None: 559 self._set_inputs(inputs) 560 561 def _layer_name_for_i(self, i): 562 return 'layer{}'.format(i) 563 564 def call(self, inputs, **kwargs): 565 x = inputs 566 for i in range(self.num_layers): 567 layer = getattr(self, self._layer_name_for_i(i)) 568 x = layer(x) 569 return x 570 571 572class _SubclassModelCustomBuild(models.Model): 573 """A Keras subclass model that uses a custom build method.""" 574 575 def __init__(self, layer_generating_func, *args, **kwargs): 576 super(_SubclassModelCustomBuild, self).__init__(*args, **kwargs) 577 self.all_layers = None 578 self._layer_generating_func = layer_generating_func 579 580 def build(self, input_shape): 581 model_layers = [] 582 for layer in self._layer_generating_func(): 583 model_layers.append(layer) 584 self.all_layers = model_layers 585 586 def call(self, inputs, **kwargs): 587 x = inputs 588 for layer in self.all_layers: 589 x = layer(x) 590 return x 591 592 593def get_model_from_layers(model_layers, 594 input_shape=None, 595 input_dtype=None, 596 name=None, 597 input_ragged=None, 598 input_sparse=None): 599 """Builds a model from a sequence of layers. 600 601 Args: 602 model_layers: The layers used to build the network. 603 input_shape: Shape tuple of the input or 'TensorShape' instance. 604 input_dtype: Datatype of the input. 605 name: Name for the model. 606 input_ragged: Boolean, whether the input data is a ragged tensor. 607 input_sparse: Boolean, whether the input data is a sparse tensor. 608 609 Returns: 610 A Keras model. 611 """ 612 613 model_type = get_model_type() 614 if model_type == 'subclass': 615 inputs = None 616 if input_ragged or input_sparse: 617 inputs = layers.Input( 618 shape=input_shape, 619 dtype=input_dtype, 620 ragged=input_ragged, 621 sparse=input_sparse) 622 return _SubclassModel(model_layers, name=name, input_tensor=inputs) 623 624 if model_type == 'subclass_custom_build': 625 layer_generating_func = lambda: model_layers 626 return _SubclassModelCustomBuild(layer_generating_func, name=name) 627 628 if model_type == 'sequential': 629 model = models.Sequential(name=name) 630 if input_shape: 631 model.add( 632 layers.InputLayer( 633 input_shape=input_shape, 634 dtype=input_dtype, 635 ragged=input_ragged, 636 sparse=input_sparse)) 637 for layer in model_layers: 638 model.add(layer) 639 return model 640 641 if model_type == 'functional': 642 if not input_shape: 643 raise ValueError('Cannot create a functional model from layers with no ' 644 'input shape.') 645 inputs = layers.Input( 646 shape=input_shape, 647 dtype=input_dtype, 648 ragged=input_ragged, 649 sparse=input_sparse) 650 outputs = inputs 651 for layer in model_layers: 652 outputs = layer(outputs) 653 return models.Model(inputs, outputs, name=name) 654 655 raise ValueError('Unknown model type {}'.format(model_type)) 656 657 658class Bias(layers.Layer): 659 660 def build(self, input_shape): 661 self.bias = self.add_variable('bias', (1,), initializer='zeros') 662 663 def call(self, inputs): 664 return inputs + self.bias 665 666 667class _MultiIOSubclassModel(models.Model): 668 """Multi IO Keras subclass model.""" 669 670 def __init__(self, branch_a, branch_b, shared_input_branch=None, 671 shared_output_branch=None, name=None): 672 super(_MultiIOSubclassModel, self).__init__(name=name) 673 self._shared_input_branch = shared_input_branch 674 self._branch_a = branch_a 675 self._branch_b = branch_b 676 self._shared_output_branch = shared_output_branch 677 678 def call(self, inputs, **kwargs): 679 if self._shared_input_branch: 680 for layer in self._shared_input_branch: 681 inputs = layer(inputs) 682 a = inputs 683 b = inputs 684 elif isinstance(inputs, dict): 685 a = inputs['input_1'] 686 b = inputs['input_2'] 687 else: 688 a, b = inputs 689 690 for layer in self._branch_a: 691 a = layer(a) 692 for layer in self._branch_b: 693 b = layer(b) 694 outs = [a, b] 695 696 if self._shared_output_branch: 697 for layer in self._shared_output_branch: 698 outs = layer(outs) 699 700 return outs 701 702 703class _MultiIOSubclassModelCustomBuild(models.Model): 704 """Multi IO Keras subclass model that uses a custom build method.""" 705 706 def __init__(self, branch_a_func, branch_b_func, 707 shared_input_branch_func=None, 708 shared_output_branch_func=None): 709 super(_MultiIOSubclassModelCustomBuild, self).__init__() 710 self._shared_input_branch_func = shared_input_branch_func 711 self._branch_a_func = branch_a_func 712 self._branch_b_func = branch_b_func 713 self._shared_output_branch_func = shared_output_branch_func 714 715 self._shared_input_branch = None 716 self._branch_a = None 717 self._branch_b = None 718 self._shared_output_branch = None 719 720 def build(self, input_shape): 721 if self._shared_input_branch_func(): 722 self._shared_input_branch = self._shared_input_branch_func() 723 self._branch_a = self._branch_a_func() 724 self._branch_b = self._branch_b_func() 725 726 if self._shared_output_branch_func(): 727 self._shared_output_branch = self._shared_output_branch_func() 728 729 def call(self, inputs, **kwargs): 730 if self._shared_input_branch: 731 for layer in self._shared_input_branch: 732 inputs = layer(inputs) 733 a = inputs 734 b = inputs 735 else: 736 a, b = inputs 737 738 for layer in self._branch_a: 739 a = layer(a) 740 for layer in self._branch_b: 741 b = layer(b) 742 outs = a, b 743 744 if self._shared_output_branch: 745 for layer in self._shared_output_branch: 746 outs = layer(outs) 747 748 return outs 749 750 751def get_multi_io_model( 752 branch_a, 753 branch_b, 754 shared_input_branch=None, 755 shared_output_branch=None): 756 """Builds a multi-io model that contains two branches. 757 758 The produced model will be of the type specified by `get_model_type`. 759 760 To build a two-input, two-output model: 761 Specify a list of layers for branch a and branch b, but do not specify any 762 shared input branch or shared output branch. The resulting model will apply 763 each branch to a different input, to produce two outputs. 764 765 The first value in branch_a must be the Keras 'Input' layer for branch a, 766 and the first value in branch_b must be the Keras 'Input' layer for 767 branch b. 768 769 example usage: 770 ``` 771 branch_a = [Input(shape=(2,), name='a'), Dense(), Dense()] 772 branch_b = [Input(shape=(3,), name='b'), Dense(), Dense()] 773 774 model = get_multi_io_model(branch_a, branch_b) 775 ``` 776 777 To build a two-input, one-output model: 778 Specify a list of layers for branch a and branch b, and specify a 779 shared output branch. The resulting model will apply 780 each branch to a different input. It will then apply the shared output 781 branch to a tuple containing the intermediate outputs of each branch, 782 to produce a single output. The first layer in the shared_output_branch 783 must be able to merge a tuple of two tensors. 784 785 The first value in branch_a must be the Keras 'Input' layer for branch a, 786 and the first value in branch_b must be the Keras 'Input' layer for 787 branch b. 788 789 example usage: 790 ``` 791 input_branch_a = [Input(shape=(2,), name='a'), Dense(), Dense()] 792 input_branch_b = [Input(shape=(3,), name='b'), Dense(), Dense()] 793 shared_output_branch = [Concatenate(), Dense(), Dense()] 794 795 model = get_multi_io_model(input_branch_a, input_branch_b, 796 shared_output_branch=shared_output_branch) 797 ``` 798 To build a one-input, two-output model: 799 Specify a list of layers for branch a and branch b, and specify a 800 shared input branch. The resulting model will take one input, and apply 801 the shared input branch to it. It will then respectively apply each branch 802 to that intermediate result in parallel, to produce two outputs. 803 804 The first value in the shared_input_branch must be the Keras 'Input' layer 805 for the whole model. Branch a and branch b should not contain any Input 806 layers. 807 808 example usage: 809 ``` 810 shared_input_branch = [Input(shape=(2,), name='in'), Dense(), Dense()] 811 output_branch_a = [Dense(), Dense()] 812 output_branch_b = [Dense(), Dense()] 813 814 815 model = get_multi_io_model(output__branch_a, output_branch_b, 816 shared_input_branch=shared_input_branch) 817 ``` 818 819 Args: 820 branch_a: A sequence of layers for branch a of the model. 821 branch_b: A sequence of layers for branch b of the model. 822 shared_input_branch: An optional sequence of layers to apply to a single 823 input, before applying both branches to that intermediate result. If set, 824 the model will take only one input instead of two. Defaults to None. 825 shared_output_branch: An optional sequence of layers to merge the 826 intermediate results produced by branch a and branch b. If set, 827 the model will produce only one output instead of two. Defaults to None. 828 829 Returns: 830 A multi-io model of the type specified by `get_model_type`, specified 831 by the different branches. 832 """ 833 # Extract the functional inputs from the layer lists 834 if shared_input_branch: 835 inputs = shared_input_branch[0] 836 shared_input_branch = shared_input_branch[1:] 837 else: 838 inputs = branch_a[0], branch_b[0] 839 branch_a = branch_a[1:] 840 branch_b = branch_b[1:] 841 842 model_type = get_model_type() 843 if model_type == 'subclass': 844 return _MultiIOSubclassModel(branch_a, branch_b, shared_input_branch, 845 shared_output_branch) 846 847 if model_type == 'subclass_custom_build': 848 return _MultiIOSubclassModelCustomBuild((lambda: branch_a), 849 (lambda: branch_b), 850 (lambda: shared_input_branch), 851 (lambda: shared_output_branch)) 852 853 if model_type == 'sequential': 854 raise ValueError('Cannot use `get_multi_io_model` to construct ' 855 'sequential models') 856 857 if model_type == 'functional': 858 if shared_input_branch: 859 a_and_b = inputs 860 for layer in shared_input_branch: 861 a_and_b = layer(a_and_b) 862 a = a_and_b 863 b = a_and_b 864 else: 865 a, b = inputs 866 867 for layer in branch_a: 868 a = layer(a) 869 for layer in branch_b: 870 b = layer(b) 871 outputs = a, b 872 873 if shared_output_branch: 874 for layer in shared_output_branch: 875 outputs = layer(outputs) 876 877 return models.Model(inputs, outputs) 878 879 raise ValueError('Unknown model type {}'.format(model_type)) 880 881 882_V2_OPTIMIZER_MAP = { 883 'adadelta': adadelta_v2.Adadelta, 884 'adagrad': adagrad_v2.Adagrad, 885 'adam': adam_v2.Adam, 886 'adamax': adamax_v2.Adamax, 887 'nadam': nadam_v2.Nadam, 888 'rmsprop': rmsprop_v2.RMSprop, 889 'sgd': gradient_descent_v2.SGD 890} 891 892 893def get_v2_optimizer(name, **kwargs): 894 """Get the v2 optimizer requested. 895 896 This is only necessary until v2 are the default, as we are testing in Eager, 897 and Eager + v1 optimizers fail tests. When we are in v2, the strings alone 898 should be sufficient, and this mapping can theoretically be removed. 899 900 Args: 901 name: string name of Keras v2 optimizer. 902 **kwargs: any kwargs to pass to the optimizer constructor. 903 904 Returns: 905 Initialized Keras v2 optimizer. 906 907 Raises: 908 ValueError: if an unknown name was passed. 909 """ 910 try: 911 return _V2_OPTIMIZER_MAP[name](**kwargs) 912 except KeyError: 913 raise ValueError( 914 'Could not find requested v2 optimizer: {}\nValid choices: {}'.format( 915 name, list(_V2_OPTIMIZER_MAP.keys()))) 916 917 918def get_expected_metric_variable_names(var_names, name_suffix=''): 919 """Returns expected metric variable names given names and prefix/suffix.""" 920 if tf2.enabled() or context.executing_eagerly(): 921 # In V1 eager mode and V2 variable names are not made unique. 922 return [n + ':0' for n in var_names] 923 # In V1 graph mode variable names are made unique using a suffix. 924 return [n + name_suffix + ':0' for n in var_names] 925 926 927def enable_v2_dtype_behavior(fn): 928 """Decorator for enabling the layer V2 dtype behavior on a test.""" 929 return _set_v2_dtype_behavior(fn, True) 930 931 932def disable_v2_dtype_behavior(fn): 933 """Decorator for disabling the layer V2 dtype behavior on a test.""" 934 return _set_v2_dtype_behavior(fn, False) 935 936 937def _set_v2_dtype_behavior(fn, enabled): 938 """Returns version of 'fn' that runs with v2 dtype behavior on or off.""" 939 @functools.wraps(fn) 940 def wrapper(*args, **kwargs): 941 v2_dtype_behavior = base_layer_utils.V2_DTYPE_BEHAVIOR 942 base_layer_utils.V2_DTYPE_BEHAVIOR = enabled 943 try: 944 return fn(*args, **kwargs) 945 finally: 946 base_layer_utils.V2_DTYPE_BEHAVIOR = v2_dtype_behavior 947 948 return tf_decorator.make_decorator(fn, wrapper) 949 950 951@contextlib.contextmanager 952def device(should_use_gpu): 953 """Uses gpu when requested and available.""" 954 if should_use_gpu and test_util.is_gpu_available(): 955 dev = '/device:GPU:0' 956 else: 957 dev = '/device:CPU:0' 958 with ops.device(dev): 959 yield 960 961 962@contextlib.contextmanager 963def use_gpu(): 964 """Uses gpu when requested and available.""" 965 with device(should_use_gpu=True): 966 yield 967 968 969def for_all_test_methods(decorator, *args, **kwargs): 970 """Generate class-level decorator from given method-level decorator. 971 972 It is expected for the given decorator to take some arguments and return 973 a method that is then called on the test method to produce a decorated 974 method. 975 976 Args: 977 decorator: The decorator to apply. 978 *args: Positional arguments 979 **kwargs: Keyword arguments 980 Returns: Function that will decorate a given classes test methods with the 981 decorator. 982 """ 983 984 def all_test_methods_impl(cls): 985 """Apply decorator to all test methods in class.""" 986 for name in dir(cls): 987 value = getattr(cls, name) 988 if callable(value) and name.startswith('test') and (name != 989 'test_session'): 990 setattr(cls, name, decorator(*args, **kwargs)(value)) 991 return cls 992 993 return all_test_methods_impl 994 995 996# The description is just for documentation purposes. 997def run_without_tensor_float_32(description): # pylint: disable=unused-argument 998 """Execute test with TensorFloat-32 disabled. 999 1000 While almost every real-world deep learning model runs fine with 1001 TensorFloat-32, many tests use assertAllClose or similar methods. 1002 TensorFloat-32 matmuls typically will cause such methods to fail with the 1003 default tolerances. 1004 1005 Args: 1006 description: A description used for documentation purposes, describing why 1007 the test requires TensorFloat-32 to be disabled. 1008 1009 Returns: 1010 Decorator which runs a test with TensorFloat-32 disabled. 1011 """ 1012 1013 def decorator(f): 1014 1015 @functools.wraps(f) 1016 def decorated(self, *args, **kwargs): 1017 allowed = config.tensor_float_32_execution_enabled() 1018 try: 1019 config.enable_tensor_float_32_execution(False) 1020 f(self, *args, **kwargs) 1021 finally: 1022 config.enable_tensor_float_32_execution(allowed) 1023 1024 return decorated 1025 1026 return decorator 1027 1028 1029# The description is just for documentation purposes. 1030def run_all_without_tensor_float_32(description): # pylint: disable=unused-argument 1031 """Execute all tests in a class with TensorFloat-32 disabled.""" 1032 return for_all_test_methods(run_without_tensor_float_32, description) 1033 1034 1035def run_v2_only(func=None): 1036 """Execute the decorated test only if running in v2 mode. 1037 1038 This function is intended to be applied to tests that exercise v2 only 1039 functionality. If the test is run in v1 mode it will simply be skipped. 1040 1041 See go/tf-test-decorator-cheatsheet for the decorators to use in different 1042 v1/v2/eager/graph combinations. 1043 1044 Args: 1045 func: function to be annotated. If `func` is None, this method returns a 1046 decorator the can be applied to a function. If `func` is not None this 1047 returns the decorator applied to `func`. 1048 1049 Returns: 1050 Returns a decorator that will conditionally skip the decorated test method. 1051 """ 1052 1053 def decorator(f): 1054 if tf_inspect.isclass(f): 1055 raise ValueError('`run_v2_only` only supports test methods.') 1056 1057 def decorated(self, *args, **kwargs): 1058 if not tf2.enabled(): 1059 self.skipTest('Test is only compatible with v2') 1060 1061 return f(self, *args, **kwargs) 1062 1063 return decorated 1064 1065 if func is not None: 1066 return decorator(func) 1067 1068 return decorator 1069 1070 1071def generate_combinations_with_testcase_name(**kwargs): 1072 """Generate combinations based on its keyword arguments using combine(). 1073 1074 This function calls combine() and appends a testcase name to the list of 1075 dictionaries returned. The 'testcase_name' key is a required for named 1076 parameterized tests. 1077 1078 Args: 1079 **kwargs: keyword arguments of form `option=[possibilities, ...]` or 1080 `option=the_only_possibility`. 1081 1082 Returns: 1083 a list of dictionaries for each combination. Keys in the dictionaries are 1084 the keyword argument names. Each key has one value - one of the 1085 corresponding keyword argument values. 1086 """ 1087 sort_by_key = lambda k: k[0] 1088 combinations = [] 1089 for key, values in sorted(kwargs.items(), key=sort_by_key): 1090 if not isinstance(values, list): 1091 values = [values] 1092 combinations.append([(key, value) for value in values]) 1093 1094 combinations = [collections.OrderedDict(result) 1095 for result in itertools.product(*combinations)] 1096 named_combinations = [] 1097 for combination in combinations: 1098 assert isinstance(combination, collections.OrderedDict) 1099 name = ''.join([ 1100 '_{}_{}'.format(''.join(filter(str.isalnum, key)), 1101 ''.join(filter(str.isalnum, str(value)))) 1102 for key, value in combination.items() 1103 ]) 1104 named_combinations.append( 1105 collections.OrderedDict( 1106 list(combination.items()) + 1107 [('testcase_name', '_test{}'.format(name))])) 1108 1109 return named_combinations 1110