• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_INFER_H
17 #define MINDSPORE_PI_JIT_INFER_H
18 
19 #include <memory>
20 #include <vector>
21 #include <string>
22 #include <unordered_map>
23 #include "pybind11/pybind11.h"
24 
25 namespace mindspore {
26 namespace pijit {
27 
28 class InferEngine : public std::enable_shared_from_this<InferEngine> {
29  public:
30   static std::shared_ptr<InferEngine> GetInstance();
31   PyObject *InferPrimitive(PyObject *primitive, const std::vector<PyObject *> &args, bool *is_abstract);
32   PyObject *InferSpecialPrimitive(PyObject *primitive, const std::vector<PyObject *> &arglist);
33   bool SupportInfer(PyObject *primitive);
34   bool Init();
35   bool Deinit();
36 
37  protected:
38   InferEngine();
39   bool bInit_ = false;
40 };
41 
42 using InferEnginePtr = std::shared_ptr<InferEngine>;
43 
44 namespace py = pybind11;
45 
46 template <class T>
GetPybindType()47 PyTypeObject *GetPybindType() {
48   py::handle mapped_type = py::detail::get_type_handle(typeid(T), false);
49   return reinterpret_cast<PyTypeObject *>(mapped_type.ptr());
50 }
51 
52 template <class T, bool sub = true>
IsPybindType(PyTypeObject * tp)53 bool IsPybindType(PyTypeObject *tp) {
54   PyTypeObject *tar = GetPybindType<T>();
55   if (tar == nullptr || tp == nullptr) {
56     return false;
57   }
58   return tp == tar || (sub ? PyType_IsSubtype(tp, tar) : false);
59 }
60 
61 template <bool sub>
62 bool IsGradOperationType(PyTypeObject *tp);
63 template <bool sub>
64 bool IsVmapOperationType(PyTypeObject *tp);
65 template <bool sub>
66 bool IsShardType(PyTypeObject *tp);
67 template <bool sub>
68 bool IsStubTensorType(PyTypeObject *tp);
69 template <bool sub>
70 bool IsTensorType(PyTypeObject *tp);
71 template <bool sub>
72 bool IsCellListType(PyTypeObject *tp);
73 template <bool sub>
74 bool IsCellType(PyTypeObject *tp);
75 template <bool sub>
76 bool IsPrimitiveType(PyTypeObject *tp);
77 template <bool sub>
78 bool IsPrimitiveFunctionType(PyTypeObject *tp);
79 template <bool sub>
80 bool IsMetaFuncGraphType(PyTypeObject *tp);
81 template <bool sub>
82 bool IsMSDTypeType(PyTypeObject *tp);
83 
84 bool FindTensorName(const std::string &name);
85 
86 bool CheckTensorDataInitialized(const py::object &tensor);
87 py::object EvalMSAPIValue(const py::object &ms_api, const py::object &args, const py::object &key_words);
88 
89 using SpecialPrimitiveInferFuncMap =
90   std::unordered_map<std::string, PyObject *(*)(PyObject *, const std::vector<PyObject *> &)>;
91 const SpecialPrimitiveInferFuncMap &GetSpecialPrimitiveInferFunc();
92 
93 }  // namespace pijit
94 }  // namespace mindspore
95 
96 #endif  // MINDSPORE_PI_JIT_INFER_H
97