• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/tpu_compile.h"
16 
17 #include "tensorflow/compiler/jit/flags.h"
18 #include "tensorflow/compiler/jit/shape_inference.h"
19 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
20 #include "tensorflow/compiler/xla/client/compile_only_client.h"
21 #include "tensorflow/compiler/xla/xla_data.pb.h"
22 #include "tensorflow/core/common_runtime/graph_constructor.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/versions.pb.h"
27 #include "tensorflow/core/graph/graph.h"
28 #include "tensorflow/core/platform/statusor.h"
29 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
30 #include "tensorflow/core/tpu/kernels/tpu_util.h"
31 #include "tensorflow/core/tpu/tpu_defs.h"
32 
33 namespace tensorflow {
34 namespace tpu {
35 namespace {
36 
CoreDevice(int core)37 std::string CoreDevice(int core) {
38   return strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE, ":", core);
39 }
40 
41 static constexpr char kArgOp[] = "_Arg";
42 static constexpr char kRetvalOp[] = "_Retval";
43 
44 // Sets arg shape, arg core mapping, and per core arg shapes for a given
45 // argument, depending on its sharding.
SetPerCoreArgShapes(const tpu::TPUCompileMetadataProto::Arg & proto_arg,const int arg_index,xla::Shape * xla_arg_shape,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes)46 Status SetPerCoreArgShapes(
47     const tpu::TPUCompileMetadataProto::Arg& proto_arg, const int arg_index,
48     xla::Shape* xla_arg_shape,
49     std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
50     std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
51   if (proto_arg.unrestricted_layout()) {
52     xla_arg_shape->clear_layout();
53   }
54 
55   (*arg_core_mapping)[arg_index].sharding = proto_arg.sharding();
56   if (proto_arg.sharding().type() == xla::OpSharding::MAXIMAL) {
57     const int core = proto_arg.sharding().tile_assignment_devices(0);
58     TF_RET_CHECK(0 <= core && core < per_core_arg_shapes->size());
59     (*arg_core_mapping)[arg_index].indices.push_back(
60         (*per_core_arg_shapes)[core].size());
61     (*per_core_arg_shapes)[core].push_back(*xla_arg_shape);
62   } else if (proto_arg.sharding().type() == xla::OpSharding::OTHER) {
63     TF_ASSIGN_OR_RETURN(xla::HloSharding hlo_sharding,
64                         xla::HloSharding::FromProto(proto_arg.sharding()));
65     for (int core : proto_arg.sharding().tile_assignment_devices()) {
66       (*arg_core_mapping)[arg_index].indices.push_back(
67           (*per_core_arg_shapes)[core].size());
68       xla::Shape per_core_shape =
69           GetPerDeviceShape(*xla_arg_shape, hlo_sharding, core);
70       if (proto_arg.unrestricted_layout()) {
71         per_core_shape.clear_layout();
72       }
73       (*per_core_arg_shapes)[core].push_back(per_core_shape);
74     }
75   } else {
76     TF_RET_CHECK(proto_arg.sharding().type() == xla::OpSharding::REPLICATED)
77         << "Unsupported argument sharding: "
78         << " proto_arg=" << proto_arg.DebugString();
79     for (int core = 0; core < per_core_arg_shapes->size(); ++core) {
80       (*arg_core_mapping)[arg_index].indices.push_back(
81           (*per_core_arg_shapes)[core].size());
82       (*per_core_arg_shapes)[core].push_back(*xla_arg_shape);
83     }
84   }
85 
86   return Status::OK();
87 }
88 
89 // Adds TPU_REPLICATED_CORE device assignments to the _Arg and _Retval
90 // nodes in `graph', using the sharding/index assignments in
91 // `arg_core_mapping` and `retval_core_mapping`. The mappings are maps from
92 // original argument/return index to (sharding, per-core argument/return
93 // index) pairs. Node attributes, such as device assignments, are not
94 // preserved on function argument and return values nodes, so we must recreate
95 // them the compilation metadata.
96 // Function arguments and return values lose their device assignments, so we
97 // must recreate them.
AssignDevicesToArgsAndRetvals(absl::Span<const tpu::ShardingAndIndex> arg_core_mapping,absl::Span<const tpu::ShardingAndIndex> retval_core_mapping,Graph * graph)98 Status AssignDevicesToArgsAndRetvals(
99     absl::Span<const tpu::ShardingAndIndex> arg_core_mapping,
100     absl::Span<const tpu::ShardingAndIndex> retval_core_mapping, Graph* graph) {
101   auto assign = [&](Node* node, const xla::OpSharding& sharding) -> Status {
102     if (sharding.type() == xla::OpSharding::MAXIMAL) {
103       const string device = CoreDevice(sharding.tile_assignment_devices(0));
104       node->set_assigned_device_name(device);
105       node->set_requested_device(device);
106     } else {
107       TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED ||
108                    sharding.type() == xla::OpSharding::OTHER)
109           << "Unsupported sharding on parameter/retval: "
110           << sharding.DebugString();
111     }
112     node->AddAttr("_XlaSharding", sharding.SerializeAsString());
113     return Status::OK();
114   };
115   for (Node* node : graph->op_nodes()) {
116     if (node->type_string() == kArgOp) {
117       int index;
118       TF_RETURN_IF_ERROR(
119           tensorflow::GetNodeAttr(node->attrs(), "index", &index));
120       TF_RET_CHECK(index >= 0 && index < arg_core_mapping.size());
121       TF_RETURN_IF_ERROR(assign(node, arg_core_mapping[index].sharding));
122     } else if (node->type_string() == kRetvalOp) {
123       int index;
124       TF_RETURN_IF_ERROR(
125           tensorflow::GetNodeAttr(node->attrs(), "index", &index));
126       TF_RET_CHECK(index >= 0 && index < retval_core_mapping.size());
127       TF_RETURN_IF_ERROR(assign(node, retval_core_mapping[index].sharding));
128     }
129   }
130   return Status::OK();
131 }
132 
ConvertGraphShapeInfoToShapeMap(const Graph & graph,const GraphShapeInfo & graph_shape_info,std::unordered_map<string,std::vector<PartialTensorShape>> * shape_map)133 void ConvertGraphShapeInfoToShapeMap(
134     const Graph& graph, const GraphShapeInfo& graph_shape_info,
135     std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map) {
136   // Builds a map from node name to Node* for `graph`.
137   std::unordered_map<string, Node*> index;
138   for (Node* node : graph.nodes()) {
139     index[node->name()] = node;
140   }
141   // Discards the resource handle shape info while converting to the correct map
142   // form.
143   for (const auto& node_shape_info : graph_shape_info) {
144     const string& node_name = node_shape_info.first;
145     const std::vector<InferredShape>& output_shapes = node_shape_info.second;
146     // Gets the vector of partial shapes, first converting node name to Node*
147     // using index. graph is the subgraph of the original graph assigned to a
148     // particular core, and we only add entries to shape_map for nodes in
149     // graph_shape_info that are in the subgraph.
150     const auto& node_iter = index.find(node_name);
151     if (node_iter != index.end()) {
152       auto& partial_shapes = (*shape_map)[node_name];
153       for (const auto& inferred_shape : output_shapes) {
154         partial_shapes.push_back(inferred_shape.shape);
155       }
156     }
157   }
158 }
159 
160 // Optimizes `graph`, given the argument descriptions in `metadata` and
161 // `arg_shapes`.
OptimizeGraph(const tpu::TPUCompileMetadataProto & metadata,const std::vector<PartialTensorShape> & arg_shapes,std::unique_ptr<Graph> * graph,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld)162 Status OptimizeGraph(const tpu::TPUCompileMetadataProto& metadata,
163                      const std::vector<PartialTensorShape>& arg_shapes,
164                      std::unique_ptr<Graph>* graph, FunctionLibraryRuntime* flr,
165                      FunctionLibraryDefinition* fld) {
166   // Sets up options for the optimization passes that need to be done. Notice
167   // that CSE is not needed as XLA has its own CSE passes later in the
168   // compilation stage.
169   auto flags = GetBuildXlaOpsPassFlags();
170   OptimizerOptions opts;
171   opts.set_opt_level(OptimizerOptions::L0);
172   opts.set_do_common_subexpression_elimination(false);
173   opts.set_do_function_inlining(true);
174   opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding);
175   GraphOptimizer optimizer(opts);
176   {
177     // Performs a first function inlining pass before shape inference, since
178     // otherwise shape inference can't see inside functions and a comprehensive
179     // shape_map, including function ops, is needed to constant-propagate Shape
180     // Ops below.
181     GraphOptimizer::Options optimizer_opts;
182     optimizer_opts.inline_multi_device_functions = true;
183     optimizer_opts.inline_impl_selection_group_functions = true;
184     optimizer_opts.inline_with_single_device_body_placer = true;
185     // Infer shapes for each node in the computation. Shape inference can help
186     // skip constant folding of large shapes.
187     GraphShapeInfo shape_info;
188     TF_RETURN_IF_ERROR(internal::RunShapeInferenceOnComputation(
189         metadata, arg_shapes, graph->get(), flr, &shape_info));
190     // Converts the GraphShapeInfo into the form needed by the constant-folding
191     // pass of the optimizer.
192     std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
193     ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
194     optimizer_opts.shape_map = &shape_map;
195     optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts);
196   }
197 
198   {
199     // Infer shapes for each node in the computation.
200     GraphShapeInfo shape_info;
201     TF_RETURN_IF_ERROR(internal::RunShapeInferenceOnComputation(
202         metadata, arg_shapes, graph->get(), flr, &shape_info));
203     std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
204     ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
205     optimizer.Optimize(flr, flr->env(), flr->device(), graph, &shape_map);
206   }
207 
208   TF_RETURN_IF_ERROR(RewriteTensorListWithConstElement(graph->get(), fld));
209 
210   return Status::OK();
211 }
212 
213 // Populates the mapping from return value to ShardingAndIndex.
AssignReturnValueToCore(const tpu::TPUCompileMetadataProto & metadata,std::vector<tpu::ShardingAndIndex> * retval_core_mapping)214 Status AssignReturnValueToCore(
215     const tpu::TPUCompileMetadataProto& metadata,
216     std::vector<tpu::ShardingAndIndex>* retval_core_mapping) {
217   std::vector<int> per_core_retval_counts(metadata.num_cores_per_replica(), 0);
218   for (int i = 0; i < metadata.retvals_size(); ++i) {
219     const tpu::TPUCompileMetadataProto::Retval& proto_retval =
220         metadata.retvals(i);
221     (*retval_core_mapping)[i].sharding = proto_retval.sharding();
222     if (proto_retval.sharding().type() == xla::OpSharding::MAXIMAL) {
223       int core = proto_retval.sharding().tile_assignment_devices(0);
224       TF_RET_CHECK(0 <= core && core < per_core_retval_counts.size());
225       (*retval_core_mapping)[i].indices.push_back(
226           per_core_retval_counts[core]++);
227     } else if (proto_retval.sharding().type() == xla::OpSharding::OTHER) {
228       for (int64_t core : proto_retval.sharding().tile_assignment_devices()) {
229         (*retval_core_mapping)[i].indices.push_back(
230             per_core_retval_counts[core]++);
231       }
232     } else {
233       TF_RET_CHECK(proto_retval.sharding().type() ==
234                    xla::OpSharding::REPLICATED)
235           << "Unsupported return value sharding: "
236           << proto_retval.sharding().DebugString();
237       for (int core = 0; core < per_core_retval_counts.size(); ++core) {
238         (*retval_core_mapping)[i].indices.push_back(
239             per_core_retval_counts[core]++);
240       }
241     }
242   }
243   return Status::OK();
244 }
245 
246 // Populates the arguments, core mapping and per core argument shape for the
247 // computation.
BuildComputationArgumentDescriptions(const std::vector<TensorShape> & arg_shapes,const GuaranteedConsts & guaranteed_constants,const XlaCompiler & compiler,const tpu::TPUCompileMetadataProto & metadata,std::vector<XlaCompiler::Argument> * args,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes)248 Status BuildComputationArgumentDescriptions(
249     const std::vector<TensorShape>& arg_shapes,
250     const GuaranteedConsts& guaranteed_constants, const XlaCompiler& compiler,
251     const tpu::TPUCompileMetadataProto& metadata,
252     std::vector<XlaCompiler::Argument>* args,
253     std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
254     std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
255   arg_core_mapping->clear();
256   arg_core_mapping->resize(metadata.args_size());
257 
258   per_core_arg_shapes->clear();
259   per_core_arg_shapes->resize(metadata.num_cores_per_replica());
260 
261   // Builds a description of the computation's arguments.
262   int constant_count = 0;
263   size_t guaranteed_constants_size = 0;
264   for (int i = 0; i < metadata.args_size(); ++i) {
265     const tpu::TPUCompileMetadataProto::Arg& proto_arg = metadata.args(i);
266     args->push_back(XlaCompiler::Argument());
267     XlaCompiler::Argument& arg = args->back();
268     arg.type = proto_arg.dtype();
269     arg.shape = arg_shapes[i];
270     arg.node_name = proto_arg.name();
271     switch (proto_arg.kind()) {
272       case tpu::TPUCompileMetadataProto::Arg::PARAMETER:
273         arg.kind = XlaCompiler::Argument::kParameter;
274         break;
275       case tpu::TPUCompileMetadataProto::Arg::VARIABLE:
276         arg.kind = XlaCompiler::Argument::kResource;
277         arg.resource_kind = XlaResource::kVariable;
278         arg.initialized = true;
279         arg.fast_mem = proto_arg.fast_mem();
280         break;
281       case tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT:
282         arg.kind = XlaCompiler::Argument::kConstant;
283         guaranteed_constants_size =
284             guaranteed_constants.index() == 0
285                 ? absl::get<0>(guaranteed_constants).size()
286                 : absl::get<1>(guaranteed_constants)->size();
287         TF_RET_CHECK(constant_count < guaranteed_constants_size)
288             << "More constant args in TPUCompileMetadataProto than constant "
289                "tensors.";
290         if (guaranteed_constants.index() == 0) {
291           // `guaranteed_constants` is of type `absl::Span<const TensorProto*
292           // const>`.
293           Tensor tensor;
294           CHECK(tensor.FromProto(
295               *absl::get<0>(guaranteed_constants)[constant_count++]))
296               << "Failed to deserialize invalid `TensorProto` into `Tensor`.";
297           arg.constant_value = tensor;
298         } else {
299           // `guaranteed_constants` is of type `const OpInputList* const`.
300           arg.constant_value =
301               (*absl::get<1>(guaranteed_constants))[constant_count++];
302         }
303         break;
304       case tpu::TPUCompileMetadataProto::Arg::INVALID:
305       default:
306         break;
307     }
308     arg.is_same_data_across_replicas = proto_arg.is_same_data_across_replicas();
309     arg.requires_broadcast = proto_arg.requires_xla_broadcast();
310     if (arg.kind == XlaCompiler::Argument::kInvalid) {
311       return errors::InvalidArgument("Invalid argument kind");
312     }
313     if (arg.kind == XlaCompiler::Argument::kConstant) {
314       continue;
315     }
316 
317     // Assign each argument a sharding.
318     xla::Shape xla_arg_shape;
319     TF_ASSIGN_OR_RETURN(auto arg_sharding,
320                         xla::HloSharding::FromProto(proto_arg.sharding()));
321     TF_RETURN_IF_ERROR(compiler.XLAShapeForArgument(
322         arg, /*is_entry_computation=*/true, arg_sharding, &xla_arg_shape));
323     TF_RETURN_IF_ERROR(SetPerCoreArgShapes(
324         proto_arg, i, &xla_arg_shape, arg_core_mapping, per_core_arg_shapes));
325   }
326   TF_RET_CHECK(constant_count == guaranteed_constants_size)
327       << "Not all of the constant tensors were consumed.";
328 
329   return Status::OK();
330 }
331 }  // namespace
332 
333 namespace internal {
RunShapeInferenceOnComputation(const tpu::TPUCompileMetadataProto & metadata,const std::vector<PartialTensorShape> & arg_shapes,Graph * graph,FunctionLibraryRuntime * flr,GraphShapeInfo * shape_info)334 Status RunShapeInferenceOnComputation(
335     const tpu::TPUCompileMetadataProto& metadata,
336     const std::vector<PartialTensorShape>& arg_shapes, Graph* graph,
337     FunctionLibraryRuntime* flr, GraphShapeInfo* shape_info) {
338   int num_args = arg_shapes.size();
339   CHECK_EQ(num_args, metadata.args_size());
340 
341   std::map<int, InferredShape> arg_shapes_for_inference;
342   for (int i = 0; i < num_args; ++i) {
343     const auto& arg = metadata.args(i);
344     InferredShape& shape_for_inference = arg_shapes_for_inference[i];
345     if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::VARIABLE) {
346       // For resource variables, arg_shapes[] contains the shape of the
347       // variable's value.
348       shape_for_inference.handle_type = arg.dtype();
349       shape_for_inference.handle_shape = arg_shapes[i];
350       // The shape of the variable itself is always a scalar.
351       shape_for_inference.shape = TensorShape();
352     } else {
353       if (arg.kind() ==
354           tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
355         VLOG(1) << "PromisedConstant shape: " << arg_shapes[i].DebugString();
356       }
357       shape_for_inference.shape = arg_shapes[i];
358     }
359   }
360   return InferShapes(
361       graph, arg_shapes_for_inference,
362       flr != nullptr ? flr->GetFunctionLibraryDefinition() : nullptr,
363       shape_info);
364 }
365 }  // namespace internal
366 
CompileTFFunctionToHlo(const FunctionLibraryDefinition & flib_def,int graph_def_version,const XlaCompiler::ShapeRepresentationFn shape_representation_fn,const std::vector<TensorShape> & arg_shapes,const GuaranteedConsts & guaranteed_constants,const NameAttrList & function,const tpu::TPUCompileMetadataProto & metadata,std::function<Status (ResourceMgr *)> populate_resource_manager_fn,xla::CompileOnlyClient * client,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes,XlaCompiler::CompilationResult * compilation_result)367 Status CompileTFFunctionToHlo(
368     const FunctionLibraryDefinition& flib_def, int graph_def_version,
369     const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
370     const std::vector<TensorShape>& arg_shapes,
371     const GuaranteedConsts& guaranteed_constants, const NameAttrList& function,
372     const tpu::TPUCompileMetadataProto& metadata,
373     std::function<Status(ResourceMgr*)> populate_resource_manager_fn,
374     xla::CompileOnlyClient* client,
375     std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
376     std::vector<std::vector<xla::Shape>>* per_core_arg_shapes,
377     XlaCompiler::CompilationResult* compilation_result) {
378   XlaCompiler::Options compiler_options;
379   FunctionLibraryDefinition flib_definition(flib_def);
380   compiler_options.device_type = DeviceType(DEVICE_TPU_XLA_JIT);
381   compiler_options.client = client;
382   compiler_options.flib_def = &flib_definition;
383   compiler_options.allow_cpu_custom_calls = false;
384   compiler_options.populate_resource_manager = &populate_resource_manager_fn;
385   compiler_options.graph_def_version = graph_def_version;
386   compiler_options.shape_representation_fn = shape_representation_fn;
387 
388   auto compiler = absl::make_unique<XlaCompiler>(compiler_options);
389 
390   std::vector<XlaCompiler::Argument> args;
391   TF_RETURN_IF_ERROR(BuildComputationArgumentDescriptions(
392       arg_shapes, guaranteed_constants, *compiler, metadata, &args,
393       arg_core_mapping, per_core_arg_shapes));
394 
395   // Assign each return value to a core.
396   std::vector<tpu::ShardingAndIndex> retval_core_mapping(
397       metadata.retvals_size());
398   TF_RETURN_IF_ERROR(AssignReturnValueToCore(metadata, &retval_core_mapping));
399 
400   LOG(INFO) << "Instantiating function:" << function.name();
401   FunctionLibraryRuntime::Handle handle;
402   TF_RETURN_IF_ERROR(compiler->flib_runtime()->Instantiate(
403       function.name(), AttrSlice(&function.attr()), &handle));
404   const FunctionBody* fbody = compiler->flib_runtime()->GetFunctionBody(handle);
405   const string function_id =
406       Canonicalize(function.name(), AttrSlice(&function.attr()));
407 
408   std::unique_ptr<Graph> graph(new Graph(&flib_definition));
409   CopyGraph(*fbody->graph, graph.get());
410 
411   VLOG(2) << "metadata: " << metadata.DebugString();
412   TF_RET_CHECK(fbody->arg_nodes.size() == args.size());
413   for (size_t i = 0; i < fbody->arg_nodes.size(); i++) {
414     args[i].node_name = fbody->arg_nodes[i]->name();
415   }
416 
417   std::vector<gtl::InlinedVector<int64, 4>> arg_shape_dims;
418   arg_shape_dims.reserve(arg_shapes.size());
419   std::vector<PartialTensorShape> partial_arg_shapes(arg_shapes.size());
420   for (const TensorShape& shape : arg_shapes) {
421     arg_shape_dims.push_back(shape.dim_sizes());
422   }
423 
424   for (int64_t i = 0; i < arg_shape_dims.size(); ++i) {
425     auto& dims = arg_shape_dims[i];
426     TF_RETURN_IF_ERROR(PartialTensorShape::MakePartialShape(
427         dims.data(), dims.size(), &partial_arg_shapes[i]));
428   }
429 
430   // Adds device assignments to _Arg and _Retval nodes.
431   TF_RETURN_IF_ERROR(AssignDevicesToArgsAndRetvals(
432       absl::MakeSpan(*arg_core_mapping), absl::MakeSpan(retval_core_mapping),
433       graph.get()));
434 
435   VLOG(1) << "Optimizing TensorFlow graph";
436   TF_RETURN_IF_ERROR(OptimizeGraph(metadata, partial_arg_shapes, &graph,
437                                    compiler->flib_runtime(), &flib_definition));
438 
439   VLOG(1) << "Compiling TensorFlow graph to HLO";
440   XlaCompiler::CompileOptions compile_options;
441   compile_options.return_updated_values_for_all_resources = false;
442   compile_options.use_tuple_arg = true;
443   compile_options.is_entry_computation = true;
444   compile_options.alias_resource_update = true;
445   return compiler->CompileGraph(compile_options, function_id, std::move(graph),
446                                 args, compilation_result);
447 }
448 
GetShardingInfo(const tpu::TPUCompileMetadataProto & metadata,absl::Span<const TensorShape> arg_shapes,const XlaCompiler::ShapeRepresentationFn shape_representation_fn,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes)449 Status GetShardingInfo(
450     const tpu::TPUCompileMetadataProto& metadata,
451     absl::Span<const TensorShape> arg_shapes,
452     const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
453     std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
454     std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
455   arg_core_mapping->clear();
456   arg_core_mapping->resize(metadata.args_size());
457 
458   per_core_arg_shapes->clear();
459   per_core_arg_shapes->resize(metadata.num_cores_per_replica());
460 
461   int num_inputs = metadata.args_size();
462   for (int i = 0; i < num_inputs; ++i) {
463     const auto& proto_arg = metadata.args(i);
464     TF_ASSIGN_OR_RETURN(auto arg_sharding,
465                         xla::HloSharding::FromProto(proto_arg.sharding()));
466     TF_ASSIGN_OR_RETURN(
467         auto xla_arg_shape,
468         shape_representation_fn(arg_shapes[i], proto_arg.dtype(),
469                                 /*use_fast_memory=*/false));
470     TF_RETURN_IF_ERROR(
471         RewriteLayoutWithShardedShape(arg_sharding, /*use_fast_memory=*/false,
472                                       shape_representation_fn, &xla_arg_shape));
473     TF_RETURN_IF_ERROR(SetPerCoreArgShapes(
474         proto_arg, i, &xla_arg_shape, arg_core_mapping, per_core_arg_shapes));
475   }
476   return Status::OK();
477 }
478 
479 }  // namespace tpu
480 }  // namespace tensorflow
481