1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
17
18 #include <chrono> // NOLINT (required by TF interfaces)
19 #include <cstdlib>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/str_format.h"
27 #include "absl/synchronization/mutex.h"
28 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
29 #include "tensorflow/compiler/xla/service/global_device_id.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
33
34 namespace xla {
35 namespace gpu {
36
37 // This file runs collective ops (i.e. ops that communicate between multiple
38 // GPUs) using NCCL.
39 //
40 // Here's a high-level overview of how running an op works.
41 //
42 // - Multiple threads call ExecuteOnStream.
43 // - All threads that "go together" (i.e. are participating in the "same"
44 // collective op) choose the same Rendezvous object from a global map.
45 // - Once all threads have arrived at the Rendezvous, we know exactly which
46 // GPUs are participating in the op, so we get or create a NcclClique
47 // containing those GPUs.
48 // - We perform the NCCL operation using the clique.
49
50 NcclCollectiveConfig::NcclCollectiveConfig() = default;
51 NcclCollectiveConfig::NcclCollectiveConfig(NcclCollectiveConfig&&) = default;
52 NcclCollectiveConfig::~NcclCollectiveConfig() = default;
53 NcclCollectiveConfig& NcclCollectiveConfig::operator=(NcclCollectiveConfig&&) =
54 default;
55
NcclIsEnabled()56 /* static */ bool NcclCollectiveThunk::NcclIsEnabled() {
57 #if XLA_ENABLE_XCCL
58 return true;
59 #else
60 return false;
61 #endif
62 }
63
ExecuteOnStream(const ExecuteParams & params)64 Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
65 #if XLA_ENABLE_XCCL
66 VLOG(1) << absl::StreamFormat("Starting %s.", ThunkKindToString(kind()));
67 auto op_profiler =
68 params.profiler->MakeScopedInstructionProfiler(profile_index());
69
70 TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
71 params.GetGlobalDeviceId());
72
73 TF_ASSIGN_OR_RETURN(
74 std::vector<GlobalDeviceId> participants,
75 GetParticipatingDevices(global_device_id, *params.device_assn,
76 config().replica_count, config().replica_groups));
77
78 if (IsGlobalNcclConfig() && (participants.size() != config().replica_count)) {
79 return InvalidArgument(
80 "Partial replica groups are not allowed when using NCCL_COMM_ID "
81 "environment configuration.");
82 }
83
84 TF_ASSIGN_OR_RETURN(
85 std::vector<LocalParticipant> local_participants,
86 GetLocalParticipants(participants, params.gpu_global_device_ids));
87
88 // Create the rendezvous for this collective operation.
89 RendezvousKey rendezvous_key(params.run_id, std::move(participants),
90 local_participants.size(),
91 config().collective_op_kind, config().op_id);
92
93 int device_ordinal = params.stream->parent()->device_ordinal();
94
95 TF_ASSIGN_OR_RETURN(
96 LockedNcclClique locked_clique,
97 AcquireNcclClique(rendezvous_key, device_ordinal, params.stream,
98 local_participants, params.nccl_unique_id_callback));
99 ncclComm_t comm =
100 locked_clique.clique.GetCommForDeviceOrdinal(device_ordinal);
101
102 se::StreamExecutor* executor = params.stream->parent();
103 se::gpu::ScopedActivateExecutorContext scoped_context(executor);
104
105 TF_RETURN_IF_ERROR(RunNcclCollective(params, comm));
106 return Status::OK();
107 #else // XLA_ENABLE_XCCL
108 return Unimplemented(
109 "NCCL support is not available: this binary was not built with a CUDA "
110 "compiler, which is necessary to build the NCCL source library.");
111 #endif // XLA_ENABLE_XCCL
112 }
113
IsTypeSupportedByNccl(PrimitiveType element_type)114 bool IsTypeSupportedByNccl(PrimitiveType element_type) {
115 switch (element_type) {
116 case S8:
117 case PRED:
118 case U8:
119 case S32:
120 case U32:
121 case S64:
122 case U64:
123 case F16:
124 case F32:
125 case F64:
126 return true;
127 default:
128 return false;
129 }
130 }
131
132 } // namespace gpu
133 } // namespace xla
134