1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2020-2021 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_D_FUNCTOR_H_
20 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_D_FUNCTOR_H_
21
22 #include <memory>
23 #include <string>
24 #include <unordered_map>
25 #include <vector>
26 #include <iostream>
27 #include <utility>
28
29 #include "ir/anf.h"
30 #include "ir/meta_func_graph.h"
31 #include "ir/func_graph_cloner.h"
32 #include "pipeline/jit/resource.h"
33 #include "frontend/optimizer/ad/adjoint.h"
34 #include "frontend/optimizer/ad/pynative_dfunctor.h"
35 #include "frontend/operator/ops.h"
36 #include "debug/trace.h"
37 #include "utils/utils.h"
38
39 namespace mindspore {
40 namespace ad {
41 using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
42 class KPrim;
43 extern KPrim g_k_prims;
44 class DFunctor;
45 using DFunctorPtr = std::shared_ptr<DFunctor>;
46
47 // Flag to control if fv should be lifted before grad. If this lift_fv feature is mature, then this flag can be removed.
48 extern bool lift_fv_before_grad;
49
50 // D Functor's rules to map closure object and morphisms.
51 class DFunctor : public std::enable_shared_from_this<DFunctor> {
52 public:
53 DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources);
54 ~DFunctor() = default;
55 // Map object in D category to K category.
56 void MapObject();
57 // Map morphism in D category to K category.
58 void MapMorphism();
59 FuncGraphPtr k_graph();
60 FuncGraphPtr tape();
61 // Construct user defined k object.
62 FuncGraphPtr KUserDefined(const FuncGraphPtr &primal);
63 // Register functor objects to form a global view.
64 void Init(bool is_top = false);
65 void Finish();
66
67 // Clear resources.
68 static void Clear();
69
70 friend class PynativeDFunctor;
71
72 private:
73 // Map one morphism.
74 AdjointPtr MapMorphism(const AnfNodePtr &morph);
75 bool IsFreeMorphism(const AnfNodePtr &node);
76 // Map morphism that's not attached to output.
77 void MapFreeMorphism();
78 void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din);
79 void BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env);
80 void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint);
81 AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv);
82 AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv);
83 // Map CNode/Index of Primitive to K.
84 AnfNodePtr MapPrimitiveToK(const CNodePtr &primitive_user, size_t index);
85 // Map ValueNode of FuncGraph to K.
86 AnfNodePtr MapFuncGraphToK(const AnfNodePtr &primal);
87 // Map ValueNode of Parameter to K.
88 AnfNodePtr MapParameterToK(const AnfNodePtr &primal);
89 // MapObject impls.
90 void MapFvObject();
91 void MapValueObject();
92 void MapParamObject();
93 // Find adjoint with its primary k.
94 AdjointPtr FindAdjoint(const AnfNodePtr &primal);
95 // Broadcast stop flags.
96 void BroadCastStopFlag();
97 bool AllReferencesStopped(const CNodePtr &node);
98 // Update k hole with adjoint_definition, only applied in recursive case.
99 void UpdateAdjoint(const AdjointPtr &adjoint_definition);
100 void CallDoutHoleOnTape();
101 // Replace the primal graph with k graph
102 void EliminatePrimalGraph();
103 // Pynative specialize
104 ValueNodePtr GenNewTensor(const CNodePtr &forward_node);
105 tensor::TensorPtr GenNewTensorInner(const TypePtr &type_elem, const BaseShapePtr &shape_elem);
106 void GetForwardOutNodeAndBpropGraph(const CNodePtr &k_app, CNodePtr *forward_node, FuncGraphPtr *bprop_graph,
107 FuncGraphPtr *fprop_graph);
108 std::vector<AnfNodePtr> RunOutputReplace(const CNodePtr &forward_node, const FuncGraphPtr &bprop_graph,
109 const FuncGraphPtr &fprop_graph, const CNodePtr &cnode_morph);
110 std::vector<AnfNodePtr> RunInputReplace(const FuncGraphPtr &bprop_graph, const FuncGraphPtr &fprop_graph,
111 const CNodePtr &cnode_morph);
112 void ReplaceEquivdout(const CNodePtr &k_app, const CNodePtr &cnode_morph);
113
114 std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_;
115 // Cache for indirect fv backpropagation, K o K can only do backprop layer by layer.
116 std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_indirect_fv_;
117 // Cache for fv node -> pair<embed<fv_node>, zeros_like<fv_node>>, so EnvGetItemTransform in optimizer
118 // can hit its cache if fv_node is same.
119 std::unordered_map<AnfNodePtr, std::pair<CNodePtr, CNodePtr>> anfnode_to_envitem_;
120 FuncGraphPtr primal_graph_;
121 // K object for primal_graph_;
122 FuncGraphPtr k_graph_;
123 // The Backprop part of k_graph_.
124 FuncGraphPtr tape_;
125 // Dout parameter for primal_graph_.
126 AnfNodePtr dout_;
127 pipeline::ResourceBasePtr resources_;
128 // Cut off stopped objects in category D.
129 bool need_cut_;
130 bool is_top_;
131 static std::unordered_map<FuncGraphPtr, std::shared_ptr<DFunctor>> func_graph_to_functor_;
132 static std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_definition_;
133 };
134
135 // D Functor's rules to map primitive object.
136 class KPrim {
137 public:
138 KPrim() = default;
139 ~KPrim() = default;
140
141 FuncGraphPtr KPrimitive(const CNodePtr &primal_user, const ValueNodePtr &value_node,
142 const pipeline::ResourceBasePtr &resources);
143 MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim);
144 // bprop_fg and primal_fg in bprop_fg's transforms are FuncGraph just after convert.
145 // current_primal_fg is the specialized and AutoMonaded primal_fg.
146 FuncGraphPtr KUserDefinedCellBprop(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg);
147
clear()148 void clear() {
149 bprop_registry_meta_.clear();
150 bprop_registry_.clear();
151 }
152 FuncGraphPtr GetPossibleBprop(const PrimitivePtr &prim);
153
154 #ifndef _WIN32
155 static void ExportBpropMindir(const py::object &obj);
156 #endif
157
158 private:
159 FuncGraphPtr GetBprop(const CNodePtr &cnode, const ValueNodePtr &value_node,
160 const pipeline::ResourceBasePtr &resources, const PrimitivePtr &prim);
161 FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources = nullptr);
162 FuncGraphPtr GetFprop(const PrimitivePtr &prim);
163 FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
164 FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
165 // Given a bprop rule, do the K mapping.
166 // current_primal_fg is only valid for user defined bprop for Cell, not for Primitive.
167 // Refer the comment in KUserDefinedCellBprop.
168 template <typename T>
169 FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g, const FuncGraphPtr ¤t_primal_fg,
170 const CNodePtr &cnode, const std::unordered_map<std::string, ValuePtr> &primal_attrs,
171 const std::vector<NodeDebugInfoPtr> &primal_debug_infos);
172 AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg);
173 void TransformArgsForPrimitive(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
174 const PrimitivePtr &primitive, const FuncGraphPtr &outer,
175 std::vector<AnfNodePtr> *const transf_args);
176 template <typename T>
177 void TransformArgsForFuncGraph(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
178 const T ¤t_primal_fg, const FuncGraphPtr &outer,
179 std::vector<AnfNodePtr> *const transf_args);
180 void CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check);
181
182 Registry bprop_registry_;
183 std::unordered_map<PrimitivePtr, MetaFuncGraphPtr> bprop_registry_meta_;
184 };
185
186 template <typename T>
BpropToK(const T & primal,const FuncGraphPtr & bprop_fg,const FuncGraphPtr & current_primal_fg,const CNodePtr & cnode,const std::unordered_map<std::string,ValuePtr> & primal_attrs,const std::vector<NodeDebugInfoPtr> & primal_debug_infos)187 FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg,
188 const CNodePtr &cnode, const std::unordered_map<std::string, ValuePtr> &primal_attrs,
189 const std::vector<NodeDebugInfoPtr> &primal_debug_infos) {
190 MS_EXCEPTION_IF_NULL(primal);
191 MS_EXCEPTION_IF_NULL(bprop_fg);
192 CheckBprop(bprop_fg, primal->ToString());
193
194 FuncGraphPtr cloned_bprop_fg;
195 {
196 PrimalAttrGuard primal_attr_guard(primal_attrs);
197 PrimalDebugInfoGuard primal_debug_info_guard(primal_debug_infos);
198 cloned_bprop_fg = BasicClone(bprop_fg);
199 }
200 MS_EXCEPTION_IF_NULL(cloned_bprop_fg);
201
202 GraphDebugInfoPtr debug_info = nullptr;
203 {
204 TraceGuard guard(std::make_shared<TraceCopy>(bprop_fg->debug_info()));
205 debug_info = std::make_shared<GraphDebugInfo>();
206 }
207 if (debug_info->trace_info() != nullptr && debug_info->trace_info()->debug_info() != nullptr) {
208 debug_info->trace_info()->debug_info()->set_name(primal->ToString());
209 }
210 cloned_bprop_fg->debug_info()->set_name("");
211 cloned_bprop_fg->debug_info()->set_trace_info(std::make_shared<TraceGradBprop>(debug_info));
212
213 // Make sure (out, dout) provided.
214 if (cloned_bprop_fg->parameters().size() < 2) {
215 MS_LOG(EXCEPTION)
216 << "The function 'bprop' of Primitive or Cell requires at least 2 params 'out' and 'dout', but got only "
217 << cloned_bprop_fg->parameters().size() << ".\n"
218 << trace::GetDebugInfo(cloned_bprop_fg->debug_info());
219 }
220 AnfNodePtr bout = BuildOutput(cloned_bprop_fg, current_primal_fg);
221 cloned_bprop_fg->set_output(bout);
222
223 FuncGraphPtr outer = nullptr;
224 {
225 auto outer_debug_info = std::make_shared<GraphDebugInfo>();
226 outer_debug_info->set_name(primal->ToString());
227 TraceGuard guard(std::make_shared<TraceGradFprop>(outer_debug_info));
228 outer = std::make_shared<FuncGraph>();
229 (void)outer->transforms().emplace("primal", FuncGraphTransform(primal));
230 outer->set_output(NewValueNode(kNone));
231 }
232
233 auto mng = Manage({cloned_bprop_fg, outer}, false);
234
235 // In a bprop definition, the last two param should be out and dout.
236 auto param_size = cloned_bprop_fg->parameters().size();
237 auto param_num = param_size - 1;
238 auto dout = cloned_bprop_fg->parameters()[param_num];
239 param_num--;
240 auto out_param = cloned_bprop_fg->parameters()[param_num];
241
242 std::vector<AnfNodePtr> transf_args;
243
244 if constexpr (std::is_same<T, PrimitivePtr>::value) {
245 PrimitivePtr primitive = primal;
246 TransformArgsForPrimitive(mng, cloned_bprop_fg, primal, outer, &transf_args);
247 (void)transf_args.insert(transf_args.begin(), NewValueNode(primal));
248 } else {
249 TransformArgsForFuncGraph(mng, cloned_bprop_fg, current_primal_fg, outer, &transf_args);
250 (void)transf_args.insert(transf_args.begin(), NewValueNode(current_primal_fg));
251 }
252 CNodePtr out_value = nullptr;
253 if (cnode != nullptr) { // Set equiv debug info. for Primitive CNode out.
254 TraceGuard trace_guard(std::make_shared<TraceEquiv>(cnode->debug_info()));
255 out_value = outer->NewCNode(transf_args);
256 if constexpr (std::is_same<T, PrimitivePtr>::value) {
257 out_value->CloneCNodeInfo(cnode);
258 }
259 } else {
260 out_value = outer->NewCNode(transf_args);
261 }
262 (void)mng->Replace(out_param, out_value);
263
264 TraceGuard guard(std::make_shared<TraceGradSens>(out_param->debug_info()));
265 auto new_dout = cloned_bprop_fg->add_parameter();
266 (void)mng->Replace(dout, new_dout);
267 // We remove all parameters except new_dout.
268 std::vector<AnfNodePtr> newBpropParams = {new_dout};
269 cloned_bprop_fg->set_parameters(newBpropParams);
270 outer->set_output(outer->NewCNode({NewValueNode(prim::kPrimMakeTuple), out_value, NewValueNode(cloned_bprop_fg)}));
271 return BasicClone(outer);
272 }
273 } // namespace ad
274 } // namespace mindspore
275
276 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_D_FUNCTOR_H_
277