• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_PYBOOST_PYBOOST_UTILS_H_
18 #define MINDSPORE_MINDSPORE_CCSRC_KERNEL_PYBOOST_PYBOOST_UTILS_H_
19 
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 #include "include/common/utils/tensor_future.h"
25 #include "runtime/pynative/op_executor.h"
26 #include "mindspore/core/ops/view/view_strides_calculator.h"
27 #include "runtime/device/device_address_utils.h"
28 #include "kernel/pyboost/pyboost_kernel_extra_func.h"
29 #include "mindspore/core/utils/simple_info.h"
30 #include "include/common/pynative/abstract_converter.h"
31 
32 namespace mindspore {
33 namespace kernel {
34 namespace pyboost {
35 using AbstractConverter = pynative::AbstractConverter;
36 using AddressInfoPair = std::pair<std::vector<kernel::KernelTensor *>, device::DeviceAddressPtrList>;
37 using BaseTensor = tensor::BaseTensor;
38 using BaseTensorPtr = tensor::BaseTensorPtr;
39 AbstractBasePtr BACKEND_EXPORT ToAbstractNoValue(const BaseTensorPtr &tensor);
40 
41 class BACKEND_EXPORT PyBoostUtils {
42  public:
43   static AbstractBasePtr InferByOpDef(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_abs);
44   static void DispatchRun(const std::shared_ptr<runtime::PyBoostDeviceTask> &task);
45 
46   static DeviceSyncPtr ContiguousByDeviceAddress(const DeviceSyncPtr &device_sync);
47 
48   // Create device address
49   static device::DeviceAddressPtrList CreateWorkSpaceDeviceAddress(const KernelModPtr &kernel_mod,
50                                                                    const device::DeviceContext *device_context,
51                                                                    const std::string &op_name);
52 
53   // Create output tensors
54   static void CreateOutputTensor(const AbstractBasePtr &abstract, std::vector<tensor::BaseTensorPtr> *outputs);
55   static void CreateOutputTensor(const DeviceContext *device_context, const tensor::BaseTensorPtr &input,
56                                  const TensorStorageInfoPtr &storage_info, std::vector<tensor::BaseTensorPtr> *outputs);
57   static void CreateOutputTensor(const DeviceContext *device_context, const tensor::BaseTensorPtr &input,
58                                  const TensorStorageInfoPtrList &storage_info_list,
59                                  std::vector<tensor::BaseTensorPtr> *outputs);
60   static void CreateOutputTensor(const ValueSimpleInfoPtr &output_value_simple_info,
61                                  std::vector<tensor::BaseTensorPtr> *outputs);
62 
63   // Create input device address without kernel tensor
64   template <typename... Args>
PrepareOpInputs(const DeviceContext * device_context,size_t stream_id,const Args &...args)65   static void PrepareOpInputs(const DeviceContext *device_context, size_t stream_id, const Args &... args) {
66     size_t index = 0;
67     auto add_index = [&index]() { return index++; };
68     (runtime::DeviceAddressUtils::CreateInputTensorAddress(device_context, stream_id, add_index(), args), ...);
69   }
70 
71   template <typename... T>
MallocOpInputs(const DeviceContext * device_context,const T &...args)72   static void MallocOpInputs(const DeviceContext *device_context, const T &... args) {
73     runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyBoostMallocInput,
74                                        runtime::ProfilerRecorder::kNoName, false);
75     (runtime::DeviceAddressUtils::MallocForInput(device_context, args, false), ...);
76   }
77 
78   template <typename... T>
MallocOpInputsForView(const DeviceContext * device_context,const T &...args)79   static void MallocOpInputsForView(const DeviceContext *device_context, const T &... args) {
80     runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyBoostMallocInput,
81                                        runtime::ProfilerRecorder::kNoName, false);
82     (runtime::DeviceAddressUtils::MallocForInput(device_context, args, true), ...);
83   }
84 
85   template <typename... T, std::size_t... Index>
GetAddressInfoHelper(const DeviceContext * device_context,size_t stream_id,const std::vector<AbstractBasePtr> & input_abs,std::vector<kernel::KernelTensor * > * kernel_tensor_list,device::DeviceAddressPtrList * device_address_list,std::index_sequence<Index...>,const T &...args)86   static void GetAddressInfoHelper(const DeviceContext *device_context, size_t stream_id,
87                                    const std::vector<AbstractBasePtr> &input_abs,
88                                    std::vector<kernel::KernelTensor *> *kernel_tensor_list,
89                                    device::DeviceAddressPtrList *device_address_list, std::index_sequence<Index...>,
90                                    const T &... args) {
91     (GetKernelTensor(device_context, stream_id, input_abs[Index], Index, kernel_tensor_list, device_address_list, args),
92      ...);
93   }
94 
95   template <typename... T>
GetAddressInfo(const DeviceContext * device_context,size_t stream_id,const std::vector<AbstractBasePtr> & input_abs,const T &...args)96   static AddressInfoPair GetAddressInfo(const DeviceContext *device_context, size_t stream_id,
97                                         const std::vector<AbstractBasePtr> &input_abs, const T &... args) {
98     std::vector<kernel::KernelTensor *> kernel_tensor_list;
99     // Kernel tensor is a raw ppointer, device address need to be returned.
100     device::DeviceAddressPtrList device_address_list;
101     if (input_abs.empty()) {
102       std::vector<AbstractBasePtr> tmp_abs(sizeof...(args), nullptr);
103       GetAddressInfoHelper(device_context, stream_id, tmp_abs, &kernel_tensor_list, &device_address_list,
104                            std::make_index_sequence<sizeof...(T)>(), args...);
105     } else {
106       GetAddressInfoHelper(device_context, stream_id, input_abs, &kernel_tensor_list, &device_address_list,
107                            std::make_index_sequence<sizeof...(T)>(), args...);
108     }
109     return std::make_pair(kernel_tensor_list, device_address_list);
110   }
111 
112   static void LaunchKernel(const PrimitivePtr &primitive, const device::DeviceContext *device_context,
113                            const AddressInfoPair &input_address_info, const AddressInfoPair &output_address_info,
114                            size_t stream_id = kDefaultStreamIndex);
115 
GetKernelTensor(const DeviceContext * device_context,size_t stream_id,size_t index,std::vector<kernel::KernelTensor * > * kernel_tensor_list,device::DeviceAddressPtrList * device_address_list,const BaseTensorPtr & tensor)116   static void GetKernelTensor(const DeviceContext *device_context, size_t stream_id, size_t index,
117                               std::vector<kernel::KernelTensor *> *kernel_tensor_list,
118                               device::DeviceAddressPtrList *device_address_list, const BaseTensorPtr &tensor) {
119     GetKernelTensor(device_context, stream_id, nullptr, index, kernel_tensor_list, device_address_list, tensor);
120   }
121 
122   static void GetKernelTensor(const DeviceContext *device_context, size_t stream_id,
123                               const abstract::AbstractBasePtr &input_abs, size_t index,
124                               std::vector<kernel::KernelTensor *> *kernel_tensor_list,
125                               device::DeviceAddressPtrList *device_address_list, const BaseTensorPtr &tensor);
126 
127   template <typename T>
GetKernelTensor(const DeviceContext * device_context,size_t stream_id,const abstract::AbstractBasePtr & input_abs,size_t index,std::vector<kernel::KernelTensor * > * kernel_tensor_list,device::DeviceAddressPtrList * device_address_list,const std::optional<T> & val)128   static void GetKernelTensor(const DeviceContext *device_context, size_t stream_id,
129                               const abstract::AbstractBasePtr &input_abs, size_t index,
130                               std::vector<kernel::KernelTensor *> *kernel_tensor_list,
131                               device::DeviceAddressPtrList *device_address_list, const std::optional<T> &val) {
132     if (val.has_value()) {
133       GetKernelTensor(device_context, stream_id, input_abs, index, kernel_tensor_list, device_address_list,
134                       val.value());
135     } else {
136       // Construct none kernel tensor
137       MS_EXCEPTION_IF_NULL(kernel_tensor_list);
138       MS_EXCEPTION_IF_NULL(device_address_list);
139 
140       const auto &kernel_tensor = std::make_shared<kernel::KernelTensor>(
141         std::make_shared<abstract::TensorShape>(ShapeVector()), kTypeNone, kNone, nullptr, 0, kOpFormat_DEFAULT,
142         kTypeNone->type_id(), ShapeVector(), device_context->device_context_key().device_name_,
143         device_context->device_context_key().device_id_);
144       kernel_tensor->set_stream_id(stream_id);
145       (void)kernel_tensor_list->emplace_back(kernel_tensor.get());
146       auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
147       (void)device_address_list->emplace_back(device_address);
148     }
149   }
150 
151   static void GetKernelTensor(const DeviceContext *device_context, size_t stream_id,
152                               const abstract::AbstractBasePtr &input_abs, size_t index,
153                               std::vector<kernel::KernelTensor *> *kernel_tensor_list,
154                               device::DeviceAddressPtrList *device_address_list,
155                               const std::vector<tensor::BaseTensorPtr> &tensors);
156 
157   template <typename T>
GetKernelTensor(const DeviceContext * device_context,size_t stream_id,const abstract::AbstractBasePtr & input_abs,size_t index,std::vector<kernel::KernelTensor * > * kernel_tensor_list,device::DeviceAddressPtrList * device_address_list,const T & val)158   static void GetKernelTensor(const DeviceContext *device_context, size_t stream_id,
159                               const abstract::AbstractBasePtr &input_abs, size_t index,
160                               std::vector<kernel::KernelTensor *> *kernel_tensor_list,
161                               device::DeviceAddressPtrList *device_address_list, const T &val) {
162     // Value ptr alloc device address and malloc mem here
163     auto device_address =
164       runtime::DeviceAddressUtils::CreateInputAddress(device_context, stream_id, input_abs, index, val);
165     MS_EXCEPTION_IF_NULL(device_address);
166     MS_EXCEPTION_IF_NULL(device_address_list);
167     MS_EXCEPTION_IF_NULL(kernel_tensor_list);
168     (void)device_address_list->emplace_back(device_address);
169     (void)kernel_tensor_list->emplace_back(device_address->kernel_tensor().get());
170   }
171 
172   // Create output tensor device address without kernel tensor
PrepareOpOutputs(const DeviceContext * device_context,size_t stream_id,const std::vector<tensor::BaseTensorPtr> & outputs)173   static void PrepareOpOutputs(const DeviceContext *device_context, size_t stream_id,
174                                const std::vector<tensor::BaseTensorPtr> &outputs) {
175     runtime::DeviceAddressUtils::CreateOutputTensorAddress(device_context, stream_id, outputs);
176   }
177 
178   // Create output tensor device address without kernel tensor
MallocOpOutputs(const DeviceContext * device_context,const std::vector<tensor::BaseTensorPtr> & outputs)179   static void MallocOpOutputs(const DeviceContext *device_context, const std::vector<tensor::BaseTensorPtr> &outputs) {
180     runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyBoostMallocOutput,
181                                        runtime::ProfilerRecorder::kNoName, false);
182     runtime::DeviceAddressUtils::MallocForOutputs(device_context, outputs);
183   }
184 
185   // Create workspace device address with kernel tensor
186   static std::vector<kernel::KernelTensor *> GetKernelTensorFromAddress(
187     const device::DeviceAddressPtrList &input_device_address);
188 
189   // Check kernel mod is reg
190   static bool IsKernelModRegistered(const std::string &device_name, const std::string &op_name);
191   static bool IsPyBoostCustomRegistered(const std::string &device_name, const std::string &op_name);
192 
193   static kernel::KernelModPtr CreateKernelMod(const PrimitivePtr &prim, const std::string &op_name,
194                                               const DeviceContext *device_context,
195                                               const std::vector<KernelTensor *> &inputs,
196                                               const std::vector<KernelTensor *> &outputs);
197   // return IsStrictlyMatched and KernelAttr
198   static std::pair<bool, KernelAttr> SelectKernel(const std::vector<AbstractBasePtr> &inputs_abs,
199                                                   const AbstractBasePtr &outputs_abs,
200                                                   const DeviceContext *device_context, const std::string &op_name);
201   static std::optional<tensor::BaseTensorPtr> CastTensor(const std::optional<tensor::BaseTensorPtr> &tensor,
202                                                          const TypeId &type_id, const std::string &device_target);
203   static tensor::BaseTensorPtr CastTensor(const tensor::BaseTensorPtr &tensor, const TypeId &type_id,
204                                           const std::string &device_target);
205   static std::vector<tensor::BaseTensorPtr> CastTensor(const std::vector<tensor::BaseTensorPtr> &tensors,
206                                                        const std::vector<TypeId> &type_id_list,
207                                                        const std::string &device_target);
208   // ValueTuple input
209   static std::vector<tensor::BaseTensorPtr> CastTensor(const std::vector<tensor::BaseTensorPtr> &tensors,
210                                                        TypeId type_id, const std::string &device_target);
211   template <typename... T>
SelectKernel(AbstractConverter * converter,const DeviceContext * device_context,const std::string & op_name,const ValueSimpleInfoPtr & output_value_simple_info,const T &...args)212   static std::pair<bool, KernelAttr> SelectKernel(AbstractConverter *converter, const DeviceContext *device_context,
213                                                   const std::string &op_name,
214                                                   const ValueSimpleInfoPtr &output_value_simple_info,
215                                                   const T &... args) {
216     // Get inputs abstract
217     std::vector<AbstractBasePtr> input_abs;
218     ((void)input_abs.emplace_back(converter->ConvertAbstract(args)), ...);
219 
220     // Get output abstract
221     auto output_abs = TransformValueSimpleInfoToAbstract(*output_value_simple_info);
222     return SelectKernel(input_abs, output_abs, device_context, op_name);
223   }
ConvertTensorVectorToTuple(const std::vector<BaseTensorPtr> & tensor_list)224   static ValueTuplePtr ConvertTensorVectorToTuple(const std::vector<BaseTensorPtr> &tensor_list) {
225     vector<ValuePtr> value_vector;
226     for (const auto &tensor : tensor_list) {
227       (void)value_vector.emplace_back(tensor);
228     }
229     auto result = std::make_shared<ValueTuple>(value_vector);
230     MS_LOG(DEBUG) << "Convert TensorList to ValueTuple " << result->ToString();
231     return result;
232   }
233   static BaseTensorPtr ScalarToTensor(const ScalarPtr &scalar);
234 
cur_stream_id()235   static uint32_t cur_stream_id() { return cur_stream_id_; }
236 
237   // Set current stream for CREATE_PYBOOST_OP in front queue.
set_cur_stream_id(uint32_t cur_stream_id)238   static void set_cur_stream_id(uint32_t cur_stream_id) { cur_stream_id_ = cur_stream_id; }
239 
240  private:
241   inline static uint32_t cur_stream_id_ = kDefaultStreamIndex;
242 };
243 
244 template <typename T>
ConvertValueTupleToVector(const ValueTuplePtr & tuple)245 std::vector<T> ConvertValueTupleToVector(const ValueTuplePtr &tuple) {
246   std::vector<T> result;
247   const auto &values = tuple->value();
248   for (const auto &value : values) {
249     (void)result.emplace_back(GetValue<T>(value));
250   }
251   MS_LOG(DEBUG) << "Convert ValueTuple to vector " << result;
252   return result;
253 }
254 
255 // Shield kernel hardware differences. Call some func of derived classes based on base classes.
256 // Just like SetThreadPool
257 class BACKEND_EXPORT PyboostKernelExtraFuncFactory {
258  public:
259   static PyboostKernelExtraFuncFactory &GetInstance();
260   PyboostKernelExtraFuncFactory() = default;
261   ~PyboostKernelExtraFuncFactory() = default;
AddPyboostKernelExtraFunc(const std::string & op_name,const PyboostKernelExtraFuncPtr & func)262   void AddPyboostKernelExtraFunc(const std::string &op_name, const PyboostKernelExtraFuncPtr &func) {
263     kernel_func_map_[op_name] = func;
264   }
265 
SetThreadPool(const std::string & device_name,const kernel::KernelModPtr & kernel)266   void SetThreadPool(const std::string &device_name, const kernel::KernelModPtr &kernel) {
267     auto iter = kernel_func_map_.find(device_name);
268     if (iter == kernel_func_map_.end()) {
269       return;
270     }
271     iter->second->SetThreadPool(kernel);
272   }
273 
IsKernelModRegistered(const std::string & device_name,const std::string & op_name)274   bool IsKernelModRegistered(const std::string &device_name, const std::string &op_name) {
275     auto iter = kernel_func_map_.find(device_name);
276     if (iter == kernel_func_map_.end()) {
277       return true;
278     }
279     return iter->second->IsKernelModRegistered(op_name);
280   }
281 
IsPyBoostCustomRegistered(const std::string & device_name,const std::string & op_name)282   bool IsPyBoostCustomRegistered(const std::string &device_name, const std::string &op_name) {
283     auto iter = kernel_func_map_.find(device_name);
284     if (iter == kernel_func_map_.end()) {
285       return true;
286     }
287     return iter->second->IsPyBoostCustomRegistered(op_name);
288   }
289 
IsEnableProfiler(const std::string & device_name)290   bool IsEnableProfiler(const std::string &device_name) {
291     auto iter = kernel_func_map_.find(device_name);
292     if (iter == kernel_func_map_.end()) {
293       return false;
294     }
295     return iter->second->IsEnableProfiler();
296   }
297 
LaunchKernelWithProfiler(const std::string & device_name,const device::DeviceContext * device_context,const std::string & op_name,const std::vector<BaseShapePtr> & base_shape,const std::function<void ()> & func)298   void LaunchKernelWithProfiler(const std::string &device_name, const device::DeviceContext *device_context,
299                                 const std::string &op_name, const std::vector<BaseShapePtr> &base_shape,
300                                 const std::function<void()> &func) {
301     auto iter = kernel_func_map_.find(device_name);
302     if (iter == kernel_func_map_.end()) {
303       return;
304     }
305     iter->second->LaunchKernelWithProfiler(op_name, device_context, base_shape, func);
306   }
307 
308  private:
309   mindspore::HashMap<std::string, PyboostKernelExtraFuncPtr> kernel_func_map_;
310 };
311 
312 class PyboostKernelExtraFuncRegistrar {
313  public:
PyboostKernelExtraFuncRegistrar(const std::string & op_name,const PyboostKernelExtraFuncPtr & func)314   PyboostKernelExtraFuncRegistrar(const std::string &op_name, const PyboostKernelExtraFuncPtr &func) {
315     PyboostKernelExtraFuncFactory::GetInstance().AddPyboostKernelExtraFunc(op_name, func);
316   }
317 
318   ~PyboostKernelExtraFuncRegistrar() = default;
319 };
320 
321 #define REG_PYBOOST_KERNEL_EXTRA_FUN(op_name, func) \
322   static PyboostKernelExtraFuncRegistrar g_##op_name##PyboostKernelExtraFunc(#op_name, std::make_shared<func>());
323 
324 }  // namespace pyboost
325 }  // namespace kernel
326 }  // namespace mindspore
327 #endif  // MINDSPORE_MINDSPORE_CCSRC_KERNEL_PYBOOST_PYBOOST_UTILS_H_
328