1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_PYBOOST_PYBOOST_UTILS_H_
18 #define MINDSPORE_MINDSPORE_CCSRC_KERNEL_PYBOOST_PYBOOST_UTILS_H_
19
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 #include "include/common/utils/tensor_future.h"
25 #include "runtime/pynative/op_executor.h"
26 #include "mindspore/core/ops/view/view_strides_calculator.h"
27 #include "runtime/device/device_address_utils.h"
28 #include "kernel/pyboost/pyboost_kernel_extra_func.h"
29 #include "mindspore/core/utils/simple_info.h"
30 #include "include/common/pynative/abstract_converter.h"
31
32 namespace mindspore {
33 namespace kernel {
34 namespace pyboost {
35 using AbstractConverter = pynative::AbstractConverter;
36 using AddressInfoPair = std::pair<std::vector<kernel::KernelTensor *>, device::DeviceAddressPtrList>;
37 using BaseTensor = tensor::BaseTensor;
38 using BaseTensorPtr = tensor::BaseTensorPtr;
39 AbstractBasePtr BACKEND_EXPORT ToAbstractNoValue(const BaseTensorPtr &tensor);
40
41 class BACKEND_EXPORT PyBoostUtils {
42 public:
43 static AbstractBasePtr InferByOpDef(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_abs);
44 static void DispatchRun(const std::shared_ptr<runtime::PyBoostDeviceTask> &task);
45
46 static DeviceSyncPtr ContiguousByDeviceAddress(const DeviceSyncPtr &device_sync);
47
48 // Create device address
49 static device::DeviceAddressPtrList CreateWorkSpaceDeviceAddress(const KernelModPtr &kernel_mod,
50 const device::DeviceContext *device_context,
51 const std::string &op_name);
52
53 // Create output tensors
54 static void CreateOutputTensor(const AbstractBasePtr &abstract, std::vector<tensor::BaseTensorPtr> *outputs);
55 static void CreateOutputTensor(const DeviceContext *device_context, const tensor::BaseTensorPtr &input,
56 const TensorStorageInfoPtr &storage_info, std::vector<tensor::BaseTensorPtr> *outputs);
57 static void CreateOutputTensor(const DeviceContext *device_context, const tensor::BaseTensorPtr &input,
58 const TensorStorageInfoPtrList &storage_info_list,
59 std::vector<tensor::BaseTensorPtr> *outputs);
60 static void CreateOutputTensor(const ValueSimpleInfoPtr &output_value_simple_info,
61 std::vector<tensor::BaseTensorPtr> *outputs);
62
63 // Create input device address without kernel tensor
64 template <typename... Args>
PrepareOpInputs(const DeviceContext * device_context,size_t stream_id,const Args &...args)65 static void PrepareOpInputs(const DeviceContext *device_context, size_t stream_id, const Args &... args) {
66 size_t index = 0;
67 auto add_index = [&index]() { return index++; };
68 (runtime::DeviceAddressUtils::CreateInputTensorAddress(device_context, stream_id, add_index(), args), ...);
69 }
70
71 template <typename... T>
MallocOpInputs(const DeviceContext * device_context,const T &...args)72 static void MallocOpInputs(const DeviceContext *device_context, const T &... args) {
73 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyBoostMallocInput,
74 runtime::ProfilerRecorder::kNoName, false);
75 (runtime::DeviceAddressUtils::MallocForInput(device_context, args, false), ...);
76 }
77
78 template <typename... T>
MallocOpInputsForView(const DeviceContext * device_context,const T &...args)79 static void MallocOpInputsForView(const DeviceContext *device_context, const T &... args) {
80 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyBoostMallocInput,
81 runtime::ProfilerRecorder::kNoName, false);
82 (runtime::DeviceAddressUtils::MallocForInput(device_context, args, true), ...);
83 }
84
85 template <typename... T, std::size_t... Index>
GetAddressInfoHelper(const DeviceContext * device_context,size_t stream_id,const std::vector<AbstractBasePtr> & input_abs,std::vector<kernel::KernelTensor * > * kernel_tensor_list,device::DeviceAddressPtrList * device_address_list,std::index_sequence<Index...>,const T &...args)86 static void GetAddressInfoHelper(const DeviceContext *device_context, size_t stream_id,
87 const std::vector<AbstractBasePtr> &input_abs,
88 std::vector<kernel::KernelTensor *> *kernel_tensor_list,
89 device::DeviceAddressPtrList *device_address_list, std::index_sequence<Index...>,
90 const T &... args) {
91 (GetKernelTensor(device_context, stream_id, input_abs[Index], Index, kernel_tensor_list, device_address_list, args),
92 ...);
93 }
94
95 template <typename... T>
GetAddressInfo(const DeviceContext * device_context,size_t stream_id,const std::vector<AbstractBasePtr> & input_abs,const T &...args)96 static AddressInfoPair GetAddressInfo(const DeviceContext *device_context, size_t stream_id,
97 const std::vector<AbstractBasePtr> &input_abs, const T &... args) {
98 std::vector<kernel::KernelTensor *> kernel_tensor_list;
99 // Kernel tensor is a raw ppointer, device address need to be returned.
100 device::DeviceAddressPtrList device_address_list;
101 if (input_abs.empty()) {
102 std::vector<AbstractBasePtr> tmp_abs(sizeof...(args), nullptr);
103 GetAddressInfoHelper(device_context, stream_id, tmp_abs, &kernel_tensor_list, &device_address_list,
104 std::make_index_sequence<sizeof...(T)>(), args...);
105 } else {
106 GetAddressInfoHelper(device_context, stream_id, input_abs, &kernel_tensor_list, &device_address_list,
107 std::make_index_sequence<sizeof...(T)>(), args...);
108 }
109 return std::make_pair(kernel_tensor_list, device_address_list);
110 }
111
112 static void LaunchKernel(const PrimitivePtr &primitive, const device::DeviceContext *device_context,
113 const AddressInfoPair &input_address_info, const AddressInfoPair &output_address_info,
114 size_t stream_id = kDefaultStreamIndex);
115
GetKernelTensor(const DeviceContext * device_context,size_t stream_id,size_t index,std::vector<kernel::KernelTensor * > * kernel_tensor_list,device::DeviceAddressPtrList * device_address_list,const BaseTensorPtr & tensor)116 static void GetKernelTensor(const DeviceContext *device_context, size_t stream_id, size_t index,
117 std::vector<kernel::KernelTensor *> *kernel_tensor_list,
118 device::DeviceAddressPtrList *device_address_list, const BaseTensorPtr &tensor) {
119 GetKernelTensor(device_context, stream_id, nullptr, index, kernel_tensor_list, device_address_list, tensor);
120 }
121
122 static void GetKernelTensor(const DeviceContext *device_context, size_t stream_id,
123 const abstract::AbstractBasePtr &input_abs, size_t index,
124 std::vector<kernel::KernelTensor *> *kernel_tensor_list,
125 device::DeviceAddressPtrList *device_address_list, const BaseTensorPtr &tensor);
126
127 template <typename T>
GetKernelTensor(const DeviceContext * device_context,size_t stream_id,const abstract::AbstractBasePtr & input_abs,size_t index,std::vector<kernel::KernelTensor * > * kernel_tensor_list,device::DeviceAddressPtrList * device_address_list,const std::optional<T> & val)128 static void GetKernelTensor(const DeviceContext *device_context, size_t stream_id,
129 const abstract::AbstractBasePtr &input_abs, size_t index,
130 std::vector<kernel::KernelTensor *> *kernel_tensor_list,
131 device::DeviceAddressPtrList *device_address_list, const std::optional<T> &val) {
132 if (val.has_value()) {
133 GetKernelTensor(device_context, stream_id, input_abs, index, kernel_tensor_list, device_address_list,
134 val.value());
135 } else {
136 // Construct none kernel tensor
137 MS_EXCEPTION_IF_NULL(kernel_tensor_list);
138 MS_EXCEPTION_IF_NULL(device_address_list);
139
140 const auto &kernel_tensor = std::make_shared<kernel::KernelTensor>(
141 std::make_shared<abstract::TensorShape>(ShapeVector()), kTypeNone, kNone, nullptr, 0, kOpFormat_DEFAULT,
142 kTypeNone->type_id(), ShapeVector(), device_context->device_context_key().device_name_,
143 device_context->device_context_key().device_id_);
144 kernel_tensor->set_stream_id(stream_id);
145 (void)kernel_tensor_list->emplace_back(kernel_tensor.get());
146 auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
147 (void)device_address_list->emplace_back(device_address);
148 }
149 }
150
151 static void GetKernelTensor(const DeviceContext *device_context, size_t stream_id,
152 const abstract::AbstractBasePtr &input_abs, size_t index,
153 std::vector<kernel::KernelTensor *> *kernel_tensor_list,
154 device::DeviceAddressPtrList *device_address_list,
155 const std::vector<tensor::BaseTensorPtr> &tensors);
156
157 template <typename T>
GetKernelTensor(const DeviceContext * device_context,size_t stream_id,const abstract::AbstractBasePtr & input_abs,size_t index,std::vector<kernel::KernelTensor * > * kernel_tensor_list,device::DeviceAddressPtrList * device_address_list,const T & val)158 static void GetKernelTensor(const DeviceContext *device_context, size_t stream_id,
159 const abstract::AbstractBasePtr &input_abs, size_t index,
160 std::vector<kernel::KernelTensor *> *kernel_tensor_list,
161 device::DeviceAddressPtrList *device_address_list, const T &val) {
162 // Value ptr alloc device address and malloc mem here
163 auto device_address =
164 runtime::DeviceAddressUtils::CreateInputAddress(device_context, stream_id, input_abs, index, val);
165 MS_EXCEPTION_IF_NULL(device_address);
166 MS_EXCEPTION_IF_NULL(device_address_list);
167 MS_EXCEPTION_IF_NULL(kernel_tensor_list);
168 (void)device_address_list->emplace_back(device_address);
169 (void)kernel_tensor_list->emplace_back(device_address->kernel_tensor().get());
170 }
171
172 // Create output tensor device address without kernel tensor
PrepareOpOutputs(const DeviceContext * device_context,size_t stream_id,const std::vector<tensor::BaseTensorPtr> & outputs)173 static void PrepareOpOutputs(const DeviceContext *device_context, size_t stream_id,
174 const std::vector<tensor::BaseTensorPtr> &outputs) {
175 runtime::DeviceAddressUtils::CreateOutputTensorAddress(device_context, stream_id, outputs);
176 }
177
178 // Create output tensor device address without kernel tensor
MallocOpOutputs(const DeviceContext * device_context,const std::vector<tensor::BaseTensorPtr> & outputs)179 static void MallocOpOutputs(const DeviceContext *device_context, const std::vector<tensor::BaseTensorPtr> &outputs) {
180 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyBoostMallocOutput,
181 runtime::ProfilerRecorder::kNoName, false);
182 runtime::DeviceAddressUtils::MallocForOutputs(device_context, outputs);
183 }
184
185 // Create workspace device address with kernel tensor
186 static std::vector<kernel::KernelTensor *> GetKernelTensorFromAddress(
187 const device::DeviceAddressPtrList &input_device_address);
188
189 // Check kernel mod is reg
190 static bool IsKernelModRegistered(const std::string &device_name, const std::string &op_name);
191 static bool IsPyBoostCustomRegistered(const std::string &device_name, const std::string &op_name);
192
193 static kernel::KernelModPtr CreateKernelMod(const PrimitivePtr &prim, const std::string &op_name,
194 const DeviceContext *device_context,
195 const std::vector<KernelTensor *> &inputs,
196 const std::vector<KernelTensor *> &outputs);
197 // return IsStrictlyMatched and KernelAttr
198 static std::pair<bool, KernelAttr> SelectKernel(const std::vector<AbstractBasePtr> &inputs_abs,
199 const AbstractBasePtr &outputs_abs,
200 const DeviceContext *device_context, const std::string &op_name);
201 static std::optional<tensor::BaseTensorPtr> CastTensor(const std::optional<tensor::BaseTensorPtr> &tensor,
202 const TypeId &type_id, const std::string &device_target);
203 static tensor::BaseTensorPtr CastTensor(const tensor::BaseTensorPtr &tensor, const TypeId &type_id,
204 const std::string &device_target);
205 static std::vector<tensor::BaseTensorPtr> CastTensor(const std::vector<tensor::BaseTensorPtr> &tensors,
206 const std::vector<TypeId> &type_id_list,
207 const std::string &device_target);
208 // ValueTuple input
209 static std::vector<tensor::BaseTensorPtr> CastTensor(const std::vector<tensor::BaseTensorPtr> &tensors,
210 TypeId type_id, const std::string &device_target);
211 template <typename... T>
SelectKernel(AbstractConverter * converter,const DeviceContext * device_context,const std::string & op_name,const ValueSimpleInfoPtr & output_value_simple_info,const T &...args)212 static std::pair<bool, KernelAttr> SelectKernel(AbstractConverter *converter, const DeviceContext *device_context,
213 const std::string &op_name,
214 const ValueSimpleInfoPtr &output_value_simple_info,
215 const T &... args) {
216 // Get inputs abstract
217 std::vector<AbstractBasePtr> input_abs;
218 ((void)input_abs.emplace_back(converter->ConvertAbstract(args)), ...);
219
220 // Get output abstract
221 auto output_abs = TransformValueSimpleInfoToAbstract(*output_value_simple_info);
222 return SelectKernel(input_abs, output_abs, device_context, op_name);
223 }
ConvertTensorVectorToTuple(const std::vector<BaseTensorPtr> & tensor_list)224 static ValueTuplePtr ConvertTensorVectorToTuple(const std::vector<BaseTensorPtr> &tensor_list) {
225 vector<ValuePtr> value_vector;
226 for (const auto &tensor : tensor_list) {
227 (void)value_vector.emplace_back(tensor);
228 }
229 auto result = std::make_shared<ValueTuple>(value_vector);
230 MS_LOG(DEBUG) << "Convert TensorList to ValueTuple " << result->ToString();
231 return result;
232 }
233 static BaseTensorPtr ScalarToTensor(const ScalarPtr &scalar);
234
cur_stream_id()235 static uint32_t cur_stream_id() { return cur_stream_id_; }
236
237 // Set current stream for CREATE_PYBOOST_OP in front queue.
set_cur_stream_id(uint32_t cur_stream_id)238 static void set_cur_stream_id(uint32_t cur_stream_id) { cur_stream_id_ = cur_stream_id; }
239
240 private:
241 inline static uint32_t cur_stream_id_ = kDefaultStreamIndex;
242 };
243
244 template <typename T>
ConvertValueTupleToVector(const ValueTuplePtr & tuple)245 std::vector<T> ConvertValueTupleToVector(const ValueTuplePtr &tuple) {
246 std::vector<T> result;
247 const auto &values = tuple->value();
248 for (const auto &value : values) {
249 (void)result.emplace_back(GetValue<T>(value));
250 }
251 MS_LOG(DEBUG) << "Convert ValueTuple to vector " << result;
252 return result;
253 }
254
255 // Shield kernel hardware differences. Call some func of derived classes based on base classes.
256 // Just like SetThreadPool
257 class BACKEND_EXPORT PyboostKernelExtraFuncFactory {
258 public:
259 static PyboostKernelExtraFuncFactory &GetInstance();
260 PyboostKernelExtraFuncFactory() = default;
261 ~PyboostKernelExtraFuncFactory() = default;
AddPyboostKernelExtraFunc(const std::string & op_name,const PyboostKernelExtraFuncPtr & func)262 void AddPyboostKernelExtraFunc(const std::string &op_name, const PyboostKernelExtraFuncPtr &func) {
263 kernel_func_map_[op_name] = func;
264 }
265
SetThreadPool(const std::string & device_name,const kernel::KernelModPtr & kernel)266 void SetThreadPool(const std::string &device_name, const kernel::KernelModPtr &kernel) {
267 auto iter = kernel_func_map_.find(device_name);
268 if (iter == kernel_func_map_.end()) {
269 return;
270 }
271 iter->second->SetThreadPool(kernel);
272 }
273
IsKernelModRegistered(const std::string & device_name,const std::string & op_name)274 bool IsKernelModRegistered(const std::string &device_name, const std::string &op_name) {
275 auto iter = kernel_func_map_.find(device_name);
276 if (iter == kernel_func_map_.end()) {
277 return true;
278 }
279 return iter->second->IsKernelModRegistered(op_name);
280 }
281
IsPyBoostCustomRegistered(const std::string & device_name,const std::string & op_name)282 bool IsPyBoostCustomRegistered(const std::string &device_name, const std::string &op_name) {
283 auto iter = kernel_func_map_.find(device_name);
284 if (iter == kernel_func_map_.end()) {
285 return true;
286 }
287 return iter->second->IsPyBoostCustomRegistered(op_name);
288 }
289
IsEnableProfiler(const std::string & device_name)290 bool IsEnableProfiler(const std::string &device_name) {
291 auto iter = kernel_func_map_.find(device_name);
292 if (iter == kernel_func_map_.end()) {
293 return false;
294 }
295 return iter->second->IsEnableProfiler();
296 }
297
LaunchKernelWithProfiler(const std::string & device_name,const device::DeviceContext * device_context,const std::string & op_name,const std::vector<BaseShapePtr> & base_shape,const std::function<void ()> & func)298 void LaunchKernelWithProfiler(const std::string &device_name, const device::DeviceContext *device_context,
299 const std::string &op_name, const std::vector<BaseShapePtr> &base_shape,
300 const std::function<void()> &func) {
301 auto iter = kernel_func_map_.find(device_name);
302 if (iter == kernel_func_map_.end()) {
303 return;
304 }
305 iter->second->LaunchKernelWithProfiler(op_name, device_context, base_shape, func);
306 }
307
308 private:
309 mindspore::HashMap<std::string, PyboostKernelExtraFuncPtr> kernel_func_map_;
310 };
311
312 class PyboostKernelExtraFuncRegistrar {
313 public:
PyboostKernelExtraFuncRegistrar(const std::string & op_name,const PyboostKernelExtraFuncPtr & func)314 PyboostKernelExtraFuncRegistrar(const std::string &op_name, const PyboostKernelExtraFuncPtr &func) {
315 PyboostKernelExtraFuncFactory::GetInstance().AddPyboostKernelExtraFunc(op_name, func);
316 }
317
318 ~PyboostKernelExtraFuncRegistrar() = default;
319 };
320
321 #define REG_PYBOOST_KERNEL_EXTRA_FUN(op_name, func) \
322 static PyboostKernelExtraFuncRegistrar g_##op_name##PyboostKernelExtraFunc(#op_name, std::make_shared<func>());
323
324 } // namespace pyboost
325 } // namespace kernel
326 } // namespace mindspore
327 #endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_PYBOOST_PYBOOST_UTILS_H_
328