• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/tpu_compile.h"
16 
17 #include "tensorflow/compiler/jit/flags.h"
18 #include "tensorflow/compiler/jit/shape_inference.h"
19 #include "tensorflow/compiler/tf2xla/layout_util.h"
20 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
21 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
22 #include "tensorflow/compiler/xla/client/compile_only_client.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 #include "tensorflow/core/common_runtime/graph_constructor.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/framework/versions.pb.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/platform/statusor.h"
31 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
32 #include "tensorflow/core/tpu/kernels/tpu_util.h"
33 #include "tensorflow/core/tpu/tpu_defs.h"
34 
35 namespace tensorflow {
36 namespace tpu {
37 namespace {
38 
CoreDevice(int core)39 std::string CoreDevice(int core) {
40   return strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE, ":", core);
41 }
42 
43 static constexpr char kArgOp[] = "_Arg";
44 static constexpr char kRetvalOp[] = "_Retval";
45 
46 // Sets arg shape, arg core mapping, and per core arg shapes for a given
47 // argument, depending on its sharding.
SetPerCoreArgShapes(const tpu::TPUCompileMetadataProto::Arg & proto_arg,const int arg_index,xla::Shape * xla_arg_shape,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes)48 Status SetPerCoreArgShapes(
49     const tpu::TPUCompileMetadataProto::Arg& proto_arg, const int arg_index,
50     xla::Shape* xla_arg_shape,
51     std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
52     std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
53   if (proto_arg.unrestricted_layout()) {
54     xla_arg_shape->clear_layout();
55   }
56 
57   (*arg_core_mapping)[arg_index].sharding = proto_arg.sharding();
58   if (proto_arg.sharding().type() == xla::OpSharding::MAXIMAL) {
59     const int core = proto_arg.sharding().tile_assignment_devices(0);
60     TF_RET_CHECK(0 <= core && core < per_core_arg_shapes->size());
61     (*arg_core_mapping)[arg_index].indices.push_back(
62         (*per_core_arg_shapes)[core].size());
63     (*per_core_arg_shapes)[core].push_back(*xla_arg_shape);
64   } else if (proto_arg.sharding().type() == xla::OpSharding::OTHER) {
65     TF_ASSIGN_OR_RETURN(xla::HloSharding hlo_sharding,
66                         xla::HloSharding::FromProto(proto_arg.sharding()));
67     for (int core : proto_arg.sharding().tile_assignment_devices()) {
68       (*arg_core_mapping)[arg_index].indices.push_back(
69           (*per_core_arg_shapes)[core].size());
70       xla::Shape per_core_shape =
71           GetPerDeviceShape(*xla_arg_shape, hlo_sharding, core);
72       if (proto_arg.unrestricted_layout()) {
73         per_core_shape.clear_layout();
74       }
75       (*per_core_arg_shapes)[core].push_back(per_core_shape);
76     }
77   } else {
78     TF_RET_CHECK(proto_arg.sharding().type() == xla::OpSharding::REPLICATED)
79         << "Unsupported argument sharding: "
80         << " proto_arg=" << proto_arg.DebugString();
81     for (int core = 0; core < per_core_arg_shapes->size(); ++core) {
82       (*arg_core_mapping)[arg_index].indices.push_back(
83           (*per_core_arg_shapes)[core].size());
84       (*per_core_arg_shapes)[core].push_back(*xla_arg_shape);
85     }
86   }
87 
88   return OkStatus();
89 }
90 
91 // Adds TPU_REPLICATED_CORE device assignments to the _Arg and _Retval
92 // nodes in `graph', using the sharding/index assignments in
93 // `arg_core_mapping` and `retval_core_mapping`. The mappings are maps from
94 // original argument/return index to (sharding, per-core argument/return
95 // index) pairs. Node attributes, such as device assignments, are not
96 // preserved on function argument and return values nodes, so we must recreate
97 // them the compilation metadata.
98 // Function arguments and return values lose their device assignments, so we
99 // must recreate them.
AssignDevicesToArgsAndRetvals(absl::Span<const tpu::ShardingAndIndex> arg_core_mapping,absl::Span<const tpu::ShardingAndIndex> retval_core_mapping,Graph * graph)100 Status AssignDevicesToArgsAndRetvals(
101     absl::Span<const tpu::ShardingAndIndex> arg_core_mapping,
102     absl::Span<const tpu::ShardingAndIndex> retval_core_mapping, Graph* graph) {
103   auto assign = [&](Node* node, const xla::OpSharding& sharding) -> Status {
104     if (sharding.type() == xla::OpSharding::MAXIMAL) {
105       const string device = CoreDevice(sharding.tile_assignment_devices(0));
106       node->set_assigned_device_name(device);
107       node->set_requested_device(device);
108     } else {
109       TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED ||
110                    sharding.type() == xla::OpSharding::OTHER)
111           << "Unsupported sharding on parameter/retval: "
112           << sharding.DebugString();
113     }
114     node->AddAttr("_XlaSharding", sharding.SerializeAsString());
115     return OkStatus();
116   };
117   for (Node* node : graph->op_nodes()) {
118     if (node->type_string() == kArgOp) {
119       int index;
120       TF_RETURN_IF_ERROR(
121           tensorflow::GetNodeAttr(node->attrs(), "index", &index));
122       TF_RET_CHECK(index >= 0 && index < arg_core_mapping.size());
123       TF_RETURN_IF_ERROR(assign(node, arg_core_mapping[index].sharding));
124     } else if (node->type_string() == kRetvalOp) {
125       int index;
126       TF_RETURN_IF_ERROR(
127           tensorflow::GetNodeAttr(node->attrs(), "index", &index));
128       TF_RET_CHECK(index >= 0 && index < retval_core_mapping.size());
129       TF_RETURN_IF_ERROR(assign(node, retval_core_mapping[index].sharding));
130     }
131   }
132   return OkStatus();
133 }
134 
ConvertGraphShapeInfoToShapeMap(const Graph & graph,const GraphShapeInfo & graph_shape_info,std::unordered_map<string,std::vector<PartialTensorShape>> * shape_map)135 void ConvertGraphShapeInfoToShapeMap(
136     const Graph& graph, const GraphShapeInfo& graph_shape_info,
137     std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map) {
138   // Builds a map from node name to Node* for `graph`.
139   std::unordered_map<string, Node*> index;
140   for (Node* node : graph.nodes()) {
141     index[node->name()] = node;
142   }
143   // Discards the resource handle shape info while converting to the correct map
144   // form.
145   for (const auto& node_shape_info : graph_shape_info) {
146     const string& node_name = node_shape_info.first;
147     const std::vector<InferredShape>& output_shapes = node_shape_info.second;
148     // Gets the vector of partial shapes, first converting node name to Node*
149     // using index. graph is the subgraph of the original graph assigned to a
150     // particular core, and we only add entries to shape_map for nodes in
151     // graph_shape_info that are in the subgraph.
152     const auto& node_iter = index.find(node_name);
153     if (node_iter != index.end()) {
154       auto& partial_shapes = (*shape_map)[node_name];
155       for (const auto& inferred_shape : output_shapes) {
156         partial_shapes.push_back(inferred_shape.shape);
157       }
158     }
159   }
160 }
161 
162 // Optimizes `graph`, given the argument descriptions in `metadata` and
163 // `arg_shapes`.
OptimizeGraph(const tpu::TPUCompileMetadataProto & metadata,const std::vector<PartialTensorShape> & arg_shapes,std::unique_ptr<Graph> * graph,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld)164 Status OptimizeGraph(const tpu::TPUCompileMetadataProto& metadata,
165                      const std::vector<PartialTensorShape>& arg_shapes,
166                      std::unique_ptr<Graph>* graph, FunctionLibraryRuntime* flr,
167                      FunctionLibraryDefinition* fld) {
168   // Sets up options for the optimization passes that need to be done. Notice
169   // that CSE is not needed as XLA has its own CSE passes later in the
170   // compilation stage.
171   auto flags = GetBuildXlaOpsPassFlags();
172   OptimizerOptions opts;
173   opts.set_opt_level(OptimizerOptions::L0);
174   opts.set_do_common_subexpression_elimination(false);
175   opts.set_do_function_inlining(true);
176   opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding);
177   GraphOptimizer optimizer(opts);
178   {
179     // Performs a first function inlining pass before shape inference, since
180     // otherwise shape inference can't see inside functions and a comprehensive
181     // shape_map, including function ops, is needed to constant-propagate Shape
182     // Ops below.
183     GraphOptimizer::Options optimizer_opts;
184     optimizer_opts.inline_multi_device_functions = true;
185     optimizer_opts.inline_impl_selection_group_functions = true;
186     optimizer_opts.inline_with_single_device_body_placer = true;
187     // Infer shapes for each node in the computation. Shape inference can help
188     // skip constant folding of large shapes.
189     GraphShapeInfo shape_info;
190     TF_RETURN_IF_ERROR(internal::RunShapeInferenceOnComputation(
191         metadata, arg_shapes, graph->get(), flr, &shape_info));
192     // Converts the GraphShapeInfo into the form needed by the constant-folding
193     // pass of the optimizer.
194     std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
195     ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
196     optimizer_opts.shape_map = &shape_map;
197     optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts);
198   }
199 
200   {
201     // Infer shapes for each node in the computation.
202     GraphShapeInfo shape_info;
203     TF_RETURN_IF_ERROR(internal::RunShapeInferenceOnComputation(
204         metadata, arg_shapes, graph->get(), flr, &shape_info));
205     std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
206     ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
207     GraphOptimizer::Options optimizer_opts;
208     optimizer_opts.shape_map = &shape_map;
209     optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts);
210   }
211 
212   TF_RETURN_IF_ERROR(RewriteTensorListWithConstElement(graph->get(), fld));
213 
214   return OkStatus();
215 }
216 
217 // Populates the mapping from return value to ShardingAndIndex.
AssignReturnValueToCore(const tpu::TPUCompileMetadataProto & metadata,std::vector<tpu::ShardingAndIndex> * retval_core_mapping)218 Status AssignReturnValueToCore(
219     const tpu::TPUCompileMetadataProto& metadata,
220     std::vector<tpu::ShardingAndIndex>* retval_core_mapping) {
221   std::vector<int> per_core_retval_counts(metadata.num_cores_per_replica(), 0);
222   for (int i = 0; i < metadata.retvals_size(); ++i) {
223     const tpu::TPUCompileMetadataProto::Retval& proto_retval =
224         metadata.retvals(i);
225     (*retval_core_mapping)[i].sharding = proto_retval.sharding();
226     if (proto_retval.sharding().type() == xla::OpSharding::MAXIMAL) {
227       int core = proto_retval.sharding().tile_assignment_devices(0);
228       TF_RET_CHECK(0 <= core && core < per_core_retval_counts.size());
229       (*retval_core_mapping)[i].indices.push_back(
230           per_core_retval_counts[core]++);
231     } else if (proto_retval.sharding().type() == xla::OpSharding::OTHER) {
232       for (int64_t core : proto_retval.sharding().tile_assignment_devices()) {
233         (*retval_core_mapping)[i].indices.push_back(
234             per_core_retval_counts[core]++);
235       }
236     } else {
237       TF_RET_CHECK(proto_retval.sharding().type() ==
238                    xla::OpSharding::REPLICATED)
239           << "Unsupported return value sharding: "
240           << proto_retval.sharding().DebugString();
241       for (int core = 0; core < per_core_retval_counts.size(); ++core) {
242         (*retval_core_mapping)[i].indices.push_back(
243             per_core_retval_counts[core]++);
244       }
245     }
246   }
247   return OkStatus();
248 }
249 
250 // Populates the arguments, core mapping and per core argument shape for the
251 // computation.
BuildComputationArgumentDescriptions(const std::vector<TensorShape> & arg_shapes,const GuaranteedConsts & guaranteed_constants,const XlaCompiler & compiler,const tpu::TPUCompileMetadataProto & metadata,std::vector<XlaCompiler::Argument> * args,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes)252 Status BuildComputationArgumentDescriptions(
253     const std::vector<TensorShape>& arg_shapes,
254     const GuaranteedConsts& guaranteed_constants, const XlaCompiler& compiler,
255     const tpu::TPUCompileMetadataProto& metadata,
256     std::vector<XlaCompiler::Argument>* args,
257     std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
258     std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
259   arg_core_mapping->clear();
260   arg_core_mapping->resize(metadata.args_size());
261 
262   per_core_arg_shapes->clear();
263   per_core_arg_shapes->resize(metadata.num_cores_per_replica());
264 
265   // Builds a description of the computation's arguments.
266   int constant_count = 0;
267   size_t guaranteed_constants_size = 0;
268   for (int i = 0; i < metadata.args_size(); ++i) {
269     const tpu::TPUCompileMetadataProto::Arg& proto_arg = metadata.args(i);
270     args->push_back(XlaCompiler::Argument());
271     XlaCompiler::Argument& arg = args->back();
272     arg.type = proto_arg.dtype();
273     arg.shape = arg_shapes[i];
274     arg.node_name = proto_arg.name();
275     switch (proto_arg.kind()) {
276       case tpu::TPUCompileMetadataProto::Arg::PARAMETER:
277         arg.kind = XlaCompiler::Argument::kParameter;
278         break;
279       case tpu::TPUCompileMetadataProto::Arg::VARIABLE:
280         arg.kind = XlaCompiler::Argument::kResource;
281         arg.resource_kind = XlaResource::kVariable;
282         arg.initialized = true;
283         arg.fast_mem = proto_arg.fast_mem();
284         break;
285       case tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT:
286         arg.kind = XlaCompiler::Argument::kConstant;
287         guaranteed_constants_size =
288             guaranteed_constants.index() == 0
289                 ? absl::get<0>(guaranteed_constants).size()
290                 : absl::get<1>(guaranteed_constants)->size();
291         TF_RET_CHECK(constant_count < guaranteed_constants_size)
292             << "More constant args in TPUCompileMetadataProto than constant "
293                "tensors.";
294         if (guaranteed_constants.index() == 0) {
295           // `guaranteed_constants` is of type `absl::Span<const TensorProto*
296           // const>`.
297           Tensor tensor;
298           CHECK(tensor.FromProto(
299               *absl::get<0>(guaranteed_constants)[constant_count++]))
300               << "Failed to deserialize invalid `TensorProto` into `Tensor`.";
301           arg.constant_value = tensor;
302         } else {
303           // `guaranteed_constants` is of type `const OpInputList* const`.
304           arg.constant_value =
305               (*absl::get<1>(guaranteed_constants))[constant_count++];
306         }
307         break;
308       case tpu::TPUCompileMetadataProto::Arg::INVALID:
309       default:
310         break;
311     }
312     arg.is_same_data_across_replicas = proto_arg.is_same_data_across_replicas();
313     arg.requires_broadcast = proto_arg.requires_xla_broadcast();
314     if (arg.kind == XlaCompiler::Argument::kInvalid) {
315       return errors::InvalidArgument("Invalid argument kind");
316     }
317     if (arg.kind == XlaCompiler::Argument::kConstant) {
318       continue;
319     }
320 
321     // Assign each argument a sharding.
322     xla::Shape xla_arg_shape;
323     TF_ASSIGN_OR_RETURN(auto arg_sharding,
324                         xla::HloSharding::FromProto(proto_arg.sharding()));
325     TF_RETURN_IF_ERROR(compiler.XLAShapeForArgument(
326         arg, /*is_entry_computation=*/true, arg_sharding, &xla_arg_shape));
327     TF_RETURN_IF_ERROR(SetPerCoreArgShapes(
328         proto_arg, i, &xla_arg_shape, arg_core_mapping, per_core_arg_shapes));
329   }
330   TF_RET_CHECK(constant_count == guaranteed_constants_size)
331       << "Not all of the constant tensors were consumed.";
332 
333   return OkStatus();
334 }
335 }  // namespace
336 
337 namespace internal {
RunShapeInferenceOnComputation(const tpu::TPUCompileMetadataProto & metadata,const std::vector<PartialTensorShape> & arg_shapes,Graph * graph,FunctionLibraryRuntime * flr,GraphShapeInfo * shape_info)338 Status RunShapeInferenceOnComputation(
339     const tpu::TPUCompileMetadataProto& metadata,
340     const std::vector<PartialTensorShape>& arg_shapes, Graph* graph,
341     FunctionLibraryRuntime* flr, GraphShapeInfo* shape_info) {
342   int num_args = arg_shapes.size();
343   CHECK_EQ(num_args, metadata.args_size());
344 
345   std::map<int, InferredShape> arg_shapes_for_inference;
346   for (int i = 0; i < num_args; ++i) {
347     const auto& arg = metadata.args(i);
348     InferredShape& shape_for_inference = arg_shapes_for_inference[i];
349     if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::VARIABLE) {
350       // For resource variables, arg_shapes[] contains the shape of the
351       // variable's value.
352       shape_for_inference.handle_type = arg.dtype();
353       shape_for_inference.handle_shape = arg_shapes[i];
354       // The shape of the variable itself is always a scalar.
355       shape_for_inference.shape = TensorShape();
356     } else {
357       if (arg.kind() ==
358           tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
359         VLOG(1) << "PromisedConstant shape: " << arg_shapes[i].DebugString();
360       }
361       shape_for_inference.shape = arg_shapes[i];
362     }
363   }
364   return InferShapes(
365       graph, arg_shapes_for_inference,
366       flr != nullptr ? flr->GetFunctionLibraryDefinition() : nullptr,
367       shape_info);
368 }
369 }  // namespace internal
370 
CompileTFFunctionToHlo(const FunctionLibraryDefinition & flib_def,int graph_def_version,const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,const std::vector<TensorShape> & arg_shapes,const GuaranteedConsts & guaranteed_constants,const NameAttrList & function,const tpu::TPUCompileMetadataProto & metadata,xla::CompileOnlyClient * client,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes,bool use_tuple_args,XlaCompiler::CompilationResult * compilation_result)371 Status CompileTFFunctionToHlo(
372     const FunctionLibraryDefinition& flib_def, int graph_def_version,
373     const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
374     const std::vector<TensorShape>& arg_shapes,
375     const GuaranteedConsts& guaranteed_constants, const NameAttrList& function,
376     const tpu::TPUCompileMetadataProto& metadata,
377     xla::CompileOnlyClient* client,
378     std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
379     std::vector<std::vector<xla::Shape>>* per_core_arg_shapes,
380     bool use_tuple_args, XlaCompiler::CompilationResult* compilation_result) {
381   XlaCompiler::Options compiler_options;
382   FunctionLibraryDefinition flib_definition(flib_def);
383   compiler_options.device_type = DeviceType(DEVICE_TPU_XLA_JIT);
384   compiler_options.client = client;
385   compiler_options.flib_def = &flib_definition;
386   compiler_options.allow_cpu_custom_calls = false;
387   compiler_options.graph_def_version = graph_def_version;
388   compiler_options.shape_determination_fns = shape_determination_fns;
389 
390   auto compiler = std::make_unique<XlaCompiler>(compiler_options);
391 
392   std::vector<XlaCompiler::Argument> args;
393   TF_RETURN_IF_ERROR(BuildComputationArgumentDescriptions(
394       arg_shapes, guaranteed_constants, *compiler, metadata, &args,
395       arg_core_mapping, per_core_arg_shapes));
396 
397   // Assign each return value to a core.
398   std::vector<tpu::ShardingAndIndex> retval_core_mapping(
399       metadata.retvals_size());
400   TF_RETURN_IF_ERROR(AssignReturnValueToCore(metadata, &retval_core_mapping));
401 
402   LOG(INFO) << "Instantiating function:" << function.name();
403   FunctionLibraryRuntime::Handle handle;
404   TF_RETURN_IF_ERROR(compiler->flib_runtime()->Instantiate(
405       function.name(), AttrSlice(&function.attr()), &handle));
406   const FunctionBody* fbody = compiler->flib_runtime()->GetFunctionBody(handle);
407   const string function_id =
408       Canonicalize(function.name(), AttrSlice(&function.attr()));
409 
410   std::unique_ptr<Graph> graph(new Graph(&flib_definition));
411   CopyGraph(*fbody->graph, graph.get());
412 
413   VLOG(2) << "metadata: " << metadata.DebugString();
414   TF_RET_CHECK(fbody->arg_nodes.size() == args.size());
415   for (size_t i = 0; i < fbody->arg_nodes.size(); i++) {
416     args[i].node_name = fbody->arg_nodes[i]->name();
417   }
418 
419   std::vector<gtl::InlinedVector<int64_t, 4>> arg_shape_dims;
420   arg_shape_dims.reserve(arg_shapes.size());
421   std::vector<PartialTensorShape> partial_arg_shapes(arg_shapes.size());
422   for (const TensorShape& shape : arg_shapes) {
423     arg_shape_dims.push_back(shape.dim_sizes());
424   }
425 
426   for (int64_t i = 0; i < arg_shape_dims.size(); ++i) {
427     auto& dims = arg_shape_dims[i];
428     TF_RETURN_IF_ERROR(PartialTensorShape::MakePartialShape(
429         dims.data(), dims.size(), &partial_arg_shapes[i]));
430   }
431 
432   // Adds device assignments to _Arg and _Retval nodes.
433   TF_RETURN_IF_ERROR(AssignDevicesToArgsAndRetvals(
434       absl::MakeSpan(*arg_core_mapping), absl::MakeSpan(retval_core_mapping),
435       graph.get()));
436 
437   VLOG(1) << "Optimizing TensorFlow graph";
438   TF_RETURN_IF_ERROR(OptimizeGraph(metadata, partial_arg_shapes, &graph,
439                                    compiler->flib_runtime(), &flib_definition));
440 
441   VLOG(1) << "Compiling TensorFlow graph to HLO";
442   XlaCompiler::CompileOptions compile_options;
443   compile_options.return_updated_values_for_all_resources = false;
444   compile_options.use_tuple_arg = use_tuple_args;
445   compile_options.is_entry_computation = true;
446   compile_options.alias_resource_update = true;
447   return compiler->CompileGraph(compile_options, function_id, std::move(graph),
448                                 args, compilation_result);
449 }
450 
GetShardingInfo(const tpu::TPUCompileMetadataProto & metadata,absl::Span<const TensorShape> arg_shapes,const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes)451 Status GetShardingInfo(
452     const tpu::TPUCompileMetadataProto& metadata,
453     absl::Span<const TensorShape> arg_shapes,
454     const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
455     std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
456     std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
457   arg_core_mapping->clear();
458   arg_core_mapping->resize(metadata.args_size());
459 
460   per_core_arg_shapes->clear();
461   per_core_arg_shapes->resize(metadata.num_cores_per_replica());
462 
463   int num_inputs = metadata.args_size();
464   for (int i = 0; i < num_inputs; ++i) {
465     const auto& proto_arg = metadata.args(i);
466     TF_ASSIGN_OR_RETURN(auto arg_sharding,
467                         xla::HloSharding::FromProto(proto_arg.sharding()));
468     auto layout_preference = shape_determination_fns.layout_preference_fn(
469         arg_shapes[i], proto_arg.dtype(), absl::nullopt);
470     TF_ASSIGN_OR_RETURN(auto xla_arg_shape,
471                         shape_determination_fns.shape_representation_fn(
472                             arg_shapes[i], proto_arg.dtype(),
473                             /*use_fast_memory=*/false, layout_preference));
474     TF_RETURN_IF_ERROR(
475         RewriteLayoutWithShardedShape(arg_sharding, /*use_fast_memory=*/false,
476                                       shape_determination_fns, &xla_arg_shape));
477     TF_RETURN_IF_ERROR(SetPerCoreArgShapes(
478         proto_arg, i, &xla_arg_shape, arg_core_mapping, per_core_arg_shapes));
479   }
480   return OkStatus();
481 }
482 
483 }  // namespace tpu
484 }  // namespace tensorflow
485