1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h"
16
17 #include <string>
18
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/compiler/jit/flags.h"
21 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
22 #include "tensorflow/compiler/xla/client/client_library.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/graph_constructor.h"
26 #include "tensorflow/core/common_runtime/graph_optimizer.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/metrics.h"
29 #include "tensorflow/core/framework/resource_mgr.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/status.h"
32 #include "tensorflow/core/protobuf/tpu/compilation_result.pb.h"
33 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
34 #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
35 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_unloader.h"
36 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
37 #include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h"
38 #include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
39 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
40 #include "tensorflow/core/tpu/kernels/tpu_op_util.h"
41 #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
42 #include "tensorflow/core/tpu/kernels/tpu_util.h"
43 #include "tensorflow/core/tpu/tpu_api.h"
44 #include "tensorflow/core/tpu/tpu_compile_interface.h"
45 #include "tensorflow/core/tpu/tpu_configuration.h"
46 #include "tensorflow/core/tpu/tpu_defs.h"
47 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
48
49 namespace tensorflow {
50 namespace tpu {
51
52 namespace {
53
54 static constexpr char kArgOp[] = "_Arg";
55 static constexpr char kRetvalOp[] = "_Retval";
56
CoreDevice(int core)57 std::string CoreDevice(int core) {
58 return strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE, ":", core);
59 }
60
ConvertGraphShapeInfoToShapeMap(const Graph & graph,const GraphShapeInfo & graph_shape_info,std::unordered_map<string,std::vector<PartialTensorShape>> * shape_map)61 void ConvertGraphShapeInfoToShapeMap(
62 const Graph& graph, const GraphShapeInfo& graph_shape_info,
63 std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map) {
64 // Builds a map from node name to Node* for `graph`.
65 std::unordered_map<string, Node*> index;
66 for (Node* node : graph.nodes()) {
67 index[node->name()] = node;
68 }
69 // Discards the resource handle shape info while converting to the correct map
70 // form.
71 for (const auto& node_shape_info : graph_shape_info) {
72 const string& node_name = node_shape_info.first;
73 const std::vector<InferredShape>& output_shapes = node_shape_info.second;
74 // Gets the vector of partial shapes, first converting node name to Node*
75 // using index. graph is the subgraph of the original graph assigned to a
76 // particular core, and we only add entries to shape_map for nodes in
77 // graph_shape_info that are in the subgraph.
78 const auto& node_iter = index.find(node_name);
79 if (node_iter != index.end()) {
80 auto& partial_shapes = (*shape_map)[node_name];
81 for (const auto& inferred_shape : output_shapes) {
82 partial_shapes.push_back(inferred_shape.shape);
83 }
84 }
85 }
86 }
87
88 // Sets arg shape, arg core mapping, and per core arg shapes for a given
89 // argument, depending on its sharding.
SetPerCoreArgShapes(const tpu::TPUCompileMetadataProto::Arg & proto_arg,const int arg_index,xla::Shape * xla_arg_shape,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes)90 Status SetPerCoreArgShapes(
91 const tpu::TPUCompileMetadataProto::Arg& proto_arg, const int arg_index,
92 xla::Shape* xla_arg_shape,
93 std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
94 std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
95 if (proto_arg.unrestricted_layout()) {
96 xla_arg_shape->clear_layout();
97 }
98
99 (*arg_core_mapping)[arg_index].sharding = proto_arg.sharding();
100 if (proto_arg.sharding().type() == xla::OpSharding::MAXIMAL) {
101 const int core = proto_arg.sharding().tile_assignment_devices(0);
102 TF_RET_CHECK(0 <= core && core < per_core_arg_shapes->size());
103 (*arg_core_mapping)[arg_index].indices.push_back(
104 (*per_core_arg_shapes)[core].size());
105 (*per_core_arg_shapes)[core].push_back(*xla_arg_shape);
106 } else if (proto_arg.sharding().type() == xla::OpSharding::OTHER) {
107 TF_ASSIGN_OR_RETURN(xla::HloSharding hlo_sharding,
108 xla::HloSharding::FromProto(proto_arg.sharding()));
109 for (int core : proto_arg.sharding().tile_assignment_devices()) {
110 (*arg_core_mapping)[arg_index].indices.push_back(
111 (*per_core_arg_shapes)[core].size());
112 xla::Shape per_core_shape =
113 GetPerDeviceShape(*xla_arg_shape, hlo_sharding, core);
114 if (proto_arg.unrestricted_layout()) {
115 per_core_shape.clear_layout();
116 }
117 (*per_core_arg_shapes)[core].push_back(per_core_shape);
118 }
119 } else {
120 TF_RET_CHECK(proto_arg.sharding().type() == xla::OpSharding::REPLICATED)
121 << "Unsupported argument sharding: "
122 << " proto_arg=" << proto_arg.DebugString();
123 for (int core = 0; core < per_core_arg_shapes->size(); ++core) {
124 (*arg_core_mapping)[arg_index].indices.push_back(
125 (*per_core_arg_shapes)[core].size());
126 (*per_core_arg_shapes)[core].push_back(*xla_arg_shape);
127 }
128 }
129
130 return Status::OK();
131 }
132
133 } // namespace
134
135 CompileOpImplFactory* CompileOpImplFactory::factory_ = nullptr;
136
137 /* static */
Get()138 CompileOpImplFactory* CompileOpImplFactory::Get() { return factory_; }
139
140 /* static */
Register(CompileOpImplFactory * factory)141 void CompileOpImplFactory::Register(CompileOpImplFactory* factory) {
142 CHECK_EQ(factory_, nullptr)
143 << "CompileOpImplFactory can only be registered "
144 "once and there can only be one factory active and used.";
145 factory_ = factory;
146 }
147
AssignReturnValueToCore(std::vector<tpu::ShardingAndIndex> * retval_core_mapping)148 Status TpuCompileOpKernelCommon::AssignReturnValueToCore(
149 std::vector<tpu::ShardingAndIndex>* retval_core_mapping) {
150 std::vector<int> per_core_retval_counts(metadata_.num_cores_per_replica(), 0);
151 for (int i = 0; i < metadata_.retvals_size(); ++i) {
152 const tpu::TPUCompileMetadataProto::Retval& proto_retval =
153 metadata_.retvals(i);
154 (*retval_core_mapping)[i].sharding = proto_retval.sharding();
155 if (proto_retval.sharding().type() == xla::OpSharding::MAXIMAL) {
156 int core = proto_retval.sharding().tile_assignment_devices(0);
157 TF_RET_CHECK(0 <= core && core < per_core_retval_counts.size());
158 (*retval_core_mapping)[i].indices.push_back(
159 per_core_retval_counts[core]++);
160 } else if (proto_retval.sharding().type() == xla::OpSharding::OTHER) {
161 for (int64 core : proto_retval.sharding().tile_assignment_devices()) {
162 (*retval_core_mapping)[i].indices.push_back(
163 per_core_retval_counts[core]++);
164 }
165 } else {
166 TF_RET_CHECK(proto_retval.sharding().type() ==
167 xla::OpSharding::REPLICATED)
168 << "Unsupported return value sharding: "
169 << proto_retval.sharding().DebugString();
170 for (int core = 0; core < per_core_retval_counts.size(); ++core) {
171 (*retval_core_mapping)[i].indices.push_back(
172 per_core_retval_counts[core]++);
173 }
174 }
175 }
176 return Status::OK();
177 }
178
BuildComputationArgumentDescriptions(const std::vector<TensorShape> & arg_shapes,const GuaranteedConsts & guaranteed_constants,const XlaCompiler & compiler,std::vector<XlaCompiler::Argument> * args,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes)179 Status TpuCompileOpKernelCommon::BuildComputationArgumentDescriptions(
180 const std::vector<TensorShape>& arg_shapes,
181 const GuaranteedConsts& guaranteed_constants, const XlaCompiler& compiler,
182 std::vector<XlaCompiler::Argument>* args,
183 std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
184 std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
185 // Builds a description of the computation's arguments.
186 int constant_count = 0;
187 size_t guaranteed_constants_size = 0;
188 for (int i = 0; i < metadata_.args_size(); ++i) {
189 const tpu::TPUCompileMetadataProto::Arg& proto_arg = metadata_.args(i);
190 args->push_back(XlaCompiler::Argument());
191 XlaCompiler::Argument& arg = args->back();
192 arg.type = proto_arg.dtype();
193 arg.shape = arg_shapes[i];
194 arg.node_name = proto_arg.name();
195 switch (proto_arg.kind()) {
196 case tpu::TPUCompileMetadataProto::Arg::PARAMETER:
197 arg.kind = XlaCompiler::Argument::kParameter;
198 break;
199 case tpu::TPUCompileMetadataProto::Arg::VARIABLE:
200 arg.kind = XlaCompiler::Argument::kResource;
201 arg.resource_kind = XlaResource::kVariable;
202 arg.initialized = true;
203 arg.fast_mem = proto_arg.fast_mem();
204 break;
205 case tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT:
206 arg.kind = XlaCompiler::Argument::kConstant;
207 guaranteed_constants_size =
208 guaranteed_constants.index() == 0
209 ? absl::get<0>(guaranteed_constants).size()
210 : absl::get<1>(guaranteed_constants)->size();
211 TF_RET_CHECK(constant_count < guaranteed_constants_size)
212 << "More constant args in TPUCompileMetadataProto than constant "
213 "tensors.";
214 if (guaranteed_constants.index() == 0) {
215 // `guaranteed_constants` is of type `absl::Span<const TensorProto*
216 // const>`.
217 Tensor tensor;
218 CHECK(tensor.FromProto(
219 *absl::get<0>(guaranteed_constants)[constant_count++]))
220 << "Failed to deserialize invalid `TensorProto` into `Tensor`.";
221 arg.constant_value = tensor;
222 } else {
223 // `guaranteed_constants` is of type `const OpInputList* const`.
224 arg.constant_value =
225 (*absl::get<1>(guaranteed_constants))[constant_count++];
226 }
227 break;
228 case tpu::TPUCompileMetadataProto::Arg::INVALID:
229 default:
230 break;
231 }
232 arg.is_same_data_across_replicas = proto_arg.is_same_data_across_replicas();
233 if (arg.kind == XlaCompiler::Argument::kInvalid) {
234 return errors::InvalidArgument("Invalid argument kind");
235 }
236 if (arg.kind == XlaCompiler::Argument::kConstant) {
237 continue;
238 }
239
240 // Assign each argument a sharding.
241 xla::Shape xla_arg_shape;
242 TF_ASSIGN_OR_RETURN(auto arg_sharding,
243 xla::HloSharding::FromProto(proto_arg.sharding()));
244 TF_RETURN_IF_ERROR(compiler.XLAShapeForArgument(
245 arg, /*is_entry_computation=*/true, arg_sharding, &xla_arg_shape));
246 TF_RETURN_IF_ERROR(SetPerCoreArgShapes(
247 proto_arg, i, &xla_arg_shape, arg_core_mapping, per_core_arg_shapes));
248 }
249 TF_RET_CHECK(constant_count == guaranteed_constants_size)
250 << "Not all of the constant tensors were consumed.";
251
252 return Status::OK();
253 }
254
GetShardingInfo(absl::Span<const TensorShape> arg_shapes,const XlaCompiler::ShapeRepresentationFn shape_representation_fn,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes)255 Status TpuCompileOpKernelCommon::GetShardingInfo(
256 absl::Span<const TensorShape> arg_shapes,
257 const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
258 std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
259 std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
260 int num_inputs = metadata_.args_size();
261 for (int i = 0; i < num_inputs; ++i) {
262 const auto& proto_arg = metadata_.args(i);
263 TF_ASSIGN_OR_RETURN(auto arg_sharding,
264 xla::HloSharding::FromProto(proto_arg.sharding()));
265 TF_ASSIGN_OR_RETURN(
266 auto xla_arg_shape,
267 shape_representation_fn(arg_shapes[i], proto_arg.dtype(),
268 /*use_fast_memory=*/false));
269 TF_RETURN_IF_ERROR(
270 RewriteLayoutWithShardedShape(arg_sharding, /*use_fast_memory=*/false,
271 shape_representation_fn, &xla_arg_shape));
272 TF_RETURN_IF_ERROR(SetPerCoreArgShapes(
273 proto_arg, i, &xla_arg_shape, arg_core_mapping, per_core_arg_shapes));
274 }
275 return Status::OK();
276 }
277
CompileTFFunctionToHlo(const FunctionLibraryDefinition & flib_def,int graph_def_version,const XlaCompiler::ShapeRepresentationFn shape_representation_fn,const std::vector<TensorShape> & arg_shapes,const GuaranteedConsts & guaranteed_constants,const NameAttrList & function,std::function<Status (ResourceMgr *)> populate_resource_manager_fn,xla::CompileOnlyClient * client,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes,XlaCompiler::CompilationResult * compilation_result)278 Status TpuCompileOpKernelCommon::CompileTFFunctionToHlo(
279 const FunctionLibraryDefinition& flib_def, int graph_def_version,
280 const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
281 const std::vector<TensorShape>& arg_shapes,
282 const GuaranteedConsts& guaranteed_constants, const NameAttrList& function,
283 std::function<Status(ResourceMgr*)> populate_resource_manager_fn,
284 xla::CompileOnlyClient* client,
285 std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
286 std::vector<std::vector<xla::Shape>>* per_core_arg_shapes,
287 XlaCompiler::CompilationResult* compilation_result) {
288 XlaCompiler::Options compiler_options;
289 compiler_options.device_type = DeviceType(DEVICE_TPU_XLA_JIT);
290 compiler_options.client = client;
291 compiler_options.flib_def = &flib_def;
292 compiler_options.allow_cpu_custom_calls = false;
293 compiler_options.populate_resource_manager = &populate_resource_manager_fn;
294 compiler_options.graph_def_version = graph_def_version;
295 compiler_options.shape_representation_fn = shape_representation_fn;
296
297 auto compiler = absl::make_unique<XlaCompiler>(compiler_options);
298
299 std::vector<XlaCompiler::Argument> args;
300 TF_RETURN_IF_ERROR(BuildComputationArgumentDescriptions(
301 arg_shapes, guaranteed_constants, *compiler, &args, arg_core_mapping,
302 per_core_arg_shapes));
303
304 // Assign each return value to a core.
305 std::vector<tpu::ShardingAndIndex> retval_core_mapping(
306 metadata_.retvals_size());
307 TF_RETURN_IF_ERROR(
308 TpuCompileOpKernelCommon::AssignReturnValueToCore(&retval_core_mapping));
309
310 LOG(INFO) << "Instantiating function:" << function.name();
311 FunctionLibraryRuntime::Handle handle;
312 TF_RETURN_IF_ERROR(compiler->flib_runtime()->Instantiate(
313 function.name(), AttrSlice(&function.attr()), &handle));
314 const FunctionBody* fbody = compiler->flib_runtime()->GetFunctionBody(handle);
315 const string function_id =
316 Canonicalize(function.name(), AttrSlice(&function.attr()));
317
318 std::unique_ptr<Graph> graph(new Graph(&flib_def));
319 CopyGraph(*fbody->graph, graph.get());
320
321 VLOG(2) << "metadata: " << metadata_.DebugString();
322 std::vector<int> parameter_arg_mapping;
323 for (int i = 0; i < args.size(); i++) {
324 XlaCompiler::Argument& arg = args[i];
325 if (arg.kind != XlaCompiler::Argument::kParameter) {
326 continue;
327 }
328 parameter_arg_mapping.push_back(i);
329 }
330 TF_RET_CHECK(fbody->arg_nodes.size() == args.size());
331 for (size_t i = 0; i < fbody->arg_nodes.size(); i++) {
332 args[i].node_name = fbody->arg_nodes[i]->name();
333 }
334
335 std::vector<gtl::InlinedVector<int64, 4>> arg_shape_dims;
336 arg_shape_dims.reserve(arg_shapes.size());
337 std::vector<PartialTensorShape> partial_arg_shapes(arg_shapes.size());
338 for (const TensorShape& shape : arg_shapes) {
339 arg_shape_dims.push_back(shape.dim_sizes());
340 }
341
342 for (const auto& padding_mapping : metadata_.padding_maps()) {
343 if (padding_mapping.padding_arg_index() >= parameter_arg_mapping.size()) {
344 return errors::Internal(absl::StrCat(
345 "TPUCompileMetadataProto `padding_maps` has `padding_arg_index` ",
346 padding_mapping.padding_arg_index(),
347 " which exceeds`parameter_arg_mapping` array bounds ",
348 parameter_arg_mapping.size(),
349 ". this usually indicates there are dynamic shape inputs fed into "
350 "TPUs from outside compilation head extraction, which is not "
351 "supported"));
352 }
353 int padding_arg_index =
354 parameter_arg_mapping.at(padding_mapping.padding_arg_index());
355 args[parameter_arg_mapping.at(padding_mapping.arg_index())]
356 .dynamic_dim_to_arg_num_map[padding_mapping.shape_index()] =
357 padding_arg_index;
358 arg_shape_dims[parameter_arg_mapping.at(padding_mapping.arg_index())]
359 [padding_mapping.shape_index()] = -1;
360 args[padding_arg_index].is_pad_arg = true;
361 }
362
363 for (int64 i = 0; i < arg_shape_dims.size(); ++i) {
364 auto& dims = arg_shape_dims[i];
365 TF_RETURN_IF_ERROR(PartialTensorShape::MakePartialShape(
366 dims.data(), dims.size(), &partial_arg_shapes[i]));
367 }
368
369 // Adds device assignments to _Arg and _Retval nodes.
370 TF_RETURN_IF_ERROR(AssignDevicesToArgsAndRetvals(
371 absl::MakeSpan(*arg_core_mapping), absl::MakeSpan(retval_core_mapping),
372 graph.get()));
373
374 VLOG(1) << "Optimizing TensorFlow graph";
375 FunctionLibraryDefinition flib_definition(flib_def);
376 TF_RETURN_IF_ERROR(OptimizeGraph(metadata_, partial_arg_shapes, &graph,
377 compiler->flib_runtime(), &flib_definition));
378
379 VLOG(1) << "Compiling TensorFlow graph to HLO";
380 XlaCompiler::CompileOptions compile_options;
381 compile_options.return_updated_values_for_all_resources = false;
382 compile_options.use_tuple_arg = true;
383 compile_options.is_entry_computation = true;
384 compile_options.alias_resource_update = true;
385 return compiler->CompileGraph(compile_options, function_id, std::move(graph),
386 args, compilation_result);
387 }
388
ExitCountdown(Env * env,std::shared_ptr<std::atomic<bool>> done)389 /* static */ void TpuCompileOpKernelCommon::ExitCountdown(
390 Env* env, std::shared_ptr<std::atomic<bool>> done) {
391 const int kSleepSeconds = 300;
392 LOG(INFO) << "TpuCompileOp was cancelled. Sleeping for " << kSleepSeconds
393 << " seconds to give time for TPUCompileOp to finished.";
394 env->SleepForMicroseconds(kSleepSeconds * 1000000);
395 if (done->load()) {
396 // If the TpuCompileOp has finished, then terminate peacefully.
397 return;
398 }
399
400 LOG(ERROR) << "Aborting process due to cancelled TpuCompileOp. This "
401 << "termination is to ensure a consistent state.";
402 std::exit(42);
403 }
404
GetDynamicShapes(OpKernelContext * ctx,std::vector<TensorShape> * shapes)405 /* static */ Status TpuCompileOpKernelCommon::GetDynamicShapes(
406 OpKernelContext* ctx, std::vector<TensorShape>* shapes) {
407 OpInputList dynamic_shapes;
408 TF_RETURN_IF_ERROR(ctx->input_list("dynamic_shapes", &dynamic_shapes));
409
410 shapes->resize(dynamic_shapes.size());
411 for (int i = 0; i < dynamic_shapes.size(); ++i) {
412 TF_RETURN_IF_ERROR(
413 tpu::ShapeTensorToTensorShape(dynamic_shapes[i], &(*shapes)[i]));
414 }
415 return Status::OK();
416 }
417
418 // Function arguments and return values lose their device assignments, so we
419 // must recreate them.
AssignDevicesToArgsAndRetvals(absl::Span<const tpu::ShardingAndIndex> arg_core_mapping,absl::Span<const tpu::ShardingAndIndex> retval_core_mapping,Graph * graph)420 /* static */ Status TpuCompileOpKernelCommon::AssignDevicesToArgsAndRetvals(
421 absl::Span<const tpu::ShardingAndIndex> arg_core_mapping,
422 absl::Span<const tpu::ShardingAndIndex> retval_core_mapping, Graph* graph) {
423 auto assign = [&](Node* node, const xla::OpSharding& sharding) -> Status {
424 if (sharding.type() == xla::OpSharding::MAXIMAL) {
425 const string device = CoreDevice(sharding.tile_assignment_devices(0));
426 node->set_assigned_device_name(device);
427 node->set_requested_device(device);
428 } else {
429 TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED ||
430 sharding.type() == xla::OpSharding::OTHER)
431 << "Unsupported sharding on parameter/retval: "
432 << sharding.DebugString();
433 }
434 node->AddAttr("_XlaSharding", sharding.SerializeAsString());
435 return Status::OK();
436 };
437 for (Node* node : graph->op_nodes()) {
438 if (node->type_string() == kArgOp) {
439 int index;
440 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
441 TF_RET_CHECK(index >= 0 && index < arg_core_mapping.size());
442 TF_RETURN_IF_ERROR(assign(node, arg_core_mapping[index].sharding));
443 } else if (node->type_string() == kRetvalOp) {
444 int index;
445 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
446 TF_RET_CHECK(index >= 0 && index < retval_core_mapping.size());
447 TF_RETURN_IF_ERROR(assign(node, retval_core_mapping[index].sharding));
448 }
449 }
450 return Status::OK();
451 }
452
453 // Performs shape inference on the body of `graph`. Shapes for arguments
454 // are taken from `metadata` and `arg_shapes`.
RunShapeInferenceOnComputation(const tpu::TPUCompileMetadataProto & metadata,const std::vector<PartialTensorShape> & arg_shapes,Graph * graph,FunctionLibraryRuntime * flr,GraphShapeInfo * shape_info)455 /* static */ Status TpuCompileOpKernelCommon::RunShapeInferenceOnComputation(
456 const tpu::TPUCompileMetadataProto& metadata,
457 const std::vector<PartialTensorShape>& arg_shapes, Graph* graph,
458 FunctionLibraryRuntime* flr, GraphShapeInfo* shape_info) {
459 int num_args = arg_shapes.size();
460 CHECK_EQ(num_args, metadata.args_size());
461
462 std::map<int, InferredShape> arg_shapes_for_inference;
463 for (int i = 0; i < num_args; ++i) {
464 const auto& arg = metadata.args(i);
465 InferredShape& shape_for_inference = arg_shapes_for_inference[i];
466 if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::VARIABLE) {
467 // For resource variables, arg_shapes[] contains the shape of the
468 // variable's value.
469 shape_for_inference.handle_type = arg.dtype();
470 shape_for_inference.handle_shape = arg_shapes[i];
471 // The shape of the variable itself is always a scalar.
472 shape_for_inference.shape = TensorShape();
473 } else {
474 if (arg.kind() ==
475 tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
476 VLOG(1) << "PromisedConstant shape: " << arg_shapes[i].DebugString();
477 }
478 shape_for_inference.shape = arg_shapes[i];
479 }
480 }
481 return InferShapes(
482 graph, arg_shapes_for_inference,
483 flr != nullptr ? flr->GetFunctionLibraryDefinition() : nullptr,
484 shape_info);
485 }
486
OptimizeGraph(const tpu::TPUCompileMetadataProto & metadata,const std::vector<PartialTensorShape> & arg_shapes,std::unique_ptr<Graph> * graph,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld)487 Status TpuCompileOpKernelCommon::OptimizeGraph(
488 const tpu::TPUCompileMetadataProto& metadata,
489 const std::vector<PartialTensorShape>& arg_shapes,
490 std::unique_ptr<Graph>* graph, FunctionLibraryRuntime* flr,
491 FunctionLibraryDefinition* fld) {
492 // Sets up options for the optimization passes that need to be done. Notice
493 // that CSE is not needed as XLA has its own CSE passes later in the
494 // compilation stage.
495 auto flags = GetBuildXlaOpsPassFlags();
496 OptimizerOptions opts;
497 opts.set_opt_level(OptimizerOptions::L0);
498 opts.set_do_common_subexpression_elimination(false);
499 opts.set_do_function_inlining(true);
500 opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding);
501 GraphOptimizer optimizer(opts);
502 {
503 // Performs a first function inlining pass before shape inference, since
504 // otherwise shape inference can't see inside functions and a comprehensive
505 // shape_map, including function ops, is needed to constant-propagate Shape
506 // Ops below.
507 GraphOptimizer::Options optimizer_opts;
508 optimizer_opts.inline_multi_device_functions = true;
509 optimizer_opts.inline_impl_selection_group_functions = true;
510 optimizer_opts.inline_with_single_device_body_placer = true;
511 // Infer shapes for each node in the computation. Shape inference can help
512 // skip constant folding of large shapes.
513 GraphShapeInfo shape_info;
514 TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation(
515 metadata, arg_shapes, graph->get(), flr, &shape_info));
516 // Converts the GraphShapeInfo into the form needed by the constant-folding
517 // pass of the optimizer.
518 std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
519 ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
520 optimizer_opts.shape_map = &shape_map;
521 optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts);
522 }
523
524 {
525 // Infer shapes for each node in the computation.
526 GraphShapeInfo shape_info;
527 TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation(
528 metadata, arg_shapes, graph->get(), flr, &shape_info));
529 std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
530 ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
531 optimizer.Optimize(flr, flr->env(), flr->device(), graph, &shape_map);
532 }
533
534 TF_RETURN_IF_ERROR(RewriteTensorListWithConstElement(graph->get(), fld));
535
536 return Status::OK();
537 }
538
Compute(OpKernelContext * ctx)539 void TpuCompileOpKernelCommon::Compute(OpKernelContext* ctx) {
540 VLOG(1) << "Cloud TPU: TpuCompileOpKernelCommon::Compute";
541
542 std::shared_ptr<std::atomic<bool>> done(new std::atomic<bool>(false));
543
544 CancellationToken token =
545 ctx->cancellation_manager()->get_cancellation_token();
546 const bool already_cancelled =
547 !ctx->cancellation_manager()->RegisterCallback(token, [ctx, done]() {
548 if (OpsApiFn()->TpuCompile_ShouldTpuCompileOpIgnoreCancellationFn()) {
549 return;
550 }
551
552 // Sleep and exit in another thread so the cancellation manager can
553 // continue running callbacks.
554 Env* env = ctx->env();
555 env->SchedClosure([env, done]() { ExitCountdown(env, done); });
556 });
557
558 // If the RPC was cancelled before we registered the cancellation callback,
559 // don't compile the TPU program.
560 OP_REQUIRES(ctx, !already_cancelled,
561 errors::Cancelled("RPC cancelled, not compiling TPU program"));
562
563 // We only want to abort the process if a cancellation actually occurs during
564 // compilation; we must deregister the callback in the success case. It
565 // doesn't hurt to also deregister the callback in the failure case; the
566 // CancellationManager ensures that already-registered callbacks will be run
567 // once cancellation has started.
568 auto cancellation_cleanup = xla::MakeCleanup([ctx, token, done] {
569 ctx->cancellation_manager()->DeregisterCallback(token);
570 done->store(true);
571 });
572
573 Status compile_status = ComputeInternal(ctx);
574 string status_payload;
575 // Construct payload if compile_status is not ok and there's no payload for
576 // compilation yet.
577 if (!compile_status.ok() &&
578 compile_status.GetPayload(TpuCompileInterface::kTpuCompileErrorPayloadKey)
579 .empty()) {
580 tpu::CompilationResultProto proto;
581 proto.set_status_code(compile_status.code());
582 proto.set_status_error_message(compile_status.error_message());
583 status_payload = proto.SerializeAsString();
584 }
585 OP_REQUIRES_OK_OR_SET_PAYLOAD(ctx,
586 TpuCompileInterface::kTpuCompileErrorPayloadKey,
587 status_payload, compile_status);
588 }
589
CompileLocallyAndFillHostCache(FunctionLibraryRuntime * flib_runtime,const SessionMetadata * session_metadata,const TpuMeshStateInterface * mesh_state,const std::vector<TensorShape> & dynamic_shapes,const OpInputList & guaranteed_constants,const TpuCompilationCacheKey & key,TpuProgramGroupInterface * tpu_program_group)590 Status TpuCompileOpKernelCommon::CompileLocallyAndFillHostCache(
591 FunctionLibraryRuntime* flib_runtime,
592 const SessionMetadata* session_metadata,
593 const TpuMeshStateInterface* mesh_state,
594 const std::vector<TensorShape>& dynamic_shapes,
595 const OpInputList& guaranteed_constants, const TpuCompilationCacheKey& key,
596 TpuProgramGroupInterface* tpu_program_group) {
597 absl::Time start_time = absl::Now();
598 std::vector<TensorShape> arg_shapes;
599 TF_RETURN_IF_ERROR(
600 ComputeArgumentShapes(metadata_, dynamic_shapes, &arg_shapes));
601 Status compile_status;
602 if (use_mlir_) {
603 compile_status = Compile(MlirToHloArgs{mlir_module_}, mesh_state->data(),
604 arg_shapes, tpu_program_group);
605 } else {
606 compile_status =
607 Compile(FunctionToHloArgs{&function_,
608 flib_runtime->GetFunctionLibraryDefinition(),
609 flib_runtime->graph_def_version(),
610 {&guaranteed_constants}},
611 mesh_state->data(), arg_shapes, tpu_program_group);
612 }
613
614 absl::Time end_time = absl::Now();
615 auto duration = end_time - start_time;
616
617 const std::string session_name = SessionNameFromMetadata(session_metadata);
618 LOG(INFO) << "Compilation of " << key.prefix << " with session name "
619 << session_name << " took " << duration << " and "
620 << (compile_status.ok() ? "succeeded" : "failed");
621 tpu_program_group->LogProgramMemorySummary();
622 metrics::UpdateXlaCompilationTime(absl::ToInt64Microseconds(duration));
623 TpuCompilationMetrics::IncrementCompilationCount(session_name);
624
625 TF_RETURN_IF_ERROR(tpu_program_group->LogCompilationStats(key, duration));
626
627 return compile_status;
628 }
629
ComputeInternal(OpKernelContext * ctx)630 Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) {
631 VLOG(1) << "Retrieving mesh state";
632 // Retrieve the topology from the resource manager
633 ResourceMgr* rm = GetTPUConfigResourceMgr();
634
635 TpuMeshStateInterface* mesh_state;
636 TF_RETURN_IF_ERROR(rm->Lookup(rm->default_container(),
637 kTpuMeshStateInterfaceResourceName,
638 &mesh_state));
639 core::ScopedUnref mesh_state_unref(mesh_state);
640
641 std::vector<TensorShape> dynamic_shapes;
642 TF_RETURN_IF_ERROR(GetDynamicShapes(ctx, &dynamic_shapes));
643
644 OpInputList guaranteed_constants;
645 // TODO(ycao): Decide whether/how to support guaranteed constants in
646 // MLIR-based TF-Compiler Bridge.
647 if (!use_mlir_) {
648 TF_RETURN_IF_ERROR(
649 ctx->input_list("guaranteed_constants", &guaranteed_constants));
650 }
651
652 const TpuCompilationCacheKey key = CreateCompilationCacheKey(
653 function_.name(), metadata_.function_library_fingerprint(),
654 mlir_module_fingerprint_, guaranteed_constants, dynamic_shapes, metadata_,
655 *mesh_state);
656
657 // Process-wide cache of TPU executables.
658 TpuCompilationCacheInterface* cache;
659 TF_RETURN_IF_ERROR(rm->Lookup<TpuCompilationCacheInterface>(
660 rm->default_container(), kCompilationCacheResourceName, &cache));
661 core::ScopedUnref cache_unref(cache);
662
663 // Per-step object that ensures that compilation cache entries aren't
664 // evicted until the step completes. This mechanism ensures that the
665 // downstream TPUExecute Ops in this step will be able to look up the
666 // compiled executable even if it is marked for eviction before the step
667 // ends.
668 //
669 // We can't use GetTPUConfigResourceMgr here because it may return the
670 // global ResourceMgr, which is not associated with any device, and
671 // GraphMgr's ScopedStepContainer only searches ResourceMgrs associated
672 // with devices when deleting resources at step boundaries.
673 CompilationRefHolder* ref_holder;
674 if (ctx->step_container() == nullptr) {
675 return errors::FailedPrecondition(
676 "TPUCompileOp requires a step container.");
677 }
678 TF_RETURN_IF_ERROR(
679 ctx->step_container()->LookupOrCreate<CompilationRefHolder>(
680 ctx->resource_manager(), "ref_holder", &ref_holder,
681 [cache](CompilationRefHolder** h) {
682 *h = cache->MakePerStepRefHolder();
683 return Status::OK();
684 }));
685 core::ScopedUnref ref_holder_unref(ref_holder);
686
687 int64 uid;
688 std::vector<std::string> proto_key;
689 std::vector<std::string> sharding_key;
690 std::vector<bool> may_modify_variables;
691 absl::Span<const xla::HloProto* const> hlo_metadatas;
692 Status status = cache->CompileIfKeyAbsent(
693 key, ctx->session_metadata(), ref_holder, &uid, &proto_key, &sharding_key,
694 &may_modify_variables, &hlo_metadatas,
695 [&](TpuProgramGroupInterface* tpu_program_group) {
696 VLOG(1) << "Cloud TPU: Compiling TPU program";
697 // When this compile function is invoked, we know that host-memory
698 // cache TpuCompilationCache saw a cache miss. There are two codepaths:
699 // 1. If persistent cache is disabled, compile locally and populate
700 // host-memory cache.
701 // 2. If persistent cache is enabled, we do an additional lookup on
702 // the persistent cache.
703 // - If persistent cache also sees a cache miss, trigger
704 // compilation. Then, populate both persistent cache and
705 // host-memory cache.
706 // - If persistent cache sees a cache hit, retrieve cache entry from
707 // persistent cache to populate host-memory cache without
708 // recompilation. If retrieval failed, compile locally as a
709 // fallback and use the local compilation result to populate
710 // host-memory cache.
711 if (persistent_cache_ == nullptr) {
712 VLOG(1) << "Persistent compilation cache not enabled. Compiling "
713 "TPU executable locally and populating host-memory cache.";
714 return CompileLocallyAndFillHostCache(
715 ctx->function_library(), ctx->session_metadata(), mesh_state,
716 dynamic_shapes, guaranteed_constants, key, tpu_program_group);
717 }
718 return LookupPersistentCompilationCacheAndFillCaches(
719 ctx->function_library(), ctx->session_metadata(), mesh_state,
720 dynamic_shapes, guaranteed_constants, persistent_cache_.get(), key,
721 tpu_program_group);
722 });
723
724 // `ref_holder` is provided to CompileIfKeyAbsent to ensure that cache
725 // entry does not get evicted before TpuExecuteOp runs it and discards
726 // `ref_holder`. When TpuCompilationCacheEntryUnloader get destroyed in the
727 // event that user closes the session while there are in-flight program
728 // executions, it will discard the cache's reference to the cache entry
729 // and but not removed the entry until `ref_holder` discards the last
730 // reference to the entry. This ensures that the guarantees of
731 // `ref_holder` is not violated when this flag is true.
732 if (unload_cache_entry_on_session_close_) {
733 // Place `unloader` in TPU_SYSTEM device resource manager. Note that
734 // - TPUConfigResourceMgr returned by GetTPUConfigResourceMgr() is a special
735 // process-global ResourceMgr. There is only one TPUConfigResourceMgr, and
736 // it is never destroyed.
737 // - TPU_SYSTEM device resource manager is a normal device ResourceMgr for
738 // TPU_SYSTEM device. If DirectSession or isolate_session_state are used,
739 // there's one TPU_SYSTEM ResourceMgr for each session, and the
740 // ResourceMgrs will be destroyed when their corresponding session is
741 // closed. Otherwise there's one TPU_SYSTEM ResourceMgr that's only
742 // destroyed when the master-session is destroyed, not when the worker
743 // sessions are destroyed
744 TpuCompilationCacheEntryUnloader* unloader;
745 TF_RETURN_IF_ERROR(
746 ctx->resource_manager()
747 ->LookupOrCreate<TpuCompilationCacheEntryUnloader>(
748 ctx->resource_manager()->default_container(),
749 kCompilationCacheUnloaderResourceName, &unloader,
750 [cache](TpuCompilationCacheEntryUnloader** new_unloader) {
751 *new_unloader = new TpuCompilationCacheEntryUnloader(cache);
752 return Status::OK();
753 }));
754 // Note that LookupOrCreate puts two refcounts on unloader.
755 core::ScopedUnref unloader_unref(unloader);
756 unloader->AddCacheEntryUid(uid);
757 }
758
759 int64 num_cores_with_compiled_programs = proto_key.size();
760 if (proto_key.size() == 1) {
761 // SPMD produces 1 program for all cores.
762 num_cores_with_compiled_programs = metadata_.num_cores_per_replica();
763 }
764 if (status.ok() &&
765 num_cores_with_compiled_programs +
766 (may_modify_variables.size() * static_cast<int>(!use_mlir_)) !=
767 ctx->num_outputs() - 1) {
768 status = errors::Internal(
769 "Number of cores with compiled programs (",
770 num_cores_with_compiled_programs, ") + variable states (",
771 may_modify_variables.size() * static_cast<int>(!use_mlir_),
772 ") + compilation status output != number of compile op outputs (",
773 ctx->num_outputs(), ")");
774 }
775
776 // TODO(jpienaar): status is not just due to the compilation. At this
777 // point we should be failing the execution of the op in some cases and
778 // returning a compilation error in others. For now, uniformly return an
779 // error and fail in _TPUExecute if status failed here.
780
781 // TODO(misard) the frame id will be wrong if this is ever called from
782 // within a function. Consider whether to use the same hack as is
783 // present in the rendezvous manager where the function call frame is
784 // cast to a uint64, or do something better all around.
785 std::string rendezvous_key_base = strings::StrCat(
786 "host_compute_rendezvous:", ctx->op_kernel().name(), ":",
787 ctx->frame_iter().frame_id, ":", ctx->frame_iter().iter_id, ":");
788
789 // Return compilation status.
790 {
791 Tensor output(DT_STRING, TensorShape({}));
792 tpu::CompilationResultProto proto;
793 proto.set_status_code(status.code());
794 if (!status.ok()) {
795 proto.set_status_error_message(
796 absl::StrCat("Compilation failure: ", status.error_message()));
797 }
798 if (return_hlo_protos_) {
799 // Return the HloProtos as part of compilation status.
800 for (const xla::HloProto* hlo_metadata : hlo_metadatas) {
801 xla::HloProto* hlo_proto = proto.add_hlo_protos();
802 *hlo_proto = *hlo_metadata;
803 }
804 }
805 SerializeToTString(proto, &output.scalar<tstring>()());
806 ctx->set_output(0, output);
807 status.SetPayload(TpuCompileInterface::kTpuCompileErrorPayloadKey,
808 output.scalar<tstring>()());
809 }
810
811 if (status.ok()) {
812 for (int i = 0; i < num_cores_with_compiled_programs; ++i) {
813 Tensor output(DT_STRING, TensorShape({3}));
814 if (proto_key.size() == 1) {
815 output.vec<tstring>()(0) = proto_key[0];
816 } else {
817 output.vec<tstring>()(0) = proto_key[i];
818 }
819 output.vec<tstring>()(1) = rendezvous_key_base;
820 if (sharding_key.empty()) {
821 output.vec<tstring>()(2) = "";
822 } else if (sharding_key.size() == 1) {
823 output.vec<tstring>()(2) = sharding_key[0];
824 } else {
825 TF_RET_CHECK(sharding_key.size() == num_cores_with_compiled_programs);
826 output.vec<tstring>()(2) = sharding_key[i];
827 }
828 ctx->set_output(i + 1, output);
829 }
830 if (!use_mlir_) {
831 // If any of the programs may modify a variable, then return that all
832 // do as the only current state being tracked here is if a model is
833 // read-only or not.
834 bool may_modify = false;
835 for (bool m : may_modify_variables) {
836 may_modify = may_modify || m;
837 }
838 for (int i = 0; i < may_modify_variables.size(); ++i) {
839 Tensor output(DT_BOOL, TensorShape({}));
840 output.scalar<bool>()() = may_modify;
841 ctx->set_output(i + num_cores_with_compiled_programs + 1, output);
842 }
843 }
844 VLOG(1) << "Cloud TPU: Compilation succeeded";
845 } else {
846 // Return error in the invalid case.
847 for (int i = 0; i < num_computations_; ++i) {
848 Tensor output(DT_STRING, TensorShape({3}));
849 output.vec<tstring>()(0) = "<<NO PROGRAM AS COMPILATION FAILED>>";
850 output.vec<tstring>()(1) = "<<NO RENDEZVOUS KEY AS COMPILATION FAILED>>";
851 output.vec<tstring>()(2) = "<<NO SHARDing KEY AS COMPILATION FAILED>>";
852 ctx->set_output(i + 1, output);
853 }
854 if (!use_mlir_) {
855 // The TPUCompileMLIR op does not have MayModifyVariable output
856 for (int i = 0; i < num_computations_; ++i) {
857 Tensor output(false);
858 ctx->set_output(i + num_computations_ + 1, output);
859 }
860 }
861 }
862 return status;
863 }
864 } // namespace tpu
865 } // namespace tensorflow
866