• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
18 #include "ops/nn_op_name.h"
19 #include "ops/array_op_name.h"
20 #include "runtime/device/device_address_utils.h"
21 #include "runtime/graph_scheduler/actor/embedding_cache/device_dense_embedding_operation.h"
22 namespace mindspore {
23 namespace runtime {
24 using distributed::kInvalidIndexValue;
25 using kernel::Address;
26 using kernel::AddressPtrList;
27 
AnalyseCache(int * batch_ids,const size_t batch_ids_num,size_t data_step,const std::atomic_ulong * graph_running_step,bool * device_cache_need_wait_graph,bool * host_cache_need_wait_graph,int * indices,EmbeddingDeviceCache * embedding_device_cache,EmbeddingHostCache * embedding_host_cache,EmbeddingCacheStatisticsInfo * statistics_info)28 bool DeviceDenseEmbeddingOperation::AnalyseCache(int *batch_ids, const size_t batch_ids_num, size_t data_step,
29                                                  const std::atomic_ulong *graph_running_step,
30                                                  bool *device_cache_need_wait_graph, bool *host_cache_need_wait_graph,
31                                                  int *indices, EmbeddingDeviceCache *embedding_device_cache,
32                                                  EmbeddingHostCache *embedding_host_cache,
33                                                  EmbeddingCacheStatisticsInfo *statistics_info) {
34   MS_ERROR_IF_NULL(batch_ids);
35   MS_ERROR_IF_NULL(graph_running_step);
36   MS_ERROR_IF_NULL(embedding_device_cache);
37   MS_ERROR_IF_NULL(statistics_info);
38 
39   statistics_info_->batch_id_count_ = batch_ids_num;
40   std::unique_ptr<bool[]> out_range = std::make_unique<bool[]>(batch_ids_num);
41   auto ret = memset_s(out_range.get(), batch_ids_num * sizeof(bool), 0, batch_ids_num * sizeof(bool));
42   if (ret != EOK) {
43     MS_LOG(ERROR) << "Memset failed, errno[" << ret << "]";
44     return false;
45   }
46 
47   // 1. Analyze the hit/miss info of the local host cache and device cache.
48   RETURN_IF_FALSE_WITH_LOG(CheckCacheHitOrOutRange(batch_ids, batch_ids_num, indices, out_range.get(), data_step),
49                            "Check cache hit or out range failed.");
50   RETURN_IF_FALSE_WITH_LOG(actor_->ResetEmbeddingHashMap(), "Reset embedding hash map failed.");
51 
52   size_t cur_graph_running_step = graph_running_step->load();
53   // 2.calculate the swapping and mapping(feature id to cache index) information of the missing feature id that needs to
54   // be inserted into the cache.
55   for (size_t i = 0; i < batch_ids_num; i++) {
56     if (out_range[i]) {
57       continue;
58     }
59     (void)modified_ids_.insert(batch_ids[i]);
60     bool need_swap_host_to_device = true;
61     bool need_swap_device_to_host = true;
62     int index = kInvalidIndexValue;
63     RETURN_IF_FALSE_WITH_LOG(ParseDeviceData(batch_ids[i], &need_swap_device_to_host, &need_swap_host_to_device, &index,
64                                              data_step, &cur_graph_running_step, graph_running_step,
65                                              device_cache_need_wait_graph, embedding_device_cache, statistics_info),
66                              "Parse device cache data failed.");
67     indices[i] = index + local_device_cache_bounds_.first;
68     if (need_swap_host_to_device) {
69       RETURN_IF_FALSE_WITH_LOG(
70         ParseHostDataHostToDevice(batch_ids[i], data_step, &cur_graph_running_step, graph_running_step,
71                                   host_cache_need_wait_graph, embedding_host_cache, statistics_info),
72         "Parse local host cache data(swap local host cache to device) failed.");
73     }
74     if (need_swap_device_to_host) {
75       RETURN_IF_FALSE_WITH_LOG(
76         ParseHostDataDeviceToHost(data_step, &cur_graph_running_step, graph_running_step, host_cache_need_wait_graph,
77                                   embedding_device_cache, embedding_host_cache, statistics_info),
78         "Parse local host cache data(swap device cache to local host) failed.");
79     }
80   }
81   return true;
82 }
83 
PushCacheFromDeviceToLocalHost(const HashTableInfo & hash_info,const CacheAnalysis * cache_analysis)84 bool DeviceDenseEmbeddingOperation::PushCacheFromDeviceToLocalHost(const HashTableInfo &hash_info,
85                                                                    const CacheAnalysis *cache_analysis) {
86   MS_ERROR_IF_NULL(cache_analysis);
87   auto statistics_info = cache_analysis->statistics_info_;
88   auto embedding_device_cache = cache_analysis->embedding_device_cache_;
89   auto embedding_host_cache = cache_analysis->embedding_host_cache_;
90   MS_ERROR_IF_NULL(statistics_info);
91   MS_ERROR_IF_NULL(embedding_device_cache);
92   MS_ERROR_IF_NULL(embedding_host_cache);
93 
94   auto swap_indices_size = statistics_info->device_to_host_size_;
95   if (swap_indices_size == 0) {
96     return true;
97   }
98 
99   auto device_cache_device_to_host_index = embedding_device_cache->device_to_host_index.get();
100   auto host_cache_device_to_host_index = embedding_host_cache->device_to_host_index.get();
101   MS_ERROR_IF_NULL(device_cache_device_to_host_index);
102   MS_ERROR_IF_NULL(host_cache_device_to_host_index);
103   auto hash_table_addr = reinterpret_cast<float *>(hash_info.address.addr);
104   auto cache_vocab_size = hash_info.cache_vocab_size;
105   auto host_hash_table_addr = hash_info.host_address;
106   auto embedding_size = hash_info.embedding_size;
107   auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
108   if (swap_indices_size >
109       embedding_cache_table_manager.batch_ids_num_ * embedding_cache_table_manager.multi_batch_threshold_) {
110     MS_LOG(EXCEPTION) << "The swap size is greater than the size of batch id buffer.";
111   }
112   RETURN_IF_FALSE_WITH_LOG(
113     MemcpyHostToDeviceAsync(embedding_cache_table_manager.hash_swap_index_addr_, device_cache_device_to_host_index,
114                             swap_indices_size * sizeof(int), device_context_, stream_id_),
115     "Memcpy host to device asynchronously failed.");
116 
117   RETURN_IF_FALSE_WITH_LOG(
118     LookupDeviceCache(embedding_cache_table_manager.hash_swap_index_addr_, hash_table_addr, swap_indices_size,
119                       cache_vocab_size, embedding_size, embedding_cache_table_manager.hash_swap_value_addr_),
120     "Lookup device cache failed.");
121 
122   RETURN_IF_FALSE_WITH_LOG(
123     MemcpyDeviceToHostAsync(swap_out_data.get(), embedding_cache_table_manager.hash_swap_value_addr_,
124                             swap_indices_size * embedding_size * sizeof(float), device_context_, stream_id_),
125     "Memcpy device to host asynchronously failed.");
126 
127   MS_ERROR_IF_NULL(device_context_);
128   MS_ERROR_IF_NULL(device_context_->device_res_manager_);
129   RETURN_IF_FALSE_WITH_LOG(device_context_->device_res_manager_->SyncStream(stream_id_), "Synchronize stream failed.");
130   RETURN_IF_FALSE_WITH_LOG(
131     actor_->InsertLocalHostCache(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index,
132                                  swap_out_data.get(), host_hash_table_addr),
133     "Insert local host cache failed.");
134   return true;
135 }
136 
PullCacheFromLocalHostToDevice(const HashTableInfo & hash_info,const CacheAnalysis * cache_analysis)137 bool DeviceDenseEmbeddingOperation::PullCacheFromLocalHostToDevice(const HashTableInfo &hash_info,
138                                                                    const CacheAnalysis *cache_analysis) {
139   MS_ERROR_IF_NULL(cache_analysis);
140   auto statistics_info = cache_analysis->statistics_info_;
141   auto embedding_device_cache = cache_analysis->embedding_device_cache_;
142   auto embedding_host_cache = cache_analysis->embedding_host_cache_;
143   MS_ERROR_IF_NULL(statistics_info);
144   MS_ERROR_IF_NULL(embedding_device_cache);
145   MS_ERROR_IF_NULL(embedding_host_cache);
146 
147   auto swap_indices_size = statistics_info->host_to_device_size_;
148   if (swap_indices_size == 0) {
149     return true;
150   }
151 
152   auto host_cache_host_to_device_index = embedding_host_cache->host_to_device_index.get();
153   auto device_cache_host_to_device_index = embedding_device_cache->host_to_device_index.get();
154   MS_ERROR_IF_NULL(host_cache_host_to_device_index);
155   MS_ERROR_IF_NULL(device_cache_host_to_device_index);
156 
157   auto embedding_size = hash_info.embedding_size;
158   MS_ERROR_IF_NULL(hash_info.address.addr);
159   auto hash_table_addr = reinterpret_cast<float *>(hash_info.address.addr);
160   auto cache_vocab_size = hash_info.cache_vocab_size;
161   auto host_hash_table_addr = hash_info.host_address;
162   MS_ERROR_IF_NULL(host_hash_table_addr);
163   auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
164   RETURN_IF_FALSE_WITH_LOG(actor_->LookupLocalHostCache(embedding_size, swap_indices_size, host_hash_table_addr,
165                                                         host_cache_host_to_device_index, swap_out_data.get()),
166                            "Lookup local host cache failed.");
167 
168   if (swap_indices_size >
169       embedding_cache_table_manager.batch_ids_num_ * embedding_cache_table_manager.multi_batch_threshold_) {
170     MS_LOG(EXCEPTION) << "The swap size is greater than the size of batch value buffer.";
171   }
172   RETURN_IF_FALSE_WITH_LOG(
173     MemcpyHostToDeviceAsync(embedding_cache_table_manager.hash_swap_value_addr_, swap_out_data.get(),
174                             swap_indices_size * embedding_size * sizeof(float), device_context_, stream_id_),
175     "Memcpy host to device asynchronously failed.");
176   RETURN_IF_FALSE_WITH_LOG(
177     MemcpyHostToDeviceAsync(embedding_cache_table_manager.hash_swap_index_addr_, device_cache_host_to_device_index,
178                             swap_indices_size * sizeof(int), device_context_, stream_id_),
179     "Memcpy host to device asynchronously failed.");
180 
181   RETURN_IF_FALSE_WITH_LOG(UpdateDeviceCache(embedding_cache_table_manager.hash_swap_index_addr_,
182                                              embedding_cache_table_manager.hash_swap_value_addr_, swap_indices_size,
183                                              cache_vocab_size, embedding_size, hash_table_addr),
184                            "Update device embedding cache failed.");
185   MS_ERROR_IF_NULL(device_context_);
186   MS_ERROR_IF_NULL(device_context_->device_res_manager_);
187   RETURN_IF_FALSE_WITH_LOG(device_context_->device_res_manager_->SyncStream(stream_id_), "Synchronize stream failed.");
188   return true;
189 }
190 
GetRemoteEmbeddingSliceBound(size_t vocab_size,size_t server_num,std::vector<std::pair<size_t,size_t>> * remote_embedding_slice_bounds)191 void DeviceDenseEmbeddingOperation::GetRemoteEmbeddingSliceBound(
192   size_t vocab_size, size_t server_num, std::vector<std::pair<size_t, size_t>> *remote_embedding_slice_bounds) {
193   if (server_num == 0) {
194     MS_LOG(EXCEPTION) << "The number of servers is at least 1, but get 0";
195   }
196   size_t average_slice_size = vocab_size / server_num;
197   std::vector<size_t> remote_embedding_slice_sizes = std::vector<size_t>(server_num, average_slice_size);
198   size_t rest_vocab_size = vocab_size % server_num;
199   for (size_t i = 0; i < rest_vocab_size; i++) {
200     remote_embedding_slice_sizes[i] += 1;
201   }
202 
203   size_t begin;
204   size_t end;
205   for (size_t i = 0; i < server_num; i++) {
206     if (i == 0) {
207       begin = 0;
208       end = remote_embedding_slice_sizes[0] - 1;
209     } else {
210       MS_EXCEPTION_IF_NULL(remote_embedding_slice_bounds);
211       begin = remote_embedding_slice_bounds->at(i - 1).second + 1;
212       end = begin + remote_embedding_slice_sizes[i] - 1;
213     }
214     (void)remote_embedding_slice_bounds->emplace_back(begin, end);
215   }
216 }
217 
BuildEmbeddingCacheLookupKernel()218 void DeviceDenseEmbeddingOperation::BuildEmbeddingCacheLookupKernel() {
219   auto graph = std::make_shared<KernelGraph>();
220   MS_EXCEPTION_IF_NULL(graph);
221   graph->set_graph_id((std::numeric_limits<uint32_t>::max)());
222   embedding_cache_graphs_.push_back(graph);
223 
224   // 1. Create parameter nodes which are inputs of embedding cache look up kernel(operator name: 'EmbeddingLookup').
225   ParameterPtr input_param = NewParameter(graph, kFloat32, kTwoDimensionalShape);
226   ParameterPtr input_indices = NewParameter(graph, kInt32, kOneDimensionalShape);
227   ValueNodePtr offset_value_node = NewValueNode(0, device_context_, stream_id_);
228 
229   // 2. Create a CNode for operator EmbeddingLookup.
230   PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kEmbeddingLookupOpName);
231   emb_lookup_primitive->set_attr(kAttrInputIsDynamicShape, MakeValue(true));
232   emb_lookup_primitive->set_attr(kAttrOutputIsDynamicShape, MakeValue(true));
233 
234   std::vector<AnfNodePtr> emb_lookup_input_nodes{mindspore::NewValueNode(emb_lookup_primitive), input_param,
235                                                  input_indices, offset_value_node};
236   embedding_cache_lookup_node_ = graph->NewCNode(emb_lookup_input_nodes);
237   MS_EXCEPTION_IF_NULL(embedding_cache_lookup_node_);
238   auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, kTwoDimensionalShape);
239   embedding_cache_lookup_node_->set_abstract(abstract);
240 
241   // 3. Kernel build process.
242   MS_EXCEPTION_IF_NULL(device_context_);
243   MS_EXCEPTION_IF_NULL(device_context_->GetKernelExecutor(false));
244   device_context_->GetKernelExecutor(false)->CreateKernel({embedding_cache_lookup_node_});
245   AnfAlgo::SetStreamId(stream_id_, embedding_cache_lookup_node_.get());
246 
247   DeviceAddressUtils::CreateParameterDeviceAddress(device_context_, graph);
248   DeviceAddressUtils::CreateKernelOutputDeviceAddress(device_context_, graph, false);
249   DeviceAddressUtils::CreateKernelWorkspaceDeviceAddress(device_context_, graph);
250 }
251 
BuildEmbeddingCacheUpdateKernel()252 void DeviceDenseEmbeddingOperation::BuildEmbeddingCacheUpdateKernel() {
253   auto graph = std::make_shared<KernelGraph>();
254   MS_EXCEPTION_IF_NULL(graph);
255   graph->set_graph_id((std::numeric_limits<uint32_t>::max)());
256   embedding_cache_graphs_.push_back(graph);
257 
258   // 1. Create parameter nodes which are inputs of embedding cache update kernel(operator name: 'ScatterUpdate').
259   ParameterPtr input_param = NewParameter(graph, kFloat32, kTwoDimensionalShape);
260   ParameterPtr input_indices = NewParameter(graph, kInt32, kOneDimensionalShape);
261   ParameterPtr update_values = NewParameter(graph, kFloat32, kTwoDimensionalShape);
262 
263   // 2. Create a CNode for operator ScatterUpdate.
264   PrimitivePtr embedding_cache_update_primitive = std::make_shared<Primitive>(kScatterUpdateOpName);
265   embedding_cache_update_primitive->set_attr(kAttrInputIsDynamicShape, MakeValue(true));
266 
267   std::vector<AnfNodePtr> embedding_cache_update_input_nodes{mindspore::NewValueNode(embedding_cache_update_primitive),
268                                                              input_param, input_indices, update_values};
269   embedding_cache_update_node_ = graph->NewCNode(embedding_cache_update_input_nodes);
270   MS_EXCEPTION_IF_NULL(embedding_cache_update_node_);
271   auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, kTwoDimensionalShape);
272   embedding_cache_update_node_->set_abstract(abstract);
273 
274   // 3. Kernel build process.
275   MS_EXCEPTION_IF_NULL(device_context_);
276   MS_EXCEPTION_IF_NULL(device_context_->GetKernelExecutor(false));
277   device_context_->GetKernelExecutor(false)->CreateKernel({embedding_cache_update_node_});
278   AnfAlgo::SetStreamId(stream_id_, embedding_cache_update_node_.get());
279 
280   DeviceAddressUtils::CreateParameterDeviceAddress(device_context_, graph);
281   DeviceAddressUtils::CreateKernelOutputDeviceAddress(device_context_, graph, false);
282   DeviceAddressUtils::CreateKernelWorkspaceDeviceAddress(device_context_, graph);
283 }
284 
LookupDeviceCache(void * indices,void * embedding_cache,size_t indices_num,size_t cache_size,size_t embedding_size,void * outputs)285 bool DeviceDenseEmbeddingOperation::LookupDeviceCache(void *indices, void *embedding_cache, size_t indices_num,
286                                                       size_t cache_size, size_t embedding_size, void *outputs) {
287   MS_ERROR_IF_NULL(indices);
288   MS_ERROR_IF_NULL(embedding_cache);
289   MS_ERROR_IF_NULL(outputs);
290   MS_ERROR_IF_NULL(embedding_cache_lookup_node_);
291 
292   // 1. Get input and output kernel tensors.
293   std::vector<kernel::KernelTensor *> input_kernel_tensors =
294     AnfAlgo::GetOrCreateAllInputKernelTensors(embedding_cache_lookup_node_);
295   std::vector<kernel::KernelTensor *> output_kernel_tensors =
296     AnfAlgo::GetOrCreateAllOutputKernelTensors(embedding_cache_lookup_node_);
297   MS_EXCEPTION_IF_NULL(input_kernel_tensors[kIndex0]);
298   MS_EXCEPTION_IF_NULL(input_kernel_tensors[kIndex1]);
299   MS_EXCEPTION_IF_NULL(input_kernel_tensors[kIndex2]);
300 
301   MS_EXCEPTION_IF_CHECK_FAIL((input_kernel_tensors.size() == kCacheOpInputNum),
302                              "For op: " + embedding_cache_lookup_node_->fullname_with_scope() +
303                                " need 3 inputs, but got " + std::to_string(input_kernel_tensors.size()));
304   MS_EXCEPTION_IF_CHECK_FAIL((output_kernel_tensors.size() == kCacheOpOutputNum),
305                              "For op: " + embedding_cache_lookup_node_->fullname_with_scope() +
306                                " has 1 output, but got " + std::to_string(output_kernel_tensors.size()));
307 
308   std::vector<abstract::AbstractBasePtr> input_kernel_tensors_for_infer(kCacheOpInputNum, nullptr);
309   for (size_t i = 0; i < kCacheOpInputNum; i++) {
310     const auto &kernel_tensor = AnfAlgo::GetPrevNodeOutputKernelTensor(embedding_cache_lookup_node_, i);
311     MS_EXCEPTION_IF_NULL(kernel_tensor);
312     input_kernel_tensors_for_infer[i] = kernel_tensor;
313   }
314 
315   // 2. Update input shape.
316   ShapeVector input_param_shape = {SizeToLong(cache_size), SizeToLong(embedding_size)};
317   input_kernel_tensors[kIndex0]->SetShape(std::make_shared<abstract::TensorShape>(std::move(input_param_shape)));
318   ShapeVector input_indices_shape = {SizeToLong(indices_num)};
319   input_kernel_tensors[kIndex1]->SetShape(std::make_shared<abstract::TensorShape>(std::move(input_indices_shape)));
320 
321   // 3. Infer shape for embedding cache look up kernel(operator name: 'EmbeddingLookup') which is dynamic shape kernel.
322   if (!InferOpShape(embedding_cache_lookup_node_, input_kernel_tensors, output_kernel_tensors,
323                     input_kernel_tensors_for_infer)) {
324     MS_LOG(ERROR) << "Infer operator shape failed, op name: " << embedding_cache_lookup_node_->fullname_with_scope();
325     return false;
326   }
327 
328   // 4. Do embedding cache look up on device.
329   input_kernel_tensors[kIndex0]->set_device_ptr(embedding_cache);
330   input_kernel_tensors[kIndex1]->set_device_ptr(indices);
331   output_kernel_tensors[kIndex1]->set_device_ptr(outputs);
332 
333   MS_ERROR_IF_NULL(device_context_);
334   MS_ERROR_IF_NULL(device_context_->GetKernelExecutor(false));
335   auto kernel_mod = AnfAlgo::GetKernelMod(embedding_cache_lookup_node_);
336   auto stream = device_context_->device_res_manager_->GetStream(stream_id_);
337   auto ret = device_context_->GetKernelExecutor(false)->LaunchKernel(embedding_cache_lookup_node_, input_kernel_tensors,
338                                                                      {}, output_kernel_tensors, kernel_mod, stream);
339   if (!ret) {
340     MS_LOG(ERROR) << "Launch kernel: " << embedding_cache_lookup_node_->fullname_with_scope() << " failed.";
341     return false;
342   }
343   return true;
344 }
345 
UpdateDeviceCache(void * indices,void * update_value,size_t indices_num,size_t cache_size,size_t embedding_size,void * embedding_cache)346 bool DeviceDenseEmbeddingOperation::UpdateDeviceCache(void *indices, void *update_value, size_t indices_num,
347                                                       size_t cache_size, size_t embedding_size, void *embedding_cache) {
348   MS_ERROR_IF_NULL(indices);
349   MS_ERROR_IF_NULL(update_value);
350   MS_ERROR_IF_NULL(embedding_cache);
351   MS_ERROR_IF_NULL(embedding_cache_update_node_);
352 
353   // 1. Get input and output kernel tensors.
354   std::vector<kernel::KernelTensor *> input_kernel_tensors =
355     AnfAlgo::GetOrCreateAllInputKernelTensors(embedding_cache_update_node_);
356   std::vector<kernel::KernelTensor *> output_kernel_tensors =
357     AnfAlgo::GetOrCreateAllOutputKernelTensors(embedding_cache_update_node_);
358   MS_EXCEPTION_IF_NULL(input_kernel_tensors[kIndex0]);
359   MS_EXCEPTION_IF_NULL(input_kernel_tensors[kIndex1]);
360   MS_EXCEPTION_IF_NULL(input_kernel_tensors[kIndex2]);
361 
362   MS_EXCEPTION_IF_CHECK_FAIL((input_kernel_tensors.size() == kCacheOpInputNum),
363                              "For op: " + embedding_cache_update_node_->fullname_with_scope() +
364                                " need 3 inputs, but got " + std::to_string(input_kernel_tensors.size()));
365   MS_EXCEPTION_IF_CHECK_FAIL((output_kernel_tensors.size() == kCacheOpOutputNum),
366                              "For op: " + embedding_cache_update_node_->fullname_with_scope() +
367                                " has 1 output, but got " + std::to_string(output_kernel_tensors.size()));
368 
369   std::vector<abstract::AbstractBasePtr> input_kernel_tensors_for_infer(kCacheOpInputNum, nullptr);
370   for (size_t i = 0; i < kCacheOpInputNum; i++) {
371     const auto &kernel_tensor = AnfAlgo::GetPrevNodeOutputKernelTensor(embedding_cache_update_node_, i);
372     MS_EXCEPTION_IF_NULL(kernel_tensor);
373     input_kernel_tensors_for_infer[i] = kernel_tensor;
374   }
375 
376   // 2. Update input shape.
377   ShapeVector input_param_shape = {SizeToLong(cache_size), SizeToLong(embedding_size)};
378   input_kernel_tensors[kIndex0]->SetShape(std::make_shared<abstract::TensorShape>(std::move(input_param_shape)));
379   const ShapeVector input_indices_shape = {SizeToLong(indices_num)};
380   input_kernel_tensors[kIndex1]->SetShape(std::make_shared<abstract::TensorShape>(std::move(input_indices_shape)));
381   const ShapeVector update_values_shape = {SizeToLong(indices_num), SizeToLong(embedding_size)};
382   input_kernel_tensors[kIndex2]->SetShape(std::make_shared<abstract::TensorShape>(std::move(update_values_shape)));
383 
384   // 3. Infer shape for embedding cache update kernel(operator name: 'ScatterUpdate') which is dynamic shape kernel.
385   if (!InferOpShape(embedding_cache_update_node_, input_kernel_tensors, output_kernel_tensors,
386                     input_kernel_tensors_for_infer)) {
387     MS_LOG(ERROR) << "Infer operator shape failed, op name: " << embedding_cache_update_node_->fullname_with_scope();
388     return false;
389   }
390 
391   // 4. Do update cache on device.
392   input_kernel_tensors[kIndex0]->set_device_ptr(embedding_cache);
393   input_kernel_tensors[kIndex1]->set_device_ptr(indices);
394   input_kernel_tensors[kIndex2]->set_device_ptr(update_value);
395   output_kernel_tensors[kIndex0]->set_device_ptr(embedding_cache);
396 
397   MS_ERROR_IF_NULL(device_context_);
398   MS_ERROR_IF_NULL(device_context_->GetKernelExecutor(false));
399   auto kernel_mod = AnfAlgo::GetKernelMod(embedding_cache_update_node_);
400   auto stream = device_context_->device_res_manager_->GetStream(stream_id_);
401   auto ret = device_context_->GetKernelExecutor(false)->LaunchKernel(embedding_cache_update_node_, input_kernel_tensors,
402                                                                      {}, output_kernel_tensors, kernel_mod, stream);
403   if (!ret) {
404     MS_LOG(ERROR) << "Launch kernel: " << embedding_cache_update_node_->fullname_with_scope() << " failed.";
405     return false;
406   }
407   return true;
408 }
409 
CheckCacheHitOrOutRange(const int * batch_ids,const size_t batch_ids_num,int * hash_index,bool * out_range,size_t data_step)410 bool DeviceDenseEmbeddingOperation::CheckCacheHitOrOutRange(const int *batch_ids, const size_t batch_ids_num,
411                                                             int *hash_index, bool *out_range, size_t data_step) {
412   MS_ERROR_IF_NULL(batch_ids);
413   MS_ERROR_IF_NULL(hash_index);
414   MS_ERROR_IF_NULL(out_range);
415 
416   size_t thread_num = batch_ids_num / kMaxIdsPerThread + 1;
417   thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
418   std::thread threads[kMaxThreadNum];
419   size_t hash_hit_count[kMaxThreadNum] = {0};
420   size_t i = 0;
421   size_t offset = 0;
422 
423   for (; i < thread_num; ++i) {
424     if (offset >= batch_ids_num) {
425       break;
426     }
427     size_t proc_len = batch_ids_num / thread_num + (i < (batch_ids_num % thread_num) ? 1 : 0);
428     threads[i] = std::thread(&DeviceDenseEmbeddingOperation::CheckCacheHitOrOutRangeFunc, this, batch_ids + offset,
429                              proc_len, hash_index + offset, out_range + offset, hash_hit_count + i, data_step);
430     offset += proc_len;
431   }
432   if (offset != batch_ids_num) {
433     MS_LOG(WARNING) << "Check id in device inadequate, total:" << batch_ids_num << " checked:" << offset;
434   }
435 
436   for (size_t j = 0; j < i; j++) {
437     threads[j].join();
438   }
439   for (size_t j = 0; j < i; j++) {
440     statistics_info_->hash_hit_count_ += hash_hit_count[j];
441   }
442   return true;
443 }
444 
CheckCacheHitOrOutRangeFunc(const int * batch_ids,const size_t batch_ids_num,int * hash_index,bool * out_range,size_t * hash_hit_count,size_t data_step)445 bool DeviceDenseEmbeddingOperation::CheckCacheHitOrOutRangeFunc(const int *batch_ids, const size_t batch_ids_num,
446                                                                 int *hash_index, bool *out_range,
447                                                                 size_t *hash_hit_count, size_t data_step) {
448   MS_ERROR_IF_NULL(batch_ids);
449   MS_ERROR_IF_NULL(hash_index);
450   MS_ERROR_IF_NULL(out_range);
451   MS_ERROR_IF_NULL(hash_hit_count);
452   auto &device_hash_map = embedding_cache_table_manager.device_hash_map_;
453   MS_ERROR_IF_NULL(device_hash_map);
454 
455   for (size_t i = 0; i < batch_ids_num; ++i) {
456     if (batch_ids[i] < local_embedding_slice_bounds_.first) {
457       hash_index[i] = batch_ids[i] - local_embedding_slice_bounds_.first + local_device_cache_bounds_.first;
458       out_range[i] = true;
459       continue;
460     }
461     if (batch_ids[i] >= local_embedding_slice_bounds_.second) {
462       hash_index[i] = batch_ids[i] + local_device_cache_bounds_.second;
463       out_range[i] = true;
464       continue;
465     }
466   }
467   return true;
468 }
469 
ParseDeviceData(int id,bool * need_swap_device_to_host,bool * need_swap_host_to_device,int * hash_index,size_t data_step,size_t * cur_graph_running_step,const std::atomic_ulong * latest_graph_running_step,bool * device_cache_need_wait_graph,EmbeddingDeviceCache * embedding_device_cache,EmbeddingCacheStatisticsInfo * statistics_info)470 bool DeviceDenseEmbeddingOperation::ParseDeviceData(int id, bool *need_swap_device_to_host,
471                                                     bool *need_swap_host_to_device, int *hash_index, size_t data_step,
472                                                     size_t *cur_graph_running_step,
473                                                     const std::atomic_ulong *latest_graph_running_step,
474                                                     bool *device_cache_need_wait_graph,
475                                                     EmbeddingDeviceCache *embedding_device_cache,
476                                                     EmbeddingCacheStatisticsInfo *statistics_info) {
477   MS_ERROR_IF_NULL(need_swap_device_to_host);
478   MS_ERROR_IF_NULL(need_swap_host_to_device);
479   MS_ERROR_IF_NULL(hash_index);
480   MS_ERROR_IF_NULL(cur_graph_running_step);
481   MS_ERROR_IF_NULL(latest_graph_running_step);
482   MS_ERROR_IF_NULL(embedding_device_cache);
483   MS_ERROR_IF_NULL(statistics_info);
484   auto &device_hash_map = embedding_cache_table_manager.device_hash_map_;
485   MS_ERROR_IF_NULL(device_hash_map);
486 
487   int index = kInvalidIndexValue;
488   if (device_hash_map->GetIndex(id, &index)) {
489     *need_swap_device_to_host = false;
490     *need_swap_host_to_device = false;
491     if (device_hash_map->hash_step(index) != data_step) {
492       statistics_info->hash_hit_count_++;
493       device_hash_map->set_hash_step(index, data_step);
494     }
495   } else {
496     int *device_to_host_index = embedding_device_cache->device_to_host_index.get();
497     int *device_to_host_ids = embedding_device_cache->device_to_host_ids.get();
498     int *host_to_device_index = embedding_device_cache->host_to_device_index.get();
499     int *host_to_device_ids = embedding_device_cache->host_to_device_ids.get();
500     MS_ERROR_IF_NULL(host_to_device_index);
501     MS_ERROR_IF_NULL(host_to_device_ids);
502     auto tmp_device_to_host_size = statistics_info->device_to_host_size_;
503     size_t retry_count = 0;
504     while (true) {
505       // Calculate the mapping of id to index.
506       index =
507         device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step, *cur_graph_running_step,
508                                    &statistics_info->device_to_host_size_, device_cache_need_wait_graph);
509       if (index == kInvalidIndexValue) {
510         const int64_t wait_interval = 10000;
511         *cur_graph_running_step = latest_graph_running_step->load();
512         std::this_thread::sleep_for(std::chrono::microseconds(wait_interval));
513         ++retry_count;
514         if (retry_count > kMaxRetryNum) {
515           MS_LOG(ERROR) << "Prefetch embedding cache timeout, please enlarge the vocab cache size.";
516           return false;
517         }
518         MS_LOG(DEBUG) << "There is no space in device cache, wait and retrying, current graph running step: "
519                       << *cur_graph_running_step << ", data step: " << data_step;
520         continue;
521       }
522       host_to_device_index[statistics_info->host_to_device_size_] = index;
523       host_to_device_ids[statistics_info->host_to_device_size_] = id;
524       statistics_info->host_to_device_size_++;
525       *need_swap_device_to_host = statistics_info->device_to_host_size_ > tmp_device_to_host_size;
526       break;
527     }
528   }
529 
530   *hash_index = index;
531   return true;
532 }
533 }  // namespace runtime
534 }  // namespace mindspore
535