• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#,============================================================================
15"""Tests for layer graphs construction & handling."""
16
17import warnings
18
19import numpy as np
20
21from tensorflow.python.eager import context
22from tensorflow.python.eager import def_function
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.keras import backend
28from tensorflow.python.keras import combinations
29from tensorflow.python.keras import initializers
30from tensorflow.python.keras import keras_parameterized
31from tensorflow.python.keras import layers
32from tensorflow.python.keras import losses
33from tensorflow.python.keras import models
34from tensorflow.python.keras import testing_utils
35from tensorflow.python.keras.engine import base_layer
36from tensorflow.python.keras.engine import functional
37from tensorflow.python.keras.engine import input_layer as input_layer_lib
38from tensorflow.python.keras.engine import sequential
39from tensorflow.python.keras.engine import training as training_lib
40from tensorflow.python.keras.utils import layer_utils
41from tensorflow.python.keras.utils import tf_utils
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import state_ops
45from tensorflow.python.ops import string_ops
46from tensorflow.python.ops.ragged import ragged_factory_ops
47from tensorflow.python.platform import test
48from tensorflow.python.training.tracking.util import Checkpoint
49
50
51class NetworkConstructionTest(keras_parameterized.TestCase):
52
53  def test_default_model_name(self):
54    inputs = input_layer_lib.Input(shape=(1,))
55    outputs = layers.Dense(1, activation='relu')(inputs)
56    model = training_lib.Model(inputs=inputs, outputs=outputs)
57    self.assertEqual(model.name, 'model')
58
59    model_2 = training_lib.Model(inputs=inputs, outputs=outputs)
60    self.assertEqual(model_2.name, 'model_1')
61
62    model_3 = training_lib.Model(inputs=inputs, outputs=outputs)
63    self.assertEqual(model_3.name, 'model_2')
64
65  def test_get_updates(self):
66
67    class MyLayer(layers.Layer):
68
69      def build(self, input_shape):
70        self.a = self.add_variable('a',
71                                   (1, 1),
72                                   'float32',
73                                   trainable=False)
74        self.b = self.add_variable('b',
75                                   (1, 1),
76                                   'float32',
77                                   trainable=False)
78        self.add_update(state_ops.assign_add(self.a, [[1.]],
79                                             name='unconditional_update'))
80        self.built = True
81
82      def call(self, inputs):
83        self.add_update(state_ops.assign_add(self.b, inputs,
84                                             name='conditional_update'),
85                        inputs=True)
86        return inputs + 1
87
88    with ops.Graph().as_default():
89      x1 = input_layer_lib.Input(shape=(1,))
90      layer = MyLayer()
91      _ = layer(x1)
92
93      self.assertEqual(len(layer.updates), 2)
94
95      x2 = input_layer_lib.Input(shape=(1,))
96      y2 = layer(x2)
97
98      self.assertEqual(len(layer.updates), 3)
99
100      network = functional.Functional(x2, y2)
101      self.assertEqual(len(network.updates), 3)
102
103      x3 = input_layer_lib.Input(shape=(1,))
104      _ = layer(x3)
105      self.assertEqual(len(network.updates), 4)
106
107      x4 = input_layer_lib.Input(shape=(1,))
108      _ = network(x4)
109      self.assertEqual(len(network.updates), 5)
110
111      network.add_update(state_ops.assign_add(layer.a, [[1]]))
112      self.assertEqual(len(network.updates), 6)
113
114      network.add_update(state_ops.assign_add(layer.b, x4), inputs=True)
115      self.assertEqual(len(network.updates), 7)
116
117  @combinations.generate(combinations.combine(mode=['graph']))
118  def test_get_updates_bn(self):
119    x1 = input_layer_lib.Input(shape=(1,))
120    layer = layers.BatchNormalization()
121    _ = layer(x1)
122
123    self.assertEqual(len(layer.updates), 2)
124
125  def test_get_layer(self):
126    # create a simple network
127    x = input_layer_lib.Input(shape=(32,))
128    dense_a = layers.Dense(4, name='dense_a')
129    dense_b = layers.Dense(2, name='dense_b')
130    y = dense_b(dense_a(x))
131    network = functional.Functional(x, y, name='dense_network')
132
133    # test various get_layer by index
134    self.assertEqual(network.get_layer(index=1), dense_a)
135
136    # test invalid get_layer by index
137    with self.assertRaisesRegex(
138        ValueError, 'Was asked to retrieve layer at index ' + str(3) +
139        ' but model only has ' + str(len(network.layers)) + ' layers.'):
140      network.get_layer(index=3)
141
142    # test that only one between name and index is requested
143    with self.assertRaisesRegex(ValueError,
144                                'Provide only a layer name or a layer index'):
145      network.get_layer(index=1, name='dense_b')
146
147    # test that a name or an index must be provided
148    with self.assertRaisesRegex(ValueError,
149                                'Provide either a layer name or layer index.'):
150      network.get_layer()
151
152    # test various get_layer by name
153    self.assertEqual(network.get_layer(name='dense_a'), dense_a)
154
155    # test invalid get_layer by name
156    with self.assertRaisesRegex(ValueError, 'No such layer: dense_c.'):
157      network.get_layer(name='dense_c')
158
159  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
160  def testTopologicalAttributes(self):
161    # test layer attributes / methods related to cross-layer connectivity.
162    a = input_layer_lib.Input(shape=(32,), name='input_a')
163    b = input_layer_lib.Input(shape=(32,), name='input_b')
164
165    # test input, output, input_shape, output_shape
166    test_layer = layers.Dense(16, name='test_layer')
167    a_test = test_layer(a)
168    self.assertIs(test_layer.input, a)
169    self.assertIs(test_layer.output, a_test)
170    self.assertEqual(test_layer.input_shape, (None, 32))
171    self.assertEqual(test_layer.output_shape, (None, 16))
172
173    # test `get_*_at` methods
174    dense = layers.Dense(16, name='dense_1')
175    a_2 = dense(a)
176    b_2 = dense(b)
177
178    self.assertIs(dense.get_input_at(0), a)
179    self.assertIs(dense.get_input_at(1), b)
180    self.assertIs(dense.get_output_at(0), a_2)
181    self.assertIs(dense.get_output_at(1), b_2)
182    self.assertEqual(dense.get_input_shape_at(0), (None, 32))
183    self.assertEqual(dense.get_input_shape_at(1), (None, 32))
184    self.assertEqual(dense.get_output_shape_at(0), (None, 16))
185    self.assertEqual(dense.get_output_shape_at(1), (None, 16))
186
187    # Test invalid value for attribute retrieval.
188    with self.assertRaises(ValueError):
189      dense.get_input_at(2)
190    with self.assertRaises(AttributeError):
191      new_dense = layers.Dense(16)
192      _ = new_dense.input
193    with self.assertRaises(AttributeError):
194      new_dense = layers.Dense(16)
195      _ = new_dense.output
196    with self.assertRaises(AttributeError):
197      new_dense = layers.Dense(16)
198      _ = new_dense.output_shape
199    with self.assertRaises(AttributeError):
200      new_dense = layers.Dense(16)
201      _ = new_dense.input_shape
202    with self.assertRaises(AttributeError):
203      new_dense = layers.Dense(16)
204      a = input_layer_lib.Input(shape=(3, 32))
205      a = input_layer_lib.Input(shape=(5, 32))
206      a_2 = dense(a)
207      b_2 = dense(b)
208      _ = new_dense.input_shape
209    with self.assertRaises(AttributeError):
210      new_dense = layers.Dense(16)
211      a = input_layer_lib.Input(shape=(3, 32))
212      a = input_layer_lib.Input(shape=(5, 32))
213      a_2 = dense(a)
214      b_2 = dense(b)
215      _ = new_dense.output_shape
216
217  def _assertAllIs(self, a, b):
218    self.assertTrue(all(x is y for x, y in zip(a, b)))
219
220  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
221  def testTopologicalAttributesMultiOutputLayer(self):
222
223    class PowersLayer(layers.Layer):
224
225      def call(self, inputs):
226        return [inputs**2, inputs**3]
227
228    x = input_layer_lib.Input(shape=(32,))
229    test_layer = PowersLayer()
230    p1, p2 = test_layer(x)  # pylint: disable=not-callable
231
232    self.assertIs(test_layer.input, x)
233    self._assertAllIs(test_layer.output, [p1, p2])
234    self.assertEqual(test_layer.input_shape, (None, 32))
235    self.assertEqual(test_layer.output_shape, [(None, 32), (None, 32)])
236
237  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
238  def testTopologicalAttributesMultiInputLayer(self):
239
240    class AddLayer(layers.Layer):
241
242      def call(self, inputs):
243        assert len(inputs) == 2
244        return inputs[0] + inputs[1]
245
246    a = input_layer_lib.Input(shape=(32,))
247    b = input_layer_lib.Input(shape=(32,))
248    test_layer = AddLayer()
249    y = test_layer([a, b])  # pylint: disable=not-callable
250
251    self._assertAllIs(test_layer.input, [a, b])
252    self.assertIs(test_layer.output, y)
253    self.assertEqual(test_layer.input_shape, [(None, 32), (None, 32)])
254    self.assertEqual(test_layer.output_shape, (None, 32))
255
256  def testBasicNetwork(self):
257    with ops.Graph().as_default():
258      # minimum viable network
259      x = input_layer_lib.Input(shape=(32,))
260      dense = layers.Dense(2)
261      y = dense(x)
262      network = functional.Functional(x, y, name='dense_network')
263
264      # test basic attributes
265      self.assertEqual(network.name, 'dense_network')
266      self.assertEqual(len(network.layers), 2)  # InputLayer + Dense
267      self.assertEqual(network.layers[1], dense)
268      self._assertAllIs(network.weights, dense.weights)
269      self._assertAllIs(network.trainable_weights, dense.trainable_weights)
270      self._assertAllIs(network.non_trainable_weights,
271                        dense.non_trainable_weights)
272
273      # test callability on Input
274      x_2 = input_layer_lib.Input(shape=(32,))
275      y_2 = network(x_2)
276      self.assertEqual(y_2.shape.as_list(), [None, 2])
277
278      # test callability on regular tensor
279      x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32))
280      y_2 = network(x_2)
281      self.assertEqual(y_2.shape.as_list(), [None, 2])
282
283      # test network `trainable` attribute
284      network.trainable = False
285      self._assertAllIs(network.weights, dense.weights)
286      self.assertEqual(network.trainable_weights, [])
287      self._assertAllIs(network.non_trainable_weights,
288                        dense.trainable_weights + dense.non_trainable_weights)
289
290  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
291  def test_trainable_weights(self):
292    a = layers.Input(shape=(2,))
293    b = layers.Dense(1)(a)
294    model = training_lib.Model(a, b)
295
296    weights = model.weights
297    self._assertAllIs(model.trainable_weights, weights)
298    self.assertListEqual(model.non_trainable_weights, [])
299
300    model.trainable = False
301    self.assertListEqual(model.trainable_weights, [])
302    self._assertAllIs(model.non_trainable_weights, weights)
303
304    model.trainable = True
305    self._assertAllIs(model.trainable_weights, weights)
306    self.assertListEqual(model.non_trainable_weights, [])
307
308    model.layers[1].trainable = False
309    self.assertListEqual(model.trainable_weights, [])
310    self._assertAllIs(model.non_trainable_weights, weights)
311
312    # sequential model
313    model = sequential.Sequential()
314    model.add(layers.Dense(1, input_dim=2))
315    weights = model.weights
316
317    self._assertAllIs(model.trainable_weights, weights)
318    self.assertListEqual(model.non_trainable_weights, [])
319
320    model.trainable = False
321    self.assertListEqual(model.trainable_weights, [])
322    self._assertAllIs(model.non_trainable_weights, weights)
323
324    model.trainable = True
325    self._assertAllIs(model.trainable_weights, weights)
326    self.assertListEqual(model.non_trainable_weights, [])
327
328    model.layers[0].trainable = False
329    self.assertListEqual(model.trainable_weights, [])
330    self._assertAllIs(model.non_trainable_weights, weights)
331
332  def test_layer_call_arguments(self):
333    with ops.Graph().as_default():
334      # Test the ability to pass and serialize arguments to `call`.
335      inp = layers.Input(shape=(2,))
336      x = layers.Dense(3)(inp)
337      x = layers.Dropout(0.5)(x, training=True)
338      model = training_lib.Model(inp, x)
339      # Would be `dropout/cond/Merge` by default
340      self.assertIn('dropout', model.output.op.name)
341
342      # Test that argument is kept when applying the model
343      inp2 = layers.Input(shape=(2,))
344      out2 = model(inp2)
345      self.assertIn('dropout', out2.op.name)
346
347      # Test that argument is kept after loading a model
348      config = model.get_config()
349      model = training_lib.Model.from_config(config)
350      self.assertIn('dropout', model.output.op.name)
351
352  def test_node_construction(self):
353    # test basics
354    a = layers.Input(shape=(32,), name='input_a')
355    b = layers.Input(shape=(32,), name='input_b')
356
357    with self.assertRaises(ValueError):
358      _ = layers.Input(shape=(32,), batch_shape=(10, 32))
359    with self.assertRaises(ValueError):
360      _ = layers.Input(shape=(32,), unknown_kwarg=None)
361
362    self.assertListEqual(a.shape.as_list(), [None, 32])
363    a_layer, a_node_index, a_tensor_index = a._keras_history
364    b_layer, _, _ = b._keras_history
365    self.assertEqual(len(a_layer._inbound_nodes), 1)
366    self.assertEqual(a_tensor_index, 0)
367    node = a_layer._inbound_nodes[a_node_index]
368    self.assertEqual(node.outbound_layer, a_layer)
369
370    self.assertListEqual(node.inbound_layers, [])
371    self.assertListEqual(node.input_tensors, [a])
372    self.assertListEqual(node.input_shapes, [(None, 32)])
373    self.assertListEqual(node.output_tensors, [a])
374    self.assertListEqual(node.output_shapes, [(None, 32)])
375
376    dense = layers.Dense(16, name='dense_1')
377    a_2 = dense(a)
378    b_2 = dense(b)
379
380    self.assertEqual(len(dense._inbound_nodes), 2)
381    self.assertEqual(len(dense._outbound_nodes), 0)
382    self.assertEqual(dense._inbound_nodes[0].inbound_layers, a_layer)
383    self.assertEqual(dense._inbound_nodes[0].outbound_layer, dense)
384    self.assertEqual(dense._inbound_nodes[1].inbound_layers, b_layer)
385    self.assertEqual(dense._inbound_nodes[1].outbound_layer, dense)
386    self.assertIs(dense._inbound_nodes[0].input_tensors, a)
387    self.assertIs(dense._inbound_nodes[1].input_tensors, b)
388
389    # test layer properties
390    test_layer = layers.Dense(16, name='test_layer')
391    a_test = test_layer(a)
392    self.assertListEqual(test_layer.kernel.shape.as_list(), [32, 16])
393    self.assertIs(test_layer.input, a)
394    self.assertIs(test_layer.output, a_test)
395    self.assertEqual(test_layer.input_shape, (None, 32))
396    self.assertEqual(test_layer.output_shape, (None, 16))
397
398    self.assertIs(dense.get_input_at(0), a)
399    self.assertIs(dense.get_input_at(1), b)
400    self.assertIs(dense.get_output_at(0), a_2)
401    self.assertIs(dense.get_output_at(1), b_2)
402    self.assertEqual(dense.get_input_shape_at(0), (None, 32))
403    self.assertEqual(dense.get_input_shape_at(1), (None, 32))
404    self.assertEqual(dense.get_output_shape_at(0), (None, 16))
405    self.assertEqual(dense.get_output_shape_at(1), (None, 16))
406    self.assertEqual(dense.get_input_mask_at(0), None)
407    self.assertEqual(dense.get_input_mask_at(1), None)
408    self.assertEqual(dense.get_output_mask_at(0), None)
409    self.assertEqual(dense.get_output_mask_at(1), None)
410
411  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
412  def test_multi_input_layer(self):
413    with self.cached_session():
414      # test multi-input layer
415      a = layers.Input(shape=(32,), name='input_a')
416      b = layers.Input(shape=(32,), name='input_b')
417
418      dense = layers.Dense(16, name='dense_1')
419      a_2 = dense(a)
420      b_2 = dense(b)
421
422      merged = layers.concatenate([a_2, b_2], name='merge')
423      self.assertListEqual(merged.shape.as_list(), [None, 16 * 2])
424      merge_layer, merge_node_index, merge_tensor_index = merged._keras_history
425
426      self.assertEqual(merge_node_index, 0)
427      self.assertEqual(merge_tensor_index, 0)
428
429      self.assertEqual(len(merge_layer._inbound_nodes), 1)
430      self.assertEqual(len(merge_layer._outbound_nodes), 0)
431
432      self.assertEqual(len(merge_layer._inbound_nodes[0].input_tensors), 2)
433      self.assertEqual(len(merge_layer._inbound_nodes[0].inbound_layers), 2)
434
435      c = layers.Dense(64, name='dense_2')(merged)
436      d = layers.Dense(5, name='dense_3')(c)
437
438      model = training_lib.Model(inputs=[a, b], outputs=[c, d], name='model')
439      self.assertEqual(len(model.layers), 6)
440      output_shapes = model.compute_output_shape([(None, 32), (None, 32)])
441      self.assertListEqual(output_shapes[0].as_list(), [None, 64])
442      self.assertListEqual(output_shapes[1].as_list(), [None, 5])
443      self.assertListEqual(
444          model.compute_mask([a, b], [None, None]), [None, None])
445
446      # we don't check names of first 2 layers (inputs) because
447      # ordering of same-level layers is not fixed
448      self.assertListEqual([l.name for l in model.layers][2:],
449                           ['dense_1', 'merge', 'dense_2', 'dense_3'])
450      self.assertListEqual([l.name for l in model._input_layers],
451                           ['input_a', 'input_b'])
452      self.assertListEqual([l.name for l in model._output_layers],
453                           ['dense_2', 'dense_3'])
454
455      # actually run model
456      fn = backend.function(model.inputs, model.outputs)
457      input_a_np = np.random.random((10, 32))
458      input_b_np = np.random.random((10, 32))
459      fn_outputs = fn([input_a_np, input_b_np])
460      self.assertListEqual([x.shape for x in fn_outputs], [(10, 64), (10, 5)])
461
462      # test get_source_inputs
463      self._assertAllIs(layer_utils.get_source_inputs(c), [a, b])
464
465      # serialization / deserialization
466      json_config = model.to_json()
467      recreated_model = models.model_from_json(json_config)
468      recreated_model.compile('rmsprop', 'mse')
469
470      self.assertListEqual([l.name for l in recreated_model.layers][2:],
471                           ['dense_1', 'merge', 'dense_2', 'dense_3'])
472      self.assertListEqual([l.name for l in recreated_model._input_layers],
473                           ['input_a', 'input_b'])
474      self.assertListEqual([l.name for l in recreated_model._output_layers],
475                           ['dense_2', 'dense_3'])
476
477      fn = backend.function(recreated_model.inputs, recreated_model.outputs)
478      input_a_np = np.random.random((10, 32))
479      input_b_np = np.random.random((10, 32))
480      fn_outputs = fn([input_a_np, input_b_np])
481      self.assertListEqual([x.shape for x in fn_outputs], [(10, 64), (10, 5)])
482
483  def test_multi_output_layer_output_names(self):
484    inp = layers.Input(name='inp', shape=(None,), dtype=dtypes.float32)
485
486    class _MultiOutput(layers.Layer):
487
488      def call(self, x):
489        return x + 1., x + 2.
490
491    out = _MultiOutput(name='out')(inp)
492    model = training_lib.Model(inp, out)
493    self.assertEqual(['out', 'out_1'], model.output_names)
494    self.assertAllClose([2., 3.], model(1.))
495
496  def test_recursion(self):
497    with ops.Graph().as_default(), self.cached_session():
498      a = layers.Input(shape=(32,), name='input_a')
499      b = layers.Input(shape=(32,), name='input_b')
500
501      dense = layers.Dense(16, name='dense_1')
502      a_2 = dense(a)
503      b_2 = dense(b)
504      merged = layers.concatenate([a_2, b_2], name='merge')
505      c = layers.Dense(64, name='dense_2')(merged)
506      d = layers.Dense(5, name='dense_3')(c)
507
508      model = training_lib.Model(inputs=[a, b], outputs=[c, d], name='model')
509
510      e = layers.Input(shape=(32,), name='input_e')
511      f = layers.Input(shape=(32,), name='input_f')
512      self.assertEqual(len(model.inputs), 2)
513      g, h = model([e, f])
514      self.assertEqual(len(model.inputs), 2)
515      self.assertEqual(g.name, 'model/dense_2/BiasAdd:0')
516
517      self.assertListEqual(g.shape.as_list(), c.shape.as_list())
518      self.assertListEqual(h.shape.as_list(), d.shape.as_list())
519
520      # test separate manipulation of different layer outputs
521      i = layers.Dense(7, name='dense_4')(h)
522
523      final_model = training_lib.Model(
524          inputs=[e, f], outputs=[i, g], name='final')
525      self.assertEqual(len(final_model.inputs), 2)
526      self.assertEqual(len(final_model.outputs), 2)
527      self.assertEqual(len(final_model.layers), 4)
528
529      # we don't check names of first 2 layers (inputs) because
530      # ordering of same-level layers is not fixed
531      self.assertListEqual([layer.name for layer in final_model.layers][2:],
532                           ['model', 'dense_4'])
533      self.assertListEqual(
534          model.compute_mask([e, f], [None, None]), [None, None])
535      self.assertListEqual(
536          final_model.compute_output_shape([(10, 32), (10, 32)]), [(10, 7),
537                                                                   (10, 64)])
538
539      # run recursive model
540      fn = backend.function(final_model.inputs, final_model.outputs)
541      input_a_np = np.random.random((10, 32))
542      input_b_np = np.random.random((10, 32))
543      fn_outputs = fn([input_a_np, input_b_np])
544      self.assertListEqual([x.shape for x in fn_outputs], [(10, 7), (10, 64)])
545
546      # test serialization
547      model_config = final_model.get_config()
548      recreated_model = models.Model.from_config(model_config)
549
550      fn = backend.function(recreated_model.inputs, recreated_model.outputs)
551      input_a_np = np.random.random((10, 32))
552      input_b_np = np.random.random((10, 32))
553      fn_outputs = fn([input_a_np, input_b_np])
554      self.assertListEqual([x.shape for x in fn_outputs], [(10, 7), (10, 64)])
555
556  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
557  def test_multi_input_multi_output_recursion(self):
558    with self.cached_session():
559      # test multi-input multi-output
560      a = layers.Input(shape=(32,), name='input_a')
561      b = layers.Input(shape=(32,), name='input_b')
562
563      dense = layers.Dense(16, name='dense_1')
564      a_2 = dense(a)
565      b_2 = dense(b)
566      merged = layers.concatenate([a_2, b_2], name='merge')
567      c = layers.Dense(64, name='dense_2')(merged)
568      d = layers.Dense(5, name='dense_3')(c)
569
570      model = training_lib.Model(inputs=[a, b], outputs=[c, d], name='model')
571
572      j = layers.Input(shape=(32,), name='input_j')
573      k = layers.Input(shape=(32,), name='input_k')
574      _, n = model([j, k])
575
576      o = layers.Input(shape=(32,), name='input_o')
577      p = layers.Input(shape=(32,), name='input_p')
578      q, _ = model([o, p])
579
580      self.assertListEqual(n.shape.as_list(), [None, 5])
581      self.assertListEqual(q.shape.as_list(), [None, 64])
582      s = layers.concatenate([n, q], name='merge_nq')
583      self.assertListEqual(s.shape.as_list(), [None, 64 + 5])
584
585      # test with single output as 1-elem list
586      multi_io_model = training_lib.Model([j, k, o, p], [s])
587
588      fn = backend.function(multi_io_model.inputs, multi_io_model.outputs)
589      fn_outputs = fn([
590          np.random.random((10, 32)), np.random.random((10, 32)),
591          np.random.random((10, 32)), np.random.random((10, 32))
592      ])
593      self.assertListEqual([x.shape for x in fn_outputs], [(10, 69)])
594
595      # test with single output as tensor
596      multi_io_model = training_lib.Model([j, k, o, p], s)
597
598      fn = backend.function(multi_io_model.inputs, multi_io_model.outputs)
599      fn_outputs = fn([
600          np.random.random((10, 32)), np.random.random((10, 32)),
601          np.random.random((10, 32)), np.random.random((10, 32))
602      ])
603      # note that the output of the function will still be a 1-elem list
604      self.assertListEqual([x.shape for x in fn_outputs], [(10, 69)])
605
606      # test serialization
607      model_config = multi_io_model.get_config()
608      recreated_model = models.Model.from_config(model_config)
609
610      fn = backend.function(recreated_model.inputs, recreated_model.outputs)
611      fn_outputs = fn([
612          np.random.random((10, 32)), np.random.random((10, 32)),
613          np.random.random((10, 32)), np.random.random((10, 32))
614      ])
615      # note that the output of the function will still be a 1-elem list
616      self.assertListEqual([x.shape for x in fn_outputs], [(10, 69)])
617
618      config = model.get_config()
619      models.Model.from_config(config)
620
621      model.summary()
622      json_str = model.to_json()
623      models.model_from_json(json_str)
624
625  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
626  def test_invalid_graphs(self):
627    a = layers.Input(shape=(32,), name='input_a')
628    b = layers.Input(shape=(32,), name='input_b')
629
630    dense = layers.Dense(16, name='dense_1')
631    a_2 = dense(a)
632    b_2 = dense(b)
633    merged = layers.concatenate([a_2, b_2], name='merge')
634    c = layers.Dense(64, name='dense_2')(merged)
635    d = layers.Dense(5, name='dense_3')(c)
636
637    model = training_lib.Model(inputs=[a, b], outputs=[c, d], name='model')
638
639    # input is not an Input tensor
640    j = layers.Input(shape=(32,), name='input_j')
641    j = layers.Dense(32)(j)
642    k = layers.Input(shape=(32,), name='input_k')
643    m, n = model([j, k])
644
645    with self.assertRaises(Exception):
646      training_lib.Model([j, k], [m, n])
647
648    # disconnected graph
649    j = layers.Input(shape=(32,), name='input_j')
650    k = layers.Input(shape=(32,), name='input_k')
651    m, n = model([j, k])
652    with self.assertRaises(Exception):
653      training_lib.Model([j], [m, n])
654
655    # redundant outputs
656    j = layers.Input(shape=(32,), name='input_j')
657    k = layers.Input(shape=(32,), name='input_k')
658    m, n = model([j, k])
659
660    training_lib.Model([j, k], [m, n, n])
661
662    # redundant inputs
663    j = layers.Input(shape=(32,), name='input_j')
664    k = layers.Input(shape=(32,), name='input_k')
665    m, n = model([j, k])
666    with self.assertRaises(Exception):
667      training_lib.Model([j, k, j], [m, n])
668
669    # i have not idea what I'm doing: garbage as inputs/outputs
670    j = layers.Input(shape=(32,), name='input_j')
671    k = layers.Input(shape=(32,), name='input_k')
672    m, n = model([j, k])
673    with self.assertRaises(Exception):
674      training_lib.Model([j, k], [m, n, 0])
675
676  def test_raw_tf_compatibility(self):
677    with ops.Graph().as_default():
678      # test calling layers/models on TF tensors
679      a = layers.Input(shape=(32,), name='input_a')
680      b = layers.Input(shape=(32,), name='input_b')
681
682      dense = layers.Dense(16, name='dense_1')
683      a_2 = dense(a)
684      b_2 = dense(b)
685      merged = layers.concatenate([a_2, b_2], name='merge')
686      c = layers.Dense(64, name='dense_2')(merged)
687      d = layers.Dense(5, name='dense_3')(c)
688
689      model = training_lib.Model(inputs=[a, b], outputs=[c, d], name='model')
690
691      j = layers.Input(shape=(32,), name='input_j')
692      k = layers.Input(shape=(32,), name='input_k')
693      self.assertEqual(len(model.inputs), 2)
694      m, n = model([j, k])
695      self.assertEqual(len(model.inputs), 2)
696      tf_model = training_lib.Model([j, k], [m, n])
697
698      j_tf = array_ops.placeholder(dtype=dtypes.float32, shape=(None, 32))
699      k_tf = array_ops.placeholder(dtype=dtypes.float32, shape=(None, 32))
700      m_tf, n_tf = tf_model([j_tf, k_tf])
701      self.assertListEqual(m_tf.shape.as_list(), [None, 64])
702      self.assertListEqual(n_tf.shape.as_list(), [None, 5])
703
704      # test merge
705      layers.concatenate([j_tf, k_tf], axis=1)
706      layers.add([j_tf, k_tf])
707
708      # test tensor input
709      x = array_ops.placeholder(shape=(None, 2), dtype=dtypes.float32)
710      layers.InputLayer(input_tensor=x)
711
712      x = layers.Input(tensor=x)
713      layers.Dense(2)(x)
714
715  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
716  def test_basic_masking(self):
717    a = layers.Input(shape=(10, 32), name='input_a')
718    b = layers.Masking()(a)
719    model = training_lib.Model(a, b)
720    self.assertEqual(model.output_mask.shape.as_list(), [None, 10])
721
722  def testMaskingSingleInput(self):
723
724    class MaskedLayer(layers.Layer):
725
726      def call(self, inputs, mask=None):
727        if mask is not None:
728          return inputs * mask
729        return inputs
730
731      def compute_mask(self, inputs, mask=None):
732        return array_ops.ones_like(inputs)
733
734    if context.executing_eagerly():
735      a = constant_op.constant([2] * 32)
736      mask = constant_op.constant([0, 1] * 16)
737      a._keras_mask = mask
738      b = MaskedLayer().apply(a)
739      self.assertTrue(hasattr(b, '_keras_mask'))
740      self.assertAllEqual(
741          self.evaluate(array_ops.ones_like(mask)),
742          self.evaluate(getattr(b, '_keras_mask')))
743      self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b))
744    else:
745      x = input_layer_lib.Input(shape=(32,))
746      y = MaskedLayer()(x)  # pylint: disable=not-callable
747      network = functional.Functional(x, y)
748
749      # test callability on Input
750      x_2 = input_layer_lib.Input(shape=(32,))
751      y_2 = network(x_2)
752      self.assertEqual(y_2.shape.as_list(), [None, 32])
753
754      # test callability on regular tensor
755      x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32))
756      y_2 = network(x_2)
757      self.assertEqual(y_2.shape.as_list(), [None, 32])
758
759  def test_activity_regularization_with_model_composition(self):
760
761    def reg(x):
762      return math_ops.reduce_sum(x)
763
764    net_a_input = input_layer_lib.Input((2,))
765    net_a = net_a_input
766    net_a = layers.Dense(
767        2, kernel_initializer='ones', use_bias=False, activity_regularizer=reg)(
768            net_a)
769    model_a = training_lib.Model([net_a_input], [net_a])
770
771    net_b_input = input_layer_lib.Input((2,))
772    net_b = model_a(net_b_input)
773    model_b = training_lib.Model([net_b_input], [net_b])
774
775    model_b.compile(optimizer='sgd', loss=None)
776    x = np.ones((1, 2))
777    loss = model_b.evaluate(x)
778    self.assertEqual(loss, 4.)
779
780  @combinations.generate(combinations.keras_mode_combinations())
781  def test_layer_sharing_at_heterogenous_depth(self):
782    x_val = np.random.random((10, 5))
783
784    x = input_layer_lib.Input(shape=(5,))
785    a = layers.Dense(5, name='A')
786    b = layers.Dense(5, name='B')
787    output = a(b(a(b(x))))
788    m = training_lib.Model(x, output)
789    m.run_eagerly = testing_utils.should_run_eagerly()
790
791    output_val = m.predict(x_val)
792
793    config = m.get_config()
794    weights = m.get_weights()
795
796    m2 = models.Model.from_config(config)
797    m2.set_weights(weights)
798
799    output_val_2 = m2.predict(x_val)
800    self.assertAllClose(output_val, output_val_2, atol=1e-6)
801
802  @combinations.generate(combinations.keras_mode_combinations())
803  def test_layer_sharing_at_heterogenous_depth_with_concat(self):
804    input_shape = (16, 9, 3)
805    input_layer = input_layer_lib.Input(shape=input_shape)
806
807    a = layers.Dense(3, name='dense_A')
808    b = layers.Dense(3, name='dense_B')
809    c = layers.Dense(3, name='dense_C')
810
811    x1 = b(a(input_layer))
812    x2 = a(c(input_layer))
813    output = layers.concatenate([x1, x2])
814
815    m = training_lib.Model(inputs=input_layer, outputs=output)
816    m.run_eagerly = testing_utils.should_run_eagerly()
817
818    x_val = np.random.random((10, 16, 9, 3))
819    output_val = m.predict(x_val)
820
821    config = m.get_config()
822    weights = m.get_weights()
823
824    m2 = models.Model.from_config(config)
825    m2.set_weights(weights)
826
827    output_val_2 = m2.predict(x_val)
828    self.assertAllClose(output_val, output_val_2, atol=1e-6)
829
830  @combinations.generate(combinations.keras_mode_combinations())
831  def test_explicit_training_argument(self):
832    a = layers.Input(shape=(2,))
833    b = layers.Dropout(0.5)(a)
834    base_model = training_lib.Model(a, b)
835
836    a = layers.Input(shape=(2,))
837    b = base_model(a, training=False)
838    model = training_lib.Model(a, b)
839
840    x = np.ones((100, 2))
841    y = np.ones((100, 2))
842    model.compile(
843        optimizer='sgd',
844        loss='mse',
845        run_eagerly=testing_utils.should_run_eagerly())
846    loss = model.train_on_batch(x, y)
847    self.assertEqual(loss, 0)  # In inference mode, output is equal to input.
848
849    a = layers.Input(shape=(2,))
850    b = base_model(a, training=True)
851    model = training_lib.Model(a, b)
852    preds = model.predict(x)
853    self.assertEqual(np.min(preds), 0.)  # At least one unit was dropped.
854
855  @combinations.generate(combinations.keras_mode_combinations())
856  def test_mask_derived_from_keras_layer(self):
857    inputs = input_layer_lib.Input((5, 10))
858    mask = input_layer_lib.Input((5,))
859    outputs = layers.RNN(layers.LSTMCell(100))(inputs, mask=mask)
860    model = training_lib.Model([inputs, mask], outputs)
861    model.compile(
862        'sgd',
863        'mse',
864        run_eagerly=testing_utils.should_run_eagerly())
865    history = model.fit(
866        x=[np.ones((10, 5, 10)), np.zeros((10, 5))],
867        y=np.zeros((10, 100)),
868        batch_size=2)
869    # All data is masked, returned values are 0's.
870    self.assertEqual(history.history['loss'][0], 0.0)
871    history = model.fit(
872        x=[np.ones((10, 5, 10)), np.ones((10, 5))],
873        y=np.zeros((10, 100)),
874        batch_size=2)
875    # Data is not masked, returned values are random.
876    self.assertGreater(history.history['loss'][0], 0.0)
877
878    model = training_lib.Model.from_config(model.get_config())
879    model.compile(
880        'sgd',
881        'mse',
882        run_eagerly=testing_utils.should_run_eagerly())
883    history = model.fit(
884        x=[np.ones((10, 5, 10)), np.zeros((10, 5))],
885        y=np.zeros((10, 100)),
886        batch_size=2)
887    # All data is masked, returned values are 0's.
888    self.assertEqual(history.history['loss'][0], 0.0)
889    history = model.fit(
890        x=[np.ones((10, 5, 10)), np.ones((10, 5))],
891        y=np.zeros((10, 100)),
892        batch_size=2)
893    # Data is not masked, returned values are random.
894    self.assertGreater(history.history['loss'][0], 0.0)
895
896  @combinations.generate(combinations.keras_mode_combinations())
897  def test_call_arg_derived_from_keras_layer(self):
898
899    class MyAdd(layers.Layer):
900
901      def call(self, x1, x2):
902        return x1 + x2
903
904    input1 = input_layer_lib.Input(10)
905    input2 = input_layer_lib.Input(10)
906    outputs = MyAdd()(input1, input2)
907    model = training_lib.Model([input1, input2], outputs)
908    model.compile(
909        'sgd',
910        'mse',
911        run_eagerly=testing_utils.should_run_eagerly())
912    history = model.fit(
913        x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
914        y=10 * np.ones((10, 10)),
915        batch_size=2)
916    # Check that second input was correctly added to first.
917    self.assertEqual(history.history['loss'][0], 0.0)
918
919    # Check serialization.
920    model = training_lib.Model.from_config(
921        model.get_config(), custom_objects={'MyAdd': MyAdd})
922    model.compile(
923        'sgd',
924        'mse',
925        run_eagerly=testing_utils.should_run_eagerly())
926    history = model.fit(
927        x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
928        y=10 * np.ones((10, 10)),
929        batch_size=2)
930    # Check that second input was correctly added to first.
931    self.assertEqual(history.history['loss'][0], 0.0)
932
933  @combinations.generate(combinations.keras_mode_combinations(mode='eager'),)
934  def test_only_some_in_first_arg_derived_from_keras_layer_keras_tensors(self):
935    # This functionality is unsupported in v1 graphs
936
937    class MyAddAll(layers.Layer):
938
939      def call(self, inputs):
940        x = inputs[0]
941        for inp in inputs[1:]:
942          if inp is not None:
943            x = x + inp
944        return x
945
946    input1 = input_layer_lib.Input(10)
947    input2 = input_layer_lib.Input(10)
948    layer = MyAddAll()
949    outputs = layer([0.0, input1, None, input2, None])
950    model = training_lib.Model([input1, input2], outputs)
951    self.assertIn(layer, model.layers)
952    model.compile(
953        'sgd',
954        'mse',
955        run_eagerly=testing_utils.should_run_eagerly())
956    history = model.fit(
957        x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
958        y=10 * np.ones((10, 10)),
959        batch_size=2)
960    # Check that second input was correctly added to first.
961    self.assertEqual(history.history['loss'][0], 0.0)
962
963    # Check serialization.
964    model = training_lib.Model.from_config(
965        model.get_config(), custom_objects={'MyAddAll': MyAddAll})
966    model.compile(
967        'sgd',
968        'mse',
969        run_eagerly=testing_utils.should_run_eagerly())
970    history = model.fit(
971        x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
972        y=10 * np.ones((10, 10)),
973        batch_size=2)
974    # Check that second input was correctly added to first.
975    self.assertEqual(history.history['loss'][0], 0.0)
976
977  @combinations.generate(
978      combinations.times(
979          combinations.keras_mode_combinations(),
980          combinations.combine(share_already_used_layer=[True, False])))
981  def test_call_kwarg_derived_from_keras_layer(self, share_already_used_layer):
982
983    class MaybeAdd(layers.Layer):
984
985      def call(self, x1, x2=None):
986        if x2 is not None:
987          return x1 + x2
988        return x1
989
990    class IdentityLayer(layers.Layer):
991
992      def call(self, x):
993        return x
994
995    input1 = input_layer_lib.Input(10)
996    input2 = input_layer_lib.Input(10)
997    identity_layer = IdentityLayer()
998
999    if share_already_used_layer:
1000      # We have had model serialization/deserialization break in the past:
1001      # when a layer was previously used to construct other functional models
1002      # and had a non-empty list of inbound nodes before being used to define
1003      # the model being serialized/deserialized.
1004      # (The serialization/deserialization was not correctly adjusting
1005      # the node_index serialization/deserialization).
1006      # So, we explicitly test this case.
1007      training_lib.Model([input1], identity_layer(input1))
1008
1009    outputs = MaybeAdd()(input1, x2=identity_layer(input2))
1010    model = training_lib.Model([input1, input2], outputs)
1011    model.compile(
1012        'sgd',
1013        'mse',
1014        run_eagerly=testing_utils.should_run_eagerly())
1015    history = model.fit(
1016        x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
1017        y=10 * np.ones((10, 10)),
1018        batch_size=2)
1019    # Check that second input was correctly added to first.
1020    self.assertEqual(history.history['loss'][0], 0.0)
1021
1022    model = training_lib.Model.from_config(
1023        model.get_config(),
1024        custom_objects={
1025            'MaybeAdd': MaybeAdd,
1026            'IdentityLayer': IdentityLayer
1027        })
1028    model.compile(
1029        'sgd',
1030        'mse',
1031        run_eagerly=testing_utils.should_run_eagerly())
1032    history = model.fit(
1033        x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
1034        y=10 * np.ones((10, 10)),
1035        batch_size=2)
1036    # Check that second input was correctly added to first.
1037    self.assertEqual(history.history['loss'][0], 0.0)
1038
1039  @combinations.generate(combinations.keras_mode_combinations())
1040  def test_call_kwarg_dtype_serialization(self):
1041
1042    class Double(layers.Layer):
1043
1044      def call(self, x1, dtype=None):
1045        return math_ops.cast(x1 + x1, dtype=dtype)
1046
1047    input1 = input_layer_lib.Input(10)
1048    outputs = Double()(input1, dtype=dtypes.float16)
1049    model = training_lib.Model([input1], outputs)
1050    model.compile(
1051        'sgd',
1052        'mse',
1053        run_eagerly=testing_utils.should_run_eagerly())
1054    history = model.fit(
1055        x=[3 * np.ones((10, 10))],
1056        y=6 * np.ones((10, 10)),
1057        batch_size=2)
1058    # Check that input was correctly doubled.
1059    self.assertEqual(history.history['loss'][0], 0.0)
1060
1061    # Check the output dtype
1062    self.assertEqual(model(array_ops.ones((3, 10))).dtype, dtypes.float16)
1063
1064    model = training_lib.Model.from_config(
1065        model.get_config(), custom_objects={'Double': Double})
1066    model.compile(
1067        'sgd',
1068        'mse',
1069        run_eagerly=testing_utils.should_run_eagerly())
1070    history = model.fit(
1071        x=[3 * np.ones((10, 10))],
1072        y=6 * np.ones((10, 10)),
1073        batch_size=2)
1074    # Check that input was correctly doubled.
1075    self.assertEqual(history.history['loss'][0], 0.0)
1076
1077    # Check the output dtype
1078    self.assertEqual(model(array_ops.ones((3, 10))).dtype, dtypes.float16)
1079
1080  @combinations.generate(combinations.keras_mode_combinations())
1081  def test_call_kwarg_nonserializable(self):
1082
1083    class Double(layers.Layer):
1084
1085      def call(self, x1, kwarg=None):
1086        return x1 + x1
1087
1088    class NonSerializable(object):
1089
1090      def __init__(self, foo=None):
1091        self.foo = foo
1092
1093    input1 = input_layer_lib.Input(10)
1094    outputs = Double()(input1, kwarg=NonSerializable())
1095    model = training_lib.Model([input1], outputs)
1096    model.compile(
1097        'sgd',
1098        'mse',
1099        run_eagerly=testing_utils.should_run_eagerly())
1100    history = model.fit(
1101        x=[3 * np.ones((10, 10))],
1102        y=6 * np.ones((10, 10)),
1103        batch_size=2)
1104    # Check that input was correctly doubled.
1105    self.assertEqual(history.history['loss'][0], 0.0)
1106    with self.assertRaisesRegex(
1107        TypeError, 'Layer double was passed non-JSON-serializable arguments.'):
1108      model.get_config()
1109
1110  @combinations.generate(
1111      combinations.times(
1112          combinations.keras_mode_combinations(),
1113          combinations.combine(share_already_used_layer=[True, False])))
1114  def test_call_kwarg_derived_from_keras_layer_and_first_arg_is_constant(
1115      self, share_already_used_layer):
1116
1117    class IdentityLayer(layers.Layer):
1118
1119      def call(self, x):
1120        return x
1121
1122    class MaybeAdd(layers.Layer):
1123
1124      def call(self, x1, x2=None):
1125        if x2 is not None:
1126          return x1 + x2
1127        return x1
1128
1129    input2 = input_layer_lib.Input(10)
1130    identity_layer = IdentityLayer()
1131    if share_already_used_layer:
1132      # We have had model serialization/deserialization break in the past:
1133      # when a layer was previously used to construct other functional models
1134      # and had a non-empty list of inbound nodes before being used to define
1135      # the model being serialized/deserialized.
1136      # (The serialization/deserialization was not correctly adjusting
1137      # the node_index serialization/deserialization).
1138      # So, we explicitly test this case.
1139      training_lib.Model([input2], identity_layer(input2))
1140
1141    outputs = MaybeAdd()(3., x2=identity_layer(input2))
1142    model = training_lib.Model([input2], outputs)
1143    model.compile(
1144        'sgd',
1145        'mse',
1146        run_eagerly=testing_utils.should_run_eagerly())
1147    history = model.fit(
1148        x=7 * np.ones((10, 10)),
1149        y=10 * np.ones((10, 10)),
1150        batch_size=2)
1151    # Check that second input was correctly added to first.
1152    self.assertEqual(history.history['loss'][0], 0.0)
1153
1154    model = training_lib.Model.from_config(
1155        model.get_config(),
1156        custom_objects={
1157            'MaybeAdd': MaybeAdd,
1158            'IdentityLayer': IdentityLayer
1159        })
1160    model.compile(
1161        'sgd',
1162        'mse',
1163        run_eagerly=testing_utils.should_run_eagerly())
1164    history = model.fit(
1165        x=7 * np.ones((10, 10)),
1166        y=10 * np.ones((10, 10)),
1167        batch_size=2)
1168    # Check that second input was correctly added to first.
1169    self.assertEqual(history.history['loss'][0], 0.0)
1170
1171  @combinations.generate(combinations.keras_mode_combinations())
1172  def test_composite_call_kwarg_derived_from_keras_layer(self):
1173
1174    # Create a test layer that accepts composite tensor inputs.
1175    class MaybeAdd(layers.Layer):
1176
1177      def call(self, x1, x2=None):
1178        # We need to convert this to a tensor for loss calculations -
1179        # losses don't play nicely with ragged tensors yet.
1180        if x2 is not None:
1181          return (x1 + x2).to_tensor(default_value=0)
1182        return x1.to_tensor(default_value=0)
1183
1184    input1 = input_layer_lib.Input((None,), ragged=True)
1185    input2 = input_layer_lib.Input((None,), ragged=True)
1186    outputs = MaybeAdd()(input1, x2=input2)
1187    model = training_lib.Model([input1, input2], outputs)
1188    model.compile(
1189        'sgd',
1190        'mse',
1191        run_eagerly=testing_utils.should_run_eagerly())
1192    input_data = [
1193        ragged_factory_ops.constant([[3.0, 3.0], [3.0, 3.0], [3.0]]),
1194        ragged_factory_ops.constant([[7.0, 7.0], [7.0, 7.0], [7.0]])
1195    ]
1196    expected_data = np.array([[10.0, 10.0], [10.0, 10.0], [10.0, 0.0]])
1197
1198    history = model.fit(x=input_data, y=expected_data)
1199    # Check that second input was correctly added to first.
1200    self.assertEqual(history.history['loss'][0], 0.0)
1201
1202    model = training_lib.Model.from_config(
1203        model.get_config(), custom_objects={'MaybeAdd': MaybeAdd})
1204    model.compile(
1205        'sgd',
1206        'mse',
1207        run_eagerly=testing_utils.should_run_eagerly())
1208    history = model.fit(x=input_data, y=expected_data)
1209    # Check that second input was correctly added to first.
1210    self.assertEqual(history.history['loss'][0], 0.0)
1211
1212  @combinations.generate(combinations.keras_mode_combinations(mode='eager'))
1213  def test_call_some_not_all_nested_in_first_arg_derived_from_keras_layer(self):
1214    # This functionality is unsupported in v1 graphs
1215
1216    class AddAll(layers.Layer):
1217
1218      def call(self, x1_x2, x3):
1219        x1, x2 = x1_x2
1220        out = x1 + x2
1221        if x3 is not None:
1222          for t in x3.values():
1223            out += t
1224        return out
1225
1226    input1 = input_layer_lib.Input(10)
1227    input2 = input_layer_lib.Input(10)
1228    input3 = input_layer_lib.Input(10)
1229
1230    layer = AddAll()
1231    outputs = layer(
1232        [input1, 4 * array_ops.ones((1, 10))],
1233        x3={
1234            'a': input2,
1235            'b': input3,
1236            'c': 5 * array_ops.ones((1, 10))
1237        })
1238    model = training_lib.Model([input1, input2, input3], outputs)
1239    self.assertIn(layer, model.layers)
1240    model.compile(
1241        'sgd',
1242        'mse',
1243        run_eagerly=testing_utils.should_run_eagerly())
1244    history = model.fit(
1245        x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
1246        y=15 * np.ones((10, 10)),
1247        batch_size=2)
1248    # Check that all inputs were correctly added.
1249    self.assertEqual(history.history['loss'][0], 0.0)
1250
1251    model = training_lib.Model.from_config(
1252        model.get_config(), custom_objects={'AddAll': AddAll})
1253    model.compile(
1254        'sgd',
1255        'mse',
1256        run_eagerly=testing_utils.should_run_eagerly())
1257    history = model.fit(
1258        x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
1259        y=15 * np.ones((10, 10)),
1260        batch_size=2)
1261    # Check that all inputs were correctly added.
1262    self.assertEqual(history.history['loss'][0], 0.0)
1263
1264  @combinations.generate(combinations.keras_mode_combinations())
1265  def test_call_nested_arg_derived_from_keras_layer(self):
1266
1267    class AddAll(layers.Layer):
1268
1269      def call(self, x1, x2, x3=None):
1270        out = x1 + x2
1271        if x3 is not None:
1272          for t in x3.values():
1273            out += t
1274        return out
1275
1276    input1 = input_layer_lib.Input(10)
1277    input2 = input_layer_lib.Input(10)
1278    input3 = input_layer_lib.Input(10)
1279    outputs = AddAll()(
1280        input1,
1281        4 * array_ops.ones((1, 10)),
1282        x3={
1283            'a': input2,
1284            'b': input3,
1285            'c': 5 * array_ops.ones((1, 10))
1286        })
1287    model = training_lib.Model([input1, input2, input3], outputs)
1288    model.compile(
1289        'sgd',
1290        'mse',
1291        run_eagerly=testing_utils.should_run_eagerly())
1292    history = model.fit(
1293        x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
1294        y=15 * np.ones((10, 10)),
1295        batch_size=2)
1296    # Check that all inputs were correctly added.
1297    self.assertEqual(history.history['loss'][0], 0.0)
1298
1299    model = training_lib.Model.from_config(
1300        model.get_config(), custom_objects={'AddAll': AddAll})
1301    model.compile(
1302        'sgd',
1303        'mse',
1304        run_eagerly=testing_utils.should_run_eagerly())
1305    history = model.fit(
1306        x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
1307        y=15 * np.ones((10, 10)),
1308        batch_size=2)
1309    # Check that all inputs were correctly added.
1310    self.assertEqual(history.history['loss'][0], 0.0)
1311
1312  @combinations.generate(combinations.keras_mode_combinations())
1313  def test_multi_output_model_with_none_masking(self):
1314    def func(x):
1315      return [x * 0.2, x * 0.3]
1316
1317    def output_shape(input_shape):
1318      return [input_shape, input_shape]
1319
1320    i = layers.Input(shape=(3, 2, 1))
1321    o = layers.Lambda(function=func, output_shape=output_shape)(i)
1322
1323    self.assertEqual(backend.int_shape(o[0]), (None, 3, 2, 1))
1324    self.assertEqual(backend.int_shape(o[1]), (None, 3, 2, 1))
1325
1326    o = layers.add(o)
1327    model = training_lib.Model(i, o)
1328    model.run_eagerly = testing_utils.should_run_eagerly()
1329
1330    i2 = layers.Input(shape=(3, 2, 1))
1331    o2 = model(i2)
1332    model2 = training_lib.Model(i2, o2)
1333    model2.run_eagerly = testing_utils.should_run_eagerly()
1334
1335    x = np.random.random((4, 3, 2, 1))
1336    out = model2.predict(x)
1337    assert out.shape == (4, 3, 2, 1)
1338    self.assertAllClose(out, x * 0.2 + x * 0.3, atol=1e-4)
1339
1340  @combinations.generate(combinations.keras_mode_combinations())
1341  def test_constant_initializer_with_numpy(self):
1342    initializer = initializers.Constant(np.ones((3, 2)))
1343    model = sequential.Sequential()
1344    model.add(layers.Dense(2, input_shape=(3,), kernel_initializer=initializer))
1345    model.add(layers.Dense(3))
1346    model.compile(
1347        loss='mse',
1348        optimizer='sgd',
1349        metrics=['acc'],
1350        run_eagerly=testing_utils.should_run_eagerly())
1351
1352    json_str = model.to_json()
1353    models.model_from_json(json_str)
1354
1355  def test_subclassed_error_if_init_not_called(self):
1356
1357    class MyNetwork(training_lib.Model):
1358
1359      def __init__(self):
1360        self._foo = [layers.Dense(10), layers.Dense(10)]
1361
1362    with self.assertRaisesRegex(RuntimeError, 'forgot to call'):
1363      MyNetwork()
1364
1365  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1366  def test_int_input_shape(self):
1367    inputs = input_layer_lib.Input(10)
1368    self.assertEqual([None, 10], inputs.shape.as_list())
1369
1370    inputs_with_batch = input_layer_lib.Input(batch_size=20, shape=5)
1371    self.assertEqual([20, 5], inputs_with_batch.shape.as_list())
1372
1373  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1374  def test_model_initialization(self):
1375    # Functional model
1376    inputs = input_layer_lib.Input(shape=(32,))
1377    outputs = layers.Dense(4)(inputs)
1378
1379    with self.assertRaisesRegex(TypeError,
1380                                'Keyword argument not understood'):
1381      model = training_lib.Model(
1382          inputs, outputs, name='m', trainable=False, dtype='int64')
1383    with self.assertRaisesRegex(TypeError,
1384                                'Keyword argument not understood'):
1385      model = training_lib.Model(
1386          inputs, outputs, name='m', trainable=False, dynamic=False)
1387
1388    model = training_lib.Model(inputs, outputs, name='m', trainable=False)
1389    self.assertEqual('m', model.name)
1390    self.assertFalse(model.trainable)
1391    self.assertFalse(model.dynamic)
1392
1393    class SubclassModel(training_lib.Model):
1394      pass
1395    # Subclassed model
1396    model = SubclassModel(
1397        name='subclassed', trainable=True, dtype='int64', dynamic=True)
1398    self.assertEqual('subclassed', model.name)
1399    self.assertTrue(model.dynamic)
1400    self.assertTrue(model.trainable)
1401    w = model.add_weight('w', [], initializer=initializers.Constant(1))
1402    self.assertEqual(dtypes.int64, w.dtype)
1403
1404  def test_disconnected_inputs(self):
1405    input_tensor1 = input_layer_lib.Input(shape=[200], name='a')
1406    input_tensor2 = input_layer_lib.Input(shape=[10], name='b')
1407    output_tensor1 = layers.Dense(units=10)(input_tensor1)
1408
1409    net = functional.Functional(
1410        inputs=[input_tensor1, input_tensor2], outputs=[output_tensor1])
1411    net2 = functional.Functional.from_config(net.get_config())
1412    self.assertLen(net2.inputs, 2)
1413    self.assertEqual('a', net2.layers[0].name)
1414    self.assertEqual('b', net2.layers[1].name)
1415
1416  @combinations.generate(combinations.keras_model_type_combinations())
1417  def test_dependency_tracking(self):
1418    model = testing_utils.get_small_mlp(1, 4, input_dim=3)
1419    model.trackable = Checkpoint()
1420    self.assertIn('trackable', model._unconditional_dependency_names)
1421    self.assertEqual(model.trackable, model._lookup_dependency('trackable'))
1422
1423  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1424  def test_model_construction_in_tf_function(self):
1425
1426    d = {'model': None}
1427
1428    @def_function.function
1429    def fn(x):
1430      if d['model'] is None:
1431        # Check that Functional can be built in a `tf.function`.
1432        inputs = input_layer_lib.Input(10)
1433        outputs = layers.Dense(1)(inputs)
1434        model = functional.Functional(inputs, outputs)
1435        d['model'] = model
1436      else:
1437        model = d['model']
1438
1439      return model(x)
1440
1441    x = array_ops.ones((10, 10))
1442    y = fn(x)
1443    self.assertEqual(y.shape.as_list(), [10, 1])
1444
1445
1446class DeferredModeTest(keras_parameterized.TestCase):
1447
1448  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1449  def testSimpleNetworkBuilding(self):
1450    inputs = input_layer_lib.Input(shape=(32,))
1451    if context.executing_eagerly():
1452      self.assertEqual(inputs.dtype.name, 'float32')
1453      self.assertEqual(inputs.shape.as_list(), [None, 32])
1454
1455    x = layers.Dense(2)(inputs)
1456    if context.executing_eagerly():
1457      self.assertEqual(x.dtype.name, 'float32')
1458      self.assertEqual(x.shape.as_list(), [None, 2])
1459
1460    outputs = layers.Dense(4)(x)
1461    network = functional.Functional(inputs, outputs)
1462    self.assertIsInstance(network, functional.Functional)
1463
1464    if context.executing_eagerly():
1465      # It should be possible to call such a network on EagerTensors.
1466      inputs = constant_op.constant(
1467          np.random.random((10, 32)).astype('float32'))
1468      outputs = network(inputs)
1469      self.assertEqual(outputs.shape.as_list(), [10, 4])
1470
1471  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1472  def testMultiIONetworkBuilding(self):
1473    input_a = input_layer_lib.Input(shape=(32,))
1474    input_b = input_layer_lib.Input(shape=(16,))
1475    a = layers.Dense(16)(input_a)
1476
1477    class AddLayer(layers.Layer):
1478
1479      def call(self, inputs):
1480        return inputs[0] + inputs[1]
1481
1482    c = AddLayer()([a, input_b])  # pylint: disable=not-callable
1483    c = layers.Dense(2)(c)
1484
1485    network = functional.Functional([input_a, input_b], [a, c])
1486    if context.executing_eagerly():
1487      a_val = constant_op.constant(
1488          np.random.random((10, 32)).astype('float32'))
1489      b_val = constant_op.constant(
1490          np.random.random((10, 16)).astype('float32'))
1491      outputs = network([a_val, b_val])
1492      self.assertEqual(len(outputs), 2)
1493      self.assertEqual(outputs[0].shape.as_list(), [10, 16])
1494      self.assertEqual(outputs[1].shape.as_list(), [10, 2])
1495
1496
1497class DefaultShapeInferenceBehaviorTest(keras_parameterized.TestCase):
1498
1499  def _testShapeInference(self, model, input_shape, expected_output_shape):
1500    input_value = np.random.random(input_shape)
1501    output_value = model.predict(input_value)
1502    self.assertEqual(output_value.shape, expected_output_shape)
1503
1504  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1505  def testSingleInputCase(self):
1506
1507    class LayerWithOneInput(layers.Layer):
1508
1509      def build(self, input_shape):
1510        self.w = array_ops.ones(shape=(3, 4))
1511
1512      def call(self, inputs):
1513        return backend.dot(inputs, self.w)
1514
1515    inputs = input_layer_lib.Input(shape=(3,))
1516    layer = LayerWithOneInput()
1517
1518    if context.executing_eagerly():
1519      self.assertEqual(
1520          layer.compute_output_shape((None, 3)).as_list(), [None, 4])
1521      # As a side-effect, compute_output_shape builds the layer.
1522      self.assertTrue(layer.built)
1523      # We can still query the layer's compute_output_shape with compatible
1524      # input shapes.
1525      self.assertEqual(
1526          layer.compute_output_shape((6, 3)).as_list(), [6, 4])
1527
1528    outputs = layer(inputs)
1529    model = training_lib.Model(inputs, outputs)
1530    self._testShapeInference(model, (2, 3), (2, 4))
1531
1532  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1533  def testMultiInputOutputCase(self):
1534
1535    class MultiInputOutputLayer(layers.Layer):
1536
1537      def build(self, input_shape):
1538        self.w = array_ops.ones(shape=(3, 4))
1539
1540      def call(self, inputs):
1541        a = backend.dot(inputs[0], self.w)
1542        b = a + inputs[1]
1543        return [a, b]
1544
1545    input_a = input_layer_lib.Input(shape=(3,))
1546    input_b = input_layer_lib.Input(shape=(4,))
1547    output_a, output_b = MultiInputOutputLayer()([input_a, input_b])
1548    model = training_lib.Model([input_a, input_b], [output_a, output_b])
1549    output_a_val, output_b_val = model.predict(
1550        [np.random.random((2, 3)), np.random.random((2, 4))])
1551    self.assertEqual(output_a_val.shape, (2, 4))
1552    self.assertEqual(output_b_val.shape, (2, 4))
1553
1554  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1555  def testTrainingArgument(self):
1556
1557    class LayerWithTrainingArg(layers.Layer):
1558
1559      def build(self, input_shape):
1560        self.w = array_ops.ones(shape=(3, 4))
1561
1562      def call(self, inputs, training):
1563        return backend.dot(inputs, self.w)
1564
1565    inputs = input_layer_lib.Input(shape=(3,))
1566    outputs = LayerWithTrainingArg()(inputs, training=False)
1567    model = training_lib.Model(inputs, outputs)
1568    self._testShapeInference(model, (2, 3), (2, 4))
1569
1570  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1571  def testNoneInShape(self):
1572
1573    class Model(training_lib.Model):
1574
1575      def __init__(self):
1576        super(Model, self).__init__()
1577        self.conv1 = layers.Conv2D(8, 3)
1578        self.pool = layers.GlobalAveragePooling2D()
1579        self.fc = layers.Dense(3)
1580
1581      def call(self, x):
1582        x = self.conv1(x)
1583        x = self.pool(x)
1584        x = self.fc(x)
1585        return x
1586
1587    model = Model()
1588    model.build(tensor_shape.TensorShape((None, None, None, 1)))
1589    self.assertTrue(model.built, 'Model should be built')
1590    self.assertTrue(model.weights,
1591                    'Model should have its weights created as it '
1592                    'has been built')
1593    sample_input = array_ops.ones((1, 10, 10, 1))
1594    output = model(sample_input)
1595    self.assertEqual(output.shape, (1, 3))
1596
1597  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1598  def testNoneInShapeWithCompoundModel(self):
1599
1600    class BasicBlock(training_lib.Model):
1601
1602      def __init__(self):
1603        super(BasicBlock, self).__init__()
1604        self.conv1 = layers.Conv2D(8, 3)
1605        self.pool = layers.GlobalAveragePooling2D()
1606        self.dense = layers.Dense(3)
1607
1608      def call(self, x):
1609        x = self.conv1(x)
1610        x = self.pool(x)
1611        x = self.dense(x)
1612        return x
1613
1614    class CompoundModel(training_lib.Model):
1615
1616      def __init__(self):
1617        super(CompoundModel, self).__init__()
1618        self.block = BasicBlock()
1619
1620      def call(self, x):
1621        x = self.block(x)  # pylint: disable=not-callable
1622        return x
1623
1624    model = CompoundModel()
1625    model.build(tensor_shape.TensorShape((None, None, None, 1)))
1626    self.assertTrue(model.built, 'Model should be built')
1627    self.assertTrue(model.weights,
1628                    'Model should have its weights created as it '
1629                    'has been built')
1630    sample_input = array_ops.ones((1, 10, 10, 1))
1631    output = model(sample_input)  # pylint: disable=not-callable
1632    self.assertEqual(output.shape, (1, 3))
1633
1634  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1635  def testNoneInShapeWithFunctionalAPI(self):
1636
1637    class BasicBlock(training_lib.Model):
1638      # Inheriting from layers.Layer since we are calling this layer
1639      # inside a model created using functional API.
1640
1641      def __init__(self):
1642        super(BasicBlock, self).__init__()
1643        self.conv1 = layers.Conv2D(8, 3)
1644
1645      def call(self, x):
1646        x = self.conv1(x)
1647        return x
1648
1649    input_layer = layers.Input(shape=(None, None, 1))
1650    x = BasicBlock()(input_layer)
1651    x = layers.GlobalAveragePooling2D()(x)
1652    output_layer = layers.Dense(3)(x)
1653
1654    model = training_lib.Model(inputs=input_layer, outputs=output_layer)
1655
1656    model.build(tensor_shape.TensorShape((None, None, None, 1)))
1657    self.assertTrue(model.built, 'Model should be built')
1658    self.assertTrue(model.weights,
1659                    'Model should have its weights created as it '
1660                    'has been built')
1661    sample_input = array_ops.ones((1, 10, 10, 1))
1662    output = model(sample_input)
1663    self.assertEqual(output.shape, (1, 3))
1664
1665  @combinations.generate(combinations.keras_mode_combinations())
1666  def test_sequential_as_downstream_of_masking_layer(self):
1667    inputs = layers.Input(shape=(3, 4))
1668    x = layers.Masking(mask_value=0., input_shape=(3, 4))(inputs)
1669
1670    s = sequential.Sequential()
1671    s.add(layers.Dense(5, input_shape=(4,)))
1672
1673    x = layers.wrappers.TimeDistributed(s)(x)
1674    model = training_lib.Model(inputs=inputs, outputs=x)
1675    model.compile(
1676        optimizer='rmsprop',
1677        loss='mse',
1678        run_eagerly=testing_utils.should_run_eagerly())
1679
1680    model_input = np.random.randint(
1681        low=1, high=5, size=(10, 3, 4)).astype('float32')
1682    for i in range(4):
1683      model_input[i, i:, :] = 0.
1684    model.fit(model_input,
1685              np.random.random((10, 3, 5)), epochs=1, batch_size=6)
1686
1687    if not context.executing_eagerly():
1688      # Note: this doesn't work in eager due to DeferredTensor/ops compatibility
1689      # issue.
1690      mask_outputs = [model.layers[1].compute_mask(model.layers[1].input)]
1691      mask_outputs += [model.layers[2].compute_mask(
1692          model.layers[2].input, mask_outputs[-1])]
1693      func = backend.function([model.input], mask_outputs)
1694      mask_outputs_val = func([model_input])
1695      self.assertAllClose(mask_outputs_val[0], np.any(model_input, axis=-1))
1696      self.assertAllClose(mask_outputs_val[1], np.any(model_input, axis=-1))
1697
1698  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1699  def test_external_keras_serialization_compat_input_layers(self):
1700    inputs = input_layer_lib.Input(shape=(10,))
1701    outputs = layers.Dense(1)(inputs)
1702    model = training_lib.Model(inputs, outputs)
1703    config = model.get_config()
1704    # Checks that single inputs and outputs are still saved as 1-element lists.
1705    # Saving as 1-element lists or not is equivalent in TF Keras, but only the
1706    # 1-element list format is supported in TF.js and keras-team/Keras.
1707    self.assertLen(config['input_layers'], 1)
1708    self.assertLen(config['output_layers'], 1)
1709
1710  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1711  def test_external_keras_serialization_compat_inbound_nodes(self):
1712    # Check single Tensor input.
1713    inputs = input_layer_lib.Input(shape=(10,), name='in')
1714    outputs = layers.Dense(1)(inputs)
1715    model = training_lib.Model(inputs, outputs)
1716    config = model.get_config()
1717    self.assertEqual(config['layers'][1]['inbound_nodes'], [[['in', 0, 0, {}]]])
1718
1719    # Check multiple Tensor input.
1720    inputs1 = input_layer_lib.Input(shape=(10,), name='in1')
1721    inputs2 = input_layer_lib.Input(shape=(10,), name='in2')
1722    outputs = layers.Add()([inputs1, inputs2])
1723    model = training_lib.Model([inputs1, inputs2], outputs)
1724    config = model.get_config()
1725    self.assertEqual(config['layers'][2]['inbound_nodes'],
1726                     [[['in1', 0, 0, {}], ['in2', 0, 0, {}]]])
1727
1728  @combinations.generate(combinations.combine(mode=['eager']))
1729  def test_dict_inputs_tensors(self):
1730    # Note that this test is running with v2 eager only, since the v1
1731    # will behave differently wrt to dict input for training.
1732    inputs = {
1733        'sentence2': input_layer_lib.Input(
1734            shape=(), name='a', dtype=dtypes.string),
1735        'sentence1': input_layer_lib.Input(
1736            shape=(), name='b', dtype=dtypes.string),
1737    }
1738    strlen = layers.Lambda(string_ops.string_length_v2)
1739    diff = layers.Subtract()(
1740        [strlen(inputs['sentence1']), strlen(inputs['sentence2'])])
1741    diff = math_ops.cast(diff, dtypes.float32)
1742    model = training_lib.Model(inputs, diff)
1743
1744    extra_keys = {
1745        'sentence1': constant_op.constant(['brown fox', 'lazy dog']),
1746        'sentence2': constant_op.constant(['owl', 'cheeky cat']),
1747        'label': constant_op.constant([0, 1]),
1748    }
1749
1750    with warnings.catch_warnings(record=True) as w:
1751      warnings.simplefilter('always')
1752      model(extra_keys)
1753      self.assertIn('ignored by the model', str(w[-1].message))
1754
1755    model.compile('sgd', 'mse')
1756    with warnings.catch_warnings(record=True) as w:
1757      warnings.simplefilter('always')
1758      model.fit(extra_keys, y=constant_op.constant([0, 1]), steps_per_epoch=1)
1759      self.assertIn('ignored by the model', str(w[-1].message))
1760
1761    with warnings.catch_warnings(record=True) as w:
1762      warnings.simplefilter('always')
1763      model.evaluate(extra_keys, constant_op.constant([0, 1]))
1764      self.assertIn('ignored by the model', str(w[-1].message))
1765
1766    # Make sure the model inputs are sorted with the dict keys.
1767    self.assertEqual(model.inputs[0]._keras_history.layer.name, 'b')
1768    self.assertEqual(model.inputs[1]._keras_history.layer.name, 'a')
1769
1770
1771class GraphUtilsTest(test.TestCase):
1772
1773  def testGetReachableFromInputs(self):
1774
1775    with ops.Graph().as_default(), self.cached_session():
1776      pl_1 = array_ops.placeholder(shape=None, dtype='float32')
1777      pl_2 = array_ops.placeholder(shape=None, dtype='float32')
1778      pl_3 = array_ops.placeholder(shape=None, dtype='float32')
1779      x_1 = pl_1 + pl_2
1780      x_2 = pl_2 * 2
1781      x_3 = pl_3 + 1
1782      x_4 = x_1 + x_2
1783      x_5 = x_3 * pl_1
1784
1785      self.assertEqual(
1786          tf_utils.get_reachable_from_inputs([pl_1]),
1787          {pl_1, x_1, x_4, x_5, x_1.op, x_4.op, x_5.op})
1788      self.assertEqual(
1789          tf_utils.get_reachable_from_inputs([pl_1, pl_2]),
1790          {pl_1, pl_2, x_1, x_2, x_4, x_5, x_1.op, x_2.op, x_4.op, x_5.op})
1791      self.assertEqual(
1792          tf_utils.get_reachable_from_inputs([pl_3]),
1793          {pl_3, x_3, x_5, x_3.op, x_5.op})
1794      self.assertEqual(
1795          tf_utils.get_reachable_from_inputs([x_3]), {x_3, x_5, x_5.op})
1796
1797
1798class NestedNetworkTest(keras_parameterized.TestCase):
1799
1800  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1801  def test_nested_inputs_network(self):
1802    inputs = {
1803        'x1': input_layer_lib.Input(shape=(1,)),
1804        'x2': input_layer_lib.Input(shape=(1,))
1805    }
1806    outputs = layers.Add()([inputs['x1'], inputs['x2']])
1807    network = functional.Functional(inputs, outputs)
1808
1809    network = functional.Functional.from_config(network.get_config())
1810
1811    result_tensor = network({
1812        'x1': array_ops.ones((1, 1), 'float32'),
1813        'x2': array_ops.ones((1, 1), 'float32')
1814    })
1815    result = self.evaluate(result_tensor)
1816    self.assertAllEqual(result, [[2.]])
1817
1818    # TODO(b/122726584): Investigate why concrete batch is flaky in some builds.
1819    output_shape = network.compute_output_shape({
1820        'x1': (None, 1),
1821        'x2': (None, 1)
1822    })
1823    self.assertListEqual(output_shape.as_list(), [None, 1])
1824
1825  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1826  def test_nested_outputs_network(self):
1827    inputs = input_layer_lib.Input(shape=(1,))
1828    outputs = {
1829        'x+x': layers.Add()([inputs, inputs]),
1830        'x*x': layers.Multiply()([inputs, inputs])
1831    }
1832
1833    network = functional.Functional(inputs, outputs)
1834
1835    network = functional.Functional.from_config(network.get_config())
1836
1837    result_tensor = network(array_ops.ones((1, 1), 'float32'))
1838    result = self.evaluate(result_tensor)
1839    self.assertAllEqual(result['x+x'], [[2.]])
1840    self.assertAllEqual(result['x*x'], [[1.]])
1841
1842    output_shape = network.compute_output_shape((None, 1))
1843    self.assertListEqual(output_shape['x+x'].as_list(), [None, 1])
1844    self.assertListEqual(output_shape['x*x'].as_list(), [None, 1])
1845
1846  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1847  def test_nested_network_inside_network(self):
1848    inner_inputs = {
1849        'x1': input_layer_lib.Input(shape=(1,)),
1850        'x2': input_layer_lib.Input(shape=(1,))
1851    }
1852    inner_outputs = {
1853        'x1+x2': layers.Add()([inner_inputs['x1'], inner_inputs['x2']]),
1854        'x1*x2': layers.Multiply()([inner_inputs['x1'], inner_inputs['x2']])
1855    }
1856    inner_network = functional.Functional(
1857        inner_inputs, inner_outputs)
1858
1859    inputs = [
1860        input_layer_lib.Input(shape=(1,)),
1861        input_layer_lib.Input(shape=(1,))
1862    ]
1863    middle = inner_network({'x1': inputs[0], 'x2': inputs[1]})
1864    outputs = layers.Add()([middle['x1+x2'], middle['x1*x2']])
1865    network = functional.Functional(inputs, outputs)
1866
1867    network = functional.Functional.from_config(network.get_config())
1868
1869    # Computes: `(x1+x2) + (x1*x2)`
1870    result_tensor = network(
1871        [array_ops.ones((1, 1), 'float32'),
1872         array_ops.ones((1, 1), 'float32')])
1873    result = self.evaluate(result_tensor)
1874    self.assertAllEqual(result, [[3.]])
1875
1876    output_shape = network.compute_output_shape([(None, 1), (None, 1)])
1877    self.assertListEqual(output_shape.as_list(), [None, 1])
1878
1879  @combinations.generate(combinations.combine(mode=['graph']))
1880  def test_updates_with_direct_call(self):
1881    inputs = input_layer_lib.Input(shape=(10,))
1882    x = layers.BatchNormalization()(inputs)
1883    x = layers.Dense(10)(x)
1884    model = training_lib.Model(inputs, x)
1885
1886    ph = backend.placeholder(shape=(10, 10))
1887    model(ph)
1888
1889    self.assertLen(model.updates, 4)
1890
1891  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1892  def test_dict_mapping_input(self):
1893
1894    class ReturnFirst(layers.Layer):
1895
1896      def call(self, inputs):
1897        b, _ = inputs
1898        return b
1899
1900    # Checks that inputs are put in same order as the
1901    # Model was constructed with.
1902    b = input_layer_lib.Input(shape=(10,), name='b')
1903    a = input_layer_lib.Input(shape=(10,), name='a')
1904    outputs = ReturnFirst()([b, a])
1905
1906    b_val = array_ops.ones((10, 10))
1907    a_val = array_ops.zeros((10, 10))
1908
1909    model = training_lib.Model([b, a], outputs)
1910    res = model({'a': a_val, 'b': b_val})
1911    self.assertAllClose(self.evaluate(res), self.evaluate(b_val))
1912
1913    reversed_model = training_lib.Model([a, b], outputs)
1914    res = reversed_model({'a': a_val, 'b': b_val})
1915    self.assertAllClose(self.evaluate(res), self.evaluate(b_val))
1916
1917  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1918  def test_dict_mapping_single_input(self):
1919    b = input_layer_lib.Input(shape=(1,), name='b')
1920    outputs = b * 2
1921    model = training_lib.Model(b, outputs)
1922
1923    b_val = array_ops.ones((1, 1))
1924    extra_val = array_ops.ones((1, 10))
1925
1926    inputs = {'a': extra_val, 'b': b_val}
1927    res = model(inputs)
1928
1929    # Check that 'b' was used and 'a' was ignored.
1930    self.assertEqual(res.shape.as_list(), [1, 1])
1931
1932  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1933  def test_nested_dict_mapping(self):
1934    a = input_layer_lib.Input(shape=(1,), dtype='int32', name='a')
1935    b = input_layer_lib.Input(shape=(1,), dtype='int32', name='b')
1936    c = input_layer_lib.Input(shape=(1,), dtype='int32', name='c')
1937    d = input_layer_lib.Input(shape=(1,), dtype='int32', name='d')
1938    inputs = {'a': (a, b), 'c': (c, d)}
1939    outputs = 1000 * a + 100 * b + 10 * c + d
1940    model = training_lib.Model(inputs, outputs)
1941
1942    a_val = array_ops.ones((1, 1), dtype='int32')
1943    b_val = 2 * array_ops.ones((1, 1), dtype='int32')
1944    c_val = 3 * array_ops.ones((1, 1), dtype='int32')
1945    d_val = 4 * array_ops.ones((1, 1), dtype='int32')
1946
1947    inputs_val = {'a': (a_val, b_val), 'c': (c_val, d_val)}
1948    res = model(inputs_val)
1949
1950    # Check that inputs were flattened in the correct order.
1951    self.assertFalse(model._enable_dict_to_input_mapping)
1952    self.assertEqual(self.evaluate(res), [1234])
1953
1954
1955@combinations.generate(combinations.keras_mode_combinations())
1956class AddLossTest(keras_parameterized.TestCase):
1957
1958  def test_add_loss_outside_call_only_loss(self):
1959    inputs = input_layer_lib.Input((10,))
1960    mid = layers.Dense(10)(inputs)
1961    outputs = layers.Dense(1)(mid)
1962    model = training_lib.Model(inputs, outputs)
1963    model.add_loss(math_ops.reduce_mean(outputs))
1964    self.assertLen(model.losses, 1)
1965
1966    initial_weights = model.get_weights()
1967
1968    x = np.ones((10, 10))
1969    model.compile(
1970        'sgd',
1971        run_eagerly=testing_utils.should_run_eagerly())
1972    model.fit(x, batch_size=2, epochs=1)
1973
1974    model2 = model.from_config(model.get_config())
1975    model2.compile(
1976        'sgd',
1977        run_eagerly=testing_utils.should_run_eagerly())
1978    model2.set_weights(initial_weights)
1979    model2.fit(x, batch_size=2, epochs=1)
1980
1981    # The TFOpLayer and the AddLoss layer are serialized.
1982    self.assertLen(model2.layers, 5)
1983    self.assertAllClose(model.get_weights(), model2.get_weights())
1984
1985  def test_add_loss_outside_call_multiple_losses(self):
1986    inputs = input_layer_lib.Input((10,))
1987    x1 = layers.Dense(10)(inputs)
1988    x2 = layers.Dense(10)(x1)
1989    outputs = layers.Dense(1)(x2)
1990    model = training_lib.Model(inputs, outputs)
1991    model.add_loss(math_ops.reduce_sum(x1 * x2))
1992    model.add_loss(math_ops.reduce_mean(outputs))
1993    self.assertLen(model.losses, 2)
1994
1995    initial_weights = model.get_weights()
1996
1997    x, y = np.ones((10, 10)), np.ones((10, 1))
1998    model.compile(
1999        'sgd',
2000        'mse',
2001        run_eagerly=testing_utils.should_run_eagerly())
2002    model.fit(x, y, batch_size=2, epochs=1)
2003
2004    model2 = model.from_config(model.get_config())
2005    model2.compile(
2006        'sgd',
2007        'mse',
2008        run_eagerly=testing_utils.should_run_eagerly())
2009    model2.set_weights(initial_weights)
2010    model2.fit(x, y, batch_size=2, epochs=1)
2011
2012    self.assertAllClose(model.get_weights(), model2.get_weights())
2013
2014  def test_add_loss_crossentropy_backtracking(self):
2015    inputs = input_layer_lib.Input((2,))
2016    labels = input_layer_lib.Input((1,))
2017    outputs = layers.Dense(1, activation='sigmoid')(inputs)
2018    model = functional.Functional([inputs, labels], outputs)
2019    model.add_loss(losses.binary_crossentropy(labels, outputs))
2020    model.compile('adam')
2021    x = np.random.random((2, 2))
2022    y = np.random.random((2, 1))
2023    model.fit([x, y])
2024
2025    inputs = input_layer_lib.Input((2,))
2026    labels = input_layer_lib.Input((2,))
2027    outputs = layers.Dense(2, activation='softmax')(inputs)
2028    model = functional.Functional([inputs, labels], outputs)
2029    model.add_loss(losses.categorical_crossentropy(labels, outputs))
2030    model.compile('adam')
2031    x = np.random.random((2, 2))
2032    y = np.random.random((2, 2))
2033    model.fit([x, y])
2034
2035    inputs = input_layer_lib.Input((2,))
2036    labels = input_layer_lib.Input((1,), dtype='int32')
2037    outputs = layers.Dense(2, activation='softmax')(inputs)
2038    model = functional.Functional([inputs, labels], outputs)
2039    model.add_loss(losses.sparse_categorical_crossentropy(labels, outputs))
2040    model.compile('adam')
2041    x = np.random.random((2, 2))
2042    y = np.random.randint(0, 2, size=(2, 1))
2043    model.fit([x, y])
2044
2045
2046@combinations.generate(combinations.keras_mode_combinations())
2047class WeightAccessTest(keras_parameterized.TestCase):
2048
2049  def test_functional_model(self):
2050    inputs = input_layer_lib.Input((10,))
2051    x1 = layers.Dense(10)(inputs)
2052    x2 = layers.Dense(10)(x1)
2053    outputs = layers.Dense(1)(x2)
2054    model = training_lib.Model(inputs, outputs)
2055
2056    self.assertEqual(len(model.weights), 6)
2057
2058  def test_sequential_model_with_input_shape(self):
2059    x1 = layers.Dense(10, input_shape=(10,))
2060    x2 = layers.Dense(10)
2061    x3 = layers.Dense(1)
2062    model = sequential.Sequential([x1, x2, x3])
2063
2064    self.assertEqual(len(model.weights), 6)
2065
2066  def test_sequential_model_without_input_shape(self):
2067    x1 = layers.Dense(10)
2068    x2 = layers.Dense(10)
2069    x3 = layers.Dense(1)
2070    model = sequential.Sequential([x1, x2, x3])
2071
2072    with self.assertRaisesRegex(
2073        ValueError, 'Weights for model .* have not yet been created'):
2074      _ = model.weights
2075
2076  def test_subclass_model_with_build_method(self):
2077
2078    class SubclassModel(models.Model):
2079
2080      def build(self, input_shape):
2081        self.w = self.add_weight(shape=input_shape[-1], initializer='ones')
2082
2083      def call(self, inputs):
2084        return inputs * self.w
2085
2086    model = SubclassModel()
2087
2088    with self.assertRaisesRegex(
2089        ValueError, 'Weights for model .* have not yet been created'):
2090      _ = model.weights
2091
2092    model(input_layer_lib.Input((10,)))
2093    self.assertEqual(len(model.weights), 1)
2094
2095  def test_subclass_model_without_build_method(self):
2096
2097    class SubclassModel(models.Model):
2098
2099      def __init__(self):
2100        super(SubclassModel, self).__init__()
2101        self.w = self.add_weight(shape=(), initializer='ones')
2102
2103      def call(self, inputs):
2104        return inputs * self.w
2105
2106    model = SubclassModel()
2107    self.assertEqual(len(model.weights), 1)
2108
2109
2110@combinations.generate(combinations.combine(mode=['graph', 'eager']))
2111class DTypeTest(keras_parameterized.TestCase):
2112
2113  @testing_utils.enable_v2_dtype_behavior
2114  def test_graph_network_dtype(self):
2115    inputs = input_layer_lib.Input((10,))
2116    outputs = layers.Dense(10)(inputs)
2117    network = functional.Functional(inputs, outputs)
2118    self.assertEqual(network.dtype, 'float32')
2119
2120  @testing_utils.enable_v2_dtype_behavior
2121  def test_subclassed_network_dtype(self):
2122
2123    class IdentityNetwork(training_lib.Model):
2124
2125      def call(self, inputs):
2126        return inputs
2127
2128    network = IdentityNetwork()
2129    self.assertEqual(network.dtype, 'float32')
2130    self.assertEqual(network(array_ops.constant(1, 'float64')).dtype, 'float32')
2131
2132    network = IdentityNetwork(dtype='float16')
2133    self.assertEqual(network.dtype, 'float16')
2134    self.assertEqual(network(array_ops.constant(1, 'float64')).dtype, 'float16')
2135
2136    network = IdentityNetwork(autocast=False)
2137    self.assertEqual(network.dtype, 'float32')
2138    self.assertEqual(network(array_ops.constant(1, 'float64')).dtype, 'float64')
2139
2140
2141class AttrTrackingLayer(base_layer.Layer):
2142  """Count how many times `dynamic` and `stateful` are called.
2143
2144  These counts are used to test that the attribute cache behaves as expected.
2145  """
2146  def __init__(self, *args, **kwargs):
2147    self.stateful_count = 0
2148    self.dynamic_count = 0
2149    super(AttrTrackingLayer, self).__init__(*args, **kwargs)
2150
2151  @base_layer.Layer.stateful.getter
2152  def stateful(self):
2153    self.stateful_count += 1
2154    return super(AttrTrackingLayer, self).stateful
2155
2156  @property
2157  def dynamic(self):
2158    self.dynamic_count += 1
2159    return super(AttrTrackingLayer, self).dynamic
2160
2161
2162@combinations.generate(combinations.combine(mode=['graph', 'eager']))
2163class CacheCorrectnessTest(keras_parameterized.TestCase):
2164
2165  def layer_and_network_test(self):
2166    # Top level layer
2167    network = functional.Functional()
2168
2169    layer_0 = AttrTrackingLayer()
2170
2171    sub_network = functional.Functional()
2172    layer_1 = AttrTrackingLayer(dynamic=True)
2173    layer_2 = AttrTrackingLayer()
2174    sub_network.sub_layers = [layer_1, layer_2]
2175
2176    network.sub_layer = layer_0
2177
2178    for _ in range(2):
2179      self.assertEqual(network.dynamic, False)
2180      self.assertEqual(network.stateful, False)
2181
2182      # The second pass should be a cache hit.
2183      self.assertEqual(layer_0.dynamic_count, 1)
2184      self.assertEqual(layer_0.stateful_count, 1)
2185
2186    # Mutations of the sub-layer should force recalculation of the network's
2187    # stateful attribute. (mutations bubble up.)
2188    layer_0.stateful = True
2189    self.assertEqual(network.stateful, True)
2190    self.assertEqual(layer_0.stateful_count, 2)
2191
2192    layer_0.stateful = False
2193    self.assertEqual(network.stateful, False)
2194    self.assertEqual(layer_0.stateful_count, 3)
2195
2196    # But changing stateful should not affect dynamic.
2197    self.assertEqual(network.dynamic, False)
2198    self.assertEqual(layer_0.dynamic_count, 1)
2199
2200    network.sub_network = sub_network
2201
2202    # Adding to the topology should invalidate the cache and reflect in the top
2203    # level network.
2204    self.assertEqual(network.dynamic, True)
2205    self.assertEqual(layer_0.dynamic_count, 2)
2206    self.assertEqual(layer_1.dynamic_count, 1)
2207
2208    # Still dynamic, but we need to recompute.
2209    sub_network.sub_layers.pop()
2210    self.assertEqual(network.dynamic, True)
2211    self.assertEqual(layer_0.dynamic_count, 3)
2212    self.assertEqual(layer_1.dynamic_count, 2)
2213
2214    # Now that we've removed the dynamic layer deep in the layer hierarchy, we
2215    # need to make sure that that bubbles up through all the levels.
2216    sub_network.sub_layers.pop()
2217    self.assertEqual(network.dynamic, False)
2218    self.assertEqual(layer_0.dynamic_count, 4)
2219    self.assertEqual(layer_1.dynamic_count, 2)
2220
2221    # Now check with a tracked dict.
2222    sub_network.sub_layers = {
2223        "layer_1": layer_1,
2224        "layer_2": layer_2,
2225    }
2226
2227    self.assertEqual(network.dynamic, True)
2228    self.assertEqual(layer_0.dynamic_count, 5)
2229    self.assertEqual(layer_1.dynamic_count, 3)
2230
2231    # In-place assignment should still invalidate the cache.
2232    sub_network.sub_layers["layer_1"] = layer_1
2233    self.assertEqual(network.dynamic, True)
2234    self.assertEqual(layer_0.dynamic_count, 6)
2235    self.assertEqual(layer_1.dynamic_count, 4)
2236
2237    sub_network.sub_layers["layer_1"] = None
2238    for _ in range(2):
2239      self.assertEqual(network.dynamic, False)
2240      self.assertEqual(layer_0.dynamic_count, 7)
2241      self.assertEqual(layer_1.dynamic_count, 4)
2242
2243    layer_3 = AttrTrackingLayer()
2244    layer_3.stateful = True
2245
2246    sub_network.sub_layers = None
2247    self.assertEqual(network.dynamic, False)
2248    self.assertEqual(network.stateful, False)
2249
2250    # Test duplicate layers.
2251    sub_network.sub_layers = [layer_1, layer_1, layer_1, layer_3]
2252    self.assertEqual(network.dynamic, True)
2253    self.assertEqual(network.stateful, True)
2254
2255    for _ in range(3):
2256      sub_network.sub_layers.pop()
2257      self.assertEqual(network.dynamic, True)
2258      self.assertEqual(network.stateful, False)
2259
2260    sub_network.sub_layers.pop()
2261    self.assertEqual(network.dynamic, False)
2262    self.assertEqual(network.stateful, False)
2263
2264  def test_compute_output_shape_cache(self):
2265    # See https://github.com/tensorflow/tensorflow/issues/32029.
2266    x = input_layer_lib.Input(shape=(None, 32))
2267    dense = layers.Dense(2)
2268    y = dense(x)
2269    network = functional.Functional(x, y, name='dense_network')
2270
2271    for i in range(999, 1024):
2272      self.assertEqual(network.compute_output_shape((1, i, 32)), (1, i, 2))
2273
2274  def test_2d_inputs_squeezed_to_1d(self):
2275    input_1d = input_layer_lib.Input(shape=())
2276    outputs = input_1d * 2.
2277    net = functional.Functional(input_1d, outputs)
2278
2279    x = np.ones((10, 1))
2280    y = net(x)
2281    self.assertEqual(y.shape.rank, 1)
2282
2283  def test_1d_inputs_expanded_to_2d(self):
2284    input_1d = input_layer_lib.Input(shape=(1,))
2285    outputs = input_1d * 2.
2286    net = functional.Functional(input_1d, outputs)
2287
2288    x = np.ones((10,))
2289    y = net(x)
2290    self.assertEqual(y.shape.rank, 2)
2291
2292  def test_training_passed_during_construction(self):
2293
2294    def _call(inputs, training):
2295      if training is None:
2296        return inputs * -1.0
2297      elif training:
2298        return inputs
2299      else:
2300        return inputs * 0.0
2301
2302    class MyLayer(base_layer.Layer):
2303
2304      def call(self, inputs, training=True):
2305        return _call(inputs, training)
2306
2307    my_layer = MyLayer()
2308    x = np.ones((1, 10))
2309
2310    # Hard-coded `true` value passed during construction is respected.
2311    inputs = input_layer_lib.Input(10)
2312    outputs = my_layer(inputs, training=True)
2313    network = functional.Functional(inputs, outputs)
2314    self.assertAllEqual(network(x, training=True), _call(x, True))
2315    self.assertAllEqual(network(x, training=False), _call(x, True))
2316    self.assertAllEqual(network(x), _call(x, True))
2317
2318    # Hard-coded `false` value passed during construction is respected.
2319    inputs = input_layer_lib.Input(10)
2320    outputs = my_layer(inputs, training=False)
2321    network = functional.Functional(inputs, outputs)
2322    self.assertAllEqual(network(x, training=True), _call(x, False))
2323    self.assertAllEqual(network(x, training=False), _call(x, False))
2324    self.assertAllEqual(network(x), _call(x, False))
2325
2326    if context.executing_eagerly():
2327      # In v2, construction still works when no `training` is specified
2328      # When no value passed during construction, it uses the local default.
2329      inputs = input_layer_lib.Input(10)
2330      outputs = my_layer(inputs)
2331      network = functional.Functional(inputs, outputs)
2332      self.assertAllEqual(network(x, training=True), _call(x, True))
2333      self.assertAllEqual(network(x, training=False), _call(x, False))
2334      self.assertAllEqual(network(x), _call(x, True))  # Use local default
2335
2336    # `None` value passed positionally during construction is ignored at runtime
2337    inputs = input_layer_lib.Input(10)
2338    outputs = my_layer(inputs, None)
2339    network = functional.Functional(inputs, outputs)
2340    self.assertAllEqual(network(x, training=True), _call(x, True))
2341    self.assertAllEqual(network(x, training=False), _call(x, False))
2342    if context.executing_eagerly():
2343      self.assertAllEqual(network(x), _call(x, True))  # Use local default
2344    else:
2345      # in v1 training would have defaulted to using the `None` inside the layer
2346      # if training is not passed at runtime
2347      self.assertAllEqual(network(x), _call(x, None))
2348
2349    # `None` value passed as kwarg during construction is ignored at runtime.
2350    inputs = input_layer_lib.Input(10)
2351    outputs = my_layer(inputs, training=None)
2352    network = functional.Functional(inputs, outputs)
2353    self.assertAllEqual(network(x, training=True), _call(x, True))
2354    self.assertAllEqual(network(x, training=False), _call(x, False))
2355    if context.executing_eagerly():
2356      self.assertAllEqual(network(x), _call(x, True))  # Use local default
2357    else:
2358      # in v1 training would have defaulted to using the `None` inside the layer
2359      # if training is not passed at runtime
2360      self.assertAllEqual(network(x), _call(x, None))
2361
2362
2363class InputsOutputsErrorTest(keras_parameterized.TestCase):
2364
2365  @testing_utils.enable_v2_dtype_behavior
2366  def test_input_error(self):
2367    inputs = input_layer_lib.Input((10,))
2368    outputs = layers.Dense(10)(inputs)
2369    with self.assertRaisesRegex(
2370        TypeError, "('Keyword argument not understood:', 'input')"):
2371      models.Model(input=inputs, outputs=outputs)
2372
2373  @testing_utils.enable_v2_dtype_behavior
2374  def test_output_error(self):
2375    inputs = input_layer_lib.Input((10,))
2376    outputs = layers.Dense(10)(inputs)
2377    with self.assertRaisesRegex(
2378        TypeError, "('Keyword argument not understood:', 'output')"):
2379      models.Model(inputs=inputs, output=outputs)
2380
2381  def test_input_spec(self):
2382    if not context.executing_eagerly():
2383      return
2384    inputs = input_layer_lib.Input((10,))
2385    outputs = layers.Dense(10)(inputs)
2386    model = models.Model(inputs, outputs)
2387    with self.assertRaisesRegex(
2388        ValueError, r'.*expected shape=.*'):
2389      model(np.zeros((3, 11)))
2390
2391  def test_input_spec_list_of_inputs(self):
2392    if not context.executing_eagerly():
2393      return
2394    input_1 = input_layer_lib.Input((10,), name='1')
2395    input_2 = input_layer_lib.Input((5,), name='2')
2396    x = layers.Concatenate()([input_1, input_2])
2397    outputs = layers.Dense(10)(x)
2398    model = models.Model([input_1, input_2], outputs)
2399    with self.assertRaisesRegex(
2400        ValueError, r'.*expects 2 input.*'):
2401      model(np.zeros((3, 10)))
2402    with self.assertRaisesRegex(
2403        ValueError, r'.*expects 2 input.*'):
2404      model([np.zeros((3, 10)), np.zeros((3, 5)), np.zeros((3, 10))])
2405    with self.assertRaisesRegex(
2406        ValueError, r'.*expected shape=.*'):
2407      model([np.zeros((3, 10)), np.zeros((3, 6))])
2408
2409    # Test passing data via dict keyed by input name
2410    with self.assertRaisesRegex(
2411        ValueError, r'Missing data for input.*'):
2412      model({'1': np.zeros((3, 10))})
2413    with self.assertRaisesRegex(
2414        ValueError, r'.*expected shape=.*'):
2415      model({'1': np.zeros((3, 10)), '2': np.zeros((3, 6))})
2416
2417  def test_input_spec_dict(self):
2418    if not context.executing_eagerly():
2419      return
2420    input_1 = input_layer_lib.Input((10,))
2421    input_2 = input_layer_lib.Input((5,))
2422    x = layers.Concatenate()([input_1, input_2])
2423    outputs = layers.Dense(10)(x)
2424    model = models.Model({'1': input_1, '2': input_2}, outputs)
2425    with self.assertRaisesRegex(
2426        ValueError, r'Missing data for input.*'):
2427      model({'1': np.zeros((3, 10))})
2428    with self.assertRaisesRegex(
2429        ValueError, r'.*expected shape=.*'):
2430      model({'1': np.zeros((3, 10)), '2': np.zeros((3, 6))})
2431
2432
2433class FunctionalSubclassModel(training_lib.Model):
2434
2435  def __init__(self, *args, **kwargs):
2436    self.foo = {'foo': 'bar'}  # Make sure users can assign dict attributes
2437    my_input = input_layer_lib.Input(shape=(16,))
2438    dense = layers.Dense(32, activation='relu')
2439    output = dense(my_input)
2440    outputs = {'output': output}
2441    super().__init__(inputs=[my_input], outputs=outputs, *args, **kwargs)
2442
2443
2444class MixinClass(object):
2445
2446  def __init__(self, foo, **kwargs):
2447    self._foo = foo
2448    super().__init__(**kwargs)
2449
2450  def get_foo(self):
2451    return self._foo
2452
2453
2454class SubclassedModel(training_lib.Model):
2455
2456  def __init__(self, bar, **kwargs):
2457    self._bar = bar
2458    super().__init__(**kwargs)
2459
2460  def get_bar(self):
2461    return self._bar
2462
2463
2464class MultipleInheritanceModelTest(keras_parameterized.TestCase):
2465
2466  def testFunctionalSubclass(self):
2467    m = FunctionalSubclassModel()
2468    # Some smoke test for the weights and output shape of the model
2469    self.assertLen(m.weights, 2)
2470    self.assertEqual(m.outputs[0].shape.as_list(), [None, 32])
2471
2472  def testFunctionalSubclassPreMixin(self):
2473    class MixedFunctionalSubclassModel(MixinClass, FunctionalSubclassModel):
2474      pass
2475
2476    m = MixedFunctionalSubclassModel(foo='123')
2477    self.assertTrue(m._is_graph_network)
2478    self.assertLen(m.weights, 2)
2479    self.assertEqual(m.outputs[0].shape.as_list(), [None, 32])
2480    self.assertEqual(m.get_foo(), '123')
2481
2482  def testFunctionalSubclassPostMixin(self):
2483    # Make sure the the mixin class is also init correct when the order changed.
2484
2485    class MixedFunctionalSubclassModel(FunctionalSubclassModel, MixinClass):
2486      pass
2487
2488    m = MixedFunctionalSubclassModel(foo='123')
2489    self.assertTrue(m._is_graph_network)
2490    self.assertLen(m.weights, 2)
2491    self.assertEqual(m.outputs[0].shape.as_list(), [None, 32])
2492    self.assertEqual(m.get_foo(), '123')
2493
2494  def testSubclassModelPreMixin(self):
2495    class MixedSubclassModel(MixinClass, SubclassedModel):
2496      pass
2497
2498    m = MixedSubclassModel(foo='123', bar='456')
2499    self.assertFalse(m._is_graph_network)
2500    self.assertEqual(m.get_foo(), '123')
2501    self.assertEqual(m.get_bar(), '456')
2502
2503
2504if __name__ == '__main__':
2505  test.main()
2506