• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
17 
18 #include <string>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/device_attributes.pb.h"
22 #include "tensorflow/core/framework/device_base.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/lib/core/refcount.h"
25 #include "tensorflow/core/lib/core/status.h"
26 
27 namespace tensorflow {
28 
29 class BufRendezvous;
30 class CancellationManager;
31 class CompleteGroupRequest;
32 class CompleteGroupResponse;
33 class CompleteInstanceRequest;
34 class CompleteInstanceResponse;
35 class Device;
36 class DeviceMgr;
37 class GetStepSequenceRequest;
38 class GetStepSequenceResponse;
39 class NcclManager;
40 class Tensor;
41 
42 // Types of supported collective operations.
43 enum CollectiveType {
44   REDUCTION_COLLECTIVE = 0,
45   BROADCAST_COLLECTIVE,
46   GATHER_COLLECTIVE,
47   PERMUTE_COLLECTIVE,
48   ALL_TO_ALL_COLLECTIVE,
49   UNDEFINED_COLLECTIVE,
50 };
51 
52 // Some collective op implementations require runtime group configuration from
53 // the OpKernel.  Currently, this struct is used to set communicator key for
54 // NCCL-based collective implementation.
55 struct CollGroupRuntimeDetails {
56   string communicator_key;  // for communicator-based techniques e.g. NCCL
57   string ToString() const;
58 };
59 
60 // Data common to all members of a device group.
61 // All members share the same device set but its order is
62 // particular to an instance so it is stored there.
63 struct CollGroupParams {
64   int32 group_key;
65   int32 group_size;
66   DeviceType device_type;
67   // Devices in this group, in default rank order.
68   std::vector<DeviceAttributes> devices;
69   // Task name prefix of corresponding device name.
70   std::vector<string> task_names;
71   // True if every task has the same number of devices.
72   bool same_num_devices_per_task = false;
73   // Task -> number of devices on that task.
74   std::unordered_map<string, int32> num_devices_per_task;
75   int32 num_tasks;  // number of distinct tasks in group
76   CollGroupRuntimeDetails runtime_details;
77   string ToString() const;
CollGroupParamsCollGroupParams78   CollGroupParams()
79       : group_key(0), group_size(0), device_type(DEVICE_CPU), num_tasks(0) {}
80 };
81 
82 // The best implementation of a collective op depends on many factors
83 // including the number of devices involved, the topology of
84 // interconnects between them and the sizes of inputs.  This structure
85 // is used in generating and representing data movement choreography
86 // for each specific algorithm, hence it does not have a single, fixed
87 // interpretation.  On first execution the runtime will update this
88 // structure with decisions that will guide all subsequent executions.
89 struct CollImplDetails {
90   string collective_name;
91   std::vector<std::vector<int>> subdiv_permutations;
92   // subdiv_offsets and max_subdivs_per_device are used together as follows:
93   // When subdiv_offsets is provided (non-empty) it is used as is. When
94   // subdiv_offsets is not provided subdivisons are generated dynamically
95   // constrained by max_subdivs_per_device. When subdiv_offsets is empty AND
96   // max_subdivs_per_device = 0 an internal default kMaxSubdivsPerDeviceDefault
97   // is used. When max_subdivs_per_device = -1, no subivision is done.
98   int max_subdivs_per_device = -1;  // Upper bound on subdivisions per device.
99   std::vector<int> subdiv_offsets;
100   std::vector<int> subdiv_source_rank;  // rank of source in each subdiv
101   std::vector<int32>
102       dependencies;           // collective instances on which this node depends
103   string communication_hint;  // user-supplied hint for implementation choice,
104                               // e.g. ring or nccl
105   float timeout_seconds;      // If non zero, set a completion timeout for the
106                               // collective op to detect staleness.
107 };
108 
109 // Data common to all members of a collective instance.
110 // TODO(b/163171014) Refactor this struct to not be a union of all fields.
111 struct CollInstanceParams {
112   // Identifies all participating graph nodes.
113   int32 instance_key = -1;
114   CollectiveType type = UNDEFINED_COLLECTIVE;
115   DataType data_type = DT_FLOAT;
116   TensorShape shape = {0};
117   CollImplDetails impl_details;
118   string ToString() const;
119   CollInstanceParams& operator=(const struct CollInstanceParams& other);
120   std::vector<string> devices;  // permuter only
121 
122   // For permuter only
123   // Each rank in the permutation is a receiver.
124   // Indices of each rank means a sender to that rank.
125   // Example: permutation = {2,0,1} means
126   //   rank 0 sends to rank 2
127   //   rank 1 sends to rank 0
128   //   rank 2 sends to rank 1
129   std::vector<int> permutation;
130 };
131 
132 // Data common to all instance members in the same task.
133 struct CollTaskParams {
134   // True for devices that are local to the process, i.e. no RPC needed.
135   std::vector<bool> is_local;
136   string ToString() const;
137 };
138 
139 // Unique to a single CollectiveOp node.
140 struct CollectiveParams : public core::RefCounted {
141   CollGroupParams group;
142   CollInstanceParams instance;
143   CollTaskParams task;
144 
145   string name = "";        // node name used only for log or error messages
146   int default_rank = -1;   // index of this op within device_names
147   bool is_source = false;  // broadcast only
148   int source_rank = -1;    // broadcast only
149   // Rank of this device in each subdivision permutation.
150   std::vector<int> subdiv_rank;
151   OpKernel* merge_op = nullptr;  // reduction only
152   OpKernel* final_op = nullptr;  // reduction only
153   string ToString() const;
154 };
155 
156 class CollectiveExecutor;
157 
158 // Interface that provides resolution of device localities.
159 class DeviceResolverInterface {
160  public:
~DeviceResolverInterface()161   virtual ~DeviceResolverInterface() {}
162 
163   // Populates *attributes with the DeviceAttributes of the specified device.
164   virtual Status GetDeviceAttributes(const string& device,
165                                      DeviceAttributes* attributes) = 0;
166 
167   // Returns all device attributes of a task.
168   virtual Status GetAllDeviceAttributes(
169       const string& task, std::vector<DeviceAttributes>* attributes) = 0;
170 
171   // Updates device attributes. It returns error if any device already
172   // exists in the DeviceResolver and has a different incarnation.
173   virtual Status UpdateDeviceAttributes(
174       const std::vector<DeviceAttributes>& attributes) = 0;
175 };
176 
177 // Interface that provides resolution of shared CollectiveParams fields.
178 class ParamResolverInterface {
179  public:
~ParamResolverInterface()180   virtual ~ParamResolverInterface() {}
181 
182   // Called by each collective op at first execution in order to fill out
183   // the CollectiveParams structure with data gathered from the full
184   // (maybe distributed) collection of peer nodes.
185   virtual void CompleteParamsAsync(const DeviceAttributes& device,
186                                    CollectiveParams* cp,
187                                    CancellationManager* cancel_mgr,
188                                    const StatusCallback& done) = 0;
189 
190   // Completes group_params with data gathered from all devices in the group.
191   // This blocks until all devices are there.
192   virtual void CompleteGroupAsync(const DeviceAttributes& device,
193                                   CollGroupParams* group_params,
194                                   CancellationManager* cancel_mgr,
195                                   const StatusCallback& done) = 0;
196 
197   // Used within a distributed implementation to discover/verify data
198   // shared across an instance group.
199   // Note: this works differently from CompleteGroupAsync as a refactor is in
200   // progress.
201   virtual void CompleteInstanceAsync(const CompleteInstanceRequest* request,
202                                      CompleteInstanceResponse* response,
203                                      CancellationManager* cancel_mgr,
204                                      const StatusCallback& done) = 0;
205 
206   // Aborts the resolver. After abortion the resolver can no longer be used.
207   virtual void StartAbort(const Status& s) = 0;
208 };
209 
210 // Graphs which utilize Collective Ops in a common instance must
211 // execute with identical step_ids even if they are disjoint graphs
212 // run by otherwise independent tasks.  This interface supplies
213 // coordinated step_ids to use in such cases.
214 class StepSequenceInterface {
215  public:
~StepSequenceInterface()216   virtual ~StepSequenceInterface() {}
217 
218   // Used with a distributed implementation to coordinate step_id
219   // sequences across tasks.
220   virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request,
221                                     GetStepSequenceResponse* response,
222                                     const StatusCallback& done) = 0;
223 
224   // Refresh the local per-graph_key step_id sequence from collective
225   // group leader, if applicable.
226   virtual void RefreshStepIdSequenceAsync(int64_t graph_key,
227                                           const StatusCallback& done) = 0;
228 
229   // Returns the step_id that should be used for initiating a new execution
230   // on the specified graph. May return the same step_id multiple times if
231   // RetireStepId or RefreshStepIdReservation is not called.
232   virtual int64 NextStepId(int64_t graph_key) = 0;
233 
234   // Reports that execution of the given step has completed successfully.
235   // Should be called immediately after a step completes with OK status,
236   // prior to calling NextStepId().  If the step fails, don't call.
237   virtual void RetireStepId(int64_t graph_key, int64_t step_id) = 0;
238 };
239 
240 class NcclCommunicatorInterface;
241 
242 // Interface that provides access to per-step CollectiveExecutor
243 // instances and various distributed resolution capabilities.
244 class CollectiveExecutorMgrInterface : public StepSequenceInterface {
245  public:
~CollectiveExecutorMgrInterface()246   virtual ~CollectiveExecutorMgrInterface() {}
247 
248   // Returns the step-specific CollectiveExecutor, creating if one does not
249   // already exist.  The caller assumes ownership of one Ref on the object.
250   virtual CollectiveExecutor* FindOrCreate(int64_t step_id) = 0;
251 
252   // If there is a CollectiveExecutor for step_id, remove it from the
253   // table.
254   virtual void Cleanup(int64_t step_id) = 0;
255 
256   virtual ParamResolverInterface* GetParamResolver() const = 0;
257 
258   virtual DeviceResolverInterface* GetDeviceResolver() const = 0;
259 
260   virtual NcclCommunicatorInterface* GetNcclCommunicator() const = 0;
261 };
262 
263 // Interface that a Collective Op implementation uses to exchange data
264 // with peers.  Note that data exchange is currently limited to types
265 // for which DMAHelper::CanUseDMA() returns true, i.e.  dense numeric
266 // types.
267 class CollectiveRemoteAccess {
268  public:
~CollectiveRemoteAccess()269   virtual ~CollectiveRemoteAccess() {}
270 
271   virtual void RecvFromPeer(const string& peer_device, const string& peer_task,
272                             bool peer_is_local, const string& key,
273                             Device* to_device, DeviceContext* to_device_ctx,
274                             const AllocatorAttributes& to_alloc_attr,
275                             Tensor* to_tensor,
276                             const DeviceLocality& client_locality,
277                             int dev_to_dev_stream_index,
278                             CancellationManager* cancellation_manager,
279                             const StatusCallback& done) = 0;
280 
281   virtual void PostToPeer(const string& peer_device, const string& peer_task,
282                           const string& key, Device* from_device,
283                           DeviceContext* from_device_ctx,
284                           const AllocatorAttributes& from_alloc_attr,
285                           const Tensor* from_tensor,
286                           const DeviceLocality& client_locality,
287                           CancellationManager* cancellation_manager,
288                           const StatusCallback& done) = 0;
289 
290   // Checks the health of a collective peer. It probes the peer to see if it is
291   // alive. Note that if a peer has restarted, it's considered a different one,
292   // so CheckPeerHealth fails.
293   virtual void CheckPeerHealth(const string& peer_task, int64_t timeout_in_ms,
294                                const StatusCallback& done) = 0;
295 
296   virtual BufRendezvous* buf_rendezvous() = 0;
297 
298   virtual void StartAbort(const Status& s) = 0;
299 };
300 
301 // A step-specific object that can execute a collective operation completely
302 // described by a CollectiveParams object.
303 class CollectiveExecutor : public core::RefCounted {
304  public:
StartAbort(const Status & s)305   virtual void StartAbort(const Status& s) {}
306 
ExecuteAsync(OpKernelContext * ctx,const CollectiveParams * col_params,const string & exec_key,StatusCallback done)307   virtual void ExecuteAsync(OpKernelContext* ctx,
308                             const CollectiveParams* col_params,
309                             const string& exec_key, StatusCallback done) {
310     done(errors::Internal(
311         "A collective Op has been called in a context in which "
312         "a CollectiveExecutor has not been provided."));
313   }
314 
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,StatusCallback done)315   virtual void CompleteParamsAsync(const DeviceAttributes& device,
316                                    CollectiveParams* cp,
317                                    CancellationManager* cancel_mgr,
318                                    StatusCallback done) {
319     done(errors::Internal(
320         "A collective Op has been called in a context in which "
321         "a CollectiveExecutor has not been provided."));
322   }
323 
324   // Runs the potentially-blocking closure/expensive callback.
325   virtual void RunClosure(std::function<void()> closure) = 0;
326 
remote_access()327   virtual CollectiveRemoteAccess* remote_access() { return nullptr; }
328 
329   // `WaitForDependencies` and `Launched` are used for fine-grained control of
330   // execution order between collective instances.  These functions are intended
331   // to be called in `Run` function of collective implementations, and may be
332   // used to make part, or whole, of the collective execution ordered with
333   // respect to other collective instances.
334   //
335   // `WaitForDependencies` will block until it is safe to continue the callee's
336   // execution, where safety is defined as: ordered with respect to the
337   // collective instances defined in the callee's `wait_for` attribute.
WaitForDependencies(const CollectiveParams & col_params)338   virtual void WaitForDependencies(const CollectiveParams& col_params) {}
339   // `UnblockDependencies` unblocks the dependent collective instances by
340   // recording that this caller's device has completed the critical portion of
341   // the collective execution.
UnblockDependencies(const CollectiveParams & col_params)342   virtual void UnblockDependencies(const CollectiveParams& col_params) {}
343 
344   // Used to designate an invalid group or instance key.
345   static int64_t kInvalidId;
346 
347   // Lexically scoped handle for Ref.
348   class Handle {
349    public:
Handle(CollectiveExecutor * ce,bool inherit_ref)350     explicit Handle(CollectiveExecutor* ce, bool inherit_ref) : ce_(ce) {
351       if (!inherit_ref) ce->Ref();
352     }
~Handle()353     ~Handle() { ce_->Unref(); }
get()354     CollectiveExecutor* get() const { return ce_; }
355 
356    private:
357     CollectiveExecutor* ce_;
358   };
359 
360  protected:
CollectiveExecutor(CollectiveExecutorMgrInterface * cem)361   explicit CollectiveExecutor(CollectiveExecutorMgrInterface* cem)
362       : cem_(cem) {}
363 
364   // For use only by derived classes
365   static OpKernelContext::Params* CtxParams(OpKernelContext* ctx);
366   CollectiveExecutorMgrInterface* cem_;
367 
368   TF_DISALLOW_COPY_AND_ASSIGN(CollectiveExecutor);
369 };
370 
371 struct CollectiveContext {
372   CollectiveExecutor* col_exec;                  // Not owned
373   NcclCommunicatorInterface* nccl_communicator;  // Not owned
374   const DeviceMgr* dev_mgr;                      // Not owned
375   OpKernelContext* op_ctx;                       // Not owned
376   OpKernelContext::Params* op_params;            // Not owned
377   const CollectiveParams* col_params;            // Not owned
378   const string exec_key;
379   const int64 step_id;
380   const Tensor* input;  // Not owned
381   Tensor* output;       // Not owned
382   Device* device;       // The device for which this instance labors
383   const string device_name;
384   DeviceLocality device_locality;
385 
386   CollectiveContext(CollectiveExecutor* col_exec,
387                     NcclCommunicatorInterface* nccl_communicator,
388                     const DeviceMgr* dev_mgr, OpKernelContext* ctx,
389                     OpKernelContext::Params* op_params,
390                     const CollectiveParams* col_params, const string& exec_key,
391                     int64_t step_id, const Tensor* input, Tensor* output);
392 };
393 
394 class NcclCommunicatorInterface {
395  public:
396   virtual ~NcclCommunicatorInterface() = default;
397 
398   virtual string GenerateCommunicatorKey() = 0;
399 
400   virtual void Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
401                        StatusCallback done) = 0;
402 
403   virtual void StartAbort(const Status& s) = 0;
404 };
405 
406 // Interface of a Collective Op implementation.  Each specific CollectiveOp will
407 // implement this interface and register the implementation via the
408 // CollectiveRegistry detailed below.  See common_runtime/ring_reducer and
409 // common_runtime/hierarchical_tree_broadcaster for examples.
410 class CollectiveImplementationInterface : public core::RefCounted {
411  public:
412   virtual ~CollectiveImplementationInterface() = default;
413 
414   // Initializes the portions of `col_params` specific to this
415   // implementation.  Called exactly once for every Collective instance during
416   // the CollectiveParams resolution process when the graph is first executed,
417   // at the end of `CompleteInstanceLocal()`.
418   // NOTE(ayushd): This is effectively a static function because it modifies the
419   // `col_params` passed in and should not manipulate any data members.  However
420   // because it is virtual and needs to be implemented by every derived class we
421   // do not mark it as static.
422   virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0;
423 
424   // Prepares the CollectiveContext for executing this CollectiveImplementation.
425   // Called from CollectiveExecutor right before calling Run().  The
426   // CollectiveContext passed in must outlive the CollectiveImplementation
427   // object.
428   virtual Status InitializeCollectiveContext(
429       std::shared_ptr<CollectiveContext> col_ctx) = 0;
430 
431   // Processes and moves data according to the logic of this Collective
432   // implementation.  Relies on appropriate initialization of op-specific
433   // CollectiveParams in InitializeCollectiveParams(), as well as appropriate
434   // context initialization in InitializeCollectiveContext().
435   virtual void Run(StatusCallback done) = 0;
436 };
437 
438 // Static-methods only class for registering and looking up collective
439 // implementations.
440 class CollectiveRegistry {
441  public:
442   using Factory = std::function<CollectiveImplementationInterface*()>;
443   // Looks up a previously registered CollectiveImplementation under
444   // `collective_name`.  If found, creates an instance of the implementation and
445   // assign to `implementation`.
446   static Status Lookup(const string& collective_name,
447                        CollectiveImplementationInterface** implementation);
448 
449   // Looks up a previously registered CollectiveImplementation under
450   // `collective_name`.  If found, returns the static instance of this
451   // implementation via `implementation`.  This instance should only be used to
452   // call InitializateCollectiveParams.
453   static Status LookupParamResolverInstance(
454       const string& collective_name,
455       CollectiveImplementationInterface** implementation);
456 
457   // Returns all registered collective implementations.
458   static void GetAll(
459       std::vector<CollectiveImplementationInterface*>* implementations);
460 
461  private:
462   friend class CollectiveRegistration;
463   // Registers a CollectiveImplementation with name `collective_name` and
464   // factory `factory`.  The latter is a function used to create instances of
465   // the CollectiveImplementation.  Also creates a static instance of the
466   // implementation - this instance is used during param resolution and should
467   // only be used to call InitializeCollectiveParams.
468   static Status Register(const string& collective_name, Factory factory);
469 
470   static Status LookupHelper(const string& collective_name,
471                              CollectiveImplementationInterface** implementation,
472                              bool param_resolver);
473 };
474 
475 // Class used to call CollectiveRegistry::Register.  This should only be used to
476 // create a global static object.
477 class CollectiveRegistration {
478  public:
CollectiveRegistration(const string & collective_name,CollectiveRegistry::Factory factory)479   CollectiveRegistration(const string& collective_name,
480                          CollectiveRegistry::Factory factory) {
481     TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory));
482   }
483 };
484 
485 #define REGISTER_COLLECTIVE(name, implementation)             \
486   static CollectiveRegistration register_##name##_collective( \
487       #name, []() { return new implementation; });
488 
489 }  // namespace tensorflow
490 
491 #endif  // TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
492