1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
17
18 #include <numeric>
19 #include <vector>
20
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/sharding_util.h"
25 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
27 #include "tensorflow/compiler/tf2xla/type_util.h"
28 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
29 #include "tensorflow/compiler/tf2xla/xla_context.h"
30 #include "tensorflow/compiler/xla/client/client_library.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/compiler/xla/client/xla_computation.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/common_runtime/device.h"
35 #include "tensorflow/core/common_runtime/executor.h"
36 #include "tensorflow/core/common_runtime/function.h"
37 #include "tensorflow/core/common_runtime/graph_optimizer.h"
38 #include "tensorflow/core/framework/attr_value_util.h"
39 #include "tensorflow/core/framework/function.h"
40 #include "tensorflow/core/framework/node_def_util.h"
41 #include "tensorflow/core/framework/types.h"
42 #include "tensorflow/core/graph/algorithm.h"
43 #include "tensorflow/core/graph/graph_constructor.h"
44 #include "tensorflow/core/graph/node_builder.h"
45 #include "tensorflow/core/lib/core/error_codes.pb.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/gtl/cleanup.h"
48 #include "tensorflow/core/lib/hash/hash.h"
49 #include "tensorflow/core/platform/logging.h"
50 #include "tensorflow/core/util/dump_graph.h"
51
52 namespace tensorflow {
53 namespace {
54
55 // Checks that arguments `args` match types `types`.
CheckSignature(const DataTypeVector & types,absl::Span<const XlaCompiler::Argument> args)56 Status CheckSignature(const DataTypeVector& types,
57 absl::Span<const XlaCompiler::Argument> args) {
58 if (args.size() != types.size()) {
59 return errors::Internal("Compilation arguments have ", args.size(),
60 " elements while function has ", types.size());
61 }
62 for (int i = 0; i < types.size(); ++i) {
63 // Don't perform type checks on resource variables and tensor
64 // lists (DT_VARIANT) as we have to trick the type system in order to
65 // plumb them through. DT_VARIANTS are wrapped in a DT_UINT8 tensor.
66 if (types[i] != args[i].type && types[i] != DT_RESOURCE &&
67 types[i] != DT_VARIANT) {
68 return errors::Internal(
69 "Argument ", i, " has declared type ", DataTypeString(args[i].type),
70 " but function parameter has type ", DataTypeString(types[i]));
71 }
72 }
73 return Status::OK();
74 }
75
76 // Uses the _Arg and _Retval nodes in the graph to determine a core assignment
77 // for each argument and return value.
78 xla::StatusOr<std::pair<std::map<int, int>, std::map<int, int>>>
ComputeArgAndRetvalCores(const Graph & graph)79 ComputeArgAndRetvalCores(const Graph& graph) {
80 auto get_sharding_for_node = [](const Node* n) -> xla::StatusOr<int> {
81 TF_ASSIGN_OR_RETURN(
82 auto sharding,
83 ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
84 if (sharding.has_value()) {
85 TF_RET_CHECK(sharding.value().type() ==
86 xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
87 return sharding.value().tile_assignment_devices(0);
88 } else {
89 return -1;
90 }
91 };
92 std::map<int, int> arg_cores;
93 std::map<int, int> retval_cores;
94 for (const Node* n : graph.nodes()) {
95 if (n->IsArg()) {
96 TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n));
97 if (core < 0) continue;
98 int index;
99 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
100 TF_RET_CHECK(index >= 0) << "Negative _Arg index";
101 arg_cores[index] = core;
102 } else if (n->IsRetval()) {
103 TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n));
104 if (core < 0) continue;
105 int index;
106 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
107 TF_RET_CHECK(index >= 0) << "Negative _Retval index";
108 TF_ASSIGN_OR_RETURN(retval_cores[index], get_sharding_for_node(n));
109 retval_cores[index] = core;
110 }
111 }
112 return std::make_pair(std::move(arg_cores), std::move(retval_cores));
113 }
114
ExecuteGraph(XlaContext * xla_context,std::unique_ptr<Graph> graph,XlaCompilationDevice * device,FunctionLibraryRuntime * flib,int64 step_id)115 Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
116 XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
117 int64 step_id) {
118 // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
119 // resource manager takes ownership via Create, and unrefs via Cleanup. We
120 // explicitly add a reference to ensure the refcount at entry is maintained at
121 // all exit points; Create and Cleanup are always called in this function.
122 //
123 // The Executor requires us to use ScopedStepContainer. We wrap it in a
124 // unique_ptr so we can capture the cleanup status in the end.
125 xla_context->Ref();
126 Status status;
127 auto step_container = absl::make_unique<ScopedStepContainer>(
128 step_id, [&status, device](const string& name) {
129 status = device->resource_manager()->Cleanup(name);
130 });
131 TF_RETURN_IF_ERROR(device->resource_manager()->Create(
132 step_container->name(), XlaContext::kXlaContextResourceName,
133 xla_context));
134
135 GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
136 TF_RETURN_IF_ERROR(graph_compiler.Compile());
137 // Explicitly clean up the step container, to capture the cleanup status.
138 step_container.reset();
139 return Status::OK();
140 }
141
142 // Builds the XLA computation.
143 // - `args` is the list of input arguments
144 // - `retvals` is the list of retvals produced by _Retval operators, in index
145 // order.
146 // - `args_core` and `retval_cores` are mapping from arg/return indices to core
147 // assignments.
148 // - If `return_updated_values_for_all_resources` is true, all resources will be
149 // included in `resource_updates`, regardless of whether their value changed.
150 // - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
151 // - Sets `*resource_updates` to a description of resources whose values are
152 // written by the computation; the variable writes are the last
153 // - `resource_updates.size()` return values from the computation. Each entry in
154 // `resource_updates` is a ResourceUpdate, whose `index` is the index of a
155 // resource variable argument to the computation to be updated, and `type` is
156 // the type of the final output.
BuildComputation(const std::vector<XlaCompiler::Argument> & args,const std::vector<XlaExpression> & retvals,const std::map<int,int> & arg_cores,const std::map<int,int> & retval_cores,const std::vector<std::unique_ptr<XlaResource>> & resources,std::unique_ptr<xla::XlaOp> token_output,const XlaCompiler::ShapeRepresentationFn & shape_representation_fn,bool return_updated_values_for_all_resources,bool always_return_tuple,xla::XlaBuilder * builder,xla::XlaComputation * computation,int * num_computation_outputs,int * num_nonconst_outputs,std::vector<XlaCompiler::OutputDescription> * outputs,std::vector<XlaCompiler::ResourceUpdate> * resource_updates,xla::Shape * output_shape)157 Status BuildComputation(
158 const std::vector<XlaCompiler::Argument>& args,
159 const std::vector<XlaExpression>& retvals,
160 const std::map<int, int>& arg_cores, const std::map<int, int>& retval_cores,
161 const std::vector<std::unique_ptr<XlaResource>>& resources,
162 std::unique_ptr<xla::XlaOp> token_output,
163 const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
164 bool return_updated_values_for_all_resources, bool always_return_tuple,
165 xla::XlaBuilder* builder, xla::XlaComputation* computation,
166 int* num_computation_outputs, int* num_nonconst_outputs,
167 std::vector<XlaCompiler::OutputDescription>* outputs,
168 std::vector<XlaCompiler::ResourceUpdate>* resource_updates,
169 xla::Shape* output_shape) {
170 // Attach a common operator name as metadata. This has no semantic effect — it
171 // merely makes the HLO graph more readable when visualized via TensorBoard,
172 // since TensorBoard forms groups out of operators with similar names.
173 xla::OpMetadata retval_metadata;
174 retval_metadata.set_op_name("XLA_Retvals");
175 builder->SetOpMetadata(retval_metadata);
176 auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); });
177
178 // Builds a no-op XLA computation. We need to set the sharding of outputs, but
179 // cannot change the sharding of the existing output op. To do this, we build
180 // a new identity op to which shardings can be applied.
181 auto identity_op = [builder](xla::XlaOp op) {
182 return xla::GetTupleElement(xla::Tuple(builder, {op}), 0);
183 };
184
185 std::vector<xla::XlaOp> elems;
186 elems.reserve(retvals.size());
187
188 // Keeps track of the layout of each retval. If a retval is not in this list,
189 // a descending layout is used. The first element is the output index, second
190 // element is the new layout.
191 std::vector<std::pair<int64, xla::Layout>> retval_index_and_layout;
192 for (int i = 0; i < retvals.size(); ++i) {
193 XlaCompiler::OutputDescription& output = (*outputs)[i];
194 const XlaExpression& retval = retvals[i];
195 output.type = retval.dtype();
196 switch (retval.kind()) {
197 case XlaExpression::Kind::kConstant:
198 output.is_constant = true;
199 output.constant_value = retval.constant_value();
200 output.shape = output.constant_value.shape();
201 break;
202
203 case XlaExpression::Kind::kTensorList:
204 TF_FALLTHROUGH_INTENDED;
205 case XlaExpression::Kind::kXlaOp: {
206 output.is_constant = false;
207 TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
208 xla::XlaOp value = retval.handle();
209 auto it = retval_cores.find(i);
210 xla::XlaScopedShardingAssignment assign_sharding(
211 builder, it == retval_cores.end()
212 ? absl::optional<xla::OpSharding>()
213 : xla::sharding_builder::AssignDevice(it->second));
214 if (shape_representation_fn) {
215 // If there is a shape representation function, reshape the output
216 // tensor to the shape given by the representation shape function.
217 TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn(
218 output.shape, output.type));
219 value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions()));
220 retval_index_and_layout.emplace_back(elems.size(), shape.layout());
221 } else if (it != retval_cores.end()) {
222 // Apply the sharding to the output, if there is a core assignment.
223 value = identity_op(value);
224 }
225
226 elems.push_back(value);
227 break;
228 }
229
230 case XlaExpression::Kind::kResource:
231 output.is_constant = false;
232 output.input_index = retval.resource()->arg_num();
233 output.shape = retval.resource()->shape();
234 break;
235
236 case XlaExpression::Kind::kInvalid:
237 return errors::InvalidArgument(
238 "Invalid expression returned by computation. "
239 "This probably means a return value was not set.");
240 }
241 }
242 *num_nonconst_outputs = elems.size();
243
244 // Add return values for resources whose values have changed.
245 std::vector<const XlaResource*> arg_resources;
246 arg_resources.reserve(resources.size());
247 for (const auto& resource : resources) {
248 if (resource->arg_num() >= 0) {
249 arg_resources.push_back(resource.get());
250 }
251 }
252 std::sort(arg_resources.begin(), arg_resources.end(),
253 [](const XlaResource* a, const XlaResource* b) {
254 return a->arg_num() < b->arg_num();
255 });
256
257 for (const XlaResource* resource : arg_resources) {
258 DCHECK_LT(resource->arg_num(), args.size());
259 const XlaCompiler::Argument& arg = args[resource->arg_num()];
260 auto it = arg_cores.find(resource->arg_num());
261 const int core = it == arg_cores.end() ? -1 : it->second;
262 bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
263 // TensorArray gradients were modified if their values changed or there are
264 // any newly created gradients.
265 for (const auto& grad : resource->tensor_array_gradients()) {
266 modified =
267 modified ||
268 !grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
269 arg.tensor_array_gradients.count(grad.first) == 0;
270 }
271 if (return_updated_values_for_all_resources || modified) {
272 resource_updates->emplace_back();
273 XlaCompiler::ResourceUpdate& update = resource_updates->back();
274 update.input_index = resource->arg_num();
275 update.type = resource->type();
276 update.shape = resource->shape();
277 update.modified = modified;
278 for (const auto& grad : resource->tensor_array_gradients()) {
279 update.tensor_array_gradients_accessed.insert(grad.first);
280 }
281
282 // Request that the value be returned on a specific core.
283 xla::XlaScopedShardingAssignment assign_sharding(
284 builder, core == -1 ? absl::optional<xla::OpSharding>()
285 : xla::sharding_builder::AssignDevice(core));
286
287 xla::XlaOp handle;
288 TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
289
290 // Ensures the correct sharding is applied to the output.
291 handle = identity_op(handle);
292
293 // Set layout of the retval to device representation layout.
294 if (resource->representation_shape().has_value()) {
295 retval_index_and_layout.emplace_back(
296 elems.size(), resource->representation_shape()->layout());
297 }
298 elems.push_back(handle);
299 }
300 }
301
302 // If we have token output, append it as the last one.
303 if (token_output) {
304 elems.push_back(*token_output);
305 }
306
307 *num_computation_outputs = elems.size();
308
309 // Builds the XLA computation. We *always* form a tuple here to ensure that
310 // the output value is the last thing added into the XLA computation, even
311 // if there is only one output value.
312 auto tuple = xla::Tuple(builder, elems);
313 if (!always_return_tuple && elems.size() == 1) {
314 xla::GetTupleElement(tuple, 0);
315 }
316
317 xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();
318 if (!computation_status.ok()) {
319 return computation_status.status();
320 }
321 *computation = computation_status.ConsumeValueOrDie();
322
323 TF_ASSIGN_OR_RETURN(const auto& program_shape,
324 computation->GetProgramShape());
325 *output_shape = program_shape.result();
326 // Update the output layout to the layout of retval.
327 for (auto& index_and_layout : retval_index_and_layout) {
328 if (!always_return_tuple && elems.size() == 1) {
329 *output_shape->mutable_layout() = index_and_layout.second;
330 continue;
331 }
332
333 xla::Shape* output_sub_shape = xla::ShapeUtil::GetMutableSubshape(
334 output_shape, {index_and_layout.first});
335 *output_sub_shape->mutable_layout() = index_and_layout.second;
336 }
337 return Status::OK();
338 }
339
340 } // namespace
341
operator ==(const XlaCompiler::Argument & other) const342 bool XlaCompiler::Argument::operator==(
343 const XlaCompiler::Argument& other) const {
344 if (std::tie(kind, resource_kind, type, name, initialized, max_array_size,
345 tensor_array_gradients) !=
346 std::tie(other.kind, other.resource_kind, other.type, other.name,
347 other.initialized, other.max_array_size,
348 other.tensor_array_gradients)) {
349 return false;
350 }
351 if (absl::holds_alternative<xla::Shape>(shape)) {
352 if (!absl::holds_alternative<xla::Shape>(other.shape)) {
353 return false;
354 }
355 if (!xla::Shape::Equal()(absl::get<xla::Shape>(shape),
356 absl::get<xla::Shape>(other.shape))) {
357 return false;
358 }
359 } else {
360 if (!absl::holds_alternative<TensorShape>(other.shape)) {
361 return false;
362 }
363 if (absl::get<TensorShape>(shape) != absl::get<TensorShape>(other.shape)) {
364 return false;
365 }
366 }
367 if (constant_value.shape() != other.constant_value.shape()) {
368 return false;
369 }
370 return constant_value.tensor_data() == other.constant_value.tensor_data();
371 }
372
HumanString() const373 string XlaCompiler::Argument::HumanString() const {
374 string common;
375 if (!name.empty()) {
376 common = absl::StrCat(" name=", name);
377 }
378 absl::StrAppend(&common, " type=", DataTypeString(type),
379 " shape=", ShapeHumanString());
380 switch (kind) {
381 case kInvalid:
382 return "invalid";
383 case kConstant:
384 return absl::StrCat("kind=constant", common,
385 " value=", constant_value.DebugString());
386 case kResource: {
387 string output = absl::StrCat("kind=resource", common, " resource_kind=",
388 XlaResource::KindToString(resource_kind),
389 " initialized=", initialized);
390 if (max_array_size >= 0) {
391 absl::StrAppend(&output, " max_array_size=", max_array_size);
392 }
393 if (!tensor_array_gradients.empty()) {
394 absl::StrAppend(&output, " tensor_array_gradients=",
395 absl::StrJoin(tensor_array_gradients, ","));
396 }
397 return output;
398 }
399 case kParameter:
400 return absl::StrCat("kind=parameter", common);
401 case kToken:
402 return absl::StrCat("token", common);
403 }
404 }
405
DimensionSizes() const406 std::vector<int64> XlaCompiler::Argument::DimensionSizes() const {
407 if (absl::holds_alternative<TensorShape>(shape)) {
408 return xla::InlinedVectorToVector(
409 absl::get<TensorShape>(shape).dim_sizes());
410 } else {
411 return absl::get<xla::Shape>(shape).dimensions();
412 }
413 }
414
ShapeHumanString() const415 string XlaCompiler::Argument::ShapeHumanString() const {
416 if (absl::holds_alternative<TensorShape>(shape)) {
417 return absl::get<TensorShape>(shape).DebugString();
418 } else {
419 return absl::get<xla::Shape>(shape).DebugString();
420 }
421 }
422
XlaCompiler(XlaCompiler::Options options)423 XlaCompiler::XlaCompiler(XlaCompiler::Options options)
424 : options_(options),
425 initialization_status_(Status::OK()),
426 next_step_id_(1),
427 device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
428 device_mgr_(absl::WrapUnique(device_)) {
429 CHECK(!options_.device_type.type_string().empty());
430 if (options_.populate_resource_manager) {
431 initialization_status_ =
432 (*options_.populate_resource_manager)(device_->resource_manager());
433 }
434
435 local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
436 FunctionDefLibrary{}));
437 local_pflr_.reset(new ProcessFunctionLibraryRuntime(
438 &device_mgr_, Env::Default(), options.graph_def_version,
439 local_flib_def_.get(), OptimizerOptions(),
440 nullptr /* custom_kernel_creator */));
441 pflr_.reset(new ProcessFunctionLibraryRuntime(
442 &device_mgr_, Env::Default(), options.graph_def_version, options.flib_def,
443 OptimizerOptions(), nullptr /* custom_kernel_creator */));
444
445 local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
446 flib_runtime_ = pflr_->GetFLR(device_->name());
447
448 // The default shape representation function is the identity.
449 if (!options_.shape_representation_fn) {
450 options_.shape_representation_fn =
451 [](const TensorShape& shape,
452 DataType dtype) -> xla::StatusOr<xla::Shape> {
453 xla::Shape xla_shape;
454 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
455 return xla_shape;
456 };
457 }
458 }
459
460 XlaCompiler::~XlaCompiler() = default;
461
NextStepId()462 int64 XlaCompiler::NextStepId() { return next_step_id_++; }
463
operator ()(const std::pair<string,std::vector<Argument>> & signature) const464 uint64 XlaCompiler::SignatureHash::operator()(
465 const std::pair<string, std::vector<Argument>>& signature) const {
466 return std::hash<string>()(signature.first);
467 }
468
GetFunctionBody(const NameAttrList & function,FunctionLibraryRuntime * flib_runtime,const FunctionBody ** fbody)469 static Status GetFunctionBody(const NameAttrList& function,
470 FunctionLibraryRuntime* flib_runtime,
471 const FunctionBody** fbody) {
472 FunctionLibraryRuntime::Handle handle;
473 TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
474 function.name(), AttrSlice(&function.attr()), &handle));
475
476 *fbody = flib_runtime->GetFunctionBody(handle);
477 TF_RET_CHECK(*fbody);
478 return Status::OK();
479 }
480
FindFunctionBody(const NameAttrList & function,const FunctionBody ** fbody)481 Status XlaCompiler::FindFunctionBody(const NameAttrList& function,
482 const FunctionBody** fbody) {
483 // The function may be in either the local_flib_runtime_ or flib_runtime_.
484 // Look up the function in local first and if it is not found then look up the
485 // function in flib_runtime_.
486 auto status = GetFunctionBody(function, local_flib_runtime_, fbody);
487 if (!status.ok()) {
488 if (!errors::IsNotFound(status)) {
489 return status;
490 }
491 TF_RETURN_WITH_CONTEXT_IF_ERROR(
492 GetFunctionBody(function, flib_runtime_, fbody),
493 "Local lookup failed with: ", status.error_message());
494 VLOG(4) << "Function " << function.name() << " in flib_runtime_";
495 } else {
496 VLOG(4) << "Function " << function.name() << " in local_flib_runtime_";
497 }
498 return Status::OK();
499 }
500
GetGraph(const FunctionBody * fbody)501 std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
502 std::unique_ptr<Graph> graph(new Graph(options_.flib_def));
503 CopyGraph(*fbody->graph, graph.get());
504 OptimizerOptions opts;
505 opts.set_opt_level(OptimizerOptions::L0);
506 opts.set_do_common_subexpression_elimination(false);
507 opts.set_do_function_inlining(true);
508 opts.set_do_constant_folding(true);
509 GraphOptimizer optimizer(opts);
510 // Do not constant fold nodes that output DT_VARIANT type tensors.
511 // XLA does not support Const nodes of Variant type since it needs
512 // to know the original ops to be able to compile them to the relevant
513 // XLA form.
514 // TODO(srbs): This filter is a little conservative. E.g. a subgraph of
515 // the form:
516 // Const
517 // |
518 // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op
519 // |
520 // (Discard popped list)
521 //
522 // Would have been reduced to "Const -> Op" without this filter.
523 // However since we are only allowed to specify the filter at the "Node"
524 // level there is no good way to allow the above behavior. So we
525 // disallow any sort of constant folding on Variant nodes for now.
526 auto cf_consider_fn = [](const Node* n) {
527 for (const auto& output_arg : n->op_def().output_arg()) {
528 if (output_arg.type() == DT_VARIANT) {
529 return false;
530 }
531 }
532 return true;
533 };
534 GraphOptimizer::Options graph_optimizer_options;
535 graph_optimizer_options.cf_consider_fn = cf_consider_fn;
536 optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
537 /*device=*/nullptr, &graph, graph_optimizer_options);
538
539 return graph;
540 }
541
CompileFunction(const XlaCompiler::CompileOptions & options,const NameAttrList & function,absl::Span<const XlaCompiler::Argument> args,XlaCompiler::CompilationResult * result)542 Status XlaCompiler::CompileFunction(
543 const XlaCompiler::CompileOptions& options, const NameAttrList& function,
544 absl::Span<const XlaCompiler::Argument> args,
545 XlaCompiler::CompilationResult* result) {
546 const string function_id =
547 Canonicalize(function.name(), AttrSlice(&function.attr()));
548 VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
549
550 const std::vector<XlaCompiler::Argument> arg_vector(args.begin(), args.end());
551 auto it = cache_.find({function_id, arg_vector});
552 if (it != cache_.end()) {
553 *result = it->second;
554 return Status::OK();
555 }
556
557 const FunctionBody* fbody;
558 TF_RETURN_IF_ERROR(FindFunctionBody(function, &fbody));
559
560 TF_RETURN_WITH_CONTEXT_IF_ERROR(
561 CheckSignature(fbody->arg_types, args),
562 "Signature check failure while compiling: ", function.name());
563
564 std::unique_ptr<Graph> graph = GetGraph(fbody);
565
566 // Clear the "_kernel" attribute if it is set to "host". This is used to
567 // indicate that a computation should happen on the host instead of the
568 // accelerator, but doesn't make sense in XLA.
569 const char* const kKernelAttr = "_kernel";
570 for (Node* n : graph->nodes()) {
571 string value;
572 if (GetNodeAttrSimple(n->attrs(), kKernelAttr, &value) && value == "host") {
573 n->ClearAttr(kKernelAttr);
574 }
575 }
576
577 // _Arg and _Retval nodes don't exist in the stored subgraph for the function;
578 // they are added by the function body looked up. Therefore, they don't have
579 // core assignments here.
580 // Attempt to assign a core to each _Retval and _Arg. Chooses the
581 // lowest-numbered core that consumes the argument. We choose the
582 // lowest-numbered core so the assignment is deterministic.
583 for (Node* n : graph->nodes()) {
584 if (n->IsArg()) {
585 TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true));
586 }
587 }
588 // Do _Retval as a second loop, in case the retval's input is an _Arg (which
589 // may have gotten a device assignment from the first loop).
590 for (Node* n : graph->nodes()) {
591 if (n->IsRetval()) {
592 TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false));
593 }
594 }
595
596 if (VLOG_IS_ON(2)) {
597 VLOG(2) << "XlaCompiler::CompileFunction: "
598 << DumpGraphToFile(
599 absl::StrCat("xla_compile_function_", function_id), *graph);
600 }
601
602 VLOG(1) << "====================================================";
603 TF_RETURN_IF_ERROR(
604 CompileGraph(options, function_id, std::move(graph), args, {}, result));
605 VLOG(1) << "====================================================";
606
607 cache_[{function_id, arg_vector}] = *result;
608 return Status::OK();
609 }
610
611 // Computes the XLA shape for argument 'arg'.
XLAShapeForArgument(const XlaCompiler::Argument & arg,bool is_entry_computation,xla::Shape * xla_shape) const612 Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
613 bool is_entry_computation,
614 xla::Shape* xla_shape) const {
615 switch (arg.kind) {
616 case XlaCompiler::Argument::kConstant:
617 LOG(FATAL) << "Unreachable case";
618 case XlaCompiler::Argument::kParameter: {
619 if (is_entry_computation) {
620 TensorShape shape;
621 if (absl::holds_alternative<TensorShape>(arg.shape)) {
622 shape = absl::get<TensorShape>(arg.shape);
623 } else {
624 TF_RETURN_IF_ERROR(
625 XLAShapeToTensorShape(absl::get<xla::Shape>(arg.shape), &shape));
626 }
627 TF_ASSIGN_OR_RETURN(*xla_shape,
628 options_.shape_representation_fn(shape, arg.type));
629 } else {
630 if (absl::holds_alternative<xla::Shape>(arg.shape)) {
631 *xla_shape = absl::get<xla::Shape>(arg.shape);
632 } else {
633 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(
634 arg.type, absl::get<TensorShape>(arg.shape), xla_shape));
635 }
636 }
637 return Status::OK();
638 }
639 case XlaCompiler::Argument::kResource: {
640 TF_RET_CHECK(arg.initialized);
641
642 switch (arg.resource_kind) {
643 case XlaResource::kVariable: {
644 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
645 TF_ASSIGN_OR_RETURN(*xla_shape,
646 options_.shape_representation_fn(
647 absl::get<TensorShape>(arg.shape), arg.type));
648
649 return Status::OK();
650 }
651 case XlaResource::kTensorArray: {
652 if (arg.max_array_size < 0) {
653 return errors::InvalidArgument(
654 "Negative max_array_size in XLAShapeForArgument");
655 }
656 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
657 TensorShape shape;
658 shape.AddDim(arg.max_array_size);
659 shape.AppendShape(absl::get<TensorShape>(arg.shape));
660 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape));
661
662 if (!arg.tensor_array_gradients.empty()) {
663 std::vector<xla::Shape> tuple_shape(
664 arg.tensor_array_gradients.size() + 1, *xla_shape);
665 *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape);
666 }
667 return Status::OK();
668 }
669 case XlaResource::kStack: {
670 if (arg.max_array_size < 0) {
671 return errors::InvalidArgument(
672 "Negative max_array_size in XLAShapeForArgument");
673 }
674 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
675 TensorShape shape;
676 shape.AddDim(arg.max_array_size);
677 shape.AppendShape(absl::get<TensorShape>(arg.shape));
678 xla::Shape buffer_shape;
679 TF_RETURN_IF_ERROR(
680 TensorShapeToXLAShape(arg.type, shape, &buffer_shape));
681 *xla_shape = xla::ShapeUtil::MakeTupleShape(
682 {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})});
683 return Status::OK();
684 }
685
686 case XlaResource::kInvalid:
687 return errors::Internal(
688 "Invalid resource type in XLAShapeForArgument()");
689 }
690 }
691 case XlaCompiler::Argument::kToken: {
692 *xla_shape = xla::ShapeUtil::MakeTokenShape();
693 return Status::OK();
694 }
695 case XlaCompiler::Argument::kInvalid:
696 return errors::Internal("Invalid argument type in XLAShapeForArgument()");
697 }
698 }
699
700 // Builds XLA computations for each of the arguments to the computation.
701 // `args` are the arguments to the computation.
BuildArguments(const Graph & graph,const std::vector<XlaCompiler::Argument> & args,bool use_tuple_arg,xla::XlaBuilder * builder,XlaContext * context,const std::map<int,int> & arg_cores,std::vector<XlaExpression> * arg_expressions,std::vector<int> * input_to_args,std::vector<xla::Shape> * input_shapes,bool is_entry_computation)702 Status XlaCompiler::BuildArguments(
703 const Graph& graph, const std::vector<XlaCompiler::Argument>& args,
704 bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context,
705 const std::map<int, int>& arg_cores,
706 std::vector<XlaExpression>* arg_expressions,
707 std::vector<int>* input_to_args, std::vector<xla::Shape>* input_shapes,
708 bool is_entry_computation) {
709 arg_expressions->resize(args.size());
710
711 // Argument numbers of arguments and resources that are to be passed to the
712 // XLA computation as runtime parameters. `input_to_args[a] = b` means that
713 // the a'th XLA input corresponds to the b'th original arg indexes.
714 input_to_args->clear();
715 input_to_args->reserve(args.size());
716
717 // Fills in constant arguments, and computes non-constant argument order.
718 for (std::vector<XlaCompiler::Argument>::size_type i = 0; i < args.size();
719 ++i) {
720 const XlaCompiler::Argument& arg = args[i];
721 XlaExpression& arg_expression = (*arg_expressions)[i];
722 switch (arg.kind) {
723 case XlaCompiler::Argument::kResource: {
724 TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid);
725 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
726 // TODO(phawkins): this code assumes that resource arguments do not
727 // alias.
728 XlaResource* resource =
729 context->AddResource(absl::make_unique<XlaResource>(
730 arg.resource_kind, i, arg.name, arg.type,
731 absl::get<TensorShape>(arg.shape), xla::XlaOp(),
732 /*max_array_size=*/arg.max_array_size,
733 /*tensor_array_gradients=*/arg.tensor_array_gradients,
734 /*tensor_array_multiple_writes_aggregate=*/true));
735 arg_expression = XlaExpression::Resource(resource);
736 if (arg.initialized) {
737 input_to_args->push_back(i);
738 }
739 break;
740 }
741 case XlaCompiler::Argument::kParameter:
742 case XlaCompiler::Argument::kToken: {
743 input_to_args->push_back(i);
744 break;
745 }
746 case XlaCompiler::Argument::kConstant:
747 arg_expression = XlaExpression::Constant(arg.constant_value);
748 break;
749 case XlaCompiler::Argument::kInvalid:
750 return errors::Internal(
751 "Unreachable case in BuildArguments() while filling constant args");
752 }
753 }
754
755 if (input_to_args->empty()) {
756 return Status::OK();
757 }
758
759 // `arg_to_inputs[c] = d` means that the c'th original arg index corresponds
760 // to the d'th XLA input. Note that the value -1 corresponds to constants, or
761 // other args that don't correspond to an input.
762 std::vector<int> arg_to_inputs(args.size(), -1);
763 for (int i = 0; i < input_to_args->size(); i++) {
764 arg_to_inputs[input_to_args->at(i)] = i;
765 }
766
767 std::vector<xla::Shape> arg_shapes(input_to_args->size());
768 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
769 // Computes the shapes of non-constant arguments.
770 TF_RETURN_IF_ERROR(XLAShapeForArgument(
771 args[(*input_to_args)[i]], is_entry_computation, &arg_shapes[i]));
772 }
773
774 if (use_tuple_arg) {
775 input_shapes->push_back(xla::ShapeUtil::MakeTupleShape(arg_shapes));
776 } else {
777 *input_shapes = arg_shapes;
778 }
779
780 // Attach a common operator name as metadata. This has no semantic effect — it
781 // merely makes the HLO graph more readable when visualized via TensorBoard,
782 // since TensorBoard forms groups out of operators with similar names.
783 xla::OpMetadata arg_metadata;
784 arg_metadata.set_op_name("XLA_Args");
785 builder->SetOpMetadata(arg_metadata);
786
787 // Build parameter handles for non-constant arguments.
788 std::vector<xla::XlaOp> arg_handles(input_to_args->size());
789 if (use_tuple_arg) {
790 xla::XlaOp tuple;
791 if (is_entry_computation) {
792 xla::OpSharding tuple_sharding;
793 tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
794 for (int64 parameter : *input_to_args) {
795 auto it = arg_cores.find(parameter);
796 const int core = it == arg_cores.end() ? 0 : it->second;
797 *tuple_sharding.add_tuple_shardings() =
798 xla::sharding_builder::AssignDevice(core);
799 }
800 xla::XlaScopedShardingAssignment assign_tuple_sharding(builder,
801 tuple_sharding);
802 tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
803 } else {
804 tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
805 }
806
807 for (int i = 0; i < input_to_args->size(); ++i) {
808 const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
809 for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) {
810 int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second);
811 TF_RETURN_IF_ERROR(builder->SetDynamicBinding(
812 /*dynamic_size_param_num=*/0, {dynamic_size_param_index},
813 /*target_param_num=*/0, /*target_param_index=*/{i},
814 dim_and_arg_num.first));
815 }
816 }
817
818 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
819 auto it = arg_cores.find(i);
820 const int core = it == arg_cores.end() ? -1 : it->second;
821 xla::XlaScopedShardingAssignment assign_sharding(
822 builder, core == -1 ? absl::optional<xla::OpSharding>()
823 : xla::sharding_builder::AssignDevice(core));
824 arg_handles[i] = xla::GetTupleElement(tuple, i);
825 }
826 } else {
827 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
828 auto it = arg_cores.find(i);
829 const int core = it == arg_cores.end() ? -1 : it->second;
830 xla::XlaScopedShardingAssignment assign_sharding(
831 builder, core == -1 ? absl::optional<xla::OpSharding>()
832 : xla::sharding_builder::AssignDevice(core));
833 arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i],
834 absl::StrCat("arg", i));
835 }
836
837 for (int i = 0; i < input_to_args->size(); ++i) {
838 const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
839 for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) {
840 int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second);
841 TF_RETURN_IF_ERROR(builder->SetDynamicBinding(
842 /*dynamic_size_param_num=*/dynamic_size_param_index, {},
843 /*target_param_num=*/i, /*target_param_index=*/{},
844 dim_and_arg_num.first));
845 }
846 }
847 }
848
849 builder->ClearOpMetadata();
850
851 // Fill in the handles in non-constant arguments, and reshape parameters
852 // back to their correct shapes.
853 VLOG(2) << "XLA computation inputs:";
854 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
855 const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
856 VLOG(2) << " XLA arg " << i
857 << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
858 << " name: " << arg.name << " TF arg " << input_to_args->at(i);
859 XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)];
860 switch (arg.kind) {
861 case XlaCompiler::Argument::kResource: {
862 TF_RET_CHECK(arg.initialized);
863 XlaResource* resource = arg_expression.resource();
864 TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients,
865 arg_handles[i], builder));
866 VLOG(2) << " resource: num_gradients: "
867 << arg.tensor_array_gradients.size();
868 break;
869 }
870 case XlaCompiler::Argument::kParameter:
871 // Reshape parameters back to their correct shapes.
872 // TODO(b/76097077): propagate device assignments onto arguments and
873 // return values of functions, and then reshape unconditionally.
874 if (is_entry_computation) {
875 arg_expression = XlaExpression::XlaOp(
876 xla::Reshape(arg_handles[i], arg.DimensionSizes()), arg.type);
877 } else {
878 arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
879 }
880 break;
881 case XlaCompiler::Argument::kToken: {
882 arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
883 break;
884 }
885 case XlaCompiler::Argument::kConstant:
886 case XlaCompiler::Argument::kInvalid:
887 return errors::Internal(
888 "Unreachable case in BuildArguments() while filling handles");
889 }
890 }
891
892 return Status::OK();
893 }
894
CompileSingleOp(const XlaCompiler::CompileOptions & options,const NodeDef & node_def,absl::Span<const XlaCompiler::Argument> args,absl::Span<const DataType> result_types,CompilationResult * result)895 Status XlaCompiler::CompileSingleOp(
896 const XlaCompiler::CompileOptions& options, const NodeDef& node_def,
897 absl::Span<const XlaCompiler::Argument> args,
898 absl::Span<const DataType> result_types, CompilationResult* result) {
899 // TODO(b/74182462): We implement this by creating a new dummy Graph including
900 // _Arg nodes, and let CompileGraph walk it. This could be optimized.
901 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
902
903 Status status;
904 // First create the actual node we care about computing.
905 Node* main_node = graph->AddNode(node_def, &status);
906 TF_RETURN_IF_ERROR(status);
907
908 // Create dummy _Arg nodes. Link these to `node` and also via a control
909 // dependency edge to the _SOURCE node.
910 for (int64 i = 0; i < args.size(); ++i) {
911 Node* node;
912 string arg_name = absl::StrCat("_arg", i);
913 Status status =
914 NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
915 .ControlInput(graph->source_node())
916 .Attr("T", args[i].kind == Argument::kResource ? DT_RESOURCE
917 : args[i].type)
918 .Attr("index", i)
919 .Finalize(graph.get(), &node);
920 TF_RETURN_IF_ERROR(status);
921 graph->AddEdge(node, 0, main_node, i);
922 }
923
924 // Similarly with return values, create dummy _Retval nodes fed by `node`.
925 for (int64 i = 0; i < result_types.size(); ++i) {
926 Node* node;
927 string retval_name = absl::StrCat("_retval", i);
928 Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
929 .Input(main_node, i)
930 .Attr("T", result_types[i])
931 .Attr("index", i)
932 .Finalize(graph.get(), &node);
933 TF_RETURN_IF_ERROR(status);
934 }
935 FixupSourceAndSinkEdges(graph.get());
936
937 return CompileGraph(options, node_def.name(), std::move(graph), args, {},
938 result);
939 }
940
941 namespace {
942
943 // Check that the ops of all non-functional nodes have been registered.
ValidateFunctionDef(const FunctionDef * fdef,const FunctionLibraryDefinition & flib_def)944 Status ValidateFunctionDef(const FunctionDef* fdef,
945 const FunctionLibraryDefinition& flib_def) {
946 for (const NodeDef& node : fdef->node_def()) {
947 const string& op = node.op();
948 if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) {
949 continue;
950 }
951 const OpDef* op_def;
952 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(op, &op_def));
953 }
954 return Status::OK();
955 }
956
957 // If node is PartitionedCall or StatefulPartitionedCall, returns the
958 // name from the "f" attr, else returns node.def().op().
959 // Returned pointer points to the internal string either in node's attributes
960 // or in its NodeDef. This pointer is valid as long as the node has not been
961 // modified.
GetPotentialFunctionName(const Node & node,const string ** name)962 Status GetPotentialFunctionName(const Node& node, const string** name) {
963 if (node.IsPartitionedCall()) {
964 const AttrValue* attr_value;
965 TF_RETURN_IF_ERROR(
966 node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
967 if (!attr_value->has_func()) {
968 return errors::InvalidArgument(
969 "The attribute value for attribute 'f' in node ", node.DebugString(),
970 " does not have 'func' field set");
971 }
972 *name = &attr_value->func().name();
973 return Status::OK();
974 }
975 *name = &node.type_string();
976 return Status::OK();
977 }
978
979 // Check that the graph doesn't have any invalid nodes (e.g. incompatible with
980 // given device_type, invalid data type, missing attributes...)
ValidateGraph(const Graph * graph,const FunctionLibraryDefinition & flib_def,const DeviceType & device_type,const string & name)981 Status ValidateGraph(const Graph* graph,
982 const FunctionLibraryDefinition& flib_def,
983 const DeviceType& device_type, const string& name) {
984 auto maybe_error = [&](const Node* node, const Status& s) -> Status {
985 if (!s.ok()) {
986 return errors::InvalidArgument(absl::StrCat(
987 "Detected unsupported operations when trying to compile graph ", name,
988 " on ", device_type.type_string(), ": ", node->def().op(), " (",
989 s.error_message(), ")", FormatNodeForError(*node)));
990 }
991 return Status::OK();
992 };
993
994 for (const Node* node : graph->nodes()) {
995 if (node->type_string() == FunctionLibraryDefinition::kGradientOp) {
996 continue;
997 }
998 const string* function_name;
999 TF_RETURN_IF_ERROR(GetPotentialFunctionName(*node, &function_name));
1000 const FunctionDef* fdef = flib_def.Find(*function_name);
1001 Status s;
1002 if (fdef) {
1003 s = ValidateFunctionDef(fdef, flib_def);
1004 TF_RETURN_IF_ERROR(maybe_error(node, s));
1005 continue;
1006 }
1007 const OpDef* op_def;
1008 s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def);
1009 TF_RETURN_IF_ERROR(maybe_error(node, s));
1010 TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def));
1011 s = FindKernelDef(device_type, node->def(), nullptr, nullptr);
1012 TF_RETURN_IF_ERROR(maybe_error(node, s));
1013 }
1014 return Status::OK();
1015 }
1016
1017 // Converts the value of any expressions whose values are known at compile-time
1018 // to constants.
ResolveConstantExpressionsToConstants(xla::Client * client,absl::Span<XlaExpression> expressions)1019 Status ResolveConstantExpressionsToConstants(
1020 xla::Client* client, absl::Span<XlaExpression> expressions) {
1021 for (XlaExpression& expression : expressions) {
1022 if (expression.kind() == XlaExpression::Kind::kXlaOp) {
1023 TF_ASSIGN_OR_RETURN(absl::optional<Tensor> constant,
1024 expression.ResolveConstant(client));
1025 if (constant.has_value()) {
1026 expression = XlaExpression::Constant(*constant);
1027 }
1028 }
1029 }
1030 return Status::OK();
1031 }
1032
ConvertConstantsToExpressions(xla::XlaBuilder * builder,absl::Span<XlaExpression> expressions)1033 void ConvertConstantsToExpressions(xla::XlaBuilder* builder,
1034 absl::Span<XlaExpression> expressions) {
1035 for (XlaExpression& expression : expressions) {
1036 if (expression.kind() == XlaExpression::Kind::kConstant) {
1037 expression =
1038 XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype());
1039 }
1040 }
1041 }
1042
1043 } // namespace
1044
CompileGraph(const XlaCompiler::CompileOptions & options,string const & name,std::unique_ptr<Graph> graph,absl::Span<const XlaCompiler::Argument> args,absl::Span<const xla::XlaBuilder::InputOutputAlias> user_aliases,CompilationResult * result)1045 Status XlaCompiler::CompileGraph(
1046 const XlaCompiler::CompileOptions& options, string const& name,
1047 std::unique_ptr<Graph> graph, absl::Span<const XlaCompiler::Argument> args,
1048 absl::Span<const xla::XlaBuilder::InputOutputAlias> user_aliases,
1049 CompilationResult* result) {
1050 VLOG(1) << "Executing graph symbolically to populate XlaBuilder.";
1051
1052 TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes(
1053 graph.get(), options_.flib_def, local_flib_def_.get()));
1054 if (VLOG_IS_ON(2)) {
1055 VLOG(2) << "XlaCompiler::CompileGraph: "
1056 << DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph,
1057 flib_runtime_->GetFunctionLibraryDefinition());
1058 }
1059
1060 // Report the error here if initialization failed.
1061 TF_RETURN_IF_ERROR(initialization_status_);
1062
1063 // Detect invalid nodes.
1064 // FunctionalizeControlFlow may remove some nodes from the graph.
1065 TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
1066 options_.device_type, name));
1067
1068 xla::XlaBuilder builder(name);
1069 XlaContext* context = new XlaContext(this, &builder);
1070 core::ScopedUnref context_unref(context);
1071
1072 std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());
1073 int token_input_index = -1;
1074 std::unique_ptr<xla::XlaOp> token_output;
1075 if (options.add_token_input_output) {
1076 // Add extra token input.
1077 token_input_index = real_args.size();
1078
1079 XlaCompiler::Argument token_arg;
1080 token_arg.kind = XlaCompiler::Argument::kToken;
1081 real_args.push_back(token_arg);
1082 }
1083
1084 std::map<int, int> arg_cores;
1085 std::map<int, int> retval_cores;
1086 TF_ASSIGN_OR_RETURN(std::tie(arg_cores, retval_cores),
1087 ComputeArgAndRetvalCores(*graph));
1088
1089 std::vector<XlaExpression> arg_expressions;
1090 TF_RETURN_IF_ERROR(BuildArguments(
1091 *graph, real_args, options.use_tuple_arg, &builder, context, arg_cores,
1092 &arg_expressions, &result->input_mapping, &result->xla_input_shapes,
1093 options.is_entry_computation));
1094 context->set_args(std::move(arg_expressions));
1095
1096 // Propagate any aliases given to us by the user.
1097 for (const xla::XlaBuilder::InputOutputAlias& alias : user_aliases) {
1098 builder.SetUpAlias(alias.output_index, alias.param_number,
1099 alias.param_index);
1100 }
1101
1102 PushNodeTokenMapping();
1103 // Use std::set instead of std::unordered_set to ensure determinism.
1104 std::set<std::string> output_node_token_inputs;
1105 if (token_input_index != -1) {
1106 // Original token comes from input.
1107 auto arg_expression = context->args()[token_input_index];
1108 TF_RETURN_IF_ERROR(
1109 SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle()));
1110
1111 // Calculate token inputs for output token.
1112 output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph);
1113
1114 // If there's no side-effecting op in the graph, use token input as token
1115 // output.
1116 if (output_node_token_inputs.empty()) {
1117 output_node_token_inputs.insert(kXlaTokenArgNodeName);
1118 }
1119 } else if (options.is_entry_computation) {
1120 // Original token is manually created.
1121 if (HasSideEffectingNodes(*graph)) {
1122 TF_RETURN_IF_ERROR(
1123 SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder)));
1124 }
1125 }
1126
1127 TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
1128 flib_runtime_, NextStepId()));
1129 if (token_input_index != -1) {
1130 // Add extra token output.
1131 std::vector<xla::XlaOp> token_inputs;
1132 for (const auto& node_name : output_node_token_inputs) {
1133 auto token_or = GetNodeToken(node_name);
1134 TF_RETURN_IF_ERROR(token_or.status());
1135 token_inputs.push_back(token_or.ValueOrDie());
1136 }
1137 token_output.reset(new xla::XlaOp(xla::AfterAll(&builder, token_inputs)));
1138 }
1139 TF_RETURN_IF_ERROR(PopNodeTokenMapping());
1140
1141 int num_nonconst_outputs;
1142 int num_computation_outputs;
1143 result->computation = std::make_shared<xla::XlaComputation>();
1144 result->outputs.resize(context->retvals().size());
1145 std::vector<XlaExpression> retvals = context->retvals();
1146 if (options.resolve_compile_time_constants) {
1147 Status status = ResolveConstantExpressionsToConstants(
1148 client(), absl::Span<XlaExpression>(retvals));
1149
1150 // If the HloEvaluator has not implemented an expression, just evaluate it
1151 // at runtime.
1152 if (status.code() == error::UNIMPLEMENTED) {
1153 ConvertConstantsToExpressions(&builder,
1154 absl::Span<XlaExpression>(retvals));
1155 } else {
1156 TF_RETURN_IF_ERROR(status);
1157 }
1158 } else {
1159 ConvertConstantsToExpressions(&builder, absl::Span<XlaExpression>(retvals));
1160 }
1161 TF_RETURN_IF_ERROR(BuildComputation(
1162 real_args, retvals, arg_cores, retval_cores, context->resources(),
1163 std::move(token_output),
1164 options.is_entry_computation ? options_.shape_representation_fn
1165 : ShapeRepresentationFn{},
1166 options.return_updated_values_for_all_resources,
1167 options.always_return_tuple, &builder, result->computation.get(),
1168 &num_computation_outputs, &num_nonconst_outputs, &result->outputs,
1169 &result->resource_updates, &result->xla_output_shape));
1170
1171 VLOG(2) << "Outputs: total: " << context->retvals().size()
1172 << " nonconstant: " << num_nonconst_outputs;
1173 VLOG(2) << "XLA output shape: "
1174 << xla::ShapeUtil::HumanStringWithLayout(result->xla_output_shape);
1175 return Status::OK();
1176 }
1177
GetChannelHandle(const string & key,xla::ChannelHandle * channel)1178 Status XlaCompiler::GetChannelHandle(const string& key,
1179 xla::ChannelHandle* channel) {
1180 auto result = channels_.emplace(key, xla::ChannelHandle());
1181 if (result.second) {
1182 TF_ASSIGN_OR_RETURN(result.first->second, client()->CreateChannelHandle());
1183 }
1184 *channel = result.first->second;
1185 VLOG(1) << "Channel: " << key << " " << channel->DebugString();
1186 return Status::OK();
1187 }
1188
GetHostToDeviceChannelHandle(const string & key,xla::ChannelHandle * channel)1189 Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key,
1190 xla::ChannelHandle* channel) {
1191 auto result = channels_.emplace(key, xla::ChannelHandle());
1192 if (result.second) {
1193 TF_ASSIGN_OR_RETURN(result.first->second,
1194 client()->CreateHostToDeviceChannelHandle());
1195 }
1196 *channel = result.first->second;
1197 VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString();
1198 return Status::OK();
1199 }
1200
GetDeviceToHostChannelHandle(const string & key,xla::ChannelHandle * channel)1201 Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key,
1202 xla::ChannelHandle* channel) {
1203 auto result = channels_.emplace(key, xla::ChannelHandle());
1204 if (result.second) {
1205 TF_ASSIGN_OR_RETURN(result.first->second,
1206 client()->CreateDeviceToHostChannelHandle());
1207 }
1208 *channel = result.first->second;
1209 VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString();
1210 return Status::OK();
1211 }
1212
1213 namespace {
1214
SetTransfer(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes,tf2xla::HostTransferMetadata * transfer)1215 void SetTransfer(const string& key, absl::Span<const DataType> types,
1216 absl::Span<const TensorShape> shapes,
1217 tf2xla::HostTransferMetadata* transfer) {
1218 transfer->set_key(key);
1219 CHECK(types.size() == shapes.size());
1220 for (int i = 0; i < types.size(); ++i) {
1221 tf2xla::TensorMetadata* metadata = transfer->add_metadata();
1222 metadata->set_type(types[i]);
1223 shapes[i].AsProto(metadata->mutable_shape());
1224 }
1225 }
1226
1227 } // namespace
1228
SetDeviceToHostMetadata(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes)1229 Status XlaCompiler::SetDeviceToHostMetadata(
1230 const string& key, absl::Span<const DataType> types,
1231 absl::Span<const TensorShape> shapes) {
1232 if (host_compute_sends_.find(key) != host_compute_sends_.end()) {
1233 return errors::InvalidArgument(
1234 "Duplicate calls to SetDeviceToHostMetadata with key ", key);
1235 }
1236 tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key];
1237 SetTransfer(key, types, shapes, &transfer);
1238 return Status::OK();
1239 }
1240
GetDeviceToHostShapes(const string & key,std::vector<TensorShape> * shapes) const1241 Status XlaCompiler::GetDeviceToHostShapes(
1242 const string& key, std::vector<TensorShape>* shapes) const {
1243 const auto iter = host_compute_sends_.find(key);
1244 if (iter == host_compute_sends_.end()) {
1245 return errors::InvalidArgument(
1246 "No host compute send shapes registered for key ", key);
1247 }
1248 shapes->clear();
1249 for (int i = 0; i < iter->second.metadata_size(); ++i) {
1250 TensorShape shape(iter->second.metadata(i).shape());
1251 shapes->push_back(shape);
1252 }
1253 return Status::OK();
1254 }
1255
SetHostToDeviceMetadata(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes)1256 Status XlaCompiler::SetHostToDeviceMetadata(
1257 const string& key, absl::Span<const DataType> types,
1258 absl::Span<const TensorShape> shapes) {
1259 if (host_compute_recvs_.find(key) != host_compute_sends_.end()) {
1260 return errors::InvalidArgument(
1261 "Duplicate calls to SetHostToDeviceMetadata with key ", key);
1262 }
1263 tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key];
1264 SetTransfer(key, types, shapes, &transfer);
1265 return Status::OK();
1266 }
1267
GetHostComputeControlDependency(const string & host_compute_name,xla::XlaOp * handle)1268 Status XlaCompiler::GetHostComputeControlDependency(
1269 const string& host_compute_name, xla::XlaOp* handle) {
1270 const auto iter = host_compute_control_output_.find(host_compute_name);
1271 if (iter == host_compute_control_output_.end()) {
1272 return errors::InvalidArgument(
1273 "No registered control handle for host compute Op '", host_compute_name,
1274 "'");
1275 } else {
1276 *handle = iter->second;
1277 }
1278 return Status::OK();
1279 }
1280
SetHostComputeControlDependency(const string & host_compute_name,const xla::XlaOp & handle)1281 Status XlaCompiler::SetHostComputeControlDependency(
1282 const string& host_compute_name, const xla::XlaOp& handle) {
1283 if (host_compute_control_output_.find(host_compute_name) !=
1284 host_compute_control_output_.end()) {
1285 return errors::InvalidArgument(
1286 "Duplicate control handles registered for for host compute Op ",
1287 host_compute_name);
1288 }
1289 host_compute_control_output_[host_compute_name] = handle;
1290 return Status::OK();
1291 }
1292
PushNodeTokenMapping()1293 void XlaCompiler::PushNodeTokenMapping() {
1294 node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{});
1295 }
1296
PopNodeTokenMapping()1297 Status XlaCompiler::PopNodeTokenMapping() {
1298 if (node_token_mapping_stack_.empty()) {
1299 return errors::FailedPrecondition(
1300 "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is "
1301 "empty.");
1302 }
1303 node_token_mapping_stack_.pop();
1304 return Status::OK();
1305 }
1306
SetNodeToken(const string & node_name,const xla::XlaOp & op)1307 Status XlaCompiler::SetNodeToken(const string& node_name,
1308 const xla::XlaOp& op) {
1309 if (node_token_mapping_stack_.empty()) {
1310 return errors::FailedPrecondition(
1311 "Calling SetNodeToken() when node_token_mapping_stack_ is "
1312 "empty.");
1313 }
1314 auto insert_result = node_token_mapping_stack_.top().insert({node_name, op});
1315 if (!insert_result.second) {
1316 return errors::FailedPrecondition("Token mapping already exists for node ",
1317 node_name);
1318 }
1319 return Status::OK();
1320 }
1321
GetNodeToken(const string & node_name)1322 xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
1323 if (node_token_mapping_stack_.empty()) {
1324 return errors::FailedPrecondition(
1325 "Calling GetNodeToken() when node_token_mapping_stack_ is "
1326 "empty.");
1327 }
1328 auto iter = node_token_mapping_stack_.top().find(node_name);
1329 if (iter == node_token_mapping_stack_.top().end()) {
1330 return errors::FailedPrecondition("Cannot find token mapping for node ",
1331 node_name);
1332 }
1333 return iter->second;
1334 }
1335
1336 } // namespace tensorflow
1337