• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h"
16 
17 #include <string>
18 
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/compiler/jit/flags.h"
21 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
22 #include "tensorflow/compiler/xla/client/client_library.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/graph_constructor.h"
26 #include "tensorflow/core/common_runtime/graph_optimizer.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/metrics.h"
29 #include "tensorflow/core/framework/resource_mgr.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/status.h"
32 #include "tensorflow/core/protobuf/tpu/compilation_result.pb.h"
33 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
34 #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
35 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_unloader.h"
36 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
37 #include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h"
38 #include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
39 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
40 #include "tensorflow/core/tpu/kernels/tpu_op_util.h"
41 #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
42 #include "tensorflow/core/tpu/kernels/tpu_util.h"
43 #include "tensorflow/core/tpu/tpu_api.h"
44 #include "tensorflow/core/tpu/tpu_compile_interface.h"
45 #include "tensorflow/core/tpu/tpu_configuration.h"
46 #include "tensorflow/core/tpu/tpu_defs.h"
47 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
48 
49 namespace tensorflow {
50 namespace tpu {
51 
52 namespace {
53 
54 static constexpr char kArgOp[] = "_Arg";
55 static constexpr char kRetvalOp[] = "_Retval";
56 
CoreDevice(int core)57 std::string CoreDevice(int core) {
58   return strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE, ":", core);
59 }
60 
ConvertGraphShapeInfoToShapeMap(const Graph & graph,const GraphShapeInfo & graph_shape_info,std::unordered_map<string,std::vector<PartialTensorShape>> * shape_map)61 void ConvertGraphShapeInfoToShapeMap(
62     const Graph& graph, const GraphShapeInfo& graph_shape_info,
63     std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map) {
64   // Builds a map from node name to Node* for `graph`.
65   std::unordered_map<string, Node*> index;
66   for (Node* node : graph.nodes()) {
67     index[node->name()] = node;
68   }
69   // Discards the resource handle shape info while converting to the correct map
70   // form.
71   for (const auto& node_shape_info : graph_shape_info) {
72     const string& node_name = node_shape_info.first;
73     const std::vector<InferredShape>& output_shapes = node_shape_info.second;
74     // Gets the vector of partial shapes, first converting node name to Node*
75     // using index. graph is the subgraph of the original graph assigned to a
76     // particular core, and we only add entries to shape_map for nodes in
77     // graph_shape_info that are in the subgraph.
78     const auto& node_iter = index.find(node_name);
79     if (node_iter != index.end()) {
80       auto& partial_shapes = (*shape_map)[node_name];
81       for (const auto& inferred_shape : output_shapes) {
82         partial_shapes.push_back(inferred_shape.shape);
83       }
84     }
85   }
86 }
87 
88 // Sets arg shape, arg core mapping, and per core arg shapes for a given
89 // argument, depending on its sharding.
SetPerCoreArgShapes(const tpu::TPUCompileMetadataProto::Arg & proto_arg,const int arg_index,xla::Shape * xla_arg_shape,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes)90 Status SetPerCoreArgShapes(
91     const tpu::TPUCompileMetadataProto::Arg& proto_arg, const int arg_index,
92     xla::Shape* xla_arg_shape,
93     std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
94     std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
95   if (proto_arg.unrestricted_layout()) {
96     xla_arg_shape->clear_layout();
97   }
98 
99   (*arg_core_mapping)[arg_index].sharding = proto_arg.sharding();
100   if (proto_arg.sharding().type() == xla::OpSharding::MAXIMAL) {
101     const int core = proto_arg.sharding().tile_assignment_devices(0);
102     TF_RET_CHECK(0 <= core && core < per_core_arg_shapes->size());
103     (*arg_core_mapping)[arg_index].indices.push_back(
104         (*per_core_arg_shapes)[core].size());
105     (*per_core_arg_shapes)[core].push_back(*xla_arg_shape);
106   } else if (proto_arg.sharding().type() == xla::OpSharding::OTHER) {
107     TF_ASSIGN_OR_RETURN(xla::HloSharding hlo_sharding,
108                         xla::HloSharding::FromProto(proto_arg.sharding()));
109     for (int core : proto_arg.sharding().tile_assignment_devices()) {
110       (*arg_core_mapping)[arg_index].indices.push_back(
111           (*per_core_arg_shapes)[core].size());
112       xla::Shape per_core_shape =
113           GetPerDeviceShape(*xla_arg_shape, hlo_sharding, core);
114       if (proto_arg.unrestricted_layout()) {
115         per_core_shape.clear_layout();
116       }
117       (*per_core_arg_shapes)[core].push_back(per_core_shape);
118     }
119   } else {
120     TF_RET_CHECK(proto_arg.sharding().type() == xla::OpSharding::REPLICATED)
121         << "Unsupported argument sharding: "
122         << " proto_arg=" << proto_arg.DebugString();
123     for (int core = 0; core < per_core_arg_shapes->size(); ++core) {
124       (*arg_core_mapping)[arg_index].indices.push_back(
125           (*per_core_arg_shapes)[core].size());
126       (*per_core_arg_shapes)[core].push_back(*xla_arg_shape);
127     }
128   }
129 
130   return Status::OK();
131 }
132 
133 }  // namespace
134 
135 CompileOpImplFactory* CompileOpImplFactory::factory_ = nullptr;
136 
137 /* static */
Get()138 CompileOpImplFactory* CompileOpImplFactory::Get() { return factory_; }
139 
140 /* static */
Register(CompileOpImplFactory * factory)141 void CompileOpImplFactory::Register(CompileOpImplFactory* factory) {
142   CHECK_EQ(factory_, nullptr)
143       << "CompileOpImplFactory can only be registered "
144          "once and there can only be one factory active and used.";
145   factory_ = factory;
146 }
147 
AssignReturnValueToCore(std::vector<tpu::ShardingAndIndex> * retval_core_mapping)148 Status TpuCompileOpKernelCommon::AssignReturnValueToCore(
149     std::vector<tpu::ShardingAndIndex>* retval_core_mapping) {
150   std::vector<int> per_core_retval_counts(metadata_.num_cores_per_replica(), 0);
151   for (int i = 0; i < metadata_.retvals_size(); ++i) {
152     const tpu::TPUCompileMetadataProto::Retval& proto_retval =
153         metadata_.retvals(i);
154     (*retval_core_mapping)[i].sharding = proto_retval.sharding();
155     if (proto_retval.sharding().type() == xla::OpSharding::MAXIMAL) {
156       int core = proto_retval.sharding().tile_assignment_devices(0);
157       TF_RET_CHECK(0 <= core && core < per_core_retval_counts.size());
158       (*retval_core_mapping)[i].indices.push_back(
159           per_core_retval_counts[core]++);
160     } else if (proto_retval.sharding().type() == xla::OpSharding::OTHER) {
161       for (int64 core : proto_retval.sharding().tile_assignment_devices()) {
162         (*retval_core_mapping)[i].indices.push_back(
163             per_core_retval_counts[core]++);
164       }
165     } else {
166       TF_RET_CHECK(proto_retval.sharding().type() ==
167                    xla::OpSharding::REPLICATED)
168           << "Unsupported return value sharding: "
169           << proto_retval.sharding().DebugString();
170       for (int core = 0; core < per_core_retval_counts.size(); ++core) {
171         (*retval_core_mapping)[i].indices.push_back(
172             per_core_retval_counts[core]++);
173       }
174     }
175   }
176   return Status::OK();
177 }
178 
BuildComputationArgumentDescriptions(const std::vector<TensorShape> & arg_shapes,const GuaranteedConsts & guaranteed_constants,const XlaCompiler & compiler,std::vector<XlaCompiler::Argument> * args,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes)179 Status TpuCompileOpKernelCommon::BuildComputationArgumentDescriptions(
180     const std::vector<TensorShape>& arg_shapes,
181     const GuaranteedConsts& guaranteed_constants, const XlaCompiler& compiler,
182     std::vector<XlaCompiler::Argument>* args,
183     std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
184     std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
185   // Builds a description of the computation's arguments.
186   int constant_count = 0;
187   size_t guaranteed_constants_size = 0;
188   for (int i = 0; i < metadata_.args_size(); ++i) {
189     const tpu::TPUCompileMetadataProto::Arg& proto_arg = metadata_.args(i);
190     args->push_back(XlaCompiler::Argument());
191     XlaCompiler::Argument& arg = args->back();
192     arg.type = proto_arg.dtype();
193     arg.shape = arg_shapes[i];
194     arg.node_name = proto_arg.name();
195     switch (proto_arg.kind()) {
196       case tpu::TPUCompileMetadataProto::Arg::PARAMETER:
197         arg.kind = XlaCompiler::Argument::kParameter;
198         break;
199       case tpu::TPUCompileMetadataProto::Arg::VARIABLE:
200         arg.kind = XlaCompiler::Argument::kResource;
201         arg.resource_kind = XlaResource::kVariable;
202         arg.initialized = true;
203         arg.fast_mem = proto_arg.fast_mem();
204         break;
205       case tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT:
206         arg.kind = XlaCompiler::Argument::kConstant;
207         guaranteed_constants_size =
208             guaranteed_constants.index() == 0
209                 ? absl::get<0>(guaranteed_constants).size()
210                 : absl::get<1>(guaranteed_constants)->size();
211         TF_RET_CHECK(constant_count < guaranteed_constants_size)
212             << "More constant args in TPUCompileMetadataProto than constant "
213                "tensors.";
214         if (guaranteed_constants.index() == 0) {
215           // `guaranteed_constants` is of type `absl::Span<const TensorProto*
216           // const>`.
217           Tensor tensor;
218           CHECK(tensor.FromProto(
219               *absl::get<0>(guaranteed_constants)[constant_count++]))
220               << "Failed to deserialize invalid `TensorProto` into `Tensor`.";
221           arg.constant_value = tensor;
222         } else {
223           // `guaranteed_constants` is of type `const OpInputList* const`.
224           arg.constant_value =
225               (*absl::get<1>(guaranteed_constants))[constant_count++];
226         }
227         break;
228       case tpu::TPUCompileMetadataProto::Arg::INVALID:
229       default:
230         break;
231     }
232     arg.is_same_data_across_replicas = proto_arg.is_same_data_across_replicas();
233     if (arg.kind == XlaCompiler::Argument::kInvalid) {
234       return errors::InvalidArgument("Invalid argument kind");
235     }
236     if (arg.kind == XlaCompiler::Argument::kConstant) {
237       continue;
238     }
239 
240     // Assign each argument a sharding.
241     xla::Shape xla_arg_shape;
242     TF_ASSIGN_OR_RETURN(auto arg_sharding,
243                         xla::HloSharding::FromProto(proto_arg.sharding()));
244     TF_RETURN_IF_ERROR(compiler.XLAShapeForArgument(
245         arg, /*is_entry_computation=*/true, arg_sharding, &xla_arg_shape));
246     TF_RETURN_IF_ERROR(SetPerCoreArgShapes(
247         proto_arg, i, &xla_arg_shape, arg_core_mapping, per_core_arg_shapes));
248   }
249   TF_RET_CHECK(constant_count == guaranteed_constants_size)
250       << "Not all of the constant tensors were consumed.";
251 
252   return Status::OK();
253 }
254 
GetShardingInfo(absl::Span<const TensorShape> arg_shapes,const XlaCompiler::ShapeRepresentationFn shape_representation_fn,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes)255 Status TpuCompileOpKernelCommon::GetShardingInfo(
256     absl::Span<const TensorShape> arg_shapes,
257     const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
258     std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
259     std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
260   int num_inputs = metadata_.args_size();
261   for (int i = 0; i < num_inputs; ++i) {
262     const auto& proto_arg = metadata_.args(i);
263     TF_ASSIGN_OR_RETURN(auto arg_sharding,
264                         xla::HloSharding::FromProto(proto_arg.sharding()));
265     TF_ASSIGN_OR_RETURN(
266         auto xla_arg_shape,
267         shape_representation_fn(arg_shapes[i], proto_arg.dtype(),
268                                 /*use_fast_memory=*/false));
269     TF_RETURN_IF_ERROR(
270         RewriteLayoutWithShardedShape(arg_sharding, /*use_fast_memory=*/false,
271                                       shape_representation_fn, &xla_arg_shape));
272     TF_RETURN_IF_ERROR(SetPerCoreArgShapes(
273         proto_arg, i, &xla_arg_shape, arg_core_mapping, per_core_arg_shapes));
274   }
275   return Status::OK();
276 }
277 
CompileTFFunctionToHlo(const FunctionLibraryDefinition & flib_def,int graph_def_version,const XlaCompiler::ShapeRepresentationFn shape_representation_fn,const std::vector<TensorShape> & arg_shapes,const GuaranteedConsts & guaranteed_constants,const NameAttrList & function,std::function<Status (ResourceMgr *)> populate_resource_manager_fn,xla::CompileOnlyClient * client,std::vector<tpu::ShardingAndIndex> * arg_core_mapping,std::vector<std::vector<xla::Shape>> * per_core_arg_shapes,XlaCompiler::CompilationResult * compilation_result)278 Status TpuCompileOpKernelCommon::CompileTFFunctionToHlo(
279     const FunctionLibraryDefinition& flib_def, int graph_def_version,
280     const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
281     const std::vector<TensorShape>& arg_shapes,
282     const GuaranteedConsts& guaranteed_constants, const NameAttrList& function,
283     std::function<Status(ResourceMgr*)> populate_resource_manager_fn,
284     xla::CompileOnlyClient* client,
285     std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
286     std::vector<std::vector<xla::Shape>>* per_core_arg_shapes,
287     XlaCompiler::CompilationResult* compilation_result) {
288   XlaCompiler::Options compiler_options;
289   compiler_options.device_type = DeviceType(DEVICE_TPU_XLA_JIT);
290   compiler_options.client = client;
291   compiler_options.flib_def = &flib_def;
292   compiler_options.allow_cpu_custom_calls = false;
293   compiler_options.populate_resource_manager = &populate_resource_manager_fn;
294   compiler_options.graph_def_version = graph_def_version;
295   compiler_options.shape_representation_fn = shape_representation_fn;
296 
297   auto compiler = absl::make_unique<XlaCompiler>(compiler_options);
298 
299   std::vector<XlaCompiler::Argument> args;
300   TF_RETURN_IF_ERROR(BuildComputationArgumentDescriptions(
301       arg_shapes, guaranteed_constants, *compiler, &args, arg_core_mapping,
302       per_core_arg_shapes));
303 
304   // Assign each return value to a core.
305   std::vector<tpu::ShardingAndIndex> retval_core_mapping(
306       metadata_.retvals_size());
307   TF_RETURN_IF_ERROR(
308       TpuCompileOpKernelCommon::AssignReturnValueToCore(&retval_core_mapping));
309 
310   LOG(INFO) << "Instantiating function:" << function.name();
311   FunctionLibraryRuntime::Handle handle;
312   TF_RETURN_IF_ERROR(compiler->flib_runtime()->Instantiate(
313       function.name(), AttrSlice(&function.attr()), &handle));
314   const FunctionBody* fbody = compiler->flib_runtime()->GetFunctionBody(handle);
315   const string function_id =
316       Canonicalize(function.name(), AttrSlice(&function.attr()));
317 
318   std::unique_ptr<Graph> graph(new Graph(&flib_def));
319   CopyGraph(*fbody->graph, graph.get());
320 
321   VLOG(2) << "metadata: " << metadata_.DebugString();
322   std::vector<int> parameter_arg_mapping;
323   for (int i = 0; i < args.size(); i++) {
324     XlaCompiler::Argument& arg = args[i];
325     if (arg.kind != XlaCompiler::Argument::kParameter) {
326       continue;
327     }
328     parameter_arg_mapping.push_back(i);
329   }
330   TF_RET_CHECK(fbody->arg_nodes.size() == args.size());
331   for (size_t i = 0; i < fbody->arg_nodes.size(); i++) {
332     args[i].node_name = fbody->arg_nodes[i]->name();
333   }
334 
335   std::vector<gtl::InlinedVector<int64, 4>> arg_shape_dims;
336   arg_shape_dims.reserve(arg_shapes.size());
337   std::vector<PartialTensorShape> partial_arg_shapes(arg_shapes.size());
338   for (const TensorShape& shape : arg_shapes) {
339     arg_shape_dims.push_back(shape.dim_sizes());
340   }
341 
342   for (const auto& padding_mapping : metadata_.padding_maps()) {
343     if (padding_mapping.padding_arg_index() >= parameter_arg_mapping.size()) {
344       return errors::Internal(absl::StrCat(
345           "TPUCompileMetadataProto `padding_maps` has `padding_arg_index` ",
346           padding_mapping.padding_arg_index(),
347           " which exceeds`parameter_arg_mapping` array bounds ",
348           parameter_arg_mapping.size(),
349           ". this usually indicates there are dynamic shape inputs fed into "
350           "TPUs from outside compilation head extraction, which is not "
351           "supported"));
352     }
353     int padding_arg_index =
354         parameter_arg_mapping.at(padding_mapping.padding_arg_index());
355     args[parameter_arg_mapping.at(padding_mapping.arg_index())]
356         .dynamic_dim_to_arg_num_map[padding_mapping.shape_index()] =
357         padding_arg_index;
358     arg_shape_dims[parameter_arg_mapping.at(padding_mapping.arg_index())]
359                   [padding_mapping.shape_index()] = -1;
360     args[padding_arg_index].is_pad_arg = true;
361   }
362 
363   for (int64 i = 0; i < arg_shape_dims.size(); ++i) {
364     auto& dims = arg_shape_dims[i];
365     TF_RETURN_IF_ERROR(PartialTensorShape::MakePartialShape(
366         dims.data(), dims.size(), &partial_arg_shapes[i]));
367   }
368 
369   // Adds device assignments to _Arg and _Retval nodes.
370   TF_RETURN_IF_ERROR(AssignDevicesToArgsAndRetvals(
371       absl::MakeSpan(*arg_core_mapping), absl::MakeSpan(retval_core_mapping),
372       graph.get()));
373 
374   VLOG(1) << "Optimizing TensorFlow graph";
375   FunctionLibraryDefinition flib_definition(flib_def);
376   TF_RETURN_IF_ERROR(OptimizeGraph(metadata_, partial_arg_shapes, &graph,
377                                    compiler->flib_runtime(), &flib_definition));
378 
379   VLOG(1) << "Compiling TensorFlow graph to HLO";
380   XlaCompiler::CompileOptions compile_options;
381   compile_options.return_updated_values_for_all_resources = false;
382   compile_options.use_tuple_arg = true;
383   compile_options.is_entry_computation = true;
384   compile_options.alias_resource_update = true;
385   return compiler->CompileGraph(compile_options, function_id, std::move(graph),
386                                 args, compilation_result);
387 }
388 
ExitCountdown(Env * env,std::shared_ptr<std::atomic<bool>> done)389 /* static */ void TpuCompileOpKernelCommon::ExitCountdown(
390     Env* env, std::shared_ptr<std::atomic<bool>> done) {
391   const int kSleepSeconds = 300;
392   LOG(INFO) << "TpuCompileOp was cancelled. Sleeping for " << kSleepSeconds
393             << " seconds to give time for TPUCompileOp to finished.";
394   env->SleepForMicroseconds(kSleepSeconds * 1000000);
395   if (done->load()) {
396     // If the TpuCompileOp has finished, then terminate peacefully.
397     return;
398   }
399 
400   LOG(ERROR) << "Aborting process due to cancelled TpuCompileOp. This "
401              << "termination is to ensure a consistent state.";
402   std::exit(42);
403 }
404 
GetDynamicShapes(OpKernelContext * ctx,std::vector<TensorShape> * shapes)405 /* static */ Status TpuCompileOpKernelCommon::GetDynamicShapes(
406     OpKernelContext* ctx, std::vector<TensorShape>* shapes) {
407   OpInputList dynamic_shapes;
408   TF_RETURN_IF_ERROR(ctx->input_list("dynamic_shapes", &dynamic_shapes));
409 
410   shapes->resize(dynamic_shapes.size());
411   for (int i = 0; i < dynamic_shapes.size(); ++i) {
412     TF_RETURN_IF_ERROR(
413         tpu::ShapeTensorToTensorShape(dynamic_shapes[i], &(*shapes)[i]));
414   }
415   return Status::OK();
416 }
417 
418 // Function arguments and return values lose their device assignments, so we
419 // must recreate them.
AssignDevicesToArgsAndRetvals(absl::Span<const tpu::ShardingAndIndex> arg_core_mapping,absl::Span<const tpu::ShardingAndIndex> retval_core_mapping,Graph * graph)420 /* static */ Status TpuCompileOpKernelCommon::AssignDevicesToArgsAndRetvals(
421     absl::Span<const tpu::ShardingAndIndex> arg_core_mapping,
422     absl::Span<const tpu::ShardingAndIndex> retval_core_mapping, Graph* graph) {
423   auto assign = [&](Node* node, const xla::OpSharding& sharding) -> Status {
424     if (sharding.type() == xla::OpSharding::MAXIMAL) {
425       const string device = CoreDevice(sharding.tile_assignment_devices(0));
426       node->set_assigned_device_name(device);
427       node->set_requested_device(device);
428     } else {
429       TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED ||
430                    sharding.type() == xla::OpSharding::OTHER)
431           << "Unsupported sharding on parameter/retval: "
432           << sharding.DebugString();
433     }
434     node->AddAttr("_XlaSharding", sharding.SerializeAsString());
435     return Status::OK();
436   };
437   for (Node* node : graph->op_nodes()) {
438     if (node->type_string() == kArgOp) {
439       int index;
440       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
441       TF_RET_CHECK(index >= 0 && index < arg_core_mapping.size());
442       TF_RETURN_IF_ERROR(assign(node, arg_core_mapping[index].sharding));
443     } else if (node->type_string() == kRetvalOp) {
444       int index;
445       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
446       TF_RET_CHECK(index >= 0 && index < retval_core_mapping.size());
447       TF_RETURN_IF_ERROR(assign(node, retval_core_mapping[index].sharding));
448     }
449   }
450   return Status::OK();
451 }
452 
453 // Performs shape inference on the body of `graph`. Shapes for arguments
454 // are taken from `metadata` and `arg_shapes`.
RunShapeInferenceOnComputation(const tpu::TPUCompileMetadataProto & metadata,const std::vector<PartialTensorShape> & arg_shapes,Graph * graph,FunctionLibraryRuntime * flr,GraphShapeInfo * shape_info)455 /* static */ Status TpuCompileOpKernelCommon::RunShapeInferenceOnComputation(
456     const tpu::TPUCompileMetadataProto& metadata,
457     const std::vector<PartialTensorShape>& arg_shapes, Graph* graph,
458     FunctionLibraryRuntime* flr, GraphShapeInfo* shape_info) {
459   int num_args = arg_shapes.size();
460   CHECK_EQ(num_args, metadata.args_size());
461 
462   std::map<int, InferredShape> arg_shapes_for_inference;
463   for (int i = 0; i < num_args; ++i) {
464     const auto& arg = metadata.args(i);
465     InferredShape& shape_for_inference = arg_shapes_for_inference[i];
466     if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::VARIABLE) {
467       // For resource variables, arg_shapes[] contains the shape of the
468       // variable's value.
469       shape_for_inference.handle_type = arg.dtype();
470       shape_for_inference.handle_shape = arg_shapes[i];
471       // The shape of the variable itself is always a scalar.
472       shape_for_inference.shape = TensorShape();
473     } else {
474       if (arg.kind() ==
475           tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
476         VLOG(1) << "PromisedConstant shape: " << arg_shapes[i].DebugString();
477       }
478       shape_for_inference.shape = arg_shapes[i];
479     }
480   }
481   return InferShapes(
482       graph, arg_shapes_for_inference,
483       flr != nullptr ? flr->GetFunctionLibraryDefinition() : nullptr,
484       shape_info);
485 }
486 
OptimizeGraph(const tpu::TPUCompileMetadataProto & metadata,const std::vector<PartialTensorShape> & arg_shapes,std::unique_ptr<Graph> * graph,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld)487 Status TpuCompileOpKernelCommon::OptimizeGraph(
488     const tpu::TPUCompileMetadataProto& metadata,
489     const std::vector<PartialTensorShape>& arg_shapes,
490     std::unique_ptr<Graph>* graph, FunctionLibraryRuntime* flr,
491     FunctionLibraryDefinition* fld) {
492   // Sets up options for the optimization passes that need to be done. Notice
493   // that CSE is not needed as XLA has its own CSE passes later in the
494   // compilation stage.
495   auto flags = GetBuildXlaOpsPassFlags();
496   OptimizerOptions opts;
497   opts.set_opt_level(OptimizerOptions::L0);
498   opts.set_do_common_subexpression_elimination(false);
499   opts.set_do_function_inlining(true);
500   opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding);
501   GraphOptimizer optimizer(opts);
502   {
503     // Performs a first function inlining pass before shape inference, since
504     // otherwise shape inference can't see inside functions and a comprehensive
505     // shape_map, including function ops, is needed to constant-propagate Shape
506     // Ops below.
507     GraphOptimizer::Options optimizer_opts;
508     optimizer_opts.inline_multi_device_functions = true;
509     optimizer_opts.inline_impl_selection_group_functions = true;
510     optimizer_opts.inline_with_single_device_body_placer = true;
511     // Infer shapes for each node in the computation. Shape inference can help
512     // skip constant folding of large shapes.
513     GraphShapeInfo shape_info;
514     TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation(
515         metadata, arg_shapes, graph->get(), flr, &shape_info));
516     // Converts the GraphShapeInfo into the form needed by the constant-folding
517     // pass of the optimizer.
518     std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
519     ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
520     optimizer_opts.shape_map = &shape_map;
521     optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts);
522   }
523 
524   {
525     // Infer shapes for each node in the computation.
526     GraphShapeInfo shape_info;
527     TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation(
528         metadata, arg_shapes, graph->get(), flr, &shape_info));
529     std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
530     ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
531     optimizer.Optimize(flr, flr->env(), flr->device(), graph, &shape_map);
532   }
533 
534   TF_RETURN_IF_ERROR(RewriteTensorListWithConstElement(graph->get(), fld));
535 
536   return Status::OK();
537 }
538 
Compute(OpKernelContext * ctx)539 void TpuCompileOpKernelCommon::Compute(OpKernelContext* ctx) {
540   VLOG(1) << "Cloud TPU: TpuCompileOpKernelCommon::Compute";
541 
542   std::shared_ptr<std::atomic<bool>> done(new std::atomic<bool>(false));
543 
544   CancellationToken token =
545       ctx->cancellation_manager()->get_cancellation_token();
546   const bool already_cancelled =
547       !ctx->cancellation_manager()->RegisterCallback(token, [ctx, done]() {
548         if (OpsApiFn()->TpuCompile_ShouldTpuCompileOpIgnoreCancellationFn()) {
549           return;
550         }
551 
552         // Sleep and exit in another thread so the cancellation manager can
553         // continue running callbacks.
554         Env* env = ctx->env();
555         env->SchedClosure([env, done]() { ExitCountdown(env, done); });
556       });
557 
558   // If the RPC was cancelled before we registered the cancellation callback,
559   // don't compile the TPU program.
560   OP_REQUIRES(ctx, !already_cancelled,
561               errors::Cancelled("RPC cancelled, not compiling TPU program"));
562 
563   // We only want to abort the process if a cancellation actually occurs during
564   // compilation; we must deregister the callback in the success case. It
565   // doesn't hurt to also deregister the callback in the failure case; the
566   // CancellationManager ensures that already-registered callbacks will be run
567   // once cancellation has started.
568   auto cancellation_cleanup = xla::MakeCleanup([ctx, token, done] {
569     ctx->cancellation_manager()->DeregisterCallback(token);
570     done->store(true);
571   });
572 
573   Status compile_status = ComputeInternal(ctx);
574   string status_payload;
575   // Construct payload if compile_status is not ok and there's no payload for
576   // compilation yet.
577   if (!compile_status.ok() &&
578       compile_status.GetPayload(TpuCompileInterface::kTpuCompileErrorPayloadKey)
579           .empty()) {
580     tpu::CompilationResultProto proto;
581     proto.set_status_code(compile_status.code());
582     proto.set_status_error_message(compile_status.error_message());
583     status_payload = proto.SerializeAsString();
584   }
585   OP_REQUIRES_OK_OR_SET_PAYLOAD(ctx,
586                                 TpuCompileInterface::kTpuCompileErrorPayloadKey,
587                                 status_payload, compile_status);
588 }
589 
CompileLocallyAndFillHostCache(FunctionLibraryRuntime * flib_runtime,const SessionMetadata * session_metadata,const TpuMeshStateInterface * mesh_state,const std::vector<TensorShape> & dynamic_shapes,const OpInputList & guaranteed_constants,const TpuCompilationCacheKey & key,TpuProgramGroupInterface * tpu_program_group)590 Status TpuCompileOpKernelCommon::CompileLocallyAndFillHostCache(
591     FunctionLibraryRuntime* flib_runtime,
592     const SessionMetadata* session_metadata,
593     const TpuMeshStateInterface* mesh_state,
594     const std::vector<TensorShape>& dynamic_shapes,
595     const OpInputList& guaranteed_constants, const TpuCompilationCacheKey& key,
596     TpuProgramGroupInterface* tpu_program_group) {
597   absl::Time start_time = absl::Now();
598   std::vector<TensorShape> arg_shapes;
599   TF_RETURN_IF_ERROR(
600       ComputeArgumentShapes(metadata_, dynamic_shapes, &arg_shapes));
601   Status compile_status;
602   if (use_mlir_) {
603     compile_status = Compile(MlirToHloArgs{mlir_module_}, mesh_state->data(),
604                              arg_shapes, tpu_program_group);
605   } else {
606     compile_status =
607         Compile(FunctionToHloArgs{&function_,
608                                   flib_runtime->GetFunctionLibraryDefinition(),
609                                   flib_runtime->graph_def_version(),
610                                   {&guaranteed_constants}},
611                 mesh_state->data(), arg_shapes, tpu_program_group);
612   }
613 
614   absl::Time end_time = absl::Now();
615   auto duration = end_time - start_time;
616 
617   const std::string session_name = SessionNameFromMetadata(session_metadata);
618   LOG(INFO) << "Compilation of " << key.prefix << " with session name "
619             << session_name << " took " << duration << " and "
620             << (compile_status.ok() ? "succeeded" : "failed");
621   tpu_program_group->LogProgramMemorySummary();
622   metrics::UpdateXlaCompilationTime(absl::ToInt64Microseconds(duration));
623   TpuCompilationMetrics::IncrementCompilationCount(session_name);
624 
625   TF_RETURN_IF_ERROR(tpu_program_group->LogCompilationStats(key, duration));
626 
627   return compile_status;
628 }
629 
ComputeInternal(OpKernelContext * ctx)630 Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) {
631   VLOG(1) << "Retrieving mesh state";
632   // Retrieve the topology from the resource manager
633   ResourceMgr* rm = GetTPUConfigResourceMgr();
634 
635   TpuMeshStateInterface* mesh_state;
636   TF_RETURN_IF_ERROR(rm->Lookup(rm->default_container(),
637                                 kTpuMeshStateInterfaceResourceName,
638                                 &mesh_state));
639   core::ScopedUnref mesh_state_unref(mesh_state);
640 
641   std::vector<TensorShape> dynamic_shapes;
642   TF_RETURN_IF_ERROR(GetDynamicShapes(ctx, &dynamic_shapes));
643 
644   OpInputList guaranteed_constants;
645   // TODO(ycao): Decide whether/how to support guaranteed constants in
646   // MLIR-based TF-Compiler Bridge.
647   if (!use_mlir_) {
648     TF_RETURN_IF_ERROR(
649         ctx->input_list("guaranteed_constants", &guaranteed_constants));
650   }
651 
652   const TpuCompilationCacheKey key = CreateCompilationCacheKey(
653       function_.name(), metadata_.function_library_fingerprint(),
654       mlir_module_fingerprint_, guaranteed_constants, dynamic_shapes, metadata_,
655       *mesh_state);
656 
657   // Process-wide cache of TPU executables.
658   TpuCompilationCacheInterface* cache;
659   TF_RETURN_IF_ERROR(rm->Lookup<TpuCompilationCacheInterface>(
660       rm->default_container(), kCompilationCacheResourceName, &cache));
661   core::ScopedUnref cache_unref(cache);
662 
663   // Per-step object that ensures that compilation cache entries aren't
664   // evicted until the step completes. This mechanism ensures that the
665   // downstream TPUExecute Ops in this step will be able to look up the
666   // compiled executable even if it is marked for eviction before the step
667   // ends.
668   //
669   // We can't use GetTPUConfigResourceMgr here because it may return the
670   // global ResourceMgr, which is not associated with any device, and
671   // GraphMgr's ScopedStepContainer only searches ResourceMgrs associated
672   // with devices when deleting resources at step boundaries.
673   CompilationRefHolder* ref_holder;
674   if (ctx->step_container() == nullptr) {
675     return errors::FailedPrecondition(
676         "TPUCompileOp requires a step container.");
677   }
678   TF_RETURN_IF_ERROR(
679       ctx->step_container()->LookupOrCreate<CompilationRefHolder>(
680           ctx->resource_manager(), "ref_holder", &ref_holder,
681           [cache](CompilationRefHolder** h) {
682             *h = cache->MakePerStepRefHolder();
683             return Status::OK();
684           }));
685   core::ScopedUnref ref_holder_unref(ref_holder);
686 
687   int64 uid;
688   std::vector<std::string> proto_key;
689   std::vector<std::string> sharding_key;
690   std::vector<bool> may_modify_variables;
691   absl::Span<const xla::HloProto* const> hlo_metadatas;
692   Status status = cache->CompileIfKeyAbsent(
693       key, ctx->session_metadata(), ref_holder, &uid, &proto_key, &sharding_key,
694       &may_modify_variables, &hlo_metadatas,
695       [&](TpuProgramGroupInterface* tpu_program_group) {
696         VLOG(1) << "Cloud TPU: Compiling TPU program";
697         // When this compile function is invoked, we know that host-memory
698         // cache TpuCompilationCache saw a cache miss. There are two codepaths:
699         // 1. If persistent cache is disabled, compile locally and populate
700         //    host-memory cache.
701         // 2. If persistent cache is enabled, we do an additional lookup on
702         //    the persistent cache.
703         //    - If persistent cache also sees a cache miss, trigger
704         //      compilation. Then, populate both persistent cache and
705         //      host-memory cache.
706         //    - If persistent cache sees a cache hit, retrieve cache entry from
707         //      persistent cache to populate host-memory cache without
708         //      recompilation. If retrieval failed, compile locally as a
709         //      fallback and use the local compilation result to populate
710         //      host-memory cache.
711         if (persistent_cache_ == nullptr) {
712           VLOG(1) << "Persistent compilation cache not enabled. Compiling "
713                      "TPU executable locally and populating host-memory cache.";
714           return CompileLocallyAndFillHostCache(
715               ctx->function_library(), ctx->session_metadata(), mesh_state,
716               dynamic_shapes, guaranteed_constants, key, tpu_program_group);
717         }
718         return LookupPersistentCompilationCacheAndFillCaches(
719             ctx->function_library(), ctx->session_metadata(), mesh_state,
720             dynamic_shapes, guaranteed_constants, persistent_cache_.get(), key,
721             tpu_program_group);
722       });
723 
724   // `ref_holder` is provided to CompileIfKeyAbsent to ensure that cache
725   // entry does not get evicted before TpuExecuteOp runs it and discards
726   // `ref_holder`. When TpuCompilationCacheEntryUnloader get destroyed in the
727   // event that user closes the session while there are in-flight program
728   // executions, it will discard the cache's reference to the cache entry
729   // and but not removed the entry until `ref_holder` discards the last
730   // reference to the entry. This ensures that the guarantees of
731   // `ref_holder` is not violated when this flag is true.
732   if (unload_cache_entry_on_session_close_) {
733     // Place `unloader` in TPU_SYSTEM device resource manager. Note that
734     // - TPUConfigResourceMgr returned by GetTPUConfigResourceMgr() is a special
735     //   process-global ResourceMgr. There is only one TPUConfigResourceMgr, and
736     //   it is never destroyed.
737     // - TPU_SYSTEM device resource manager is a normal device ResourceMgr for
738     //   TPU_SYSTEM device. If DirectSession or isolate_session_state are used,
739     //   there's one TPU_SYSTEM ResourceMgr for each session, and the
740     //   ResourceMgrs will be destroyed when their corresponding session is
741     //   closed. Otherwise there's one TPU_SYSTEM ResourceMgr that's only
742     //   destroyed when the master-session is destroyed, not when the worker
743     //   sessions are destroyed
744     TpuCompilationCacheEntryUnloader* unloader;
745     TF_RETURN_IF_ERROR(
746         ctx->resource_manager()
747             ->LookupOrCreate<TpuCompilationCacheEntryUnloader>(
748                 ctx->resource_manager()->default_container(),
749                 kCompilationCacheUnloaderResourceName, &unloader,
750                 [cache](TpuCompilationCacheEntryUnloader** new_unloader) {
751                   *new_unloader = new TpuCompilationCacheEntryUnloader(cache);
752                   return Status::OK();
753                 }));
754     // Note that LookupOrCreate puts two refcounts on unloader.
755     core::ScopedUnref unloader_unref(unloader);
756     unloader->AddCacheEntryUid(uid);
757   }
758 
759   int64 num_cores_with_compiled_programs = proto_key.size();
760   if (proto_key.size() == 1) {
761     // SPMD produces 1 program for all cores.
762     num_cores_with_compiled_programs = metadata_.num_cores_per_replica();
763   }
764   if (status.ok() &&
765       num_cores_with_compiled_programs +
766               (may_modify_variables.size() * static_cast<int>(!use_mlir_)) !=
767           ctx->num_outputs() - 1) {
768     status = errors::Internal(
769         "Number of cores with compiled programs (",
770         num_cores_with_compiled_programs, ") + variable states (",
771         may_modify_variables.size() * static_cast<int>(!use_mlir_),
772         ") + compilation status output != number of compile op outputs (",
773         ctx->num_outputs(), ")");
774   }
775 
776   // TODO(jpienaar): status is not just due to the compilation. At this
777   // point we should be failing the execution of the op in some cases and
778   // returning a compilation error in others. For now, uniformly return an
779   // error and fail in _TPUExecute if status failed here.
780 
781   // TODO(misard) the frame id will be wrong if this is ever called from
782   // within a function. Consider whether to use the same hack as is
783   // present in the rendezvous manager where the function call frame is
784   // cast to a uint64, or do something better all around.
785   std::string rendezvous_key_base = strings::StrCat(
786       "host_compute_rendezvous:", ctx->op_kernel().name(), ":",
787       ctx->frame_iter().frame_id, ":", ctx->frame_iter().iter_id, ":");
788 
789   // Return compilation status.
790   {
791     Tensor output(DT_STRING, TensorShape({}));
792     tpu::CompilationResultProto proto;
793     proto.set_status_code(status.code());
794     if (!status.ok()) {
795       proto.set_status_error_message(
796           absl::StrCat("Compilation failure: ", status.error_message()));
797     }
798     if (return_hlo_protos_) {
799       // Return the HloProtos as part of compilation status.
800       for (const xla::HloProto* hlo_metadata : hlo_metadatas) {
801         xla::HloProto* hlo_proto = proto.add_hlo_protos();
802         *hlo_proto = *hlo_metadata;
803       }
804     }
805     SerializeToTString(proto, &output.scalar<tstring>()());
806     ctx->set_output(0, output);
807     status.SetPayload(TpuCompileInterface::kTpuCompileErrorPayloadKey,
808                       output.scalar<tstring>()());
809   }
810 
811   if (status.ok()) {
812     for (int i = 0; i < num_cores_with_compiled_programs; ++i) {
813       Tensor output(DT_STRING, TensorShape({3}));
814       if (proto_key.size() == 1) {
815         output.vec<tstring>()(0) = proto_key[0];
816       } else {
817         output.vec<tstring>()(0) = proto_key[i];
818       }
819       output.vec<tstring>()(1) = rendezvous_key_base;
820       if (sharding_key.empty()) {
821         output.vec<tstring>()(2) = "";
822       } else if (sharding_key.size() == 1) {
823         output.vec<tstring>()(2) = sharding_key[0];
824       } else {
825         TF_RET_CHECK(sharding_key.size() == num_cores_with_compiled_programs);
826         output.vec<tstring>()(2) = sharding_key[i];
827       }
828       ctx->set_output(i + 1, output);
829     }
830     if (!use_mlir_) {
831       // If any of the programs may modify a variable, then return that all
832       // do as the only current state being tracked here is if a model is
833       // read-only or not.
834       bool may_modify = false;
835       for (bool m : may_modify_variables) {
836         may_modify = may_modify || m;
837       }
838       for (int i = 0; i < may_modify_variables.size(); ++i) {
839         Tensor output(DT_BOOL, TensorShape({}));
840         output.scalar<bool>()() = may_modify;
841         ctx->set_output(i + num_cores_with_compiled_programs + 1, output);
842       }
843     }
844     VLOG(1) << "Cloud TPU: Compilation succeeded";
845   } else {
846     // Return error in the invalid case.
847     for (int i = 0; i < num_computations_; ++i) {
848       Tensor output(DT_STRING, TensorShape({3}));
849       output.vec<tstring>()(0) = "<<NO PROGRAM AS COMPILATION FAILED>>";
850       output.vec<tstring>()(1) = "<<NO RENDEZVOUS KEY AS COMPILATION FAILED>>";
851       output.vec<tstring>()(2) = "<<NO SHARDing KEY AS COMPILATION FAILED>>";
852       ctx->set_output(i + 1, output);
853     }
854     if (!use_mlir_) {
855       // The TPUCompileMLIR op does not have MayModifyVariable output
856       for (int i = 0; i < num_computations_; ++i) {
857         Tensor output(false);
858         ctx->set_output(i + num_computations_ + 1, output);
859       }
860     }
861   }
862   return status;
863 }
864 }  // namespace tpu
865 }  // namespace tensorflow
866