1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2020 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_ 20 #define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_ 21 22 #include <vector> 23 #include <string> 24 #include <unordered_map> 25 #include <utility> 26 #include <map> 27 #include <set> 28 #include <memory> 29 #include "pipeline/jit/static_analysis/static_analysis.h" 30 #include "utils/misc.h" 31 #include "ir/dtype.h" 32 #include "ir/meta_func_graph.h" 33 34 namespace mindspore { 35 // namespace to support composite operators definition 36 namespace prim { 37 class MultitypeFuncGraph : public MetaFuncGraph { 38 public: 39 explicit MultitypeFuncGraph(const std::string &name); 40 ~MultitypeFuncGraph() override = default; 41 MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph) 42 43 using specialize_fn = FuncGraph *(*)(TypePtrList); 44 // Register a method which specialize based on types vectors; 45 virtual void Register(const TypePtrList &types, specialize_fn s_fn); 46 virtual void Register(const TypePtrList &types, const py::function &py_fn); 47 virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn); 48 49 FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override; GetPyFnCacheSize()50 size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); } GetPyFunctions()51 const std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> &GetPyFunctions() const { 52 return fn_cache_py_; 53 } 54 55 private: 56 const std::pair<py::function, bool> SignMatch(const TypePtrList &types); 57 std::unordered_map<TypePtrList, specialize_fn, TypeListHasher, TypeListEqual> fn_cache_; 58 std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> fn_cache_py_; 59 }; 60 using MultitypeFuncGraphPtr = std::shared_ptr<MultitypeFuncGraph>; 61 } // namespace prim 62 } // namespace mindspore 63 64 #endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_ 65