• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CONTRIB_VERBS_RDMA_H_
17 #define TENSORFLOW_CONTRIB_VERBS_RDMA_H_
18 
19 #ifdef TENSORFLOW_USE_VERBS
20 
21 #include <infiniband/verbs.h>
22 #include <cstring>  // for memset
23 #include <functional>
24 #include <memory>  // for shared_ptr
25 #include <queue>
26 #include <string>
27 #include <unordered_map>
28 #include <vector>
29 
30 #include "tensorflow/contrib/verbs/verbs_util.h"
31 #include "tensorflow/core/distributed_runtime/worker_env.h"
32 #include "tensorflow/core/framework/rendezvous.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/tensor_shape.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/platform/env.h"
37 #include "tensorflow/core/platform/mutex.h"
38 
39 namespace tensorflow {
40 #define PKEY_DEFAULT 0
41 #define QUEUE_DEPTH_DEFAULT 1024
42 #define TIMEOUT_DEFAULT 14
43 #define RETRY_CNT_DEFAULT 7
44 #define SL_DEFAULT 0
45 #define TRAFFIC_CLASS 0
46 
47 #define RDMA_LOG_0 LOG(INFO)
48 #define RDMA_LOG_1 VLOG(1)
49 #define RDMA_LOG_2 VLOG(2)
50 #define RDMA_LOG(LEVEL) RDMA_LOG_##LEVEL
51 
52 struct RdmaParams {
53   uint8_t port_num;
54   uint8_t sgid_index;
55   uint8_t pkey_index;
56   uint32_t queue_depth;
57   uint8_t timeout;
58   uint8_t retry_cnt;
59   uint8_t sl;
60   enum ibv_mtu mtu;
61   uint8_t traffic_class;
62 };
63 // structure to save the address of remote channels.
64 struct RdmaAddress {
65   uint32_t lid;
66   uint32_t qpn;
67   uint32_t psn;
68   uint64_t snp;
69   uint64_t iid;
70 };
71 // structure to save information for remote memory regions.
72 struct RemoteMR {
73   uint64_t remote_addr;
74   uint32_t rkey;
75 };
76 enum BufferStatus { none, idle, busy };
77 enum Location { local, remote };
78 
79 enum RdmaMessageType {
80   RDMA_MESSAGE_META_DATA_UPDATE,
81   RDMA_MESSAGE_TENSOR_RE_REQUEST,
82   RDMA_MESSAGE_TENSOR_REQUEST,
83   RDMA_MESSAGE_ERROR_STATUS,
84 };
85 
86 struct RdmaMessage {
87   RdmaMessageType type_;
88   uint16_t name_size_;
89   string name_;
90   int64 step_id_;
91   uint64_t request_index_;
92   union {
93     uint64_t remote_addr_;
94 #ifdef RDMA_DATA_VALIDATION
95     uint64_t checksum_;
96 #endif
97   };
98   uint32_t rkey_;
99   bool is_dead_;
100   DataType data_type_;
101   TensorShape tensor_shape_;
102   size_t tensor_bytes_;
103 
104   // For error status:
105   Status status_;
106 
107   // type|name_size|name|step_id|request_index|remote_addr/checksum|rkey|...
108   //   1B|    2B   | 512|  8B   |     8B      |       8B           | 4B |...
109   // ...|is_dead|data_type|tensor_shape|tensor_bytes|error_status          |
110   // ...|    1B |   XB    |    XB      |    8B      |size - 4B, proto - XB |
111   static const size_t kNameCapacity = 512;
112   static const size_t kTypeStartIndex = 0;
113   static const size_t kNameSizeStartIndex = kTypeStartIndex + sizeof(type_);
114   static const size_t kNameStartIndex =
115       kNameSizeStartIndex + sizeof(name_size_);
116   static const size_t kStepIdStartIndex = kNameStartIndex + kNameCapacity;
117   static const size_t kRequestIndexStartIndex =
118       kStepIdStartIndex + sizeof(step_id_);
119   static const size_t kRemoteAddrStartIndex =
120       kRequestIndexStartIndex + sizeof(request_index_);
121   static const size_t kChecksumStartIndex = kRemoteAddrStartIndex;
122   static const size_t kRkeyStartIndex =
123       kRemoteAddrStartIndex + sizeof(remote_addr_);
124   static const size_t kIsDeadStartIndex = kRkeyStartIndex + sizeof(rkey_);
125   static const size_t kDataTypeStartIndex =
126       kIsDeadStartIndex + sizeof(is_dead_);
127   static const size_t kTensorShapeStartIndex =
128       kDataTypeStartIndex + sizeof(data_type_);
129   static const size_t kTensorBytesStartIndex =
130       kTensorShapeStartIndex + sizeof(TensorShape);
131   static const size_t kErrorStatusStartIndex =
132       kTensorBytesStartIndex + sizeof(tensor_bytes_);
133   static const size_t kErrorStatusMaxSize = 4096;
134 
135   static const size_t kMessageTotalBytes = kErrorStatusStartIndex;
136   static const size_t kRdmaMessageBufferSize =
137       kMessageTotalBytes + kErrorStatusMaxSize;
138   static string CreateMessage(const RdmaMessage& rm);
139   static void ParseMessage(RdmaMessage& rm, void* buffer);
140 };
141 
142 // Immediate types for RDMA write
143 enum RdmaImmDataType {
144   RDMA_IMM_MAX_REQUEST_ID = 0xFFFFFFFD,
145   RDMA_IMM_DATA_ACK = 0xFFFFFFFE,
146   RDMA_IMM_DATA_MESSAGE = 0xFFFFFFFF
147 };
148 
149 // Write types for RDMA write-complete events
150 enum RdmaWriteIDType {
151   RDMA_WRITE_ID_ACK,
152   RDMA_WRITE_ID_MESSAGE,
153   RDMA_WRITE_ID_TENSOR_WRITE
154 };
155 
156 // Context for RDMA write-complete events
157 class RdmaWriteID {
158  public:
RdmaWriteID(RdmaWriteIDType write_type,void * write_context)159   RdmaWriteID(RdmaWriteIDType write_type, void* write_context)
160       : write_type(write_type), write_context(write_context) {}
161 
162   RdmaWriteIDType write_type;
163   void* write_context;
164 };
165 
166 // Tensor meta-data
167 class TensorMetaData {
168  public:
169   TensorShape tensor_shape_;
170   DataType data_type_;
171   size_t proto_size_;
172   bool is_dead_;
173 
print(std::ostream & out)174   std::ostream& print(std::ostream& out) const {
175     out << "Dtype = " << DataTypeString(data_type_)
176         << ", Shape = " << tensor_shape_.DebugString() << ", Proto size = 0x"
177         << std::hex << proto_size_ << ", Is dead = " << is_dead_;
178     return out;
179   }
180 };
181 
182 inline std::ostream& operator<<(std::ostream& out,
183                                 const TensorMetaData& meta_data) {
184   return meta_data.print(out);
185 }
186 
187 class RdmaChannel;
188 
189 void MRDeleter(ibv_mr* mr);
190 using MemoryRegionPtr = std::unique_ptr<ibv_mr, decltype(&MRDeleter)>;
191 
192 // RdmaMemoryMgr
193 // Manages the local meta-data cache, and the registered RDMA memory regions.
194 class RdmaMemoryMgr {
195  public:
Singleton()196   static RdmaMemoryMgr& Singleton() {
197     static RdmaMemoryMgr instance;
198     return instance;
199   }
200 
201   // Memory regions
202   ibv_mr* FindMemoryRegion(void* addr, size_t length);
203   void InsertMemoryRegion(void* addr, size_t length,
204                           const std::string& allocator_name);
205   void EvictMemoryRegion(void* addr, size_t length);
206 
207   // Tensor meta-data cache
208   const TensorMetaData* GetTensorMetaData(const std::string& tensor_name);
209   const TensorMetaData* SetTensorMetaData(const std::string& tensor_name,
210                                           DataType dtype,
211                                           const TensorShape& shape,
212                                           bool is_dead, size_t proto_size);
213 
214   struct ibv_pd* pd_;
215 
216  protected:
RdmaMemoryMgr()217   RdmaMemoryMgr() : pd_(nullptr) {}
218 
Comparator(const void * ptr,const MemoryRegionPtr & other)219   static bool Comparator(const void* ptr, const MemoryRegionPtr& other) {
220     return ptr < reinterpret_cast<char*>(other->addr) + other->length;
221   }
222 
223  private:
224   mutex tensor_meta_data_mu_;
225   std::unordered_map<std::string, TensorMetaData> tensors_meta_data_;
226 
227   // Managed memory regions
228   mutex mrs_mu_;
229   std::vector<MemoryRegionPtr> mrs_ GUARDED_BY(mrs_mu_);
230 };
231 
232 // RdmaTensorRequest
233 // Represents a single tensor request.
234 class RdmaTensorRequest {
235  public:
236   typedef Rendezvous::DoneCallback RecvDoneCallback;
237 
238   // Creates a tensor request identified by index.
239   RdmaTensorRequest(uint32_t index, const string& key, int64 step_id,
240                     RdmaChannel* channel, Device* dst_dev,
241                     const Rendezvous::Args recv_args,
242                     const RecvDoneCallback& done);
243   ~RdmaTensorRequest();
244 
245   // Request unique index.
index()246   uint32_t index() { return index_; }
247 
248   // Start the tensor request sequence.
249   //
250   // 1. Allocate the result tensor (and proxy tensor if required).
251   // 2. Send RDMA_MESSAGE_TENSOR_REQUEST to the remote side.
252   void Start();
253 
254   // Receive tensor meta-data.
255   //
256   // 1. Update the local meta-data cache.
257   // 2. Reallocate the result tensor (and proxy tensor if required).
258   // 3. Re-send the request to the remote side.
259   void RecvTensorMetaData(DataType dtype, TensorShape shape, bool is_dead,
260                           size_t proto_size);
261 
262   // Receive tensor content (RDMA write was completed).
263   //
264   // Decode proto if required and/or move to GPU if the content was not
265   // written to it directly (GPU direct is not available). Afterwards,
266   // invoke Done().
267   void RecvTensorContent();
268 
269   // Receive error status (in case of a remote error).
270   // Invoke Done() with the status code.
271   void RecvErrorStatus(const Status& status);
272 
273 #ifdef RDMA_DATA_VALIDATION
274   // Receive tensor checksum
275   //
276   // For validation: Get and store the Tensor's expected checksum for the
277   // current request. Compare the result Tensor's checksum with the stored
278   // checksum right before invoking Done().
RecvTensorChecksum(uint64_t checksum)279   void RecvTensorChecksum(uint64_t checksum) { checksum_ = checksum; }
280 #endif
281 
282  private:
283   void Done(const Status& s);
284   void Send(RdmaMessageType message_type);
285   bool AllocateTensors();
286   void AllocateTensorsAsync(StatusCallback done);
287   void DeallocateTensors();
288 
289   uint32_t index_;
290   string key_;
291   int64 step_id_;
292   RdmaChannel* channel_;
293   Device* dst_dev_;
294   Rendezvous::Args recv_args_;
295   const TensorMetaData* meta_data_;
296   Tensor* result_tensor_;
297   Tensor* proxy_tensor_;
298   void* rdma_addr_;
299   ibv_mr* mr_;
300   RecvDoneCallback done_;
301 #ifdef RDMA_DATA_VALIDATION
302   uint64_t checksum_;
303 #endif
304 };
305 
306 // RdmaTensorResponse
307 // Represents a single tensor response.
308 class RdmaTensorResponse {
309  public:
310   // Creates a response for request message.
RdmaTensorResponse(RdmaChannel * channel,const RdmaMessage & rm)311   RdmaTensorResponse(RdmaChannel* channel, const RdmaMessage& rm)
312       : channel_(channel), rm_(rm) {}
313 
Update(const RdmaMessage & rm)314   void Update(const RdmaMessage& rm) { rm_ = rm; }
315 
316   // Start the tensor response sequence.
317   //
318   // 1. Find the tensor in the local tag-match table and invoke RecvHandler.
319   //    (Using RecvLocalAsync()).
320   // 2. Compare the tensor's meta-data to the meta-data in the message (taken
321   //    from the requester's local cache).
322   //    If meta-data changed:
323   //    a. Clone the tensor to be sent later.
324   //    b. Send a meta-data update message and wait for re-request.
325   //    Else:
326   //    a. Send the tensor's content (using direct RDMA write).
327   void Start();
328 
329   // Resume the response sequence, after a re-request.
330   //
331   // 1. Send the tensor's content that was cloned earlier.
332   void Resume();
333 
334   // Destroy the response's resources and remove it from the pending list.
335   void Destroy();
336 
337  private:
338   void RecvHandler(Rendezvous::ParsedKey parsed,
339                    const Rendezvous::Args& send_args,
340                    const Rendezvous::Args& recv_args, const Tensor& in,
341                    bool is_dead);
342   void Clone(const Tensor& in, const TensorProto& proto, bool is_dead);
343   void Send(const Tensor& in, const TensorProto& proto, bool is_dead,
344             const Status& status);
345   bool TensorMetaDataChanged(const Tensor& in, bool is_dead);
346   Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
347                            Device** src_dev);
348   void SendMetaData(const Tensor& in, const TensorProto& proto, bool is_dead);
349   void SendContent(const Tensor& in, const TensorProto& proto, bool is_dead);
350   void SendErrorStatus(const Status& status);
351 
352   RdmaChannel* channel_;
353   RdmaMessage rm_;  // The request message
354   Device* src_dev_ = nullptr;
355   TensorBuffer* src_buffer_ = nullptr;
356   void* src_addr_ = nullptr;
357   ibv_mr* mr_ = nullptr;
358   uint64_t checksum_ = 0;
359   bool meta_data_changed_ = false;
360 
361   // Re-item:
362   TensorProto* proto_ = nullptr;
363   Tensor* tensor_ = nullptr;
364   bool is_dead_ = false;
365 };
366 
367 class RdmaMessageBuffer;
368 // Class that represents the Rdma Adapter.
369 // Responsible for creation of the completion queue, and handling
370 // of work completions.
371 class RdmaAdapter {
372   friend class RdmaChannel;
373   friend class RdmaMessageBuffer;
374   friend class RdmaTensorResponse;
375   friend class RdmaMgr;
376   friend class RdmaRemoteRendezvous;
377 
378  public:
379   RdmaAdapter(const WorkerEnv* worker_env);
380   ~RdmaAdapter();
381   // Adapter name, e.g. mlx5_0.
382   string name() const;
383   void StartPolling();
384   void Process_CQ();
385 
386  protected:
387   static const int MAX_CONCURRENT_WRITES = 1000;
388   ibv_context* context_;
389   // RDMA configuration parameters
390   RdmaParams params_;
391   // ibverbs protection domain
392   ibv_pd* pd_;
393   // Completion event channel, to wait for work completions
394   ibv_comp_channel* event_channel_;
395   // Completion queue, to poll on work completions
396   ibv_cq* cq_;
397   // Pre-allocated work completions array used for polling
398   ibv_wc wc_[MAX_CONCURRENT_WRITES * 2];
399   // worker env for thread
400   const WorkerEnv* worker_env_;
401   // thread for cq.
402   std::unique_ptr<Thread> polling_thread_;
403 };
404 
405 // Class that represents a connection to a remote Rdma peer.
406 // Responsible for connecting queue pairs.
407 class RdmaChannel {
408   friend class RdmaAdapter;
409   friend class RdmaMessageBuffer;
410   friend class RdmaTensorBuffer;
411   friend class RdmaTensorRequest;
412   friend class RdmaTensorResponse;
413   friend class RdmaMgr;
414   friend class RdmaRemoteRendezvous;
415 
416  public:
417   explicit RdmaChannel(const RdmaAdapter* adapter, const string local_name,
418                        const string remote_name_);
419   ~RdmaChannel();
self()420   inline const RdmaAddress& self() { return self_; }
421   RdmaAddress address() const;
message_buffers()422   inline const std::vector<RdmaMessageBuffer*>& message_buffers() const {
423     return message_buffers_;
424   }
425   void Connect(const RdmaAddress& remoteAddr);
426   void Connect();
427   void Recv();
428   void SetRemoteAddress(const RdmaAddress& ra, bool override);
429 
430   // Requests:
431   RdmaTensorRequest* InsertTensorRequest(
432       const string& key, int64 step_id, Device* dst_dev,
433       const Rendezvous::Args recv_args,
434       const RdmaTensorRequest::RecvDoneCallback& done);
435   void RemoveTensorRequest(uint32_t request_index);
436   RdmaTensorRequest* GetTensorRequest(uint32_t request_index);
437 
438   // Responses:
439   RdmaTensorResponse* AddTensorResponse(const RdmaMessage& rm);
440   RdmaTensorResponse* UpdateTensorResponse(const RdmaMessage& rm);
441   void RemoveTensorResponse(uint32_t request_index);
442 
443   static const int kNumMessageBuffers = 2;
444   static const int kPingRecvWrid = 0;
445 
446  private:
447   static const int kPingBuffSize = 1024;
448   char ping_buff_[kPingBuffSize];
449   struct ibv_mr* mr_;
450   struct ibv_sge ping_sge_list_;
451   int PingPostRecv();
452   int PingPostSend();
453 
454  protected:
455   const RdmaAdapter* adapter_;
456   RdmaAddress self_;
457   string local_name_;
458   string remote_name_;
459   ibv_qp* qp_;
460   mutex mu_;
461   bool connected_ GUARDED_BY(mu_) = false;
462   RdmaAddress remote_ GUARDED_BY(mu_);
463   bool remote_set_ GUARDED_BY(mu_) = false;
464   mutex ct_mu_;
465   typedef std::unordered_map<uint32_t, RdmaTensorRequest> RequestTable;
466   RequestTable request_table_ GUARDED_BY(ct_mu_);
467   uint32_t request_serial_ GUARDED_BY(ct_mu_);
468   mutex responses_mu_;
469   typedef std::unordered_map<uint32_t, RdmaTensorResponse> ResponsesTable;
470   ResponsesTable responses_table_ GUARDED_BY(responses_mu_);
471   RdmaMessageBuffer* tx_message_buffer_;
472   RdmaMessageBuffer* rx_message_buffer_;
473   std::vector<RdmaMessageBuffer*> message_buffers_;
474 };
475 
476 // Class that represents a buffer for Rdma message sending.
477 class RdmaMessageBuffer {
478   friend class RdmaChannel;
479   friend class RdmaAdapter;
480   friend class RdmaMgr;
481   friend class RdmaRemoteRendezvous;
482 
483  public:
484   explicit RdmaMessageBuffer(RdmaChannel* channel, string name);
485   ~RdmaMessageBuffer();
486 
buffer()487   inline void* buffer() const { return buffer_; }
self()488   inline ibv_mr* self() const { return self_; }
SetBufferStatus(Location loc,BufferStatus status)489   inline void SetBufferStatus(Location loc, BufferStatus status) {
490     mu_.lock();
491     if (loc == local) {
492       local_status_ = status;
493     } else {
494       remote_status_ = status;
495     }
496     mu_.unlock();
497   }
498   void FreeBuffer();
499   void EnqueueItem(string Item);
500   void SendNextItem();
501   void CreateCPUBuffer(size_t size, bool lock = true);
502   void SetRemoteMR(RemoteMR rmi, bool override);
503   void Write(uint32_t imm_data, size_t buffer_size);
504   static void Write(const RdmaChannel* channel, uint32_t imm_data,
505                     size_t buffer_size, uint64_t src_addr, uint32_t lkey,
506                     uint64_t remote_addr, uint32_t rkey,
507                     RdmaWriteIDType write_type, void* write_context);
508   static void SendAck(const RdmaChannel* channel);
509 
510  protected:
511   const RdmaChannel* channel_;
512   void* buffer_ = nullptr;
513   bool buffer_on_host_ = true;
514   size_t size_ = 0;
515   const string name_;
516   ibv_mr* self_ = nullptr;
517   mutex mu_;
518   RemoteMR remote_;
519   std::queue<string> queue_ GUARDED_BY(mu_);
520   BufferStatus local_status_ GUARDED_BY(mu_) = none;
521   BufferStatus remote_status_ GUARDED_BY(mu_) = none;
522 };
523 
524 }  // namespace tensorflow
525 
526 #endif  // TENSORFLOW_USE_VERBS
527 #endif  // TENSORFLOW_CONTRIB_VERBS_RDMA_H_
528