• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/actor/embedding_cache/device_embedding_operation.h"
18 #include <string>
19 #include "kernel/framework_utils.h"
20 #include "include/backend/optimizer/helper.h"
21 #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
22 #include "runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h"
23 
24 namespace mindspore {
25 namespace runtime {
Initialize()26 bool DeviceEmbeddingOperation::Initialize() {
27   BuildEmbeddingCacheLookupKernel();
28   BuildEmbeddingCacheUpdateKernel();
29   return true;
30 }
31 
ParseHostDataHostToDevice(int id,size_t data_step,size_t * cur_graph_running_step,const std::atomic_ulong * latest_graph_running_step,bool * host_cache_need_wait_graph,EmbeddingHostCache * embedding_host_cache,EmbeddingCacheStatisticsInfo * statistics_info)32 bool DeviceEmbeddingOperation::ParseHostDataHostToDevice(int id, size_t data_step, size_t *cur_graph_running_step,
33                                                          const std::atomic_ulong *latest_graph_running_step,
34                                                          bool *host_cache_need_wait_graph,
35                                                          EmbeddingHostCache *embedding_host_cache,
36                                                          EmbeddingCacheStatisticsInfo *statistics_info) {
37   MS_ERROR_IF_NULL(cur_graph_running_step);
38   MS_ERROR_IF_NULL(latest_graph_running_step);
39   MS_ERROR_IF_NULL(embedding_host_cache);
40   MS_ERROR_IF_NULL(statistics_info);
41   int *host_to_device_index = embedding_host_cache->host_to_device_index.get();
42   MS_ERROR_IF_NULL(host_to_device_index);
43   auto &host_hash_map = embedding_cache_table_manager.host_hash_map_;
44   MS_ERROR_IF_NULL(host_hash_map);
45 
46   int index;
47   if (host_hash_map->GetIndex(id, &index)) {
48     if (host_hash_map->hash_step(index) != data_step) {
49       host_hash_map->set_hash_step(index, data_step);
50     }
51     host_to_device_index[statistics_info->host_to_device_size_ - 1] = index;
52   } else {
53     int *host_to_server_index = embedding_host_cache->host_to_server_index.get();
54     int *host_to_server_ids = embedding_host_cache->host_to_server_ids.get();
55     auto tmp_host_to_server_size = statistics_info_->host_to_server_size_;
56     size_t retry_count = 0;
57     while (true) {
58       // Calculate the mapping of id to index.
59       index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step, *cur_graph_running_step,
60                                        &(statistics_info->host_to_server_size_), host_cache_need_wait_graph);
61       if (index == kInvalidIndexValue) {
62         const int64_t wait_interval = 10000;
63         *cur_graph_running_step = latest_graph_running_step->load();
64         std::this_thread::sleep_for(std::chrono::microseconds(wait_interval));
65         ++retry_count;
66         if (retry_count > kMaxRetryNum) {
67           MS_LOG(ERROR) << "Prefetch embedding cache timeout, please enlarge the vocab cache size.";
68           return false;
69         }
70         MS_LOG(DEBUG) << "There is no space in local host cache, wait and retrying, current graph running step: "
71                       << *cur_graph_running_step << ", data step: " << data_step;
72         continue;
73       }
74 
75       // The embedding vector of id which is never used before need not be evicted to remote.
76       if (tmp_host_to_server_size < statistics_info_->host_to_server_size_) {
77         if (modified_ids_.find(host_to_server_ids[tmp_host_to_server_size]) == modified_ids_.end()) {
78           statistics_info_->host_to_server_size_ = tmp_host_to_server_size;
79         }
80       }
81 
82       host_to_device_index[statistics_info->host_to_device_size_ - 1] = index;
83 
84       // This feature id has never been seen before, so it's value is initialized using the local random generator.
85       // Initialize with random value when checkpoint has not been loaded.
86       if (!embedding_cache_table_manager.checkpoint_load_status() &&
87           initialized_ids_.find(id) == initialized_ids_.end()) {
88         int *new_id_index = embedding_host_cache->new_id_index.get();
89         MS_ERROR_IF_NULL(new_id_index);
90         new_id_index[statistics_info->new_id_size_++] = index;
91         (void)initialized_ids_.insert(id);
92         // This feature id has been initialized already, so it's latest value has been kept in the remote server.
93       } else {
94         int *server_to_host_index = embedding_host_cache->server_to_host_index.get();
95         int *server_to_host_ids = embedding_host_cache->server_to_host_ids.get();
96         MS_ERROR_IF_NULL(server_to_host_index);
97         MS_ERROR_IF_NULL(server_to_host_ids);
98         server_to_host_index[statistics_info->server_to_host_size_] = index;
99         server_to_host_ids[statistics_info->server_to_host_size_++] = id;
100       }
101       break;
102     }
103   }
104 
105   return true;
106 }
107 
ParseHostDataDeviceToHost(size_t data_step,size_t * cur_graph_running_step,const std::atomic_ulong * latest_graph_running_step,bool * host_cache_need_wait_graph,EmbeddingDeviceCache * embedding_device_cache,EmbeddingHostCache * embedding_host_cache,EmbeddingCacheStatisticsInfo * statistics_info)108 bool DeviceEmbeddingOperation::ParseHostDataDeviceToHost(size_t data_step, size_t *cur_graph_running_step,
109                                                          const std::atomic_ulong *latest_graph_running_step,
110                                                          bool *host_cache_need_wait_graph,
111                                                          EmbeddingDeviceCache *embedding_device_cache,
112                                                          EmbeddingHostCache *embedding_host_cache,
113                                                          EmbeddingCacheStatisticsInfo *statistics_info) {
114   MS_ERROR_IF_NULL(cur_graph_running_step);
115   MS_ERROR_IF_NULL(latest_graph_running_step);
116   MS_ERROR_IF_NULL(embedding_device_cache);
117   MS_ERROR_IF_NULL(embedding_host_cache);
118   MS_ERROR_IF_NULL(statistics_info);
119 
120   int *device_to_host_ids = embedding_device_cache->device_to_host_ids.get();
121   int *device_to_host_index = embedding_host_cache->device_to_host_index.get();
122   MS_ERROR_IF_NULL(device_to_host_ids);
123   MS_ERROR_IF_NULL(device_to_host_index);
124 
125   auto &host_hash_map = embedding_cache_table_manager.host_hash_map_;
126   MS_ERROR_IF_NULL(host_hash_map);
127   int swap_device_to_host_id = device_to_host_ids[statistics_info->device_to_host_size_ - 1];
128   int index;
129   if (host_hash_map->GetIndex(swap_device_to_host_id, &index)) {
130     if (host_hash_map->hash_step(index) != data_step) {
131       host_hash_map->set_hash_step(index, data_step);
132     }
133     device_to_host_index[statistics_info->device_to_host_size_ - 1] = index;
134   } else {
135     int *host_to_server_index = embedding_host_cache->host_to_server_index.get();
136     int *host_to_server_ids = embedding_host_cache->host_to_server_ids.get();
137     auto tmp_host_to_server_size = statistics_info_->host_to_server_size_;
138     size_t retry_count = 0;
139     while (true) {
140       // Calculate the mapping of id to index.
141       index = host_hash_map->ParseData(swap_device_to_host_id, host_to_server_index, host_to_server_ids, data_step,
142                                        *cur_graph_running_step, &statistics_info->host_to_server_size_,
143                                        host_cache_need_wait_graph);
144       if (index == kInvalidIndexValue) {
145         const int64_t wait_interval = 10000;
146         *cur_graph_running_step = latest_graph_running_step->load();
147         std::this_thread::sleep_for(std::chrono::microseconds(wait_interval));
148         ++retry_count;
149         if (retry_count > kMaxRetryNum) {
150           MS_LOG(ERROR) << "Prefetch embedding cache timeout, please enlarge the vocab cache size.";
151           return false;
152         }
153         MS_LOG(DEBUG) << "There is no space in local host cache, wait and retrying, current graph running step: "
154                       << *cur_graph_running_step << ", data step: " << data_step;
155         continue;
156       }
157 
158       // The embedding vector of id which is never used before need not be evicted to remote.
159       if (tmp_host_to_server_size < statistics_info_->host_to_server_size_) {
160         if (modified_ids_.find(host_to_server_ids[tmp_host_to_server_size]) == modified_ids_.end()) {
161           statistics_info_->host_to_server_size_ = tmp_host_to_server_size;
162         }
163       }
164 
165       device_to_host_index[statistics_info->device_to_host_size_ - 1] = index;
166       break;
167     }
168   }
169 
170   return true;
171 }
172 
MemcpyHostToDeviceAsync(void * dst,const void * src,size_t size,const DeviceContext * device_context,size_t stream_id)173 bool DeviceEmbeddingOperation::MemcpyHostToDeviceAsync(void *dst, const void *src, size_t size,
174                                                        const DeviceContext *device_context, size_t stream_id) {
175   MS_ERROR_IF_NULL(dst);
176   MS_ERROR_IF_NULL(src);
177   MS_ERROR_IF_NULL(device_context);
178   MS_ERROR_IF_NULL(device_context->device_res_manager_);
179 
180   void *device_ptr = dst;
181   const void *host_ptr = src;
182 
183   auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
184     device_ptr, size, Format::DEFAULT_FORMAT, kTypeUnknown, ShapeVector(),
185     device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
186   kernel_tensor->set_stream_id(stream_id);
187   auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
188   MS_ERROR_IF_NULL(device_address);
189   RETURN_IF_FALSE_WITH_LOG(device_address->AsyncHostToDevice({}, size, kTypeUnknown, host_ptr, stream_id),
190                            "Async memcpy host to device failed.");
191 
192   return true;
193 }
194 
MemcpyDeviceToHostAsync(void * dst,const void * src,size_t size,const DeviceContext * device_context,size_t stream_id)195 bool DeviceEmbeddingOperation::MemcpyDeviceToHostAsync(void *dst, const void *src, size_t size,
196                                                        const DeviceContext *device_context, size_t stream_id) {
197   MS_ERROR_IF_NULL(dst);
198   MS_ERROR_IF_NULL(src);
199   MS_ERROR_IF_NULL(device_context);
200   MS_ERROR_IF_NULL(device_context->device_res_manager_);
201 
202   void *device_ptr = const_cast<void *>(src);
203   void *host_ptr = dst;
204 
205   auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
206     device_ptr, size, Format::DEFAULT_FORMAT, kTypeUnknown, ShapeVector(),
207     device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
208   kernel_tensor->set_stream_id(stream_id);
209   auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
210   MS_ERROR_IF_NULL(device_address);
211   RETURN_IF_FALSE_WITH_LOG(device_address->AsyncDeviceToHost({}, size, kTypeUnknown, host_ptr, stream_id),
212                            "Async memcpy device to host failed.");
213 
214   return true;
215 }
216 
NewParameter(const KernelGraphPtr & graph,TypePtr type,const ShapeVector & shape)217 ParameterPtr DeviceEmbeddingOperation::NewParameter(const KernelGraphPtr &graph, TypePtr type,
218                                                     const ShapeVector &shape) {
219   MS_EXCEPTION_IF_NULL(graph);
220   MS_EXCEPTION_IF_NULL(type);
221 
222   auto param = graph->NewParameter();
223   MS_EXCEPTION_IF_NULL(param);
224   auto abstract = std::make_shared<abstract::AbstractTensor>(type, shape);
225   param->set_abstract(abstract);
226 
227   auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
228   std::vector<std::string> formats = {kOpFormat_DEFAULT};
229   std::vector<TypeId> types = {type->type_id()};
230   kernel_build_info_builder->SetOutputsFormat(formats);
231   kernel_build_info_builder->SetOutputsDeviceType(types);
232   AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
233 
234   auto mutable_inputs = graph->MutableInputs();
235   MS_EXCEPTION_IF_NULL(mutable_inputs);
236   mutable_inputs->push_back(param);
237 
238   return param;
239 }
240 
NewValueNode(int64_t value,const DeviceContext * device_context,size_t stream_id)241 ValueNodePtr DeviceEmbeddingOperation::NewValueNode(int64_t value, const DeviceContext *device_context,
242                                                     size_t stream_id) {
243   MS_EXCEPTION_IF_NULL(device_context);
244 
245   auto tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(value), kInt32);
246   auto value_node = mindspore::NewValueNode(tensor);
247   value_node->set_abstract(tensor->ToAbstract());
248 
249   // Create kernel build info.
250   auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
251   std::vector<std::string> formats = {kOpFormat_DEFAULT};
252   std::vector<TypeId> types = {kInt32->type_id()};
253   kernel_build_info_builder->SetOutputsFormat(formats);
254   kernel_build_info_builder->SetOutputsDeviceType(types);
255 
256   auto kernel_info = std::make_shared<device::KernelInfo>();
257   MS_EXCEPTION_IF_NULL(kernel_info);
258   value_node->set_kernel_info(kernel_info);
259   AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), value_node.get());
260 
261   // Create device address.
262   size_t output_idx = 0;
263   size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
264   TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
265   std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
266 
267   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
268   auto value_addr = device_context->device_res_manager_->AllocateMemory(tensor_size);
269   MS_EXCEPTION_IF_NULL(value_addr);
270 
271   const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
272     {value_node, output_idx}, value_addr, tensor_size, output_format, output_type_id,
273     trans::GetRuntimePaddingShape(value_node, output_idx), device_context->device_context_key().device_name_,
274     device_context->device_context_key().device_id_);
275   kernel_tensor->set_stream_id(stream_id);
276   auto address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
277   MS_EXCEPTION_IF_NULL(address);
278 
279   // Sync tensor value.
280   MS_EXCEPTION_IF_CHECK_FAIL(address->AsyncHostToDevice({}, tensor_size, output_type_id, tensor->data_c(), stream_id),
281                              "Async memcpy host to device failed.");
282   MS_EXCEPTION_IF_CHECK_FAIL(device_context->device_res_manager_->SyncStream(stream_id), "Synchronize stream failed.");
283 
284   address->set_from_persistent_mem(true);
285   AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
286 
287   return value_node;
288 }
289 
InferOpShape(const CNodePtr & kernel,const std::vector<kernel::KernelTensor * > & input_kernel_tensors,const std::vector<kernel::KernelTensor * > & output_kernel_tensors,const std::vector<abstract::AbstractBasePtr> & input_kernel_tensors_for_infer)290 bool DeviceEmbeddingOperation::InferOpShape(
291   const CNodePtr &kernel, const std::vector<kernel::KernelTensor *> &input_kernel_tensors,
292   const std::vector<kernel::KernelTensor *> &output_kernel_tensors,
293   const std::vector<abstract::AbstractBasePtr> &input_kernel_tensors_for_infer) {
294   MS_ERROR_IF_NULL(kernel);
295   auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
296   MS_ERROR_IF_NULL(kernel_mod);
297   // 1. Infer operator's output's Shape.
298   auto base_shape = opt::dynamic_shape::InferShape(kernel_mod->primitive(), input_kernel_tensors_for_infer);
299   MS_ERROR_IF_NULL(base_shape);
300   MS_LOG(DEBUG) << "End InferShape for kernel: " << kernel->fullname_with_scope()
301                 << ", shape: " << base_shape->ToString();
302 
303   // 2. Update shape of output kernel tensor.
304   opt::dynamic_shape::UpdateKernelTensorShape(base_shape, output_kernel_tensors);
305 
306   // 3. Resize kernel mod.
307   MS_LOG(DEBUG) << "Begin Resize kernel mod for kernel: " << kernel->fullname_with_scope();
308   int ret = kernel_mod->Resize(input_kernel_tensors, output_kernel_tensors);
309   MS_LOG(DEBUG) << "End Resize kernel mod for kernel: " << kernel->fullname_with_scope()
310                 << ", the output size list: " << kernel_mod->GetOutputSizeList();
311   if (ret != kernel::KRET_OK) {
312     MS_LOG(ERROR) << "Resize failed for kernel: " << kernel->fullname_with_scope();
313     return false;
314   }
315   return true;
316 }
317 }  // namespace runtime
318 }  // namespace mindspore
319