• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
17 
18 #include <deque>
19 #include <numeric>
20 #include <vector>
21 #include "tensorflow/compiler/tf2xla/const_analysis.h"
22 #include "tensorflow/compiler/tf2xla/literal_util.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
25 #include "tensorflow/compiler/tf2xla/type_util.h"
26 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
27 #include "tensorflow/compiler/tf2xla/xla_context.h"
28 #include "tensorflow/compiler/tf2xla/xla_expression.h"
29 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
30 #include "tensorflow/compiler/xla/client/client_library.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/core/common_runtime/device.h"
33 #include "tensorflow/core/common_runtime/executor.h"
34 #include "tensorflow/core/common_runtime/function.h"
35 #include "tensorflow/core/common_runtime/graph_optimizer.h"
36 #include "tensorflow/core/framework/attr_value.pb.h"
37 #include "tensorflow/core/framework/attr_value_util.h"
38 #include "tensorflow/core/framework/function.h"
39 #include "tensorflow/core/framework/node_def_util.h"
40 #include "tensorflow/core/framework/op_kernel.h"
41 #include "tensorflow/core/graph/algorithm.h"
42 #include "tensorflow/core/graph/graph_constructor.h"
43 #include "tensorflow/core/graph/node_builder.h"
44 #include "tensorflow/core/graph/validate.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/gtl/cleanup.h"
47 #include "tensorflow/core/lib/hash/hash.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/public/version.h"
50 #include "tensorflow/core/util/dump_graph.h"
51 
52 namespace tensorflow {
53 
54 namespace {
PrepareArguments(XlaOpKernelContext * ctx,Graph * graph,const std::vector<const XlaExpression * > & expressions,std::vector<XlaCompiler::Argument> * args)55 Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
56                         const std::vector<const XlaExpression*>& expressions,
57                         std::vector<XlaCompiler::Argument>* args) {
58   auto client = ctx->compiler()->client();
59   std::vector<bool> arg_must_be_compile_time_constant(expressions.size());
60 
61   TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
62       *graph, &arg_must_be_compile_time_constant,
63       /*compile_time_const_nodes=*/nullptr, ctx->function_library()));
64 
65   args->resize(expressions.size());
66   for (int i = 0; i < args->size(); ++i) {
67     XlaCompiler::Argument& arg = (*args)[i];
68     arg.type = ctx->input_type(i);
69     arg.shape = ctx->InputShape(i);
70 
71     switch (expressions[i]->kind()) {
72       case XlaExpression::Kind::kConstant:
73         arg.kind = XlaCompiler::Argument::kConstant;
74         arg.constant_value = expressions[i]->constant_value();
75         break;
76       case XlaExpression::Kind::kXlaOp:
77         if (arg_must_be_compile_time_constant[i]) {
78           TF_ASSIGN_OR_RETURN(absl::optional<Tensor> value,
79                               expressions[i]->ResolveConstant(client));
80           if (!value.has_value()) {
81             return errors::InvalidArgument(
82                 "Argument to function must be a compile-time constant, but "
83                 "unable to resolve argument value to a constant.");
84           }
85           arg.kind = XlaCompiler::Argument::kConstant;
86           arg.constant_value = *value;
87         } else {
88           arg.kind = XlaCompiler::Argument::kParameter;
89         }
90         break;
91       case XlaExpression::Kind::kResource:
92         // TODO(b/126601755): This is a fairly common use case in TF 2.0 that
93         // we can hit when inlining is disabled or fails.
94         return errors::Unimplemented(
95             "Resource as function argument is not yet implemented.");
96       case XlaExpression::Kind::kTensorList:
97         return errors::Unimplemented(
98             "TensorList as function argument is not yet implemented.");
99       case XlaExpression::Kind::kInvalid:
100         return errors::InvalidArgument("Invalid function argument");
101     }
102   }
103   return Status::OK();
104 }
105 }  // namespace
Compile()106 Status GraphCompiler::Compile() {
107   // Check that the graph has no illegal cycles.
108   TF_RETURN_IF_ERROR(graph::ValidateGraphHasNoCycle(*graph_));
109   // Maintain a mapping from node id to node outputs.
110   using NodeOutputs = std::vector<TensorValue>;
111   std::vector<NodeOutputs> output_registry(graph_->num_node_ids());
112   auto output_registry_cleanup = gtl::MakeCleanup([&output_registry] {
113     for (const NodeOutputs& outputs : output_registry) {
114       for (const TensorValue& value : outputs) {
115         CHECK(!value.is_ref());
116         delete value.tensor;
117       }
118     }
119   });
120 
121   // XLA requires determinism, generate a stable ordering from DFS.
122   std::vector<Node*> topo_sorted_nodes;
123   GetReversePostOrder(*graph_, &topo_sorted_nodes,
124                       /*stable_comparator=*/NodeComparatorName());
125 
126   OpKernelContext::Params params;
127   PartiallySetupParams(&params);
128 
129   for (Node* n : topo_sorted_nodes) {
130     OpKernel* op_kernel_raw = nullptr;
131     // The kernel is not actually run for functional ops, we just need it
132     // for metadata.
133     Status s = flib_->CreateKernel(n->def(), &op_kernel_raw);
134     // Transfer ownership of the kernel to a local smart pointer.
135     std::unique_ptr<OpKernel> op_kernel(op_kernel_raw);
136 
137     if (!s.ok()) {
138       s = AttachDef(s, *n);
139       LOG(ERROR) << "Executor failed to create kernel. " << s;
140       return s;
141     }
142 
143     TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch())
144         << "Not supported node: " << n->DebugString();
145     params.op_kernel = op_kernel.get();
146     absl::InlinedVector<AllocatorAttributes, 4> output_attr(n->num_outputs());
147     params.output_attr_array = output_attr.data();
148 
149     // tensor_inputs_ is a buffer reused across graph traversal. We clean up and
150     // reinitialize the buffer before we visit a new node.
151     tensor_inputs_.clear();
152     tensor_inputs_.resize(n->num_inputs());
153 
154     // Set up inputs from outputs of previous nodes.
155     for (auto* e : n->in_edges()) {
156       if (e->IsControlEdge()) continue;
157       const Node* src = e->src();
158       TF_RET_CHECK(src->id() < output_registry.size());
159       const NodeOutputs& src_outputs = output_registry[src->id()];
160 
161       tensor_inputs_.at(e->dst_input()) = src_outputs.at(e->src_output());
162     }
163 
164     OpKernelContext op_context(&params, n->num_outputs());
165     VLOG(3) << "Translating " << params.op_kernel->name();
166     if (IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n)) {
167       TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context));
168     } else {
169       device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context);
170       Status s = op_context.status();
171       if (!s.ok()) {
172         return AttachDef(s, n->def());
173       }
174     }
175 
176     // Set up outputs. Also check if outputs from the previous computation is
177     // valid.
178     NodeOutputs& outputs = output_registry[n->id()];
179     outputs.resize(n->num_outputs());
180     for (int o = 0; o < n->num_outputs(); ++o) {
181       outputs[o] = op_context.release_output(o);
182       if (outputs[o].tensor == nullptr) {
183         return errors::Internal("Missing xla_context ", o, "-th output from ",
184                                 FormatNodeForError(*n));
185       }
186     }
187   }
188   return Status::OK();
189 }
190 
191 namespace {
192 
GetFunctionNameAndAttr(const FunctionLibraryRuntime & flib,const Node & node,NameAttrList * func)193 Status GetFunctionNameAndAttr(const FunctionLibraryRuntime& flib,
194                               const Node& node, NameAttrList* func) {
195   if (node.IsPartitionedCall()) {
196     const AttrValue* attr_value;
197     TF_RETURN_IF_ERROR(
198         node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
199     if (!attr_value->has_func()) {
200       return errors::InvalidArgument(
201           "The attribute value for attribute 'f' in node ", node.DebugString(),
202           " does not have 'func' field set");
203     }
204     *func = attr_value->func();
205     return Status::OK();
206   }
207 
208   if (flib.GetFunctionLibraryDefinition()->Find(node.def().op())) {
209     func->set_name(node.type_string());
210   } else {
211     func->set_name(FunctionLibraryDefinition::kGradientOp);
212   }
213   *func->mutable_attr() = node.def().attr();
214   return Status::OK();
215 }
216 
217 }  // namespace
218 
CompileFunctionalNode(Node * n,OpKernelContext * op_context)219 Status GraphCompiler::CompileFunctionalNode(Node* n,
220                                             OpKernelContext* op_context) {
221   TF_RET_CHECK(IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n));
222   // For functional nodes, compile them using compiler from the context and call
223   // into the functions.
224   XlaOpKernelContext xla_op_context(op_context);
225 
226   XlaContext& context = XlaContext::Get(op_context);
227   auto* b = context.builder();
228 
229   XlaCompiler* compiler = xla_op_context.compiler();
230 
231   NameAttrList func;
232   TF_RETURN_IF_ERROR(GetFunctionNameAndAttr(*flib_, *n, &func));
233 
234   std::vector<const XlaExpression*> expressions;
235 
236   for (auto tensor : tensor_inputs_) {
237     auto expression =
238         reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
239     expressions.push_back(expression);
240   }
241 
242   // Prepare the arguments and compile the function.
243   std::vector<XlaCompiler::Argument> arguments;
244   const FunctionBody* fbody;
245   TF_RETURN_IF_ERROR(compiler->FindFunctionBody(func, &fbody));
246 
247   auto graph = compiler->GetGraph(fbody);
248 
249   TF_RETURN_IF_ERROR(
250       PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments));
251 
252   bool add_token_input_output =
253       func.attr().find(kXlaTokenInputNodesAttrName) != func.attr().end();
254 
255   XlaCompiler::CompileOptions compile_options;
256   compile_options.is_entry_computation = false;
257   compile_options.add_token_input_output = add_token_input_output;
258   XlaCompiler::CompilationResult result;
259   TF_RETURN_IF_ERROR(
260       compiler->CompileFunction(compile_options, func, arguments, &result));
261 
262   TF_RET_CHECK(arguments.size() == expressions.size());
263 
264   std::vector<xla::XlaOp> handles;
265   for (int64 i = 0; i < expressions.size(); ++i) {
266     if (arguments[i].kind == XlaCompiler::Argument::kConstant) {
267       continue;
268     }
269     handles.push_back(expressions[i]->handle());
270   }
271   if (add_token_input_output) {
272     std::vector<string> token_input_nodes;
273     TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(&func.attr()),
274                                    kXlaTokenInputNodesAttrName,
275                                    &token_input_nodes));
276     std::vector<xla::XlaOp> token_inputs;
277     for (const string& node_name : token_input_nodes) {
278       auto token_or = compiler->GetNodeToken(node_name);
279       TF_RETURN_IF_ERROR(token_or.status());
280       token_inputs.push_back(token_or.ConsumeValueOrDie());
281     }
282     xla::XlaOp token_input = xla::AfterAll(b, token_inputs);
283     handles.push_back(token_input);
284   }
285 
286   auto output_handle = xla::Call(b, *result.computation, handles);
287   // The output handle of `Call` computation is a tuple type. Unzip it so
288   // that it can fit into future computations.
289   int computation_output = 0;
290   for (int64 i = 0; i < n->num_outputs(); ++i) {
291     if (result.outputs[i].is_constant) {
292       xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value);
293     } else {
294       xla_op_context.SetOutput(
295           i, xla::GetTupleElement(output_handle, computation_output));
296       ++computation_output;
297     }
298   }
299   if (add_token_input_output) {
300     TF_RETURN_IF_ERROR(compiler->SetNodeToken(
301         n->name(), xla::GetTupleElement(output_handle, computation_output)));
302   }
303   return b->first_error();
304 }
305 
PartiallySetupParams(OpKernelContext::Params * params)306 void GraphCompiler::PartiallySetupParams(OpKernelContext::Params* params) {
307   params->device = device_;
308   params->inputs = &tensor_inputs_;
309   params->step_container = step_container_;
310   params->resource_manager = device_->resource_manager();
311   params->function_library = flib_;
312 }
313 
314 }  // namespace tensorflow
315