• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7 
8  * http://www.apache.org/licenses/LICENSE-2.0
9 
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15 */
16 #include "minddata/dataset/engine/cache/cache_server.h"
17 #include <algorithm>
18 #include <functional>
19 #include <limits>
20 #include <vector>
21 #include "minddata/dataset/include/dataset/constants.h"
22 #include "minddata/dataset/engine/cache/cache_ipc.h"
23 #include "minddata/dataset/engine/cache/cache_service.h"
24 #include "minddata/dataset/engine/cache/cache_request.h"
25 #include "minddata/dataset/util/bit.h"
26 #include "minddata/dataset/util/path.h"
27 #include "minddata/dataset/util/random.h"
28 #ifdef CACHE_LOCAL_CLIENT
29 #include "minddata/dataset/util/sig_handler.h"
30 #endif
31 
32 namespace mindspore {
33 namespace dataset {
34 CacheServer *CacheServer::instance_ = nullptr;
35 std::once_flag CacheServer::init_instance_flag_;
DoServiceStart()36 Status CacheServer::DoServiceStart() {
37 #ifdef CACHE_LOCAL_CLIENT
38   // We need to destroy the shared memory if user hits Control-C
39   RegisterHandlers();
40 #endif
41   if (!top_.empty()) {
42     Path spill(top_);
43     RETURN_IF_NOT_OK(spill.CreateDirectories());
44     MS_LOG(INFO) << "CacheServer will use disk folder: " << top_;
45   }
46   RETURN_IF_NOT_OK(vg_.ServiceStart());
47   auto num_numa_nodes = GetNumaNodeCount();
48   // If we link with numa library. Set default memory policy.
49   // If we don't pin thread to cpu, then use up all memory controllers to maximize
50   // memory bandwidth.
51   RETURN_IF_NOT_OK(
52     CacheServerHW::SetDefaultMemoryPolicy(numa_affinity_ ? CachePoolPolicy::kLocal : CachePoolPolicy::kInterleave));
53   auto my_node = hw_info_->GetMyNode();
54   MS_LOG(DEBUG) << "Cache server is running on numa node " << my_node;
55   // There will be some threads working on the grpc queue and
56   // some number of threads working on the CacheServerRequest queue.
57   // Like a connector object we will set up the same number of queues but
58   // we do not need to preserve any order. We will set the capacity of
59   // each queue to be 64 since we are just pushing memory pointers which
60   // is only 8 byte each.
61   const int32_t kQueCapacity = 64;
62   // This is the request queue from the client
63   cache_q_ = std::make_shared<QueueList<CacheServerRequest *>>();
64   cache_q_->Init(num_workers_, kQueCapacity);
65   // We will match the number of grpc workers with the number of server workers.
66   // But technically they don't have to be the same.
67   num_grpc_workers_ = num_workers_;
68   MS_LOG(DEBUG) << "Number of gprc workers is set to " << num_grpc_workers_;
69   RETURN_IF_NOT_OK(cache_q_->Register(&vg_));
70   // Start the comm layer
71   try {
72     comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_);
73     RETURN_IF_NOT_OK(comm_layer_->Run());
74   } catch (const std::exception &e) {
75     RETURN_STATUS_UNEXPECTED(e.what());
76   }
77 #ifdef CACHE_LOCAL_CLIENT
78   RETURN_IF_NOT_OK(CachedSharedMemory::CreateArena(&shm_, port_, shared_memory_sz_in_gb_));
79   // Bring up a thread to monitor the unix socket in case it is removed. But it must be done
80   // after we have created the unix socket.
81   auto inotify_f = std::bind(&CacheServerGreeterImpl::MonitorUnixSocket, comm_layer_.get());
82   RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Monitor unix socket", inotify_f));
83 #endif
84   // Spawn a few threads to serve the real request.
85   auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1);
86   for (auto i = 0; i < num_workers_; ++i) {
87     Task *pTask;
88     RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache service worker", std::bind(f, i), &pTask));
89     // Save a copy of the pointer to the underlying Task object. We may dynamically change their affinity if needed.
90     numa_tasks_.emplace(i, pTask);
91     // Spread out all the threads to all the numa nodes if needed
92     if (IsNumaAffinityOn()) {
93       auto numa_id = i % num_numa_nodes;
94       RETURN_IF_NOT_OK(SetAffinity(*pTask, numa_id));
95     }
96   }
97   // Finally loop forever to handle the request.
98   auto r = std::bind(&CacheServer::RpcRequest, this, std::placeholders::_1);
99   for (auto i = 0; i < num_grpc_workers_; ++i) {
100     Task *pTask;
101     RETURN_IF_NOT_OK(vg_.CreateAsyncTask("rpc worker", std::bind(r, i), &pTask));
102     // All these grpc workers will be allocated to the same node which is where we allocate all those free tag
103     // memory.
104     if (IsNumaAffinityOn()) {
105       RETURN_IF_NOT_OK(SetAffinity(*pTask, i % num_numa_nodes));
106     }
107   }
108   return Status::OK();
109 }
110 
DoServiceStop()111 Status CacheServer::DoServiceStop() {
112   Status rc;
113   Status rc2;
114   // First stop all the threads.
115   RETURN_IF_NOT_OK(vg_.ServiceStop());
116   // Clean up all the caches if any.
117   UniqueLock lck(&rwLock_);
118   auto it = all_caches_.begin();
119   while (it != all_caches_.end()) {
120     auto cs = std::move(it->second);
121     rc2 = cs->ServiceStop();
122     if (rc2.IsError()) {
123       rc = rc2;
124     }
125     ++it;
126   }
127   // Also remove the path we use to generate ftok.
128   Path p(PortToUnixSocketPath(port_));
129   (void)p.Remove();
130   // Finally wake up cache_admin if it is waiting
131   for (int32_t qID : shutdown_qIDs_) {
132     SharedMessage msg(qID);
133     RETURN_IF_NOT_OK(msg.SendStatus(Status::OK()));
134     msg.RemoveResourcesOnExit();
135     // Let msg goes out of scope which will destroy the queue.
136   }
137   return rc;
138 }
139 
GetService(connection_id_type id) const140 CacheService *CacheServer::GetService(connection_id_type id) const {
141   auto it = all_caches_.find(id);
142   if (it != all_caches_.end()) {
143     return it->second.get();
144   }
145   return nullptr;
146 }
147 
148 // We would like to protect ourselves from over allocating too much. We will go over existing cache
149 // and calculate how much we have consumed so far.
GlobalMemoryCheck(uint64_t cache_mem_sz)150 Status CacheServer::GlobalMemoryCheck(uint64_t cache_mem_sz) {
151   auto end = all_caches_.end();
152   auto it = all_caches_.begin();
153   auto avail_mem = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_;
154   int64_t max_avail = avail_mem;
155   while (it != end) {
156     auto &cs = it->second;
157     CacheService::ServiceStat stat;
158     RETURN_IF_NOT_OK(cs->GetStat(&stat));
159     int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz;
160     max_avail -= mem_consumed;
161     if (max_avail <= 0) {
162       return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
163     }
164     ++it;
165   }
166 
167   // If we have some cache using some memory already, make a reasonable decision if we should return
168   // out of memory.
169   if (max_avail < avail_mem) {
170     int64_t req_mem = cache_mem_sz * 1048576L;  // It is in MB unit.
171     if (req_mem > max_avail) {
172       return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
173     } else if (req_mem == 0) {
174       // This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than
175       // 85% of our limit, fail this request.
176       if (static_cast<float>(max_avail) / static_cast<float>(avail_mem) <= kMemoryBottomLineForNewService) {
177         return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
178       }
179     }
180   }
181   return Status::OK();
182 }
183 
CreateService(CacheRequest * rq,CacheReply * reply)184 Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
185   CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing connection info");
186   std::string cookie;
187   int32_t client_id;
188   auto session_id = rq->connection_info().session_id();
189   auto crc = rq->connection_info().crc();
190 
191   // Before allowing the creation, make sure the session had already been created by the user
192   // Our intention is to add this cache to the active sessions list so leave the list locked during
193   // this entire function.
194   UniqueLock sess_lck(&sessions_lock_);
195   auto session_it = active_sessions_.find(session_id);
196   if (session_it == active_sessions_.end()) {
197     RETURN_STATUS_UNEXPECTED("A cache creation has been requested but the session was not found!");
198   }
199 
200   // We concat both numbers to form the internal connection id.
201   auto connection_id = GetConnectionID(session_id, crc);
202   CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing info to create cache");
203   auto &create_cache_buf = rq->buf_data(0);
204   auto p = flatbuffers::GetRoot<CreateCacheRequestMsg>(create_cache_buf.data());
205   auto flag = static_cast<CreateCacheRequest::CreateCacheFlag>(p->flag());
206   auto cache_mem_sz = p->cache_mem_sz();
207   // We can't do spilling unless this server is setup with a spill path in the first place
208   bool spill =
209     (flag & CreateCacheRequest::CreateCacheFlag::kSpillToDisk) == CreateCacheRequest::CreateCacheFlag::kSpillToDisk;
210   bool generate_id =
211     (flag & CreateCacheRequest::CreateCacheFlag::kGenerateRowId) == CreateCacheRequest::CreateCacheFlag::kGenerateRowId;
212   if (spill && top_.empty()) {
213     RETURN_STATUS_UNEXPECTED("Server is not set up with spill support.");
214   }
215   // Before creating the cache, first check if this is a request for a shared usage of an existing cache
216   // If two CreateService come in with identical connection_id, we need to serialize the create.
217   // The first create will be successful and be given a special cookie.
218   UniqueLock lck(&rwLock_);
219   bool duplicate = false;
220   CacheService *curr_cs = GetService(connection_id);
221   if (curr_cs != nullptr) {
222     duplicate = true;
223     client_id = curr_cs->num_clients_.fetch_add(1);
224     MS_LOG(INFO) << "Duplicate request from client " + std::to_string(client_id) + " for " +
225                       std::to_string(connection_id) + " to create cache service";
226   }
227   // Early exit if we are doing global shutdown
228   if (global_shutdown_) {
229     return Status::OK();
230   }
231 
232   if (!duplicate) {
233     RETURN_IF_NOT_OK(GlobalMemoryCheck(cache_mem_sz));
234     std::unique_ptr<CacheService> cs;
235     try {
236       cs = std::make_unique<CacheService>(cache_mem_sz, spill ? top_ : "", generate_id);
237       RETURN_IF_NOT_OK(cs->ServiceStart());
238       cookie = cs->cookie();
239       client_id = cs->num_clients_.fetch_add(1);
240       all_caches_.emplace(connection_id, std::move(cs));
241     } catch (const std::bad_alloc &e) {
242       return Status(StatusCode::kMDOutOfMemory);
243     }
244   }
245 
246   // Shuffle the worker threads. But we need to release the locks or we will deadlock when calling
247   // the following function
248   lck.Unlock();
249   sess_lck.Unlock();
250   auto numa_id = client_id % GetNumaNodeCount();
251   std::vector<cpu_id_t> cpu_list = hw_info_->GetCpuList(numa_id);
252   // Send back the data
253   flatbuffers::FlatBufferBuilder fbb;
254   flatbuffers::Offset<flatbuffers::String> off_cookie;
255   flatbuffers::Offset<flatbuffers::Vector<cpu_id_t>> off_cpu_list;
256   off_cookie = fbb.CreateString(cookie);
257   off_cpu_list = fbb.CreateVector(cpu_list);
258   CreateCacheReplyMsgBuilder bld(fbb);
259   bld.add_connection_id(connection_id);
260   bld.add_cookie(off_cookie);
261   bld.add_client_id(client_id);
262   // The last thing we send back is a set of cpu id that we suggest the client should bind itself to
263   bld.add_cpu_id(off_cpu_list);
264   auto off = bld.Finish();
265   fbb.Finish(off);
266   reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
267   // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it
268   // treat it as OK.
269   return duplicate ? Status(StatusCode::kMDDuplicateKey) : Status::OK();
270 }
271 
DestroyCache(CacheRequest * rq)272 Status CacheServer::DestroyCache(CacheRequest *rq) {
273   // We need a strong lock to protect the map.
274   UniqueLock lck(&rwLock_);
275   auto id = rq->connection_id();
276   CacheService *cs = GetService(id);
277   // it is already destroyed. Ignore it.
278   if (cs != nullptr) {
279     MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id);
280     // std::map will invoke the destructor of CacheService. So we don't need to do anything here.
281     auto n = all_caches_.erase(id);
282     if (n == 0) {
283       // It has been destroyed by another duplicate request.
284       MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service";
285     }
286   }
287   // We aren't touching the session list even though we may be dropping the last remaining cache of a session.
288   // Leave that to be done by the drop session command.
289   return Status::OK();
290 }
291 
CacheRow(CacheRequest * rq,CacheReply * reply)292 Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) {
293   auto connection_id = rq->connection_id();
294   // Hold the shared lock to prevent the cache from being dropped.
295   SharedLock lck(&rwLock_);
296   CacheService *cs = GetService(connection_id);
297   if (cs == nullptr) {
298     std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
299     return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
300   } else {
301     auto sz = rq->buf_data_size();
302     std::vector<const void *> buffers;
303     buffers.reserve(sz);
304     // First piece of data is the cookie and is required
305     CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie");
306     auto &cookie = rq->buf_data(0);
307     // Only if the cookie matches, we can accept insert into this cache that has a build phase
308     if (!cs->HasBuildPhase() || cookie == cs->cookie()) {
309       // Push the address of each buffer (in the form of std::string coming in from protobuf) into
310       // a vector of buffer
311       for (auto i = 1; i < sz; ++i) {
312         buffers.push_back(rq->buf_data(i).data());
313       }
314       row_id_type id = -1;
315       // We will allocate the memory the same numa node this thread is bound to.
316       RETURN_IF_NOT_OK(cs->CacheRow(buffers, &id));
317       reply->set_result(std::to_string(id));
318     } else {
319       return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
320     }
321   }
322   return Status::OK();
323 }
324 
FastCacheRow(CacheRequest * rq,CacheReply * reply)325 Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
326   auto connection_id = rq->connection_id();
327   auto client_id = rq->client_id();
328   CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
329   // Hold the shared lock to prevent the cache from being dropped.
330   SharedLock lck(&rwLock_);
331   CacheService *cs = GetService(connection_id);
332   auto *base = SharedMemoryBaseAddr();
333   // Ensure we got 3 pieces of data coming in
334   constexpr int32_t kMinBufDataSize = 3;
335   CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() >= kMinBufDataSize, "Incomplete data");
336   // First one is cookie, followed by data address and then size.
337   enum BufDataIndex : uint8_t { kCookie = 0, kAddr = 1, kSize = 2 };
338   // First piece of data is the cookie and is required
339   auto &cookie = rq->buf_data(BufDataIndex::kCookie);
340   // Second piece of data is the address where we can find the serialized data
341   auto addr = strtoll(rq->buf_data(BufDataIndex::kAddr).data(), nullptr, kDecimal);
342   auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr);
343   // Third piece of data is the size of the serialized data that we need to transfer
344   auto sz = strtoll(rq->buf_data(BufDataIndex::kSize).data(), nullptr, kDecimal);
345   // Successful or not, we need to free the memory on exit.
346   Status rc;
347   if (cs == nullptr) {
348     std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
349     rc = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
350   } else {
351     // Only if the cookie matches, we can accept insert into this cache that has a build phase
352     if (!cs->HasBuildPhase() || cookie == cs->cookie()) {
353       row_id_type id = -1;
354       ReadableSlice src(p, sz);
355       // We will allocate the memory the same numa node this thread is bound to.
356       rc = cs->FastCacheRow(src, &id);
357       reply->set_result(std::to_string(id));
358     } else {
359       auto state = cs->GetState();
360       if (state != CacheServiceState::kFetchPhase) {
361         rc = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
362                     "Cache service is not in fetch phase. The current phase is " +
363                       std::to_string(static_cast<int8_t>(state)) + ". Client id: " + std::to_string(client_id));
364       } else {
365         rc = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
366                     "Cookie mismatch. Client id: " + std::to_string(client_id));
367       }
368     }
369   }
370   // Return the block to the shared memory only if it is not internal request.
371   if (static_cast<BaseRequest::RequestType>(rq->type()) == BaseRequest::RequestType::kCacheRow) {
372     DeallocateSharedMemory(client_id, p);
373   }
374   return rc;
375 }
376 
InternalCacheRow(CacheRequest * rq,CacheReply * reply)377 Status CacheServer::InternalCacheRow(CacheRequest *rq, CacheReply *reply) {
378   // Look into the flag to see where we can find the data and call the appropriate method.
379   auto flag = rq->flag();
380   Status rc;
381   if (BitTest(flag, kDataIsInSharedMemory)) {
382     rc = FastCacheRow(rq, reply);
383     // This is an internal request and is not tied to rpc. But need to post because there
384     // is a thread waiting on the completion of this request.
385     try {
386       constexpr int32_t kBatchWaitIdx = 3;
387       // Fourth piece of the data is the address of the BatchWait ptr
388       int64_t addr = strtol(rq->buf_data(kBatchWaitIdx).data(), nullptr, kDecimal);
389       auto *bw = reinterpret_cast<BatchWait *>(addr);
390       // Check if the object is still around.
391       auto bwObj = bw->GetBatchWait();
392       if (bwObj.lock()) {
393         RETURN_IF_NOT_OK(bw->Set(rc));
394       }
395     } catch (const std::exception &e) {
396       RETURN_STATUS_UNEXPECTED(e.what());
397     }
398   } else {
399     rc = CacheRow(rq, reply);
400   }
401   return rc;
402 }
403 
InternalFetchRow(CacheRequest * rq)404 Status CacheServer::InternalFetchRow(CacheRequest *rq) {
405   auto connection_id = rq->connection_id();
406   SharedLock lck(&rwLock_);
407   CacheService *cs = GetService(connection_id);
408   Status rc;
409   if (cs == nullptr) {
410     std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
411     return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
412   }
413   // First piece is a flatbuffer containing row fetch information, second piece is the address of the BatchWait ptr
414   enum BufDataIndex : uint8_t { kFetchRowMsg = 0, kBatchWait = 1 };
415   rc = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq->buf_data(BufDataIndex::kFetchRowMsg).data()));
416   // This is an internal request and is not tied to rpc. But need to post because there
417   // is a thread waiting on the completion of this request.
418   try {
419     int64_t addr = strtol(rq->buf_data(BufDataIndex::kBatchWait).data(), nullptr, kDecimal);
420     auto *bw = reinterpret_cast<BatchWait *>(addr);
421     // Check if the object is still around.
422     auto bwObj = bw->GetBatchWait();
423     if (bwObj.lock()) {
424       RETURN_IF_NOT_OK(bw->Set(rc));
425     }
426   } catch (const std::exception &e) {
427     RETURN_STATUS_UNEXPECTED(e.what());
428   }
429   return rc;
430 }
431 
BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> & fbb,WritableSlice * out)432 Status CacheServer::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out) {
433   RETURN_UNEXPECTED_IF_NULL(out);
434   auto p = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer());
435   const auto num_elements = p->rows()->size();
436   auto connection_id = p->connection_id();
437   auto batch_wait = std::make_shared<BatchWait>(num_elements);
438   int64_t data_offset = (num_elements + 1) * sizeof(int64_t);
439   auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer());
440   offset_array[0] = data_offset;
441   for (uint32_t i = 0; i < num_elements; ++i) {
442     auto data_locator = p->rows()->Get(i);
443     auto node_id = data_locator->node_id();
444     size_t sz = data_locator->size();
445     void *source_addr = reinterpret_cast<void *>(data_locator->addr());
446     auto key = data_locator->key();
447     // Please read the comment in CacheServer::BatchFetchRows where we allocate
448     // the buffer big enough so each thread (which we are going to dispatch) will
449     // not run into false sharing problem. We are going to round up sz to 4k.
450     auto sz_4k = round_up_4K(sz);
451     offset_array[i + 1] = offset_array[i] + sz_4k;
452     if (sz > 0) {
453       WritableSlice row_data(*out, offset_array[i], sz);
454       // Get a request and send to the proper worker (at some numa node) to do the fetch.
455       worker_id_t worker_id = IsNumaAffinityOn() ? GetWorkerByNumaId(node_id) : GetRandomWorker();
456       CacheServerRequest *cache_rq;
457       RETURN_IF_NOT_OK(GetFreeRequestTag(&cache_rq));
458       // Set up all the necessarily field.
459       cache_rq->type_ = BaseRequest::RequestType::kInternalFetchRow;
460       cache_rq->st_ = CacheServerRequest::STATE::PROCESS;
461       cache_rq->rq_.set_connection_id(connection_id);
462       cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_));
463       auto dest_addr = row_data.GetMutablePointer();
464       flatbuffers::FlatBufferBuilder fb2;
465       FetchRowMsgBuilder bld(fb2);
466       bld.add_key(key);
467       bld.add_size(sz);
468       bld.add_source_addr(reinterpret_cast<int64_t>(source_addr));
469       bld.add_dest_addr(reinterpret_cast<int64_t>(dest_addr));
470       auto offset = bld.Finish();
471       fb2.Finish(offset);
472       cache_rq->rq_.add_buf_data(fb2.GetBufferPointer(), fb2.GetSize());
473       cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast<int64_t>(batch_wait.get())));
474       RETURN_IF_NOT_OK(PushRequest(worker_id, cache_rq));
475     } else {
476       // Nothing to fetch but we still need to post something back into the wait area.
477       RETURN_IF_NOT_OK(batch_wait->Set(Status::OK()));
478     }
479   }
480   // Now wait for all of them to come back.
481   RETURN_IF_NOT_OK(batch_wait->Wait());
482   // Return the result
483   return batch_wait->GetRc();
484 }
485 
BatchFetchRows(CacheRequest * rq,CacheReply * reply)486 Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
487   auto connection_id = rq->connection_id();
488   auto client_id = rq->client_id();
489   // Hold the shared lock to prevent the cache from being dropped.
490   SharedLock lck(&rwLock_);
491   CacheService *cs = GetService(connection_id);
492   if (cs == nullptr) {
493     std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
494     return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
495   } else {
496     CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing row id");
497     auto &row_id_buf = rq->buf_data(0);
498     auto p = flatbuffers::GetRoot<TensorRowIds>(row_id_buf.data());
499     std::vector<row_id_type> row_id;
500     auto sz = p->row_id()->size();
501     row_id.reserve(sz);
502     for (uint32_t i = 0; i < sz; ++i) {
503       row_id.push_back(p->row_id()->Get(i));
504     }
505     std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
506     RETURN_IF_NOT_OK(cs->PreBatchFetch(connection_id, row_id, fbb));
507     // Let go of the shared lock. We don't need to interact with the CacheService anymore.
508     // We shouldn't be holding any lock while we can wait for a long time for the rows to come back.
509     lck.Unlock();
510     auto locator = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer());
511     int64_t mem_sz = sizeof(int64_t) * (sz + 1);
512     for (auto i = 0; i < sz; ++i) {
513       auto row_sz = locator->rows()->Get(i)->size();
514       // row_sz is the size of the cached data. Later we will spawn multiple threads
515       // each of which will copy the data into either shared memory or protobuf concurrently but
516       // to different region.
517       // To avoid false sharing, we will bump up row_sz to be a multiple of 4k, i.e. 4096 bytes
518       row_sz = round_up_4K(row_sz);
519       mem_sz += row_sz;
520     }
521     auto client_flag = rq->flag();
522     bool local_client = BitTest(client_flag, kLocalClientSupport);
523     // For large amount data to be sent back, we will use shared memory provided it is a local
524     // client that has local bypass support
525     bool local_bypass = local_client ? (mem_sz >= kLocalByPassThreshold) : false;
526     reply->set_flag(local_bypass ? kDataIsInSharedMemory : 0);
527     if (local_bypass) {
528       // We will use shared memory
529       auto *base = SharedMemoryBaseAddr();
530       void *q = nullptr;
531       RETURN_IF_NOT_OK(AllocateSharedMemory(client_id, mem_sz, &q));
532       WritableSlice dest(q, mem_sz);
533       Status rc = BatchFetch(fbb, &dest);
534       if (rc.IsError()) {
535         DeallocateSharedMemory(client_id, q);
536         return rc;
537       }
538       // We can't return the absolute address which makes no sense to the client.
539       // Instead we return the difference.
540       auto difference = reinterpret_cast<int64_t>(q) - reinterpret_cast<int64_t>(base);
541       reply->set_result(std::to_string(difference));
542     } else {
543       // We are going to use std::string to allocate and hold the result which will be eventually
544       // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
545       // to minimize memory copy.
546       std::string mem;
547       try {
548         mem.resize(mem_sz);
549         CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= mem_sz, "Programming error");
550       } catch (const std::bad_alloc &e) {
551         return Status(StatusCode::kMDOutOfMemory);
552       }
553       WritableSlice dest(mem.data(), mem_sz);
554       RETURN_IF_NOT_OK(BatchFetch(fbb, &dest));
555       reply->set_result(std::move(mem));
556     }
557   }
558   return Status::OK();
559 }
560 
GetStat(CacheRequest * rq,CacheReply * reply)561 Status CacheServer::GetStat(CacheRequest *rq, CacheReply *reply) {
562   auto connection_id = rq->connection_id();
563   // Hold the shared lock to prevent the cache from being dropped.
564   SharedLock lck(&rwLock_);
565   CacheService *cs = GetService(connection_id);
566   if (cs == nullptr) {
567     std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
568     return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
569   } else {
570     CacheService::ServiceStat svc_stat;
571     RETURN_IF_NOT_OK(cs->GetStat(&svc_stat));
572     flatbuffers::FlatBufferBuilder fbb;
573     ServiceStatMsgBuilder bld(fbb);
574     bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached);
575     bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached);
576     bld.add_avg_cache_sz(svc_stat.stat_.average_cache_sz);
577     bld.add_num_numa_hit(svc_stat.stat_.num_numa_hit);
578     bld.add_max_row_id(svc_stat.stat_.max_key);
579     bld.add_min_row_id(svc_stat.stat_.min_key);
580     bld.add_state(svc_stat.state_);
581     auto offset = bld.Finish();
582     fbb.Finish(offset);
583     reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
584   }
585   return Status::OK();
586 }
587 
CacheSchema(CacheRequest * rq)588 Status CacheServer::CacheSchema(CacheRequest *rq) {
589   auto connection_id = rq->connection_id();
590   // Hold the shared lock to prevent the cache from being dropped.
591   SharedLock lck(&rwLock_);
592   CacheService *cs = GetService(connection_id);
593   if (cs == nullptr) {
594     std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
595     return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
596   } else {
597     CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing schema information");
598     auto &create_schema_buf = rq->buf_data(0);
599     RETURN_IF_NOT_OK(cs->CacheSchema(create_schema_buf.data(), create_schema_buf.size()));
600   }
601   return Status::OK();
602 }
603 
FetchSchema(CacheRequest * rq,CacheReply * reply)604 Status CacheServer::FetchSchema(CacheRequest *rq, CacheReply *reply) {
605   auto connection_id = rq->connection_id();
606   // Hold the shared lock to prevent the cache from being dropped.
607   SharedLock lck(&rwLock_);
608   CacheService *cs = GetService(connection_id);
609   if (cs == nullptr) {
610     std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
611     return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
612   } else {
613     // We are going to use std::string to allocate and hold the result which will be eventually
614     // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
615     // to minimize memory copy.
616     std::string mem;
617     RETURN_IF_NOT_OK(cs->FetchSchema(&mem));
618     reply->set_result(std::move(mem));
619   }
620   return Status::OK();
621 }
622 
BuildPhaseDone(CacheRequest * rq)623 Status CacheServer::BuildPhaseDone(CacheRequest *rq) {
624   auto connection_id = rq->connection_id();
625   // Hold the shared lock to prevent the cache from being dropped.
626   SharedLock lck(&rwLock_);
627   CacheService *cs = GetService(connection_id);
628   if (cs == nullptr) {
629     std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
630     return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
631   } else {
632     // First piece of data is the cookie
633     CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie");
634     auto &cookie = rq->buf_data(0);
635     // We can only allow to switch phase if the cookie match.
636     if (cookie == cs->cookie()) {
637       RETURN_IF_NOT_OK(cs->BuildPhaseDone());
638     } else {
639       return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
640     }
641   }
642   return Status::OK();
643 }
644 
GetCacheMissKeys(CacheRequest * rq,CacheReply * reply)645 Status CacheServer::GetCacheMissKeys(CacheRequest *rq, CacheReply *reply) {
646   auto connection_id = rq->connection_id();
647   // Hold the shared lock to prevent the cache from being dropped.
648   SharedLock lck(&rwLock_);
649   CacheService *cs = GetService(connection_id);
650   if (cs == nullptr) {
651     std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
652     return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
653   } else {
654     std::vector<row_id_type> gap;
655     RETURN_IF_NOT_OK(cs->FindKeysMiss(&gap));
656     flatbuffers::FlatBufferBuilder fbb;
657     auto off_t = fbb.CreateVector(gap);
658     TensorRowIdsBuilder bld(fbb);
659     bld.add_row_id(off_t);
660     auto off = bld.Finish();
661     fbb.Finish(off);
662     reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
663   }
664   return Status::OK();
665 }
666 
GenerateClientSessionID(session_id_type session_id,CacheReply * reply)667 inline Status GenerateClientSessionID(session_id_type session_id, CacheReply *reply) {
668   reply->set_result(std::to_string(session_id));
669   MS_LOG(INFO) << "Server generated new session id " << session_id;
670   return Status::OK();
671 }
672 
ToggleWriteMode(CacheRequest * rq)673 Status CacheServer::ToggleWriteMode(CacheRequest *rq) {
674   auto connection_id = rq->connection_id();
675   // Hold the shared lock to prevent the cache from being dropped.
676   SharedLock lck(&rwLock_);
677   CacheService *cs = GetService(connection_id);
678   if (cs == nullptr) {
679     std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
680     return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
681   } else {
682     // First piece of data is the on/off flag
683     CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing action flag");
684     const auto &action = rq->buf_data(0);
685     bool on_off = false;
686     if (strcmp(action.data(), "on") == 0) {
687       on_off = true;
688     } else if (strcmp(action.data(), "off") == 0) {
689       on_off = false;
690     } else {
691       RETURN_STATUS_UNEXPECTED("Unknown request: " + action);
692     }
693     RETURN_IF_NOT_OK(cs->ToggleWriteMode(on_off));
694   }
695   return Status::OK();
696 }
697 
ListSessions(CacheReply * reply)698 Status CacheServer::ListSessions(CacheReply *reply) {
699   SharedLock sess_lck(&sessions_lock_);
700   SharedLock lck(&rwLock_);
701   flatbuffers::FlatBufferBuilder fbb;
702   std::vector<flatbuffers::Offset<ListSessionMsg>> session_msgs_vector;
703   for (auto const &current_session_id : active_sessions_) {
704     bool found = false;
705     for (auto const &it : all_caches_) {
706       auto current_conn_id = it.first;
707       if (GetSessionID(current_conn_id) == current_session_id) {
708         found = true;
709         auto &cs = it.second;
710         CacheService::ServiceStat svc_stat;
711         RETURN_IF_NOT_OK(cs->GetStat(&svc_stat));
712         auto current_stats = CreateServiceStatMsg(fbb, svc_stat.stat_.num_mem_cached, svc_stat.stat_.num_disk_cached,
713                                                   svc_stat.stat_.average_cache_sz, svc_stat.stat_.num_numa_hit,
714                                                   svc_stat.stat_.min_key, svc_stat.stat_.max_key, svc_stat.state_);
715         auto current_session_info = CreateListSessionMsg(fbb, current_session_id, current_conn_id, current_stats);
716         session_msgs_vector.push_back(current_session_info);
717       }
718     }
719     if (!found) {
720       // If there is no cache created yet, assign a connection id of 0 along with empty stats
721       auto current_stats = CreateServiceStatMsg(fbb, 0, 0, 0, 0, 0, 0);
722       auto current_session_info = CreateListSessionMsg(fbb, current_session_id, 0, current_stats);
723       session_msgs_vector.push_back(current_session_info);
724     }
725   }
726   flatbuffers::Offset<flatbuffers::String> spill_dir;
727   spill_dir = fbb.CreateString(top_);
728   auto session_msgs = fbb.CreateVector(session_msgs_vector);
729   ListSessionsMsgBuilder s_builder(fbb);
730   s_builder.add_sessions(session_msgs);
731   s_builder.add_num_workers(num_workers_);
732   s_builder.add_log_level(log_level_);
733   s_builder.add_spill_dir(spill_dir);
734   auto offset = s_builder.Finish();
735   fbb.Finish(offset);
736   reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
737   return Status::OK();
738 }
739 
ConnectReset(CacheRequest * rq)740 Status CacheServer::ConnectReset(CacheRequest *rq) {
741   auto connection_id = rq->connection_id();
742   // Hold the shared lock to prevent the cache from being dropped.
743   SharedLock lck(&rwLock_);
744   CacheService *cs = GetService(connection_id);
745   if (cs == nullptr) {
746     std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
747     return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
748   } else {
749     auto client_id = rq->client_id();
750     MS_LOG(WARNING) << "Client id " << client_id << " with connection id " << connection_id << " disconnects";
751     cs->num_clients_--;
752   }
753   return Status::OK();
754 }
755 
BatchCacheRows(CacheRequest * rq)756 Status CacheServer::BatchCacheRows(CacheRequest *rq) {
757   // First one is cookie, followed by address and then size.
758   enum BufDataIndex : uint8_t { kCookie = 0, kAddr = 1, kSize = 2 };
759   constexpr int32_t kExpectedBufDataSize = 3;
760   CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data().size() == kExpectedBufDataSize, "Expect three pieces of data");
761   try {
762     auto &cookie = rq->buf_data(BufDataIndex::kCookie);
763     auto connection_id = rq->connection_id();
764     auto client_id = rq->client_id();
765     int64_t offset_addr;
766     int32_t num_elem;
767     auto *base = SharedMemoryBaseAddr();
768     offset_addr = strtoll(rq->buf_data(BufDataIndex::kAddr).data(), nullptr, kDecimal);
769     auto p = reinterpret_cast<char *>(reinterpret_cast<int64_t>(base) + offset_addr);
770     num_elem = static_cast<int32_t>(strtol(rq->buf_data(BufDataIndex::kSize).data(), nullptr, kDecimal));
771     auto batch_wait = std::make_shared<BatchWait>(num_elem);
772     // Get a set of free request and push into the queues.
773     for (auto i = 0; i < num_elem; ++i) {
774       auto start = reinterpret_cast<int64_t>(p);
775       auto msg = GetTensorRowHeaderMsg(p);
776       p += msg->size_of_this();
777       for (auto k = 0; k < msg->column()->size(); ++k) {
778         p += msg->data_sz()->Get(k);
779       }
780       CacheServerRequest *cache_rq;
781       RETURN_IF_NOT_OK(GetFreeRequestTag(&cache_rq));
782       // Fill in details.
783       cache_rq->type_ = BaseRequest::RequestType::kInternalCacheRow;
784       cache_rq->st_ = CacheServerRequest::STATE::PROCESS;
785       cache_rq->rq_.set_connection_id(connection_id);
786       cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_));
787       cache_rq->rq_.set_client_id(client_id);
788       cache_rq->rq_.set_flag(kDataIsInSharedMemory);
789       cache_rq->rq_.add_buf_data(cookie);
790       cache_rq->rq_.add_buf_data(std::to_string(start - reinterpret_cast<int64_t>(base)));
791       cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast<int64_t>(p - start)));
792       cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast<int64_t>(batch_wait.get())));
793       RETURN_IF_NOT_OK(PushRequest(GetRandomWorker(), cache_rq));
794     }
795     // Now wait for all of them to come back.
796     RETURN_IF_NOT_OK(batch_wait->Wait());
797     // Return the result
798     return batch_wait->GetRc();
799   } catch (const std::exception &e) {
800     RETURN_STATUS_UNEXPECTED(e.what());
801   }
802   return Status::OK();
803 }
804 
ProcessRowRequest(CacheServerRequest * cache_req,bool * internal_request)805 Status CacheServer::ProcessRowRequest(CacheServerRequest *cache_req, bool *internal_request) {
806   auto &rq = cache_req->rq_;
807   auto &reply = cache_req->reply_;
808   switch (cache_req->type_) {
809     case BaseRequest::RequestType::kCacheRow: {
810       // Look into the flag to see where we can find the data and call the appropriate method.
811       if (BitTest(rq.flag(), kDataIsInSharedMemory)) {
812         cache_req->rc_ = FastCacheRow(&rq, &reply);
813       } else {
814         cache_req->rc_ = CacheRow(&rq, &reply);
815       }
816       break;
817     }
818     case BaseRequest::RequestType::kInternalCacheRow: {
819       *internal_request = true;
820       cache_req->rc_ = InternalCacheRow(&rq, &reply);
821       break;
822     }
823     case BaseRequest::RequestType::kBatchCacheRows: {
824       cache_req->rc_ = BatchCacheRows(&rq);
825       break;
826     }
827     case BaseRequest::RequestType::kBatchFetchRows: {
828       cache_req->rc_ = BatchFetchRows(&rq, &reply);
829       break;
830     }
831     case BaseRequest::RequestType::kInternalFetchRow: {
832       *internal_request = true;
833       cache_req->rc_ = InternalFetchRow(&rq);
834       break;
835     }
836     default:
837       std::string errMsg("Internal error, request type is not row request: ");
838       errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
839       cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
840   }
841   return Status::OK();
842 }
843 
ProcessSessionRequest(CacheServerRequest * cache_req)844 Status CacheServer::ProcessSessionRequest(CacheServerRequest *cache_req) {
845   auto &rq = cache_req->rq_;
846   auto &reply = cache_req->reply_;
847   switch (cache_req->type_) {
848     case BaseRequest::RequestType::kDropSession: {
849       cache_req->rc_ = DestroySession(&rq);
850       break;
851     }
852     case BaseRequest::RequestType::kGenerateSessionId: {
853       cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply);
854       break;
855     }
856     case BaseRequest::RequestType::kListSessions: {
857       cache_req->rc_ = ListSessions(&reply);
858       break;
859     }
860     default:
861       std::string errMsg("Internal error, request type is not session request: ");
862       errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
863       cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
864   }
865   return Status::OK();
866 }
867 
ProcessAdminRequest(CacheServerRequest * cache_req)868 Status CacheServer::ProcessAdminRequest(CacheServerRequest *cache_req) {
869   auto &rq = cache_req->rq_;
870   auto &reply = cache_req->reply_;
871   switch (cache_req->type_) {
872     case BaseRequest::RequestType::kCreateCache: {
873       cache_req->rc_ = CreateService(&rq, &reply);
874       break;
875     }
876     case BaseRequest::RequestType::kGetCacheMissKeys: {
877       cache_req->rc_ = GetCacheMissKeys(&rq, &reply);
878       break;
879     }
880     case BaseRequest::RequestType::kDestroyCache: {
881       cache_req->rc_ = DestroyCache(&rq);
882       break;
883     }
884     case BaseRequest::RequestType::kGetStat: {
885       cache_req->rc_ = GetStat(&rq, &reply);
886       break;
887     }
888     case BaseRequest::RequestType::kCacheSchema: {
889       cache_req->rc_ = CacheSchema(&rq);
890       break;
891     }
892     case BaseRequest::RequestType::kFetchSchema: {
893       cache_req->rc_ = FetchSchema(&rq, &reply);
894       break;
895     }
896     case BaseRequest::RequestType::kBuildPhaseDone: {
897       cache_req->rc_ = BuildPhaseDone(&rq);
898       break;
899     }
900     case BaseRequest::RequestType::kAllocateSharedBlock: {
901       cache_req->rc_ = AllocateSharedMemory(&rq, &reply);
902       break;
903     }
904     case BaseRequest::RequestType::kFreeSharedBlock: {
905       cache_req->rc_ = FreeSharedMemory(&rq);
906       break;
907     }
908     case BaseRequest::RequestType::kStopService: {
909       // This command shutdowns everything.
910       // But we first reply back to the client that we receive the request.
911       // The real shutdown work will be done by the caller.
912       cache_req->rc_ = AcknowledgeShutdown(cache_req);
913       break;
914     }
915     case BaseRequest::RequestType::kHeartBeat: {
916       cache_req->rc_ = Status::OK();
917       break;
918     }
919     case BaseRequest::RequestType::kToggleWriteMode: {
920       cache_req->rc_ = ToggleWriteMode(&rq);
921       break;
922     }
923     case BaseRequest::RequestType::kConnectReset: {
924       cache_req->rc_ = ConnectReset(&rq);
925       break;
926     }
927     case BaseRequest::RequestType::kGetCacheState: {
928       cache_req->rc_ = GetCacheState(&rq, &reply);
929       break;
930     }
931     default:
932       std::string errMsg("Internal error, request type is not admin request: ");
933       errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
934       cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
935   }
936   return Status::OK();
937 }
938 
ProcessRequest(CacheServerRequest * cache_req)939 Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
940   bool internal_request = false;
941 
942   // Except for creating a new session, we expect cs is not null.
943   if (cache_req->IsRowRequest()) {
944     RETURN_IF_NOT_OK(ProcessRowRequest(cache_req, &internal_request));
945   } else if (cache_req->IsSessionRequest()) {
946     RETURN_IF_NOT_OK(ProcessSessionRequest(cache_req));
947   } else if (cache_req->IsAdminRequest()) {
948     RETURN_IF_NOT_OK(ProcessAdminRequest(cache_req));
949   } else {
950     std::string errMsg("Unknown request type : ");
951     errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
952     cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
953   }
954 
955   // Notify it is done, and move on to the next request.
956   Status2CacheReply(cache_req->rc_, &cache_req->reply_);
957   cache_req->st_ = CacheServerRequest::STATE::FINISH;
958   // We will re-tag the request back to the grpc queue. Once it comes back from the client,
959   // the CacheServerRequest, i.e. the pointer cache_req, will be free
960   if (!internal_request && !global_shutdown_) {
961     cache_req->responder_.Finish(cache_req->reply_, grpc::Status::OK, cache_req);
962   } else {
963     // We can free up the request now.
964     RETURN_IF_NOT_OK(ReturnRequestTag(cache_req));
965   }
966   return Status::OK();
967 }
968 
969 /// \brief This is the main loop the cache server thread(s) are running.
970 /// Each thread will pop a request and send the result back to the client using grpc
971 /// \return
ServerRequest(worker_id_t worker_id)972 Status CacheServer::ServerRequest(worker_id_t worker_id) {
973   TaskManager::FindMe()->Post();
974   MS_LOG(DEBUG) << "Worker id " << worker_id << " is running on node " << hw_info_->GetMyNode();
975   auto &my_que = cache_q_->operator[](worker_id);
976   // Loop forever until we are interrupted or shutdown.
977   while (!global_shutdown_) {
978     CacheServerRequest *cache_req = nullptr;
979     RETURN_IF_NOT_OK(my_que->PopFront(&cache_req));
980     RETURN_IF_NOT_OK(ProcessRequest(cache_req));
981   }
982   return Status::OK();
983 }
984 
GetConnectionID(session_id_type session_id,uint32_t crc) const985 connection_id_type CacheServer::GetConnectionID(session_id_type session_id, uint32_t crc) const {
986   connection_id_type connection_id =
987     (static_cast<connection_id_type>(session_id) << 32u) | static_cast<connection_id_type>(crc);
988   return connection_id;
989 }
990 
GetSessionID(connection_id_type connection_id) const991 session_id_type CacheServer::GetSessionID(connection_id_type connection_id) const {
992   return static_cast<session_id_type>(connection_id >> 32u);
993 }
994 
CacheServer(const std::string & spill_path,int32_t num_workers,int32_t port,int32_t shared_meory_sz_in_gb,float memory_cap_ratio,int8_t log_level,std::shared_ptr<CacheServerHW> hw_info)995 CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port,
996                          int32_t shared_meory_sz_in_gb, float memory_cap_ratio, int8_t log_level,
997                          std::shared_ptr<CacheServerHW> hw_info)
998     : top_(spill_path),
999       num_workers_(num_workers),
1000       num_grpc_workers_(num_workers_),
1001       port_(port),
1002       shared_memory_sz_in_gb_(shared_meory_sz_in_gb),
1003       global_shutdown_(false),
1004       memory_cap_ratio_(memory_cap_ratio),
1005       numa_affinity_(true),
1006       log_level_(log_level),
1007       hw_info_(std::move(hw_info)) {
1008   // If we are not linked with numa library (i.e. NUMA_ENABLED is false), turn off cpu
1009   // affinity which can make performance worse.
1010   if (!CacheServerHW::numa_enabled()) {
1011     numa_affinity_ = false;
1012     MS_LOG(WARNING) << "Warning: This build is not compiled with numa support.  Install libnuma-devel and use a build "
1013                        "that is compiled with numa support for more optimal performance";
1014   }
1015   // We create the shared memory and we will destroy it. All other client just detach only.
1016   if (shared_memory_sz_in_gb_ > kDefaultSharedMemorySize) {
1017     MS_LOG(INFO) << "Shared memory size is readjust to " << kDefaultSharedMemorySize << " GB.";
1018     shared_memory_sz_in_gb_ = kDefaultSharedMemorySize;
1019   }
1020 }
1021 
Run(int msg_qid)1022 Status CacheServer::Run(int msg_qid) {
1023   Status rc = ServiceStart();
1024   // If there is a message que, return the status now before we call join_all which will never return
1025   if (msg_qid != -1) {
1026     SharedMessage msg(msg_qid);
1027     RETURN_IF_NOT_OK(msg.SendStatus(rc));
1028   }
1029   if (rc.IsError()) {
1030     return rc;
1031   }
1032   // This is called by the main function and we shouldn't exit. Otherwise the main thread
1033   // will just shutdown. So we will call some function that never return unless error.
1034   // One good case will be simply to wait for all threads to return.
1035   // note that after we have sent the initial status using the msg_qid, parent process will exit and
1036   // remove it. So we can't use it again.
1037   RETURN_IF_NOT_OK(vg_.join_all(Task::WaitFlag::kBlocking));
1038   // Shutdown the grpc queue. No longer accept any new comer.
1039   comm_layer_->Shutdown();
1040   // The next thing to do drop all the caches.
1041   RETURN_IF_NOT_OK(ServiceStop());
1042   return Status::OK();
1043 }
1044 
GetFreeRequestTag(CacheServerRequest ** q)1045 Status CacheServer::GetFreeRequestTag(CacheServerRequest **q) {
1046   RETURN_UNEXPECTED_IF_NULL(q);
1047   auto *p = new (std::nothrow) CacheServerRequest();
1048   if (p == nullptr) {
1049     return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
1050   }
1051   *q = p;
1052   return Status::OK();
1053 }
1054 
ReturnRequestTag(CacheServerRequest * p)1055 Status CacheServer::ReturnRequestTag(CacheServerRequest *p) {
1056   RETURN_UNEXPECTED_IF_NULL(p);
1057   delete p;
1058   return Status::OK();
1059 }
1060 
DestroySession(CacheRequest * rq)1061 Status CacheServer::DestroySession(CacheRequest *rq) {
1062   CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing session id");
1063   auto drop_session_id = rq->connection_info().session_id();
1064   // Grab the locks in the correct order to avoid deadlock.
1065   UniqueLock sess_lck(&sessions_lock_);
1066   UniqueLock lck(&rwLock_);
1067   // Iterate over the set of connection id's for this session that we're dropping and erase each one.
1068   bool found = false;
1069   for (auto it = all_caches_.begin(); it != all_caches_.end();) {
1070     auto connection_id = it->first;
1071     auto session_id = GetSessionID(connection_id);
1072     // We can just call DestroyCache() but we are holding a lock already. Doing so will cause deadlock.
1073     // So we will just manually do it.
1074     if (session_id == drop_session_id) {
1075       found = true;
1076       it = all_caches_.erase(it);
1077       MS_LOG(INFO) << "Destroy cache with id " << connection_id;
1078     } else {
1079       ++it;
1080     }
1081   }
1082   // Finally remove the session itself
1083   auto n = active_sessions_.erase(drop_session_id);
1084   if (n > 0) {
1085     MS_LOG(INFO) << "Session destroyed with id " << drop_session_id;
1086     return Status::OK();
1087   } else {
1088     if (found) {
1089       std::string errMsg =
1090         "A destroy cache request has been completed but it had a stale session id " + std::to_string(drop_session_id);
1091       RETURN_STATUS_UNEXPECTED(errMsg);
1092     } else {
1093       std::string errMsg =
1094         "Session id " + std::to_string(drop_session_id) + " not found in server on port " + std::to_string(port_) + ".";
1095       return Status(StatusCode::kMDFileNotExist, errMsg);
1096     }
1097   }
1098 }
1099 
GenerateSessionID()1100 session_id_type CacheServer::GenerateSessionID() {
1101   UniqueLock sess_lck(&sessions_lock_);
1102   auto mt = GetRandomDevice();
1103   std::uniform_int_distribution<session_id_type> distribution(0, std::numeric_limits<session_id_type>::max());
1104   session_id_type session_id;
1105   bool duplicate = false;
1106   do {
1107     session_id = distribution(mt);
1108     auto r = active_sessions_.insert(session_id);
1109     duplicate = !r.second;
1110   } while (duplicate);
1111   return session_id;
1112 }
1113 
AllocateSharedMemory(CacheRequest * rq,CacheReply * reply)1114 Status CacheServer::AllocateSharedMemory(CacheRequest *rq, CacheReply *reply) {
1115   auto client_id = rq->client_id();
1116   CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
1117   try {
1118     auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, kDecimal);
1119     void *p = nullptr;
1120     RETURN_IF_NOT_OK(AllocateSharedMemory(client_id, requestedSz, &p));
1121     auto *base = SharedMemoryBaseAddr();
1122     // We can't return the absolute address which makes no sense to the client.
1123     // Instead we return the difference.
1124     auto difference = reinterpret_cast<int64_t>(p) - reinterpret_cast<int64_t>(base);
1125     reply->set_result(std::to_string(difference));
1126   } catch (const std::exception &e) {
1127     RETURN_STATUS_UNEXPECTED(e.what());
1128   }
1129   return Status::OK();
1130 }
1131 
FreeSharedMemory(CacheRequest * rq)1132 Status CacheServer::FreeSharedMemory(CacheRequest *rq) {
1133   auto client_id = rq->client_id();
1134   CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
1135   auto *base = SharedMemoryBaseAddr();
1136   try {
1137     auto addr = strtoll(rq->buf_data(0).data(), nullptr, kDecimal);
1138     auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr);
1139     DeallocateSharedMemory(client_id, p);
1140   } catch (const std::exception &e) {
1141     RETURN_STATUS_UNEXPECTED(e.what());
1142   }
1143   return Status::OK();
1144 }
1145 
GetCacheState(CacheRequest * rq,CacheReply * reply)1146 Status CacheServer::GetCacheState(CacheRequest *rq, CacheReply *reply) {
1147   auto connection_id = rq->connection_id();
1148   SharedLock lck(&rwLock_);
1149   CacheService *cs = GetService(connection_id);
1150   if (cs == nullptr) {
1151     std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
1152     return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
1153   } else {
1154     auto state = cs->GetState();
1155     reply->set_result(std::to_string(static_cast<int8_t>(state)));
1156     return Status::OK();
1157   }
1158 }
1159 
RpcRequest(worker_id_t worker_id)1160 Status CacheServer::RpcRequest(worker_id_t worker_id) {
1161   TaskManager::FindMe()->Post();
1162   RETURN_IF_NOT_OK(comm_layer_->HandleRequest(worker_id));
1163   return Status::OK();
1164 }
1165 
AcknowledgeShutdown(CacheServerRequest * cache_req)1166 Status CacheServer::AcknowledgeShutdown(CacheServerRequest *cache_req) {
1167   auto *rq = &cache_req->rq_;
1168   auto *reply = &cache_req->reply_;
1169   if (!rq->buf_data().empty()) {
1170     // cache_admin sends us a message qID and we will destroy the
1171     // queue in our destructor and this will wake up cache_admin.
1172     // But we don't want the cache_admin blindly just block itself.
1173     // So we will send back an ack before shutdown the comm layer.
1174     try {
1175       int32_t qID = std::stoi(rq->buf_data(0));
1176       shutdown_qIDs_.push_back(qID);
1177     } catch (const std::exception &e) {
1178       // ignore it.
1179     }
1180   }
1181   reply->set_result("OK");
1182   return Status::OK();
1183 }
1184 
GlobalShutdown()1185 void CacheServer::GlobalShutdown() {
1186   // Let's shutdown in proper order.
1187   bool expected = false;
1188   if (global_shutdown_.compare_exchange_strong(expected, true)) {
1189     MS_LOG(WARNING) << "Shutting down server.";
1190     // Interrupt all the threads and queues. We will leave the shutdown
1191     // of the comm layer after we have joined all the threads and will
1192     // be done by the master thread.
1193     vg_.interrupt_all();
1194   }
1195 }
1196 
GetWorkerByNumaId(numa_id_t numa_id) const1197 worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) const {
1198   auto num_numa_nodes = GetNumaNodeCount();
1199   MS_ASSERT(numa_id < num_numa_nodes);
1200   auto num_workers_per_node = GetNumWorkers() / num_numa_nodes;
1201   std::mt19937 gen = GetRandomDevice();
1202   std::uniform_int_distribution<worker_id_t> dist(0, num_workers_per_node - 1);
1203   auto n = dist(gen);
1204   worker_id_t worker_id = n * num_numa_nodes + numa_id;
1205   MS_ASSERT(worker_id < GetNumWorkers());
1206   return worker_id;
1207 }
1208 
GetRandomWorker() const1209 worker_id_t CacheServer::GetRandomWorker() const {
1210   std::mt19937 gen = GetRandomDevice();
1211   std::uniform_int_distribution<worker_id_t> dist(0, num_workers_ - 1);
1212   return dist(gen);
1213 }
1214 
AllocateSharedMemory(int32_t client_id,size_t sz,void ** p)1215 Status CacheServer::AllocateSharedMemory(int32_t client_id, size_t sz, void **p) {
1216   return shm_->AllocateSharedMemory(client_id, sz, p);
1217 }
1218 
DeallocateSharedMemory(int32_t client_id,void * p)1219 void CacheServer::DeallocateSharedMemory(int32_t client_id, void *p) { shm_->DeallocateSharedMemory(client_id, p); }
1220 
IpcResourceCleanup()1221 Status CacheServer::Builder::IpcResourceCleanup() {
1222   Status rc;
1223   SharedMemory::shm_key_t shm_key;
1224   auto unix_socket = PortToUnixSocketPath(port_);
1225   rc = PortToFtok(port_, &shm_key);
1226   // We are expecting the unix path doesn't exist.
1227   if (rc.IsError()) {
1228     return Status::OK();
1229   }
1230   // Attach to the shared memory which we expect don't exist
1231   SharedMemory mem(shm_key);
1232   rc = mem.Attach();
1233   if (rc.IsError()) {
1234     return Status::OK();
1235   } else {
1236     RETURN_IF_NOT_OK(mem.Detach());
1237   }
1238   int32_t num_attached;
1239   RETURN_IF_NOT_OK(mem.GetNumAttached(&num_attached));
1240   if (num_attached == 0) {
1241     // Stale shared memory from last time.
1242     // Remove both the memory and the socket path
1243     RETURN_IF_NOT_OK(mem.Destroy());
1244     Path p(unix_socket);
1245     (void)p.Remove();
1246   } else {
1247     // Server is already up.
1248     std::string errMsg = "Cache server is already up and running";
1249     // We return a duplicate error. The main() will intercept
1250     // and output a proper message
1251     return Status(StatusCode::kMDDuplicateKey, errMsg);
1252   }
1253   return Status::OK();
1254 }
1255 
SanityCheck()1256 Status CacheServer::Builder::SanityCheck() {
1257   if (shared_memory_sz_in_gb_ <= 0) {
1258     RETURN_STATUS_UNEXPECTED("Shared memory size (in GB unit) must be positive");
1259   }
1260   if (num_workers_ <= 0) {
1261     RETURN_STATUS_UNEXPECTED("Number of parallel workers must be positive");
1262   }
1263   if (!top_.empty()) {
1264     auto p = top_.data();
1265     if (p[0] != '/') {
1266       RETURN_STATUS_UNEXPECTED("Spilling directory must be an absolute path");
1267     }
1268     // Check if the spill directory is writable
1269     Path spill(top_);
1270     auto t = spill / Services::GetUniqueID();
1271     Status rc = t.CreateDirectory();
1272     if (rc.IsOk()) {
1273       rc = t.Remove();
1274     }
1275     if (rc.IsError()) {
1276       RETURN_STATUS_UNEXPECTED("Spilling directory is not writable\n" + rc.ToString());
1277     }
1278   }
1279   if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1) {
1280     RETURN_STATUS_UNEXPECTED("Memory cap ratio should be positive and no greater than 1");
1281   }
1282 
1283   // Check if the shared memory.
1284   RETURN_IF_NOT_OK(IpcResourceCleanup());
1285   return Status::OK();
1286 }
1287 
AdjustNumWorkers(int32_t num_workers)1288 int32_t CacheServer::Builder::AdjustNumWorkers(int32_t num_workers) {
1289   int32_t num_numa_nodes = hw_info_->GetNumaNodeCount();
1290   // Bump up num_workers_ to at least the number of numa nodes
1291   num_workers = std::max(num_numa_nodes, num_workers);
1292   // But also it shouldn't be too many more than the hardware concurrency
1293   int32_t num_cpus = hw_info_->GetCpuCount();
1294   constexpr int32_t kThreadsPerCore = 2;
1295   num_workers = std::min(kThreadsPerCore * num_cpus, num_workers);
1296   // Round up num_workers to a multiple of numa nodes.
1297   auto remainder = num_workers % num_numa_nodes;
1298   if (remainder > 0) num_workers += (num_numa_nodes - remainder);
1299   return num_workers;
1300 }
1301 
Builder()1302 CacheServer::Builder::Builder()
1303     : top_(""),
1304       num_workers_(kDefaultNumWorkers),
1305       port_(kCfgDefaultCachePort),
1306       shared_memory_sz_in_gb_(kDefaultSharedMemorySize),
1307       memory_cap_ratio_(kDefaultMemoryCapRatio),
1308       log_level_(kDefaultLogLevel) {
1309   if (num_workers_ == 0) {
1310     num_workers_ = 1;
1311   }
1312 }
1313 }  // namespace dataset
1314 }  // namespace mindspore
1315