• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifdef TENSORFLOW_USE_GDR
17 
18 #include "tensorflow/contrib/gdr/gdr_memory_manager.h"
19 
20 #include <atomic>
21 #include <cerrno>
22 #include <fstream>
23 #include <list>
24 #include <map>
25 
26 #include <fcntl.h>
27 #include <rdma/rdma_cma.h>
28 #include <rdma/rdma_verbs.h>
29 
30 #include "tensorflow/contrib/gdr/gdr.pb.h"
31 #include "tensorflow/core/common_runtime/device.h"
32 #include "tensorflow/core/common_runtime/dma_helper.h"
33 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
34 #include "tensorflow/core/common_runtime/process_state.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/core/lib/random/random.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/mutex.h"
39 #include "tensorflow/core/platform/numa.h"
40 
41 namespace tensorflow {
42 
43 namespace {
44 
IsGDRAvailable()45 bool IsGDRAvailable() {
46 #if defined(__APPLE__)
47   return false;
48 #elif defined(PLATFORM_WINDOWS)
49   return false;
50 #else
51   std::ifstream ifs("/proc/modules");
52   string line;
53   while (std::getline(ifs, line)) {
54     auto sep = line.find(' ');
55     CHECK_NE(sep, std::string::npos);
56     if (line.substr(0, sep) == "nv_peer_mem") {
57       return true;
58     }
59   }
60   return false;
61 #endif
62 }
63 
TryToReadNumaNode(ibv_device * device)64 int TryToReadNumaNode(ibv_device* device) {
65 #if defined(__APPLE__)
66   LOG(INFO) << "OS X does not support NUMA - returning NUMA node 0";
67   return port::kNUMANoAffinity;
68 #elif defined(PLATFORM_WINDOWS)
69   // Windows support for NUMA is not currently implemented. Return node 0.
70   return port::kNUMANoAffinity;
71 #else
72   auto filename = string(device->ibdev_path) + "/device/numa_node";
73 
74   std::ifstream ifs(filename.c_str());
75   string content;
76   const auto& ret = std::getline(ifs, content);
77   if (!ret) {
78     return port::kNUMANoAffinity;
79   }
80 
81   int32 value;
82   if (strings::safe_strto32(content, &value)) {
83     if (value < 0) {
84       return port::kNUMANoAffinity;
85     }
86     LOG(INFO) << "NUMA node for device: " << device->name << " is " << value;
87     return value;
88   }
89   return port::kNUMANoAffinity;
90 #endif
91 }
92 
EndpointDeleter(rdma_cm_id * id)93 void EndpointDeleter(rdma_cm_id* id) {
94   if (id) {
95     rdma_destroy_ep(id);
96   }
97 }
98 
MRDeleter(ibv_mr * mr)99 void MRDeleter(ibv_mr* mr) {
100   if (mr) {
101     rdma_dereg_mr(mr);
102   }
103 }
104 
105 using RdmaEndpointPtr = std::unique_ptr<rdma_cm_id, decltype(&EndpointDeleter)>;
106 
107 using MemoryRegionPtr = std::unique_ptr<ibv_mr, decltype(&MRDeleter)>;
108 
109 class GdrMemoryManager : public RemoteMemoryManager {
110  public:
111   GdrMemoryManager(const string& host, const string& port);
112 
~GdrMemoryManager()113   virtual ~GdrMemoryManager() {}
114 
115   virtual Status Init() override;
116 
117   virtual void Run() override;
118 
119   virtual void Stop() override;
120 
121   virtual void TransportOptionsFromTensor(
122       ::google::protobuf::Any* mutable_transport_options, const Tensor& tensor,
123       Device* device, DeviceContext* device_context, bool on_host,
124       StatusCallback done) override;
125 
126   virtual void TensorFromTransportOptions(
127       Tensor* tensor, const ::google::protobuf::Any& transport_options,
128       Device* device, DeviceContext* device_context, bool on_host,
129       StatusCallback done) override;
130 
131  protected:
132   Status CreateEndpoint(const string& host, const string& port,
133                         RdmaEndpointPtr& endpoint);
134 
Comparator(const void * ptr,const MemoryRegionPtr & other)135   static bool Comparator(const void* ptr, const MemoryRegionPtr& other) {
136     return ptr < reinterpret_cast<char*>(other->addr) + other->length;
137   }
138 
139   ibv_mr* FindMemoryRegion(const Tensor* tensor);
140 
141   void InsertMemoryRegion(void* addr, size_t length,
142                           const std::string& allocator_name);
143 
144   void EvictMemoryRegion(void* addr, size_t length);
145 
146  private:
147   const string host_;
148   const string port_;
149   RdmaEndpointPtr listening_;
150   std::atomic<bool> stopped_;
151   int numa_node_;
152 
153   // Server side endpoints
154   // Accessed sequentially in Run() so not protected by lock
155   std::list<RdmaEndpointPtr> server_clients_;
156 
157   using TensorKey = uint32_t;
158   std::atomic<TensorKey> next_key_;
159 
160   // Server side on-the-fly tensor buffers
161   mutex buf_mu_;
162   std::map<TensorKey, const TensorBuffer*> tensor_buffers_ GUARDED_BY(buf_mu_);
163 
164   // Client side endpoints
165   mutex client_mu_;
166   std::map<std::pair<string, string>, RdmaEndpointPtr> clients_
167       GUARDED_BY(client_mu_);
168 
169   // Client side callbacks
170   mutex callback_mu_;
171   std::map<TensorKey, StatusCallback> tensor_callbacks_
172       GUARDED_BY(callback_mu_);
173 
174   // Managed memory regions
175   mutex alloc_mu_;
176   std::vector<MemoryRegionPtr> mrs_ GUARDED_BY(alloc_mu_);
177 
178   TF_DISALLOW_COPY_AND_ASSIGN(GdrMemoryManager);
179 };
180 
GdrMemoryManager(const string & host,const string & port)181 GdrMemoryManager::GdrMemoryManager(const string& host, const string& port)
182     : host_(host),
183       port_(port),
184       listening_(nullptr, EndpointDeleter),
185       stopped_(true),
186       next_key_(static_cast<uint32_t>(random::New64())) {}
187 
Init()188 Status GdrMemoryManager::Init() {
189   rdma_addrinfo* addrinfo;
190   rdma_addrinfo hints = {};
191   hints.ai_port_space = RDMA_PS_TCP;
192   hints.ai_flags = RAI_PASSIVE;
193   if (rdma_getaddrinfo(const_cast<char*>(host_.c_str()),
194                        const_cast<char*>(port_.c_str()), &hints, &addrinfo)) {
195     return errors::Unavailable(strerror(errno), ": ", "cannot resolve rdma://",
196                                host_, ":", port_);
197   }
198 
199   ibv_qp_init_attr init_attr = {};
200   init_attr.qp_type = IBV_QPT_RC;
201   init_attr.cap.max_recv_wr = 1024;
202   init_attr.cap.max_send_wr = 1;
203   init_attr.cap.max_recv_sge = 1;
204   init_attr.cap.max_send_sge = 1;
205 
206   // Create listening endpoint
207   rdma_cm_id* id;
208   if (rdma_create_ep(&id, addrinfo, nullptr, &init_attr)) {
209     return errors::Unavailable(strerror(errno), ": ", "cannot bind to rdma://",
210                                host_, ":", port_);
211   }
212   listening_.reset(id);
213   rdma_freeaddrinfo(addrinfo);
214 
215   // Listen without backlog
216   if (rdma_listen(listening_.get(), 0)) {
217     return errors::Unavailable(strerror(errno), ": ",
218                                "cannot listen on rdma://", host_, ":", port_);
219   }
220   LOG(INFO) << "RDMA server is listening on " << host_ << ":" << port_;
221 
222   if (listening_->verbs == nullptr) {
223     return errors::Unimplemented(
224         "Unsupported address ", host_, ":", port_,
225         " as it does not bind to a particular RDMA device");
226   }
227 
228   int flags = fcntl(listening_->channel->fd, F_GETFL, 0);
229   if (fcntl(listening_->channel->fd, F_SETFL, flags | O_NONBLOCK)) {
230     return errors::Unavailable(strerror(errno), ": ",
231                                "cannot set server to non-blocking mode");
232   }
233 
234   numa_node_ = TryToReadNumaNode(listening_->verbs->device);
235 
236   SubAllocator::Visitor alloc_visitor = [this](void* ptr, int numa_node,
237                                                size_t num_bytes) {
238     VLOG(2) << "Registering RDMA capable memory region on numa_node "
239             << numa_node;
240     InsertMemoryRegion(ptr, num_bytes, strings::StrCat("CPU:", numa_node));
241   };
242   SubAllocator::Visitor free_visitor = [this](void* ptr, int numa_node,
243                                               size_t num_bytes) {
244     VLOG(2) << "De-registering RDMA capable memory region on numa_node "
245             << numa_node;
246     EvictMemoryRegion(ptr, num_bytes);
247   };
248   ProcessState::singleton()->AddCPUAllocVisitor(alloc_visitor);
249   ProcessState::singleton()->AddCPUFreeVisitor(free_visitor);
250   LOG(INFO) << "Instrumenting CPU allocator(s)";
251 
252   for (int numa_idx = 0; numa_idx < port::NUMANumNodes(); ++numa_idx) {
253     GPUProcessState::singleton()->AddGpuHostAllocVisitor(numa_idx,
254                                                          alloc_visitor);
255     GPUProcessState::singleton()->AddGpuHostFreeVisitor(numa_idx, free_visitor);
256   }
257 
258   if (IsGDRAvailable()) {
259     SubAllocator::Visitor cuda_alloc_visitor = [this](void* ptr, int gpu_id,
260                                                       size_t num_bytes) {
261       VLOG(2) << "Registering RDMA capable memory region on GPU " << gpu_id;
262       InsertMemoryRegion(ptr, num_bytes, strings::StrCat("GPU:", gpu_id));
263     };
264     GPUProcessState::singleton()->AddGPUAllocVisitor(numa_node_,
265                                                      cuda_alloc_visitor);
266     LOG(INFO) << "Instrumenting GPU allocator for NUMA " << numa_node_;
267   }
268 
269   return Status::OK();
270 }
271 
Run()272 void GdrMemoryManager::Run() {
273   stopped_ = false;
274   while (!stopped_) {
275     rdma_cm_id* id = nullptr;
276     // Accept incoming connections
277     if (!rdma_get_request(listening_.get(), &id)) {
278       if (!rdma_accept(id, nullptr)) {
279         LOG(INFO) << "Accepted new RDMA connection";
280         for (int i = 0; i < 1024; i++) {
281           if (rdma_post_recvv(id, nullptr, nullptr, 0)) {
282             LOG(ERROR) << strerror(errno) << ": rdma_post_recvv failed";
283             EndpointDeleter(id);
284             continue;
285           }
286         }
287         server_clients_.push_back({id, EndpointDeleter});
288       }
289     }
290     // Polling server side work completions
291     for (const auto& client : server_clients_) {
292       ibv_wc wc[32];
293       int ret = ibv_poll_cq(client->recv_cq, 32, wc);
294       if (ret < 0) {
295         LOG(ERROR) << "ibv_poll_cq failed";
296         continue;
297       }
298       for (int i = 0; i < ret; i++) {
299         if (wc[i].opcode != IBV_WC_RECV_RDMA_WITH_IMM) {
300           LOG(ERROR) << "Received unknown operation " << wc[i].opcode;
301         }
302         if (wc[i].status != 0) {
303           LOG(ERROR) << ibv_wc_status_str(wc[i].status);
304         }
305         TensorKey tensor_key = ntohl(wc[i].imm_data);
306 
307         if (rdma_post_recvv(client.get(), nullptr, nullptr, 0)) {
308           perror("rdma_post_recvv");
309           LOG(ERROR) << "rdma_post_recvv failed";
310         }
311 
312         mutex_lock l(buf_mu_);
313         auto iter = tensor_buffers_.find(tensor_key);
314         if (iter == std::end(tensor_buffers_)) {
315           LOG(ERROR) << "Cannot find tensor buffer for tensor key "
316                      << tensor_key;
317         } else {
318           const TensorBuffer* buffer = iter->second;
319           buffer->Unref();
320           tensor_buffers_.erase(iter);
321         }
322       }
323     }
324     // Polling client side work completions
325     if (client_mu_.try_lock()) {
326       for (const auto& client : clients_) {
327         ibv_wc wc[32];
328         int ret = ibv_poll_cq(client.second->send_cq, 32, wc);
329         for (int i = 0; i < ret; i++) {
330           Status s;
331           if (wc[i].status) {
332             s = errors::Unavailable(ibv_wc_status_str(wc[i].status));
333           } else {
334             s = Status::OK();
335           }
336           TensorKey key = wc[i].wr_id;
337 
338           ibv_send_wr wr = {};
339           wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
340           wr.imm_data = htonl(key);
341           ibv_send_wr* bad_wr;
342           if (ibv_post_send(client.second->qp, &wr, &bad_wr)) {
343             LOG(ERROR) << strerror(errno)
344                        << ": ibv_post_send failed for tensor_key " << key;
345           }
346 
347           mutex_lock l(callback_mu_);
348           auto iter = tensor_callbacks_.find(key);
349           if (iter != std::end(tensor_callbacks_)) {
350             iter->second(s);
351             tensor_callbacks_.erase(iter);
352           } else {
353             LOG(WARNING) << "Cannot find client callback with tensor key "
354                          << key;
355           }
356         }
357       }
358       client_mu_.unlock();
359     }
360   }
361 }
362 
Stop()363 void GdrMemoryManager::Stop() { stopped_ = true; }
364 
TransportOptionsFromTensor(::google::protobuf::Any * mutable_transport_options,const Tensor & tensor,Device * device,DeviceContext * device_context,bool on_host,StatusCallback done)365 void GdrMemoryManager::TransportOptionsFromTensor(
366     ::google::protobuf::Any* mutable_transport_options, const Tensor& tensor,
367     Device* device, DeviceContext* device_context, bool on_host,
368     StatusCallback done) {
369   ibv_mr* mr = FindMemoryRegion(&tensor);
370   const TensorBuffer* buffer = DMAHelper::buffer(&tensor);
371 
372   Tensor* copy = nullptr;
373 
374   if (mr == nullptr) {
375     AllocatorAttributes alloc_attrs;
376     alloc_attrs.set_gpu_compatible(true);
377     alloc_attrs.set_nic_compatible(true);
378     alloc_attrs.set_on_host(true);
379     Allocator* alloc = device->GetAllocator(alloc_attrs);
380     copy = new Tensor(alloc, tensor.dtype(), tensor.shape());
381 
382     mr = FindMemoryRegion(copy);
383     buffer = DMAHelper::buffer(copy);
384     if (mr == nullptr) {
385       done(errors::Unavailable("Cannot find pinned memory region"));
386       delete copy;
387       return;
388     }
389   }
390 
391   TensorKey tensor_key = next_key_++;
392   buffer->Ref();
393   {
394     mutex_lock l(buf_mu_);
395     tensor_buffers_.insert(std::make_pair(tensor_key, buffer));
396   }
397 
398   RemoteMemoryRegion remote_mr;
399   remote_mr.set_host(host_);
400   remote_mr.set_port(port_);
401   remote_mr.set_addr(reinterpret_cast<uint64_t>(buffer->data()));
402   remote_mr.set_rkey(mr->rkey);
403   remote_mr.set_tensor_key(tensor_key);
404   mutable_transport_options->PackFrom(remote_mr);
405 
406   if (copy && device->tensorflow_gpu_device_info() && !on_host) {
407     device_context->CopyDeviceTensorToCPU(&tensor, "" /* tensor_name */, device,
408                                           copy, [done, copy](const Status& s) {
409                                             done(s);
410                                             delete copy;
411                                           });
412     return;
413   } else if (copy) {
414     std::memcpy(buffer->data(), DMAHelper::buffer(&tensor)->data(),
415                 buffer->size());
416     done(Status::OK());
417     delete copy;  // OK to delete; we have reffed the underlying TensorBuffer
418   } else {
419     done(Status::OK());
420   }
421 }
422 
TensorFromTransportOptions(Tensor * tensor,const::google::protobuf::Any & transport_options,Device * device,DeviceContext * device_context,bool on_host,StatusCallback done)423 void GdrMemoryManager::TensorFromTransportOptions(
424     Tensor* tensor, const ::google::protobuf::Any& transport_options,
425     Device* device, DeviceContext* device_context, bool on_host,
426     StatusCallback done) {
427   RemoteMemoryRegion remote_mr;
428   if (!transport_options.UnpackTo(&remote_mr)) {
429     done(errors::NotFound("No RDMA transport options found"));
430     return;
431   }
432 
433   rdma_cm_id* id = nullptr;
434   {
435     decltype(clients_)::iterator iter;
436     bool success;
437     mutex_lock l(client_mu_);
438     std::tie(iter, success) = clients_.insert(
439         std::make_pair(std::make_pair(remote_mr.host(), remote_mr.port()),
440                        RdmaEndpointPtr(nullptr, EndpointDeleter)));
441     if (success || iter->second.get() == nullptr) {
442       Status s =
443           CreateEndpoint(remote_mr.host(), remote_mr.port(), iter->second);
444       if (!s.ok()) {
445         done(s);
446         return;
447       }
448     }
449     id = iter->second.get();
450   }
451 
452   ibv_mr* mr = FindMemoryRegion(tensor);
453   const TensorBuffer* buffer = DMAHelper::buffer(tensor);
454 
455   const Tensor* copy = nullptr;
456 
457   if (mr == nullptr) {
458     AllocatorAttributes alloc_attrs;
459     alloc_attrs.set_gpu_compatible(true);
460     alloc_attrs.set_nic_compatible(true);
461     alloc_attrs.set_on_host(true);
462     Allocator* alloc = device->GetAllocator(alloc_attrs);
463     copy = new Tensor(alloc, tensor->dtype(), tensor->shape());
464 
465     mr = FindMemoryRegion(copy);
466     buffer = DMAHelper::buffer(copy);
467     if (mr == nullptr) {
468       done(errors::Unavailable("Cannot find pinned memory region"));
469       delete copy;
470       return;
471     }
472   }
473 
474   uint64_t start = Env::Default()->NowMicros();
475 
476   TensorKey tensor_key = remote_mr.tensor_key();
477 
478   StatusCallback callback = [done, copy, device, device_context, on_host,
479                              tensor, start, tensor_key](const Status& s) {
480     if (!s.ok()) {
481       done(s);
482       if (copy) {
483         delete copy;
484       }
485       return;
486     }
487 
488     VLOG(2) << "RDMA of tensor " << tensor_key << " of size "
489             << DMAHelper::buffer(tensor)->size() << " took "
490             << (Env::Default()->NowMicros() - start) << " micros";
491 
492     if (copy && device->tensorflow_gpu_device_info() && !on_host) {
493       device_context->CopyCPUTensorToDevice(copy, device, tensor,
494                                             [done, copy](const Status& s) {
495                                               done(s);
496                                               delete copy;
497                                             });
498     } else if (copy) {
499       std::memcpy(DMAHelper::buffer(tensor)->data(),
500                   DMAHelper::buffer(copy)->data(),
501                   DMAHelper::buffer(copy)->size());
502       done(s);
503       delete copy;
504     } else {
505       done(s);
506     }
507   };
508 
509   {
510     mutex_lock l(callback_mu_);
511     if (tensor_callbacks_.find(tensor_key) == std::end(tensor_callbacks_)) {
512       tensor_callbacks_.insert(std::make_pair(tensor_key, std::move(callback)));
513     } else {
514       done(errors::Unavailable("Received duplicated tensor key"));
515       if (copy) {
516         delete copy;
517       }
518       return;
519     }
520   }
521 
522   if (rdma_post_read(id, reinterpret_cast<void*>(tensor_key), buffer->data(),
523                      buffer->size(), mr, IBV_SEND_SIGNALED, remote_mr.addr(),
524                      remote_mr.rkey())) {
525     done(errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed"));
526     {
527       mutex_lock l(callback_mu_);
528       auto iter = tensor_callbacks_.find(tensor_key);
529       if (iter != std::end(tensor_callbacks_)) {
530         tensor_callbacks_.erase(iter);
531       }
532     }
533     if (copy) {
534       delete copy;
535     }
536   }
537 }
538 
CreateEndpoint(const string & host,const string & port,RdmaEndpointPtr & endpoint)539 Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port,
540                                         RdmaEndpointPtr& endpoint) {
541   rdma_addrinfo* addrinfo;
542   rdma_addrinfo hints = {};
543   hints.ai_port_space = RDMA_PS_TCP;
544   if (rdma_getaddrinfo(const_cast<char*>(host.c_str()),
545                        const_cast<char*>(port.c_str()), &hints, &addrinfo)) {
546     return errors::InvalidArgument(
547         strerror(errno), ": ", "cannot connect to rdma://", host, ":", port);
548   }
549 
550   ibv_qp_init_attr init_attr = {};
551   init_attr.qp_type = IBV_QPT_RC;
552   init_attr.cap.max_recv_wr = 1;
553   init_attr.cap.max_send_wr = 1024;
554   init_attr.cap.max_recv_sge = 1;
555   init_attr.cap.max_send_sge = 1;
556 
557   rdma_cm_id* id;
558   if (rdma_create_ep(&id, addrinfo, nullptr, &init_attr)) {
559     rdma_freeaddrinfo(addrinfo);
560     return errors::Unavailable(strerror(errno), ": ",
561                                "cannot create endpoint to rdma://", host, ":",
562                                port);
563   }
564   rdma_freeaddrinfo(addrinfo);
565 
566   if (rdma_connect(id, nullptr)) {
567     rdma_destroy_ep(id);
568     return errors::Unavailable(strerror(errno), ": ",
569                                "cannot connect to rdma://", host, ":", port);
570   }
571 
572   LOG(INFO) << "RDMA endpoint connected to rdma://" << host << ":" << port;
573   endpoint = RdmaEndpointPtr(id, EndpointDeleter);
574   return Status::OK();
575 }
576 
FindMemoryRegion(const Tensor * tensor)577 ibv_mr* GdrMemoryManager::FindMemoryRegion(const Tensor* tensor) {
578   const void* addr = DMAHelper::buffer(tensor)->data();
579   mutex_lock l(alloc_mu_);
580   auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
581   if (iter == std::end(mrs_) || iter->get()->addr > addr) {
582     return nullptr;
583   } else {
584     return iter->get();
585   }
586 }
587 
InsertMemoryRegion(void * addr,size_t length,const std::string & allocator_name)588 void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length,
589                                           const std::string& allocator_name) {
590   if (length == 0) return;
591   ibv_mr* mr = rdma_reg_read(listening_.get(), addr, length);
592   if (mr != nullptr) {
593     mutex_lock l(alloc_mu_);
594     auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
595     mrs_.insert(iter, {mr, &MRDeleter});
596   } else {
597     LOG(WARNING) << "Cannot register memory region allocated by "
598                  << allocator_name;
599   }
600 }
601 
EvictMemoryRegion(void * addr,size_t length)602 void GdrMemoryManager::EvictMemoryRegion(void* addr, size_t length) {
603   if (length == 0) return;
604   mutex_lock l(alloc_mu_);
605   auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
606   if (iter != std::end(mrs_) && iter->get()->addr == addr) {
607     mrs_.erase(iter);
608   } else {
609     LOG(WARNING) << "Failed to de-register memory region";
610   }
611 }
612 
613 }  // namespace
614 
CreateRemoteMemoryManager(const string & host,const string & port)615 RemoteMemoryManager* CreateRemoteMemoryManager(const string& host,
616                                                const string& port) {
617   return new GdrMemoryManager(host, port);
618 }
619 
620 }  // namespace tensorflow
621 
622 #endif  // TENSORFLOW_USE_GDR
623