1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
17
18 #include <deque>
19 #include <numeric>
20 #include <vector>
21
22 #include "tensorflow/compiler/tf2xla/const_analysis.h"
23 #include "tensorflow/compiler/tf2xla/literal_util.h"
24 #include "tensorflow/compiler/tf2xla/shape_util.h"
25 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
26 #include "tensorflow/compiler/tf2xla/type_util.h"
27 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
28 #include "tensorflow/compiler/tf2xla/xla_context.h"
29 #include "tensorflow/compiler/tf2xla/xla_expression.h"
30 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
31 #include "tensorflow/compiler/xla/client/client_library.h"
32 #include "tensorflow/compiler/xla/client/xla_builder.h"
33 #include "tensorflow/core/common_runtime/device.h"
34 #include "tensorflow/core/common_runtime/executor.h"
35 #include "tensorflow/core/common_runtime/function.h"
36 #include "tensorflow/core/common_runtime/graph_constructor.h"
37 #include "tensorflow/core/common_runtime/graph_optimizer.h"
38 #include "tensorflow/core/framework/attr_value.pb.h"
39 #include "tensorflow/core/framework/attr_value_util.h"
40 #include "tensorflow/core/framework/function.h"
41 #include "tensorflow/core/framework/node_def_util.h"
42 #include "tensorflow/core/framework/op_kernel.h"
43 #include "tensorflow/core/graph/algorithm.h"
44 #include "tensorflow/core/graph/node_builder.h"
45 #include "tensorflow/core/graph/validate.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/gtl/cleanup.h"
48 #include "tensorflow/core/lib/hash/hash.h"
49 #include "tensorflow/core/platform/logging.h"
50 #include "tensorflow/core/public/version.h"
51 #include "tensorflow/core/util/dump_graph.h"
52
53 namespace tensorflow {
54
55 namespace {
PrepareArguments(XlaOpKernelContext * ctx,Graph * graph,const std::vector<const XlaExpression * > & expressions,const NameAttrList & func,std::vector<XlaCompiler::Argument> * args)56 Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
57 const std::vector<const XlaExpression*>& expressions,
58 const NameAttrList& func,
59 std::vector<XlaCompiler::Argument>* args) {
60 auto client = ctx->compiler()->client();
61 std::vector<bool> arg_must_be_compile_time_constant(expressions.size());
62
63 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
64 *graph, &arg_must_be_compile_time_constant,
65 /*compile_time_const_nodes=*/nullptr, ctx->function_library()));
66
67 args->resize(expressions.size());
68 for (int i = 0, end = args->size(); i < end; ++i) {
69 XlaCompiler::Argument& arg = (*args)[i];
70 arg.type = ctx->input_type(i);
71 arg.shape = ctx->InputShape(i);
72
73 switch (expressions[i]->kind()) {
74 case XlaExpression::Kind::kConstant:
75 arg.kind = XlaCompiler::Argument::kConstant;
76 arg.constant_value = *expressions[i]->constant_value();
77 break;
78 case XlaExpression::Kind::kXlaOp:
79 if (arg_must_be_compile_time_constant[i]) {
80 TF_ASSIGN_OR_RETURN(absl::optional<Tensor> value,
81 expressions[i]->ResolveConstant(client));
82 if (value.has_value()) {
83 arg.kind = XlaCompiler::Argument::kConstant;
84 arg.constant_value = *value;
85 } else {
86 arg.kind = XlaCompiler::Argument::kParameter;
87 }
88
89 } else {
90 arg.kind = XlaCompiler::Argument::kParameter;
91 }
92 break;
93 case XlaExpression::Kind::kResource: {
94 XlaResource* resource = expressions[i]->resource();
95 XlaCompiler::PopulateArgumentFromResource(*resource, &arg);
96 break;
97 }
98 case XlaExpression::Kind::kTensorList: {
99 arg.kind = XlaCompiler::Argument::kTensorList;
100 const xla::XlaOp& tensor_list = expressions[i]->handle();
101 arg.shape = tensor_list.builder()->GetShape(tensor_list).ValueOrDie();
102 break;
103 }
104 case XlaExpression::Kind::kInvalid:
105 return errors::InvalidArgument("Invalid function argument");
106 }
107 }
108 return Status::OK();
109 }
110 } // namespace
Compile()111 Status GraphCompiler::Compile() {
112 // Check that the graph has no illegal cycles.
113 TF_RETURN_IF_ERROR(graph::ValidateGraphHasNoCycle(*graph_));
114 // Maintain a mapping from node id to node outputs.
115 using NodeOutputs = std::vector<TensorValue>;
116 std::vector<NodeOutputs> output_registry(graph_->num_node_ids());
117 auto output_registry_cleanup = gtl::MakeCleanup([&output_registry] {
118 for (const NodeOutputs& outputs : output_registry) {
119 for (const TensorValue& value : outputs) {
120 CHECK(!value.is_ref());
121 delete value.tensor;
122 }
123 }
124 });
125
126 // XLA requires determinism, generate a stable ordering from DFS.
127 std::vector<Node*> topo_sorted_nodes;
128 GetReversePostOrder(*graph_, &topo_sorted_nodes,
129 /*stable_comparator=*/NodeComparatorName());
130
131 OpKernelContext::Params params;
132 PartiallySetupParams(¶ms);
133
134 for (Node* n : topo_sorted_nodes) {
135 OpKernel* op_kernel_raw = nullptr;
136 // The kernel is not actually run for functional ops, we just need it
137 // for metadata.
138 Status s = flib_->CreateKernel(n->properties(), &op_kernel_raw);
139 // Transfer ownership of the kernel to a local smart pointer.
140 std::unique_ptr<OpKernel> op_kernel(op_kernel_raw);
141
142 if (!s.ok()) {
143 s = AttachDef(s, *n);
144 LOG(ERROR) << "Executor failed to create kernel. " << s;
145 return s;
146 }
147
148 TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch())
149 << "Not supported node: " << n->DebugString();
150 params.op_kernel = op_kernel.get();
151 absl::InlinedVector<AllocatorAttributes, 4> output_attr(n->num_outputs());
152 params.output_attr_array = output_attr.data();
153
154 // tensor_inputs_ is a buffer reused across graph traversal. We clean up and
155 // reinitialize the buffer before we visit a new node.
156 tensor_inputs_.clear();
157 tensor_inputs_.resize(n->num_inputs());
158
159 // Set up inputs from outputs of previous nodes.
160 for (auto* e : n->in_edges()) {
161 if (e->IsControlEdge()) continue;
162 const Node* src = e->src();
163 const int output_registry_size = output_registry.size();
164 TF_RET_CHECK(src->id() < output_registry_size);
165 const NodeOutputs& src_outputs = output_registry[src->id()];
166
167 tensor_inputs_.at(e->dst_input()) = src_outputs.at(e->src_output());
168 }
169
170 OpKernelContext op_context(¶ms, n->num_outputs());
171 VLOG(3) << "Translating " << params.op_kernel->name();
172 if (IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n)) {
173 TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context));
174 } else {
175 device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context);
176 Status s = op_context.status();
177 if (!s.ok()) {
178 return AttachDef(s, n->def());
179 }
180 }
181
182 // Set up outputs. Also check if outputs from the previous computation is
183 // valid.
184 NodeOutputs& outputs = output_registry[n->id()];
185 outputs.resize(n->num_outputs());
186 for (int o = 0; o < n->num_outputs(); ++o) {
187 outputs[o] = op_context.release_output(o);
188 if (outputs[o].tensor == nullptr) {
189 return errors::Internal("Missing xla_context ", o, "-th output from ",
190 FormatNodeForError(*n));
191 }
192 }
193 }
194 return Status::OK();
195 }
196
197 namespace {
198
GetFunctionNameAndAttr(const FunctionLibraryRuntime & flib,const Node & node,NameAttrList * func)199 Status GetFunctionNameAndAttr(const FunctionLibraryRuntime& flib,
200 const Node& node, NameAttrList* func) {
201 if (node.IsPartitionedCall()) {
202 const AttrValue* attr_value;
203 TF_RETURN_IF_ERROR(
204 node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
205 if (!attr_value->has_func()) {
206 return errors::InvalidArgument(
207 "The attribute value for attribute 'f' in node ", node.DebugString(),
208 " does not have 'func' field set");
209 }
210 *func = attr_value->func();
211 return Status::OK();
212 }
213
214 if (flib.GetFunctionLibraryDefinition()->Find(node.def().op())) {
215 func->set_name(node.type_string());
216 } else {
217 func->set_name(FunctionLibraryDefinition::kGradientOp);
218 }
219 *func->mutable_attr() = node.def().attr();
220 return Status::OK();
221 }
222
223 } // namespace
224
CompileFunctionalNode(Node * n,OpKernelContext * op_context)225 Status GraphCompiler::CompileFunctionalNode(Node* n,
226 OpKernelContext* op_context) {
227 TF_RET_CHECK(IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n));
228 // For functional nodes, compile them using compiler from the context and call
229 // into the functions.
230 XlaOpKernelContext xla_op_context(op_context);
231
232 XlaContext& context = XlaContext::Get(op_context);
233 auto* b = context.builder();
234
235 XlaCompiler* compiler = xla_op_context.compiler();
236
237 NameAttrList func;
238 TF_RETURN_IF_ERROR(GetFunctionNameAndAttr(*flib_, *n, &func));
239
240 std::vector<const XlaExpression*> expressions;
241
242 for (auto tensor : tensor_inputs_) {
243 auto expression =
244 reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
245 expressions.push_back(expression);
246 }
247
248 // Prepare the arguments and compile the function.
249 std::vector<XlaCompiler::Argument> arguments;
250 const FunctionBody* fbody;
251 TF_RETURN_IF_ERROR(compiler->FindFunctionBody(func, &fbody));
252
253 auto graph = compiler->GetGraph(fbody);
254
255 TF_RETURN_IF_ERROR(PrepareArguments(&xla_op_context, graph.get(), expressions,
256 func, &arguments));
257
258 bool add_token_input_output =
259 func.attr().find(kXlaTokenInputNodesAttrName) != func.attr().end();
260
261 XlaCompiler::CompileOptions compile_options;
262 compile_options.is_entry_computation = false;
263 compile_options.add_token_input_output = add_token_input_output;
264 XlaCompiler::CompilationResult result;
265 TF_RETURN_IF_ERROR(
266 compiler->CompileFunction(compile_options, func, arguments, &result));
267
268 TF_RET_CHECK(arguments.size() == expressions.size());
269
270 std::vector<xla::XlaOp> handles;
271 for (int64_t i = 0, end = expressions.size(); i < end; ++i) {
272 if (arguments[i].kind == XlaCompiler::Argument::kConstant) {
273 continue;
274 }
275 if (arguments[i].kind == XlaCompiler::Argument::kResource) {
276 handles.push_back(expressions[i]->resource()->value());
277 } else {
278 handles.push_back(expressions[i]->handle());
279 }
280 }
281 if (add_token_input_output) {
282 std::vector<string> token_input_nodes;
283 TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(&func.attr()),
284 kXlaTokenInputNodesAttrName,
285 &token_input_nodes));
286 std::vector<xla::XlaOp> token_inputs;
287 for (const string& node_name : token_input_nodes) {
288 auto token_or = compiler->GetNodeToken(node_name);
289 TF_RETURN_IF_ERROR(token_or.status());
290 token_inputs.push_back(token_or.ConsumeValueOrDie());
291 }
292 xla::XlaOp token_input = xla::AfterAll(b, token_inputs);
293 handles.push_back(token_input);
294 }
295
296 auto output_handle = xla::Call(b, *result.computation, handles);
297 // The output handle of `Call` computation is a tuple type. Unzip it so
298 // that it can fit into future computations.
299 int computation_output = 0;
300 for (int64_t i = 0; i < n->num_outputs(); ++i) {
301 if (result.outputs[i].is_constant) {
302 xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value);
303 } else {
304 if (result.outputs[i].is_tensor_list) {
305 xla_op_context.SetTensorListOutput(
306 i, xla::GetTupleElement(output_handle, computation_output));
307 } else {
308 xla_op_context.SetOutput(
309 i, xla::GetTupleElement(output_handle, computation_output));
310 }
311 ++computation_output;
312 }
313 }
314
315 for (int64_t i = 0, end = result.resource_updates.size(); i < end; i++) {
316 if (result.resource_updates[i].modified) {
317 XlaResource* resource =
318 expressions[result.resource_updates[i].input_index]->resource();
319 xla::XlaOp updated_value =
320 xla::GetTupleElement(output_handle, i + n->num_outputs());
321 TF_RETURN_IF_ERROR(resource->SetValue(updated_value));
322 }
323 }
324
325 if (add_token_input_output) {
326 std::string node_name;
327 if (!GetNodeAttr(n->attrs(), kXlaOriginalOutsideCompilationNodeName,
328 &node_name)
329 .ok())
330 node_name = n->name();
331 TF_RETURN_IF_ERROR(compiler->SetNodeToken(
332 node_name, xla::GetTupleElement(output_handle, computation_output)));
333 }
334 return b->first_error();
335 }
336
PartiallySetupParams(OpKernelContext::Params * params)337 void GraphCompiler::PartiallySetupParams(OpKernelContext::Params* params) {
338 params->device = device_;
339 params->inputs = &tensor_inputs_;
340 params->step_container = step_container_;
341 params->resource_manager = device_->resource_manager();
342 params->function_library = flib_;
343 }
344
345 } // namespace tensorflow
346