1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7
8 * http://www.apache.org/licenses/LICENSE-2.0
9
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/cache/cache_request.h"
17 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
18 #include <sched.h>
19 #include <sys/types.h>
20 #include <unistd.h>
21 #endif
22 #include <cstdlib>
23 #include <cstring>
24 #include <thread>
25 #include "minddata/dataset/include/dataset/constants.h"
26 #include "minddata/dataset/engine/cache/cache_client.h"
27 #include "minddata/dataset/engine/cache/cache_fbb.h"
28 #undef BitTest
29 namespace mindspore {
30 namespace dataset {
Wait()31 Status BaseRequest::Wait() {
32 RETURN_IF_NOT_OK(wp_.Wait());
33 Status remote_rc(static_cast<StatusCode>(reply_.rc()), reply_.msg());
34 RETURN_IF_NOT_OK(remote_rc);
35 // Any extra work to do before we return back to the client.
36 RETURN_IF_NOT_OK(PostReply());
37 return Status::OK();
38 }
SerializeCacheRowRequest(const CacheClient * cc,const TensorRow & row)39 Status CacheRowRequest::SerializeCacheRowRequest(const CacheClient *cc, const TensorRow &row) {
40 CHECK_FAIL_RETURN_UNEXPECTED(row.size() > 0, "Empty tensor row");
41 CHECK_FAIL_RETURN_UNEXPECTED(cc->SupportLocalClient() == support_local_bypass_, "Local bypass mismatch");
42 // Calculate how many bytes (not counting the cookie) we are sending to the server. We only
43 // use shared memory (if supported) if we exceed certain amount
44 std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb;
45 RETURN_IF_NOT_OK(::mindspore::dataset::SerializeTensorRowHeader(row, &fbb));
46 sz_ += fbb->GetSize();
47 for (const auto &ts : row) {
48 sz_ += ts->SizeInBytes();
49 }
50 bool sent_using_local_bypass = support_local_bypass_ ? (sz_ >= kLocalByPassThreshold) : false;
51 uint32_t flag = 0;
52 if (support_local_bypass_) {
53 BitSet(&flag, kLocalClientSupport);
54 }
55 if (sent_using_local_bypass) {
56 BitSet(&flag, kDataIsInSharedMemory);
57 }
58 rq_.set_flag(flag);
59 if (sent_using_local_bypass) {
60 MS_LOG(DEBUG) << "Requesting " << sz_ << " bytes of shared memory data";
61 // Allocate shared memory from the server
62 auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(rq_.connection_id(), cc->GetClientId(), sz_);
63 RETURN_IF_NOT_OK(cc->PushRequest(mem_rq));
64 RETURN_IF_NOT_OK(mem_rq->Wait());
65 addr_ = mem_rq->GetAddr();
66 // Now we need to add that to the base address of where we attach.
67 auto base = cc->SharedMemoryBaseAddr();
68 auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr_);
69 // Now we copy the data onto shared memory.
70 WritableSlice all(p, sz_);
71 auto offset = fbb->GetSize();
72 ReadableSlice header(fbb->GetBufferPointer(), fbb->GetSize());
73 Status copy_rc = WritableSlice::Copy(&all, header);
74 if (copy_rc.IsOk()) {
75 for (const auto &ts : row) {
76 WritableSlice row_data(all, offset, ts->SizeInBytes());
77 ReadableSlice src(ts->GetBuffer(), ts->SizeInBytes());
78 copy_rc = WritableSlice::Copy(&row_data, src);
79 if (copy_rc.IsError()) {
80 break;
81 }
82 offset += ts->SizeInBytes();
83 }
84 // Fill in where to find the data
85 AddDataLocation();
86 }
87 if (copy_rc.IsError()) {
88 // We need to return the memory back to the server
89 auto mfree_req = GenerateFreeBlockRequest();
90 Status rc = cc->PushRequest(mfree_req);
91 // But we won't wait for the result for the sake of performance.
92 if (rc.IsError()) {
93 MS_LOG(ERROR) << "Push request for free memory failed.";
94 }
95 return copy_rc;
96 }
97 } else {
98 // We have already filled the first buffer which is the cookie.
99 sz_ += rq_.buf_data(0).size();
100 rq_.add_buf_data(fbb->GetBufferPointer(), fbb->GetSize());
101 for (const auto &ts : row) {
102 rq_.add_buf_data(ts->GetBuffer(), ts->SizeInBytes());
103 }
104 MS_LOG(DEBUG) << "Sending " << sz_ << " bytes of tensor data in " << rq_.buf_data_size() << " segments";
105 }
106 return Status::OK();
107 }
108
PostReply()109 Status CacheRowRequest::PostReply() {
110 if (!reply_.result().empty()) {
111 row_id_from_server_ = strtoll(reply_.result().data(), nullptr, kDecimal);
112 }
113 return Status::OK();
114 }
115
Prepare()116 Status CacheRowRequest::Prepare() {
117 if (BitTest(rq_.flag(), kDataIsInSharedMemory)) {
118 // First one is cookie, followed by address and then size.
119 constexpr int32_t kExpectedBufDataSize = 3;
120 CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() == kExpectedBufDataSize, "Incomplete rpc data");
121 } else {
122 // First one is cookie. 2nd one is the google flat buffers followed by a number of buffers.
123 // But we are not going to decode them to verify.
124 constexpr int32_t kMinBufDataSize = 3;
125 CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() >= kMinBufDataSize, "Incomplete rpc data");
126 }
127 return Status::OK();
128 }
129
CacheRowRequest(const CacheClient * cc)130 CacheRowRequest::CacheRowRequest(const CacheClient *cc)
131 : BaseRequest(RequestType::kCacheRow),
132 support_local_bypass_(cc->local_bypass_),
133 addr_(-1),
134 sz_(0),
135 row_id_from_server_(-1) {
136 rq_.set_connection_id(cc->server_connection_id_);
137 rq_.set_client_id(cc->client_id_);
138 rq_.add_buf_data(cc->cookie_);
139 }
140
BatchFetchRequest(const CacheClient * cc,const std::vector<row_id_type> & row_id)141 BatchFetchRequest::BatchFetchRequest(const CacheClient *cc, const std::vector<row_id_type> &row_id)
142 : BaseRequest(RequestType::kBatchFetchRows), support_local_bypass_(cc->local_bypass_), row_id_(row_id) {
143 rq_.set_connection_id(cc->server_connection_id_);
144 rq_.set_client_id(cc->client_id_);
145 rq_.set_flag(support_local_bypass_ ? kLocalClientSupport : 0);
146 // Convert the row id into a flatbuffer
147 flatbuffers::FlatBufferBuilder fbb;
148 auto off_t = fbb.CreateVector(row_id);
149 TensorRowIdsBuilder bld(fbb);
150 bld.add_row_id(off_t);
151 auto off = bld.Finish();
152 fbb.Finish(off);
153 rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
154 }
155
RestoreRows(TensorTable * out,const void * baseAddr,int64_t * out_addr)156 Status BatchFetchRequest::RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr) {
157 RETURN_UNEXPECTED_IF_NULL(out);
158 auto num_elements = row_id_.size();
159 const char *ptr = nullptr;
160 int64_t sz = 0;
161 // Tap into the reply flag to see where we can find the data. Server may decide the amount is
162 // so small that it doesn't use shared memory method.
163 auto flag = reply_.flag();
164 bool dataOnSharedMemory = support_local_bypass_ ? (BitTest(flag, kDataIsInSharedMemory)) : false;
165 if (dataOnSharedMemory) {
166 auto addr = strtoll(reply_.result().data(), nullptr, kDecimal);
167 ptr = reinterpret_cast<const char *>(reinterpret_cast<int64_t>(baseAddr) + addr);
168 RETURN_UNEXPECTED_IF_NULL(out);
169 *out_addr = addr;
170 } else {
171 ptr = reply_.result().data();
172 *out_addr = -1;
173 }
174 auto *offset_array = reinterpret_cast<const int64_t *>(ptr);
175 sz = offset_array[num_elements];
176 CHECK_FAIL_RETURN_UNEXPECTED(support_local_bypass_ || sz == reply_.result().length(), "Length mismatch");
177 TensorTable tbl;
178 tbl.reserve(num_elements);
179 ReadableSlice all(ptr, sz);
180 for (auto i = 0; i < num_elements; ++i) {
181 auto len = offset_array[i + 1] - offset_array[i];
182 TensorRow row;
183 row.setId(row_id_.at(i));
184 if (len > 0) {
185 ReadableSlice row_data(all, offset_array[i], len);
186 // Next we de-serialize flat buffer to get back each column
187 auto msg = GetTensorRowHeaderMsg(row_data.GetPointer());
188 auto msg_sz = msg->size_of_this();
189 // Start of the tensor data
190 auto ts_offset = msg_sz;
191 row.reserve(msg->column()->size());
192 for (auto k = 0; k < msg->column()->size(); ++k) {
193 auto col_ts = msg->column()->Get(k);
194 std::shared_ptr<Tensor> ts;
195 ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k));
196 RETURN_IF_NOT_OK(mindspore::dataset::RestoreOneTensor(col_ts, data, &ts));
197 row.push_back(ts);
198 ts_offset += data.GetSize();
199 }
200 } else {
201 CHECK_FAIL_RETURN_UNEXPECTED(len == 0, "Data corruption detected.");
202 }
203 tbl.push_back(std::move(row));
204 }
205 *out = std::move(tbl);
206 return Status::OK();
207 }
208
CreateCacheRequest(CacheClient * cc,const CacheClientInfo & cinfo,uint64_t cache_mem_sz,CreateCacheRequest::CreateCacheFlag flag)209 CreateCacheRequest::CreateCacheRequest(CacheClient *cc, const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
210 CreateCacheRequest::CreateCacheFlag flag)
211 : BaseRequest(RequestType::kCreateCache), cache_mem_sz_(cache_mem_sz), flag_(flag), cc_(cc) {
212 // Type has been set already in the base constructor. So we need to fill in the connection info.
213 // On successful return, we will get the connection id
214 rq_.mutable_connection_info()->operator=(cinfo);
215 }
216
Prepare()217 Status CreateCacheRequest::Prepare() {
218 try {
219 flatbuffers::FlatBufferBuilder fbb;
220 CreateCacheRequestMsgBuilder bld(fbb);
221 bld.add_cache_mem_sz(cache_mem_sz_);
222 bld.add_flag(static_cast<uint32_t>(flag_));
223 auto off = bld.Finish();
224 fbb.Finish(off);
225 rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
226 return Status::OK();
227 } catch (const std::bad_alloc &e) {
228 RETURN_STATUS_OOM("Out of memory.");
229 }
230 }
231
PostReply()232 Status CreateCacheRequest::PostReply() {
233 auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data());
234 cc_->server_connection_id_ = p->connection_id();
235 cc_->cookie_ = p->cookie()->str();
236 cc_->client_id_ = p->client_id();
237 // Next is a set of cpu id that we should re-adjust ourselves for better affinity.
238 auto sz = p->cpu_id()->size();
239 cc_->cpu_list_.reserve(sz);
240 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
241 std::string c_list;
242 cpu_set_t cpu_set;
243 CPU_ZERO(&cpu_set);
244 #endif
245 for (uint32_t i = 0; i < sz; ++i) {
246 auto cpu_id = p->cpu_id()->Get(i);
247 cc_->cpu_list_.push_back(cpu_id);
248 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
249 c_list += std::to_string(cpu_id) + " ";
250 CPU_SET(cpu_id, &cpu_set);
251 #endif
252 }
253
254 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
255 if (sz > 0) {
256 auto err = sched_setaffinity(getpid(), sizeof(cpu_set), &cpu_set);
257 if (err == -1) {
258 RETURN_STATUS_UNEXPECTED("Unable to set affinity. Errno = " + std::to_string(errno));
259 }
260 MS_LOG(INFO) << "Changing cpu affinity to the following list of cpu id: " + c_list;
261 }
262 #endif
263
264 return Status::OK();
265 }
266
SerializeCacheSchemaRequest(const std::unordered_map<std::string,int32_t> & map)267 Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) {
268 try {
269 flatbuffers::FlatBufferBuilder fbb;
270 std::vector<flatbuffers::Offset<ColumnNameMsg>> v;
271 v.reserve(map.size());
272 for (auto &column : map) {
273 auto c = CreateColumnNameMsg(fbb, fbb.CreateString(column.first), column.second);
274 v.push_back(c);
275 }
276 auto v_off = fbb.CreateVector(v);
277 auto final_off = CreateSchemaMsg(fbb, v_off);
278 fbb.Finish(final_off);
279 rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
280 return Status::OK();
281 } catch (const std::bad_alloc &e) {
282 RETURN_STATUS_OOM("Out of memory.");
283 }
284 }
285
PostReply()286 Status FetchSchemaRequest::PostReply() {
287 auto *map_msg = flatbuffers::GetRoot<SchemaMsg>(reply_.result().data());
288 auto v = map_msg->column();
289 for (auto i = 0; i < v->size(); ++i) {
290 auto col = map_msg->column()->Get(i);
291 column_name_id_map_.emplace(col->name()->str(), col->id());
292 }
293 return Status::OK();
294 }
295
GetColumnMap()296 std::unordered_map<std::string, int32_t> FetchSchemaRequest::GetColumnMap() { return column_name_id_map_; }
297
PostReply()298 Status GetStatRequest::PostReply() {
299 auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(reply_.result().data());
300 stat_.num_disk_cached = msg->num_disk_cached();
301 stat_.num_mem_cached = msg->num_mem_cached();
302 stat_.avg_cache_sz = msg->avg_cache_sz();
303 stat_.num_numa_hit = msg->num_numa_hit();
304 stat_.max_row_id = msg->max_row_id();
305 stat_.min_row_id = msg->min_row_id();
306 stat_.cache_service_state = msg->state();
307 return Status::OK();
308 }
309
PostReply()310 Status GetCacheStateRequest::PostReply() {
311 try {
312 cache_service_state_ = std::stoi(reply_.result());
313 } catch (const std::exception &e) {
314 RETURN_STATUS_UNEXPECTED(e.what());
315 }
316 return Status::OK();
317 }
318
PostReply()319 Status ListSessionsRequest::PostReply() {
320 auto *msg = flatbuffers::GetRoot<ListSessionsMsg>(reply_.result().data());
321 auto session_vector = msg->sessions();
322 for (uint32_t i = 0; i < session_vector->size(); ++i) {
323 SessionCacheInfo current_info{};
324 CacheServiceStat stats{};
325 auto current_session_info = session_vector->Get(i);
326 current_info.session_id = current_session_info->session_id();
327 current_info.connection_id = current_session_info->connection_id();
328 stats.num_mem_cached = current_session_info->stats()->num_mem_cached();
329 stats.num_disk_cached = current_session_info->stats()->num_disk_cached();
330 stats.avg_cache_sz = current_session_info->stats()->avg_cache_sz();
331 stats.num_numa_hit = current_session_info->stats()->num_numa_hit();
332 stats.min_row_id = current_session_info->stats()->min_row_id();
333 stats.max_row_id = current_session_info->stats()->max_row_id();
334 stats.cache_service_state = current_session_info->stats()->state();
335 current_info.stats = stats; // fixed length struct. = operator is safe
336 session_info_list_.push_back(current_info);
337 }
338 server_cfg_.num_workers = msg->num_workers();
339 server_cfg_.log_level = msg->log_level();
340 server_cfg_.spill_dir = msg->spill_dir()->str();
341 return Status::OK();
342 }
343
PostReply()344 Status ServerStopRequest::PostReply() {
345 CHECK_FAIL_RETURN_UNEXPECTED(strcmp(reply_.result().data(), "OK") == 0, "Not the right response");
346 return Status::OK();
347 }
348
BatchCacheRowsRequest(const CacheClient * cc,int64_t addr,int32_t num_ele)349 BatchCacheRowsRequest::BatchCacheRowsRequest(const CacheClient *cc, int64_t addr, int32_t num_ele)
350 : BaseRequest(RequestType::kBatchCacheRows) {
351 rq_.set_connection_id(cc->server_connection_id_);
352 rq_.set_client_id(cc->client_id_);
353 rq_.add_buf_data(cc->cookie());
354 rq_.add_buf_data(std::to_string(addr));
355 rq_.add_buf_data(std::to_string(num_ele));
356 }
357 } // namespace dataset
358 } // namespace mindspore
359