1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
17
18 #include <numeric>
19 #include <vector>
20
21 #include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "absl/types/variant.h"
25 #include "tensorflow/compiler/jit/defs.h"
26 #include "tensorflow/compiler/jit/flags.h"
27 #include "tensorflow/compiler/jit/shape_inference.h"
28 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
29 #include "tensorflow/compiler/mlir/utils/array_container_utils.h"
30 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
31 #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
32 #include "tensorflow/compiler/tf2xla/shape_util.h"
33 #include "tensorflow/compiler/tf2xla/sharding_util.h"
34 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
35 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
36 #include "tensorflow/compiler/tf2xla/type_util.h"
37 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
38 #include "tensorflow/compiler/tf2xla/xla_context.h"
39 #include "tensorflow/compiler/xla/client/client_library.h"
40 #include "tensorflow/compiler/xla/client/xla_builder.h"
41 #include "tensorflow/compiler/xla/client/xla_computation.h"
42 #include "tensorflow/compiler/xla/protobuf_util.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/core/common_runtime/device.h"
46 #include "tensorflow/core/common_runtime/executor.h"
47 #include "tensorflow/core/common_runtime/function.h"
48 #include "tensorflow/core/common_runtime/graph_constructor.h"
49 #include "tensorflow/core/common_runtime/graph_optimizer.h"
50 #include "tensorflow/core/framework/attr_value_util.h"
51 #include "tensorflow/core/framework/function.h"
52 #include "tensorflow/core/framework/node_def_util.h"
53 #include "tensorflow/core/framework/types.h"
54 #include "tensorflow/core/graph/node_builder.h"
55 #include "tensorflow/core/lib/core/errors.h"
56 #include "tensorflow/core/lib/gtl/cleanup.h"
57 #include "tensorflow/core/lib/hash/hash.h"
58 #include "tensorflow/core/platform/logging.h"
59 #include "tensorflow/core/protobuf/error_codes.pb.h"
60 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
61 #include "tensorflow/core/util/dump_graph.h"
62
63 namespace tensorflow {
64 namespace {
65
66 // Checks that arguments `args` match types `types`.
CheckSignature(const DataTypeVector & types,absl::Span<const XlaCompiler::Argument> args)67 Status CheckSignature(const DataTypeVector& types,
68 absl::Span<const XlaCompiler::Argument> args) {
69 if (args.size() != types.size()) {
70 return errors::Internal("Compilation arguments have ", args.size(),
71 " elements while function has ", types.size());
72 }
73 for (int i = 0, end = types.size(); i < end; ++i) {
74 // Don't perform type checks on resource variables and tensor
75 // lists (DT_VARIANT) as we have to trick the type system in order to
76 // plumb them through. DT_VARIANTS are wrapped in a DT_UINT8 tensor.
77 if (types[i] != args[i].type && types[i] != DT_RESOURCE &&
78 types[i] != DT_VARIANT) {
79 return errors::Internal(
80 "Argument ", i, " has declared type ", DataTypeString(args[i].type),
81 " but function parameter has type ", DataTypeString(types[i]));
82 }
83 }
84 return Status::OK();
85 }
86
87 // Uses the _Arg and _Retval nodes in the graph to determine an OpSharding for
88 // each argument and return value.
89 xla::StatusOr<
90 std::pair<std::map<int, xla::OpSharding>, std::map<int, xla::OpSharding>>>
ComputeArgAndRetvalShardings(const Graph & graph)91 ComputeArgAndRetvalShardings(const Graph& graph) {
92 auto get_sharding_for_node =
93 [](const Node* n) -> xla::StatusOr<absl::optional<xla::OpSharding>> {
94 TF_ASSIGN_OR_RETURN(
95 auto sharding,
96 ParseShardingFromDevice(*n, std::numeric_limits<int32>::max(),
97 /*add_metadata=*/false));
98 return sharding;
99 };
100 std::map<int, xla::OpSharding> arg_shardings;
101 std::map<int, xla::OpSharding> retval_shardings;
102 for (const Node* n : graph.nodes()) {
103 if (n->IsArg()) {
104 TF_ASSIGN_OR_RETURN(auto sharding, get_sharding_for_node(n));
105 if (!sharding.has_value()) continue;
106 int index;
107 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
108 TF_RET_CHECK(index >= 0) << "Negative _Arg index";
109 arg_shardings[index] = std::move(*sharding);
110 } else if (n->IsRetval()) {
111 TF_ASSIGN_OR_RETURN(auto sharding, get_sharding_for_node(n));
112 if (!sharding.has_value()) continue;
113 int index;
114 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
115 TF_RET_CHECK(index >= 0) << "Negative _Retval index";
116 retval_shardings[index] = std::move(*sharding);
117 }
118 }
119 return std::make_pair(std::move(arg_shardings), std::move(retval_shardings));
120 }
121
ExecuteGraph(XlaContext * xla_context,std::unique_ptr<Graph> graph,XlaCompilationDevice * device,FunctionLibraryRuntime * flib,int64 step_id)122 Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
123 XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
124 int64 step_id) {
125 // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
126 // resource manager takes ownership via Create, and unrefs via Cleanup. We
127 // explicitly add a reference to ensure the refcount at entry is maintained at
128 // all exit points; Create and Cleanup are always called in this function.
129 //
130 // The Executor requires us to use ScopedStepContainer. We wrap it in a
131 // unique_ptr so we can capture the cleanup status in the end.
132 xla_context->Ref();
133 Status status;
134 auto step_container = absl::make_unique<ScopedStepContainer>(
135 step_id, [&status, device](const string& name) {
136 status = device->resource_manager()->Cleanup(name);
137 });
138 TF_RETURN_IF_ERROR(step_container->Create(device->resource_manager(),
139 XlaContext::kXlaContextResourceName,
140 xla_context));
141
142 GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
143 TF_RETURN_IF_ERROR(graph_compiler.Compile());
144 // Explicitly clean up the step container, to capture the cleanup status.
145 step_container.reset();
146 return Status::OK();
147 }
148
149 // Builds the XLA computation.
150 // - `args` is the list of input arguments
151 // - `retvals` is the list of retvals produced by _Retval operators, in index
152 // order.
153 // - `arg_shardings` and `retval_shardings` are mapping from arg/return indices
154 // to sharding.
155 // - If `return_updated_values_for_all_resources` is true, all resources will be
156 // included in `resource_updates`, regardless of whether their value changed.
157 // - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
158 // - Sets `*resource_updates` to a description of resources whose values are
159 // written by the computation; the variable writes are the last
160 // - `resource_updates.size()` return values from the computation. Each entry in
161 // `resource_updates` is a ResourceUpdate, whose `index` is the index of a
162 // resource variable argument to the computation to be updated, and `type` is
163 // the type of the final output.
BuildComputation(const std::vector<XlaCompiler::Argument> & args,const std::vector<XlaExpression> & retvals,const std::map<int,xla::OpSharding> & arg_shardings,const std::map<int,xla::OpSharding> & retval_shardings,const std::vector<std::unique_ptr<XlaResource>> & resources,std::unique_ptr<xla::XlaOp> token_output,const XlaCompiler::ShapeRepresentationFn & shape_representation_fn,bool is_entry_computation,bool return_updated_values_for_all_resources,bool always_return_tuple,bool use_tuple_arg,bool alias_resource_update,xla::XlaBuilder * builder,xla::XlaComputation * computation,int * num_computation_outputs,int * num_nonconst_outputs,std::vector<XlaCompiler::OutputDescription> * outputs,std::vector<XlaCompiler::ResourceUpdate> * resource_updates,xla::Shape * output_shape,absl::Span<int const> input_mapping)164 Status BuildComputation(
165 const std::vector<XlaCompiler::Argument>& args,
166 const std::vector<XlaExpression>& retvals,
167 const std::map<int, xla::OpSharding>& arg_shardings,
168 const std::map<int, xla::OpSharding>& retval_shardings,
169 const std::vector<std::unique_ptr<XlaResource>>& resources,
170 std::unique_ptr<xla::XlaOp> token_output,
171 const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
172 bool is_entry_computation, bool return_updated_values_for_all_resources,
173 bool always_return_tuple, bool use_tuple_arg, bool alias_resource_update,
174 xla::XlaBuilder* builder, xla::XlaComputation* computation,
175 int* num_computation_outputs, int* num_nonconst_outputs,
176 std::vector<XlaCompiler::OutputDescription>* outputs,
177 std::vector<XlaCompiler::ResourceUpdate>* resource_updates,
178 xla::Shape* output_shape, absl::Span<int const> input_mapping) {
179 // Attach a common operator name as metadata. This has no semantic effect — it
180 // merely makes the HLO graph more readable when visualized via TensorBoard,
181 // since TensorBoard forms groups out of operators with similar names.
182 xla::OpMetadata retval_metadata;
183 retval_metadata.set_op_name("XLA_Retvals");
184 builder->SetOpMetadata(retval_metadata);
185 VLOG(1) << "Building new computation";
186 auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); });
187
188 // Builds a no-op XLA computation. We need to set the sharding of outputs, but
189 // cannot change the sharding of the existing output op. To do this, we build
190 // a new identity op to which shardings can be applied.
191 auto identity_op = [builder](xla::XlaOp op) {
192 return xla::GetTupleElement(xla::Tuple(builder, {op}), 0);
193 };
194
195 std::vector<xla::XlaOp> elems;
196 elems.reserve(retvals.size());
197
198 // Keeps track of sharding of each retval. If a retval is not in this list,
199 // replicate sharding is used. The first element is the output index, second
200 // element is the sharding.
201 std::unordered_map<int, xla::OpSharding> retval_index_and_sharding;
202 for (int i = 0, end = retvals.size(); i < end; ++i) {
203 XlaCompiler::OutputDescription& output = (*outputs)[i];
204 const XlaExpression& retval = retvals[i];
205 output.type = retval.dtype();
206 switch (retval.kind()) {
207 case XlaExpression::Kind::kConstant:
208 output.is_constant = true;
209 output.constant_value = *retval.constant_value();
210 output.shape = output.constant_value.shape();
211 break;
212
213 case XlaExpression::Kind::kTensorList: {
214 output.is_tensor_list = true;
215 xla::XlaOp value = retval.handle();
216 elems.push_back(value);
217 break;
218 }
219
220 case XlaExpression::Kind::kXlaOp: {
221 output.is_constant = false;
222 TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
223 xla::XlaOp value = retval.handle();
224 auto it = retval_shardings.find(i);
225 absl::optional<xla::OpSharding> sharding =
226 it == retval_shardings.end() ? absl::optional<xla::OpSharding>()
227 : it->second;
228 if (it != retval_shardings.end()) {
229 retval_index_and_sharding[elems.size()] = it->second;
230 }
231 if (shape_representation_fn) {
232 TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(value));
233 TF_ASSIGN_OR_RETURN(value,
234 ReshapeWithCorrectRepresentationAndSharding(
235 builder, value, original_shape,
236 shape_representation_fn, sharding,
237 /*fast_mem=*/false));
238 }
239 if (it != retval_shardings.end()) {
240 xla::XlaScopedShardingAssignment assign_sharding(builder, sharding);
241 // Apply the sharding to the output, if there is a core assignment.
242 value = identity_op(value);
243 }
244
245 elems.push_back(value);
246 break;
247 }
248
249 case XlaExpression::Kind::kResource:
250 // Resources will be pushed into elems later when processing resource
251 // arguments below.
252 output.is_constant = false;
253 output.input_index = retval.resource()->arg_num();
254 output.shape = retval.resource()->shape();
255 break;
256
257 case XlaExpression::Kind::kInvalid:
258 return errors::InvalidArgument(
259 "Invalid expression returned by computation. "
260 "This probably means a return value was not set.");
261 }
262 }
263 *num_nonconst_outputs = elems.size();
264
265 // Add return values for resources whose values have changed.
266 std::vector<const XlaResource*> arg_resources;
267 arg_resources.reserve(resources.size());
268 for (const auto& resource : resources) {
269 if (resource->arg_num() >= 0) {
270 arg_resources.push_back(resource.get());
271 }
272 }
273 std::sort(arg_resources.begin(), arg_resources.end(),
274 [](const XlaResource* a, const XlaResource* b) {
275 return a->arg_num() < b->arg_num();
276 });
277
278 absl::flat_hash_map<int, int> argument_to_xla_arg;
279 for (int xla_arg = 0; xla_arg < input_mapping.size(); xla_arg++) {
280 argument_to_xla_arg[input_mapping[xla_arg]] = xla_arg;
281 }
282
283 std::vector<xla::XlaBuilder::InputOutputAlias> aliases;
284 for (const XlaResource* resource : arg_resources) {
285 DCHECK_LT(resource->arg_num(), args.size());
286 const XlaCompiler::Argument& arg = args[resource->arg_num()];
287 auto it = arg_shardings.find(resource->arg_num());
288 bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
289 // TensorArray gradients were modified if their values changed or there are
290 // any newly created gradients.
291 for (const auto& grad : resource->tensor_array_gradients()) {
292 modified =
293 modified ||
294 !grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
295 arg.tensor_array_gradients.count(grad.first) == 0;
296 }
297
298 if (return_updated_values_for_all_resources || modified) {
299 resource_updates->emplace_back();
300 XlaCompiler::ResourceUpdate& update = resource_updates->back();
301 update.input_index = resource->arg_num();
302 update.type = resource->type();
303 update.shape = resource->shape();
304 update.modified = modified;
305 int param_num = use_tuple_arg ? 0 : update.input_index;
306 if (is_entry_computation &&
307 arg.resource_kind != XlaResource::kTensorArray &&
308 alias_resource_update && argument_to_xla_arg.count(param_num)) {
309 // Assuming tuple arg and results are used.
310 xla::ShapeIndex param_index =
311 use_tuple_arg ? xla::ShapeIndex({update.input_index})
312 : xla::ShapeIndex{};
313 int xla_param_num = argument_to_xla_arg[param_num];
314 int64 output_index_num = elems.size();
315 xla::ShapeIndex output_index = xla::ShapeIndex({output_index_num});
316 VLOG(3) << "Storing alias: " << output_index.ToString() << ": ("
317 << xla_param_num << ", " << param_index.ToString() << ")";
318 aliases.push_back({output_index, xla_param_num, param_index});
319 }
320 for (const auto& grad : resource->tensor_array_gradients()) {
321 update.tensor_array_gradients_accessed.insert(grad.first);
322 }
323
324 xla::XlaOp handle;
325 TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
326 auto sharding = it == arg_shardings.end()
327 ? absl::optional<xla::OpSharding>()
328 : it->second;
329 // Set layout of the retval to device representation layout.
330 if (shape_representation_fn) {
331 TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(handle));
332 TF_ASSIGN_OR_RETURN(
333 handle, ReshapeWithCorrectRepresentationAndSharding(
334 builder, handle, original_shape,
335 shape_representation_fn, sharding, arg.fast_mem));
336 }
337
338 // Request that the value be returned on a specific core.
339 xla::XlaScopedShardingAssignment assign_sharding(builder, sharding);
340 if (it != arg_shardings.end()) {
341 retval_index_and_sharding[elems.size()] = it->second;
342 }
343 // Ensures the correct sharding is applied to the output.
344 handle = identity_op(handle);
345 elems.push_back(handle);
346 }
347 }
348
349 // If we have token output, append it as the last one.
350 if (token_output) {
351 elems.push_back(*token_output);
352 }
353
354 *num_computation_outputs = elems.size();
355
356 // Builds the XLA computation. We *always* form a tuple here to ensure that
357 // the output value is the last thing added into the XLA computation, even
358 // if there is only one output value.
359 xla::XlaOp tuple;
360 if (retval_index_and_sharding.empty() || !is_entry_computation) {
361 tuple = xla::Tuple(builder, elems);
362 } else {
363 std::vector<xla::Shape> elem_shapes;
364 for (const auto& elem : elems) {
365 TF_ASSIGN_OR_RETURN(xla::Shape elem_shape,
366 elem.builder()->GetShape(elem));
367 elem_shapes.push_back(elem_shape);
368 }
369 xla::Shape shape = xla::ShapeUtil::MakeTupleShape(elem_shapes);
370 // Copy specified sharding from retval_index_and_sharding.
371 std::vector<xla::HloSharding> sharding_elems;
372 for (int i = 0, end = elems.size(); i < end; i++) {
373 const auto& iter = retval_index_and_sharding.find(i);
374 TF_RET_CHECK(iter != retval_index_and_sharding.end());
375 const xla::OpSharding& sub_op_sharding = iter->second;
376 TF_ASSIGN_OR_RETURN(xla::HloSharding sub_sharding,
377 xla::HloSharding::FromProto(sub_op_sharding));
378 if (elem_shapes[i].IsTuple()) {
379 const std::vector<xla::HloSharding> sub_sharding_elems =
380 sub_sharding.tuple_elements();
381 const int64 sub_sharding_elems_size = sub_sharding_elems.size();
382 TF_RET_CHECK(sub_sharding_elems_size ==
383 xla::ShapeUtil::GetLeafCount(elem_shapes[i]));
384 for (const auto& sub_sharding_elem : sub_sharding_elems) {
385 sharding_elems.push_back(sub_sharding_elem);
386 }
387 } else {
388 sharding_elems.push_back(sub_sharding);
389 }
390 }
391 xla::HloSharding modified_sharding =
392 xla::HloSharding::Tuple(shape, sharding_elems);
393 xla::OpSharding op_sharding = modified_sharding.ToProto();
394 // Assign proper sharding to the tuple instruction.
395 xla::XlaScopedShardingAssignment assign_sharding(builder, op_sharding);
396 tuple = xla::Tuple(builder, elems);
397 }
398 bool returns_tuple = always_return_tuple || elems.size() != 1;
399 VLOG(3) << "Computation returns a tuple=" << returns_tuple;
400 if (!returns_tuple) {
401 xla::GetTupleElement(tuple, 0);
402
403 for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
404 if (alias.output_index == xla::ShapeIndex({0})) {
405 VLOG(3) << "For aliased parameter " << alias.param_number << ": "
406 << alias.param_index.ToString()
407 << " normalizing output_index from {0} to {}, as a scalar is "
408 "returned from the cluster";
409 alias.output_index = xla::ShapeIndex({});
410 }
411 }
412 }
413
414 for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
415 builder->SetUpAlias(alias.output_index, alias.param_number,
416 alias.param_index);
417 }
418
419 xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();
420 if (!computation_status.ok()) {
421 return computation_status.status();
422 }
423 *computation = computation_status.ConsumeValueOrDie();
424
425 TF_ASSIGN_OR_RETURN(auto program_shape, computation->GetProgramShape());
426 *output_shape = program_shape.result();
427 return Status::OK();
428 }
429
430 } // namespace
431
432
HumanString() const433 string XlaCompiler::Argument::HumanString() const {
434 string common;
435 if (!name.empty()) {
436 common = absl::StrCat(" name=", name);
437 }
438 absl::StrAppend(&common, " type=", DataTypeString(type),
439 " shape=", ShapeHumanString());
440 absl::StrAppend(
441 &common, " is_same_data_across_replicas=", is_same_data_across_replicas);
442 switch (kind) {
443 case kInvalid:
444 return "invalid";
445 case kConstant:
446 return absl::StrCat("kind=constant", common,
447 " value=", constant_value.DebugString());
448 case kConstantResource:
449 return absl::StrCat("kind=constant-resource", common,
450 " value=", constant_value.DebugString());
451 case kResource: {
452 string output = absl::StrCat(
453 "kind=resource", common,
454 " resource_kind=", XlaResource::KindToString(resource_kind),
455 " initialized=", initialized, " is_fast_mem=", fast_mem);
456 if (max_array_size >= 0) {
457 absl::StrAppend(&output, " max_array_size=", max_array_size);
458 }
459 if (!tensor_array_gradients.empty()) {
460 absl::StrAppend(&output, " tensor_array_gradients=",
461 absl::StrJoin(tensor_array_gradients, ","));
462 }
463 return output;
464 }
465 case kParameter:
466 return absl::StrCat("kind=parameter", common);
467 case kTensorList:
468 return absl::StrCat("kind=tensorlist", common);
469 case kToken:
470 return absl::StrCat("token", common);
471 }
472 }
473
DimensionSizes() const474 std::vector<int64> XlaCompiler::Argument::DimensionSizes() const {
475 if (absl::holds_alternative<TensorShape>(shape)) {
476 return xla::InlinedVectorToVector(
477 absl::get<TensorShape>(shape).dim_sizes());
478 } else {
479 return xla::SpanToVector(absl::get<xla::Shape>(shape).dimensions());
480 }
481 }
482
483 absl::InlinedVector<int64, 4>
DimensionSizesAsInlinedVector() const484 XlaCompiler::Argument::DimensionSizesAsInlinedVector() const {
485 if (absl::holds_alternative<TensorShape>(shape)) {
486 return absl::get<TensorShape>(shape).dim_sizes();
487 } else {
488 auto v = absl::get<xla::Shape>(shape).dimensions();
489 return absl::InlinedVector<int64, 4>(v.begin(), v.end());
490 }
491 }
492
ShapeHumanString() const493 string XlaCompiler::Argument::ShapeHumanString() const {
494 if (absl::holds_alternative<TensorShape>(shape)) {
495 return absl::get<TensorShape>(shape).DebugString();
496 } else {
497 return absl::get<xla::Shape>(shape).DebugString();
498 }
499 }
500
XlaCompiler(XlaCompiler::Options options)501 XlaCompiler::XlaCompiler(XlaCompiler::Options options)
502 : options_(options),
503 initialization_status_(Status::OK()),
504 next_step_id_(1),
505 device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
506 device_mgr_(absl::WrapUnique(device_)) {
507 CHECK(!options_.device_type.type_string().empty());
508 if (options_.populate_resource_manager) {
509 initialization_status_ =
510 (*options_.populate_resource_manager)(device_->resource_manager());
511 }
512
513 local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
514 FunctionDefLibrary{}));
515 local_pflr_.reset(new ProcessFunctionLibraryRuntime(
516 &device_mgr_, Env::Default(), /*config=*/nullptr,
517 options.graph_def_version, local_flib_def_.get(), OptimizerOptions()));
518 pflr_.reset(new ProcessFunctionLibraryRuntime(
519 &device_mgr_, Env::Default(), /*config=*/nullptr,
520 options.graph_def_version, options.flib_def, OptimizerOptions()));
521
522 local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
523 flib_runtime_ = pflr_->GetFLR(device_->name());
524
525 // The default shape representation function is the identity.
526 if (!options_.shape_representation_fn) {
527 options_.shape_representation_fn = IdentityShapeRepresentationFn();
528 }
529 }
530
531 XlaCompiler::~XlaCompiler() = default;
532
NextStepId()533 int64 XlaCompiler::NextStepId() { return next_step_id_++; }
534
operator ()(const std::pair<string,std::vector<Argument>> & signature) const535 uint64 XlaCompiler::SignatureHash::operator()(
536 const std::pair<string, std::vector<Argument>>& signature) const {
537 return std::hash<string>()(signature.first);
538 }
539
GetFunctionBody(const NameAttrList & function,FunctionLibraryRuntime * flib_runtime,const FunctionBody ** fbody)540 static Status GetFunctionBody(const NameAttrList& function,
541 FunctionLibraryRuntime* flib_runtime,
542 const FunctionBody** fbody) {
543 FunctionLibraryRuntime::Handle handle;
544 TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
545 function.name(), AttrSlice(&function.attr()), &handle));
546
547 *fbody = flib_runtime->GetFunctionBody(handle);
548 TF_RET_CHECK(*fbody);
549 return Status::OK();
550 }
551
FindFunctionBody(const NameAttrList & function,const FunctionBody ** fbody,const ConfigProto ** config_proto)552 Status XlaCompiler::FindFunctionBody(const NameAttrList& function,
553 const FunctionBody** fbody,
554 const ConfigProto** config_proto) {
555 // The function may be in either the local_flib_runtime_ or flib_runtime_.
556 // Look up the function in local first and if it is not found then look up the
557 // function in flib_runtime_.
558 auto status = GetFunctionBody(function, local_flib_runtime_, fbody);
559 if (!status.ok()) {
560 if (!errors::IsNotFound(status)) {
561 return status;
562 }
563 TF_RETURN_WITH_CONTEXT_IF_ERROR(
564 GetFunctionBody(function, flib_runtime_, fbody),
565 "Local lookup failed with: ", status.error_message());
566 if (config_proto) {
567 *config_proto = flib_runtime_->config_proto();
568 }
569 VLOG(4) << "Function " << function.name() << " in flib_runtime_";
570 } else {
571 if (config_proto) {
572 *config_proto = local_flib_runtime_->config_proto();
573 }
574 VLOG(4) << "Function " << function.name() << " in local_flib_runtime_";
575 }
576 return Status::OK();
577 }
578
GetGraph(const FunctionBody * fbody)579 std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
580 std::unique_ptr<Graph> graph(new Graph(options_.flib_def));
581 CopyGraph(*fbody->graph, graph.get());
582
583 bool is_inside_mustcompile = false;
584 TryGetNodeAttr(AttrSlice(&fbody->fdef.attr()), kXlaMustCompileAttr,
585 &is_inside_mustcompile);
586
587 // Performs a first function inlining pass before shape inference, since
588 // otherwise shape inference can't see inside functions and a comprehensive
589 // shape_map, including function ops, is needed to constant-propagate Shape
590 // Ops below.
591 auto flags = GetBuildXlaOpsPassFlags();
592 OptimizerOptions opts;
593 opts.set_opt_level(OptimizerOptions::L0);
594 opts.set_do_common_subexpression_elimination(false);
595 opts.set_do_function_inlining(true);
596 opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding);
597 GraphOptimizer optimizer(opts);
598 // Do not constant fold nodes that output DT_VARIANT type tensors.
599 // XLA does not support Const nodes of Variant type since it needs
600 // to know the original ops to be able to compile them to the relevant
601 // XLA form.
602 // TODO(srbs): This filter is a little conservative. E.g. a subgraph of
603 // the form:
604 // Const
605 // |
606 // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op
607 // |
608 // (Discard popped list)
609 //
610 // Would have been reduced to "Const -> Op" without this filter.
611 // However since we are only allowed to specify the filter at the "Node"
612 // level there is no good way to allow the above behavior. So we
613 // disallow any sort of constant folding on Variant nodes for now.
614 //
615 // Also do not consider constant folding Shape ops. When there is a dynamic
616 // dimension in a tensor, TF2XLA currently represent them as the static
617 // upperbound shape, which can be constant folded and then lose the info
618 // that this Shape is dynamic.
619 auto cf_consider_fn = [](const Node* n) {
620 for (const auto& output_arg : n->op_def().output_arg()) {
621 if (output_arg.type() == DT_VARIANT) {
622 return false;
623 }
624 }
625 const auto& ts = n->type_string();
626 // XLA has special logic to handle dynamic shapes, don't constant fold
627 // them.
628 if (ts == "Shape" || ts == "ShapeN" || ts == "Size") {
629 return false;
630 }
631 return true;
632 };
633 GraphOptimizer::Options graph_optimizer_options;
634 graph_optimizer_options.cf_consider_fn = cf_consider_fn;
635 graph_optimizer_options.inline_multi_device_functions = true;
636 graph_optimizer_options.inline_impl_selection_group_functions = true;
637 graph_optimizer_options.inline_with_single_device_body_placer = true;
638 graph_optimizer_options.ignore_noinline = is_inside_mustcompile;
639
640 {
641 GraphShapeInfo shape_info;
642 InferShapes(graph.get(), /*arg_shapes=*/{},
643 flib_runtime_->GetFunctionLibraryDefinition(), &shape_info)
644 .IgnoreError();
645 auto node_name_index = graph->BuildNodeNameIndex();
646 std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
647 for (const auto& node_shape_info : shape_info) {
648 const string& node_name = node_shape_info.first;
649 const std::vector<InferredShape>& output_shapes = node_shape_info.second;
650 const auto& node_iter = node_name_index.find(node_name);
651 if (node_iter != node_name_index.end()) {
652 auto& partial_shapes = shape_map[node_name];
653 for (const auto& inferred_shape : output_shapes) {
654 partial_shapes.push_back(inferred_shape.shape);
655 }
656 }
657 }
658 graph_optimizer_options.shape_map = &shape_map;
659 optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
660 /*device=*/nullptr, &graph, graph_optimizer_options);
661 }
662
663 // Run shape inference on the graph and optimize the graph again.
664 GraphShapeInfo shape_info;
665 InferShapes(graph.get(), /*arg_shapes=*/{},
666 flib_runtime_->GetFunctionLibraryDefinition(), &shape_info)
667 .IgnoreError();
668 auto node_name_index = graph->BuildNodeNameIndex();
669 std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
670 for (const auto& node_shape_info : shape_info) {
671 const string& node_name = node_shape_info.first;
672 const std::vector<InferredShape>& output_shapes = node_shape_info.second;
673 const auto& node_iter = node_name_index.find(node_name);
674 if (node_iter != node_name_index.end()) {
675 auto& partial_shapes = shape_map[node_name];
676 for (const auto& inferred_shape : output_shapes) {
677 partial_shapes.push_back(inferred_shape.shape);
678 }
679 }
680 }
681 graph_optimizer_options.shape_map = &shape_map;
682 optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
683 /*device=*/nullptr, &graph, graph_optimizer_options);
684
685 return graph;
686 }
687
688 // Collects all control rets from `orig_control_ret_nodes` that are still valid,
689 // keeping the same order.
GetValidControlRets(absl::Span<Node * const> orig_control_ret_nodes,const Graph & graph)690 std::vector<std::string> GetValidControlRets(
691 absl::Span<Node* const> orig_control_ret_nodes, const Graph& graph) {
692 // Build map from control ret node to index.
693 absl::flat_hash_map<const Node*, int> control_ret_nodes_map;
694 for (int i = 0; i < orig_control_ret_nodes.size(); ++i) {
695 const Node* n = orig_control_ret_nodes[i];
696 control_ret_nodes_map[n] = i;
697 }
698 // Check which control rets are still valid.
699 std::vector<bool> is_valid_control_ret(orig_control_ret_nodes.size(), false);
700 int num_valid_control_rets = 0;
701 for (const Node* n : graph.nodes()) {
702 auto iter = control_ret_nodes_map.find(n);
703 if (iter != control_ret_nodes_map.end()) {
704 ++num_valid_control_rets;
705 is_valid_control_ret[iter->second] = true;
706 }
707 }
708 // Return valid control rets in same order as they appear in
709 // `orig_control_ret_nodes`.
710 std::vector<std::string> valid_control_rets;
711 valid_control_rets.reserve(num_valid_control_rets);
712 for (int i = 0; i < orig_control_ret_nodes.size(); ++i) {
713 if (is_valid_control_ret[i]) {
714 valid_control_rets.push_back(orig_control_ret_nodes[i]->name());
715 }
716 }
717 return valid_control_rets;
718 }
719
CompileFunction(const XlaCompiler::CompileOptions & options,const NameAttrList & fn_name_attrs,absl::Span<const XlaCompiler::Argument> args,XlaCompiler::CompilationResult * result)720 Status XlaCompiler::CompileFunction(
721 const XlaCompiler::CompileOptions& options,
722 const NameAttrList& fn_name_attrs,
723 absl::Span<const XlaCompiler::Argument> args,
724 XlaCompiler::CompilationResult* result) {
725 const string function_id =
726 Canonicalize(fn_name_attrs.name(), AttrSlice(&fn_name_attrs.attr()));
727 VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
728
729 const std::vector<XlaCompiler::Argument> arg_vector(args.begin(), args.end());
730 auto it = cache_.find({function_id, arg_vector});
731 if (it != cache_.end()) {
732 *result = it->second;
733 return Status::OK();
734 }
735
736 const FunctionBody* fbody;
737 const ConfigProto* config = nullptr;
738 TF_RETURN_IF_ERROR(FindFunctionBody(fn_name_attrs, &fbody, &config));
739
740 absl::optional<ConfigProto> config_proto;
741 if (config) {
742 config_proto = *config;
743 }
744
745 TF_RETURN_WITH_CONTEXT_IF_ERROR(
746 CheckSignature(fbody->arg_types, args),
747 "Signature check failure while compiling: ", fn_name_attrs.name());
748
749 // Set shapes for _Arg nodes. They are useful for constant folding (e.g. an
750 // Xla op requires a compile-time constant input, and that input is shape of
751 // an _Arg node.
752 for (int i = 0, end = args.size(); i < end; i++) {
753 // Skip resource variables and tensor lists.
754 DataType dtype;
755 TF_RETURN_IF_ERROR(GetNodeAttr(fbody->arg_nodes[i]->def(), "T", &dtype));
756 if (dtype == DT_RESOURCE || dtype == DT_VARIANT) {
757 continue;
758 }
759
760 if (absl::holds_alternative<xla::Shape>(args[i].shape)) {
761 xla::Shape xla_shape = absl::get<xla::Shape>(args[i].shape);
762 TensorShape tensor_shape;
763 // If xla_shape is dynamic, prevent constant folding by not setting
764 // output_shapes.
765 if (XLAShapeToTensorShape(xla_shape, &tensor_shape).ok() &&
766 xla_shape.is_static()) {
767 fbody->arg_nodes[i]->ClearAttr("_output_shapes");
768 fbody->arg_nodes[i]->AddAttr("_output_shapes",
769 std::vector<TensorShape>{tensor_shape});
770 }
771 } else {
772 TensorShape tensor_shape = absl::get<TensorShape>(args[i].shape);
773 fbody->arg_nodes[i]->ClearAttr("_output_shapes");
774 fbody->arg_nodes[i]->AddAttr("_output_shapes",
775 std::vector<TensorShape>{tensor_shape});
776 }
777 }
778
779 std::unique_ptr<Graph> graph = GetGraph(fbody);
780
781 // _Arg and _Retval nodes don't exist in the stored subgraph for the function;
782 // they are added by the function body looked up. Therefore, they don't have
783 // core assignments here.
784 // Attempt to assign a core to each _Retval and _Arg. Chooses the
785 // lowest-numbered core that consumes the argument. We choose the
786 // lowest-numbered core so the assignment is deterministic.
787 for (Node* n : graph->nodes()) {
788 if (n->IsArg()) {
789 TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true));
790 }
791 }
792 // Do _Retval as a second loop, in case the retval's input is an _Arg (which
793 // may have gotten a device assignment from the first loop).
794 for (Node* n : graph->nodes()) {
795 if (n->IsRetval()) {
796 TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false));
797 }
798 }
799
800 if (VLOG_IS_ON(2)) {
801 VLOG(2) << "XlaCompiler::CompileFunction: "
802 << DumpGraphToFile(
803 absl::StrCat("xla_compile_function_", function_id), *graph);
804 }
805
806 VLOG(1) << "====================================================";
807 MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
808 *graph, config_proto,
809 /*uses_uninitialized_resource_args=*/AnyUninitializedResourceArg(args));
810 if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) {
811 VLOG(1) << "Using MLIR bridge";
812 GraphDebugInfo debug_info;
813
814 std::vector<std::string> valid_control_rets =
815 GetValidControlRets(fbody->control_ret_nodes, *graph);
816
817 TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
818 std::move(*graph), mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
819 valid_control_rets, options_.device_type.type_string(),
820 options.use_tuple_arg, *options_.flib_def, debug_info,
821 options_.shape_representation_fn, result));
822 } else {
823 TF_RETURN_IF_ERROR(
824 CompileGraph(options, function_id, std::move(graph), args, result));
825 }
826 VLOG(1) << "====================================================";
827
828 cache_[{function_id, arg_vector}] = *result;
829 return Status::OK();
830 }
831
832 // Computes the XLA shape for argument 'arg'.
XLAShapeForArgument(const XlaCompiler::Argument & arg,bool is_entry_computation,const absl::optional<xla::HloSharding> & arg_sharding,xla::Shape * xla_shape) const833 Status XlaCompiler::XLAShapeForArgument(
834 const XlaCompiler::Argument& arg, bool is_entry_computation,
835 const absl::optional<xla::HloSharding>& arg_sharding,
836 xla::Shape* xla_shape) const {
837 switch (arg.kind) {
838 case XlaCompiler::Argument::kConstant:
839 LOG(FATAL) << "Unreachable case";
840 case XlaCompiler::Argument::kParameter: {
841 if (is_entry_computation) {
842 TensorShape shape;
843 if (absl::holds_alternative<TensorShape>(arg.shape)) {
844 shape = absl::get<TensorShape>(arg.shape);
845 } else {
846 TF_RETURN_IF_ERROR(
847 XLAShapeToTensorShape(absl::get<xla::Shape>(arg.shape), &shape));
848 }
849 TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn(
850 shape, arg.type,
851 /*use_fast_memory=*/false));
852 TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
853 arg_sharding, /*use_fast_memory=*/false,
854 options_.shape_representation_fn, xla_shape));
855 } else {
856 if (absl::holds_alternative<xla::Shape>(arg.shape)) {
857 *xla_shape = absl::get<xla::Shape>(arg.shape);
858 } else {
859 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(
860 arg.type, absl::get<TensorShape>(arg.shape), xla_shape));
861 }
862 }
863 return Status::OK();
864 }
865 case XlaCompiler::Argument::kTensorList: {
866 TF_RET_CHECK(absl::holds_alternative<xla::Shape>(arg.shape));
867 *xla_shape = absl::get<xla::Shape>(arg.shape);
868 return Status::OK();
869 }
870 case XlaCompiler::Argument::kConstantResource:
871 case XlaCompiler::Argument::kResource: {
872 TF_RET_CHECK(arg.initialized);
873
874 switch (arg.resource_kind) {
875 case XlaResource::kVariable: {
876 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
877 TF_ASSIGN_OR_RETURN(*xla_shape,
878 options_.shape_representation_fn(
879 absl::get<TensorShape>(arg.shape), arg.type,
880 /*use_fast_memory=*/arg.fast_mem));
881 TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
882 arg_sharding, arg.fast_mem, options_.shape_representation_fn,
883 xla_shape));
884 return Status::OK();
885 }
886 case XlaResource::kTensorArray: {
887 if (arg.max_array_size < 0) {
888 return errors::InvalidArgument(
889 "Negative max_array_size in XLAShapeForArgument");
890 }
891 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
892 TensorShape shape;
893 shape.AddDim(arg.max_array_size);
894 shape.AppendShape(absl::get<TensorShape>(arg.shape));
895 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape));
896
897 if (!arg.tensor_array_gradients.empty()) {
898 std::vector<xla::Shape> tuple_shape(
899 arg.tensor_array_gradients.size() + 1, *xla_shape);
900 *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape);
901 }
902 return Status::OK();
903 }
904 case XlaResource::kStack: {
905 if (arg.max_array_size < 0) {
906 return errors::InvalidArgument(
907 "Negative max_array_size in XLAShapeForArgument");
908 }
909 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
910 TensorShape shape;
911 shape.AddDim(arg.max_array_size);
912 shape.AppendShape(absl::get<TensorShape>(arg.shape));
913 xla::Shape buffer_shape;
914 TF_RETURN_IF_ERROR(
915 TensorShapeToXLAShape(arg.type, shape, &buffer_shape));
916 *xla_shape = xla::ShapeUtil::MakeTupleShape(
917 {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})});
918 return Status::OK();
919 }
920
921 case XlaResource::kInvalid:
922 return errors::Internal(
923 "Invalid resource type in XLAShapeForArgument()");
924 }
925 }
926 case XlaCompiler::Argument::kToken: {
927 *xla_shape = xla::ShapeUtil::MakeTokenShape();
928 return Status::OK();
929 }
930 case XlaCompiler::Argument::kInvalid:
931 return errors::Internal("Invalid argument type in XLAShapeForArgument()");
932 }
933 }
934
935 /* static */
PopulateArgumentFromResource(const XlaResource & resource,Argument * arg)936 void XlaCompiler::PopulateArgumentFromResource(const XlaResource& resource,
937 Argument* arg) {
938 arg->initialized = resource.initialized();
939 arg->kind = XlaCompiler::Argument::kResource;
940 arg->resource_kind = resource.kind();
941
942 arg->type = resource.type();
943 arg->shape = resource.shape();
944 arg->max_array_size = resource.max_array_size();
945 for (const auto& gradient : resource.tensor_array_gradients()) {
946 arg->tensor_array_gradients.insert(gradient.first);
947 }
948 arg->name = resource.name();
949 }
950
951 // Builds XLA computations for each of the arguments to the computation.
952 // `args` are the arguments to the computation.
BuildArguments(const Graph & graph,const std::vector<XlaCompiler::Argument> & args,bool use_tuple_arg,xla::XlaBuilder * builder,XlaContext * context,const std::map<int,xla::OpSharding> & arg_shardings,std::vector<XlaExpression> * arg_expressions,std::vector<int> * input_to_args,std::vector<xla::Shape> * input_shapes,bool is_entry_computation)953 Status XlaCompiler::BuildArguments(
954 const Graph& graph, const std::vector<XlaCompiler::Argument>& args,
955 bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context,
956 const std::map<int, xla::OpSharding>& arg_shardings,
957 std::vector<XlaExpression>* arg_expressions,
958 std::vector<int>* input_to_args, std::vector<xla::Shape>* input_shapes,
959 bool is_entry_computation) {
960 arg_expressions->resize(args.size());
961
962 // Argument numbers of arguments and resources that are to be passed to the
963 // XLA computation as runtime parameters. `input_to_args[a] = b` means that
964 // the a'th XLA input corresponds to the b'th original arg indexes.
965 input_to_args->clear();
966 input_to_args->reserve(args.size());
967
968 // Fills in constant arguments, and computes non-constant argument order.
969 for (std::vector<XlaCompiler::Argument>::size_type i = 0; i < args.size();
970 ++i) {
971 const XlaCompiler::Argument& arg = args[i];
972 XlaExpression& arg_expression = (*arg_expressions)[i];
973 switch (arg.kind) {
974 case XlaCompiler::Argument::kConstantResource:
975 case XlaCompiler::Argument::kResource: {
976 TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid);
977 TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
978 // TODO(phawkins): this code assumes that resource arguments do not
979 // alias.
980 XlaResource* resource =
981 context->AddResource(absl::make_unique<XlaResource>(
982 arg.resource_kind, i, arg.name, arg.type,
983 absl::get<TensorShape>(arg.shape), xla::XlaOp(),
984 /*max_array_size=*/arg.max_array_size,
985 /*tensor_array_gradients=*/arg.tensor_array_gradients,
986 /*tensor_array_multiple_writes_aggregate=*/true));
987 arg_expression =
988 arg.kind == XlaCompiler::Argument::kResource
989 ? XlaExpression::Resource(resource)
990 : XlaExpression::ConstantResource(arg.constant_value, resource);
991 if (arg.initialized) {
992 input_to_args->push_back(i);
993 }
994 break;
995 }
996 case XlaCompiler::Argument::kParameter:
997 case XlaCompiler::Argument::kTensorList:
998 case XlaCompiler::Argument::kToken: {
999 input_to_args->push_back(i);
1000 break;
1001 }
1002 case XlaCompiler::Argument::kConstant:
1003 arg_expression = XlaExpression::Constant(arg.constant_value);
1004 break;
1005 case XlaCompiler::Argument::kInvalid:
1006 return errors::Internal(
1007 "Unreachable case in BuildArguments() while filling constant args");
1008 }
1009 }
1010
1011 if (input_to_args->empty() && !use_tuple_arg) {
1012 return Status::OK();
1013 }
1014
1015 // `arg_to_inputs[c] = d` means that the c'th original arg index corresponds
1016 // to the d'th XLA input. Note that the value -1 corresponds to constants, or
1017 // other args that don't correspond to an input.
1018 std::vector<int> arg_to_inputs(args.size(), -1);
1019 for (int i = 0, end = input_to_args->size(); i < end; i++) {
1020 arg_to_inputs[input_to_args->at(i)] = i;
1021 }
1022
1023 std::vector<xla::Shape> arg_shapes(input_to_args->size());
1024 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1025 // Computes the shapes of non-constant arguments.
1026 auto arg_sharding = arg_shardings.find((*input_to_args)[i]);
1027 absl::optional<xla::HloSharding> sharding;
1028 if (arg_sharding != arg_shardings.end()) {
1029 TF_ASSIGN_OR_RETURN(auto hlo_sharding,
1030 xla::HloSharding::FromProto(arg_sharding->second));
1031 sharding = hlo_sharding;
1032 }
1033 TF_RETURN_IF_ERROR(XLAShapeForArgument(args[(*input_to_args)[i]],
1034 is_entry_computation, sharding,
1035 &arg_shapes[i]));
1036 }
1037
1038 if (use_tuple_arg) {
1039 input_shapes->push_back(xla::ShapeUtil::MakeTupleShape(arg_shapes));
1040 } else {
1041 *input_shapes = arg_shapes;
1042 }
1043
1044 // Attach a common operator name as metadata. This has no semantic effect — it
1045 // merely makes the HLO graph more readable when visualized via TensorBoard,
1046 // since TensorBoard forms groups out of operators with similar names.
1047 xla::OpMetadata arg_metadata;
1048 arg_metadata.set_op_name("XLA_Args");
1049 builder->SetOpMetadata(arg_metadata);
1050
1051 // Build parameter handles for non-constant arguments.
1052 std::vector<xla::XlaOp> arg_handles(input_to_args->size());
1053 if (use_tuple_arg) {
1054 xla::XlaOp tuple;
1055 if (is_entry_computation) {
1056 xla::OpSharding tuple_sharding;
1057 tuple_sharding.set_type(xla::OpSharding::TUPLE);
1058 for (int64 parameter : *input_to_args) {
1059 auto it = arg_shardings.find(parameter);
1060 *tuple_sharding.add_tuple_shardings() =
1061 it == arg_shardings.end() ? xla::sharding_builder::AssignDevice(0)
1062 : it->second;
1063 }
1064 std::vector<bool> is_same_across_replicas;
1065 for (int i = 0, end = input_to_args->size(); i < end; ++i) {
1066 // Add an entry to is_same_across_replicas for every leaf buffer.
1067 is_same_across_replicas.insert(
1068 is_same_across_replicas.end(),
1069 xla::ShapeUtil::GetLeafCount(arg_shapes[i]),
1070 args[input_to_args->at(i)].is_same_data_across_replicas);
1071 }
1072 xla::XlaScopedShardingAssignment assign_tuple_sharding(
1073 builder, input_to_args->empty() ? absl::optional<xla::OpSharding>()
1074 : tuple_sharding);
1075 tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple",
1076 is_same_across_replicas);
1077 } else {
1078 tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
1079 }
1080
1081 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1082 auto it = arg_shardings.find(i);
1083 xla::XlaScopedShardingAssignment assign_sharding(
1084 builder, it == arg_shardings.end() ? absl::optional<xla::OpSharding>()
1085 : it->second);
1086 auto& arg = args[input_to_args->at(i)];
1087
1088 xla::OpMetadata arg_metadata;
1089 arg_metadata.set_op_name(arg.node_name);
1090 builder->SetOneShotOpMetadata(arg_metadata);
1091 arg_handles[i] = xla::GetTupleElement(tuple, i);
1092 }
1093 } else {
1094 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1095 auto it = arg_shardings.find(i);
1096 xla::XlaScopedShardingAssignment assign_sharding(
1097 builder, it == arg_shardings.end() ? absl::optional<xla::OpSharding>()
1098 : it->second);
1099 if (is_entry_computation) {
1100 // Add an entry to is_same_across_replicas for every leaf buffer.
1101 std::vector<bool> is_same_across_replicas(
1102 xla::ShapeUtil::GetLeafCount((*input_shapes)[i]),
1103 args[input_to_args->at(i)].is_same_data_across_replicas);
1104 arg_handles[i] =
1105 xla::Parameter(builder, i, (*input_shapes)[i],
1106 absl::StrCat("arg", i), is_same_across_replicas);
1107 } else {
1108 arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i],
1109 absl::StrCat("arg", i));
1110 }
1111 }
1112 }
1113
1114 for (int i = 0, end = input_to_args->size(); i < end; ++i) {
1115 const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
1116 for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) {
1117 int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second);
1118 VLOG(1) << "Setting dynamic size " << i << " -> "
1119 << dynamic_size_param_index;
1120 arg_handles[i] = xla::SetDimensionSize(
1121 arg_handles[i], arg_handles[dynamic_size_param_index],
1122 dim_and_arg_num.first);
1123 }
1124 }
1125
1126 builder->ClearOpMetadata();
1127
1128 // Fill in the handles in non-constant arguments, and reshape parameters
1129 // back to their correct shapes.
1130 VLOG(2) << "XLA computation inputs:";
1131 for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1132 const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
1133 VLOG(2) << " XLA arg " << i
1134 << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
1135 << " name: " << arg.name << " TF arg " << input_to_args->at(i)
1136 << " node name: " << arg.node_name
1137 << (arg_shardings.find(i) == arg_shardings.end()
1138 ? ""
1139 : absl::StrCat(" sharding: ",
1140 arg_shardings.at(i).DebugString()));
1141 XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)];
1142 switch (arg.kind) {
1143 case XlaCompiler::Argument::kConstantResource:
1144 case XlaCompiler::Argument::kResource: {
1145 TF_RET_CHECK(arg.initialized);
1146 XlaResource* resource = arg_expression.resource();
1147 TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients,
1148 arg_handles[i], builder));
1149 VLOG(2) << " resource: num_gradients: "
1150 << arg.tensor_array_gradients.size();
1151 break;
1152 }
1153 case XlaCompiler::Argument::kParameter:
1154 // Reshape parameters back to their correct shapes.
1155 // TODO(b/76097077): propagate device assignments onto arguments and
1156 // return values of functions, and then reshape unconditionally.
1157 if (is_entry_computation) {
1158 arg_expression = XlaExpression::XlaOp(
1159 xla::Reshape(arg_handles[i], arg.DimensionSizes()), arg.type);
1160 } else {
1161 arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
1162 if (arg.value_bound) {
1163 // Propagate upper bound to arg_expression.
1164 arg_expression.set_value_bound(arg.value_bound.value());
1165 }
1166 }
1167 break;
1168 case XlaCompiler::Argument::kTensorList: {
1169 arg_expression = XlaExpression::TensorList(arg_handles[i]);
1170 break;
1171 }
1172 case XlaCompiler::Argument::kToken: {
1173 arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
1174 break;
1175 }
1176 case XlaCompiler::Argument::kConstant:
1177 case XlaCompiler::Argument::kInvalid:
1178 return errors::Internal(
1179 "Unreachable case in BuildArguments() while filling handles");
1180 }
1181 }
1182
1183 return Status::OK();
1184 }
1185
1186 namespace {
1187
1188 // Check that the ops of all non-functional nodes have been registered.
ValidateFunctionDef(const FunctionDef * fdef,const FunctionLibraryDefinition & flib_def)1189 Status ValidateFunctionDef(const FunctionDef* fdef,
1190 const FunctionLibraryDefinition& flib_def) {
1191 for (const NodeDef& node : fdef->node_def()) {
1192 const string& op = node.op();
1193 if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) {
1194 continue;
1195 }
1196 const OpDef* op_def;
1197 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(op, &op_def));
1198 }
1199 return Status::OK();
1200 }
1201
1202 // If node is PartitionedCall or StatefulPartitionedCall, returns the
1203 // name from the "f" attr, else returns node.def().op().
1204 // Returned pointer points to the internal string either in node's attributes
1205 // or in its NodeDef. This pointer is valid as long as the node has not been
1206 // modified.
GetPotentialFunctionName(const Node & node,const string ** name)1207 Status GetPotentialFunctionName(const Node& node, const string** name) {
1208 if (node.IsPartitionedCall()) {
1209 const AttrValue* attr_value;
1210 TF_RETURN_IF_ERROR(
1211 node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
1212 if (!attr_value->has_func()) {
1213 return errors::InvalidArgument(
1214 "The attribute value for attribute 'f' in node ", node.DebugString(),
1215 " does not have 'func' field set");
1216 }
1217 *name = &attr_value->func().name();
1218 return Status::OK();
1219 }
1220 *name = &node.type_string();
1221 return Status::OK();
1222 }
1223
1224 // Check that the graph doesn't have any invalid nodes (e.g. incompatible with
1225 // given device_type, invalid data type, missing attributes...)
ValidateGraph(const Graph * graph,const FunctionLibraryDefinition & flib_def,const DeviceType & device_type,const string & name)1226 Status ValidateGraph(const Graph* graph,
1227 const FunctionLibraryDefinition& flib_def,
1228 const DeviceType& device_type, const string& name) {
1229 // Make sure the XLA compilation kernels are registered. This operation is
1230 // idempotent so it is fine if someone called it already.
1231 XlaOpRegistry::RegisterCompilationKernels();
1232
1233 auto maybe_error = [&](const Node* node, const Status& s) -> Status {
1234 if (!s.ok()) {
1235 return errors::InvalidArgument(absl::StrCat(
1236 "Detected unsupported operations when trying to compile graph ", name,
1237 " on ", device_type.type_string(), ": ", node->def().op(), " (",
1238 s.error_message(), ")", FormatNodeForError(*node),
1239 "One approach is to outside compile the unsupported ops to run on "
1240 "CPUs by enabling soft placement "
1241 "`tf.config.set_soft_device_placement(True)`."
1242 " This has a potential performance penalty."));
1243 }
1244 return Status::OK();
1245 };
1246
1247 for (const Node* node : graph->nodes()) {
1248 if (node->type_string() == FunctionLibraryDefinition::kGradientOp) {
1249 continue;
1250 }
1251 const string* function_name;
1252 TF_RETURN_IF_ERROR(GetPotentialFunctionName(*node, &function_name));
1253 const FunctionDef* fdef = flib_def.Find(*function_name);
1254 Status s;
1255 if (fdef) {
1256 s = ValidateFunctionDef(fdef, flib_def);
1257 TF_RETURN_IF_ERROR(maybe_error(node, s));
1258 continue;
1259 }
1260 const OpDef* op_def;
1261 s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def);
1262 TF_RETURN_IF_ERROR(maybe_error(node, s));
1263 TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def));
1264 s = FindKernelDef(device_type, node->def(), nullptr, nullptr);
1265 TF_RETURN_IF_ERROR(maybe_error(node, s));
1266 }
1267 return Status::OK();
1268 }
1269
ConvertConstantsToExpressions(xla::XlaBuilder * builder,absl::Span<XlaExpression> expressions)1270 void ConvertConstantsToExpressions(xla::XlaBuilder* builder,
1271 absl::Span<XlaExpression> expressions) {
1272 for (XlaExpression& expression : expressions) {
1273 if (expression.kind() == XlaExpression::Kind::kConstant) {
1274 expression =
1275 XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype());
1276 }
1277 }
1278 }
1279
1280 } // namespace
1281
CompileGraph(const XlaCompiler::CompileOptions & options,string const & name,std::unique_ptr<Graph> graph,absl::Span<const XlaCompiler::Argument> args,CompilationResult * result)1282 Status XlaCompiler::CompileGraph(
1283 const XlaCompiler::CompileOptions& options, string const& name,
1284 std::unique_ptr<Graph> graph, absl::Span<const XlaCompiler::Argument> args,
1285 CompilationResult* result) {
1286 VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name;
1287
1288 TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes(
1289 graph.get(), options_.flib_def, local_flib_def_.get()));
1290 TF_RETURN_IF_ERROR(RearrangeFunctionArguments(
1291 [this](const NameAttrList& function, const FunctionBody** fbody) {
1292 return FindFunctionBody(function, fbody);
1293 },
1294 graph.get(), local_flib_def_.get(),
1295 pflr_->GetFunctionLibraryDefinition()));
1296
1297 if (VLOG_IS_ON(2)) {
1298 VLOG(2) << "XlaCompiler::CompileGraph: "
1299 << DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph,
1300 flib_runtime_->GetFunctionLibraryDefinition());
1301 }
1302
1303 // Report the error here if initialization failed.
1304 TF_RETURN_IF_ERROR(initialization_status_);
1305
1306 // Detect invalid nodes.
1307 // FunctionalizeControlFlow may remove some nodes from the graph.
1308 TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
1309 options_.device_type, name));
1310
1311 xla::XlaBuilder builder(name);
1312 XlaContext* context = new XlaContext(this, &builder, graph.get());
1313 core::ScopedUnref context_unref(context);
1314
1315 std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());
1316 int token_input_index = -1;
1317 std::unique_ptr<xla::XlaOp> token_output;
1318 if (options.add_token_input_output) {
1319 // Add extra token input.
1320 token_input_index = real_args.size();
1321
1322 XlaCompiler::Argument token_arg;
1323 token_arg.kind = XlaCompiler::Argument::kToken;
1324 real_args.push_back(token_arg);
1325 }
1326
1327 std::map<int, xla::OpSharding> arg_shardings;
1328 std::map<int, xla::OpSharding> retval_shardings;
1329 TF_ASSIGN_OR_RETURN(std::tie(arg_shardings, retval_shardings),
1330 ComputeArgAndRetvalShardings(*graph));
1331
1332 std::vector<XlaExpression> arg_expressions;
1333 TF_RETURN_IF_ERROR(BuildArguments(
1334 *graph, real_args, options.use_tuple_arg, &builder, context,
1335 arg_shardings, &arg_expressions, &result->input_mapping,
1336 &result->xla_input_shapes, options.is_entry_computation));
1337 context->set_args(std::move(arg_expressions));
1338
1339 PushNodeTokenMapping();
1340 // Use std::set instead of std::unordered_set to ensure determinism.
1341 std::set<std::string> output_node_token_inputs;
1342 if (token_input_index != -1) {
1343 // Original token comes from input.
1344 auto arg_expression = context->args()[token_input_index];
1345 TF_RETURN_IF_ERROR(
1346 SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle()));
1347
1348 // Calculate token inputs for output token.
1349 output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph);
1350
1351 // If there's no side-effecting op in the graph, use token input as token
1352 // output.
1353 if (output_node_token_inputs.empty()) {
1354 output_node_token_inputs.insert(kXlaTokenArgNodeName);
1355 }
1356 } else if (options.is_entry_computation) {
1357 // Original token is manually created.
1358 if (HasSideEffectingNodes(*graph)) {
1359 TF_RETURN_IF_ERROR(
1360 SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder)));
1361 }
1362 }
1363
1364 TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
1365 flib_runtime_, NextStepId()));
1366 if (token_input_index != -1) {
1367 // Add extra token output.
1368 std::vector<xla::XlaOp> token_inputs;
1369 for (const auto& node_name : output_node_token_inputs) {
1370 auto token_or = GetNodeToken(node_name);
1371 TF_RETURN_IF_ERROR(token_or.status());
1372 token_inputs.push_back(token_or.ValueOrDie());
1373 }
1374 token_output.reset(new xla::XlaOp(xla::AfterAll(&builder, token_inputs)));
1375 }
1376 TF_RETURN_IF_ERROR(PopNodeTokenMapping());
1377
1378 int num_nonconst_outputs;
1379 int num_computation_outputs;
1380 result->computation = std::make_shared<xla::XlaComputation>();
1381 result->outputs.resize(context->retvals().size());
1382 std::vector<XlaExpression> retvals = context->retvals();
1383 ConvertConstantsToExpressions(&builder, absl::Span<XlaExpression>(retvals));
1384 TF_RETURN_IF_ERROR(BuildComputation(
1385 real_args, retvals, arg_shardings, retval_shardings, context->resources(),
1386 std::move(token_output),
1387 options.is_entry_computation ? options_.shape_representation_fn
1388 : ShapeRepresentationFn{},
1389 options.is_entry_computation,
1390 options.return_updated_values_for_all_resources,
1391 options.always_return_tuple, options.use_tuple_arg,
1392 options.alias_resource_update, &builder, result->computation.get(),
1393 &num_computation_outputs, &num_nonconst_outputs, &result->outputs,
1394 &result->resource_updates, &result->xla_output_shape,
1395 result->input_mapping));
1396
1397 VLOG(2) << "Outputs: total: " << context->retvals().size()
1398 << " nonconstant: " << num_nonconst_outputs;
1399 VLOG(2) << "XLA output shape: "
1400 << xla::ShapeUtil::HumanStringWithLayout(result->xla_output_shape);
1401 return Status::OK();
1402 }
1403
GetChannelHandle(const string & key,xla::ChannelHandle * channel)1404 Status XlaCompiler::GetChannelHandle(const string& key,
1405 xla::ChannelHandle* channel) {
1406 auto result = channels_.emplace(key, xla::ChannelHandle());
1407 if (result.second) {
1408 TF_ASSIGN_OR_RETURN(result.first->second, client()->CreateChannelHandle());
1409 }
1410 *channel = result.first->second;
1411 VLOG(1) << "Channel: " << key << " " << channel->DebugString();
1412 return Status::OK();
1413 }
1414
GetHostToDeviceChannelHandle(const string & key,xla::ChannelHandle * channel)1415 Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key,
1416 xla::ChannelHandle* channel) {
1417 auto result = channels_.emplace(key, xla::ChannelHandle());
1418 if (result.second) {
1419 TF_ASSIGN_OR_RETURN(result.first->second,
1420 client()->CreateHostToDeviceChannelHandle());
1421 }
1422 *channel = result.first->second;
1423 VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString();
1424 return Status::OK();
1425 }
1426
GetDeviceToHostChannelHandle(const string & key,xla::ChannelHandle * channel)1427 Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key,
1428 xla::ChannelHandle* channel) {
1429 auto result = channels_.emplace(key, xla::ChannelHandle());
1430 if (result.second) {
1431 TF_ASSIGN_OR_RETURN(result.first->second,
1432 client()->CreateDeviceToHostChannelHandle());
1433 }
1434 *channel = result.first->second;
1435 VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString();
1436 return Status::OK();
1437 }
1438
1439 namespace {
1440
SetTransfer(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes,tf2xla::HostTransferMetadata * transfer)1441 void SetTransfer(const string& key, absl::Span<const DataType> types,
1442 absl::Span<const TensorShape> shapes,
1443 tf2xla::HostTransferMetadata* transfer) {
1444 transfer->set_key(key);
1445 CHECK(types.size() == shapes.size());
1446 for (int i = 0, end = types.size(); i < end; ++i) {
1447 tf2xla::TensorMetadata* metadata = transfer->add_metadata();
1448 metadata->set_type(types[i]);
1449 shapes[i].AsProto(metadata->mutable_shape());
1450 }
1451 }
1452
1453 } // namespace
1454
SetDeviceToHostMetadata(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes)1455 Status XlaCompiler::SetDeviceToHostMetadata(
1456 const string& key, absl::Span<const DataType> types,
1457 absl::Span<const TensorShape> shapes) {
1458 if (host_compute_sends_.find(key) != host_compute_sends_.end()) {
1459 tf2xla::HostTransferMetadata& existing_transfer = host_compute_sends_[key];
1460 tf2xla::HostTransferMetadata new_transfer;
1461 SetTransfer(key, types, shapes, &new_transfer);
1462 if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) {
1463 return Status::OK();
1464 } else {
1465 return errors::InvalidArgument(
1466 "Duplicate calls to SetDeviceToHostMetadata with key ", key);
1467 }
1468 }
1469 tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key];
1470 SetTransfer(key, types, shapes, &transfer);
1471 return Status::OK();
1472 }
1473
GetDeviceToHostShapes(const string & key,std::vector<TensorShape> * shapes) const1474 Status XlaCompiler::GetDeviceToHostShapes(
1475 const string& key, std::vector<TensorShape>* shapes) const {
1476 const auto iter = host_compute_sends_.find(key);
1477 if (iter == host_compute_sends_.end()) {
1478 return errors::InvalidArgument(
1479 "No host compute send shapes registered for key ", key);
1480 }
1481 shapes->clear();
1482 for (int i = 0; i < iter->second.metadata_size(); ++i) {
1483 TensorShape shape(iter->second.metadata(i).shape());
1484 shapes->push_back(shape);
1485 }
1486 return Status::OK();
1487 }
1488
SetHostToDeviceMetadata(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes)1489 Status XlaCompiler::SetHostToDeviceMetadata(
1490 const string& key, absl::Span<const DataType> types,
1491 absl::Span<const TensorShape> shapes) {
1492 if (host_compute_recvs_.find(key) != host_compute_recvs_.end()) {
1493 tf2xla::HostTransferMetadata& existing_transfer = host_compute_recvs_[key];
1494 tf2xla::HostTransferMetadata new_transfer;
1495 SetTransfer(key, types, shapes, &new_transfer);
1496 if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) {
1497 return Status::OK();
1498 } else {
1499 return errors::InvalidArgument(
1500 "Duplicate calls to SetHostToDeviceMetadata with key ", key);
1501 }
1502 }
1503 tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key];
1504 SetTransfer(key, types, shapes, &transfer);
1505 return Status::OK();
1506 }
1507
GetHostComputeControlDependency(const string & host_compute_name,xla::XlaOp * handle)1508 Status XlaCompiler::GetHostComputeControlDependency(
1509 const string& host_compute_name, xla::XlaOp* handle) {
1510 const auto iter = host_compute_control_output_.find(host_compute_name);
1511 if (iter == host_compute_control_output_.end()) {
1512 return errors::InvalidArgument(
1513 "No registered control handle for host compute Op '", host_compute_name,
1514 "'");
1515 } else {
1516 *handle = iter->second;
1517 }
1518 return Status::OK();
1519 }
1520
SetHostComputeControlDependency(const string & host_compute_name,const xla::XlaOp & handle)1521 Status XlaCompiler::SetHostComputeControlDependency(
1522 const string& host_compute_name, const xla::XlaOp& handle) {
1523 if (host_compute_control_output_.find(host_compute_name) !=
1524 host_compute_control_output_.end()) {
1525 return errors::InvalidArgument(
1526 "Duplicate control handles registered for for host compute Op ",
1527 host_compute_name);
1528 }
1529 host_compute_control_output_[host_compute_name] = handle;
1530 return Status::OK();
1531 }
1532
PushNodeTokenMapping()1533 void XlaCompiler::PushNodeTokenMapping() {
1534 node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{});
1535 }
1536
PopNodeTokenMapping()1537 Status XlaCompiler::PopNodeTokenMapping() {
1538 if (node_token_mapping_stack_.empty()) {
1539 return errors::FailedPrecondition(
1540 "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is "
1541 "empty.");
1542 }
1543 node_token_mapping_stack_.pop();
1544 return Status::OK();
1545 }
1546
SetNodeToken(const string & node_name,const xla::XlaOp & op)1547 Status XlaCompiler::SetNodeToken(const string& node_name,
1548 const xla::XlaOp& op) {
1549 if (node_token_mapping_stack_.empty()) {
1550 return errors::FailedPrecondition(
1551 "Calling SetNodeToken() when node_token_mapping_stack_ is "
1552 "empty.");
1553 }
1554 auto insert_result = node_token_mapping_stack_.top().insert({node_name, op});
1555 if (!insert_result.second) {
1556 return errors::FailedPrecondition("Token mapping already exists for node ",
1557 node_name);
1558 }
1559 return Status::OK();
1560 }
1561
GetNodeToken(const string & node_name)1562 xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
1563 if (node_token_mapping_stack_.empty()) {
1564 return errors::FailedPrecondition(
1565 "Calling GetNodeToken() when node_token_mapping_stack_ is "
1566 "empty.");
1567 }
1568 auto iter = node_token_mapping_stack_.top().find(node_name);
1569 if (iter == node_token_mapping_stack_.top().end()) {
1570 return errors::FailedPrecondition("Cannot find token mapping for node ",
1571 node_name);
1572 }
1573 return iter->second;
1574 }
1575
1576 } // namespace tensorflow
1577