1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "kernel/pyboost/pyboost_utils.h"
18 #include <algorithm>
19 #include <utility>
20 #include <unordered_map>
21 #include "kernel/common_utils.h"
22 #include "kernel/kernel_mod_cache.h"
23 #include "mindapi/base/type_id.h"
24 #include "runtime/device/device_address_utils.h"
25 #include "ops/ops_frontend_func_impl.h"
26 #include "ops/op_def.h"
27 #include "runtime/pynative/op_executor.h"
28 #include "pybind_api/gil_scoped_long_running.h"
29 #include "mindspore/ccsrc/plugin/device/cpu/kernel/cpu_kernel.h"
30 #include "kernel/pyboost/auto_generate/cast.h"
31 #include "mindspore/core/ops/array_ops.h"
32
33 namespace mindspore {
34 namespace kernel {
35 namespace pyboost {
36 namespace {
CreateTensor(const TypePtr & type,const ShapeVector & shape_vector,const AbstractBasePtr & abstract_tensor,std::vector<tensor::BaseTensorPtr> * outputs)37 void CreateTensor(const TypePtr &type, const ShapeVector &shape_vector, const AbstractBasePtr &abstract_tensor,
38 std::vector<tensor::BaseTensorPtr> *outputs) {
39 auto output_tensor = std::make_shared<tensor::BaseTensor>(type->type_id(), shape_vector);
40 output_tensor->set_abstract(abstract_tensor);
41 output_tensor->set_need_pipeline_sync(true);
42 (void)outputs->emplace_back(output_tensor);
43 MS_LOG(DEBUG) << "Create output tensor " << output_tensor->ToString();
44 }
45
CreateTensor(const TypePtr & type,const ShapeVector & shape_vector,std::vector<tensor::BaseTensorPtr> * outputs)46 void CreateTensor(const TypePtr &type, const ShapeVector &shape_vector, std::vector<tensor::BaseTensorPtr> *outputs) {
47 auto output_tensor = std::make_shared<tensor::BaseTensor>(type->type_id(), shape_vector);
48 output_tensor->set_need_pipeline_sync(true);
49 (void)outputs->emplace_back(output_tensor);
50 MS_LOG(DEBUG) << "Create output tensor " << output_tensor->ToString();
51 }
52 } // namespace
53
ToAbstractNoValue(const tensor::BaseTensorPtr & tensor)54 AbstractBasePtr ToAbstractNoValue(const tensor::BaseTensorPtr &tensor) {
55 auto abs = tensor->GetAbstractCache();
56 abs->set_value(kValueAny);
57 return abs;
58 }
59
CreateOutputTensor(const AbstractBasePtr & abstract,std::vector<tensor::BaseTensorPtr> * outputs)60 void PyBoostUtils::CreateOutputTensor(const AbstractBasePtr &abstract, std::vector<tensor::BaseTensorPtr> *outputs) {
61 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative,
62 runtime::ProfilerEvent::kPyBoostCreateOutputTensor,
63 runtime::ProfilerRecorder::kNoName, false);
64 MS_EXCEPTION_IF_NULL(abstract);
65 if (abstract->isa<abstract::AbstractSequence>()) {
66 const auto &seq = abstract->cast<abstract::AbstractSequencePtr>();
67 const auto &elements = seq->elements();
68 for (const auto &element : elements) {
69 CreateOutputTensor(element, outputs);
70 }
71 } else if (abstract->isa<abstract::AbstractTensor>()) {
72 const auto &abstract_tensor = abstract->cast<abstract::AbstractTensorPtr>();
73 const auto &shape = abstract_tensor->GetShapeTrack();
74 const auto &type = abstract_tensor->element()->GetTypeTrack();
75 MS_LOG(DEBUG) << "get abstract tensor shape " << shape->ToString() << " type " << type->ToString();
76 if (!shape->isa<abstract::Shape>()) {
77 MS_LOG(EXCEPTION) << "AbstractTensor shape is valid " << shape->ToString();
78 }
79 const auto &shape_vector = shape->cast<abstract::ShapePtr>()->shape();
80 CreateTensor(type, shape_vector, abstract_tensor, outputs);
81 } else if (abstract->isa<abstract::AbstractScalar>()) {
82 const auto &scalar = abstract->cast<abstract::AbstractScalarPtr>();
83 const auto &type = scalar->GetTypeTrack();
84 MS_LOG(DEBUG) << "Create scalar tensor type " << type->ToString();
85 CreateTensor(type, {}, nullptr, outputs);
86 } else {
87 MS_LOG(EXCEPTION) << "Not support abstract " << abstract->ToString();
88 }
89 }
90
ScalarToTensor(const ScalarPtr & scalar)91 tensor::BaseTensorPtr PyBoostUtils::ScalarToTensor(const ScalarPtr &scalar) {
92 if (scalar == nullptr) {
93 MS_EXCEPTION(ArgumentError) << "Nullptr Error!";
94 }
95 TypePtr data_type = scalar->type();
96 MS_EXCEPTION_IF_NULL(data_type);
97 TypeId type_id = data_type->type_id();
98 switch (type_id) {
99 case kNumberTypeBool:
100 return std::make_shared<tensor::BaseTensor>(GetValue<bool>(scalar), data_type);
101 case kNumberTypeInt8:
102 return std::make_shared<tensor::BaseTensor>(static_cast<int64_t>(GetValue<int8_t>(scalar)), data_type);
103 case kNumberTypeInt16:
104 return std::make_shared<tensor::BaseTensor>(static_cast<int64_t>(GetValue<int16_t>(scalar)), data_type);
105 case kNumberTypeInt32:
106 return std::make_shared<tensor::BaseTensor>(static_cast<int64_t>(GetValue<int32_t>(scalar)), data_type);
107 case kNumberTypeInt64:
108 return std::make_shared<tensor::BaseTensor>(GetValue<int64_t>(scalar), data_type);
109 case kNumberTypeUInt8:
110 return std::make_shared<tensor::BaseTensor>(static_cast<uint64_t>(GetValue<uint8_t>(scalar)), data_type);
111 case kNumberTypeUInt16:
112 return std::make_shared<tensor::BaseTensor>(static_cast<uint64_t>(GetValue<uint16_t>(scalar)), data_type);
113 case kNumberTypeUInt32:
114 return std::make_shared<tensor::BaseTensor>(static_cast<uint64_t>(GetValue<uint32_t>(scalar)), data_type);
115 case kNumberTypeUInt64:
116 return std::make_shared<tensor::BaseTensor>(GetValue<uint64_t>(scalar), data_type);
117 case kNumberTypeFloat32:
118 return std::make_shared<tensor::BaseTensor>(GetValue<float>(scalar), data_type);
119 case kNumberTypeFloat64:
120 return std::make_shared<tensor::BaseTensor>(GetValue<double>(scalar), data_type);
121 default:
122 MS_LOG(EXCEPTION) << "When convert scalar to tensor, the scalar type: " << data_type << " is invalid.";
123 }
124 }
125
CreateOutputTensor(const ValueSimpleInfoPtr & output_value_simple_info,std::vector<tensor::BaseTensorPtr> * outputs)126 void PyBoostUtils::CreateOutputTensor(const ValueSimpleInfoPtr &output_value_simple_info,
127 std::vector<tensor::BaseTensorPtr> *outputs) {
128 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative,
129 runtime::ProfilerEvent::kPyBoostCreateOutputTensor,
130 runtime::ProfilerRecorder::kNoName, false);
131 MS_EXCEPTION_IF_NULL(output_value_simple_info);
132 size_t elem_size = output_value_simple_info->dtype_vector_.size();
133 for (size_t i = 0; i < elem_size; ++i) {
134 MS_LOG(DEBUG) << "Get tensor shape " << output_value_simple_info->shape_vector_[i] << ", type "
135 << TypeIdToType(output_value_simple_info->dtype_vector_[i]->type_id())->ToString();
136 CreateTensor(output_value_simple_info->dtype_vector_[i], output_value_simple_info->shape_vector_[i], outputs);
137 }
138 }
139
IsKernelModRegistered(const std::string & device_name,const std::string & op_name)140 bool PyBoostUtils::IsKernelModRegistered(const std::string &device_name, const std::string &op_name) {
141 return PyboostKernelExtraFuncFactory::GetInstance().IsKernelModRegistered(device_name, op_name);
142 }
143
IsPyBoostCustomRegistered(const std::string & device_name,const std::string & op_name)144 bool PyBoostUtils::IsPyBoostCustomRegistered(const std::string &device_name, const std::string &op_name) {
145 return PyboostKernelExtraFuncFactory::GetInstance().IsPyBoostCustomRegistered(device_name, op_name);
146 }
147
CreateKernelMod(const PrimitivePtr & prim,const std::string & op_name,const DeviceContext * device_context,const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)148 kernel::KernelModPtr PyBoostUtils::CreateKernelMod(const PrimitivePtr &prim, const std::string &op_name,
149 const DeviceContext *device_context,
150 const std::vector<KernelTensor *> &inputs,
151 const std::vector<KernelTensor *> &outputs) {
152 MS_EXCEPTION_IF_NULL(device_context);
153 const auto &device_name = device_context->device_context_key().device_name_;
154
155 auto &cache_helper = kernel::KernelModCache::GetInstance();
156 const auto &key = cache_helper.GetKernelModKey(op_name, device_name, inputs);
157 auto kernel_mod = cache_helper.GetKernelMod(key);
158 if (kernel_mod == nullptr) {
159 kernel_mod = device_context->GetKernelExecutor(false)->CreateKernelMod(op_name);
160 if (kernel_mod == nullptr) {
161 MS_LOG(EXCEPTION) << "Create kernelmod for op " << op_name << " failed";
162 }
163 if (!kernel_mod->Init(prim, inputs, outputs)) {
164 MS_LOG(EXCEPTION) << "KernelMod Init Failed: " << op_name;
165 }
166 cache_helper.SetCache(key, kernel_mod);
167 PyboostKernelExtraFuncFactory::GetInstance().SetThreadPool(device_name, kernel_mod);
168 }
169
170 return kernel_mod;
171 }
172
ContiguousByDeviceAddress(const DeviceSyncPtr & device_sync)173 DeviceSyncPtr PyBoostUtils::ContiguousByDeviceAddress(const DeviceSyncPtr &device_sync) {
174 auto &storage_info = device_sync->GetTensorStorageInfo();
175 if (storage_info == nullptr) {
176 return device_sync;
177 }
178
179 auto old_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(device_sync);
180
181 MS_EXCEPTION_IF_NULL(old_device_address);
182 MS_EXCEPTION_IF_NULL(storage_info);
183 GilReleaseWithCheck gil_release;
184
185 const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
186 {old_device_address->device_name(), old_device_address->device_id()});
187 MS_EXCEPTION_IF_NULL(device_context);
188
189 auto stream_id = device_context->device_res_manager_->GetCurrentStreamId();
190 auto address_size = GetTypeByte(TypeIdToType(old_device_address->type_id())) * SizeOf(storage_info->shape);
191 auto new_device_address = device_context->device_res_manager_->CreateDeviceAddress(
192 nullptr, address_size, storage_info->shape, DEFAULT_FORMAT, old_device_address->type_id(),
193 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_, stream_id);
194 new_device_address->set_device_shape(storage_info->shape);
195 new_device_address->set_original_ref_count(SIZE_MAX);
196 new_device_address->ResetRefCount();
197
198 if (!device_context->GetKernelExecutor(false)->ExecuteKernelTask(
199 runtime::KernelTaskType::kCONTIGUOUS_TASK, {old_device_address}, {new_device_address}, stream_id)) {
200 MS_LOG(EXCEPTION) << "ExecuteKernelTask failed, task_type:" << runtime::KernelTaskType::kCONTIGUOUS_TASK;
201 }
202 runtime::OpExecutor::GetInstance().WaitAll();
203 return new_device_address;
204 }
205
CreateOutputTensor(const DeviceContext * device_context,const tensor::BaseTensorPtr & input,const TensorStorageInfoPtrList & storage_info_list,std::vector<tensor::BaseTensorPtr> * outputs)206 void PyBoostUtils::CreateOutputTensor(const DeviceContext *device_context, const tensor::BaseTensorPtr &input,
207 const TensorStorageInfoPtrList &storage_info_list,
208 std::vector<tensor::BaseTensorPtr> *outputs) {
209 for (auto &storage_info : storage_info_list) {
210 CreateOutputTensor(device_context, input, storage_info, outputs);
211 }
212 }
213
CreateOutputTensor(const DeviceContext * device_context,const tensor::BaseTensorPtr & input,const TensorStorageInfoPtr & storage_info,std::vector<tensor::BaseTensorPtr> * outputs)214 void PyBoostUtils::CreateOutputTensor(const DeviceContext *device_context, const tensor::BaseTensorPtr &input,
215 const TensorStorageInfoPtr &storage_info,
216 std::vector<tensor::BaseTensorPtr> *outputs) {
217 MS_EXCEPTION_IF_NULL(input);
218 MS_EXCEPTION_IF_NULL(storage_info);
219 MS_EXCEPTION_IF_NULL(device_context);
220
221 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative,
222 runtime::ProfilerEvent::kPyBoostCreateOutputTensor,
223 runtime::ProfilerRecorder::kNoName, false);
224 auto output_tensor = std::make_shared<tensor::BaseTensor>(input->data_type(), storage_info->shape);
225 output_tensor->set_need_pipeline_sync(true);
226 output_tensor->set_contiguous_callback(
227 [](const DeviceSyncPtr &device_address) -> DeviceSyncPtr { return ContiguousByDeviceAddress(device_address); });
228
229 auto input_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(input->device_address());
230 MS_EXCEPTION_IF_NULL(input_device_address);
231 input_device_address->set_is_view(true);
232
233 // Create view output address
234 auto output_device_address = device_context->device_res_manager_->CreateDeviceAddress(
235 nullptr, input_device_address->GetSize(), output_tensor->shape(), DEFAULT_FORMAT, output_tensor->data_type(),
236 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_,
237 input_device_address->stream_id());
238 MS_EXCEPTION_IF_NULL(output_device_address);
239 output_device_address->set_tensor_storage_info(storage_info);
240 output_device_address->set_pointer_ref_count(input_device_address->pointer_ref_count());
241 output_tensor->set_device_address(output_device_address);
242 (void)outputs->emplace_back(output_tensor);
243 MS_LOG(DEBUG) << "Create output tensor " << output_tensor->ToString() << " with " << storage_info->ToString();
244 }
245
InferByOpDef(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_abs)246 AbstractBasePtr PyBoostUtils::InferByOpDef(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_abs) {
247 MS_EXCEPTION_IF_NULL(prim);
248 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyBoostInferByOpDef,
249 prim->name(), false);
250 auto op_def = mindspore::ops::GetOpDef(prim->name());
251 if (op_def) {
252 (void)op_def->func_impl_.CheckValidation(prim, input_abs);
253 auto shape = op_def->func_impl_.InferShape(prim, input_abs);
254 auto type = op_def->func_impl_.InferType(prim, input_abs);
255 auto output_abs = mindspore::abstract::MakeAbstract(shape, type);
256 MS_LOG(DEBUG) << "Pynative Infer " << prim->name() << " by OpDef, got abstract: " << output_abs->ToString();
257 return output_abs;
258 } else {
259 const auto &infer_map = abstract::GetPrimitiveInferMapPtr();
260 const auto &iter = infer_map->find(prim);
261 if (iter != infer_map->end()) {
262 auto output_abs = iter->second.InferShapeAndType(nullptr, prim, input_abs);
263 MS_LOG(DEBUG) << "Pynative Infer " << prim->name()
264 << " by C++ PrimitiveInferMap, got abstract: " << output_abs->ToString();
265 return output_abs;
266 } else {
267 MS_LOG(EXCEPTION) << "Cannot found infer function for Op " << prim->name();
268 }
269 }
270 }
271
DispatchRun(const std::shared_ptr<runtime::PyBoostDeviceTask> & task)272 void PyBoostUtils::DispatchRun(const std::shared_ptr<runtime::PyBoostDeviceTask> &task) {
273 static auto need_sync = runtime::OpExecutor::NeedSync();
274 if (need_sync && !runtime::OpExecutor::GetInstance().async_for_graph()) {
275 MS_LOG(INFO) << "PyBoost sync run device task";
276 runtime::OpExecutor::GetInstance().WaitAll();
277 task->Run();
278 } else {
279 runtime::ProfilerAnalyzer::GetInstance().RecordFlowData(task->task_id());
280 runtime::OpExecutor::GetInstance().PushOpRunTask(task);
281 }
282 }
283
GetKernelTensorFromAddress(const device::DeviceAddressPtrList & input_device_address)284 std::vector<kernel::KernelTensor *> PyBoostUtils::GetKernelTensorFromAddress(
285 const device::DeviceAddressPtrList &input_device_address) {
286 std::vector<kernel::KernelTensor *> input_kernel_tensors;
287 std::transform(input_device_address.begin(), input_device_address.end(), std::back_inserter(input_kernel_tensors),
288 [](const auto &item) { return item->kernel_tensor().get(); });
289 return input_kernel_tensors;
290 }
291
GetKernelTensor(const DeviceContext * device_context,size_t stream_id,const abstract::AbstractBasePtr & input_abs,size_t index,std::vector<kernel::KernelTensor * > * kernel_tensor_list,device::DeviceAddressPtrList * device_address_list,const BaseTensorPtr & tensor)292 void PyBoostUtils::GetKernelTensor(const DeviceContext *device_context, size_t stream_id,
293 const abstract::AbstractBasePtr &input_abs, size_t index,
294 std::vector<kernel::KernelTensor *> *kernel_tensor_list,
295 device::DeviceAddressPtrList *device_address_list, const BaseTensorPtr &tensor) {
296 MS_EXCEPTION_IF_NULL(tensor);
297 MS_EXCEPTION_IF_NULL(kernel_tensor_list);
298 MS_EXCEPTION_IF_NULL(device_address_list);
299
300 const auto &device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
301 MS_EXCEPTION_IF_NULL(device_address);
302 (void)device_address_list->emplace_back(device_address);
303 const auto &kernel_tensor = device_address->kernel_tensor();
304 MS_EXCEPTION_IF_NULL(kernel_tensor);
305 (void)kernel_tensor_list->emplace_back(kernel_tensor.get());
306 }
307
GetKernelTensor(const DeviceContext * device_context,size_t stream_id,const abstract::AbstractBasePtr & input_abs,size_t index,std::vector<kernel::KernelTensor * > * kernel_tensor_list,device::DeviceAddressPtrList * device_address_list,const std::vector<tensor::BaseTensorPtr> & tensors)308 void PyBoostUtils::GetKernelTensor(const DeviceContext *device_context, size_t stream_id,
309 const abstract::AbstractBasePtr &input_abs, size_t index,
310 std::vector<kernel::KernelTensor *> *kernel_tensor_list,
311 device::DeviceAddressPtrList *device_address_list,
312 const std::vector<tensor::BaseTensorPtr> &tensors) {
313 for (const auto &tensor : tensors) {
314 // input_abs is not used in GetKernelTensor when value is TensorPtr.
315 GetKernelTensor(device_context, stream_id, input_abs, index, kernel_tensor_list, device_address_list, tensor);
316 }
317 }
318
CreateWorkSpaceDeviceAddress(const KernelModPtr & kernel_mod,const device::DeviceContext * device_context,const std::string & op_name)319 device::DeviceAddressPtrList PyBoostUtils::CreateWorkSpaceDeviceAddress(const KernelModPtr &kernel_mod,
320 const device::DeviceContext *device_context,
321 const std::string &op_name) {
322 MS_EXCEPTION_IF_NULL(device_context);
323 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
324 MS_EXCEPTION_IF_NULL(kernel_mod);
325
326 const auto &workspace_sizes = kernel_mod->GetWorkspaceSizeList();
327 device::DeviceAddressPtrList workspaces_address;
328 for (const auto workspace_size : workspace_sizes) {
329 auto kernel_tensor = std::make_shared<KernelTensor>(
330 nullptr, workspace_size, Format::DEFAULT_FORMAT, kTypeUnknown, ShapeVector(),
331 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
332 auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
333 MS_LOG(DEBUG) << "Create workspace for op: " << op_name << " addr: " << device_address;
334 MS_EXCEPTION_IF_NULL(device_address);
335 (void)workspaces_address.emplace_back(device_address);
336 }
337
338 for (size_t i = 0; i < workspace_sizes.size(); ++i) {
339 auto device_address = workspaces_address[i];
340 MS_EXCEPTION_IF_NULL(device_address);
341 if (device_address->GetPtr() == nullptr &&
342 !device_context->device_res_manager_->AllocateMemory(device_address.get())) {
343 MS_LOG(EXCEPTION) << "Allocate workspace memory failed";
344 }
345 MS_LOG(DEBUG) << "workspace[" << i << "]:" << device_address->kernel_tensor()->device_ptr()
346 << " size:" << device_address->kernel_tensor()->size();
347 }
348 return workspaces_address;
349 }
350
GetInstance()351 PyboostKernelExtraFuncFactory &PyboostKernelExtraFuncFactory::GetInstance() {
352 static PyboostKernelExtraFuncFactory instance;
353 return instance;
354 }
355
LaunchKernel(const PrimitivePtr & primitive,const DeviceContext * device_context,const AddressInfoPair & input_address_info,const AddressInfoPair & output_address_info,size_t stream_id)356 void PyBoostUtils::LaunchKernel(const PrimitivePtr &primitive, const DeviceContext *device_context,
357 const AddressInfoPair &input_address_info, const AddressInfoPair &output_address_info,
358 size_t stream_id) {
359 const auto &real_name = primitive->name();
360 // KernelMod init
361 auto kernel_mod = PyBoostUtils::CreateKernelMod(primitive, real_name, device_context, input_address_info.first,
362 output_address_info.first);
363 MS_EXCEPTION_IF_NULL(kernel_mod);
364 // KernelMod resize
365 if (kernel_mod->Resize(input_address_info.first, output_address_info.first) == kernel::KRET_RESIZE_FAILED) {
366 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Kernel build failed:#dmsg#CPU kernel op [" << real_name << "] resize failed.";
367 }
368 // Get workspace address
369 const auto &workspace_device_address =
370 PyBoostUtils::CreateWorkSpaceDeviceAddress(kernel_mod, device_context, primitive->name());
371 const auto &workspace_kernel_tensors = PyBoostUtils::GetKernelTensorFromAddress(workspace_device_address);
372
373 const auto &device_name = device_context->device_context_key().device_name_;
374 void *stream_ptr = device_context->device_res_manager_->GetStream(stream_id);
375 if (!PyboostKernelExtraFuncFactory::GetInstance().IsEnableProfiler(device_name)) {
376 if (!kernel_mod->Launch(input_address_info.first, workspace_kernel_tensors, output_address_info.first,
377 stream_ptr)) {
378 MS_LOG(EXCEPTION) << "Launch kernel failed, name: " << real_name;
379 }
380 } else {
381 const auto &input_kts = input_address_info.first;
382 std::vector<BaseShapePtr> input_shapes;
383 for (auto kt : input_kts) {
384 MS_EXCEPTION_IF_NULL(kt);
385 input_shapes.push_back(kt->GetShape());
386 }
387 PyboostKernelExtraFuncFactory::GetInstance().LaunchKernelWithProfiler(
388 device_name, device_context, real_name, {}, [&]() {
389 if (!kernel_mod->Launch(input_address_info.first, workspace_kernel_tensors, output_address_info.first,
390 stream_ptr)) {
391 MS_LOG(EXCEPTION) << "Launch kernel failed, name: " << real_name;
392 }
393 });
394 }
395 if (kernel_mod->IsNeedUpdateOutputShapeAndSize()) {
396 kernel_mod->UpdateOutputShapeAndSize(input_address_info.first, output_address_info.first);
397 }
398 runtime::DeviceAddressUtils::ProcessCrossStreamAddress(real_name, device_context, stream_id, input_address_info.first,
399 output_address_info.first);
400 MS_LOG(DEBUG) << real_name << " Launch end";
401 }
402
403 namespace {
GetTypeIdFromAbstractTensor(const AbstractBasePtr & abs_base)404 TypeId GetTypeIdFromAbstractTensor(const AbstractBasePtr &abs_base) {
405 if (abs_base->isa<abstract::AbstractTensor>()) {
406 auto abs_tensor = std::dynamic_pointer_cast<abstract::AbstractTensor>(abs_base);
407 return abs_tensor->element()->BuildType()->type_id();
408 }
409 return abs_base->BuildType()->type_id();
410 }
411
GetAbstractObjectType(const AbstractBasePtr & abstract)412 TypeId GetAbstractObjectType(const AbstractBasePtr &abstract) {
413 if (abstract == nullptr) {
414 return kTypeUnknown;
415 }
416 if (abstract->isa<abstract::AbstractTensor>()) {
417 return kObjectTypeTensorType;
418 }
419 if (abstract->isa<abstract::AbstractTuple>()) {
420 return kObjectTypeTuple;
421 }
422 if (abstract->isa<abstract::AbstractList>()) {
423 return kObjectTypeList;
424 }
425 if (abstract->isa<abstract::AbstractScalar>()) {
426 // scalar input may not converted to tensor
427 return kObjectTypeNumber;
428 }
429 if (abstract->isa<abstract::AbstractNone>()) {
430 return kMetaTypeNone;
431 }
432
433 return kTypeUnknown;
434 }
435
GetOutputTypeFromAbstractBase(const AbstractBasePtr & abs_base)436 std::pair<std::vector<TypeId>, std::vector<TypeId>> GetOutputTypeFromAbstractBase(const AbstractBasePtr &abs_base) {
437 std::vector<TypeId> output_dtype;
438 std::vector<TypeId> output_type;
439 if (abs_base->isa<abstract::AbstractTuple>()) {
440 auto abs_tuple = std::dynamic_pointer_cast<abstract::AbstractTuple>(abs_base);
441 for (auto &abs : abs_tuple->elements()) {
442 (void)output_dtype.emplace_back(GetTypeIdFromAbstractTensor(abs));
443 (void)output_type.emplace_back(GetAbstractObjectType(abs));
444 }
445 } else {
446 (void)output_type.emplace_back(GetAbstractObjectType(abs_base));
447 (void)output_dtype.emplace_back(GetTypeIdFromAbstractTensor(abs_base));
448 }
449 return std::make_pair(output_type, output_dtype);
450 }
451
GetInputTypeFromAbstractBase(const std::vector<AbstractBasePtr> & abs_vec)452 std::pair<std::vector<TypeId>, std::vector<TypeId>> GetInputTypeFromAbstractBase(
453 const std::vector<AbstractBasePtr> &abs_vec) {
454 std::vector<TypeId> input_dtype;
455 std::vector<TypeId> input_type;
456 for (auto &abs : abs_vec) {
457 if (abs->isa<abstract::AbstractTuple>()) {
458 // a tuple tensors have same type
459 auto abs_tuple = std::dynamic_pointer_cast<abstract::AbstractTuple>(abs);
460 if (abs_tuple->elements().empty()) {
461 input_dtype.emplace_back(kTypeUnknown);
462 continue;
463 }
464 input_dtype.emplace_back(abs_tuple->elements()[0]->BuildType()->type_id());
465 } else {
466 input_dtype.emplace_back(GetTypeIdFromAbstractTensor(abs));
467 }
468 input_type.emplace_back(GetAbstractObjectType(abs));
469 }
470 return std::make_pair(input_type, input_dtype);
471 }
472
InputDtypeMatch(TypeId input_attr,TypeId input_type)473 bool InputDtypeMatch(TypeId input_attr, TypeId input_type) {
474 if (input_attr == input_type || kTypeUnknown == input_type) {
475 return true;
476 }
477 if (input_attr == kNumberTypeInt32 && (input_type == kNumberTypeInt16 || input_type == kNumberTypeInt64)) {
478 return true;
479 }
480 if (input_attr == kNumberTypeFloat32 && (input_type == kNumberTypeFloat16 || input_type == kNumberTypeFloat64)) {
481 return true;
482 }
483 return false;
484 }
485
IsObjectDtypeWeaklyMatched(const std::vector<TypeId> & object_dtypes,const std::vector<DataType> & kernel_data_types)486 bool IsObjectDtypeWeaklyMatched(const std::vector<TypeId> &object_dtypes,
487 const std::vector<DataType> &kernel_data_types) {
488 // only support CPU
489 for (size_t i = 0; i < object_dtypes.size(); i++) {
490 // For optional input, the real input object type can be a None.
491 if (!InputDtypeMatch(kernel_data_types[i].dtype, object_dtypes[i])) {
492 return false;
493 }
494 }
495 return true;
496 }
497
IsObjectStrictlyMatched(const std::vector<TypeId> & object_types,const std::vector<TypeId> & object_dtypes,const std::vector<DataType> & kernel_data_types)498 bool IsObjectStrictlyMatched(const std::vector<TypeId> &object_types, const std::vector<TypeId> &object_dtypes,
499 const std::vector<DataType> &kernel_data_types) {
500 if (object_dtypes.size() != kernel_data_types.size()) {
501 return false;
502 }
503
504 for (size_t i = 0; i < object_dtypes.size(); i++) {
505 auto is_tuple = (kernel_data_types[i].object_type == kObjectTypeTuple);
506 // For optional input, the real input object type can be a None.Tuple data-type unknown means empty tuple.
507 if (object_dtypes[i] != kernel_data_types[i].dtype) {
508 if ((!is_tuple || object_dtypes[i] != kTypeUnknown) &&
509 !(object_types[i] == kMetaTypeNone && kernel_data_types[i].is_optional)) {
510 return false;
511 }
512 }
513 }
514
515 return true;
516 }
517
GetKernelAttr(const std::string & op_name,const kernel::KernelModPtr & kernel_mod,const std::pair<std::vector<TypeId>,std::vector<TypeId>> & inputs_types_dtypes,const std::pair<std::vector<TypeId>,std::vector<TypeId>> & outputs_types_dtypes)518 std::pair<bool, KernelAttr> GetKernelAttr(
519 const std::string &op_name, const kernel::KernelModPtr &kernel_mod,
520 const std::pair<std::vector<TypeId>, std::vector<TypeId>> &inputs_types_dtypes,
521 const std::pair<std::vector<TypeId>, std::vector<TypeId>> &outputs_types_dtypes) {
522 const auto &support_list = kernel_mod->GetOpSupport();
523 for (auto &cur_kernel_attr : support_list) {
524 if (cur_kernel_attr.GetSkipCheck()) {
525 return {true, cur_kernel_attr};
526 }
527 auto data_pair = kernel::GetInOutDataTypesFromKernelAttr(cur_kernel_attr);
528 const auto &[input_data_types, output_data_types] = kernel::GetInOutDataTypesFromKernelAttr(cur_kernel_attr);
529 if (IsObjectStrictlyMatched(inputs_types_dtypes.first, inputs_types_dtypes.second, input_data_types) &&
530 IsObjectStrictlyMatched(outputs_types_dtypes.first, outputs_types_dtypes.second, output_data_types)) {
531 return std::make_pair(true, cur_kernel_attr);
532 }
533 }
534
535 for (auto &cur_kernel_attr : support_list) {
536 auto data_pair = kernel::GetInOutDataTypesFromKernelAttr(cur_kernel_attr);
537 const auto &[input_data_types, output_data_types] = kernel::GetInOutDataTypesFromKernelAttr(cur_kernel_attr);
538 if (IsObjectDtypeWeaklyMatched(inputs_types_dtypes.second, input_data_types) &&
539 IsObjectDtypeWeaklyMatched(outputs_types_dtypes.second, output_data_types)) {
540 return std::make_pair(false, cur_kernel_attr);
541 }
542 }
543 std::vector<std::string> inputs;
544 std::vector<std::string> outputs;
545 for (auto &input_type : inputs_types_dtypes.second) {
546 (void)inputs.emplace_back(TypeIdToString(input_type));
547 }
548 for (auto &output_type : outputs_types_dtypes.second) {
549 (void)outputs.emplace_back(TypeIdToString(output_type));
550 }
551 MS_EXCEPTION(TypeError)
552 << "Unsupported op [" << op_name << "] on CPU, input_type:" << inputs << " ,output_type:" << outputs
553 << ". Please confirm whether the device target setting is correct, "
554 << "or refer to 'mindspore.ops' at https://www.mindspore.cn to query the operator support list.";
555 }
556 } // namespace
557
SelectKernel(const std::vector<AbstractBasePtr> & inputs_abs,const AbstractBasePtr & outputs_abs,const DeviceContext * device_context,const std::string & op_name)558 std::pair<bool, KernelAttr> PyBoostUtils::SelectKernel(const std::vector<AbstractBasePtr> &inputs_abs,
559 const AbstractBasePtr &outputs_abs,
560 const DeviceContext *device_context,
561 const std::string &op_name) {
562 // only support CPU
563 const auto &kernel_mod = device_context->GetKernelExecutor(false)->CreateKernelMod(op_name);
564 if (kernel_mod == nullptr) {
565 MS_LOG(EXCEPTION) << "The kernel " << op_name << " unregistered.";
566 }
567 return GetKernelAttr(op_name, kernel_mod, GetInputTypeFromAbstractBase(inputs_abs),
568 GetOutputTypeFromAbstractBase(outputs_abs));
569 }
570
CastTensor(const std::optional<tensor::BaseTensorPtr> & tensor,const TypeId & type_id,const std::string & device_target)571 std::optional<tensor::BaseTensorPtr> PyBoostUtils::CastTensor(const std::optional<tensor::BaseTensorPtr> &tensor,
572 const TypeId &type_id, const std::string &device_target) {
573 if (!tensor.has_value()) {
574 return tensor;
575 }
576 if (tensor.value()->Dtype()->type_id() == type_id) {
577 return tensor;
578 }
579 auto type_id64 = std::make_shared<Int64Imm>(static_cast<int64_t>(type_id));
580 const auto &cast_op = CREATE_PYBOOST_OP(Cast, device_target);
581 cast_op->set_primitive(prim::kPrimCast);
582 return cast_op->Call(tensor.value(), type_id64);
583 }
584
CastTensor(const tensor::BaseTensorPtr & tensor,const TypeId & type_id,const std::string & device_target)585 tensor::BaseTensorPtr PyBoostUtils::CastTensor(const tensor::BaseTensorPtr &tensor, const TypeId &type_id,
586 const std::string &device_target) {
587 if (tensor->Dtype()->type_id() == type_id) {
588 return tensor;
589 }
590 auto type_id64 = std::make_shared<Int64Imm>(static_cast<int64_t>(type_id));
591 const auto &cast_op = CREATE_PYBOOST_OP(Cast, device_target);
592 return cast_op->Call(tensor, type_id64);
593 }
594
CastTensor(const std::vector<tensor::BaseTensorPtr> & tensors,const std::vector<TypeId> & type_id_list,const std::string & device_target)595 std::vector<tensor::BaseTensorPtr> PyBoostUtils::CastTensor(const std::vector<tensor::BaseTensorPtr> &tensors,
596 const std::vector<TypeId> &type_id_list,
597 const std::string &device_target) {
598 if (tensors.size() != type_id_list.size()) {
599 MS_LOG(EXCEPTION) << "before cast tensor output size is not equal after cast";
600 }
601 std::vector<tensor::BaseTensorPtr> output_tensors;
602 for (size_t i = 0; i < tensors.size(); ++i) {
603 const auto &output = CastTensor(tensors[i], type_id_list[i], device_target);
604 (void)output_tensors.emplace_back(output);
605 }
606 return output_tensors;
607 }
608
CastTensor(const std::vector<tensor::BaseTensorPtr> & tensors,TypeId type_id,const std::string & device_target)609 std::vector<tensor::BaseTensorPtr> PyBoostUtils::CastTensor(const std::vector<tensor::BaseTensorPtr> &tensors,
610 TypeId type_id, const std::string &device_target) {
611 // tuple input
612 std::vector<tensor::BaseTensorPtr> output_tensors;
613 for (size_t i = 0; i < tensors.size(); ++i) {
614 const auto &output = CastTensor(tensors[i], type_id, device_target);
615 (void)output_tensors.emplace_back(output);
616 }
617 return output_tensors;
618 }
619 } // namespace pyboost
620 } // namespace kernel
621 } // namespace mindspore
622