• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_ADAPTER_EXPANDER_H_
17 #define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_ADAPTER_EXPANDER_H_
18 #include <memory>
19 #include "backend/common/graph_kernel/core/expander.h"
20 #include "ir/func_graph.h"
21 #include "include/backend/visible.h"
22 #include <nlohmann/json.hpp>
23 
24 namespace mindspore::graphkernel {
25 class ComplexOpDecorator : public ExpanderDecorator {
26  public:
ComplexOpDecorator(const ExpanderPtr & decorated)27   explicit ComplexOpDecorator(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {}
28   ~ComplexOpDecorator() override = default;
Creator(const ExpanderPtr & decorated)29   static ExpanderPtr Creator(const ExpanderPtr &decorated) {
30     return std::static_pointer_cast<Expander>(std::make_shared<ComplexOpDecorator>(decorated));
31   }
32   AnfNodePtr Run(const AnfNodePtr &node) override;
33 };
34 
35 class ArgWithValueDeco : public ExpanderDecorator {
36  public:
ArgWithValueDeco(const ExpanderPtr & decorated)37   explicit ArgWithValueDeco(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {}
38   ~ArgWithValueDeco() override = default;
Creator(const ExpanderPtr & decorated)39   static ExpanderPtr Creator(const ExpanderPtr &decorated) {
40     return std::static_pointer_cast<Expander>(std::make_shared<ArgWithValueDeco>(decorated));
41   }
42   AnfNodePtr Run(const AnfNodePtr &node) override;
43 };
44 
45 class UnfoldMakeTupleDeco : public ExpanderDecorator {
46  public:
UnfoldMakeTupleDeco(const ExpanderPtr & decorated)47   explicit UnfoldMakeTupleDeco(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {}
48   ~UnfoldMakeTupleDeco() override = default;
Creator(const ExpanderPtr & decorated)49   static ExpanderPtr Creator(const ExpanderPtr &decorated) {
50     return std::static_pointer_cast<Expander>(std::make_shared<UnfoldMakeTupleDeco>(decorated));
51   }
52   AnfNodePtr Run(const AnfNodePtr &node) override;
53 };
54 
55 class ProcessCustomOpDeco : public ExpanderDecorator {
56  public:
ProcessCustomOpDeco(const ExpanderPtr & decorated)57   explicit ProcessCustomOpDeco(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {}
58   ~ProcessCustomOpDeco() override = default;
Creator(const ExpanderPtr & decorated)59   static ExpanderPtr Creator(const ExpanderPtr &decorated) {
60     return std::static_pointer_cast<Expander>(std::make_shared<ProcessCustomOpDeco>(decorated));
61   }
62   AnfNodePtr Run(const AnfNodePtr &node) override;
63 };
64 
65 class SetDynamicShapeAttrDeco : public ExpanderDecorator {
66  public:
SetDynamicShapeAttrDeco(const ExpanderPtr & decorated)67   explicit SetDynamicShapeAttrDeco(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {}
68   ~SetDynamicShapeAttrDeco() override = default;
Creator(const ExpanderPtr & decorated)69   static ExpanderPtr Creator(const ExpanderPtr &decorated) {
70     return std::static_pointer_cast<Expander>(std::make_shared<SetDynamicShapeAttrDeco>(decorated));
71   }
72   AnfNodePtr Run(const AnfNodePtr &node) override;
73 };
74 
75 /**
76  * Get the Expander which is used to expand a cnode to a funcgraph which composite same function with core ops.
77  */
78 BACKEND_EXPORT ExpanderPtr GetExpander(const AnfNodePtr &node, const ExpanderPtr &init);
79 
80 /**
81  * Get the Expander which is used to expand a cnode to a funcgraph which composite same function with core ops.
82  */
83 BACKEND_EXPORT ExpanderPtr GetExpander(const AnfNodePtr &node, bool abstract = true);
84 
85 /**
86  * Inline the expanded func graph to main graph.
87  */
88 BACKEND_EXPORT void InlineExpandFuncGraph(const AnfNodePtr &expanding_node, const FuncGraphPtr &expanded_graph);
89 
90 /**
91  * Try Expand cnode with check func.
92  */
93 BACKEND_EXPORT AnfNodePtr TryExpandCNode(const AnfNodePtr &node, const std::function<bool(const CNodePtr &)> &func);
94 
95 /**
96  * Check if node can be expanded fallback.
97  */
98 BACKEND_EXPORT bool CanExpandFallback(const AnfNodePtr &node);
99 
100 bool IsComplexOp(const AnfNodePtr &node);
101 }  // namespace mindspore::graphkernel
102 #endif  // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_ADAPTER_EXPANDER_H_
103