• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PYFUNC_KERNEL_H_
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PYFUNC_KERNEL_H_
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 #include <Python.h>
24 #include "pybind11/pybind11.h"
25 #include "pybind11/numpy.h"
26 #include "plugin/device/cpu/kernel/cpu_kernel.h"
27 
28 namespace py = pybind11;
29 namespace mindspore {
30 namespace kernel {
31 // Indicate Python object type. The input/output of PyFun must be either Scalar or Numpy Array.
32 enum class PythonOjectType : char { kScalar, kNumpyArray };
33 // Indicate PyFunc input/output information
34 struct PyFuncArgumentInfo {
35   // Empty vector indicate the Python object is Scalar and non-empty means Numpy Array.
36   std::vector<std::vector<int64_t>> shapes;
37   // Data type as int, float, bool.
38   std::vector<TypeId> dtypes;
39   // Python object type
40   std::vector<PythonOjectType> object_types;
41 };
42 
43 class PyFuncCpuKernelMod : public NativeCpuKernelMod {
44  public:
PyFuncCpuKernelMod()45   PyFuncCpuKernelMod() : init_(false), fake_output_(false), single_scalar_output_(false), func_id_(-1) {}
46   ~PyFuncCpuKernelMod() = default;
47 
Init(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)48   bool Init(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) override {
49     return true;
50   }
51   int Resize(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) override;
52   // Construct arguments with raw memory, invoke Python function and then convert result to raw memory.
53   bool Launch(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &,
54               const std::vector<KernelTensor *> &outputs) override;
55 
56  protected:
57   // Analyse PyFunc input/output spec.
58   void BuildFuncInfo(const PrimitivePtr &primitive, const std::vector<KernelTensor *> &inputs,
59                      const std::vector<KernelTensor *> &outputs);
60   // Get Python function from anchor.
61   py::function GetPythonFunc() const;
62   bool ExecuteKernel(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs);
63 
64   bool init_;
65   bool fake_output_;
66   bool single_scalar_output_;
67   // The Python object is not acceptable for `Primitive` attribute. So we pass an unique key instead of Python function.
68   // mindspore.ops.operations.PyFunc store the Python function to a dict, and pass the key to backend kernel.
69   // The kernel get the Python functhon by the key from the dict when the kernel is first invoked.
70   int64_t func_id_;
71   py::function py_func_;
72   // Input and output specifications.
73   PyFuncArgumentInfo input_infos_;
74   PyFuncArgumentInfo output_infos_;
75   // The kernel hold the input tensors during execution to avoid dynamic malloc/free host memory.
76   std::vector<std::shared_ptr<tensor::Tensor>> input_tensors_;
77 };
78 }  // namespace kernel
79 }  // namespace mindspore
80 
81 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PYFUNC_KERNEL_H_
82