• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
17 
18 #include <numeric>
19 #include <vector>
20 
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/sharding_util.h"
25 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
27 #include "tensorflow/compiler/tf2xla/type_util.h"
28 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
29 #include "tensorflow/compiler/tf2xla/xla_context.h"
30 #include "tensorflow/compiler/xla/client/client_library.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/compiler/xla/client/xla_computation.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/common_runtime/device.h"
35 #include "tensorflow/core/common_runtime/executor.h"
36 #include "tensorflow/core/common_runtime/function.h"
37 #include "tensorflow/core/common_runtime/graph_optimizer.h"
38 #include "tensorflow/core/framework/attr_value_util.h"
39 #include "tensorflow/core/framework/function.h"
40 #include "tensorflow/core/framework/node_def_util.h"
41 #include "tensorflow/core/framework/types.h"
42 #include "tensorflow/core/graph/algorithm.h"
43 #include "tensorflow/core/graph/graph_constructor.h"
44 #include "tensorflow/core/graph/node_builder.h"
45 #include "tensorflow/core/lib/core/error_codes.pb.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/gtl/cleanup.h"
48 #include "tensorflow/core/lib/hash/hash.h"
49 #include "tensorflow/core/platform/logging.h"
50 #include "tensorflow/core/util/dump_graph.h"
51 
52 namespace tensorflow {
53 namespace {
54 
55 // Checks that arguments `args` match types `types`.
CheckSignature(const DataTypeVector & types,absl::Span<const XlaCompiler::Argument> args)56 Status CheckSignature(const DataTypeVector& types,
57                       absl::Span<const XlaCompiler::Argument> args) {
58   if (args.size() != types.size()) {
59     return errors::Internal("Compilation arguments have ", args.size(),
60                             " elements while function has ", types.size());
61   }
62   for (int i = 0; i < types.size(); ++i) {
63     // Don't perform type checks on resource variables and tensor
64     // lists (DT_VARIANT) as we have to trick the type system in order to
65     // plumb them through. DT_VARIANTS are wrapped in a DT_UINT8 tensor.
66     if (types[i] != args[i].type && types[i] != DT_RESOURCE &&
67         types[i] != DT_VARIANT) {
68       return errors::Internal(
69           "Argument ", i, " has declared type ", DataTypeString(args[i].type),
70           " but function parameter has type ", DataTypeString(types[i]));
71     }
72   }
73   return Status::OK();
74 }
75 
76 // Uses the _Arg and _Retval nodes in the graph to determine a core assignment
77 // for each argument and return value.
78 xla::StatusOr<std::pair<std::map<int, int>, std::map<int, int>>>
ComputeArgAndRetvalCores(const Graph & graph)79 ComputeArgAndRetvalCores(const Graph& graph) {
80   auto get_sharding_for_node = [](const Node* n) -> xla::StatusOr<int> {
81     TF_ASSIGN_OR_RETURN(
82         auto sharding,
83         ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
84     if (sharding.has_value()) {
85       TF_RET_CHECK(sharding.value().type() ==
86                    xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
87       return sharding.value().tile_assignment_devices(0);
88     } else {
89       return -1;
90     }
91   };
92   std::map<int, int> arg_cores;
93   std::map<int, int> retval_cores;
94   for (const Node* n : graph.nodes()) {
95     if (n->IsArg()) {
96       TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n));
97       if (core < 0) continue;
98       int index;
99       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
100       TF_RET_CHECK(index >= 0) << "Negative _Arg index";
101       arg_cores[index] = core;
102     } else if (n->IsRetval()) {
103       TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n));
104       if (core < 0) continue;
105       int index;
106       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
107       TF_RET_CHECK(index >= 0) << "Negative _Retval index";
108       TF_ASSIGN_OR_RETURN(retval_cores[index], get_sharding_for_node(n));
109       retval_cores[index] = core;
110     }
111   }
112   return std::make_pair(std::move(arg_cores), std::move(retval_cores));
113 }
114 
ExecuteGraph(XlaContext * xla_context,std::unique_ptr<Graph> graph,XlaCompilationDevice * device,FunctionLibraryRuntime * flib,int64 step_id)115 Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
116                     XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
117                     int64 step_id) {
118   // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
119   // resource manager takes ownership via Create, and unrefs via Cleanup.  We
120   // explicitly add a reference to ensure the refcount at entry is maintained at
121   // all exit points; Create and Cleanup are always called in this function.
122   //
123   // The Executor requires us to use ScopedStepContainer. We wrap it in a
124   // unique_ptr so we can capture the cleanup status in the end.
125   xla_context->Ref();
126   Status status;
127   auto step_container = absl::make_unique<ScopedStepContainer>(
128       step_id, [&status, device](const string& name) {
129         status = device->resource_manager()->Cleanup(name);
130       });
131   TF_RETURN_IF_ERROR(device->resource_manager()->Create(
132       step_container->name(), XlaContext::kXlaContextResourceName,
133       xla_context));
134 
135   GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
136   TF_RETURN_IF_ERROR(graph_compiler.Compile());
137   // Explicitly clean up the step container, to capture the cleanup status.
138   step_container.reset();
139   return Status::OK();
140 }
141 
142 // Builds the XLA computation.
143 // - `args` is the list of input arguments
144 // - `retvals` is the list of retvals produced by _Retval operators, in index
145 //   order.
146 // - `args_core` and `retval_cores` are mapping from arg/return indices to core
147 //   assignments.
148 // - If `return_updated_values_for_all_resources` is true, all resources will be
149 //   included in `resource_updates`, regardless of whether their value changed.
150 // - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
151 // - Sets `*resource_updates` to a description of resources whose values are
152 //   written by the computation; the variable writes are the last
153 // - `resource_updates.size()` return values from the computation. Each entry in
154 //   `resource_updates` is a ResourceUpdate, whose `index` is the index of a
155 //   resource variable argument to the computation to be updated, and `type` is
156 //   the type of the final output.
BuildComputation(const std::vector<XlaCompiler::Argument> & args,const std::vector<XlaExpression> & retvals,const std::map<int,int> & arg_cores,const std::map<int,int> & retval_cores,const std::vector<std::unique_ptr<XlaResource>> & resources,std::unique_ptr<xla::XlaOp> token_output,const XlaCompiler::ShapeRepresentationFn & shape_representation_fn,bool return_updated_values_for_all_resources,bool always_return_tuple,xla::XlaBuilder * builder,xla::XlaComputation * computation,int * num_computation_outputs,int * num_nonconst_outputs,std::vector<XlaCompiler::OutputDescription> * outputs,std::vector<XlaCompiler::ResourceUpdate> * resource_updates,xla::Shape * output_shape)157 Status BuildComputation(
158     const std::vector<XlaCompiler::Argument>& args,
159     const std::vector<XlaExpression>& retvals,
160     const std::map<int, int>& arg_cores, const std::map<int, int>& retval_cores,
161     const std::vector<std::unique_ptr<XlaResource>>& resources,
162     std::unique_ptr<xla::XlaOp> token_output,
163     const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
164     bool return_updated_values_for_all_resources, bool always_return_tuple,
165     xla::XlaBuilder* builder, xla::XlaComputation* computation,
166     int* num_computation_outputs, int* num_nonconst_outputs,
167     std::vector<XlaCompiler::OutputDescription>* outputs,
168     std::vector<XlaCompiler::ResourceUpdate>* resource_updates,
169     xla::Shape* output_shape) {
170   // Attach a common operator name as metadata. This has no semantic effect — it
171   // merely makes the HLO graph more readable when visualized via TensorBoard,
172   // since TensorBoard forms groups out of operators with similar names.
173   xla::OpMetadata retval_metadata;
174   retval_metadata.set_op_name("XLA_Retvals");
175   builder->SetOpMetadata(retval_metadata);
176   auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); });
177 
178   // Builds a no-op XLA computation. We need to set the sharding of outputs, but
179   // cannot change the sharding of the existing output op. To do this, we build
180   // a new identity op to which shardings can be applied.
181   auto identity_op = [builder](xla::XlaOp op) {
182     return xla::GetTupleElement(xla::Tuple(builder, {op}), 0);
183   };
184 
185   std::vector<xla::XlaOp> elems;
186   elems.reserve(retvals.size());
187 
188   // Keeps track of the layout of each retval. If a retval is not in this list,
189   // a descending layout is used. The first element is the output index, second
190   // element is the new layout.
191   std::vector<std::pair<int64, xla::Layout>> retval_index_and_layout;
192   for (int i = 0; i < retvals.size(); ++i) {
193     XlaCompiler::OutputDescription& output = (*outputs)[i];
194     const XlaExpression& retval = retvals[i];
195     output.type = retval.dtype();
196     switch (retval.kind()) {
197       case XlaExpression::Kind::kConstant:
198         output.is_constant = true;
199         output.constant_value = retval.constant_value();
200         output.shape = output.constant_value.shape();
201         break;
202 
203       case XlaExpression::Kind::kTensorList:
204         TF_FALLTHROUGH_INTENDED;
205       case XlaExpression::Kind::kXlaOp: {
206         output.is_constant = false;
207         TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
208         xla::XlaOp value = retval.handle();
209         auto it = retval_cores.find(i);
210         xla::XlaScopedShardingAssignment assign_sharding(
211             builder, it == retval_cores.end()
212                          ? absl::optional<xla::OpSharding>()
213                          : xla::sharding_builder::AssignDevice(it->second));
214         if (shape_representation_fn) {
215           // If there is a shape representation function, reshape the output
216           // tensor to the shape given by the representation shape function.
217           TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn(
218                                                     output.shape, output.type));
219           value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions()));
220           retval_index_and_layout.emplace_back(elems.size(), shape.layout());
221         } else if (it != retval_cores.end()) {
222           // Apply the sharding to the output, if there is a core assignment.
223           value = identity_op(value);
224         }
225 
226         elems.push_back(value);
227         break;
228       }
229 
230       case XlaExpression::Kind::kResource:
231         output.is_constant = false;
232         output.input_index = retval.resource()->arg_num();
233         output.shape = retval.resource()->shape();
234         break;
235 
236       case XlaExpression::Kind::kInvalid:
237         return errors::InvalidArgument(
238             "Invalid expression returned by computation. "
239             "This probably means a return value was not set.");
240     }
241   }
242   *num_nonconst_outputs = elems.size();
243 
244   // Add return values for resources whose values have changed.
245   std::vector<const XlaResource*> arg_resources;
246   arg_resources.reserve(resources.size());
247   for (const auto& resource : resources) {
248     if (resource->arg_num() >= 0) {
249       arg_resources.push_back(resource.get());
250     }
251   }
252   std::sort(arg_resources.begin(), arg_resources.end(),
253             [](const XlaResource* a, const XlaResource* b) {
254               return a->arg_num() < b->arg_num();
255             });
256 
257   for (const XlaResource* resource : arg_resources) {
258     DCHECK_LT(resource->arg_num(), args.size());
259     const XlaCompiler::Argument& arg = args[resource->arg_num()];
260     auto it = arg_cores.find(resource->arg_num());
261     const int core = it == arg_cores.end() ? -1 : it->second;
262     bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
263     // TensorArray gradients were modified if their values changed or there are
264     // any newly created gradients.
265     for (const auto& grad : resource->tensor_array_gradients()) {
266       modified =
267           modified ||
268           !grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
269           arg.tensor_array_gradients.count(grad.first) == 0;
270     }
271     if (return_updated_values_for_all_resources || modified) {
272       resource_updates->emplace_back();
273       XlaCompiler::ResourceUpdate& update = resource_updates->back();
274       update.input_index = resource->arg_num();
275       update.type = resource->type();
276       update.shape = resource->shape();
277       update.modified = modified;
278       for (const auto& grad : resource->tensor_array_gradients()) {
279         update.tensor_array_gradients_accessed.insert(grad.first);
280       }
281 
282       // Request that the value be returned on a specific core.
283       xla::XlaScopedShardingAssignment assign_sharding(
284           builder, core == -1 ? absl::optional<xla::OpSharding>()
285                               : xla::sharding_builder::AssignDevice(core));
286 
287       xla::XlaOp handle;
288       TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
289 
290       // Ensures the correct sharding is applied to the output.
291       handle = identity_op(handle);
292 
293       // Set layout of the retval to device representation layout.
294       if (resource->representation_shape().has_value()) {
295         retval_index_and_layout.emplace_back(
296             elems.size(), resource->representation_shape()->layout());
297       }
298       elems.push_back(handle);
299     }
300   }
301 
302   // If we have token output, append it as the last one.
303   if (token_output) {
304     elems.push_back(*token_output);
305   }
306 
307   *num_computation_outputs = elems.size();
308 
309   // Builds the XLA computation. We *always* form a tuple here to ensure that
310   // the output value is the last thing added into the XLA computation, even
311   // if there is only one output value.
312   auto tuple = xla::Tuple(builder, elems);
313   if (!always_return_tuple && elems.size() == 1) {
314     xla::GetTupleElement(tuple, 0);
315   }
316 
317   xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();
318   if (!computation_status.ok()) {
319     return computation_status.status();
320   }
321   *computation = computation_status.ConsumeValueOrDie();
322 
323   TF_ASSIGN_OR_RETURN(const auto& program_shape,
324                       computation->GetProgramShape());
325   *output_shape = program_shape.result();
326   // Update the output layout to the layout of retval.
327   for (auto& index_and_layout : retval_index_and_layout) {
328     if (!always_return_tuple && elems.size() == 1) {
329       *output_shape->mutable_layout() = index_and_layout.second;
330       continue;
331     }
332 
333     xla::Shape* output_sub_shape = xla::ShapeUtil::GetMutableSubshape(
334         output_shape, {index_and_layout.first});
335     *output_sub_shape->mutable_layout() = index_and_layout.second;
336   }
337   return Status::OK();
338 }
339 
340 }  // namespace
341 
operator ==(const XlaCompiler::Argument & other) const342 bool XlaCompiler::Argument::operator==(
343     const XlaCompiler::Argument& other) const {
344   if (std::tie(kind, resource_kind, type, name, initialized, max_array_size,
345                tensor_array_gradients) !=
346       std::tie(other.kind, other.resource_kind, other.type, other.name,
347                other.initialized, other.max_array_size,
348                other.tensor_array_gradients)) {
349     return false;
350   }
351   if (absl::holds_alternative<xla::Shape>(shape)) {
352     if (!absl::holds_alternative<xla::Shape>(other.shape)) {
353       return false;
354     }
355     if (!xla::Shape::Equal()(absl::get<xla::Shape>(shape),
356                              absl::get<xla::Shape>(other.shape))) {
357       return false;
358     }
359   } else {
360     if (!absl::holds_alternative<TensorShape>(other.shape)) {
361       return false;
362     }
363     if (absl::get<TensorShape>(shape) != absl::get<TensorShape>(other.shape)) {
364       return false;
365     }
366   }
367   if (constant_value.shape() != other.constant_value.shape()) {
368     return false;
369   }
370   return constant_value.tensor_data() == other.constant_value.tensor_data();
371 }
372 
HumanString() const373 string XlaCompiler::Argument::HumanString() const {
374   string common;
375   if (!name.empty()) {
376     common = absl::StrCat(" name=", name);
377   }
378   absl::StrAppend(&common, " type=", DataTypeString(type),
379                   " shape=", ShapeHumanString());
380   switch (kind) {
381     case kInvalid:
382       return "invalid";
383     case kConstant:
384       return absl::StrCat("kind=constant", common,
385                           " value=", constant_value.DebugString());
386     case kResource: {
387       string output = absl::StrCat("kind=resource", common, " resource_kind=",
388                                    XlaResource::KindToString(resource_kind),
389                                    " initialized=", initialized);
390       if (max_array_size >= 0) {
391         absl::StrAppend(&output, " max_array_size=", max_array_size);
392       }
393       if (!tensor_array_gradients.empty()) {
394         absl::StrAppend(&output, " tensor_array_gradients=",
395                         absl::StrJoin(tensor_array_gradients, ","));
396       }
397       return output;
398     }
399     case kParameter:
400       return absl::StrCat("kind=parameter", common);
401     case kToken:
402       return absl::StrCat("token", common);
403   }
404 }
405 
DimensionSizes() const406 std::vector<int64> XlaCompiler::Argument::DimensionSizes() const {
407   if (absl::holds_alternative<TensorShape>(shape)) {
408     return xla::InlinedVectorToVector(
409         absl::get<TensorShape>(shape).dim_sizes());
410   } else {
411     return absl::get<xla::Shape>(shape).dimensions();
412   }
413 }
414 
ShapeHumanString() const415 string XlaCompiler::Argument::ShapeHumanString() const {
416   if (absl::holds_alternative<TensorShape>(shape)) {
417     return absl::get<TensorShape>(shape).DebugString();
418   } else {
419     return absl::get<xla::Shape>(shape).DebugString();
420   }
421 }
422 
XlaCompiler(XlaCompiler::Options options)423 XlaCompiler::XlaCompiler(XlaCompiler::Options options)
424     : options_(options),
425       initialization_status_(Status::OK()),
426       next_step_id_(1),
427       device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
428       device_mgr_(absl::WrapUnique(device_)) {
429   CHECK(!options_.device_type.type_string().empty());
430   if (options_.populate_resource_manager) {
431     initialization_status_ =
432         (*options_.populate_resource_manager)(device_->resource_manager());
433   }
434 
435   local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
436                                                       FunctionDefLibrary{}));
437   local_pflr_.reset(new ProcessFunctionLibraryRuntime(
438       &device_mgr_, Env::Default(), options.graph_def_version,
439       local_flib_def_.get(), OptimizerOptions(),
440       nullptr /* custom_kernel_creator */));
441   pflr_.reset(new ProcessFunctionLibraryRuntime(
442       &device_mgr_, Env::Default(), options.graph_def_version, options.flib_def,
443       OptimizerOptions(), nullptr /* custom_kernel_creator */));
444 
445   local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
446   flib_runtime_ = pflr_->GetFLR(device_->name());
447 
448   // The default shape representation function is the identity.
449   if (!options_.shape_representation_fn) {
450     options_.shape_representation_fn =
451         [](const TensorShape& shape,
452            DataType dtype) -> xla::StatusOr<xla::Shape> {
453       xla::Shape xla_shape;
454       TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
455       return xla_shape;
456     };
457   }
458 }
459 
460 XlaCompiler::~XlaCompiler() = default;
461 
NextStepId()462 int64 XlaCompiler::NextStepId() { return next_step_id_++; }
463 
operator ()(const std::pair<string,std::vector<Argument>> & signature) const464 uint64 XlaCompiler::SignatureHash::operator()(
465     const std::pair<string, std::vector<Argument>>& signature) const {
466   return std::hash<string>()(signature.first);
467 }
468 
GetFunctionBody(const NameAttrList & function,FunctionLibraryRuntime * flib_runtime,const FunctionBody ** fbody)469 static Status GetFunctionBody(const NameAttrList& function,
470                               FunctionLibraryRuntime* flib_runtime,
471                               const FunctionBody** fbody) {
472   FunctionLibraryRuntime::Handle handle;
473   TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
474       function.name(), AttrSlice(&function.attr()), &handle));
475 
476   *fbody = flib_runtime->GetFunctionBody(handle);
477   TF_RET_CHECK(*fbody);
478   return Status::OK();
479 }
480 
FindFunctionBody(const NameAttrList & function,const FunctionBody ** fbody)481 Status XlaCompiler::FindFunctionBody(const NameAttrList& function,
482                                      const FunctionBody** fbody) {
483   // The function may be in either the local_flib_runtime_ or flib_runtime_.
484   // Look up the function in local first and if it is not found then look up the
485   // function in flib_runtime_.
486   auto status = GetFunctionBody(function, local_flib_runtime_, fbody);
487   if (!status.ok()) {
488     if (!errors::IsNotFound(status)) {
489       return status;
490     }
491     TF_RETURN_WITH_CONTEXT_IF_ERROR(
492         GetFunctionBody(function, flib_runtime_, fbody),
493         "Local lookup failed with: ", status.error_message());
494     VLOG(4) << "Function " << function.name() << " in flib_runtime_";
495   } else {
496     VLOG(4) << "Function " << function.name() << " in local_flib_runtime_";
497   }
498   return Status::OK();
499 }
500 
GetGraph(const FunctionBody * fbody)501 std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
502   std::unique_ptr<Graph> graph(new Graph(options_.flib_def));
503   CopyGraph(*fbody->graph, graph.get());
504   OptimizerOptions opts;
505   opts.set_opt_level(OptimizerOptions::L0);
506   opts.set_do_common_subexpression_elimination(false);
507   opts.set_do_function_inlining(true);
508   opts.set_do_constant_folding(true);
509   GraphOptimizer optimizer(opts);
510   // Do not constant fold nodes that output DT_VARIANT type tensors.
511   // XLA does not support Const nodes of Variant type since it needs
512   // to know the original ops to be able to compile them to the relevant
513   // XLA form.
514   // TODO(srbs): This filter is a little conservative. E.g. a subgraph of
515   // the form:
516   //                          Const
517   //                            |
518   // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op
519   //                                                  |
520   //                                        (Discard popped list)
521   //
522   // Would have been reduced to "Const -> Op" without this filter.
523   // However since we are only allowed to specify the filter at the "Node"
524   // level there is no good way to allow the above behavior. So we
525   // disallow any sort of constant folding on Variant nodes for now.
526   auto cf_consider_fn = [](const Node* n) {
527     for (const auto& output_arg : n->op_def().output_arg()) {
528       if (output_arg.type() == DT_VARIANT) {
529         return false;
530       }
531     }
532     return true;
533   };
534   GraphOptimizer::Options graph_optimizer_options;
535   graph_optimizer_options.cf_consider_fn = cf_consider_fn;
536   optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
537                      /*device=*/nullptr, &graph, graph_optimizer_options);
538 
539   return graph;
540 }
541 
CompileFunction(const XlaCompiler::CompileOptions & options,const NameAttrList & function,absl::Span<const XlaCompiler::Argument> args,XlaCompiler::CompilationResult * result)542 Status XlaCompiler::CompileFunction(
543     const XlaCompiler::CompileOptions& options, const NameAttrList& function,
544     absl::Span<const XlaCompiler::Argument> args,
545     XlaCompiler::CompilationResult* result) {
546   const string function_id =
547       Canonicalize(function.name(), AttrSlice(&function.attr()));
548   VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
549 
550   const std::vector<XlaCompiler::Argument> arg_vector(args.begin(), args.end());
551   auto it = cache_.find({function_id, arg_vector});
552   if (it != cache_.end()) {
553     *result = it->second;
554     return Status::OK();
555   }
556 
557   const FunctionBody* fbody;
558   TF_RETURN_IF_ERROR(FindFunctionBody(function, &fbody));
559 
560   TF_RETURN_WITH_CONTEXT_IF_ERROR(
561       CheckSignature(fbody->arg_types, args),
562       "Signature check failure while compiling: ", function.name());
563 
564   std::unique_ptr<Graph> graph = GetGraph(fbody);
565 
566   // Clear the "_kernel" attribute if it is set to "host". This is used to
567   // indicate that a computation should happen on the host instead of the
568   // accelerator, but doesn't make sense in XLA.
569   const char* const kKernelAttr = "_kernel";
570   for (Node* n : graph->nodes()) {
571     string value;
572     if (GetNodeAttrSimple(n->attrs(), kKernelAttr, &value) && value == "host") {
573       n->ClearAttr(kKernelAttr);
574     }
575   }
576 
577   // _Arg and _Retval nodes don't exist in the stored subgraph for the function;
578   // they are added by the function body looked up.  Therefore, they don't have
579   // core assignments here.
580   // Attempt to assign a core to each _Retval and _Arg. Chooses the
581   // lowest-numbered core that consumes the argument. We choose the
582   // lowest-numbered core so the assignment is deterministic.
583   for (Node* n : graph->nodes()) {
584     if (n->IsArg()) {
585       TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true));
586     }
587   }
588   // Do _Retval as a second loop, in case the retval's input is an _Arg (which
589   // may have gotten a device assignment from the first loop).
590   for (Node* n : graph->nodes()) {
591     if (n->IsRetval()) {
592       TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false));
593     }
594   }
595 
596   if (VLOG_IS_ON(2)) {
597     VLOG(2) << "XlaCompiler::CompileFunction: "
598             << DumpGraphToFile(
599                    absl::StrCat("xla_compile_function_", function_id), *graph);
600   }
601 
602   VLOG(1) << "====================================================";
603   TF_RETURN_IF_ERROR(
604       CompileGraph(options, function_id, std::move(graph), args, {}, result));
605   VLOG(1) << "====================================================";
606 
607   cache_[{function_id, arg_vector}] = *result;
608   return Status::OK();
609 }
610 
611 // Computes the XLA shape for argument 'arg'.
XLAShapeForArgument(const XlaCompiler::Argument & arg,bool is_entry_computation,xla::Shape * xla_shape) const612 Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
613                                         bool is_entry_computation,
614                                         xla::Shape* xla_shape) const {
615   switch (arg.kind) {
616     case XlaCompiler::Argument::kConstant:
617       LOG(FATAL) << "Unreachable case";
618     case XlaCompiler::Argument::kParameter: {
619       if (is_entry_computation) {
620         TensorShape shape;
621         if (absl::holds_alternative<TensorShape>(arg.shape)) {
622           shape = absl::get<TensorShape>(arg.shape);
623         } else {
624           TF_RETURN_IF_ERROR(
625               XLAShapeToTensorShape(absl::get<xla::Shape>(arg.shape), &shape));
626         }
627         TF_ASSIGN_OR_RETURN(*xla_shape,
628                             options_.shape_representation_fn(shape, arg.type));
629       } else {
630         if (absl::holds_alternative<xla::Shape>(arg.shape)) {
631           *xla_shape = absl::get<xla::Shape>(arg.shape);
632         } else {
633           TF_RETURN_IF_ERROR(TensorShapeToXLAShape(
634               arg.type, absl::get<TensorShape>(arg.shape), xla_shape));
635         }
636       }
637       return Status::OK();
638     }
639     case XlaCompiler::Argument::kResource: {
640       TF_RET_CHECK(arg.initialized);
641 
642       switch (arg.resource_kind) {
643         case XlaResource::kVariable: {
644           TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
645           TF_ASSIGN_OR_RETURN(*xla_shape,
646                               options_.shape_representation_fn(
647                                   absl::get<TensorShape>(arg.shape), arg.type));
648 
649           return Status::OK();
650         }
651         case XlaResource::kTensorArray: {
652           if (arg.max_array_size < 0) {
653             return errors::InvalidArgument(
654                 "Negative max_array_size in XLAShapeForArgument");
655           }
656           TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
657           TensorShape shape;
658           shape.AddDim(arg.max_array_size);
659           shape.AppendShape(absl::get<TensorShape>(arg.shape));
660           TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape));
661 
662           if (!arg.tensor_array_gradients.empty()) {
663             std::vector<xla::Shape> tuple_shape(
664                 arg.tensor_array_gradients.size() + 1, *xla_shape);
665             *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape);
666           }
667           return Status::OK();
668         }
669         case XlaResource::kStack: {
670           if (arg.max_array_size < 0) {
671             return errors::InvalidArgument(
672                 "Negative max_array_size in XLAShapeForArgument");
673           }
674           TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
675           TensorShape shape;
676           shape.AddDim(arg.max_array_size);
677           shape.AppendShape(absl::get<TensorShape>(arg.shape));
678           xla::Shape buffer_shape;
679           TF_RETURN_IF_ERROR(
680               TensorShapeToXLAShape(arg.type, shape, &buffer_shape));
681           *xla_shape = xla::ShapeUtil::MakeTupleShape(
682               {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})});
683           return Status::OK();
684         }
685 
686         case XlaResource::kInvalid:
687           return errors::Internal(
688               "Invalid resource type in XLAShapeForArgument()");
689       }
690     }
691     case XlaCompiler::Argument::kToken: {
692       *xla_shape = xla::ShapeUtil::MakeTokenShape();
693       return Status::OK();
694     }
695     case XlaCompiler::Argument::kInvalid:
696       return errors::Internal("Invalid argument type in XLAShapeForArgument()");
697   }
698 }
699 
700 // Builds XLA computations for each of the arguments to the computation.
701 // `args` are the arguments to the computation.
BuildArguments(const Graph & graph,const std::vector<XlaCompiler::Argument> & args,bool use_tuple_arg,xla::XlaBuilder * builder,XlaContext * context,const std::map<int,int> & arg_cores,std::vector<XlaExpression> * arg_expressions,std::vector<int> * input_to_args,std::vector<xla::Shape> * input_shapes,bool is_entry_computation)702 Status XlaCompiler::BuildArguments(
703     const Graph& graph, const std::vector<XlaCompiler::Argument>& args,
704     bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context,
705     const std::map<int, int>& arg_cores,
706     std::vector<XlaExpression>* arg_expressions,
707     std::vector<int>* input_to_args, std::vector<xla::Shape>* input_shapes,
708     bool is_entry_computation) {
709   arg_expressions->resize(args.size());
710 
711   // Argument numbers of arguments and resources that are to be passed to the
712   // XLA computation as runtime parameters. `input_to_args[a] = b` means that
713   // the a'th XLA input corresponds to the b'th original arg indexes.
714   input_to_args->clear();
715   input_to_args->reserve(args.size());
716 
717   // Fills in constant arguments, and computes non-constant argument order.
718   for (std::vector<XlaCompiler::Argument>::size_type i = 0; i < args.size();
719        ++i) {
720     const XlaCompiler::Argument& arg = args[i];
721     XlaExpression& arg_expression = (*arg_expressions)[i];
722     switch (arg.kind) {
723       case XlaCompiler::Argument::kResource: {
724         TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid);
725         TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
726         // TODO(phawkins): this code assumes that resource arguments do not
727         // alias.
728         XlaResource* resource =
729             context->AddResource(absl::make_unique<XlaResource>(
730                 arg.resource_kind, i, arg.name, arg.type,
731                 absl::get<TensorShape>(arg.shape), xla::XlaOp(),
732                 /*max_array_size=*/arg.max_array_size,
733                 /*tensor_array_gradients=*/arg.tensor_array_gradients,
734                 /*tensor_array_multiple_writes_aggregate=*/true));
735         arg_expression = XlaExpression::Resource(resource);
736         if (arg.initialized) {
737           input_to_args->push_back(i);
738         }
739         break;
740       }
741       case XlaCompiler::Argument::kParameter:
742       case XlaCompiler::Argument::kToken: {
743         input_to_args->push_back(i);
744         break;
745       }
746       case XlaCompiler::Argument::kConstant:
747         arg_expression = XlaExpression::Constant(arg.constant_value);
748         break;
749       case XlaCompiler::Argument::kInvalid:
750         return errors::Internal(
751             "Unreachable case in BuildArguments() while filling constant args");
752     }
753   }
754 
755   if (input_to_args->empty()) {
756     return Status::OK();
757   }
758 
759   // `arg_to_inputs[c] = d` means that the c'th original arg index corresponds
760   // to the d'th XLA input. Note that the value -1 corresponds to constants, or
761   // other args that don't correspond to an input.
762   std::vector<int> arg_to_inputs(args.size(), -1);
763   for (int i = 0; i < input_to_args->size(); i++) {
764     arg_to_inputs[input_to_args->at(i)] = i;
765   }
766 
767   std::vector<xla::Shape> arg_shapes(input_to_args->size());
768   for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
769     // Computes the shapes of non-constant arguments.
770     TF_RETURN_IF_ERROR(XLAShapeForArgument(
771         args[(*input_to_args)[i]], is_entry_computation, &arg_shapes[i]));
772   }
773 
774   if (use_tuple_arg) {
775     input_shapes->push_back(xla::ShapeUtil::MakeTupleShape(arg_shapes));
776   } else {
777     *input_shapes = arg_shapes;
778   }
779 
780   // Attach a common operator name as metadata. This has no semantic effect — it
781   // merely makes the HLO graph more readable when visualized via TensorBoard,
782   // since TensorBoard forms groups out of operators with similar names.
783   xla::OpMetadata arg_metadata;
784   arg_metadata.set_op_name("XLA_Args");
785   builder->SetOpMetadata(arg_metadata);
786 
787   // Build parameter handles for non-constant arguments.
788   std::vector<xla::XlaOp> arg_handles(input_to_args->size());
789   if (use_tuple_arg) {
790     xla::XlaOp tuple;
791     if (is_entry_computation) {
792       xla::OpSharding tuple_sharding;
793       tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
794       for (int64 parameter : *input_to_args) {
795         auto it = arg_cores.find(parameter);
796         const int core = it == arg_cores.end() ? 0 : it->second;
797         *tuple_sharding.add_tuple_shardings() =
798             xla::sharding_builder::AssignDevice(core);
799       }
800       xla::XlaScopedShardingAssignment assign_tuple_sharding(builder,
801                                                              tuple_sharding);
802       tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
803     } else {
804       tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
805     }
806 
807     for (int i = 0; i < input_to_args->size(); ++i) {
808       const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
809       for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) {
810         int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second);
811         TF_RETURN_IF_ERROR(builder->SetDynamicBinding(
812             /*dynamic_size_param_num=*/0, {dynamic_size_param_index},
813             /*target_param_num=*/0, /*target_param_index=*/{i},
814             dim_and_arg_num.first));
815       }
816     }
817 
818     for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
819       auto it = arg_cores.find(i);
820       const int core = it == arg_cores.end() ? -1 : it->second;
821       xla::XlaScopedShardingAssignment assign_sharding(
822           builder, core == -1 ? absl::optional<xla::OpSharding>()
823                               : xla::sharding_builder::AssignDevice(core));
824       arg_handles[i] = xla::GetTupleElement(tuple, i);
825     }
826   } else {
827     for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
828       auto it = arg_cores.find(i);
829       const int core = it == arg_cores.end() ? -1 : it->second;
830       xla::XlaScopedShardingAssignment assign_sharding(
831           builder, core == -1 ? absl::optional<xla::OpSharding>()
832                               : xla::sharding_builder::AssignDevice(core));
833       arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i],
834                                       absl::StrCat("arg", i));
835     }
836 
837     for (int i = 0; i < input_to_args->size(); ++i) {
838       const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
839       for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) {
840         int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second);
841         TF_RETURN_IF_ERROR(builder->SetDynamicBinding(
842             /*dynamic_size_param_num=*/dynamic_size_param_index, {},
843             /*target_param_num=*/i, /*target_param_index=*/{},
844             dim_and_arg_num.first));
845       }
846     }
847   }
848 
849   builder->ClearOpMetadata();
850 
851   // Fill in the handles in non-constant arguments, and reshape parameters
852   // back to their correct shapes.
853   VLOG(2) << "XLA computation inputs:";
854   for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
855     const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
856     VLOG(2) << "  XLA arg " << i
857             << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
858             << " name: " << arg.name << " TF arg " << input_to_args->at(i);
859     XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)];
860     switch (arg.kind) {
861       case XlaCompiler::Argument::kResource: {
862         TF_RET_CHECK(arg.initialized);
863         XlaResource* resource = arg_expression.resource();
864         TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients,
865                                                  arg_handles[i], builder));
866         VLOG(2) << "    resource: num_gradients: "
867                 << arg.tensor_array_gradients.size();
868         break;
869       }
870       case XlaCompiler::Argument::kParameter:
871         // Reshape parameters back to their correct shapes.
872         // TODO(b/76097077): propagate device assignments onto arguments and
873         // return values of functions, and then reshape unconditionally.
874         if (is_entry_computation) {
875           arg_expression = XlaExpression::XlaOp(
876               xla::Reshape(arg_handles[i], arg.DimensionSizes()), arg.type);
877         } else {
878           arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
879         }
880         break;
881       case XlaCompiler::Argument::kToken: {
882         arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
883         break;
884       }
885       case XlaCompiler::Argument::kConstant:
886       case XlaCompiler::Argument::kInvalid:
887         return errors::Internal(
888             "Unreachable case in BuildArguments() while filling handles");
889     }
890   }
891 
892   return Status::OK();
893 }
894 
CompileSingleOp(const XlaCompiler::CompileOptions & options,const NodeDef & node_def,absl::Span<const XlaCompiler::Argument> args,absl::Span<const DataType> result_types,CompilationResult * result)895 Status XlaCompiler::CompileSingleOp(
896     const XlaCompiler::CompileOptions& options, const NodeDef& node_def,
897     absl::Span<const XlaCompiler::Argument> args,
898     absl::Span<const DataType> result_types, CompilationResult* result) {
899   // TODO(b/74182462): We implement this by creating a new dummy Graph including
900   // _Arg nodes, and let CompileGraph walk it. This could be optimized.
901   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
902 
903   Status status;
904   // First create the actual node we care about computing.
905   Node* main_node = graph->AddNode(node_def, &status);
906   TF_RETURN_IF_ERROR(status);
907 
908   // Create dummy _Arg nodes. Link these to `node` and also via a control
909   // dependency edge to the _SOURCE node.
910   for (int64 i = 0; i < args.size(); ++i) {
911     Node* node;
912     string arg_name = absl::StrCat("_arg", i);
913     Status status =
914         NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
915             .ControlInput(graph->source_node())
916             .Attr("T", args[i].kind == Argument::kResource ? DT_RESOURCE
917                                                            : args[i].type)
918             .Attr("index", i)
919             .Finalize(graph.get(), &node);
920     TF_RETURN_IF_ERROR(status);
921     graph->AddEdge(node, 0, main_node, i);
922   }
923 
924   // Similarly with return values, create dummy _Retval nodes fed by `node`.
925   for (int64 i = 0; i < result_types.size(); ++i) {
926     Node* node;
927     string retval_name = absl::StrCat("_retval", i);
928     Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
929                         .Input(main_node, i)
930                         .Attr("T", result_types[i])
931                         .Attr("index", i)
932                         .Finalize(graph.get(), &node);
933     TF_RETURN_IF_ERROR(status);
934   }
935   FixupSourceAndSinkEdges(graph.get());
936 
937   return CompileGraph(options, node_def.name(), std::move(graph), args, {},
938                       result);
939 }
940 
941 namespace {
942 
943 // Check that the ops of all non-functional nodes have been registered.
ValidateFunctionDef(const FunctionDef * fdef,const FunctionLibraryDefinition & flib_def)944 Status ValidateFunctionDef(const FunctionDef* fdef,
945                            const FunctionLibraryDefinition& flib_def) {
946   for (const NodeDef& node : fdef->node_def()) {
947     const string& op = node.op();
948     if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) {
949       continue;
950     }
951     const OpDef* op_def;
952     TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(op, &op_def));
953   }
954   return Status::OK();
955 }
956 
957 // If node is PartitionedCall or StatefulPartitionedCall, returns the
958 // name from the "f" attr, else returns node.def().op().
959 // Returned pointer points to the internal string either in node's attributes
960 // or in its NodeDef. This pointer is valid as long as the node has not been
961 // modified.
GetPotentialFunctionName(const Node & node,const string ** name)962 Status GetPotentialFunctionName(const Node& node, const string** name) {
963   if (node.IsPartitionedCall()) {
964     const AttrValue* attr_value;
965     TF_RETURN_IF_ERROR(
966         node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
967     if (!attr_value->has_func()) {
968       return errors::InvalidArgument(
969           "The attribute value for attribute 'f' in node ", node.DebugString(),
970           " does not have 'func' field set");
971     }
972     *name = &attr_value->func().name();
973     return Status::OK();
974   }
975   *name = &node.type_string();
976   return Status::OK();
977 }
978 
979 // Check that the graph doesn't have any invalid nodes (e.g. incompatible with
980 // given device_type, invalid data type, missing attributes...)
ValidateGraph(const Graph * graph,const FunctionLibraryDefinition & flib_def,const DeviceType & device_type,const string & name)981 Status ValidateGraph(const Graph* graph,
982                      const FunctionLibraryDefinition& flib_def,
983                      const DeviceType& device_type, const string& name) {
984   auto maybe_error = [&](const Node* node, const Status& s) -> Status {
985     if (!s.ok()) {
986       return errors::InvalidArgument(absl::StrCat(
987           "Detected unsupported operations when trying to compile graph ", name,
988           " on ", device_type.type_string(), ": ", node->def().op(), " (",
989           s.error_message(), ")", FormatNodeForError(*node)));
990     }
991     return Status::OK();
992   };
993 
994   for (const Node* node : graph->nodes()) {
995     if (node->type_string() == FunctionLibraryDefinition::kGradientOp) {
996       continue;
997     }
998     const string* function_name;
999     TF_RETURN_IF_ERROR(GetPotentialFunctionName(*node, &function_name));
1000     const FunctionDef* fdef = flib_def.Find(*function_name);
1001     Status s;
1002     if (fdef) {
1003       s = ValidateFunctionDef(fdef, flib_def);
1004       TF_RETURN_IF_ERROR(maybe_error(node, s));
1005       continue;
1006     }
1007     const OpDef* op_def;
1008     s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def);
1009     TF_RETURN_IF_ERROR(maybe_error(node, s));
1010     TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def));
1011     s = FindKernelDef(device_type, node->def(), nullptr, nullptr);
1012     TF_RETURN_IF_ERROR(maybe_error(node, s));
1013   }
1014   return Status::OK();
1015 }
1016 
1017 // Converts the value of any expressions whose values are known at compile-time
1018 // to constants.
ResolveConstantExpressionsToConstants(xla::Client * client,absl::Span<XlaExpression> expressions)1019 Status ResolveConstantExpressionsToConstants(
1020     xla::Client* client, absl::Span<XlaExpression> expressions) {
1021   for (XlaExpression& expression : expressions) {
1022     if (expression.kind() == XlaExpression::Kind::kXlaOp) {
1023       TF_ASSIGN_OR_RETURN(absl::optional<Tensor> constant,
1024                           expression.ResolveConstant(client));
1025       if (constant.has_value()) {
1026         expression = XlaExpression::Constant(*constant);
1027       }
1028     }
1029   }
1030   return Status::OK();
1031 }
1032 
ConvertConstantsToExpressions(xla::XlaBuilder * builder,absl::Span<XlaExpression> expressions)1033 void ConvertConstantsToExpressions(xla::XlaBuilder* builder,
1034                                    absl::Span<XlaExpression> expressions) {
1035   for (XlaExpression& expression : expressions) {
1036     if (expression.kind() == XlaExpression::Kind::kConstant) {
1037       expression =
1038           XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype());
1039     }
1040   }
1041 }
1042 
1043 }  // namespace
1044 
CompileGraph(const XlaCompiler::CompileOptions & options,string const & name,std::unique_ptr<Graph> graph,absl::Span<const XlaCompiler::Argument> args,absl::Span<const xla::XlaBuilder::InputOutputAlias> user_aliases,CompilationResult * result)1045 Status XlaCompiler::CompileGraph(
1046     const XlaCompiler::CompileOptions& options, string const& name,
1047     std::unique_ptr<Graph> graph, absl::Span<const XlaCompiler::Argument> args,
1048     absl::Span<const xla::XlaBuilder::InputOutputAlias> user_aliases,
1049     CompilationResult* result) {
1050   VLOG(1) << "Executing graph symbolically to populate XlaBuilder.";
1051 
1052   TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes(
1053       graph.get(), options_.flib_def, local_flib_def_.get()));
1054   if (VLOG_IS_ON(2)) {
1055     VLOG(2) << "XlaCompiler::CompileGraph: "
1056             << DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph,
1057                                flib_runtime_->GetFunctionLibraryDefinition());
1058   }
1059 
1060   // Report the error here if initialization failed.
1061   TF_RETURN_IF_ERROR(initialization_status_);
1062 
1063   // Detect invalid nodes.
1064   // FunctionalizeControlFlow may remove some nodes from the graph.
1065   TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
1066                                    options_.device_type, name));
1067 
1068   xla::XlaBuilder builder(name);
1069   XlaContext* context = new XlaContext(this, &builder);
1070   core::ScopedUnref context_unref(context);
1071 
1072   std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());
1073   int token_input_index = -1;
1074   std::unique_ptr<xla::XlaOp> token_output;
1075   if (options.add_token_input_output) {
1076     // Add extra token input.
1077     token_input_index = real_args.size();
1078 
1079     XlaCompiler::Argument token_arg;
1080     token_arg.kind = XlaCompiler::Argument::kToken;
1081     real_args.push_back(token_arg);
1082   }
1083 
1084   std::map<int, int> arg_cores;
1085   std::map<int, int> retval_cores;
1086   TF_ASSIGN_OR_RETURN(std::tie(arg_cores, retval_cores),
1087                       ComputeArgAndRetvalCores(*graph));
1088 
1089   std::vector<XlaExpression> arg_expressions;
1090   TF_RETURN_IF_ERROR(BuildArguments(
1091       *graph, real_args, options.use_tuple_arg, &builder, context, arg_cores,
1092       &arg_expressions, &result->input_mapping, &result->xla_input_shapes,
1093       options.is_entry_computation));
1094   context->set_args(std::move(arg_expressions));
1095 
1096   // Propagate any aliases given to us by the user.
1097   for (const xla::XlaBuilder::InputOutputAlias& alias : user_aliases) {
1098     builder.SetUpAlias(alias.output_index, alias.param_number,
1099                        alias.param_index);
1100   }
1101 
1102   PushNodeTokenMapping();
1103   // Use std::set instead of std::unordered_set to ensure determinism.
1104   std::set<std::string> output_node_token_inputs;
1105   if (token_input_index != -1) {
1106     // Original token comes from input.
1107     auto arg_expression = context->args()[token_input_index];
1108     TF_RETURN_IF_ERROR(
1109         SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle()));
1110 
1111     // Calculate token inputs for output token.
1112     output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph);
1113 
1114     // If there's no side-effecting op in the graph, use token input as token
1115     // output.
1116     if (output_node_token_inputs.empty()) {
1117       output_node_token_inputs.insert(kXlaTokenArgNodeName);
1118     }
1119   } else if (options.is_entry_computation) {
1120     // Original token is manually created.
1121     if (HasSideEffectingNodes(*graph)) {
1122       TF_RETURN_IF_ERROR(
1123           SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder)));
1124     }
1125   }
1126 
1127   TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
1128                                   flib_runtime_, NextStepId()));
1129   if (token_input_index != -1) {
1130     // Add extra token output.
1131     std::vector<xla::XlaOp> token_inputs;
1132     for (const auto& node_name : output_node_token_inputs) {
1133       auto token_or = GetNodeToken(node_name);
1134       TF_RETURN_IF_ERROR(token_or.status());
1135       token_inputs.push_back(token_or.ValueOrDie());
1136     }
1137     token_output.reset(new xla::XlaOp(xla::AfterAll(&builder, token_inputs)));
1138   }
1139   TF_RETURN_IF_ERROR(PopNodeTokenMapping());
1140 
1141   int num_nonconst_outputs;
1142   int num_computation_outputs;
1143   result->computation = std::make_shared<xla::XlaComputation>();
1144   result->outputs.resize(context->retvals().size());
1145   std::vector<XlaExpression> retvals = context->retvals();
1146   if (options.resolve_compile_time_constants) {
1147     Status status = ResolveConstantExpressionsToConstants(
1148         client(), absl::Span<XlaExpression>(retvals));
1149 
1150     // If the HloEvaluator has not implemented an expression, just evaluate it
1151     // at runtime.
1152     if (status.code() == error::UNIMPLEMENTED) {
1153       ConvertConstantsToExpressions(&builder,
1154                                     absl::Span<XlaExpression>(retvals));
1155     } else {
1156       TF_RETURN_IF_ERROR(status);
1157     }
1158   } else {
1159     ConvertConstantsToExpressions(&builder, absl::Span<XlaExpression>(retvals));
1160   }
1161   TF_RETURN_IF_ERROR(BuildComputation(
1162       real_args, retvals, arg_cores, retval_cores, context->resources(),
1163       std::move(token_output),
1164       options.is_entry_computation ? options_.shape_representation_fn
1165                                    : ShapeRepresentationFn{},
1166       options.return_updated_values_for_all_resources,
1167       options.always_return_tuple, &builder, result->computation.get(),
1168       &num_computation_outputs, &num_nonconst_outputs, &result->outputs,
1169       &result->resource_updates, &result->xla_output_shape));
1170 
1171   VLOG(2) << "Outputs: total: " << context->retvals().size()
1172           << " nonconstant: " << num_nonconst_outputs;
1173   VLOG(2) << "XLA output shape: "
1174           << xla::ShapeUtil::HumanStringWithLayout(result->xla_output_shape);
1175   return Status::OK();
1176 }
1177 
GetChannelHandle(const string & key,xla::ChannelHandle * channel)1178 Status XlaCompiler::GetChannelHandle(const string& key,
1179                                      xla::ChannelHandle* channel) {
1180   auto result = channels_.emplace(key, xla::ChannelHandle());
1181   if (result.second) {
1182     TF_ASSIGN_OR_RETURN(result.first->second, client()->CreateChannelHandle());
1183   }
1184   *channel = result.first->second;
1185   VLOG(1) << "Channel: " << key << " " << channel->DebugString();
1186   return Status::OK();
1187 }
1188 
GetHostToDeviceChannelHandle(const string & key,xla::ChannelHandle * channel)1189 Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key,
1190                                                  xla::ChannelHandle* channel) {
1191   auto result = channels_.emplace(key, xla::ChannelHandle());
1192   if (result.second) {
1193     TF_ASSIGN_OR_RETURN(result.first->second,
1194                         client()->CreateHostToDeviceChannelHandle());
1195   }
1196   *channel = result.first->second;
1197   VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString();
1198   return Status::OK();
1199 }
1200 
GetDeviceToHostChannelHandle(const string & key,xla::ChannelHandle * channel)1201 Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key,
1202                                                  xla::ChannelHandle* channel) {
1203   auto result = channels_.emplace(key, xla::ChannelHandle());
1204   if (result.second) {
1205     TF_ASSIGN_OR_RETURN(result.first->second,
1206                         client()->CreateDeviceToHostChannelHandle());
1207   }
1208   *channel = result.first->second;
1209   VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString();
1210   return Status::OK();
1211 }
1212 
1213 namespace {
1214 
SetTransfer(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes,tf2xla::HostTransferMetadata * transfer)1215 void SetTransfer(const string& key, absl::Span<const DataType> types,
1216                  absl::Span<const TensorShape> shapes,
1217                  tf2xla::HostTransferMetadata* transfer) {
1218   transfer->set_key(key);
1219   CHECK(types.size() == shapes.size());
1220   for (int i = 0; i < types.size(); ++i) {
1221     tf2xla::TensorMetadata* metadata = transfer->add_metadata();
1222     metadata->set_type(types[i]);
1223     shapes[i].AsProto(metadata->mutable_shape());
1224   }
1225 }
1226 
1227 }  // namespace
1228 
SetDeviceToHostMetadata(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes)1229 Status XlaCompiler::SetDeviceToHostMetadata(
1230     const string& key, absl::Span<const DataType> types,
1231     absl::Span<const TensorShape> shapes) {
1232   if (host_compute_sends_.find(key) != host_compute_sends_.end()) {
1233     return errors::InvalidArgument(
1234         "Duplicate calls to SetDeviceToHostMetadata with key ", key);
1235   }
1236   tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key];
1237   SetTransfer(key, types, shapes, &transfer);
1238   return Status::OK();
1239 }
1240 
GetDeviceToHostShapes(const string & key,std::vector<TensorShape> * shapes) const1241 Status XlaCompiler::GetDeviceToHostShapes(
1242     const string& key, std::vector<TensorShape>* shapes) const {
1243   const auto iter = host_compute_sends_.find(key);
1244   if (iter == host_compute_sends_.end()) {
1245     return errors::InvalidArgument(
1246         "No host compute send shapes registered for key ", key);
1247   }
1248   shapes->clear();
1249   for (int i = 0; i < iter->second.metadata_size(); ++i) {
1250     TensorShape shape(iter->second.metadata(i).shape());
1251     shapes->push_back(shape);
1252   }
1253   return Status::OK();
1254 }
1255 
SetHostToDeviceMetadata(const string & key,absl::Span<const DataType> types,absl::Span<const TensorShape> shapes)1256 Status XlaCompiler::SetHostToDeviceMetadata(
1257     const string& key, absl::Span<const DataType> types,
1258     absl::Span<const TensorShape> shapes) {
1259   if (host_compute_recvs_.find(key) != host_compute_sends_.end()) {
1260     return errors::InvalidArgument(
1261         "Duplicate calls to SetHostToDeviceMetadata with key ", key);
1262   }
1263   tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key];
1264   SetTransfer(key, types, shapes, &transfer);
1265   return Status::OK();
1266 }
1267 
GetHostComputeControlDependency(const string & host_compute_name,xla::XlaOp * handle)1268 Status XlaCompiler::GetHostComputeControlDependency(
1269     const string& host_compute_name, xla::XlaOp* handle) {
1270   const auto iter = host_compute_control_output_.find(host_compute_name);
1271   if (iter == host_compute_control_output_.end()) {
1272     return errors::InvalidArgument(
1273         "No registered control handle for host compute Op '", host_compute_name,
1274         "'");
1275   } else {
1276     *handle = iter->second;
1277   }
1278   return Status::OK();
1279 }
1280 
SetHostComputeControlDependency(const string & host_compute_name,const xla::XlaOp & handle)1281 Status XlaCompiler::SetHostComputeControlDependency(
1282     const string& host_compute_name, const xla::XlaOp& handle) {
1283   if (host_compute_control_output_.find(host_compute_name) !=
1284       host_compute_control_output_.end()) {
1285     return errors::InvalidArgument(
1286         "Duplicate control handles registered for for host compute Op ",
1287         host_compute_name);
1288   }
1289   host_compute_control_output_[host_compute_name] = handle;
1290   return Status::OK();
1291 }
1292 
PushNodeTokenMapping()1293 void XlaCompiler::PushNodeTokenMapping() {
1294   node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{});
1295 }
1296 
PopNodeTokenMapping()1297 Status XlaCompiler::PopNodeTokenMapping() {
1298   if (node_token_mapping_stack_.empty()) {
1299     return errors::FailedPrecondition(
1300         "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is "
1301         "empty.");
1302   }
1303   node_token_mapping_stack_.pop();
1304   return Status::OK();
1305 }
1306 
SetNodeToken(const string & node_name,const xla::XlaOp & op)1307 Status XlaCompiler::SetNodeToken(const string& node_name,
1308                                  const xla::XlaOp& op) {
1309   if (node_token_mapping_stack_.empty()) {
1310     return errors::FailedPrecondition(
1311         "Calling SetNodeToken() when node_token_mapping_stack_ is "
1312         "empty.");
1313   }
1314   auto insert_result = node_token_mapping_stack_.top().insert({node_name, op});
1315   if (!insert_result.second) {
1316     return errors::FailedPrecondition("Token mapping already exists for node ",
1317                                       node_name);
1318   }
1319   return Status::OK();
1320 }
1321 
GetNodeToken(const string & node_name)1322 xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
1323   if (node_token_mapping_stack_.empty()) {
1324     return errors::FailedPrecondition(
1325         "Calling GetNodeToken() when node_token_mapping_stack_ is "
1326         "empty.");
1327   }
1328   auto iter = node_token_mapping_stack_.top().find(node_name);
1329   if (iter == node_token_mapping_stack_.top().end()) {
1330     return errors::FailedPrecondition("Cannot find token mapping for node ",
1331                                       node_name);
1332   }
1333   return iter->second;
1334 }
1335 
1336 }  // namespace tensorflow
1337