1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_FUNCTIONAL_OPS_H_ 17 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_FUNCTIONAL_OPS_H_ 18 19 #include "absl/base/call_once.h" 20 #include "tensorflow/compiler/jit/shape_inference.h" 21 #include "tensorflow/core/common_runtime/device_mgr.h" 22 #include "tensorflow/core/common_runtime/optimization_registry.h" 23 #include "tensorflow/core/framework/function.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/graph/graph.h" 26 #include "tensorflow/core/lib/core/threadpool.h" 27 #include "tensorflow/core/tpu/kernels/tpu_ordinal_selector.h" 28 #include "tensorflow/core/tpu/tpu_api.h" 29 #include "tensorflow/core/tpu/tpu_ops_c_api.h" 30 #include "tensorflow/core/util/reffed_status_callback.h" 31 #include "absl/container/flat_hash_map.h" 32 33 namespace tensorflow { 34 // Holds node's shape information for Concat/Split. 35 using EdgeShapes = absl::flat_hash_map<const Edge*, std::vector<int>>; 36 using GroupedEdges = 37 absl::flat_hash_map<std::string, std::vector<const Edge*>>; 38 39 // Contains attrs "T", "sharding", "_tpu_replicate" for each XlaSharding op that 40 // we find as part of searching for inputs to models that are replicated. 41 using XlaShardingInfoMap = absl::flat_hash_map< 42 std::string, std::tuple<tensorflow::DataType, std::string, std::string>>; 43 44 // Contains attrs "T", and a pointer to tpu_replicated_metadata for ctrl dep 45 // for each TpuReplicatedInput op that we find as part of searching for inputs 46 // to models that are replicated. 47 using TpuReplicatedInputInfoMap = 48 absl::flat_hash_map<std::string, 49 std::tuple<tensorflow::DataType, Node*>>; 50 51 namespace tpu_functional_internal { 52 53 // Helper functions for graph rewrites. 54 GroupedEdges GroupTensorsForInputPacking( 55 const EdgeShapes& tpu_input_shapes, 56 const absl::flat_hash_map<const Edge*, DataType>& tpu_input_dtypes, 57 bool input_shape_opt, bool group_tensors_for_packing); 58 GroupedEdges GroupTensorsForOutputPacking(Graph* graph, 59 EdgeShapes& tpu_output_shapes, 60 GraphShapeInfo* shape_info); 61 62 Status CreateConcatAndSplitNodesForInputTensor( 63 Graph* graph, const string& cluster_name, EdgeShapes* tpu_input_shapes, 64 const absl::flat_hash_map<std::string, std::vector<const Edge*>>& 65 grouped_input_edges, 66 int32_t minimum_input_tensors_packing, bool xla_spmd_input_sharded, 67 const XlaShardingInfoMap& xla_sharding_info, 68 const TpuReplicatedInputInfoMap& tpu_replicated_input_info); 69 Status CreateConcatAndSplitNodesForOutputTensor( 70 Graph* graph, const string& cluster_name, EdgeShapes* tpu_output_shapes, 71 GraphShapeInfo* tpu_inferred_info, GroupedEdges shape_to_output, 72 int32_t minimum_output_tensors_packing); 73 74 Status InsertReshapeNodePairs(Graph* graph, const string& cluster_name, 75 EdgeShapes* tpu_input_shapes, 76 int num_cores_per_replica); 77 78 } // namespace tpu_functional_internal 79 80 typedef FunctionLibraryRuntime::Handle FHandle; 81 82 // A `TPUPartitionedCallOp` asynchronously executes a function on exactly one 83 // TPU core and potentially across multiple other devices, but within a single 84 // process. The kernel places and partitions the function's underlying graph, 85 // executing each of the partitioned subgraphs as a function. 86 // 87 // The core on which the TPU computation is executed must be specified via the 88 // `device_ordinal` input. Different invocations of this op may specify 89 // different device ordinals, making it possible to map TPU computations to 90 // different cores at runtime. Currently, macro-substitution of device ordinals 91 // is only supported for the following whitelisted ops: 92 // * TPUExecute 93 // * InfeedEnqueue 94 // * InfeedEnqueueTuple 95 // 96 // Attempting to compute a TPUPartitionedCallOp whose function body has a 97 // non-whitelisted node bearing an attribute named "device_ordinal" will result 98 // in an error. 99 // 100 // TODO(akshayka): This class duplicates most of the logic of 101 // `PartitionedCallOp`; once that class and this one have evolved to stable 102 // states, and if at that time they remain sufficiently similar, either unify 103 // them in one op or set up an inheritance structure that allows for code reuse. 104 class TPUPartitionedCallOp : public AsyncOpKernel { 105 public: TPUPartitionedCallOp(OpKernelConstruction * ctx)106 explicit TPUPartitionedCallOp(OpKernelConstruction* ctx) 107 : AsyncOpKernel(ctx), 108 pool_(ctx->env(), "InitializeVarOnTPUPool", 1), 109 library_runtime_(nullptr) { 110 OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); 111 auto status = ctx->GetAttr("autotuner_thresh", &autotuner_thresh_); 112 if (!status.ok()) { 113 autotuner_thresh_ = 0; 114 } 115 tensorflow::tpu::OpsApiFn()->TfTpu_GetTpuPartitionedCallParamsFn( 116 &runtime_params_); 117 } 118 ~TPUPartitionedCallOp()119 ~TPUPartitionedCallOp() override {} 120 121 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; 122 123 private: 124 struct DeviceAndFHandle { 125 std::string device; 126 FHandle handle; 127 128 // The FLD passed to `library_runtime_` as an overlay function library for 129 // instantiation of function `handle`. This is a snapshot of the currrent 130 // `flib_def_`. Since `flib_def_` can be changed concurrently by another 131 // graph rewrite when executing `handle`, we need to make sure each 132 // `handle` uses a different FLD to avoid races. See b/181149591. 133 std::unique_ptr<FunctionLibraryDefinition> flib_def; 134 }; 135 136 Status GetTpuCoreOrdinal(OpKernelContext* ctx, uint64 input_hash, 137 int64_t* ordinal_selector_req_id, 138 int32_t* core_ordinal) 139 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 140 141 // Helper to create and initialize a TPU variable given a CPU variable 142 // var: the CPU variable created by the user 143 // ndef: the node def of the corresponding TPU var handle that we created 144 // device_ordinal: TPU device ordinal on which to initialize this variable 145 Status InitializeVarOnTPU(OpKernelContext* ctx, 146 const core::RefCountPtr<Var>& var, NodeDef* ndef, 147 int device_ordinal, bool fast_mem) 148 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 149 150 // Helper to create and initialize partitioned TPU variables given a CPU 151 // variable with XLA sharding annotation. 152 // var: the CPU variable created by the user. 153 // ndefs: the node def of the corresponding TPU var handle on all the logical 154 // cores. 155 // split_dim: the partition dimension of the variable. If -1, the variable is 156 // replicated. 157 // device_ordinal: The index of the TPU core that is scheduled to run 158 // the computation. In the case of XLA SPMD, it is the "primary" core, which 159 // is the smallest index of all the cores. 160 Status InitializeShardedVarOnTPU(OpKernelContext* ctx, 161 const core::RefCountPtr<Var>& var, 162 std::vector<NodeDef>& ndefs, int split_dim, 163 int device_ordinal) 164 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 165 166 // Check if any of the immediate successors of node has attribute 167 // "_tpu_replicate". 168 bool IsInputToTPUReplicate(Node* node) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 169 170 // Replace an _Arg node of type DT_RESOURCE by a VarHandleOp on TPU 171 Status ReplaceResourceArgsWithVarHandleOps(Graph* graph, OpKernelContext* ctx, 172 int device_ordinal, 173 int num_cores_per_replica, 174 bool enable_spmd_xla_partitioning) 175 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 176 177 // Replace a _Arg node indicates a variable on CPU host by sharded/replicated 178 // variables on all logical TPU devices. 179 Status ReplaceAndPartitionXLAShardingVariable( 180 Graph* graph, OpKernelContext* ctx, int device_ordinal, 181 ResourceHandle& handle, Node* variable, int num_cores_per_replica) 182 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 183 184 Status ShardInputsWithXlaSharding(Graph* graph, int num_cores_per_replica, 185 OpKernelContext* ctx) 186 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 187 188 // Rewrite the graph for input and output optimiazations. 189 // TODO(ylc): Move this function to Graph optimization pass. 190 Status OptimizeTpuInputOutputTensors( 191 Graph* graph, bool enable_spmd_xla_partitioning, 192 int num_cores_per_replica, 193 std::map<std::string, std::vector<int>>& named_input_shapes, 194 OpKernelContext* ctx) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 195 196 Status InferShapesWithResourceVar(Graph* graph, OpKernelContext* ctx, 197 std::map<int, InferredShape>& arg_shapes, 198 GraphShapeInfo* tpu_inferred_info); 199 200 // Copies the graph backing `func_` into `graph`. 201 Status GetGraphFromFunction(Graph* graph, int device_ordinal, 202 int* num_core_per_replica, 203 bool* use_spmd_for_xla_partitioning) 204 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 205 206 // Places the graph carried by `optimization_options` and runs graph 207 // optimization passes (pre-placement, post-placement, and post-rewrite). 208 Status PlacementHelper( 209 const DeviceSet& device_set, 210 const GraphOptimizationPassOptions& optimization_options, 211 const string& function_name); 212 // Partitions `graph`, populates `subgraphs` with the partitions, and runs 213 // the post-partitioning graph optimization passes. 214 Status PartitionHelper( 215 const DeviceSet& device_set, 216 const GraphOptimizationPassOptions& optimization_options, Graph* graph, 217 std::unordered_map<std::string, std::unique_ptr<Graph>>* subgraphs); 218 219 // Adds and instantiates a function backed by `graph` with name 220 // `function_name` on device `target_device`, storing the handle in `handle`. 221 // If `out_flib_def` is not null, it will be set to a copy of `flib_def_` and 222 // used for instantiation. 223 Status InstantiatePartition( 224 const Graph& graph, const string& function_name, 225 const string& target_device, FHandle* handle, 226 std::unique_ptr<FunctionLibraryDefinition>* out_flib_def) 227 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 228 // Adds and instantiates functions for each subgraph in `subgraphs` after 229 // rewriting nodes' `device_ordinal` attributes to match `replica_id` when 230 // num_cores_per_replica == 1. 231 Status InstantiateFunctionsFromSubgraphs( 232 const DeviceSet& device_set, int replica_id, uint64 cache_hash, 233 int num_cores_per_replica, 234 std::unordered_map<std::string, std::unique_ptr<Graph>> subgraphs) 235 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 236 237 // Rewrites `graph` such that the device ordinal attributes of all whitelisted 238 // nodes (see `IsSupportedTPUOp`) are set to `device_ordinal`; 239 // `*modified` is set to true if the graph is modified in the process (i.e., 240 // if it contains a whitelisted node), otherwise is unmodified. 241 // 242 // Returns an error if 243 // (1) the graph contains a non-whitelisted node that carries an attribute 244 // with name "device_ordinal", or 245 // (2) the set of device ordinals found among the graph's nodes has 246 // cardinality greater than 1. 247 Status SetDeviceOrdinal(const DeviceSet& device_set, int device_ordinal, 248 Graph* graph, bool* modified) 249 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 250 251 void ExecuteRemoteFunction(const FunctionLibraryRuntime::Options& opts, 252 FHandle handle, OpKernelContext* ctx, 253 ReffedStatusCallback* done) 254 ABSL_LOCKS_EXCLUDED(mu_); 255 void ExecuteLocalFunction(const FunctionLibraryRuntime::Options& opts, 256 const OpInputList& arguments, FHandle handle, 257 OpKernelContext* ctx, ReffedStatusCallback* done) 258 ABSL_LOCKS_EXCLUDED(mu_); 259 void ExecuteFunctions(const std::vector<DeviceAndFHandle>& functions, 260 OpKernelContext* ctx, int device_ordinal, 261 int64_t ordinal_selector_req_id, DoneCallback done) 262 ABSL_LOCKS_EXCLUDED(mu_); 263 ShouldUseRemoteExecutionForFn(const std::string & target_device,bool * remote_execution)264 Status ShouldUseRemoteExecutionForFn(const std::string& target_device, 265 bool* remote_execution) { 266 DeviceNameUtils::ParsedName target_device_parsed; 267 DeviceNameUtils::ParsedName local_device_parsed; 268 269 if (!DeviceNameUtils::ParseFullOrLocalName(target_device, 270 &target_device_parsed)) { 271 return errors::InvalidArgument("Cannot parse target device ", 272 target_device); 273 } 274 if (!DeviceNameUtils::ParseFullOrLocalName(local_device_name_, 275 &local_device_parsed)) { 276 return errors::InvalidArgument("Cannot parse local device ", 277 local_device_name_); 278 } 279 280 if (DeviceNameUtils::AreCompatibleDevNames(target_device_parsed, 281 local_device_parsed)) { 282 *remote_execution = false; 283 } else { 284 *remote_execution = true; 285 } 286 return Status::OK(); 287 } 288 289 // Init once flagas. 290 absl::once_flag once_; 291 absl::once_flag ordinal_selector_once_; 292 293 // Device manager and device set. 294 const DeviceMgr* device_mgr_; 295 DeviceSet device_set_; 296 297 // Threadpool. 298 thread::ThreadPool pool_; 299 300 // `func_` is the original function supplied to this OpKernel. 301 NameAttrList func_; 302 string local_device_name_; 303 // Maps from cache key to their corresponding functions, which are 304 // represented as (device, handle) pairs. 305 gtl::FlatMap<uint64, std::vector<DeviceAndFHandle>> partition_cache_ 306 ABSL_GUARDED_BY(mu_); 307 308 // A set contains seen ordinals. Used by variable initialization on TPU. 309 absl::flat_hash_set<int> seen_ordinals_; 310 311 // Record the indices of the _Arg with type DT_RESOURCE that goes 312 // into a TPU Op. 313 std::vector<bool> replaced_input_indices_; 314 315 absl::Mutex mu_; 316 // Function shards are added to the `flib_def_`, and later on it'll create 317 // a copy of `flib_def_` to pass to `library_runtime_` as an overlay function 318 // library for instantiation. 319 std::unique_ptr<FunctionLibraryDefinition> flib_def_; 320 FunctionLibraryRuntime* library_runtime_; 321 322 // Used to uniquify function names in `flib_def_`. 323 uint32 suffix_ = 0; 324 325 // Minimum number of run steps (batches) necessary to trigger xla autotuner. 326 int autotuner_thresh_ = 0; 327 328 // TPU core selection. 329 std::shared_ptr<tpu::TPUOrdinalSelector> ordinal_selector_; 330 331 // Maps input hash to TF fingerprint. 332 absl::flat_hash_map<uint64, uint64> inputs_to_fingerprint_; 333 334 // List of TPU devices 335 std::vector<Device*> tpu_devices_; 336 337 TpuPartitionedCall_Params runtime_params_; 338 }; 339 340 } // namespace tensorflow 341 342 #endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_FUNCTIONAL_OPS_H_ 343