• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/kernels/tpu_functional_ops.h"
17 
18 #include <memory>
19 
20 #include "tensorflow/core/framework/op_kernel.h"
21 
22 #define EIGEN_USE_THREADS
23 
24 #include "absl/base/call_once.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/synchronization/mutex.h"
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/compiler/tf2xla/sharding_util.h"
29 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/common_runtime/function_body.h"
32 #include "tensorflow/core/common_runtime/graph_constructor.h"
33 #include "tensorflow/core/common_runtime/placer.h"
34 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
35 #include "tensorflow/core/framework/graph_to_functiondef.h"
36 #include "tensorflow/core/framework/metrics.h"
37 #include "tensorflow/core/framework/node_def.pb.h"
38 #include "tensorflow/core/framework/node_def_util.h"
39 #include "tensorflow/core/framework/resource_mgr.h"
40 #include "tensorflow/core/framework/resource_var.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/tensor.pb.h"
43 #include "tensorflow/core/framework/tensor_shape.h"
44 #include "tensorflow/core/graph/graph_partition.h"
45 #include "tensorflow/core/graph/node_builder.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/hash/hash.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/platform/blocking_counter.h"
50 #include "tensorflow/core/platform/errors.h"
51 #include "tensorflow/core/platform/fingerprint.h"
52 #include "tensorflow/core/platform/refcount.h"
53 #include "tensorflow/core/profiler/lib/traceme.h"
54 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
55 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
56 #include "tensorflow/core/tpu/kernels/tpu_fingerprint_lookup.h"
57 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
58 #include "tensorflow/core/tpu/kernels/tpu_op_util.h"
59 #include "tensorflow/core/tpu/kernels/tpu_util.h"
60 #include "tensorflow/core/tpu/tpu_configuration.h"
61 #include "tensorflow/core/tpu/tpu_defs.h"
62 #include "tensorflow/core/util/dump_graph.h"
63 
64 namespace tensorflow {
65 namespace {
66 
67 constexpr char kTpuReplicateAttr[] = "_tpu_replicate";
68 constexpr int kLastDimOfTpuInputFastPath = 128;
69 constexpr int kOtherDimOfTpuInputFastPath = 8;
70 
71 constexpr char kXLAShardingAttrName[] = "sharding";
72 constexpr char kXLAShardingAttrAltName[] = "_XlaSharding";
73 
GenerateDeviceNaturalOrder(int x_num_cores,int y_num_cores,int z_num_cores,int num_cores_per_chip,std::vector<int> * natural_order)74 Status GenerateDeviceNaturalOrder(int x_num_cores, int y_num_cores,
75                                   int z_num_cores, int num_cores_per_chip,
76                                   std::vector<int>* natural_order) {
77   for (int y = 0; y < y_num_cores; ++y) {
78     for (int x = 0; x < x_num_cores; ++x) {
79       for (int z = 0; z < z_num_cores; ++z) {
80         for (int c = 0; c < num_cores_per_chip; ++c) {
81           natural_order->push_back(x);
82           natural_order->push_back(y);
83           natural_order->push_back(z);
84           natural_order->push_back(c);
85         }
86       }
87     }
88   }
89 
90   return Status::OK();
91 }
92 
93 struct TPUVariableInfo {
TPUVariableInfotensorflow::__anone29eeb940111::TPUVariableInfo94   TPUVariableInfo(int device_ordinal_id, bool use_fast_mem)
95       : device_ordinal(device_ordinal_id), fast_mem(use_fast_mem) {}
96   // The TPU core which the variable will be placed on.
97   int device_ordinal;
98   // If true, try to place the variable on fast memory space if hardware
99   // support.
100   bool fast_mem;
101 };
102 
103 // Check the descendants to parse the placement information for the input node.
104 // num_cores_per_replica descriables how many cores the single model uses.
ParseTPUVariableInfor(const Node * node,const int num_cores_per_replica,TPUVariableInfo * var_info)105 Status ParseTPUVariableInfor(const Node* node, const int num_cores_per_replica,
106                              TPUVariableInfo* var_info) {
107   int core = 0;
108   bool use_fast_mem = false;
109   VLOG(3) << "Parse tpu variable information for " << node->name();
110   for (const Edge* edge : node->out_edges()) {
111     if (edge->IsControlEdge()) continue;
112     Node* next = edge->dst();
113     VLOG(3) << "Neighbor node " << next->name();
114     // Looking through Enter/Switch/ReadVariableOp nodes.
115     while (next->IsEnter() || next->IsSwitch() ||
116            next->type_string() == "ReadVariableOp") {
117       Node* new_node = nullptr;
118       for (const Edge* e : next->out_edges()) {
119         if (!e->IsControlEdge()) {
120           new_node = e->dst();
121           break;
122         }
123       }
124       if (new_node == nullptr) break;
125       next = new_node;
126     }
127     if (next != edge->dst()) {
128       VLOG(3) << "Looked through Enter/Switch node " << next->DebugString();
129     }
130     TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
131                         ParseShardingFromDevice(*next, num_cores_per_replica,
132                                                 /*add_metadata=*/false));
133     if (sharding.has_value() && sharding->tile_assignment_devices_size() > 0) {
134       core = sharding->tile_assignment_devices(0);
135       VLOG(3) << next->name() << " is placed on core " << core;
136     }
137     if (next->attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
138       use_fast_mem = true;
139       VLOG(3) << next->name() << " has " << TPU_FAST_MEM_ATTR << " attribute";
140     }
141   }
142   VLOG(1) << "Place " << node->name() << " to core: " << core
143           << " fast_mem: " << use_fast_mem;
144   var_info->device_ordinal = core;
145   var_info->fast_mem = use_fast_mem;
146 
147   return Status::OK();
148 }
149 
150 // Helper to instantiate function "func" in the library "lib".
Instantiate(FunctionLibraryRuntime * lib,const NameAttrList & func,FunctionLibraryRuntime::Handle * handle)151 Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func,
152                    FunctionLibraryRuntime::Handle* handle) {
153   return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle);
154 }
155 
156 static constexpr const char* const kDeviceOrdinalAttr = "device_ordinal";
157 
158 static constexpr const char* const kTPUExecuteOp = "TPUExecute";
159 static constexpr const char* const kInfeedEnqueueOp = "InfeedEnqueue";
160 static constexpr const char* const kInfeedEnqueueTupleOp = "InfeedEnqueueTuple";
161 static constexpr const char* const kOutfeedDequeueOp = "OutfeedDequeue";
162 static constexpr const char* const kOutfeedDequeueTupleOp =
163     "OutfeedDequeueTuple";
164 static constexpr const char* const kOutfeedDequeueV2Op = "OutfeedDequeueV2";
165 static constexpr const char* const kOutfeedDequeueTupleV2Op =
166     "OutfeedDequeueTupleV2";
167 static constexpr const char* const kVarHandleOp = "VarHandleOp";
168 
169 static constexpr const char* const kTPUDeviceNamePrefix = "/device:TPU:";
170 static constexpr const int kTPUDefaultDeviceOrdinal = 0;
171 
IsSupportedTPUOp(const string & op_name)172 bool IsSupportedTPUOp(const string& op_name) {
173   return op_name == kTPUExecuteOp || op_name == kInfeedEnqueueOp ||
174          op_name == kInfeedEnqueueTupleOp || op_name == kOutfeedDequeueOp ||
175          op_name == kOutfeedDequeueTupleOp || op_name == kOutfeedDequeueV2Op ||
176          op_name == kOutfeedDequeueTupleV2Op;
177 }
178 
179 // Sets the sharding attributes for an XlaSharding node.
SetXlaShardingNodeAttr(Node * xla_sharding_node,int num_cores_per_replica,int rank,int shard_dim)180 void SetXlaShardingNodeAttr(Node* xla_sharding_node, int num_cores_per_replica,
181                             int rank, int shard_dim) {
182   auto sharding = absl::make_optional<xla::OpSharding>();
183   sharding->set_type(xla::OpSharding::OTHER);
184 
185   std::vector<int64> dims(rank, 1LL);
186   dims[shard_dim] = num_cores_per_replica;
187   for (auto dim : dims) {
188     sharding->add_tile_assignment_dimensions(dim);
189   }
190 
191   // Sets up tile_assignment_devices.
192   for (int d = 0; d < num_cores_per_replica; ++d) {
193     sharding->add_tile_assignment_devices(d);
194   }
195 
196   xla_sharding_node->ClearAttr(kXLAShardingAttrName);
197   xla_sharding_node->ClearAttr(kXLAShardingAttrAltName);
198   xla_sharding_node->AddAttr(kXLAShardingAttrName,
199                              sharding->SerializeAsString());
200   xla_sharding_node->AddAttr(kXLAShardingAttrAltName,
201                              sharding->SerializeAsString());
202 }
203 
204 // If 'device_name' is a TPU device, set its device_ordinal to 'device_ordinal'
205 // and set '*rewritten' to true. Otherwise, do nothing.
UpdateTPUDeviceOrdinal(int device_ordinal,string * device_name,bool * rewritten)206 Status UpdateTPUDeviceOrdinal(int device_ordinal, string* device_name,
207                               bool* rewritten) {
208   DeviceNameUtils::ParsedName device;
209   if (!DeviceNameUtils::ParseFullName(*device_name, &device)) {
210     return errors::InvalidArgument("Unable to parse device name ",
211                                    *device_name);
212   }
213   if (device.type == DEVICE_TPU_NODE) {
214     device.id = device_ordinal;
215     *rewritten = true;
216   }
217   *device_name = DeviceNameUtils::ParsedNameToString(device);
218   return Status::OK();
219 }
220 
FindHostToDeviceEdge(Node * arg_node)221 const Edge* FindHostToDeviceEdge(Node* arg_node) {
222   const Edge* candidate_edge = nullptr;
223   for (const Edge* edge : arg_node->out_edges())
224     if (!edge->IsControlEdge()) {
225       // Find CPU -> TPU input edge.
226       const Edge* original_edge;
227       while (edge->src()->attrs().Find(kTpuReplicateAttr) != nullptr ||
228              edge->dst()->attrs().Find(kTpuReplicateAttr) == nullptr) {
229         const Node* new_src = edge->dst();
230         original_edge = edge;
231         for (const Edge* new_edge : new_src->out_edges())
232           if (!new_edge->IsControlEdge()) {
233             original_edge = edge;
234             edge = new_edge;
235             break;
236           }
237         if (original_edge == edge) break;
238       }
239       // TPU input edge: src is on CPU and dest is on TPU.
240       if (edge->src()->attrs().Find(kTpuReplicateAttr) != nullptr ||
241           edge->dst()->attrs().Find(kTpuReplicateAttr) == nullptr)
242         continue;
243       // Won't work with GuaranteeConst.
244       if (edge->src()->type_string() == "GuaranteeConst") break;
245       candidate_edge = edge;
246     }
247   return candidate_edge;
248 }
249 
CreateInputProxy(Graph * graph,const Edge * candidate_edge,const Edge ** tpu_input_edge)250 Status CreateInputProxy(Graph* graph, const Edge* candidate_edge,
251                         const Edge** tpu_input_edge) {
252   std::vector<const Edge*> edges_to_replace;
253   for (const Edge* input_edge : candidate_edge->src()->out_edges()) {
254     if (!input_edge->IsControlEdge() &&
255         input_edge->dst()->attrs().Find(kTpuReplicateAttr) != nullptr)
256       edges_to_replace.push_back(input_edge);
257   }
258   // Build an Identity node as the proxy of the original edge source.
259   Node* input_identity_node = nullptr;
260   TF_RETURN_IF_ERROR(
261       NodeBuilder(strings::StrCat(candidate_edge->src()->name(), "/proxy"),
262                   "Identity")
263           .Input(candidate_edge->src())
264           .Attr("T", candidate_edge->src()->output_type(0))
265           .Attr(kTpuReplicateAttr,
266                 candidate_edge->dst()->attrs().Find(kTpuReplicateAttr)->s())
267           .Finalize(graph, &input_identity_node));
268   // Find the tpu input edge from original source to proxy identity.
269   for (const Edge* input_edge : input_identity_node->in_edges())
270     if (input_edge->src() == candidate_edge->src()) {
271       *tpu_input_edge = input_edge;
272       break;
273     }
274   // Replace original input edges with proxy's output.
275   for (const Edge* input_edge : edges_to_replace) {
276     graph->RemoveEdge(input_edge);
277     graph->AddEdge(input_identity_node, 0, input_edge->dst(),
278                    input_edge->dst_input());
279   }
280   return Status::OK();
281 }
282 
GetClusterName(Graph * graph,string * cluster_name)283 Status GetClusterName(Graph* graph, string* cluster_name) {
284   *cluster_name = "";
285   for (const Node* node : graph->nodes()) {
286     if (node->attrs().Find(kTpuReplicateAttr) == nullptr) continue;
287     if (cluster_name->empty())
288       *cluster_name = node->attrs().Find(kTpuReplicateAttr)->s();
289     // When optimization is turned on, the graph should only have one TPU
290     // cluster.
291     if (*cluster_name != node->attrs().Find(kTpuReplicateAttr)->s())
292       return errors::FailedPrecondition(
293           "Only one cluster is allowed when optimization is turned on for "
294           "TPUPartitionedCall. Found ",
295           node->attrs().Find(kTpuReplicateAttr)->s(), " and ", *cluster_name);
296   }
297   return Status::OK();
298 }
299 
300 // Removes nodes that has no effect that directly descends from _Arg node.
301 //
302 // This is currently used for removing TPUReplicatedInput and XlaSharding node
303 // are always descendants of _Arg node. During optimization, we try to insert
304 // nodes in between _Arg and _Arg's children, where some of the nodes inserted
305 // are TPU nodes. We will add the TPUReplicatedInput and XlaSharding op nodes
306 // back where necessary.
307 //
308 // Returns the number of nodes that were removed.
RemoveDescendantNodeOfArg(Graph * graph,const std::string & node_type_to_remove,const std::set<std::string> & must_be_child_of)309 int64 RemoveDescendantNodeOfArg(Graph* graph,
310                                 const std::string& node_type_to_remove,
311                                 const std::set<std::string>& must_be_child_of) {
312   int64_t nodes_removed = 0;
313   std::vector<std::pair<const Edge*, std::vector<const Edge*>>> edges_to_remove;
314 
315   for (Node* node : graph->nodes()) {
316     if (node_type_to_remove != node->type_string()) continue;
317     if (!must_be_child_of.empty()) {
318       bool has_arg_parent = false;
319       for (const Edge* edge : node->in_edges()) {
320         if (must_be_child_of.count(edge->src()->type_string()) > 0) {
321           has_arg_parent = true;
322         }
323       }
324       if (!has_arg_parent) continue;
325     }
326     nodes_removed++;
327 
328     const Edge* input_edge = nullptr;
329     std::vector<const Edge*> output_edges;
330     for (const Edge* edge : node->in_edges())
331       if (!edge->IsControlEdge()) {
332         input_edge = edge;
333         break;
334       }
335     for (const Edge* edge : node->out_edges())
336       if (!edge->IsControlEdge()) {
337         output_edges.push_back(edge);
338       }
339     if (input_edge != nullptr && !output_edges.empty())
340       edges_to_remove.push_back(std::make_pair(input_edge, output_edges));
341   }
342   for (const auto& it : edges_to_remove) {
343     for (const Edge* output_edge : it.second) {
344       graph->RemoveEdge(output_edge);
345       graph->AddEdge(it.first->src(), it.first->src_output(),
346                      output_edge->dst(), output_edge->dst_input());
347     }
348     graph->RemoveNode(it.first->dst());
349   }
350   return nodes_removed;
351 }
352 
GetInputHash(OpKernelContext * ctx)353 uint64 GetInputHash(OpKernelContext* ctx) {
354   uint64 input_hash = 0;  // initialization for determinism.
355   // Use the number of elements to compute hash.
356   // TODO(chiachenc): use fhe full shape to compute the hash.
357   for (int i = 0; i < ctx->num_inputs(); ++i) {
358     VLOG(4) << "InputHash, combine input " << i
359             << ", NumElements: " << ctx->input(i).NumElements();
360     input_hash = Hash64Combine(input_hash, ctx->input(i).NumElements());
361   }
362   return input_hash;
363 }
364 
HashShapeAndType(const string prefix,const std::vector<int> & input_dims,const DataType & dtype,const bool input_shape_opt)365 string HashShapeAndType(const string prefix, const std::vector<int>& input_dims,
366                         const DataType& dtype, const bool input_shape_opt) {
367   string hash = strings::StrCat(prefix, dtype, "_dims");
368   // We will concat at the last dimension.
369   for (int d = 0; d < input_dims.size() - 1; ++d) {
370     strings::StrAppend(&hash, "_", input_dims.at(d));
371   }
372 
373   if (input_shape_opt) {
374     if (input_dims.back() % kLastDimOfTpuInputFastPath == 0) {
375       strings::StrAppend(&hash, "_last_", kLastDimOfTpuInputFastPath, "n");
376     } else {
377       strings::StrAppend(&hash, "_last_other");
378     }
379   }
380   return hash;
381 }
382 
383 // Get the information for input and output tensors (shapes, dtypes, etc).
GetInputOutputInfo(Graph * graph,GraphShapeInfo & tpu_inferred_info,std::map<int,InferredShape> & arg_shapes,EdgeShapes & tpu_input_shapes,absl::flat_hash_map<const Edge *,DataType> & tpu_input_dtypes,OpKernelContext * ctx)384 Status GetInputOutputInfo(
385     Graph* graph, GraphShapeInfo& tpu_inferred_info,
386     std::map<int, InferredShape>& arg_shapes, EdgeShapes& tpu_input_shapes,
387     absl::flat_hash_map<const Edge*, DataType>& tpu_input_dtypes,
388     OpKernelContext* ctx) {
389   // Search for the device-to-host or tpu-to-cpu edges.
390   for (Node* node : graph->op_nodes()) {
391     if (!node->IsArg()) continue;
392     const DataType dtype = node->attrs().Find("T")->type();
393     const int arg_index = node->attrs().Find("index")->i();
394     if (dtype != DT_INT32 && dtype != DT_BFLOAT16 && dtype != DT_FLOAT &&
395         dtype != DT_BOOL && dtype != DT_QINT8 && dtype != DT_QUINT8)
396       continue;
397     VLOG(3) << "Argnode: " << node->DebugString();
398     const Tensor& tensor = ctx->input(arg_index);
399 
400     // Search for the cross-device edge from arg node.
401     const Edge* candidate_edge = FindHostToDeviceEdge(node);
402     if (candidate_edge == nullptr) continue;
403 
404     // Make proxy and get the sole tpu_input_edge for transfer the input tensor
405     // corresponding to the current _Arg node.
406     const Edge* tpu_input_edge = nullptr;
407     TF_RETURN_IF_ERROR(
408         CreateInputProxy(graph, candidate_edge, &tpu_input_edge));
409     if (tpu_input_edge == nullptr)
410       return errors::NotFound("Couldn't find TPU input edge for", node->name());
411 
412     // Optimize edge: original source to proxy identity.
413     VLOG(3) << "Input: " << tpu_input_edge->src()->name();
414     std::vector<int>& input_shapes = tpu_input_shapes[tpu_input_edge];
415     input_shapes.clear();
416     for (int d = 0; d < tensor.dims(); ++d) {
417       input_shapes.push_back(tensor.dim_size(d));
418       VLOG(3) << "Input Tensor: Dim[" << d << "] = " << tensor.dim_size(d);
419     }
420     tpu_input_dtypes[tpu_input_edge] = tensor.dtype();
421 
422     // Collect shapes for non-resource-variable args.
423     PartialTensorShape partial_tensor_shape;
424     auto partial_shape = PartialTensorShape::MakePartialShape(
425         input_shapes.data(), input_shapes.size(), &partial_tensor_shape);
426     InferredShape inferred_shape = {partial_tensor_shape};
427     arg_shapes[arg_index] = inferred_shape;
428   }
429   return Status::OK();
430 }
431 
432 // Converts a integer vector that represents the shapes to a Tensorshape.
ConvertEdgeShapesToTensorShapes(const std::map<std::string,std::vector<int>> & named_input_shapes,std::vector<TensorShape> * shapes)433 Status ConvertEdgeShapesToTensorShapes(
434     const std::map<std::string, std::vector<int>>& named_input_shapes,
435     std::vector<TensorShape>* shapes) {
436   shapes->resize(named_input_shapes.size());
437   int32_t i = 0;
438   // keys in tpu_input_shapes may be stale.
439   for (const auto& iter : named_input_shapes) {
440     VLOG(2) << iter.first << ", rank: " << iter.second.size();
441     const int64_t rank = iter.second.size();
442     std::vector<int64> dims(rank);
443     for (int64_t d = 0; d < rank; ++d) {
444       VLOG(2) << " dim[" << d << "]: " << iter.second.at(d);
445       dims[d] = iter.second.at(d);
446     }
447     TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(dims, &(*shapes)[i]));
448     i++;
449   }
450   return Status::OK();
451 }
452 
453 // Get the TF fingerprint with the information from the TPUCompileOp or
454 // _TPUCompileMlirOp.
MaybeRegisterFingerprint(Graph * graph,const std::map<std::string,std::vector<int>> & named_input_shapes,uint64 input_hash)455 Status MaybeRegisterFingerprint(
456     Graph* graph,
457     const std::map<std::string, std::vector<int>>& named_input_shapes,
458     uint64 input_hash) {
459   // Find the compiler metadata.
460   tpu::TPUCompileMetadataProto metadata_proto;
461   std::map<std::string, std::vector<int>> inputs_to_keep;
462   int num_dynamic_shapes = -1;
463   tensorflow::uint64 fingerprint = 0;
464 
465   for (Node* node : graph->op_nodes()) {
466     if (node->type_string() == "TPUCompile" ||
467         node->type_string() == "_TPUCompileMlir") {
468       num_dynamic_shapes = node->attrs().Find("NumDynamicShapes")->i();
469       if (num_dynamic_shapes <= 0) {
470         break;
471       }
472       int visited = 0;
473       // TPUCompileOp/_TPUCompileMlirOp take Shape nodes as inputs.
474       // The number of Shape nodes matches the number of dynamic shaped inputs.
475       // The Shape nodes come from the input nodes:
476       //   [TPU Input] --> [Input Shape] --> [TPUCompileOp]
477       for (auto in_node : node->in_nodes()) {
478         if (in_node->type_string() != "Shape") {
479           continue;
480         }
481         for (auto input_node : in_node->in_nodes()) {
482           auto iter = named_input_shapes.find(input_node->name());
483           if (iter != named_input_shapes.end()) {
484             inputs_to_keep[iter->first] = iter->second;
485           }
486         }
487         visited++;
488         if (visited == num_dynamic_shapes) {
489           break;
490         }
491       }
492       std::string metadata = node->attrs().Find("metadata")->s();
493       metadata_proto.ParseFromString(metadata);
494 
495       if (node->type_string() == "_TPUCompileMlir") {
496         std::string mlir_module = node->attrs().Find("mlir_module")->s();
497         fingerprint = tensorflow::Fingerprint64(mlir_module);
498       } else {
499         fingerprint = metadata_proto.function_library_fingerprint();
500       }
501 
502       break;
503     }
504   }
505   VLOG(2) << "inputs_to_keep size: " << inputs_to_keep.size();
506   if (inputs_to_keep.size() != num_dynamic_shapes) {
507     VLOG(2) << "Cannot match all inputs shapes. Skip fingerprint registration.";
508     return Status::OK();
509   }
510 
511   std::vector<TensorShape> input_shapes;
512   TF_RETURN_IF_ERROR(
513       ConvertEdgeShapesToTensorShapes(inputs_to_keep, &input_shapes));
514 
515   std::vector<TensorShape> arg_shapes;
516   auto status =
517       tpu::ComputeArgumentShapes(metadata_proto, input_shapes, &arg_shapes);
518   if (!status.ok()) {
519     VLOG(2) << status.error_message();
520     return Status::OK();
521   }
522   uint64 tf_fingerprint =
523       tpu::CreateFingerprintWithNameAndShapes(fingerprint, arg_shapes);
524   VLOG(2) << "fingerprint: " << fingerprint;
525   VLOG(2) << "TF fingerprint: " << tf_fingerprint;
526 
527   ResourceMgr* rm = GetTPUConfigResourceMgr();
528   tpu::TpuFingerprintLookup* fingerprint_lookup;
529   TF_RETURN_IF_ERROR(rm->Lookup<tpu::TpuFingerprintLookup>(
530       rm->default_container(), tpu::kFingerprintLookupResourceName,
531       &fingerprint_lookup));
532   fingerprint_lookup->RegisterKeyAndIntermediatePair(input_hash,
533                                                      tf_fingerprint);
534   return Status::OK();
535 }
536 
FindTpuReplicatedInputAndXlaSharding(const Graph * graph,XlaShardingInfoMap & xla_sharding_ops,TpuReplicatedInputInfoMap & tpu_replicated_input_ops)537 bool FindTpuReplicatedInputAndXlaSharding(
538     const Graph* graph, XlaShardingInfoMap& xla_sharding_ops,
539     TpuReplicatedInputInfoMap& tpu_replicated_input_ops) {
540   bool xla_spmd_input_sharded = false;
541   // Detect whether there are XLA Sharding on the inputs, if there are, then
542   // we cannot remove the replicated inputs or the xla sharding ops.
543   for (Node* xla_sharding_node : graph->nodes()) {
544     if (xla_sharding_node->type_string() == "XlaSharding") {
545       for (const Edge* edge : xla_sharding_node->in_edges()) {
546         if (edge->src()->type_string() == "TPUReplicatedInput") {
547           Node* tpu_replicated_input_node = edge->src();
548           Node* tpu_replicated_metadata_node = nullptr;
549           for (const Edge* input_edge : tpu_replicated_input_node->in_edges()) {
550             if (input_edge->IsControlEdge()) {
551               tpu_replicated_metadata_node = input_edge->src();
552             }
553           }
554 
555           for (const Edge* input_edge : tpu_replicated_input_node->in_edges()) {
556             if (input_edge->src()->type_string() == "_Arg") {
557               Node* arg_node = input_edge->src();
558 
559               xla_sharding_ops[arg_node->name()] = std::make_tuple(
560                   xla_sharding_node->attrs().Find("T")->type(),
561                   xla_sharding_node->attrs().Find("sharding")->s(),
562                   xla_sharding_node->attrs().Find("_tpu_replicate")->s());
563 
564               tpu_replicated_input_ops[arg_node->name()] = std::make_tuple(
565                   tpu_replicated_input_node->attrs().Find("T")->type(),
566                   tpu_replicated_metadata_node);
567 
568               VLOG(2) << "Detected input is sharded. XlaSharding node: "
569                       << xla_sharding_node->DebugString()
570                       << ", TPUReplicatedInput node: "
571                       << edge->src()->DebugString()
572                       << ", _Arg node: " << arg_node->DebugString();
573               xla_spmd_input_sharded = true;
574               break;
575             }
576           }
577         }
578       }
579     }
580   }
581   return xla_spmd_input_sharded;
582 }
583 
584 }  // end namespace
585 
586 namespace tpu_functional_internal {
587 
588 // An optimization pass that separates tensors to leverage the fast path in
589 // TPU input preparation. The algorithm is as follows:
590 // (1) Group all tensors that have same dimensions except the last dimension. A
591 // group of tensors will be concatenated by the last dimension in a later pass.
592 // (2) Check all groups of tensors and find groups whose dimensions after concat
593 // cannot leverage the fast path.
594 // (3) For groups of tensors that don't leverage the fast path, split tensors
595 // into two sub-groups such that one sub-group of tensors can leverage the fast
596 // path.
597 // Exception in (2) is that concated tensors are small, which means separating
598 // tensors would introduce overheads of data transfer to device.
599 // This optimization takes effect when both --input_shape_opt and
600 // --group_tensors_for_packing are true.
GroupTensorsForInputPacking(const EdgeShapes & tpu_input_shapes,const absl::flat_hash_map<const Edge *,DataType> & tpu_input_dtypes,bool input_shape_opt,bool group_tensors_for_packing)601 GroupedEdges GroupTensorsForInputPacking(
602     const EdgeShapes& tpu_input_shapes,
603     const absl::flat_hash_map<const Edge*, DataType>& tpu_input_dtypes,
604     bool input_shape_opt, bool group_tensors_for_packing) {
605   GroupedEdges grouped_input_edges;
606   for (const auto& iter : tpu_input_shapes) {
607     if (iter.second.empty()) continue;
608     DataType dtype = tpu_input_dtypes.find(iter.first)->second;
609     string hash_key = HashShapeAndType("input_tensors_dtype_", iter.second,
610                                        dtype, /*input_shape_opt*/ false);
611     grouped_input_edges[hash_key].push_back(iter.first);
612   }
613   // Apply grouping when both are true.
614   if (!input_shape_opt || !group_tensors_for_packing)
615     return grouped_input_edges;
616 
617   GroupedEdges grouped_input_edges_opt;
618   for (const auto& iter : grouped_input_edges) {
619     int sum_last_dim = 0;
620     int product_other_dims = 0;
621     VLOG(3) << "group name: " << iter.first;
622     for (const auto& edge : iter.second) {
623       const std::vector<int>& input_shapes =
624           tpu_input_shapes.find(edge)->second;
625       sum_last_dim += input_shapes.back();
626       if (product_other_dims == 0) {
627         product_other_dims = 1;
628         for (int d = 0; d < input_shapes.size() - 1; ++d)
629           product_other_dims *= input_shapes.at(d);
630       }
631     }
632     VLOG(3) << "sum_last_dim: " << sum_last_dim;
633     VLOG(3) << "product_other_dims: " << product_other_dims;
634     // Already uses fast path, skip further grouping.
635     if ((sum_last_dim % kLastDimOfTpuInputFastPath) == 0 &&
636         (product_other_dims % kOtherDimOfTpuInputFastPath) == 0) {
637       grouped_input_edges_opt[iter.first] = iter.second;
638       continue;
639     }
640     // Tensors are small, skip further grouping.
641     if ((sum_last_dim * product_other_dims) <
642         (kLastDimOfTpuInputFastPath * kOtherDimOfTpuInputFastPath)) {
643       grouped_input_edges_opt[iter.first] = iter.second;
644       continue;
645     }
646     VLOG(3) << "Splitting tensors.";
647     for (const auto& edge : iter.second) {
648       auto tpu_input_shape = tpu_input_shapes.find(edge)->second;
649       string hash_key =
650           HashShapeAndType("input_tensors_dtype_", tpu_input_shape,
651                            tpu_input_dtypes.find(edge)->second,
652                            /*input_shape_opt*/ true);
653       grouped_input_edges_opt[hash_key].push_back(edge);
654     }
655   }
656   return grouped_input_edges_opt;
657 }
658 
GroupTensorsForOutputPacking(Graph * graph,EdgeShapes & tpu_output_shapes,GraphShapeInfo * shape_info)659 GroupedEdges GroupTensorsForOutputPacking(Graph* graph,
660                                           EdgeShapes& tpu_output_shapes,
661                                           GraphShapeInfo* shape_info) {
662   GroupedEdges shape_to_output;
663   for (const Edge* edge : graph->edges()) {
664     if (edge->IsControlEdge()) continue;
665 
666     // TPU input edge: src is on TPU and dest is on CPU.
667     if (edge->dst()->type_string() != "TPUReplicatedOutput") continue;
668     if (!shape_info->count(edge->src()->name())) continue;
669 
670     // output shapes for hashing
671     std::vector<int>& output_shapes = tpu_output_shapes[edge];
672     output_shapes.clear();
673 
674     int output_id = edge->src_output();
675     auto inferred_shape_vec = shape_info->at(edge->src()->name());
676 
677     for (int d : inferred_shape_vec.at(output_id).shape.dim_sizes()) {
678       output_shapes.push_back(d);
679     }
680 
681     // Hash Shape and Types.
682     DataType dtype = edge->src()->input_type(output_id);
683     string hash_key =
684         HashShapeAndType("output_tensors_dtype_", output_shapes, dtype, false);
685 
686     shape_to_output[hash_key].push_back(edge);
687   }
688   return shape_to_output;
689 }
690 
691 // Concatenates input tensors on CPU along the last dimension if all other
692 // dimensions are the same, and split them on TPU to reduce input overhead.
693 // `tpu_input_shapes` maps an edge to the shape of its output tensor.
694 // `grouped_input_edges` maps tensor name to all edges output from this tensor.
CreateConcatAndSplitNodesForInputTensor(Graph * graph,const string & cluster_name,EdgeShapes * tpu_input_shapes,const absl::flat_hash_map<std::string,std::vector<const Edge * >> & grouped_input_edges,int32_t minimum_input_tensors_packing,bool xla_spmd_input_sharded,const XlaShardingInfoMap & xla_sharding_info,const TpuReplicatedInputInfoMap & tpu_replicated_input_info)695 Status CreateConcatAndSplitNodesForInputTensor(
696     Graph* graph, const string& cluster_name, EdgeShapes* tpu_input_shapes,
697     const absl::flat_hash_map<std::string, std::vector<const Edge*>>&
698         grouped_input_edges,
699     int32_t minimum_input_tensors_packing, bool xla_spmd_input_sharded,
700     const XlaShardingInfoMap& xla_sharding_info,
701     const TpuReplicatedInputInfoMap& tpu_replicated_input_info) {
702   for (const auto& iter : grouped_input_edges) {
703     std::vector<int> last_dim_vec;
704     std::vector<NodeBuilder::NodeOut> concat_nodeouts;
705     absl::flat_hash_map<std::string, int> tensor_to_split_output;
706     int rank;
707     DataType dtype = DT_INVALID;
708     std::string src_name;
709     for (const Edge* edge : iter.second) {
710       src_name = edge->src()->name();
711       string tensor_name =
712           absl::StrCat(edge->src()->name(), ":", edge->src_output());
713       // Create Concat / Split pair for a tensor if not exist yet.
714       if (tensor_to_split_output.contains(tensor_name)) continue;
715       tensor_to_split_output[tensor_name] = concat_nodeouts.size();
716       concat_nodeouts.push_back(
717           NodeBuilder::NodeOut(edge->src(), edge->src_output()));
718       dtype = edge->src()->output_type(edge->src_output());
719       rank = tpu_input_shapes->at(edge).size();
720       last_dim_vec.push_back(tpu_input_shapes->at(edge).back());
721     }
722 
723     const int num_tensors = tensor_to_split_output.size();
724     VLOG(3) << iter.first << " num_tensors: " << num_tensors;
725     if (num_tensors < minimum_input_tensors_packing) {
726       VLOG(3) << "skip concat/split " << iter.first;
727       continue;
728     }
729 
730     Node* concat_axis_node = nullptr;
731     TensorShape t_shape;
732     Tensor dim_tensor(DT_INT32, t_shape);
733     // Concat and Split at the last dim.
734     dim_tensor.flat<int>()(0) = rank - 1;
735     TF_RETURN_IF_ERROR(
736         NodeBuilder(strings::StrCat(iter.first, "/concat/axis"), "Const")
737             .Attr("dtype", DT_INT32)
738             .Attr("value", dim_tensor)
739             .Finalize(graph, &concat_axis_node));
740 
741     Node* concat_node = nullptr;
742     TF_RETURN_IF_ERROR(
743         NodeBuilder(strings::StrCat(iter.first, "/concat"), "ConcatV2")
744             .Input(concat_nodeouts)
745             .Input(concat_axis_node)
746             .Attr("T", dtype)
747             .Attr("Tidx", DT_INT32)
748             .Attr("N", num_tensors)
749             .Finalize(graph, &concat_node));
750 
751     Node* split_dim_node = nullptr;
752     TF_RETURN_IF_ERROR(
753         NodeBuilder(strings::StrCat(iter.first, "/split/split_dim"), "Const")
754             .Attr("dtype", DT_INT32)
755             .Attr("value", dim_tensor)
756             .Attr(kTpuReplicateAttr, cluster_name)
757             .Finalize(graph, &split_dim_node));
758 
759     Node* split_vec_node = nullptr;
760     TensorShape split_vec_shape;
761     split_vec_shape.AddDim(1);
762     split_vec_shape.set_dim(0, last_dim_vec.size());
763 
764     Tensor split_vec_tensor(DT_INT32, split_vec_shape);
765     for (int i = 0; i < last_dim_vec.size(); ++i) {
766       split_vec_tensor.flat<int>()(i) = last_dim_vec[i];
767     }
768     VLOG(3) << "split_vec_tensor: " << split_vec_tensor.DebugString();
769 
770     TF_RETURN_IF_ERROR(
771         NodeBuilder(strings::StrCat(iter.first, "/split/vec"), "Const")
772             .Attr("dtype", DT_INT32)
773             .Attr("value", split_vec_tensor)
774             .Attr(kTpuReplicateAttr, cluster_name)
775             .Finalize(graph, &split_vec_node));
776 
777     Node* split_node = nullptr;
778     Node* input_to_split_node = concat_node;
779     Node* output_from_concat_node = nullptr;
780     if (xla_spmd_input_sharded &&
781         tpu_replicated_input_info.count(src_name) > 0 &&
782         xla_sharding_info.count(src_name) > 0) {
783       // Create new TPUReplicatedInput and XLAShardingOp nodes
784       //
785       // Rewrite the graph from:
786       //   Concat -> Split
787       // to
788       //   Concat -> TPUReplicatedInput -> XlaSharding -> Split
789       Node* tpu_replicated_input = nullptr;
790       Node* xla_sharding_op = nullptr;
791 
792       std::vector<NodeBuilder::NodeOut> replicated_input;
793       replicated_input.push_back(NodeBuilder::NodeOut(concat_node));
794 
795       // TODO(b/183060455): Add TPUReplicatedInput to all graphs.
796       TF_RETURN_IF_ERROR(
797           NodeBuilder(strings::StrCat(iter.first, "/TPUReplicatedInput"),
798                       "TPUReplicatedInput")
799               .Input(replicated_input)
800               .ControlInput(std::get<1>(tpu_replicated_input_info.at(src_name)))
801               .Attr("N", 1)
802               .Attr("T", std::get<0>(tpu_replicated_input_info.at(src_name)))
803               .Attr("index", -1)
804               .Attr("is_mirrored_variable", false)
805               .Attr("is_packed", false)
806               .Finalize(graph, &tpu_replicated_input));
807       VLOG(2) << "Created new TPUReplicatedInput node "
808               << tpu_replicated_input->DebugString();
809 
810       TF_RETURN_IF_ERROR(
811           NodeBuilder(strings::StrCat(iter.first, "/XlaSharding"),
812                       "XlaSharding")
813               .Input(tpu_replicated_input)
814               .Attr("T", std::get<0>(xla_sharding_info.at(src_name)))
815               .Attr("sharding", std::get<1>(xla_sharding_info.at(src_name)))
816               .Attr("_XlaSharding", std::get<1>(xla_sharding_info.at(src_name)))
817               .Attr("_tpu_replicate",
818                     std::get<2>(xla_sharding_info.at(src_name)))
819               .Finalize(graph, &xla_sharding_op));
820       VLOG(2) << "Created new XLA sharding node "
821               << xla_sharding_op->DebugString();
822 
823       input_to_split_node = xla_sharding_op;
824       output_from_concat_node = tpu_replicated_input;
825     }
826     // Update the `tpu_input_shapes` mapping: Add the new edge
827     // from concat to split.
828     TF_RETURN_IF_ERROR(
829         NodeBuilder(strings::StrCat(iter.first, "/split"), "SplitV")
830             .Input(input_to_split_node)
831             .Input(split_vec_node)
832             .Input(split_dim_node)
833             .Attr("T", dtype)
834             .Attr("num_split", num_tensors)
835             .Attr(kTpuReplicateAttr, cluster_name)
836             .Finalize(graph, &split_node));
837 
838     if (output_from_concat_node == nullptr)
839       output_from_concat_node = split_node;
840 
841     const Edge* concat_to_split;
842     for (const Edge* edge : concat_node->out_edges())
843       if (edge->dst() == output_from_concat_node) {
844         concat_to_split = edge;
845         break;
846       }
847     if (rank > 1) {
848       for (int d = 0; d < rank - 1; ++d)
849         (*tpu_input_shapes)[concat_to_split].push_back(
850             tpu_input_shapes->at(iter.second.back()).at(d));
851     }
852     (*tpu_input_shapes)[concat_to_split].push_back(
853         std::accumulate(last_dim_vec.begin(), last_dim_vec.end(), 0));
854 
855     // Connect split node to original tensor output.
856     for (const Edge* edge : iter.second) {
857       string tensor_name =
858           absl::StrCat(edge->src()->name(), ":", edge->src_output());
859       int output_index = tensor_to_split_output.at(tensor_name);
860       graph->RemoveEdge(edge);
861       graph->AddEdge(split_node, output_index, edge->dst(), edge->dst_input());
862       // Update the `tpu_input_shapes` mapping: Remove old edges.
863       tpu_input_shapes->erase(edge);
864     }
865     VLOG(3) << "Concat node: " << concat_node->DebugString();
866   }
867   return Status::OK();
868 }
869 
870 // Concatenates input tensors on TPU along the last dimension if all other
871 // dimensions are the same, and split them on CPU to reduce outfeed overhead.
872 // `tpu_inferred_info` maps an edge to the inferred shape of its output tensor.
873 // `shape_to_output` maps tensor name to all edges output from this tensor.
CreateConcatAndSplitNodesForOutputTensor(Graph * graph,const string & cluster_name,EdgeShapes * tpu_output_shapes,GraphShapeInfo * tpu_inferred_info,GroupedEdges shape_to_output,int32_t minimum_output_tensors_packing)874 Status CreateConcatAndSplitNodesForOutputTensor(
875     Graph* graph, const string& cluster_name, EdgeShapes* tpu_output_shapes,
876     GraphShapeInfo* tpu_inferred_info, GroupedEdges shape_to_output,
877     int32_t minimum_output_tensors_packing) {
878   for (const auto& iter : shape_to_output) {
879     std::vector<int> last_dim_vec;
880     std::vector<NodeBuilder::NodeOut> concat_nodeouts;
881     absl::flat_hash_map<std::string, int> tensor_to_split_output;
882     int rank;
883     DataType dtype = DT_INVALID;
884     for (const Edge* edge : iter.second) {
885       string tensor_name =
886           absl::StrCat(edge->src()->name(), ":", edge->src_output());
887 
888       // Create Concat / Split pair for a tensor if not exist yet.
889       if (tensor_to_split_output.contains(tensor_name)) continue;
890       tensor_to_split_output[tensor_name] = concat_nodeouts.size();
891 
892       concat_nodeouts.push_back(
893           NodeBuilder::NodeOut(edge->src(), edge->src_output()));
894       dtype = edge->src()->output_type(edge->src_output());
895       rank = tpu_output_shapes->at(edge).size();
896       last_dim_vec.push_back(tpu_output_shapes->at(edge).back());
897     }
898 
899     const int num_tensors = tensor_to_split_output.size();
900     if (num_tensors < minimum_output_tensors_packing) {
901       VLOG(3) << "skip concat/split " << iter.first;
902       continue;
903     }
904 
905     Node* concat_axis_node = nullptr;
906     TensorShape t_shape;
907     Tensor dim_tensor(DT_INT32, t_shape);
908     // Concat and Split at the last dim.
909     dim_tensor.flat<int>()(0) = rank - 1;
910     TF_RETURN_IF_ERROR(
911         NodeBuilder(strings::StrCat(iter.first, "/concat/axis"), "Const")
912             .Attr("dtype", DT_INT32)
913             .Attr("value", dim_tensor)
914             .Attr(kTpuReplicateAttr, cluster_name)
915             .Finalize(graph, &concat_axis_node));
916 
917     Node* concat_node = nullptr;
918     TF_RETURN_IF_ERROR(
919         NodeBuilder(strings::StrCat(iter.first, "/concat"), "ConcatV2")
920             .Input(concat_nodeouts)
921             .Input(concat_axis_node)
922             .Attr("T", dtype)
923             .Attr("Tidx", DT_INT32)
924             .Attr("N", num_tensors)
925             .Attr(kTpuReplicateAttr, cluster_name)
926             .Finalize(graph, &concat_node));
927 
928     Node* tpu_replicated_output_node = nullptr;
929     TF_RETURN_IF_ERROR(
930         NodeBuilder(strings::StrCat(iter.first, "/tpu_replicated_output"),
931                     "TPUReplicatedOutput")
932             .Input(concat_node)
933             .Attr("T", dtype)
934             .Attr("num_replicas", 1)
935             .Finalize(graph, &tpu_replicated_output_node));
936 
937     Node* split_dim_node = nullptr;
938     TF_RETURN_IF_ERROR(
939         NodeBuilder(strings::StrCat(iter.first, "/split/split_dim"), "Const")
940             .Attr("dtype", DT_INT32)
941             .Attr("value", dim_tensor)
942             .Finalize(graph, &split_dim_node));
943 
944     Node* split_vec_node = nullptr;
945     TensorShape split_vec_shape;
946     split_vec_shape.AddDim(1);
947     split_vec_shape.set_dim(0, last_dim_vec.size());
948 
949     Tensor split_vec_tensor(DT_INT32, split_vec_shape);
950     for (int i = 0; i < last_dim_vec.size(); ++i) {
951       split_vec_tensor.flat<int>()(i) = last_dim_vec[i];
952     }
953     VLOG(3) << "split_vec_tensor: " << split_vec_tensor.DebugString();
954 
955     TF_RETURN_IF_ERROR(
956         NodeBuilder(strings::StrCat(iter.first, "/split/vec"), "Const")
957             .Attr("dtype", DT_INT32)
958             .Attr("value", split_vec_tensor)
959             .Finalize(graph, &split_vec_node));
960 
961     Node* split_node = nullptr;
962     TF_RETURN_IF_ERROR(
963         NodeBuilder(strings::StrCat(iter.first, "/split"), "SplitV")
964             .Input(tpu_replicated_output_node)
965             .Input(split_vec_node)
966             .Input(split_dim_node)
967             .Attr("T", dtype)
968             .Attr("num_split", num_tensors)
969             .Finalize(graph, &split_node));
970 
971     // Update the `tpu_out_shapes` mapping: Add the new edge
972     // from concat to split.
973     const Edge* concat_to_split;
974     for (const Edge* edge : concat_node->out_edges())
975       if (edge->dst() == split_node) {
976         concat_to_split = edge;
977         break;
978       }
979 
980     if (rank > 1) (*tpu_output_shapes)[concat_to_split].push_back(-1);
981     for (int d = 1; d < rank - 1; ++d)
982       (*tpu_output_shapes)[concat_to_split].push_back(
983           tpu_output_shapes->at(iter.second.back()).at(d));
984     (*tpu_output_shapes)[concat_to_split].push_back(
985         std::accumulate(last_dim_vec.begin(), last_dim_vec.end(), 0));
986 
987     for (const Edge* edge : iter.second) {
988       // 1. Find old TPURelicatedOutput output edges
989       std::vector<const Edge*> output_edge_vec;
990       for (const Edge* output_edge : edge->dst()->out_edges())
991         output_edge_vec.push_back(output_edge);
992 
993       string tensor_name =
994           absl::StrCat(edge->src()->name(), ":", edge->src_output());
995       int output_index = tensor_to_split_output.at(tensor_name);
996       VLOG(3) << "output_index: " << output_index;
997 
998       // Connect split node to original tensor output.
999       for (const Edge* output_edge : output_edge_vec) {
1000         VLOG(3) << "output_edge" << output_edge->DebugString();
1001         graph->RemoveEdge(output_edge);
1002         graph->AddEdge(split_node, output_index, output_edge->dst(),
1003                        output_edge->dst_input());
1004         // Update the `tpu_output_shapes` mapping: Remove old edges.
1005         tpu_output_shapes->erase(output_edge);
1006       }
1007       graph->RemoveNode(edge->dst());
1008     }
1009     VLOG(3) << "Concat node: " << concat_node->DebugString();
1010   }
1011   return Status::OK();
1012 }
1013 
InsertReshapeNodePairs(Graph * graph,const string & cluster_name,EdgeShapes * tpu_input_shapes,int num_cores_per_replica)1014 Status InsertReshapeNodePairs(Graph* graph, const string& cluster_name,
1015                               EdgeShapes* tpu_input_shapes,
1016                               int num_cores_per_replica) {
1017   std::vector<const Edge*> tpu_input_edges_original;
1018   for (const auto& it : *tpu_input_shapes)
1019     if (!it.second.empty()) tpu_input_edges_original.push_back(it.first);
1020   for (const Edge* edge : tpu_input_edges_original) {
1021     VLOG(3) << "Reshape input: " << edge->DebugString();
1022 
1023     // Check if there is a TPUReplicatedInput and XlaSharding in the middle
1024     bool xla_sharded_input = false;
1025     Node* xla_sharding_node = nullptr;
1026     if (edge->dst()->type_string() == "TPUReplicatedInput" &&
1027         edge->dst()->out_nodes().begin()->type_string() == "XlaSharding") {
1028       VLOG(3) << "Detected TPUReplicatedInput " << edge->dst()->DebugString()
1029               << " and XlaSharding "
1030               << edge->dst()->out_nodes().begin()->DebugString()
1031               << ", setting xla_sharded_input = true";
1032       xla_sharded_input = true;
1033       xla_sharding_node = *(edge->dst()->out_nodes().begin());
1034     }
1035 
1036     // 1. Build Reshape node for flatten.
1037 
1038     // 1.1 Build Const node for shape
1039     Node* flatten_reshape_shape_node = nullptr;
1040     Tensor flattened_input_shape_tensor;
1041     flattened_input_shape_tensor =
1042         Tensor(DT_INT32, TensorShape({static_cast<int64>(1)}));
1043     flattened_input_shape_tensor.flat<int>()(0) = -1;
1044     TF_RETURN_IF_ERROR(
1045         NodeBuilder(absl::StrCat(edge->src()->name(), "/flatten/Reshape/shape"),
1046                     "Const")
1047             .Attr("dtype", DT_INT32)
1048             .Attr("value", flattened_input_shape_tensor)
1049             .Finalize(graph, &flatten_reshape_shape_node));
1050 
1051     // 1.2 Build Reshape node for flatten.
1052     Node* flatten_reshape_node = nullptr;
1053     TF_RETURN_IF_ERROR(
1054         NodeBuilder(absl::StrCat(edge->src()->name(), "/flatten/Reshape"),
1055                     "Reshape")
1056             .Input(edge->src(), edge->src_output())
1057             .Input(flatten_reshape_shape_node)
1058             .Attr("T", edge->src()->output_type(edge->src_output()))
1059             .Attr("Tshape", DT_INT32)
1060             .Finalize(graph, &flatten_reshape_node));
1061 
1062     // 2. Build Reshape node for recover.
1063 
1064     // 2.1 Build Const node for shape.
1065     Node* recover_reshape_shape_node = nullptr;
1066     Tensor original_input_shape_tensor(
1067         DT_INT32,
1068         TensorShape({static_cast<int64>(tpu_input_shapes->at(edge).size())}));
1069     original_input_shape_tensor.flat<int>()(0) = -1;
1070     for (int d = 1; d < tpu_input_shapes->at(edge).size(); ++d)
1071       original_input_shape_tensor.flat<int>()(d) =
1072           tpu_input_shapes->at(edge).at(d);
1073     TF_RETURN_IF_ERROR(
1074         NodeBuilder(absl::StrCat(edge->src()->name(), "/recover/Reshape/shape"),
1075                     "Const")
1076             .Attr("dtype", DT_INT32)
1077             .Attr("value", original_input_shape_tensor)
1078             .Attr(kTpuReplicateAttr, cluster_name)  // This node is on TPU.
1079             .Finalize(graph, &recover_reshape_shape_node));
1080 
1081     // 2.2 Build Reshape node for recover.
1082     Node* recover_reshape_input_node = flatten_reshape_node;
1083     const Edge* original_recover_reshape_input_edge = nullptr;
1084     if (xla_sharded_input) {
1085       // We want to find the node after the XlaSharding node
1086       original_recover_reshape_input_edge =
1087           *(edge->dst()->out_nodes().begin()->out_edges().begin());
1088       recover_reshape_input_node = *(edge->dst()->out_nodes().begin());
1089       VLOG(3) << "Recover reshape input node: "
1090               << recover_reshape_input_node->DebugString()
1091               << ", recover reshape input edge: "
1092               << original_recover_reshape_input_edge->DebugString();
1093     }
1094 
1095     Node* recover_reshape_node = nullptr;
1096     TF_RETURN_IF_ERROR(
1097         NodeBuilder(absl::StrCat(edge->src()->name(), "/recover/Reshape"),
1098                     "Reshape")
1099             .Input(recover_reshape_input_node)
1100             .Input(recover_reshape_shape_node)
1101             .Attr("T", edge->src()->output_type(edge->src_output()))
1102             .Attr("Tshape", DT_INT32)
1103             .Attr(kTpuReplicateAttr, cluster_name)  // This node is on TPU.
1104             .Finalize(graph, &recover_reshape_node));
1105 
1106     // 3. Rewrite XlaSharding attribute if necessary
1107     if (xla_sharding_node != nullptr) {
1108       // The flattened tensor always has rank = 1 and we want to shard the only
1109       // dimension (0).
1110       SetXlaShardingNodeAttr(xla_sharding_node, num_cores_per_replica, 1, 0);
1111     }
1112 
1113     // 4. Connect / disconnect nodes.
1114     if (xla_sharded_input) {
1115       graph->AddEdge(flatten_reshape_node, 0, edge->dst(), edge->dst_input());
1116     }
1117 
1118     if (original_recover_reshape_input_edge != nullptr) {
1119       graph->AddEdge(recover_reshape_node, 0,
1120                      original_recover_reshape_input_edge->dst(),
1121                      original_recover_reshape_input_edge->dst_input());
1122     } else {
1123       graph->AddEdge(recover_reshape_node, 0, edge->dst(), edge->dst_input());
1124     }
1125 
1126     graph->RemoveEdge(edge);
1127     if (original_recover_reshape_input_edge != nullptr) {
1128       graph->RemoveEdge(original_recover_reshape_input_edge);
1129     }
1130 
1131     // 4. Update EdgeShapes.
1132     int dimension = 1;
1133     for (auto& it : (*tpu_input_shapes)[edge]) {
1134       dimension *= it;
1135     }
1136     VLOG(3) << "Dimension after reshape: " << dimension;
1137     for (const Edge* out_edge : flatten_reshape_node->out_edges()) {
1138       if (out_edge->dst() == recover_reshape_node) {
1139         (*tpu_input_shapes)[out_edge].push_back(dimension);
1140         tpu_input_shapes->erase(edge);
1141         break;
1142       }
1143     }
1144     VLOG(3) << "Reshape optimization done for " << edge->src()->name();
1145   }
1146   return Status::OK();
1147 }
1148 }  // namespace tpu_functional_internal
1149 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)1150 void TPUPartitionedCallOp::ComputeAsync(OpKernelContext* ctx,
1151                                         DoneCallback done) {
1152   Status init_status;
1153   absl::call_once(once_, [&]() {
1154     library_runtime_ = ctx->function_library();
1155     if (library_runtime_ == nullptr) {
1156       init_status = errors::Internal("No function library is provided.");
1157       return;
1158     }
1159     flib_def_ = std::make_unique<FunctionLibraryDefinition>(
1160         *library_runtime_->GetFunctionLibraryDefinition());
1161     device_mgr_ = library_runtime_->device_mgr();
1162     for (auto d : device_mgr_->ListDevices()) {
1163       device_set_.AddDevice(d);
1164     }
1165 
1166     DeviceNameUtils::ParsedName tpu_device_name;
1167     tpu_device_name.has_type = true;
1168     tpu_device_name.type = "TPU";
1169     std::vector<Device*> tpu_devices;
1170     device_set_.FindMatchingDevices(tpu_device_name, &tpu_devices_);
1171   });
1172   OP_REQUIRES_OK_ASYNC(ctx, init_status, done);
1173 
1174   // Initialize the ordinal selector with information from the graph if it is
1175   // the first time we are running this op.
1176   absl::call_once(ordinal_selector_once_, [&]() {
1177     std::unique_ptr<Graph> graph(new Graph(flib_def_.get()));
1178     int num_cores_per_replica = 1;
1179     bool enable_spmd_xla_partitioning = false;
1180     {
1181       absl::MutexLock l(&mu_);
1182       OP_REQUIRES_OK_ASYNC(
1183           ctx,
1184           GetGraphFromFunction(graph.get(), /*device_ordinal=*/0,
1185                                &num_cores_per_replica,
1186                                &enable_spmd_xla_partitioning),
1187           done);
1188     }
1189     if (enable_spmd_xla_partitioning) {
1190       ordinal_selector_ =
1191           std::make_shared<tpu::TPUOrdinalSelector>(num_cores_per_replica);
1192     } else {
1193       ordinal_selector_ = std::make_shared<tpu::TPUOrdinalSelector>();
1194     }
1195 
1196     metrics::RecordTPUXlaSpmdCoresPerReplica(num_cores_per_replica);
1197   });
1198   uint64 input_hash = GetInputHash(ctx);
1199   int64_t ordinal_selector_req_id = -1;
1200   // Select a TPU core.
1201   absl::ReleasableMutexLock lock(&mu_);
1202   int32_t device_ordinal = 0;
1203   OP_REQUIRES_OK_ASYNC(
1204       ctx,
1205       GetTpuCoreOrdinal(ctx, input_hash, &ordinal_selector_req_id,
1206                         &device_ordinal),
1207       done);
1208   uint64 cache_hash = Hash64Combine(input_hash, device_ordinal);
1209 
1210   const std::vector<DeviceAndFHandle>* functions;
1211 
1212   bool cache_miss = !partition_cache_.count(cache_hash);
1213   if (cache_miss) {
1214     VLOG(3) << "Cache Miss: partitioning function " << func_.name()
1215             << " cache_hash: " << cache_hash
1216             << " device_ordinal: " << device_ordinal;
1217 
1218     profiler::TraceMe trace_me(
1219         "TPUPartitionedCallOp-RewriteAndInstantiateFunctions");
1220     std::unique_ptr<Graph> graph(new Graph(flib_def_.get()));
1221     int num_cores_per_replica = 1;
1222     bool enable_spmd_xla_partitioning = false;
1223     OP_REQUIRES_OK_ASYNC(ctx,
1224                          GetGraphFromFunction(graph.get(), device_ordinal,
1225                                               &num_cores_per_replica,
1226                                               &enable_spmd_xla_partitioning),
1227                          done);
1228 
1229     VLOG(1) << DumpGraphToFile("before_input_output_optimizations", *graph,
1230                                flib_def_.get());
1231 
1232     std::map<std::string, std::vector<int>> named_input_shapes;
1233     OP_REQUIRES_OK_ASYNC(ctx,
1234                          OptimizeTpuInputOutputTensors(
1235                              graph.get(), enable_spmd_xla_partitioning,
1236                              num_cores_per_replica, named_input_shapes, ctx),
1237                          done);
1238 
1239     VLOG(1) << DumpGraphToFile(
1240         "before_replace_resource_args_with_var_handle_ops", *graph,
1241         flib_def_.get());
1242     OP_REQUIRES_OK_ASYNC(
1243         ctx,
1244         ReplaceResourceArgsWithVarHandleOps(graph.get(), ctx, device_ordinal,
1245                                             num_cores_per_replica,
1246                                             enable_spmd_xla_partitioning),
1247         done);
1248 
1249     VLOG(1) << DumpGraphToFile(
1250         "after_replace_resource_args_with_var_handle_ops", *graph,
1251         flib_def_.get());
1252 
1253     // Graph rewrite passes.
1254     GraphOptimizationPassOptions optimization_options;
1255     // TODO(akshayka): Thread the SessionOptions into this kernel, or make
1256     // it possible to specify the relevant options via attributes.
1257     SessionOptions session_options;
1258     session_options.config.mutable_experimental()
1259         ->set_xla_fusion_autotuner_thresh(autotuner_thresh_);
1260 
1261     session_options.env = ctx->env();
1262     optimization_options.session_handle = ctx->session_handle();
1263     optimization_options.session_options = &session_options;
1264     optimization_options.graph = &graph;
1265     optimization_options.flib_def = flib_def_.get();
1266     optimization_options.device_set = &device_set_;
1267     OP_REQUIRES_OK_ASYNC(
1268         ctx, PlacementHelper(device_set_, optimization_options, func_.name()),
1269         done);
1270 
1271     if (!enable_spmd_xla_partitioning || num_cores_per_replica == 1) {
1272       OP_REQUIRES_OK_ASYNC(
1273           ctx,
1274           MaybeRegisterFingerprint(graph.get(), named_input_shapes, input_hash),
1275           done);
1276     }
1277     // `subgraphs` maps from device names to functions.
1278     std::unordered_map<std::string, std::unique_ptr<Graph>> subgraphs;
1279     optimization_options.graph = nullptr;
1280     optimization_options.device_set = nullptr;
1281     optimization_options.partition_graphs = &subgraphs;
1282     VLOG(1) << DumpGraphToFile("before_partition_helper.pbtxt", *graph,
1283                                flib_def_.get());
1284     OP_REQUIRES_OK_ASYNC(ctx,
1285                          PartitionHelper(device_set_, optimization_options,
1286                                          graph.get(), &subgraphs),
1287                          done);
1288     OP_REQUIRES_OK_ASYNC(ctx,
1289                          InstantiateFunctionsFromSubgraphs(
1290                              device_set_, device_ordinal, cache_hash,
1291                              num_cores_per_replica, std::move(subgraphs)),
1292                          done);
1293   }
1294   functions = &partition_cache_[cache_hash];
1295   lock.Release();
1296 
1297   ExecuteFunctions(*functions, ctx, device_ordinal, ordinal_selector_req_id,
1298                    std::move(done));
1299 }
1300 
GetTpuCoreOrdinal(OpKernelContext * ctx,uint64 input_hash,int64_t * ordinal_selector_req_id,int32_t * core_ordinal)1301 Status TPUPartitionedCallOp::GetTpuCoreOrdinal(OpKernelContext* ctx,
1302                                                uint64 input_hash,
1303                                                int64_t* ordinal_selector_req_id,
1304                                                int32_t* core_ordinal) {
1305   profiler::TraceMe trace_me("TPUPartitionedCallOp-GetTpuCoreOrdinal");
1306   const Tensor* device_ordinal_t;
1307   TF_RETURN_IF_ERROR(ctx->input(kDeviceOrdinalAttr, &device_ordinal_t));
1308   int device_ordinal = device_ordinal_t->scalar<int>()();
1309   if (device_ordinal == tpu::kDeferredCoreSelectionReserved) {
1310     device_ordinal =
1311         ordinal_selector_->GetOrdinal(input_hash, ordinal_selector_req_id);
1312   }
1313   *core_ordinal = device_ordinal;
1314   return Status::OK();
1315 }
1316 
InitializeVarOnTPU(OpKernelContext * ctx,const core::RefCountPtr<Var> & var,NodeDef * ndef,int device_ordinal,bool fast_mem)1317 Status TPUPartitionedCallOp::InitializeVarOnTPU(
1318     OpKernelContext* ctx, const core::RefCountPtr<Var>& var, NodeDef* ndef,
1319     int device_ordinal, bool fast_mem) {
1320   const string device = strings::StrCat(kTPUDeviceNamePrefix, device_ordinal);
1321   Status status;
1322   std::unique_ptr<Graph> init_graph(new Graph(OpRegistry::Global()));
1323   Node* init_handle = init_graph->AddNode(*ndef, &status);
1324   TF_RETURN_IF_ERROR(status);
1325   init_handle->set_assigned_device_name(device);
1326 
1327   NodeDef init_const_ndef;
1328   init_const_ndef.set_name("initial_value");
1329   if (fast_mem) {
1330     init_const_ndef.set_op("_TPUConst");
1331     AddNodeAttr("memory_space", "FastMem", &init_const_ndef);
1332   } else {
1333     init_const_ndef.set_op("Const");
1334   }
1335   init_const_ndef.set_device(device);
1336   AddNodeAttr("dtype", var->tensor()->dtype(), &init_const_ndef);
1337   AddNodeAttr("value", *var->tensor(), &init_const_ndef);
1338 
1339   Node* init_const = init_graph->AddNode(init_const_ndef, &status);
1340   TF_RETURN_IF_ERROR(status);
1341 
1342   NodeDef assign_node_def;
1343   assign_node_def.set_name("Assign");
1344   assign_node_def.set_op("AssignVariableOp");
1345   assign_node_def.set_device(device);
1346   AddNodeAttr("dtype", var->tensor()->dtype(), &assign_node_def);
1347   Node* init_assign = init_graph->AddNode(assign_node_def, &status);
1348   TF_RETURN_IF_ERROR(status);
1349 
1350   init_graph->AddEdge(init_handle, 0, init_assign, 0);
1351   init_graph->AddEdge(init_const, 0, init_assign, 1);
1352   FHandle fhandle;
1353   const string fname =
1354       strings::StrCat(ndef->name(), "_init_ord_", device_ordinal);
1355 
1356   TF_RETURN_IF_ERROR(
1357       InstantiatePartition(*init_graph, fname, device, &fhandle, nullptr));
1358 
1359   FunctionLibraryRuntime::Options opts;
1360   opts.step_container = ctx->step_container();
1361   opts.cancellation_manager = ctx->cancellation_manager();
1362   opts.stats_collector = ctx->stats_collector();
1363 
1364   // Blocking on threads in the same thread pool is disallowed because
1365   // concurrent warm-up requests can exhaust the default thread pool.
1366   // Create a new thread pool to initialize variables on TPU.
1367   std::function<void(std::function<void()>)> runner =
1368       [this](std::function<void()> fn) { pool_.Schedule(fn); };
1369   opts.runner = &runner;
1370 
1371   opts.source_device = local_device_name_;
1372   PrivateIntraProcessRendezvous rendez(device_mgr_);
1373   opts.rendezvous = &rendez;
1374   opts.remote_execution = true;
1375 
1376   std::vector<Tensor> dummy_args;
1377   std::vector<Tensor>* dummy_rets = new std::vector<Tensor>;
1378   Notification done;
1379   profiler::TraceMe trace_me("TPUPartitionedCallOp-InitializeVarOnTPU");
1380   library_runtime_->Run(opts, fhandle, dummy_args, dummy_rets,
1381                         [dummy_rets, &done, ctx](const Status& status) {
1382                           if (!status.ok()) {
1383                             ctx->SetStatus(status);
1384                           }
1385                           delete dummy_rets;
1386                           done.Notify();
1387                         });
1388   done.WaitForNotification();
1389   // We don't actually want the variable initialization functions
1390   // in the function library definition and the function library
1391   // runtime, because flib_def_ is used for the graph rewrite passes.
1392   // The TPU distributed rewrite pass computes a fingerprint for
1393   // flib_def_, which will throw an length error if there are
1394   // many variables whose initialization functions are added
1395   // to the library definition.
1396   TF_RETURN_IF_ERROR(flib_def_->RemoveFunction(fname));
1397   TF_RETURN_IF_ERROR(library_runtime_->ReleaseHandle(fhandle));
1398   return Status::OK();
1399 }
1400 
InitializeShardedVarOnTPU(OpKernelContext * ctx,const core::RefCountPtr<Var> & var,std::vector<NodeDef> & ndefs,int split_dim,int device_ordinal)1401 Status TPUPartitionedCallOp::InitializeShardedVarOnTPU(
1402     OpKernelContext* ctx, const core::RefCountPtr<Var>& var,
1403     std::vector<NodeDef>& ndefs, int split_dim, int device_ordinal) {
1404   std::unique_ptr<Graph> init_graph(new Graph(OpRegistry::Global()));
1405   int num_cores = ndefs.size();
1406   string cpu_device = "/device:CPU:0";
1407 
1408   Status status;
1409   std::vector<std::string> devices;
1410   std::vector<Node*> init_handles;
1411   for (int i = 0; i < num_cores; i++) {
1412     Node* init_handle = init_graph->AddNode(ndefs[i], &status);
1413     TF_RETURN_IF_ERROR(status);
1414     string device = strings::StrCat(kTPUDeviceNamePrefix, device_ordinal + i);
1415     init_handle->set_assigned_device_name(device);
1416     devices.push_back(device);
1417     init_handles.push_back(init_handle);
1418   }
1419 
1420   NodeDef init_const_ndef;
1421   init_const_ndef.set_name("initial_value");
1422   init_const_ndef.set_op("Const");
1423   init_const_ndef.set_device(cpu_device);
1424   AddNodeAttr("dtype", var->tensor()->dtype(), &init_const_ndef);
1425   AddNodeAttr("value", *var->tensor(), &init_const_ndef);
1426   Node* init_const = init_graph->AddNode(init_const_ndef, &status);
1427   init_const->set_assigned_device_name(cpu_device);
1428   TF_RETURN_IF_ERROR(status);
1429 
1430   Node* assign_value_node = init_const;
1431   // If the variable is sharded, we will insert "Split" node between the initial
1432   // value and AssignVariableOp, so the variables on each TPU device get
1433   // assigned to the splitted value.
1434   //
1435   // initial_value--Split--AssignVariableOp ("/device:TPU:0")
1436   //                  |
1437   //            AssignVariableOp ("/device:TPU:1")
1438   if (split_dim >= 0) {
1439     // Add a split dimension node.
1440     NodeDef split_dim_def;
1441     split_dim_def.set_name("initial_value_split_dim");
1442     split_dim_def.set_op("Const");
1443     split_dim_def.set_device(cpu_device);
1444     AddNodeAttr("dtype", DT_INT32, &split_dim_def);
1445     TensorProto tensor_proto;
1446     tensor_proto.set_dtype(DT_INT32);
1447     tensor_proto.add_int_val(split_dim);
1448     TensorShape shape({});
1449     shape.AsProto(tensor_proto.mutable_tensor_shape());
1450     AddNodeAttr("value", tensor_proto, &split_dim_def);
1451     Node* split_dim_node = init_graph->AddNode(split_dim_def, &status);
1452     split_dim_node->set_assigned_device_name(cpu_device);
1453     TF_RETURN_IF_ERROR(status);
1454 
1455     // Add a split node.
1456     NodeDef split_def;
1457     int split_num = ndefs.size();
1458     split_def.set_name("initial_value_split");
1459     split_def.set_op("Split");
1460     split_def.set_device(cpu_device);
1461     AddNodeAttr("num_split", split_num, &split_def);
1462     AddNodeAttr("T", var->tensor()->dtype(), &split_def);
1463     split_def.add_input(absl::StrCat(split_dim_node->name(), ":0"));
1464     split_def.add_input(absl::StrCat(init_const->name(), ":0"));
1465     Node* split_node = init_graph->AddNode(split_def, &status);
1466     split_node->set_assigned_device_name(cpu_device);
1467     TF_RETURN_IF_ERROR(status);
1468 
1469     init_graph->AddEdge(split_dim_node, 0, split_node, 0);
1470     init_graph->AddEdge(init_const, 0, split_node, 1);
1471 
1472     assign_value_node = split_node;
1473   }
1474 
1475   for (int i = 0; i < num_cores; i++) {
1476     NodeDef assign_node_def;
1477     assign_node_def.set_name(absl::StrCat("Assign_", i));
1478     assign_node_def.set_op("AssignVariableOp");
1479     assign_node_def.set_device(devices[i]);
1480     AddNodeAttr("dtype", var->tensor()->dtype(), &assign_node_def);
1481     Node* init_assign = init_graph->AddNode(assign_node_def, &status);
1482     init_assign->set_assigned_device_name(devices[i]);
1483     TF_RETURN_IF_ERROR(status);
1484 
1485     init_graph->AddEdge(init_handles[i], 0, init_assign, 0);
1486     if (split_dim >= 0) {
1487       init_graph->AddEdge(assign_value_node, i, init_assign, 1);
1488     } else {
1489       init_graph->AddEdge(assign_value_node, 0, init_assign, 1);
1490     }
1491   }
1492 
1493   GraphOptimizationPassOptions optimization_options;
1494   SessionOptions session_options;
1495   session_options.env = ctx->env();
1496   optimization_options.session_handle = ctx->session_handle();
1497   optimization_options.session_options = &session_options;
1498   optimization_options.flib_def = flib_def_.get();
1499   optimization_options.graph = nullptr;
1500   optimization_options.device_set = nullptr;
1501   std::unordered_map<std::string, std::unique_ptr<Graph>> subgraphs;
1502   optimization_options.partition_graphs = &subgraphs;
1503   TF_RETURN_IF_ERROR(PartitionHelper(device_set_, optimization_options,
1504                                      init_graph.get(), &subgraphs));
1505 
1506   std::vector<DeviceAndFHandle> functions;
1507   std::vector<std::string> function_names;
1508   for (auto& pair : subgraphs) {
1509     string target = pair.first;
1510     Device* device;
1511     TF_RETURN_IF_ERROR(
1512         library_runtime_->device_mgr()->LookupDevice(target, &device));
1513     Graph* subgraph = pair.second.get();
1514     string function_name = flib_def_->UniqueFunctionName(
1515         strings::StrCat(func_.name(), "_hash_", pair.first));
1516     function_names.push_back(function_name);
1517     FHandle handle;
1518     TF_RETURN_IF_ERROR(InstantiatePartition(*subgraph, function_name, target,
1519                                             &handle, nullptr));
1520     functions.push_back(DeviceAndFHandle{.device = target, .handle = handle});
1521   }
1522 
1523   FunctionLibraryRuntime::Options opts;
1524 
1525   // Blocking on threads in the same thread pool is disallowed because
1526   // concurrent warm-up requests can exhaust the default thread pool.
1527   // Create a new thread pool to initialize variables on TPU.
1528   std::function<void(std::function<void()>)> runner =
1529       [this](std::function<void()> fn) { pool_.Schedule(fn); };
1530   opts.runner = &runner;
1531 
1532   opts.step_container = ctx->step_container();
1533   opts.cancellation_manager = ctx->cancellation_manager();
1534   opts.stats_collector = ctx->stats_collector();
1535   opts.source_device = local_device_name_;
1536   opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
1537 
1538   OpInputList arguments;
1539   TF_RETURN_IF_ERROR(ctx->input_list("args", &arguments));
1540 
1541   auto* rendez = new PrivateIntraProcessRendezvous(device_mgr_);
1542   opts.rendezvous = rendez;
1543 
1544   BlockingCounter bcount(functions.size());
1545   for (const DeviceAndFHandle& entry : functions) {
1546     const string& target_device = entry.device;
1547     FHandle handle = entry.handle;
1548 
1549     TF_RETURN_IF_ERROR(
1550         ShouldUseRemoteExecutionForFn(target_device, &(opts.remote_execution)));
1551     std::vector<Tensor> dummy_args;
1552     std::vector<Tensor>* dummy_rets = new std::vector<Tensor>;
1553 
1554     profiler::TraceMe trace_me(
1555         "TPUPartitionedCallOp-InitializeShardedVarOnTPU");
1556     library_runtime_->Run(opts, handle, dummy_args, dummy_rets,
1557                           [dummy_rets, &bcount, ctx](const Status& status) {
1558                             if (!status.ok()) {
1559                               ctx->SetStatus(status);
1560                             }
1561                             delete dummy_rets;
1562                             bcount.DecrementCount();
1563                           });
1564   }
1565   bcount.Wait();
1566 
1567   for (int i = 0; i < functions.size(); i++) {
1568     TF_RETURN_IF_ERROR(flib_def_->RemoveFunction(function_names[i]));
1569     TF_RETURN_IF_ERROR(library_runtime_->ReleaseHandle(functions[i].handle));
1570   }
1571   return Status::OK();
1572 }
1573 
IsInputToTPUReplicate(Node * node)1574 bool TPUPartitionedCallOp::IsInputToTPUReplicate(Node* node) {
1575   for (Node* successor : node->out_nodes()) {
1576     if (successor->attrs().Find(kTpuReplicateAttr) != nullptr) {
1577       return true;
1578     }
1579   }
1580   return false;
1581 }
1582 
ReplaceResourceArgsWithVarHandleOps(Graph * graph,OpKernelContext * ctx,int device_ordinal,int num_cores_per_replica,bool enable_spmd_xla_partitioning)1583 Status TPUPartitionedCallOp::ReplaceResourceArgsWithVarHandleOps(
1584     Graph* graph, OpKernelContext* ctx, int device_ordinal,
1585     int num_cores_per_replica, bool enable_spmd_xla_partitioning) {
1586   // Currently variable deduplication is not supported for XLA SPMD
1587   // partitioning. It is possible that it could be supported in the future.
1588   bool enable_variable_deduplication =
1589       runtime_params_.enable_variable_deduplication;
1590   if (enable_spmd_xla_partitioning && num_cores_per_replica > 1) {
1591     // If enable_spmd_xla_partitioning is true, the user set the
1592     // enable_auto_xla_input_sharding flag. Warn them that only one of the flags
1593     // can be set safely when num_cores_per_replica > 1. If
1594     // num_cores_per_replica==1, enable_spmd_xla_partitioning is effectively a
1595     // no-op so we can skip this check.
1596     LOG(WARNING) << "Disabling variable deduplication because it is not "
1597                     "compatible with enable_auto_xla_input_sharding.";
1598     enable_variable_deduplication = false;
1599   }
1600   std::vector<Node*> tpu_resource_args;
1601   std::vector<int> arg_indices;
1602   absl::flat_hash_map<const Node*, xla::OpSharding> variable_to_xla_sharding;
1603   for (Node* node : graph->op_nodes()) {
1604     if (node->IsArg()) {
1605       const AttrValue* attr_value;
1606       TF_RETURN_IF_ERROR(node->attrs().Find("T", &attr_value));
1607       DataType dtype = attr_value->type();
1608       if (dtype == DT_RESOURCE && IsInputToTPUReplicate(node)) {
1609         // If this VarHandleOp is used by a TPU computation,
1610         // we need to create a TPU version of the variable,
1611         TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
1612         int index = attr_value->i();
1613         tpu_resource_args.push_back(node);
1614         arg_indices.push_back(index);
1615         replaced_input_indices_[index] = true;
1616       }
1617     }
1618   }
1619 
1620   VLOG(3) << "tpu_resource_args.size(): " << tpu_resource_args.size();
1621   // Create a mapping from ResourceHandle to variable node. When a
1622   // ResourceHandle backs several variable nodes, the variable nodes refer to
1623   // the same underlying resource. In that case, only one variable node needs
1624   // to be mirrored to the TPU for that resource.
1625   absl::flat_hash_map<uint64, Node*> tpu_variables;
1626   for (int i = 0; i < tpu_resource_args.size(); i++) {
1627     Node* node = tpu_resource_args[i];
1628     ResourceHandle handle = HandleFromInput(ctx, arg_indices[i]);
1629 
1630     if (num_cores_per_replica > 1 && enable_spmd_xla_partitioning) {
1631       TF_RETURN_IF_ERROR(ReplaceAndPartitionXLAShardingVariable(
1632           graph, ctx, device_ordinal, handle, node, num_cores_per_replica));
1633       continue;
1634     }
1635     TPUVariableInfo var_info(/*device_ordinal_id=*/0, /*use_fast_mem=*/false);
1636     TF_RETURN_IF_ERROR(
1637         ParseTPUVariableInfor(node, num_cores_per_replica, &var_info));
1638     // Only respect graph's placement when model parallelism enabled.
1639     if (num_cores_per_replica > 1) device_ordinal = var_info.device_ordinal;
1640 
1641     const uint64 handle_fp =
1642         Fingerprint64(strings::StrCat(handle.container(), handle.name()));
1643     if (enable_variable_deduplication && tpu_variables.contains(handle_fp) &&
1644         num_cores_per_replica == 1) {
1645       Node* tpu_variable = tpu_variables.at(handle_fp);
1646       std::vector<Node*> dst_nodes;
1647       std::vector<int> src_indices;
1648       std::vector<int> dst_indices;
1649       for (const Edge* edge : node->out_edges()) {
1650         dst_nodes.push_back(edge->dst());
1651         src_indices.push_back(edge->src_output());
1652         dst_indices.push_back(edge->dst_input());
1653       }
1654       graph->RemoveNode(node);
1655       for (int i = 0; i < dst_nodes.size(); i++) {
1656         graph->AddEdge(tpu_variable, src_indices[i], dst_nodes[i],
1657                        dst_indices[i]);
1658       }
1659     } else {
1660       uint64 fp =
1661           Fingerprint64(strings::StrCat(handle.container(), handle.name(), i));
1662       NodeDef ndef;
1663       ndef.set_name(strings::StrCat(handle.name(), fp));
1664       ndef.set_op(kVarHandleOp);
1665       if (num_cores_per_replica > 1) {
1666         ndef.set_device(strings::StrCat(kTPUDeviceNamePrefix, device_ordinal));
1667       } else {
1668         // Assign this new VarHandleOp to TPU:0 so the partitioner only
1669         // partiitons the graph into two subgraphs, one on CPU and one on TPU.
1670         // The actual device ordinal on which this VarHandleOp runs is assigned
1671         // after partitioning (in SetDeviceOrdinal).
1672         ndef.set_device(
1673             strings::StrCat(kTPUDeviceNamePrefix, kTPUDefaultDeviceOrdinal));
1674       }
1675 
1676       // Replace each _Arg node of type DT_RESOURCE that goes into a TPU node
1677       // by a VarHandleOp on TPU with shared_name "v_tpu_x" where "v" is the
1678       // shared_name of the variable on CPU and "x" is the rewritten device
1679       // ordinal.
1680       const string sname =
1681           strings::StrCat(handle.name(), "_tpu_", device_ordinal);
1682       AddNodeAttr("shared_name", sname, &ndef);
1683       const string cname = ctx->resource_manager()->default_container();
1684       AddNodeAttr("container", cname, &ndef);
1685       core::RefCountPtr<Var> var;
1686       TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &var));
1687       AddNodeAttr("dtype", var->tensor()->dtype(), &ndef);
1688       TensorShapeProto proto;
1689       var->tensor()->shape().AsProto(&proto);
1690       AddNodeAttr("shape", proto, &ndef);
1691       Status status;
1692       Node* new_node = graph->AddNode(ndef, &status);
1693       TF_RETURN_IF_ERROR(status);
1694       std::vector<const Edge*> in_edges(node->in_edges().begin(),
1695                                         node->in_edges().end());
1696       for (const Edge* edge : in_edges) {
1697         graph->AddEdge(edge->src(), edge->src_output(), new_node,
1698                        edge->dst_input());
1699       }
1700       std::vector<Node*> dst_nodes;
1701       std::vector<int> src_indices;
1702       std::vector<int> dst_indices;
1703       for (const Edge* edge : node->out_edges()) {
1704         dst_nodes.push_back(edge->dst());
1705         src_indices.push_back(edge->src_output());
1706         dst_indices.push_back(edge->dst_input());
1707       }
1708       graph->RemoveNode(node);
1709       for (int i = 0; i < dst_nodes.size(); i++) {
1710         graph->AddEdge(new_node, src_indices[i], dst_nodes[i], dst_indices[i]);
1711       }
1712       // Don't initialize variables on TPU if it is done for the ordinal
1713       // already.
1714       if (seen_ordinals_.contains(device_ordinal)) continue;
1715 
1716       Device* d;
1717       TF_RETURN_IF_ERROR(library_runtime_->device_mgr()->LookupDevice(
1718           strings::StrCat(kTPUDeviceNamePrefix, device_ordinal), &d));
1719       Var* tpu_var;
1720       status = d->resource_manager()->Lookup(cname, sname, &tpu_var);
1721       if (!status.ok()) {
1722         TF_RETURN_IF_ERROR(InitializeVarOnTPU(ctx, var, &ndef, device_ordinal,
1723                                               var_info.fast_mem));
1724       }
1725       tpu_variables[handle_fp] = new_node;
1726     }
1727   }
1728 
1729   // adjust the index attr of other non-resource arg nodes
1730   int new_index = 0;
1731   for (Node* node : graph->op_nodes()) {
1732     if (node->IsArg()) {
1733       node->ClearAttr("index");
1734       node->AddAttr("index", new_index);
1735       new_index++;
1736     }
1737   }
1738 
1739   seen_ordinals_.insert(device_ordinal);
1740 
1741   return Status::OK();
1742 }
1743 
ReplaceAndPartitionXLAShardingVariable(Graph * graph,OpKernelContext * ctx,int device_ordinal,ResourceHandle & handle,Node * variable,int num_cores_per_replica)1744 Status TPUPartitionedCallOp::ReplaceAndPartitionXLAShardingVariable(
1745     Graph* graph, OpKernelContext* ctx, int device_ordinal,
1746     ResourceHandle& handle, Node* variable, int num_cores_per_replica) {
1747   TF_ASSIGN_OR_RETURN(
1748       auto sharding,
1749       GetShardingFromNodeDef(variable->def(), /*add_metadata=*/false));
1750   xla::OpSharding xla_sharding;
1751   bool is_var_sharded = false;
1752   if (sharding.has_value() &&
1753       sharding.value().type() == xla::OpSharding::OTHER) {
1754     xla_sharding = sharding.value();
1755     is_var_sharded = true;
1756   } else {
1757     xla_sharding.set_type(xla::OpSharding::REPLICATED);
1758     is_var_sharded = false;
1759   }
1760   VLOG(3) << "Replace and partition variable " << variable->name()
1761           << " with xla_sharding: " << xla_sharding.DebugString();
1762 
1763   core::RefCountPtr<Var> var;
1764   TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &var));
1765 
1766   int split_dim = -1;
1767   int split_size = 0;
1768   if (is_var_sharded) {
1769     for (int dim = 0; dim < xla_sharding.tile_assignment_dimensions_size();
1770          dim++) {
1771       if (xla_sharding.tile_assignment_dimensions(dim) > 1) {
1772         if (split_dim != -1) {
1773           return errors::InvalidArgument(
1774               "Currently we only support inference with one split dimension, "
1775               "however got sharding: ",
1776               xla_sharding.DebugString());
1777         }
1778         split_dim = dim;
1779         split_size = xla_sharding.tile_assignment_dimensions(dim);
1780       }
1781     }
1782   }
1783   const string cname = ctx->resource_manager()->default_container();
1784   std::vector<Node*> per_core_vars;
1785   for (int core_index = device_ordinal;
1786        core_index < (device_ordinal + num_cores_per_replica); core_index++) {
1787     NodeDef ndef;
1788     uint64 fp = Fingerprint64(
1789         strings::StrCat(handle.container(), handle.name(), "_", core_index));
1790     ndef.set_name(strings::StrCat(handle.name(), fp));
1791     ndef.set_op(kVarHandleOp);
1792     ndef.set_device(strings::StrCat(kTPUDeviceNamePrefix, core_index));
1793 
1794     // Replace each _Arg node of type DT_RESOURCE that goes into a TPU node
1795     // by a VarHandleOp on TPU with shared_name "v_tpu_x" where "v" is the
1796     // shared_name of the variable on CPU and "x" is the rewritten device
1797     // ordinal.
1798     const string sname = strings::StrCat(handle.name(), "_tpu_", core_index);
1799     AddNodeAttr("shared_name", sname, &ndef);
1800     AddNodeAttr("container", cname, &ndef);
1801     AddNodeAttr("dtype", var->tensor()->dtype(), &ndef);
1802 
1803     TensorShapeProto proto;
1804     var->tensor()->shape().AsProto(&proto);
1805 
1806     if (is_var_sharded) {
1807       int dim_size = proto.dim(split_dim).size();
1808       if (dim_size % split_size != 0) {
1809         return errors::InvalidArgument("dimension size ", dim_size,
1810                                        " cannot be divisible by split size ",
1811                                        split_size);
1812       }
1813       proto.mutable_dim(split_dim)->set_size(dim_size / split_size);
1814     }
1815     AddNodeAttr("shape", proto, &ndef);
1816 
1817     Status status;
1818     Node* new_node = graph->AddNode(ndef, &status);
1819     TF_RETURN_IF_ERROR(status);
1820     per_core_vars.push_back(new_node);
1821   }
1822 
1823   // Insert TPUPartitionedInput op.
1824   NodeDefBuilder builder(absl::StrCat(handle.name(), "/tpu_partitioned_input"),
1825                          "TPUPartitionedInput");
1826   builder.Attr("N", num_cores_per_replica);
1827   builder.Attr("T", DT_RESOURCE);
1828   builder.Attr("partition_dim", split_dim);
1829   builder.Attr("_XlaSharding", xla_sharding.SerializeAsString());
1830   std::vector<NodeDefBuilder::NodeOut> inputs;
1831   inputs.reserve(num_cores_per_replica);
1832   for (int core_index = 0; core_index < num_cores_per_replica; core_index++) {
1833     inputs.push_back({per_core_vars[core_index]->name(), 0, DT_RESOURCE});
1834   }
1835   builder.Input(inputs);
1836   NodeDef node_def;
1837   TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
1838   Status s;
1839   Node* tpu_partitioned_input_node = graph->AddNode(node_def, &s);
1840   if (!s.ok()) {
1841     return s;
1842   }
1843 
1844   for (int core_index = 0; core_index < num_cores_per_replica; core_index++) {
1845     graph->AddEdge(per_core_vars[core_index], 0, tpu_partitioned_input_node,
1846                    core_index);
1847   }
1848 
1849   // Insert TPUReplicatedInput op.
1850   NodeDefBuilder replicated_builder(
1851       absl::StrCat(handle.name(), "/tpu_replicated_input"),
1852       "TPUReplicatedInput");
1853   replicated_builder.Attr("N", 1);
1854   replicated_builder.Attr("T", DT_RESOURCE);
1855   replicated_builder.Attr("is_mirrored_variable", true);
1856   std::vector<NodeDefBuilder::NodeOut> replicated_inputs;
1857   replicated_inputs.push_back(
1858       {tpu_partitioned_input_node->name(), 0, DT_RESOURCE});
1859   replicated_builder.Input(replicated_inputs);
1860   NodeDef replicated_node_def;
1861   TF_RETURN_IF_ERROR(replicated_builder.Finalize(&replicated_node_def));
1862   Status replicated_s;
1863   Node* tpu_replicated_input_node =
1864       graph->AddNode(replicated_node_def, &replicated_s);
1865   if (!replicated_s.ok()) {
1866     return replicated_s;
1867   }
1868   graph->AddEdge(tpu_partitioned_input_node, 0, tpu_replicated_input_node, 0);
1869 
1870   // Connect the TPUReplicatedInput node to the previous output nodes of the
1871   // variable, and remove the variable node.
1872   std::vector<Node*> dst_nodes;
1873   std::vector<int> src_indices;
1874   std::vector<int> dst_indices;
1875   for (const Edge* edge : variable->out_edges()) {
1876     dst_nodes.push_back(edge->dst());
1877     src_indices.push_back(edge->src_output());
1878     dst_indices.push_back(edge->dst_input());
1879   }
1880   for (int i = 0; i < dst_nodes.size(); i++) {
1881     graph->AddEdge(tpu_replicated_input_node, src_indices[i], dst_nodes[i],
1882                    dst_indices[i]);
1883   }
1884 
1885   graph->RemoveNode(variable);
1886 
1887   std::vector<NodeDef> ndefs;
1888   Status status;
1889   for (int core_index = 0; core_index < num_cores_per_replica; core_index++) {
1890     Device* d;
1891     TF_RETURN_IF_ERROR(library_runtime_->device_mgr()->LookupDevice(
1892         strings::StrCat(kTPUDeviceNamePrefix, device_ordinal + core_index),
1893         &d));
1894     string sname;
1895     const NodeDef& ndef = per_core_vars[core_index]->def();
1896     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "shared_name", &sname));
1897     ndefs.push_back(ndef);
1898     Var* tpu_var;
1899     status = d->resource_manager()->Lookup(cname, sname, &tpu_var);
1900   }
1901 
1902   if (!status.ok()) {
1903     TF_RETURN_IF_ERROR(
1904         InitializeShardedVarOnTPU(ctx, var, ndefs, split_dim, device_ordinal));
1905   }
1906 
1907   return Status::OK();
1908 }
1909 
InferShapesWithResourceVar(Graph * graph,OpKernelContext * ctx,std::map<int,InferredShape> & arg_shapes,GraphShapeInfo * tpu_inferred_info)1910 Status TPUPartitionedCallOp::InferShapesWithResourceVar(
1911     Graph* graph, OpKernelContext* ctx,
1912     std::map<int, InferredShape>& arg_shapes,
1913     GraphShapeInfo* tpu_inferred_info) {
1914   auto shape_inference_graph_interim =
1915       absl::make_unique<Graph>(graph->flib_def());
1916   CopyGraph(*graph, shape_inference_graph_interim.get());
1917 
1918   for (Node* node : shape_inference_graph_interim->nodes()) {
1919     if (node->type_string() != "_Arg" ||
1920         node->attrs().Find("T")->type() != DT_RESOURCE)
1921       continue;
1922 
1923     std::vector<std::function<void()>> to_remove;
1924 
1925     for (const Edge* out_edge : node->out_edges()) {
1926       Node* read_node = out_edge->dst();
1927       if (read_node->type_string() != "ReadVariableOp") continue;
1928 
1929       for (const Edge* variable_edge : read_node->out_edges()) {
1930         // We are delaying these modifications as we cannot do in-place
1931         // modification of EdgeSets.
1932         to_remove.push_back(
1933             [variable_edge, graph = shape_inference_graph_interim.get(), node] {
1934               Node* dst = variable_edge->dst();
1935               graph->RemoveEdge(variable_edge);
1936               graph->AddEdge(node, variable_edge->src_output(), dst,
1937                              variable_edge->dst_input());
1938             });
1939       }
1940       to_remove.push_back(
1941           [graph = shape_inference_graph_interim.get(), out_edge, read_node] {
1942             graph->RemoveEdge(out_edge);
1943             graph->RemoveNode(read_node);
1944           });
1945     }
1946 
1947     for (auto& func : to_remove) {
1948       func();
1949     }
1950 
1951     int resource_arg_index = node->attrs().Find("index")->i();
1952 
1953     // Get resource variable tensor
1954     core::RefCountPtr<Var> variable;
1955     const ResourceHandle& handle = HandleFromInput(ctx, resource_arg_index);
1956     TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &variable));
1957 
1958     const Tensor* variable_tensor = variable->tensor();
1959     std::vector<int> variable_tensor_vec;
1960 
1961     variable_tensor_vec.reserve(variable_tensor->dims());
1962     for (int d = 0; d < variable_tensor->dims(); ++d) {
1963       variable_tensor_vec.push_back(variable_tensor->dim_size(d));
1964     }
1965 
1966     PartialTensorShape partial_tensor_shape;
1967     auto partial_shape = PartialTensorShape::MakePartialShape(
1968         variable_tensor_vec.data(), variable_tensor_vec.size(),
1969         &partial_tensor_shape);
1970     InferredShape inferred_shape = {partial_tensor_shape};
1971     arg_shapes.emplace(resource_arg_index, inferred_shape);
1972   }
1973 
1974   TF_RETURN_IF_ERROR(tensorflow::InferShapes(
1975       shape_inference_graph_interim.get(), arg_shapes,
1976       &shape_inference_graph_interim->flib_def(), tpu_inferred_info));
1977   return Status::OK();
1978 }
1979 
ShardInputsWithXlaSharding(Graph * graph,int num_cores_per_replica,OpKernelContext * ctx)1980 Status TPUPartitionedCallOp::ShardInputsWithXlaSharding(
1981     Graph* graph, int num_cores_per_replica, OpKernelContext* ctx) {
1982   for (Node* replicated_input_node : graph->nodes()) {
1983     if (replicated_input_node->type_string() != "TPUReplicatedInput") continue;
1984 
1985     Node* arg_node;
1986     auto input_node_status = replicated_input_node->input_node(0, &arg_node);
1987     if (!input_node_status.ok()) {
1988       VLOG(2) << "Skip because cannot retrieve input node 0 of "
1989               << replicated_input_node->name() << " because "
1990               << input_node_status.ToString();
1991       continue;
1992     }
1993 
1994     // Check if this TPUReplicatedInput can qualify because it has _Arg
1995     // as input and doesn't have XlaSharding already as an output, then
1996     // try to shard inputs automatically.
1997     //
1998     // In short, we want to see the following graph:
1999     //    _Arg -> TPUReplicatedInput -> (not XlaSharding op)
2000     // and transform it to:
2001     //    _Arg -> TPUReplicatedInput -> XlaSharding -> (not XlaSharding op)
2002     if (arg_node->IsArg() &&
2003         replicated_input_node->out_nodes().begin()->type_string() !=
2004             "XlaSharding") {
2005       int arg_id;
2006       if (!absl::SimpleAtoi(absl::StripPrefix(arg_node->name(), "arg_"),
2007                             &arg_id)) {
2008         VLOG(3) << "Skip auto-sharding because we are unable to extract "
2009                    "argument number from "
2010                 << arg_node->name();
2011         continue;
2012       }
2013 
2014       auto shape = ctx->input(arg_id).shape();
2015 
2016       VLOG(3) << "Identified arg node " << arg_node->DebugString()
2017               << " for TPUReplicatedInput "
2018               << replicated_input_node->DebugString();
2019       VLOG(3) << "Shape within TPUReplicatedInput is: " << shape.DebugString();
2020 
2021       int rank = shape.dims();
2022       int shard_dim =
2023           (runtime_params_.auto_xla_input_sharding_dim + rank) % rank;
2024 
2025       if (shape.dim_size(shard_dim) % num_cores_per_replica != 0) {
2026         VLOG(3) << "Skip auto-sharding " << replicated_input_node->name()
2027                 << " because the specified sharding dimension " << shard_dim
2028                 << " cannot be evenly split by " << num_cores_per_replica;
2029         continue;
2030       }
2031 
2032       auto sharding = absl::make_optional<xla::OpSharding>();
2033       sharding->set_type(xla::OpSharding::OTHER);
2034 
2035       // Sets up tile_assignment_dimensions.
2036       std::vector<int64> dims(rank, 1LL);
2037       dims[shard_dim] = num_cores_per_replica;
2038       for (auto dim : dims) {
2039         sharding->add_tile_assignment_dimensions(dim);
2040       }
2041 
2042       // Sets up tile_assignment_devices.
2043       for (int d = 0; d < num_cores_per_replica; ++d) {
2044         sharding->add_tile_assignment_devices(d);
2045       }
2046 
2047       std::vector<const Edge*> edges_to_remove;
2048       for (const Edge* edge : replicated_input_node->out_edges()) {
2049         if (edge->IsControlEdge()) continue;
2050         edges_to_remove.push_back(edge);
2051       }
2052 
2053       // Create XlaSharding Op.
2054       Node* sharding_op = nullptr;
2055       TF_RETURN_IF_ERROR(
2056           NodeBuilder(absl::StrCat(replicated_input_node->name(), "/sharding"),
2057                       "XlaSharding")
2058               .Input(replicated_input_node, 0)
2059               .Attr("T", replicated_input_node->output_type(0))
2060               .Attr(kXLAShardingAttrName, sharding->SerializeAsString())
2061               .Attr(kXLAShardingAttrAltName, sharding->SerializeAsString())
2062               .Attr("_tpu_replicate", "cluster")
2063               .Finalize(graph, &sharding_op));
2064       for (const Edge* edge : edges_to_remove) {
2065         VLOG(3) << "XlaSharding op creation output edge "
2066                 << edge->DebugString();
2067         graph->RemoveEdge(edge);
2068         graph->AddEdge(sharding_op, 0, edge->dst(), edge->dst_input());
2069       }
2070 
2071       VLOG(3) << "Auto shard " << replicated_input_node->name() << " by dim "
2072               << shard_dim << " into " << num_cores_per_replica << " slices";
2073 
2074       VLOG(3) << "Created XlaSharding Op " << sharding_op->DebugString();
2075     }
2076   }
2077 
2078   return Status::OK();
2079 }
2080 
2081 // OptimizeTpuInputOutputTensors does the following things;
2082 //  (1) Detect input arguments, and add XlaSharding op to the arguments if the
2083 //  enable_auto_xla_input_sharding is turned on
2084 //  (2) Pack multiple input tensors into one tensor by a concat to avoid PCIe
2085 //  transfer overheads for small tensors.
2086 //  (3) Reshape input tensors to R1 to leverage the fast path in TPU input
2087 //  preparation done by runtime.
2088 //  (4) Pack multiple output tensors into one tensor by a concat.
2089 //
2090 // (1) is controlled by --enable_auto_xla_input_sharding and
2091 // --auto_xla_input_sharding_dim
2092 // (2) and (3) are controlled by flags --minimum_input_tensors_packing
2093 // and --input_shape_opt, respectively, while (4) is controlled by
2094 // --minimum_output_tensors_packing.
OptimizeTpuInputOutputTensors(Graph * graph,bool enable_spmd_xla_partitioning,int num_cores_per_replica,std::map<std::string,std::vector<int>> & named_input_shapes,OpKernelContext * ctx)2095 Status TPUPartitionedCallOp::OptimizeTpuInputOutputTensors(
2096     Graph* graph, bool enable_spmd_xla_partitioning, int num_cores_per_replica,
2097     std::map<std::string, std::vector<int>>& named_input_shapes,
2098     OpKernelContext* ctx) {
2099   if (runtime_params_.enable_auto_xla_input_sharding) {
2100     VLOG(2) << DumpGraphToFile("before_enable_auto_xla_input_sharding", *graph,
2101                                flib_def_.get());
2102 
2103     TF_RETURN_IF_ERROR(
2104         ShardInputsWithXlaSharding(graph, num_cores_per_replica, ctx));
2105   }
2106 
2107   GraphShapeInfo tpu_inferred_info;
2108   std::map<int, InferredShape> arg_shapes;
2109   EdgeShapes tpu_input_shapes;
2110   absl::flat_hash_map<const Edge*, DataType> tpu_input_dtypes;
2111 
2112   // Contains attrs "T", "sharding", "_tpu_replicate" for each XlaSharding op.
2113   XlaShardingInfoMap xla_sharding_ops;
2114 
2115   // Contains attrs "T", and a pointer to tpu_replicated_metadata for ctrl dep
2116   TpuReplicatedInputInfoMap tpu_replicated_input_ops;
2117 
2118   bool xla_spmd_input_sharded = false;
2119 
2120   if (enable_spmd_xla_partitioning) {
2121     xla_spmd_input_sharded = FindTpuReplicatedInputAndXlaSharding(
2122         graph, xla_sharding_ops, tpu_replicated_input_ops);
2123   }
2124 
2125   VLOG(1) << "xla_spmd_input_sharded: " << xla_spmd_input_sharded;
2126   VLOG(2) << DumpGraphToFile("before_remove_descendant_nodes", *graph,
2127                              flib_def_.get());
2128 
2129   if (!xla_spmd_input_sharded ||
2130       runtime_params_.minimum_input_tensors_packing > 1 ||
2131       runtime_params_.enable_auto_xla_input_sharding) {
2132     // Currently we remove `TPUReplicatedInput` nodes when the input tensors are
2133     // not sharded, input tensors packing optimization is enabled or when
2134     // auto xla input sharding is there.
2135     //
2136     // In all thse cases, we want to remove both the TPUReplicatedInput and
2137     // XlaSharding ops or else downstream rewrites will be confused.
2138     RemoveDescendantNodeOfArg(graph, "TPUReplicatedInput",
2139                               /*must_be_child_of=*/{});
2140   }
2141 
2142   if (xla_spmd_input_sharded) {
2143     // We are setting must_be_child_of to {"Arg"} because we do not want
2144     // to remove other XlaSharding ops that might be in the graph. We only
2145     // want the XlaSharding ops that are directly attached to the input
2146     // arguments to be removed.
2147     RemoveDescendantNodeOfArg(graph, "XlaSharding",
2148                               /*must_be_child_of=*/{"_Arg"});
2149   }
2150 
2151   VLOG(2) << DumpGraphToFile("before_get_input_output_info", *graph,
2152                              flib_def_.get());
2153 
2154   TF_RETURN_IF_ERROR(GetInputOutputInfo(graph, tpu_inferred_info, arg_shapes,
2155                                         tpu_input_shapes, tpu_input_dtypes,
2156                                         ctx));
2157 
2158   VLOG(2) << DumpGraphToFile("before_optimize_tpu_input_output_tensors", *graph,
2159                              flib_def_.get());
2160 
2161   string cluster_name;
2162   TF_RETURN_IF_ERROR(GetClusterName(graph, &cluster_name));
2163 
2164   if (runtime_params_.minimum_output_tensors_packing > 1) {
2165     // Copy graph to shape_inference_graph
2166     EdgeShapes tpu_output_shapes;
2167     TF_RETURN_IF_ERROR(
2168         InferShapesWithResourceVar(graph, ctx, arg_shapes, &tpu_inferred_info));
2169 
2170     // Find TPU -> CPU output edges.
2171     GroupedEdges shape_to_output =
2172         tpu_functional_internal::GroupTensorsForOutputPacking(
2173             graph, tpu_output_shapes, &tpu_inferred_info);
2174 
2175     TF_RETURN_IF_ERROR(
2176         tpu_functional_internal::CreateConcatAndSplitNodesForOutputTensor(
2177             graph, cluster_name, &tpu_output_shapes, &tpu_inferred_info,
2178             shape_to_output, runtime_params_.minimum_output_tensors_packing));
2179   }
2180 
2181   if (runtime_params_.minimum_input_tensors_packing > 1) {
2182     GroupedEdges grouped_input_edges =
2183         tpu_functional_internal::GroupTensorsForInputPacking(
2184             tpu_input_shapes, tpu_input_dtypes, runtime_params_.input_shape_opt,
2185             runtime_params_.group_tensors_for_packing);
2186     TF_RETURN_IF_ERROR(
2187         tpu_functional_internal::CreateConcatAndSplitNodesForInputTensor(
2188             graph, cluster_name, &tpu_input_shapes, grouped_input_edges,
2189             runtime_params_.minimum_input_tensors_packing,
2190             xla_spmd_input_sharded, xla_sharding_ops,
2191             tpu_replicated_input_ops));
2192   }
2193   if (runtime_params_.input_shape_opt) {
2194     TF_RETURN_IF_ERROR(tpu_functional_internal::InsertReshapeNodePairs(
2195         graph, cluster_name, &tpu_input_shapes, num_cores_per_replica));
2196   }
2197   VLOG(1) << DumpGraphToFile("optim_result", *graph);
2198 
2199   // With or without optimizations, collect the input names and shapes.
2200   for (const auto& iter : tpu_input_shapes) {
2201     std::string name = iter.first->src()->name();
2202     named_input_shapes[name] = iter.second;
2203   }
2204   return Status::OK();
2205 }
2206 
GetGraphFromFunction(Graph * graph,int device_ordinal,int * num_core_per_replica,bool * use_spmd_for_xla_partitioning)2207 Status TPUPartitionedCallOp::GetGraphFromFunction(
2208     Graph* graph, int device_ordinal, int* num_core_per_replica,
2209     bool* use_spmd_for_xla_partitioning) {
2210   FunctionLibraryRuntime::InstantiateOptions opts;
2211   FHandle handle;
2212   TF_RETURN_IF_ERROR(library_runtime_->Instantiate(
2213       func_.name(), AttrSlice(&func_.attr()), opts, &handle));
2214   const FunctionBody* fbody = library_runtime_->GetFunctionBody(handle);
2215   if (fbody == nullptr) {
2216     return errors::Internal("Could not find handle ", handle);
2217   }
2218   CopyGraph(*fbody->graph, graph);
2219 
2220   // Pin the inputs and outputs to the local device to simplify the
2221   // function-dispatching logic.
2222   local_device_name_ = library_runtime_->device()->name();
2223   replaced_input_indices_.resize(fbody->arg_nodes.size(), false);
2224   for (Node* node : graph->op_nodes()) {
2225     if (node->IsArg() || node->IsRetval()) {
2226       node->set_assigned_device_name(local_device_name_);
2227     } else if (node->type_string() == "TPUReplicateMetadata") {
2228       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "num_cores_per_replica",
2229                                      num_core_per_replica));
2230       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(),
2231                                      "use_spmd_for_xla_partitioning",
2232                                      use_spmd_for_xla_partitioning));
2233       VLOG(1) << "num_core_per_replica = " << *num_core_per_replica
2234               << ", use_spmd_for_xla_partitioning = "
2235               << *use_spmd_for_xla_partitioning;
2236 
2237       if (*num_core_per_replica > 1) {
2238         std::string topology_str;
2239         std::vector<int> device_assignment;
2240         TF_RETURN_IF_ERROR(
2241             GetNodeAttr(node->attrs(), "topology", &topology_str));
2242         TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "device_assignment",
2243                                        &device_assignment));
2244 
2245         tpu::TopologyProto topology;
2246         topology.ParseFromString(topology_str);
2247         int num_cores = topology.device_coordinates_size() / 4;
2248 
2249         if (device_assignment.empty()) {
2250           // Number of devices match the cores per replica, so we can just use
2251           // the device assignment from the existing topology instead of
2252           // generating our own.
2253           //
2254           // TODO(b/179292031): Add support for non-natural orders for pods.
2255 
2256           // check that the device coordinates for a donut is always in
2257           // natural order.
2258           std::vector<int> natural_order;
2259           switch (num_cores) {
2260             case 2:
2261               TF_RETURN_IF_ERROR(GenerateDeviceNaturalOrder(
2262                   /*x_num_cores=*/1, /*y_num_cores=*/1, /*z_num_cores=*/1,
2263                   /*num_cores_per_chip=*/2, &natural_order));
2264               break;
2265             case 4:  // we assume this is a device with one core per chip.
2266               TF_RETURN_IF_ERROR(GenerateDeviceNaturalOrder(
2267                   /*x_num_cores=*/2, /*y_num_cores=*/2, /*z_num_cores=*/1,
2268                   /*num_cores_per_chip=*/1, &natural_order));
2269               break;
2270             case 8:
2271               TF_RETURN_IF_ERROR(GenerateDeviceNaturalOrder(
2272                   /*x_num_cores=*/2, /*y_num_cores=*/2, /*z_num_cores=*/1,
2273                   /*num_cores_per_chip=*/2, &natural_order));
2274               break;
2275             default:
2276               return errors::Unimplemented(
2277                   "You must specify a device assignment for all TPU "
2278                   "configurations.");
2279           }
2280           if (*num_core_per_replica != num_cores &&
2281               !std::equal(natural_order.begin(), natural_order.end(),
2282                           topology.device_coordinates().begin())) {
2283             return errors::InvalidArgument(
2284                 "Topology device coordinates for XLA SPMD on donuts must be in "
2285                 "natural order.");
2286           }
2287 
2288           auto coordinates_start =
2289               topology.device_coordinates().begin() + device_ordinal * 4;
2290           auto coordinates_end = topology.device_coordinates().begin() +
2291                                  (device_ordinal + *num_core_per_replica) * 4;
2292 
2293           node->ClearAttr("device_assignment");
2294           device_assignment.insert(device_assignment.begin(), coordinates_start,
2295                                    coordinates_end);
2296           node->AddAttr("device_assignment", device_assignment);
2297         }
2298       }
2299     }
2300   }
2301   return Status::OK();
2302 }
2303 
PlacementHelper(const DeviceSet & device_set,const GraphOptimizationPassOptions & optimization_options,const string & function_name)2304 Status TPUPartitionedCallOp::PlacementHelper(
2305     const DeviceSet& device_set,
2306     const GraphOptimizationPassOptions& optimization_options,
2307     const string& function_name) {
2308   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
2309       OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
2310   Placer placer(optimization_options.graph->get(), function_name,
2311                 optimization_options.flib_def, &device_set);
2312   TF_RETURN_IF_ERROR(placer.Run());
2313   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
2314       OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
2315   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
2316       OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
2317   return Status::OK();
2318 }
2319 
PartitionHelper(const DeviceSet & device_set,const GraphOptimizationPassOptions & optimization_options,Graph * graph,std::unordered_map<std::string,std::unique_ptr<Graph>> * subgraphs)2320 Status TPUPartitionedCallOp::PartitionHelper(
2321     const DeviceSet& device_set,
2322     const GraphOptimizationPassOptions& optimization_options, Graph* graph,
2323     std::unordered_map<std::string, std::unique_ptr<Graph>>* subgraphs) {
2324   PartitionOptions partition_options;
2325   partition_options.node_to_loc = [](const Node* node) {
2326     // TODO(akshayka): To better support the distributed case, first split
2327     // the graph by worker (e.g,. using the master session's
2328     // `SplitByWorker` policy), and then recursively partition the
2329     // per-worker shards at the remote worker(s).
2330     return node->assigned_device_name();
2331   };
2332   int64_t edge_name_counter = 0;
2333   partition_options.new_name = [&edge_name_counter](const string& prefix) {
2334     return strings::StrCat(prefix, "/_", ++edge_name_counter);
2335   };
2336   partition_options.get_incarnation = [&device_set](const string& name) {
2337     const Device* d = device_set.FindDeviceByName(name);
2338     if (d == nullptr) {
2339       return PartitionOptions::kIllegalIncarnation;
2340     } else {
2341       return d->attributes().incarnation();
2342     }
2343   };
2344   partition_options.control_flow_added = false;
2345   std::unordered_map<std::string, GraphDef> partitions;
2346   TF_RETURN_IF_ERROR(Partition(partition_options, graph, &partitions));
2347 
2348   VLOG(3) << "Partitioned function '" << func_.name() << "', yielding "
2349           << partitions.size() << " shards.";
2350 
2351   const FunctionLibraryDefinition* flib_def = &graph->flib_def();
2352   for (auto& partition : partitions) {
2353     std::unique_ptr<Graph> subgraph(new Graph(flib_def));
2354     GraphConstructorOptions opts;
2355     opts.allow_internal_ops = true;
2356     opts.expect_device_spec = true;
2357     const string& device = partition.first;
2358     GraphDef& graph_def = partition.second;
2359     TF_RETURN_IF_ERROR(
2360         ConvertGraphDefToGraph(opts, std::move(graph_def), subgraph.get()));
2361     subgraphs->emplace(device, std::move(subgraph));
2362   }
2363 
2364   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
2365       OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
2366 
2367   return Status::OK();
2368 }
2369 
InstantiatePartition(const Graph & graph,const string & function_name,const string & target_device,FHandle * handle,std::unique_ptr<FunctionLibraryDefinition> * out_flib_def)2370 Status TPUPartitionedCallOp::InstantiatePartition(
2371     const Graph& graph, const string& function_name,
2372     const string& target_device, FHandle* handle,
2373     std::unique_ptr<FunctionLibraryDefinition>* out_flib_def) {
2374   FunctionDef shard;
2375   TF_RETURN_IF_ERROR(GraphToFunctionDef(graph, function_name, &shard));
2376   TF_RETURN_IF_ERROR(flib_def_->AddFunctionDef(shard));
2377   FunctionLibraryRuntime::InstantiateOptions opts;
2378   opts.target = target_device;
2379   if (out_flib_def) {
2380     *out_flib_def = std::make_unique<FunctionLibraryDefinition>(*flib_def_);
2381     opts.lib_def = out_flib_def->get();
2382   } else {
2383     opts.lib_def = flib_def_.get();
2384   }
2385   return library_runtime_->Instantiate(function_name, AttrSlice(&shard.attr()),
2386                                        opts, handle);
2387 }
2388 
SetDeviceOrdinal(const DeviceSet & device_set,int device_ordinal,Graph * graph,bool * modified)2389 Status TPUPartitionedCallOp::SetDeviceOrdinal(const DeviceSet& device_set,
2390                                               int device_ordinal, Graph* graph,
2391                                               bool* modified) {
2392   int ordinal = -1;
2393   for (Node* node : graph->op_nodes()) {
2394     if (node->type_string() == kVarHandleOp) {
2395       if (IsInputToTPUReplicate(node)) {
2396         // If this VarHandleOp is going to a TPU computation,
2397         // it refers to the TPU variable that we created when replacing the
2398         // resource arguments with VarHandleOps.
2399         node->set_assigned_device_name(
2400             strings::StrCat(kTPUDeviceNamePrefix, device_ordinal));
2401       }
2402       continue;
2403     }
2404     if (HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) {
2405       // Outside compilation related node.
2406       TF_RETURN_IF_ERROR(
2407           SetDeviceOrdinalAttributeForNode(node, device_ordinal));
2408       *modified = true;
2409       continue;
2410     }
2411     const AttrValue* attr = node->attrs().Find(kDeviceOrdinalAttr);
2412     if (attr != nullptr) {
2413       if (!IsSupportedTPUOp(node->type_string())) {
2414         return errors::InvalidArgument("Node ", node->type_string(),
2415                                        " is not yet supported.");
2416       }
2417       if (ordinal == -1) {
2418         ordinal = attr->i();
2419       } else {
2420         if (ordinal != attr->i()) {
2421           return errors::InvalidArgument(
2422               "Can only partition graphs that use a single device ordinal.");
2423         }
2424       }
2425       node->ClearAttr(kDeviceOrdinalAttr);
2426       node->AddAttr(kDeviceOrdinalAttr, device_ordinal);
2427       VLOG(3) << "Set device ordinal of " << node->type_string() << " to "
2428               << device_ordinal;
2429       *modified = true;
2430     }
2431     if (node->IsSend() || node->IsRecv()) {
2432       static const char* kSendDevice = "send_device";
2433       static const char* kSendDeviceIncarnation = "send_device_incarnation";
2434       static const char* kRecvDevice = "recv_device";
2435       const AttrValue* attr = node->attrs().Find(kSendDevice);
2436       if (attr != nullptr) {
2437         string device = attr->s();
2438         TF_RETURN_IF_ERROR(
2439             UpdateTPUDeviceOrdinal(device_ordinal, &device, modified));
2440         node->ClearAttr(kSendDevice);
2441         node->AddAttr(kSendDevice, device);
2442         node->ClearAttr(kSendDeviceIncarnation);
2443         const Device* d = device_set.FindDeviceByName(device);
2444         int64_t send_incarnation = (d == nullptr)
2445                                        ? PartitionOptions::kIllegalIncarnation
2446                                        : d->attributes().incarnation();
2447         node->AddAttr(kSendDeviceIncarnation, send_incarnation);
2448       }
2449       attr = node->attrs().Find(kRecvDevice);
2450       if (attr != nullptr) {
2451         string device = attr->s();
2452         TF_RETURN_IF_ERROR(
2453             UpdateTPUDeviceOrdinal(device_ordinal, &device, modified));
2454         node->ClearAttr(kRecvDevice);
2455         node->AddAttr(kRecvDevice, device);
2456       }
2457     }
2458   }
2459   return Status::OK();
2460 }
2461 
InstantiateFunctionsFromSubgraphs(const DeviceSet & device_set,int replica_id,uint64 cache_hash,int num_cores_per_replica,std::unordered_map<std::string,std::unique_ptr<Graph>> subgraphs)2462 Status TPUPartitionedCallOp::InstantiateFunctionsFromSubgraphs(
2463     const DeviceSet& device_set, int replica_id, uint64 cache_hash,
2464     int num_cores_per_replica,
2465     std::unordered_map<std::string, std::unique_ptr<Graph>> subgraphs) {
2466   const Device* reference_device = nullptr;
2467   auto entry =
2468       partition_cache_.emplace(cache_hash, std::vector<DeviceAndFHandle>());
2469 
2470   bool rewritten = false;
2471   for (auto& pair : subgraphs) {
2472     string target = pair.first;
2473     int device_ordinal = replica_id;
2474     if (num_cores_per_replica > 1) {
2475       DeviceNameUtils::ParsedName parsed_device;
2476       if (!DeviceNameUtils::ParseFullName(target, &parsed_device)) {
2477         return errors::InvalidArgument("Malformed assigned device '", target,
2478                                        "'");
2479       }
2480       device_ordinal = parsed_device.id;
2481     }
2482     Device* device;
2483     TF_RETURN_IF_ERROR(
2484         library_runtime_->device_mgr()->LookupDevice(target, &device));
2485     if (reference_device == nullptr) {
2486       reference_device = device;
2487     } else {
2488       if (!DeviceNameUtils::IsSameAddressSpace(
2489               device->parsed_name(), reference_device->parsed_name())) {
2490         return errors::InvalidArgument(
2491             "TPUPartitionedCallOp does not yet support inter-process"
2492             "execution.");
2493       }
2494     }
2495     TF_RETURN_IF_ERROR(device->MaybeRewriteGraph(&pair.second));
2496     Graph* subgraph = pair.second.get();
2497     // For model paralleism inference, we only support num_replica == 1, thus
2498     // there is no need to update the device_ordinal anymore.
2499     if (num_cores_per_replica == 1) {
2500       TF_RETURN_IF_ERROR(
2501           SetDeviceOrdinal(device_set, device_ordinal, subgraph, &rewritten));
2502     } else {
2503       VLOG(1) << "Skip SetDeviceOrdinal()";
2504     }
2505     string function_name = flib_def_->UniqueFunctionName(
2506         strings::StrCat(func_.name(), "_hash_", cache_hash));
2507     TF_RETURN_IF_ERROR(
2508         UpdateTPUDeviceOrdinal(device_ordinal, &target, &rewritten));
2509     FHandle handle;
2510     // Use a copy of the current `flib_def_` to instantiate the function to
2511     // avoid races.
2512     std::unique_ptr<FunctionLibraryDefinition> sub_flib_def;
2513     TF_RETURN_IF_ERROR(InstantiatePartition(*subgraph, function_name, target,
2514                                             &handle, &sub_flib_def));
2515     // Add handle to the cache entry.
2516     entry.first->second.push_back(
2517         DeviceAndFHandle{.device = target,
2518                          .handle = handle,
2519                          .flib_def = std::move(sub_flib_def)});
2520   }
2521 
2522   if (!rewritten) {
2523     // For regular use cases, TPUPartitionedCallOp only works when the
2524     // function being called in rewritten for TPU. If we don't see any signs
2525     // of this rewriting, warn the user about it.
2526     // We don't raise an error because we want to support the use case of
2527     // running tpu.initialize_system eagerly. In this case, we can't use
2528     // tpu.rewrite because it will add compilation ops that require TPU
2529     // to be initialized, i.e. there is a chicken and egg problem.
2530     // We run tpu.initialize_system through TPUPartitionedCallOp because it
2531     // invokes graph rewrite passes that are necessary for initialization to
2532     // work.
2533     LOG(INFO) << "Function body was not rewritten for TPU. "
2534               << "This is probably a bug unless you are initializing "
2535               << "TPUs eagerly.";
2536   }
2537   return Status::OK();
2538 }
2539 
ExecuteRemoteFunction(const FunctionLibraryRuntime::Options & opts,FHandle handle,OpKernelContext * ctx,ReffedStatusCallback * done)2540 void TPUPartitionedCallOp::ExecuteRemoteFunction(
2541     const FunctionLibraryRuntime::Options& opts, FHandle handle,
2542     OpKernelContext* ctx, ReffedStatusCallback* done) {
2543   std::vector<Tensor> dummy_args;
2544   std::vector<Tensor>* dummy_rets = new std::vector<Tensor>;
2545 
2546   profiler::TraceMe trace_me("TPUPartitionedCallOp-ExecuteRemote");
2547   absl::ReaderMutexLock l(&mu_);
2548   library_runtime_->Run(opts, handle, dummy_args, dummy_rets,
2549                         [dummy_rets, done, ctx](const Status& status) {
2550                           if (!status.ok()) {
2551                             ctx->SetStatus(status);
2552                           }
2553                           delete dummy_rets;
2554                           done->Unref();
2555                         });
2556 }
2557 
ExecuteLocalFunction(const FunctionLibraryRuntime::Options & opts,const OpInputList & arguments,FHandle handle,OpKernelContext * ctx,ReffedStatusCallback * done)2558 void TPUPartitionedCallOp::ExecuteLocalFunction(
2559     const FunctionLibraryRuntime::Options& opts, const OpInputList& arguments,
2560     FHandle handle, OpKernelContext* ctx, ReffedStatusCallback* done) {
2561   std::vector<Tensor> args;
2562 
2563   for (int i = 0; i < arguments.size(); ++i) {
2564     if (!replaced_input_indices_[i]) {
2565       // _Arg nodes of type DT_RESOURCE that go into a TPU node have been
2566       // replaced by TPU VarHandleOp nodes. No longer need to pass them as
2567       // inputs.
2568       args.push_back(arguments[i]);
2569     }
2570   }
2571   auto* rets = new std::vector<Tensor>;
2572 
2573   profiler::TraceMe trace_me("TPUPartitionedCallOp-ExecuteLocal");
2574   absl::ReaderMutexLock l(&mu_);
2575   library_runtime_->Run(opts, handle, args, rets,
2576                         [rets, done, ctx](const Status& status) {
2577                           if (!status.ok()) {
2578                             ctx->SetStatus(status);
2579                           } else {
2580                             for (int i = 0; i < rets->size(); ++i) {
2581                               ctx->set_output(i, (*rets)[i]);
2582                             }
2583                           }
2584                           delete rets;
2585                           done->Unref();
2586                         });
2587 }
2588 
ExecuteFunctions(const std::vector<DeviceAndFHandle> & functions,OpKernelContext * ctx,int device_ordinal,int64_t ordinal_selector_req_id,DoneCallback done)2589 void TPUPartitionedCallOp::ExecuteFunctions(
2590     const std::vector<DeviceAndFHandle>& functions, OpKernelContext* ctx,
2591     int device_ordinal, int64_t ordinal_selector_req_id, DoneCallback done) {
2592   profiler::TraceMe trace_me("TPUPartitionedCallOp-ExecuteFunctions");
2593   FunctionLibraryRuntime::Options opts;
2594   opts.step_container = ctx->step_container();
2595   opts.cancellation_manager = ctx->cancellation_manager();
2596   opts.stats_collector = ctx->stats_collector();
2597   // TODO(akshayka): Consider selecting a runner on a per-device basis,
2598   // i.e., using device-specific threadpools when available.
2599   opts.runner = ctx->runner();
2600   opts.source_device = local_device_name_;
2601   opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
2602 
2603   OpInputList arguments;
2604   OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
2605 
2606   auto* rendez = new PrivateIntraProcessRendezvous(device_mgr_);
2607   opts.rendezvous = rendez;
2608 
2609   StatusCallback callback(
2610       [rendez = rendez, done = std::move(done), device_ordinal = device_ordinal,
2611        req_id = ordinal_selector_req_id,
2612        ordinal_selector = ordinal_selector_](const Status& status) {
2613         delete rendez;
2614         done();
2615         if (req_id >= 0) {
2616           ordinal_selector->DequeueFromCoreSelector(device_ordinal, req_id);
2617         }
2618       });
2619 
2620   auto* refcounted_done = new ReffedStatusCallback(std::move(callback));
2621   for (int i = 1; i < functions.size(); ++i) {
2622     refcounted_done->Ref();
2623   }
2624   for (const DeviceAndFHandle& entry : functions) {
2625     const string& target_device = entry.device;
2626     FHandle handle = entry.handle;
2627     VLOG(3) << "Running function shard on device " << target_device
2628             << " with local device name " << local_device_name_;
2629     if (target_device == local_device_name_) {
2630       opts.remote_execution = false;
2631       ExecuteLocalFunction(opts, arguments, handle, ctx, refcounted_done);
2632     } else {
2633       opts.remote_execution = true;
2634       ExecuteRemoteFunction(opts, handle, ctx, refcounted_done);
2635     }
2636   }
2637 }
2638 
2639 REGISTER_KERNEL_BUILDER(Name("TPUPartitionedCall").Device(DEVICE_CPU),
2640                         TPUPartitionedCallOp);
2641 
2642 }  // end namespace tensorflow
2643