1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2020-2021 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_D_FUNCTOR_H_
20 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_D_FUNCTOR_H_
21
22 #include <memory>
23 #include <string>
24 #include <vector>
25 #include <iostream>
26 #include <utility>
27 #include <unordered_map>
28
29 #include "utils/hash_map.h"
30 #include "mindspore/core/ops/sequence_ops.h"
31 #include "ir/anf.h"
32 #include "ir/meta_func_graph.h"
33 #include "ir/func_graph_cloner.h"
34 #include "pipeline/jit/ps/resource.h"
35 #include "frontend/optimizer/ad/adjoint.h"
36 #include "frontend/operator/ops.h"
37 #include "pipeline/jit/ps/debug/trace.h"
38 #include "include/common/utils/utils.h"
39
40 namespace mindspore {
41 namespace ad {
42 using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
43 class KPrim;
44 extern KPrim g_k_prims;
45 class DFunctor;
46 using DFunctorPtr = std::shared_ptr<DFunctor>;
47
48 // Flag to control if fv should be lifted before grad. If this lift_fv feature is mature, then this flag can be removed.
49 extern bool lift_fv_before_grad;
50
51 // D Functor's rules to map closure object and morphisms.
52 class DFunctor : public std::enable_shared_from_this<DFunctor> {
53 public:
54 DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources, bool is_top);
55 ~DFunctor() = default;
56 // Map object in D category to K category.
57 void MapObject();
58 // Map morphism in D category to K category.
59 void MapMorphism();
60 FuncGraphPtr k_graph();
61 FuncGraphPtr tape();
62 // Construct user defined k object.
63 FuncGraphPtr KUserDefined(const FuncGraphPtr &primal);
64 // Register functor objects to form a global view.
65 void Init(bool is_top = false);
66 void Finish();
67
68 // Clear resources.
69 static void Clear();
70
71 friend class PynativeDFunctor;
72
73 private:
74 // Map one morphism.
75 AdjointPtr MapMorphism(const AnfNodePtr &morph);
76 bool IsFreeMorphism(const AnfNodePtr &node);
77 // Map morphism that's not attached to output.
78 void MapFreeMorphism();
79 void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din);
80 void BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env);
81 void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint,
82 bool side_effect_bprop_app_propagate = false);
83 AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv);
84 AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv);
85 // Map CNode/Index of Primitive to K.
86 AnfNodePtr MapPrimitiveToK(const CNodePtr &primitive_user, size_t index);
87 // Map ValueNode of FuncGraph to K.
88 AnfNodePtr MapFuncGraphToK(const AnfNodePtr &primal);
89 // Map ValueNode of Parameter to K.
90 AnfNodePtr MapParameterToK(const AnfNodePtr &primal);
91 // MapObject impls.
92 void MapFvObject();
93 void MapValueObject();
94 void MapParamObject();
95 // Find adjoint with its primary k.
96 AdjointPtr FindAdjoint(const AnfNodePtr &primal) const;
97 // Broadcast stop flags.
98 void BroadCastStopFlag();
99 bool AllReferencesStopped(const CNodePtr &node);
100 // Update k hole with adjoint_definition, only applied in recursive case.
101 void UpdateAdjoint(const AdjointPtr &adjoint_definition);
102 void CallDoutHoleOnTape() const;
103 // Replace the primal graph with k graph
104 void EliminatePrimalGraph();
105 // Pynative specialize
106 ValueNodePtr GenNewTensor(const CNodePtr &forward_node);
107 tensor::TensorPtr GenNewTensorInner(const TypePtr &type_elem, const BaseShapePtr &shape_elem);
108 void GetForwardOutNodeAndBpropGraph(const CNodePtr &k_app, CNodePtr *forward_node, FuncGraphPtr *bprop_graph,
109 FuncGraphPtr *fprop_graph);
110 std::vector<AnfNodePtr> RunOutputReplace(const CNodePtr &forward_node, const FuncGraphPtr &bprop_graph,
111 const FuncGraphPtr &fprop_graph, const CNodePtr &cnode_morph);
112 std::vector<AnfNodePtr> RunInputReplace(const FuncGraphPtr &bprop_graph, const FuncGraphPtr &fprop_graph,
113 const CNodePtr &cnode_morph);
114 void ReplaceEquivdout(const CNodePtr &k_app, const CNodePtr &cnode_morph);
115
116 mindspore::HashMap<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_;
117 // Cache for indirect fv backpropagation, K o K can only do backprop layer by layer.
118 mindspore::HashMap<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_indirect_fv_;
119 // Cache for fv node -> pair<embed<fv_node>, zeros_like<fv_node>>, so EnvironGetTransform in optimizer
120 // can hit its cache if fv_node is same.
121 mindspore::HashMap<AnfNodePtr, std::pair<CNodePtr, CNodePtr>> anfnode_to_envitem_;
122 FuncGraphPtr primal_graph_;
123 // K object for primal_graph_;
124 FuncGraphPtr k_graph_;
125 // The Backprop part of k_graph_.
126 FuncGraphPtr tape_;
127 // Dout parameter for primal_graph_.
128 AnfNodePtr dout_;
129 pipeline::ResourceBasePtr resources_;
130 // Cut off stopped objects in category D.
131 bool need_cut_;
132 bool is_top_;
133 static mindspore::HashMap<FuncGraphPtr, std::shared_ptr<DFunctor>> func_graph_to_functor_;
134 static mindspore::HashMap<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_definition_;
135 };
136
137 // D Functor's rules to map primitive object.
138 class KPrim {
139 public:
140 KPrim() = default;
141 ~KPrim() = default;
142
143 FuncGraphPtr KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_node,
144 const pipeline::ResourceBasePtr &resources);
145 MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim);
146 // bprop_fg and primal_fg in bprop_fg's transforms are FuncGraph just after convert.
147 // current_primal_fg is the specialized and AutoMonaded primal_fg.
148 FuncGraphPtr KUserDefinedCellBprop(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg);
149
150 bool CheckCustomVjp(const FuncGraphPtr &bprop_fg) const;
151 FuncGraphPtr GetCustomVjpBprop(const FuncGraphPtr &bprop_fg) const;
clear()152 void clear() {
153 bprop_registry_meta_.clear();
154 bprop_registry_.clear();
155 }
156
157 private:
158 FuncGraphPtr GetFprop(const PrimitivePtr &prim) const;
159 FuncGraphPtr GetPrimBprop(const PrimitivePtr &prim, const ValueNodePtr &value_node,
160 const pipeline::ResourceBasePtr &resources, const CNodePtr &cnode = nullptr);
161 FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) const;
162 FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) const;
163 // Given a bprop rule, do the K mapping.
164 // current_primal_fg is only valid for user defined bprop for Cell, not for Primitive.
165 // Refer the comment in KUserDefinedCellBprop.
166 template <typename T>
167 FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg,
168 const CNodePtr &cnode, const mindspore::HashMap<std::string, ValuePtr> &primal_attrs,
169 const std::vector<NodeDebugInfoPtr> &primal_debug_infos);
170 AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg) const;
171 void TransformArgsForPrimitive(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
172 const PrimitivePtr &primitive, const FuncGraphPtr &outer,
173 std::vector<AnfNodePtr> *const transf_args) const;
174 template <typename T>
175 void TransformArgsForFuncGraph(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
176 const T ¤t_primal_fg, const FuncGraphPtr &outer,
177 std::vector<AnfNodePtr> *const transf_args) const;
178 void CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) const;
179
180 Registry bprop_registry_;
181 mindspore::HashMap<PrimitivePtr, MetaFuncGraphPtr> bprop_registry_meta_;
182 };
183
184 template <typename T>
BpropToK(const T & primal,const FuncGraphPtr & bprop_fg,const FuncGraphPtr & current_primal_fg,const CNodePtr & cnode,const mindspore::HashMap<std::string,ValuePtr> & primal_attrs,const std::vector<NodeDebugInfoPtr> & primal_debug_infos)185 FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg,
186 const CNodePtr &cnode, const mindspore::HashMap<std::string, ValuePtr> &primal_attrs,
187 const std::vector<NodeDebugInfoPtr> &primal_debug_infos) {
188 MS_EXCEPTION_IF_NULL(primal);
189 MS_EXCEPTION_IF_NULL(bprop_fg);
190 CheckBprop(bprop_fg, primal->ToString());
191
192 FuncGraphPtr cloned_bprop_fg;
193 {
194 PrimalAttrGuard primal_attr_guard(primal_attrs);
195 PrimalDebugInfoGuard primal_debug_info_guard(primal_debug_infos);
196 if (bprop_fg->has_flag(mindspore::kFuncGraphFlagMetaFuncGraphBprop) &&
197 (cnode == nullptr || !cnode->primal_attrs().empty())) {
198 cloned_bprop_fg = BasicClone(bprop_fg, true);
199 } else {
200 cloned_bprop_fg = BasicClone(bprop_fg);
201 }
202 }
203 MS_EXCEPTION_IF_NULL(cloned_bprop_fg);
204
205 GraphDebugInfoPtr debug_info = nullptr;
206 {
207 TraceGuard guard(std::make_shared<TraceCopy>(bprop_fg->debug_info()));
208 debug_info = std::make_shared<GraphDebugInfo>();
209 }
210 if (debug_info->trace_info() != nullptr && debug_info->trace_info()->debug_info() != nullptr) {
211 debug_info->trace_info()->debug_info()->set_name(primal->ToString());
212 }
213 cloned_bprop_fg->debug_info()->set_name("");
214 cloned_bprop_fg->debug_info()->set_trace_info(std::make_shared<TraceGradBprop>(debug_info));
215
216 // Make sure (out, dout) provided.
217 constexpr auto number_two = 2;
218 if (cloned_bprop_fg->parameters().size() < number_two) {
219 MS_LOG(EXCEPTION)
220 << "The function 'bprop' of Primitive or Cell requires at least 2 params 'out' and 'dout', but got only "
221 << cloned_bprop_fg->parameters().size() << ".\n"
222 << trace::GetDebugInfoStr(cloned_bprop_fg->debug_info());
223 }
224 AnfNodePtr bout = BuildOutput(cloned_bprop_fg, current_primal_fg);
225 cloned_bprop_fg->set_output(bout);
226
227 FuncGraphPtr outer = nullptr;
228 {
229 auto outer_debug_info = std::make_shared<GraphDebugInfo>();
230 outer_debug_info->set_name(primal->ToString());
231 TraceGuard guard(std::make_shared<TraceGradFprop>(outer_debug_info));
232 outer = std::make_shared<FuncGraph>();
233 (void)outer->transforms().emplace("primal", FuncGraphTransform(primal));
234 outer->set_output(NewValueNode(kNone));
235 }
236
237 auto mng = Manage({cloned_bprop_fg, outer}, false);
238
239 // In a bprop definition, the last two param should be out and dout.
240 auto param_size = cloned_bprop_fg->parameters().size();
241 auto param_num = param_size - 1;
242 auto dout = cloned_bprop_fg->parameters()[param_num];
243 param_num--;
244 auto out_param = cloned_bprop_fg->parameters()[param_num];
245
246 std::vector<AnfNodePtr> transf_args;
247
248 if constexpr (std::is_same<T, PrimitivePtr>::value) {
249 PrimitivePtr primitive = primal;
250 auto prim_recompute_attr = primitive->GetAttr(kAttrRecompute);
251 if (prim_recompute_attr != nullptr && prim_recompute_attr->isa<BoolImm>() && GetValue<bool>(prim_recompute_attr)) {
252 cloned_bprop_fg->set_flag(FUNC_GRAPH_RECOMPUTE_GRAD_GRAPH, true);
253 }
254 TransformArgsForPrimitive(mng, cloned_bprop_fg, primal, outer, &transf_args);
255 (void)transf_args.insert(transf_args.cbegin(), NewValueNode(primal));
256 } else {
257 TransformArgsForFuncGraph(mng, cloned_bprop_fg, current_primal_fg, outer, &transf_args);
258 (void)transf_args.insert(transf_args.cbegin(), NewValueNode(current_primal_fg));
259 }
260 CNodePtr out_value = nullptr;
261 if (cnode != nullptr) { // Set equiv debug info. for Primitive CNode out.
262 TraceGuard trace_guard(std::make_shared<TraceEquiv>(cnode->debug_info()));
263 out_value = outer->NewCNode(transf_args);
264 if constexpr (std::is_same<T, PrimitivePtr>::value) {
265 out_value->CloneCNodeInfo(cnode);
266 }
267 } else {
268 out_value = outer->NewCNode(transf_args);
269 }
270 (void)mng->Replace(out_param, out_value);
271
272 TraceGuard guard(std::make_shared<TraceGradSens>(out_param->debug_info()));
273 auto new_dout = cloned_bprop_fg->add_parameter();
274 (void)mng->Replace(dout, new_dout);
275 // We remove all parameters except new_dout.
276 std::vector<AnfNodePtr> newBpropParams = {new_dout};
277 cloned_bprop_fg->set_parameters(newBpropParams);
278 outer->set_output(outer->NewCNode({NewValueNode(prim::kPrimMakeTuple), out_value, NewValueNode(cloned_bprop_fg)}));
279 return BasicClone(outer);
280 }
281
282 // Handle bprob of op which input dtype is real number and output dtype is complex number.
283 // If the dtype of a gradient(din) is complex number and the input of that is real number,
284 // only the real part of the gradient make sense in back propagate. So we handle it by
285 // insert a Real() ops after the gradient.
286 // input: AnfNode with input of op which input dtype is real number and output dtype is complex number.
287 // din: CNodePtr with gradient of input.
288 // fg: Funcgraph witch input and din belong to.
289 // return: New din with inserted real op if necessarily.
290 AnfNodePtr HandleRealToComplex(const AnfNodePtr &input, const CNodePtr &din, const FuncGraphPtr &fg);
291 } // namespace ad
292 } // namespace mindspore
293
294 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_D_FUNCTOR_H_
295