• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for unit-testing Keras."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import threading
23
24import numpy as np
25
26from tensorflow.python import keras
27from tensorflow.python import tf2
28from tensorflow.python.eager import context
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import tensor_spec
31from tensorflow.python.framework import test_util
32from tensorflow.python.keras.engine import base_layer_utils
33from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2
34from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2
35from tensorflow.python.keras.optimizer_v2 import adam as adam_v2
36from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2
37from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
38from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2
39from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2
40from tensorflow.python.util import tf_contextlib
41from tensorflow.python.util import tf_decorator
42from tensorflow.python.util import tf_inspect
43
44
45def get_test_data(train_samples,
46                  test_samples,
47                  input_shape,
48                  num_classes,
49                  random_seed=None):
50  """Generates test data to train a model on.
51
52  Arguments:
53    train_samples: Integer, how many training samples to generate.
54    test_samples: Integer, how many test samples to generate.
55    input_shape: Tuple of integers, shape of the inputs.
56    num_classes: Integer, number of classes for the data and targets.
57    random_seed: Integer, random seed used by numpy to generate data.
58
59  Returns:
60    A tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
61  """
62  if random_seed is not None:
63    np.random.seed(random_seed)
64  num_sample = train_samples + test_samples
65  templates = 2 * num_classes * np.random.random((num_classes,) + input_shape)
66  y = np.random.randint(0, num_classes, size=(num_sample,))
67  x = np.zeros((num_sample,) + input_shape, dtype=np.float32)
68  for i in range(num_sample):
69    x[i] = templates[y[i]] + np.random.normal(loc=0, scale=1., size=input_shape)
70  return ((x[:train_samples], y[:train_samples]),
71          (x[train_samples:], y[train_samples:]))
72
73
74@test_util.disable_cudnn_autotune
75def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
76               input_data=None, expected_output=None,
77               expected_output_dtype=None, expected_output_shape=None,
78               validate_training=True, adapt_data=None):
79  """Test routine for a layer with a single input and single output.
80
81  Arguments:
82    layer_cls: Layer class object.
83    kwargs: Optional dictionary of keyword arguments for instantiating the
84      layer.
85    input_shape: Input shape tuple.
86    input_dtype: Data type of the input data.
87    input_data: Numpy array of input data.
88    expected_output: Numpy array of the expected output.
89    expected_output_dtype: Data type expected for the output.
90    expected_output_shape: Shape tuple for the expected shape of the output.
91    validate_training: Whether to attempt to validate training on this layer.
92      This might be set to False for non-differentiable layers that output
93      string or integer values.
94    adapt_data: Optional data for an 'adapt' call. If None, adapt() will not
95      be tested for this layer. This is only relevant for PreprocessingLayers.
96
97  Returns:
98    The output data (Numpy array) returned by the layer, for additional
99    checks to be done by the calling code.
100
101  Raises:
102    ValueError: if `input_shape is None`.
103  """
104  if input_data is None:
105    if input_shape is None:
106      raise ValueError('input_shape is None')
107    if not input_dtype:
108      input_dtype = 'float32'
109    input_data_shape = list(input_shape)
110    for i, e in enumerate(input_data_shape):
111      if e is None:
112        input_data_shape[i] = np.random.randint(1, 4)
113    input_data = 10 * np.random.random(input_data_shape)
114    if input_dtype[:5] == 'float':
115      input_data -= 0.5
116    input_data = input_data.astype(input_dtype)
117  elif input_shape is None:
118    input_shape = input_data.shape
119  if input_dtype is None:
120    input_dtype = input_data.dtype
121  if expected_output_dtype is None:
122    expected_output_dtype = input_dtype
123
124  # instantiation
125  kwargs = kwargs or {}
126  layer = layer_cls(**kwargs)
127
128  # Test adapt, if data was passed.
129  if adapt_data is not None:
130    layer.adapt(adapt_data)
131
132  # test get_weights , set_weights at layer level
133  weights = layer.get_weights()
134  layer.set_weights(weights)
135
136  # test and instantiation from weights
137  if 'weights' in tf_inspect.getargspec(layer_cls.__init__):
138    kwargs['weights'] = weights
139    layer = layer_cls(**kwargs)
140
141  # test in functional API
142  x = keras.layers.Input(shape=input_shape[1:], dtype=input_dtype)
143  y = layer(x)
144  if keras.backend.dtype(y) != expected_output_dtype:
145    raise AssertionError('When testing layer %s, for input %s, found output '
146                         'dtype=%s but expected to find %s.\nFull kwargs: %s' %
147                         (layer_cls.__name__,
148                          x,
149                          keras.backend.dtype(y),
150                          expected_output_dtype,
151                          kwargs))
152
153  def assert_shapes_equal(expected, actual):
154    """Asserts that the output shape from the layer matches the actual shape."""
155    if len(expected) != len(actual):
156      raise AssertionError(
157          'When testing layer %s, for input %s, found output_shape='
158          '%s but expected to find %s.\nFull kwargs: %s' %
159          (layer_cls.__name__, x, actual, expected, kwargs))
160
161    for expected_dim, actual_dim in zip(expected, actual):
162      if isinstance(expected_dim, tensor_shape.Dimension):
163        expected_dim = expected_dim.value
164      if isinstance(actual_dim, tensor_shape.Dimension):
165        actual_dim = actual_dim.value
166      if expected_dim is not None and expected_dim != actual_dim:
167        raise AssertionError(
168            'When testing layer %s, for input %s, found output_shape='
169            '%s but expected to find %s.\nFull kwargs: %s' %
170            (layer_cls.__name__, x, actual, expected, kwargs))
171
172  if expected_output_shape is not None:
173    assert_shapes_equal(tensor_shape.TensorShape(expected_output_shape),
174                        y.shape)
175
176  # check shape inference
177  model = keras.models.Model(x, y)
178  computed_output_shape = tuple(
179      layer.compute_output_shape(
180          tensor_shape.TensorShape(input_shape)).as_list())
181  computed_output_signature = layer.compute_output_signature(
182      tensor_spec.TensorSpec(shape=input_shape, dtype=input_dtype))
183  actual_output = model.predict(input_data)
184  actual_output_shape = actual_output.shape
185  assert_shapes_equal(computed_output_shape, actual_output_shape)
186  assert_shapes_equal(computed_output_signature.shape, actual_output_shape)
187  if computed_output_signature.dtype != actual_output.dtype:
188    raise AssertionError(
189        'When testing layer %s, for input %s, found output_dtype='
190        '%s but expected to find %s.\nFull kwargs: %s' %
191        (layer_cls.__name__, x, actual_output.dtype,
192         computed_output_signature.dtype, kwargs))
193  if expected_output is not None:
194    np.testing.assert_allclose(actual_output, expected_output,
195                               rtol=1e-3, atol=1e-6)
196
197  # test serialization, weight setting at model level
198  model_config = model.get_config()
199  recovered_model = keras.models.Model.from_config(model_config)
200  if model.weights:
201    weights = model.get_weights()
202    recovered_model.set_weights(weights)
203    output = recovered_model.predict(input_data)
204    np.testing.assert_allclose(output, actual_output, rtol=1e-3, atol=1e-6)
205
206  # test training mode (e.g. useful for dropout tests)
207  # Rebuild the model to avoid the graph being reused between predict() and
208  # See b/120160788 for more details. This should be mitigated after 2.0.
209  if validate_training:
210    model = keras.models.Model(x, layer(x))
211    if _thread_local_data.run_eagerly is not None:
212      model.compile(
213          'rmsprop',
214          'mse',
215          weighted_metrics=['acc'],
216          run_eagerly=should_run_eagerly())
217    else:
218      model.compile('rmsprop', 'mse', weighted_metrics=['acc'])
219    model.train_on_batch(input_data, actual_output)
220
221  # test as first layer in Sequential API
222  layer_config = layer.get_config()
223  layer_config['batch_input_shape'] = input_shape
224  layer = layer.__class__.from_config(layer_config)
225
226  # Test adapt, if data was passed.
227  if adapt_data is not None:
228    layer.adapt(adapt_data)
229
230  model = keras.models.Sequential()
231  model.add(layer)
232  actual_output = model.predict(input_data)
233  actual_output_shape = actual_output.shape
234  for expected_dim, actual_dim in zip(computed_output_shape,
235                                      actual_output_shape):
236    if expected_dim is not None:
237      if expected_dim != actual_dim:
238        raise AssertionError(
239            'When testing layer %s **after deserialization**, '
240            'for input %s, found output_shape='
241            '%s but expected to find inferred shape %s.\nFull kwargs: %s' %
242            (layer_cls.__name__,
243             x,
244             actual_output_shape,
245             computed_output_shape,
246             kwargs))
247  if expected_output is not None:
248    np.testing.assert_allclose(actual_output, expected_output,
249                               rtol=1e-3, atol=1e-6)
250
251  # test serialization, weight setting at model level
252  model_config = model.get_config()
253  recovered_model = keras.models.Sequential.from_config(model_config)
254  if model.weights:
255    weights = model.get_weights()
256    recovered_model.set_weights(weights)
257    output = recovered_model.predict(input_data)
258    np.testing.assert_allclose(output, actual_output, rtol=1e-3, atol=1e-6)
259
260  # for further checks in the caller function
261  return actual_output
262
263
264_thread_local_data = threading.local()
265_thread_local_data.model_type = None
266_thread_local_data.run_eagerly = None
267_thread_local_data.experimental_run_tf_function = None
268_thread_local_data.saved_model_format = None
269
270
271@tf_contextlib.contextmanager
272def model_type_scope(value):
273  """Provides a scope within which the model type to test is equal to `value`.
274
275  The model type gets restored to its original value upon exiting the scope.
276
277  Arguments:
278     value: model type value
279
280  Yields:
281    The provided value.
282  """
283  previous_value = _thread_local_data.model_type
284  try:
285    _thread_local_data.model_type = value
286    yield value
287  finally:
288    # Restore model type to initial value.
289    _thread_local_data.model_type = previous_value
290
291
292@tf_contextlib.contextmanager
293def run_eagerly_scope(value):
294  """Provides a scope within which we compile models to run eagerly or not.
295
296  The boolean gets restored to its original value upon exiting the scope.
297
298  Arguments:
299     value: Bool specifying if we should run models eagerly in the active test.
300     Should be True or False.
301
302  Yields:
303    The provided value.
304  """
305  previous_value = _thread_local_data.run_eagerly
306  try:
307    _thread_local_data.run_eagerly = value
308    yield value
309  finally:
310    # Restore model type to initial value.
311    _thread_local_data.run_eagerly = previous_value
312
313
314def should_run_eagerly():
315  """Returns whether the models we are testing should be run eagerly."""
316  if _thread_local_data.run_eagerly is None:
317    raise ValueError('Cannot call `should_run_eagerly()` outside of a '
318                     '`run_eagerly_scope()` or `run_all_keras_modes` '
319                     'decorator.')
320
321  return _thread_local_data.run_eagerly and context.executing_eagerly()
322
323
324@tf_contextlib.contextmanager
325def experimental_run_tf_function_scope(value):
326  """Provides a scope within which we compile models to run with distribution.
327
328  The boolean gets restored to its original value upon exiting the scope.
329
330  Arguments:
331     value: Bool specifying if we should run models with default distribution
332     in the active test. Should be True or False.
333
334  Yields:
335    The provided value.
336  """
337  previous_value = _thread_local_data.experimental_run_tf_function
338  try:
339    _thread_local_data.experimental_run_tf_function = value
340    yield value
341  finally:
342    # Restore model type to initial value.
343    _thread_local_data.experimental_run_tf_function = previous_value
344
345
346def should_run_tf_function():
347  """Returns whether the models we are testing should be run distributed."""
348  if _thread_local_data.experimental_run_tf_function is None:
349    raise ValueError(
350        'Cannot call `should_run_tf_function()` outside of a '
351        '`experimental_run_tf_function_scope()` or `run_all_keras_modes` '
352        'decorator.')
353
354  return (_thread_local_data.experimental_run_tf_function and
355          context.executing_eagerly())
356
357
358@tf_contextlib.contextmanager
359def saved_model_format_scope(value):
360  """Provides a scope within which the savde model format to test is `value`.
361
362  The saved model format gets restored to its original value upon exiting the
363  scope.
364
365  Arguments:
366     value: saved model format value
367
368  Yields:
369    The provided value.
370  """
371  previous_value = _thread_local_data.saved_model_format
372  try:
373    _thread_local_data.saved_model_format = value
374    yield value
375  finally:
376    # Restore saved model format to initial value.
377    _thread_local_data.saved_model_format = previous_value
378
379
380def get_save_format():
381  if _thread_local_data.saved_model_format is None:
382    raise ValueError(
383        'Cannot call `get_save_format()` outside of a '
384        '`saved_model_format_scope()` or `run_with_all_saved_model_formats` '
385        'decorator.')
386  return _thread_local_data.saved_model_format
387
388
389def get_model_type():
390  """Gets the model type that should be tested."""
391  if _thread_local_data.model_type is None:
392    raise ValueError('Cannot call `get_model_type()` outside of a '
393                     '`model_type_scope()` or `run_with_all_model_types` '
394                     'decorator.')
395
396  return _thread_local_data.model_type
397
398
399def get_small_sequential_mlp(num_hidden, num_classes, input_dim=None):
400  model = keras.models.Sequential()
401  if input_dim:
402    model.add(keras.layers.Dense(num_hidden, activation='relu',
403                                 input_dim=input_dim))
404  else:
405    model.add(keras.layers.Dense(num_hidden, activation='relu'))
406  activation = 'sigmoid' if num_classes == 1 else 'softmax'
407  model.add(keras.layers.Dense(num_classes, activation=activation))
408  return model
409
410
411def get_small_functional_mlp(num_hidden, num_classes, input_dim):
412  inputs = keras.Input(shape=(input_dim,))
413  outputs = keras.layers.Dense(num_hidden, activation='relu')(inputs)
414  activation = 'sigmoid' if num_classes == 1 else 'softmax'
415  outputs = keras.layers.Dense(num_classes, activation=activation)(outputs)
416  return keras.Model(inputs, outputs)
417
418
419class SmallSubclassMLP(keras.Model):
420  """A subclass model based small MLP."""
421
422  def __init__(self, num_hidden, num_classes, use_bn=False, use_dp=False):
423    super(SmallSubclassMLP, self).__init__(name='test_model')
424    self.use_bn = use_bn
425    self.use_dp = use_dp
426
427    self.layer_a = keras.layers.Dense(num_hidden, activation='relu')
428    activation = 'sigmoid' if num_classes == 1 else 'softmax'
429    self.layer_b = keras.layers.Dense(num_classes, activation=activation)
430    if self.use_dp:
431      self.dp = keras.layers.Dropout(0.5)
432    if self.use_bn:
433      self.bn = keras.layers.BatchNormalization(axis=-1)
434
435  def call(self, inputs, **kwargs):
436    x = self.layer_a(inputs)
437    if self.use_dp:
438      x = self.dp(x)
439    if self.use_bn:
440      x = self.bn(x)
441    return self.layer_b(x)
442
443
444class _SmallSubclassMLPCustomBuild(keras.Model):
445  """A subclass model small MLP that uses a custom build method."""
446
447  def __init__(self, num_hidden, num_classes):
448    super(_SmallSubclassMLPCustomBuild, self).__init__()
449    self.layer_a = None
450    self.layer_b = None
451    self.num_hidden = num_hidden
452    self.num_classes = num_classes
453
454  def build(self, input_shape):
455    self.layer_a = keras.layers.Dense(self.num_hidden, activation='relu')
456    activation = 'sigmoid' if self.num_classes == 1 else 'softmax'
457    self.layer_b = keras.layers.Dense(self.num_classes, activation=activation)
458
459  def call(self, inputs, **kwargs):
460    x = self.layer_a(inputs)
461    return self.layer_b(x)
462
463
464def get_small_subclass_mlp(num_hidden, num_classes):
465  return SmallSubclassMLP(num_hidden, num_classes)
466
467
468def get_small_subclass_mlp_with_custom_build(num_hidden, num_classes):
469  return _SmallSubclassMLPCustomBuild(num_hidden, num_classes)
470
471
472def get_small_mlp(num_hidden, num_classes, input_dim):
473  """Get a small mlp of the model type specified by `get_model_type`."""
474  model_type = get_model_type()
475  if model_type == 'subclass':
476    return get_small_subclass_mlp(num_hidden, num_classes)
477  if model_type == 'subclass_custom_build':
478    return get_small_subclass_mlp_with_custom_build(num_hidden, num_classes)
479  if model_type == 'sequential':
480    return get_small_sequential_mlp(num_hidden, num_classes, input_dim)
481  if model_type == 'functional':
482    return get_small_functional_mlp(num_hidden, num_classes, input_dim)
483  raise ValueError('Unknown model type {}'.format(model_type))
484
485
486class _SubclassModel(keras.Model):
487  """A Keras subclass model."""
488
489  def __init__(self, layers, *args, **kwargs):
490    """Instantiate a model.
491
492    Args:
493      layers: a list of layers to be added to the model.
494      *args: Model's args
495      **kwargs: Model's keyword args, at most one of
496        input_tensor -> the input tensor required for ragged/sparse input.
497    """
498
499    inputs = kwargs.pop('input_tensor', None)
500    super(_SubclassModel, self).__init__(*args, **kwargs)
501    # Note that clone and build doesn't support lists of layers in subclassed
502    # models. Adding each layer directly here.
503    for i, layer in enumerate(layers):
504      setattr(self, self._layer_name_for_i(i), layer)
505
506    self.num_layers = len(layers)
507
508    if inputs is not None:
509      self._set_inputs(inputs)
510
511  def _layer_name_for_i(self, i):
512    return 'layer{}'.format(i)
513
514  def call(self, inputs, **kwargs):
515    x = inputs
516    for i in range(self.num_layers):
517      layer = getattr(self, self._layer_name_for_i(i))
518      x = layer(x)
519    return x
520
521
522class _SubclassModelCustomBuild(keras.Model):
523  """A Keras subclass model that uses a custom build method."""
524
525  def __init__(self, layer_generating_func, *args, **kwargs):
526    super(_SubclassModelCustomBuild, self).__init__(*args, **kwargs)
527    self.all_layers = None
528    self._layer_generating_func = layer_generating_func
529
530  def build(self, input_shape):
531    layers = []
532    for layer in self._layer_generating_func():
533      layers.append(layer)
534    self.all_layers = layers
535
536  def call(self, inputs, **kwargs):
537    x = inputs
538    for layer in self.all_layers:
539      x = layer(x)
540    return x
541
542
543def get_model_from_layers(layers,
544                          input_shape=None,
545                          input_dtype=None,
546                          name=None,
547                          input_ragged=None,
548                          input_sparse=None):
549  """Builds a model from a sequence of layers.
550
551  Args:
552    layers: The layers used to build the network.
553    input_shape: Shape tuple of the input or 'TensorShape' instance.
554    input_dtype: Datatype of the input.
555    name: Name for the model.
556    input_ragged: Boolean, whether the input data is a ragged tensor.
557    input_sparse: Boolean, whether the input data is a sparse tensor.
558
559  Returns:
560    A Keras model.
561  """
562
563  model_type = get_model_type()
564  if model_type == 'subclass':
565    inputs = None
566    if input_ragged or input_sparse:
567      inputs = keras.Input(
568          shape=input_shape,
569          dtype=input_dtype,
570          ragged=input_ragged,
571          sparse=input_sparse)
572    return _SubclassModel(layers, name=name, input_tensor=inputs)
573
574  if model_type == 'subclass_custom_build':
575    layer_generating_func = lambda: layers
576    return _SubclassModelCustomBuild(layer_generating_func, name=name)
577
578  if model_type == 'sequential':
579    model = keras.models.Sequential(name=name)
580    if input_shape:
581      model.add(
582          keras.layers.InputLayer(
583              input_shape=input_shape,
584              dtype=input_dtype,
585              ragged=input_ragged,
586              sparse=input_sparse))
587    for layer in layers:
588      model.add(layer)
589    return model
590
591  if model_type == 'functional':
592    if not input_shape:
593      raise ValueError('Cannot create a functional model from layers with no '
594                       'input shape.')
595    inputs = keras.Input(
596        shape=input_shape,
597        dtype=input_dtype,
598        ragged=input_ragged,
599        sparse=input_sparse)
600    outputs = inputs
601    for layer in layers:
602      outputs = layer(outputs)
603    return keras.Model(inputs, outputs, name=name)
604
605  raise ValueError('Unknown model type {}'.format(model_type))
606
607
608class Bias(keras.layers.Layer):
609
610  def build(self, input_shape):
611    self.bias = self.add_variable('bias', (1,), initializer='zeros')
612
613  def call(self, inputs):
614    return inputs + self.bias
615
616
617class _MultiIOSubclassModel(keras.Model):
618  """Multi IO Keras subclass model."""
619
620  def __init__(self, branch_a, branch_b, shared_input_branch=None,
621               shared_output_branch=None, name=None):
622    super(_MultiIOSubclassModel, self).__init__(name=name)
623    self._shared_input_branch = shared_input_branch
624    self._branch_a = branch_a
625    self._branch_b = branch_b
626    self._shared_output_branch = shared_output_branch
627
628  def call(self, inputs, **kwargs):
629    if self._shared_input_branch:
630      for layer in self._shared_input_branch:
631        inputs = layer(inputs)
632      a = inputs
633      b = inputs
634    else:
635      a, b = inputs
636
637    for layer in self._branch_a:
638      a = layer(a)
639    for layer in self._branch_b:
640      b = layer(b)
641    outs = [a, b]
642
643    if self._shared_output_branch:
644      for layer in self._shared_output_branch:
645        outs = layer(outs)
646
647    return outs
648
649
650class _MultiIOSubclassModelCustomBuild(keras.Model):
651  """Multi IO Keras subclass model that uses a custom build method."""
652
653  def __init__(self, branch_a_func, branch_b_func,
654               shared_input_branch_func=None,
655               shared_output_branch_func=None):
656    super(_MultiIOSubclassModelCustomBuild, self).__init__()
657    self._shared_input_branch_func = shared_input_branch_func
658    self._branch_a_func = branch_a_func
659    self._branch_b_func = branch_b_func
660    self._shared_output_branch_func = shared_output_branch_func
661
662    self._shared_input_branch = None
663    self._branch_a = None
664    self._branch_b = None
665    self._shared_output_branch = None
666
667  def build(self, input_shape):
668    if self._shared_input_branch_func():
669      self._shared_input_branch = self._shared_input_branch_func()
670    self._branch_a = self._branch_a_func()
671    self._branch_b = self._branch_b_func()
672
673    if self._shared_output_branch_func():
674      self._shared_output_branch = self._shared_output_branch_func()
675
676  def call(self, inputs, **kwargs):
677    if self._shared_input_branch:
678      for layer in self._shared_input_branch:
679        inputs = layer(inputs)
680      a = inputs
681      b = inputs
682    else:
683      a, b = inputs
684
685    for layer in self._branch_a:
686      a = layer(a)
687    for layer in self._branch_b:
688      b = layer(b)
689    outs = a, b
690
691    if self._shared_output_branch:
692      for layer in self._shared_output_branch:
693        outs = layer(outs)
694
695    return outs
696
697
698def get_multi_io_model(
699    branch_a,
700    branch_b,
701    shared_input_branch=None,
702    shared_output_branch=None):
703  """Builds a multi-io model that contains two branches.
704
705  The produced model will be of the type specified by `get_model_type`.
706
707  To build a two-input, two-output model:
708    Specify a list of layers for branch a and branch b, but do not specify any
709    shared input branch or shared output branch. The resulting model will apply
710    each branch to a different input, to produce two outputs.
711
712    The first value in branch_a must be the Keras 'Input' layer for branch a,
713    and the first value in branch_b must be the Keras 'Input' layer for
714    branch b.
715
716    example usage:
717    ```
718    branch_a = [Input(shape=(2,), name='a'), Dense(), Dense()]
719    branch_b = [Input(shape=(3,), name='b'), Dense(), Dense()]
720
721    model = get_multi_io_model(branch_a, branch_b)
722    ```
723
724  To build a two-input, one-output model:
725    Specify a list of layers for branch a and branch b, and specify a
726    shared output branch. The resulting model will apply
727    each branch to a different input. It will then apply the shared output
728    branch to a tuple containing the intermediate outputs of each branch,
729    to produce a single output. The first layer in the shared_output_branch
730    must be able to merge a tuple of two tensors.
731
732    The first value in branch_a must be the Keras 'Input' layer for branch a,
733    and the first value in branch_b must be the Keras 'Input' layer for
734    branch b.
735
736    example usage:
737    ```
738    input_branch_a = [Input(shape=(2,), name='a'), Dense(), Dense()]
739    input_branch_b = [Input(shape=(3,), name='b'), Dense(), Dense()]
740    shared_output_branch = [Concatenate(), Dense(), Dense()]
741
742    model = get_multi_io_model(input_branch_a, input_branch_b,
743                               shared_output_branch=shared_output_branch)
744    ```
745  To build a one-input, two-output model:
746    Specify a list of layers for branch a and branch b, and specify a
747    shared input branch. The resulting model will take one input, and apply
748    the shared input branch to it. It will then respectively apply each branch
749    to that intermediate result in parallel, to produce two outputs.
750
751    The first value in the shared_input_branch must be the Keras 'Input' layer
752    for the whole model. Branch a and branch b should not contain any Input
753    layers.
754
755    example usage:
756    ```
757    shared_input_branch = [Input(shape=(2,), name='in'), Dense(), Dense()]
758    output_branch_a = [Dense(), Dense()]
759    output_branch_b = [Dense(), Dense()]
760
761
762    model = get_multi_io_model(output__branch_a, output_branch_b,
763                               shared_input_branch=shared_input_branch)
764    ```
765
766  Args:
767    branch_a: A sequence of layers for branch a of the model.
768    branch_b: A sequence of layers for branch b of the model.
769    shared_input_branch: An optional sequence of layers to apply to a single
770      input, before applying both branches to that intermediate result. If set,
771      the model will take only one input instead of two. Defaults to None.
772    shared_output_branch: An optional sequence of layers to merge the
773      intermediate results produced by branch a and branch b. If set,
774      the model will produce only one output instead of two. Defaults to None.
775
776  Returns:
777    A multi-io model of the type specified by `get_model_type`, specified
778    by the different branches.
779  """
780  # Extract the functional inputs from the layer lists
781  if shared_input_branch:
782    inputs = shared_input_branch[0]
783    shared_input_branch = shared_input_branch[1:]
784  else:
785    inputs = branch_a[0], branch_b[0]
786    branch_a = branch_a[1:]
787    branch_b = branch_b[1:]
788
789  model_type = get_model_type()
790  if model_type == 'subclass':
791    return _MultiIOSubclassModel(branch_a, branch_b, shared_input_branch,
792                                 shared_output_branch)
793
794  if model_type == 'subclass_custom_build':
795    return _MultiIOSubclassModelCustomBuild((lambda: branch_a),
796                                            (lambda: branch_b),
797                                            (lambda: shared_input_branch),
798                                            (lambda: shared_output_branch))
799
800  if model_type == 'sequential':
801    raise ValueError('Cannot use `get_multi_io_model` to construct '
802                     'sequential models')
803
804  if model_type == 'functional':
805    if shared_input_branch:
806      a_and_b = inputs
807      for layer in shared_input_branch:
808        a_and_b = layer(a_and_b)
809      a = a_and_b
810      b = a_and_b
811    else:
812      a, b = inputs
813
814    for layer in branch_a:
815      a = layer(a)
816    for layer in branch_b:
817      b = layer(b)
818    outputs = a, b
819
820    if shared_output_branch:
821      for layer in shared_output_branch:
822        outputs = layer(outputs)
823
824    return keras.Model(inputs, outputs)
825
826  raise ValueError('Unknown model type {}'.format(model_type))
827
828
829_V2_OPTIMIZER_MAP = {
830    'adadelta': adadelta_v2.Adadelta,
831    'adagrad': adagrad_v2.Adagrad,
832    'adam': adam_v2.Adam,
833    'adamax': adamax_v2.Adamax,
834    'nadam': nadam_v2.Nadam,
835    'rmsprop': rmsprop_v2.RMSprop,
836    'sgd': gradient_descent_v2.SGD
837}
838
839
840def get_v2_optimizer(name, **kwargs):
841  """Get the v2 optimizer requested.
842
843  This is only necessary until v2 are the default, as we are testing in Eager,
844  and Eager + v1 optimizers fail tests. When we are in v2, the strings alone
845  should be sufficient, and this mapping can theoretically be removed.
846
847  Args:
848    name: string name of Keras v2 optimizer.
849    **kwargs: any kwargs to pass to the optimizer constructor.
850
851  Returns:
852    Initialized Keras v2 optimizer.
853
854  Raises:
855    ValueError: if an unknown name was passed.
856  """
857  try:
858    return _V2_OPTIMIZER_MAP[name](**kwargs)
859  except KeyError:
860    raise ValueError(
861        'Could not find requested v2 optimizer: {}\nValid choices: {}'.format(
862            name, list(_V2_OPTIMIZER_MAP.keys())))
863
864
865def get_expected_metric_variable_names(var_names, name_suffix=''):
866  """Returns expected metric variable names given names and prefix/suffix."""
867  if tf2.enabled() or context.executing_eagerly():
868    # In V1 eager mode and V2 variable names are not made unique.
869    return [n + ':0' for n in var_names]
870  # In V1 graph mode variable names are made unique using a suffix.
871  return [n + name_suffix + ':0' for n in var_names]
872
873
874def enable_v2_dtype_behavior(fn):
875  """Decorator for enabling the layer V2 dtype behavior on a test."""
876  return _set_v2_dtype_behavior(fn, True)
877
878
879def disable_v2_dtype_behavior(fn):
880  """Decorator for disabling the layer V2 dtype behavior on a test."""
881  return _set_v2_dtype_behavior(fn, False)
882
883
884def _set_v2_dtype_behavior(fn, enabled):
885  """Returns version of 'fn' that runs with v2 dtype behavior on or off."""
886  @functools.wraps(fn)
887  def wrapper(*args, **kwargs):
888    v2_dtype_behavior = base_layer_utils.V2_DTYPE_BEHAVIOR
889    base_layer_utils.V2_DTYPE_BEHAVIOR = enabled
890    try:
891      return fn(*args, **kwargs)
892    finally:
893      base_layer_utils.V2_DTYPE_BEHAVIOR = v2_dtype_behavior
894
895  return tf_decorator.make_decorator(fn, wrapper)
896