1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
17
18 #include <chrono> // NOLINT (required by TF interfaces)
19 #include <cstdlib>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/str_format.h"
27 #include "absl/synchronization/mutex.h"
28 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
29 #include "tensorflow/compiler/xla/service/global_device_id.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
33
34 namespace xla {
35 namespace gpu {
36
37 // This file runs collective ops (i.e. ops that communicate between multiple
38 // GPUs) using NCCL.
39 //
40 // Here's a high-level overview of how running an op works.
41 //
42 // - Multiple threads call ExecuteOnStream.
43 // - All threads that "go together" (i.e. are participating in the "same"
44 // collective op) choose the same Rendezvous object from a global map.
45 // - Once all threads have arrived at the Rendezvous, we know exactly which
46 // GPUs are participating in the op, so we get or create a NcclClique
47 // containing those GPUs.
48 // - We perform the NCCL operation using the clique.
49
50 NcclCollectiveConfig::NcclCollectiveConfig() = default;
51 NcclCollectiveConfig::NcclCollectiveConfig(NcclCollectiveConfig&&) = default;
52 NcclCollectiveConfig::~NcclCollectiveConfig() = default;
53 NcclCollectiveConfig& NcclCollectiveConfig::operator=(NcclCollectiveConfig&&) =
54 default;
55
56 // Returns if the collective communication operation is degenerate because all
57 // the groups formed by the operation are singleton. A given op can be
58 // degenerate under several conditions, corresponding to the modes supported
59 // in GetParticipatingDevices().
60 // 1. no channel id, use_global_device_ids = false:
61 // degenerate if replica_groups are singleton, or groups empty and
62 // replica_count == 1.
63 // 2. channel_id is set, use_global_device_ids = false:
64 // degenerate if replica_groups are singleton and num_partitions == 1,
65 // or groups empty and num_replicas == 1 && num_partitions == 1.
66 // 3. channel_id is set, use_global_device_ids = true (flattened-ids):
67 // degenerate if replica_groups are singleton (groups cannot be empty).
68 // 4. no channel_id, no use_global_device_ids:
69 // identical to 1.
70 // 5. channel_id is set, no use_global_device_ids:
71 // degenerate if replica_groups are singleton or group emty and
72 // num_partitions == 1 (since replica groups contain partition ids).
73 //
IsDegenerate(int64_t replica_count,int64_t partition_count) const74 bool NcclCollectiveConfig::IsDegenerate(int64_t replica_count,
75 int64_t partition_count) const {
76 bool groups_empty = replica_groups.empty();
77
78 // check if all replica_groups are singleton. If not, then the operation is
79 // not degenerate.
80 bool all_groups_singleton =
81 !groups_empty &&
82 absl::c_all_of(replica_groups, [](const ReplicaGroup& group) {
83 return group.replica_ids_size() == 1;
84 });
85
86 switch (group_mode) {
87 case CollectiveOpGroupMode::kCrossReplica:
88 return all_groups_singleton || (groups_empty && replica_count == 1);
89 case CollectiveOpGroupMode::kCrossPartition:
90 return all_groups_singleton || (groups_empty && partition_count == 1);
91 case CollectiveOpGroupMode::kCrossReplicaAndPartition:
92 return (all_groups_singleton && partition_count == 1) ||
93 (groups_empty && replica_count == 1 && partition_count == 1);
94 case CollectiveOpGroupMode::kFlattenedID:
95 CHECK(!groups_empty)
96 << "replica groups cannot be empty if use_global_device_ids = true";
97 return all_groups_singleton;
98 default:
99 CHECK(0) << "Invalid collective op mode";
100 return false;
101 }
102 }
103
NcclIsEnabled()104 /* static */ bool NcclCollectiveThunk::NcclIsEnabled() {
105 #if XLA_ENABLE_XCCL
106 return true;
107 #else
108 return false;
109 #endif
110 }
111
ExecuteOnStream(const ExecuteParams & params)112 Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
113 #if XLA_ENABLE_XCCL
114 VLOG(1) << absl::StreamFormat("Starting %s.", Thunk::KindToString(kind()));
115 TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
116 params.GetGlobalDeviceId());
117
118 TF_ASSIGN_OR_RETURN(
119 std::vector<GlobalDeviceId> participants,
120 GetParticipatingDevices(global_device_id, *params.device_assn,
121 config().replica_groups, config().group_mode));
122
123 if (IsGlobalNcclConfig() &&
124 (participants.size() != params.device_assn->replica_count())) {
125 return InvalidArgument(
126 "Partial replica groups are not allowed when using NCCL_COMM_ID "
127 "environment configuration.");
128 }
129
130 TF_ASSIGN_OR_RETURN(
131 std::vector<LocalParticipant> local_participants,
132 GetLocalParticipants(participants, params.gpu_global_device_ids));
133
134 // Create the rendezvous for this collective operation.
135 const RendezvousKey rendezvous_key(
136 params.run_id, std::move(participants), local_participants.size(),
137 config().collective_op_kind, config().op_id);
138 VLOG(2) << GetDeviceString(params) << ": key " << rendezvous_key.ToString()
139 << "\n";
140
141 int device_ordinal = params.stream->parent()->device_ordinal();
142 NcclCliqueParticipantData participant(rendezvous_key, device_ordinal,
143 params.stream);
144
145 TF_ASSIGN_OR_RETURN(LockedNcclClique locked_clique,
146 AcquireNcclClique(participant, local_participants,
147 params.nccl_unique_id_callback));
148 ncclComm_t comm =
149 locked_clique.clique.GetCommForDeviceOrdinal(device_ordinal);
150
151 se::StreamExecutor* executor = params.stream->parent();
152 se::gpu::ScopedActivateExecutorContext scoped_context(executor);
153
154 TF_RETURN_IF_ERROR(RunNcclCollective(params, comm));
155 return Status::OK();
156 #else // XLA_ENABLE_XCCL
157 return Unimplemented(
158 "NCCL support is not available: this binary was not built with a CUDA "
159 "compiler, which is necessary to build the NCCL source library.");
160 #endif // XLA_ENABLE_XCCL
161 }
162
GetDeviceString(const ExecuteParams & params) const163 std::string NcclCollectiveThunk::GetDeviceString(
164 const ExecuteParams& params) const {
165 int device_ordinal = params.stream->parent()->device_ordinal();
166 GlobalDeviceId global_device_id = params.GetGlobalDeviceId().ValueOrDie();
167 DeviceAssignment::LogicalID logical_id =
168 params.device_assn->LogicalIdForDevice(global_device_id).ValueOrDie();
169 return absl::StrFormat("(r%d, p%d) : GlobalID %d, ord %d",
170 logical_id.replica_id, logical_id.computation_id,
171 global_device_id.value(), device_ordinal);
172 }
173
IsTypeSupportedByNccl(PrimitiveType element_type)174 bool IsTypeSupportedByNccl(PrimitiveType element_type) {
175 switch (element_type) {
176 case S8:
177 case PRED:
178 case U8:
179 case S32:
180 case U32:
181 case S64:
182 case U64:
183 case F16:
184 case F32:
185 case F64:
186 #if defined(__CUDA_BF16_TYPES_EXIST__)
187 case BF16:
188 #endif
189 case C64:
190 case C128:
191 return true;
192 default:
193 return false;
194 }
195 }
196
197 } // namespace gpu
198 } // namespace xla
199