1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for graph_matcher.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.framework.python import ops as contrib_ops 22from tensorflow.contrib.layers.python.layers import initializers 23from tensorflow.contrib.layers.python.layers import layers 24from tensorflow.contrib.quantize.python import graph_matcher 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import test_util 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import init_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import nn_ops 33from tensorflow.python.platform import googletest 34 35 36class GraphMatcherTest(test_util.TensorFlowTestCase): 37 38 def test_conv_layer(self): 39 g = ops.Graph() 40 with g.as_default(): 41 inputs = array_ops.placeholder(dtypes.float32, shape=[8, 5, 5, 3]) 42 43 with contrib_ops.arg_scope( 44 [layers.batch_norm], fused=True, is_training=True, trainable=True): 45 return layers.convolution( 46 inputs, 47 num_outputs=16, 48 kernel_size=3, 49 stride=1, 50 padding='VALID', 51 activation_fn=nn_ops.relu, 52 normalizer_fn=layers.batch_norm, 53 normalizer_params={}, 54 weights_initializer=initializers.xavier_initializer(), 55 weights_regularizer=None, 56 biases_initializer=init_ops.zeros_initializer(), 57 biases_regularizer=None, 58 reuse=None, 59 trainable=True, 60 scope=None) 61 62 inputs_pattern = graph_matcher.OpTypePattern('*', name='inputs') 63 relu_pattern = graph_matcher.OpTypePattern( 64 'Relu', 65 name='relu', 66 inputs=[ 67 graph_matcher.OpTypePattern( 68 'FusedBatchNorm', 69 inputs=[ 70 graph_matcher.OpTypePattern( 71 'Conv2D', inputs=[inputs_pattern, '*']), '*', '*', '*', 72 '*' 73 ]) 74 ]) 75 matcher = graph_matcher.GraphMatcher(relu_pattern) 76 match_results = list(matcher.match_graph(g)) 77 self.assertEqual(1, len(match_results)) 78 match_result = match_results[0] 79 self.assertEqual(match_result.get_tensor(inputs_pattern), inputs) 80 self.assertEqual(match_result.get_tensor('inputs'), inputs) 81 82 def test_multiple_outputs(self): 83 # - + 84 # / \y0 y1/ \ 85 # x split z 86 # | 87 # y (nodes are ops; edges are going up) 88 g = ops.Graph() 89 with g.as_default(): 90 x = array_ops.placeholder(dtypes.float32, shape=[1], name='x') 91 y = array_ops.placeholder(dtypes.float32, shape=[2], name='y') 92 y0, y1 = array_ops.split(y, num_or_size_splits=2, axis=0) 93 z = array_ops.placeholder(dtypes.float32, shape=[1], name='z') 94 math_ops.add(x, y0) 95 math_ops.subtract(y1, z) 96 97 y1_pattern = graph_matcher.OpTypePattern('*') 98 minus_pattern = graph_matcher.OpTypePattern('Sub', inputs=[y1_pattern, '*']) 99 matcher = graph_matcher.GraphMatcher(minus_pattern) 100 101 match_results = list(matcher.match_graph(g)) 102 self.assertEqual(1, len(match_results)) 103 match_result = match_results[0] 104 105 self.assertEqual(y0.op, y1.op) 106 self.assertEqual(match_result.get_op(y1_pattern), y1.op) 107 self.assertEqual(match_result.get_tensor(y1_pattern), y1) 108 109 def test_oneof_type_pattern(self): 110 # - + 111 # / \ / \ 112 # x y z 113 g = ops.Graph() 114 with g.as_default(): 115 x = array_ops.placeholder(dtypes.float32, shape=[], name='x') 116 y = array_ops.placeholder(dtypes.float32, shape=[], name='y') 117 z = array_ops.placeholder(dtypes.float32, shape=[], name='z') 118 plus = x + y 119 minus = y - z 120 121 add_or_sub_pattern = graph_matcher.OpTypePattern( 122 'Add|Sub', inputs=['*', '*']) 123 matcher = graph_matcher.GraphMatcher(add_or_sub_pattern) 124 self.assertEqual([ 125 match_result.get_op(add_or_sub_pattern) 126 for match_result in matcher.match_graph(g) 127 ], [plus.op, minus.op]) 128 129 def test_oneof_pattern(self): 130 reshape_pattern = graph_matcher.OpTypePattern('Reshape') 131 transpose_pattern = graph_matcher.OneofPattern([ 132 graph_matcher.OpTypePattern( 133 'Transpose', 134 name='transpose', 135 inputs=[ 136 graph_matcher.OpTypePattern( 137 'Slice', name='slice', inputs=[reshape_pattern, '*', '*']), 138 '*' 139 ]), 140 graph_matcher.OpTypePattern( 141 'Transpose', name='transpose', inputs=[reshape_pattern, '*']) 142 ]) 143 144 matcher = graph_matcher.GraphMatcher(transpose_pattern) 145 146 g = ops.Graph() 147 with g.as_default(): 148 inputs = array_ops.placeholder(dtypes.float32, shape=[6]) 149 reshape = array_ops.reshape(inputs, [2, 3]) 150 transpose = array_ops.transpose(reshape) 151 [match_result] = list(matcher.match_graph(g)) 152 self.assertEqual(match_result.get_tensor(reshape_pattern), reshape) 153 self.assertEqual(match_result.get_tensor('slice'), None) 154 self.assertEqual(match_result.get_op('transpose'), transpose.op) 155 156 g = ops.Graph() 157 with g.as_default(): 158 inputs = array_ops.placeholder(dtypes.float32, shape=[6]) 159 reshape = array_ops.reshape(inputs, [2, 3]) 160 slicing = array_ops.slice(reshape, [0, 0], [-1, -1]) 161 transpose = array_ops.transpose(slicing) 162 [match_result] = list(matcher.match_graph(g)) 163 self.assertEqual(match_result.get_tensor(reshape_pattern), reshape) 164 self.assertEqual(match_result.get_tensor('slice'), slicing) 165 self.assertEqual(match_result.get_op('transpose'), transpose.op) 166 167 def test_ordered_pattern(self): 168 # + + 169 # / \ / \ 170 # x y and y x should both match when ordered inputs is False. 171 # Even when x and y are different operations. 172 g = ops.Graph() 173 with g.as_default(): 174 x = array_ops.placeholder(dtypes.float32, shape=[], name='x') 175 y = constant_op.constant(1.0, dtype=dtypes.float32) 176 plus = x + y 177 178 add_pattern_a = graph_matcher.OpTypePattern( 179 'Add', inputs=['Const', 'Placeholder'], ordered_inputs=False) 180 add_pattern_b = graph_matcher.OpTypePattern( 181 'Add', inputs=['Placeholder', 'Const'], ordered_inputs=False) 182 add_pattern_fail = graph_matcher.OpTypePattern( 183 'Add', inputs=['Const', 'Placeholder'], ordered_inputs=True) 184 # Both add_pattern_a and add_pattern_b should match the graph since 185 # ordered_input was set False. 186 matcher_a = graph_matcher.GraphMatcher(add_pattern_a) 187 self.assertEqual([ 188 match_result.get_op(add_pattern_a) 189 for match_result in matcher_a.match_graph(g) 190 ], [plus.op]) 191 matcher_b = graph_matcher.GraphMatcher(add_pattern_b) 192 self.assertEqual([ 193 match_result.get_op(add_pattern_b) 194 for match_result in matcher_b.match_graph(g) 195 ], [plus.op]) 196 # But if ordered_inputs is True, the inputs list match should fail if not 197 # specified in the right order. 198 matcher_fail = graph_matcher.GraphMatcher(add_pattern_fail) 199 self.assertEqual( 200 len([ 201 match_result.get_op(add_pattern_fail) 202 for match_result in matcher_fail.match_graph(g) 203 ]), 0) 204 205 206if __name__ == '__main__': 207 googletest.main() 208