• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
17 
18 #include <numeric>
19 #include <vector>
20 
21 #include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "absl/types/variant.h"
25 #include "tensorflow/compiler/jit/defs.h"
26 #include "tensorflow/compiler/jit/flags.h"
27 #include "tensorflow/compiler/jit/shape_inference.h"
28 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
29 #include "tensorflow/compiler/mlir/utils/array_container_utils.h"
30 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
31 #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
32 #include "tensorflow/compiler/tf2xla/shape_util.h"
33 #include "tensorflow/compiler/tf2xla/sharding_util.h"
34 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
35 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
36 #include "tensorflow/compiler/tf2xla/type_util.h"
37 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
38 #include "tensorflow/compiler/tf2xla/xla_context.h"
39 #include "tensorflow/compiler/xla/client/client_library.h"
40 #include "tensorflow/compiler/xla/client/xla_builder.h"
41 #include "tensorflow/compiler/xla/client/xla_computation.h"
42 #include "tensorflow/compiler/xla/protobuf_util.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/core/common_runtime/device.h"
46 #include "tensorflow/core/common_runtime/executor.h"
47 #include "tensorflow/core/common_runtime/function.h"
48 #include "tensorflow/core/common_runtime/graph_constructor.h"
49 #include "tensorflow/core/common_runtime/graph_optimizer.h"
50 #include "tensorflow/core/framework/attr_value_util.h"
51 #include "tensorflow/core/framework/function.h"
52 #include "tensorflow/core/framework/node_def_util.h"
53 #include "tensorflow/core/framework/types.h"
54 #include "tensorflow/core/graph/node_builder.h"
55 #include "tensorflow/core/lib/core/errors.h"
56 #include "tensorflow/core/lib/gtl/cleanup.h"
57 #include "tensorflow/core/lib/hash/hash.h"
58 #include "tensorflow/core/platform/logging.h"
59 #include "tensorflow/core/protobuf/error_codes.pb.h"
60 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
61 #include "tensorflow/core/util/dump_graph.h"
62 
63 namespace tensorflow {
64 namespace {
65 
66 // Checks that arguments `args` match types `types`.
CheckSignature(const DataTypeVector & types,absl::Span<const XlaCompiler::Argument> args)67 Status CheckSignature(const DataTypeVector& types,
68                       absl::Span<const XlaCompiler::Argument> args) {
69   if (args.size() != types.size()) {
70     return errors::Internal("Compilation arguments have ", args.size(),
71                             " elements while function has ", types.size());
72   }
73   for (int i = 0, end = types.size(); i < end; ++i) {
74     // Don't perform type checks on resource variables and tensor
75     // lists (DT_VARIANT) as we have to trick the type system in order to
76     // plumb them through. DT_VARIANTS are wrapped in a DT_UINT8 tensor.
77     if (types[i] != args[i].type && types[i] != DT_RESOURCE &&
78         types[i] != DT_VARIANT) {
79       return errors::Internal(
80           "Argument ", i, " has declared type ", DataTypeString(args[i].type),
81           " but function parameter has type ", DataTypeString(types[i]));
82     }
83   }
84   return Status::OK();
85 }
86 
87 // Uses the _Arg and _Retval nodes in the graph to determine an OpSharding for
88 // each argument and return value.
89 StatusOr<
90     std::pair<std::map<int, xla::OpSharding>, std::map<int, xla::OpSharding>>>
ComputeArgAndRetvalShardings(const Graph & graph)91 ComputeArgAndRetvalShardings(const Graph& graph) {
92   auto get_sharding_for_node =
93       [](const Node* n) -> StatusOr<absl::optional<xla::OpSharding>> {
94     TF_ASSIGN_OR_RETURN(
95         auto sharding,
96         ParseShardingFromDevice(*n, std::numeric_limits<int32>::max(),
97                                 /*add_metadata=*/false));
98     return sharding;
99   };
100   std::map<int, xla::OpSharding> arg_shardings;
101   std::map<int, xla::OpSharding> retval_shardings;
102   for (const Node* n : graph.nodes()) {
103     if (n->IsArg()) {
104       TF_ASSIGN_OR_RETURN(auto sharding, get_sharding_for_node(n));
105       if (!sharding.has_value()) continue;
106       int index;
107       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
108       TF_RET_CHECK(index >= 0) << "Negative _Arg index";
109       arg_shardings[index] = std::move(*sharding);
110     } else if (n->IsRetval()) {
111       TF_ASSIGN_OR_RETURN(auto sharding, get_sharding_for_node(n));
112       if (!sharding.has_value()) continue;
113       int index;
114       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
115       TF_RET_CHECK(index >= 0) << "Negative _Retval index";
116       retval_shardings[index] = std::move(*sharding);
117     }
118   }
119   return std::make_pair(std::move(arg_shardings), std::move(retval_shardings));
120 }
121 
ExecuteGraph(XlaContext * xla_context,std::unique_ptr<Graph> graph,XlaCompilationDevice * device,FunctionLibraryRuntime * flib,int64_t step_id)122 Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
123                     XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
124                     int64_t step_id) {
125   // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
126   // resource manager takes ownership via Create, and unrefs via Cleanup.  We
127   // explicitly add a reference to ensure the refcount at entry is maintained at
128   // all exit points; Create and Cleanup are always called in this function.
129   //
130   // The Executor requires us to use ScopedStepContainer. We wrap it in a
131   // unique_ptr so we can capture the cleanup status in the end.
132   xla_context->Ref();
133   Status status;
134   auto step_container = absl::make_unique<ScopedStepContainer>(
135       step_id, [&status, device](const string& name) {
136         status = device->resource_manager()->Cleanup(name);
137       });
138   TF_RETURN_IF_ERROR(step_container->Create(device->resource_manager(),
139                                             XlaContext::kXlaContextResourceName,
140                                             xla_context));
141 
142   GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
143   TF_RETURN_IF_ERROR(graph_compiler.Compile());
144   // Explicitly clean up the step container, to capture the cleanup status.
145   step_container.reset();
146   return status;
147 }
148 
149 // Builds the XLA computation.
150 // - `args` is the list of input arguments
151 // - `retvals` is the list of retvals produced by _Retval operators, in index
152 //   order.
153 // - `arg_shardings` and `retval_shardings` are mapping from arg/return indices
154 //   to sharding.
155 // - If `return_updated_values_for_all_resources` is true, all resources will be
156 //   included in `resource_updates`, regardless of whether their value changed.
157 // - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
158 // - Sets `*resource_updates` to a description of resources whose values are
159 //   written by the computation; the variable writes are the last
160 // - `resource_updates.size()` return values from the computation. Each entry in
161 //   `resource_updates` is a ResourceUpdate, whose `index` is the index of a
162 //   resource variable argument to the computation to be updated, and `type` is
163 //   the type of the final output.
BuildComputation(const std::vector<XlaCompiler::Argument> & args,const std::vector<XlaExpression> & retvals,const std::map<int,xla::OpSharding> & arg_shardings,const std::map<int,xla::OpSharding> & retval_shardings,const std::vector<std::unique_ptr<XlaResource>> & resources,std::unique_ptr<xla::XlaOp> token_output,const XlaCompiler::ShapeRepresentationFn & shape_representation_fn,bool is_entry_computation,bool return_updated_values_for_all_resources,bool always_return_tuple,bool use_tuple_arg,bool alias_resource_update,xla::XlaBuilder * builder,xla::XlaComputation * computation,int * num_computation_outputs,int * num_nonconst_outputs,std::vector<XlaCompiler::OutputDescription> * outputs,std::vector<XlaCompiler::ResourceUpdate> * resource_updates,xla::Shape * output_shape,absl::Span<int const> input_mapping)164 Status BuildComputation(
165     const std::vector<XlaCompiler::Argument>& args,
166     const std::vector<XlaExpression>& retvals,
167     const std::map<int, xla::OpSharding>& arg_shardings,
168     const std::map<int, xla::OpSharding>& retval_shardings,
169     const std::vector<std::unique_ptr<XlaResource>>& resources,
170     std::unique_ptr<xla::XlaOp> token_output,
171     const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
172     bool is_entry_computation, bool return_updated_values_for_all_resources,
173     bool always_return_tuple, bool use_tuple_arg, bool alias_resource_update,
174     xla::XlaBuilder* builder, xla::XlaComputation* computation,
175     int* num_computation_outputs, int* num_nonconst_outputs,
176     std::vector<XlaCompiler::OutputDescription>* outputs,
177     std::vector<XlaCompiler::ResourceUpdate>* resource_updates,
178     xla::Shape* output_shape, absl::Span<int const> input_mapping) {
179   // Attach a common operator name as metadata. This has no semantic effect — it
180   // merely makes the HLO graph more readable when visualized via TensorBoard,
181   // since TensorBoard forms groups out of operators with similar names.
182   xla::OpMetadata retval_metadata;
183   retval_metadata.set_op_name("XLA_Retvals");
184   builder->SetOpMetadata(retval_metadata);
185   VLOG(1) << "Building new computation";
186   auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); });
187 
188   // Builds a no-op XLA computation. We need to set the sharding of outputs, but
189   // cannot change the sharding of the existing output op. To do this, we build
190   // a new identity op to which shardings can be applied.
191   auto identity_op = [builder](xla::XlaOp op) {
192     return xla::GetTupleElement(xla::Tuple(builder, {op}), 0);
193   };
194 
195   std::vector<xla::XlaOp> elems;
196   elems.reserve(retvals.size());
197 
198   // Keeps track of sharding of each retval. If a retval is not in this list,
199   // replicate sharding is used. The first element is the output index, second
200   // element is the sharding.
201   std::unordered_map<int, xla::OpSharding> retval_index_and_sharding;
202   for (int i = 0, end = retvals.size(); i < end; ++i) {
203     XlaCompiler::OutputDescription& output = (*outputs)[i];
204     const XlaExpression& retval = retvals[i];
205     output.type = retval.dtype();
206     switch (retval.kind()) {
207       case XlaExpression::Kind::kConstant:
208         output.is_constant = true;
209         output.constant_value = *retval.constant_value();
210         output.shape = output.constant_value.shape();
211         break;
212 
213       case XlaExpression::Kind::kTensorList: {
214         output.is_tensor_list = true;
215         xla::XlaOp value = retval.handle();
216         elems.push_back(value);
217         break;
218       }
219 
220       case XlaExpression::Kind::kXlaOp: {
221         output.is_constant = false;
222         TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
223         xla::XlaOp value = retval.handle();
224         auto it = retval_shardings.find(i);
225         absl::optional<xla::OpSharding> sharding =
226             it == retval_shardings.end() ? absl::optional<xla::OpSharding>()
227                                          : it->second;
228         if (it != retval_shardings.end()) {
229           retval_index_and_sharding[elems.size()] = it->second;
230         }
231         if (shape_representation_fn) {
232           TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(value));
233           TF_ASSIGN_OR_RETURN(value,
234                               ReshapeWithCorrectRepresentationAndSharding(
235                                   builder, value, original_shape,
236                                   shape_representation_fn, sharding,
237                                   /*fast_mem=*/false));
238         }
239         if (it != retval_shardings.end()) {
240           xla::XlaScopedShardingAssignment assign_sharding(builder, sharding);
241           // Apply the sharding to the output, if there is a core assignment.
242           value = identity_op(value);
243         }
244 
245         elems.push_back(value);
246         break;
247       }
248 
249       case XlaExpression::Kind::kResource:
250         // Resources will be pushed into elems later when processing resource
251         // arguments below.
252         output.is_constant = false;
253         output.input_index = retval.resource()->arg_num();
254         output.shape = retval.resource()->shape();
255         break;
256 
257       case XlaExpression::Kind::kInvalid:
258         return errors::InvalidArgument(
259             "Invalid expression returned by computation. "
260             "This probably means a return value was not set.");
261     }
262   }
263   *num_nonconst_outputs = elems.size();
264 
265   // Add return values for resources whose values have changed.
266   std::vector<const XlaResource*> arg_resources;
267   arg_resources.reserve(resources.size());
268   for (const auto& resource : resources) {
269     if (resource->arg_num() >= 0) {
270       arg_resources.push_back(resource.get());
271     }
272   }
273   std::sort(arg_resources.begin(), arg_resources.end(),
274             [](const XlaResource* a, const XlaResource* b) {
275               return a->arg_num() < b->arg_num();
276             });
277 
278   absl::flat_hash_map<int, int> argument_to_xla_arg;
279   for (int xla_arg = 0; xla_arg < input_mapping.size(); xla_arg++) {
280     argument_to_xla_arg[input_mapping[xla_arg]] = xla_arg;
281   }
282 
283   std::vector<xla::XlaBuilder::InputOutputAlias> aliases;
284   for (const XlaResource* resource : arg_resources) {
285     DCHECK_LT(resource->arg_num(), args.size());
286     const XlaCompiler::Argument& arg = args[resource->arg_num()];
287     auto it = arg_shardings.find(resource->arg_num());
288     bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
289     // TensorArray gradients were modified if their values changed or there are
290     // any newly created gradients.
291     for (const auto& grad : resource->tensor_array_gradients()) {
292       modified =
293           modified ||
294           !grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
295           arg.tensor_array_gradients.count(grad.first) == 0;
296     }
297 
298     if (return_updated_values_for_all_resources || modified ||
299         arg.requires_broadcast) {
300       resource_updates->emplace_back();
301       XlaCompiler::ResourceUpdate& update = resource_updates->back();
302       update.input_index = resource->arg_num();
303       update.type = resource->type();
304       update.shape = resource->shape();
305       update.modified = modified;
306       int param_num = use_tuple_arg ? 0 : update.input_index;
307       if (is_entry_computation &&
308           arg.resource_kind != XlaResource::kTensorArray &&
309           alias_resource_update && argument_to_xla_arg.count(param_num)) {
310         // Assuming tuple arg and results are used.
311         xla::ShapeIndex param_index =
312             use_tuple_arg ? xla::ShapeIndex({update.input_index})
313                           : xla::ShapeIndex{};
314         int xla_param_num = argument_to_xla_arg[param_num];
315         int64_t output_index_num = elems.size();
316         xla::ShapeIndex output_index = xla::ShapeIndex({output_index_num});
317         VLOG(3) << "Storing alias: " << output_index.ToString() << ": ("
318                 << xla_param_num << ", " << param_index.ToString() << ")";
319         aliases.push_back({output_index, xla_param_num, param_index});
320       }
321       for (const auto& grad : resource->tensor_array_gradients()) {
322         update.tensor_array_gradients_accessed.insert(grad.first);
323       }
324 
325       xla::XlaOp handle;
326       TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
327       auto sharding = it == arg_shardings.end()
328                           ? absl::optional<xla::OpSharding>()
329                           : it->second;
330       // Set layout of the retval to device representation layout.
331       if (shape_representation_fn) {
332         TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(handle));
333         TF_ASSIGN_OR_RETURN(
334             handle, ReshapeWithCorrectRepresentationAndSharding(
335                         builder, handle, original_shape,
336                         shape_representation_fn, sharding, arg.fast_mem));
337       }
338 
339       // Request that the value be returned on a specific core.
340       xla::XlaScopedShardingAssignment assign_sharding(builder, sharding);
341       if (it != arg_shardings.end()) {
342         retval_index_and_sharding[elems.size()] = it->second;
343       }
344       // Ensures the correct sharding is applied to the output.
345       handle = identity_op(handle);
346       elems.push_back(handle);
347     }
348   }
349 
350   // If we have token output, append it as the last one.
351   if (token_output) {
352     elems.push_back(*token_output);
353   }
354 
355   *num_computation_outputs = elems.size();
356 
357   // Builds the XLA computation. We *always* form a tuple here to ensure that
358   // the output value is the last thing added into the XLA computation, even
359   // if there is only one output value.
360   xla::XlaOp tuple;
361   if (retval_index_and_sharding.empty() || !is_entry_computation) {
362     tuple = xla::Tuple(builder, elems);
363   } else {
364     std::vector<xla::Shape> elem_shapes;
365     for (const auto& elem : elems) {
366       TF_ASSIGN_OR_RETURN(xla::Shape elem_shape,
367                           elem.builder()->GetShape(elem));
368       elem_shapes.push_back(elem_shape);
369     }
370     xla::Shape shape = xla::ShapeUtil::MakeTupleShape(elem_shapes);
371     // Copy specified sharding from retval_index_and_sharding.
372     std::vector<xla::HloSharding> sharding_elems;
373     for (int i = 0, end = elems.size(); i < end; i++) {
374       const auto& iter = retval_index_and_sharding.find(i);
375       TF_RET_CHECK(iter != retval_index_and_sharding.end());
376       const xla::OpSharding& sub_op_sharding = iter->second;
377       TF_ASSIGN_OR_RETURN(xla::HloSharding sub_sharding,
378                           xla::HloSharding::FromProto(sub_op_sharding));
379       if (elem_shapes[i].IsTuple()) {
380         const std::vector<xla::HloSharding> sub_sharding_elems =
381             sub_sharding.tuple_elements();
382         const int64_t sub_sharding_elems_size = sub_sharding_elems.size();
383         TF_RET_CHECK(sub_sharding_elems_size ==
384                      xla::ShapeUtil::GetLeafCount(elem_shapes[i]));
385         for (const auto& sub_sharding_elem : sub_sharding_elems) {
386           sharding_elems.push_back(sub_sharding_elem);
387         }
388       } else {
389         sharding_elems.push_back(sub_sharding);
390       }
391     }
392     xla::HloSharding modified_sharding =
393         xla::HloSharding::Tuple(shape, sharding_elems);
394     xla::OpSharding op_sharding = modified_sharding.ToProto();
395     // Assign proper sharding to the tuple instruction.
396     xla::XlaScopedShardingAssignment assign_sharding(builder, op_sharding);
397     tuple = xla::Tuple(builder, elems);
398   }
399   bool returns_tuple = always_return_tuple || elems.size() != 1;
400   VLOG(3) << "Computation returns a tuple=" << returns_tuple;
401   if (!returns_tuple) {
402     xla::GetTupleElement(tuple, 0);
403 
404     for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
405       if (alias.output_index == xla::ShapeIndex({0})) {
406         VLOG(3) << "For aliased parameter " << alias.param_number << ": "
407                 << alias.param_index.ToString()
408                 << " normalizing output_index from {0} to {}, as a scalar is "
409                    "returned from the cluster";
410         alias.output_index = xla::ShapeIndex({});
411       }
412     }
413   }
414 
415   for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
416     builder->SetUpAlias(alias.output_index, alias.param_number,
417                         alias.param_index);
418   }
419 
420   StatusOr<xla::XlaComputation> computation_status = builder->Build();
421   if (!computation_status.ok()) {
422     return computation_status.status();
423   }
424   *computation = computation_status.ConsumeValueOrDie();
425 
426   TF_ASSIGN_OR_RETURN(auto program_shape, computation->GetProgramShape());
427   *output_shape = program_shape.result();
428   return Status::OK();
429 }
430 
431 }  // namespace
432 
433 
HumanString() const434 string XlaCompiler::Argument::HumanString() const {
435   string common;
436   if (!name.empty()) {
437     common = absl::StrCat(" name=", name);
438   }
439   absl::StrAppend(&common, " type=", DataTypeString(type),
440                   " shape=", ShapeHumanString());
441   absl::StrAppend(
442       &common, " is_same_data_across_replicas=", is_same_data_across_replicas);
443   switch (kind) {
444     case kInvalid:
445       return "invalid";
446     case kConstant:
447       return absl::StrCat("kind=constant", common,
448                           " value=", constant_value.DebugString());
449     case kConstantResource:
450       return absl::StrCat("kind=constant-resource", common,
451                           " value=", constant_value.DebugString());
452     case kResource: {
453       string output = absl::StrCat(
454           "kind=resource", common,
455           " resource_kind=", XlaResource::KindToString(resource_kind),
456           " initialized=", initialized, " is_fast_mem=", fast_mem);
457       if (max_array_size >= 0) {
458         absl::StrAppend(&output, " max_array_size=", max_array_size);
459       }
460       if (!tensor_array_gradients.empty()) {
461         absl::StrAppend(&output, " tensor_array_gradients=",
462                         absl::StrJoin(tensor_array_gradients, ","));
463       }
464       return output;
465     }
466     case kParameter:
467       return absl::StrCat("kind=parameter", common);
468     case kTensorList:
469       return absl::StrCat("kind=tensorlist", common);
470     case kToken:
471       return absl::StrCat("token", common);
472   }
473 }
474 
DimensionSizes() const475 std::vector<int64> XlaCompiler::Argument::DimensionSizes() const {
476   if (absl::holds_alternative<TensorShape>(shape)) {
477     return xla::InlinedVectorToVector(
478         absl::get<TensorShape>(shape).dim_sizes());
479   } else {
480     return xla::SpanToVector(absl::get<xla::Shape>(shape).dimensions());
481   }
482 }
483 
484 absl::InlinedVector<int64, 4>
DimensionSizesAsInlinedVector() const485 XlaCompiler::Argument::DimensionSizesAsInlinedVector() const {
486   if (absl::holds_alternative<TensorShape>(shape)) {
487     return absl::get<TensorShape>(shape).dim_sizes();
488   } else {
489     auto v = absl::get<xla::Shape>(shape).dimensions();
490     return absl::InlinedVector<int64, 4>(v.begin(), v.end());
491   }
492 }
493 
ShapeHumanString() const494 string XlaCompiler::Argument::ShapeHumanString() const {
495   if (absl::holds_alternative<TensorShape>(shape)) {
496     return absl::get<TensorShape>(shape).DebugString();
497   } else {
498     return absl::get<xla::Shape>(shape).DebugString();
499   }
500 }
501 
XlaCompiler(XlaCompiler::Options options)502 XlaCompiler::XlaCompiler(XlaCompiler::Options options)
503     : options_(options),
504       initialization_status_(Status::OK()),
505       next_step_id_(1),
506       device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
507       device_mgr_(absl::WrapUnique(device_)) {
508   CHECK(!options_.device_type.type_string().empty());
509   if (options_.populate_resource_manager) {
510     initialization_status_ =
511         (*options_.populate_resource_manager)(device_->resource_manager());
512   }
513 
514   local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
515                                                       FunctionDefLibrary{}));
516   local_pflr_.reset(new ProcessFunctionLibraryRuntime(
517       &device_mgr_, Env::Default(), /*config=*/nullptr,
518       options.graph_def_version, local_flib_def_.get(), OptimizerOptions()));
519   pflr_.reset(new ProcessFunctionLibraryRuntime(
520       &device_mgr_, Env::Default(), /*config=*/nullptr,
521       options.graph_def_version, options.flib_def, OptimizerOptions()));
522 
523   local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
524   flib_runtime_ = pflr_->GetFLR(device_->name());
525 
526   // The default shape representation function is the identity.
527   if (!options_.shape_representation_fn) {
528     options_.shape_representation_fn = IdentityShapeRepresentationFn();
529   }
530 }
531 
532 XlaCompiler::~XlaCompiler() = default;
533 
NextStepId()534 int64 XlaCompiler::NextStepId() { return next_step_id_++; }
535 
operator ()(const std::pair<string,std::vector<Argument>> & signature) const536 uint64 XlaCompiler::SignatureHash::operator()(
537     const std::pair<string, std::vector<Argument>>& signature) const {
538   return std::hash<string>()(signature.first);
539 }
540 
GetFunctionBody(const NameAttrList & function,FunctionLibraryRuntime * flib_runtime,const FunctionBody ** fbody)541 static Status GetFunctionBody(const NameAttrList& function,
542                               FunctionLibraryRuntime* flib_runtime,
543                               const FunctionBody** fbody) {
544   FunctionLibraryRuntime::Handle handle;
545   TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
546       function.name(), AttrSlice(&function.attr()), &handle));
547 
548   *fbody = flib_runtime->GetFunctionBody(handle);
549   TF_RET_CHECK(*fbody);
550   return Status::OK();
551 }
552 
FindFunctionBody(const NameAttrList & function,const FunctionBody ** fbody,const ConfigProto ** config_proto)553 Status XlaCompiler::FindFunctionBody(const NameAttrList& function,
554                                      const FunctionBody** fbody,
555                                      const ConfigProto** config_proto) {
556   // The function may be in either the local_flib_runtime_ or flib_runtime_.
557   // Look up the function in local first and if it is not found then look up the
558   // function in flib_runtime_.
559   auto status = GetFunctionBody(function, local_flib_runtime_, fbody);
560   if (!status.ok()) {
561     if (!errors::IsNotFound(status)) {
562       return status;
563     }
564     TF_RETURN_WITH_CONTEXT_IF_ERROR(
565         GetFunctionBody(function, flib_runtime_, fbody),
566         "Local lookup failed with: ", status.error_message());
567     if (config_proto) {
568       *config_proto = flib_runtime_->config_proto();
569     }
570     VLOG(4) << "Function " << function.name() << " in flib_runtime_";
571   } else {
572     if (config_proto) {
573       *config_proto = local_flib_runtime_->config_proto();
574     }
575     VLOG(4) << "Function " << function.name() << " in local_flib_runtime_";
576   }
577   return Status::OK();
578 }
579 
GetGraph(const FunctionBody * fbody)580 std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
581   std::unique_ptr<Graph> graph(new Graph(options_.flib_def));
582   CopyGraph(*fbody->graph, graph.get());
583 
584   bool is_inside_mustcompile = false;
585   TryGetNodeAttr(AttrSlice(&fbody->fdef.attr()), kXlaMustCompileAttr,
586                  &is_inside_mustcompile);
587 
588   // Performs a first function inlining pass before shape inference, since
589   // otherwise shape inference can't see inside functions and a comprehensive
590   // shape_map, including function ops, is needed to constant-propagate Shape
591   // Ops below.
592   auto flags = GetBuildXlaOpsPassFlags();
593   OptimizerOptions opts;
594   opts.set_opt_level(OptimizerOptions::L0);
595   opts.set_do_common_subexpression_elimination(false);
596   opts.set_do_function_inlining(true);
597   opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding);
598   GraphOptimizer optimizer(opts);
599   // Do not constant fold nodes that output DT_VARIANT type tensors.
600   // XLA does not support Const nodes of Variant type since it needs
601   // to know the original ops to be able to compile them to the relevant
602   // XLA form.
603   // TODO(srbs): This filter is a little conservative. E.g. a subgraph of
604   // the form:
605   //                          Const
606   //                            |
607   // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op
608   //                                                  |
609   //                                        (Discard popped list)
610   //
611   // Would have been reduced to "Const -> Op" without this filter.
612   // However since we are only allowed to specify the filter at the "Node"
613   // level there is no good way to allow the above behavior. So we
614   // disallow any sort of constant folding on Variant nodes for now.
615   //
616   // Also do not consider constant folding Shape ops. When there is a dynamic
617   // dimension in a tensor, TF2XLA currently represent them as the static
618   // upperbound shape, which can be constant folded and then lose the info
619   // that this Shape is dynamic.
620   auto cf_consider_fn = [](const Node* n) {
621     for (const auto& output_arg : n->op_def().output_arg()) {
622       if (output_arg.type() == DT_VARIANT) {
623         return false;
624       }
625     }
626     const auto& ts = n->type_string();
627     // XLA has special logic to handle dynamic shapes, don't constant fold
628     // them.
629     if (ts == "Shape" || ts == "ShapeN" || ts == "Size") {
630       return false;
631     }
632     return true;
633   };
634   GraphOptimizer::Options graph_optimizer_options;
635   graph_optimizer_options.cf_consider_fn = cf_consider_fn;
636   graph_optimizer_options.inline_multi_device_functions = true;
637   graph_optimizer_options.inline_impl_selection_group_functions = true;
638   graph_optimizer_options.inline_with_single_device_body_placer = true;
639   graph_optimizer_options.ignore_noinline = is_inside_mustcompile;
640 
641   {
642     GraphShapeInfo shape_info;
643     InferShapes(graph.get(), /*arg_shapes=*/{},
644                 flib_runtime_->GetFunctionLibraryDefinition(), &shape_info)
645         .IgnoreError();
646     auto node_name_index = graph->BuildNodeNameIndex();
647     std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
648     for (const auto& node_shape_info : shape_info) {
649       const string& node_name = node_shape_info.first;
650       const std::vector<InferredShape>& output_shapes = node_shape_info.second;
651       const auto& node_iter = node_name_index.find(node_name);
652       if (node_iter != node_name_index.end()) {
653         auto& partial_shapes = shape_map[node_name];
654         for (const auto& inferred_shape : output_shapes) {
655           partial_shapes.push_back(inferred_shape.shape);
656         }
657       }
658     }
659     graph_optimizer_options.shape_map = &shape_map;
660     optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
661                        /*device=*/nullptr, &graph, graph_optimizer_options);
662   }
663 
664   // Run shape inference on the graph and optimize the graph again.
665   GraphShapeInfo shape_info;
666   InferShapes(graph.get(), /*arg_shapes=*/{},
667               flib_runtime_->GetFunctionLibraryDefinition(), &shape_info)
668       .IgnoreError();
669   auto node_name_index = graph->BuildNodeNameIndex();
670   std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
671   for (const auto& node_shape_info : shape_info) {
672     const string& node_name = node_shape_info.first;
673     const std::vector<InferredShape>& output_shapes = node_shape_info.second;
674     const auto& node_iter = node_name_index.find(node_name);
675     if (node_iter != node_name_index.end()) {
676       auto& partial_shapes = shape_map[node_name];
677       for (const auto& inferred_shape : output_shapes) {
678         partial_shapes.push_back(inferred_shape.shape);
679       }
680     }
681   }
682   graph_optimizer_options.shape_map = &shape_map;
683   optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
684                      /*device=*/nullptr, &graph, graph_optimizer_options);
685 
686   return graph;
687 }
688 
689 // Collects all control rets from `orig_control_ret_nodes` that are still valid,
690 // keeping the same order.
GetValidControlRets(absl::Span<Node * const> orig_control_ret_nodes,const Graph & graph)691 std::vector<std::string> GetValidControlRets(
692     absl::Span<Node* const> orig_control_ret_nodes, const Graph& graph) {
693   // Build map from control ret node name to index.
694   // We use Node name instead of Node* here to index into the map as we populate
695   // the map with nodes in FunctionDef control_ret_nodes and later query it
696   // using the nodes in `graph`. The Node pointers would be different but the
697   // Node name is expected to remain the same between the two.
698   absl::flat_hash_map<const string, int> control_ret_nodes_map;
699   for (int i = 0; i < orig_control_ret_nodes.size(); ++i) {
700     const Node* n = orig_control_ret_nodes[i];
701     control_ret_nodes_map[n->name()] = i;
702   }
703   // Check which control rets are still valid.
704   std::vector<bool> is_valid_control_ret(orig_control_ret_nodes.size(), false);
705   int num_valid_control_rets = 0;
706   for (const Node* n : graph.nodes()) {
707     auto iter = control_ret_nodes_map.find(n->name());
708     if (iter != control_ret_nodes_map.end()) {
709       ++num_valid_control_rets;
710       is_valid_control_ret[iter->second] = true;
711     }
712   }
713   // Return valid control rets in same order as they appear in
714   // `orig_control_ret_nodes`.
715   std::vector<std::string> valid_control_rets;
716   valid_control_rets.reserve(num_valid_control_rets);
717   for (int i = 0; i < orig_control_ret_nodes.size(); ++i) {
718     if (is_valid_control_ret[i]) {
719       valid_control_rets.push_back(orig_control_ret_nodes[i]->name());
720     }
721   }
722   return valid_control_rets;
723 }
724 
CompileFunction(const XlaCompiler::CompileOptions & options,const NameAttrList & fn_name_attrs,absl::Span<const XlaCompiler::Argument> args,XlaCompiler::CompilationResult * result)725 Status XlaCompiler::CompileFunction(
726     const XlaCompiler::CompileOptions& options,
727     const NameAttrList& fn_name_attrs,
728     absl::Span<const XlaCompiler::Argument> args,
729     XlaCompiler::CompilationResult* result) {
730   const string function_id =
731       Canonicalize(fn_name_attrs.name(), AttrSlice(&fn_name_attrs.attr()));
732   VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
733 
734   const std::vector<XlaCompiler::Argument> arg_vector(args.begin(), args.end());
735   auto it = cache_.find({function_id, arg_vector});
736   if (it != cache_.end()) {
737     *result = it->second;
738     return Status::OK();
739   }
740 
741   const FunctionBody* fbody;
742   const ConfigProto* config = nullptr;
743   TF_RETURN_IF_ERROR(FindFunctionBody(fn_name_attrs, &fbody, &config));
744 
745   absl::optional<ConfigProto> config_proto;
746   if (config) {
747     config_proto = *config;
748   }
749 
750   TF_RETURN_WITH_CONTEXT_IF_ERROR(
751       CheckSignature(fbody->arg_types, args),
752       "Signature check failure while compiling: ", fn_name_attrs.name());
753 
754   // Set shapes for _Arg nodes. They are useful for constant folding (e.g. an
755   // Xla op requires a compile-time constant input, and that input is shape of
756   // an _Arg node.
757   for (int i = 0, end = args.size(); i < end; i++) {
758     // Skip resource variables and tensor lists.
759     DataType dtype;
760     TF_RETURN_IF_ERROR(GetNodeAttr(fbody->arg_nodes[i]->def(), "T", &dtype));
761     if (dtype == DT_RESOURCE || dtype == DT_VARIANT) {
762       continue;
763     }
764 
765     if (absl::holds_alternative<xla::Shape>(args[i].shape)) {
766       xla::Shape xla_shape = absl::get<xla::Shape>(args[i].shape);
767       TensorShape tensor_shape;
768       // If xla_shape is dynamic, prevent constant folding by not setting
769       // output_shapes.
770       if (XLAShapeToTensorShape(xla_shape, &tensor_shape).ok() &&
771           xla_shape.is_static()) {
772         fbody->arg_nodes[i]->ClearAttr("_output_shapes");
773         fbody->arg_nodes[i]->AddAttr("_output_shapes",
774                                      std::vector<TensorShape>{tensor_shape});
775       }
776     } else {
777       TensorShape tensor_shape = absl::get<TensorShape>(args[i].shape);
778       fbody->arg_nodes[i]->ClearAttr("_output_shapes");
779       fbody->arg_nodes[i]->AddAttr("_output_shapes",
780                                    std::vector<TensorShape>{tensor_shape});
781     }
782   }
783 
784   std::unique_ptr<Graph> graph = GetGraph(fbody);
785 
786   // _Arg and _Retval nodes don't exist in the stored subgraph for the function;
787   // they are added by the function body looked up.  Therefore, they don't have
788   // core assignments here.
789   // Attempt to assign a core to each _Retval and _Arg. Chooses the
790   // lowest-numbered core that consumes the argument. We choose the
791   // lowest-numbered core so the assignment is deterministic.
792   for (Node* n : graph->nodes()) {
793     if (n->IsArg()) {
794       TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true));
795     }
796   }
797   // Do _Retval as a second loop, in case the retval's input is an _Arg (which
798   // may have gotten a device assignment from the first loop).
799   for (Node* n : graph->nodes()) {
800     if (n->IsRetval()) {
801       TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false));
802     }
803   }
804 
805   if (VLOG_IS_ON(2)) {
806     VLOG(2) << "XlaCompiler::CompileFunction: "
807             << DumpGraphToFile(
808                    absl::StrCat("xla_compile_function_", function_id), *graph);
809   }
810 
811   VLOG(1) << "====================================================";
812   MlirBridgeRolloutPolicy policy = MlirBridgeRolloutPolicy::kDisabledByUser;
813   if (options.is_entry_computation) {
814     policy = GetMlirBridgeRolloutPolicy(
815         *graph, /*function_library=*/nullptr, config_proto,
816         /*uses_uninitialized_resource_args=*/AnyUninitializedResourceArg(args));
817   }
818   if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) {
819     VLOG(1) << "Using MLIR bridge to compile the function";
820     GraphDebugInfo debug_info;
821 
822     std::vector<std::string> valid_control_rets =
823         GetValidControlRets(fbody->control_ret_nodes, *graph);
824 
825     TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
826         std::move(*graph), mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
827         valid_control_rets, options_.device_type.type_string(),
828         options.use_tuple_arg, /*analyse_graph=*/false, *options_.flib_def,
829         debug_info, options_.shape_representation_fn, result));
830   } else {
831     VLOG(1) << "Using the old bridge to compile the function";
832     TF_RETURN_IF_ERROR(
833         CompileGraph(options, function_id, std::move(graph), args, result));
834   }
835   VLOG(1) << "====================================================";
836 
837   cache_[{function_id, arg_vector}] = *result;
838   return Status::OK();
839 }
840 
841 // Computes the XLA shape for argument 'arg'.
XLAShapeForArgument(const XlaCompiler::Argument & arg,bool is_entry_computation,const absl::optional<xla::HloSharding> & arg_sharding,xla::Shape * xla_shape) const842 Status XlaCompiler::XLAShapeForArgument(
843     const XlaCompiler::Argument& arg, bool is_entry_computation,
844     const absl::optional<xla::HloSharding>& arg_sharding,
845     xla::Shape* xla_shape) const {
846   switch (arg.kind) {
847     case XlaCompiler::Argument::kConstant:
848       LOG(FATAL) << "Unreachable case";
849     case XlaCompiler::Argument::kParameter: {
850       if (is_entry_computation) {
851         TensorShape shape;
852         if (absl::holds_alternative<TensorShape>(arg.shape)) {
853           shape = absl::get<TensorShape>(arg.shape);
854         } else {
855           TF_RETURN_IF_ERROR(
856               XLAShapeToTensorShape(absl::get<xla::Shape>(arg.shape), &shape));
857         }
858         TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn(
859                                             shape, arg.type,
860                                             /*use_fast_memory=*/false));
861         TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
862             arg_sharding, /*use_fast_memory=*/false,
863             options_.shape_representation_fn, xla_shape));
864       } else {
865         if (absl::holds_alternative<xla::Shape>(arg.shape)) {
866           *xla_shape = absl::get<xla::Shape>(arg.shape);
867         } else {
868           TF_RETURN_IF_ERROR(TensorShapeToXLAShape(
869               arg.type, absl::get<TensorShape>(arg.shape), xla_shape));
870         }
871       }
872       return Status::OK();
873     }
874     case XlaCompiler::Argument::kTensorList: {
875       TF_RET_CHECK(absl::holds_alternative<xla::Shape>(arg.shape));
876       *xla_shape = absl::get<xla::Shape>(arg.shape);
877       return Status::OK();
878     }
879     case XlaCompiler::Argument::kConstantResource:
880     case XlaCompiler::Argument::kResource: {
881       TF_RET_CHECK(arg.initialized);
882 
883       switch (arg.resource_kind) {
884         case XlaResource::kVariable: {
885           TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
886           TF_ASSIGN_OR_RETURN(*xla_shape,
887                               options_.shape_representation_fn(
888                                   absl::get<TensorShape>(arg.shape), arg.type,
889                                   /*use_fast_memory=*/arg.fast_mem));
890           TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
891               arg_sharding, arg.fast_mem, options_.shape_representation_fn,
892               xla_shape));
893           return Status::OK();
894         }
895         case XlaResource::kTensorArray: {
896           if (arg.max_array_size < 0) {
897             return errors::InvalidArgument(
898                 "Negative max_array_size in XLAShapeForArgument");
899           }
900           TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
901           TensorShape shape;
902           shape.AddDim(arg.max_array_size);
903           shape.AppendShape(absl::get<TensorShape>(arg.shape));
904           TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape));
905 
906           if (!arg.tensor_array_gradients.empty()) {
907             std::vector<xla::Shape> tuple_shape(
908                 arg.tensor_array_gradients.size() + 1, *xla_shape);
909             *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape);
910           }
911           return Status::OK();
912         }
913         case XlaResource::kStack: {
914           if (arg.max_array_size < 0) {
915             return errors::InvalidArgument(
916                 "Negative max_array_size in XLAShapeForArgument");
917           }
918           TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
919           TensorShape shape;
920           shape.AddDim(arg.max_array_size);
921           shape.AppendShape(absl::get<TensorShape>(arg.shape));
922           xla::Shape buffer_shape;
923           TF_RETURN_IF_ERROR(
924               TensorShapeToXLAShape(arg.type, shape, &buffer_shape));
925           *xla_shape = xla::ShapeUtil::MakeTupleShape(
926               {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})});
927           return Status::OK();
928         }
929 
930         case XlaResource::kInvalid:
931           return errors::Internal(
932               "Invalid resource type in XLAShapeForArgument()");
933       }
934     }
935     case XlaCompiler::Argument::kToken: {
936       *xla_shape = xla::ShapeUtil::MakeTokenShape();
937       return Status::OK();
938     }
939     case XlaCompiler::Argument::kInvalid:
940       return errors::Internal("Invalid argument type in XLAShapeForArgument()");
941   }
942 }
943 
944 /* static */
PopulateArgumentFromResource(const XlaResource & resource,Argument * arg)945 void XlaCompiler::PopulateArgumentFromResource(const XlaResource& resource,
946                                                Argument* arg) {
947   arg->initialized = resource.initialized();
948   arg->kind = XlaCompiler::Argument::kResource;
949   arg->resource_kind = resource.kind();
950 
951   arg->type = resource.type();
952   arg->shape = resource.shape();
953   arg->max_array_size = resource.max_array_size();
954   for (const auto& gradient : resource.tensor_array_gradients()) {
955     arg->tensor_array_gradients.insert(gradient.first);
956   }
957   arg->name = resource.name();
958 }
959 
960 // Builds XLA computations for each of the arguments to the computation.
961 // `args` are the arguments to the computation.
BuildArguments(const Graph & graph,const std::vector<XlaCompiler::Argument> & args,bool use_tuple_arg,xla::XlaBuilder * builder,XlaContext * context,const std::map<int,xla::OpSharding> & arg_shardings,std::vector<XlaExpression> * arg_expressions,std::vector<int> * input_to_args,std::vector<xla::Shape> * input_shapes,bool is_entry_computation)962 Status XlaCompiler::BuildArguments(
963     const Graph& graph, const std::vector<XlaCompiler::Argument>& args,
964     bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context,
965     const std::map<int, xla::OpSharding>& arg_shardings,
966     std::vector<XlaExpression>* arg_expressions,
967     std::vector<int>* input_to_args, std::vector<xla::Shape>* input_shapes,
968     bool is_entry_computation) {
969   arg_expressions->resize(args.size());
970 
971   // Argument numbers of arguments and resources that are to be passed to the
972   // XLA computation as runtime parameters. `input_to_args[a] = b` means that
973   // the a'th XLA input corresponds to the b'th original arg indexes.
974   input_to_args->clear();
975   input_to_args->reserve(args.size());
976 
977   // Fills in constant arguments, and computes non-constant argument order.
978   for (std::vector<XlaCompiler::Argument>::size_type i = 0; i < args.size();
979        ++i) {
980     const XlaCompiler::Argument& arg = args[i];
981     XlaExpression& arg_expression = (*arg_expressions)[i];
982     switch (arg.kind) {
983       case XlaCompiler::Argument::kConstantResource:
984       case XlaCompiler::Argument::kResource: {
985         TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid);
986         TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
987         // TODO(phawkins): this code assumes that resource arguments do not
988         // alias.
989         XlaResource* resource =
990             context->AddResource(absl::make_unique<XlaResource>(
991                 arg.resource_kind, i, arg.name, arg.type,
992                 absl::get<TensorShape>(arg.shape), xla::XlaOp(),
993                 /*max_array_size=*/arg.max_array_size,
994                 /*tensor_array_gradients=*/arg.tensor_array_gradients,
995                 /*tensor_array_multiple_writes_aggregate=*/true,
996                 arg.definition_stack_trace));
997         arg_expression =
998             arg.kind == XlaCompiler::Argument::kResource
999                 ? XlaExpression::Resource(resource)
1000                 : XlaExpression::ConstantResource(arg.constant_value, resource);
1001         if (arg.initialized) {
1002           input_to_args->push_back(i);
1003         }
1004         break;
1005       }
1006       case XlaCompiler::Argument::kParameter:
1007       case XlaCompiler::Argument::kTensorList:
1008       case XlaCompiler::Argument::kToken: {
1009         input_to_args->push_back(i);
1010         break;
1011       }
1012       case XlaCompiler::Argument::kConstant:
1013         arg_expression = XlaExpression::Constant(arg.constant_value);
1014         break;
1015       case XlaCompiler::Argument::kInvalid:
1016         return errors::Internal(
1017             "Unreachable case in BuildArguments() while filling constant args");
1018     }
1019   }
1020 
1021   if (input_to_args->empty() && !use_tuple_arg) {
1022     return Status::OK();
1023   }
1024 
1025   // `arg_to_inputs[c] = d` means that the c'th original arg index corresponds
1026   // to the d'th XLA input. Note that the value -1 corresponds to constants, or
1027   // other args that don't correspond to an input.
1028   std::vector<int> arg_to_inputs(args.size(), -1);
1029   for (int i = 0, end = input_to_args->size(); i < end; i++) {
1030     arg_to_inputs[input_to_args->at(i)] = i;
1031   }
1032 
1033   std::vector<xla::Shape> arg_shapes(input_to_args->size());
1034   for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1035     // Computes the shapes of non-constant arguments.
1036     auto arg_sharding = arg_shardings.find((*input_to_args)[i]);
1037     absl::optional<xla::HloSharding> sharding;
1038     if (arg_sharding != arg_shardings.end()) {
1039       TF_ASSIGN_OR_RETURN(auto hlo_sharding,
1040                           xla::HloSharding::FromProto(arg_sharding->second));
1041       sharding = hlo_sharding;
1042     }
1043     TF_RETURN_IF_ERROR(XLAShapeForArgument(args[(*input_to_args)[i]],
1044                                            is_entry_computation, sharding,
1045                                            &arg_shapes[i]));
1046   }
1047 
1048   if (use_tuple_arg) {
1049     input_shapes->push_back(xla::ShapeUtil::MakeTupleShape(arg_shapes));
1050   } else {
1051     *input_shapes = arg_shapes;
1052   }
1053 
1054   // Attach a common operator name as metadata. This has no semantic effect — it
1055   // merely makes the HLO graph more readable when visualized via TensorBoard,
1056   // since TensorBoard forms groups out of operators with similar names.
1057   xla::OpMetadata arg_metadata;
1058   arg_metadata.set_op_name("XLA_Args");
1059   builder->SetOpMetadata(arg_metadata);
1060 
1061   // Build parameter handles for non-constant arguments.
1062   std::vector<xla::XlaOp> arg_handles(input_to_args->size());
1063   if (use_tuple_arg) {
1064     xla::XlaOp tuple;
1065     if (is_entry_computation) {
1066       xla::OpSharding tuple_sharding;
1067       tuple_sharding.set_type(xla::OpSharding::TUPLE);
1068       for (int64_t parameter : *input_to_args) {
1069         auto it = arg_shardings.find(parameter);
1070         *tuple_sharding.add_tuple_shardings() =
1071             it == arg_shardings.end() ? xla::sharding_builder::AssignDevice(0)
1072                                       : it->second;
1073       }
1074       std::vector<bool> is_same_across_replicas;
1075       for (int i = 0, end = input_to_args->size(); i < end; ++i) {
1076         // Add an entry to is_same_across_replicas for every leaf buffer.
1077         is_same_across_replicas.insert(
1078             is_same_across_replicas.end(),
1079             xla::ShapeUtil::GetLeafCount(arg_shapes[i]),
1080             args[input_to_args->at(i)].is_same_data_across_replicas);
1081       }
1082       xla::XlaScopedShardingAssignment assign_tuple_sharding(
1083           builder, input_to_args->empty() ? absl::optional<xla::OpSharding>()
1084                                           : tuple_sharding);
1085       tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple",
1086                              is_same_across_replicas);
1087     } else {
1088       tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
1089     }
1090 
1091     for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1092       auto it = arg_shardings.find(i);
1093       xla::XlaScopedShardingAssignment assign_sharding(
1094           builder, it == arg_shardings.end() ? absl::optional<xla::OpSharding>()
1095                                              : it->second);
1096       auto& arg = args[input_to_args->at(i)];
1097 
1098       xla::OpMetadata arg_metadata;
1099       arg_metadata.set_op_name(arg.node_name);
1100       builder->SetOneShotOpMetadata(arg_metadata);
1101       arg_handles[i] = xla::GetTupleElement(tuple, i);
1102     }
1103   } else {
1104     for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1105       auto it = arg_shardings.find(i);
1106       xla::XlaScopedShardingAssignment assign_sharding(
1107           builder, it == arg_shardings.end() ? absl::optional<xla::OpSharding>()
1108                                              : it->second);
1109       if (is_entry_computation) {
1110         // Add an entry to is_same_across_replicas for every leaf buffer.
1111         std::vector<bool> is_same_across_replicas(
1112             xla::ShapeUtil::GetLeafCount((*input_shapes)[i]),
1113             args[input_to_args->at(i)].is_same_data_across_replicas);
1114         arg_handles[i] =
1115             xla::Parameter(builder, i, (*input_shapes)[i],
1116                            absl::StrCat("arg", i), is_same_across_replicas);
1117       } else {
1118         arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i],
1119                                         absl::StrCat("arg", i));
1120       }
1121     }
1122   }
1123 
1124   builder->ClearOpMetadata();
1125 
1126   // Fill in the handles in non-constant arguments, and reshape parameters
1127   // back to their correct shapes.
1128   VLOG(2) << "XLA computation inputs:";
1129   for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1130     const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
1131     VLOG(2) << "  XLA arg " << i
1132             << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
1133             << " name: " << arg.name << " TF arg " << input_to_args->at(i)
1134             << " node name: " << arg.node_name
1135             << (arg_shardings.find(i) == arg_shardings.end()
1136                     ? ""
1137                     : absl::StrCat(" sharding: ",
1138                                    arg_shardings.at(i).DebugString()));
1139     XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)];
1140     switch (arg.kind) {
1141       case XlaCompiler::Argument::kConstantResource:
1142       case XlaCompiler::Argument::kResource: {
1143         TF_RET_CHECK(arg.initialized);
1144         XlaResource* resource = arg_expression.resource();
1145         TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients,
1146                                                  arg_handles[i], builder));
1147         VLOG(2) << "    resource: num_gradients: "
1148                 << arg.tensor_array_gradients.size();
1149         break;
1150       }
1151       case XlaCompiler::Argument::kParameter:
1152         // Reshape parameters back to their correct shapes.
1153         // TODO(b/76097077): propagate device assignments onto arguments and
1154         // return values of functions, and then reshape unconditionally.
1155         if (is_entry_computation) {
1156           arg_expression = XlaExpression::XlaOp(
1157               xla::Reshape(arg_handles[i], arg.DimensionSizes()), arg.type);
1158         } else {
1159           arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
1160           if (arg.value_bound) {
1161             TF_RET_CHECK(arg.value_dynamism);
1162             // Propagate upper bound and value dynamism to arg_expression.
1163             arg_expression.set_value_bound(arg.value_bound.value());
1164             arg_expression.set_value_dynamism(arg.value_dynamism.value());
1165           }
1166         }
1167         break;
1168       case XlaCompiler::Argument::kTensorList: {
1169         arg_expression = XlaExpression::TensorList(arg_handles[i]);
1170         break;
1171       }
1172       case XlaCompiler::Argument::kToken: {
1173         arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
1174         break;
1175       }
1176       case XlaCompiler::Argument::kConstant:
1177       case XlaCompiler::Argument::kInvalid:
1178         return errors::Internal(
1179             "Unreachable case in BuildArguments() while filling handles");
1180     }
1181   }
1182 
1183   return Status::OK();
1184 }
1185 
1186 namespace {
1187 
1188 // Check that the ops of all non-functional nodes have been registered.
ValidateFunctionDef(const FunctionDef * fdef,const FunctionLibraryDefinition & flib_def)1189 Status ValidateFunctionDef(const FunctionDef* fdef,
1190                            const FunctionLibraryDefinition& flib_def) {
1191   for (const NodeDef& node : fdef->node_def()) {
1192     const string& op = node.op();
1193     if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) {
1194       continue;
1195     }
1196     const OpDef* op_def;
1197     TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(op, &op_def));
1198   }
1199   return Status::OK();
1200 }
1201 
1202 // If node is PartitionedCall or StatefulPartitionedCall, returns the
1203 // name from the "f" attr, else returns node.def().op().
1204 // Returned pointer points to the internal string either in node's attributes
1205 // or in its NodeDef. This pointer is valid as long as the node has not been
1206 // modified.
GetPotentialFunctionName(const Node & node,const string ** name)1207 Status GetPotentialFunctionName(const Node& node, const string** name) {
1208   if (node.IsPartitionedCall()) {
1209     const AttrValue* attr_value;
1210     TF_RETURN_IF_ERROR(
1211         node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
1212     if (!attr_value->has_func()) {
1213       return errors::InvalidArgument(
1214           "The attribute value for attribute 'f' in node ", node.DebugString(),
1215           " does not have 'func' field set");
1216     }
1217     *name = &attr_value->func().name();
1218     return Status::OK();
1219   }
1220   *name = &node.type_string();
1221   return Status::OK();
1222 }
1223 
1224 // Check that the graph doesn't have any invalid nodes (e.g. incompatible with
1225 // given device_type, invalid data type, missing attributes...)
ValidateGraph(const Graph * graph,const FunctionLibraryDefinition & flib_def,const DeviceType & device_type,const string & name)1226 Status ValidateGraph(const Graph* graph,
1227                      const FunctionLibraryDefinition& flib_def,
1228                      const DeviceType& device_type, const string& name) {
1229   // Make sure the XLA compilation kernels are registered.  This operation is
1230   // idempotent so it is fine if someone called it already.
1231   XlaOpRegistry::RegisterCompilationKernels();
1232 
1233   auto maybe_error = [&](const Node* node, const Status& s) -> Status {
1234     if (!s.ok()) {
1235       std::string errmsg = absl::StrCat(
1236           "Detected unsupported operations when trying to compile graph ", name,
1237           " on ", device_type.type_string(), ": ", node->def().op(), " (",
1238           s.error_message(), ")", FormatNodeForError(*node));
1239       if (absl::StrContains(device_type.type_string(), "TPU")) {
1240         absl::StrAppend(&errmsg,
1241                         "\nOne approach is to outside compile the unsupported "
1242                         "ops to run on CPUs by enabling soft placement "
1243                         "`tf.config.set_soft_device_placement(True)`."
1244                         " This has a potential performance penalty.\n");
1245       }
1246       if (std::shared_ptr<AbstractStackTrace> stack_trace =
1247               node->GetStackTrace()) {
1248         absl::StrAppend(
1249             &errmsg, "\nThe op is created at: \n",
1250             stack_trace->ToString({/*show_line_contents =*/true,
1251                                    /*filter_common_prefix =*/true,
1252                                    /*drop_internal_frames =*/true}));
1253       }
1254 
1255       return errors::InvalidArgument(errmsg);
1256     }
1257     return Status::OK();
1258   };
1259 
1260   for (const Node* node : graph->nodes()) {
1261     if (node->type_string() == FunctionLibraryDefinition::kGradientOp) {
1262       continue;
1263     }
1264     const string* function_name;
1265     TF_RETURN_IF_ERROR(GetPotentialFunctionName(*node, &function_name));
1266     const FunctionDef* fdef = flib_def.Find(*function_name);
1267     Status s;
1268     if (fdef) {
1269       s = ValidateFunctionDef(fdef, flib_def);
1270       TF_RETURN_IF_ERROR(maybe_error(node, s));
1271       continue;
1272     }
1273     const OpDef* op_def;
1274     s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def);
1275     TF_RETURN_IF_ERROR(maybe_error(node, s));
1276     TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def));
1277     s = FindKernelDef(device_type, node->def(), nullptr, nullptr);
1278     TF_RETURN_IF_ERROR(maybe_error(node, s));
1279   }
1280   return Status::OK();
1281 }
1282 
ConvertConstantsToExpressions(xla::XlaBuilder * builder,absl::Span<XlaExpression> expressions)1283 void ConvertConstantsToExpressions(xla::XlaBuilder* builder,
1284                                    absl::Span<XlaExpression> expressions) {
1285   for (XlaExpression& expression : expressions) {
1286     if (expression.kind() == XlaExpression::Kind::kConstant) {
1287       expression =
1288           XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype());
1289     }
1290   }
1291 }
1292 
1293 }  // namespace
1294 
CompileGraph(const XlaCompiler::CompileOptions & options,string const & name,std::unique_ptr<Graph> graph,absl::Span<const XlaCompiler::Argument> args,CompilationResult * result)1295 Status XlaCompiler::CompileGraph(
1296     const XlaCompiler::CompileOptions& options, string const& name,
1297     std::unique_ptr<Graph> graph, absl::Span<const XlaCompiler::Argument> args,
1298     CompilationResult* result) {
1299   VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name;
1300 
1301   TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes(
1302       graph.get(), options_.flib_def, local_flib_def_.get()));
1303   TF_RETURN_IF_ERROR(RearrangeFunctionArguments(
1304       [this](const NameAttrList& function, const FunctionBody** fbody) {
1305         return FindFunctionBody(function, fbody);
1306       },
1307       graph.get(), local_flib_def_.get(),
1308       pflr_->GetFunctionLibraryDefinition()));
1309 
1310   if (VLOG_IS_ON(2)) {
1311     VLOG(2) << "XlaCompiler::CompileGraph: "
1312             << DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph,
1313                                flib_runtime_->GetFunctionLibraryDefinition());
1314   }
1315 
1316   // Report the error here if initialization failed.
1317   TF_RETURN_IF_ERROR(initialization_status_);
1318 
1319   // Detect invalid nodes.
1320   // FunctionalizeControlFlow may remove some nodes from the graph.
1321   TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
1322                                    options_.device_type, name));
1323   xla::XlaBuilder builder(name);
1324   XlaContext* context = new XlaContext(this, &builder, graph.get());
1325   core::ScopedUnref context_unref(context);
1326 
1327   std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());
1328   int token_input_index = -1;
1329   std::unique_ptr<xla::XlaOp> token_output;
1330   if (options.add_token_input_output) {
1331     // Add extra token input.
1332     token_input_index = real_args.size();
1333 
1334     XlaCompiler::Argument token_arg;
1335     token_arg.kind = XlaCompiler::Argument::kToken;
1336     real_args.push_back(token_arg);
1337   }
1338 
1339   std::map<int, xla::OpSharding> arg_shardings;
1340   std::map<int, xla::OpSharding> retval_shardings;
1341   TF_ASSIGN_OR_RETURN(std::tie(arg_shardings, retval_shardings),
1342                       ComputeArgAndRetvalShardings(*graph));
1343 
1344   std::vector<XlaExpression> arg_expressions;
1345   TF_RETURN_IF_ERROR(BuildArguments(
1346       *graph, real_args, options.use_tuple_arg, &builder, context,
1347       arg_shardings, &arg_expressions, &result->input_mapping,
1348       &result->xla_input_shapes, options.is_entry_computation));
1349   context->set_args(std::move(arg_expressions));
1350 
1351   PushNodeTokenMapping();
1352   // Use std::set instead of std::unordered_set to ensure determinism.
1353   std::set<std::string> output_node_token_inputs;
1354   if (token_input_index != -1) {
1355     // Original token comes from input.
1356     auto arg_expression = context->args()[token_input_index];
1357     TF_RETURN_IF_ERROR(
1358         SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle()));
1359 
1360     // Calculate token inputs for output token.
1361     output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph);
1362 
1363     // If there's no side-effecting op in the graph, use token input as token
1364     // output.
1365     if (output_node_token_inputs.empty()) {
1366       output_node_token_inputs.insert(kXlaTokenArgNodeName);
1367     }
1368   } else if (options.is_entry_computation) {
1369     // Original token is manually created.
1370     if (HasSideEffectingNodes(*graph)) {
1371       TF_RETURN_IF_ERROR(
1372           SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder)));
1373     }
1374   }
1375 
1376   Status execute_status = ExecuteGraph(context, std::move(graph), device_,
1377                                        flib_runtime_, NextStepId());
1378   if (!execute_status.ok()) {
1379     VLOG(1) << "Failed executing graph " << name;
1380     return execute_status;
1381   }
1382   if (token_input_index != -1) {
1383     // Add extra token output.
1384     std::vector<xla::XlaOp> token_inputs;
1385     for (const auto& node_name : output_node_token_inputs) {
1386       auto token_or = GetNodeToken(node_name);
1387       TF_RETURN_IF_ERROR(token_or.status());
1388       token_inputs.push_back(token_or.ValueOrDie());
1389     }
1390     token_output.reset(new xla::XlaOp(xla::AfterAll(&builder, token_inputs)));
1391   }
1392   TF_RETURN_IF_ERROR(PopNodeTokenMapping());
1393 
1394   int num_nonconst_outputs;
1395   int num_computation_outputs;
1396   result->computation = std::make_shared<xla::XlaComputation>();
1397   result->outputs.resize(context->retvals().size());
1398   std::vector<XlaExpression> retvals = context->retvals();
1399   ConvertConstantsToExpressions(&builder, absl::Span<XlaExpression>(retvals));
1400   TF_RETURN_IF_ERROR(BuildComputation(
1401       real_args, retvals, arg_shardings, retval_shardings, context->resources(),
1402       std::move(token_output),
1403       options.is_entry_computation ? options_.shape_representation_fn
1404                                    : ShapeRepresentationFn{},
1405       options.is_entry_computation,
1406       options.return_updated_values_for_all_resources,
1407       options.always_return_tuple, options.use_tuple_arg,
1408       options.alias_resource_update, &builder, result->computation.get(),
1409       &num_computation_outputs, &num_nonconst_outputs, &result->outputs,
1410       &result->resource_updates, &result->xla_output_shape,
1411       result->input_mapping));
1412 
1413   VLOG(2) << "Outputs: total: " << context->retvals().size()
1414           << " nonconstant: " << num_nonconst_outputs;
1415   VLOG(2) << "XLA output shape: "
1416           << xla::ShapeUtil::HumanStringWithLayout(result->xla_output_shape);
1417   result->collective_reduce_info = context->GetCollectiveReduceV2OpInfo();
1418   return Status::OK();
1419 }
1420 
GetChannelHandle(const string & key,xla::ChannelHandle * channel)1421 Status XlaCompiler::GetChannelHandle(const string& key,
1422                                      xla::ChannelHandle* channel) {
1423   auto result = channels_.emplace(key, xla::ChannelHandle());
1424   if (result.second) {
1425     TF_ASSIGN_OR_RETURN(result.first->second, client()->CreateChannelHandle());
1426   }
1427   *channel = result.first->second;
1428   VLOG(1) << "Channel: " << key << " " << channel->DebugString();
1429   return Status::OK();
1430 }
1431 
GetHostToDeviceChannelHandle(const string & key,xla::ChannelHandle * channel)1432 Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key,
1433                                                  xla::ChannelHandle* channel) {
1434   auto result = channels_.emplace(key, xla::ChannelHandle());
1435   if (result.second) {
1436     TF_ASSIGN_OR_RETURN(result.first->second,
1437                         client()->CreateHostToDeviceChannelHandle());
1438   }
1439   *channel = result.first->second;
1440   VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString();
1441   return Status::OK();
1442 }
1443 
GetDeviceToHostChannelHandle(const string & key,xla::ChannelHandle * channel)1444 Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key,
1445                                                  xla::ChannelHandle* channel) {
1446   auto result = channels_.emplace(key, xla::ChannelHandle());
1447   if (result.second) {
1448     TF_ASSIGN_OR_RETURN(result.first->second,
1449                         client()->CreateDeviceToHostChannelHandle());
1450   }
1451   *channel = result.first->second;
1452   VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString();
1453   return Status::OK();
1454 }
1455 
1456 namespace {
1457 
SetTransfer(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes,tf2xla::HostTransferMetadata * transfer)1458 void SetTransfer(const string& key, absl::Span<const DataType> types,
1459                  absl::Span<const TensorShape> shapes,
1460                  tf2xla::HostTransferMetadata* transfer) {
1461   transfer->set_key(key);
1462   CHECK(types.size() == shapes.size());
1463   for (int i = 0, end = types.size(); i < end; ++i) {
1464     tf2xla::TensorMetadata* metadata = transfer->add_metadata();
1465     metadata->set_type(types[i]);
1466     shapes[i].AsProto(metadata->mutable_shape());
1467   }
1468 }
1469 
1470 }  // namespace
1471 
SetDeviceToHostMetadata(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes)1472 Status XlaCompiler::SetDeviceToHostMetadata(
1473     const string& key, absl::Span<const DataType> types,
1474     absl::Span<const TensorShape> shapes) {
1475   if (host_compute_sends_.find(key) != host_compute_sends_.end()) {
1476     tf2xla::HostTransferMetadata& existing_transfer = host_compute_sends_[key];
1477     tf2xla::HostTransferMetadata new_transfer;
1478     SetTransfer(key, types, shapes, &new_transfer);
1479     if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) {
1480       return Status::OK();
1481     } else {
1482       return errors::InvalidArgument(
1483           "Duplicate calls to SetDeviceToHostMetadata with key ", key);
1484     }
1485   }
1486   tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key];
1487   SetTransfer(key, types, shapes, &transfer);
1488   return Status::OK();
1489 }
1490 
GetDeviceToHostShapes(const string & key,std::vector<TensorShape> * shapes) const1491 Status XlaCompiler::GetDeviceToHostShapes(
1492     const string& key, std::vector<TensorShape>* shapes) const {
1493   const auto iter = host_compute_sends_.find(key);
1494   if (iter == host_compute_sends_.end()) {
1495     return errors::InvalidArgument(
1496         "No host compute send shapes registered for key ", key);
1497   }
1498   shapes->clear();
1499   for (int i = 0; i < iter->second.metadata_size(); ++i) {
1500     TensorShape shape(iter->second.metadata(i).shape());
1501     shapes->push_back(shape);
1502   }
1503   return Status::OK();
1504 }
1505 
SetHostToDeviceMetadata(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes)1506 Status XlaCompiler::SetHostToDeviceMetadata(
1507     const string& key, absl::Span<const DataType> types,
1508     absl::Span<const TensorShape> shapes) {
1509   if (host_compute_recvs_.find(key) != host_compute_recvs_.end()) {
1510     tf2xla::HostTransferMetadata& existing_transfer = host_compute_recvs_[key];
1511     tf2xla::HostTransferMetadata new_transfer;
1512     SetTransfer(key, types, shapes, &new_transfer);
1513     if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) {
1514       return Status::OK();
1515     } else {
1516       return errors::InvalidArgument(
1517           "Duplicate calls to SetHostToDeviceMetadata with key ", key);
1518     }
1519   }
1520   tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key];
1521   SetTransfer(key, types, shapes, &transfer);
1522   return Status::OK();
1523 }
1524 
GetHostComputeControlDependency(const string & host_compute_name,xla::XlaOp * handle)1525 Status XlaCompiler::GetHostComputeControlDependency(
1526     const string& host_compute_name, xla::XlaOp* handle) {
1527   const auto iter = host_compute_control_output_.find(host_compute_name);
1528   if (iter == host_compute_control_output_.end()) {
1529     return errors::InvalidArgument(
1530         "No registered control handle for host compute Op '", host_compute_name,
1531         "'");
1532   } else {
1533     *handle = iter->second;
1534   }
1535   return Status::OK();
1536 }
1537 
SetHostComputeControlDependency(const string & host_compute_name,const xla::XlaOp & handle)1538 Status XlaCompiler::SetHostComputeControlDependency(
1539     const string& host_compute_name, const xla::XlaOp& handle) {
1540   if (host_compute_control_output_.find(host_compute_name) !=
1541       host_compute_control_output_.end()) {
1542     return errors::InvalidArgument(
1543         "Duplicate control handles registered for for host compute Op ",
1544         host_compute_name);
1545   }
1546   host_compute_control_output_[host_compute_name] = handle;
1547   return Status::OK();
1548 }
1549 
PushNodeTokenMapping()1550 void XlaCompiler::PushNodeTokenMapping() {
1551   node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{});
1552 }
1553 
PopNodeTokenMapping()1554 Status XlaCompiler::PopNodeTokenMapping() {
1555   if (node_token_mapping_stack_.empty()) {
1556     return errors::FailedPrecondition(
1557         "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is "
1558         "empty.");
1559   }
1560   node_token_mapping_stack_.pop();
1561   return Status::OK();
1562 }
1563 
SetNodeToken(const string & node_name,const xla::XlaOp & op)1564 Status XlaCompiler::SetNodeToken(const string& node_name,
1565                                  const xla::XlaOp& op) {
1566   if (node_token_mapping_stack_.empty()) {
1567     return errors::FailedPrecondition(
1568         "Calling SetNodeToken() when node_token_mapping_stack_ is "
1569         "empty.");
1570   }
1571   auto insert_result = node_token_mapping_stack_.top().insert({node_name, op});
1572   if (!insert_result.second) {
1573     return errors::FailedPrecondition("Token mapping already exists for node ",
1574                                       node_name);
1575   }
1576   return Status::OK();
1577 }
1578 
GetNodeToken(const string & node_name)1579 StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
1580   if (node_token_mapping_stack_.empty()) {
1581     return errors::FailedPrecondition(
1582         "Calling GetNodeToken() when node_token_mapping_stack_ is "
1583         "empty.");
1584   }
1585   auto iter = node_token_mapping_stack_.top().find(node_name);
1586   if (iter == node_token_mapping_stack_.top().end()) {
1587     return errors::FailedPrecondition("Cannot find token mapping for node ",
1588                                       node_name);
1589   }
1590   return iter->second;
1591 }
1592 
1593 }  // namespace tensorflow
1594