• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
17 
18 #include <numeric>
19 #include <vector>
20 
21 #include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "absl/types/variant.h"
25 #include "tensorflow/compiler/jit/defs.h"
26 #include "tensorflow/compiler/jit/flags.h"
27 #include "tensorflow/compiler/jit/shape_inference.h"
28 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
29 #include "tensorflow/compiler/mlir/utils/array_container_utils.h"
30 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
31 #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
32 #include "tensorflow/compiler/tf2xla/shape_util.h"
33 #include "tensorflow/compiler/tf2xla/sharding_util.h"
34 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
35 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
36 #include "tensorflow/compiler/tf2xla/type_util.h"
37 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
38 #include "tensorflow/compiler/tf2xla/xla_context.h"
39 #include "tensorflow/compiler/xla/client/client_library.h"
40 #include "tensorflow/compiler/xla/client/xla_builder.h"
41 #include "tensorflow/compiler/xla/client/xla_computation.h"
42 #include "tensorflow/compiler/xla/protobuf_util.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/core/common_runtime/device.h"
46 #include "tensorflow/core/common_runtime/executor.h"
47 #include "tensorflow/core/common_runtime/function.h"
48 #include "tensorflow/core/common_runtime/graph_constructor.h"
49 #include "tensorflow/core/common_runtime/graph_optimizer.h"
50 #include "tensorflow/core/framework/attr_value_util.h"
51 #include "tensorflow/core/framework/function.h"
52 #include "tensorflow/core/framework/node_def_util.h"
53 #include "tensorflow/core/framework/types.h"
54 #include "tensorflow/core/graph/node_builder.h"
55 #include "tensorflow/core/lib/core/errors.h"
56 #include "tensorflow/core/lib/gtl/cleanup.h"
57 #include "tensorflow/core/lib/hash/hash.h"
58 #include "tensorflow/core/platform/logging.h"
59 #include "tensorflow/core/protobuf/error_codes.pb.h"
60 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
61 #include "tensorflow/core/util/dump_graph.h"
62 
63 namespace tensorflow {
64 namespace {
65 
66 // Checks that arguments `args` match types `types`.
CheckSignature(const DataTypeVector & types,absl::Span<const XlaCompiler::Argument> args)67 Status CheckSignature(const DataTypeVector& types,
68                       absl::Span<const XlaCompiler::Argument> args) {
69   if (args.size() != types.size()) {
70     return errors::Internal("Compilation arguments have ", args.size(),
71                             " elements while function has ", types.size());
72   }
73   for (int i = 0, end = types.size(); i < end; ++i) {
74     // Don't perform type checks on resource variables and tensor
75     // lists (DT_VARIANT) as we have to trick the type system in order to
76     // plumb them through. DT_VARIANTS are wrapped in a DT_UINT8 tensor.
77     if (types[i] != args[i].type && types[i] != DT_RESOURCE &&
78         types[i] != DT_VARIANT) {
79       return errors::Internal(
80           "Argument ", i, " has declared type ", DataTypeString(args[i].type),
81           " but function parameter has type ", DataTypeString(types[i]));
82     }
83   }
84   return Status::OK();
85 }
86 
87 // Uses the _Arg and _Retval nodes in the graph to determine an OpSharding for
88 // each argument and return value.
89 xla::StatusOr<
90     std::pair<std::map<int, xla::OpSharding>, std::map<int, xla::OpSharding>>>
ComputeArgAndRetvalShardings(const Graph & graph)91 ComputeArgAndRetvalShardings(const Graph& graph) {
92   auto get_sharding_for_node =
93       [](const Node* n) -> xla::StatusOr<absl::optional<xla::OpSharding>> {
94     TF_ASSIGN_OR_RETURN(
95         auto sharding,
96         ParseShardingFromDevice(*n, std::numeric_limits<int32>::max(),
97                                 /*add_metadata=*/false));
98     return sharding;
99   };
100   std::map<int, xla::OpSharding> arg_shardings;
101   std::map<int, xla::OpSharding> retval_shardings;
102   for (const Node* n : graph.nodes()) {
103     if (n->IsArg()) {
104       TF_ASSIGN_OR_RETURN(auto sharding, get_sharding_for_node(n));
105       if (!sharding.has_value()) continue;
106       int index;
107       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
108       TF_RET_CHECK(index >= 0) << "Negative _Arg index";
109       arg_shardings[index] = std::move(*sharding);
110     } else if (n->IsRetval()) {
111       TF_ASSIGN_OR_RETURN(auto sharding, get_sharding_for_node(n));
112       if (!sharding.has_value()) continue;
113       int index;
114       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
115       TF_RET_CHECK(index >= 0) << "Negative _Retval index";
116       retval_shardings[index] = std::move(*sharding);
117     }
118   }
119   return std::make_pair(std::move(arg_shardings), std::move(retval_shardings));
120 }
121 
ExecuteGraph(XlaContext * xla_context,std::unique_ptr<Graph> graph,XlaCompilationDevice * device,FunctionLibraryRuntime * flib,int64 step_id)122 Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
123                     XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
124                     int64 step_id) {
125   // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
126   // resource manager takes ownership via Create, and unrefs via Cleanup.  We
127   // explicitly add a reference to ensure the refcount at entry is maintained at
128   // all exit points; Create and Cleanup are always called in this function.
129   //
130   // The Executor requires us to use ScopedStepContainer. We wrap it in a
131   // unique_ptr so we can capture the cleanup status in the end.
132   xla_context->Ref();
133   Status status;
134   auto step_container = absl::make_unique<ScopedStepContainer>(
135       step_id, [&status, device](const string& name) {
136         status = device->resource_manager()->Cleanup(name);
137       });
138   TF_RETURN_IF_ERROR(step_container->Create(device->resource_manager(),
139                                             XlaContext::kXlaContextResourceName,
140                                             xla_context));
141 
142   GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
143   TF_RETURN_IF_ERROR(graph_compiler.Compile());
144   // Explicitly clean up the step container, to capture the cleanup status.
145   step_container.reset();
146   return Status::OK();
147 }
148 
149 // Builds the XLA computation.
150 // - `args` is the list of input arguments
151 // - `retvals` is the list of retvals produced by _Retval operators, in index
152 //   order.
153 // - `arg_shardings` and `retval_shardings` are mapping from arg/return indices
154 //   to sharding.
155 // - If `return_updated_values_for_all_resources` is true, all resources will be
156 //   included in `resource_updates`, regardless of whether their value changed.
157 // - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
158 // - Sets `*resource_updates` to a description of resources whose values are
159 //   written by the computation; the variable writes are the last
160 // - `resource_updates.size()` return values from the computation. Each entry in
161 //   `resource_updates` is a ResourceUpdate, whose `index` is the index of a
162 //   resource variable argument to the computation to be updated, and `type` is
163 //   the type of the final output.
BuildComputation(const std::vector<XlaCompiler::Argument> & args,const std::vector<XlaExpression> & retvals,const std::map<int,xla::OpSharding> & arg_shardings,const std::map<int,xla::OpSharding> & retval_shardings,const std::vector<std::unique_ptr<XlaResource>> & resources,std::unique_ptr<xla::XlaOp> token_output,const XlaCompiler::ShapeRepresentationFn & shape_representation_fn,bool is_entry_computation,bool return_updated_values_for_all_resources,bool always_return_tuple,bool use_tuple_arg,bool alias_resource_update,xla::XlaBuilder * builder,xla::XlaComputation * computation,int * num_computation_outputs,int * num_nonconst_outputs,std::vector<XlaCompiler::OutputDescription> * outputs,std::vector<XlaCompiler::ResourceUpdate> * resource_updates,xla::Shape * output_shape,absl::Span<int const> input_mapping)164 Status BuildComputation(
165     const std::vector<XlaCompiler::Argument>& args,
166     const std::vector<XlaExpression>& retvals,
167     const std::map<int, xla::OpSharding>& arg_shardings,
168     const std::map<int, xla::OpSharding>& retval_shardings,
169     const std::vector<std::unique_ptr<XlaResource>>& resources,
170     std::unique_ptr<xla::XlaOp> token_output,
171     const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
172     bool is_entry_computation, bool return_updated_values_for_all_resources,
173     bool always_return_tuple, bool use_tuple_arg, bool alias_resource_update,
174     xla::XlaBuilder* builder, xla::XlaComputation* computation,
175     int* num_computation_outputs, int* num_nonconst_outputs,
176     std::vector<XlaCompiler::OutputDescription>* outputs,
177     std::vector<XlaCompiler::ResourceUpdate>* resource_updates,
178     xla::Shape* output_shape, absl::Span<int const> input_mapping) {
179   // Attach a common operator name as metadata. This has no semantic effect — it
180   // merely makes the HLO graph more readable when visualized via TensorBoard,
181   // since TensorBoard forms groups out of operators with similar names.
182   xla::OpMetadata retval_metadata;
183   retval_metadata.set_op_name("XLA_Retvals");
184   builder->SetOpMetadata(retval_metadata);
185   VLOG(1) << "Building new computation";
186   auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); });
187 
188   // Builds a no-op XLA computation. We need to set the sharding of outputs, but
189   // cannot change the sharding of the existing output op. To do this, we build
190   // a new identity op to which shardings can be applied.
191   auto identity_op = [builder](xla::XlaOp op) {
192     return xla::GetTupleElement(xla::Tuple(builder, {op}), 0);
193   };
194 
195   std::vector<xla::XlaOp> elems;
196   elems.reserve(retvals.size());
197 
198   // Keeps track of sharding of each retval. If a retval is not in this list,
199   // replicate sharding is used. The first element is the output index, second
200   // element is the sharding.
201   std::unordered_map<int, xla::OpSharding> retval_index_and_sharding;
202   for (int i = 0, end = retvals.size(); i < end; ++i) {
203     XlaCompiler::OutputDescription& output = (*outputs)[i];
204     const XlaExpression& retval = retvals[i];
205     output.type = retval.dtype();
206     switch (retval.kind()) {
207       case XlaExpression::Kind::kConstant:
208         output.is_constant = true;
209         output.constant_value = *retval.constant_value();
210         output.shape = output.constant_value.shape();
211         break;
212 
213       case XlaExpression::Kind::kTensorList: {
214         output.is_tensor_list = true;
215         xla::XlaOp value = retval.handle();
216         elems.push_back(value);
217         break;
218       }
219 
220       case XlaExpression::Kind::kXlaOp: {
221         output.is_constant = false;
222         TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
223         xla::XlaOp value = retval.handle();
224         auto it = retval_shardings.find(i);
225         absl::optional<xla::OpSharding> sharding =
226             it == retval_shardings.end() ? absl::optional<xla::OpSharding>()
227                                          : it->second;
228         if (it != retval_shardings.end()) {
229           retval_index_and_sharding[elems.size()] = it->second;
230         }
231         if (shape_representation_fn) {
232           TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(value));
233           TF_ASSIGN_OR_RETURN(value,
234                               ReshapeWithCorrectRepresentationAndSharding(
235                                   builder, value, original_shape,
236                                   shape_representation_fn, sharding,
237                                   /*fast_mem=*/false));
238         }
239         if (it != retval_shardings.end()) {
240           xla::XlaScopedShardingAssignment assign_sharding(builder, sharding);
241           // Apply the sharding to the output, if there is a core assignment.
242           value = identity_op(value);
243         }
244 
245         elems.push_back(value);
246         break;
247       }
248 
249       case XlaExpression::Kind::kResource:
250         // Resources will be pushed into elems later when processing resource
251         // arguments below.
252         output.is_constant = false;
253         output.input_index = retval.resource()->arg_num();
254         output.shape = retval.resource()->shape();
255         break;
256 
257       case XlaExpression::Kind::kInvalid:
258         return errors::InvalidArgument(
259             "Invalid expression returned by computation. "
260             "This probably means a return value was not set.");
261     }
262   }
263   *num_nonconst_outputs = elems.size();
264 
265   // Add return values for resources whose values have changed.
266   std::vector<const XlaResource*> arg_resources;
267   arg_resources.reserve(resources.size());
268   for (const auto& resource : resources) {
269     if (resource->arg_num() >= 0) {
270       arg_resources.push_back(resource.get());
271     }
272   }
273   std::sort(arg_resources.begin(), arg_resources.end(),
274             [](const XlaResource* a, const XlaResource* b) {
275               return a->arg_num() < b->arg_num();
276             });
277 
278   absl::flat_hash_map<int, int> argument_to_xla_arg;
279   for (int xla_arg = 0; xla_arg < input_mapping.size(); xla_arg++) {
280     argument_to_xla_arg[input_mapping[xla_arg]] = xla_arg;
281   }
282 
283   std::vector<xla::XlaBuilder::InputOutputAlias> aliases;
284   for (const XlaResource* resource : arg_resources) {
285     DCHECK_LT(resource->arg_num(), args.size());
286     const XlaCompiler::Argument& arg = args[resource->arg_num()];
287     auto it = arg_shardings.find(resource->arg_num());
288     bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
289     // TensorArray gradients were modified if their values changed or there are
290     // any newly created gradients.
291     for (const auto& grad : resource->tensor_array_gradients()) {
292       modified =
293           modified ||
294           !grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
295           arg.tensor_array_gradients.count(grad.first) == 0;
296     }
297 
298     if (return_updated_values_for_all_resources || modified) {
299       resource_updates->emplace_back();
300       XlaCompiler::ResourceUpdate& update = resource_updates->back();
301       update.input_index = resource->arg_num();
302       update.type = resource->type();
303       update.shape = resource->shape();
304       update.modified = modified;
305       int param_num = use_tuple_arg ? 0 : update.input_index;
306       if (is_entry_computation &&
307           arg.resource_kind != XlaResource::kTensorArray &&
308           alias_resource_update && argument_to_xla_arg.count(param_num)) {
309         // Assuming tuple arg and results are used.
310         xla::ShapeIndex param_index =
311             use_tuple_arg ? xla::ShapeIndex({update.input_index})
312                           : xla::ShapeIndex{};
313         int xla_param_num = argument_to_xla_arg[param_num];
314         int64 output_index_num = elems.size();
315         xla::ShapeIndex output_index = xla::ShapeIndex({output_index_num});
316         VLOG(3) << "Storing alias: " << output_index.ToString() << ": ("
317                 << xla_param_num << ", " << param_index.ToString() << ")";
318         aliases.push_back({output_index, xla_param_num, param_index});
319       }
320       for (const auto& grad : resource->tensor_array_gradients()) {
321         update.tensor_array_gradients_accessed.insert(grad.first);
322       }
323 
324       xla::XlaOp handle;
325       TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
326       auto sharding = it == arg_shardings.end()
327                           ? absl::optional<xla::OpSharding>()
328                           : it->second;
329       // Set layout of the retval to device representation layout.
330       if (shape_representation_fn) {
331         TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(handle));
332         TF_ASSIGN_OR_RETURN(
333             handle, ReshapeWithCorrectRepresentationAndSharding(
334                         builder, handle, original_shape,
335                         shape_representation_fn, sharding, arg.fast_mem));
336       }
337 
338       // Request that the value be returned on a specific core.
339       xla::XlaScopedShardingAssignment assign_sharding(builder, sharding);
340       if (it != arg_shardings.end()) {
341         retval_index_and_sharding[elems.size()] = it->second;
342       }
343       // Ensures the correct sharding is applied to the output.
344       handle = identity_op(handle);
345       elems.push_back(handle);
346     }
347   }
348 
349   // If we have token output, append it as the last one.
350   if (token_output) {
351     elems.push_back(*token_output);
352   }
353 
354   *num_computation_outputs = elems.size();
355 
356   // Builds the XLA computation. We *always* form a tuple here to ensure that
357   // the output value is the last thing added into the XLA computation, even
358   // if there is only one output value.
359   xla::XlaOp tuple;
360   if (retval_index_and_sharding.empty() || !is_entry_computation) {
361     tuple = xla::Tuple(builder, elems);
362   } else {
363     std::vector<xla::Shape> elem_shapes;
364     for (const auto& elem : elems) {
365       TF_ASSIGN_OR_RETURN(xla::Shape elem_shape,
366                           elem.builder()->GetShape(elem));
367       elem_shapes.push_back(elem_shape);
368     }
369     xla::Shape shape = xla::ShapeUtil::MakeTupleShape(elem_shapes);
370     // Copy specified sharding from retval_index_and_sharding.
371     std::vector<xla::HloSharding> sharding_elems;
372     for (int i = 0, end = elems.size(); i < end; i++) {
373       const auto& iter = retval_index_and_sharding.find(i);
374       TF_RET_CHECK(iter != retval_index_and_sharding.end());
375       const xla::OpSharding& sub_op_sharding = iter->second;
376       TF_ASSIGN_OR_RETURN(xla::HloSharding sub_sharding,
377                           xla::HloSharding::FromProto(sub_op_sharding));
378       if (elem_shapes[i].IsTuple()) {
379         const std::vector<xla::HloSharding> sub_sharding_elems =
380             sub_sharding.tuple_elements();
381         const int64 sub_sharding_elems_size = sub_sharding_elems.size();
382         TF_RET_CHECK(sub_sharding_elems_size ==
383                      xla::ShapeUtil::GetLeafCount(elem_shapes[i]));
384         for (const auto& sub_sharding_elem : sub_sharding_elems) {
385           sharding_elems.push_back(sub_sharding_elem);
386         }
387       } else {
388         sharding_elems.push_back(sub_sharding);
389       }
390     }
391     xla::HloSharding modified_sharding =
392         xla::HloSharding::Tuple(shape, sharding_elems);
393     xla::OpSharding op_sharding = modified_sharding.ToProto();
394     // Assign proper sharding to the tuple instruction.
395     xla::XlaScopedShardingAssignment assign_sharding(builder, op_sharding);
396     tuple = xla::Tuple(builder, elems);
397   }
398   bool returns_tuple = always_return_tuple || elems.size() != 1;
399   VLOG(3) << "Computation returns a tuple=" << returns_tuple;
400   if (!returns_tuple) {
401     xla::GetTupleElement(tuple, 0);
402 
403     for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
404       if (alias.output_index == xla::ShapeIndex({0})) {
405         VLOG(3) << "For aliased parameter " << alias.param_number << ": "
406                 << alias.param_index.ToString()
407                 << " normalizing output_index from {0} to {}, as a scalar is "
408                    "returned from the cluster";
409         alias.output_index = xla::ShapeIndex({});
410       }
411     }
412   }
413 
414   for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
415     builder->SetUpAlias(alias.output_index, alias.param_number,
416                         alias.param_index);
417   }
418 
419   xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();
420   if (!computation_status.ok()) {
421     return computation_status.status();
422   }
423   *computation = computation_status.ConsumeValueOrDie();
424 
425   TF_ASSIGN_OR_RETURN(auto program_shape, computation->GetProgramShape());
426   *output_shape = program_shape.result();
427   return Status::OK();
428 }
429 
430 }  // namespace
431 
432 
HumanString() const433 string XlaCompiler::Argument::HumanString() const {
434   string common;
435   if (!name.empty()) {
436     common = absl::StrCat(" name=", name);
437   }
438   absl::StrAppend(&common, " type=", DataTypeString(type),
439                   " shape=", ShapeHumanString());
440   absl::StrAppend(
441       &common, " is_same_data_across_replicas=", is_same_data_across_replicas);
442   switch (kind) {
443     case kInvalid:
444       return "invalid";
445     case kConstant:
446       return absl::StrCat("kind=constant", common,
447                           " value=", constant_value.DebugString());
448     case kConstantResource:
449       return absl::StrCat("kind=constant-resource", common,
450                           " value=", constant_value.DebugString());
451     case kResource: {
452       string output = absl::StrCat(
453           "kind=resource", common,
454           " resource_kind=", XlaResource::KindToString(resource_kind),
455           " initialized=", initialized, " is_fast_mem=", fast_mem);
456       if (max_array_size >= 0) {
457         absl::StrAppend(&output, " max_array_size=", max_array_size);
458       }
459       if (!tensor_array_gradients.empty()) {
460         absl::StrAppend(&output, " tensor_array_gradients=",
461                         absl::StrJoin(tensor_array_gradients, ","));
462       }
463       return output;
464     }
465     case kParameter:
466       return absl::StrCat("kind=parameter", common);
467     case kTensorList:
468       return absl::StrCat("kind=tensorlist", common);
469     case kToken:
470       return absl::StrCat("token", common);
471   }
472 }
473 
DimensionSizes() const474 std::vector<int64> XlaCompiler::Argument::DimensionSizes() const {
475   if (absl::holds_alternative<TensorShape>(shape)) {
476     return xla::InlinedVectorToVector(
477         absl::get<TensorShape>(shape).dim_sizes());
478   } else {
479     return xla::SpanToVector(absl::get<xla::Shape>(shape).dimensions());
480   }
481 }
482 
483 absl::InlinedVector<int64, 4>
DimensionSizesAsInlinedVector() const484 XlaCompiler::Argument::DimensionSizesAsInlinedVector() const {
485   if (absl::holds_alternative<TensorShape>(shape)) {
486     return absl::get<TensorShape>(shape).dim_sizes();
487   } else {
488     auto v = absl::get<xla::Shape>(shape).dimensions();
489     return absl::InlinedVector<int64, 4>(v.begin(), v.end());
490   }
491 }
492 
ShapeHumanString() const493 string XlaCompiler::Argument::ShapeHumanString() const {
494   if (absl::holds_alternative<TensorShape>(shape)) {
495     return absl::get<TensorShape>(shape).DebugString();
496   } else {
497     return absl::get<xla::Shape>(shape).DebugString();
498   }
499 }
500 
XlaCompiler(XlaCompiler::Options options)501 XlaCompiler::XlaCompiler(XlaCompiler::Options options)
502     : options_(options),
503       initialization_status_(Status::OK()),
504       next_step_id_(1),
505       device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
506       device_mgr_(absl::WrapUnique(device_)) {
507   CHECK(!options_.device_type.type_string().empty());
508   if (options_.populate_resource_manager) {
509     initialization_status_ =
510         (*options_.populate_resource_manager)(device_->resource_manager());
511   }
512 
513   local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
514                                                       FunctionDefLibrary{}));
515   local_pflr_.reset(new ProcessFunctionLibraryRuntime(
516       &device_mgr_, Env::Default(), /*config=*/nullptr,
517       options.graph_def_version, local_flib_def_.get(), OptimizerOptions()));
518   pflr_.reset(new ProcessFunctionLibraryRuntime(
519       &device_mgr_, Env::Default(), /*config=*/nullptr,
520       options.graph_def_version, options.flib_def, OptimizerOptions()));
521 
522   local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
523   flib_runtime_ = pflr_->GetFLR(device_->name());
524 
525   // The default shape representation function is the identity.
526   if (!options_.shape_representation_fn) {
527     options_.shape_representation_fn = IdentityShapeRepresentationFn();
528   }
529 }
530 
531 XlaCompiler::~XlaCompiler() = default;
532 
NextStepId()533 int64 XlaCompiler::NextStepId() { return next_step_id_++; }
534 
operator ()(const std::pair<string,std::vector<Argument>> & signature) const535 uint64 XlaCompiler::SignatureHash::operator()(
536     const std::pair<string, std::vector<Argument>>& signature) const {
537   return std::hash<string>()(signature.first);
538 }
539 
GetFunctionBody(const NameAttrList & function,FunctionLibraryRuntime * flib_runtime,const FunctionBody ** fbody)540 static Status GetFunctionBody(const NameAttrList& function,
541                               FunctionLibraryRuntime* flib_runtime,
542                               const FunctionBody** fbody) {
543   FunctionLibraryRuntime::Handle handle;
544   TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
545       function.name(), AttrSlice(&function.attr()), &handle));
546 
547   *fbody = flib_runtime->GetFunctionBody(handle);
548   TF_RET_CHECK(*fbody);
549   return Status::OK();
550 }
551 
FindFunctionBody(const NameAttrList & function,const FunctionBody ** fbody,const ConfigProto ** config_proto)552 Status XlaCompiler::FindFunctionBody(const NameAttrList& function,
553                                      const FunctionBody** fbody,
554                                      const ConfigProto** config_proto) {
555   // The function may be in either the local_flib_runtime_ or flib_runtime_.
556   // Look up the function in local first and if it is not found then look up the
557   // function in flib_runtime_.
558   auto status = GetFunctionBody(function, local_flib_runtime_, fbody);
559   if (!status.ok()) {
560     if (!errors::IsNotFound(status)) {
561       return status;
562     }
563     TF_RETURN_WITH_CONTEXT_IF_ERROR(
564         GetFunctionBody(function, flib_runtime_, fbody),
565         "Local lookup failed with: ", status.error_message());
566     if (config_proto) {
567       *config_proto = flib_runtime_->config_proto();
568     }
569     VLOG(4) << "Function " << function.name() << " in flib_runtime_";
570   } else {
571     if (config_proto) {
572       *config_proto = local_flib_runtime_->config_proto();
573     }
574     VLOG(4) << "Function " << function.name() << " in local_flib_runtime_";
575   }
576   return Status::OK();
577 }
578 
GetGraph(const FunctionBody * fbody)579 std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
580   std::unique_ptr<Graph> graph(new Graph(options_.flib_def));
581   CopyGraph(*fbody->graph, graph.get());
582 
583   bool is_inside_mustcompile = false;
584   TryGetNodeAttr(AttrSlice(&fbody->fdef.attr()), kXlaMustCompileAttr,
585                  &is_inside_mustcompile);
586 
587   // Performs a first function inlining pass before shape inference, since
588   // otherwise shape inference can't see inside functions and a comprehensive
589   // shape_map, including function ops, is needed to constant-propagate Shape
590   // Ops below.
591   auto flags = GetBuildXlaOpsPassFlags();
592   OptimizerOptions opts;
593   opts.set_opt_level(OptimizerOptions::L0);
594   opts.set_do_common_subexpression_elimination(false);
595   opts.set_do_function_inlining(true);
596   opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding);
597   GraphOptimizer optimizer(opts);
598   // Do not constant fold nodes that output DT_VARIANT type tensors.
599   // XLA does not support Const nodes of Variant type since it needs
600   // to know the original ops to be able to compile them to the relevant
601   // XLA form.
602   // TODO(srbs): This filter is a little conservative. E.g. a subgraph of
603   // the form:
604   //                          Const
605   //                            |
606   // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op
607   //                                                  |
608   //                                        (Discard popped list)
609   //
610   // Would have been reduced to "Const -> Op" without this filter.
611   // However since we are only allowed to specify the filter at the "Node"
612   // level there is no good way to allow the above behavior. So we
613   // disallow any sort of constant folding on Variant nodes for now.
614   //
615   // Also do not consider constant folding Shape ops. When there is a dynamic
616   // dimension in a tensor, TF2XLA currently represent them as the static
617   // upperbound shape, which can be constant folded and then lose the info
618   // that this Shape is dynamic.
619   auto cf_consider_fn = [](const Node* n) {
620     for (const auto& output_arg : n->op_def().output_arg()) {
621       if (output_arg.type() == DT_VARIANT) {
622         return false;
623       }
624     }
625     const auto& ts = n->type_string();
626     // XLA has special logic to handle dynamic shapes, don't constant fold
627     // them.
628     if (ts == "Shape" || ts == "ShapeN" || ts == "Size") {
629       return false;
630     }
631     return true;
632   };
633   GraphOptimizer::Options graph_optimizer_options;
634   graph_optimizer_options.cf_consider_fn = cf_consider_fn;
635   graph_optimizer_options.inline_multi_device_functions = true;
636   graph_optimizer_options.inline_impl_selection_group_functions = true;
637   graph_optimizer_options.inline_with_single_device_body_placer = true;
638   graph_optimizer_options.ignore_noinline = is_inside_mustcompile;
639 
640   {
641     GraphShapeInfo shape_info;
642     InferShapes(graph.get(), /*arg_shapes=*/{},
643                 flib_runtime_->GetFunctionLibraryDefinition(), &shape_info)
644         .IgnoreError();
645     auto node_name_index = graph->BuildNodeNameIndex();
646     std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
647     for (const auto& node_shape_info : shape_info) {
648       const string& node_name = node_shape_info.first;
649       const std::vector<InferredShape>& output_shapes = node_shape_info.second;
650       const auto& node_iter = node_name_index.find(node_name);
651       if (node_iter != node_name_index.end()) {
652         auto& partial_shapes = shape_map[node_name];
653         for (const auto& inferred_shape : output_shapes) {
654           partial_shapes.push_back(inferred_shape.shape);
655         }
656       }
657     }
658     graph_optimizer_options.shape_map = &shape_map;
659     optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
660                        /*device=*/nullptr, &graph, graph_optimizer_options);
661   }
662 
663   // Run shape inference on the graph and optimize the graph again.
664   GraphShapeInfo shape_info;
665   InferShapes(graph.get(), /*arg_shapes=*/{},
666               flib_runtime_->GetFunctionLibraryDefinition(), &shape_info)
667       .IgnoreError();
668   auto node_name_index = graph->BuildNodeNameIndex();
669   std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
670   for (const auto& node_shape_info : shape_info) {
671     const string& node_name = node_shape_info.first;
672     const std::vector<InferredShape>& output_shapes = node_shape_info.second;
673     const auto& node_iter = node_name_index.find(node_name);
674     if (node_iter != node_name_index.end()) {
675       auto& partial_shapes = shape_map[node_name];
676       for (const auto& inferred_shape : output_shapes) {
677         partial_shapes.push_back(inferred_shape.shape);
678       }
679     }
680   }
681   graph_optimizer_options.shape_map = &shape_map;
682   optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
683                      /*device=*/nullptr, &graph, graph_optimizer_options);
684 
685   return graph;
686 }
687 
688 // Collects all control rets from `orig_control_ret_nodes` that are still valid,
689 // keeping the same order.
GetValidControlRets(absl::Span<Node * const> orig_control_ret_nodes,const Graph & graph)690 std::vector<std::string> GetValidControlRets(
691     absl::Span<Node* const> orig_control_ret_nodes, const Graph& graph) {
692   // Build map from control ret node to index.
693   absl::flat_hash_map<const Node*, int> control_ret_nodes_map;
694   for (int i = 0; i < orig_control_ret_nodes.size(); ++i) {
695     const Node* n = orig_control_ret_nodes[i];
696     control_ret_nodes_map[n] = i;
697   }
698   // Check which control rets are still valid.
699   std::vector<bool> is_valid_control_ret(orig_control_ret_nodes.size(), false);
700   int num_valid_control_rets = 0;
701   for (const Node* n : graph.nodes()) {
702     auto iter = control_ret_nodes_map.find(n);
703     if (iter != control_ret_nodes_map.end()) {
704       ++num_valid_control_rets;
705       is_valid_control_ret[iter->second] = true;
706     }
707   }
708   // Return valid control rets in same order as they appear in
709   // `orig_control_ret_nodes`.
710   std::vector<std::string> valid_control_rets;
711   valid_control_rets.reserve(num_valid_control_rets);
712   for (int i = 0; i < orig_control_ret_nodes.size(); ++i) {
713     if (is_valid_control_ret[i]) {
714       valid_control_rets.push_back(orig_control_ret_nodes[i]->name());
715     }
716   }
717   return valid_control_rets;
718 }
719 
CompileFunction(const XlaCompiler::CompileOptions & options,const NameAttrList & fn_name_attrs,absl::Span<const XlaCompiler::Argument> args,XlaCompiler::CompilationResult * result)720 Status XlaCompiler::CompileFunction(
721     const XlaCompiler::CompileOptions& options,
722     const NameAttrList& fn_name_attrs,
723     absl::Span<const XlaCompiler::Argument> args,
724     XlaCompiler::CompilationResult* result) {
725   const string function_id =
726       Canonicalize(fn_name_attrs.name(), AttrSlice(&fn_name_attrs.attr()));
727   VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
728 
729   const std::vector<XlaCompiler::Argument> arg_vector(args.begin(), args.end());
730   auto it = cache_.find({function_id, arg_vector});
731   if (it != cache_.end()) {
732     *result = it->second;
733     return Status::OK();
734   }
735 
736   const FunctionBody* fbody;
737   const ConfigProto* config = nullptr;
738   TF_RETURN_IF_ERROR(FindFunctionBody(fn_name_attrs, &fbody, &config));
739 
740   absl::optional<ConfigProto> config_proto;
741   if (config) {
742     config_proto = *config;
743   }
744 
745   TF_RETURN_WITH_CONTEXT_IF_ERROR(
746       CheckSignature(fbody->arg_types, args),
747       "Signature check failure while compiling: ", fn_name_attrs.name());
748 
749   // Set shapes for _Arg nodes. They are useful for constant folding (e.g. an
750   // Xla op requires a compile-time constant input, and that input is shape of
751   // an _Arg node.
752   for (int i = 0, end = args.size(); i < end; i++) {
753     // Skip resource variables and tensor lists.
754     DataType dtype;
755     TF_RETURN_IF_ERROR(GetNodeAttr(fbody->arg_nodes[i]->def(), "T", &dtype));
756     if (dtype == DT_RESOURCE || dtype == DT_VARIANT) {
757       continue;
758     }
759 
760     if (absl::holds_alternative<xla::Shape>(args[i].shape)) {
761       xla::Shape xla_shape = absl::get<xla::Shape>(args[i].shape);
762       TensorShape tensor_shape;
763       // If xla_shape is dynamic, prevent constant folding by not setting
764       // output_shapes.
765       if (XLAShapeToTensorShape(xla_shape, &tensor_shape).ok() &&
766           xla_shape.is_static()) {
767         fbody->arg_nodes[i]->ClearAttr("_output_shapes");
768         fbody->arg_nodes[i]->AddAttr("_output_shapes",
769                                      std::vector<TensorShape>{tensor_shape});
770       }
771     } else {
772       TensorShape tensor_shape = absl::get<TensorShape>(args[i].shape);
773       fbody->arg_nodes[i]->ClearAttr("_output_shapes");
774       fbody->arg_nodes[i]->AddAttr("_output_shapes",
775                                    std::vector<TensorShape>{tensor_shape});
776     }
777   }
778 
779   std::unique_ptr<Graph> graph = GetGraph(fbody);
780 
781   // _Arg and _Retval nodes don't exist in the stored subgraph for the function;
782   // they are added by the function body looked up.  Therefore, they don't have
783   // core assignments here.
784   // Attempt to assign a core to each _Retval and _Arg. Chooses the
785   // lowest-numbered core that consumes the argument. We choose the
786   // lowest-numbered core so the assignment is deterministic.
787   for (Node* n : graph->nodes()) {
788     if (n->IsArg()) {
789       TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true));
790     }
791   }
792   // Do _Retval as a second loop, in case the retval's input is an _Arg (which
793   // may have gotten a device assignment from the first loop).
794   for (Node* n : graph->nodes()) {
795     if (n->IsRetval()) {
796       TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false));
797     }
798   }
799 
800   if (VLOG_IS_ON(2)) {
801     VLOG(2) << "XlaCompiler::CompileFunction: "
802             << DumpGraphToFile(
803                    absl::StrCat("xla_compile_function_", function_id), *graph);
804   }
805 
806   VLOG(1) << "====================================================";
807   MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
808       *graph, config_proto,
809       /*uses_uninitialized_resource_args=*/AnyUninitializedResourceArg(args));
810   if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) {
811     VLOG(1) << "Using MLIR bridge";
812     GraphDebugInfo debug_info;
813 
814     std::vector<std::string> valid_control_rets =
815         GetValidControlRets(fbody->control_ret_nodes, *graph);
816 
817     TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
818         std::move(*graph), mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
819         valid_control_rets, options_.device_type.type_string(),
820         options.use_tuple_arg, *options_.flib_def, debug_info,
821         options_.shape_representation_fn, result));
822   } else {
823     TF_RETURN_IF_ERROR(
824         CompileGraph(options, function_id, std::move(graph), args, result));
825   }
826   VLOG(1) << "====================================================";
827 
828   cache_[{function_id, arg_vector}] = *result;
829   return Status::OK();
830 }
831 
832 // Computes the XLA shape for argument 'arg'.
XLAShapeForArgument(const XlaCompiler::Argument & arg,bool is_entry_computation,const absl::optional<xla::HloSharding> & arg_sharding,xla::Shape * xla_shape) const833 Status XlaCompiler::XLAShapeForArgument(
834     const XlaCompiler::Argument& arg, bool is_entry_computation,
835     const absl::optional<xla::HloSharding>& arg_sharding,
836     xla::Shape* xla_shape) const {
837   switch (arg.kind) {
838     case XlaCompiler::Argument::kConstant:
839       LOG(FATAL) << "Unreachable case";
840     case XlaCompiler::Argument::kParameter: {
841       if (is_entry_computation) {
842         TensorShape shape;
843         if (absl::holds_alternative<TensorShape>(arg.shape)) {
844           shape = absl::get<TensorShape>(arg.shape);
845         } else {
846           TF_RETURN_IF_ERROR(
847               XLAShapeToTensorShape(absl::get<xla::Shape>(arg.shape), &shape));
848         }
849         TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn(
850                                             shape, arg.type,
851                                             /*use_fast_memory=*/false));
852         TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
853             arg_sharding, /*use_fast_memory=*/false,
854             options_.shape_representation_fn, xla_shape));
855       } else {
856         if (absl::holds_alternative<xla::Shape>(arg.shape)) {
857           *xla_shape = absl::get<xla::Shape>(arg.shape);
858         } else {
859           TF_RETURN_IF_ERROR(TensorShapeToXLAShape(
860               arg.type, absl::get<TensorShape>(arg.shape), xla_shape));
861         }
862       }
863       return Status::OK();
864     }
865     case XlaCompiler::Argument::kTensorList: {
866       TF_RET_CHECK(absl::holds_alternative<xla::Shape>(arg.shape));
867       *xla_shape = absl::get<xla::Shape>(arg.shape);
868       return Status::OK();
869     }
870     case XlaCompiler::Argument::kConstantResource:
871     case XlaCompiler::Argument::kResource: {
872       TF_RET_CHECK(arg.initialized);
873 
874       switch (arg.resource_kind) {
875         case XlaResource::kVariable: {
876           TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
877           TF_ASSIGN_OR_RETURN(*xla_shape,
878                               options_.shape_representation_fn(
879                                   absl::get<TensorShape>(arg.shape), arg.type,
880                                   /*use_fast_memory=*/arg.fast_mem));
881           TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
882               arg_sharding, arg.fast_mem, options_.shape_representation_fn,
883               xla_shape));
884           return Status::OK();
885         }
886         case XlaResource::kTensorArray: {
887           if (arg.max_array_size < 0) {
888             return errors::InvalidArgument(
889                 "Negative max_array_size in XLAShapeForArgument");
890           }
891           TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
892           TensorShape shape;
893           shape.AddDim(arg.max_array_size);
894           shape.AppendShape(absl::get<TensorShape>(arg.shape));
895           TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape));
896 
897           if (!arg.tensor_array_gradients.empty()) {
898             std::vector<xla::Shape> tuple_shape(
899                 arg.tensor_array_gradients.size() + 1, *xla_shape);
900             *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape);
901           }
902           return Status::OK();
903         }
904         case XlaResource::kStack: {
905           if (arg.max_array_size < 0) {
906             return errors::InvalidArgument(
907                 "Negative max_array_size in XLAShapeForArgument");
908           }
909           TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
910           TensorShape shape;
911           shape.AddDim(arg.max_array_size);
912           shape.AppendShape(absl::get<TensorShape>(arg.shape));
913           xla::Shape buffer_shape;
914           TF_RETURN_IF_ERROR(
915               TensorShapeToXLAShape(arg.type, shape, &buffer_shape));
916           *xla_shape = xla::ShapeUtil::MakeTupleShape(
917               {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})});
918           return Status::OK();
919         }
920 
921         case XlaResource::kInvalid:
922           return errors::Internal(
923               "Invalid resource type in XLAShapeForArgument()");
924       }
925     }
926     case XlaCompiler::Argument::kToken: {
927       *xla_shape = xla::ShapeUtil::MakeTokenShape();
928       return Status::OK();
929     }
930     case XlaCompiler::Argument::kInvalid:
931       return errors::Internal("Invalid argument type in XLAShapeForArgument()");
932   }
933 }
934 
935 /* static */
PopulateArgumentFromResource(const XlaResource & resource,Argument * arg)936 void XlaCompiler::PopulateArgumentFromResource(const XlaResource& resource,
937                                                Argument* arg) {
938   arg->initialized = resource.initialized();
939   arg->kind = XlaCompiler::Argument::kResource;
940   arg->resource_kind = resource.kind();
941 
942   arg->type = resource.type();
943   arg->shape = resource.shape();
944   arg->max_array_size = resource.max_array_size();
945   for (const auto& gradient : resource.tensor_array_gradients()) {
946     arg->tensor_array_gradients.insert(gradient.first);
947   }
948   arg->name = resource.name();
949 }
950 
951 // Builds XLA computations for each of the arguments to the computation.
952 // `args` are the arguments to the computation.
BuildArguments(const Graph & graph,const std::vector<XlaCompiler::Argument> & args,bool use_tuple_arg,xla::XlaBuilder * builder,XlaContext * context,const std::map<int,xla::OpSharding> & arg_shardings,std::vector<XlaExpression> * arg_expressions,std::vector<int> * input_to_args,std::vector<xla::Shape> * input_shapes,bool is_entry_computation)953 Status XlaCompiler::BuildArguments(
954     const Graph& graph, const std::vector<XlaCompiler::Argument>& args,
955     bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context,
956     const std::map<int, xla::OpSharding>& arg_shardings,
957     std::vector<XlaExpression>* arg_expressions,
958     std::vector<int>* input_to_args, std::vector<xla::Shape>* input_shapes,
959     bool is_entry_computation) {
960   arg_expressions->resize(args.size());
961 
962   // Argument numbers of arguments and resources that are to be passed to the
963   // XLA computation as runtime parameters. `input_to_args[a] = b` means that
964   // the a'th XLA input corresponds to the b'th original arg indexes.
965   input_to_args->clear();
966   input_to_args->reserve(args.size());
967 
968   // Fills in constant arguments, and computes non-constant argument order.
969   for (std::vector<XlaCompiler::Argument>::size_type i = 0; i < args.size();
970        ++i) {
971     const XlaCompiler::Argument& arg = args[i];
972     XlaExpression& arg_expression = (*arg_expressions)[i];
973     switch (arg.kind) {
974       case XlaCompiler::Argument::kConstantResource:
975       case XlaCompiler::Argument::kResource: {
976         TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid);
977         TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
978         // TODO(phawkins): this code assumes that resource arguments do not
979         // alias.
980         XlaResource* resource =
981             context->AddResource(absl::make_unique<XlaResource>(
982                 arg.resource_kind, i, arg.name, arg.type,
983                 absl::get<TensorShape>(arg.shape), xla::XlaOp(),
984                 /*max_array_size=*/arg.max_array_size,
985                 /*tensor_array_gradients=*/arg.tensor_array_gradients,
986                 /*tensor_array_multiple_writes_aggregate=*/true));
987         arg_expression =
988             arg.kind == XlaCompiler::Argument::kResource
989                 ? XlaExpression::Resource(resource)
990                 : XlaExpression::ConstantResource(arg.constant_value, resource);
991         if (arg.initialized) {
992           input_to_args->push_back(i);
993         }
994         break;
995       }
996       case XlaCompiler::Argument::kParameter:
997       case XlaCompiler::Argument::kTensorList:
998       case XlaCompiler::Argument::kToken: {
999         input_to_args->push_back(i);
1000         break;
1001       }
1002       case XlaCompiler::Argument::kConstant:
1003         arg_expression = XlaExpression::Constant(arg.constant_value);
1004         break;
1005       case XlaCompiler::Argument::kInvalid:
1006         return errors::Internal(
1007             "Unreachable case in BuildArguments() while filling constant args");
1008     }
1009   }
1010 
1011   if (input_to_args->empty() && !use_tuple_arg) {
1012     return Status::OK();
1013   }
1014 
1015   // `arg_to_inputs[c] = d` means that the c'th original arg index corresponds
1016   // to the d'th XLA input. Note that the value -1 corresponds to constants, or
1017   // other args that don't correspond to an input.
1018   std::vector<int> arg_to_inputs(args.size(), -1);
1019   for (int i = 0, end = input_to_args->size(); i < end; i++) {
1020     arg_to_inputs[input_to_args->at(i)] = i;
1021   }
1022 
1023   std::vector<xla::Shape> arg_shapes(input_to_args->size());
1024   for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1025     // Computes the shapes of non-constant arguments.
1026     auto arg_sharding = arg_shardings.find((*input_to_args)[i]);
1027     absl::optional<xla::HloSharding> sharding;
1028     if (arg_sharding != arg_shardings.end()) {
1029       TF_ASSIGN_OR_RETURN(auto hlo_sharding,
1030                           xla::HloSharding::FromProto(arg_sharding->second));
1031       sharding = hlo_sharding;
1032     }
1033     TF_RETURN_IF_ERROR(XLAShapeForArgument(args[(*input_to_args)[i]],
1034                                            is_entry_computation, sharding,
1035                                            &arg_shapes[i]));
1036   }
1037 
1038   if (use_tuple_arg) {
1039     input_shapes->push_back(xla::ShapeUtil::MakeTupleShape(arg_shapes));
1040   } else {
1041     *input_shapes = arg_shapes;
1042   }
1043 
1044   // Attach a common operator name as metadata. This has no semantic effect — it
1045   // merely makes the HLO graph more readable when visualized via TensorBoard,
1046   // since TensorBoard forms groups out of operators with similar names.
1047   xla::OpMetadata arg_metadata;
1048   arg_metadata.set_op_name("XLA_Args");
1049   builder->SetOpMetadata(arg_metadata);
1050 
1051   // Build parameter handles for non-constant arguments.
1052   std::vector<xla::XlaOp> arg_handles(input_to_args->size());
1053   if (use_tuple_arg) {
1054     xla::XlaOp tuple;
1055     if (is_entry_computation) {
1056       xla::OpSharding tuple_sharding;
1057       tuple_sharding.set_type(xla::OpSharding::TUPLE);
1058       for (int64 parameter : *input_to_args) {
1059         auto it = arg_shardings.find(parameter);
1060         *tuple_sharding.add_tuple_shardings() =
1061             it == arg_shardings.end() ? xla::sharding_builder::AssignDevice(0)
1062                                       : it->second;
1063       }
1064       std::vector<bool> is_same_across_replicas;
1065       for (int i = 0, end = input_to_args->size(); i < end; ++i) {
1066         // Add an entry to is_same_across_replicas for every leaf buffer.
1067         is_same_across_replicas.insert(
1068             is_same_across_replicas.end(),
1069             xla::ShapeUtil::GetLeafCount(arg_shapes[i]),
1070             args[input_to_args->at(i)].is_same_data_across_replicas);
1071       }
1072       xla::XlaScopedShardingAssignment assign_tuple_sharding(
1073           builder, input_to_args->empty() ? absl::optional<xla::OpSharding>()
1074                                           : tuple_sharding);
1075       tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple",
1076                              is_same_across_replicas);
1077     } else {
1078       tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
1079     }
1080 
1081     for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1082       auto it = arg_shardings.find(i);
1083       xla::XlaScopedShardingAssignment assign_sharding(
1084           builder, it == arg_shardings.end() ? absl::optional<xla::OpSharding>()
1085                                              : it->second);
1086       auto& arg = args[input_to_args->at(i)];
1087 
1088       xla::OpMetadata arg_metadata;
1089       arg_metadata.set_op_name(arg.node_name);
1090       builder->SetOneShotOpMetadata(arg_metadata);
1091       arg_handles[i] = xla::GetTupleElement(tuple, i);
1092     }
1093   } else {
1094     for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1095       auto it = arg_shardings.find(i);
1096       xla::XlaScopedShardingAssignment assign_sharding(
1097           builder, it == arg_shardings.end() ? absl::optional<xla::OpSharding>()
1098                                              : it->second);
1099       if (is_entry_computation) {
1100         // Add an entry to is_same_across_replicas for every leaf buffer.
1101         std::vector<bool> is_same_across_replicas(
1102             xla::ShapeUtil::GetLeafCount((*input_shapes)[i]),
1103             args[input_to_args->at(i)].is_same_data_across_replicas);
1104         arg_handles[i] =
1105             xla::Parameter(builder, i, (*input_shapes)[i],
1106                            absl::StrCat("arg", i), is_same_across_replicas);
1107       } else {
1108         arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i],
1109                                         absl::StrCat("arg", i));
1110       }
1111     }
1112   }
1113 
1114   for (int i = 0, end = input_to_args->size(); i < end; ++i) {
1115     const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
1116     for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) {
1117       int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second);
1118       VLOG(1) << "Setting dynamic size " << i << " -> "
1119               << dynamic_size_param_index;
1120       arg_handles[i] = xla::SetDimensionSize(
1121           arg_handles[i], arg_handles[dynamic_size_param_index],
1122           dim_and_arg_num.first);
1123     }
1124   }
1125 
1126   builder->ClearOpMetadata();
1127 
1128   // Fill in the handles in non-constant arguments, and reshape parameters
1129   // back to their correct shapes.
1130   VLOG(2) << "XLA computation inputs:";
1131   for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
1132     const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
1133     VLOG(2) << "  XLA arg " << i
1134             << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
1135             << " name: " << arg.name << " TF arg " << input_to_args->at(i)
1136             << " node name: " << arg.node_name
1137             << (arg_shardings.find(i) == arg_shardings.end()
1138                     ? ""
1139                     : absl::StrCat(" sharding: ",
1140                                    arg_shardings.at(i).DebugString()));
1141     XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)];
1142     switch (arg.kind) {
1143       case XlaCompiler::Argument::kConstantResource:
1144       case XlaCompiler::Argument::kResource: {
1145         TF_RET_CHECK(arg.initialized);
1146         XlaResource* resource = arg_expression.resource();
1147         TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients,
1148                                                  arg_handles[i], builder));
1149         VLOG(2) << "    resource: num_gradients: "
1150                 << arg.tensor_array_gradients.size();
1151         break;
1152       }
1153       case XlaCompiler::Argument::kParameter:
1154         // Reshape parameters back to their correct shapes.
1155         // TODO(b/76097077): propagate device assignments onto arguments and
1156         // return values of functions, and then reshape unconditionally.
1157         if (is_entry_computation) {
1158           arg_expression = XlaExpression::XlaOp(
1159               xla::Reshape(arg_handles[i], arg.DimensionSizes()), arg.type);
1160         } else {
1161           arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
1162           if (arg.value_bound) {
1163             // Propagate upper bound to arg_expression.
1164             arg_expression.set_value_bound(arg.value_bound.value());
1165           }
1166         }
1167         break;
1168       case XlaCompiler::Argument::kTensorList: {
1169         arg_expression = XlaExpression::TensorList(arg_handles[i]);
1170         break;
1171       }
1172       case XlaCompiler::Argument::kToken: {
1173         arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
1174         break;
1175       }
1176       case XlaCompiler::Argument::kConstant:
1177       case XlaCompiler::Argument::kInvalid:
1178         return errors::Internal(
1179             "Unreachable case in BuildArguments() while filling handles");
1180     }
1181   }
1182 
1183   return Status::OK();
1184 }
1185 
1186 namespace {
1187 
1188 // Check that the ops of all non-functional nodes have been registered.
ValidateFunctionDef(const FunctionDef * fdef,const FunctionLibraryDefinition & flib_def)1189 Status ValidateFunctionDef(const FunctionDef* fdef,
1190                            const FunctionLibraryDefinition& flib_def) {
1191   for (const NodeDef& node : fdef->node_def()) {
1192     const string& op = node.op();
1193     if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) {
1194       continue;
1195     }
1196     const OpDef* op_def;
1197     TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(op, &op_def));
1198   }
1199   return Status::OK();
1200 }
1201 
1202 // If node is PartitionedCall or StatefulPartitionedCall, returns the
1203 // name from the "f" attr, else returns node.def().op().
1204 // Returned pointer points to the internal string either in node's attributes
1205 // or in its NodeDef. This pointer is valid as long as the node has not been
1206 // modified.
GetPotentialFunctionName(const Node & node,const string ** name)1207 Status GetPotentialFunctionName(const Node& node, const string** name) {
1208   if (node.IsPartitionedCall()) {
1209     const AttrValue* attr_value;
1210     TF_RETURN_IF_ERROR(
1211         node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
1212     if (!attr_value->has_func()) {
1213       return errors::InvalidArgument(
1214           "The attribute value for attribute 'f' in node ", node.DebugString(),
1215           " does not have 'func' field set");
1216     }
1217     *name = &attr_value->func().name();
1218     return Status::OK();
1219   }
1220   *name = &node.type_string();
1221   return Status::OK();
1222 }
1223 
1224 // Check that the graph doesn't have any invalid nodes (e.g. incompatible with
1225 // given device_type, invalid data type, missing attributes...)
ValidateGraph(const Graph * graph,const FunctionLibraryDefinition & flib_def,const DeviceType & device_type,const string & name)1226 Status ValidateGraph(const Graph* graph,
1227                      const FunctionLibraryDefinition& flib_def,
1228                      const DeviceType& device_type, const string& name) {
1229   // Make sure the XLA compilation kernels are registered.  This operation is
1230   // idempotent so it is fine if someone called it already.
1231   XlaOpRegistry::RegisterCompilationKernels();
1232 
1233   auto maybe_error = [&](const Node* node, const Status& s) -> Status {
1234     if (!s.ok()) {
1235       return errors::InvalidArgument(absl::StrCat(
1236           "Detected unsupported operations when trying to compile graph ", name,
1237           " on ", device_type.type_string(), ": ", node->def().op(), " (",
1238           s.error_message(), ")", FormatNodeForError(*node),
1239           "One approach is to outside compile the unsupported ops to run on "
1240           "CPUs by enabling soft placement "
1241           "`tf.config.set_soft_device_placement(True)`."
1242           " This has a potential performance penalty."));
1243     }
1244     return Status::OK();
1245   };
1246 
1247   for (const Node* node : graph->nodes()) {
1248     if (node->type_string() == FunctionLibraryDefinition::kGradientOp) {
1249       continue;
1250     }
1251     const string* function_name;
1252     TF_RETURN_IF_ERROR(GetPotentialFunctionName(*node, &function_name));
1253     const FunctionDef* fdef = flib_def.Find(*function_name);
1254     Status s;
1255     if (fdef) {
1256       s = ValidateFunctionDef(fdef, flib_def);
1257       TF_RETURN_IF_ERROR(maybe_error(node, s));
1258       continue;
1259     }
1260     const OpDef* op_def;
1261     s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def);
1262     TF_RETURN_IF_ERROR(maybe_error(node, s));
1263     TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def));
1264     s = FindKernelDef(device_type, node->def(), nullptr, nullptr);
1265     TF_RETURN_IF_ERROR(maybe_error(node, s));
1266   }
1267   return Status::OK();
1268 }
1269 
ConvertConstantsToExpressions(xla::XlaBuilder * builder,absl::Span<XlaExpression> expressions)1270 void ConvertConstantsToExpressions(xla::XlaBuilder* builder,
1271                                    absl::Span<XlaExpression> expressions) {
1272   for (XlaExpression& expression : expressions) {
1273     if (expression.kind() == XlaExpression::Kind::kConstant) {
1274       expression =
1275           XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype());
1276     }
1277   }
1278 }
1279 
1280 }  // namespace
1281 
CompileGraph(const XlaCompiler::CompileOptions & options,string const & name,std::unique_ptr<Graph> graph,absl::Span<const XlaCompiler::Argument> args,CompilationResult * result)1282 Status XlaCompiler::CompileGraph(
1283     const XlaCompiler::CompileOptions& options, string const& name,
1284     std::unique_ptr<Graph> graph, absl::Span<const XlaCompiler::Argument> args,
1285     CompilationResult* result) {
1286   VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name;
1287 
1288   TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes(
1289       graph.get(), options_.flib_def, local_flib_def_.get()));
1290   TF_RETURN_IF_ERROR(RearrangeFunctionArguments(
1291       [this](const NameAttrList& function, const FunctionBody** fbody) {
1292         return FindFunctionBody(function, fbody);
1293       },
1294       graph.get(), local_flib_def_.get(),
1295       pflr_->GetFunctionLibraryDefinition()));
1296 
1297   if (VLOG_IS_ON(2)) {
1298     VLOG(2) << "XlaCompiler::CompileGraph: "
1299             << DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph,
1300                                flib_runtime_->GetFunctionLibraryDefinition());
1301   }
1302 
1303   // Report the error here if initialization failed.
1304   TF_RETURN_IF_ERROR(initialization_status_);
1305 
1306   // Detect invalid nodes.
1307   // FunctionalizeControlFlow may remove some nodes from the graph.
1308   TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
1309                                    options_.device_type, name));
1310 
1311   xla::XlaBuilder builder(name);
1312   XlaContext* context = new XlaContext(this, &builder, graph.get());
1313   core::ScopedUnref context_unref(context);
1314 
1315   std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());
1316   int token_input_index = -1;
1317   std::unique_ptr<xla::XlaOp> token_output;
1318   if (options.add_token_input_output) {
1319     // Add extra token input.
1320     token_input_index = real_args.size();
1321 
1322     XlaCompiler::Argument token_arg;
1323     token_arg.kind = XlaCompiler::Argument::kToken;
1324     real_args.push_back(token_arg);
1325   }
1326 
1327   std::map<int, xla::OpSharding> arg_shardings;
1328   std::map<int, xla::OpSharding> retval_shardings;
1329   TF_ASSIGN_OR_RETURN(std::tie(arg_shardings, retval_shardings),
1330                       ComputeArgAndRetvalShardings(*graph));
1331 
1332   std::vector<XlaExpression> arg_expressions;
1333   TF_RETURN_IF_ERROR(BuildArguments(
1334       *graph, real_args, options.use_tuple_arg, &builder, context,
1335       arg_shardings, &arg_expressions, &result->input_mapping,
1336       &result->xla_input_shapes, options.is_entry_computation));
1337   context->set_args(std::move(arg_expressions));
1338 
1339   PushNodeTokenMapping();
1340   // Use std::set instead of std::unordered_set to ensure determinism.
1341   std::set<std::string> output_node_token_inputs;
1342   if (token_input_index != -1) {
1343     // Original token comes from input.
1344     auto arg_expression = context->args()[token_input_index];
1345     TF_RETURN_IF_ERROR(
1346         SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle()));
1347 
1348     // Calculate token inputs for output token.
1349     output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph);
1350 
1351     // If there's no side-effecting op in the graph, use token input as token
1352     // output.
1353     if (output_node_token_inputs.empty()) {
1354       output_node_token_inputs.insert(kXlaTokenArgNodeName);
1355     }
1356   } else if (options.is_entry_computation) {
1357     // Original token is manually created.
1358     if (HasSideEffectingNodes(*graph)) {
1359       TF_RETURN_IF_ERROR(
1360           SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder)));
1361     }
1362   }
1363 
1364   TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
1365                                   flib_runtime_, NextStepId()));
1366   if (token_input_index != -1) {
1367     // Add extra token output.
1368     std::vector<xla::XlaOp> token_inputs;
1369     for (const auto& node_name : output_node_token_inputs) {
1370       auto token_or = GetNodeToken(node_name);
1371       TF_RETURN_IF_ERROR(token_or.status());
1372       token_inputs.push_back(token_or.ValueOrDie());
1373     }
1374     token_output.reset(new xla::XlaOp(xla::AfterAll(&builder, token_inputs)));
1375   }
1376   TF_RETURN_IF_ERROR(PopNodeTokenMapping());
1377 
1378   int num_nonconst_outputs;
1379   int num_computation_outputs;
1380   result->computation = std::make_shared<xla::XlaComputation>();
1381   result->outputs.resize(context->retvals().size());
1382   std::vector<XlaExpression> retvals = context->retvals();
1383   ConvertConstantsToExpressions(&builder, absl::Span<XlaExpression>(retvals));
1384   TF_RETURN_IF_ERROR(BuildComputation(
1385       real_args, retvals, arg_shardings, retval_shardings, context->resources(),
1386       std::move(token_output),
1387       options.is_entry_computation ? options_.shape_representation_fn
1388                                    : ShapeRepresentationFn{},
1389       options.is_entry_computation,
1390       options.return_updated_values_for_all_resources,
1391       options.always_return_tuple, options.use_tuple_arg,
1392       options.alias_resource_update, &builder, result->computation.get(),
1393       &num_computation_outputs, &num_nonconst_outputs, &result->outputs,
1394       &result->resource_updates, &result->xla_output_shape,
1395       result->input_mapping));
1396 
1397   VLOG(2) << "Outputs: total: " << context->retvals().size()
1398           << " nonconstant: " << num_nonconst_outputs;
1399   VLOG(2) << "XLA output shape: "
1400           << xla::ShapeUtil::HumanStringWithLayout(result->xla_output_shape);
1401   return Status::OK();
1402 }
1403 
GetChannelHandle(const string & key,xla::ChannelHandle * channel)1404 Status XlaCompiler::GetChannelHandle(const string& key,
1405                                      xla::ChannelHandle* channel) {
1406   auto result = channels_.emplace(key, xla::ChannelHandle());
1407   if (result.second) {
1408     TF_ASSIGN_OR_RETURN(result.first->second, client()->CreateChannelHandle());
1409   }
1410   *channel = result.first->second;
1411   VLOG(1) << "Channel: " << key << " " << channel->DebugString();
1412   return Status::OK();
1413 }
1414 
GetHostToDeviceChannelHandle(const string & key,xla::ChannelHandle * channel)1415 Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key,
1416                                                  xla::ChannelHandle* channel) {
1417   auto result = channels_.emplace(key, xla::ChannelHandle());
1418   if (result.second) {
1419     TF_ASSIGN_OR_RETURN(result.first->second,
1420                         client()->CreateHostToDeviceChannelHandle());
1421   }
1422   *channel = result.first->second;
1423   VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString();
1424   return Status::OK();
1425 }
1426 
GetDeviceToHostChannelHandle(const string & key,xla::ChannelHandle * channel)1427 Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key,
1428                                                  xla::ChannelHandle* channel) {
1429   auto result = channels_.emplace(key, xla::ChannelHandle());
1430   if (result.second) {
1431     TF_ASSIGN_OR_RETURN(result.first->second,
1432                         client()->CreateDeviceToHostChannelHandle());
1433   }
1434   *channel = result.first->second;
1435   VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString();
1436   return Status::OK();
1437 }
1438 
1439 namespace {
1440 
SetTransfer(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes,tf2xla::HostTransferMetadata * transfer)1441 void SetTransfer(const string& key, absl::Span<const DataType> types,
1442                  absl::Span<const TensorShape> shapes,
1443                  tf2xla::HostTransferMetadata* transfer) {
1444   transfer->set_key(key);
1445   CHECK(types.size() == shapes.size());
1446   for (int i = 0, end = types.size(); i < end; ++i) {
1447     tf2xla::TensorMetadata* metadata = transfer->add_metadata();
1448     metadata->set_type(types[i]);
1449     shapes[i].AsProto(metadata->mutable_shape());
1450   }
1451 }
1452 
1453 }  // namespace
1454 
SetDeviceToHostMetadata(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes)1455 Status XlaCompiler::SetDeviceToHostMetadata(
1456     const string& key, absl::Span<const DataType> types,
1457     absl::Span<const TensorShape> shapes) {
1458   if (host_compute_sends_.find(key) != host_compute_sends_.end()) {
1459     tf2xla::HostTransferMetadata& existing_transfer = host_compute_sends_[key];
1460     tf2xla::HostTransferMetadata new_transfer;
1461     SetTransfer(key, types, shapes, &new_transfer);
1462     if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) {
1463       return Status::OK();
1464     } else {
1465       return errors::InvalidArgument(
1466           "Duplicate calls to SetDeviceToHostMetadata with key ", key);
1467     }
1468   }
1469   tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key];
1470   SetTransfer(key, types, shapes, &transfer);
1471   return Status::OK();
1472 }
1473 
GetDeviceToHostShapes(const string & key,std::vector<TensorShape> * shapes) const1474 Status XlaCompiler::GetDeviceToHostShapes(
1475     const string& key, std::vector<TensorShape>* shapes) const {
1476   const auto iter = host_compute_sends_.find(key);
1477   if (iter == host_compute_sends_.end()) {
1478     return errors::InvalidArgument(
1479         "No host compute send shapes registered for key ", key);
1480   }
1481   shapes->clear();
1482   for (int i = 0; i < iter->second.metadata_size(); ++i) {
1483     TensorShape shape(iter->second.metadata(i).shape());
1484     shapes->push_back(shape);
1485   }
1486   return Status::OK();
1487 }
1488 
SetHostToDeviceMetadata(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes)1489 Status XlaCompiler::SetHostToDeviceMetadata(
1490     const string& key, absl::Span<const DataType> types,
1491     absl::Span<const TensorShape> shapes) {
1492   if (host_compute_recvs_.find(key) != host_compute_recvs_.end()) {
1493     tf2xla::HostTransferMetadata& existing_transfer = host_compute_recvs_[key];
1494     tf2xla::HostTransferMetadata new_transfer;
1495     SetTransfer(key, types, shapes, &new_transfer);
1496     if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) {
1497       return Status::OK();
1498     } else {
1499       return errors::InvalidArgument(
1500           "Duplicate calls to SetHostToDeviceMetadata with key ", key);
1501     }
1502   }
1503   tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key];
1504   SetTransfer(key, types, shapes, &transfer);
1505   return Status::OK();
1506 }
1507 
GetHostComputeControlDependency(const string & host_compute_name,xla::XlaOp * handle)1508 Status XlaCompiler::GetHostComputeControlDependency(
1509     const string& host_compute_name, xla::XlaOp* handle) {
1510   const auto iter = host_compute_control_output_.find(host_compute_name);
1511   if (iter == host_compute_control_output_.end()) {
1512     return errors::InvalidArgument(
1513         "No registered control handle for host compute Op '", host_compute_name,
1514         "'");
1515   } else {
1516     *handle = iter->second;
1517   }
1518   return Status::OK();
1519 }
1520 
SetHostComputeControlDependency(const string & host_compute_name,const xla::XlaOp & handle)1521 Status XlaCompiler::SetHostComputeControlDependency(
1522     const string& host_compute_name, const xla::XlaOp& handle) {
1523   if (host_compute_control_output_.find(host_compute_name) !=
1524       host_compute_control_output_.end()) {
1525     return errors::InvalidArgument(
1526         "Duplicate control handles registered for for host compute Op ",
1527         host_compute_name);
1528   }
1529   host_compute_control_output_[host_compute_name] = handle;
1530   return Status::OK();
1531 }
1532 
PushNodeTokenMapping()1533 void XlaCompiler::PushNodeTokenMapping() {
1534   node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{});
1535 }
1536 
PopNodeTokenMapping()1537 Status XlaCompiler::PopNodeTokenMapping() {
1538   if (node_token_mapping_stack_.empty()) {
1539     return errors::FailedPrecondition(
1540         "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is "
1541         "empty.");
1542   }
1543   node_token_mapping_stack_.pop();
1544   return Status::OK();
1545 }
1546 
SetNodeToken(const string & node_name,const xla::XlaOp & op)1547 Status XlaCompiler::SetNodeToken(const string& node_name,
1548                                  const xla::XlaOp& op) {
1549   if (node_token_mapping_stack_.empty()) {
1550     return errors::FailedPrecondition(
1551         "Calling SetNodeToken() when node_token_mapping_stack_ is "
1552         "empty.");
1553   }
1554   auto insert_result = node_token_mapping_stack_.top().insert({node_name, op});
1555   if (!insert_result.second) {
1556     return errors::FailedPrecondition("Token mapping already exists for node ",
1557                                       node_name);
1558   }
1559   return Status::OK();
1560 }
1561 
GetNodeToken(const string & node_name)1562 xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
1563   if (node_token_mapping_stack_.empty()) {
1564     return errors::FailedPrecondition(
1565         "Calling GetNodeToken() when node_token_mapping_stack_ is "
1566         "empty.");
1567   }
1568   auto iter = node_token_mapping_stack_.top().find(node_name);
1569   if (iter == node_token_mapping_stack_.top().end()) {
1570     return errors::FailedPrecondition("Cannot find token mapping for node ",
1571                                       node_name);
1572   }
1573   return iter->second;
1574 }
1575 
1576 }  // namespace tensorflow
1577