1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_CORE_EXPANDER_H_ 17 #define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_CORE_EXPANDER_H_ 18 #include <vector> 19 #include <memory> 20 #include "ir/func_graph.h" 21 #include "include/backend/visible.h" 22 #include "backend/common/graph_kernel/core/graph_kernel_callback.h" 23 24 namespace mindspore::graphkernel { 25 class BACKEND_EXPORT Expander { 26 public: 27 /** 28 * Expand input cnode to a funcgraph which composite same function with core ops, 29 * and return a cnode which input[0] is the funcgraph and input[1:-1] are inputs. 30 */ 31 virtual AnfNodePtr Run(const AnfNodePtr &node) = 0; 32 virtual ~Expander() = default; 33 }; 34 using ExpanderPtr = std::shared_ptr<Expander>; 35 36 class DefaultExpander : public Expander { 37 public: DefaultExpander(const CallbackPtr & cb)38 explicit DefaultExpander(const CallbackPtr &cb) : cb_(cb) {} 39 ~DefaultExpander() override = default; 40 AnfNodePtr Run(const AnfNodePtr &node) override; 41 42 protected: 43 virtual FuncGraphPtr ExpandToGraph(const CNodePtr &node); 44 virtual AnfNodePtr CreateCallCNode(const FuncGraphPtr &fg, const CNodePtr &cnode); 45 CallbackPtr cb_; 46 }; 47 48 class LitegraphExpander : public DefaultExpander { 49 public: LitegraphExpander(const CallbackPtr & cb)50 explicit LitegraphExpander(const CallbackPtr &cb) : DefaultExpander(cb) {} 51 ~LitegraphExpander() override = default; 52 53 protected: 54 FuncGraphPtr ExpandToGraph(const CNodePtr &node) override; 55 AnfNodePtr CreateCallCNode(const FuncGraphPtr &sub_fg, const CNodePtr &cnode) override; 56 }; 57 58 class BACKEND_EXPORT ExpanderDecorator : public Expander { 59 public: ExpanderDecorator(const ExpanderPtr & decorated)60 explicit ExpanderDecorator(const ExpanderPtr &decorated) : decorated_(decorated) {} 61 ~ExpanderDecorator() override = default; 62 /** 63 * Do something before or after decoreated run. 64 */ 65 AnfNodePtr Run(const AnfNodePtr &node) override; 66 67 protected: 68 // The expander cannot change the original node, this function clone the cnode with original info. 69 CNodePtr QuickCloneCNode(const AnfNodePtr &node, bool clone_prim = false) const; 70 71 ExpanderPtr decorated_; 72 }; 73 74 using ExpanderCreatorFunc = std::function<ExpanderPtr(const ExpanderPtr &)>; 75 using ExpanderCreatorFuncList = std::vector<ExpanderCreatorFunc>; 76 77 // This decorator is required if we need to get the value of one input parameter during expanding 78 class BACKEND_EXPORT DependValueDeco : public ExpanderDecorator { 79 public: DependValueDeco(const ExpanderPtr & decorated,const HashSet<size_t> & input_idx)80 DependValueDeco(const ExpanderPtr &decorated, const HashSet<size_t> &input_idx) 81 : ExpanderDecorator(decorated), input_idx_(input_idx) {} 82 ~DependValueDeco() = default; 83 GetCreator(const HashSet<size_t> & input_idx)84 static ExpanderCreatorFunc GetCreator(const HashSet<size_t> &input_idx) { 85 return [input_idx](const ExpanderPtr &decorated) { 86 return std::static_pointer_cast<Expander>(std::make_shared<DependValueDeco>(decorated, input_idx)); 87 }; 88 } 89 AnfNodePtr Run(const AnfNodePtr &node) override; 90 91 protected: 92 HashSet<size_t> input_idx_; 93 }; 94 /** 95 * Wrap Expander with decorators. 96 */ 97 BACKEND_EXPORT ExpanderPtr WrapExpander(const ExpanderPtr &base, const ExpanderCreatorFuncList &deco_creators); 98 } // namespace mindspore::graphkernel 99 #endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_CORE_EXPANDER_H_ 100