• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
17 
18 #include <deque>
19 #include <numeric>
20 
21 #include "tensorflow/compiler/tf2xla/const_analysis.h"
22 #include "tensorflow/compiler/tf2xla/dump_graph.h"
23 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
24 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
25 #include "tensorflow/compiler/tf2xla/shape_util.h"
26 #include "tensorflow/compiler/tf2xla/sharding_util.h"
27 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
28 #include "tensorflow/compiler/tf2xla/type_util.h"
29 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
30 #include "tensorflow/compiler/tf2xla/xla_context.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
32 #include "tensorflow/compiler/xla/client/client_library.h"
33 #include "tensorflow/core/common_runtime/device.h"
34 #include "tensorflow/core/common_runtime/executor.h"
35 #include "tensorflow/core/common_runtime/function.h"
36 #include "tensorflow/core/common_runtime/graph_optimizer.h"
37 #include "tensorflow/core/framework/attr_value_util.h"
38 #include "tensorflow/core/graph/algorithm.h"
39 #include "tensorflow/core/graph/graph_constructor.h"
40 #include "tensorflow/core/graph/node_builder.h"
41 #include "tensorflow/core/lib/hash/hash.h"
42 #include "tensorflow/core/platform/logging.h"
43 #include "tensorflow/core/public/version.h"
44 
45 namespace tensorflow {
46 namespace {
47 
48 // Checks that arguments `args` match types `types`.
CheckSignature(const DataTypeVector & types,const std::vector<XlaCompiler::Argument> & args)49 Status CheckSignature(const DataTypeVector& types,
50                       const std::vector<XlaCompiler::Argument>& args) {
51   if (args.size() != types.size()) {
52     return errors::Internal("Compilation arguments have ", args.size(),
53                             " elements while function has ", types.size());
54   }
55   for (int i = 0; i < types.size(); ++i) {
56     if (types[i] != args[i].type && types[i] != DT_RESOURCE) {
57       return errors::Internal(
58           "Argument ", i, " has declared type ", DataTypeString(args[i].type),
59           " but function parameter has type ", DataTypeString(types[i]));
60     }
61   }
62   return Status::OK();
63 }
64 
65 }  // namespace
66 
operator ==(const XlaCompiler::Argument & other) const67 bool XlaCompiler::Argument::operator==(
68     const XlaCompiler::Argument& other) const {
69   if (std::tie(kind, resource_kind, type, name, initialized, tensor_array_size,
70                tensor_array_gradients) !=
71       std::tie(other.kind, other.resource_kind, other.type, other.name,
72                other.initialized, other.tensor_array_size,
73                other.tensor_array_gradients)) {
74     return false;
75   }
76   if (shape != other.shape) {
77     return false;
78   }
79   if (constant_value.shape() != other.constant_value.shape()) {
80     return false;
81   }
82   return constant_value.tensor_data() == other.constant_value.tensor_data();
83 }
84 
XlaCompiler(XlaCompiler::Options options)85 XlaCompiler::XlaCompiler(XlaCompiler::Options options)
86     : options_(options),
87       initialization_status_(Status::OK()),
88       next_step_id_(1),
89       device_(
90           new XlaCompilationDevice(SessionOptions(), *options_.device_type)),
91       device_mgr_({device_}) {
92   // We no longer need the device_type.
93   options_.device_type = nullptr;
94 
95   if (options_.populate_resource_manager) {
96     initialization_status_ =
97         (*options_.populate_resource_manager)(device_->resource_manager());
98   }
99 
100   local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
101                                                       FunctionDefLibrary{}));
102   local_pflr_.reset(new ProcessFunctionLibraryRuntime(
103       &device_mgr_, Env::Default(), options.graph_def_version,
104       local_flib_def_.get(), OptimizerOptions(),
105       nullptr /* custom_kernel_creator */));
106   pflr_.reset(new ProcessFunctionLibraryRuntime(
107       &device_mgr_, Env::Default(), options.graph_def_version, options.flib_def,
108       OptimizerOptions(), nullptr /* custom_kernel_creator */));
109 
110   local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
111   flib_runtime_ = pflr_->GetFLR(device_->name());
112 
113   // The default variable representation shape is the identity function.
114   if (!options_.variable_representation_shape_fn) {
115     options_.variable_representation_shape_fn =
__anon9239614f0202(const TensorShape& shape, DataType type) 116         [](const TensorShape& shape, DataType type) { return shape; };
117   }
118 }
119 
120 XlaCompiler::~XlaCompiler() = default;
121 
NextStepId()122 int64 XlaCompiler::NextStepId() { return next_step_id_++; }
123 
operator ()(const std::pair<string,std::vector<Argument>> & signature) const124 uint64 XlaCompiler::SignatureHash::operator()(
125     const std::pair<string, std::vector<Argument>>& signature) const {
126   return std::hash<string>()(signature.first);
127 }
128 
GetFunctionBody(const NameAttrList & function,FunctionLibraryRuntime * flib_runtime,const FunctionBody ** fbody)129 static Status GetFunctionBody(const NameAttrList& function,
130                               FunctionLibraryRuntime* flib_runtime,
131                               const FunctionBody** fbody) {
132   FunctionLibraryRuntime::Handle handle;
133   TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
134       function.name(), AttrSlice(&function.attr()), &handle));
135 
136   *fbody = flib_runtime->GetFunctionBody(handle);
137   TF_RET_CHECK(*fbody);
138   return Status::OK();
139 }
140 
FindFunctionBody(const NameAttrList & function,const FunctionBody ** fbody)141 Status XlaCompiler::FindFunctionBody(const NameAttrList& function,
142                                      const FunctionBody** fbody) {
143   // The function may be in either the local_flib_runtime_ or flib_runtime_.
144   // Look up the function in local first and if it is not found then look up the
145   // function in flib_runtime_.
146   auto status = GetFunctionBody(function, local_flib_runtime_, fbody);
147   if (!status.ok()) {
148     if (!errors::IsNotFound(status)) {
149       return status;
150     }
151     TF_RETURN_WITH_CONTEXT_IF_ERROR(
152         GetFunctionBody(function, flib_runtime_, fbody),
153         "Local lookup failed with: ", status.error_message());
154   }
155   return Status::OK();
156 }
157 
GetGraph(const FunctionBody * fbody)158 std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
159   std::unique_ptr<Graph> graph(new Graph(options_.flib_def));
160   CopyGraph(*fbody->graph, graph.get());
161   OptimizerOptions opts;
162   opts.set_opt_level(OptimizerOptions::L0);
163   opts.set_do_common_subexpression_elimination(false);
164   opts.set_do_function_inlining(true);
165   opts.set_do_constant_folding(true);
166   GraphOptimizer optimizer(opts);
167   optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
168                      /*device=*/nullptr, &graph, /*shape_map=*/nullptr);
169 
170   return graph;
171 }
172 
CompileFunction(const XlaCompiler::CompileOptions & options,const NameAttrList & function,std::vector<XlaCompiler::Argument> args,XlaCompiler::CompilationResult * result)173 Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
174                                     const NameAttrList& function,
175                                     std::vector<XlaCompiler::Argument> args,
176                                     XlaCompiler::CompilationResult* result) {
177   const string function_id =
178       Canonicalize(function.name(), AttrSlice(&function.attr()));
179   VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
180 
181   auto it = cache_.find({function_id, args});
182   if (it != cache_.end()) {
183     *result = it->second;
184     return Status::OK();
185   }
186 
187   const FunctionBody* fbody;
188   TF_RETURN_IF_ERROR(FindFunctionBody(function, &fbody));
189 
190   TF_RETURN_WITH_CONTEXT_IF_ERROR(
191       CheckSignature(fbody->arg_types, args),
192       "Signature check failure while compiling: ", function.name());
193 
194   std::unique_ptr<Graph> graph = GetGraph(fbody);
195 
196   // _Arg and _Retval nodes don't exist in the stored subgraph for the function;
197   // they are added by the function body looked up.  Therefore, they don't have
198   // core assignments here.
199   // Attempt to assign a core to each _Retval and _Arg. Chooses the
200   // lowest-numbered core that consumes the argument. We choose the
201   // lowest-numbered core so the assignment is deterministic.
202   for (Node* n : graph->nodes()) {
203     if (StringPiece(n->type_string()) == "_Arg") {
204       TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true));
205     }
206   }
207   // Do _Retval as a second loop, in case the retval's input is an _Arg (which
208   // may have gotten a device assignment from the first loop).
209   for (Node* n : graph->nodes()) {
210     if (StringPiece(n->type_string()) == "_Retval") {
211       TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false));
212     }
213   }
214 
215   if (VLOG_IS_ON(2)) {
216     VLOG(2) << "XlaCompiler::CompileFunction: "
217             << dump_graph::DumpGraphToFile(
218                    strings::StrCat("xla_compile_function_", function_id),
219                    *graph);
220   }
221 
222   VLOG(1) << "====================================================";
223   TF_RETURN_IF_ERROR(
224       CompileGraph(options, function_id, std::move(graph), args, result));
225   VLOG(1) << "====================================================";
226 
227   cache_[{function_id, args}] = *result;
228   return Status::OK();
229 }
230 
231 // Computes the XLA shape for argument 'arg'.
XLAShapeForArgument(const XlaCompiler::Argument & arg,xla::Shape * xla_shape)232 Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
233                                         xla::Shape* xla_shape) {
234   switch (arg.kind) {
235     case XlaCompiler::Argument::kConstant:
236       return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(),
237                                    xla_shape);
238     case XlaCompiler::Argument::kParameter:
239       return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
240     case XlaCompiler::Argument::kResource: {
241       TF_RET_CHECK(arg.initialized);
242 
243       switch (arg.resource_kind) {
244         case XlaResource::kVariable: {
245           TensorShape representation_shape =
246               options_.variable_representation_shape_fn(arg.shape, arg.type);
247           return TensorShapeToXLAShape(arg.type, representation_shape,
248                                        xla_shape);
249         }
250         case XlaResource::kTensorArray: {
251           if (arg.tensor_array_size < 0) {
252             return errors::InvalidArgument(
253                 "Negative tensor_array_size in XLAShapeForArgument");
254           }
255           TensorShape shape;
256           shape.AddDim(arg.tensor_array_size);
257           shape.AppendShape(arg.shape);
258           TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape));
259 
260           if (!arg.tensor_array_gradients.empty()) {
261             std::vector<xla::Shape> tuple_shape(
262                 arg.tensor_array_gradients.size() + 1, *xla_shape);
263             *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape);
264           }
265           return Status::OK();
266         }
267         case XlaResource::kStack: {
268           if (arg.tensor_array_size < 0) {
269             return errors::InvalidArgument(
270                 "Negative tensor_array_size in XLAShapeForArgument");
271           }
272           TensorShape shape;
273           shape.AddDim(arg.tensor_array_size);
274           shape.AppendShape(arg.shape);
275           xla::Shape buffer_shape;
276           TF_RETURN_IF_ERROR(
277               TensorShapeToXLAShape(arg.type, shape, &buffer_shape));
278           *xla_shape = xla::ShapeUtil::MakeTupleShape(
279               {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})});
280           return Status::OK();
281         }
282 
283         case XlaResource::kInvalid:
284           return errors::Internal(
285               "Invalid resource type in XLAShapeForArgument()");
286       }
287     }
288     case XlaCompiler::Argument::kInvalid:
289       return errors::Internal("Invalid argument type in XLAShapeForArgument()");
290   }
291 }
292 
293 namespace {
294 
ExecuteGraph(XlaContext * xla_context,std::unique_ptr<Graph> graph,XlaCompilationDevice * device,FunctionLibraryRuntime * flib,int64 step_id)295 Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
296                     XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
297                     int64 step_id) {
298   // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
299   // resource manager takes ownership via Create, and unrefs via Cleanup.  We
300   // explicitly add a reference to ensure the refcount at entry is maintained at
301   // all exit points; Create and Cleanup are always called in this function.
302   //
303   // The Executor requires us to use ScopedStepContainer. We wrap it in a
304   // unique_ptr so we can capture the cleanup status in the end.
305   xla_context->Ref();
306   Status status;
307   auto step_container = xla::MakeUnique<ScopedStepContainer>(
308       step_id, [&status, device](const string& name) {
309         status = device->resource_manager()->Cleanup(name);
310       });
311   TF_RETURN_IF_ERROR(device->resource_manager()->Create(
312       step_container->name(), XlaContext::kXlaContextResourceName,
313       xla_context));
314 
315   GraphCompiler graph_compiler(xla_context, device, graph.get(), flib,
316                                step_container.get());
317   TF_RETURN_IF_ERROR(graph_compiler.Compile());
318   // Explicitly clean up the step container, to capture the cleanup status.
319   step_container.reset();
320   return Status::OK();
321 }
322 
323 // Builds the XLA computation.
324 //
325 // `retvals` is the list of retvals produced by _Retval operators, in index
326 // order. `variable_map` is a map from variable ID numbers to XlaOpContext
327 // variable states, generated by the symbolic evaluation.
328 // If `return_updated_values_for_all_resources` is true, all resources will be
329 // included in `resource_updates`, regardless of whether their value changed.
330 // Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
331 // Sets `*resource_updates` to a description of resources whose values are
332 // written by the computation; the variable writes are the last
333 // `resource_updates.size()` return values from the computation. Each entry in
334 // `resource_updates` is a (input_index, type) pair, where `input_index` is the
335 // index of a resource variable argument to the computation, and `type` is the
336 // type of the final output.
BuildComputation(const std::vector<XlaCompiler::Argument> & args,const std::vector<int> & arg_cores,const std::vector<XlaExpression> & retvals,const std::vector<std::unique_ptr<XlaResource>> & resources,bool return_updated_values_for_all_resources,xla::ComputationBuilder * builder,xla::Computation * computation,int * num_computation_outputs,int * num_nonconst_outputs,std::vector<XlaCompiler::ResourceUpdate> * resource_updates)337 Status BuildComputation(
338     const std::vector<XlaCompiler::Argument>& args,
339     const std::vector<int>& arg_cores,
340     const std::vector<XlaExpression>& retvals,
341     const std::vector<std::unique_ptr<XlaResource>>& resources,
342     bool return_updated_values_for_all_resources,
343     xla::ComputationBuilder* builder, xla::Computation* computation,
344     int* num_computation_outputs, int* num_nonconst_outputs,
345     std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
346   std::vector<xla::ComputationDataHandle> elems;
347   elems.reserve(retvals.size());
348   for (const XlaExpression& retval : retvals) {
349     if (!retval.has_constant_value()) {
350       elems.push_back(retval.handle());
351     }
352   }
353   *num_nonconst_outputs = elems.size();
354 
355   // Add return values for resources whose values have changed.
356   std::vector<const XlaResource*> arg_resources;
357   arg_resources.reserve(resources.size());
358   for (const auto& resource : resources) {
359     if (resource->arg_num() >= 0) {
360       arg_resources.push_back(resource.get());
361     }
362   }
363   std::sort(arg_resources.begin(), arg_resources.end(),
364             [](const XlaResource* a, const XlaResource* b) {
365               return a->arg_num() < b->arg_num();
366             });
367 
368   for (const XlaResource* resource : arg_resources) {
369     const XlaCompiler::Argument& arg = args[resource->arg_num()];
370     const int core = arg_cores[resource->arg_num()];
371     DCHECK_LT(resource->arg_num(), arg_cores.size());
372     bool modified =
373         resource->value().handle() != resource->initial_value().handle();
374     // TensorArray gradients were modified if their values changed or there are
375     // any newly created gradients.
376     for (const auto& grad : resource->tensor_array_gradients()) {
377       modified = modified ||
378                  grad.second->value().handle() !=
379                      grad.second->initial_value().handle() ||
380                  arg.tensor_array_gradients.count(grad.first) == 0;
381     }
382     if (return_updated_values_for_all_resources || modified) {
383       resource_updates->emplace_back();
384       XlaCompiler::ResourceUpdate& update = resource_updates->back();
385       update.input_index = resource->arg_num();
386       update.type = resource->type();
387       update.shape = resource->shape();
388       update.modified = modified;
389       for (const auto& grad : resource->tensor_array_gradients()) {
390         update.tensor_array_gradients_accessed.insert(grad.first);
391       }
392 
393       // Request that the value be returned on a specific core.
394       xla::ScopedShardingAssignment assign_sharding(
395           builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
396                               : xla::sharding_builder::AssignDevice(core));
397 
398       xla::ComputationDataHandle handle;
399       TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
400 
401       // Since we can't change the sharding metadata of <value> as this point,
402       // create a tuple/get-tuple-element combination so that sharding
403       // assignment will be placed on this value, which will cause the resource
404       // update to be returned from the same device that provided the resource.
405       handle = builder->GetTupleElement(builder->Tuple({handle}), 0);
406 
407       elems.push_back(handle);
408     }
409   }
410 
411   *num_computation_outputs = elems.size();
412 
413   // Builds the XLA computation.
414   builder->Tuple(elems);
415   xla::StatusOr<xla::Computation> computation_status = builder->Build();
416   if (!computation_status.ok()) {
417     return computation_status.status();
418   }
419   *computation = computation_status.ConsumeValueOrDie();
420   return Status::OK();
421 }
422 
423 }  // namespace
424 
425 // Builds XLA computations for each of the arguments to the computation.
426 // `args` are the arguments to the computation.
BuildArguments(const Graph & graph,const std::vector<XlaCompiler::Argument> & args,bool use_tuple_arg,xla::ComputationBuilder * builder,XlaContext * context,std::vector<int> * arg_cores,std::vector<XlaExpression> * arg_expressions,std::vector<int> * input_mapping,std::vector<xla::Shape> * input_shapes,bool is_entry_computation)427 Status XlaCompiler::BuildArguments(
428     const Graph& graph, const std::vector<XlaCompiler::Argument>& args,
429     bool use_tuple_arg, xla::ComputationBuilder* builder, XlaContext* context,
430     std::vector<int>* arg_cores, std::vector<XlaExpression>* arg_expressions,
431     std::vector<int>* input_mapping, std::vector<xla::Shape>* input_shapes,
432     bool is_entry_computation) {
433   arg_expressions->resize(args.size());
434   *arg_cores = std::vector<int>(args.size(), -1);
435 
436   // Argument numbers of arguments and resources that are to be passed to the
437   // XLA computation as runtime parameters.
438   input_mapping->clear();
439   input_mapping->reserve(args.size());
440   std::vector<int> resources;
441   resources.reserve(args.size());
442 
443   // Fills in constant arguments, and computes non-constant argument order.
444   for (std::vector<XlaCompiler::Argument>::size_type i = 0; i < args.size();
445        ++i) {
446     const XlaCompiler::Argument& arg = args[i];
447     XlaExpression& arg_expression = (*arg_expressions)[i];
448     switch (arg.kind) {
449       case XlaCompiler::Argument::kResource:
450         TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid);
451         // TODO(phawkins): this code assumes that resource arguments do not
452         // alias.
453         XlaResource* resource;
454         TF_RETURN_IF_ERROR(context->CreateResource(
455             arg.resource_kind, i, arg.name, arg.type, arg.shape,
456             xla::ComputationDataHandle(),
457             /*tensor_array_size=*/arg.tensor_array_size,
458             /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource));
459         arg_expression.set_resource(resource);
460         if (arg.initialized) {
461           resources.push_back(i);
462         }
463         break;
464       case XlaCompiler::Argument::kParameter: {
465         input_mapping->push_back(i);
466         break;
467       }
468       case XlaCompiler::Argument::kConstant:
469         arg_expression.set_constant_value(arg.constant_value);
470         break;
471       case XlaCompiler::Argument::kInvalid:
472         return errors::Internal("Unreachable case in BuildArguments()");
473     }
474   }
475 
476   // Append parameters containing variable values after the other runtime
477   // parameters.
478   input_mapping->insert(input_mapping->end(), resources.begin(),
479                         resources.end());
480   if (input_mapping->empty()) {
481     return Status::OK();
482   }
483 
484   std::vector<xla::Shape> arg_shapes(input_mapping->size());
485   for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
486     // Computes the shapes of non-constant arguments.
487     TF_RETURN_IF_ERROR(
488         XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i]));
489   }
490 
491   if (use_tuple_arg) {
492     input_shapes->push_back(xla::ShapeUtil::MakeTupleShape(arg_shapes));
493   } else {
494     *input_shapes = arg_shapes;
495   }
496 
497   // Use the _Arg nodes in the graph to resolve core assignments.
498   for (const Node* n : graph.nodes()) {
499     if (StringPiece(n->type_string()) != "_Arg") continue;
500     int index;
501     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
502     TF_RET_CHECK(index >= 0 && index < args.size())
503         << "_Arg out of bounds: " << index << " vs " << args.size();
504     TF_ASSIGN_OR_RETURN(
505         auto sharding,
506         ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
507     if (sharding.has_value()) {
508       TF_RET_CHECK(sharding.value().type() ==
509                    xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
510       const int core = sharding.value().tile_assignment_devices(0);
511       if ((*arg_cores)[index] == -1 || core < (*arg_cores)[index]) {
512         (*arg_cores)[index] = core;
513       }
514     }
515   }
516 
517   // Build parameter handles for non-constant arguments.
518   std::vector<xla::ComputationDataHandle> arg_handles(input_mapping->size());
519   if (use_tuple_arg) {
520     xla::ComputationDataHandle tuple;
521     if (is_entry_computation) {
522       xla::OpSharding tuple_sharding;
523       tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
524       for (int64 parameter : *input_mapping) {
525         const int core = (*arg_cores)[parameter];
526         const int root_device = 0;
527         *tuple_sharding.add_tuple_shardings() =
528             core == -1 ? xla::sharding_builder::AssignDevice(root_device)
529                        : xla::sharding_builder::AssignDevice(core);
530       }
531       xla::ScopedShardingAssignment assign_tuple_sharding(builder,
532                                                           tuple_sharding);
533       tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple");
534     } else {
535       tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple");
536     }
537     for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
538       const int core = (*arg_cores)[input_mapping->at(i)];
539       xla::ScopedShardingAssignment assign_sharding(
540           builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
541                               : xla::sharding_builder::AssignDevice(core));
542       arg_handles[i] = builder->GetTupleElement(tuple, i);
543     }
544   } else {
545     for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
546       const int core = (*arg_cores)[input_mapping->at(i)];
547       xla::ScopedShardingAssignment assign_sharding(
548           builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
549                               : xla::sharding_builder::AssignDevice(core));
550       arg_handles[i] =
551           builder->Parameter(i, (*input_shapes)[i], strings::StrCat("arg", i));
552     }
553   }
554 
555   // Fill in the handles in non-constant arguments.
556   VLOG(2) << "XLA computation inputs:";
557   for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
558     const XlaCompiler::Argument& arg = args[input_mapping->at(i)];
559     VLOG(2) << "  XLA arg " << i
560             << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
561             << " name: " << arg.name << " TF arg " << input_mapping->at(i);
562     XlaExpression& arg_expression = (*arg_expressions)[input_mapping->at(i)];
563     switch (arg.kind) {
564       case XlaCompiler::Argument::kResource: {
565         TF_RET_CHECK(arg.initialized);
566         XlaResource* resource = arg_expression.resource();
567         TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients,
568                                                  arg_handles[i], builder));
569         VLOG(2) << "    resource: num_gradients: "
570                 << arg.tensor_array_gradients.size();
571         break;
572       }
573       case XlaCompiler::Argument::kParameter:
574         arg_expression.set_handle(arg_handles[i]);
575         break;
576       case XlaCompiler::Argument::kConstant:
577       case XlaCompiler::Argument::kInvalid:
578         return errors::Internal("Unreachable case in BuildArguments()");
579     }
580   }
581 
582   return Status::OK();
583 }
584 
CompileGraph(const XlaCompiler::CompileOptions & options,string const & name,std::unique_ptr<Graph> graph,const std::vector<XlaCompiler::Argument> & args,CompilationResult * result)585 Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
586                                  string const& name,
587                                  std::unique_ptr<Graph> graph,
588                                  const std::vector<XlaCompiler::Argument>& args,
589                                  CompilationResult* result) {
590   VLOG(1) << "Executing graph symbolically to populate ComputationBuilder.";
591 
592   if (VLOG_IS_ON(2)) {
593     VLOG(2) << "XlaCompiler::CompileGraph: "
594             << dump_graph::DumpGraphToFile(
595                    strings::StrCat("xla_compile_graph_", name), *graph);
596   }
597 
598   // Report the error here if initialization failed.
599   TF_RETURN_IF_ERROR(initialization_status_);
600 
601   // Converts Tensorflow's graph control-flow constructs into functional
602   // control-flow that can be compiled into XLA code.
603   TF_RETURN_IF_ERROR(
604       FunctionalizeControlFlow(graph.get(), local_flib_def_.get()));
605 
606   xla::ComputationBuilder builder(client(), name);
607   XlaContext* context =
608       new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
609                      options.resolve_compile_time_constants,
610                      &options_.variable_representation_shape_fn);
611   core::ScopedUnref context_unref(context);
612 
613   std::vector<XlaExpression> arg_expressions;
614   std::vector<int> arg_cores;
615   TF_RETURN_IF_ERROR(
616       BuildArguments(*graph, args, options.use_tuple_arg, &builder, context,
617                      &arg_cores, &arg_expressions, &result->input_mapping,
618                      &result->xla_input_shapes, options.is_entry_computation));
619   context->set_args(std::move(arg_expressions));
620 
621   TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
622                                   flib_runtime_, NextStepId()));
623 
624   int num_nonconst_outputs;
625   int num_computation_outputs;
626   result->computation = std::make_shared<xla::Computation>();
627   TF_RETURN_IF_ERROR(BuildComputation(
628       args, arg_cores, context->retvals(), context->resources(),
629       options.return_updated_values_for_all_resources, &builder,
630       result->computation.get(), &num_computation_outputs,
631       &num_nonconst_outputs, &result->resource_updates));
632 
633   VLOG(2) << "Outputs: total: " << context->retvals().size()
634           << " nonconstant: " << num_nonconst_outputs;
635   result->outputs.resize(context->retvals().size());
636   for (std::vector<XlaExpression>::size_type i = 0;
637        i < context->retvals().size(); ++i) {
638     const XlaExpression& retval = context->retvals()[i];
639     if (retval.has_constant_value()) {
640       OutputDescription& output = result->outputs[i];
641       output.shape = retval.constant_value().shape();
642       output.is_constant = true;
643       output.constant_value = retval.constant_value();
644     }
645   }
646 
647   // Compute the output shapes, if there is a computation with non-constant
648   // outputs.
649   auto computation_shape = client()->GetComputationShape(*result->computation);
650   if (!computation_shape.ok()) {
651     return computation_shape.status();
652   }
653 
654   result->xla_output_shape.Swap(
655       computation_shape.ValueOrDie()->mutable_result());
656   VLOG(2) << "XLA output shape: "
657           << xla::ShapeUtil::HumanString(result->xla_output_shape);
658 
659   // Tensorflow expects a major-to-minor order of results.
660   xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
661 
662   // Converts the output shapes to TensorShapes.
663   int computation_output = 0;
664   for (std::vector<XlaExpression>::size_type i = 0;
665        i < context->retvals().size(); ++i) {
666     const XlaExpression& retval = context->retvals()[i];
667     if (!retval.has_constant_value()) {
668       TF_RET_CHECK(computation_output < num_computation_outputs)
669           << "Computation has more outputs than expected";
670       OutputDescription& output = result->outputs[i];
671       output.is_constant = false;
672       TF_RETURN_IF_ERROR(XLAShapeToTensorShape(
673           xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape,
674                                                computation_output),
675           &output.shape));
676       ++computation_output;
677     }
678   }
679   return Status::OK();
680 }
681 
GetChannelHandle(const string & key,xla::ChannelHandle * channel)682 Status XlaCompiler::GetChannelHandle(const string& key,
683                                      xla::ChannelHandle* channel) {
684   auto result = channels_.emplace(key, xla::ChannelHandle());
685   if (result.second) {
686     TF_ASSIGN_OR_RETURN(result.first->second, client()->CreateChannelHandle());
687   }
688   *channel = result.first->second;
689   VLOG(1) << "Channel: " << key << " " << channel->DebugString();
690   return Status::OK();
691 }
692 
693 }  // namespace tensorflow
694