1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ 16 #define TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ 17 18 #include <string> 19 #include <vector> 20 21 #include "absl/container/flat_hash_set.h" 22 #include "tensorflow/core/framework/device_attributes.pb.h" 23 #include "tensorflow/core/framework/device_base.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/lib/core/refcount.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/platform/intrusive_ptr.h" 28 29 namespace tensorflow { 30 31 class BufRendezvous; 32 class CancellationManager; 33 class CompleteGroupRequest; 34 class CompleteGroupResponse; 35 class CompleteInstanceRequest; 36 class CompleteInstanceResponse; 37 class Device; 38 class DeviceMgr; 39 class GetStepSequenceRequest; 40 class GetStepSequenceResponse; 41 class NcclManager; 42 class Tensor; 43 44 // Types of supported collective operations. 45 enum CollectiveType { 46 REDUCTION_COLLECTIVE = 0, 47 BROADCAST_COLLECTIVE, 48 GATHER_COLLECTIVE, 49 PERMUTE_COLLECTIVE, 50 ALL_TO_ALL_COLLECTIVE, 51 UNDEFINED_COLLECTIVE, 52 }; 53 54 // Some collective op implementations require runtime group configuration from 55 // the OpKernel. Currently, this struct is used to set communicator key for 56 // NCCL-based collective implementation. 57 struct CollGroupRuntimeDetails { 58 string communicator_key; // for communicator-based techniques e.g. NCCL 59 string ToString() const; 60 }; 61 62 struct CollGroupMember { 63 DeviceAttributes device; 64 string task; 65 bool is_local; 66 // User provided rank 67 int32 rank = -1; 68 }; 69 70 // Data common to all members of a device group. 71 // All members share the same device set but its order is 72 // particular to an instance so it is stored there. 73 struct CollGroupParams { 74 // Inputs from Collective ops: 75 int32 group_key; 76 int32 group_size; 77 DeviceType device_type; 78 int user_specified_rank = -1; // rank provided by the user. 79 // Generated from Collective Group Resolver: 80 // Members in this group, in default rank order. 81 std::vector<CollGroupMember> members; 82 // True if every task has the same number of devices. 83 bool same_num_devices_per_task = false; 84 // Task -> number of devices on that task. 85 std::unordered_map<string, int32> num_devices_per_task; 86 int32 num_tasks; // number of distinct tasks in group 87 CollGroupRuntimeDetails runtime_details; 88 string ToString() const; CollGroupParamsCollGroupParams89 CollGroupParams() 90 : group_key(0), group_size(0), device_type(DEVICE_CPU), num_tasks(0) {} 91 }; 92 93 // The best implementation of a collective op depends on many factors 94 // including the number of devices involved, the topology of 95 // interconnects between them and the sizes of inputs. This structure 96 // is used in generating and representing data movement choreography 97 // for each specific algorithm, hence it does not have a single, fixed 98 // interpretation. On first execution the runtime will update this 99 // structure with decisions that will guide all subsequent executions. 100 struct CollImplDetails { 101 string collective_name; 102 std::vector<std::vector<int>> subdiv_permutations; 103 // subdiv_offsets and max_subdivs_per_device are used together as follows: 104 // When subdiv_offsets is provided (non-empty) it is used as is. When 105 // subdiv_offsets is not provided subdivisons are generated dynamically 106 // constrained by max_subdivs_per_device. When subdiv_offsets is empty AND 107 // max_subdivs_per_device = 0 an internal default kMaxSubdivsPerDeviceDefault 108 // is used. When max_subdivs_per_device = -1, no subivision is done. 109 int max_subdivs_per_device = -1; // Upper bound on subdivisions per device. 110 std::vector<int> subdiv_offsets; 111 std::vector<int> subdiv_source_rank; // rank of source in each subdiv 112 std::vector<int32> 113 dependencies; // collective instances on which this node depends 114 string communication_hint; // user-supplied hint for implementation choice, 115 // e.g. ring or nccl 116 float timeout_seconds; // If non zero, set a completion timeout for the 117 // collective op to detect staleness. 118 }; 119 120 // Data common to all members of a collective instance. 121 // TODO(b/163171014) Refactor this struct to not be a union of all fields. 122 struct CollInstanceParams { 123 // Identifies all participating graph nodes. 124 int32 instance_key = -1; 125 CollectiveType type = UNDEFINED_COLLECTIVE; 126 DataType data_type = DT_FLOAT; 127 TensorShape shape = {0}; 128 CollImplDetails impl_details; 129 string ToString() const; 130 CollInstanceParams& operator=(const struct CollInstanceParams& other); 131 std::vector<string> devices; // permuter only 132 133 // For permuter only 134 // Each rank in the permutation is a receiver. 135 // Indices of each rank means a sender to that rank. 136 // Example: permutation = {2,0,1} means 137 // rank 0 sends to rank 2 138 // rank 1 sends to rank 0 139 // rank 2 sends to rank 1 140 std::vector<int> permutation; 141 }; 142 143 // Unique to a single CollectiveOp node. 144 struct CollectiveParams : public core::RefCounted { 145 CollGroupParams group; 146 CollInstanceParams instance; 147 148 string name = ""; // node name used only for log or error messages 149 int default_rank = -1; // index of this op within device_names 150 bool is_source = false; // broadcast only 151 int source_rank = -1; // broadcast only 152 // Rank of this device in each subdivision permutation. 153 std::vector<int> subdiv_rank; 154 OpKernel* merge_op = nullptr; // reduction only 155 OpKernel* final_op = nullptr; // reduction only 156 string ToString() const; 157 bool run_group_initialization = true; 158 }; 159 160 class CollectiveExecutor; 161 162 // Interface that provides resolution of device localities. 163 class DeviceResolverInterface { 164 public: ~DeviceResolverInterface()165 virtual ~DeviceResolverInterface() {} 166 167 // Populates *attributes with the DeviceAttributes of the specified device. 168 virtual Status GetDeviceAttributes(const string& device, 169 DeviceAttributes* attributes) = 0; 170 171 // Returns all device attributes of a task. 172 virtual Status GetAllDeviceAttributes( 173 const string& task, std::vector<DeviceAttributes>* attributes) = 0; 174 175 // Updates device attributes. It returns error if any device already 176 // exists in the DeviceResolver and has a different incarnation. 177 virtual Status UpdateDeviceAttributes( 178 const std::vector<DeviceAttributes>& attributes) = 0; 179 }; 180 181 // Interface that provides resolution of shared CollectiveParams fields. 182 class ParamResolverInterface { 183 public: ~ParamResolverInterface()184 virtual ~ParamResolverInterface() {} 185 186 // Called by each collective op at first execution in order to fill out 187 // the CollectiveParams structure with data gathered from the full 188 // (maybe distributed) collection of peer nodes. 189 virtual void CompleteParamsAsync(const DeviceAttributes& device, 190 CollectiveParams* cp, 191 CancellationManager* cancel_mgr, 192 const StatusCallback& done) = 0; 193 194 // Completes group_params with data gathered from all devices in the group. 195 // This blocks until all devices are there. 196 virtual void CompleteGroupAsync(const DeviceAttributes& device, 197 CollGroupParams* group_params, 198 CancellationManager* cancel_mgr, 199 const StatusCallback& done) = 0; 200 201 // Used within a distributed implementation to discover/verify data 202 // shared across an instance group. 203 // Note: this works differently from CompleteGroupAsync as a refactor is in 204 // progress. 205 virtual void CompleteInstanceAsync(const CompleteInstanceRequest* request, 206 CompleteInstanceResponse* response, 207 CancellationManager* cancel_mgr, 208 const StatusCallback& done) = 0; 209 210 // Looks up a group. It returns an error if the group is not ready or not 211 // found. 212 virtual Status LookupGroup(int32_t group_key, CollGroupParams* group) = 0; 213 214 // Aborts the resolver. After abortion the resolver can no longer be used. 215 virtual void StartAbort(const Status& s) = 0; 216 }; 217 218 // Graphs which utilize Collective Ops in a common instance must 219 // execute with identical step_ids even if they are disjoint graphs 220 // run by otherwise independent tasks. This interface supplies 221 // coordinated step_ids to use in such cases. 222 class StepSequenceInterface { 223 public: ~StepSequenceInterface()224 virtual ~StepSequenceInterface() {} 225 226 // Used with a distributed implementation to coordinate step_id 227 // sequences across tasks. 228 virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request, 229 GetStepSequenceResponse* response, 230 const StatusCallback& done) = 0; 231 232 // Refresh the local per-graph_key step_id sequence from collective 233 // group leader, if applicable. 234 virtual void RefreshStepIdSequenceAsync(int64_t graph_key, 235 const StatusCallback& done) = 0; 236 237 // Returns the step_id that should be used for initiating a new execution 238 // on the specified graph. May return the same step_id multiple times if 239 // RetireStepId or RefreshStepIdReservation is not called. 240 virtual int64_t NextStepId(int64_t graph_key) = 0; 241 242 // Reports that execution of the given step has completed successfully. 243 // Should be called immediately after a step completes with OK status, 244 // prior to calling NextStepId(). If the step fails, don't call. 245 virtual void RetireStepId(int64_t graph_key, int64_t step_id) = 0; 246 }; 247 248 class NcclCommunicatorInterface; 249 250 // Interface that provides access to per-step CollectiveExecutor 251 // instances and various distributed resolution capabilities. 252 class CollectiveExecutorMgrInterface : public StepSequenceInterface { 253 public: ~CollectiveExecutorMgrInterface()254 virtual ~CollectiveExecutorMgrInterface() {} 255 256 // Returns the step-specific CollectiveExecutor, creating if one does not 257 // already exist. The caller assumes ownership of one Ref on the object. 258 virtual CollectiveExecutor* FindOrCreate(int64_t step_id) = 0; 259 260 // If there is a CollectiveExecutor for step_id, remove it from the 261 // table. 262 virtual void Cleanup(int64_t step_id) = 0; 263 264 virtual ParamResolverInterface* GetParamResolver() const = 0; 265 266 virtual DeviceResolverInterface* GetDeviceResolver() const = 0; 267 268 virtual NcclCommunicatorInterface* GetNcclCommunicator() const = 0; 269 }; 270 271 // Interface that a Collective Op implementation uses to exchange data 272 // with peers. Note that data exchange is currently limited to types 273 // for which DMAHelper::CanUseDMA() returns true, i.e. dense numeric 274 // types. 275 class CollectiveRemoteAccess { 276 public: ~CollectiveRemoteAccess()277 virtual ~CollectiveRemoteAccess() {} 278 279 virtual void RecvFromPeer(const string& peer_device, const string& peer_task, 280 bool peer_is_local, const string& key, 281 Device* to_device, DeviceContext* to_device_ctx, 282 const AllocatorAttributes& to_alloc_attr, 283 Tensor* to_tensor, 284 const DeviceLocality& client_locality, 285 int dev_to_dev_stream_index, 286 CancellationManager* cancellation_manager, 287 const StatusCallback& done) = 0; 288 289 virtual void PostToPeer(const string& peer_device, const string& peer_task, 290 const string& key, Device* from_device, 291 DeviceContext* from_device_ctx, 292 const AllocatorAttributes& from_alloc_attr, 293 const Tensor* from_tensor, 294 const DeviceLocality& client_locality, 295 CancellationManager* cancellation_manager, 296 const StatusCallback& done) = 0; 297 298 // Checks the health of a collective peer. It probes the peer to see if it is 299 // alive. Note that if a peer has restarted, it's considered a different one, 300 // so CheckPeerHealth fails. 301 virtual void CheckPeerHealth(const string& peer_task, int64_t timeout_in_ms, 302 const StatusCallback& done) = 0; 303 304 virtual BufRendezvous* buf_rendezvous() = 0; 305 306 virtual void StartAbort(const Status& s) = 0; 307 }; 308 309 // A step-specific object that can execute a collective operation completely 310 // described by a CollectiveParams object. 311 class CollectiveExecutor : public core::RefCounted { 312 public: StartAbort(const Status & s)313 virtual void StartAbort(const Status& s) {} 314 ExecuteAsync(OpKernelContext * ctx,const CollectiveParams * col_params,const string & exec_key,StatusCallback done)315 virtual void ExecuteAsync(OpKernelContext* ctx, 316 const CollectiveParams* col_params, 317 const string& exec_key, StatusCallback done) { 318 done(errors::Internal( 319 "A collective Op has been called in a context in which " 320 "a CollectiveExecutor has not been provided.")); 321 } 322 CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,StatusCallback done)323 virtual void CompleteParamsAsync(const DeviceAttributes& device, 324 CollectiveParams* cp, 325 CancellationManager* cancel_mgr, 326 StatusCallback done) { 327 done(errors::Internal( 328 "A collective Op has been called in a context in which " 329 "a CollectiveExecutor has not been provided.")); 330 } 331 CompleteGroupAsync(const DeviceAttributes & device,CollGroupParams * group_params,CancellationManager * cancel_mgr,StatusCallback done)332 virtual void CompleteGroupAsync(const DeviceAttributes& device, 333 CollGroupParams* group_params, 334 CancellationManager* cancel_mgr, 335 StatusCallback done) { 336 return cem_->GetParamResolver()->CompleteGroupAsync(device, group_params, 337 cancel_mgr, done); 338 } 339 LookupGroup(int32_t group_key,CollGroupParams * group)340 virtual Status LookupGroup(int32_t group_key, CollGroupParams* group) { 341 return cem_->GetParamResolver()->LookupGroup(group_key, group); 342 } 343 344 // Runs the potentially-blocking closure/expensive callback. 345 virtual void RunClosure(std::function<void()> closure) = 0; 346 remote_access()347 virtual CollectiveRemoteAccess* remote_access() { return nullptr; } 348 349 // `WaitForDependencies` and `Launched` are used for fine-grained control of 350 // execution order between collective instances. These functions are intended 351 // to be called in `Run` function of collective implementations, and may be 352 // used to make part, or whole, of the collective execution ordered with 353 // respect to other collective instances. 354 // 355 // `WaitForDependencies` will block until it is safe to continue the callee's 356 // execution, where safety is defined as: ordered with respect to the 357 // collective instances defined in the callee's `wait_for` attribute. WaitForDependencies(const CollectiveParams & col_params)358 virtual void WaitForDependencies(const CollectiveParams& col_params) {} 359 // `UnblockDependencies` unblocks the dependent collective instances by 360 // recording that this caller's device has completed the critical portion of 361 // the collective execution. UnblockDependencies(const CollectiveParams & col_params)362 virtual void UnblockDependencies(const CollectiveParams& col_params) {} 363 364 // Used to designate an invalid group or instance key. 365 static int64_t kInvalidId; 366 367 // Lexically scoped handle for Ref. 368 class Handle { 369 public: Handle(CollectiveExecutor * ce,bool inherit_ref)370 explicit Handle(CollectiveExecutor* ce, bool inherit_ref) : ce_(ce) { 371 if (!inherit_ref) ce->Ref(); 372 } ~Handle()373 ~Handle() { ce_->Unref(); } get()374 CollectiveExecutor* get() const { return ce_; } 375 376 private: 377 CollectiveExecutor* ce_; 378 }; 379 380 protected: CollectiveExecutor(CollectiveExecutorMgrInterface * cem)381 explicit CollectiveExecutor(CollectiveExecutorMgrInterface* cem) 382 : cem_(cem) {} 383 384 // For use only by derived classes 385 static OpKernelContext::Params* CtxParams(OpKernelContext* ctx); 386 CollectiveExecutorMgrInterface* cem_; 387 388 TF_DISALLOW_COPY_AND_ASSIGN(CollectiveExecutor); 389 }; 390 391 struct CollectiveContext { 392 CollectiveExecutor* col_exec; // Not owned 393 NcclCommunicatorInterface* nccl_communicator; // Not owned 394 const DeviceMgr* dev_mgr; // Not owned 395 OpKernelContext* op_ctx; // Not owned 396 OpKernelContext::Params* op_params; // Not owned 397 core::IntrusivePtr<const CollectiveParams> col_params; 398 const string exec_key; 399 const int64_t step_id; 400 const Tensor* input; // Not owned 401 Tensor* output; // Not owned 402 Device* device; // The device for which this instance labors 403 const string device_name; 404 DeviceLocality device_locality; 405 406 CollectiveContext(CollectiveExecutor* col_exec, 407 NcclCommunicatorInterface* nccl_communicator, 408 const DeviceMgr* dev_mgr, OpKernelContext* ctx, 409 OpKernelContext::Params* op_params, 410 const CollectiveParams* col_params, const string& exec_key, 411 int64_t step_id, const Tensor* input, Tensor* output); 412 }; 413 414 class NcclCommunicatorInterface { 415 public: 416 virtual ~NcclCommunicatorInterface() = default; 417 418 virtual string GenerateCommunicatorKey() = 0; 419 420 virtual void Enqueue(std::shared_ptr<CollectiveContext> col_ctx, 421 StatusCallback done) = 0; 422 423 virtual void StartAbort(const Status& s) = 0; 424 }; 425 426 // Interface of a Collective Op implementation. Each specific CollectiveOp will 427 // implement this interface and register the implementation via the 428 // CollectiveRegistry detailed below. See common_runtime/ring_reducer and 429 // common_runtime/hierarchical_tree_broadcaster for examples. 430 class CollectiveImplementationInterface : public core::RefCounted { 431 public: 432 virtual ~CollectiveImplementationInterface() = default; 433 434 // Initializes the portions of `col_params` specific to this 435 // implementation. Called exactly once for every Collective instance during 436 // the CollectiveParams resolution process when the graph is first executed, 437 // at the end of `CompleteInstanceLocal()`. 438 // NOTE(ayushd): This is effectively a static function because it modifies the 439 // `col_params` passed in and should not manipulate any data members. However 440 // because it is virtual and needs to be implemented by every derived class we 441 // do not mark it as static. 442 virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0; 443 444 // Prepares the CollectiveContext for executing this CollectiveImplementation. 445 // Called from CollectiveExecutor right before calling Run(). The 446 // CollectiveContext passed in must outlive the CollectiveImplementation 447 // object. 448 virtual Status InitializeCollectiveContext( 449 std::shared_ptr<CollectiveContext> col_ctx) = 0; 450 451 // Processes and moves data according to the logic of this Collective 452 // implementation. Relies on appropriate initialization of op-specific 453 // CollectiveParams in InitializeCollectiveParams(), as well as appropriate 454 // context initialization in InitializeCollectiveContext(). 455 virtual void Run(StatusCallback done) = 0; 456 }; 457 458 // Static-methods only class for registering and looking up collective 459 // implementations. 460 class CollectiveRegistry { 461 public: 462 using Factory = std::function<CollectiveImplementationInterface*()>; 463 // Looks up a previously registered CollectiveImplementation under 464 // `collective_name`. If found, creates an instance of the implementation and 465 // assign to `implementation`. 466 static Status Lookup(const string& collective_name, 467 CollectiveImplementationInterface** implementation); 468 469 // Looks up a previously registered CollectiveImplementation under 470 // `collective_name`. If found, returns the static instance of this 471 // implementation via `implementation`. This instance should only be used to 472 // call InitializateCollectiveParams. 473 static Status LookupParamResolverInstance( 474 const string& collective_name, 475 CollectiveImplementationInterface** implementation); 476 477 // Returns all registered collective implementations. 478 static void GetAll( 479 std::vector<CollectiveImplementationInterface*>* implementations); 480 481 private: 482 friend class CollectiveRegistration; 483 // Registers a CollectiveImplementation with name `collective_name` and 484 // factory `factory`. The latter is a function used to create instances of 485 // the CollectiveImplementation. Also creates a static instance of the 486 // implementation - this instance is used during param resolution and should 487 // only be used to call InitializeCollectiveParams. 488 static Status Register(const string& collective_name, Factory factory); 489 490 static Status LookupHelper(const string& collective_name, 491 CollectiveImplementationInterface** implementation, 492 bool param_resolver); 493 }; 494 495 // Class used to call CollectiveRegistry::Register. This should only be used to 496 // create a global static object. 497 class CollectiveRegistration { 498 public: CollectiveRegistration(const string & collective_name,CollectiveRegistry::Factory factory)499 CollectiveRegistration(const string& collective_name, 500 CollectiveRegistry::Factory factory) { 501 TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory)); 502 } 503 }; 504 505 #define REGISTER_COLLECTIVE(name, implementation) \ 506 static CollectiveRegistration register_##name##_collective( \ 507 #name, []() { return new implementation; }); 508 509 } // namespace tensorflow 510 511 #endif // TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ 512