1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/graph_scheduler/embedding_cache_scheduler.h"
18
19 #include <string>
20 #include <memory>
21 #include <functional>
22 #include "ops/structure_op_name.h"
23 #include "ops/sparse_op_name.h"
24 #include "ops/math_op_name.h"
25 #include "ops/array_op_name.h"
26 #include "ops/framework_op_name.h"
27 #include "ops/other_op_name.h"
28 #include "runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h"
29 #include "runtime/graph_scheduler/device_tensor_store.h"
30 #include "include/backend/distributed/embedding_cache/embedding_cache_utils.h"
31 #include "utils/ms_context.h"
32 #include "include/common/utils/parallel_context.h"
33
34 namespace mindspore {
35 namespace runtime {
36 using session::KernelGraph;
37 namespace {
CheckEnableEmbeddingCache()38 bool CheckEnableEmbeddingCache() {
39 return ps::PSContext::instance()->cache_enable() && distributed::cluster::ClusterContext::instance()->initialized() &&
40 ps::PSContext::instance()->is_worker();
41 }
42
CheckEmbeddingCacheServer()43 bool CheckEmbeddingCacheServer() {
44 return ps::PSContext::instance()->cache_enable() && distributed::cluster::ClusterContext::instance()->initialized() &&
45 ps::PSContext::instance()->is_server();
46 }
47
48 // Whether device address exist.
NodeDeviceAddressExist(const DeviceContext * device_context,const AnfNodePtr & kernel,size_t index)49 bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &kernel, size_t index) {
50 MS_EXCEPTION_IF_NULL(kernel);
51 MS_EXCEPTION_IF_NULL(device_context);
52 if (AnfAlgo::OutputAddrExist(kernel, index)) {
53 const auto &address = AnfAlgo::GetOutputAddr(kernel, index, false);
54 MS_EXCEPTION_IF_NULL(address);
55 return address->GetDeviceType() == device_context->GetDeviceType();
56 }
57 return false;
58 }
59
60 // Finalize ps cache module before throw an exception.
FinalizeEmbeddingCachePrefetch(const std::string & exception)61 void FinalizeEmbeddingCachePrefetch(const std::string &exception) {
62 MS_LOG(INFO) << "Begin finalize the EmbeddingCacheScheduler.";
63 EmbeddingCacheScheduler::GetInstance().Finalize(false);
64 MS_LOG(INFO) << "End finalize the EmbeddingCacheScheduler.";
65 MS_LOG(EXCEPTION) << exception;
66 }
67
GetFirstEmbeddingCacheTableInfo(const KernelGraph & graph,AnfNodePtr * const first_cache_input_index,size_t * output_idx,size_t * const first_cache_size)68 void GetFirstEmbeddingCacheTableInfo(const KernelGraph &graph, AnfNodePtr *const first_cache_input_index,
69 size_t *output_idx, size_t *const first_cache_size) {
70 MS_EXCEPTION_IF_NULL(first_cache_input_index);
71 MS_EXCEPTION_IF_NULL(first_cache_size);
72 for (const auto &kernel : graph.execution_order()) {
73 MS_EXCEPTION_IF_NULL(kernel);
74 const mindspore::HashSet<std::string> kNeedCheckNodes = {kGatherOpName, kSparseGatherV2OpName, kGatherV2OpName,
75 kGatherV2DOpName, kMapTensorGetOpName};
76 auto kernel_name = common::AnfAlgo::GetCNodeName(kernel);
77 if (kNeedCheckNodes.find(kernel_name) == kNeedCheckNodes.end()) {
78 continue;
79 }
80 auto input_param = common::AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
81 auto input_index = common::AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
82 MS_EXCEPTION_IF_NULL(input_param.first);
83 MS_EXCEPTION_IF_NULL(input_index.first);
84 auto param_name = input_param.first->fullname_with_scope();
85 if (!embedding_cache_table_manager.IsEmbeddingCacheTable(param_name)) {
86 continue;
87 }
88 auto size = embedding_cache_table_manager.QueryHashTableSize(param_name);
89 while (input_index.first->isa<CNode>() &&
90 ((common::AnfAlgo::GetCNodeName(input_index.first) == kCastOpName) ||
91 (common::AnfAlgo::GetCNodeName(input_index.first) == kTensorMoveOpName))) {
92 input_index = common::AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
93 MS_EXCEPTION_IF_NULL(input_index.first);
94 }
95 auto cnode = common::AnfAlgo::IsGraphKernel(input_index.first)
96 ? common::AnfAlgo::GetOutputOfGraphkernel(input_index)
97 : input_index.first;
98 MS_EXCEPTION_IF_NULL(cnode);
99 if (!cnode->isa<CNode>()) {
100 FinalizeEmbeddingCachePrefetch("The EmbeddingLookup whose input index should be a CNode but got " +
101 cnode->fullname_with_scope());
102 }
103 if (!common::AnfAlgo::IsGetNextNode(cnode)) {
104 auto input_index_node_name = common::AnfAlgo::GetCNodeName(cnode);
105 bool full_batch = parallel::ParallelContext::GetInstance()->full_batch();
106 if ((!full_batch && (input_index_node_name != kUniqueOpName)) ||
107 (full_batch && (input_index_node_name != kMinimumOpName))) {
108 MS_LOG(ERROR) << "The input index of the EmbeddingLookup(" << kernel->fullname_with_scope()
109 << ") cache is from " << cnode->fullname_with_scope();
110 FinalizeEmbeddingCachePrefetch(
111 "The EmbeddingLookup whose input index isn't from dataset doesn't support cache in parameter server training "
112 "mode.");
113 }
114 }
115 *output_idx = input_index.second;
116 *first_cache_input_index = cnode;
117 *first_cache_size = size;
118 MS_LOG(INFO) << "The input index of the first EmbeddingLookup cache is from " << cnode->fullname_with_scope()
119 << ", the cache size is " << size;
120 return;
121 }
122 }
123
CheckSparseModeForEmbeddingCache(const CNodePtr & node)124 void CheckSparseModeForEmbeddingCache(const CNodePtr &node) {
125 MS_EXCEPTION_IF_NULL(node);
126 auto pre_node = common::AnfAlgo::GetPrevNodeOutput(node, 1, true);
127 MS_EXCEPTION_IF_NULL(pre_node.first);
128 while (pre_node.first->isa<CNode>() && (common::AnfAlgo::GetCNodeName(pre_node.first) != kUniqueOpName)) {
129 pre_node = common::AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
130 MS_EXCEPTION_IF_NULL(pre_node.first);
131 }
132 if (!(pre_node.first->isa<CNode>()) || (common::AnfAlgo::GetCNodeName(pre_node.first) != kUniqueOpName)) {
133 FinalizeEmbeddingCachePrefetch(std::string("The input_indices of kernel[") + node->DebugString() +
134 "] must be unique in parameter server cache mode");
135 }
136
137 pre_node = common::AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
138 MS_EXCEPTION_IF_NULL(pre_node.first);
139 while (pre_node.first->isa<CNode>() && ((common::AnfAlgo::GetCNodeName(pre_node.first) == kCastOpName) ||
140 (common::AnfAlgo::GetCNodeName(pre_node.first) == kTensorMoveOpName))) {
141 pre_node = common::AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
142 MS_EXCEPTION_IF_NULL(pre_node.first);
143 }
144 if (!(pre_node.first->isa<CNode>()) || (!common::AnfAlgo::IsGetNextNode(pre_node.first))) {
145 FinalizeEmbeddingCachePrefetch(
146 "The input indices of kernel[Unique] must be produced from dataset directly and the indices value can not be "
147 "changed before delivering to kernel[Unique] in parameter server cache mode.");
148 }
149 }
150
ShouldCheckSparseMode(const std::string & param_name,const std::string & kernel_name,bool is_sparse_gather)151 bool ShouldCheckSparseMode(const std::string ¶m_name, const std::string &kernel_name, bool is_sparse_gather) {
152 if (embedding_cache_table_manager.IsEmbeddingCacheTable(param_name)) {
153 if (kernel_name == kSparseGatherV2OpName || is_sparse_gather) {
154 return true;
155 }
156 }
157 return false;
158 }
159
CheckGraphValidForEmbeddingCache(const KernelGraph & graph)160 void CheckGraphValidForEmbeddingCache(const KernelGraph &graph) {
161 AnfNodePtr first_cache_input_index = nullptr;
162 size_t first_cache_size = 0;
163 size_t output_idx = 0;
164 GetFirstEmbeddingCacheTableInfo(graph, &first_cache_input_index, &output_idx, &first_cache_size);
165 MS_EXCEPTION_IF_NULL(first_cache_input_index);
166 for (const auto &kernel : graph.execution_order()) {
167 MS_EXCEPTION_IF_NULL(kernel);
168 auto kernel_name = common::AnfAlgo::GetCNodeName(kernel);
169 const mindspore::HashSet<std::string> kNeedCacheNodes = {kGatherOpName, kSparseGatherV2OpName, kGatherV2OpName,
170 kGatherV2DOpName, kMapTensorGetOpName};
171 if (kNeedCacheNodes.count(kernel_name) == 0) {
172 continue;
173 }
174 auto input_param = common::AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
175 auto input_index = common::AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
176 MS_EXCEPTION_IF_NULL(input_param.first);
177 MS_EXCEPTION_IF_NULL(input_index.first);
178 if (!input_param.first->isa<Parameter>()) {
179 continue;
180 }
181 auto param_name = input_param.first->fullname_with_scope();
182 // In ascend, change kSparseGatherV2OpName to kGatherV2OpName & set attr sparse: true
183 bool is_sparse_gather = false;
184 if (kernel_name == kGatherV2OpName && common::AnfAlgo::HasNodeAttr(kAttrIsSparse, kernel)) {
185 is_sparse_gather = common::AnfAlgo::GetNodeAttr<bool>(kernel, kAttrIsSparse);
186 }
187 if (ShouldCheckSparseMode(param_name, kernel_name, is_sparse_gather)) {
188 CheckSparseModeForEmbeddingCache(kernel);
189 }
190 while (input_index.first->isa<CNode>() &&
191 ((common::AnfAlgo::GetCNodeName(input_index.first) == kCastOpName) ||
192 (common::AnfAlgo::GetCNodeName(input_index.first) == kTensorMoveOpName))) {
193 input_index = common::AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
194 MS_EXCEPTION_IF_NULL(input_index.first);
195 }
196 auto cnode = common::AnfAlgo::IsGraphKernel(input_index.first)
197 ? common::AnfAlgo::GetOutputOfGraphkernel(input_index)
198 : input_index.first;
199 MS_EXCEPTION_IF_NULL(cnode);
200 if (cnode == first_cache_input_index && input_index.second == output_idx) {
201 if (!embedding_cache_table_manager.IsEmbeddingCacheTable(param_name)) {
202 MS_LOG(ERROR) << "The EmbeddingLookup(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
203 FinalizeEmbeddingCachePrefetch(
204 "All the embeddingLookups whose input indices are from dataset must enable cache at the same time when one "
205 "of them enables cache in parameter server training mode.");
206 }
207 auto size = embedding_cache_table_manager.QueryHashTableSize(param_name);
208 if (size != first_cache_size) {
209 MS_LOG(ERROR) << "The cache size(" << size << ") of EmbeddingLookup(" << kernel->fullname_with_scope()
210 << ") is not the same as other EmbeddingLookup cache size(" << first_cache_size << ").";
211 FinalizeEmbeddingCachePrefetch(
212 "The cache sizes of embeddingLookups are not the same in parameter server training mode.");
213 }
214 } else if (embedding_cache_table_manager.IsEmbeddingCacheTable(param_name)) {
215 MS_LOG(ERROR) << "The input index of the EmbeddingLookup(" << kernel->fullname_with_scope() << ") cache is from "
216 << cnode->fullname_with_scope();
217 FinalizeEmbeddingCachePrefetch(
218 "The EmbeddingLookup whose input index isn't from dataset doesn't support cache in parameter server training "
219 "mode.");
220 } else if (cnode->isa<CNode>() && (common::AnfAlgo::IsGetNextNode(cnode)) && (input_index.second == output_idx)) {
221 MS_LOG(ERROR) << "The EmbeddingLookup kernel(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
222 FinalizeEmbeddingCachePrefetch(
223 "All EmbeddingLookup kernels whose input indices are from dataset must enable cache at the same time.");
224 }
225 }
226 }
227 } // namespace
228
GetInstance()229 EmbeddingCacheScheduler &EmbeddingCacheScheduler::GetInstance() {
230 static EmbeddingCacheScheduler instance{};
231 if (!instance.initialized_) {
232 instance.Initialize();
233 }
234 return instance;
235 }
236
Initialize()237 void EmbeddingCacheScheduler::Initialize() {
238 if (!CheckEnableEmbeddingCache()) {
239 return;
240 }
241 if (initialized_) {
242 return;
243 }
244
245 // Get or Create device context.
246 auto ms_context = MsContext::GetInstance();
247 MS_EXCEPTION_IF_NULL(ms_context);
248 std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
249 uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
250 DeviceContext *device_context =
251 device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
252 MS_EXCEPTION_IF_NULL(device_context);
253 device_context->Initialize();
254
255 // Create and initialize EmbeddingCachePrefetchActor.
256 embedding_cache_prefetch_actor_ = std::make_shared<EmbeddingCachePrefetchActor>(device_context);
257 MS_EXCEPTION_IF_NULL(embedding_cache_prefetch_actor_);
258
259 initialized_ = true;
260 }
261
ParseBatchIdsNum(const KernelGraphPtr & graph)262 void EmbeddingCacheScheduler::ParseBatchIdsNum(const KernelGraphPtr &graph) {
263 if (parsed_batch_ids_num_) {
264 return;
265 }
266
267 // 1. Find InitDataSetQueue kernel.
268 MS_EXCEPTION_IF_NULL(graph);
269 const auto &kernels = graph->execution_order();
270 auto iter = find_if(kernels.begin(), kernels.end(), [](const CNodePtr &kernel) {
271 MS_EXCEPTION_IF_NULL(kernel);
272 return common::AnfAlgo::GetCNodeName(kernel) == kInitDatasetQueueOpName;
273 });
274 if (iter == kernels.end()) {
275 MS_LOG(EXCEPTION) << "Can not find InitDataSetQueue kernel";
276 }
277
278 const auto &kernel = *iter;
279 MS_EXCEPTION_IF_NULL(kernel);
280
281 // 2. Get shape of InitDataSetQueue kernel.
282 std::vector<std::vector<int64_t>> shapes;
283 if (common::AnfAlgo::IsDynamicShape(kernel)) {
284 shapes = common::AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel, "max_shapes");
285 } else {
286 shapes = common::AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel, "shapes");
287 }
288 auto types = common::AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel, "types");
289 if (shapes.size() != types.size() || shapes.size() == 0 || types.size() == 0) {
290 MS_LOG(EXCEPTION) << "Invalid shapes of op[InitDataSetQueue]: shapes size " << shapes.size() << ", types size "
291 << types;
292 }
293
294 const TypePtr &id_type = types.front();
295 MS_EXCEPTION_IF_NULL(id_type);
296 if (id_type->type_id() != kInt32->type_id() && id_type->type_id() != kInt->type_id()) {
297 MS_LOG(EXCEPTION) << "Embedding cache mode need input ids with data type[" << kInt32->ToString() << " or "
298 << kInt->ToString() << "], but got[" << id_type->ToString() << "]";
299 }
300
301 // 3. Get batch ids num(not batch size).
302 const auto &shape = shapes[0];
303 auto batch_ids_num = LongToSize(std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>()));
304 embedding_cache_table_manager.Initialize();
305 embedding_cache_table_manager.set_batch_ids_num(batch_ids_num);
306
307 parsed_batch_ids_num_ = true;
308 }
309
AllocMemForEmbeddingCacheTable(const DeviceContext * device_context)310 void EmbeddingCacheScheduler::AllocMemForEmbeddingCacheTable(const DeviceContext *device_context) {
311 if (allocated_embed_cache_mem_) {
312 return;
313 }
314
315 embedding_cache_table_manager.AllocMemForEmbedding(device_context);
316 allocated_embed_cache_mem_ = true;
317 }
318
SetEmbedCachedParamAddress(const DeviceContext * device_context,const KernelGraphPtr & graph)319 void EmbeddingCacheScheduler::SetEmbedCachedParamAddress(const DeviceContext *device_context,
320 const KernelGraphPtr &graph) {
321 if (!CheckEnableEmbeddingCache()) {
322 return;
323 }
324
325 MS_EXCEPTION_IF_NULL(device_context);
326 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
327 MS_EXCEPTION_IF_NULL(graph);
328
329 // 1. Get batch ids number before allocate device memory for embedding cache table.
330 ParseBatchIdsNum(graph);
331
332 const std::vector<AnfNodePtr> &input_nodes = graph->input_nodes();
333 bool exist_embedding_cache_table = std::any_of(input_nodes.begin(), input_nodes.end(), [](const AnfNodePtr &node) {
334 MS_EXCEPTION_IF_NULL(node);
335 return embedding_cache_table_manager.IsEmbeddingCacheTable(node->fullname_with_scope());
336 });
337 if (!exist_embedding_cache_table) {
338 return;
339 }
340
341 // Graph valid check.
342 // The sparse mode does not perform graph structure verification currently.
343 if (!embedding_cache_table_manager.is_sparse_format()) {
344 CheckGraphValidForEmbeddingCache(*graph);
345 }
346
347 // 2. Set parameter device address to embedding cache table.
348 for (const auto &node : input_nodes) {
349 MS_EXCEPTION_IF_NULL(node);
350 const std::string ¶m_name = node->fullname_with_scope();
351 if (!embedding_cache_table_manager.IsEmbeddingCacheTable(param_name)) {
352 continue;
353 }
354
355 if (node->isa<Parameter>() && !NodeDeviceAddressExist(device_context, node, 0)) {
356 MS_LOG_WITH_NODE(EXCEPTION, node) << "Not found device address for parameter: " << node->fullname_with_scope();
357 }
358
359 if (embedding_cache_table_manager.QueryEmbeddingDeviceAddress(param_name) == nullptr) {
360 embedding_cache_table_manager.SetEmbeddingDeviceAddress(param_name, AnfAlgo::GetMutableOutputAddr(node, 0).get());
361 }
362 }
363
364 // 3. Allocate device memory for embedding cache table.
365 AllocMemForEmbeddingCacheTable(device_context);
366 }
367
SetDataSetChannel(const std::string & actor_id,const std::vector<KernelGraphPtr> & graphs)368 void EmbeddingCacheScheduler::SetDataSetChannel(const std::string &actor_id,
369 const std::vector<KernelGraphPtr> &graphs) {
370 if (!CheckEnableEmbeddingCache()) {
371 return;
372 }
373
374 for (const auto &graph : graphs) {
375 MS_EXCEPTION_IF_NULL(graph);
376 for (const auto &kernel_node : graph->execution_order()) {
377 if (!common::AnfAlgo::IsGetNextNode(kernel_node)) {
378 continue;
379 }
380
381 if (!common::AnfAlgo::HasNodeAttr("shared_name", kernel_node)) {
382 MS_LOG(EXCEPTION) << "Can not find attr[shared_name] of GetNext";
383 }
384 (void)data_prepare_aid_to_data_channel_.emplace(
385 actor_id, common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, "shared_name"));
386 break;
387 }
388 }
389 }
390
InitEmbeddingStorage(const std::vector<AnfNodePtr> & parameters) const391 void EmbeddingCacheScheduler::InitEmbeddingStorage(const std::vector<AnfNodePtr> ¶meters) const {
392 if (!CheckEmbeddingCacheServer()) {
393 return;
394 }
395
396 for (size_t i = 0; i < parameters.size(); i++) {
397 MS_EXCEPTION_IF_NULL(parameters[i]);
398 if (!parameters[i]->isa<Parameter>()) {
399 MS_LOG_WITH_NODE(EXCEPTION, parameters[i])
400 << "The node with name: " << parameters[i]->fullname_with_scope() << "is not a Parameter.";
401 }
402
403 ParameterPtr param = parameters[i]->cast<ParameterPtr>();
404 MS_EXCEPTION_IF_NULL(param);
405 auto param_info = param->param_info();
406 // Check whether enable embedding storage for the parameter.
407 bool enable_embedding_storage =
408 param_info && param_info->key() != -1 && param_info->cache_enable() && param_info->use_persistent_storage();
409 if (!enable_embedding_storage) {
410 continue;
411 }
412
413 auto embed_storage = embedding_storage_manager.Get(param_info->key());
414 MS_EXCEPTION_IF_NULL(embed_storage);
415 std::vector<DeviceTensorPtr> device_tensors = DeviceTensorStore::GetInstance().Fetch(param.get());
416 if (device_tensors.size() != 1) {
417 MS_LOG(EXCEPTION)
418 << "The device tensor size for embedding table which enables embedding storage should be 1, but got:"
419 << device_tensors.size();
420 }
421
422 // Initialize embedding storage instance.
423 embed_storage->Initialize(device_tensors.front().get());
424 }
425 }
426
Schedule()427 void EmbeddingCacheScheduler::Schedule() {
428 if (!initialized_ || scheduled_) {
429 return;
430 }
431
432 // 1. Initialize embedding cache prefetch actor and build network connection inter process.
433 MS_EXCEPTION_IF_NULL(embedding_cache_prefetch_actor_);
434 embedding_cache_prefetch_actor_->Initialize();
435
436 // 2. Spawn embedding cache prefetch actor.
437 auto actor_manager = ActorMgr::GetActorMgrRef();
438 // Bind single thread to execute embedding cache prefetch actor.
439 (void)actor_manager->Spawn(embedding_cache_prefetch_actor_, false);
440
441 // 3. Run embedding cache prefetch actor.
442 ActorDispatcher::Send(embedding_cache_prefetch_actor_->GetAID(), &EmbeddingCachePrefetchActor::Run);
443
444 scheduled_ = true;
445 }
446
IncreaseGraphStep(const std::string & actor_id) const447 void EmbeddingCacheScheduler::IncreaseGraphStep(const std::string &actor_id) const {
448 if (!CheckEnableEmbeddingCache()) {
449 return;
450 }
451
452 auto iter = data_prepare_aid_to_data_channel_.find(actor_id);
453 if (iter != data_prepare_aid_to_data_channel_.end()) {
454 MS_EXCEPTION_IF_NULL(embedding_cache_prefetch_actor_);
455 embedding_cache_prefetch_actor_->IncreaseGraphStep(iter->second);
456 }
457 }
458
SyncEmbeddingTable() const459 void EmbeddingCacheScheduler::SyncEmbeddingTable() const {
460 MS_EXCEPTION_IF_NULL(embedding_cache_prefetch_actor_);
461 embedding_cache_prefetch_actor_->SyncEmbeddingTable();
462 }
463
Finalize(bool sync_embedding_table)464 void EmbeddingCacheScheduler::Finalize(bool sync_embedding_table) {
465 std::lock_guard<std::mutex> lock(finalize_mutex_);
466 if (!initialized_ || finalized_) {
467 return;
468 }
469
470 MS_LOG(INFO) << "Begin finalize EmbeddingCacheScheduler";
471 if (sync_embedding_table) {
472 SyncEmbeddingTable();
473 }
474
475 MS_EXCEPTION_IF_NULL(embedding_cache_prefetch_actor_);
476 // Stop the embedding cache prefetch_actor.
477 bool finalize_remote = sync_embedding_table;
478 embedding_cache_prefetch_actor_->Finalize(finalize_remote);
479
480 // Get or Create device context.
481 auto ms_context = MsContext::GetInstance();
482 MS_EXCEPTION_IF_NULL(ms_context);
483 std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
484 uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
485 DeviceContext *device_context =
486 device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
487 MS_EXCEPTION_IF_NULL(device_context);
488 device_context->Initialize();
489 embedding_cache_table_manager.Finalize(device_context);
490
491 embedding_storage_manager.Clear();
492
493 initialized_ = false;
494 finalized_ = true;
495 MS_LOG(INFO) << "End finalize EmbeddingCacheScheduler";
496 }
497 } // namespace runtime
498 } // namespace mindspore
499