1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7
8 * http://www.apache.org/licenses/LICENSE-2.0
9
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/cache/cache_request.h"
17 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
18 #include <sched.h>
19 #include <sys/types.h>
20 #include <unistd.h>
21 #endif
22 #include <cstdlib>
23 #include <cstring>
24 #include <thread>
25 #include "minddata/dataset/include/dataset/constants.h"
26 #include "minddata/dataset/engine/cache/cache_client.h"
27 #include "minddata/dataset/engine/cache/cache_fbb.h"
28 namespace mindspore {
29 namespace dataset {
Wait()30 Status BaseRequest::Wait() {
31 RETURN_IF_NOT_OK(wp_.Wait());
32 Status remote_rc(static_cast<StatusCode>(reply_.rc()), reply_.msg());
33 RETURN_IF_NOT_OK(remote_rc);
34 // Any extra work to do before we return back to the client.
35 RETURN_IF_NOT_OK(PostReply());
36 return Status::OK();
37 }
SerializeCacheRowRequest(const CacheClient * cc,const TensorRow & row)38 Status CacheRowRequest::SerializeCacheRowRequest(const CacheClient *cc, const TensorRow &row) {
39 CHECK_FAIL_RETURN_UNEXPECTED(row.size() > 0, "Empty tensor row");
40 CHECK_FAIL_RETURN_UNEXPECTED(cc->SupportLocalClient() == support_local_bypass_, "Local bypass mismatch");
41 // Calculate how many bytes (not counting the cookie) we are sending to the server. We only
42 // use shared memory (if supported) if we exceed certain amount
43 std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb;
44 RETURN_IF_NOT_OK(::mindspore::dataset::SerializeTensorRowHeader(row, &fbb));
45 sz_ += fbb->GetSize();
46 for (const auto &ts : row) {
47 sz_ += ts->SizeInBytes();
48 }
49 bool sent_using_local_bypass = support_local_bypass_ ? (sz_ >= kLocalByPassThreshold) : false;
50 uint32_t flag = 0;
51 if (support_local_bypass_) {
52 BitSet(&flag, kLocalClientSupport);
53 }
54 if (sent_using_local_bypass) {
55 BitSet(&flag, kDataIsInSharedMemory);
56 }
57 rq_.set_flag(flag);
58 if (sent_using_local_bypass) {
59 MS_LOG(DEBUG) << "Requesting " << sz_ << " bytes of shared memory data";
60 // Allocate shared memory from the server
61 auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(rq_.connection_id(), cc->GetClientId(), sz_);
62 RETURN_IF_NOT_OK(cc->PushRequest(mem_rq));
63 RETURN_IF_NOT_OK(mem_rq->Wait());
64 addr_ = mem_rq->GetAddr();
65 // Now we need to add that to the base address of where we attach.
66 auto base = cc->SharedMemoryBaseAddr();
67 auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr_);
68 // Now we copy the data onto shared memory.
69 WritableSlice all(p, sz_);
70 auto offset = fbb->GetSize();
71 ReadableSlice header(fbb->GetBufferPointer(), fbb->GetSize());
72 Status copy_rc = WritableSlice::Copy(&all, header);
73 if (copy_rc.IsOk()) {
74 for (const auto &ts : row) {
75 WritableSlice row_data(all, offset, ts->SizeInBytes());
76 ReadableSlice src(ts->GetBuffer(), ts->SizeInBytes());
77 copy_rc = WritableSlice::Copy(&row_data, src);
78 if (copy_rc.IsError()) {
79 break;
80 }
81 offset += ts->SizeInBytes();
82 }
83 // Fill in where to find the data
84 AddDataLocation();
85 }
86 if (copy_rc.IsError()) {
87 // We need to return the memory back to the server
88 auto mfree_req = GenerateFreeBlockRequest();
89 Status rc = cc->PushRequest(mfree_req);
90 // But we won't wait for the result for the sake of performance.
91 if (rc.IsError()) {
92 MS_LOG(ERROR) << "Push request for free memory failed.";
93 }
94 return copy_rc;
95 }
96 } else {
97 // We have already filled the first buffer which is the cookie.
98 sz_ += rq_.buf_data(0).size();
99 rq_.add_buf_data(fbb->GetBufferPointer(), fbb->GetSize());
100 for (const auto &ts : row) {
101 rq_.add_buf_data(ts->GetBuffer(), ts->SizeInBytes());
102 }
103 MS_LOG(DEBUG) << "Sending " << sz_ << " bytes of tensor data in " << rq_.buf_data_size() << " segments";
104 }
105 return Status::OK();
106 }
107
PostReply()108 Status CacheRowRequest::PostReply() {
109 if (!reply_.result().empty()) {
110 row_id_from_server_ = strtoll(reply_.result().data(), nullptr, kDecimal);
111 }
112 return Status::OK();
113 }
114
Prepare()115 Status CacheRowRequest::Prepare() {
116 if (BitTest(rq_.flag(), kDataIsInSharedMemory)) {
117 // First one is cookie, followed by address and then size.
118 constexpr int32_t kExpectedBufDataSize = 3;
119 CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() == kExpectedBufDataSize, "Incomplete rpc data");
120 } else {
121 // First one is cookie. 2nd one is the google flat buffers followed by a number of buffers.
122 // But we are not going to decode them to verify.
123 constexpr int32_t kMinBufDataSize = 3;
124 CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() >= kMinBufDataSize, "Incomplete rpc data");
125 }
126 return Status::OK();
127 }
128
CacheRowRequest(const CacheClient * cc)129 CacheRowRequest::CacheRowRequest(const CacheClient *cc)
130 : BaseRequest(RequestType::kCacheRow),
131 support_local_bypass_(cc->local_bypass_),
132 addr_(-1),
133 sz_(0),
134 row_id_from_server_(-1) {
135 rq_.set_connection_id(cc->server_connection_id_);
136 rq_.set_client_id(cc->client_id_);
137 rq_.add_buf_data(cc->cookie_);
138 }
139
BatchFetchRequest(const CacheClient * cc,const std::vector<row_id_type> & row_id)140 BatchFetchRequest::BatchFetchRequest(const CacheClient *cc, const std::vector<row_id_type> &row_id)
141 : BaseRequest(RequestType::kBatchFetchRows), support_local_bypass_(cc->local_bypass_), row_id_(row_id) {
142 rq_.set_connection_id(cc->server_connection_id_);
143 rq_.set_client_id(cc->client_id_);
144 rq_.set_flag(support_local_bypass_ ? kLocalClientSupport : 0);
145 // Convert the row id into a flatbuffer
146 flatbuffers::FlatBufferBuilder fbb;
147 auto off_t = fbb.CreateVector(row_id);
148 TensorRowIdsBuilder bld(fbb);
149 bld.add_row_id(off_t);
150 auto off = bld.Finish();
151 fbb.Finish(off);
152 rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
153 }
154
RestoreRows(TensorTable * out,const void * baseAddr,int64_t * out_addr)155 Status BatchFetchRequest::RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr) {
156 RETURN_UNEXPECTED_IF_NULL(out);
157 auto num_elements = row_id_.size();
158 const char *ptr = nullptr;
159 int64_t sz = 0;
160 // Tap into the reply flag to see where we can find the data. Server may decide the amount is
161 // so small that it doesn't use shared memory method.
162 auto flag = reply_.flag();
163 bool dataOnSharedMemory = support_local_bypass_ ? (BitTest(flag, kDataIsInSharedMemory)) : false;
164 if (dataOnSharedMemory) {
165 auto addr = strtoll(reply_.result().data(), nullptr, kDecimal);
166 ptr = reinterpret_cast<const char *>(reinterpret_cast<int64_t>(baseAddr) + addr);
167 RETURN_UNEXPECTED_IF_NULL(out);
168 *out_addr = addr;
169 } else {
170 ptr = reply_.result().data();
171 *out_addr = -1;
172 }
173 auto *offset_array = reinterpret_cast<const int64_t *>(ptr);
174 sz = offset_array[num_elements];
175 CHECK_FAIL_RETURN_UNEXPECTED(support_local_bypass_ || sz == reply_.result().length(), "Length mismatch");
176 TensorTable tbl;
177 tbl.reserve(num_elements);
178 ReadableSlice all(ptr, sz);
179 for (auto i = 0; i < num_elements; ++i) {
180 auto len = offset_array[i + 1] - offset_array[i];
181 TensorRow row;
182 row.setId(row_id_.at(i));
183 if (len > 0) {
184 ReadableSlice row_data(all, offset_array[i], len);
185 // Next we de-serialize flat buffer to get back each column
186 auto msg = GetTensorRowHeaderMsg(row_data.GetPointer());
187 auto msg_sz = msg->size_of_this();
188 // Start of the tensor data
189 auto ts_offset = msg_sz;
190 row.reserve(msg->column()->size());
191 for (auto k = 0; k < msg->column()->size(); ++k) {
192 auto col_ts = msg->column()->Get(k);
193 std::shared_ptr<Tensor> ts;
194 ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k));
195 RETURN_IF_NOT_OK(mindspore::dataset::RestoreOneTensor(col_ts, data, &ts));
196 row.push_back(ts);
197 ts_offset += data.GetSize();
198 }
199 } else {
200 CHECK_FAIL_RETURN_UNEXPECTED(len == 0, "Data corruption detected.");
201 }
202 tbl.push_back(std::move(row));
203 }
204 *out = std::move(tbl);
205 return Status::OK();
206 }
207
CreateCacheRequest(CacheClient * cc,const CacheClientInfo & cinfo,uint64_t cache_mem_sz,CreateCacheRequest::CreateCacheFlag flag)208 CreateCacheRequest::CreateCacheRequest(CacheClient *cc, const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
209 CreateCacheRequest::CreateCacheFlag flag)
210 : BaseRequest(RequestType::kCreateCache), cache_mem_sz_(cache_mem_sz), flag_(flag), cc_(cc) {
211 // Type has been set already in the base constructor. So we need to fill in the connection info.
212 // On successful return, we will get the connection id
213 rq_.mutable_connection_info()->operator=(cinfo);
214 }
215
Prepare()216 Status CreateCacheRequest::Prepare() {
217 try {
218 flatbuffers::FlatBufferBuilder fbb;
219 CreateCacheRequestMsgBuilder bld(fbb);
220 bld.add_cache_mem_sz(cache_mem_sz_);
221 bld.add_flag(static_cast<uint32_t>(flag_));
222 auto off = bld.Finish();
223 fbb.Finish(off);
224 rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
225 return Status::OK();
226 } catch (const std::bad_alloc &e) {
227 return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
228 }
229 }
230
PostReply()231 Status CreateCacheRequest::PostReply() {
232 auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data());
233 cc_->server_connection_id_ = p->connection_id();
234 cc_->cookie_ = p->cookie()->str();
235 cc_->client_id_ = p->client_id();
236 // Next is a set of cpu id that we should re-adjust ourselves for better affinity.
237 auto sz = p->cpu_id()->size();
238 cc_->cpu_list_.reserve(sz);
239 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
240 std::string c_list;
241 cpu_set_t cpu_set;
242 CPU_ZERO(&cpu_set);
243 #endif
244 for (uint32_t i = 0; i < sz; ++i) {
245 auto cpu_id = p->cpu_id()->Get(i);
246 cc_->cpu_list_.push_back(cpu_id);
247 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
248 c_list += std::to_string(cpu_id) + " ";
249 CPU_SET(cpu_id, &cpu_set);
250 #endif
251 }
252
253 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
254 if (sz > 0) {
255 auto err = sched_setaffinity(getpid(), sizeof(cpu_set), &cpu_set);
256 if (err == -1) {
257 RETURN_STATUS_UNEXPECTED("Unable to set affinity. Errno = " + std::to_string(errno));
258 }
259 MS_LOG(INFO) << "Changing cpu affinity to the following list of cpu id: " + c_list;
260 }
261 #endif
262
263 return Status::OK();
264 }
265
SerializeCacheSchemaRequest(const std::unordered_map<std::string,int32_t> & map)266 Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) {
267 try {
268 flatbuffers::FlatBufferBuilder fbb;
269 std::vector<flatbuffers::Offset<ColumnNameMsg>> v;
270 v.reserve(map.size());
271 for (auto &column : map) {
272 auto c = CreateColumnNameMsg(fbb, fbb.CreateString(column.first), column.second);
273 v.push_back(c);
274 }
275 auto v_off = fbb.CreateVector(v);
276 auto final_off = CreateSchemaMsg(fbb, v_off);
277 fbb.Finish(final_off);
278 rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
279 return Status::OK();
280 } catch (const std::bad_alloc &e) {
281 return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__);
282 }
283 }
284
PostReply()285 Status FetchSchemaRequest::PostReply() {
286 auto *map_msg = flatbuffers::GetRoot<SchemaMsg>(reply_.result().data());
287 auto v = map_msg->column();
288 for (auto i = 0; i < v->size(); ++i) {
289 auto col = map_msg->column()->Get(i);
290 column_name_id_map_.emplace(col->name()->str(), col->id());
291 }
292 return Status::OK();
293 }
294
GetColumnMap()295 std::unordered_map<std::string, int32_t> FetchSchemaRequest::GetColumnMap() { return column_name_id_map_; }
296
PostReply()297 Status GetStatRequest::PostReply() {
298 auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(reply_.result().data());
299 stat_.num_disk_cached = msg->num_disk_cached();
300 stat_.num_mem_cached = msg->num_mem_cached();
301 stat_.avg_cache_sz = msg->avg_cache_sz();
302 stat_.num_numa_hit = msg->num_numa_hit();
303 stat_.max_row_id = msg->max_row_id();
304 stat_.min_row_id = msg->min_row_id();
305 stat_.cache_service_state = msg->state();
306 return Status::OK();
307 }
308
PostReply()309 Status GetCacheStateRequest::PostReply() {
310 try {
311 cache_service_state_ = std::stoi(reply_.result());
312 } catch (const std::exception &e) {
313 RETURN_STATUS_UNEXPECTED(e.what());
314 }
315 return Status::OK();
316 }
317
PostReply()318 Status ListSessionsRequest::PostReply() {
319 auto *msg = flatbuffers::GetRoot<ListSessionsMsg>(reply_.result().data());
320 auto session_vector = msg->sessions();
321 for (uint32_t i = 0; i < session_vector->size(); ++i) {
322 SessionCacheInfo current_info{};
323 CacheServiceStat stats{};
324 auto current_session_info = session_vector->Get(i);
325 current_info.session_id = current_session_info->session_id();
326 current_info.connection_id = current_session_info->connection_id();
327 stats.num_mem_cached = current_session_info->stats()->num_mem_cached();
328 stats.num_disk_cached = current_session_info->stats()->num_disk_cached();
329 stats.avg_cache_sz = current_session_info->stats()->avg_cache_sz();
330 stats.num_numa_hit = current_session_info->stats()->num_numa_hit();
331 stats.min_row_id = current_session_info->stats()->min_row_id();
332 stats.max_row_id = current_session_info->stats()->max_row_id();
333 stats.cache_service_state = current_session_info->stats()->state();
334 current_info.stats = stats; // fixed length struct. = operator is safe
335 session_info_list_.push_back(current_info);
336 }
337 server_cfg_.num_workers = msg->num_workers();
338 server_cfg_.log_level = msg->log_level();
339 server_cfg_.spill_dir = msg->spill_dir()->str();
340 return Status::OK();
341 }
342
PostReply()343 Status ServerStopRequest::PostReply() {
344 CHECK_FAIL_RETURN_UNEXPECTED(strcmp(reply_.result().data(), "OK") == 0, "Not the right response");
345 return Status::OK();
346 }
347
BatchCacheRowsRequest(const CacheClient * cc,int64_t addr,int32_t num_ele)348 BatchCacheRowsRequest::BatchCacheRowsRequest(const CacheClient *cc, int64_t addr, int32_t num_ele)
349 : BaseRequest(RequestType::kBatchCacheRows) {
350 rq_.set_connection_id(cc->server_connection_id_);
351 rq_.set_client_id(cc->client_id_);
352 rq_.add_buf_data(cc->cookie());
353 rq_.add_buf_data(std::to_string(addr));
354 rq_.add_buf_data(std::to_string(num_ele));
355 }
356 } // namespace dataset
357 } // namespace mindspore
358