• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
17 
18 #include <string>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/device_attributes.pb.h"
22 #include "tensorflow/core/framework/device_base.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/lib/core/refcount.h"
25 #include "tensorflow/core/lib/core/status.h"
26 
27 namespace tensorflow {
28 
29 class BufRendezvous;
30 class CancellationManager;
31 class CompleteGroupRequest;
32 class CompleteGroupResponse;
33 class CompleteInstanceRequest;
34 class CompleteInstanceResponse;
35 class Device;
36 class DeviceMgr;
37 class GetStepSequenceRequest;
38 class GetStepSequenceResponse;
39 class NcclManager;
40 class Tensor;
41 
42 // Types of supported collective operations.
43 enum CollectiveType {
44   REDUCTION_COLLECTIVE = 0,
45   BROADCAST_COLLECTIVE,
46   GATHER_COLLECTIVE,
47   PERMUTE_COLLECTIVE,
48   UNDEFINED_COLLECTIVE,
49 };
50 
51 // Some collective op implementations require runtime group configuration from
52 // the OpKernel.  Currently, this struct is used to set communicator key for
53 // NCCL-based collective implementation.
54 struct CollGroupRuntimeDetails {
55   string communicator_key;  // for communicator-based techniques e.g. NCCL
56   string ToString() const;
57 };
58 
59 // Data common to all members of a device group.
60 // All members share the same device set but its order is
61 // particular to an instance so it is stored there.
62 struct CollGroupParams {
63   int32 group_key;
64   int32 group_size;
65   DeviceType device_type;
66   // Fully qualified name of device for each member, in default rank order.
67   std::vector<string> device_names;
68   // Task name prefix of corresponding device name.
69   std::vector<string> task_names;
70   // True if every task has the same number of devices.
71   bool same_num_devices_per_task = false;
72   // Task -> number of devices on that task.
73   std::unordered_map<string, int32> num_devices_per_task;
74   // If passed in to GPUOptions in ConfigProto, defines a good ring order for
75   // GPUs.  Assumes same GPU configuration at each worker.
76   string gpu_ring_order = "";
77   int32 num_tasks;  // number of distinct tasks in group
78   CollGroupRuntimeDetails runtime_details;
79   string ToString() const;
CollGroupParamsCollGroupParams80   CollGroupParams()
81       : group_key(0), group_size(0), device_type(DEVICE_CPU), num_tasks(0) {}
82 };
83 
84 // The best implementation of a collective op depends on many factors
85 // including the number of devices involved, the topology of
86 // interconnects between them and the sizes of inputs.  This structure
87 // is used in generating and representing data movement choreography
88 // for each specific algorithm, hence it does not have a single, fixed
89 // interpretation.  On first execution the runtime will update this
90 // structure with decisions that will guide all subsequent executions.
91 struct CollImplDetails {
92   string collective_name;
93   std::vector<std::vector<int>> subdiv_permutations;
94   std::vector<int> subdiv_offsets;
95   std::vector<int> subdiv_source_rank;  // rank of source in each subdiv
96   std::vector<int32>
97       dependencies;           // collective instances on which this node depends
98   string communication_hint;  // user-supplied hint for implementation choice,
99                               // e.g. ring or nccl
100   float timeout_seconds;      // If non zero, set a completion timeout for the
101                               // collective op to detect staleness.
102 };
103 
104 // Data common to all members of a collective instance.
105 // TODO(b/163171014) Refactor this struct to not be a union of all fields.
106 struct CollInstanceParams {
107   // Identifies all participating graph nodes.
108   int32 instance_key = -1;
109   CollectiveType type = UNDEFINED_COLLECTIVE;
110   DataType data_type = DT_FLOAT;
111   TensorShape shape = {0};
112   CollImplDetails impl_details;
113   string ToString() const;
114   CollInstanceParams& operator=(const struct CollInstanceParams& other);
115   std::vector<string> devices;  // permuter only
116 
117   // For permuter only
118   // Each rank in the permutation is a receiver.
119   // Indices of each rank means a sender to that rank.
120   // Example: permutation = {2,0,1} means
121   //   rank 0 sends to rank 2
122   //   rank 1 sends to rank 0
123   //   rank 2 sends to rank 1
124   std::vector<int> permutation;
125 };
126 
127 // Data common to all instance members in the same task.
128 struct CollTaskParams {
129   // True for devices that are local to the process, i.e. no RPC needed.
130   std::vector<bool> is_local;
131   string ToString() const;
132 };
133 
134 // Unique to a single CollectiveOp node.
135 struct CollectiveParams : public core::RefCounted {
136   CollGroupParams group;
137   CollInstanceParams instance;
138   CollTaskParams task;
139 
140   string name = "";        // node name used only for log or error messages
141   int default_rank = -1;   // index of this op within device_names
142   bool is_source = false;  // broadcast only
143   int source_rank = -1;    // broadcast only
144   // Rank of this device in each subdivision permutation.
145   std::vector<int> subdiv_rank;
146   OpKernel* merge_op = nullptr;  // reduction only
147   OpKernel* final_op = nullptr;  // reduction only
148   string ToString() const;
149 };
150 
151 class CollectiveExecutor;
152 
153 // Interface that provides resolution of device localities.
154 class DeviceResolverInterface {
155  public:
~DeviceResolverInterface()156   virtual ~DeviceResolverInterface() {}
157 
158   // Populates *attributes with the DeviceAttributes of the specified device.
159   virtual Status GetDeviceAttributes(const string& device,
160                                      DeviceAttributes* attributes) = 0;
161 
162   // Returns all device attributes of a task.
163   virtual Status GetAllDeviceAttributes(
164       const string& task, std::vector<DeviceAttributes>* attributes) = 0;
165 
166   // Updates device attributes. It returns error if any device already
167   // exists in the DeviceResolver and has a different incarnation.
168   virtual Status UpdateDeviceAttributes(
169       const std::vector<DeviceAttributes>& attributes) = 0;
170 };
171 
172 // Interface that provides resolution of shared CollectiveParams fields.
173 class ParamResolverInterface {
174  public:
~ParamResolverInterface()175   virtual ~ParamResolverInterface() {}
176 
177   // Called by each collective op at first execution in order to fill out
178   // the CollectiveParams structure with data gathered from the full
179   // (maybe distributed) collection of peer nodes.
180   virtual void CompleteParamsAsync(const DeviceAttributes& device,
181                                    CollectiveParams* cp,
182                                    CancellationManager* cancel_mgr,
183                                    const StatusCallback& done) = 0;
184 
185   // Used within a distributed implementation to discover/verify
186   // data shared across a device group.
187   virtual void CompleteGroupAsync(const CompleteGroupRequest* request,
188                                   CompleteGroupResponse* response,
189                                   CancellationManager* cancel_mgr,
190                                   const StatusCallback& done) = 0;
191 
192   // Used within a distributed implementation to discover/verify data
193   // shared across an instance group.
194   virtual void CompleteInstanceAsync(const CompleteInstanceRequest* request,
195                                      CompleteInstanceResponse* response,
196                                      CancellationManager* cancel_mgr,
197                                      const StatusCallback& done) = 0;
198 
199   // Aborts the resolver. After abortion the resolver can no longer be used.
200   virtual void StartAbort(const Status& s) = 0;
201 };
202 
203 // Graphs which utilize Collective Ops in a common instance must
204 // execute with identical step_ids even if they are disjoint graphs
205 // run by otherwise independent tasks.  This interface supplies
206 // coordinated step_ids to use in such cases.
207 class StepSequenceInterface {
208  public:
~StepSequenceInterface()209   virtual ~StepSequenceInterface() {}
210 
211   // Used with a distributed implementation to coordinate step_id
212   // sequences across tasks.
213   virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request,
214                                     GetStepSequenceResponse* response,
215                                     const StatusCallback& done) = 0;
216 
217   // Refresh the local per-graph_key step_id sequence from collective
218   // group leader, if applicable.
219   virtual void RefreshStepIdSequenceAsync(int64 graph_key,
220                                           const StatusCallback& done) = 0;
221 
222   // Returns the step_id that should be used for initiating a new execution
223   // on the specified graph. May return the same step_id multiple times if
224   // RetireStepId or RefreshStepIdReservation is not called.
225   virtual int64 NextStepId(int64 graph_key) = 0;
226 
227   // Reports that execution of the given step has completed successfully.
228   // Should be called immediately after a step completes with OK status,
229   // prior to calling NextStepId().  If the step fails, don't call.
230   virtual void RetireStepId(int64 graph_key, int64 step_id) = 0;
231 };
232 
233 class NcclCommunicatorInterface;
234 
235 // Interface that provides access to per-step CollectiveExecutor
236 // instances and various distributed resolution capabilities.
237 class CollectiveExecutorMgrInterface : public StepSequenceInterface {
238  public:
~CollectiveExecutorMgrInterface()239   virtual ~CollectiveExecutorMgrInterface() {}
240 
241   // Returns the step-specific CollectiveExecutor, creating if one does not
242   // already exist.  The caller assumes ownership of one Ref on the object.
243   virtual CollectiveExecutor* FindOrCreate(int64 step_id) = 0;
244 
245   // If there is a CollectiveExecutor for step_id, remove it from the
246   // table.
247   virtual void Cleanup(int64 step_id) = 0;
248 
249   virtual ParamResolverInterface* GetParamResolver() const = 0;
250 
251   virtual DeviceResolverInterface* GetDeviceResolver() const = 0;
252 
253   virtual NcclCommunicatorInterface* GetNcclCommunicator() const = 0;
254 };
255 
256 // Interface that a Collective Op implementation uses to exchange data
257 // with peers.  Note that data exchange is currently limited to types
258 // for which DMAHelper::CanUseDMA() returns true, i.e.  dense numeric
259 // types.
260 class CollectiveRemoteAccess {
261  public:
~CollectiveRemoteAccess()262   virtual ~CollectiveRemoteAccess() {}
263 
264   virtual void RecvFromPeer(const string& peer_device, const string& peer_task,
265                             bool peer_is_local, const string& key,
266                             Device* to_device, DeviceContext* to_device_ctx,
267                             const AllocatorAttributes& to_alloc_attr,
268                             Tensor* to_tensor,
269                             const DeviceLocality& client_locality,
270                             int dev_to_dev_stream_index,
271                             CancellationManager* cancellation_manager,
272                             const StatusCallback& done) = 0;
273 
274   virtual void PostToPeer(const string& peer_device, const string& peer_task,
275                           const string& key, Device* from_device,
276                           DeviceContext* from_device_ctx,
277                           const AllocatorAttributes& from_alloc_attr,
278                           const Tensor* from_tensor,
279                           const DeviceLocality& client_locality,
280                           CancellationManager* cancellation_manager,
281                           const StatusCallback& done) = 0;
282 
283   // Checks the health of a collective peer. It probes the peer to see if it is
284   // alive. Note that if a peer has restarted, it's considered a different one,
285   // so CheckPeerHealth fails.
286   virtual void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms,
287                                const StatusCallback& done) = 0;
288 
289   virtual BufRendezvous* buf_rendezvous() = 0;
290 
291   virtual void StartAbort(const Status& s) = 0;
292 };
293 
294 // A step-specific object that can execute a collective operation completely
295 // described by a CollectiveParams object.
296 class CollectiveExecutor : public core::RefCounted {
297  public:
StartAbort(const Status & s)298   virtual void StartAbort(const Status& s) {}
299 
ExecuteAsync(OpKernelContext * ctx,const CollectiveParams * col_params,const string & exec_key,StatusCallback done)300   virtual void ExecuteAsync(OpKernelContext* ctx,
301                             const CollectiveParams* col_params,
302                             const string& exec_key, StatusCallback done) {
303     done(errors::Internal(
304         "A collective Op has been called in a context in which "
305         "a CollectiveExecutor has not been provided."));
306   }
307 
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,StatusCallback done)308   virtual void CompleteParamsAsync(const DeviceAttributes& device,
309                                    CollectiveParams* cp,
310                                    CancellationManager* cancel_mgr,
311                                    StatusCallback done) {
312     done(errors::Internal(
313         "A collective Op has been called in a context in which "
314         "a CollectiveExecutor has not been provided."));
315   }
316 
317   // Runs the potentially-blocking closure/expensive callback.
318   virtual void RunClosure(std::function<void()> closure) = 0;
319 
remote_access()320   virtual CollectiveRemoteAccess* remote_access() { return nullptr; }
321 
322   // `WaitForDependencies` and `Launched` are used for fine-grained control of
323   // execution order between collective instances.  These functions are intended
324   // to be called in `Run` function of collective implementations, and may be
325   // used to make part, or whole, of the collective execution ordered with
326   // respect to other collective instances.
327   //
328   // `WaitForDependencies` will block until it is safe to continue the callee's
329   // execution, where safety is defined as: ordered with respect to the
330   // collective instances defined in the callee's `wait_for` attribute.
WaitForDependencies(const CollectiveParams & col_params)331   virtual void WaitForDependencies(const CollectiveParams& col_params) {}
332   // `UnblockDependencies` unblocks the dependent collective instances by
333   // recording that this caller's device has completed the critical portion of
334   // the collective execution.
UnblockDependencies(const CollectiveParams & col_params)335   virtual void UnblockDependencies(const CollectiveParams& col_params) {}
336 
337   // Used to designate an invalid group or instance key.
338   static int64 kInvalidId;
339 
340   // Lexically scoped handle for Ref.
341   class Handle {
342    public:
Handle(CollectiveExecutor * ce,bool inherit_ref)343     explicit Handle(CollectiveExecutor* ce, bool inherit_ref) : ce_(ce) {
344       if (!inherit_ref) ce->Ref();
345     }
~Handle()346     ~Handle() { ce_->Unref(); }
get()347     CollectiveExecutor* get() const { return ce_; }
348 
349    private:
350     CollectiveExecutor* ce_;
351   };
352 
353  protected:
CollectiveExecutor(CollectiveExecutorMgrInterface * cem)354   explicit CollectiveExecutor(CollectiveExecutorMgrInterface* cem)
355       : cem_(cem) {}
356 
357   // For use only by derived classes
358   static OpKernelContext::Params* CtxParams(OpKernelContext* ctx);
359   CollectiveExecutorMgrInterface* cem_;
360 
361   TF_DISALLOW_COPY_AND_ASSIGN(CollectiveExecutor);
362 };
363 
364 struct CollectiveContext {
365   CollectiveExecutor* col_exec;                  // Not owned
366   NcclCommunicatorInterface* nccl_communicator;  // Not owned
367   const DeviceMgr* dev_mgr;                      // Not owned
368   OpKernelContext* op_ctx;                       // Not owned
369   OpKernelContext::Params* op_params;            // Not owned
370   const CollectiveParams* col_params;            // Not owned
371   const string exec_key;
372   const int64 step_id;
373   const Tensor* input;  // Not owned
374   Tensor* output;       // Not owned
375   Device* device;       // The device for which this instance labors
376   const string device_name;
377   DeviceLocality device_locality;
378 
379   CollectiveContext(CollectiveExecutor* col_exec,
380                     NcclCommunicatorInterface* nccl_communicator,
381                     const DeviceMgr* dev_mgr, OpKernelContext* ctx,
382                     OpKernelContext::Params* op_params,
383                     const CollectiveParams* col_params, const string& exec_key,
384                     int64 step_id, const Tensor* input, Tensor* output);
385 };
386 
387 class NcclCommunicatorInterface {
388  public:
389   virtual ~NcclCommunicatorInterface() = default;
390 
391   virtual void Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
392                        StatusCallback done) = 0;
393 
394   virtual void StartAbort(const Status& s) = 0;
395 };
396 
397 // Interface of a Collective Op implementation.  Each specific CollectiveOp will
398 // implement this interface and register the implementation via the
399 // CollectiveRegistry detailed below.  See common_runtime/ring_reducer and
400 // common_runtime/hierarchical_tree_broadcaster for examples.
401 class CollectiveImplementationInterface : public core::RefCounted {
402  public:
403   virtual ~CollectiveImplementationInterface() = default;
404 
405   // Initializes the portions of `col_params` specific to this
406   // implementation.  Called exactly once for every Collective instance during
407   // the CollectiveParams resolution process when the graph is first executed,
408   // at the end of `CompleteInstanceLocal()`.
409   // NOTE(ayushd): This is effectively a static function because it modifies the
410   // `col_params` passed in and should not manipulate any data members.  However
411   // because it is virtual and needs to be implemented by every derived class we
412   // do not mark it as static.
413   virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0;
414 
415   // Prepares the CollectiveContext for executing this CollectiveImplementation.
416   // Called from CollectiveExecutor right before calling Run().  The
417   // CollectiveContext passed in must outlive the CollectiveImplementation
418   // object.
419   virtual Status InitializeCollectiveContext(
420       std::shared_ptr<CollectiveContext> col_ctx) = 0;
421 
422   // Performs collective implementation specific group initialization.  The
423   // intention is to do group-specific initialization of runtime details for the
424   // collective implementation.  Currently used only to set `communicator_key`
425   // in techniques which use a communicator for distributed collectives (NCCL).
426   virtual Status InitializeCollectiveGroupRuntimeDetails(
427       CollGroupRuntimeDetails* col_group_runtime_details) = 0;
428 
429   // Processes and moves data according to the logic of this Collective
430   // implementation.  Relies on appropriate initialization of op-specific
431   // CollectiveParams in InitializeCollectiveParams(), as well as appropriate
432   // context initialization in InitializeCollectiveContext().
433   virtual void Run(StatusCallback done) = 0;
434 };
435 
436 // Static-methods only class for registering and looking up collective
437 // implementations.
438 class CollectiveRegistry {
439  public:
440   using Factory = std::function<CollectiveImplementationInterface*()>;
441   // Looks up a previously registered CollectiveImplementation under
442   // `collective_name`.  If found, creates an instance of the implementation and
443   // assign to `implementation`.
444   static Status Lookup(const string& collective_name,
445                        CollectiveImplementationInterface** implementation);
446 
447   // Looks up a previously registered CollectiveImplementation under
448   // `collective_name`.  If found, returns the static instance of this
449   // implementation via `implementation`.  This instance should only be used to
450   // call InitializateCollectiveParams.
451   static Status LookupParamResolverInstance(
452       const string& collective_name,
453       CollectiveImplementationInterface** implementation);
454 
455   // Returns all registered collective implementations.
456   static void GetAll(
457       std::vector<CollectiveImplementationInterface*>* implementations);
458 
459  private:
460   friend class CollectiveRegistration;
461   // Registers a CollectiveImplementation with name `collective_name` and
462   // factory `factory`.  The latter is a function used to create instances of
463   // the CollectiveImplementation.  Also creates a static instance of the
464   // implementation - this instance is used during param resolution and should
465   // only be used to call InitializeCollectiveParams.
466   static Status Register(const string& collective_name, Factory factory);
467 
468   static Status LookupHelper(const string& collective_name,
469                              CollectiveImplementationInterface** implementation,
470                              bool param_resolver);
471 };
472 
473 // Class used to call CollectiveRegistry::Register.  This should only be used to
474 // create a global static object.
475 class CollectiveRegistration {
476  public:
CollectiveRegistration(const string & collective_name,CollectiveRegistry::Factory factory)477   CollectiveRegistration(const string& collective_name,
478                          CollectiveRegistry::Factory factory) {
479     TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory));
480   }
481 };
482 
483 #define REGISTER_COLLECTIVE(name, implementation)             \
484   static CollectiveRegistration register_##name##_collective( \
485       #name, []() { return new implementation; });
486 
487 }  // namespace tensorflow
488 
489 #endif  // TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
490