1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14#,============================================================================ 15"""Tests for layer graphs construction & handling.""" 16 17from tensorflow.python.keras import keras_parameterized 18from tensorflow.python.keras.engine import base_layer 19from tensorflow.python.keras.engine import node as node_module 20from tensorflow.python.platform import test 21 22 23class DummyTensor: 24 25 def __init__(self, shape=None): 26 self.shape = shape 27 28 29class DummyLayer(base_layer.Layer): 30 pass 31 32 33class NetworkConstructionTest(keras_parameterized.TestCase): 34 35 def test_chained_node_construction(self): 36 # test basics 37 a = DummyTensor(shape=(None, 32)) 38 b = DummyTensor(shape=(None, 32)) 39 40 a_layer = DummyLayer() 41 node = node_module.Node(a_layer, outputs=a) 42 self.assertEqual(node.outbound_layer, a_layer) 43 44 self.assertTrue(node.is_input) 45 self.assertListEqual(node.inbound_layers, []) 46 self.assertListEqual(node.input_tensors, [a]) 47 self.assertListEqual(node.input_shapes, [(None, 32)]) 48 self.assertListEqual(node.output_tensors, [a]) 49 self.assertListEqual(node.output_shapes, [(None, 32)]) 50 51 b_layer = DummyLayer() 52 node_module.Node(b_layer, outputs=b) 53 54 dense = DummyLayer() 55 a_2 = DummyTensor() 56 node_a = node_module.Node(layer=dense, call_args=(a,), outputs=a_2) 57 b_2 = DummyTensor() 58 node_b = node_module.Node(layer=dense, call_args=(b,), outputs=b_2) 59 60 # test the node attributes 61 self.assertFalse(node_a.is_input) 62 self.assertFalse(node_b.is_input) 63 self.assertEqual(node_a.call_args, (a,)) 64 self.assertEqual(node_a.call_kwargs, {}) 65 self.assertEqual(node_a.outputs, a_2) 66 67 # Test the layer wiring 68 self.assertLen(dense._inbound_nodes, 2) 69 self.assertLen(dense._outbound_nodes, 0) 70 self.assertEqual(dense._inbound_nodes, [node_a, node_b]) 71 self.assertEqual(dense._inbound_nodes[0].inbound_layers, a_layer) 72 self.assertEqual(dense._inbound_nodes[0].outbound_layer, dense) 73 self.assertEqual(dense._inbound_nodes[1].inbound_layers, b_layer) 74 self.assertEqual(dense._inbound_nodes[1].outbound_layer, dense) 75 self.assertIs(dense._inbound_nodes[0].input_tensors, a) 76 self.assertIs(dense._inbound_nodes[1].input_tensors, b) 77 78 def test_multi_input_node(self): 79 # test multi-input layer 80 a = DummyTensor() 81 b = DummyTensor() 82 83 dense = DummyLayer() 84 a_2 = DummyTensor() 85 node_module.Node(layer=dense, call_args=(a,), outputs=a_2) 86 b_2 = DummyTensor() 87 node_module.Node(layer=dense, call_args=(b,), outputs=b_2) 88 89 concat_layer = DummyLayer() 90 merged = DummyTensor() 91 node_module.Node(layer=concat_layer, call_args=([a_2, b_2],), 92 outputs=merged) 93 94 merge_layer, merge_node_index, merge_tensor_index = merged._keras_history 95 96 self.assertEqual(merge_node_index, 0) 97 self.assertEqual(merge_tensor_index, 0) 98 99 self.assertLen(merge_layer._inbound_nodes, 1) 100 self.assertLen(merge_layer._outbound_nodes, 0) 101 102 self.assertLen(merge_layer._inbound_nodes[0].input_tensors, 2) 103 self.assertEqual(merge_layer._inbound_nodes[0].input_tensors, [a_2, b_2]) 104 self.assertLen(merge_layer._inbound_nodes[0].inbound_layers, 2) 105 106 def test_arg_and_kwarg_mix(self): 107 input_layer = DummyLayer() 108 input_layer_2 = DummyLayer() 109 a = DummyTensor() 110 node_a = node_module.Node(layer=input_layer, outputs=a) 111 b = DummyTensor() 112 node_b = node_module.Node(layer=input_layer_2, outputs=b) 113 114 arg_2 = DummyTensor() 115 arg_3 = DummyTensor() 116 node_c = node_module.Node(layer=input_layer, outputs=arg_3) 117 118 kwarg_x = DummyTensor() 119 kwarg_y = DummyTensor() 120 node_d = node_module.Node(layer=input_layer, outputs=kwarg_y) 121 122 merge_layer = DummyLayer() 123 merged = DummyTensor() 124 node = node_module.Node(layer=merge_layer, 125 call_args=([a, b], arg_2, arg_3), 126 call_kwargs={'x': kwarg_x, 'y': kwarg_y}, 127 outputs=merged) 128 129 merge_layer, merge_node_index, merge_tensor_index = merged._keras_history 130 131 # Check the saved call args/kwargs 132 self.assertEqual(([a, b], arg_2, arg_3), node.call_args) 133 self.assertEqual({'x': kwarg_x, 'y': kwarg_y}, node.call_kwargs) 134 135 # Only the inputs that were produced by input nodes should appear in 136 # keras_tensors 137 self.assertEqual({a, b, arg_3, kwarg_y}, set(node.keras_inputs)) 138 self.assertEqual(set(node.parent_nodes), {node_a, node_b, node_c, node_d}) 139 140 # Check the layer wirings 141 self.assertEqual(merge_node_index, 0) 142 self.assertEqual(merge_tensor_index, 0) 143 self.assertLen(merge_layer._inbound_nodes, 1) 144 self.assertLen(merge_layer._outbound_nodes, 0) 145 self.assertLen(input_layer._outbound_nodes, 3) 146 self.assertLen(input_layer_2._outbound_nodes, 1) 147 148 # The 'backwards compatibility' attributes should only check the 149 # first call argument 150 self.assertLen(merge_layer._inbound_nodes[0].input_tensors, 2) 151 self.assertEqual(merge_layer._inbound_nodes[0].input_tensors, [a, b]) 152 self.assertLen(merge_layer._inbound_nodes[0].inbound_layers, 2) 153 154 155if __name__ == '__main__': 156 test.main() 157