• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <map>
21 #include <string>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 
27 #include "tensorflow/core/common_runtime/constant_folding.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/common_runtime/shape_refiner.h"
30 #include "tensorflow/core/graph/node_builder.h"
31 #include "tensorflow/core/graph/subgraph.h"
32 #include "tensorflow/core/lib/core/stringpiece.h"
33 #include "tensorflow/core/lib/strings/numbers.h"
34 #include "tensorflow/core/platform/init_main.h"
35 #include "tensorflow/core/public/session.h"
36 #include "tensorflow/tools/graph_transforms/transform_utils.h"
37 
38 namespace tensorflow {
39 namespace graph_transforms {
40 namespace {
41 using StringPieceSet = std::unordered_set<StringPiece, StringPieceHasher>;
42 template <typename T>
43 using StringPieceMap = std::unordered_map<StringPiece, T, StringPieceHasher>;
44 }  // namespace
45 
ReplaceSendRecvs(const GraphDef & original_graph_def,const GraphDef & rewritten_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,GraphDef * output_graph_def)46 Status ReplaceSendRecvs(const GraphDef& original_graph_def,
47                         const GraphDef& rewritten_graph_def,
48                         const std::vector<string>& inputs,
49                         const std::vector<string>& outputs,
50                         GraphDef* output_graph_def) {
51   // recv_node_names serves as a string storage for recv node names.
52   std::vector<string> recv_node_names(inputs.size());
53   StringPieceMap<TensorId> recv_node_map;
54   StringPieceSet input_nodes;
55   for (int i = 0; i < inputs.size(); ++i) {
56     // RewriteGraphForExecution adds a recv node for each input edge. We assume
57     // here that adding such recv node did not fail. For example, the original
58     // graph did not already have a node with the name for the new added recv
59     // node.
60     TensorId id = ParseTensorName(inputs[i]);
61     input_nodes.insert(id.first);
62     string& recv_node_name = recv_node_names[i];
63     recv_node_name = strings::StrCat("_recv_", id.first, "_", id.second);
64     recv_node_map.emplace(recv_node_name, id);
65   }
66 
67   StringPieceMap<const NodeDef*> original_map;
68   for (const NodeDef& node : original_graph_def.node()) {
69     original_map.emplace(node.name(), &node);
70   }
71 
72   for (const NodeDef& node : rewritten_graph_def.node()) {
73     if ((node.op() == "_Send") || (node.op() == "_Recv")) {
74       // If the op is a Send or Recv that wasn't in the original, skip it.
75       if (original_map.count(node.name()) == 0) {
76         continue;
77       }
78     }
79 
80     NodeDef* new_node = output_graph_def->add_node();
81     new_node->MergeFrom(node);
82     for (int i = 0; i < new_node->input_size(); ++i) {
83       string& input = *new_node->mutable_input(i);
84       TensorId id = ParseTensorName(input);
85       const auto iter = recv_node_map.find(id.first);
86       if (iter != recv_node_map.end()) {
87         // The node being substituted is a Recv node, and it has only one
88         // output. If this input is not a control input, then replace the input
89         // with the mapped value. Otherwise, replace the node name only.
90         if (id.second != Graph::kControlSlot) {
91           CHECK_EQ(id.second, 0);
92           input = iter->second.ToString();
93         } else {
94           id.first = iter->second.first;
95           input = id.ToString();
96         }
97       }
98     }
99 
100     // RewriteGraphForExecution() did not remove this input node. Remove this
101     // node name from input_nodes so that a duplicate does not get added to the
102     // output_graph_def.
103     auto iter = input_nodes.find(new_node->name());
104     if (iter != input_nodes.end()) {
105       input_nodes.erase(iter);
106     }
107   }
108 
109   // Some input nodes are removed in rewrite_graph_def. Add those nodes to
110   // output_graph_def.
111   for (StringPiece name : input_nodes) {
112     const NodeDef& removed_node = *CHECK_NOTNULL(original_map[name]);
113     output_graph_def->add_node()->MergeFrom(removed_node);
114   }
115 
116   return OkStatus();
117 }
118 
RewriteInputsAsPlaceholders(const TransformFuncContext & context,GraphDef * graph_def)119 Status RewriteInputsAsPlaceholders(const TransformFuncContext& context,
120                                    GraphDef* graph_def) {
121   std::unordered_set<string> input_names;
122   for (const string& input_name : context.input_names) {
123     input_names.emplace(ParseTensorName(input_name).first);
124   }
125 
126   for (NodeDef& node : *graph_def->mutable_node()) {
127     if (input_names.find(node.name()) == input_names.end()) {
128       continue;
129     }
130     if (node.op() == "PlaceholderWithDefault") {
131       node.set_op("Placeholder");
132       node.clear_input();
133     } else if (node.op() != "Placeholder") {
134       return errors::InvalidArgument(
135           "Input '", node.name(),
136           "' was expected to be a Placeholder or PlaceholderWithDefault op, "
137           "but was ",
138           node.op());
139     }
140   }
141   return OkStatus();
142 }
143 
RemoveUnusedNodes(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)144 Status RemoveUnusedNodes(const GraphDef& input_graph_def,
145                          const TransformFuncContext& context,
146                          GraphDef* output_graph_def) {
147   StringPieceMap<const NodeDef*> node_map;
148   for (const NodeDef& node : input_graph_def.node()) {
149     node_map.emplace(node.name(), &node);
150   }
151 
152   std::unordered_set<TensorId, TensorId::Hasher> input_names;
153   for (const string& input : context.input_names) {
154     input_names.insert(ParseTensorName(input));
155   }
156   StringPieceSet used_nodes;
157   StringPieceSet current_nodes;
158   for (const string& name : context.output_names) {
159     TensorId id = ParseTensorName(name);
160     used_nodes.insert(id.first);
161     current_nodes.insert(id.first);
162   }
163   while (!current_nodes.empty()) {
164     StringPieceSet next_nodes;
165     for (StringPiece node_name : current_nodes) {
166       if (node_map.count(node_name) == 0) {
167         LOG(ERROR) << "Bad graph structure, no node named '" << node_name
168                    << "' found for input lookup";
169         return errors::InvalidArgument("Bad graph structure, no node named '",
170                                        node_name, "' found for input lookup");
171       }
172       const NodeDef& node = *(node_map[node_name]);
173       for (const string& input : node.input()) {
174         TensorId id = ParseTensorName(input);
175         if (input_names.count(id) > 0) {
176           continue;
177         }
178         if (used_nodes.insert(id.first).second) {
179           next_nodes.insert(id.first);
180         }
181       }
182     }
183     current_nodes.swap(next_nodes);
184   }
185   for (const TensorId& id : input_names) {
186     used_nodes.insert(id.first);
187   }
188   FilterGraphDef(
189       input_graph_def,
190       [&](const NodeDef& node) { return used_nodes.count(node.name()) > 0; },
191       output_graph_def);
192   TF_RETURN_IF_ERROR(RewriteInputsAsPlaceholders(context, output_graph_def));
193 
194   return OkStatus();
195 }
196 
197 // Converts a shape inference handle to a PartialTensorShape.
ShapeHandleToTensorShape(const shape_inference::ShapeHandle & handle,shape_inference::InferenceContext * context,PartialTensorShape * shape)198 Status ShapeHandleToTensorShape(const shape_inference::ShapeHandle& handle,
199                                 shape_inference::InferenceContext* context,
200                                 PartialTensorShape* shape) {
201   // The default is already unknown.
202   if (!context->RankKnown(handle)) return OkStatus();
203 
204   std::vector<int64_t> dims(context->Rank(handle));
205   for (int32_t i = 0; i < dims.size(); ++i) {
206     dims[i] = context->Value(context->Dim(handle, i));
207   }
208   return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
209 }
210 
211 // Converts any sub-graphs that can be resolved into constant expressions into
212 // single Const ops.
FoldConstants(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)213 Status FoldConstants(const GraphDef& input_graph_def,
214                      const TransformFuncContext& context,
215                      GraphDef* output_graph_def) {
216   Graph input_graph(OpRegistry::Global());
217   TF_RETURN_IF_ERROR(input_graph.AddFunctionLibrary(input_graph_def.library()));
218 
219   ShapeRefiner shape_refiner(input_graph.versions(), input_graph.op_registry());
220   shape_refiner.set_require_shape_inference_fns(false);
221   shape_refiner.set_disable_constant_propagation(false);
222   shape_refiner.set_function_library_for_shape_inference(
223       &input_graph.flib_def());
224 
225   bool clear_output_shapes;
226   TF_RETURN_IF_ERROR(context.GetOneBoolParameter("clear_output_shapes", true,
227                                                  &clear_output_shapes));
228   if (clear_output_shapes) {
229     // Some older GraphDefs have saved _output_shapes attributes which are out
230     // of date and cause import errors, so clean them up first.
231     GraphDef cleaned_graph_def;
232     RemoveAttributes(input_graph_def, {"_output_shapes"}, &cleaned_graph_def);
233 
234     TF_RETURN_IF_ERROR(
235         ImportGraphDef({}, cleaned_graph_def, &input_graph, &shape_refiner));
236   } else {
237     TF_RETURN_IF_ERROR(
238         ImportGraphDef({}, input_graph_def, &input_graph, &shape_refiner));
239   }
240 
241   // Sorted array of input names as lookup table.
242   std::vector<TensorId> input_names;
243   input_names.reserve(context.input_names.size());
244   std::transform(context.input_names.begin(), context.input_names.end(),
245                  std::back_inserter(input_names),
246                  [](const string& name) { return ParseTensorName(name); });
247 
248   const auto compare = [](TensorId lhs, TensorId rhs) {
249     return lhs.first < rhs.first;
250   };
251 
252   std::sort(input_names.begin(), input_names.end(), compare);
253 
254   // Set statically inferred shapes.
255   std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
256   for (const Node* const node : input_graph.nodes()) {
257     auto ctx = shape_refiner.GetContext(node);
258     if (ctx == nullptr) {
259       continue;
260     }
261 
262     std::vector<PartialTensorShape>& partial_shapes = shape_map[node->name()];
263     if (ctx->num_outputs() <= 0) continue;
264     partial_shapes.resize(ctx->num_outputs());
265 
266     // Check all outputs.
267     for (const Edge* out_edge : node->out_edges()) {
268       if (out_edge->IsControlEdge()) continue;
269 
270       const int output_idx = out_edge->src_output();
271       TF_RETURN_IF_ERROR(ShapeHandleToTensorShape(ctx->output(output_idx), ctx,
272                                                   &partial_shapes[output_idx]));
273     }
274 
275     // RewriteGraphForExecution() will add a Recv node for each input. Shape
276     // refiner does not include shape information of these Recv nodes. Therefore
277     // we add entries for Recv nodes here.
278     const auto pair = std::equal_range(input_names.begin(), input_names.end(),
279                                        TensorId{node->name(), 0}, compare);
280     for (auto it = pair.first; it != pair.second; ++it) {
281       const string recv_name =
282           strings::StrCat("_recv_", it->first, "_", it->second);
283       auto& recv_partial_shapes = shape_map[recv_name];
284       // For whatever reason (for example, name collision) if the map entry was
285       // already there, then do nothing.
286       if (recv_partial_shapes.empty()) {
287         recv_partial_shapes.push_back(partial_shapes[it->second]);
288       }
289     }
290   }
291 
292   subgraph::RewriteGraphMetadata unused_metadata;
293   TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
294       &input_graph, context.input_names, context.output_names, {}, {},
295       false /* use_function_convention */, &unused_metadata));
296 
297   ConstantFoldingOptions cf_opts;
298   cf_opts.shape_map = &shape_map;
299 
300   // Exclude specified nodes from constant folding.
301   std::set<string> excluded_ops, excluded_nodes;
302   if (context.params.count("exclude_op") > 0) {
303     const auto& ops = context.params.at("exclude_op");
304     excluded_ops = std::set<string>(ops.begin(), ops.end());
305   }
306   if (context.params.count("exclude_node") > 0) {
307     const auto& nodes = context.params.at("exclude_node");
308     excluded_nodes = std::set<string>(nodes.begin(), nodes.end());
309   }
310   if (!excluded_ops.empty() || !excluded_nodes.empty()) {
311     cf_opts.consider = [excluded_ops, excluded_nodes](const Node* n) {
312       return excluded_ops.find(n->op_def().name()) == excluded_ops.end() &&
313              excluded_nodes.find(n->name()) == excluded_nodes.end();
314     };
315   }
316 
317   TF_RETURN_IF_ERROR(context.GetOneInt64Parameter(
318       "max_constant_size_in_bytes", cf_opts.max_constant_size_in_bytes,
319       &cf_opts.max_constant_size_in_bytes));
320 
321   // Constant folding.
322   bool was_mutated;
323   TF_RETURN_IF_ERROR(ConstantFold(cf_opts, nullptr, Env::Default(), nullptr,
324                                   &input_graph, &was_mutated));
325   GraphDef folded_graph_def;
326   input_graph.ToGraphDef(&folded_graph_def);
327   GraphDef send_recvs_replaced;
328   TF_RETURN_IF_ERROR(ReplaceSendRecvs(input_graph_def, folded_graph_def,
329                                       context.input_names, context.output_names,
330                                       &send_recvs_replaced));
331   TF_RETURN_IF_ERROR(
332       RemoveUnusedNodes(send_recvs_replaced, context, output_graph_def));
333   return OkStatus();
334 }
335 
336 REGISTER_GRAPH_TRANSFORM("fold_constants", FoldConstants);
337 
338 }  // namespace graph_transforms
339 }  // namespace tensorflow
340