• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifdef TENSORFLOW_USE_GDR
17 
18 #include "tensorflow/contrib/gdr/gdr_memory_manager.h"
19 
20 #include <atomic>
21 #include <cerrno>
22 #include <fstream>
23 #include <list>
24 #include <map>
25 #include <set>
26 
27 #include <fcntl.h>
28 #include <rdma/rdma_cma.h>
29 #include <rdma/rdma_verbs.h>
30 #include <sys/epoll.h>
31 
32 #include "tensorflow/contrib/gdr/gdr.pb.h"
33 #include "tensorflow/core/common_runtime/bfc_allocator.h"
34 #include "tensorflow/core/common_runtime/device.h"
35 #include "tensorflow/core/common_runtime/dma_helper.h"
36 #if GOOGLE_CUDA
37 #include "tensorflow/core/common_runtime/gpu/gpu_util.h"
38 #include "tensorflow/core/common_runtime/gpu/process_state.h"
39 #endif  // GOOGLE_CUDA
40 #include "tensorflow/core/framework/allocator_registry.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/mutex.h"
44 
45 namespace tensorflow {
46 
47 namespace {
48 
IsGDRAvailable()49 bool IsGDRAvailable() {
50 #if defined(__APPLE__)
51   return false;
52 #elif defined(PLATFORM_WINDOWS)
53   return false;
54 #else
55   std::ifstream ifs("/proc/modules");
56   string line;
57   while (std::getline(ifs, line)) {
58     auto sep = line.find(' ');
59     CHECK_NE(sep, std::string::npos);
60     if (line.substr(0, sep) == "nv_peer_mem") {
61       return true;
62     }
63   }
64   return false;
65 #endif
66 }
67 
TryToReadNumaNode(ibv_device * device)68 int TryToReadNumaNode(ibv_device* device) {
69 #if defined(__APPLE__)
70   LOG(INFO) << "OS X does not support NUMA - returning NUMA node 0";
71   return 0;
72 #elif defined(PLATFORM_WINDOWS)
73   // Windows support for NUMA is not currently implemented. Return node 0.
74   return 0;
75 #else
76   VLOG(2) << "Trying to read NUMA node for device: " << device->name;
77   static const int kUnknownNumaNode = -1;
78 
79   auto filename = string(device->ibdev_path) + "/device/numa_node";
80 
81   std::ifstream ifs(filename.c_str());
82   string content;
83   CHECK(std::getline(ifs, content));
84 
85   int32 value;
86   if (strings::safe_strto32(content, &value)) {
87     if (value < 0) {
88       LOG(INFO) << "Successful NUMA node read from SysFS had negative value ("
89                 << value
90                 << "), but there must be at least one NUMA node"
91                    ", so returning NUMA node zero";
92       return 0;
93     }
94     LOG(INFO) << "NUMA node for device: " << device->name << " is " << value;
95     return value;
96   }
97   return kUnknownNumaNode;
98 #endif
99 }
100 
EndpointDeleter(rdma_cm_id * id)101 void EndpointDeleter(rdma_cm_id* id) {
102   if (id) {
103     rdma_destroy_ep(id);
104   }
105 }
106 
MRDeleter(ibv_mr * mr)107 void MRDeleter(ibv_mr* mr) {
108   if (mr) {
109     rdma_dereg_mr(mr);
110   }
111 }
112 
113 using RdmaEndpointPtr = std::unique_ptr<rdma_cm_id, decltype(&EndpointDeleter)>;
114 
115 using MemoryRegionPtr = std::unique_ptr<ibv_mr, decltype(&MRDeleter)>;
116 
117 class GdrMemoryManager : public RemoteMemoryManager {
118  public:
119   GdrMemoryManager(const string& host, const string& port);
120 
121   virtual ~GdrMemoryManager();
122 
123   virtual Status Init() override;
124 
125   virtual void Run() override;
126 
127   virtual void Stop() override;
128 
129   virtual void TransportOptionsFromTensor(
130       ::google::protobuf::Any* mutable_transport_options, const Tensor& tensor,
131       Device* device, DeviceContext* device_context, bool on_host,
132       StatusCallback done) override;
133 
134   virtual void TensorFromTransportOptions(
135       Tensor* tensor, const ::google::protobuf::Any& transport_options,
136       Device* device, DeviceContext* device_context, bool on_host,
137       StatusCallback done) override;
138 
139  protected:
140   Status CreateEndpoint(const string& host, const string& port,
141                         RdmaEndpointPtr& endpoint);
142 
Comparator(const void * ptr,const MemoryRegionPtr & other)143   static bool Comparator(const void* ptr, const MemoryRegionPtr& other) {
144     return ptr < reinterpret_cast<char*>(other->addr) + other->length;
145   }
146 
147   ibv_mr* FindMemoryRegion(void* addr, size_t length);
148 
149   void InsertMemoryRegion(void* addr, size_t length);
150 
151   void EvictMemoryRegion(void* addr, size_t length);
152 
153  private:
154   const string host_;
155   const string port_;
156   RdmaEndpointPtr listening_;
157   std::atomic<bool> stopped_;
158   int epfd_;
159 
160   // Server side endpoints
161   // Accessed sequentially in Run() so not protected by lock
162   std::list<RdmaEndpointPtr> server_clients_;
163 
164   using TensorKey = uint32_t;
165   std::atomic<TensorKey> next_key_;
166 
167   // Server side on-the-fly tensor buffers
168   mutex server_mu_;
169   std::map<TensorKey, const TensorBuffer*> tensor_buffers_
170       GUARDED_BY(server_mu_);
171 
172   // Client side endpoints
173   mutex client_mu_;
174   std::map<std::pair<string, string>, RdmaEndpointPtr> clients_
175       GUARDED_BY(cient_mu_);
176 
177   // Managed memory regions
178   mutex alloc_mu_;
179   std::vector<MemoryRegionPtr> mrs_ GUARDED_BY(alloc_mu_);
180 
181   TF_DISALLOW_COPY_AND_ASSIGN(GdrMemoryManager);
182 };
183 
184 // TODO(byronyi): remove this class duplicated from the one in
185 // common/runtime/gpu/pool_allocator.h when it is available in common_runtime
186 class BasicCPUAllocator : public SubAllocator {
187  public:
~BasicCPUAllocator()188   ~BasicCPUAllocator() override {}
189 
Alloc(size_t alignment,size_t num_bytes)190   void* Alloc(size_t alignment, size_t num_bytes) override {
191     return port::AlignedMalloc(num_bytes, alignment);
192   }
Free(void * ptr,size_t)193   void Free(void* ptr, size_t) override { port::AlignedFree(ptr); }
194 };
195 
196 // TODO(byronyi): remove this class and its registration when the default
197 // cpu_allocator() returns visitable allocator
198 class BFCRdmaAllocator : public BFCAllocator {
199  public:
BFCRdmaAllocator()200   BFCRdmaAllocator()
201       : BFCAllocator(new BasicCPUAllocator(), 1LL << 36, true, "cpu_rdma_bfc") {
202   }
203 };
204 
205 REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocator);
206 
GdrMemoryManager(const string & host,const string & port)207 GdrMemoryManager::GdrMemoryManager(const string& host, const string& port)
208     : host_(host),
209       port_(port),
210       listening_(nullptr, EndpointDeleter),
211       stopped_(true),
212       next_key_(0) {}
213 
~GdrMemoryManager()214 GdrMemoryManager::~GdrMemoryManager() { close(epfd_); }
215 
Init()216 Status GdrMemoryManager::Init() {
217   epfd_ = epoll_create1(0);
218   if (epfd_ == -1) {
219     return errors::Unavailable(strerror(errno), ": ", "epoll_create");
220   }
221 
222   rdma_addrinfo* addrinfo;
223   rdma_addrinfo hints = {};
224   hints.ai_port_space = RDMA_PS_TCP;
225   hints.ai_flags = RAI_PASSIVE;
226   if (rdma_getaddrinfo(const_cast<char*>(host_.c_str()),
227                        const_cast<char*>(port_.c_str()), &hints, &addrinfo)) {
228     return errors::Unavailable(strerror(errno), ": ", "cannot resolve rdma://",
229                                host_, ":", port_);
230   }
231 
232   ibv_qp_init_attr init_attr = {};
233   init_attr.qp_type = IBV_QPT_RC;
234   init_attr.cap.max_recv_wr = 32;
235   init_attr.cap.max_send_wr = 1;
236   init_attr.cap.max_recv_sge = 1;
237   init_attr.cap.max_send_sge = 1;
238 
239   // Create listening endpoint
240   rdma_cm_id* id;
241   if (rdma_create_ep(&id, addrinfo, nullptr, &init_attr)) {
242     return errors::Unavailable(strerror(errno), ": ", "cannot bind to rdma://",
243                                host_, ":", port_);
244   }
245   listening_.reset(id);
246   rdma_freeaddrinfo(addrinfo);
247 
248   // Listen without backlog
249   if (rdma_listen(listening_.get(), 0)) {
250     return errors::Unavailable(strerror(errno), ": ",
251                                "cannot listen on rdma://", host_, ":", port_);
252   }
253   LOG(INFO) << "RDMA server is listening on " << host_ << ":" << port_;
254 
255   if (listening_->verbs == nullptr) {
256     return errors::Unimplemented(
257         "Unsupported address ", host_, ":", port_,
258         " as it does not bind to a particular RDMA device");
259   }
260 
261   int flags = fcntl(listening_->channel->fd, F_GETFL, 0);
262   if (fcntl(listening_->channel->fd, F_SETFL, flags | O_NONBLOCK)) {
263     return errors::Unavailable(strerror(errno), ": ",
264                                "cannot set server to non-blocking mode");
265   }
266 
267   epoll_event event = {};
268   event.events = EPOLLIN | EPOLLPRI;
269   event.data.ptr = listening_.get();
270   if (epoll_ctl(epfd_, EPOLL_CTL_ADD, listening_->channel->fd, &event)) {
271     return errors::Unavailable(strerror(errno), ": ",
272                                "cannot add server to epoll");
273   }
274 
275   Allocator* allocators[] = {
276 #if GOOGLE_CUDA
277     ProcessState::singleton()->GetCUDAHostAllocator(0),
278     ProcessState::singleton()->GetCPUAllocator(0),
279 #endif  // GOOGLE_CUDA
280     cpu_allocator(),
281   };
282 
283   using namespace std::placeholders;
284   VisitableAllocator::Visitor alloc_visitor =
285       std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2);
286   VisitableAllocator::Visitor free_visitor =
287       std::bind(&GdrMemoryManager::EvictMemoryRegion, this, _1, _2);
288 
289   std::set<Allocator*> instrumented_;
290 
291   // Host memory allocators
292   for (Allocator* allocator : allocators) {
293     auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator);
294     CHECK(visitable_allocator)
295         << "is not visitable for instrumentation" << allocator->Name();
296     // Make sure we don't instrument the same allocator twice
297     if (instrumented_.find(allocator) == std::end(instrumented_)) {
298       visitable_allocator->AddAllocVisitor(alloc_visitor);
299       visitable_allocator->AddFreeVisitor(free_visitor);
300       instrumented_.insert(allocator);
301       LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name();
302     }
303   }
304 
305 #if GOOGLE_CUDA
306   VisitableAllocator::Visitor cuda_alloc_visitor =
307       std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2);
308   if (IsGDRAvailable()) {
309     // Note we don't free allocated GPU memory so there is no free visitor
310     int32_t bus_id = TryToReadNumaNode(listening_->verbs->device) + 1;
311     ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor);
312     LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
313   }
314 #endif  // GOOGLE_CUDA
315 
316   return Status::OK();
317 }
318 
Run()319 void GdrMemoryManager::Run() {
320   stopped_ = false;
321   while (!stopped_) {
322     epoll_event events[32];
323     int ret = epoll_wait(epfd_, events, 32, 1);
324     if (ret == -1) {
325       LOG(ERROR) << "epoll_wait: " << strerror(errno);
326       return;
327     }
328     for (int i = 0; i < ret; i++) {
329       rdma_cm_id* id = static_cast<rdma_cm_id*>(events[i].data.ptr);
330       if (id == listening_.get()) {
331         // Accept incoming connections
332         if (!rdma_get_request(listening_.get(), &id)) {
333           if (!rdma_accept(id, nullptr)) {
334             LOG(INFO) << "Accepted new RDMA connection";
335             if (ibv_req_notify_cq(id->recv_cq, 0)) {
336               LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed";
337               EndpointDeleter(id);
338               continue;
339             }
340             for (int i = 0; i < 32; i++) {
341               if (rdma_post_recvv(id, nullptr, nullptr, 0)) {
342                 LOG(ERROR) << strerror(errno) << ": rdma_post_recvv failed";
343                 EndpointDeleter(id);
344                 continue;
345               }
346             }
347             int flags = fcntl(id->recv_cq_channel->fd, F_GETFL, 0);
348             if (fcntl(id->recv_cq_channel->fd, F_SETFL, flags | O_NONBLOCK)) {
349               LOG(ERROR) << strerror(errno)
350                          << ": cannot set server_client to non-blocking mode";
351               EndpointDeleter(id);
352               continue;
353             }
354             epoll_event event = {};
355             event.events = EPOLLIN | EPOLLPRI;
356             event.data.ptr = id;
357             if (epoll_ctl(epfd_, EPOLL_CTL_ADD, id->recv_cq_channel->fd,
358                           &event)) {
359               LOG(ERROR) << strerror(errno)
360                          << ": cannot add server client to epoll";
361               EndpointDeleter(id);
362               continue;
363             }
364             server_clients_.push_back({id, EndpointDeleter});
365           }
366         }
367       } else {
368         // Polling work completions
369         ibv_cq* cq;
370         void* context;
371         if (!ibv_get_cq_event(id->recv_cq_channel, &cq, &context)) {
372           ibv_ack_cq_events(id->recv_cq, 1);
373           if (ibv_req_notify_cq(id->recv_cq, 0)) {
374             LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed";
375             continue;
376           }
377           ibv_wc wc[32];
378           int ret = ibv_poll_cq(id->recv_cq, 32, wc);
379           if (ret < 0) {
380             LOG(ERROR) << "ibv_poll_cq failed";
381             continue;
382           }
383           for (int i = 0; i < ret; i++) {
384             if (wc[i].opcode != IBV_WC_RECV_RDMA_WITH_IMM) {
385               LOG(ERROR) << "Received unknown operation " << wc[i].opcode;
386             }
387             if (wc[i].status != 0) {
388               LOG(ERROR) << ibv_wc_status_str(wc[i].status);
389             }
390             TensorKey tensor_key = ntohl(wc[i].imm_data);
391             {
392               mutex_lock l(server_mu_);
393               auto iter = tensor_buffers_.find(tensor_key);
394               if (iter == std::end(tensor_buffers_)) {
395                 LOG(ERROR) << "Cannot find tensor buffer for tensor key "
396                            << tensor_key;
397               } else {
398                 const TensorBuffer* buffer = iter->second;
399                 buffer->Unref();
400                 tensor_buffers_.erase(iter);
401               }
402             }
403             if (rdma_post_recvv(id, nullptr, nullptr, 0)) {
404               perror("rdma_post_recvv");
405               LOG(ERROR) << "rdma_post_recvv failed";
406               continue;
407             }
408           }
409         }
410       }
411     }
412   }
413 }
414 
Stop()415 void GdrMemoryManager::Stop() { stopped_ = true; }
416 
TransportOptionsFromTensor(::google::protobuf::Any * mutable_transport_options,const Tensor & tensor,Device * device,DeviceContext * device_context,bool on_host,StatusCallback done)417 void GdrMemoryManager::TransportOptionsFromTensor(
418     ::google::protobuf::Any* mutable_transport_options, const Tensor& tensor,
419     Device* device, DeviceContext* device_context, bool on_host,
420     StatusCallback done) {
421   auto buffer = DMAHelper::buffer(&tensor);
422   void* addr = buffer->data();
423   size_t length = buffer->size();
424   if (length == 0) {
425     done(errors::Unavailable("Cannot register tensor buffer of size 0"));
426     return;
427   }
428 
429   ibv_mr* mr = FindMemoryRegion(addr, length);
430 
431 #if GOOGLE_CUDA
432   if (!on_host) {
433     Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
434     Tensor* host_copy = new Tensor(alloc, tensor.dtype(), tensor.shape());
435     GPUUtil::CopyGPUTensorToCPU(
436         device, device_context, &tensor, host_copy,
437         [done, host_copy, mutable_transport_options, this](const Status& s) {
438           if (!s.ok()) {
439             done(s);
440             delete host_copy;
441             return;
442           }
443           auto buffer = DMAHelper::buffer(host_copy);
444           void* addr = buffer->data();
445           size_t length = buffer->size();
446           ibv_mr* mr = FindMemoryRegion(addr, length);
447 
448           if (mr == nullptr) {
449             done(errors::Unavailable("Cannot find pinned memory region"));
450             delete host_copy;
451             return;
452           }
453 
454           buffer->Ref();
455           TensorKey tensor_key = next_key_++;
456           {
457             mutex_lock l(server_mu_);
458             tensor_buffers_.insert(std::make_pair(tensor_key, buffer));
459           }
460 
461           uint64_t checksum = 0;
462           if (VLOG_IS_ON(2)) {
463             checksum = GPUUtil::Checksum(*host_copy);
464           }
465 
466           RemoteMemoryRegion remote_mr;
467           remote_mr.set_host(host_);
468           remote_mr.set_port(port_);
469           remote_mr.set_addr(reinterpret_cast<uint64_t>(addr));
470           remote_mr.set_rkey(mr->rkey);
471           remote_mr.set_tensor_key(tensor_key);
472           remote_mr.set_checksum(checksum);
473           mutable_transport_options->PackFrom(remote_mr);
474 
475           done(Status::OK());
476           delete host_copy;
477         });
478     return;
479   }
480 #endif
481 
482   if (mr == nullptr) {
483     done(errors::Unavailable("Cannot find pinned memory region"));
484     return;
485   }
486 
487   buffer->Ref();
488   TensorKey tensor_key = next_key_++;
489   {
490     mutex_lock l(server_mu_);
491     tensor_buffers_.insert(std::make_pair(tensor_key, buffer));
492   }
493 
494   uint64_t checksum = 0;
495   if (VLOG_IS_ON(2)) {
496 #ifdef GOOGLE_CUDA
497     if (!on_host) {
498       checksum = GPUUtil::Checksum(device, device_context, tensor);
499     } else {
500       checksum = GPUUtil::Checksum(tensor);
501     }
502 #endif
503   }
504 
505   RemoteMemoryRegion remote_mr;
506   remote_mr.set_host(host_);
507   remote_mr.set_port(port_);
508   remote_mr.set_addr(reinterpret_cast<uint64_t>(addr));
509   remote_mr.set_rkey(mr->rkey);
510   remote_mr.set_tensor_key(tensor_key);
511   remote_mr.set_checksum(checksum);
512   mutable_transport_options->PackFrom(remote_mr);
513 
514   done(Status::OK());
515 }
516 
TensorFromTransportOptions(Tensor * tensor,const::google::protobuf::Any & transport_options,Device * device,DeviceContext * device_context,bool on_host,StatusCallback done)517 void GdrMemoryManager::TensorFromTransportOptions(
518     Tensor* tensor, const ::google::protobuf::Any& transport_options,
519     Device* device, DeviceContext* device_context, bool on_host,
520     StatusCallback done) {
521   RemoteMemoryRegion remote_mr;
522   if (!transport_options.UnpackTo(&remote_mr)) {
523     done(errors::NotFound("No RDMA transport options found"));
524     return;
525   }
526 
527   auto buffer = DMAHelper::buffer(tensor);
528   void* addr = buffer->data();
529   size_t length = buffer->size();
530   ibv_mr* mr = FindMemoryRegion(addr, length);
531 
532   Tensor host_copy;
533 #if GOOGLE_CUDA
534   if (mr == nullptr && !on_host) {
535     Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
536     host_copy = Tensor(alloc, tensor->dtype(), tensor->shape());
537     buffer = DMAHelper::buffer(&host_copy);
538     addr = buffer->data();
539     length = buffer->size();
540     mr = FindMemoryRegion(addr, length);
541   }
542 #endif  // GOOGLE_CUDA
543 
544   if (mr == nullptr) {
545     done(errors::Unavailable("Cannot find pinned memory region"));
546     return;
547   }
548 
549   decltype(clients_)::iterator iter;
550   bool success;
551   {
552     mutex_lock l(client_mu_);
553     std::tie(iter, success) = clients_.insert(
554         std::make_pair(std::make_pair(remote_mr.host(), remote_mr.port()),
555                        RdmaEndpointPtr(nullptr, EndpointDeleter)));
556     if (success || iter->second.get() == nullptr) {
557       Status s =
558           CreateEndpoint(remote_mr.host(), remote_mr.port(), iter->second);
559       if (!s.ok()) {
560         done(s);
561         return;
562       }
563     }
564   }
565   rdma_cm_id* id = iter->second.get();
566 
567   uint64_t start = Env::Default()->NowMicros();
568 
569   if (rdma_post_read(id, nullptr, buffer->data(), buffer->size(), mr, 0,
570                      remote_mr.addr(), remote_mr.rkey())) {
571     done(errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed"));
572     return;
573   }
574 
575   ibv_send_wr wr = {};
576   wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
577   wr.imm_data = htonl(remote_mr.tensor_key());
578   wr.send_flags = IBV_SEND_SIGNALED;
579   ibv_send_wr* bad_wr;
580   if (ibv_post_send(id->qp, &wr, &bad_wr)) {
581     done(errors::Unavailable(strerror(errno), ": ", "ibv_post_send failed"));
582     return;
583   }
584 
585   ibv_wc wc = {};
586   int ret;
587   while ((ret = ibv_poll_cq(id->send_cq, 1, &wc)) == 0)
588     ;
589   if (ret < 0 || wc.status) {
590     done(errors::Unavailable(ibv_wc_status_str(wc.status)));
591     return;
592   }
593 
594 #if GOOGLE_CUDA
595   if (host_copy.NumElements() > 0) {
596     uint64_t checksum = 0;
597     if (VLOG_IS_ON(2)) {
598       checksum = GPUUtil::Checksum(host_copy);
599       CHECK(checksum == remote_mr.checksum())
600           << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum();
601     }
602     Tensor* ref = new Tensor;
603     std::swap(host_copy, *ref);
604     GPUUtil::CopyCPUTensorToGPU(
605         ref, device_context, device, tensor,
606         [ref, done, buffer, remote_mr, start](const Status& s) {
607           if (!s.ok()) {
608             done(s);
609             delete ref;
610             return;
611           }
612           uint64_t end = Env::Default()->NowMicros();
613 
614           VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey()
615                   << " of size " << buffer->size() << " with tensor key "
616                   << remote_mr.tensor_key() << " took " << (end - start)
617                   << " micros";
618           done(Status::OK());
619           delete ref;
620         });
621     return;
622   }
623 #endif  // GOOGLE_CUDA
624 
625   uint64_t end = Env::Default()->NowMicros();
626 
627   VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey()
628           << " of size " << buffer->size() << " with tensor key "
629           << remote_mr.tensor_key() << " took " << (end - start) << " micros";
630 
631   uint64_t checksum = 0;
632   if (VLOG_IS_ON(2)) {
633 #ifdef GOOGLE_CUDA
634     if (device->tensorflow_gpu_device_info() && (!on_host)) {
635       checksum = GPUUtil::Checksum(device, device_context, *tensor);
636     } else {
637       checksum = GPUUtil::Checksum(*tensor);
638     }
639     CHECK(checksum == remote_mr.checksum())
640         << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum();
641 #endif
642   }
643   done(Status::OK());
644 }
645 
CreateEndpoint(const string & host,const string & port,RdmaEndpointPtr & endpoint)646 Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port,
647                                         RdmaEndpointPtr& endpoint) {
648   rdma_addrinfo* addrinfo;
649   rdma_addrinfo hints = {};
650   hints.ai_port_space = RDMA_PS_TCP;
651   if (rdma_getaddrinfo(const_cast<char*>(host.c_str()),
652                        const_cast<char*>(port.c_str()), &hints, &addrinfo)) {
653     return errors::InvalidArgument(
654         strerror(errno), ": ", "cannot connect to rdma://", host, ":", port);
655   }
656 
657   ibv_qp_init_attr init_attr = {};
658   init_attr.qp_type = IBV_QPT_RC;
659   init_attr.cap.max_recv_wr = 1;
660   init_attr.cap.max_send_wr = 32;
661   init_attr.cap.max_recv_sge = 1;
662   init_attr.cap.max_send_sge = 1;
663 
664   rdma_cm_id* id;
665   if (rdma_create_ep(&id, addrinfo, nullptr, &init_attr)) {
666     rdma_freeaddrinfo(addrinfo);
667     return errors::Unavailable(strerror(errno), ": ",
668                                "cannot create endpoint to rdma://", host, ":",
669                                port);
670   }
671   rdma_freeaddrinfo(addrinfo);
672 
673   if (rdma_connect(id, nullptr)) {
674     rdma_destroy_ep(id);
675     return errors::Unavailable(strerror(errno), ": ",
676                                "cannot connect to rdma://", host, ":", port);
677   }
678 
679   LOG(INFO) << "RDMA endpoint connected to rdma://" << host << ":" << port;
680   endpoint = RdmaEndpointPtr(id, EndpointDeleter);
681   return Status::OK();
682 }
683 
FindMemoryRegion(void * addr,size_t length)684 ibv_mr* GdrMemoryManager::FindMemoryRegion(void* addr, size_t length) {
685   if (length == 0) return nullptr;
686   mutex_lock l(alloc_mu_);
687   auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
688   if (iter == std::end(mrs_) || iter->get()->addr > addr) {
689     return nullptr;
690   } else {
691     return iter->get();
692   }
693 }
694 
InsertMemoryRegion(void * addr,size_t length)695 void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length) {
696   if (length == 0) return;
697   ibv_mr* mr = rdma_reg_read(listening_.get(), addr, length);
698   if (mr != nullptr) {
699     mutex_lock l(alloc_mu_);
700     auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
701     mrs_.insert(iter, {mr, &MRDeleter});
702   } else {
703     LOG(WARNING) << "Cannot register memory region";
704   }
705 }
706 
EvictMemoryRegion(void * addr,size_t length)707 void GdrMemoryManager::EvictMemoryRegion(void* addr, size_t length) {
708   if (length == 0) return;
709   mutex_lock l(alloc_mu_);
710   auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
711   if (iter != std::end(mrs_) && iter->get()->addr == addr) {
712     mrs_.erase(iter);
713   } else {
714     LOG(WARNING) << "Failed to de-register memory region";
715   }
716 }
717 
718 }  // namespace
719 
CreateRemoteMemoryManager(const string & host,const string & port)720 RemoteMemoryManager* CreateRemoteMemoryManager(const string& host,
721                                                const string& port) {
722   return new GdrMemoryManager(host, port);
723 }
724 
725 }  // namespace tensorflow
726 
727 #endif  // TENSORFLOW_USE_GDR
728