1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/tpu_compile.h"
16
17 #include "tensorflow/compiler/jit/flags.h"
18 #include "tensorflow/compiler/jit/shape_inference.h"
19 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
20 #include "tensorflow/compiler/xla/client/compile_only_client.h"
21 #include "tensorflow/compiler/xla/xla_data.pb.h"
22 #include "tensorflow/core/common_runtime/graph_constructor.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/versions.pb.h"
27 #include "tensorflow/core/graph/graph.h"
28 #include "tensorflow/core/platform/statusor.h"
29 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
30 #include "tensorflow/core/tpu/kernels/tpu_util.h"
31 #include "tensorflow/core/tpu/tpu_defs.h"
32
33 namespace tensorflow {
34 namespace tpu {
35 namespace {
36
CoreDevice(int core)37 std::string CoreDevice(int core) {
38 return strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE, ":", core);
39 }
40
41 static constexpr char kArgOp[] = "_Arg";
42 static constexpr char kRetvalOp[] = "_Retval";
43
44 // Sets arg shape, arg core mapping, and per core arg shapes for a given
45 // argument, depending on its sharding.
SetPerCoreArgShapes(const tpu::TPUCompileMetadataProto::Arg & proto_arg,const int arg_index,xla::Shape * xla_arg_shape,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes)46 Status SetPerCoreArgShapes(
47 const tpu::TPUCompileMetadataProto::Arg& proto_arg, const int arg_index,
48 xla::Shape* xla_arg_shape,
49 std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
50 std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
51 if (proto_arg.unrestricted_layout()) {
52 xla_arg_shape->clear_layout();
53 }
54
55 (*arg_core_mapping)[arg_index].sharding = proto_arg.sharding();
56 if (proto_arg.sharding().type() == xla::OpSharding::MAXIMAL) {
57 const int core = proto_arg.sharding().tile_assignment_devices(0);
58 TF_RET_CHECK(0 <= core && core < per_core_arg_shapes->size());
59 (*arg_core_mapping)[arg_index].indices.push_back(
60 (*per_core_arg_shapes)[core].size());
61 (*per_core_arg_shapes)[core].push_back(*xla_arg_shape);
62 } else if (proto_arg.sharding().type() == xla::OpSharding::OTHER) {
63 TF_ASSIGN_OR_RETURN(xla::HloSharding hlo_sharding,
64 xla::HloSharding::FromProto(proto_arg.sharding()));
65 for (int core : proto_arg.sharding().tile_assignment_devices()) {
66 (*arg_core_mapping)[arg_index].indices.push_back(
67 (*per_core_arg_shapes)[core].size());
68 xla::Shape per_core_shape =
69 GetPerDeviceShape(*xla_arg_shape, hlo_sharding, core);
70 if (proto_arg.unrestricted_layout()) {
71 per_core_shape.clear_layout();
72 }
73 (*per_core_arg_shapes)[core].push_back(per_core_shape);
74 }
75 } else {
76 TF_RET_CHECK(proto_arg.sharding().type() == xla::OpSharding::REPLICATED)
77 << "Unsupported argument sharding: "
78 << " proto_arg=" << proto_arg.DebugString();
79 for (int core = 0; core < per_core_arg_shapes->size(); ++core) {
80 (*arg_core_mapping)[arg_index].indices.push_back(
81 (*per_core_arg_shapes)[core].size());
82 (*per_core_arg_shapes)[core].push_back(*xla_arg_shape);
83 }
84 }
85
86 return Status::OK();
87 }
88
89 // Adds TPU_REPLICATED_CORE device assignments to the _Arg and _Retval
90 // nodes in `graph', using the sharding/index assignments in
91 // `arg_core_mapping` and `retval_core_mapping`. The mappings are maps from
92 // original argument/return index to (sharding, per-core argument/return
93 // index) pairs. Node attributes, such as device assignments, are not
94 // preserved on function argument and return values nodes, so we must recreate
95 // them the compilation metadata.
96 // Function arguments and return values lose their device assignments, so we
97 // must recreate them.
AssignDevicesToArgsAndRetvals(absl::Span<const tpu::ShardingAndIndex> arg_core_mapping,absl::Span<const tpu::ShardingAndIndex> retval_core_mapping,Graph * graph)98 Status AssignDevicesToArgsAndRetvals(
99 absl::Span<const tpu::ShardingAndIndex> arg_core_mapping,
100 absl::Span<const tpu::ShardingAndIndex> retval_core_mapping, Graph* graph) {
101 auto assign = [&](Node* node, const xla::OpSharding& sharding) -> Status {
102 if (sharding.type() == xla::OpSharding::MAXIMAL) {
103 const string device = CoreDevice(sharding.tile_assignment_devices(0));
104 node->set_assigned_device_name(device);
105 node->set_requested_device(device);
106 } else {
107 TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED ||
108 sharding.type() == xla::OpSharding::OTHER)
109 << "Unsupported sharding on parameter/retval: "
110 << sharding.DebugString();
111 }
112 node->AddAttr("_XlaSharding", sharding.SerializeAsString());
113 return Status::OK();
114 };
115 for (Node* node : graph->op_nodes()) {
116 if (node->type_string() == kArgOp) {
117 int index;
118 TF_RETURN_IF_ERROR(
119 tensorflow::GetNodeAttr(node->attrs(), "index", &index));
120 TF_RET_CHECK(index >= 0 && index < arg_core_mapping.size());
121 TF_RETURN_IF_ERROR(assign(node, arg_core_mapping[index].sharding));
122 } else if (node->type_string() == kRetvalOp) {
123 int index;
124 TF_RETURN_IF_ERROR(
125 tensorflow::GetNodeAttr(node->attrs(), "index", &index));
126 TF_RET_CHECK(index >= 0 && index < retval_core_mapping.size());
127 TF_RETURN_IF_ERROR(assign(node, retval_core_mapping[index].sharding));
128 }
129 }
130 return Status::OK();
131 }
132
ConvertGraphShapeInfoToShapeMap(const Graph & graph,const GraphShapeInfo & graph_shape_info,std::unordered_map<string,std::vector<PartialTensorShape>> * shape_map)133 void ConvertGraphShapeInfoToShapeMap(
134 const Graph& graph, const GraphShapeInfo& graph_shape_info,
135 std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map) {
136 // Builds a map from node name to Node* for `graph`.
137 std::unordered_map<string, Node*> index;
138 for (Node* node : graph.nodes()) {
139 index[node->name()] = node;
140 }
141 // Discards the resource handle shape info while converting to the correct map
142 // form.
143 for (const auto& node_shape_info : graph_shape_info) {
144 const string& node_name = node_shape_info.first;
145 const std::vector<InferredShape>& output_shapes = node_shape_info.second;
146 // Gets the vector of partial shapes, first converting node name to Node*
147 // using index. graph is the subgraph of the original graph assigned to a
148 // particular core, and we only add entries to shape_map for nodes in
149 // graph_shape_info that are in the subgraph.
150 const auto& node_iter = index.find(node_name);
151 if (node_iter != index.end()) {
152 auto& partial_shapes = (*shape_map)[node_name];
153 for (const auto& inferred_shape : output_shapes) {
154 partial_shapes.push_back(inferred_shape.shape);
155 }
156 }
157 }
158 }
159
160 // Optimizes `graph`, given the argument descriptions in `metadata` and
161 // `arg_shapes`.
OptimizeGraph(const tpu::TPUCompileMetadataProto & metadata,const std::vector<PartialTensorShape> & arg_shapes,std::unique_ptr<Graph> * graph,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld)162 Status OptimizeGraph(const tpu::TPUCompileMetadataProto& metadata,
163 const std::vector<PartialTensorShape>& arg_shapes,
164 std::unique_ptr<Graph>* graph, FunctionLibraryRuntime* flr,
165 FunctionLibraryDefinition* fld) {
166 // Sets up options for the optimization passes that need to be done. Notice
167 // that CSE is not needed as XLA has its own CSE passes later in the
168 // compilation stage.
169 auto flags = GetBuildXlaOpsPassFlags();
170 OptimizerOptions opts;
171 opts.set_opt_level(OptimizerOptions::L0);
172 opts.set_do_common_subexpression_elimination(false);
173 opts.set_do_function_inlining(true);
174 opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding);
175 GraphOptimizer optimizer(opts);
176 {
177 // Performs a first function inlining pass before shape inference, since
178 // otherwise shape inference can't see inside functions and a comprehensive
179 // shape_map, including function ops, is needed to constant-propagate Shape
180 // Ops below.
181 GraphOptimizer::Options optimizer_opts;
182 optimizer_opts.inline_multi_device_functions = true;
183 optimizer_opts.inline_impl_selection_group_functions = true;
184 optimizer_opts.inline_with_single_device_body_placer = true;
185 // Infer shapes for each node in the computation. Shape inference can help
186 // skip constant folding of large shapes.
187 GraphShapeInfo shape_info;
188 TF_RETURN_IF_ERROR(internal::RunShapeInferenceOnComputation(
189 metadata, arg_shapes, graph->get(), flr, &shape_info));
190 // Converts the GraphShapeInfo into the form needed by the constant-folding
191 // pass of the optimizer.
192 std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
193 ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
194 optimizer_opts.shape_map = &shape_map;
195 optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts);
196 }
197
198 {
199 // Infer shapes for each node in the computation.
200 GraphShapeInfo shape_info;
201 TF_RETURN_IF_ERROR(internal::RunShapeInferenceOnComputation(
202 metadata, arg_shapes, graph->get(), flr, &shape_info));
203 std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
204 ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
205 optimizer.Optimize(flr, flr->env(), flr->device(), graph, &shape_map);
206 }
207
208 TF_RETURN_IF_ERROR(RewriteTensorListWithConstElement(graph->get(), fld));
209
210 return Status::OK();
211 }
212
213 // Populates the mapping from return value to ShardingAndIndex.
AssignReturnValueToCore(const tpu::TPUCompileMetadataProto & metadata,std::vector<tpu::ShardingAndIndex> * retval_core_mapping)214 Status AssignReturnValueToCore(
215 const tpu::TPUCompileMetadataProto& metadata,
216 std::vector<tpu::ShardingAndIndex>* retval_core_mapping) {
217 std::vector<int> per_core_retval_counts(metadata.num_cores_per_replica(), 0);
218 for (int i = 0; i < metadata.retvals_size(); ++i) {
219 const tpu::TPUCompileMetadataProto::Retval& proto_retval =
220 metadata.retvals(i);
221 (*retval_core_mapping)[i].sharding = proto_retval.sharding();
222 if (proto_retval.sharding().type() == xla::OpSharding::MAXIMAL) {
223 int core = proto_retval.sharding().tile_assignment_devices(0);
224 TF_RET_CHECK(0 <= core && core < per_core_retval_counts.size());
225 (*retval_core_mapping)[i].indices.push_back(
226 per_core_retval_counts[core]++);
227 } else if (proto_retval.sharding().type() == xla::OpSharding::OTHER) {
228 for (int64_t core : proto_retval.sharding().tile_assignment_devices()) {
229 (*retval_core_mapping)[i].indices.push_back(
230 per_core_retval_counts[core]++);
231 }
232 } else {
233 TF_RET_CHECK(proto_retval.sharding().type() ==
234 xla::OpSharding::REPLICATED)
235 << "Unsupported return value sharding: "
236 << proto_retval.sharding().DebugString();
237 for (int core = 0; core < per_core_retval_counts.size(); ++core) {
238 (*retval_core_mapping)[i].indices.push_back(
239 per_core_retval_counts[core]++);
240 }
241 }
242 }
243 return Status::OK();
244 }
245
246 // Populates the arguments, core mapping and per core argument shape for the
247 // computation.
BuildComputationArgumentDescriptions(const std::vector<TensorShape> & arg_shapes,const GuaranteedConsts & guaranteed_constants,const XlaCompiler & compiler,const tpu::TPUCompileMetadataProto & metadata,std::vector<XlaCompiler::Argument> * args,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes)248 Status BuildComputationArgumentDescriptions(
249 const std::vector<TensorShape>& arg_shapes,
250 const GuaranteedConsts& guaranteed_constants, const XlaCompiler& compiler,
251 const tpu::TPUCompileMetadataProto& metadata,
252 std::vector<XlaCompiler::Argument>* args,
253 std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
254 std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
255 arg_core_mapping->clear();
256 arg_core_mapping->resize(metadata.args_size());
257
258 per_core_arg_shapes->clear();
259 per_core_arg_shapes->resize(metadata.num_cores_per_replica());
260
261 // Builds a description of the computation's arguments.
262 int constant_count = 0;
263 size_t guaranteed_constants_size = 0;
264 for (int i = 0; i < metadata.args_size(); ++i) {
265 const tpu::TPUCompileMetadataProto::Arg& proto_arg = metadata.args(i);
266 args->push_back(XlaCompiler::Argument());
267 XlaCompiler::Argument& arg = args->back();
268 arg.type = proto_arg.dtype();
269 arg.shape = arg_shapes[i];
270 arg.node_name = proto_arg.name();
271 switch (proto_arg.kind()) {
272 case tpu::TPUCompileMetadataProto::Arg::PARAMETER:
273 arg.kind = XlaCompiler::Argument::kParameter;
274 break;
275 case tpu::TPUCompileMetadataProto::Arg::VARIABLE:
276 arg.kind = XlaCompiler::Argument::kResource;
277 arg.resource_kind = XlaResource::kVariable;
278 arg.initialized = true;
279 arg.fast_mem = proto_arg.fast_mem();
280 break;
281 case tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT:
282 arg.kind = XlaCompiler::Argument::kConstant;
283 guaranteed_constants_size =
284 guaranteed_constants.index() == 0
285 ? absl::get<0>(guaranteed_constants).size()
286 : absl::get<1>(guaranteed_constants)->size();
287 TF_RET_CHECK(constant_count < guaranteed_constants_size)
288 << "More constant args in TPUCompileMetadataProto than constant "
289 "tensors.";
290 if (guaranteed_constants.index() == 0) {
291 // `guaranteed_constants` is of type `absl::Span<const TensorProto*
292 // const>`.
293 Tensor tensor;
294 CHECK(tensor.FromProto(
295 *absl::get<0>(guaranteed_constants)[constant_count++]))
296 << "Failed to deserialize invalid `TensorProto` into `Tensor`.";
297 arg.constant_value = tensor;
298 } else {
299 // `guaranteed_constants` is of type `const OpInputList* const`.
300 arg.constant_value =
301 (*absl::get<1>(guaranteed_constants))[constant_count++];
302 }
303 break;
304 case tpu::TPUCompileMetadataProto::Arg::INVALID:
305 default:
306 break;
307 }
308 arg.is_same_data_across_replicas = proto_arg.is_same_data_across_replicas();
309 arg.requires_broadcast = proto_arg.requires_xla_broadcast();
310 if (arg.kind == XlaCompiler::Argument::kInvalid) {
311 return errors::InvalidArgument("Invalid argument kind");
312 }
313 if (arg.kind == XlaCompiler::Argument::kConstant) {
314 continue;
315 }
316
317 // Assign each argument a sharding.
318 xla::Shape xla_arg_shape;
319 TF_ASSIGN_OR_RETURN(auto arg_sharding,
320 xla::HloSharding::FromProto(proto_arg.sharding()));
321 TF_RETURN_IF_ERROR(compiler.XLAShapeForArgument(
322 arg, /*is_entry_computation=*/true, arg_sharding, &xla_arg_shape));
323 TF_RETURN_IF_ERROR(SetPerCoreArgShapes(
324 proto_arg, i, &xla_arg_shape, arg_core_mapping, per_core_arg_shapes));
325 }
326 TF_RET_CHECK(constant_count == guaranteed_constants_size)
327 << "Not all of the constant tensors were consumed.";
328
329 return Status::OK();
330 }
331 } // namespace
332
333 namespace internal {
RunShapeInferenceOnComputation(const tpu::TPUCompileMetadataProto & metadata,const std::vector<PartialTensorShape> & arg_shapes,Graph * graph,FunctionLibraryRuntime * flr,GraphShapeInfo * shape_info)334 Status RunShapeInferenceOnComputation(
335 const tpu::TPUCompileMetadataProto& metadata,
336 const std::vector<PartialTensorShape>& arg_shapes, Graph* graph,
337 FunctionLibraryRuntime* flr, GraphShapeInfo* shape_info) {
338 int num_args = arg_shapes.size();
339 CHECK_EQ(num_args, metadata.args_size());
340
341 std::map<int, InferredShape> arg_shapes_for_inference;
342 for (int i = 0; i < num_args; ++i) {
343 const auto& arg = metadata.args(i);
344 InferredShape& shape_for_inference = arg_shapes_for_inference[i];
345 if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::VARIABLE) {
346 // For resource variables, arg_shapes[] contains the shape of the
347 // variable's value.
348 shape_for_inference.handle_type = arg.dtype();
349 shape_for_inference.handle_shape = arg_shapes[i];
350 // The shape of the variable itself is always a scalar.
351 shape_for_inference.shape = TensorShape();
352 } else {
353 if (arg.kind() ==
354 tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
355 VLOG(1) << "PromisedConstant shape: " << arg_shapes[i].DebugString();
356 }
357 shape_for_inference.shape = arg_shapes[i];
358 }
359 }
360 return InferShapes(
361 graph, arg_shapes_for_inference,
362 flr != nullptr ? flr->GetFunctionLibraryDefinition() : nullptr,
363 shape_info);
364 }
365 } // namespace internal
366
CompileTFFunctionToHlo(const FunctionLibraryDefinition & flib_def,int graph_def_version,const XlaCompiler::ShapeRepresentationFn shape_representation_fn,const std::vector<TensorShape> & arg_shapes,const GuaranteedConsts & guaranteed_constants,const NameAttrList & function,const tpu::TPUCompileMetadataProto & metadata,std::function<Status (ResourceMgr *)> populate_resource_manager_fn,xla::CompileOnlyClient * client,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes,XlaCompiler::CompilationResult * compilation_result)367 Status CompileTFFunctionToHlo(
368 const FunctionLibraryDefinition& flib_def, int graph_def_version,
369 const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
370 const std::vector<TensorShape>& arg_shapes,
371 const GuaranteedConsts& guaranteed_constants, const NameAttrList& function,
372 const tpu::TPUCompileMetadataProto& metadata,
373 std::function<Status(ResourceMgr*)> populate_resource_manager_fn,
374 xla::CompileOnlyClient* client,
375 std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
376 std::vector<std::vector<xla::Shape>>* per_core_arg_shapes,
377 XlaCompiler::CompilationResult* compilation_result) {
378 XlaCompiler::Options compiler_options;
379 FunctionLibraryDefinition flib_definition(flib_def);
380 compiler_options.device_type = DeviceType(DEVICE_TPU_XLA_JIT);
381 compiler_options.client = client;
382 compiler_options.flib_def = &flib_definition;
383 compiler_options.allow_cpu_custom_calls = false;
384 compiler_options.populate_resource_manager = &populate_resource_manager_fn;
385 compiler_options.graph_def_version = graph_def_version;
386 compiler_options.shape_representation_fn = shape_representation_fn;
387
388 auto compiler = absl::make_unique<XlaCompiler>(compiler_options);
389
390 std::vector<XlaCompiler::Argument> args;
391 TF_RETURN_IF_ERROR(BuildComputationArgumentDescriptions(
392 arg_shapes, guaranteed_constants, *compiler, metadata, &args,
393 arg_core_mapping, per_core_arg_shapes));
394
395 // Assign each return value to a core.
396 std::vector<tpu::ShardingAndIndex> retval_core_mapping(
397 metadata.retvals_size());
398 TF_RETURN_IF_ERROR(AssignReturnValueToCore(metadata, &retval_core_mapping));
399
400 LOG(INFO) << "Instantiating function:" << function.name();
401 FunctionLibraryRuntime::Handle handle;
402 TF_RETURN_IF_ERROR(compiler->flib_runtime()->Instantiate(
403 function.name(), AttrSlice(&function.attr()), &handle));
404 const FunctionBody* fbody = compiler->flib_runtime()->GetFunctionBody(handle);
405 const string function_id =
406 Canonicalize(function.name(), AttrSlice(&function.attr()));
407
408 std::unique_ptr<Graph> graph(new Graph(&flib_definition));
409 CopyGraph(*fbody->graph, graph.get());
410
411 VLOG(2) << "metadata: " << metadata.DebugString();
412 TF_RET_CHECK(fbody->arg_nodes.size() == args.size());
413 for (size_t i = 0; i < fbody->arg_nodes.size(); i++) {
414 args[i].node_name = fbody->arg_nodes[i]->name();
415 }
416
417 std::vector<gtl::InlinedVector<int64, 4>> arg_shape_dims;
418 arg_shape_dims.reserve(arg_shapes.size());
419 std::vector<PartialTensorShape> partial_arg_shapes(arg_shapes.size());
420 for (const TensorShape& shape : arg_shapes) {
421 arg_shape_dims.push_back(shape.dim_sizes());
422 }
423
424 for (int64_t i = 0; i < arg_shape_dims.size(); ++i) {
425 auto& dims = arg_shape_dims[i];
426 TF_RETURN_IF_ERROR(PartialTensorShape::MakePartialShape(
427 dims.data(), dims.size(), &partial_arg_shapes[i]));
428 }
429
430 // Adds device assignments to _Arg and _Retval nodes.
431 TF_RETURN_IF_ERROR(AssignDevicesToArgsAndRetvals(
432 absl::MakeSpan(*arg_core_mapping), absl::MakeSpan(retval_core_mapping),
433 graph.get()));
434
435 VLOG(1) << "Optimizing TensorFlow graph";
436 TF_RETURN_IF_ERROR(OptimizeGraph(metadata, partial_arg_shapes, &graph,
437 compiler->flib_runtime(), &flib_definition));
438
439 VLOG(1) << "Compiling TensorFlow graph to HLO";
440 XlaCompiler::CompileOptions compile_options;
441 compile_options.return_updated_values_for_all_resources = false;
442 compile_options.use_tuple_arg = true;
443 compile_options.is_entry_computation = true;
444 compile_options.alias_resource_update = true;
445 return compiler->CompileGraph(compile_options, function_id, std::move(graph),
446 args, compilation_result);
447 }
448
GetShardingInfo(const tpu::TPUCompileMetadataProto & metadata,absl::Span<const TensorShape> arg_shapes,const XlaCompiler::ShapeRepresentationFn shape_representation_fn,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes)449 Status GetShardingInfo(
450 const tpu::TPUCompileMetadataProto& metadata,
451 absl::Span<const TensorShape> arg_shapes,
452 const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
453 std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
454 std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
455 arg_core_mapping->clear();
456 arg_core_mapping->resize(metadata.args_size());
457
458 per_core_arg_shapes->clear();
459 per_core_arg_shapes->resize(metadata.num_cores_per_replica());
460
461 int num_inputs = metadata.args_size();
462 for (int i = 0; i < num_inputs; ++i) {
463 const auto& proto_arg = metadata.args(i);
464 TF_ASSIGN_OR_RETURN(auto arg_sharding,
465 xla::HloSharding::FromProto(proto_arg.sharding()));
466 TF_ASSIGN_OR_RETURN(
467 auto xla_arg_shape,
468 shape_representation_fn(arg_shapes[i], proto_arg.dtype(),
469 /*use_fast_memory=*/false));
470 TF_RETURN_IF_ERROR(
471 RewriteLayoutWithShardedShape(arg_sharding, /*use_fast_memory=*/false,
472 shape_representation_fn, &xla_arg_shape));
473 TF_RETURN_IF_ERROR(SetPerCoreArgShapes(
474 proto_arg, i, &xla_arg_shape, arg_core_mapping, per_core_arg_shapes));
475 }
476 return Status::OK();
477 }
478
479 } // namespace tpu
480 } // namespace tensorflow
481