1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifdef TENSORFLOW_USE_GDR
17
18 #include "tensorflow/contrib/gdr/gdr_memory_manager.h"
19
20 #include <atomic>
21 #include <cerrno>
22 #include <fstream>
23 #include <list>
24 #include <map>
25 #include <set>
26
27 #include <fcntl.h>
28 #include <rdma/rdma_cma.h>
29 #include <rdma/rdma_verbs.h>
30 #include <sys/epoll.h>
31
32 #include "tensorflow/contrib/gdr/gdr.pb.h"
33 #include "tensorflow/core/common_runtime/bfc_allocator.h"
34 #include "tensorflow/core/common_runtime/device.h"
35 #include "tensorflow/core/common_runtime/dma_helper.h"
36 #if GOOGLE_CUDA
37 #include "tensorflow/core/common_runtime/gpu/gpu_util.h"
38 #include "tensorflow/core/common_runtime/gpu/process_state.h"
39 #endif // GOOGLE_CUDA
40 #include "tensorflow/core/framework/allocator_registry.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/mutex.h"
44
45 namespace tensorflow {
46
47 namespace {
48
IsGDRAvailable()49 bool IsGDRAvailable() {
50 #if defined(__APPLE__)
51 return false;
52 #elif defined(PLATFORM_WINDOWS)
53 return false;
54 #else
55 std::ifstream ifs("/proc/modules");
56 string line;
57 while (std::getline(ifs, line)) {
58 auto sep = line.find(' ');
59 CHECK_NE(sep, std::string::npos);
60 if (line.substr(0, sep) == "nv_peer_mem") {
61 return true;
62 }
63 }
64 return false;
65 #endif
66 }
67
TryToReadNumaNode(ibv_device * device)68 int TryToReadNumaNode(ibv_device* device) {
69 #if defined(__APPLE__)
70 LOG(INFO) << "OS X does not support NUMA - returning NUMA node 0";
71 return 0;
72 #elif defined(PLATFORM_WINDOWS)
73 // Windows support for NUMA is not currently implemented. Return node 0.
74 return 0;
75 #else
76 VLOG(2) << "Trying to read NUMA node for device: " << device->name;
77 static const int kUnknownNumaNode = -1;
78
79 auto filename = string(device->ibdev_path) + "/device/numa_node";
80
81 std::ifstream ifs(filename.c_str());
82 string content;
83 CHECK(std::getline(ifs, content));
84
85 int32 value;
86 if (strings::safe_strto32(content, &value)) {
87 if (value < 0) {
88 LOG(INFO) << "Successful NUMA node read from SysFS had negative value ("
89 << value
90 << "), but there must be at least one NUMA node"
91 ", so returning NUMA node zero";
92 return 0;
93 }
94 LOG(INFO) << "NUMA node for device: " << device->name << " is " << value;
95 return value;
96 }
97 return kUnknownNumaNode;
98 #endif
99 }
100
EndpointDeleter(rdma_cm_id * id)101 void EndpointDeleter(rdma_cm_id* id) {
102 if (id) {
103 rdma_destroy_ep(id);
104 }
105 }
106
MRDeleter(ibv_mr * mr)107 void MRDeleter(ibv_mr* mr) {
108 if (mr) {
109 rdma_dereg_mr(mr);
110 }
111 }
112
113 using RdmaEndpointPtr = std::unique_ptr<rdma_cm_id, decltype(&EndpointDeleter)>;
114
115 using MemoryRegionPtr = std::unique_ptr<ibv_mr, decltype(&MRDeleter)>;
116
117 class GdrMemoryManager : public RemoteMemoryManager {
118 public:
119 GdrMemoryManager(const string& host, const string& port);
120
121 virtual ~GdrMemoryManager();
122
123 virtual Status Init() override;
124
125 virtual void Run() override;
126
127 virtual void Stop() override;
128
129 virtual void TransportOptionsFromTensor(
130 ::google::protobuf::Any* mutable_transport_options, const Tensor& tensor,
131 Device* device, DeviceContext* device_context, bool on_host,
132 StatusCallback done) override;
133
134 virtual void TensorFromTransportOptions(
135 Tensor* tensor, const ::google::protobuf::Any& transport_options,
136 Device* device, DeviceContext* device_context, bool on_host,
137 StatusCallback done) override;
138
139 protected:
140 Status CreateEndpoint(const string& host, const string& port,
141 RdmaEndpointPtr& endpoint);
142
Comparator(const void * ptr,const MemoryRegionPtr & other)143 static bool Comparator(const void* ptr, const MemoryRegionPtr& other) {
144 return ptr < reinterpret_cast<char*>(other->addr) + other->length;
145 }
146
147 ibv_mr* FindMemoryRegion(void* addr, size_t length);
148
149 void InsertMemoryRegion(void* addr, size_t length);
150
151 void EvictMemoryRegion(void* addr, size_t length);
152
153 private:
154 const string host_;
155 const string port_;
156 RdmaEndpointPtr listening_;
157 std::atomic<bool> stopped_;
158 int epfd_;
159
160 // Server side endpoints
161 // Accessed sequentially in Run() so not protected by lock
162 std::list<RdmaEndpointPtr> server_clients_;
163
164 using TensorKey = uint32_t;
165 std::atomic<TensorKey> next_key_;
166
167 // Server side on-the-fly tensor buffers
168 mutex server_mu_;
169 std::map<TensorKey, const TensorBuffer*> tensor_buffers_
170 GUARDED_BY(server_mu_);
171
172 // Client side endpoints
173 mutex client_mu_;
174 std::map<std::pair<string, string>, RdmaEndpointPtr> clients_
175 GUARDED_BY(cient_mu_);
176
177 // Managed memory regions
178 mutex alloc_mu_;
179 std::vector<MemoryRegionPtr> mrs_ GUARDED_BY(alloc_mu_);
180
181 TF_DISALLOW_COPY_AND_ASSIGN(GdrMemoryManager);
182 };
183
184 // TODO(byronyi): remove this class duplicated from the one in
185 // common/runtime/gpu/pool_allocator.h when it is available in common_runtime
186 class BasicCPUAllocator : public SubAllocator {
187 public:
~BasicCPUAllocator()188 ~BasicCPUAllocator() override {}
189
Alloc(size_t alignment,size_t num_bytes)190 void* Alloc(size_t alignment, size_t num_bytes) override {
191 return port::AlignedMalloc(num_bytes, alignment);
192 }
Free(void * ptr,size_t)193 void Free(void* ptr, size_t) override { port::AlignedFree(ptr); }
194 };
195
196 // TODO(byronyi): remove this class and its registration when the default
197 // cpu_allocator() returns visitable allocator
198 class BFCRdmaAllocator : public BFCAllocator {
199 public:
BFCRdmaAllocator()200 BFCRdmaAllocator()
201 : BFCAllocator(new BasicCPUAllocator(), 1LL << 36, true, "cpu_rdma_bfc") {
202 }
203 };
204
205 REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocator);
206
GdrMemoryManager(const string & host,const string & port)207 GdrMemoryManager::GdrMemoryManager(const string& host, const string& port)
208 : host_(host),
209 port_(port),
210 listening_(nullptr, EndpointDeleter),
211 stopped_(true),
212 next_key_(0) {}
213
~GdrMemoryManager()214 GdrMemoryManager::~GdrMemoryManager() { close(epfd_); }
215
Init()216 Status GdrMemoryManager::Init() {
217 epfd_ = epoll_create1(0);
218 if (epfd_ == -1) {
219 return errors::Unavailable(strerror(errno), ": ", "epoll_create");
220 }
221
222 rdma_addrinfo* addrinfo;
223 rdma_addrinfo hints = {};
224 hints.ai_port_space = RDMA_PS_TCP;
225 hints.ai_flags = RAI_PASSIVE;
226 if (rdma_getaddrinfo(const_cast<char*>(host_.c_str()),
227 const_cast<char*>(port_.c_str()), &hints, &addrinfo)) {
228 return errors::Unavailable(strerror(errno), ": ", "cannot resolve rdma://",
229 host_, ":", port_);
230 }
231
232 ibv_qp_init_attr init_attr = {};
233 init_attr.qp_type = IBV_QPT_RC;
234 init_attr.cap.max_recv_wr = 32;
235 init_attr.cap.max_send_wr = 1;
236 init_attr.cap.max_recv_sge = 1;
237 init_attr.cap.max_send_sge = 1;
238
239 // Create listening endpoint
240 rdma_cm_id* id;
241 if (rdma_create_ep(&id, addrinfo, nullptr, &init_attr)) {
242 return errors::Unavailable(strerror(errno), ": ", "cannot bind to rdma://",
243 host_, ":", port_);
244 }
245 listening_.reset(id);
246 rdma_freeaddrinfo(addrinfo);
247
248 // Listen without backlog
249 if (rdma_listen(listening_.get(), 0)) {
250 return errors::Unavailable(strerror(errno), ": ",
251 "cannot listen on rdma://", host_, ":", port_);
252 }
253 LOG(INFO) << "RDMA server is listening on " << host_ << ":" << port_;
254
255 if (listening_->verbs == nullptr) {
256 return errors::Unimplemented(
257 "Unsupported address ", host_, ":", port_,
258 " as it does not bind to a particular RDMA device");
259 }
260
261 int flags = fcntl(listening_->channel->fd, F_GETFL, 0);
262 if (fcntl(listening_->channel->fd, F_SETFL, flags | O_NONBLOCK)) {
263 return errors::Unavailable(strerror(errno), ": ",
264 "cannot set server to non-blocking mode");
265 }
266
267 epoll_event event = {};
268 event.events = EPOLLIN | EPOLLPRI;
269 event.data.ptr = listening_.get();
270 if (epoll_ctl(epfd_, EPOLL_CTL_ADD, listening_->channel->fd, &event)) {
271 return errors::Unavailable(strerror(errno), ": ",
272 "cannot add server to epoll");
273 }
274
275 Allocator* allocators[] = {
276 #if GOOGLE_CUDA
277 ProcessState::singleton()->GetCUDAHostAllocator(0),
278 ProcessState::singleton()->GetCPUAllocator(0),
279 #endif // GOOGLE_CUDA
280 cpu_allocator(),
281 };
282
283 using namespace std::placeholders;
284 VisitableAllocator::Visitor alloc_visitor =
285 std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2);
286 VisitableAllocator::Visitor free_visitor =
287 std::bind(&GdrMemoryManager::EvictMemoryRegion, this, _1, _2);
288
289 std::set<Allocator*> instrumented_;
290
291 // Host memory allocators
292 for (Allocator* allocator : allocators) {
293 auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator);
294 CHECK(visitable_allocator)
295 << "is not visitable for instrumentation" << allocator->Name();
296 // Make sure we don't instrument the same allocator twice
297 if (instrumented_.find(allocator) == std::end(instrumented_)) {
298 visitable_allocator->AddAllocVisitor(alloc_visitor);
299 visitable_allocator->AddFreeVisitor(free_visitor);
300 instrumented_.insert(allocator);
301 LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name();
302 }
303 }
304
305 #if GOOGLE_CUDA
306 VisitableAllocator::Visitor cuda_alloc_visitor =
307 std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2);
308 if (IsGDRAvailable()) {
309 // Note we don't free allocated GPU memory so there is no free visitor
310 int32_t bus_id = TryToReadNumaNode(listening_->verbs->device) + 1;
311 ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor);
312 LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
313 }
314 #endif // GOOGLE_CUDA
315
316 return Status::OK();
317 }
318
Run()319 void GdrMemoryManager::Run() {
320 stopped_ = false;
321 while (!stopped_) {
322 epoll_event events[32];
323 int ret = epoll_wait(epfd_, events, 32, 1);
324 if (ret == -1) {
325 LOG(ERROR) << "epoll_wait: " << strerror(errno);
326 return;
327 }
328 for (int i = 0; i < ret; i++) {
329 rdma_cm_id* id = static_cast<rdma_cm_id*>(events[i].data.ptr);
330 if (id == listening_.get()) {
331 // Accept incoming connections
332 if (!rdma_get_request(listening_.get(), &id)) {
333 if (!rdma_accept(id, nullptr)) {
334 LOG(INFO) << "Accepted new RDMA connection";
335 if (ibv_req_notify_cq(id->recv_cq, 0)) {
336 LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed";
337 EndpointDeleter(id);
338 continue;
339 }
340 for (int i = 0; i < 32; i++) {
341 if (rdma_post_recvv(id, nullptr, nullptr, 0)) {
342 LOG(ERROR) << strerror(errno) << ": rdma_post_recvv failed";
343 EndpointDeleter(id);
344 continue;
345 }
346 }
347 int flags = fcntl(id->recv_cq_channel->fd, F_GETFL, 0);
348 if (fcntl(id->recv_cq_channel->fd, F_SETFL, flags | O_NONBLOCK)) {
349 LOG(ERROR) << strerror(errno)
350 << ": cannot set server_client to non-blocking mode";
351 EndpointDeleter(id);
352 continue;
353 }
354 epoll_event event = {};
355 event.events = EPOLLIN | EPOLLPRI;
356 event.data.ptr = id;
357 if (epoll_ctl(epfd_, EPOLL_CTL_ADD, id->recv_cq_channel->fd,
358 &event)) {
359 LOG(ERROR) << strerror(errno)
360 << ": cannot add server client to epoll";
361 EndpointDeleter(id);
362 continue;
363 }
364 server_clients_.push_back({id, EndpointDeleter});
365 }
366 }
367 } else {
368 // Polling work completions
369 ibv_cq* cq;
370 void* context;
371 if (!ibv_get_cq_event(id->recv_cq_channel, &cq, &context)) {
372 ibv_ack_cq_events(id->recv_cq, 1);
373 if (ibv_req_notify_cq(id->recv_cq, 0)) {
374 LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed";
375 continue;
376 }
377 ibv_wc wc[32];
378 int ret = ibv_poll_cq(id->recv_cq, 32, wc);
379 if (ret < 0) {
380 LOG(ERROR) << "ibv_poll_cq failed";
381 continue;
382 }
383 for (int i = 0; i < ret; i++) {
384 if (wc[i].opcode != IBV_WC_RECV_RDMA_WITH_IMM) {
385 LOG(ERROR) << "Received unknown operation " << wc[i].opcode;
386 }
387 if (wc[i].status != 0) {
388 LOG(ERROR) << ibv_wc_status_str(wc[i].status);
389 }
390 TensorKey tensor_key = ntohl(wc[i].imm_data);
391 {
392 mutex_lock l(server_mu_);
393 auto iter = tensor_buffers_.find(tensor_key);
394 if (iter == std::end(tensor_buffers_)) {
395 LOG(ERROR) << "Cannot find tensor buffer for tensor key "
396 << tensor_key;
397 } else {
398 const TensorBuffer* buffer = iter->second;
399 buffer->Unref();
400 tensor_buffers_.erase(iter);
401 }
402 }
403 if (rdma_post_recvv(id, nullptr, nullptr, 0)) {
404 perror("rdma_post_recvv");
405 LOG(ERROR) << "rdma_post_recvv failed";
406 continue;
407 }
408 }
409 }
410 }
411 }
412 }
413 }
414
Stop()415 void GdrMemoryManager::Stop() { stopped_ = true; }
416
TransportOptionsFromTensor(::google::protobuf::Any * mutable_transport_options,const Tensor & tensor,Device * device,DeviceContext * device_context,bool on_host,StatusCallback done)417 void GdrMemoryManager::TransportOptionsFromTensor(
418 ::google::protobuf::Any* mutable_transport_options, const Tensor& tensor,
419 Device* device, DeviceContext* device_context, bool on_host,
420 StatusCallback done) {
421 auto buffer = DMAHelper::buffer(&tensor);
422 void* addr = buffer->data();
423 size_t length = buffer->size();
424 if (length == 0) {
425 done(errors::Unavailable("Cannot register tensor buffer of size 0"));
426 return;
427 }
428
429 ibv_mr* mr = FindMemoryRegion(addr, length);
430
431 #if GOOGLE_CUDA
432 if (!on_host) {
433 Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
434 Tensor* host_copy = new Tensor(alloc, tensor.dtype(), tensor.shape());
435 GPUUtil::CopyGPUTensorToCPU(
436 device, device_context, &tensor, host_copy,
437 [done, host_copy, mutable_transport_options, this](const Status& s) {
438 if (!s.ok()) {
439 done(s);
440 delete host_copy;
441 return;
442 }
443 auto buffer = DMAHelper::buffer(host_copy);
444 void* addr = buffer->data();
445 size_t length = buffer->size();
446 ibv_mr* mr = FindMemoryRegion(addr, length);
447
448 if (mr == nullptr) {
449 done(errors::Unavailable("Cannot find pinned memory region"));
450 delete host_copy;
451 return;
452 }
453
454 buffer->Ref();
455 TensorKey tensor_key = next_key_++;
456 {
457 mutex_lock l(server_mu_);
458 tensor_buffers_.insert(std::make_pair(tensor_key, buffer));
459 }
460
461 uint64_t checksum = 0;
462 if (VLOG_IS_ON(2)) {
463 checksum = GPUUtil::Checksum(*host_copy);
464 }
465
466 RemoteMemoryRegion remote_mr;
467 remote_mr.set_host(host_);
468 remote_mr.set_port(port_);
469 remote_mr.set_addr(reinterpret_cast<uint64_t>(addr));
470 remote_mr.set_rkey(mr->rkey);
471 remote_mr.set_tensor_key(tensor_key);
472 remote_mr.set_checksum(checksum);
473 mutable_transport_options->PackFrom(remote_mr);
474
475 done(Status::OK());
476 delete host_copy;
477 });
478 return;
479 }
480 #endif
481
482 if (mr == nullptr) {
483 done(errors::Unavailable("Cannot find pinned memory region"));
484 return;
485 }
486
487 buffer->Ref();
488 TensorKey tensor_key = next_key_++;
489 {
490 mutex_lock l(server_mu_);
491 tensor_buffers_.insert(std::make_pair(tensor_key, buffer));
492 }
493
494 uint64_t checksum = 0;
495 if (VLOG_IS_ON(2)) {
496 #ifdef GOOGLE_CUDA
497 if (!on_host) {
498 checksum = GPUUtil::Checksum(device, device_context, tensor);
499 } else {
500 checksum = GPUUtil::Checksum(tensor);
501 }
502 #endif
503 }
504
505 RemoteMemoryRegion remote_mr;
506 remote_mr.set_host(host_);
507 remote_mr.set_port(port_);
508 remote_mr.set_addr(reinterpret_cast<uint64_t>(addr));
509 remote_mr.set_rkey(mr->rkey);
510 remote_mr.set_tensor_key(tensor_key);
511 remote_mr.set_checksum(checksum);
512 mutable_transport_options->PackFrom(remote_mr);
513
514 done(Status::OK());
515 }
516
TensorFromTransportOptions(Tensor * tensor,const::google::protobuf::Any & transport_options,Device * device,DeviceContext * device_context,bool on_host,StatusCallback done)517 void GdrMemoryManager::TensorFromTransportOptions(
518 Tensor* tensor, const ::google::protobuf::Any& transport_options,
519 Device* device, DeviceContext* device_context, bool on_host,
520 StatusCallback done) {
521 RemoteMemoryRegion remote_mr;
522 if (!transport_options.UnpackTo(&remote_mr)) {
523 done(errors::NotFound("No RDMA transport options found"));
524 return;
525 }
526
527 auto buffer = DMAHelper::buffer(tensor);
528 void* addr = buffer->data();
529 size_t length = buffer->size();
530 ibv_mr* mr = FindMemoryRegion(addr, length);
531
532 Tensor host_copy;
533 #if GOOGLE_CUDA
534 if (mr == nullptr && !on_host) {
535 Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
536 host_copy = Tensor(alloc, tensor->dtype(), tensor->shape());
537 buffer = DMAHelper::buffer(&host_copy);
538 addr = buffer->data();
539 length = buffer->size();
540 mr = FindMemoryRegion(addr, length);
541 }
542 #endif // GOOGLE_CUDA
543
544 if (mr == nullptr) {
545 done(errors::Unavailable("Cannot find pinned memory region"));
546 return;
547 }
548
549 decltype(clients_)::iterator iter;
550 bool success;
551 {
552 mutex_lock l(client_mu_);
553 std::tie(iter, success) = clients_.insert(
554 std::make_pair(std::make_pair(remote_mr.host(), remote_mr.port()),
555 RdmaEndpointPtr(nullptr, EndpointDeleter)));
556 if (success || iter->second.get() == nullptr) {
557 Status s =
558 CreateEndpoint(remote_mr.host(), remote_mr.port(), iter->second);
559 if (!s.ok()) {
560 done(s);
561 return;
562 }
563 }
564 }
565 rdma_cm_id* id = iter->second.get();
566
567 uint64_t start = Env::Default()->NowMicros();
568
569 if (rdma_post_read(id, nullptr, buffer->data(), buffer->size(), mr, 0,
570 remote_mr.addr(), remote_mr.rkey())) {
571 done(errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed"));
572 return;
573 }
574
575 ibv_send_wr wr = {};
576 wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
577 wr.imm_data = htonl(remote_mr.tensor_key());
578 wr.send_flags = IBV_SEND_SIGNALED;
579 ibv_send_wr* bad_wr;
580 if (ibv_post_send(id->qp, &wr, &bad_wr)) {
581 done(errors::Unavailable(strerror(errno), ": ", "ibv_post_send failed"));
582 return;
583 }
584
585 ibv_wc wc = {};
586 int ret;
587 while ((ret = ibv_poll_cq(id->send_cq, 1, &wc)) == 0)
588 ;
589 if (ret < 0 || wc.status) {
590 done(errors::Unavailable(ibv_wc_status_str(wc.status)));
591 return;
592 }
593
594 #if GOOGLE_CUDA
595 if (host_copy.NumElements() > 0) {
596 uint64_t checksum = 0;
597 if (VLOG_IS_ON(2)) {
598 checksum = GPUUtil::Checksum(host_copy);
599 CHECK(checksum == remote_mr.checksum())
600 << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum();
601 }
602 Tensor* ref = new Tensor;
603 std::swap(host_copy, *ref);
604 GPUUtil::CopyCPUTensorToGPU(
605 ref, device_context, device, tensor,
606 [ref, done, buffer, remote_mr, start](const Status& s) {
607 if (!s.ok()) {
608 done(s);
609 delete ref;
610 return;
611 }
612 uint64_t end = Env::Default()->NowMicros();
613
614 VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey()
615 << " of size " << buffer->size() << " with tensor key "
616 << remote_mr.tensor_key() << " took " << (end - start)
617 << " micros";
618 done(Status::OK());
619 delete ref;
620 });
621 return;
622 }
623 #endif // GOOGLE_CUDA
624
625 uint64_t end = Env::Default()->NowMicros();
626
627 VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey()
628 << " of size " << buffer->size() << " with tensor key "
629 << remote_mr.tensor_key() << " took " << (end - start) << " micros";
630
631 uint64_t checksum = 0;
632 if (VLOG_IS_ON(2)) {
633 #ifdef GOOGLE_CUDA
634 if (device->tensorflow_gpu_device_info() && (!on_host)) {
635 checksum = GPUUtil::Checksum(device, device_context, *tensor);
636 } else {
637 checksum = GPUUtil::Checksum(*tensor);
638 }
639 CHECK(checksum == remote_mr.checksum())
640 << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum();
641 #endif
642 }
643 done(Status::OK());
644 }
645
CreateEndpoint(const string & host,const string & port,RdmaEndpointPtr & endpoint)646 Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port,
647 RdmaEndpointPtr& endpoint) {
648 rdma_addrinfo* addrinfo;
649 rdma_addrinfo hints = {};
650 hints.ai_port_space = RDMA_PS_TCP;
651 if (rdma_getaddrinfo(const_cast<char*>(host.c_str()),
652 const_cast<char*>(port.c_str()), &hints, &addrinfo)) {
653 return errors::InvalidArgument(
654 strerror(errno), ": ", "cannot connect to rdma://", host, ":", port);
655 }
656
657 ibv_qp_init_attr init_attr = {};
658 init_attr.qp_type = IBV_QPT_RC;
659 init_attr.cap.max_recv_wr = 1;
660 init_attr.cap.max_send_wr = 32;
661 init_attr.cap.max_recv_sge = 1;
662 init_attr.cap.max_send_sge = 1;
663
664 rdma_cm_id* id;
665 if (rdma_create_ep(&id, addrinfo, nullptr, &init_attr)) {
666 rdma_freeaddrinfo(addrinfo);
667 return errors::Unavailable(strerror(errno), ": ",
668 "cannot create endpoint to rdma://", host, ":",
669 port);
670 }
671 rdma_freeaddrinfo(addrinfo);
672
673 if (rdma_connect(id, nullptr)) {
674 rdma_destroy_ep(id);
675 return errors::Unavailable(strerror(errno), ": ",
676 "cannot connect to rdma://", host, ":", port);
677 }
678
679 LOG(INFO) << "RDMA endpoint connected to rdma://" << host << ":" << port;
680 endpoint = RdmaEndpointPtr(id, EndpointDeleter);
681 return Status::OK();
682 }
683
FindMemoryRegion(void * addr,size_t length)684 ibv_mr* GdrMemoryManager::FindMemoryRegion(void* addr, size_t length) {
685 if (length == 0) return nullptr;
686 mutex_lock l(alloc_mu_);
687 auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
688 if (iter == std::end(mrs_) || iter->get()->addr > addr) {
689 return nullptr;
690 } else {
691 return iter->get();
692 }
693 }
694
InsertMemoryRegion(void * addr,size_t length)695 void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length) {
696 if (length == 0) return;
697 ibv_mr* mr = rdma_reg_read(listening_.get(), addr, length);
698 if (mr != nullptr) {
699 mutex_lock l(alloc_mu_);
700 auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
701 mrs_.insert(iter, {mr, &MRDeleter});
702 } else {
703 LOG(WARNING) << "Cannot register memory region";
704 }
705 }
706
EvictMemoryRegion(void * addr,size_t length)707 void GdrMemoryManager::EvictMemoryRegion(void* addr, size_t length) {
708 if (length == 0) return;
709 mutex_lock l(alloc_mu_);
710 auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
711 if (iter != std::end(mrs_) && iter->get()->addr == addr) {
712 mrs_.erase(iter);
713 } else {
714 LOG(WARNING) << "Failed to de-register memory region";
715 }
716 }
717
718 } // namespace
719
CreateRemoteMemoryManager(const string & host,const string & port)720 RemoteMemoryManager* CreateRemoteMemoryManager(const string& host,
721 const string& port) {
722 return new GdrMemoryManager(host, port);
723 }
724
725 } // namespace tensorflow
726
727 #endif // TENSORFLOW_USE_GDR
728