• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "frontend/expander/bprop/bprop_meta_func_graph.h"
17 
18 #include <unordered_set>
19 #include <vector>
20 #include "include/common/utils/utils.h"
21 #include "frontend/expander/bprop/bprop.h"
22 #include "include/common/utils/python_adapter.h"
23 
24 namespace mindspore {
25 namespace expander {
26 namespace bprop {
NewGraph(const AbstractBasePtrList & abs_list)27 FuncGraphPtr NewGraph(const AbstractBasePtrList &abs_list) {
28   auto fg = std::make_shared<FuncGraph>();
29   for (const auto &abs : abs_list) {
30     auto para = fg->add_parameter();
31     para->set_abstract(abs);
32   }
33   return fg;
34 }
35 
GenerateFuncGraph(const abstract::AbstractBasePtrList & input_abs)36 FuncGraphPtr BpropMetaFuncGraph::GenerateFuncGraph(const abstract::AbstractBasePtrList &input_abs) {
37   auto fg = NewGraph(input_abs);
38   try {
39     if (!expander::bprop::ExpandBpropInGraphMode(handle_, primal_, fg)) {
40       return nullptr;
41     }
42   } catch (const py::type_error &ex) {
43     MS_EXCEPTION(TypeError) << "Bprop \"" << primal_->name() << "\" encounter a problem: [" << ex.what() << "]";
44   } catch (const py::value_error &ex) {
45     MS_EXCEPTION(ValueError) << "Bprop \"" << primal_->name() << "\" encounter a problem: [" << ex.what() << "]";
46   } catch (const std::exception &e) {
47     MS_LOG(EXCEPTION) << "Bprop \"" << primal_->name() << "\" encounter a problem: [" << e.what() << "]";
48   }
49   return fg;
50 }
51 
52 static const std::unordered_set<std::string> g_blacklist = {"SparseGatherV2",
53                                                             "EmbeddingLookup",
54                                                             "AffineGrid",
55                                                             "ScatterAddWithAxis",
56                                                             "Expand",
57                                                             "AllReduce",
58                                                             "AllGather",
59                                                             "_MirrorOperator",
60                                                             "Load",
61                                                             "UpdateState",
62                                                             "Depend",
63                                                             "ParallelResizeBilinear",
64                                                             "MatrixSolve",
65                                                             "CholeskySolve",
66                                                             "CumulativeLogsumexp",
67                                                             "AvgPoolV1",
68                                                             "Eigh",
69                                                             "SparseAdd",
70                                                             "CSRReduceSum",
71                                                             "CSRMV",
72                                                             "CSRMul",
73                                                             "CSRDiv",
74                                                             "COOTensorGetIndices",
75                                                             "COOTensorGetValues",
76                                                             "COOTensorGetDenseShape",
77                                                             "CSRTensorGetIndptr",
78                                                             "CSRTensorGetIndices",
79                                                             "CSRTensorGetValues",
80                                                             "CSRTensorGetDenseShape",
81                                                             "CSRSparseMatrixToDense",
82                                                             "DenseToCSRSparseMatrix",
83                                                             "SparseSegmentSqrtN",
84                                                             "SparseSegmentSqrtNWithNumSegments",
85                                                             "SparseSegmentSum",
86                                                             "SparseSegmentSumWithNumSegments",
87                                                             "SparseSegmentMeanWithNumSegments",
88                                                             "SparseReorder",
89                                                             "SparseDenseCwiseMul",
90                                                             "SparseDenseCwiseDiv",
91                                                             "RaggedTensorToSparse"};
CanExpand(const std::string & name)92 bool CanExpand(const std::string &name) {
93   if (OpEnvManager::UsePyBprop(name)) {
94     return false;
95   }
96   if (g_blacklist.count(name) != 0) {
97     return false;
98   }
99   return true;
100 }
101 
GetBpropMetaFuncGraph(const PrimitivePtr & primal,const CNodePtr & cnode)102 FuncGraphPtr GetBpropMetaFuncGraph(const PrimitivePtr &primal, const CNodePtr &cnode) {
103   auto prim_name = primal->name();
104   const BpropHandle *handle = BpropIRBuilderFactory::Instance().GetBuilder(prim_name);
105   if (!CanExpand(prim_name) || handle == nullptr) {
106     return nullptr;
107   }
108   size_t forward_inputs_size = 0;
109   if (cnode) {
110     std::vector<AnfNodePtr> node_lists = cnode->inputs();
111     forward_inputs_size = cnode->size() - 1;
112     for (size_t i = 1; i < node_lists.size(); i++) {
113       auto input_i = node_lists[i];
114       if (HasAbstractMonad(input_i)) {
115         --forward_inputs_size;
116       }
117     }
118   } else {
119     const auto &op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
120     const auto iter = op_primc_fns.find(prim_name);
121     if (iter == op_primc_fns.end()) {
122       MS_LOG(EXCEPTION) << "The " << prim_name << " operator is not registered";
123     }
124     auto primc = iter->second();
125     forward_inputs_size = GetValue<std::vector<std::string>>(primc->GetAttr(kAttrInputNames)).size();
126   }
127   auto fg = std::make_shared<FuncGraph>();
128   auto meta_graph = std::make_shared<BpropMetaFuncGraph>(primal, handle);
129   std::vector<AnfNodePtr> inputs{NewValueNode(meta_graph)};
130   for (size_t i = 0; i < forward_inputs_size; ++i) {
131     (void)inputs.emplace_back(fg->add_parameter());
132   }
133   (void)inputs.emplace_back(fg->add_parameter());
134   (void)inputs.emplace_back(fg->add_parameter());
135   fg->set_output(fg->NewCNode(inputs));
136   fg->set_flag(mindspore::kFuncGraphFlagMetaFuncGraphBprop, true);
137   if (GetPrimitiveFlag(primal, GRAPH_FLAG_SIDE_EFFECT_BACKPROP)) {
138     fg->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
139   }
140   return fg;
141 }
142 }  // namespace bprop
143 }  // namespace expander
144 }  // namespace mindspore
145