1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7
8 * http://www.apache.org/licenses/LICENSE-2.0
9
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/cache/cache_server.h"
17 #include <algorithm>
18 #include <functional>
19 #include <limits>
20 #include <vector>
21 #include "minddata/dataset/include/dataset/constants.h"
22 #include "minddata/dataset/engine/cache/cache_ipc.h"
23 #include "minddata/dataset/engine/cache/cache_service.h"
24 #include "minddata/dataset/engine/cache/cache_request.h"
25 #include "minddata/dataset/util/bit.h"
26 #include "minddata/dataset/util/path.h"
27 #include "minddata/dataset/util/random.h"
28 #ifdef CACHE_LOCAL_CLIENT
29 #include "minddata/dataset/util/sig_handler.h"
30 #endif
31
32 namespace mindspore {
33 namespace dataset {
34 CacheServer *CacheServer::instance_ = nullptr;
35 std::once_flag CacheServer::init_instance_flag_;
DoServiceStart()36 Status CacheServer::DoServiceStart() {
37 #ifdef CACHE_LOCAL_CLIENT
38 // We need to destroy the shared memory if user hits Control-C
39 RegisterHandlers();
40 #endif
41 if (!top_.empty()) {
42 Path spill(top_);
43 RETURN_IF_NOT_OK(spill.CreateDirectories());
44 MS_LOG(INFO) << "CacheServer will use disk folder: " << top_;
45 }
46 RETURN_IF_NOT_OK(vg_.ServiceStart());
47 auto num_numa_nodes = GetNumaNodeCount();
48 // If we link with numa library. Set default memory policy.
49 // If we don't pin thread to cpu, then use up all memory controllers to maximize
50 // memory bandwidth.
51 RETURN_IF_NOT_OK(
52 CacheServerHW::SetDefaultMemoryPolicy(numa_affinity_ ? CachePoolPolicy::kLocal : CachePoolPolicy::kInterleave));
53 auto my_node = hw_info_->GetMyNode();
54 MS_LOG(DEBUG) << "Cache server is running on numa node " << my_node;
55 // There will be some threads working on the grpc queue and
56 // some number of threads working on the CacheServerRequest queue.
57 // Like a connector object we will set up the same number of queues but
58 // we do not need to preserve any order. We will set the capacity of
59 // each queue to be 64 since we are just pushing memory pointers which
60 // is only 8 byte each.
61 const int32_t kQueCapacity = 64;
62 // This is the request queue from the client
63 cache_q_ = std::make_shared<QueueList<CacheServerRequest *>>();
64 cache_q_->Init(num_workers_, kQueCapacity);
65 // We will match the number of grpc workers with the number of server workers.
66 // But technically they don't have to be the same.
67 num_grpc_workers_ = num_workers_;
68 MS_LOG(DEBUG) << "Number of gprc workers is set to " << num_grpc_workers_;
69 RETURN_IF_NOT_OK(cache_q_->Register(&vg_));
70 // Start the comm layer
71 try {
72 comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_);
73 RETURN_IF_NOT_OK(comm_layer_->Run());
74 } catch (const std::exception &e) {
75 RETURN_STATUS_UNEXPECTED(e.what());
76 }
77 #ifdef CACHE_LOCAL_CLIENT
78 RETURN_IF_NOT_OK(CachedSharedMemory::CreateArena(&shm_, port_, shared_memory_sz_in_gb_));
79 // Bring up a thread to monitor the unix socket in case it is removed. But it must be done
80 // after we have created the unix socket.
81 auto inotify_f = std::bind(&CacheServerGreeterImpl::MonitorUnixSocket, comm_layer_.get());
82 RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Monitor unix socket", inotify_f));
83 #endif
84 // Spawn a few threads to serve the real request.
85 auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1);
86 for (auto i = 0; i < num_workers_; ++i) {
87 Task *pTask;
88 RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache service worker", std::bind(f, i), &pTask));
89 // Save a copy of the pointer to the underlying Task object. We may dynamically change their affinity if needed.
90 numa_tasks_.emplace(i, pTask);
91 // Spread out all the threads to all the numa nodes if needed
92 if (IsNumaAffinityOn()) {
93 auto numa_id = i % num_numa_nodes;
94 RETURN_IF_NOT_OK(SetAffinity(*pTask, numa_id));
95 }
96 }
97 // Finally loop forever to handle the request.
98 auto r = std::bind(&CacheServer::RpcRequest, this, std::placeholders::_1);
99 for (auto i = 0; i < num_grpc_workers_; ++i) {
100 Task *pTask;
101 RETURN_IF_NOT_OK(vg_.CreateAsyncTask("rpc worker", std::bind(r, i), &pTask));
102 // All these grpc workers will be allocated to the same node which is where we allocate all those free tag
103 // memory.
104 if (IsNumaAffinityOn()) {
105 RETURN_IF_NOT_OK(SetAffinity(*pTask, i % num_numa_nodes));
106 }
107 }
108 return Status::OK();
109 }
110
DoServiceStop()111 Status CacheServer::DoServiceStop() {
112 Status rc;
113 Status rc2;
114 // First stop all the threads.
115 RETURN_IF_NOT_OK(vg_.ServiceStop());
116 // Clean up all the caches if any.
117 UniqueLock lck(&rwLock_);
118 auto it = all_caches_.begin();
119 while (it != all_caches_.end()) {
120 auto cs = std::move(it->second);
121 rc2 = cs->ServiceStop();
122 if (rc2.IsError()) {
123 rc = rc2;
124 }
125 ++it;
126 }
127 // Also remove the path we use to generate ftok.
128 Path p(PortToUnixSocketPath(port_));
129 (void)p.Remove();
130 // Finally wake up cache_admin if it is waiting
131 for (int32_t qID : shutdown_qIDs_) {
132 SharedMessage msg(qID);
133 RETURN_IF_NOT_OK(msg.SendStatus(Status::OK()));
134 msg.RemoveResourcesOnExit();
135 // Let msg goes out of scope which will destroy the queue.
136 }
137 return rc;
138 }
139
GetService(connection_id_type id) const140 CacheService *CacheServer::GetService(connection_id_type id) const {
141 auto it = all_caches_.find(id);
142 if (it != all_caches_.end()) {
143 return it->second.get();
144 }
145 return nullptr;
146 }
147
148 // We would like to protect ourselves from over allocating too much. We will go over existing cache
149 // and calculate how much we have consumed so far.
GlobalMemoryCheck(uint64_t cache_mem_sz)150 Status CacheServer::GlobalMemoryCheck(uint64_t cache_mem_sz) {
151 auto end = all_caches_.end();
152 auto it = all_caches_.begin();
153 auto avail_mem = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_;
154 int64_t max_avail = avail_mem;
155 while (it != end) {
156 auto &cs = it->second;
157 CacheService::ServiceStat stat;
158 RETURN_IF_NOT_OK(cs->GetStat(&stat));
159 int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz;
160 max_avail -= mem_consumed;
161 if (max_avail <= 0) {
162 return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
163 }
164 ++it;
165 }
166
167 // If we have some cache using some memory already, make a reasonable decision if we should return
168 // out of memory.
169 if (max_avail < avail_mem) {
170 int64_t req_mem = cache_mem_sz * 1048576L; // It is in MB unit.
171 if (req_mem > max_avail) {
172 return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
173 } else if (req_mem == 0) {
174 // This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than
175 // 85% of our limit, fail this request.
176 if (static_cast<float>(max_avail) / static_cast<float>(avail_mem) <= kMemoryBottomLineForNewService) {
177 return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions");
178 }
179 }
180 }
181 return Status::OK();
182 }
183
CreateService(CacheRequest * rq,CacheReply * reply)184 Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) {
185 CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing connection info");
186 std::string cookie;
187 int32_t client_id;
188 auto session_id = rq->connection_info().session_id();
189 auto crc = rq->connection_info().crc();
190
191 // Before allowing the creation, make sure the session had already been created by the user
192 // Our intention is to add this cache to the active sessions list so leave the list locked during
193 // this entire function.
194 UniqueLock sess_lck(&sessions_lock_);
195 auto session_it = active_sessions_.find(session_id);
196 if (session_it == active_sessions_.end()) {
197 RETURN_STATUS_UNEXPECTED("A cache creation has been requested but the session was not found!");
198 }
199
200 // We concat both numbers to form the internal connection id.
201 auto connection_id = GetConnectionID(session_id, crc);
202 CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing info to create cache");
203 auto &create_cache_buf = rq->buf_data(0);
204 auto p = flatbuffers::GetRoot<CreateCacheRequestMsg>(create_cache_buf.data());
205 auto flag = static_cast<CreateCacheRequest::CreateCacheFlag>(p->flag());
206 auto cache_mem_sz = p->cache_mem_sz();
207 // We can't do spilling unless this server is setup with a spill path in the first place
208 bool spill =
209 (flag & CreateCacheRequest::CreateCacheFlag::kSpillToDisk) == CreateCacheRequest::CreateCacheFlag::kSpillToDisk;
210 bool generate_id =
211 (flag & CreateCacheRequest::CreateCacheFlag::kGenerateRowId) == CreateCacheRequest::CreateCacheFlag::kGenerateRowId;
212 if (spill && top_.empty()) {
213 RETURN_STATUS_UNEXPECTED("Server is not set up with spill support.");
214 }
215 // Before creating the cache, first check if this is a request for a shared usage of an existing cache
216 // If two CreateService come in with identical connection_id, we need to serialize the create.
217 // The first create will be successful and be given a special cookie.
218 UniqueLock lck(&rwLock_);
219 bool duplicate = false;
220 CacheService *curr_cs = GetService(connection_id);
221 if (curr_cs != nullptr) {
222 duplicate = true;
223 client_id = curr_cs->num_clients_.fetch_add(1);
224 MS_LOG(INFO) << "Duplicate request from client " + std::to_string(client_id) + " for " +
225 std::to_string(connection_id) + " to create cache service";
226 }
227 // Early exit if we are doing global shutdown
228 if (global_shutdown_) {
229 return Status::OK();
230 }
231
232 if (!duplicate) {
233 RETURN_IF_NOT_OK(GlobalMemoryCheck(cache_mem_sz));
234 std::unique_ptr<CacheService> cs;
235 try {
236 cs = std::make_unique<CacheService>(cache_mem_sz, spill ? top_ : "", generate_id);
237 RETURN_IF_NOT_OK(cs->ServiceStart());
238 cookie = cs->cookie();
239 client_id = cs->num_clients_.fetch_add(1);
240 all_caches_.emplace(connection_id, std::move(cs));
241 } catch (const std::bad_alloc &e) {
242 return Status(StatusCode::kMDOutOfMemory);
243 }
244 }
245
246 // Shuffle the worker threads. But we need to release the locks or we will deadlock when calling
247 // the following function
248 lck.Unlock();
249 sess_lck.Unlock();
250 auto numa_id = client_id % GetNumaNodeCount();
251 std::vector<cpu_id_t> cpu_list = hw_info_->GetCpuList(numa_id);
252 // Send back the data
253 flatbuffers::FlatBufferBuilder fbb;
254 flatbuffers::Offset<flatbuffers::String> off_cookie;
255 flatbuffers::Offset<flatbuffers::Vector<cpu_id_t>> off_cpu_list;
256 off_cookie = fbb.CreateString(cookie);
257 off_cpu_list = fbb.CreateVector(cpu_list);
258 CreateCacheReplyMsgBuilder bld(fbb);
259 bld.add_connection_id(connection_id);
260 bld.add_cookie(off_cookie);
261 bld.add_client_id(client_id);
262 // The last thing we send back is a set of cpu id that we suggest the client should bind itself to
263 bld.add_cpu_id(off_cpu_list);
264 auto off = bld.Finish();
265 fbb.Finish(off);
266 reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
267 // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it
268 // treat it as OK.
269 return duplicate ? Status(StatusCode::kMDDuplicateKey) : Status::OK();
270 }
271
DestroyCache(CacheRequest * rq)272 Status CacheServer::DestroyCache(CacheRequest *rq) {
273 // We need a strong lock to protect the map.
274 UniqueLock lck(&rwLock_);
275 auto id = rq->connection_id();
276 CacheService *cs = GetService(id);
277 // it is already destroyed. Ignore it.
278 if (cs != nullptr) {
279 MS_LOG(WARNING) << "Dropping cache with connection id " << std::to_string(id);
280 // std::map will invoke the destructor of CacheService. So we don't need to do anything here.
281 auto n = all_caches_.erase(id);
282 if (n == 0) {
283 // It has been destroyed by another duplicate request.
284 MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service";
285 }
286 }
287 // We aren't touching the session list even though we may be dropping the last remaining cache of a session.
288 // Leave that to be done by the drop session command.
289 return Status::OK();
290 }
291
CacheRow(CacheRequest * rq,CacheReply * reply)292 Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) {
293 auto connection_id = rq->connection_id();
294 // Hold the shared lock to prevent the cache from being dropped.
295 SharedLock lck(&rwLock_);
296 CacheService *cs = GetService(connection_id);
297 if (cs == nullptr) {
298 std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
299 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
300 } else {
301 auto sz = rq->buf_data_size();
302 std::vector<const void *> buffers;
303 buffers.reserve(sz);
304 // First piece of data is the cookie and is required
305 CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie");
306 auto &cookie = rq->buf_data(0);
307 // Only if the cookie matches, we can accept insert into this cache that has a build phase
308 if (!cs->HasBuildPhase() || cookie == cs->cookie()) {
309 // Push the address of each buffer (in the form of std::string coming in from protobuf) into
310 // a vector of buffer
311 for (auto i = 1; i < sz; ++i) {
312 buffers.push_back(rq->buf_data(i).data());
313 }
314 row_id_type id = -1;
315 // We will allocate the memory the same numa node this thread is bound to.
316 RETURN_IF_NOT_OK(cs->CacheRow(buffers, &id));
317 reply->set_result(std::to_string(id));
318 } else {
319 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
320 }
321 }
322 return Status::OK();
323 }
324
FastCacheRow(CacheRequest * rq,CacheReply * reply)325 Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
326 auto connection_id = rq->connection_id();
327 auto client_id = rq->client_id();
328 CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
329 // Hold the shared lock to prevent the cache from being dropped.
330 SharedLock lck(&rwLock_);
331 CacheService *cs = GetService(connection_id);
332 auto *base = SharedMemoryBaseAddr();
333 // Ensure we got 3 pieces of data coming in
334 constexpr int32_t kMinBufDataSize = 3;
335 CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() >= kMinBufDataSize, "Incomplete data");
336 // First one is cookie, followed by data address and then size.
337 enum BufDataIndex : uint8_t { kCookie = 0, kAddr = 1, kSize = 2 };
338 // First piece of data is the cookie and is required
339 auto &cookie = rq->buf_data(BufDataIndex::kCookie);
340 // Second piece of data is the address where we can find the serialized data
341 auto addr = strtoll(rq->buf_data(BufDataIndex::kAddr).data(), nullptr, kDecimal);
342 auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr);
343 // Third piece of data is the size of the serialized data that we need to transfer
344 auto sz = strtoll(rq->buf_data(BufDataIndex::kSize).data(), nullptr, kDecimal);
345 // Successful or not, we need to free the memory on exit.
346 Status rc;
347 if (cs == nullptr) {
348 std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
349 rc = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
350 } else {
351 // Only if the cookie matches, we can accept insert into this cache that has a build phase
352 if (!cs->HasBuildPhase() || cookie == cs->cookie()) {
353 row_id_type id = -1;
354 ReadableSlice src(p, sz);
355 // We will allocate the memory the same numa node this thread is bound to.
356 rc = cs->FastCacheRow(src, &id);
357 reply->set_result(std::to_string(id));
358 } else {
359 auto state = cs->GetState();
360 if (state != CacheServiceState::kFetchPhase) {
361 rc = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
362 "Cache service is not in fetch phase. The current phase is " +
363 std::to_string(static_cast<int8_t>(state)) + ". Client id: " + std::to_string(client_id));
364 } else {
365 rc = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
366 "Cookie mismatch. Client id: " + std::to_string(client_id));
367 }
368 }
369 }
370 // Return the block to the shared memory only if it is not internal request.
371 if (static_cast<BaseRequest::RequestType>(rq->type()) == BaseRequest::RequestType::kCacheRow) {
372 DeallocateSharedMemory(client_id, p);
373 }
374 return rc;
375 }
376
InternalCacheRow(CacheRequest * rq,CacheReply * reply)377 Status CacheServer::InternalCacheRow(CacheRequest *rq, CacheReply *reply) {
378 // Look into the flag to see where we can find the data and call the appropriate method.
379 auto flag = rq->flag();
380 Status rc;
381 if (BitTest(flag, kDataIsInSharedMemory)) {
382 rc = FastCacheRow(rq, reply);
383 // This is an internal request and is not tied to rpc. But need to post because there
384 // is a thread waiting on the completion of this request.
385 try {
386 constexpr int32_t kBatchWaitIdx = 3;
387 // Fourth piece of the data is the address of the BatchWait ptr
388 int64_t addr = strtol(rq->buf_data(kBatchWaitIdx).data(), nullptr, kDecimal);
389 auto *bw = reinterpret_cast<BatchWait *>(addr);
390 // Check if the object is still around.
391 auto bwObj = bw->GetBatchWait();
392 if (bwObj.lock()) {
393 RETURN_IF_NOT_OK(bw->Set(rc));
394 }
395 } catch (const std::exception &e) {
396 RETURN_STATUS_UNEXPECTED(e.what());
397 }
398 } else {
399 rc = CacheRow(rq, reply);
400 }
401 return rc;
402 }
403
InternalFetchRow(CacheRequest * rq)404 Status CacheServer::InternalFetchRow(CacheRequest *rq) {
405 auto connection_id = rq->connection_id();
406 SharedLock lck(&rwLock_);
407 CacheService *cs = GetService(connection_id);
408 Status rc;
409 if (cs == nullptr) {
410 std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
411 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
412 }
413 // First piece is a flatbuffer containing row fetch information, second piece is the address of the BatchWait ptr
414 enum BufDataIndex : uint8_t { kFetchRowMsg = 0, kBatchWait = 1 };
415 rc = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq->buf_data(BufDataIndex::kFetchRowMsg).data()));
416 // This is an internal request and is not tied to rpc. But need to post because there
417 // is a thread waiting on the completion of this request.
418 try {
419 int64_t addr = strtol(rq->buf_data(BufDataIndex::kBatchWait).data(), nullptr, kDecimal);
420 auto *bw = reinterpret_cast<BatchWait *>(addr);
421 // Check if the object is still around.
422 auto bwObj = bw->GetBatchWait();
423 if (bwObj.lock()) {
424 RETURN_IF_NOT_OK(bw->Set(rc));
425 }
426 } catch (const std::exception &e) {
427 RETURN_STATUS_UNEXPECTED(e.what());
428 }
429 return rc;
430 }
431
BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> & fbb,WritableSlice * out)432 Status CacheServer::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out) {
433 RETURN_UNEXPECTED_IF_NULL(out);
434 auto p = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer());
435 const auto num_elements = p->rows()->size();
436 auto connection_id = p->connection_id();
437 auto batch_wait = std::make_shared<BatchWait>(num_elements);
438 int64_t data_offset = (num_elements + 1) * sizeof(int64_t);
439 auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer());
440 offset_array[0] = data_offset;
441 for (uint32_t i = 0; i < num_elements; ++i) {
442 auto data_locator = p->rows()->Get(i);
443 auto node_id = data_locator->node_id();
444 size_t sz = data_locator->size();
445 void *source_addr = reinterpret_cast<void *>(data_locator->addr());
446 auto key = data_locator->key();
447 // Please read the comment in CacheServer::BatchFetchRows where we allocate
448 // the buffer big enough so each thread (which we are going to dispatch) will
449 // not run into false sharing problem. We are going to round up sz to 4k.
450 auto sz_4k = round_up_4K(sz);
451 offset_array[i + 1] = offset_array[i] + sz_4k;
452 if (sz > 0) {
453 WritableSlice row_data(*out, offset_array[i], sz);
454 // Get a request and send to the proper worker (at some numa node) to do the fetch.
455 worker_id_t worker_id = IsNumaAffinityOn() ? GetWorkerByNumaId(node_id) : GetRandomWorker();
456 CacheServerRequest *cache_rq;
457 RETURN_IF_NOT_OK(GetFreeRequestTag(&cache_rq));
458 // Set up all the necessarily field.
459 cache_rq->type_ = BaseRequest::RequestType::kInternalFetchRow;
460 cache_rq->st_ = CacheServerRequest::STATE::PROCESS;
461 cache_rq->rq_.set_connection_id(connection_id);
462 cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_));
463 auto dest_addr = row_data.GetMutablePointer();
464 flatbuffers::FlatBufferBuilder fb2;
465 FetchRowMsgBuilder bld(fb2);
466 bld.add_key(key);
467 bld.add_size(sz);
468 bld.add_source_addr(reinterpret_cast<int64_t>(source_addr));
469 bld.add_dest_addr(reinterpret_cast<int64_t>(dest_addr));
470 auto offset = bld.Finish();
471 fb2.Finish(offset);
472 cache_rq->rq_.add_buf_data(fb2.GetBufferPointer(), fb2.GetSize());
473 cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast<int64_t>(batch_wait.get())));
474 RETURN_IF_NOT_OK(PushRequest(worker_id, cache_rq));
475 } else {
476 // Nothing to fetch but we still need to post something back into the wait area.
477 RETURN_IF_NOT_OK(batch_wait->Set(Status::OK()));
478 }
479 }
480 // Now wait for all of them to come back.
481 RETURN_IF_NOT_OK(batch_wait->Wait());
482 // Return the result
483 return batch_wait->GetRc();
484 }
485
BatchFetchRows(CacheRequest * rq,CacheReply * reply)486 Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
487 auto connection_id = rq->connection_id();
488 auto client_id = rq->client_id();
489 // Hold the shared lock to prevent the cache from being dropped.
490 SharedLock lck(&rwLock_);
491 CacheService *cs = GetService(connection_id);
492 if (cs == nullptr) {
493 std::string errMsg = "Cache id " + std::to_string(connection_id) + " not found";
494 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
495 } else {
496 CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing row id");
497 auto &row_id_buf = rq->buf_data(0);
498 auto p = flatbuffers::GetRoot<TensorRowIds>(row_id_buf.data());
499 std::vector<row_id_type> row_id;
500 auto sz = p->row_id()->size();
501 row_id.reserve(sz);
502 for (uint32_t i = 0; i < sz; ++i) {
503 row_id.push_back(p->row_id()->Get(i));
504 }
505 std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
506 RETURN_IF_NOT_OK(cs->PreBatchFetch(connection_id, row_id, fbb));
507 // Let go of the shared lock. We don't need to interact with the CacheService anymore.
508 // We shouldn't be holding any lock while we can wait for a long time for the rows to come back.
509 lck.Unlock();
510 auto locator = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer());
511 int64_t mem_sz = sizeof(int64_t) * (sz + 1);
512 for (auto i = 0; i < sz; ++i) {
513 auto row_sz = locator->rows()->Get(i)->size();
514 // row_sz is the size of the cached data. Later we will spawn multiple threads
515 // each of which will copy the data into either shared memory or protobuf concurrently but
516 // to different region.
517 // To avoid false sharing, we will bump up row_sz to be a multiple of 4k, i.e. 4096 bytes
518 row_sz = round_up_4K(row_sz);
519 mem_sz += row_sz;
520 }
521 auto client_flag = rq->flag();
522 bool local_client = BitTest(client_flag, kLocalClientSupport);
523 // For large amount data to be sent back, we will use shared memory provided it is a local
524 // client that has local bypass support
525 bool local_bypass = local_client ? (mem_sz >= kLocalByPassThreshold) : false;
526 reply->set_flag(local_bypass ? kDataIsInSharedMemory : 0);
527 if (local_bypass) {
528 // We will use shared memory
529 auto *base = SharedMemoryBaseAddr();
530 void *q = nullptr;
531 RETURN_IF_NOT_OK(AllocateSharedMemory(client_id, mem_sz, &q));
532 WritableSlice dest(q, mem_sz);
533 Status rc = BatchFetch(fbb, &dest);
534 if (rc.IsError()) {
535 DeallocateSharedMemory(client_id, q);
536 return rc;
537 }
538 // We can't return the absolute address which makes no sense to the client.
539 // Instead we return the difference.
540 auto difference = reinterpret_cast<int64_t>(q) - reinterpret_cast<int64_t>(base);
541 reply->set_result(std::to_string(difference));
542 } else {
543 // We are going to use std::string to allocate and hold the result which will be eventually
544 // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
545 // to minimize memory copy.
546 std::string mem;
547 try {
548 mem.resize(mem_sz);
549 CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= mem_sz, "Programming error");
550 } catch (const std::bad_alloc &e) {
551 return Status(StatusCode::kMDOutOfMemory);
552 }
553 WritableSlice dest(mem.data(), mem_sz);
554 RETURN_IF_NOT_OK(BatchFetch(fbb, &dest));
555 reply->set_result(std::move(mem));
556 }
557 }
558 return Status::OK();
559 }
560
GetStat(CacheRequest * rq,CacheReply * reply)561 Status CacheServer::GetStat(CacheRequest *rq, CacheReply *reply) {
562 auto connection_id = rq->connection_id();
563 // Hold the shared lock to prevent the cache from being dropped.
564 SharedLock lck(&rwLock_);
565 CacheService *cs = GetService(connection_id);
566 if (cs == nullptr) {
567 std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
568 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
569 } else {
570 CacheService::ServiceStat svc_stat;
571 RETURN_IF_NOT_OK(cs->GetStat(&svc_stat));
572 flatbuffers::FlatBufferBuilder fbb;
573 ServiceStatMsgBuilder bld(fbb);
574 bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached);
575 bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached);
576 bld.add_avg_cache_sz(svc_stat.stat_.average_cache_sz);
577 bld.add_num_numa_hit(svc_stat.stat_.num_numa_hit);
578 bld.add_max_row_id(svc_stat.stat_.max_key);
579 bld.add_min_row_id(svc_stat.stat_.min_key);
580 bld.add_state(svc_stat.state_);
581 auto offset = bld.Finish();
582 fbb.Finish(offset);
583 reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
584 }
585 return Status::OK();
586 }
587
CacheSchema(CacheRequest * rq)588 Status CacheServer::CacheSchema(CacheRequest *rq) {
589 auto connection_id = rq->connection_id();
590 // Hold the shared lock to prevent the cache from being dropped.
591 SharedLock lck(&rwLock_);
592 CacheService *cs = GetService(connection_id);
593 if (cs == nullptr) {
594 std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
595 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
596 } else {
597 CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing schema information");
598 auto &create_schema_buf = rq->buf_data(0);
599 RETURN_IF_NOT_OK(cs->CacheSchema(create_schema_buf.data(), create_schema_buf.size()));
600 }
601 return Status::OK();
602 }
603
FetchSchema(CacheRequest * rq,CacheReply * reply)604 Status CacheServer::FetchSchema(CacheRequest *rq, CacheReply *reply) {
605 auto connection_id = rq->connection_id();
606 // Hold the shared lock to prevent the cache from being dropped.
607 SharedLock lck(&rwLock_);
608 CacheService *cs = GetService(connection_id);
609 if (cs == nullptr) {
610 std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
611 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
612 } else {
613 // We are going to use std::string to allocate and hold the result which will be eventually
614 // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
615 // to minimize memory copy.
616 std::string mem;
617 RETURN_IF_NOT_OK(cs->FetchSchema(&mem));
618 reply->set_result(std::move(mem));
619 }
620 return Status::OK();
621 }
622
BuildPhaseDone(CacheRequest * rq)623 Status CacheServer::BuildPhaseDone(CacheRequest *rq) {
624 auto connection_id = rq->connection_id();
625 // Hold the shared lock to prevent the cache from being dropped.
626 SharedLock lck(&rwLock_);
627 CacheService *cs = GetService(connection_id);
628 if (cs == nullptr) {
629 std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
630 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
631 } else {
632 // First piece of data is the cookie
633 CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie");
634 auto &cookie = rq->buf_data(0);
635 // We can only allow to switch phase if the cookie match.
636 if (cookie == cs->cookie()) {
637 RETURN_IF_NOT_OK(cs->BuildPhaseDone());
638 } else {
639 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
640 }
641 }
642 return Status::OK();
643 }
644
GetCacheMissKeys(CacheRequest * rq,CacheReply * reply)645 Status CacheServer::GetCacheMissKeys(CacheRequest *rq, CacheReply *reply) {
646 auto connection_id = rq->connection_id();
647 // Hold the shared lock to prevent the cache from being dropped.
648 SharedLock lck(&rwLock_);
649 CacheService *cs = GetService(connection_id);
650 if (cs == nullptr) {
651 std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
652 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
653 } else {
654 std::vector<row_id_type> gap;
655 RETURN_IF_NOT_OK(cs->FindKeysMiss(&gap));
656 flatbuffers::FlatBufferBuilder fbb;
657 auto off_t = fbb.CreateVector(gap);
658 TensorRowIdsBuilder bld(fbb);
659 bld.add_row_id(off_t);
660 auto off = bld.Finish();
661 fbb.Finish(off);
662 reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
663 }
664 return Status::OK();
665 }
666
GenerateClientSessionID(session_id_type session_id,CacheReply * reply)667 inline Status GenerateClientSessionID(session_id_type session_id, CacheReply *reply) {
668 reply->set_result(std::to_string(session_id));
669 MS_LOG(INFO) << "Server generated new session id " << session_id;
670 return Status::OK();
671 }
672
ToggleWriteMode(CacheRequest * rq)673 Status CacheServer::ToggleWriteMode(CacheRequest *rq) {
674 auto connection_id = rq->connection_id();
675 // Hold the shared lock to prevent the cache from being dropped.
676 SharedLock lck(&rwLock_);
677 CacheService *cs = GetService(connection_id);
678 if (cs == nullptr) {
679 std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
680 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
681 } else {
682 // First piece of data is the on/off flag
683 CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing action flag");
684 const auto &action = rq->buf_data(0);
685 bool on_off = false;
686 if (strcmp(action.data(), "on") == 0) {
687 on_off = true;
688 } else if (strcmp(action.data(), "off") == 0) {
689 on_off = false;
690 } else {
691 RETURN_STATUS_UNEXPECTED("Unknown request: " + action);
692 }
693 RETURN_IF_NOT_OK(cs->ToggleWriteMode(on_off));
694 }
695 return Status::OK();
696 }
697
ListSessions(CacheReply * reply)698 Status CacheServer::ListSessions(CacheReply *reply) {
699 SharedLock sess_lck(&sessions_lock_);
700 SharedLock lck(&rwLock_);
701 flatbuffers::FlatBufferBuilder fbb;
702 std::vector<flatbuffers::Offset<ListSessionMsg>> session_msgs_vector;
703 for (auto const ¤t_session_id : active_sessions_) {
704 bool found = false;
705 for (auto const &it : all_caches_) {
706 auto current_conn_id = it.first;
707 if (GetSessionID(current_conn_id) == current_session_id) {
708 found = true;
709 auto &cs = it.second;
710 CacheService::ServiceStat svc_stat;
711 RETURN_IF_NOT_OK(cs->GetStat(&svc_stat));
712 auto current_stats = CreateServiceStatMsg(fbb, svc_stat.stat_.num_mem_cached, svc_stat.stat_.num_disk_cached,
713 svc_stat.stat_.average_cache_sz, svc_stat.stat_.num_numa_hit,
714 svc_stat.stat_.min_key, svc_stat.stat_.max_key, svc_stat.state_);
715 auto current_session_info = CreateListSessionMsg(fbb, current_session_id, current_conn_id, current_stats);
716 session_msgs_vector.push_back(current_session_info);
717 }
718 }
719 if (!found) {
720 // If there is no cache created yet, assign a connection id of 0 along with empty stats
721 auto current_stats = CreateServiceStatMsg(fbb, 0, 0, 0, 0, 0, 0);
722 auto current_session_info = CreateListSessionMsg(fbb, current_session_id, 0, current_stats);
723 session_msgs_vector.push_back(current_session_info);
724 }
725 }
726 flatbuffers::Offset<flatbuffers::String> spill_dir;
727 spill_dir = fbb.CreateString(top_);
728 auto session_msgs = fbb.CreateVector(session_msgs_vector);
729 ListSessionsMsgBuilder s_builder(fbb);
730 s_builder.add_sessions(session_msgs);
731 s_builder.add_num_workers(num_workers_);
732 s_builder.add_log_level(log_level_);
733 s_builder.add_spill_dir(spill_dir);
734 auto offset = s_builder.Finish();
735 fbb.Finish(offset);
736 reply->set_result(fbb.GetBufferPointer(), fbb.GetSize());
737 return Status::OK();
738 }
739
ConnectReset(CacheRequest * rq)740 Status CacheServer::ConnectReset(CacheRequest *rq) {
741 auto connection_id = rq->connection_id();
742 // Hold the shared lock to prevent the cache from being dropped.
743 SharedLock lck(&rwLock_);
744 CacheService *cs = GetService(connection_id);
745 if (cs == nullptr) {
746 std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
747 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
748 } else {
749 auto client_id = rq->client_id();
750 MS_LOG(WARNING) << "Client id " << client_id << " with connection id " << connection_id << " disconnects";
751 cs->num_clients_--;
752 }
753 return Status::OK();
754 }
755
BatchCacheRows(CacheRequest * rq)756 Status CacheServer::BatchCacheRows(CacheRequest *rq) {
757 // First one is cookie, followed by address and then size.
758 enum BufDataIndex : uint8_t { kCookie = 0, kAddr = 1, kSize = 2 };
759 constexpr int32_t kExpectedBufDataSize = 3;
760 CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data().size() == kExpectedBufDataSize, "Expect three pieces of data");
761 try {
762 auto &cookie = rq->buf_data(BufDataIndex::kCookie);
763 auto connection_id = rq->connection_id();
764 auto client_id = rq->client_id();
765 int64_t offset_addr;
766 int32_t num_elem;
767 auto *base = SharedMemoryBaseAddr();
768 offset_addr = strtoll(rq->buf_data(BufDataIndex::kAddr).data(), nullptr, kDecimal);
769 auto p = reinterpret_cast<char *>(reinterpret_cast<int64_t>(base) + offset_addr);
770 num_elem = static_cast<int32_t>(strtol(rq->buf_data(BufDataIndex::kSize).data(), nullptr, kDecimal));
771 auto batch_wait = std::make_shared<BatchWait>(num_elem);
772 // Get a set of free request and push into the queues.
773 for (auto i = 0; i < num_elem; ++i) {
774 auto start = reinterpret_cast<int64_t>(p);
775 auto msg = GetTensorRowHeaderMsg(p);
776 p += msg->size_of_this();
777 for (auto k = 0; k < msg->column()->size(); ++k) {
778 p += msg->data_sz()->Get(k);
779 }
780 CacheServerRequest *cache_rq;
781 RETURN_IF_NOT_OK(GetFreeRequestTag(&cache_rq));
782 // Fill in details.
783 cache_rq->type_ = BaseRequest::RequestType::kInternalCacheRow;
784 cache_rq->st_ = CacheServerRequest::STATE::PROCESS;
785 cache_rq->rq_.set_connection_id(connection_id);
786 cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_));
787 cache_rq->rq_.set_client_id(client_id);
788 cache_rq->rq_.set_flag(kDataIsInSharedMemory);
789 cache_rq->rq_.add_buf_data(cookie);
790 cache_rq->rq_.add_buf_data(std::to_string(start - reinterpret_cast<int64_t>(base)));
791 cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast<int64_t>(p - start)));
792 cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast<int64_t>(batch_wait.get())));
793 RETURN_IF_NOT_OK(PushRequest(GetRandomWorker(), cache_rq));
794 }
795 // Now wait for all of them to come back.
796 RETURN_IF_NOT_OK(batch_wait->Wait());
797 // Return the result
798 return batch_wait->GetRc();
799 } catch (const std::exception &e) {
800 RETURN_STATUS_UNEXPECTED(e.what());
801 }
802 return Status::OK();
803 }
804
ProcessRowRequest(CacheServerRequest * cache_req,bool * internal_request)805 Status CacheServer::ProcessRowRequest(CacheServerRequest *cache_req, bool *internal_request) {
806 auto &rq = cache_req->rq_;
807 auto &reply = cache_req->reply_;
808 switch (cache_req->type_) {
809 case BaseRequest::RequestType::kCacheRow: {
810 // Look into the flag to see where we can find the data and call the appropriate method.
811 if (BitTest(rq.flag(), kDataIsInSharedMemory)) {
812 cache_req->rc_ = FastCacheRow(&rq, &reply);
813 } else {
814 cache_req->rc_ = CacheRow(&rq, &reply);
815 }
816 break;
817 }
818 case BaseRequest::RequestType::kInternalCacheRow: {
819 *internal_request = true;
820 cache_req->rc_ = InternalCacheRow(&rq, &reply);
821 break;
822 }
823 case BaseRequest::RequestType::kBatchCacheRows: {
824 cache_req->rc_ = BatchCacheRows(&rq);
825 break;
826 }
827 case BaseRequest::RequestType::kBatchFetchRows: {
828 cache_req->rc_ = BatchFetchRows(&rq, &reply);
829 break;
830 }
831 case BaseRequest::RequestType::kInternalFetchRow: {
832 *internal_request = true;
833 cache_req->rc_ = InternalFetchRow(&rq);
834 break;
835 }
836 default:
837 std::string errMsg("Internal error, request type is not row request: ");
838 errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
839 cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
840 }
841 return Status::OK();
842 }
843
ProcessSessionRequest(CacheServerRequest * cache_req)844 Status CacheServer::ProcessSessionRequest(CacheServerRequest *cache_req) {
845 auto &rq = cache_req->rq_;
846 auto &reply = cache_req->reply_;
847 switch (cache_req->type_) {
848 case BaseRequest::RequestType::kDropSession: {
849 cache_req->rc_ = DestroySession(&rq);
850 break;
851 }
852 case BaseRequest::RequestType::kGenerateSessionId: {
853 cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply);
854 break;
855 }
856 case BaseRequest::RequestType::kListSessions: {
857 cache_req->rc_ = ListSessions(&reply);
858 break;
859 }
860 default:
861 std::string errMsg("Internal error, request type is not session request: ");
862 errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
863 cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
864 }
865 return Status::OK();
866 }
867
ProcessAdminRequest(CacheServerRequest * cache_req)868 Status CacheServer::ProcessAdminRequest(CacheServerRequest *cache_req) {
869 auto &rq = cache_req->rq_;
870 auto &reply = cache_req->reply_;
871 switch (cache_req->type_) {
872 case BaseRequest::RequestType::kCreateCache: {
873 cache_req->rc_ = CreateService(&rq, &reply);
874 break;
875 }
876 case BaseRequest::RequestType::kGetCacheMissKeys: {
877 cache_req->rc_ = GetCacheMissKeys(&rq, &reply);
878 break;
879 }
880 case BaseRequest::RequestType::kDestroyCache: {
881 cache_req->rc_ = DestroyCache(&rq);
882 break;
883 }
884 case BaseRequest::RequestType::kGetStat: {
885 cache_req->rc_ = GetStat(&rq, &reply);
886 break;
887 }
888 case BaseRequest::RequestType::kCacheSchema: {
889 cache_req->rc_ = CacheSchema(&rq);
890 break;
891 }
892 case BaseRequest::RequestType::kFetchSchema: {
893 cache_req->rc_ = FetchSchema(&rq, &reply);
894 break;
895 }
896 case BaseRequest::RequestType::kBuildPhaseDone: {
897 cache_req->rc_ = BuildPhaseDone(&rq);
898 break;
899 }
900 case BaseRequest::RequestType::kAllocateSharedBlock: {
901 cache_req->rc_ = AllocateSharedMemory(&rq, &reply);
902 break;
903 }
904 case BaseRequest::RequestType::kFreeSharedBlock: {
905 cache_req->rc_ = FreeSharedMemory(&rq);
906 break;
907 }
908 case BaseRequest::RequestType::kStopService: {
909 // This command shutdowns everything.
910 // But we first reply back to the client that we receive the request.
911 // The real shutdown work will be done by the caller.
912 cache_req->rc_ = AcknowledgeShutdown(cache_req);
913 break;
914 }
915 case BaseRequest::RequestType::kHeartBeat: {
916 cache_req->rc_ = Status::OK();
917 break;
918 }
919 case BaseRequest::RequestType::kToggleWriteMode: {
920 cache_req->rc_ = ToggleWriteMode(&rq);
921 break;
922 }
923 case BaseRequest::RequestType::kConnectReset: {
924 cache_req->rc_ = ConnectReset(&rq);
925 break;
926 }
927 case BaseRequest::RequestType::kGetCacheState: {
928 cache_req->rc_ = GetCacheState(&rq, &reply);
929 break;
930 }
931 default:
932 std::string errMsg("Internal error, request type is not admin request: ");
933 errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
934 cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
935 }
936 return Status::OK();
937 }
938
ProcessRequest(CacheServerRequest * cache_req)939 Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
940 bool internal_request = false;
941
942 // Except for creating a new session, we expect cs is not null.
943 if (cache_req->IsRowRequest()) {
944 RETURN_IF_NOT_OK(ProcessRowRequest(cache_req, &internal_request));
945 } else if (cache_req->IsSessionRequest()) {
946 RETURN_IF_NOT_OK(ProcessSessionRequest(cache_req));
947 } else if (cache_req->IsAdminRequest()) {
948 RETURN_IF_NOT_OK(ProcessAdminRequest(cache_req));
949 } else {
950 std::string errMsg("Unknown request type : ");
951 errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
952 cache_req->rc_ = Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
953 }
954
955 // Notify it is done, and move on to the next request.
956 Status2CacheReply(cache_req->rc_, &cache_req->reply_);
957 cache_req->st_ = CacheServerRequest::STATE::FINISH;
958 // We will re-tag the request back to the grpc queue. Once it comes back from the client,
959 // the CacheServerRequest, i.e. the pointer cache_req, will be free
960 if (!internal_request && !global_shutdown_) {
961 cache_req->responder_.Finish(cache_req->reply_, grpc::Status::OK, cache_req);
962 } else {
963 // We can free up the request now.
964 RETURN_IF_NOT_OK(ReturnRequestTag(cache_req));
965 }
966 return Status::OK();
967 }
968
969 /// \brief This is the main loop the cache server thread(s) are running.
970 /// Each thread will pop a request and send the result back to the client using grpc
971 /// \return
ServerRequest(worker_id_t worker_id)972 Status CacheServer::ServerRequest(worker_id_t worker_id) {
973 TaskManager::FindMe()->Post();
974 MS_LOG(DEBUG) << "Worker id " << worker_id << " is running on node " << hw_info_->GetMyNode();
975 auto &my_que = cache_q_->operator[](worker_id);
976 // Loop forever until we are interrupted or shutdown.
977 while (!global_shutdown_) {
978 CacheServerRequest *cache_req = nullptr;
979 RETURN_IF_NOT_OK(my_que->PopFront(&cache_req));
980 RETURN_IF_NOT_OK(ProcessRequest(cache_req));
981 }
982 return Status::OK();
983 }
984
GetConnectionID(session_id_type session_id,uint32_t crc) const985 connection_id_type CacheServer::GetConnectionID(session_id_type session_id, uint32_t crc) const {
986 connection_id_type connection_id =
987 (static_cast<connection_id_type>(session_id) << 32u) | static_cast<connection_id_type>(crc);
988 return connection_id;
989 }
990
GetSessionID(connection_id_type connection_id) const991 session_id_type CacheServer::GetSessionID(connection_id_type connection_id) const {
992 return static_cast<session_id_type>(connection_id >> 32u);
993 }
994
CacheServer(const std::string & spill_path,int32_t num_workers,int32_t port,int32_t shared_meory_sz_in_gb,float memory_cap_ratio,int8_t log_level,std::shared_ptr<CacheServerHW> hw_info)995 CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port,
996 int32_t shared_meory_sz_in_gb, float memory_cap_ratio, int8_t log_level,
997 std::shared_ptr<CacheServerHW> hw_info)
998 : top_(spill_path),
999 num_workers_(num_workers),
1000 num_grpc_workers_(num_workers_),
1001 port_(port),
1002 shared_memory_sz_in_gb_(shared_meory_sz_in_gb),
1003 global_shutdown_(false),
1004 memory_cap_ratio_(memory_cap_ratio),
1005 numa_affinity_(true),
1006 log_level_(log_level),
1007 hw_info_(std::move(hw_info)) {
1008 // If we are not linked with numa library (i.e. NUMA_ENABLED is false), turn off cpu
1009 // affinity which can make performance worse.
1010 if (!CacheServerHW::numa_enabled()) {
1011 numa_affinity_ = false;
1012 MS_LOG(WARNING) << "Warning: This build is not compiled with numa support. Install libnuma-devel and use a build "
1013 "that is compiled with numa support for more optimal performance";
1014 }
1015 // We create the shared memory and we will destroy it. All other client just detach only.
1016 if (shared_memory_sz_in_gb_ > kDefaultSharedMemorySize) {
1017 MS_LOG(INFO) << "Shared memory size is readjust to " << kDefaultSharedMemorySize << " GB.";
1018 shared_memory_sz_in_gb_ = kDefaultSharedMemorySize;
1019 }
1020 }
1021
Run(int msg_qid)1022 Status CacheServer::Run(int msg_qid) {
1023 Status rc = ServiceStart();
1024 // If there is a message que, return the status now before we call join_all which will never return
1025 if (msg_qid != -1) {
1026 SharedMessage msg(msg_qid);
1027 RETURN_IF_NOT_OK(msg.SendStatus(rc));
1028 }
1029 if (rc.IsError()) {
1030 return rc;
1031 }
1032 // This is called by the main function and we shouldn't exit. Otherwise the main thread
1033 // will just shutdown. So we will call some function that never return unless error.
1034 // One good case will be simply to wait for all threads to return.
1035 // note that after we have sent the initial status using the msg_qid, parent process will exit and
1036 // remove it. So we can't use it again.
1037 RETURN_IF_NOT_OK(vg_.join_all(Task::WaitFlag::kBlocking));
1038 // Shutdown the grpc queue. No longer accept any new comer.
1039 comm_layer_->Shutdown();
1040 // The next thing to do drop all the caches.
1041 RETURN_IF_NOT_OK(ServiceStop());
1042 return Status::OK();
1043 }
1044
GetFreeRequestTag(CacheServerRequest ** q)1045 Status CacheServer::GetFreeRequestTag(CacheServerRequest **q) {
1046 RETURN_UNEXPECTED_IF_NULL(q);
1047 auto *p = new (std::nothrow) CacheServerRequest();
1048 if (p == nullptr) {
1049 return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
1050 }
1051 *q = p;
1052 return Status::OK();
1053 }
1054
ReturnRequestTag(CacheServerRequest * p)1055 Status CacheServer::ReturnRequestTag(CacheServerRequest *p) {
1056 RETURN_UNEXPECTED_IF_NULL(p);
1057 delete p;
1058 return Status::OK();
1059 }
1060
DestroySession(CacheRequest * rq)1061 Status CacheServer::DestroySession(CacheRequest *rq) {
1062 CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing session id");
1063 auto drop_session_id = rq->connection_info().session_id();
1064 // Grab the locks in the correct order to avoid deadlock.
1065 UniqueLock sess_lck(&sessions_lock_);
1066 UniqueLock lck(&rwLock_);
1067 // Iterate over the set of connection id's for this session that we're dropping and erase each one.
1068 bool found = false;
1069 for (auto it = all_caches_.begin(); it != all_caches_.end();) {
1070 auto connection_id = it->first;
1071 auto session_id = GetSessionID(connection_id);
1072 // We can just call DestroyCache() but we are holding a lock already. Doing so will cause deadlock.
1073 // So we will just manually do it.
1074 if (session_id == drop_session_id) {
1075 found = true;
1076 it = all_caches_.erase(it);
1077 MS_LOG(INFO) << "Destroy cache with id " << connection_id;
1078 } else {
1079 ++it;
1080 }
1081 }
1082 // Finally remove the session itself
1083 auto n = active_sessions_.erase(drop_session_id);
1084 if (n > 0) {
1085 MS_LOG(INFO) << "Session destroyed with id " << drop_session_id;
1086 return Status::OK();
1087 } else {
1088 if (found) {
1089 std::string errMsg =
1090 "A destroy cache request has been completed but it had a stale session id " + std::to_string(drop_session_id);
1091 RETURN_STATUS_UNEXPECTED(errMsg);
1092 } else {
1093 std::string errMsg =
1094 "Session id " + std::to_string(drop_session_id) + " not found in server on port " + std::to_string(port_) + ".";
1095 return Status(StatusCode::kMDFileNotExist, errMsg);
1096 }
1097 }
1098 }
1099
GenerateSessionID()1100 session_id_type CacheServer::GenerateSessionID() {
1101 UniqueLock sess_lck(&sessions_lock_);
1102 auto mt = GetRandomDevice();
1103 std::uniform_int_distribution<session_id_type> distribution(0, std::numeric_limits<session_id_type>::max());
1104 session_id_type session_id;
1105 bool duplicate = false;
1106 do {
1107 session_id = distribution(mt);
1108 auto r = active_sessions_.insert(session_id);
1109 duplicate = !r.second;
1110 } while (duplicate);
1111 return session_id;
1112 }
1113
AllocateSharedMemory(CacheRequest * rq,CacheReply * reply)1114 Status CacheServer::AllocateSharedMemory(CacheRequest *rq, CacheReply *reply) {
1115 auto client_id = rq->client_id();
1116 CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
1117 try {
1118 auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, kDecimal);
1119 void *p = nullptr;
1120 RETURN_IF_NOT_OK(AllocateSharedMemory(client_id, requestedSz, &p));
1121 auto *base = SharedMemoryBaseAddr();
1122 // We can't return the absolute address which makes no sense to the client.
1123 // Instead we return the difference.
1124 auto difference = reinterpret_cast<int64_t>(p) - reinterpret_cast<int64_t>(base);
1125 reply->set_result(std::to_string(difference));
1126 } catch (const std::exception &e) {
1127 RETURN_STATUS_UNEXPECTED(e.what());
1128 }
1129 return Status::OK();
1130 }
1131
FreeSharedMemory(CacheRequest * rq)1132 Status CacheServer::FreeSharedMemory(CacheRequest *rq) {
1133 auto client_id = rq->client_id();
1134 CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
1135 auto *base = SharedMemoryBaseAddr();
1136 try {
1137 auto addr = strtoll(rq->buf_data(0).data(), nullptr, kDecimal);
1138 auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr);
1139 DeallocateSharedMemory(client_id, p);
1140 } catch (const std::exception &e) {
1141 RETURN_STATUS_UNEXPECTED(e.what());
1142 }
1143 return Status::OK();
1144 }
1145
GetCacheState(CacheRequest * rq,CacheReply * reply)1146 Status CacheServer::GetCacheState(CacheRequest *rq, CacheReply *reply) {
1147 auto connection_id = rq->connection_id();
1148 SharedLock lck(&rwLock_);
1149 CacheService *cs = GetService(connection_id);
1150 if (cs == nullptr) {
1151 std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
1152 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, errMsg);
1153 } else {
1154 auto state = cs->GetState();
1155 reply->set_result(std::to_string(static_cast<int8_t>(state)));
1156 return Status::OK();
1157 }
1158 }
1159
RpcRequest(worker_id_t worker_id)1160 Status CacheServer::RpcRequest(worker_id_t worker_id) {
1161 TaskManager::FindMe()->Post();
1162 RETURN_IF_NOT_OK(comm_layer_->HandleRequest(worker_id));
1163 return Status::OK();
1164 }
1165
AcknowledgeShutdown(CacheServerRequest * cache_req)1166 Status CacheServer::AcknowledgeShutdown(CacheServerRequest *cache_req) {
1167 auto *rq = &cache_req->rq_;
1168 auto *reply = &cache_req->reply_;
1169 if (!rq->buf_data().empty()) {
1170 // cache_admin sends us a message qID and we will destroy the
1171 // queue in our destructor and this will wake up cache_admin.
1172 // But we don't want the cache_admin blindly just block itself.
1173 // So we will send back an ack before shutdown the comm layer.
1174 try {
1175 int32_t qID = std::stoi(rq->buf_data(0));
1176 shutdown_qIDs_.push_back(qID);
1177 } catch (const std::exception &e) {
1178 // ignore it.
1179 }
1180 }
1181 reply->set_result("OK");
1182 return Status::OK();
1183 }
1184
GlobalShutdown()1185 void CacheServer::GlobalShutdown() {
1186 // Let's shutdown in proper order.
1187 bool expected = false;
1188 if (global_shutdown_.compare_exchange_strong(expected, true)) {
1189 MS_LOG(WARNING) << "Shutting down server.";
1190 // Interrupt all the threads and queues. We will leave the shutdown
1191 // of the comm layer after we have joined all the threads and will
1192 // be done by the master thread.
1193 vg_.interrupt_all();
1194 }
1195 }
1196
GetWorkerByNumaId(numa_id_t numa_id) const1197 worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) const {
1198 auto num_numa_nodes = GetNumaNodeCount();
1199 MS_ASSERT(numa_id < num_numa_nodes);
1200 auto num_workers_per_node = GetNumWorkers() / num_numa_nodes;
1201 std::mt19937 gen = GetRandomDevice();
1202 std::uniform_int_distribution<worker_id_t> dist(0, num_workers_per_node - 1);
1203 auto n = dist(gen);
1204 worker_id_t worker_id = n * num_numa_nodes + numa_id;
1205 MS_ASSERT(worker_id < GetNumWorkers());
1206 return worker_id;
1207 }
1208
GetRandomWorker() const1209 worker_id_t CacheServer::GetRandomWorker() const {
1210 std::mt19937 gen = GetRandomDevice();
1211 std::uniform_int_distribution<worker_id_t> dist(0, num_workers_ - 1);
1212 return dist(gen);
1213 }
1214
AllocateSharedMemory(int32_t client_id,size_t sz,void ** p)1215 Status CacheServer::AllocateSharedMemory(int32_t client_id, size_t sz, void **p) {
1216 return shm_->AllocateSharedMemory(client_id, sz, p);
1217 }
1218
DeallocateSharedMemory(int32_t client_id,void * p)1219 void CacheServer::DeallocateSharedMemory(int32_t client_id, void *p) { shm_->DeallocateSharedMemory(client_id, p); }
1220
IpcResourceCleanup()1221 Status CacheServer::Builder::IpcResourceCleanup() {
1222 Status rc;
1223 SharedMemory::shm_key_t shm_key;
1224 auto unix_socket = PortToUnixSocketPath(port_);
1225 rc = PortToFtok(port_, &shm_key);
1226 // We are expecting the unix path doesn't exist.
1227 if (rc.IsError()) {
1228 return Status::OK();
1229 }
1230 // Attach to the shared memory which we expect don't exist
1231 SharedMemory mem(shm_key);
1232 rc = mem.Attach();
1233 if (rc.IsError()) {
1234 return Status::OK();
1235 } else {
1236 RETURN_IF_NOT_OK(mem.Detach());
1237 }
1238 int32_t num_attached;
1239 RETURN_IF_NOT_OK(mem.GetNumAttached(&num_attached));
1240 if (num_attached == 0) {
1241 // Stale shared memory from last time.
1242 // Remove both the memory and the socket path
1243 RETURN_IF_NOT_OK(mem.Destroy());
1244 Path p(unix_socket);
1245 (void)p.Remove();
1246 } else {
1247 // Server is already up.
1248 std::string errMsg = "Cache server is already up and running";
1249 // We return a duplicate error. The main() will intercept
1250 // and output a proper message
1251 return Status(StatusCode::kMDDuplicateKey, errMsg);
1252 }
1253 return Status::OK();
1254 }
1255
SanityCheck()1256 Status CacheServer::Builder::SanityCheck() {
1257 if (shared_memory_sz_in_gb_ <= 0) {
1258 RETURN_STATUS_UNEXPECTED("Shared memory size (in GB unit) must be positive");
1259 }
1260 if (num_workers_ <= 0) {
1261 RETURN_STATUS_UNEXPECTED("Number of parallel workers must be positive");
1262 }
1263 if (!top_.empty()) {
1264 auto p = top_.data();
1265 if (p[0] != '/') {
1266 RETURN_STATUS_UNEXPECTED("Spilling directory must be an absolute path");
1267 }
1268 // Check if the spill directory is writable
1269 Path spill(top_);
1270 auto t = spill / Services::GetUniqueID();
1271 Status rc = t.CreateDirectory();
1272 if (rc.IsOk()) {
1273 rc = t.Remove();
1274 }
1275 if (rc.IsError()) {
1276 RETURN_STATUS_UNEXPECTED("Spilling directory is not writable\n" + rc.ToString());
1277 }
1278 }
1279 if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1) {
1280 RETURN_STATUS_UNEXPECTED("Memory cap ratio should be positive and no greater than 1");
1281 }
1282
1283 // Check if the shared memory.
1284 RETURN_IF_NOT_OK(IpcResourceCleanup());
1285 return Status::OK();
1286 }
1287
AdjustNumWorkers(int32_t num_workers)1288 int32_t CacheServer::Builder::AdjustNumWorkers(int32_t num_workers) {
1289 int32_t num_numa_nodes = hw_info_->GetNumaNodeCount();
1290 // Bump up num_workers_ to at least the number of numa nodes
1291 num_workers = std::max(num_numa_nodes, num_workers);
1292 // But also it shouldn't be too many more than the hardware concurrency
1293 int32_t num_cpus = hw_info_->GetCpuCount();
1294 constexpr int32_t kThreadsPerCore = 2;
1295 num_workers = std::min(kThreadsPerCore * num_cpus, num_workers);
1296 // Round up num_workers to a multiple of numa nodes.
1297 auto remainder = num_workers % num_numa_nodes;
1298 if (remainder > 0) num_workers += (num_numa_nodes - remainder);
1299 return num_workers;
1300 }
1301
Builder()1302 CacheServer::Builder::Builder()
1303 : top_(""),
1304 num_workers_(kDefaultNumWorkers),
1305 port_(kCfgDefaultCachePort),
1306 shared_memory_sz_in_gb_(kDefaultSharedMemorySize),
1307 memory_cap_ratio_(kDefaultMemoryCapRatio),
1308 log_level_(kDefaultLogLevel) {
1309 if (num_workers_ == 0) {
1310 num_workers_ = 1;
1311 }
1312 }
1313 } // namespace dataset
1314 } // namespace mindspore
1315