1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h"
18 #include <limits>
19 #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
20 #include "kernel/common_utils.h"
21 #include "runtime/graph_scheduler/actor/rpc/rpc_actor.h"
22 #include "proto/topology.pb.h"
23 #include "include/backend/distributed/constants.h"
24 #include "include/backend/distributed/rpc/tcp/constants.h"
25 #include "runtime/graph_scheduler/actor/embedding_cache/device_dense_embedding_operation.h"
26 #include "runtime/graph_scheduler/actor/embedding_cache/device_sparse_embedding_operation.h"
27 #include "include/backend/distributed/embedding_cache/data_queue_manager.h"
28
29 namespace mindspore {
30 namespace runtime {
31 using distributed::IdDataInfo;
32 using distributed::IndexDataInfo;
33
34 using distributed::DataQueueManager;
35 using distributed::cluster::ClusterContext;
36 using mindspore::session::KernelGraph;
37 constexpr size_t kDefaultQueueCapacity = 128;
38
39 namespace {
40 // Generate unique inter process edge name, format:
41 // src role + src rank id -> dst role + dst rank id + embedding cache operation + parameter key.
GenerateInterProcessEdge(const std::string & src_role,uint32_t src_rank,const std::string & dst_role,uint32_t dst_rank,const std::string & cache_operation,int32_t param_key)42 std::string GenerateInterProcessEdge(const std::string &src_role, uint32_t src_rank, const std::string &dst_role,
43 uint32_t dst_rank, const std::string &cache_operation, int32_t param_key) {
44 std::string edge = src_role + std::to_string(src_rank) + "->" + dst_role + std::to_string(dst_rank) + "_" +
45 cache_operation + "_" + distributed::kParameterKey + std::to_string(param_key);
46 return edge;
47 }
48
CreateRouteTableProxy()49 ActorRouteTableProxyPtr CreateRouteTableProxy() {
50 auto cgn = std::dynamic_pointer_cast<distributed::cluster::topology::ComputeGraphNode>(
51 ClusterContext::instance()->node_base());
52 ActorRouteTableProxyPtr actor_route_table_proxy = std::make_shared<ActorRouteTableProxy>(cgn);
53 MS_EXCEPTION_IF_NULL(actor_route_table_proxy);
54 return actor_route_table_proxy;
55 }
56
57 // Create a sender and receiver pair,The sender and receiver are paired.
58 // When creating a sender, need to create and specify the receiver paired with it in advance.
CreateSenderReceiverPair(uint32_t worker_rank,uint32_t server_rank,const std::string & cache_operation,int32_t param_key,device::DeviceContext * cpu_device_context)59 SendRecvPair CreateSenderReceiverPair(uint32_t worker_rank, uint32_t server_rank, const std::string &cache_operation,
60 int32_t param_key, device::DeviceContext *cpu_device_context) {
61 // Create sender and receiver pair.
62 ReceiverPtr receiver = std::make_shared<Receiver>(cpu_device_context);
63 MS_EXCEPTION_IF_NULL(receiver);
64 SenderPtr sender = std::make_shared<Sender>(cpu_device_context);
65 MS_EXCEPTION_IF_NULL(sender);
66 sender->set_receiver(receiver);
67
68 // Set inter process edge
69 receiver->set_inter_process_edge_name(GenerateInterProcessEdge(distributed::kEnvRoleOfPServer, server_rank,
70 distributed::kEnvRoleOfWorker, worker_rank,
71 cache_operation, param_key));
72 sender->set_inter_process_edge_name(GenerateInterProcessEdge(distributed::kEnvRoleOfWorker, worker_rank,
73 distributed::kEnvRoleOfPServer, server_rank,
74 cache_operation, param_key));
75
76 // Set route table proxy.
77 receiver->set_actor_route_table_proxy(CreateRouteTableProxy());
78 sender->set_actor_route_table_proxy(CreateRouteTableProxy());
79
80 return std::make_pair(sender, receiver);
81 }
82
83 // Get cache operation service id which is used to decide which set of cache services to request.
84 // The server side executes the corresponding service according to this id.
GetCacheOpsServiceId(const std::string & cache_operation,int32_t param_key)85 int32_t GetCacheOpsServiceId(const std::string &cache_operation, int32_t param_key) {
86 static mindspore::HashMap<std::string, int32_t> cache_ops_to_index;
87 if (cache_ops_to_index.empty()) {
88 int32_t cnt = 0;
89 for (const auto &cache_op : distributed::kEmbeddingCacheOps) {
90 cache_ops_to_index[cache_op] = cnt++;
91 }
92 }
93
94 auto iter = cache_ops_to_index.find(cache_operation);
95 if (iter == cache_ops_to_index.end()) {
96 MS_LOG(EXCEPTION) << "Can not find index for cache operation: " << cache_operation;
97 }
98
99 int32_t id = SizeToInt(distributed::kEmbeddingCacheOps.size()) * param_key + iter->second;
100 return id;
101 }
102
103 // Parallelly generate fixed or random numbers continuously using specified algorithm.
104 template <typename T, typename Generator, typename Distribution, typename... Args>
GenerateDistributionParallel(size_t size,T * output,Args...args)105 void GenerateDistributionParallel(size_t size, T *output, Args... args) {
106 std::thread threads[kMaxThreadNum];
107 std::random_device rd;
108 const std::uint64_t seed = rd();
109
110 // 1. Compute total thread number need to parallel generate distribution and the size of new numbers that each
111 // thread need to generate.
112 // Once calculation of the normal distribution may produce two random values, so each thread should be responsible for
113 // producing an even number of random numbers, except for the last thread.
114 auto [thread_num, size_per_thread] = random::ComputeTaskNumSize(size, kMaxThreadNum);
115
116 // For performance, multi-thread concurrency is not required when the total size is small.
117 if (thread_num == 1) {
118 random::GenerateRandoms<T, Generator, Distribution, Args...>(seed, 0, output, size, args...);
119 return;
120 }
121
122 // 2. Parallelly generate specified distribution using specified algorithm.
123 // Note that the offset need to be set to 'Generator' to prevent generating same sequence of each thread.
124 size_t offset = 0;
125 for (size_t i = 0; i < thread_num; ++i) {
126 size_t task_len = ((i < (thread_num - 1)) ? size_per_thread : (size - ((thread_num - 1) * size_per_thread)));
127 threads[i] = std::thread(&random::GenerateRandoms<T, Generator, Distribution, Args...>, seed, offset,
128 output + offset, task_len, args...);
129 offset += task_len;
130 }
131
132 for (size_t j = 0; j < thread_num; j++) {
133 threads[j].join();
134 }
135 }
136
DeduplicateId(UniqueIds * unique_ids)137 void DeduplicateId(UniqueIds *unique_ids) {
138 MS_EXCEPTION_IF_NULL(unique_ids);
139
140 constexpr size_t kMaxParallelNum = 32;
141 size_t parallel_num = unique_ids->multi_batch_data_.size();
142 if (parallel_num > kMaxParallelNum) {
143 MS_LOG(EXCEPTION) << "The parallel num: " << parallel_num
144 << " can not be greater than max parallel num: " << kMaxParallelNum;
145 }
146 std::thread threads[kMaxParallelNum];
147
148 std::vector<mindspore::HashSet<int>> unique_batch_ids_sets(parallel_num);
149 auto unique_task = [&](int *origin_batch_ids, size_t proc_len, mindspore::HashSet<int> *unique_set) {
150 (void)std::for_each(origin_batch_ids, origin_batch_ids + proc_len,
151 [&unique_set](int id) { (void)unique_set->insert(id); });
152 };
153
154 size_t i = 0;
155 for (; i < parallel_num; ++i) {
156 threads[i] = std::thread(unique_task, reinterpret_cast<int *>(unique_ids->multi_batch_data_.at(i)),
157 unique_ids->multi_batch_size_.at(i), &unique_batch_ids_sets[i]);
158 }
159
160 for (size_t j = 0; j < i; j++) {
161 threads[j].join();
162 }
163
164 for (size_t k = 1; k < parallel_num; ++k) {
165 auto end_iter = unique_batch_ids_sets[k].end();
166 for (auto iter = unique_batch_ids_sets[k].begin(); iter != end_iter; ++iter) {
167 unique_batch_ids_sets[0].insert(*iter);
168 }
169 }
170 const auto &unique_ids_set = unique_batch_ids_sets.front();
171 unique_ids->ids_num_ = unique_ids_set.size();
172 unique_ids->ids_ = new (std::nothrow) int[unique_ids->ids_num_];
173 MS_EXCEPTION_IF_NULL(unique_ids->ids_);
174 size_t index = 0;
175 auto unique_ids_ptr = unique_ids->ids_;
176 (void)std::for_each(unique_ids_set.begin(), unique_ids_set.end(), [&](int id) { unique_ids_ptr[index++] = id; });
177 }
178
TransformIdsToIndices(mindspore::HashMap<int,int> * unique_ids_to_indices,size_t batch_ids_num,int * batch_ids)179 void TransformIdsToIndices(mindspore::HashMap<int, int> *unique_ids_to_indices, size_t batch_ids_num, int *batch_ids) {
180 auto change_id_to_index_func = [&](int *batch_ids_ptr, size_t proc_len) {
181 for (size_t i = 0; i < proc_len; i++) {
182 batch_ids_ptr[i] = (*unique_ids_to_indices)[batch_ids_ptr[i]];
183 }
184 };
185
186 size_t thread_num = batch_ids_num / kMaxIdsPerThread + 1;
187 thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
188 std::thread threads[kMaxThreadNum];
189 size_t i = 0;
190 size_t offset = 0;
191
192 for (; i < thread_num; ++i) {
193 size_t proc_len = batch_ids_num / thread_num + (i < (batch_ids_num % thread_num) ? 1 : 0);
194 threads[i] = std::thread(change_id_to_index_func, batch_ids + offset, proc_len);
195 offset += proc_len;
196 }
197 if (offset != batch_ids_num) {
198 MS_LOG(WARNING) << "Check id in device inadequate, total:" << batch_ids_num << " checked:" << offset;
199 }
200
201 for (size_t j = 0; j < i; j++) {
202 threads[j].join();
203 }
204 }
205 } // namespace
206
Initialize()207 void EmbeddingCachePrefetchActor::Initialize() {
208 if (initialized_) {
209 return;
210 }
211 MS_EXCEPTION_IF_NULL(device_context_);
212 MS_EXCEPTION_IF_NULL(device_context_->device_res_manager_);
213 if (!device_context_->device_res_manager_->CreateStream(&stream_id_)) {
214 MS_LOG(EXCEPTION) << "Create stream failed.";
215 }
216
217 // Get embedding cache table info.
218 local_host_cache_size_ = embedding_cache_table_manager.host_cache_size_;
219 vocab_size_ = embedding_cache_table_manager.vocab_size_;
220 local_embedding_slice_bounds_ = embedding_cache_table_manager.local_embedding_slice_bounds_;
221 local_device_cache_bounds_ = embedding_cache_table_manager.local_device_cache_bounds_;
222
223 // Initialize CPU device context. The origin device context for embedding cache prefetch actor is GPU or NPU. But we
224 // still need the CPU device context to allocate host memory.
225 device::DeviceContextKey host_key = {"CPU", 0};
226 cpu_device_context_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(host_key);
227 MS_EXCEPTION_IF_NULL(cpu_device_context_);
228 cpu_device_context_->Initialize();
229
230 server_num_ = PSContext::instance()->server_num();
231 if (server_num_ == 0) {
232 MS_LOG(EXCEPTION) << "The number of servers is at least 1, but get 0";
233 }
234
235 // Build and link rpc operators.
236 BuildRpcOperators();
237 LinkRpcOperators();
238
239 // Create the device embedding operation.
240 emb_ops_ = new (std::nothrow) DeviceDenseEmbeddingOperation(
241 this, device_context_, local_embedding_slice_bounds_, local_device_cache_bounds_, &statistics_info_, stream_id_);
242 MS_EXCEPTION_IF_NULL(emb_ops_);
243 if (!emb_ops_->Initialize()) {
244 MS_LOG(ERROR) << "Failed to initialize the device embedding operation.";
245 }
246
247 // Get the id range of each server's embedding table slice.
248 emb_ops_->GetRemoteEmbeddingSliceBound(vocab_size_, server_num_, &remote_embedding_slice_bounds_);
249
250 initialized_ = true;
251 }
252
Finalize(bool finalize_remote)253 void EmbeddingCachePrefetchActor::Finalize(bool finalize_remote) {
254 std::lock_guard<std::mutex> lock(finalize_mutex_);
255 if (!initialized_ || finalized_) {
256 return;
257 }
258
259 StopPrefetchCachePipeline();
260 for (const auto &item : channel_locks_) {
261 const auto &channel_ptr = item.second;
262 MS_EXCEPTION_IF_NULL(channel_ptr);
263 channel_ptr->TryWakeChannel(true);
264 }
265 WaitPrefetchCacheFinish();
266
267 PsDataPrefetch::GetInstance().NotifyFinalize();
268
269 if (finalize_remote) {
270 (void)FinalizeRemote();
271 }
272
273 data_parser_.notify_all();
274
275 if (emb_ops_ != nullptr) {
276 delete emb_ops_;
277 emb_ops_ = nullptr;
278 }
279
280 rpc_operators_.clear();
281 finalized_ = true;
282 initialized_ = false;
283 }
284
IncreaseGraphStep(const std::string & channel_name)285 void EmbeddingCachePrefetchActor::IncreaseGraphStep(const std::string &channel_name) {
286 if (!running_) {
287 std::string error_info =
288 !error_info_.empty() ? error_info_ : "Embedding cache prefetch actor is finalized abnormally.";
289 MS_LOG(EXCEPTION) << error_info;
290 }
291 if (graph_step_ >= UINT64_MAX) {
292 MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t.";
293 }
294 if (graph_step_ == 0) {
295 MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameters_on_remote_;
296 std::unique_lock<std::mutex> locker(data_mutex_);
297 data_parser_.wait(locker, [this] { return ((finish_init_parameters_on_remote_ == true) || (running_ == false)); });
298 if (!running_) {
299 std::string error_info =
300 !error_info_.empty() ? error_info_ : "Embedding cache prefetch actor is finalized abnormally.";
301 MS_LOG(EXCEPTION) << error_info;
302 }
303 distributed::EmbeddingCacheTableManager::GetInstance().WaitForWarmUpHostCacheComplete();
304 MS_LOG(INFO) << "Graph running waiting embedding table init end.";
305 }
306 graph_step_++;
307 if (embedding_cache_table_manager.enable_pipeline()) {
308 if (channel_name != channel_name_) {
309 set_channel_name(channel_name);
310 // Create pipeline tasks for this channel
311 StartPrefetchCachePipeline(channel_name);
312 }
313 const auto &iter = channel_locks_.find(channel_name);
314 if (iter == channel_locks_.end()) {
315 MS_LOG(EXCEPTION) << "Can not find channel lock for channel: " << channel_name;
316 }
317 MS_EXCEPTION_IF_NULL(iter->second);
318 iter->second->TryWakeChannel();
319 } else {
320 set_channel_name(channel_name);
321 if (!PsDataPrefetch::GetInstance().TryWakeChannel(channel_name)) {
322 MS_LOG(EXCEPTION) << "TryWakeChannel failed, channel name: " << channel_name;
323 }
324 }
325 data_parser_.notify_one();
326 }
327
Run()328 void EmbeddingCachePrefetchActor::Run() {
329 running_ = true;
330
331 // Bind device to current thread to gain device control privileges
332 MS_EXCEPTION_IF_NULL(device_context_);
333 MS_EXCEPTION_IF_NULL(device_context_->device_res_manager_);
334 if (!device_context_->device_res_manager_->BindDeviceToCurrentThread(false)) {
335 MS_LOG(ERROR) << "Failed to bind device to current thread.";
336 running_ = false;
337 PsDataPrefetch::GetInstance().NotifyFinalize();
338 return;
339 }
340
341 // Wait initialize parameters on remote.
342 // Prevents the subsequent prefetch cache from failing due to the long initialization time of the large parameter on
343 // the remote side.
344 WaitInitParametersOnRemote();
345
346 // Wait data channel ready.
347 WaitDataChannelInit();
348 }
349
CreateChannelLock(const std::string & channel_name)350 void EmbeddingCachePrefetchActor::CreateChannelLock(const std::string &channel_name) {
351 if (channel_locks_.find(channel_name) != channel_locks_.end()) {
352 return;
353 }
354 auto sink_size = DataQueueManager::GetInstance().GetSinkSize(channel_name);
355 channel_locks_.emplace(channel_name, std::make_shared<PsDataChannel>(channel_name, sink_size));
356 }
357
CreateBlockQueue(const std::string & channel_name)358 void EmbeddingCachePrefetchActor::CreateBlockQueue(const std::string &channel_name) {
359 auto unique_ids_queue = std::make_shared<BlockingQueue<UniqueIds>>(kDefaultQueueCapacity);
360 auto cache_analysis_queue = std::make_shared<BlockingQueue<CacheAnalysis>>(kDefaultQueueCapacity);
361 auto ids_and_indices_queue = std::make_shared<BlockingQueue<IdsAndIndices>>(kDefaultQueueCapacity);
362 (void)channel_to_queues_.emplace(channel_name,
363 std::make_tuple(unique_ids_queue, cache_analysis_queue, ids_and_indices_queue));
364 }
365
StartPrefetchCachePipeline(const std::string & channel_name)366 void EmbeddingCachePrefetchActor::StartPrefetchCachePipeline(const std::string &channel_name) {
367 MS_LOG(INFO) << "Begin StartPrefetchCachePipeline for channel name: " << channel_name;
368 std::lock_guard<std::mutex> lock(pipeline_mutex_);
369 if (pipeline_stages_.find(channel_name) != pipeline_stages_.end()) {
370 return;
371 }
372
373 CreateChannelLock(channel_name);
374 CreateBlockQueue(channel_name);
375
376 auto thread_list = std::make_shared<std::vector<std::thread>>(kPipelineStageNum);
377 pipeline_stages_.emplace(channel_name, thread_list);
378
379 thread_list->at(kIndex0) = std::thread(&EmbeddingCachePrefetchActor::UniqueIdsTask, this, channel_name);
380 thread_list->at(kIndex1) = std::thread(&EmbeddingCachePrefetchActor::AnalyseCacheTask, this, channel_name);
381 thread_list->at(kIndex2) = std::thread(&EmbeddingCachePrefetchActor::UpdateCacheTask, this, channel_name);
382 thread_list->at(kIndex3) = std::thread(&EmbeddingCachePrefetchActor::TransformIdsToIndicesTask, this, channel_name);
383 MS_LOG(INFO) << "End StartPrefetchCachePipeline for channel name: " << channel_name;
384 }
385
StopPrefetchCachePipeline()386 void EmbeddingCachePrefetchActor::StopPrefetchCachePipeline() {
387 MS_LOG(INFO) << "Begin StopPrefetchCachePipeline";
388 std::lock_guard<std::mutex> lock(pipeline_mutex_);
389 running_ = false;
390 DataQueueManager::GetInstance().CloseAllQueues();
391 for (const auto &item : channel_to_queues_) {
392 const BlockingQueueTuple &queues_tuple = item.second;
393 const auto &unique_ids_queue = std::get<std::shared_ptr<BlockingQueue<UniqueIds>>>(queues_tuple);
394 MS_EXCEPTION_IF_NULL(unique_ids_queue);
395 unique_ids_queue->Close();
396
397 const auto &cache_analysis_queue = std::get<std::shared_ptr<BlockingQueue<CacheAnalysis>>>(queues_tuple);
398 MS_EXCEPTION_IF_NULL(cache_analysis_queue);
399 cache_analysis_queue->Close();
400
401 const auto &ids_and_indices_queue = std::get<std::shared_ptr<BlockingQueue<IdsAndIndices>>>(queues_tuple);
402 MS_EXCEPTION_IF_NULL(ids_and_indices_queue);
403 ids_and_indices_queue->Close();
404 }
405 MS_LOG(INFO) << "End StopPrefetchCachePipeline";
406 }
407
WaitPrefetchCacheFinish()408 void EmbeddingCachePrefetchActor::WaitPrefetchCacheFinish() {
409 std::lock_guard<std::mutex> lock(pipeline_mutex_);
410 for (auto &item : pipeline_stages_) {
411 const std::string &channel_name = item.first;
412 MS_LOG(INFO) << "Begin stop pipeline for channel: " << channel_name;
413 auto stage_threads = item.second;
414 MS_EXCEPTION_IF_NULL(stage_threads);
415 for (size_t i = 0; i < kPipelineStageNum; ++i) {
416 if (stage_threads->at(i).joinable()) {
417 stage_threads->at(i).join();
418 }
419 }
420 MS_LOG(INFO) << "End stop pipeline for channel: " << channel_name;
421 }
422 }
423
UniqueIdsTask(const std::string & channel_name)424 void EmbeddingCachePrefetchActor::UniqueIdsTask(const std::string &channel_name) {
425 const auto &iter = channel_locks_.find(channel_name);
426 if (iter == channel_locks_.end()) {
427 MS_LOG(EXCEPTION) << "Can not find channel lock for channel: " << channel_name;
428 }
429 auto channel_lock = iter->second;
430 MS_EXCEPTION_IF_NULL(channel_lock);
431
432 const auto &id_data_queue = DataQueueManager::GetInstance().GetDataQueue(channel_name).first;
433 MS_EXCEPTION_IF_NULL(id_data_queue);
434
435 const auto &queue_iter = channel_to_queues_.find(channel_name);
436 if (queue_iter == channel_to_queues_.end()) {
437 MS_LOG(EXCEPTION) << "Can not find queue for channel: " << channel_name;
438 }
439 const auto &unique_ids_queue = std::get<std::shared_ptr<BlockingQueue<UniqueIds>>>(queue_iter->second);
440 MS_EXCEPTION_IF_NULL(unique_ids_queue);
441
442 size_t sink_size = DataQueueManager::GetInstance().GetSinkSize(channel_name);
443 size_t multi_batch_counter = 0;
444 UniqueIds *unique_ids = nullptr;
445 while (running_) {
446 IdDataInfo *data = id_data_queue->Pop();
447 if (!running_) {
448 break;
449 }
450 MS_EXCEPTION_IF_NULL(data);
451 int *batch_ids = reinterpret_cast<int *>(data->data_);
452 if (batch_ids) {
453 // Lock in first stage to support multi-channel case for real input data.
454 channel_lock->TryLockChannel();
455 MS_EXCEPTION_IF_CHECK_FAIL(IncreaseStep(), "Increase step failed.");
456 }
457
458 if (data->end_of_file_ || data->end_of_epoch_) {
459 // Push empty data for epoch or file end flag.
460 UniqueIds *empty_unique_ids = new (std::nothrow) UniqueIds();
461 MS_EXCEPTION_IF_NULL(empty_unique_ids);
462 empty_unique_ids->end_of_epoch_ = data->end_of_epoch_;
463 empty_unique_ids->end_of_file_ = data->end_of_file_;
464 unique_ids_queue->Push(empty_unique_ids);
465 delete data;
466 continue;
467 }
468
469 if (unique_ids == nullptr) {
470 unique_ids = new (std::nothrow) UniqueIds();
471 MS_EXCEPTION_IF_NULL(unique_ids);
472 }
473
474 ++multi_batch_counter;
475 if (multi_batch_counter < sink_size &&
476 multi_batch_counter % embedding_cache_table_manager.multi_batch_threshold_ != 0) {
477 unique_ids->multi_batch_data_.push_back(batch_ids);
478 unique_ids->multi_batch_size_.push_back(data->size_ / sizeof(int));
479 unique_ids->multi_batch_items_.push_back(data->items_);
480 continue;
481 }
482 unique_ids->multi_batch_data_.push_back(batch_ids);
483 unique_ids->multi_batch_size_.push_back(data->size_ / sizeof(int));
484 unique_ids->multi_batch_items_.push_back(data->items_);
485
486 if (multi_batch_counter == sink_size) {
487 multi_batch_counter = 0;
488 }
489
490 // Unique for each batch and store unique ids
491 DeduplicateId(unique_ids);
492 // Push to next stage pipeline queue.
493 unique_ids->data_step_ = data_step_;
494
495 unique_ids_queue->Push(unique_ids);
496 unique_ids = nullptr;
497 delete data;
498 }
499 }
500
AnalyseCacheTask(const std::string & channel_name)501 void EmbeddingCachePrefetchActor::AnalyseCacheTask(const std::string &channel_name) {
502 const auto &queue_iter = channel_to_queues_.find(channel_name);
503 if (queue_iter == channel_to_queues_.end()) {
504 MS_LOG(EXCEPTION) << "Can not find queue for channel: " << channel_name;
505 }
506
507 const auto &unique_ids_queue = std::get<std::shared_ptr<BlockingQueue<UniqueIds>>>(queue_iter->second);
508 MS_EXCEPTION_IF_NULL(unique_ids_queue);
509 const auto &cache_analysis_queue = std::get<std::shared_ptr<BlockingQueue<CacheAnalysis>>>(queue_iter->second);
510 MS_EXCEPTION_IF_NULL(cache_analysis_queue);
511
512 while (running_) {
513 UniqueIds *unique_ids = unique_ids_queue->Pop();
514 if (!running_) {
515 break;
516 }
517 MS_EXCEPTION_IF_NULL(unique_ids);
518 if (unique_ids->end_of_file_ || unique_ids->end_of_epoch_) {
519 // Push empty data for epoch or file end flag.
520 CacheAnalysis *cache_analysis = new (std::nothrow) CacheAnalysis();
521 MS_EXCEPTION_IF_NULL(cache_analysis);
522 cache_analysis->end_of_epoch_ = unique_ids->end_of_epoch_;
523 cache_analysis->end_of_file_ = unique_ids->end_of_file_;
524 cache_analysis_queue->Push(cache_analysis);
525 delete unique_ids;
526 continue;
527 }
528 size_t unique_ids_num = unique_ids->ids_num_;
529 int *indices = new (std::nothrow) int[unique_ids_num];
530 MS_EXCEPTION_IF_NULL(indices);
531
532 EmbeddingDeviceCache *embedding_device_cache = new (std::nothrow) EmbeddingDeviceCache(unique_ids_num);
533 MS_EXCEPTION_IF_NULL(embedding_device_cache);
534 EmbeddingHostCache *embedding_host_cache = new (std::nothrow) EmbeddingHostCache(unique_ids_num);
535 MS_EXCEPTION_IF_NULL(embedding_host_cache);
536 EmbeddingCacheStatisticsInfo *statistics_info = new (std::nothrow) EmbeddingCacheStatisticsInfo();
537 MS_EXCEPTION_IF_NULL(statistics_info);
538
539 // Analyse cache hit/miss
540 if (!emb_ops_->AnalyseCache(unique_ids->ids_, unique_ids_num, unique_ids->data_step_, &graph_step_,
541 &device_cache_need_wait_graph_, &host_cache_need_wait_graph_, indices,
542 embedding_device_cache, embedding_host_cache, statistics_info)) {
543 MS_LOG(ERROR) << "Analyse cache failed.";
544 StopPrefetchCachePipeline();
545 return;
546 }
547
548 // Push analyse result to update cache queue
549 CacheAnalysis *cache_analysis =
550 new (std::nothrow) CacheAnalysis(embedding_device_cache, embedding_host_cache, statistics_info, unique_ids,
551 indices, unique_ids->end_of_epoch_, unique_ids->end_of_file_);
552 MS_EXCEPTION_IF_NULL(cache_analysis);
553 cache_analysis_queue->Push(cache_analysis);
554 }
555 }
556
UpdateCacheTask(const std::string & channel_name)557 void EmbeddingCachePrefetchActor::UpdateCacheTask(const std::string &channel_name) {
558 const auto &queue_iter = channel_to_queues_.find(channel_name);
559 if (queue_iter == channel_to_queues_.end()) {
560 MS_LOG(EXCEPTION) << "Can not find queue for channel: " << channel_name;
561 }
562
563 const auto &cache_analysis_queue = std::get<std::shared_ptr<BlockingQueue<CacheAnalysis>>>(queue_iter->second);
564 MS_EXCEPTION_IF_NULL(cache_analysis_queue);
565
566 const auto &ids_and_indices_queue = std::get<std::shared_ptr<BlockingQueue<IdsAndIndices>>>(queue_iter->second);
567 MS_EXCEPTION_IF_NULL(ids_and_indices_queue);
568
569 while (running_) {
570 CacheAnalysis *cache_analysis = cache_analysis_queue->Pop();
571 if (!running_) {
572 break;
573 }
574 MS_EXCEPTION_IF_NULL(cache_analysis);
575 if (cache_analysis->end_of_file_ || cache_analysis->end_of_epoch_) {
576 // Push empty data for epoch end flag.
577 IdsAndIndices *ids_and_indices = new (std::nothrow) IdsAndIndices();
578 MS_EXCEPTION_IF_NULL(ids_and_indices);
579 ids_and_indices->end_of_epoch_ = cache_analysis->end_of_epoch_;
580 ids_and_indices->end_of_file_ = cache_analysis->end_of_file_;
581 ids_and_indices_queue->Push(ids_and_indices);
582 delete cache_analysis;
583 continue;
584 }
585
586 for (const auto &item : embedding_cache_table_manager.hash_tables_) {
587 const auto &hash_info = item.second;
588 MS_EXCEPTION_IF_CHECK_FAIL(PushCacheFromLocalHostToRemote(hash_info, cache_analysis),
589 "Push cache from local host to remote failed.");
590 MS_EXCEPTION_IF_CHECK_FAIL(emb_ops_->PushCacheFromDeviceToLocalHost(hash_info, cache_analysis),
591 "Push cache from device to local host failed.");
592 MS_EXCEPTION_IF_CHECK_FAIL(InitLocalCacheForNewIds(hash_info, cache_analysis),
593 "Initialize the local cache values using random generator.");
594 MS_EXCEPTION_IF_CHECK_FAIL(PullCacheFromRemoteToLocalHost(hash_info, cache_analysis),
595 "Pull cache from remote to local host failed.");
596 MS_EXCEPTION_IF_CHECK_FAIL(emb_ops_->PullCacheFromLocalHostToDevice(hash_info, cache_analysis),
597 "Pull cache from local host to device failed.");
598 }
599
600 IdsAndIndices *ids_and_indices =
601 new (std::nothrow) IdsAndIndices(cache_analysis->unique_ids_, cache_analysis->indices_,
602 cache_analysis->end_of_epoch_, cache_analysis->end_of_file_);
603
604 ids_and_indices_queue->Push(ids_and_indices);
605
606 delete cache_analysis->embedding_host_cache_;
607 delete cache_analysis->embedding_device_cache_;
608 delete cache_analysis->statistics_info_;
609 delete cache_analysis;
610 }
611 }
612
TransformIdsToIndicesTask(const std::string & channel_name)613 void EmbeddingCachePrefetchActor::TransformIdsToIndicesTask(const std::string &channel_name) {
614 const auto &queue_iter = channel_to_queues_.find(channel_name);
615 if (queue_iter == channel_to_queues_.end()) {
616 MS_LOG(EXCEPTION) << "Can not find queue for channel: " << channel_name;
617 }
618 const auto &ids_and_indices_queue = std::get<std::shared_ptr<BlockingQueue<IdsAndIndices>>>(queue_iter->second);
619 MS_EXCEPTION_IF_NULL(ids_and_indices_queue);
620
621 const auto &index_data_queue = DataQueueManager::GetInstance().GetDataQueue(channel_name).second;
622 MS_EXCEPTION_IF_NULL(index_data_queue);
623 while (running_) {
624 IdsAndIndices *ids_and_indices = ids_and_indices_queue->Pop();
625 if (!running_) {
626 break;
627 }
628 MS_EXCEPTION_IF_NULL(ids_and_indices);
629 // Push empty data for epoch end flag.
630 if (ids_and_indices->end_of_file_ || ids_and_indices->end_of_epoch_) {
631 IndexDataInfo *indices_info = new (std::nothrow) IndexDataInfo();
632 MS_EXCEPTION_IF_NULL(indices_info);
633 indices_info->end_of_epoch_ = ids_and_indices->end_of_epoch_;
634 indices_info->end_of_file_ = ids_and_indices->end_of_file_;
635 index_data_queue->Push(indices_info);
636 delete ids_and_indices;
637 continue;
638 }
639
640 auto *unique_ids = ids_and_indices->unique_ids_;
641 MS_EXCEPTION_IF_NULL(unique_ids);
642 auto *unique_ids_ptr = unique_ids->ids_;
643 auto unique_ids_num = unique_ids->ids_num_;
644 auto *unique_indices_ptr = ids_and_indices->indices_;
645
646 mindspore::HashMap<int, int> unique_ids_to_indices;
647 for (size_t i = 0; i < unique_ids_num; i++) {
648 (void)unique_ids_to_indices.try_emplace(unique_ids_ptr[i], unique_indices_ptr[i]);
649 }
650
651 for (size_t i = 0; i < unique_ids->multi_batch_data_.size(); ++i) {
652 if (!embedding_cache_table_manager.is_sparse_format()) {
653 TransformIdsToIndices(&unique_ids_to_indices, unique_ids->multi_batch_size_.at(i),
654 reinterpret_cast<int *>(unique_ids->multi_batch_data_.at(i)));
655 }
656 IndexDataInfo *indices_info =
657 new IndexDataInfo(unique_ids->multi_batch_data_.at(i), unique_ids->multi_batch_items_.at(i),
658 ids_and_indices->end_of_epoch_, ids_and_indices->end_of_file_);
659
660 if (unique_ids->multi_batch_data_.at(i) != unique_ids->multi_batch_items_.at(i)->at(0).data_ptr) {
661 MS_LOG(EXCEPTION) << "The id data ptr is valid";
662 }
663
664 index_data_queue->Push(indices_info);
665 }
666
667 delete[] ids_and_indices->unique_ids_->ids_;
668 delete ids_and_indices->unique_ids_;
669 delete[] ids_and_indices->indices_;
670 delete ids_and_indices;
671 }
672 }
673
IncreaseStep()674 bool EmbeddingCachePrefetchActor::IncreaseStep() {
675 if (data_step_ >= UINT64_MAX) {
676 MS_LOG(ERROR) << "The data step (" << data_step_ << ") will exceed the maximum value of uint64_t.";
677 return false;
678 }
679 data_step_++;
680 set_current_graph_step();
681 if (graph_running_step_ > data_step_) {
682 MS_LOG(ERROR) << "The graph running step (" << graph_running_step_ << ") exceed the data step (" << data_step_
683 << ").";
684 return false;
685 }
686 return true;
687 }
688
WaitGraphRun()689 bool EmbeddingCachePrefetchActor::WaitGraphRun() {
690 MS_LOG(INFO) << "Hash table has no space to insert new data and retries within 2 minutes.";
691 std::unique_lock<std::mutex> locker(data_mutex_);
692 const int64_t longest_time_to_wait = 120;
693 if (!data_parser_.wait_for(locker, std::chrono::seconds(longest_time_to_wait),
694 [this] { return graph_step_ > graph_running_step_; })) {
695 std::string err_info = "Prefetch embedding cache timeout, please enlarge the vocab cache size(graph step:" +
696 std::to_string(graph_step_) + ", graph running step:" + std::to_string(graph_running_step_) +
697 ").";
698 SetErrorInfo(err_info);
699 MS_LOG(ERROR) << err_info;
700 return false;
701 }
702 set_current_graph_step();
703 return true;
704 }
705
ResetEmbeddingHashMap()706 bool EmbeddingCachePrefetchActor::ResetEmbeddingHashMap() {
707 const auto &device_hash_map = embedding_cache_table_manager.device_hash_map_;
708 MS_ERROR_IF_NULL(device_hash_map);
709 const auto &host_hash_map = embedding_cache_table_manager.host_hash_map_;
710 MS_ERROR_IF_NULL(host_hash_map);
711 device_hash_map->Reset();
712 host_hash_map->Reset();
713 device_cache_need_wait_graph_ = false;
714 host_cache_need_wait_graph_ = false;
715 return true;
716 }
717
LookupEmbeddingTable(size_t indices_num,size_t embedding_size,size_t first_dim_size,const float * input_addr,const int * indices_addr,float * output_addr)718 void EmbeddingCachePrefetchActor::LookupEmbeddingTable(size_t indices_num, size_t embedding_size, size_t first_dim_size,
719 const float *input_addr, const int *indices_addr,
720 float *output_addr) {
721 MS_ERROR_IF_NULL_WO_RET_VAL(input_addr);
722 MS_ERROR_IF_NULL_WO_RET_VAL(indices_addr);
723 MS_ERROR_IF_NULL_WO_RET_VAL(output_addr);
724
725 auto type_size = sizeof(float);
726 size_t lens = embedding_size * type_size;
727 for (size_t i = 0; i < indices_num; ++i) {
728 int index = indices_addr[i];
729 if (index >= 0 && index < SizeToInt(first_dim_size)) {
730 size_t pos = IntToSize(index) * embedding_size;
731 auto ret = memcpy_s(output_addr, (indices_num - i) * lens, input_addr + pos, lens);
732 if (ret != EOK) {
733 MS_LOG(ERROR) << "Memcpy failed, errno[" << ret << "]";
734 StopPrefetchCachePipeline();
735 return;
736 }
737 } else {
738 auto ret = memset_s(output_addr, (indices_num - i) * lens, 0, lens);
739 if (ret != EOK) {
740 MS_LOG(ERROR) << "Memset failed, errno[" << ret << "]";
741 StopPrefetchCachePipeline();
742 return;
743 }
744 }
745 output_addr += embedding_size;
746 }
747 }
748
PushCacheFromLocalHostToRemote(const HashTableInfo & hash_info,const CacheAnalysis * cache_analysis)749 bool EmbeddingCachePrefetchActor::PushCacheFromLocalHostToRemote(const HashTableInfo &hash_info,
750 const CacheAnalysis *cache_analysis) {
751 MS_ERROR_IF_NULL(cache_analysis);
752 auto statistics_info = cache_analysis->statistics_info_;
753 auto embedding_host_cache = cache_analysis->embedding_host_cache_;
754 MS_ERROR_IF_NULL(statistics_info);
755 MS_ERROR_IF_NULL(embedding_host_cache);
756
757 auto swap_indices_size = statistics_info->host_to_server_size_;
758 if (swap_indices_size == 0) {
759 return true;
760 }
761
762 auto host_to_server_ids = embedding_host_cache->host_to_server_ids.get();
763 MS_ERROR_IF_NULL(host_to_server_ids);
764 auto host_to_server_index = embedding_host_cache->host_to_server_index.get();
765 MS_ERROR_IF_NULL(host_to_server_index);
766
767 std::vector<float> swap_out_data;
768 auto embedding_size = hash_info.embedding_size;
769 swap_out_data.resize(swap_indices_size * embedding_size);
770 auto host_hash_table_addr = hash_info.host_address;
771
772 RETURN_IF_FALSE_WITH_LOG(LookupLocalHostCache(embedding_size, swap_indices_size, host_hash_table_addr,
773 host_to_server_index, swap_out_data.data()),
774 "Lookup local host cache failed.");
775 RETURN_IF_FALSE_WITH_LOG(PushEmbeddingsToRemote(hash_info.param_key_, host_to_server_ids, swap_indices_size,
776 swap_out_data.data(), swap_out_data.size() * sizeof(float)),
777 "Push embeddings to remote failed.");
778 return true;
779 }
780
PullCacheFromRemoteToLocalHost(const HashTableInfo & hash_info,const CacheAnalysis * cache_analysis)781 bool EmbeddingCachePrefetchActor::PullCacheFromRemoteToLocalHost(const HashTableInfo &hash_info,
782 const CacheAnalysis *cache_analysis) {
783 MS_ERROR_IF_NULL(cache_analysis);
784 auto statistics_info = cache_analysis->statistics_info_;
785 auto embedding_host_cache = cache_analysis->embedding_host_cache_;
786 MS_ERROR_IF_NULL(statistics_info);
787 MS_ERROR_IF_NULL(embedding_host_cache);
788
789 auto swap_indices_size = statistics_info->server_to_host_size_;
790 if (swap_indices_size == 0) {
791 return true;
792 }
793
794 auto server_to_host_ids = embedding_host_cache->server_to_host_ids.get();
795 MS_ERROR_IF_NULL(server_to_host_ids);
796 auto server_to_host_index = embedding_host_cache->server_to_host_index.get();
797 MS_ERROR_IF_NULL(server_to_host_index);
798
799 auto host_hash_table_addr = hash_info.host_address;
800 MS_ERROR_IF_NULL(host_hash_table_addr);
801 auto embedding_size = hash_info.embedding_size;
802 std::vector<float> lookup_result(swap_indices_size * embedding_size, 0);
803
804 RETURN_IF_FALSE_WITH_LOG(
805 PullEembeddingsFromRemote(hash_info.param_key_, server_to_host_ids, swap_indices_size, &lookup_result),
806 "Pull embedding from remote failed.");
807 RETURN_IF_FALSE_WITH_LOG(InsertLocalHostCache(embedding_size, IntToSize(swap_indices_size), server_to_host_index,
808 lookup_result.data(), host_hash_table_addr),
809 "Insert local host cache failed.");
810 return true;
811 }
812
InitLocalCacheForNewIds(const HashTableInfo & hash_info,const CacheAnalysis * cache_analysis)813 bool EmbeddingCachePrefetchActor::InitLocalCacheForNewIds(const HashTableInfo &hash_info,
814 const CacheAnalysis *cache_analysis) {
815 MS_ERROR_IF_NULL(cache_analysis);
816 auto statistics_info = cache_analysis->statistics_info_;
817 auto embedding_host_cache = cache_analysis->embedding_host_cache_;
818 MS_ERROR_IF_NULL(statistics_info);
819 MS_ERROR_IF_NULL(embedding_host_cache);
820
821 auto new_id_size = statistics_info->new_id_size_;
822 if (new_id_size == 0) {
823 return true;
824 }
825
826 auto new_id_index = embedding_host_cache->new_id_index.get();
827 MS_ERROR_IF_NULL(new_id_index);
828
829 // Compute the feature values size needed to be initialized.
830 auto embedding_size = hash_info.embedding_size;
831 auto total_size = new_id_size * embedding_size;
832 std::vector<float> init_result(total_size, 0);
833
834 // Initialize accumulate values with the configured constant value.
835 if (hash_info.param_init_info_.param_type_ == distributed::ParamType::kAccumulation) {
836 auto init_value = hash_info.param_init_info_.init_val_;
837 GenerateDistributionParallel<DataType, Generator, ConstantDistribution>(total_size, init_result.data(), init_value);
838 } else {
839 // Initialize embedding values from local random generator for feature ids that have never been seen before.
840 const double mean = 0.0;
841 const double sigma = 0.01;
842 GenerateDistributionParallel<DataType, Generator, NormalDistribution>(total_size, init_result.data(), mean, sigma);
843 }
844
845 // Insert initialized feature values into the local hash cache.
846 auto host_hash_table_addr = hash_info.host_address;
847 MS_ERROR_IF_NULL(host_hash_table_addr);
848 RETURN_IF_FALSE_WITH_LOG(InsertLocalHostCache(embedding_size, IntToSize(new_id_size), new_id_index,
849 init_result.data(), host_hash_table_addr),
850 "Insert local host cache failed.");
851 return true;
852 }
853
InsertLocalHostCache(size_t embedding_size,size_t insert_indices_size,const int * insert_indices,const float * insert_data,float * hash_table_addr)854 bool EmbeddingCachePrefetchActor::InsertLocalHostCache(size_t embedding_size, size_t insert_indices_size,
855 const int *insert_indices, const float *insert_data,
856 float *hash_table_addr) {
857 MS_ERROR_IF_NULL(insert_indices);
858 MS_ERROR_IF_NULL(insert_data);
859 MS_ERROR_IF_NULL(hash_table_addr);
860
861 size_t first_dim_size = local_host_cache_size_;
862 size_t thread_num = insert_indices_size / kMaxIdsPerThread + 1;
863 thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
864 std::thread threads[kMaxThreadNum];
865 size_t proc_len = (insert_indices_size + thread_num - 1) / thread_num;
866 size_t i = 0;
867 size_t offset = 0;
868
869 auto insert_cache_func = [this](size_t insert_indices_size, size_t embedding_size, size_t first_dim_size,
870 const int *insert_indices, const float *insert_data, float *hash_table_addr) {
871 auto type_size = sizeof(float);
872 size_t copy_len = embedding_size * type_size;
873 size_t dest_len = copy_len;
874 for (size_t i = 0; i < insert_indices_size; ++i) {
875 int index = insert_indices[i];
876 if (index >= 0 && index < SizeToInt(first_dim_size)) {
877 auto ret =
878 memcpy_s(hash_table_addr + index * embedding_size, dest_len, insert_data + i * embedding_size, copy_len);
879 if (ret != EOK) {
880 MS_LOG(ERROR) << "Memcpy failed, errno[" << ret << "]";
881 StopPrefetchCachePipeline();
882 return;
883 }
884 }
885 }
886 };
887
888 for (; i < thread_num; i++) {
889 if (offset >= insert_indices_size) {
890 break;
891 }
892 threads[i] = std::thread(insert_cache_func, proc_len, embedding_size, first_dim_size, insert_indices + offset,
893 insert_data + offset * embedding_size, hash_table_addr);
894 offset += proc_len;
895 if (offset + proc_len > insert_indices_size) {
896 proc_len = insert_indices_size - offset;
897 }
898 }
899
900 for (size_t j = 0; j < i; j++) {
901 threads[j].join();
902 }
903 return running_;
904 }
905
LookupLocalHostCache(size_t embedding_size,size_t indices_num,const float * hash_table_addr,const int * indices_addr,float * output_addr)906 bool EmbeddingCachePrefetchActor::LookupLocalHostCache(size_t embedding_size, size_t indices_num,
907 const float *hash_table_addr, const int *indices_addr,
908 float *output_addr) {
909 MS_ERROR_IF_NULL(hash_table_addr);
910 MS_ERROR_IF_NULL(indices_addr);
911 MS_ERROR_IF_NULL(output_addr);
912
913 size_t first_dim_size = local_host_cache_size_;
914 size_t thread_num = indices_num / kMaxIdsPerThread + 1;
915 thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
916 std::thread threads[kMaxThreadNum];
917 size_t proc_len = (indices_num + thread_num - 1) / thread_num;
918 size_t i = 0;
919 size_t offset = 0;
920
921 for (; i < thread_num; i++) {
922 if (offset >= indices_num) {
923 break;
924 }
925 threads[i] =
926 std::thread(&EmbeddingCachePrefetchActor::LookupEmbeddingTable, this, proc_len, embedding_size, first_dim_size,
927 hash_table_addr, indices_addr + offset, output_addr + offset * embedding_size);
928 offset += proc_len;
929 if (offset + proc_len > indices_num) {
930 proc_len = indices_num - offset;
931 }
932 }
933
934 for (size_t j = 0; j < i; j++) {
935 threads[j].join();
936 }
937 return running_;
938 }
939
PullEembeddingsFromRemote(int32_t param_key,const int * ids,size_t ids_num,std::vector<float> * outputs)940 bool EmbeddingCachePrefetchActor::PullEembeddingsFromRemote(int32_t param_key, const int *ids, size_t ids_num,
941 std::vector<float> *outputs) {
942 MS_ERROR_IF_NULL(ids);
943 MS_ERROR_IF_NULL(outputs);
944
945 if (ids_num == 0) {
946 MS_LOG(WARNING) << "The ids number is 0";
947 return true;
948 }
949
950 std::vector<std::vector<int>> slice_ids_list(server_num_);
951 // 1. Partition ids by remote embedding slice bound and get unique ids.
952 RETURN_IF_FALSE_WITH_LOG(PartitionIds(ids, ids_num, &slice_ids_list), "Partition ids failed.");
953
954 size_t embedding_dim = outputs->size() / ids_num;
955 for (size_t i = 0; i < server_num_; i++) {
956 auto &slice_ids = slice_ids_list[i];
957 if (slice_ids.empty()) {
958 continue;
959 }
960
961 // 2. Send unique ids to remote to do embedding lookup.
962 RETURN_IF_FALSE_WITH_LOG(SendToRemote(distributed::kLookupEmbeddingCache, param_key, i, embedding_dim,
963 slice_ids.data(), slice_ids.size() * sizeof(int), nullptr, 0, false, false),
964 "Send ids to server failed.");
965 }
966
967 std::vector<std::unique_ptr<std::vector<char>>> slice_embeddings_list(server_num_);
968 for (size_t i = 0; i < server_num_; i++) {
969 if (slice_ids_list[i].empty()) {
970 continue;
971 }
972
973 // 3. Wait embeddings result.
974 slice_embeddings_list[i] = ReceiveFromRemote(distributed::kLookupEmbeddingCache, param_key, i);
975 MS_ERROR_IF_NULL(slice_embeddings_list[i]);
976 // Received embedding integrity check.
977 size_t expected_embedding_size = SizetMulWithOverflowCheck(slice_ids_list[i].size(), embedding_dim);
978 size_t received_embedding_size = slice_embeddings_list[i]->size() / sizeof(float);
979 if (received_embedding_size != expected_embedding_size) {
980 MS_LOG(ERROR) << "Received embedding data from remote is incomplete, expected embedding size: "
981 << expected_embedding_size << ", but received embedding size: " << received_embedding_size;
982 return false;
983 }
984 }
985
986 // 4. Retrieve embeddings by input ids order.
987 RETURN_IF_FALSE_WITH_LOG(RetrieveEmbeddings(ids, ids_num, slice_ids_list, slice_embeddings_list, outputs),
988 "Retrieve embeddings failed.");
989
990 return true;
991 }
992
DoPushEmbeddingsToRemote(int32_t param_key,const int * ids,size_t ids_num,const float * embeddings,size_t embeddings_len)993 bool EmbeddingCachePrefetchActor::DoPushEmbeddingsToRemote(int32_t param_key, const int *ids, size_t ids_num,
994 const float *embeddings, size_t embeddings_len) {
995 MS_LOG(DEBUG) << "Enter DoPushEmbeddingsToRemote - param_key : " << param_key << ", ids : " << ids
996 << ", ids_num : " << ids_num << ", embeddings : " << embeddings
997 << ", embeddings_len : " << embeddings_len << ".";
998 if (ids_num == 0) {
999 MS_LOG(ERROR) << "Invalidate ids num : 0.";
1000 return false;
1001 }
1002 std::vector<std::vector<int>> slice_ids_list(server_num_);
1003 std::vector<std::vector<float>> slice_embeddings_list(server_num_);
1004 // 1. Partition ids end embeddings by remote embedding slice bound.
1005 RETURN_IF_FALSE_WITH_LOG(
1006 PartitionIdsAndEmbeddings(ids, ids_num, embeddings, embeddings_len, &slice_ids_list, &slice_embeddings_list),
1007 "Partition ids and embeddings failed.");
1008
1009 size_t embedding_dim = (embeddings_len / ids_num) / sizeof(float);
1010 for (size_t i = 0; i < server_num_; i++) {
1011 auto &slice_ids = slice_ids_list[i];
1012 if (slice_ids.empty()) {
1013 continue;
1014 }
1015
1016 // 2. Send embeddings to remote.
1017 auto &slice_embeddings = slice_embeddings_list[i];
1018 RETURN_IF_FALSE_WITH_LOG(
1019 SendToRemote(distributed::kUpdateEmbeddingCache, param_key, i, embedding_dim, slice_ids.data(),
1020 slice_ids.size() * sizeof(int), slice_embeddings.data(), slice_embeddings.size() * sizeof(float)),
1021 "Send ids and embeddings to server failed.");
1022 }
1023 MS_LOG(DEBUG) << "Exit DoPushEmbeddingsToRemote.";
1024 return true;
1025 }
1026
PushEmbeddingsToRemote(int32_t param_key,const int * ids,size_t ids_num,const float * embeddings,size_t embeddings_len)1027 bool EmbeddingCachePrefetchActor::PushEmbeddingsToRemote(int32_t param_key, const int *ids, size_t ids_num,
1028 const float *embeddings, size_t embeddings_len) {
1029 MS_EXCEPTION_IF_NULL(ids);
1030 MS_EXCEPTION_IF_NULL(embeddings);
1031 MS_EXCEPTION_IF_CHECK_FAIL(ids_num != 0, "The ids_num is 0.");
1032
1033 const auto ids_boundary = ids + ids_num;
1034 const auto embeddings_boundary = embeddings + embeddings_len / sizeof(float);
1035 MS_LOG(DEBUG) << "Enter PushEmbeddingsToRemote - param_key : " << param_key << ", ids : " << ids
1036 << ", ids_num : " << ids_num << ", embeddings : " << embeddings
1037 << ", embeddings_len : " << embeddings_len << ", ids_boundary : " << ids_boundary
1038 << ", embeddings_boundary : " << embeddings_boundary << ".";
1039 const size_t embeddings_num = embeddings_len / sizeof(float);
1040 const size_t embeddings_dim = embeddings_num / ids_num;
1041 MS_EXCEPTION_IF_CHECK_FAIL(embeddings_dim != 0, "The embeddings_dim is 0.");
1042 // Max batch size : 128Mb.
1043 const size_t max_batch_size = 1 << 27;
1044 const size_t batch_num = max_batch_size / embeddings_dim;
1045 MS_EXCEPTION_IF_CHECK_FAIL(batch_num != 0, "The batch_num is 0.");
1046 size_t batch_size = ids_num / batch_num;
1047 size_t batch_remainder = ids_num % batch_num;
1048 if (batch_remainder != 0) {
1049 batch_size++;
1050 }
1051 MS_LOG(DEBUG) << "batch_size : " << batch_size << ", batch_num << : " << batch_num
1052 << ", batch_remainder : " << batch_remainder << ".";
1053 for (size_t count = 0; count != batch_size; count++) {
1054 auto batch_ids = ids + batch_num * count;
1055 size_t batch_ids_num = (count != batch_size - 1) ? batch_num : batch_remainder;
1056 auto batch_embeddings = embeddings + batch_num * count * embeddings_dim;
1057 size_t batch_embeddings_len = batch_ids_num * embeddings_dim * sizeof(float);
1058 (void)DoPushEmbeddingsToRemote(param_key, batch_ids, batch_ids_num, batch_embeddings, batch_embeddings_len);
1059 }
1060 MS_LOG(DEBUG) << "Exit PushEmbeddingsToRemote.";
1061 return true;
1062 }
1063
PartitionIds(const int * ids,size_t ids_num,std::vector<std::vector<int>> * slice_ids_list)1064 bool EmbeddingCachePrefetchActor::PartitionIds(const int *ids, size_t ids_num,
1065 std::vector<std::vector<int>> *slice_ids_list) {
1066 MS_ERROR_IF_NULL(ids);
1067 MS_ERROR_IF_NULL(slice_ids_list);
1068
1069 size_t partition_num = slice_ids_list->size();
1070 // There is no need to partition ids for one server case.
1071 if (partition_num == 1) {
1072 std::vector<int> &slice_ids = slice_ids_list->front();
1073 slice_ids.resize(ids_num);
1074 auto ret = memcpy_s(slice_ids.data(), slice_ids.size() * sizeof(int), ids, ids_num * sizeof(int));
1075 if (ret != EOK) {
1076 MS_LOG(ERROR) << "Memcpy failed, errno[" << ret << "]";
1077 return false;
1078 }
1079 return true;
1080 }
1081
1082 for (size_t i = 0; i < partition_num; i++) {
1083 int begin = SizeToInt(remote_embedding_slice_bounds_[i].first);
1084 int end = SizeToInt(remote_embedding_slice_bounds_[i].second);
1085
1086 std::vector<int> &slice_ids = slice_ids_list->at(i);
1087 (void)std::for_each(ids, ids + ids_num, [&](int id) {
1088 if (id >= begin && id <= end) {
1089 slice_ids.push_back(id);
1090 }
1091 });
1092 }
1093
1094 return true;
1095 }
1096
PartitionIdsAndEmbeddings(const int * ids,size_t ids_num,const float * embeddings,size_t embeddings_len,std::vector<std::vector<int>> * slice_ids_list,std::vector<std::vector<float>> * slice_embeddings_list)1097 bool EmbeddingCachePrefetchActor::PartitionIdsAndEmbeddings(const int *ids, size_t ids_num, const float *embeddings,
1098 size_t embeddings_len,
1099 std::vector<std::vector<int>> *slice_ids_list,
1100 std::vector<std::vector<float>> *slice_embeddings_list) {
1101 MS_ERROR_IF_NULL(ids);
1102 MS_ERROR_IF_NULL(embeddings);
1103 MS_ERROR_IF_NULL(slice_ids_list);
1104 MS_ERROR_IF_NULL(slice_embeddings_list);
1105
1106 if (ids_num == 0) {
1107 MS_LOG(WARNING) << "The ids number is 0";
1108 return true;
1109 }
1110
1111 size_t partition_num = slice_ids_list->size();
1112 // There is no need to partition ids and embeddings for one server case.
1113 if (partition_num == 1) {
1114 std::vector<int> &slice_ids = slice_ids_list->front();
1115 std::vector<float> &slice_embeddings = slice_embeddings_list->front();
1116 slice_ids.resize(ids_num);
1117 slice_embeddings.resize(embeddings_len / sizeof(float));
1118 auto ret = memcpy_s(slice_ids.data(), slice_ids.size() * sizeof(int), ids, ids_num * sizeof(int));
1119 if (ret != EOK) {
1120 MS_LOG(ERROR) << "Memcpy failed, errno[" << ret << "]";
1121 return false;
1122 }
1123 ret = memcpy_s(slice_embeddings.data(), slice_embeddings.size() * sizeof(float), embeddings, embeddings_len);
1124 if (ret != EOK) {
1125 MS_LOG(ERROR) << "Memcpy failed, errno[" << ret << "]";
1126 return false;
1127 }
1128 return true;
1129 }
1130
1131 size_t embedding_dim = (embeddings_len / ids_num) / sizeof(float);
1132 for (size_t i = 0; i < partition_num; i++) {
1133 int begin = SizeToInt(remote_embedding_slice_bounds_[i].first);
1134 int end = SizeToInt(remote_embedding_slice_bounds_[i].second);
1135
1136 std::vector<int> &slice_ids = slice_ids_list->at(i);
1137 std::vector<float> &slice_embeddings = slice_embeddings_list->at(i);
1138 // Ids range offset for multi server.
1139 int offset = SizeToInt(remote_embedding_slice_bounds_.at(i).first);
1140 for (size_t j = 0; j < ids_num; j++) {
1141 if (ids[j] >= begin && ids[j] <= end) {
1142 slice_ids.push_back(ids[j] - offset);
1143 (void)slice_embeddings.insert(slice_embeddings.end(), embeddings + (j * embedding_dim),
1144 embeddings + (j * embedding_dim) + embedding_dim);
1145 }
1146 }
1147 }
1148 return true;
1149 }
1150
SendToRemote(const std::string & cache_operation,int32_t param_key,size_t server_rank_id,size_t embedding_dim,const void * keys,size_t keys_len,const void * values,size_t values_len,bool finalize_remote,bool sync)1151 bool EmbeddingCachePrefetchActor::SendToRemote(const std::string &cache_operation, int32_t param_key,
1152 size_t server_rank_id, size_t embedding_dim, const void *keys,
1153 size_t keys_len, const void *values, size_t values_len,
1154 bool finalize_remote, bool sync) {
1155 MS_ERROR_IF_NULL(keys);
1156 // Find sender corresponding to cache operation and parameter key.
1157 auto iter = rpc_operators_.find(cache_operation);
1158 if (iter == rpc_operators_.end()) {
1159 MS_LOG(ERROR) << "Can not find rpc operator for cache operation: " << cache_operation;
1160 return false;
1161 }
1162
1163 const std::vector<SendRecvPairList> &send_recv_pair_lists = iter->second;
1164 const SenderPtr &sender = send_recv_pair_lists[server_rank_id][param_key].first;
1165 MS_ERROR_IF_NULL(sender);
1166
1167 int64_t ids_num = SizeToLong(keys_len / sizeof(int));
1168 ShapeVector ids_shape = {ids_num};
1169 ShapeVector values_shape;
1170 float fake_value = 0.0;
1171
1172 if (values == nullptr && values_len == 0) {
1173 values_shape = {1, 1};
1174 values = &fake_value;
1175 values_len = sizeof(fake_value);
1176 } else {
1177 MS_EXCEPTION_IF_ZERO("embedding_dim", embedding_dim);
1178 int64_t embed_vec_num = SizeToLong(values_len / sizeof(float) / embedding_dim);
1179 if (embed_vec_num != ids_num) {
1180 MS_LOG(EXCEPTION) << "The embedding vector number[" << embed_vec_num << "] shouled be equal to ids number["
1181 << ids_num << "] which will be send to remote.";
1182 }
1183 values_shape = {embed_vec_num, SizeToLong(embedding_dim)};
1184 }
1185
1186 std::vector<ShapeVector> shapes = {ids_shape, values_shape, {static_cast<int64_t>(1)}};
1187 std::vector<TypeId> data_types = {kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeInt32};
1188
1189 int32_t service_id = GetCacheOpsServiceId(cache_operation, param_key);
1190 AddressPtrList data_list = {std::make_shared<Address>(const_cast<void *>(keys), keys_len),
1191 std::make_shared<Address>(const_cast<void *>(values), values_len),
1192 std::make_shared<Address>(&service_id, sizeof(int32_t))};
1193
1194 // Send data.
1195 return sender->Send(shapes, data_types, data_list, finalize_remote, sync);
1196 }
1197
ReceiveFromRemote(const std::string & cache_operation,int32_t param_key,size_t server_rank_id) const1198 std::unique_ptr<std::vector<char>> EmbeddingCachePrefetchActor::ReceiveFromRemote(const std::string &cache_operation,
1199 int32_t param_key,
1200 size_t server_rank_id) const {
1201 // Find receiver corresponding to cache operation and parameter key.
1202 auto iter = rpc_operators_.find(cache_operation);
1203 if (iter == rpc_operators_.end()) {
1204 MS_LOG(ERROR) << "Can not find rpc operator for cache operation: " << cache_operation;
1205 return nullptr;
1206 }
1207
1208 const std::vector<SendRecvPairList> &send_recv_pair_lists = iter->second;
1209 const ReceiverPtr &receiver = send_recv_pair_lists[server_rank_id][param_key].second;
1210 MS_EXCEPTION_IF_NULL(receiver);
1211 // Receive data.
1212 return receiver->Receive();
1213 }
1214
RetrieveEmbeddings(const int * ids,size_t ids_num,const std::vector<std::vector<int>> & slice_ids_list,const std::vector<std::unique_ptr<std::vector<char>>> & slice_embeddings_list,std::vector<float> * outputs) const1215 bool EmbeddingCachePrefetchActor::RetrieveEmbeddings(
1216 const int *ids, size_t ids_num, const std::vector<std::vector<int>> &slice_ids_list,
1217 const std::vector<std::unique_ptr<std::vector<char>>> &slice_embeddings_list, std::vector<float> *outputs) const {
1218 MS_ERROR_IF_NULL(ids);
1219 MS_ERROR_IF_NULL(outputs);
1220
1221 if (ids_num == 0) {
1222 MS_LOG(WARNING) << "The ids number is 0";
1223 return true;
1224 }
1225
1226 // Merge all slice ids and embedding data address into ids_to_addrs map.
1227 mindspore::HashMap<int, const float *> ids_to_addrs;
1228 size_t embedding_dim = outputs->size() / ids_num;
1229 size_t offset = 0;
1230 for (size_t i = 0; i < slice_ids_list.size(); i++) {
1231 const std::vector<int> &slice_ids = slice_ids_list[i];
1232 if (slice_ids.empty()) {
1233 continue;
1234 }
1235 const std::unique_ptr<std::vector<char>> &slice_embeddings = slice_embeddings_list[i];
1236 MS_ERROR_IF_NULL(slice_embeddings);
1237 const float *embeddings_data = reinterpret_cast<float *>(slice_embeddings->data());
1238 for (size_t j = 0; j < slice_ids.size(); j++) {
1239 (void)ids_to_addrs.emplace(slice_ids[j], embeddings_data + offset);
1240 offset += embedding_dim;
1241 }
1242 offset = 0;
1243 }
1244
1245 float *outputs_data = outputs->data();
1246 size_t dst_size = embedding_dim * sizeof(float);
1247 size_t src_size = dst_size;
1248 offset = 0;
1249
1250 // Retrieve embeddings by input ids order.
1251 for (size_t i = 0; i < ids_num; i++) {
1252 auto id = ids[i];
1253 auto iter = ids_to_addrs.find(id);
1254 if (iter == ids_to_addrs.end()) {
1255 MS_LOG(WARNING) << "Can not find id[" << id << "]";
1256 continue;
1257 }
1258
1259 auto ret = memcpy_s(outputs_data + offset, dst_size, iter->second, src_size);
1260 if (ret != 0) {
1261 MS_LOG(ERROR) << "Memcpy failed, errno[" << ret << "]";
1262 return false;
1263 }
1264 offset += embedding_dim;
1265 }
1266 return true;
1267 }
1268
SyncEmbeddingTable()1269 void EmbeddingCachePrefetchActor::SyncEmbeddingTable() {
1270 std::lock_guard<std::mutex> locker(sync_embedding_table_mutex_);
1271 // Do not synchronize in case of abnormally finalizing.
1272 if (!running_) {
1273 return;
1274 }
1275
1276 if (finish_sync_embedding_table_) {
1277 return;
1278 }
1279 if (!initialized_) {
1280 return;
1281 }
1282 if (!SyncHostEmbeddingTable()) {
1283 MS_LOG(ERROR) << "SyncHostEmbeddingTable failed.";
1284 }
1285 if (!SyncDeviceEmbeddingTable()) {
1286 MS_LOG(ERROR) << "SyncDeviceEmbeddingTable failed.";
1287 }
1288 finish_sync_embedding_table_ = true;
1289 }
1290
SyncHostEmbeddingTable()1291 bool EmbeddingCachePrefetchActor::SyncHostEmbeddingTable() {
1292 MS_ERROR_IF_NULL(embedding_cache_table_manager.host_hash_map_);
1293 const auto &ids_indices_pairs = embedding_cache_table_manager.host_hash_map_->Export();
1294 size_t swap_indices_lens = ids_indices_pairs.size();
1295 if (swap_indices_lens == 0) {
1296 return true;
1297 }
1298
1299 std::unique_ptr<int[]> host_to_server_ids_ptr = std::make_unique<int[]>(swap_indices_lens);
1300 MS_ERROR_IF_NULL(host_to_server_ids_ptr);
1301 std::unique_ptr<int[]> host_to_server_indices_ptr = std::make_unique<int[]>(swap_indices_lens);
1302 MS_ERROR_IF_NULL(host_to_server_indices_ptr);
1303 size_t idx = 0;
1304 MS_EXCEPTION_IF_NULL(emb_ops_);
1305 const auto &modified_ids = emb_ops_->modified_ids();
1306 for (const auto &item : ids_indices_pairs) {
1307 if (modified_ids.find(item.first) != modified_ids.end()) {
1308 host_to_server_ids_ptr[idx] = item.first;
1309 host_to_server_indices_ptr[idx++] = item.second;
1310 }
1311 }
1312 swap_indices_lens = idx;
1313 if (swap_indices_lens == 0) {
1314 return true;
1315 }
1316 for (const auto &item : embedding_cache_table_manager.hash_tables_) {
1317 const auto &hash_info = item.second;
1318 std::vector<float> swap_out_data;
1319 auto embedding_size = hash_info.embedding_size;
1320 swap_out_data.resize(swap_indices_lens * embedding_size);
1321 auto host_hash_table_addr = hash_info.host_address;
1322 MS_ERROR_IF_NULL(host_hash_table_addr);
1323 RETURN_IF_FALSE(LookupLocalHostCache(embedding_size, swap_indices_lens, host_hash_table_addr,
1324 host_to_server_indices_ptr.get(), swap_out_data.data()));
1325
1326 RETURN_IF_FALSE_WITH_LOG(
1327 PushEmbeddingsToRemote(hash_info.param_key_, host_to_server_ids_ptr.get(), swap_indices_lens,
1328 swap_out_data.data(), swap_out_data.size() * sizeof(float)),
1329 "Push embeddings to remote failed.");
1330 }
1331 return true;
1332 }
1333
SyncDeviceEmbeddingTable()1334 bool EmbeddingCachePrefetchActor::SyncDeviceEmbeddingTable() {
1335 const auto &device_hash_map = embedding_cache_table_manager.device_hash_map_;
1336 MS_ERROR_IF_NULL(device_hash_map);
1337 const auto &ids_indices_pairs = device_hash_map->Export();
1338 size_t swap_indices_lens = ids_indices_pairs.size();
1339 if (swap_indices_lens == 0) {
1340 return true;
1341 }
1342 MS_ERROR_IF_NULL(device_context_);
1343 MS_ERROR_IF_NULL(device_context_->device_res_manager_);
1344 std::unique_ptr<int[]> device_to_server_ids_ptr = std::make_unique<int[]>(swap_indices_lens);
1345 MS_ERROR_IF_NULL(device_to_server_ids_ptr);
1346 std::unique_ptr<int[]> device_to_server_indices_ptr = std::make_unique<int[]>(swap_indices_lens);
1347 MS_ERROR_IF_NULL(device_to_server_indices_ptr);
1348 size_t idx = 0;
1349 for (const auto &item : ids_indices_pairs) {
1350 device_to_server_ids_ptr[idx] = item.first;
1351 device_to_server_indices_ptr[idx++] = item.second;
1352 }
1353 for (const auto &item : embedding_cache_table_manager.hash_tables_) {
1354 const auto &hash_info = item.second;
1355 std::vector<float> swap_out_data;
1356 auto embedding_size = hash_info.embedding_size;
1357 swap_out_data.resize(swap_indices_lens * embedding_size);
1358 std::unique_ptr<float[]> device_hash_table_addr_tmp =
1359 std::make_unique<float[]>(device_hash_map->hash_capacity() * embedding_size);
1360 MS_ERROR_IF_NULL(device_hash_table_addr_tmp);
1361
1362 auto hash_table_addr = reinterpret_cast<float *>(hash_info.address.addr);
1363 MS_ERROR_IF_NULL(hash_table_addr);
1364 auto hash_table_size = hash_info.address.size;
1365 RETURN_IF_FALSE_WITH_LOG(
1366 DeviceEmbeddingOperation::MemcpyDeviceToHostAsync(device_hash_table_addr_tmp.get(), hash_table_addr,
1367 hash_table_size, device_context_, stream_id_),
1368 "Memcpy device to host asynchronously failed.");
1369 RETURN_IF_FALSE_WITH_LOG(device_context_->device_res_manager_->SyncStream(stream_id_),
1370 "Synchronize stream failed.");
1371 RETURN_IF_FALSE(LookupLocalHostCache(embedding_size, swap_indices_lens, device_hash_table_addr_tmp.get(),
1372 device_to_server_indices_ptr.get(), swap_out_data.data()));
1373
1374 RETURN_IF_FALSE_WITH_LOG(
1375 PushEmbeddingsToRemote(hash_info.param_key_, device_to_server_ids_ptr.get(), swap_indices_lens,
1376 swap_out_data.data(), swap_out_data.size() * sizeof(float)),
1377 "Push embeddings to remote failed.");
1378 }
1379 return true;
1380 }
1381
FinalizeRemote()1382 bool EmbeddingCachePrefetchActor::FinalizeRemote() {
1383 for (size_t i = 0; i < server_num_; i++) {
1384 size_t embedding_dim = 1;
1385 int id = 0;
1386 float value = 0.0;
1387 RETURN_IF_FALSE_WITH_LOG(SendToRemote(distributed::kLookupEmbeddingCache, 0, i, embedding_dim, &id, sizeof(int),
1388 &value, sizeof(float), true),
1389 "Send finalize request to remote failed.");
1390 }
1391
1392 return true;
1393 }
1394
channel_name()1395 const std::string &EmbeddingCachePrefetchActor::channel_name() {
1396 std::lock_guard<std::mutex> locker(channel_mutex_);
1397 return channel_name_;
1398 }
1399
set_channel_name(const std::string & channel_name)1400 void EmbeddingCachePrefetchActor::set_channel_name(const std::string &channel_name) {
1401 if (channel_name_ == channel_name) {
1402 return;
1403 }
1404 std::lock_guard<std::mutex> locker(channel_mutex_);
1405 channel_name_ = channel_name;
1406 }
1407
WaitDataChannelInit()1408 void EmbeddingCachePrefetchActor::WaitDataChannelInit() {
1409 MS_LOG(INFO) << "Begin wait embedding cache data channel init.";
1410 auto channel = channel_name();
1411 if (channel.empty()) {
1412 std::unique_lock<std::mutex> locker(data_mutex_);
1413 data_parser_.wait(locker, [this] { return !channel_name_.empty() || running_ == false; });
1414 if (!running_) {
1415 return;
1416 }
1417 }
1418 MS_LOG(INFO) << "End wait embedding cache data channel init.";
1419 }
1420
WaitInitParametersOnRemote()1421 void EmbeddingCachePrefetchActor::WaitInitParametersOnRemote() {
1422 std::unique_lock<std::mutex> locker(data_mutex_);
1423 // Note: wait to finish embedding lookup from remote.
1424 finish_init_parameters_on_remote_ = true;
1425 data_parser_.notify_one();
1426 }
1427
SetErrorInfo(const std::string & error_info)1428 void EmbeddingCachePrefetchActor::SetErrorInfo(const std::string &error_info) {
1429 static std::mutex mtx;
1430 std::lock_guard<std::mutex> lock(mtx);
1431 error_info_ = error_info;
1432 }
1433
BuildRpcOperators()1434 void EmbeddingCachePrefetchActor::BuildRpcOperators() {
1435 // The cache operation support LookupEmbeddingCache and UpdateEmbeddingCache currently.
1436 for (const auto &cache_op : distributed::kEmbeddingCacheOps) {
1437 rpc_operators_[cache_op] = std::vector<SendRecvPairList>();
1438 rpc_operators_[cache_op].resize(server_num_);
1439 }
1440
1441 auto node = distributed::cluster::ClusterContext::instance()->node();
1442 MS_EXCEPTION_IF_NULL(node);
1443 uint32_t worker_rank_id = node->rank_id();
1444
1445 // Create sender and receiver pairs for different cache operation, server and parameter key.
1446 for (auto &item : rpc_operators_) {
1447 const std::string &cache_op = item.first;
1448 std::vector<SendRecvPairList> &send_recv_pair_lists = item.second;
1449 for (uint32_t i = 0; i < server_num_; i++) {
1450 SendRecvPairList &send_recv_pair_list = send_recv_pair_lists[i];
1451 send_recv_pair_list.resize(embedding_cache_table_manager.hash_tables_.size());
1452
1453 for (const auto &table : embedding_cache_table_manager.hash_tables_) {
1454 int32_t key = table.second.param_key_;
1455 if (key >= SizeToInt(embedding_cache_table_manager.hash_tables_.size()) || key < 0) {
1456 MS_LOG(EXCEPTION) << "Invalid parameter key: " << key;
1457 }
1458
1459 send_recv_pair_list[key] = CreateSenderReceiverPair(worker_rank_id, i, cache_op, key, cpu_device_context_);
1460 }
1461 }
1462 }
1463 }
1464
LinkRpcOperators()1465 void EmbeddingCachePrefetchActor::LinkRpcOperators() {
1466 std::vector<SenderPtr> senders;
1467 std::vector<ReceiverPtr> receivers;
1468 for (const auto &item : rpc_operators_) {
1469 const std::vector<SendRecvPairList> &send_recv_pair_lists = item.second;
1470 for (const SendRecvPairList &send_recv_pair_list : send_recv_pair_lists) {
1471 for (const SendRecvPair &pair : send_recv_pair_list) {
1472 senders.push_back(pair.first);
1473 receivers.push_back(pair.second);
1474 }
1475 }
1476 }
1477
1478 // Must start server and register route table before looking up route and connecting.
1479 // Start servers of receiver and register route table.
1480 for (auto &receiver : receivers) {
1481 MS_EXCEPTION_IF_NULL(receiver);
1482 if (!receiver->StartServer()) {
1483 MS_LOG(EXCEPTION) << "Failed to start server for the receiver.";
1484 }
1485 }
1486
1487 // Lookup route and connect to servers for sender.
1488 for (auto &sender : senders) {
1489 MS_EXCEPTION_IF_NULL(sender);
1490 if (!sender->ConnectServer()) {
1491 MS_LOG(EXCEPTION) << "Failed to connect servers for the sender.";
1492 }
1493 }
1494 }
1495
Send(const std::vector<ShapeVector> & shapes,const std::vector<TypeId> data_types,const AddressPtrList & data_list,bool finalize_remote,bool sync) const1496 bool Sender::Send(const std::vector<ShapeVector> &shapes, const std::vector<TypeId> data_types,
1497 const AddressPtrList &data_list, bool finalize_remote, bool sync) const {
1498 MS_ERROR_IF_NULL(receiver_);
1499 auto message = BuildRpcMessage(shapes, data_types, data_list, receiver_->get_url(), server_url_, finalize_remote);
1500 MS_ERROR_IF_NULL(message);
1501 MS_ERROR_IF_NULL(client_);
1502 if (sync) {
1503 return client_->SendSync(std::move(message));
1504 }
1505
1506 client_->SendAsync(std::move(message));
1507 return true;
1508 }
1509
~Sender()1510 Sender::~Sender() {
1511 if (client_) {
1512 try {
1513 if (!client_->Disconnect(server_url_)) {
1514 MS_LOG(ERROR) << "Failed to disconnect tcp client.";
1515 }
1516 client_->Finalize();
1517 client_ = nullptr;
1518 } catch (const std::exception &e) {
1519 MS_LOG(ERROR) << "Failed to disconnect and finalize tcp client, error message: " << e.what();
1520 }
1521 }
1522 receiver_ = nullptr;
1523 }
1524
ConnectServer()1525 bool Sender::ConnectServer() {
1526 client_ = std::make_unique<TCPClient>();
1527 MS_ERROR_IF_NULL(client_);
1528 if (!client_->Initialize()) {
1529 MS_LOG(ERROR) << "Failed to initialize tcp server for send actor.";
1530 return false;
1531 }
1532
1533 // Lookup peer receiver addresses.
1534 MS_ERROR_IF_NULL(route_table_proxy_);
1535 auto peer_actor_address = route_table_proxy_->LookupRoute(inter_process_edge_);
1536 server_url_ = peer_actor_address.ip() + ":" + std::to_string(peer_actor_address.port());
1537
1538 auto free_callback = std::bind(&Sender::FreeMessage, this, std::placeholders::_1);
1539 size_t retry_count = 60;
1540
1541 bool ret = client_->Connect(server_url_, retry_count, free_callback);
1542 if (!ret) {
1543 MS_LOG(ERROR) << "Failed to connect to server of edge: " << inter_process_edge_ << ", server_url: " << server_url_;
1544 return false;
1545 }
1546
1547 MS_LOG(INFO) << "Successfully connect to server " << server_url_
1548 << ", inter process edge name: " << inter_process_edge_;
1549 return true;
1550 }
1551
BuildRpcMessage(const std::vector<ShapeVector> & shapes,const std::vector<TypeId> data_types,const AddressPtrList & data_list,const std::string & from_url,const std::string & to_url,bool finalize_remote) const1552 std::unique_ptr<MessageBase> Sender::BuildRpcMessage(const std::vector<ShapeVector> &shapes,
1553 const std::vector<TypeId> data_types,
1554 const AddressPtrList &data_list, const std::string &from_url,
1555 const std::string &to_url, bool finalize_remote) const {
1556 std::unique_ptr<MessageBase> message = std::make_unique<MessageBase>();
1557 MS_ERROR_IF_NULL_W_RET_VAL(message, nullptr);
1558 message->from = AID("", from_url);
1559 message->to = AID("", to_url);
1560
1561 if (shapes.size() != data_list.size()) {
1562 MS_LOG(ERROR) << "The shape list size[" << shapes.size() << "] should be equal to data list size["
1563 << data_list.size() << "]";
1564 }
1565
1566 if (data_types.size() != data_list.size()) {
1567 MS_LOG(ERROR) << "The date type list size[" << data_types.size() << "] should be equal to data list size["
1568 << data_list.size() << "]";
1569 }
1570
1571 RpcDataPtr rpc_data = nullptr;
1572 size_t data_size = CalDataSize(shapes, data_types, data_list, finalize_remote);
1573 MS_EXCEPTION_IF_NULL(cpu_device_context_);
1574 MS_EXCEPTION_IF_NULL(cpu_device_context_->device_res_manager_);
1575 rpc_data = static_cast<RpcDataPtr>(cpu_device_context_->device_res_manager_->AllocateMemory(data_size));
1576 MS_EXCEPTION_IF_NULL(rpc_data);
1577 message->data = rpc_data;
1578 message->size = data_size;
1579
1580 errno_t ret;
1581 size_t offset = 0;
1582 for (size_t i = 0; i < data_list.size(); i++) {
1583 const ShapeVector &shape = shapes[i];
1584 const AddressPtr &data = data_list[i];
1585 const TypeId &type_id = data_types[i];
1586
1587 rpc::DynamicShapeMessage ds_pb_msg;
1588 ds_pb_msg.set_type_id(type_id);
1589 *ds_pb_msg.mutable_shape_vector() = {shape.begin(), shape.end()};
1590 std::string ds_pb_msg_str = ds_pb_msg.SerializeAsString();
1591
1592 // Message format:
1593 // |RPC_DYNAMIC_SHAPE_DATA | dynamic shape PB data size |---dynamic shape PB data----|---real data----|
1594 // 1. The dynamic shape header.
1595 if ((ret = memcpy_s(rpc_data + offset, strlen(kRpcDynamicShapeData), kRpcDynamicShapeData,
1596 strlen(kRpcDynamicShapeData))) != EOK) {
1597 MS_LOG(EXCEPTION) << "Failed to memcpy_s for kRpcDynamicShapeData, errno[" << ret << "].";
1598 }
1599 offset += strlen(kRpcDynamicShapeData);
1600
1601 // 2. The size of the protobuf DynamicShapeMessage.
1602 size_t ds_pb_msg_size = ds_pb_msg_str.size();
1603 if ((ret = memcpy_s(rpc_data + offset, sizeof(ds_pb_msg_size), &ds_pb_msg_size, sizeof(ds_pb_msg_size))) != EOK) {
1604 MS_LOG(EXCEPTION) << "Failed to memcpy_s for pb message size, errno[" << ret << "].";
1605 }
1606 offset += sizeof(ds_pb_msg_size);
1607
1608 // 3. Protobuf DynamicShapeMessage.
1609 if ((ret = memcpy_s(rpc_data + offset, ds_pb_msg_str.size(), ds_pb_msg_str.c_str(), ds_pb_msg_str.size())) != EOK) {
1610 MS_LOG(EXCEPTION) << "Failed to memcpy_s for pb message, errno[" << ret << "].";
1611 }
1612 offset += ds_pb_msg_str.size();
1613
1614 // 4. The real data buffer need to be sent.
1615 MS_EXCEPTION_IF_NULL(data);
1616 if ((ret = memcpy_s(rpc_data + offset, data->size, data->addr, data->size)) != EOK) {
1617 MS_LOG(EXCEPTION) << "Failed to memcpy_s for real data, errno[" << ret << "].";
1618 }
1619 offset += data->size;
1620 }
1621
1622 // 5. Finalize remote command.
1623 if (finalize_remote) {
1624 size_t header_len = strlen(distributed::kFinalizeMuxRecvActor);
1625 if ((ret = memcpy_s(rpc_data + offset, header_len, distributed::kFinalizeMuxRecvActor, header_len)) != EOK) {
1626 MS_LOG(EXCEPTION) << "Failed to memcpy_s for kFinalizeMuxRecvActor, errno[" << ret << "].";
1627 }
1628 offset += header_len;
1629
1630 if ((ret = memcpy_s(rpc_data + offset, sizeof(finalize_remote), &finalize_remote, sizeof(finalize_remote))) !=
1631 EOK) {
1632 MS_LOG(EXCEPTION) << "Failed to memcpy_s for finalize_remote, errno[" << ret << "].";
1633 }
1634 }
1635
1636 return message;
1637 }
1638
FreeMessage(void * data)1639 bool Sender::FreeMessage(void *data) {
1640 MS_EXCEPTION_IF_NULL(cpu_device_context_);
1641 MS_EXCEPTION_IF_NULL(cpu_device_context_->device_res_manager_);
1642 MS_ERROR_IF_NULL_W_RET_VAL(data, false);
1643 cpu_device_context_->device_res_manager_->FreeMemory(data);
1644 return true;
1645 }
1646
CalDataSize(const std::vector<ShapeVector> & shapes,const std::vector<TypeId> data_types,const AddressPtrList & data_list,bool finalize_remote) const1647 size_t Sender::CalDataSize(const std::vector<ShapeVector> &shapes, const std::vector<TypeId> data_types,
1648 const AddressPtrList &data_list, bool finalize_remote) const {
1649 size_t data_size = 0;
1650 for (size_t i = 0; i < data_list.size(); i++) {
1651 const ShapeVector &shape = shapes[i];
1652 const AddressPtr &data = data_list[i];
1653 const TypeId &type_id = data_types[i];
1654
1655 rpc::DynamicShapeMessage ds_pb_msg;
1656 ds_pb_msg.set_type_id(type_id);
1657 *ds_pb_msg.mutable_shape_vector() = {shape.begin(), shape.end()};
1658 std::string ds_pb_msg_str = ds_pb_msg.SerializeAsString();
1659 data_size += strlen(kRpcDynamicShapeData);
1660 data_size += sizeof(size_t);
1661 data_size += ds_pb_msg_str.size();
1662 MS_EXCEPTION_IF_NULL(data);
1663 data_size += data->size;
1664 }
1665 if (finalize_remote) {
1666 data_size += strlen(distributed::kFinalizeMuxRecvActor);
1667 data_size += sizeof(finalize_remote);
1668 }
1669 return data_size;
1670 }
1671
~Receiver()1672 Receiver::~Receiver() {
1673 if (server_) {
1674 try {
1675 server_->Finalize();
1676 server_ = nullptr;
1677 } catch (const std::exception &e) {
1678 MS_LOG(ERROR) << "Failed to finalize tcp server, error message: " << e.what();
1679 }
1680 }
1681 received_buffer_ = nullptr;
1682 }
1683
Receive()1684 std::unique_ptr<std::vector<char>> Receiver::Receive() {
1685 std::unique_lock<std::mutex> locker(received_msg_mtx_);
1686 // The maximum time(300 seconds) to wait to receive message.
1687 const int64_t longest_time_to_wait = 300;
1688 auto ret = received_msg_cv_.wait_for(locker, std::chrono::seconds(longest_time_to_wait),
1689 [this] { return received_msg_.load(); });
1690 if (!ret) {
1691 MS_LOG(ERROR) << "Receive message timeout";
1692 return nullptr;
1693 }
1694
1695 std::unique_ptr<std::vector<char>> output = std::move(received_buffer_);
1696 MS_EXCEPTION_IF_NULL(output);
1697 received_msg_ = false;
1698 return output;
1699 }
1700
StartServer()1701 bool Receiver::StartServer() {
1702 // 1. Create a tcp server and start listening.
1703 server_ = std::make_unique<TCPServer>();
1704 MS_EXCEPTION_IF_NULL(server_);
1705
1706 std::function<void *(size_t size)> allocate_callback =
1707 std::bind(&Receiver::AllocateMessage, this, std::placeholders::_1);
1708 if (!server_->Initialize(allocate_callback)) {
1709 MS_LOG(EXCEPTION) << "Failed to initialize tcp server for recv actor";
1710 }
1711 ip_ = server_->GetIP();
1712 port_ = server_->GetPort();
1713 std::string server_url = ip_ + ":" + std::to_string(port_);
1714
1715 // 2. Set the message handler of the server.
1716 server_->SetMessageHandler(std::bind(&Receiver::HandleMessage, this, std::placeholders::_1));
1717
1718 // 3. Register the server address to route table. The server should not be connected before this step is done.
1719 MS_LOG(INFO) << "Start server for receiver. Server address: " << server_url
1720 << ", inter process edge name: " << inter_process_edge_;
1721 distributed::cluster::topology::ActorAddress recv_actor_addresss;
1722 recv_actor_addresss.set_actor_id(inter_process_edge_);
1723 recv_actor_addresss.set_ip(ip_);
1724 recv_actor_addresss.set_port(port_);
1725 MS_EXCEPTION_IF_NULL(route_table_proxy_);
1726 if (!route_table_proxy_->RegisterRoute(inter_process_edge_, recv_actor_addresss)) {
1727 MS_LOG(EXCEPTION) << "Failed to register route for " << inter_process_edge_ << " " << server_url
1728 << " when starting server.";
1729 }
1730 return true;
1731 }
1732
ParseDynamicShapeData(const char * msg_body,size_t msg_len,std::pair<const void *,size_t> * data) const1733 bool Receiver::ParseDynamicShapeData(const char *msg_body, size_t msg_len,
1734 std::pair<const void *, size_t> *data) const {
1735 MS_ERROR_IF_NULL(msg_body);
1736 MS_ERROR_IF_NULL(data);
1737 // 1. Check whether received data is valid dynamic shape data.
1738 size_t dynamic_shape_header_size = strlen(kRpcDynamicShapeData);
1739 if (msg_len <= dynamic_shape_header_size) {
1740 MS_LOG(ERROR) << "Received data is not dynamic shape, data length: " << msg_len;
1741 return false;
1742 }
1743 std::string msg_dynamic_shape_header(msg_body, dynamic_shape_header_size);
1744 if (msg_dynamic_shape_header != kRpcDynamicShapeData) {
1745 MS_LOG(ERROR) << "Received data is not dynamic shape, not find dynamic shape header: " << kRpcDynamicShapeData;
1746 return false;
1747 }
1748
1749 size_t offset = dynamic_shape_header_size;
1750 // 2. Parse the size of dynamic shape serialized protobuf message.
1751 if (offset + sizeof(size_t) >= msg_len) {
1752 MS_LOG(ERROR) << "Received data is incomplete";
1753 return false;
1754 }
1755 size_t dynamic_shape_pb_size = *(reinterpret_cast<const size_t *>(msg_body + offset));
1756 offset += sizeof(size_t);
1757 if (offset + dynamic_shape_pb_size >= msg_len) {
1758 MS_LOG(ERROR) << "The dynamic shape pb data is incomplete";
1759 return false;
1760 }
1761
1762 // 3. Deserialize the dynamic shape serialized protobuf message.
1763 rpc::DynamicShapeMessage pb_msg;
1764 (void)pb_msg.ParseFromArray(msg_body + offset, dynamic_shape_pb_size);
1765 offset += dynamic_shape_pb_size;
1766 size_t received_data_len = msg_len - offset;
1767
1768 // 4. The data integrity check.
1769 ShapeVector shapes(pb_msg.shape_vector().begin(), pb_msg.shape_vector().end());
1770 TypeId data_type = static_cast<TypeId>(pb_msg.type_id());
1771 int64_t expected_data_len = 1;
1772 if (!kernel::GetShapeSize(shapes, TypeIdToType(data_type), &expected_data_len)) {
1773 MS_LOG(ERROR) << "Getting shape size for shape " << shapes << " failed.";
1774 return false;
1775 }
1776 if (LongToSize(expected_data_len) != received_data_len) {
1777 MS_LOG(ERROR) << "Received data is incomplete, expected size: " << expected_data_len
1778 << ", but received data size: " << received_data_len;
1779 return false;
1780 }
1781 // 5. Get real data addr and size.
1782 *data = std::make_pair(msg_body + offset, received_data_len);
1783 return true;
1784 }
1785
HandleMessage(MessageBase * const msg)1786 MessageBase *Receiver::HandleMessage(MessageBase *const msg) {
1787 if (msg == nullptr) {
1788 MS_LOG(WARNING) << "Received message pointer is nullptr";
1789 return distributed::rpc::NULL_MSG;
1790 }
1791
1792 RpcDataPtr data = static_cast<RpcDataPtr>(msg->data);
1793 size_t data_size = msg->size;
1794 // The data pair: <addr of data, size of data>.
1795 std::pair<const void *, size_t> real_data;
1796 // Get real data addr and size.
1797 if (!ParseDynamicShapeData(data, data_size, &real_data)) {
1798 MS_LOG(EXCEPTION) << "Parse dynamic shape data failed.";
1799 }
1800
1801 std::unique_lock<std::mutex> locker(received_msg_mtx_);
1802 received_buffer_ = std::make_unique<std::vector<char>>();
1803 received_buffer_->resize(real_data.second);
1804 MS_EXCEPTION_IF_NULL(real_data.first);
1805
1806 int ret = memcpy_s(received_buffer_->data(), received_buffer_->size(), real_data.first, real_data.second);
1807 if (ret != 0) {
1808 MS_LOG(EXCEPTION) << "Memcpy for received data failed, errno[" << ret << "]";
1809 }
1810
1811 received_msg_ = true;
1812 received_msg_cv_.notify_one();
1813
1814 MS_EXCEPTION_IF_NULL(cpu_device_context_);
1815 MS_EXCEPTION_IF_NULL(cpu_device_context_->device_res_manager_);
1816 cpu_device_context_->device_res_manager_->FreeMemory(data);
1817
1818 delete msg;
1819 return distributed::rpc::NULL_MSG;
1820 }
1821
AllocateMessage(size_t size)1822 void *Receiver::AllocateMessage(size_t size) {
1823 MS_EXCEPTION_IF_NULL(cpu_device_context_);
1824 MS_EXCEPTION_IF_NULL(cpu_device_context_->device_res_manager_);
1825 void *data = cpu_device_context_->device_res_manager_->AllocateMemory(size);
1826 MS_EXCEPTION_IF_NULL(data);
1827 return data;
1828 }
1829 } // namespace runtime
1830 } // namespace mindspore
1831