• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/constant_folding.h"
17 #include "tensorflow/core/common_runtime/graph_constructor.h"
18 #include "tensorflow/core/graph/node_builder.h"
19 #include "tensorflow/core/graph/subgraph.h"
20 #include "tensorflow/core/platform/init_main.h"
21 #include "tensorflow/core/public/session.h"
22 #include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
23 #include "tensorflow/tools/graph_transforms/transform_utils.h"
24 
25 namespace tensorflow {
26 namespace graph_transforms {
27 
28 namespace {
29 
TypeForPlaceholder(const TransformFuncContext & context,const string & node_name,DataType * result)30 Status TypeForPlaceholder(const TransformFuncContext& context,
31                           const string& node_name, DataType* result) {
32   // If we don't find anything else, return float.
33   *result = DT_FLOAT;
34 
35   // Check to see if we have been given a default for all placeholders.
36   if (context.params.count("type")) {
37     if (context.params.at("type").size() != 1) {
38       return errors::InvalidArgument(
39           "You must pass no more than one default 'type' to "
40           "strip_unused_nodes");
41     }
42     const string& type_string = context.params.at("type")[0];
43     if (!DataTypeFromString(type_string, result)) {
44       return errors::InvalidArgument("Couldn't understand type argument '",
45                                      type_string, "'");
46     }
47   }
48 
49   // See if there's a particular type specified for this placeholder.
50   if (context.params.count("name") || context.params.count("type_for_name")) {
51     if (!context.params.count("name") ||
52         !context.params.count("type_for_name") ||
53         (context.params.at("type_for_name").size() !=
54          context.params.at("name").size())) {
55       return errors::InvalidArgument(
56           "You must pass a 'type_for_name' arg for every 'name', e.g. "
57           "strip_unused_nodes(name=foo, type_for_name=float, name=bar, "
58           "type_for_name=quint8");
59     }
60     const int name_count = context.params.at("name").size();
61     for (int i = 0; i < name_count; ++i) {
62       if (context.params.at("name")[i] == node_name) {
63         const string& type_string = context.params.at("type_for_name")[i];
64         if (!DataTypeFromString(type_string, result)) {
65           return errors::InvalidArgument("Couldn't understand type argument '",
66                                          type_string, "'");
67         }
68       }
69     }
70   }
71 
72   return OkStatus();
73 }
74 
ShapeForPlaceholder(const TransformFuncContext & context,const string & node_name,TensorShape * result)75 Status ShapeForPlaceholder(const TransformFuncContext& context,
76                            const string& node_name, TensorShape* result) {
77   // If we don't find anything else, return scalar.
78   *result = {};
79 
80   // Check to see if we have been given a default for all placeholders.
81   if (context.params.count("shape")) {
82     if (context.params.at("shape").size() != 1) {
83       return errors::InvalidArgument(
84           "You must pass no more than one default 'shape' to "
85           "strip_unused_nodes");
86     }
87     const string& shape_string = context.params.at("shape")[0];
88     TF_RETURN_IF_ERROR(TensorShapeFromString(shape_string, result));
89   }
90 
91   // See if there's a particular type specified for this placeholder.
92   if (context.params.count("name") || context.params.count("shape_for_name")) {
93     if (!context.params.count("name") ||
94         !context.params.count("shape_for_name") ||
95         (context.params.at("shape_for_name").size() !=
96          context.params.at("name").size())) {
97       return errors::InvalidArgument(
98           "You must pass a 'shape_for_name' arg for every 'name', e.g. "
99           "strip_unused_nodes(name=foo, shape_for_name=\"2,2,1\", name=bar, "
100           "shape_for_name=\"1\"");
101     }
102     const int name_count = context.params.at("name").size();
103     for (int i = 0; i < name_count; ++i) {
104       if (context.params.at("name")[i] == node_name) {
105         const string& shape_string = context.params.at("shape_for_name")[i];
106         TF_RETURN_IF_ERROR(TensorShapeFromString(shape_string, result));
107       }
108     }
109   }
110 
111   return OkStatus();
112 }
113 }  // namespace
114 
115 // Delete any nodes that don't contribute to the inference result.
StripUnusedNodes(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)116 Status StripUnusedNodes(const GraphDef& input_graph_def,
117                         const TransformFuncContext& context,
118                         GraphDef* output_graph_def) {
119   std::set<string> required_nodes;
120   std::set<string> input_nodes;
121   for (const string& input : context.input_names) {
122     required_nodes.insert(NodeNameFromInput(input));
123     input_nodes.insert(NodeNameFromInput(input));
124   }
125   for (const string& output : context.output_names) {
126     required_nodes.insert(output);
127   }
128 
129   std::map<string, const NodeDef*> node_lookup;
130   MapNamesToNodes(input_graph_def, &node_lookup);
131 
132   std::vector<string> current_inputs;
133   for (const string& output_name : context.output_names) {
134     current_inputs.push_back(NodeNameFromInput(output_name));
135   }
136 
137   while (!current_inputs.empty()) {
138     std::set<string> next_inputs;
139     for (const string& current_input : current_inputs) {
140       required_nodes.insert(current_input);
141       if (input_nodes.count(current_input)) {
142         continue;
143       }
144       if (!node_lookup.count(current_input)) {
145         return errors::InvalidArgument("Input node ", current_input,
146                                        " not found in graph");
147       }
148       const NodeDef* current_node = node_lookup[current_input];
149       for (const string& input_name : current_node->input()) {
150         string input_node_name = NodeNameFromInput(input_name);
151         if (!required_nodes.count(input_node_name)) {
152           next_inputs.insert(input_node_name);
153         }
154       }
155     }
156     current_inputs =
157         std::vector<string>(next_inputs.begin(), next_inputs.end());
158   }
159 
160   GraphDef filtered_graph_def;
161   FilterGraphDef(input_graph_def,
162                  [&](const NodeDef& node) {
163                    return required_nodes.count(node.name()) > 0;
164                  },
165                  &filtered_graph_def);
166 
167   output_graph_def->Clear();
168   for (const NodeDef& node : filtered_graph_def.node()) {
169     if (input_nodes.count(node.name())) {
170       NodeDef placeholder_node;
171       if (node.op() == "Placeholder") {
172         placeholder_node = node;
173       } else {
174         placeholder_node.set_op("Placeholder");
175         placeholder_node.set_name(node.name());
176         DataType type;
177         TF_RETURN_IF_ERROR(TypeForPlaceholder(context, node.name(), &type));
178         TensorShape shape;
179         TF_RETURN_IF_ERROR(ShapeForPlaceholder(context, node.name(), &shape));
180         SetNodeAttr("dtype", type, &placeholder_node);
181         SetNodeAttr("shape", shape, &placeholder_node);
182       }
183       *(output_graph_def->mutable_node()->Add()) = placeholder_node;
184     } else {
185       *(output_graph_def->mutable_node()->Add()) = node;
186     }
187   }
188   return OkStatus();
189 }
190 
191 REGISTER_GRAPH_TRANSFORM("strip_unused_nodes", StripUnusedNodes);
192 
193 }  // namespace graph_transforms
194 }  // namespace tensorflow
195