• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_CORE_EXPANDER_H_
17 #define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_CORE_EXPANDER_H_
18 #include <vector>
19 #include <memory>
20 #include "ir/func_graph.h"
21 #include "include/backend/visible.h"
22 #include "backend/common/graph_kernel/core/graph_kernel_callback.h"
23 
24 namespace mindspore::graphkernel {
25 class BACKEND_EXPORT Expander {
26  public:
27   /**
28    * Expand input cnode to a funcgraph which composite same function with core ops,
29    * and return a cnode which input[0] is the funcgraph and input[1:-1] are inputs.
30    */
31   virtual AnfNodePtr Run(const AnfNodePtr &node) = 0;
32   virtual ~Expander() = default;
33 };
34 using ExpanderPtr = std::shared_ptr<Expander>;
35 
36 class DefaultExpander : public Expander {
37  public:
DefaultExpander(const CallbackPtr & cb)38   explicit DefaultExpander(const CallbackPtr &cb) : cb_(cb) {}
39   ~DefaultExpander() override = default;
40   AnfNodePtr Run(const AnfNodePtr &node) override;
41 
42  protected:
43   virtual FuncGraphPtr ExpandToGraph(const CNodePtr &node);
44   virtual AnfNodePtr CreateCallCNode(const FuncGraphPtr &fg, const CNodePtr &cnode);
45   CallbackPtr cb_;
46 };
47 
48 class LitegraphExpander : public DefaultExpander {
49  public:
LitegraphExpander(const CallbackPtr & cb)50   explicit LitegraphExpander(const CallbackPtr &cb) : DefaultExpander(cb) {}
51   ~LitegraphExpander() override = default;
52 
53  protected:
54   FuncGraphPtr ExpandToGraph(const CNodePtr &node) override;
55   AnfNodePtr CreateCallCNode(const FuncGraphPtr &sub_fg, const CNodePtr &cnode) override;
56 };
57 
58 class BACKEND_EXPORT ExpanderDecorator : public Expander {
59  public:
ExpanderDecorator(const ExpanderPtr & decorated)60   explicit ExpanderDecorator(const ExpanderPtr &decorated) : decorated_(decorated) {}
61   ~ExpanderDecorator() override = default;
62   /**
63    * Do something before or after decoreated run.
64    */
65   AnfNodePtr Run(const AnfNodePtr &node) override;
66 
67  protected:
68   // The expander cannot change the original node, this function clone the cnode with original info.
69   CNodePtr QuickCloneCNode(const AnfNodePtr &node, bool clone_prim = false) const;
70 
71   ExpanderPtr decorated_;
72 };
73 
74 using ExpanderCreatorFunc = std::function<ExpanderPtr(const ExpanderPtr &)>;
75 using ExpanderCreatorFuncList = std::vector<ExpanderCreatorFunc>;
76 
77 // This decorator is required if we need to get the value of one input parameter during expanding
78 class BACKEND_EXPORT DependValueDeco : public ExpanderDecorator {
79  public:
DependValueDeco(const ExpanderPtr & decorated,const HashSet<size_t> & input_idx)80   DependValueDeco(const ExpanderPtr &decorated, const HashSet<size_t> &input_idx)
81       : ExpanderDecorator(decorated), input_idx_(input_idx) {}
82   ~DependValueDeco() = default;
83 
GetCreator(const HashSet<size_t> & input_idx)84   static ExpanderCreatorFunc GetCreator(const HashSet<size_t> &input_idx) {
85     return [input_idx](const ExpanderPtr &decorated) {
86       return std::static_pointer_cast<Expander>(std::make_shared<DependValueDeco>(decorated, input_idx));
87     };
88   }
89   AnfNodePtr Run(const AnfNodePtr &node) override;
90 
91  protected:
92   HashSet<size_t> input_idx_;
93 };
94 /**
95  * Wrap Expander with decorators.
96  */
97 BACKEND_EXPORT ExpanderPtr WrapExpander(const ExpanderPtr &base, const ExpanderCreatorFuncList &deco_creators);
98 }  // namespace mindspore::graphkernel
99 #endif  // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_CORE_EXPANDER_H_
100