1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the swig wrapper tf_optimizer.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20from tensorflow.core.framework import attr_value_pb2 21from tensorflow.core.protobuf import config_pb2 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import meta_graph 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.framework import test_util 28from tensorflow.python.grappler import item as gitem 29from tensorflow.python.grappler import tf_optimizer 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import control_flow_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import variables 34from tensorflow.python.platform import test 35 36 37class PyWrapOptimizeGraphTest(test.TestCase): 38 39 @test_util.run_deprecated_v1 40 def testBasic(self): 41 """Make sure arguments can be passed correctly.""" 42 a = constant_op.constant(10, name='a') 43 b = constant_op.constant(20, name='b') 44 c = math_ops.add_n([a, b], name='c') 45 d = math_ops.add_n([b, c], name='d') 46 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 47 # Being a train_op will make 'd' to be added as a fetch node. 48 train_op.append(d) 49 mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) 50 51 config = config_pb2.ConfigProto() 52 rewriter_config = config.graph_options.rewrite_options 53 rewriter_config.optimizers.append('constfold') 54 rewriter_config.min_graph_nodes = -1 55 56 graph = tf_optimizer.OptimizeGraph(config, mg) 57 58 self.assertEqual(len(graph.node), 1) 59 self.assertItemsEqual([node.name for node in graph.node], ['d']) 60 61 @test_util.run_v1_only('b/120545219') 62 def testKeepNodes(self): 63 g = ops.Graph() 64 with g.as_default(): 65 a1 = variables.VariableV1( 66 1.0) # Must be preserved since it's in the collection 'variables'. 67 a2 = constant_op.constant(0, shape=[50, 50], name='keep') 68 ops.add_to_collection('a2', a2) # Explicitly add to collection. 69 with g._attr_scope( 70 {'_grappler_do_not_remove': attr_value_pb2.AttrValue(b=True)}): 71 a3 = constant_op.constant(0, name='keep2') 72 b = constant_op.constant(1, shape=[100, 10]) 73 c = constant_op.constant(0, shape=[10, 30]) 74 d = math_ops.matmul(b, c) 75 ops.add_to_collection('train_op', d) # d is the fetch node. 76 77 # Optimize the graph. 78 mg = meta_graph.create_meta_graph_def(graph=g) 79 config = config_pb2.ConfigProto() 80 rewriter_config = config.graph_options.rewrite_options 81 rewriter_config.min_graph_nodes = -1 82 optimized_graph = tf_optimizer.OptimizeGraph(config, mg) 83 84 # Check that the nodes referenced in various collections have been preserved 85 optimized_graph_nodes = [node.name for node in optimized_graph.node] 86 expected_nodes = [ 87 d.op.name, a1.op.name, a2.op.name, a3.op.name, 'Variable/initial_value', 88 'Variable/Assign' 89 ] 90 self.assertEqual(len(optimized_graph_nodes), len(expected_nodes)) 91 self.assertAllInSet(optimized_graph_nodes, expected_nodes) 92 93 @test_util.run_v1_only('b/120545219') 94 def testLoops(self): 95 g = ops.Graph() 96 with g.as_default(): 97 98 def _Cond(_, counter): 99 return counter < end 100 101 def _Body(buf, counter): 102 buf = array_ops.concat([buf, [counter]], 0) 103 counter += 1 104 return [buf, counter] 105 106 start = array_ops.placeholder(shape=[], dtype=dtypes.int32) 107 end = array_ops.placeholder(shape=[], dtype=dtypes.int32) 108 init_buf = array_ops.zeros(shape=[0], dtype=dtypes.int32) 109 loop_vars = [init_buf, start] 110 shape_inv = [ 111 tensor_shape.TensorShape([None]), 112 tensor_shape.TensorShape([]) 113 ] 114 buf, _ = control_flow_ops.while_loop(_Cond, _Body, loop_vars, shape_inv) 115 116 f = -array_ops.ones_like(buf, optimize=False) 117 buf_shape = array_ops.shape(buf) 118 f_shape = array_ops.shape(f) 119 ops.add_to_collection('train_op', buf_shape) 120 ops.add_to_collection('train_op', f_shape) 121 122 # Optimize the graph. 123 mg = meta_graph.create_meta_graph_def(graph=g) 124 config = config_pb2.ConfigProto() 125 rewriter_config = config.graph_options.rewrite_options 126 rewriter_config.min_graph_nodes = -1 127 optimized_graph = tf_optimizer.OptimizeGraph(config, mg) 128 mg.graph_def.CopyFrom(optimized_graph) 129 130 # Check that the nodes referenced in various collections have been preserved 131 item = gitem.Item(mg) 132 props = item.GetOpProperties() 133 buf_prop = props[buf.op.name] 134 f_prop = props[f.op.name] 135 self.assertEqual(buf_prop, f_prop) 136 137 138if __name__ == '__main__': 139 test.main() 140