• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h"
18 #include <limits>
19 #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
20 #include "kernel/common_utils.h"
21 #include "runtime/graph_scheduler/actor/rpc/rpc_actor.h"
22 #include "proto/topology.pb.h"
23 #include "include/backend/distributed/constants.h"
24 #include "include/backend/distributed/rpc/tcp/constants.h"
25 #include "runtime/graph_scheduler/actor/embedding_cache/device_dense_embedding_operation.h"
26 #include "runtime/graph_scheduler/actor/embedding_cache/device_sparse_embedding_operation.h"
27 #include "include/backend/distributed/embedding_cache/data_queue_manager.h"
28 
29 namespace mindspore {
30 namespace runtime {
31 using distributed::IdDataInfo;
32 using distributed::IndexDataInfo;
33 
34 using distributed::DataQueueManager;
35 using distributed::cluster::ClusterContext;
36 using mindspore::session::KernelGraph;
37 constexpr size_t kDefaultQueueCapacity = 128;
38 
39 namespace {
40 // Generate unique inter process edge name, format:
41 // src role + src rank id -> dst role + dst rank id + embedding cache operation + parameter key.
GenerateInterProcessEdge(const std::string & src_role,uint32_t src_rank,const std::string & dst_role,uint32_t dst_rank,const std::string & cache_operation,int32_t param_key)42 std::string GenerateInterProcessEdge(const std::string &src_role, uint32_t src_rank, const std::string &dst_role,
43                                      uint32_t dst_rank, const std::string &cache_operation, int32_t param_key) {
44   std::string edge = src_role + std::to_string(src_rank) + "->" + dst_role + std::to_string(dst_rank) + "_" +
45                      cache_operation + "_" + distributed::kParameterKey + std::to_string(param_key);
46   return edge;
47 }
48 
CreateRouteTableProxy()49 ActorRouteTableProxyPtr CreateRouteTableProxy() {
50   auto cgn = std::dynamic_pointer_cast<distributed::cluster::topology::ComputeGraphNode>(
51     ClusterContext::instance()->node_base());
52   ActorRouteTableProxyPtr actor_route_table_proxy = std::make_shared<ActorRouteTableProxy>(cgn);
53   MS_EXCEPTION_IF_NULL(actor_route_table_proxy);
54   return actor_route_table_proxy;
55 }
56 
57 // Create a sender and receiver pair,The sender and receiver are paired.
58 // When creating a sender, need to create and specify the receiver paired with it in advance.
CreateSenderReceiverPair(uint32_t worker_rank,uint32_t server_rank,const std::string & cache_operation,int32_t param_key,device::DeviceContext * cpu_device_context)59 SendRecvPair CreateSenderReceiverPair(uint32_t worker_rank, uint32_t server_rank, const std::string &cache_operation,
60                                       int32_t param_key, device::DeviceContext *cpu_device_context) {
61   // Create sender and receiver pair.
62   ReceiverPtr receiver = std::make_shared<Receiver>(cpu_device_context);
63   MS_EXCEPTION_IF_NULL(receiver);
64   SenderPtr sender = std::make_shared<Sender>(cpu_device_context);
65   MS_EXCEPTION_IF_NULL(sender);
66   sender->set_receiver(receiver);
67 
68   // Set inter process edge
69   receiver->set_inter_process_edge_name(GenerateInterProcessEdge(distributed::kEnvRoleOfPServer, server_rank,
70                                                                  distributed::kEnvRoleOfWorker, worker_rank,
71                                                                  cache_operation, param_key));
72   sender->set_inter_process_edge_name(GenerateInterProcessEdge(distributed::kEnvRoleOfWorker, worker_rank,
73                                                                distributed::kEnvRoleOfPServer, server_rank,
74                                                                cache_operation, param_key));
75 
76   // Set route table proxy.
77   receiver->set_actor_route_table_proxy(CreateRouteTableProxy());
78   sender->set_actor_route_table_proxy(CreateRouteTableProxy());
79 
80   return std::make_pair(sender, receiver);
81 }
82 
83 // Get cache operation service id which is used to decide which set of cache services to request.
84 // The server side executes the corresponding service according to this id.
GetCacheOpsServiceId(const std::string & cache_operation,int32_t param_key)85 int32_t GetCacheOpsServiceId(const std::string &cache_operation, int32_t param_key) {
86   static mindspore::HashMap<std::string, int32_t> cache_ops_to_index;
87   if (cache_ops_to_index.empty()) {
88     int32_t cnt = 0;
89     for (const auto &cache_op : distributed::kEmbeddingCacheOps) {
90       cache_ops_to_index[cache_op] = cnt++;
91     }
92   }
93 
94   auto iter = cache_ops_to_index.find(cache_operation);
95   if (iter == cache_ops_to_index.end()) {
96     MS_LOG(EXCEPTION) << "Can not find index for cache operation: " << cache_operation;
97   }
98 
99   int32_t id = SizeToInt(distributed::kEmbeddingCacheOps.size()) * param_key + iter->second;
100   return id;
101 }
102 
103 // Parallelly generate fixed or random numbers continuously using specified algorithm.
104 template <typename T, typename Generator, typename Distribution, typename... Args>
GenerateDistributionParallel(size_t size,T * output,Args...args)105 void GenerateDistributionParallel(size_t size, T *output, Args... args) {
106   std::thread threads[kMaxThreadNum];
107   std::random_device rd;
108   const std::uint64_t seed = rd();
109 
110   // 1. Compute total thread number need to parallel generate distribution and the size of new numbers that each
111   // thread need to generate.
112   // Once calculation of the normal distribution may produce two random values, so each thread should be responsible for
113   // producing an even number of random numbers, except for the last thread.
114   auto [thread_num, size_per_thread] = random::ComputeTaskNumSize(size, kMaxThreadNum);
115 
116   // For performance, multi-thread concurrency is not required when the total size is small.
117   if (thread_num == 1) {
118     random::GenerateRandoms<T, Generator, Distribution, Args...>(seed, 0, output, size, args...);
119     return;
120   }
121 
122   // 2. Parallelly generate specified distribution using specified algorithm.
123   // Note that the offset need to be set to 'Generator' to prevent generating same sequence of each thread.
124   size_t offset = 0;
125   for (size_t i = 0; i < thread_num; ++i) {
126     size_t task_len = ((i < (thread_num - 1)) ? size_per_thread : (size - ((thread_num - 1) * size_per_thread)));
127     threads[i] = std::thread(&random::GenerateRandoms<T, Generator, Distribution, Args...>, seed, offset,
128                              output + offset, task_len, args...);
129     offset += task_len;
130   }
131 
132   for (size_t j = 0; j < thread_num; j++) {
133     threads[j].join();
134   }
135 }
136 
DeduplicateId(UniqueIds * unique_ids)137 void DeduplicateId(UniqueIds *unique_ids) {
138   MS_EXCEPTION_IF_NULL(unique_ids);
139 
140   constexpr size_t kMaxParallelNum = 32;
141   size_t parallel_num = unique_ids->multi_batch_data_.size();
142   if (parallel_num > kMaxParallelNum) {
143     MS_LOG(EXCEPTION) << "The parallel num: " << parallel_num
144                       << " can not be greater than max parallel num: " << kMaxParallelNum;
145   }
146   std::thread threads[kMaxParallelNum];
147 
148   std::vector<mindspore::HashSet<int>> unique_batch_ids_sets(parallel_num);
149   auto unique_task = [&](int *origin_batch_ids, size_t proc_len, mindspore::HashSet<int> *unique_set) {
150     (void)std::for_each(origin_batch_ids, origin_batch_ids + proc_len,
151                         [&unique_set](int id) { (void)unique_set->insert(id); });
152   };
153 
154   size_t i = 0;
155   for (; i < parallel_num; ++i) {
156     threads[i] = std::thread(unique_task, reinterpret_cast<int *>(unique_ids->multi_batch_data_.at(i)),
157                              unique_ids->multi_batch_size_.at(i), &unique_batch_ids_sets[i]);
158   }
159 
160   for (size_t j = 0; j < i; j++) {
161     threads[j].join();
162   }
163 
164   for (size_t k = 1; k < parallel_num; ++k) {
165     auto end_iter = unique_batch_ids_sets[k].end();
166     for (auto iter = unique_batch_ids_sets[k].begin(); iter != end_iter; ++iter) {
167       unique_batch_ids_sets[0].insert(*iter);
168     }
169   }
170   const auto &unique_ids_set = unique_batch_ids_sets.front();
171   unique_ids->ids_num_ = unique_ids_set.size();
172   unique_ids->ids_ = new (std::nothrow) int[unique_ids->ids_num_];
173   MS_EXCEPTION_IF_NULL(unique_ids->ids_);
174   size_t index = 0;
175   auto unique_ids_ptr = unique_ids->ids_;
176   (void)std::for_each(unique_ids_set.begin(), unique_ids_set.end(), [&](int id) { unique_ids_ptr[index++] = id; });
177 }
178 
TransformIdsToIndices(mindspore::HashMap<int,int> * unique_ids_to_indices,size_t batch_ids_num,int * batch_ids)179 void TransformIdsToIndices(mindspore::HashMap<int, int> *unique_ids_to_indices, size_t batch_ids_num, int *batch_ids) {
180   auto change_id_to_index_func = [&](int *batch_ids_ptr, size_t proc_len) {
181     for (size_t i = 0; i < proc_len; i++) {
182       batch_ids_ptr[i] = (*unique_ids_to_indices)[batch_ids_ptr[i]];
183     }
184   };
185 
186   size_t thread_num = batch_ids_num / kMaxIdsPerThread + 1;
187   thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
188   std::thread threads[kMaxThreadNum];
189   size_t i = 0;
190   size_t offset = 0;
191 
192   for (; i < thread_num; ++i) {
193     size_t proc_len = batch_ids_num / thread_num + (i < (batch_ids_num % thread_num) ? 1 : 0);
194     threads[i] = std::thread(change_id_to_index_func, batch_ids + offset, proc_len);
195     offset += proc_len;
196   }
197   if (offset != batch_ids_num) {
198     MS_LOG(WARNING) << "Check id in device inadequate, total:" << batch_ids_num << " checked:" << offset;
199   }
200 
201   for (size_t j = 0; j < i; j++) {
202     threads[j].join();
203   }
204 }
205 }  // namespace
206 
Initialize()207 void EmbeddingCachePrefetchActor::Initialize() {
208   if (initialized_) {
209     return;
210   }
211   MS_EXCEPTION_IF_NULL(device_context_);
212   MS_EXCEPTION_IF_NULL(device_context_->device_res_manager_);
213   if (!device_context_->device_res_manager_->CreateStream(&stream_id_)) {
214     MS_LOG(EXCEPTION) << "Create stream failed.";
215   }
216 
217   // Get embedding cache table info.
218   local_host_cache_size_ = embedding_cache_table_manager.host_cache_size_;
219   vocab_size_ = embedding_cache_table_manager.vocab_size_;
220   local_embedding_slice_bounds_ = embedding_cache_table_manager.local_embedding_slice_bounds_;
221   local_device_cache_bounds_ = embedding_cache_table_manager.local_device_cache_bounds_;
222 
223   // Initialize CPU device context. The origin device context for embedding cache prefetch actor is GPU or NPU. But we
224   // still need the CPU device context to allocate host memory.
225   device::DeviceContextKey host_key = {"CPU", 0};
226   cpu_device_context_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(host_key);
227   MS_EXCEPTION_IF_NULL(cpu_device_context_);
228   cpu_device_context_->Initialize();
229 
230   server_num_ = PSContext::instance()->server_num();
231   if (server_num_ == 0) {
232     MS_LOG(EXCEPTION) << "The number of servers is at least 1, but get 0";
233   }
234 
235   // Build and link rpc operators.
236   BuildRpcOperators();
237   LinkRpcOperators();
238 
239   // Create the device embedding operation.
240   emb_ops_ = new (std::nothrow) DeviceDenseEmbeddingOperation(
241     this, device_context_, local_embedding_slice_bounds_, local_device_cache_bounds_, &statistics_info_, stream_id_);
242   MS_EXCEPTION_IF_NULL(emb_ops_);
243   if (!emb_ops_->Initialize()) {
244     MS_LOG(ERROR) << "Failed to initialize the device embedding operation.";
245   }
246 
247   // Get the id range of each server's embedding table slice.
248   emb_ops_->GetRemoteEmbeddingSliceBound(vocab_size_, server_num_, &remote_embedding_slice_bounds_);
249 
250   initialized_ = true;
251 }
252 
Finalize(bool finalize_remote)253 void EmbeddingCachePrefetchActor::Finalize(bool finalize_remote) {
254   std::lock_guard<std::mutex> lock(finalize_mutex_);
255   if (!initialized_ || finalized_) {
256     return;
257   }
258 
259   StopPrefetchCachePipeline();
260   for (const auto &item : channel_locks_) {
261     const auto &channel_ptr = item.second;
262     MS_EXCEPTION_IF_NULL(channel_ptr);
263     channel_ptr->TryWakeChannel(true);
264   }
265   WaitPrefetchCacheFinish();
266 
267   PsDataPrefetch::GetInstance().NotifyFinalize();
268 
269   if (finalize_remote) {
270     (void)FinalizeRemote();
271   }
272 
273   data_parser_.notify_all();
274 
275   if (emb_ops_ != nullptr) {
276     delete emb_ops_;
277     emb_ops_ = nullptr;
278   }
279 
280   rpc_operators_.clear();
281   finalized_ = true;
282   initialized_ = false;
283 }
284 
IncreaseGraphStep(const std::string & channel_name)285 void EmbeddingCachePrefetchActor::IncreaseGraphStep(const std::string &channel_name) {
286   if (!running_) {
287     std::string error_info =
288       !error_info_.empty() ? error_info_ : "Embedding cache prefetch actor is finalized abnormally.";
289     MS_LOG(EXCEPTION) << error_info;
290   }
291   if (graph_step_ >= UINT64_MAX) {
292     MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t.";
293   }
294   if (graph_step_ == 0) {
295     MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameters_on_remote_;
296     std::unique_lock<std::mutex> locker(data_mutex_);
297     data_parser_.wait(locker, [this] { return ((finish_init_parameters_on_remote_ == true) || (running_ == false)); });
298     if (!running_) {
299       std::string error_info =
300         !error_info_.empty() ? error_info_ : "Embedding cache prefetch actor is finalized abnormally.";
301       MS_LOG(EXCEPTION) << error_info;
302     }
303     distributed::EmbeddingCacheTableManager::GetInstance().WaitForWarmUpHostCacheComplete();
304     MS_LOG(INFO) << "Graph running waiting embedding table init end.";
305   }
306   graph_step_++;
307   if (embedding_cache_table_manager.enable_pipeline()) {
308     if (channel_name != channel_name_) {
309       set_channel_name(channel_name);
310       // Create pipeline tasks for this channel
311       StartPrefetchCachePipeline(channel_name);
312     }
313     const auto &iter = channel_locks_.find(channel_name);
314     if (iter == channel_locks_.end()) {
315       MS_LOG(EXCEPTION) << "Can not find channel lock for channel: " << channel_name;
316     }
317     MS_EXCEPTION_IF_NULL(iter->second);
318     iter->second->TryWakeChannel();
319   } else {
320     set_channel_name(channel_name);
321     if (!PsDataPrefetch::GetInstance().TryWakeChannel(channel_name)) {
322       MS_LOG(EXCEPTION) << "TryWakeChannel failed, channel name: " << channel_name;
323     }
324   }
325   data_parser_.notify_one();
326 }
327 
Run()328 void EmbeddingCachePrefetchActor::Run() {
329   running_ = true;
330 
331   // Bind device to current thread to gain device control privileges
332   MS_EXCEPTION_IF_NULL(device_context_);
333   MS_EXCEPTION_IF_NULL(device_context_->device_res_manager_);
334   if (!device_context_->device_res_manager_->BindDeviceToCurrentThread(false)) {
335     MS_LOG(ERROR) << "Failed to bind device to current thread.";
336     running_ = false;
337     PsDataPrefetch::GetInstance().NotifyFinalize();
338     return;
339   }
340 
341   // Wait initialize parameters on remote.
342   // Prevents the subsequent prefetch cache from failing due to the long initialization time of the large parameter on
343   // the remote side.
344   WaitInitParametersOnRemote();
345 
346   // Wait data channel ready.
347   WaitDataChannelInit();
348 }
349 
CreateChannelLock(const std::string & channel_name)350 void EmbeddingCachePrefetchActor::CreateChannelLock(const std::string &channel_name) {
351   if (channel_locks_.find(channel_name) != channel_locks_.end()) {
352     return;
353   }
354   auto sink_size = DataQueueManager::GetInstance().GetSinkSize(channel_name);
355   channel_locks_.emplace(channel_name, std::make_shared<PsDataChannel>(channel_name, sink_size));
356 }
357 
CreateBlockQueue(const std::string & channel_name)358 void EmbeddingCachePrefetchActor::CreateBlockQueue(const std::string &channel_name) {
359   auto unique_ids_queue = std::make_shared<BlockingQueue<UniqueIds>>(kDefaultQueueCapacity);
360   auto cache_analysis_queue = std::make_shared<BlockingQueue<CacheAnalysis>>(kDefaultQueueCapacity);
361   auto ids_and_indices_queue = std::make_shared<BlockingQueue<IdsAndIndices>>(kDefaultQueueCapacity);
362   (void)channel_to_queues_.emplace(channel_name,
363                                    std::make_tuple(unique_ids_queue, cache_analysis_queue, ids_and_indices_queue));
364 }
365 
StartPrefetchCachePipeline(const std::string & channel_name)366 void EmbeddingCachePrefetchActor::StartPrefetchCachePipeline(const std::string &channel_name) {
367   MS_LOG(INFO) << "Begin StartPrefetchCachePipeline for channel name: " << channel_name;
368   std::lock_guard<std::mutex> lock(pipeline_mutex_);
369   if (pipeline_stages_.find(channel_name) != pipeline_stages_.end()) {
370     return;
371   }
372 
373   CreateChannelLock(channel_name);
374   CreateBlockQueue(channel_name);
375 
376   auto thread_list = std::make_shared<std::vector<std::thread>>(kPipelineStageNum);
377   pipeline_stages_.emplace(channel_name, thread_list);
378 
379   thread_list->at(kIndex0) = std::thread(&EmbeddingCachePrefetchActor::UniqueIdsTask, this, channel_name);
380   thread_list->at(kIndex1) = std::thread(&EmbeddingCachePrefetchActor::AnalyseCacheTask, this, channel_name);
381   thread_list->at(kIndex2) = std::thread(&EmbeddingCachePrefetchActor::UpdateCacheTask, this, channel_name);
382   thread_list->at(kIndex3) = std::thread(&EmbeddingCachePrefetchActor::TransformIdsToIndicesTask, this, channel_name);
383   MS_LOG(INFO) << "End StartPrefetchCachePipeline for channel name: " << channel_name;
384 }
385 
StopPrefetchCachePipeline()386 void EmbeddingCachePrefetchActor::StopPrefetchCachePipeline() {
387   MS_LOG(INFO) << "Begin StopPrefetchCachePipeline";
388   std::lock_guard<std::mutex> lock(pipeline_mutex_);
389   running_ = false;
390   DataQueueManager::GetInstance().CloseAllQueues();
391   for (const auto &item : channel_to_queues_) {
392     const BlockingQueueTuple &queues_tuple = item.second;
393     const auto &unique_ids_queue = std::get<std::shared_ptr<BlockingQueue<UniqueIds>>>(queues_tuple);
394     MS_EXCEPTION_IF_NULL(unique_ids_queue);
395     unique_ids_queue->Close();
396 
397     const auto &cache_analysis_queue = std::get<std::shared_ptr<BlockingQueue<CacheAnalysis>>>(queues_tuple);
398     MS_EXCEPTION_IF_NULL(cache_analysis_queue);
399     cache_analysis_queue->Close();
400 
401     const auto &ids_and_indices_queue = std::get<std::shared_ptr<BlockingQueue<IdsAndIndices>>>(queues_tuple);
402     MS_EXCEPTION_IF_NULL(ids_and_indices_queue);
403     ids_and_indices_queue->Close();
404   }
405   MS_LOG(INFO) << "End StopPrefetchCachePipeline";
406 }
407 
WaitPrefetchCacheFinish()408 void EmbeddingCachePrefetchActor::WaitPrefetchCacheFinish() {
409   std::lock_guard<std::mutex> lock(pipeline_mutex_);
410   for (auto &item : pipeline_stages_) {
411     const std::string &channel_name = item.first;
412     MS_LOG(INFO) << "Begin stop pipeline for channel: " << channel_name;
413     auto stage_threads = item.second;
414     MS_EXCEPTION_IF_NULL(stage_threads);
415     for (size_t i = 0; i < kPipelineStageNum; ++i) {
416       if (stage_threads->at(i).joinable()) {
417         stage_threads->at(i).join();
418       }
419     }
420     MS_LOG(INFO) << "End stop pipeline for channel: " << channel_name;
421   }
422 }
423 
UniqueIdsTask(const std::string & channel_name)424 void EmbeddingCachePrefetchActor::UniqueIdsTask(const std::string &channel_name) {
425   const auto &iter = channel_locks_.find(channel_name);
426   if (iter == channel_locks_.end()) {
427     MS_LOG(EXCEPTION) << "Can not find channel lock for channel: " << channel_name;
428   }
429   auto channel_lock = iter->second;
430   MS_EXCEPTION_IF_NULL(channel_lock);
431 
432   const auto &id_data_queue = DataQueueManager::GetInstance().GetDataQueue(channel_name).first;
433   MS_EXCEPTION_IF_NULL(id_data_queue);
434 
435   const auto &queue_iter = channel_to_queues_.find(channel_name);
436   if (queue_iter == channel_to_queues_.end()) {
437     MS_LOG(EXCEPTION) << "Can not find queue for channel: " << channel_name;
438   }
439   const auto &unique_ids_queue = std::get<std::shared_ptr<BlockingQueue<UniqueIds>>>(queue_iter->second);
440   MS_EXCEPTION_IF_NULL(unique_ids_queue);
441 
442   size_t sink_size = DataQueueManager::GetInstance().GetSinkSize(channel_name);
443   size_t multi_batch_counter = 0;
444   UniqueIds *unique_ids = nullptr;
445   while (running_) {
446     IdDataInfo *data = id_data_queue->Pop();
447     if (!running_) {
448       break;
449     }
450     MS_EXCEPTION_IF_NULL(data);
451     int *batch_ids = reinterpret_cast<int *>(data->data_);
452     if (batch_ids) {
453       // Lock in first stage to support multi-channel case for real input data.
454       channel_lock->TryLockChannel();
455       MS_EXCEPTION_IF_CHECK_FAIL(IncreaseStep(), "Increase step failed.");
456     }
457 
458     if (data->end_of_file_ || data->end_of_epoch_) {
459       // Push empty data for epoch or file end flag.
460       UniqueIds *empty_unique_ids = new (std::nothrow) UniqueIds();
461       MS_EXCEPTION_IF_NULL(empty_unique_ids);
462       empty_unique_ids->end_of_epoch_ = data->end_of_epoch_;
463       empty_unique_ids->end_of_file_ = data->end_of_file_;
464       unique_ids_queue->Push(empty_unique_ids);
465       delete data;
466       continue;
467     }
468 
469     if (unique_ids == nullptr) {
470       unique_ids = new (std::nothrow) UniqueIds();
471       MS_EXCEPTION_IF_NULL(unique_ids);
472     }
473 
474     ++multi_batch_counter;
475     if (multi_batch_counter < sink_size &&
476         multi_batch_counter % embedding_cache_table_manager.multi_batch_threshold_ != 0) {
477       unique_ids->multi_batch_data_.push_back(batch_ids);
478       unique_ids->multi_batch_size_.push_back(data->size_ / sizeof(int));
479       unique_ids->multi_batch_items_.push_back(data->items_);
480       continue;
481     }
482     unique_ids->multi_batch_data_.push_back(batch_ids);
483     unique_ids->multi_batch_size_.push_back(data->size_ / sizeof(int));
484     unique_ids->multi_batch_items_.push_back(data->items_);
485 
486     if (multi_batch_counter == sink_size) {
487       multi_batch_counter = 0;
488     }
489 
490     // Unique for each batch and store unique ids
491     DeduplicateId(unique_ids);
492     // Push to next stage pipeline queue.
493     unique_ids->data_step_ = data_step_;
494 
495     unique_ids_queue->Push(unique_ids);
496     unique_ids = nullptr;
497     delete data;
498   }
499 }
500 
AnalyseCacheTask(const std::string & channel_name)501 void EmbeddingCachePrefetchActor::AnalyseCacheTask(const std::string &channel_name) {
502   const auto &queue_iter = channel_to_queues_.find(channel_name);
503   if (queue_iter == channel_to_queues_.end()) {
504     MS_LOG(EXCEPTION) << "Can not find queue for channel: " << channel_name;
505   }
506 
507   const auto &unique_ids_queue = std::get<std::shared_ptr<BlockingQueue<UniqueIds>>>(queue_iter->second);
508   MS_EXCEPTION_IF_NULL(unique_ids_queue);
509   const auto &cache_analysis_queue = std::get<std::shared_ptr<BlockingQueue<CacheAnalysis>>>(queue_iter->second);
510   MS_EXCEPTION_IF_NULL(cache_analysis_queue);
511 
512   while (running_) {
513     UniqueIds *unique_ids = unique_ids_queue->Pop();
514     if (!running_) {
515       break;
516     }
517     MS_EXCEPTION_IF_NULL(unique_ids);
518     if (unique_ids->end_of_file_ || unique_ids->end_of_epoch_) {
519       // Push empty data for epoch or file end flag.
520       CacheAnalysis *cache_analysis = new (std::nothrow) CacheAnalysis();
521       MS_EXCEPTION_IF_NULL(cache_analysis);
522       cache_analysis->end_of_epoch_ = unique_ids->end_of_epoch_;
523       cache_analysis->end_of_file_ = unique_ids->end_of_file_;
524       cache_analysis_queue->Push(cache_analysis);
525       delete unique_ids;
526       continue;
527     }
528     size_t unique_ids_num = unique_ids->ids_num_;
529     int *indices = new (std::nothrow) int[unique_ids_num];
530     MS_EXCEPTION_IF_NULL(indices);
531 
532     EmbeddingDeviceCache *embedding_device_cache = new (std::nothrow) EmbeddingDeviceCache(unique_ids_num);
533     MS_EXCEPTION_IF_NULL(embedding_device_cache);
534     EmbeddingHostCache *embedding_host_cache = new (std::nothrow) EmbeddingHostCache(unique_ids_num);
535     MS_EXCEPTION_IF_NULL(embedding_host_cache);
536     EmbeddingCacheStatisticsInfo *statistics_info = new (std::nothrow) EmbeddingCacheStatisticsInfo();
537     MS_EXCEPTION_IF_NULL(statistics_info);
538 
539     // Analyse cache hit/miss
540     if (!emb_ops_->AnalyseCache(unique_ids->ids_, unique_ids_num, unique_ids->data_step_, &graph_step_,
541                                 &device_cache_need_wait_graph_, &host_cache_need_wait_graph_, indices,
542                                 embedding_device_cache, embedding_host_cache, statistics_info)) {
543       MS_LOG(ERROR) << "Analyse cache failed.";
544       StopPrefetchCachePipeline();
545       return;
546     }
547 
548     // Push analyse result to update cache queue
549     CacheAnalysis *cache_analysis =
550       new (std::nothrow) CacheAnalysis(embedding_device_cache, embedding_host_cache, statistics_info, unique_ids,
551                                        indices, unique_ids->end_of_epoch_, unique_ids->end_of_file_);
552     MS_EXCEPTION_IF_NULL(cache_analysis);
553     cache_analysis_queue->Push(cache_analysis);
554   }
555 }
556 
UpdateCacheTask(const std::string & channel_name)557 void EmbeddingCachePrefetchActor::UpdateCacheTask(const std::string &channel_name) {
558   const auto &queue_iter = channel_to_queues_.find(channel_name);
559   if (queue_iter == channel_to_queues_.end()) {
560     MS_LOG(EXCEPTION) << "Can not find queue for channel: " << channel_name;
561   }
562 
563   const auto &cache_analysis_queue = std::get<std::shared_ptr<BlockingQueue<CacheAnalysis>>>(queue_iter->second);
564   MS_EXCEPTION_IF_NULL(cache_analysis_queue);
565 
566   const auto &ids_and_indices_queue = std::get<std::shared_ptr<BlockingQueue<IdsAndIndices>>>(queue_iter->second);
567   MS_EXCEPTION_IF_NULL(ids_and_indices_queue);
568 
569   while (running_) {
570     CacheAnalysis *cache_analysis = cache_analysis_queue->Pop();
571     if (!running_) {
572       break;
573     }
574     MS_EXCEPTION_IF_NULL(cache_analysis);
575     if (cache_analysis->end_of_file_ || cache_analysis->end_of_epoch_) {
576       // Push empty data for epoch end flag.
577       IdsAndIndices *ids_and_indices = new (std::nothrow) IdsAndIndices();
578       MS_EXCEPTION_IF_NULL(ids_and_indices);
579       ids_and_indices->end_of_epoch_ = cache_analysis->end_of_epoch_;
580       ids_and_indices->end_of_file_ = cache_analysis->end_of_file_;
581       ids_and_indices_queue->Push(ids_and_indices);
582       delete cache_analysis;
583       continue;
584     }
585 
586     for (const auto &item : embedding_cache_table_manager.hash_tables_) {
587       const auto &hash_info = item.second;
588       MS_EXCEPTION_IF_CHECK_FAIL(PushCacheFromLocalHostToRemote(hash_info, cache_analysis),
589                                  "Push cache from local host to remote failed.");
590       MS_EXCEPTION_IF_CHECK_FAIL(emb_ops_->PushCacheFromDeviceToLocalHost(hash_info, cache_analysis),
591                                  "Push cache from device to local host failed.");
592       MS_EXCEPTION_IF_CHECK_FAIL(InitLocalCacheForNewIds(hash_info, cache_analysis),
593                                  "Initialize the local cache values using random generator.");
594       MS_EXCEPTION_IF_CHECK_FAIL(PullCacheFromRemoteToLocalHost(hash_info, cache_analysis),
595                                  "Pull cache from remote to local host failed.");
596       MS_EXCEPTION_IF_CHECK_FAIL(emb_ops_->PullCacheFromLocalHostToDevice(hash_info, cache_analysis),
597                                  "Pull cache from local host to device failed.");
598     }
599 
600     IdsAndIndices *ids_and_indices =
601       new (std::nothrow) IdsAndIndices(cache_analysis->unique_ids_, cache_analysis->indices_,
602                                        cache_analysis->end_of_epoch_, cache_analysis->end_of_file_);
603 
604     ids_and_indices_queue->Push(ids_and_indices);
605 
606     delete cache_analysis->embedding_host_cache_;
607     delete cache_analysis->embedding_device_cache_;
608     delete cache_analysis->statistics_info_;
609     delete cache_analysis;
610   }
611 }
612 
TransformIdsToIndicesTask(const std::string & channel_name)613 void EmbeddingCachePrefetchActor::TransformIdsToIndicesTask(const std::string &channel_name) {
614   const auto &queue_iter = channel_to_queues_.find(channel_name);
615   if (queue_iter == channel_to_queues_.end()) {
616     MS_LOG(EXCEPTION) << "Can not find queue for channel: " << channel_name;
617   }
618   const auto &ids_and_indices_queue = std::get<std::shared_ptr<BlockingQueue<IdsAndIndices>>>(queue_iter->second);
619   MS_EXCEPTION_IF_NULL(ids_and_indices_queue);
620 
621   const auto &index_data_queue = DataQueueManager::GetInstance().GetDataQueue(channel_name).second;
622   MS_EXCEPTION_IF_NULL(index_data_queue);
623   while (running_) {
624     IdsAndIndices *ids_and_indices = ids_and_indices_queue->Pop();
625     if (!running_) {
626       break;
627     }
628     MS_EXCEPTION_IF_NULL(ids_and_indices);
629     // Push empty data for epoch end flag.
630     if (ids_and_indices->end_of_file_ || ids_and_indices->end_of_epoch_) {
631       IndexDataInfo *indices_info = new (std::nothrow) IndexDataInfo();
632       MS_EXCEPTION_IF_NULL(indices_info);
633       indices_info->end_of_epoch_ = ids_and_indices->end_of_epoch_;
634       indices_info->end_of_file_ = ids_and_indices->end_of_file_;
635       index_data_queue->Push(indices_info);
636       delete ids_and_indices;
637       continue;
638     }
639 
640     auto *unique_ids = ids_and_indices->unique_ids_;
641     MS_EXCEPTION_IF_NULL(unique_ids);
642     auto *unique_ids_ptr = unique_ids->ids_;
643     auto unique_ids_num = unique_ids->ids_num_;
644     auto *unique_indices_ptr = ids_and_indices->indices_;
645 
646     mindspore::HashMap<int, int> unique_ids_to_indices;
647     for (size_t i = 0; i < unique_ids_num; i++) {
648       (void)unique_ids_to_indices.try_emplace(unique_ids_ptr[i], unique_indices_ptr[i]);
649     }
650 
651     for (size_t i = 0; i < unique_ids->multi_batch_data_.size(); ++i) {
652       if (!embedding_cache_table_manager.is_sparse_format()) {
653         TransformIdsToIndices(&unique_ids_to_indices, unique_ids->multi_batch_size_.at(i),
654                               reinterpret_cast<int *>(unique_ids->multi_batch_data_.at(i)));
655       }
656       IndexDataInfo *indices_info =
657         new IndexDataInfo(unique_ids->multi_batch_data_.at(i), unique_ids->multi_batch_items_.at(i),
658                           ids_and_indices->end_of_epoch_, ids_and_indices->end_of_file_);
659 
660       if (unique_ids->multi_batch_data_.at(i) != unique_ids->multi_batch_items_.at(i)->at(0).data_ptr) {
661         MS_LOG(EXCEPTION) << "The id data ptr is valid";
662       }
663 
664       index_data_queue->Push(indices_info);
665     }
666 
667     delete[] ids_and_indices->unique_ids_->ids_;
668     delete ids_and_indices->unique_ids_;
669     delete[] ids_and_indices->indices_;
670     delete ids_and_indices;
671   }
672 }
673 
IncreaseStep()674 bool EmbeddingCachePrefetchActor::IncreaseStep() {
675   if (data_step_ >= UINT64_MAX) {
676     MS_LOG(ERROR) << "The data step (" << data_step_ << ") will exceed the maximum value of uint64_t.";
677     return false;
678   }
679   data_step_++;
680   set_current_graph_step();
681   if (graph_running_step_ > data_step_) {
682     MS_LOG(ERROR) << "The graph running step (" << graph_running_step_ << ") exceed the data step (" << data_step_
683                   << ").";
684     return false;
685   }
686   return true;
687 }
688 
WaitGraphRun()689 bool EmbeddingCachePrefetchActor::WaitGraphRun() {
690   MS_LOG(INFO) << "Hash table has no space to insert new data and retries within 2 minutes.";
691   std::unique_lock<std::mutex> locker(data_mutex_);
692   const int64_t longest_time_to_wait = 120;
693   if (!data_parser_.wait_for(locker, std::chrono::seconds(longest_time_to_wait),
694                              [this] { return graph_step_ > graph_running_step_; })) {
695     std::string err_info = "Prefetch embedding cache timeout, please enlarge the vocab cache size(graph step:" +
696                            std::to_string(graph_step_) + ", graph running step:" + std::to_string(graph_running_step_) +
697                            ").";
698     SetErrorInfo(err_info);
699     MS_LOG(ERROR) << err_info;
700     return false;
701   }
702   set_current_graph_step();
703   return true;
704 }
705 
ResetEmbeddingHashMap()706 bool EmbeddingCachePrefetchActor::ResetEmbeddingHashMap() {
707   const auto &device_hash_map = embedding_cache_table_manager.device_hash_map_;
708   MS_ERROR_IF_NULL(device_hash_map);
709   const auto &host_hash_map = embedding_cache_table_manager.host_hash_map_;
710   MS_ERROR_IF_NULL(host_hash_map);
711   device_hash_map->Reset();
712   host_hash_map->Reset();
713   device_cache_need_wait_graph_ = false;
714   host_cache_need_wait_graph_ = false;
715   return true;
716 }
717 
LookupEmbeddingTable(size_t indices_num,size_t embedding_size,size_t first_dim_size,const float * input_addr,const int * indices_addr,float * output_addr)718 void EmbeddingCachePrefetchActor::LookupEmbeddingTable(size_t indices_num, size_t embedding_size, size_t first_dim_size,
719                                                        const float *input_addr, const int *indices_addr,
720                                                        float *output_addr) {
721   MS_ERROR_IF_NULL_WO_RET_VAL(input_addr);
722   MS_ERROR_IF_NULL_WO_RET_VAL(indices_addr);
723   MS_ERROR_IF_NULL_WO_RET_VAL(output_addr);
724 
725   auto type_size = sizeof(float);
726   size_t lens = embedding_size * type_size;
727   for (size_t i = 0; i < indices_num; ++i) {
728     int index = indices_addr[i];
729     if (index >= 0 && index < SizeToInt(first_dim_size)) {
730       size_t pos = IntToSize(index) * embedding_size;
731       auto ret = memcpy_s(output_addr, (indices_num - i) * lens, input_addr + pos, lens);
732       if (ret != EOK) {
733         MS_LOG(ERROR) << "Memcpy failed, errno[" << ret << "]";
734         StopPrefetchCachePipeline();
735         return;
736       }
737     } else {
738       auto ret = memset_s(output_addr, (indices_num - i) * lens, 0, lens);
739       if (ret != EOK) {
740         MS_LOG(ERROR) << "Memset failed, errno[" << ret << "]";
741         StopPrefetchCachePipeline();
742         return;
743       }
744     }
745     output_addr += embedding_size;
746   }
747 }
748 
PushCacheFromLocalHostToRemote(const HashTableInfo & hash_info,const CacheAnalysis * cache_analysis)749 bool EmbeddingCachePrefetchActor::PushCacheFromLocalHostToRemote(const HashTableInfo &hash_info,
750                                                                  const CacheAnalysis *cache_analysis) {
751   MS_ERROR_IF_NULL(cache_analysis);
752   auto statistics_info = cache_analysis->statistics_info_;
753   auto embedding_host_cache = cache_analysis->embedding_host_cache_;
754   MS_ERROR_IF_NULL(statistics_info);
755   MS_ERROR_IF_NULL(embedding_host_cache);
756 
757   auto swap_indices_size = statistics_info->host_to_server_size_;
758   if (swap_indices_size == 0) {
759     return true;
760   }
761 
762   auto host_to_server_ids = embedding_host_cache->host_to_server_ids.get();
763   MS_ERROR_IF_NULL(host_to_server_ids);
764   auto host_to_server_index = embedding_host_cache->host_to_server_index.get();
765   MS_ERROR_IF_NULL(host_to_server_index);
766 
767   std::vector<float> swap_out_data;
768   auto embedding_size = hash_info.embedding_size;
769   swap_out_data.resize(swap_indices_size * embedding_size);
770   auto host_hash_table_addr = hash_info.host_address;
771 
772   RETURN_IF_FALSE_WITH_LOG(LookupLocalHostCache(embedding_size, swap_indices_size, host_hash_table_addr,
773                                                 host_to_server_index, swap_out_data.data()),
774                            "Lookup local host cache failed.");
775   RETURN_IF_FALSE_WITH_LOG(PushEmbeddingsToRemote(hash_info.param_key_, host_to_server_ids, swap_indices_size,
776                                                   swap_out_data.data(), swap_out_data.size() * sizeof(float)),
777                            "Push embeddings to remote failed.");
778   return true;
779 }
780 
PullCacheFromRemoteToLocalHost(const HashTableInfo & hash_info,const CacheAnalysis * cache_analysis)781 bool EmbeddingCachePrefetchActor::PullCacheFromRemoteToLocalHost(const HashTableInfo &hash_info,
782                                                                  const CacheAnalysis *cache_analysis) {
783   MS_ERROR_IF_NULL(cache_analysis);
784   auto statistics_info = cache_analysis->statistics_info_;
785   auto embedding_host_cache = cache_analysis->embedding_host_cache_;
786   MS_ERROR_IF_NULL(statistics_info);
787   MS_ERROR_IF_NULL(embedding_host_cache);
788 
789   auto swap_indices_size = statistics_info->server_to_host_size_;
790   if (swap_indices_size == 0) {
791     return true;
792   }
793 
794   auto server_to_host_ids = embedding_host_cache->server_to_host_ids.get();
795   MS_ERROR_IF_NULL(server_to_host_ids);
796   auto server_to_host_index = embedding_host_cache->server_to_host_index.get();
797   MS_ERROR_IF_NULL(server_to_host_index);
798 
799   auto host_hash_table_addr = hash_info.host_address;
800   MS_ERROR_IF_NULL(host_hash_table_addr);
801   auto embedding_size = hash_info.embedding_size;
802   std::vector<float> lookup_result(swap_indices_size * embedding_size, 0);
803 
804   RETURN_IF_FALSE_WITH_LOG(
805     PullEembeddingsFromRemote(hash_info.param_key_, server_to_host_ids, swap_indices_size, &lookup_result),
806     "Pull embedding from remote failed.");
807   RETURN_IF_FALSE_WITH_LOG(InsertLocalHostCache(embedding_size, IntToSize(swap_indices_size), server_to_host_index,
808                                                 lookup_result.data(), host_hash_table_addr),
809                            "Insert local host cache failed.");
810   return true;
811 }
812 
InitLocalCacheForNewIds(const HashTableInfo & hash_info,const CacheAnalysis * cache_analysis)813 bool EmbeddingCachePrefetchActor::InitLocalCacheForNewIds(const HashTableInfo &hash_info,
814                                                           const CacheAnalysis *cache_analysis) {
815   MS_ERROR_IF_NULL(cache_analysis);
816   auto statistics_info = cache_analysis->statistics_info_;
817   auto embedding_host_cache = cache_analysis->embedding_host_cache_;
818   MS_ERROR_IF_NULL(statistics_info);
819   MS_ERROR_IF_NULL(embedding_host_cache);
820 
821   auto new_id_size = statistics_info->new_id_size_;
822   if (new_id_size == 0) {
823     return true;
824   }
825 
826   auto new_id_index = embedding_host_cache->new_id_index.get();
827   MS_ERROR_IF_NULL(new_id_index);
828 
829   // Compute the feature values size needed to be initialized.
830   auto embedding_size = hash_info.embedding_size;
831   auto total_size = new_id_size * embedding_size;
832   std::vector<float> init_result(total_size, 0);
833 
834   // Initialize accumulate values with the configured constant value.
835   if (hash_info.param_init_info_.param_type_ == distributed::ParamType::kAccumulation) {
836     auto init_value = hash_info.param_init_info_.init_val_;
837     GenerateDistributionParallel<DataType, Generator, ConstantDistribution>(total_size, init_result.data(), init_value);
838   } else {
839     // Initialize embedding values from local random generator for feature ids that have never been seen before.
840     const double mean = 0.0;
841     const double sigma = 0.01;
842     GenerateDistributionParallel<DataType, Generator, NormalDistribution>(total_size, init_result.data(), mean, sigma);
843   }
844 
845   // Insert initialized feature values into the local hash cache.
846   auto host_hash_table_addr = hash_info.host_address;
847   MS_ERROR_IF_NULL(host_hash_table_addr);
848   RETURN_IF_FALSE_WITH_LOG(InsertLocalHostCache(embedding_size, IntToSize(new_id_size), new_id_index,
849                                                 init_result.data(), host_hash_table_addr),
850                            "Insert local host cache failed.");
851   return true;
852 }
853 
InsertLocalHostCache(size_t embedding_size,size_t insert_indices_size,const int * insert_indices,const float * insert_data,float * hash_table_addr)854 bool EmbeddingCachePrefetchActor::InsertLocalHostCache(size_t embedding_size, size_t insert_indices_size,
855                                                        const int *insert_indices, const float *insert_data,
856                                                        float *hash_table_addr) {
857   MS_ERROR_IF_NULL(insert_indices);
858   MS_ERROR_IF_NULL(insert_data);
859   MS_ERROR_IF_NULL(hash_table_addr);
860 
861   size_t first_dim_size = local_host_cache_size_;
862   size_t thread_num = insert_indices_size / kMaxIdsPerThread + 1;
863   thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
864   std::thread threads[kMaxThreadNum];
865   size_t proc_len = (insert_indices_size + thread_num - 1) / thread_num;
866   size_t i = 0;
867   size_t offset = 0;
868 
869   auto insert_cache_func = [this](size_t insert_indices_size, size_t embedding_size, size_t first_dim_size,
870                                   const int *insert_indices, const float *insert_data, float *hash_table_addr) {
871     auto type_size = sizeof(float);
872     size_t copy_len = embedding_size * type_size;
873     size_t dest_len = copy_len;
874     for (size_t i = 0; i < insert_indices_size; ++i) {
875       int index = insert_indices[i];
876       if (index >= 0 && index < SizeToInt(first_dim_size)) {
877         auto ret =
878           memcpy_s(hash_table_addr + index * embedding_size, dest_len, insert_data + i * embedding_size, copy_len);
879         if (ret != EOK) {
880           MS_LOG(ERROR) << "Memcpy failed, errno[" << ret << "]";
881           StopPrefetchCachePipeline();
882           return;
883         }
884       }
885     }
886   };
887 
888   for (; i < thread_num; i++) {
889     if (offset >= insert_indices_size) {
890       break;
891     }
892     threads[i] = std::thread(insert_cache_func, proc_len, embedding_size, first_dim_size, insert_indices + offset,
893                              insert_data + offset * embedding_size, hash_table_addr);
894     offset += proc_len;
895     if (offset + proc_len > insert_indices_size) {
896       proc_len = insert_indices_size - offset;
897     }
898   }
899 
900   for (size_t j = 0; j < i; j++) {
901     threads[j].join();
902   }
903   return running_;
904 }
905 
LookupLocalHostCache(size_t embedding_size,size_t indices_num,const float * hash_table_addr,const int * indices_addr,float * output_addr)906 bool EmbeddingCachePrefetchActor::LookupLocalHostCache(size_t embedding_size, size_t indices_num,
907                                                        const float *hash_table_addr, const int *indices_addr,
908                                                        float *output_addr) {
909   MS_ERROR_IF_NULL(hash_table_addr);
910   MS_ERROR_IF_NULL(indices_addr);
911   MS_ERROR_IF_NULL(output_addr);
912 
913   size_t first_dim_size = local_host_cache_size_;
914   size_t thread_num = indices_num / kMaxIdsPerThread + 1;
915   thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
916   std::thread threads[kMaxThreadNum];
917   size_t proc_len = (indices_num + thread_num - 1) / thread_num;
918   size_t i = 0;
919   size_t offset = 0;
920 
921   for (; i < thread_num; i++) {
922     if (offset >= indices_num) {
923       break;
924     }
925     threads[i] =
926       std::thread(&EmbeddingCachePrefetchActor::LookupEmbeddingTable, this, proc_len, embedding_size, first_dim_size,
927                   hash_table_addr, indices_addr + offset, output_addr + offset * embedding_size);
928     offset += proc_len;
929     if (offset + proc_len > indices_num) {
930       proc_len = indices_num - offset;
931     }
932   }
933 
934   for (size_t j = 0; j < i; j++) {
935     threads[j].join();
936   }
937   return running_;
938 }
939 
PullEembeddingsFromRemote(int32_t param_key,const int * ids,size_t ids_num,std::vector<float> * outputs)940 bool EmbeddingCachePrefetchActor::PullEembeddingsFromRemote(int32_t param_key, const int *ids, size_t ids_num,
941                                                             std::vector<float> *outputs) {
942   MS_ERROR_IF_NULL(ids);
943   MS_ERROR_IF_NULL(outputs);
944 
945   if (ids_num == 0) {
946     MS_LOG(WARNING) << "The ids number is 0";
947     return true;
948   }
949 
950   std::vector<std::vector<int>> slice_ids_list(server_num_);
951   // 1. Partition ids by remote embedding slice bound and get unique ids.
952   RETURN_IF_FALSE_WITH_LOG(PartitionIds(ids, ids_num, &slice_ids_list), "Partition ids failed.");
953 
954   size_t embedding_dim = outputs->size() / ids_num;
955   for (size_t i = 0; i < server_num_; i++) {
956     auto &slice_ids = slice_ids_list[i];
957     if (slice_ids.empty()) {
958       continue;
959     }
960 
961     // 2. Send unique ids to remote to do embedding lookup.
962     RETURN_IF_FALSE_WITH_LOG(SendToRemote(distributed::kLookupEmbeddingCache, param_key, i, embedding_dim,
963                                           slice_ids.data(), slice_ids.size() * sizeof(int), nullptr, 0, false, false),
964                              "Send ids to server failed.");
965   }
966 
967   std::vector<std::unique_ptr<std::vector<char>>> slice_embeddings_list(server_num_);
968   for (size_t i = 0; i < server_num_; i++) {
969     if (slice_ids_list[i].empty()) {
970       continue;
971     }
972 
973     // 3. Wait embeddings result.
974     slice_embeddings_list[i] = ReceiveFromRemote(distributed::kLookupEmbeddingCache, param_key, i);
975     MS_ERROR_IF_NULL(slice_embeddings_list[i]);
976     // Received embedding integrity check.
977     size_t expected_embedding_size = SizetMulWithOverflowCheck(slice_ids_list[i].size(), embedding_dim);
978     size_t received_embedding_size = slice_embeddings_list[i]->size() / sizeof(float);
979     if (received_embedding_size != expected_embedding_size) {
980       MS_LOG(ERROR) << "Received embedding data from remote is incomplete, expected embedding size: "
981                     << expected_embedding_size << ", but received embedding size: " << received_embedding_size;
982       return false;
983     }
984   }
985 
986   // 4. Retrieve embeddings by input ids order.
987   RETURN_IF_FALSE_WITH_LOG(RetrieveEmbeddings(ids, ids_num, slice_ids_list, slice_embeddings_list, outputs),
988                            "Retrieve embeddings failed.");
989 
990   return true;
991 }
992 
DoPushEmbeddingsToRemote(int32_t param_key,const int * ids,size_t ids_num,const float * embeddings,size_t embeddings_len)993 bool EmbeddingCachePrefetchActor::DoPushEmbeddingsToRemote(int32_t param_key, const int *ids, size_t ids_num,
994                                                            const float *embeddings, size_t embeddings_len) {
995   MS_LOG(DEBUG) << "Enter DoPushEmbeddingsToRemote - param_key : " << param_key << ", ids : " << ids
996                 << ", ids_num : " << ids_num << ", embeddings : " << embeddings
997                 << ", embeddings_len : " << embeddings_len << ".";
998   if (ids_num == 0) {
999     MS_LOG(ERROR) << "Invalidate ids num : 0.";
1000     return false;
1001   }
1002   std::vector<std::vector<int>> slice_ids_list(server_num_);
1003   std::vector<std::vector<float>> slice_embeddings_list(server_num_);
1004   // 1. Partition ids end embeddings by remote embedding slice bound.
1005   RETURN_IF_FALSE_WITH_LOG(
1006     PartitionIdsAndEmbeddings(ids, ids_num, embeddings, embeddings_len, &slice_ids_list, &slice_embeddings_list),
1007     "Partition ids and embeddings failed.");
1008 
1009   size_t embedding_dim = (embeddings_len / ids_num) / sizeof(float);
1010   for (size_t i = 0; i < server_num_; i++) {
1011     auto &slice_ids = slice_ids_list[i];
1012     if (slice_ids.empty()) {
1013       continue;
1014     }
1015 
1016     // 2. Send embeddings to remote.
1017     auto &slice_embeddings = slice_embeddings_list[i];
1018     RETURN_IF_FALSE_WITH_LOG(
1019       SendToRemote(distributed::kUpdateEmbeddingCache, param_key, i, embedding_dim, slice_ids.data(),
1020                    slice_ids.size() * sizeof(int), slice_embeddings.data(), slice_embeddings.size() * sizeof(float)),
1021       "Send ids and embeddings to server failed.");
1022   }
1023   MS_LOG(DEBUG) << "Exit DoPushEmbeddingsToRemote.";
1024   return true;
1025 }
1026 
PushEmbeddingsToRemote(int32_t param_key,const int * ids,size_t ids_num,const float * embeddings,size_t embeddings_len)1027 bool EmbeddingCachePrefetchActor::PushEmbeddingsToRemote(int32_t param_key, const int *ids, size_t ids_num,
1028                                                          const float *embeddings, size_t embeddings_len) {
1029   MS_EXCEPTION_IF_NULL(ids);
1030   MS_EXCEPTION_IF_NULL(embeddings);
1031   MS_EXCEPTION_IF_CHECK_FAIL(ids_num != 0, "The ids_num is 0.");
1032 
1033   const auto ids_boundary = ids + ids_num;
1034   const auto embeddings_boundary = embeddings + embeddings_len / sizeof(float);
1035   MS_LOG(DEBUG) << "Enter PushEmbeddingsToRemote - param_key : " << param_key << ", ids : " << ids
1036                 << ", ids_num : " << ids_num << ", embeddings : " << embeddings
1037                 << ", embeddings_len : " << embeddings_len << ", ids_boundary : " << ids_boundary
1038                 << ", embeddings_boundary : " << embeddings_boundary << ".";
1039   const size_t embeddings_num = embeddings_len / sizeof(float);
1040   const size_t embeddings_dim = embeddings_num / ids_num;
1041   MS_EXCEPTION_IF_CHECK_FAIL(embeddings_dim != 0, "The embeddings_dim is 0.");
1042   // Max batch size : 128Mb.
1043   const size_t max_batch_size = 1 << 27;
1044   const size_t batch_num = max_batch_size / embeddings_dim;
1045   MS_EXCEPTION_IF_CHECK_FAIL(batch_num != 0, "The batch_num is 0.");
1046   size_t batch_size = ids_num / batch_num;
1047   size_t batch_remainder = ids_num % batch_num;
1048   if (batch_remainder != 0) {
1049     batch_size++;
1050   }
1051   MS_LOG(DEBUG) << "batch_size : " << batch_size << ", batch_num << : " << batch_num
1052                 << ", batch_remainder : " << batch_remainder << ".";
1053   for (size_t count = 0; count != batch_size; count++) {
1054     auto batch_ids = ids + batch_num * count;
1055     size_t batch_ids_num = (count != batch_size - 1) ? batch_num : batch_remainder;
1056     auto batch_embeddings = embeddings + batch_num * count * embeddings_dim;
1057     size_t batch_embeddings_len = batch_ids_num * embeddings_dim * sizeof(float);
1058     (void)DoPushEmbeddingsToRemote(param_key, batch_ids, batch_ids_num, batch_embeddings, batch_embeddings_len);
1059   }
1060   MS_LOG(DEBUG) << "Exit PushEmbeddingsToRemote.";
1061   return true;
1062 }
1063 
PartitionIds(const int * ids,size_t ids_num,std::vector<std::vector<int>> * slice_ids_list)1064 bool EmbeddingCachePrefetchActor::PartitionIds(const int *ids, size_t ids_num,
1065                                                std::vector<std::vector<int>> *slice_ids_list) {
1066   MS_ERROR_IF_NULL(ids);
1067   MS_ERROR_IF_NULL(slice_ids_list);
1068 
1069   size_t partition_num = slice_ids_list->size();
1070   // There is no need to partition ids for one server case.
1071   if (partition_num == 1) {
1072     std::vector<int> &slice_ids = slice_ids_list->front();
1073     slice_ids.resize(ids_num);
1074     auto ret = memcpy_s(slice_ids.data(), slice_ids.size() * sizeof(int), ids, ids_num * sizeof(int));
1075     if (ret != EOK) {
1076       MS_LOG(ERROR) << "Memcpy failed, errno[" << ret << "]";
1077       return false;
1078     }
1079     return true;
1080   }
1081 
1082   for (size_t i = 0; i < partition_num; i++) {
1083     int begin = SizeToInt(remote_embedding_slice_bounds_[i].first);
1084     int end = SizeToInt(remote_embedding_slice_bounds_[i].second);
1085 
1086     std::vector<int> &slice_ids = slice_ids_list->at(i);
1087     (void)std::for_each(ids, ids + ids_num, [&](int id) {
1088       if (id >= begin && id <= end) {
1089         slice_ids.push_back(id);
1090       }
1091     });
1092   }
1093 
1094   return true;
1095 }
1096 
PartitionIdsAndEmbeddings(const int * ids,size_t ids_num,const float * embeddings,size_t embeddings_len,std::vector<std::vector<int>> * slice_ids_list,std::vector<std::vector<float>> * slice_embeddings_list)1097 bool EmbeddingCachePrefetchActor::PartitionIdsAndEmbeddings(const int *ids, size_t ids_num, const float *embeddings,
1098                                                             size_t embeddings_len,
1099                                                             std::vector<std::vector<int>> *slice_ids_list,
1100                                                             std::vector<std::vector<float>> *slice_embeddings_list) {
1101   MS_ERROR_IF_NULL(ids);
1102   MS_ERROR_IF_NULL(embeddings);
1103   MS_ERROR_IF_NULL(slice_ids_list);
1104   MS_ERROR_IF_NULL(slice_embeddings_list);
1105 
1106   if (ids_num == 0) {
1107     MS_LOG(WARNING) << "The ids number is 0";
1108     return true;
1109   }
1110 
1111   size_t partition_num = slice_ids_list->size();
1112   // There is no need to partition ids and embeddings for one server case.
1113   if (partition_num == 1) {
1114     std::vector<int> &slice_ids = slice_ids_list->front();
1115     std::vector<float> &slice_embeddings = slice_embeddings_list->front();
1116     slice_ids.resize(ids_num);
1117     slice_embeddings.resize(embeddings_len / sizeof(float));
1118     auto ret = memcpy_s(slice_ids.data(), slice_ids.size() * sizeof(int), ids, ids_num * sizeof(int));
1119     if (ret != EOK) {
1120       MS_LOG(ERROR) << "Memcpy failed, errno[" << ret << "]";
1121       return false;
1122     }
1123     ret = memcpy_s(slice_embeddings.data(), slice_embeddings.size() * sizeof(float), embeddings, embeddings_len);
1124     if (ret != EOK) {
1125       MS_LOG(ERROR) << "Memcpy failed, errno[" << ret << "]";
1126       return false;
1127     }
1128     return true;
1129   }
1130 
1131   size_t embedding_dim = (embeddings_len / ids_num) / sizeof(float);
1132   for (size_t i = 0; i < partition_num; i++) {
1133     int begin = SizeToInt(remote_embedding_slice_bounds_[i].first);
1134     int end = SizeToInt(remote_embedding_slice_bounds_[i].second);
1135 
1136     std::vector<int> &slice_ids = slice_ids_list->at(i);
1137     std::vector<float> &slice_embeddings = slice_embeddings_list->at(i);
1138     // Ids range offset for multi server.
1139     int offset = SizeToInt(remote_embedding_slice_bounds_.at(i).first);
1140     for (size_t j = 0; j < ids_num; j++) {
1141       if (ids[j] >= begin && ids[j] <= end) {
1142         slice_ids.push_back(ids[j] - offset);
1143         (void)slice_embeddings.insert(slice_embeddings.end(), embeddings + (j * embedding_dim),
1144                                       embeddings + (j * embedding_dim) + embedding_dim);
1145       }
1146     }
1147   }
1148   return true;
1149 }
1150 
SendToRemote(const std::string & cache_operation,int32_t param_key,size_t server_rank_id,size_t embedding_dim,const void * keys,size_t keys_len,const void * values,size_t values_len,bool finalize_remote,bool sync)1151 bool EmbeddingCachePrefetchActor::SendToRemote(const std::string &cache_operation, int32_t param_key,
1152                                                size_t server_rank_id, size_t embedding_dim, const void *keys,
1153                                                size_t keys_len, const void *values, size_t values_len,
1154                                                bool finalize_remote, bool sync) {
1155   MS_ERROR_IF_NULL(keys);
1156   // Find sender corresponding to cache operation and parameter key.
1157   auto iter = rpc_operators_.find(cache_operation);
1158   if (iter == rpc_operators_.end()) {
1159     MS_LOG(ERROR) << "Can not find rpc operator for cache operation: " << cache_operation;
1160     return false;
1161   }
1162 
1163   const std::vector<SendRecvPairList> &send_recv_pair_lists = iter->second;
1164   const SenderPtr &sender = send_recv_pair_lists[server_rank_id][param_key].first;
1165   MS_ERROR_IF_NULL(sender);
1166 
1167   int64_t ids_num = SizeToLong(keys_len / sizeof(int));
1168   ShapeVector ids_shape = {ids_num};
1169   ShapeVector values_shape;
1170   float fake_value = 0.0;
1171 
1172   if (values == nullptr && values_len == 0) {
1173     values_shape = {1, 1};
1174     values = &fake_value;
1175     values_len = sizeof(fake_value);
1176   } else {
1177     MS_EXCEPTION_IF_ZERO("embedding_dim", embedding_dim);
1178     int64_t embed_vec_num = SizeToLong(values_len / sizeof(float) / embedding_dim);
1179     if (embed_vec_num != ids_num) {
1180       MS_LOG(EXCEPTION) << "The embedding vector number[" << embed_vec_num << "] shouled be equal to ids number["
1181                         << ids_num << "] which will be send to remote.";
1182     }
1183     values_shape = {embed_vec_num, SizeToLong(embedding_dim)};
1184   }
1185 
1186   std::vector<ShapeVector> shapes = {ids_shape, values_shape, {static_cast<int64_t>(1)}};
1187   std::vector<TypeId> data_types = {kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeInt32};
1188 
1189   int32_t service_id = GetCacheOpsServiceId(cache_operation, param_key);
1190   AddressPtrList data_list = {std::make_shared<Address>(const_cast<void *>(keys), keys_len),
1191                               std::make_shared<Address>(const_cast<void *>(values), values_len),
1192                               std::make_shared<Address>(&service_id, sizeof(int32_t))};
1193 
1194   // Send data.
1195   return sender->Send(shapes, data_types, data_list, finalize_remote, sync);
1196 }
1197 
ReceiveFromRemote(const std::string & cache_operation,int32_t param_key,size_t server_rank_id) const1198 std::unique_ptr<std::vector<char>> EmbeddingCachePrefetchActor::ReceiveFromRemote(const std::string &cache_operation,
1199                                                                                   int32_t param_key,
1200                                                                                   size_t server_rank_id) const {
1201   // Find receiver corresponding to cache operation and parameter key.
1202   auto iter = rpc_operators_.find(cache_operation);
1203   if (iter == rpc_operators_.end()) {
1204     MS_LOG(ERROR) << "Can not find rpc operator for cache operation: " << cache_operation;
1205     return nullptr;
1206   }
1207 
1208   const std::vector<SendRecvPairList> &send_recv_pair_lists = iter->second;
1209   const ReceiverPtr &receiver = send_recv_pair_lists[server_rank_id][param_key].second;
1210   MS_EXCEPTION_IF_NULL(receiver);
1211   // Receive data.
1212   return receiver->Receive();
1213 }
1214 
RetrieveEmbeddings(const int * ids,size_t ids_num,const std::vector<std::vector<int>> & slice_ids_list,const std::vector<std::unique_ptr<std::vector<char>>> & slice_embeddings_list,std::vector<float> * outputs) const1215 bool EmbeddingCachePrefetchActor::RetrieveEmbeddings(
1216   const int *ids, size_t ids_num, const std::vector<std::vector<int>> &slice_ids_list,
1217   const std::vector<std::unique_ptr<std::vector<char>>> &slice_embeddings_list, std::vector<float> *outputs) const {
1218   MS_ERROR_IF_NULL(ids);
1219   MS_ERROR_IF_NULL(outputs);
1220 
1221   if (ids_num == 0) {
1222     MS_LOG(WARNING) << "The ids number is 0";
1223     return true;
1224   }
1225 
1226   // Merge all slice ids and embedding data address into ids_to_addrs map.
1227   mindspore::HashMap<int, const float *> ids_to_addrs;
1228   size_t embedding_dim = outputs->size() / ids_num;
1229   size_t offset = 0;
1230   for (size_t i = 0; i < slice_ids_list.size(); i++) {
1231     const std::vector<int> &slice_ids = slice_ids_list[i];
1232     if (slice_ids.empty()) {
1233       continue;
1234     }
1235     const std::unique_ptr<std::vector<char>> &slice_embeddings = slice_embeddings_list[i];
1236     MS_ERROR_IF_NULL(slice_embeddings);
1237     const float *embeddings_data = reinterpret_cast<float *>(slice_embeddings->data());
1238     for (size_t j = 0; j < slice_ids.size(); j++) {
1239       (void)ids_to_addrs.emplace(slice_ids[j], embeddings_data + offset);
1240       offset += embedding_dim;
1241     }
1242     offset = 0;
1243   }
1244 
1245   float *outputs_data = outputs->data();
1246   size_t dst_size = embedding_dim * sizeof(float);
1247   size_t src_size = dst_size;
1248   offset = 0;
1249 
1250   // Retrieve embeddings by input ids order.
1251   for (size_t i = 0; i < ids_num; i++) {
1252     auto id = ids[i];
1253     auto iter = ids_to_addrs.find(id);
1254     if (iter == ids_to_addrs.end()) {
1255       MS_LOG(WARNING) << "Can not find id[" << id << "]";
1256       continue;
1257     }
1258 
1259     auto ret = memcpy_s(outputs_data + offset, dst_size, iter->second, src_size);
1260     if (ret != 0) {
1261       MS_LOG(ERROR) << "Memcpy failed, errno[" << ret << "]";
1262       return false;
1263     }
1264     offset += embedding_dim;
1265   }
1266   return true;
1267 }
1268 
SyncEmbeddingTable()1269 void EmbeddingCachePrefetchActor::SyncEmbeddingTable() {
1270   std::lock_guard<std::mutex> locker(sync_embedding_table_mutex_);
1271   // Do not synchronize in case of abnormally finalizing.
1272   if (!running_) {
1273     return;
1274   }
1275 
1276   if (finish_sync_embedding_table_) {
1277     return;
1278   }
1279   if (!initialized_) {
1280     return;
1281   }
1282   if (!SyncHostEmbeddingTable()) {
1283     MS_LOG(ERROR) << "SyncHostEmbeddingTable failed.";
1284   }
1285   if (!SyncDeviceEmbeddingTable()) {
1286     MS_LOG(ERROR) << "SyncDeviceEmbeddingTable failed.";
1287   }
1288   finish_sync_embedding_table_ = true;
1289 }
1290 
SyncHostEmbeddingTable()1291 bool EmbeddingCachePrefetchActor::SyncHostEmbeddingTable() {
1292   MS_ERROR_IF_NULL(embedding_cache_table_manager.host_hash_map_);
1293   const auto &ids_indices_pairs = embedding_cache_table_manager.host_hash_map_->Export();
1294   size_t swap_indices_lens = ids_indices_pairs.size();
1295   if (swap_indices_lens == 0) {
1296     return true;
1297   }
1298 
1299   std::unique_ptr<int[]> host_to_server_ids_ptr = std::make_unique<int[]>(swap_indices_lens);
1300   MS_ERROR_IF_NULL(host_to_server_ids_ptr);
1301   std::unique_ptr<int[]> host_to_server_indices_ptr = std::make_unique<int[]>(swap_indices_lens);
1302   MS_ERROR_IF_NULL(host_to_server_indices_ptr);
1303   size_t idx = 0;
1304   MS_EXCEPTION_IF_NULL(emb_ops_);
1305   const auto &modified_ids = emb_ops_->modified_ids();
1306   for (const auto &item : ids_indices_pairs) {
1307     if (modified_ids.find(item.first) != modified_ids.end()) {
1308       host_to_server_ids_ptr[idx] = item.first;
1309       host_to_server_indices_ptr[idx++] = item.second;
1310     }
1311   }
1312   swap_indices_lens = idx;
1313   if (swap_indices_lens == 0) {
1314     return true;
1315   }
1316   for (const auto &item : embedding_cache_table_manager.hash_tables_) {
1317     const auto &hash_info = item.second;
1318     std::vector<float> swap_out_data;
1319     auto embedding_size = hash_info.embedding_size;
1320     swap_out_data.resize(swap_indices_lens * embedding_size);
1321     auto host_hash_table_addr = hash_info.host_address;
1322     MS_ERROR_IF_NULL(host_hash_table_addr);
1323     RETURN_IF_FALSE(LookupLocalHostCache(embedding_size, swap_indices_lens, host_hash_table_addr,
1324                                          host_to_server_indices_ptr.get(), swap_out_data.data()));
1325 
1326     RETURN_IF_FALSE_WITH_LOG(
1327       PushEmbeddingsToRemote(hash_info.param_key_, host_to_server_ids_ptr.get(), swap_indices_lens,
1328                              swap_out_data.data(), swap_out_data.size() * sizeof(float)),
1329       "Push embeddings to remote failed.");
1330   }
1331   return true;
1332 }
1333 
SyncDeviceEmbeddingTable()1334 bool EmbeddingCachePrefetchActor::SyncDeviceEmbeddingTable() {
1335   const auto &device_hash_map = embedding_cache_table_manager.device_hash_map_;
1336   MS_ERROR_IF_NULL(device_hash_map);
1337   const auto &ids_indices_pairs = device_hash_map->Export();
1338   size_t swap_indices_lens = ids_indices_pairs.size();
1339   if (swap_indices_lens == 0) {
1340     return true;
1341   }
1342   MS_ERROR_IF_NULL(device_context_);
1343   MS_ERROR_IF_NULL(device_context_->device_res_manager_);
1344   std::unique_ptr<int[]> device_to_server_ids_ptr = std::make_unique<int[]>(swap_indices_lens);
1345   MS_ERROR_IF_NULL(device_to_server_ids_ptr);
1346   std::unique_ptr<int[]> device_to_server_indices_ptr = std::make_unique<int[]>(swap_indices_lens);
1347   MS_ERROR_IF_NULL(device_to_server_indices_ptr);
1348   size_t idx = 0;
1349   for (const auto &item : ids_indices_pairs) {
1350     device_to_server_ids_ptr[idx] = item.first;
1351     device_to_server_indices_ptr[idx++] = item.second;
1352   }
1353   for (const auto &item : embedding_cache_table_manager.hash_tables_) {
1354     const auto &hash_info = item.second;
1355     std::vector<float> swap_out_data;
1356     auto embedding_size = hash_info.embedding_size;
1357     swap_out_data.resize(swap_indices_lens * embedding_size);
1358     std::unique_ptr<float[]> device_hash_table_addr_tmp =
1359       std::make_unique<float[]>(device_hash_map->hash_capacity() * embedding_size);
1360     MS_ERROR_IF_NULL(device_hash_table_addr_tmp);
1361 
1362     auto hash_table_addr = reinterpret_cast<float *>(hash_info.address.addr);
1363     MS_ERROR_IF_NULL(hash_table_addr);
1364     auto hash_table_size = hash_info.address.size;
1365     RETURN_IF_FALSE_WITH_LOG(
1366       DeviceEmbeddingOperation::MemcpyDeviceToHostAsync(device_hash_table_addr_tmp.get(), hash_table_addr,
1367                                                         hash_table_size, device_context_, stream_id_),
1368       "Memcpy device to host asynchronously failed.");
1369     RETURN_IF_FALSE_WITH_LOG(device_context_->device_res_manager_->SyncStream(stream_id_),
1370                              "Synchronize stream failed.");
1371     RETURN_IF_FALSE(LookupLocalHostCache(embedding_size, swap_indices_lens, device_hash_table_addr_tmp.get(),
1372                                          device_to_server_indices_ptr.get(), swap_out_data.data()));
1373 
1374     RETURN_IF_FALSE_WITH_LOG(
1375       PushEmbeddingsToRemote(hash_info.param_key_, device_to_server_ids_ptr.get(), swap_indices_lens,
1376                              swap_out_data.data(), swap_out_data.size() * sizeof(float)),
1377       "Push embeddings to remote failed.");
1378   }
1379   return true;
1380 }
1381 
FinalizeRemote()1382 bool EmbeddingCachePrefetchActor::FinalizeRemote() {
1383   for (size_t i = 0; i < server_num_; i++) {
1384     size_t embedding_dim = 1;
1385     int id = 0;
1386     float value = 0.0;
1387     RETURN_IF_FALSE_WITH_LOG(SendToRemote(distributed::kLookupEmbeddingCache, 0, i, embedding_dim, &id, sizeof(int),
1388                                           &value, sizeof(float), true),
1389                              "Send finalize request to remote failed.");
1390   }
1391 
1392   return true;
1393 }
1394 
channel_name()1395 const std::string &EmbeddingCachePrefetchActor::channel_name() {
1396   std::lock_guard<std::mutex> locker(channel_mutex_);
1397   return channel_name_;
1398 }
1399 
set_channel_name(const std::string & channel_name)1400 void EmbeddingCachePrefetchActor::set_channel_name(const std::string &channel_name) {
1401   if (channel_name_ == channel_name) {
1402     return;
1403   }
1404   std::lock_guard<std::mutex> locker(channel_mutex_);
1405   channel_name_ = channel_name;
1406 }
1407 
WaitDataChannelInit()1408 void EmbeddingCachePrefetchActor::WaitDataChannelInit() {
1409   MS_LOG(INFO) << "Begin wait embedding cache data channel init.";
1410   auto channel = channel_name();
1411   if (channel.empty()) {
1412     std::unique_lock<std::mutex> locker(data_mutex_);
1413     data_parser_.wait(locker, [this] { return !channel_name_.empty() || running_ == false; });
1414     if (!running_) {
1415       return;
1416     }
1417   }
1418   MS_LOG(INFO) << "End wait embedding cache data channel init.";
1419 }
1420 
WaitInitParametersOnRemote()1421 void EmbeddingCachePrefetchActor::WaitInitParametersOnRemote() {
1422   std::unique_lock<std::mutex> locker(data_mutex_);
1423   // Note: wait to finish embedding lookup from remote.
1424   finish_init_parameters_on_remote_ = true;
1425   data_parser_.notify_one();
1426 }
1427 
SetErrorInfo(const std::string & error_info)1428 void EmbeddingCachePrefetchActor::SetErrorInfo(const std::string &error_info) {
1429   static std::mutex mtx;
1430   std::lock_guard<std::mutex> lock(mtx);
1431   error_info_ = error_info;
1432 }
1433 
BuildRpcOperators()1434 void EmbeddingCachePrefetchActor::BuildRpcOperators() {
1435   // The cache operation support LookupEmbeddingCache and UpdateEmbeddingCache currently.
1436   for (const auto &cache_op : distributed::kEmbeddingCacheOps) {
1437     rpc_operators_[cache_op] = std::vector<SendRecvPairList>();
1438     rpc_operators_[cache_op].resize(server_num_);
1439   }
1440 
1441   auto node = distributed::cluster::ClusterContext::instance()->node();
1442   MS_EXCEPTION_IF_NULL(node);
1443   uint32_t worker_rank_id = node->rank_id();
1444 
1445   // Create sender and receiver pairs for different cache operation, server and parameter key.
1446   for (auto &item : rpc_operators_) {
1447     const std::string &cache_op = item.first;
1448     std::vector<SendRecvPairList> &send_recv_pair_lists = item.second;
1449     for (uint32_t i = 0; i < server_num_; i++) {
1450       SendRecvPairList &send_recv_pair_list = send_recv_pair_lists[i];
1451       send_recv_pair_list.resize(embedding_cache_table_manager.hash_tables_.size());
1452 
1453       for (const auto &table : embedding_cache_table_manager.hash_tables_) {
1454         int32_t key = table.second.param_key_;
1455         if (key >= SizeToInt(embedding_cache_table_manager.hash_tables_.size()) || key < 0) {
1456           MS_LOG(EXCEPTION) << "Invalid parameter key: " << key;
1457         }
1458 
1459         send_recv_pair_list[key] = CreateSenderReceiverPair(worker_rank_id, i, cache_op, key, cpu_device_context_);
1460       }
1461     }
1462   }
1463 }
1464 
LinkRpcOperators()1465 void EmbeddingCachePrefetchActor::LinkRpcOperators() {
1466   std::vector<SenderPtr> senders;
1467   std::vector<ReceiverPtr> receivers;
1468   for (const auto &item : rpc_operators_) {
1469     const std::vector<SendRecvPairList> &send_recv_pair_lists = item.second;
1470     for (const SendRecvPairList &send_recv_pair_list : send_recv_pair_lists) {
1471       for (const SendRecvPair &pair : send_recv_pair_list) {
1472         senders.push_back(pair.first);
1473         receivers.push_back(pair.second);
1474       }
1475     }
1476   }
1477 
1478   // Must start server and register route table before looking up route and connecting.
1479   // Start servers of receiver and register route table.
1480   for (auto &receiver : receivers) {
1481     MS_EXCEPTION_IF_NULL(receiver);
1482     if (!receiver->StartServer()) {
1483       MS_LOG(EXCEPTION) << "Failed to start server for the receiver.";
1484     }
1485   }
1486 
1487   // Lookup route and connect to servers for sender.
1488   for (auto &sender : senders) {
1489     MS_EXCEPTION_IF_NULL(sender);
1490     if (!sender->ConnectServer()) {
1491       MS_LOG(EXCEPTION) << "Failed to connect servers for the sender.";
1492     }
1493   }
1494 }
1495 
Send(const std::vector<ShapeVector> & shapes,const std::vector<TypeId> data_types,const AddressPtrList & data_list,bool finalize_remote,bool sync) const1496 bool Sender::Send(const std::vector<ShapeVector> &shapes, const std::vector<TypeId> data_types,
1497                   const AddressPtrList &data_list, bool finalize_remote, bool sync) const {
1498   MS_ERROR_IF_NULL(receiver_);
1499   auto message = BuildRpcMessage(shapes, data_types, data_list, receiver_->get_url(), server_url_, finalize_remote);
1500   MS_ERROR_IF_NULL(message);
1501   MS_ERROR_IF_NULL(client_);
1502   if (sync) {
1503     return client_->SendSync(std::move(message));
1504   }
1505 
1506   client_->SendAsync(std::move(message));
1507   return true;
1508 }
1509 
~Sender()1510 Sender::~Sender() {
1511   if (client_) {
1512     try {
1513       if (!client_->Disconnect(server_url_)) {
1514         MS_LOG(ERROR) << "Failed to disconnect tcp client.";
1515       }
1516       client_->Finalize();
1517       client_ = nullptr;
1518     } catch (const std::exception &e) {
1519       MS_LOG(ERROR) << "Failed to disconnect and finalize tcp client, error message: " << e.what();
1520     }
1521   }
1522   receiver_ = nullptr;
1523 }
1524 
ConnectServer()1525 bool Sender::ConnectServer() {
1526   client_ = std::make_unique<TCPClient>();
1527   MS_ERROR_IF_NULL(client_);
1528   if (!client_->Initialize()) {
1529     MS_LOG(ERROR) << "Failed to initialize tcp server for send actor.";
1530     return false;
1531   }
1532 
1533   // Lookup peer receiver addresses.
1534   MS_ERROR_IF_NULL(route_table_proxy_);
1535   auto peer_actor_address = route_table_proxy_->LookupRoute(inter_process_edge_);
1536   server_url_ = peer_actor_address.ip() + ":" + std::to_string(peer_actor_address.port());
1537 
1538   auto free_callback = std::bind(&Sender::FreeMessage, this, std::placeholders::_1);
1539   size_t retry_count = 60;
1540 
1541   bool ret = client_->Connect(server_url_, retry_count, free_callback);
1542   if (!ret) {
1543     MS_LOG(ERROR) << "Failed to connect to server of edge: " << inter_process_edge_ << ", server_url: " << server_url_;
1544     return false;
1545   }
1546 
1547   MS_LOG(INFO) << "Successfully connect to server " << server_url_
1548                << ", inter process edge name: " << inter_process_edge_;
1549   return true;
1550 }
1551 
BuildRpcMessage(const std::vector<ShapeVector> & shapes,const std::vector<TypeId> data_types,const AddressPtrList & data_list,const std::string & from_url,const std::string & to_url,bool finalize_remote) const1552 std::unique_ptr<MessageBase> Sender::BuildRpcMessage(const std::vector<ShapeVector> &shapes,
1553                                                      const std::vector<TypeId> data_types,
1554                                                      const AddressPtrList &data_list, const std::string &from_url,
1555                                                      const std::string &to_url, bool finalize_remote) const {
1556   std::unique_ptr<MessageBase> message = std::make_unique<MessageBase>();
1557   MS_ERROR_IF_NULL_W_RET_VAL(message, nullptr);
1558   message->from = AID("", from_url);
1559   message->to = AID("", to_url);
1560 
1561   if (shapes.size() != data_list.size()) {
1562     MS_LOG(ERROR) << "The shape list size[" << shapes.size() << "] should be equal to data list size["
1563                   << data_list.size() << "]";
1564   }
1565 
1566   if (data_types.size() != data_list.size()) {
1567     MS_LOG(ERROR) << "The date type list size[" << data_types.size() << "] should be equal to data list size["
1568                   << data_list.size() << "]";
1569   }
1570 
1571   RpcDataPtr rpc_data = nullptr;
1572   size_t data_size = CalDataSize(shapes, data_types, data_list, finalize_remote);
1573   MS_EXCEPTION_IF_NULL(cpu_device_context_);
1574   MS_EXCEPTION_IF_NULL(cpu_device_context_->device_res_manager_);
1575   rpc_data = static_cast<RpcDataPtr>(cpu_device_context_->device_res_manager_->AllocateMemory(data_size));
1576   MS_EXCEPTION_IF_NULL(rpc_data);
1577   message->data = rpc_data;
1578   message->size = data_size;
1579 
1580   errno_t ret;
1581   size_t offset = 0;
1582   for (size_t i = 0; i < data_list.size(); i++) {
1583     const ShapeVector &shape = shapes[i];
1584     const AddressPtr &data = data_list[i];
1585     const TypeId &type_id = data_types[i];
1586 
1587     rpc::DynamicShapeMessage ds_pb_msg;
1588     ds_pb_msg.set_type_id(type_id);
1589     *ds_pb_msg.mutable_shape_vector() = {shape.begin(), shape.end()};
1590     std::string ds_pb_msg_str = ds_pb_msg.SerializeAsString();
1591 
1592     // Message format:
1593     // |RPC_DYNAMIC_SHAPE_DATA | dynamic shape PB data size |---dynamic shape PB data----|---real data----|
1594     // 1. The dynamic shape header.
1595     if ((ret = memcpy_s(rpc_data + offset, strlen(kRpcDynamicShapeData), kRpcDynamicShapeData,
1596                         strlen(kRpcDynamicShapeData))) != EOK) {
1597       MS_LOG(EXCEPTION) << "Failed to memcpy_s for kRpcDynamicShapeData, errno[" << ret << "].";
1598     }
1599     offset += strlen(kRpcDynamicShapeData);
1600 
1601     // 2. The size of the protobuf DynamicShapeMessage.
1602     size_t ds_pb_msg_size = ds_pb_msg_str.size();
1603     if ((ret = memcpy_s(rpc_data + offset, sizeof(ds_pb_msg_size), &ds_pb_msg_size, sizeof(ds_pb_msg_size))) != EOK) {
1604       MS_LOG(EXCEPTION) << "Failed to memcpy_s for pb message size, errno[" << ret << "].";
1605     }
1606     offset += sizeof(ds_pb_msg_size);
1607 
1608     // 3. Protobuf DynamicShapeMessage.
1609     if ((ret = memcpy_s(rpc_data + offset, ds_pb_msg_str.size(), ds_pb_msg_str.c_str(), ds_pb_msg_str.size())) != EOK) {
1610       MS_LOG(EXCEPTION) << "Failed to memcpy_s for pb message, errno[" << ret << "].";
1611     }
1612     offset += ds_pb_msg_str.size();
1613 
1614     // 4. The real data buffer need to be sent.
1615     MS_EXCEPTION_IF_NULL(data);
1616     if ((ret = memcpy_s(rpc_data + offset, data->size, data->addr, data->size)) != EOK) {
1617       MS_LOG(EXCEPTION) << "Failed to memcpy_s for real data, errno[" << ret << "].";
1618     }
1619     offset += data->size;
1620   }
1621 
1622   // 5. Finalize remote command.
1623   if (finalize_remote) {
1624     size_t header_len = strlen(distributed::kFinalizeMuxRecvActor);
1625     if ((ret = memcpy_s(rpc_data + offset, header_len, distributed::kFinalizeMuxRecvActor, header_len)) != EOK) {
1626       MS_LOG(EXCEPTION) << "Failed to memcpy_s for kFinalizeMuxRecvActor, errno[" << ret << "].";
1627     }
1628     offset += header_len;
1629 
1630     if ((ret = memcpy_s(rpc_data + offset, sizeof(finalize_remote), &finalize_remote, sizeof(finalize_remote))) !=
1631         EOK) {
1632       MS_LOG(EXCEPTION) << "Failed to memcpy_s for finalize_remote, errno[" << ret << "].";
1633     }
1634   }
1635 
1636   return message;
1637 }
1638 
FreeMessage(void * data)1639 bool Sender::FreeMessage(void *data) {
1640   MS_EXCEPTION_IF_NULL(cpu_device_context_);
1641   MS_EXCEPTION_IF_NULL(cpu_device_context_->device_res_manager_);
1642   MS_ERROR_IF_NULL_W_RET_VAL(data, false);
1643   cpu_device_context_->device_res_manager_->FreeMemory(data);
1644   return true;
1645 }
1646 
CalDataSize(const std::vector<ShapeVector> & shapes,const std::vector<TypeId> data_types,const AddressPtrList & data_list,bool finalize_remote) const1647 size_t Sender::CalDataSize(const std::vector<ShapeVector> &shapes, const std::vector<TypeId> data_types,
1648                            const AddressPtrList &data_list, bool finalize_remote) const {
1649   size_t data_size = 0;
1650   for (size_t i = 0; i < data_list.size(); i++) {
1651     const ShapeVector &shape = shapes[i];
1652     const AddressPtr &data = data_list[i];
1653     const TypeId &type_id = data_types[i];
1654 
1655     rpc::DynamicShapeMessage ds_pb_msg;
1656     ds_pb_msg.set_type_id(type_id);
1657     *ds_pb_msg.mutable_shape_vector() = {shape.begin(), shape.end()};
1658     std::string ds_pb_msg_str = ds_pb_msg.SerializeAsString();
1659     data_size += strlen(kRpcDynamicShapeData);
1660     data_size += sizeof(size_t);
1661     data_size += ds_pb_msg_str.size();
1662     MS_EXCEPTION_IF_NULL(data);
1663     data_size += data->size;
1664   }
1665   if (finalize_remote) {
1666     data_size += strlen(distributed::kFinalizeMuxRecvActor);
1667     data_size += sizeof(finalize_remote);
1668   }
1669   return data_size;
1670 }
1671 
~Receiver()1672 Receiver::~Receiver() {
1673   if (server_) {
1674     try {
1675       server_->Finalize();
1676       server_ = nullptr;
1677     } catch (const std::exception &e) {
1678       MS_LOG(ERROR) << "Failed to finalize tcp server, error message: " << e.what();
1679     }
1680   }
1681   received_buffer_ = nullptr;
1682 }
1683 
Receive()1684 std::unique_ptr<std::vector<char>> Receiver::Receive() {
1685   std::unique_lock<std::mutex> locker(received_msg_mtx_);
1686   // The maximum time(300 seconds) to wait to receive message.
1687   const int64_t longest_time_to_wait = 300;
1688   auto ret = received_msg_cv_.wait_for(locker, std::chrono::seconds(longest_time_to_wait),
1689                                        [this] { return received_msg_.load(); });
1690   if (!ret) {
1691     MS_LOG(ERROR) << "Receive message timeout";
1692     return nullptr;
1693   }
1694 
1695   std::unique_ptr<std::vector<char>> output = std::move(received_buffer_);
1696   MS_EXCEPTION_IF_NULL(output);
1697   received_msg_ = false;
1698   return output;
1699 }
1700 
StartServer()1701 bool Receiver::StartServer() {
1702   // 1. Create a tcp server and start listening.
1703   server_ = std::make_unique<TCPServer>();
1704   MS_EXCEPTION_IF_NULL(server_);
1705 
1706   std::function<void *(size_t size)> allocate_callback =
1707     std::bind(&Receiver::AllocateMessage, this, std::placeholders::_1);
1708   if (!server_->Initialize(allocate_callback)) {
1709     MS_LOG(EXCEPTION) << "Failed to initialize tcp server for recv actor";
1710   }
1711   ip_ = server_->GetIP();
1712   port_ = server_->GetPort();
1713   std::string server_url = ip_ + ":" + std::to_string(port_);
1714 
1715   // 2. Set the message handler of the server.
1716   server_->SetMessageHandler(std::bind(&Receiver::HandleMessage, this, std::placeholders::_1));
1717 
1718   // 3. Register the server address to route table. The server should not be connected before this step is done.
1719   MS_LOG(INFO) << "Start server for receiver. Server address: " << server_url
1720                << ", inter process edge name: " << inter_process_edge_;
1721   distributed::cluster::topology::ActorAddress recv_actor_addresss;
1722   recv_actor_addresss.set_actor_id(inter_process_edge_);
1723   recv_actor_addresss.set_ip(ip_);
1724   recv_actor_addresss.set_port(port_);
1725   MS_EXCEPTION_IF_NULL(route_table_proxy_);
1726   if (!route_table_proxy_->RegisterRoute(inter_process_edge_, recv_actor_addresss)) {
1727     MS_LOG(EXCEPTION) << "Failed to register route for " << inter_process_edge_ << " " << server_url
1728                       << " when starting server.";
1729   }
1730   return true;
1731 }
1732 
ParseDynamicShapeData(const char * msg_body,size_t msg_len,std::pair<const void *,size_t> * data) const1733 bool Receiver::ParseDynamicShapeData(const char *msg_body, size_t msg_len,
1734                                      std::pair<const void *, size_t> *data) const {
1735   MS_ERROR_IF_NULL(msg_body);
1736   MS_ERROR_IF_NULL(data);
1737   // 1. Check whether received data is valid dynamic shape data.
1738   size_t dynamic_shape_header_size = strlen(kRpcDynamicShapeData);
1739   if (msg_len <= dynamic_shape_header_size) {
1740     MS_LOG(ERROR) << "Received data is not dynamic shape, data length: " << msg_len;
1741     return false;
1742   }
1743   std::string msg_dynamic_shape_header(msg_body, dynamic_shape_header_size);
1744   if (msg_dynamic_shape_header != kRpcDynamicShapeData) {
1745     MS_LOG(ERROR) << "Received data is not dynamic shape, not find dynamic shape header: " << kRpcDynamicShapeData;
1746     return false;
1747   }
1748 
1749   size_t offset = dynamic_shape_header_size;
1750   // 2. Parse the size of dynamic shape serialized protobuf message.
1751   if (offset + sizeof(size_t) >= msg_len) {
1752     MS_LOG(ERROR) << "Received data is incomplete";
1753     return false;
1754   }
1755   size_t dynamic_shape_pb_size = *(reinterpret_cast<const size_t *>(msg_body + offset));
1756   offset += sizeof(size_t);
1757   if (offset + dynamic_shape_pb_size >= msg_len) {
1758     MS_LOG(ERROR) << "The dynamic shape pb data is incomplete";
1759     return false;
1760   }
1761 
1762   // 3. Deserialize the dynamic shape serialized protobuf message.
1763   rpc::DynamicShapeMessage pb_msg;
1764   (void)pb_msg.ParseFromArray(msg_body + offset, dynamic_shape_pb_size);
1765   offset += dynamic_shape_pb_size;
1766   size_t received_data_len = msg_len - offset;
1767 
1768   // 4. The data integrity check.
1769   ShapeVector shapes(pb_msg.shape_vector().begin(), pb_msg.shape_vector().end());
1770   TypeId data_type = static_cast<TypeId>(pb_msg.type_id());
1771   int64_t expected_data_len = 1;
1772   if (!kernel::GetShapeSize(shapes, TypeIdToType(data_type), &expected_data_len)) {
1773     MS_LOG(ERROR) << "Getting shape size for shape " << shapes << " failed.";
1774     return false;
1775   }
1776   if (LongToSize(expected_data_len) != received_data_len) {
1777     MS_LOG(ERROR) << "Received data is incomplete, expected size: " << expected_data_len
1778                   << ", but received data size: " << received_data_len;
1779     return false;
1780   }
1781   // 5. Get real data addr and size.
1782   *data = std::make_pair(msg_body + offset, received_data_len);
1783   return true;
1784 }
1785 
HandleMessage(MessageBase * const msg)1786 MessageBase *Receiver::HandleMessage(MessageBase *const msg) {
1787   if (msg == nullptr) {
1788     MS_LOG(WARNING) << "Received message pointer is nullptr";
1789     return distributed::rpc::NULL_MSG;
1790   }
1791 
1792   RpcDataPtr data = static_cast<RpcDataPtr>(msg->data);
1793   size_t data_size = msg->size;
1794   // The data pair: <addr of data, size of data>.
1795   std::pair<const void *, size_t> real_data;
1796   // Get real data addr and size.
1797   if (!ParseDynamicShapeData(data, data_size, &real_data)) {
1798     MS_LOG(EXCEPTION) << "Parse dynamic shape data failed.";
1799   }
1800 
1801   std::unique_lock<std::mutex> locker(received_msg_mtx_);
1802   received_buffer_ = std::make_unique<std::vector<char>>();
1803   received_buffer_->resize(real_data.second);
1804   MS_EXCEPTION_IF_NULL(real_data.first);
1805 
1806   int ret = memcpy_s(received_buffer_->data(), received_buffer_->size(), real_data.first, real_data.second);
1807   if (ret != 0) {
1808     MS_LOG(EXCEPTION) << "Memcpy for received data failed, errno[" << ret << "]";
1809   }
1810 
1811   received_msg_ = true;
1812   received_msg_cv_.notify_one();
1813 
1814   MS_EXCEPTION_IF_NULL(cpu_device_context_);
1815   MS_EXCEPTION_IF_NULL(cpu_device_context_->device_res_manager_);
1816   cpu_device_context_->device_res_manager_->FreeMemory(data);
1817 
1818   delete msg;
1819   return distributed::rpc::NULL_MSG;
1820 }
1821 
AllocateMessage(size_t size)1822 void *Receiver::AllocateMessage(size_t size) {
1823   MS_EXCEPTION_IF_NULL(cpu_device_context_);
1824   MS_EXCEPTION_IF_NULL(cpu_device_context_->device_res_manager_);
1825   void *data = cpu_device_context_->device_res_manager_->AllocateMemory(size);
1826   MS_EXCEPTION_IF_NULL(data);
1827   return data;
1828 }
1829 }  // namespace runtime
1830 }  // namespace mindspore
1831