• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_EMBEDDING_CACHE_PREFETCH_ACTOR_H_
18 #define MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_EMBEDDING_CACHE_PREFETCH_ACTOR_H_
19 
20 #include <map>
21 #include <set>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 #include <utility>
26 #include <tuple>
27 #include <random>
28 
29 #include "runtime/graph_scheduler/actor/actor_common.h"
30 #include "ir/anf.h"
31 #include "include/backend/kernel_graph.h"
32 #include "include/backend/distributed/cluster/cluster_context.h"
33 #include "distributed/cluster/actor_route_table_proxy.h"
34 #include "include/backend/distributed/rpc/tcp/tcp_client.h"
35 #include "include/backend/distributed/rpc/tcp/tcp_server.h"
36 #include "utils/hash_map.h"
37 #include "include/common/random.h"
38 #include "include/backend/distributed/embedding_cache/embedding_cache_utils.h"
39 #include "include/backend/distributed/embedding_cache/blocking_queue.h"
40 
41 // Note: After the code in ps/ps_cache are removed into runtime/addons/embedding_cache/,
42 // the follow include file and using declaration of ps will be removed.
43 #include "include/backend/distributed/ps/ps_cache/ps_data_prefetch.h"
44 #include "include/backend/distributed/ps/ps_context.h"
45 using mindspore::ps::PSContext;
46 using mindspore::ps::PsDataChannel;
47 using mindspore::ps::PsDataPrefetch;
48 
49 namespace mindspore {
50 namespace runtime {
51 using kernel::Address;
52 using kernel::AddressPtr;
53 using kernel::AddressPtrList;
54 
55 class DeviceEmbeddingOperation;
56 class Sender;
57 class Receiver;
58 using SenderPtr = std::shared_ptr<Sender>;
59 using ReceiverPtr = std::shared_ptr<Receiver>;
60 using SendRecvPair = std::pair<SenderPtr, ReceiverPtr>;
61 using SendRecvPairList = std::vector<SendRecvPair>;
62 
63 using distributed::EmbeddingCacheStatisticsInfo;
64 using distributed::EmbeddingDeviceCache;
65 using distributed::EmbeddingHostCache;
66 using distributed::HashTableInfo;
67 using distributed::kInvalidIndexValue;
68 
69 using distributed::BlockingQueue;
70 using distributed::CacheAnalysis;
71 using distributed::IdsAndIndices;
72 using distributed::UniqueIds;
73 using BlockingQueueTuple =
74   std::tuple<std::shared_ptr<BlockingQueue<UniqueIds>>, std::shared_ptr<BlockingQueue<CacheAnalysis>>,
75              std::shared_ptr<BlockingQueue<IdsAndIndices>>>;
76 
77 using distributed::cluster::ActorRouteTableProxy;
78 using distributed::cluster::ActorRouteTableProxyPtr;
79 using distributed::rpc::TCPClient;
80 using distributed::rpc::TCPServer;
81 
82 using DataType = float;
83 using Generator = random::Philox;
84 using NormalDistribution = random::NormalDistribution<double>;
85 using ConstantDistribution = random::ConstantDistribution<DataType>;
86 
87 constexpr size_t kPipelineStageNum = 4;
88 constexpr size_t kIndex0 = 0;
89 constexpr size_t kIndex1 = 1;
90 constexpr size_t kIndex2 = 2;
91 constexpr size_t kIndex3 = 3;
92 
93 // The EmbeddingCachePrefetchActor is used to cache large embedding table scenarios. The cache level is: Device
94 // Cache->Local Host Cache->Remote Cache. This Actor is used to perform Local and Device Cache hit analysis and cache
95 // prefetching (the feature weights corresponding to the ids of subsequent batches are assigned in advance Prefetching
96 // into the Device Cache, so that it is pipelined with the calculation on the Device side), cache prefetching may
97 // involve RPC communication with the Server side.
98 class EmbeddingCachePrefetchActor : public ActorBase {
99  public:
EmbeddingCachePrefetchActor(device::DeviceContext * device_context)100   explicit EmbeddingCachePrefetchActor(device::DeviceContext *device_context)
101       : ActorBase("EmbeddingCachePrefetchActor"), device_context_(device_context), cpu_device_context_(nullptr) {}
102 
103   ~EmbeddingCachePrefetchActor() override = default;
104 
105   // Initialize embedding cache prefetch actor.
106   // 1. Build and Link rpc operators between local cache and remote cache.
107   // 2. Build network connection of rpc operators.
108   void Initialize();
109 
110   // Perform local cache hit analysis, prefetch the feature vector corresponding to the next batch into the cache.
111   void Run();
112 
113   // Increase the global step of compute graph.
114   void IncreaseGraphStep(const std::string &channel_name);
115 
116   // Sync latest embedding table to remote.
117   void SyncEmbeddingTable();
118 
119   // Finalize embedding cache prefetch actor and push latest embedding from local cache to remote cache.
120   void Finalize(bool finalize_remote);
121 
122   // Wait the computed graph finish current step when there is not enough free memory space in the cache, in order to
123   // delete the feature vector used by the current step from the cache.
124   bool WaitGraphRun();
125 
126   // Reset EmbeddingHashMap for device and local host cache.
127   bool ResetEmbeddingHashMap();
128 
129   // Insert weights into the local host embedding cache.
130   bool InsertLocalHostCache(size_t embedding_size, size_t insert_indices_size, const int *insert_indices,
131                             const float *insert_data, float *hash_table_addr);
132 
133   // Lookup embeddings from local host embedding cache.
134   bool LookupLocalHostCache(size_t embedding_size, size_t indices_num, const float *hash_table_addr,
135                             const int *indices_addr, float *output_addr);
136 
137  private:
138   // Increase the current global step of cache prefetching operation.
139   bool IncreaseStep();
140 
141   // Update the current computed graph's step to real global step at the time when this actor starts to prefetch cache
142   // for a batch ids.
set_current_graph_step()143   void set_current_graph_step() { graph_running_step_ = graph_step_.load(); }
144 
145   // Push non-hotspot embeddings on local host cache to remote.
146   bool PushCacheFromLocalHostToRemote(const HashTableInfo &hash_info, const CacheAnalysis *cache_analysis);
147 
148   // Pull missing embeddings on local cache from remote.
149   bool PullCacheFromRemoteToLocalHost(const HashTableInfo &hash_info, const CacheAnalysis *cache_analysis);
150 
151   // Initialize local cache values using the random number generator.
152   bool InitLocalCacheForNewIds(const HashTableInfo &hash_info);
153   bool InitLocalCacheForNewIds(const HashTableInfo &hash_info, const CacheAnalysis *cache_analysis);
154 
155   // Lookup embedding from Remote and get embeddings via RPC.
156   bool PullEembeddingsFromRemote(int32_t param_key, const int *ids, size_t ids_num, std::vector<float> *outputs);
157   // Push the local embedding cache that requires evict to the remote.
158   bool PushEmbeddingsToRemote(int32_t param_key, const int *ids, size_t ids_num, const float *embeddings,
159                               size_t embeddings_len);
160   bool DoPushEmbeddingsToRemote(int32_t param_key, const int *ids, size_t ids_num, const float *embeddings,
161                                 size_t embeddings_len);
162 
163   // In a multi-server scenario, the embeddings need to be segmented, and each server saves the embeddings of
164   // different feature id ranges. Therefore, when the local side performs the push or pull embeddings operation, the
165   // embeddings and ids need to be divided, and then communicate with the corresponding remote: Partition ids by
166   // remote embedding slice bound and get unique ids.
167   bool PartitionIds(const int *ids, size_t ids_num, std::vector<std::vector<int>> *slice_ids_list);
168   // Partition ids end embeddings by remote embedding slice bound.
169   bool PartitionIdsAndEmbeddings(const int *ids, size_t ids_num, const float *embeddings, size_t embeddings_len,
170                                  std::vector<std::vector<int>> *slice_ids_list,
171                                  std::vector<std::vector<float>> *slice_embeddings_list);
172 
173   // Send content to remote, such as ids or embeddings.
174   // The parameter 'cache_operation' is cache operation name such as LookupEmbeddingCache and UpdateEmbeddingCache.
175   bool SendToRemote(const std::string &cache_operation, int32_t param_key, size_t server_rank_id, size_t embedding_dim,
176                     const void *keys, size_t keys_len, const void *values = nullptr, size_t values_len = 0,
177                     bool finalize_remote = false, bool sync = true);
178   // Wait response of remote and get return result.
179   // The parameter 'cache_operation' is cache operation name such as LookupEmbeddingCache and UpdateEmbeddingCache.
180   std::unique_ptr<std::vector<char>> ReceiveFromRemote(const std::string &cache_operation, int32_t param_key,
181                                                        size_t server_rank_id) const;
182   // Retrieve embeddings by input ids order.
183   bool RetrieveEmbeddings(const int *ids, size_t ids_num, const std::vector<std::vector<int>> &slice_ids_list,
184                           const std::vector<std::unique_ptr<std::vector<char>>> &slice_embeddings_list,
185                           std::vector<float> *outputs) const;
186 
187   // Send finalize request to remote and finalize it.
188   bool FinalizeRemote();
189 
190   // Sync latest local host embedding cache to remote.
191   bool SyncHostEmbeddingTable();
192   // Sync latest device embedding cache to remote.
193   bool SyncDeviceEmbeddingTable();
194 
195   // The cache prefetch phase may involve RPC communication with the server, implemented through Sender and
196   // Receiver.
197   // Build rpc operators.
198   void BuildRpcOperators();
199   // Link rpc operators and build network connection.
200   void LinkRpcOperators();
201 
202   // Get dataset channel name.
203   const std::string &channel_name();
204   // Set dataset channel name.
205   void set_channel_name(const std::string &channel_name);
206 
207   // When the device cache does not reach 100% hit, the cache needs to be updated, which involves cache insertion and
208   // deletion. That is, push the non-hotspot embeddings on the local side to the remote, and pull the missing embeddings
209   // on the local side from the remote.
210   bool UpdateCache();
211 
212   // Do lookup embedding table operation.
213   void LookupEmbeddingTable(size_t indices_num, size_t outer_dim_size, size_t first_dim_size, const float *input_addr,
214                             const int *indices_addr, float *output_addr);
215 
216   // Wait data channel ready.
217   void WaitDataChannelInit();
218 
219   // Wait initialize parameters on remote.
220   // Prevents the subsequent prefetch cache from failing due to the long initialization time of the large parameter on
221   // the remote side.
222   void WaitInitParametersOnRemote();
223 
224   void CreateChannelLock(const std::string &channel_name);
225   void CreateBlockQueue(const std::string &channel_name);
226 
227   // Perform Local and Device Cache hit/miss analysis and prefetch cache for missing embeddings by multi-stage pipeline.
228   // Data flow: unique id queue -> cache analysis queue->id and indices queue
229   void StartPrefetchCachePipeline(const std::string &channel_name);
230   void StopPrefetchCachePipeline();
231   void WaitPrefetchCacheFinish();
232 
233   // The four stage pipeline task.
234   void UniqueIdsTask(const std::string &channel_name);
235   void AnalyseCacheTask(const std::string &channel_name);
236   void UpdateCacheTask(const std::string &channel_name);
237   void TransformIdsToIndicesTask(const std::string &channel_name);
238 
239   // Set current error information before finalizing actor.
240   void SetErrorInfo(const std::string &error_info);
241 
242   mindspore::HashMap<std::string, std::shared_ptr<PsDataChannel>> channel_locks_;
243   mindspore::HashMap<std::string, std::shared_ptr<std::vector<std::thread>>> pipeline_stages_;
244   mindspore::HashMap<std::string, BlockingQueueTuple> channel_to_queues_;
245 
246   // The operations for the embedding on the device.
247   DeviceEmbeddingOperation *emb_ops_{nullptr};
248 
249   // Record sender and receiver pairs for different cache operation, server and parameter key.
250   // key: cache operation(such as LookupEmbeddingCache and UpdateEmbeddingCache)
251   // value: sender and receiver pairs for this kind of cache operation.
252   mindspore::HashMap<std::string, std::vector<SendRecvPairList>> rpc_operators_;
253 
254   // The device interface.
255   device::DeviceContext *device_context_;
256   // The CPU device context used for allocating rpc message data.
257   device::DeviceContext *cpu_device_context_;
258   // The device stream used to async memcpy operators and launch device kernels, such as embedding cache look up and
259   // update kernel.
260   size_t stream_id_{0};
261 
262   // Full Embedding table row num, not less than the total number of feature ids.
263   size_t vocab_size_{0};
264 
265   // Embedding cache size(row number of embedding cache) of local host cache.
266   size_t local_host_cache_size_{0};
267 
268   // Statistics on the cache hit rate of the host and device and the information used to update cache.
269   EmbeddingCacheStatisticsInfo statistics_info_;
270 
271   // Model parallelism is used between multiple workers, and local_embedding_slice_bounds_ records the feature range
272   // corresponding to the embedding table slice of the process.
273   std::pair<int, int> local_embedding_slice_bounds_;
274 
275   // Model parallelism is used between multiple workers, and local_device_cache_bounds_ records the local device cache
276   // range corresponding to the embedding table slice of the process.
277   std::pair<int, int> local_device_cache_bounds_;
278 
279   // In a multi-server scenario, the embeddings need to be segmented, and each server saves the embeddings of
280   // different feature id ranges, remote_embedding_slice_bounds_ records the feature range of the embedding table
281   // slice on each server.
282   std::vector<std::pair<size_t, size_t>> remote_embedding_slice_bounds_;
283 
284   // Total server number of cluster.
285   size_t server_num_{0};
286 
287   // The flag which indicates whether this actor is running to prefetch cache.
288   std::atomic_bool running_{false};
289 
290   // The flag which indicates whether this actor is initialized.
291   bool initialized_{false};
292   // The flag which indicates whether this actor is finalized.
293   bool finalized_{false};
294 
295   // Ensure that the Finalize function is multithreaded safe.
296   std::mutex finalize_mutex_;
297 
298   // The flag which indicates whether finish sync embedding table.
299   bool finish_sync_embedding_table_{false};
300   std::mutex sync_embedding_table_mutex_;
301 
302   // The current global step of the computed graph.
303   std::atomic_ulong graph_step_{0};
304   // The computed graph's global step at the time when this actor starts to prefetch cache for a batch ids.
305   std::atomic_ulong graph_running_step_{0};
306   // The current global step of cache prefetching operation.
307   size_t data_step_{0};
308 
309   // Dataset channel name, used in dataset switching scenarios.
310   std::string channel_name_{""};
311   // The mutex to access channel_name_.
312   std::mutex channel_mutex_;
313 
314   // The flag indicates whether finish initializing parameters on remote..
315   std::atomic_bool finish_init_parameters_on_remote_{false};
316 
317   // Data parser condition variable for prefetching cache, used to start and synchronize intermediate state for cache
318   // prefetching.
319   std::condition_variable data_parser_;
320   // Data parser mutex for prefetching cache.
321   std::mutex data_mutex_;
322 
323   // Whether device cache prefetching process needs to wait the computed graph finish current step when there is not
324   // enough free memory space in the cache.
325   bool device_cache_need_wait_graph_{false};
326   // Whether local host cache prefetching process needs to wait the computed graph finish current step when there is not
327   // enough free memory space in the cache.
328   bool host_cache_need_wait_graph_{false};
329 
330   std::mutex pipeline_mutex_;
331   // Record latest error information user related.
332   std::string error_info_{""};
333 };
334 
335 // RpcOperator is used to do rpc with other processes in distributed execution.
336 // RpcOperator use inter process edge to identify paired rpc operators uniquely.
337 class RpcOperator {
338  public:
RpcOperator()339   RpcOperator() : inter_process_edge_(""), route_table_proxy_(nullptr) {}
340   virtual ~RpcOperator() = default;
341 
342   // Set the inter-process edge name for rpc operators.
set_inter_process_edge_name(const std::string & edge_name)343   void set_inter_process_edge_name(const std::string &edge_name) { inter_process_edge_ = edge_name; }
344 
345   // Set the route table proxy for rpc operators.
set_actor_route_table_proxy(const ActorRouteTableProxyPtr & route_table_proxy)346   void set_actor_route_table_proxy(const ActorRouteTableProxyPtr &route_table_proxy) {
347     route_table_proxy_ = route_table_proxy;
348   }
349 
350  protected:
351   // Unique edge name between rpc operator, format:
352   // src role + src rank id -> dst role + dst rank id + embedding cache operation + parameter key.
353   std::string inter_process_edge_;
354 
355   // Route table proxy for buildding network connection between nodes like workers and server.
356   ActorRouteTableProxyPtr route_table_proxy_;
357 };
358 
359 // Sender is used to send data to other process.
360 class Sender : public RpcOperator {
361  public:
Sender(device::DeviceContext * cpu_device_context)362   explicit Sender(device::DeviceContext *cpu_device_context)
363       : server_url_(""), client_(nullptr), cpu_device_context_(cpu_device_context) {}
364   ~Sender() override;
365 
366   // Send buffer to peer.
367   bool Send(const std::vector<ShapeVector> &shapes, const std::vector<TypeId> data_types,
368             const AddressPtrList &data_list, bool finalize_remote = false, bool sync = true) const;
369 
370   // Set the receiver paired with the sender to get the 'from url' from the receiver.
set_receiver(const ReceiverPtr & receiver)371   void set_receiver(const ReceiverPtr &receiver) { receiver_ = receiver; }
372 
373   // Lookup peer receiver's route and build network connection.
374   bool ConnectServer();
375 
376  private:
377   // Build the MessageBase include dynamic shape protobuf, which will be sent to peer receiver.
378   // The message format is as below:
379   // |--------22 bytes-------|-------sizeof(size_t)-------|-dynamic shape PB data size-| real data size |
380   // |RPC_DYNAMIC_SHAPE_DATA | dynamic shape PB data size |---dynamic shape PB data----|---real data----|
381   // The message.from (from url) must be set.
382   std::unique_ptr<MessageBase> BuildRpcMessage(const std::vector<ShapeVector> &shapes,
383                                                const std::vector<TypeId> data_types, const AddressPtrList &data_list,
384                                                const std::string &from_url, const std::string &to_url,
385                                                bool finalize_remote) const;
386 
387   // Free message after it's sent to remote.
388   bool FreeMessage(void *data);
389 
390   // Calculate the dynamic shape message size.
391   size_t CalDataSize(const std::vector<ShapeVector> &shapes, const std::vector<TypeId> data_types,
392                      const AddressPtrList &data_list, bool finalize_remote) const;
393 
394   // The url of the peer receiver's tcp server.
395   std::string server_url_;
396 
397   std::unique_ptr<TCPClient> client_;
398 
399   // The sender and the receiver are used in pairs. The information sent by the sender contains the url of the
400   // corresponding receiver, so a reference to the receiver is maintained in the sender.
401   ReceiverPtr receiver_;
402 
403   // The CPU device context used for allocating rpc message data.
404   device::DeviceContext *cpu_device_context_;
405 };
406 
407 // Receiver is used to receive data from other process.
408 class Receiver : public RpcOperator {
409  public:
Receiver(device::DeviceContext * cpu_device_context)410   explicit Receiver(device::DeviceContext *cpu_device_context)
411       : ip_(""),
412         port_(0),
413         server_(nullptr),
414         received_buffer_(nullptr),
415         received_msg_(false),
416         cpu_device_context_(cpu_device_context) {}
417   ~Receiver() override;
418 
419   // Receive message from the peer sender, this interface is a synchronous interface and will wait for the message
420   // until the timeout period is reached.
421   std::unique_ptr<std::vector<char>> Receive();
422 
423   // Start receiver server and register this server address to route table in scheduler by proxy.
424   bool StartServer();
425 
426   // Get the url of this receiver, format: ip:port.
get_url()427   std::string get_url() const { return ip_ + ":" + std::to_string(port_); }
428 
429  private:
430   // The message callback of the tcp server.
431   MessageBase *HandleMessage(MessageBase *const msg);
432 
433   // Parse the dynamic shape protobuf message. The format is as below:
434   // |--------22 bytes-------|-------sizeof(size_t)-------|-dynamic shape PB data size-| real data size |
435   // |RPC_DYNAMIC_SHAPE_DATA | dynamic shape PB data size |---dynamic shape PB data----|---real data----|
436   // The output parameter 'data' contains real data addr and size.
437   bool ParseDynamicShapeData(const char *msg_body, size_t msg_len, std::pair<const void *, size_t> *data) const;
438 
439   // The callback set to rpc module to allocate message(Raw pointer).
440   void *AllocateMessage(size_t size);
441 
442   // The network address of this receiver. It's generated automatically by rpc module.
443   std::string ip_;
444   uint32_t port_;
445 
446   std::unique_ptr<TCPServer> server_;
447 
448   // The buffer used save received content of message.
449   std::unique_ptr<std::vector<char>> received_buffer_;
450 
451   // The flag indicates whether receive message successfully.
452   std::atomic_bool received_msg_;
453 
454   // The interface 'Receive' is a synchronous, use condition variable to block thread and wait for the message.
455   std::condition_variable received_msg_cv_;
456   std::mutex received_msg_mtx_;
457 
458   // The CPU device context used for allocating rpc message data.
459   device::DeviceContext *cpu_device_context_;
460 };
461 
462 using EmbeddingCachePrefetchActorPtr = std::shared_ptr<EmbeddingCachePrefetchActor>;
463 }  // namespace runtime
464 }  // namespace mindspore
465 #endif  // MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_EMBEDDING_CACHE_PREFETCH_ACTOR_H_
466