1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ 16 #define TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ 17 18 #include <string> 19 #include <vector> 20 21 #include "tensorflow/core/framework/device_attributes.pb.h" 22 #include "tensorflow/core/framework/device_base.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/lib/core/refcount.h" 25 #include "tensorflow/core/lib/core/status.h" 26 27 namespace tensorflow { 28 29 class BufRendezvous; 30 class CancellationManager; 31 class CompleteGroupRequest; 32 class CompleteGroupResponse; 33 class CompleteInstanceRequest; 34 class CompleteInstanceResponse; 35 class Device; 36 class DeviceMgr; 37 class GetStepSequenceRequest; 38 class GetStepSequenceResponse; 39 class Tensor; 40 41 // Types of supported collective operations. 42 enum CollectiveType { 43 REDUCTION_COLLECTIVE = 0, 44 BROADCAST_COLLECTIVE, 45 GATHER_COLLECTIVE, 46 UNDEFINED_COLLECTIVE, 47 }; 48 49 // Some collective op implementations require runtime group configuration from 50 // the OpKernel. Currently, this struct is used to set communicator key for 51 // NCCL-based collective implementation. 52 struct CollGroupRuntimeDetails { 53 string communicator_key; // for communicator-based techniques e.g. NCCL 54 string ToString() const; 55 }; 56 57 // Data common to all members of a device group. 58 // All members share the same device set but its order is 59 // particular to an instance so it is stored there. 60 struct CollGroupParams { 61 int32 group_key; 62 int32 group_size; 63 DeviceType device_type; 64 int32 num_tasks; // number of distinct tasks in group 65 CollGroupRuntimeDetails runtime_details; 66 string ToString() const; CollGroupParamsCollGroupParams67 CollGroupParams() 68 : group_key(0), group_size(0), device_type(DEVICE_CPU), num_tasks(0) {} 69 }; 70 71 // The best implementation of a collective op depends on many factors 72 // including the number of devices involved, the topology of 73 // interconnects between them and the sizes of inputs. This structure 74 // is used in generating and representing data movement choreography 75 // for each specific algorithm, hence it does not have a single, fixed 76 // interpretation. On first execution the runtime will update this 77 // structure with decisions that will guide all subsequent executions. 78 struct CollImplDetails { 79 string collective_name; 80 std::vector<std::vector<int>> subdiv_permutations; 81 std::vector<int> subdiv_offsets; 82 std::vector<int> subdiv_source_rank; // rank of source in each subdiv 83 std::vector<int32> 84 dependencies; // collective instances on which this node depends 85 string communication_hint; // user-supplied hint for implementation choice, 86 // e.g. ring or nccl 87 }; 88 89 // Data common to all members of a collective instance. 90 struct CollInstanceParams { 91 // Identifies all participating graph nodes. 92 int32 instance_key = -1; 93 CollectiveType type = UNDEFINED_COLLECTIVE; 94 DataType data_type = DT_FLOAT; 95 TensorShape shape = {0}; 96 // Fully qualified name of device for each member, in default rank order. 97 std::vector<string> device_names; 98 // Task name prefix of corresponding device name. 99 std::vector<string> task_names; 100 // True if every task has the same number of devices. 101 bool same_num_devices_per_task = false; 102 // Task -> number of devices on that task. 103 std::unordered_map<string, int32> num_devices_per_task; 104 // If passed in to GPUOptions in ConfigProto, defines a good ring order for 105 // GPUs. Assumes same GPU configuration at each worker. 106 string gpu_ring_order = ""; 107 CollImplDetails impl_details; 108 string ToString() const; 109 CollInstanceParams& operator=(const struct CollInstanceParams& other); 110 }; 111 112 // Data common to all instance members in the same task. 113 struct CollTaskParams { 114 // True for devices that are local to the process, i.e. no RPC needed. 115 std::vector<bool> is_local; 116 string ToString() const; 117 }; 118 119 // Unique to a single CollectiveOp node. 120 struct CollectiveParams { 121 CollGroupParams group; 122 CollInstanceParams instance; 123 CollTaskParams task; 124 125 string name = ""; // node name used only for log or error messages 126 int default_rank = -1; // index of this op within device_names 127 bool is_source = false; // broadcast only 128 int source_rank = -1; // broadcast only 129 // Rank of this device in each subdivision permutation. 130 std::vector<int> subdiv_rank; 131 std::unique_ptr<OpKernel> merge_op; // reduction only 132 std::unique_ptr<OpKernel> final_op; // reduction only 133 string ToString() const; 134 }; 135 136 class CollectiveExecutor; 137 138 // Interface that provides resolution of device localities. 139 class DeviceResolverInterface { 140 public: ~DeviceResolverInterface()141 virtual ~DeviceResolverInterface() {} 142 143 // Collects DeviceAttributes protobufs from all of the devices identified 144 // in 'col_params'. 145 virtual void GetAllDeviceAttributesAsync( 146 const std::vector<string>& devices, const std::vector<string>& tasks, 147 std::vector<DeviceAttributes>* attributes, 148 const StatusCallback& done) = 0; 149 150 // Populate *attributes with the DeviceAttributes of the specified 151 // device. 152 virtual void GetDeviceAttributesAsync(const string& device, 153 const string& task, 154 DeviceAttributes* attributes, 155 const StatusCallback& done) = 0; 156 157 // Clear the cache of device data belonging to the specified task. 158 virtual void ClearTask(const string& task) = 0; 159 160 // Clear the cache of all device data. 161 virtual void ClearCache() = 0; 162 }; 163 164 // Interface that provides resolution of shared CollectiveParams fields. 165 class ParamResolverInterface { 166 public: ~ParamResolverInterface()167 virtual ~ParamResolverInterface() {} 168 169 // Called by each collective op at first execution in order to fill out 170 // the CollectiveParams structure with data gathered from the full 171 // (maybe distributed) collection of peer nodes. 172 virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp, 173 CancellationManager* cancel_mgr, 174 const StatusCallback& done) = 0; 175 176 // Used within a distributed implementation to discover/verify 177 // data shared across a device group. 178 virtual void CompleteGroupAsync(const CompleteGroupRequest* request, 179 CompleteGroupResponse* response, 180 CancellationManager* cancel_mgr, 181 const StatusCallback& done) = 0; 182 183 // Used within a distributed implementation to discover/verify data 184 // shared across an instance group. 185 virtual void CompleteInstanceAsync(const CompleteInstanceRequest* request, 186 CompleteInstanceResponse* response, 187 CancellationManager* cancel_mgr, 188 const StatusCallback& done) = 0; 189 }; 190 191 // Graphs which utilize Collective Ops in a common instance must 192 // execute with identical step_ids even if they are disjoint graphs 193 // run by otherwise independent tasks. This interface supplies 194 // coordinated step_ids to use in such cases. 195 class StepSequenceInterface { 196 public: ~StepSequenceInterface()197 virtual ~StepSequenceInterface() {} 198 199 // Used with a distributed implementation to coordinate step_id 200 // sequences across tasks. 201 virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request, 202 GetStepSequenceResponse* response, 203 const StatusCallback& done) = 0; 204 205 // Refresh the local per-graph_key step_id sequence from collective 206 // group leader, if applicable. 207 virtual void RefreshStepIdSequenceAsync(int64 graph_key, 208 const StatusCallback& done) = 0; 209 210 // Returns the step_id that should be used for initiating a new execution 211 // on the specified graph. May return the same step_id multiple times if 212 // RetireStepId or RefreshStepIdReservation is not called. 213 virtual int64 NextStepId(int64 graph_key) = 0; 214 215 // Reports that execution of the given step has completed successfully. 216 // Should be called immediately after a step completes with OK status, 217 // prior to calling NextStepId(). If the step fails, don't call. 218 virtual void RetireStepId(int64 graph_key, int64 step_id) = 0; 219 }; 220 221 // Interface that provides access to per-step CollectiveExecutor 222 // instances and various distributed resolution capabilities. 223 class CollectiveExecutorMgrInterface : public StepSequenceInterface { 224 public: ~CollectiveExecutorMgrInterface()225 virtual ~CollectiveExecutorMgrInterface() {} 226 227 // Returns the step-specific CollectiveExecutor, creating if one does not 228 // already exist. The caller assumes ownership of one Ref on the object. 229 virtual CollectiveExecutor* FindOrCreate(int64 step_id) = 0; 230 231 // If there is a CollectiveExecutor for step_id, remove it from the 232 // table. 233 virtual void Cleanup(int64 step_id) = 0; 234 235 virtual ParamResolverInterface* GetParamResolver() const = 0; 236 237 virtual DeviceResolverInterface* GetDeviceResolver() const = 0; 238 }; 239 240 // Interface that a Collective Op implementation uses to exchange data 241 // with peers. Note that data exchange is currently limited to types 242 // for which DMAHelper::CanUseDMA() returns true, i.e. dense numeric 243 // types. 244 class PeerAccessInterface { 245 public: ~PeerAccessInterface()246 virtual ~PeerAccessInterface() {} 247 248 virtual void RecvFromPeer(const string& peer_device, const string& peer_task, 249 bool peer_is_local, const string& key, 250 Device* to_device, DeviceContext* to_device_ctx, 251 const AllocatorAttributes& to_alloc_attr, 252 Tensor* to_tensor, 253 const DeviceLocality& client_locality, 254 int dev_to_dev_stream_index, 255 const StatusCallback& done) = 0; 256 257 virtual void PostToPeer(const string& peer_device, const string& peer_task, 258 const string& key, Device* from_device, 259 DeviceContext* from_device_ctx, 260 const AllocatorAttributes& from_alloc_attr, 261 const Tensor* from_tensor, 262 const DeviceLocality& client_locality, 263 const StatusCallback& done) = 0; 264 265 // Runs the potentially-blocking closure/expensive callback. 266 virtual void RunClosure(std::function<void()> closure) = 0; 267 }; 268 269 class PerStepCollectiveRemoteAccess; 270 271 // A step-specific object that can execute a collective operation completely 272 // described by a CollectiveParams object. 273 class CollectiveExecutor : public PeerAccessInterface, public core::RefCounted { 274 public: StartAbort(const Status & s)275 virtual void StartAbort(const Status& s) {} 276 ExecuteAsync(OpKernelContext * ctx,const CollectiveParams & col_params,const string & exec_key,StatusCallback done)277 virtual void ExecuteAsync(OpKernelContext* ctx, 278 const CollectiveParams& col_params, 279 const string& exec_key, StatusCallback done) { 280 done(errors::Internal( 281 "A collective Op has been called in a context in which " 282 "a CollectiveExecutor has not been provided.")); 283 } 284 CompleteParamsAsync(const string & device,CollectiveParams * cp,CancellationManager * cancel_mgr,StatusCallback done)285 virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp, 286 CancellationManager* cancel_mgr, 287 StatusCallback done) { 288 done(errors::Internal( 289 "A collective Op has been called in a context in which " 290 "a CollectiveExecutor has not been provided.")); 291 } 292 remote_access()293 virtual PerStepCollectiveRemoteAccess* remote_access() { return nullptr; } 294 295 // `WaitForDependencies` and `Launched` are used for fine-grained control of 296 // execution order between collective instances. These functions are intended 297 // to be called in `Run` function of collective implementations, and may be 298 // used to make part, or whole, of the collective execution ordered with 299 // respect to other collective instances. 300 // 301 // `WaitForDependencies` will block until it is safe to continue the callee's 302 // execution, where safety is defined as: ordered with respect to the 303 // collective instances defined in the callee's `wait_for` attribute. WaitForDependencies(const CollectiveParams & col_params)304 virtual void WaitForDependencies(const CollectiveParams& col_params) {} 305 // `UnblockDependencies` unblocks the dependent collective instances by 306 // recording that this caller's device has completed the critical portion of 307 // the collective execution. UnblockDependencies(const CollectiveParams & col_params)308 virtual void UnblockDependencies(const CollectiveParams& col_params) {} 309 310 // Used to designate an invalid group or instance key. 311 static int64 kInvalidId; 312 313 // Lexically scoped handle for Ref. 314 class Handle { 315 public: Handle(CollectiveExecutor * ce,bool inherit_ref)316 explicit Handle(CollectiveExecutor* ce, bool inherit_ref) : ce_(ce) { 317 if (!inherit_ref) ce->Ref(); 318 } ~Handle()319 ~Handle() { ce_->Unref(); } get()320 CollectiveExecutor* get() const { return ce_; } 321 322 private: 323 CollectiveExecutor* ce_; 324 }; 325 326 protected: CollectiveExecutor(CollectiveExecutorMgrInterface * cem)327 explicit CollectiveExecutor(CollectiveExecutorMgrInterface* cem) 328 : cem_(cem) {} 329 330 // For use only by derived classes 331 static OpKernelContext::Params* CtxParams(OpKernelContext* ctx); 332 CollectiveExecutorMgrInterface* cem_; 333 334 TF_DISALLOW_COPY_AND_ASSIGN(CollectiveExecutor); 335 }; 336 337 // Interface of a helper object that provides a CollectiveExecutor with 338 // all of the remote access it needs. 339 class CollectiveRemoteAccess : public PeerAccessInterface, 340 public DeviceResolverInterface { 341 public: ~CollectiveRemoteAccess()342 virtual ~CollectiveRemoteAccess() {} 343 344 virtual BufRendezvous* buf_rendezvous() = 0; 345 }; 346 347 // A per-step version of CollectiveRemoteAccess that cleans up outstanding 348 // communications in case step execution is abandoned. 349 class PerStepCollectiveRemoteAccess : public CollectiveRemoteAccess { 350 public: ~PerStepCollectiveRemoteAccess()351 virtual ~PerStepCollectiveRemoteAccess() {} 352 virtual void StartAbort(const Status& s) = 0; 353 }; 354 355 class CollectiveContext { 356 public: 357 CollectiveContext(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr, 358 OpKernelContext* ctx, OpKernelContext::Params* op_params, 359 const CollectiveParams& col_params, const string& exec_key, 360 int64 step_id, const Tensor* input, Tensor* output); 361 362 virtual ~CollectiveContext() = default; 363 364 CollectiveExecutor* col_exec; // Not owned 365 const DeviceMgr* dev_mgr; // Not owned 366 OpKernelContext* op_ctx; // Not owned 367 OpKernelContext::Params* op_params; // Not owned 368 const CollectiveParams& col_params; 369 const string exec_key; 370 const int64 step_id; 371 const Tensor* input; // Not owned 372 Tensor* output; // Not owned 373 Device* device; // The device for which this instance labors 374 const string device_name; 375 DeviceLocality device_locality; 376 }; 377 378 // Interface of a Collective Op implementation. Each specific CollectiveOp will 379 // implement this interface and register the implementation via the 380 // CollectiveRegistry detailed below. See common_runtime/ring_reducer and 381 // common_runtime/hierarchical_tree_broadcaster for examples. 382 class CollectiveImplementationInterface { 383 public: 384 virtual ~CollectiveImplementationInterface() = default; 385 386 // Initializes the portions of `col_params` specific to this 387 // implementation. Called exactly once for every Collective instance during 388 // the CollectiveParams resolution process when the graph is first executed, 389 // at the end of `CompleteInstanceLocal()`. 390 // NOTE(ayushd): This is effectively a static function because it modifies the 391 // `col_params` passed in and should not manipulate any data members. However 392 // because it is virtual and needs to be implemented by every derived class we 393 // do not mark it as static. 394 virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0; 395 396 // Prepares the CollectiveContext for executing this CollectiveImplementation. 397 // Called from CollectiveExecutor right before calling Run(). The 398 // CollectiveContext passed in must outlive the CollectiveImplementation 399 // object. 400 virtual Status InitializeCollectiveContext(CollectiveContext* col_ctx) = 0; 401 402 // Performs collective implementation specific group initialization. The 403 // intention is to do group-specific initialization of runtime details for the 404 // collective implementation. Currently used only to set `communicator_key` 405 // in techniques which use a communicator for distributed collectives (NCCL). 406 virtual Status InitializeCollectiveGroupRuntimeDetails( 407 CollGroupRuntimeDetails* col_group_runtime_details) = 0; 408 409 // Processes and moves data according to the logic of this Collective 410 // implementation. Relies on appropriate initialization of op-specific 411 // CollectiveParams in InitializeCollectiveParams(), as well as appropriate 412 // context initialization in InitializeCollectiveContext(). 413 virtual void Run(StatusCallback done) = 0; 414 }; 415 416 // Static-methods only class for registering and looking up collective 417 // implementations. 418 class CollectiveRegistry { 419 public: 420 using Factory = std::function<CollectiveImplementationInterface*()>; 421 // Looks up a previously registered CollectiveImplementation under 422 // `collective_name`. If found, creates an instance of the implementation and 423 // assign to `implementation`. 424 static Status Lookup(const string& collective_name, 425 CollectiveImplementationInterface** implementation); 426 427 // Looks up a previously registered CollectiveImplementation under 428 // `collective_name`. If found, returns the static instance of this 429 // implementation via `implementation`. This instance should only be used to 430 // call InitializateCollectiveParams. 431 static Status LookupParamResolverInstance( 432 const string& collective_name, 433 CollectiveImplementationInterface** implementation); 434 435 // Returns all registered collective implementations. 436 static void GetAll( 437 std::vector<CollectiveImplementationInterface*>* implementations); 438 439 private: 440 friend class CollectiveRegistration; 441 // Registers a CollectiveImplementation with name `collective_name` and 442 // factory `factory`. The latter is a function used to create instances of 443 // the CollectiveImplementation. Also creates a static instance of the 444 // implementation - this instance is used during param resolution and should 445 // only be used to call InitializeCollectiveParams. 446 static Status Register(const string& collective_name, Factory factory); 447 448 static Status LookupHelper(const string& collective_name, 449 CollectiveImplementationInterface** implementation, 450 bool param_resolver); 451 }; 452 453 // Class used to call CollectiveRegistry::Register. This should only be used to 454 // create a global static object. 455 class CollectiveRegistration { 456 public: CollectiveRegistration(const string & collective_name,CollectiveRegistry::Factory factory)457 CollectiveRegistration(const string& collective_name, 458 CollectiveRegistry::Factory factory) { 459 TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory)); 460 } 461 }; 462 463 #define REGISTER_COLLECTIVE(name, implementation) \ 464 static CollectiveRegistration register_##name##_collective( \ 465 #name, []() { return new implementation; }); 466 467 } // namespace tensorflow 468 469 #endif // TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ 470