• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_
18 
19 #include <memory>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/synchronization/blocking_counter.h"
24 #include "absl/synchronization/mutex.h"
25 #if GOOGLE_CUDA
26 #include "third_party/nccl/nccl.h"
27 #elif TENSORFLOW_USE_ROCM
28 #include "rocm/include/rccl/rccl.h"
29 #endif
30 #include "tensorflow/compiler/xla/refcounting_hash_map.h"
31 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
32 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
33 #include "tensorflow/compiler/xla/status.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/stream_executor/gpu/gpu_types.h"
37 
38 #if TENSORFLOW_USE_ROCM
39 // Local hipify of cuda symbols
40 #define cudaError_t hipError_t
41 #define cudaStream_t hipStream_t
42 #define cudaGetErrorString hipGetErrorString
43 #define cudaGetDevice hipGetDevice
44 #define cudaSetDevice hipSetDevice
45 #define cudaSuccess hipSuccess
46 #endif
47 
48 namespace xla {
49 namespace gpu {
50 
51 ncclRedOp_t ToNcclReduction(ReductionKind kind);
52 StatusOr<std::pair<ncclDataType_t, int>> ToNcclDataTypeAndCountMultiplier(
53     PrimitiveType element_type);
54 
55 bool IsGlobalNcclConfig();
56 bool IsNcclLaunchModeParallel();
57 
58 Status ToStatus(ncclResult_t s, const char* file, int64_t line,
59                 const char* expr);
60 Status ToStatus(cudaError_t s, const char* file, int64_t line,
61                 const char* expr);
62 
63 // Macros to return or warn on CUDA/NCCL errors.  (The same macro works for both
64 // NCCL and CUDA errors.)
65 //
66 // It's tempting to say these macros belong in an XLA header somewhere, but in
67 // practice we don't do much direct-to-CUDA-API stuff outside of this file.
68 #define XLA_CUDA_STATUS(expr) \
69   xla::gpu::ToStatus(expr, __FILE__, __LINE__, #expr)
70 
71 #define XLA_CUDA_RETURN_IF_ERROR(expr) \
72   do {                                 \
73     Status s = XLA_CUDA_STATUS(expr);  \
74     if (!s.ok()) {                     \
75       return s;                        \
76     }                                  \
77   } while (0)
78 
79 #define XLA_CUDA_WARN_IF_ERROR(expr)  \
80   do {                                \
81     Status s = XLA_CUDA_STATUS(expr); \
82     if (!s.ok()) {                    \
83       LOG(ERROR) << s.ToString();     \
84     }                                 \
85   } while (0)
86 
87 // RAII type for NCCL communicators.
88 using NcclComm = std::unique_ptr<ncclComm, void (*)(ncclComm_t)>;
89 
90 // Owns a clique of NCCL comms which can be used for collective operations among
91 // a particular set of GPUs.
92 //
93 // Note that if you want to do a collective operation among a subset of these
94 // GPUs, you'll need a different clique.
95 class NcclClique {
96  public:
97   explicit NcclClique(
98       absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal);
99 
100   ncclComm_t GetCommForDeviceOrdinal(int device_ordinal) const;
mu()101   absl::Mutex* mu() { return &mu_; }
102 
103  private:
104   absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal_;
105   absl::Mutex mu_;
106 };
107 
108 struct LocalParticipant {
109   int device_ordinal;
110   int rank;
111 };
112 
113 StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
114     const std::vector<GlobalDeviceId>& participants,
115     const std::vector<GlobalDeviceId>* local_devices);  // may be null
116 
117 class LockedNcclClique {
118  public:
119   LockedNcclClique(NcclClique& clique, std::unique_ptr<absl::MutexLock> lock,
120                    absl::BlockingCounter* counter);
121   LockedNcclClique(LockedNcclClique&&);
122   ~LockedNcclClique();
123 
124   NcclClique& clique;
125 
126  private:
127   // Must come after clique, so it is destroyed first.
128   // One thread holds a lock (it is null in the others).
129   std::unique_ptr<absl::MutexLock> lock_;
130   absl::BlockingCounter* counter_;
131 };
132 
133 // Threadsafe leaky map from NcclCliqueKeys to NcclCliques.
134 class NcclCliqueMap {
135  public:
136   StatusOr<NcclClique*> GetOrTryCreateIfAbsent(
137       const NcclCliqueKey& key,
138       const std::function<StatusOr<std::unique_ptr<NcclClique>>(
139           const NcclCliqueKey&)>& value_factory) ABSL_LOCKS_EXCLUDED(mu_);
140 
141   // Runs a function over every key/value in the map.
142   void ForEach(
143       const std::function<void(const NcclCliqueKey&, const NcclClique&)>& fn)
144       ABSL_LOCKS_EXCLUDED(mu_);
145 
146  private:
147   absl::Mutex mu_;
148   absl::flat_hash_map<NcclCliqueKey, std::unique_ptr<NcclClique>> map_
149       ABSL_GUARDED_BY(mu_);
150 };
151 
152 NcclCliqueMap& NcclCliqueCache();
153 
154 struct NcclCliqueParticipantData : public ParticipantData {
155   // For running in StreamExecutor. To be deprecated after transitioning to
156   // TFRT.
NcclCliqueParticipantDataNcclCliqueParticipantData157   NcclCliqueParticipantData(const RendezvousKey& rendezvous_key,
158                             int64_t device_ordinal, se::Stream* stream)
159       : ParticipantData(rendezvous_key),
160         device_ordinal(device_ordinal),
161         stream(stream) {}
162 
163   // For running in TFRT.
NcclCliqueParticipantDataNcclCliqueParticipantData164   NcclCliqueParticipantData(const RendezvousKey& rendezvous_key,
165                             se::gpu::GpuContextHandle context)
166       : ParticipantData(rendezvous_key), stream(nullptr), context(context) {}
167 
168   int64 device_ordinal;
169   se::Stream* stream;
170   se::gpu::GpuContextHandle context;
171 
ToStringNcclCliqueParticipantData172   std::string ToString() const override {
173     if (stream != nullptr) {
174       return absl::StrFormat(
175           "NcclCliqueParticipantData{rendezvous_key=%s, "
176           "device_ordinal=%d, stream=%p}",
177           rendezvous_key.ToString(), device_ordinal, stream);
178     }
179     return absl::StrFormat(
180         "NcclCliqueParticipantData{rendezvous_key=%s, context=%p}",
181         rendezvous_key.ToString(), context);
182   }
183 };
184 
185 // Acquires a locked NCCL clique for use in NCCL collective operations.
186 StatusOr<LockedNcclClique> AcquireNcclClique(
187     const NcclCliqueParticipantData& participant,
188     const std::vector<LocalParticipant>& local_participants,
189     const NcclUniqueIdCallback* callback);  // may be null
190 
191 }  // namespace gpu
192 }  // namespace xla
193 
194 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_
195