1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_PI_JIT_INFER_H
17 #define MINDSPORE_PI_JIT_INFER_H
18
19 #include <memory>
20 #include <vector>
21 #include <string>
22 #include <unordered_map>
23 #include "pybind11/pybind11.h"
24
25 namespace mindspore {
26 namespace pijit {
27
28 class InferEngine : public std::enable_shared_from_this<InferEngine> {
29 public:
30 static std::shared_ptr<InferEngine> GetInstance();
31 PyObject *InferPrimitive(PyObject *primitive, const std::vector<PyObject *> &args, bool *is_abstract);
32 PyObject *InferSpecialPrimitive(PyObject *primitive, const std::vector<PyObject *> &arglist);
33 bool SupportInfer(PyObject *primitive);
34 bool Init();
35 bool Deinit();
36
37 protected:
38 InferEngine();
39 bool bInit_ = false;
40 };
41
42 using InferEnginePtr = std::shared_ptr<InferEngine>;
43
44 namespace py = pybind11;
45
46 template <class T>
GetPybindType()47 PyTypeObject *GetPybindType() {
48 py::handle mapped_type = py::detail::get_type_handle(typeid(T), false);
49 return reinterpret_cast<PyTypeObject *>(mapped_type.ptr());
50 }
51
52 template <class T, bool sub = true>
IsPybindType(PyTypeObject * tp)53 bool IsPybindType(PyTypeObject *tp) {
54 PyTypeObject *tar = GetPybindType<T>();
55 if (tar == nullptr || tp == nullptr) {
56 return false;
57 }
58 return tp == tar || (sub ? PyType_IsSubtype(tp, tar) : false);
59 }
60
61 template <bool sub>
62 bool IsGradOperationType(PyTypeObject *tp);
63 template <bool sub>
64 bool IsVmapOperationType(PyTypeObject *tp);
65 template <bool sub>
66 bool IsShardType(PyTypeObject *tp);
67 template <bool sub>
68 bool IsStubTensorType(PyTypeObject *tp);
69 template <bool sub>
70 bool IsTensorType(PyTypeObject *tp);
71 template <bool sub>
72 bool IsCellListType(PyTypeObject *tp);
73 template <bool sub>
74 bool IsCellType(PyTypeObject *tp);
75 template <bool sub>
76 bool IsPrimitiveType(PyTypeObject *tp);
77 template <bool sub>
78 bool IsPrimitiveFunctionType(PyTypeObject *tp);
79 template <bool sub>
80 bool IsMetaFuncGraphType(PyTypeObject *tp);
81 template <bool sub>
82 bool IsMSDTypeType(PyTypeObject *tp);
83
84 bool FindTensorName(const std::string &name);
85
86 bool CheckTensorDataInitialized(const py::object &tensor);
87 py::object EvalMSAPIValue(const py::object &ms_api, const py::object &args, const py::object &key_words);
88
89 using SpecialPrimitiveInferFuncMap =
90 std::unordered_map<std::string, PyObject *(*)(PyObject *, const std::vector<PyObject *> &)>;
91 const SpecialPrimitiveInferFuncMap &GetSpecialPrimitiveInferFunc();
92
93 } // namespace pijit
94 } // namespace mindspore
95
96 #endif // MINDSPORE_PI_JIT_INFER_H
97