• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/embedding_cache_scheduler.h"
18 
19 #include <string>
20 #include <memory>
21 #include <functional>
22 #include "ops/structure_op_name.h"
23 #include "ops/sparse_op_name.h"
24 #include "ops/math_op_name.h"
25 #include "ops/array_op_name.h"
26 #include "ops/framework_op_name.h"
27 #include "ops/other_op_name.h"
28 #include "runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h"
29 #include "runtime/graph_scheduler/device_tensor_store.h"
30 #include "include/backend/distributed/embedding_cache/embedding_cache_utils.h"
31 #include "utils/ms_context.h"
32 #include "include/common/utils/parallel_context.h"
33 
34 namespace mindspore {
35 namespace runtime {
36 using session::KernelGraph;
37 namespace {
CheckEnableEmbeddingCache()38 bool CheckEnableEmbeddingCache() {
39   return ps::PSContext::instance()->cache_enable() && distributed::cluster::ClusterContext::instance()->initialized() &&
40          ps::PSContext::instance()->is_worker();
41 }
42 
CheckEmbeddingCacheServer()43 bool CheckEmbeddingCacheServer() {
44   return ps::PSContext::instance()->cache_enable() && distributed::cluster::ClusterContext::instance()->initialized() &&
45          ps::PSContext::instance()->is_server();
46 }
47 
48 // Whether device address exist.
NodeDeviceAddressExist(const DeviceContext * device_context,const AnfNodePtr & kernel,size_t index)49 bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &kernel, size_t index) {
50   MS_EXCEPTION_IF_NULL(kernel);
51   MS_EXCEPTION_IF_NULL(device_context);
52   if (AnfAlgo::OutputAddrExist(kernel, index)) {
53     const auto &address = AnfAlgo::GetOutputAddr(kernel, index, false);
54     MS_EXCEPTION_IF_NULL(address);
55     return address->GetDeviceType() == device_context->GetDeviceType();
56   }
57   return false;
58 }
59 
60 // Finalize ps cache module before throw an exception.
FinalizeEmbeddingCachePrefetch(const std::string & exception)61 void FinalizeEmbeddingCachePrefetch(const std::string &exception) {
62   MS_LOG(INFO) << "Begin finalize the EmbeddingCacheScheduler.";
63   EmbeddingCacheScheduler::GetInstance().Finalize(false);
64   MS_LOG(INFO) << "End finalize the EmbeddingCacheScheduler.";
65   MS_LOG(EXCEPTION) << exception;
66 }
67 
GetFirstEmbeddingCacheTableInfo(const KernelGraph & graph,AnfNodePtr * const first_cache_input_index,size_t * output_idx,size_t * const first_cache_size)68 void GetFirstEmbeddingCacheTableInfo(const KernelGraph &graph, AnfNodePtr *const first_cache_input_index,
69                                      size_t *output_idx, size_t *const first_cache_size) {
70   MS_EXCEPTION_IF_NULL(first_cache_input_index);
71   MS_EXCEPTION_IF_NULL(first_cache_size);
72   for (const auto &kernel : graph.execution_order()) {
73     MS_EXCEPTION_IF_NULL(kernel);
74     const mindspore::HashSet<std::string> kNeedCheckNodes = {kGatherOpName, kSparseGatherV2OpName, kGatherV2OpName,
75                                                              kGatherV2DOpName, kMapTensorGetOpName};
76     auto kernel_name = common::AnfAlgo::GetCNodeName(kernel);
77     if (kNeedCheckNodes.find(kernel_name) == kNeedCheckNodes.end()) {
78       continue;
79     }
80     auto input_param = common::AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
81     auto input_index = common::AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
82     MS_EXCEPTION_IF_NULL(input_param.first);
83     MS_EXCEPTION_IF_NULL(input_index.first);
84     auto param_name = input_param.first->fullname_with_scope();
85     if (!embedding_cache_table_manager.IsEmbeddingCacheTable(param_name)) {
86       continue;
87     }
88     auto size = embedding_cache_table_manager.QueryHashTableSize(param_name);
89     while (input_index.first->isa<CNode>() &&
90            ((common::AnfAlgo::GetCNodeName(input_index.first) == kCastOpName) ||
91             (common::AnfAlgo::GetCNodeName(input_index.first) == kTensorMoveOpName))) {
92       input_index = common::AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
93       MS_EXCEPTION_IF_NULL(input_index.first);
94     }
95     auto cnode = common::AnfAlgo::IsGraphKernel(input_index.first)
96                    ? common::AnfAlgo::GetOutputOfGraphkernel(input_index)
97                    : input_index.first;
98     MS_EXCEPTION_IF_NULL(cnode);
99     if (!cnode->isa<CNode>()) {
100       FinalizeEmbeddingCachePrefetch("The EmbeddingLookup whose input index should be a CNode but got " +
101                                      cnode->fullname_with_scope());
102     }
103     if (!common::AnfAlgo::IsGetNextNode(cnode)) {
104       auto input_index_node_name = common::AnfAlgo::GetCNodeName(cnode);
105       bool full_batch = parallel::ParallelContext::GetInstance()->full_batch();
106       if ((!full_batch && (input_index_node_name != kUniqueOpName)) ||
107           (full_batch && (input_index_node_name != kMinimumOpName))) {
108         MS_LOG(ERROR) << "The input index of the EmbeddingLookup(" << kernel->fullname_with_scope()
109                       << ") cache is from " << cnode->fullname_with_scope();
110         FinalizeEmbeddingCachePrefetch(
111           "The EmbeddingLookup whose input index isn't from dataset doesn't support cache in parameter server training "
112           "mode.");
113       }
114     }
115     *output_idx = input_index.second;
116     *first_cache_input_index = cnode;
117     *first_cache_size = size;
118     MS_LOG(INFO) << "The input index of the first EmbeddingLookup cache is from " << cnode->fullname_with_scope()
119                  << ", the cache size is " << size;
120     return;
121   }
122 }
123 
CheckSparseModeForEmbeddingCache(const CNodePtr & node)124 void CheckSparseModeForEmbeddingCache(const CNodePtr &node) {
125   MS_EXCEPTION_IF_NULL(node);
126   auto pre_node = common::AnfAlgo::GetPrevNodeOutput(node, 1, true);
127   MS_EXCEPTION_IF_NULL(pre_node.first);
128   while (pre_node.first->isa<CNode>() && (common::AnfAlgo::GetCNodeName(pre_node.first) != kUniqueOpName)) {
129     pre_node = common::AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
130     MS_EXCEPTION_IF_NULL(pre_node.first);
131   }
132   if (!(pre_node.first->isa<CNode>()) || (common::AnfAlgo::GetCNodeName(pre_node.first) != kUniqueOpName)) {
133     FinalizeEmbeddingCachePrefetch(std::string("The input_indices of kernel[") + node->DebugString() +
134                                    "] must be unique in parameter server cache mode");
135   }
136 
137   pre_node = common::AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
138   MS_EXCEPTION_IF_NULL(pre_node.first);
139   while (pre_node.first->isa<CNode>() && ((common::AnfAlgo::GetCNodeName(pre_node.first) == kCastOpName) ||
140                                           (common::AnfAlgo::GetCNodeName(pre_node.first) == kTensorMoveOpName))) {
141     pre_node = common::AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
142     MS_EXCEPTION_IF_NULL(pre_node.first);
143   }
144   if (!(pre_node.first->isa<CNode>()) || (!common::AnfAlgo::IsGetNextNode(pre_node.first))) {
145     FinalizeEmbeddingCachePrefetch(
146       "The input indices of kernel[Unique] must be produced from dataset directly and the indices value can not be "
147       "changed before delivering to kernel[Unique] in parameter server cache mode.");
148   }
149 }
150 
ShouldCheckSparseMode(const std::string & param_name,const std::string & kernel_name,bool is_sparse_gather)151 bool ShouldCheckSparseMode(const std::string &param_name, const std::string &kernel_name, bool is_sparse_gather) {
152   if (embedding_cache_table_manager.IsEmbeddingCacheTable(param_name)) {
153     if (kernel_name == kSparseGatherV2OpName || is_sparse_gather) {
154       return true;
155     }
156   }
157   return false;
158 }
159 
CheckGraphValidForEmbeddingCache(const KernelGraph & graph)160 void CheckGraphValidForEmbeddingCache(const KernelGraph &graph) {
161   AnfNodePtr first_cache_input_index = nullptr;
162   size_t first_cache_size = 0;
163   size_t output_idx = 0;
164   GetFirstEmbeddingCacheTableInfo(graph, &first_cache_input_index, &output_idx, &first_cache_size);
165   MS_EXCEPTION_IF_NULL(first_cache_input_index);
166   for (const auto &kernel : graph.execution_order()) {
167     MS_EXCEPTION_IF_NULL(kernel);
168     auto kernel_name = common::AnfAlgo::GetCNodeName(kernel);
169     const mindspore::HashSet<std::string> kNeedCacheNodes = {kGatherOpName, kSparseGatherV2OpName, kGatherV2OpName,
170                                                              kGatherV2DOpName, kMapTensorGetOpName};
171     if (kNeedCacheNodes.count(kernel_name) == 0) {
172       continue;
173     }
174     auto input_param = common::AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
175     auto input_index = common::AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
176     MS_EXCEPTION_IF_NULL(input_param.first);
177     MS_EXCEPTION_IF_NULL(input_index.first);
178     if (!input_param.first->isa<Parameter>()) {
179       continue;
180     }
181     auto param_name = input_param.first->fullname_with_scope();
182     // In ascend, change kSparseGatherV2OpName to kGatherV2OpName & set attr sparse: true
183     bool is_sparse_gather = false;
184     if (kernel_name == kGatherV2OpName && common::AnfAlgo::HasNodeAttr(kAttrIsSparse, kernel)) {
185       is_sparse_gather = common::AnfAlgo::GetNodeAttr<bool>(kernel, kAttrIsSparse);
186     }
187     if (ShouldCheckSparseMode(param_name, kernel_name, is_sparse_gather)) {
188       CheckSparseModeForEmbeddingCache(kernel);
189     }
190     while (input_index.first->isa<CNode>() &&
191            ((common::AnfAlgo::GetCNodeName(input_index.first) == kCastOpName) ||
192             (common::AnfAlgo::GetCNodeName(input_index.first) == kTensorMoveOpName))) {
193       input_index = common::AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
194       MS_EXCEPTION_IF_NULL(input_index.first);
195     }
196     auto cnode = common::AnfAlgo::IsGraphKernel(input_index.first)
197                    ? common::AnfAlgo::GetOutputOfGraphkernel(input_index)
198                    : input_index.first;
199     MS_EXCEPTION_IF_NULL(cnode);
200     if (cnode == first_cache_input_index && input_index.second == output_idx) {
201       if (!embedding_cache_table_manager.IsEmbeddingCacheTable(param_name)) {
202         MS_LOG(ERROR) << "The EmbeddingLookup(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
203         FinalizeEmbeddingCachePrefetch(
204           "All the embeddingLookups whose input indices are from dataset must enable cache at the same time when one "
205           "of them enables cache in parameter server training mode.");
206       }
207       auto size = embedding_cache_table_manager.QueryHashTableSize(param_name);
208       if (size != first_cache_size) {
209         MS_LOG(ERROR) << "The cache size(" << size << ") of EmbeddingLookup(" << kernel->fullname_with_scope()
210                       << ") is not the same as other EmbeddingLookup cache size(" << first_cache_size << ").";
211         FinalizeEmbeddingCachePrefetch(
212           "The cache sizes of embeddingLookups are not the same in parameter server training mode.");
213       }
214     } else if (embedding_cache_table_manager.IsEmbeddingCacheTable(param_name)) {
215       MS_LOG(ERROR) << "The input index of the EmbeddingLookup(" << kernel->fullname_with_scope() << ") cache is from "
216                     << cnode->fullname_with_scope();
217       FinalizeEmbeddingCachePrefetch(
218         "The EmbeddingLookup whose input index isn't from dataset doesn't support cache in parameter server training "
219         "mode.");
220     } else if (cnode->isa<CNode>() && (common::AnfAlgo::IsGetNextNode(cnode)) && (input_index.second == output_idx)) {
221       MS_LOG(ERROR) << "The EmbeddingLookup kernel(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
222       FinalizeEmbeddingCachePrefetch(
223         "All EmbeddingLookup kernels whose input indices are from dataset must enable cache at the same time.");
224     }
225   }
226 }
227 }  // namespace
228 
GetInstance()229 EmbeddingCacheScheduler &EmbeddingCacheScheduler::GetInstance() {
230   static EmbeddingCacheScheduler instance{};
231   if (!instance.initialized_) {
232     instance.Initialize();
233   }
234   return instance;
235 }
236 
Initialize()237 void EmbeddingCacheScheduler::Initialize() {
238   if (!CheckEnableEmbeddingCache()) {
239     return;
240   }
241   if (initialized_) {
242     return;
243   }
244 
245   // Get or Create device context.
246   auto ms_context = MsContext::GetInstance();
247   MS_EXCEPTION_IF_NULL(ms_context);
248   std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
249   uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
250   DeviceContext *device_context =
251     device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
252   MS_EXCEPTION_IF_NULL(device_context);
253   device_context->Initialize();
254 
255   // Create and initialize EmbeddingCachePrefetchActor.
256   embedding_cache_prefetch_actor_ = std::make_shared<EmbeddingCachePrefetchActor>(device_context);
257   MS_EXCEPTION_IF_NULL(embedding_cache_prefetch_actor_);
258 
259   initialized_ = true;
260 }
261 
ParseBatchIdsNum(const KernelGraphPtr & graph)262 void EmbeddingCacheScheduler::ParseBatchIdsNum(const KernelGraphPtr &graph) {
263   if (parsed_batch_ids_num_) {
264     return;
265   }
266 
267   // 1. Find InitDataSetQueue kernel.
268   MS_EXCEPTION_IF_NULL(graph);
269   const auto &kernels = graph->execution_order();
270   auto iter = find_if(kernels.begin(), kernels.end(), [](const CNodePtr &kernel) {
271     MS_EXCEPTION_IF_NULL(kernel);
272     return common::AnfAlgo::GetCNodeName(kernel) == kInitDatasetQueueOpName;
273   });
274   if (iter == kernels.end()) {
275     MS_LOG(EXCEPTION) << "Can not find InitDataSetQueue kernel";
276   }
277 
278   const auto &kernel = *iter;
279   MS_EXCEPTION_IF_NULL(kernel);
280 
281   // 2. Get shape of InitDataSetQueue kernel.
282   std::vector<std::vector<int64_t>> shapes;
283   if (common::AnfAlgo::IsDynamicShape(kernel)) {
284     shapes = common::AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel, "max_shapes");
285   } else {
286     shapes = common::AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel, "shapes");
287   }
288   auto types = common::AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel, "types");
289   if (shapes.size() != types.size() || shapes.size() == 0 || types.size() == 0) {
290     MS_LOG(EXCEPTION) << "Invalid shapes of op[InitDataSetQueue]: shapes size " << shapes.size() << ", types size "
291                       << types;
292   }
293 
294   const TypePtr &id_type = types.front();
295   MS_EXCEPTION_IF_NULL(id_type);
296   if (id_type->type_id() != kInt32->type_id() && id_type->type_id() != kInt->type_id()) {
297     MS_LOG(EXCEPTION) << "Embedding cache mode need input ids with data type[" << kInt32->ToString() << " or "
298                       << kInt->ToString() << "], but got[" << id_type->ToString() << "]";
299   }
300 
301   // 3. Get batch ids num(not batch size).
302   const auto &shape = shapes[0];
303   auto batch_ids_num = LongToSize(std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>()));
304   embedding_cache_table_manager.Initialize();
305   embedding_cache_table_manager.set_batch_ids_num(batch_ids_num);
306 
307   parsed_batch_ids_num_ = true;
308 }
309 
AllocMemForEmbeddingCacheTable(const DeviceContext * device_context)310 void EmbeddingCacheScheduler::AllocMemForEmbeddingCacheTable(const DeviceContext *device_context) {
311   if (allocated_embed_cache_mem_) {
312     return;
313   }
314 
315   embedding_cache_table_manager.AllocMemForEmbedding(device_context);
316   allocated_embed_cache_mem_ = true;
317 }
318 
SetEmbedCachedParamAddress(const DeviceContext * device_context,const KernelGraphPtr & graph)319 void EmbeddingCacheScheduler::SetEmbedCachedParamAddress(const DeviceContext *device_context,
320                                                          const KernelGraphPtr &graph) {
321   if (!CheckEnableEmbeddingCache()) {
322     return;
323   }
324 
325   MS_EXCEPTION_IF_NULL(device_context);
326   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
327   MS_EXCEPTION_IF_NULL(graph);
328 
329   // 1. Get batch ids number before allocate device memory for embedding cache table.
330   ParseBatchIdsNum(graph);
331 
332   const std::vector<AnfNodePtr> &input_nodes = graph->input_nodes();
333   bool exist_embedding_cache_table = std::any_of(input_nodes.begin(), input_nodes.end(), [](const AnfNodePtr &node) {
334     MS_EXCEPTION_IF_NULL(node);
335     return embedding_cache_table_manager.IsEmbeddingCacheTable(node->fullname_with_scope());
336   });
337   if (!exist_embedding_cache_table) {
338     return;
339   }
340 
341   // Graph valid check.
342   // The sparse mode does not perform graph structure verification currently.
343   if (!embedding_cache_table_manager.is_sparse_format()) {
344     CheckGraphValidForEmbeddingCache(*graph);
345   }
346 
347   // 2. Set parameter device address to embedding cache table.
348   for (const auto &node : input_nodes) {
349     MS_EXCEPTION_IF_NULL(node);
350     const std::string &param_name = node->fullname_with_scope();
351     if (!embedding_cache_table_manager.IsEmbeddingCacheTable(param_name)) {
352       continue;
353     }
354 
355     if (node->isa<Parameter>() && !NodeDeviceAddressExist(device_context, node, 0)) {
356       MS_LOG_WITH_NODE(EXCEPTION, node) << "Not found device address for parameter: " << node->fullname_with_scope();
357     }
358 
359     if (embedding_cache_table_manager.QueryEmbeddingDeviceAddress(param_name) == nullptr) {
360       embedding_cache_table_manager.SetEmbeddingDeviceAddress(param_name, AnfAlgo::GetMutableOutputAddr(node, 0).get());
361     }
362   }
363 
364   // 3. Allocate device memory for embedding cache table.
365   AllocMemForEmbeddingCacheTable(device_context);
366 }
367 
SetDataSetChannel(const std::string & actor_id,const std::vector<KernelGraphPtr> & graphs)368 void EmbeddingCacheScheduler::SetDataSetChannel(const std::string &actor_id,
369                                                 const std::vector<KernelGraphPtr> &graphs) {
370   if (!CheckEnableEmbeddingCache()) {
371     return;
372   }
373 
374   for (const auto &graph : graphs) {
375     MS_EXCEPTION_IF_NULL(graph);
376     for (const auto &kernel_node : graph->execution_order()) {
377       if (!common::AnfAlgo::IsGetNextNode(kernel_node)) {
378         continue;
379       }
380 
381       if (!common::AnfAlgo::HasNodeAttr("shared_name", kernel_node)) {
382         MS_LOG(EXCEPTION) << "Can not find attr[shared_name] of GetNext";
383       }
384       (void)data_prepare_aid_to_data_channel_.emplace(
385         actor_id, common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, "shared_name"));
386       break;
387     }
388   }
389 }
390 
InitEmbeddingStorage(const std::vector<AnfNodePtr> & parameters) const391 void EmbeddingCacheScheduler::InitEmbeddingStorage(const std::vector<AnfNodePtr> &parameters) const {
392   if (!CheckEmbeddingCacheServer()) {
393     return;
394   }
395 
396   for (size_t i = 0; i < parameters.size(); i++) {
397     MS_EXCEPTION_IF_NULL(parameters[i]);
398     if (!parameters[i]->isa<Parameter>()) {
399       MS_LOG_WITH_NODE(EXCEPTION, parameters[i])
400         << "The node with name: " << parameters[i]->fullname_with_scope() << "is not a Parameter.";
401     }
402 
403     ParameterPtr param = parameters[i]->cast<ParameterPtr>();
404     MS_EXCEPTION_IF_NULL(param);
405     auto param_info = param->param_info();
406     // Check whether enable embedding storage for the parameter.
407     bool enable_embedding_storage =
408       param_info && param_info->key() != -1 && param_info->cache_enable() && param_info->use_persistent_storage();
409     if (!enable_embedding_storage) {
410       continue;
411     }
412 
413     auto embed_storage = embedding_storage_manager.Get(param_info->key());
414     MS_EXCEPTION_IF_NULL(embed_storage);
415     std::vector<DeviceTensorPtr> device_tensors = DeviceTensorStore::GetInstance().Fetch(param.get());
416     if (device_tensors.size() != 1) {
417       MS_LOG(EXCEPTION)
418         << "The device tensor size for embedding table which enables embedding storage should be 1, but got:"
419         << device_tensors.size();
420     }
421 
422     // Initialize embedding storage instance.
423     embed_storage->Initialize(device_tensors.front().get());
424   }
425 }
426 
Schedule()427 void EmbeddingCacheScheduler::Schedule() {
428   if (!initialized_ || scheduled_) {
429     return;
430   }
431 
432   // 1. Initialize embedding cache prefetch actor and build network connection inter process.
433   MS_EXCEPTION_IF_NULL(embedding_cache_prefetch_actor_);
434   embedding_cache_prefetch_actor_->Initialize();
435 
436   // 2. Spawn embedding cache prefetch actor.
437   auto actor_manager = ActorMgr::GetActorMgrRef();
438   // Bind single thread to execute embedding cache prefetch actor.
439   (void)actor_manager->Spawn(embedding_cache_prefetch_actor_, false);
440 
441   // 3. Run embedding cache prefetch actor.
442   ActorDispatcher::Send(embedding_cache_prefetch_actor_->GetAID(), &EmbeddingCachePrefetchActor::Run);
443 
444   scheduled_ = true;
445 }
446 
IncreaseGraphStep(const std::string & actor_id) const447 void EmbeddingCacheScheduler::IncreaseGraphStep(const std::string &actor_id) const {
448   if (!CheckEnableEmbeddingCache()) {
449     return;
450   }
451 
452   auto iter = data_prepare_aid_to_data_channel_.find(actor_id);
453   if (iter != data_prepare_aid_to_data_channel_.end()) {
454     MS_EXCEPTION_IF_NULL(embedding_cache_prefetch_actor_);
455     embedding_cache_prefetch_actor_->IncreaseGraphStep(iter->second);
456   }
457 }
458 
SyncEmbeddingTable() const459 void EmbeddingCacheScheduler::SyncEmbeddingTable() const {
460   MS_EXCEPTION_IF_NULL(embedding_cache_prefetch_actor_);
461   embedding_cache_prefetch_actor_->SyncEmbeddingTable();
462 }
463 
Finalize(bool sync_embedding_table)464 void EmbeddingCacheScheduler::Finalize(bool sync_embedding_table) {
465   std::lock_guard<std::mutex> lock(finalize_mutex_);
466   if (!initialized_ || finalized_) {
467     return;
468   }
469 
470   MS_LOG(INFO) << "Begin finalize EmbeddingCacheScheduler";
471   if (sync_embedding_table) {
472     SyncEmbeddingTable();
473   }
474 
475   MS_EXCEPTION_IF_NULL(embedding_cache_prefetch_actor_);
476   // Stop the embedding cache prefetch_actor.
477   bool finalize_remote = sync_embedding_table;
478   embedding_cache_prefetch_actor_->Finalize(finalize_remote);
479 
480   // Get or Create device context.
481   auto ms_context = MsContext::GetInstance();
482   MS_EXCEPTION_IF_NULL(ms_context);
483   std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
484   uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
485   DeviceContext *device_context =
486     device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
487   MS_EXCEPTION_IF_NULL(device_context);
488   device_context->Initialize();
489   embedding_cache_table_manager.Finalize(device_context);
490 
491   embedding_storage_manager.Clear();
492 
493   initialized_ = false;
494   finalized_ = true;
495   MS_LOG(INFO) << "End finalize EmbeddingCacheScheduler";
496 }
497 }  // namespace runtime
498 }  // namespace mindspore
499