1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15r"""Removes unneeded nodes from a GraphDef file. 16 17This script is designed to help streamline models, by taking the input and 18output nodes that will be used by an application and figuring out the smallest 19set of operations that are required to run for those arguments. The resulting 20minimal graph is then saved out. 21 22The advantages of running this script are: 23 - You may be able to shrink the file size. 24 - Operations that are unsupported on your platform but still present can be 25 safely removed. 26The resulting graph may not be as flexible as the original though, since any 27input nodes that weren't explicitly mentioned may not be accessible any more. 28 29An example of command-line usage is: 30bazel build tensorflow/python/tools:strip_unused && \ 31bazel-bin/tensorflow/python/tools/strip_unused \ 32--input_graph=some_graph_def.pb \ 33--output_graph=/tmp/stripped_graph.pb \ 34--input_node_names=input0 35--output_node_names=softmax 36 37You can also look at strip_unused_test.py for an example of how to use it. 38 39""" 40from __future__ import absolute_import 41from __future__ import division 42from __future__ import print_function 43 44import argparse 45import sys 46 47from tensorflow.python.framework import dtypes 48from tensorflow.python.platform import app 49from tensorflow.python.tools import strip_unused_lib 50 51FLAGS = None 52 53 54def main(unused_args): 55 strip_unused_lib.strip_unused_from_files(FLAGS.input_graph, 56 FLAGS.input_binary, 57 FLAGS.output_graph, 58 FLAGS.output_binary, 59 FLAGS.input_node_names, 60 FLAGS.output_node_names, 61 FLAGS.placeholder_type_enum) 62 63 64if __name__ == '__main__': 65 parser = argparse.ArgumentParser() 66 parser.register('type', 'bool', lambda v: v.lower() == 'true') 67 parser.add_argument( 68 '--input_graph', 69 type=str, 70 default='', 71 help='TensorFlow \'GraphDef\' file to load.') 72 parser.add_argument( 73 '--input_binary', 74 nargs='?', 75 const=True, 76 type='bool', 77 default=False, 78 help='Whether the input files are in binary format.') 79 parser.add_argument( 80 '--output_graph', 81 type=str, 82 default='', 83 help='Output \'GraphDef\' file name.') 84 parser.add_argument( 85 '--output_binary', 86 nargs='?', 87 const=True, 88 type='bool', 89 default=True, 90 help='Whether to write a binary format graph.') 91 parser.add_argument( 92 '--input_node_names', 93 type=str, 94 default='', 95 help='The name of the input nodes, comma separated.') 96 parser.add_argument( 97 '--output_node_names', 98 type=str, 99 default='', 100 help='The name of the output nodes, comma separated.') 101 parser.add_argument( 102 '--placeholder_type_enum', 103 type=int, 104 default=dtypes.float32.as_datatype_enum, 105 help='The AttrValue enum to use for placeholders.') 106 FLAGS, unparsed = parser.parse_known_args() 107 app.run(main=main, argv=[sys.argv[0]] + unparsed) 108