• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for graph_matcher."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.framework.python import ops as contrib_ops
22from tensorflow.contrib.layers.python.layers import initializers
23from tensorflow.contrib.layers.python.layers import layers
24from tensorflow.contrib.quantize.python import graph_matcher
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import test_util
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import init_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import nn_ops
33from tensorflow.python.platform import googletest
34
35
36class GraphMatcherTest(test_util.TensorFlowTestCase):
37
38  def test_conv_layer(self):
39    g = ops.Graph()
40    with g.as_default():
41      inputs = array_ops.placeholder(dtypes.float32, shape=[8, 5, 5, 3])
42
43    with contrib_ops.arg_scope(
44        [layers.batch_norm], fused=True, is_training=True, trainable=True):
45      return layers.convolution(
46          inputs,
47          num_outputs=16,
48          kernel_size=3,
49          stride=1,
50          padding='VALID',
51          activation_fn=nn_ops.relu,
52          normalizer_fn=layers.batch_norm,
53          normalizer_params={},
54          weights_initializer=initializers.xavier_initializer(),
55          weights_regularizer=None,
56          biases_initializer=init_ops.zeros_initializer(),
57          biases_regularizer=None,
58          reuse=None,
59          trainable=True,
60          scope=None)
61
62    inputs_pattern = graph_matcher.OpTypePattern('*', name='inputs')
63    relu_pattern = graph_matcher.OpTypePattern(
64        'Relu',
65        name='relu',
66        inputs=[
67            graph_matcher.OpTypePattern(
68                'FusedBatchNorm',
69                inputs=[
70                    graph_matcher.OpTypePattern(
71                        'Conv2D', inputs=[inputs_pattern, '*']), '*', '*', '*',
72                    '*'
73                ])
74        ])
75    matcher = graph_matcher.GraphMatcher(relu_pattern)
76    match_results = list(matcher.match_graph(g))
77    self.assertEqual(1, len(match_results))
78    match_result = match_results[0]
79    self.assertEqual(match_result.get_tensor(inputs_pattern), inputs)
80    self.assertEqual(match_result.get_tensor('inputs'), inputs)
81
82  def test_multiple_outputs(self):
83    #   -         +
84    #  / \y0   y1/ \
85    # x    split    z
86    #       |
87    #       y         (nodes are ops; edges are going up)
88    g = ops.Graph()
89    with g.as_default():
90      x = array_ops.placeholder(dtypes.float32, shape=[1], name='x')
91      y = array_ops.placeholder(dtypes.float32, shape=[2], name='y')
92      y0, y1 = array_ops.split(y, num_or_size_splits=2, axis=0)
93      z = array_ops.placeholder(dtypes.float32, shape=[1], name='z')
94      math_ops.add(x, y0)
95      math_ops.subtract(y1, z)
96
97    y1_pattern = graph_matcher.OpTypePattern('*')
98    minus_pattern = graph_matcher.OpTypePattern('Sub', inputs=[y1_pattern, '*'])
99    matcher = graph_matcher.GraphMatcher(minus_pattern)
100
101    match_results = list(matcher.match_graph(g))
102    self.assertEqual(1, len(match_results))
103    match_result = match_results[0]
104
105    self.assertEqual(y0.op, y1.op)
106    self.assertEqual(match_result.get_op(y1_pattern), y1.op)
107    self.assertEqual(match_result.get_tensor(y1_pattern), y1)
108
109  def test_oneof_type_pattern(self):
110    #   -   +
111    #  / \ / \
112    # x   y   z
113    g = ops.Graph()
114    with g.as_default():
115      x = array_ops.placeholder(dtypes.float32, shape=[], name='x')
116      y = array_ops.placeholder(dtypes.float32, shape=[], name='y')
117      z = array_ops.placeholder(dtypes.float32, shape=[], name='z')
118      plus = x + y
119      minus = y - z
120
121    add_or_sub_pattern = graph_matcher.OpTypePattern(
122        'Add|Sub', inputs=['*', '*'])
123    matcher = graph_matcher.GraphMatcher(add_or_sub_pattern)
124    self.assertEqual([
125        match_result.get_op(add_or_sub_pattern)
126        for match_result in matcher.match_graph(g)
127    ], [plus.op, minus.op])
128
129  def test_oneof_pattern(self):
130    reshape_pattern = graph_matcher.OpTypePattern('Reshape')
131    transpose_pattern = graph_matcher.OneofPattern([
132        graph_matcher.OpTypePattern(
133            'Transpose',
134            name='transpose',
135            inputs=[
136                graph_matcher.OpTypePattern(
137                    'Slice', name='slice', inputs=[reshape_pattern, '*', '*']),
138                '*'
139            ]),
140        graph_matcher.OpTypePattern(
141            'Transpose', name='transpose', inputs=[reshape_pattern, '*'])
142    ])
143
144    matcher = graph_matcher.GraphMatcher(transpose_pattern)
145
146    g = ops.Graph()
147    with g.as_default():
148      inputs = array_ops.placeholder(dtypes.float32, shape=[6])
149      reshape = array_ops.reshape(inputs, [2, 3])
150      transpose = array_ops.transpose(reshape)
151      [match_result] = list(matcher.match_graph(g))
152      self.assertEqual(match_result.get_tensor(reshape_pattern), reshape)
153      self.assertEqual(match_result.get_tensor('slice'), None)
154      self.assertEqual(match_result.get_op('transpose'), transpose.op)
155
156    g = ops.Graph()
157    with g.as_default():
158      inputs = array_ops.placeholder(dtypes.float32, shape=[6])
159      reshape = array_ops.reshape(inputs, [2, 3])
160      slicing = array_ops.slice(reshape, [0, 0], [-1, -1])
161      transpose = array_ops.transpose(slicing)
162      [match_result] = list(matcher.match_graph(g))
163      self.assertEqual(match_result.get_tensor(reshape_pattern), reshape)
164      self.assertEqual(match_result.get_tensor('slice'), slicing)
165      self.assertEqual(match_result.get_op('transpose'), transpose.op)
166
167  def test_ordered_pattern(self):
168    #   +            +
169    #  / \          / \
170    # x   y  and   y   x  should both match when ordered inputs is False.
171    # Even when x and y are different operations.
172    g = ops.Graph()
173    with g.as_default():
174      x = array_ops.placeholder(dtypes.float32, shape=[], name='x')
175      y = constant_op.constant(1.0, dtype=dtypes.float32)
176      plus = x + y
177
178    add_pattern_a = graph_matcher.OpTypePattern(
179        'Add', inputs=['Const', 'Placeholder'], ordered_inputs=False)
180    add_pattern_b = graph_matcher.OpTypePattern(
181        'Add', inputs=['Placeholder', 'Const'], ordered_inputs=False)
182    add_pattern_fail = graph_matcher.OpTypePattern(
183        'Add', inputs=['Const', 'Placeholder'], ordered_inputs=True)
184    # Both add_pattern_a and add_pattern_b should match the graph since
185    # ordered_input was set False.
186    matcher_a = graph_matcher.GraphMatcher(add_pattern_a)
187    self.assertEqual([
188        match_result.get_op(add_pattern_a)
189        for match_result in matcher_a.match_graph(g)
190    ], [plus.op])
191    matcher_b = graph_matcher.GraphMatcher(add_pattern_b)
192    self.assertEqual([
193        match_result.get_op(add_pattern_b)
194        for match_result in matcher_b.match_graph(g)
195    ], [plus.op])
196    # But if ordered_inputs is True, the inputs list match should fail if not
197    # specified in the right order.
198    matcher_fail = graph_matcher.GraphMatcher(add_pattern_fail)
199    self.assertEqual(
200        len([
201            match_result.get_op(add_pattern_fail)
202            for match_result in matcher_fail.match_graph(g)
203        ]), 0)
204
205
206if __name__ == '__main__':
207  googletest.main()
208