• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2020-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_D_FUNCTOR_H_
20 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_D_FUNCTOR_H_
21 
22 #include <memory>
23 #include <string>
24 #include <vector>
25 #include <iostream>
26 #include <utility>
27 #include <unordered_map>
28 
29 #include "utils/hash_map.h"
30 #include "mindspore/core/ops/sequence_ops.h"
31 #include "ir/anf.h"
32 #include "ir/meta_func_graph.h"
33 #include "ir/func_graph_cloner.h"
34 #include "pipeline/jit/ps/resource.h"
35 #include "frontend/optimizer/ad/adjoint.h"
36 #include "frontend/operator/ops.h"
37 #include "pipeline/jit/ps/debug/trace.h"
38 #include "include/common/utils/utils.h"
39 
40 namespace mindspore {
41 namespace ad {
42 using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
43 class KPrim;
44 extern KPrim g_k_prims;
45 class DFunctor;
46 using DFunctorPtr = std::shared_ptr<DFunctor>;
47 
48 // Flag to control if fv should be lifted before grad. If this lift_fv feature is mature, then this flag can be removed.
49 extern bool lift_fv_before_grad;
50 
51 // D Functor's rules to map closure object and morphisms.
52 class DFunctor : public std::enable_shared_from_this<DFunctor> {
53  public:
54   DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources, bool is_top);
55   ~DFunctor() = default;
56   // Map object in D category to K category.
57   void MapObject();
58   // Map morphism in D category to K category.
59   void MapMorphism();
60   FuncGraphPtr k_graph();
61   FuncGraphPtr tape();
62   // Construct user defined k object.
63   FuncGraphPtr KUserDefined(const FuncGraphPtr &primal);
64   // Register functor objects to form a global view.
65   void Init(bool is_top = false);
66   void Finish();
67 
68   // Clear resources.
69   static void Clear();
70 
71   friend class PynativeDFunctor;
72 
73  private:
74   // Map one morphism.
75   AdjointPtr MapMorphism(const AnfNodePtr &morph);
76   bool IsFreeMorphism(const AnfNodePtr &node);
77   // Map morphism that's not attached to output.
78   void MapFreeMorphism();
79   void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din);
80   void BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env);
81   void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint,
82                      bool side_effect_bprop_app_propagate = false);
83   AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv);
84   AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv);
85   // Map CNode/Index of Primitive to K.
86   AnfNodePtr MapPrimitiveToK(const CNodePtr &primitive_user, size_t index);
87   // Map ValueNode of FuncGraph to K.
88   AnfNodePtr MapFuncGraphToK(const AnfNodePtr &primal);
89   // Map ValueNode of Parameter to K.
90   AnfNodePtr MapParameterToK(const AnfNodePtr &primal);
91   // MapObject impls.
92   void MapFvObject();
93   void MapValueObject();
94   void MapParamObject();
95   // Find adjoint with its primary k.
96   AdjointPtr FindAdjoint(const AnfNodePtr &primal) const;
97   // Broadcast stop flags.
98   void BroadCastStopFlag();
99   bool AllReferencesStopped(const CNodePtr &node);
100   // Update k hole with adjoint_definition, only applied in recursive case.
101   void UpdateAdjoint(const AdjointPtr &adjoint_definition);
102   void CallDoutHoleOnTape() const;
103   // Replace the primal graph with k graph
104   void EliminatePrimalGraph();
105   // Pynative specialize
106   ValueNodePtr GenNewTensor(const CNodePtr &forward_node);
107   tensor::TensorPtr GenNewTensorInner(const TypePtr &type_elem, const BaseShapePtr &shape_elem);
108   void GetForwardOutNodeAndBpropGraph(const CNodePtr &k_app, CNodePtr *forward_node, FuncGraphPtr *bprop_graph,
109                                       FuncGraphPtr *fprop_graph);
110   std::vector<AnfNodePtr> RunOutputReplace(const CNodePtr &forward_node, const FuncGraphPtr &bprop_graph,
111                                            const FuncGraphPtr &fprop_graph, const CNodePtr &cnode_morph);
112   std::vector<AnfNodePtr> RunInputReplace(const FuncGraphPtr &bprop_graph, const FuncGraphPtr &fprop_graph,
113                                           const CNodePtr &cnode_morph);
114   void ReplaceEquivdout(const CNodePtr &k_app, const CNodePtr &cnode_morph);
115 
116   mindspore::HashMap<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_;
117   // Cache for indirect fv backpropagation, K o K can only do backprop layer by layer.
118   mindspore::HashMap<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_indirect_fv_;
119   // Cache for fv node -> pair<embed<fv_node>, zeros_like<fv_node>>, so EnvironGetTransform in optimizer
120   // can hit its cache if fv_node is same.
121   mindspore::HashMap<AnfNodePtr, std::pair<CNodePtr, CNodePtr>> anfnode_to_envitem_;
122   FuncGraphPtr primal_graph_;
123   // K object for primal_graph_;
124   FuncGraphPtr k_graph_;
125   // The Backprop part of k_graph_.
126   FuncGraphPtr tape_;
127   // Dout parameter for primal_graph_.
128   AnfNodePtr dout_;
129   pipeline::ResourceBasePtr resources_;
130   // Cut off stopped objects in category D.
131   bool need_cut_;
132   bool is_top_;
133   static mindspore::HashMap<FuncGraphPtr, std::shared_ptr<DFunctor>> func_graph_to_functor_;
134   static mindspore::HashMap<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_definition_;
135 };
136 
137 // D Functor's rules to map primitive object.
138 class KPrim {
139  public:
140   KPrim() = default;
141   ~KPrim() = default;
142 
143   FuncGraphPtr KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_node,
144                           const pipeline::ResourceBasePtr &resources);
145   MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim);
146   // bprop_fg and primal_fg in bprop_fg's transforms are FuncGraph just after convert.
147   // current_primal_fg is the specialized and AutoMonaded primal_fg.
148   FuncGraphPtr KUserDefinedCellBprop(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &current_primal_fg);
149 
150   bool CheckCustomVjp(const FuncGraphPtr &bprop_fg) const;
151   FuncGraphPtr GetCustomVjpBprop(const FuncGraphPtr &bprop_fg) const;
clear()152   void clear() {
153     bprop_registry_meta_.clear();
154     bprop_registry_.clear();
155   }
156 
157  private:
158   FuncGraphPtr GetFprop(const PrimitivePtr &prim) const;
159   FuncGraphPtr GetPrimBprop(const PrimitivePtr &prim, const ValueNodePtr &value_node,
160                             const pipeline::ResourceBasePtr &resources, const CNodePtr &cnode = nullptr);
161   FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) const;
162   FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) const;
163   // Given a bprop rule, do the K mapping.
164   // current_primal_fg is only valid for user defined bprop for Cell, not for Primitive.
165   // Refer the comment in KUserDefinedCellBprop.
166   template <typename T>
167   FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &current_primal_fg,
168                         const CNodePtr &cnode, const mindspore::HashMap<std::string, ValuePtr> &primal_attrs,
169                         const std::vector<NodeDebugInfoPtr> &primal_debug_infos);
170   AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &current_primal_fg) const;
171   void TransformArgsForPrimitive(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
172                                  const PrimitivePtr &primitive, const FuncGraphPtr &outer,
173                                  std::vector<AnfNodePtr> *const transf_args) const;
174   template <typename T>
175   void TransformArgsForFuncGraph(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
176                                  const T &current_primal_fg, const FuncGraphPtr &outer,
177                                  std::vector<AnfNodePtr> *const transf_args) const;
178   void CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) const;
179 
180   Registry bprop_registry_;
181   mindspore::HashMap<PrimitivePtr, MetaFuncGraphPtr> bprop_registry_meta_;
182 };
183 
184 template <typename T>
BpropToK(const T & primal,const FuncGraphPtr & bprop_fg,const FuncGraphPtr & current_primal_fg,const CNodePtr & cnode,const mindspore::HashMap<std::string,ValuePtr> & primal_attrs,const std::vector<NodeDebugInfoPtr> & primal_debug_infos)185 FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &current_primal_fg,
186                              const CNodePtr &cnode, const mindspore::HashMap<std::string, ValuePtr> &primal_attrs,
187                              const std::vector<NodeDebugInfoPtr> &primal_debug_infos) {
188   MS_EXCEPTION_IF_NULL(primal);
189   MS_EXCEPTION_IF_NULL(bprop_fg);
190   CheckBprop(bprop_fg, primal->ToString());
191 
192   FuncGraphPtr cloned_bprop_fg;
193   {
194     PrimalAttrGuard primal_attr_guard(primal_attrs);
195     PrimalDebugInfoGuard primal_debug_info_guard(primal_debug_infos);
196     if (bprop_fg->has_flag(mindspore::kFuncGraphFlagMetaFuncGraphBprop) &&
197         (cnode == nullptr || !cnode->primal_attrs().empty())) {
198       cloned_bprop_fg = BasicClone(bprop_fg, true);
199     } else {
200       cloned_bprop_fg = BasicClone(bprop_fg);
201     }
202   }
203   MS_EXCEPTION_IF_NULL(cloned_bprop_fg);
204 
205   GraphDebugInfoPtr debug_info = nullptr;
206   {
207     TraceGuard guard(std::make_shared<TraceCopy>(bprop_fg->debug_info()));
208     debug_info = std::make_shared<GraphDebugInfo>();
209   }
210   if (debug_info->trace_info() != nullptr && debug_info->trace_info()->debug_info() != nullptr) {
211     debug_info->trace_info()->debug_info()->set_name(primal->ToString());
212   }
213   cloned_bprop_fg->debug_info()->set_name("");
214   cloned_bprop_fg->debug_info()->set_trace_info(std::make_shared<TraceGradBprop>(debug_info));
215 
216   // Make sure (out, dout) provided.
217   constexpr auto number_two = 2;
218   if (cloned_bprop_fg->parameters().size() < number_two) {
219     MS_LOG(EXCEPTION)
220       << "The function 'bprop' of Primitive or Cell requires at least 2 params 'out' and 'dout', but got only "
221       << cloned_bprop_fg->parameters().size() << ".\n"
222       << trace::GetDebugInfoStr(cloned_bprop_fg->debug_info());
223   }
224   AnfNodePtr bout = BuildOutput(cloned_bprop_fg, current_primal_fg);
225   cloned_bprop_fg->set_output(bout);
226 
227   FuncGraphPtr outer = nullptr;
228   {
229     auto outer_debug_info = std::make_shared<GraphDebugInfo>();
230     outer_debug_info->set_name(primal->ToString());
231     TraceGuard guard(std::make_shared<TraceGradFprop>(outer_debug_info));
232     outer = std::make_shared<FuncGraph>();
233     (void)outer->transforms().emplace("primal", FuncGraphTransform(primal));
234     outer->set_output(NewValueNode(kNone));
235   }
236 
237   auto mng = Manage({cloned_bprop_fg, outer}, false);
238 
239   // In a bprop definition, the last two param should be out and dout.
240   auto param_size = cloned_bprop_fg->parameters().size();
241   auto param_num = param_size - 1;
242   auto dout = cloned_bprop_fg->parameters()[param_num];
243   param_num--;
244   auto out_param = cloned_bprop_fg->parameters()[param_num];
245 
246   std::vector<AnfNodePtr> transf_args;
247 
248   if constexpr (std::is_same<T, PrimitivePtr>::value) {
249     PrimitivePtr primitive = primal;
250     auto prim_recompute_attr = primitive->GetAttr(kAttrRecompute);
251     if (prim_recompute_attr != nullptr && prim_recompute_attr->isa<BoolImm>() && GetValue<bool>(prim_recompute_attr)) {
252       cloned_bprop_fg->set_flag(FUNC_GRAPH_RECOMPUTE_GRAD_GRAPH, true);
253     }
254     TransformArgsForPrimitive(mng, cloned_bprop_fg, primal, outer, &transf_args);
255     (void)transf_args.insert(transf_args.cbegin(), NewValueNode(primal));
256   } else {
257     TransformArgsForFuncGraph(mng, cloned_bprop_fg, current_primal_fg, outer, &transf_args);
258     (void)transf_args.insert(transf_args.cbegin(), NewValueNode(current_primal_fg));
259   }
260   CNodePtr out_value = nullptr;
261   if (cnode != nullptr) {  // Set equiv debug info. for Primitive CNode out.
262     TraceGuard trace_guard(std::make_shared<TraceEquiv>(cnode->debug_info()));
263     out_value = outer->NewCNode(transf_args);
264     if constexpr (std::is_same<T, PrimitivePtr>::value) {
265       out_value->CloneCNodeInfo(cnode);
266     }
267   } else {
268     out_value = outer->NewCNode(transf_args);
269   }
270   (void)mng->Replace(out_param, out_value);
271 
272   TraceGuard guard(std::make_shared<TraceGradSens>(out_param->debug_info()));
273   auto new_dout = cloned_bprop_fg->add_parameter();
274   (void)mng->Replace(dout, new_dout);
275   // We remove all parameters except new_dout.
276   std::vector<AnfNodePtr> newBpropParams = {new_dout};
277   cloned_bprop_fg->set_parameters(newBpropParams);
278   outer->set_output(outer->NewCNode({NewValueNode(prim::kPrimMakeTuple), out_value, NewValueNode(cloned_bprop_fg)}));
279   return BasicClone(outer);
280 }
281 
282 // Handle bprob of op which input dtype is real number and output dtype is complex number.
283 // If the dtype of a gradient(din) is complex number and the input of that is real number,
284 // only the real part of the gradient make sense in back propagate. So we handle it by
285 // insert a Real() ops after the gradient.
286 // input: AnfNode with input of op which input dtype is real number and output dtype is complex number.
287 // din: CNodePtr with gradient of input.
288 // fg: Funcgraph witch input and din belong to.
289 // return: New din with inserted real op if necessarily.
290 AnfNodePtr HandleRealToComplex(const AnfNodePtr &input, const CNodePtr &din, const FuncGraphPtr &fg);
291 }  // namespace ad
292 }  // namespace mindspore
293 
294 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_D_FUNCTOR_H_
295