• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7 
8  * http://www.apache.org/licenses/LICENSE-2.0
9 
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15 */
16 #include <random>
17 #include "minddata/dataset/engine/cache/cache_service.h"
18 #include "minddata/dataset/engine/cache/cache_server.h"
19 #include "minddata/dataset/engine/cache/cache_numa.h"
20 #include "minddata/dataset/util/random.h"
21 #include "minddata/dataset/util/slice.h"
22 
23 namespace mindspore {
24 namespace dataset {
CacheService(uint64_t mem_sz,const std::string & root,bool generate_id)25 CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool generate_id)
26     : root_(root),
27       cache_mem_sz_(mem_sz * 1048576L),  // mem_sz is in MB unit
28       cp_(nullptr),
29       next_id_(0),
30       generate_id_(generate_id),
31       num_clients_(0),
32       st_(generate_id ? CacheServiceState::kBuildPhase : CacheServiceState::kNone) {}
33 
~CacheService()34 CacheService::~CacheService() { (void)ServiceStop(); }
35 
DoServiceStart()36 Status CacheService::DoServiceStart() {
37   CacheServer &cs = CacheServer::GetInstance();
38   float memory_cap_ratio = cs.GetMemoryCapRatio();
39   if (cache_mem_sz_ > 0) {
40     auto avail_mem = CacheServerHW::GetTotalSystemMemory();
41     if (cache_mem_sz_ > avail_mem) {
42       // Return an error if we use more than recommended memory.
43       std::string errMsg = "Requesting cache size " + std::to_string(cache_mem_sz_) +
44                            " while available system memory " + std::to_string(avail_mem);
45       RETURN_STATUS_OOM(errMsg);
46     }
47     memory_cap_ratio = static_cast<float>(cache_mem_sz_) / avail_mem;
48   }
49   numa_pool_ = std::make_shared<NumaMemoryPool>(cs.GetHWControl(), memory_cap_ratio);
50   // It is possible we aren't able to allocate the pool for many reasons.
51   std::vector<numa_id_t> avail_nodes = numa_pool_->GetAvailableNodes();
52   if (avail_nodes.empty()) {
53     RETURN_STATUS_UNEXPECTED("Unable to bring up numa memory pool");
54   }
55   // Put together a CachePool for backing up the Tensor.
56   cp_ = std::make_shared<CachePool>(numa_pool_, root_);
57   RETURN_IF_NOT_OK(cp_->ServiceStart());
58   // Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name.
59   cookie_ = cp_->MyName();
60   return Status::OK();
61 }
62 
DoServiceStop()63 Status CacheService::DoServiceStop() {
64   if (cp_ != nullptr) {
65     RETURN_IF_NOT_OK(cp_->ServiceStop());
66   }
67   return Status::OK();
68 }
69 
CacheRow(const std::vector<const void * > & buf,row_id_type * row_id_generated)70 Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) {
71   SharedLock rw(&rw_lock_);
72   RETURN_UNEXPECTED_IF_NULL(row_id_generated);
73   if (HasBuildPhase() && st_ != CacheServiceState::kBuildPhase) {
74     // For this kind of cache service, once we are done with the build phase into fetch phase, we can't
75     // allow other to cache more rows.
76     RETURN_STATUS_UNEXPECTED("Can't accept cache request in non-build phase. Current phase: " +
77                              std::to_string(static_cast<int>(st_.load())));
78   }
79   if (st_ == CacheServiceState::kNoLocking) {
80     // We ignore write this request once we turn off locking on the B+ tree. So we will just
81     // return out of memory from now on.
82     RETURN_STATUS_OOM("Out of memory.");
83   }
84   try {
85     // The first buffer is a flatbuffer which describes the rest of the buffers follow
86     auto fb = buf.front();
87     RETURN_UNEXPECTED_IF_NULL(fb);
88     auto msg = GetTensorRowHeaderMsg(fb);
89     // If the server side is designed to ignore incoming row id, we generate row id.
90     if (generate_id_) {
91       *row_id_generated = GetNextRowId();
92       // Some debug information on how many rows we have generated so far.
93       constexpr int32_t kDisplayInterval = 1000;
94       if ((*row_id_generated) % kDisplayInterval == 0) {
95         MS_LOG(DEBUG) << "Number of rows cached: " << ((*row_id_generated) + 1);
96       }
97     } else {
98       if (msg->row_id() < 0) {
99         std::string errMsg = "Expect positive row id: " + std::to_string(msg->row_id());
100         RETURN_STATUS_UNEXPECTED(errMsg);
101       }
102       *row_id_generated = msg->row_id();
103     }
104     auto size_of_this = msg->size_of_this();
105     auto column_hdr = msg->column();
106     RETURN_UNEXPECTED_IF_NULL(column_hdr);
107     // Number of tensor buffer should match the number of columns plus one.
108     if (buf.size() != column_hdr->size() + 1) {
109       std::string errMsg = "Column count does not match. Expect " + std::to_string(column_hdr->size() + 1) +
110                            " but get " + std::to_string(buf.size());
111       RETURN_STATUS_UNEXPECTED(errMsg);
112     }
113     // Next we store in either memory or on disk. Low level code will consolidate everything in one piece.
114     std::vector<ReadableSlice> all_data;
115     all_data.reserve(column_hdr->size() + 1);
116     all_data.emplace_back(fb, size_of_this);
117     for (auto i = 0; i < column_hdr->size(); ++i) {
118       all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i));
119     }
120     // Now we cache the buffer.
121     Status rc = cp_->Insert(*row_id_generated, all_data);
122     if (rc == Status(StatusCode::kMDDuplicateKey)) {
123       MS_LOG(DEBUG) << "Ignoring duplicate key.";
124     } else {
125       if (HasBuildPhase()) {
126         // For cache service that has a build phase, record the error in the state
127         // so other clients can be aware of the new state. There is nothing one can
128         // do to resume other than to drop the cache.
129         if (rc == StatusCode::kMDNoSpace) {
130           st_ = CacheServiceState::kNoSpace;
131         } else if (rc == StatusCode::kMDOutOfMemory) {
132           st_ = CacheServiceState::kOutOfMemory;
133         }
134       }
135       RETURN_IF_NOT_OK(rc);
136     }
137     return Status::OK();
138   } catch (const std::exception &e) {
139     RETURN_STATUS_UNEXPECTED(e.what());
140   }
141 }
142 
FastCacheRow(const ReadableSlice & src,row_id_type * row_id_generated)143 Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated) {
144   SharedLock rw(&rw_lock_);
145   RETURN_UNEXPECTED_IF_NULL(row_id_generated);
146   if (HasBuildPhase() && st_ != CacheServiceState::kBuildPhase) {
147     // For this kind of cache service, once we are done with the build phase into fetch phase, we can't
148     // allow other to cache more rows.
149     RETURN_STATUS_UNEXPECTED("Can't accept cache request in non-build phase. Current phase: " +
150                              std::to_string(static_cast<int>(st_.load())));
151   }
152   if (st_ == CacheServiceState::kNoLocking) {
153     // We ignore write this request once we turn off locking on the B+ tree. So we will just
154     // return out of memory from now on.
155     RETURN_STATUS_OOM("Out of memory.");
156   }
157   try {
158     // If we don't need to generate id, we need to find it from the buffer.
159     if (generate_id_) {
160       *row_id_generated = GetNextRowId();
161       // Some debug information on how many rows we have generated so far.
162       constexpr int32_t kDisplayInterval = 1000;
163       if ((*row_id_generated) % kDisplayInterval == 0) {
164         MS_LOG(DEBUG) << "Number of rows cached: " << ((*row_id_generated) + 1);
165       }
166     } else {
167       auto msg = GetTensorRowHeaderMsg(src.GetPointer());
168       if (msg->row_id() < 0) {
169         std::string errMsg = "Expect positive row id: " + std::to_string(msg->row_id());
170         RETURN_STATUS_UNEXPECTED(errMsg);
171       }
172       *row_id_generated = msg->row_id();
173     }
174     // Now we cache the buffer.
175     Status rc = cp_->Insert(*row_id_generated, {src});
176     if (rc == Status(StatusCode::kMDDuplicateKey)) {
177       MS_LOG(DEBUG) << "Ignoring duplicate key.";
178     } else {
179       if (HasBuildPhase()) {
180         // For cache service that has a build phase, record the error in the state
181         // so other clients can be aware of the new state. There is nothing one can
182         // do to resume other than to drop the cache.
183         if (rc == StatusCode::kMDNoSpace) {
184           st_ = CacheServiceState::kNoSpace;
185         } else if (rc == StatusCode::kMDOutOfMemory) {
186           st_ = CacheServiceState::kOutOfMemory;
187         }
188       }
189       RETURN_IF_NOT_OK(rc);
190     }
191     return Status::OK();
192   } catch (const std::exception &e) {
193     RETURN_STATUS_UNEXPECTED(e.what());
194   }
195 }
196 
operator <<(std::ostream & out,const CacheService & cs)197 std::ostream &operator<<(std::ostream &out, const CacheService &cs) {
198   // Then show any custom derived-internal stuff
199   out << "\nCache memory size: " << cs.cache_mem_sz_;
200   out << "\nSpill path: ";
201   if (cs.root_.empty()) {
202     out << "None";
203   } else {
204     out << cs.GetSpillPath();
205   }
206   return out;
207 }
208 
GetSpillPath() const209 Path CacheService::GetSpillPath() const { return cp_->GetSpillPath(); }
210 
FindKeysMiss(std::vector<row_id_type> * out)211 Status CacheService::FindKeysMiss(std::vector<row_id_type> *out) {
212   RETURN_UNEXPECTED_IF_NULL(out);
213   std::unique_lock<std::mutex> lock(get_key_miss_mux_);
214   if (key_miss_results_ == nullptr) {
215     // Just do it once.
216     key_miss_results_ = std::make_shared<std::vector<row_id_type>>();
217     auto stat = cp_->GetStat(true);
218     key_miss_results_->push_back(stat.min_key);
219     key_miss_results_->push_back(stat.max_key);
220     key_miss_results_->insert(key_miss_results_->end(), stat.gap.begin(), stat.gap.end());
221   }
222   out->insert(out->end(), key_miss_results_->begin(), key_miss_results_->end());
223   return Status::OK();
224 }
225 
GetStat(CacheService::ServiceStat * out)226 Status CacheService::GetStat(CacheService::ServiceStat *out) {
227   SharedLock rw(&rw_lock_);
228   RETURN_UNEXPECTED_IF_NULL(out);
229   out->stat_ = cp_->GetStat();
230   out->state_ = static_cast<ServiceStat::state_type>(st_.load());
231   return Status::OK();
232 }
233 
PreBatchFetch(connection_id_type connection_id,const std::vector<row_id_type> & v,const std::shared_ptr<flatbuffers::FlatBufferBuilder> & fbb)234 Status CacheService::PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v,
235                                    const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb) {
236   SharedLock rw(&rw_lock_);
237   if (HasBuildPhase() && st_ != CacheServiceState::kFetchPhase) {
238     // For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
239     RETURN_STATUS_UNEXPECTED("Can't accept fetch request in non-fetch phase. Current phase: " +
240                              std::to_string(static_cast<int>(st_.load())));
241   }
242   std::vector<flatbuffers::Offset<DataLocatorMsg>> datalocator_v;
243   datalocator_v.reserve(v.size());
244   for (auto row_id : v) {
245     flatbuffers::Offset<DataLocatorMsg> offset;
246     RETURN_IF_NOT_OK(cp_->GetDataLocator(row_id, fbb, &offset));
247     datalocator_v.push_back(offset);
248   }
249   auto offset_v = fbb->CreateVector(datalocator_v);
250   BatchDataLocatorMsgBuilder bld(*fbb);
251   bld.add_connection_id(connection_id);
252   bld.add_rows(offset_v);
253   auto offset_final = bld.Finish();
254   fbb->Finish(offset_final);
255   return Status::OK();
256 }
257 
InternalFetchRow(const FetchRowMsg * p)258 Status CacheService::InternalFetchRow(const FetchRowMsg *p) {
259   RETURN_UNEXPECTED_IF_NULL(p);
260   SharedLock rw(&rw_lock_);
261   size_t bytesRead = 0;
262   int64_t key = p->key();
263   size_t sz = p->size();
264   void *source_addr = reinterpret_cast<void *>(p->source_addr());
265   void *dest_addr = reinterpret_cast<void *>(p->dest_addr());
266   WritableSlice dest(dest_addr, sz);
267   if (source_addr != nullptr) {
268     // We are not checking if the row is still present but simply use the information passed in.
269     // This saves another tree lookup and is faster.
270     ReadableSlice src(source_addr, sz);
271     RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, src));
272   } else {
273     RETURN_IF_NOT_OK(cp_->Read(key, &dest, &bytesRead));
274     if (bytesRead != sz) {
275       std::string errMsg = "Unexpected length. Read " + std::to_string(bytesRead) + ". Expected " + std::to_string(sz) +
276                            "." + " Internal key: " + std::to_string(key);
277       MS_LOG(ERROR) << errMsg;
278       RETURN_STATUS_UNEXPECTED(errMsg);
279     }
280   }
281   return Status::OK();
282 }
283 
CacheSchema(const void * buf,int64_t len)284 Status CacheService::CacheSchema(const void *buf, int64_t len) {
285   UniqueLock rw(&rw_lock_);
286   // In case we are calling the same function from multiple threads, only
287   // the first one is considered. Rest is ignored.
288   if (schema_.empty()) {
289     schema_.assign(static_cast<const char *>(buf), len);
290   } else {
291     MS_LOG(DEBUG) << "Caching Schema already done";
292   }
293   return Status::OK();
294 }
295 
FetchSchema(std::string * out) const296 Status CacheService::FetchSchema(std::string *out) const {
297   SharedLock rw(&rw_lock_);
298   if (st_ == CacheServiceState::kBuildPhase) {
299     // For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
300     RETURN_STATUS_UNEXPECTED("Can't accept fetch request in non-fetch phase. Current phase: " +
301                              std::to_string(static_cast<int>(st_.load())));
302   }
303   RETURN_UNEXPECTED_IF_NULL(out);
304   // We are going to use std::string to allocate and hold the result which will be eventually
305   // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
306   // to minimize memory copy.
307   std::string mem(schema_);
308   if (!mem.empty()) {
309     *out = std::move(mem);
310   } else {
311     RETURN_STATUS_ERROR(StatusCode::kMDFileNotExist, "No schema has been cached");
312   }
313   return Status::OK();
314 }
315 
BuildPhaseDone()316 Status CacheService::BuildPhaseDone() {
317   if (HasBuildPhase()) {
318     // Exclusive lock to switch phase
319     UniqueLock rw(&rw_lock_);
320     st_ = CacheServiceState::kFetchPhase;
321     cp_->SetLocking(false);
322     MS_LOG(WARNING) << "Locking mode is switched off.";
323     return Status::OK();
324   } else {
325     RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase");
326   }
327 }
328 
ToggleWriteMode(bool on_off)329 Status CacheService::ToggleWriteMode(bool on_off) {
330   UniqueLock rw(&rw_lock_);
331   if (HasBuildPhase()) {
332     RETURN_STATUS_UNEXPECTED("Not applicable to non-mappable dataset");
333   } else {
334     // If we stop accepting write request, we turn off locking for the
335     // underlying B+ tree. All future write request we will return kOutOfMemory.
336     if (st_ == CacheServiceState::kNone && !on_off) {
337       st_ = CacheServiceState::kNoLocking;
338       cp_->SetLocking(on_off);
339       MS_LOG(WARNING) << "Locking mode is switched off.";
340     } else if (st_ == CacheServiceState::kNoLocking && on_off) {
341       st_ = CacheServiceState::kNone;
342       cp_->SetLocking(on_off);
343     }
344   }
345   return Status::OK();
346 }
347 }  // namespace dataset
348 }  // namespace mindspore
349