• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
17 
18 #include <chrono>  // NOLINT (required by TF interfaces)
19 #include <cstdlib>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/str_format.h"
27 #include "absl/synchronization/mutex.h"
28 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
29 #include "tensorflow/compiler/xla/service/global_device_id.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
33 
34 namespace xla {
35 namespace gpu {
36 
37 // This file runs collective ops (i.e. ops that communicate between multiple
38 // GPUs) using NCCL.
39 //
40 // Here's a high-level overview of how running an op works.
41 //
42 //  - Multiple threads call ExecuteOnStream.
43 //  - All threads that "go together" (i.e. are participating in the "same"
44 //    collective op) choose the same Rendezvous object from a global map.
45 //  - Once all threads have arrived at the Rendezvous, we know exactly which
46 //    GPUs are participating in the op, so we get or create a NcclClique
47 //    containing those GPUs.
48 //  - We perform the NCCL operation using the clique.
49 
50 NcclCollectiveConfig::NcclCollectiveConfig() = default;
51 NcclCollectiveConfig::NcclCollectiveConfig(NcclCollectiveConfig&&) = default;
52 NcclCollectiveConfig::~NcclCollectiveConfig() = default;
53 NcclCollectiveConfig& NcclCollectiveConfig::operator=(NcclCollectiveConfig&&) =
54     default;
55 
56 // Returns if the collective communication operation is degenerate because all
57 // the groups formed by the operation are singleton. A given op can be
58 // degenerate under several conditions, corresponding to the modes supported
59 // in GetParticipatingDevices().
60 //   1. no channel id, use_global_device_ids = false:
61 //         degenerate if replica_groups are singleton, or groups empty and
62 //         replica_count == 1.
63 //   2. channel_id is set, use_global_device_ids = false:
64 //         degenerate if replica_groups are singleton and num_partitions == 1,
65 //         or groups empty and num_replicas == 1 && num_partitions == 1.
66 //   3. channel_id is set, use_global_device_ids = true (flattened-ids):
67 //         degenerate if replica_groups are singleton (groups cannot be empty).
68 //   4. no channel_id, no use_global_device_ids:
69 //         identical to 1.
70 //   5. channel_id is set, no use_global_device_ids:
71 //         degenerate if replica_groups are singleton or group emty and
72 //         num_partitions == 1 (since replica groups contain partition ids).
73 //
IsDegenerate(int64_t replica_count,int64_t partition_count) const74 bool NcclCollectiveConfig::IsDegenerate(int64_t replica_count,
75                                         int64_t partition_count) const {
76   bool groups_empty = replica_groups.empty();
77 
78   // check if all replica_groups are singleton. If not, then the operation is
79   // not degenerate.
80   bool all_groups_singleton =
81       !groups_empty &&
82       absl::c_all_of(replica_groups, [](const ReplicaGroup& group) {
83         return group.replica_ids_size() == 1;
84       });
85 
86   switch (group_mode) {
87     case CollectiveOpGroupMode::kCrossReplica:
88       return all_groups_singleton || (groups_empty && replica_count == 1);
89     case CollectiveOpGroupMode::kCrossPartition:
90       return all_groups_singleton || (groups_empty && partition_count == 1);
91     case CollectiveOpGroupMode::kCrossReplicaAndPartition:
92       return (all_groups_singleton && partition_count == 1) ||
93              (groups_empty && replica_count == 1 && partition_count == 1);
94     case CollectiveOpGroupMode::kFlattenedID:
95       CHECK(!groups_empty)
96           << "replica groups cannot be empty if use_global_device_ids = true";
97       return all_groups_singleton;
98     default:
99       CHECK(0) << "Invalid collective op mode";
100       return false;
101   }
102 }
103 
NcclIsEnabled()104 /* static */ bool NcclCollectiveThunk::NcclIsEnabled() {
105 #if XLA_ENABLE_XCCL
106   return true;
107 #else
108   return false;
109 #endif
110 }
111 
ExecuteOnStream(const ExecuteParams & params)112 Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
113 #if XLA_ENABLE_XCCL
114   VLOG(1) << absl::StreamFormat("Starting %s.", Thunk::KindToString(kind()));
115   TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
116                       params.GetGlobalDeviceId());
117 
118   TF_ASSIGN_OR_RETURN(
119       std::vector<GlobalDeviceId> participants,
120       GetParticipatingDevices(global_device_id, *params.device_assn,
121                               config().replica_groups, config().group_mode));
122 
123   if (IsGlobalNcclConfig() &&
124       (participants.size() != params.device_assn->replica_count())) {
125     return InvalidArgument(
126         "Partial replica groups are not allowed when using NCCL_COMM_ID "
127         "environment configuration.");
128   }
129 
130   TF_ASSIGN_OR_RETURN(
131       std::vector<LocalParticipant> local_participants,
132       GetLocalParticipants(participants, params.gpu_global_device_ids));
133 
134   // Create the rendezvous for this collective operation.
135   const RendezvousKey rendezvous_key(
136       params.run_id, std::move(participants), local_participants.size(),
137       config().collective_op_kind, config().op_id);
138   VLOG(2) << GetDeviceString(params) << ": key " << rendezvous_key.ToString()
139           << "\n";
140 
141   int device_ordinal = params.stream->parent()->device_ordinal();
142   NcclCliqueParticipantData participant(rendezvous_key, device_ordinal,
143                                         params.stream);
144 
145   TF_ASSIGN_OR_RETURN(LockedNcclClique locked_clique,
146                       AcquireNcclClique(participant, local_participants,
147                                         params.nccl_unique_id_callback));
148   ncclComm_t comm =
149       locked_clique.clique.GetCommForDeviceOrdinal(device_ordinal);
150 
151   se::StreamExecutor* executor = params.stream->parent();
152   se::gpu::ScopedActivateExecutorContext scoped_context(executor);
153 
154   TF_RETURN_IF_ERROR(RunNcclCollective(params, comm));
155   return Status::OK();
156 #else   // XLA_ENABLE_XCCL
157   return Unimplemented(
158       "NCCL support is not available: this binary was not built with a CUDA "
159       "compiler, which is necessary to build the NCCL source library.");
160 #endif  // XLA_ENABLE_XCCL
161 }
162 
GetDeviceString(const ExecuteParams & params) const163 std::string NcclCollectiveThunk::GetDeviceString(
164     const ExecuteParams& params) const {
165   int device_ordinal = params.stream->parent()->device_ordinal();
166   GlobalDeviceId global_device_id = params.GetGlobalDeviceId().ValueOrDie();
167   DeviceAssignment::LogicalID logical_id =
168       params.device_assn->LogicalIdForDevice(global_device_id).ValueOrDie();
169   return absl::StrFormat("(r%d, p%d) : GlobalID %d, ord %d",
170                          logical_id.replica_id, logical_id.computation_id,
171                          global_device_id.value(), device_ordinal);
172 }
173 
IsTypeSupportedByNccl(PrimitiveType element_type)174 bool IsTypeSupportedByNccl(PrimitiveType element_type) {
175   switch (element_type) {
176     case S8:
177     case PRED:
178     case U8:
179     case S32:
180     case U32:
181     case S64:
182     case U64:
183     case F16:
184     case F32:
185     case F64:
186 #if defined(__CUDA_BF16_TYPES_EXIST__)
187     case BF16:
188 #endif
189     case C64:
190     case C128:
191       return true;
192     default:
193       return false;
194   }
195 }
196 
197 }  // namespace gpu
198 }  // namespace xla
199