1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7
8 * http://www.apache.org/licenses/LICENSE-2.0
9
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/cache/cache_server.h"
17 #include <algorithm>
18 #include <functional>
19 #include <limits>
20 #include <vector>
21 #include "minddata/dataset/include/dataset/constants.h"
22 #include "minddata/dataset/engine/cache/cache_ipc.h"
23 #include "minddata/dataset/engine/cache/cache_service.h"
24 #include "minddata/dataset/engine/cache/cache_request.h"
25 #include "minddata/dataset/util/bit.h"
26 #include "minddata/dataset/util/path.h"
27 #include "minddata/dataset/util/random.h"
28 #ifdef CACHE_LOCAL_CLIENT
29 #include "minddata/dataset/util/sig_handler.h"
30 #endif
31 #undef BitTest
32
33 namespace mindspore {
34 namespace dataset {
35 CacheServer *CacheServer::instance_ = nullptr;
36 std::once_flag CacheServer::init_instance_flag_;
DoServiceStart()37 Status CacheServer::DoServiceStart() {
38 #ifdef CACHE_LOCAL_CLIENT
39 // We need to destroy the shared memory if user hits Control-C
40 RegisterHandlers();
41 #endif
42 if (!top_.empty()) {
43 Path spill(top_);
44 RETURN_IF_NOT_OK(spill.CreateDirectories());
45 MS_LOG(INFO) << "CacheServer will use disk folder: " << top_;
46 }
47 RETURN_IF_NOT_OK(vg_.ServiceStart());
48 auto num_numa_nodes = GetNumaNodeCount();
49 // If we link with numa library. Set default memory policy.
50 // If we don't pin thread to cpu, then use up all memory controllers to maximize
51 // memory bandwidth.
52 RETURN_IF_NOT_OK(
53 CacheServerHW::SetDefaultMemoryPolicy(numa_affinity_ ? CachePoolPolicy::kLocal : CachePoolPolicy::kInterleave));
54 auto my_node = hw_info_->GetMyNode();
55 MS_LOG(DEBUG) << "Cache server is running on numa node " << my_node;
56 // There will be some threads working on the grpc queue and
57 // some number of threads working on the CacheServerRequest queue.
58 // Like a connector object we will set up the same number of queues but
59 // we do not need to preserve any order. We will set the capacity of
60 // each queue to be 64 since we are just pushing memory pointers which
61 // is only 8 byte each.
62 const int32_t kQueCapacity = 64;
63 // This is the request queue from the client
64 cache_q_ = std::make_shared<QueueList<CacheServerRequest *>>();
65 cache_q_->Init(num_workers_, kQueCapacity);
66 // We will match the number of grpc workers with the number of server workers.
67 // But technically they don't have to be the same.
68 num_grpc_workers_ = num_workers_;
69 MS_LOG(DEBUG) << "Number of gprc workers is set to " << num_grpc_workers_;
70 RETURN_IF_NOT_OK(cache_q_->Register(&vg_));
71 // Start the comm layer
72 try {
73 comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_);
74 RETURN_IF_NOT_OK(comm_layer_->Run());
75 } catch (const std::exception &e) {
76 RETURN_STATUS_UNEXPECTED(e.what());
77 }
78 #ifdef CACHE_LOCAL_CLIENT
79 RETURN_IF_NOT_OK(CachedSharedMemory::CreateArena(&shm_, port_, shared_memory_sz_in_gb_));
80 // Bring up a thread to monitor the unix socket in case it is removed. But it must be done
81 // after we have created the unix socket.
82 auto inotify_f = std::bind(&CacheServerGreeterImpl::MonitorUnixSocket, comm_layer_.get());
83 RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Monitor unix socket", inotify_f));
84 #endif
85 // Spawn a few threads to serve the real request.
86 auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1);
87 for (auto i = 0; i < num_workers_; ++i) {
88 Task *pTask;
89 RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache service worker", std::bind(f, i), &pTask));
90 // Save a copy of the pointer to the underlying Task object. We may dynamically change their affinity if needed.
91 numa_tasks_.emplace(i, pTask);
92 // Spread out all the threads to all the numa nodes if needed
93 if (IsNumaAffinityOn()) {
94 auto numa_id = i % num_numa_nodes;
95 RETURN_IF_NOT_OK(SetAffinity(*pTask, numa_id));
96 }
97 }
98 // Finally loop forever to handle the request.
99 auto r = std::bind(&CacheServer::RpcRequest, this, std::placeholders::_1);
100 for (auto i = 0; i < num_grpc_workers_; ++i) {
101 Task *pTask;
102 RETURN_IF_NOT_OK(vg_.CreateAsyncTask("rpc worker", std::bind(r, i), &pTask));
103 // All these grpc workers will be allocated to the same node which is where we allocate all those free tag
104 // memory.
105 if (IsNumaAffinityOn()) {
106 RETURN_IF_NOT_OK(SetAffinity(*pTask, i % num_numa_nodes));
107 }
108 }
109 return Status::OK();
110 }
111
DoServiceStop()112 Status CacheServer::DoServiceStop() {
113 Status rc;
114 Status rc2;
115 // First stop all the threads.
116 RETURN_IF_NOT_OK(vg_.ServiceStop());
117 // Clean up all the caches if any.
118 UniqueLock lck(&rwLock_);
119 auto it = all_caches_.begin();
120 while (it != all_caches_.end()) {
121 auto cs = std::move(it->second);
122 rc2 = cs->ServiceStop();
123 if (rc2.IsError()) {
124 rc = rc2;
125 }
126 ++it;
127 }
128 // Also remove the path we use to generate ftok.
129 Path p(PortToUnixSocketPath(port_));
130 (void)p.Remove();
131 // Finally wake up cache_admin if it is waiting
132 for (int32_t qID : shutdown_qIDs_) {
133 SharedMessage msg(qID);
134 RETURN_IF_NOT_OK(msg.SendStatus(Status::OK()));
135 msg.RemoveResourcesOnExit();
136 // Let msg goes out of scope which will destroy the queue.
137 }
138 return rc;
139 }
140
GetService(connection_id_type id) const141 CacheService *CacheServer::GetService(connection_id_type id) const {
142 auto it = all_caches_.find(id);
143 if (it != all_caches_.end()) {
144 return it->second.get();
145 }
146 return nullptr;
147 }
148
149 // We would like to protect ourselves from over allocating too much. We will go over existing cache
150 // and calculate how much we have consumed so far.
GlobalMemoryCheck(uint64_t cache_mem_sz)151 Status CacheServer::GlobalMemoryCheck(uint64_t cache_mem_sz) {
152 auto end = all_caches_.end();
153 auto it = all_caches_.begin();
154 auto avail_mem = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_;
155 int64_t max_avail = avail_mem;
156 while (it != end) {
157 auto &cs = it->second;
158 CacheService::ServiceStat stat;
159 RETURN_IF_NOT_OK(cs->GetStat(&stat));
160 int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz;
161 max_avail -= mem_consumed;
162 if (max_avail <= 0) {
163 RETURN_STATUS_OOM("Out of memory, please destroy some sessions.");
164 }
165 ++it;
166 }
167
168 // If we have some cache using some memory already, make a reasonable decision if we should return
169 // out of memory.
170 if (max_avail < avail_mem) {
171 int64_t req_mem = cache_mem_sz * 1048576L; // It is in MB unit.
172 if (req_mem > max_avail) {
173 RETURN_STATUS_OOM("Out of memory, please destroy some sessions");
174 } else if (req_mem == 0) {
175 // This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than
176 // 85% of our limit, fail this request.
177 if (static_cast<float>(max_avail) / static_cast<float>(avail_mem) <= kMemoryBottomLineForNewService) {
178 RETURN_STATUS_OOM("Out of memory, please destroy some sessions");
179 }
180 }
181 }
182 return Status::OK();
183 }
184
CreateService(CacheRequest * rq,CacheReply * reply)185 Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
186 CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing connection info");
187 std::string cookie;
188 int32_t client_id;
189 auto session_id = rq->connection_info().session_id();
190 auto crc = rq->connection_info().crc();
191
192 // Before allowing the creation, make sure the session had already been created by the user
193 // Our intention is to add this cache to the active sessions list so leave the list locked during
194 // this entire function.
195 UniqueLock sess_lck(&sessions_lock_);
196 auto session_it = active_sessions_.find(session_id);
197 if (session_it == active_sessions_.end()) {
198 RETURN_STATUS_UNEXPECTED("A cache creation has been requested but the session was not found!");
199 }
200
201 // We concat both numbers to form the internal connection id.
202 auto connection_id = GetConnectionID(session_id, crc);
203 CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing info to create cache");
204 auto &create_cache_buf = rq->buf_data(0);
205 auto p = flatbuffers::GetRoot<CreateCacheRequestMsg>(create_cache_buf.data());
206 auto flag = static_cast<CreateCacheRequest::CreateCacheFlag>(p->flag());
207 auto cache_mem_sz = p->cache_mem_sz();
208 // We can't do spilling unless this server is setup with a spill path in the first place
209 bool spill =
210 (flag & CreateCacheRequest::CreateCacheFlag::kSpillToDisk) == CreateCacheRequest::CreateCacheFlag::kSpillToDisk;
211 bool generate_id =
212 (flag & CreateCacheRequest::CreateCacheFlag::kGenerateRowId) == CreateCacheRequest::CreateCacheFlag::kGenerateRowId;
213 if (spill && top_.empty()) {
214 RETURN_STATUS_UNEXPECTED("Server is not set up with spill support.");
215 }
216 // Before creating the cache, first check if this is a request for a shared usage of an existing cache
217 // If two CreateService come in with identical connection_id, we need to serialize the create.
218 // The first create will be successful and be given a special cookie.
219 UniqueLock lck(&rwLock_);
220 bool duplicate = false;
221 CacheService *curr_cs = GetService(connection_id);
222 if (curr_cs != nullptr) {
223 duplicate = true;
224 client_id = curr_cs->num_clients_.fetch_add(1);
225 MS_LOG(INFO) << "Duplicate request from client " + std::to_string(client_id) + " for " +
226 std::to_string(connection_id) + " to create cache service";
227 }
228 // Early exit if we are doing global shutdown
229 if (global_shutdown_) {
230 return Status::OK();
231 }
232
233 if (!duplicate) {
234 RETURN_IF_NOT_OK(GlobalMemoryCheck(cache_mem_sz));
235 std::unique_ptr<CacheService> cs;
236 try {
237 cs = std::make_unique<CacheService>(cache_mem_sz, spill ? top_ : "", generate_id);
238 RETURN_IF_NOT_OK(cs->ServiceStart());
239 cookie = cs->cookie();
240 client_id = cs->num_clients_.fetch_add(1);
241 all_caches_.emplace(connection_id, std::move(cs));
242 } catch (const std::bad_alloc &e) {
243 RETURN_STATUS_OOM("Out of memory.");
244 }
245 }
246
247 // Shuffle the worker threads. But we need to release the locks or we will deadlock when calling
248 // the following function
249 lck.Unlock();
250 sess_lck.Unlock();
251 auto numa_id = client_id % GetNumaNodeCount();
252 std::vector<cpu_id_t> cpu_list = hw_info_->GetCpuList(numa_id);
253 // Send back the data
254 flatbuffers::FlatBufferBuilder fbb;
255 flatbuffers::Offset<flatbuffers::String> off_cookie;
256 flatbuffers::Offset<flatbuffers::Vector<cpu_id_t>> off_cpu_list;
257 off_cookie = fbb.CreateString(cookie);
258 off_cpu_list = fbb.CreateVector(cpu_list);
259 CreateCacheReplyMsgBuilder bld(fbb);
260 bld.add_connection_id(connection_id);
261 bld.add_cookie(off_cookie);
262 bld.add_client_id(client_id);
263 // The last thing we send back is a set of cpu id that we suggest the client should bind itself to
264 bld.add_cpu_id(off_cpu_list);
265 auto off = bld.Finish();
266 fbb.Finish(off);
267 reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
268 // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it
269 // treat it as OK.
270 return duplicate ? Status(StatusCode::kMDDuplicateKey) : Status::OK();
271 }
272
DestroyCache(CacheRequest * rq)273 Status CacheServer::DestroyCache(CacheRequest *rq) {
274 // We need a strong lock to protect the map.
275 UniqueLock lck(&rwLock_);
276 auto id = rq->connection_id();
277 CacheService *cs = GetService(id);
278 // it is already destroyed. Ignore it.
279 if (cs != nullptr) {
280 MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id);
281 // std::map will invoke the destructor of CacheService. So we don't need to do anything here.
282 auto n = all_caches_.erase(id);
283 if (n == 0) {
284 // It has been destroyed by another duplicate request.
285 MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service";
286 }
287 }
288 // We aren't touching the session list even though we may be dropping the last remaining cache of a session.
289 // Leave that to be done by the drop session command.
290 return Status::OK();
291 }
292
CacheRow(CacheRequest * rq,CacheReply * reply)293 Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) {
294 auto connection_id = rq->connection_id();
295 // Hold the shared lock to prevent the cache from being dropped.
296 SharedLock lck(&rwLock_);
297 CacheService *cs = GetService(connection_id);
298 if (cs == nullptr) {
299 std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
300 RETURN_STATUS_UNEXPECTED(errMsg);
301 } else {
302 auto sz = rq->buf_data_size();
303 std::vector<const void *> buffers;
304 buffers.reserve(sz);
305 // First piece of data is the cookie and is required
306 CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie");
307 auto &cookie = rq->buf_data(0);
308 // Only if the cookie matches, we can accept insert into this cache that has a build phase
309 if (!cs->HasBuildPhase() || cookie == cs->cookie()) {
310 // Push the address of each buffer (in the form of std::string coming in from protobuf) into
311 // a vector of buffer
312 for (auto i = 1; i < sz; ++i) {
313 buffers.push_back(rq->buf_data(i).data());
314 }
315 row_id_type id = -1;
316 // We will allocate the memory the same numa node this thread is bound to.
317 RETURN_IF_NOT_OK(cs->CacheRow(buffers, &id));
318 reply->set_result(std::to_string(id));
319 } else {
320 RETURN_STATUS_UNEXPECTED("Cookie mismatch");
321 }
322 }
323 return Status::OK();
324 }
325
FastCacheRow(CacheRequest * rq,CacheReply * reply)326 Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
327 auto connection_id = rq->connection_id();
328 auto client_id = rq->client_id();
329 CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
330 // Hold the shared lock to prevent the cache from being dropped.
331 SharedLock lck(&rwLock_);
332 CacheService *cs = GetService(connection_id);
333 auto *base = SharedMemoryBaseAddr();
334 // Ensure we got 3 pieces of data coming in
335 constexpr int32_t kMinBufDataSize = 3;
336 CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() >= kMinBufDataSize, "Incomplete data");
337 // First one is cookie, followed by data address and then size.
338 enum BufDataIndex : uint8_t { kCookie = 0, kAddr = 1, kSize = 2 };
339 // First piece of data is the cookie and is required
340 auto &cookie = rq->buf_data(BufDataIndex::kCookie);
341 // Second piece of data is the address where we can find the serialized data
342 auto addr = strtoll(rq->buf_data(BufDataIndex::kAddr).data(), nullptr, kDecimal);
343 auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr);
344 // Third piece of data is the size of the serialized data that we need to transfer
345 auto sz = strtoll(rq->buf_data(BufDataIndex::kSize).data(), nullptr, kDecimal);
346 // Successful or not, we need to free the memory on exit.
347 Status rc;
348 if (cs == nullptr) {
349 std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
350 rc = STATUS_ERROR(StatusCode::kMDUnexpectedError, errMsg);
351 } else {
352 // Only if the cookie matches, we can accept insert into this cache that has a build phase
353 if (!cs->HasBuildPhase() || cookie == cs->cookie()) {
354 row_id_type id = -1;
355 ReadableSlice src(p, sz);
356 // We will allocate the memory the same numa node this thread is bound to.
357 rc = cs->FastCacheRow(src, &id);
358 reply->set_result(std::to_string(id));
359 } else {
360 auto state = cs->GetState();
361 if (state != CacheServiceState::kFetchPhase) {
362 rc = STATUS_ERROR(StatusCode::kMDUnexpectedError, "Cache service is not in fetch phase. The current phase is " +
363 std::to_string(static_cast<int8_t>(state)) +
364 ". Client id: " + std::to_string(client_id));
365 } else {
366 rc = STATUS_ERROR(StatusCode::kMDUnexpectedError, "Cookie mismatch. Client id: " + std::to_string(client_id));
367 }
368 }
369 }
370 // Return the block to the shared memory only if it is not internal request.
371 if (static_cast<BaseRequest::RequestType>(rq->type()) == BaseRequest::RequestType::kCacheRow) {
372 DeallocateSharedMemory(client_id, p);
373 }
374 return rc;
375 }
376
InternalCacheRow(CacheRequest * rq,CacheReply * reply)377 Status CacheServer::InternalCacheRow(CacheRequest *rq, CacheReply *reply) {
378 // Look into the flag to see where we can find the data and call the appropriate method.
379 auto flag = rq->flag();
380 Status rc;
381 if (BitTest(flag, kDataIsInSharedMemory)) {
382 rc = FastCacheRow(rq, reply);
383 // This is an internal request and is not tied to rpc. But need to post because there
384 // is a thread waiting on the completion of this request.
385 try {
386 constexpr int32_t kBatchWaitIdx = 3;
387 // Ensure we got 4 pieces of data coming in
388 constexpr int32_t kMinBufDataSize = 4;
389 CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() >= kMinBufDataSize, "Incomplete data");
390 // Fourth piece of the data is the address of the BatchWait ptr
391 int64_t addr = strtol(rq->buf_data(kBatchWaitIdx).data(), nullptr, kDecimal);
392 auto *bw = reinterpret_cast<BatchWait *>(addr);
393 // Check if the object is still around.
394 auto bwObj = bw->GetBatchWait();
395 if (bwObj.lock()) {
396 RETURN_IF_NOT_OK(bw->Set(rc));
397 }
398 } catch (const std::exception &e) {
399 RETURN_STATUS_UNEXPECTED(e.what());
400 }
401 } else {
402 rc = CacheRow(rq, reply);
403 }
404 return rc;
405 }
406
InternalFetchRow(CacheRequest * rq)407 Status CacheServer::InternalFetchRow(CacheRequest *rq) {
408 auto connection_id = rq->connection_id();
409 SharedLock lck(&rwLock_);
410 CacheService *cs = GetService(connection_id);
411 Status rc;
412 if (cs == nullptr) {
413 std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
414 RETURN_STATUS_UNEXPECTED(errMsg);
415 }
416 // First piece is a flatbuffer containing row fetch information, second piece is the address of the BatchWait ptr
417 enum BufDataIndex : uint8_t { kFetchRowMsg = 0, kBatchWait = 1 };
418 // Ensure we got 2 pieces of data coming in
419 constexpr int32_t kMinBufDataSize = 2;
420 CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() >= kMinBufDataSize, "Incomplete data");
421 rc = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq->buf_data(BufDataIndex::kFetchRowMsg).data()));
422 // This is an internal request and is not tied to rpc. But need to post because there
423 // is a thread waiting on the completion of this request.
424 try {
425 int64_t addr = strtol(rq->buf_data(BufDataIndex::kBatchWait).data(), nullptr, kDecimal);
426 auto *bw = reinterpret_cast<BatchWait *>(addr);
427 // Check if the object is still around.
428 auto bwObj = bw->GetBatchWait();
429 if (bwObj.lock()) {
430 RETURN_IF_NOT_OK(bw->Set(rc));
431 }
432 } catch (const std::exception &e) {
433 RETURN_STATUS_UNEXPECTED(e.what());
434 }
435 return rc;
436 }
437
BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> & fbb,WritableSlice * out)438 Status CacheServer::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out) {
439 RETURN_UNEXPECTED_IF_NULL(out);
440 auto p = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer());
441 RETURN_UNEXPECTED_IF_NULL(p);
442 const auto num_elements = p->rows()->size();
443 auto connection_id = p->connection_id();
444 auto batch_wait = std::make_shared<BatchWait>(num_elements);
445 int64_t data_offset = (num_elements + 1) * sizeof(int64_t);
446 auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer());
447 offset_array[0] = data_offset;
448 for (uint32_t i = 0; i < num_elements; ++i) {
449 auto data_locator = p->rows()->Get(i);
450 auto node_id = data_locator->node_id();
451 size_t sz = data_locator->size();
452 void *source_addr = reinterpret_cast<void *>(data_locator->addr());
453 auto key = data_locator->key();
454 // Please read the comment in CacheServer::BatchFetchRows where we allocate
455 // the buffer big enough so each thread (which we are going to dispatch) will
456 // not run into false sharing problem. We are going to round up sz to 4k.
457 auto sz_4k = round_up_4K(sz);
458 offset_array[i + 1] = offset_array[i] + sz_4k;
459 if (sz > 0) {
460 WritableSlice row_data(*out, offset_array[i], sz);
461 // Get a request and send to the proper worker (at some numa node) to do the fetch.
462 worker_id_t worker_id = IsNumaAffinityOn() ? GetWorkerByNumaId(node_id) : GetRandomWorker();
463 CacheServerRequest *cache_rq;
464 RETURN_IF_NOT_OK(GetFreeRequestTag(&cache_rq));
465 // Set up all the necessarily field.
466 cache_rq->type_ = BaseRequest::RequestType::kInternalFetchRow;
467 cache_rq->st_ = CacheServerRequest::STATE::PROCESS;
468 cache_rq->rq_.set_connection_id(connection_id);
469 cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_));
470 auto dest_addr = row_data.GetMutablePointer();
471 flatbuffers::FlatBufferBuilder fb2;
472 FetchRowMsgBuilder bld(fb2);
473 bld.add_key(key);
474 bld.add_size(sz);
475 bld.add_source_addr(reinterpret_cast<int64_t>(source_addr));
476 bld.add_dest_addr(reinterpret_cast<int64_t>(dest_addr));
477 auto offset = bld.Finish();
478 fb2.Finish(offset);
479 cache_rq->rq_.add_buf_data(fb2.GetBufferPointer(), fb2.GetSize());
480 cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast<int64_t>(batch_wait.get())));
481 RETURN_IF_NOT_OK(PushRequest(worker_id, cache_rq));
482 } else {
483 // Nothing to fetch but we still need to post something back into the wait area.
484 RETURN_IF_NOT_OK(batch_wait->Set(Status::OK()));
485 }
486 }
487 // Now wait for all of them to come back.
488 RETURN_IF_NOT_OK(batch_wait->Wait());
489 // Return the result
490 return batch_wait->GetRc();
491 }
492
BatchFetchRows(CacheRequest * rq,CacheReply * reply)493 Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
494 auto connection_id = rq->connection_id();
495 auto client_id = rq->client_id();
496 // Hold the shared lock to prevent the cache from being dropped.
497 SharedLock lck(&rwLock_);
498 CacheService *cs = GetService(connection_id);
499 if (cs == nullptr) {
500 std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
501 RETURN_STATUS_UNEXPECTED(errMsg);
502 } else {
503 CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing row id");
504 auto &row_id_buf = rq->buf_data(0);
505 auto p = flatbuffers::GetRoot<TensorRowIds>(row_id_buf.data());
506 RETURN_UNEXPECTED_IF_NULL(p);
507 std::vector<row_id_type> row_id;
508 auto sz = p->row_id()->size();
509 row_id.reserve(sz);
510 for (uint32_t i = 0; i < sz; ++i) {
511 row_id.push_back(p->row_id()->Get(i));
512 }
513 std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
514 RETURN_IF_NOT_OK(cs->PreBatchFetch(connection_id, row_id, fbb));
515 // Let go of the shared lock. We don't need to interact with the CacheService anymore.
516 // We shouldn't be holding any lock while we can wait for a long time for the rows to come back.
517 lck.Unlock();
518 auto locator = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer());
519 int64_t mem_sz = sizeof(int64_t) * (sz + 1);
520 for (auto i = 0; i < sz; ++i) {
521 auto row_sz = locator->rows()->Get(i)->size();
522 // row_sz is the size of the cached data. Later we will spawn multiple threads
523 // each of which will copy the data into either shared memory or protobuf concurrently but
524 // to different region.
525 // To avoid false sharing, we will bump up row_sz to be a multiple of 4k, i.e. 4096 bytes
526 row_sz = round_up_4K(row_sz);
527 mem_sz += row_sz;
528 }
529 auto client_flag = rq->flag();
530 bool local_client = BitTest(client_flag, kLocalClientSupport);
531 // For large amount data to be sent back, we will use shared memory provided it is a local
532 // client that has local bypass support
533 bool local_bypass = local_client ? (mem_sz >= kLocalByPassThreshold) : false;
534 reply->set_flag(local_bypass ? kDataIsInSharedMemory : 0);
535 if (local_bypass) {
536 // We will use shared memory
537 auto *base = SharedMemoryBaseAddr();
538 void *q = nullptr;
539 RETURN_IF_NOT_OK(AllocateSharedMemory(client_id, mem_sz, &q));
540 WritableSlice dest(q, mem_sz);
541 Status rc = BatchFetch(fbb, &dest);
542 if (rc.IsError()) {
543 DeallocateSharedMemory(client_id, q);
544 return rc;
545 }
546 // We can't return the absolute address which makes no sense to the client.
547 // Instead we return the difference.
548 auto difference = reinterpret_cast<int64_t>(q) - reinterpret_cast<int64_t>(base);
549 reply->set_result(std::to_string(difference));
550 } else {
551 // We are going to use std::string to allocate and hold the result which will be eventually
552 // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
553 // to minimize memory copy.
554 std::string mem;
555 try {
556 mem.resize(mem_sz);
557 CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= mem_sz, "Programming error");
558 } catch (const std::bad_alloc &e) {
559 RETURN_STATUS_OOM("Out of memory.");
560 }
561 WritableSlice dest(mem.data(), mem_sz);
562 RETURN_IF_NOT_OK(BatchFetch(fbb, &dest));
563 reply->set_result(std::move(mem));
564 }
565 }
566 return Status::OK();
567 }
568
GetStat(CacheRequest * rq,CacheReply * reply)569 Status CacheServer::GetStat(CacheRequest *rq, CacheReply *reply) {
570 auto connection_id = rq->connection_id();
571 // Hold the shared lock to prevent the cache from being dropped.
572 SharedLock lck(&rwLock_);
573 CacheService *cs = GetService(connection_id);
574 if (cs == nullptr) {
575 std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
576 RETURN_STATUS_UNEXPECTED(errMsg);
577 } else {
578 CacheService::ServiceStat svc_stat;
579 RETURN_IF_NOT_OK(cs->GetStat(&svc_stat));
580 flatbuffers::FlatBufferBuilder fbb;
581 ServiceStatMsgBuilder bld(fbb);
582 bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached);
583 bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached);
584 bld.add_avg_cache_sz(svc_stat.stat_.average_cache_sz);
585 bld.add_num_numa_hit(svc_stat.stat_.num_numa_hit);
586 bld.add_max_row_id(svc_stat.stat_.max_key);
587 bld.add_min_row_id(svc_stat.stat_.min_key);
588 bld.add_state(svc_stat.state_);
589 auto offset = bld.Finish();
590 fbb.Finish(offset);
591 reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
592 }
593 return Status::OK();
594 }
595
CacheSchema(CacheRequest * rq)596 Status CacheServer::CacheSchema(CacheRequest *rq) {
597 auto connection_id = rq->connection_id();
598 // Hold the shared lock to prevent the cache from being dropped.
599 SharedLock lck(&rwLock_);
600 CacheService *cs = GetService(connection_id);
601 if (cs == nullptr) {
602 std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
603 RETURN_STATUS_UNEXPECTED(errMsg);
604 } else {
605 CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing schema information");
606 auto &create_schema_buf = rq->buf_data(0);
607 RETURN_IF_NOT_OK(cs->CacheSchema(create_schema_buf.data(), create_schema_buf.size()));
608 }
609 return Status::OK();
610 }
611
FetchSchema(CacheRequest * rq,CacheReply * reply)612 Status CacheServer::FetchSchema(CacheRequest *rq, CacheReply *reply) {
613 auto connection_id = rq->connection_id();
614 // Hold the shared lock to prevent the cache from being dropped.
615 SharedLock lck(&rwLock_);
616 CacheService *cs = GetService(connection_id);
617 if (cs == nullptr) {
618 std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
619 RETURN_STATUS_UNEXPECTED(errMsg);
620 } else {
621 // We are going to use std::string to allocate and hold the result which will be eventually
622 // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
623 // to minimize memory copy.
624 std::string mem;
625 RETURN_IF_NOT_OK(cs->FetchSchema(&mem));
626 reply->set_result(std::move(mem));
627 }
628 return Status::OK();
629 }
630
BuildPhaseDone(CacheRequest * rq)631 Status CacheServer::BuildPhaseDone(CacheRequest *rq) {
632 auto connection_id = rq->connection_id();
633 // Hold the shared lock to prevent the cache from being dropped.
634 SharedLock lck(&rwLock_);
635 CacheService *cs = GetService(connection_id);
636 if (cs == nullptr) {
637 std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
638 RETURN_STATUS_UNEXPECTED(errMsg);
639 } else {
640 // First piece of data is the cookie
641 CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie");
642 auto &cookie = rq->buf_data(0);
643 // We can only allow to switch phase if the cookie match.
644 if (cookie == cs->cookie()) {
645 RETURN_IF_NOT_OK(cs->BuildPhaseDone());
646 } else {
647 RETURN_STATUS_UNEXPECTED("Cookie mismatch");
648 }
649 }
650 return Status::OK();
651 }
652
GetCacheMissKeys(CacheRequest * rq,CacheReply * reply)653 Status CacheServer::GetCacheMissKeys(CacheRequest *rq, CacheReply *reply) {
654 auto connection_id = rq->connection_id();
655 // Hold the shared lock to prevent the cache from being dropped.
656 SharedLock lck(&rwLock_);
657 CacheService *cs = GetService(connection_id);
658 if (cs == nullptr) {
659 std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
660 RETURN_STATUS_UNEXPECTED(errMsg);
661 } else {
662 std::vector<row_id_type> gap;
663 RETURN_IF_NOT_OK(cs->FindKeysMiss(&gap));
664 flatbuffers::FlatBufferBuilder fbb;
665 auto off_t = fbb.CreateVector(gap);
666 TensorRowIdsBuilder bld(fbb);
667 bld.add_row_id(off_t);
668 auto off = bld.Finish();
669 fbb.Finish(off);
670 reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
671 }
672 return Status::OK();
673 }
674
GenerateClientSessionID(session_id_type session_id,CacheReply * reply)675 inline Status GenerateClientSessionID(session_id_type session_id, CacheReply *reply) {
676 reply->set_result(std::to_string(session_id));
677 MS_LOG(INFO) << "Server generated new session id " << session_id;
678 return Status::OK();
679 }
680
ToggleWriteMode(CacheRequest * rq)681 Status CacheServer::ToggleWriteMode(CacheRequest *rq) {
682 auto connection_id = rq->connection_id();
683 // Hold the shared lock to prevent the cache from being dropped.
684 SharedLock lck(&rwLock_);
685 CacheService *cs = GetService(connection_id);
686 if (cs == nullptr) {
687 std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
688 RETURN_STATUS_UNEXPECTED(errMsg);
689 } else {
690 // First piece of data is the on/off flag
691 CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing action flag");
692 const auto &action = rq->buf_data(0);
693 bool on_off = false;
694 if (strcmp(action.data(), "on") == 0) {
695 on_off = true;
696 } else if (strcmp(action.data(), "off") == 0) {
697 on_off = false;
698 } else {
699 RETURN_STATUS_UNEXPECTED("Unknown request: " + action);
700 }
701 RETURN_IF_NOT_OK(cs->ToggleWriteMode(on_off));
702 }
703 return Status::OK();
704 }
705
ListSessions(CacheReply * reply)706 Status CacheServer::ListSessions(CacheReply *reply) {
707 SharedLock sess_lck(&sessions_lock_);
708 SharedLock lck(&rwLock_);
709 flatbuffers::FlatBufferBuilder fbb;
710 std::vector<flatbuffers::Offset<ListSessionMsg>> session_msgs_vector;
711 for (auto const ¤t_session_id : active_sessions_) {
712 bool found = false;
713 for (auto const &it : all_caches_) {
714 auto current_conn_id = it.first;
715 if (GetSessionID(current_conn_id) == current_session_id) {
716 found = true;
717 auto &cs = it.second;
718 CacheService::ServiceStat svc_stat;
719 RETURN_IF_NOT_OK(cs->GetStat(&svc_stat));
720 auto current_stats = CreateServiceStatMsg(fbb, svc_stat.stat_.num_mem_cached, svc_stat.stat_.num_disk_cached,
721 svc_stat.stat_.average_cache_sz, svc_stat.stat_.num_numa_hit,
722 svc_stat.stat_.min_key, svc_stat.stat_.max_key, svc_stat.state_);
723 auto current_session_info = CreateListSessionMsg(fbb, current_session_id, current_conn_id, current_stats);
724 session_msgs_vector.push_back(current_session_info);
725 }
726 }
727 if (!found) {
728 // If there is no cache created yet, assign a connection id of 0 along with empty stats
729 auto current_stats = CreateServiceStatMsg(fbb, 0, 0, 0, 0, 0, 0);
730 auto current_session_info = CreateListSessionMsg(fbb, current_session_id, 0, current_stats);
731 session_msgs_vector.push_back(current_session_info);
732 }
733 }
734 flatbuffers::Offset<flatbuffers::String> spill_dir;
735 spill_dir = fbb.CreateString(top_);
736 auto session_msgs = fbb.CreateVector(session_msgs_vector);
737 ListSessionsMsgBuilder s_builder(fbb);
738 s_builder.add_sessions(session_msgs);
739 s_builder.add_num_workers(num_workers_);
740 s_builder.add_log_level(log_level_);
741 s_builder.add_spill_dir(spill_dir);
742 auto offset = s_builder.Finish();
743 fbb.Finish(offset);
744 reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
745 return Status::OK();
746 }
747
ConnectReset(CacheRequest * rq)748 Status CacheServer::ConnectReset(CacheRequest *rq) {
749 auto connection_id = rq->connection_id();
750 // Hold the shared lock to prevent the cache from being dropped.
751 SharedLock lck(&rwLock_);
752 CacheService *cs = GetService(connection_id);
753 if (cs == nullptr) {
754 std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
755 RETURN_STATUS_UNEXPECTED(errMsg);
756 } else {
757 auto client_id = rq->client_id();
758 MS_LOG(WARNING) << "Client id " << client_id << " with connection id " << connection_id << " disconnects";
759 cs->num_clients_--;
760 }
761 return Status::OK();
762 }
763
BatchCacheRows(CacheRequest * rq)764 Status CacheServer::BatchCacheRows(CacheRequest *rq) {
765 // First one is cookie, followed by address and then size.
766 enum BufDataIndex : uint8_t { kCookie = 0, kAddr = 1, kSize = 2 };
767 constexpr int32_t kExpectedBufDataSize = 3;
768 CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data().size() == kExpectedBufDataSize, "Expect three pieces of data");
769 try {
770 auto &cookie = rq->buf_data(BufDataIndex::kCookie);
771 auto connection_id = rq->connection_id();
772 auto client_id = rq->client_id();
773 int64_t offset_addr;
774 int32_t num_elem;
775 auto *base = SharedMemoryBaseAddr();
776 offset_addr = strtoll(rq->buf_data(BufDataIndex::kAddr).data(), nullptr, kDecimal);
777 auto p = reinterpret_cast<char *>(reinterpret_cast<int64_t>(base) + offset_addr);
778 num_elem = static_cast<int32_t>(strtol(rq->buf_data(BufDataIndex::kSize).data(), nullptr, kDecimal));
779 auto batch_wait = std::make_shared<BatchWait>(num_elem);
780 // Get a set of free request and push into the queues.
781 for (auto i = 0; i < num_elem; ++i) {
782 auto start = reinterpret_cast<int64_t>(p);
783 auto msg = GetTensorRowHeaderMsg(p);
784 p += msg->size_of_this();
785 for (auto k = 0; k < msg->column()->size(); ++k) {
786 p += msg->data_sz()->Get(k);
787 }
788 CacheServerRequest *cache_rq;
789 RETURN_IF_NOT_OK(GetFreeRequestTag(&cache_rq));
790 // Fill in detail.
791 cache_rq->type_ = BaseRequest::RequestType::kInternalCacheRow;
792 cache_rq->st_ = CacheServerRequest::STATE::PROCESS;
793 cache_rq->rq_.set_connection_id(connection_id);
794 cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_));
795 cache_rq->rq_.set_client_id(client_id);
796 cache_rq->rq_.set_flag(kDataIsInSharedMemory);
797 cache_rq->rq_.add_buf_data(cookie);
798 cache_rq->rq_.add_buf_data(std::to_string(start - reinterpret_cast<int64_t>(base)));
799 cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast<int64_t>(p - start)));
800 cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast<int64_t>(batch_wait.get())));
801 RETURN_IF_NOT_OK(PushRequest(GetRandomWorker(), cache_rq));
802 }
803 // Now wait for all of them to come back.
804 RETURN_IF_NOT_OK(batch_wait->Wait());
805 // Return the result
806 return batch_wait->GetRc();
807 } catch (const std::exception &e) {
808 RETURN_STATUS_UNEXPECTED(e.what());
809 }
810 return Status::OK();
811 }
812
ProcessRowRequest(CacheServerRequest * cache_req,bool * internal_request)813 Status CacheServer::ProcessRowRequest(CacheServerRequest *cache_req, bool *internal_request) {
814 auto &rq = cache_req->rq_;
815 auto &reply = cache_req->reply_;
816 switch (cache_req->type_) {
817 case BaseRequest::RequestType::kCacheRow: {
818 // Look into the flag to see where we can find the data and call the appropriate method.
819 if (BitTest(rq.flag(), kDataIsInSharedMemory)) {
820 cache_req->rc_ = FastCacheRow(&rq, &reply);
821 } else {
822 cache_req->rc_ = CacheRow(&rq, &reply);
823 }
824 break;
825 }
826 case BaseRequest::RequestType::kInternalCacheRow: {
827 *internal_request = true;
828 cache_req->rc_ = InternalCacheRow(&rq, &reply);
829 break;
830 }
831 case BaseRequest::RequestType::kBatchCacheRows: {
832 cache_req->rc_ = BatchCacheRows(&rq);
833 break;
834 }
835 case BaseRequest::RequestType::kBatchFetchRows: {
836 cache_req->rc_ = BatchFetchRows(&rq, &reply);
837 break;
838 }
839 case BaseRequest::RequestType::kInternalFetchRow: {
840 *internal_request = true;
841 cache_req->rc_ = InternalFetchRow(&rq);
842 break;
843 }
844 default:
845 std::string errMsg("Internal error, request type is not row request: ");
846 errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
847 cache_req->rc_ = STATUS_ERROR(StatusCode::kMDUnexpectedError, errMsg);
848 }
849 return Status::OK();
850 }
851
ProcessSessionRequest(CacheServerRequest * cache_req)852 Status CacheServer::ProcessSessionRequest(CacheServerRequest *cache_req) {
853 auto &rq = cache_req->rq_;
854 auto &reply = cache_req->reply_;
855 switch (cache_req->type_) {
856 case BaseRequest::RequestType::kDropSession: {
857 cache_req->rc_ = DestroySession(&rq);
858 break;
859 }
860 case BaseRequest::RequestType::kGenerateSessionId: {
861 cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply);
862 break;
863 }
864 case BaseRequest::RequestType::kListSessions: {
865 cache_req->rc_ = ListSessions(&reply);
866 break;
867 }
868 default:
869 std::string errMsg("Internal error, request type is not session request: ");
870 errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
871 cache_req->rc_ = STATUS_ERROR(StatusCode::kMDUnexpectedError, errMsg);
872 }
873 return Status::OK();
874 }
875
ProcessAdminRequest(CacheServerRequest * cache_req)876 Status CacheServer::ProcessAdminRequest(CacheServerRequest *cache_req) {
877 auto &rq = cache_req->rq_;
878 auto &reply = cache_req->reply_;
879 switch (cache_req->type_) {
880 case BaseRequest::RequestType::kCreateCache: {
881 cache_req->rc_ = CreateService(&rq, &reply);
882 break;
883 }
884 case BaseRequest::RequestType::kGetCacheMissKeys: {
885 cache_req->rc_ = GetCacheMissKeys(&rq, &reply);
886 break;
887 }
888 case BaseRequest::RequestType::kDestroyCache: {
889 cache_req->rc_ = DestroyCache(&rq);
890 break;
891 }
892 case BaseRequest::RequestType::kGetStat: {
893 cache_req->rc_ = GetStat(&rq, &reply);
894 break;
895 }
896 case BaseRequest::RequestType::kCacheSchema: {
897 cache_req->rc_ = CacheSchema(&rq);
898 break;
899 }
900 case BaseRequest::RequestType::kFetchSchema: {
901 cache_req->rc_ = FetchSchema(&rq, &reply);
902 break;
903 }
904 case BaseRequest::RequestType::kBuildPhaseDone: {
905 cache_req->rc_ = BuildPhaseDone(&rq);
906 break;
907 }
908 case BaseRequest::RequestType::kAllocateSharedBlock: {
909 cache_req->rc_ = AllocateSharedMemory(&rq, &reply);
910 break;
911 }
912 case BaseRequest::RequestType::kFreeSharedBlock: {
913 cache_req->rc_ = FreeSharedMemory(&rq);
914 break;
915 }
916 case BaseRequest::RequestType::kStopService: {
917 // This command shutdowns everything.
918 // But we first reply back to the client that we receive the request.
919 // The real shutdown work will be done by the caller.
920 cache_req->rc_ = AcknowledgeShutdown(cache_req);
921 break;
922 }
923 case BaseRequest::RequestType::kHeartBeat: {
924 cache_req->rc_ = Status::OK();
925 break;
926 }
927 case BaseRequest::RequestType::kToggleWriteMode: {
928 cache_req->rc_ = ToggleWriteMode(&rq);
929 break;
930 }
931 case BaseRequest::RequestType::kConnectReset: {
932 cache_req->rc_ = ConnectReset(&rq);
933 break;
934 }
935 case BaseRequest::RequestType::kGetCacheState: {
936 cache_req->rc_ = GetCacheState(&rq, &reply);
937 break;
938 }
939 default:
940 std::string errMsg("Internal error, request type is not admin request: ");
941 errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
942 cache_req->rc_ = STATUS_ERROR(StatusCode::kMDUnexpectedError, errMsg);
943 }
944 return Status::OK();
945 }
946
ProcessRequest(CacheServerRequest * cache_req)947 Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
948 bool internal_request = false;
949
950 // Except for creating a new session, we expect cs is not null.
951 if (cache_req->IsRowRequest()) {
952 RETURN_IF_NOT_OK(ProcessRowRequest(cache_req, &internal_request));
953 } else if (cache_req->IsSessionRequest()) {
954 RETURN_IF_NOT_OK(ProcessSessionRequest(cache_req));
955 } else if (cache_req->IsAdminRequest()) {
956 RETURN_IF_NOT_OK(ProcessAdminRequest(cache_req));
957 } else {
958 std::string errMsg("Unknown request type : ");
959 errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
960 cache_req->rc_ = STATUS_ERROR(StatusCode::kMDUnexpectedError, errMsg);
961 }
962
963 // Notify it is done, and move on to the next request.
964 Status2CacheReply(cache_req->rc_, &cache_req->reply_);
965 cache_req->st_ = CacheServerRequest::STATE::FINISH;
966 // We will re-tag the request back to the grpc queue. Once it comes back from the client,
967 // the CacheServerRequest, i.e. the pointer cache_req, will be free
968 if (!internal_request && !global_shutdown_) {
969 cache_req->responder_.Finish(cache_req->reply_, grpc::Status::OK, cache_req);
970 } else {
971 // We can free up the request now.
972 RETURN_IF_NOT_OK(ReturnRequestTag(cache_req));
973 }
974 return Status::OK();
975 }
976
977 /// \brief This is the main loop the cache server thread(s) are running.
978 /// Each thread will pop a request and send the result back to the client using grpc
979 /// \return
ServerRequest(worker_id_t worker_id)980 Status CacheServer::ServerRequest(worker_id_t worker_id) {
981 TaskManager::FindMe()->Post();
982 MS_LOG(DEBUG) << "Worker id " << worker_id << " is running on node " << hw_info_->GetMyNode();
983 auto &my_que = cache_q_->operator[](worker_id);
984 // Loop forever until we are interrupted or shutdown.
985 while (!global_shutdown_) {
986 CacheServerRequest *cache_req = nullptr;
987 RETURN_IF_NOT_OK(my_que->PopFront(&cache_req));
988 RETURN_IF_NOT_OK(ProcessRequest(cache_req));
989 }
990 return Status::OK();
991 }
992
GetConnectionID(session_id_type session_id,uint32_t crc) const993 connection_id_type CacheServer::GetConnectionID(session_id_type session_id, uint32_t crc) const {
994 connection_id_type connection_id =
995 (static_cast<connection_id_type>(session_id) << 32u) | static_cast<connection_id_type>(crc);
996 return connection_id;
997 }
998
GetSessionID(connection_id_type connection_id) const999 session_id_type CacheServer::GetSessionID(connection_id_type connection_id) const {
1000 return static_cast<session_id_type>(connection_id >> 32u);
1001 }
1002
CacheServer(const std::string & spill_path,int32_t num_workers,int32_t port,int32_t shared_meory_sz_in_gb,float memory_cap_ratio,int8_t log_level,std::shared_ptr<CacheServerHW> hw_info)1003 CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port,
1004 int32_t shared_meory_sz_in_gb, float memory_cap_ratio, int8_t log_level,
1005 std::shared_ptr<CacheServerHW> hw_info)
1006 : top_(spill_path),
1007 num_workers_(num_workers),
1008 num_grpc_workers_(num_workers_),
1009 port_(port),
1010 shared_memory_sz_in_gb_(shared_meory_sz_in_gb),
1011 global_shutdown_(false),
1012 memory_cap_ratio_(memory_cap_ratio),
1013 numa_affinity_(true),
1014 log_level_(log_level),
1015 hw_info_(std::move(hw_info)) {
1016 // If we are not linked with numa library (i.e. NUMA_ENABLED is false), turn off cpu
1017 // affinity which can make performance worse.
1018 if (!CacheServerHW::numa_enabled()) {
1019 numa_affinity_ = false;
1020 MS_LOG(WARNING) << "Warning: This build is not compiled with numa support. Install libnuma-devel and use a build "
1021 "that is compiled with numa support for more optimal performance";
1022 }
1023 // We create the shared memory and we will destroy it. All other client just detach only.
1024 if (shared_memory_sz_in_gb_ > kDefaultSharedMemorySize) {
1025 MS_LOG(INFO) << "Shared memory size is readjust to " << kDefaultSharedMemorySize << " GB.";
1026 shared_memory_sz_in_gb_ = kDefaultSharedMemorySize;
1027 }
1028 }
1029
Run(int msg_qid)1030 Status CacheServer::Run(int msg_qid) {
1031 Status rc = ServiceStart();
1032 // If there is a message que, return the status now before we call join_all which will never return
1033 if (msg_qid != -1) {
1034 SharedMessage msg(msg_qid);
1035 RETURN_IF_NOT_OK(msg.SendStatus(rc));
1036 }
1037 if (rc.IsError()) {
1038 return rc;
1039 }
1040 // This is called by the main function and we shouldn't exit. Otherwise the main thread
1041 // will just shutdown. So we will call some function that never return unless error.
1042 // One good case will be simply to wait for all threads to return.
1043 // note that after we have sent the initial status using the msg_qid, parent process will exit and
1044 // remove it. So we can't use it again.
1045 RETURN_IF_NOT_OK(vg_.join_all(Task::WaitFlag::kBlocking));
1046 // Shutdown the grpc queue. No longer accept any new comer.
1047 comm_layer_->Shutdown();
1048 // The next thing to do drop all the caches.
1049 RETURN_IF_NOT_OK(ServiceStop());
1050 return Status::OK();
1051 }
1052
GetFreeRequestTag(CacheServerRequest ** q)1053 Status CacheServer::GetFreeRequestTag(CacheServerRequest **q) {
1054 RETURN_UNEXPECTED_IF_NULL(q);
1055 auto *p = new (std::nothrow) CacheServerRequest();
1056 if (p == nullptr) {
1057 RETURN_STATUS_OOM("Out of memory.");
1058 }
1059 *q = p;
1060 return Status::OK();
1061 }
1062
ReturnRequestTag(CacheServerRequest * p)1063 Status CacheServer::ReturnRequestTag(CacheServerRequest *p) {
1064 RETURN_UNEXPECTED_IF_NULL(p);
1065 delete p;
1066 return Status::OK();
1067 }
1068
DestroySession(CacheRequest * rq)1069 Status CacheServer::DestroySession(CacheRequest *rq) {
1070 CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing session id");
1071 auto drop_session_id = rq->connection_info().session_id();
1072 // Grab the locks in the correct order to avoid deadlock.
1073 UniqueLock sess_lck(&sessions_lock_);
1074 UniqueLock lck(&rwLock_);
1075 // Iterate over the set of connection id's for this session that we're dropping and erase each one.
1076 bool found = false;
1077 for (auto it = all_caches_.begin(); it != all_caches_.end();) {
1078 auto connection_id = it->first;
1079 auto session_id = GetSessionID(connection_id);
1080 // We can just call DestroyCache() but we are holding a lock already. Doing so will cause deadlock.
1081 // So we will just manually do it.
1082 if (session_id == drop_session_id) {
1083 found = true;
1084 it = all_caches_.erase(it);
1085 MS_LOG(INFO) << "Destroy cache with id " << connection_id;
1086 } else {
1087 ++it;
1088 }
1089 }
1090 // Finally remove the session itself
1091 auto n = active_sessions_.erase(drop_session_id);
1092 if (n > 0) {
1093 MS_LOG(INFO) << "Session destroyed with id " << drop_session_id;
1094 return Status::OK();
1095 } else {
1096 if (found) {
1097 std::string errMsg =
1098 "A destroy cache request has been completed but it had a stale session id " + std::to_string(drop_session_id);
1099 RETURN_STATUS_UNEXPECTED(errMsg);
1100 } else {
1101 std::string errMsg =
1102 "Session id " + std::to_string(drop_session_id) + " not found in server on port " + std::to_string(port_) + ".";
1103 RETURN_STATUS_ERROR(StatusCode::kMDFileNotExist, errMsg);
1104 }
1105 }
1106 }
1107
GenerateSessionID()1108 session_id_type CacheServer::GenerateSessionID() {
1109 UniqueLock sess_lck(&sessions_lock_);
1110 auto mt = GetRandomDevice();
1111 std::uniform_int_distribution<session_id_type> distribution(0, std::numeric_limits<session_id_type>::max());
1112 session_id_type session_id;
1113 bool duplicate = false;
1114 do {
1115 session_id = distribution(mt);
1116 auto r = active_sessions_.insert(session_id);
1117 duplicate = !r.second;
1118 } while (duplicate);
1119 return session_id;
1120 }
1121
AllocateSharedMemory(CacheRequest * rq,CacheReply * reply)1122 Status CacheServer::AllocateSharedMemory(CacheRequest *rq, CacheReply *reply) {
1123 auto client_id = rq->client_id();
1124 CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
1125 CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing requested size");
1126 try {
1127 auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, kDecimal);
1128 void *p = nullptr;
1129 RETURN_IF_NOT_OK(AllocateSharedMemory(client_id, requestedSz, &p));
1130 auto *base = SharedMemoryBaseAddr();
1131 // We can't return the absolute address which makes no sense to the client.
1132 // Instead we return the difference.
1133 auto difference = reinterpret_cast<int64_t>(p) - reinterpret_cast<int64_t>(base);
1134 reply->set_result(std::to_string(difference));
1135 } catch (const std::exception &e) {
1136 RETURN_STATUS_UNEXPECTED(e.what());
1137 }
1138 return Status::OK();
1139 }
1140
FreeSharedMemory(CacheRequest * rq)1141 Status CacheServer::FreeSharedMemory(CacheRequest *rq) {
1142 auto client_id = rq->client_id();
1143 CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
1144 CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing addr");
1145 auto *base = SharedMemoryBaseAddr();
1146 try {
1147 auto addr = strtoll(rq->buf_data(0).data(), nullptr, kDecimal);
1148 auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr);
1149 DeallocateSharedMemory(client_id, p);
1150 } catch (const std::exception &e) {
1151 RETURN_STATUS_UNEXPECTED(e.what());
1152 }
1153 return Status::OK();
1154 }
1155
GetCacheState(CacheRequest * rq,CacheReply * reply)1156 Status CacheServer::GetCacheState(CacheRequest *rq, CacheReply *reply) {
1157 auto connection_id = rq->connection_id();
1158 SharedLock lck(&rwLock_);
1159 CacheService *cs = GetService(connection_id);
1160 if (cs == nullptr) {
1161 std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
1162 RETURN_STATUS_UNEXPECTED(errMsg);
1163 } else {
1164 auto state = cs->GetState();
1165 reply->set_result(std::to_string(static_cast<int8_t>(state)));
1166 return Status::OK();
1167 }
1168 }
1169
RpcRequest(worker_id_t worker_id)1170 Status CacheServer::RpcRequest(worker_id_t worker_id) {
1171 TaskManager::FindMe()->Post();
1172 RETURN_IF_NOT_OK(comm_layer_->HandleRequest(worker_id));
1173 return Status::OK();
1174 }
1175
AcknowledgeShutdown(CacheServerRequest * cache_req)1176 Status CacheServer::AcknowledgeShutdown(CacheServerRequest *cache_req) {
1177 auto *rq = &cache_req->rq_;
1178 auto *reply = &cache_req->reply_;
1179 if (!rq->buf_data().empty()) {
1180 // cache_admin sends us a message qID and we will destroy the
1181 // queue in our destructor and this will wake up cache_admin.
1182 // But we don't want the cache_admin blindly just block itself.
1183 // So we will send back an ack before shutdown the comm layer.
1184 try {
1185 int32_t qID = std::stoi(rq->buf_data(0));
1186 shutdown_qIDs_.push_back(qID);
1187 } catch (const std::exception &e) {
1188 // ignore it.
1189 }
1190 }
1191 reply->set_result("OK");
1192 return Status::OK();
1193 }
1194
GlobalShutdown()1195 void CacheServer::GlobalShutdown() {
1196 // Let's shutdown in proper order.
1197 bool expected = false;
1198 if (global_shutdown_.compare_exchange_strong(expected, true)) {
1199 MS_LOG(WARNING) << "Shutting down server.";
1200 // Interrupt all the threads and queues. We will leave the shutdown
1201 // of the comm layer after we have joined all the threads and will
1202 // be done by the master thread.
1203 vg_.interrupt_all();
1204 }
1205 }
1206
GetWorkerByNumaId(numa_id_t numa_id) const1207 worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) const {
1208 auto num_numa_nodes = GetNumaNodeCount();
1209 MS_ASSERT(numa_id < num_numa_nodes);
1210 auto num_workers_per_node = GetNumWorkers() / num_numa_nodes;
1211 std::mt19937 gen = GetRandomDevice();
1212 std::uniform_int_distribution<worker_id_t> dist(0, num_workers_per_node - 1);
1213 auto n = dist(gen);
1214 worker_id_t worker_id = n * num_numa_nodes + numa_id;
1215 MS_ASSERT(worker_id < GetNumWorkers());
1216 return worker_id;
1217 }
1218
GetRandomWorker() const1219 worker_id_t CacheServer::GetRandomWorker() const {
1220 std::mt19937 gen = GetRandomDevice();
1221 std::uniform_int_distribution<worker_id_t> dist(0, num_workers_ - 1);
1222 return dist(gen);
1223 }
1224
AllocateSharedMemory(int32_t client_id,size_t sz,void ** p)1225 Status CacheServer::AllocateSharedMemory(int32_t client_id, size_t sz, void **p) {
1226 return shm_->AllocateSharedMemory(client_id, sz, p);
1227 }
1228
DeallocateSharedMemory(int32_t client_id,void * p)1229 void CacheServer::DeallocateSharedMemory(int32_t client_id, void *p) { shm_->DeallocateSharedMemory(client_id, p); }
1230
IpcResourceCleanup()1231 Status CacheServer::Builder::IpcResourceCleanup() {
1232 Status rc;
1233 SharedMemory::shm_key_t shm_key;
1234 auto unix_socket = PortToUnixSocketPath(port_);
1235 rc = PortToFtok(port_, &shm_key);
1236 // We are expecting the unix path doesn't exist.
1237 if (rc.IsError()) {
1238 return Status::OK();
1239 }
1240 // Attach to the shared memory which we expect don't exist
1241 SharedMemory mem(shm_key);
1242 rc = mem.Attach();
1243 if (rc.IsError()) {
1244 return Status::OK();
1245 } else {
1246 RETURN_IF_NOT_OK(mem.Detach());
1247 }
1248 int32_t num_attached;
1249 RETURN_IF_NOT_OK(mem.GetNumAttached(&num_attached));
1250 if (num_attached == 0) {
1251 // Stale shared memory from last time.
1252 // Remove both the memory and the socket path
1253 RETURN_IF_NOT_OK(mem.Destroy());
1254 Path p(unix_socket);
1255 (void)p.Remove();
1256 } else {
1257 // Server is already up.
1258 std::string errMsg = "Cache server is already up and running";
1259 // We return a duplicate error. The main() will intercept
1260 // and output a proper message
1261 RETURN_STATUS_ERROR(StatusCode::kMDDuplicateKey, errMsg);
1262 }
1263 return Status::OK();
1264 }
1265
SanityCheck()1266 Status CacheServer::Builder::SanityCheck() {
1267 if (shared_memory_sz_in_gb_ <= 0) {
1268 RETURN_STATUS_UNEXPECTED("Shared memory size (in GB unit) must be positive");
1269 }
1270 if (num_workers_ <= 0) {
1271 RETURN_STATUS_UNEXPECTED("Number of parallel workers must be positive");
1272 }
1273 if (!top_.empty()) {
1274 auto p = top_.data();
1275 if (p[0] != '/') {
1276 RETURN_STATUS_UNEXPECTED("Spilling directory must be an absolute path");
1277 }
1278 // Check if the spill directory is writable
1279 Path spill(top_);
1280 auto t = spill / Services::GetUniqueID();
1281 Status rc = t.CreateDirectory();
1282 if (rc.IsOk()) {
1283 rc = t.Remove();
1284 }
1285 if (rc.IsError()) {
1286 RETURN_STATUS_UNEXPECTED("Spilling directory is not writable\n" + rc.ToString());
1287 }
1288 }
1289 if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1) {
1290 RETURN_STATUS_UNEXPECTED("Memory cap ratio should be positive and no greater than 1");
1291 }
1292
1293 // Check if the shared memory.
1294 RETURN_IF_NOT_OK(IpcResourceCleanup());
1295 return Status::OK();
1296 }
1297
AdjustNumWorkers(int32_t num_workers)1298 int32_t CacheServer::Builder::AdjustNumWorkers(int32_t num_workers) {
1299 int32_t num_numa_nodes = hw_info_->GetNumaNodeCount();
1300 // Bump up num_workers_ to at least the number of numa nodes
1301 num_workers = std::max(num_numa_nodes, num_workers);
1302 // But also it shouldn't be too many more than the hardware concurrency
1303 int32_t num_cpus = hw_info_->GetCpuCount();
1304 constexpr int32_t kThreadsPerCore = 2;
1305 num_workers = std::min(kThreadsPerCore * num_cpus, num_workers);
1306 // Round up num_workers to a multiple of numa nodes.
1307 auto remainder = num_workers % num_numa_nodes;
1308 if (remainder > 0) {
1309 num_workers += (num_numa_nodes - remainder);
1310 }
1311 return num_workers;
1312 }
1313
Builder()1314 CacheServer::Builder::Builder()
1315 : top_(""),
1316 num_workers_(kDefaultNumWorkers),
1317 port_(kCfgDefaultCachePort),
1318 shared_memory_sz_in_gb_(kDefaultSharedMemorySize),
1319 memory_cap_ratio_(kDefaultMemoryCapRatio),
1320 log_level_(kDefaultLogLevel) {
1321 if (num_workers_ == 0) {
1322 num_workers_ = 1;
1323 }
1324 }
1325 } // namespace dataset
1326 } // namespace mindspore
1327