• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
17 
18 #include <deque>
19 #include <numeric>
20 #include <vector>
21 
22 #include "tensorflow/compiler/tf2xla/const_analysis.h"
23 #include "tensorflow/compiler/tf2xla/literal_util.h"
24 #include "tensorflow/compiler/tf2xla/shape_util.h"
25 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
26 #include "tensorflow/compiler/tf2xla/type_util.h"
27 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
28 #include "tensorflow/compiler/tf2xla/xla_context.h"
29 #include "tensorflow/compiler/tf2xla/xla_expression.h"
30 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
31 #include "tensorflow/compiler/xla/client/client_library.h"
32 #include "tensorflow/compiler/xla/client/xla_builder.h"
33 #include "tensorflow/core/common_runtime/device.h"
34 #include "tensorflow/core/common_runtime/executor.h"
35 #include "tensorflow/core/common_runtime/function.h"
36 #include "tensorflow/core/common_runtime/graph_constructor.h"
37 #include "tensorflow/core/common_runtime/graph_optimizer.h"
38 #include "tensorflow/core/framework/attr_value.pb.h"
39 #include "tensorflow/core/framework/attr_value_util.h"
40 #include "tensorflow/core/framework/function.h"
41 #include "tensorflow/core/framework/node_def_util.h"
42 #include "tensorflow/core/framework/op_kernel.h"
43 #include "tensorflow/core/graph/algorithm.h"
44 #include "tensorflow/core/graph/node_builder.h"
45 #include "tensorflow/core/graph/validate.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/gtl/cleanup.h"
48 #include "tensorflow/core/lib/hash/hash.h"
49 #include "tensorflow/core/platform/logging.h"
50 #include "tensorflow/core/public/version.h"
51 #include "tensorflow/core/util/dump_graph.h"
52 
53 namespace tensorflow {
54 
55 namespace {
PrepareArguments(XlaOpKernelContext * ctx,Graph * graph,const std::vector<const XlaExpression * > & expressions,const NameAttrList & func,std::vector<XlaCompiler::Argument> * args)56 Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
57                         const std::vector<const XlaExpression*>& expressions,
58                         const NameAttrList& func,
59                         std::vector<XlaCompiler::Argument>* args) {
60   auto client = ctx->compiler()->client();
61   std::vector<bool> arg_must_be_compile_time_constant(expressions.size());
62 
63   TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
64       *graph, &arg_must_be_compile_time_constant,
65       /*compile_time_const_nodes=*/nullptr, ctx->function_library()));
66 
67   args->resize(expressions.size());
68   for (int i = 0, end = args->size(); i < end; ++i) {
69     XlaCompiler::Argument& arg = (*args)[i];
70     arg.type = ctx->input_type(i);
71     arg.shape = ctx->InputShape(i);
72 
73     switch (expressions[i]->kind()) {
74       case XlaExpression::Kind::kConstant:
75         arg.kind = XlaCompiler::Argument::kConstant;
76         arg.constant_value = *expressions[i]->constant_value();
77         break;
78       case XlaExpression::Kind::kXlaOp:
79         if (arg_must_be_compile_time_constant[i]) {
80           TF_ASSIGN_OR_RETURN(absl::optional<Tensor> value,
81                               expressions[i]->ResolveConstant(client));
82           if (value.has_value()) {
83             arg.kind = XlaCompiler::Argument::kConstant;
84             arg.constant_value = *value;
85           } else {
86             arg.kind = XlaCompiler::Argument::kParameter;
87           }
88 
89         } else {
90           arg.kind = XlaCompiler::Argument::kParameter;
91         }
92         break;
93       case XlaExpression::Kind::kResource: {
94         XlaResource* resource = expressions[i]->resource();
95         XlaCompiler::PopulateArgumentFromResource(*resource, &arg);
96         break;
97       }
98       case XlaExpression::Kind::kTensorList: {
99         arg.kind = XlaCompiler::Argument::kTensorList;
100         const xla::XlaOp& tensor_list = expressions[i]->handle();
101         arg.shape = tensor_list.builder()->GetShape(tensor_list).ValueOrDie();
102         break;
103       }
104       case XlaExpression::Kind::kInvalid:
105         return errors::InvalidArgument("Invalid function argument");
106     }
107   }
108   return Status::OK();
109 }
110 }  // namespace
Compile()111 Status GraphCompiler::Compile() {
112   // Check that the graph has no illegal cycles.
113   TF_RETURN_IF_ERROR(graph::ValidateGraphHasNoCycle(*graph_));
114   // Maintain a mapping from node id to node outputs.
115   using NodeOutputs = std::vector<TensorValue>;
116   std::vector<NodeOutputs> output_registry(graph_->num_node_ids());
117   auto output_registry_cleanup = gtl::MakeCleanup([&output_registry] {
118     for (const NodeOutputs& outputs : output_registry) {
119       for (const TensorValue& value : outputs) {
120         CHECK(!value.is_ref());
121         delete value.tensor;
122       }
123     }
124   });
125 
126   // XLA requires determinism, generate a stable ordering from DFS.
127   std::vector<Node*> topo_sorted_nodes;
128   GetReversePostOrder(*graph_, &topo_sorted_nodes,
129                       /*stable_comparator=*/NodeComparatorName());
130 
131   OpKernelContext::Params params;
132   PartiallySetupParams(&params);
133 
134   for (Node* n : topo_sorted_nodes) {
135     OpKernel* op_kernel_raw = nullptr;
136     // The kernel is not actually run for functional ops, we just need it
137     // for metadata.
138     Status s = flib_->CreateKernel(n->properties(), &op_kernel_raw);
139     // Transfer ownership of the kernel to a local smart pointer.
140     std::unique_ptr<OpKernel> op_kernel(op_kernel_raw);
141 
142     if (!s.ok()) {
143       s = AttachDef(s, *n);
144       LOG(ERROR) << "Executor failed to create kernel. " << s;
145       return s;
146     }
147 
148     TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch())
149         << "Not supported node: " << n->DebugString();
150     params.op_kernel = op_kernel.get();
151     absl::InlinedVector<AllocatorAttributes, 4> output_attr(n->num_outputs());
152     params.output_attr_array = output_attr.data();
153 
154     // tensor_inputs_ is a buffer reused across graph traversal. We clean up and
155     // reinitialize the buffer before we visit a new node.
156     tensor_inputs_.clear();
157     tensor_inputs_.resize(n->num_inputs());
158 
159     // Set up inputs from outputs of previous nodes.
160     for (auto* e : n->in_edges()) {
161       if (e->IsControlEdge()) continue;
162       const Node* src = e->src();
163       const int output_registry_size = output_registry.size();
164       TF_RET_CHECK(src->id() < output_registry_size);
165       const NodeOutputs& src_outputs = output_registry[src->id()];
166 
167       tensor_inputs_.at(e->dst_input()) = src_outputs.at(e->src_output());
168     }
169 
170     OpKernelContext op_context(&params, n->num_outputs());
171     VLOG(3) << "Translating " << params.op_kernel->name();
172     if (IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n)) {
173       TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context));
174     } else {
175       device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context);
176       Status s = op_context.status();
177       if (!s.ok()) {
178         return AttachDef(s, n->def());
179       }
180     }
181 
182     // Set up outputs. Also check if outputs from the previous computation is
183     // valid.
184     NodeOutputs& outputs = output_registry[n->id()];
185     outputs.resize(n->num_outputs());
186     for (int o = 0; o < n->num_outputs(); ++o) {
187       outputs[o] = op_context.release_output(o);
188       if (outputs[o].tensor == nullptr) {
189         return errors::Internal("Missing xla_context ", o, "-th output from ",
190                                 FormatNodeForError(*n));
191       }
192     }
193   }
194   return Status::OK();
195 }
196 
197 namespace {
198 
GetFunctionNameAndAttr(const FunctionLibraryRuntime & flib,const Node & node,NameAttrList * func)199 Status GetFunctionNameAndAttr(const FunctionLibraryRuntime& flib,
200                               const Node& node, NameAttrList* func) {
201   if (node.IsPartitionedCall()) {
202     const AttrValue* attr_value;
203     TF_RETURN_IF_ERROR(
204         node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
205     if (!attr_value->has_func()) {
206       return errors::InvalidArgument(
207           "The attribute value for attribute 'f' in node ", node.DebugString(),
208           " does not have 'func' field set");
209     }
210     *func = attr_value->func();
211     return Status::OK();
212   }
213 
214   if (flib.GetFunctionLibraryDefinition()->Find(node.def().op())) {
215     func->set_name(node.type_string());
216   } else {
217     func->set_name(FunctionLibraryDefinition::kGradientOp);
218   }
219   *func->mutable_attr() = node.def().attr();
220   return Status::OK();
221 }
222 
223 }  // namespace
224 
CompileFunctionalNode(Node * n,OpKernelContext * op_context)225 Status GraphCompiler::CompileFunctionalNode(Node* n,
226                                             OpKernelContext* op_context) {
227   TF_RET_CHECK(IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n));
228   // For functional nodes, compile them using compiler from the context and call
229   // into the functions.
230   XlaOpKernelContext xla_op_context(op_context);
231 
232   XlaContext& context = XlaContext::Get(op_context);
233   auto* b = context.builder();
234 
235   XlaCompiler* compiler = xla_op_context.compiler();
236 
237   NameAttrList func;
238   TF_RETURN_IF_ERROR(GetFunctionNameAndAttr(*flib_, *n, &func));
239 
240   std::vector<const XlaExpression*> expressions;
241 
242   for (auto tensor : tensor_inputs_) {
243     auto expression =
244         reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
245     expressions.push_back(expression);
246   }
247 
248   // Prepare the arguments and compile the function.
249   std::vector<XlaCompiler::Argument> arguments;
250   const FunctionBody* fbody;
251   TF_RETURN_IF_ERROR(compiler->FindFunctionBody(func, &fbody));
252 
253   auto graph = compiler->GetGraph(fbody);
254 
255   TF_RETURN_IF_ERROR(PrepareArguments(&xla_op_context, graph.get(), expressions,
256                                       func, &arguments));
257 
258   bool add_token_input_output =
259       func.attr().find(kXlaTokenInputNodesAttrName) != func.attr().end();
260 
261   XlaCompiler::CompileOptions compile_options;
262   compile_options.is_entry_computation = false;
263   compile_options.add_token_input_output = add_token_input_output;
264   XlaCompiler::CompilationResult result;
265   TF_RETURN_IF_ERROR(
266       compiler->CompileFunction(compile_options, func, arguments, &result));
267 
268   TF_RET_CHECK(arguments.size() == expressions.size());
269 
270   std::vector<xla::XlaOp> handles;
271   for (int64_t i = 0, end = expressions.size(); i < end; ++i) {
272     if (arguments[i].kind == XlaCompiler::Argument::kConstant) {
273       continue;
274     }
275     if (arguments[i].kind == XlaCompiler::Argument::kResource) {
276       handles.push_back(expressions[i]->resource()->value());
277     } else {
278       handles.push_back(expressions[i]->handle());
279     }
280   }
281   if (add_token_input_output) {
282     std::vector<string> token_input_nodes;
283     TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(&func.attr()),
284                                    kXlaTokenInputNodesAttrName,
285                                    &token_input_nodes));
286     std::vector<xla::XlaOp> token_inputs;
287     for (const string& node_name : token_input_nodes) {
288       auto token_or = compiler->GetNodeToken(node_name);
289       TF_RETURN_IF_ERROR(token_or.status());
290       token_inputs.push_back(token_or.ConsumeValueOrDie());
291     }
292     xla::XlaOp token_input = xla::AfterAll(b, token_inputs);
293     handles.push_back(token_input);
294   }
295 
296   auto output_handle = xla::Call(b, *result.computation, handles);
297   // The output handle of `Call` computation is a tuple type. Unzip it so
298   // that it can fit into future computations.
299   int computation_output = 0;
300   for (int64_t i = 0; i < n->num_outputs(); ++i) {
301     if (result.outputs[i].is_constant) {
302       xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value);
303     } else {
304       if (result.outputs[i].is_tensor_list) {
305         xla_op_context.SetTensorListOutput(
306             i, xla::GetTupleElement(output_handle, computation_output));
307       } else {
308         xla_op_context.SetOutput(
309             i, xla::GetTupleElement(output_handle, computation_output));
310       }
311       ++computation_output;
312     }
313   }
314 
315   for (int64_t i = 0, end = result.resource_updates.size(); i < end; i++) {
316     if (result.resource_updates[i].modified) {
317       XlaResource* resource =
318           expressions[result.resource_updates[i].input_index]->resource();
319       xla::XlaOp updated_value =
320           xla::GetTupleElement(output_handle, i + n->num_outputs());
321       TF_RETURN_IF_ERROR(resource->SetValue(updated_value));
322     }
323   }
324 
325   if (add_token_input_output) {
326     std::string node_name;
327     if (!GetNodeAttr(n->attrs(), kXlaOriginalOutsideCompilationNodeName,
328                      &node_name)
329              .ok())
330       node_name = n->name();
331     TF_RETURN_IF_ERROR(compiler->SetNodeToken(
332         node_name, xla::GetTupleElement(output_handle, computation_output)));
333   }
334   return b->first_error();
335 }
336 
PartiallySetupParams(OpKernelContext::Params * params)337 void GraphCompiler::PartiallySetupParams(OpKernelContext::Params* params) {
338   params->device = device_;
339   params->inputs = &tensor_inputs_;
340   params->step_container = step_container_;
341   params->resource_manager = device_->resource_manager();
342   params->function_library = flib_;
343 }
344 
345 }  // namespace tensorflow
346