1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/tpu_compile.h"
16
17 #include "tensorflow/compiler/jit/flags.h"
18 #include "tensorflow/compiler/jit/shape_inference.h"
19 #include "tensorflow/compiler/tf2xla/layout_util.h"
20 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
21 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
22 #include "tensorflow/compiler/xla/client/compile_only_client.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 #include "tensorflow/core/common_runtime/graph_constructor.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/framework/versions.pb.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/platform/statusor.h"
31 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
32 #include "tensorflow/core/tpu/kernels/tpu_util.h"
33 #include "tensorflow/core/tpu/tpu_defs.h"
34
35 namespace tensorflow {
36 namespace tpu {
37 namespace {
38
CoreDevice(int core)39 std::string CoreDevice(int core) {
40 return strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE, ":", core);
41 }
42
43 static constexpr char kArgOp[] = "_Arg";
44 static constexpr char kRetvalOp[] = "_Retval";
45
46 // Sets arg shape, arg core mapping, and per core arg shapes for a given
47 // argument, depending on its sharding.
SetPerCoreArgShapes(const tpu::TPUCompileMetadataProto::Arg & proto_arg,const int arg_index,xla::Shape * xla_arg_shape,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes)48 Status SetPerCoreArgShapes(
49 const tpu::TPUCompileMetadataProto::Arg& proto_arg, const int arg_index,
50 xla::Shape* xla_arg_shape,
51 std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
52 std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
53 if (proto_arg.unrestricted_layout()) {
54 xla_arg_shape->clear_layout();
55 }
56
57 (*arg_core_mapping)[arg_index].sharding = proto_arg.sharding();
58 if (proto_arg.sharding().type() == xla::OpSharding::MAXIMAL) {
59 const int core = proto_arg.sharding().tile_assignment_devices(0);
60 TF_RET_CHECK(0 <= core && core < per_core_arg_shapes->size());
61 (*arg_core_mapping)[arg_index].indices.push_back(
62 (*per_core_arg_shapes)[core].size());
63 (*per_core_arg_shapes)[core].push_back(*xla_arg_shape);
64 } else if (proto_arg.sharding().type() == xla::OpSharding::OTHER) {
65 TF_ASSIGN_OR_RETURN(xla::HloSharding hlo_sharding,
66 xla::HloSharding::FromProto(proto_arg.sharding()));
67 for (int core : proto_arg.sharding().tile_assignment_devices()) {
68 (*arg_core_mapping)[arg_index].indices.push_back(
69 (*per_core_arg_shapes)[core].size());
70 xla::Shape per_core_shape =
71 GetPerDeviceShape(*xla_arg_shape, hlo_sharding, core);
72 if (proto_arg.unrestricted_layout()) {
73 per_core_shape.clear_layout();
74 }
75 (*per_core_arg_shapes)[core].push_back(per_core_shape);
76 }
77 } else {
78 TF_RET_CHECK(proto_arg.sharding().type() == xla::OpSharding::REPLICATED)
79 << "Unsupported argument sharding: "
80 << " proto_arg=" << proto_arg.DebugString();
81 for (int core = 0; core < per_core_arg_shapes->size(); ++core) {
82 (*arg_core_mapping)[arg_index].indices.push_back(
83 (*per_core_arg_shapes)[core].size());
84 (*per_core_arg_shapes)[core].push_back(*xla_arg_shape);
85 }
86 }
87
88 return OkStatus();
89 }
90
91 // Adds TPU_REPLICATED_CORE device assignments to the _Arg and _Retval
92 // nodes in `graph', using the sharding/index assignments in
93 // `arg_core_mapping` and `retval_core_mapping`. The mappings are maps from
94 // original argument/return index to (sharding, per-core argument/return
95 // index) pairs. Node attributes, such as device assignments, are not
96 // preserved on function argument and return values nodes, so we must recreate
97 // them the compilation metadata.
98 // Function arguments and return values lose their device assignments, so we
99 // must recreate them.
AssignDevicesToArgsAndRetvals(absl::Span<const tpu::ShardingAndIndex> arg_core_mapping,absl::Span<const tpu::ShardingAndIndex> retval_core_mapping,Graph * graph)100 Status AssignDevicesToArgsAndRetvals(
101 absl::Span<const tpu::ShardingAndIndex> arg_core_mapping,
102 absl::Span<const tpu::ShardingAndIndex> retval_core_mapping, Graph* graph) {
103 auto assign = [&](Node* node, const xla::OpSharding& sharding) -> Status {
104 if (sharding.type() == xla::OpSharding::MAXIMAL) {
105 const string device = CoreDevice(sharding.tile_assignment_devices(0));
106 node->set_assigned_device_name(device);
107 node->set_requested_device(device);
108 } else {
109 TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED ||
110 sharding.type() == xla::OpSharding::OTHER)
111 << "Unsupported sharding on parameter/retval: "
112 << sharding.DebugString();
113 }
114 node->AddAttr("_XlaSharding", sharding.SerializeAsString());
115 return OkStatus();
116 };
117 for (Node* node : graph->op_nodes()) {
118 if (node->type_string() == kArgOp) {
119 int index;
120 TF_RETURN_IF_ERROR(
121 tensorflow::GetNodeAttr(node->attrs(), "index", &index));
122 TF_RET_CHECK(index >= 0 && index < arg_core_mapping.size());
123 TF_RETURN_IF_ERROR(assign(node, arg_core_mapping[index].sharding));
124 } else if (node->type_string() == kRetvalOp) {
125 int index;
126 TF_RETURN_IF_ERROR(
127 tensorflow::GetNodeAttr(node->attrs(), "index", &index));
128 TF_RET_CHECK(index >= 0 && index < retval_core_mapping.size());
129 TF_RETURN_IF_ERROR(assign(node, retval_core_mapping[index].sharding));
130 }
131 }
132 return OkStatus();
133 }
134
ConvertGraphShapeInfoToShapeMap(const Graph & graph,const GraphShapeInfo & graph_shape_info,std::unordered_map<string,std::vector<PartialTensorShape>> * shape_map)135 void ConvertGraphShapeInfoToShapeMap(
136 const Graph& graph, const GraphShapeInfo& graph_shape_info,
137 std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map) {
138 // Builds a map from node name to Node* for `graph`.
139 std::unordered_map<string, Node*> index;
140 for (Node* node : graph.nodes()) {
141 index[node->name()] = node;
142 }
143 // Discards the resource handle shape info while converting to the correct map
144 // form.
145 for (const auto& node_shape_info : graph_shape_info) {
146 const string& node_name = node_shape_info.first;
147 const std::vector<InferredShape>& output_shapes = node_shape_info.second;
148 // Gets the vector of partial shapes, first converting node name to Node*
149 // using index. graph is the subgraph of the original graph assigned to a
150 // particular core, and we only add entries to shape_map for nodes in
151 // graph_shape_info that are in the subgraph.
152 const auto& node_iter = index.find(node_name);
153 if (node_iter != index.end()) {
154 auto& partial_shapes = (*shape_map)[node_name];
155 for (const auto& inferred_shape : output_shapes) {
156 partial_shapes.push_back(inferred_shape.shape);
157 }
158 }
159 }
160 }
161
162 // Optimizes `graph`, given the argument descriptions in `metadata` and
163 // `arg_shapes`.
OptimizeGraph(const tpu::TPUCompileMetadataProto & metadata,const std::vector<PartialTensorShape> & arg_shapes,std::unique_ptr<Graph> * graph,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld)164 Status OptimizeGraph(const tpu::TPUCompileMetadataProto& metadata,
165 const std::vector<PartialTensorShape>& arg_shapes,
166 std::unique_ptr<Graph>* graph, FunctionLibraryRuntime* flr,
167 FunctionLibraryDefinition* fld) {
168 // Sets up options for the optimization passes that need to be done. Notice
169 // that CSE is not needed as XLA has its own CSE passes later in the
170 // compilation stage.
171 auto flags = GetBuildXlaOpsPassFlags();
172 OptimizerOptions opts;
173 opts.set_opt_level(OptimizerOptions::L0);
174 opts.set_do_common_subexpression_elimination(false);
175 opts.set_do_function_inlining(true);
176 opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding);
177 GraphOptimizer optimizer(opts);
178 {
179 // Performs a first function inlining pass before shape inference, since
180 // otherwise shape inference can't see inside functions and a comprehensive
181 // shape_map, including function ops, is needed to constant-propagate Shape
182 // Ops below.
183 GraphOptimizer::Options optimizer_opts;
184 optimizer_opts.inline_multi_device_functions = true;
185 optimizer_opts.inline_impl_selection_group_functions = true;
186 optimizer_opts.inline_with_single_device_body_placer = true;
187 // Infer shapes for each node in the computation. Shape inference can help
188 // skip constant folding of large shapes.
189 GraphShapeInfo shape_info;
190 TF_RETURN_IF_ERROR(internal::RunShapeInferenceOnComputation(
191 metadata, arg_shapes, graph->get(), flr, &shape_info));
192 // Converts the GraphShapeInfo into the form needed by the constant-folding
193 // pass of the optimizer.
194 std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
195 ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
196 optimizer_opts.shape_map = &shape_map;
197 optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts);
198 }
199
200 {
201 // Infer shapes for each node in the computation.
202 GraphShapeInfo shape_info;
203 TF_RETURN_IF_ERROR(internal::RunShapeInferenceOnComputation(
204 metadata, arg_shapes, graph->get(), flr, &shape_info));
205 std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
206 ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
207 GraphOptimizer::Options optimizer_opts;
208 optimizer_opts.shape_map = &shape_map;
209 optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts);
210 }
211
212 TF_RETURN_IF_ERROR(RewriteTensorListWithConstElement(graph->get(), fld));
213
214 return OkStatus();
215 }
216
217 // Populates the mapping from return value to ShardingAndIndex.
AssignReturnValueToCore(const tpu::TPUCompileMetadataProto & metadata,std::vector<tpu::ShardingAndIndex> * retval_core_mapping)218 Status AssignReturnValueToCore(
219 const tpu::TPUCompileMetadataProto& metadata,
220 std::vector<tpu::ShardingAndIndex>* retval_core_mapping) {
221 std::vector<int> per_core_retval_counts(metadata.num_cores_per_replica(), 0);
222 for (int i = 0; i < metadata.retvals_size(); ++i) {
223 const tpu::TPUCompileMetadataProto::Retval& proto_retval =
224 metadata.retvals(i);
225 (*retval_core_mapping)[i].sharding = proto_retval.sharding();
226 if (proto_retval.sharding().type() == xla::OpSharding::MAXIMAL) {
227 int core = proto_retval.sharding().tile_assignment_devices(0);
228 TF_RET_CHECK(0 <= core && core < per_core_retval_counts.size());
229 (*retval_core_mapping)[i].indices.push_back(
230 per_core_retval_counts[core]++);
231 } else if (proto_retval.sharding().type() == xla::OpSharding::OTHER) {
232 for (int64_t core : proto_retval.sharding().tile_assignment_devices()) {
233 (*retval_core_mapping)[i].indices.push_back(
234 per_core_retval_counts[core]++);
235 }
236 } else {
237 TF_RET_CHECK(proto_retval.sharding().type() ==
238 xla::OpSharding::REPLICATED)
239 << "Unsupported return value sharding: "
240 << proto_retval.sharding().DebugString();
241 for (int core = 0; core < per_core_retval_counts.size(); ++core) {
242 (*retval_core_mapping)[i].indices.push_back(
243 per_core_retval_counts[core]++);
244 }
245 }
246 }
247 return OkStatus();
248 }
249
250 // Populates the arguments, core mapping and per core argument shape for the
251 // computation.
BuildComputationArgumentDescriptions(const std::vector<TensorShape> & arg_shapes,const GuaranteedConsts & guaranteed_constants,const XlaCompiler & compiler,const tpu::TPUCompileMetadataProto & metadata,std::vector<XlaCompiler::Argument> * args,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes)252 Status BuildComputationArgumentDescriptions(
253 const std::vector<TensorShape>& arg_shapes,
254 const GuaranteedConsts& guaranteed_constants, const XlaCompiler& compiler,
255 const tpu::TPUCompileMetadataProto& metadata,
256 std::vector<XlaCompiler::Argument>* args,
257 std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
258 std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
259 arg_core_mapping->clear();
260 arg_core_mapping->resize(metadata.args_size());
261
262 per_core_arg_shapes->clear();
263 per_core_arg_shapes->resize(metadata.num_cores_per_replica());
264
265 // Builds a description of the computation's arguments.
266 int constant_count = 0;
267 size_t guaranteed_constants_size = 0;
268 for (int i = 0; i < metadata.args_size(); ++i) {
269 const tpu::TPUCompileMetadataProto::Arg& proto_arg = metadata.args(i);
270 args->push_back(XlaCompiler::Argument());
271 XlaCompiler::Argument& arg = args->back();
272 arg.type = proto_arg.dtype();
273 arg.shape = arg_shapes[i];
274 arg.node_name = proto_arg.name();
275 switch (proto_arg.kind()) {
276 case tpu::TPUCompileMetadataProto::Arg::PARAMETER:
277 arg.kind = XlaCompiler::Argument::kParameter;
278 break;
279 case tpu::TPUCompileMetadataProto::Arg::VARIABLE:
280 arg.kind = XlaCompiler::Argument::kResource;
281 arg.resource_kind = XlaResource::kVariable;
282 arg.initialized = true;
283 arg.fast_mem = proto_arg.fast_mem();
284 break;
285 case tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT:
286 arg.kind = XlaCompiler::Argument::kConstant;
287 guaranteed_constants_size =
288 guaranteed_constants.index() == 0
289 ? absl::get<0>(guaranteed_constants).size()
290 : absl::get<1>(guaranteed_constants)->size();
291 TF_RET_CHECK(constant_count < guaranteed_constants_size)
292 << "More constant args in TPUCompileMetadataProto than constant "
293 "tensors.";
294 if (guaranteed_constants.index() == 0) {
295 // `guaranteed_constants` is of type `absl::Span<const TensorProto*
296 // const>`.
297 Tensor tensor;
298 CHECK(tensor.FromProto(
299 *absl::get<0>(guaranteed_constants)[constant_count++]))
300 << "Failed to deserialize invalid `TensorProto` into `Tensor`.";
301 arg.constant_value = tensor;
302 } else {
303 // `guaranteed_constants` is of type `const OpInputList* const`.
304 arg.constant_value =
305 (*absl::get<1>(guaranteed_constants))[constant_count++];
306 }
307 break;
308 case tpu::TPUCompileMetadataProto::Arg::INVALID:
309 default:
310 break;
311 }
312 arg.is_same_data_across_replicas = proto_arg.is_same_data_across_replicas();
313 arg.requires_broadcast = proto_arg.requires_xla_broadcast();
314 if (arg.kind == XlaCompiler::Argument::kInvalid) {
315 return errors::InvalidArgument("Invalid argument kind");
316 }
317 if (arg.kind == XlaCompiler::Argument::kConstant) {
318 continue;
319 }
320
321 // Assign each argument a sharding.
322 xla::Shape xla_arg_shape;
323 TF_ASSIGN_OR_RETURN(auto arg_sharding,
324 xla::HloSharding::FromProto(proto_arg.sharding()));
325 TF_RETURN_IF_ERROR(compiler.XLAShapeForArgument(
326 arg, /*is_entry_computation=*/true, arg_sharding, &xla_arg_shape));
327 TF_RETURN_IF_ERROR(SetPerCoreArgShapes(
328 proto_arg, i, &xla_arg_shape, arg_core_mapping, per_core_arg_shapes));
329 }
330 TF_RET_CHECK(constant_count == guaranteed_constants_size)
331 << "Not all of the constant tensors were consumed.";
332
333 return OkStatus();
334 }
335 } // namespace
336
337 namespace internal {
RunShapeInferenceOnComputation(const tpu::TPUCompileMetadataProto & metadata,const std::vector<PartialTensorShape> & arg_shapes,Graph * graph,FunctionLibraryRuntime * flr,GraphShapeInfo * shape_info)338 Status RunShapeInferenceOnComputation(
339 const tpu::TPUCompileMetadataProto& metadata,
340 const std::vector<PartialTensorShape>& arg_shapes, Graph* graph,
341 FunctionLibraryRuntime* flr, GraphShapeInfo* shape_info) {
342 int num_args = arg_shapes.size();
343 CHECK_EQ(num_args, metadata.args_size());
344
345 std::map<int, InferredShape> arg_shapes_for_inference;
346 for (int i = 0; i < num_args; ++i) {
347 const auto& arg = metadata.args(i);
348 InferredShape& shape_for_inference = arg_shapes_for_inference[i];
349 if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::VARIABLE) {
350 // For resource variables, arg_shapes[] contains the shape of the
351 // variable's value.
352 shape_for_inference.handle_type = arg.dtype();
353 shape_for_inference.handle_shape = arg_shapes[i];
354 // The shape of the variable itself is always a scalar.
355 shape_for_inference.shape = TensorShape();
356 } else {
357 if (arg.kind() ==
358 tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
359 VLOG(1) << "PromisedConstant shape: " << arg_shapes[i].DebugString();
360 }
361 shape_for_inference.shape = arg_shapes[i];
362 }
363 }
364 return InferShapes(
365 graph, arg_shapes_for_inference,
366 flr != nullptr ? flr->GetFunctionLibraryDefinition() : nullptr,
367 shape_info);
368 }
369 } // namespace internal
370
CompileTFFunctionToHlo(const FunctionLibraryDefinition & flib_def,int graph_def_version,const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,const std::vector<TensorShape> & arg_shapes,const GuaranteedConsts & guaranteed_constants,const NameAttrList & function,const tpu::TPUCompileMetadataProto & metadata,xla::CompileOnlyClient * client,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes,bool use_tuple_args,XlaCompiler::CompilationResult * compilation_result)371 Status CompileTFFunctionToHlo(
372 const FunctionLibraryDefinition& flib_def, int graph_def_version,
373 const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
374 const std::vector<TensorShape>& arg_shapes,
375 const GuaranteedConsts& guaranteed_constants, const NameAttrList& function,
376 const tpu::TPUCompileMetadataProto& metadata,
377 xla::CompileOnlyClient* client,
378 std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
379 std::vector<std::vector<xla::Shape>>* per_core_arg_shapes,
380 bool use_tuple_args, XlaCompiler::CompilationResult* compilation_result) {
381 XlaCompiler::Options compiler_options;
382 FunctionLibraryDefinition flib_definition(flib_def);
383 compiler_options.device_type = DeviceType(DEVICE_TPU_XLA_JIT);
384 compiler_options.client = client;
385 compiler_options.flib_def = &flib_definition;
386 compiler_options.allow_cpu_custom_calls = false;
387 compiler_options.graph_def_version = graph_def_version;
388 compiler_options.shape_determination_fns = shape_determination_fns;
389
390 auto compiler = std::make_unique<XlaCompiler>(compiler_options);
391
392 std::vector<XlaCompiler::Argument> args;
393 TF_RETURN_IF_ERROR(BuildComputationArgumentDescriptions(
394 arg_shapes, guaranteed_constants, *compiler, metadata, &args,
395 arg_core_mapping, per_core_arg_shapes));
396
397 // Assign each return value to a core.
398 std::vector<tpu::ShardingAndIndex> retval_core_mapping(
399 metadata.retvals_size());
400 TF_RETURN_IF_ERROR(AssignReturnValueToCore(metadata, &retval_core_mapping));
401
402 LOG(INFO) << "Instantiating function:" << function.name();
403 FunctionLibraryRuntime::Handle handle;
404 TF_RETURN_IF_ERROR(compiler->flib_runtime()->Instantiate(
405 function.name(), AttrSlice(&function.attr()), &handle));
406 const FunctionBody* fbody = compiler->flib_runtime()->GetFunctionBody(handle);
407 const string function_id =
408 Canonicalize(function.name(), AttrSlice(&function.attr()));
409
410 std::unique_ptr<Graph> graph(new Graph(&flib_definition));
411 CopyGraph(*fbody->graph, graph.get());
412
413 VLOG(2) << "metadata: " << metadata.DebugString();
414 TF_RET_CHECK(fbody->arg_nodes.size() == args.size());
415 for (size_t i = 0; i < fbody->arg_nodes.size(); i++) {
416 args[i].node_name = fbody->arg_nodes[i]->name();
417 }
418
419 std::vector<gtl::InlinedVector<int64_t, 4>> arg_shape_dims;
420 arg_shape_dims.reserve(arg_shapes.size());
421 std::vector<PartialTensorShape> partial_arg_shapes(arg_shapes.size());
422 for (const TensorShape& shape : arg_shapes) {
423 arg_shape_dims.push_back(shape.dim_sizes());
424 }
425
426 for (int64_t i = 0; i < arg_shape_dims.size(); ++i) {
427 auto& dims = arg_shape_dims[i];
428 TF_RETURN_IF_ERROR(PartialTensorShape::MakePartialShape(
429 dims.data(), dims.size(), &partial_arg_shapes[i]));
430 }
431
432 // Adds device assignments to _Arg and _Retval nodes.
433 TF_RETURN_IF_ERROR(AssignDevicesToArgsAndRetvals(
434 absl::MakeSpan(*arg_core_mapping), absl::MakeSpan(retval_core_mapping),
435 graph.get()));
436
437 VLOG(1) << "Optimizing TensorFlow graph";
438 TF_RETURN_IF_ERROR(OptimizeGraph(metadata, partial_arg_shapes, &graph,
439 compiler->flib_runtime(), &flib_definition));
440
441 VLOG(1) << "Compiling TensorFlow graph to HLO";
442 XlaCompiler::CompileOptions compile_options;
443 compile_options.return_updated_values_for_all_resources = false;
444 compile_options.use_tuple_arg = use_tuple_args;
445 compile_options.is_entry_computation = true;
446 compile_options.alias_resource_update = true;
447 return compiler->CompileGraph(compile_options, function_id, std::move(graph),
448 args, compilation_result);
449 }
450
GetShardingInfo(const tpu::TPUCompileMetadataProto & metadata,absl::Span<const TensorShape> arg_shapes,const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes)451 Status GetShardingInfo(
452 const tpu::TPUCompileMetadataProto& metadata,
453 absl::Span<const TensorShape> arg_shapes,
454 const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
455 std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
456 std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
457 arg_core_mapping->clear();
458 arg_core_mapping->resize(metadata.args_size());
459
460 per_core_arg_shapes->clear();
461 per_core_arg_shapes->resize(metadata.num_cores_per_replica());
462
463 int num_inputs = metadata.args_size();
464 for (int i = 0; i < num_inputs; ++i) {
465 const auto& proto_arg = metadata.args(i);
466 TF_ASSIGN_OR_RETURN(auto arg_sharding,
467 xla::HloSharding::FromProto(proto_arg.sharding()));
468 auto layout_preference = shape_determination_fns.layout_preference_fn(
469 arg_shapes[i], proto_arg.dtype(), absl::nullopt);
470 TF_ASSIGN_OR_RETURN(auto xla_arg_shape,
471 shape_determination_fns.shape_representation_fn(
472 arg_shapes[i], proto_arg.dtype(),
473 /*use_fast_memory=*/false, layout_preference));
474 TF_RETURN_IF_ERROR(
475 RewriteLayoutWithShardedShape(arg_sharding, /*use_fast_memory=*/false,
476 shape_determination_fns, &xla_arg_shape));
477 TF_RETURN_IF_ERROR(SetPerCoreArgShapes(
478 proto_arg, i, &xla_arg_shape, arg_core_mapping, per_core_arg_shapes));
479 }
480 return OkStatus();
481 }
482
483 } // namespace tpu
484 } // namespace tensorflow
485