• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/const_analysis.h"
17 
18 #include <unordered_map>
19 #include <unordered_set>
20 
21 #include "absl/algorithm/container.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/graph/algorithm.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 
31 namespace tensorflow {
32 
33 namespace {
34 
GetFunctionBody(FunctionLibraryRuntime * flib_runtime,const NodeDef & node,StringPiece func_attr_name,const FunctionBody ** fbody)35 Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime,
36                        const NodeDef& node, StringPiece func_attr_name,
37                        const FunctionBody** fbody) {
38   NameAttrList name_attr_list;
39   TF_RETURN_IF_ERROR(GetNodeAttr(node, func_attr_name, &name_attr_list));
40   FunctionLibraryRuntime::Handle func_handle;
41   TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
42       name_attr_list.name(), AttrSlice(&name_attr_list.attr()), &func_handle));
43   *fbody = flib_runtime->GetFunctionBody(func_handle);
44   return Status::OK();
45 }
46 
GetFunctionBodies(FunctionLibraryRuntime * flib_runtime,const NodeDef & node,StringPiece func_list_attr_name,std::vector<const FunctionBody * > * fbodies)47 Status GetFunctionBodies(FunctionLibraryRuntime* flib_runtime,
48                          const NodeDef& node, StringPiece func_list_attr_name,
49                          std::vector<const FunctionBody*>* fbodies) {
50   std::vector<NameAttrList> name_attr_lists;
51   TF_RETURN_IF_ERROR(GetNodeAttr(node, func_list_attr_name, &name_attr_lists));
52   for (const NameAttrList& name_attr_list : name_attr_lists) {
53     FunctionLibraryRuntime::Handle func_handle;
54     TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
55         name_attr_list.name(), AttrSlice(&name_attr_list.attr()),
56         &func_handle));
57     fbodies->push_back(flib_runtime->GetFunctionBody(func_handle));
58   }
59   return Status::OK();
60 }
61 
CondConstInputIndices(absl::Span<const FunctionBody * const> branch_bodies,std::vector<int> * const_input_idxs,FunctionLibraryRuntime * flib_runtime)62 Status CondConstInputIndices(
63     absl::Span<const FunctionBody* const> branch_bodies,
64     std::vector<int>* const_input_idxs, FunctionLibraryRuntime* flib_runtime) {
65   TF_RET_CHECK(!branch_bodies.empty());
66   TF_RET_CHECK(branch_bodies[0] != nullptr);
67   int num_inputs = branch_bodies[0]->fdef.signature().input_arg_size();
68   // Stores indices of the "branch function" inputs that are expected to be
69   // compile time constants.
70   std::vector<bool> compile_time_const_arg_indices(num_inputs);
71   for (auto fbody : branch_bodies) {
72     TF_RET_CHECK(fbody != nullptr);
73     TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
74         *(fbody->graph), &compile_time_const_arg_indices,
75         /*compile_time_const_nodes=*/nullptr, flib_runtime));
76   }
77   for (int i = 0, end = compile_time_const_arg_indices.size(); i < end; i++) {
78     if (compile_time_const_arg_indices[i]) {
79       // The 0th input is the pred or branch index, which is not passed to the
80       // branches. So the i'th input of a branch function corresponds to the
81       // i + 1'th input of the If/Case op.
82       const_input_idxs->push_back(i + 1);
83     }
84   }
85   return Status::OK();
86 }
87 
GetCompileTimeConstInputs(const NodeDef & node,const OpKernel * op_kernel,const OpDef * op_def,std::vector<int> * const_input_idxs,FunctionLibraryRuntime * flib_runtime)88 Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel,
89                                  const OpDef* op_def,
90                                  std::vector<int>* const_input_idxs,
91                                  FunctionLibraryRuntime* flib_runtime) {
92   DCHECK(op_def != nullptr || op_kernel != nullptr);
93   if (node.op() == "While" || node.op() == "StatelessWhile") {
94     // For While nodes, recurse into the body and cond graphs.
95     const FunctionBody* fcond = nullptr;
96     const FunctionBody* fbody = nullptr;
97     TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "cond", &fcond));
98     TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "body", &fbody));
99     TF_RET_CHECK(fcond);
100     TF_RET_CHECK(fbody);
101     int num_inputs = fbody->fdef.signature().input_arg_size();
102 
103     // Stores which of the loop inputs are expected to be compile time
104     // constants.
105     std::vector<bool> compile_time_const_arg_indices(num_inputs);
106     TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
107         *(fcond->graph), &compile_time_const_arg_indices,
108         /*compile_time_const_nodes=*/nullptr, flib_runtime));
109     TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
110         *(fbody->graph), &compile_time_const_arg_indices,
111         /*compile_time_const_nodes=*/nullptr, flib_runtime));
112     for (int i = 0; i < num_inputs; i++) {
113       if (compile_time_const_arg_indices[i]) {
114         // Check that this input is actually a loop invariant.
115         Node* arg_i = fbody->arg_nodes[i];
116         Node* ret_i = fbody->ret_nodes[i];
117         const Node* ret_i_input_0;
118         TF_RETURN_IF_ERROR(ret_i->input_node(0, &ret_i_input_0));
119         if (ret_i_input_0->id() == arg_i->id()) {
120           const_input_idxs->push_back(i);
121         } else {
122           // TODO(b/178546817): Verify that it's OK and raise an error if we are
123           // using this branch from jit_compile=True.
124           VLOG(1) << "Argument " << i << " to while-loop "
125                   << node.ShortDebugString()
126                   << " has to be constant, but it's not a loop invariant, "
127                      "cluster compilation likely to fail at compile time: "
128                   << arg_i->def().ShortDebugString() << " vs. "
129                   << ret_i->def().ShortDebugString();
130         }
131       }
132     }
133     return Status::OK();
134   } else if (node.op() == "If" || node.op() == "StatelessIf") {
135     const FunctionBody* fthen = nullptr;
136     const FunctionBody* felse = nullptr;
137     TF_RETURN_IF_ERROR(
138         GetFunctionBody(flib_runtime, node, "then_branch", &fthen));
139     TF_RETURN_IF_ERROR(
140         GetFunctionBody(flib_runtime, node, "else_branch", &felse));
141     return CondConstInputIndices({fthen, felse}, const_input_idxs,
142                                  flib_runtime);
143   } else if (node.op() == "Case" || node.op() == "StatelessCase") {
144     std::vector<const FunctionBody*> branch_bodies;
145     TF_RETURN_IF_ERROR(
146         GetFunctionBodies(flib_runtime, node, "branches", &branch_bodies));
147     return CondConstInputIndices(branch_bodies, const_input_idxs, flib_runtime);
148   } else if (node.op() == "PartitionedCall" ||
149              node.op() == "StatefulPartitionedCall") {
150     const FunctionBody* fbody;
151     TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "f", &fbody));
152     int num_inputs = fbody->fdef.signature().input_arg_size();
153     std::vector<bool> compile_time_const_arg_indices(num_inputs);
154     TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
155         *(fbody->graph), &compile_time_const_arg_indices,
156         /*compile_time_const_nodes=*/nullptr, flib_runtime));
157     for (int i = 0; i < num_inputs; i++) {
158       if (compile_time_const_arg_indices[i]) {
159         const_input_idxs->push_back(i);
160       }
161     }
162     return Status::OK();
163   } else if (op_def != nullptr) {
164     return XlaOpRegistry::CompileTimeConstantInputs(node, *op_def,
165                                                     const_input_idxs);
166   } else {
167     return XlaOpRegistry::CompileTimeConstantInputs(*op_kernel,
168                                                     const_input_idxs);
169   }
170 }
171 
GetCompileTimeConstInputs(const Node * node,std::vector<int> * const_input_idxs,FunctionLibraryRuntime * flib_runtime)172 Status GetCompileTimeConstInputs(const Node* node,
173                                  std::vector<int>* const_input_idxs,
174                                  FunctionLibraryRuntime* flib_runtime) {
175   return GetCompileTimeConstInputs(node->def(), /*op_kernel=*/nullptr,
176                                    &node->op_def(), const_input_idxs,
177                                    flib_runtime);
178 }
179 
180 }  // namespace
181 
182 // Backwards dataflow analysis that finds arguments to a graph that must be
183 // compile-time constants.
BackwardsConstAnalysis(const Graph & g,std::vector<bool> * compile_time_const_arg_indices,std::vector<bool> * compile_time_const_nodes,FunctionLibraryRuntime * flib_runtime,std::function<bool (const Edge &)> edge_filter_input)184 Status BackwardsConstAnalysis(
185     const Graph& g, std::vector<bool>* compile_time_const_arg_indices,
186     std::vector<bool>* compile_time_const_nodes,
187     FunctionLibraryRuntime* flib_runtime,
188     std::function<bool(const Edge&)> edge_filter_input) {
189   if (!compile_time_const_nodes && g.GetConstArgIndicesCache().has_value() &&
190       !edge_filter_input) {
191     VLOG(5) << "Using cached argument indices on graph " << &g;
192     *compile_time_const_arg_indices = g.GetConstArgIndicesCache().value();
193     return Status::OK();
194   }
195   auto edge_filter = [&](const Edge& e) {
196     return edge_filter_input ? edge_filter_input(e) : true;
197   };
198 
199   std::vector<bool> compile_time_const_nodes_impl;
200   if (compile_time_const_nodes) {
201     CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids());
202   } else {
203     compile_time_const_nodes_impl.resize(g.num_node_ids());
204     compile_time_const_nodes = &compile_time_const_nodes_impl;
205   }
206 
207   Status status;
208   auto visit = [&](Node* node) {
209     if (!status.ok()) return;
210 
211     // If this is a metadata-only op, don't propagate the const requirement.
212     if (XlaOpRegistry::IsMetadataOp(node->type_string())) {
213       return;
214     }
215 
216     // If this node must be const, and it isn't a metadata op, then all of its
217     // parents must be const.
218     if ((*compile_time_const_nodes)[node->id()]) {
219       if (node->type_string() == "_Arg") {
220         int index;
221         status = GetNodeAttr(node->attrs(), "index", &index);
222         if (!status.ok()) return;
223         if (compile_time_const_arg_indices) {
224           (*compile_time_const_arg_indices)[index] = true;
225         }
226         return;
227       }
228       for (const Edge* pred : node->in_edges()) {
229         if (!pred->IsControlEdge() && edge_filter(*pred)) {
230           // If the src node of the `pred` is an IdentityN do not mark it as a
231           // compile-time const. Only mark the corresponding input to the
232           // IdentityN node as a const.
233           // Note: XLA IdentityN op simply forwards its inputs so this is safe.
234           while (edge_filter(*pred) &&
235                  pred->src()->type_string() == "IdentityN") {
236             status = pred->src()->input_edge(pred->src_output(), &pred);
237             if (!status.ok()) return;
238           }
239           if (edge_filter(*pred)) {
240             (*compile_time_const_nodes)[pred->src()->id()] = true;
241           }
242         }
243       }
244       return;
245     }
246 
247     // Mark any compile-time constant operator arguments as const.
248     std::vector<int> const_input_idxs;
249     status = GetCompileTimeConstInputs(node, &const_input_idxs, flib_runtime);
250 
251     if (!status.ok()) {
252       return;
253     }
254 
255     for (Edge const* edge : node->in_edges()) {
256       if (!edge->IsControlEdge() &&
257           absl::c_binary_search(const_input_idxs, edge->dst_input()) &&
258           edge_filter(*edge)) {
259         // Do not mark IdentityN nodes as compile-time const.
260         // If the src node of the `pred` is an IdentityN do not mark it as a
261         // compile-time const. Only mark the corresponding input to the
262         // IdentityN node as a const.
263         // Note: XLA IdentityN op simply forwards its inputs so this is safe.
264         while (edge_filter(*edge) &&
265                edge->src()->type_string() == "IdentityN") {
266           status = edge->src()->input_edge(edge->src_output(), &edge);
267           if (!status.ok()) return;
268         }
269         if (edge_filter(*edge)) {
270           (*compile_time_const_nodes)[edge->src()->id()] = true;
271         }
272       }
273     }
274   };
275 
276   // Post-order traversal visits nodes in reverse topological order for an
277   // acyclic graph.
278   DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{},
279       [](const Edge& edge) { return !edge.src()->IsNextIteration(); });
280   if (compile_time_const_arg_indices && !edge_filter_input) {
281     VLOG(5) << "Setting the cache on the graph: " << &g;
282     g.GetConstArgIndicesCache() = *compile_time_const_arg_indices;
283   }
284   return status;
285 }
286 
GetCompileTimeConstInputs(const OpKernel * op_kernel,std::vector<int> * const_input_idxs,FunctionLibraryRuntime * flib_runtime)287 Status GetCompileTimeConstInputs(const OpKernel* op_kernel,
288                                  std::vector<int>* const_input_idxs,
289                                  FunctionLibraryRuntime* flib_runtime) {
290   return GetCompileTimeConstInputs(op_kernel->def(), op_kernel,
291                                    /*op_def=*/nullptr, const_input_idxs,
292                                    flib_runtime);
293 }
294 
295 }  // namespace tensorflow
296