1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ 16 #define TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ 17 18 #include <string> 19 #include <vector> 20 21 #include "tensorflow/core/framework/device_attributes.pb.h" 22 #include "tensorflow/core/framework/device_base.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/lib/core/refcount.h" 25 #include "tensorflow/core/lib/core/status.h" 26 27 namespace tensorflow { 28 29 class BufRendezvous; 30 class CancellationManager; 31 class CompleteGroupRequest; 32 class CompleteGroupResponse; 33 class CompleteInstanceRequest; 34 class CompleteInstanceResponse; 35 class Device; 36 class DeviceMgr; 37 class GetStepSequenceRequest; 38 class GetStepSequenceResponse; 39 class NcclManager; 40 class Tensor; 41 42 // Types of supported collective operations. 43 enum CollectiveType { 44 REDUCTION_COLLECTIVE = 0, 45 BROADCAST_COLLECTIVE, 46 GATHER_COLLECTIVE, 47 PERMUTE_COLLECTIVE, 48 UNDEFINED_COLLECTIVE, 49 }; 50 51 // Some collective op implementations require runtime group configuration from 52 // the OpKernel. Currently, this struct is used to set communicator key for 53 // NCCL-based collective implementation. 54 struct CollGroupRuntimeDetails { 55 string communicator_key; // for communicator-based techniques e.g. NCCL 56 string ToString() const; 57 }; 58 59 // Data common to all members of a device group. 60 // All members share the same device set but its order is 61 // particular to an instance so it is stored there. 62 struct CollGroupParams { 63 int32 group_key; 64 int32 group_size; 65 DeviceType device_type; 66 // Fully qualified name of device for each member, in default rank order. 67 std::vector<string> device_names; 68 // Task name prefix of corresponding device name. 69 std::vector<string> task_names; 70 // True if every task has the same number of devices. 71 bool same_num_devices_per_task = false; 72 // Task -> number of devices on that task. 73 std::unordered_map<string, int32> num_devices_per_task; 74 // If passed in to GPUOptions in ConfigProto, defines a good ring order for 75 // GPUs. Assumes same GPU configuration at each worker. 76 string gpu_ring_order = ""; 77 int32 num_tasks; // number of distinct tasks in group 78 CollGroupRuntimeDetails runtime_details; 79 string ToString() const; CollGroupParamsCollGroupParams80 CollGroupParams() 81 : group_key(0), group_size(0), device_type(DEVICE_CPU), num_tasks(0) {} 82 }; 83 84 // The best implementation of a collective op depends on many factors 85 // including the number of devices involved, the topology of 86 // interconnects between them and the sizes of inputs. This structure 87 // is used in generating and representing data movement choreography 88 // for each specific algorithm, hence it does not have a single, fixed 89 // interpretation. On first execution the runtime will update this 90 // structure with decisions that will guide all subsequent executions. 91 struct CollImplDetails { 92 string collective_name; 93 std::vector<std::vector<int>> subdiv_permutations; 94 std::vector<int> subdiv_offsets; 95 std::vector<int> subdiv_source_rank; // rank of source in each subdiv 96 std::vector<int32> 97 dependencies; // collective instances on which this node depends 98 string communication_hint; // user-supplied hint for implementation choice, 99 // e.g. ring or nccl 100 float timeout_seconds; // If non zero, set a completion timeout for the 101 // collective op to detect staleness. 102 }; 103 104 // Data common to all members of a collective instance. 105 // TODO(b/163171014) Refactor this struct to not be a union of all fields. 106 struct CollInstanceParams { 107 // Identifies all participating graph nodes. 108 int32 instance_key = -1; 109 CollectiveType type = UNDEFINED_COLLECTIVE; 110 DataType data_type = DT_FLOAT; 111 TensorShape shape = {0}; 112 CollImplDetails impl_details; 113 string ToString() const; 114 CollInstanceParams& operator=(const struct CollInstanceParams& other); 115 std::vector<string> devices; // permuter only 116 117 // For permuter only 118 // Each rank in the permutation is a receiver. 119 // Indices of each rank means a sender to that rank. 120 // Example: permutation = {2,0,1} means 121 // rank 0 sends to rank 2 122 // rank 1 sends to rank 0 123 // rank 2 sends to rank 1 124 std::vector<int> permutation; 125 }; 126 127 // Data common to all instance members in the same task. 128 struct CollTaskParams { 129 // True for devices that are local to the process, i.e. no RPC needed. 130 std::vector<bool> is_local; 131 string ToString() const; 132 }; 133 134 // Unique to a single CollectiveOp node. 135 struct CollectiveParams : public core::RefCounted { 136 CollGroupParams group; 137 CollInstanceParams instance; 138 CollTaskParams task; 139 140 string name = ""; // node name used only for log or error messages 141 int default_rank = -1; // index of this op within device_names 142 bool is_source = false; // broadcast only 143 int source_rank = -1; // broadcast only 144 // Rank of this device in each subdivision permutation. 145 std::vector<int> subdiv_rank; 146 OpKernel* merge_op = nullptr; // reduction only 147 OpKernel* final_op = nullptr; // reduction only 148 string ToString() const; 149 }; 150 151 class CollectiveExecutor; 152 153 // Interface that provides resolution of device localities. 154 class DeviceResolverInterface { 155 public: ~DeviceResolverInterface()156 virtual ~DeviceResolverInterface() {} 157 158 // Populates *attributes with the DeviceAttributes of the specified device. 159 virtual Status GetDeviceAttributes(const string& device, 160 DeviceAttributes* attributes) = 0; 161 162 // Returns all device attributes of a task. 163 virtual Status GetAllDeviceAttributes( 164 const string& task, std::vector<DeviceAttributes>* attributes) = 0; 165 166 // Updates device attributes. It returns error if any device already 167 // exists in the DeviceResolver and has a different incarnation. 168 virtual Status UpdateDeviceAttributes( 169 const std::vector<DeviceAttributes>& attributes) = 0; 170 }; 171 172 // Interface that provides resolution of shared CollectiveParams fields. 173 class ParamResolverInterface { 174 public: ~ParamResolverInterface()175 virtual ~ParamResolverInterface() {} 176 177 // Called by each collective op at first execution in order to fill out 178 // the CollectiveParams structure with data gathered from the full 179 // (maybe distributed) collection of peer nodes. 180 virtual void CompleteParamsAsync(const DeviceAttributes& device, 181 CollectiveParams* cp, 182 CancellationManager* cancel_mgr, 183 const StatusCallback& done) = 0; 184 185 // Used within a distributed implementation to discover/verify 186 // data shared across a device group. 187 virtual void CompleteGroupAsync(const CompleteGroupRequest* request, 188 CompleteGroupResponse* response, 189 CancellationManager* cancel_mgr, 190 const StatusCallback& done) = 0; 191 192 // Used within a distributed implementation to discover/verify data 193 // shared across an instance group. 194 virtual void CompleteInstanceAsync(const CompleteInstanceRequest* request, 195 CompleteInstanceResponse* response, 196 CancellationManager* cancel_mgr, 197 const StatusCallback& done) = 0; 198 199 // Aborts the resolver. After abortion the resolver can no longer be used. 200 virtual void StartAbort(const Status& s) = 0; 201 }; 202 203 // Graphs which utilize Collective Ops in a common instance must 204 // execute with identical step_ids even if they are disjoint graphs 205 // run by otherwise independent tasks. This interface supplies 206 // coordinated step_ids to use in such cases. 207 class StepSequenceInterface { 208 public: ~StepSequenceInterface()209 virtual ~StepSequenceInterface() {} 210 211 // Used with a distributed implementation to coordinate step_id 212 // sequences across tasks. 213 virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request, 214 GetStepSequenceResponse* response, 215 const StatusCallback& done) = 0; 216 217 // Refresh the local per-graph_key step_id sequence from collective 218 // group leader, if applicable. 219 virtual void RefreshStepIdSequenceAsync(int64 graph_key, 220 const StatusCallback& done) = 0; 221 222 // Returns the step_id that should be used for initiating a new execution 223 // on the specified graph. May return the same step_id multiple times if 224 // RetireStepId or RefreshStepIdReservation is not called. 225 virtual int64 NextStepId(int64 graph_key) = 0; 226 227 // Reports that execution of the given step has completed successfully. 228 // Should be called immediately after a step completes with OK status, 229 // prior to calling NextStepId(). If the step fails, don't call. 230 virtual void RetireStepId(int64 graph_key, int64 step_id) = 0; 231 }; 232 233 class NcclCommunicatorInterface; 234 235 // Interface that provides access to per-step CollectiveExecutor 236 // instances and various distributed resolution capabilities. 237 class CollectiveExecutorMgrInterface : public StepSequenceInterface { 238 public: ~CollectiveExecutorMgrInterface()239 virtual ~CollectiveExecutorMgrInterface() {} 240 241 // Returns the step-specific CollectiveExecutor, creating if one does not 242 // already exist. The caller assumes ownership of one Ref on the object. 243 virtual CollectiveExecutor* FindOrCreate(int64 step_id) = 0; 244 245 // If there is a CollectiveExecutor for step_id, remove it from the 246 // table. 247 virtual void Cleanup(int64 step_id) = 0; 248 249 virtual ParamResolverInterface* GetParamResolver() const = 0; 250 251 virtual DeviceResolverInterface* GetDeviceResolver() const = 0; 252 253 virtual NcclCommunicatorInterface* GetNcclCommunicator() const = 0; 254 }; 255 256 // Interface that a Collective Op implementation uses to exchange data 257 // with peers. Note that data exchange is currently limited to types 258 // for which DMAHelper::CanUseDMA() returns true, i.e. dense numeric 259 // types. 260 class CollectiveRemoteAccess { 261 public: ~CollectiveRemoteAccess()262 virtual ~CollectiveRemoteAccess() {} 263 264 virtual void RecvFromPeer(const string& peer_device, const string& peer_task, 265 bool peer_is_local, const string& key, 266 Device* to_device, DeviceContext* to_device_ctx, 267 const AllocatorAttributes& to_alloc_attr, 268 Tensor* to_tensor, 269 const DeviceLocality& client_locality, 270 int dev_to_dev_stream_index, 271 CancellationManager* cancellation_manager, 272 const StatusCallback& done) = 0; 273 274 virtual void PostToPeer(const string& peer_device, const string& peer_task, 275 const string& key, Device* from_device, 276 DeviceContext* from_device_ctx, 277 const AllocatorAttributes& from_alloc_attr, 278 const Tensor* from_tensor, 279 const DeviceLocality& client_locality, 280 CancellationManager* cancellation_manager, 281 const StatusCallback& done) = 0; 282 283 // Checks the health of a collective peer. It probes the peer to see if it is 284 // alive. Note that if a peer has restarted, it's considered a different one, 285 // so CheckPeerHealth fails. 286 virtual void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms, 287 const StatusCallback& done) = 0; 288 289 virtual BufRendezvous* buf_rendezvous() = 0; 290 291 virtual void StartAbort(const Status& s) = 0; 292 }; 293 294 // A step-specific object that can execute a collective operation completely 295 // described by a CollectiveParams object. 296 class CollectiveExecutor : public core::RefCounted { 297 public: StartAbort(const Status & s)298 virtual void StartAbort(const Status& s) {} 299 ExecuteAsync(OpKernelContext * ctx,const CollectiveParams * col_params,const string & exec_key,StatusCallback done)300 virtual void ExecuteAsync(OpKernelContext* ctx, 301 const CollectiveParams* col_params, 302 const string& exec_key, StatusCallback done) { 303 done(errors::Internal( 304 "A collective Op has been called in a context in which " 305 "a CollectiveExecutor has not been provided.")); 306 } 307 CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,StatusCallback done)308 virtual void CompleteParamsAsync(const DeviceAttributes& device, 309 CollectiveParams* cp, 310 CancellationManager* cancel_mgr, 311 StatusCallback done) { 312 done(errors::Internal( 313 "A collective Op has been called in a context in which " 314 "a CollectiveExecutor has not been provided.")); 315 } 316 317 // Runs the potentially-blocking closure/expensive callback. 318 virtual void RunClosure(std::function<void()> closure) = 0; 319 remote_access()320 virtual CollectiveRemoteAccess* remote_access() { return nullptr; } 321 322 // `WaitForDependencies` and `Launched` are used for fine-grained control of 323 // execution order between collective instances. These functions are intended 324 // to be called in `Run` function of collective implementations, and may be 325 // used to make part, or whole, of the collective execution ordered with 326 // respect to other collective instances. 327 // 328 // `WaitForDependencies` will block until it is safe to continue the callee's 329 // execution, where safety is defined as: ordered with respect to the 330 // collective instances defined in the callee's `wait_for` attribute. WaitForDependencies(const CollectiveParams & col_params)331 virtual void WaitForDependencies(const CollectiveParams& col_params) {} 332 // `UnblockDependencies` unblocks the dependent collective instances by 333 // recording that this caller's device has completed the critical portion of 334 // the collective execution. UnblockDependencies(const CollectiveParams & col_params)335 virtual void UnblockDependencies(const CollectiveParams& col_params) {} 336 337 // Used to designate an invalid group or instance key. 338 static int64 kInvalidId; 339 340 // Lexically scoped handle for Ref. 341 class Handle { 342 public: Handle(CollectiveExecutor * ce,bool inherit_ref)343 explicit Handle(CollectiveExecutor* ce, bool inherit_ref) : ce_(ce) { 344 if (!inherit_ref) ce->Ref(); 345 } ~Handle()346 ~Handle() { ce_->Unref(); } get()347 CollectiveExecutor* get() const { return ce_; } 348 349 private: 350 CollectiveExecutor* ce_; 351 }; 352 353 protected: CollectiveExecutor(CollectiveExecutorMgrInterface * cem)354 explicit CollectiveExecutor(CollectiveExecutorMgrInterface* cem) 355 : cem_(cem) {} 356 357 // For use only by derived classes 358 static OpKernelContext::Params* CtxParams(OpKernelContext* ctx); 359 CollectiveExecutorMgrInterface* cem_; 360 361 TF_DISALLOW_COPY_AND_ASSIGN(CollectiveExecutor); 362 }; 363 364 struct CollectiveContext { 365 CollectiveExecutor* col_exec; // Not owned 366 NcclCommunicatorInterface* nccl_communicator; // Not owned 367 const DeviceMgr* dev_mgr; // Not owned 368 OpKernelContext* op_ctx; // Not owned 369 OpKernelContext::Params* op_params; // Not owned 370 const CollectiveParams* col_params; // Not owned 371 const string exec_key; 372 const int64 step_id; 373 const Tensor* input; // Not owned 374 Tensor* output; // Not owned 375 Device* device; // The device for which this instance labors 376 const string device_name; 377 DeviceLocality device_locality; 378 379 CollectiveContext(CollectiveExecutor* col_exec, 380 NcclCommunicatorInterface* nccl_communicator, 381 const DeviceMgr* dev_mgr, OpKernelContext* ctx, 382 OpKernelContext::Params* op_params, 383 const CollectiveParams* col_params, const string& exec_key, 384 int64 step_id, const Tensor* input, Tensor* output); 385 }; 386 387 class NcclCommunicatorInterface { 388 public: 389 virtual ~NcclCommunicatorInterface() = default; 390 391 virtual void Enqueue(std::shared_ptr<CollectiveContext> col_ctx, 392 StatusCallback done) = 0; 393 394 virtual void StartAbort(const Status& s) = 0; 395 }; 396 397 // Interface of a Collective Op implementation. Each specific CollectiveOp will 398 // implement this interface and register the implementation via the 399 // CollectiveRegistry detailed below. See common_runtime/ring_reducer and 400 // common_runtime/hierarchical_tree_broadcaster for examples. 401 class CollectiveImplementationInterface : public core::RefCounted { 402 public: 403 virtual ~CollectiveImplementationInterface() = default; 404 405 // Initializes the portions of `col_params` specific to this 406 // implementation. Called exactly once for every Collective instance during 407 // the CollectiveParams resolution process when the graph is first executed, 408 // at the end of `CompleteInstanceLocal()`. 409 // NOTE(ayushd): This is effectively a static function because it modifies the 410 // `col_params` passed in and should not manipulate any data members. However 411 // because it is virtual and needs to be implemented by every derived class we 412 // do not mark it as static. 413 virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0; 414 415 // Prepares the CollectiveContext for executing this CollectiveImplementation. 416 // Called from CollectiveExecutor right before calling Run(). The 417 // CollectiveContext passed in must outlive the CollectiveImplementation 418 // object. 419 virtual Status InitializeCollectiveContext( 420 std::shared_ptr<CollectiveContext> col_ctx) = 0; 421 422 // Performs collective implementation specific group initialization. The 423 // intention is to do group-specific initialization of runtime details for the 424 // collective implementation. Currently used only to set `communicator_key` 425 // in techniques which use a communicator for distributed collectives (NCCL). 426 virtual Status InitializeCollectiveGroupRuntimeDetails( 427 CollGroupRuntimeDetails* col_group_runtime_details) = 0; 428 429 // Processes and moves data according to the logic of this Collective 430 // implementation. Relies on appropriate initialization of op-specific 431 // CollectiveParams in InitializeCollectiveParams(), as well as appropriate 432 // context initialization in InitializeCollectiveContext(). 433 virtual void Run(StatusCallback done) = 0; 434 }; 435 436 // Static-methods only class for registering and looking up collective 437 // implementations. 438 class CollectiveRegistry { 439 public: 440 using Factory = std::function<CollectiveImplementationInterface*()>; 441 // Looks up a previously registered CollectiveImplementation under 442 // `collective_name`. If found, creates an instance of the implementation and 443 // assign to `implementation`. 444 static Status Lookup(const string& collective_name, 445 CollectiveImplementationInterface** implementation); 446 447 // Looks up a previously registered CollectiveImplementation under 448 // `collective_name`. If found, returns the static instance of this 449 // implementation via `implementation`. This instance should only be used to 450 // call InitializateCollectiveParams. 451 static Status LookupParamResolverInstance( 452 const string& collective_name, 453 CollectiveImplementationInterface** implementation); 454 455 // Returns all registered collective implementations. 456 static void GetAll( 457 std::vector<CollectiveImplementationInterface*>* implementations); 458 459 private: 460 friend class CollectiveRegistration; 461 // Registers a CollectiveImplementation with name `collective_name` and 462 // factory `factory`. The latter is a function used to create instances of 463 // the CollectiveImplementation. Also creates a static instance of the 464 // implementation - this instance is used during param resolution and should 465 // only be used to call InitializeCollectiveParams. 466 static Status Register(const string& collective_name, Factory factory); 467 468 static Status LookupHelper(const string& collective_name, 469 CollectiveImplementationInterface** implementation, 470 bool param_resolver); 471 }; 472 473 // Class used to call CollectiveRegistry::Register. This should only be used to 474 // create a global static object. 475 class CollectiveRegistration { 476 public: CollectiveRegistration(const string & collective_name,CollectiveRegistry::Factory factory)477 CollectiveRegistration(const string& collective_name, 478 CollectiveRegistry::Factory factory) { 479 TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory)); 480 } 481 }; 482 483 #define REGISTER_COLLECTIVE(name, implementation) \ 484 static CollectiveRegistration register_##name##_collective( \ 485 #name, []() { return new implementation; }); 486 487 } // namespace tensorflow 488 489 #endif // TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ 490