• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
17 
18 #include <deque>
19 
20 #include "tensorflow/core/common_runtime/graph_runner.h"
21 #include "tensorflow/core/common_runtime/shape_refiner.h"
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/shape_inference.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/versions.pb.h"
27 #include "tensorflow/core/graph/graph.h"
28 
29 namespace tensorflow {
30 
31 using shape_inference::InferenceContext;
32 
33 namespace {
34 
35 // Tries to infer tensor output based on the input shapes of the node. In some
36 // cases, the shapes of the inputs are sufficient for inferring the contents of
37 // the output tensor. For example, a Shape op with fully defined input shapes
38 // can have its output tensor inferred.
TryToInferTensorOutputFromInputShapes(const Edge & edge,const ShapeRefiner & refiner,Tensor * output,bool * success)39 Status TryToInferTensorOutputFromInputShapes(const Edge& edge,
40                                              const ShapeRefiner& refiner,
41                                              Tensor* output, bool* success) {
42   *success = false;
43   const Node* node = edge.src();
44   InferenceContext* c = refiner.GetContext(node);
45   if (c == nullptr) {
46     // An input without context is a soft failure; we sometimes need to break
47     // control flow loops by running shape inference on a node without first
48     // adding its input.
49     return Status::OK();
50   }
51 
52   if (node->type_string() == "Shape") {
53     // If input shapes to the shape op are fully defined,
54     // we can infer the shape op's output tensor.
55     bool fully_defined_inputs = c->FullyDefined(c->input(0));
56     if (fully_defined_inputs) {
57       int input_rank = c->Rank(c->input(0));
58       Tensor t(node->output_type(0), TensorShape({input_rank}));
59       if (node->output_type(0) == DT_INT32) {
60         auto flat = t.flat<int>();
61         for (int i = 0; i < input_rank; i++) {
62           int64 dimension = c->Value(c->Dim(c->input(0), i));
63           if (!FastBoundsCheck(dimension, std::numeric_limits<int32>::max())) {
64             return errors::InvalidArgument(
65                 "Shape has output type int32, but dimension exceeds maximum "
66                 "int32 value");
67           }
68           flat(i) = static_cast<int32>(dimension);
69         }
70       } else if (node->output_type(0) == DT_INT64) {
71         auto flat = t.flat<int64>();
72         for (int i = 0; i < input_rank; i++) {
73           flat(i) = c->Value(c->Dim(c->input(0), i));
74         }
75       } else {
76         return errors::FailedPrecondition(
77             "Shape has output type that is not int32 or int64");
78       }
79       *output = t;
80       *success = true;
81     }
82   } else if (node->type_string() == "Rank") {
83     bool rank_known = c->RankKnown(c->input(0));
84     if (rank_known) {
85       int32 input_rank = c->Rank(c->input(0));
86       Tensor t(node->output_type(0), TensorShape({}));
87       t.flat<int32>()(0) = input_rank;
88       *output = t;
89       *success = true;
90     }
91   } else if (node->type_string() == "Size") {
92     bool fully_defined_inputs = c->FullyDefined(c->input(0));
93     if (fully_defined_inputs) {
94       int32 rank = c->Rank(c->input(0));
95       Tensor t(node->output_type(0), TensorShape({}));
96       int64 size = 1;
97       for (int i = 0; i < rank; i++) {
98         size *= c->Value(c->Dim(c->input(0), i));
99       }
100       if (node->output_type(0) == DT_INT32) {
101         if (!FastBoundsCheck(size, std::numeric_limits<int32>::max())) {
102           return errors::InvalidArgument(
103               "Size has output type int32, but size exceeds maximum int32 "
104               "value");
105         }
106         t.flat<int32>()(0) = static_cast<int32>(size);
107       } else if (node->output_type(0) == DT_INT64) {
108         t.flat<int64>()(0) = size;
109       } else {
110         return errors::FailedPrecondition(
111             "Size has output type that is not int32 or int64");
112       }
113       *output = t;
114       *success = true;
115     }
116   }
117   return Status::OK();
118 }
119 
120 // Returns true if 'node' has a registered CPU kernel.
HasCpuKernel(const Node & node)121 bool HasCpuKernel(const Node& node) {
122   return FindKernelDef(DeviceType(DEVICE_CPU), node.def(), /*def=*/nullptr,
123                        /*kernel_class_name=*/nullptr)
124       .ok();
125 }
126 
GetArgNodeIndex(const Node * node,int num_function_inputs,int * index)127 Status GetArgNodeIndex(const Node* node, int num_function_inputs, int* index) {
128   DCHECK(node->IsArg());
129   TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", index));
130   if (*index < 0 || num_function_inputs <= *index) {
131     return errors::Internal(
132         "Function instantiation included invalid input index: ", index,
133         " not in [0, ", num_function_inputs, ").");
134   }
135   return Status::OK();
136 }
137 
138 // Extracts the subgraph ending at 'target_node' that is statically computable
139 // and inserts into 'out_graph'. If statically computable, 'is_constant_graph'
140 // will be set to true.
ExtractConstantSubgraph(const Node & target_node,const ShapeRefiner & refiner,const std::unordered_map<string,Tensor> * cached_values,Graph * out_graph,bool * is_constant_graph,std::vector<std::pair<string,Tensor>> * const_inputs,InferenceContext * outer_context)141 Status ExtractConstantSubgraph(
142     const Node& target_node, const ShapeRefiner& refiner,
143     const std::unordered_map<string, Tensor>* cached_values, Graph* out_graph,
144     bool* is_constant_graph,
145     std::vector<std::pair<string, Tensor>>* const_inputs,
146     InferenceContext* outer_context) {
147   *is_constant_graph = false;
148   std::unordered_set<string> const_inputs_added;
149 
150   if (target_node.op_def().is_stateful()) {
151     return Status::OK();
152   }
153 
154   if (IsMerge(&target_node)) {
155     return Status::OK();
156   }
157 
158   if (target_node.type_string() == "PlaceholderWithDefault") {
159     return Status::OK();
160   }
161 
162   // Since constant-folding runs on the CPU, do not attempt to constant-fold
163   // operators that have no CPU kernel.
164   if (!HasCpuKernel(target_node)) {
165     return Status::OK();
166   }
167 
168   // TODO(skyewm): should more of the filtering applied in input nodes below be
169   // applied to target_node here?
170 
171   // Identify the possibly constant subgraph by recursively iterating backwards
172   // through the inputs to 'target_node' until we either 1) find an already
173   // existing input to our subgraph 'const_inputs', 2) Discover our graph is not
174   // constant, or 3) Hit a root node.
175 
176   struct NodeAndRecursed {
177     Node* new_node = nullptr;
178     bool recursed = false;
179   };
180 
181   std::map<const Node*, NodeAndRecursed> old_to_new_and_recursed;
182   Node* target_node_copy = out_graph->CopyNode(&target_node);
183   old_to_new_and_recursed[&target_node].new_node = target_node_copy;
184   old_to_new_and_recursed[&target_node].recursed = true;
185 
186   // Add the target node's inputs to seed the recursion.
187   std::deque<const Edge*> edges_to_visit;
188   for (const Edge* e : target_node.in_edges()) {
189     // TODO(skyewm): control edges will be meaningful if/when we handle control
190     // flow (e.g. constants in cond branches are triggered via control edges).
191     if (e->IsControlEdge()) continue;
192     edges_to_visit.push_back(e);
193   }
194 
195   *is_constant_graph = true;
196 
197   // Iterate over the set of edges to visit (backwards).
198   while (!edges_to_visit.empty()) {
199     const Edge* current_edge = edges_to_visit.front();
200     edges_to_visit.pop_front();
201     Node* current_node = current_edge->src();
202 
203     // If the node is stateful, assume the graph is not constant unless it is
204     // an Arg node which is handled later on.
205     if (!current_node->IsArg() && current_node->op_def().is_stateful()) {
206       *is_constant_graph = false;
207       return Status::OK();
208     }
209 
210     // During construction or import from GraphConstructor, back edges may not
211     // be filled in. In addition, control flow constructs may depend on control
212     // edges which aren't handled by this method. Don't constant fold through
213     // merges at all for now.
214     if (IsMerge(current_node)) {
215       *is_constant_graph = false;
216       return Status::OK();
217     }
218 
219     // Don't constant fold enter/exit currently either, as it's easy to end
220     // up with a partial frame.
221     if (IsEnter(current_node) || IsExit(current_node)) {
222       *is_constant_graph = false;
223       return Status::OK();
224     }
225 
226     // Placeholders should never be constant folded because their outputs are
227     // fed by the user. Note that "Placeholder" nodes have no inputs so are
228     // handled below.
229     if (current_node->type_string() == "PlaceholderWithDefault") {
230       *is_constant_graph = false;
231       return Status::OK();
232     }
233 
234     if (!HasCpuKernel(*current_node)) {
235       *is_constant_graph = false;
236       return Status::OK();
237     }
238 
239     // If there is nothing more to recurse down, see if
240     // the generator node is a constant or an Arg node whose value is available
241     // in the `outer_context`.
242     if (current_node->num_inputs() == 0) {
243       if (outer_context && current_node->IsArg()) {
244         const string& tensor_name =
245             strings::StrCat(current_node->name(), ":", 0);
246         // If we do not already have a constant Tensor for this Arg try to
247         // fetch it from the outer context.
248         if (const_inputs_added.count(tensor_name) == 0) {
249           int index;
250           TF_RETURN_IF_ERROR(GetArgNodeIndex(
251               current_node, outer_context->num_inputs(), &index));
252           const Tensor* const_tensor = outer_context->input_tensor(index);
253           if (const_tensor) {
254             const_inputs->emplace_back(tensor_name, *const_tensor);
255             const_inputs_added.insert(tensor_name);
256           } else {
257             // Request a constant value for this Arg. If that is statically
258             // computable, shape refiner will re-run the shape inference for
259             // this function with this tensor's value.
260             outer_context->request_input_tensor(index);
261             *is_constant_graph = false;
262             return Status::OK();
263           }
264         }
265       } else if (!current_node->IsConstant()) {
266         // Generator node is not a constant, so subgraph is not
267         // constant.
268         *is_constant_graph = false;
269         return Status::OK();
270       }
271     }
272 
273     // Either the node is a constant, or the node is a potential
274     // intermediate node on the path from a constant.
275     //
276     // Add a copy of its node and a new edge to the new subgraph.
277 
278     // Get or create the version of 'current_node' in the new graph.
279     Node* current_node_copy;
280     // This gets or creates the NodeAndRecursed entry for current_node.
281     NodeAndRecursed* node_and_recursed = &old_to_new_and_recursed[current_node];
282     if (node_and_recursed->new_node == nullptr) {
283       // First time processing this node.
284       current_node_copy = out_graph->CopyNode(current_node);
285       // Track the mapping from the original node to the new one.
286       node_and_recursed->new_node = current_node_copy;
287     } else {
288       current_node_copy = node_and_recursed->new_node;
289     }
290 
291     // Add the edge to the destination node.
292     {
293       auto it = old_to_new_and_recursed.find(current_edge->dst());
294       if (it == old_to_new_and_recursed.end()) {
295         return errors::Internal(
296             "Could not find mapping from old to new copy of destination node: ",
297             current_edge->dst()->name());
298       }
299       Node* dst_copy = it->second.new_node;
300 
301       out_graph->AddEdge(current_node_copy, current_edge->src_output(),
302                          dst_copy, current_edge->dst_input());
303     }
304 
305     const string& output_tensor_name =
306         strings::StrCat(current_node->name(), ":", current_edge->src_output());
307 
308     // Some tensor values can be inferred. For example, a shape op
309     // with input shapes fully defined can have its output tensor inferred.
310     Tensor tensor_inferred;
311     bool successfully_inferred_tensor = false;
312     TF_RETURN_IF_ERROR(TryToInferTensorOutputFromInputShapes(
313         *current_edge, refiner, &tensor_inferred,
314         &successfully_inferred_tensor));
315     if (successfully_inferred_tensor) {
316       const_inputs->emplace_back(output_tensor_name, tensor_inferred);
317       const_inputs_added.insert(output_tensor_name);
318       continue;
319     }
320 
321     // If we have a copy of the input tensor materialized already,
322     // then add to the list of inputs to feed and do not recurse further.
323     if (cached_values != nullptr) {
324       auto it = cached_values->find(output_tensor_name);
325       if (it != cached_values->end() &&
326           const_inputs_added.count(output_tensor_name) == 0) {
327         const_inputs->emplace_back(output_tensor_name, it->second);
328         const_inputs_added.insert(output_tensor_name);
329         continue;
330       }
331     }
332 
333     // If this node's inputs have not been processed already, do so now.
334     if (!node_and_recursed->recursed) {
335       node_and_recursed->recursed = true;
336       for (const Edge* e : current_node->in_edges()) {
337         if (e->IsControlEdge()) continue;
338         edges_to_visit.push_back(e);
339       }
340     }
341   }
342 
343   return Status::OK();
344 }
345 
346 }  // namespace
347 
EvaluateConstantTensor(OutputTensor tensor,const ShapeRefiner & refiner,const OpRegistryInterface & ops,int32 graph_def_version,bool * evaluated,Tensor * result,GraphRunner * graph_runner,std::unordered_map<string,Tensor> * cached_values,int64 max_cached_value_size,bool disable_constant_propagation,InferenceContext * outer_context)348 Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner,
349                               const OpRegistryInterface& ops,
350                               int32 graph_def_version, bool* evaluated,
351                               Tensor* result, GraphRunner* graph_runner,
352                               std::unordered_map<string, Tensor>* cached_values,
353                               int64 max_cached_value_size,
354                               bool disable_constant_propagation,
355                               InferenceContext* outer_context) {
356   *evaluated = false;
357   const Node* src = tensor.node;
358 
359   // Simple case: the source node is a constant
360   if (src->IsConstant()) {
361     if (result->FromProto(src->def().attr().at("value").tensor())) {
362       *evaluated = true;
363       return Status::OK();
364     }
365   }
366 
367   // If the source node is an Arg return its value, if available in the outer
368   // context.
369   if (src->IsArg() && outer_context) {
370     int index;
371     TF_RETURN_IF_ERROR(
372         GetArgNodeIndex(src, outer_context->num_inputs(), &index));
373     const Tensor* const_tensor = outer_context->input_tensor(index);
374     if (const_tensor) {
375       *evaluated = true;
376       *result = *(outer_context->input_tensor(index));
377     } else {
378       outer_context->request_input_tensor(index);
379     }
380     return Status::OK();
381   }
382 
383   if (disable_constant_propagation) {
384     return Status::OK();
385   }
386 
387   bool is_constant_graph = false;
388   Graph subgraph(&ops);
389   auto versions = subgraph.versions();
390   versions.set_producer(graph_def_version);
391   subgraph.set_versions(versions);
392 
393   std::vector<std::pair<string, Tensor>> const_inputs;
394   TF_RETURN_IF_ERROR(ExtractConstantSubgraph(*src, refiner, cached_values,
395                                              &subgraph, &is_constant_graph,
396                                              &const_inputs, outer_context));
397   if (!is_constant_graph) {
398     return Status::OK();
399   }
400   const string output_tensor_name =
401       strings::StrCat(src->name(), ":", tensor.index);
402   std::vector<Tensor> outputs;
403 
404   std::unique_ptr<GraphRunner> graph_runner_storage;
405   if (graph_runner == nullptr) {
406     // TODO(skyewm): Convert to std::make_unique when available.
407     graph_runner_storage.reset(new GraphRunner(Env::Default()));
408     graph_runner = graph_runner_storage.get();
409   }
410 
411   // NOTE; we should pass in a function library runtime if we want
412   // to support constant-expression evaluation on functions.
413   Status s = graph_runner->Run(&subgraph, nullptr /* function_library */,
414                                const_inputs, {output_tensor_name}, &outputs);
415 
416   // If all kernels in the constant graph are not registered
417   // in the process, GraphRunner::Run may fail, in which case
418   // we cannot propagate constants, so this is best-effort.
419   if (s.ok()) {
420     *result = outputs[0];
421     *evaluated = true;
422 
423     // We memoize (small) constants evaluated so far, so
424     // ExtractConstantSubgraph can avoid extracting the full
425     // subgraph.  As we build up large graphs, this avoids
426     // repeated computation of the early parts of a constant
427     // graph.
428     if (cached_values != nullptr &&
429         outputs[0].TotalBytes() <= max_cached_value_size) {
430       (*cached_values)[output_tensor_name] = outputs[0];
431     }
432   }
433   return Status::OK();
434 }
435 
436 }  // namespace tensorflow
437