• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7 
8  * http://www.apache.org/licenses/LICENSE-2.0
9 
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15 */
16 #include <random>
17 #include "minddata/dataset/engine/cache/cache_service.h"
18 #include "minddata/dataset/engine/cache/cache_server.h"
19 #include "minddata/dataset/engine/cache/cache_numa.h"
20 #include "minddata/dataset/util/random.h"
21 #include "minddata/dataset/util/slice.h"
22 
23 namespace mindspore {
24 namespace dataset {
CacheService(uint64_t mem_sz,const std::string & root,bool generate_id)25 CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool generate_id)
26     : root_(root),
27       cache_mem_sz_(mem_sz * 1048576L),  // mem_sz is in MB unit
28       cp_(nullptr),
29       next_id_(0),
30       generate_id_(generate_id),
31       num_clients_(0),
32       st_(generate_id ? CacheServiceState::kBuildPhase : CacheServiceState::kNone) {}
33 
~CacheService()34 CacheService::~CacheService() { (void)ServiceStop(); }
35 
DoServiceStart()36 Status CacheService::DoServiceStart() {
37   CacheServer &cs = CacheServer::GetInstance();
38   float memory_cap_ratio = cs.GetMemoryCapRatio();
39   if (cache_mem_sz_ > 0) {
40     auto avail_mem = CacheServerHW::GetTotalSystemMemory();
41     if (cache_mem_sz_ > avail_mem) {
42       // Return an error if we use more than recommended memory.
43       std::string errMsg = "Requesting cache size " + std::to_string(cache_mem_sz_) +
44                            " while available system memory " + std::to_string(avail_mem);
45       return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, errMsg);
46     }
47     memory_cap_ratio = static_cast<float>(cache_mem_sz_) / avail_mem;
48   }
49   numa_pool_ = std::make_shared<NumaMemoryPool>(cs.GetHWControl(), memory_cap_ratio);
50   // It is possible we aren't able to allocate the pool for many reasons.
51   std::vector<numa_id_t> avail_nodes = numa_pool_->GetAvailableNodes();
52   if (avail_nodes.empty()) {
53     RETURN_STATUS_UNEXPECTED("Unable to bring up numa memory pool");
54   }
55   // Put together a CachePool for backing up the Tensor.
56   cp_ = std::make_shared<CachePool>(numa_pool_, root_);
57   RETURN_IF_NOT_OK(cp_->ServiceStart());
58   // Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name.
59   cookie_ = cp_->MyName();
60   return Status::OK();
61 }
62 
DoServiceStop()63 Status CacheService::DoServiceStop() {
64   if (cp_ != nullptr) {
65     RETURN_IF_NOT_OK(cp_->ServiceStop());
66   }
67   return Status::OK();
68 }
69 
CacheRow(const std::vector<const void * > & buf,row_id_type * row_id_generated)70 Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) {
71   SharedLock rw(&rw_lock_);
72   RETURN_UNEXPECTED_IF_NULL(row_id_generated);
73   if (HasBuildPhase() && st_ != CacheServiceState::kBuildPhase) {
74     // For this kind of cache service, once we are done with the build phase into fetch phase, we can't
75     // allow other to cache more rows.
76     RETURN_STATUS_UNEXPECTED("Can't accept cache request in non-build phase. Current phase: " +
77                              std::to_string(static_cast<int>(st_.load())));
78   }
79   if (st_ == CacheServiceState::kNoLocking) {
80     // We ignore write this request once we turn off locking on the B+ tree. So we will just
81     // return out of memory from now on.
82     return Status(StatusCode::kMDOutOfMemory);
83   }
84   try {
85     // The first buffer is a flatbuffer which describes the rest of the buffers follow
86     auto fb = buf.front();
87     RETURN_UNEXPECTED_IF_NULL(fb);
88     auto msg = GetTensorRowHeaderMsg(fb);
89     // If the server side is designed to ignore incoming row id, we generate row id.
90     if (generate_id_) {
91       *row_id_generated = GetNextRowId();
92       // Some debug information on how many rows we have generated so far.
93       constexpr int32_t kDisplayInterval = 1000;
94       if ((*row_id_generated) % kDisplayInterval == 0) {
95         MS_LOG(DEBUG) << "Number of rows cached: " << ((*row_id_generated) + 1);
96       }
97     } else {
98       if (msg->row_id() < 0) {
99         std::string errMsg = "Expect positive row id: " + std::to_string(msg->row_id());
100         RETURN_STATUS_UNEXPECTED(errMsg);
101       }
102       *row_id_generated = msg->row_id();
103     }
104     auto size_of_this = msg->size_of_this();
105     size_t total_sz = size_of_this;
106     auto column_hdr = msg->column();
107     // Number of tensor buffer should match the number of columns plus one.
108     if (buf.size() != column_hdr->size() + 1) {
109       std::string errMsg = "Column count does not match. Expect " + std::to_string(column_hdr->size() + 1) +
110                            " but get " + std::to_string(buf.size());
111       RETURN_STATUS_UNEXPECTED(errMsg);
112     }
113     // Next we store in either memory or on disk. Low level code will consolidate everything in one piece.
114     std::vector<ReadableSlice> all_data;
115     all_data.reserve(column_hdr->size() + 1);
116     all_data.emplace_back(fb, size_of_this);
117     for (auto i = 0; i < column_hdr->size(); ++i) {
118       all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i));
119       total_sz += msg->data_sz()->Get(i);
120     }
121     // Now we cache the buffer.
122     Status rc = cp_->Insert(*row_id_generated, all_data);
123     if (rc == Status(StatusCode::kMDDuplicateKey)) {
124       MS_LOG(DEBUG) << "Ignoring duplicate key.";
125     } else {
126       if (HasBuildPhase()) {
127         // For cache service that has a build phase, record the error in the state
128         // so other clients can be aware of the new state. There is nothing one can
129         // do to resume other than to drop the cache.
130         if (rc == StatusCode::kMDNoSpace) {
131           st_ = CacheServiceState::kNoSpace;
132         } else if (rc == StatusCode::kMDOutOfMemory) {
133           st_ = CacheServiceState::kOutOfMemory;
134         }
135       }
136       RETURN_IF_NOT_OK(rc);
137     }
138     return Status::OK();
139   } catch (const std::exception &e) {
140     RETURN_STATUS_UNEXPECTED(e.what());
141   }
142 }
143 
FastCacheRow(const ReadableSlice & src,row_id_type * row_id_generated)144 Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated) {
145   SharedLock rw(&rw_lock_);
146   RETURN_UNEXPECTED_IF_NULL(row_id_generated);
147   if (HasBuildPhase() && st_ != CacheServiceState::kBuildPhase) {
148     // For this kind of cache service, once we are done with the build phase into fetch phase, we can't
149     // allow other to cache more rows.
150     RETURN_STATUS_UNEXPECTED("Can't accept cache request in non-build phase. Current phase: " +
151                              std::to_string(static_cast<int>(st_.load())));
152   }
153   if (st_ == CacheServiceState::kNoLocking) {
154     // We ignore write this request once we turn off locking on the B+ tree. So we will just
155     // return out of memory from now on.
156     return Status(StatusCode::kMDOutOfMemory);
157   }
158   try {
159     // If we don't need to generate id, we need to find it from the buffer.
160     if (generate_id_) {
161       *row_id_generated = GetNextRowId();
162       // Some debug information on how many rows we have generated so far.
163       constexpr int32_t kDisplayInterval = 1000;
164       if ((*row_id_generated) % kDisplayInterval == 0) {
165         MS_LOG(DEBUG) << "Number of rows cached: " << ((*row_id_generated) + 1);
166       }
167     } else {
168       auto msg = GetTensorRowHeaderMsg(src.GetPointer());
169       if (msg->row_id() < 0) {
170         std::string errMsg = "Expect positive row id: " + std::to_string(msg->row_id());
171         RETURN_STATUS_UNEXPECTED(errMsg);
172       }
173       *row_id_generated = msg->row_id();
174     }
175     // Now we cache the buffer.
176     Status rc = cp_->Insert(*row_id_generated, {src});
177     if (rc == Status(StatusCode::kMDDuplicateKey)) {
178       MS_LOG(DEBUG) << "Ignoring duplicate key.";
179     } else {
180       if (HasBuildPhase()) {
181         // For cache service that has a build phase, record the error in the state
182         // so other clients can be aware of the new state. There is nothing one can
183         // do to resume other than to drop the cache.
184         if (rc == StatusCode::kMDNoSpace) {
185           st_ = CacheServiceState::kNoSpace;
186         } else if (rc == StatusCode::kMDOutOfMemory) {
187           st_ = CacheServiceState::kOutOfMemory;
188         }
189       }
190       RETURN_IF_NOT_OK(rc);
191     }
192     return Status::OK();
193   } catch (const std::exception &e) {
194     RETURN_STATUS_UNEXPECTED(e.what());
195   }
196 }
197 
operator <<(std::ostream & out,const CacheService & cs)198 std::ostream &operator<<(std::ostream &out, const CacheService &cs) {
199   // Then show any custom derived-internal stuff
200   out << "\nCache memory size: " << cs.cache_mem_sz_;
201   out << "\nSpill path: ";
202   if (cs.root_.empty()) {
203     out << "None";
204   } else {
205     out << cs.GetSpillPath();
206   }
207   return out;
208 }
209 
GetSpillPath() const210 Path CacheService::GetSpillPath() const { return cp_->GetSpillPath(); }
211 
FindKeysMiss(std::vector<row_id_type> * out)212 Status CacheService::FindKeysMiss(std::vector<row_id_type> *out) {
213   RETURN_UNEXPECTED_IF_NULL(out);
214   std::unique_lock<std::mutex> lock(get_key_miss_mux_);
215   if (key_miss_results_ == nullptr) {
216     // Just do it once.
217     key_miss_results_ = std::make_shared<std::vector<row_id_type>>();
218     auto stat = cp_->GetStat(true);
219     key_miss_results_->push_back(stat.min_key);
220     key_miss_results_->push_back(stat.max_key);
221     key_miss_results_->insert(key_miss_results_->end(), stat.gap.begin(), stat.gap.end());
222   }
223   out->insert(out->end(), key_miss_results_->begin(), key_miss_results_->end());
224   return Status::OK();
225 }
226 
GetStat(CacheService::ServiceStat * out)227 Status CacheService::GetStat(CacheService::ServiceStat *out) {
228   SharedLock rw(&rw_lock_);
229   RETURN_UNEXPECTED_IF_NULL(out);
230   out->stat_ = cp_->GetStat();
231   out->state_ = static_cast<ServiceStat::state_type>(st_.load());
232   return Status::OK();
233 }
234 
PreBatchFetch(connection_id_type connection_id,const std::vector<row_id_type> & v,const std::shared_ptr<flatbuffers::FlatBufferBuilder> & fbb)235 Status CacheService::PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v,
236                                    const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb) {
237   SharedLock rw(&rw_lock_);
238   if (HasBuildPhase() && st_ != CacheServiceState::kFetchPhase) {
239     // For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
240     RETURN_STATUS_UNEXPECTED("Can't accept fetch request in non-fetch phase. Current phase: " +
241                              std::to_string(static_cast<int>(st_.load())));
242   }
243   std::vector<flatbuffers::Offset<DataLocatorMsg>> datalocator_v;
244   datalocator_v.reserve(v.size());
245   for (auto row_id : v) {
246     flatbuffers::Offset<DataLocatorMsg> offset;
247     RETURN_IF_NOT_OK(cp_->GetDataLocator(row_id, fbb, &offset));
248     datalocator_v.push_back(offset);
249   }
250   auto offset_v = fbb->CreateVector(datalocator_v);
251   BatchDataLocatorMsgBuilder bld(*fbb);
252   bld.add_connection_id(connection_id);
253   bld.add_rows(offset_v);
254   auto offset_final = bld.Finish();
255   fbb->Finish(offset_final);
256   return Status::OK();
257 }
258 
InternalFetchRow(const FetchRowMsg * p)259 Status CacheService::InternalFetchRow(const FetchRowMsg *p) {
260   RETURN_UNEXPECTED_IF_NULL(p);
261   SharedLock rw(&rw_lock_);
262   size_t bytesRead = 0;
263   int64_t key = p->key();
264   size_t sz = p->size();
265   void *source_addr = reinterpret_cast<void *>(p->source_addr());
266   void *dest_addr = reinterpret_cast<void *>(p->dest_addr());
267   WritableSlice dest(dest_addr, sz);
268   if (source_addr != nullptr) {
269     // We are not checking if the row is still present but simply use the information passed in.
270     // This saves another tree lookup and is faster.
271     ReadableSlice src(source_addr, sz);
272     RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, src));
273   } else {
274     RETURN_IF_NOT_OK(cp_->Read(key, &dest, &bytesRead));
275     if (bytesRead != sz) {
276       std::string errMsg = "Unexpected length. Read " + std::to_string(bytesRead) + ". Expected " + std::to_string(sz) +
277                            "." + " Internal key: " + std::to_string(key);
278       MS_LOG(ERROR) << errMsg;
279       RETURN_STATUS_UNEXPECTED(errMsg);
280     }
281   }
282   return Status::OK();
283 }
284 
CacheSchema(const void * buf,int64_t len)285 Status CacheService::CacheSchema(const void *buf, int64_t len) {
286   UniqueLock rw(&rw_lock_);
287   // In case we are calling the same function from multiple threads, only
288   // the first one is considered. Rest is ignored.
289   if (schema_.empty()) {
290     schema_.assign(static_cast<const char *>(buf), len);
291   } else {
292     MS_LOG(DEBUG) << "Caching Schema already done";
293   }
294   return Status::OK();
295 }
296 
FetchSchema(std::string * out) const297 Status CacheService::FetchSchema(std::string *out) const {
298   SharedLock rw(&rw_lock_);
299   if (st_ == CacheServiceState::kBuildPhase) {
300     // For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
301     RETURN_STATUS_UNEXPECTED("Can't accept fetch request in non-fetch phase. Current phase: " +
302                              std::to_string(static_cast<int>(st_.load())));
303   }
304   RETURN_UNEXPECTED_IF_NULL(out);
305   // We are going to use std::string to allocate and hold the result which will be eventually
306   // 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
307   // to minimize memory copy.
308   std::string mem(schema_);
309   if (!mem.empty()) {
310     *out = std::move(mem);
311   } else {
312     return Status(StatusCode::kMDFileNotExist, __LINE__, __FILE__, "No schema has been cached");
313   }
314   return Status::OK();
315 }
316 
BuildPhaseDone()317 Status CacheService::BuildPhaseDone() {
318   if (HasBuildPhase()) {
319     // Exclusive lock to switch phase
320     UniqueLock rw(&rw_lock_);
321     st_ = CacheServiceState::kFetchPhase;
322     cp_->SetLocking(false);
323     MS_LOG(WARNING) << "Locking mode is switched off.";
324     return Status::OK();
325   } else {
326     RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase");
327   }
328 }
329 
ToggleWriteMode(bool on_off)330 Status CacheService::ToggleWriteMode(bool on_off) {
331   UniqueLock rw(&rw_lock_);
332   if (HasBuildPhase()) {
333     RETURN_STATUS_UNEXPECTED("Not applicable to non-mappable dataset");
334   } else {
335     // If we stop accepting write request, we turn off locking for the
336     // underlying B+ tree. All future write request we will return kOutOfMemory.
337     if (st_ == CacheServiceState::kNone && !on_off) {
338       st_ = CacheServiceState::kNoLocking;
339       cp_->SetLocking(on_off);
340       MS_LOG(WARNING) << "Locking mode is switched off.";
341     } else if (st_ == CacheServiceState::kNoLocking && on_off) {
342       st_ = CacheServiceState::kNone;
343       cp_->SetLocking(on_off);
344     }
345   }
346   return Status::OK();
347 }
348 }  // namespace dataset
349 }  // namespace mindspore
350