1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 8 * http://www.apache.org/licenses/LICENSE-2.0 9 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_REQ_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_REQ_H_ 18 19 #include <algorithm> 20 #include <memory> 21 #include <iostream> 22 #include <string> 23 #include <unordered_map> 24 #include <utility> 25 #include <vector> 26 27 #ifdef ENABLE_CACHE 28 #include "proto/cache_grpc.grpc.pb.h" 29 #endif 30 #include "proto/cache_grpc.pb.h" 31 #include "minddata/dataset/core/tensor_row.h" 32 #include "minddata/dataset/engine/cache/de_tensor_generated.h" 33 #include "minddata/dataset/util/slice.h" 34 #include "minddata/dataset/util/wait_post.h" 35 36 namespace mindspore { 37 namespace dataset { 38 class CacheClient; 39 /// \brief Statistic structure for GetStat request 40 struct CacheServiceStat { 41 int64_t num_mem_cached; 42 int64_t num_disk_cached; 43 int64_t avg_cache_sz; 44 int64_t num_numa_hit; 45 row_id_type min_row_id; 46 row_id_type max_row_id; 47 int8_t cache_service_state; 48 }; 49 50 struct CacheServerCfgInfo { 51 int32_t num_workers; 52 int8_t log_level; 53 std::string spill_dir; 54 }; 55 56 /// \brief Info structure ListSessionsRequest 57 struct SessionCacheInfo { 58 session_id_type session_id; 59 connection_id_type connection_id; 60 CacheServiceStat stats; 61 }; 62 63 /// \brief CacheClient communicates with CacheServer using Requests. 64 class BaseRequest { 65 public: 66 // Request types 67 enum class RequestType : int16_t { 68 kCacheRow = 0, 69 kBatchFetchRows = 1, 70 kCreateCache = 2, 71 kGetCacheMissKeys = 3, 72 kDestroyCache = 4, 73 kGetStat = 5, 74 kCacheSchema = 6, 75 kFetchSchema = 7, 76 kBuildPhaseDone = 8, 77 kDropSession = 9, 78 kGenerateSessionId = 10, 79 kAllocateSharedBlock = 11, 80 kFreeSharedBlock = 12, 81 kStopService = 13, 82 kHeartBeat = 14, 83 kToggleWriteMode = 15, 84 kListSessions = 16, 85 kConnectReset = 17, 86 kInternalFetchRow = 18, 87 kBatchCacheRows = 19, 88 kInternalCacheRow = 20, 89 kGetCacheState = 21, 90 // Add new request before it. 91 kRequestUnknown = 32767 92 }; 93 94 friend class CacheServer; 95 friend class CacheServerRequest; 96 friend class CacheClientGreeter; 97 friend class CacheClientRequestTag; 98 friend class CacheClient; 99 friend class CacheService; 100 friend class CacheServerGreeterImpl; 101 102 /// \brief Base class of a cache server request 103 /// \param type Type of the request BaseRequest(RequestType type)104 explicit BaseRequest(RequestType type) : type_(type) { 105 rq_.set_type(static_cast<int16_t>(type_)); 106 rq_.set_client_id(-1); 107 rq_.set_flag(0); 108 } 109 virtual ~BaseRequest() = default; 110 111 /// \brief A print method for debugging 112 /// \param out The output stream to write output to Print(std::ostream & out)113 virtual void Print(std::ostream &out) const { out << "Request type: " << static_cast<int16_t>(type_); } 114 115 /// \brief << Stream output operator overload 116 /// \param out reference to the output stream 117 /// \param rq reference to the BaseRequest 118 /// \return the output stream 119 friend std::ostream &operator<<(std::ostream &out, const BaseRequest &rq) { 120 rq.Print(out); 121 return out; 122 } 123 124 /// \brief Derived class can implement extra work to be done before the request is sent to the server Prepare()125 virtual Status Prepare() { return Status::OK(); } 126 127 /// \brief Derived class can implement extra work to be done after the server sends the request PostReply()128 virtual Status PostReply() { return Status::OK(); } 129 130 /// \brief A method for the client to wait for the availability of the result back from the server. 131 /// \return Status object 132 Status Wait(); 133 134 /// \brief Return if the request is of row request type 135 /// \return True if the request is row-related request IsRowRequest()136 bool IsRowRequest() const { 137 return type_ == RequestType::kBatchCacheRows || type_ == RequestType::kBatchFetchRows || 138 type_ == RequestType::kInternalCacheRow || type_ == RequestType::kInternalFetchRow || 139 type_ == RequestType::kCacheRow; 140 } 141 142 /// \brief Return if the request is of admin request type 143 /// \return True if the request is admin-related request IsAdminRequest()144 bool IsAdminRequest() const { 145 return type_ == RequestType::kCreateCache || type_ == RequestType::kDestroyCache || 146 type_ == RequestType::kGetStat || type_ == RequestType::kGetCacheState || 147 type_ == RequestType::kAllocateSharedBlock || type_ == RequestType::kFreeSharedBlock || 148 type_ == RequestType::kCacheSchema || type_ == RequestType::kFetchSchema || 149 type_ == RequestType::kBuildPhaseDone || type_ == RequestType::kToggleWriteMode || 150 type_ == RequestType::kConnectReset || type_ == RequestType::kStopService || 151 type_ == RequestType::kHeartBeat || type_ == RequestType::kGetCacheMissKeys; 152 } 153 154 /// \brief Return if the request is of session request type 155 /// \return True if the request is session-related request IsSessionRequest()156 bool IsSessionRequest() const { 157 return type_ == RequestType::kGenerateSessionId || type_ == RequestType::kDropSession || 158 type_ == RequestType::kListSessions; 159 } 160 161 protected: 162 CacheRequest rq_; // This is what we send to the server 163 CacheReply reply_; // This is what the server send back 164 165 private: 166 RequestType type_; 167 WaitPost wp_; // A sync area used by the client side. 168 }; 169 170 class FreeSharedBlockRequest : public BaseRequest { 171 public: 172 friend class CacheServer; FreeSharedBlockRequest(connection_id_type connection_id,int32_t client_id,int64_t addr)173 explicit FreeSharedBlockRequest(connection_id_type connection_id, int32_t client_id, int64_t addr) 174 : BaseRequest(RequestType::kFreeSharedBlock) { 175 rq_.set_connection_id(connection_id); 176 rq_.add_buf_data(std::to_string(addr)); 177 rq_.set_client_id(client_id); 178 } 179 ~FreeSharedBlockRequest() override = default; 180 }; 181 182 /// \brief Request to cache a single TensorRow 183 class CacheRowRequest : public BaseRequest { 184 public: 185 friend class CacheServer; 186 friend class CacheClient; 187 explicit CacheRowRequest(const CacheClient *cc); 188 ~CacheRowRequest() override = default; 189 190 /// \brief Serialize a TensorRow for streaming to the cache server 191 /// \param row TensorRow 192 /// \return Status object 193 Status SerializeCacheRowRequest(const CacheClient *cc, const TensorRow &row); 194 195 /// \brief Sanity check before we send the row. 196 /// \return Status object 197 Status Prepare() override; 198 199 /// \brief Override the base function get the row id returned from the server 200 /// \return Status object 201 Status PostReply() override; 202 203 /// \brief Return the row id assigned to this row for non-mappable dataset 204 /// \return row id of the cached row GetRowIdAfterCache()205 row_id_type GetRowIdAfterCache() { return row_id_from_server_; } 206 207 /// \brief If we are doing local bypass, fill in extra request information of where the data is located. AddDataLocation()208 void AddDataLocation() { 209 if (support_local_bypass_) { 210 rq_.add_buf_data(std::to_string(addr_)); 211 rq_.add_buf_data(std::to_string(sz_)); 212 } 213 } 214 215 /// \brief If we fail to send the data to the server using shared memory method, we should release 216 /// the shared memory by sending another request. The following function will generate a suitable 217 /// request for the CacheClient to send. GenerateFreeBlockRequest()218 std::shared_ptr<FreeSharedBlockRequest> GenerateFreeBlockRequest() { 219 return std::make_shared<FreeSharedBlockRequest>(rq_.connection_id(), rq_.client_id(), addr_); 220 } 221 222 private: 223 bool support_local_bypass_; 224 int64_t addr_; 225 int64_t sz_; 226 row_id_type row_id_from_server_; 227 }; 228 229 /// \brief Request to fetch rows in batch 230 class BatchFetchRequest : public BaseRequest { 231 public: 232 friend class CacheServer; 233 friend class CacheService; 234 BatchFetchRequest(const CacheClient *cc, const std::vector<row_id_type> &row_id); 235 ~BatchFetchRequest() override = default; 236 Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr); 237 238 private: 239 bool support_local_bypass_; 240 std::vector<row_id_type> row_id_; 241 }; 242 243 /// \brief Request to create a cache for the current connection 244 class CreateCacheRequest : public BaseRequest { 245 public: 246 friend class CacheServer; 247 enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L }; 248 249 /// \brief Constructor 250 /// \param connection_id 251 /// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited 252 /// \param flag Attributes of the cache. 253 explicit CreateCacheRequest(CacheClient *cc, const CacheClientInfo &cinfo, uint64_t cache_mem_sz, 254 CreateCacheFlag flag = CreateCacheFlag::kNone); 255 ~CreateCacheRequest() override = default; 256 257 /// Overload the base class Prepare/PostReply 258 Status Prepare() override; 259 Status PostReply() override; 260 261 private: 262 uint64_t cache_mem_sz_; 263 CreateCacheFlag flag_; 264 CacheClient *cc_; 265 }; 266 267 /// \brief Request to get all the keys not present at the server. 268 /// \note Only applicable to mappable case 269 class GetCacheMissKeysRequest : public BaseRequest { 270 public: 271 friend class CacheServer; GetCacheMissKeysRequest(connection_id_type connection_id)272 explicit GetCacheMissKeysRequest(connection_id_type connection_id) : BaseRequest(RequestType::kGetCacheMissKeys) { 273 rq_.set_connection_id(connection_id); 274 } 275 ~GetCacheMissKeysRequest() override = default; 276 }; 277 278 /// \brief Request to destroy a cache 279 class DestroyCacheRequest : public BaseRequest { 280 public: 281 friend class CacheServer; DestroyCacheRequest(connection_id_type connection_id)282 explicit DestroyCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kDestroyCache) { 283 rq_.set_connection_id(connection_id); 284 } 285 ~DestroyCacheRequest() override = default; 286 }; 287 288 /// \brief Obtain the statistics of the current connection 289 class GetStatRequest : public BaseRequest { 290 public: 291 friend class CacheServer; 292 friend class CacheService; GetStatRequest(connection_id_type connection_id)293 explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(RequestType::kGetStat) { 294 rq_.set_connection_id(connection_id); 295 } 296 297 ~GetStatRequest() override = default; 298 299 /// \brief Override base function to process the result. 300 Status PostReply() override; 301 GetStat(CacheServiceStat * stat)302 void GetStat(CacheServiceStat *stat) { 303 if (stat != nullptr) { 304 (*stat) = stat_; 305 } 306 } 307 308 private: 309 CacheServiceStat stat_{}; 310 }; 311 312 /// \brief Get the state of a cache service 313 class GetCacheStateRequest : public BaseRequest { 314 public: 315 friend class CacheServer; GetCacheStateRequest(connection_id_type connection_id)316 explicit GetCacheStateRequest(connection_id_type connection_id) 317 : BaseRequest(RequestType::kGetCacheState), cache_service_state_(0) { 318 rq_.set_connection_id(connection_id); 319 } 320 ~GetCacheStateRequest() override = default; 321 322 Status PostReply() override; 323 GetState()324 auto GetState() const { return cache_service_state_; } 325 326 private: 327 int8_t cache_service_state_; 328 }; 329 330 /// \brief Request to cache a schema 331 class CacheSchemaRequest : public BaseRequest { 332 public: 333 friend class CacheServer; CacheSchemaRequest(connection_id_type connection_id)334 explicit CacheSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kCacheSchema) { 335 rq_.set_connection_id(connection_id); 336 } 337 ~CacheSchemaRequest() override = default; 338 339 Status SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map); 340 }; 341 342 /// \brief Request to fetch a schema 343 class FetchSchemaRequest : public BaseRequest { 344 public: 345 friend class CacheServer; FetchSchemaRequest(connection_id_type connection_id)346 explicit FetchSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kFetchSchema) { 347 rq_.set_connection_id(connection_id); 348 } 349 ~FetchSchemaRequest() override = default; 350 351 Status PostReply() override; 352 353 std::unordered_map<std::string, int32_t> GetColumnMap(); 354 355 private: 356 std::unordered_map<std::string, int32_t> column_name_id_map_; 357 }; 358 359 /// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only. 360 class BuildPhaseDoneRequest : public BaseRequest { 361 public: 362 friend class CacheServer; BuildPhaseDoneRequest(connection_id_type connection_id,const std::string & cookie)363 BuildPhaseDoneRequest(connection_id_type connection_id, const std::string &cookie) 364 : BaseRequest(RequestType::kBuildPhaseDone), cookie_(cookie) { 365 rq_.set_connection_id(connection_id); 366 rq_.add_buf_data(cookie_); 367 } 368 ~BuildPhaseDoneRequest() override = default; 369 370 private: 371 std::string cookie_; 372 }; 373 374 /// \brief Request to drop all the caches in the current session 375 class DropSessionRequest : public BaseRequest { 376 public: 377 friend class CacheServer; DropSessionRequest(const CacheClientInfo & cinfo)378 explicit DropSessionRequest(const CacheClientInfo &cinfo) : BaseRequest(RequestType::kDropSession) { 379 rq_.mutable_connection_info()->operator=(cinfo); 380 } 381 ~DropSessionRequest() override = default; 382 }; 383 384 class GenerateSessionIdRequest : public BaseRequest { 385 public: 386 friend class CacheServer; GenerateSessionIdRequest()387 GenerateSessionIdRequest() : BaseRequest(RequestType::kGenerateSessionId) { 388 // We don't have anything client info nor connection id to send. But we will manually 389 // set the connection id to 0. 390 rq_.set_connection_id(0); 391 } 392 393 ~GenerateSessionIdRequest() override = default; 394 GetSessionId()395 session_id_type GetSessionId() { return atoi(reply_.result().data()); } 396 }; 397 398 class ListSessionsRequest : public BaseRequest { 399 public: 400 friend class CacheServer; ListSessionsRequest()401 ListSessionsRequest() : BaseRequest(RequestType::kListSessions) { 402 // This request is not specific to any cache or session 403 rq_.set_connection_id(0); 404 } 405 406 ~ListSessionsRequest() override = default; 407 408 /// \brief Override base function to process the result. 409 Status PostReply() override; 410 GetSessionCacheInfo(std::vector<SessionCacheInfo> * info)411 void GetSessionCacheInfo(std::vector<SessionCacheInfo> *info) { 412 if (info != nullptr) { 413 (*info) = session_info_list_; 414 } 415 } 416 GetSessionCacheInfo()417 std::vector<SessionCacheInfo> GetSessionCacheInfo() { return session_info_list_; } 418 GetSessionIds()419 std::vector<session_id_type> GetSessionIds() { 420 std::vector<session_id_type> session_ids; 421 for (auto session_info : session_info_list_) { 422 session_ids.push_back(session_info.session_id); 423 } 424 return session_ids; 425 } 426 GetServerStat()427 CacheServerCfgInfo GetServerStat() { return server_cfg_; } 428 429 private: 430 std::vector<SessionCacheInfo> session_info_list_; 431 CacheServerCfgInfo server_cfg_{}; 432 }; 433 434 class AllocateSharedBlockRequest : public BaseRequest { 435 public: 436 friend class CacheServer; AllocateSharedBlockRequest(connection_id_type connection_id,int32_t client_id,size_t requestedSz)437 explicit AllocateSharedBlockRequest(connection_id_type connection_id, int32_t client_id, size_t requestedSz) 438 : BaseRequest(RequestType::kAllocateSharedBlock) { 439 rq_.set_connection_id(connection_id); 440 rq_.add_buf_data(std::to_string(requestedSz)); 441 rq_.set_client_id(client_id); 442 } 443 ~AllocateSharedBlockRequest() override = default; 444 445 /// \brief On return from the server, we get the (relative) address where 446 /// the free block is located. 447 /// \return GetAddr()448 int64_t GetAddr() { 449 auto addr = strtoll(reply_.result().data(), nullptr, kDecimal); 450 return addr; 451 } 452 }; 453 454 class ToggleWriteModeRequest : public BaseRequest { 455 public: 456 friend class CacheServer; ToggleWriteModeRequest(connection_id_type connection_id,bool on_off)457 explicit ToggleWriteModeRequest(connection_id_type connection_id, bool on_off) 458 : BaseRequest(RequestType::kToggleWriteMode) { 459 rq_.set_connection_id(connection_id); 460 rq_.add_buf_data(on_off ? "on" : "off"); 461 } 462 ~ToggleWriteModeRequest() override = default; 463 }; 464 465 class ServerStopRequest : public BaseRequest { 466 public: 467 friend class CacheServer; ServerStopRequest(int32_t qID)468 explicit ServerStopRequest(int32_t qID) : BaseRequest(RequestType::kStopService) { 469 rq_.add_buf_data(std::to_string(qID)); 470 } 471 ~ServerStopRequest() = default; 472 Status PostReply() override; 473 }; 474 475 class ConnectResetRequest : public BaseRequest { 476 public: 477 friend class CacheServer; ConnectResetRequest(connection_id_type connection_id,int32_t client_id)478 explicit ConnectResetRequest(connection_id_type connection_id, int32_t client_id) 479 : BaseRequest(RequestType::kConnectReset) { 480 rq_.set_connection_id(connection_id); 481 rq_.set_client_id(client_id); 482 } 483 ~ConnectResetRequest() override = default; 484 485 /// Override the base class function Prepare()486 Status Prepare() override { 487 CHECK_FAIL_RETURN_UNEXPECTED(rq_.client_id() != -1, "Invalid client id"); 488 return Status::OK(); 489 } 490 }; 491 492 class BatchCacheRowsRequest : public BaseRequest { 493 public: 494 friend class CacheServer; 495 explicit BatchCacheRowsRequest(const CacheClient *cc, int64_t addr, int32_t num_ele); 496 ~BatchCacheRowsRequest() override = default; 497 }; 498 } // namespace dataset 499 } // namespace mindspore 500 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_ 501