1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/expander/utils.h"
18
19 #include <memory>
20 #include <string>
21 #include <vector>
22 #include <map>
23 #include <set>
24 #include "ops/nn_op_name.h"
25 #include "ops/structure_ops.h"
26 #include "ops/op_def.h"
27 #include "ops/math_ops.h"
28 #include "ops/array_ops.h"
29 #include "mindspore/core/utils/anf_utils.h"
30 #include "frontend/parallel/auto_parallel/costmodel.h"
31 #include "frontend/parallel/graph_util/generate_graph.h"
32 #include "frontend/operator/ops_front_infer_function.h"
33 #include "frontend/expander/bprop/bprop.h"
34 #include "pybind_api/ir/primitive_py.h"
35 #include "backend/common/graph_kernel/adapter/expander.h"
36 #include "utils/ms_context.h"
37 #include "include/common/utils/utils.h"
38 #include "include/common/debug/anf_ir_dump.h"
39 #include "ir/func_graph_cloner.h"
40
41 namespace mindspore {
42 /* namespace to support expander */
43 namespace expander {
44 namespace {
45 const std::map<std::string, std::vector<std::string>> op2attrs = {
46 {kBroadcastOpName, {kAttrShape}},
47 {kReduceMaxOpName, {kAttrKeepDims}},
48 {kReduceMinOpName, {kAttrKeepDims}},
49 {kReduceSumOpName, {kAttrKeepDims}},
50 {kMatMulOpName, {kTransposeA, kTransposeB}},
51 {kConcatOpName, {kAttrAxis}},
52 {kSqueezeOpName, {kAttrAxis}},
53 {kOneHotOpName, {kAttrAxis}},
54 {kSoftmaxOpName, {kAttrAxis}},
55 {kSplitOpName, {kAttrAxis}},
56 {kLayerNormOpName, {kAttrBeginNormAxis, kAttrBeginParamsAxis, kAttrEpsilon}},
57 {kStridedSliceOpName, {kAttrBeginMask, kAttrEndMask, kAttrEllipsisMask, kAttrNewAxisMask, kAttrShrinkAxisMask}},
58 {kLayerNormGradOpName, {kAttrBeginNormAxis, kAttrBeginParamsAxis}},
59 {kLayerNormGradGradOpName, {kAttrBeginNormAxis, kAttrBeginParamsAxis}},
60 {kBiasAddOpName, {kAttrDataFormat}},
61 {kBiasAddGradOpName, {kAttrDataFormat}},
62 {kStackOpName, {kAttrAxis}},
63 {kBatchMatMulOpName, {kTransposeA, kTransposeB}}};
64 } // namespace
65
ConvertPrimToPrimPy(const PrimitivePtr & primc)66 ValuePtr ConvertPrimToPrimPy(const PrimitivePtr &primc) {
67 if (primc == nullptr || primc->isa<PrimitivePy>()) {
68 return nullptr;
69 }
70 // If it is primitive function, no need convert because primitive function are all C++ infer.
71 if (mindspore::ops::IsPrimitiveFunction(primc->name())) {
72 return nullptr;
73 }
74 if (abstract::GetFrontendPrimitiveInferImpl(primc).has_value()) {
75 return nullptr;
76 }
77 if (primc->isa<prim::DoSignaturePrimitive>()) {
78 return nullptr;
79 }
80 const auto &primpy_cache = OpPrimPyRegister::GetInstance().GetPrimPyMap();
81 if (auto it = primpy_cache.find(primc->name()); it != primpy_cache.end()) {
82 return it->second;
83 }
84 parallel::OperatorAttrs attrs;
85 const auto iter = op2attrs.find(primc->name());
86 if (iter != op2attrs.end()) {
87 for (auto &attr : iter->second) {
88 if (primc->HasAttr(attr)) {
89 (void)attrs.emplace_back(std::pair{attr, primc->GetAttr(attr)});
90 } else {
91 MS_LOG(WARNING) << primc->name() << " op do not have attr: " << attr;
92 return nullptr;
93 }
94 }
95 }
96 auto new_prim = parallel::CreateOpInstance(attrs, primc->name(), "");
97 MS_EXCEPTION_IF_NULL(new_prim);
98 (void)new_prim->cast<PrimitivePtr>()->SetAttrs(primc->attrs());
99 // prim can be cached when prim has no attrs
100 constexpr size_t kOnlyIONames = 2;
101 if ((primc->attrs().size() == kOnlyIONames) && primc->HasAttr("input_names") && primc->HasAttr("output_names")) {
102 OpPrimPyRegister::GetInstance().SetPrimPyMap(primc->name(), new_prim);
103 }
104 return new_prim;
105 }
106
107 class PrimpyConverter {
108 public:
Run(const FuncGraphPtr & graph)109 bool Run(const FuncGraphPtr &graph) {
110 MS_EXCEPTION_IF_NULL(graph);
111 (void)visited_graphs_.insert(graph);
112 auto todos = TopoSort(graph->get_return());
113 auto mng = Manage({graph}, false);
114 for (const auto &node : todos) {
115 if (node->isa<ValueNode>()) {
116 auto sub_graph = node->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>();
117 if (sub_graph != nullptr && visited_graphs_.count(sub_graph) == 0) {
118 (void)Run(sub_graph);
119 continue;
120 }
121 }
122 if (!node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
123 continue;
124 }
125 auto primitive = GetCNodePrimitive(node);
126 auto new_prim = ConvertPrimToPrimPy(primitive);
127 AnfNodePtrList inputs = {NewValueNode(new_prim)};
128 auto cnode = dyn_cast_ptr<CNode>(node);
129 auto cnode_inputs = cnode->inputs();
130 (void)inputs.insert(inputs.cend(), cnode_inputs.cbegin() + 1, cnode_inputs.cend());
131 auto new_cnode = graph->NewCNodeInOrder(inputs);
132 (void)mng->Replace(node, new_cnode);
133 }
134 return true;
135 }
136
137 private:
138 std::set<FuncGraphPtr> visited_graphs_;
139 };
140
ConvertPrimToPrimPy(const FuncGraphPtr & graph)141 bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
142 PrimpyConverter c;
143 return c.Run(graph);
144 }
145
146 using graphkernel::ExpanderDecorator;
147 using graphkernel::ExpanderPtr;
148 class PrimToPrimPyDecorator : public ExpanderDecorator {
149 public:
PrimToPrimPyDecorator(const ExpanderPtr & decorated)150 explicit PrimToPrimPyDecorator(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {}
151 ~PrimToPrimPyDecorator() override = default;
Creator(const ExpanderPtr & decorated)152 static ExpanderPtr Creator(const ExpanderPtr &decorated) {
153 return std::static_pointer_cast<Expander>(std::make_shared<PrimToPrimPyDecorator>(decorated));
154 }
Run(const AnfNodePtr & node)155 AnfNodePtr Run(const AnfNodePtr &node) override {
156 auto new_node = decorated_->Run(node);
157 if (new_node == nullptr) {
158 return nullptr;
159 }
160 auto new_cnode = dyn_cast<CNode>(new_node);
161 auto expand_fg = GetCNodeFuncGraph(new_cnode);
162 if (!ConvertPrimToPrimPy(expand_fg)) {
163 return nullptr;
164 }
165 new_cnode->set_input(0, NewValueNode(expand_fg));
166 return new_cnode;
167 }
168 };
169
TryExpandCNodeFE(const AnfNodePtr & node)170 AnfNodePtr TryExpandCNodeFE(const AnfNodePtr &node) {
171 if (!graphkernel::CanExpandFallback(node)) {
172 return nullptr;
173 }
174 auto primitive = GetCNodePrimitive(node);
175 if (primitive == nullptr) {
176 return nullptr;
177 }
178 auto expander = graphkernel::GetExpander(node);
179 expander = PrimToPrimPyDecorator::Creator(expander);
180 auto new_node = expander->Run(node);
181 auto expand_fg = GetCNodeFuncGraph(new_node);
182 if (expand_fg == nullptr) {
183 return nullptr;
184 }
185 #ifdef ENABLE_DUMP_IR
186 auto context = MsContext::GetInstance();
187 MS_EXCEPTION_IF_NULL(context);
188 if (context->CanDump(kIntroductory)) {
189 DumpIR("expand_fe_" + GetCNodeFuncName(node->cast<CNodePtr>()) + ".ir", expand_fg);
190 }
191 #endif
192 return new_node;
193 }
194
ClearAllCache()195 void ClearAllCache() { bprop::ClearBpropOpGraphMap(); }
196 } // namespace expander
197 } // namespace mindspore
198