• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7 
8  * http://www.apache.org/licenses/LICENSE-2.0
9 
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15 */
16 #include "minddata/dataset/engine/cache/cache_request.h"
17 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
18 #include <sched.h>
19 #include <sys/types.h>
20 #include <unistd.h>
21 #endif
22 #include <cstdlib>
23 #include <cstring>
24 #include <thread>
25 #include "minddata/dataset/include/dataset/constants.h"
26 #include "minddata/dataset/engine/cache/cache_client.h"
27 #include "minddata/dataset/engine/cache/cache_fbb.h"
28 namespace mindspore {
29 namespace dataset {
Wait()30 Status BaseRequest::Wait() {
31   RETURN_IF_NOT_OK(wp_.Wait());
32   Status remote_rc(static_cast<StatusCode>(reply_.rc()), reply_.msg());
33   RETURN_IF_NOT_OK(remote_rc);
34   // Any extra work to do before we return back to the client.
35   RETURN_IF_NOT_OK(PostReply());
36   return Status::OK();
37 }
SerializeCacheRowRequest(const CacheClient * cc,const TensorRow & row)38 Status CacheRowRequest::SerializeCacheRowRequest(const CacheClient *cc, const TensorRow &row) {
39   CHECK_FAIL_RETURN_UNEXPECTED(row.size() > 0, "Empty tensor row");
40   CHECK_FAIL_RETURN_UNEXPECTED(cc->SupportLocalClient() == support_local_bypass_, "Local bypass mismatch");
41   // Calculate how many bytes (not counting the cookie) we are sending to the server. We only
42   // use shared memory (if supported) if we exceed certain amount
43   std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb;
44   RETURN_IF_NOT_OK(::mindspore::dataset::SerializeTensorRowHeader(row, &fbb));
45   sz_ += fbb->GetSize();
46   for (const auto &ts : row) {
47     sz_ += ts->SizeInBytes();
48   }
49   bool sent_using_local_bypass = support_local_bypass_ ? (sz_ >= kLocalByPassThreshold) : false;
50   uint32_t flag = 0;
51   if (support_local_bypass_) {
52     BitSet(&flag, kLocalClientSupport);
53   }
54   if (sent_using_local_bypass) {
55     BitSet(&flag, kDataIsInSharedMemory);
56   }
57   rq_.set_flag(flag);
58   if (sent_using_local_bypass) {
59     MS_LOG(DEBUG) << "Requesting " << sz_ << " bytes of shared memory data";
60     // Allocate shared memory from the server
61     auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(rq_.connection_id(), cc->GetClientId(), sz_);
62     RETURN_IF_NOT_OK(cc->PushRequest(mem_rq));
63     RETURN_IF_NOT_OK(mem_rq->Wait());
64     addr_ = mem_rq->GetAddr();
65     // Now we need to add that to the base address of where we attach.
66     auto base = cc->SharedMemoryBaseAddr();
67     auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr_);
68     // Now we copy the data onto shared memory.
69     WritableSlice all(p, sz_);
70     auto offset = fbb->GetSize();
71     ReadableSlice header(fbb->GetBufferPointer(), fbb->GetSize());
72     Status copy_rc = WritableSlice::Copy(&all, header);
73     if (copy_rc.IsOk()) {
74       for (const auto &ts : row) {
75         WritableSlice row_data(all, offset, ts->SizeInBytes());
76         ReadableSlice src(ts->GetBuffer(), ts->SizeInBytes());
77         copy_rc = WritableSlice::Copy(&row_data, src);
78         if (copy_rc.IsError()) {
79           break;
80         }
81         offset += ts->SizeInBytes();
82       }
83       // Fill in where to find the data
84       AddDataLocation();
85     }
86     if (copy_rc.IsError()) {
87       // We need to return the memory back to the server
88       auto mfree_req = GenerateFreeBlockRequest();
89       Status rc = cc->PushRequest(mfree_req);
90       // But we won't wait for the result for the sake of performance.
91       if (rc.IsError()) {
92         MS_LOG(ERROR) << "Push request for free memory failed.";
93       }
94       return copy_rc;
95     }
96   } else {
97     // We have already filled the first buffer which is the cookie.
98     sz_ += rq_.buf_data(0).size();
99     rq_.add_buf_data(fbb->GetBufferPointer(), fbb->GetSize());
100     for (const auto &ts : row) {
101       rq_.add_buf_data(ts->GetBuffer(), ts->SizeInBytes());
102     }
103     MS_LOG(DEBUG) << "Sending " << sz_ << " bytes of tensor data in " << rq_.buf_data_size() << " segments";
104   }
105   return Status::OK();
106 }
107 
PostReply()108 Status CacheRowRequest::PostReply() {
109   if (!reply_.result().empty()) {
110     row_id_from_server_ = strtoll(reply_.result().data(), nullptr, kDecimal);
111   }
112   return Status::OK();
113 }
114 
Prepare()115 Status CacheRowRequest::Prepare() {
116   if (BitTest(rq_.flag(), kDataIsInSharedMemory)) {
117     // First one is cookie, followed by address and then size.
118     constexpr int32_t kExpectedBufDataSize = 3;
119     CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() == kExpectedBufDataSize, "Incomplete rpc data");
120   } else {
121     // First one is cookie. 2nd one is the google flat buffers followed by a number of buffers.
122     // But we are not going to decode them to verify.
123     constexpr int32_t kMinBufDataSize = 3;
124     CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() >= kMinBufDataSize, "Incomplete rpc data");
125   }
126   return Status::OK();
127 }
128 
CacheRowRequest(const CacheClient * cc)129 CacheRowRequest::CacheRowRequest(const CacheClient *cc)
130     : BaseRequest(RequestType::kCacheRow),
131       support_local_bypass_(cc->local_bypass_),
132       addr_(-1),
133       sz_(0),
134       row_id_from_server_(-1) {
135   rq_.set_connection_id(cc->server_connection_id_);
136   rq_.set_client_id(cc->client_id_);
137   rq_.add_buf_data(cc->cookie_);
138 }
139 
BatchFetchRequest(const CacheClient * cc,const std::vector<row_id_type> & row_id)140 BatchFetchRequest::BatchFetchRequest(const CacheClient *cc, const std::vector<row_id_type> &row_id)
141     : BaseRequest(RequestType::kBatchFetchRows), support_local_bypass_(cc->local_bypass_), row_id_(row_id) {
142   rq_.set_connection_id(cc->server_connection_id_);
143   rq_.set_client_id(cc->client_id_);
144   rq_.set_flag(support_local_bypass_ ? kLocalClientSupport : 0);
145   // Convert the row id into a flatbuffer
146   flatbuffers::FlatBufferBuilder fbb;
147   auto off_t = fbb.CreateVector(row_id);
148   TensorRowIdsBuilder bld(fbb);
149   bld.add_row_id(off_t);
150   auto off = bld.Finish();
151   fbb.Finish(off);
152   rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
153 }
154 
RestoreRows(TensorTable * out,const void * baseAddr,int64_t * out_addr)155 Status BatchFetchRequest::RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr) {
156   RETURN_UNEXPECTED_IF_NULL(out);
157   auto num_elements = row_id_.size();
158   const char *ptr = nullptr;
159   int64_t sz = 0;
160   // Tap into the reply flag to see where we can find the data. Server may decide the amount is
161   // so small that it doesn't use shared memory method.
162   auto flag = reply_.flag();
163   bool dataOnSharedMemory = support_local_bypass_ ? (BitTest(flag, kDataIsInSharedMemory)) : false;
164   if (dataOnSharedMemory) {
165     auto addr = strtoll(reply_.result().data(), nullptr, kDecimal);
166     ptr = reinterpret_cast<const char *>(reinterpret_cast<int64_t>(baseAddr) + addr);
167     RETURN_UNEXPECTED_IF_NULL(out);
168     *out_addr = addr;
169   } else {
170     ptr = reply_.result().data();
171     *out_addr = -1;
172   }
173   auto *offset_array = reinterpret_cast<const int64_t *>(ptr);
174   sz = offset_array[num_elements];
175   CHECK_FAIL_RETURN_UNEXPECTED(support_local_bypass_ || sz == reply_.result().length(), "Length mismatch");
176   TensorTable tbl;
177   tbl.reserve(num_elements);
178   ReadableSlice all(ptr, sz);
179   for (auto i = 0; i < num_elements; ++i) {
180     auto len = offset_array[i + 1] - offset_array[i];
181     TensorRow row;
182     row.setId(row_id_.at(i));
183     if (len > 0) {
184       ReadableSlice row_data(all, offset_array[i], len);
185       // Next we de-serialize flat buffer to get back each column
186       auto msg = GetTensorRowHeaderMsg(row_data.GetPointer());
187       auto msg_sz = msg->size_of_this();
188       // Start of the tensor data
189       auto ts_offset = msg_sz;
190       row.reserve(msg->column()->size());
191       for (auto k = 0; k < msg->column()->size(); ++k) {
192         auto col_ts = msg->column()->Get(k);
193         std::shared_ptr<Tensor> ts;
194         ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k));
195         RETURN_IF_NOT_OK(mindspore::dataset::RestoreOneTensor(col_ts, data, &ts));
196         row.push_back(ts);
197         ts_offset += data.GetSize();
198       }
199     } else {
200       CHECK_FAIL_RETURN_UNEXPECTED(len == 0, "Data corruption detected.");
201     }
202     tbl.push_back(std::move(row));
203   }
204   *out = std::move(tbl);
205   return Status::OK();
206 }
207 
CreateCacheRequest(CacheClient * cc,const CacheClientInfo & cinfo,uint64_t cache_mem_sz,CreateCacheRequest::CreateCacheFlag flag)208 CreateCacheRequest::CreateCacheRequest(CacheClient *cc, const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
209                                        CreateCacheRequest::CreateCacheFlag flag)
210     : BaseRequest(RequestType::kCreateCache), cache_mem_sz_(cache_mem_sz), flag_(flag), cc_(cc) {
211   // Type has been set already in the base constructor. So we need to fill in the connection info.
212   // On successful return, we will get the connection id
213   rq_.mutable_connection_info()->operator=(cinfo);
214 }
215 
Prepare()216 Status CreateCacheRequest::Prepare() {
217   try {
218     flatbuffers::FlatBufferBuilder fbb;
219     CreateCacheRequestMsgBuilder bld(fbb);
220     bld.add_cache_mem_sz(cache_mem_sz_);
221     bld.add_flag(static_cast<uint32_t>(flag_));
222     auto off = bld.Finish();
223     fbb.Finish(off);
224     rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
225     return Status::OK();
226   } catch (const std::bad_alloc &e) {
227     return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
228   }
229 }
230 
PostReply()231 Status CreateCacheRequest::PostReply() {
232   auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data());
233   cc_->server_connection_id_ = p->connection_id();
234   cc_->cookie_ = p->cookie()->str();
235   cc_->client_id_ = p->client_id();
236   // Next is a set of cpu id that we should re-adjust ourselves for better affinity.
237   auto sz = p->cpu_id()->size();
238   cc_->cpu_list_.reserve(sz);
239 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
240   std::string c_list;
241   cpu_set_t cpu_set;
242   CPU_ZERO(&cpu_set);
243 #endif
244   for (uint32_t i = 0; i < sz; ++i) {
245     auto cpu_id = p->cpu_id()->Get(i);
246     cc_->cpu_list_.push_back(cpu_id);
247 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
248     c_list += std::to_string(cpu_id) + " ";
249     CPU_SET(cpu_id, &cpu_set);
250 #endif
251   }
252 
253 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
254   if (sz > 0) {
255     auto err = sched_setaffinity(getpid(), sizeof(cpu_set), &cpu_set);
256     if (err == -1) {
257       RETURN_STATUS_UNEXPECTED("Unable to set affinity. Errno = " + std::to_string(errno));
258     }
259     MS_LOG(INFO) << "Changing cpu affinity to the following list of cpu id: " + c_list;
260   }
261 #endif
262 
263   return Status::OK();
264 }
265 
SerializeCacheSchemaRequest(const std::unordered_map<std::string,int32_t> & map)266 Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) {
267   try {
268     flatbuffers::FlatBufferBuilder fbb;
269     std::vector<flatbuffers::Offset<ColumnNameMsg>> v;
270     v.reserve(map.size());
271     for (auto &column : map) {
272       auto c = CreateColumnNameMsg(fbb, fbb.CreateString(column.first), column.second);
273       v.push_back(c);
274     }
275     auto v_off = fbb.CreateVector(v);
276     auto final_off = CreateSchemaMsg(fbb, v_off);
277     fbb.Finish(final_off);
278     rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
279     return Status::OK();
280   } catch (const std::bad_alloc &e) {
281     return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
282   }
283 }
284 
PostReply()285 Status FetchSchemaRequest::PostReply() {
286   auto *map_msg = flatbuffers::GetRoot<SchemaMsg>(reply_.result().data());
287   auto v = map_msg->column();
288   for (auto i = 0; i < v->size(); ++i) {
289     auto col = map_msg->column()->Get(i);
290     column_name_id_map_.emplace(col->name()->str(), col->id());
291   }
292   return Status::OK();
293 }
294 
GetColumnMap()295 std::unordered_map<std::string, int32_t> FetchSchemaRequest::GetColumnMap() { return column_name_id_map_; }
296 
PostReply()297 Status GetStatRequest::PostReply() {
298   auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(reply_.result().data());
299   stat_.num_disk_cached = msg->num_disk_cached();
300   stat_.num_mem_cached = msg->num_mem_cached();
301   stat_.avg_cache_sz = msg->avg_cache_sz();
302   stat_.num_numa_hit = msg->num_numa_hit();
303   stat_.max_row_id = msg->max_row_id();
304   stat_.min_row_id = msg->min_row_id();
305   stat_.cache_service_state = msg->state();
306   return Status::OK();
307 }
308 
PostReply()309 Status GetCacheStateRequest::PostReply() {
310   try {
311     cache_service_state_ = std::stoi(reply_.result());
312   } catch (const std::exception &e) {
313     RETURN_STATUS_UNEXPECTED(e.what());
314   }
315   return Status::OK();
316 }
317 
PostReply()318 Status ListSessionsRequest::PostReply() {
319   auto *msg = flatbuffers::GetRoot<ListSessionsMsg>(reply_.result().data());
320   auto session_vector = msg->sessions();
321   for (uint32_t i = 0; i < session_vector->size(); ++i) {
322     SessionCacheInfo current_info{};
323     CacheServiceStat stats{};
324     auto current_session_info = session_vector->Get(i);
325     current_info.session_id = current_session_info->session_id();
326     current_info.connection_id = current_session_info->connection_id();
327     stats.num_mem_cached = current_session_info->stats()->num_mem_cached();
328     stats.num_disk_cached = current_session_info->stats()->num_disk_cached();
329     stats.avg_cache_sz = current_session_info->stats()->avg_cache_sz();
330     stats.num_numa_hit = current_session_info->stats()->num_numa_hit();
331     stats.min_row_id = current_session_info->stats()->min_row_id();
332     stats.max_row_id = current_session_info->stats()->max_row_id();
333     stats.cache_service_state = current_session_info->stats()->state();
334     current_info.stats = stats;  // fixed length struct.  = operator is safe
335     session_info_list_.push_back(current_info);
336   }
337   server_cfg_.num_workers = msg->num_workers();
338   server_cfg_.log_level = msg->log_level();
339   server_cfg_.spill_dir = msg->spill_dir()->str();
340   return Status::OK();
341 }
342 
PostReply()343 Status ServerStopRequest::PostReply() {
344   CHECK_FAIL_RETURN_UNEXPECTED(strcmp(reply_.result().data(), "OK") == 0, "Not the right response");
345   return Status::OK();
346 }
347 
BatchCacheRowsRequest(const CacheClient * cc,int64_t addr,int32_t num_ele)348 BatchCacheRowsRequest::BatchCacheRowsRequest(const CacheClient *cc, int64_t addr, int32_t num_ele)
349     : BaseRequest(RequestType::kBatchCacheRows) {
350   rq_.set_connection_id(cc->server_connection_id_);
351   rq_.set_client_id(cc->client_id_);
352   rq_.add_buf_data(cc->cookie());
353   rq_.add_buf_data(std::to_string(addr));
354   rq_.add_buf_data(std::to_string(num_ele));
355 }
356 }  // namespace dataset
357 }  // namespace mindspore
358