• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_
18 
19 #include <memory>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/synchronization/blocking_counter.h"
24 #include "absl/synchronization/mutex.h"
25 #if GOOGLE_CUDA
26 #include "third_party/nccl/nccl.h"
27 #elif TENSORFLOW_USE_ROCM
28 #include "rocm/include/rccl/rccl.h"
29 #endif
30 #include "tensorflow/compiler/xla/refcounting_hash_map.h"
31 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
32 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
33 #include "tensorflow/compiler/xla/status.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 
37 #if TENSORFLOW_USE_ROCM
38 // Local hipify of cuda symbols
39 #define cudaError_t hipError_t
40 #define cudaStream_t hipStream_t
41 #define cudaGetErrorString hipGetErrorString
42 #define cudaGetDevice hipGetDevice
43 #define cudaSetDevice hipSetDevice
44 #define cudaSuccess hipSuccess
45 #endif
46 
47 namespace xla {
48 namespace gpu {
49 
50 ncclRedOp_t ToNcclReduction(ReductionKind kind);
51 StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType element_type);
52 
53 bool IsGlobalNcclConfig();
54 
55 Status ToStatus(ncclResult_t s, const char* file, int64 line, const char* expr);
56 Status ToStatus(cudaError_t s, const char* file, int64 line, const char* expr);
57 
58 // Macros to return or warn on CUDA/NCCL errors.  (The same macro works for both
59 // NCCL and CUDA errors.)
60 //
61 // It's tempting to say these macros belong in an XLA header somewhere, but in
62 // practice we don't do much direct-to-CUDA-API stuff outside of this file.
63 #define XLA_CUDA_STATUS(expr) \
64   xla::gpu::ToStatus(expr, __FILE__, __LINE__, #expr)
65 
66 #define XLA_CUDA_RETURN_IF_ERROR(expr) \
67   do {                                 \
68     Status s = XLA_CUDA_STATUS(expr);  \
69     if (!s.ok()) {                     \
70       return s;                        \
71     }                                  \
72   } while (0)
73 
74 #define XLA_CUDA_WARN_IF_ERROR(expr)  \
75   do {                                \
76     Status s = XLA_CUDA_STATUS(expr); \
77     if (!s.ok()) {                    \
78       LOG(ERROR) << s.ToString();     \
79     }                                 \
80   } while (0)
81 
82 // RAII type for NCCL communicators.
83 using NcclComm = std::unique_ptr<ncclComm, void (*)(ncclComm_t)>;
84 
85 // Owns a clique of NCCL comms which can be used for collective operations among
86 // a particular set of GPUs.
87 //
88 // Note that if you want to do a collective operation among a subset of these
89 // GPUs, you'll need a different clique.
90 class NcclClique {
91  public:
92   explicit NcclClique(
93       absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal);
94 
95   ncclComm_t GetCommForDeviceOrdinal(int device_ordinal) const;
mu()96   absl::Mutex* mu() { return &mu_; }
97 
98  private:
99   absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal_;
100   absl::Mutex mu_;
101 };
102 
103 struct LocalParticipant {
104   int device_ordinal;
105   int rank;
106 };
107 
108 StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
109     const std::vector<GlobalDeviceId>& participants,
110     const std::vector<GlobalDeviceId>* local_devices);  // may be null
111 
112 class LockedNcclClique {
113  public:
114   LockedNcclClique(NcclClique& clique, std::unique_ptr<absl::MutexLock> lock,
115                    absl::BlockingCounter* counter);
116   LockedNcclClique(LockedNcclClique&&);
117   ~LockedNcclClique();
118 
119   NcclClique& clique;
120 
121  private:
122   // Must come after clique, so it is destroyed first.
123   // One thread holds a lock (it is null in the others).
124   std::unique_ptr<absl::MutexLock> lock_;
125   absl::BlockingCounter* counter_;
126 };
127 
128 // Threadsafe leaky map from NcclCliqueKeys to NcclCliques.
129 class NcclCliqueMap {
130  public:
131   StatusOr<NcclClique*> GetOrTryCreateIfAbsent(
132       const NcclCliqueKey& key,
133       const std::function<StatusOr<std::unique_ptr<NcclClique>>(
134           const NcclCliqueKey&)>& value_factory) ABSL_LOCKS_EXCLUDED(mu_);
135 
136   // Runs a function over every key/value in the map.
137   void ForEach(
138       const std::function<void(const NcclCliqueKey&, const NcclClique&)>& fn)
139       ABSL_LOCKS_EXCLUDED(mu_);
140 
141  private:
142   absl::Mutex mu_;
143   absl::flat_hash_map<NcclCliqueKey, std::unique_ptr<NcclClique>> map_
144       ABSL_GUARDED_BY(mu_);
145 };
146 
147 NcclCliqueMap& NcclCliqueCache();
148 
149 // Acquires a locked NCCL clique for use in NCCL collective operations.
150 StatusOr<LockedNcclClique> AcquireNcclClique(
151     const RendezvousKey& rendezvous_key, int local_device_ordinal,
152     se::Stream* stream, const std::vector<LocalParticipant>& local_participants,
153     const NcclUniqueIdCallback* callback);  // may be null
154 
155 }  // namespace gpu
156 }  // namespace xla
157 
158 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_
159