1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/ps_cache/ascend/ascend_ps_cache.h"
18 #include <google/protobuf/text_format.h>
19 #include <string>
20 #include <vector>
21 #include <memory>
22 #include "ps/ps_cache/ps_cache_factory.h"
23 #include "runtime/device/ascend/ascend_memory_pool.h"
24 #include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h"
25 #include "utils/ms_context.h"
26 #include "proto/tensor.pb.h"
27 #include "proto/tensor_shape.pb.h"
28 #include "proto/attr.pb.h"
29 #include "proto/node_def.pb.h"
30 #include "runtime/rt.h"
31
32 using mindspore::kernel::Address;
33 using AddressPtr = std::shared_ptr<Address>;
34 using AddressPtrList = std::vector<AddressPtr>;
35
36 namespace mindspore {
37 namespace ps {
38 namespace ascend {
39 MS_REG_PS_CACHE(kAscendDevice, AscendPsCache);
40 namespace {
SetProtoInputs(const std::vector<std::vector<size_t>> & data_shape,const std::vector<TypeId> & data_type,mindspore::NodeDef * proto)41 bool SetProtoInputs(const std::vector<std::vector<size_t>> &data_shape, const std::vector<TypeId> &data_type,
42 mindspore::NodeDef *proto) {
43 MS_ERROR_IF_NULL(proto);
44 if (data_shape.size() != data_type.size()) {
45 MS_LOG(ERROR) << "The size of data shape is not equal to the size of data type.";
46 return false;
47 }
48 for (size_t input_index = 0; input_index < data_shape.size(); input_index++) {
49 ::mindspore::Tensor *proto_inputs = proto->add_inputs();
50 MS_ERROR_IF_NULL(proto_inputs);
51 auto input_shape = data_shape[input_index];
52 mindspore::TensorShape *tensorShape = proto_inputs->mutable_tensor_shape();
53 MS_ERROR_IF_NULL(tensorShape);
54 for (auto item : input_shape) {
55 mindspore::TensorShape_Dim *dim = tensorShape->add_dim();
56 MS_ERROR_IF_NULL(dim);
57 dim->set_size((::google::protobuf::int64)item);
58 }
59 auto input_type = kernel::AicpuOpUtil::MsTypeToProtoType(data_type[input_index]);
60 proto_inputs->set_tensor_type(input_type);
61 proto_inputs->set_mem_device("HBM");
62 }
63 return true;
64 }
65
SetProtoOutputs(const std::vector<std::vector<size_t>> & data_shape,const std::vector<TypeId> & data_type,mindspore::NodeDef * proto)66 bool SetProtoOutputs(const std::vector<std::vector<size_t>> &data_shape, const std::vector<TypeId> &data_type,
67 mindspore::NodeDef *proto) {
68 MS_ERROR_IF_NULL(proto);
69 if (data_shape.size() != data_type.size()) {
70 MS_LOG(ERROR) << "The size of data shape is not equal to the size of data type.";
71 return false;
72 }
73 for (size_t output_index = 0; output_index < data_shape.size(); output_index++) {
74 ::mindspore::Tensor *proto_outputs = proto->add_outputs();
75 MS_ERROR_IF_NULL(proto_outputs);
76 auto output_shape = data_shape[output_index];
77 mindspore::TensorShape *tensorShape = proto_outputs->mutable_tensor_shape();
78 MS_ERROR_IF_NULL(tensorShape);
79 for (auto item : output_shape) {
80 mindspore::TensorShape_Dim *dim = tensorShape->add_dim();
81 MS_ERROR_IF_NULL(dim);
82 dim->set_size((::google::protobuf::int64)item);
83 }
84 auto output_type = kernel::AicpuOpUtil::MsTypeToProtoType(data_type[output_index]);
85 proto_outputs->set_tensor_type(output_type);
86 proto_outputs->set_mem_device("HBM");
87 }
88 return true;
89 }
90
SetNodedefProto(const std::shared_ptr<KernelNodeInfo> & op_info,const std::shared_ptr<kernel::AicpuOpKernelMod> & kernel_mod_ptr)91 bool SetNodedefProto(const std::shared_ptr<KernelNodeInfo> &op_info,
92 const std::shared_ptr<kernel::AicpuOpKernelMod> &kernel_mod_ptr) {
93 MS_ERROR_IF_NULL(op_info);
94 MS_ERROR_IF_NULL(kernel_mod_ptr);
95 mindspore::NodeDef proto;
96 proto.set_op(op_info->op_name_);
97 RETURN_IF_FALSE(SetProtoInputs(op_info->input_data_shape_, op_info->input_data_type_, &proto));
98 RETURN_IF_FALSE(SetProtoOutputs(op_info->output_data_shape_, op_info->output_data_type_, &proto));
99 std::string nodeDefStr;
100 if (!proto.SerializeToString(&nodeDefStr)) {
101 MS_LOG(ERROR) << "Serialize nodeDef to string failed.";
102 return false;
103 }
104 MS_LOG(DEBUG) << "Set node def proto, node name:" << op_info->op_name_;
105 kernel_mod_ptr->SetNodeDef(nodeDefStr);
106 return true;
107 }
108 } // namespace
109
InitDevice(uint32_t device_id,const void * context)110 bool AscendPsCache::InitDevice(uint32_t device_id, const void *context) {
111 MS_ERROR_IF_NULL(context);
112 auto ret = rtSetDevice(UintToInt(device_id));
113 if (ret != RT_ERROR_NONE) {
114 MS_LOG(ERROR) << "Call rtSetDevice, ret[" << ret << "]";
115 return false;
116 }
117 auto rt_context = const_cast<rtContext_t>(context);
118 ret = rtCtxSetCurrent(rt_context);
119 if (ret != RT_ERROR_NONE) {
120 MS_LOG(ERROR) << "Call rtCtxSetCurrent, ret[" << ret << "]";
121 return false;
122 }
123 ret = rtStreamCreate(&stream_, 0);
124 if (ret != RT_ERROR_NONE) {
125 MS_LOG(ERROR) << "Call rtStreamCreate, ret[" << ret << "]";
126 return false;
127 }
128 return true;
129 }
130
MallocMemory(size_t size)131 void *AscendPsCache::MallocMemory(size_t size) {
132 return device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(size);
133 }
134
MallocConstantMemory(size_t cache_vocab_size)135 bool AscendPsCache::MallocConstantMemory(size_t cache_vocab_size) {
136 offset_addr_ = reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int)));
137 MS_ERROR_IF_NULL(offset_addr_);
138 rtMemset(offset_addr_, sizeof(int), 0, sizeof(int));
139 cache_vocab_size_addr_ =
140 reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int)));
141 MS_ERROR_IF_NULL(cache_vocab_size_addr_);
142 int copy_value = SizeToInt(cache_vocab_size);
143 if (!CopyHostMemToDevice(cache_vocab_size_addr_, ©_value, sizeof(int))) {
144 return false;
145 }
146 return SynchronizeStream();
147 }
148
RecordEvent()149 bool AscendPsCache::RecordEvent() {
150 event_.reset(new rtEvent_t());
151 MS_ERROR_IF_NULL_W_RET_VAL(event_, false);
152 auto ret = rtEventCreate(&(*event_));
153 if (ret != RT_ERROR_NONE) {
154 MS_LOG(ERROR) << "Create event failed";
155 return false;
156 }
157 ret = rtEventRecord(*event_, stream_);
158 if (ret != RT_ERROR_NONE) {
159 MS_LOG(ERROR) << "Record event failed";
160 return false;
161 }
162 return true;
163 }
164
SynchronizeEvent()165 bool AscendPsCache::SynchronizeEvent() {
166 MS_ERROR_IF_NULL_W_RET_VAL(event_, false);
167 auto ret = rtEventSynchronize(*event_);
168 if (ret != RT_ERROR_NONE) {
169 MS_LOG(ERROR) << "tEventSynchronize failed";
170 return false;
171 }
172 ret = rtEventDestroy(*event_);
173 if (ret != RT_ERROR_NONE) {
174 MS_LOG(ERROR) << "rtEventDestroy failed";
175 return false;
176 }
177 return true;
178 }
179
SynchronizeStream()180 bool AscendPsCache::SynchronizeStream() {
181 MS_ERROR_IF_NULL_W_RET_VAL(stream_, false);
182 auto ret = rtStreamSynchronize(stream_);
183 if (ret != RT_ERROR_NONE) {
184 MS_LOG(ERROR) << "rtStreamSynchronize failed";
185 return false;
186 }
187 return true;
188 }
189
CopyHostMemToDevice(void * dst,const void * src,size_t size)190 bool AscendPsCache::CopyHostMemToDevice(void *dst, const void *src, size_t size) {
191 MS_ERROR_IF_NULL(dst);
192 MS_ERROR_IF_NULL(src);
193 auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_HOST_TO_DEVICE, stream_);
194 if (ret != RT_ERROR_NONE) {
195 MS_LOG(ERROR) << "rtMemcpyAsync failed, the error num is:" << ret;
196 return false;
197 }
198 return true;
199 }
200
CopyDeviceMemToHost(void * dst,const void * src,size_t size)201 bool AscendPsCache::CopyDeviceMemToHost(void *dst, const void *src, size_t size) {
202 MS_ERROR_IF_NULL(dst);
203 MS_ERROR_IF_NULL(src);
204 auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_DEVICE_TO_HOST, stream_);
205 if (ret != RT_ERROR_NONE) {
206 MS_LOG(ERROR) << "rtMemcpyAsync failed, the error num is:" << ret;
207 return false;
208 }
209 return true;
210 }
211
HashSwapOut(void * hash_table_addr,void * swap_out_value_addr,void * swap_out_index_addr,size_t cache_vocab_size,size_t embedding_size,size_t swap_out_size)212 bool AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr,
213 size_t cache_vocab_size, size_t embedding_size, size_t swap_out_size) {
214 MS_ERROR_IF_NULL(hash_table_addr);
215 MS_ERROR_IF_NULL(swap_out_value_addr);
216 MS_ERROR_IF_NULL(swap_out_index_addr);
217 auto hash_swap_out_mod = std::make_shared<kernel::AicpuOpKernelMod>();
218 MS_ERROR_IF_NULL(hash_swap_out_mod);
219 hash_swap_out_mod->SetNodeName(kEmbeddingLookupOpName);
220
221 std::vector<size_t> hash_table_shape = {cache_vocab_size, embedding_size};
222 std::vector<size_t> swap_out_index_shape = {swap_out_size};
223 std::vector<size_t> offset_shape = {1};
224 std::vector<std::vector<size_t>> input_shape = {hash_table_shape, swap_out_index_shape, offset_shape};
225
226 std::vector<size_t> swap_out_value_shape = {swap_out_size, embedding_size};
227 std::vector<std::vector<size_t>> output_shape = {swap_out_value_shape};
228
229 std::vector<TypeId> input_type = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32};
230 std::vector<TypeId> output_type = {TypeId::kNumberTypeFloat32};
231 auto op_info =
232 std::make_shared<KernelNodeInfo>(kEmbeddingLookupOpName, input_shape, input_type, output_shape, output_type);
233 MS_ERROR_IF_NULL_W_RET_VAL(op_info, false);
234 RETURN_IF_FALSE(SetNodedefProto(op_info, hash_swap_out_mod));
235
236 AddressPtrList kernel_inputs;
237 AddressPtrList kernel_outputs = {
238 std::make_shared<Address>(swap_out_value_addr, swap_out_size * embedding_size * sizeof(float))};
239 AddressPtrList kernel_workspaces;
240 (void)kernel_inputs.emplace_back(
241 std::make_shared<Address>(hash_table_addr, cache_vocab_size * embedding_size * sizeof(float)));
242 (void)kernel_inputs.emplace_back(std::make_shared<Address>(swap_out_index_addr, swap_out_size * sizeof(int)));
243 (void)kernel_inputs.emplace_back(std::make_shared<Address>(offset_addr_, sizeof(int)));
244 auto ret = hash_swap_out_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
245 if (!ret) {
246 MS_LOG(ERROR) << "Hash swap out launch failed.";
247 return false;
248 }
249 return true;
250 }
251
HashSwapIn(void * hash_table_addr,void * swap_in_value_addr,void * swap_in_index_addr,size_t cache_vocab_size,size_t embedding_size,size_t swap_in_size)252 bool AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr,
253 size_t cache_vocab_size, size_t embedding_size, size_t swap_in_size) {
254 MS_ERROR_IF_NULL(hash_table_addr);
255 MS_ERROR_IF_NULL(swap_in_value_addr);
256 MS_ERROR_IF_NULL(swap_in_index_addr);
257 auto hash_swap_in_mod = std::make_shared<kernel::AicpuOpKernelMod>();
258 MS_ERROR_IF_NULL(hash_swap_in_mod);
259 hash_swap_in_mod->SetNodeName(kernel::kUpdateCache);
260
261 std::vector<size_t> hash_table_shape = {cache_vocab_size, embedding_size};
262 std::vector<size_t> swap_in_index_shape = {swap_in_size};
263 std::vector<size_t> swap_in_value_shape = {swap_in_size, embedding_size};
264 std::vector<size_t> offset_shape = {1};
265 std::vector<std::vector<size_t>> input_shape = {hash_table_shape, swap_in_index_shape, swap_in_value_shape,
266 offset_shape};
267 std::vector<std::vector<size_t>> output_shape = {offset_shape};
268
269 std::vector<TypeId> input_type = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32,
270 TypeId::kNumberTypeInt32};
271 std::vector<TypeId> output_type = {TypeId::kNumberTypeInt32};
272 auto op_info =
273 std::make_shared<KernelNodeInfo>(kernel::kUpdateCache, input_shape, input_type, output_shape, output_type);
274 MS_ERROR_IF_NULL_W_RET_VAL(op_info, false);
275 SetNodedefProto(op_info, hash_swap_in_mod);
276
277 AddressPtrList kernel_inputs;
278 AddressPtrList kernel_outputs;
279 AddressPtrList kernel_workspaces;
280 (void)kernel_inputs.emplace_back(
281 std::make_shared<Address>(hash_table_addr, cache_vocab_size * embedding_size * sizeof(float)));
282 (void)kernel_inputs.emplace_back(std::make_shared<Address>(swap_in_index_addr, swap_in_size * sizeof(int)));
283 (void)kernel_inputs.emplace_back(
284 std::make_shared<Address>(swap_in_value_addr, swap_in_size * embedding_size * sizeof(float)));
285 (void)kernel_inputs.emplace_back(std::make_shared<Address>(cache_vocab_size_addr_, sizeof(int)));
286 // The output of updateCache kernel is required but not useful, so any address can be assigned.
287 (void)kernel_outputs.emplace_back(std::make_shared<Address>(offset_addr_, sizeof(int)));
288 auto ret = hash_swap_in_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
289 if (!ret) {
290 MS_LOG(ERROR) << "Hash swap in launch failed.";
291 return false;
292 }
293 return true;
294 }
295 } // namespace ascend
296 } // namespace ps
297 } // namespace mindspore
298