1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Rewrites TPUReplicate nodes into replicated computations on TPU. 17 // 18 // To represent a distributed TPU computation, we use the 19 // TPUReplicate operator, that describes a subgraph (represented as a 20 // Tensorflow function) to replicate across a TPU pod. 21 // 22 // Model parallelism and data parallelism: 23 // --------------------------------------- 24 // We support two different kinds of parallelism on TPU: 25 // * data parallelism (replication), or parallelization across batches, and 26 // * model parallelism, or parallelization within a batch. 27 // 28 // The function passed to a TPUReplicate operator is replicated many 29 // times across a TPU pod (data parallelism). The `num_replicas` attribute 30 // controls how many replicas of the computation to create. Replicas are mostly 31 // independent; replicas can only communicate using the CrossReplicaSum 32 // operator, which is typically used to communicate gradients during training. 33 // 34 // Each replica may optionally use more than one TPU core (model 35 // parallelism). The `num_cores_per_replica` attribute controls how many cores 36 // there are per replica. For each core, there is a virtual TPU_REPLICATED_CORE 37 // device that is only valid within replicated TPU computations (e.g., 38 // TPU_REPLICATED_CORE:0, TPU_REPLICATED_CORE:1, etc.); each TPU_REPLICATED_CORE 39 // device corresponds to one TPU core in every replica. 40 // Each replica has runs its own copy of the computation assigned to each 41 // TPU_REPLICATED_CORE device. 42 // 43 // The Python code is responsible for providing a device_assignment that 44 // describes how the replicated logical cores map to physical cores on the TPU 45 // topology. 46 // 47 // Inputs to TPUReplicate: 48 // ------------------------------ 49 // The TPUReplicate operator takes three kinds of inputs, in the 50 // following order: 51 // * per-replica inputs. If there are three per-replica inputs (A, B, C) and two 52 // replicas, the first six arguments to TPUReplicate will be: 53 // A0 B0 C0 A1 B1 C1 54 // where Ai is the A input to the i-th replica. 55 // * distributed inputs. These inputs follow the per-replica inputs. 56 // If there are two distributed inputs (E, F) and two replicas, the following 57 // arguments to TPUReplicate will be: E F. 58 // But there is local E and F on each replica. 59 // * broadcast inputs. These inputs follow the distributed inputs. All 60 // replicas receive a copy of each of these inputs. 61 // * variables. Resource variables accessed by the computation follow the 62 // broadcast inputs. 63 // 64 // For example, for a computation with two replicas, three per-replica inputs 65 // (A, B, C), two distributed inputs(E, F), two broadcast inputs (X, Y), and two 66 // variables (V, W), the arguments to TPUReplicate will be: 67 // A0 B0 C0 A1 B1 C1 E F X Y V W 68 // and each replica will receive the following arguments: 69 // A B C E F X Y V W 70 // 71 // Distributed TPU compilation requires that the shapes of all operators 72 // be known statically at compilation time, before any nodes have executed. 73 // Shapes are determined using shape information emitted by InferShapes. It 74 // is not possible to replicate Tensorflow operators with unknown or dynamic 75 // shapes for TPU at present. 76 // 77 // Graph rewrite: 78 // -------------- 79 // Compilation replaces TPUReplicate operators with: 80 // * a single TPUCompile node that compiles the computations, 81 // * one TPUExecute node for each TPU device in the system that 82 // executes the relevant computation, 83 // * one ReadVariableOp for each variable accessed by the replicated 84 // computation, 85 // * one AssignVariableOp for each variable accessed by the replicated 86 // computation. An assignment is built even if a variable is only read by the 87 // computation. We do not know which variables are written until we apply the 88 // XlaCompiler to the computation, but that does not happen until after the 89 // rewrite. Conservatively, we write back the values of all variables after 90 // the computation completes. 91 // TODO(phawkins): only write back variables that the computation may write. 92 // * one Shape node for each Tensor or Variable input to the computation whose 93 // shape is not statically known at rewrite time. The input shapes are fed 94 // to the TPUCompile node. 95 // 96 // To ensure that the reads and writes seem to happen at the right time in the 97 // graph execution, we add control edges from all predecessors of the original 98 // TPUReplicate operator to each of the ReadVariableOp operators. 99 // Similarly, we add control edges from all of the AssignVariableOp operators to 100 // all of the successors of the TPUReplicate operator. 101 // 102 // The TPUReplicate rewrite must run before placement, since resource 103 // variable inputs will have DT_RESOURCE, which cannot be sent across devices, 104 // leading to objections from the placer. The rewrite rewrites the resource 105 // accesses into explicit ReadVariableOp and AssignVariableOp operators that the 106 // placer is free to colocate with the variables. 107 108 #ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_ 109 #define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_ 110 111 #include <string> 112 #include <vector> 113 114 #include "absl/container/node_hash_map.h" 115 #include "absl/types/span.h" 116 #include "tensorflow/compiler/jit/shape_inference.h" 117 #include "tensorflow/compiler/xla/service/computation_placer.h" 118 #include "tensorflow/core/common_runtime/optimization_registry.h" 119 #include "tensorflow/core/framework/function.h" 120 #include "tensorflow/core/graph/graph.h" 121 #include "tensorflow/core/platform/env.h" 122 #include "tensorflow/stream_executor/tpu/tpu_topology.h" 123 124 namespace tensorflow { 125 126 // Replaces clusters assigned to TPU_SYSTEM devices with 127 // TPUCompile and TPUExecute nodes assigned to the corresponding 128 // TPU devices. 129 class DistributedTPURewritePass : public GraphOptimizationPass { 130 public: 131 static void SetDistributedTpuRewritePassOptions( 132 bool distribute_vars, bool allow_xla_spmd_partition, 133 bool replicate_inputs_outputs_by_default_for_xla_spmd, 134 bool enable_cross_replica_sharding_mirrored_variables, 135 bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast); 136 137 Status Run(const GraphOptimizationPassOptions& options) override; 138 139 // The following methods are public only for the use of unit tests. 140 141 // See comment at the top of the file for how the inputs are ordered. 142 // Encapsulates the different TPU replicated node input and output 143 // information, and provide common APIs over them. 144 class ParameterInfo { 145 public: ParameterInfo()146 ParameterInfo() {} ParameterInfo(int64 num_replicas,int64 num_per_replica_args,int64 num_distributed_args,int64 num_broadcast_args,int64 num_variables,int64 num_guaranteed_constants,int64 num_retvals_per_replica)147 ParameterInfo(int64 num_replicas, int64 num_per_replica_args, 148 int64 num_distributed_args, int64 num_broadcast_args, 149 int64 num_variables, int64 num_guaranteed_constants, 150 int64 num_retvals_per_replica) 151 : num_replicas_(num_replicas), 152 num_per_replica_args_(num_per_replica_args), 153 num_distributed_args_(num_distributed_args), 154 num_broadcast_args_(num_broadcast_args), 155 num_variables_(num_variables), 156 num_guaranteed_constants_(num_guaranteed_constants), 157 num_retvals_per_replica_(num_retvals_per_replica) {} 158 NumReplicas()159 int64 NumReplicas() const { return num_replicas_; } 160 NumPerReplicaArgs()161 int64 NumPerReplicaArgs() const { return num_per_replica_args_; } 162 NumDistributedArgs()163 int64 NumDistributedArgs() const { return num_distributed_args_; } 164 NumBroadcastArgs()165 int64 NumBroadcastArgs() const { return num_broadcast_args_; } 166 NumVariables()167 int64 NumVariables() const { return num_variables_; } 168 NumGuaranteedConstants()169 int64 NumGuaranteedConstants() const { return num_guaranteed_constants_; } 170 NumRetvalsPerReplica()171 int64 NumRetvalsPerReplica() const { return num_retvals_per_replica_; } 172 IsPerReplicaArg(int64 index)173 bool IsPerReplicaArg(int64 index) const { 174 return index < num_per_replica_args_; 175 } 176 IsDistributedArg(int64 index)177 bool IsDistributedArg(int64 index) const { 178 return index >= num_per_replica_args_ && 179 index < (num_per_replica_args_ + num_distributed_args_); 180 } 181 IsBroadcastArg(int64 index)182 bool IsBroadcastArg(int64 index) const { 183 return index >= num_per_replica_args_ && 184 index < (num_per_replica_args_ + num_distributed_args_ + 185 num_broadcast_args_); 186 } 187 IsVariableArg(int64 index)188 bool IsVariableArg(int64 index) const { 189 return index >= (num_per_replica_args_ + num_broadcast_args_) && 190 index < (num_per_replica_args_ + num_distributed_args_ + 191 num_broadcast_args_ + num_variables_); 192 } 193 IsConstantArg(int64 index)194 bool IsConstantArg(int64 index) const { 195 return index >= (num_per_replica_args_ + num_distributed_args_ + 196 num_broadcast_args_ + num_variables_) && 197 index < (num_per_replica_args_ + num_distributed_args_ + 198 num_broadcast_args_ + num_variables_ + 199 num_guaranteed_constants_); 200 } 201 202 // Returns the number of inputs which has been received by the host. NumInputsFromHost()203 int64 NumInputsFromHost() const { 204 return num_replicas_ * num_per_replica_args_ + num_distributed_args_ + 205 num_broadcast_args_ + num_variables_ + num_guaranteed_constants_; 206 } 207 208 // Returns the number of inputs which will be sent to each replica. NumInputsToEachReplica()209 int64 NumInputsToEachReplica() const { 210 return num_per_replica_args_ + num_distributed_args_ + 211 num_broadcast_args_ + num_variables_ + num_guaranteed_constants_; 212 } 213 214 // Returns the total number of output values returned to the host (for all 215 // replicas). NumOutputsToHost()216 int64 NumOutputsToHost() const { 217 return num_replicas_ * num_retvals_per_replica_; 218 } 219 220 // Returns the position of the first per-replica argument, within the set 221 // of all hosts arguments. 222 // Broadcast arguments follow the distributed arguments. FirstBroadcastArgFromHost()223 int64 FirstBroadcastArgFromHost() const { 224 return num_replicas_ * num_per_replica_args_ + num_distributed_args_; 225 } 226 227 // Indices of mirrored variables across replicas, which should be 228 // categorized as per_replica_args. mirrored_variable_indices()229 const std::set<int64>& mirrored_variable_indices() const { 230 return mirrored_variable_indices_; 231 } mutable_mirrored_variable_indices()232 std::set<int64>* mutable_mirrored_variable_indices() { 233 return &mirrored_variable_indices_; 234 } 235 236 private: 237 int64 num_replicas_ = 1; 238 int64 num_per_replica_args_ = 0; 239 int64 num_distributed_args_ = 0; 240 int64 num_broadcast_args_ = 0; 241 int64 num_variables_ = 0; 242 int64 num_guaranteed_constants_ = 0; 243 int64 num_retvals_per_replica_ = 0; 244 std::set<int64> mirrored_variable_indices_; 245 }; 246 247 // Mapping from TPUReplicate cluster name to tpu device names. Value is a 248 // mapping from [replica][core] to a TF device name. 249 typedef absl::flat_hash_map<string, std::vector<std::vector<string>>> 250 TPUReplicateDeviceNamesMapping; 251 252 // Determines which devices to use to run the computation. 253 // Inputs: 254 // * num_tpus_per_task: the number of TPU devices attached to each task 255 // * tpu_devices: a [task][device] collection of TPU devices 256 // * num_replicas: the number of replicas requested 257 // * num_cores_per_replica: the number of cores in each computation instance 258 // * topology_attr: the topology TPUReplicate attribute 259 // * device_assignment_attr: the device_assignment TPUReplicate attribute 260 // Outputs: 261 // * tf_device_assignment: a mapping from [replica][core] to a TF device name 262 // * xla_device_assignment: a mapping from [replica][core] to a linearized TPU 263 // coordinate. 264 // TODO(phawkins): change tf_device_assignment to an xla::Array2D. 265 static Status BuildDeviceAssignment( 266 const tpu::TpuTopologyExternal& topology, int num_tpus_per_task, 267 const std::vector<std::vector<Device*>>& tpu_devices, int num_replicas, 268 int num_cores_per_replica, const string& topology_attr, 269 absl::Span<const int> device_assignment_attr, 270 std::vector<std::vector<string>>* tf_device_assignment, 271 std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment); 272 273 // Returns the `computation` graph attached to TPUReplicate operator 274 // `node`. `flr` is a FunctionLibraryRuntime to use when 275 // instantiating the function body. Sets `*arg_types` and 276 // `*retval_types` to the argument/return types of the function. 277 static Status GetComputationForTPUReplicateOp(const NameAttrList& function, 278 FunctionLibraryRuntime* flr, 279 Graph* computation, 280 DataTypeVector* arg_types, 281 DataTypeVector* retval_types); 282 283 // Returns the shapes of the argument tensors and return values of the 284 // TPUReplicate operator `node` using the _output_shapes, 285 // _output_handle_shapes, and _output_handle_types annotations on the input 286 // nodes. Expects inputs in the following order (see comment at top of file): 287 // * num_replicas * num_per_replica_args per-replica inputs, 288 // * num_broadcast_args broadcast inputs, 289 // * num_variables variable inputs. 290 // Returns an error if the input shapes to `node` are not statically known. 291 // Also verifies that all replicas have identical input shapes for their 292 // per-replica inputs. 293 static Status GetArgAndRetvalShapes( 294 const GraphShapeInfo& shape_info, const Node& node, 295 const ParameterInfo& params_info, std::vector<InferredShape>* arg_shapes, 296 std::vector<InferredShape>* retval_shapes); 297 298 // Assigns arguments and return values to cores. The assignment is represented 299 // as an XLA op sharding, so that an argument can be replicated across cores. 300 // `arg_sharding` and `retval_sharding` are vectors of shardings indexed by 301 // argument/retval number. 302 // `arg_fast_mem` is vector of fast_mem indication which is indexed by 303 // argument number. 304 static Status AssignArgsAndRetvalsToCores( 305 int num_cores_per_replica, const ParameterInfo& params_info, 306 const DataTypeVector& arg_types, 307 const std::vector<InferredShape>& arg_shapes, 308 const DataTypeVector& retval_types, 309 const std::vector<InferredShape>& retval_shapes, const Graph& graph, 310 const Node* replicate_node, FunctionLibraryRuntime* flr, 311 bool allow_parameter_replication_for_spmd, 312 std::vector<::xla::OpSharding>* arg_sharding, 313 std::vector<bool>* arg_fast_mem, 314 std::vector<::xla::OpSharding>* retval_sharding, 315 std::vector<std::string>* arg_names); 316 317 // Populates `*variables` with the "variables" inputs to `index`-th output of 318 // `node`. 319 struct VariableInput { 320 Node* node; 321 int index; 322 323 // Type of the variable's value. Note that this is different to the type of 324 // the output of 'variable', which is always DT_RESOURCE. 325 DataType dtype; 326 }; 327 static Status FindVariableInputs(const Node& node, 328 const NameRangeMap& input_range_map, 329 std::vector<VariableInput>* variables); 330 331 // Populates '*guaranteed_constants' with the "guaranteed_constants" inputs 332 // to 'node'. 333 static Status FindGuaranteedConstantInputs( 334 const Node& node, const NameRangeMap& input_range_map, 335 std::vector<Node*>* guaranteed_constants); 336 337 // Builds Shape nodes that compute the shapes of arguments whose shapes are 338 // not statically known. 339 static Status BuildDynamicShapeNodes( 340 const Node& replicate_node, const std::vector<InferredShape>& arg_shapes, 341 const ParameterInfo& params_info, 342 const std::vector<Node*>& variable_reads, Graph* graph, 343 std::vector<Node*>* dynamic_shape_nodes); 344 345 // Builds a TPUCompile node that compiles the computation in 346 // `function_names`. calls `nodes`. 347 // TODO(b/33943292): at present, for model parallelism with Send/Recv to work 348 // the `nodes` must correspond to the computations assigned to TPU:0, 349 // TPU:1, ... in order since XLA hard-codes the chip IDs in the generated 350 // executables. 351 static Status BuildCompileNode( 352 const Node* replicate_node, const NameAttrList& function, 353 uint64 library_fingerprint, const ParameterInfo& params_info, 354 const std::vector<InferredShape>& arg_shapes, 355 const DataTypeVector& arg_types, 356 const std::vector<Node*>& guaranteed_constant_nodes, 357 const string& session_handle, 358 const std::vector<::xla::OpSharding>& arg_sharding, 359 const std::vector<bool>& arg_fast_mem, 360 const std::vector<std::string>& arg_names, 361 const std::vector<::xla::OpSharding>& retval_sharding, 362 int num_cores_per_replica, const string& compile_device, 363 const xla::DeviceAssignment* xla_device_assignment, 364 const std::vector<Node*>& dynamic_shape_nodes, Graph* graph, 365 Node** compile_node, int64 autotuner_thresh); 366 367 // Builds a TPUCompileSucceededAssert node that verifies that compilation 368 // succeeded and replaces the TPUCompilationStatus node in the graph. 369 static Status BuildCompilationStatusReturnNodes( 370 Node* replicate_node, Node* compile_node, 371 Node** control_after_compilation, Graph* graph); 372 373 // Builds ReadVariableOp nodes that read `variables`, with a control 374 // edges that ensure they happen after `control_predecessor`. 375 static Status BuildVariableReads(absl::Span<const VariableInput> variables, 376 Node* control_predecessor, Graph* graph, 377 std::vector<Node*>* variable_reads); 378 379 // Returns true if graph or functions contain resource write op, otherwise 380 // return false. 381 // TODO(b/137048563): Recognize unused resource rewrite op. 382 static bool ContainsResourceWriteOp(const Graph& graph, 383 const FunctionLibraryDefinition& fld); 384 // Struct that describes a variable value to be written back from TPUExecute. 385 struct VariableWrite { 386 // A node:output pair containing a boolean tensor that determines whether 387 // the value should be written back. 388 Node* predicate; 389 int predicate_output; 390 391 // A node:output pair containing the value to be written back. 392 Node* value; 393 int value_output; 394 }; 395 396 // Builds AssignVariableOp nodes that write `variables` with the values from 397 // `variable_writes`, with control edges that ensure the writes happen before 398 // `control_successor`. 399 static Status BuildVariableWrites( 400 absl::Span<const VariableInput> variables, Node* control_successor, 401 absl::Span<const VariableWrite> variable_writes, Graph* graph); 402 403 // Builds TPUExecute operators assigned to each TPU device 404 // involved in the computation. 405 // Arguments: 406 // * `params_info` is the structure containing the information about the 407 // TPUReplicate node inputs and outputs. 408 // * `num_tasks` is the number of TensorFlow tasks in the slice. 409 // * `num_cores_per_replica` is the number of cores which are dedicated to 410 // each replica. 411 // * `replicate_node` is the original TPUReplicate node. 412 // * `arg_names` are the names of the arguments to the computation function 413 // passed as argument to TPUReplicate, including per-replica, 414 // broadcast, and variable arguments. 415 // * `arg_types` are the corresponding types of the arguments. 416 // * `arg_shapes` are the corresponding shapes (and handle types/shapes, if 417 // applicable). 418 // * `arg_shardings` and `retval_shardings` are mappings from 419 // arguments/return indices to shardings, as returned by 420 // `AssignArgsAndRetvalsToCores`. 421 // * `pod_devices` lists the devices to assign to each core of each replica. 422 // * `variable_reads` is a vectors of ReadVariableOp operators, one for each 423 // variable argument to the computation. 424 // * The execute operators will have a control edge from 425 // `control_predecessor` and another control edge to `control_successor`. 426 // Populates '*variable_writes' with information about variable values to 427 // write back. 428 static Status BuildExecuteNodes( 429 const ParameterInfo& params_info, int num_tasks, 430 int num_cores_per_replica, const Node& replicate_node, 431 const std::vector<std::string>& arg_names, 432 const DataTypeVector& arg_types, 433 const std::vector<InferredShape>& arg_shapes, 434 const DataTypeVector& retval_types, 435 const std::vector<::xla::OpSharding>& arg_shardings, 436 const std::vector<::xla::OpSharding>& retval_shardings, 437 const std::vector<std::vector<string>>& tpu_device_names, 438 Node* compile_node, const std::vector<Node*>& variable_reads, 439 Node* control_predecessor, Node* control_successor, 440 std::vector<VariableWrite>* variable_writes, Graph* graph); 441 442 // Connects the compile node to all the host transfer nodes, and removes the 443 // key placeholder node that was previously standing in for it. 444 // Arguments: 445 // * `compile_node` is the TPUCompile node that has been added to the graph. 446 // * `key_placeholder_node` is the placeholder node to send the key to all the 447 // host 448 // * transfer nodes in the original graph. 449 // * `graph` is the graph being rewritten. 450 static Status ConnectHostComputeNodes(Node* compile_node, 451 Node* key_placeholder_node, 452 Graph* graph); 453 454 // Map from a Node in an outside_compilation cluster in the original graph to 455 // the list of Nodes, one for each replica, that it is expanded into during 456 // replication. 457 typedef absl::node_hash_map<Node*, std::vector<Node*>> NodeToNodeReplicasMap; 458 459 // Map from the name of an outside_compilation cluster to the model-parallel 460 // core index that the HostCompute Op should be placed on in that cluster. 461 typedef std::map<string, int> HostComputeCoreMap; 462 463 // Map from the name of an outside_compilation cluster to the list of Nodes 464 // that should run on the host for that cluster. 465 typedef std::map<string, std::vector<Node*>> OutsideCompilationNodeMap; 466 467 // Copies the outside_compilation nodes in a cluster to create replica 468 // replica_index. 469 static Status CopyOutsideCompilationNodes( 470 int replica_index, const std::vector<Node*>& outside_compilation_nodes, 471 const DeviceNameUtils::ParsedName& tpu_device, 472 const DeviceNameUtils::ParsedName& partial_device, 473 NodeToNodeReplicasMap* node_images, Graph* graph); 474 475 // Replicates all the nodes in outside_compilation clusters in a compiled 476 // computation. 477 static Status ReplicateOutsideCompilationNodes( 478 const std::vector<std::vector<string>>& tf_device_assignment, 479 const HostComputeCoreMap& host_compute_core, 480 const OutsideCompilationNodeMap& outside_compilation_nodes, 481 NodeToNodeReplicasMap* node_images, Graph* graph); 482 483 // Lifts the edges between original outside_compilation nodes in a cluster 484 // onto their replicas. 485 static Status CopyOutsideCompilationEdges( 486 const std::vector<Node*>& outside_compilation_nodes, 487 const NodeToNodeReplicasMap& node_images, 488 const std::unordered_map<string, Node*> outside_compilation_inputs, 489 Graph* graph); 490 491 // Lifts all the edges in outside_compilation clusters in a compiled 492 // computation to their replicas. 493 static Status ReplicateOutsideCompilationEdges( 494 const OutsideCompilationNodeMap& outside_compilation_nodes, 495 const NodeToNodeReplicasMap& node_images, 496 const std::unordered_map<string, Node*> outside_compilation_inputs, 497 Graph* graph); 498 499 // Removes all the original outside_compilation nodes from the graph, 500 // following replication. 501 static Status RemoveOutsideCompilationNodes( 502 const NodeToNodeReplicasMap& node_images, Graph* graph); 503 504 // Lowers outside compilation functional nodes (If/While/function call). 505 // Otherwise, when we have multiple workers, device placer will not be able to 506 // place nodes if outside compilation has DT_RESOURCE inputs (e.g. a 507 // DT_RESOURCE input fed into multiple While nodes on different devices). 508 static Status LowerOutsideCompilationFunctionalNodes( 509 Graph* g, const FunctionLibraryDefinition& flib_def, 510 const TPUReplicateDeviceNamesMapping& tpu_replicate_device_names_mapping); 511 512 // Parses the 'host_compute_core' attribute on replicate_node to get the 513 // replicated core id of each outside_compilation cluster. 514 static Status ParseHostComputeCores( 515 const Node& replicate_node, 516 const OutsideCompilationNodeMap& outside_compilation_nodes, 517 HostComputeCoreMap* host_compute_core); 518 519 // Gets the physical topology information about the TPU system. 520 static Status GetDeviceTopology( 521 const DeviceSet& device_set, const Node& replicate_node, 522 int* num_replicas, int* num_cores_per_replica, int* num_tasks, 523 std::vector<std::vector<string>>* tf_device_assignment, 524 std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment, 525 string* tpu_compilation_device); 526 527 // Gets the types of args, retvals, and parameters. 528 static Status GetIOTypes( 529 int num_replicas, const Node& replicate_node, FunctionLibraryRuntime* flr, 530 Graph* graph, NameRangeMap* input_name_map, const NameAttrList** function, 531 std::unique_ptr<Graph>* computation, DataTypeVector* arg_types, 532 DataTypeVector* retval_types, ParameterInfo* params_info); 533 534 // Find known constants and deals with variable reads. 535 static Status DealWithConstantsAndVariables( 536 const Node& replicate_node, const NameRangeMap& input_name_map, 537 Graph* graph, Node* host_transfer_sequencer, Node* control_before, 538 Node* control_after, absl::Span<const VariableInput> variable_nodes, 539 std::vector<Node*>* guaranteed_constant_nodes, 540 std::vector<Node*>* variable_reads); 541 542 // Adds NoOp nodes for sequencing computation and variable reads/writes. 543 static Status BuildSequencingNodes(const string& tpu_compilation_device, 544 const Node& replicate_node, Graph* graph, 545 Node** host_transfer_sequencer, 546 Node** control_before, 547 Node** control_after); 548 549 // Performs the pass's rewrite on a TPUReplicate node `node`. 550 static Status RewriteTPUReplicateNode( 551 const string& session_handle, const DeviceSet& device_set, 552 Node* replicate_node, FunctionLibraryDefinition* flib_def, 553 FunctionLibraryRuntime* flr, Node* host_compute_key_placeholder_node, 554 const OutsideCompilationNodeMap& outside_compilation_nodes, 555 const std::vector<Node*>& head_tail_outside_compilation_nodes, 556 NodeToNodeReplicasMap* outside_compilation_node_images, Graph* graph, 557 const GraphShapeInfo& shape_info, 558 TPUReplicateDeviceNamesMapping* tpu_replicate_device_names_mapping, 559 int64 autotuner_thresh); 560 561 // Performs host training loop optimization. For example, when TPUExecute 562 // node is inside a while loop, then model weight variables can be sharded 563 // in XLA preferred layout and then unsharded only at the very last iteration 564 // to reduce the number of all_gather. 565 static Status PerformHostTrainingLoopOptimization( 566 Graph* graph, FunctionLibraryDefinition* flib_def, 567 FunctionLibraryRuntime* flr); 568 569 // Heuristically place some nodes with unassigned devices on TPUs for 570 // performance reasons. 571 static Status PlaceUnassignedDeviceNodesOnTPUIfPossible(Graph* graph); 572 573 // Updates the head and tail outside compiled nodes so that nodes have the 574 // correct device and removes the replication and outside compilation 575 // attributes so that these nodes do not trigger further graph optimization 576 // passes. 577 static Status UpdateHeadTailOutsideCompilation( 578 const std::vector<std::vector<string>>& tf_device_assignment, 579 const std::vector<Node*>& head_tail_outside_compilation_nodes); 580 581 private: 582 static bool distribute_vars_; 583 static bool allow_xla_spmd_partition_; 584 static bool replicate_inputs_outputs_by_default_for_xla_spmd_; 585 static bool enable_cross_replica_sharding_mirrored_variables_; 586 static bool enable_automatic_model_parallelism_; 587 static bool enable_xla_param_broadcast_; 588 }; 589 590 } // namespace tensorflow 591 592 #endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_ 593