1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
17
18 #include <deque>
19 #include <numeric>
20 #include <vector>
21 #include "tensorflow/compiler/tf2xla/const_analysis.h"
22 #include "tensorflow/compiler/tf2xla/literal_util.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
25 #include "tensorflow/compiler/tf2xla/type_util.h"
26 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
27 #include "tensorflow/compiler/tf2xla/xla_context.h"
28 #include "tensorflow/compiler/tf2xla/xla_expression.h"
29 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
30 #include "tensorflow/compiler/xla/client/client_library.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/core/common_runtime/device.h"
33 #include "tensorflow/core/common_runtime/executor.h"
34 #include "tensorflow/core/common_runtime/function.h"
35 #include "tensorflow/core/common_runtime/graph_optimizer.h"
36 #include "tensorflow/core/framework/attr_value.pb.h"
37 #include "tensorflow/core/framework/attr_value_util.h"
38 #include "tensorflow/core/framework/function.h"
39 #include "tensorflow/core/framework/node_def_util.h"
40 #include "tensorflow/core/framework/op_kernel.h"
41 #include "tensorflow/core/graph/algorithm.h"
42 #include "tensorflow/core/graph/graph_constructor.h"
43 #include "tensorflow/core/graph/node_builder.h"
44 #include "tensorflow/core/graph/validate.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/gtl/cleanup.h"
47 #include "tensorflow/core/lib/hash/hash.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/public/version.h"
50 #include "tensorflow/core/util/dump_graph.h"
51
52 namespace tensorflow {
53
54 namespace {
PrepareArguments(XlaOpKernelContext * ctx,Graph * graph,const std::vector<const XlaExpression * > & expressions,std::vector<XlaCompiler::Argument> * args)55 Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
56 const std::vector<const XlaExpression*>& expressions,
57 std::vector<XlaCompiler::Argument>* args) {
58 auto client = ctx->compiler()->client();
59 std::vector<bool> arg_must_be_compile_time_constant(expressions.size());
60
61 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
62 *graph, &arg_must_be_compile_time_constant,
63 /*compile_time_const_nodes=*/nullptr, ctx->function_library()));
64
65 args->resize(expressions.size());
66 for (int i = 0; i < args->size(); ++i) {
67 XlaCompiler::Argument& arg = (*args)[i];
68 arg.type = ctx->input_type(i);
69 arg.shape = ctx->InputShape(i);
70
71 switch (expressions[i]->kind()) {
72 case XlaExpression::Kind::kConstant:
73 arg.kind = XlaCompiler::Argument::kConstant;
74 arg.constant_value = expressions[i]->constant_value();
75 break;
76 case XlaExpression::Kind::kXlaOp:
77 if (arg_must_be_compile_time_constant[i]) {
78 TF_ASSIGN_OR_RETURN(absl::optional<Tensor> value,
79 expressions[i]->ResolveConstant(client));
80 if (!value.has_value()) {
81 return errors::InvalidArgument(
82 "Argument to function must be a compile-time constant, but "
83 "unable to resolve argument value to a constant.");
84 }
85 arg.kind = XlaCompiler::Argument::kConstant;
86 arg.constant_value = *value;
87 } else {
88 arg.kind = XlaCompiler::Argument::kParameter;
89 }
90 break;
91 case XlaExpression::Kind::kResource:
92 // TODO(b/126601755): This is a fairly common use case in TF 2.0 that
93 // we can hit when inlining is disabled or fails.
94 return errors::Unimplemented(
95 "Resource as function argument is not yet implemented.");
96 case XlaExpression::Kind::kTensorList:
97 return errors::Unimplemented(
98 "TensorList as function argument is not yet implemented.");
99 case XlaExpression::Kind::kInvalid:
100 return errors::InvalidArgument("Invalid function argument");
101 }
102 }
103 return Status::OK();
104 }
105 } // namespace
Compile()106 Status GraphCompiler::Compile() {
107 // Check that the graph has no illegal cycles.
108 TF_RETURN_IF_ERROR(graph::ValidateGraphHasNoCycle(*graph_));
109 // Maintain a mapping from node id to node outputs.
110 using NodeOutputs = std::vector<TensorValue>;
111 std::vector<NodeOutputs> output_registry(graph_->num_node_ids());
112 auto output_registry_cleanup = gtl::MakeCleanup([&output_registry] {
113 for (const NodeOutputs& outputs : output_registry) {
114 for (const TensorValue& value : outputs) {
115 CHECK(!value.is_ref());
116 delete value.tensor;
117 }
118 }
119 });
120
121 // XLA requires determinism, generate a stable ordering from DFS.
122 std::vector<Node*> topo_sorted_nodes;
123 GetReversePostOrder(*graph_, &topo_sorted_nodes,
124 /*stable_comparator=*/NodeComparatorName());
125
126 OpKernelContext::Params params;
127 PartiallySetupParams(¶ms);
128
129 for (Node* n : topo_sorted_nodes) {
130 OpKernel* op_kernel_raw = nullptr;
131 // The kernel is not actually run for functional ops, we just need it
132 // for metadata.
133 Status s = flib_->CreateKernel(n->def(), &op_kernel_raw);
134 // Transfer ownership of the kernel to a local smart pointer.
135 std::unique_ptr<OpKernel> op_kernel(op_kernel_raw);
136
137 if (!s.ok()) {
138 s = AttachDef(s, *n);
139 LOG(ERROR) << "Executor failed to create kernel. " << s;
140 return s;
141 }
142
143 TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch())
144 << "Not supported node: " << n->DebugString();
145 params.op_kernel = op_kernel.get();
146 absl::InlinedVector<AllocatorAttributes, 4> output_attr(n->num_outputs());
147 params.output_attr_array = output_attr.data();
148
149 // tensor_inputs_ is a buffer reused across graph traversal. We clean up and
150 // reinitialize the buffer before we visit a new node.
151 tensor_inputs_.clear();
152 tensor_inputs_.resize(n->num_inputs());
153
154 // Set up inputs from outputs of previous nodes.
155 for (auto* e : n->in_edges()) {
156 if (e->IsControlEdge()) continue;
157 const Node* src = e->src();
158 TF_RET_CHECK(src->id() < output_registry.size());
159 const NodeOutputs& src_outputs = output_registry[src->id()];
160
161 tensor_inputs_.at(e->dst_input()) = src_outputs.at(e->src_output());
162 }
163
164 OpKernelContext op_context(¶ms, n->num_outputs());
165 VLOG(3) << "Translating " << params.op_kernel->name();
166 if (IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n)) {
167 TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context));
168 } else {
169 device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context);
170 Status s = op_context.status();
171 if (!s.ok()) {
172 return AttachDef(s, n->def());
173 }
174 }
175
176 // Set up outputs. Also check if outputs from the previous computation is
177 // valid.
178 NodeOutputs& outputs = output_registry[n->id()];
179 outputs.resize(n->num_outputs());
180 for (int o = 0; o < n->num_outputs(); ++o) {
181 outputs[o] = op_context.release_output(o);
182 if (outputs[o].tensor == nullptr) {
183 return errors::Internal("Missing xla_context ", o, "-th output from ",
184 FormatNodeForError(*n));
185 }
186 }
187 }
188 return Status::OK();
189 }
190
191 namespace {
192
GetFunctionNameAndAttr(const FunctionLibraryRuntime & flib,const Node & node,NameAttrList * func)193 Status GetFunctionNameAndAttr(const FunctionLibraryRuntime& flib,
194 const Node& node, NameAttrList* func) {
195 if (node.IsPartitionedCall()) {
196 const AttrValue* attr_value;
197 TF_RETURN_IF_ERROR(
198 node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
199 if (!attr_value->has_func()) {
200 return errors::InvalidArgument(
201 "The attribute value for attribute 'f' in node ", node.DebugString(),
202 " does not have 'func' field set");
203 }
204 *func = attr_value->func();
205 return Status::OK();
206 }
207
208 if (flib.GetFunctionLibraryDefinition()->Find(node.def().op())) {
209 func->set_name(node.type_string());
210 } else {
211 func->set_name(FunctionLibraryDefinition::kGradientOp);
212 }
213 *func->mutable_attr() = node.def().attr();
214 return Status::OK();
215 }
216
217 } // namespace
218
CompileFunctionalNode(Node * n,OpKernelContext * op_context)219 Status GraphCompiler::CompileFunctionalNode(Node* n,
220 OpKernelContext* op_context) {
221 TF_RET_CHECK(IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n));
222 // For functional nodes, compile them using compiler from the context and call
223 // into the functions.
224 XlaOpKernelContext xla_op_context(op_context);
225
226 XlaContext& context = XlaContext::Get(op_context);
227 auto* b = context.builder();
228
229 XlaCompiler* compiler = xla_op_context.compiler();
230
231 NameAttrList func;
232 TF_RETURN_IF_ERROR(GetFunctionNameAndAttr(*flib_, *n, &func));
233
234 std::vector<const XlaExpression*> expressions;
235
236 for (auto tensor : tensor_inputs_) {
237 auto expression =
238 reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
239 expressions.push_back(expression);
240 }
241
242 // Prepare the arguments and compile the function.
243 std::vector<XlaCompiler::Argument> arguments;
244 const FunctionBody* fbody;
245 TF_RETURN_IF_ERROR(compiler->FindFunctionBody(func, &fbody));
246
247 auto graph = compiler->GetGraph(fbody);
248
249 TF_RETURN_IF_ERROR(
250 PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments));
251
252 bool add_token_input_output =
253 func.attr().find(kXlaTokenInputNodesAttrName) != func.attr().end();
254
255 XlaCompiler::CompileOptions compile_options;
256 compile_options.is_entry_computation = false;
257 compile_options.add_token_input_output = add_token_input_output;
258 XlaCompiler::CompilationResult result;
259 TF_RETURN_IF_ERROR(
260 compiler->CompileFunction(compile_options, func, arguments, &result));
261
262 TF_RET_CHECK(arguments.size() == expressions.size());
263
264 std::vector<xla::XlaOp> handles;
265 for (int64 i = 0; i < expressions.size(); ++i) {
266 if (arguments[i].kind == XlaCompiler::Argument::kConstant) {
267 continue;
268 }
269 handles.push_back(expressions[i]->handle());
270 }
271 if (add_token_input_output) {
272 std::vector<string> token_input_nodes;
273 TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(&func.attr()),
274 kXlaTokenInputNodesAttrName,
275 &token_input_nodes));
276 std::vector<xla::XlaOp> token_inputs;
277 for (const string& node_name : token_input_nodes) {
278 auto token_or = compiler->GetNodeToken(node_name);
279 TF_RETURN_IF_ERROR(token_or.status());
280 token_inputs.push_back(token_or.ConsumeValueOrDie());
281 }
282 xla::XlaOp token_input = xla::AfterAll(b, token_inputs);
283 handles.push_back(token_input);
284 }
285
286 auto output_handle = xla::Call(b, *result.computation, handles);
287 // The output handle of `Call` computation is a tuple type. Unzip it so
288 // that it can fit into future computations.
289 int computation_output = 0;
290 for (int64 i = 0; i < n->num_outputs(); ++i) {
291 if (result.outputs[i].is_constant) {
292 xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value);
293 } else {
294 xla_op_context.SetOutput(
295 i, xla::GetTupleElement(output_handle, computation_output));
296 ++computation_output;
297 }
298 }
299 if (add_token_input_output) {
300 TF_RETURN_IF_ERROR(compiler->SetNodeToken(
301 n->name(), xla::GetTupleElement(output_handle, computation_output)));
302 }
303 return b->first_error();
304 }
305
PartiallySetupParams(OpKernelContext::Params * params)306 void GraphCompiler::PartiallySetupParams(OpKernelContext::Params* params) {
307 params->device = device_;
308 params->inputs = &tensor_inputs_;
309 params->step_container = step_container_;
310 params->resource_manager = device_->resource_manager();
311 params->function_library = flib_;
312 }
313
314 } // namespace tensorflow
315