• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for StatSummarizer Python wrapper."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.framework import attr_value_pb2
22from tensorflow.core.framework import graph_pb2
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import tensor_util
25from tensorflow.python.platform import test
26from tensorflow.tools.graph_transforms import TransformGraph
27
28
29class TransformGraphTest(test.TestCase):
30
31  # This test constructs a graph with a relu op that's not used by the normal
32  # inference path, and then tests that the strip_unused transform removes it as
33  # expected.
34  def testTransformGraph(self):
35    input_graph_def = graph_pb2.GraphDef()
36
37    const_op1 = input_graph_def.node.add()
38    const_op1.op = "Const"
39    const_op1.name = "const_op1"
40    const_op1.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue(
41        type=dtypes.float32.as_datatype_enum))
42    const_op1.attr["value"].CopyFrom(
43        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
44            [1, 2], dtypes.float32, [1, 2])))
45
46    const_op2 = input_graph_def.node.add()
47    const_op2.op = "Const"
48    const_op2.name = "const_op2"
49    const_op2.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue(
50        type=dtypes.float32.as_datatype_enum))
51    const_op2.attr["value"].CopyFrom(
52        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
53            [3, 4], dtypes.float32, [1, 2])))
54
55    # Create an add that has two constants as inputs.
56    add_op = input_graph_def.node.add()
57    add_op.op = "Add"
58    add_op.attr["T"].CopyFrom(attr_value_pb2.AttrValue(
59        type=dtypes.float32.as_datatype_enum))
60    add_op.name = "add_op"
61    add_op.input.extend(["const_op1", "const_op2"])
62
63    # Create a relu that reads from the add.
64    relu_op = input_graph_def.node.add()
65    relu_op.op = "Relu"
66    relu_op.attr["T"].CopyFrom(attr_value_pb2.AttrValue(
67        type=dtypes.float32.as_datatype_enum))
68    relu_op.name = "relu_op"
69    relu_op.input.extend(["add_op"])
70
71    # We're specifying that add_op is the final output, and so the relu isn't
72    # needed.
73    input_names = []
74    output_names = ["add_op"]
75    transforms = ["strip_unused_nodes"]
76    transformed_graph_def = TransformGraph(input_graph_def, input_names,
77                                           output_names, transforms)
78
79    # We expect that the relu is no longer present after running the transform.
80    for node in transformed_graph_def.node:
81      self.assertNotEqual("Relu", node.op)
82
83
84if __name__ == "__main__":
85  test.main()
86