1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/graph_scheduler/actor/embedding_cache/device_embedding_operation.h"
18 #include <string>
19 #include "kernel/framework_utils.h"
20 #include "include/backend/optimizer/helper.h"
21 #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
22 #include "runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h"
23
24 namespace mindspore {
25 namespace runtime {
Initialize()26 bool DeviceEmbeddingOperation::Initialize() {
27 BuildEmbeddingCacheLookupKernel();
28 BuildEmbeddingCacheUpdateKernel();
29 return true;
30 }
31
ParseHostDataHostToDevice(int id,size_t data_step,size_t * cur_graph_running_step,const std::atomic_ulong * latest_graph_running_step,bool * host_cache_need_wait_graph,EmbeddingHostCache * embedding_host_cache,EmbeddingCacheStatisticsInfo * statistics_info)32 bool DeviceEmbeddingOperation::ParseHostDataHostToDevice(int id, size_t data_step, size_t *cur_graph_running_step,
33 const std::atomic_ulong *latest_graph_running_step,
34 bool *host_cache_need_wait_graph,
35 EmbeddingHostCache *embedding_host_cache,
36 EmbeddingCacheStatisticsInfo *statistics_info) {
37 MS_ERROR_IF_NULL(cur_graph_running_step);
38 MS_ERROR_IF_NULL(latest_graph_running_step);
39 MS_ERROR_IF_NULL(embedding_host_cache);
40 MS_ERROR_IF_NULL(statistics_info);
41 int *host_to_device_index = embedding_host_cache->host_to_device_index.get();
42 MS_ERROR_IF_NULL(host_to_device_index);
43 auto &host_hash_map = embedding_cache_table_manager.host_hash_map_;
44 MS_ERROR_IF_NULL(host_hash_map);
45
46 int index;
47 if (host_hash_map->GetIndex(id, &index)) {
48 if (host_hash_map->hash_step(index) != data_step) {
49 host_hash_map->set_hash_step(index, data_step);
50 }
51 host_to_device_index[statistics_info->host_to_device_size_ - 1] = index;
52 } else {
53 int *host_to_server_index = embedding_host_cache->host_to_server_index.get();
54 int *host_to_server_ids = embedding_host_cache->host_to_server_ids.get();
55 auto tmp_host_to_server_size = statistics_info_->host_to_server_size_;
56 size_t retry_count = 0;
57 while (true) {
58 // Calculate the mapping of id to index.
59 index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step, *cur_graph_running_step,
60 &(statistics_info->host_to_server_size_), host_cache_need_wait_graph);
61 if (index == kInvalidIndexValue) {
62 const int64_t wait_interval = 10000;
63 *cur_graph_running_step = latest_graph_running_step->load();
64 std::this_thread::sleep_for(std::chrono::microseconds(wait_interval));
65 ++retry_count;
66 if (retry_count > kMaxRetryNum) {
67 MS_LOG(ERROR) << "Prefetch embedding cache timeout, please enlarge the vocab cache size.";
68 return false;
69 }
70 MS_LOG(DEBUG) << "There is no space in local host cache, wait and retrying, current graph running step: "
71 << *cur_graph_running_step << ", data step: " << data_step;
72 continue;
73 }
74
75 // The embedding vector of id which is never used before need not be evicted to remote.
76 if (tmp_host_to_server_size < statistics_info_->host_to_server_size_) {
77 if (modified_ids_.find(host_to_server_ids[tmp_host_to_server_size]) == modified_ids_.end()) {
78 statistics_info_->host_to_server_size_ = tmp_host_to_server_size;
79 }
80 }
81
82 host_to_device_index[statistics_info->host_to_device_size_ - 1] = index;
83
84 // This feature id has never been seen before, so it's value is initialized using the local random generator.
85 // Initialize with random value when checkpoint has not been loaded.
86 if (!embedding_cache_table_manager.checkpoint_load_status() &&
87 initialized_ids_.find(id) == initialized_ids_.end()) {
88 int *new_id_index = embedding_host_cache->new_id_index.get();
89 MS_ERROR_IF_NULL(new_id_index);
90 new_id_index[statistics_info->new_id_size_++] = index;
91 (void)initialized_ids_.insert(id);
92 // This feature id has been initialized already, so it's latest value has been kept in the remote server.
93 } else {
94 int *server_to_host_index = embedding_host_cache->server_to_host_index.get();
95 int *server_to_host_ids = embedding_host_cache->server_to_host_ids.get();
96 MS_ERROR_IF_NULL(server_to_host_index);
97 MS_ERROR_IF_NULL(server_to_host_ids);
98 server_to_host_index[statistics_info->server_to_host_size_] = index;
99 server_to_host_ids[statistics_info->server_to_host_size_++] = id;
100 }
101 break;
102 }
103 }
104
105 return true;
106 }
107
ParseHostDataDeviceToHost(size_t data_step,size_t * cur_graph_running_step,const std::atomic_ulong * latest_graph_running_step,bool * host_cache_need_wait_graph,EmbeddingDeviceCache * embedding_device_cache,EmbeddingHostCache * embedding_host_cache,EmbeddingCacheStatisticsInfo * statistics_info)108 bool DeviceEmbeddingOperation::ParseHostDataDeviceToHost(size_t data_step, size_t *cur_graph_running_step,
109 const std::atomic_ulong *latest_graph_running_step,
110 bool *host_cache_need_wait_graph,
111 EmbeddingDeviceCache *embedding_device_cache,
112 EmbeddingHostCache *embedding_host_cache,
113 EmbeddingCacheStatisticsInfo *statistics_info) {
114 MS_ERROR_IF_NULL(cur_graph_running_step);
115 MS_ERROR_IF_NULL(latest_graph_running_step);
116 MS_ERROR_IF_NULL(embedding_device_cache);
117 MS_ERROR_IF_NULL(embedding_host_cache);
118 MS_ERROR_IF_NULL(statistics_info);
119
120 int *device_to_host_ids = embedding_device_cache->device_to_host_ids.get();
121 int *device_to_host_index = embedding_host_cache->device_to_host_index.get();
122 MS_ERROR_IF_NULL(device_to_host_ids);
123 MS_ERROR_IF_NULL(device_to_host_index);
124
125 auto &host_hash_map = embedding_cache_table_manager.host_hash_map_;
126 MS_ERROR_IF_NULL(host_hash_map);
127 int swap_device_to_host_id = device_to_host_ids[statistics_info->device_to_host_size_ - 1];
128 int index;
129 if (host_hash_map->GetIndex(swap_device_to_host_id, &index)) {
130 if (host_hash_map->hash_step(index) != data_step) {
131 host_hash_map->set_hash_step(index, data_step);
132 }
133 device_to_host_index[statistics_info->device_to_host_size_ - 1] = index;
134 } else {
135 int *host_to_server_index = embedding_host_cache->host_to_server_index.get();
136 int *host_to_server_ids = embedding_host_cache->host_to_server_ids.get();
137 auto tmp_host_to_server_size = statistics_info_->host_to_server_size_;
138 size_t retry_count = 0;
139 while (true) {
140 // Calculate the mapping of id to index.
141 index = host_hash_map->ParseData(swap_device_to_host_id, host_to_server_index, host_to_server_ids, data_step,
142 *cur_graph_running_step, &statistics_info->host_to_server_size_,
143 host_cache_need_wait_graph);
144 if (index == kInvalidIndexValue) {
145 const int64_t wait_interval = 10000;
146 *cur_graph_running_step = latest_graph_running_step->load();
147 std::this_thread::sleep_for(std::chrono::microseconds(wait_interval));
148 ++retry_count;
149 if (retry_count > kMaxRetryNum) {
150 MS_LOG(ERROR) << "Prefetch embedding cache timeout, please enlarge the vocab cache size.";
151 return false;
152 }
153 MS_LOG(DEBUG) << "There is no space in local host cache, wait and retrying, current graph running step: "
154 << *cur_graph_running_step << ", data step: " << data_step;
155 continue;
156 }
157
158 // The embedding vector of id which is never used before need not be evicted to remote.
159 if (tmp_host_to_server_size < statistics_info_->host_to_server_size_) {
160 if (modified_ids_.find(host_to_server_ids[tmp_host_to_server_size]) == modified_ids_.end()) {
161 statistics_info_->host_to_server_size_ = tmp_host_to_server_size;
162 }
163 }
164
165 device_to_host_index[statistics_info->device_to_host_size_ - 1] = index;
166 break;
167 }
168 }
169
170 return true;
171 }
172
MemcpyHostToDeviceAsync(void * dst,const void * src,size_t size,const DeviceContext * device_context,size_t stream_id)173 bool DeviceEmbeddingOperation::MemcpyHostToDeviceAsync(void *dst, const void *src, size_t size,
174 const DeviceContext *device_context, size_t stream_id) {
175 MS_ERROR_IF_NULL(dst);
176 MS_ERROR_IF_NULL(src);
177 MS_ERROR_IF_NULL(device_context);
178 MS_ERROR_IF_NULL(device_context->device_res_manager_);
179
180 void *device_ptr = dst;
181 const void *host_ptr = src;
182
183 auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
184 device_ptr, size, Format::DEFAULT_FORMAT, kTypeUnknown, ShapeVector(),
185 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
186 kernel_tensor->set_stream_id(stream_id);
187 auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
188 MS_ERROR_IF_NULL(device_address);
189 RETURN_IF_FALSE_WITH_LOG(device_address->AsyncHostToDevice({}, size, kTypeUnknown, host_ptr, stream_id),
190 "Async memcpy host to device failed.");
191
192 return true;
193 }
194
MemcpyDeviceToHostAsync(void * dst,const void * src,size_t size,const DeviceContext * device_context,size_t stream_id)195 bool DeviceEmbeddingOperation::MemcpyDeviceToHostAsync(void *dst, const void *src, size_t size,
196 const DeviceContext *device_context, size_t stream_id) {
197 MS_ERROR_IF_NULL(dst);
198 MS_ERROR_IF_NULL(src);
199 MS_ERROR_IF_NULL(device_context);
200 MS_ERROR_IF_NULL(device_context->device_res_manager_);
201
202 void *device_ptr = const_cast<void *>(src);
203 void *host_ptr = dst;
204
205 auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
206 device_ptr, size, Format::DEFAULT_FORMAT, kTypeUnknown, ShapeVector(),
207 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
208 kernel_tensor->set_stream_id(stream_id);
209 auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
210 MS_ERROR_IF_NULL(device_address);
211 RETURN_IF_FALSE_WITH_LOG(device_address->AsyncDeviceToHost({}, size, kTypeUnknown, host_ptr, stream_id),
212 "Async memcpy device to host failed.");
213
214 return true;
215 }
216
NewParameter(const KernelGraphPtr & graph,TypePtr type,const ShapeVector & shape)217 ParameterPtr DeviceEmbeddingOperation::NewParameter(const KernelGraphPtr &graph, TypePtr type,
218 const ShapeVector &shape) {
219 MS_EXCEPTION_IF_NULL(graph);
220 MS_EXCEPTION_IF_NULL(type);
221
222 auto param = graph->NewParameter();
223 MS_EXCEPTION_IF_NULL(param);
224 auto abstract = std::make_shared<abstract::AbstractTensor>(type, shape);
225 param->set_abstract(abstract);
226
227 auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
228 std::vector<std::string> formats = {kOpFormat_DEFAULT};
229 std::vector<TypeId> types = {type->type_id()};
230 kernel_build_info_builder->SetOutputsFormat(formats);
231 kernel_build_info_builder->SetOutputsDeviceType(types);
232 AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
233
234 auto mutable_inputs = graph->MutableInputs();
235 MS_EXCEPTION_IF_NULL(mutable_inputs);
236 mutable_inputs->push_back(param);
237
238 return param;
239 }
240
NewValueNode(int64_t value,const DeviceContext * device_context,size_t stream_id)241 ValueNodePtr DeviceEmbeddingOperation::NewValueNode(int64_t value, const DeviceContext *device_context,
242 size_t stream_id) {
243 MS_EXCEPTION_IF_NULL(device_context);
244
245 auto tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(value), kInt32);
246 auto value_node = mindspore::NewValueNode(tensor);
247 value_node->set_abstract(tensor->ToAbstract());
248
249 // Create kernel build info.
250 auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
251 std::vector<std::string> formats = {kOpFormat_DEFAULT};
252 std::vector<TypeId> types = {kInt32->type_id()};
253 kernel_build_info_builder->SetOutputsFormat(formats);
254 kernel_build_info_builder->SetOutputsDeviceType(types);
255
256 auto kernel_info = std::make_shared<device::KernelInfo>();
257 MS_EXCEPTION_IF_NULL(kernel_info);
258 value_node->set_kernel_info(kernel_info);
259 AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), value_node.get());
260
261 // Create device address.
262 size_t output_idx = 0;
263 size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
264 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
265 std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
266
267 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
268 auto value_addr = device_context->device_res_manager_->AllocateMemory(tensor_size);
269 MS_EXCEPTION_IF_NULL(value_addr);
270
271 const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
272 {value_node, output_idx}, value_addr, tensor_size, output_format, output_type_id,
273 trans::GetRuntimePaddingShape(value_node, output_idx), device_context->device_context_key().device_name_,
274 device_context->device_context_key().device_id_);
275 kernel_tensor->set_stream_id(stream_id);
276 auto address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
277 MS_EXCEPTION_IF_NULL(address);
278
279 // Sync tensor value.
280 MS_EXCEPTION_IF_CHECK_FAIL(address->AsyncHostToDevice({}, tensor_size, output_type_id, tensor->data_c(), stream_id),
281 "Async memcpy host to device failed.");
282 MS_EXCEPTION_IF_CHECK_FAIL(device_context->device_res_manager_->SyncStream(stream_id), "Synchronize stream failed.");
283
284 address->set_from_persistent_mem(true);
285 AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
286
287 return value_node;
288 }
289
InferOpShape(const CNodePtr & kernel,const std::vector<kernel::KernelTensor * > & input_kernel_tensors,const std::vector<kernel::KernelTensor * > & output_kernel_tensors,const std::vector<abstract::AbstractBasePtr> & input_kernel_tensors_for_infer)290 bool DeviceEmbeddingOperation::InferOpShape(
291 const CNodePtr &kernel, const std::vector<kernel::KernelTensor *> &input_kernel_tensors,
292 const std::vector<kernel::KernelTensor *> &output_kernel_tensors,
293 const std::vector<abstract::AbstractBasePtr> &input_kernel_tensors_for_infer) {
294 MS_ERROR_IF_NULL(kernel);
295 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
296 MS_ERROR_IF_NULL(kernel_mod);
297 // 1. Infer operator's output's Shape.
298 auto base_shape = opt::dynamic_shape::InferShape(kernel_mod->primitive(), input_kernel_tensors_for_infer);
299 MS_ERROR_IF_NULL(base_shape);
300 MS_LOG(DEBUG) << "End InferShape for kernel: " << kernel->fullname_with_scope()
301 << ", shape: " << base_shape->ToString();
302
303 // 2. Update shape of output kernel tensor.
304 opt::dynamic_shape::UpdateKernelTensorShape(base_shape, output_kernel_tensors);
305
306 // 3. Resize kernel mod.
307 MS_LOG(DEBUG) << "Begin Resize kernel mod for kernel: " << kernel->fullname_with_scope();
308 int ret = kernel_mod->Resize(input_kernel_tensors, output_kernel_tensors);
309 MS_LOG(DEBUG) << "End Resize kernel mod for kernel: " << kernel->fullname_with_scope()
310 << ", the output size list: " << kernel_mod->GetOutputSizeList();
311 if (ret != kernel::KRET_OK) {
312 MS_LOG(ERROR) << "Resize failed for kernel: " << kernel->fullname_with_scope();
313 return false;
314 }
315 return true;
316 }
317 } // namespace runtime
318 } // namespace mindspore
319