1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the swig wrapper of items.""" 16 17from tensorflow.python.framework import constant_op 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import errors_impl 20from tensorflow.python.framework import meta_graph 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import tensor_shape 23from tensorflow.python.framework import test_util 24from tensorflow.python.grappler import item 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.ops import gen_array_ops 27from tensorflow.python.ops import state_ops 28from tensorflow.python.ops import variables 29from tensorflow.python.platform import test 30 31 32class ItemTest(test.TestCase): 33 34 def testInvalidItem(self): 35 with ops.Graph().as_default() as g: 36 a = constant_op.constant(10) 37 b = constant_op.constant(20) 38 c = a + b # pylint: disable=unused-variable 39 mg = meta_graph.create_meta_graph_def(graph=g) 40 41 # The train op isn't specified: this should raise an InvalidArgumentError 42 # exception. 43 with self.assertRaises(errors_impl.InvalidArgumentError): 44 item.Item(mg) 45 46 def testImportantOps(self): 47 with ops.Graph().as_default() as g: 48 a = constant_op.constant(10) 49 b = constant_op.constant(20) 50 c = a + b 51 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 52 train_op.append(c) 53 mg = meta_graph.create_meta_graph_def(graph=g) 54 grappler_item = item.Item(mg) 55 op_list = grappler_item.IdentifyImportantOps() 56 self.assertItemsEqual(['Const', 'Const_1', 'add'], op_list) 57 58 def testOpProperties(self): 59 with ops.Graph().as_default() as g: 60 a = constant_op.constant(10) 61 b = constant_op.constant(20) 62 c = a + b 63 z = control_flow_ops.no_op() 64 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 65 train_op.append(c) 66 mg = meta_graph.create_meta_graph_def(graph=g) 67 grappler_item = item.Item(mg) 68 op_properties = grappler_item.GetOpProperties() 69 70 # All the nodes in this model have one scalar output 71 for node in grappler_item.metagraph.graph_def.node: 72 node_prop = op_properties[node.name] 73 74 if node.name == z.name: 75 self.assertEqual(0, len(node_prop)) 76 else: 77 self.assertEqual(1, len(node_prop)) 78 self.assertEqual(dtypes.int32, node_prop[0].dtype) 79 self.assertEqual(tensor_shape.TensorShape([]), node_prop[0].shape) 80 81 def testUpdates(self): 82 with ops.Graph().as_default() as g: 83 a = constant_op.constant(10) 84 b = constant_op.constant(20) 85 c = a + b 86 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 87 train_op.append(c) 88 mg = meta_graph.create_meta_graph_def(graph=g) 89 grappler_item = item.Item(mg) 90 91 initial_tf_item = grappler_item.tf_item 92 no_change_tf_item = grappler_item.tf_item 93 self.assertEqual(initial_tf_item, no_change_tf_item) 94 95 # Modify the placement. 96 for node in grappler_item.metagraph.graph_def.node: 97 node.device = '/cpu:0' 98 new_tf_item = grappler_item.tf_item 99 self.assertNotEqual(initial_tf_item, new_tf_item) 100 101 # Assign the same placement. 102 for node in grappler_item.metagraph.graph_def.node: 103 node.device = '/cpu:0' 104 newest_tf_item = grappler_item.tf_item 105 self.assertEqual(new_tf_item, newest_tf_item) 106 107 @test_util.run_v1_only('b/120545219') 108 def testColocationConstraints(self): 109 with ops.Graph().as_default() as g: 110 c = constant_op.constant([10]) 111 v = variables.VariableV1([3], dtype=dtypes.int32) 112 i = gen_array_ops.ref_identity(v) 113 a = state_ops.assign(i, c) 114 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 115 train_op.append(a) 116 mg = meta_graph.create_meta_graph_def(graph=g) 117 grappler_item = item.Item(mg) 118 groups = grappler_item.GetColocationGroups() 119 self.assertEqual(len(groups), 1) 120 self.assertItemsEqual( 121 groups[0], ['Assign', 'RefIdentity', 'Variable', 'Variable/Assign']) 122 123 124if __name__ == '__main__': 125 test.main() 126