1 /**
2 * Copyright 2023-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_TRANSFORM_ACL_IR_OP_API_CONVERT_H_
18 #define MINDSPORE_CCSRC_TRANSFORM_ACL_IR_OP_API_CONVERT_H_
19
20 #include <dlfcn.h>
21 #include <vector>
22 #include <string>
23 #include <map>
24 #include <memory>
25 #include <algorithm>
26 #include <functional>
27 #include <regex>
28 #include <utility>
29 #include <tuple>
30 #include "acl/acl_base.h"
31 #include "ir/tensor.h"
32 #include "transform/acl_ir/acl_convert.h"
33 #include "plugin/device/ascend/hal/common/ascend_utils.h"
34 #include "plugin/device/ascend/hal/device/ascend_device_address.h"
35 #include "transform/acl_ir/acl_helper.h"
36 #include "runtime/device/ms_device_shape_transfer.h"
37
38 namespace mindspore::transform {
39 // Api data struct.
40 typedef struct aclOpExecutor aclOpExecutor;
41 typedef struct aclTensor aclTensor;
42 typedef struct aclScalar aclScalar;
43 typedef struct aclIntArray aclIntArray;
44 typedef struct aclFloatArray aclFloatArray;
45 typedef struct aclBoolArray aclBoolArray;
46 typedef struct aclTensorList aclTensorList;
47
48 // Create operator.
49 using _aclCreateTensor = aclTensor *(*)(const int64_t *view_dims, uint64_t view_dims_num, aclDataType data_type,
50 const int64_t *stride, int64_t offset, aclFormat format,
51 const int64_t *storage_dims, uint64_t storage_dims_num, void *tensor_data);
52 using _aclCreateScalar = aclScalar *(*)(void *value, aclDataType data_type);
53 using _aclCreateIntArray = aclIntArray *(*)(const int64_t *value, uint64_t size);
54 using _aclCreateFloatArray = aclFloatArray *(*)(const float *value, uint64_t size);
55 using _aclCreateBoolArray = aclBoolArray *(*)(const bool *value, uint64_t size);
56 using _aclCreateTensorList = aclTensorList *(*)(const aclTensor *const *value, uint64_t size);
57 // Destroy operator.
58 using _aclDestroyTensor = int (*)(const aclTensor *tensor);
59 using _aclDestroyScalar = int (*)(const aclScalar *scalar);
60 using _aclDestroyIntArray = int (*)(const aclIntArray *array);
61 using _aclDestroyFloatArray = int (*)(const aclFloatArray *array);
62 using _aclDestroyBoolArray = int (*)(const aclBoolArray *array);
63 using _aclDestroyTensorList = int (*)(const aclTensorList *array);
64
65 extern std::vector<std::pair<void *, std::string>> opapi_lib_handle;
66 extern void LoadOpApiLib();
67
68 // Get op api func.
GetOpApiLibName()69 inline std::string GetOpApiLibName() { return "/lib64/libopapi.so"; }
70
GetCustOpApiLibName()71 inline std::string GetCustOpApiLibName() { return "/op_api/lib/libcust_opapi.so"; }
72
GetOpApiFuncFromLib(void * handler,const char * lib_name,const char * api_name)73 inline void *GetOpApiFuncFromLib(void *handler, const char *lib_name, const char *api_name) {
74 MS_EXCEPTION_IF_NULL(handler);
75 auto func = dlsym(handler, api_name);
76 if (func == nullptr) {
77 MS_LOG(DEBUG) << "Dlsym " << api_name << " from " << lib_name << " failed!" << dlerror();
78 }
79 return func;
80 }
81
GetOpApiLibHandler(const std::string & lib_path)82 inline void *GetOpApiLibHandler(const std::string &lib_path) {
83 auto handler = dlopen(lib_path.c_str(), RTLD_LAZY);
84 if (handler == nullptr) {
85 MS_LOG(INFO) << "Dlopen " << lib_path << " failed!" << dlerror();
86 }
87 return handler;
88 }
89
GetOpApiFunc(const char * api_name)90 inline void *GetOpApiFunc(const char *api_name) {
91 static std::map<string, void *> opapi_cache;
92 auto res = opapi_cache.find(string(api_name));
93 if (res != opapi_cache.end()) {
94 MS_LOG(DEBUG) << "OpApi " << api_name << " hit cache.";
95 return res->second;
96 }
97 if (opapi_lib_handle.size() == 0) {
98 LoadOpApiLib();
99 }
100 for (const auto &handle : opapi_lib_handle) {
101 const auto api_func = GetOpApiFuncFromLib(handle.first, handle.second.c_str(), api_name);
102 if (api_func != nullptr) {
103 (void)opapi_cache.emplace(string(api_name), api_func);
104 MS_LOG(DEBUG) << "Get OpApiFunc [" << api_name << "] from " << handle.second;
105 return api_func;
106 }
107 }
108 MS_LOG(WARNING) << "Dlsym " << api_name << " failed!";
109 return nullptr;
110 }
111
112 #define GET_OP_API_FUNC(func_name) reinterpret_cast<_##func_name>(GetOpApiFunc(#func_name))
113
114 template <typename Tuple, size_t... I>
ConvertToOpApiFunc(const Tuple & params,void * opApiAddr,std::index_sequence<I...>)115 auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr, std::index_sequence<I...>) {
116 using OpApiFunc = int (*)(typename std::decay<decltype(std::get<I>(params))>::type...);
117 auto func = reinterpret_cast<OpApiFunc>(opApiAddr);
118 return func;
119 }
120
121 template <typename Tuple>
ConvertToOpApiFunc(const Tuple & params,void * opApiAddr)122 auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr) {
123 static constexpr auto size = std::tuple_size<Tuple>::value;
124 return ConvertToOpApiFunc(params, opApiAddr, std::make_index_sequence<size>{});
125 }
126
127 // Convert Value
128 class OpApiTensorConverter : public AttrHelper<OpApiTensorConverter> {
129 public:
130 OpApiTensorConverter() = default;
131 ~OpApiTensorConverter() = default;
132
133 template <typename T>
ConvertValue(const ValuePtr & value,const AttrDeclType<T> &,aclScalar ** scalar)134 void ConvertValue(const ValuePtr &value, const AttrDeclType<T> &, aclScalar **scalar) {
135 auto real_val = GetValue<T>(value);
136 MS_EXCEPTION_IF_NULL(scalar);
137 *scalar = CreateAclScalar(&real_val, GetDataType(value));
138 }
139
ConvertValue(const ValuePtr & value,const AttrDeclType<int32_t> &,aclScalar ** scalar)140 void ConvertValue(const ValuePtr &value, const AttrDeclType<int32_t> &, aclScalar **scalar) {
141 auto real_val = static_cast<int64_t>(GetValue<int32_t>(value));
142 MS_EXCEPTION_IF_NULL(scalar);
143 *scalar = CreateAclScalar(&real_val, ACL_INT64);
144 }
145
ConvertValue(const ValuePtr & value,const AttrDeclType<std::vector<int64_t>> &,aclIntArray * array)146 void ConvertValue(const ValuePtr &value, const AttrDeclType<std::vector<int64_t>> &, aclIntArray *array) {
147 std::vector<int64_t> array_list;
148 ConvertValueSequenceToList(value, &array_list);
149 array = CreateIntArray(array_list);
150 }
151
ConvertValue(const ValuePtr & value,const AttrDeclType<std::vector<int32_t>> &,aclIntArray * array)152 void ConvertValue(const ValuePtr &value, const AttrDeclType<std::vector<int32_t>> &, aclIntArray *array) {
153 std::vector<int32_t> array_list;
154 ConvertValueSequenceToList(value, &array_list);
155 std::vector<int64_t> array_list_int64;
156 (void)std::transform(array_list.begin(), array_list.end(), std::back_inserter(array_list_int64),
157 [](const int val) { return IntToLong(val); });
158 array = CreateIntArray(array_list_int64);
159 }
160
ConvertValue(const ValuePtr & value,const AttrDeclType<std::vector<uint8_t>> &,aclBoolArray * array)161 void ConvertValue(const ValuePtr &value, const AttrDeclType<std::vector<uint8_t>> &, aclBoolArray *array) {
162 std::vector<uint8_t> array_list;
163 ConvertValueSequenceToList(value, &array_list);
164 array = CreateBoolArray(array_list);
165 }
166
ConvertValue(const ValuePtr & value,const AttrDeclType<std::vector<float>> &,aclFloatArray * array)167 void ConvertValue(const ValuePtr &value, const AttrDeclType<std::vector<float>> &, aclFloatArray *array) {
168 std::vector<float> array_list;
169 ConvertValueSequenceToList(value, &array_list);
170 array = CreateFloatArray(array_list);
171 }
172
173 template <typename T>
CreateAclScalar(T * val,aclDataType dtype)174 aclScalar *CreateAclScalar(T *val, aclDataType dtype) {
175 static const auto aclCreateScalar = GET_OP_API_FUNC(aclCreateScalar);
176 if (aclCreateScalar == nullptr) {
177 MS_LOG(EXCEPTION) << "Failed to get `aclCreateScalar` func.";
178 }
179 return aclCreateScalar(val, dtype);
180 }
181
CreateIntArray(const std::vector<int64_t> & val)182 aclIntArray *CreateIntArray(const std::vector<int64_t> &val) {
183 static const auto aclCreateIntArray = GET_OP_API_FUNC(aclCreateIntArray);
184 if (aclCreateIntArray == nullptr) {
185 MS_LOG(EXCEPTION) << "Failed to get `aclCreateIntArray` func.";
186 }
187 return aclCreateIntArray(val.data(), val.size());
188 }
189
CreateBoolArray(const std::vector<uint8_t> & val)190 aclBoolArray *CreateBoolArray(const std::vector<uint8_t> &val) {
191 static const auto aclCreateBoolArray = GET_OP_API_FUNC(aclCreateBoolArray);
192 if (aclCreateBoolArray == nullptr) {
193 MS_LOG(EXCEPTION) << "Failed to get `aclCreateBoolArray` func.";
194 }
195 return aclCreateBoolArray(reinterpret_cast<const bool *>(val.data()), val.size());
196 }
197
CreateFloatArray(const std::vector<float> & val)198 aclFloatArray *CreateFloatArray(const std::vector<float> &val) {
199 static const auto aclCreateFloatArray = GET_OP_API_FUNC(aclCreateFloatArray);
200 if (aclCreateFloatArray == nullptr) {
201 MS_LOG(EXCEPTION) << "Failed to get `aclCreateFloatArray` func.";
202 }
203 return aclCreateFloatArray(val.data(), val.size());
204 }
205
206 private:
GetDataType(const ValuePtr & value)207 inline aclDataType GetDataType(const ValuePtr &value) { return AclConverter::ConvertType(value->type()->type_id()); }
208 };
209
ConvertType(const mindspore::kernel::KernelTensor * tensor)210 inline aclTensor *ConvertType(const mindspore::kernel::KernelTensor *tensor) {
211 static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
212 if (aclCreateTensor == nullptr) {
213 return nullptr;
214 }
215 if (tensor == nullptr || tensor->type_id() == kMetaTypeNone) {
216 return nullptr;
217 }
218
219 auto acl_data_type = AclConverter::ConvertType(tensor->dtype_id());
220 const auto &shape = tensor->GetShapeVector();
221 const auto shape_size = shape.size();
222 aclFormat format = ACL_FORMAT_ND;
223 switch (shape_size) {
224 case 3:
225 format = ACL_FORMAT_NCL;
226 break;
227 case 4:
228 format = ACL_FORMAT_NCHW;
229 break;
230 case 5:
231 format = ACL_FORMAT_NCDHW;
232 break;
233 default:
234 format = ACL_FORMAT_ND;
235 }
236
237 aclTensor *acl_tensor = nullptr;
238 const auto &storage_info = tensor->tensor_storage_info();
239 if (storage_info == nullptr) {
240 // Create strides.
241 auto strides = shape;
242 if (!strides.empty()) {
243 strides.erase(strides.begin());
244 }
245 strides.push_back(1);
246 for (int i = static_cast<int>(strides.size()) - 2; i >= 0; i--) {
247 strides[i] = strides[i] * strides[i + 1];
248 }
249 acl_tensor = aclCreateTensor(shape.data(), shape_size, acl_data_type, strides.data(), 0, format, shape.data(),
250 shape.size(), tensor->device_ptr());
251 } else {
252 const auto &strides = storage_info->strides;
253 const auto &storage_shape = storage_info->ori_shape;
254 acl_tensor =
255 aclCreateTensor(shape.data(), shape_size, acl_data_type, strides.data(), SizeToLong(storage_info->storage_offset),
256 format, storage_shape.data(), storage_shape.size(), tensor->device_ptr());
257 }
258
259 return acl_tensor;
260 }
261
ConvertType(const device::DeviceAddressPtr & device_address)262 inline aclTensor *ConvertType(const device::DeviceAddressPtr &device_address) {
263 static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
264 if (aclCreateTensor == nullptr) {
265 return nullptr;
266 }
267 if (device_address == nullptr) {
268 return nullptr;
269 }
270
271 auto acl_data_type = AclConverter::ConvertType(device_address->type_id());
272 const auto &shape = device_address->GetShapeVector();
273 const auto shape_size = shape.size();
274 aclFormat format = ACL_FORMAT_ND;
275 switch (shape_size) {
276 case 3:
277 format = ACL_FORMAT_NCL;
278 break;
279 case 4:
280 format = ACL_FORMAT_NCHW;
281 break;
282 case 5:
283 format = ACL_FORMAT_NCDHW;
284 break;
285 default:
286 format = ACL_FORMAT_ND;
287 }
288
289 aclTensor *acl_tensor = nullptr;
290 const auto &storage_info = device_address->address_common()->tensor_storage_info_;
291 if (storage_info == nullptr) {
292 // Create strides.
293 auto strides = shape;
294 if (!strides.empty()) {
295 strides.erase(strides.begin());
296 }
297 strides.push_back(1);
298 for (int i = static_cast<int>(strides.size()) - 2; i >= 0; i--) {
299 strides[i] = strides[i] * strides[i + 1];
300 }
301 acl_tensor = aclCreateTensor(shape.data(), shape_size, acl_data_type, strides.data(), 0, format, shape.data(),
302 shape.size(), device_address->GetMutablePtr());
303 } else {
304 const auto &strides = storage_info->strides;
305 const auto &storage_shape = storage_info->ori_shape;
306 acl_tensor =
307 aclCreateTensor(shape.data(), shape_size, acl_data_type, strides.data(), SizeToLong(storage_info->storage_offset),
308 format, storage_shape.data(), storage_shape.size(), device_address->GetMutablePtr());
309 }
310
311 return acl_tensor;
312 }
313
ConvertType(mindspore::kernel::KernelTensor * tensor)314 inline aclTensor *ConvertType(mindspore::kernel::KernelTensor *tensor) {
315 return ConvertType(reinterpret_cast<const mindspore::kernel::KernelTensor *>(tensor));
316 }
317
ConvertType(std::pair<mindspore::kernel::KernelTensor *,bool> tensor_and_trans)318 inline aclTensor *ConvertType(std::pair<mindspore::kernel::KernelTensor *, bool> tensor_and_trans) {
319 static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
320 if (aclCreateTensor == nullptr) {
321 return nullptr;
322 }
323 auto tensor = tensor_and_trans.first;
324 auto trans = tensor_and_trans.second;
325 auto acl_data_type = AclConverter::ConvertType(tensor->dtype_id());
326 auto shape = tensor->GetShapeVector();
327 const auto shape_size = shape.size();
328 aclFormat format = ACL_FORMAT_ND;
329 switch (shape_size) {
330 case 3:
331 format = ACL_FORMAT_NCL;
332 break;
333 case 4:
334 format = ACL_FORMAT_NCHW;
335 break;
336 case 5:
337 format = ACL_FORMAT_NCDHW;
338 break;
339 default:
340 format = ACL_FORMAT_ND;
341 }
342
343 // Create strides.
344 auto strides = shape;
345 if (!strides.empty()) {
346 strides.erase(strides.begin());
347 }
348 strides.push_back(1);
349 for (int i = static_cast<int>(strides.size()) - 2; i >= 0; i--) {
350 strides[i] = strides[i] * strides[i + 1];
351 }
352 // Check if shape need transpose.
353 if (trans) {
354 std::swap(shape[shape.size() - 1], shape[shape.size() - 2]);
355 std::swap(strides[strides.size() - 1], strides[strides.size() - 2]);
356 }
357 auto acl_tensor = aclCreateTensor(shape.data(), shape_size, acl_data_type, strides.data(), 0, format, shape.data(),
358 shape_size, tensor->device_ptr());
359 return acl_tensor;
360 }
361
GetViewShapeAndStride(const tensor::BaseTensorPtr & tensor,const device::DeviceAddressPtr & device_address)362 inline std::tuple<std::vector<int64_t>, std::vector<int64_t>, int64_t, std::vector<int64_t>> GetViewShapeAndStride(
363 const tensor::BaseTensorPtr &tensor, const device::DeviceAddressPtr &device_address) {
364 MS_EXCEPTION_IF_NULL(tensor);
365 MS_EXCEPTION_IF_NULL(device_address);
366
367 const auto &storage_info = tensor->storage_info();
368 // Get dev shape
369 auto get_dev_shape = [device_address, tensor](const std::string &tensor_format, const auto &tensor_shape) {
370 if (transform::AclHelper::CheckDefaultSupportFormat(tensor_format)) {
371 return tensor_shape;
372 }
373 int64_t groups = 1;
374 auto node_idx = device_address->GetNodeIndex();
375 if (node_idx.first != nullptr) {
376 groups = common::AnfAlgo::GetAttrGroups(node_idx.first, node_idx.second);
377 }
378 return trans::TransShapeToDevice(tensor_shape, tensor_format, tensor->data_type(), groups);
379 };
380
381 const auto &tensor_shape = tensor->shape();
382 const auto &tensor_format = device_address->format();
383 if (storage_info == nullptr) {
384 const auto &dev_shape = get_dev_shape(tensor_format, tensor_shape);
385
386 // Get contiguous strides
387 std::vector<int64_t> strides(tensor_shape.size(), 1);
388 for (int i = static_cast<int>(strides.size()) - 2; i >= 0; i--) {
389 strides[i] = tensor_shape[i + 1] * strides[i + 1];
390 }
391
392 return std::make_tuple(tensor_shape, strides, 0, dev_shape);
393 } else {
394 const auto &dev_shape = get_dev_shape(tensor_format, storage_info->ori_shape);
395 return std::make_tuple(tensor_shape, storage_info->strides, storage_info->storage_offset, dev_shape);
396 }
397 }
398
ConvertType(const tensor::BaseTensorPtr & tensor)399 inline aclTensor *ConvertType(const tensor::BaseTensorPtr &tensor) {
400 MS_EXCEPTION_IF_NULL(tensor);
401 static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
402 if (aclCreateTensor == nullptr) {
403 return nullptr;
404 }
405 auto shape = tensor->shape();
406 const auto shape_size = shape.size();
407 aclFormat format = ACL_FORMAT_ND;
408 switch (shape_size) {
409 case kSizeThree:
410 format = ACL_FORMAT_NCL;
411 break;
412 case kSizeFour:
413 format = ACL_FORMAT_NCHW;
414 break;
415 case kSizeFive:
416 format = ACL_FORMAT_NCDHW;
417 break;
418 default:
419 format = ACL_FORMAT_ND;
420 }
421 auto acl_data_type = AclConverter::ConvertType(tensor->data_type());
422 auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
423 if (device_address->GetMutablePtr() == nullptr) {
424 MS_LOG(EXCEPTION) << "The device memory is null, please allocate the device memory for tensor "
425 << tensor->ToString();
426 }
427 auto [view_shape, strides, offset, ori_dev_shape] = GetViewShapeAndStride(tensor, device_address);
428 auto acl_tensor = aclCreateTensor(view_shape.data(), view_shape.size(), acl_data_type, strides.data(), offset, format,
429 ori_dev_shape.data(), ori_dev_shape.size(), device_address->GetMutablePtr());
430 return acl_tensor;
431 }
432
ConvertType(const std::optional<tensor::BaseTensorPtr> & value)433 inline aclTensor *ConvertType(const std::optional<tensor::BaseTensorPtr> &value) {
434 if (value.has_value()) {
435 return ConvertType(value.value());
436 }
437 return nullptr;
438 }
439
ConvertType(const std::vector<int64_t> & int_array)440 inline aclIntArray *ConvertType(const std::vector<int64_t> &int_array) {
441 if (int_array.empty()) {
442 MS_LOG(DEBUG) << "int array is empty!";
443 }
444 static OpApiTensorConverter converter;
445 return converter.CreateIntArray(int_array);
446 }
447
ConvertType(const std::vector<float> & float_array)448 inline aclFloatArray *ConvertType(const std::vector<float> &float_array) {
449 if (float_array.empty()) {
450 MS_LOG(ERROR) << "float array is empty!";
451 }
452 static OpApiTensorConverter converter;
453 return converter.CreateFloatArray(float_array);
454 }
455
ConvertType(const std::vector<uint8_t> & bool_array)456 inline aclBoolArray *ConvertType(const std::vector<uint8_t> &bool_array) {
457 if (bool_array.empty()) {
458 MS_LOG(ERROR) << "bool array is empty!";
459 }
460 static OpApiTensorConverter converter;
461 return converter.CreateBoolArray(bool_array);
462 }
463
ConvertType(const std::vector<tensor::BaseTensorPtr> & tensor_list)464 inline aclTensorList *ConvertType(const std::vector<tensor::BaseTensorPtr> &tensor_list) {
465 if (tensor_list.empty()) {
466 MS_LOG(DEBUG) << "tensor list is empty!";
467 }
468 static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList);
469 std::vector<aclTensor *> tmp;
470 std::transform(tensor_list.begin(), tensor_list.end(), std::back_inserter(tmp),
471 [](const tensor::BaseTensorPtr &tensor) { return ConvertType(tensor); });
472 return aclCreateTensorList(tmp.data(), tmp.size());
473 }
474
ConvertType(const tensor::TensorPtr & tensor)475 inline aclTensor *ConvertType(const tensor::TensorPtr &tensor) {
476 return ConvertType(tensor->cast<tensor::BaseTensorPtr>());
477 }
478
ConvertType(const std::optional<tensor::TensorPtr> & value)479 inline aclTensor *ConvertType(const std::optional<tensor::TensorPtr> &value) {
480 if (value.has_value()) {
481 return ConvertType(value.value());
482 }
483 return nullptr;
484 }
485
ConvertType(const std::vector<tensor::TensorPtr> & tensor_list)486 inline aclTensorList *ConvertType(const std::vector<tensor::TensorPtr> &tensor_list) {
487 if (tensor_list.empty()) {
488 MS_LOG(DEBUG) << "tensor list is empty!";
489 }
490 static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList);
491 std::vector<aclTensor *> tmp;
492 std::transform(tensor_list.begin(), tensor_list.end(), std::back_inserter(tmp),
493 [](const tensor::TensorPtr &tensor) { return ConvertType(tensor); });
494 return aclCreateTensorList(tmp.data(), tmp.size());
495 }
496
ConvertType(const std::vector<mindspore::kernel::KernelTensor * > & tensor_list)497 inline aclTensorList *ConvertType(const std::vector<mindspore::kernel::KernelTensor *> &tensor_list) {
498 if (tensor_list.empty()) {
499 MS_LOG(DEBUG) << "tensor list is empty!";
500 }
501 static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList);
502 std::vector<aclTensor *> tmp;
503 std::transform(tensor_list.begin(), tensor_list.end(), std::back_inserter(tmp),
504 [](mindspore::kernel::KernelTensor *tensor) { return ConvertType(tensor); });
505 return aclCreateTensorList(tmp.data(), tmp.size());
506 }
507
ConvertType(const ScalarPtr & value)508 inline aclScalar *ConvertType(const ScalarPtr &value) {
509 if (value == nullptr) {
510 // for None
511 return nullptr;
512 }
513 aclScalar *acl_scalar;
514 static OpApiTensorConverter converter;
515 if (value->isa<BoolImm>()) {
516 converter.ConvertValue(value, AttrDeclType<bool>(), &acl_scalar);
517 } else if (value->isa<Int64Imm>()) {
518 converter.ConvertValue(value, AttrDeclType<int64_t>(), &acl_scalar);
519 } else if (value->isa<FP64Imm>()) {
520 converter.ConvertValue(value, AttrDeclType<double>(), &acl_scalar);
521 } else if (value->isa<FP32Imm>()) {
522 converter.ConvertValue(value, AttrDeclType<float>(), &acl_scalar);
523 } else if (value->isa<Int32Imm>()) {
524 converter.ConvertValue(value, AttrDeclType<int32_t>(), &acl_scalar);
525 } else if (value->isa<Int8Imm>()) {
526 converter.ConvertValue(value, AttrDeclType<int8_t>(), &acl_scalar);
527 } else if (value->isa<Int16Imm>()) {
528 converter.ConvertValue(value, AttrDeclType<int16_t>(), &acl_scalar);
529 } else if (value->isa<UInt8Imm>()) {
530 converter.ConvertValue(value, AttrDeclType<uint8_t>(), &acl_scalar);
531 } else if (value->isa<FP64Imm>()) {
532 converter.ConvertValue(value, AttrDeclType<double>(), &acl_scalar);
533 } else if (value->isa<BF16Imm>()) {
534 converter.ConvertValue(value, AttrDeclType<bfloat16>(), &acl_scalar);
535 } else {
536 MS_LOG(EXCEPTION) << "Currently not support value: " << value->ToString();
537 }
538 return acl_scalar;
539 }
540
ConvertType(const std::optional<ScalarPtr> & value)541 inline aclScalar *ConvertType(const std::optional<ScalarPtr> &value) {
542 if (value.has_value()) {
543 return ConvertType(value.value());
544 }
545 return nullptr;
546 }
547
ConvertType(TypeId type_id)548 inline aclDataType ConvertType(TypeId type_id) { return AclConverter::ConvertType(type_id); }
549
ConvertType(const TypePtr & type)550 inline aclDataType ConvertType(const TypePtr &type) { return AclConverter::ConvertType(type->type_id()); }
551
ConvertType(const std::string & value)552 inline const char *ConvertType(const std::string &value) { return value.c_str(); }
553
554 template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
ConvertType(T value)555 T ConvertType(T value) {
556 return value;
557 }
558
559 template <typename... Ts>
ConvertTypes(const Ts &...args)560 constexpr auto ConvertTypes(const Ts &... args) {
561 return std::make_tuple(ConvertType(args)...);
562 }
563
564 template <typename T>
ConvertKernelTensor(mindspore::kernel::KernelTensor * tensor)565 T ConvertKernelTensor(mindspore::kernel::KernelTensor *tensor) {
566 MS_EXCEPTION_IF_NULL(tensor);
567 return tensor->GetValueWithCheck<T>();
568 }
569
570 template <>
571 inline ScalarPtr ConvertKernelTensor<ScalarPtr>(mindspore::kernel::KernelTensor *tensor) {
572 MS_EXCEPTION_IF_NULL(tensor);
573 if (tensor->dtype_id() == kMetaTypeNone) {
574 // for None
575 return nullptr;
576 }
577 auto value_ptr = tensor->GetValueTrack();
578 if (value_ptr == nullptr) {
579 if (tensor->dtype_id() == kNumberTypeBool) {
580 auto value = tensor->GetValueWithCheck<bool>();
581 value_ptr = std::make_shared<BoolImm>(value);
582 } else if (tensor->dtype_id() == kNumberTypeInt64) {
583 auto value = tensor->GetValueWithCheck<int64_t>();
584 value_ptr = std::make_shared<Int64Imm>(value);
585 } else if (tensor->dtype_id() == kNumberTypeDouble || tensor->dtype_id() == kNumberTypeFloat64) {
586 auto value = tensor->GetValueWithCheck<double>();
587 value_ptr = std::make_shared<FP64Imm>(value);
588 } else if (tensor->dtype_id() == kNumberTypeFloat32) {
589 auto value = tensor->GetValueWithCheck<float>();
590 value_ptr = std::make_shared<FP32Imm>(value);
591 } else if (tensor->dtype_id() == kNumberTypeInt32) {
592 auto value = tensor->GetValueWithCheck<int32_t>();
593 value_ptr = std::make_shared<Int32Imm>(value);
594 } else if (tensor->dtype_id() == kNumberTypeInt8) {
595 auto value = tensor->GetValueWithCheck<int8_t>();
596 value_ptr = std::make_shared<Int8Imm>(value);
597 } else if (tensor->dtype_id() == kNumberTypeInt16) {
598 auto value = tensor->GetValueWithCheck<int16_t>();
599 value_ptr = std::make_shared<Int16Imm>(value);
600 } else if (tensor->dtype_id() == kNumberTypeUInt8) {
601 auto value = tensor->GetValueWithCheck<uint8_t>();
602 value_ptr = std::make_shared<UInt8Imm>(value);
603 } else if (tensor->dtype_id() == kNumberTypeBFloat16) {
604 auto value = tensor->GetValueWithCheck<bfloat16>();
605 value_ptr = std::make_shared<BF16Imm>(value);
606 } else {
607 MS_LOG(EXCEPTION) << "Currently not support value type: " << tensor->dtype_id();
608 }
609 }
610
611 MS_EXCEPTION_IF_NULL(value_ptr);
612
613 if (!value_ptr->isa<Scalar>()) {
614 MS_LOG(EXCEPTION) << "Current tensor's must be a scalar, please check!";
615 }
616 auto scalar_ptr = value_ptr->cast<ScalarPtr>();
617 MS_EXCEPTION_IF_NULL(scalar_ptr);
618 return scalar_ptr;
619 }
620
621 template <>
622 inline std::vector<int64_t> ConvertKernelTensor<std::vector<int64_t>>(mindspore::kernel::KernelTensor *tensor) {
623 MS_EXCEPTION_IF_NULL(tensor);
624 return tensor->GetValueWithCheck<std::vector<int64_t>>();
625 }
626
627 template <>
628 inline std::vector<float> ConvertKernelTensor<std::vector<float>>(mindspore::kernel::KernelTensor *tensor) {
629 MS_EXCEPTION_IF_NULL(tensor);
630 return tensor->GetValueWithCheck<std::vector<float>>();
631 }
632
633 template <>
634 inline std::vector<uint8_t> ConvertKernelTensor<std::vector<uint8_t>>(mindspore::kernel::KernelTensor *tensor) {
635 MS_EXCEPTION_IF_NULL(tensor);
636 return tensor->GetValueWithCheck<std::vector<uint8_t>>();
637 }
638
639 template <>
640 inline TypeId ConvertKernelTensor<TypeId>(mindspore::kernel::KernelTensor *tensor) {
641 MS_EXCEPTION_IF_NULL(tensor);
642 return tensor->dtype_id();
643 }
644
645 template <>
646 inline std::vector<mindspore::kernel::KernelTensor *>
647 ConvertKernelTensor<std::vector<mindspore::kernel::KernelTensor *>>(mindspore::kernel::KernelTensor *tensor) {
648 MS_EXCEPTION_IF_NULL(tensor);
649 if (tensor->type_id() != kObjectTypeTuple && tensor->type_id() != kObjectTypeList) {
650 return {tensor};
651 }
652 auto shape = tensor->GetShapeVector();
653 if (shape.empty()) {
654 MS_LOG(EXCEPTION) << "Current tensor is a tuple of tensor, but get a empty shape!";
655 }
656 if (shape[kIndex0] <= 0) {
657 MS_LOG(EXCEPTION) << shape << " is an invalid shape, please check op infer!";
658 }
659
660 std::vector<mindspore::kernel::KernelTensor *> res;
661
662 auto split_num = shape[kIndex0];
663 auto offset = tensor->size() / split_num;
664 auto new_shape = shape;
665 new_shape.erase(new_shape.begin());
666
667 for (int i = 0; i < split_num; ++i) {
668 auto new_tensor = new KernelTensor(*tensor);
669 new_tensor->SetType(std::make_shared<TensorType>(TypeIdToType(tensor->dtype_id())));
670 auto tensor_shape = std::make_shared<abstract::TensorShape>();
671 tensor_shape->SetShapeVector(new_shape);
672 new_tensor->SetShape(tensor_shape);
673 new_tensor->set_device_ptr(
674 reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor->device_ptr()) + offset * i));
675 new_tensor->set_size(offset);
676 (void)res.emplace_back(new_tensor);
677 }
678 return res;
679 }
680
Release(aclTensor * p)681 inline void Release(aclTensor *p) {
682 static const auto aclDestroyTensor = GET_OP_API_FUNC(aclDestroyTensor);
683 if (aclDestroyTensor == nullptr) {
684 return;
685 }
686 aclDestroyTensor(p);
687 }
688
Release(aclScalar * p)689 inline void Release(aclScalar *p) {
690 static const auto aclDestroyScalar = GET_OP_API_FUNC(aclDestroyScalar);
691 if (aclDestroyScalar == nullptr) {
692 return;
693 }
694 aclDestroyScalar(p);
695 }
696
Release(aclIntArray * p)697 inline void Release(aclIntArray *p) {
698 static const auto aclDestroyIntArray = GET_OP_API_FUNC(aclDestroyIntArray);
699 if (aclDestroyIntArray == nullptr) {
700 return;
701 }
702
703 aclDestroyIntArray(p);
704 }
705
Release(aclBoolArray * p)706 inline void Release(aclBoolArray *p) {
707 static const auto aclDestroyBoolArray = GET_OP_API_FUNC(aclDestroyBoolArray);
708 if (aclDestroyBoolArray == nullptr) {
709 return;
710 }
711
712 aclDestroyBoolArray(p);
713 }
714
Release(aclTensorList * p)715 inline void Release(aclTensorList *p) {
716 static const auto aclDestroyTensorList = GET_OP_API_FUNC(aclDestroyTensorList);
717 if (aclDestroyTensorList == nullptr) {
718 return;
719 }
720
721 aclDestroyTensorList(p);
722 }
723
724 template <typename T>
Release(T value)725 void Release(T value) {
726 (void)value;
727 }
728
729 template <typename Tuple, size_t... I>
CallRelease(Tuple t,std::index_sequence<I...>)730 void CallRelease(Tuple t, std::index_sequence<I...>) {
731 (void)std::initializer_list<int>{(Release(std::get<I>(t)), 0)...};
732 }
733
734 template <typename Tuple>
ReleaseConvertTypes(const Tuple & t)735 void ReleaseConvertTypes(const Tuple &t) {
736 static constexpr auto size = std::tuple_size<Tuple>::value;
737 CallRelease(t, std::make_index_sequence<size>{});
738 }
739
740 // return a Scalar with the input type
741 #define MAKE_SCALAR(num, typeid, out) \
742 switch (typeid) { \
743 case kNumberTypeFloat32: { \
744 out = std::make_shared<FP32Imm>(static_cast<float>(num)); \
745 break; \
746 } \
747 case kNumberTypeFloat16: { \
748 out = std::make_shared<FP32Imm>(static_cast<float>(num)); \
749 break; \
750 } \
751 case kNumberTypeFloat64: { \
752 out = std::make_shared<FP64Imm>(static_cast<double>(num)); \
753 break; \
754 } \
755 case kNumberTypeInt8: { \
756 out = std::make_shared<Int8Imm>(static_cast<int8_t>(num)); \
757 break; \
758 } \
759 case kNumberTypeInt16: { \
760 out = std::make_shared<Int16Imm>(static_cast<int16_t>(num)); \
761 break; \
762 } \
763 case kNumberTypeInt32: { \
764 out = std::make_shared<Int32Imm>(static_cast<int>(num)); \
765 break; \
766 } \
767 case kNumberTypeInt64: { \
768 out = std::make_shared<Int64Imm>(static_cast<int64_t>(num)); \
769 break; \
770 } \
771 case kNumberTypeBool: { \
772 out = std::make_shared<BoolImm>(static_cast<bool>(num)); \
773 break; \
774 } \
775 case kNumberTypeUInt8: { \
776 out = std::make_shared<UInt8Imm>(static_cast<uint8_t>(num)); \
777 break; \
778 } \
779 case kNumberTypeBFloat16: { \
780 out = std::make_shared<BF16Imm>(static_cast<bfloat16>(num)); \
781 break; \
782 } \
783 default: { \
784 MS_LOG(EXCEPTION) << "Not support typeid " << TypeIdToString(typeid); \
785 } \
786 }
787
788 } // namespace mindspore::transform
789 #endif // MINDSPORE_CCSRC_TRANSFORM_ACL_IR_OP_API_CONVERT_H_
790