1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for unit-testing Keras.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import threading 23 24import numpy as np 25 26from tensorflow.python import keras 27from tensorflow.python import tf2 28from tensorflow.python.eager import context 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.framework import tensor_spec 31from tensorflow.python.framework import test_util 32from tensorflow.python.keras.engine import base_layer_utils 33from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2 34from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2 35from tensorflow.python.keras.optimizer_v2 import adam as adam_v2 36from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2 37from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 38from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2 39from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2 40from tensorflow.python.util import tf_contextlib 41from tensorflow.python.util import tf_decorator 42from tensorflow.python.util import tf_inspect 43 44 45def get_test_data(train_samples, 46 test_samples, 47 input_shape, 48 num_classes, 49 random_seed=None): 50 """Generates test data to train a model on. 51 52 Arguments: 53 train_samples: Integer, how many training samples to generate. 54 test_samples: Integer, how many test samples to generate. 55 input_shape: Tuple of integers, shape of the inputs. 56 num_classes: Integer, number of classes for the data and targets. 57 random_seed: Integer, random seed used by numpy to generate data. 58 59 Returns: 60 A tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. 61 """ 62 if random_seed is not None: 63 np.random.seed(random_seed) 64 num_sample = train_samples + test_samples 65 templates = 2 * num_classes * np.random.random((num_classes,) + input_shape) 66 y = np.random.randint(0, num_classes, size=(num_sample,)) 67 x = np.zeros((num_sample,) + input_shape, dtype=np.float32) 68 for i in range(num_sample): 69 x[i] = templates[y[i]] + np.random.normal(loc=0, scale=1., size=input_shape) 70 return ((x[:train_samples], y[:train_samples]), 71 (x[train_samples:], y[train_samples:])) 72 73 74@test_util.disable_cudnn_autotune 75def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None, 76 input_data=None, expected_output=None, 77 expected_output_dtype=None, expected_output_shape=None, 78 validate_training=True, adapt_data=None): 79 """Test routine for a layer with a single input and single output. 80 81 Arguments: 82 layer_cls: Layer class object. 83 kwargs: Optional dictionary of keyword arguments for instantiating the 84 layer. 85 input_shape: Input shape tuple. 86 input_dtype: Data type of the input data. 87 input_data: Numpy array of input data. 88 expected_output: Numpy array of the expected output. 89 expected_output_dtype: Data type expected for the output. 90 expected_output_shape: Shape tuple for the expected shape of the output. 91 validate_training: Whether to attempt to validate training on this layer. 92 This might be set to False for non-differentiable layers that output 93 string or integer values. 94 adapt_data: Optional data for an 'adapt' call. If None, adapt() will not 95 be tested for this layer. This is only relevant for PreprocessingLayers. 96 97 Returns: 98 The output data (Numpy array) returned by the layer, for additional 99 checks to be done by the calling code. 100 101 Raises: 102 ValueError: if `input_shape is None`. 103 """ 104 if input_data is None: 105 if input_shape is None: 106 raise ValueError('input_shape is None') 107 if not input_dtype: 108 input_dtype = 'float32' 109 input_data_shape = list(input_shape) 110 for i, e in enumerate(input_data_shape): 111 if e is None: 112 input_data_shape[i] = np.random.randint(1, 4) 113 input_data = 10 * np.random.random(input_data_shape) 114 if input_dtype[:5] == 'float': 115 input_data -= 0.5 116 input_data = input_data.astype(input_dtype) 117 elif input_shape is None: 118 input_shape = input_data.shape 119 if input_dtype is None: 120 input_dtype = input_data.dtype 121 if expected_output_dtype is None: 122 expected_output_dtype = input_dtype 123 124 # instantiation 125 kwargs = kwargs or {} 126 layer = layer_cls(**kwargs) 127 128 # Test adapt, if data was passed. 129 if adapt_data is not None: 130 layer.adapt(adapt_data) 131 132 # test get_weights , set_weights at layer level 133 weights = layer.get_weights() 134 layer.set_weights(weights) 135 136 # test and instantiation from weights 137 if 'weights' in tf_inspect.getargspec(layer_cls.__init__): 138 kwargs['weights'] = weights 139 layer = layer_cls(**kwargs) 140 141 # test in functional API 142 x = keras.layers.Input(shape=input_shape[1:], dtype=input_dtype) 143 y = layer(x) 144 if keras.backend.dtype(y) != expected_output_dtype: 145 raise AssertionError('When testing layer %s, for input %s, found output ' 146 'dtype=%s but expected to find %s.\nFull kwargs: %s' % 147 (layer_cls.__name__, 148 x, 149 keras.backend.dtype(y), 150 expected_output_dtype, 151 kwargs)) 152 153 def assert_shapes_equal(expected, actual): 154 """Asserts that the output shape from the layer matches the actual shape.""" 155 if len(expected) != len(actual): 156 raise AssertionError( 157 'When testing layer %s, for input %s, found output_shape=' 158 '%s but expected to find %s.\nFull kwargs: %s' % 159 (layer_cls.__name__, x, actual, expected, kwargs)) 160 161 for expected_dim, actual_dim in zip(expected, actual): 162 if isinstance(expected_dim, tensor_shape.Dimension): 163 expected_dim = expected_dim.value 164 if isinstance(actual_dim, tensor_shape.Dimension): 165 actual_dim = actual_dim.value 166 if expected_dim is not None and expected_dim != actual_dim: 167 raise AssertionError( 168 'When testing layer %s, for input %s, found output_shape=' 169 '%s but expected to find %s.\nFull kwargs: %s' % 170 (layer_cls.__name__, x, actual, expected, kwargs)) 171 172 if expected_output_shape is not None: 173 assert_shapes_equal(tensor_shape.TensorShape(expected_output_shape), 174 y.shape) 175 176 # check shape inference 177 model = keras.models.Model(x, y) 178 computed_output_shape = tuple( 179 layer.compute_output_shape( 180 tensor_shape.TensorShape(input_shape)).as_list()) 181 computed_output_signature = layer.compute_output_signature( 182 tensor_spec.TensorSpec(shape=input_shape, dtype=input_dtype)) 183 actual_output = model.predict(input_data) 184 actual_output_shape = actual_output.shape 185 assert_shapes_equal(computed_output_shape, actual_output_shape) 186 assert_shapes_equal(computed_output_signature.shape, actual_output_shape) 187 if computed_output_signature.dtype != actual_output.dtype: 188 raise AssertionError( 189 'When testing layer %s, for input %s, found output_dtype=' 190 '%s but expected to find %s.\nFull kwargs: %s' % 191 (layer_cls.__name__, x, actual_output.dtype, 192 computed_output_signature.dtype, kwargs)) 193 if expected_output is not None: 194 np.testing.assert_allclose(actual_output, expected_output, 195 rtol=1e-3, atol=1e-6) 196 197 # test serialization, weight setting at model level 198 model_config = model.get_config() 199 recovered_model = keras.models.Model.from_config(model_config) 200 if model.weights: 201 weights = model.get_weights() 202 recovered_model.set_weights(weights) 203 output = recovered_model.predict(input_data) 204 np.testing.assert_allclose(output, actual_output, rtol=1e-3, atol=1e-6) 205 206 # test training mode (e.g. useful for dropout tests) 207 # Rebuild the model to avoid the graph being reused between predict() and 208 # See b/120160788 for more details. This should be mitigated after 2.0. 209 if validate_training: 210 model = keras.models.Model(x, layer(x)) 211 if _thread_local_data.run_eagerly is not None: 212 model.compile( 213 'rmsprop', 214 'mse', 215 weighted_metrics=['acc'], 216 run_eagerly=should_run_eagerly()) 217 else: 218 model.compile('rmsprop', 'mse', weighted_metrics=['acc']) 219 model.train_on_batch(input_data, actual_output) 220 221 # test as first layer in Sequential API 222 layer_config = layer.get_config() 223 layer_config['batch_input_shape'] = input_shape 224 layer = layer.__class__.from_config(layer_config) 225 226 # Test adapt, if data was passed. 227 if adapt_data is not None: 228 layer.adapt(adapt_data) 229 230 model = keras.models.Sequential() 231 model.add(layer) 232 actual_output = model.predict(input_data) 233 actual_output_shape = actual_output.shape 234 for expected_dim, actual_dim in zip(computed_output_shape, 235 actual_output_shape): 236 if expected_dim is not None: 237 if expected_dim != actual_dim: 238 raise AssertionError( 239 'When testing layer %s **after deserialization**, ' 240 'for input %s, found output_shape=' 241 '%s but expected to find inferred shape %s.\nFull kwargs: %s' % 242 (layer_cls.__name__, 243 x, 244 actual_output_shape, 245 computed_output_shape, 246 kwargs)) 247 if expected_output is not None: 248 np.testing.assert_allclose(actual_output, expected_output, 249 rtol=1e-3, atol=1e-6) 250 251 # test serialization, weight setting at model level 252 model_config = model.get_config() 253 recovered_model = keras.models.Sequential.from_config(model_config) 254 if model.weights: 255 weights = model.get_weights() 256 recovered_model.set_weights(weights) 257 output = recovered_model.predict(input_data) 258 np.testing.assert_allclose(output, actual_output, rtol=1e-3, atol=1e-6) 259 260 # for further checks in the caller function 261 return actual_output 262 263 264_thread_local_data = threading.local() 265_thread_local_data.model_type = None 266_thread_local_data.run_eagerly = None 267_thread_local_data.experimental_run_tf_function = None 268_thread_local_data.saved_model_format = None 269 270 271@tf_contextlib.contextmanager 272def model_type_scope(value): 273 """Provides a scope within which the model type to test is equal to `value`. 274 275 The model type gets restored to its original value upon exiting the scope. 276 277 Arguments: 278 value: model type value 279 280 Yields: 281 The provided value. 282 """ 283 previous_value = _thread_local_data.model_type 284 try: 285 _thread_local_data.model_type = value 286 yield value 287 finally: 288 # Restore model type to initial value. 289 _thread_local_data.model_type = previous_value 290 291 292@tf_contextlib.contextmanager 293def run_eagerly_scope(value): 294 """Provides a scope within which we compile models to run eagerly or not. 295 296 The boolean gets restored to its original value upon exiting the scope. 297 298 Arguments: 299 value: Bool specifying if we should run models eagerly in the active test. 300 Should be True or False. 301 302 Yields: 303 The provided value. 304 """ 305 previous_value = _thread_local_data.run_eagerly 306 try: 307 _thread_local_data.run_eagerly = value 308 yield value 309 finally: 310 # Restore model type to initial value. 311 _thread_local_data.run_eagerly = previous_value 312 313 314def should_run_eagerly(): 315 """Returns whether the models we are testing should be run eagerly.""" 316 if _thread_local_data.run_eagerly is None: 317 raise ValueError('Cannot call `should_run_eagerly()` outside of a ' 318 '`run_eagerly_scope()` or `run_all_keras_modes` ' 319 'decorator.') 320 321 return _thread_local_data.run_eagerly and context.executing_eagerly() 322 323 324@tf_contextlib.contextmanager 325def experimental_run_tf_function_scope(value): 326 """Provides a scope within which we compile models to run with distribution. 327 328 The boolean gets restored to its original value upon exiting the scope. 329 330 Arguments: 331 value: Bool specifying if we should run models with default distribution 332 in the active test. Should be True or False. 333 334 Yields: 335 The provided value. 336 """ 337 previous_value = _thread_local_data.experimental_run_tf_function 338 try: 339 _thread_local_data.experimental_run_tf_function = value 340 yield value 341 finally: 342 # Restore model type to initial value. 343 _thread_local_data.experimental_run_tf_function = previous_value 344 345 346def should_run_tf_function(): 347 """Returns whether the models we are testing should be run distributed.""" 348 if _thread_local_data.experimental_run_tf_function is None: 349 raise ValueError( 350 'Cannot call `should_run_tf_function()` outside of a ' 351 '`experimental_run_tf_function_scope()` or `run_all_keras_modes` ' 352 'decorator.') 353 354 return (_thread_local_data.experimental_run_tf_function and 355 context.executing_eagerly()) 356 357 358@tf_contextlib.contextmanager 359def saved_model_format_scope(value): 360 """Provides a scope within which the savde model format to test is `value`. 361 362 The saved model format gets restored to its original value upon exiting the 363 scope. 364 365 Arguments: 366 value: saved model format value 367 368 Yields: 369 The provided value. 370 """ 371 previous_value = _thread_local_data.saved_model_format 372 try: 373 _thread_local_data.saved_model_format = value 374 yield value 375 finally: 376 # Restore saved model format to initial value. 377 _thread_local_data.saved_model_format = previous_value 378 379 380def get_save_format(): 381 if _thread_local_data.saved_model_format is None: 382 raise ValueError( 383 'Cannot call `get_save_format()` outside of a ' 384 '`saved_model_format_scope()` or `run_with_all_saved_model_formats` ' 385 'decorator.') 386 return _thread_local_data.saved_model_format 387 388 389def get_model_type(): 390 """Gets the model type that should be tested.""" 391 if _thread_local_data.model_type is None: 392 raise ValueError('Cannot call `get_model_type()` outside of a ' 393 '`model_type_scope()` or `run_with_all_model_types` ' 394 'decorator.') 395 396 return _thread_local_data.model_type 397 398 399def get_small_sequential_mlp(num_hidden, num_classes, input_dim=None): 400 model = keras.models.Sequential() 401 if input_dim: 402 model.add(keras.layers.Dense(num_hidden, activation='relu', 403 input_dim=input_dim)) 404 else: 405 model.add(keras.layers.Dense(num_hidden, activation='relu')) 406 activation = 'sigmoid' if num_classes == 1 else 'softmax' 407 model.add(keras.layers.Dense(num_classes, activation=activation)) 408 return model 409 410 411def get_small_functional_mlp(num_hidden, num_classes, input_dim): 412 inputs = keras.Input(shape=(input_dim,)) 413 outputs = keras.layers.Dense(num_hidden, activation='relu')(inputs) 414 activation = 'sigmoid' if num_classes == 1 else 'softmax' 415 outputs = keras.layers.Dense(num_classes, activation=activation)(outputs) 416 return keras.Model(inputs, outputs) 417 418 419class SmallSubclassMLP(keras.Model): 420 """A subclass model based small MLP.""" 421 422 def __init__(self, num_hidden, num_classes, use_bn=False, use_dp=False): 423 super(SmallSubclassMLP, self).__init__(name='test_model') 424 self.use_bn = use_bn 425 self.use_dp = use_dp 426 427 self.layer_a = keras.layers.Dense(num_hidden, activation='relu') 428 activation = 'sigmoid' if num_classes == 1 else 'softmax' 429 self.layer_b = keras.layers.Dense(num_classes, activation=activation) 430 if self.use_dp: 431 self.dp = keras.layers.Dropout(0.5) 432 if self.use_bn: 433 self.bn = keras.layers.BatchNormalization(axis=-1) 434 435 def call(self, inputs, **kwargs): 436 x = self.layer_a(inputs) 437 if self.use_dp: 438 x = self.dp(x) 439 if self.use_bn: 440 x = self.bn(x) 441 return self.layer_b(x) 442 443 444class _SmallSubclassMLPCustomBuild(keras.Model): 445 """A subclass model small MLP that uses a custom build method.""" 446 447 def __init__(self, num_hidden, num_classes): 448 super(_SmallSubclassMLPCustomBuild, self).__init__() 449 self.layer_a = None 450 self.layer_b = None 451 self.num_hidden = num_hidden 452 self.num_classes = num_classes 453 454 def build(self, input_shape): 455 self.layer_a = keras.layers.Dense(self.num_hidden, activation='relu') 456 activation = 'sigmoid' if self.num_classes == 1 else 'softmax' 457 self.layer_b = keras.layers.Dense(self.num_classes, activation=activation) 458 459 def call(self, inputs, **kwargs): 460 x = self.layer_a(inputs) 461 return self.layer_b(x) 462 463 464def get_small_subclass_mlp(num_hidden, num_classes): 465 return SmallSubclassMLP(num_hidden, num_classes) 466 467 468def get_small_subclass_mlp_with_custom_build(num_hidden, num_classes): 469 return _SmallSubclassMLPCustomBuild(num_hidden, num_classes) 470 471 472def get_small_mlp(num_hidden, num_classes, input_dim): 473 """Get a small mlp of the model type specified by `get_model_type`.""" 474 model_type = get_model_type() 475 if model_type == 'subclass': 476 return get_small_subclass_mlp(num_hidden, num_classes) 477 if model_type == 'subclass_custom_build': 478 return get_small_subclass_mlp_with_custom_build(num_hidden, num_classes) 479 if model_type == 'sequential': 480 return get_small_sequential_mlp(num_hidden, num_classes, input_dim) 481 if model_type == 'functional': 482 return get_small_functional_mlp(num_hidden, num_classes, input_dim) 483 raise ValueError('Unknown model type {}'.format(model_type)) 484 485 486class _SubclassModel(keras.Model): 487 """A Keras subclass model.""" 488 489 def __init__(self, layers, *args, **kwargs): 490 """Instantiate a model. 491 492 Args: 493 layers: a list of layers to be added to the model. 494 *args: Model's args 495 **kwargs: Model's keyword args, at most one of 496 input_tensor -> the input tensor required for ragged/sparse input. 497 """ 498 499 inputs = kwargs.pop('input_tensor', None) 500 super(_SubclassModel, self).__init__(*args, **kwargs) 501 # Note that clone and build doesn't support lists of layers in subclassed 502 # models. Adding each layer directly here. 503 for i, layer in enumerate(layers): 504 setattr(self, self._layer_name_for_i(i), layer) 505 506 self.num_layers = len(layers) 507 508 if inputs is not None: 509 self._set_inputs(inputs) 510 511 def _layer_name_for_i(self, i): 512 return 'layer{}'.format(i) 513 514 def call(self, inputs, **kwargs): 515 x = inputs 516 for i in range(self.num_layers): 517 layer = getattr(self, self._layer_name_for_i(i)) 518 x = layer(x) 519 return x 520 521 522class _SubclassModelCustomBuild(keras.Model): 523 """A Keras subclass model that uses a custom build method.""" 524 525 def __init__(self, layer_generating_func, *args, **kwargs): 526 super(_SubclassModelCustomBuild, self).__init__(*args, **kwargs) 527 self.all_layers = None 528 self._layer_generating_func = layer_generating_func 529 530 def build(self, input_shape): 531 layers = [] 532 for layer in self._layer_generating_func(): 533 layers.append(layer) 534 self.all_layers = layers 535 536 def call(self, inputs, **kwargs): 537 x = inputs 538 for layer in self.all_layers: 539 x = layer(x) 540 return x 541 542 543def get_model_from_layers(layers, 544 input_shape=None, 545 input_dtype=None, 546 name=None, 547 input_ragged=None, 548 input_sparse=None): 549 """Builds a model from a sequence of layers. 550 551 Args: 552 layers: The layers used to build the network. 553 input_shape: Shape tuple of the input or 'TensorShape' instance. 554 input_dtype: Datatype of the input. 555 name: Name for the model. 556 input_ragged: Boolean, whether the input data is a ragged tensor. 557 input_sparse: Boolean, whether the input data is a sparse tensor. 558 559 Returns: 560 A Keras model. 561 """ 562 563 model_type = get_model_type() 564 if model_type == 'subclass': 565 inputs = None 566 if input_ragged or input_sparse: 567 inputs = keras.Input( 568 shape=input_shape, 569 dtype=input_dtype, 570 ragged=input_ragged, 571 sparse=input_sparse) 572 return _SubclassModel(layers, name=name, input_tensor=inputs) 573 574 if model_type == 'subclass_custom_build': 575 layer_generating_func = lambda: layers 576 return _SubclassModelCustomBuild(layer_generating_func, name=name) 577 578 if model_type == 'sequential': 579 model = keras.models.Sequential(name=name) 580 if input_shape: 581 model.add( 582 keras.layers.InputLayer( 583 input_shape=input_shape, 584 dtype=input_dtype, 585 ragged=input_ragged, 586 sparse=input_sparse)) 587 for layer in layers: 588 model.add(layer) 589 return model 590 591 if model_type == 'functional': 592 if not input_shape: 593 raise ValueError('Cannot create a functional model from layers with no ' 594 'input shape.') 595 inputs = keras.Input( 596 shape=input_shape, 597 dtype=input_dtype, 598 ragged=input_ragged, 599 sparse=input_sparse) 600 outputs = inputs 601 for layer in layers: 602 outputs = layer(outputs) 603 return keras.Model(inputs, outputs, name=name) 604 605 raise ValueError('Unknown model type {}'.format(model_type)) 606 607 608class Bias(keras.layers.Layer): 609 610 def build(self, input_shape): 611 self.bias = self.add_variable('bias', (1,), initializer='zeros') 612 613 def call(self, inputs): 614 return inputs + self.bias 615 616 617class _MultiIOSubclassModel(keras.Model): 618 """Multi IO Keras subclass model.""" 619 620 def __init__(self, branch_a, branch_b, shared_input_branch=None, 621 shared_output_branch=None, name=None): 622 super(_MultiIOSubclassModel, self).__init__(name=name) 623 self._shared_input_branch = shared_input_branch 624 self._branch_a = branch_a 625 self._branch_b = branch_b 626 self._shared_output_branch = shared_output_branch 627 628 def call(self, inputs, **kwargs): 629 if self._shared_input_branch: 630 for layer in self._shared_input_branch: 631 inputs = layer(inputs) 632 a = inputs 633 b = inputs 634 else: 635 a, b = inputs 636 637 for layer in self._branch_a: 638 a = layer(a) 639 for layer in self._branch_b: 640 b = layer(b) 641 outs = [a, b] 642 643 if self._shared_output_branch: 644 for layer in self._shared_output_branch: 645 outs = layer(outs) 646 647 return outs 648 649 650class _MultiIOSubclassModelCustomBuild(keras.Model): 651 """Multi IO Keras subclass model that uses a custom build method.""" 652 653 def __init__(self, branch_a_func, branch_b_func, 654 shared_input_branch_func=None, 655 shared_output_branch_func=None): 656 super(_MultiIOSubclassModelCustomBuild, self).__init__() 657 self._shared_input_branch_func = shared_input_branch_func 658 self._branch_a_func = branch_a_func 659 self._branch_b_func = branch_b_func 660 self._shared_output_branch_func = shared_output_branch_func 661 662 self._shared_input_branch = None 663 self._branch_a = None 664 self._branch_b = None 665 self._shared_output_branch = None 666 667 def build(self, input_shape): 668 if self._shared_input_branch_func(): 669 self._shared_input_branch = self._shared_input_branch_func() 670 self._branch_a = self._branch_a_func() 671 self._branch_b = self._branch_b_func() 672 673 if self._shared_output_branch_func(): 674 self._shared_output_branch = self._shared_output_branch_func() 675 676 def call(self, inputs, **kwargs): 677 if self._shared_input_branch: 678 for layer in self._shared_input_branch: 679 inputs = layer(inputs) 680 a = inputs 681 b = inputs 682 else: 683 a, b = inputs 684 685 for layer in self._branch_a: 686 a = layer(a) 687 for layer in self._branch_b: 688 b = layer(b) 689 outs = a, b 690 691 if self._shared_output_branch: 692 for layer in self._shared_output_branch: 693 outs = layer(outs) 694 695 return outs 696 697 698def get_multi_io_model( 699 branch_a, 700 branch_b, 701 shared_input_branch=None, 702 shared_output_branch=None): 703 """Builds a multi-io model that contains two branches. 704 705 The produced model will be of the type specified by `get_model_type`. 706 707 To build a two-input, two-output model: 708 Specify a list of layers for branch a and branch b, but do not specify any 709 shared input branch or shared output branch. The resulting model will apply 710 each branch to a different input, to produce two outputs. 711 712 The first value in branch_a must be the Keras 'Input' layer for branch a, 713 and the first value in branch_b must be the Keras 'Input' layer for 714 branch b. 715 716 example usage: 717 ``` 718 branch_a = [Input(shape=(2,), name='a'), Dense(), Dense()] 719 branch_b = [Input(shape=(3,), name='b'), Dense(), Dense()] 720 721 model = get_multi_io_model(branch_a, branch_b) 722 ``` 723 724 To build a two-input, one-output model: 725 Specify a list of layers for branch a and branch b, and specify a 726 shared output branch. The resulting model will apply 727 each branch to a different input. It will then apply the shared output 728 branch to a tuple containing the intermediate outputs of each branch, 729 to produce a single output. The first layer in the shared_output_branch 730 must be able to merge a tuple of two tensors. 731 732 The first value in branch_a must be the Keras 'Input' layer for branch a, 733 and the first value in branch_b must be the Keras 'Input' layer for 734 branch b. 735 736 example usage: 737 ``` 738 input_branch_a = [Input(shape=(2,), name='a'), Dense(), Dense()] 739 input_branch_b = [Input(shape=(3,), name='b'), Dense(), Dense()] 740 shared_output_branch = [Concatenate(), Dense(), Dense()] 741 742 model = get_multi_io_model(input_branch_a, input_branch_b, 743 shared_output_branch=shared_output_branch) 744 ``` 745 To build a one-input, two-output model: 746 Specify a list of layers for branch a and branch b, and specify a 747 shared input branch. The resulting model will take one input, and apply 748 the shared input branch to it. It will then respectively apply each branch 749 to that intermediate result in parallel, to produce two outputs. 750 751 The first value in the shared_input_branch must be the Keras 'Input' layer 752 for the whole model. Branch a and branch b should not contain any Input 753 layers. 754 755 example usage: 756 ``` 757 shared_input_branch = [Input(shape=(2,), name='in'), Dense(), Dense()] 758 output_branch_a = [Dense(), Dense()] 759 output_branch_b = [Dense(), Dense()] 760 761 762 model = get_multi_io_model(output__branch_a, output_branch_b, 763 shared_input_branch=shared_input_branch) 764 ``` 765 766 Args: 767 branch_a: A sequence of layers for branch a of the model. 768 branch_b: A sequence of layers for branch b of the model. 769 shared_input_branch: An optional sequence of layers to apply to a single 770 input, before applying both branches to that intermediate result. If set, 771 the model will take only one input instead of two. Defaults to None. 772 shared_output_branch: An optional sequence of layers to merge the 773 intermediate results produced by branch a and branch b. If set, 774 the model will produce only one output instead of two. Defaults to None. 775 776 Returns: 777 A multi-io model of the type specified by `get_model_type`, specified 778 by the different branches. 779 """ 780 # Extract the functional inputs from the layer lists 781 if shared_input_branch: 782 inputs = shared_input_branch[0] 783 shared_input_branch = shared_input_branch[1:] 784 else: 785 inputs = branch_a[0], branch_b[0] 786 branch_a = branch_a[1:] 787 branch_b = branch_b[1:] 788 789 model_type = get_model_type() 790 if model_type == 'subclass': 791 return _MultiIOSubclassModel(branch_a, branch_b, shared_input_branch, 792 shared_output_branch) 793 794 if model_type == 'subclass_custom_build': 795 return _MultiIOSubclassModelCustomBuild((lambda: branch_a), 796 (lambda: branch_b), 797 (lambda: shared_input_branch), 798 (lambda: shared_output_branch)) 799 800 if model_type == 'sequential': 801 raise ValueError('Cannot use `get_multi_io_model` to construct ' 802 'sequential models') 803 804 if model_type == 'functional': 805 if shared_input_branch: 806 a_and_b = inputs 807 for layer in shared_input_branch: 808 a_and_b = layer(a_and_b) 809 a = a_and_b 810 b = a_and_b 811 else: 812 a, b = inputs 813 814 for layer in branch_a: 815 a = layer(a) 816 for layer in branch_b: 817 b = layer(b) 818 outputs = a, b 819 820 if shared_output_branch: 821 for layer in shared_output_branch: 822 outputs = layer(outputs) 823 824 return keras.Model(inputs, outputs) 825 826 raise ValueError('Unknown model type {}'.format(model_type)) 827 828 829_V2_OPTIMIZER_MAP = { 830 'adadelta': adadelta_v2.Adadelta, 831 'adagrad': adagrad_v2.Adagrad, 832 'adam': adam_v2.Adam, 833 'adamax': adamax_v2.Adamax, 834 'nadam': nadam_v2.Nadam, 835 'rmsprop': rmsprop_v2.RMSprop, 836 'sgd': gradient_descent_v2.SGD 837} 838 839 840def get_v2_optimizer(name, **kwargs): 841 """Get the v2 optimizer requested. 842 843 This is only necessary until v2 are the default, as we are testing in Eager, 844 and Eager + v1 optimizers fail tests. When we are in v2, the strings alone 845 should be sufficient, and this mapping can theoretically be removed. 846 847 Args: 848 name: string name of Keras v2 optimizer. 849 **kwargs: any kwargs to pass to the optimizer constructor. 850 851 Returns: 852 Initialized Keras v2 optimizer. 853 854 Raises: 855 ValueError: if an unknown name was passed. 856 """ 857 try: 858 return _V2_OPTIMIZER_MAP[name](**kwargs) 859 except KeyError: 860 raise ValueError( 861 'Could not find requested v2 optimizer: {}\nValid choices: {}'.format( 862 name, list(_V2_OPTIMIZER_MAP.keys()))) 863 864 865def get_expected_metric_variable_names(var_names, name_suffix=''): 866 """Returns expected metric variable names given names and prefix/suffix.""" 867 if tf2.enabled() or context.executing_eagerly(): 868 # In V1 eager mode and V2 variable names are not made unique. 869 return [n + ':0' for n in var_names] 870 # In V1 graph mode variable names are made unique using a suffix. 871 return [n + name_suffix + ':0' for n in var_names] 872 873 874def enable_v2_dtype_behavior(fn): 875 """Decorator for enabling the layer V2 dtype behavior on a test.""" 876 return _set_v2_dtype_behavior(fn, True) 877 878 879def disable_v2_dtype_behavior(fn): 880 """Decorator for disabling the layer V2 dtype behavior on a test.""" 881 return _set_v2_dtype_behavior(fn, False) 882 883 884def _set_v2_dtype_behavior(fn, enabled): 885 """Returns version of 'fn' that runs with v2 dtype behavior on or off.""" 886 @functools.wraps(fn) 887 def wrapper(*args, **kwargs): 888 v2_dtype_behavior = base_layer_utils.V2_DTYPE_BEHAVIOR 889 base_layer_utils.V2_DTYPE_BEHAVIOR = enabled 890 try: 891 return fn(*args, **kwargs) 892 finally: 893 base_layer_utils.V2_DTYPE_BEHAVIOR = v2_dtype_behavior 894 895 return tf_decorator.make_decorator(fn, wrapper) 896