1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ 16 #define TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ 17 18 #include <string> 19 #include <vector> 20 21 #include "tensorflow/core/framework/device_attributes.pb.h" 22 #include "tensorflow/core/framework/device_base.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/lib/core/refcount.h" 25 #include "tensorflow/core/lib/core/status.h" 26 27 namespace tensorflow { 28 29 class BufRendezvous; 30 class CancellationManager; 31 class CompleteGroupRequest; 32 class CompleteGroupResponse; 33 class CompleteInstanceRequest; 34 class CompleteInstanceResponse; 35 class Device; 36 class DeviceMgr; 37 class GetStepSequenceRequest; 38 class GetStepSequenceResponse; 39 class NcclManager; 40 class Tensor; 41 42 // Types of supported collective operations. 43 enum CollectiveType { 44 REDUCTION_COLLECTIVE = 0, 45 BROADCAST_COLLECTIVE, 46 GATHER_COLLECTIVE, 47 PERMUTE_COLLECTIVE, 48 ALL_TO_ALL_COLLECTIVE, 49 UNDEFINED_COLLECTIVE, 50 }; 51 52 // Some collective op implementations require runtime group configuration from 53 // the OpKernel. Currently, this struct is used to set communicator key for 54 // NCCL-based collective implementation. 55 struct CollGroupRuntimeDetails { 56 string communicator_key; // for communicator-based techniques e.g. NCCL 57 string ToString() const; 58 }; 59 60 // Data common to all members of a device group. 61 // All members share the same device set but its order is 62 // particular to an instance so it is stored there. 63 struct CollGroupParams { 64 int32 group_key; 65 int32 group_size; 66 DeviceType device_type; 67 // Devices in this group, in default rank order. 68 std::vector<DeviceAttributes> devices; 69 // Task name prefix of corresponding device name. 70 std::vector<string> task_names; 71 // True if every task has the same number of devices. 72 bool same_num_devices_per_task = false; 73 // Task -> number of devices on that task. 74 std::unordered_map<string, int32> num_devices_per_task; 75 int32 num_tasks; // number of distinct tasks in group 76 CollGroupRuntimeDetails runtime_details; 77 string ToString() const; CollGroupParamsCollGroupParams78 CollGroupParams() 79 : group_key(0), group_size(0), device_type(DEVICE_CPU), num_tasks(0) {} 80 }; 81 82 // The best implementation of a collective op depends on many factors 83 // including the number of devices involved, the topology of 84 // interconnects between them and the sizes of inputs. This structure 85 // is used in generating and representing data movement choreography 86 // for each specific algorithm, hence it does not have a single, fixed 87 // interpretation. On first execution the runtime will update this 88 // structure with decisions that will guide all subsequent executions. 89 struct CollImplDetails { 90 string collective_name; 91 std::vector<std::vector<int>> subdiv_permutations; 92 // subdiv_offsets and max_subdivs_per_device are used together as follows: 93 // When subdiv_offsets is provided (non-empty) it is used as is. When 94 // subdiv_offsets is not provided subdivisons are generated dynamically 95 // constrained by max_subdivs_per_device. When subdiv_offsets is empty AND 96 // max_subdivs_per_device = 0 an internal default kMaxSubdivsPerDeviceDefault 97 // is used. When max_subdivs_per_device = -1, no subivision is done. 98 int max_subdivs_per_device = -1; // Upper bound on subdivisions per device. 99 std::vector<int> subdiv_offsets; 100 std::vector<int> subdiv_source_rank; // rank of source in each subdiv 101 std::vector<int32> 102 dependencies; // collective instances on which this node depends 103 string communication_hint; // user-supplied hint for implementation choice, 104 // e.g. ring or nccl 105 float timeout_seconds; // If non zero, set a completion timeout for the 106 // collective op to detect staleness. 107 }; 108 109 // Data common to all members of a collective instance. 110 // TODO(b/163171014) Refactor this struct to not be a union of all fields. 111 struct CollInstanceParams { 112 // Identifies all participating graph nodes. 113 int32 instance_key = -1; 114 CollectiveType type = UNDEFINED_COLLECTIVE; 115 DataType data_type = DT_FLOAT; 116 TensorShape shape = {0}; 117 CollImplDetails impl_details; 118 string ToString() const; 119 CollInstanceParams& operator=(const struct CollInstanceParams& other); 120 std::vector<string> devices; // permuter only 121 122 // For permuter only 123 // Each rank in the permutation is a receiver. 124 // Indices of each rank means a sender to that rank. 125 // Example: permutation = {2,0,1} means 126 // rank 0 sends to rank 2 127 // rank 1 sends to rank 0 128 // rank 2 sends to rank 1 129 std::vector<int> permutation; 130 }; 131 132 // Data common to all instance members in the same task. 133 struct CollTaskParams { 134 // True for devices that are local to the process, i.e. no RPC needed. 135 std::vector<bool> is_local; 136 string ToString() const; 137 }; 138 139 // Unique to a single CollectiveOp node. 140 struct CollectiveParams : public core::RefCounted { 141 CollGroupParams group; 142 CollInstanceParams instance; 143 CollTaskParams task; 144 145 string name = ""; // node name used only for log or error messages 146 int default_rank = -1; // index of this op within device_names 147 bool is_source = false; // broadcast only 148 int source_rank = -1; // broadcast only 149 // Rank of this device in each subdivision permutation. 150 std::vector<int> subdiv_rank; 151 OpKernel* merge_op = nullptr; // reduction only 152 OpKernel* final_op = nullptr; // reduction only 153 string ToString() const; 154 }; 155 156 class CollectiveExecutor; 157 158 // Interface that provides resolution of device localities. 159 class DeviceResolverInterface { 160 public: ~DeviceResolverInterface()161 virtual ~DeviceResolverInterface() {} 162 163 // Populates *attributes with the DeviceAttributes of the specified device. 164 virtual Status GetDeviceAttributes(const string& device, 165 DeviceAttributes* attributes) = 0; 166 167 // Returns all device attributes of a task. 168 virtual Status GetAllDeviceAttributes( 169 const string& task, std::vector<DeviceAttributes>* attributes) = 0; 170 171 // Updates device attributes. It returns error if any device already 172 // exists in the DeviceResolver and has a different incarnation. 173 virtual Status UpdateDeviceAttributes( 174 const std::vector<DeviceAttributes>& attributes) = 0; 175 }; 176 177 // Interface that provides resolution of shared CollectiveParams fields. 178 class ParamResolverInterface { 179 public: ~ParamResolverInterface()180 virtual ~ParamResolverInterface() {} 181 182 // Called by each collective op at first execution in order to fill out 183 // the CollectiveParams structure with data gathered from the full 184 // (maybe distributed) collection of peer nodes. 185 virtual void CompleteParamsAsync(const DeviceAttributes& device, 186 CollectiveParams* cp, 187 CancellationManager* cancel_mgr, 188 const StatusCallback& done) = 0; 189 190 // Completes group_params with data gathered from all devices in the group. 191 // This blocks until all devices are there. 192 virtual void CompleteGroupAsync(const DeviceAttributes& device, 193 CollGroupParams* group_params, 194 CancellationManager* cancel_mgr, 195 const StatusCallback& done) = 0; 196 197 // Used within a distributed implementation to discover/verify data 198 // shared across an instance group. 199 // Note: this works differently from CompleteGroupAsync as a refactor is in 200 // progress. 201 virtual void CompleteInstanceAsync(const CompleteInstanceRequest* request, 202 CompleteInstanceResponse* response, 203 CancellationManager* cancel_mgr, 204 const StatusCallback& done) = 0; 205 206 // Aborts the resolver. After abortion the resolver can no longer be used. 207 virtual void StartAbort(const Status& s) = 0; 208 }; 209 210 // Graphs which utilize Collective Ops in a common instance must 211 // execute with identical step_ids even if they are disjoint graphs 212 // run by otherwise independent tasks. This interface supplies 213 // coordinated step_ids to use in such cases. 214 class StepSequenceInterface { 215 public: ~StepSequenceInterface()216 virtual ~StepSequenceInterface() {} 217 218 // Used with a distributed implementation to coordinate step_id 219 // sequences across tasks. 220 virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request, 221 GetStepSequenceResponse* response, 222 const StatusCallback& done) = 0; 223 224 // Refresh the local per-graph_key step_id sequence from collective 225 // group leader, if applicable. 226 virtual void RefreshStepIdSequenceAsync(int64_t graph_key, 227 const StatusCallback& done) = 0; 228 229 // Returns the step_id that should be used for initiating a new execution 230 // on the specified graph. May return the same step_id multiple times if 231 // RetireStepId or RefreshStepIdReservation is not called. 232 virtual int64 NextStepId(int64_t graph_key) = 0; 233 234 // Reports that execution of the given step has completed successfully. 235 // Should be called immediately after a step completes with OK status, 236 // prior to calling NextStepId(). If the step fails, don't call. 237 virtual void RetireStepId(int64_t graph_key, int64_t step_id) = 0; 238 }; 239 240 class NcclCommunicatorInterface; 241 242 // Interface that provides access to per-step CollectiveExecutor 243 // instances and various distributed resolution capabilities. 244 class CollectiveExecutorMgrInterface : public StepSequenceInterface { 245 public: ~CollectiveExecutorMgrInterface()246 virtual ~CollectiveExecutorMgrInterface() {} 247 248 // Returns the step-specific CollectiveExecutor, creating if one does not 249 // already exist. The caller assumes ownership of one Ref on the object. 250 virtual CollectiveExecutor* FindOrCreate(int64_t step_id) = 0; 251 252 // If there is a CollectiveExecutor for step_id, remove it from the 253 // table. 254 virtual void Cleanup(int64_t step_id) = 0; 255 256 virtual ParamResolverInterface* GetParamResolver() const = 0; 257 258 virtual DeviceResolverInterface* GetDeviceResolver() const = 0; 259 260 virtual NcclCommunicatorInterface* GetNcclCommunicator() const = 0; 261 }; 262 263 // Interface that a Collective Op implementation uses to exchange data 264 // with peers. Note that data exchange is currently limited to types 265 // for which DMAHelper::CanUseDMA() returns true, i.e. dense numeric 266 // types. 267 class CollectiveRemoteAccess { 268 public: ~CollectiveRemoteAccess()269 virtual ~CollectiveRemoteAccess() {} 270 271 virtual void RecvFromPeer(const string& peer_device, const string& peer_task, 272 bool peer_is_local, const string& key, 273 Device* to_device, DeviceContext* to_device_ctx, 274 const AllocatorAttributes& to_alloc_attr, 275 Tensor* to_tensor, 276 const DeviceLocality& client_locality, 277 int dev_to_dev_stream_index, 278 CancellationManager* cancellation_manager, 279 const StatusCallback& done) = 0; 280 281 virtual void PostToPeer(const string& peer_device, const string& peer_task, 282 const string& key, Device* from_device, 283 DeviceContext* from_device_ctx, 284 const AllocatorAttributes& from_alloc_attr, 285 const Tensor* from_tensor, 286 const DeviceLocality& client_locality, 287 CancellationManager* cancellation_manager, 288 const StatusCallback& done) = 0; 289 290 // Checks the health of a collective peer. It probes the peer to see if it is 291 // alive. Note that if a peer has restarted, it's considered a different one, 292 // so CheckPeerHealth fails. 293 virtual void CheckPeerHealth(const string& peer_task, int64_t timeout_in_ms, 294 const StatusCallback& done) = 0; 295 296 virtual BufRendezvous* buf_rendezvous() = 0; 297 298 virtual void StartAbort(const Status& s) = 0; 299 }; 300 301 // A step-specific object that can execute a collective operation completely 302 // described by a CollectiveParams object. 303 class CollectiveExecutor : public core::RefCounted { 304 public: StartAbort(const Status & s)305 virtual void StartAbort(const Status& s) {} 306 ExecuteAsync(OpKernelContext * ctx,const CollectiveParams * col_params,const string & exec_key,StatusCallback done)307 virtual void ExecuteAsync(OpKernelContext* ctx, 308 const CollectiveParams* col_params, 309 const string& exec_key, StatusCallback done) { 310 done(errors::Internal( 311 "A collective Op has been called in a context in which " 312 "a CollectiveExecutor has not been provided.")); 313 } 314 CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,StatusCallback done)315 virtual void CompleteParamsAsync(const DeviceAttributes& device, 316 CollectiveParams* cp, 317 CancellationManager* cancel_mgr, 318 StatusCallback done) { 319 done(errors::Internal( 320 "A collective Op has been called in a context in which " 321 "a CollectiveExecutor has not been provided.")); 322 } 323 324 // Runs the potentially-blocking closure/expensive callback. 325 virtual void RunClosure(std::function<void()> closure) = 0; 326 remote_access()327 virtual CollectiveRemoteAccess* remote_access() { return nullptr; } 328 329 // `WaitForDependencies` and `Launched` are used for fine-grained control of 330 // execution order between collective instances. These functions are intended 331 // to be called in `Run` function of collective implementations, and may be 332 // used to make part, or whole, of the collective execution ordered with 333 // respect to other collective instances. 334 // 335 // `WaitForDependencies` will block until it is safe to continue the callee's 336 // execution, where safety is defined as: ordered with respect to the 337 // collective instances defined in the callee's `wait_for` attribute. WaitForDependencies(const CollectiveParams & col_params)338 virtual void WaitForDependencies(const CollectiveParams& col_params) {} 339 // `UnblockDependencies` unblocks the dependent collective instances by 340 // recording that this caller's device has completed the critical portion of 341 // the collective execution. UnblockDependencies(const CollectiveParams & col_params)342 virtual void UnblockDependencies(const CollectiveParams& col_params) {} 343 344 // Used to designate an invalid group or instance key. 345 static int64_t kInvalidId; 346 347 // Lexically scoped handle for Ref. 348 class Handle { 349 public: Handle(CollectiveExecutor * ce,bool inherit_ref)350 explicit Handle(CollectiveExecutor* ce, bool inherit_ref) : ce_(ce) { 351 if (!inherit_ref) ce->Ref(); 352 } ~Handle()353 ~Handle() { ce_->Unref(); } get()354 CollectiveExecutor* get() const { return ce_; } 355 356 private: 357 CollectiveExecutor* ce_; 358 }; 359 360 protected: CollectiveExecutor(CollectiveExecutorMgrInterface * cem)361 explicit CollectiveExecutor(CollectiveExecutorMgrInterface* cem) 362 : cem_(cem) {} 363 364 // For use only by derived classes 365 static OpKernelContext::Params* CtxParams(OpKernelContext* ctx); 366 CollectiveExecutorMgrInterface* cem_; 367 368 TF_DISALLOW_COPY_AND_ASSIGN(CollectiveExecutor); 369 }; 370 371 struct CollectiveContext { 372 CollectiveExecutor* col_exec; // Not owned 373 NcclCommunicatorInterface* nccl_communicator; // Not owned 374 const DeviceMgr* dev_mgr; // Not owned 375 OpKernelContext* op_ctx; // Not owned 376 OpKernelContext::Params* op_params; // Not owned 377 const CollectiveParams* col_params; // Not owned 378 const string exec_key; 379 const int64 step_id; 380 const Tensor* input; // Not owned 381 Tensor* output; // Not owned 382 Device* device; // The device for which this instance labors 383 const string device_name; 384 DeviceLocality device_locality; 385 386 CollectiveContext(CollectiveExecutor* col_exec, 387 NcclCommunicatorInterface* nccl_communicator, 388 const DeviceMgr* dev_mgr, OpKernelContext* ctx, 389 OpKernelContext::Params* op_params, 390 const CollectiveParams* col_params, const string& exec_key, 391 int64_t step_id, const Tensor* input, Tensor* output); 392 }; 393 394 class NcclCommunicatorInterface { 395 public: 396 virtual ~NcclCommunicatorInterface() = default; 397 398 virtual string GenerateCommunicatorKey() = 0; 399 400 virtual void Enqueue(std::shared_ptr<CollectiveContext> col_ctx, 401 StatusCallback done) = 0; 402 403 virtual void StartAbort(const Status& s) = 0; 404 }; 405 406 // Interface of a Collective Op implementation. Each specific CollectiveOp will 407 // implement this interface and register the implementation via the 408 // CollectiveRegistry detailed below. See common_runtime/ring_reducer and 409 // common_runtime/hierarchical_tree_broadcaster for examples. 410 class CollectiveImplementationInterface : public core::RefCounted { 411 public: 412 virtual ~CollectiveImplementationInterface() = default; 413 414 // Initializes the portions of `col_params` specific to this 415 // implementation. Called exactly once for every Collective instance during 416 // the CollectiveParams resolution process when the graph is first executed, 417 // at the end of `CompleteInstanceLocal()`. 418 // NOTE(ayushd): This is effectively a static function because it modifies the 419 // `col_params` passed in and should not manipulate any data members. However 420 // because it is virtual and needs to be implemented by every derived class we 421 // do not mark it as static. 422 virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0; 423 424 // Prepares the CollectiveContext for executing this CollectiveImplementation. 425 // Called from CollectiveExecutor right before calling Run(). The 426 // CollectiveContext passed in must outlive the CollectiveImplementation 427 // object. 428 virtual Status InitializeCollectiveContext( 429 std::shared_ptr<CollectiveContext> col_ctx) = 0; 430 431 // Processes and moves data according to the logic of this Collective 432 // implementation. Relies on appropriate initialization of op-specific 433 // CollectiveParams in InitializeCollectiveParams(), as well as appropriate 434 // context initialization in InitializeCollectiveContext(). 435 virtual void Run(StatusCallback done) = 0; 436 }; 437 438 // Static-methods only class for registering and looking up collective 439 // implementations. 440 class CollectiveRegistry { 441 public: 442 using Factory = std::function<CollectiveImplementationInterface*()>; 443 // Looks up a previously registered CollectiveImplementation under 444 // `collective_name`. If found, creates an instance of the implementation and 445 // assign to `implementation`. 446 static Status Lookup(const string& collective_name, 447 CollectiveImplementationInterface** implementation); 448 449 // Looks up a previously registered CollectiveImplementation under 450 // `collective_name`. If found, returns the static instance of this 451 // implementation via `implementation`. This instance should only be used to 452 // call InitializateCollectiveParams. 453 static Status LookupParamResolverInstance( 454 const string& collective_name, 455 CollectiveImplementationInterface** implementation); 456 457 // Returns all registered collective implementations. 458 static void GetAll( 459 std::vector<CollectiveImplementationInterface*>* implementations); 460 461 private: 462 friend class CollectiveRegistration; 463 // Registers a CollectiveImplementation with name `collective_name` and 464 // factory `factory`. The latter is a function used to create instances of 465 // the CollectiveImplementation. Also creates a static instance of the 466 // implementation - this instance is used during param resolution and should 467 // only be used to call InitializeCollectiveParams. 468 static Status Register(const string& collective_name, Factory factory); 469 470 static Status LookupHelper(const string& collective_name, 471 CollectiveImplementationInterface** implementation, 472 bool param_resolver); 473 }; 474 475 // Class used to call CollectiveRegistry::Register. This should only be used to 476 // create a global static object. 477 class CollectiveRegistration { 478 public: CollectiveRegistration(const string & collective_name,CollectiveRegistry::Factory factory)479 CollectiveRegistration(const string& collective_name, 480 CollectiveRegistry::Factory factory) { 481 TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory)); 482 } 483 }; 484 485 #define REGISTER_COLLECTIVE(name, implementation) \ 486 static CollectiveRegistration register_##name##_collective( \ 487 #name, []() { return new implementation; }); 488 489 } // namespace tensorflow 490 491 #endif // TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ 492