• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "kernel/pyboost/pyboost_utils.h"
18 #include <algorithm>
19 #include <utility>
20 #include <unordered_map>
21 #include "kernel/common_utils.h"
22 #include "kernel/kernel_mod_cache.h"
23 #include "mindapi/base/type_id.h"
24 #include "runtime/device/device_address_utils.h"
25 #include "ops/ops_frontend_func_impl.h"
26 #include "ops/op_def.h"
27 #include "runtime/pynative/op_executor.h"
28 #include "pybind_api/gil_scoped_long_running.h"
29 #include "mindspore/ccsrc/plugin/device/cpu/kernel/cpu_kernel.h"
30 #include "kernel/pyboost/auto_generate/cast.h"
31 #include "mindspore/core/ops/array_ops.h"
32 
33 namespace mindspore {
34 namespace kernel {
35 namespace pyboost {
36 namespace {
CreateTensor(const TypePtr & type,const ShapeVector & shape_vector,const AbstractBasePtr & abstract_tensor,std::vector<tensor::BaseTensorPtr> * outputs)37 void CreateTensor(const TypePtr &type, const ShapeVector &shape_vector, const AbstractBasePtr &abstract_tensor,
38                   std::vector<tensor::BaseTensorPtr> *outputs) {
39   auto output_tensor = std::make_shared<tensor::BaseTensor>(type->type_id(), shape_vector);
40   output_tensor->set_abstract(abstract_tensor);
41   output_tensor->set_need_pipeline_sync(true);
42   (void)outputs->emplace_back(output_tensor);
43   MS_LOG(DEBUG) << "Create output tensor " << output_tensor->ToString();
44 }
45 
CreateTensor(const TypePtr & type,const ShapeVector & shape_vector,std::vector<tensor::BaseTensorPtr> * outputs)46 void CreateTensor(const TypePtr &type, const ShapeVector &shape_vector, std::vector<tensor::BaseTensorPtr> *outputs) {
47   auto output_tensor = std::make_shared<tensor::BaseTensor>(type->type_id(), shape_vector);
48   output_tensor->set_need_pipeline_sync(true);
49   (void)outputs->emplace_back(output_tensor);
50   MS_LOG(DEBUG) << "Create output tensor " << output_tensor->ToString();
51 }
52 }  // namespace
53 
ToAbstractNoValue(const tensor::BaseTensorPtr & tensor)54 AbstractBasePtr ToAbstractNoValue(const tensor::BaseTensorPtr &tensor) {
55   auto abs = tensor->GetAbstractCache();
56   abs->set_value(kValueAny);
57   return abs;
58 }
59 
CreateOutputTensor(const AbstractBasePtr & abstract,std::vector<tensor::BaseTensorPtr> * outputs)60 void PyBoostUtils::CreateOutputTensor(const AbstractBasePtr &abstract, std::vector<tensor::BaseTensorPtr> *outputs) {
61   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative,
62                                      runtime::ProfilerEvent::kPyBoostCreateOutputTensor,
63                                      runtime::ProfilerRecorder::kNoName, false);
64   MS_EXCEPTION_IF_NULL(abstract);
65   if (abstract->isa<abstract::AbstractSequence>()) {
66     const auto &seq = abstract->cast<abstract::AbstractSequencePtr>();
67     const auto &elements = seq->elements();
68     for (const auto &element : elements) {
69       CreateOutputTensor(element, outputs);
70     }
71   } else if (abstract->isa<abstract::AbstractTensor>()) {
72     const auto &abstract_tensor = abstract->cast<abstract::AbstractTensorPtr>();
73     const auto &shape = abstract_tensor->GetShapeTrack();
74     const auto &type = abstract_tensor->element()->GetTypeTrack();
75     MS_LOG(DEBUG) << "get abstract tensor shape " << shape->ToString() << " type " << type->ToString();
76     if (!shape->isa<abstract::Shape>()) {
77       MS_LOG(EXCEPTION) << "AbstractTensor shape is valid " << shape->ToString();
78     }
79     const auto &shape_vector = shape->cast<abstract::ShapePtr>()->shape();
80     CreateTensor(type, shape_vector, abstract_tensor, outputs);
81   } else if (abstract->isa<abstract::AbstractScalar>()) {
82     const auto &scalar = abstract->cast<abstract::AbstractScalarPtr>();
83     const auto &type = scalar->GetTypeTrack();
84     MS_LOG(DEBUG) << "Create scalar tensor type " << type->ToString();
85     CreateTensor(type, {}, nullptr, outputs);
86   } else {
87     MS_LOG(EXCEPTION) << "Not support abstract " << abstract->ToString();
88   }
89 }
90 
ScalarToTensor(const ScalarPtr & scalar)91 tensor::BaseTensorPtr PyBoostUtils::ScalarToTensor(const ScalarPtr &scalar) {
92   if (scalar == nullptr) {
93     MS_EXCEPTION(ArgumentError) << "Nullptr Error!";
94   }
95   TypePtr data_type = scalar->type();
96   MS_EXCEPTION_IF_NULL(data_type);
97   TypeId type_id = data_type->type_id();
98   switch (type_id) {
99     case kNumberTypeBool:
100       return std::make_shared<tensor::BaseTensor>(GetValue<bool>(scalar), data_type);
101     case kNumberTypeInt8:
102       return std::make_shared<tensor::BaseTensor>(static_cast<int64_t>(GetValue<int8_t>(scalar)), data_type);
103     case kNumberTypeInt16:
104       return std::make_shared<tensor::BaseTensor>(static_cast<int64_t>(GetValue<int16_t>(scalar)), data_type);
105     case kNumberTypeInt32:
106       return std::make_shared<tensor::BaseTensor>(static_cast<int64_t>(GetValue<int32_t>(scalar)), data_type);
107     case kNumberTypeInt64:
108       return std::make_shared<tensor::BaseTensor>(GetValue<int64_t>(scalar), data_type);
109     case kNumberTypeUInt8:
110       return std::make_shared<tensor::BaseTensor>(static_cast<uint64_t>(GetValue<uint8_t>(scalar)), data_type);
111     case kNumberTypeUInt16:
112       return std::make_shared<tensor::BaseTensor>(static_cast<uint64_t>(GetValue<uint16_t>(scalar)), data_type);
113     case kNumberTypeUInt32:
114       return std::make_shared<tensor::BaseTensor>(static_cast<uint64_t>(GetValue<uint32_t>(scalar)), data_type);
115     case kNumberTypeUInt64:
116       return std::make_shared<tensor::BaseTensor>(GetValue<uint64_t>(scalar), data_type);
117     case kNumberTypeFloat32:
118       return std::make_shared<tensor::BaseTensor>(GetValue<float>(scalar), data_type);
119     case kNumberTypeFloat64:
120       return std::make_shared<tensor::BaseTensor>(GetValue<double>(scalar), data_type);
121     default:
122       MS_LOG(EXCEPTION) << "When convert scalar to tensor, the scalar type: " << data_type << " is invalid.";
123   }
124 }
125 
CreateOutputTensor(const ValueSimpleInfoPtr & output_value_simple_info,std::vector<tensor::BaseTensorPtr> * outputs)126 void PyBoostUtils::CreateOutputTensor(const ValueSimpleInfoPtr &output_value_simple_info,
127                                       std::vector<tensor::BaseTensorPtr> *outputs) {
128   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative,
129                                      runtime::ProfilerEvent::kPyBoostCreateOutputTensor,
130                                      runtime::ProfilerRecorder::kNoName, false);
131   MS_EXCEPTION_IF_NULL(output_value_simple_info);
132   size_t elem_size = output_value_simple_info->dtype_vector_.size();
133   for (size_t i = 0; i < elem_size; ++i) {
134     MS_LOG(DEBUG) << "Get tensor shape " << output_value_simple_info->shape_vector_[i] << ", type "
135                   << TypeIdToType(output_value_simple_info->dtype_vector_[i]->type_id())->ToString();
136     CreateTensor(output_value_simple_info->dtype_vector_[i], output_value_simple_info->shape_vector_[i], outputs);
137   }
138 }
139 
IsKernelModRegistered(const std::string & device_name,const std::string & op_name)140 bool PyBoostUtils::IsKernelModRegistered(const std::string &device_name, const std::string &op_name) {
141   return PyboostKernelExtraFuncFactory::GetInstance().IsKernelModRegistered(device_name, op_name);
142 }
143 
IsPyBoostCustomRegistered(const std::string & device_name,const std::string & op_name)144 bool PyBoostUtils::IsPyBoostCustomRegistered(const std::string &device_name, const std::string &op_name) {
145   return PyboostKernelExtraFuncFactory::GetInstance().IsPyBoostCustomRegistered(device_name, op_name);
146 }
147 
CreateKernelMod(const PrimitivePtr & prim,const std::string & op_name,const DeviceContext * device_context,const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)148 kernel::KernelModPtr PyBoostUtils::CreateKernelMod(const PrimitivePtr &prim, const std::string &op_name,
149                                                    const DeviceContext *device_context,
150                                                    const std::vector<KernelTensor *> &inputs,
151                                                    const std::vector<KernelTensor *> &outputs) {
152   MS_EXCEPTION_IF_NULL(device_context);
153   const auto &device_name = device_context->device_context_key().device_name_;
154 
155   auto &cache_helper = kernel::KernelModCache::GetInstance();
156   const auto &key = cache_helper.GetKernelModKey(op_name, device_name, inputs);
157   auto kernel_mod = cache_helper.GetKernelMod(key);
158   if (kernel_mod == nullptr) {
159     kernel_mod = device_context->GetKernelExecutor(false)->CreateKernelMod(op_name);
160     if (kernel_mod == nullptr) {
161       MS_LOG(EXCEPTION) << "Create kernelmod for op " << op_name << " failed";
162     }
163     if (!kernel_mod->Init(prim, inputs, outputs)) {
164       MS_LOG(EXCEPTION) << "KernelMod Init Failed: " << op_name;
165     }
166     cache_helper.SetCache(key, kernel_mod);
167     PyboostKernelExtraFuncFactory::GetInstance().SetThreadPool(device_name, kernel_mod);
168   }
169 
170   return kernel_mod;
171 }
172 
ContiguousByDeviceAddress(const DeviceSyncPtr & device_sync)173 DeviceSyncPtr PyBoostUtils::ContiguousByDeviceAddress(const DeviceSyncPtr &device_sync) {
174   auto &storage_info = device_sync->GetTensorStorageInfo();
175   if (storage_info == nullptr) {
176     return device_sync;
177   }
178 
179   auto old_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(device_sync);
180 
181   MS_EXCEPTION_IF_NULL(old_device_address);
182   MS_EXCEPTION_IF_NULL(storage_info);
183   GilReleaseWithCheck gil_release;
184 
185   const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
186     {old_device_address->device_name(), old_device_address->device_id()});
187   MS_EXCEPTION_IF_NULL(device_context);
188 
189   auto stream_id = device_context->device_res_manager_->GetCurrentStreamId();
190   auto address_size = GetTypeByte(TypeIdToType(old_device_address->type_id())) * SizeOf(storage_info->shape);
191   auto new_device_address = device_context->device_res_manager_->CreateDeviceAddress(
192     nullptr, address_size, storage_info->shape, DEFAULT_FORMAT, old_device_address->type_id(),
193     device_context->device_context_key().device_name_, device_context->device_context_key().device_id_, stream_id);
194   new_device_address->set_device_shape(storage_info->shape);
195   new_device_address->set_original_ref_count(SIZE_MAX);
196   new_device_address->ResetRefCount();
197 
198   if (!device_context->GetKernelExecutor(false)->ExecuteKernelTask(
199         runtime::KernelTaskType::kCONTIGUOUS_TASK, {old_device_address}, {new_device_address}, stream_id)) {
200     MS_LOG(EXCEPTION) << "ExecuteKernelTask failed, task_type:" << runtime::KernelTaskType::kCONTIGUOUS_TASK;
201   }
202   runtime::OpExecutor::GetInstance().WaitAll();
203   return new_device_address;
204 }
205 
CreateOutputTensor(const DeviceContext * device_context,const tensor::BaseTensorPtr & input,const TensorStorageInfoPtrList & storage_info_list,std::vector<tensor::BaseTensorPtr> * outputs)206 void PyBoostUtils::CreateOutputTensor(const DeviceContext *device_context, const tensor::BaseTensorPtr &input,
207                                       const TensorStorageInfoPtrList &storage_info_list,
208                                       std::vector<tensor::BaseTensorPtr> *outputs) {
209   for (auto &storage_info : storage_info_list) {
210     CreateOutputTensor(device_context, input, storage_info, outputs);
211   }
212 }
213 
CreateOutputTensor(const DeviceContext * device_context,const tensor::BaseTensorPtr & input,const TensorStorageInfoPtr & storage_info,std::vector<tensor::BaseTensorPtr> * outputs)214 void PyBoostUtils::CreateOutputTensor(const DeviceContext *device_context, const tensor::BaseTensorPtr &input,
215                                       const TensorStorageInfoPtr &storage_info,
216                                       std::vector<tensor::BaseTensorPtr> *outputs) {
217   MS_EXCEPTION_IF_NULL(input);
218   MS_EXCEPTION_IF_NULL(storage_info);
219   MS_EXCEPTION_IF_NULL(device_context);
220 
221   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative,
222                                      runtime::ProfilerEvent::kPyBoostCreateOutputTensor,
223                                      runtime::ProfilerRecorder::kNoName, false);
224   auto output_tensor = std::make_shared<tensor::BaseTensor>(input->data_type(), storage_info->shape);
225   output_tensor->set_need_pipeline_sync(true);
226   output_tensor->set_contiguous_callback(
227     [](const DeviceSyncPtr &device_address) -> DeviceSyncPtr { return ContiguousByDeviceAddress(device_address); });
228 
229   auto input_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(input->device_address());
230   MS_EXCEPTION_IF_NULL(input_device_address);
231   input_device_address->set_is_view(true);
232 
233   // Create view output address
234   auto output_device_address = device_context->device_res_manager_->CreateDeviceAddress(
235     nullptr, input_device_address->GetSize(), output_tensor->shape(), DEFAULT_FORMAT, output_tensor->data_type(),
236     device_context->device_context_key().device_name_, device_context->device_context_key().device_id_,
237     input_device_address->stream_id());
238   MS_EXCEPTION_IF_NULL(output_device_address);
239   output_device_address->set_tensor_storage_info(storage_info);
240   output_device_address->set_pointer_ref_count(input_device_address->pointer_ref_count());
241   output_tensor->set_device_address(output_device_address);
242   (void)outputs->emplace_back(output_tensor);
243   MS_LOG(DEBUG) << "Create output tensor " << output_tensor->ToString() << " with " << storage_info->ToString();
244 }
245 
InferByOpDef(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_abs)246 AbstractBasePtr PyBoostUtils::InferByOpDef(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_abs) {
247   MS_EXCEPTION_IF_NULL(prim);
248   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyBoostInferByOpDef,
249                                      prim->name(), false);
250   auto op_def = mindspore::ops::GetOpDef(prim->name());
251   if (op_def) {
252     (void)op_def->func_impl_.CheckValidation(prim, input_abs);
253     auto shape = op_def->func_impl_.InferShape(prim, input_abs);
254     auto type = op_def->func_impl_.InferType(prim, input_abs);
255     auto output_abs = mindspore::abstract::MakeAbstract(shape, type);
256     MS_LOG(DEBUG) << "Pynative Infer " << prim->name() << " by OpDef, got abstract: " << output_abs->ToString();
257     return output_abs;
258   } else {
259     const auto &infer_map = abstract::GetPrimitiveInferMapPtr();
260     const auto &iter = infer_map->find(prim);
261     if (iter != infer_map->end()) {
262       auto output_abs = iter->second.InferShapeAndType(nullptr, prim, input_abs);
263       MS_LOG(DEBUG) << "Pynative Infer " << prim->name()
264                     << " by C++ PrimitiveInferMap, got abstract: " << output_abs->ToString();
265       return output_abs;
266     } else {
267       MS_LOG(EXCEPTION) << "Cannot found infer function for Op " << prim->name();
268     }
269   }
270 }
271 
DispatchRun(const std::shared_ptr<runtime::PyBoostDeviceTask> & task)272 void PyBoostUtils::DispatchRun(const std::shared_ptr<runtime::PyBoostDeviceTask> &task) {
273   static auto need_sync = runtime::OpExecutor::NeedSync();
274   if (need_sync && !runtime::OpExecutor::GetInstance().async_for_graph()) {
275     MS_LOG(INFO) << "PyBoost sync run device task";
276     runtime::OpExecutor::GetInstance().WaitAll();
277     task->Run();
278   } else {
279     runtime::ProfilerAnalyzer::GetInstance().RecordFlowData(task->task_id());
280     runtime::OpExecutor::GetInstance().PushOpRunTask(task);
281   }
282 }
283 
GetKernelTensorFromAddress(const device::DeviceAddressPtrList & input_device_address)284 std::vector<kernel::KernelTensor *> PyBoostUtils::GetKernelTensorFromAddress(
285   const device::DeviceAddressPtrList &input_device_address) {
286   std::vector<kernel::KernelTensor *> input_kernel_tensors;
287   std::transform(input_device_address.begin(), input_device_address.end(), std::back_inserter(input_kernel_tensors),
288                  [](const auto &item) { return item->kernel_tensor().get(); });
289   return input_kernel_tensors;
290 }
291 
GetKernelTensor(const DeviceContext * device_context,size_t stream_id,const abstract::AbstractBasePtr & input_abs,size_t index,std::vector<kernel::KernelTensor * > * kernel_tensor_list,device::DeviceAddressPtrList * device_address_list,const BaseTensorPtr & tensor)292 void PyBoostUtils::GetKernelTensor(const DeviceContext *device_context, size_t stream_id,
293                                    const abstract::AbstractBasePtr &input_abs, size_t index,
294                                    std::vector<kernel::KernelTensor *> *kernel_tensor_list,
295                                    device::DeviceAddressPtrList *device_address_list, const BaseTensorPtr &tensor) {
296   MS_EXCEPTION_IF_NULL(tensor);
297   MS_EXCEPTION_IF_NULL(kernel_tensor_list);
298   MS_EXCEPTION_IF_NULL(device_address_list);
299 
300   const auto &device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
301   MS_EXCEPTION_IF_NULL(device_address);
302   (void)device_address_list->emplace_back(device_address);
303   const auto &kernel_tensor = device_address->kernel_tensor();
304   MS_EXCEPTION_IF_NULL(kernel_tensor);
305   (void)kernel_tensor_list->emplace_back(kernel_tensor.get());
306 }
307 
GetKernelTensor(const DeviceContext * device_context,size_t stream_id,const abstract::AbstractBasePtr & input_abs,size_t index,std::vector<kernel::KernelTensor * > * kernel_tensor_list,device::DeviceAddressPtrList * device_address_list,const std::vector<tensor::BaseTensorPtr> & tensors)308 void PyBoostUtils::GetKernelTensor(const DeviceContext *device_context, size_t stream_id,
309                                    const abstract::AbstractBasePtr &input_abs, size_t index,
310                                    std::vector<kernel::KernelTensor *> *kernel_tensor_list,
311                                    device::DeviceAddressPtrList *device_address_list,
312                                    const std::vector<tensor::BaseTensorPtr> &tensors) {
313   for (const auto &tensor : tensors) {
314     // input_abs is not used in GetKernelTensor when value is TensorPtr.
315     GetKernelTensor(device_context, stream_id, input_abs, index, kernel_tensor_list, device_address_list, tensor);
316   }
317 }
318 
CreateWorkSpaceDeviceAddress(const KernelModPtr & kernel_mod,const device::DeviceContext * device_context,const std::string & op_name)319 device::DeviceAddressPtrList PyBoostUtils::CreateWorkSpaceDeviceAddress(const KernelModPtr &kernel_mod,
320                                                                         const device::DeviceContext *device_context,
321                                                                         const std::string &op_name) {
322   MS_EXCEPTION_IF_NULL(device_context);
323   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
324   MS_EXCEPTION_IF_NULL(kernel_mod);
325 
326   const auto &workspace_sizes = kernel_mod->GetWorkspaceSizeList();
327   device::DeviceAddressPtrList workspaces_address;
328   for (const auto workspace_size : workspace_sizes) {
329     auto kernel_tensor = std::make_shared<KernelTensor>(
330       nullptr, workspace_size, Format::DEFAULT_FORMAT, kTypeUnknown, ShapeVector(),
331       device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
332     auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
333     MS_LOG(DEBUG) << "Create workspace for op: " << op_name << " addr: " << device_address;
334     MS_EXCEPTION_IF_NULL(device_address);
335     (void)workspaces_address.emplace_back(device_address);
336   }
337 
338   for (size_t i = 0; i < workspace_sizes.size(); ++i) {
339     auto device_address = workspaces_address[i];
340     MS_EXCEPTION_IF_NULL(device_address);
341     if (device_address->GetPtr() == nullptr &&
342         !device_context->device_res_manager_->AllocateMemory(device_address.get())) {
343       MS_LOG(EXCEPTION) << "Allocate workspace memory failed";
344     }
345     MS_LOG(DEBUG) << "workspace[" << i << "]:" << device_address->kernel_tensor()->device_ptr()
346                   << " size:" << device_address->kernel_tensor()->size();
347   }
348   return workspaces_address;
349 }
350 
GetInstance()351 PyboostKernelExtraFuncFactory &PyboostKernelExtraFuncFactory::GetInstance() {
352   static PyboostKernelExtraFuncFactory instance;
353   return instance;
354 }
355 
LaunchKernel(const PrimitivePtr & primitive,const DeviceContext * device_context,const AddressInfoPair & input_address_info,const AddressInfoPair & output_address_info,size_t stream_id)356 void PyBoostUtils::LaunchKernel(const PrimitivePtr &primitive, const DeviceContext *device_context,
357                                 const AddressInfoPair &input_address_info, const AddressInfoPair &output_address_info,
358                                 size_t stream_id) {
359   const auto &real_name = primitive->name();
360   // KernelMod init
361   auto kernel_mod = PyBoostUtils::CreateKernelMod(primitive, real_name, device_context, input_address_info.first,
362                                                   output_address_info.first);
363   MS_EXCEPTION_IF_NULL(kernel_mod);
364   // KernelMod resize
365   if (kernel_mod->Resize(input_address_info.first, output_address_info.first) == kernel::KRET_RESIZE_FAILED) {
366     MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Kernel build failed:#dmsg#CPU kernel op [" << real_name << "] resize failed.";
367   }
368   // Get workspace address
369   const auto &workspace_device_address =
370     PyBoostUtils::CreateWorkSpaceDeviceAddress(kernel_mod, device_context, primitive->name());
371   const auto &workspace_kernel_tensors = PyBoostUtils::GetKernelTensorFromAddress(workspace_device_address);
372 
373   const auto &device_name = device_context->device_context_key().device_name_;
374   void *stream_ptr = device_context->device_res_manager_->GetStream(stream_id);
375   if (!PyboostKernelExtraFuncFactory::GetInstance().IsEnableProfiler(device_name)) {
376     if (!kernel_mod->Launch(input_address_info.first, workspace_kernel_tensors, output_address_info.first,
377                             stream_ptr)) {
378       MS_LOG(EXCEPTION) << "Launch kernel failed, name: " << real_name;
379     }
380   } else {
381     const auto &input_kts = input_address_info.first;
382     std::vector<BaseShapePtr> input_shapes;
383     for (auto kt : input_kts) {
384       MS_EXCEPTION_IF_NULL(kt);
385       input_shapes.push_back(kt->GetShape());
386     }
387     PyboostKernelExtraFuncFactory::GetInstance().LaunchKernelWithProfiler(
388       device_name, device_context, real_name, {}, [&]() {
389         if (!kernel_mod->Launch(input_address_info.first, workspace_kernel_tensors, output_address_info.first,
390                                 stream_ptr)) {
391           MS_LOG(EXCEPTION) << "Launch kernel failed, name: " << real_name;
392         }
393       });
394   }
395   if (kernel_mod->IsNeedUpdateOutputShapeAndSize()) {
396     kernel_mod->UpdateOutputShapeAndSize(input_address_info.first, output_address_info.first);
397   }
398   runtime::DeviceAddressUtils::ProcessCrossStreamAddress(real_name, device_context, stream_id, input_address_info.first,
399                                                          output_address_info.first);
400   MS_LOG(DEBUG) << real_name << " Launch end";
401 }
402 
403 namespace {
GetTypeIdFromAbstractTensor(const AbstractBasePtr & abs_base)404 TypeId GetTypeIdFromAbstractTensor(const AbstractBasePtr &abs_base) {
405   if (abs_base->isa<abstract::AbstractTensor>()) {
406     auto abs_tensor = std::dynamic_pointer_cast<abstract::AbstractTensor>(abs_base);
407     return abs_tensor->element()->BuildType()->type_id();
408   }
409   return abs_base->BuildType()->type_id();
410 }
411 
GetAbstractObjectType(const AbstractBasePtr & abstract)412 TypeId GetAbstractObjectType(const AbstractBasePtr &abstract) {
413   if (abstract == nullptr) {
414     return kTypeUnknown;
415   }
416   if (abstract->isa<abstract::AbstractTensor>()) {
417     return kObjectTypeTensorType;
418   }
419   if (abstract->isa<abstract::AbstractTuple>()) {
420     return kObjectTypeTuple;
421   }
422   if (abstract->isa<abstract::AbstractList>()) {
423     return kObjectTypeList;
424   }
425   if (abstract->isa<abstract::AbstractScalar>()) {
426     // scalar input may not converted to tensor
427     return kObjectTypeNumber;
428   }
429   if (abstract->isa<abstract::AbstractNone>()) {
430     return kMetaTypeNone;
431   }
432 
433   return kTypeUnknown;
434 }
435 
GetOutputTypeFromAbstractBase(const AbstractBasePtr & abs_base)436 std::pair<std::vector<TypeId>, std::vector<TypeId>> GetOutputTypeFromAbstractBase(const AbstractBasePtr &abs_base) {
437   std::vector<TypeId> output_dtype;
438   std::vector<TypeId> output_type;
439   if (abs_base->isa<abstract::AbstractTuple>()) {
440     auto abs_tuple = std::dynamic_pointer_cast<abstract::AbstractTuple>(abs_base);
441     for (auto &abs : abs_tuple->elements()) {
442       (void)output_dtype.emplace_back(GetTypeIdFromAbstractTensor(abs));
443       (void)output_type.emplace_back(GetAbstractObjectType(abs));
444     }
445   } else {
446     (void)output_type.emplace_back(GetAbstractObjectType(abs_base));
447     (void)output_dtype.emplace_back(GetTypeIdFromAbstractTensor(abs_base));
448   }
449   return std::make_pair(output_type, output_dtype);
450 }
451 
GetInputTypeFromAbstractBase(const std::vector<AbstractBasePtr> & abs_vec)452 std::pair<std::vector<TypeId>, std::vector<TypeId>> GetInputTypeFromAbstractBase(
453   const std::vector<AbstractBasePtr> &abs_vec) {
454   std::vector<TypeId> input_dtype;
455   std::vector<TypeId> input_type;
456   for (auto &abs : abs_vec) {
457     if (abs->isa<abstract::AbstractTuple>()) {
458       // a tuple tensors have same type
459       auto abs_tuple = std::dynamic_pointer_cast<abstract::AbstractTuple>(abs);
460       if (abs_tuple->elements().empty()) {
461         input_dtype.emplace_back(kTypeUnknown);
462         continue;
463       }
464       input_dtype.emplace_back(abs_tuple->elements()[0]->BuildType()->type_id());
465     } else {
466       input_dtype.emplace_back(GetTypeIdFromAbstractTensor(abs));
467     }
468     input_type.emplace_back(GetAbstractObjectType(abs));
469   }
470   return std::make_pair(input_type, input_dtype);
471 }
472 
InputDtypeMatch(TypeId input_attr,TypeId input_type)473 bool InputDtypeMatch(TypeId input_attr, TypeId input_type) {
474   if (input_attr == input_type || kTypeUnknown == input_type) {
475     return true;
476   }
477   if (input_attr == kNumberTypeInt32 && (input_type == kNumberTypeInt16 || input_type == kNumberTypeInt64)) {
478     return true;
479   }
480   if (input_attr == kNumberTypeFloat32 && (input_type == kNumberTypeFloat16 || input_type == kNumberTypeFloat64)) {
481     return true;
482   }
483   return false;
484 }
485 
IsObjectDtypeWeaklyMatched(const std::vector<TypeId> & object_dtypes,const std::vector<DataType> & kernel_data_types)486 bool IsObjectDtypeWeaklyMatched(const std::vector<TypeId> &object_dtypes,
487                                 const std::vector<DataType> &kernel_data_types) {
488   // only support CPU
489   for (size_t i = 0; i < object_dtypes.size(); i++) {
490     // For optional input, the real input object type can be a None.
491     if (!InputDtypeMatch(kernel_data_types[i].dtype, object_dtypes[i])) {
492       return false;
493     }
494   }
495   return true;
496 }
497 
IsObjectStrictlyMatched(const std::vector<TypeId> & object_types,const std::vector<TypeId> & object_dtypes,const std::vector<DataType> & kernel_data_types)498 bool IsObjectStrictlyMatched(const std::vector<TypeId> &object_types, const std::vector<TypeId> &object_dtypes,
499                              const std::vector<DataType> &kernel_data_types) {
500   if (object_dtypes.size() != kernel_data_types.size()) {
501     return false;
502   }
503 
504   for (size_t i = 0; i < object_dtypes.size(); i++) {
505     auto is_tuple = (kernel_data_types[i].object_type == kObjectTypeTuple);
506     // For optional input, the real input object type can be a None.Tuple data-type unknown means empty tuple.
507     if (object_dtypes[i] != kernel_data_types[i].dtype) {
508       if ((!is_tuple || object_dtypes[i] != kTypeUnknown) &&
509           !(object_types[i] == kMetaTypeNone && kernel_data_types[i].is_optional)) {
510         return false;
511       }
512     }
513   }
514 
515   return true;
516 }
517 
GetKernelAttr(const std::string & op_name,const kernel::KernelModPtr & kernel_mod,const std::pair<std::vector<TypeId>,std::vector<TypeId>> & inputs_types_dtypes,const std::pair<std::vector<TypeId>,std::vector<TypeId>> & outputs_types_dtypes)518 std::pair<bool, KernelAttr> GetKernelAttr(
519   const std::string &op_name, const kernel::KernelModPtr &kernel_mod,
520   const std::pair<std::vector<TypeId>, std::vector<TypeId>> &inputs_types_dtypes,
521   const std::pair<std::vector<TypeId>, std::vector<TypeId>> &outputs_types_dtypes) {
522   const auto &support_list = kernel_mod->GetOpSupport();
523   for (auto &cur_kernel_attr : support_list) {
524     if (cur_kernel_attr.GetSkipCheck()) {
525       return {true, cur_kernel_attr};
526     }
527     auto data_pair = kernel::GetInOutDataTypesFromKernelAttr(cur_kernel_attr);
528     const auto &[input_data_types, output_data_types] = kernel::GetInOutDataTypesFromKernelAttr(cur_kernel_attr);
529     if (IsObjectStrictlyMatched(inputs_types_dtypes.first, inputs_types_dtypes.second, input_data_types) &&
530         IsObjectStrictlyMatched(outputs_types_dtypes.first, outputs_types_dtypes.second, output_data_types)) {
531       return std::make_pair(true, cur_kernel_attr);
532     }
533   }
534 
535   for (auto &cur_kernel_attr : support_list) {
536     auto data_pair = kernel::GetInOutDataTypesFromKernelAttr(cur_kernel_attr);
537     const auto &[input_data_types, output_data_types] = kernel::GetInOutDataTypesFromKernelAttr(cur_kernel_attr);
538     if (IsObjectDtypeWeaklyMatched(inputs_types_dtypes.second, input_data_types) &&
539         IsObjectDtypeWeaklyMatched(outputs_types_dtypes.second, output_data_types)) {
540       return std::make_pair(false, cur_kernel_attr);
541     }
542   }
543   std::vector<std::string> inputs;
544   std::vector<std::string> outputs;
545   for (auto &input_type : inputs_types_dtypes.second) {
546     (void)inputs.emplace_back(TypeIdToString(input_type));
547   }
548   for (auto &output_type : outputs_types_dtypes.second) {
549     (void)outputs.emplace_back(TypeIdToString(output_type));
550   }
551   MS_EXCEPTION(TypeError)
552     << "Unsupported op [" << op_name << "] on CPU, input_type:" << inputs << " ,output_type:" << outputs
553     << ". Please confirm whether the device target setting is correct, "
554     << "or refer to 'mindspore.ops' at https://www.mindspore.cn to query the operator support list.";
555 }
556 }  // namespace
557 
SelectKernel(const std::vector<AbstractBasePtr> & inputs_abs,const AbstractBasePtr & outputs_abs,const DeviceContext * device_context,const std::string & op_name)558 std::pair<bool, KernelAttr> PyBoostUtils::SelectKernel(const std::vector<AbstractBasePtr> &inputs_abs,
559                                                        const AbstractBasePtr &outputs_abs,
560                                                        const DeviceContext *device_context,
561                                                        const std::string &op_name) {
562   // only support CPU
563   const auto &kernel_mod = device_context->GetKernelExecutor(false)->CreateKernelMod(op_name);
564   if (kernel_mod == nullptr) {
565     MS_LOG(EXCEPTION) << "The kernel " << op_name << " unregistered.";
566   }
567   return GetKernelAttr(op_name, kernel_mod, GetInputTypeFromAbstractBase(inputs_abs),
568                        GetOutputTypeFromAbstractBase(outputs_abs));
569 }
570 
CastTensor(const std::optional<tensor::BaseTensorPtr> & tensor,const TypeId & type_id,const std::string & device_target)571 std::optional<tensor::BaseTensorPtr> PyBoostUtils::CastTensor(const std::optional<tensor::BaseTensorPtr> &tensor,
572                                                               const TypeId &type_id, const std::string &device_target) {
573   if (!tensor.has_value()) {
574     return tensor;
575   }
576   if (tensor.value()->Dtype()->type_id() == type_id) {
577     return tensor;
578   }
579   auto type_id64 = std::make_shared<Int64Imm>(static_cast<int64_t>(type_id));
580   const auto &cast_op = CREATE_PYBOOST_OP(Cast, device_target);
581   cast_op->set_primitive(prim::kPrimCast);
582   return cast_op->Call(tensor.value(), type_id64);
583 }
584 
CastTensor(const tensor::BaseTensorPtr & tensor,const TypeId & type_id,const std::string & device_target)585 tensor::BaseTensorPtr PyBoostUtils::CastTensor(const tensor::BaseTensorPtr &tensor, const TypeId &type_id,
586                                                const std::string &device_target) {
587   if (tensor->Dtype()->type_id() == type_id) {
588     return tensor;
589   }
590   auto type_id64 = std::make_shared<Int64Imm>(static_cast<int64_t>(type_id));
591   const auto &cast_op = CREATE_PYBOOST_OP(Cast, device_target);
592   return cast_op->Call(tensor, type_id64);
593 }
594 
CastTensor(const std::vector<tensor::BaseTensorPtr> & tensors,const std::vector<TypeId> & type_id_list,const std::string & device_target)595 std::vector<tensor::BaseTensorPtr> PyBoostUtils::CastTensor(const std::vector<tensor::BaseTensorPtr> &tensors,
596                                                             const std::vector<TypeId> &type_id_list,
597                                                             const std::string &device_target) {
598   if (tensors.size() != type_id_list.size()) {
599     MS_LOG(EXCEPTION) << "before cast tensor output size is not equal after cast";
600   }
601   std::vector<tensor::BaseTensorPtr> output_tensors;
602   for (size_t i = 0; i < tensors.size(); ++i) {
603     const auto &output = CastTensor(tensors[i], type_id_list[i], device_target);
604     (void)output_tensors.emplace_back(output);
605   }
606   return output_tensors;
607 }
608 
CastTensor(const std::vector<tensor::BaseTensorPtr> & tensors,TypeId type_id,const std::string & device_target)609 std::vector<tensor::BaseTensorPtr> PyBoostUtils::CastTensor(const std::vector<tensor::BaseTensorPtr> &tensors,
610                                                             TypeId type_id, const std::string &device_target) {
611   // tuple input
612   std::vector<tensor::BaseTensorPtr> output_tensors;
613   for (size_t i = 0; i < tensors.size(); ++i) {
614     const auto &output = CastTensor(tensors[i], type_id, device_target);
615     (void)output_tensors.emplace_back(output);
616   }
617   return output_tensors;
618 }
619 }  // namespace pyboost
620 }  // namespace kernel
621 }  // namespace mindspore
622