1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifdef TENSORFLOW_USE_MPI
17
18 #include <queue>
19 #include <thread>
20 #include <unordered_map>
21
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/platform/mutex.h"
26
27 #define EIGEN_USE_THREADS
28
29 #if GOOGLE_CUDA
30 #include <cuda_runtime.h>
31 #include "tensorflow/stream_executor/stream.h"
32 #endif
33
34 #include "tensorflow/stream_executor/lib/statusor.h"
35
36 #define OMPI_SKIP_MPICXX
37 #include "third_party/mpi/mpi.h"
38 #include "tensorflow/contrib/mpi_collectives/kernels/ring.h"
39 #include "tensorflow/contrib/mpi_collectives/mpi_message.pb.h"
40
41 /*
42 * MPI Allreduce and Allgather Ops for TensorFlow.
43 *
44 * TensorFlow natively provides inter-device communication through send and
45 * receive ops and inter-node communication through Distributed TensorFlow,
46 * based on the same send and receive abstractions. These end up being
47 * insufficient for synchronous data-parallel training on HPC clusters where
48 * Infiniband or other high-speed interconnects are available. This module
49 * implements MPI ops for allgather and allreduce, which do bandwidth-optimal
50 * gathers and reductions and can take advantage of hardware-optimized
51 * communication libraries through the MPI implementation.
52 *
53 * The primary logic of the allreduce and allgather are in RingAllgather() and
54 * RingAllreduce(). The background thread which facilitates MPI operations is
55 * run in BackgroundThreadLoop(). The provided MPI ops are:
56 * – MPIInit:
57 * Initialize MPI on a given device (CPU or GPU).
58 * Should only be run on a single device in every process.
59 * – MPISize:
60 * Get the number of MPI processes in the global communicator.
61 * – MPIRank:
62 * Get the rank of the current MPI process in the global communicator.
63 * – MPILocalRank:
64 * Get the local rank of the current MPI process within its node.
65 * – MPIAllreduce:
66 * Perform an allreduce on a Tensor, returning the sum
67 * across all MPI processes in the global communicator.
68 * – MPIAllgather:
69 * Perform an allgather on a Tensor, returning the concatenation of
70 * the tensor on the first dimension across all MPI processes in the
71 * global communicator.
72 *
73 */
74
75 template <class T>
76 using StatusOr = stream_executor::port::StatusOr<T>;
77
78 using CPUDevice = Eigen::ThreadPoolDevice;
79 using GPUDevice = Eigen::GpuDevice;
80
81 namespace tensorflow {
82 namespace contrib {
83 namespace mpi_collectives {
84
85 // Make sure template specializations are generated in the ring.cu.cc and the
86 // ring.cc file, not in this file.
87 extern template Status RingAllreduce<GPUDevice, int>(OpKernelContext*,
88 const Tensor*, Tensor*,
89 Tensor*);
90 extern template Status RingAllreduce<GPUDevice, long long>(OpKernelContext*,
91 const Tensor*,
92 Tensor*, Tensor*);
93 extern template Status RingAllreduce<GPUDevice, float>(OpKernelContext*,
94 const Tensor*, Tensor*,
95 Tensor*);
96 extern template Status RingAllgather<GPUDevice, int>(OpKernelContext*,
97 const Tensor*,
98 const std::vector<size_t>&,
99 Tensor*);
100 extern template Status RingAllgather<GPUDevice, long long>(
101 OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
102 extern template Status RingAllgather<GPUDevice, float>(
103 OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
104 extern template Status RingAllreduce<CPUDevice, int>(OpKernelContext*,
105 const Tensor*, Tensor*,
106 Tensor*);
107 extern template Status RingAllreduce<CPUDevice, long long>(OpKernelContext*,
108 const Tensor*,
109 Tensor*, Tensor*);
110 extern template Status RingAllreduce<CPUDevice, float>(OpKernelContext*,
111 const Tensor*, Tensor*,
112 Tensor*);
113 extern template Status RingAllgather<CPUDevice, int>(OpKernelContext*,
114 const Tensor*,
115 const std::vector<size_t>&,
116 Tensor*);
117 extern template Status RingAllgather<CPUDevice, long long>(
118 OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
119 extern template Status RingAllgather<CPUDevice, float>(
120 OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
121
122 namespace {
123
124 // Return true if the templated type is GPUDevice, otherwise false.
125 template <typename T>
126 bool IsGPUDevice();
127 template <>
IsGPUDevice()128 bool IsGPUDevice<GPUDevice>() {
129 return true;
130 };
131 template <>
IsGPUDevice()132 bool IsGPUDevice<CPUDevice>() {
133 return false;
134 };
135
136 // A callback to call after the MPI communication completes. Since the
137 // allreduce and allgather ops are asynchronous, this callback is what resumes
138 // computation after the reduction is completed.
139 typedef std::function<void(StatusOr<Tensor>)> CommunicationDoneCallback;
140
141 struct CollectiveOpRecord {
142 // The rank performing this piece of the op
143 int rank;
144
145 // The name of the op/tensor to be reduced
146 std::string name;
147
148 // The op's kernel context
149 OpKernelContext* context;
150
151 // Data type of the op
152 DataType dtype;
153
154 // The input tensor
155 const Tensor* in_t;
156
157 // Allgather: Vector of per-rank first-dimension sizes
158 std::vector<size_t> sizes_vec;
159
160 // The temp tensor for intermediate results
161 Tensor temp_t;
162
163 // The output tensor
164 Tensor* out_t;
165
166 // Whether to run this op on the gpu
167 bool on_gpu;
168
169 // The callback to call after the op has completed
170 CommunicationDoneCallback callback;
171 };
172
173 // Table storing Tensors to be reduced, keyed by unique name.
174 // This table contains everything necessary to do the reduction
175 typedef std::unordered_map<std::string, CollectiveOpRecord> TensorTable;
176
177 // Table for storing Tensor metadata on rank zero. This is used for error
178 // checking and size calculations, as well as determining when a reduction is
179 // ready to be done (when all nodes are ready to do it).
180 typedef std::unordered_map<std::string, std::vector<MPIRequest> > MessageTable;
181
182 // The global state required for the MPI ops.
183 //
184 // MPI is a library that stores a lot of global per-program state and often
185 // requires running on a single thread. As a result, we have to have a single
186 // background thread responsible for all MPI operations, and communicate with
187 // that background thread through global state.
188 struct MPIGlobalState {
189 // An atomic boolean which is set to true when MPI is initialized.
190 // This ensures that MPI_Init is never called twice.
191 std::atomic_flag initialized_flag = ATOMIC_FLAG_INIT;
192
193 // Condition variable to wait for initialization
194 condition_variable cv;
195
196 // Whether MPI_Init has been completed on the background thread.
197 bool initialization_done = false;
198
199 // Whether MPI_Init succeeded on the background thread.
200 Status init_status;
201
202 // A mutex that needs to be used whenever MPI operations touch
203 // shared structures.
204 mutex mu;
205
206 // Tensors waiting to be allreduced or allgathered.
207 TensorTable tensor_table;
208
209 // Queue of MPI requests waiting to be sent to the coordinator node.
210 std::queue<MPIRequest> message_queue;
211
212 // Background thread running MPI communication.
213 std::thread background_thread;
214
215 // Whether the background thread should shutdown.
216 bool shut_down = false;
217
218 // Only exists on the coordinator node (rank zero). Maintains a count of
219 // how many nodes are ready to allreduce every tensor (keyed by tensor
220 // name).
221 std::unique_ptr<MessageTable> message_table;
222
223 // The MPI rank, local rank, and size.
224 int rank = 0;
225 int local_rank = 0;
226 int size = 1;
227
228 // The device that MPI was initialized on. (-1 for no GPU)
229 int device = -1;
230
231 // The CUDA stream used for data transfers and within-allreduce operations.
232 // A naive implementation would use the TensorFlow StreamExecutor CUDA
233 // stream. However, the allreduce and allgather require doing memory copies
234 // and kernel executions (for accumulation of values on the GPU). However,
235 // the subsequent operations must wait for those operations to complete,
236 // otherwise MPI (which uses its own stream internally) will begin the data
237 // transfers before the CUDA calls are complete. In order to wait for those
238 // CUDA operations, if we were using the TensorFlow stream, we would have
239 // to synchronize that stream; however, other TensorFlow threads may be
240 // submitting more work to that stream, so synchronizing on it can cause
241 // the allreduce to be delayed, waiting for compute totally unrelated to it
242 // in other parts of the graph. Overlaying memory transfers and compute
243 // during backpropagation is crucial for good performance, so we cannot use
244 // the TensorFlow stream, and must use our own stream.
245 #if GOOGLE_CUDA
246 cudaStream_t stream;
247 std::atomic_flag stream_created_flag = ATOMIC_FLAG_INIT;
248 #endif
249
~MPIGlobalStatetensorflow::contrib::mpi_collectives::__anon63380d9d0111::MPIGlobalState250 ~MPIGlobalState() {
251 // Make sure that the destructor of the background thread is safe to
252 // call. If a thread is still joinable (not detached or complete) its
253 // destructor cannot be called.
254 if (background_thread.joinable()) {
255 shut_down = true;
256 background_thread.join();
257 }
258 }
259 };
260
261 // All the MPI state that must be stored globally per-process.
262 static MPIGlobalState mpi_global;
263
264 // For clarify in argument lists.
265 #define RANK_ZERO 0
266
267 // A tag used for all coordinator messaging.
268 #define TAG_NOTIFY 1
269
270 // Store the MPIRequest for a name, and return whether the total count of
271 // MPIRequests for that tensor is now equal to the MPI size (and thus we are
272 // ready to reduce the tensor).
IncrementTensorCount(std::unique_ptr<MessageTable> & message_table,MPIRequest msg,int mpi_size)273 bool IncrementTensorCount(std::unique_ptr<MessageTable>& message_table,
274 MPIRequest msg, int mpi_size) {
275 auto name = msg.tensor_name();
276 auto table_iter = message_table->find(name);
277 if (table_iter == message_table->end()) {
278 message_table->emplace(name, std::vector<MPIRequest>({msg}));
279 table_iter = message_table->find(name);
280 } else {
281 table_iter->second.push_back(msg);
282 }
283
284 int count = table_iter->second.size();
285 return count == mpi_size;
286 }
287
288 // Once a tensor is ready to be reduced, the coordinator sends an MPIResponse
289 // instructing all ranks to start the reduction to all ranks. The MPIResponse
290 // also contains error messages in case the submitted MPIRequests were not
291 // valid (for example, contained mismatched shapes or types).
292 //
293 // Constructing the MPIResponse, thus, requires a whole lot of error checking.
ConstructMPIResponse(std::unique_ptr<MessageTable> & message_table,std::string name)294 MPIResponse ConstructMPIResponse(std::unique_ptr<MessageTable>& message_table,
295 std::string name) {
296 bool error = false;
297 auto it = message_table->find(name);
298 assert(it != message_table->end());
299
300 std::vector<MPIRequest> requests = it->second;
301 assert(requests.size() > 0);
302
303 std::ostringstream error_message_stream;
304
305 // Check that all data types being reduced or gathered are identical
306 auto data_type = requests[0].tensor_type();
307 for (unsigned int i = 1; i < requests.size(); i++) {
308 auto request_type = requests[i].tensor_type();
309 if (data_type != request_type) {
310 error = true;
311 error_message_stream << "Mismatched data types: One rank had type "
312 << DataType_Name(data_type)
313 << ", but another rank had type "
314 << DataType_Name(request_type) << ".";
315 break;
316 }
317 }
318
319 // Check that all requested operations are the same
320 auto message_type = requests[0].request_type();
321 for (unsigned int i = 1; i < requests.size(); i++) {
322 if (error) {
323 break;
324 }
325
326 auto request_type = requests[i].request_type();
327 if (message_type != request_type) {
328 error = true;
329 error_message_stream << "Mismatched MPI operations: One rank did an "
330 << message_type << ", but another rank did an "
331 << request_type << ".";
332 break;
333 }
334 }
335
336 // If we are doing an allreduce, check that all tensor shapes
337 // are identical
338 if (message_type == MPIRequest::ALLREDUCE) {
339 TensorShape tensor_shape = requests[0].tensor_shape();
340 for (unsigned int i = 1; i < requests.size(); i++) {
341 if (error) {
342 break;
343 }
344
345 TensorShape request_shape = requests[i].tensor_shape();
346 if (tensor_shape != request_shape) {
347 error = true;
348 error_message_stream << "Mismatched allreduce tensor shapes: "
349 << "One rank reduced a tensor of shape "
350 << tensor_shape.DebugString()
351 << ", but another rank sent a tensor of shape "
352 << request_shape.DebugString() << ".";
353 break;
354 }
355 }
356 }
357
358 // If we are doing an allgather, make sure all but the first dimension are
359 // the same. The first dimension may be different and the output tensor is
360 // the sum of the first dimension. Collect the sizes by rank.
361 if (message_type == MPIRequest::ALLGATHER) {
362 TensorShape tensor_shape = requests[0].tensor_shape();
363
364 if (tensor_shape.dims() == 0) {
365 error = true;
366 error_message_stream << "Rank zero tried to gather a rank-zero tensor.";
367 }
368
369 for (unsigned int i = 1; i < requests.size(); i++) {
370 if (error) {
371 break;
372 }
373
374 TensorShape request_shape = requests[i].tensor_shape();
375 if (tensor_shape.dims() != request_shape.dims()) {
376 error = true;
377 error_message_stream << "Mismatched allgather tensor shapes: "
378 << "One rank gathered a tensor of rank "
379 << tensor_shape.dims()
380 << ", but another rank sent a tensor of rank "
381 << request_shape.dims() << ".";
382 break;
383 }
384
385 for (unsigned int dim = 1; dim < tensor_shape.dims(); dim++) {
386 if (tensor_shape.dim_size(dim) != request_shape.dim_size(dim)) {
387 error = true;
388 error_message_stream
389 << "Mismatched allgather tensor shapes: "
390 << "One rank gathered a tensor with dimension " << dim
391 << " equal to " << tensor_shape.dim_size(dim)
392 << ", but another rank sent a tensor with dimension " << dim
393 << " equal to " << request_shape.dim_size(dim) << ".";
394 break;
395 }
396 }
397 }
398 }
399
400 MPIResponse response;
401 response.set_tensor_name(name);
402 if (error) {
403 std::string error_message = error_message_stream.str();
404 response.set_response_type(MPIResponse::ERROR);
405 response.set_error_message(error_message);
406 } else {
407 auto response_type = MPIResponse::ERROR;
408 if (message_type == MPIRequest::ALLREDUCE) {
409 response_type = MPIResponse::ALLREDUCE;
410 } else {
411 response_type = MPIResponse::ALLGATHER;
412 }
413 response.set_response_type(response_type);
414 }
415
416 // Clear all queued up requests for this name. They are now taken care of
417 // by the constructed MPI response.
418 message_table->erase(it);
419
420 return response;
421 }
422
423 // Process an MPIResponse by doing a reduction, a gather, or raising an error.
PerformCollectiveOp(TensorTable & tensor_table,MPIResponse response)424 void PerformCollectiveOp(TensorTable& tensor_table, MPIResponse response) {
425 OpKernelContext* context;
426 const Tensor* input_tensor;
427 std::vector<size_t> sizes_vec;
428 Tensor temp_tensor;
429 Tensor* output_tensor;
430 CommunicationDoneCallback callback;
431 bool on_gpu;
432 {
433 // Lock on the tensor table.
434 mutex_lock guard(mpi_global.mu);
435
436 // We should never fail at finding this key in the tensor table.
437 auto name = response.tensor_name();
438 auto iter = tensor_table.find(name);
439 assert(iter != tensor_table.end());
440
441 assert(response.response_type() == MPIResponse::ALLREDUCE ||
442 response.response_type() == MPIResponse::ALLGATHER ||
443 response.response_type() == MPIResponse::ERROR);
444
445 CollectiveOpRecord record = iter->second;
446 context = record.context;
447 input_tensor = record.in_t;
448 sizes_vec = record.sizes_vec;
449 temp_tensor = record.temp_t;
450 output_tensor = record.out_t;
451 on_gpu = record.on_gpu;
452 callback = record.callback;
453
454 // Clear the tensor table of this tensor and its callbacks; the rest of
455 // this function takes care of it.
456 tensor_table.erase(iter);
457 }
458
459 // Use CPUDevice instead of GPUDevice if no CUDA, to ensure we don't
460 // link to non-existent symbols.
461 #if GOOGLE_CUDA
462 #define GPU_DEVICE_IF_CUDA GPUDevice
463 #else
464 #define GPU_DEVICE_IF_CUDA CPUDevice
465 #endif
466
467 Status status;
468 auto dtype = input_tensor->dtype();
469 if (response.response_type() == MPIResponse::ALLGATHER) {
470 if (dtype == DT_FLOAT) {
471 status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, float>(
472 context, input_tensor, sizes_vec, output_tensor)
473 : RingAllgather<CPUDevice, float>(
474 context, input_tensor, sizes_vec, output_tensor);
475 } else if (dtype == DT_INT32) {
476 status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, int>(
477 context, input_tensor, sizes_vec, output_tensor)
478 : RingAllgather<CPUDevice, int>(context, input_tensor,
479 sizes_vec, output_tensor);
480 } else if (dtype == DT_INT64) {
481 status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, long long>(
482 context, input_tensor, sizes_vec, output_tensor)
483 : RingAllgather<CPUDevice, long long>(
484 context, input_tensor, sizes_vec, output_tensor);
485 } else {
486 status = errors::Unknown("Invalid tensor type for MPI allgather.");
487 }
488 } else if (response.response_type() == MPIResponse::ALLREDUCE) {
489 if (dtype == DT_FLOAT) {
490 status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, float>(
491 context, input_tensor, &temp_tensor, output_tensor)
492 : RingAllreduce<CPUDevice, float>(
493 context, input_tensor, &temp_tensor, output_tensor);
494 } else if (dtype == DT_INT32) {
495 status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, int>(
496 context, input_tensor, &temp_tensor, output_tensor)
497 : RingAllreduce<CPUDevice, int>(
498 context, input_tensor, &temp_tensor, output_tensor);
499 } else if (dtype == DT_INT64) {
500 status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, long long>(
501 context, input_tensor, &temp_tensor, output_tensor)
502 : RingAllreduce<CPUDevice, long long>(
503 context, input_tensor, &temp_tensor, output_tensor);
504 } else {
505 status = errors::Unknown("Invalid tensor type for MPI allreduce.");
506 }
507 } else if (response.response_type() == MPIResponse::ERROR) {
508 status = errors::FailedPrecondition(response.error_message());
509 }
510
511 if (status.ok()) {
512 callback(StatusOr<Tensor>(*output_tensor));
513 } else {
514 callback(StatusOr<Tensor>(status));
515 }
516 }
517
518 // The MPI background thread loop coordinates all the MPI processes and the
519 // tensor reductions. The design of the communicator mechanism is limited by a
520 // few considerations:
521 //
522 // 1. Some MPI implementations require all MPI calls to happen from a
523 // single thread. Since TensorFlow may use several threads for graph
524 // processing, this means we must have our own dedicated thread for
525 // dealing with MPI.
526 // 2. We want to gracefully handle errors, when MPI processes do not
527 // properly agree upon what should happen (such as mismatched types or
528 // shapes). To do so requires the MPI processes to know about the shapes
529 // and types of the relevant tensors on the other processes.
530 // 3. The MPI reductions and gathers should be able to happen in parallel
531 // with other ongoing operations. Since MPI uses an internal
532 // (inaccessible) GPU stream separate from the TF GPUDevice streams, we
533 // cannot explicitly synchronize memcpys or kernels with it. As a result,
534 // MPIAllreduce and MPIAllgather must be AsyncOpKernels to ensure proper
535 // ordering of memcpys and kernels with respect to TF streams.
536 // 4. NOTE: We cannot guarantee that all the MPI processes reduce their
537 // tensors in the same order. Thus, there must be a way to ensure the
538 // reduction memcpys and kernels occur for correct tensors across all
539 // ranks at the same time. We choose to use a coordinator (rank ID 0) to
540 // gather and trigger the reduction operations that are ready to execute.
541 //
542 // The coordinator currently follows a master-worker paradigm. Rank zero acts
543 // as the master (the "coordinator"), whereas all other ranks are simply
544 // workers. Each rank runs its own background thread which progresses in ticks.
545 // In each tick, the following actions happen:
546 //
547 // a) The workers send any available MPIRequests to the coordinator. These
548 // MPIRequests indicate what the worker would like to do (i.e. which
549 // tensor they would like to gather or reduce, as well as their shape and
550 // type). They repeat this for every tensor that they would like to
551 // operate on after that tensor's collective op has executed ComputeAsync.
552 //
553 // b) The workers send an empty "DONE" message to the coordinator to
554 // indicate that there are no more tensors they wish to operate on.
555 //
556 // c) The coordinator receives the MPIRequests from the workers, as well
557 // as from its own TensorFlow ops, and stores them in a request table. The
558 // coordinator continues to receive MPIRequest messages until it has
559 // received MPI_SIZE number of empty "DONE" messages.
560 //
561 // d) The coordinator finds all tensors that are ready to be reduced,
562 // gathered, or all operations that result in an error. For each of those,
563 // it sends an MPIResponse to all the workers. When no more MPIResponses
564 // are available, it sends a "DONE" response to the workers. If the
565 // process is being shutdown, it instead sends a "SHUTDOWN" response.
566 //
567 // e) The workers listen for MPIResponse messages, processing each one by
568 // doing the required reduce or gather, until they receive a "DONE"
569 // response from the coordinator. At that point, the tick ends.
570 // If instead of "DONE" they receive "SHUTDOWN", they exit their
571 // background loop.
572 // TODO: Use the global mpi_global state variable instead of a local one
BackgroundThreadLoop()573 void BackgroundThreadLoop() {
574 #if GOOGLE_CUDA
575 // Set the device, so that this thread uses the same GPU context as the
576 // calling thread.
577 // TODO: Ensure that this is operating correctly. The background thread
578 // needs to be able to control all GPUs that the rank has access to, and
579 // might be more than 1 GPU. Tensors could be resident in any of the
580 // GPUs, so the background thread's accumulate and copy kernels might need
581 // to correctly set the device and it might be necessary for the background
582 // thread to manage multiple streams.
583 cudaSetDevice(mpi_global.device);
584 cudaStreamCreate(&mpi_global.stream);
585 #endif
586
587 // Initialize MPI. This must happen on the background thread, since not all
588 // MPI implementations support being called from multiple threads.
589 auto init_result = MPI_Init(NULL, NULL);
590 if (init_result != MPI_SUCCESS) {
591 mpi_global.init_status =
592 errors::Unknown("Could not initialize MPI; MPI_Init() failed.");
593 mpi_global.initialization_done = true;
594 mpi_global.cv.notify_all();
595 return;
596 } else {
597 mpi_global.init_status = Status::OK();
598 }
599
600 // Get MPI rank to determine if we are rank zero.
601 int rank;
602 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
603 bool is_coordinator = rank == 0;
604
605 // Get MPI size to determine how many tensors to wait for before reducing.
606 int size;
607 MPI_Comm_size(MPI_COMM_WORLD, &size);
608
609 // Determine local rank by querying the local communicator.
610 MPI_Comm local_comm;
611 MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL,
612 &local_comm);
613 int local_rank;
614 MPI_Comm_rank(local_comm, &local_rank);
615
616 mpi_global.rank = rank;
617 mpi_global.local_rank = local_rank;
618 mpi_global.size = size;
619 mpi_global.initialization_done = true;
620
621 // Notify calling thread that initialization is complete
622 mpi_global.cv.notify_all();
623
624 // TODO: MOVE MESSAGE TABLE INITIALIZATION TO LIBRARY LOAD!
625 // Initialize the tensor count table. No tensors are available yet.
626 if (is_coordinator) {
627 mpi_global.message_table =
628 std::unique_ptr<MessageTable>(new MessageTable());
629 }
630
631 // The coordinator sends a SHUTDOWN message to trigger shutdown.
632 bool should_shut_down = false;
633 do {
634 // TODO: Eliminate the need for thread sleep by making all activity
635 // depend on other activity (e.g. condition or MPI waits).
636 std::this_thread::sleep_for(std::chrono::milliseconds(1));
637
638 // Copy the data structures from global state under this lock.
639 // However, don't keep the lock for the rest of the loop, so that
640 // enqueued stream callbacks can continue.
641 std::queue<MPIRequest> message_queue;
642 {
643 mutex_lock guard(mpi_global.mu);
644 while (!mpi_global.message_queue.empty()) {
645 MPIRequest message = mpi_global.message_queue.front();
646 mpi_global.message_queue.pop();
647 message_queue.push(message);
648 }
649 }
650
651 // Collect all tensors that are ready to be reduced. Record them in the
652 // tensor count table (rank zero) or send them to rank zero to be
653 // recorded (everyone else).
654 std::vector<std::string> ready_to_reduce;
655 while (!message_queue.empty()) {
656 // Pop the first available message message
657 MPIRequest message = message_queue.front();
658 message_queue.pop();
659
660 if (is_coordinator) {
661 bool reduce =
662 IncrementTensorCount(mpi_global.message_table, message, size);
663 if (reduce) {
664 ready_to_reduce.push_back(message.tensor_name());
665 }
666 } else {
667 std::string encoded_message;
668 message.SerializeToString(&encoded_message);
669 MPI_Send(encoded_message.c_str(), encoded_message.length() + 1,
670 MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
671 }
672 }
673
674 // Rank zero has put all its own tensors in the tensor count table.
675 // Now, it should count all the tensors that are coming from other
676 // ranks at this tick. It should keep getting tensors until it gets a
677 // DONE message from all the other ranks.
678 if (is_coordinator) {
679 // Count of DONE messages. Keep receiving messages until the number
680 // of messages is equal to the number of processes. Initialize to
681 // one since the coordinator is effectively done.
682 int completed_ranks = 1;
683 while (completed_ranks != size) {
684 MPI_Status status;
685 MPI_Probe(MPI_ANY_SOURCE, TAG_NOTIFY, MPI_COMM_WORLD, &status);
686
687 // Find number of characters in message (including zero byte).
688 int source_rank = status.MPI_SOURCE;
689 int msg_length;
690 MPI_Get_count(&status, MPI_BYTE, &msg_length);
691
692 // If the length is zero, this is a DONE message.
693 if (msg_length == 0) {
694 completed_ranks++;
695 MPI_Recv(NULL, 0, MPI_BYTE, source_rank, TAG_NOTIFY, MPI_COMM_WORLD,
696 &status);
697 continue;
698 }
699
700 // Get tensor name from MPI into an std::string.
701 char* buffer = new char[msg_length];
702 MPI_Recv(buffer, msg_length, MPI_BYTE, source_rank, TAG_NOTIFY,
703 MPI_COMM_WORLD, &status);
704 std::string received_data(buffer);
705 delete[] buffer;
706
707 MPIRequest received_message;
708 received_message.ParseFromString(received_data);
709 auto received_name = received_message.tensor_name();
710
711 bool reduce = IncrementTensorCount(mpi_global.message_table,
712 received_message, size);
713 if (reduce) {
714 ready_to_reduce.push_back(received_name);
715 }
716 }
717
718 // At this point, rank zero should have a fully updated tensor
719 // count table and should know all the tensors that need to be
720 // reduced or gathered, and everyone else should have sent all
721 // their information to rank zero. We can now do reductions and
722 // gathers; rank zero will choose which ones and in what order,
723 // and will notify the other ranks before doing each reduction.
724 for (int i = 0; i < ready_to_reduce.size(); i++) {
725 // Notify all nodes which tensor we'd like to reduce now
726 auto name = ready_to_reduce[i];
727 MPIResponse response =
728 ConstructMPIResponse(mpi_global.message_table, name);
729
730 std::string encoded_response;
731 response.SerializeToString(&encoded_response);
732 for (int r = 1; r < size; r++) {
733 MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
734 MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
735 }
736
737 // Perform the reduction. All nodes should end up performing
738 // the same reduction.
739 PerformCollectiveOp(mpi_global.tensor_table, response);
740 }
741
742 // Notify all nodes that we are done with the reductions for this
743 // tick.
744 MPIResponse done_response;
745 should_shut_down = mpi_global.shut_down;
746 done_response.set_response_type(
747 mpi_global.shut_down ? MPIResponse::SHUTDOWN : MPIResponse::DONE);
748 std::string encoded_response;
749 done_response.SerializeToString(&encoded_response);
750 for (int r = 1; r < size; r++) {
751 MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
752 MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
753 }
754 } else {
755 // Notify the coordinator that this node is done sending messages.
756 // A DONE message is encoded as a zero-length message.
757 MPI_Send(NULL, 0, MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
758
759 // Receive names for tensors to reduce from rank zero. Once we
760 // receive a empty DONE message, stop waiting for more names.
761 while (true) {
762 MPI_Status status;
763 MPI_Probe(0, TAG_NOTIFY, MPI_COMM_WORLD, &status);
764
765 // Find number of characters in message (including zero byte).
766 int msg_length;
767 MPI_Get_count(&status, MPI_BYTE, &msg_length);
768
769 // Get tensor name from MPI into an std::string.
770 char* buffer = new char[msg_length];
771 MPI_Recv(buffer, msg_length, MPI_BYTE, 0, TAG_NOTIFY, MPI_COMM_WORLD,
772 &status);
773 std::string received_message(buffer);
774 delete[] buffer;
775
776 MPIResponse response;
777 response.ParseFromString(received_message);
778 if (response.response_type() == MPIResponse::DONE) {
779 // No more messages this tick
780 break;
781 } else if (response.response_type() == MPIResponse::SHUTDOWN) {
782 // No more messages this tick, and the background thread
783 // should shut down
784 should_shut_down = true;
785 break;
786 } else {
787 // Process the current message
788 PerformCollectiveOp(mpi_global.tensor_table, response);
789 }
790 }
791 }
792 } while (!should_shut_down);
793
794 MPI_Finalize();
795 }
796
797 // Initialize MPI and start the MPI background thread. Ensure that this is
798 // only done once no matter how many times this function is called.
InitializeMPIOnce(bool gpu)799 Status InitializeMPIOnce(bool gpu) {
800 // Ensure MPI is only initialized once.
801 if (mpi_global.initialized_flag.test_and_set()) return mpi_global.init_status;
802
803 mpi_global.device = -1;
804 #if GOOGLE_CUDA
805 if (gpu) {
806 cudaGetDevice(&mpi_global.device);
807 }
808 #endif
809
810 // Start the MPI background thread, which assumes MPI is initialized
811 // TODO: Change this to a Tensorflow thread
812 mpi_global.background_thread = std::thread(BackgroundThreadLoop);
813
814 // Wait to ensure that the background thread has finished initializing MPI
815 mutex_lock guard(mpi_global.mu);
816 mpi_global.cv.wait(guard);
817 if (!mpi_global.initialization_done) {
818 mpi_global.init_status =
819 errors::Unknown("Failed to wait for MPI initialization.");
820 }
821
822 return mpi_global.init_status;
823 }
824
825 // Check that MPI is initialized.
IsMPIInitialized()826 Status IsMPIInitialized() {
827 if (!mpi_global.initialization_done) {
828 return errors::FailedPrecondition(
829 "MPI has not been initialized; use tf.contrib.mpi.Session.");
830 }
831 return Status::OK();
832 }
833
834 // This function (called from the callback set up in MPIAll*Op::ComputeAsync)
835 // only adds the op's record into the local op queue (to track the op's
836 // progress), and sends a message to the coordinator indicating that this rank
837 // is ready to begin. The MPI background thread will handle the MPI message.
EnqueueTensorCollective(CollectiveOpRecord record,MPIRequest::RequestType rtype)838 void EnqueueTensorCollective(CollectiveOpRecord record,
839 MPIRequest::RequestType rtype) {
840 const Tensor* input_tensor = record.in_t;
841 MPIRequest message;
842 message.set_request_rank(record.rank);
843 message.set_tensor_name(record.name);
844 message.set_tensor_type(record.dtype);
845 message.set_request_type(rtype);
846 input_tensor->shape().AsProto(message.mutable_tensor_shape());
847
848 mutex_lock guard(mpi_global.mu);
849 mpi_global.tensor_table.emplace(record.name, record);
850 mpi_global.message_queue.push(message);
851 }
852
853 } // namespace
854
855 #if GOOGLE_CUDA
CudaStreamForMPI()856 cudaStream_t CudaStreamForMPI() { return mpi_global.stream; }
857 #endif
858
859 // Op to initialize MPI in the current process. The settings used in the
860 // configuration are the same that must be used for all future MPI ops.
861 template <typename Device>
862 class MPIInitOp : public OpKernel {
863 public:
MPIInitOp(OpKernelConstruction * context)864 explicit MPIInitOp(OpKernelConstruction* context) : OpKernel(context) {}
865
Compute(OpKernelContext * context)866 void Compute(OpKernelContext* context) override {
867 bool on_gpu = IsGPUDevice<Device>();
868 OP_REQUIRES_OK(context, InitializeMPIOnce(on_gpu));
869 }
870 };
871
872 REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_CPU),
873 MPIInitOp<CPUDevice>);
874 #if GOOGLE_CUDA
875 REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_GPU),
876 MPIInitOp<GPUDevice>);
877 #endif
878
879 // Op to get the current MPI Size.
880 template <typename Device>
881 class MPISizeOp : public OpKernel {
882 public:
MPISizeOp(OpKernelConstruction * context)883 explicit MPISizeOp(OpKernelConstruction* context) : OpKernel(context) {}
884
Compute(OpKernelContext * context)885 void Compute(OpKernelContext* context) override {
886 OP_REQUIRES_OK(context, IsMPIInitialized());
887
888 // Write integer to output tensor
889 Tensor* output;
890 OP_REQUIRES_OK(context,
891 context->allocate_output(0, TensorShape({}), &output));
892
893 auto flat = output->flat<int>();
894 flat(0) = mpi_global.size;
895 }
896 };
897
898 REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_CPU),
899 MPISizeOp<CPUDevice>);
900 #if GOOGLE_CUDA
901 REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_GPU).HostMemory("size"),
902 MPISizeOp<GPUDevice>);
903 #endif
904
905 // Op to get the current MPI Rank.
906 template <typename Device>
907 class MPIRankOp : public OpKernel {
908 public:
MPIRankOp(OpKernelConstruction * context)909 explicit MPIRankOp(OpKernelConstruction* context) : OpKernel(context) {}
910
Compute(OpKernelContext * context)911 void Compute(OpKernelContext* context) override {
912 OP_REQUIRES_OK(context, IsMPIInitialized());
913
914 // Write integer to output tensor
915 Tensor* output;
916 OP_REQUIRES_OK(context,
917 context->allocate_output(0, TensorShape({}), &output));
918
919 auto flat = output->flat<int>();
920 flat(0) = mpi_global.rank;
921 }
922 };
923
924 REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_CPU),
925 MPIRankOp<CPUDevice>);
926 #if GOOGLE_CUDA
927 REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_GPU).HostMemory("rank"),
928 MPIRankOp<GPUDevice>);
929 #endif
930
931 // Op to get the current local MPI Rank.
932 template <typename Device>
933 class MPILocalRankOp : public OpKernel {
934 public:
MPILocalRankOp(OpKernelConstruction * context)935 explicit MPILocalRankOp(OpKernelConstruction* context) : OpKernel(context) {}
936
Compute(OpKernelContext * context)937 void Compute(OpKernelContext* context) override {
938 OP_REQUIRES_OK(context, IsMPIInitialized());
939
940 // Write integer to output tensor
941 Tensor* output;
942 OP_REQUIRES_OK(context,
943 context->allocate_output(0, TensorShape({}), &output));
944
945 auto flat = output->flat<int>();
946 flat(0) = mpi_global.local_rank;
947 }
948 };
949
950 REGISTER_KERNEL_BUILDER(Name("MPILocalRank").Device(DEVICE_CPU),
951 MPILocalRankOp<CPUDevice>);
952 #if GOOGLE_CUDA
953 REGISTER_KERNEL_BUILDER(
954 Name("MPILocalRank").Device(DEVICE_GPU).HostMemory("rank"),
955 MPILocalRankOp<GPUDevice>);
956 #endif
957
958 template <typename Device>
959 class MPIAllreduceOp : public AsyncOpKernel {
960 public:
MPIAllreduceOp(OpKernelConstruction * context)961 explicit MPIAllreduceOp(OpKernelConstruction* context)
962 : AsyncOpKernel(context) {}
963
964 // Although this op is handled asynchronously, the ComputeAsync call is
965 // very inexpensive. It only sets up a CollectiveOpRecord and places it
966 // in the table for the background thread to handle. Thus, we do not need
967 // a TF pool thread to perform the op.
IsExpensive()968 bool IsExpensive() override { return false; }
969
ComputeAsync(OpKernelContext * context,DoneCallback done)970 void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
971 OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done);
972 const Tensor* input_tensor = &context->input(0);
973 Tensor* output_tensor;
974 OP_REQUIRES_OK_ASYNC(
975 context,
976 context->allocate_output(0, input_tensor->shape(), &output_tensor),
977 done);
978
979 // Record allocated on stack so op can fail without memory leak
980 CollectiveOpRecord record;
981 record.name = name();
982 record.context = context;
983 record.in_t = input_tensor;
984 record.out_t = output_tensor;
985 record.on_gpu = IsGPUDevice<Device>();
986 record.dtype = input_tensor->dtype();
987
988 const size_t temp_size =
989 (input_tensor->NumElements() + mpi_global.size - 1) / mpi_global.size;
990 TensorShape temp_shape;
991 temp_shape.AddDim(temp_size);
992 OP_REQUIRES_OK_ASYNC(context,
993 context->allocate_temp(input_tensor->dtype(),
994 temp_shape, &record.temp_t),
995 done);
996
997 auto allreduce_done_callback = [done, context](StatusOr<Tensor> status) {
998 context->SetStatus(status.status());
999 done();
1000 };
1001 record.callback = allreduce_done_callback;
1002
1003 auto allreduce_launch_callback = [record] {
1004 EnqueueTensorCollective(record, MPIRequest::ALLREDUCE);
1005 };
1006
1007 // If we are on a CPU, our device context will be null and we can't
1008 // get a stream to enqueue this on. On a CPU this op is called when the
1009 // data is already available, so we can just immediately do the
1010 // allreduce; we don't have to wait for the data to get populated.
1011 #if GOOGLE_CUDA
1012 auto device_context = context->op_device_context();
1013 if (device_context == nullptr) {
1014 allreduce_launch_callback();
1015 } else {
1016 auto stream = device_context->stream();
1017 stream->ThenDoHostCallback(allreduce_launch_callback);
1018 }
1019 #else
1020 allreduce_launch_callback();
1021 #endif
1022 }
1023 };
1024
1025 REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_CPU),
1026 MPIAllreduceOp<CPUDevice>);
1027 #if GOOGLE_CUDA
1028 REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_GPU),
1029 MPIAllreduceOp<GPUDevice>);
1030 #endif
1031
1032 template <typename Device>
1033 class MPIAllgatherOp : public AsyncOpKernel {
1034 public:
MPIAllgatherOp(OpKernelConstruction * context)1035 explicit MPIAllgatherOp(OpKernelConstruction* context)
1036 : AsyncOpKernel(context) {}
1037
1038 // Although this op is handled asynchronously, the ComputeAsync call is
1039 // very inexpensive. It only sets up a CollectiveOpRecord and places it
1040 // in the table for the background thread to handle. Thus, we do not need
1041 // a TF pool thread to perform the op.
IsExpensive()1042 bool IsExpensive() override { return false; }
1043
ComputeAsync(OpKernelContext * context,DoneCallback done)1044 void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
1045 OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done);
1046 const Tensor* input_tensor = &context->input(0);
1047 const Tensor* sizing_tensor = &context->input(1);
1048
1049 // Record allocated on stack so op can fail without memory leak
1050 CollectiveOpRecord record;
1051 record.name = name();
1052 record.context = context;
1053 record.in_t = input_tensor;
1054 record.on_gpu = IsGPUDevice<Device>();
1055
1056 // Construct the output size from the sizing tensor
1057 size_t output_first_dim = 0;
1058 if (sizing_tensor->shape().dims() == 0) {
1059 // 0-dim sizing_tensor implies that the op is just gathering
1060 // a single element from each rank
1061 output_first_dim = mpi_global.size;
1062 for (int i = 0; i < mpi_global.size; i++) {
1063 record.sizes_vec.push_back(1);
1064 }
1065 } else {
1066 // Collect the total output tensor sizing from the sizing tensor
1067 // NOTE: The sizing tensor is forced to be placed on the CPU by
1068 // declaring the input as HostMemory, so it is valid to read it here.
1069 const int64* sizing_array =
1070 (const int64*)sizing_tensor->tensor_data().data();
1071 for (int i = 0; i < mpi_global.size; i++) {
1072 record.sizes_vec.push_back(sizing_array[i]);
1073 output_first_dim += sizing_array[i];
1074 }
1075 }
1076
1077 TensorShape output_shape;
1078 output_shape.AddDim(output_first_dim);
1079 for (int i = 1; i < input_tensor->shape().dims(); i++) {
1080 output_shape.AddDim(input_tensor->shape().dim_size(i));
1081 }
1082
1083 Tensor* output_tensor;
1084 OP_REQUIRES_OK_ASYNC(
1085 context, context->allocate_output(0, output_shape, &output_tensor),
1086 done);
1087
1088 record.out_t = output_tensor;
1089 record.dtype = input_tensor->dtype();
1090
1091 auto allgather_done_callback = [done, context](StatusOr<Tensor> status) {
1092 context->SetStatus(status.status());
1093 done();
1094 };
1095 record.callback = allgather_done_callback;
1096
1097 auto allgather_launch_callback = [record] {
1098 EnqueueTensorCollective(record, MPIRequest::ALLGATHER);
1099 };
1100
1101 // If we are on a CPU, our device context will be null and we can't
1102 // get a stream to enqueue this on. On a CPU this op is called when the
1103 // data is already available, so we can just immediately do the
1104 // allgather; we don't have to wait for the data to get populated.
1105 #if GOOGLE_CUDA
1106 auto device_context = context->op_device_context();
1107 if (device_context == nullptr) {
1108 allgather_launch_callback();
1109 } else {
1110 auto stream = device_context->stream();
1111 stream->ThenDoHostCallback(allgather_launch_callback);
1112 }
1113 #else
1114 allgather_launch_callback();
1115 #endif
1116 }
1117 };
1118
1119 REGISTER_KERNEL_BUILDER(
1120 Name("MPIAllgather").Device(DEVICE_CPU).HostMemory("sizes"),
1121 MPIAllgatherOp<CPUDevice>);
1122 #if GOOGLE_CUDA
1123 REGISTER_KERNEL_BUILDER(
1124 Name("MPIAllgather").Device(DEVICE_GPU).HostMemory("sizes"),
1125 MPIAllgatherOp<GPUDevice>);
1126 #endif
1127
1128 } // namespace mpi_collectives
1129 } // namespace contrib
1130 } // namespace tensorflow
1131
1132 #endif // TENSORFLOW_USE_MPI
1133