1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <vector>
18 #include <string>
19 #include <limits>
20 #include "runtime/graph_scheduler/actor/embedding_cache/device_sparse_embedding_operation.h"
21 #include "ops/sparse_op_name.h"
22 #include "ir/map_tensor.h"
23
24 namespace mindspore {
25 namespace runtime {
Initialize()26 bool DeviceSparseEmbeddingOperation::Initialize() {
27 RETURN_IF_FALSE_WITH_LOG(DeviceEmbeddingOperation::Initialize(), "Initialize device embedding operation failed.");
28 BuildEmbeddingCacheEraseKernel();
29 return true;
30 }
31
PushCacheFromDeviceToLocalHost(const HashTableInfo & hash_info,const CacheAnalysis * cache_analysis)32 bool DeviceSparseEmbeddingOperation::PushCacheFromDeviceToLocalHost(const HashTableInfo &hash_info,
33 const CacheAnalysis *cache_analysis) {
34 MS_ERROR_IF_NULL(cache_analysis);
35 auto statistics_info = cache_analysis->statistics_info_;
36 auto embedding_device_cache = cache_analysis->embedding_device_cache_;
37 auto embedding_host_cache = cache_analysis->embedding_host_cache_;
38 MS_ERROR_IF_NULL(statistics_info);
39 MS_ERROR_IF_NULL(embedding_device_cache);
40 MS_ERROR_IF_NULL(embedding_host_cache);
41
42 auto swap_indices_size = statistics_info->device_to_host_size_;
43 if (swap_indices_size == 0) {
44 return true;
45 }
46
47 auto device_cache_device_to_host_ids = embedding_device_cache->device_to_host_ids.get();
48 auto host_cache_device_to_host_index = embedding_host_cache->device_to_host_index.get();
49 MS_ERROR_IF_NULL(device_cache_device_to_host_ids);
50 MS_ERROR_IF_NULL(host_cache_device_to_host_index);
51 auto hash_table_addr = reinterpret_cast<float *>(hash_info.address.addr);
52 auto host_hash_table_addr = hash_info.host_address;
53 auto embedding_size = hash_info.embedding_size;
54 auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
55
56 // Copy origin id to temp buffer of indices.
57 int *tmp_swap_ids = embedding_cache_table_manager.hash_swap_index_addr_;
58 RETURN_IF_FALSE_WITH_LOG(MemcpyHostToDeviceAsync(tmp_swap_ids, device_cache_device_to_host_ids,
59 swap_indices_size * sizeof(int), device_context_, stream_id_),
60 "Memcpy host to device asynchronously failed.");
61
62 RETURN_IF_FALSE_WITH_LOG(LookupDeviceCache(hash_info.device_address, tmp_swap_ids, hash_table_addr, swap_indices_size,
63 embedding_size, embedding_cache_table_manager.hash_swap_value_addr_),
64 "Lookup device cache failed.");
65
66 // Erase swap out id from device hash table.
67 RETURN_IF_FALSE_WITH_LOG(EraseDeviceCache(tmp_swap_ids, swap_indices_size, hash_table_addr, hash_info.device_address),
68 "Erase device cache failed");
69
70 RETURN_IF_FALSE_WITH_LOG(
71 MemcpyDeviceToHostAsync(swap_out_data.get(), embedding_cache_table_manager.hash_swap_value_addr_,
72 swap_indices_size * embedding_size * sizeof(float), device_context_, stream_id_),
73 "Memcpy device to host asynchronously failed.");
74
75 MS_ERROR_IF_NULL(device_context_);
76 MS_ERROR_IF_NULL(device_context_->device_res_manager_);
77 RETURN_IF_FALSE_WITH_LOG(device_context_->device_res_manager_->SyncStream(stream_id_), "Synchronize stream failed.");
78 RETURN_IF_FALSE_WITH_LOG(
79 actor_->InsertLocalHostCache(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index,
80 swap_out_data.get(), host_hash_table_addr),
81 "Insert local host cache failed.");
82 return true;
83 }
84
PullCacheFromLocalHostToDevice(const HashTableInfo & hash_info,const CacheAnalysis * cache_analysis)85 bool DeviceSparseEmbeddingOperation::PullCacheFromLocalHostToDevice(const HashTableInfo &hash_info,
86 const CacheAnalysis *cache_analysis) {
87 MS_ERROR_IF_NULL(cache_analysis);
88 auto statistics_info = cache_analysis->statistics_info_;
89 auto embedding_device_cache = cache_analysis->embedding_device_cache_;
90 auto embedding_host_cache = cache_analysis->embedding_host_cache_;
91 MS_ERROR_IF_NULL(statistics_info);
92 MS_ERROR_IF_NULL(embedding_device_cache);
93 MS_ERROR_IF_NULL(embedding_host_cache);
94
95 auto swap_indices_size = statistics_info->host_to_device_size_;
96 if (swap_indices_size == 0) {
97 return true;
98 }
99
100 auto host_cache_host_to_device_index = embedding_host_cache->host_to_device_index.get();
101 auto device_cache_host_to_device_ids = embedding_device_cache->host_to_device_ids.get();
102 MS_ERROR_IF_NULL(host_cache_host_to_device_index);
103 MS_ERROR_IF_NULL(device_cache_host_to_device_ids);
104
105 auto embedding_size = hash_info.embedding_size;
106 MS_ERROR_IF_NULL(hash_info.address.addr);
107 auto hash_table_addr = reinterpret_cast<float *>(hash_info.address.addr);
108 MS_ERROR_IF_NULL(hash_info.host_address);
109 auto host_hash_table_addr = hash_info.host_address;
110 auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
111 RETURN_IF_FALSE_WITH_LOG(actor_->LookupLocalHostCache(embedding_size, swap_indices_size, host_hash_table_addr,
112 host_cache_host_to_device_index, swap_out_data.get()),
113 "Lookup local host cache failed.");
114
115 RETURN_IF_FALSE_WITH_LOG(
116 MemcpyHostToDeviceAsync(embedding_cache_table_manager.hash_swap_value_addr_, swap_out_data.get(),
117 swap_indices_size * embedding_size * sizeof(float), device_context_, stream_id_),
118 "Memcpy host to device asynchronously failed.");
119 // Copy origin id to temp buffer of indices.
120 RETURN_IF_FALSE_WITH_LOG(
121 MemcpyHostToDeviceAsync(embedding_cache_table_manager.hash_swap_index_addr_, device_cache_host_to_device_ids,
122 swap_indices_size * sizeof(int), device_context_, stream_id_),
123 "Memcpy host to device asynchronously failed.");
124
125 RETURN_IF_FALSE_WITH_LOG(UpdateDeviceCache(embedding_cache_table_manager.hash_swap_index_addr_,
126 embedding_cache_table_manager.hash_swap_value_addr_, swap_indices_size,
127 embedding_size, hash_table_addr, hash_info.device_address),
128 "Update device embedding cache failed.");
129 MS_ERROR_IF_NULL(device_context_);
130 MS_ERROR_IF_NULL(device_context_->device_res_manager_);
131 RETURN_IF_FALSE_WITH_LOG(device_context_->device_res_manager_->SyncStream(stream_id_), "Synchronize stream failed.");
132 return true;
133 }
134
GetRemoteEmbeddingSliceBound(size_t vocab_size,size_t server_num,std::vector<std::pair<size_t,size_t>> * remote_embedding_slice_bounds)135 void DeviceSparseEmbeddingOperation::GetRemoteEmbeddingSliceBound(
136 size_t vocab_size, size_t server_num, std::vector<std::pair<size_t, size_t>> *remote_embedding_slice_bounds) {
137 if (server_num != 1) {
138 MS_LOG(EXCEPTION)
139 << "Sparse mode does not support multiple servers currently, so server number should be 1, but got: "
140 << server_num;
141 }
142
143 MS_EXCEPTION_IF_NULL(remote_embedding_slice_bounds);
144 // Sparse mode does not support multiple servers currently, so the id does not need to be split, and the id range is
145 // specified from 0 to INTMAX .
146 (void)remote_embedding_slice_bounds->emplace_back(0, INT32_MAX);
147 }
148
BuildEmbeddingCacheLookupKernel()149 void DeviceSparseEmbeddingOperation::BuildEmbeddingCacheLookupKernel() {
150 auto graph = std::make_shared<KernelGraph>();
151 MS_EXCEPTION_IF_NULL(graph);
152 graph->set_graph_id((std::numeric_limits<uint32_t>::max)());
153 embedding_cache_graphs_.push_back(graph);
154
155 // 1. Create parameter nodes which are inputs of embedding cache look up kernel(operator name: 'MapTensorGet').
156 ParameterPtr input_param = NewMapParameter(graph, kNumberTypeInt32, kNumberTypeFloat32, kOneDimensionalShape);
157 ParameterPtr input_ids = NewParameter(graph, kInt32, kOneDimensionalShape);
158
159 // 2. Create a CNode for operator MapTensorGet.
160 PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kMapTensorGetOpName);
161 emb_lookup_primitive->set_attr(kAttrInputIsDynamicShape, MakeValue(true));
162 emb_lookup_primitive->set_attr(kAttrOutputIsDynamicShape, MakeValue(true));
163 emb_lookup_primitive->set_attr(kAttrInsertDefaultValue, MakeValue(false));
164
165 std::vector<AnfNodePtr> emb_lookup_input_nodes{mindspore::NewValueNode(emb_lookup_primitive), input_param, input_ids};
166 embedding_cache_lookup_node_ = graph->NewCNode(emb_lookup_input_nodes);
167 MS_EXCEPTION_IF_NULL(embedding_cache_lookup_node_);
168 auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, kTwoDimensionalShape);
169 embedding_cache_lookup_node_->set_abstract(abstract);
170
171 // 3. Kernel build process.
172 MS_EXCEPTION_IF_NULL(device_context_);
173 MS_EXCEPTION_IF_NULL(device_context_->GetKernelExecutor(false));
174 device_context_->GetKernelExecutor(false)->CreateKernel({embedding_cache_lookup_node_});
175 AnfAlgo::SetStreamId(stream_id_, embedding_cache_lookup_node_.get());
176 }
177
BuildEmbeddingCacheUpdateKernel()178 void DeviceSparseEmbeddingOperation::BuildEmbeddingCacheUpdateKernel() {
179 auto graph = std::make_shared<KernelGraph>();
180 MS_EXCEPTION_IF_NULL(graph);
181 graph->set_graph_id((std::numeric_limits<uint32_t>::max)());
182 embedding_cache_graphs_.push_back(graph);
183
184 // 1. Create parameter nodes which are inputs of embedding cache update kernel(operator name: 'MapTensorPut').
185 ParameterPtr input_param = NewMapParameter(graph, kNumberTypeInt32, kNumberTypeFloat32, kOneDimensionalShape);
186 ParameterPtr input_ids = NewParameter(graph, kInt32, kOneDimensionalShape);
187 ParameterPtr update_values = NewParameter(graph, kFloat32, kTwoDimensionalShape);
188
189 // 2. Create a CNode for operator MapTensorPut.
190 PrimitivePtr embedding_cache_update_primitive = std::make_shared<Primitive>(kMapTensorPutOpName);
191 embedding_cache_update_primitive->set_attr(kAttrInputIsDynamicShape, MakeValue(true));
192
193 std::vector<AnfNodePtr> embedding_cache_update_input_nodes{mindspore::NewValueNode(embedding_cache_update_primitive),
194 input_param, input_ids, update_values};
195 embedding_cache_update_node_ = graph->NewCNode(embedding_cache_update_input_nodes);
196 MS_EXCEPTION_IF_NULL(embedding_cache_update_node_);
197 embedding_cache_update_node_->set_abstract(input_param->abstract());
198
199 // 3. Kernel build process.
200 MS_EXCEPTION_IF_NULL(device_context_);
201 MS_EXCEPTION_IF_NULL(device_context_->GetKernelExecutor(false));
202 device_context_->GetKernelExecutor(false)->CreateKernel({embedding_cache_update_node_});
203 AnfAlgo::SetStreamId(stream_id_, embedding_cache_update_node_.get());
204 }
205
BuildEmbeddingCacheEraseKernel()206 void DeviceSparseEmbeddingOperation::BuildEmbeddingCacheEraseKernel() {
207 auto graph = std::make_shared<KernelGraph>();
208 MS_EXCEPTION_IF_NULL(graph);
209 graph->set_graph_id((std::numeric_limits<uint32_t>::max)());
210 embedding_cache_graphs_.push_back(graph);
211
212 // 1. Create parameter nodes which are inputs of embedding cache erase kernel(operator name: 'MapTensorErase').
213 ParameterPtr input_param = NewMapParameter(graph, kNumberTypeInt32, kNumberTypeFloat32, kOneDimensionalShape);
214 ParameterPtr input_ids = NewParameter(graph, kInt32, kOneDimensionalShape);
215
216 // 2. Create a CNode for operator MapTensorErase.
217 PrimitivePtr embedding_cache_erase_primitive = std::make_shared<Primitive>(kMapTensorEraseOpName);
218 embedding_cache_erase_primitive->set_attr(kAttrInputIsDynamicShape, MakeValue(true));
219 embedding_cache_erase_primitive->set_attr(kAttrOutputIsDynamicShape, MakeValue(true));
220
221 std::vector<AnfNodePtr> embedding_cache_erase_input_nodes{mindspore::NewValueNode(embedding_cache_erase_primitive),
222 input_param, input_ids};
223 embedding_cache_erase_node_ = graph->NewCNode(embedding_cache_erase_input_nodes);
224 MS_EXCEPTION_IF_NULL(embedding_cache_erase_node_);
225 embedding_cache_erase_node_->set_abstract(input_param->abstract());
226
227 // 3. Kernel build process.
228 MS_EXCEPTION_IF_NULL(device_context_);
229 MS_EXCEPTION_IF_NULL(device_context_->GetKernelExecutor(false));
230 device_context_->GetKernelExecutor(false)->CreateKernel({embedding_cache_erase_node_});
231 AnfAlgo::SetStreamId(stream_id_, embedding_cache_erase_node_.get());
232 }
233
NewMapParameter(const KernelGraphPtr & graph,TypeId key_type,TypeId value_type,const ShapeVector & value_shape)234 ParameterPtr DeviceSparseEmbeddingOperation::NewMapParameter(const KernelGraphPtr &graph, TypeId key_type,
235 TypeId value_type, const ShapeVector &value_shape) {
236 MS_EXCEPTION_IF_NULL(graph);
237
238 auto param = graph->NewParameter();
239 MS_EXCEPTION_IF_NULL(param);
240 auto map_tensor = std::make_shared<tensor::MapTensor>(key_type, value_type, value_shape, nullptr);
241 auto abstract = std::make_shared<abstract::AbstractMapTensor>(map_tensor);
242 param->set_abstract(abstract);
243
244 auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
245 std::vector<std::string> formats = {kOpFormat_DEFAULT};
246 std::vector<TypeId> types = {kObjectTypeMapTensorType};
247 kernel_build_info_builder->SetOutputsFormat(formats);
248 kernel_build_info_builder->SetOutputsDeviceType(types);
249 AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
250
251 auto mutable_inputs = graph->MutableInputs();
252 MS_EXCEPTION_IF_NULL(mutable_inputs);
253 mutable_inputs->push_back(param);
254
255 return param;
256 }
257 } // namespace runtime
258 } // namespace mindspore
259