• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#,============================================================================
15"""Tests for layer graphs construction & handling."""
16
17from tensorflow.python.keras import keras_parameterized
18from tensorflow.python.keras.engine import base_layer
19from tensorflow.python.keras.engine import node as node_module
20from tensorflow.python.platform import test
21
22
23class DummyTensor:
24
25  def __init__(self, shape=None):
26    self.shape = shape
27
28
29class DummyLayer(base_layer.Layer):
30  pass
31
32
33class NetworkConstructionTest(keras_parameterized.TestCase):
34
35  def test_chained_node_construction(self):
36    # test basics
37    a = DummyTensor(shape=(None, 32))
38    b = DummyTensor(shape=(None, 32))
39
40    a_layer = DummyLayer()
41    node = node_module.Node(a_layer, outputs=a)
42    self.assertEqual(node.outbound_layer, a_layer)
43
44    self.assertTrue(node.is_input)
45    self.assertListEqual(node.inbound_layers, [])
46    self.assertListEqual(node.input_tensors, [a])
47    self.assertListEqual(node.input_shapes, [(None, 32)])
48    self.assertListEqual(node.output_tensors, [a])
49    self.assertListEqual(node.output_shapes, [(None, 32)])
50
51    b_layer = DummyLayer()
52    node_module.Node(b_layer, outputs=b)
53
54    dense = DummyLayer()
55    a_2 = DummyTensor()
56    node_a = node_module.Node(layer=dense, call_args=(a,), outputs=a_2)
57    b_2 = DummyTensor()
58    node_b = node_module.Node(layer=dense, call_args=(b,), outputs=b_2)
59
60    # test the node attributes
61    self.assertFalse(node_a.is_input)
62    self.assertFalse(node_b.is_input)
63    self.assertEqual(node_a.call_args, (a,))
64    self.assertEqual(node_a.call_kwargs, {})
65    self.assertEqual(node_a.outputs, a_2)
66
67    # Test the layer wiring
68    self.assertLen(dense._inbound_nodes, 2)
69    self.assertLen(dense._outbound_nodes, 0)
70    self.assertEqual(dense._inbound_nodes, [node_a, node_b])
71    self.assertEqual(dense._inbound_nodes[0].inbound_layers, a_layer)
72    self.assertEqual(dense._inbound_nodes[0].outbound_layer, dense)
73    self.assertEqual(dense._inbound_nodes[1].inbound_layers, b_layer)
74    self.assertEqual(dense._inbound_nodes[1].outbound_layer, dense)
75    self.assertIs(dense._inbound_nodes[0].input_tensors, a)
76    self.assertIs(dense._inbound_nodes[1].input_tensors, b)
77
78  def test_multi_input_node(self):
79    # test multi-input layer
80    a = DummyTensor()
81    b = DummyTensor()
82
83    dense = DummyLayer()
84    a_2 = DummyTensor()
85    node_module.Node(layer=dense, call_args=(a,), outputs=a_2)
86    b_2 = DummyTensor()
87    node_module.Node(layer=dense, call_args=(b,), outputs=b_2)
88
89    concat_layer = DummyLayer()
90    merged = DummyTensor()
91    node_module.Node(layer=concat_layer, call_args=([a_2, b_2],),
92                     outputs=merged)
93
94    merge_layer, merge_node_index, merge_tensor_index = merged._keras_history
95
96    self.assertEqual(merge_node_index, 0)
97    self.assertEqual(merge_tensor_index, 0)
98
99    self.assertLen(merge_layer._inbound_nodes, 1)
100    self.assertLen(merge_layer._outbound_nodes, 0)
101
102    self.assertLen(merge_layer._inbound_nodes[0].input_tensors, 2)
103    self.assertEqual(merge_layer._inbound_nodes[0].input_tensors, [a_2, b_2])
104    self.assertLen(merge_layer._inbound_nodes[0].inbound_layers, 2)
105
106  def test_arg_and_kwarg_mix(self):
107    input_layer = DummyLayer()
108    input_layer_2 = DummyLayer()
109    a = DummyTensor()
110    node_a = node_module.Node(layer=input_layer, outputs=a)
111    b = DummyTensor()
112    node_b = node_module.Node(layer=input_layer_2, outputs=b)
113
114    arg_2 = DummyTensor()
115    arg_3 = DummyTensor()
116    node_c = node_module.Node(layer=input_layer, outputs=arg_3)
117
118    kwarg_x = DummyTensor()
119    kwarg_y = DummyTensor()
120    node_d = node_module.Node(layer=input_layer, outputs=kwarg_y)
121
122    merge_layer = DummyLayer()
123    merged = DummyTensor()
124    node = node_module.Node(layer=merge_layer,
125                            call_args=([a, b], arg_2, arg_3),
126                            call_kwargs={'x': kwarg_x, 'y': kwarg_y},
127                            outputs=merged)
128
129    merge_layer, merge_node_index, merge_tensor_index = merged._keras_history
130
131    # Check the saved call args/kwargs
132    self.assertEqual(([a, b], arg_2, arg_3), node.call_args)
133    self.assertEqual({'x': kwarg_x, 'y': kwarg_y}, node.call_kwargs)
134
135    # Only the inputs that were produced by input nodes should appear in
136    # keras_tensors
137    self.assertEqual({a, b, arg_3, kwarg_y}, set(node.keras_inputs))
138    self.assertEqual(set(node.parent_nodes), {node_a, node_b, node_c, node_d})
139
140    # Check the layer wirings
141    self.assertEqual(merge_node_index, 0)
142    self.assertEqual(merge_tensor_index, 0)
143    self.assertLen(merge_layer._inbound_nodes, 1)
144    self.assertLen(merge_layer._outbound_nodes, 0)
145    self.assertLen(input_layer._outbound_nodes, 3)
146    self.assertLen(input_layer_2._outbound_nodes, 1)
147
148    # The 'backwards compatibility' attributes should only check the
149    # first call argument
150    self.assertLen(merge_layer._inbound_nodes[0].input_tensors, 2)
151    self.assertEqual(merge_layer._inbound_nodes[0].input_tensors, [a, b])
152    self.assertLen(merge_layer._inbound_nodes[0].inbound_layers, 2)
153
154
155if __name__ == '__main__':
156  test.main()
157