1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/kernels/tpu_functional_ops.h"
17
18 #include <memory>
19
20 #include "tensorflow/core/framework/op_kernel.h"
21
22 #define EIGEN_USE_THREADS
23
24 #include "absl/base/call_once.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/synchronization/mutex.h"
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/compiler/tf2xla/sharding_util.h"
29 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/common_runtime/function_body.h"
32 #include "tensorflow/core/common_runtime/graph_constructor.h"
33 #include "tensorflow/core/common_runtime/placer.h"
34 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
35 #include "tensorflow/core/framework/graph_to_functiondef.h"
36 #include "tensorflow/core/framework/metrics.h"
37 #include "tensorflow/core/framework/node_def.pb.h"
38 #include "tensorflow/core/framework/node_def_util.h"
39 #include "tensorflow/core/framework/resource_mgr.h"
40 #include "tensorflow/core/framework/resource_var.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/tensor.pb.h"
43 #include "tensorflow/core/framework/tensor_shape.h"
44 #include "tensorflow/core/graph/graph_partition.h"
45 #include "tensorflow/core/graph/node_builder.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/hash/hash.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/platform/blocking_counter.h"
50 #include "tensorflow/core/platform/errors.h"
51 #include "tensorflow/core/platform/fingerprint.h"
52 #include "tensorflow/core/platform/refcount.h"
53 #include "tensorflow/core/profiler/lib/traceme.h"
54 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
55 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
56 #include "tensorflow/core/tpu/kernels/tpu_fingerprint_lookup.h"
57 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
58 #include "tensorflow/core/tpu/kernels/tpu_op_util.h"
59 #include "tensorflow/core/tpu/kernels/tpu_util.h"
60 #include "tensorflow/core/tpu/tpu_configuration.h"
61 #include "tensorflow/core/tpu/tpu_defs.h"
62 #include "tensorflow/core/util/dump_graph.h"
63
64 namespace tensorflow {
65 namespace {
66
67 constexpr char kTpuReplicateAttr[] = "_tpu_replicate";
68 constexpr int kLastDimOfTpuInputFastPath = 128;
69 constexpr int kOtherDimOfTpuInputFastPath = 8;
70
71 constexpr char kXLAShardingAttrName[] = "sharding";
72 constexpr char kXLAShardingAttrAltName[] = "_XlaSharding";
73
GenerateDeviceNaturalOrder(int x_num_cores,int y_num_cores,int z_num_cores,int num_cores_per_chip,std::vector<int> * natural_order)74 Status GenerateDeviceNaturalOrder(int x_num_cores, int y_num_cores,
75 int z_num_cores, int num_cores_per_chip,
76 std::vector<int>* natural_order) {
77 for (int y = 0; y < y_num_cores; ++y) {
78 for (int x = 0; x < x_num_cores; ++x) {
79 for (int z = 0; z < z_num_cores; ++z) {
80 for (int c = 0; c < num_cores_per_chip; ++c) {
81 natural_order->push_back(x);
82 natural_order->push_back(y);
83 natural_order->push_back(z);
84 natural_order->push_back(c);
85 }
86 }
87 }
88 }
89
90 return Status::OK();
91 }
92
93 struct TPUVariableInfo {
TPUVariableInfotensorflow::__anone29eeb940111::TPUVariableInfo94 TPUVariableInfo(int device_ordinal_id, bool use_fast_mem)
95 : device_ordinal(device_ordinal_id), fast_mem(use_fast_mem) {}
96 // The TPU core which the variable will be placed on.
97 int device_ordinal;
98 // If true, try to place the variable on fast memory space if hardware
99 // support.
100 bool fast_mem;
101 };
102
103 // Check the descendants to parse the placement information for the input node.
104 // num_cores_per_replica descriables how many cores the single model uses.
ParseTPUVariableInfor(const Node * node,const int num_cores_per_replica,TPUVariableInfo * var_info)105 Status ParseTPUVariableInfor(const Node* node, const int num_cores_per_replica,
106 TPUVariableInfo* var_info) {
107 int core = 0;
108 bool use_fast_mem = false;
109 VLOG(3) << "Parse tpu variable information for " << node->name();
110 for (const Edge* edge : node->out_edges()) {
111 if (edge->IsControlEdge()) continue;
112 Node* next = edge->dst();
113 VLOG(3) << "Neighbor node " << next->name();
114 // Looking through Enter/Switch/ReadVariableOp nodes.
115 while (next->IsEnter() || next->IsSwitch() ||
116 next->type_string() == "ReadVariableOp") {
117 Node* new_node = nullptr;
118 for (const Edge* e : next->out_edges()) {
119 if (!e->IsControlEdge()) {
120 new_node = e->dst();
121 break;
122 }
123 }
124 if (new_node == nullptr) break;
125 next = new_node;
126 }
127 if (next != edge->dst()) {
128 VLOG(3) << "Looked through Enter/Switch node " << next->DebugString();
129 }
130 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
131 ParseShardingFromDevice(*next, num_cores_per_replica,
132 /*add_metadata=*/false));
133 if (sharding.has_value() && sharding->tile_assignment_devices_size() > 0) {
134 core = sharding->tile_assignment_devices(0);
135 VLOG(3) << next->name() << " is placed on core " << core;
136 }
137 if (next->attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
138 use_fast_mem = true;
139 VLOG(3) << next->name() << " has " << TPU_FAST_MEM_ATTR << " attribute";
140 }
141 }
142 VLOG(1) << "Place " << node->name() << " to core: " << core
143 << " fast_mem: " << use_fast_mem;
144 var_info->device_ordinal = core;
145 var_info->fast_mem = use_fast_mem;
146
147 return Status::OK();
148 }
149
150 // Helper to instantiate function "func" in the library "lib".
Instantiate(FunctionLibraryRuntime * lib,const NameAttrList & func,FunctionLibraryRuntime::Handle * handle)151 Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func,
152 FunctionLibraryRuntime::Handle* handle) {
153 return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle);
154 }
155
156 static constexpr const char* const kDeviceOrdinalAttr = "device_ordinal";
157
158 static constexpr const char* const kTPUExecuteOp = "TPUExecute";
159 static constexpr const char* const kInfeedEnqueueOp = "InfeedEnqueue";
160 static constexpr const char* const kInfeedEnqueueTupleOp = "InfeedEnqueueTuple";
161 static constexpr const char* const kOutfeedDequeueOp = "OutfeedDequeue";
162 static constexpr const char* const kOutfeedDequeueTupleOp =
163 "OutfeedDequeueTuple";
164 static constexpr const char* const kOutfeedDequeueV2Op = "OutfeedDequeueV2";
165 static constexpr const char* const kOutfeedDequeueTupleV2Op =
166 "OutfeedDequeueTupleV2";
167 static constexpr const char* const kVarHandleOp = "VarHandleOp";
168
169 static constexpr const char* const kTPUDeviceNamePrefix = "/device:TPU:";
170 static constexpr const int kTPUDefaultDeviceOrdinal = 0;
171
IsSupportedTPUOp(const string & op_name)172 bool IsSupportedTPUOp(const string& op_name) {
173 return op_name == kTPUExecuteOp || op_name == kInfeedEnqueueOp ||
174 op_name == kInfeedEnqueueTupleOp || op_name == kOutfeedDequeueOp ||
175 op_name == kOutfeedDequeueTupleOp || op_name == kOutfeedDequeueV2Op ||
176 op_name == kOutfeedDequeueTupleV2Op;
177 }
178
179 // Sets the sharding attributes for an XlaSharding node.
SetXlaShardingNodeAttr(Node * xla_sharding_node,int num_cores_per_replica,int rank,int shard_dim)180 void SetXlaShardingNodeAttr(Node* xla_sharding_node, int num_cores_per_replica,
181 int rank, int shard_dim) {
182 auto sharding = absl::make_optional<xla::OpSharding>();
183 sharding->set_type(xla::OpSharding::OTHER);
184
185 std::vector<int64> dims(rank, 1LL);
186 dims[shard_dim] = num_cores_per_replica;
187 for (auto dim : dims) {
188 sharding->add_tile_assignment_dimensions(dim);
189 }
190
191 // Sets up tile_assignment_devices.
192 for (int d = 0; d < num_cores_per_replica; ++d) {
193 sharding->add_tile_assignment_devices(d);
194 }
195
196 xla_sharding_node->ClearAttr(kXLAShardingAttrName);
197 xla_sharding_node->ClearAttr(kXLAShardingAttrAltName);
198 xla_sharding_node->AddAttr(kXLAShardingAttrName,
199 sharding->SerializeAsString());
200 xla_sharding_node->AddAttr(kXLAShardingAttrAltName,
201 sharding->SerializeAsString());
202 }
203
204 // If 'device_name' is a TPU device, set its device_ordinal to 'device_ordinal'
205 // and set '*rewritten' to true. Otherwise, do nothing.
UpdateTPUDeviceOrdinal(int device_ordinal,string * device_name,bool * rewritten)206 Status UpdateTPUDeviceOrdinal(int device_ordinal, string* device_name,
207 bool* rewritten) {
208 DeviceNameUtils::ParsedName device;
209 if (!DeviceNameUtils::ParseFullName(*device_name, &device)) {
210 return errors::InvalidArgument("Unable to parse device name ",
211 *device_name);
212 }
213 if (device.type == DEVICE_TPU_NODE) {
214 device.id = device_ordinal;
215 *rewritten = true;
216 }
217 *device_name = DeviceNameUtils::ParsedNameToString(device);
218 return Status::OK();
219 }
220
FindHostToDeviceEdge(Node * arg_node)221 const Edge* FindHostToDeviceEdge(Node* arg_node) {
222 const Edge* candidate_edge = nullptr;
223 for (const Edge* edge : arg_node->out_edges())
224 if (!edge->IsControlEdge()) {
225 // Find CPU -> TPU input edge.
226 const Edge* original_edge;
227 while (edge->src()->attrs().Find(kTpuReplicateAttr) != nullptr ||
228 edge->dst()->attrs().Find(kTpuReplicateAttr) == nullptr) {
229 const Node* new_src = edge->dst();
230 original_edge = edge;
231 for (const Edge* new_edge : new_src->out_edges())
232 if (!new_edge->IsControlEdge()) {
233 original_edge = edge;
234 edge = new_edge;
235 break;
236 }
237 if (original_edge == edge) break;
238 }
239 // TPU input edge: src is on CPU and dest is on TPU.
240 if (edge->src()->attrs().Find(kTpuReplicateAttr) != nullptr ||
241 edge->dst()->attrs().Find(kTpuReplicateAttr) == nullptr)
242 continue;
243 // Won't work with GuaranteeConst.
244 if (edge->src()->type_string() == "GuaranteeConst") break;
245 candidate_edge = edge;
246 }
247 return candidate_edge;
248 }
249
CreateInputProxy(Graph * graph,const Edge * candidate_edge,const Edge ** tpu_input_edge)250 Status CreateInputProxy(Graph* graph, const Edge* candidate_edge,
251 const Edge** tpu_input_edge) {
252 std::vector<const Edge*> edges_to_replace;
253 for (const Edge* input_edge : candidate_edge->src()->out_edges()) {
254 if (!input_edge->IsControlEdge() &&
255 input_edge->dst()->attrs().Find(kTpuReplicateAttr) != nullptr)
256 edges_to_replace.push_back(input_edge);
257 }
258 // Build an Identity node as the proxy of the original edge source.
259 Node* input_identity_node = nullptr;
260 TF_RETURN_IF_ERROR(
261 NodeBuilder(strings::StrCat(candidate_edge->src()->name(), "/proxy"),
262 "Identity")
263 .Input(candidate_edge->src())
264 .Attr("T", candidate_edge->src()->output_type(0))
265 .Attr(kTpuReplicateAttr,
266 candidate_edge->dst()->attrs().Find(kTpuReplicateAttr)->s())
267 .Finalize(graph, &input_identity_node));
268 // Find the tpu input edge from original source to proxy identity.
269 for (const Edge* input_edge : input_identity_node->in_edges())
270 if (input_edge->src() == candidate_edge->src()) {
271 *tpu_input_edge = input_edge;
272 break;
273 }
274 // Replace original input edges with proxy's output.
275 for (const Edge* input_edge : edges_to_replace) {
276 graph->RemoveEdge(input_edge);
277 graph->AddEdge(input_identity_node, 0, input_edge->dst(),
278 input_edge->dst_input());
279 }
280 return Status::OK();
281 }
282
GetClusterName(Graph * graph,string * cluster_name)283 Status GetClusterName(Graph* graph, string* cluster_name) {
284 *cluster_name = "";
285 for (const Node* node : graph->nodes()) {
286 if (node->attrs().Find(kTpuReplicateAttr) == nullptr) continue;
287 if (cluster_name->empty())
288 *cluster_name = node->attrs().Find(kTpuReplicateAttr)->s();
289 // When optimization is turned on, the graph should only have one TPU
290 // cluster.
291 if (*cluster_name != node->attrs().Find(kTpuReplicateAttr)->s())
292 return errors::FailedPrecondition(
293 "Only one cluster is allowed when optimization is turned on for "
294 "TPUPartitionedCall. Found ",
295 node->attrs().Find(kTpuReplicateAttr)->s(), " and ", *cluster_name);
296 }
297 return Status::OK();
298 }
299
300 // Removes nodes that has no effect that directly descends from _Arg node.
301 //
302 // This is currently used for removing TPUReplicatedInput and XlaSharding node
303 // are always descendants of _Arg node. During optimization, we try to insert
304 // nodes in between _Arg and _Arg's children, where some of the nodes inserted
305 // are TPU nodes. We will add the TPUReplicatedInput and XlaSharding op nodes
306 // back where necessary.
307 //
308 // Returns the number of nodes that were removed.
RemoveDescendantNodeOfArg(Graph * graph,const std::string & node_type_to_remove,const std::set<std::string> & must_be_child_of)309 int64 RemoveDescendantNodeOfArg(Graph* graph,
310 const std::string& node_type_to_remove,
311 const std::set<std::string>& must_be_child_of) {
312 int64_t nodes_removed = 0;
313 std::vector<std::pair<const Edge*, std::vector<const Edge*>>> edges_to_remove;
314
315 for (Node* node : graph->nodes()) {
316 if (node_type_to_remove != node->type_string()) continue;
317 if (!must_be_child_of.empty()) {
318 bool has_arg_parent = false;
319 for (const Edge* edge : node->in_edges()) {
320 if (must_be_child_of.count(edge->src()->type_string()) > 0) {
321 has_arg_parent = true;
322 }
323 }
324 if (!has_arg_parent) continue;
325 }
326 nodes_removed++;
327
328 const Edge* input_edge = nullptr;
329 std::vector<const Edge*> output_edges;
330 for (const Edge* edge : node->in_edges())
331 if (!edge->IsControlEdge()) {
332 input_edge = edge;
333 break;
334 }
335 for (const Edge* edge : node->out_edges())
336 if (!edge->IsControlEdge()) {
337 output_edges.push_back(edge);
338 }
339 if (input_edge != nullptr && !output_edges.empty())
340 edges_to_remove.push_back(std::make_pair(input_edge, output_edges));
341 }
342 for (const auto& it : edges_to_remove) {
343 for (const Edge* output_edge : it.second) {
344 graph->RemoveEdge(output_edge);
345 graph->AddEdge(it.first->src(), it.first->src_output(),
346 output_edge->dst(), output_edge->dst_input());
347 }
348 graph->RemoveNode(it.first->dst());
349 }
350 return nodes_removed;
351 }
352
GetInputHash(OpKernelContext * ctx)353 uint64 GetInputHash(OpKernelContext* ctx) {
354 uint64 input_hash = 0; // initialization for determinism.
355 // Use the number of elements to compute hash.
356 // TODO(chiachenc): use fhe full shape to compute the hash.
357 for (int i = 0; i < ctx->num_inputs(); ++i) {
358 VLOG(4) << "InputHash, combine input " << i
359 << ", NumElements: " << ctx->input(i).NumElements();
360 input_hash = Hash64Combine(input_hash, ctx->input(i).NumElements());
361 }
362 return input_hash;
363 }
364
HashShapeAndType(const string prefix,const std::vector<int> & input_dims,const DataType & dtype,const bool input_shape_opt)365 string HashShapeAndType(const string prefix, const std::vector<int>& input_dims,
366 const DataType& dtype, const bool input_shape_opt) {
367 string hash = strings::StrCat(prefix, dtype, "_dims");
368 // We will concat at the last dimension.
369 for (int d = 0; d < input_dims.size() - 1; ++d) {
370 strings::StrAppend(&hash, "_", input_dims.at(d));
371 }
372
373 if (input_shape_opt) {
374 if (input_dims.back() % kLastDimOfTpuInputFastPath == 0) {
375 strings::StrAppend(&hash, "_last_", kLastDimOfTpuInputFastPath, "n");
376 } else {
377 strings::StrAppend(&hash, "_last_other");
378 }
379 }
380 return hash;
381 }
382
383 // Get the information for input and output tensors (shapes, dtypes, etc).
GetInputOutputInfo(Graph * graph,GraphShapeInfo & tpu_inferred_info,std::map<int,InferredShape> & arg_shapes,EdgeShapes & tpu_input_shapes,absl::flat_hash_map<const Edge *,DataType> & tpu_input_dtypes,OpKernelContext * ctx)384 Status GetInputOutputInfo(
385 Graph* graph, GraphShapeInfo& tpu_inferred_info,
386 std::map<int, InferredShape>& arg_shapes, EdgeShapes& tpu_input_shapes,
387 absl::flat_hash_map<const Edge*, DataType>& tpu_input_dtypes,
388 OpKernelContext* ctx) {
389 // Search for the device-to-host or tpu-to-cpu edges.
390 for (Node* node : graph->op_nodes()) {
391 if (!node->IsArg()) continue;
392 const DataType dtype = node->attrs().Find("T")->type();
393 const int arg_index = node->attrs().Find("index")->i();
394 if (dtype != DT_INT32 && dtype != DT_BFLOAT16 && dtype != DT_FLOAT &&
395 dtype != DT_BOOL && dtype != DT_QINT8 && dtype != DT_QUINT8)
396 continue;
397 VLOG(3) << "Argnode: " << node->DebugString();
398 const Tensor& tensor = ctx->input(arg_index);
399
400 // Search for the cross-device edge from arg node.
401 const Edge* candidate_edge = FindHostToDeviceEdge(node);
402 if (candidate_edge == nullptr) continue;
403
404 // Make proxy and get the sole tpu_input_edge for transfer the input tensor
405 // corresponding to the current _Arg node.
406 const Edge* tpu_input_edge = nullptr;
407 TF_RETURN_IF_ERROR(
408 CreateInputProxy(graph, candidate_edge, &tpu_input_edge));
409 if (tpu_input_edge == nullptr)
410 return errors::NotFound("Couldn't find TPU input edge for", node->name());
411
412 // Optimize edge: original source to proxy identity.
413 VLOG(3) << "Input: " << tpu_input_edge->src()->name();
414 std::vector<int>& input_shapes = tpu_input_shapes[tpu_input_edge];
415 input_shapes.clear();
416 for (int d = 0; d < tensor.dims(); ++d) {
417 input_shapes.push_back(tensor.dim_size(d));
418 VLOG(3) << "Input Tensor: Dim[" << d << "] = " << tensor.dim_size(d);
419 }
420 tpu_input_dtypes[tpu_input_edge] = tensor.dtype();
421
422 // Collect shapes for non-resource-variable args.
423 PartialTensorShape partial_tensor_shape;
424 auto partial_shape = PartialTensorShape::MakePartialShape(
425 input_shapes.data(), input_shapes.size(), &partial_tensor_shape);
426 InferredShape inferred_shape = {partial_tensor_shape};
427 arg_shapes[arg_index] = inferred_shape;
428 }
429 return Status::OK();
430 }
431
432 // Converts a integer vector that represents the shapes to a Tensorshape.
ConvertEdgeShapesToTensorShapes(const std::map<std::string,std::vector<int>> & named_input_shapes,std::vector<TensorShape> * shapes)433 Status ConvertEdgeShapesToTensorShapes(
434 const std::map<std::string, std::vector<int>>& named_input_shapes,
435 std::vector<TensorShape>* shapes) {
436 shapes->resize(named_input_shapes.size());
437 int32_t i = 0;
438 // keys in tpu_input_shapes may be stale.
439 for (const auto& iter : named_input_shapes) {
440 VLOG(2) << iter.first << ", rank: " << iter.second.size();
441 const int64_t rank = iter.second.size();
442 std::vector<int64> dims(rank);
443 for (int64_t d = 0; d < rank; ++d) {
444 VLOG(2) << " dim[" << d << "]: " << iter.second.at(d);
445 dims[d] = iter.second.at(d);
446 }
447 TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(dims, &(*shapes)[i]));
448 i++;
449 }
450 return Status::OK();
451 }
452
453 // Get the TF fingerprint with the information from the TPUCompileOp or
454 // _TPUCompileMlirOp.
MaybeRegisterFingerprint(Graph * graph,const std::map<std::string,std::vector<int>> & named_input_shapes,uint64 input_hash)455 Status MaybeRegisterFingerprint(
456 Graph* graph,
457 const std::map<std::string, std::vector<int>>& named_input_shapes,
458 uint64 input_hash) {
459 // Find the compiler metadata.
460 tpu::TPUCompileMetadataProto metadata_proto;
461 std::map<std::string, std::vector<int>> inputs_to_keep;
462 int num_dynamic_shapes = -1;
463 tensorflow::uint64 fingerprint = 0;
464
465 for (Node* node : graph->op_nodes()) {
466 if (node->type_string() == "TPUCompile" ||
467 node->type_string() == "_TPUCompileMlir") {
468 num_dynamic_shapes = node->attrs().Find("NumDynamicShapes")->i();
469 if (num_dynamic_shapes <= 0) {
470 break;
471 }
472 int visited = 0;
473 // TPUCompileOp/_TPUCompileMlirOp take Shape nodes as inputs.
474 // The number of Shape nodes matches the number of dynamic shaped inputs.
475 // The Shape nodes come from the input nodes:
476 // [TPU Input] --> [Input Shape] --> [TPUCompileOp]
477 for (auto in_node : node->in_nodes()) {
478 if (in_node->type_string() != "Shape") {
479 continue;
480 }
481 for (auto input_node : in_node->in_nodes()) {
482 auto iter = named_input_shapes.find(input_node->name());
483 if (iter != named_input_shapes.end()) {
484 inputs_to_keep[iter->first] = iter->second;
485 }
486 }
487 visited++;
488 if (visited == num_dynamic_shapes) {
489 break;
490 }
491 }
492 std::string metadata = node->attrs().Find("metadata")->s();
493 metadata_proto.ParseFromString(metadata);
494
495 if (node->type_string() == "_TPUCompileMlir") {
496 std::string mlir_module = node->attrs().Find("mlir_module")->s();
497 fingerprint = tensorflow::Fingerprint64(mlir_module);
498 } else {
499 fingerprint = metadata_proto.function_library_fingerprint();
500 }
501
502 break;
503 }
504 }
505 VLOG(2) << "inputs_to_keep size: " << inputs_to_keep.size();
506 if (inputs_to_keep.size() != num_dynamic_shapes) {
507 VLOG(2) << "Cannot match all inputs shapes. Skip fingerprint registration.";
508 return Status::OK();
509 }
510
511 std::vector<TensorShape> input_shapes;
512 TF_RETURN_IF_ERROR(
513 ConvertEdgeShapesToTensorShapes(inputs_to_keep, &input_shapes));
514
515 std::vector<TensorShape> arg_shapes;
516 auto status =
517 tpu::ComputeArgumentShapes(metadata_proto, input_shapes, &arg_shapes);
518 if (!status.ok()) {
519 VLOG(2) << status.error_message();
520 return Status::OK();
521 }
522 uint64 tf_fingerprint =
523 tpu::CreateFingerprintWithNameAndShapes(fingerprint, arg_shapes);
524 VLOG(2) << "fingerprint: " << fingerprint;
525 VLOG(2) << "TF fingerprint: " << tf_fingerprint;
526
527 ResourceMgr* rm = GetTPUConfigResourceMgr();
528 tpu::TpuFingerprintLookup* fingerprint_lookup;
529 TF_RETURN_IF_ERROR(rm->Lookup<tpu::TpuFingerprintLookup>(
530 rm->default_container(), tpu::kFingerprintLookupResourceName,
531 &fingerprint_lookup));
532 fingerprint_lookup->RegisterKeyAndIntermediatePair(input_hash,
533 tf_fingerprint);
534 return Status::OK();
535 }
536
FindTpuReplicatedInputAndXlaSharding(const Graph * graph,XlaShardingInfoMap & xla_sharding_ops,TpuReplicatedInputInfoMap & tpu_replicated_input_ops)537 bool FindTpuReplicatedInputAndXlaSharding(
538 const Graph* graph, XlaShardingInfoMap& xla_sharding_ops,
539 TpuReplicatedInputInfoMap& tpu_replicated_input_ops) {
540 bool xla_spmd_input_sharded = false;
541 // Detect whether there are XLA Sharding on the inputs, if there are, then
542 // we cannot remove the replicated inputs or the xla sharding ops.
543 for (Node* xla_sharding_node : graph->nodes()) {
544 if (xla_sharding_node->type_string() == "XlaSharding") {
545 for (const Edge* edge : xla_sharding_node->in_edges()) {
546 if (edge->src()->type_string() == "TPUReplicatedInput") {
547 Node* tpu_replicated_input_node = edge->src();
548 Node* tpu_replicated_metadata_node = nullptr;
549 for (const Edge* input_edge : tpu_replicated_input_node->in_edges()) {
550 if (input_edge->IsControlEdge()) {
551 tpu_replicated_metadata_node = input_edge->src();
552 }
553 }
554
555 for (const Edge* input_edge : tpu_replicated_input_node->in_edges()) {
556 if (input_edge->src()->type_string() == "_Arg") {
557 Node* arg_node = input_edge->src();
558
559 xla_sharding_ops[arg_node->name()] = std::make_tuple(
560 xla_sharding_node->attrs().Find("T")->type(),
561 xla_sharding_node->attrs().Find("sharding")->s(),
562 xla_sharding_node->attrs().Find("_tpu_replicate")->s());
563
564 tpu_replicated_input_ops[arg_node->name()] = std::make_tuple(
565 tpu_replicated_input_node->attrs().Find("T")->type(),
566 tpu_replicated_metadata_node);
567
568 VLOG(2) << "Detected input is sharded. XlaSharding node: "
569 << xla_sharding_node->DebugString()
570 << ", TPUReplicatedInput node: "
571 << edge->src()->DebugString()
572 << ", _Arg node: " << arg_node->DebugString();
573 xla_spmd_input_sharded = true;
574 break;
575 }
576 }
577 }
578 }
579 }
580 }
581 return xla_spmd_input_sharded;
582 }
583
584 } // end namespace
585
586 namespace tpu_functional_internal {
587
588 // An optimization pass that separates tensors to leverage the fast path in
589 // TPU input preparation. The algorithm is as follows:
590 // (1) Group all tensors that have same dimensions except the last dimension. A
591 // group of tensors will be concatenated by the last dimension in a later pass.
592 // (2) Check all groups of tensors and find groups whose dimensions after concat
593 // cannot leverage the fast path.
594 // (3) For groups of tensors that don't leverage the fast path, split tensors
595 // into two sub-groups such that one sub-group of tensors can leverage the fast
596 // path.
597 // Exception in (2) is that concated tensors are small, which means separating
598 // tensors would introduce overheads of data transfer to device.
599 // This optimization takes effect when both --input_shape_opt and
600 // --group_tensors_for_packing are true.
GroupTensorsForInputPacking(const EdgeShapes & tpu_input_shapes,const absl::flat_hash_map<const Edge *,DataType> & tpu_input_dtypes,bool input_shape_opt,bool group_tensors_for_packing)601 GroupedEdges GroupTensorsForInputPacking(
602 const EdgeShapes& tpu_input_shapes,
603 const absl::flat_hash_map<const Edge*, DataType>& tpu_input_dtypes,
604 bool input_shape_opt, bool group_tensors_for_packing) {
605 GroupedEdges grouped_input_edges;
606 for (const auto& iter : tpu_input_shapes) {
607 if (iter.second.empty()) continue;
608 DataType dtype = tpu_input_dtypes.find(iter.first)->second;
609 string hash_key = HashShapeAndType("input_tensors_dtype_", iter.second,
610 dtype, /*input_shape_opt*/ false);
611 grouped_input_edges[hash_key].push_back(iter.first);
612 }
613 // Apply grouping when both are true.
614 if (!input_shape_opt || !group_tensors_for_packing)
615 return grouped_input_edges;
616
617 GroupedEdges grouped_input_edges_opt;
618 for (const auto& iter : grouped_input_edges) {
619 int sum_last_dim = 0;
620 int product_other_dims = 0;
621 VLOG(3) << "group name: " << iter.first;
622 for (const auto& edge : iter.second) {
623 const std::vector<int>& input_shapes =
624 tpu_input_shapes.find(edge)->second;
625 sum_last_dim += input_shapes.back();
626 if (product_other_dims == 0) {
627 product_other_dims = 1;
628 for (int d = 0; d < input_shapes.size() - 1; ++d)
629 product_other_dims *= input_shapes.at(d);
630 }
631 }
632 VLOG(3) << "sum_last_dim: " << sum_last_dim;
633 VLOG(3) << "product_other_dims: " << product_other_dims;
634 // Already uses fast path, skip further grouping.
635 if ((sum_last_dim % kLastDimOfTpuInputFastPath) == 0 &&
636 (product_other_dims % kOtherDimOfTpuInputFastPath) == 0) {
637 grouped_input_edges_opt[iter.first] = iter.second;
638 continue;
639 }
640 // Tensors are small, skip further grouping.
641 if ((sum_last_dim * product_other_dims) <
642 (kLastDimOfTpuInputFastPath * kOtherDimOfTpuInputFastPath)) {
643 grouped_input_edges_opt[iter.first] = iter.second;
644 continue;
645 }
646 VLOG(3) << "Splitting tensors.";
647 for (const auto& edge : iter.second) {
648 auto tpu_input_shape = tpu_input_shapes.find(edge)->second;
649 string hash_key =
650 HashShapeAndType("input_tensors_dtype_", tpu_input_shape,
651 tpu_input_dtypes.find(edge)->second,
652 /*input_shape_opt*/ true);
653 grouped_input_edges_opt[hash_key].push_back(edge);
654 }
655 }
656 return grouped_input_edges_opt;
657 }
658
GroupTensorsForOutputPacking(Graph * graph,EdgeShapes & tpu_output_shapes,GraphShapeInfo * shape_info)659 GroupedEdges GroupTensorsForOutputPacking(Graph* graph,
660 EdgeShapes& tpu_output_shapes,
661 GraphShapeInfo* shape_info) {
662 GroupedEdges shape_to_output;
663 for (const Edge* edge : graph->edges()) {
664 if (edge->IsControlEdge()) continue;
665
666 // TPU input edge: src is on TPU and dest is on CPU.
667 if (edge->dst()->type_string() != "TPUReplicatedOutput") continue;
668 if (!shape_info->count(edge->src()->name())) continue;
669
670 // output shapes for hashing
671 std::vector<int>& output_shapes = tpu_output_shapes[edge];
672 output_shapes.clear();
673
674 int output_id = edge->src_output();
675 auto inferred_shape_vec = shape_info->at(edge->src()->name());
676
677 for (int d : inferred_shape_vec.at(output_id).shape.dim_sizes()) {
678 output_shapes.push_back(d);
679 }
680
681 // Hash Shape and Types.
682 DataType dtype = edge->src()->input_type(output_id);
683 string hash_key =
684 HashShapeAndType("output_tensors_dtype_", output_shapes, dtype, false);
685
686 shape_to_output[hash_key].push_back(edge);
687 }
688 return shape_to_output;
689 }
690
691 // Concatenates input tensors on CPU along the last dimension if all other
692 // dimensions are the same, and split them on TPU to reduce input overhead.
693 // `tpu_input_shapes` maps an edge to the shape of its output tensor.
694 // `grouped_input_edges` maps tensor name to all edges output from this tensor.
CreateConcatAndSplitNodesForInputTensor(Graph * graph,const string & cluster_name,EdgeShapes * tpu_input_shapes,const absl::flat_hash_map<std::string,std::vector<const Edge * >> & grouped_input_edges,int32_t minimum_input_tensors_packing,bool xla_spmd_input_sharded,const XlaShardingInfoMap & xla_sharding_info,const TpuReplicatedInputInfoMap & tpu_replicated_input_info)695 Status CreateConcatAndSplitNodesForInputTensor(
696 Graph* graph, const string& cluster_name, EdgeShapes* tpu_input_shapes,
697 const absl::flat_hash_map<std::string, std::vector<const Edge*>>&
698 grouped_input_edges,
699 int32_t minimum_input_tensors_packing, bool xla_spmd_input_sharded,
700 const XlaShardingInfoMap& xla_sharding_info,
701 const TpuReplicatedInputInfoMap& tpu_replicated_input_info) {
702 for (const auto& iter : grouped_input_edges) {
703 std::vector<int> last_dim_vec;
704 std::vector<NodeBuilder::NodeOut> concat_nodeouts;
705 absl::flat_hash_map<std::string, int> tensor_to_split_output;
706 int rank;
707 DataType dtype = DT_INVALID;
708 std::string src_name;
709 for (const Edge* edge : iter.second) {
710 src_name = edge->src()->name();
711 string tensor_name =
712 absl::StrCat(edge->src()->name(), ":", edge->src_output());
713 // Create Concat / Split pair for a tensor if not exist yet.
714 if (tensor_to_split_output.contains(tensor_name)) continue;
715 tensor_to_split_output[tensor_name] = concat_nodeouts.size();
716 concat_nodeouts.push_back(
717 NodeBuilder::NodeOut(edge->src(), edge->src_output()));
718 dtype = edge->src()->output_type(edge->src_output());
719 rank = tpu_input_shapes->at(edge).size();
720 last_dim_vec.push_back(tpu_input_shapes->at(edge).back());
721 }
722
723 const int num_tensors = tensor_to_split_output.size();
724 VLOG(3) << iter.first << " num_tensors: " << num_tensors;
725 if (num_tensors < minimum_input_tensors_packing) {
726 VLOG(3) << "skip concat/split " << iter.first;
727 continue;
728 }
729
730 Node* concat_axis_node = nullptr;
731 TensorShape t_shape;
732 Tensor dim_tensor(DT_INT32, t_shape);
733 // Concat and Split at the last dim.
734 dim_tensor.flat<int>()(0) = rank - 1;
735 TF_RETURN_IF_ERROR(
736 NodeBuilder(strings::StrCat(iter.first, "/concat/axis"), "Const")
737 .Attr("dtype", DT_INT32)
738 .Attr("value", dim_tensor)
739 .Finalize(graph, &concat_axis_node));
740
741 Node* concat_node = nullptr;
742 TF_RETURN_IF_ERROR(
743 NodeBuilder(strings::StrCat(iter.first, "/concat"), "ConcatV2")
744 .Input(concat_nodeouts)
745 .Input(concat_axis_node)
746 .Attr("T", dtype)
747 .Attr("Tidx", DT_INT32)
748 .Attr("N", num_tensors)
749 .Finalize(graph, &concat_node));
750
751 Node* split_dim_node = nullptr;
752 TF_RETURN_IF_ERROR(
753 NodeBuilder(strings::StrCat(iter.first, "/split/split_dim"), "Const")
754 .Attr("dtype", DT_INT32)
755 .Attr("value", dim_tensor)
756 .Attr(kTpuReplicateAttr, cluster_name)
757 .Finalize(graph, &split_dim_node));
758
759 Node* split_vec_node = nullptr;
760 TensorShape split_vec_shape;
761 split_vec_shape.AddDim(1);
762 split_vec_shape.set_dim(0, last_dim_vec.size());
763
764 Tensor split_vec_tensor(DT_INT32, split_vec_shape);
765 for (int i = 0; i < last_dim_vec.size(); ++i) {
766 split_vec_tensor.flat<int>()(i) = last_dim_vec[i];
767 }
768 VLOG(3) << "split_vec_tensor: " << split_vec_tensor.DebugString();
769
770 TF_RETURN_IF_ERROR(
771 NodeBuilder(strings::StrCat(iter.first, "/split/vec"), "Const")
772 .Attr("dtype", DT_INT32)
773 .Attr("value", split_vec_tensor)
774 .Attr(kTpuReplicateAttr, cluster_name)
775 .Finalize(graph, &split_vec_node));
776
777 Node* split_node = nullptr;
778 Node* input_to_split_node = concat_node;
779 Node* output_from_concat_node = nullptr;
780 if (xla_spmd_input_sharded &&
781 tpu_replicated_input_info.count(src_name) > 0 &&
782 xla_sharding_info.count(src_name) > 0) {
783 // Create new TPUReplicatedInput and XLAShardingOp nodes
784 //
785 // Rewrite the graph from:
786 // Concat -> Split
787 // to
788 // Concat -> TPUReplicatedInput -> XlaSharding -> Split
789 Node* tpu_replicated_input = nullptr;
790 Node* xla_sharding_op = nullptr;
791
792 std::vector<NodeBuilder::NodeOut> replicated_input;
793 replicated_input.push_back(NodeBuilder::NodeOut(concat_node));
794
795 // TODO(b/183060455): Add TPUReplicatedInput to all graphs.
796 TF_RETURN_IF_ERROR(
797 NodeBuilder(strings::StrCat(iter.first, "/TPUReplicatedInput"),
798 "TPUReplicatedInput")
799 .Input(replicated_input)
800 .ControlInput(std::get<1>(tpu_replicated_input_info.at(src_name)))
801 .Attr("N", 1)
802 .Attr("T", std::get<0>(tpu_replicated_input_info.at(src_name)))
803 .Attr("index", -1)
804 .Attr("is_mirrored_variable", false)
805 .Attr("is_packed", false)
806 .Finalize(graph, &tpu_replicated_input));
807 VLOG(2) << "Created new TPUReplicatedInput node "
808 << tpu_replicated_input->DebugString();
809
810 TF_RETURN_IF_ERROR(
811 NodeBuilder(strings::StrCat(iter.first, "/XlaSharding"),
812 "XlaSharding")
813 .Input(tpu_replicated_input)
814 .Attr("T", std::get<0>(xla_sharding_info.at(src_name)))
815 .Attr("sharding", std::get<1>(xla_sharding_info.at(src_name)))
816 .Attr("_XlaSharding", std::get<1>(xla_sharding_info.at(src_name)))
817 .Attr("_tpu_replicate",
818 std::get<2>(xla_sharding_info.at(src_name)))
819 .Finalize(graph, &xla_sharding_op));
820 VLOG(2) << "Created new XLA sharding node "
821 << xla_sharding_op->DebugString();
822
823 input_to_split_node = xla_sharding_op;
824 output_from_concat_node = tpu_replicated_input;
825 }
826 // Update the `tpu_input_shapes` mapping: Add the new edge
827 // from concat to split.
828 TF_RETURN_IF_ERROR(
829 NodeBuilder(strings::StrCat(iter.first, "/split"), "SplitV")
830 .Input(input_to_split_node)
831 .Input(split_vec_node)
832 .Input(split_dim_node)
833 .Attr("T", dtype)
834 .Attr("num_split", num_tensors)
835 .Attr(kTpuReplicateAttr, cluster_name)
836 .Finalize(graph, &split_node));
837
838 if (output_from_concat_node == nullptr)
839 output_from_concat_node = split_node;
840
841 const Edge* concat_to_split;
842 for (const Edge* edge : concat_node->out_edges())
843 if (edge->dst() == output_from_concat_node) {
844 concat_to_split = edge;
845 break;
846 }
847 if (rank > 1) {
848 for (int d = 0; d < rank - 1; ++d)
849 (*tpu_input_shapes)[concat_to_split].push_back(
850 tpu_input_shapes->at(iter.second.back()).at(d));
851 }
852 (*tpu_input_shapes)[concat_to_split].push_back(
853 std::accumulate(last_dim_vec.begin(), last_dim_vec.end(), 0));
854
855 // Connect split node to original tensor output.
856 for (const Edge* edge : iter.second) {
857 string tensor_name =
858 absl::StrCat(edge->src()->name(), ":", edge->src_output());
859 int output_index = tensor_to_split_output.at(tensor_name);
860 graph->RemoveEdge(edge);
861 graph->AddEdge(split_node, output_index, edge->dst(), edge->dst_input());
862 // Update the `tpu_input_shapes` mapping: Remove old edges.
863 tpu_input_shapes->erase(edge);
864 }
865 VLOG(3) << "Concat node: " << concat_node->DebugString();
866 }
867 return Status::OK();
868 }
869
870 // Concatenates input tensors on TPU along the last dimension if all other
871 // dimensions are the same, and split them on CPU to reduce outfeed overhead.
872 // `tpu_inferred_info` maps an edge to the inferred shape of its output tensor.
873 // `shape_to_output` maps tensor name to all edges output from this tensor.
CreateConcatAndSplitNodesForOutputTensor(Graph * graph,const string & cluster_name,EdgeShapes * tpu_output_shapes,GraphShapeInfo * tpu_inferred_info,GroupedEdges shape_to_output,int32_t minimum_output_tensors_packing)874 Status CreateConcatAndSplitNodesForOutputTensor(
875 Graph* graph, const string& cluster_name, EdgeShapes* tpu_output_shapes,
876 GraphShapeInfo* tpu_inferred_info, GroupedEdges shape_to_output,
877 int32_t minimum_output_tensors_packing) {
878 for (const auto& iter : shape_to_output) {
879 std::vector<int> last_dim_vec;
880 std::vector<NodeBuilder::NodeOut> concat_nodeouts;
881 absl::flat_hash_map<std::string, int> tensor_to_split_output;
882 int rank;
883 DataType dtype = DT_INVALID;
884 for (const Edge* edge : iter.second) {
885 string tensor_name =
886 absl::StrCat(edge->src()->name(), ":", edge->src_output());
887
888 // Create Concat / Split pair for a tensor if not exist yet.
889 if (tensor_to_split_output.contains(tensor_name)) continue;
890 tensor_to_split_output[tensor_name] = concat_nodeouts.size();
891
892 concat_nodeouts.push_back(
893 NodeBuilder::NodeOut(edge->src(), edge->src_output()));
894 dtype = edge->src()->output_type(edge->src_output());
895 rank = tpu_output_shapes->at(edge).size();
896 last_dim_vec.push_back(tpu_output_shapes->at(edge).back());
897 }
898
899 const int num_tensors = tensor_to_split_output.size();
900 if (num_tensors < minimum_output_tensors_packing) {
901 VLOG(3) << "skip concat/split " << iter.first;
902 continue;
903 }
904
905 Node* concat_axis_node = nullptr;
906 TensorShape t_shape;
907 Tensor dim_tensor(DT_INT32, t_shape);
908 // Concat and Split at the last dim.
909 dim_tensor.flat<int>()(0) = rank - 1;
910 TF_RETURN_IF_ERROR(
911 NodeBuilder(strings::StrCat(iter.first, "/concat/axis"), "Const")
912 .Attr("dtype", DT_INT32)
913 .Attr("value", dim_tensor)
914 .Attr(kTpuReplicateAttr, cluster_name)
915 .Finalize(graph, &concat_axis_node));
916
917 Node* concat_node = nullptr;
918 TF_RETURN_IF_ERROR(
919 NodeBuilder(strings::StrCat(iter.first, "/concat"), "ConcatV2")
920 .Input(concat_nodeouts)
921 .Input(concat_axis_node)
922 .Attr("T", dtype)
923 .Attr("Tidx", DT_INT32)
924 .Attr("N", num_tensors)
925 .Attr(kTpuReplicateAttr, cluster_name)
926 .Finalize(graph, &concat_node));
927
928 Node* tpu_replicated_output_node = nullptr;
929 TF_RETURN_IF_ERROR(
930 NodeBuilder(strings::StrCat(iter.first, "/tpu_replicated_output"),
931 "TPUReplicatedOutput")
932 .Input(concat_node)
933 .Attr("T", dtype)
934 .Attr("num_replicas", 1)
935 .Finalize(graph, &tpu_replicated_output_node));
936
937 Node* split_dim_node = nullptr;
938 TF_RETURN_IF_ERROR(
939 NodeBuilder(strings::StrCat(iter.first, "/split/split_dim"), "Const")
940 .Attr("dtype", DT_INT32)
941 .Attr("value", dim_tensor)
942 .Finalize(graph, &split_dim_node));
943
944 Node* split_vec_node = nullptr;
945 TensorShape split_vec_shape;
946 split_vec_shape.AddDim(1);
947 split_vec_shape.set_dim(0, last_dim_vec.size());
948
949 Tensor split_vec_tensor(DT_INT32, split_vec_shape);
950 for (int i = 0; i < last_dim_vec.size(); ++i) {
951 split_vec_tensor.flat<int>()(i) = last_dim_vec[i];
952 }
953 VLOG(3) << "split_vec_tensor: " << split_vec_tensor.DebugString();
954
955 TF_RETURN_IF_ERROR(
956 NodeBuilder(strings::StrCat(iter.first, "/split/vec"), "Const")
957 .Attr("dtype", DT_INT32)
958 .Attr("value", split_vec_tensor)
959 .Finalize(graph, &split_vec_node));
960
961 Node* split_node = nullptr;
962 TF_RETURN_IF_ERROR(
963 NodeBuilder(strings::StrCat(iter.first, "/split"), "SplitV")
964 .Input(tpu_replicated_output_node)
965 .Input(split_vec_node)
966 .Input(split_dim_node)
967 .Attr("T", dtype)
968 .Attr("num_split", num_tensors)
969 .Finalize(graph, &split_node));
970
971 // Update the `tpu_out_shapes` mapping: Add the new edge
972 // from concat to split.
973 const Edge* concat_to_split;
974 for (const Edge* edge : concat_node->out_edges())
975 if (edge->dst() == split_node) {
976 concat_to_split = edge;
977 break;
978 }
979
980 if (rank > 1) (*tpu_output_shapes)[concat_to_split].push_back(-1);
981 for (int d = 1; d < rank - 1; ++d)
982 (*tpu_output_shapes)[concat_to_split].push_back(
983 tpu_output_shapes->at(iter.second.back()).at(d));
984 (*tpu_output_shapes)[concat_to_split].push_back(
985 std::accumulate(last_dim_vec.begin(), last_dim_vec.end(), 0));
986
987 for (const Edge* edge : iter.second) {
988 // 1. Find old TPURelicatedOutput output edges
989 std::vector<const Edge*> output_edge_vec;
990 for (const Edge* output_edge : edge->dst()->out_edges())
991 output_edge_vec.push_back(output_edge);
992
993 string tensor_name =
994 absl::StrCat(edge->src()->name(), ":", edge->src_output());
995 int output_index = tensor_to_split_output.at(tensor_name);
996 VLOG(3) << "output_index: " << output_index;
997
998 // Connect split node to original tensor output.
999 for (const Edge* output_edge : output_edge_vec) {
1000 VLOG(3) << "output_edge" << output_edge->DebugString();
1001 graph->RemoveEdge(output_edge);
1002 graph->AddEdge(split_node, output_index, output_edge->dst(),
1003 output_edge->dst_input());
1004 // Update the `tpu_output_shapes` mapping: Remove old edges.
1005 tpu_output_shapes->erase(output_edge);
1006 }
1007 graph->RemoveNode(edge->dst());
1008 }
1009 VLOG(3) << "Concat node: " << concat_node->DebugString();
1010 }
1011 return Status::OK();
1012 }
1013
InsertReshapeNodePairs(Graph * graph,const string & cluster_name,EdgeShapes * tpu_input_shapes,int num_cores_per_replica)1014 Status InsertReshapeNodePairs(Graph* graph, const string& cluster_name,
1015 EdgeShapes* tpu_input_shapes,
1016 int num_cores_per_replica) {
1017 std::vector<const Edge*> tpu_input_edges_original;
1018 for (const auto& it : *tpu_input_shapes)
1019 if (!it.second.empty()) tpu_input_edges_original.push_back(it.first);
1020 for (const Edge* edge : tpu_input_edges_original) {
1021 VLOG(3) << "Reshape input: " << edge->DebugString();
1022
1023 // Check if there is a TPUReplicatedInput and XlaSharding in the middle
1024 bool xla_sharded_input = false;
1025 Node* xla_sharding_node = nullptr;
1026 if (edge->dst()->type_string() == "TPUReplicatedInput" &&
1027 edge->dst()->out_nodes().begin()->type_string() == "XlaSharding") {
1028 VLOG(3) << "Detected TPUReplicatedInput " << edge->dst()->DebugString()
1029 << " and XlaSharding "
1030 << edge->dst()->out_nodes().begin()->DebugString()
1031 << ", setting xla_sharded_input = true";
1032 xla_sharded_input = true;
1033 xla_sharding_node = *(edge->dst()->out_nodes().begin());
1034 }
1035
1036 // 1. Build Reshape node for flatten.
1037
1038 // 1.1 Build Const node for shape
1039 Node* flatten_reshape_shape_node = nullptr;
1040 Tensor flattened_input_shape_tensor;
1041 flattened_input_shape_tensor =
1042 Tensor(DT_INT32, TensorShape({static_cast<int64>(1)}));
1043 flattened_input_shape_tensor.flat<int>()(0) = -1;
1044 TF_RETURN_IF_ERROR(
1045 NodeBuilder(absl::StrCat(edge->src()->name(), "/flatten/Reshape/shape"),
1046 "Const")
1047 .Attr("dtype", DT_INT32)
1048 .Attr("value", flattened_input_shape_tensor)
1049 .Finalize(graph, &flatten_reshape_shape_node));
1050
1051 // 1.2 Build Reshape node for flatten.
1052 Node* flatten_reshape_node = nullptr;
1053 TF_RETURN_IF_ERROR(
1054 NodeBuilder(absl::StrCat(edge->src()->name(), "/flatten/Reshape"),
1055 "Reshape")
1056 .Input(edge->src(), edge->src_output())
1057 .Input(flatten_reshape_shape_node)
1058 .Attr("T", edge->src()->output_type(edge->src_output()))
1059 .Attr("Tshape", DT_INT32)
1060 .Finalize(graph, &flatten_reshape_node));
1061
1062 // 2. Build Reshape node for recover.
1063
1064 // 2.1 Build Const node for shape.
1065 Node* recover_reshape_shape_node = nullptr;
1066 Tensor original_input_shape_tensor(
1067 DT_INT32,
1068 TensorShape({static_cast<int64>(tpu_input_shapes->at(edge).size())}));
1069 original_input_shape_tensor.flat<int>()(0) = -1;
1070 for (int d = 1; d < tpu_input_shapes->at(edge).size(); ++d)
1071 original_input_shape_tensor.flat<int>()(d) =
1072 tpu_input_shapes->at(edge).at(d);
1073 TF_RETURN_IF_ERROR(
1074 NodeBuilder(absl::StrCat(edge->src()->name(), "/recover/Reshape/shape"),
1075 "Const")
1076 .Attr("dtype", DT_INT32)
1077 .Attr("value", original_input_shape_tensor)
1078 .Attr(kTpuReplicateAttr, cluster_name) // This node is on TPU.
1079 .Finalize(graph, &recover_reshape_shape_node));
1080
1081 // 2.2 Build Reshape node for recover.
1082 Node* recover_reshape_input_node = flatten_reshape_node;
1083 const Edge* original_recover_reshape_input_edge = nullptr;
1084 if (xla_sharded_input) {
1085 // We want to find the node after the XlaSharding node
1086 original_recover_reshape_input_edge =
1087 *(edge->dst()->out_nodes().begin()->out_edges().begin());
1088 recover_reshape_input_node = *(edge->dst()->out_nodes().begin());
1089 VLOG(3) << "Recover reshape input node: "
1090 << recover_reshape_input_node->DebugString()
1091 << ", recover reshape input edge: "
1092 << original_recover_reshape_input_edge->DebugString();
1093 }
1094
1095 Node* recover_reshape_node = nullptr;
1096 TF_RETURN_IF_ERROR(
1097 NodeBuilder(absl::StrCat(edge->src()->name(), "/recover/Reshape"),
1098 "Reshape")
1099 .Input(recover_reshape_input_node)
1100 .Input(recover_reshape_shape_node)
1101 .Attr("T", edge->src()->output_type(edge->src_output()))
1102 .Attr("Tshape", DT_INT32)
1103 .Attr(kTpuReplicateAttr, cluster_name) // This node is on TPU.
1104 .Finalize(graph, &recover_reshape_node));
1105
1106 // 3. Rewrite XlaSharding attribute if necessary
1107 if (xla_sharding_node != nullptr) {
1108 // The flattened tensor always has rank = 1 and we want to shard the only
1109 // dimension (0).
1110 SetXlaShardingNodeAttr(xla_sharding_node, num_cores_per_replica, 1, 0);
1111 }
1112
1113 // 4. Connect / disconnect nodes.
1114 if (xla_sharded_input) {
1115 graph->AddEdge(flatten_reshape_node, 0, edge->dst(), edge->dst_input());
1116 }
1117
1118 if (original_recover_reshape_input_edge != nullptr) {
1119 graph->AddEdge(recover_reshape_node, 0,
1120 original_recover_reshape_input_edge->dst(),
1121 original_recover_reshape_input_edge->dst_input());
1122 } else {
1123 graph->AddEdge(recover_reshape_node, 0, edge->dst(), edge->dst_input());
1124 }
1125
1126 graph->RemoveEdge(edge);
1127 if (original_recover_reshape_input_edge != nullptr) {
1128 graph->RemoveEdge(original_recover_reshape_input_edge);
1129 }
1130
1131 // 4. Update EdgeShapes.
1132 int dimension = 1;
1133 for (auto& it : (*tpu_input_shapes)[edge]) {
1134 dimension *= it;
1135 }
1136 VLOG(3) << "Dimension after reshape: " << dimension;
1137 for (const Edge* out_edge : flatten_reshape_node->out_edges()) {
1138 if (out_edge->dst() == recover_reshape_node) {
1139 (*tpu_input_shapes)[out_edge].push_back(dimension);
1140 tpu_input_shapes->erase(edge);
1141 break;
1142 }
1143 }
1144 VLOG(3) << "Reshape optimization done for " << edge->src()->name();
1145 }
1146 return Status::OK();
1147 }
1148 } // namespace tpu_functional_internal
1149
ComputeAsync(OpKernelContext * ctx,DoneCallback done)1150 void TPUPartitionedCallOp::ComputeAsync(OpKernelContext* ctx,
1151 DoneCallback done) {
1152 Status init_status;
1153 absl::call_once(once_, [&]() {
1154 library_runtime_ = ctx->function_library();
1155 if (library_runtime_ == nullptr) {
1156 init_status = errors::Internal("No function library is provided.");
1157 return;
1158 }
1159 flib_def_ = std::make_unique<FunctionLibraryDefinition>(
1160 *library_runtime_->GetFunctionLibraryDefinition());
1161 device_mgr_ = library_runtime_->device_mgr();
1162 for (auto d : device_mgr_->ListDevices()) {
1163 device_set_.AddDevice(d);
1164 }
1165
1166 DeviceNameUtils::ParsedName tpu_device_name;
1167 tpu_device_name.has_type = true;
1168 tpu_device_name.type = "TPU";
1169 std::vector<Device*> tpu_devices;
1170 device_set_.FindMatchingDevices(tpu_device_name, &tpu_devices_);
1171 });
1172 OP_REQUIRES_OK_ASYNC(ctx, init_status, done);
1173
1174 // Initialize the ordinal selector with information from the graph if it is
1175 // the first time we are running this op.
1176 absl::call_once(ordinal_selector_once_, [&]() {
1177 std::unique_ptr<Graph> graph(new Graph(flib_def_.get()));
1178 int num_cores_per_replica = 1;
1179 bool enable_spmd_xla_partitioning = false;
1180 {
1181 absl::MutexLock l(&mu_);
1182 OP_REQUIRES_OK_ASYNC(
1183 ctx,
1184 GetGraphFromFunction(graph.get(), /*device_ordinal=*/0,
1185 &num_cores_per_replica,
1186 &enable_spmd_xla_partitioning),
1187 done);
1188 }
1189 if (enable_spmd_xla_partitioning) {
1190 ordinal_selector_ =
1191 std::make_shared<tpu::TPUOrdinalSelector>(num_cores_per_replica);
1192 } else {
1193 ordinal_selector_ = std::make_shared<tpu::TPUOrdinalSelector>();
1194 }
1195
1196 metrics::RecordTPUXlaSpmdCoresPerReplica(num_cores_per_replica);
1197 });
1198 uint64 input_hash = GetInputHash(ctx);
1199 int64_t ordinal_selector_req_id = -1;
1200 // Select a TPU core.
1201 absl::ReleasableMutexLock lock(&mu_);
1202 int32_t device_ordinal = 0;
1203 OP_REQUIRES_OK_ASYNC(
1204 ctx,
1205 GetTpuCoreOrdinal(ctx, input_hash, &ordinal_selector_req_id,
1206 &device_ordinal),
1207 done);
1208 uint64 cache_hash = Hash64Combine(input_hash, device_ordinal);
1209
1210 const std::vector<DeviceAndFHandle>* functions;
1211
1212 bool cache_miss = !partition_cache_.count(cache_hash);
1213 if (cache_miss) {
1214 VLOG(3) << "Cache Miss: partitioning function " << func_.name()
1215 << " cache_hash: " << cache_hash
1216 << " device_ordinal: " << device_ordinal;
1217
1218 profiler::TraceMe trace_me(
1219 "TPUPartitionedCallOp-RewriteAndInstantiateFunctions");
1220 std::unique_ptr<Graph> graph(new Graph(flib_def_.get()));
1221 int num_cores_per_replica = 1;
1222 bool enable_spmd_xla_partitioning = false;
1223 OP_REQUIRES_OK_ASYNC(ctx,
1224 GetGraphFromFunction(graph.get(), device_ordinal,
1225 &num_cores_per_replica,
1226 &enable_spmd_xla_partitioning),
1227 done);
1228
1229 VLOG(1) << DumpGraphToFile("before_input_output_optimizations", *graph,
1230 flib_def_.get());
1231
1232 std::map<std::string, std::vector<int>> named_input_shapes;
1233 OP_REQUIRES_OK_ASYNC(ctx,
1234 OptimizeTpuInputOutputTensors(
1235 graph.get(), enable_spmd_xla_partitioning,
1236 num_cores_per_replica, named_input_shapes, ctx),
1237 done);
1238
1239 VLOG(1) << DumpGraphToFile(
1240 "before_replace_resource_args_with_var_handle_ops", *graph,
1241 flib_def_.get());
1242 OP_REQUIRES_OK_ASYNC(
1243 ctx,
1244 ReplaceResourceArgsWithVarHandleOps(graph.get(), ctx, device_ordinal,
1245 num_cores_per_replica,
1246 enable_spmd_xla_partitioning),
1247 done);
1248
1249 VLOG(1) << DumpGraphToFile(
1250 "after_replace_resource_args_with_var_handle_ops", *graph,
1251 flib_def_.get());
1252
1253 // Graph rewrite passes.
1254 GraphOptimizationPassOptions optimization_options;
1255 // TODO(akshayka): Thread the SessionOptions into this kernel, or make
1256 // it possible to specify the relevant options via attributes.
1257 SessionOptions session_options;
1258 session_options.config.mutable_experimental()
1259 ->set_xla_fusion_autotuner_thresh(autotuner_thresh_);
1260
1261 session_options.env = ctx->env();
1262 optimization_options.session_handle = ctx->session_handle();
1263 optimization_options.session_options = &session_options;
1264 optimization_options.graph = &graph;
1265 optimization_options.flib_def = flib_def_.get();
1266 optimization_options.device_set = &device_set_;
1267 OP_REQUIRES_OK_ASYNC(
1268 ctx, PlacementHelper(device_set_, optimization_options, func_.name()),
1269 done);
1270
1271 if (!enable_spmd_xla_partitioning || num_cores_per_replica == 1) {
1272 OP_REQUIRES_OK_ASYNC(
1273 ctx,
1274 MaybeRegisterFingerprint(graph.get(), named_input_shapes, input_hash),
1275 done);
1276 }
1277 // `subgraphs` maps from device names to functions.
1278 std::unordered_map<std::string, std::unique_ptr<Graph>> subgraphs;
1279 optimization_options.graph = nullptr;
1280 optimization_options.device_set = nullptr;
1281 optimization_options.partition_graphs = &subgraphs;
1282 VLOG(1) << DumpGraphToFile("before_partition_helper.pbtxt", *graph,
1283 flib_def_.get());
1284 OP_REQUIRES_OK_ASYNC(ctx,
1285 PartitionHelper(device_set_, optimization_options,
1286 graph.get(), &subgraphs),
1287 done);
1288 OP_REQUIRES_OK_ASYNC(ctx,
1289 InstantiateFunctionsFromSubgraphs(
1290 device_set_, device_ordinal, cache_hash,
1291 num_cores_per_replica, std::move(subgraphs)),
1292 done);
1293 }
1294 functions = &partition_cache_[cache_hash];
1295 lock.Release();
1296
1297 ExecuteFunctions(*functions, ctx, device_ordinal, ordinal_selector_req_id,
1298 std::move(done));
1299 }
1300
GetTpuCoreOrdinal(OpKernelContext * ctx,uint64 input_hash,int64_t * ordinal_selector_req_id,int32_t * core_ordinal)1301 Status TPUPartitionedCallOp::GetTpuCoreOrdinal(OpKernelContext* ctx,
1302 uint64 input_hash,
1303 int64_t* ordinal_selector_req_id,
1304 int32_t* core_ordinal) {
1305 profiler::TraceMe trace_me("TPUPartitionedCallOp-GetTpuCoreOrdinal");
1306 const Tensor* device_ordinal_t;
1307 TF_RETURN_IF_ERROR(ctx->input(kDeviceOrdinalAttr, &device_ordinal_t));
1308 int device_ordinal = device_ordinal_t->scalar<int>()();
1309 if (device_ordinal == tpu::kDeferredCoreSelectionReserved) {
1310 device_ordinal =
1311 ordinal_selector_->GetOrdinal(input_hash, ordinal_selector_req_id);
1312 }
1313 *core_ordinal = device_ordinal;
1314 return Status::OK();
1315 }
1316
InitializeVarOnTPU(OpKernelContext * ctx,const core::RefCountPtr<Var> & var,NodeDef * ndef,int device_ordinal,bool fast_mem)1317 Status TPUPartitionedCallOp::InitializeVarOnTPU(
1318 OpKernelContext* ctx, const core::RefCountPtr<Var>& var, NodeDef* ndef,
1319 int device_ordinal, bool fast_mem) {
1320 const string device = strings::StrCat(kTPUDeviceNamePrefix, device_ordinal);
1321 Status status;
1322 std::unique_ptr<Graph> init_graph(new Graph(OpRegistry::Global()));
1323 Node* init_handle = init_graph->AddNode(*ndef, &status);
1324 TF_RETURN_IF_ERROR(status);
1325 init_handle->set_assigned_device_name(device);
1326
1327 NodeDef init_const_ndef;
1328 init_const_ndef.set_name("initial_value");
1329 if (fast_mem) {
1330 init_const_ndef.set_op("_TPUConst");
1331 AddNodeAttr("memory_space", "FastMem", &init_const_ndef);
1332 } else {
1333 init_const_ndef.set_op("Const");
1334 }
1335 init_const_ndef.set_device(device);
1336 AddNodeAttr("dtype", var->tensor()->dtype(), &init_const_ndef);
1337 AddNodeAttr("value", *var->tensor(), &init_const_ndef);
1338
1339 Node* init_const = init_graph->AddNode(init_const_ndef, &status);
1340 TF_RETURN_IF_ERROR(status);
1341
1342 NodeDef assign_node_def;
1343 assign_node_def.set_name("Assign");
1344 assign_node_def.set_op("AssignVariableOp");
1345 assign_node_def.set_device(device);
1346 AddNodeAttr("dtype", var->tensor()->dtype(), &assign_node_def);
1347 Node* init_assign = init_graph->AddNode(assign_node_def, &status);
1348 TF_RETURN_IF_ERROR(status);
1349
1350 init_graph->AddEdge(init_handle, 0, init_assign, 0);
1351 init_graph->AddEdge(init_const, 0, init_assign, 1);
1352 FHandle fhandle;
1353 const string fname =
1354 strings::StrCat(ndef->name(), "_init_ord_", device_ordinal);
1355
1356 TF_RETURN_IF_ERROR(
1357 InstantiatePartition(*init_graph, fname, device, &fhandle, nullptr));
1358
1359 FunctionLibraryRuntime::Options opts;
1360 opts.step_container = ctx->step_container();
1361 opts.cancellation_manager = ctx->cancellation_manager();
1362 opts.stats_collector = ctx->stats_collector();
1363
1364 // Blocking on threads in the same thread pool is disallowed because
1365 // concurrent warm-up requests can exhaust the default thread pool.
1366 // Create a new thread pool to initialize variables on TPU.
1367 std::function<void(std::function<void()>)> runner =
1368 [this](std::function<void()> fn) { pool_.Schedule(fn); };
1369 opts.runner = &runner;
1370
1371 opts.source_device = local_device_name_;
1372 PrivateIntraProcessRendezvous rendez(device_mgr_);
1373 opts.rendezvous = &rendez;
1374 opts.remote_execution = true;
1375
1376 std::vector<Tensor> dummy_args;
1377 std::vector<Tensor>* dummy_rets = new std::vector<Tensor>;
1378 Notification done;
1379 profiler::TraceMe trace_me("TPUPartitionedCallOp-InitializeVarOnTPU");
1380 library_runtime_->Run(opts, fhandle, dummy_args, dummy_rets,
1381 [dummy_rets, &done, ctx](const Status& status) {
1382 if (!status.ok()) {
1383 ctx->SetStatus(status);
1384 }
1385 delete dummy_rets;
1386 done.Notify();
1387 });
1388 done.WaitForNotification();
1389 // We don't actually want the variable initialization functions
1390 // in the function library definition and the function library
1391 // runtime, because flib_def_ is used for the graph rewrite passes.
1392 // The TPU distributed rewrite pass computes a fingerprint for
1393 // flib_def_, which will throw an length error if there are
1394 // many variables whose initialization functions are added
1395 // to the library definition.
1396 TF_RETURN_IF_ERROR(flib_def_->RemoveFunction(fname));
1397 TF_RETURN_IF_ERROR(library_runtime_->ReleaseHandle(fhandle));
1398 return Status::OK();
1399 }
1400
InitializeShardedVarOnTPU(OpKernelContext * ctx,const core::RefCountPtr<Var> & var,std::vector<NodeDef> & ndefs,int split_dim,int device_ordinal)1401 Status TPUPartitionedCallOp::InitializeShardedVarOnTPU(
1402 OpKernelContext* ctx, const core::RefCountPtr<Var>& var,
1403 std::vector<NodeDef>& ndefs, int split_dim, int device_ordinal) {
1404 std::unique_ptr<Graph> init_graph(new Graph(OpRegistry::Global()));
1405 int num_cores = ndefs.size();
1406 string cpu_device = "/device:CPU:0";
1407
1408 Status status;
1409 std::vector<std::string> devices;
1410 std::vector<Node*> init_handles;
1411 for (int i = 0; i < num_cores; i++) {
1412 Node* init_handle = init_graph->AddNode(ndefs[i], &status);
1413 TF_RETURN_IF_ERROR(status);
1414 string device = strings::StrCat(kTPUDeviceNamePrefix, device_ordinal + i);
1415 init_handle->set_assigned_device_name(device);
1416 devices.push_back(device);
1417 init_handles.push_back(init_handle);
1418 }
1419
1420 NodeDef init_const_ndef;
1421 init_const_ndef.set_name("initial_value");
1422 init_const_ndef.set_op("Const");
1423 init_const_ndef.set_device(cpu_device);
1424 AddNodeAttr("dtype", var->tensor()->dtype(), &init_const_ndef);
1425 AddNodeAttr("value", *var->tensor(), &init_const_ndef);
1426 Node* init_const = init_graph->AddNode(init_const_ndef, &status);
1427 init_const->set_assigned_device_name(cpu_device);
1428 TF_RETURN_IF_ERROR(status);
1429
1430 Node* assign_value_node = init_const;
1431 // If the variable is sharded, we will insert "Split" node between the initial
1432 // value and AssignVariableOp, so the variables on each TPU device get
1433 // assigned to the splitted value.
1434 //
1435 // initial_value--Split--AssignVariableOp ("/device:TPU:0")
1436 // |
1437 // AssignVariableOp ("/device:TPU:1")
1438 if (split_dim >= 0) {
1439 // Add a split dimension node.
1440 NodeDef split_dim_def;
1441 split_dim_def.set_name("initial_value_split_dim");
1442 split_dim_def.set_op("Const");
1443 split_dim_def.set_device(cpu_device);
1444 AddNodeAttr("dtype", DT_INT32, &split_dim_def);
1445 TensorProto tensor_proto;
1446 tensor_proto.set_dtype(DT_INT32);
1447 tensor_proto.add_int_val(split_dim);
1448 TensorShape shape({});
1449 shape.AsProto(tensor_proto.mutable_tensor_shape());
1450 AddNodeAttr("value", tensor_proto, &split_dim_def);
1451 Node* split_dim_node = init_graph->AddNode(split_dim_def, &status);
1452 split_dim_node->set_assigned_device_name(cpu_device);
1453 TF_RETURN_IF_ERROR(status);
1454
1455 // Add a split node.
1456 NodeDef split_def;
1457 int split_num = ndefs.size();
1458 split_def.set_name("initial_value_split");
1459 split_def.set_op("Split");
1460 split_def.set_device(cpu_device);
1461 AddNodeAttr("num_split", split_num, &split_def);
1462 AddNodeAttr("T", var->tensor()->dtype(), &split_def);
1463 split_def.add_input(absl::StrCat(split_dim_node->name(), ":0"));
1464 split_def.add_input(absl::StrCat(init_const->name(), ":0"));
1465 Node* split_node = init_graph->AddNode(split_def, &status);
1466 split_node->set_assigned_device_name(cpu_device);
1467 TF_RETURN_IF_ERROR(status);
1468
1469 init_graph->AddEdge(split_dim_node, 0, split_node, 0);
1470 init_graph->AddEdge(init_const, 0, split_node, 1);
1471
1472 assign_value_node = split_node;
1473 }
1474
1475 for (int i = 0; i < num_cores; i++) {
1476 NodeDef assign_node_def;
1477 assign_node_def.set_name(absl::StrCat("Assign_", i));
1478 assign_node_def.set_op("AssignVariableOp");
1479 assign_node_def.set_device(devices[i]);
1480 AddNodeAttr("dtype", var->tensor()->dtype(), &assign_node_def);
1481 Node* init_assign = init_graph->AddNode(assign_node_def, &status);
1482 init_assign->set_assigned_device_name(devices[i]);
1483 TF_RETURN_IF_ERROR(status);
1484
1485 init_graph->AddEdge(init_handles[i], 0, init_assign, 0);
1486 if (split_dim >= 0) {
1487 init_graph->AddEdge(assign_value_node, i, init_assign, 1);
1488 } else {
1489 init_graph->AddEdge(assign_value_node, 0, init_assign, 1);
1490 }
1491 }
1492
1493 GraphOptimizationPassOptions optimization_options;
1494 SessionOptions session_options;
1495 session_options.env = ctx->env();
1496 optimization_options.session_handle = ctx->session_handle();
1497 optimization_options.session_options = &session_options;
1498 optimization_options.flib_def = flib_def_.get();
1499 optimization_options.graph = nullptr;
1500 optimization_options.device_set = nullptr;
1501 std::unordered_map<std::string, std::unique_ptr<Graph>> subgraphs;
1502 optimization_options.partition_graphs = &subgraphs;
1503 TF_RETURN_IF_ERROR(PartitionHelper(device_set_, optimization_options,
1504 init_graph.get(), &subgraphs));
1505
1506 std::vector<DeviceAndFHandle> functions;
1507 std::vector<std::string> function_names;
1508 for (auto& pair : subgraphs) {
1509 string target = pair.first;
1510 Device* device;
1511 TF_RETURN_IF_ERROR(
1512 library_runtime_->device_mgr()->LookupDevice(target, &device));
1513 Graph* subgraph = pair.second.get();
1514 string function_name = flib_def_->UniqueFunctionName(
1515 strings::StrCat(func_.name(), "_hash_", pair.first));
1516 function_names.push_back(function_name);
1517 FHandle handle;
1518 TF_RETURN_IF_ERROR(InstantiatePartition(*subgraph, function_name, target,
1519 &handle, nullptr));
1520 functions.push_back(DeviceAndFHandle{.device = target, .handle = handle});
1521 }
1522
1523 FunctionLibraryRuntime::Options opts;
1524
1525 // Blocking on threads in the same thread pool is disallowed because
1526 // concurrent warm-up requests can exhaust the default thread pool.
1527 // Create a new thread pool to initialize variables on TPU.
1528 std::function<void(std::function<void()>)> runner =
1529 [this](std::function<void()> fn) { pool_.Schedule(fn); };
1530 opts.runner = &runner;
1531
1532 opts.step_container = ctx->step_container();
1533 opts.cancellation_manager = ctx->cancellation_manager();
1534 opts.stats_collector = ctx->stats_collector();
1535 opts.source_device = local_device_name_;
1536 opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
1537
1538 OpInputList arguments;
1539 TF_RETURN_IF_ERROR(ctx->input_list("args", &arguments));
1540
1541 auto* rendez = new PrivateIntraProcessRendezvous(device_mgr_);
1542 opts.rendezvous = rendez;
1543
1544 BlockingCounter bcount(functions.size());
1545 for (const DeviceAndFHandle& entry : functions) {
1546 const string& target_device = entry.device;
1547 FHandle handle = entry.handle;
1548
1549 TF_RETURN_IF_ERROR(
1550 ShouldUseRemoteExecutionForFn(target_device, &(opts.remote_execution)));
1551 std::vector<Tensor> dummy_args;
1552 std::vector<Tensor>* dummy_rets = new std::vector<Tensor>;
1553
1554 profiler::TraceMe trace_me(
1555 "TPUPartitionedCallOp-InitializeShardedVarOnTPU");
1556 library_runtime_->Run(opts, handle, dummy_args, dummy_rets,
1557 [dummy_rets, &bcount, ctx](const Status& status) {
1558 if (!status.ok()) {
1559 ctx->SetStatus(status);
1560 }
1561 delete dummy_rets;
1562 bcount.DecrementCount();
1563 });
1564 }
1565 bcount.Wait();
1566
1567 for (int i = 0; i < functions.size(); i++) {
1568 TF_RETURN_IF_ERROR(flib_def_->RemoveFunction(function_names[i]));
1569 TF_RETURN_IF_ERROR(library_runtime_->ReleaseHandle(functions[i].handle));
1570 }
1571 return Status::OK();
1572 }
1573
IsInputToTPUReplicate(Node * node)1574 bool TPUPartitionedCallOp::IsInputToTPUReplicate(Node* node) {
1575 for (Node* successor : node->out_nodes()) {
1576 if (successor->attrs().Find(kTpuReplicateAttr) != nullptr) {
1577 return true;
1578 }
1579 }
1580 return false;
1581 }
1582
ReplaceResourceArgsWithVarHandleOps(Graph * graph,OpKernelContext * ctx,int device_ordinal,int num_cores_per_replica,bool enable_spmd_xla_partitioning)1583 Status TPUPartitionedCallOp::ReplaceResourceArgsWithVarHandleOps(
1584 Graph* graph, OpKernelContext* ctx, int device_ordinal,
1585 int num_cores_per_replica, bool enable_spmd_xla_partitioning) {
1586 // Currently variable deduplication is not supported for XLA SPMD
1587 // partitioning. It is possible that it could be supported in the future.
1588 bool enable_variable_deduplication =
1589 runtime_params_.enable_variable_deduplication;
1590 if (enable_spmd_xla_partitioning && num_cores_per_replica > 1) {
1591 // If enable_spmd_xla_partitioning is true, the user set the
1592 // enable_auto_xla_input_sharding flag. Warn them that only one of the flags
1593 // can be set safely when num_cores_per_replica > 1. If
1594 // num_cores_per_replica==1, enable_spmd_xla_partitioning is effectively a
1595 // no-op so we can skip this check.
1596 LOG(WARNING) << "Disabling variable deduplication because it is not "
1597 "compatible with enable_auto_xla_input_sharding.";
1598 enable_variable_deduplication = false;
1599 }
1600 std::vector<Node*> tpu_resource_args;
1601 std::vector<int> arg_indices;
1602 absl::flat_hash_map<const Node*, xla::OpSharding> variable_to_xla_sharding;
1603 for (Node* node : graph->op_nodes()) {
1604 if (node->IsArg()) {
1605 const AttrValue* attr_value;
1606 TF_RETURN_IF_ERROR(node->attrs().Find("T", &attr_value));
1607 DataType dtype = attr_value->type();
1608 if (dtype == DT_RESOURCE && IsInputToTPUReplicate(node)) {
1609 // If this VarHandleOp is used by a TPU computation,
1610 // we need to create a TPU version of the variable,
1611 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
1612 int index = attr_value->i();
1613 tpu_resource_args.push_back(node);
1614 arg_indices.push_back(index);
1615 replaced_input_indices_[index] = true;
1616 }
1617 }
1618 }
1619
1620 VLOG(3) << "tpu_resource_args.size(): " << tpu_resource_args.size();
1621 // Create a mapping from ResourceHandle to variable node. When a
1622 // ResourceHandle backs several variable nodes, the variable nodes refer to
1623 // the same underlying resource. In that case, only one variable node needs
1624 // to be mirrored to the TPU for that resource.
1625 absl::flat_hash_map<uint64, Node*> tpu_variables;
1626 for (int i = 0; i < tpu_resource_args.size(); i++) {
1627 Node* node = tpu_resource_args[i];
1628 ResourceHandle handle = HandleFromInput(ctx, arg_indices[i]);
1629
1630 if (num_cores_per_replica > 1 && enable_spmd_xla_partitioning) {
1631 TF_RETURN_IF_ERROR(ReplaceAndPartitionXLAShardingVariable(
1632 graph, ctx, device_ordinal, handle, node, num_cores_per_replica));
1633 continue;
1634 }
1635 TPUVariableInfo var_info(/*device_ordinal_id=*/0, /*use_fast_mem=*/false);
1636 TF_RETURN_IF_ERROR(
1637 ParseTPUVariableInfor(node, num_cores_per_replica, &var_info));
1638 // Only respect graph's placement when model parallelism enabled.
1639 if (num_cores_per_replica > 1) device_ordinal = var_info.device_ordinal;
1640
1641 const uint64 handle_fp =
1642 Fingerprint64(strings::StrCat(handle.container(), handle.name()));
1643 if (enable_variable_deduplication && tpu_variables.contains(handle_fp) &&
1644 num_cores_per_replica == 1) {
1645 Node* tpu_variable = tpu_variables.at(handle_fp);
1646 std::vector<Node*> dst_nodes;
1647 std::vector<int> src_indices;
1648 std::vector<int> dst_indices;
1649 for (const Edge* edge : node->out_edges()) {
1650 dst_nodes.push_back(edge->dst());
1651 src_indices.push_back(edge->src_output());
1652 dst_indices.push_back(edge->dst_input());
1653 }
1654 graph->RemoveNode(node);
1655 for (int i = 0; i < dst_nodes.size(); i++) {
1656 graph->AddEdge(tpu_variable, src_indices[i], dst_nodes[i],
1657 dst_indices[i]);
1658 }
1659 } else {
1660 uint64 fp =
1661 Fingerprint64(strings::StrCat(handle.container(), handle.name(), i));
1662 NodeDef ndef;
1663 ndef.set_name(strings::StrCat(handle.name(), fp));
1664 ndef.set_op(kVarHandleOp);
1665 if (num_cores_per_replica > 1) {
1666 ndef.set_device(strings::StrCat(kTPUDeviceNamePrefix, device_ordinal));
1667 } else {
1668 // Assign this new VarHandleOp to TPU:0 so the partitioner only
1669 // partiitons the graph into two subgraphs, one on CPU and one on TPU.
1670 // The actual device ordinal on which this VarHandleOp runs is assigned
1671 // after partitioning (in SetDeviceOrdinal).
1672 ndef.set_device(
1673 strings::StrCat(kTPUDeviceNamePrefix, kTPUDefaultDeviceOrdinal));
1674 }
1675
1676 // Replace each _Arg node of type DT_RESOURCE that goes into a TPU node
1677 // by a VarHandleOp on TPU with shared_name "v_tpu_x" where "v" is the
1678 // shared_name of the variable on CPU and "x" is the rewritten device
1679 // ordinal.
1680 const string sname =
1681 strings::StrCat(handle.name(), "_tpu_", device_ordinal);
1682 AddNodeAttr("shared_name", sname, &ndef);
1683 const string cname = ctx->resource_manager()->default_container();
1684 AddNodeAttr("container", cname, &ndef);
1685 core::RefCountPtr<Var> var;
1686 TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &var));
1687 AddNodeAttr("dtype", var->tensor()->dtype(), &ndef);
1688 TensorShapeProto proto;
1689 var->tensor()->shape().AsProto(&proto);
1690 AddNodeAttr("shape", proto, &ndef);
1691 Status status;
1692 Node* new_node = graph->AddNode(ndef, &status);
1693 TF_RETURN_IF_ERROR(status);
1694 std::vector<const Edge*> in_edges(node->in_edges().begin(),
1695 node->in_edges().end());
1696 for (const Edge* edge : in_edges) {
1697 graph->AddEdge(edge->src(), edge->src_output(), new_node,
1698 edge->dst_input());
1699 }
1700 std::vector<Node*> dst_nodes;
1701 std::vector<int> src_indices;
1702 std::vector<int> dst_indices;
1703 for (const Edge* edge : node->out_edges()) {
1704 dst_nodes.push_back(edge->dst());
1705 src_indices.push_back(edge->src_output());
1706 dst_indices.push_back(edge->dst_input());
1707 }
1708 graph->RemoveNode(node);
1709 for (int i = 0; i < dst_nodes.size(); i++) {
1710 graph->AddEdge(new_node, src_indices[i], dst_nodes[i], dst_indices[i]);
1711 }
1712 // Don't initialize variables on TPU if it is done for the ordinal
1713 // already.
1714 if (seen_ordinals_.contains(device_ordinal)) continue;
1715
1716 Device* d;
1717 TF_RETURN_IF_ERROR(library_runtime_->device_mgr()->LookupDevice(
1718 strings::StrCat(kTPUDeviceNamePrefix, device_ordinal), &d));
1719 Var* tpu_var;
1720 status = d->resource_manager()->Lookup(cname, sname, &tpu_var);
1721 if (!status.ok()) {
1722 TF_RETURN_IF_ERROR(InitializeVarOnTPU(ctx, var, &ndef, device_ordinal,
1723 var_info.fast_mem));
1724 }
1725 tpu_variables[handle_fp] = new_node;
1726 }
1727 }
1728
1729 // adjust the index attr of other non-resource arg nodes
1730 int new_index = 0;
1731 for (Node* node : graph->op_nodes()) {
1732 if (node->IsArg()) {
1733 node->ClearAttr("index");
1734 node->AddAttr("index", new_index);
1735 new_index++;
1736 }
1737 }
1738
1739 seen_ordinals_.insert(device_ordinal);
1740
1741 return Status::OK();
1742 }
1743
ReplaceAndPartitionXLAShardingVariable(Graph * graph,OpKernelContext * ctx,int device_ordinal,ResourceHandle & handle,Node * variable,int num_cores_per_replica)1744 Status TPUPartitionedCallOp::ReplaceAndPartitionXLAShardingVariable(
1745 Graph* graph, OpKernelContext* ctx, int device_ordinal,
1746 ResourceHandle& handle, Node* variable, int num_cores_per_replica) {
1747 TF_ASSIGN_OR_RETURN(
1748 auto sharding,
1749 GetShardingFromNodeDef(variable->def(), /*add_metadata=*/false));
1750 xla::OpSharding xla_sharding;
1751 bool is_var_sharded = false;
1752 if (sharding.has_value() &&
1753 sharding.value().type() == xla::OpSharding::OTHER) {
1754 xla_sharding = sharding.value();
1755 is_var_sharded = true;
1756 } else {
1757 xla_sharding.set_type(xla::OpSharding::REPLICATED);
1758 is_var_sharded = false;
1759 }
1760 VLOG(3) << "Replace and partition variable " << variable->name()
1761 << " with xla_sharding: " << xla_sharding.DebugString();
1762
1763 core::RefCountPtr<Var> var;
1764 TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &var));
1765
1766 int split_dim = -1;
1767 int split_size = 0;
1768 if (is_var_sharded) {
1769 for (int dim = 0; dim < xla_sharding.tile_assignment_dimensions_size();
1770 dim++) {
1771 if (xla_sharding.tile_assignment_dimensions(dim) > 1) {
1772 if (split_dim != -1) {
1773 return errors::InvalidArgument(
1774 "Currently we only support inference with one split dimension, "
1775 "however got sharding: ",
1776 xla_sharding.DebugString());
1777 }
1778 split_dim = dim;
1779 split_size = xla_sharding.tile_assignment_dimensions(dim);
1780 }
1781 }
1782 }
1783 const string cname = ctx->resource_manager()->default_container();
1784 std::vector<Node*> per_core_vars;
1785 for (int core_index = device_ordinal;
1786 core_index < (device_ordinal + num_cores_per_replica); core_index++) {
1787 NodeDef ndef;
1788 uint64 fp = Fingerprint64(
1789 strings::StrCat(handle.container(), handle.name(), "_", core_index));
1790 ndef.set_name(strings::StrCat(handle.name(), fp));
1791 ndef.set_op(kVarHandleOp);
1792 ndef.set_device(strings::StrCat(kTPUDeviceNamePrefix, core_index));
1793
1794 // Replace each _Arg node of type DT_RESOURCE that goes into a TPU node
1795 // by a VarHandleOp on TPU with shared_name "v_tpu_x" where "v" is the
1796 // shared_name of the variable on CPU and "x" is the rewritten device
1797 // ordinal.
1798 const string sname = strings::StrCat(handle.name(), "_tpu_", core_index);
1799 AddNodeAttr("shared_name", sname, &ndef);
1800 AddNodeAttr("container", cname, &ndef);
1801 AddNodeAttr("dtype", var->tensor()->dtype(), &ndef);
1802
1803 TensorShapeProto proto;
1804 var->tensor()->shape().AsProto(&proto);
1805
1806 if (is_var_sharded) {
1807 int dim_size = proto.dim(split_dim).size();
1808 if (dim_size % split_size != 0) {
1809 return errors::InvalidArgument("dimension size ", dim_size,
1810 " cannot be divisible by split size ",
1811 split_size);
1812 }
1813 proto.mutable_dim(split_dim)->set_size(dim_size / split_size);
1814 }
1815 AddNodeAttr("shape", proto, &ndef);
1816
1817 Status status;
1818 Node* new_node = graph->AddNode(ndef, &status);
1819 TF_RETURN_IF_ERROR(status);
1820 per_core_vars.push_back(new_node);
1821 }
1822
1823 // Insert TPUPartitionedInput op.
1824 NodeDefBuilder builder(absl::StrCat(handle.name(), "/tpu_partitioned_input"),
1825 "TPUPartitionedInput");
1826 builder.Attr("N", num_cores_per_replica);
1827 builder.Attr("T", DT_RESOURCE);
1828 builder.Attr("partition_dim", split_dim);
1829 builder.Attr("_XlaSharding", xla_sharding.SerializeAsString());
1830 std::vector<NodeDefBuilder::NodeOut> inputs;
1831 inputs.reserve(num_cores_per_replica);
1832 for (int core_index = 0; core_index < num_cores_per_replica; core_index++) {
1833 inputs.push_back({per_core_vars[core_index]->name(), 0, DT_RESOURCE});
1834 }
1835 builder.Input(inputs);
1836 NodeDef node_def;
1837 TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
1838 Status s;
1839 Node* tpu_partitioned_input_node = graph->AddNode(node_def, &s);
1840 if (!s.ok()) {
1841 return s;
1842 }
1843
1844 for (int core_index = 0; core_index < num_cores_per_replica; core_index++) {
1845 graph->AddEdge(per_core_vars[core_index], 0, tpu_partitioned_input_node,
1846 core_index);
1847 }
1848
1849 // Insert TPUReplicatedInput op.
1850 NodeDefBuilder replicated_builder(
1851 absl::StrCat(handle.name(), "/tpu_replicated_input"),
1852 "TPUReplicatedInput");
1853 replicated_builder.Attr("N", 1);
1854 replicated_builder.Attr("T", DT_RESOURCE);
1855 replicated_builder.Attr("is_mirrored_variable", true);
1856 std::vector<NodeDefBuilder::NodeOut> replicated_inputs;
1857 replicated_inputs.push_back(
1858 {tpu_partitioned_input_node->name(), 0, DT_RESOURCE});
1859 replicated_builder.Input(replicated_inputs);
1860 NodeDef replicated_node_def;
1861 TF_RETURN_IF_ERROR(replicated_builder.Finalize(&replicated_node_def));
1862 Status replicated_s;
1863 Node* tpu_replicated_input_node =
1864 graph->AddNode(replicated_node_def, &replicated_s);
1865 if (!replicated_s.ok()) {
1866 return replicated_s;
1867 }
1868 graph->AddEdge(tpu_partitioned_input_node, 0, tpu_replicated_input_node, 0);
1869
1870 // Connect the TPUReplicatedInput node to the previous output nodes of the
1871 // variable, and remove the variable node.
1872 std::vector<Node*> dst_nodes;
1873 std::vector<int> src_indices;
1874 std::vector<int> dst_indices;
1875 for (const Edge* edge : variable->out_edges()) {
1876 dst_nodes.push_back(edge->dst());
1877 src_indices.push_back(edge->src_output());
1878 dst_indices.push_back(edge->dst_input());
1879 }
1880 for (int i = 0; i < dst_nodes.size(); i++) {
1881 graph->AddEdge(tpu_replicated_input_node, src_indices[i], dst_nodes[i],
1882 dst_indices[i]);
1883 }
1884
1885 graph->RemoveNode(variable);
1886
1887 std::vector<NodeDef> ndefs;
1888 Status status;
1889 for (int core_index = 0; core_index < num_cores_per_replica; core_index++) {
1890 Device* d;
1891 TF_RETURN_IF_ERROR(library_runtime_->device_mgr()->LookupDevice(
1892 strings::StrCat(kTPUDeviceNamePrefix, device_ordinal + core_index),
1893 &d));
1894 string sname;
1895 const NodeDef& ndef = per_core_vars[core_index]->def();
1896 TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "shared_name", &sname));
1897 ndefs.push_back(ndef);
1898 Var* tpu_var;
1899 status = d->resource_manager()->Lookup(cname, sname, &tpu_var);
1900 }
1901
1902 if (!status.ok()) {
1903 TF_RETURN_IF_ERROR(
1904 InitializeShardedVarOnTPU(ctx, var, ndefs, split_dim, device_ordinal));
1905 }
1906
1907 return Status::OK();
1908 }
1909
InferShapesWithResourceVar(Graph * graph,OpKernelContext * ctx,std::map<int,InferredShape> & arg_shapes,GraphShapeInfo * tpu_inferred_info)1910 Status TPUPartitionedCallOp::InferShapesWithResourceVar(
1911 Graph* graph, OpKernelContext* ctx,
1912 std::map<int, InferredShape>& arg_shapes,
1913 GraphShapeInfo* tpu_inferred_info) {
1914 auto shape_inference_graph_interim =
1915 absl::make_unique<Graph>(graph->flib_def());
1916 CopyGraph(*graph, shape_inference_graph_interim.get());
1917
1918 for (Node* node : shape_inference_graph_interim->nodes()) {
1919 if (node->type_string() != "_Arg" ||
1920 node->attrs().Find("T")->type() != DT_RESOURCE)
1921 continue;
1922
1923 std::vector<std::function<void()>> to_remove;
1924
1925 for (const Edge* out_edge : node->out_edges()) {
1926 Node* read_node = out_edge->dst();
1927 if (read_node->type_string() != "ReadVariableOp") continue;
1928
1929 for (const Edge* variable_edge : read_node->out_edges()) {
1930 // We are delaying these modifications as we cannot do in-place
1931 // modification of EdgeSets.
1932 to_remove.push_back(
1933 [variable_edge, graph = shape_inference_graph_interim.get(), node] {
1934 Node* dst = variable_edge->dst();
1935 graph->RemoveEdge(variable_edge);
1936 graph->AddEdge(node, variable_edge->src_output(), dst,
1937 variable_edge->dst_input());
1938 });
1939 }
1940 to_remove.push_back(
1941 [graph = shape_inference_graph_interim.get(), out_edge, read_node] {
1942 graph->RemoveEdge(out_edge);
1943 graph->RemoveNode(read_node);
1944 });
1945 }
1946
1947 for (auto& func : to_remove) {
1948 func();
1949 }
1950
1951 int resource_arg_index = node->attrs().Find("index")->i();
1952
1953 // Get resource variable tensor
1954 core::RefCountPtr<Var> variable;
1955 const ResourceHandle& handle = HandleFromInput(ctx, resource_arg_index);
1956 TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &variable));
1957
1958 const Tensor* variable_tensor = variable->tensor();
1959 std::vector<int> variable_tensor_vec;
1960
1961 variable_tensor_vec.reserve(variable_tensor->dims());
1962 for (int d = 0; d < variable_tensor->dims(); ++d) {
1963 variable_tensor_vec.push_back(variable_tensor->dim_size(d));
1964 }
1965
1966 PartialTensorShape partial_tensor_shape;
1967 auto partial_shape = PartialTensorShape::MakePartialShape(
1968 variable_tensor_vec.data(), variable_tensor_vec.size(),
1969 &partial_tensor_shape);
1970 InferredShape inferred_shape = {partial_tensor_shape};
1971 arg_shapes.emplace(resource_arg_index, inferred_shape);
1972 }
1973
1974 TF_RETURN_IF_ERROR(tensorflow::InferShapes(
1975 shape_inference_graph_interim.get(), arg_shapes,
1976 &shape_inference_graph_interim->flib_def(), tpu_inferred_info));
1977 return Status::OK();
1978 }
1979
ShardInputsWithXlaSharding(Graph * graph,int num_cores_per_replica,OpKernelContext * ctx)1980 Status TPUPartitionedCallOp::ShardInputsWithXlaSharding(
1981 Graph* graph, int num_cores_per_replica, OpKernelContext* ctx) {
1982 for (Node* replicated_input_node : graph->nodes()) {
1983 if (replicated_input_node->type_string() != "TPUReplicatedInput") continue;
1984
1985 Node* arg_node;
1986 auto input_node_status = replicated_input_node->input_node(0, &arg_node);
1987 if (!input_node_status.ok()) {
1988 VLOG(2) << "Skip because cannot retrieve input node 0 of "
1989 << replicated_input_node->name() << " because "
1990 << input_node_status.ToString();
1991 continue;
1992 }
1993
1994 // Check if this TPUReplicatedInput can qualify because it has _Arg
1995 // as input and doesn't have XlaSharding already as an output, then
1996 // try to shard inputs automatically.
1997 //
1998 // In short, we want to see the following graph:
1999 // _Arg -> TPUReplicatedInput -> (not XlaSharding op)
2000 // and transform it to:
2001 // _Arg -> TPUReplicatedInput -> XlaSharding -> (not XlaSharding op)
2002 if (arg_node->IsArg() &&
2003 replicated_input_node->out_nodes().begin()->type_string() !=
2004 "XlaSharding") {
2005 int arg_id;
2006 if (!absl::SimpleAtoi(absl::StripPrefix(arg_node->name(), "arg_"),
2007 &arg_id)) {
2008 VLOG(3) << "Skip auto-sharding because we are unable to extract "
2009 "argument number from "
2010 << arg_node->name();
2011 continue;
2012 }
2013
2014 auto shape = ctx->input(arg_id).shape();
2015
2016 VLOG(3) << "Identified arg node " << arg_node->DebugString()
2017 << " for TPUReplicatedInput "
2018 << replicated_input_node->DebugString();
2019 VLOG(3) << "Shape within TPUReplicatedInput is: " << shape.DebugString();
2020
2021 int rank = shape.dims();
2022 int shard_dim =
2023 (runtime_params_.auto_xla_input_sharding_dim + rank) % rank;
2024
2025 if (shape.dim_size(shard_dim) % num_cores_per_replica != 0) {
2026 VLOG(3) << "Skip auto-sharding " << replicated_input_node->name()
2027 << " because the specified sharding dimension " << shard_dim
2028 << " cannot be evenly split by " << num_cores_per_replica;
2029 continue;
2030 }
2031
2032 auto sharding = absl::make_optional<xla::OpSharding>();
2033 sharding->set_type(xla::OpSharding::OTHER);
2034
2035 // Sets up tile_assignment_dimensions.
2036 std::vector<int64> dims(rank, 1LL);
2037 dims[shard_dim] = num_cores_per_replica;
2038 for (auto dim : dims) {
2039 sharding->add_tile_assignment_dimensions(dim);
2040 }
2041
2042 // Sets up tile_assignment_devices.
2043 for (int d = 0; d < num_cores_per_replica; ++d) {
2044 sharding->add_tile_assignment_devices(d);
2045 }
2046
2047 std::vector<const Edge*> edges_to_remove;
2048 for (const Edge* edge : replicated_input_node->out_edges()) {
2049 if (edge->IsControlEdge()) continue;
2050 edges_to_remove.push_back(edge);
2051 }
2052
2053 // Create XlaSharding Op.
2054 Node* sharding_op = nullptr;
2055 TF_RETURN_IF_ERROR(
2056 NodeBuilder(absl::StrCat(replicated_input_node->name(), "/sharding"),
2057 "XlaSharding")
2058 .Input(replicated_input_node, 0)
2059 .Attr("T", replicated_input_node->output_type(0))
2060 .Attr(kXLAShardingAttrName, sharding->SerializeAsString())
2061 .Attr(kXLAShardingAttrAltName, sharding->SerializeAsString())
2062 .Attr("_tpu_replicate", "cluster")
2063 .Finalize(graph, &sharding_op));
2064 for (const Edge* edge : edges_to_remove) {
2065 VLOG(3) << "XlaSharding op creation output edge "
2066 << edge->DebugString();
2067 graph->RemoveEdge(edge);
2068 graph->AddEdge(sharding_op, 0, edge->dst(), edge->dst_input());
2069 }
2070
2071 VLOG(3) << "Auto shard " << replicated_input_node->name() << " by dim "
2072 << shard_dim << " into " << num_cores_per_replica << " slices";
2073
2074 VLOG(3) << "Created XlaSharding Op " << sharding_op->DebugString();
2075 }
2076 }
2077
2078 return Status::OK();
2079 }
2080
2081 // OptimizeTpuInputOutputTensors does the following things;
2082 // (1) Detect input arguments, and add XlaSharding op to the arguments if the
2083 // enable_auto_xla_input_sharding is turned on
2084 // (2) Pack multiple input tensors into one tensor by a concat to avoid PCIe
2085 // transfer overheads for small tensors.
2086 // (3) Reshape input tensors to R1 to leverage the fast path in TPU input
2087 // preparation done by runtime.
2088 // (4) Pack multiple output tensors into one tensor by a concat.
2089 //
2090 // (1) is controlled by --enable_auto_xla_input_sharding and
2091 // --auto_xla_input_sharding_dim
2092 // (2) and (3) are controlled by flags --minimum_input_tensors_packing
2093 // and --input_shape_opt, respectively, while (4) is controlled by
2094 // --minimum_output_tensors_packing.
OptimizeTpuInputOutputTensors(Graph * graph,bool enable_spmd_xla_partitioning,int num_cores_per_replica,std::map<std::string,std::vector<int>> & named_input_shapes,OpKernelContext * ctx)2095 Status TPUPartitionedCallOp::OptimizeTpuInputOutputTensors(
2096 Graph* graph, bool enable_spmd_xla_partitioning, int num_cores_per_replica,
2097 std::map<std::string, std::vector<int>>& named_input_shapes,
2098 OpKernelContext* ctx) {
2099 if (runtime_params_.enable_auto_xla_input_sharding) {
2100 VLOG(2) << DumpGraphToFile("before_enable_auto_xla_input_sharding", *graph,
2101 flib_def_.get());
2102
2103 TF_RETURN_IF_ERROR(
2104 ShardInputsWithXlaSharding(graph, num_cores_per_replica, ctx));
2105 }
2106
2107 GraphShapeInfo tpu_inferred_info;
2108 std::map<int, InferredShape> arg_shapes;
2109 EdgeShapes tpu_input_shapes;
2110 absl::flat_hash_map<const Edge*, DataType> tpu_input_dtypes;
2111
2112 // Contains attrs "T", "sharding", "_tpu_replicate" for each XlaSharding op.
2113 XlaShardingInfoMap xla_sharding_ops;
2114
2115 // Contains attrs "T", and a pointer to tpu_replicated_metadata for ctrl dep
2116 TpuReplicatedInputInfoMap tpu_replicated_input_ops;
2117
2118 bool xla_spmd_input_sharded = false;
2119
2120 if (enable_spmd_xla_partitioning) {
2121 xla_spmd_input_sharded = FindTpuReplicatedInputAndXlaSharding(
2122 graph, xla_sharding_ops, tpu_replicated_input_ops);
2123 }
2124
2125 VLOG(1) << "xla_spmd_input_sharded: " << xla_spmd_input_sharded;
2126 VLOG(2) << DumpGraphToFile("before_remove_descendant_nodes", *graph,
2127 flib_def_.get());
2128
2129 if (!xla_spmd_input_sharded ||
2130 runtime_params_.minimum_input_tensors_packing > 1 ||
2131 runtime_params_.enable_auto_xla_input_sharding) {
2132 // Currently we remove `TPUReplicatedInput` nodes when the input tensors are
2133 // not sharded, input tensors packing optimization is enabled or when
2134 // auto xla input sharding is there.
2135 //
2136 // In all thse cases, we want to remove both the TPUReplicatedInput and
2137 // XlaSharding ops or else downstream rewrites will be confused.
2138 RemoveDescendantNodeOfArg(graph, "TPUReplicatedInput",
2139 /*must_be_child_of=*/{});
2140 }
2141
2142 if (xla_spmd_input_sharded) {
2143 // We are setting must_be_child_of to {"Arg"} because we do not want
2144 // to remove other XlaSharding ops that might be in the graph. We only
2145 // want the XlaSharding ops that are directly attached to the input
2146 // arguments to be removed.
2147 RemoveDescendantNodeOfArg(graph, "XlaSharding",
2148 /*must_be_child_of=*/{"_Arg"});
2149 }
2150
2151 VLOG(2) << DumpGraphToFile("before_get_input_output_info", *graph,
2152 flib_def_.get());
2153
2154 TF_RETURN_IF_ERROR(GetInputOutputInfo(graph, tpu_inferred_info, arg_shapes,
2155 tpu_input_shapes, tpu_input_dtypes,
2156 ctx));
2157
2158 VLOG(2) << DumpGraphToFile("before_optimize_tpu_input_output_tensors", *graph,
2159 flib_def_.get());
2160
2161 string cluster_name;
2162 TF_RETURN_IF_ERROR(GetClusterName(graph, &cluster_name));
2163
2164 if (runtime_params_.minimum_output_tensors_packing > 1) {
2165 // Copy graph to shape_inference_graph
2166 EdgeShapes tpu_output_shapes;
2167 TF_RETURN_IF_ERROR(
2168 InferShapesWithResourceVar(graph, ctx, arg_shapes, &tpu_inferred_info));
2169
2170 // Find TPU -> CPU output edges.
2171 GroupedEdges shape_to_output =
2172 tpu_functional_internal::GroupTensorsForOutputPacking(
2173 graph, tpu_output_shapes, &tpu_inferred_info);
2174
2175 TF_RETURN_IF_ERROR(
2176 tpu_functional_internal::CreateConcatAndSplitNodesForOutputTensor(
2177 graph, cluster_name, &tpu_output_shapes, &tpu_inferred_info,
2178 shape_to_output, runtime_params_.minimum_output_tensors_packing));
2179 }
2180
2181 if (runtime_params_.minimum_input_tensors_packing > 1) {
2182 GroupedEdges grouped_input_edges =
2183 tpu_functional_internal::GroupTensorsForInputPacking(
2184 tpu_input_shapes, tpu_input_dtypes, runtime_params_.input_shape_opt,
2185 runtime_params_.group_tensors_for_packing);
2186 TF_RETURN_IF_ERROR(
2187 tpu_functional_internal::CreateConcatAndSplitNodesForInputTensor(
2188 graph, cluster_name, &tpu_input_shapes, grouped_input_edges,
2189 runtime_params_.minimum_input_tensors_packing,
2190 xla_spmd_input_sharded, xla_sharding_ops,
2191 tpu_replicated_input_ops));
2192 }
2193 if (runtime_params_.input_shape_opt) {
2194 TF_RETURN_IF_ERROR(tpu_functional_internal::InsertReshapeNodePairs(
2195 graph, cluster_name, &tpu_input_shapes, num_cores_per_replica));
2196 }
2197 VLOG(1) << DumpGraphToFile("optim_result", *graph);
2198
2199 // With or without optimizations, collect the input names and shapes.
2200 for (const auto& iter : tpu_input_shapes) {
2201 std::string name = iter.first->src()->name();
2202 named_input_shapes[name] = iter.second;
2203 }
2204 return Status::OK();
2205 }
2206
GetGraphFromFunction(Graph * graph,int device_ordinal,int * num_core_per_replica,bool * use_spmd_for_xla_partitioning)2207 Status TPUPartitionedCallOp::GetGraphFromFunction(
2208 Graph* graph, int device_ordinal, int* num_core_per_replica,
2209 bool* use_spmd_for_xla_partitioning) {
2210 FunctionLibraryRuntime::InstantiateOptions opts;
2211 FHandle handle;
2212 TF_RETURN_IF_ERROR(library_runtime_->Instantiate(
2213 func_.name(), AttrSlice(&func_.attr()), opts, &handle));
2214 const FunctionBody* fbody = library_runtime_->GetFunctionBody(handle);
2215 if (fbody == nullptr) {
2216 return errors::Internal("Could not find handle ", handle);
2217 }
2218 CopyGraph(*fbody->graph, graph);
2219
2220 // Pin the inputs and outputs to the local device to simplify the
2221 // function-dispatching logic.
2222 local_device_name_ = library_runtime_->device()->name();
2223 replaced_input_indices_.resize(fbody->arg_nodes.size(), false);
2224 for (Node* node : graph->op_nodes()) {
2225 if (node->IsArg() || node->IsRetval()) {
2226 node->set_assigned_device_name(local_device_name_);
2227 } else if (node->type_string() == "TPUReplicateMetadata") {
2228 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "num_cores_per_replica",
2229 num_core_per_replica));
2230 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(),
2231 "use_spmd_for_xla_partitioning",
2232 use_spmd_for_xla_partitioning));
2233 VLOG(1) << "num_core_per_replica = " << *num_core_per_replica
2234 << ", use_spmd_for_xla_partitioning = "
2235 << *use_spmd_for_xla_partitioning;
2236
2237 if (*num_core_per_replica > 1) {
2238 std::string topology_str;
2239 std::vector<int> device_assignment;
2240 TF_RETURN_IF_ERROR(
2241 GetNodeAttr(node->attrs(), "topology", &topology_str));
2242 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "device_assignment",
2243 &device_assignment));
2244
2245 tpu::TopologyProto topology;
2246 topology.ParseFromString(topology_str);
2247 int num_cores = topology.device_coordinates_size() / 4;
2248
2249 if (device_assignment.empty()) {
2250 // Number of devices match the cores per replica, so we can just use
2251 // the device assignment from the existing topology instead of
2252 // generating our own.
2253 //
2254 // TODO(b/179292031): Add support for non-natural orders for pods.
2255
2256 // check that the device coordinates for a donut is always in
2257 // natural order.
2258 std::vector<int> natural_order;
2259 switch (num_cores) {
2260 case 2:
2261 TF_RETURN_IF_ERROR(GenerateDeviceNaturalOrder(
2262 /*x_num_cores=*/1, /*y_num_cores=*/1, /*z_num_cores=*/1,
2263 /*num_cores_per_chip=*/2, &natural_order));
2264 break;
2265 case 4: // we assume this is a device with one core per chip.
2266 TF_RETURN_IF_ERROR(GenerateDeviceNaturalOrder(
2267 /*x_num_cores=*/2, /*y_num_cores=*/2, /*z_num_cores=*/1,
2268 /*num_cores_per_chip=*/1, &natural_order));
2269 break;
2270 case 8:
2271 TF_RETURN_IF_ERROR(GenerateDeviceNaturalOrder(
2272 /*x_num_cores=*/2, /*y_num_cores=*/2, /*z_num_cores=*/1,
2273 /*num_cores_per_chip=*/2, &natural_order));
2274 break;
2275 default:
2276 return errors::Unimplemented(
2277 "You must specify a device assignment for all TPU "
2278 "configurations.");
2279 }
2280 if (*num_core_per_replica != num_cores &&
2281 !std::equal(natural_order.begin(), natural_order.end(),
2282 topology.device_coordinates().begin())) {
2283 return errors::InvalidArgument(
2284 "Topology device coordinates for XLA SPMD on donuts must be in "
2285 "natural order.");
2286 }
2287
2288 auto coordinates_start =
2289 topology.device_coordinates().begin() + device_ordinal * 4;
2290 auto coordinates_end = topology.device_coordinates().begin() +
2291 (device_ordinal + *num_core_per_replica) * 4;
2292
2293 node->ClearAttr("device_assignment");
2294 device_assignment.insert(device_assignment.begin(), coordinates_start,
2295 coordinates_end);
2296 node->AddAttr("device_assignment", device_assignment);
2297 }
2298 }
2299 }
2300 }
2301 return Status::OK();
2302 }
2303
PlacementHelper(const DeviceSet & device_set,const GraphOptimizationPassOptions & optimization_options,const string & function_name)2304 Status TPUPartitionedCallOp::PlacementHelper(
2305 const DeviceSet& device_set,
2306 const GraphOptimizationPassOptions& optimization_options,
2307 const string& function_name) {
2308 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
2309 OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
2310 Placer placer(optimization_options.graph->get(), function_name,
2311 optimization_options.flib_def, &device_set);
2312 TF_RETURN_IF_ERROR(placer.Run());
2313 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
2314 OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
2315 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
2316 OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
2317 return Status::OK();
2318 }
2319
PartitionHelper(const DeviceSet & device_set,const GraphOptimizationPassOptions & optimization_options,Graph * graph,std::unordered_map<std::string,std::unique_ptr<Graph>> * subgraphs)2320 Status TPUPartitionedCallOp::PartitionHelper(
2321 const DeviceSet& device_set,
2322 const GraphOptimizationPassOptions& optimization_options, Graph* graph,
2323 std::unordered_map<std::string, std::unique_ptr<Graph>>* subgraphs) {
2324 PartitionOptions partition_options;
2325 partition_options.node_to_loc = [](const Node* node) {
2326 // TODO(akshayka): To better support the distributed case, first split
2327 // the graph by worker (e.g,. using the master session's
2328 // `SplitByWorker` policy), and then recursively partition the
2329 // per-worker shards at the remote worker(s).
2330 return node->assigned_device_name();
2331 };
2332 int64_t edge_name_counter = 0;
2333 partition_options.new_name = [&edge_name_counter](const string& prefix) {
2334 return strings::StrCat(prefix, "/_", ++edge_name_counter);
2335 };
2336 partition_options.get_incarnation = [&device_set](const string& name) {
2337 const Device* d = device_set.FindDeviceByName(name);
2338 if (d == nullptr) {
2339 return PartitionOptions::kIllegalIncarnation;
2340 } else {
2341 return d->attributes().incarnation();
2342 }
2343 };
2344 partition_options.control_flow_added = false;
2345 std::unordered_map<std::string, GraphDef> partitions;
2346 TF_RETURN_IF_ERROR(Partition(partition_options, graph, &partitions));
2347
2348 VLOG(3) << "Partitioned function '" << func_.name() << "', yielding "
2349 << partitions.size() << " shards.";
2350
2351 const FunctionLibraryDefinition* flib_def = &graph->flib_def();
2352 for (auto& partition : partitions) {
2353 std::unique_ptr<Graph> subgraph(new Graph(flib_def));
2354 GraphConstructorOptions opts;
2355 opts.allow_internal_ops = true;
2356 opts.expect_device_spec = true;
2357 const string& device = partition.first;
2358 GraphDef& graph_def = partition.second;
2359 TF_RETURN_IF_ERROR(
2360 ConvertGraphDefToGraph(opts, std::move(graph_def), subgraph.get()));
2361 subgraphs->emplace(device, std::move(subgraph));
2362 }
2363
2364 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
2365 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
2366
2367 return Status::OK();
2368 }
2369
InstantiatePartition(const Graph & graph,const string & function_name,const string & target_device,FHandle * handle,std::unique_ptr<FunctionLibraryDefinition> * out_flib_def)2370 Status TPUPartitionedCallOp::InstantiatePartition(
2371 const Graph& graph, const string& function_name,
2372 const string& target_device, FHandle* handle,
2373 std::unique_ptr<FunctionLibraryDefinition>* out_flib_def) {
2374 FunctionDef shard;
2375 TF_RETURN_IF_ERROR(GraphToFunctionDef(graph, function_name, &shard));
2376 TF_RETURN_IF_ERROR(flib_def_->AddFunctionDef(shard));
2377 FunctionLibraryRuntime::InstantiateOptions opts;
2378 opts.target = target_device;
2379 if (out_flib_def) {
2380 *out_flib_def = std::make_unique<FunctionLibraryDefinition>(*flib_def_);
2381 opts.lib_def = out_flib_def->get();
2382 } else {
2383 opts.lib_def = flib_def_.get();
2384 }
2385 return library_runtime_->Instantiate(function_name, AttrSlice(&shard.attr()),
2386 opts, handle);
2387 }
2388
SetDeviceOrdinal(const DeviceSet & device_set,int device_ordinal,Graph * graph,bool * modified)2389 Status TPUPartitionedCallOp::SetDeviceOrdinal(const DeviceSet& device_set,
2390 int device_ordinal, Graph* graph,
2391 bool* modified) {
2392 int ordinal = -1;
2393 for (Node* node : graph->op_nodes()) {
2394 if (node->type_string() == kVarHandleOp) {
2395 if (IsInputToTPUReplicate(node)) {
2396 // If this VarHandleOp is going to a TPU computation,
2397 // it refers to the TPU variable that we created when replacing the
2398 // resource arguments with VarHandleOps.
2399 node->set_assigned_device_name(
2400 strings::StrCat(kTPUDeviceNamePrefix, device_ordinal));
2401 }
2402 continue;
2403 }
2404 if (HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) {
2405 // Outside compilation related node.
2406 TF_RETURN_IF_ERROR(
2407 SetDeviceOrdinalAttributeForNode(node, device_ordinal));
2408 *modified = true;
2409 continue;
2410 }
2411 const AttrValue* attr = node->attrs().Find(kDeviceOrdinalAttr);
2412 if (attr != nullptr) {
2413 if (!IsSupportedTPUOp(node->type_string())) {
2414 return errors::InvalidArgument("Node ", node->type_string(),
2415 " is not yet supported.");
2416 }
2417 if (ordinal == -1) {
2418 ordinal = attr->i();
2419 } else {
2420 if (ordinal != attr->i()) {
2421 return errors::InvalidArgument(
2422 "Can only partition graphs that use a single device ordinal.");
2423 }
2424 }
2425 node->ClearAttr(kDeviceOrdinalAttr);
2426 node->AddAttr(kDeviceOrdinalAttr, device_ordinal);
2427 VLOG(3) << "Set device ordinal of " << node->type_string() << " to "
2428 << device_ordinal;
2429 *modified = true;
2430 }
2431 if (node->IsSend() || node->IsRecv()) {
2432 static const char* kSendDevice = "send_device";
2433 static const char* kSendDeviceIncarnation = "send_device_incarnation";
2434 static const char* kRecvDevice = "recv_device";
2435 const AttrValue* attr = node->attrs().Find(kSendDevice);
2436 if (attr != nullptr) {
2437 string device = attr->s();
2438 TF_RETURN_IF_ERROR(
2439 UpdateTPUDeviceOrdinal(device_ordinal, &device, modified));
2440 node->ClearAttr(kSendDevice);
2441 node->AddAttr(kSendDevice, device);
2442 node->ClearAttr(kSendDeviceIncarnation);
2443 const Device* d = device_set.FindDeviceByName(device);
2444 int64_t send_incarnation = (d == nullptr)
2445 ? PartitionOptions::kIllegalIncarnation
2446 : d->attributes().incarnation();
2447 node->AddAttr(kSendDeviceIncarnation, send_incarnation);
2448 }
2449 attr = node->attrs().Find(kRecvDevice);
2450 if (attr != nullptr) {
2451 string device = attr->s();
2452 TF_RETURN_IF_ERROR(
2453 UpdateTPUDeviceOrdinal(device_ordinal, &device, modified));
2454 node->ClearAttr(kRecvDevice);
2455 node->AddAttr(kRecvDevice, device);
2456 }
2457 }
2458 }
2459 return Status::OK();
2460 }
2461
InstantiateFunctionsFromSubgraphs(const DeviceSet & device_set,int replica_id,uint64 cache_hash,int num_cores_per_replica,std::unordered_map<std::string,std::unique_ptr<Graph>> subgraphs)2462 Status TPUPartitionedCallOp::InstantiateFunctionsFromSubgraphs(
2463 const DeviceSet& device_set, int replica_id, uint64 cache_hash,
2464 int num_cores_per_replica,
2465 std::unordered_map<std::string, std::unique_ptr<Graph>> subgraphs) {
2466 const Device* reference_device = nullptr;
2467 auto entry =
2468 partition_cache_.emplace(cache_hash, std::vector<DeviceAndFHandle>());
2469
2470 bool rewritten = false;
2471 for (auto& pair : subgraphs) {
2472 string target = pair.first;
2473 int device_ordinal = replica_id;
2474 if (num_cores_per_replica > 1) {
2475 DeviceNameUtils::ParsedName parsed_device;
2476 if (!DeviceNameUtils::ParseFullName(target, &parsed_device)) {
2477 return errors::InvalidArgument("Malformed assigned device '", target,
2478 "'");
2479 }
2480 device_ordinal = parsed_device.id;
2481 }
2482 Device* device;
2483 TF_RETURN_IF_ERROR(
2484 library_runtime_->device_mgr()->LookupDevice(target, &device));
2485 if (reference_device == nullptr) {
2486 reference_device = device;
2487 } else {
2488 if (!DeviceNameUtils::IsSameAddressSpace(
2489 device->parsed_name(), reference_device->parsed_name())) {
2490 return errors::InvalidArgument(
2491 "TPUPartitionedCallOp does not yet support inter-process"
2492 "execution.");
2493 }
2494 }
2495 TF_RETURN_IF_ERROR(device->MaybeRewriteGraph(&pair.second));
2496 Graph* subgraph = pair.second.get();
2497 // For model paralleism inference, we only support num_replica == 1, thus
2498 // there is no need to update the device_ordinal anymore.
2499 if (num_cores_per_replica == 1) {
2500 TF_RETURN_IF_ERROR(
2501 SetDeviceOrdinal(device_set, device_ordinal, subgraph, &rewritten));
2502 } else {
2503 VLOG(1) << "Skip SetDeviceOrdinal()";
2504 }
2505 string function_name = flib_def_->UniqueFunctionName(
2506 strings::StrCat(func_.name(), "_hash_", cache_hash));
2507 TF_RETURN_IF_ERROR(
2508 UpdateTPUDeviceOrdinal(device_ordinal, &target, &rewritten));
2509 FHandle handle;
2510 // Use a copy of the current `flib_def_` to instantiate the function to
2511 // avoid races.
2512 std::unique_ptr<FunctionLibraryDefinition> sub_flib_def;
2513 TF_RETURN_IF_ERROR(InstantiatePartition(*subgraph, function_name, target,
2514 &handle, &sub_flib_def));
2515 // Add handle to the cache entry.
2516 entry.first->second.push_back(
2517 DeviceAndFHandle{.device = target,
2518 .handle = handle,
2519 .flib_def = std::move(sub_flib_def)});
2520 }
2521
2522 if (!rewritten) {
2523 // For regular use cases, TPUPartitionedCallOp only works when the
2524 // function being called in rewritten for TPU. If we don't see any signs
2525 // of this rewriting, warn the user about it.
2526 // We don't raise an error because we want to support the use case of
2527 // running tpu.initialize_system eagerly. In this case, we can't use
2528 // tpu.rewrite because it will add compilation ops that require TPU
2529 // to be initialized, i.e. there is a chicken and egg problem.
2530 // We run tpu.initialize_system through TPUPartitionedCallOp because it
2531 // invokes graph rewrite passes that are necessary for initialization to
2532 // work.
2533 LOG(INFO) << "Function body was not rewritten for TPU. "
2534 << "This is probably a bug unless you are initializing "
2535 << "TPUs eagerly.";
2536 }
2537 return Status::OK();
2538 }
2539
ExecuteRemoteFunction(const FunctionLibraryRuntime::Options & opts,FHandle handle,OpKernelContext * ctx,ReffedStatusCallback * done)2540 void TPUPartitionedCallOp::ExecuteRemoteFunction(
2541 const FunctionLibraryRuntime::Options& opts, FHandle handle,
2542 OpKernelContext* ctx, ReffedStatusCallback* done) {
2543 std::vector<Tensor> dummy_args;
2544 std::vector<Tensor>* dummy_rets = new std::vector<Tensor>;
2545
2546 profiler::TraceMe trace_me("TPUPartitionedCallOp-ExecuteRemote");
2547 absl::ReaderMutexLock l(&mu_);
2548 library_runtime_->Run(opts, handle, dummy_args, dummy_rets,
2549 [dummy_rets, done, ctx](const Status& status) {
2550 if (!status.ok()) {
2551 ctx->SetStatus(status);
2552 }
2553 delete dummy_rets;
2554 done->Unref();
2555 });
2556 }
2557
ExecuteLocalFunction(const FunctionLibraryRuntime::Options & opts,const OpInputList & arguments,FHandle handle,OpKernelContext * ctx,ReffedStatusCallback * done)2558 void TPUPartitionedCallOp::ExecuteLocalFunction(
2559 const FunctionLibraryRuntime::Options& opts, const OpInputList& arguments,
2560 FHandle handle, OpKernelContext* ctx, ReffedStatusCallback* done) {
2561 std::vector<Tensor> args;
2562
2563 for (int i = 0; i < arguments.size(); ++i) {
2564 if (!replaced_input_indices_[i]) {
2565 // _Arg nodes of type DT_RESOURCE that go into a TPU node have been
2566 // replaced by TPU VarHandleOp nodes. No longer need to pass them as
2567 // inputs.
2568 args.push_back(arguments[i]);
2569 }
2570 }
2571 auto* rets = new std::vector<Tensor>;
2572
2573 profiler::TraceMe trace_me("TPUPartitionedCallOp-ExecuteLocal");
2574 absl::ReaderMutexLock l(&mu_);
2575 library_runtime_->Run(opts, handle, args, rets,
2576 [rets, done, ctx](const Status& status) {
2577 if (!status.ok()) {
2578 ctx->SetStatus(status);
2579 } else {
2580 for (int i = 0; i < rets->size(); ++i) {
2581 ctx->set_output(i, (*rets)[i]);
2582 }
2583 }
2584 delete rets;
2585 done->Unref();
2586 });
2587 }
2588
ExecuteFunctions(const std::vector<DeviceAndFHandle> & functions,OpKernelContext * ctx,int device_ordinal,int64_t ordinal_selector_req_id,DoneCallback done)2589 void TPUPartitionedCallOp::ExecuteFunctions(
2590 const std::vector<DeviceAndFHandle>& functions, OpKernelContext* ctx,
2591 int device_ordinal, int64_t ordinal_selector_req_id, DoneCallback done) {
2592 profiler::TraceMe trace_me("TPUPartitionedCallOp-ExecuteFunctions");
2593 FunctionLibraryRuntime::Options opts;
2594 opts.step_container = ctx->step_container();
2595 opts.cancellation_manager = ctx->cancellation_manager();
2596 opts.stats_collector = ctx->stats_collector();
2597 // TODO(akshayka): Consider selecting a runner on a per-device basis,
2598 // i.e., using device-specific threadpools when available.
2599 opts.runner = ctx->runner();
2600 opts.source_device = local_device_name_;
2601 opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
2602
2603 OpInputList arguments;
2604 OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
2605
2606 auto* rendez = new PrivateIntraProcessRendezvous(device_mgr_);
2607 opts.rendezvous = rendez;
2608
2609 StatusCallback callback(
2610 [rendez = rendez, done = std::move(done), device_ordinal = device_ordinal,
2611 req_id = ordinal_selector_req_id,
2612 ordinal_selector = ordinal_selector_](const Status& status) {
2613 delete rendez;
2614 done();
2615 if (req_id >= 0) {
2616 ordinal_selector->DequeueFromCoreSelector(device_ordinal, req_id);
2617 }
2618 });
2619
2620 auto* refcounted_done = new ReffedStatusCallback(std::move(callback));
2621 for (int i = 1; i < functions.size(); ++i) {
2622 refcounted_done->Ref();
2623 }
2624 for (const DeviceAndFHandle& entry : functions) {
2625 const string& target_device = entry.device;
2626 FHandle handle = entry.handle;
2627 VLOG(3) << "Running function shard on device " << target_device
2628 << " with local device name " << local_device_name_;
2629 if (target_device == local_device_name_) {
2630 opts.remote_execution = false;
2631 ExecuteLocalFunction(opts, arguments, handle, ctx, refcounted_done);
2632 } else {
2633 opts.remote_execution = true;
2634 ExecuteRemoteFunction(opts, handle, ctx, refcounted_done);
2635 }
2636 }
2637 }
2638
2639 REGISTER_KERNEL_BUILDER(Name("TPUPartitionedCall").Device(DEVICE_CPU),
2640 TPUPartitionedCallOp);
2641
2642 } // end namespace tensorflow
2643