1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
17
18 #include <deque>
19
20 #include "tensorflow/core/common_runtime/graph_runner.h"
21 #include "tensorflow/core/common_runtime/shape_refiner.h"
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/shape_inference.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/versions.pb.h"
27 #include "tensorflow/core/graph/graph.h"
28
29 namespace tensorflow {
30
31 using shape_inference::InferenceContext;
32
33 namespace {
34
35 // Tries to infer tensor output based on the input shapes of the node. In some
36 // cases, the shapes of the inputs are sufficient for inferring the contents of
37 // the output tensor. For example, a Shape op with fully defined input shapes
38 // can have its output tensor inferred.
TryToInferTensorOutputFromInputShapes(const Edge & edge,const ShapeRefiner & refiner,Tensor * output,bool * success)39 Status TryToInferTensorOutputFromInputShapes(const Edge& edge,
40 const ShapeRefiner& refiner,
41 Tensor* output, bool* success) {
42 *success = false;
43 const Node* node = edge.src();
44 InferenceContext* c = refiner.GetContext(node);
45 if (c == nullptr) {
46 // An input without context is a soft failure; we sometimes need to break
47 // control flow loops by running shape inference on a node without first
48 // adding its input.
49 return Status::OK();
50 }
51
52 if (node->type_string() == "Shape") {
53 // If input shapes to the shape op are fully defined,
54 // we can infer the shape op's output tensor.
55 bool fully_defined_inputs = c->FullyDefined(c->input(0));
56 if (fully_defined_inputs) {
57 int input_rank = c->Rank(c->input(0));
58 Tensor t(node->output_type(0), TensorShape({input_rank}));
59 if (node->output_type(0) == DT_INT32) {
60 auto flat = t.flat<int>();
61 for (int i = 0; i < input_rank; i++) {
62 int64 dimension = c->Value(c->Dim(c->input(0), i));
63 if (!FastBoundsCheck(dimension, std::numeric_limits<int32>::max())) {
64 return errors::InvalidArgument(
65 "Shape has output type int32, but dimension exceeds maximum "
66 "int32 value");
67 }
68 flat(i) = static_cast<int32>(dimension);
69 }
70 } else if (node->output_type(0) == DT_INT64) {
71 auto flat = t.flat<int64>();
72 for (int i = 0; i < input_rank; i++) {
73 flat(i) = c->Value(c->Dim(c->input(0), i));
74 }
75 } else {
76 return errors::FailedPrecondition(
77 "Shape has output type that is not int32 or int64");
78 }
79 *output = t;
80 *success = true;
81 }
82 } else if (node->type_string() == "Rank") {
83 bool rank_known = c->RankKnown(c->input(0));
84 if (rank_known) {
85 int32 input_rank = c->Rank(c->input(0));
86 Tensor t(node->output_type(0), TensorShape({}));
87 t.flat<int32>()(0) = input_rank;
88 *output = t;
89 *success = true;
90 }
91 } else if (node->type_string() == "Size") {
92 bool fully_defined_inputs = c->FullyDefined(c->input(0));
93 if (fully_defined_inputs) {
94 int32 rank = c->Rank(c->input(0));
95 Tensor t(node->output_type(0), TensorShape({}));
96 int64 size = 1;
97 for (int i = 0; i < rank; i++) {
98 size *= c->Value(c->Dim(c->input(0), i));
99 }
100 if (node->output_type(0) == DT_INT32) {
101 if (!FastBoundsCheck(size, std::numeric_limits<int32>::max())) {
102 return errors::InvalidArgument(
103 "Size has output type int32, but size exceeds maximum int32 "
104 "value");
105 }
106 t.flat<int32>()(0) = static_cast<int32>(size);
107 } else if (node->output_type(0) == DT_INT64) {
108 t.flat<int64>()(0) = size;
109 } else {
110 return errors::FailedPrecondition(
111 "Size has output type that is not int32 or int64");
112 }
113 *output = t;
114 *success = true;
115 }
116 }
117 return Status::OK();
118 }
119
120 // Returns true if 'node' has a registered CPU kernel.
HasCpuKernel(const Node & node)121 bool HasCpuKernel(const Node& node) {
122 return FindKernelDef(DeviceType(DEVICE_CPU), node.def(), /*def=*/nullptr,
123 /*kernel_class_name=*/nullptr)
124 .ok();
125 }
126
GetArgNodeIndex(const Node * node,int num_function_inputs,int * index)127 Status GetArgNodeIndex(const Node* node, int num_function_inputs, int* index) {
128 DCHECK(node->IsArg());
129 TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(node->def()), "index", index));
130 if (*index < 0 || num_function_inputs <= *index) {
131 return errors::Internal(
132 "Function instantiation included invalid input index: ", index,
133 " not in [0, ", num_function_inputs, ").");
134 }
135 return Status::OK();
136 }
137
138 // Extracts the subgraph ending at 'target_node' that is statically computable
139 // and inserts into 'out_graph'. If statically computable, 'is_constant_graph'
140 // will be set to true.
ExtractConstantSubgraph(const Node & target_node,const ShapeRefiner & refiner,const std::unordered_map<string,Tensor> * cached_values,Graph * out_graph,bool * is_constant_graph,std::vector<std::pair<string,Tensor>> * const_inputs,InferenceContext * outer_context)141 Status ExtractConstantSubgraph(
142 const Node& target_node, const ShapeRefiner& refiner,
143 const std::unordered_map<string, Tensor>* cached_values, Graph* out_graph,
144 bool* is_constant_graph,
145 std::vector<std::pair<string, Tensor>>* const_inputs,
146 InferenceContext* outer_context) {
147 *is_constant_graph = false;
148 std::unordered_set<string> const_inputs_added;
149
150 if (target_node.op_def().is_stateful()) {
151 return Status::OK();
152 }
153
154 if (IsMerge(&target_node)) {
155 return Status::OK();
156 }
157
158 if (target_node.type_string() == "PlaceholderWithDefault") {
159 return Status::OK();
160 }
161
162 // Since constant-folding runs on the CPU, do not attempt to constant-fold
163 // operators that have no CPU kernel.
164 if (!HasCpuKernel(target_node)) {
165 return Status::OK();
166 }
167
168 // TODO(skyewm): should more of the filtering applied in input nodes below be
169 // applied to target_node here?
170
171 // Identify the possibly constant subgraph by recursively iterating backwards
172 // through the inputs to 'target_node' until we either 1) find an already
173 // existing input to our subgraph 'const_inputs', 2) Discover our graph is not
174 // constant, or 3) Hit a root node.
175
176 struct NodeAndRecursed {
177 Node* new_node = nullptr;
178 bool recursed = false;
179 };
180
181 std::map<const Node*, NodeAndRecursed> old_to_new_and_recursed;
182 Node* target_node_copy = out_graph->CopyNode(&target_node);
183 old_to_new_and_recursed[&target_node].new_node = target_node_copy;
184 old_to_new_and_recursed[&target_node].recursed = true;
185
186 // Add the target node's inputs to seed the recursion.
187 std::deque<const Edge*> edges_to_visit;
188 for (const Edge* e : target_node.in_edges()) {
189 // TODO(skyewm): control edges will be meaningful if/when we handle control
190 // flow (e.g. constants in cond branches are triggered via control edges).
191 if (e->IsControlEdge()) continue;
192 edges_to_visit.push_back(e);
193 }
194
195 *is_constant_graph = true;
196
197 // Iterate over the set of edges to visit (backwards).
198 while (!edges_to_visit.empty()) {
199 const Edge* current_edge = edges_to_visit.front();
200 edges_to_visit.pop_front();
201 Node* current_node = current_edge->src();
202
203 // If the node is stateful, assume the graph is not constant unless it is
204 // an Arg node which is handled later on.
205 if (!current_node->IsArg() && current_node->op_def().is_stateful()) {
206 *is_constant_graph = false;
207 return Status::OK();
208 }
209
210 // During construction or import from GraphConstructor, back edges may not
211 // be filled in. In addition, control flow constructs may depend on control
212 // edges which aren't handled by this method. Don't constant fold through
213 // merges at all for now.
214 if (IsMerge(current_node)) {
215 *is_constant_graph = false;
216 return Status::OK();
217 }
218
219 // Don't constant fold enter/exit currently either, as it's easy to end
220 // up with a partial frame.
221 if (IsEnter(current_node) || IsExit(current_node)) {
222 *is_constant_graph = false;
223 return Status::OK();
224 }
225
226 // Placeholders should never be constant folded because their outputs are
227 // fed by the user. Note that "Placeholder" nodes have no inputs so are
228 // handled below.
229 if (current_node->type_string() == "PlaceholderWithDefault") {
230 *is_constant_graph = false;
231 return Status::OK();
232 }
233
234 if (!HasCpuKernel(*current_node)) {
235 *is_constant_graph = false;
236 return Status::OK();
237 }
238
239 // If there is nothing more to recurse down, see if
240 // the generator node is a constant or an Arg node whose value is available
241 // in the `outer_context`.
242 if (current_node->num_inputs() == 0) {
243 if (outer_context && current_node->IsArg()) {
244 const string& tensor_name =
245 strings::StrCat(current_node->name(), ":", 0);
246 // If we do not already have a constant Tensor for this Arg try to
247 // fetch it from the outer context.
248 if (const_inputs_added.count(tensor_name) == 0) {
249 int index;
250 TF_RETURN_IF_ERROR(GetArgNodeIndex(
251 current_node, outer_context->num_inputs(), &index));
252 const Tensor* const_tensor = outer_context->input_tensor(index);
253 if (const_tensor) {
254 const_inputs->emplace_back(tensor_name, *const_tensor);
255 const_inputs_added.insert(tensor_name);
256 } else {
257 // Request a constant value for this Arg. If that is statically
258 // computable, shape refiner will re-run the shape inference for
259 // this function with this tensor's value.
260 outer_context->request_input_tensor(index);
261 *is_constant_graph = false;
262 return Status::OK();
263 }
264 }
265 } else if (!current_node->IsConstant()) {
266 // Generator node is not a constant, so subgraph is not
267 // constant.
268 *is_constant_graph = false;
269 return Status::OK();
270 }
271 }
272
273 // Either the node is a constant, or the node is a potential
274 // intermediate node on the path from a constant.
275 //
276 // Add a copy of its node and a new edge to the new subgraph.
277
278 // Get or create the version of 'current_node' in the new graph.
279 Node* current_node_copy;
280 // This gets or creates the NodeAndRecursed entry for current_node.
281 NodeAndRecursed* node_and_recursed = &old_to_new_and_recursed[current_node];
282 if (node_and_recursed->new_node == nullptr) {
283 // First time processing this node.
284 current_node_copy = out_graph->CopyNode(current_node);
285 // Track the mapping from the original node to the new one.
286 node_and_recursed->new_node = current_node_copy;
287 } else {
288 current_node_copy = node_and_recursed->new_node;
289 }
290
291 // Add the edge to the destination node.
292 {
293 auto it = old_to_new_and_recursed.find(current_edge->dst());
294 if (it == old_to_new_and_recursed.end()) {
295 return errors::Internal(
296 "Could not find mapping from old to new copy of destination node: ",
297 current_edge->dst()->name());
298 }
299 Node* dst_copy = it->second.new_node;
300
301 out_graph->AddEdge(current_node_copy, current_edge->src_output(),
302 dst_copy, current_edge->dst_input());
303 }
304
305 const string& output_tensor_name =
306 strings::StrCat(current_node->name(), ":", current_edge->src_output());
307
308 // Some tensor values can be inferred. For example, a shape op
309 // with input shapes fully defined can have its output tensor inferred.
310 Tensor tensor_inferred;
311 bool successfully_inferred_tensor = false;
312 TF_RETURN_IF_ERROR(TryToInferTensorOutputFromInputShapes(
313 *current_edge, refiner, &tensor_inferred,
314 &successfully_inferred_tensor));
315 if (successfully_inferred_tensor) {
316 const_inputs->emplace_back(output_tensor_name, tensor_inferred);
317 const_inputs_added.insert(output_tensor_name);
318 continue;
319 }
320
321 // If we have a copy of the input tensor materialized already,
322 // then add to the list of inputs to feed and do not recurse further.
323 if (cached_values != nullptr) {
324 auto it = cached_values->find(output_tensor_name);
325 if (it != cached_values->end() &&
326 const_inputs_added.count(output_tensor_name) == 0) {
327 const_inputs->emplace_back(output_tensor_name, it->second);
328 const_inputs_added.insert(output_tensor_name);
329 continue;
330 }
331 }
332
333 // If this node's inputs have not been processed already, do so now.
334 if (!node_and_recursed->recursed) {
335 node_and_recursed->recursed = true;
336 for (const Edge* e : current_node->in_edges()) {
337 if (e->IsControlEdge()) continue;
338 edges_to_visit.push_back(e);
339 }
340 }
341 }
342
343 return Status::OK();
344 }
345
346 } // namespace
347
EvaluateConstantTensor(OutputTensor tensor,const ShapeRefiner & refiner,const OpRegistryInterface & ops,int32 graph_def_version,bool * evaluated,Tensor * result,GraphRunner * graph_runner,std::unordered_map<string,Tensor> * cached_values,int64 max_cached_value_size,bool disable_constant_propagation,InferenceContext * outer_context)348 Status EvaluateConstantTensor(OutputTensor tensor, const ShapeRefiner& refiner,
349 const OpRegistryInterface& ops,
350 int32 graph_def_version, bool* evaluated,
351 Tensor* result, GraphRunner* graph_runner,
352 std::unordered_map<string, Tensor>* cached_values,
353 int64 max_cached_value_size,
354 bool disable_constant_propagation,
355 InferenceContext* outer_context) {
356 *evaluated = false;
357 const Node* src = tensor.node;
358
359 // Simple case: the source node is a constant
360 if (src->IsConstant()) {
361 if (result->FromProto(src->def().attr().at("value").tensor())) {
362 *evaluated = true;
363 return Status::OK();
364 }
365 }
366
367 // If the source node is an Arg return its value, if available in the outer
368 // context.
369 if (src->IsArg() && outer_context) {
370 int index;
371 TF_RETURN_IF_ERROR(
372 GetArgNodeIndex(src, outer_context->num_inputs(), &index));
373 const Tensor* const_tensor = outer_context->input_tensor(index);
374 if (const_tensor) {
375 *evaluated = true;
376 *result = *(outer_context->input_tensor(index));
377 } else {
378 outer_context->request_input_tensor(index);
379 }
380 return Status::OK();
381 }
382
383 if (disable_constant_propagation) {
384 return Status::OK();
385 }
386
387 bool is_constant_graph = false;
388 Graph subgraph(&ops);
389 auto versions = subgraph.versions();
390 versions.set_producer(graph_def_version);
391 subgraph.set_versions(versions);
392
393 std::vector<std::pair<string, Tensor>> const_inputs;
394 TF_RETURN_IF_ERROR(ExtractConstantSubgraph(*src, refiner, cached_values,
395 &subgraph, &is_constant_graph,
396 &const_inputs, outer_context));
397 if (!is_constant_graph) {
398 return Status::OK();
399 }
400 const string output_tensor_name =
401 strings::StrCat(src->name(), ":", tensor.index);
402 std::vector<Tensor> outputs;
403
404 std::unique_ptr<GraphRunner> graph_runner_storage;
405 if (graph_runner == nullptr) {
406 // TODO(skyewm): Convert to std::make_unique when available.
407 graph_runner_storage.reset(new GraphRunner(Env::Default()));
408 graph_runner = graph_runner_storage.get();
409 }
410
411 // NOTE; we should pass in a function library runtime if we want
412 // to support constant-expression evaluation on functions.
413 Status s = graph_runner->Run(&subgraph, nullptr /* function_library */,
414 const_inputs, {output_tensor_name}, &outputs);
415
416 // If all kernels in the constant graph are not registered
417 // in the process, GraphRunner::Run may fail, in which case
418 // we cannot propagate constants, so this is best-effort.
419 if (s.ok()) {
420 *result = outputs[0];
421 *evaluated = true;
422
423 // We memoize (small) constants evaluated so far, so
424 // ExtractConstantSubgraph can avoid extracting the full
425 // subgraph. As we build up large graphs, this avoids
426 // repeated computation of the early parts of a constant
427 // graph.
428 if (cached_values != nullptr &&
429 outputs[0].TotalBytes() <= max_cached_value_size) {
430 (*cached_values)[output_tensor_name] = outputs[0];
431 }
432 }
433 return Status::OK();
434 }
435
436 } // namespace tensorflow
437