• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_NCCL_NCCL_MANAGER_H_
16 #define TENSORFLOW_CORE_NCCL_NCCL_MANAGER_H_
17 
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 
20 #include <vector>
21 
22 // TODO(rmlarsen): Get rid of this workaround. "gpu_assert" is defined when
23 // setting EIGEN_USE_THREADS. But when defining EIGEN_USE_THREADS here,
24 // incAtomic and other CUDA specific symbols are no longer recognized.
25 #ifndef gpu_assert
26 #define gpu_assert(x)
27 #endif
28 
29 #include "absl/container/flat_hash_map.h"
30 #if GOOGLE_CUDA
31 #include "third_party/nccl/nccl.h"
32 #elif TENSORFLOW_USE_ROCM
33 #include "rocm/include/rccl/rccl.h"
34 #include "tensorflow/core/common_runtime/gpu_device_context.h"
35 #endif
36 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
37 #include "tensorflow/core/framework/device_base.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/platform/mutex.h"
40 #include "tensorflow/core/platform/stream_executor.h"
41 
42 namespace tensorflow {
43 
44 // NCCL manager is used to make the asynchronous communicator calls and to
45 // manage the per-device streams used for communication.
46 //
47 // See nccl_ops.cc for example usage, including description of memory
48 // management and stream synchronization.
49 class NcclManager {
50  public:
51   typedef std::function<void(Status)> DoneCallback;
52   NcclManager();
53   ~NcclManager();
54 
55   static NcclManager* instance();
56 
57 #if TENSORFLOW_USE_ROCM
58   static int instance_count;
59 #endif
60 
61   // Calls `ncclGetUniqueId` and returns the id as a string.  The returned value
62   // may be shared with other participants on different nodes and passed in to
63   // multi-node collective invocations.
64   string GenerateCommunicatorKey();
65 
66   // A participant in a Collective.
67   struct Participant {
ParticipantParticipant68     Participant(se::StreamExecutor* executor, se::Stream* tensor_stream,
69                 const DeviceBase::GpuDeviceInfo* info, const Tensor* input,
70                 Tensor* output, int global_rank, DoneCallback done_callback)
71         : executor(executor),
72           tensor_stream(tensor_stream),
73           event_mgr(info->event_mgr),
74           gpu_device_id(info->gpu_id),
75 #if TENSORFLOW_USE_ROCM
76           context(static_cast<GPUDeviceContext*>(info->default_context)),
77 #endif
78           input(input),
79           output(output),
80           global_rank(global_rank),
81           done_callback(std::move(done_callback)),
82           root(false) {
83       DCHECK(executor != nullptr);
84       DCHECK(event_mgr != nullptr);
85       DCHECK(tensor_stream != nullptr);
86     }
87 
88     // StreamExecutor for the device. Expected to be live for process lifetime.
89     se::StreamExecutor* const executor = nullptr;
90 
91     // `tensor_stream` is the stream that should be waited on to ensure
92     // `input`'s data is available on the GPU for the communication stream to
93     // access. It is also the stream that will use the produced data;
94     // `done_callback` is not called until the next kernel launched on `stream`
95     // would see the data. Owned by the caller, who must keep it live until
96     // `done_callback` is called.
97     se::Stream* const tensor_stream;
98 
99     // EventMgr which polls on executor.
100     // Owned by the caller, who must keep it live until `done_callback` is
101     // called.
102     EventMgr* const event_mgr;
103 
104     const int gpu_device_id;
105 
106 #if TENSORFLOW_USE_ROCM
107     GPUDeviceContext* const context;
108 #endif
109 
110     // Owned by the caller, who must keep it live until `done_callback` is
111     // called. Is NULL for participants that only receive data.
112     const Tensor* input;
113 
114     // Owned by the caller, who must keep it live until `done_callback` is
115     // called. Is NULL for participants that only send data.
116     Tensor* output;
117 
118     // Rank across all devices and all nodes.
119     // `global_rank` is not required for single-node collectives.
120     const int global_rank;
121 
122     // The callback which is called at the completion of the NCCL operation.
123     // When called, `output` has been set to the result of the operation. (note:
124     // the stream may not yet have been synced)
125     DoneCallback done_callback;
126 
127     // True if this is the root of the collective, e.g. source of broadcast.
128     bool root;
129   };
130 
131   // Data that provides context for the collective operation, including the
132   // operation key, number of participants, and communicator key.
133   struct Context {
ContextContext134     Context(const string& collective_key, int num_local_devices,
135             int num_global_devices, const string& communicator_key,
136             int source_rank)
137         : collective_key(collective_key),
138           num_local_devices(num_local_devices),
139           num_global_devices(num_global_devices),
140           communicator_key(communicator_key),
141           source_rank(source_rank) {}
142 
143     // Unique key for this collective instance
144     const string& collective_key;
145 
146     // Devices local to this node
147     int num_local_devices;
148 
149     // Devices across all nodes
150     int num_global_devices;
151 
152     // In order to use NCCL across nodes, the callee first has to generate a
153     // `communicator_key` via `GenerateCommunicatorKey()` function and share
154     // this with all the other nodes.  Each node should pass in this
155     // `communicator_key` to the `NcclManager` functions.
156     // `communicator_key` is not required for single-node collectives and can be
157     // empty.
158     const string& communicator_key;
159 
160     // Rank of broadcast source.
161     int source_rank;
162   };
163 
164   // Adds one participant to an all-reduce.
165   void AddToAllReduce(std::unique_ptr<Participant> participant,
166                       const Context& context, ncclRedOp_t reduction_op);
167 
168   // Adds one participant to an all-gather.
169   void AddToAllGather(std::unique_ptr<Participant> participant,
170                       const Context& context);
171 
172   // AddBroadcastSend and AddBroadcastRecv combine to send data from one sender
173   // to all receivers.
174   void AddBroadcastSend(std::unique_ptr<Participant> participant,
175                         const Context& context);
176   void AddBroadcastRecv(std::unique_ptr<Participant> participant,
177                         const Context& context);
178 
179   // AddReduceSend and AddReduceRecv combine to send data from all senders
180   // to one receiver.
181   void AddReduceSend(std::unique_ptr<Participant> participant,
182                      const Context& context, ncclRedOp_t reduction_op);
183   void AddReduceRecv(std::unique_ptr<Participant> participant,
184                      const Context& context, ncclRedOp_t reduction_op);
185 
186   // Signals that the `Collective` corresponding to `key` is ready to launch
187   // across all nodes participating in this multi-node collective operation.
188   //
189   // This should only be called for multi-node collectives; single-node
190   // collectives are implicitly ready when all participants have called Add*
191   // function.
192   void SignalMultiNodeReady(const string& collective_key);
193 
194   // Aborts all collectives. After abortion, no further collectives can be
195   // launched with this NcclManager.
196   void StartAbort(const Status& s);
197 
198   // Resets a previously aborted NcclManager, making it available for future
199   // collectives.
200   void Reset();
201 
202  private:
203   enum CollectiveType {
204     kAllReduce = 1,
205     kBroadcast = 2,
206     kReduce = 3,
207     kAllGather = 4,
208   };
209   struct Collective;
210   struct Communicator;
211   struct CommunicatorMember;
212   struct NcclStream;
213 
214   // Gets the `Communicator` object that will be used to enqueue NCCL kernels
215   // for `collective`, and returns it via `communicator`.
216   //
217   // This may involve creating CUDA streams and NCCL initialization.  If a NCCL
218   // or CUDA error occurs in the process, this returns an INTERNAL error with
219   // the corresponding NCCL/CUDA error string.
220   Status GetCommunicator(Collective* collective, Communicator** communicator);
221 
222   // Adds a participant device to the local `Collective` instance corresponding
223   // to `collective_key`.  Launches the `Collective` if it is ready, which it
224   // checks by calling `CheckReady()`.  Also performs consistency and sanity
225   // checks before launching.
226   void AddParticipant(std::unique_ptr<Participant> participant,
227                       const Context& context, CollectiveType collective_type,
228                       ncclRedOp_t reduction_op);
229 
230   // If `collective` is ready to run, removes it from the `collectives_` map and
231   // returns true.  Otherwise returns false.
232   // Assumes `collective_key` corresponds to `collective`.
233   //
234   // A collective is ready to run when all local participants have called Add*
235   // function, and the collective is signalled globally ready via
236   // `SetMultiNodeReady`.
237   bool CheckReady(const string& collective_key, Collective* collective)
238       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
239 
240   // Run <collective>.  This calls takes ownership of <collective>.
241   void RunCollective(Collective* collective);
242   void LoopKernelLaunches(NcclStream* stream);
243 
244   mutex mu_;
245 
246   // Maps key to collectives currently being assembled or run.
247   absl::flat_hash_map<string, Collective*> collectives_ TF_GUARDED_BY(mu_);
248 
249   // Maps a device to the communication streams that make up its collective.
250   // This is used to share the stream across different communicators that
251   // include the same device.
252   absl::flat_hash_map<se::StreamExecutor*, std::vector<NcclStream*>>
253       device_to_comm_streams_ TF_GUARDED_BY(mu_);
254 
255   std::vector<std::unique_ptr<Communicator>> communicators_ TF_GUARDED_BY(mu_);
256 
257   Status status_ TF_GUARDED_BY(mu_);
258 
259   TF_DISALLOW_COPY_AND_ASSIGN(NcclManager);
260 };
261 
262 }  // namespace tensorflow
263 
264 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
265 
266 #endif  // TENSORFLOW_CORE_NCCL_NCCL_MANAGER_H_
267