1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
17
18 #include <numeric>
19 #include <vector>
20
21 #include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "absl/types/variant.h"
25 #include "tensorflow/compiler/jit/defs.h"
26 #include "tensorflow/compiler/jit/flags.h"
27 #include "tensorflow/compiler/jit/shape_inference.h"
28 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
29 #include "tensorflow/compiler/mlir/utils/array_container_utils.h"
30 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
31 #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
32 #include "tensorflow/compiler/tf2xla/shape_util.h"
33 #include "tensorflow/compiler/tf2xla/sharding_util.h"
34 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
35 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
36 #include "tensorflow/compiler/tf2xla/type_util.h"
37 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
38 #include "tensorflow/compiler/tf2xla/xla_context.h"
39 #include "tensorflow/compiler/xla/client/client_library.h"
40 #include "tensorflow/compiler/xla/client/xla_builder.h"
41 #include "tensorflow/compiler/xla/client/xla_computation.h"
42 #include "tensorflow/compiler/xla/protobuf_util.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/core/common_runtime/device.h"
46 #include "tensorflow/core/common_runtime/executor.h"
47 #include "tensorflow/core/common_runtime/function.h"
48 #include "tensorflow/core/common_runtime/graph_constructor.h"
49 #include "tensorflow/core/common_runtime/graph_optimizer.h"
50 #include "tensorflow/core/framework/attr_value_util.h"
51 #include "tensorflow/core/framework/function.h"
52 #include "tensorflow/core/framework/node_def_util.h"
53 #include "tensorflow/core/framework/types.h"
54 #include "tensorflow/core/graph/node_builder.h"
55 #include "tensorflow/core/lib/core/errors.h"
56 #include "tensorflow/core/lib/gtl/cleanup.h"
57 #include "tensorflow/core/lib/hash/hash.h"
58 #include "tensorflow/core/platform/logging.h"
59 #include "tensorflow/core/protobuf/error_codes.pb.h"
60 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
61 #include "tensorflow/core/util/dump_graph.h"
62
63 namespace tensorflow {
64 namespace {
65
66 // Checks that arguments `args` match types `types`.
CheckSignature(const DataTypeVector & types,absl::Span<const XlaCompiler::Argument> args)67 Status CheckSignature(const DataTypeVector& types,
68 absl::Span<const XlaCompiler::Argument> args) {
69 if (args.size() != types.size()) {
70 return errors::Internal("Compilation arguments have ", args.size(),
71 " elements while function has ", types.size());
72 }
73 for (int i = 0, end = types.size(); i < end; ++i) {
74 // Don't perform type checks on resource variables and tensor
75 // lists (DT_VARIANT) as we have to trick the type system in order to
76 // plumb them through. DT_VARIANTS are wrapped in a DT_UINT8 tensor.
77 if (types[i] != args[i].type && types[i] != DT_RESOURCE &&
78 types[i] != DT_VARIANT) {
79 return errors::Internal(
80 "Argument ", i, " has declared type ", DataTypeString(args[i].type),
81 " but function parameter has type ", DataTypeString(types[i]));
82 }
83 }
84 return Status::OK();
85 }
86
87 // Uses the _Arg and _Retval nodes in the graph to determine an OpSharding for
88 // each argument and return value.
89 StatusOr<
90 std::pair<std::map<int, xla::OpSharding>, std::map<int, xla::OpSharding>>>
ComputeArgAndRetvalShardings(const Graph & graph)91 ComputeArgAndRetvalShardings(const Graph& graph) {
92 auto get_sharding_for_node =
93 [](const Node* n) -> StatusOr<absl::optional<xla::OpSharding>> {
94 TF_ASSIGN_OR_RETURN(
95 auto sharding,
96 ParseShardingFromDevice(*n, std::numeric_limits<int32>::max(),
97 /*add_metadata=*/false));
98 return sharding;
99 };
100 std::map<int, xla::OpSharding> arg_shardings;
101 std::map<int, xla::OpSharding> retval_shardings;
102 for (const Node* n : graph.nodes()) {
103 if (n->IsArg()) {
104 TF_ASSIGN_OR_RETURN(auto sharding, get_sharding_for_node(n));
105 if (!sharding.has_value()) continue;
106 int index;
107 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
108 TF_RET_CHECK(index >= 0) << "Negative _Arg index";
109 arg_shardings[index] = std::move(*sharding);
110 } else if (n->IsRetval()) {
111 TF_ASSIGN_OR_RETURN(auto sharding, get_sharding_for_node(n));
112 if (!sharding.has_value()) continue;
113 int index;
114 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
115 TF_RET_CHECK(index >= 0) << "Negative _Retval index";
116 retval_shardings[index] = std::move(*sharding);
117 }
118 }
119 return std::make_pair(std::move(arg_shardings), std::move(retval_shardings));
120 }
121
ExecuteGraph(XlaContext * xla_context,std::unique_ptr<Graph> graph,XlaCompilationDevice * device,FunctionLibraryRuntime * flib,int64_t step_id)122 Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
123 XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
124 int64_t step_id) {
125 // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
126 // resource manager takes ownership via Create, and unrefs via Cleanup. We
127 // explicitly add a reference to ensure the refcount at entry is maintained at
128 // all exit points; Create and Cleanup are always called in this function.
129 //
130 // The Executor requires us to use ScopedStepContainer. We wrap it in a
131 // unique_ptr so we can capture the cleanup status in the end.
132 xla_context->Ref();
133 Status status;
134 auto step_container = absl::make_unique<ScopedStepContainer>(
135 step_id, [&status, device](const string& name) {
136 status = device->resource_manager()->Cleanup(name);
137 });
138 TF_RETURN_IF_ERROR(step_container->Create(device->resource_manager(),
139 XlaContext::kXlaContextResourceName,
140 xla_context));
141
142 GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
143 TF_RETURN_IF_ERROR(graph_compiler.Compile());
144 // Explicitly clean up the step container, to capture the cleanup status.
145 step_container.reset();
146 return status;
147 }
148
149 // Builds the XLA computation.
150 // - `args` is the list of input arguments
151 // - `retvals` is the list of retvals produced by _Retval operators, in index
152 // order.
153 // - `arg_shardings` and `retval_shardings` are mapping from arg/return indices
154 // to sharding.
155 // - If `return_updated_values_for_all_resources` is true, all resources will be
156 // included in `resource_updates`, regardless of whether their value changed.
157 // - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
158 // - Sets `*resource_updates` to a description of resources whose values are
159 // written by the computation; the variable writes are the last
160 // - `resource_updates.size()` return values from the computation. Each entry in
161 // `resource_updates` is a ResourceUpdate, whose `index` is the index of a
162 // resource variable argument to the computation to be updated, and `type` is
163 // the type of the final output.
BuildComputation(const std::vector<XlaCompiler::Argument> & args,const std::vector<XlaExpression> & retvals,const std::map<int,xla::OpSharding> & arg_shardings,const std::map<int,xla::OpSharding> & retval_shardings,const std::vector<std::unique_ptr<XlaResource>> & resources,std::unique_ptr<xla::XlaOp> token_output,const XlaCompiler::ShapeRepresentationFn & shape_representation_fn,bool is_entry_computation,bool return_updated_values_for_all_resources,bool always_return_tuple,bool use_tuple_arg,bool alias_resource_update,xla::XlaBuilder * builder,xla::XlaComputation * computation,int * num_computation_outputs,int * num_nonconst_outputs,std::vector<XlaCompiler::OutputDescription> * outputs,std::vector<XlaCompiler::ResourceUpdate> * resource_updates,xla::Shape * output_shape,absl::Span<int const> input_mapping)164 Status BuildComputation(
165 const std::vector<XlaCompiler::Argument>& args,
166 const std::vector<XlaExpression>& retvals,
167 const std::map<int, xla::OpSharding>& arg_shardings,
168 const std::map<int, xla::OpSharding>& retval_shardings,
169 const std::vector<std::unique_ptr<XlaResource>>& resources,
170 std::unique_ptr<xla::XlaOp> token_output,
171 const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
172 bool is_entry_computation, bool return_updated_values_for_all_resources,
173 bool always_return_tuple, bool use_tuple_arg, bool alias_resource_update,
174 xla::XlaBuilder* builder, xla::XlaComputation* computation,
175 int* num_computation_outputs, int* num_nonconst_outputs,
176 std::vector<XlaCompiler::OutputDescription>* outputs,
177 std::vector<XlaCompiler::ResourceUpdate>* resource_updates,
178 xla::Shape* output_shape, absl::Span<int const> input_mapping) {
179 // Attach a common operator name as metadata. This has no semantic effect — it
180 // merely makes the HLO graph more readable when visualized via TensorBoard,
181 // since TensorBoard forms groups out of operators with similar names.
182 xla::OpMetadata retval_metadata;
183 retval_metadata.set_op_name("XLA_Retvals");
184 builder->SetOpMetadata(retval_metadata);
185 VLOG(1) << "Building new computation";
186 auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); });
187
188 // Builds a no-op XLA computation. We need to set the sharding of outputs, but
189 // cannot change the sharding of the existing output op. To do this, we build
190 // a new identity op to which shardings can be applied.
191 auto identity_op = [builder](xla::XlaOp op) {
192 return xla::GetTupleElement(xla::Tuple(builder, {op}), 0);
193 };
194
195 std::vector<xla::XlaOp> elems;
196 elems.reserve(retvals.size());
197
198 // Keeps track of sharding of each retval. If a retval is not in this list,
199 // replicate sharding is used. The first element is the output index, second
200 // element is the sharding.
201 std::unordered_map<int, xla::OpSharding> retval_index_and_sharding;
202 for (int i = 0, end = retvals.size(); i < end; ++i) {
203 XlaCompiler::OutputDescription& output = (*outputs)[i];
204 const XlaExpression& retval = retvals[i];
205 output.type = retval.dtype();
206 switch (retval.kind()) {
207 case XlaExpression::Kind::kConstant:
208 output.is_constant = true;
209 output.constant_value = *retval.constant_value();
210 output.shape = output.constant_value.shape();
211 break;
212
213 case XlaExpression::Kind::kTensorList: {
214 output.is_tensor_list = true;
215 xla::XlaOp value = retval.handle();
216 elems.push_back(value);
217 break;
218 }
219
220 case XlaExpression::Kind::kXlaOp: {
221 output.is_constant = false;
222 TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
223 xla::XlaOp value = retval.handle();
224 auto it = retval_shardings.find(i);
225 absl::optional<xla::OpSharding> sharding =
226 it == retval_shardings.end() ? absl::optional<xla::OpSharding>()
227 : it->second;
228 if (it != retval_shardings.end()) {
229 retval_index_and_sharding[elems.size()] = it->second;
230 }
231 if (shape_representation_fn) {
232 TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(value));
233 TF_ASSIGN_OR_RETURN(value,
234 ReshapeWithCorrectRepresentationAndSharding(
235 builder, value, original_shape,
236 shape_representation_fn, sharding,
237 /*fast_mem=*/false));
238 }
239 if (it != retval_shardings.end()) {
240 xla::XlaScopedShardingAssignment assign_sharding(builder, sharding);
241 // Apply the sharding to the output, if there is a core assignment.
242 value = identity_op(value);
243 }
244
245 elems.push_back(value);
246 break;
247 }
248
249 case XlaExpression::Kind::kResource:
250 // Resources will be pushed into elems later when processing resource
251 // arguments below.
252 output.is_constant = false;
253 output.input_index = retval.resource()->arg_num();
254 output.shape = retval.resource()->shape();
255 break;
256
257 case XlaExpression::Kind::kInvalid:
258 return errors::InvalidArgument(
259 "Invalid expression returned by computation. "
260 "This probably means a return value was not set.");
261 }
262 }
263 *num_nonconst_outputs = elems.size();
264
265 // Add return values for resources whose values have changed.
266 std::vector<const XlaResource*> arg_resources;
267 arg_resources.reserve(resources.size());
268 for (const auto& resource : resources) {
269 if (resource->arg_num() >= 0) {
270 arg_resources.push_back(resource.get());
271 }
272 }
273 std::sort(arg_resources.begin(), arg_resources.end(),
274 [](const XlaResource* a, const XlaResource* b) {
275 return a->arg_num() < b->arg_num();
276 });
277
278 absl::flat_hash_map<int, int> argument_to_xla_arg;
279 for (int xla_arg = 0; xla_arg < input_mapping.size(); xla_arg++) {
280 argument_to_xla_arg[input_mapping[xla_arg]] = xla_arg;
281 }
282
283 std::vector<xla::XlaBuilder::InputOutputAlias> aliases;
284 for (const XlaResource* resource : arg_resources) {
285 DCHECK_LT(resource->arg_num(), args.size());
286 const XlaCompiler::Argument& arg = args[resource->arg_num()];
287 auto it = arg_shardings.find(resource->arg_num());
288 bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
289 // TensorArray gradients were modified if their values changed or there are
290 // any newly created gradients.
291 for (const auto& grad : resource->tensor_array_gradients()) {
292 modified =
293 modified ||
294 !grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
295 arg.tensor_array_gradients.count(grad.first) == 0;
296 }
297
298 if (return_updated_values_for_all_resources || modified ||
299 arg.requires_broadcast) {
300 resource_updates->emplace_back();
301 XlaCompiler::ResourceUpdate& update = resource_updates->back();
302 update.input_index = resource->arg_num();
303 update.type = resource->type();
304 update.shape = resource->shape();
305 update.modified = modified;
306 int param_num = use_tuple_arg ? 0 : update.input_index;
307 if (is_entry_computation &&
308 arg.resource_kind != XlaResource::kTensorArray &&
309 alias_resource_update && argument_to_xla_arg.count(param_num)) {
310 // Assuming tuple arg and results are used.
311 xla::ShapeIndex param_index =
312 use_tuple_arg ? xla::ShapeIndex({update.input_index})
313 : xla::ShapeIndex{};
314 int xla_param_num = argument_to_xla_arg[param_num];
315 int64_t output_index_num = elems.size();
316 xla::ShapeIndex output_index = xla::ShapeIndex({output_index_num});
317 VLOG(3) << "Storing alias: " << output_index.ToString() << ": ("
318 << xla_param_num << ", " << param_index.ToString() << ")";
319 aliases.push_back({output_index, xla_param_num, param_index});
320 }
321 for (const auto& grad : resource->tensor_array_gradients()) {
322 update.tensor_array_gradients_accessed.insert(grad.first);
323 }
324
325 xla::XlaOp handle;
326 TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
327 auto sharding = it == arg_shardings.end()
328 ? absl::optional<xla::OpSharding>()
329 : it->second;
330 // Set layout of the retval to device representation layout.
331 if (shape_representation_fn) {
332 TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(handle));
333 TF_ASSIGN_OR_RETURN(
334 handle, ReshapeWithCorrectRepresentationAndSharding(
335 builder, handle, original_shape,
336 shape_representation_fn, sharding, arg.fast_mem));
337 }
338
339 // Request that the value be returned on a specific core.
340 xla::XlaScopedShardingAssignment assign_sharding(builder, sharding);
341 if (it != arg_shardings.end()) {
342 retval_index_and_sharding[elems.size()] = it->second;
343 }
344 // Ensures the correct sharding is applied to the output.
345 handle = identity_op(handle);
346 elems.push_back(handle);
347 }
348 }
349
350 // If we have token output, append it as the last one.
351 if (token_output) {
352 elems.push_back(*token_output);
353 }
354
355 *num_computation_outputs = elems.size();
356
357 // Builds the XLA computation. We *always* form a tuple here to ensure that
358 // the output value is the last thing added into the XLA computation, even
359 // if there is only one output value.
360 xla::XlaOp tuple;
361 if (retval_index_and_sharding.empty() || !is_entry_computation) {
362 tuple = xla::Tuple(builder, elems);
363 } else {
364 std::vector<xla::Shape> elem_shapes;
365 for (const auto& elem : elems) {
366 TF_ASSIGN_OR_RETURN(xla::Shape elem_shape,
367 elem.builder()->GetShape(elem));
368 elem_shapes.push_back(elem_shape);
369 }
370 xla::Shape shape = xla::ShapeUtil::MakeTupleShape(elem_shapes);
371 // Copy specified sharding from retval_index_and_sharding.
372 std::vector<xla::HloSharding> sharding_elems;
373 for (int i = 0, end = elems.size(); i < end; i++) {
374 const auto& iter = retval_index_and_sharding.find(i);
375 TF_RET_CHECK(iter != retval_index_and_sharding.end());
376 const xla::OpSharding& sub_op_sharding = iter->second;
377 TF_ASSIGN_OR_RETURN(xla::HloSharding sub_sharding,
378 xla::HloSharding::FromProto(sub_op_sharding));
379 if (elem_shapes[i].IsTuple()) {
380 const std::vector<xla::HloSharding> sub_sharding_elems =
381 sub_sharding.tuple_elements();
382 const int64_t sub_sharding_elems_size = sub_sharding_elems.size();
383 TF_RET_CHECK(sub_sharding_elems_size ==
384 xla::ShapeUtil::GetLeafCount(elem_shapes[i]));
385 for (const auto& sub_sharding_elem : sub_sharding_elems) {
386 sharding_elems.push_back(sub_sharding_elem);
387 }
388 } else {
389 sharding_elems.push_back(sub_sharding);
390 }
391 }
392 xla::HloSharding modified_sharding =
393 xla::HloSharding::Tuple(shape, sharding_elems);
394 xla::OpSharding op_sharding = modified_sharding.ToProto();
395 // Assign proper sharding to the tuple instruction.
396 xla::XlaScopedShardingAssignment assign_sharding(builder, op_sharding);
397 tuple = xla::Tuple(builder, elems);
398 }
399 bool returns_tuple = always_return_tuple || elems.size() != 1;
400 VLOG(3) << "Computation returns a tuple=" << returns_tuple;
401 if (!returns_tuple) {
402 xla::GetTupleElement(tuple, 0);
403
404 for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
405 if (alias.output_index == xla::ShapeIndex({0})) {
406 VLOG(3) << "For aliased parameter " << alias.param_number << ": "
407 << alias.param_index.ToString()
408 << " normalizing output_index from {0} to {}, as a scalar is "
409 "returned from the cluster";
410 alias.output_index = xla::ShapeIndex({});
411 }
412 }
413 }
414
415 for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
416 builder->SetUpAlias(alias.output_index, alias.param_number,
417 alias.param_index);
418 }
419
420 StatusOr<xla::XlaComputation> computation_status = builder->Build();
421 if (!computation_status.ok()) {
422 return computation_status.status();
423 }
424 *computation = computation_status.ConsumeValueOrDie();
425
426 TF_ASSIGN_OR_RETURN(auto program_shape, computation->GetProgramShape());
427 *output_shape = program_shape.result();
428 return Status::OK();
429 }
430
431 } // namespace
432
433
HumanString() const434 string XlaCompiler::Argument::HumanString() const {
435 string common;
436 if (!name.empty()) {
437 common = absl::StrCat(" name=", name);
438 }
439 absl::StrAppend(&common, " type=", DataTypeString(type),
440 " shape=", ShapeHumanString());
441 absl::StrAppend(
442 &common, " is_same_data_across_replicas=", is_same_data_across_replicas);
443 switch (kind) {
444 case kInvalid:
445 return "invalid";
446 case kConstant:
447 return absl::StrCat("kind=constant", common,
448 " value=", constant_value.DebugString());
449 case kConstantResource:
450 return absl::StrCat("kind=constant-resource", common,
451 " value=", constant_value.DebugString());
452 case kResource: {
453 string output = absl::StrCat(
454 "kind=resource", common,
455 " resource_kind=", XlaResource::KindToString(resource_kind),
456 " initialized=", initialized, " is_fast_mem=", fast_mem);
457 if (max_array_size >= 0) {
458 absl::StrAppend(&output, " max_array_size=", max_array_size);
459 }
460 if (!tensor_array_gradients.empty()) {
461 absl::StrAppend(&output, " tensor_array_gradients=",
462 absl::StrJoin(tensor_array_gradients, ","));
463 }
464 return output;
465 }
466 case kParameter:
467 return absl::StrCat("kind=parameter", common);
468 case kTensorList:
469 return absl::StrCat("kind=tensorlist", common);
470 case kToken:
471 return absl::StrCat("token", common);
472 }
473 }
474
DimensionSizes() const475 std::vector<int64> XlaCompiler::Argument::DimensionSizes() const {
476 if (absl::holds_alternative<TensorShape>(shape)) {
477 return xla::InlinedVectorToVector(
478 absl::get<TensorShape>(shape).dim_sizes());
479 } else {
480 return xla::SpanToVector(absl::get<xla::Shape>(shape).dimensions());
481 }
482 }
483
484 absl::InlinedVector<int64, 4>
DimensionSizesAsInlinedVector() const485 XlaCompiler::Argument::DimensionSizesAsInlinedVector() const {
486 if (absl::holds_alternative<TensorShape>(shape)) {
487 return absl::get<TensorShape>(shape).dim_sizes();
488 } else {
489 auto v = absl::get<xla::Shape>(shape).dimensions();
490 return absl::InlinedVector<int64, 4>(v.begin(), v.end());
491 }
492 }
493
ShapeHumanString() const494 string XlaCompiler::Argument::ShapeHumanString() const {
495 if (absl::holds_alternative<TensorShape>(shape)) {
496 return absl::get<TensorShape>(shape).DebugString();
497 } else {
498 return absl::get<xla::Shape>(shape).DebugString();
499 }
500 }
501
XlaCompiler(XlaCompiler::Options options)502 XlaCompiler::XlaCompiler(XlaCompiler::Options options)
503 : options_(options),
504 initialization_status_(Status::OK()),
505 next_step_id_(1),
506 device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
507 device_mgr_(absl::WrapUnique(device_)) {
508 CHECK(!options_.device_type.type_string().empty());
509 if (options_.populate_resource_manager) {
510 initialization_status_ =
511 (*options_.populate_resource_manager)(device_->resource_manager());
512 }
513
514 local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
515 FunctionDefLibrary{}));
516 local_pflr_.reset(new ProcessFunctionLibraryRuntime(
517 &device_mgr_, Env::Default(), /*config=*/nullptr,
518 options.graph_def_version, local_flib_def_.get(), OptimizerOptions()));
519 pflr_.reset(new ProcessFunctionLibraryRuntime(
520 &device_mgr_, Env::Default(), /*config=*/nullptr,
521 options.graph_def_version, options.flib_def, OptimizerOptions()));
522
523 local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
524 flib_runtime_ = pflr_->GetFLR(device_->name());
525
526 // The default shape representation function is the identity.
527 if (!options_.shape_representation_fn) {
528 options_.shape_representation_fn = IdentityShapeRepresentationFn();
529 }
530 }
531
532 XlaCompiler::~XlaCompiler() = default;
533
NextStepId()534 int64 XlaCompiler::NextStepId() { return next_step_id_++; }
535
operator ()(const std::pair<string,std::vector<Argument>> & signature) const536 uint64 XlaCompiler::SignatureHash::operator()(
537 const std::pair<string, std::vector<Argument>>& signature) const {
538 return std::hash<string>()(signature.first);
539 }
540
GetFunctionBody(const NameAttrList & function,FunctionLibraryRuntime * flib_runtime,const FunctionBody ** fbody)541 static Status GetFunctionBody(const NameAttrList& function,
542 FunctionLibraryRuntime* flib_runtime,
543 const FunctionBody** fbody) {
544 FunctionLibraryRuntime::Handle handle;
545 TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
546 function.name(), AttrSlice(&function.attr()), &handle));
547
548 *fbody = flib_runtime->GetFunctionBody(handle);
549 TF_RET_CHECK(*fbody);
550 return Status::OK();
551 }
552
FindFunctionBody(const NameAttrList & function,const FunctionBody ** fbody,const ConfigProto ** config_proto)553 Status XlaCompiler::FindFunctionBody(const NameAttrList& function,
554 const FunctionBody** fbody,
555 const ConfigProto** config_proto) {
556 // The function may be in either the local_flib_runtime_ or flib_runtime_.
557 // Look up the function in local first and if it is not found then look up the
558 // function in flib_runtime_.
559 auto status = GetFunctionBody(function, local_flib_runtime_, fbody);
560 if (!status.ok()) {
561 if (!errors::IsNotFound(status)) {
562 return status;
563 }
564 TF_RETURN_WITH_CONTEXT_IF_ERROR(
565 GetFunctionBody(function, flib_runtime_, fbody),
566 "Local lookup failed with: ", status.error_message());
567 if (config_proto) {
568 *config_proto = flib_runtime_->config_proto();
569 }
570 VLOG(4) << "Function " << function.name() << " in flib_runtime_";
571 } else {
572 if (config_proto) {
573 *config_proto = local_flib_runtime_->config_proto();
574 }
575 VLOG(4) << "Function " << function.name() << " in local_flib_runtime_";
576 }
577 return Status::OK();
578 }
579
GetGraph(const FunctionBody * fbody)580 std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
581 std::unique_ptr<Graph> graph(new Graph(options_.flib_def));
582 CopyGraph(*fbody->graph, graph.get());
583
584 bool is_inside_mustcompile = false;
585 TryGetNodeAttr(AttrSlice(&fbody->fdef.attr()), kXlaMustCompileAttr,
586 &is_inside_mustcompile);
587
588 // Performs a first function inlining pass before shape inference, since
589 // otherwise shape inference can't see inside functions and a comprehensive
590 // shape_map, including function ops, is needed to constant-propagate Shape
591 // Ops below.
592 auto flags = GetBuildXlaOpsPassFlags();
593 OptimizerOptions opts;
594 opts.set_opt_level(OptimizerOptions::L0);
595 opts.set_do_common_subexpression_elimination(false);
596 opts.set_do_function_inlining(true);
597 opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding);
598 GraphOptimizer optimizer(opts);
599 // Do not constant fold nodes that output DT_VARIANT type tensors.
600 // XLA does not support Const nodes of Variant type since it needs
601 // to know the original ops to be able to compile them to the relevant
602 // XLA form.
603 // TODO(srbs): This filter is a little conservative. E.g. a subgraph of
604 // the form:
605 // Const
606 // |
607 // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op
608 // |
609 // (Discard popped list)
610 //
611 // Would have been reduced to "Const -> Op" without this filter.
612 // However since we are only allowed to specify the filter at the "Node"
613 // level there is no good way to allow the above behavior. So we
614 // disallow any sort of constant folding on Variant nodes for now.
615 //
616 // Also do not consider constant folding Shape ops. When there is a dynamic
617 // dimension in a tensor, TF2XLA currently represent them as the static
618 // upperbound shape, which can be constant folded and then lose the info
619 // that this Shape is dynamic.
620 auto cf_consider_fn = [](const Node* n) {
621 for (const auto& output_arg : n->op_def().output_arg()) {
622 if (output_arg.type() == DT_VARIANT) {
623 return false;
624 }
625 }
626 const auto& ts = n->type_string();
627 // XLA has special logic to handle dynamic shapes, don't constant fold
628 // them.
629 if (ts == "Shape" || ts == "ShapeN" || ts == "Size") {
630 return false;
631 }
632 return true;
633 };
634 GraphOptimizer::Options graph_optimizer_options;
635 graph_optimizer_options.cf_consider_fn = cf_consider_fn;
636 graph_optimizer_options.inline_multi_device_functions = true;
637 graph_optimizer_options.inline_impl_selection_group_functions = true;
638 graph_optimizer_options.inline_with_single_device_body_placer = true;
639 graph_optimizer_options.ignore_noinline = is_inside_mustcompile;
640
641 {
642 GraphShapeInfo shape_info;
643 InferShapes(graph.get(), /*arg_shapes=*/{},
644 flib_runtime_->GetFunctionLibraryDefinition(), &shape_info)
645 .IgnoreError();
646 auto node_name_index = graph->BuildNodeNameIndex();
647 std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
648 for (const auto& node_shape_info : shape_info) {
649 const string& node_name = node_shape_info.first;
650 const std::vector<InferredShape>& output_shapes = node_shape_info.second;
651 const auto& node_iter = node_name_index.find(node_name);
652 if (node_iter != node_name_index.end()) {
653 auto& partial_shapes = shape_map[node_name];
654 for (const auto& inferred_shape : output_shapes) {
655 partial_shapes.push_back(inferred_shape.shape);
656 }
657 }
658 }
659 graph_optimizer_options.shape_map = &shape_map;
660 optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
661 /*device=*/nullptr, &graph, graph_optimizer_options);
662 }
663
664 // Run shape inference on the graph and optimize the graph again.
665 GraphShapeInfo shape_info;
666 InferShapes(graph.get(), /*arg_shapes=*/{},
667 flib_runtime_->GetFunctionLibraryDefinition(), &shape_info)
668 .IgnoreError();
669 auto node_name_index = graph->BuildNodeNameIndex();
670 std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
671 for (const auto& node_shape_info : shape_info) {
672 const string& node_name = node_shape_info.first;
673 const std::vector<InferredShape>& output_shapes = node_shape_info.second;
674 const auto& node_iter = node_name_index.find(node_name);
675 if (node_iter != node_name_index.end()) {
676 auto& partial_shapes = shape_map[node_name];
677 for (const auto& inferred_shape : output_shapes) {
678 partial_shapes.push_back(inferred_shape.shape);
679 }
680 }
681 }
682 graph_optimizer_options.shape_map = &shape_map;
683 optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
684 /*device=*/nullptr, &graph, graph_optimizer_options);
685
686 return graph;
687 }
688
689 // Collects all control rets from `orig_control_ret_nodes` that are still valid,
690 // keeping the same order.
GetValidControlRets(absl::Span<Node * const> orig_control_ret_nodes,const Graph & graph)691 std::vector<std::string> GetValidControlRets(
692 absl::Span<Node* const> orig_control_ret_nodes, const Graph& graph) {
693 // Build map from control ret node name to index.
694 // We use Node name instead of Node* here to index into the map as we populate
695 // the map with nodes in FunctionDef control_ret_nodes and later query it
696 // using the nodes in `graph`. The Node pointers would be different but the
697 // Node name is expected to remain the same between the two.
698 absl::flat_hash_map<const string, int> control_ret_nodes_map;
699 for (int i = 0; i < orig_control_ret_nodes.size(); ++i) {
700 const Node* n = orig_control_ret_nodes[i];
701 control_ret_nodes_map[n->name()] = i;
702 }
703 // Check which control rets are still valid.
704 std::vector<bool> is_valid_control_ret(orig_control_ret_nodes.size(), false);
705 int num_valid_control_rets = 0;
706 for (const Node* n : graph.nodes()) {
707 auto iter = control_ret_nodes_map.find(n->name());
708 if (iter != control_ret_nodes_map.end()) {
709 ++num_valid_control_rets;
710 is_valid_control_ret[iter->second] = true;
711 }
712 }
713 // Return valid control rets in same order as they appear in
714 // `orig_control_ret_nodes`.
715 std::vector<std::string> valid_control_rets;
716 valid_control_rets.reserve(num_valid_control_rets);
717 for (int i = 0; i < orig_control_ret_nodes.size(); ++i) {
718 if (is_valid_control_ret[i]) {
719 valid_control_rets.push_back(orig_control_ret_nodes[i]->name());
720 }
721 }
722 return valid_control_rets;
723 }
724
CompileFunction(const XlaCompiler::CompileOptions & options,const NameAttrList & fn_name_attrs,absl::Span<const XlaCompiler::Argument> args,XlaCompiler::CompilationResult * result)725 Status XlaCompiler::CompileFunction(
726 const XlaCompiler::CompileOptions& options,
727 const NameAttrList& fn_name_attrs,
728 absl::Span<const XlaCompiler::Argument> args,
729 XlaCompiler::CompilationResult* result) {
730 const string function_id =
731 Canonicalize(fn_name_attrs.name(), AttrSlice(&fn_name_attrs.attr()));
732 VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
733
734 const std::vector<XlaCompiler::Argument> arg_vector(args.begin(), args.end());
735 auto it = cache_.find({function_id, arg_vector});
736 if (it != cache_.end()) {
737 *result = it->second;
738 return Status::OK();
739 }
740
741 const FunctionBody* fbody;
742 const ConfigProto* config = nullptr;
743 TF_RETURN_IF_ERROR(FindFunctionBody(fn_name_attrs, &fbody, &config));
744
745 absl::optional<ConfigProto> config_proto;
746 if (config) {
747 config_proto = *config;
748 }
749
750 TF_RETURN_WITH_CONTEXT_IF_ERROR(
751 CheckSignature(fbody->arg_types, args),
752 "Signature check failure while compiling: ", fn_name_attrs.name());
753
754 // Set shapes for _Arg nodes. They are useful for constant folding (e.g. an
755 // Xla op requires a compile-time constant input, and that input is shape of
756 // an _Arg node.
757 for (int i = 0, end = args.size(); i < end; i++) {
758 // Skip resource variables and tensor lists.
759 DataType dtype;
760 TF_RETURN_IF_ERROR(GetNodeAttr(fbody->arg_nodes[i]->def(), "T", &dtype));
761 if (dtype == DT_RESOURCE || dtype == DT_VARIANT) {
762 continue;
763 }
764
765 if (absl::holds_alternative<xla::Shape>(args[i].shape)) {
766 xla::Shape xla_shape = absl::get<xla::Shape>(args[i].shape);
767 TensorShape tensor_shape;
768 // If xla_shape is dynamic, prevent constant folding by not setting
769 // output_shapes.
770 if (XLAShapeToTensorShape(xla_shape, &tensor_shape).ok() &&
771 xla_shape.is_static()) {
772 fbody->arg_nodes[i]->ClearAttr("_output_shapes");
773 fbody->arg_nodes[i]->AddAttr("_output_shapes",
774 std::vector<TensorShape>{tensor_shape});
775 }
776 } else {
777 TensorShape tensor_shape = absl::get<TensorShape>(args[i].shape);
778 fbody->arg_nodes[i]->ClearAttr("_output_shapes");
779 fbody->arg_nodes[i]->AddAttr("_output_shapes",
780 std::vector<TensorShape>{tensor_shape});
781 }
782 }
783
784 std::unique_ptr<Graph> graph = GetGraph(fbody);
785
786 // _Arg and _Retval nodes don't exist in the stored subgraph for the function;
787 // they are added by the function body looked up. Therefore, they don't have
788 // core assignments here.
789 // Attempt to assign a core to each _Retval and _Arg. Chooses the
790 // lowest-numbered core that consumes the argument. We choose the
791 // lowest-numbered core so the assignment is deterministic.
792 for (Node* n : graph->nodes()) {
793 if (n->IsArg()) {
794 TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true));
795 }
796 }
797 // Do _Retval as a second loop, in case the retval's input is an _Arg (which
798 // may have gotten a device assignment from the first loop).
799 for (Node* n : graph->nodes()) {
800 if (n->IsRetval()) {
801 TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false));
802 }
803 }
804
805 if (VLOG_IS_ON(2)) {
806 VLOG(2) << "XlaCompiler::CompileFunction: "
807 << DumpGraphToFile(
808 absl::StrCat("xla_compile_function_", function_id), *graph);
809 }
810
811 VLOG(1) << "====================================================";
812 MlirBridgeRolloutPolicy policy = MlirBridgeRolloutPolicy::kDisabledByUser;
813 if (options.is_entry_computation) {
814 policy = GetMlirBridgeRolloutPolicy(
815 *graph, /*function_library=*/nullptr, config_proto,
816 /*uses_uninitialized_resource_args=*/AnyUninitializedResourceArg(args));
817 }
818 if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) {
819 VLOG(1) << "Using MLIR bridge to compile the function";
820 GraphDebugInfo debug_info;
821
822 std::vector<std::string> valid_control_rets =
823 GetValidControlRets(fbody->control_ret_nodes, *graph);
824
825 TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
826 std::move(*graph), mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
827 valid_control_rets, options_.device_type.type_string(),
828 options.use_tuple_arg, /*analyse_graph=*/false, *options_.flib_def,
829 debug_info, options_.shape_representation_fn, result));
830 } else {
831 VLOG(1) << "Using the old bridge to compile the function";
832 TF_RETURN_IF_ERROR(
833 CompileGraph(options, function_id, std::move(graph), args, result));
834 }
835 VLOG(1) << "====================================================";
836
837 cache_[{function_id, arg_vector}] = *result;
838 return Status::OK();
839 }
840
841 // Computes the XLA shape for argument 'arg'.
XLAShapeForArgument(const XlaCompiler::Argument & arg,bool is_entry_computation,const absl::optional<xla::HloSharding> & arg_sharding,xla::Shape * xla_shape) const842 Status XlaCompiler::XLAShapeForArgument(
843 const XlaCompiler::Argument& arg, bool is_entry_computation,
844 const absl::optional<xla::HloSharding>& arg_sharding,
845 xla::Shape* xla_shape) const {
846 switch (arg.kind) {
847 case XlaCompiler::Argument::kConstant:
848 LOG(FATAL) << "Unreachable case";
849 case XlaCompiler::Argument::kParameter: {
850 if (is_entry_computation) {
851 TensorShape shape;
852 if (absl::holds_alternative<TensorShape>(arg.shape)) {
853 shape = absl::get<TensorShape>(arg.shape);
854 } else {
855 TF_RETURN_IF_ERROR(
856 XLAShapeToTensorShape(absl::get<xla::Shape>(arg.shape), &shape));
857 }
858 TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn(
859 shape, arg.type,
860 /*use_fast_memory=*/false));
861 TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
862 arg_sharding, /*use_fast_memory=*/false,
863 options_.shape_representation_fn, xla_shape));
864 } else {
865 if (absl::holds_alternative<xla::Shape>(arg.shape)) {
866 *xla_shape = absl::get<xla::Shape>(arg.shape);
867 } else {
868 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(
869 arg.type, absl::get<TensorShape>(arg.shape), xla_shape));
870 }
871 }
872 return Status::OK();
873 }
874 case XlaCompiler::Argument::kTensorList: {
875 TF_RET_CHECK(absl::holds_alternative<xla::Shape>(arg.shape));
876 *xla_shape = absl::get<xla::Shape>(arg.shape);
877 return Status::OK();
878 }
879 case XlaCompiler::Argument::kConstantResource:
880 case XlaCompiler::Argument::kResource: {
881 TF_RET_CHECK(arg.initialized);
882
883 switch (arg.resource_kind) {
884 case XlaResource::kVariable: {
885 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
886 TF_ASSIGN_OR_RETURN(*xla_shape,
887 options_.shape_representation_fn(
888 absl::get<TensorShape>(arg.shape), arg.type,
889 /*use_fast_memory=*/arg.fast_mem));
890 TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
891 arg_sharding, arg.fast_mem, options_.shape_representation_fn,
892 xla_shape));
893 return Status::OK();
894 }
895 case XlaResource::kTensorArray: {
896 if (arg.max_array_size < 0) {
897 return errors::InvalidArgument(
898 "Negative max_array_size in XLAShapeForArgument");
899 }
900 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
901 TensorShape shape;
902 shape.AddDim(arg.max_array_size);
903 shape.AppendShape(absl::get<TensorShape>(arg.shape));
904 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape));
905
906 if (!arg.tensor_array_gradients.empty()) {
907 std::vector<xla::Shape> tuple_shape(
908 arg.tensor_array_gradients.size() + 1, *xla_shape);
909 *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape);
910 }
911 return Status::OK();
912 }
913 case XlaResource::kStack: {
914 if (arg.max_array_size < 0) {
915 return errors::InvalidArgument(
916 "Negative max_array_size in XLAShapeForArgument");
917 }
918 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
919 TensorShape shape;
920 shape.AddDim(arg.max_array_size);
921 shape.AppendShape(absl::get<TensorShape>(arg.shape));
922 xla::Shape buffer_shape;
923 TF_RETURN_IF_ERROR(
924 TensorShapeToXLAShape(arg.type, shape, &buffer_shape));
925 *xla_shape = xla::ShapeUtil::MakeTupleShape(
926 {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})});
927 return Status::OK();
928 }
929
930 case XlaResource::kInvalid:
931 return errors::Internal(
932 "Invalid resource type in XLAShapeForArgument()");
933 }
934 }
935 case XlaCompiler::Argument::kToken: {
936 *xla_shape = xla::ShapeUtil::MakeTokenShape();
937 return Status::OK();
938 }
939 case XlaCompiler::Argument::kInvalid:
940 return errors::Internal("Invalid argument type in XLAShapeForArgument()");
941 }
942 }
943
944 /* static */
PopulateArgumentFromResource(const XlaResource & resource,Argument * arg)945 void XlaCompiler::PopulateArgumentFromResource(const XlaResource& resource,
946 Argument* arg) {
947 arg->initialized = resource.initialized();
948 arg->kind = XlaCompiler::Argument::kResource;
949 arg->resource_kind = resource.kind();
950
951 arg->type = resource.type();
952 arg->shape = resource.shape();
953 arg->max_array_size = resource.max_array_size();
954 for (const auto& gradient : resource.tensor_array_gradients()) {
955 arg->tensor_array_gradients.insert(gradient.first);
956 }
957 arg->name = resource.name();
958 }
959
960 // Builds XLA computations for each of the arguments to the computation.
961 // `args` are the arguments to the computation.
BuildArguments(const Graph & graph,const std::vector<XlaCompiler::Argument> & args,bool use_tuple_arg,xla::XlaBuilder * builder,XlaContext * context,const std::map<int,xla::OpSharding> & arg_shardings,std::vector<XlaExpression> * arg_expressions,std::vector<int> * input_to_args,std::vector<xla::Shape> * input_shapes,bool is_entry_computation)962 Status XlaCompiler::BuildArguments(
963 const Graph& graph, const std::vector<XlaCompiler::Argument>& args,
964 bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context,
965 const std::map<int, xla::OpSharding>& arg_shardings,
966 std::vector<XlaExpression>* arg_expressions,
967 std::vector<int>* input_to_args, std::vector<xla::Shape>* input_shapes,
968 bool is_entry_computation) {
969 arg_expressions->resize(args.size());
970
971 // Argument numbers of arguments and resources that are to be passed to the
972 // XLA computation as runtime parameters. `input_to_args[a] = b` means that
973 // the a'th XLA input corresponds to the b'th original arg indexes.
974 input_to_args->clear();
975 input_to_args->reserve(args.size());
976
977 // Fills in constant arguments, and computes non-constant argument order.
978 for (std::vector<XlaCompiler::Argument>::size_type i = 0; i < args.size();
979 ++i) {
980 const XlaCompiler::Argument& arg = args[i];
981 XlaExpression& arg_expression = (*arg_expressions)[i];
982 switch (arg.kind) {
983 case XlaCompiler::Argument::kConstantResource:
984 case XlaCompiler::Argument::kResource: {
985 TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid);
986 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
987 // TODO(phawkins): this code assumes that resource arguments do not
988 // alias.
989 XlaResource* resource =
990 context->AddResource(absl::make_unique<XlaResource>(
991 arg.resource_kind, i, arg.name, arg.type,
992 absl::get<TensorShape>(arg.shape), xla::XlaOp(),
993 /*max_array_size=*/arg.max_array_size,
994 /*tensor_array_gradients=*/arg.tensor_array_gradients,
995 /*tensor_array_multiple_writes_aggregate=*/true,
996 arg.definition_stack_trace));
997 arg_expression =
998 arg.kind == XlaCompiler::Argument::kResource
999 ? XlaExpression::Resource(resource)
1000 : XlaExpression::ConstantResource(arg.constant_value, resource);
1001 if (arg.initialized) {
1002 input_to_args->push_back(i);
1003 }
1004 break;
1005 }
1006 case XlaCompiler::Argument::kParameter:
1007 case XlaCompiler::Argument::kTensorList:
1008 case XlaCompiler::Argument::kToken: {
1009 input_to_args->push_back(i);
1010 break;
1011 }
1012 case XlaCompiler::Argument::kConstant:
1013 arg_expression = XlaExpression::Constant(arg.constant_value);
1014 break;
1015 case XlaCompiler::Argument::kInvalid:
1016 return errors::Internal(
1017 "Unreachable case in BuildArguments() while filling constant args");
1018 }
1019 }
1020
1021 if (input_to_args->empty() && !use_tuple_arg) {
1022 return Status::OK();
1023 }
1024
1025 // `arg_to_inputs[c] = d` means that the c'th original arg index corresponds
1026 // to the d'th XLA input. Note that the value -1 corresponds to constants, or
1027 // other args that don't correspond to an input.
1028 std::vector<int> arg_to_inputs(args.size(), -1);
1029 for (int i = 0, end = input_to_args->size(); i < end; i++) {
1030 arg_to_inputs[input_to_args->at(i)] = i;
1031 }
1032
1033 std::vector<xla::Shape> arg_shapes(input_to_args->size());
1034 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1035 // Computes the shapes of non-constant arguments.
1036 auto arg_sharding = arg_shardings.find((*input_to_args)[i]);
1037 absl::optional<xla::HloSharding> sharding;
1038 if (arg_sharding != arg_shardings.end()) {
1039 TF_ASSIGN_OR_RETURN(auto hlo_sharding,
1040 xla::HloSharding::FromProto(arg_sharding->second));
1041 sharding = hlo_sharding;
1042 }
1043 TF_RETURN_IF_ERROR(XLAShapeForArgument(args[(*input_to_args)[i]],
1044 is_entry_computation, sharding,
1045 &arg_shapes[i]));
1046 }
1047
1048 if (use_tuple_arg) {
1049 input_shapes->push_back(xla::ShapeUtil::MakeTupleShape(arg_shapes));
1050 } else {
1051 *input_shapes = arg_shapes;
1052 }
1053
1054 // Attach a common operator name as metadata. This has no semantic effect — it
1055 // merely makes the HLO graph more readable when visualized via TensorBoard,
1056 // since TensorBoard forms groups out of operators with similar names.
1057 xla::OpMetadata arg_metadata;
1058 arg_metadata.set_op_name("XLA_Args");
1059 builder->SetOpMetadata(arg_metadata);
1060
1061 // Build parameter handles for non-constant arguments.
1062 std::vector<xla::XlaOp> arg_handles(input_to_args->size());
1063 if (use_tuple_arg) {
1064 xla::XlaOp tuple;
1065 if (is_entry_computation) {
1066 xla::OpSharding tuple_sharding;
1067 tuple_sharding.set_type(xla::OpSharding::TUPLE);
1068 for (int64_t parameter : *input_to_args) {
1069 auto it = arg_shardings.find(parameter);
1070 *tuple_sharding.add_tuple_shardings() =
1071 it == arg_shardings.end() ? xla::sharding_builder::AssignDevice(0)
1072 : it->second;
1073 }
1074 std::vector<bool> is_same_across_replicas;
1075 for (int i = 0, end = input_to_args->size(); i < end; ++i) {
1076 // Add an entry to is_same_across_replicas for every leaf buffer.
1077 is_same_across_replicas.insert(
1078 is_same_across_replicas.end(),
1079 xla::ShapeUtil::GetLeafCount(arg_shapes[i]),
1080 args[input_to_args->at(i)].is_same_data_across_replicas);
1081 }
1082 xla::XlaScopedShardingAssignment assign_tuple_sharding(
1083 builder, input_to_args->empty() ? absl::optional<xla::OpSharding>()
1084 : tuple_sharding);
1085 tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple",
1086 is_same_across_replicas);
1087 } else {
1088 tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
1089 }
1090
1091 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1092 auto it = arg_shardings.find(i);
1093 xla::XlaScopedShardingAssignment assign_sharding(
1094 builder, it == arg_shardings.end() ? absl::optional<xla::OpSharding>()
1095 : it->second);
1096 auto& arg = args[input_to_args->at(i)];
1097
1098 xla::OpMetadata arg_metadata;
1099 arg_metadata.set_op_name(arg.node_name);
1100 builder->SetOneShotOpMetadata(arg_metadata);
1101 arg_handles[i] = xla::GetTupleElement(tuple, i);
1102 }
1103 } else {
1104 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1105 auto it = arg_shardings.find(i);
1106 xla::XlaScopedShardingAssignment assign_sharding(
1107 builder, it == arg_shardings.end() ? absl::optional<xla::OpSharding>()
1108 : it->second);
1109 if (is_entry_computation) {
1110 // Add an entry to is_same_across_replicas for every leaf buffer.
1111 std::vector<bool> is_same_across_replicas(
1112 xla::ShapeUtil::GetLeafCount((*input_shapes)[i]),
1113 args[input_to_args->at(i)].is_same_data_across_replicas);
1114 arg_handles[i] =
1115 xla::Parameter(builder, i, (*input_shapes)[i],
1116 absl::StrCat("arg", i), is_same_across_replicas);
1117 } else {
1118 arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i],
1119 absl::StrCat("arg", i));
1120 }
1121 }
1122 }
1123
1124 builder->ClearOpMetadata();
1125
1126 // Fill in the handles in non-constant arguments, and reshape parameters
1127 // back to their correct shapes.
1128 VLOG(2) << "XLA computation inputs:";
1129 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1130 const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
1131 VLOG(2) << " XLA arg " << i
1132 << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
1133 << " name: " << arg.name << " TF arg " << input_to_args->at(i)
1134 << " node name: " << arg.node_name
1135 << (arg_shardings.find(i) == arg_shardings.end()
1136 ? ""
1137 : absl::StrCat(" sharding: ",
1138 arg_shardings.at(i).DebugString()));
1139 XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)];
1140 switch (arg.kind) {
1141 case XlaCompiler::Argument::kConstantResource:
1142 case XlaCompiler::Argument::kResource: {
1143 TF_RET_CHECK(arg.initialized);
1144 XlaResource* resource = arg_expression.resource();
1145 TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients,
1146 arg_handles[i], builder));
1147 VLOG(2) << " resource: num_gradients: "
1148 << arg.tensor_array_gradients.size();
1149 break;
1150 }
1151 case XlaCompiler::Argument::kParameter:
1152 // Reshape parameters back to their correct shapes.
1153 // TODO(b/76097077): propagate device assignments onto arguments and
1154 // return values of functions, and then reshape unconditionally.
1155 if (is_entry_computation) {
1156 arg_expression = XlaExpression::XlaOp(
1157 xla::Reshape(arg_handles[i], arg.DimensionSizes()), arg.type);
1158 } else {
1159 arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
1160 if (arg.value_bound) {
1161 TF_RET_CHECK(arg.value_dynamism);
1162 // Propagate upper bound and value dynamism to arg_expression.
1163 arg_expression.set_value_bound(arg.value_bound.value());
1164 arg_expression.set_value_dynamism(arg.value_dynamism.value());
1165 }
1166 }
1167 break;
1168 case XlaCompiler::Argument::kTensorList: {
1169 arg_expression = XlaExpression::TensorList(arg_handles[i]);
1170 break;
1171 }
1172 case XlaCompiler::Argument::kToken: {
1173 arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
1174 break;
1175 }
1176 case XlaCompiler::Argument::kConstant:
1177 case XlaCompiler::Argument::kInvalid:
1178 return errors::Internal(
1179 "Unreachable case in BuildArguments() while filling handles");
1180 }
1181 }
1182
1183 return Status::OK();
1184 }
1185
1186 namespace {
1187
1188 // Check that the ops of all non-functional nodes have been registered.
ValidateFunctionDef(const FunctionDef * fdef,const FunctionLibraryDefinition & flib_def)1189 Status ValidateFunctionDef(const FunctionDef* fdef,
1190 const FunctionLibraryDefinition& flib_def) {
1191 for (const NodeDef& node : fdef->node_def()) {
1192 const string& op = node.op();
1193 if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) {
1194 continue;
1195 }
1196 const OpDef* op_def;
1197 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(op, &op_def));
1198 }
1199 return Status::OK();
1200 }
1201
1202 // If node is PartitionedCall or StatefulPartitionedCall, returns the
1203 // name from the "f" attr, else returns node.def().op().
1204 // Returned pointer points to the internal string either in node's attributes
1205 // or in its NodeDef. This pointer is valid as long as the node has not been
1206 // modified.
GetPotentialFunctionName(const Node & node,const string ** name)1207 Status GetPotentialFunctionName(const Node& node, const string** name) {
1208 if (node.IsPartitionedCall()) {
1209 const AttrValue* attr_value;
1210 TF_RETURN_IF_ERROR(
1211 node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
1212 if (!attr_value->has_func()) {
1213 return errors::InvalidArgument(
1214 "The attribute value for attribute 'f' in node ", node.DebugString(),
1215 " does not have 'func' field set");
1216 }
1217 *name = &attr_value->func().name();
1218 return Status::OK();
1219 }
1220 *name = &node.type_string();
1221 return Status::OK();
1222 }
1223
1224 // Check that the graph doesn't have any invalid nodes (e.g. incompatible with
1225 // given device_type, invalid data type, missing attributes...)
ValidateGraph(const Graph * graph,const FunctionLibraryDefinition & flib_def,const DeviceType & device_type,const string & name)1226 Status ValidateGraph(const Graph* graph,
1227 const FunctionLibraryDefinition& flib_def,
1228 const DeviceType& device_type, const string& name) {
1229 // Make sure the XLA compilation kernels are registered. This operation is
1230 // idempotent so it is fine if someone called it already.
1231 XlaOpRegistry::RegisterCompilationKernels();
1232
1233 auto maybe_error = [&](const Node* node, const Status& s) -> Status {
1234 if (!s.ok()) {
1235 std::string errmsg = absl::StrCat(
1236 "Detected unsupported operations when trying to compile graph ", name,
1237 " on ", device_type.type_string(), ": ", node->def().op(), " (",
1238 s.error_message(), ")", FormatNodeForError(*node));
1239 if (absl::StrContains(device_type.type_string(), "TPU")) {
1240 absl::StrAppend(&errmsg,
1241 "\nOne approach is to outside compile the unsupported "
1242 "ops to run on CPUs by enabling soft placement "
1243 "`tf.config.set_soft_device_placement(True)`."
1244 " This has a potential performance penalty.\n");
1245 }
1246 if (std::shared_ptr<AbstractStackTrace> stack_trace =
1247 node->GetStackTrace()) {
1248 absl::StrAppend(
1249 &errmsg, "\nThe op is created at: \n",
1250 stack_trace->ToString({/*show_line_contents =*/true,
1251 /*filter_common_prefix =*/true,
1252 /*drop_internal_frames =*/true}));
1253 }
1254
1255 return errors::InvalidArgument(errmsg);
1256 }
1257 return Status::OK();
1258 };
1259
1260 for (const Node* node : graph->nodes()) {
1261 if (node->type_string() == FunctionLibraryDefinition::kGradientOp) {
1262 continue;
1263 }
1264 const string* function_name;
1265 TF_RETURN_IF_ERROR(GetPotentialFunctionName(*node, &function_name));
1266 const FunctionDef* fdef = flib_def.Find(*function_name);
1267 Status s;
1268 if (fdef) {
1269 s = ValidateFunctionDef(fdef, flib_def);
1270 TF_RETURN_IF_ERROR(maybe_error(node, s));
1271 continue;
1272 }
1273 const OpDef* op_def;
1274 s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def);
1275 TF_RETURN_IF_ERROR(maybe_error(node, s));
1276 TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def));
1277 s = FindKernelDef(device_type, node->def(), nullptr, nullptr);
1278 TF_RETURN_IF_ERROR(maybe_error(node, s));
1279 }
1280 return Status::OK();
1281 }
1282
ConvertConstantsToExpressions(xla::XlaBuilder * builder,absl::Span<XlaExpression> expressions)1283 void ConvertConstantsToExpressions(xla::XlaBuilder* builder,
1284 absl::Span<XlaExpression> expressions) {
1285 for (XlaExpression& expression : expressions) {
1286 if (expression.kind() == XlaExpression::Kind::kConstant) {
1287 expression =
1288 XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype());
1289 }
1290 }
1291 }
1292
1293 } // namespace
1294
CompileGraph(const XlaCompiler::CompileOptions & options,string const & name,std::unique_ptr<Graph> graph,absl::Span<const XlaCompiler::Argument> args,CompilationResult * result)1295 Status XlaCompiler::CompileGraph(
1296 const XlaCompiler::CompileOptions& options, string const& name,
1297 std::unique_ptr<Graph> graph, absl::Span<const XlaCompiler::Argument> args,
1298 CompilationResult* result) {
1299 VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name;
1300
1301 TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes(
1302 graph.get(), options_.flib_def, local_flib_def_.get()));
1303 TF_RETURN_IF_ERROR(RearrangeFunctionArguments(
1304 [this](const NameAttrList& function, const FunctionBody** fbody) {
1305 return FindFunctionBody(function, fbody);
1306 },
1307 graph.get(), local_flib_def_.get(),
1308 pflr_->GetFunctionLibraryDefinition()));
1309
1310 if (VLOG_IS_ON(2)) {
1311 VLOG(2) << "XlaCompiler::CompileGraph: "
1312 << DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph,
1313 flib_runtime_->GetFunctionLibraryDefinition());
1314 }
1315
1316 // Report the error here if initialization failed.
1317 TF_RETURN_IF_ERROR(initialization_status_);
1318
1319 // Detect invalid nodes.
1320 // FunctionalizeControlFlow may remove some nodes from the graph.
1321 TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
1322 options_.device_type, name));
1323 xla::XlaBuilder builder(name);
1324 XlaContext* context = new XlaContext(this, &builder, graph.get());
1325 core::ScopedUnref context_unref(context);
1326
1327 std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());
1328 int token_input_index = -1;
1329 std::unique_ptr<xla::XlaOp> token_output;
1330 if (options.add_token_input_output) {
1331 // Add extra token input.
1332 token_input_index = real_args.size();
1333
1334 XlaCompiler::Argument token_arg;
1335 token_arg.kind = XlaCompiler::Argument::kToken;
1336 real_args.push_back(token_arg);
1337 }
1338
1339 std::map<int, xla::OpSharding> arg_shardings;
1340 std::map<int, xla::OpSharding> retval_shardings;
1341 TF_ASSIGN_OR_RETURN(std::tie(arg_shardings, retval_shardings),
1342 ComputeArgAndRetvalShardings(*graph));
1343
1344 std::vector<XlaExpression> arg_expressions;
1345 TF_RETURN_IF_ERROR(BuildArguments(
1346 *graph, real_args, options.use_tuple_arg, &builder, context,
1347 arg_shardings, &arg_expressions, &result->input_mapping,
1348 &result->xla_input_shapes, options.is_entry_computation));
1349 context->set_args(std::move(arg_expressions));
1350
1351 PushNodeTokenMapping();
1352 // Use std::set instead of std::unordered_set to ensure determinism.
1353 std::set<std::string> output_node_token_inputs;
1354 if (token_input_index != -1) {
1355 // Original token comes from input.
1356 auto arg_expression = context->args()[token_input_index];
1357 TF_RETURN_IF_ERROR(
1358 SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle()));
1359
1360 // Calculate token inputs for output token.
1361 output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph);
1362
1363 // If there's no side-effecting op in the graph, use token input as token
1364 // output.
1365 if (output_node_token_inputs.empty()) {
1366 output_node_token_inputs.insert(kXlaTokenArgNodeName);
1367 }
1368 } else if (options.is_entry_computation) {
1369 // Original token is manually created.
1370 if (HasSideEffectingNodes(*graph)) {
1371 TF_RETURN_IF_ERROR(
1372 SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder)));
1373 }
1374 }
1375
1376 Status execute_status = ExecuteGraph(context, std::move(graph), device_,
1377 flib_runtime_, NextStepId());
1378 if (!execute_status.ok()) {
1379 VLOG(1) << "Failed executing graph " << name;
1380 return execute_status;
1381 }
1382 if (token_input_index != -1) {
1383 // Add extra token output.
1384 std::vector<xla::XlaOp> token_inputs;
1385 for (const auto& node_name : output_node_token_inputs) {
1386 auto token_or = GetNodeToken(node_name);
1387 TF_RETURN_IF_ERROR(token_or.status());
1388 token_inputs.push_back(token_or.ValueOrDie());
1389 }
1390 token_output.reset(new xla::XlaOp(xla::AfterAll(&builder, token_inputs)));
1391 }
1392 TF_RETURN_IF_ERROR(PopNodeTokenMapping());
1393
1394 int num_nonconst_outputs;
1395 int num_computation_outputs;
1396 result->computation = std::make_shared<xla::XlaComputation>();
1397 result->outputs.resize(context->retvals().size());
1398 std::vector<XlaExpression> retvals = context->retvals();
1399 ConvertConstantsToExpressions(&builder, absl::Span<XlaExpression>(retvals));
1400 TF_RETURN_IF_ERROR(BuildComputation(
1401 real_args, retvals, arg_shardings, retval_shardings, context->resources(),
1402 std::move(token_output),
1403 options.is_entry_computation ? options_.shape_representation_fn
1404 : ShapeRepresentationFn{},
1405 options.is_entry_computation,
1406 options.return_updated_values_for_all_resources,
1407 options.always_return_tuple, options.use_tuple_arg,
1408 options.alias_resource_update, &builder, result->computation.get(),
1409 &num_computation_outputs, &num_nonconst_outputs, &result->outputs,
1410 &result->resource_updates, &result->xla_output_shape,
1411 result->input_mapping));
1412
1413 VLOG(2) << "Outputs: total: " << context->retvals().size()
1414 << " nonconstant: " << num_nonconst_outputs;
1415 VLOG(2) << "XLA output shape: "
1416 << xla::ShapeUtil::HumanStringWithLayout(result->xla_output_shape);
1417 result->collective_reduce_info = context->GetCollectiveReduceV2OpInfo();
1418 return Status::OK();
1419 }
1420
GetChannelHandle(const string & key,xla::ChannelHandle * channel)1421 Status XlaCompiler::GetChannelHandle(const string& key,
1422 xla::ChannelHandle* channel) {
1423 auto result = channels_.emplace(key, xla::ChannelHandle());
1424 if (result.second) {
1425 TF_ASSIGN_OR_RETURN(result.first->second, client()->CreateChannelHandle());
1426 }
1427 *channel = result.first->second;
1428 VLOG(1) << "Channel: " << key << " " << channel->DebugString();
1429 return Status::OK();
1430 }
1431
GetHostToDeviceChannelHandle(const string & key,xla::ChannelHandle * channel)1432 Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key,
1433 xla::ChannelHandle* channel) {
1434 auto result = channels_.emplace(key, xla::ChannelHandle());
1435 if (result.second) {
1436 TF_ASSIGN_OR_RETURN(result.first->second,
1437 client()->CreateHostToDeviceChannelHandle());
1438 }
1439 *channel = result.first->second;
1440 VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString();
1441 return Status::OK();
1442 }
1443
GetDeviceToHostChannelHandle(const string & key,xla::ChannelHandle * channel)1444 Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key,
1445 xla::ChannelHandle* channel) {
1446 auto result = channels_.emplace(key, xla::ChannelHandle());
1447 if (result.second) {
1448 TF_ASSIGN_OR_RETURN(result.first->second,
1449 client()->CreateDeviceToHostChannelHandle());
1450 }
1451 *channel = result.first->second;
1452 VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString();
1453 return Status::OK();
1454 }
1455
1456 namespace {
1457
SetTransfer(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes,tf2xla::HostTransferMetadata * transfer)1458 void SetTransfer(const string& key, absl::Span<const DataType> types,
1459 absl::Span<const TensorShape> shapes,
1460 tf2xla::HostTransferMetadata* transfer) {
1461 transfer->set_key(key);
1462 CHECK(types.size() == shapes.size());
1463 for (int i = 0, end = types.size(); i < end; ++i) {
1464 tf2xla::TensorMetadata* metadata = transfer->add_metadata();
1465 metadata->set_type(types[i]);
1466 shapes[i].AsProto(metadata->mutable_shape());
1467 }
1468 }
1469
1470 } // namespace
1471
SetDeviceToHostMetadata(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes)1472 Status XlaCompiler::SetDeviceToHostMetadata(
1473 const string& key, absl::Span<const DataType> types,
1474 absl::Span<const TensorShape> shapes) {
1475 if (host_compute_sends_.find(key) != host_compute_sends_.end()) {
1476 tf2xla::HostTransferMetadata& existing_transfer = host_compute_sends_[key];
1477 tf2xla::HostTransferMetadata new_transfer;
1478 SetTransfer(key, types, shapes, &new_transfer);
1479 if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) {
1480 return Status::OK();
1481 } else {
1482 return errors::InvalidArgument(
1483 "Duplicate calls to SetDeviceToHostMetadata with key ", key);
1484 }
1485 }
1486 tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key];
1487 SetTransfer(key, types, shapes, &transfer);
1488 return Status::OK();
1489 }
1490
GetDeviceToHostShapes(const string & key,std::vector<TensorShape> * shapes) const1491 Status XlaCompiler::GetDeviceToHostShapes(
1492 const string& key, std::vector<TensorShape>* shapes) const {
1493 const auto iter = host_compute_sends_.find(key);
1494 if (iter == host_compute_sends_.end()) {
1495 return errors::InvalidArgument(
1496 "No host compute send shapes registered for key ", key);
1497 }
1498 shapes->clear();
1499 for (int i = 0; i < iter->second.metadata_size(); ++i) {
1500 TensorShape shape(iter->second.metadata(i).shape());
1501 shapes->push_back(shape);
1502 }
1503 return Status::OK();
1504 }
1505
SetHostToDeviceMetadata(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes)1506 Status XlaCompiler::SetHostToDeviceMetadata(
1507 const string& key, absl::Span<const DataType> types,
1508 absl::Span<const TensorShape> shapes) {
1509 if (host_compute_recvs_.find(key) != host_compute_recvs_.end()) {
1510 tf2xla::HostTransferMetadata& existing_transfer = host_compute_recvs_[key];
1511 tf2xla::HostTransferMetadata new_transfer;
1512 SetTransfer(key, types, shapes, &new_transfer);
1513 if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) {
1514 return Status::OK();
1515 } else {
1516 return errors::InvalidArgument(
1517 "Duplicate calls to SetHostToDeviceMetadata with key ", key);
1518 }
1519 }
1520 tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key];
1521 SetTransfer(key, types, shapes, &transfer);
1522 return Status::OK();
1523 }
1524
GetHostComputeControlDependency(const string & host_compute_name,xla::XlaOp * handle)1525 Status XlaCompiler::GetHostComputeControlDependency(
1526 const string& host_compute_name, xla::XlaOp* handle) {
1527 const auto iter = host_compute_control_output_.find(host_compute_name);
1528 if (iter == host_compute_control_output_.end()) {
1529 return errors::InvalidArgument(
1530 "No registered control handle for host compute Op '", host_compute_name,
1531 "'");
1532 } else {
1533 *handle = iter->second;
1534 }
1535 return Status::OK();
1536 }
1537
SetHostComputeControlDependency(const string & host_compute_name,const xla::XlaOp & handle)1538 Status XlaCompiler::SetHostComputeControlDependency(
1539 const string& host_compute_name, const xla::XlaOp& handle) {
1540 if (host_compute_control_output_.find(host_compute_name) !=
1541 host_compute_control_output_.end()) {
1542 return errors::InvalidArgument(
1543 "Duplicate control handles registered for for host compute Op ",
1544 host_compute_name);
1545 }
1546 host_compute_control_output_[host_compute_name] = handle;
1547 return Status::OK();
1548 }
1549
PushNodeTokenMapping()1550 void XlaCompiler::PushNodeTokenMapping() {
1551 node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{});
1552 }
1553
PopNodeTokenMapping()1554 Status XlaCompiler::PopNodeTokenMapping() {
1555 if (node_token_mapping_stack_.empty()) {
1556 return errors::FailedPrecondition(
1557 "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is "
1558 "empty.");
1559 }
1560 node_token_mapping_stack_.pop();
1561 return Status::OK();
1562 }
1563
SetNodeToken(const string & node_name,const xla::XlaOp & op)1564 Status XlaCompiler::SetNodeToken(const string& node_name,
1565 const xla::XlaOp& op) {
1566 if (node_token_mapping_stack_.empty()) {
1567 return errors::FailedPrecondition(
1568 "Calling SetNodeToken() when node_token_mapping_stack_ is "
1569 "empty.");
1570 }
1571 auto insert_result = node_token_mapping_stack_.top().insert({node_name, op});
1572 if (!insert_result.second) {
1573 return errors::FailedPrecondition("Token mapping already exists for node ",
1574 node_name);
1575 }
1576 return Status::OK();
1577 }
1578
GetNodeToken(const string & node_name)1579 StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
1580 if (node_token_mapping_stack_.empty()) {
1581 return errors::FailedPrecondition(
1582 "Calling GetNodeToken() when node_token_mapping_stack_ is "
1583 "empty.");
1584 }
1585 auto iter = node_token_mapping_stack_.top().find(node_name);
1586 if (iter == node_token_mapping_stack_.top().end()) {
1587 return errors::FailedPrecondition("Cannot find token mapping for node ",
1588 node_name);
1589 }
1590 return iter->second;
1591 }
1592
1593 } // namespace tensorflow
1594