• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifdef TENSORFLOW_USE_VERBS
17 
18 #include <fcntl.h>
19 #include <cstdlib>
20 
21 #include "tensorflow/contrib/verbs/rdma.h"
22 #include "tensorflow/contrib/verbs/verbs_service.pb.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/dma_helper.h"
25 #include "tensorflow/core/common_runtime/process_util.h"
26 #if GOOGLE_CUDA
27 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
28 #include "tensorflow/core/common_runtime/gpu/gpu_util.h"
29 #endif
30 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
31 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
32 #include "tensorflow/core/distributed_runtime/session_mgr.h"
33 #include "tensorflow/core/framework/rendezvous.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/core/lib/core/stringpiece.h"
37 #include "tensorflow/core/lib/core/threadpool.h"
38 #include "tensorflow/core/lib/hash/hash.h"
39 #include "tensorflow/core/lib/random/random.h"
40 
41 namespace tensorflow {
42 
43 #define RoCE_V2 "RoCE v2"
44 
45 namespace {
46 
47 // convenience function for printing message
MessageTypeToString(RdmaMessageType rmt)48 string MessageTypeToString(RdmaMessageType rmt) {
49   switch (rmt) {
50     case RDMA_MESSAGE_META_DATA_UPDATE:
51       return "RDMA_MESSAGE_META_DATA_UPDATE";
52       break;
53     case RDMA_MESSAGE_TENSOR_RE_REQUEST:
54       return "RDMA_MESSAGE_TENSOR_RE_REQUEST";
55       break;
56     case RDMA_MESSAGE_TENSOR_REQUEST:
57       return "RDMA_MESSAGE_TENSOR_REQUEST";
58       break;
59     default:
60       return "UNKNOWN MESSAGE";
61   }
62 }
63 }  // namespace
64 
65 // Function to get environment variable
66 // Args:
67 //    var_name - the name of the environmental variable
68 // Returns:
69 //    string with it's value or empty string if not set
get_env_var(char const * var_name)70 string get_env_var(char const* var_name) {
71   char const* var_temp = getenv(var_name);
72 
73   return (var_temp == NULL) ? string() : string(var_temp);
74 }
75 
76 // Function to open device
77 // Args:
78 //   ibv_dev device to open
79 // Returns:
80 //   context of the opened device
open_device(ibv_device * ibv_dev)81 ibv_context* open_device(ibv_device* ibv_dev) {
82   ibv_context* context = ibv_open_device(ibv_dev);
83 
84   CHECK(context) << "Open context failed for " << ibv_get_device_name(ibv_dev);
85   return context;
86 }
87 
88 // Function to count the number of active ports for device
89 // Args:
90 //   device - to check active ports
91 // Returns:
92 //   number of active ports of the given device
get_dev_active_port_count(ibv_device * device)93 int get_dev_active_port_count(ibv_device* device) {
94   ibv_device_attr device_att;
95   ibv_port_attr port_attr;
96   ibv_context* context = NULL;
97   int rc, port_index, active_ports = 0;
98 
99   context = ibv_open_device(device);
100   CHECK(context) << "Open context failed for " << ibv_get_device_name(device);
101   rc = ibv_query_device(context, &device_att);
102   CHECK(!rc) << "Failed to query the device";
103 
104   for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) {
105     rc = ibv_query_port(context, port_index, &port_attr);
106     CHECK(!rc) << "Failed to query the port" << port_index;
107     if (port_attr.state == IBV_PORT_ACTIVE) {
108       active_ports++;
109     }
110   }
111   ibv_close_device(context);
112   return active_ports;
113 }
114 
115 // Function to set device. If RDMA_DEVICE not set, search for device with active
116 // port.
117 // Fails if more than one device with active port was found.
118 // Returns:
119 //   device to use
set_device()120 ibv_device* set_device() {
121   ibv_device** dev_list;
122   int dev_num, device_index, device_to_open = 0;
123   int num_devs_with_active_port = 0;
124   string env_p_rdma_device, str_port_num;
125 
126   dev_list = ibv_get_device_list(&dev_num);
127   CHECK(dev_list) << "No InfiniBand device found";
128 
129   env_p_rdma_device = get_env_var("RDMA_DEVICE");
130   if (!env_p_rdma_device.empty()) {
131     for (device_index = 0; device_index < dev_num; device_index++) {
132       if (!env_p_rdma_device.compare(
133               ibv_get_device_name(dev_list[device_index]))) {
134         CHECK(get_dev_active_port_count(dev_list[device_index]) != 0)
135             << "Device " << ibv_get_device_name(dev_list[device_index])
136             << " has no active ports";
137         return dev_list[device_index];
138       }
139     }
140     // check validity of input device
141     CHECK(false) << "The device " << env_p_rdma_device << " wasn't found";
142   } else {
143     // set default device
144     str_port_num = get_env_var("RDMA_DEVICE_PORT");
145     CHECK(str_port_num.empty())
146         << "RDMA_DEVICE should be provided if RDMA_DEVICE_PORT is set by user";
147     for (device_index = 0; device_index < dev_num; device_index++) {
148       // get port_num
149       if (get_dev_active_port_count(dev_list[device_index]) > 0) {
150         num_devs_with_active_port++;
151         CHECK(num_devs_with_active_port <= 1) << ". More than one device with "
152                                                  "active port in the system. "
153                                                  "Please enter RDMA_DEVICE";
154         // found device with at least 1 active port
155         device_to_open = device_index;
156       }
157     }
158     CHECK(num_devs_with_active_port > 0)
159         << "There is no active port in the system";
160     return dev_list[device_to_open];
161   }
162   CHECK(false) << "No device was set!";
163   return NULL;  // never happens
164 }
165 
166 // Function to set port for device.
167 // If RDMA_DEVICE_PORT not set, first active port of the device will be set.
168 // Args:
169 //   context of the device
170 // Returns:
171 //   port to use
set_port(ibv_context * context)172 uint8_t set_port(ibv_context* context) {
173   uint8_t port_num = 0;  // 0 is illegal port number
174   string str_port_num;
175   ibv_device_attr device_att;
176   ibv_port_attr port_attr;
177   int rc, port_index;
178 
179   rc = ibv_query_device(context, &device_att);
180   CHECK(!rc) << "Failed to query the device\n";
181 
182   str_port_num = get_env_var("RDMA_DEVICE_PORT");
183   // user defined port
184   if (!str_port_num.empty()) {
185     port_num = stoi(str_port_num);
186     CHECK(port_num > 0) << "RDMA_DEVICE_PORT should be positive";
187     CHECK(port_num <= device_att.phys_port_cnt) << "RDMA_DEVICE_PORT should be "
188                                                    "less or equal to amount of "
189                                                    "available ports";
190     rc = ibv_query_port(context, port_num, &port_attr);
191     CHECK(!rc) << "Failed to query the port" << port_num;
192     // check if port id active
193     CHECK(port_attr.state == IBV_PORT_ACTIVE)
194         << "Selected RDMA_DEVICE_PORT is not active";
195   } else {  // set default port
196     for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) {
197       rc = ibv_query_port(context, port_index, &port_attr);
198       CHECK(!rc) << "Failed to query the port" << port_index;
199       if (port_attr.state == IBV_PORT_ACTIVE) {
200         port_num = port_index;
201         break;
202       }
203     }
204     CHECK_GT(port_num, 0) << "No active ports";
205   }
206   return port_num;
207 }
208 
209 // Function read from sysfs file
210 // Args:
211 //   dir - directory
212 //   file - file
213 //   buff - buffer for the result
214 //   size - buffer size
215 // Returns:
216 //   number of bytes were read or -1 if failed
read_sysfs_file(const char * dir,const char * file,char * buf,size_t size)217 int read_sysfs_file(const char* dir, const char* file, char* buf, size_t size) {
218   char* path;
219   int fd;
220   int len;
221 
222   if (asprintf(&path, "%s/%s", dir, file) < 0) return -1;
223 
224   fd = open(path, O_RDONLY);
225   if (fd < 0) {
226     free(path);
227     return -1;
228   }
229 
230   len = read(fd, buf, size);
231 
232   close(fd);
233   free(path);
234 
235   if (len > 0 && buf[len - 1] == '\n') buf[--len] = '\0';
236 
237   return len;
238 }
239 
240 // Function to check if GID index support RoCE V2
241 // Args:
242 //   context - device context
243 //   port_num - port number
244 //   index -  GID index
245 // Returns:
246 //   if GID supports RoCE V2 - true, otherwise - false.
is_gid_type_roce_v2(ibv_context * context,uint8_t port_num,uint8_t index)247 bool is_gid_type_roce_v2(ibv_context* context, uint8_t port_num,
248                          uint8_t index) {
249   char name[32];
250   char buff[41];
251 
252   snprintf(name, sizeof(name), "ports/%d/gid_attrs/types/%d", port_num, index);
253   if (read_sysfs_file(context->device->ibdev_path, name, buff, sizeof(buff)) <=
254       0) {
255     return false;
256   }
257   return !strcmp(buff, RoCE_V2);
258 }
259 
260 // Function to set GID index.
261 // If the port link is IB, no GID index should be selected.
262 // If Ethernet but RDMA_GID_INDEX not set gid index that supports
263 //   RoCE V2 will be chosen(fails if more than one IP is configured)
264 // Args:
265 //   context - device context
266 //   port_num - port number
267 // Returns:
268 //   GID index to use
set_gid(uint8_t port_num,ibv_context * context)269 uint8_t set_gid(uint8_t port_num, ibv_context* context) {
270   ibv_port_attr port_attr;
271   string gid_str;
272   int rc, i, gids_num = 0, v2_ip_num = 0;
273   union ibv_gid gid;
274   uint8_t gid_index = 0;
275 
276   rc = ibv_query_port(context, port_num, &port_attr);
277   CHECK(!rc) << "Failed to query the port" << port_num;
278 
279   for (i = 0; i < port_attr.gid_tbl_len; i++) {
280     rc = ibv_query_gid(context, port_num, i, &gid);
281     CHECK(!rc) << "Failed to query gid to port " << (int)port_num << " index "
282                << i;
283     if (gid.global.interface_id) {
284       gids_num++;
285       if (gid.global.subnet_prefix == 0 &&
286           is_gid_type_roce_v2(context, port_num, i)) {
287         if (v2_ip_num == 0) {
288           // can be overwritten by RDMA_GID_INDEX later
289           gid_index = i;
290         }
291         v2_ip_num++;
292       }
293     }
294   }
295   switch (port_attr.link_layer) {
296     case (IBV_LINK_LAYER_ETHERNET):
297       gid_str = get_env_var("RDMA_GID_INDEX");
298       if (!gid_str.empty()) {
299         gid_index = stoi(gid_str);
300         CHECK(gid_index < gids_num)
301             << "RDMA_GID_INDEX should be less than GIDs amount" << gids_num;
302       } else {
303         CHECK(v2_ip_num <= 1)
304             << "More than one IP is available, please specify GID_INDEX";
305       }
306       break;
307     case (IBV_LINK_LAYER_INFINIBAND):  // no need in GID index
308       break;
309     default:
310       LOG(INFO) << "Unknown port link layer. Currently supporting Ethernet and "
311                    "InfiniBand only. ";
312   }
313   if (!is_gid_type_roce_v2(context, port_num, gid_index)) {
314     LOG(INFO) << "RoCE v2 is not configured for GID_INDEX " << (int)gid_index;
315   }
316   return gid_index;
317 }
318 
319 // set the default or environment value to the configuration parameter.
320 // Args:
321 //   default_val- the default value for this parameter
322 //   env_param- the environment parameter's name
323 // Returns:
324 //   32-bit value
set_param(uint32_t default_val,const char * env_param)325 uint32_t set_param(uint32_t default_val, const char* env_param) {
326   uint32_t val = default_val;
327   string val_s;
328 
329   val_s = get_env_var(env_param);
330 
331   if (!val_s.empty()) {
332     val = stoi(val_s);
333   }
334   return val;
335 }
336 
set_mtu(uint8_t port_num,ibv_context * context)337 enum ibv_mtu set_mtu(uint8_t port_num, ibv_context* context) {
338   ibv_port_attr port_attr;
339   enum ibv_mtu mtu = IBV_MTU_512;
340   string mtu_s;
341   int rc, mtu_i;
342 
343   rc = ibv_query_port(context, port_num, &port_attr);
344   CHECK(!rc) << "Failed to query the port" << port_num;
345 
346   mtu_s = get_env_var("RDMA_MTU");
347 
348   if (!mtu_s.empty()) {
349     mtu_i = stoi(mtu_s);
350     switch (mtu_i) {
351       case 256:
352         mtu = IBV_MTU_256;
353         break;
354       case 512:
355         mtu = IBV_MTU_512;
356         break;
357       case 1024:
358         mtu = IBV_MTU_1024;
359         break;
360       case 2048:
361         mtu = IBV_MTU_2048;
362         break;
363       case 4096:
364         mtu = IBV_MTU_4096;
365         break;
366       default:
367         CHECK(0) << "Error: MTU input value must be one of the following: 256, "
368                     "512, 1024, 2048, 4096. MTU "
369                  << mtu << " is invalid\n";
370         break;
371     }
372     CHECK(mtu < port_attr.active_mtu)
373         << "MTU configuration for the QPs is larger than active MTU";
374   } else {
375     mtu = port_attr.active_mtu;
376   }
377   return mtu;
378 }
379 
params_init(ibv_context * context)380 RdmaParams params_init(ibv_context* context) {
381   RdmaParams params;
382 
383   params.port_num = set_port(context);
384   params.sgid_index = set_gid(params.port_num, context);
385   params.pkey_index = (uint8_t)set_param(PKEY_DEFAULT, "RDMA_PKEY");
386   params.queue_depth = set_param(QUEUE_DEPTH_DEFAULT, "RDMA_QUEUE_DEPTH");
387   params.timeout = (uint8_t)set_param(TIMEOUT_DEFAULT, "RDMA_TIMEOUT");
388   params.retry_cnt = (uint8_t)set_param(RETRY_CNT_DEFAULT, "RDMA_RETRY_CNT");
389   params.sl = (uint8_t)set_param(SL_DEFAULT, "RDMA_SL");
390   CHECK(params.sl <= 7) << "SL value is " << (int)params.sl
391                         << ". Valid values are 0-7.";
392   params.mtu = set_mtu(params.port_num, context);
393   params.traffic_class = set_param(TRAFFIC_CLASS, "RDMA_TRAFFIC_CLASS");
394   return params;
395 }
396 
alloc_protection_domain(ibv_context * context)397 ibv_pd* alloc_protection_domain(ibv_context* context) {
398   ibv_pd* pd = ibv_alloc_pd(context);
399   CHECK(pd) << "Failed to allocate protection domain";
400   return pd;
401 }
402 
RdmaAdapter(const WorkerEnv * worker_env)403 RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env)
404     : context_(open_device(set_device())),
405       params_(params_init(context_)),
406       pd_(alloc_protection_domain(context_)),
407       worker_env_(worker_env) {
408   event_channel_ = ibv_create_comp_channel(context_);
409   CHECK(event_channel_) << "Failed to create completion channel";
410   cq_ = ibv_create_cq(context_, MAX_CONCURRENT_WRITES * 2, NULL, event_channel_,
411                       0);
412   CHECK(cq_) << "Failed to create completion queue";
413   CHECK(!ibv_req_notify_cq(cq_, 0)) << "Failed to request CQ notification";
414 }
415 
~RdmaAdapter()416 RdmaAdapter::~RdmaAdapter() {
417   polling_thread_.reset();
418   CHECK(!ibv_destroy_cq(cq_)) << "Failed to destroy CQ";
419   CHECK(!ibv_destroy_comp_channel(event_channel_))
420       << "Failed to destroy channel";
421   CHECK(!ibv_dealloc_pd(pd_)) << "Failed to deallocate PD";
422   CHECK(!ibv_close_device(context_)) << "Failed to release context";
423 }
424 
StartPolling()425 void RdmaAdapter::StartPolling() {
426   polling_thread_.reset(Env::Default()->StartThread(
427       ThreadOptions(), "RdmaAdapterCQThread", [this] { Process_CQ(); }));
428   VLOG(2) << "Start RdmaAdapter: " << name();
429 }
430 
name() const431 string RdmaAdapter::name() const { return string(context_->device->name); }
432 
433 // Function to process incoming messages
434 // There are two types of messages:
435 // 1. IBV_WC_RECV_RDMA_WITH_IMM (receive)
436 // 2. IBV_WC_RDMA_WRITE (send))
Process_CQ()437 void RdmaAdapter::Process_CQ() {
438   while (true) {
439     ibv_cq* cq;
440     void* cq_context;
441     CHECK(!ibv_get_cq_event(event_channel_, &cq, &cq_context));
442     CHECK(cq == cq_);
443     ibv_ack_cq_events(cq, 1);
444     CHECK(!ibv_req_notify_cq(cq_, 0));
445 
446     int ne =
447         ibv_poll_cq(cq_, MAX_CONCURRENT_WRITES * 2, static_cast<ibv_wc*>(wc_));
448     CHECK_GE(ne, 0);
449     for (int i = 0; i < ne; ++i) {
450       CHECK(wc_[i].status == IBV_WC_SUCCESS)
451           << "Failed status \n"
452           << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " "
453           << static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err;
454       if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
455         RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id);
456         // put back a recv wr.
457         rc->Recv();
458         // imm_data is the index of RX buffer in the buffer table.
459         uint32_t imm_data = wc_[i].imm_data;
460         RdmaMessageBuffer* rb;
461         RdmaMessage rm;
462 
463         if (imm_data == RDMA_IMM_DATA_ACK) {
464           // receive an ack to a message
465           rb = rc->tx_message_buffer_;
466           rb->SetBufferStatus(remote, idle);
467           rb->SendNextItem();
468           continue;
469         }
470 
471         if (imm_data <= RDMA_IMM_MAX_REQUEST_ID) {
472           // receive a tensor RDMA write
473           uint32_t request_index = imm_data;
474           RdmaTensorRequest* request = rc->GetTensorRequest(request_index);
475           request->RecvTensorContent();
476           continue;
477         }
478 
479         // receive a control message
480         rb = rc->rx_message_buffer_;
481         RdmaMessage::ParseMessage(rm, rb->buffer_);
482         RdmaMessageBuffer::SendAck(rc);
483         RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec
484                     << ": Received " << MessageTypeToString(rm.type_) << " "
485                     << "#" << rm.request_index_ << ": " << rm.name_;
486 
487         if (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) {
488           RdmaTensorResponse* response = rc->AddTensorResponse(rm);
489           response->Start();
490         } else if (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) {
491           RdmaTensorRequest* request = rc->GetTensorRequest(rm.request_index_);
492           request->RecvTensorMetaData(rm.data_type_, rm.tensor_shape_,
493                                       rm.is_dead_, rm.tensor_bytes_);
494 #ifdef RDMA_DATA_VALIDATION
495           request->RecvTensorChecksum(rm.checksum_);
496 #endif
497         } else if (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST) {
498           RdmaTensorResponse* response = rc->UpdateTensorResponse(rm);
499           response->Resume();
500         } else if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
501           RdmaTensorRequest* request = rc->GetTensorRequest(rm.request_index_);
502           request->RecvErrorStatus(rm.status_);
503         }
504       } else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) {
505         RdmaWriteID* wr_id = reinterpret_cast<RdmaWriteID*>(wc_[i].wr_id);
506         RDMA_LOG(2) << "Write complete of type " << wr_id->write_type;
507         switch (wr_id->write_type) {
508           case RDMA_WRITE_ID_ACK:
509             break;
510           case RDMA_WRITE_ID_MESSAGE: {
511             RdmaMessageBuffer* rb =
512                 reinterpret_cast<RdmaMessageBuffer*>(wr_id->write_context);
513             rb->SetBufferStatus(local, idle);
514             rb->SendNextItem();
515             break;
516           }
517           case RDMA_WRITE_ID_TENSOR_WRITE: {
518             RdmaTensorResponse* response =
519                 reinterpret_cast<RdmaTensorResponse*>(wr_id->write_context);
520             response->Destroy();
521           }
522         }
523         delete wr_id;
524       }
525     }
526   }
527 }
528 
PingPostRecv()529 int RdmaChannel::PingPostRecv() {
530   struct ibv_recv_wr wr, *bad_wr;
531   memset(&wr, 0, sizeof(wr));
532   wr.sg_list = &ping_sge_list_;
533   wr.num_sge = 1;
534   wr.wr_id = kPingRecvWrid;
535 
536   return ibv_post_recv(qp_, &wr, &bad_wr);
537 }
538 
PingPostSend()539 int RdmaChannel::PingPostSend() {
540   struct ibv_send_wr wr, *bad_wr;
541   memset(&wr, 0, sizeof(wr));
542   wr.wr_id = (uint64_t)this;
543   wr.sg_list = &ping_sge_list_;
544   wr.num_sge = 1;
545   wr.opcode = IBV_WR_SEND;
546   wr.send_flags = IBV_SEND_SIGNALED;
547 
548   return ibv_post_send(qp_, &wr, &bad_wr);
549 }
550 
RdmaChannel(const RdmaAdapter * adapter,const string local_name,const string remote_name)551 RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
552                          const string remote_name)
553     : adapter_(adapter),
554       local_name_(local_name),
555       remote_name_(remote_name),
556       request_serial_(0) {
557   struct ibv_sge list;
558 
559   mr_ = ibv_reg_mr(adapter_->pd_, ping_buff_, kPingBuffSize,
560                    IBV_ACCESS_LOCAL_WRITE);
561   CHECK(mr_) << "Failed to register memory region";
562 
563   memset(&list, 0, sizeof(list));
564   list.addr = (uintptr_t)ping_buff_;
565   list.length = kPingBuffSize;
566   list.lkey = mr_->lkey;
567 
568   ping_sge_list_ = list;
569   // Create queue pair
570   {
571     struct ibv_qp_init_attr attr;
572     memset(&attr, 0, sizeof(ibv_qp_init_attr));
573     attr.send_cq = adapter_->cq_;
574     attr.recv_cq = adapter_->cq_;
575     attr.cap.max_send_wr = adapter_->params_.queue_depth;
576     attr.cap.max_recv_wr = adapter_->params_.queue_depth;
577     attr.cap.max_send_sge = 1;
578     attr.cap.max_recv_sge = 1;
579     attr.qp_type = IBV_QPT_RC;
580 
581     qp_ = ibv_create_qp(adapter_->pd_, &attr);
582     CHECK(qp_) << "Failed to create queue pair";
583   }
584 
585   // Init queue pair
586   {
587     struct ibv_qp_attr attr;
588     memset(&attr, 0, sizeof(ibv_qp_attr));
589     attr.qp_state = IBV_QPS_INIT;
590     attr.pkey_index = adapter_->params_.pkey_index;
591     attr.port_num = adapter_->params_.port_num;
592     attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE;
593 
594     int mask =
595         IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS;
596     CHECK(!ibv_modify_qp(qp_, &attr, mask)) << "Failed to set QP to INIT";
597   }
598 
599   // Local address
600   {
601     struct ibv_port_attr attr;
602     CHECK(
603         !ibv_query_port(adapter_->context_, adapter_->params_.port_num, &attr))
604         << "Query port";
605     self_.lid = attr.lid;
606     self_.qpn = qp_->qp_num;
607     self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff;
608     union ibv_gid gid;
609     CHECK(!ibv_query_gid(adapter_->context_, adapter_->params_.port_num,
610                          adapter_->params_.sgid_index, &gid))
611         << "Query gid";
612     self_.snp = gid.global.subnet_prefix;
613     self_.iid = gid.global.interface_id;
614   }
615 
616   // create message and ack buffers, then initialize the tables.
617   {
618     const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer"};
619     tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]);
620     rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]);
621     message_buffers_.reserve(kNumMessageBuffers);
622     message_buffers_.push_back(tx_message_buffer_);
623     message_buffers_.push_back(rx_message_buffer_);
624     // create buffer on host
625     tx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize);
626     rx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize);
627   }
628   CHECK(PingPostRecv() == 0) << "Couldn't post receive from " << remote_name_
629                              << " with error " << std::strerror(errno);
630 }
631 
~RdmaChannel()632 RdmaChannel::~RdmaChannel() {
633   ibv_dereg_mr(mr_);
634   CHECK(!ibv_destroy_qp(qp_)) << "Failed to destroy QP";
635   delete tx_message_buffer_;
636   delete rx_message_buffer_;
637 }
638 
SetRemoteAddress(const RdmaAddress & ra,bool override)639 void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
640   mutex_lock lock{mu_};
641   if ((override) || (!remote_set_)) {
642     remote_.lid = ra.lid;
643     remote_.qpn = ra.qpn;
644     remote_.psn = ra.psn;
645     remote_.snp = ra.snp;
646     remote_.iid = ra.iid;
647     remote_set_ = true;
648   } else {
649     CHECK(remote_.lid == ra.lid);
650     CHECK(remote_.qpn == ra.qpn);
651     CHECK(remote_.psn == ra.psn);
652     CHECK(remote_.snp == ra.snp);
653     CHECK(remote_.iid == ra.iid);
654   }
655 }
656 
657 // Adding tokens to the completion queue
658 // Tokens are needed to process future messages.
Recv()659 void RdmaChannel::Recv() {
660   struct ibv_recv_wr wr;
661   memset(&wr, 0, sizeof(wr));
662   wr.wr_id = (uint64_t)this;
663   struct ibv_recv_wr* bad_wr;
664   CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv";
665 }
666 
InsertTensorRequest(const string & key,int64 step_id,Device * dst_dev,const Rendezvous::Args recv_args,const RdmaTensorRequest::RecvDoneCallback & done)667 RdmaTensorRequest* RdmaChannel::InsertTensorRequest(
668     const string& key, int64 step_id, Device* dst_dev,
669     const Rendezvous::Args recv_args,
670     const RdmaTensorRequest::RecvDoneCallback& done) {
671   mutex_lock lock{ct_mu_};
672   uint32_t request_index = request_serial_++;
673   if (request_serial_ > RDMA_IMM_MAX_REQUEST_ID) {
674     request_serial_ = 0;
675   }
676   RdmaTensorRequest request(request_index, key, step_id, this, dst_dev,
677                             recv_args, done);
678   auto it = request_table_.emplace(request_index, request);
679   return &it.first->second;
680 }
681 
RemoveTensorRequest(uint32_t request_index)682 void RdmaChannel::RemoveTensorRequest(uint32_t request_index) {
683   mutex_lock lock{ct_mu_};
684   request_table_.erase(request_index);
685 }
686 
GetTensorRequest(uint32_t request_index)687 RdmaTensorRequest* RdmaChannel::GetTensorRequest(uint32_t request_index) {
688   mutex_lock lock{ct_mu_};
689   RequestTable::iterator iter = request_table_.find(request_index);
690   CHECK(iter != request_table_.end());
691   return &iter->second;
692 }
693 
Connect()694 void RdmaChannel::Connect() {
695   {
696     mutex_lock lock{mu_};
697     CHECK(remote_set_) << "remote channel is not set";
698   }
699   Connect(remote_);
700 }
701 
702 // Setup channel to a remote node
703 // Args:
704 //   remoteAddr: the rdma address of a remote channel.
705 // Returns:
706 //   None
Connect(const RdmaAddress & remoteAddr)707 void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
708   mutex_lock lock{mu_};
709   if (!connected_) {
710     struct ibv_qp_attr attr;
711     memset(&attr, 0, sizeof(ibv_qp_attr));
712     attr.qp_state = IBV_QPS_RTR;
713 
714     // This assumes both QP's ports are configured with the same MTU
715     attr.path_mtu = adapter_->params_.mtu;
716     attr.dest_qp_num = remoteAddr.qpn;
717     attr.rq_psn = remoteAddr.psn;
718     attr.max_dest_rd_atomic = 1;
719     attr.min_rnr_timer = 12;
720     attr.ah_attr.is_global = 1;
721     attr.ah_attr.grh.dgid.global.subnet_prefix = remoteAddr.snp;
722     attr.ah_attr.grh.dgid.global.interface_id = remoteAddr.iid;
723     attr.ah_attr.grh.flow_label = 0;
724     attr.ah_attr.grh.hop_limit = 255;
725     attr.ah_attr.dlid = remoteAddr.lid;
726     attr.ah_attr.sl = adapter_->params_.sl;
727     attr.ah_attr.src_path_bits = 0;
728     attr.ah_attr.port_num = adapter_->params_.port_num;
729     attr.ah_attr.grh.sgid_index = adapter_->params_.sgid_index;
730     attr.ah_attr.grh.traffic_class = adapter_->params_.traffic_class;
731 
732     int r;
733     CHECK(!(r = ibv_modify_qp(qp_, &attr,
734                               IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU |
735                                   IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
736                                   IBV_QP_MAX_DEST_RD_ATOMIC |
737                                   IBV_QP_MIN_RNR_TIMER)))
738         << "QP to Ready to Receive " << r;
739 
740     memset(&attr, 0, sizeof(ibv_qp_attr));
741     attr.qp_state = IBV_QPS_RTS;
742     attr.sq_psn = self_.psn;
743     attr.timeout = adapter_->params_.timeout;
744     attr.retry_cnt = adapter_->params_.retry_cnt;
745     attr.rnr_retry = 7; /* infinite */
746     attr.max_rd_atomic = 1;
747 
748     CHECK(!(r = ibv_modify_qp(qp_, &attr,
749                               IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT |
750                                   IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
751                                   IBV_QP_MAX_QP_RD_ATOMIC)))
752         << "QP to Ready to Send " << r;
753 
754     connected_ = true;
755   } else {
756     RDMA_LOG(2) << "channel already connected";
757   }
758 }
759 
RdmaMessageBuffer(RdmaChannel * channel,string name)760 RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name)
761     : channel_(channel), name_(name) {}
762 
~RdmaMessageBuffer()763 RdmaMessageBuffer::~RdmaMessageBuffer() {
764   CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
765   FreeBuffer();
766 }
767 
FreeBuffer()768 void RdmaMessageBuffer::FreeBuffer() {
769   if ((buffer_ != nullptr) && buffer_on_host_) {
770     free(buffer_);
771   }
772 }
773 
774 // Allocate CPU memory for the Rdma buffer
775 // Args:
776 //   size: to-be-allocated memory size
777 //   lock: whether or not mutex_lock the process to protect concurrency.
778 // Returns:
779 //   None
CreateCPUBuffer(size_t size,bool lock)780 void RdmaMessageBuffer::CreateCPUBuffer(size_t size, bool lock) {
781   CHECK(size > 0);
782   if (lock) {
783     mu_.lock();
784   }
785   if (local_status_ != none) {
786     // delete existing buffer
787     CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
788     FreeBuffer();
789   }
790   size_ = size;
791   buffer_ = malloc(size_);
792   self_ = ibv_reg_mr(channel_->adapter_->pd_, buffer_, size_,
793                      IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
794   CHECK(self_) << "Failed to register memory region";
795   buffer_on_host_ = true;
796   local_status_ = idle;
797   if (lock) {
798     mu_.unlock();
799   }
800 }
801 
802 // Set address of remote memory region
803 // Args:
804 //   rmr: address of remote memory region
805 //   override: whether override existing information
806 // Returns:
807 //   None
SetRemoteMR(RemoteMR rmr,bool override)808 void RdmaMessageBuffer::SetRemoteMR(RemoteMR rmr, bool override) {
809   mutex_lock lock{mu_};
810   if ((override) || (remote_status_ == none)) {
811     remote_.remote_addr = rmr.remote_addr;
812     remote_.rkey = rmr.rkey;
813     remote_status_ = idle;
814   } else {
815     CHECK(remote_.remote_addr == rmr.remote_addr);
816     CHECK(remote_.rkey == rmr.rkey);
817   }
818 }
819 
820 // Put a task in the buffer's job queue
EnqueueItem(string item)821 void RdmaMessageBuffer::EnqueueItem(string item) {
822   mutex_lock lock{mu_};
823   queue_.push(item);
824 }
825 
826 // Rdma-Write the content of the buffer
Write(uint32_t imm_data,size_t buffer_size)827 void RdmaMessageBuffer::Write(uint32_t imm_data, size_t buffer_size) {
828   Write(channel_, imm_data, buffer_size, (uint64_t)buffer_, self_->lkey,
829         remote_.remote_addr, remote_.rkey, RDMA_WRITE_ID_MESSAGE, this);
830 }
831 
832 // Generalized Write method
Write(const RdmaChannel * channel,uint32_t imm_data,size_t buffer_size,uint64_t src_addr,uint32_t lkey,uint64_t remote_addr,uint32_t rkey,RdmaWriteIDType write_type,void * write_context)833 void RdmaMessageBuffer::Write(const RdmaChannel* channel, uint32_t imm_data,
834                               size_t buffer_size, uint64_t src_addr,
835                               uint32_t lkey, uint64_t remote_addr,
836                               uint32_t rkey, RdmaWriteIDType write_type,
837                               void* write_context) {
838   struct ibv_sge list;
839   list.addr = src_addr;
840   list.length = buffer_size;
841   list.lkey = lkey;
842 
843   struct ibv_send_wr wr;
844   memset(&wr, 0, sizeof(wr));
845   wr.wr_id = (uint64_t) new RdmaWriteID(write_type, write_context);
846   wr.sg_list = &list;
847   wr.num_sge = 1;
848   wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
849   wr.send_flags = IBV_SEND_SIGNALED;
850   wr.imm_data = imm_data;
851   wr.wr.rdma.remote_addr = remote_addr;
852   wr.wr.rdma.rkey = rkey;
853 
854   struct ibv_send_wr* bad_wr;
855   CHECK(!ibv_post_send(channel->qp_, &wr, &bad_wr)) << "Failed to post send";
856 }
857 
858 // Send the next ack from the buffer's job queue.
SendAck(const RdmaChannel * channel)859 void RdmaMessageBuffer::SendAck(const RdmaChannel* channel) {
860   Write(channel, RDMA_IMM_DATA_ACK, 0, 0, 0, 0, 0, RDMA_WRITE_ID_ACK, nullptr);
861 }
862 
863 // Send the next message from the buffer's job queue.
SendNextItem()864 void RdmaMessageBuffer::SendNextItem() {
865   uint32_t imm_data = RDMA_IMM_DATA_MESSAGE;
866   mu_.lock();
867   if (!queue_.empty() && (local_status_ == idle) && (remote_status_ == idle)) {
868     local_status_ = busy;
869     remote_status_ = busy;
870     string message = queue_.front();
871     queue_.pop();
872     // local/remote_status_ won't be set back to idle
873     // unitl Write() is successful
874     mu_.unlock();
875     memcpy(buffer_, message.data(), message.size());
876     Write(imm_data, message.size());
877   } else {
878     mu_.unlock();
879   }
880 }
881 
882 #if GOOGLE_CUDA
CountCopies(const std::string & key,void * src_addr,void * dst_addr,size_t tensor_bytes,bool is_gpu_to_cpu)883 static void CountCopies(const std::string& key, void* src_addr, void* dst_addr,
884                         size_t tensor_bytes, bool is_gpu_to_cpu) {
885 #ifdef RDMA_COUNT_COPIES
886   static uint64_t numGPUToCPUCopies = 0;
887   static uint64_t numGPUToCPUCopiedBytes = 0;
888   static uint64_t numCPUToGPUCopies = 0;
889   static uint64_t numCPUToGPUCopiedBytes = 0;
890   static uint64_t numTotalCopies = 0;
891 
892   if (is_gpu_to_cpu) {
893     ++numGPUToCPUCopies;
894     numGPUToCPUCopiedBytes += tensor_bytes;
895   } else {
896     ++numCPUToGPUCopies;
897     numCPUToGPUCopiedBytes += tensor_bytes;
898   }
899   if ((++numTotalCopies % 0x400) == 0) {
900     RDMA_LOG(0) << "Tensor copies:"
901                 << " GPU to CPU: " << numGPUToCPUCopies << " ("
902                 << numGPUToCPUCopiedBytes << " Bytes)"
903                 << " CPU to GPU: " << numCPUToGPUCopies << " ("
904                 << numCPUToGPUCopiedBytes << " Bytes)";
905   }
906   RDMA_LOG(2) << "Copying tensor " << key << " From: " << src_addr
907               << " To: " << dst_addr;
908 #endif  // RDMA_COUNT_COPIES
909 }
910 #endif  // GOOGLE_CUDA
911 
912 #ifdef RDMA_DATA_VALIDATION
Checksum(Device * device,const DeviceContext * device_context,const Tensor & in)913 static uint64_t Checksum(Device* device, const DeviceContext* device_context,
914                          const Tensor& in) {
915   uint64 checksum = 0;
916   if (DataTypeCanUseMemcpy(in.dtype())) {
917 #if GOOGLE_CUDA
918     if (in.TotalBytes() == 0) {
919       return 0;
920     }
921     checksum = (device_context != nullptr)
922                    ? GPUUtil::Checksum(device, device_context, in)
923                    : GPUUtil::Checksum(in);
924 #endif  // GOOGLE_CUDA
925   } else {
926     string s = in.SummarizeValue(999999);
927     checksum = Hash64(s.c_str(), s.size(), 0);
928   }
929   return checksum;
930 }
931 
ValidateChecksum(uint64_t expected,uint64_t actual,const Tensor & in,uint32_t request_index,const std::string & key,const std::string & msg)932 static void ValidateChecksum(uint64_t expected, uint64_t actual,
933                              const Tensor& in, uint32_t request_index,
934                              const std::string& key, const std::string& msg) {
935   RDMA_LOG(2) << "Request #" << request_index << ": " << key
936               << ": Checksum: " << std::hex << " Expected = 0x" << expected
937               << ". Actual = 0x" << actual << ".";
938 
939   if (expected != actual) {
940     // Checksum failed. There is one case where this is allowed - if the
941     // tensor is an AssignAdd of the global step. Since the data-validation
942     // always postpones the Tensor response in order to send a checksum message,
943     // it is possible that the global-step was updated while the response was
944     // still in queue.
945     if ((in.TotalBytes() == 8) && (in.dtype() == DT_INT64)) {
946       int64_t prev_val = *(int64_t*)DMAHelper::base(&in) - 1;
947       actual = Hash64((const char*)&prev_val, 8, 0);
948     }
949     if (expected != actual) {
950       LOG(FATAL) << "[" << msg << "]: Checksum validation failed for request #"
951                  << request_index << ": " << key << std::hex << " "
952                  << DataTypeString(in.dtype()) << " "
953                  << in.shape().DebugString() << " (0x" << in.TotalBytes()
954                  << " bytes): "
955                  << " Expected 0x" << expected << ". Got 0x" << actual << ".";
956     }
957   }
958 }
959 #endif  // RDMA_DATA_VALIDATION
960 
961 #if GOOGLE_CUDA
962 // Sync the 'done' operation on the GPU stream, but without all the data
963 // copying.
StreamGPUOp(Device * gpu_device,const DeviceContext * device_context,StatusCallback done)964 static void StreamGPUOp(Device* gpu_device, const DeviceContext* device_context,
965                         StatusCallback done) {
966   Tensor dummy1, dummy2;
967   GPUUtil::CopyGPUTensorToCPU(gpu_device, device_context, &dummy1, &dummy2,
968                               done);
969 }
970 #endif  // GOOGLE_CUDA
971 
AddTensorResponse(const RdmaMessage & rm)972 RdmaTensorResponse* RdmaChannel::AddTensorResponse(const RdmaMessage& rm) {
973   mutex_lock lock{mu_};
974   auto it =
975       responses_table_.emplace(rm.request_index_, RdmaTensorResponse(this, rm));
976   CHECK(it.second) << "Response with the ID " << rm.request_index_
977                    << " already exists.";
978   return &it.first->second;
979 }
980 
UpdateTensorResponse(const RdmaMessage & rm)981 RdmaTensorResponse* RdmaChannel::UpdateTensorResponse(const RdmaMessage& rm) {
982   mutex_lock lock{mu_};
983   auto it = responses_table_.find(rm.request_index_);
984   CHECK(it != responses_table_.end()) << "No response found.";
985   RdmaTensorResponse* response = &it->second;
986   response->Update(rm);
987   return response;
988 }
989 
RemoveTensorResponse(uint32_t request_index)990 void RdmaChannel::RemoveTensorResponse(uint32_t request_index) {
991   mutex_lock lock{mu_};
992   responses_table_.erase(request_index);
993 }
994 
Start()995 void RdmaTensorResponse::Start() {
996   Rendezvous::ParsedKey parsed;
997   Status s = Rendezvous::ParseKey(rm_.name_, &parsed);
998   if (!s.ok()) {
999     SendErrorStatus(s);
1000     return;
1001   }
1002 
1003   channel_->adapter_->worker_env_->rendezvous_mgr->RecvLocalAsync(
1004       rm_.step_id_, parsed,
1005       [this, parsed](const Status& status, const Rendezvous::Args& send_args,
1006                      const Rendezvous::Args& recv_args, const Tensor& in,
1007                      bool is_dead) {
1008         CHECK(status.ok()) << "RecvLocalAsync was not ok."
1009                            << " error message: " << status.error_message();
1010         RecvHandler(parsed, send_args, recv_args, in, is_dead);
1011       });
1012 }
1013 
Resume()1014 void RdmaTensorResponse::Resume() { SendContent(*tensor_, *proto_, is_dead_); }
1015 
1016 // Helper for RecvTensor. Validates "key" and returns the source
1017 // device in "*src_dev".
PrepareRecvTensor(const Rendezvous::ParsedKey & parsed,Device ** src_dev)1018 Status RdmaTensorResponse::PrepareRecvTensor(
1019     const Rendezvous::ParsedKey& parsed, Device** src_dev) {
1020   // Figures out which device the tensor is hosted on.
1021   string local_name = DeviceNameUtils::LocalName(parsed.src_device);
1022   TF_RETURN_IF_ERROR(channel_->adapter_->worker_env_->device_mgr->LookupDevice(
1023       local_name, src_dev));
1024 
1025   // Does the device have the right incarnation number we expect?
1026   if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
1027     return errors::Aborted(
1028         "RecvTensor expects a different device incarnation: ",
1029         parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(),
1030         ". Your worker job (\"",
1031         channel_->adapter_->worker_env_->session_mgr->LegacySession()
1032             ->worker_name,
1033         "\") was probably restarted. Check your "
1034         "worker job for the reason why it was restarted.");
1035   }
1036 
1037   return Status::OK();
1038 }
1039 
RecvHandler(Rendezvous::ParsedKey parsed,const Rendezvous::Args & send_args,const Rendezvous::Args & recv_args,const Tensor & in,bool is_dead)1040 void RdmaTensorResponse::RecvHandler(Rendezvous::ParsedKey parsed,
1041                                      const Rendezvous::Args& send_args,
1042                                      const Rendezvous::Args& recv_args,
1043                                      const Tensor& in, bool is_dead) {
1044   Status s = PrepareRecvTensor(parsed, &src_dev_);
1045   if (!s.ok()) {
1046     SendErrorStatus(s);
1047     return;
1048   }
1049 
1050   meta_data_changed_ = TensorMetaDataChanged(in, is_dead);
1051 #ifdef RDMA_DATA_VALIDATION
1052   // Always send a meta data message with the source checksum
1053   meta_data_changed_ = rm_.type_ == RDMA_MESSAGE_TENSOR_REQUEST;
1054   checksum_ = Checksum(src_dev_, send_args.device_context, in);
1055 #endif
1056   bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
1057   // string tensor needs to be serialized
1058   Tensor copy;
1059   TensorProto proto;
1060   const bool on_host = send_args.alloc_attrs.on_host();
1061   if (src_dev_->tensorflow_gpu_device_info() && !on_host) {
1062 #if GOOGLE_CUDA
1063     DeviceContext* send_dev_context = send_args.device_context;
1064     CHECK(send_dev_context)
1065         << "send dev name: " << src_dev_->name()
1066         << " gpu_info: " << src_dev_->tensorflow_gpu_device_info();
1067 
1068     if (can_memcpy) {
1069       // If the tensor is located on a GDR compatible GPU, there is no need to
1070       // copy it. We can send directly from the source, just need to make sure
1071       // we are in sync with the GPU stream.
1072       // If the tensor's meta-data changed however, we will need to clone it,
1073       // so anyway we'll have to copy it from GPU to CPU first. If at some
1074       // point in time Clone() is changed to only save a shallow copy, we can
1075       // skip the copy here as well.
1076       if ((in.TotalBytes() > 0) && !meta_data_changed_ &&
1077           (RdmaMemoryMgr::Singleton().FindMemoryRegion(
1078                (void*)DMAHelper::base(&in), in.TotalBytes()) != nullptr)) {
1079         StreamGPUOp(src_dev_, send_dev_context,
1080                     [this, in, proto, is_dead](const Status& s) {
1081                       Send(in, proto, is_dead, s);
1082                     });
1083         return;
1084       }
1085 
1086       // The tensor must be copied from GPU to CPU, because either:
1087       // 1. The tensor is located on a non GDR compatible GPU.
1088       // 2. The tensor's meta-data has changed.
1089       Allocator* alloc = GPUProcessState::singleton()->GetGpuHostAllocator(0);
1090       copy = Tensor(alloc, in.dtype(), in.shape());
1091       CountCopies(rm_.name_, (void*)DMAHelper::base(&in),
1092                   (void*)DMAHelper::base(&copy), in.TotalBytes(), true);
1093       GPUUtil::CopyGPUTensorToCPU(
1094           src_dev_, send_dev_context, &in, &copy,
1095           [this, copy, proto, is_dead](const Status& s) {
1096             Send(copy, proto, is_dead, s);
1097           });
1098     } else {
1099       GPUUtil::SetProtoFromGPU(
1100           in, src_dev_, send_args.device_context, &proto, is_dead,
1101           [this, in, proto, is_dead](const Status& s) mutable {
1102             Send(in, proto, is_dead, s);
1103           });
1104     }
1105 #else
1106     SendErrorStatus(errors::Internal("No GPU device in process"));
1107 #endif  // GOOGLE_CUDA
1108   } else {
1109     // tensor is in CPU memory.
1110     if (!can_memcpy) {
1111       in.AsProtoTensorContent(&proto);
1112     }
1113     Send(in, proto, is_dead, Status::OK());
1114   }
1115 }
1116 
Send(const Tensor & in,const TensorProto & proto,bool is_dead,const Status & status)1117 void RdmaTensorResponse::Send(const Tensor& in, const TensorProto& proto,
1118                               bool is_dead, const Status& status) {
1119   if (!status.ok()) {
1120     SendErrorStatus(status);
1121     return;
1122   }
1123   bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
1124   bool proto_size_changed =
1125       (!can_memcpy) && (proto.ByteSize() != rm_.tensor_bytes_);
1126   if (meta_data_changed_ || proto_size_changed) {
1127     Clone(in, proto, is_dead);
1128     SendMetaData(in, proto, is_dead);
1129   } else {
1130     SendContent(in, proto, is_dead);
1131   }
1132 }
1133 
TensorMetaDataChanged(const Tensor & in,bool is_dead)1134 bool RdmaTensorResponse::TensorMetaDataChanged(const Tensor& in, bool is_dead) {
1135   return (rm_.data_type_ != in.dtype()) || (rm_.tensor_shape_ != in.shape()) ||
1136          (rm_.is_dead_ != is_dead);
1137 }
1138 
Clone(const Tensor & in,const TensorProto & proto,bool is_dead)1139 void RdmaTensorResponse::Clone(const Tensor& in, const TensorProto& proto,
1140                                bool is_dead) {
1141   // Clone the data to be sent later. For simplicity, we clone the tensor's
1142   // data even if it is already a copy. Performance is less of a concern here
1143   // since the meta-data hardly ever changes. The reason we create a copy, is
1144   // that some tensors share their buffer between different step-ids, so the
1145   // tensor content may change before re-request was completed.
1146   bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
1147   if (can_memcpy && (in.TotalBytes() > 0)) {
1148     AllocatorAttributes host_alloc_attrs;
1149     host_alloc_attrs.set_nic_compatible(true);
1150     host_alloc_attrs.set_on_host(true);
1151     Allocator* allocator = src_dev_->GetAllocator(host_alloc_attrs);
1152     tensor_ = new Tensor(allocator, in.dtype(), in.shape());
1153     memcpy(DMAHelper::base(tensor_), DMAHelper::base(&in), in.TotalBytes());
1154   } else {
1155     tensor_ = new Tensor(in.dtype(), in.shape());
1156   }
1157   if (!can_memcpy) {
1158     proto_ = new TensorProto(proto);
1159   }
1160   is_dead_ = is_dead;
1161 }
1162 
SendMetaData(const Tensor & in,const TensorProto & proto,bool is_dead)1163 void RdmaTensorResponse::SendMetaData(const Tensor& in,
1164                                       const TensorProto& proto, bool is_dead) {
1165   RDMA_LOG(2) << "Request #" << rm_.request_index_
1166               << ": Meta data changed: " << rm_.name_;
1167   bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
1168   size_t tensor_bytes = (can_memcpy) ? in.TotalBytes() : proto.ByteSize();
1169 
1170   // Send meta-data update:
1171   RdmaMessage rm;
1172   rm.type_ = RDMA_MESSAGE_META_DATA_UPDATE;
1173   rm.name_size_ = rm_.name_.size();
1174   rm.name_ = rm_.name_;
1175   rm.tensor_shape_ = in.shape();
1176   rm.data_type_ = in.dtype();
1177   rm.step_id_ = rm_.step_id_;
1178   rm.is_dead_ = is_dead;
1179   rm.tensor_bytes_ = tensor_bytes;
1180   rm.request_index_ = rm_.request_index_;
1181 #ifdef RDMA_DATA_VALIDATION
1182   rm.checksum_ = checksum_;
1183 #endif
1184   RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec
1185               << ": Sending RDMA_MESSAGE_META_DATA_UPDATE #"
1186               << rm.request_index_ << ": " << rm.name_
1187               << " (shape = " << rm.tensor_shape_.DebugString() << "."
1188               << " data-type = " << DataTypeString(rm.data_type_) << "."
1189               << " is-dead = " << rm.is_dead_ << ")";
1190 
1191   string message = RdmaMessage::CreateMessage(rm);
1192   channel_->tx_message_buffer_->EnqueueItem(message);
1193   channel_->tx_message_buffer_->SendNextItem();
1194 }
1195 
SendContent(const Tensor & in,const TensorProto & proto,bool is_dead)1196 void RdmaTensorResponse::SendContent(const Tensor& in, const TensorProto& proto,
1197                                      bool is_dead) {
1198   bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
1199   size_t tensor_bytes = (can_memcpy) ? in.TotalBytes() : proto.ByteSize();
1200   uint32_t imm_data = rm_.request_index_;
1201   if (!is_dead) {
1202     if (can_memcpy) {
1203       src_buffer_ = const_cast<TensorBuffer*>(DMAHelper::buffer(&in));
1204       if (src_buffer_ != nullptr) {
1205         src_buffer_->Ref();  // Keep buffer alive until write is complete
1206         src_addr_ = src_buffer_->data();
1207         mr_ = RdmaMemoryMgr::Singleton().FindMemoryRegion(src_addr_,
1208                                                           tensor_bytes);
1209       }
1210     } else {
1211       RDMA_LOG(2) << "Encoding proto: " << rm_.name_
1212                   << " (Size: " << tensor_bytes << ") " << in.DebugString();
1213       src_addr_ = malloc(tensor_bytes);
1214       mr_ = ibv_reg_mr(channel_->adapter_->pd_, src_addr_, tensor_bytes,
1215                        IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
1216       proto.SerializeToArray(src_addr_, tensor_bytes);
1217     }
1218   } else {
1219     tensor_bytes = 0;
1220   }
1221 
1222   uint32_t lkey = (mr_ == nullptr) ? 0 : mr_->lkey;
1223   RDMA_LOG(1) << "Step 0x" << std::hex << rm_.step_id_ << std::dec
1224               << ": Sending tensor content #" << rm_.request_index_ << " from "
1225               << std::hex << src_addr_ << " (0x" << lkey << ")"
1226               << " to " << rm_.remote_addr_ << " (0x" << rm_.rkey_
1227               << "): " << rm_.name_ << " (size: 0x" << std::hex << tensor_bytes
1228               << ")";
1229 
1230   RdmaMessageBuffer::Write(channel_, imm_data, tensor_bytes,
1231                            (uint64_t)src_addr_, lkey, rm_.remote_addr_,
1232                            rm_.rkey_, RDMA_WRITE_ID_TENSOR_WRITE, this);
1233 }
1234 
SendErrorStatus(const Status & status)1235 void RdmaTensorResponse::SendErrorStatus(const Status& status) {
1236   RdmaMessage rm;
1237   rm.type_ = RDMA_MESSAGE_ERROR_STATUS;
1238   rm.name_size_ = rm_.name_.size();
1239   rm.name_ = rm_.name_;
1240   rm.step_id_ = rm_.step_id_;
1241   rm.request_index_ = rm_.request_index_;
1242   rm.status_ = status;
1243   LOG(ERROR) << "Step 0x" << std::hex << rm.step_id_ << std::dec
1244              << ": Sending RDMA_MESSAGE_ERROR_STATUS #" << rm.request_index_
1245              << ": " << rm.name_ << ". Status: " << status.ToString();
1246 
1247   string message = RdmaMessage::CreateMessage(rm);
1248   channel_->tx_message_buffer_->EnqueueItem(message);
1249   channel_->tx_message_buffer_->SendNextItem();
1250 
1251   // Destroy the response.
1252   Destroy();
1253 }
1254 
Destroy()1255 void RdmaTensorResponse::Destroy() {
1256   if (src_buffer_ != nullptr) {
1257     src_buffer_->Unref();
1258   }
1259   if (tensor_ != nullptr) {
1260     delete tensor_;
1261   }
1262   if (proto_ != nullptr) {
1263     ibv_dereg_mr(mr_);
1264     free(src_addr_);
1265     delete proto_;
1266   }
1267   // Remove response from the pending list:
1268   channel_->RemoveTensorResponse(rm_.request_index_);
1269 }
1270 
1271 // Create a RdmaMessage according to the pre-defined format
1272 // Args:
1273 //   rm: the message structure
1274 // Returns:
1275 //   message in string format
CreateMessage(const RdmaMessage & rm)1276 string RdmaMessage::CreateMessage(const RdmaMessage& rm) {
1277   // Rdma Message format
1278   // type|name_size|name|step_id|request_index|remote_addr|rkey|is_dead|...
1279   //   1B|    2B   | 512|  8B   |     8B      |       8B  | 4B |    1B |...
1280   // ...|data_type|tensor_shape|tensor_bytes|error_status          |
1281   // ...|   XB    |    XB      |    8B      |size - 4B, proto - XB |
1282   //
1283   // ACK:             Imm-type: ACK
1284   // TENSOR_REQUEST:  Imm-type: MESSAGE
1285   //                  Fields: type, request_index, name, step_id, remote_addr,
1286   //                      rkey, is_dead, data_type, tensor_shape, tensor_bytes
1287   // META_DATA_UPDATE: Imm-type: MESSAGE
1288   //                  Fields: type, request_index, is_dead, data_type,
1289   //                      tensor_shape, tensor_bytes
1290   // TENSOR_RE_REQUST: Imm-type: MESSAGE
1291   //                  Fields: type, request_index, name, step_id, remote_addr,
1292   //                      rkey, is_dead, data_type, tensor_shape, tensor_bytes
1293   // ERROR_STATUS:    Imm-type: MESSAGE
1294   //                  Fields: type, request_index, name, step_id, error_status
1295   // Tensor content:  Imm-type: request_index
1296   size_t message_size = kMessageTotalBytes;
1297   char message[kMessageTotalBytes + kErrorStatusMaxSize];
1298   // type
1299   message[kTypeStartIndex] = static_cast<char>(rm.type_) & 0xff;
1300   // request index
1301   memcpy(&message[kRequestIndexStartIndex], &rm.request_index_,
1302          sizeof(rm.request_index_));
1303   // name, step_id, remote_addr, rkey
1304   if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
1305       (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
1306     memcpy(&message[kNameSizeStartIndex], &rm.name_size_,
1307            sizeof(rm.name_size_));
1308     memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size());
1309     memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_,
1310            sizeof(rm.remote_addr_));
1311     memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_));
1312     memcpy(&message[kStepIdStartIndex], &rm.step_id_, sizeof(rm.step_id_));
1313   }
1314   // is_dead, data_type, tensor_shape, tensor_bytes
1315   if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
1316       (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) ||
1317       (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
1318     memcpy(&message[kIsDeadStartIndex], &rm.is_dead_, sizeof(rm.is_dead_));
1319 
1320     memcpy(&message[kDataTypeStartIndex], &rm.data_type_,
1321            sizeof(rm.data_type_));
1322     memcpy(&message[kTensorShapeStartIndex], &rm.tensor_shape_,
1323            sizeof(rm.tensor_shape_));
1324     memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_,
1325            sizeof(rm.tensor_bytes_));
1326   }
1327   // checksum
1328 #ifdef RDMA_DATA_VALIDATION
1329   memcpy(&message[kChecksumStartIndex], &rm.checksum_, sizeof(rm.checksum_));
1330 #endif
1331   // error status
1332   if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
1333     ::grpc::Status gs = ToGrpcStatus(rm.status_);
1334     ErrorStatusProto gsProto;
1335     gsProto.set_error_code(gs.error_code());
1336     gsProto.set_error_message(gs.error_message());
1337     gsProto.set_error_details(gs.error_details());
1338     uint32_t gsProtoSize = gsProto.ByteSize();
1339     if (gsProtoSize + 4 > kErrorStatusMaxSize) {
1340       LOG(ERROR) << "Error status (" << gsProtoSize + 4 << " bytes) "
1341                  << "is too big to fit in RDMA message (" << kErrorStatusMaxSize
1342                  << " bytes). Truncated.";
1343       gsProtoSize = kErrorStatusMaxSize - 4;
1344     }
1345     uint32_t* proto_size = (uint32_t*)&message[kErrorStatusStartIndex];
1346     *proto_size = gsProtoSize;
1347     gsProto.SerializeToArray(&message[kErrorStatusStartIndex + 4], gsProtoSize);
1348     message_size += gsProtoSize + 4;
1349   }
1350   return string(message, message_size);
1351 }
1352 
1353 // Parse a RdmaMessage according to the pre-defined format
1354 // Args:
1355 //   rm: the message structure where the parsed message will be saved
1356 //   buffer: the place where the raw message is stored
1357 // Returns:
1358 //   None
ParseMessage(RdmaMessage & rm,void * buffer)1359 void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) {
1360   char* message = static_cast<char*>(buffer);
1361   // type
1362   rm.type_ = static_cast<RdmaMessageType>(message[kTypeStartIndex]);
1363   // request index
1364   memcpy(&rm.request_index_, &message[kRequestIndexStartIndex],
1365          sizeof(rm.request_index_));
1366   // name, step_id, remote_addr, rkey
1367   if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
1368       (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
1369     memcpy(&rm.name_size_, &message[kNameSizeStartIndex],
1370            sizeof(rm.name_size_));
1371     rm.name_ = string(&message[kNameStartIndex], rm.name_size_);
1372     memcpy(&rm.remote_addr_, &message[kRemoteAddrStartIndex],
1373            sizeof(rm.remote_addr_));
1374     memcpy(&rm.rkey_, &message[kRkeyStartIndex], sizeof(rm.rkey_));
1375     memcpy(&rm.step_id_, &message[kStepIdStartIndex], sizeof(rm.step_id_));
1376   }
1377   // data_type, tensor_bytes, tensor_shape, is_dead
1378   if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
1379       (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) ||
1380       (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
1381     memcpy(&rm.is_dead_, &message[kIsDeadStartIndex], sizeof(rm.is_dead_));
1382     memcpy(&rm.data_type_, &message[kDataTypeStartIndex],
1383            sizeof(rm.data_type_));
1384     memcpy(&rm.tensor_shape_, &message[kTensorShapeStartIndex],
1385            sizeof(rm.tensor_shape_));
1386     memcpy(&rm.tensor_bytes_, &message[kTensorBytesStartIndex],
1387            sizeof(rm.tensor_bytes_));
1388   }
1389   // checksum
1390 #ifdef RDMA_DATA_VALIDATION
1391   memcpy(&rm.checksum_, &message[kChecksumStartIndex], sizeof(rm.checksum_));
1392 #endif
1393   // error status
1394   if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
1395     ErrorStatusProto gsProto;
1396     uint32_t gsProtoSize = *(uint32_t*)&message[kErrorStatusStartIndex];
1397     CHECK(ParseProtoUnlimited(&gsProto, &message[kErrorStatusStartIndex + 4],
1398                               gsProtoSize))
1399         << "Failed to parse error status proto from message. Aborting.";
1400     ::grpc::Status gs((::grpc::StatusCode)gsProto.error_code(),
1401                       gsProto.error_message(), gsProto.error_details());
1402     rm.status_ = FromGrpcStatus(gs);
1403   }
1404 }
1405 
1406 //*****************************************************************************
1407 // RdmaMemoryMgr
1408 //*****************************************************************************
1409 
FindMemoryRegion(void * addr,size_t length)1410 ibv_mr* RdmaMemoryMgr::FindMemoryRegion(void* addr, size_t length) {
1411   mutex_lock l(mrs_mu_);
1412   auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
1413   if (iter == std::end(mrs_) || iter->get()->addr > addr) {
1414     return nullptr;
1415   } else {
1416     return iter->get();
1417   }
1418 }
1419 
InsertMemoryRegion(void * addr,size_t length,const std::string & allocator_name)1420 void RdmaMemoryMgr::InsertMemoryRegion(void* addr, size_t length,
1421                                        const std::string& allocator_name) {
1422   if (length == 0) return;
1423   ibv_mr* mr = ibv_reg_mr(pd_, addr, length,
1424                           IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
1425   RDMA_LOG(1) << "Insert memory region 0x" << std::hex << mr->rkey << ". ["
1426               << addr << "-" << (void*)((uint64_t)addr + length - 1) << "]"
1427               << " SIZE: 0x" << length << " (" << allocator_name << ").";
1428   if (mr != nullptr) {
1429     mutex_lock l(mrs_mu_);
1430     auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
1431     mrs_.insert(iter, {mr, &MRDeleter});
1432   } else {
1433     LOG(WARNING) << "Cannot register memory region";
1434   }
1435 }
1436 
EvictMemoryRegion(void * addr,size_t length)1437 void RdmaMemoryMgr::EvictMemoryRegion(void* addr, size_t length) {
1438   if (length == 0) return;
1439   mutex_lock l(mrs_mu_);
1440   auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
1441   if (iter != std::end(mrs_) && iter->get()->addr == addr) {
1442     mrs_.erase(iter);
1443     RDMA_LOG(1) << "Evict memory region 0x" << std::hex << iter->get()->rkey;
1444 
1445   } else {
1446     LOG(WARNING) << "Failed to de-register memory region";
1447   }
1448 }
1449 
GetTensorMetaData(const std::string & tensor_name)1450 const TensorMetaData* RdmaMemoryMgr::GetTensorMetaData(
1451     const std::string& tensor_name) {
1452   mutex_lock l(tensor_meta_data_mu_);
1453   auto it = tensors_meta_data_.find(tensor_name);
1454   if (it == tensors_meta_data_.end()) {
1455     return nullptr;
1456   }
1457   return &it->second;
1458 }
1459 
SetTensorMetaData(const std::string & tensor_name,DataType dtype,const TensorShape & shape,bool is_dead,size_t proto_size)1460 const TensorMetaData* RdmaMemoryMgr::SetTensorMetaData(
1461     const std::string& tensor_name, DataType dtype, const TensorShape& shape,
1462     bool is_dead, size_t proto_size) {
1463   mutex_lock l(tensor_meta_data_mu_);
1464   TensorMetaData& meta_data = tensors_meta_data_[tensor_name];
1465   meta_data.data_type_ = dtype;
1466   meta_data.tensor_shape_ = shape;
1467   meta_data.proto_size_ = proto_size;
1468   meta_data.is_dead_ = is_dead;
1469   return &meta_data;
1470 }
1471 
1472 //*****************************************************************************
1473 // RdmaTensorRequest
1474 //*****************************************************************************
1475 
RdmaTensorRequest(uint32_t index,const string & key,int64 step_id,RdmaChannel * channel,Device * dst_dev,const Rendezvous::Args recv_args,const RdmaTensorRequest::RecvDoneCallback & done)1476 RdmaTensorRequest::RdmaTensorRequest(
1477     uint32_t index, const string& key, int64 step_id, RdmaChannel* channel,
1478     Device* dst_dev, const Rendezvous::Args recv_args,
1479     const RdmaTensorRequest::RecvDoneCallback& done)
1480     : index_(index),
1481       key_(key),
1482       step_id_(step_id),
1483       channel_(channel),
1484       dst_dev_(dst_dev),
1485       recv_args_(recv_args),
1486       meta_data_(RdmaMemoryMgr::Singleton().GetTensorMetaData(key)),
1487       result_tensor_(nullptr),
1488       proxy_tensor_(nullptr),
1489       rdma_addr_(nullptr),
1490       mr_(nullptr),
1491       done_(done) {}
1492 
~RdmaTensorRequest()1493 RdmaTensorRequest::~RdmaTensorRequest() { DeallocateTensors(); }
1494 
Done(const Status & s)1495 void RdmaTensorRequest::Done(const Status& s) {
1496   Tensor val = std::move(*result_tensor_);
1497 
1498 #ifdef RDMA_DATA_VALIDATION
1499   // Validate checksum
1500   // Unfortunately we can't always do a Checksum directly on the result tensor.
1501   // If the result tensor is on GPU, then we need to copy it back to CPU. If
1502   // we happen to be in the midst of a proxy callback, then the copying will
1503   // get stuck.
1504   uint64_t checksum = (proxy_tensor_ != nullptr)
1505                           ? Checksum(nullptr, nullptr, *proxy_tensor_)
1506                           : Checksum(dst_dev_, recv_args_.device_context, val);
1507   ValidateChecksum(checksum_, checksum, val, index_, key_, "RDMA");
1508 #endif
1509 
1510   Rendezvous::Args recv_args = std::move(recv_args_);
1511   bool is_dead = (meta_data_ == nullptr) ? false : meta_data_->is_dead_;
1512   RecvDoneCallback done = done_;
1513   DeallocateTensors();
1514   channel_->RemoveTensorRequest(index_);
1515   done(s, Rendezvous::Args(), recv_args, val, is_dead);
1516 }
1517 
DeallocateTensors()1518 void RdmaTensorRequest::DeallocateTensors() {
1519   if (result_tensor_ != nullptr) {
1520     delete result_tensor_;
1521     result_tensor_ = nullptr;
1522   }
1523   if (proxy_tensor_ != nullptr) {
1524     delete proxy_tensor_;
1525     proxy_tensor_ = nullptr;
1526   }
1527 }
1528 
AllocateTensors()1529 bool RdmaTensorRequest::AllocateTensors() {
1530   result_tensor_ =
1531       new Tensor(dst_dev_->GetAllocator(recv_args_.alloc_attrs),
1532                  meta_data_->data_type_, meta_data_->tensor_shape_);
1533 
1534   size_t tensor_size = result_tensor_->TotalBytes();
1535   bool can_memcpy = DataTypeCanUseMemcpy(result_tensor_->dtype());
1536   if (can_memcpy) {
1537     if (tensor_size == 0) {
1538       return true;
1539     }
1540     rdma_addr_ = DMAHelper::base(result_tensor_);
1541     mr_ = RdmaMemoryMgr::Singleton().FindMemoryRegion(rdma_addr_, tensor_size);
1542 #if GOOGLE_CUDA
1543     if (mr_ == nullptr) {
1544       // Can't RDMA directly to result. Use a proxy.
1545       proxy_tensor_ =
1546           new Tensor(GPUProcessState::singleton()->GetGpuHostAllocator(0),
1547                      result_tensor_->dtype(), result_tensor_->shape());
1548       rdma_addr_ = DMAHelper::base(proxy_tensor_);
1549       mr_ =
1550           RdmaMemoryMgr::Singleton().FindMemoryRegion(rdma_addr_, tensor_size);
1551     }
1552 #endif
1553   } else {
1554     uint32_t proto_size = meta_data_->proto_size_;
1555     rdma_addr_ = malloc(proto_size);
1556     mr_ = ibv_reg_mr(RdmaMemoryMgr::Singleton().pd_, rdma_addr_, proto_size,
1557                      IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
1558   }
1559   CHECK(mr_ != nullptr) << " No memory region found for address " << rdma_addr_
1560                         << ": " << key_;
1561   return true;
1562 }
1563 
AllocateTensorsAsync(StatusCallback done)1564 void RdmaTensorRequest::AllocateTensorsAsync(StatusCallback done) {
1565   AllocateTensors();
1566   bool on_host = recv_args_.alloc_attrs.on_host();
1567   if (dst_dev_->tensorflow_gpu_device_info() && !on_host &&
1568       (proxy_tensor_ == nullptr)) {
1569 #if GOOGLE_CUDA
1570     // We need to sync the memory allocation on the GPU:
1571     StreamGPUOp(dst_dev_, recv_args_.device_context, done);
1572 #endif
1573   } else {
1574     done(Status::OK());
1575   }
1576 }
1577 
Send(RdmaMessageType message_type)1578 void RdmaTensorRequest::Send(RdmaMessageType message_type) {
1579   RdmaMessageBuffer* rb = channel_->tx_message_buffer_;
1580   RdmaMessage rm;
1581   rm.type_ = message_type;
1582   rm.request_index_ = index_;
1583   rm.name_size_ = key_.size();
1584   rm.name_ = key_;
1585   rm.step_id_ = step_id_;
1586   rm.remote_addr_ = (uint64_t)rdma_addr_;
1587   if (meta_data_ != nullptr) {
1588     rm.data_type_ = meta_data_->data_type_;
1589     rm.tensor_shape_ = meta_data_->tensor_shape_;
1590     rm.is_dead_ = meta_data_->is_dead_;
1591     rm.tensor_bytes_ = meta_data_->proto_size_;
1592   } else {
1593     rm.data_type_ = DT_INVALID;
1594   }
1595   rm.rkey_ = (mr_ == nullptr) ? 0 : mr_->rkey;
1596 
1597   RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec
1598               << ": Sending  " << MessageTypeToString(message_type) << " #"
1599               << index_ << ": " << rm.name_ << " on " << rdma_addr_
1600               << " (rkey: 0x" << std::hex << rm.rkey_ << ")";
1601 
1602   string message = RdmaMessage::CreateMessage(rm);
1603   rb->EnqueueItem(message);
1604   rb->SendNextItem();
1605 }
1606 
RecvTensorMetaData(DataType dtype,TensorShape shape,bool is_dead,size_t proto_size)1607 void RdmaTensorRequest::RecvTensorMetaData(DataType dtype, TensorShape shape,
1608                                            bool is_dead, size_t proto_size) {
1609   meta_data_ = RdmaMemoryMgr::Singleton().SetTensorMetaData(
1610       key_, dtype, shape, is_dead, proto_size);
1611 
1612   DeallocateTensors();
1613   AllocateTensorsAsync(
1614       [this](const Status& s) { Send(RDMA_MESSAGE_TENSOR_RE_REQUEST); });
1615 }
1616 
RecvTensorContent()1617 void RdmaTensorRequest::RecvTensorContent() {
1618   bool can_memcpy = DataTypeCanUseMemcpy(meta_data_->data_type_);
1619   size_t message_size =
1620       can_memcpy ? result_tensor_->TotalBytes() : meta_data_->proto_size_;
1621   RDMA_LOG(1) << "Step 0x" << std::hex << step_id_ << std::dec
1622               << ": Received tensor content #" << index_ << ": " << key_
1623               << " (Size: 0x" << std::hex << message_size << ")";
1624 
1625   Tensor val;
1626 
1627 #if GOOGLE_CUDA
1628   if (proxy_tensor_ != nullptr) {
1629     CountCopies(key_, (void*)DMAHelper::base(proxy_tensor_),
1630                 (void*)DMAHelper::base(result_tensor_),
1631                 result_tensor_->TotalBytes(), false);
1632     GPUUtil::CopyCPUTensorToGPU(proxy_tensor_, recv_args_.device_context,
1633                                 dst_dev_, result_tensor_,
1634                                 [this](const Status& s) {
1635                                   CHECK(s.ok()) << "copy tensor to gpu sync";
1636                                   Done(s);
1637                                 });
1638     return;
1639   }
1640 #endif
1641 
1642   if (can_memcpy) {
1643     Done(Status::OK());
1644   } else {
1645     RDMA_LOG(2) << "Decoding proto: " << key_
1646                 << " (Size: " << meta_data_->proto_size_ << ")";
1647     TensorProto proto;
1648     CHECK(ParseProtoUnlimited(&proto, rdma_addr_, meta_data_->proto_size_))
1649         << "fail to parse proto from array";
1650     ibv_dereg_mr(mr_);
1651     free(rdma_addr_);
1652     Status s = dst_dev_->MakeTensorFromProto(proto, recv_args_.alloc_attrs,
1653                                              result_tensor_);
1654     Done(s);
1655   }
1656 }
1657 
RecvErrorStatus(const Status & status)1658 void RdmaTensorRequest::RecvErrorStatus(const Status& status) {
1659   if (result_tensor_ == nullptr) {
1660     result_tensor_ = new Tensor();
1661   }
1662   LOG(ERROR) << "Received RDMA_MESSAGE_ERROR_STATUS: " << status.ToString();
1663   Done(status);
1664 }
1665 
Start()1666 void RdmaTensorRequest::Start() {
1667   meta_data_ = RdmaMemoryMgr::Singleton().GetTensorMetaData(key_);
1668   if (meta_data_ != nullptr) {
1669     AllocateTensorsAsync(
1670         [this](const Status& s) { Send(RDMA_MESSAGE_TENSOR_REQUEST); });
1671   } else {
1672     Send(RDMA_MESSAGE_TENSOR_REQUEST);
1673   }
1674 }
1675 
1676 }  // end namespace tensorflow
1677 
1678 #endif
1679