1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/const_analysis.h"
17
18 #include <unordered_map>
19 #include <unordered_set>
20
21 #include "absl/algorithm/container.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/graph/algorithm.h"
29 #include "tensorflow/core/lib/core/errors.h"
30
31 namespace tensorflow {
32
33 Status GetCompileTimeConstInputs(const Node* node,
34 std::vector<int>* const_input_idxs,
35 FunctionLibraryRuntime* flib_runtime);
36
37 // Backwards dataflow analysis that finds arguments to a graph that must be
38 // compile-time constants.
BackwardsConstAnalysis(const Graph & g,std::vector<bool> * compile_time_const_arg_indices,std::vector<bool> * compile_time_const_nodes,FunctionLibraryRuntime * flib_runtime,std::function<bool (const Edge &)> edge_filter)39 Status BackwardsConstAnalysis(const Graph& g,
40 std::vector<bool>* compile_time_const_arg_indices,
41 std::vector<bool>* compile_time_const_nodes,
42 FunctionLibraryRuntime* flib_runtime,
43 std::function<bool(const Edge&)> edge_filter) {
44 std::vector<bool> compile_time_const_nodes_impl;
45 if (compile_time_const_nodes) {
46 CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids());
47 } else {
48 compile_time_const_nodes_impl.resize(g.num_node_ids());
49 compile_time_const_nodes = &compile_time_const_nodes_impl;
50 }
51
52 Status status;
53 auto visit = [&](Node* node) {
54 if (!status.ok()) return;
55
56 // If this is a metadata-only op, don't propagate the const requirement.
57 if (XlaOpRegistry::IsMetadataOp(node->type_string())) {
58 return;
59 }
60
61 // If this node must be const, and it isn't a metadata op, then all of its
62 // parents must be const.
63 if ((*compile_time_const_nodes)[node->id()]) {
64 if (node->type_string() == "_Arg") {
65 int index;
66 status = GetNodeAttr(node->attrs(), "index", &index);
67 if (!status.ok()) return;
68 if (compile_time_const_arg_indices) {
69 (*compile_time_const_arg_indices)[index] = true;
70 }
71 return;
72 }
73 for (const Edge* pred : node->in_edges()) {
74 if (!pred->IsControlEdge() && edge_filter(*pred)) {
75 // If the src node of the `pred` is an IdentityN do not mark it as a
76 // compile-time const. Only mark the corresponding input to the
77 // IdentityN node as a const.
78 // Note: XLA IdentityN op simply forwards its inputs so this is safe.
79 while (edge_filter(*pred) &&
80 pred->src()->type_string() == "IdentityN") {
81 status = pred->src()->input_edge(pred->src_output(), &pred);
82 if (!status.ok()) return;
83 }
84 if (edge_filter(*pred)) {
85 (*compile_time_const_nodes)[pred->src()->id()] = true;
86 }
87 }
88 }
89 return;
90 }
91
92 // Mark any compile-time constant operator arguments as const.
93 std::vector<int> const_input_idxs;
94 status = GetCompileTimeConstInputs(node, &const_input_idxs, flib_runtime);
95
96 if (!status.ok()) {
97 return;
98 }
99
100 for (Edge const* edge : node->in_edges()) {
101 if (!edge->IsControlEdge() &&
102 absl::c_binary_search(const_input_idxs, edge->dst_input()) &&
103 edge_filter(*edge)) {
104 // Do not mark IdentityN nodes as compile-time const.
105 // If the src node of the `pred` is an IdentityN do not mark it as a
106 // compile-time const. Only mark the corresponding input to the
107 // IdentityN node as a const.
108 // Note: XLA IdentityN op simply forwards its inputs so this is safe.
109 while (edge_filter(*edge) &&
110 edge->src()->type_string() == "IdentityN") {
111 status = edge->src()->input_edge(edge->src_output(), &edge);
112 if (!status.ok()) return;
113 }
114 if (edge_filter(*edge)) {
115 (*compile_time_const_nodes)[edge->src()->id()] = true;
116 }
117 }
118 }
119 };
120
121 // Post-order traversal visits nodes in reverse topological order for an
122 // acyclic graph.
123 DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{},
124 [](const Edge& edge) { return !edge.src()->IsNextIteration(); });
125 return status;
126 }
127
GetCompileTimeConstInputs(const Node * node,std::vector<int> * const_input_idxs,FunctionLibraryRuntime * flib_runtime)128 Status GetCompileTimeConstInputs(const Node* node,
129 std::vector<int>* const_input_idxs,
130 FunctionLibraryRuntime* flib_runtime) {
131 if (node->type_string() != "While") {
132 return XlaOpRegistry::CompileTimeConstantInputs(node->def(), node->op_def(),
133 const_input_idxs);
134 }
135 // For While nodes, recurse into the body and cond graphs.
136 // TODO(b/124403063): Implement similar functionality for cond nodes and other
137 // functional ops.
138 NameAttrList cond_function;
139 TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "cond", &cond_function));
140 NameAttrList body_function;
141 TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "body", &body_function));
142 FunctionLibraryRuntime::Handle cond_handle;
143 FunctionLibraryRuntime::Handle body_handle;
144 TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
145 cond_function.name(), AttrSlice(&cond_function.attr()), &cond_handle));
146 TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
147 body_function.name(), AttrSlice(&body_function.attr()), &body_handle));
148 const FunctionBody* fcond = flib_runtime->GetFunctionBody(cond_handle);
149 const FunctionBody* fbody = flib_runtime->GetFunctionBody(body_handle);
150 TF_RET_CHECK(fcond);
151 TF_RET_CHECK(fbody);
152 int num_inputs = fbody->fdef.signature().input_arg_size();
153
154 // Stores which of the loop inputs are expected to be compile time constants.
155 std::vector<bool> compile_time_const_arg_indices(num_inputs);
156 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
157 *(fcond->graph), &compile_time_const_arg_indices,
158 /*compile_time_const_nodes=*/nullptr, flib_runtime));
159 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
160 *(fbody->graph), &compile_time_const_arg_indices,
161 /*compile_time_const_nodes=*/nullptr, flib_runtime));
162 for (int i = 0; i < num_inputs; i++) {
163 if (compile_time_const_arg_indices[i]) {
164 // Check that this input is actually a loop invariant.
165 // NOTE(srbs): Ideally this should raise an error if the loop body
166 // requires the input at this index to be a compile time const but it is
167 // not a loop invariant. However, that causes problems because const
168 // analysis is performed for the entire graph (in the
169 // MarkForCompilationPass for example) and not just for the ops
170 // that will actually be run using XLA kernels. So we silently return here
171 // and let the error be raised during the actual compilation of the
172 // XLA graph.
173 Node* arg_i = fbody->arg_nodes[i];
174 Node* ret_i = fbody->ret_nodes[i];
175 const Node* ret_i_input_0;
176 TF_RETURN_IF_ERROR(ret_i->input_node(0, &ret_i_input_0));
177 if (ret_i_input_0->id() == arg_i->id()) {
178 const_input_idxs->push_back(i);
179 }
180 }
181 }
182 return Status::OK();
183 }
184
185 } // namespace tensorflow
186