1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifdef TENSORFLOW_USE_VERBS
17
18 #include <fcntl.h>
19 #include <cstdlib>
20
21 #include "tensorflow/contrib/verbs/rdma.h"
22 #include "tensorflow/contrib/verbs/verbs_service.pb.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/dma_helper.h"
25 #include "tensorflow/core/common_runtime/process_util.h"
26 #if GOOGLE_CUDA
27 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
28 #include "tensorflow/core/common_runtime/gpu/gpu_util.h"
29 #endif
30 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
31 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
32 #include "tensorflow/core/distributed_runtime/session_mgr.h"
33 #include "tensorflow/core/framework/rendezvous.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/core/lib/core/stringpiece.h"
37 #include "tensorflow/core/lib/core/threadpool.h"
38 #include "tensorflow/core/lib/hash/hash.h"
39 #include "tensorflow/core/lib/random/random.h"
40
41 namespace tensorflow {
42
43 #define RoCE_V2 "RoCE v2"
44
45 namespace {
46
47 // convenience function for printing message
MessageTypeToString(RdmaMessageType rmt)48 string MessageTypeToString(RdmaMessageType rmt) {
49 switch (rmt) {
50 case RDMA_MESSAGE_META_DATA_UPDATE:
51 return "RDMA_MESSAGE_META_DATA_UPDATE";
52 break;
53 case RDMA_MESSAGE_TENSOR_RE_REQUEST:
54 return "RDMA_MESSAGE_TENSOR_RE_REQUEST";
55 break;
56 case RDMA_MESSAGE_TENSOR_REQUEST:
57 return "RDMA_MESSAGE_TENSOR_REQUEST";
58 break;
59 default:
60 return "UNKNOWN MESSAGE";
61 }
62 }
63 } // namespace
64
65 // Function to get environment variable
66 // Args:
67 // var_name - the name of the environmental variable
68 // Returns:
69 // string with it's value or empty string if not set
get_env_var(char const * var_name)70 string get_env_var(char const* var_name) {
71 char const* var_temp = getenv(var_name);
72
73 return (var_temp == NULL) ? string() : string(var_temp);
74 }
75
76 // Function to open device
77 // Args:
78 // ibv_dev device to open
79 // Returns:
80 // context of the opened device
open_device(ibv_device * ibv_dev)81 ibv_context* open_device(ibv_device* ibv_dev) {
82 ibv_context* context = ibv_open_device(ibv_dev);
83
84 CHECK(context) << "Open context failed for " << ibv_get_device_name(ibv_dev);
85 return context;
86 }
87
88 // Function to count the number of active ports for device
89 // Args:
90 // device - to check active ports
91 // Returns:
92 // number of active ports of the given device
get_dev_active_port_count(ibv_device * device)93 int get_dev_active_port_count(ibv_device* device) {
94 ibv_device_attr device_att;
95 ibv_port_attr port_attr;
96 ibv_context* context = NULL;
97 int rc, port_index, active_ports = 0;
98
99 context = ibv_open_device(device);
100 CHECK(context) << "Open context failed for " << ibv_get_device_name(device);
101 rc = ibv_query_device(context, &device_att);
102 CHECK(!rc) << "Failed to query the device";
103
104 for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) {
105 rc = ibv_query_port(context, port_index, &port_attr);
106 CHECK(!rc) << "Failed to query the port" << port_index;
107 if (port_attr.state == IBV_PORT_ACTIVE) {
108 active_ports++;
109 }
110 }
111 ibv_close_device(context);
112 return active_ports;
113 }
114
115 // Function to set device. If RDMA_DEVICE not set, search for device with active
116 // port.
117 // Fails if more than one device with active port was found.
118 // Returns:
119 // device to use
set_device()120 ibv_device* set_device() {
121 ibv_device** dev_list;
122 int dev_num, device_index, device_to_open = 0;
123 int num_devs_with_active_port = 0;
124 string env_p_rdma_device, str_port_num;
125
126 dev_list = ibv_get_device_list(&dev_num);
127 CHECK(dev_list) << "No InfiniBand device found";
128
129 env_p_rdma_device = get_env_var("RDMA_DEVICE");
130 if (!env_p_rdma_device.empty()) {
131 for (device_index = 0; device_index < dev_num; device_index++) {
132 if (!env_p_rdma_device.compare(
133 ibv_get_device_name(dev_list[device_index]))) {
134 CHECK(get_dev_active_port_count(dev_list[device_index]) != 0)
135 << "Device " << ibv_get_device_name(dev_list[device_index])
136 << " has no active ports";
137 return dev_list[device_index];
138 }
139 }
140 // check validity of input device
141 CHECK(false) << "The device " << env_p_rdma_device << " wasn't found";
142 } else {
143 // set default device
144 str_port_num = get_env_var("RDMA_DEVICE_PORT");
145 CHECK(str_port_num.empty())
146 << "RDMA_DEVICE should be provided if RDMA_DEVICE_PORT is set by user";
147 for (device_index = 0; device_index < dev_num; device_index++) {
148 // get port_num
149 if (get_dev_active_port_count(dev_list[device_index]) > 0) {
150 num_devs_with_active_port++;
151 CHECK(num_devs_with_active_port <= 1) << ". More than one device with "
152 "active port in the system. "
153 "Please enter RDMA_DEVICE";
154 // found device with at least 1 active port
155 device_to_open = device_index;
156 }
157 }
158 CHECK(num_devs_with_active_port > 0)
159 << "There is no active port in the system";
160 return dev_list[device_to_open];
161 }
162 CHECK(false) << "No device was set!";
163 return NULL; // never happens
164 }
165
166 // Function to set port for device.
167 // If RDMA_DEVICE_PORT not set, first active port of the device will be set.
168 // Args:
169 // context of the device
170 // Returns:
171 // port to use
set_port(ibv_context * context)172 uint8_t set_port(ibv_context* context) {
173 uint8_t port_num = 0; // 0 is illegal port number
174 string str_port_num;
175 ibv_device_attr device_att;
176 ibv_port_attr port_attr;
177 int rc, port_index;
178
179 rc = ibv_query_device(context, &device_att);
180 CHECK(!rc) << "Failed to query the device\n";
181
182 str_port_num = get_env_var("RDMA_DEVICE_PORT");
183 // user defined port
184 if (!str_port_num.empty()) {
185 port_num = stoi(str_port_num);
186 CHECK(port_num > 0) << "RDMA_DEVICE_PORT should be positive";
187 CHECK(port_num <= device_att.phys_port_cnt) << "RDMA_DEVICE_PORT should be "
188 "less or equal to amount of "
189 "available ports";
190 rc = ibv_query_port(context, port_num, &port_attr);
191 CHECK(!rc) << "Failed to query the port" << port_num;
192 // check if port id active
193 CHECK(port_attr.state == IBV_PORT_ACTIVE)
194 << "Selected RDMA_DEVICE_PORT is not active";
195 } else { // set default port
196 for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) {
197 rc = ibv_query_port(context, port_index, &port_attr);
198 CHECK(!rc) << "Failed to query the port" << port_index;
199 if (port_attr.state == IBV_PORT_ACTIVE) {
200 port_num = port_index;
201 break;
202 }
203 }
204 CHECK_GT(port_num, 0) << "No active ports";
205 }
206 return port_num;
207 }
208
209 // Function read from sysfs file
210 // Args:
211 // dir - directory
212 // file - file
213 // buff - buffer for the result
214 // size - buffer size
215 // Returns:
216 // number of bytes were read or -1 if failed
read_sysfs_file(const char * dir,const char * file,char * buf,size_t size)217 int read_sysfs_file(const char* dir, const char* file, char* buf, size_t size) {
218 char* path;
219 int fd;
220 int len;
221
222 if (asprintf(&path, "%s/%s", dir, file) < 0) return -1;
223
224 fd = open(path, O_RDONLY);
225 if (fd < 0) {
226 free(path);
227 return -1;
228 }
229
230 len = read(fd, buf, size);
231
232 close(fd);
233 free(path);
234
235 if (len > 0 && buf[len - 1] == '\n') buf[--len] = '\0';
236
237 return len;
238 }
239
240 // Function to check if GID index support RoCE V2
241 // Args:
242 // context - device context
243 // port_num - port number
244 // index - GID index
245 // Returns:
246 // if GID supports RoCE V2 - true, otherwise - false.
is_gid_type_roce_v2(ibv_context * context,uint8_t port_num,uint8_t index)247 bool is_gid_type_roce_v2(ibv_context* context, uint8_t port_num,
248 uint8_t index) {
249 char name[32];
250 char buff[41];
251
252 snprintf(name, sizeof(name), "ports/%d/gid_attrs/types/%d", port_num, index);
253 if (read_sysfs_file(context->device->ibdev_path, name, buff, sizeof(buff)) <=
254 0) {
255 return false;
256 }
257 return !strcmp(buff, RoCE_V2);
258 }
259
260 // Function to set GID index.
261 // If the port link is IB, no GID index should be selected.
262 // If Ethernet but RDMA_GID_INDEX not set gid index that supports
263 // RoCE V2 will be chosen(fails if more than one IP is configured)
264 // Args:
265 // context - device context
266 // port_num - port number
267 // Returns:
268 // GID index to use
set_gid(uint8_t port_num,ibv_context * context)269 uint8_t set_gid(uint8_t port_num, ibv_context* context) {
270 ibv_port_attr port_attr;
271 string gid_str;
272 int rc, i, gids_num = 0, v2_ip_num = 0;
273 union ibv_gid gid;
274 uint8_t gid_index = 0;
275
276 rc = ibv_query_port(context, port_num, &port_attr);
277 CHECK(!rc) << "Failed to query the port" << port_num;
278
279 for (i = 0; i < port_attr.gid_tbl_len; i++) {
280 rc = ibv_query_gid(context, port_num, i, &gid);
281 CHECK(!rc) << "Failed to query gid to port " << (int)port_num << " index "
282 << i;
283 if (gid.global.interface_id) {
284 gids_num++;
285 if (gid.global.subnet_prefix == 0 &&
286 is_gid_type_roce_v2(context, port_num, i)) {
287 if (v2_ip_num == 0) {
288 // can be overwritten by RDMA_GID_INDEX later
289 gid_index = i;
290 }
291 v2_ip_num++;
292 }
293 }
294 }
295 switch (port_attr.link_layer) {
296 case (IBV_LINK_LAYER_ETHERNET):
297 gid_str = get_env_var("RDMA_GID_INDEX");
298 if (!gid_str.empty()) {
299 gid_index = stoi(gid_str);
300 CHECK(gid_index < gids_num)
301 << "RDMA_GID_INDEX should be less than GIDs amount" << gids_num;
302 } else {
303 CHECK(v2_ip_num <= 1)
304 << "More than one IP is available, please specify GID_INDEX";
305 }
306 break;
307 case (IBV_LINK_LAYER_INFINIBAND): // no need in GID index
308 break;
309 default:
310 LOG(INFO) << "Unknown port link layer. Currently supporting Ethernet and "
311 "InfiniBand only. ";
312 }
313 if (!is_gid_type_roce_v2(context, port_num, gid_index)) {
314 LOG(INFO) << "RoCE v2 is not configured for GID_INDEX " << (int)gid_index;
315 }
316 return gid_index;
317 }
318
319 // set the default or environment value to the configuration parameter.
320 // Args:
321 // default_val- the default value for this parameter
322 // env_param- the environment parameter's name
323 // Returns:
324 // 32-bit value
set_param(uint32_t default_val,const char * env_param)325 uint32_t set_param(uint32_t default_val, const char* env_param) {
326 uint32_t val = default_val;
327 string val_s;
328
329 val_s = get_env_var(env_param);
330
331 if (!val_s.empty()) {
332 val = stoi(val_s);
333 }
334 return val;
335 }
336
set_mtu(uint8_t port_num,ibv_context * context)337 enum ibv_mtu set_mtu(uint8_t port_num, ibv_context* context) {
338 ibv_port_attr port_attr;
339 enum ibv_mtu mtu = IBV_MTU_512;
340 string mtu_s;
341 int rc, mtu_i;
342
343 rc = ibv_query_port(context, port_num, &port_attr);
344 CHECK(!rc) << "Failed to query the port" << port_num;
345
346 mtu_s = get_env_var("RDMA_MTU");
347
348 if (!mtu_s.empty()) {
349 mtu_i = stoi(mtu_s);
350 switch (mtu_i) {
351 case 256:
352 mtu = IBV_MTU_256;
353 break;
354 case 512:
355 mtu = IBV_MTU_512;
356 break;
357 case 1024:
358 mtu = IBV_MTU_1024;
359 break;
360 case 2048:
361 mtu = IBV_MTU_2048;
362 break;
363 case 4096:
364 mtu = IBV_MTU_4096;
365 break;
366 default:
367 CHECK(0) << "Error: MTU input value must be one of the following: 256, "
368 "512, 1024, 2048, 4096. MTU "
369 << mtu << " is invalid\n";
370 break;
371 }
372 CHECK(mtu < port_attr.active_mtu)
373 << "MTU configuration for the QPs is larger than active MTU";
374 } else {
375 mtu = port_attr.active_mtu;
376 }
377 return mtu;
378 }
379
params_init(ibv_context * context)380 RdmaParams params_init(ibv_context* context) {
381 RdmaParams params;
382
383 params.port_num = set_port(context);
384 params.sgid_index = set_gid(params.port_num, context);
385 params.pkey_index = (uint8_t)set_param(PKEY_DEFAULT, "RDMA_PKEY");
386 params.queue_depth = set_param(QUEUE_DEPTH_DEFAULT, "RDMA_QUEUE_DEPTH");
387 params.timeout = (uint8_t)set_param(TIMEOUT_DEFAULT, "RDMA_TIMEOUT");
388 params.retry_cnt = (uint8_t)set_param(RETRY_CNT_DEFAULT, "RDMA_RETRY_CNT");
389 params.sl = (uint8_t)set_param(SL_DEFAULT, "RDMA_SL");
390 CHECK(params.sl <= 7) << "SL value is " << (int)params.sl
391 << ". Valid values are 0-7.";
392 params.mtu = set_mtu(params.port_num, context);
393 params.traffic_class = set_param(TRAFFIC_CLASS, "RDMA_TRAFFIC_CLASS");
394 return params;
395 }
396
alloc_protection_domain(ibv_context * context)397 ibv_pd* alloc_protection_domain(ibv_context* context) {
398 ibv_pd* pd = ibv_alloc_pd(context);
399 CHECK(pd) << "Failed to allocate protection domain";
400 return pd;
401 }
402
RdmaAdapter(const WorkerEnv * worker_env)403 RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env)
404 : context_(open_device(set_device())),
405 params_(params_init(context_)),
406 pd_(alloc_protection_domain(context_)),
407 worker_env_(worker_env) {
408 event_channel_ = ibv_create_comp_channel(context_);
409 CHECK(event_channel_) << "Failed to create completion channel";
410 cq_ = ibv_create_cq(context_, MAX_CONCURRENT_WRITES * 2, NULL, event_channel_,
411 0);
412 CHECK(cq_) << "Failed to create completion queue";
413 CHECK(!ibv_req_notify_cq(cq_, 0)) << "Failed to request CQ notification";
414 }
415
~RdmaAdapter()416 RdmaAdapter::~RdmaAdapter() {
417 polling_thread_.reset();
418 CHECK(!ibv_destroy_cq(cq_)) << "Failed to destroy CQ";
419 CHECK(!ibv_destroy_comp_channel(event_channel_))
420 << "Failed to destroy channel";
421 CHECK(!ibv_dealloc_pd(pd_)) << "Failed to deallocate PD";
422 CHECK(!ibv_close_device(context_)) << "Failed to release context";
423 }
424
StartPolling()425 void RdmaAdapter::StartPolling() {
426 polling_thread_.reset(Env::Default()->StartThread(
427 ThreadOptions(), "RdmaAdapterCQThread", [this] { Process_CQ(); }));
428 VLOG(2) << "Start RdmaAdapter: " << name();
429 }
430
name() const431 string RdmaAdapter::name() const { return string(context_->device->name); }
432
433 // Function to process incoming messages
434 // There are two types of messages:
435 // 1. IBV_WC_RECV_RDMA_WITH_IMM (receive)
436 // 2. IBV_WC_RDMA_WRITE (send))
Process_CQ()437 void RdmaAdapter::Process_CQ() {
438 while (true) {
439 ibv_cq* cq;
440 void* cq_context;
441 CHECK(!ibv_get_cq_event(event_channel_, &cq, &cq_context));
442 CHECK(cq == cq_);
443 ibv_ack_cq_events(cq, 1);
444 CHECK(!ibv_req_notify_cq(cq_, 0));
445
446 int ne =
447 ibv_poll_cq(cq_, MAX_CONCURRENT_WRITES * 2, static_cast<ibv_wc*>(wc_));
448 CHECK_GE(ne, 0);
449 for (int i = 0; i < ne; ++i) {
450 CHECK(wc_[i].status == IBV_WC_SUCCESS)
451 << "Failed status \n"
452 << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " "
453 << static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err;
454 if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
455 RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id);
456 // put back a recv wr.
457 rc->Recv();
458 // imm_data is the index of RX buffer in the buffer table.
459 uint32_t imm_data = wc_[i].imm_data;
460 RdmaMessageBuffer* rb;
461 RdmaMessage rm;
462
463 if (imm_data == RDMA_IMM_DATA_ACK) {
464 // receive an ack to a message
465 rb = rc->tx_message_buffer_;
466 rb->SetBufferStatus(remote, idle);
467 rb->SendNextItem();
468 continue;
469 }
470
471 if (imm_data <= RDMA_IMM_MAX_REQUEST_ID) {
472 // receive a tensor RDMA write
473 uint32_t request_index = imm_data;
474 RdmaTensorRequest* request = rc->GetTensorRequest(request_index);
475 request->RecvTensorContent();
476 continue;
477 }
478
479 // receive a control message
480 rb = rc->rx_message_buffer_;
481 RdmaMessage::ParseMessage(rm, rb->buffer_);
482 RdmaMessageBuffer::SendAck(rc);
483 RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec
484 << ": Received " << MessageTypeToString(rm.type_) << " "
485 << "#" << rm.request_index_ << ": " << rm.name_;
486
487 if (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) {
488 RdmaTensorResponse* response = rc->AddTensorResponse(rm);
489 response->Start();
490 } else if (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) {
491 RdmaTensorRequest* request = rc->GetTensorRequest(rm.request_index_);
492 request->RecvTensorMetaData(rm.data_type_, rm.tensor_shape_,
493 rm.is_dead_, rm.tensor_bytes_);
494 #ifdef RDMA_DATA_VALIDATION
495 request->RecvTensorChecksum(rm.checksum_);
496 #endif
497 } else if (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST) {
498 RdmaTensorResponse* response = rc->UpdateTensorResponse(rm);
499 response->Resume();
500 } else if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
501 RdmaTensorRequest* request = rc->GetTensorRequest(rm.request_index_);
502 request->RecvErrorStatus(rm.status_);
503 }
504 } else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) {
505 RdmaWriteID* wr_id = reinterpret_cast<RdmaWriteID*>(wc_[i].wr_id);
506 RDMA_LOG(2) << "Write complete of type " << wr_id->write_type;
507 switch (wr_id->write_type) {
508 case RDMA_WRITE_ID_ACK:
509 break;
510 case RDMA_WRITE_ID_MESSAGE: {
511 RdmaMessageBuffer* rb =
512 reinterpret_cast<RdmaMessageBuffer*>(wr_id->write_context);
513 rb->SetBufferStatus(local, idle);
514 rb->SendNextItem();
515 break;
516 }
517 case RDMA_WRITE_ID_TENSOR_WRITE: {
518 RdmaTensorResponse* response =
519 reinterpret_cast<RdmaTensorResponse*>(wr_id->write_context);
520 response->Destroy();
521 }
522 }
523 delete wr_id;
524 }
525 }
526 }
527 }
528
PingPostRecv()529 int RdmaChannel::PingPostRecv() {
530 struct ibv_recv_wr wr, *bad_wr;
531 memset(&wr, 0, sizeof(wr));
532 wr.sg_list = &ping_sge_list_;
533 wr.num_sge = 1;
534 wr.wr_id = kPingRecvWrid;
535
536 return ibv_post_recv(qp_, &wr, &bad_wr);
537 }
538
PingPostSend()539 int RdmaChannel::PingPostSend() {
540 struct ibv_send_wr wr, *bad_wr;
541 memset(&wr, 0, sizeof(wr));
542 wr.wr_id = (uint64_t)this;
543 wr.sg_list = &ping_sge_list_;
544 wr.num_sge = 1;
545 wr.opcode = IBV_WR_SEND;
546 wr.send_flags = IBV_SEND_SIGNALED;
547
548 return ibv_post_send(qp_, &wr, &bad_wr);
549 }
550
RdmaChannel(const RdmaAdapter * adapter,const string local_name,const string remote_name)551 RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
552 const string remote_name)
553 : adapter_(adapter),
554 local_name_(local_name),
555 remote_name_(remote_name),
556 request_serial_(0) {
557 struct ibv_sge list;
558
559 mr_ = ibv_reg_mr(adapter_->pd_, ping_buff_, kPingBuffSize,
560 IBV_ACCESS_LOCAL_WRITE);
561 CHECK(mr_) << "Failed to register memory region";
562
563 memset(&list, 0, sizeof(list));
564 list.addr = (uintptr_t)ping_buff_;
565 list.length = kPingBuffSize;
566 list.lkey = mr_->lkey;
567
568 ping_sge_list_ = list;
569 // Create queue pair
570 {
571 struct ibv_qp_init_attr attr;
572 memset(&attr, 0, sizeof(ibv_qp_init_attr));
573 attr.send_cq = adapter_->cq_;
574 attr.recv_cq = adapter_->cq_;
575 attr.cap.max_send_wr = adapter_->params_.queue_depth;
576 attr.cap.max_recv_wr = adapter_->params_.queue_depth;
577 attr.cap.max_send_sge = 1;
578 attr.cap.max_recv_sge = 1;
579 attr.qp_type = IBV_QPT_RC;
580
581 qp_ = ibv_create_qp(adapter_->pd_, &attr);
582 CHECK(qp_) << "Failed to create queue pair";
583 }
584
585 // Init queue pair
586 {
587 struct ibv_qp_attr attr;
588 memset(&attr, 0, sizeof(ibv_qp_attr));
589 attr.qp_state = IBV_QPS_INIT;
590 attr.pkey_index = adapter_->params_.pkey_index;
591 attr.port_num = adapter_->params_.port_num;
592 attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE;
593
594 int mask =
595 IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS;
596 CHECK(!ibv_modify_qp(qp_, &attr, mask)) << "Failed to set QP to INIT";
597 }
598
599 // Local address
600 {
601 struct ibv_port_attr attr;
602 CHECK(
603 !ibv_query_port(adapter_->context_, adapter_->params_.port_num, &attr))
604 << "Query port";
605 self_.lid = attr.lid;
606 self_.qpn = qp_->qp_num;
607 self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff;
608 union ibv_gid gid;
609 CHECK(!ibv_query_gid(adapter_->context_, adapter_->params_.port_num,
610 adapter_->params_.sgid_index, &gid))
611 << "Query gid";
612 self_.snp = gid.global.subnet_prefix;
613 self_.iid = gid.global.interface_id;
614 }
615
616 // create message and ack buffers, then initialize the tables.
617 {
618 const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer"};
619 tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]);
620 rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]);
621 message_buffers_.reserve(kNumMessageBuffers);
622 message_buffers_.push_back(tx_message_buffer_);
623 message_buffers_.push_back(rx_message_buffer_);
624 // create buffer on host
625 tx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize);
626 rx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize);
627 }
628 CHECK(PingPostRecv() == 0) << "Couldn't post receive from " << remote_name_
629 << " with error " << std::strerror(errno);
630 }
631
~RdmaChannel()632 RdmaChannel::~RdmaChannel() {
633 ibv_dereg_mr(mr_);
634 CHECK(!ibv_destroy_qp(qp_)) << "Failed to destroy QP";
635 delete tx_message_buffer_;
636 delete rx_message_buffer_;
637 }
638
SetRemoteAddress(const RdmaAddress & ra,bool override)639 void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
640 mutex_lock lock{mu_};
641 if ((override) || (!remote_set_)) {
642 remote_.lid = ra.lid;
643 remote_.qpn = ra.qpn;
644 remote_.psn = ra.psn;
645 remote_.snp = ra.snp;
646 remote_.iid = ra.iid;
647 remote_set_ = true;
648 } else {
649 CHECK(remote_.lid == ra.lid);
650 CHECK(remote_.qpn == ra.qpn);
651 CHECK(remote_.psn == ra.psn);
652 CHECK(remote_.snp == ra.snp);
653 CHECK(remote_.iid == ra.iid);
654 }
655 }
656
657 // Adding tokens to the completion queue
658 // Tokens are needed to process future messages.
Recv()659 void RdmaChannel::Recv() {
660 struct ibv_recv_wr wr;
661 memset(&wr, 0, sizeof(wr));
662 wr.wr_id = (uint64_t)this;
663 struct ibv_recv_wr* bad_wr;
664 CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv";
665 }
666
InsertTensorRequest(const string & key,int64 step_id,Device * dst_dev,const Rendezvous::Args recv_args,const RdmaTensorRequest::RecvDoneCallback & done)667 RdmaTensorRequest* RdmaChannel::InsertTensorRequest(
668 const string& key, int64 step_id, Device* dst_dev,
669 const Rendezvous::Args recv_args,
670 const RdmaTensorRequest::RecvDoneCallback& done) {
671 mutex_lock lock{ct_mu_};
672 uint32_t request_index = request_serial_++;
673 if (request_serial_ > RDMA_IMM_MAX_REQUEST_ID) {
674 request_serial_ = 0;
675 }
676 RdmaTensorRequest request(request_index, key, step_id, this, dst_dev,
677 recv_args, done);
678 auto it = request_table_.emplace(request_index, request);
679 return &it.first->second;
680 }
681
RemoveTensorRequest(uint32_t request_index)682 void RdmaChannel::RemoveTensorRequest(uint32_t request_index) {
683 mutex_lock lock{ct_mu_};
684 request_table_.erase(request_index);
685 }
686
GetTensorRequest(uint32_t request_index)687 RdmaTensorRequest* RdmaChannel::GetTensorRequest(uint32_t request_index) {
688 mutex_lock lock{ct_mu_};
689 RequestTable::iterator iter = request_table_.find(request_index);
690 CHECK(iter != request_table_.end());
691 return &iter->second;
692 }
693
Connect()694 void RdmaChannel::Connect() {
695 {
696 mutex_lock lock{mu_};
697 CHECK(remote_set_) << "remote channel is not set";
698 }
699 Connect(remote_);
700 }
701
702 // Setup channel to a remote node
703 // Args:
704 // remoteAddr: the rdma address of a remote channel.
705 // Returns:
706 // None
Connect(const RdmaAddress & remoteAddr)707 void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
708 mutex_lock lock{mu_};
709 if (!connected_) {
710 struct ibv_qp_attr attr;
711 memset(&attr, 0, sizeof(ibv_qp_attr));
712 attr.qp_state = IBV_QPS_RTR;
713
714 // This assumes both QP's ports are configured with the same MTU
715 attr.path_mtu = adapter_->params_.mtu;
716 attr.dest_qp_num = remoteAddr.qpn;
717 attr.rq_psn = remoteAddr.psn;
718 attr.max_dest_rd_atomic = 1;
719 attr.min_rnr_timer = 12;
720 attr.ah_attr.is_global = 1;
721 attr.ah_attr.grh.dgid.global.subnet_prefix = remoteAddr.snp;
722 attr.ah_attr.grh.dgid.global.interface_id = remoteAddr.iid;
723 attr.ah_attr.grh.flow_label = 0;
724 attr.ah_attr.grh.hop_limit = 255;
725 attr.ah_attr.dlid = remoteAddr.lid;
726 attr.ah_attr.sl = adapter_->params_.sl;
727 attr.ah_attr.src_path_bits = 0;
728 attr.ah_attr.port_num = adapter_->params_.port_num;
729 attr.ah_attr.grh.sgid_index = adapter_->params_.sgid_index;
730 attr.ah_attr.grh.traffic_class = adapter_->params_.traffic_class;
731
732 int r;
733 CHECK(!(r = ibv_modify_qp(qp_, &attr,
734 IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU |
735 IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
736 IBV_QP_MAX_DEST_RD_ATOMIC |
737 IBV_QP_MIN_RNR_TIMER)))
738 << "QP to Ready to Receive " << r;
739
740 memset(&attr, 0, sizeof(ibv_qp_attr));
741 attr.qp_state = IBV_QPS_RTS;
742 attr.sq_psn = self_.psn;
743 attr.timeout = adapter_->params_.timeout;
744 attr.retry_cnt = adapter_->params_.retry_cnt;
745 attr.rnr_retry = 7; /* infinite */
746 attr.max_rd_atomic = 1;
747
748 CHECK(!(r = ibv_modify_qp(qp_, &attr,
749 IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT |
750 IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
751 IBV_QP_MAX_QP_RD_ATOMIC)))
752 << "QP to Ready to Send " << r;
753
754 connected_ = true;
755 } else {
756 RDMA_LOG(2) << "channel already connected";
757 }
758 }
759
RdmaMessageBuffer(RdmaChannel * channel,string name)760 RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name)
761 : channel_(channel), name_(name) {}
762
~RdmaMessageBuffer()763 RdmaMessageBuffer::~RdmaMessageBuffer() {
764 CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
765 FreeBuffer();
766 }
767
FreeBuffer()768 void RdmaMessageBuffer::FreeBuffer() {
769 if ((buffer_ != nullptr) && buffer_on_host_) {
770 free(buffer_);
771 }
772 }
773
774 // Allocate CPU memory for the Rdma buffer
775 // Args:
776 // size: to-be-allocated memory size
777 // lock: whether or not mutex_lock the process to protect concurrency.
778 // Returns:
779 // None
CreateCPUBuffer(size_t size,bool lock)780 void RdmaMessageBuffer::CreateCPUBuffer(size_t size, bool lock) {
781 CHECK(size > 0);
782 if (lock) {
783 mu_.lock();
784 }
785 if (local_status_ != none) {
786 // delete existing buffer
787 CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
788 FreeBuffer();
789 }
790 size_ = size;
791 buffer_ = malloc(size_);
792 self_ = ibv_reg_mr(channel_->adapter_->pd_, buffer_, size_,
793 IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
794 CHECK(self_) << "Failed to register memory region";
795 buffer_on_host_ = true;
796 local_status_ = idle;
797 if (lock) {
798 mu_.unlock();
799 }
800 }
801
802 // Set address of remote memory region
803 // Args:
804 // rmr: address of remote memory region
805 // override: whether override existing information
806 // Returns:
807 // None
SetRemoteMR(RemoteMR rmr,bool override)808 void RdmaMessageBuffer::SetRemoteMR(RemoteMR rmr, bool override) {
809 mutex_lock lock{mu_};
810 if ((override) || (remote_status_ == none)) {
811 remote_.remote_addr = rmr.remote_addr;
812 remote_.rkey = rmr.rkey;
813 remote_status_ = idle;
814 } else {
815 CHECK(remote_.remote_addr == rmr.remote_addr);
816 CHECK(remote_.rkey == rmr.rkey);
817 }
818 }
819
820 // Put a task in the buffer's job queue
EnqueueItem(string item)821 void RdmaMessageBuffer::EnqueueItem(string item) {
822 mutex_lock lock{mu_};
823 queue_.push(item);
824 }
825
826 // Rdma-Write the content of the buffer
Write(uint32_t imm_data,size_t buffer_size)827 void RdmaMessageBuffer::Write(uint32_t imm_data, size_t buffer_size) {
828 Write(channel_, imm_data, buffer_size, (uint64_t)buffer_, self_->lkey,
829 remote_.remote_addr, remote_.rkey, RDMA_WRITE_ID_MESSAGE, this);
830 }
831
832 // Generalized Write method
Write(const RdmaChannel * channel,uint32_t imm_data,size_t buffer_size,uint64_t src_addr,uint32_t lkey,uint64_t remote_addr,uint32_t rkey,RdmaWriteIDType write_type,void * write_context)833 void RdmaMessageBuffer::Write(const RdmaChannel* channel, uint32_t imm_data,
834 size_t buffer_size, uint64_t src_addr,
835 uint32_t lkey, uint64_t remote_addr,
836 uint32_t rkey, RdmaWriteIDType write_type,
837 void* write_context) {
838 struct ibv_sge list;
839 list.addr = src_addr;
840 list.length = buffer_size;
841 list.lkey = lkey;
842
843 struct ibv_send_wr wr;
844 memset(&wr, 0, sizeof(wr));
845 wr.wr_id = (uint64_t) new RdmaWriteID(write_type, write_context);
846 wr.sg_list = &list;
847 wr.num_sge = 1;
848 wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
849 wr.send_flags = IBV_SEND_SIGNALED;
850 wr.imm_data = imm_data;
851 wr.wr.rdma.remote_addr = remote_addr;
852 wr.wr.rdma.rkey = rkey;
853
854 struct ibv_send_wr* bad_wr;
855 CHECK(!ibv_post_send(channel->qp_, &wr, &bad_wr)) << "Failed to post send";
856 }
857
858 // Send the next ack from the buffer's job queue.
SendAck(const RdmaChannel * channel)859 void RdmaMessageBuffer::SendAck(const RdmaChannel* channel) {
860 Write(channel, RDMA_IMM_DATA_ACK, 0, 0, 0, 0, 0, RDMA_WRITE_ID_ACK, nullptr);
861 }
862
863 // Send the next message from the buffer's job queue.
SendNextItem()864 void RdmaMessageBuffer::SendNextItem() {
865 uint32_t imm_data = RDMA_IMM_DATA_MESSAGE;
866 mu_.lock();
867 if (!queue_.empty() && (local_status_ == idle) && (remote_status_ == idle)) {
868 local_status_ = busy;
869 remote_status_ = busy;
870 string message = queue_.front();
871 queue_.pop();
872 // local/remote_status_ won't be set back to idle
873 // unitl Write() is successful
874 mu_.unlock();
875 memcpy(buffer_, message.data(), message.size());
876 Write(imm_data, message.size());
877 } else {
878 mu_.unlock();
879 }
880 }
881
882 #if GOOGLE_CUDA
CountCopies(const std::string & key,void * src_addr,void * dst_addr,size_t tensor_bytes,bool is_gpu_to_cpu)883 static void CountCopies(const std::string& key, void* src_addr, void* dst_addr,
884 size_t tensor_bytes, bool is_gpu_to_cpu) {
885 #ifdef RDMA_COUNT_COPIES
886 static uint64_t numGPUToCPUCopies = 0;
887 static uint64_t numGPUToCPUCopiedBytes = 0;
888 static uint64_t numCPUToGPUCopies = 0;
889 static uint64_t numCPUToGPUCopiedBytes = 0;
890 static uint64_t numTotalCopies = 0;
891
892 if (is_gpu_to_cpu) {
893 ++numGPUToCPUCopies;
894 numGPUToCPUCopiedBytes += tensor_bytes;
895 } else {
896 ++numCPUToGPUCopies;
897 numCPUToGPUCopiedBytes += tensor_bytes;
898 }
899 if ((++numTotalCopies % 0x400) == 0) {
900 RDMA_LOG(0) << "Tensor copies:"
901 << " GPU to CPU: " << numGPUToCPUCopies << " ("
902 << numGPUToCPUCopiedBytes << " Bytes)"
903 << " CPU to GPU: " << numCPUToGPUCopies << " ("
904 << numCPUToGPUCopiedBytes << " Bytes)";
905 }
906 RDMA_LOG(2) << "Copying tensor " << key << " From: " << src_addr
907 << " To: " << dst_addr;
908 #endif // RDMA_COUNT_COPIES
909 }
910 #endif // GOOGLE_CUDA
911
912 #ifdef RDMA_DATA_VALIDATION
Checksum(Device * device,const DeviceContext * device_context,const Tensor & in)913 static uint64_t Checksum(Device* device, const DeviceContext* device_context,
914 const Tensor& in) {
915 uint64 checksum = 0;
916 if (DataTypeCanUseMemcpy(in.dtype())) {
917 #if GOOGLE_CUDA
918 if (in.TotalBytes() == 0) {
919 return 0;
920 }
921 checksum = (device_context != nullptr)
922 ? GPUUtil::Checksum(device, device_context, in)
923 : GPUUtil::Checksum(in);
924 #endif // GOOGLE_CUDA
925 } else {
926 string s = in.SummarizeValue(999999);
927 checksum = Hash64(s.c_str(), s.size(), 0);
928 }
929 return checksum;
930 }
931
ValidateChecksum(uint64_t expected,uint64_t actual,const Tensor & in,uint32_t request_index,const std::string & key,const std::string & msg)932 static void ValidateChecksum(uint64_t expected, uint64_t actual,
933 const Tensor& in, uint32_t request_index,
934 const std::string& key, const std::string& msg) {
935 RDMA_LOG(2) << "Request #" << request_index << ": " << key
936 << ": Checksum: " << std::hex << " Expected = 0x" << expected
937 << ". Actual = 0x" << actual << ".";
938
939 if (expected != actual) {
940 // Checksum failed. There is one case where this is allowed - if the
941 // tensor is an AssignAdd of the global step. Since the data-validation
942 // always postpones the Tensor response in order to send a checksum message,
943 // it is possible that the global-step was updated while the response was
944 // still in queue.
945 if ((in.TotalBytes() == 8) && (in.dtype() == DT_INT64)) {
946 int64_t prev_val = *(int64_t*)DMAHelper::base(&in) - 1;
947 actual = Hash64((const char*)&prev_val, 8, 0);
948 }
949 if (expected != actual) {
950 LOG(FATAL) << "[" << msg << "]: Checksum validation failed for request #"
951 << request_index << ": " << key << std::hex << " "
952 << DataTypeString(in.dtype()) << " "
953 << in.shape().DebugString() << " (0x" << in.TotalBytes()
954 << " bytes): "
955 << " Expected 0x" << expected << ". Got 0x" << actual << ".";
956 }
957 }
958 }
959 #endif // RDMA_DATA_VALIDATION
960
961 #if GOOGLE_CUDA
962 // Sync the 'done' operation on the GPU stream, but without all the data
963 // copying.
StreamGPUOp(Device * gpu_device,const DeviceContext * device_context,StatusCallback done)964 static void StreamGPUOp(Device* gpu_device, const DeviceContext* device_context,
965 StatusCallback done) {
966 Tensor dummy1, dummy2;
967 GPUUtil::CopyGPUTensorToCPU(gpu_device, device_context, &dummy1, &dummy2,
968 done);
969 }
970 #endif // GOOGLE_CUDA
971
AddTensorResponse(const RdmaMessage & rm)972 RdmaTensorResponse* RdmaChannel::AddTensorResponse(const RdmaMessage& rm) {
973 mutex_lock lock{mu_};
974 auto it =
975 responses_table_.emplace(rm.request_index_, RdmaTensorResponse(this, rm));
976 CHECK(it.second) << "Response with the ID " << rm.request_index_
977 << " already exists.";
978 return &it.first->second;
979 }
980
UpdateTensorResponse(const RdmaMessage & rm)981 RdmaTensorResponse* RdmaChannel::UpdateTensorResponse(const RdmaMessage& rm) {
982 mutex_lock lock{mu_};
983 auto it = responses_table_.find(rm.request_index_);
984 CHECK(it != responses_table_.end()) << "No response found.";
985 RdmaTensorResponse* response = &it->second;
986 response->Update(rm);
987 return response;
988 }
989
RemoveTensorResponse(uint32_t request_index)990 void RdmaChannel::RemoveTensorResponse(uint32_t request_index) {
991 mutex_lock lock{mu_};
992 responses_table_.erase(request_index);
993 }
994
Start()995 void RdmaTensorResponse::Start() {
996 Rendezvous::ParsedKey parsed;
997 Status s = Rendezvous::ParseKey(rm_.name_, &parsed);
998 if (!s.ok()) {
999 SendErrorStatus(s);
1000 return;
1001 }
1002
1003 channel_->adapter_->worker_env_->rendezvous_mgr->RecvLocalAsync(
1004 rm_.step_id_, parsed,
1005 [this, parsed](const Status& status, const Rendezvous::Args& send_args,
1006 const Rendezvous::Args& recv_args, const Tensor& in,
1007 bool is_dead) {
1008 CHECK(status.ok()) << "RecvLocalAsync was not ok."
1009 << " error message: " << status.error_message();
1010 RecvHandler(parsed, send_args, recv_args, in, is_dead);
1011 });
1012 }
1013
Resume()1014 void RdmaTensorResponse::Resume() { SendContent(*tensor_, *proto_, is_dead_); }
1015
1016 // Helper for RecvTensor. Validates "key" and returns the source
1017 // device in "*src_dev".
PrepareRecvTensor(const Rendezvous::ParsedKey & parsed,Device ** src_dev)1018 Status RdmaTensorResponse::PrepareRecvTensor(
1019 const Rendezvous::ParsedKey& parsed, Device** src_dev) {
1020 // Figures out which device the tensor is hosted on.
1021 string local_name = DeviceNameUtils::LocalName(parsed.src_device);
1022 TF_RETURN_IF_ERROR(channel_->adapter_->worker_env_->device_mgr->LookupDevice(
1023 local_name, src_dev));
1024
1025 // Does the device have the right incarnation number we expect?
1026 if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
1027 return errors::Aborted(
1028 "RecvTensor expects a different device incarnation: ",
1029 parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(),
1030 ". Your worker job (\"",
1031 channel_->adapter_->worker_env_->session_mgr->LegacySession()
1032 ->worker_name,
1033 "\") was probably restarted. Check your "
1034 "worker job for the reason why it was restarted.");
1035 }
1036
1037 return Status::OK();
1038 }
1039
RecvHandler(Rendezvous::ParsedKey parsed,const Rendezvous::Args & send_args,const Rendezvous::Args & recv_args,const Tensor & in,bool is_dead)1040 void RdmaTensorResponse::RecvHandler(Rendezvous::ParsedKey parsed,
1041 const Rendezvous::Args& send_args,
1042 const Rendezvous::Args& recv_args,
1043 const Tensor& in, bool is_dead) {
1044 Status s = PrepareRecvTensor(parsed, &src_dev_);
1045 if (!s.ok()) {
1046 SendErrorStatus(s);
1047 return;
1048 }
1049
1050 meta_data_changed_ = TensorMetaDataChanged(in, is_dead);
1051 #ifdef RDMA_DATA_VALIDATION
1052 // Always send a meta data message with the source checksum
1053 meta_data_changed_ = rm_.type_ == RDMA_MESSAGE_TENSOR_REQUEST;
1054 checksum_ = Checksum(src_dev_, send_args.device_context, in);
1055 #endif
1056 bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
1057 // string tensor needs to be serialized
1058 Tensor copy;
1059 TensorProto proto;
1060 const bool on_host = send_args.alloc_attrs.on_host();
1061 if (src_dev_->tensorflow_gpu_device_info() && !on_host) {
1062 #if GOOGLE_CUDA
1063 DeviceContext* send_dev_context = send_args.device_context;
1064 CHECK(send_dev_context)
1065 << "send dev name: " << src_dev_->name()
1066 << " gpu_info: " << src_dev_->tensorflow_gpu_device_info();
1067
1068 if (can_memcpy) {
1069 // If the tensor is located on a GDR compatible GPU, there is no need to
1070 // copy it. We can send directly from the source, just need to make sure
1071 // we are in sync with the GPU stream.
1072 // If the tensor's meta-data changed however, we will need to clone it,
1073 // so anyway we'll have to copy it from GPU to CPU first. If at some
1074 // point in time Clone() is changed to only save a shallow copy, we can
1075 // skip the copy here as well.
1076 if ((in.TotalBytes() > 0) && !meta_data_changed_ &&
1077 (RdmaMemoryMgr::Singleton().FindMemoryRegion(
1078 (void*)DMAHelper::base(&in), in.TotalBytes()) != nullptr)) {
1079 StreamGPUOp(src_dev_, send_dev_context,
1080 [this, in, proto, is_dead](const Status& s) {
1081 Send(in, proto, is_dead, s);
1082 });
1083 return;
1084 }
1085
1086 // The tensor must be copied from GPU to CPU, because either:
1087 // 1. The tensor is located on a non GDR compatible GPU.
1088 // 2. The tensor's meta-data has changed.
1089 Allocator* alloc = GPUProcessState::singleton()->GetGpuHostAllocator(0);
1090 copy = Tensor(alloc, in.dtype(), in.shape());
1091 CountCopies(rm_.name_, (void*)DMAHelper::base(&in),
1092 (void*)DMAHelper::base(©), in.TotalBytes(), true);
1093 GPUUtil::CopyGPUTensorToCPU(
1094 src_dev_, send_dev_context, &in, ©,
1095 [this, copy, proto, is_dead](const Status& s) {
1096 Send(copy, proto, is_dead, s);
1097 });
1098 } else {
1099 GPUUtil::SetProtoFromGPU(
1100 in, src_dev_, send_args.device_context, &proto, is_dead,
1101 [this, in, proto, is_dead](const Status& s) mutable {
1102 Send(in, proto, is_dead, s);
1103 });
1104 }
1105 #else
1106 SendErrorStatus(errors::Internal("No GPU device in process"));
1107 #endif // GOOGLE_CUDA
1108 } else {
1109 // tensor is in CPU memory.
1110 if (!can_memcpy) {
1111 in.AsProtoTensorContent(&proto);
1112 }
1113 Send(in, proto, is_dead, Status::OK());
1114 }
1115 }
1116
Send(const Tensor & in,const TensorProto & proto,bool is_dead,const Status & status)1117 void RdmaTensorResponse::Send(const Tensor& in, const TensorProto& proto,
1118 bool is_dead, const Status& status) {
1119 if (!status.ok()) {
1120 SendErrorStatus(status);
1121 return;
1122 }
1123 bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
1124 bool proto_size_changed =
1125 (!can_memcpy) && (proto.ByteSize() != rm_.tensor_bytes_);
1126 if (meta_data_changed_ || proto_size_changed) {
1127 Clone(in, proto, is_dead);
1128 SendMetaData(in, proto, is_dead);
1129 } else {
1130 SendContent(in, proto, is_dead);
1131 }
1132 }
1133
TensorMetaDataChanged(const Tensor & in,bool is_dead)1134 bool RdmaTensorResponse::TensorMetaDataChanged(const Tensor& in, bool is_dead) {
1135 return (rm_.data_type_ != in.dtype()) || (rm_.tensor_shape_ != in.shape()) ||
1136 (rm_.is_dead_ != is_dead);
1137 }
1138
Clone(const Tensor & in,const TensorProto & proto,bool is_dead)1139 void RdmaTensorResponse::Clone(const Tensor& in, const TensorProto& proto,
1140 bool is_dead) {
1141 // Clone the data to be sent later. For simplicity, we clone the tensor's
1142 // data even if it is already a copy. Performance is less of a concern here
1143 // since the meta-data hardly ever changes. The reason we create a copy, is
1144 // that some tensors share their buffer between different step-ids, so the
1145 // tensor content may change before re-request was completed.
1146 bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
1147 if (can_memcpy && (in.TotalBytes() > 0)) {
1148 AllocatorAttributes host_alloc_attrs;
1149 host_alloc_attrs.set_nic_compatible(true);
1150 host_alloc_attrs.set_on_host(true);
1151 Allocator* allocator = src_dev_->GetAllocator(host_alloc_attrs);
1152 tensor_ = new Tensor(allocator, in.dtype(), in.shape());
1153 memcpy(DMAHelper::base(tensor_), DMAHelper::base(&in), in.TotalBytes());
1154 } else {
1155 tensor_ = new Tensor(in.dtype(), in.shape());
1156 }
1157 if (!can_memcpy) {
1158 proto_ = new TensorProto(proto);
1159 }
1160 is_dead_ = is_dead;
1161 }
1162
SendMetaData(const Tensor & in,const TensorProto & proto,bool is_dead)1163 void RdmaTensorResponse::SendMetaData(const Tensor& in,
1164 const TensorProto& proto, bool is_dead) {
1165 RDMA_LOG(2) << "Request #" << rm_.request_index_
1166 << ": Meta data changed: " << rm_.name_;
1167 bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
1168 size_t tensor_bytes = (can_memcpy) ? in.TotalBytes() : proto.ByteSize();
1169
1170 // Send meta-data update:
1171 RdmaMessage rm;
1172 rm.type_ = RDMA_MESSAGE_META_DATA_UPDATE;
1173 rm.name_size_ = rm_.name_.size();
1174 rm.name_ = rm_.name_;
1175 rm.tensor_shape_ = in.shape();
1176 rm.data_type_ = in.dtype();
1177 rm.step_id_ = rm_.step_id_;
1178 rm.is_dead_ = is_dead;
1179 rm.tensor_bytes_ = tensor_bytes;
1180 rm.request_index_ = rm_.request_index_;
1181 #ifdef RDMA_DATA_VALIDATION
1182 rm.checksum_ = checksum_;
1183 #endif
1184 RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec
1185 << ": Sending RDMA_MESSAGE_META_DATA_UPDATE #"
1186 << rm.request_index_ << ": " << rm.name_
1187 << " (shape = " << rm.tensor_shape_.DebugString() << "."
1188 << " data-type = " << DataTypeString(rm.data_type_) << "."
1189 << " is-dead = " << rm.is_dead_ << ")";
1190
1191 string message = RdmaMessage::CreateMessage(rm);
1192 channel_->tx_message_buffer_->EnqueueItem(message);
1193 channel_->tx_message_buffer_->SendNextItem();
1194 }
1195
SendContent(const Tensor & in,const TensorProto & proto,bool is_dead)1196 void RdmaTensorResponse::SendContent(const Tensor& in, const TensorProto& proto,
1197 bool is_dead) {
1198 bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
1199 size_t tensor_bytes = (can_memcpy) ? in.TotalBytes() : proto.ByteSize();
1200 uint32_t imm_data = rm_.request_index_;
1201 if (!is_dead) {
1202 if (can_memcpy) {
1203 src_buffer_ = const_cast<TensorBuffer*>(DMAHelper::buffer(&in));
1204 if (src_buffer_ != nullptr) {
1205 src_buffer_->Ref(); // Keep buffer alive until write is complete
1206 src_addr_ = src_buffer_->data();
1207 mr_ = RdmaMemoryMgr::Singleton().FindMemoryRegion(src_addr_,
1208 tensor_bytes);
1209 }
1210 } else {
1211 RDMA_LOG(2) << "Encoding proto: " << rm_.name_
1212 << " (Size: " << tensor_bytes << ") " << in.DebugString();
1213 src_addr_ = malloc(tensor_bytes);
1214 mr_ = ibv_reg_mr(channel_->adapter_->pd_, src_addr_, tensor_bytes,
1215 IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
1216 proto.SerializeToArray(src_addr_, tensor_bytes);
1217 }
1218 } else {
1219 tensor_bytes = 0;
1220 }
1221
1222 uint32_t lkey = (mr_ == nullptr) ? 0 : mr_->lkey;
1223 RDMA_LOG(1) << "Step 0x" << std::hex << rm_.step_id_ << std::dec
1224 << ": Sending tensor content #" << rm_.request_index_ << " from "
1225 << std::hex << src_addr_ << " (0x" << lkey << ")"
1226 << " to " << rm_.remote_addr_ << " (0x" << rm_.rkey_
1227 << "): " << rm_.name_ << " (size: 0x" << std::hex << tensor_bytes
1228 << ")";
1229
1230 RdmaMessageBuffer::Write(channel_, imm_data, tensor_bytes,
1231 (uint64_t)src_addr_, lkey, rm_.remote_addr_,
1232 rm_.rkey_, RDMA_WRITE_ID_TENSOR_WRITE, this);
1233 }
1234
SendErrorStatus(const Status & status)1235 void RdmaTensorResponse::SendErrorStatus(const Status& status) {
1236 RdmaMessage rm;
1237 rm.type_ = RDMA_MESSAGE_ERROR_STATUS;
1238 rm.name_size_ = rm_.name_.size();
1239 rm.name_ = rm_.name_;
1240 rm.step_id_ = rm_.step_id_;
1241 rm.request_index_ = rm_.request_index_;
1242 rm.status_ = status;
1243 LOG(ERROR) << "Step 0x" << std::hex << rm.step_id_ << std::dec
1244 << ": Sending RDMA_MESSAGE_ERROR_STATUS #" << rm.request_index_
1245 << ": " << rm.name_ << ". Status: " << status.ToString();
1246
1247 string message = RdmaMessage::CreateMessage(rm);
1248 channel_->tx_message_buffer_->EnqueueItem(message);
1249 channel_->tx_message_buffer_->SendNextItem();
1250
1251 // Destroy the response.
1252 Destroy();
1253 }
1254
Destroy()1255 void RdmaTensorResponse::Destroy() {
1256 if (src_buffer_ != nullptr) {
1257 src_buffer_->Unref();
1258 }
1259 if (tensor_ != nullptr) {
1260 delete tensor_;
1261 }
1262 if (proto_ != nullptr) {
1263 ibv_dereg_mr(mr_);
1264 free(src_addr_);
1265 delete proto_;
1266 }
1267 // Remove response from the pending list:
1268 channel_->RemoveTensorResponse(rm_.request_index_);
1269 }
1270
1271 // Create a RdmaMessage according to the pre-defined format
1272 // Args:
1273 // rm: the message structure
1274 // Returns:
1275 // message in string format
CreateMessage(const RdmaMessage & rm)1276 string RdmaMessage::CreateMessage(const RdmaMessage& rm) {
1277 // Rdma Message format
1278 // type|name_size|name|step_id|request_index|remote_addr|rkey|is_dead|...
1279 // 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |...
1280 // ...|data_type|tensor_shape|tensor_bytes|error_status |
1281 // ...| XB | XB | 8B |size - 4B, proto - XB |
1282 //
1283 // ACK: Imm-type: ACK
1284 // TENSOR_REQUEST: Imm-type: MESSAGE
1285 // Fields: type, request_index, name, step_id, remote_addr,
1286 // rkey, is_dead, data_type, tensor_shape, tensor_bytes
1287 // META_DATA_UPDATE: Imm-type: MESSAGE
1288 // Fields: type, request_index, is_dead, data_type,
1289 // tensor_shape, tensor_bytes
1290 // TENSOR_RE_REQUST: Imm-type: MESSAGE
1291 // Fields: type, request_index, name, step_id, remote_addr,
1292 // rkey, is_dead, data_type, tensor_shape, tensor_bytes
1293 // ERROR_STATUS: Imm-type: MESSAGE
1294 // Fields: type, request_index, name, step_id, error_status
1295 // Tensor content: Imm-type: request_index
1296 size_t message_size = kMessageTotalBytes;
1297 char message[kMessageTotalBytes + kErrorStatusMaxSize];
1298 // type
1299 message[kTypeStartIndex] = static_cast<char>(rm.type_) & 0xff;
1300 // request index
1301 memcpy(&message[kRequestIndexStartIndex], &rm.request_index_,
1302 sizeof(rm.request_index_));
1303 // name, step_id, remote_addr, rkey
1304 if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
1305 (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
1306 memcpy(&message[kNameSizeStartIndex], &rm.name_size_,
1307 sizeof(rm.name_size_));
1308 memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size());
1309 memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_,
1310 sizeof(rm.remote_addr_));
1311 memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_));
1312 memcpy(&message[kStepIdStartIndex], &rm.step_id_, sizeof(rm.step_id_));
1313 }
1314 // is_dead, data_type, tensor_shape, tensor_bytes
1315 if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
1316 (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) ||
1317 (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
1318 memcpy(&message[kIsDeadStartIndex], &rm.is_dead_, sizeof(rm.is_dead_));
1319
1320 memcpy(&message[kDataTypeStartIndex], &rm.data_type_,
1321 sizeof(rm.data_type_));
1322 memcpy(&message[kTensorShapeStartIndex], &rm.tensor_shape_,
1323 sizeof(rm.tensor_shape_));
1324 memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_,
1325 sizeof(rm.tensor_bytes_));
1326 }
1327 // checksum
1328 #ifdef RDMA_DATA_VALIDATION
1329 memcpy(&message[kChecksumStartIndex], &rm.checksum_, sizeof(rm.checksum_));
1330 #endif
1331 // error status
1332 if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
1333 ::grpc::Status gs = ToGrpcStatus(rm.status_);
1334 ErrorStatusProto gsProto;
1335 gsProto.set_error_code(gs.error_code());
1336 gsProto.set_error_message(gs.error_message());
1337 gsProto.set_error_details(gs.error_details());
1338 uint32_t gsProtoSize = gsProto.ByteSize();
1339 if (gsProtoSize + 4 > kErrorStatusMaxSize) {
1340 LOG(ERROR) << "Error status (" << gsProtoSize + 4 << " bytes) "
1341 << "is too big to fit in RDMA message (" << kErrorStatusMaxSize
1342 << " bytes). Truncated.";
1343 gsProtoSize = kErrorStatusMaxSize - 4;
1344 }
1345 uint32_t* proto_size = (uint32_t*)&message[kErrorStatusStartIndex];
1346 *proto_size = gsProtoSize;
1347 gsProto.SerializeToArray(&message[kErrorStatusStartIndex + 4], gsProtoSize);
1348 message_size += gsProtoSize + 4;
1349 }
1350 return string(message, message_size);
1351 }
1352
1353 // Parse a RdmaMessage according to the pre-defined format
1354 // Args:
1355 // rm: the message structure where the parsed message will be saved
1356 // buffer: the place where the raw message is stored
1357 // Returns:
1358 // None
ParseMessage(RdmaMessage & rm,void * buffer)1359 void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) {
1360 char* message = static_cast<char*>(buffer);
1361 // type
1362 rm.type_ = static_cast<RdmaMessageType>(message[kTypeStartIndex]);
1363 // request index
1364 memcpy(&rm.request_index_, &message[kRequestIndexStartIndex],
1365 sizeof(rm.request_index_));
1366 // name, step_id, remote_addr, rkey
1367 if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
1368 (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
1369 memcpy(&rm.name_size_, &message[kNameSizeStartIndex],
1370 sizeof(rm.name_size_));
1371 rm.name_ = string(&message[kNameStartIndex], rm.name_size_);
1372 memcpy(&rm.remote_addr_, &message[kRemoteAddrStartIndex],
1373 sizeof(rm.remote_addr_));
1374 memcpy(&rm.rkey_, &message[kRkeyStartIndex], sizeof(rm.rkey_));
1375 memcpy(&rm.step_id_, &message[kStepIdStartIndex], sizeof(rm.step_id_));
1376 }
1377 // data_type, tensor_bytes, tensor_shape, is_dead
1378 if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
1379 (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) ||
1380 (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
1381 memcpy(&rm.is_dead_, &message[kIsDeadStartIndex], sizeof(rm.is_dead_));
1382 memcpy(&rm.data_type_, &message[kDataTypeStartIndex],
1383 sizeof(rm.data_type_));
1384 memcpy(&rm.tensor_shape_, &message[kTensorShapeStartIndex],
1385 sizeof(rm.tensor_shape_));
1386 memcpy(&rm.tensor_bytes_, &message[kTensorBytesStartIndex],
1387 sizeof(rm.tensor_bytes_));
1388 }
1389 // checksum
1390 #ifdef RDMA_DATA_VALIDATION
1391 memcpy(&rm.checksum_, &message[kChecksumStartIndex], sizeof(rm.checksum_));
1392 #endif
1393 // error status
1394 if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
1395 ErrorStatusProto gsProto;
1396 uint32_t gsProtoSize = *(uint32_t*)&message[kErrorStatusStartIndex];
1397 CHECK(ParseProtoUnlimited(&gsProto, &message[kErrorStatusStartIndex + 4],
1398 gsProtoSize))
1399 << "Failed to parse error status proto from message. Aborting.";
1400 ::grpc::Status gs((::grpc::StatusCode)gsProto.error_code(),
1401 gsProto.error_message(), gsProto.error_details());
1402 rm.status_ = FromGrpcStatus(gs);
1403 }
1404 }
1405
1406 //*****************************************************************************
1407 // RdmaMemoryMgr
1408 //*****************************************************************************
1409
FindMemoryRegion(void * addr,size_t length)1410 ibv_mr* RdmaMemoryMgr::FindMemoryRegion(void* addr, size_t length) {
1411 mutex_lock l(mrs_mu_);
1412 auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
1413 if (iter == std::end(mrs_) || iter->get()->addr > addr) {
1414 return nullptr;
1415 } else {
1416 return iter->get();
1417 }
1418 }
1419
InsertMemoryRegion(void * addr,size_t length,const std::string & allocator_name)1420 void RdmaMemoryMgr::InsertMemoryRegion(void* addr, size_t length,
1421 const std::string& allocator_name) {
1422 if (length == 0) return;
1423 ibv_mr* mr = ibv_reg_mr(pd_, addr, length,
1424 IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
1425 RDMA_LOG(1) << "Insert memory region 0x" << std::hex << mr->rkey << ". ["
1426 << addr << "-" << (void*)((uint64_t)addr + length - 1) << "]"
1427 << " SIZE: 0x" << length << " (" << allocator_name << ").";
1428 if (mr != nullptr) {
1429 mutex_lock l(mrs_mu_);
1430 auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
1431 mrs_.insert(iter, {mr, &MRDeleter});
1432 } else {
1433 LOG(WARNING) << "Cannot register memory region";
1434 }
1435 }
1436
EvictMemoryRegion(void * addr,size_t length)1437 void RdmaMemoryMgr::EvictMemoryRegion(void* addr, size_t length) {
1438 if (length == 0) return;
1439 mutex_lock l(mrs_mu_);
1440 auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
1441 if (iter != std::end(mrs_) && iter->get()->addr == addr) {
1442 mrs_.erase(iter);
1443 RDMA_LOG(1) << "Evict memory region 0x" << std::hex << iter->get()->rkey;
1444
1445 } else {
1446 LOG(WARNING) << "Failed to de-register memory region";
1447 }
1448 }
1449
GetTensorMetaData(const std::string & tensor_name)1450 const TensorMetaData* RdmaMemoryMgr::GetTensorMetaData(
1451 const std::string& tensor_name) {
1452 mutex_lock l(tensor_meta_data_mu_);
1453 auto it = tensors_meta_data_.find(tensor_name);
1454 if (it == tensors_meta_data_.end()) {
1455 return nullptr;
1456 }
1457 return &it->second;
1458 }
1459
SetTensorMetaData(const std::string & tensor_name,DataType dtype,const TensorShape & shape,bool is_dead,size_t proto_size)1460 const TensorMetaData* RdmaMemoryMgr::SetTensorMetaData(
1461 const std::string& tensor_name, DataType dtype, const TensorShape& shape,
1462 bool is_dead, size_t proto_size) {
1463 mutex_lock l(tensor_meta_data_mu_);
1464 TensorMetaData& meta_data = tensors_meta_data_[tensor_name];
1465 meta_data.data_type_ = dtype;
1466 meta_data.tensor_shape_ = shape;
1467 meta_data.proto_size_ = proto_size;
1468 meta_data.is_dead_ = is_dead;
1469 return &meta_data;
1470 }
1471
1472 //*****************************************************************************
1473 // RdmaTensorRequest
1474 //*****************************************************************************
1475
RdmaTensorRequest(uint32_t index,const string & key,int64 step_id,RdmaChannel * channel,Device * dst_dev,const Rendezvous::Args recv_args,const RdmaTensorRequest::RecvDoneCallback & done)1476 RdmaTensorRequest::RdmaTensorRequest(
1477 uint32_t index, const string& key, int64 step_id, RdmaChannel* channel,
1478 Device* dst_dev, const Rendezvous::Args recv_args,
1479 const RdmaTensorRequest::RecvDoneCallback& done)
1480 : index_(index),
1481 key_(key),
1482 step_id_(step_id),
1483 channel_(channel),
1484 dst_dev_(dst_dev),
1485 recv_args_(recv_args),
1486 meta_data_(RdmaMemoryMgr::Singleton().GetTensorMetaData(key)),
1487 result_tensor_(nullptr),
1488 proxy_tensor_(nullptr),
1489 rdma_addr_(nullptr),
1490 mr_(nullptr),
1491 done_(done) {}
1492
~RdmaTensorRequest()1493 RdmaTensorRequest::~RdmaTensorRequest() { DeallocateTensors(); }
1494
Done(const Status & s)1495 void RdmaTensorRequest::Done(const Status& s) {
1496 Tensor val = std::move(*result_tensor_);
1497
1498 #ifdef RDMA_DATA_VALIDATION
1499 // Validate checksum
1500 // Unfortunately we can't always do a Checksum directly on the result tensor.
1501 // If the result tensor is on GPU, then we need to copy it back to CPU. If
1502 // we happen to be in the midst of a proxy callback, then the copying will
1503 // get stuck.
1504 uint64_t checksum = (proxy_tensor_ != nullptr)
1505 ? Checksum(nullptr, nullptr, *proxy_tensor_)
1506 : Checksum(dst_dev_, recv_args_.device_context, val);
1507 ValidateChecksum(checksum_, checksum, val, index_, key_, "RDMA");
1508 #endif
1509
1510 Rendezvous::Args recv_args = std::move(recv_args_);
1511 bool is_dead = (meta_data_ == nullptr) ? false : meta_data_->is_dead_;
1512 RecvDoneCallback done = done_;
1513 DeallocateTensors();
1514 channel_->RemoveTensorRequest(index_);
1515 done(s, Rendezvous::Args(), recv_args, val, is_dead);
1516 }
1517
DeallocateTensors()1518 void RdmaTensorRequest::DeallocateTensors() {
1519 if (result_tensor_ != nullptr) {
1520 delete result_tensor_;
1521 result_tensor_ = nullptr;
1522 }
1523 if (proxy_tensor_ != nullptr) {
1524 delete proxy_tensor_;
1525 proxy_tensor_ = nullptr;
1526 }
1527 }
1528
AllocateTensors()1529 bool RdmaTensorRequest::AllocateTensors() {
1530 result_tensor_ =
1531 new Tensor(dst_dev_->GetAllocator(recv_args_.alloc_attrs),
1532 meta_data_->data_type_, meta_data_->tensor_shape_);
1533
1534 size_t tensor_size = result_tensor_->TotalBytes();
1535 bool can_memcpy = DataTypeCanUseMemcpy(result_tensor_->dtype());
1536 if (can_memcpy) {
1537 if (tensor_size == 0) {
1538 return true;
1539 }
1540 rdma_addr_ = DMAHelper::base(result_tensor_);
1541 mr_ = RdmaMemoryMgr::Singleton().FindMemoryRegion(rdma_addr_, tensor_size);
1542 #if GOOGLE_CUDA
1543 if (mr_ == nullptr) {
1544 // Can't RDMA directly to result. Use a proxy.
1545 proxy_tensor_ =
1546 new Tensor(GPUProcessState::singleton()->GetGpuHostAllocator(0),
1547 result_tensor_->dtype(), result_tensor_->shape());
1548 rdma_addr_ = DMAHelper::base(proxy_tensor_);
1549 mr_ =
1550 RdmaMemoryMgr::Singleton().FindMemoryRegion(rdma_addr_, tensor_size);
1551 }
1552 #endif
1553 } else {
1554 uint32_t proto_size = meta_data_->proto_size_;
1555 rdma_addr_ = malloc(proto_size);
1556 mr_ = ibv_reg_mr(RdmaMemoryMgr::Singleton().pd_, rdma_addr_, proto_size,
1557 IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
1558 }
1559 CHECK(mr_ != nullptr) << " No memory region found for address " << rdma_addr_
1560 << ": " << key_;
1561 return true;
1562 }
1563
AllocateTensorsAsync(StatusCallback done)1564 void RdmaTensorRequest::AllocateTensorsAsync(StatusCallback done) {
1565 AllocateTensors();
1566 bool on_host = recv_args_.alloc_attrs.on_host();
1567 if (dst_dev_->tensorflow_gpu_device_info() && !on_host &&
1568 (proxy_tensor_ == nullptr)) {
1569 #if GOOGLE_CUDA
1570 // We need to sync the memory allocation on the GPU:
1571 StreamGPUOp(dst_dev_, recv_args_.device_context, done);
1572 #endif
1573 } else {
1574 done(Status::OK());
1575 }
1576 }
1577
Send(RdmaMessageType message_type)1578 void RdmaTensorRequest::Send(RdmaMessageType message_type) {
1579 RdmaMessageBuffer* rb = channel_->tx_message_buffer_;
1580 RdmaMessage rm;
1581 rm.type_ = message_type;
1582 rm.request_index_ = index_;
1583 rm.name_size_ = key_.size();
1584 rm.name_ = key_;
1585 rm.step_id_ = step_id_;
1586 rm.remote_addr_ = (uint64_t)rdma_addr_;
1587 if (meta_data_ != nullptr) {
1588 rm.data_type_ = meta_data_->data_type_;
1589 rm.tensor_shape_ = meta_data_->tensor_shape_;
1590 rm.is_dead_ = meta_data_->is_dead_;
1591 rm.tensor_bytes_ = meta_data_->proto_size_;
1592 } else {
1593 rm.data_type_ = DT_INVALID;
1594 }
1595 rm.rkey_ = (mr_ == nullptr) ? 0 : mr_->rkey;
1596
1597 RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec
1598 << ": Sending " << MessageTypeToString(message_type) << " #"
1599 << index_ << ": " << rm.name_ << " on " << rdma_addr_
1600 << " (rkey: 0x" << std::hex << rm.rkey_ << ")";
1601
1602 string message = RdmaMessage::CreateMessage(rm);
1603 rb->EnqueueItem(message);
1604 rb->SendNextItem();
1605 }
1606
RecvTensorMetaData(DataType dtype,TensorShape shape,bool is_dead,size_t proto_size)1607 void RdmaTensorRequest::RecvTensorMetaData(DataType dtype, TensorShape shape,
1608 bool is_dead, size_t proto_size) {
1609 meta_data_ = RdmaMemoryMgr::Singleton().SetTensorMetaData(
1610 key_, dtype, shape, is_dead, proto_size);
1611
1612 DeallocateTensors();
1613 AllocateTensorsAsync(
1614 [this](const Status& s) { Send(RDMA_MESSAGE_TENSOR_RE_REQUEST); });
1615 }
1616
RecvTensorContent()1617 void RdmaTensorRequest::RecvTensorContent() {
1618 bool can_memcpy = DataTypeCanUseMemcpy(meta_data_->data_type_);
1619 size_t message_size =
1620 can_memcpy ? result_tensor_->TotalBytes() : meta_data_->proto_size_;
1621 RDMA_LOG(1) << "Step 0x" << std::hex << step_id_ << std::dec
1622 << ": Received tensor content #" << index_ << ": " << key_
1623 << " (Size: 0x" << std::hex << message_size << ")";
1624
1625 Tensor val;
1626
1627 #if GOOGLE_CUDA
1628 if (proxy_tensor_ != nullptr) {
1629 CountCopies(key_, (void*)DMAHelper::base(proxy_tensor_),
1630 (void*)DMAHelper::base(result_tensor_),
1631 result_tensor_->TotalBytes(), false);
1632 GPUUtil::CopyCPUTensorToGPU(proxy_tensor_, recv_args_.device_context,
1633 dst_dev_, result_tensor_,
1634 [this](const Status& s) {
1635 CHECK(s.ok()) << "copy tensor to gpu sync";
1636 Done(s);
1637 });
1638 return;
1639 }
1640 #endif
1641
1642 if (can_memcpy) {
1643 Done(Status::OK());
1644 } else {
1645 RDMA_LOG(2) << "Decoding proto: " << key_
1646 << " (Size: " << meta_data_->proto_size_ << ")";
1647 TensorProto proto;
1648 CHECK(ParseProtoUnlimited(&proto, rdma_addr_, meta_data_->proto_size_))
1649 << "fail to parse proto from array";
1650 ibv_dereg_mr(mr_);
1651 free(rdma_addr_);
1652 Status s = dst_dev_->MakeTensorFromProto(proto, recv_args_.alloc_attrs,
1653 result_tensor_);
1654 Done(s);
1655 }
1656 }
1657
RecvErrorStatus(const Status & status)1658 void RdmaTensorRequest::RecvErrorStatus(const Status& status) {
1659 if (result_tensor_ == nullptr) {
1660 result_tensor_ = new Tensor();
1661 }
1662 LOG(ERROR) << "Received RDMA_MESSAGE_ERROR_STATUS: " << status.ToString();
1663 Done(status);
1664 }
1665
Start()1666 void RdmaTensorRequest::Start() {
1667 meta_data_ = RdmaMemoryMgr::Singleton().GetTensorMetaData(key_);
1668 if (meta_data_ != nullptr) {
1669 AllocateTensorsAsync(
1670 [this](const Status& s) { Send(RDMA_MESSAGE_TENSOR_REQUEST); });
1671 } else {
1672 Send(RDMA_MESSAGE_TENSOR_REQUEST);
1673 }
1674 }
1675
1676 } // end namespace tensorflow
1677
1678 #endif
1679