• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_TRANSFORM_ACL_IR_OP_API_CONVERT_H_
18 #define MINDSPORE_CCSRC_TRANSFORM_ACL_IR_OP_API_CONVERT_H_
19 
20 #include <dlfcn.h>
21 #include <vector>
22 #include <string>
23 #include <map>
24 #include <memory>
25 #include <algorithm>
26 #include <functional>
27 #include <regex>
28 #include <utility>
29 #include <tuple>
30 #include "acl/acl_base.h"
31 #include "ir/tensor.h"
32 #include "transform/acl_ir/acl_convert.h"
33 #include "plugin/device/ascend/hal/common/ascend_utils.h"
34 #include "plugin/device/ascend/hal/device/ascend_device_address.h"
35 #include "transform/acl_ir/acl_helper.h"
36 #include "runtime/device/ms_device_shape_transfer.h"
37 
38 namespace mindspore::transform {
39 // Api data struct.
40 typedef struct aclOpExecutor aclOpExecutor;
41 typedef struct aclTensor aclTensor;
42 typedef struct aclScalar aclScalar;
43 typedef struct aclIntArray aclIntArray;
44 typedef struct aclFloatArray aclFloatArray;
45 typedef struct aclBoolArray aclBoolArray;
46 typedef struct aclTensorList aclTensorList;
47 
48 // Create operator.
49 using _aclCreateTensor = aclTensor *(*)(const int64_t *view_dims, uint64_t view_dims_num, aclDataType data_type,
50                                         const int64_t *stride, int64_t offset, aclFormat format,
51                                         const int64_t *storage_dims, uint64_t storage_dims_num, void *tensor_data);
52 using _aclCreateScalar = aclScalar *(*)(void *value, aclDataType data_type);
53 using _aclCreateIntArray = aclIntArray *(*)(const int64_t *value, uint64_t size);
54 using _aclCreateFloatArray = aclFloatArray *(*)(const float *value, uint64_t size);
55 using _aclCreateBoolArray = aclBoolArray *(*)(const bool *value, uint64_t size);
56 using _aclCreateTensorList = aclTensorList *(*)(const aclTensor *const *value, uint64_t size);
57 // Destroy operator.
58 using _aclDestroyTensor = int (*)(const aclTensor *tensor);
59 using _aclDestroyScalar = int (*)(const aclScalar *scalar);
60 using _aclDestroyIntArray = int (*)(const aclIntArray *array);
61 using _aclDestroyFloatArray = int (*)(const aclFloatArray *array);
62 using _aclDestroyBoolArray = int (*)(const aclBoolArray *array);
63 using _aclDestroyTensorList = int (*)(const aclTensorList *array);
64 
65 extern std::vector<std::pair<void *, std::string>> opapi_lib_handle;
66 extern void LoadOpApiLib();
67 
68 // Get op api func.
GetOpApiLibName()69 inline std::string GetOpApiLibName() { return "/lib64/libopapi.so"; }
70 
GetCustOpApiLibName()71 inline std::string GetCustOpApiLibName() { return "/op_api/lib/libcust_opapi.so"; }
72 
GetOpApiFuncFromLib(void * handler,const char * lib_name,const char * api_name)73 inline void *GetOpApiFuncFromLib(void *handler, const char *lib_name, const char *api_name) {
74   MS_EXCEPTION_IF_NULL(handler);
75   auto func = dlsym(handler, api_name);
76   if (func == nullptr) {
77     MS_LOG(DEBUG) << "Dlsym " << api_name << " from " << lib_name << " failed!" << dlerror();
78   }
79   return func;
80 }
81 
GetOpApiLibHandler(const std::string & lib_path)82 inline void *GetOpApiLibHandler(const std::string &lib_path) {
83   auto handler = dlopen(lib_path.c_str(), RTLD_LAZY);
84   if (handler == nullptr) {
85     MS_LOG(INFO) << "Dlopen " << lib_path << " failed!" << dlerror();
86   }
87   return handler;
88 }
89 
GetOpApiFunc(const char * api_name)90 inline void *GetOpApiFunc(const char *api_name) {
91   static std::map<string, void *> opapi_cache;
92   auto res = opapi_cache.find(string(api_name));
93   if (res != opapi_cache.end()) {
94     MS_LOG(DEBUG) << "OpApi " << api_name << " hit cache.";
95     return res->second;
96   }
97   if (opapi_lib_handle.size() == 0) {
98     LoadOpApiLib();
99   }
100   for (const auto &handle : opapi_lib_handle) {
101     const auto api_func = GetOpApiFuncFromLib(handle.first, handle.second.c_str(), api_name);
102     if (api_func != nullptr) {
103       (void)opapi_cache.emplace(string(api_name), api_func);
104       MS_LOG(DEBUG) << "Get OpApiFunc [" << api_name << "] from " << handle.second;
105       return api_func;
106     }
107   }
108   MS_LOG(WARNING) << "Dlsym " << api_name << " failed!";
109   return nullptr;
110 }
111 
112 #define GET_OP_API_FUNC(func_name) reinterpret_cast<_##func_name>(GetOpApiFunc(#func_name))
113 
114 template <typename Tuple, size_t... I>
ConvertToOpApiFunc(const Tuple & params,void * opApiAddr,std::index_sequence<I...>)115 auto ConvertToOpApiFunc(const Tuple &params, void *opApiAddr, std::index_sequence<I...>) {
116   using OpApiFunc = int (*)(typename std::decay<decltype(std::get<I>(params))>::type...);
117   auto func = reinterpret_cast<OpApiFunc>(opApiAddr);
118   return func;
119 }
120 
121 template <typename Tuple>
ConvertToOpApiFunc(const Tuple & params,void * opApiAddr)122 auto ConvertToOpApiFunc(const Tuple &params, void *opApiAddr) {
123   static constexpr auto size = std::tuple_size<Tuple>::value;
124   return ConvertToOpApiFunc(params, opApiAddr, std::make_index_sequence<size>{});
125 }
126 
127 // Convert Value
128 class OpApiTensorConverter : public AttrHelper<OpApiTensorConverter> {
129  public:
130   OpApiTensorConverter() = default;
131   ~OpApiTensorConverter() = default;
132 
133   template <typename T>
ConvertValue(const ValuePtr & value,const AttrDeclType<T> &,aclScalar ** scalar)134   void ConvertValue(const ValuePtr &value, const AttrDeclType<T> &, aclScalar **scalar) {
135     auto real_val = GetValue<T>(value);
136     MS_EXCEPTION_IF_NULL(scalar);
137     *scalar = CreateAclScalar(&real_val, GetDataType(value));
138   }
139 
ConvertValue(const ValuePtr & value,const AttrDeclType<int32_t> &,aclScalar ** scalar)140   void ConvertValue(const ValuePtr &value, const AttrDeclType<int32_t> &, aclScalar **scalar) {
141     auto real_val = static_cast<int64_t>(GetValue<int32_t>(value));
142     MS_EXCEPTION_IF_NULL(scalar);
143     *scalar = CreateAclScalar(&real_val, ACL_INT64);
144   }
145 
ConvertValue(const ValuePtr & value,const AttrDeclType<std::vector<int64_t>> &,aclIntArray * array)146   void ConvertValue(const ValuePtr &value, const AttrDeclType<std::vector<int64_t>> &, aclIntArray *array) {
147     std::vector<int64_t> array_list;
148     ConvertValueSequenceToList(value, &array_list);
149     array = CreateIntArray(array_list);
150   }
151 
ConvertValue(const ValuePtr & value,const AttrDeclType<std::vector<int32_t>> &,aclIntArray * array)152   void ConvertValue(const ValuePtr &value, const AttrDeclType<std::vector<int32_t>> &, aclIntArray *array) {
153     std::vector<int32_t> array_list;
154     ConvertValueSequenceToList(value, &array_list);
155     std::vector<int64_t> array_list_int64;
156     (void)std::transform(array_list.begin(), array_list.end(), std::back_inserter(array_list_int64),
157                          [](const int val) { return IntToLong(val); });
158     array = CreateIntArray(array_list_int64);
159   }
160 
ConvertValue(const ValuePtr & value,const AttrDeclType<std::vector<uint8_t>> &,aclBoolArray * array)161   void ConvertValue(const ValuePtr &value, const AttrDeclType<std::vector<uint8_t>> &, aclBoolArray *array) {
162     std::vector<uint8_t> array_list;
163     ConvertValueSequenceToList(value, &array_list);
164     array = CreateBoolArray(array_list);
165   }
166 
ConvertValue(const ValuePtr & value,const AttrDeclType<std::vector<float>> &,aclFloatArray * array)167   void ConvertValue(const ValuePtr &value, const AttrDeclType<std::vector<float>> &, aclFloatArray *array) {
168     std::vector<float> array_list;
169     ConvertValueSequenceToList(value, &array_list);
170     array = CreateFloatArray(array_list);
171   }
172 
173   template <typename T>
CreateAclScalar(T * val,aclDataType dtype)174   aclScalar *CreateAclScalar(T *val, aclDataType dtype) {
175     static const auto aclCreateScalar = GET_OP_API_FUNC(aclCreateScalar);
176     if (aclCreateScalar == nullptr) {
177       MS_LOG(EXCEPTION) << "Failed to get `aclCreateScalar` func.";
178     }
179     return aclCreateScalar(val, dtype);
180   }
181 
CreateIntArray(const std::vector<int64_t> & val)182   aclIntArray *CreateIntArray(const std::vector<int64_t> &val) {
183     static const auto aclCreateIntArray = GET_OP_API_FUNC(aclCreateIntArray);
184     if (aclCreateIntArray == nullptr) {
185       MS_LOG(EXCEPTION) << "Failed to get `aclCreateIntArray` func.";
186     }
187     return aclCreateIntArray(val.data(), val.size());
188   }
189 
CreateBoolArray(const std::vector<uint8_t> & val)190   aclBoolArray *CreateBoolArray(const std::vector<uint8_t> &val) {
191     static const auto aclCreateBoolArray = GET_OP_API_FUNC(aclCreateBoolArray);
192     if (aclCreateBoolArray == nullptr) {
193       MS_LOG(EXCEPTION) << "Failed to get `aclCreateBoolArray` func.";
194     }
195     return aclCreateBoolArray(reinterpret_cast<const bool *>(val.data()), val.size());
196   }
197 
CreateFloatArray(const std::vector<float> & val)198   aclFloatArray *CreateFloatArray(const std::vector<float> &val) {
199     static const auto aclCreateFloatArray = GET_OP_API_FUNC(aclCreateFloatArray);
200     if (aclCreateFloatArray == nullptr) {
201       MS_LOG(EXCEPTION) << "Failed to get `aclCreateFloatArray` func.";
202     }
203     return aclCreateFloatArray(val.data(), val.size());
204   }
205 
206  private:
GetDataType(const ValuePtr & value)207   inline aclDataType GetDataType(const ValuePtr &value) { return AclConverter::ConvertType(value->type()->type_id()); }
208 };
209 
ConvertType(const mindspore::kernel::KernelTensor * tensor)210 inline aclTensor *ConvertType(const mindspore::kernel::KernelTensor *tensor) {
211   static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
212   if (aclCreateTensor == nullptr) {
213     return nullptr;
214   }
215   if (tensor == nullptr || tensor->type_id() == kMetaTypeNone) {
216     return nullptr;
217   }
218 
219   auto acl_data_type = AclConverter::ConvertType(tensor->dtype_id());
220   const auto &shape = tensor->GetShapeVector();
221   const auto shape_size = shape.size();
222   aclFormat format = ACL_FORMAT_ND;
223   switch (shape_size) {
224     case 3:
225       format = ACL_FORMAT_NCL;
226       break;
227     case 4:
228       format = ACL_FORMAT_NCHW;
229       break;
230     case 5:
231       format = ACL_FORMAT_NCDHW;
232       break;
233     default:
234       format = ACL_FORMAT_ND;
235   }
236 
237   aclTensor *acl_tensor = nullptr;
238   const auto &storage_info = tensor->tensor_storage_info();
239   if (storage_info == nullptr) {
240     // Create strides.
241     auto strides = shape;
242     if (!strides.empty()) {
243       strides.erase(strides.begin());
244     }
245     strides.push_back(1);
246     for (int i = static_cast<int>(strides.size()) - 2; i >= 0; i--) {
247       strides[i] = strides[i] * strides[i + 1];
248     }
249     acl_tensor = aclCreateTensor(shape.data(), shape_size, acl_data_type, strides.data(), 0, format, shape.data(),
250                                  shape.size(), tensor->device_ptr());
251   } else {
252     const auto &strides = storage_info->strides;
253     const auto &storage_shape = storage_info->ori_shape;
254     acl_tensor =
255       aclCreateTensor(shape.data(), shape_size, acl_data_type, strides.data(), SizeToLong(storage_info->storage_offset),
256                       format, storage_shape.data(), storage_shape.size(), tensor->device_ptr());
257   }
258 
259   return acl_tensor;
260 }
261 
ConvertType(const device::DeviceAddressPtr & device_address)262 inline aclTensor *ConvertType(const device::DeviceAddressPtr &device_address) {
263   static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
264   if (aclCreateTensor == nullptr) {
265     return nullptr;
266   }
267   if (device_address == nullptr) {
268     return nullptr;
269   }
270 
271   auto acl_data_type = AclConverter::ConvertType(device_address->type_id());
272   const auto &shape = device_address->GetShapeVector();
273   const auto shape_size = shape.size();
274   aclFormat format = ACL_FORMAT_ND;
275   switch (shape_size) {
276     case 3:
277       format = ACL_FORMAT_NCL;
278       break;
279     case 4:
280       format = ACL_FORMAT_NCHW;
281       break;
282     case 5:
283       format = ACL_FORMAT_NCDHW;
284       break;
285     default:
286       format = ACL_FORMAT_ND;
287   }
288 
289   aclTensor *acl_tensor = nullptr;
290   const auto &storage_info = device_address->address_common()->tensor_storage_info_;
291   if (storage_info == nullptr) {
292     // Create strides.
293     auto strides = shape;
294     if (!strides.empty()) {
295       strides.erase(strides.begin());
296     }
297     strides.push_back(1);
298     for (int i = static_cast<int>(strides.size()) - 2; i >= 0; i--) {
299       strides[i] = strides[i] * strides[i + 1];
300     }
301     acl_tensor = aclCreateTensor(shape.data(), shape_size, acl_data_type, strides.data(), 0, format, shape.data(),
302                                  shape.size(), device_address->GetMutablePtr());
303   } else {
304     const auto &strides = storage_info->strides;
305     const auto &storage_shape = storage_info->ori_shape;
306     acl_tensor =
307       aclCreateTensor(shape.data(), shape_size, acl_data_type, strides.data(), SizeToLong(storage_info->storage_offset),
308                       format, storage_shape.data(), storage_shape.size(), device_address->GetMutablePtr());
309   }
310 
311   return acl_tensor;
312 }
313 
ConvertType(mindspore::kernel::KernelTensor * tensor)314 inline aclTensor *ConvertType(mindspore::kernel::KernelTensor *tensor) {
315   return ConvertType(reinterpret_cast<const mindspore::kernel::KernelTensor *>(tensor));
316 }
317 
ConvertType(std::pair<mindspore::kernel::KernelTensor *,bool> tensor_and_trans)318 inline aclTensor *ConvertType(std::pair<mindspore::kernel::KernelTensor *, bool> tensor_and_trans) {
319   static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
320   if (aclCreateTensor == nullptr) {
321     return nullptr;
322   }
323   auto tensor = tensor_and_trans.first;
324   auto trans = tensor_and_trans.second;
325   auto acl_data_type = AclConverter::ConvertType(tensor->dtype_id());
326   auto shape = tensor->GetShapeVector();
327   const auto shape_size = shape.size();
328   aclFormat format = ACL_FORMAT_ND;
329   switch (shape_size) {
330     case 3:
331       format = ACL_FORMAT_NCL;
332       break;
333     case 4:
334       format = ACL_FORMAT_NCHW;
335       break;
336     case 5:
337       format = ACL_FORMAT_NCDHW;
338       break;
339     default:
340       format = ACL_FORMAT_ND;
341   }
342 
343   // Create strides.
344   auto strides = shape;
345   if (!strides.empty()) {
346     strides.erase(strides.begin());
347   }
348   strides.push_back(1);
349   for (int i = static_cast<int>(strides.size()) - 2; i >= 0; i--) {
350     strides[i] = strides[i] * strides[i + 1];
351   }
352   // Check if shape need transpose.
353   if (trans) {
354     std::swap(shape[shape.size() - 1], shape[shape.size() - 2]);
355     std::swap(strides[strides.size() - 1], strides[strides.size() - 2]);
356   }
357   auto acl_tensor = aclCreateTensor(shape.data(), shape_size, acl_data_type, strides.data(), 0, format, shape.data(),
358                                     shape_size, tensor->device_ptr());
359   return acl_tensor;
360 }
361 
GetViewShapeAndStride(const tensor::BaseTensorPtr & tensor,const device::DeviceAddressPtr & device_address)362 inline std::tuple<std::vector<int64_t>, std::vector<int64_t>, int64_t, std::vector<int64_t>> GetViewShapeAndStride(
363   const tensor::BaseTensorPtr &tensor, const device::DeviceAddressPtr &device_address) {
364   MS_EXCEPTION_IF_NULL(tensor);
365   MS_EXCEPTION_IF_NULL(device_address);
366 
367   const auto &storage_info = tensor->storage_info();
368   // Get dev shape
369   auto get_dev_shape = [device_address, tensor](const std::string &tensor_format, const auto &tensor_shape) {
370     if (transform::AclHelper::CheckDefaultSupportFormat(tensor_format)) {
371       return tensor_shape;
372     }
373     int64_t groups = 1;
374     auto node_idx = device_address->GetNodeIndex();
375     if (node_idx.first != nullptr) {
376       groups = common::AnfAlgo::GetAttrGroups(node_idx.first, node_idx.second);
377     }
378     return trans::TransShapeToDevice(tensor_shape, tensor_format, tensor->data_type(), groups);
379   };
380 
381   const auto &tensor_shape = tensor->shape();
382   const auto &tensor_format = device_address->format();
383   if (storage_info == nullptr) {
384     const auto &dev_shape = get_dev_shape(tensor_format, tensor_shape);
385 
386     // Get contiguous strides
387     std::vector<int64_t> strides(tensor_shape.size(), 1);
388     for (int i = static_cast<int>(strides.size()) - 2; i >= 0; i--) {
389       strides[i] = tensor_shape[i + 1] * strides[i + 1];
390     }
391 
392     return std::make_tuple(tensor_shape, strides, 0, dev_shape);
393   } else {
394     const auto &dev_shape = get_dev_shape(tensor_format, storage_info->ori_shape);
395     return std::make_tuple(tensor_shape, storage_info->strides, storage_info->storage_offset, dev_shape);
396   }
397 }
398 
ConvertType(const tensor::BaseTensorPtr & tensor)399 inline aclTensor *ConvertType(const tensor::BaseTensorPtr &tensor) {
400   MS_EXCEPTION_IF_NULL(tensor);
401   static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
402   if (aclCreateTensor == nullptr) {
403     return nullptr;
404   }
405   auto shape = tensor->shape();
406   const auto shape_size = shape.size();
407   aclFormat format = ACL_FORMAT_ND;
408   switch (shape_size) {
409     case kSizeThree:
410       format = ACL_FORMAT_NCL;
411       break;
412     case kSizeFour:
413       format = ACL_FORMAT_NCHW;
414       break;
415     case kSizeFive:
416       format = ACL_FORMAT_NCDHW;
417       break;
418     default:
419       format = ACL_FORMAT_ND;
420   }
421   auto acl_data_type = AclConverter::ConvertType(tensor->data_type());
422   auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
423   if (device_address->GetMutablePtr() == nullptr) {
424     MS_LOG(EXCEPTION) << "The device memory is null, please allocate the device memory for tensor "
425                       << tensor->ToString();
426   }
427   auto [view_shape, strides, offset, ori_dev_shape] = GetViewShapeAndStride(tensor, device_address);
428   auto acl_tensor = aclCreateTensor(view_shape.data(), view_shape.size(), acl_data_type, strides.data(), offset, format,
429                                     ori_dev_shape.data(), ori_dev_shape.size(), device_address->GetMutablePtr());
430   return acl_tensor;
431 }
432 
ConvertType(const std::optional<tensor::BaseTensorPtr> & value)433 inline aclTensor *ConvertType(const std::optional<tensor::BaseTensorPtr> &value) {
434   if (value.has_value()) {
435     return ConvertType(value.value());
436   }
437   return nullptr;
438 }
439 
ConvertType(const std::vector<int64_t> & int_array)440 inline aclIntArray *ConvertType(const std::vector<int64_t> &int_array) {
441   if (int_array.empty()) {
442     MS_LOG(DEBUG) << "int array is empty!";
443   }
444   static OpApiTensorConverter converter;
445   return converter.CreateIntArray(int_array);
446 }
447 
ConvertType(const std::vector<float> & float_array)448 inline aclFloatArray *ConvertType(const std::vector<float> &float_array) {
449   if (float_array.empty()) {
450     MS_LOG(ERROR) << "float array is empty!";
451   }
452   static OpApiTensorConverter converter;
453   return converter.CreateFloatArray(float_array);
454 }
455 
ConvertType(const std::vector<uint8_t> & bool_array)456 inline aclBoolArray *ConvertType(const std::vector<uint8_t> &bool_array) {
457   if (bool_array.empty()) {
458     MS_LOG(ERROR) << "bool array is empty!";
459   }
460   static OpApiTensorConverter converter;
461   return converter.CreateBoolArray(bool_array);
462 }
463 
ConvertType(const std::vector<tensor::BaseTensorPtr> & tensor_list)464 inline aclTensorList *ConvertType(const std::vector<tensor::BaseTensorPtr> &tensor_list) {
465   if (tensor_list.empty()) {
466     MS_LOG(DEBUG) << "tensor list is empty!";
467   }
468   static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList);
469   std::vector<aclTensor *> tmp;
470   std::transform(tensor_list.begin(), tensor_list.end(), std::back_inserter(tmp),
471                  [](const tensor::BaseTensorPtr &tensor) { return ConvertType(tensor); });
472   return aclCreateTensorList(tmp.data(), tmp.size());
473 }
474 
ConvertType(const tensor::TensorPtr & tensor)475 inline aclTensor *ConvertType(const tensor::TensorPtr &tensor) {
476   return ConvertType(tensor->cast<tensor::BaseTensorPtr>());
477 }
478 
ConvertType(const std::optional<tensor::TensorPtr> & value)479 inline aclTensor *ConvertType(const std::optional<tensor::TensorPtr> &value) {
480   if (value.has_value()) {
481     return ConvertType(value.value());
482   }
483   return nullptr;
484 }
485 
ConvertType(const std::vector<tensor::TensorPtr> & tensor_list)486 inline aclTensorList *ConvertType(const std::vector<tensor::TensorPtr> &tensor_list) {
487   if (tensor_list.empty()) {
488     MS_LOG(DEBUG) << "tensor list is empty!";
489   }
490   static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList);
491   std::vector<aclTensor *> tmp;
492   std::transform(tensor_list.begin(), tensor_list.end(), std::back_inserter(tmp),
493                  [](const tensor::TensorPtr &tensor) { return ConvertType(tensor); });
494   return aclCreateTensorList(tmp.data(), tmp.size());
495 }
496 
ConvertType(const std::vector<mindspore::kernel::KernelTensor * > & tensor_list)497 inline aclTensorList *ConvertType(const std::vector<mindspore::kernel::KernelTensor *> &tensor_list) {
498   if (tensor_list.empty()) {
499     MS_LOG(DEBUG) << "tensor list is empty!";
500   }
501   static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList);
502   std::vector<aclTensor *> tmp;
503   std::transform(tensor_list.begin(), tensor_list.end(), std::back_inserter(tmp),
504                  [](mindspore::kernel::KernelTensor *tensor) { return ConvertType(tensor); });
505   return aclCreateTensorList(tmp.data(), tmp.size());
506 }
507 
ConvertType(const ScalarPtr & value)508 inline aclScalar *ConvertType(const ScalarPtr &value) {
509   if (value == nullptr) {
510     // for None
511     return nullptr;
512   }
513   aclScalar *acl_scalar;
514   static OpApiTensorConverter converter;
515   if (value->isa<BoolImm>()) {
516     converter.ConvertValue(value, AttrDeclType<bool>(), &acl_scalar);
517   } else if (value->isa<Int64Imm>()) {
518     converter.ConvertValue(value, AttrDeclType<int64_t>(), &acl_scalar);
519   } else if (value->isa<FP64Imm>()) {
520     converter.ConvertValue(value, AttrDeclType<double>(), &acl_scalar);
521   } else if (value->isa<FP32Imm>()) {
522     converter.ConvertValue(value, AttrDeclType<float>(), &acl_scalar);
523   } else if (value->isa<Int32Imm>()) {
524     converter.ConvertValue(value, AttrDeclType<int32_t>(), &acl_scalar);
525   } else if (value->isa<Int8Imm>()) {
526     converter.ConvertValue(value, AttrDeclType<int8_t>(), &acl_scalar);
527   } else if (value->isa<Int16Imm>()) {
528     converter.ConvertValue(value, AttrDeclType<int16_t>(), &acl_scalar);
529   } else if (value->isa<UInt8Imm>()) {
530     converter.ConvertValue(value, AttrDeclType<uint8_t>(), &acl_scalar);
531   } else if (value->isa<FP64Imm>()) {
532     converter.ConvertValue(value, AttrDeclType<double>(), &acl_scalar);
533   } else if (value->isa<BF16Imm>()) {
534     converter.ConvertValue(value, AttrDeclType<bfloat16>(), &acl_scalar);
535   } else {
536     MS_LOG(EXCEPTION) << "Currently not support value: " << value->ToString();
537   }
538   return acl_scalar;
539 }
540 
ConvertType(const std::optional<ScalarPtr> & value)541 inline aclScalar *ConvertType(const std::optional<ScalarPtr> &value) {
542   if (value.has_value()) {
543     return ConvertType(value.value());
544   }
545   return nullptr;
546 }
547 
ConvertType(TypeId type_id)548 inline aclDataType ConvertType(TypeId type_id) { return AclConverter::ConvertType(type_id); }
549 
ConvertType(const TypePtr & type)550 inline aclDataType ConvertType(const TypePtr &type) { return AclConverter::ConvertType(type->type_id()); }
551 
ConvertType(const std::string & value)552 inline const char *ConvertType(const std::string &value) { return value.c_str(); }
553 
554 template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
ConvertType(T value)555 T ConvertType(T value) {
556   return value;
557 }
558 
559 template <typename... Ts>
ConvertTypes(const Ts &...args)560 constexpr auto ConvertTypes(const Ts &... args) {
561   return std::make_tuple(ConvertType(args)...);
562 }
563 
564 template <typename T>
ConvertKernelTensor(mindspore::kernel::KernelTensor * tensor)565 T ConvertKernelTensor(mindspore::kernel::KernelTensor *tensor) {
566   MS_EXCEPTION_IF_NULL(tensor);
567   return tensor->GetValueWithCheck<T>();
568 }
569 
570 template <>
571 inline ScalarPtr ConvertKernelTensor<ScalarPtr>(mindspore::kernel::KernelTensor *tensor) {
572   MS_EXCEPTION_IF_NULL(tensor);
573   if (tensor->dtype_id() == kMetaTypeNone) {
574     // for None
575     return nullptr;
576   }
577   auto value_ptr = tensor->GetValueTrack();
578   if (value_ptr == nullptr) {
579     if (tensor->dtype_id() == kNumberTypeBool) {
580       auto value = tensor->GetValueWithCheck<bool>();
581       value_ptr = std::make_shared<BoolImm>(value);
582     } else if (tensor->dtype_id() == kNumberTypeInt64) {
583       auto value = tensor->GetValueWithCheck<int64_t>();
584       value_ptr = std::make_shared<Int64Imm>(value);
585     } else if (tensor->dtype_id() == kNumberTypeDouble || tensor->dtype_id() == kNumberTypeFloat64) {
586       auto value = tensor->GetValueWithCheck<double>();
587       value_ptr = std::make_shared<FP64Imm>(value);
588     } else if (tensor->dtype_id() == kNumberTypeFloat32) {
589       auto value = tensor->GetValueWithCheck<float>();
590       value_ptr = std::make_shared<FP32Imm>(value);
591     } else if (tensor->dtype_id() == kNumberTypeInt32) {
592       auto value = tensor->GetValueWithCheck<int32_t>();
593       value_ptr = std::make_shared<Int32Imm>(value);
594     } else if (tensor->dtype_id() == kNumberTypeInt8) {
595       auto value = tensor->GetValueWithCheck<int8_t>();
596       value_ptr = std::make_shared<Int8Imm>(value);
597     } else if (tensor->dtype_id() == kNumberTypeInt16) {
598       auto value = tensor->GetValueWithCheck<int16_t>();
599       value_ptr = std::make_shared<Int16Imm>(value);
600     } else if (tensor->dtype_id() == kNumberTypeUInt8) {
601       auto value = tensor->GetValueWithCheck<uint8_t>();
602       value_ptr = std::make_shared<UInt8Imm>(value);
603     } else if (tensor->dtype_id() == kNumberTypeBFloat16) {
604       auto value = tensor->GetValueWithCheck<bfloat16>();
605       value_ptr = std::make_shared<BF16Imm>(value);
606     } else {
607       MS_LOG(EXCEPTION) << "Currently not support value type: " << tensor->dtype_id();
608     }
609   }
610 
611   MS_EXCEPTION_IF_NULL(value_ptr);
612 
613   if (!value_ptr->isa<Scalar>()) {
614     MS_LOG(EXCEPTION) << "Current tensor's must be a scalar, please check!";
615   }
616   auto scalar_ptr = value_ptr->cast<ScalarPtr>();
617   MS_EXCEPTION_IF_NULL(scalar_ptr);
618   return scalar_ptr;
619 }
620 
621 template <>
622 inline std::vector<int64_t> ConvertKernelTensor<std::vector<int64_t>>(mindspore::kernel::KernelTensor *tensor) {
623   MS_EXCEPTION_IF_NULL(tensor);
624   return tensor->GetValueWithCheck<std::vector<int64_t>>();
625 }
626 
627 template <>
628 inline std::vector<float> ConvertKernelTensor<std::vector<float>>(mindspore::kernel::KernelTensor *tensor) {
629   MS_EXCEPTION_IF_NULL(tensor);
630   return tensor->GetValueWithCheck<std::vector<float>>();
631 }
632 
633 template <>
634 inline std::vector<uint8_t> ConvertKernelTensor<std::vector<uint8_t>>(mindspore::kernel::KernelTensor *tensor) {
635   MS_EXCEPTION_IF_NULL(tensor);
636   return tensor->GetValueWithCheck<std::vector<uint8_t>>();
637 }
638 
639 template <>
640 inline TypeId ConvertKernelTensor<TypeId>(mindspore::kernel::KernelTensor *tensor) {
641   MS_EXCEPTION_IF_NULL(tensor);
642   return tensor->dtype_id();
643 }
644 
645 template <>
646 inline std::vector<mindspore::kernel::KernelTensor *>
647 ConvertKernelTensor<std::vector<mindspore::kernel::KernelTensor *>>(mindspore::kernel::KernelTensor *tensor) {
648   MS_EXCEPTION_IF_NULL(tensor);
649   if (tensor->type_id() != kObjectTypeTuple && tensor->type_id() != kObjectTypeList) {
650     return {tensor};
651   }
652   auto shape = tensor->GetShapeVector();
653   if (shape.empty()) {
654     MS_LOG(EXCEPTION) << "Current tensor is a tuple of tensor, but get a empty shape!";
655   }
656   if (shape[kIndex0] <= 0) {
657     MS_LOG(EXCEPTION) << shape << " is an invalid shape, please check op infer!";
658   }
659 
660   std::vector<mindspore::kernel::KernelTensor *> res;
661 
662   auto split_num = shape[kIndex0];
663   auto offset = tensor->size() / split_num;
664   auto new_shape = shape;
665   new_shape.erase(new_shape.begin());
666 
667   for (int i = 0; i < split_num; ++i) {
668     auto new_tensor = new KernelTensor(*tensor);
669     new_tensor->SetType(std::make_shared<TensorType>(TypeIdToType(tensor->dtype_id())));
670     auto tensor_shape = std::make_shared<abstract::TensorShape>();
671     tensor_shape->SetShapeVector(new_shape);
672     new_tensor->SetShape(tensor_shape);
673     new_tensor->set_device_ptr(
674       reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor->device_ptr()) + offset * i));
675     new_tensor->set_size(offset);
676     (void)res.emplace_back(new_tensor);
677   }
678   return res;
679 }
680 
Release(aclTensor * p)681 inline void Release(aclTensor *p) {
682   static const auto aclDestroyTensor = GET_OP_API_FUNC(aclDestroyTensor);
683   if (aclDestroyTensor == nullptr) {
684     return;
685   }
686   aclDestroyTensor(p);
687 }
688 
Release(aclScalar * p)689 inline void Release(aclScalar *p) {
690   static const auto aclDestroyScalar = GET_OP_API_FUNC(aclDestroyScalar);
691   if (aclDestroyScalar == nullptr) {
692     return;
693   }
694   aclDestroyScalar(p);
695 }
696 
Release(aclIntArray * p)697 inline void Release(aclIntArray *p) {
698   static const auto aclDestroyIntArray = GET_OP_API_FUNC(aclDestroyIntArray);
699   if (aclDestroyIntArray == nullptr) {
700     return;
701   }
702 
703   aclDestroyIntArray(p);
704 }
705 
Release(aclBoolArray * p)706 inline void Release(aclBoolArray *p) {
707   static const auto aclDestroyBoolArray = GET_OP_API_FUNC(aclDestroyBoolArray);
708   if (aclDestroyBoolArray == nullptr) {
709     return;
710   }
711 
712   aclDestroyBoolArray(p);
713 }
714 
Release(aclTensorList * p)715 inline void Release(aclTensorList *p) {
716   static const auto aclDestroyTensorList = GET_OP_API_FUNC(aclDestroyTensorList);
717   if (aclDestroyTensorList == nullptr) {
718     return;
719   }
720 
721   aclDestroyTensorList(p);
722 }
723 
724 template <typename T>
Release(T value)725 void Release(T value) {
726   (void)value;
727 }
728 
729 template <typename Tuple, size_t... I>
CallRelease(Tuple t,std::index_sequence<I...>)730 void CallRelease(Tuple t, std::index_sequence<I...>) {
731   (void)std::initializer_list<int>{(Release(std::get<I>(t)), 0)...};
732 }
733 
734 template <typename Tuple>
ReleaseConvertTypes(const Tuple & t)735 void ReleaseConvertTypes(const Tuple &t) {
736   static constexpr auto size = std::tuple_size<Tuple>::value;
737   CallRelease(t, std::make_index_sequence<size>{});
738 }
739 
740 // return a Scalar with the input type
741 #define MAKE_SCALAR(num, typeid, out)                                       \
742   switch (typeid) {                                                         \
743     case kNumberTypeFloat32: {                                              \
744       out = std::make_shared<FP32Imm>(static_cast<float>(num));             \
745       break;                                                                \
746     }                                                                       \
747     case kNumberTypeFloat16: {                                              \
748       out = std::make_shared<FP32Imm>(static_cast<float>(num));             \
749       break;                                                                \
750     }                                                                       \
751     case kNumberTypeFloat64: {                                              \
752       out = std::make_shared<FP64Imm>(static_cast<double>(num));            \
753       break;                                                                \
754     }                                                                       \
755     case kNumberTypeInt8: {                                                 \
756       out = std::make_shared<Int8Imm>(static_cast<int8_t>(num));            \
757       break;                                                                \
758     }                                                                       \
759     case kNumberTypeInt16: {                                                \
760       out = std::make_shared<Int16Imm>(static_cast<int16_t>(num));          \
761       break;                                                                \
762     }                                                                       \
763     case kNumberTypeInt32: {                                                \
764       out = std::make_shared<Int32Imm>(static_cast<int>(num));              \
765       break;                                                                \
766     }                                                                       \
767     case kNumberTypeInt64: {                                                \
768       out = std::make_shared<Int64Imm>(static_cast<int64_t>(num));          \
769       break;                                                                \
770     }                                                                       \
771     case kNumberTypeBool: {                                                 \
772       out = std::make_shared<BoolImm>(static_cast<bool>(num));              \
773       break;                                                                \
774     }                                                                       \
775     case kNumberTypeUInt8: {                                                \
776       out = std::make_shared<UInt8Imm>(static_cast<uint8_t>(num));          \
777       break;                                                                \
778     }                                                                       \
779     case kNumberTypeBFloat16: {                                             \
780       out = std::make_shared<BF16Imm>(static_cast<bfloat16>(num));          \
781       break;                                                                \
782     }                                                                       \
783     default: {                                                              \
784       MS_LOG(EXCEPTION) << "Not support typeid " << TypeIdToString(typeid); \
785     }                                                                       \
786   }
787 
788 }  // namespace mindspore::transform
789 #endif  // MINDSPORE_CCSRC_TRANSFORM_ACL_IR_OP_API_CONVERT_H_
790