1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2020-2022 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_ 20 #define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_ 21 22 #include <vector> 23 #include <string> 24 #include <tuple> 25 #include <utility> 26 #include <map> 27 #include <set> 28 #include <memory> 29 #include <algorithm> 30 #include "pipeline/jit/ps/static_analysis/static_analysis.h" 31 #include "utils/misc.h" 32 #include "ir/dtype.h" 33 #include "ir/meta_func_graph.h" 34 #include "pipeline/jit/ps/parse/parse_base.h" 35 36 namespace mindspore { 37 namespace prim { 38 class MultitypeFuncGraph : public MetaFuncGraph { 39 public: 40 explicit MultitypeFuncGraph(const std::string &name); 41 ~MultitypeFuncGraph() override = default; 42 MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph) 43 44 using specialize_fn = FuncGraph *(*)(TypePtrList); 45 // Register a method which specialize based on types vectors. 46 virtual void Register(const TypePtrList &types, specialize_fn s_fn); 47 virtual void Register(const TypePtrList &types, const py::function &py_fn); 48 virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn); 49 set_doc_url(const std::string & doc_url)50 void set_doc_url(const std::string &doc_url) { doc_url_ = doc_url; } set_need_raise()51 void set_need_raise() { need_raise_ = true; } set_meta_obj(const py::object & obj)52 void set_meta_obj(const py::object &obj) { meta_obj_ = obj; } 53 FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override; GetPyFnCacheSize()54 size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); } GetPyFunctions()55 const TypeListMap<py::function> &GetPyFunctions() const { return fn_cache_py_; } 56 57 private: 58 const std::tuple<py::function, bool, size_t> SignMatch(const TypePtrList &types); 59 const std::string PrintMatchFailLog(const TypeListMap<py::function>, const TypePtrList &types, size_t match_max_idx, 60 bool has_any); 61 TypeListMap<specialize_fn> fn_cache_; 62 TypeListMap<py::function> fn_cache_py_; 63 std::string doc_url_; 64 py::object meta_obj_ = py::none(); 65 bool need_raise_ = false; 66 }; 67 using MultitypeFuncGraphPtr = std::shared_ptr<MultitypeFuncGraph>; 68 bool CheckDictContainsAny(const std::vector<std::pair<mindspore::ValuePtr, mindspore::TypePtr>> &key_values); 69 bool CheckContainsAny(const TypePtrList &types); 70 } // namespace prim 71 } // namespace mindspore 72 73 #endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_ 74