• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7 
8  * http://www.apache.org/licenses/LICENSE-2.0
9 
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15 */
16 #include "minddata/dataset/engine/cache/cache_request.h"
17 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
18 #include <sched.h>
19 #include <sys/types.h>
20 #include <unistd.h>
21 #endif
22 #include <cstdlib>
23 #include <cstring>
24 #include <thread>
25 #include "minddata/dataset/include/dataset/constants.h"
26 #include "minddata/dataset/engine/cache/cache_client.h"
27 #include "minddata/dataset/engine/cache/cache_fbb.h"
28 #undef BitTest
29 namespace mindspore {
30 namespace dataset {
Wait()31 Status BaseRequest::Wait() {
32   RETURN_IF_NOT_OK(wp_.Wait());
33   Status remote_rc(static_cast<StatusCode>(reply_.rc()), reply_.msg());
34   RETURN_IF_NOT_OK(remote_rc);
35   // Any extra work to do before we return back to the client.
36   RETURN_IF_NOT_OK(PostReply());
37   return Status::OK();
38 }
SerializeCacheRowRequest(const CacheClient * cc,const TensorRow & row)39 Status CacheRowRequest::SerializeCacheRowRequest(const CacheClient *cc, const TensorRow &row) {
40   CHECK_FAIL_RETURN_UNEXPECTED(row.size() > 0, "Empty tensor row");
41   CHECK_FAIL_RETURN_UNEXPECTED(cc->SupportLocalClient() == support_local_bypass_, "Local bypass mismatch");
42   // Calculate how many bytes (not counting the cookie) we are sending to the server. We only
43   // use shared memory (if supported) if we exceed certain amount
44   std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb;
45   RETURN_IF_NOT_OK(::mindspore::dataset::SerializeTensorRowHeader(row, &fbb));
46   sz_ += fbb->GetSize();
47   for (const auto &ts : row) {
48     sz_ += ts->SizeInBytes();
49   }
50   bool sent_using_local_bypass = support_local_bypass_ ? (sz_ >= kLocalByPassThreshold) : false;
51   uint32_t flag = 0;
52   if (support_local_bypass_) {
53     BitSet(&flag, kLocalClientSupport);
54   }
55   if (sent_using_local_bypass) {
56     BitSet(&flag, kDataIsInSharedMemory);
57   }
58   rq_.set_flag(flag);
59   if (sent_using_local_bypass) {
60     MS_LOG(DEBUG) << "Requesting " << sz_ << " bytes of shared memory data";
61     // Allocate shared memory from the server
62     auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(rq_.connection_id(), cc->GetClientId(), sz_);
63     RETURN_IF_NOT_OK(cc->PushRequest(mem_rq));
64     RETURN_IF_NOT_OK(mem_rq->Wait());
65     addr_ = mem_rq->GetAddr();
66     // Now we need to add that to the base address of where we attach.
67     auto base = cc->SharedMemoryBaseAddr();
68     auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr_);
69     // Now we copy the data onto shared memory.
70     WritableSlice all(p, sz_);
71     auto offset = fbb->GetSize();
72     ReadableSlice header(fbb->GetBufferPointer(), fbb->GetSize());
73     Status copy_rc = WritableSlice::Copy(&all, header);
74     if (copy_rc.IsOk()) {
75       for (const auto &ts : row) {
76         WritableSlice row_data(all, offset, ts->SizeInBytes());
77         ReadableSlice src(ts->GetBuffer(), ts->SizeInBytes());
78         copy_rc = WritableSlice::Copy(&row_data, src);
79         if (copy_rc.IsError()) {
80           break;
81         }
82         offset += ts->SizeInBytes();
83       }
84       // Fill in where to find the data
85       AddDataLocation();
86     }
87     if (copy_rc.IsError()) {
88       // We need to return the memory back to the server
89       auto mfree_req = GenerateFreeBlockRequest();
90       Status rc = cc->PushRequest(mfree_req);
91       // But we won't wait for the result for the sake of performance.
92       if (rc.IsError()) {
93         MS_LOG(ERROR) << "Push request for free memory failed.";
94       }
95       return copy_rc;
96     }
97   } else {
98     // We have already filled the first buffer which is the cookie.
99     sz_ += rq_.buf_data(0).size();
100     rq_.add_buf_data(fbb->GetBufferPointer(), fbb->GetSize());
101     for (const auto &ts : row) {
102       rq_.add_buf_data(ts->GetBuffer(), ts->SizeInBytes());
103     }
104     MS_LOG(DEBUG) << "Sending " << sz_ << " bytes of tensor data in " << rq_.buf_data_size() << " segments";
105   }
106   return Status::OK();
107 }
108 
PostReply()109 Status CacheRowRequest::PostReply() {
110   if (!reply_.result().empty()) {
111     row_id_from_server_ = strtoll(reply_.result().data(), nullptr, kDecimal);
112   }
113   return Status::OK();
114 }
115 
Prepare()116 Status CacheRowRequest::Prepare() {
117   if (BitTest(rq_.flag(), kDataIsInSharedMemory)) {
118     // First one is cookie, followed by address and then size.
119     constexpr int32_t kExpectedBufDataSize = 3;
120     CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() == kExpectedBufDataSize, "Incomplete rpc data");
121   } else {
122     // First one is cookie. 2nd one is the google flat buffers followed by a number of buffers.
123     // But we are not going to decode them to verify.
124     constexpr int32_t kMinBufDataSize = 3;
125     CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() >= kMinBufDataSize, "Incomplete rpc data");
126   }
127   return Status::OK();
128 }
129 
CacheRowRequest(const CacheClient * cc)130 CacheRowRequest::CacheRowRequest(const CacheClient *cc)
131     : BaseRequest(RequestType::kCacheRow),
132       support_local_bypass_(cc->local_bypass_),
133       addr_(-1),
134       sz_(0),
135       row_id_from_server_(-1) {
136   rq_.set_connection_id(cc->server_connection_id_);
137   rq_.set_client_id(cc->client_id_);
138   rq_.add_buf_data(cc->cookie_);
139 }
140 
BatchFetchRequest(const CacheClient * cc,const std::vector<row_id_type> & row_id)141 BatchFetchRequest::BatchFetchRequest(const CacheClient *cc, const std::vector<row_id_type> &row_id)
142     : BaseRequest(RequestType::kBatchFetchRows), support_local_bypass_(cc->local_bypass_), row_id_(row_id) {
143   rq_.set_connection_id(cc->server_connection_id_);
144   rq_.set_client_id(cc->client_id_);
145   rq_.set_flag(support_local_bypass_ ? kLocalClientSupport : 0);
146   // Convert the row id into a flatbuffer
147   flatbuffers::FlatBufferBuilder fbb;
148   auto off_t = fbb.CreateVector(row_id);
149   TensorRowIdsBuilder bld(fbb);
150   bld.add_row_id(off_t);
151   auto off = bld.Finish();
152   fbb.Finish(off);
153   rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
154 }
155 
RestoreRows(TensorTable * out,const void * baseAddr,int64_t * out_addr)156 Status BatchFetchRequest::RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr) {
157   RETURN_UNEXPECTED_IF_NULL(out);
158   auto num_elements = row_id_.size();
159   const char *ptr = nullptr;
160   int64_t sz = 0;
161   // Tap into the reply flag to see where we can find the data. Server may decide the amount is
162   // so small that it doesn't use shared memory method.
163   auto flag = reply_.flag();
164   bool dataOnSharedMemory = support_local_bypass_ ? (BitTest(flag, kDataIsInSharedMemory)) : false;
165   if (dataOnSharedMemory) {
166     auto addr = strtoll(reply_.result().data(), nullptr, kDecimal);
167     ptr = reinterpret_cast<const char *>(reinterpret_cast<int64_t>(baseAddr) + addr);
168     RETURN_UNEXPECTED_IF_NULL(out);
169     *out_addr = addr;
170   } else {
171     ptr = reply_.result().data();
172     *out_addr = -1;
173   }
174   auto *offset_array = reinterpret_cast<const int64_t *>(ptr);
175   sz = offset_array[num_elements];
176   CHECK_FAIL_RETURN_UNEXPECTED(support_local_bypass_ || sz == reply_.result().length(), "Length mismatch");
177   TensorTable tbl;
178   tbl.reserve(num_elements);
179   ReadableSlice all(ptr, sz);
180   for (auto i = 0; i < num_elements; ++i) {
181     auto len = offset_array[i + 1] - offset_array[i];
182     TensorRow row;
183     row.setId(row_id_.at(i));
184     if (len > 0) {
185       ReadableSlice row_data(all, offset_array[i], len);
186       // Next we de-serialize flat buffer to get back each column
187       auto msg = GetTensorRowHeaderMsg(row_data.GetPointer());
188       auto msg_sz = msg->size_of_this();
189       // Start of the tensor data
190       auto ts_offset = msg_sz;
191       row.reserve(msg->column()->size());
192       for (auto k = 0; k < msg->column()->size(); ++k) {
193         auto col_ts = msg->column()->Get(k);
194         std::shared_ptr<Tensor> ts;
195         ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k));
196         RETURN_IF_NOT_OK(mindspore::dataset::RestoreOneTensor(col_ts, data, &ts));
197         row.push_back(ts);
198         ts_offset += data.GetSize();
199       }
200     } else {
201       CHECK_FAIL_RETURN_UNEXPECTED(len == 0, "Data corruption detected.");
202     }
203     tbl.push_back(std::move(row));
204   }
205   *out = std::move(tbl);
206   return Status::OK();
207 }
208 
CreateCacheRequest(CacheClient * cc,const CacheClientInfo & cinfo,uint64_t cache_mem_sz,CreateCacheRequest::CreateCacheFlag flag)209 CreateCacheRequest::CreateCacheRequest(CacheClient *cc, const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
210                                        CreateCacheRequest::CreateCacheFlag flag)
211     : BaseRequest(RequestType::kCreateCache), cache_mem_sz_(cache_mem_sz), flag_(flag), cc_(cc) {
212   // Type has been set already in the base constructor. So we need to fill in the connection info.
213   // On successful return, we will get the connection id
214   rq_.mutable_connection_info()->operator=(cinfo);
215 }
216 
Prepare()217 Status CreateCacheRequest::Prepare() {
218   try {
219     flatbuffers::FlatBufferBuilder fbb;
220     CreateCacheRequestMsgBuilder bld(fbb);
221     bld.add_cache_mem_sz(cache_mem_sz_);
222     bld.add_flag(static_cast<uint32_t>(flag_));
223     auto off = bld.Finish();
224     fbb.Finish(off);
225     rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
226     return Status::OK();
227   } catch (const std::bad_alloc &e) {
228     RETURN_STATUS_OOM("Out of memory.");
229   }
230 }
231 
PostReply()232 Status CreateCacheRequest::PostReply() {
233   auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data());
234   cc_->server_connection_id_ = p->connection_id();
235   cc_->cookie_ = p->cookie()->str();
236   cc_->client_id_ = p->client_id();
237   // Next is a set of cpu id that we should re-adjust ourselves for better affinity.
238   auto sz = p->cpu_id()->size();
239   cc_->cpu_list_.reserve(sz);
240 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
241   std::string c_list;
242   cpu_set_t cpu_set;
243   CPU_ZERO(&cpu_set);
244 #endif
245   for (uint32_t i = 0; i < sz; ++i) {
246     auto cpu_id = p->cpu_id()->Get(i);
247     cc_->cpu_list_.push_back(cpu_id);
248 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
249     c_list += std::to_string(cpu_id) + " ";
250     CPU_SET(cpu_id, &cpu_set);
251 #endif
252   }
253 
254 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
255   if (sz > 0) {
256     auto err = sched_setaffinity(getpid(), sizeof(cpu_set), &cpu_set);
257     if (err == -1) {
258       RETURN_STATUS_UNEXPECTED("Unable to set affinity. Errno = " + std::to_string(errno));
259     }
260     MS_LOG(INFO) << "Changing cpu affinity to the following list of cpu id: " + c_list;
261   }
262 #endif
263 
264   return Status::OK();
265 }
266 
SerializeCacheSchemaRequest(const std::unordered_map<std::string,int32_t> & map)267 Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) {
268   try {
269     flatbuffers::FlatBufferBuilder fbb;
270     std::vector<flatbuffers::Offset<ColumnNameMsg>> v;
271     v.reserve(map.size());
272     for (auto &column : map) {
273       auto c = CreateColumnNameMsg(fbb, fbb.CreateString(column.first), column.second);
274       v.push_back(c);
275     }
276     auto v_off = fbb.CreateVector(v);
277     auto final_off = CreateSchemaMsg(fbb, v_off);
278     fbb.Finish(final_off);
279     rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
280     return Status::OK();
281   } catch (const std::bad_alloc &e) {
282     RETURN_STATUS_OOM("Out of memory.");
283   }
284 }
285 
PostReply()286 Status FetchSchemaRequest::PostReply() {
287   auto *map_msg = flatbuffers::GetRoot<SchemaMsg>(reply_.result().data());
288   auto v = map_msg->column();
289   for (auto i = 0; i < v->size(); ++i) {
290     auto col = map_msg->column()->Get(i);
291     column_name_id_map_.emplace(col->name()->str(), col->id());
292   }
293   return Status::OK();
294 }
295 
GetColumnMap()296 std::unordered_map<std::string, int32_t> FetchSchemaRequest::GetColumnMap() { return column_name_id_map_; }
297 
PostReply()298 Status GetStatRequest::PostReply() {
299   auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(reply_.result().data());
300   stat_.num_disk_cached = msg->num_disk_cached();
301   stat_.num_mem_cached = msg->num_mem_cached();
302   stat_.avg_cache_sz = msg->avg_cache_sz();
303   stat_.num_numa_hit = msg->num_numa_hit();
304   stat_.max_row_id = msg->max_row_id();
305   stat_.min_row_id = msg->min_row_id();
306   stat_.cache_service_state = msg->state();
307   return Status::OK();
308 }
309 
PostReply()310 Status GetCacheStateRequest::PostReply() {
311   try {
312     cache_service_state_ = std::stoi(reply_.result());
313   } catch (const std::exception &e) {
314     RETURN_STATUS_UNEXPECTED(e.what());
315   }
316   return Status::OK();
317 }
318 
PostReply()319 Status ListSessionsRequest::PostReply() {
320   auto *msg = flatbuffers::GetRoot<ListSessionsMsg>(reply_.result().data());
321   auto session_vector = msg->sessions();
322   for (uint32_t i = 0; i < session_vector->size(); ++i) {
323     SessionCacheInfo current_info{};
324     CacheServiceStat stats{};
325     auto current_session_info = session_vector->Get(i);
326     current_info.session_id = current_session_info->session_id();
327     current_info.connection_id = current_session_info->connection_id();
328     stats.num_mem_cached = current_session_info->stats()->num_mem_cached();
329     stats.num_disk_cached = current_session_info->stats()->num_disk_cached();
330     stats.avg_cache_sz = current_session_info->stats()->avg_cache_sz();
331     stats.num_numa_hit = current_session_info->stats()->num_numa_hit();
332     stats.min_row_id = current_session_info->stats()->min_row_id();
333     stats.max_row_id = current_session_info->stats()->max_row_id();
334     stats.cache_service_state = current_session_info->stats()->state();
335     current_info.stats = stats;  // fixed length struct.  = operator is safe
336     session_info_list_.push_back(current_info);
337   }
338   server_cfg_.num_workers = msg->num_workers();
339   server_cfg_.log_level = msg->log_level();
340   server_cfg_.spill_dir = msg->spill_dir()->str();
341   return Status::OK();
342 }
343 
PostReply()344 Status ServerStopRequest::PostReply() {
345   CHECK_FAIL_RETURN_UNEXPECTED(strcmp(reply_.result().data(), "OK") == 0, "Not the right response");
346   return Status::OK();
347 }
348 
BatchCacheRowsRequest(const CacheClient * cc,int64_t addr,int32_t num_ele)349 BatchCacheRowsRequest::BatchCacheRowsRequest(const CacheClient *cc, int64_t addr, int32_t num_ele)
350     : BaseRequest(RequestType::kBatchCacheRows) {
351   rq_.set_connection_id(cc->server_connection_id_);
352   rq_.set_client_id(cc->client_id_);
353   rq_.add_buf_data(cc->cookie());
354   rq_.add_buf_data(std::to_string(addr));
355   rq_.add_buf_data(std::to_string(num_ele));
356 }
357 }  // namespace dataset
358 }  // namespace mindspore
359