1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "frontend/expander/bprop/bprop_meta_func_graph.h"
17
18 #include <unordered_set>
19 #include <vector>
20 #include "include/common/utils/utils.h"
21 #include "frontend/expander/bprop/bprop.h"
22 #include "include/common/utils/python_adapter.h"
23
24 namespace mindspore {
25 namespace expander {
26 namespace bprop {
NewGraph(const AbstractBasePtrList & abs_list)27 FuncGraphPtr NewGraph(const AbstractBasePtrList &abs_list) {
28 auto fg = std::make_shared<FuncGraph>();
29 for (const auto &abs : abs_list) {
30 auto para = fg->add_parameter();
31 para->set_abstract(abs);
32 }
33 return fg;
34 }
35
GenerateFuncGraph(const abstract::AbstractBasePtrList & input_abs)36 FuncGraphPtr BpropMetaFuncGraph::GenerateFuncGraph(const abstract::AbstractBasePtrList &input_abs) {
37 auto fg = NewGraph(input_abs);
38 try {
39 if (!expander::bprop::ExpandBpropInGraphMode(handle_, primal_, fg)) {
40 return nullptr;
41 }
42 } catch (const py::type_error &ex) {
43 MS_EXCEPTION(TypeError) << "Bprop \"" << primal_->name() << "\" encounter a problem: [" << ex.what() << "]";
44 } catch (const py::value_error &ex) {
45 MS_EXCEPTION(ValueError) << "Bprop \"" << primal_->name() << "\" encounter a problem: [" << ex.what() << "]";
46 } catch (const std::exception &e) {
47 MS_LOG(EXCEPTION) << "Bprop \"" << primal_->name() << "\" encounter a problem: [" << e.what() << "]";
48 }
49 return fg;
50 }
51
52 static const std::unordered_set<std::string> g_blacklist = {"SparseGatherV2",
53 "EmbeddingLookup",
54 "AffineGrid",
55 "ScatterAddWithAxis",
56 "Expand",
57 "AllReduce",
58 "AllGather",
59 "_MirrorOperator",
60 "Load",
61 "UpdateState",
62 "Depend",
63 "ParallelResizeBilinear",
64 "MatrixSolve",
65 "CholeskySolve",
66 "CumulativeLogsumexp",
67 "AvgPoolV1",
68 "Eigh",
69 "SparseAdd",
70 "CSRReduceSum",
71 "CSRMV",
72 "CSRMul",
73 "CSRDiv",
74 "COOTensorGetIndices",
75 "COOTensorGetValues",
76 "COOTensorGetDenseShape",
77 "CSRTensorGetIndptr",
78 "CSRTensorGetIndices",
79 "CSRTensorGetValues",
80 "CSRTensorGetDenseShape",
81 "CSRSparseMatrixToDense",
82 "DenseToCSRSparseMatrix",
83 "SparseSegmentSqrtN",
84 "SparseSegmentSqrtNWithNumSegments",
85 "SparseSegmentSum",
86 "SparseSegmentSumWithNumSegments",
87 "SparseSegmentMeanWithNumSegments",
88 "SparseReorder",
89 "SparseDenseCwiseMul",
90 "SparseDenseCwiseDiv",
91 "RaggedTensorToSparse"};
CanExpand(const std::string & name)92 bool CanExpand(const std::string &name) {
93 if (OpEnvManager::UsePyBprop(name)) {
94 return false;
95 }
96 if (g_blacklist.count(name) != 0) {
97 return false;
98 }
99 return true;
100 }
101
GetBpropMetaFuncGraph(const PrimitivePtr & primal,const CNodePtr & cnode)102 FuncGraphPtr GetBpropMetaFuncGraph(const PrimitivePtr &primal, const CNodePtr &cnode) {
103 auto prim_name = primal->name();
104 const BpropHandle *handle = BpropIRBuilderFactory::Instance().GetBuilder(prim_name);
105 if (!CanExpand(prim_name) || handle == nullptr) {
106 return nullptr;
107 }
108 size_t forward_inputs_size = 0;
109 if (cnode) {
110 std::vector<AnfNodePtr> node_lists = cnode->inputs();
111 forward_inputs_size = cnode->size() - 1;
112 for (size_t i = 1; i < node_lists.size(); i++) {
113 auto input_i = node_lists[i];
114 if (HasAbstractMonad(input_i)) {
115 --forward_inputs_size;
116 }
117 }
118 } else {
119 const auto &op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
120 const auto iter = op_primc_fns.find(prim_name);
121 if (iter == op_primc_fns.end()) {
122 MS_LOG(EXCEPTION) << "The " << prim_name << " operator is not registered";
123 }
124 auto primc = iter->second();
125 forward_inputs_size = GetValue<std::vector<std::string>>(primc->GetAttr(kAttrInputNames)).size();
126 }
127 auto fg = std::make_shared<FuncGraph>();
128 auto meta_graph = std::make_shared<BpropMetaFuncGraph>(primal, handle);
129 std::vector<AnfNodePtr> inputs{NewValueNode(meta_graph)};
130 for (size_t i = 0; i < forward_inputs_size; ++i) {
131 (void)inputs.emplace_back(fg->add_parameter());
132 }
133 (void)inputs.emplace_back(fg->add_parameter());
134 (void)inputs.emplace_back(fg->add_parameter());
135 fg->set_output(fg->NewCNode(inputs));
136 fg->set_flag(mindspore::kFuncGraphFlagMetaFuncGraphBprop, true);
137 if (GetPrimitiveFlag(primal, GRAPH_FLAG_SIDE_EFFECT_BACKPROP)) {
138 fg->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
139 }
140 return fg;
141 }
142 } // namespace bprop
143 } // namespace expander
144 } // namespace mindspore
145