1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CONTRIB_VERBS_RDMA_H_ 17 #define TENSORFLOW_CONTRIB_VERBS_RDMA_H_ 18 19 #ifdef TENSORFLOW_USE_VERBS 20 21 #include <infiniband/verbs.h> 22 #include <cstring> // for memset 23 #include <functional> 24 #include <memory> // for shared_ptr 25 #include <queue> 26 #include <string> 27 #include <unordered_map> 28 #include <vector> 29 30 #include "tensorflow/contrib/verbs/verbs_util.h" 31 #include "tensorflow/core/distributed_runtime/worker_env.h" 32 #include "tensorflow/core/framework/rendezvous.h" 33 #include "tensorflow/core/framework/tensor.h" 34 #include "tensorflow/core/framework/tensor_shape.h" 35 #include "tensorflow/core/framework/types.h" 36 #include "tensorflow/core/platform/env.h" 37 #include "tensorflow/core/platform/mutex.h" 38 39 namespace tensorflow { 40 #define PKEY_DEFAULT 0 41 #define QUEUE_DEPTH_DEFAULT 1024 42 #define TIMEOUT_DEFAULT 14 43 #define RETRY_CNT_DEFAULT 7 44 #define SL_DEFAULT 0 45 #define TRAFFIC_CLASS 0 46 47 #define RDMA_LOG_0 LOG(INFO) 48 #define RDMA_LOG_1 VLOG(1) 49 #define RDMA_LOG_2 VLOG(2) 50 #define RDMA_LOG(LEVEL) RDMA_LOG_##LEVEL 51 52 struct RdmaParams { 53 uint8_t port_num; 54 uint8_t sgid_index; 55 uint8_t pkey_index; 56 uint32_t queue_depth; 57 uint8_t timeout; 58 uint8_t retry_cnt; 59 uint8_t sl; 60 enum ibv_mtu mtu; 61 uint8_t traffic_class; 62 }; 63 // structure to save the address of remote channels. 64 struct RdmaAddress { 65 uint32_t lid; 66 uint32_t qpn; 67 uint32_t psn; 68 uint64_t snp; 69 uint64_t iid; 70 }; 71 // structure to save information for remote memory regions. 72 struct RemoteMR { 73 uint64_t remote_addr; 74 uint32_t rkey; 75 }; 76 enum BufferStatus { none, idle, busy }; 77 enum Location { local, remote }; 78 79 enum RdmaMessageType { 80 RDMA_MESSAGE_META_DATA_UPDATE, 81 RDMA_MESSAGE_TENSOR_RE_REQUEST, 82 RDMA_MESSAGE_TENSOR_REQUEST, 83 RDMA_MESSAGE_ERROR_STATUS, 84 }; 85 86 struct RdmaMessage { 87 RdmaMessageType type_; 88 uint16_t name_size_; 89 string name_; 90 int64 step_id_; 91 uint64_t request_index_; 92 union { 93 uint64_t remote_addr_; 94 #ifdef RDMA_DATA_VALIDATION 95 uint64_t checksum_; 96 #endif 97 }; 98 uint32_t rkey_; 99 bool is_dead_; 100 DataType data_type_; 101 TensorShape tensor_shape_; 102 size_t tensor_bytes_; 103 104 // For error status: 105 Status status_; 106 107 // type|name_size|name|step_id|request_index|remote_addr/checksum|rkey|... 108 // 1B| 2B | 512| 8B | 8B | 8B | 4B |... 109 // ...|is_dead|data_type|tensor_shape|tensor_bytes|error_status | 110 // ...| 1B | XB | XB | 8B |size - 4B, proto - XB | 111 static const size_t kNameCapacity = 512; 112 static const size_t kTypeStartIndex = 0; 113 static const size_t kNameSizeStartIndex = kTypeStartIndex + sizeof(type_); 114 static const size_t kNameStartIndex = 115 kNameSizeStartIndex + sizeof(name_size_); 116 static const size_t kStepIdStartIndex = kNameStartIndex + kNameCapacity; 117 static const size_t kRequestIndexStartIndex = 118 kStepIdStartIndex + sizeof(step_id_); 119 static const size_t kRemoteAddrStartIndex = 120 kRequestIndexStartIndex + sizeof(request_index_); 121 static const size_t kChecksumStartIndex = kRemoteAddrStartIndex; 122 static const size_t kRkeyStartIndex = 123 kRemoteAddrStartIndex + sizeof(remote_addr_); 124 static const size_t kIsDeadStartIndex = kRkeyStartIndex + sizeof(rkey_); 125 static const size_t kDataTypeStartIndex = 126 kIsDeadStartIndex + sizeof(is_dead_); 127 static const size_t kTensorShapeStartIndex = 128 kDataTypeStartIndex + sizeof(data_type_); 129 static const size_t kTensorBytesStartIndex = 130 kTensorShapeStartIndex + sizeof(TensorShape); 131 static const size_t kErrorStatusStartIndex = 132 kTensorBytesStartIndex + sizeof(tensor_bytes_); 133 static const size_t kErrorStatusMaxSize = 4096; 134 135 static const size_t kMessageTotalBytes = kErrorStatusStartIndex; 136 static const size_t kRdmaMessageBufferSize = 137 kMessageTotalBytes + kErrorStatusMaxSize; 138 static string CreateMessage(const RdmaMessage& rm); 139 static void ParseMessage(RdmaMessage& rm, void* buffer); 140 }; 141 142 // Immediate types for RDMA write 143 enum RdmaImmDataType { 144 RDMA_IMM_MAX_REQUEST_ID = 0xFFFFFFFD, 145 RDMA_IMM_DATA_ACK = 0xFFFFFFFE, 146 RDMA_IMM_DATA_MESSAGE = 0xFFFFFFFF 147 }; 148 149 // Write types for RDMA write-complete events 150 enum RdmaWriteIDType { 151 RDMA_WRITE_ID_ACK, 152 RDMA_WRITE_ID_MESSAGE, 153 RDMA_WRITE_ID_TENSOR_WRITE 154 }; 155 156 // Context for RDMA write-complete events 157 class RdmaWriteID { 158 public: RdmaWriteID(RdmaWriteIDType write_type,void * write_context)159 RdmaWriteID(RdmaWriteIDType write_type, void* write_context) 160 : write_type(write_type), write_context(write_context) {} 161 162 RdmaWriteIDType write_type; 163 void* write_context; 164 }; 165 166 // Tensor meta-data 167 class TensorMetaData { 168 public: 169 TensorShape tensor_shape_; 170 DataType data_type_; 171 size_t proto_size_; 172 bool is_dead_; 173 print(std::ostream & out)174 std::ostream& print(std::ostream& out) const { 175 out << "Dtype = " << DataTypeString(data_type_) 176 << ", Shape = " << tensor_shape_.DebugString() << ", Proto size = 0x" 177 << std::hex << proto_size_ << ", Is dead = " << is_dead_; 178 return out; 179 } 180 }; 181 182 inline std::ostream& operator<<(std::ostream& out, 183 const TensorMetaData& meta_data) { 184 return meta_data.print(out); 185 } 186 187 class RdmaChannel; 188 189 void MRDeleter(ibv_mr* mr); 190 using MemoryRegionPtr = std::unique_ptr<ibv_mr, decltype(&MRDeleter)>; 191 192 // RdmaMemoryMgr 193 // Manages the local meta-data cache, and the registered RDMA memory regions. 194 class RdmaMemoryMgr { 195 public: Singleton()196 static RdmaMemoryMgr& Singleton() { 197 static RdmaMemoryMgr instance; 198 return instance; 199 } 200 201 // Memory regions 202 ibv_mr* FindMemoryRegion(void* addr, size_t length); 203 void InsertMemoryRegion(void* addr, size_t length, 204 const std::string& allocator_name); 205 void EvictMemoryRegion(void* addr, size_t length); 206 207 // Tensor meta-data cache 208 const TensorMetaData* GetTensorMetaData(const std::string& tensor_name); 209 const TensorMetaData* SetTensorMetaData(const std::string& tensor_name, 210 DataType dtype, 211 const TensorShape& shape, 212 bool is_dead, size_t proto_size); 213 214 struct ibv_pd* pd_; 215 216 protected: RdmaMemoryMgr()217 RdmaMemoryMgr() : pd_(nullptr) {} 218 Comparator(const void * ptr,const MemoryRegionPtr & other)219 static bool Comparator(const void* ptr, const MemoryRegionPtr& other) { 220 return ptr < reinterpret_cast<char*>(other->addr) + other->length; 221 } 222 223 private: 224 mutex tensor_meta_data_mu_; 225 std::unordered_map<std::string, TensorMetaData> tensors_meta_data_; 226 227 // Managed memory regions 228 mutex mrs_mu_; 229 std::vector<MemoryRegionPtr> mrs_ GUARDED_BY(mrs_mu_); 230 }; 231 232 // RdmaTensorRequest 233 // Represents a single tensor request. 234 class RdmaTensorRequest { 235 public: 236 typedef Rendezvous::DoneCallback RecvDoneCallback; 237 238 // Creates a tensor request identified by index. 239 RdmaTensorRequest(uint32_t index, const string& key, int64 step_id, 240 RdmaChannel* channel, Device* dst_dev, 241 const Rendezvous::Args recv_args, 242 const RecvDoneCallback& done); 243 ~RdmaTensorRequest(); 244 245 // Request unique index. index()246 uint32_t index() { return index_; } 247 248 // Start the tensor request sequence. 249 // 250 // 1. Allocate the result tensor (and proxy tensor if required). 251 // 2. Send RDMA_MESSAGE_TENSOR_REQUEST to the remote side. 252 void Start(); 253 254 // Receive tensor meta-data. 255 // 256 // 1. Update the local meta-data cache. 257 // 2. Reallocate the result tensor (and proxy tensor if required). 258 // 3. Re-send the request to the remote side. 259 void RecvTensorMetaData(DataType dtype, TensorShape shape, bool is_dead, 260 size_t proto_size); 261 262 // Receive tensor content (RDMA write was completed). 263 // 264 // Decode proto if required and/or move to GPU if the content was not 265 // written to it directly (GPU direct is not available). Afterwards, 266 // invoke Done(). 267 void RecvTensorContent(); 268 269 // Receive error status (in case of a remote error). 270 // Invoke Done() with the status code. 271 void RecvErrorStatus(const Status& status); 272 273 #ifdef RDMA_DATA_VALIDATION 274 // Receive tensor checksum 275 // 276 // For validation: Get and store the Tensor's expected checksum for the 277 // current request. Compare the result Tensor's checksum with the stored 278 // checksum right before invoking Done(). RecvTensorChecksum(uint64_t checksum)279 void RecvTensorChecksum(uint64_t checksum) { checksum_ = checksum; } 280 #endif 281 282 private: 283 void Done(const Status& s); 284 void Send(RdmaMessageType message_type); 285 bool AllocateTensors(); 286 void AllocateTensorsAsync(StatusCallback done); 287 void DeallocateTensors(); 288 289 uint32_t index_; 290 string key_; 291 int64 step_id_; 292 RdmaChannel* channel_; 293 Device* dst_dev_; 294 Rendezvous::Args recv_args_; 295 const TensorMetaData* meta_data_; 296 Tensor* result_tensor_; 297 Tensor* proxy_tensor_; 298 void* rdma_addr_; 299 ibv_mr* mr_; 300 RecvDoneCallback done_; 301 #ifdef RDMA_DATA_VALIDATION 302 uint64_t checksum_; 303 #endif 304 }; 305 306 // RdmaTensorResponse 307 // Represents a single tensor response. 308 class RdmaTensorResponse { 309 public: 310 // Creates a response for request message. RdmaTensorResponse(RdmaChannel * channel,const RdmaMessage & rm)311 RdmaTensorResponse(RdmaChannel* channel, const RdmaMessage& rm) 312 : channel_(channel), rm_(rm) {} 313 Update(const RdmaMessage & rm)314 void Update(const RdmaMessage& rm) { rm_ = rm; } 315 316 // Start the tensor response sequence. 317 // 318 // 1. Find the tensor in the local tag-match table and invoke RecvHandler. 319 // (Using RecvLocalAsync()). 320 // 2. Compare the tensor's meta-data to the meta-data in the message (taken 321 // from the requester's local cache). 322 // If meta-data changed: 323 // a. Clone the tensor to be sent later. 324 // b. Send a meta-data update message and wait for re-request. 325 // Else: 326 // a. Send the tensor's content (using direct RDMA write). 327 void Start(); 328 329 // Resume the response sequence, after a re-request. 330 // 331 // 1. Send the tensor's content that was cloned earlier. 332 void Resume(); 333 334 // Destroy the response's resources and remove it from the pending list. 335 void Destroy(); 336 337 private: 338 void RecvHandler(Rendezvous::ParsedKey parsed, 339 const Rendezvous::Args& send_args, 340 const Rendezvous::Args& recv_args, const Tensor& in, 341 bool is_dead); 342 void Clone(const Tensor& in, const TensorProto& proto, bool is_dead); 343 void Send(const Tensor& in, const TensorProto& proto, bool is_dead, 344 const Status& status); 345 bool TensorMetaDataChanged(const Tensor& in, bool is_dead); 346 Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, 347 Device** src_dev); 348 void SendMetaData(const Tensor& in, const TensorProto& proto, bool is_dead); 349 void SendContent(const Tensor& in, const TensorProto& proto, bool is_dead); 350 void SendErrorStatus(const Status& status); 351 352 RdmaChannel* channel_; 353 RdmaMessage rm_; // The request message 354 Device* src_dev_ = nullptr; 355 TensorBuffer* src_buffer_ = nullptr; 356 void* src_addr_ = nullptr; 357 ibv_mr* mr_ = nullptr; 358 uint64_t checksum_ = 0; 359 bool meta_data_changed_ = false; 360 361 // Re-item: 362 TensorProto* proto_ = nullptr; 363 Tensor* tensor_ = nullptr; 364 bool is_dead_ = false; 365 }; 366 367 class RdmaMessageBuffer; 368 // Class that represents the Rdma Adapter. 369 // Responsible for creation of the completion queue, and handling 370 // of work completions. 371 class RdmaAdapter { 372 friend class RdmaChannel; 373 friend class RdmaMessageBuffer; 374 friend class RdmaTensorResponse; 375 friend class RdmaMgr; 376 friend class RdmaRemoteRendezvous; 377 378 public: 379 RdmaAdapter(const WorkerEnv* worker_env); 380 ~RdmaAdapter(); 381 // Adapter name, e.g. mlx5_0. 382 string name() const; 383 void StartPolling(); 384 void Process_CQ(); 385 386 protected: 387 static const int MAX_CONCURRENT_WRITES = 1000; 388 ibv_context* context_; 389 // RDMA configuration parameters 390 RdmaParams params_; 391 // ibverbs protection domain 392 ibv_pd* pd_; 393 // Completion event channel, to wait for work completions 394 ibv_comp_channel* event_channel_; 395 // Completion queue, to poll on work completions 396 ibv_cq* cq_; 397 // Pre-allocated work completions array used for polling 398 ibv_wc wc_[MAX_CONCURRENT_WRITES * 2]; 399 // worker env for thread 400 const WorkerEnv* worker_env_; 401 // thread for cq. 402 std::unique_ptr<Thread> polling_thread_; 403 }; 404 405 // Class that represents a connection to a remote Rdma peer. 406 // Responsible for connecting queue pairs. 407 class RdmaChannel { 408 friend class RdmaAdapter; 409 friend class RdmaMessageBuffer; 410 friend class RdmaTensorBuffer; 411 friend class RdmaTensorRequest; 412 friend class RdmaTensorResponse; 413 friend class RdmaMgr; 414 friend class RdmaRemoteRendezvous; 415 416 public: 417 explicit RdmaChannel(const RdmaAdapter* adapter, const string local_name, 418 const string remote_name_); 419 ~RdmaChannel(); self()420 inline const RdmaAddress& self() { return self_; } 421 RdmaAddress address() const; message_buffers()422 inline const std::vector<RdmaMessageBuffer*>& message_buffers() const { 423 return message_buffers_; 424 } 425 void Connect(const RdmaAddress& remoteAddr); 426 void Connect(); 427 void Recv(); 428 void SetRemoteAddress(const RdmaAddress& ra, bool override); 429 430 // Requests: 431 RdmaTensorRequest* InsertTensorRequest( 432 const string& key, int64 step_id, Device* dst_dev, 433 const Rendezvous::Args recv_args, 434 const RdmaTensorRequest::RecvDoneCallback& done); 435 void RemoveTensorRequest(uint32_t request_index); 436 RdmaTensorRequest* GetTensorRequest(uint32_t request_index); 437 438 // Responses: 439 RdmaTensorResponse* AddTensorResponse(const RdmaMessage& rm); 440 RdmaTensorResponse* UpdateTensorResponse(const RdmaMessage& rm); 441 void RemoveTensorResponse(uint32_t request_index); 442 443 static const int kNumMessageBuffers = 2; 444 static const int kPingRecvWrid = 0; 445 446 private: 447 static const int kPingBuffSize = 1024; 448 char ping_buff_[kPingBuffSize]; 449 struct ibv_mr* mr_; 450 struct ibv_sge ping_sge_list_; 451 int PingPostRecv(); 452 int PingPostSend(); 453 454 protected: 455 const RdmaAdapter* adapter_; 456 RdmaAddress self_; 457 string local_name_; 458 string remote_name_; 459 ibv_qp* qp_; 460 mutex mu_; 461 bool connected_ GUARDED_BY(mu_) = false; 462 RdmaAddress remote_ GUARDED_BY(mu_); 463 bool remote_set_ GUARDED_BY(mu_) = false; 464 mutex ct_mu_; 465 typedef std::unordered_map<uint32_t, RdmaTensorRequest> RequestTable; 466 RequestTable request_table_ GUARDED_BY(ct_mu_); 467 uint32_t request_serial_ GUARDED_BY(ct_mu_); 468 mutex responses_mu_; 469 typedef std::unordered_map<uint32_t, RdmaTensorResponse> ResponsesTable; 470 ResponsesTable responses_table_ GUARDED_BY(responses_mu_); 471 RdmaMessageBuffer* tx_message_buffer_; 472 RdmaMessageBuffer* rx_message_buffer_; 473 std::vector<RdmaMessageBuffer*> message_buffers_; 474 }; 475 476 // Class that represents a buffer for Rdma message sending. 477 class RdmaMessageBuffer { 478 friend class RdmaChannel; 479 friend class RdmaAdapter; 480 friend class RdmaMgr; 481 friend class RdmaRemoteRendezvous; 482 483 public: 484 explicit RdmaMessageBuffer(RdmaChannel* channel, string name); 485 ~RdmaMessageBuffer(); 486 buffer()487 inline void* buffer() const { return buffer_; } self()488 inline ibv_mr* self() const { return self_; } SetBufferStatus(Location loc,BufferStatus status)489 inline void SetBufferStatus(Location loc, BufferStatus status) { 490 mu_.lock(); 491 if (loc == local) { 492 local_status_ = status; 493 } else { 494 remote_status_ = status; 495 } 496 mu_.unlock(); 497 } 498 void FreeBuffer(); 499 void EnqueueItem(string Item); 500 void SendNextItem(); 501 void CreateCPUBuffer(size_t size, bool lock = true); 502 void SetRemoteMR(RemoteMR rmi, bool override); 503 void Write(uint32_t imm_data, size_t buffer_size); 504 static void Write(const RdmaChannel* channel, uint32_t imm_data, 505 size_t buffer_size, uint64_t src_addr, uint32_t lkey, 506 uint64_t remote_addr, uint32_t rkey, 507 RdmaWriteIDType write_type, void* write_context); 508 static void SendAck(const RdmaChannel* channel); 509 510 protected: 511 const RdmaChannel* channel_; 512 void* buffer_ = nullptr; 513 bool buffer_on_host_ = true; 514 size_t size_ = 0; 515 const string name_; 516 ibv_mr* self_ = nullptr; 517 mutex mu_; 518 RemoteMR remote_; 519 std::queue<string> queue_ GUARDED_BY(mu_); 520 BufferStatus local_status_ GUARDED_BY(mu_) = none; 521 BufferStatus remote_status_ GUARDED_BY(mu_) = none; 522 }; 523 524 } // namespace tensorflow 525 526 #endif // TENSORFLOW_USE_VERBS 527 #endif // TENSORFLOW_CONTRIB_VERBS_RDMA_H_ 528