1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/const_analysis.h"
17
18 #include <unordered_map>
19 #include <unordered_set>
20
21 #include "absl/algorithm/container.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/graph/algorithm.h"
29 #include "tensorflow/core/lib/core/errors.h"
30
31 namespace tensorflow {
32
33 namespace {
34
GetFunctionBody(FunctionLibraryRuntime * flib_runtime,const NodeDef & node,StringPiece func_attr_name,const FunctionBody ** fbody)35 Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime,
36 const NodeDef& node, StringPiece func_attr_name,
37 const FunctionBody** fbody) {
38 NameAttrList name_attr_list;
39 TF_RETURN_IF_ERROR(GetNodeAttr(node, func_attr_name, &name_attr_list));
40 FunctionLibraryRuntime::Handle func_handle;
41 TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
42 name_attr_list.name(), AttrSlice(&name_attr_list.attr()), &func_handle));
43 *fbody = flib_runtime->GetFunctionBody(func_handle);
44 return Status::OK();
45 }
46
GetFunctionBodies(FunctionLibraryRuntime * flib_runtime,const NodeDef & node,StringPiece func_list_attr_name,std::vector<const FunctionBody * > * fbodies)47 Status GetFunctionBodies(FunctionLibraryRuntime* flib_runtime,
48 const NodeDef& node, StringPiece func_list_attr_name,
49 std::vector<const FunctionBody*>* fbodies) {
50 std::vector<NameAttrList> name_attr_lists;
51 TF_RETURN_IF_ERROR(GetNodeAttr(node, func_list_attr_name, &name_attr_lists));
52 for (const NameAttrList& name_attr_list : name_attr_lists) {
53 FunctionLibraryRuntime::Handle func_handle;
54 TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
55 name_attr_list.name(), AttrSlice(&name_attr_list.attr()),
56 &func_handle));
57 fbodies->push_back(flib_runtime->GetFunctionBody(func_handle));
58 }
59 return Status::OK();
60 }
61
CondConstInputIndices(absl::Span<const FunctionBody * const> branch_bodies,std::vector<int> * const_input_idxs,FunctionLibraryRuntime * flib_runtime)62 Status CondConstInputIndices(
63 absl::Span<const FunctionBody* const> branch_bodies,
64 std::vector<int>* const_input_idxs, FunctionLibraryRuntime* flib_runtime) {
65 TF_RET_CHECK(!branch_bodies.empty());
66 TF_RET_CHECK(branch_bodies[0] != nullptr);
67 int num_inputs = branch_bodies[0]->fdef.signature().input_arg_size();
68 // Stores indices of the "branch function" inputs that are expected to be
69 // compile time constants.
70 std::vector<bool> compile_time_const_arg_indices(num_inputs);
71 for (auto fbody : branch_bodies) {
72 TF_RET_CHECK(fbody != nullptr);
73 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
74 *(fbody->graph), &compile_time_const_arg_indices,
75 /*compile_time_const_nodes=*/nullptr, flib_runtime));
76 }
77 for (int i = 0, end = compile_time_const_arg_indices.size(); i < end; i++) {
78 if (compile_time_const_arg_indices[i]) {
79 // The 0th input is the pred or branch index, which is not passed to the
80 // branches. So the i'th input of a branch function corresponds to the
81 // i + 1'th input of the If/Case op.
82 const_input_idxs->push_back(i + 1);
83 }
84 }
85 return Status::OK();
86 }
87
GetCompileTimeConstInputs(const NodeDef & node,const OpKernel * op_kernel,const OpDef * op_def,std::vector<int> * const_input_idxs,FunctionLibraryRuntime * flib_runtime)88 Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel,
89 const OpDef* op_def,
90 std::vector<int>* const_input_idxs,
91 FunctionLibraryRuntime* flib_runtime) {
92 DCHECK(op_def != nullptr || op_kernel != nullptr);
93 if (node.op() == "While" || node.op() == "StatelessWhile") {
94 // For While nodes, recurse into the body and cond graphs.
95 const FunctionBody* fcond = nullptr;
96 const FunctionBody* fbody = nullptr;
97 TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "cond", &fcond));
98 TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "body", &fbody));
99 TF_RET_CHECK(fcond);
100 TF_RET_CHECK(fbody);
101 int num_inputs = fbody->fdef.signature().input_arg_size();
102
103 // Stores which of the loop inputs are expected to be compile time
104 // constants.
105 std::vector<bool> compile_time_const_arg_indices(num_inputs);
106 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
107 *(fcond->graph), &compile_time_const_arg_indices,
108 /*compile_time_const_nodes=*/nullptr, flib_runtime));
109 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
110 *(fbody->graph), &compile_time_const_arg_indices,
111 /*compile_time_const_nodes=*/nullptr, flib_runtime));
112 for (int i = 0; i < num_inputs; i++) {
113 if (compile_time_const_arg_indices[i]) {
114 // Check that this input is actually a loop invariant.
115 Node* arg_i = fbody->arg_nodes[i];
116 Node* ret_i = fbody->ret_nodes[i];
117 const Node* ret_i_input_0;
118 TF_RETURN_IF_ERROR(ret_i->input_node(0, &ret_i_input_0));
119 if (ret_i_input_0->id() == arg_i->id()) {
120 const_input_idxs->push_back(i);
121 } else {
122 // TODO(b/178546817): Verify that it's OK and raise an error if we are
123 // using this branch from jit_compile=True.
124 VLOG(1) << "Argument " << i << " to while-loop "
125 << node.ShortDebugString()
126 << " has to be constant, but it's not a loop invariant, "
127 "cluster compilation likely to fail at compile time: "
128 << arg_i->def().ShortDebugString() << " vs. "
129 << ret_i->def().ShortDebugString();
130 }
131 }
132 }
133 return Status::OK();
134 } else if (node.op() == "If" || node.op() == "StatelessIf") {
135 const FunctionBody* fthen = nullptr;
136 const FunctionBody* felse = nullptr;
137 TF_RETURN_IF_ERROR(
138 GetFunctionBody(flib_runtime, node, "then_branch", &fthen));
139 TF_RETURN_IF_ERROR(
140 GetFunctionBody(flib_runtime, node, "else_branch", &felse));
141 return CondConstInputIndices({fthen, felse}, const_input_idxs,
142 flib_runtime);
143 } else if (node.op() == "Case" || node.op() == "StatelessCase") {
144 std::vector<const FunctionBody*> branch_bodies;
145 TF_RETURN_IF_ERROR(
146 GetFunctionBodies(flib_runtime, node, "branches", &branch_bodies));
147 return CondConstInputIndices(branch_bodies, const_input_idxs, flib_runtime);
148 } else if (node.op() == "PartitionedCall" ||
149 node.op() == "StatefulPartitionedCall") {
150 const FunctionBody* fbody;
151 TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "f", &fbody));
152 int num_inputs = fbody->fdef.signature().input_arg_size();
153 std::vector<bool> compile_time_const_arg_indices(num_inputs);
154 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
155 *(fbody->graph), &compile_time_const_arg_indices,
156 /*compile_time_const_nodes=*/nullptr, flib_runtime));
157 for (int i = 0; i < num_inputs; i++) {
158 if (compile_time_const_arg_indices[i]) {
159 const_input_idxs->push_back(i);
160 }
161 }
162 return Status::OK();
163 } else if (op_def != nullptr) {
164 return XlaOpRegistry::CompileTimeConstantInputs(node, *op_def,
165 const_input_idxs);
166 } else {
167 return XlaOpRegistry::CompileTimeConstantInputs(*op_kernel,
168 const_input_idxs);
169 }
170 }
171
GetCompileTimeConstInputs(const Node * node,std::vector<int> * const_input_idxs,FunctionLibraryRuntime * flib_runtime)172 Status GetCompileTimeConstInputs(const Node* node,
173 std::vector<int>* const_input_idxs,
174 FunctionLibraryRuntime* flib_runtime) {
175 return GetCompileTimeConstInputs(node->def(), /*op_kernel=*/nullptr,
176 &node->op_def(), const_input_idxs,
177 flib_runtime);
178 }
179
180 } // namespace
181
182 // Backwards dataflow analysis that finds arguments to a graph that must be
183 // compile-time constants.
BackwardsConstAnalysis(const Graph & g,std::vector<bool> * compile_time_const_arg_indices,std::vector<bool> * compile_time_const_nodes,FunctionLibraryRuntime * flib_runtime,std::function<bool (const Edge &)> edge_filter_input)184 Status BackwardsConstAnalysis(
185 const Graph& g, std::vector<bool>* compile_time_const_arg_indices,
186 std::vector<bool>* compile_time_const_nodes,
187 FunctionLibraryRuntime* flib_runtime,
188 std::function<bool(const Edge&)> edge_filter_input) {
189 if (!compile_time_const_nodes && g.GetConstArgIndicesCache().has_value() &&
190 !edge_filter_input) {
191 VLOG(5) << "Using cached argument indices on graph " << &g;
192 *compile_time_const_arg_indices = g.GetConstArgIndicesCache().value();
193 return Status::OK();
194 }
195 auto edge_filter = [&](const Edge& e) {
196 return edge_filter_input ? edge_filter_input(e) : true;
197 };
198
199 std::vector<bool> compile_time_const_nodes_impl;
200 if (compile_time_const_nodes) {
201 CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids());
202 } else {
203 compile_time_const_nodes_impl.resize(g.num_node_ids());
204 compile_time_const_nodes = &compile_time_const_nodes_impl;
205 }
206
207 Status status;
208 auto visit = [&](Node* node) {
209 if (!status.ok()) return;
210
211 // If this is a metadata-only op, don't propagate the const requirement.
212 if (XlaOpRegistry::IsMetadataOp(node->type_string())) {
213 return;
214 }
215
216 // If this node must be const, and it isn't a metadata op, then all of its
217 // parents must be const.
218 if ((*compile_time_const_nodes)[node->id()]) {
219 if (node->type_string() == "_Arg") {
220 int index;
221 status = GetNodeAttr(node->attrs(), "index", &index);
222 if (!status.ok()) return;
223 if (compile_time_const_arg_indices) {
224 (*compile_time_const_arg_indices)[index] = true;
225 }
226 return;
227 }
228 for (const Edge* pred : node->in_edges()) {
229 if (!pred->IsControlEdge() && edge_filter(*pred)) {
230 // If the src node of the `pred` is an IdentityN do not mark it as a
231 // compile-time const. Only mark the corresponding input to the
232 // IdentityN node as a const.
233 // Note: XLA IdentityN op simply forwards its inputs so this is safe.
234 while (edge_filter(*pred) &&
235 pred->src()->type_string() == "IdentityN") {
236 status = pred->src()->input_edge(pred->src_output(), &pred);
237 if (!status.ok()) return;
238 }
239 if (edge_filter(*pred)) {
240 (*compile_time_const_nodes)[pred->src()->id()] = true;
241 }
242 }
243 }
244 return;
245 }
246
247 // Mark any compile-time constant operator arguments as const.
248 std::vector<int> const_input_idxs;
249 status = GetCompileTimeConstInputs(node, &const_input_idxs, flib_runtime);
250
251 if (!status.ok()) {
252 return;
253 }
254
255 for (Edge const* edge : node->in_edges()) {
256 if (!edge->IsControlEdge() &&
257 absl::c_binary_search(const_input_idxs, edge->dst_input()) &&
258 edge_filter(*edge)) {
259 // Do not mark IdentityN nodes as compile-time const.
260 // If the src node of the `pred` is an IdentityN do not mark it as a
261 // compile-time const. Only mark the corresponding input to the
262 // IdentityN node as a const.
263 // Note: XLA IdentityN op simply forwards its inputs so this is safe.
264 while (edge_filter(*edge) &&
265 edge->src()->type_string() == "IdentityN") {
266 status = edge->src()->input_edge(edge->src_output(), &edge);
267 if (!status.ok()) return;
268 }
269 if (edge_filter(*edge)) {
270 (*compile_time_const_nodes)[edge->src()->id()] = true;
271 }
272 }
273 }
274 };
275
276 // Post-order traversal visits nodes in reverse topological order for an
277 // acyclic graph.
278 DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{},
279 [](const Edge& edge) { return !edge.src()->IsNextIteration(); });
280 if (compile_time_const_arg_indices && !edge_filter_input) {
281 VLOG(5) << "Setting the cache on the graph: " << &g;
282 g.GetConstArgIndicesCache() = *compile_time_const_arg_indices;
283 }
284 return status;
285 }
286
GetCompileTimeConstInputs(const OpKernel * op_kernel,std::vector<int> * const_input_idxs,FunctionLibraryRuntime * flib_runtime)287 Status GetCompileTimeConstInputs(const OpKernel* op_kernel,
288 std::vector<int>* const_input_idxs,
289 FunctionLibraryRuntime* flib_runtime) {
290 return GetCompileTimeConstInputs(op_kernel->def(), op_kernel,
291 /*op_def=*/nullptr, const_input_idxs,
292 flib_runtime);
293 }
294
295 } // namespace tensorflow
296