• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
17 
18 #include <string>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/core/framework/device_attributes.pb.h"
23 #include "tensorflow/core/framework/device_base.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/lib/core/refcount.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/platform/intrusive_ptr.h"
28 
29 namespace tensorflow {
30 
31 class BufRendezvous;
32 class CancellationManager;
33 class CompleteGroupRequest;
34 class CompleteGroupResponse;
35 class CompleteInstanceRequest;
36 class CompleteInstanceResponse;
37 class Device;
38 class DeviceMgr;
39 class GetStepSequenceRequest;
40 class GetStepSequenceResponse;
41 class NcclManager;
42 class Tensor;
43 
44 // Types of supported collective operations.
45 enum CollectiveType {
46   REDUCTION_COLLECTIVE = 0,
47   BROADCAST_COLLECTIVE,
48   GATHER_COLLECTIVE,
49   PERMUTE_COLLECTIVE,
50   ALL_TO_ALL_COLLECTIVE,
51   UNDEFINED_COLLECTIVE,
52 };
53 
54 // Some collective op implementations require runtime group configuration from
55 // the OpKernel.  Currently, this struct is used to set communicator key for
56 // NCCL-based collective implementation.
57 struct CollGroupRuntimeDetails {
58   string communicator_key;  // for communicator-based techniques e.g. NCCL
59   string ToString() const;
60 };
61 
62 struct CollGroupMember {
63   DeviceAttributes device;
64   string task;
65   bool is_local;
66   // User provided rank
67   int32 rank = -1;
68 };
69 
70 // Data common to all members of a device group.
71 // All members share the same device set but its order is
72 // particular to an instance so it is stored there.
73 struct CollGroupParams {
74   // Inputs from Collective ops:
75   int32 group_key;
76   int32 group_size;
77   DeviceType device_type;
78   int user_specified_rank = -1;  // rank provided by the user.
79   // Generated from Collective Group Resolver:
80   // Members in this group, in default rank order.
81   std::vector<CollGroupMember> members;
82   // True if every task has the same number of devices.
83   bool same_num_devices_per_task = false;
84   // Task -> number of devices on that task.
85   std::unordered_map<string, int32> num_devices_per_task;
86   int32 num_tasks;  // number of distinct tasks in group
87   CollGroupRuntimeDetails runtime_details;
88   string ToString() const;
CollGroupParamsCollGroupParams89   CollGroupParams()
90       : group_key(0), group_size(0), device_type(DEVICE_CPU), num_tasks(0) {}
91 };
92 
93 // The best implementation of a collective op depends on many factors
94 // including the number of devices involved, the topology of
95 // interconnects between them and the sizes of inputs.  This structure
96 // is used in generating and representing data movement choreography
97 // for each specific algorithm, hence it does not have a single, fixed
98 // interpretation.  On first execution the runtime will update this
99 // structure with decisions that will guide all subsequent executions.
100 struct CollImplDetails {
101   string collective_name;
102   std::vector<std::vector<int>> subdiv_permutations;
103   // subdiv_offsets and max_subdivs_per_device are used together as follows:
104   // When subdiv_offsets is provided (non-empty) it is used as is. When
105   // subdiv_offsets is not provided subdivisons are generated dynamically
106   // constrained by max_subdivs_per_device. When subdiv_offsets is empty AND
107   // max_subdivs_per_device = 0 an internal default kMaxSubdivsPerDeviceDefault
108   // is used. When max_subdivs_per_device = -1, no subivision is done.
109   int max_subdivs_per_device = -1;  // Upper bound on subdivisions per device.
110   std::vector<int> subdiv_offsets;
111   std::vector<int> subdiv_source_rank;  // rank of source in each subdiv
112   std::vector<int32>
113       dependencies;           // collective instances on which this node depends
114   string communication_hint;  // user-supplied hint for implementation choice,
115                               // e.g. ring or nccl
116   float timeout_seconds;      // If non zero, set a completion timeout for the
117                               // collective op to detect staleness.
118 };
119 
120 // Data common to all members of a collective instance.
121 // TODO(b/163171014) Refactor this struct to not be a union of all fields.
122 struct CollInstanceParams {
123   // Identifies all participating graph nodes.
124   int32 instance_key = -1;
125   CollectiveType type = UNDEFINED_COLLECTIVE;
126   DataType data_type = DT_FLOAT;
127   TensorShape shape = {0};
128   CollImplDetails impl_details;
129   string ToString() const;
130   CollInstanceParams& operator=(const struct CollInstanceParams& other);
131   std::vector<string> devices;  // permuter only
132 
133   // For permuter only
134   // Each rank in the permutation is a receiver.
135   // Indices of each rank means a sender to that rank.
136   // Example: permutation = {2,0,1} means
137   //   rank 0 sends to rank 2
138   //   rank 1 sends to rank 0
139   //   rank 2 sends to rank 1
140   std::vector<int> permutation;
141 };
142 
143 // Unique to a single CollectiveOp node.
144 struct CollectiveParams : public core::RefCounted {
145   CollGroupParams group;
146   CollInstanceParams instance;
147 
148   string name = "";        // node name used only for log or error messages
149   int default_rank = -1;   // index of this op within device_names
150   bool is_source = false;  // broadcast only
151   int source_rank = -1;    // broadcast only
152   // Rank of this device in each subdivision permutation.
153   std::vector<int> subdiv_rank;
154   OpKernel* merge_op = nullptr;  // reduction only
155   OpKernel* final_op = nullptr;  // reduction only
156   string ToString() const;
157   bool run_group_initialization = true;
158 };
159 
160 class CollectiveExecutor;
161 
162 // Interface that provides resolution of device localities.
163 class DeviceResolverInterface {
164  public:
~DeviceResolverInterface()165   virtual ~DeviceResolverInterface() {}
166 
167   // Populates *attributes with the DeviceAttributes of the specified device.
168   virtual Status GetDeviceAttributes(const string& device,
169                                      DeviceAttributes* attributes) = 0;
170 
171   // Returns all device attributes of a task.
172   virtual Status GetAllDeviceAttributes(
173       const string& task, std::vector<DeviceAttributes>* attributes) = 0;
174 
175   // Updates device attributes. It returns error if any device already
176   // exists in the DeviceResolver and has a different incarnation.
177   virtual Status UpdateDeviceAttributes(
178       const std::vector<DeviceAttributes>& attributes) = 0;
179 };
180 
181 // Interface that provides resolution of shared CollectiveParams fields.
182 class ParamResolverInterface {
183  public:
~ParamResolverInterface()184   virtual ~ParamResolverInterface() {}
185 
186   // Called by each collective op at first execution in order to fill out
187   // the CollectiveParams structure with data gathered from the full
188   // (maybe distributed) collection of peer nodes.
189   virtual void CompleteParamsAsync(const DeviceAttributes& device,
190                                    CollectiveParams* cp,
191                                    CancellationManager* cancel_mgr,
192                                    const StatusCallback& done) = 0;
193 
194   // Completes group_params with data gathered from all devices in the group.
195   // This blocks until all devices are there.
196   virtual void CompleteGroupAsync(const DeviceAttributes& device,
197                                   CollGroupParams* group_params,
198                                   CancellationManager* cancel_mgr,
199                                   const StatusCallback& done) = 0;
200 
201   // Used within a distributed implementation to discover/verify data
202   // shared across an instance group.
203   // Note: this works differently from CompleteGroupAsync as a refactor is in
204   // progress.
205   virtual void CompleteInstanceAsync(const CompleteInstanceRequest* request,
206                                      CompleteInstanceResponse* response,
207                                      CancellationManager* cancel_mgr,
208                                      const StatusCallback& done) = 0;
209 
210   // Looks up a group. It returns an error if the group is not ready or not
211   // found.
212   virtual Status LookupGroup(int32_t group_key, CollGroupParams* group) = 0;
213 
214   // Aborts the resolver. After abortion the resolver can no longer be used.
215   virtual void StartAbort(const Status& s) = 0;
216 };
217 
218 // Graphs which utilize Collective Ops in a common instance must
219 // execute with identical step_ids even if they are disjoint graphs
220 // run by otherwise independent tasks.  This interface supplies
221 // coordinated step_ids to use in such cases.
222 class StepSequenceInterface {
223  public:
~StepSequenceInterface()224   virtual ~StepSequenceInterface() {}
225 
226   // Used with a distributed implementation to coordinate step_id
227   // sequences across tasks.
228   virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request,
229                                     GetStepSequenceResponse* response,
230                                     const StatusCallback& done) = 0;
231 
232   // Refresh the local per-graph_key step_id sequence from collective
233   // group leader, if applicable.
234   virtual void RefreshStepIdSequenceAsync(int64_t graph_key,
235                                           const StatusCallback& done) = 0;
236 
237   // Returns the step_id that should be used for initiating a new execution
238   // on the specified graph. May return the same step_id multiple times if
239   // RetireStepId or RefreshStepIdReservation is not called.
240   virtual int64_t NextStepId(int64_t graph_key) = 0;
241 
242   // Reports that execution of the given step has completed successfully.
243   // Should be called immediately after a step completes with OK status,
244   // prior to calling NextStepId().  If the step fails, don't call.
245   virtual void RetireStepId(int64_t graph_key, int64_t step_id) = 0;
246 };
247 
248 class NcclCommunicatorInterface;
249 
250 // Interface that provides access to per-step CollectiveExecutor
251 // instances and various distributed resolution capabilities.
252 class CollectiveExecutorMgrInterface : public StepSequenceInterface {
253  public:
~CollectiveExecutorMgrInterface()254   virtual ~CollectiveExecutorMgrInterface() {}
255 
256   // Returns the step-specific CollectiveExecutor, creating if one does not
257   // already exist.  The caller assumes ownership of one Ref on the object.
258   virtual CollectiveExecutor* FindOrCreate(int64_t step_id) = 0;
259 
260   // If there is a CollectiveExecutor for step_id, remove it from the
261   // table.
262   virtual void Cleanup(int64_t step_id) = 0;
263 
264   virtual ParamResolverInterface* GetParamResolver() const = 0;
265 
266   virtual DeviceResolverInterface* GetDeviceResolver() const = 0;
267 
268   virtual NcclCommunicatorInterface* GetNcclCommunicator() const = 0;
269 };
270 
271 // Interface that a Collective Op implementation uses to exchange data
272 // with peers.  Note that data exchange is currently limited to types
273 // for which DMAHelper::CanUseDMA() returns true, i.e.  dense numeric
274 // types.
275 class CollectiveRemoteAccess {
276  public:
~CollectiveRemoteAccess()277   virtual ~CollectiveRemoteAccess() {}
278 
279   virtual void RecvFromPeer(const string& peer_device, const string& peer_task,
280                             bool peer_is_local, const string& key,
281                             Device* to_device, DeviceContext* to_device_ctx,
282                             const AllocatorAttributes& to_alloc_attr,
283                             Tensor* to_tensor,
284                             const DeviceLocality& client_locality,
285                             int dev_to_dev_stream_index,
286                             CancellationManager* cancellation_manager,
287                             const StatusCallback& done) = 0;
288 
289   virtual void PostToPeer(const string& peer_device, const string& peer_task,
290                           const string& key, Device* from_device,
291                           DeviceContext* from_device_ctx,
292                           const AllocatorAttributes& from_alloc_attr,
293                           const Tensor* from_tensor,
294                           const DeviceLocality& client_locality,
295                           CancellationManager* cancellation_manager,
296                           const StatusCallback& done) = 0;
297 
298   // Checks the health of a collective peer. It probes the peer to see if it is
299   // alive. Note that if a peer has restarted, it's considered a different one,
300   // so CheckPeerHealth fails.
301   virtual void CheckPeerHealth(const string& peer_task, int64_t timeout_in_ms,
302                                const StatusCallback& done) = 0;
303 
304   virtual BufRendezvous* buf_rendezvous() = 0;
305 
306   virtual void StartAbort(const Status& s) = 0;
307 };
308 
309 // A step-specific object that can execute a collective operation completely
310 // described by a CollectiveParams object.
311 class CollectiveExecutor : public core::RefCounted {
312  public:
StartAbort(const Status & s)313   virtual void StartAbort(const Status& s) {}
314 
ExecuteAsync(OpKernelContext * ctx,const CollectiveParams * col_params,const string & exec_key,StatusCallback done)315   virtual void ExecuteAsync(OpKernelContext* ctx,
316                             const CollectiveParams* col_params,
317                             const string& exec_key, StatusCallback done) {
318     done(errors::Internal(
319         "A collective Op has been called in a context in which "
320         "a CollectiveExecutor has not been provided."));
321   }
322 
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,StatusCallback done)323   virtual void CompleteParamsAsync(const DeviceAttributes& device,
324                                    CollectiveParams* cp,
325                                    CancellationManager* cancel_mgr,
326                                    StatusCallback done) {
327     done(errors::Internal(
328         "A collective Op has been called in a context in which "
329         "a CollectiveExecutor has not been provided."));
330   }
331 
CompleteGroupAsync(const DeviceAttributes & device,CollGroupParams * group_params,CancellationManager * cancel_mgr,StatusCallback done)332   virtual void CompleteGroupAsync(const DeviceAttributes& device,
333                                   CollGroupParams* group_params,
334                                   CancellationManager* cancel_mgr,
335                                   StatusCallback done) {
336     return cem_->GetParamResolver()->CompleteGroupAsync(device, group_params,
337                                                         cancel_mgr, done);
338   }
339 
LookupGroup(int32_t group_key,CollGroupParams * group)340   virtual Status LookupGroup(int32_t group_key, CollGroupParams* group) {
341     return cem_->GetParamResolver()->LookupGroup(group_key, group);
342   }
343 
344   // Runs the potentially-blocking closure/expensive callback.
345   virtual void RunClosure(std::function<void()> closure) = 0;
346 
remote_access()347   virtual CollectiveRemoteAccess* remote_access() { return nullptr; }
348 
349   // `WaitForDependencies` and `Launched` are used for fine-grained control of
350   // execution order between collective instances.  These functions are intended
351   // to be called in `Run` function of collective implementations, and may be
352   // used to make part, or whole, of the collective execution ordered with
353   // respect to other collective instances.
354   //
355   // `WaitForDependencies` will block until it is safe to continue the callee's
356   // execution, where safety is defined as: ordered with respect to the
357   // collective instances defined in the callee's `wait_for` attribute.
WaitForDependencies(const CollectiveParams & col_params)358   virtual void WaitForDependencies(const CollectiveParams& col_params) {}
359   // `UnblockDependencies` unblocks the dependent collective instances by
360   // recording that this caller's device has completed the critical portion of
361   // the collective execution.
UnblockDependencies(const CollectiveParams & col_params)362   virtual void UnblockDependencies(const CollectiveParams& col_params) {}
363 
364   // Used to designate an invalid group or instance key.
365   static int64_t kInvalidId;
366 
367   // Lexically scoped handle for Ref.
368   class Handle {
369    public:
Handle(CollectiveExecutor * ce,bool inherit_ref)370     explicit Handle(CollectiveExecutor* ce, bool inherit_ref) : ce_(ce) {
371       if (!inherit_ref) ce->Ref();
372     }
~Handle()373     ~Handle() { ce_->Unref(); }
get()374     CollectiveExecutor* get() const { return ce_; }
375 
376    private:
377     CollectiveExecutor* ce_;
378   };
379 
380  protected:
CollectiveExecutor(CollectiveExecutorMgrInterface * cem)381   explicit CollectiveExecutor(CollectiveExecutorMgrInterface* cem)
382       : cem_(cem) {}
383 
384   // For use only by derived classes
385   static OpKernelContext::Params* CtxParams(OpKernelContext* ctx);
386   CollectiveExecutorMgrInterface* cem_;
387 
388   TF_DISALLOW_COPY_AND_ASSIGN(CollectiveExecutor);
389 };
390 
391 struct CollectiveContext {
392   CollectiveExecutor* col_exec;                  // Not owned
393   NcclCommunicatorInterface* nccl_communicator;  // Not owned
394   const DeviceMgr* dev_mgr;                      // Not owned
395   OpKernelContext* op_ctx;                       // Not owned
396   OpKernelContext::Params* op_params;            // Not owned
397   core::IntrusivePtr<const CollectiveParams> col_params;
398   const string exec_key;
399   const int64_t step_id;
400   const Tensor* input;  // Not owned
401   Tensor* output;       // Not owned
402   Device* device;       // The device for which this instance labors
403   const string device_name;
404   DeviceLocality device_locality;
405 
406   CollectiveContext(CollectiveExecutor* col_exec,
407                     NcclCommunicatorInterface* nccl_communicator,
408                     const DeviceMgr* dev_mgr, OpKernelContext* ctx,
409                     OpKernelContext::Params* op_params,
410                     const CollectiveParams* col_params, const string& exec_key,
411                     int64_t step_id, const Tensor* input, Tensor* output);
412 };
413 
414 class NcclCommunicatorInterface {
415  public:
416   virtual ~NcclCommunicatorInterface() = default;
417 
418   virtual string GenerateCommunicatorKey() = 0;
419 
420   virtual void Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
421                        StatusCallback done) = 0;
422 
423   virtual void StartAbort(const Status& s) = 0;
424 };
425 
426 // Interface of a Collective Op implementation.  Each specific CollectiveOp will
427 // implement this interface and register the implementation via the
428 // CollectiveRegistry detailed below.  See common_runtime/ring_reducer and
429 // common_runtime/hierarchical_tree_broadcaster for examples.
430 class CollectiveImplementationInterface : public core::RefCounted {
431  public:
432   virtual ~CollectiveImplementationInterface() = default;
433 
434   // Initializes the portions of `col_params` specific to this
435   // implementation.  Called exactly once for every Collective instance during
436   // the CollectiveParams resolution process when the graph is first executed,
437   // at the end of `CompleteInstanceLocal()`.
438   // NOTE(ayushd): This is effectively a static function because it modifies the
439   // `col_params` passed in and should not manipulate any data members.  However
440   // because it is virtual and needs to be implemented by every derived class we
441   // do not mark it as static.
442   virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0;
443 
444   // Prepares the CollectiveContext for executing this CollectiveImplementation.
445   // Called from CollectiveExecutor right before calling Run().  The
446   // CollectiveContext passed in must outlive the CollectiveImplementation
447   // object.
448   virtual Status InitializeCollectiveContext(
449       std::shared_ptr<CollectiveContext> col_ctx) = 0;
450 
451   // Processes and moves data according to the logic of this Collective
452   // implementation.  Relies on appropriate initialization of op-specific
453   // CollectiveParams in InitializeCollectiveParams(), as well as appropriate
454   // context initialization in InitializeCollectiveContext().
455   virtual void Run(StatusCallback done) = 0;
456 };
457 
458 // Static-methods only class for registering and looking up collective
459 // implementations.
460 class CollectiveRegistry {
461  public:
462   using Factory = std::function<CollectiveImplementationInterface*()>;
463   // Looks up a previously registered CollectiveImplementation under
464   // `collective_name`.  If found, creates an instance of the implementation and
465   // assign to `implementation`.
466   static Status Lookup(const string& collective_name,
467                        CollectiveImplementationInterface** implementation);
468 
469   // Looks up a previously registered CollectiveImplementation under
470   // `collective_name`.  If found, returns the static instance of this
471   // implementation via `implementation`.  This instance should only be used to
472   // call InitializateCollectiveParams.
473   static Status LookupParamResolverInstance(
474       const string& collective_name,
475       CollectiveImplementationInterface** implementation);
476 
477   // Returns all registered collective implementations.
478   static void GetAll(
479       std::vector<CollectiveImplementationInterface*>* implementations);
480 
481  private:
482   friend class CollectiveRegistration;
483   // Registers a CollectiveImplementation with name `collective_name` and
484   // factory `factory`.  The latter is a function used to create instances of
485   // the CollectiveImplementation.  Also creates a static instance of the
486   // implementation - this instance is used during param resolution and should
487   // only be used to call InitializeCollectiveParams.
488   static Status Register(const string& collective_name, Factory factory);
489 
490   static Status LookupHelper(const string& collective_name,
491                              CollectiveImplementationInterface** implementation,
492                              bool param_resolver);
493 };
494 
495 // Class used to call CollectiveRegistry::Register.  This should only be used to
496 // create a global static object.
497 class CollectiveRegistration {
498  public:
CollectiveRegistration(const string & collective_name,CollectiveRegistry::Factory factory)499   CollectiveRegistration(const string& collective_name,
500                          CollectiveRegistry::Factory factory) {
501     TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory));
502   }
503 };
504 
505 #define REGISTER_COLLECTIVE(name, implementation)             \
506   static CollectiveRegistration register_##name##_collective( \
507       #name, []() { return new implementation; });
508 
509 }  // namespace tensorflow
510 
511 #endif  // TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
512