• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7 
8  * http://www.apache.org/licenses/LICENSE-2.0
9 
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15 */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_REQ_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_REQ_H_
18 
19 #include <algorithm>
20 #include <memory>
21 #include <iostream>
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26 
27 #ifdef ENABLE_CACHE
28 #include "proto/cache_grpc.grpc.pb.h"
29 #endif
30 #include "proto/cache_grpc.pb.h"
31 #include "minddata/dataset/core/tensor_row.h"
32 #include "minddata/dataset/engine/cache/de_tensor_generated.h"
33 #include "minddata/dataset/util/slice.h"
34 #include "minddata/dataset/util/wait_post.h"
35 
36 namespace mindspore {
37 namespace dataset {
38 class CacheClient;
39 /// \brief Statistic structure for GetStat request
40 struct CacheServiceStat {
41   int64_t num_mem_cached;
42   int64_t num_disk_cached;
43   int64_t avg_cache_sz;
44   int64_t num_numa_hit;
45   row_id_type min_row_id;
46   row_id_type max_row_id;
47   int8_t cache_service_state;
48 };
49 
50 struct CacheServerCfgInfo {
51   int32_t num_workers;
52   int8_t log_level;
53   std::string spill_dir;
54 };
55 
56 /// \brief Info structure ListSessionsRequest
57 struct SessionCacheInfo {
58   session_id_type session_id;
59   connection_id_type connection_id;
60   CacheServiceStat stats;
61 };
62 
63 /// \brief CacheClient communicates with CacheServer using Requests.
64 class BaseRequest {
65  public:
66   // Request types
67   enum class RequestType : int16_t {
68     kCacheRow = 0,
69     kBatchFetchRows = 1,
70     kCreateCache = 2,
71     kGetCacheMissKeys = 3,
72     kDestroyCache = 4,
73     kGetStat = 5,
74     kCacheSchema = 6,
75     kFetchSchema = 7,
76     kBuildPhaseDone = 8,
77     kDropSession = 9,
78     kGenerateSessionId = 10,
79     kAllocateSharedBlock = 11,
80     kFreeSharedBlock = 12,
81     kStopService = 13,
82     kHeartBeat = 14,
83     kToggleWriteMode = 15,
84     kListSessions = 16,
85     kConnectReset = 17,
86     kInternalFetchRow = 18,
87     kBatchCacheRows = 19,
88     kInternalCacheRow = 20,
89     kGetCacheState = 21,
90     // Add new request before it.
91     kRequestUnknown = 32767
92   };
93 
94   friend class CacheServer;
95   friend class CacheServerRequest;
96   friend class CacheClientGreeter;
97   friend class CacheClientRequestTag;
98   friend class CacheClient;
99   friend class CacheService;
100   friend class CacheServerGreeterImpl;
101 
102   /// \brief Base class of a cache server request
103   /// \param type Type of the request
BaseRequest(RequestType type)104   explicit BaseRequest(RequestType type) : type_(type) {
105     rq_.set_type(static_cast<int16_t>(type_));
106     rq_.set_client_id(-1);
107     rq_.set_flag(0);
108   }
109   virtual ~BaseRequest() = default;
110 
111   /// \brief A print method for debugging
112   /// \param out The output stream to write output to
Print(std::ostream & out)113   virtual void Print(std::ostream &out) const { out << "Request type: " << static_cast<int16_t>(type_); }
114 
115   /// \brief << Stream output operator overload
116   /// \param out reference to the output stream
117   /// \param rq reference to the BaseRequest
118   /// \return the output stream
119   friend std::ostream &operator<<(std::ostream &out, const BaseRequest &rq) {
120     rq.Print(out);
121     return out;
122   }
123 
124   /// \brief Derived class can implement extra work to be done before the request is sent to the server
Prepare()125   virtual Status Prepare() { return Status::OK(); }
126 
127   /// \brief Derived class can implement extra work to be done after the server sends the request
PostReply()128   virtual Status PostReply() { return Status::OK(); }
129 
130   /// \brief A method for the client to wait for the availability of the result back from the server.
131   /// \return Status object
132   Status Wait();
133 
134   /// \brief Return if the request is of row request type
135   /// \return True if the request is row-related request
IsRowRequest()136   bool IsRowRequest() const {
137     return type_ == RequestType::kBatchCacheRows || type_ == RequestType::kBatchFetchRows ||
138            type_ == RequestType::kInternalCacheRow || type_ == RequestType::kInternalFetchRow ||
139            type_ == RequestType::kCacheRow;
140   }
141 
142   /// \brief Return if the request is of admin request type
143   /// \return True if the request is admin-related request
IsAdminRequest()144   bool IsAdminRequest() const {
145     return type_ == RequestType::kCreateCache || type_ == RequestType::kDestroyCache ||
146            type_ == RequestType::kGetStat || type_ == RequestType::kGetCacheState ||
147            type_ == RequestType::kAllocateSharedBlock || type_ == RequestType::kFreeSharedBlock ||
148            type_ == RequestType::kCacheSchema || type_ == RequestType::kFetchSchema ||
149            type_ == RequestType::kBuildPhaseDone || type_ == RequestType::kToggleWriteMode ||
150            type_ == RequestType::kConnectReset || type_ == RequestType::kStopService ||
151            type_ == RequestType::kHeartBeat || type_ == RequestType::kGetCacheMissKeys;
152   }
153 
154   /// \brief Return if the request is of session request type
155   /// \return True if the request is session-related request
IsSessionRequest()156   bool IsSessionRequest() const {
157     return type_ == RequestType::kGenerateSessionId || type_ == RequestType::kDropSession ||
158            type_ == RequestType::kListSessions;
159   }
160 
161  protected:
162   CacheRequest rq_;   // This is what we send to the server
163   CacheReply reply_;  // This is what the server send back
164 
165  private:
166   RequestType type_;
167   WaitPost wp_;  // A sync area used by the client side.
168 };
169 
170 class FreeSharedBlockRequest : public BaseRequest {
171  public:
172   friend class CacheServer;
FreeSharedBlockRequest(connection_id_type connection_id,int32_t client_id,int64_t addr)173   explicit FreeSharedBlockRequest(connection_id_type connection_id, int32_t client_id, int64_t addr)
174       : BaseRequest(RequestType::kFreeSharedBlock) {
175     rq_.set_connection_id(connection_id);
176     rq_.add_buf_data(std::to_string(addr));
177     rq_.set_client_id(client_id);
178   }
179   ~FreeSharedBlockRequest() override = default;
180 };
181 
182 /// \brief Request to cache a single TensorRow
183 class CacheRowRequest : public BaseRequest {
184  public:
185   friend class CacheServer;
186   friend class CacheClient;
187   explicit CacheRowRequest(const CacheClient *cc);
188   ~CacheRowRequest() override = default;
189 
190   /// \brief Serialize a TensorRow for streaming to the cache server
191   /// \param row TensorRow
192   /// \return Status object
193   Status SerializeCacheRowRequest(const CacheClient *cc, const TensorRow &row);
194 
195   /// \brief Sanity check before we send the row.
196   /// \return Status object
197   Status Prepare() override;
198 
199   /// \brief Override the base function get the row id returned from the server
200   /// \return Status object
201   Status PostReply() override;
202 
203   /// \brief Return the row id assigned to this row for non-mappable dataset
204   /// \return row id of the cached row
GetRowIdAfterCache()205   row_id_type GetRowIdAfterCache() { return row_id_from_server_; }
206 
207   /// \brief If we are doing local bypass, fill in extra request information of where the data is located.
AddDataLocation()208   void AddDataLocation() {
209     if (support_local_bypass_) {
210       rq_.add_buf_data(std::to_string(addr_));
211       rq_.add_buf_data(std::to_string(sz_));
212     }
213   }
214 
215   /// \brief If we fail to send the data to the server using shared memory method, we should release
216   /// the shared memory by sending another request. The following function will generate a suitable
217   /// request for the CacheClient to send.
GenerateFreeBlockRequest()218   std::shared_ptr<FreeSharedBlockRequest> GenerateFreeBlockRequest() {
219     return std::make_shared<FreeSharedBlockRequest>(rq_.connection_id(), rq_.client_id(), addr_);
220   }
221 
222  private:
223   bool support_local_bypass_;
224   int64_t addr_;
225   int64_t sz_;
226   row_id_type row_id_from_server_;
227 };
228 
229 /// \brief Request to fetch rows in batch
230 class BatchFetchRequest : public BaseRequest {
231  public:
232   friend class CacheServer;
233   friend class CacheService;
234   BatchFetchRequest(const CacheClient *cc, const std::vector<row_id_type> &row_id);
235   ~BatchFetchRequest() override = default;
236   Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr);
237 
238  private:
239   bool support_local_bypass_;
240   std::vector<row_id_type> row_id_;
241 };
242 
243 /// \brief Request to create a cache for the current connection
244 class CreateCacheRequest : public BaseRequest {
245  public:
246   friend class CacheServer;
247   enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L };
248 
249   /// \brief Constructor
250   /// \param connection_id
251   /// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited
252   /// \param flag Attributes of the cache.
253   explicit CreateCacheRequest(CacheClient *cc, const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
254                               CreateCacheFlag flag = CreateCacheFlag::kNone);
255   ~CreateCacheRequest() override = default;
256 
257   /// Overload the base class Prepare/PostReply
258   Status Prepare() override;
259   Status PostReply() override;
260 
261  private:
262   uint64_t cache_mem_sz_;
263   CreateCacheFlag flag_;
264   CacheClient *cc_;
265 };
266 
267 /// \brief Request to get all the keys not present at the server.
268 /// \note Only applicable to mappable case
269 class GetCacheMissKeysRequest : public BaseRequest {
270  public:
271   friend class CacheServer;
GetCacheMissKeysRequest(connection_id_type connection_id)272   explicit GetCacheMissKeysRequest(connection_id_type connection_id) : BaseRequest(RequestType::kGetCacheMissKeys) {
273     rq_.set_connection_id(connection_id);
274   }
275   ~GetCacheMissKeysRequest() override = default;
276 };
277 
278 /// \brief Request to destroy a cache
279 class DestroyCacheRequest : public BaseRequest {
280  public:
281   friend class CacheServer;
DestroyCacheRequest(connection_id_type connection_id)282   explicit DestroyCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kDestroyCache) {
283     rq_.set_connection_id(connection_id);
284   }
285   ~DestroyCacheRequest() override = default;
286 };
287 
288 /// \brief Obtain the statistics of the current connection
289 class GetStatRequest : public BaseRequest {
290  public:
291   friend class CacheServer;
292   friend class CacheService;
GetStatRequest(connection_id_type connection_id)293   explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(RequestType::kGetStat) {
294     rq_.set_connection_id(connection_id);
295   }
296 
297   ~GetStatRequest() override = default;
298 
299   /// \brief Override base function to process the result.
300   Status PostReply() override;
301 
GetStat(CacheServiceStat * stat)302   void GetStat(CacheServiceStat *stat) {
303     if (stat != nullptr) {
304       (*stat) = stat_;
305     }
306   }
307 
308  private:
309   CacheServiceStat stat_{};
310 };
311 
312 /// \brief Get the state of a cache service
313 class GetCacheStateRequest : public BaseRequest {
314  public:
315   friend class CacheServer;
GetCacheStateRequest(connection_id_type connection_id)316   explicit GetCacheStateRequest(connection_id_type connection_id)
317       : BaseRequest(RequestType::kGetCacheState), cache_service_state_(0) {
318     rq_.set_connection_id(connection_id);
319   }
320   ~GetCacheStateRequest() override = default;
321 
322   Status PostReply() override;
323 
GetState()324   auto GetState() const { return cache_service_state_; }
325 
326  private:
327   int8_t cache_service_state_;
328 };
329 
330 /// \brief Request to cache a schema
331 class CacheSchemaRequest : public BaseRequest {
332  public:
333   friend class CacheServer;
CacheSchemaRequest(connection_id_type connection_id)334   explicit CacheSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kCacheSchema) {
335     rq_.set_connection_id(connection_id);
336   }
337   ~CacheSchemaRequest() override = default;
338 
339   Status SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map);
340 };
341 
342 /// \brief Request to fetch a schema
343 class FetchSchemaRequest : public BaseRequest {
344  public:
345   friend class CacheServer;
FetchSchemaRequest(connection_id_type connection_id)346   explicit FetchSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kFetchSchema) {
347     rq_.set_connection_id(connection_id);
348   }
349   ~FetchSchemaRequest() override = default;
350 
351   Status PostReply() override;
352 
353   std::unordered_map<std::string, int32_t> GetColumnMap();
354 
355  private:
356   std::unordered_map<std::string, int32_t> column_name_id_map_;
357 };
358 
359 /// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only.
360 class BuildPhaseDoneRequest : public BaseRequest {
361  public:
362   friend class CacheServer;
BuildPhaseDoneRequest(connection_id_type connection_id,const std::string & cookie)363   BuildPhaseDoneRequest(connection_id_type connection_id, const std::string &cookie)
364       : BaseRequest(RequestType::kBuildPhaseDone), cookie_(cookie) {
365     rq_.set_connection_id(connection_id);
366     rq_.add_buf_data(cookie_);
367   }
368   ~BuildPhaseDoneRequest() override = default;
369 
370  private:
371   std::string cookie_;
372 };
373 
374 /// \brief Request to drop all the caches in the current session
375 class DropSessionRequest : public BaseRequest {
376  public:
377   friend class CacheServer;
DropSessionRequest(const CacheClientInfo & cinfo)378   explicit DropSessionRequest(const CacheClientInfo &cinfo) : BaseRequest(RequestType::kDropSession) {
379     rq_.mutable_connection_info()->operator=(cinfo);
380   }
381   ~DropSessionRequest() override = default;
382 };
383 
384 class GenerateSessionIdRequest : public BaseRequest {
385  public:
386   friend class CacheServer;
GenerateSessionIdRequest()387   GenerateSessionIdRequest() : BaseRequest(RequestType::kGenerateSessionId) {
388     // We don't have anything client info nor connection id to send. But we will manually
389     // set the connection id to 0.
390     rq_.set_connection_id(0);
391   }
392 
393   ~GenerateSessionIdRequest() override = default;
394 
GetSessionId()395   session_id_type GetSessionId() { return atoi(reply_.result().data()); }
396 };
397 
398 class ListSessionsRequest : public BaseRequest {
399  public:
400   friend class CacheServer;
ListSessionsRequest()401   ListSessionsRequest() : BaseRequest(RequestType::kListSessions) {
402     // This request is not specific to any cache or session
403     rq_.set_connection_id(0);
404   }
405 
406   ~ListSessionsRequest() override = default;
407 
408   /// \brief Override base function to process the result.
409   Status PostReply() override;
410 
GetSessionCacheInfo(std::vector<SessionCacheInfo> * info)411   void GetSessionCacheInfo(std::vector<SessionCacheInfo> *info) {
412     if (info != nullptr) {
413       (*info) = session_info_list_;
414     }
415   }
416 
GetSessionCacheInfo()417   std::vector<SessionCacheInfo> GetSessionCacheInfo() { return session_info_list_; }
418 
GetSessionIds()419   std::vector<session_id_type> GetSessionIds() {
420     std::vector<session_id_type> session_ids;
421     for (auto session_info : session_info_list_) {
422       session_ids.push_back(session_info.session_id);
423     }
424     return session_ids;
425   }
426 
GetServerStat()427   CacheServerCfgInfo GetServerStat() { return server_cfg_; }
428 
429  private:
430   std::vector<SessionCacheInfo> session_info_list_;
431   CacheServerCfgInfo server_cfg_{};
432 };
433 
434 class AllocateSharedBlockRequest : public BaseRequest {
435  public:
436   friend class CacheServer;
AllocateSharedBlockRequest(connection_id_type connection_id,int32_t client_id,size_t requestedSz)437   explicit AllocateSharedBlockRequest(connection_id_type connection_id, int32_t client_id, size_t requestedSz)
438       : BaseRequest(RequestType::kAllocateSharedBlock) {
439     rq_.set_connection_id(connection_id);
440     rq_.add_buf_data(std::to_string(requestedSz));
441     rq_.set_client_id(client_id);
442   }
443   ~AllocateSharedBlockRequest() override = default;
444 
445   /// \brief On return from the server, we get the (relative) address where
446   /// the free block is located.
447   /// \return
GetAddr()448   int64_t GetAddr() {
449     auto addr = strtoll(reply_.result().data(), nullptr, kDecimal);
450     return addr;
451   }
452 };
453 
454 class ToggleWriteModeRequest : public BaseRequest {
455  public:
456   friend class CacheServer;
ToggleWriteModeRequest(connection_id_type connection_id,bool on_off)457   explicit ToggleWriteModeRequest(connection_id_type connection_id, bool on_off)
458       : BaseRequest(RequestType::kToggleWriteMode) {
459     rq_.set_connection_id(connection_id);
460     rq_.add_buf_data(on_off ? "on" : "off");
461   }
462   ~ToggleWriteModeRequest() override = default;
463 };
464 
465 class ServerStopRequest : public BaseRequest {
466  public:
467   friend class CacheServer;
ServerStopRequest(int32_t qID)468   explicit ServerStopRequest(int32_t qID) : BaseRequest(RequestType::kStopService) {
469     rq_.add_buf_data(std::to_string(qID));
470   }
471   ~ServerStopRequest() = default;
472   Status PostReply() override;
473 };
474 
475 class ConnectResetRequest : public BaseRequest {
476  public:
477   friend class CacheServer;
ConnectResetRequest(connection_id_type connection_id,int32_t client_id)478   explicit ConnectResetRequest(connection_id_type connection_id, int32_t client_id)
479       : BaseRequest(RequestType::kConnectReset) {
480     rq_.set_connection_id(connection_id);
481     rq_.set_client_id(client_id);
482   }
483   ~ConnectResetRequest() override = default;
484 
485   /// Override the base class function
Prepare()486   Status Prepare() override {
487     CHECK_FAIL_RETURN_UNEXPECTED(rq_.client_id() != -1, "Invalid client id");
488     return Status::OK();
489   }
490 };
491 
492 class BatchCacheRowsRequest : public BaseRequest {
493  public:
494   friend class CacheServer;
495   explicit BatchCacheRowsRequest(const CacheClient *cc, int64_t addr, int32_t num_ele);
496   ~BatchCacheRowsRequest() override = default;
497 };
498 }  // namespace dataset
499 }  // namespace mindspore
500 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_
501