• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/expander/utils.h"
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 #include <map>
23 #include <set>
24 #include "ops/nn_op_name.h"
25 #include "ops/structure_ops.h"
26 #include "ops/op_def.h"
27 #include "ops/math_ops.h"
28 #include "ops/array_ops.h"
29 #include "mindspore/core/utils/anf_utils.h"
30 #include "frontend/parallel/auto_parallel/costmodel.h"
31 #include "frontend/parallel/graph_util/generate_graph.h"
32 #include "frontend/operator/ops_front_infer_function.h"
33 #include "frontend/expander/bprop/bprop.h"
34 #include "pybind_api/ir/primitive_py.h"
35 #include "backend/common/graph_kernel/adapter/expander.h"
36 #include "utils/ms_context.h"
37 #include "include/common/utils/utils.h"
38 #include "include/common/debug/anf_ir_dump.h"
39 #include "ir/func_graph_cloner.h"
40 
41 namespace mindspore {
42 /* namespace to support expander */
43 namespace expander {
44 namespace {
45 const std::map<std::string, std::vector<std::string>> op2attrs = {
46   {kBroadcastOpName, {kAttrShape}},
47   {kReduceMaxOpName, {kAttrKeepDims}},
48   {kReduceMinOpName, {kAttrKeepDims}},
49   {kReduceSumOpName, {kAttrKeepDims}},
50   {kMatMulOpName, {kTransposeA, kTransposeB}},
51   {kConcatOpName, {kAttrAxis}},
52   {kSqueezeOpName, {kAttrAxis}},
53   {kOneHotOpName, {kAttrAxis}},
54   {kSoftmaxOpName, {kAttrAxis}},
55   {kSplitOpName, {kAttrAxis}},
56   {kLayerNormOpName, {kAttrBeginNormAxis, kAttrBeginParamsAxis, kAttrEpsilon}},
57   {kStridedSliceOpName, {kAttrBeginMask, kAttrEndMask, kAttrEllipsisMask, kAttrNewAxisMask, kAttrShrinkAxisMask}},
58   {kLayerNormGradOpName, {kAttrBeginNormAxis, kAttrBeginParamsAxis}},
59   {kLayerNormGradGradOpName, {kAttrBeginNormAxis, kAttrBeginParamsAxis}},
60   {kBiasAddOpName, {kAttrDataFormat}},
61   {kBiasAddGradOpName, {kAttrDataFormat}},
62   {kStackOpName, {kAttrAxis}},
63   {kBatchMatMulOpName, {kTransposeA, kTransposeB}}};
64 }  // namespace
65 
ConvertPrimToPrimPy(const PrimitivePtr & primc)66 ValuePtr ConvertPrimToPrimPy(const PrimitivePtr &primc) {
67   if (primc == nullptr || primc->isa<PrimitivePy>()) {
68     return nullptr;
69   }
70   // If it is primitive function, no need convert because primitive function are all C++ infer.
71   if (mindspore::ops::IsPrimitiveFunction(primc->name())) {
72     return nullptr;
73   }
74   if (abstract::GetFrontendPrimitiveInferImpl(primc).has_value()) {
75     return nullptr;
76   }
77   if (primc->isa<prim::DoSignaturePrimitive>()) {
78     return nullptr;
79   }
80   const auto &primpy_cache = OpPrimPyRegister::GetInstance().GetPrimPyMap();
81   if (auto it = primpy_cache.find(primc->name()); it != primpy_cache.end()) {
82     return it->second;
83   }
84   parallel::OperatorAttrs attrs;
85   const auto iter = op2attrs.find(primc->name());
86   if (iter != op2attrs.end()) {
87     for (auto &attr : iter->second) {
88       if (primc->HasAttr(attr)) {
89         (void)attrs.emplace_back(std::pair{attr, primc->GetAttr(attr)});
90       } else {
91         MS_LOG(WARNING) << primc->name() << " op do not have attr: " << attr;
92         return nullptr;
93       }
94     }
95   }
96   auto new_prim = parallel::CreateOpInstance(attrs, primc->name(), "");
97   MS_EXCEPTION_IF_NULL(new_prim);
98   (void)new_prim->cast<PrimitivePtr>()->SetAttrs(primc->attrs());
99   // prim can be cached when prim has no attrs
100   constexpr size_t kOnlyIONames = 2;
101   if ((primc->attrs().size() == kOnlyIONames) && primc->HasAttr("input_names") && primc->HasAttr("output_names")) {
102     OpPrimPyRegister::GetInstance().SetPrimPyMap(primc->name(), new_prim);
103   }
104   return new_prim;
105 }
106 
107 class PrimpyConverter {
108  public:
Run(const FuncGraphPtr & graph)109   bool Run(const FuncGraphPtr &graph) {
110     MS_EXCEPTION_IF_NULL(graph);
111     (void)visited_graphs_.insert(graph);
112     auto todos = TopoSort(graph->get_return());
113     auto mng = Manage({graph}, false);
114     for (const auto &node : todos) {
115       if (node->isa<ValueNode>()) {
116         auto sub_graph = node->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>();
117         if (sub_graph != nullptr && visited_graphs_.count(sub_graph) == 0) {
118           (void)Run(sub_graph);
119           continue;
120         }
121       }
122       if (!node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
123         continue;
124       }
125       auto primitive = GetCNodePrimitive(node);
126       auto new_prim = ConvertPrimToPrimPy(primitive);
127       AnfNodePtrList inputs = {NewValueNode(new_prim)};
128       auto cnode = dyn_cast_ptr<CNode>(node);
129       auto cnode_inputs = cnode->inputs();
130       (void)inputs.insert(inputs.cend(), cnode_inputs.cbegin() + 1, cnode_inputs.cend());
131       auto new_cnode = graph->NewCNodeInOrder(inputs);
132       (void)mng->Replace(node, new_cnode);
133     }
134     return true;
135   }
136 
137  private:
138   std::set<FuncGraphPtr> visited_graphs_;
139 };
140 
ConvertPrimToPrimPy(const FuncGraphPtr & graph)141 bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
142   PrimpyConverter c;
143   return c.Run(graph);
144 }
145 
146 using graphkernel::ExpanderDecorator;
147 using graphkernel::ExpanderPtr;
148 class PrimToPrimPyDecorator : public ExpanderDecorator {
149  public:
PrimToPrimPyDecorator(const ExpanderPtr & decorated)150   explicit PrimToPrimPyDecorator(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {}
151   ~PrimToPrimPyDecorator() override = default;
Creator(const ExpanderPtr & decorated)152   static ExpanderPtr Creator(const ExpanderPtr &decorated) {
153     return std::static_pointer_cast<Expander>(std::make_shared<PrimToPrimPyDecorator>(decorated));
154   }
Run(const AnfNodePtr & node)155   AnfNodePtr Run(const AnfNodePtr &node) override {
156     auto new_node = decorated_->Run(node);
157     if (new_node == nullptr) {
158       return nullptr;
159     }
160     auto new_cnode = dyn_cast<CNode>(new_node);
161     auto expand_fg = GetCNodeFuncGraph(new_cnode);
162     if (!ConvertPrimToPrimPy(expand_fg)) {
163       return nullptr;
164     }
165     new_cnode->set_input(0, NewValueNode(expand_fg));
166     return new_cnode;
167   }
168 };
169 
TryExpandCNodeFE(const AnfNodePtr & node)170 AnfNodePtr TryExpandCNodeFE(const AnfNodePtr &node) {
171   if (!graphkernel::CanExpandFallback(node)) {
172     return nullptr;
173   }
174   auto primitive = GetCNodePrimitive(node);
175   if (primitive == nullptr) {
176     return nullptr;
177   }
178   auto expander = graphkernel::GetExpander(node);
179   expander = PrimToPrimPyDecorator::Creator(expander);
180   auto new_node = expander->Run(node);
181   auto expand_fg = GetCNodeFuncGraph(new_node);
182   if (expand_fg == nullptr) {
183     return nullptr;
184   }
185 #ifdef ENABLE_DUMP_IR
186   auto context = MsContext::GetInstance();
187   MS_EXCEPTION_IF_NULL(context);
188   if (context->CanDump(kIntroductory)) {
189     DumpIR("expand_fe_" + GetCNodeFuncName(node->cast<CNodePtr>()) + ".ir", expand_fg);
190   }
191 #endif
192   return new_node;
193 }
194 
ClearAllCache()195 void ClearAllCache() { bprop::ClearBpropOpGraphMap(); }
196 }  // namespace expander
197 }  // namespace mindspore
198