1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_ 18 19 #include <memory> 20 21 #include "absl/container/flat_hash_map.h" 22 #include "absl/container/flat_hash_set.h" 23 #include "absl/synchronization/blocking_counter.h" 24 #include "absl/synchronization/mutex.h" 25 #if GOOGLE_CUDA 26 #include "third_party/nccl/nccl.h" 27 #elif TENSORFLOW_USE_ROCM 28 #include "rocm/include/rccl/rccl.h" 29 #endif 30 #include "tensorflow/compiler/xla/refcounting_hash_map.h" 31 #include "tensorflow/compiler/xla/service/collective_ops_utils.h" 32 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" 33 #include "tensorflow/compiler/xla/status.h" 34 #include "tensorflow/compiler/xla/statusor.h" 35 #include "tensorflow/compiler/xla/xla_data.pb.h" 36 #include "tensorflow/stream_executor/gpu/gpu_types.h" 37 38 #if TENSORFLOW_USE_ROCM 39 // Local hipify of cuda symbols 40 #define cudaError_t hipError_t 41 #define cudaStream_t hipStream_t 42 #define cudaGetErrorString hipGetErrorString 43 #define cudaGetDevice hipGetDevice 44 #define cudaSetDevice hipSetDevice 45 #define cudaSuccess hipSuccess 46 #endif 47 48 namespace xla { 49 namespace gpu { 50 51 ncclRedOp_t ToNcclReduction(ReductionKind kind); 52 StatusOr<std::pair<ncclDataType_t, int>> ToNcclDataTypeAndCountMultiplier( 53 PrimitiveType element_type); 54 55 bool IsGlobalNcclConfig(); 56 bool IsNcclLaunchModeParallel(); 57 58 Status ToStatus(ncclResult_t s, const char* file, int64_t line, 59 const char* expr); 60 Status ToStatus(cudaError_t s, const char* file, int64_t line, 61 const char* expr); 62 63 // Macros to return or warn on CUDA/NCCL errors. (The same macro works for both 64 // NCCL and CUDA errors.) 65 // 66 // It's tempting to say these macros belong in an XLA header somewhere, but in 67 // practice we don't do much direct-to-CUDA-API stuff outside of this file. 68 #define XLA_CUDA_STATUS(expr) \ 69 xla::gpu::ToStatus(expr, __FILE__, __LINE__, #expr) 70 71 #define XLA_CUDA_RETURN_IF_ERROR(expr) \ 72 do { \ 73 Status s = XLA_CUDA_STATUS(expr); \ 74 if (!s.ok()) { \ 75 return s; \ 76 } \ 77 } while (0) 78 79 #define XLA_CUDA_WARN_IF_ERROR(expr) \ 80 do { \ 81 Status s = XLA_CUDA_STATUS(expr); \ 82 if (!s.ok()) { \ 83 LOG(ERROR) << s.ToString(); \ 84 } \ 85 } while (0) 86 87 // RAII type for NCCL communicators. 88 using NcclComm = std::unique_ptr<ncclComm, void (*)(ncclComm_t)>; 89 90 // Owns a clique of NCCL comms which can be used for collective operations among 91 // a particular set of GPUs. 92 // 93 // Note that if you want to do a collective operation among a subset of these 94 // GPUs, you'll need a different clique. 95 class NcclClique { 96 public: 97 explicit NcclClique( 98 absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal); 99 100 ncclComm_t GetCommForDeviceOrdinal(int device_ordinal) const; mu()101 absl::Mutex* mu() { return &mu_; } 102 103 private: 104 absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal_; 105 absl::Mutex mu_; 106 }; 107 108 struct LocalParticipant { 109 int device_ordinal; 110 int rank; 111 }; 112 113 StatusOr<std::vector<LocalParticipant>> GetLocalParticipants( 114 const std::vector<GlobalDeviceId>& participants, 115 const std::vector<GlobalDeviceId>* local_devices); // may be null 116 117 class LockedNcclClique { 118 public: 119 LockedNcclClique(NcclClique& clique, std::unique_ptr<absl::MutexLock> lock, 120 absl::BlockingCounter* counter); 121 LockedNcclClique(LockedNcclClique&&); 122 ~LockedNcclClique(); 123 124 NcclClique& clique; 125 126 private: 127 // Must come after clique, so it is destroyed first. 128 // One thread holds a lock (it is null in the others). 129 std::unique_ptr<absl::MutexLock> lock_; 130 absl::BlockingCounter* counter_; 131 }; 132 133 // Threadsafe leaky map from NcclCliqueKeys to NcclCliques. 134 class NcclCliqueMap { 135 public: 136 StatusOr<NcclClique*> GetOrTryCreateIfAbsent( 137 const NcclCliqueKey& key, 138 const std::function<StatusOr<std::unique_ptr<NcclClique>>( 139 const NcclCliqueKey&)>& value_factory) ABSL_LOCKS_EXCLUDED(mu_); 140 141 // Runs a function over every key/value in the map. 142 void ForEach( 143 const std::function<void(const NcclCliqueKey&, const NcclClique&)>& fn) 144 ABSL_LOCKS_EXCLUDED(mu_); 145 146 private: 147 absl::Mutex mu_; 148 absl::flat_hash_map<NcclCliqueKey, std::unique_ptr<NcclClique>> map_ 149 ABSL_GUARDED_BY(mu_); 150 }; 151 152 NcclCliqueMap& NcclCliqueCache(); 153 154 struct NcclCliqueParticipantData : public ParticipantData { 155 // For running in StreamExecutor. To be deprecated after transitioning to 156 // TFRT. NcclCliqueParticipantDataNcclCliqueParticipantData157 NcclCliqueParticipantData(const RendezvousKey& rendezvous_key, 158 int64_t device_ordinal, se::Stream* stream) 159 : ParticipantData(rendezvous_key), 160 device_ordinal(device_ordinal), 161 stream(stream) {} 162 163 // For running in TFRT. NcclCliqueParticipantDataNcclCliqueParticipantData164 NcclCliqueParticipantData(const RendezvousKey& rendezvous_key, 165 se::gpu::GpuContextHandle context) 166 : ParticipantData(rendezvous_key), stream(nullptr), context(context) {} 167 168 int64 device_ordinal; 169 se::Stream* stream; 170 se::gpu::GpuContextHandle context; 171 ToStringNcclCliqueParticipantData172 std::string ToString() const override { 173 if (stream != nullptr) { 174 return absl::StrFormat( 175 "NcclCliqueParticipantData{rendezvous_key=%s, " 176 "device_ordinal=%d, stream=%p}", 177 rendezvous_key.ToString(), device_ordinal, stream); 178 } 179 return absl::StrFormat( 180 "NcclCliqueParticipantData{rendezvous_key=%s, context=%p}", 181 rendezvous_key.ToString(), context); 182 } 183 }; 184 185 // Acquires a locked NCCL clique for use in NCCL collective operations. 186 StatusOr<LockedNcclClique> AcquireNcclClique( 187 const NcclCliqueParticipantData& participant, 188 const std::vector<LocalParticipant>& local_participants, 189 const NcclUniqueIdCallback* callback); // may be null 190 191 } // namespace gpu 192 } // namespace xla 193 194 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_ 195