• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2020-2022 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_
20 #define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_
21 
22 #include <vector>
23 #include <string>
24 #include <tuple>
25 #include <utility>
26 #include <map>
27 #include <set>
28 #include <memory>
29 #include <algorithm>
30 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
31 #include "utils/misc.h"
32 #include "ir/dtype.h"
33 #include "ir/meta_func_graph.h"
34 #include "pipeline/jit/ps/parse/parse_base.h"
35 
36 namespace mindspore {
37 namespace prim {
38 class MultitypeFuncGraph : public MetaFuncGraph {
39  public:
40   explicit MultitypeFuncGraph(const std::string &name);
41   ~MultitypeFuncGraph() override = default;
42   MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph)
43 
44   using specialize_fn = FuncGraph *(*)(TypePtrList);
45   // Register a method which specialize based on types vectors.
46   virtual void Register(const TypePtrList &types, specialize_fn s_fn);
47   virtual void Register(const TypePtrList &types, const py::function &py_fn);
48   virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn);
49 
set_doc_url(const std::string & doc_url)50   void set_doc_url(const std::string &doc_url) { doc_url_ = doc_url; }
set_need_raise()51   void set_need_raise() { need_raise_ = true; }
set_meta_obj(const py::object & obj)52   void set_meta_obj(const py::object &obj) { meta_obj_ = obj; }
53   FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override;
GetPyFnCacheSize()54   size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); }
GetPyFunctions()55   const TypeListMap<py::function> &GetPyFunctions() const { return fn_cache_py_; }
56 
57  private:
58   const std::tuple<py::function, bool, size_t> SignMatch(const TypePtrList &types);
59   const std::string PrintMatchFailLog(const TypeListMap<py::function>, const TypePtrList &types, size_t match_max_idx,
60                                       bool has_any);
61   TypeListMap<specialize_fn> fn_cache_;
62   TypeListMap<py::function> fn_cache_py_;
63   std::string doc_url_;
64   py::object meta_obj_ = py::none();
65   bool need_raise_ = false;
66 };
67 using MultitypeFuncGraphPtr = std::shared_ptr<MultitypeFuncGraph>;
68 bool CheckDictContainsAny(const std::vector<std::pair<mindspore::ValuePtr, mindspore::TypePtr>> &key_values);
69 bool CheckContainsAny(const TypePtrList &types);
70 }  // namespace prim
71 }  // namespace mindspore
72 
73 #endif  // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_
74