• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
17 
18 #include <string>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/device_attributes.pb.h"
22 #include "tensorflow/core/framework/device_base.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/lib/core/refcount.h"
25 #include "tensorflow/core/lib/core/status.h"
26 
27 namespace tensorflow {
28 
29 class BufRendezvous;
30 class CancellationManager;
31 class CompleteGroupRequest;
32 class CompleteGroupResponse;
33 class CompleteInstanceRequest;
34 class CompleteInstanceResponse;
35 class Device;
36 class DeviceMgr;
37 class GetStepSequenceRequest;
38 class GetStepSequenceResponse;
39 class Tensor;
40 
41 // Types of supported collective operations.
42 enum CollectiveType {
43   REDUCTION_COLLECTIVE = 0,
44   BROADCAST_COLLECTIVE,
45   GATHER_COLLECTIVE,
46   UNDEFINED_COLLECTIVE,
47 };
48 
49 // Some collective op implementations require runtime group configuration from
50 // the OpKernel.  Currently, this struct is used to set communicator key for
51 // NCCL-based collective implementation.
52 struct CollGroupRuntimeDetails {
53   string communicator_key;  // for communicator-based techniques e.g. NCCL
54   string ToString() const;
55 };
56 
57 // Data common to all members of a device group.
58 // All members share the same device set but its order is
59 // particular to an instance so it is stored there.
60 struct CollGroupParams {
61   int32 group_key;
62   int32 group_size;
63   DeviceType device_type;
64   int32 num_tasks;  // number of distinct tasks in group
65   CollGroupRuntimeDetails runtime_details;
66   string ToString() const;
CollGroupParamsCollGroupParams67   CollGroupParams()
68       : group_key(0), group_size(0), device_type(DEVICE_CPU), num_tasks(0) {}
69 };
70 
71 // The best implementation of a collective op depends on many factors
72 // including the number of devices involved, the topology of
73 // interconnects between them and the sizes of inputs.  This structure
74 // is used in generating and representing data movement choreography
75 // for each specific algorithm, hence it does not have a single, fixed
76 // interpretation.  On first execution the runtime will update this
77 // structure with decisions that will guide all subsequent executions.
78 struct CollImplDetails {
79   string collective_name;
80   std::vector<std::vector<int>> subdiv_permutations;
81   std::vector<int> subdiv_offsets;
82   std::vector<int> subdiv_source_rank;  // rank of source in each subdiv
83   std::vector<int32>
84       dependencies;           // collective instances on which this node depends
85   string communication_hint;  // user-supplied hint for implementation choice,
86                               // e.g. ring or nccl
87 };
88 
89 // Data common to all members of a collective instance.
90 struct CollInstanceParams {
91   // Identifies all participating graph nodes.
92   int32 instance_key = -1;
93   CollectiveType type = UNDEFINED_COLLECTIVE;
94   DataType data_type = DT_FLOAT;
95   TensorShape shape = {0};
96   // Fully qualified name of device for each member, in default rank order.
97   std::vector<string> device_names;
98   // Task name prefix of corresponding device name.
99   std::vector<string> task_names;
100   // True if every task has the same number of devices.
101   bool same_num_devices_per_task = false;
102   // Task -> number of devices on that task.
103   std::unordered_map<string, int32> num_devices_per_task;
104   // If passed in to GPUOptions in ConfigProto, defines a good ring order for
105   // GPUs.  Assumes same GPU configuration at each worker.
106   string gpu_ring_order = "";
107   CollImplDetails impl_details;
108   string ToString() const;
109   CollInstanceParams& operator=(const struct CollInstanceParams& other);
110 };
111 
112 // Data common to all instance members in the same task.
113 struct CollTaskParams {
114   // True for devices that are local to the process, i.e. no RPC needed.
115   std::vector<bool> is_local;
116   string ToString() const;
117 };
118 
119 // Unique to a single CollectiveOp node.
120 struct CollectiveParams {
121   CollGroupParams group;
122   CollInstanceParams instance;
123   CollTaskParams task;
124 
125   string name = "";        // node name used only for log or error messages
126   int default_rank = -1;   // index of this op within device_names
127   bool is_source = false;  // broadcast only
128   int source_rank = -1;    // broadcast only
129   // Rank of this device in each subdivision permutation.
130   std::vector<int> subdiv_rank;
131   std::unique_ptr<OpKernel> merge_op;  // reduction only
132   std::unique_ptr<OpKernel> final_op;  // reduction only
133   string ToString() const;
134 };
135 
136 class CollectiveExecutor;
137 
138 // Interface that provides resolution of device localities.
139 class DeviceResolverInterface {
140  public:
~DeviceResolverInterface()141   virtual ~DeviceResolverInterface() {}
142 
143   // Collects DeviceAttributes protobufs from all of the devices identified
144   // in 'col_params'.
145   virtual void GetAllDeviceAttributesAsync(
146       const std::vector<string>& devices, const std::vector<string>& tasks,
147       std::vector<DeviceAttributes>* attributes,
148       const StatusCallback& done) = 0;
149 
150   // Populate *attributes with the DeviceAttributes of the specified
151   // device.
152   virtual void GetDeviceAttributesAsync(const string& device,
153                                         const string& task,
154                                         DeviceAttributes* attributes,
155                                         const StatusCallback& done) = 0;
156 
157   // Clear the cache of device data belonging to the specified task.
158   virtual void ClearTask(const string& task) = 0;
159 
160   // Clear the cache of all device data.
161   virtual void ClearCache() = 0;
162 };
163 
164 // Interface that provides resolution of shared CollectiveParams fields.
165 class ParamResolverInterface {
166  public:
~ParamResolverInterface()167   virtual ~ParamResolverInterface() {}
168 
169   // Called by each collective op at first execution in order to fill out
170   // the CollectiveParams structure with data gathered from the full
171   // (maybe distributed) collection of peer nodes.
172   virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp,
173                                    CancellationManager* cancel_mgr,
174                                    const StatusCallback& done) = 0;
175 
176   // Used within a distributed implementation to discover/verify
177   // data shared across a device group.
178   virtual void CompleteGroupAsync(const CompleteGroupRequest* request,
179                                   CompleteGroupResponse* response,
180                                   CancellationManager* cancel_mgr,
181                                   const StatusCallback& done) = 0;
182 
183   // Used within a distributed implementation to discover/verify data
184   // shared across an instance group.
185   virtual void CompleteInstanceAsync(const CompleteInstanceRequest* request,
186                                      CompleteInstanceResponse* response,
187                                      CancellationManager* cancel_mgr,
188                                      const StatusCallback& done) = 0;
189 };
190 
191 // Graphs which utilize Collective Ops in a common instance must
192 // execute with identical step_ids even if they are disjoint graphs
193 // run by otherwise independent tasks.  This interface supplies
194 // coordinated step_ids to use in such cases.
195 class StepSequenceInterface {
196  public:
~StepSequenceInterface()197   virtual ~StepSequenceInterface() {}
198 
199   // Used with a distributed implementation to coordinate step_id
200   // sequences across tasks.
201   virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request,
202                                     GetStepSequenceResponse* response,
203                                     const StatusCallback& done) = 0;
204 
205   // Refresh the local per-graph_key step_id sequence from collective
206   // group leader, if applicable.
207   virtual void RefreshStepIdSequenceAsync(int64 graph_key,
208                                           const StatusCallback& done) = 0;
209 
210   // Returns the step_id that should be used for initiating a new execution
211   // on the specified graph. May return the same step_id multiple times if
212   // RetireStepId or RefreshStepIdReservation is not called.
213   virtual int64 NextStepId(int64 graph_key) = 0;
214 
215   // Reports that execution of the given step has completed successfully.
216   // Should be called immediately after a step completes with OK status,
217   // prior to calling NextStepId().  If the step fails, don't call.
218   virtual void RetireStepId(int64 graph_key, int64 step_id) = 0;
219 };
220 
221 // Interface that provides access to per-step CollectiveExecutor
222 // instances and various distributed resolution capabilities.
223 class CollectiveExecutorMgrInterface : public StepSequenceInterface {
224  public:
~CollectiveExecutorMgrInterface()225   virtual ~CollectiveExecutorMgrInterface() {}
226 
227   // Returns the step-specific CollectiveExecutor, creating if one does not
228   // already exist.  The caller assumes ownership of one Ref on the object.
229   virtual CollectiveExecutor* FindOrCreate(int64 step_id) = 0;
230 
231   // If there is a CollectiveExecutor for step_id, remove it from the
232   // table.
233   virtual void Cleanup(int64 step_id) = 0;
234 
235   virtual ParamResolverInterface* GetParamResolver() const = 0;
236 
237   virtual DeviceResolverInterface* GetDeviceResolver() const = 0;
238 };
239 
240 // Interface that a Collective Op implementation uses to exchange data
241 // with peers.  Note that data exchange is currently limited to types
242 // for which DMAHelper::CanUseDMA() returns true, i.e.  dense numeric
243 // types.
244 class PeerAccessInterface {
245  public:
~PeerAccessInterface()246   virtual ~PeerAccessInterface() {}
247 
248   virtual void RecvFromPeer(const string& peer_device, const string& peer_task,
249                             bool peer_is_local, const string& key,
250                             Device* to_device, DeviceContext* to_device_ctx,
251                             const AllocatorAttributes& to_alloc_attr,
252                             Tensor* to_tensor,
253                             const DeviceLocality& client_locality,
254                             int dev_to_dev_stream_index,
255                             const StatusCallback& done) = 0;
256 
257   virtual void PostToPeer(const string& peer_device, const string& peer_task,
258                           const string& key, Device* from_device,
259                           DeviceContext* from_device_ctx,
260                           const AllocatorAttributes& from_alloc_attr,
261                           const Tensor* from_tensor,
262                           const DeviceLocality& client_locality,
263                           const StatusCallback& done) = 0;
264 
265   // Runs the potentially-blocking closure/expensive callback.
266   virtual void RunClosure(std::function<void()> closure) = 0;
267 };
268 
269 class PerStepCollectiveRemoteAccess;
270 
271 // A step-specific object that can execute a collective operation completely
272 // described by a CollectiveParams object.
273 class CollectiveExecutor : public PeerAccessInterface, public core::RefCounted {
274  public:
StartAbort(const Status & s)275   virtual void StartAbort(const Status& s) {}
276 
ExecuteAsync(OpKernelContext * ctx,const CollectiveParams & col_params,const string & exec_key,StatusCallback done)277   virtual void ExecuteAsync(OpKernelContext* ctx,
278                             const CollectiveParams& col_params,
279                             const string& exec_key, StatusCallback done) {
280     done(errors::Internal(
281         "A collective Op has been called in a context in which "
282         "a CollectiveExecutor has not been provided."));
283   }
284 
CompleteParamsAsync(const string & device,CollectiveParams * cp,CancellationManager * cancel_mgr,StatusCallback done)285   virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp,
286                                    CancellationManager* cancel_mgr,
287                                    StatusCallback done) {
288     done(errors::Internal(
289         "A collective Op has been called in a context in which "
290         "a CollectiveExecutor has not been provided."));
291   }
292 
remote_access()293   virtual PerStepCollectiveRemoteAccess* remote_access() { return nullptr; }
294 
295   // `WaitForDependencies` and `Launched` are used for fine-grained control of
296   // execution order between collective instances.  These functions are intended
297   // to be called in `Run` function of collective implementations, and may be
298   // used to make part, or whole, of the collective execution ordered with
299   // respect to other collective instances.
300   //
301   // `WaitForDependencies` will block until it is safe to continue the callee's
302   // execution, where safety is defined as: ordered with respect to the
303   // collective instances defined in the callee's `wait_for` attribute.
WaitForDependencies(const CollectiveParams & col_params)304   virtual void WaitForDependencies(const CollectiveParams& col_params) {}
305   // `UnblockDependencies` unblocks the dependent collective instances by
306   // recording that this caller's device has completed the critical portion of
307   // the collective execution.
UnblockDependencies(const CollectiveParams & col_params)308   virtual void UnblockDependencies(const CollectiveParams& col_params) {}
309 
310   // Used to designate an invalid group or instance key.
311   static int64 kInvalidId;
312 
313   // Lexically scoped handle for Ref.
314   class Handle {
315    public:
Handle(CollectiveExecutor * ce,bool inherit_ref)316     explicit Handle(CollectiveExecutor* ce, bool inherit_ref) : ce_(ce) {
317       if (!inherit_ref) ce->Ref();
318     }
~Handle()319     ~Handle() { ce_->Unref(); }
get()320     CollectiveExecutor* get() const { return ce_; }
321 
322    private:
323     CollectiveExecutor* ce_;
324   };
325 
326  protected:
CollectiveExecutor(CollectiveExecutorMgrInterface * cem)327   explicit CollectiveExecutor(CollectiveExecutorMgrInterface* cem)
328       : cem_(cem) {}
329 
330   // For use only by derived classes
331   static OpKernelContext::Params* CtxParams(OpKernelContext* ctx);
332   CollectiveExecutorMgrInterface* cem_;
333 
334   TF_DISALLOW_COPY_AND_ASSIGN(CollectiveExecutor);
335 };
336 
337 // Interface of a helper object that provides a CollectiveExecutor with
338 // all of the remote access it needs.
339 class CollectiveRemoteAccess : public PeerAccessInterface,
340                                public DeviceResolverInterface {
341  public:
~CollectiveRemoteAccess()342   virtual ~CollectiveRemoteAccess() {}
343 
344   virtual BufRendezvous* buf_rendezvous() = 0;
345 };
346 
347 // A per-step version of CollectiveRemoteAccess that cleans up outstanding
348 // communications in case step execution is abandoned.
349 class PerStepCollectiveRemoteAccess : public CollectiveRemoteAccess {
350  public:
~PerStepCollectiveRemoteAccess()351   virtual ~PerStepCollectiveRemoteAccess() {}
352   virtual void StartAbort(const Status& s) = 0;
353 };
354 
355 class CollectiveContext {
356  public:
357   CollectiveContext(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
358                     OpKernelContext* ctx, OpKernelContext::Params* op_params,
359                     const CollectiveParams& col_params, const string& exec_key,
360                     int64 step_id, const Tensor* input, Tensor* output);
361 
362   virtual ~CollectiveContext() = default;
363 
364   CollectiveExecutor* col_exec;        // Not owned
365   const DeviceMgr* dev_mgr;            // Not owned
366   OpKernelContext* op_ctx;             // Not owned
367   OpKernelContext::Params* op_params;  // Not owned
368   const CollectiveParams& col_params;
369   const string exec_key;
370   const int64 step_id;
371   const Tensor* input;  // Not owned
372   Tensor* output;       // Not owned
373   Device* device;       // The device for which this instance labors
374   const string device_name;
375   DeviceLocality device_locality;
376 };
377 
378 // Interface of a Collective Op implementation.  Each specific CollectiveOp will
379 // implement this interface and register the implementation via the
380 // CollectiveRegistry detailed below.  See common_runtime/ring_reducer and
381 // common_runtime/hierarchical_tree_broadcaster for examples.
382 class CollectiveImplementationInterface {
383  public:
384   virtual ~CollectiveImplementationInterface() = default;
385 
386   // Initializes the portions of `col_params` specific to this
387   // implementation.  Called exactly once for every Collective instance during
388   // the CollectiveParams resolution process when the graph is first executed,
389   // at the end of `CompleteInstanceLocal()`.
390   // NOTE(ayushd): This is effectively a static function because it modifies the
391   // `col_params` passed in and should not manipulate any data members.  However
392   // because it is virtual and needs to be implemented by every derived class we
393   // do not mark it as static.
394   virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0;
395 
396   // Prepares the CollectiveContext for executing this CollectiveImplementation.
397   // Called from CollectiveExecutor right before calling Run().  The
398   // CollectiveContext passed in must outlive the CollectiveImplementation
399   // object.
400   virtual Status InitializeCollectiveContext(CollectiveContext* col_ctx) = 0;
401 
402   // Performs collective implementation specific group initialization.  The
403   // intention is to do group-specific initialization of runtime details for the
404   // collective implementation.  Currently used only to set `communicator_key`
405   // in techniques which use a communicator for distributed collectives (NCCL).
406   virtual Status InitializeCollectiveGroupRuntimeDetails(
407       CollGroupRuntimeDetails* col_group_runtime_details) = 0;
408 
409   // Processes and moves data according to the logic of this Collective
410   // implementation.  Relies on appropriate initialization of op-specific
411   // CollectiveParams in InitializeCollectiveParams(), as well as appropriate
412   // context initialization in InitializeCollectiveContext().
413   virtual void Run(StatusCallback done) = 0;
414 };
415 
416 // Static-methods only class for registering and looking up collective
417 // implementations.
418 class CollectiveRegistry {
419  public:
420   using Factory = std::function<CollectiveImplementationInterface*()>;
421   // Looks up a previously registered CollectiveImplementation under
422   // `collective_name`.  If found, creates an instance of the implementation and
423   // assign to `implementation`.
424   static Status Lookup(const string& collective_name,
425                        CollectiveImplementationInterface** implementation);
426 
427   // Looks up a previously registered CollectiveImplementation under
428   // `collective_name`.  If found, returns the static instance of this
429   // implementation via `implementation`.  This instance should only be used to
430   // call InitializateCollectiveParams.
431   static Status LookupParamResolverInstance(
432       const string& collective_name,
433       CollectiveImplementationInterface** implementation);
434 
435   // Returns all registered collective implementations.
436   static void GetAll(
437       std::vector<CollectiveImplementationInterface*>* implementations);
438 
439  private:
440   friend class CollectiveRegistration;
441   // Registers a CollectiveImplementation with name `collective_name` and
442   // factory `factory`.  The latter is a function used to create instances of
443   // the CollectiveImplementation.  Also creates a static instance of the
444   // implementation - this instance is used during param resolution and should
445   // only be used to call InitializeCollectiveParams.
446   static Status Register(const string& collective_name, Factory factory);
447 
448   static Status LookupHelper(const string& collective_name,
449                              CollectiveImplementationInterface** implementation,
450                              bool param_resolver);
451 };
452 
453 // Class used to call CollectiveRegistry::Register.  This should only be used to
454 // create a global static object.
455 class CollectiveRegistration {
456  public:
CollectiveRegistration(const string & collective_name,CollectiveRegistry::Factory factory)457   CollectiveRegistration(const string& collective_name,
458                          CollectiveRegistry::Factory factory) {
459     TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory));
460   }
461 };
462 
463 #define REGISTER_COLLECTIVE(name, implementation)             \
464   static CollectiveRegistration register_##name##_collective( \
465       #name, []() { return new implementation; });
466 
467 }  // namespace tensorflow
468 
469 #endif  // TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
470