• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7 
8  * http://www.apache.org/licenses/LICENSE-2.0
9 
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15 */
16 #include "minddata/dataset/engine/cache/cache_server.h"
17 #include <algorithm>
18 #include <functional>
19 #include <limits>
20 #include <vector>
21 #include "minddata/dataset/include/dataset/constants.h"
22 #include "minddata/dataset/engine/cache/cache_ipc.h"
23 #include "minddata/dataset/engine/cache/cache_service.h"
24 #include "minddata/dataset/engine/cache/cache_request.h"
25 #include "minddata/dataset/util/bit.h"
26 #include "minddata/dataset/util/path.h"
27 #include "minddata/dataset/util/random.h"
28 #ifdef CACHE_LOCAL_CLIENT
29 #include "minddata/dataset/util/sig_handler.h"
30 #endif
31 #undef BitTest
32 
33 namespace mindspore {
34 namespace dataset {
35 CacheServer *CacheServer::instance_ = nullptr;
36 std::once_flag CacheServer::init_instance_flag_;
DoServiceStart()37 Status CacheServer::DoServiceStart() {
38 #ifdef CACHE_LOCAL_CLIENT
39   // We need to destroy the shared memory if user hits Control-C
40   RegisterHandlers();
41 #endif
42   if (!top_.empty()) {
43     Path spill(top_);
44     RETURN_IF_NOT_OK(spill.CreateDirectories());
45     MS_LOG(INFO) << "CacheServer will use disk folder: " << top_;
46   }
47   RETURN_IF_NOT_OK(vg_.ServiceStart());
48   auto num_numa_nodes = GetNumaNodeCount();
49   // If we link with numa library. Set default memory policy.
50   // If we don't pin thread to cpu, then use up all memory controllers to maximize
51   // memory bandwidth.
52   RETURN_IF_NOT_OK(
53     CacheServerHW::SetDefaultMemoryPolicy(numa_affinity_ ? CachePoolPolicy::kLocal : CachePoolPolicy::kInterleave));
54   auto my_node = hw_info_->GetMyNode();
55   MS_LOG(DEBUG) << "Cache server is running on numa node " << my_node;
56   // There will be some threads working on the grpc queue and
57   // some number of threads working on the CacheServerRequest queue.
58   // Like a connector object we will set up the same number of queues but
59   // we do not need to preserve any order. We will set the capacity of
60   // each queue to be 64 since we are just pushing memory pointers which
61   // is only 8 byte each.
62   const int32_t kQueCapacity = 64;
63   // This is the request queue from the client
64   cache_q_ = std::make_shared<QueueList<CacheServerRequest *>>();
65   cache_q_->Init(num_workers_, kQueCapacity);
66   // We will match the number of grpc workers with the number of server workers.
67   // But technically they don't have to be the same.
68   num_grpc_workers_ = num_workers_;
69   MS_LOG(DEBUG) << "Number of gprc workers is set to " << num_grpc_workers_;
70   RETURN_IF_NOT_OK(cache_q_->Register(&vg_));
71   // Start the comm layer
72   try {
73     comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_);
74     RETURN_IF_NOT_OK(comm_layer_->Run());
75   } catch (const std::exception &e) {
76     RETURN_STATUS_UNEXPECTED(e.what());
77   }
78 #ifdef CACHE_LOCAL_CLIENT
79   RETURN_IF_NOT_OK(CachedSharedMemory::CreateArena(&shm_, port_, shared_memory_sz_in_gb_));
80   // Bring up a thread to monitor the unix socket in case it is removed. But it must be done
81   // after we have created the unix socket.
82   auto inotify_f = std::bind(&CacheServerGreeterImpl::MonitorUnixSocket, comm_layer_.get());
83   RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Monitor unix socket", inotify_f));
84 #endif
85   // Spawn a few threads to serve the real request.
86   auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1);
87   for (auto i = 0; i < num_workers_; ++i) {
88     Task *pTask;
89     RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache service worker", std::bind(f, i), &pTask));
90     // Save a copy of the pointer to the underlying Task object. We may dynamically change their affinity if needed.
91     numa_tasks_.emplace(i, pTask);
92     // Spread out all the threads to all the numa nodes if needed
93     if (IsNumaAffinityOn()) {
94       auto numa_id = i % num_numa_nodes;
95       RETURN_IF_NOT_OK(SetAffinity(*pTask, numa_id));
96     }
97   }
98   // Finally loop forever to handle the request.
99   auto r = std::bind(&CacheServer::RpcRequest, this, std::placeholders::_1);
100   for (auto i = 0; i < num_grpc_workers_; ++i) {
101     Task *pTask;
102     RETURN_IF_NOT_OK(vg_.CreateAsyncTask("rpc worker", std::bind(r, i), &pTask));
103     // All these grpc workers will be allocated to the same node which is where we allocate all those free tag
104     // memory.
105     if (IsNumaAffinityOn()) {
106       RETURN_IF_NOT_OK(SetAffinity(*pTask, i % num_numa_nodes));
107     }
108   }
109   return Status::OK();
110 }
111 
DoServiceStop()112 Status CacheServer::DoServiceStop() {
113   Status rc;
114   Status rc2;
115   // First stop all the threads.
116   RETURN_IF_NOT_OK(vg_.ServiceStop());
117   // Clean up all the caches if any.
118   UniqueLock lck(&rwLock_);
119   auto it = all_caches_.begin();
120   while (it != all_caches_.end()) {
121     auto cs = std::move(it->second);
122     rc2 = cs->ServiceStop();
123     if (rc2.IsError()) {
124       rc = rc2;
125     }
126     ++it;
127   }
128   // Also remove the path we use to generate ftok.
129   Path p(PortToUnixSocketPath(port_));
130   (void)p.Remove();
131   // Finally wake up cache_admin if it is waiting
132   for (int32_t qID : shutdown_qIDs_) {
133     SharedMessage msg(qID);
134     RETURN_IF_NOT_OK(msg.SendStatus(Status::OK()));
135     msg.RemoveResourcesOnExit();
136     // Let msg goes out of scope which will destroy the queue.
137   }
138   return rc;
139 }
140 
GetService(connection_id_type id) const141 CacheService *CacheServer::GetService(connection_id_type id) const {
142   auto it = all_caches_.find(id);
143   if (it != all_caches_.end()) {
144     return it->second.get();
145   }
146   return nullptr;
147 }
148 
149 // We would like to protect ourselves from over allocating too much. We will go over existing cache
150 // and calculate how much we have consumed so far.
GlobalMemoryCheck(uint64_t cache_mem_sz)151 Status CacheServer::GlobalMemoryCheck(uint64_t cache_mem_sz) {
152   auto end = all_caches_.end();
153   auto it = all_caches_.begin();
154   auto avail_mem = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_;
155   int64_t max_avail = avail_mem;
156   while (it != end) {
157     auto &cs = it->second;
158     CacheService::ServiceStat stat;
159     RETURN_IF_NOT_OK(cs->GetStat(&stat));
160     int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz;
161     max_avail -= mem_consumed;
162     if (max_avail <= 0) {
163       RETURN_STATUS_OOM("Out of memory, please destroy some sessions.");
164     }
165     ++it;
166   }
167 
168   // If we have some cache using some memory already, make a reasonable decision if we should return
169   // out of memory.
170   if (max_avail < avail_mem) {
171     int64_t req_mem = cache_mem_sz * 1048576L;  // It is in MB unit.
172     if (req_mem > max_avail) {
173       RETURN_STATUS_OOM("Out of memory, please destroy some sessions");
174     } else if (req_mem == 0) {
175       // This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than
176       // 85% of our limit, fail this request.
177       if (static_cast<float>(max_avail) / static_cast<float>(avail_mem) <= kMemoryBottomLineForNewService) {
178         RETURN_STATUS_OOM("Out of memory, please destroy some sessions");
179       }
180     }
181   }
182   return Status::OK();
183 }
184 
CreateService(CacheRequest * rq,CacheReply * reply)185 Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
186   CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing connection info");
187   std::string cookie;
188   int32_t client_id;
189   auto session_id = rq->connection_info().session_id();
190   auto crc = rq->connection_info().crc();
191 
192   // Before allowing the creation, make sure the session had already been created by the user
193   // Our intention is to add this cache to the active sessions list so leave the list locked during
194   // this entire function.
195   UniqueLock sess_lck(&sessions_lock_);
196   auto session_it = active_sessions_.find(session_id);
197   if (session_it == active_sessions_.end()) {
198     RETURN_STATUS_UNEXPECTED("A cache creation has been requested but the session was not found!");
199   }
200 
201   // We concat both numbers to form the internal connection id.
202   auto connection_id = GetConnectionID(session_id, crc);
203   CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing info to create cache");
204   auto &create_cache_buf = rq->buf_data(0);
205   auto p = flatbuffers::GetRoot<CreateCacheRequestMsg>(create_cache_buf.data());
206   auto flag = static_cast<CreateCacheRequest::CreateCacheFlag>(p->flag());
207   auto cache_mem_sz = p->cache_mem_sz();
208   // We can't do spilling unless this server is setup with a spill path in the first place
209   bool spill =
210     (flag & CreateCacheRequest::CreateCacheFlag::kSpillToDisk) == CreateCacheRequest::CreateCacheFlag::kSpillToDisk;
211   bool generate_id =
212     (flag & CreateCacheRequest::CreateCacheFlag::kGenerateRowId) == CreateCacheRequest::CreateCacheFlag::kGenerateRowId;
213   if (spill && top_.empty()) {
214     RETURN_STATUS_UNEXPECTED("Server is not set up with spill support.");
215   }
216   // Before creating the cache, first check if this is a request for a shared usage of an existing cache
217   // If two CreateService come in with identical connection_id, we need to serialize the create.
218   // The first create will be successful and be given a special cookie.
219   UniqueLock lck(&rwLock_);
220   bool duplicate = false;
221   CacheService *curr_cs = GetService(connection_id);
222   if (curr_cs != nullptr) {
223     duplicate = true;
224     client_id = curr_cs->num_clients_.fetch_add(1);
225     MS_LOG(INFO) << "Duplicate request from client " + std::to_string(client_id) + " for " +
226                       std::to_string(connection_id) + " to create cache service";
227   }
228   // Early exit if we are doing global shutdown
229   if (global_shutdown_) {
230     return Status::OK();
231   }
232 
233   if (!duplicate) {
234     RETURN_IF_NOT_OK(GlobalMemoryCheck(cache_mem_sz));
235     std::unique_ptr<CacheService> cs;
236     try {
237       cs = std::make_unique<CacheService>(cache_mem_sz, spill ? top_ : "", generate_id);
238       RETURN_IF_NOT_OK(cs->ServiceStart());
239       cookie = cs->cookie();
240       client_id = cs->num_clients_.fetch_add(1);
241       all_caches_.emplace(connection_id, std::move(cs));
242     } catch (const std::bad_alloc &e) {
243       RETURN_STATUS_OOM("Out of memory.");
244     }
245   }
246 
247   // Shuffle the worker threads. But we need to release the locks or we will deadlock when calling
248   // the following function
249   lck.Unlock();
250   sess_lck.Unlock();
251   auto numa_id = client_id % GetNumaNodeCount();
252   std::vector<cpu_id_t> cpu_list = hw_info_->GetCpuList(numa_id);
253   // Send back the data
254   flatbuffers::FlatBufferBuilder fbb;
255   flatbuffers::Offset<flatbuffers::String> off_cookie;
256   flatbuffers::Offset<flatbuffers::Vector<cpu_id_t>> off_cpu_list;
257   off_cookie = fbb.CreateString(cookie);
258   off_cpu_list = fbb.CreateVector(cpu_list);
259   CreateCacheReplyMsgBuilder bld(fbb);
260   bld.add_connection_id(connection_id);
261   bld.add_cookie(off_cookie);
262   bld.add_client_id(client_id);
263   // The last thing we send back is a set of cpu id that we suggest the client should bind itself to
264   bld.add_cpu_id(off_cpu_list);
265   auto off = bld.Finish();
266   fbb.Finish(off);
267   reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
268   // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it
269   // treat it as OK.
270   return duplicate ? Status(StatusCode::kMDDuplicateKey) : Status::OK();
271 }
272 
DestroyCache(CacheRequest * rq)273 Status CacheServer::DestroyCache(CacheRequest *rq) {
274   // We need a strong lock to protect the map.
275   UniqueLock lck(&rwLock_);
276   auto id = rq->connection_id();
277   CacheService *cs = GetService(id);
278   // it is already destroyed. Ignore it.
279   if (cs != nullptr) {
280     MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id);
281     // std::map will invoke the destructor of CacheService. So we don't need to do anything here.
282     auto n = all_caches_.erase(id);
283     if (n == 0) {
284       // It has been destroyed by another duplicate request.
285       MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service";
286     }
287   }
288   // We aren't touching the session list even though we may be dropping the last remaining cache of a session.
289   // Leave that to be done by the drop session command.
290   return Status::OK();
291 }
292 
CacheRow(CacheRequest * rq,CacheReply * reply)293 Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) {
294   auto connection_id = rq->connection_id();
295   // Hold the shared lock to prevent the cache from being dropped.
296   SharedLock lck(&rwLock_);
297   CacheService *cs = GetService(connection_id);
298   if (cs == nullptr) {
299     std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
300     RETURN_STATUS_UNEXPECTED(errMsg);
301   } else {
302     auto sz = rq->buf_data_size();
303     std::vector<const void *> buffers;
304     buffers.reserve(sz);
305     // First piece of data is the cookie and is required
306     CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie");
307     auto &cookie = rq->buf_data(0);
308     // Only if the cookie matches, we can accept insert into this cache that has a build phase
309     if (!cs->HasBuildPhase() || cookie == cs->cookie()) {
310       // Push the address of each buffer (in the form of std::string coming in from protobuf) into
311       // a vector of buffer
312       for (auto i = 1; i < sz; ++i) {
313         buffers.push_back(rq->buf_data(i).data());
314       }
315       row_id_type id = -1;
316       // We will allocate the memory the same numa node this thread is bound to.
317       RETURN_IF_NOT_OK(cs->CacheRow(buffers, &id));
318       reply->set_result(std::to_string(id));
319     } else {
320       RETURN_STATUS_UNEXPECTED("Cookie mismatch");
321     }
322   }
323   return Status::OK();
324 }
325 
FastCacheRow(CacheRequest * rq,CacheReply * reply)326 Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
327   auto connection_id = rq->connection_id();
328   auto client_id = rq->client_id();
329   CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
330   // Hold the shared lock to prevent the cache from being dropped.
331   SharedLock lck(&rwLock_);
332   CacheService *cs = GetService(connection_id);
333   auto *base = SharedMemoryBaseAddr();
334   // Ensure we got 3 pieces of data coming in
335   constexpr int32_t kMinBufDataSize = 3;
336   CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() >= kMinBufDataSize, "Incomplete data");
337   // First one is cookie, followed by data address and then size.
338   enum BufDataIndex : uint8_t { kCookie = 0, kAddr = 1, kSize = 2 };
339   // First piece of data is the cookie and is required
340   auto &cookie = rq->buf_data(BufDataIndex::kCookie);
341   // Second piece of data is the address where we can find the serialized data
342   auto addr = strtoll(rq->buf_data(BufDataIndex::kAddr).data(), nullptr, kDecimal);
343   auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr);
344   // Third piece of data is the size of the serialized data that we need to transfer
345   auto sz = strtoll(rq->buf_data(BufDataIndex::kSize).data(), nullptr, kDecimal);
346   // Successful or not, we need to free the memory on exit.
347   Status rc;
348   if (cs == nullptr) {
349     std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
350     rc = STATUS_ERROR(StatusCode::kMDUnexpectedError, errMsg);
351   } else {
352     // Only if the cookie matches, we can accept insert into this cache that has a build phase
353     if (!cs->HasBuildPhase() || cookie == cs->cookie()) {
354       row_id_type id = -1;
355       ReadableSlice src(p, sz);
356       // We will allocate the memory the same numa node this thread is bound to.
357       rc = cs->FastCacheRow(src, &id);
358       reply->set_result(std::to_string(id));
359     } else {
360       auto state = cs->GetState();
361       if (state != CacheServiceState::kFetchPhase) {
362         rc = STATUS_ERROR(StatusCode::kMDUnexpectedError, "Cache service is not in fetch phase. The current phase is " +
363                                                             std::to_string(static_cast<int8_t>(state)) +
364                                                             ". Client id: " + std::to_string(client_id));
365       } else {
366         rc = STATUS_ERROR(StatusCode::kMDUnexpectedError, "Cookie mismatch. Client id: " + std::to_string(client_id));
367       }
368     }
369   }
370   // Return the block to the shared memory only if it is not internal request.
371   if (static_cast<BaseRequest::RequestType>(rq->type()) == BaseRequest::RequestType::kCacheRow) {
372     DeallocateSharedMemory(client_id, p);
373   }
374   return rc;
375 }
376 
InternalCacheRow(CacheRequest * rq,CacheReply * reply)377 Status CacheServer::InternalCacheRow(CacheRequest *rq, CacheReply *reply) {
378   // Look into the flag to see where we can find the data and call the appropriate method.
379   auto flag = rq->flag();
380   Status rc;
381   if (BitTest(flag, kDataIsInSharedMemory)) {
382     rc = FastCacheRow(rq, reply);
383     // This is an internal request and is not tied to rpc. But need to post because there
384     // is a thread waiting on the completion of this request.
385     try {
386       constexpr int32_t kBatchWaitIdx = 3;
387       // Ensure we got 4 pieces of data coming in
388       constexpr int32_t kMinBufDataSize = 4;
389       CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() >= kMinBufDataSize, "Incomplete data");
390       // Fourth piece of the data is the address of the BatchWait ptr
391       int64_t addr = strtol(rq->buf_data(kBatchWaitIdx).data(), nullptr, kDecimal);
392       auto *bw = reinterpret_cast<BatchWait *>(addr);
393       // Check if the object is still around.
394       auto bwObj = bw->GetBatchWait();
395       if (bwObj.lock()) {
396         RETURN_IF_NOT_OK(bw->Set(rc));
397       }
398     } catch (const std::exception &e) {
399       RETURN_STATUS_UNEXPECTED(e.what());
400     }
401   } else {
402     rc = CacheRow(rq, reply);
403   }
404   return rc;
405 }
406 
InternalFetchRow(CacheRequest * rq)407 Status CacheServer::InternalFetchRow(CacheRequest *rq) {
408   auto connection_id = rq->connection_id();
409   SharedLock lck(&rwLock_);
410   CacheService *cs = GetService(connection_id);
411   Status rc;
412   if (cs == nullptr) {
413     std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
414     RETURN_STATUS_UNEXPECTED(errMsg);
415   }
416   // First piece is a flatbuffer containing row fetch information, second piece is the address of the BatchWait ptr
417   enum BufDataIndex : uint8_t { kFetchRowMsg = 0, kBatchWait = 1 };
418   // Ensure we got 2 pieces of data coming in
419   constexpr int32_t kMinBufDataSize = 2;
420   CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() >= kMinBufDataSize, "Incomplete data");
421   rc = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq->buf_data(BufDataIndex::kFetchRowMsg).data()));
422   // This is an internal request and is not tied to rpc. But need to post because there
423   // is a thread waiting on the completion of this request.
424   try {
425     int64_t addr = strtol(rq->buf_data(BufDataIndex::kBatchWait).data(), nullptr, kDecimal);
426     auto *bw = reinterpret_cast<BatchWait *>(addr);
427     // Check if the object is still around.
428     auto bwObj = bw->GetBatchWait();
429     if (bwObj.lock()) {
430       RETURN_IF_NOT_OK(bw->Set(rc));
431     }
432   } catch (const std::exception &e) {
433     RETURN_STATUS_UNEXPECTED(e.what());
434   }
435   return rc;
436 }
437 
BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> & fbb,WritableSlice * out)438 Status CacheServer::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out) {
439   RETURN_UNEXPECTED_IF_NULL(out);
440   auto p = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer());
441   RETURN_UNEXPECTED_IF_NULL(p);
442   const auto num_elements = p->rows()->size();
443   auto connection_id = p->connection_id();
444   auto batch_wait = std::make_shared<BatchWait>(num_elements);
445   int64_t data_offset = (num_elements + 1) * sizeof(int64_t);
446   auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer());
447   offset_array[0] = data_offset;
448   for (uint32_t i = 0; i < num_elements; ++i) {
449     auto data_locator = p->rows()->Get(i);
450     auto node_id = data_locator->node_id();
451     size_t sz = data_locator->size();
452     void *source_addr = reinterpret_cast<void *>(data_locator->addr());
453     auto key = data_locator->key();
454     // Please read the comment in CacheServer::BatchFetchRows where we allocate
455     // the buffer big enough so each thread (which we are going to dispatch) will
456     // not run into false sharing problem. We are going to round up sz to 4k.
457     auto sz_4k = round_up_4K(sz);
458     offset_array[i + 1] = offset_array[i] + sz_4k;
459     if (sz > 0) {
460       WritableSlice row_data(*out, offset_array[i], sz);
461       // Get a request and send to the proper worker (at some numa node) to do the fetch.
462       worker_id_t worker_id = IsNumaAffinityOn() ? GetWorkerByNumaId(node_id) : GetRandomWorker();
463       CacheServerRequest *cache_rq;
464       RETURN_IF_NOT_OK(GetFreeRequestTag(&cache_rq));
465       // Set up all the necessarily field.
466       cache_rq->type_ = BaseRequest::RequestType::kInternalFetchRow;
467       cache_rq->st_ = CacheServerRequest::STATE::PROCESS;
468       cache_rq->rq_.set_connection_id(connection_id);
469       cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_));
470       auto dest_addr = row_data.GetMutablePointer();
471       flatbuffers::FlatBufferBuilder fb2;
472       FetchRowMsgBuilder bld(fb2);
473       bld.add_key(key);
474       bld.add_size(sz);
475       bld.add_source_addr(reinterpret_cast<int64_t>(source_addr));
476       bld.add_dest_addr(reinterpret_cast<int64_t>(dest_addr));
477       auto offset = bld.Finish();
478       fb2.Finish(offset);
479       cache_rq->rq_.add_buf_data(fb2.GetBufferPointer(), fb2.GetSize());
480       cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast<int64_t>(batch_wait.get())));
481       RETURN_IF_NOT_OK(PushRequest(worker_id, cache_rq));
482     } else {
483       // Nothing to fetch but we still need to post something back into the wait area.
484       RETURN_IF_NOT_OK(batch_wait->Set(Status::OK()));
485     }
486   }
487   // Now wait for all of them to come back.
488   RETURN_IF_NOT_OK(batch_wait->Wait());
489   // Return the result
490   return batch_wait->GetRc();
491 }
492 
BatchFetchRows(CacheRequest * rq,CacheReply * reply)493 Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
494   auto connection_id = rq->connection_id();
495   auto client_id = rq->client_id();
496   // Hold the shared lock to prevent the cache from being dropped.
497   SharedLock lck(&rwLock_);
498   CacheService *cs = GetService(connection_id);
499   if (cs == nullptr) {
500     std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
501     RETURN_STATUS_UNEXPECTED(errMsg);
502   } else {
503     CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing row id");
504     auto &row_id_buf = rq->buf_data(0);
505     auto p = flatbuffers::GetRoot<TensorRowIds>(row_id_buf.data());
506     RETURN_UNEXPECTED_IF_NULL(p);
507     std::vector<row_id_type> row_id;
508     auto sz = p->row_id()->size();
509     row_id.reserve(sz);
510     for (uint32_t i = 0; i < sz; ++i) {
511       row_id.push_back(p->row_id()->Get(i));
512     }
513     std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
514     RETURN_IF_NOT_OK(cs->PreBatchFetch(connection_id, row_id, fbb));
515     // Let go of the shared lock. We don't need to interact with the CacheService anymore.
516     // We shouldn't be holding any lock while we can wait for a long time for the rows to come back.
517     lck.Unlock();
518     auto locator = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer());
519     int64_t mem_sz = sizeof(int64_t) * (sz + 1);
520     for (auto i = 0; i < sz; ++i) {
521       auto row_sz = locator->rows()->Get(i)->size();
522       // row_sz is the size of the cached data. Later we will spawn multiple threads
523       // each of which will copy the data into either shared memory or protobuf concurrently but
524       // to different region.
525       // To avoid false sharing, we will bump up row_sz to be a multiple of 4k, i.e. 4096 bytes
526       row_sz = round_up_4K(row_sz);
527       mem_sz += row_sz;
528     }
529     auto client_flag = rq->flag();
530     bool local_client = BitTest(client_flag, kLocalClientSupport);
531     // For large amount data to be sent back, we will use shared memory provided it is a local
532     // client that has local bypass support
533     bool local_bypass = local_client ? (mem_sz >= kLocalByPassThreshold) : false;
534     reply->set_flag(local_bypass ? kDataIsInSharedMemory : 0);
535     if (local_bypass) {
536       // We will use shared memory
537       auto *base = SharedMemoryBaseAddr();
538       void *q = nullptr;
539       RETURN_IF_NOT_OK(AllocateSharedMemory(client_id, mem_sz, &q));
540       WritableSlice dest(q, mem_sz);
541       Status rc = BatchFetch(fbb, &dest);
542       if (rc.IsError()) {
543         DeallocateSharedMemory(client_id, q);
544         return rc;
545       }
546       // We can't return the absolute address which makes no sense to the client.
547       // Instead we return the difference.
548       auto difference = reinterpret_cast<int64_t>(q) - reinterpret_cast<int64_t>(base);
549       reply->set_result(std::to_string(difference));
550     } else {
551       // We are going to use std::string to allocate and hold the result which will be eventually
552       // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
553       // to minimize memory copy.
554       std::string mem;
555       try {
556         mem.resize(mem_sz);
557         CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= mem_sz, "Programming error");
558       } catch (const std::bad_alloc &e) {
559         RETURN_STATUS_OOM("Out of memory.");
560       }
561       WritableSlice dest(mem.data(), mem_sz);
562       RETURN_IF_NOT_OK(BatchFetch(fbb, &dest));
563       reply->set_result(std::move(mem));
564     }
565   }
566   return Status::OK();
567 }
568 
GetStat(CacheRequest * rq,CacheReply * reply)569 Status CacheServer::GetStat(CacheRequest *rq, CacheReply *reply) {
570   auto connection_id = rq->connection_id();
571   // Hold the shared lock to prevent the cache from being dropped.
572   SharedLock lck(&rwLock_);
573   CacheService *cs = GetService(connection_id);
574   if (cs == nullptr) {
575     std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
576     RETURN_STATUS_UNEXPECTED(errMsg);
577   } else {
578     CacheService::ServiceStat svc_stat;
579     RETURN_IF_NOT_OK(cs->GetStat(&svc_stat));
580     flatbuffers::FlatBufferBuilder fbb;
581     ServiceStatMsgBuilder bld(fbb);
582     bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached);
583     bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached);
584     bld.add_avg_cache_sz(svc_stat.stat_.average_cache_sz);
585     bld.add_num_numa_hit(svc_stat.stat_.num_numa_hit);
586     bld.add_max_row_id(svc_stat.stat_.max_key);
587     bld.add_min_row_id(svc_stat.stat_.min_key);
588     bld.add_state(svc_stat.state_);
589     auto offset = bld.Finish();
590     fbb.Finish(offset);
591     reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
592   }
593   return Status::OK();
594 }
595 
CacheSchema(CacheRequest * rq)596 Status CacheServer::CacheSchema(CacheRequest *rq) {
597   auto connection_id = rq->connection_id();
598   // Hold the shared lock to prevent the cache from being dropped.
599   SharedLock lck(&rwLock_);
600   CacheService *cs = GetService(connection_id);
601   if (cs == nullptr) {
602     std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
603     RETURN_STATUS_UNEXPECTED(errMsg);
604   } else {
605     CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing schema information");
606     auto &create_schema_buf = rq->buf_data(0);
607     RETURN_IF_NOT_OK(cs->CacheSchema(create_schema_buf.data(), create_schema_buf.size()));
608   }
609   return Status::OK();
610 }
611 
FetchSchema(CacheRequest * rq,CacheReply * reply)612 Status CacheServer::FetchSchema(CacheRequest *rq, CacheReply *reply) {
613   auto connection_id = rq->connection_id();
614   // Hold the shared lock to prevent the cache from being dropped.
615   SharedLock lck(&rwLock_);
616   CacheService *cs = GetService(connection_id);
617   if (cs == nullptr) {
618     std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
619     RETURN_STATUS_UNEXPECTED(errMsg);
620   } else {
621     // We are going to use std::string to allocate and hold the result which will be eventually
622     // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
623     // to minimize memory copy.
624     std::string mem;
625     RETURN_IF_NOT_OK(cs->FetchSchema(&mem));
626     reply->set_result(std::move(mem));
627   }
628   return Status::OK();
629 }
630 
BuildPhaseDone(CacheRequest * rq)631 Status CacheServer::BuildPhaseDone(CacheRequest *rq) {
632   auto connection_id = rq->connection_id();
633   // Hold the shared lock to prevent the cache from being dropped.
634   SharedLock lck(&rwLock_);
635   CacheService *cs = GetService(connection_id);
636   if (cs == nullptr) {
637     std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
638     RETURN_STATUS_UNEXPECTED(errMsg);
639   } else {
640     // First piece of data is the cookie
641     CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie");
642     auto &cookie = rq->buf_data(0);
643     // We can only allow to switch phase if the cookie match.
644     if (cookie == cs->cookie()) {
645       RETURN_IF_NOT_OK(cs->BuildPhaseDone());
646     } else {
647       RETURN_STATUS_UNEXPECTED("Cookie mismatch");
648     }
649   }
650   return Status::OK();
651 }
652 
GetCacheMissKeys(CacheRequest * rq,CacheReply * reply)653 Status CacheServer::GetCacheMissKeys(CacheRequest *rq, CacheReply *reply) {
654   auto connection_id = rq->connection_id();
655   // Hold the shared lock to prevent the cache from being dropped.
656   SharedLock lck(&rwLock_);
657   CacheService *cs = GetService(connection_id);
658   if (cs == nullptr) {
659     std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
660     RETURN_STATUS_UNEXPECTED(errMsg);
661   } else {
662     std::vector<row_id_type> gap;
663     RETURN_IF_NOT_OK(cs->FindKeysMiss(&gap));
664     flatbuffers::FlatBufferBuilder fbb;
665     auto off_t = fbb.CreateVector(gap);
666     TensorRowIdsBuilder bld(fbb);
667     bld.add_row_id(off_t);
668     auto off = bld.Finish();
669     fbb.Finish(off);
670     reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
671   }
672   return Status::OK();
673 }
674 
GenerateClientSessionID(session_id_type session_id,CacheReply * reply)675 inline Status GenerateClientSessionID(session_id_type session_id, CacheReply *reply) {
676   reply->set_result(std::to_string(session_id));
677   MS_LOG(INFO) << "Server generated new session id " << session_id;
678   return Status::OK();
679 }
680 
ToggleWriteMode(CacheRequest * rq)681 Status CacheServer::ToggleWriteMode(CacheRequest *rq) {
682   auto connection_id = rq->connection_id();
683   // Hold the shared lock to prevent the cache from being dropped.
684   SharedLock lck(&rwLock_);
685   CacheService *cs = GetService(connection_id);
686   if (cs == nullptr) {
687     std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
688     RETURN_STATUS_UNEXPECTED(errMsg);
689   } else {
690     // First piece of data is the on/off flag
691     CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing action flag");
692     const auto &action = rq->buf_data(0);
693     bool on_off = false;
694     if (strcmp(action.data(), "on") == 0) {
695       on_off = true;
696     } else if (strcmp(action.data(), "off") == 0) {
697       on_off = false;
698     } else {
699       RETURN_STATUS_UNEXPECTED("Unknown request: " + action);
700     }
701     RETURN_IF_NOT_OK(cs->ToggleWriteMode(on_off));
702   }
703   return Status::OK();
704 }
705 
ListSessions(CacheReply * reply)706 Status CacheServer::ListSessions(CacheReply *reply) {
707   SharedLock sess_lck(&sessions_lock_);
708   SharedLock lck(&rwLock_);
709   flatbuffers::FlatBufferBuilder fbb;
710   std::vector<flatbuffers::Offset<ListSessionMsg>> session_msgs_vector;
711   for (auto const &current_session_id : active_sessions_) {
712     bool found = false;
713     for (auto const &it : all_caches_) {
714       auto current_conn_id = it.first;
715       if (GetSessionID(current_conn_id) == current_session_id) {
716         found = true;
717         auto &cs = it.second;
718         CacheService::ServiceStat svc_stat;
719         RETURN_IF_NOT_OK(cs->GetStat(&svc_stat));
720         auto current_stats = CreateServiceStatMsg(fbb, svc_stat.stat_.num_mem_cached, svc_stat.stat_.num_disk_cached,
721                                                   svc_stat.stat_.average_cache_sz, svc_stat.stat_.num_numa_hit,
722                                                   svc_stat.stat_.min_key, svc_stat.stat_.max_key, svc_stat.state_);
723         auto current_session_info = CreateListSessionMsg(fbb, current_session_id, current_conn_id, current_stats);
724         session_msgs_vector.push_back(current_session_info);
725       }
726     }
727     if (!found) {
728       // If there is no cache created yet, assign a connection id of 0 along with empty stats
729       auto current_stats = CreateServiceStatMsg(fbb, 0, 0, 0, 0, 0, 0);
730       auto current_session_info = CreateListSessionMsg(fbb, current_session_id, 0, current_stats);
731       session_msgs_vector.push_back(current_session_info);
732     }
733   }
734   flatbuffers::Offset<flatbuffers::String> spill_dir;
735   spill_dir = fbb.CreateString(top_);
736   auto session_msgs = fbb.CreateVector(session_msgs_vector);
737   ListSessionsMsgBuilder s_builder(fbb);
738   s_builder.add_sessions(session_msgs);
739   s_builder.add_num_workers(num_workers_);
740   s_builder.add_log_level(log_level_);
741   s_builder.add_spill_dir(spill_dir);
742   auto offset = s_builder.Finish();
743   fbb.Finish(offset);
744   reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
745   return Status::OK();
746 }
747 
ConnectReset(CacheRequest * rq)748 Status CacheServer::ConnectReset(CacheRequest *rq) {
749   auto connection_id = rq->connection_id();
750   // Hold the shared lock to prevent the cache from being dropped.
751   SharedLock lck(&rwLock_);
752   CacheService *cs = GetService(connection_id);
753   if (cs == nullptr) {
754     std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
755     RETURN_STATUS_UNEXPECTED(errMsg);
756   } else {
757     auto client_id = rq->client_id();
758     MS_LOG(WARNING) << "Client id " << client_id << " with connection id " << connection_id << " disconnects";
759     cs->num_clients_--;
760   }
761   return Status::OK();
762 }
763 
BatchCacheRows(CacheRequest * rq)764 Status CacheServer::BatchCacheRows(CacheRequest *rq) {
765   // First one is cookie, followed by address and then size.
766   enum BufDataIndex : uint8_t { kCookie = 0, kAddr = 1, kSize = 2 };
767   constexpr int32_t kExpectedBufDataSize = 3;
768   CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data().size() == kExpectedBufDataSize, "Expect three pieces of data");
769   try {
770     auto &cookie = rq->buf_data(BufDataIndex::kCookie);
771     auto connection_id = rq->connection_id();
772     auto client_id = rq->client_id();
773     int64_t offset_addr;
774     int32_t num_elem;
775     auto *base = SharedMemoryBaseAddr();
776     offset_addr = strtoll(rq->buf_data(BufDataIndex::kAddr).data(), nullptr, kDecimal);
777     auto p = reinterpret_cast<char *>(reinterpret_cast<int64_t>(base) + offset_addr);
778     num_elem = static_cast<int32_t>(strtol(rq->buf_data(BufDataIndex::kSize).data(), nullptr, kDecimal));
779     auto batch_wait = std::make_shared<BatchWait>(num_elem);
780     // Get a set of free request and push into the queues.
781     for (auto i = 0; i < num_elem; ++i) {
782       auto start = reinterpret_cast<int64_t>(p);
783       auto msg = GetTensorRowHeaderMsg(p);
784       p += msg->size_of_this();
785       for (auto k = 0; k < msg->column()->size(); ++k) {
786         p += msg->data_sz()->Get(k);
787       }
788       CacheServerRequest *cache_rq;
789       RETURN_IF_NOT_OK(GetFreeRequestTag(&cache_rq));
790       // Fill in detail.
791       cache_rq->type_ = BaseRequest::RequestType::kInternalCacheRow;
792       cache_rq->st_ = CacheServerRequest::STATE::PROCESS;
793       cache_rq->rq_.set_connection_id(connection_id);
794       cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_));
795       cache_rq->rq_.set_client_id(client_id);
796       cache_rq->rq_.set_flag(kDataIsInSharedMemory);
797       cache_rq->rq_.add_buf_data(cookie);
798       cache_rq->rq_.add_buf_data(std::to_string(start - reinterpret_cast<int64_t>(base)));
799       cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast<int64_t>(p - start)));
800       cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast<int64_t>(batch_wait.get())));
801       RETURN_IF_NOT_OK(PushRequest(GetRandomWorker(), cache_rq));
802     }
803     // Now wait for all of them to come back.
804     RETURN_IF_NOT_OK(batch_wait->Wait());
805     // Return the result
806     return batch_wait->GetRc();
807   } catch (const std::exception &e) {
808     RETURN_STATUS_UNEXPECTED(e.what());
809   }
810   return Status::OK();
811 }
812 
ProcessRowRequest(CacheServerRequest * cache_req,bool * internal_request)813 Status CacheServer::ProcessRowRequest(CacheServerRequest *cache_req, bool *internal_request) {
814   auto &rq = cache_req->rq_;
815   auto &reply = cache_req->reply_;
816   switch (cache_req->type_) {
817     case BaseRequest::RequestType::kCacheRow: {
818       // Look into the flag to see where we can find the data and call the appropriate method.
819       if (BitTest(rq.flag(), kDataIsInSharedMemory)) {
820         cache_req->rc_ = FastCacheRow(&rq, &reply);
821       } else {
822         cache_req->rc_ = CacheRow(&rq, &reply);
823       }
824       break;
825     }
826     case BaseRequest::RequestType::kInternalCacheRow: {
827       *internal_request = true;
828       cache_req->rc_ = InternalCacheRow(&rq, &reply);
829       break;
830     }
831     case BaseRequest::RequestType::kBatchCacheRows: {
832       cache_req->rc_ = BatchCacheRows(&rq);
833       break;
834     }
835     case BaseRequest::RequestType::kBatchFetchRows: {
836       cache_req->rc_ = BatchFetchRows(&rq, &reply);
837       break;
838     }
839     case BaseRequest::RequestType::kInternalFetchRow: {
840       *internal_request = true;
841       cache_req->rc_ = InternalFetchRow(&rq);
842       break;
843     }
844     default:
845       std::string errMsg("Internal error, request type is not row request: ");
846       errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
847       cache_req->rc_ = STATUS_ERROR(StatusCode::kMDUnexpectedError, errMsg);
848   }
849   return Status::OK();
850 }
851 
ProcessSessionRequest(CacheServerRequest * cache_req)852 Status CacheServer::ProcessSessionRequest(CacheServerRequest *cache_req) {
853   auto &rq = cache_req->rq_;
854   auto &reply = cache_req->reply_;
855   switch (cache_req->type_) {
856     case BaseRequest::RequestType::kDropSession: {
857       cache_req->rc_ = DestroySession(&rq);
858       break;
859     }
860     case BaseRequest::RequestType::kGenerateSessionId: {
861       cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply);
862       break;
863     }
864     case BaseRequest::RequestType::kListSessions: {
865       cache_req->rc_ = ListSessions(&reply);
866       break;
867     }
868     default:
869       std::string errMsg("Internal error, request type is not session request: ");
870       errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
871       cache_req->rc_ = STATUS_ERROR(StatusCode::kMDUnexpectedError, errMsg);
872   }
873   return Status::OK();
874 }
875 
ProcessAdminRequest(CacheServerRequest * cache_req)876 Status CacheServer::ProcessAdminRequest(CacheServerRequest *cache_req) {
877   auto &rq = cache_req->rq_;
878   auto &reply = cache_req->reply_;
879   switch (cache_req->type_) {
880     case BaseRequest::RequestType::kCreateCache: {
881       cache_req->rc_ = CreateService(&rq, &reply);
882       break;
883     }
884     case BaseRequest::RequestType::kGetCacheMissKeys: {
885       cache_req->rc_ = GetCacheMissKeys(&rq, &reply);
886       break;
887     }
888     case BaseRequest::RequestType::kDestroyCache: {
889       cache_req->rc_ = DestroyCache(&rq);
890       break;
891     }
892     case BaseRequest::RequestType::kGetStat: {
893       cache_req->rc_ = GetStat(&rq, &reply);
894       break;
895     }
896     case BaseRequest::RequestType::kCacheSchema: {
897       cache_req->rc_ = CacheSchema(&rq);
898       break;
899     }
900     case BaseRequest::RequestType::kFetchSchema: {
901       cache_req->rc_ = FetchSchema(&rq, &reply);
902       break;
903     }
904     case BaseRequest::RequestType::kBuildPhaseDone: {
905       cache_req->rc_ = BuildPhaseDone(&rq);
906       break;
907     }
908     case BaseRequest::RequestType::kAllocateSharedBlock: {
909       cache_req->rc_ = AllocateSharedMemory(&rq, &reply);
910       break;
911     }
912     case BaseRequest::RequestType::kFreeSharedBlock: {
913       cache_req->rc_ = FreeSharedMemory(&rq);
914       break;
915     }
916     case BaseRequest::RequestType::kStopService: {
917       // This command shutdowns everything.
918       // But we first reply back to the client that we receive the request.
919       // The real shutdown work will be done by the caller.
920       cache_req->rc_ = AcknowledgeShutdown(cache_req);
921       break;
922     }
923     case BaseRequest::RequestType::kHeartBeat: {
924       cache_req->rc_ = Status::OK();
925       break;
926     }
927     case BaseRequest::RequestType::kToggleWriteMode: {
928       cache_req->rc_ = ToggleWriteMode(&rq);
929       break;
930     }
931     case BaseRequest::RequestType::kConnectReset: {
932       cache_req->rc_ = ConnectReset(&rq);
933       break;
934     }
935     case BaseRequest::RequestType::kGetCacheState: {
936       cache_req->rc_ = GetCacheState(&rq, &reply);
937       break;
938     }
939     default:
940       std::string errMsg("Internal error, request type is not admin request: ");
941       errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
942       cache_req->rc_ = STATUS_ERROR(StatusCode::kMDUnexpectedError, errMsg);
943   }
944   return Status::OK();
945 }
946 
ProcessRequest(CacheServerRequest * cache_req)947 Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
948   bool internal_request = false;
949 
950   // Except for creating a new session, we expect cs is not null.
951   if (cache_req->IsRowRequest()) {
952     RETURN_IF_NOT_OK(ProcessRowRequest(cache_req, &internal_request));
953   } else if (cache_req->IsSessionRequest()) {
954     RETURN_IF_NOT_OK(ProcessSessionRequest(cache_req));
955   } else if (cache_req->IsAdminRequest()) {
956     RETURN_IF_NOT_OK(ProcessAdminRequest(cache_req));
957   } else {
958     std::string errMsg("Unknown request type : ");
959     errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
960     cache_req->rc_ = STATUS_ERROR(StatusCode::kMDUnexpectedError, errMsg);
961   }
962 
963   // Notify it is done, and move on to the next request.
964   Status2CacheReply(cache_req->rc_, &cache_req->reply_);
965   cache_req->st_ = CacheServerRequest::STATE::FINISH;
966   // We will re-tag the request back to the grpc queue. Once it comes back from the client,
967   // the CacheServerRequest, i.e. the pointer cache_req, will be free
968   if (!internal_request && !global_shutdown_) {
969     cache_req->responder_.Finish(cache_req->reply_, grpc::Status::OK, cache_req);
970   } else {
971     // We can free up the request now.
972     RETURN_IF_NOT_OK(ReturnRequestTag(cache_req));
973   }
974   return Status::OK();
975 }
976 
977 /// \brief This is the main loop the cache server thread(s) are running.
978 /// Each thread will pop a request and send the result back to the client using grpc
979 /// \return
ServerRequest(worker_id_t worker_id)980 Status CacheServer::ServerRequest(worker_id_t worker_id) {
981   TaskManager::FindMe()->Post();
982   MS_LOG(DEBUG) << "Worker id " << worker_id << " is running on node " << hw_info_->GetMyNode();
983   auto &my_que = cache_q_->operator[](worker_id);
984   // Loop forever until we are interrupted or shutdown.
985   while (!global_shutdown_) {
986     CacheServerRequest *cache_req = nullptr;
987     RETURN_IF_NOT_OK(my_que->PopFront(&cache_req));
988     RETURN_IF_NOT_OK(ProcessRequest(cache_req));
989   }
990   return Status::OK();
991 }
992 
GetConnectionID(session_id_type session_id,uint32_t crc) const993 connection_id_type CacheServer::GetConnectionID(session_id_type session_id, uint32_t crc) const {
994   connection_id_type connection_id =
995     (static_cast<connection_id_type>(session_id) << 32u) | static_cast<connection_id_type>(crc);
996   return connection_id;
997 }
998 
GetSessionID(connection_id_type connection_id) const999 session_id_type CacheServer::GetSessionID(connection_id_type connection_id) const {
1000   return static_cast<session_id_type>(connection_id >> 32u);
1001 }
1002 
CacheServer(const std::string & spill_path,int32_t num_workers,int32_t port,int32_t shared_meory_sz_in_gb,float memory_cap_ratio,int8_t log_level,std::shared_ptr<CacheServerHW> hw_info)1003 CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port,
1004                          int32_t shared_meory_sz_in_gb, float memory_cap_ratio, int8_t log_level,
1005                          std::shared_ptr<CacheServerHW> hw_info)
1006     : top_(spill_path),
1007       num_workers_(num_workers),
1008       num_grpc_workers_(num_workers_),
1009       port_(port),
1010       shared_memory_sz_in_gb_(shared_meory_sz_in_gb),
1011       global_shutdown_(false),
1012       memory_cap_ratio_(memory_cap_ratio),
1013       numa_affinity_(true),
1014       log_level_(log_level),
1015       hw_info_(std::move(hw_info)) {
1016   // If we are not linked with numa library (i.e. NUMA_ENABLED is false), turn off cpu
1017   // affinity which can make performance worse.
1018   if (!CacheServerHW::numa_enabled()) {
1019     numa_affinity_ = false;
1020     MS_LOG(WARNING) << "Warning: This build is not compiled with numa support.  Install libnuma-devel and use a build "
1021                        "that is compiled with numa support for more optimal performance";
1022   }
1023   // We create the shared memory and we will destroy it. All other client just detach only.
1024   if (shared_memory_sz_in_gb_ > kDefaultSharedMemorySize) {
1025     MS_LOG(INFO) << "Shared memory size is readjust to " << kDefaultSharedMemorySize << " GB.";
1026     shared_memory_sz_in_gb_ = kDefaultSharedMemorySize;
1027   }
1028 }
1029 
Run(int msg_qid)1030 Status CacheServer::Run(int msg_qid) {
1031   Status rc = ServiceStart();
1032   // If there is a message que, return the status now before we call join_all which will never return
1033   if (msg_qid != -1) {
1034     SharedMessage msg(msg_qid);
1035     RETURN_IF_NOT_OK(msg.SendStatus(rc));
1036   }
1037   if (rc.IsError()) {
1038     return rc;
1039   }
1040   // This is called by the main function and we shouldn't exit. Otherwise the main thread
1041   // will just shutdown. So we will call some function that never return unless error.
1042   // One good case will be simply to wait for all threads to return.
1043   // note that after we have sent the initial status using the msg_qid, parent process will exit and
1044   // remove it. So we can't use it again.
1045   RETURN_IF_NOT_OK(vg_.join_all(Task::WaitFlag::kBlocking));
1046   // Shutdown the grpc queue. No longer accept any new comer.
1047   comm_layer_->Shutdown();
1048   // The next thing to do drop all the caches.
1049   RETURN_IF_NOT_OK(ServiceStop());
1050   return Status::OK();
1051 }
1052 
GetFreeRequestTag(CacheServerRequest ** q)1053 Status CacheServer::GetFreeRequestTag(CacheServerRequest **q) {
1054   RETURN_UNEXPECTED_IF_NULL(q);
1055   auto *p = new (std::nothrow) CacheServerRequest();
1056   if (p == nullptr) {
1057     RETURN_STATUS_OOM("Out of memory.");
1058   }
1059   *q = p;
1060   return Status::OK();
1061 }
1062 
ReturnRequestTag(CacheServerRequest * p)1063 Status CacheServer::ReturnRequestTag(CacheServerRequest *p) {
1064   RETURN_UNEXPECTED_IF_NULL(p);
1065   delete p;
1066   return Status::OK();
1067 }
1068 
DestroySession(CacheRequest * rq)1069 Status CacheServer::DestroySession(CacheRequest *rq) {
1070   CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing session id");
1071   auto drop_session_id = rq->connection_info().session_id();
1072   // Grab the locks in the correct order to avoid deadlock.
1073   UniqueLock sess_lck(&sessions_lock_);
1074   UniqueLock lck(&rwLock_);
1075   // Iterate over the set of connection id's for this session that we're dropping and erase each one.
1076   bool found = false;
1077   for (auto it = all_caches_.begin(); it != all_caches_.end();) {
1078     auto connection_id = it->first;
1079     auto session_id = GetSessionID(connection_id);
1080     // We can just call DestroyCache() but we are holding a lock already. Doing so will cause deadlock.
1081     // So we will just manually do it.
1082     if (session_id == drop_session_id) {
1083       found = true;
1084       it = all_caches_.erase(it);
1085       MS_LOG(INFO) << "Destroy cache with id " << connection_id;
1086     } else {
1087       ++it;
1088     }
1089   }
1090   // Finally remove the session itself
1091   auto n = active_sessions_.erase(drop_session_id);
1092   if (n > 0) {
1093     MS_LOG(INFO) << "Session destroyed with id " << drop_session_id;
1094     return Status::OK();
1095   } else {
1096     if (found) {
1097       std::string errMsg =
1098         "A destroy cache request has been completed but it had a stale session id " + std::to_string(drop_session_id);
1099       RETURN_STATUS_UNEXPECTED(errMsg);
1100     } else {
1101       std::string errMsg =
1102         "Session id " + std::to_string(drop_session_id) + " not found in server on port " + std::to_string(port_) + ".";
1103       RETURN_STATUS_ERROR(StatusCode::kMDFileNotExist, errMsg);
1104     }
1105   }
1106 }
1107 
GenerateSessionID()1108 session_id_type CacheServer::GenerateSessionID() {
1109   UniqueLock sess_lck(&sessions_lock_);
1110   auto mt = GetRandomDevice();
1111   std::uniform_int_distribution<session_id_type> distribution(0, std::numeric_limits<session_id_type>::max());
1112   session_id_type session_id;
1113   bool duplicate = false;
1114   do {
1115     session_id = distribution(mt);
1116     auto r = active_sessions_.insert(session_id);
1117     duplicate = !r.second;
1118   } while (duplicate);
1119   return session_id;
1120 }
1121 
AllocateSharedMemory(CacheRequest * rq,CacheReply * reply)1122 Status CacheServer::AllocateSharedMemory(CacheRequest *rq, CacheReply *reply) {
1123   auto client_id = rq->client_id();
1124   CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
1125   CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing requested size");
1126   try {
1127     auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, kDecimal);
1128     void *p = nullptr;
1129     RETURN_IF_NOT_OK(AllocateSharedMemory(client_id, requestedSz, &p));
1130     auto *base = SharedMemoryBaseAddr();
1131     // We can't return the absolute address which makes no sense to the client.
1132     // Instead we return the difference.
1133     auto difference = reinterpret_cast<int64_t>(p) - reinterpret_cast<int64_t>(base);
1134     reply->set_result(std::to_string(difference));
1135   } catch (const std::exception &e) {
1136     RETURN_STATUS_UNEXPECTED(e.what());
1137   }
1138   return Status::OK();
1139 }
1140 
FreeSharedMemory(CacheRequest * rq)1141 Status CacheServer::FreeSharedMemory(CacheRequest *rq) {
1142   auto client_id = rq->client_id();
1143   CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
1144   CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing addr");
1145   auto *base = SharedMemoryBaseAddr();
1146   try {
1147     auto addr = strtoll(rq->buf_data(0).data(), nullptr, kDecimal);
1148     auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr);
1149     DeallocateSharedMemory(client_id, p);
1150   } catch (const std::exception &e) {
1151     RETURN_STATUS_UNEXPECTED(e.what());
1152   }
1153   return Status::OK();
1154 }
1155 
GetCacheState(CacheRequest * rq,CacheReply * reply)1156 Status CacheServer::GetCacheState(CacheRequest *rq, CacheReply *reply) {
1157   auto connection_id = rq->connection_id();
1158   SharedLock lck(&rwLock_);
1159   CacheService *cs = GetService(connection_id);
1160   if (cs == nullptr) {
1161     std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
1162     RETURN_STATUS_UNEXPECTED(errMsg);
1163   } else {
1164     auto state = cs->GetState();
1165     reply->set_result(std::to_string(static_cast<int8_t>(state)));
1166     return Status::OK();
1167   }
1168 }
1169 
RpcRequest(worker_id_t worker_id)1170 Status CacheServer::RpcRequest(worker_id_t worker_id) {
1171   TaskManager::FindMe()->Post();
1172   RETURN_IF_NOT_OK(comm_layer_->HandleRequest(worker_id));
1173   return Status::OK();
1174 }
1175 
AcknowledgeShutdown(CacheServerRequest * cache_req)1176 Status CacheServer::AcknowledgeShutdown(CacheServerRequest *cache_req) {
1177   auto *rq = &cache_req->rq_;
1178   auto *reply = &cache_req->reply_;
1179   if (!rq->buf_data().empty()) {
1180     // cache_admin sends us a message qID and we will destroy the
1181     // queue in our destructor and this will wake up cache_admin.
1182     // But we don't want the cache_admin blindly just block itself.
1183     // So we will send back an ack before shutdown the comm layer.
1184     try {
1185       int32_t qID = std::stoi(rq->buf_data(0));
1186       shutdown_qIDs_.push_back(qID);
1187     } catch (const std::exception &e) {
1188       // ignore it.
1189     }
1190   }
1191   reply->set_result("OK");
1192   return Status::OK();
1193 }
1194 
GlobalShutdown()1195 void CacheServer::GlobalShutdown() {
1196   // Let's shutdown in proper order.
1197   bool expected = false;
1198   if (global_shutdown_.compare_exchange_strong(expected, true)) {
1199     MS_LOG(WARNING) << "Shutting down server.";
1200     // Interrupt all the threads and queues. We will leave the shutdown
1201     // of the comm layer after we have joined all the threads and will
1202     // be done by the master thread.
1203     vg_.interrupt_all();
1204   }
1205 }
1206 
GetWorkerByNumaId(numa_id_t numa_id) const1207 worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) const {
1208   auto num_numa_nodes = GetNumaNodeCount();
1209   MS_ASSERT(numa_id < num_numa_nodes);
1210   auto num_workers_per_node = GetNumWorkers() / num_numa_nodes;
1211   std::mt19937 gen = GetRandomDevice();
1212   std::uniform_int_distribution<worker_id_t> dist(0, num_workers_per_node - 1);
1213   auto n = dist(gen);
1214   worker_id_t worker_id = n * num_numa_nodes + numa_id;
1215   MS_ASSERT(worker_id < GetNumWorkers());
1216   return worker_id;
1217 }
1218 
GetRandomWorker() const1219 worker_id_t CacheServer::GetRandomWorker() const {
1220   std::mt19937 gen = GetRandomDevice();
1221   std::uniform_int_distribution<worker_id_t> dist(0, num_workers_ - 1);
1222   return dist(gen);
1223 }
1224 
AllocateSharedMemory(int32_t client_id,size_t sz,void ** p)1225 Status CacheServer::AllocateSharedMemory(int32_t client_id, size_t sz, void **p) {
1226   return shm_->AllocateSharedMemory(client_id, sz, p);
1227 }
1228 
DeallocateSharedMemory(int32_t client_id,void * p)1229 void CacheServer::DeallocateSharedMemory(int32_t client_id, void *p) { shm_->DeallocateSharedMemory(client_id, p); }
1230 
IpcResourceCleanup()1231 Status CacheServer::Builder::IpcResourceCleanup() {
1232   Status rc;
1233   SharedMemory::shm_key_t shm_key;
1234   auto unix_socket = PortToUnixSocketPath(port_);
1235   rc = PortToFtok(port_, &shm_key);
1236   // We are expecting the unix path doesn't exist.
1237   if (rc.IsError()) {
1238     return Status::OK();
1239   }
1240   // Attach to the shared memory which we expect don't exist
1241   SharedMemory mem(shm_key);
1242   rc = mem.Attach();
1243   if (rc.IsError()) {
1244     return Status::OK();
1245   } else {
1246     RETURN_IF_NOT_OK(mem.Detach());
1247   }
1248   int32_t num_attached;
1249   RETURN_IF_NOT_OK(mem.GetNumAttached(&num_attached));
1250   if (num_attached == 0) {
1251     // Stale shared memory from last time.
1252     // Remove both the memory and the socket path
1253     RETURN_IF_NOT_OK(mem.Destroy());
1254     Path p(unix_socket);
1255     (void)p.Remove();
1256   } else {
1257     // Server is already up.
1258     std::string errMsg = "Cache server is already up and running";
1259     // We return a duplicate error. The main() will intercept
1260     // and output a proper message
1261     RETURN_STATUS_ERROR(StatusCode::kMDDuplicateKey, errMsg);
1262   }
1263   return Status::OK();
1264 }
1265 
SanityCheck()1266 Status CacheServer::Builder::SanityCheck() {
1267   if (shared_memory_sz_in_gb_ <= 0) {
1268     RETURN_STATUS_UNEXPECTED("Shared memory size (in GB unit) must be positive");
1269   }
1270   if (num_workers_ <= 0) {
1271     RETURN_STATUS_UNEXPECTED("Number of parallel workers must be positive");
1272   }
1273   if (!top_.empty()) {
1274     auto p = top_.data();
1275     if (p[0] != '/') {
1276       RETURN_STATUS_UNEXPECTED("Spilling directory must be an absolute path");
1277     }
1278     // Check if the spill directory is writable
1279     Path spill(top_);
1280     auto t = spill / Services::GetUniqueID();
1281     Status rc = t.CreateDirectory();
1282     if (rc.IsOk()) {
1283       rc = t.Remove();
1284     }
1285     if (rc.IsError()) {
1286       RETURN_STATUS_UNEXPECTED("Spilling directory is not writable\n" + rc.ToString());
1287     }
1288   }
1289   if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1) {
1290     RETURN_STATUS_UNEXPECTED("Memory cap ratio should be positive and no greater than 1");
1291   }
1292 
1293   // Check if the shared memory.
1294   RETURN_IF_NOT_OK(IpcResourceCleanup());
1295   return Status::OK();
1296 }
1297 
AdjustNumWorkers(int32_t num_workers)1298 int32_t CacheServer::Builder::AdjustNumWorkers(int32_t num_workers) {
1299   int32_t num_numa_nodes = hw_info_->GetNumaNodeCount();
1300   // Bump up num_workers_ to at least the number of numa nodes
1301   num_workers = std::max(num_numa_nodes, num_workers);
1302   // But also it shouldn't be too many more than the hardware concurrency
1303   int32_t num_cpus = hw_info_->GetCpuCount();
1304   constexpr int32_t kThreadsPerCore = 2;
1305   num_workers = std::min(kThreadsPerCore * num_cpus, num_workers);
1306   // Round up num_workers to a multiple of numa nodes.
1307   auto remainder = num_workers % num_numa_nodes;
1308   if (remainder > 0) {
1309     num_workers += (num_numa_nodes - remainder);
1310   }
1311   return num_workers;
1312 }
1313 
Builder()1314 CacheServer::Builder::Builder()
1315     : top_(""),
1316       num_workers_(kDefaultNumWorkers),
1317       port_(kCfgDefaultCachePort),
1318       shared_memory_sz_in_gb_(kDefaultSharedMemorySize),
1319       memory_cap_ratio_(kDefaultMemoryCapRatio),
1320       log_level_(kDefaultLogLevel) {
1321   if (num_workers_ == 0) {
1322     num_workers_ = 1;
1323   }
1324 }
1325 }  // namespace dataset
1326 }  // namespace mindspore
1327