1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7
8 * http://www.apache.org/licenses/LICENSE-2.0
9
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <random>
17 #include "minddata/dataset/engine/cache/cache_service.h"
18 #include "minddata/dataset/engine/cache/cache_server.h"
19 #include "minddata/dataset/engine/cache/cache_numa.h"
20 #include "minddata/dataset/util/random.h"
21 #include "minddata/dataset/util/slice.h"
22
23 namespace mindspore {
24 namespace dataset {
CacheService(uint64_t mem_sz,const std::string & root,bool generate_id)25 CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool generate_id)
26 : root_(root),
27 cache_mem_sz_(mem_sz * 1048576L), // mem_sz is in MB unit
28 cp_(nullptr),
29 next_id_(0),
30 generate_id_(generate_id),
31 num_clients_(0),
32 st_(generate_id ? CacheServiceState::kBuildPhase : CacheServiceState::kNone) {}
33
~CacheService()34 CacheService::~CacheService() { (void)ServiceStop(); }
35
DoServiceStart()36 Status CacheService::DoServiceStart() {
37 CacheServer &cs = CacheServer::GetInstance();
38 float memory_cap_ratio = cs.GetMemoryCapRatio();
39 if (cache_mem_sz_ > 0) {
40 auto avail_mem = CacheServerHW::GetTotalSystemMemory();
41 if (cache_mem_sz_ > avail_mem) {
42 // Return an error if we use more than recommended memory.
43 std::string errMsg = "Requesting cache size " + std::to_string(cache_mem_sz_) +
44 " while available system memory " + std::to_string(avail_mem);
45 RETURN_STATUS_OOM(errMsg);
46 }
47 memory_cap_ratio = static_cast<float>(cache_mem_sz_) / avail_mem;
48 }
49 numa_pool_ = std::make_shared<NumaMemoryPool>(cs.GetHWControl(), memory_cap_ratio);
50 // It is possible we aren't able to allocate the pool for many reasons.
51 std::vector<numa_id_t> avail_nodes = numa_pool_->GetAvailableNodes();
52 if (avail_nodes.empty()) {
53 RETURN_STATUS_UNEXPECTED("Unable to bring up numa memory pool");
54 }
55 // Put together a CachePool for backing up the Tensor.
56 cp_ = std::make_shared<CachePool>(numa_pool_, root_);
57 RETURN_IF_NOT_OK(cp_->ServiceStart());
58 // Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name.
59 cookie_ = cp_->MyName();
60 return Status::OK();
61 }
62
DoServiceStop()63 Status CacheService::DoServiceStop() {
64 if (cp_ != nullptr) {
65 RETURN_IF_NOT_OK(cp_->ServiceStop());
66 }
67 return Status::OK();
68 }
69
CacheRow(const std::vector<const void * > & buf,row_id_type * row_id_generated)70 Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) {
71 SharedLock rw(&rw_lock_);
72 RETURN_UNEXPECTED_IF_NULL(row_id_generated);
73 if (HasBuildPhase() && st_ != CacheServiceState::kBuildPhase) {
74 // For this kind of cache service, once we are done with the build phase into fetch phase, we can't
75 // allow other to cache more rows.
76 RETURN_STATUS_UNEXPECTED("Can't accept cache request in non-build phase. Current phase: " +
77 std::to_string(static_cast<int>(st_.load())));
78 }
79 if (st_ == CacheServiceState::kNoLocking) {
80 // We ignore write this request once we turn off locking on the B+ tree. So we will just
81 // return out of memory from now on.
82 RETURN_STATUS_OOM("Out of memory.");
83 }
84 try {
85 // The first buffer is a flatbuffer which describes the rest of the buffers follow
86 auto fb = buf.front();
87 RETURN_UNEXPECTED_IF_NULL(fb);
88 auto msg = GetTensorRowHeaderMsg(fb);
89 // If the server side is designed to ignore incoming row id, we generate row id.
90 if (generate_id_) {
91 *row_id_generated = GetNextRowId();
92 // Some debug information on how many rows we have generated so far.
93 constexpr int32_t kDisplayInterval = 1000;
94 if ((*row_id_generated) % kDisplayInterval == 0) {
95 MS_LOG(DEBUG) << "Number of rows cached: " << ((*row_id_generated) + 1);
96 }
97 } else {
98 if (msg->row_id() < 0) {
99 std::string errMsg = "Expect positive row id: " + std::to_string(msg->row_id());
100 RETURN_STATUS_UNEXPECTED(errMsg);
101 }
102 *row_id_generated = msg->row_id();
103 }
104 auto size_of_this = msg->size_of_this();
105 auto column_hdr = msg->column();
106 RETURN_UNEXPECTED_IF_NULL(column_hdr);
107 // Number of tensor buffer should match the number of columns plus one.
108 if (buf.size() != column_hdr->size() + 1) {
109 std::string errMsg = "Column count does not match. Expect " + std::to_string(column_hdr->size() + 1) +
110 " but get " + std::to_string(buf.size());
111 RETURN_STATUS_UNEXPECTED(errMsg);
112 }
113 // Next we store in either memory or on disk. Low level code will consolidate everything in one piece.
114 std::vector<ReadableSlice> all_data;
115 all_data.reserve(column_hdr->size() + 1);
116 all_data.emplace_back(fb, size_of_this);
117 for (auto i = 0; i < column_hdr->size(); ++i) {
118 all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i));
119 }
120 // Now we cache the buffer.
121 Status rc = cp_->Insert(*row_id_generated, all_data);
122 if (rc == Status(StatusCode::kMDDuplicateKey)) {
123 MS_LOG(DEBUG) << "Ignoring duplicate key.";
124 } else {
125 if (HasBuildPhase()) {
126 // For cache service that has a build phase, record the error in the state
127 // so other clients can be aware of the new state. There is nothing one can
128 // do to resume other than to drop the cache.
129 if (rc == StatusCode::kMDNoSpace) {
130 st_ = CacheServiceState::kNoSpace;
131 } else if (rc == StatusCode::kMDOutOfMemory) {
132 st_ = CacheServiceState::kOutOfMemory;
133 }
134 }
135 RETURN_IF_NOT_OK(rc);
136 }
137 return Status::OK();
138 } catch (const std::exception &e) {
139 RETURN_STATUS_UNEXPECTED(e.what());
140 }
141 }
142
FastCacheRow(const ReadableSlice & src,row_id_type * row_id_generated)143 Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated) {
144 SharedLock rw(&rw_lock_);
145 RETURN_UNEXPECTED_IF_NULL(row_id_generated);
146 if (HasBuildPhase() && st_ != CacheServiceState::kBuildPhase) {
147 // For this kind of cache service, once we are done with the build phase into fetch phase, we can't
148 // allow other to cache more rows.
149 RETURN_STATUS_UNEXPECTED("Can't accept cache request in non-build phase. Current phase: " +
150 std::to_string(static_cast<int>(st_.load())));
151 }
152 if (st_ == CacheServiceState::kNoLocking) {
153 // We ignore write this request once we turn off locking on the B+ tree. So we will just
154 // return out of memory from now on.
155 RETURN_STATUS_OOM("Out of memory.");
156 }
157 try {
158 // If we don't need to generate id, we need to find it from the buffer.
159 if (generate_id_) {
160 *row_id_generated = GetNextRowId();
161 // Some debug information on how many rows we have generated so far.
162 constexpr int32_t kDisplayInterval = 1000;
163 if ((*row_id_generated) % kDisplayInterval == 0) {
164 MS_LOG(DEBUG) << "Number of rows cached: " << ((*row_id_generated) + 1);
165 }
166 } else {
167 auto msg = GetTensorRowHeaderMsg(src.GetPointer());
168 if (msg->row_id() < 0) {
169 std::string errMsg = "Expect positive row id: " + std::to_string(msg->row_id());
170 RETURN_STATUS_UNEXPECTED(errMsg);
171 }
172 *row_id_generated = msg->row_id();
173 }
174 // Now we cache the buffer.
175 Status rc = cp_->Insert(*row_id_generated, {src});
176 if (rc == Status(StatusCode::kMDDuplicateKey)) {
177 MS_LOG(DEBUG) << "Ignoring duplicate key.";
178 } else {
179 if (HasBuildPhase()) {
180 // For cache service that has a build phase, record the error in the state
181 // so other clients can be aware of the new state. There is nothing one can
182 // do to resume other than to drop the cache.
183 if (rc == StatusCode::kMDNoSpace) {
184 st_ = CacheServiceState::kNoSpace;
185 } else if (rc == StatusCode::kMDOutOfMemory) {
186 st_ = CacheServiceState::kOutOfMemory;
187 }
188 }
189 RETURN_IF_NOT_OK(rc);
190 }
191 return Status::OK();
192 } catch (const std::exception &e) {
193 RETURN_STATUS_UNEXPECTED(e.what());
194 }
195 }
196
operator <<(std::ostream & out,const CacheService & cs)197 std::ostream &operator<<(std::ostream &out, const CacheService &cs) {
198 // Then show any custom derived-internal stuff
199 out << "\nCache memory size: " << cs.cache_mem_sz_;
200 out << "\nSpill path: ";
201 if (cs.root_.empty()) {
202 out << "None";
203 } else {
204 out << cs.GetSpillPath();
205 }
206 return out;
207 }
208
GetSpillPath() const209 Path CacheService::GetSpillPath() const { return cp_->GetSpillPath(); }
210
FindKeysMiss(std::vector<row_id_type> * out)211 Status CacheService::FindKeysMiss(std::vector<row_id_type> *out) {
212 RETURN_UNEXPECTED_IF_NULL(out);
213 std::unique_lock<std::mutex> lock(get_key_miss_mux_);
214 if (key_miss_results_ == nullptr) {
215 // Just do it once.
216 key_miss_results_ = std::make_shared<std::vector<row_id_type>>();
217 auto stat = cp_->GetStat(true);
218 key_miss_results_->push_back(stat.min_key);
219 key_miss_results_->push_back(stat.max_key);
220 key_miss_results_->insert(key_miss_results_->end(), stat.gap.begin(), stat.gap.end());
221 }
222 out->insert(out->end(), key_miss_results_->begin(), key_miss_results_->end());
223 return Status::OK();
224 }
225
GetStat(CacheService::ServiceStat * out)226 Status CacheService::GetStat(CacheService::ServiceStat *out) {
227 SharedLock rw(&rw_lock_);
228 RETURN_UNEXPECTED_IF_NULL(out);
229 out->stat_ = cp_->GetStat();
230 out->state_ = static_cast<ServiceStat::state_type>(st_.load());
231 return Status::OK();
232 }
233
PreBatchFetch(connection_id_type connection_id,const std::vector<row_id_type> & v,const std::shared_ptr<flatbuffers::FlatBufferBuilder> & fbb)234 Status CacheService::PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v,
235 const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb) {
236 SharedLock rw(&rw_lock_);
237 if (HasBuildPhase() && st_ != CacheServiceState::kFetchPhase) {
238 // For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
239 RETURN_STATUS_UNEXPECTED("Can't accept fetch request in non-fetch phase. Current phase: " +
240 std::to_string(static_cast<int>(st_.load())));
241 }
242 std::vector<flatbuffers::Offset<DataLocatorMsg>> datalocator_v;
243 datalocator_v.reserve(v.size());
244 for (auto row_id : v) {
245 flatbuffers::Offset<DataLocatorMsg> offset;
246 RETURN_IF_NOT_OK(cp_->GetDataLocator(row_id, fbb, &offset));
247 datalocator_v.push_back(offset);
248 }
249 auto offset_v = fbb->CreateVector(datalocator_v);
250 BatchDataLocatorMsgBuilder bld(*fbb);
251 bld.add_connection_id(connection_id);
252 bld.add_rows(offset_v);
253 auto offset_final = bld.Finish();
254 fbb->Finish(offset_final);
255 return Status::OK();
256 }
257
InternalFetchRow(const FetchRowMsg * p)258 Status CacheService::InternalFetchRow(const FetchRowMsg *p) {
259 RETURN_UNEXPECTED_IF_NULL(p);
260 SharedLock rw(&rw_lock_);
261 size_t bytesRead = 0;
262 int64_t key = p->key();
263 size_t sz = p->size();
264 void *source_addr = reinterpret_cast<void *>(p->source_addr());
265 void *dest_addr = reinterpret_cast<void *>(p->dest_addr());
266 WritableSlice dest(dest_addr, sz);
267 if (source_addr != nullptr) {
268 // We are not checking if the row is still present but simply use the information passed in.
269 // This saves another tree lookup and is faster.
270 ReadableSlice src(source_addr, sz);
271 RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, src));
272 } else {
273 RETURN_IF_NOT_OK(cp_->Read(key, &dest, &bytesRead));
274 if (bytesRead != sz) {
275 std::string errMsg = "Unexpected length. Read " + std::to_string(bytesRead) + ". Expected " + std::to_string(sz) +
276 "." + " Internal key: " + std::to_string(key);
277 MS_LOG(ERROR) << errMsg;
278 RETURN_STATUS_UNEXPECTED(errMsg);
279 }
280 }
281 return Status::OK();
282 }
283
CacheSchema(const void * buf,int64_t len)284 Status CacheService::CacheSchema(const void *buf, int64_t len) {
285 UniqueLock rw(&rw_lock_);
286 // In case we are calling the same function from multiple threads, only
287 // the first one is considered. Rest is ignored.
288 if (schema_.empty()) {
289 schema_.assign(static_cast<const char *>(buf), len);
290 } else {
291 MS_LOG(DEBUG) << "Caching Schema already done";
292 }
293 return Status::OK();
294 }
295
FetchSchema(std::string * out) const296 Status CacheService::FetchSchema(std::string *out) const {
297 SharedLock rw(&rw_lock_);
298 if (st_ == CacheServiceState::kBuildPhase) {
299 // For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
300 RETURN_STATUS_UNEXPECTED("Can't accept fetch request in non-fetch phase. Current phase: " +
301 std::to_string(static_cast<int>(st_.load())));
302 }
303 RETURN_UNEXPECTED_IF_NULL(out);
304 // We are going to use std::string to allocate and hold the result which will be eventually
305 // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
306 // to minimize memory copy.
307 std::string mem(schema_);
308 if (!mem.empty()) {
309 *out = std::move(mem);
310 } else {
311 RETURN_STATUS_ERROR(StatusCode::kMDFileNotExist, "No schema has been cached");
312 }
313 return Status::OK();
314 }
315
BuildPhaseDone()316 Status CacheService::BuildPhaseDone() {
317 if (HasBuildPhase()) {
318 // Exclusive lock to switch phase
319 UniqueLock rw(&rw_lock_);
320 st_ = CacheServiceState::kFetchPhase;
321 cp_->SetLocking(false);
322 MS_LOG(WARNING) << "Locking mode is switched off.";
323 return Status::OK();
324 } else {
325 RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase");
326 }
327 }
328
ToggleWriteMode(bool on_off)329 Status CacheService::ToggleWriteMode(bool on_off) {
330 UniqueLock rw(&rw_lock_);
331 if (HasBuildPhase()) {
332 RETURN_STATUS_UNEXPECTED("Not applicable to non-mappable dataset");
333 } else {
334 // If we stop accepting write request, we turn off locking for the
335 // underlying B+ tree. All future write request we will return kOutOfMemory.
336 if (st_ == CacheServiceState::kNone && !on_off) {
337 st_ = CacheServiceState::kNoLocking;
338 cp_->SetLocking(on_off);
339 MS_LOG(WARNING) << "Locking mode is switched off.";
340 } else if (st_ == CacheServiceState::kNoLocking && on_off) {
341 st_ = CacheServiceState::kNone;
342 cp_->SetLocking(on_off);
343 }
344 }
345 return Status::OK();
346 }
347 } // namespace dataset
348 } // namespace mindspore
349