1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
18 #include "ops/nn_op_name.h"
19 #include "ops/array_op_name.h"
20 #include "runtime/device/device_address_utils.h"
21 #include "runtime/graph_scheduler/actor/embedding_cache/device_dense_embedding_operation.h"
22 namespace mindspore {
23 namespace runtime {
24 using distributed::kInvalidIndexValue;
25 using kernel::Address;
26 using kernel::AddressPtrList;
27
AnalyseCache(int * batch_ids,const size_t batch_ids_num,size_t data_step,const std::atomic_ulong * graph_running_step,bool * device_cache_need_wait_graph,bool * host_cache_need_wait_graph,int * indices,EmbeddingDeviceCache * embedding_device_cache,EmbeddingHostCache * embedding_host_cache,EmbeddingCacheStatisticsInfo * statistics_info)28 bool DeviceDenseEmbeddingOperation::AnalyseCache(int *batch_ids, const size_t batch_ids_num, size_t data_step,
29 const std::atomic_ulong *graph_running_step,
30 bool *device_cache_need_wait_graph, bool *host_cache_need_wait_graph,
31 int *indices, EmbeddingDeviceCache *embedding_device_cache,
32 EmbeddingHostCache *embedding_host_cache,
33 EmbeddingCacheStatisticsInfo *statistics_info) {
34 MS_ERROR_IF_NULL(batch_ids);
35 MS_ERROR_IF_NULL(graph_running_step);
36 MS_ERROR_IF_NULL(embedding_device_cache);
37 MS_ERROR_IF_NULL(statistics_info);
38
39 statistics_info_->batch_id_count_ = batch_ids_num;
40 std::unique_ptr<bool[]> out_range = std::make_unique<bool[]>(batch_ids_num);
41 auto ret = memset_s(out_range.get(), batch_ids_num * sizeof(bool), 0, batch_ids_num * sizeof(bool));
42 if (ret != EOK) {
43 MS_LOG(ERROR) << "Memset failed, errno[" << ret << "]";
44 return false;
45 }
46
47 // 1. Analyze the hit/miss info of the local host cache and device cache.
48 RETURN_IF_FALSE_WITH_LOG(CheckCacheHitOrOutRange(batch_ids, batch_ids_num, indices, out_range.get(), data_step),
49 "Check cache hit or out range failed.");
50 RETURN_IF_FALSE_WITH_LOG(actor_->ResetEmbeddingHashMap(), "Reset embedding hash map failed.");
51
52 size_t cur_graph_running_step = graph_running_step->load();
53 // 2.calculate the swapping and mapping(feature id to cache index) information of the missing feature id that needs to
54 // be inserted into the cache.
55 for (size_t i = 0; i < batch_ids_num; i++) {
56 if (out_range[i]) {
57 continue;
58 }
59 (void)modified_ids_.insert(batch_ids[i]);
60 bool need_swap_host_to_device = true;
61 bool need_swap_device_to_host = true;
62 int index = kInvalidIndexValue;
63 RETURN_IF_FALSE_WITH_LOG(ParseDeviceData(batch_ids[i], &need_swap_device_to_host, &need_swap_host_to_device, &index,
64 data_step, &cur_graph_running_step, graph_running_step,
65 device_cache_need_wait_graph, embedding_device_cache, statistics_info),
66 "Parse device cache data failed.");
67 indices[i] = index + local_device_cache_bounds_.first;
68 if (need_swap_host_to_device) {
69 RETURN_IF_FALSE_WITH_LOG(
70 ParseHostDataHostToDevice(batch_ids[i], data_step, &cur_graph_running_step, graph_running_step,
71 host_cache_need_wait_graph, embedding_host_cache, statistics_info),
72 "Parse local host cache data(swap local host cache to device) failed.");
73 }
74 if (need_swap_device_to_host) {
75 RETURN_IF_FALSE_WITH_LOG(
76 ParseHostDataDeviceToHost(data_step, &cur_graph_running_step, graph_running_step, host_cache_need_wait_graph,
77 embedding_device_cache, embedding_host_cache, statistics_info),
78 "Parse local host cache data(swap device cache to local host) failed.");
79 }
80 }
81 return true;
82 }
83
PushCacheFromDeviceToLocalHost(const HashTableInfo & hash_info,const CacheAnalysis * cache_analysis)84 bool DeviceDenseEmbeddingOperation::PushCacheFromDeviceToLocalHost(const HashTableInfo &hash_info,
85 const CacheAnalysis *cache_analysis) {
86 MS_ERROR_IF_NULL(cache_analysis);
87 auto statistics_info = cache_analysis->statistics_info_;
88 auto embedding_device_cache = cache_analysis->embedding_device_cache_;
89 auto embedding_host_cache = cache_analysis->embedding_host_cache_;
90 MS_ERROR_IF_NULL(statistics_info);
91 MS_ERROR_IF_NULL(embedding_device_cache);
92 MS_ERROR_IF_NULL(embedding_host_cache);
93
94 auto swap_indices_size = statistics_info->device_to_host_size_;
95 if (swap_indices_size == 0) {
96 return true;
97 }
98
99 auto device_cache_device_to_host_index = embedding_device_cache->device_to_host_index.get();
100 auto host_cache_device_to_host_index = embedding_host_cache->device_to_host_index.get();
101 MS_ERROR_IF_NULL(device_cache_device_to_host_index);
102 MS_ERROR_IF_NULL(host_cache_device_to_host_index);
103 auto hash_table_addr = reinterpret_cast<float *>(hash_info.address.addr);
104 auto cache_vocab_size = hash_info.cache_vocab_size;
105 auto host_hash_table_addr = hash_info.host_address;
106 auto embedding_size = hash_info.embedding_size;
107 auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
108 if (swap_indices_size >
109 embedding_cache_table_manager.batch_ids_num_ * embedding_cache_table_manager.multi_batch_threshold_) {
110 MS_LOG(EXCEPTION) << "The swap size is greater than the size of batch id buffer.";
111 }
112 RETURN_IF_FALSE_WITH_LOG(
113 MemcpyHostToDeviceAsync(embedding_cache_table_manager.hash_swap_index_addr_, device_cache_device_to_host_index,
114 swap_indices_size * sizeof(int), device_context_, stream_id_),
115 "Memcpy host to device asynchronously failed.");
116
117 RETURN_IF_FALSE_WITH_LOG(
118 LookupDeviceCache(embedding_cache_table_manager.hash_swap_index_addr_, hash_table_addr, swap_indices_size,
119 cache_vocab_size, embedding_size, embedding_cache_table_manager.hash_swap_value_addr_),
120 "Lookup device cache failed.");
121
122 RETURN_IF_FALSE_WITH_LOG(
123 MemcpyDeviceToHostAsync(swap_out_data.get(), embedding_cache_table_manager.hash_swap_value_addr_,
124 swap_indices_size * embedding_size * sizeof(float), device_context_, stream_id_),
125 "Memcpy device to host asynchronously failed.");
126
127 MS_ERROR_IF_NULL(device_context_);
128 MS_ERROR_IF_NULL(device_context_->device_res_manager_);
129 RETURN_IF_FALSE_WITH_LOG(device_context_->device_res_manager_->SyncStream(stream_id_), "Synchronize stream failed.");
130 RETURN_IF_FALSE_WITH_LOG(
131 actor_->InsertLocalHostCache(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index,
132 swap_out_data.get(), host_hash_table_addr),
133 "Insert local host cache failed.");
134 return true;
135 }
136
PullCacheFromLocalHostToDevice(const HashTableInfo & hash_info,const CacheAnalysis * cache_analysis)137 bool DeviceDenseEmbeddingOperation::PullCacheFromLocalHostToDevice(const HashTableInfo &hash_info,
138 const CacheAnalysis *cache_analysis) {
139 MS_ERROR_IF_NULL(cache_analysis);
140 auto statistics_info = cache_analysis->statistics_info_;
141 auto embedding_device_cache = cache_analysis->embedding_device_cache_;
142 auto embedding_host_cache = cache_analysis->embedding_host_cache_;
143 MS_ERROR_IF_NULL(statistics_info);
144 MS_ERROR_IF_NULL(embedding_device_cache);
145 MS_ERROR_IF_NULL(embedding_host_cache);
146
147 auto swap_indices_size = statistics_info->host_to_device_size_;
148 if (swap_indices_size == 0) {
149 return true;
150 }
151
152 auto host_cache_host_to_device_index = embedding_host_cache->host_to_device_index.get();
153 auto device_cache_host_to_device_index = embedding_device_cache->host_to_device_index.get();
154 MS_ERROR_IF_NULL(host_cache_host_to_device_index);
155 MS_ERROR_IF_NULL(device_cache_host_to_device_index);
156
157 auto embedding_size = hash_info.embedding_size;
158 MS_ERROR_IF_NULL(hash_info.address.addr);
159 auto hash_table_addr = reinterpret_cast<float *>(hash_info.address.addr);
160 auto cache_vocab_size = hash_info.cache_vocab_size;
161 auto host_hash_table_addr = hash_info.host_address;
162 MS_ERROR_IF_NULL(host_hash_table_addr);
163 auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
164 RETURN_IF_FALSE_WITH_LOG(actor_->LookupLocalHostCache(embedding_size, swap_indices_size, host_hash_table_addr,
165 host_cache_host_to_device_index, swap_out_data.get()),
166 "Lookup local host cache failed.");
167
168 if (swap_indices_size >
169 embedding_cache_table_manager.batch_ids_num_ * embedding_cache_table_manager.multi_batch_threshold_) {
170 MS_LOG(EXCEPTION) << "The swap size is greater than the size of batch value buffer.";
171 }
172 RETURN_IF_FALSE_WITH_LOG(
173 MemcpyHostToDeviceAsync(embedding_cache_table_manager.hash_swap_value_addr_, swap_out_data.get(),
174 swap_indices_size * embedding_size * sizeof(float), device_context_, stream_id_),
175 "Memcpy host to device asynchronously failed.");
176 RETURN_IF_FALSE_WITH_LOG(
177 MemcpyHostToDeviceAsync(embedding_cache_table_manager.hash_swap_index_addr_, device_cache_host_to_device_index,
178 swap_indices_size * sizeof(int), device_context_, stream_id_),
179 "Memcpy host to device asynchronously failed.");
180
181 RETURN_IF_FALSE_WITH_LOG(UpdateDeviceCache(embedding_cache_table_manager.hash_swap_index_addr_,
182 embedding_cache_table_manager.hash_swap_value_addr_, swap_indices_size,
183 cache_vocab_size, embedding_size, hash_table_addr),
184 "Update device embedding cache failed.");
185 MS_ERROR_IF_NULL(device_context_);
186 MS_ERROR_IF_NULL(device_context_->device_res_manager_);
187 RETURN_IF_FALSE_WITH_LOG(device_context_->device_res_manager_->SyncStream(stream_id_), "Synchronize stream failed.");
188 return true;
189 }
190
GetRemoteEmbeddingSliceBound(size_t vocab_size,size_t server_num,std::vector<std::pair<size_t,size_t>> * remote_embedding_slice_bounds)191 void DeviceDenseEmbeddingOperation::GetRemoteEmbeddingSliceBound(
192 size_t vocab_size, size_t server_num, std::vector<std::pair<size_t, size_t>> *remote_embedding_slice_bounds) {
193 if (server_num == 0) {
194 MS_LOG(EXCEPTION) << "The number of servers is at least 1, but get 0";
195 }
196 size_t average_slice_size = vocab_size / server_num;
197 std::vector<size_t> remote_embedding_slice_sizes = std::vector<size_t>(server_num, average_slice_size);
198 size_t rest_vocab_size = vocab_size % server_num;
199 for (size_t i = 0; i < rest_vocab_size; i++) {
200 remote_embedding_slice_sizes[i] += 1;
201 }
202
203 size_t begin;
204 size_t end;
205 for (size_t i = 0; i < server_num; i++) {
206 if (i == 0) {
207 begin = 0;
208 end = remote_embedding_slice_sizes[0] - 1;
209 } else {
210 MS_EXCEPTION_IF_NULL(remote_embedding_slice_bounds);
211 begin = remote_embedding_slice_bounds->at(i - 1).second + 1;
212 end = begin + remote_embedding_slice_sizes[i] - 1;
213 }
214 (void)remote_embedding_slice_bounds->emplace_back(begin, end);
215 }
216 }
217
BuildEmbeddingCacheLookupKernel()218 void DeviceDenseEmbeddingOperation::BuildEmbeddingCacheLookupKernel() {
219 auto graph = std::make_shared<KernelGraph>();
220 MS_EXCEPTION_IF_NULL(graph);
221 graph->set_graph_id((std::numeric_limits<uint32_t>::max)());
222 embedding_cache_graphs_.push_back(graph);
223
224 // 1. Create parameter nodes which are inputs of embedding cache look up kernel(operator name: 'EmbeddingLookup').
225 ParameterPtr input_param = NewParameter(graph, kFloat32, kTwoDimensionalShape);
226 ParameterPtr input_indices = NewParameter(graph, kInt32, kOneDimensionalShape);
227 ValueNodePtr offset_value_node = NewValueNode(0, device_context_, stream_id_);
228
229 // 2. Create a CNode for operator EmbeddingLookup.
230 PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kEmbeddingLookupOpName);
231 emb_lookup_primitive->set_attr(kAttrInputIsDynamicShape, MakeValue(true));
232 emb_lookup_primitive->set_attr(kAttrOutputIsDynamicShape, MakeValue(true));
233
234 std::vector<AnfNodePtr> emb_lookup_input_nodes{mindspore::NewValueNode(emb_lookup_primitive), input_param,
235 input_indices, offset_value_node};
236 embedding_cache_lookup_node_ = graph->NewCNode(emb_lookup_input_nodes);
237 MS_EXCEPTION_IF_NULL(embedding_cache_lookup_node_);
238 auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, kTwoDimensionalShape);
239 embedding_cache_lookup_node_->set_abstract(abstract);
240
241 // 3. Kernel build process.
242 MS_EXCEPTION_IF_NULL(device_context_);
243 MS_EXCEPTION_IF_NULL(device_context_->GetKernelExecutor(false));
244 device_context_->GetKernelExecutor(false)->CreateKernel({embedding_cache_lookup_node_});
245 AnfAlgo::SetStreamId(stream_id_, embedding_cache_lookup_node_.get());
246
247 DeviceAddressUtils::CreateParameterDeviceAddress(device_context_, graph);
248 DeviceAddressUtils::CreateKernelOutputDeviceAddress(device_context_, graph, false);
249 DeviceAddressUtils::CreateKernelWorkspaceDeviceAddress(device_context_, graph);
250 }
251
BuildEmbeddingCacheUpdateKernel()252 void DeviceDenseEmbeddingOperation::BuildEmbeddingCacheUpdateKernel() {
253 auto graph = std::make_shared<KernelGraph>();
254 MS_EXCEPTION_IF_NULL(graph);
255 graph->set_graph_id((std::numeric_limits<uint32_t>::max)());
256 embedding_cache_graphs_.push_back(graph);
257
258 // 1. Create parameter nodes which are inputs of embedding cache update kernel(operator name: 'ScatterUpdate').
259 ParameterPtr input_param = NewParameter(graph, kFloat32, kTwoDimensionalShape);
260 ParameterPtr input_indices = NewParameter(graph, kInt32, kOneDimensionalShape);
261 ParameterPtr update_values = NewParameter(graph, kFloat32, kTwoDimensionalShape);
262
263 // 2. Create a CNode for operator ScatterUpdate.
264 PrimitivePtr embedding_cache_update_primitive = std::make_shared<Primitive>(kScatterUpdateOpName);
265 embedding_cache_update_primitive->set_attr(kAttrInputIsDynamicShape, MakeValue(true));
266
267 std::vector<AnfNodePtr> embedding_cache_update_input_nodes{mindspore::NewValueNode(embedding_cache_update_primitive),
268 input_param, input_indices, update_values};
269 embedding_cache_update_node_ = graph->NewCNode(embedding_cache_update_input_nodes);
270 MS_EXCEPTION_IF_NULL(embedding_cache_update_node_);
271 auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, kTwoDimensionalShape);
272 embedding_cache_update_node_->set_abstract(abstract);
273
274 // 3. Kernel build process.
275 MS_EXCEPTION_IF_NULL(device_context_);
276 MS_EXCEPTION_IF_NULL(device_context_->GetKernelExecutor(false));
277 device_context_->GetKernelExecutor(false)->CreateKernel({embedding_cache_update_node_});
278 AnfAlgo::SetStreamId(stream_id_, embedding_cache_update_node_.get());
279
280 DeviceAddressUtils::CreateParameterDeviceAddress(device_context_, graph);
281 DeviceAddressUtils::CreateKernelOutputDeviceAddress(device_context_, graph, false);
282 DeviceAddressUtils::CreateKernelWorkspaceDeviceAddress(device_context_, graph);
283 }
284
LookupDeviceCache(void * indices,void * embedding_cache,size_t indices_num,size_t cache_size,size_t embedding_size,void * outputs)285 bool DeviceDenseEmbeddingOperation::LookupDeviceCache(void *indices, void *embedding_cache, size_t indices_num,
286 size_t cache_size, size_t embedding_size, void *outputs) {
287 MS_ERROR_IF_NULL(indices);
288 MS_ERROR_IF_NULL(embedding_cache);
289 MS_ERROR_IF_NULL(outputs);
290 MS_ERROR_IF_NULL(embedding_cache_lookup_node_);
291
292 // 1. Get input and output kernel tensors.
293 std::vector<kernel::KernelTensor *> input_kernel_tensors =
294 AnfAlgo::GetOrCreateAllInputKernelTensors(embedding_cache_lookup_node_);
295 std::vector<kernel::KernelTensor *> output_kernel_tensors =
296 AnfAlgo::GetOrCreateAllOutputKernelTensors(embedding_cache_lookup_node_);
297 MS_EXCEPTION_IF_NULL(input_kernel_tensors[kIndex0]);
298 MS_EXCEPTION_IF_NULL(input_kernel_tensors[kIndex1]);
299 MS_EXCEPTION_IF_NULL(input_kernel_tensors[kIndex2]);
300
301 MS_EXCEPTION_IF_CHECK_FAIL((input_kernel_tensors.size() == kCacheOpInputNum),
302 "For op: " + embedding_cache_lookup_node_->fullname_with_scope() +
303 " need 3 inputs, but got " + std::to_string(input_kernel_tensors.size()));
304 MS_EXCEPTION_IF_CHECK_FAIL((output_kernel_tensors.size() == kCacheOpOutputNum),
305 "For op: " + embedding_cache_lookup_node_->fullname_with_scope() +
306 " has 1 output, but got " + std::to_string(output_kernel_tensors.size()));
307
308 std::vector<abstract::AbstractBasePtr> input_kernel_tensors_for_infer(kCacheOpInputNum, nullptr);
309 for (size_t i = 0; i < kCacheOpInputNum; i++) {
310 const auto &kernel_tensor = AnfAlgo::GetPrevNodeOutputKernelTensor(embedding_cache_lookup_node_, i);
311 MS_EXCEPTION_IF_NULL(kernel_tensor);
312 input_kernel_tensors_for_infer[i] = kernel_tensor;
313 }
314
315 // 2. Update input shape.
316 ShapeVector input_param_shape = {SizeToLong(cache_size), SizeToLong(embedding_size)};
317 input_kernel_tensors[kIndex0]->SetShape(std::make_shared<abstract::TensorShape>(std::move(input_param_shape)));
318 ShapeVector input_indices_shape = {SizeToLong(indices_num)};
319 input_kernel_tensors[kIndex1]->SetShape(std::make_shared<abstract::TensorShape>(std::move(input_indices_shape)));
320
321 // 3. Infer shape for embedding cache look up kernel(operator name: 'EmbeddingLookup') which is dynamic shape kernel.
322 if (!InferOpShape(embedding_cache_lookup_node_, input_kernel_tensors, output_kernel_tensors,
323 input_kernel_tensors_for_infer)) {
324 MS_LOG(ERROR) << "Infer operator shape failed, op name: " << embedding_cache_lookup_node_->fullname_with_scope();
325 return false;
326 }
327
328 // 4. Do embedding cache look up on device.
329 input_kernel_tensors[kIndex0]->set_device_ptr(embedding_cache);
330 input_kernel_tensors[kIndex1]->set_device_ptr(indices);
331 output_kernel_tensors[kIndex1]->set_device_ptr(outputs);
332
333 MS_ERROR_IF_NULL(device_context_);
334 MS_ERROR_IF_NULL(device_context_->GetKernelExecutor(false));
335 auto kernel_mod = AnfAlgo::GetKernelMod(embedding_cache_lookup_node_);
336 auto stream = device_context_->device_res_manager_->GetStream(stream_id_);
337 auto ret = device_context_->GetKernelExecutor(false)->LaunchKernel(embedding_cache_lookup_node_, input_kernel_tensors,
338 {}, output_kernel_tensors, kernel_mod, stream);
339 if (!ret) {
340 MS_LOG(ERROR) << "Launch kernel: " << embedding_cache_lookup_node_->fullname_with_scope() << " failed.";
341 return false;
342 }
343 return true;
344 }
345
UpdateDeviceCache(void * indices,void * update_value,size_t indices_num,size_t cache_size,size_t embedding_size,void * embedding_cache)346 bool DeviceDenseEmbeddingOperation::UpdateDeviceCache(void *indices, void *update_value, size_t indices_num,
347 size_t cache_size, size_t embedding_size, void *embedding_cache) {
348 MS_ERROR_IF_NULL(indices);
349 MS_ERROR_IF_NULL(update_value);
350 MS_ERROR_IF_NULL(embedding_cache);
351 MS_ERROR_IF_NULL(embedding_cache_update_node_);
352
353 // 1. Get input and output kernel tensors.
354 std::vector<kernel::KernelTensor *> input_kernel_tensors =
355 AnfAlgo::GetOrCreateAllInputKernelTensors(embedding_cache_update_node_);
356 std::vector<kernel::KernelTensor *> output_kernel_tensors =
357 AnfAlgo::GetOrCreateAllOutputKernelTensors(embedding_cache_update_node_);
358 MS_EXCEPTION_IF_NULL(input_kernel_tensors[kIndex0]);
359 MS_EXCEPTION_IF_NULL(input_kernel_tensors[kIndex1]);
360 MS_EXCEPTION_IF_NULL(input_kernel_tensors[kIndex2]);
361
362 MS_EXCEPTION_IF_CHECK_FAIL((input_kernel_tensors.size() == kCacheOpInputNum),
363 "For op: " + embedding_cache_update_node_->fullname_with_scope() +
364 " need 3 inputs, but got " + std::to_string(input_kernel_tensors.size()));
365 MS_EXCEPTION_IF_CHECK_FAIL((output_kernel_tensors.size() == kCacheOpOutputNum),
366 "For op: " + embedding_cache_update_node_->fullname_with_scope() +
367 " has 1 output, but got " + std::to_string(output_kernel_tensors.size()));
368
369 std::vector<abstract::AbstractBasePtr> input_kernel_tensors_for_infer(kCacheOpInputNum, nullptr);
370 for (size_t i = 0; i < kCacheOpInputNum; i++) {
371 const auto &kernel_tensor = AnfAlgo::GetPrevNodeOutputKernelTensor(embedding_cache_update_node_, i);
372 MS_EXCEPTION_IF_NULL(kernel_tensor);
373 input_kernel_tensors_for_infer[i] = kernel_tensor;
374 }
375
376 // 2. Update input shape.
377 ShapeVector input_param_shape = {SizeToLong(cache_size), SizeToLong(embedding_size)};
378 input_kernel_tensors[kIndex0]->SetShape(std::make_shared<abstract::TensorShape>(std::move(input_param_shape)));
379 const ShapeVector input_indices_shape = {SizeToLong(indices_num)};
380 input_kernel_tensors[kIndex1]->SetShape(std::make_shared<abstract::TensorShape>(std::move(input_indices_shape)));
381 const ShapeVector update_values_shape = {SizeToLong(indices_num), SizeToLong(embedding_size)};
382 input_kernel_tensors[kIndex2]->SetShape(std::make_shared<abstract::TensorShape>(std::move(update_values_shape)));
383
384 // 3. Infer shape for embedding cache update kernel(operator name: 'ScatterUpdate') which is dynamic shape kernel.
385 if (!InferOpShape(embedding_cache_update_node_, input_kernel_tensors, output_kernel_tensors,
386 input_kernel_tensors_for_infer)) {
387 MS_LOG(ERROR) << "Infer operator shape failed, op name: " << embedding_cache_update_node_->fullname_with_scope();
388 return false;
389 }
390
391 // 4. Do update cache on device.
392 input_kernel_tensors[kIndex0]->set_device_ptr(embedding_cache);
393 input_kernel_tensors[kIndex1]->set_device_ptr(indices);
394 input_kernel_tensors[kIndex2]->set_device_ptr(update_value);
395 output_kernel_tensors[kIndex0]->set_device_ptr(embedding_cache);
396
397 MS_ERROR_IF_NULL(device_context_);
398 MS_ERROR_IF_NULL(device_context_->GetKernelExecutor(false));
399 auto kernel_mod = AnfAlgo::GetKernelMod(embedding_cache_update_node_);
400 auto stream = device_context_->device_res_manager_->GetStream(stream_id_);
401 auto ret = device_context_->GetKernelExecutor(false)->LaunchKernel(embedding_cache_update_node_, input_kernel_tensors,
402 {}, output_kernel_tensors, kernel_mod, stream);
403 if (!ret) {
404 MS_LOG(ERROR) << "Launch kernel: " << embedding_cache_update_node_->fullname_with_scope() << " failed.";
405 return false;
406 }
407 return true;
408 }
409
CheckCacheHitOrOutRange(const int * batch_ids,const size_t batch_ids_num,int * hash_index,bool * out_range,size_t data_step)410 bool DeviceDenseEmbeddingOperation::CheckCacheHitOrOutRange(const int *batch_ids, const size_t batch_ids_num,
411 int *hash_index, bool *out_range, size_t data_step) {
412 MS_ERROR_IF_NULL(batch_ids);
413 MS_ERROR_IF_NULL(hash_index);
414 MS_ERROR_IF_NULL(out_range);
415
416 size_t thread_num = batch_ids_num / kMaxIdsPerThread + 1;
417 thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
418 std::thread threads[kMaxThreadNum];
419 size_t hash_hit_count[kMaxThreadNum] = {0};
420 size_t i = 0;
421 size_t offset = 0;
422
423 for (; i < thread_num; ++i) {
424 if (offset >= batch_ids_num) {
425 break;
426 }
427 size_t proc_len = batch_ids_num / thread_num + (i < (batch_ids_num % thread_num) ? 1 : 0);
428 threads[i] = std::thread(&DeviceDenseEmbeddingOperation::CheckCacheHitOrOutRangeFunc, this, batch_ids + offset,
429 proc_len, hash_index + offset, out_range + offset, hash_hit_count + i, data_step);
430 offset += proc_len;
431 }
432 if (offset != batch_ids_num) {
433 MS_LOG(WARNING) << "Check id in device inadequate, total:" << batch_ids_num << " checked:" << offset;
434 }
435
436 for (size_t j = 0; j < i; j++) {
437 threads[j].join();
438 }
439 for (size_t j = 0; j < i; j++) {
440 statistics_info_->hash_hit_count_ += hash_hit_count[j];
441 }
442 return true;
443 }
444
CheckCacheHitOrOutRangeFunc(const int * batch_ids,const size_t batch_ids_num,int * hash_index,bool * out_range,size_t * hash_hit_count,size_t data_step)445 bool DeviceDenseEmbeddingOperation::CheckCacheHitOrOutRangeFunc(const int *batch_ids, const size_t batch_ids_num,
446 int *hash_index, bool *out_range,
447 size_t *hash_hit_count, size_t data_step) {
448 MS_ERROR_IF_NULL(batch_ids);
449 MS_ERROR_IF_NULL(hash_index);
450 MS_ERROR_IF_NULL(out_range);
451 MS_ERROR_IF_NULL(hash_hit_count);
452 auto &device_hash_map = embedding_cache_table_manager.device_hash_map_;
453 MS_ERROR_IF_NULL(device_hash_map);
454
455 for (size_t i = 0; i < batch_ids_num; ++i) {
456 if (batch_ids[i] < local_embedding_slice_bounds_.first) {
457 hash_index[i] = batch_ids[i] - local_embedding_slice_bounds_.first + local_device_cache_bounds_.first;
458 out_range[i] = true;
459 continue;
460 }
461 if (batch_ids[i] >= local_embedding_slice_bounds_.second) {
462 hash_index[i] = batch_ids[i] + local_device_cache_bounds_.second;
463 out_range[i] = true;
464 continue;
465 }
466 }
467 return true;
468 }
469
ParseDeviceData(int id,bool * need_swap_device_to_host,bool * need_swap_host_to_device,int * hash_index,size_t data_step,size_t * cur_graph_running_step,const std::atomic_ulong * latest_graph_running_step,bool * device_cache_need_wait_graph,EmbeddingDeviceCache * embedding_device_cache,EmbeddingCacheStatisticsInfo * statistics_info)470 bool DeviceDenseEmbeddingOperation::ParseDeviceData(int id, bool *need_swap_device_to_host,
471 bool *need_swap_host_to_device, int *hash_index, size_t data_step,
472 size_t *cur_graph_running_step,
473 const std::atomic_ulong *latest_graph_running_step,
474 bool *device_cache_need_wait_graph,
475 EmbeddingDeviceCache *embedding_device_cache,
476 EmbeddingCacheStatisticsInfo *statistics_info) {
477 MS_ERROR_IF_NULL(need_swap_device_to_host);
478 MS_ERROR_IF_NULL(need_swap_host_to_device);
479 MS_ERROR_IF_NULL(hash_index);
480 MS_ERROR_IF_NULL(cur_graph_running_step);
481 MS_ERROR_IF_NULL(latest_graph_running_step);
482 MS_ERROR_IF_NULL(embedding_device_cache);
483 MS_ERROR_IF_NULL(statistics_info);
484 auto &device_hash_map = embedding_cache_table_manager.device_hash_map_;
485 MS_ERROR_IF_NULL(device_hash_map);
486
487 int index = kInvalidIndexValue;
488 if (device_hash_map->GetIndex(id, &index)) {
489 *need_swap_device_to_host = false;
490 *need_swap_host_to_device = false;
491 if (device_hash_map->hash_step(index) != data_step) {
492 statistics_info->hash_hit_count_++;
493 device_hash_map->set_hash_step(index, data_step);
494 }
495 } else {
496 int *device_to_host_index = embedding_device_cache->device_to_host_index.get();
497 int *device_to_host_ids = embedding_device_cache->device_to_host_ids.get();
498 int *host_to_device_index = embedding_device_cache->host_to_device_index.get();
499 int *host_to_device_ids = embedding_device_cache->host_to_device_ids.get();
500 MS_ERROR_IF_NULL(host_to_device_index);
501 MS_ERROR_IF_NULL(host_to_device_ids);
502 auto tmp_device_to_host_size = statistics_info->device_to_host_size_;
503 size_t retry_count = 0;
504 while (true) {
505 // Calculate the mapping of id to index.
506 index =
507 device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step, *cur_graph_running_step,
508 &statistics_info->device_to_host_size_, device_cache_need_wait_graph);
509 if (index == kInvalidIndexValue) {
510 const int64_t wait_interval = 10000;
511 *cur_graph_running_step = latest_graph_running_step->load();
512 std::this_thread::sleep_for(std::chrono::microseconds(wait_interval));
513 ++retry_count;
514 if (retry_count > kMaxRetryNum) {
515 MS_LOG(ERROR) << "Prefetch embedding cache timeout, please enlarge the vocab cache size.";
516 return false;
517 }
518 MS_LOG(DEBUG) << "There is no space in device cache, wait and retrying, current graph running step: "
519 << *cur_graph_running_step << ", data step: " << data_step;
520 continue;
521 }
522 host_to_device_index[statistics_info->host_to_device_size_] = index;
523 host_to_device_ids[statistics_info->host_to_device_size_] = id;
524 statistics_info->host_to_device_size_++;
525 *need_swap_device_to_host = statistics_info->device_to_host_size_ > tmp_device_to_host_size;
526 break;
527 }
528 }
529
530 *hash_index = index;
531 return true;
532 }
533 } // namespace runtime
534 } // namespace mindspore
535