1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
17
18 #include <deque>
19 #include <numeric>
20
21 #include "tensorflow/compiler/tf2xla/const_analysis.h"
22 #include "tensorflow/compiler/tf2xla/dump_graph.h"
23 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
24 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
25 #include "tensorflow/compiler/tf2xla/shape_util.h"
26 #include "tensorflow/compiler/tf2xla/sharding_util.h"
27 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
28 #include "tensorflow/compiler/tf2xla/type_util.h"
29 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
30 #include "tensorflow/compiler/tf2xla/xla_context.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
32 #include "tensorflow/compiler/xla/client/client_library.h"
33 #include "tensorflow/core/common_runtime/device.h"
34 #include "tensorflow/core/common_runtime/executor.h"
35 #include "tensorflow/core/common_runtime/function.h"
36 #include "tensorflow/core/common_runtime/graph_optimizer.h"
37 #include "tensorflow/core/framework/attr_value_util.h"
38 #include "tensorflow/core/graph/algorithm.h"
39 #include "tensorflow/core/graph/graph_constructor.h"
40 #include "tensorflow/core/graph/node_builder.h"
41 #include "tensorflow/core/lib/hash/hash.h"
42 #include "tensorflow/core/platform/logging.h"
43 #include "tensorflow/core/public/version.h"
44
45 namespace tensorflow {
46 namespace {
47
48 // Checks that arguments `args` match types `types`.
CheckSignature(const DataTypeVector & types,const std::vector<XlaCompiler::Argument> & args)49 Status CheckSignature(const DataTypeVector& types,
50 const std::vector<XlaCompiler::Argument>& args) {
51 if (args.size() != types.size()) {
52 return errors::Internal("Compilation arguments have ", args.size(),
53 " elements while function has ", types.size());
54 }
55 for (int i = 0; i < types.size(); ++i) {
56 if (types[i] != args[i].type && types[i] != DT_RESOURCE) {
57 return errors::Internal(
58 "Argument ", i, " has declared type ", DataTypeString(args[i].type),
59 " but function parameter has type ", DataTypeString(types[i]));
60 }
61 }
62 return Status::OK();
63 }
64
65 } // namespace
66
operator ==(const XlaCompiler::Argument & other) const67 bool XlaCompiler::Argument::operator==(
68 const XlaCompiler::Argument& other) const {
69 if (std::tie(kind, resource_kind, type, name, initialized, tensor_array_size,
70 tensor_array_gradients) !=
71 std::tie(other.kind, other.resource_kind, other.type, other.name,
72 other.initialized, other.tensor_array_size,
73 other.tensor_array_gradients)) {
74 return false;
75 }
76 if (shape != other.shape) {
77 return false;
78 }
79 if (constant_value.shape() != other.constant_value.shape()) {
80 return false;
81 }
82 return constant_value.tensor_data() == other.constant_value.tensor_data();
83 }
84
XlaCompiler(XlaCompiler::Options options)85 XlaCompiler::XlaCompiler(XlaCompiler::Options options)
86 : options_(options),
87 initialization_status_(Status::OK()),
88 next_step_id_(1),
89 device_(
90 new XlaCompilationDevice(SessionOptions(), *options_.device_type)),
91 device_mgr_({device_}) {
92 // We no longer need the device_type.
93 options_.device_type = nullptr;
94
95 if (options_.populate_resource_manager) {
96 initialization_status_ =
97 (*options_.populate_resource_manager)(device_->resource_manager());
98 }
99
100 local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
101 FunctionDefLibrary{}));
102 local_pflr_.reset(new ProcessFunctionLibraryRuntime(
103 &device_mgr_, Env::Default(), options.graph_def_version,
104 local_flib_def_.get(), OptimizerOptions(),
105 nullptr /* custom_kernel_creator */));
106 pflr_.reset(new ProcessFunctionLibraryRuntime(
107 &device_mgr_, Env::Default(), options.graph_def_version, options.flib_def,
108 OptimizerOptions(), nullptr /* custom_kernel_creator */));
109
110 local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
111 flib_runtime_ = pflr_->GetFLR(device_->name());
112
113 // The default variable representation shape is the identity function.
114 if (!options_.variable_representation_shape_fn) {
115 options_.variable_representation_shape_fn =
__anon9239614f0202(const TensorShape& shape, DataType type) 116 [](const TensorShape& shape, DataType type) { return shape; };
117 }
118 }
119
120 XlaCompiler::~XlaCompiler() = default;
121
NextStepId()122 int64 XlaCompiler::NextStepId() { return next_step_id_++; }
123
operator ()(const std::pair<string,std::vector<Argument>> & signature) const124 uint64 XlaCompiler::SignatureHash::operator()(
125 const std::pair<string, std::vector<Argument>>& signature) const {
126 return std::hash<string>()(signature.first);
127 }
128
GetFunctionBody(const NameAttrList & function,FunctionLibraryRuntime * flib_runtime,const FunctionBody ** fbody)129 static Status GetFunctionBody(const NameAttrList& function,
130 FunctionLibraryRuntime* flib_runtime,
131 const FunctionBody** fbody) {
132 FunctionLibraryRuntime::Handle handle;
133 TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
134 function.name(), AttrSlice(&function.attr()), &handle));
135
136 *fbody = flib_runtime->GetFunctionBody(handle);
137 TF_RET_CHECK(*fbody);
138 return Status::OK();
139 }
140
FindFunctionBody(const NameAttrList & function,const FunctionBody ** fbody)141 Status XlaCompiler::FindFunctionBody(const NameAttrList& function,
142 const FunctionBody** fbody) {
143 // The function may be in either the local_flib_runtime_ or flib_runtime_.
144 // Look up the function in local first and if it is not found then look up the
145 // function in flib_runtime_.
146 auto status = GetFunctionBody(function, local_flib_runtime_, fbody);
147 if (!status.ok()) {
148 if (!errors::IsNotFound(status)) {
149 return status;
150 }
151 TF_RETURN_WITH_CONTEXT_IF_ERROR(
152 GetFunctionBody(function, flib_runtime_, fbody),
153 "Local lookup failed with: ", status.error_message());
154 }
155 return Status::OK();
156 }
157
GetGraph(const FunctionBody * fbody)158 std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
159 std::unique_ptr<Graph> graph(new Graph(options_.flib_def));
160 CopyGraph(*fbody->graph, graph.get());
161 OptimizerOptions opts;
162 opts.set_opt_level(OptimizerOptions::L0);
163 opts.set_do_common_subexpression_elimination(false);
164 opts.set_do_function_inlining(true);
165 opts.set_do_constant_folding(true);
166 GraphOptimizer optimizer(opts);
167 optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
168 /*device=*/nullptr, &graph, /*shape_map=*/nullptr);
169
170 return graph;
171 }
172
CompileFunction(const XlaCompiler::CompileOptions & options,const NameAttrList & function,std::vector<XlaCompiler::Argument> args,XlaCompiler::CompilationResult * result)173 Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
174 const NameAttrList& function,
175 std::vector<XlaCompiler::Argument> args,
176 XlaCompiler::CompilationResult* result) {
177 const string function_id =
178 Canonicalize(function.name(), AttrSlice(&function.attr()));
179 VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
180
181 auto it = cache_.find({function_id, args});
182 if (it != cache_.end()) {
183 *result = it->second;
184 return Status::OK();
185 }
186
187 const FunctionBody* fbody;
188 TF_RETURN_IF_ERROR(FindFunctionBody(function, &fbody));
189
190 TF_RETURN_WITH_CONTEXT_IF_ERROR(
191 CheckSignature(fbody->arg_types, args),
192 "Signature check failure while compiling: ", function.name());
193
194 std::unique_ptr<Graph> graph = GetGraph(fbody);
195
196 // _Arg and _Retval nodes don't exist in the stored subgraph for the function;
197 // they are added by the function body looked up. Therefore, they don't have
198 // core assignments here.
199 // Attempt to assign a core to each _Retval and _Arg. Chooses the
200 // lowest-numbered core that consumes the argument. We choose the
201 // lowest-numbered core so the assignment is deterministic.
202 for (Node* n : graph->nodes()) {
203 if (StringPiece(n->type_string()) == "_Arg") {
204 TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true));
205 }
206 }
207 // Do _Retval as a second loop, in case the retval's input is an _Arg (which
208 // may have gotten a device assignment from the first loop).
209 for (Node* n : graph->nodes()) {
210 if (StringPiece(n->type_string()) == "_Retval") {
211 TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false));
212 }
213 }
214
215 if (VLOG_IS_ON(2)) {
216 VLOG(2) << "XlaCompiler::CompileFunction: "
217 << dump_graph::DumpGraphToFile(
218 strings::StrCat("xla_compile_function_", function_id),
219 *graph);
220 }
221
222 VLOG(1) << "====================================================";
223 TF_RETURN_IF_ERROR(
224 CompileGraph(options, function_id, std::move(graph), args, result));
225 VLOG(1) << "====================================================";
226
227 cache_[{function_id, args}] = *result;
228 return Status::OK();
229 }
230
231 // Computes the XLA shape for argument 'arg'.
XLAShapeForArgument(const XlaCompiler::Argument & arg,xla::Shape * xla_shape)232 Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
233 xla::Shape* xla_shape) {
234 switch (arg.kind) {
235 case XlaCompiler::Argument::kConstant:
236 return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(),
237 xla_shape);
238 case XlaCompiler::Argument::kParameter:
239 return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
240 case XlaCompiler::Argument::kResource: {
241 TF_RET_CHECK(arg.initialized);
242
243 switch (arg.resource_kind) {
244 case XlaResource::kVariable: {
245 TensorShape representation_shape =
246 options_.variable_representation_shape_fn(arg.shape, arg.type);
247 return TensorShapeToXLAShape(arg.type, representation_shape,
248 xla_shape);
249 }
250 case XlaResource::kTensorArray: {
251 if (arg.tensor_array_size < 0) {
252 return errors::InvalidArgument(
253 "Negative tensor_array_size in XLAShapeForArgument");
254 }
255 TensorShape shape;
256 shape.AddDim(arg.tensor_array_size);
257 shape.AppendShape(arg.shape);
258 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape));
259
260 if (!arg.tensor_array_gradients.empty()) {
261 std::vector<xla::Shape> tuple_shape(
262 arg.tensor_array_gradients.size() + 1, *xla_shape);
263 *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape);
264 }
265 return Status::OK();
266 }
267 case XlaResource::kStack: {
268 if (arg.tensor_array_size < 0) {
269 return errors::InvalidArgument(
270 "Negative tensor_array_size in XLAShapeForArgument");
271 }
272 TensorShape shape;
273 shape.AddDim(arg.tensor_array_size);
274 shape.AppendShape(arg.shape);
275 xla::Shape buffer_shape;
276 TF_RETURN_IF_ERROR(
277 TensorShapeToXLAShape(arg.type, shape, &buffer_shape));
278 *xla_shape = xla::ShapeUtil::MakeTupleShape(
279 {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})});
280 return Status::OK();
281 }
282
283 case XlaResource::kInvalid:
284 return errors::Internal(
285 "Invalid resource type in XLAShapeForArgument()");
286 }
287 }
288 case XlaCompiler::Argument::kInvalid:
289 return errors::Internal("Invalid argument type in XLAShapeForArgument()");
290 }
291 }
292
293 namespace {
294
ExecuteGraph(XlaContext * xla_context,std::unique_ptr<Graph> graph,XlaCompilationDevice * device,FunctionLibraryRuntime * flib,int64 step_id)295 Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
296 XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
297 int64 step_id) {
298 // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
299 // resource manager takes ownership via Create, and unrefs via Cleanup. We
300 // explicitly add a reference to ensure the refcount at entry is maintained at
301 // all exit points; Create and Cleanup are always called in this function.
302 //
303 // The Executor requires us to use ScopedStepContainer. We wrap it in a
304 // unique_ptr so we can capture the cleanup status in the end.
305 xla_context->Ref();
306 Status status;
307 auto step_container = xla::MakeUnique<ScopedStepContainer>(
308 step_id, [&status, device](const string& name) {
309 status = device->resource_manager()->Cleanup(name);
310 });
311 TF_RETURN_IF_ERROR(device->resource_manager()->Create(
312 step_container->name(), XlaContext::kXlaContextResourceName,
313 xla_context));
314
315 GraphCompiler graph_compiler(xla_context, device, graph.get(), flib,
316 step_container.get());
317 TF_RETURN_IF_ERROR(graph_compiler.Compile());
318 // Explicitly clean up the step container, to capture the cleanup status.
319 step_container.reset();
320 return Status::OK();
321 }
322
323 // Builds the XLA computation.
324 //
325 // `retvals` is the list of retvals produced by _Retval operators, in index
326 // order. `variable_map` is a map from variable ID numbers to XlaOpContext
327 // variable states, generated by the symbolic evaluation.
328 // If `return_updated_values_for_all_resources` is true, all resources will be
329 // included in `resource_updates`, regardless of whether their value changed.
330 // Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
331 // Sets `*resource_updates` to a description of resources whose values are
332 // written by the computation; the variable writes are the last
333 // `resource_updates.size()` return values from the computation. Each entry in
334 // `resource_updates` is a (input_index, type) pair, where `input_index` is the
335 // index of a resource variable argument to the computation, and `type` is the
336 // type of the final output.
BuildComputation(const std::vector<XlaCompiler::Argument> & args,const std::vector<int> & arg_cores,const std::vector<XlaExpression> & retvals,const std::vector<std::unique_ptr<XlaResource>> & resources,bool return_updated_values_for_all_resources,xla::ComputationBuilder * builder,xla::Computation * computation,int * num_computation_outputs,int * num_nonconst_outputs,std::vector<XlaCompiler::ResourceUpdate> * resource_updates)337 Status BuildComputation(
338 const std::vector<XlaCompiler::Argument>& args,
339 const std::vector<int>& arg_cores,
340 const std::vector<XlaExpression>& retvals,
341 const std::vector<std::unique_ptr<XlaResource>>& resources,
342 bool return_updated_values_for_all_resources,
343 xla::ComputationBuilder* builder, xla::Computation* computation,
344 int* num_computation_outputs, int* num_nonconst_outputs,
345 std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
346 std::vector<xla::ComputationDataHandle> elems;
347 elems.reserve(retvals.size());
348 for (const XlaExpression& retval : retvals) {
349 if (!retval.has_constant_value()) {
350 elems.push_back(retval.handle());
351 }
352 }
353 *num_nonconst_outputs = elems.size();
354
355 // Add return values for resources whose values have changed.
356 std::vector<const XlaResource*> arg_resources;
357 arg_resources.reserve(resources.size());
358 for (const auto& resource : resources) {
359 if (resource->arg_num() >= 0) {
360 arg_resources.push_back(resource.get());
361 }
362 }
363 std::sort(arg_resources.begin(), arg_resources.end(),
364 [](const XlaResource* a, const XlaResource* b) {
365 return a->arg_num() < b->arg_num();
366 });
367
368 for (const XlaResource* resource : arg_resources) {
369 const XlaCompiler::Argument& arg = args[resource->arg_num()];
370 const int core = arg_cores[resource->arg_num()];
371 DCHECK_LT(resource->arg_num(), arg_cores.size());
372 bool modified =
373 resource->value().handle() != resource->initial_value().handle();
374 // TensorArray gradients were modified if their values changed or there are
375 // any newly created gradients.
376 for (const auto& grad : resource->tensor_array_gradients()) {
377 modified = modified ||
378 grad.second->value().handle() !=
379 grad.second->initial_value().handle() ||
380 arg.tensor_array_gradients.count(grad.first) == 0;
381 }
382 if (return_updated_values_for_all_resources || modified) {
383 resource_updates->emplace_back();
384 XlaCompiler::ResourceUpdate& update = resource_updates->back();
385 update.input_index = resource->arg_num();
386 update.type = resource->type();
387 update.shape = resource->shape();
388 update.modified = modified;
389 for (const auto& grad : resource->tensor_array_gradients()) {
390 update.tensor_array_gradients_accessed.insert(grad.first);
391 }
392
393 // Request that the value be returned on a specific core.
394 xla::ScopedShardingAssignment assign_sharding(
395 builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
396 : xla::sharding_builder::AssignDevice(core));
397
398 xla::ComputationDataHandle handle;
399 TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
400
401 // Since we can't change the sharding metadata of <value> as this point,
402 // create a tuple/get-tuple-element combination so that sharding
403 // assignment will be placed on this value, which will cause the resource
404 // update to be returned from the same device that provided the resource.
405 handle = builder->GetTupleElement(builder->Tuple({handle}), 0);
406
407 elems.push_back(handle);
408 }
409 }
410
411 *num_computation_outputs = elems.size();
412
413 // Builds the XLA computation.
414 builder->Tuple(elems);
415 xla::StatusOr<xla::Computation> computation_status = builder->Build();
416 if (!computation_status.ok()) {
417 return computation_status.status();
418 }
419 *computation = computation_status.ConsumeValueOrDie();
420 return Status::OK();
421 }
422
423 } // namespace
424
425 // Builds XLA computations for each of the arguments to the computation.
426 // `args` are the arguments to the computation.
BuildArguments(const Graph & graph,const std::vector<XlaCompiler::Argument> & args,bool use_tuple_arg,xla::ComputationBuilder * builder,XlaContext * context,std::vector<int> * arg_cores,std::vector<XlaExpression> * arg_expressions,std::vector<int> * input_mapping,std::vector<xla::Shape> * input_shapes,bool is_entry_computation)427 Status XlaCompiler::BuildArguments(
428 const Graph& graph, const std::vector<XlaCompiler::Argument>& args,
429 bool use_tuple_arg, xla::ComputationBuilder* builder, XlaContext* context,
430 std::vector<int>* arg_cores, std::vector<XlaExpression>* arg_expressions,
431 std::vector<int>* input_mapping, std::vector<xla::Shape>* input_shapes,
432 bool is_entry_computation) {
433 arg_expressions->resize(args.size());
434 *arg_cores = std::vector<int>(args.size(), -1);
435
436 // Argument numbers of arguments and resources that are to be passed to the
437 // XLA computation as runtime parameters.
438 input_mapping->clear();
439 input_mapping->reserve(args.size());
440 std::vector<int> resources;
441 resources.reserve(args.size());
442
443 // Fills in constant arguments, and computes non-constant argument order.
444 for (std::vector<XlaCompiler::Argument>::size_type i = 0; i < args.size();
445 ++i) {
446 const XlaCompiler::Argument& arg = args[i];
447 XlaExpression& arg_expression = (*arg_expressions)[i];
448 switch (arg.kind) {
449 case XlaCompiler::Argument::kResource:
450 TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid);
451 // TODO(phawkins): this code assumes that resource arguments do not
452 // alias.
453 XlaResource* resource;
454 TF_RETURN_IF_ERROR(context->CreateResource(
455 arg.resource_kind, i, arg.name, arg.type, arg.shape,
456 xla::ComputationDataHandle(),
457 /*tensor_array_size=*/arg.tensor_array_size,
458 /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource));
459 arg_expression.set_resource(resource);
460 if (arg.initialized) {
461 resources.push_back(i);
462 }
463 break;
464 case XlaCompiler::Argument::kParameter: {
465 input_mapping->push_back(i);
466 break;
467 }
468 case XlaCompiler::Argument::kConstant:
469 arg_expression.set_constant_value(arg.constant_value);
470 break;
471 case XlaCompiler::Argument::kInvalid:
472 return errors::Internal("Unreachable case in BuildArguments()");
473 }
474 }
475
476 // Append parameters containing variable values after the other runtime
477 // parameters.
478 input_mapping->insert(input_mapping->end(), resources.begin(),
479 resources.end());
480 if (input_mapping->empty()) {
481 return Status::OK();
482 }
483
484 std::vector<xla::Shape> arg_shapes(input_mapping->size());
485 for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
486 // Computes the shapes of non-constant arguments.
487 TF_RETURN_IF_ERROR(
488 XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i]));
489 }
490
491 if (use_tuple_arg) {
492 input_shapes->push_back(xla::ShapeUtil::MakeTupleShape(arg_shapes));
493 } else {
494 *input_shapes = arg_shapes;
495 }
496
497 // Use the _Arg nodes in the graph to resolve core assignments.
498 for (const Node* n : graph.nodes()) {
499 if (StringPiece(n->type_string()) != "_Arg") continue;
500 int index;
501 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
502 TF_RET_CHECK(index >= 0 && index < args.size())
503 << "_Arg out of bounds: " << index << " vs " << args.size();
504 TF_ASSIGN_OR_RETURN(
505 auto sharding,
506 ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
507 if (sharding.has_value()) {
508 TF_RET_CHECK(sharding.value().type() ==
509 xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
510 const int core = sharding.value().tile_assignment_devices(0);
511 if ((*arg_cores)[index] == -1 || core < (*arg_cores)[index]) {
512 (*arg_cores)[index] = core;
513 }
514 }
515 }
516
517 // Build parameter handles for non-constant arguments.
518 std::vector<xla::ComputationDataHandle> arg_handles(input_mapping->size());
519 if (use_tuple_arg) {
520 xla::ComputationDataHandle tuple;
521 if (is_entry_computation) {
522 xla::OpSharding tuple_sharding;
523 tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
524 for (int64 parameter : *input_mapping) {
525 const int core = (*arg_cores)[parameter];
526 const int root_device = 0;
527 *tuple_sharding.add_tuple_shardings() =
528 core == -1 ? xla::sharding_builder::AssignDevice(root_device)
529 : xla::sharding_builder::AssignDevice(core);
530 }
531 xla::ScopedShardingAssignment assign_tuple_sharding(builder,
532 tuple_sharding);
533 tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple");
534 } else {
535 tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple");
536 }
537 for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
538 const int core = (*arg_cores)[input_mapping->at(i)];
539 xla::ScopedShardingAssignment assign_sharding(
540 builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
541 : xla::sharding_builder::AssignDevice(core));
542 arg_handles[i] = builder->GetTupleElement(tuple, i);
543 }
544 } else {
545 for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
546 const int core = (*arg_cores)[input_mapping->at(i)];
547 xla::ScopedShardingAssignment assign_sharding(
548 builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
549 : xla::sharding_builder::AssignDevice(core));
550 arg_handles[i] =
551 builder->Parameter(i, (*input_shapes)[i], strings::StrCat("arg", i));
552 }
553 }
554
555 // Fill in the handles in non-constant arguments.
556 VLOG(2) << "XLA computation inputs:";
557 for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
558 const XlaCompiler::Argument& arg = args[input_mapping->at(i)];
559 VLOG(2) << " XLA arg " << i
560 << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
561 << " name: " << arg.name << " TF arg " << input_mapping->at(i);
562 XlaExpression& arg_expression = (*arg_expressions)[input_mapping->at(i)];
563 switch (arg.kind) {
564 case XlaCompiler::Argument::kResource: {
565 TF_RET_CHECK(arg.initialized);
566 XlaResource* resource = arg_expression.resource();
567 TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients,
568 arg_handles[i], builder));
569 VLOG(2) << " resource: num_gradients: "
570 << arg.tensor_array_gradients.size();
571 break;
572 }
573 case XlaCompiler::Argument::kParameter:
574 arg_expression.set_handle(arg_handles[i]);
575 break;
576 case XlaCompiler::Argument::kConstant:
577 case XlaCompiler::Argument::kInvalid:
578 return errors::Internal("Unreachable case in BuildArguments()");
579 }
580 }
581
582 return Status::OK();
583 }
584
CompileGraph(const XlaCompiler::CompileOptions & options,string const & name,std::unique_ptr<Graph> graph,const std::vector<XlaCompiler::Argument> & args,CompilationResult * result)585 Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
586 string const& name,
587 std::unique_ptr<Graph> graph,
588 const std::vector<XlaCompiler::Argument>& args,
589 CompilationResult* result) {
590 VLOG(1) << "Executing graph symbolically to populate ComputationBuilder.";
591
592 if (VLOG_IS_ON(2)) {
593 VLOG(2) << "XlaCompiler::CompileGraph: "
594 << dump_graph::DumpGraphToFile(
595 strings::StrCat("xla_compile_graph_", name), *graph);
596 }
597
598 // Report the error here if initialization failed.
599 TF_RETURN_IF_ERROR(initialization_status_);
600
601 // Converts Tensorflow's graph control-flow constructs into functional
602 // control-flow that can be compiled into XLA code.
603 TF_RETURN_IF_ERROR(
604 FunctionalizeControlFlow(graph.get(), local_flib_def_.get()));
605
606 xla::ComputationBuilder builder(client(), name);
607 XlaContext* context =
608 new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
609 options.resolve_compile_time_constants,
610 &options_.variable_representation_shape_fn);
611 core::ScopedUnref context_unref(context);
612
613 std::vector<XlaExpression> arg_expressions;
614 std::vector<int> arg_cores;
615 TF_RETURN_IF_ERROR(
616 BuildArguments(*graph, args, options.use_tuple_arg, &builder, context,
617 &arg_cores, &arg_expressions, &result->input_mapping,
618 &result->xla_input_shapes, options.is_entry_computation));
619 context->set_args(std::move(arg_expressions));
620
621 TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
622 flib_runtime_, NextStepId()));
623
624 int num_nonconst_outputs;
625 int num_computation_outputs;
626 result->computation = std::make_shared<xla::Computation>();
627 TF_RETURN_IF_ERROR(BuildComputation(
628 args, arg_cores, context->retvals(), context->resources(),
629 options.return_updated_values_for_all_resources, &builder,
630 result->computation.get(), &num_computation_outputs,
631 &num_nonconst_outputs, &result->resource_updates));
632
633 VLOG(2) << "Outputs: total: " << context->retvals().size()
634 << " nonconstant: " << num_nonconst_outputs;
635 result->outputs.resize(context->retvals().size());
636 for (std::vector<XlaExpression>::size_type i = 0;
637 i < context->retvals().size(); ++i) {
638 const XlaExpression& retval = context->retvals()[i];
639 if (retval.has_constant_value()) {
640 OutputDescription& output = result->outputs[i];
641 output.shape = retval.constant_value().shape();
642 output.is_constant = true;
643 output.constant_value = retval.constant_value();
644 }
645 }
646
647 // Compute the output shapes, if there is a computation with non-constant
648 // outputs.
649 auto computation_shape = client()->GetComputationShape(*result->computation);
650 if (!computation_shape.ok()) {
651 return computation_shape.status();
652 }
653
654 result->xla_output_shape.Swap(
655 computation_shape.ValueOrDie()->mutable_result());
656 VLOG(2) << "XLA output shape: "
657 << xla::ShapeUtil::HumanString(result->xla_output_shape);
658
659 // Tensorflow expects a major-to-minor order of results.
660 xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
661
662 // Converts the output shapes to TensorShapes.
663 int computation_output = 0;
664 for (std::vector<XlaExpression>::size_type i = 0;
665 i < context->retvals().size(); ++i) {
666 const XlaExpression& retval = context->retvals()[i];
667 if (!retval.has_constant_value()) {
668 TF_RET_CHECK(computation_output < num_computation_outputs)
669 << "Computation has more outputs than expected";
670 OutputDescription& output = result->outputs[i];
671 output.is_constant = false;
672 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(
673 xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape,
674 computation_output),
675 &output.shape));
676 ++computation_output;
677 }
678 }
679 return Status::OK();
680 }
681
GetChannelHandle(const string & key,xla::ChannelHandle * channel)682 Status XlaCompiler::GetChannelHandle(const string& key,
683 xla::ChannelHandle* channel) {
684 auto result = channels_.emplace(key, xla::ChannelHandle());
685 if (result.second) {
686 TF_ASSIGN_OR_RETURN(result.first->second, client()->CreateChannelHandle());
687 }
688 *channel = result.first->second;
689 VLOG(1) << "Channel: " << key << " " << channel->DebugString();
690 return Status::OK();
691 }
692
693 } // namespace tensorflow
694