1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_ 18 19 #include <memory> 20 21 #include "absl/container/flat_hash_map.h" 22 #include "absl/container/flat_hash_set.h" 23 #include "absl/synchronization/blocking_counter.h" 24 #include "absl/synchronization/mutex.h" 25 #if GOOGLE_CUDA 26 #include "third_party/nccl/nccl.h" 27 #elif TENSORFLOW_USE_ROCM 28 #include "rocm/include/rccl/rccl.h" 29 #endif 30 #include "tensorflow/compiler/xla/refcounting_hash_map.h" 31 #include "tensorflow/compiler/xla/service/collective_ops_utils.h" 32 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" 33 #include "tensorflow/compiler/xla/status.h" 34 #include "tensorflow/compiler/xla/statusor.h" 35 #include "tensorflow/compiler/xla/xla_data.pb.h" 36 37 #if TENSORFLOW_USE_ROCM 38 // Local hipify of cuda symbols 39 #define cudaError_t hipError_t 40 #define cudaStream_t hipStream_t 41 #define cudaGetErrorString hipGetErrorString 42 #define cudaGetDevice hipGetDevice 43 #define cudaSetDevice hipSetDevice 44 #define cudaSuccess hipSuccess 45 #endif 46 47 namespace xla { 48 namespace gpu { 49 50 ncclRedOp_t ToNcclReduction(ReductionKind kind); 51 StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType element_type); 52 53 bool IsGlobalNcclConfig(); 54 55 Status ToStatus(ncclResult_t s, const char* file, int64 line, const char* expr); 56 Status ToStatus(cudaError_t s, const char* file, int64 line, const char* expr); 57 58 // Macros to return or warn on CUDA/NCCL errors. (The same macro works for both 59 // NCCL and CUDA errors.) 60 // 61 // It's tempting to say these macros belong in an XLA header somewhere, but in 62 // practice we don't do much direct-to-CUDA-API stuff outside of this file. 63 #define XLA_CUDA_STATUS(expr) \ 64 xla::gpu::ToStatus(expr, __FILE__, __LINE__, #expr) 65 66 #define XLA_CUDA_RETURN_IF_ERROR(expr) \ 67 do { \ 68 Status s = XLA_CUDA_STATUS(expr); \ 69 if (!s.ok()) { \ 70 return s; \ 71 } \ 72 } while (0) 73 74 #define XLA_CUDA_WARN_IF_ERROR(expr) \ 75 do { \ 76 Status s = XLA_CUDA_STATUS(expr); \ 77 if (!s.ok()) { \ 78 LOG(ERROR) << s.ToString(); \ 79 } \ 80 } while (0) 81 82 // RAII type for NCCL communicators. 83 using NcclComm = std::unique_ptr<ncclComm, void (*)(ncclComm_t)>; 84 85 // Owns a clique of NCCL comms which can be used for collective operations among 86 // a particular set of GPUs. 87 // 88 // Note that if you want to do a collective operation among a subset of these 89 // GPUs, you'll need a different clique. 90 class NcclClique { 91 public: 92 explicit NcclClique( 93 absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal); 94 95 ncclComm_t GetCommForDeviceOrdinal(int device_ordinal) const; mu()96 absl::Mutex* mu() { return &mu_; } 97 98 private: 99 absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal_; 100 absl::Mutex mu_; 101 }; 102 103 struct LocalParticipant { 104 int device_ordinal; 105 int rank; 106 }; 107 108 StatusOr<std::vector<LocalParticipant>> GetLocalParticipants( 109 const std::vector<GlobalDeviceId>& participants, 110 const std::vector<GlobalDeviceId>* local_devices); // may be null 111 112 class LockedNcclClique { 113 public: 114 LockedNcclClique(NcclClique& clique, std::unique_ptr<absl::MutexLock> lock, 115 absl::BlockingCounter* counter); 116 LockedNcclClique(LockedNcclClique&&); 117 ~LockedNcclClique(); 118 119 NcclClique& clique; 120 121 private: 122 // Must come after clique, so it is destroyed first. 123 // One thread holds a lock (it is null in the others). 124 std::unique_ptr<absl::MutexLock> lock_; 125 absl::BlockingCounter* counter_; 126 }; 127 128 // Threadsafe leaky map from NcclCliqueKeys to NcclCliques. 129 class NcclCliqueMap { 130 public: 131 StatusOr<NcclClique*> GetOrTryCreateIfAbsent( 132 const NcclCliqueKey& key, 133 const std::function<StatusOr<std::unique_ptr<NcclClique>>( 134 const NcclCliqueKey&)>& value_factory) ABSL_LOCKS_EXCLUDED(mu_); 135 136 // Runs a function over every key/value in the map. 137 void ForEach( 138 const std::function<void(const NcclCliqueKey&, const NcclClique&)>& fn) 139 ABSL_LOCKS_EXCLUDED(mu_); 140 141 private: 142 absl::Mutex mu_; 143 absl::flat_hash_map<NcclCliqueKey, std::unique_ptr<NcclClique>> map_ 144 ABSL_GUARDED_BY(mu_); 145 }; 146 147 NcclCliqueMap& NcclCliqueCache(); 148 149 // Acquires a locked NCCL clique for use in NCCL collective operations. 150 StatusOr<LockedNcclClique> AcquireNcclClique( 151 const RendezvousKey& rendezvous_key, int local_device_ordinal, 152 se::Stream* stream, const std::vector<LocalParticipant>& local_participants, 153 const NcclUniqueIdCallback* callback); // may be null 154 155 } // namespace gpu 156 } // namespace xla 157 158 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_ 159