• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2020-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_D_FUNCTOR_H_
20 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_D_FUNCTOR_H_
21 
22 #include <memory>
23 #include <string>
24 #include <unordered_map>
25 #include <vector>
26 #include <iostream>
27 #include <utility>
28 
29 #include "ir/anf.h"
30 #include "ir/meta_func_graph.h"
31 #include "ir/func_graph_cloner.h"
32 #include "pipeline/jit/resource.h"
33 #include "frontend/optimizer/ad/adjoint.h"
34 #include "frontend/optimizer/ad/pynative_dfunctor.h"
35 #include "frontend/operator/ops.h"
36 #include "debug/trace.h"
37 #include "utils/utils.h"
38 
39 namespace mindspore {
40 namespace ad {
41 using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
42 class KPrim;
43 extern KPrim g_k_prims;
44 class DFunctor;
45 using DFunctorPtr = std::shared_ptr<DFunctor>;
46 
47 // Flag to control if fv should be lifted before grad. If this lift_fv feature is mature, then this flag can be removed.
48 extern bool lift_fv_before_grad;
49 
50 // D Functor's rules to map closure object and morphisms.
51 class DFunctor : public std::enable_shared_from_this<DFunctor> {
52  public:
53   DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources);
54   ~DFunctor() = default;
55   // Map object in D category to K category.
56   void MapObject();
57   // Map morphism in D category to K category.
58   void MapMorphism();
59   FuncGraphPtr k_graph();
60   FuncGraphPtr tape();
61   // Construct user defined k object.
62   FuncGraphPtr KUserDefined(const FuncGraphPtr &primal);
63   // Register functor objects to form a global view.
64   void Init(bool is_top = false);
65   void Finish();
66 
67   // Clear resources.
68   static void Clear();
69 
70   friend class PynativeDFunctor;
71 
72  private:
73   // Map one morphism.
74   AdjointPtr MapMorphism(const AnfNodePtr &morph);
75   bool IsFreeMorphism(const AnfNodePtr &node);
76   // Map morphism that's not attached to output.
77   void MapFreeMorphism();
78   void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din);
79   void BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env);
80   void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint);
81   AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv);
82   AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv);
83   // Map CNode/Index of Primitive to K.
84   AnfNodePtr MapPrimitiveToK(const CNodePtr &primitive_user, size_t index);
85   // Map ValueNode of FuncGraph to K.
86   AnfNodePtr MapFuncGraphToK(const AnfNodePtr &primal);
87   // Map ValueNode of Parameter to K.
88   AnfNodePtr MapParameterToK(const AnfNodePtr &primal);
89   // MapObject impls.
90   void MapFvObject();
91   void MapValueObject();
92   void MapParamObject();
93   // Find adjoint with its primary k.
94   AdjointPtr FindAdjoint(const AnfNodePtr &primal);
95   // Broadcast stop flags.
96   void BroadCastStopFlag();
97   bool AllReferencesStopped(const CNodePtr &node);
98   // Update k hole with adjoint_definition, only applied in recursive case.
99   void UpdateAdjoint(const AdjointPtr &adjoint_definition);
100   void CallDoutHoleOnTape();
101   // Replace the primal graph with k graph
102   void EliminatePrimalGraph();
103   // Pynative specialize
104   ValueNodePtr GenNewTensor(const CNodePtr &forward_node);
105   tensor::TensorPtr GenNewTensorInner(const TypePtr &type_elem, const BaseShapePtr &shape_elem);
106   void GetForwardOutNodeAndBpropGraph(const CNodePtr &k_app, CNodePtr *forward_node, FuncGraphPtr *bprop_graph,
107                                       FuncGraphPtr *fprop_graph);
108   std::vector<AnfNodePtr> RunOutputReplace(const CNodePtr &forward_node, const FuncGraphPtr &bprop_graph,
109                                            const FuncGraphPtr &fprop_graph, const CNodePtr &cnode_morph);
110   std::vector<AnfNodePtr> RunInputReplace(const FuncGraphPtr &bprop_graph, const FuncGraphPtr &fprop_graph,
111                                           const CNodePtr &cnode_morph);
112   void ReplaceEquivdout(const CNodePtr &k_app, const CNodePtr &cnode_morph);
113 
114   std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_;
115   // Cache for indirect fv backpropagation, K o K can only do backprop layer by layer.
116   std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_indirect_fv_;
117   // Cache for fv node -> pair<embed<fv_node>, zeros_like<fv_node>>, so EnvGetItemTransform in optimizer
118   // can hit its cache if fv_node is same.
119   std::unordered_map<AnfNodePtr, std::pair<CNodePtr, CNodePtr>> anfnode_to_envitem_;
120   FuncGraphPtr primal_graph_;
121   // K object for primal_graph_;
122   FuncGraphPtr k_graph_;
123   // The Backprop part of k_graph_.
124   FuncGraphPtr tape_;
125   // Dout parameter for primal_graph_.
126   AnfNodePtr dout_;
127   pipeline::ResourceBasePtr resources_;
128   // Cut off stopped objects in category D.
129   bool need_cut_;
130   bool is_top_;
131   static std::unordered_map<FuncGraphPtr, std::shared_ptr<DFunctor>> func_graph_to_functor_;
132   static std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_definition_;
133 };
134 
135 // D Functor's rules to map primitive object.
136 class KPrim {
137  public:
138   KPrim() = default;
139   ~KPrim() = default;
140 
141   FuncGraphPtr KPrimitive(const CNodePtr &primal_user, const ValueNodePtr &value_node,
142                           const pipeline::ResourceBasePtr &resources);
143   MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim);
144   // bprop_fg and primal_fg in bprop_fg's transforms are FuncGraph just after convert.
145   // current_primal_fg is the specialized and AutoMonaded primal_fg.
146   FuncGraphPtr KUserDefinedCellBprop(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &current_primal_fg);
147 
clear()148   void clear() {
149     bprop_registry_meta_.clear();
150     bprop_registry_.clear();
151   }
152   FuncGraphPtr GetPossibleBprop(const PrimitivePtr &prim);
153 
154 #ifndef _WIN32
155   static void ExportBpropMindir(const py::object &obj);
156 #endif
157 
158  private:
159   FuncGraphPtr GetBprop(const CNodePtr &cnode, const ValueNodePtr &value_node,
160                         const pipeline::ResourceBasePtr &resources, const PrimitivePtr &prim);
161   FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources = nullptr);
162   FuncGraphPtr GetFprop(const PrimitivePtr &prim);
163   FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
164   FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
165   // Given a bprop rule, do the K mapping.
166   // current_primal_fg is only valid for user defined bprop for Cell, not for Primitive.
167   // Refer the comment in KUserDefinedCellBprop.
168   template <typename T>
169   FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g, const FuncGraphPtr &current_primal_fg,
170                         const CNodePtr &cnode, const std::unordered_map<std::string, ValuePtr> &primal_attrs,
171                         const std::vector<NodeDebugInfoPtr> &primal_debug_infos);
172   AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &current_primal_fg);
173   void TransformArgsForPrimitive(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
174                                  const PrimitivePtr &primitive, const FuncGraphPtr &outer,
175                                  std::vector<AnfNodePtr> *const transf_args);
176   template <typename T>
177   void TransformArgsForFuncGraph(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
178                                  const T &current_primal_fg, const FuncGraphPtr &outer,
179                                  std::vector<AnfNodePtr> *const transf_args);
180   void CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check);
181 
182   Registry bprop_registry_;
183   std::unordered_map<PrimitivePtr, MetaFuncGraphPtr> bprop_registry_meta_;
184 };
185 
186 template <typename T>
BpropToK(const T & primal,const FuncGraphPtr & bprop_fg,const FuncGraphPtr & current_primal_fg,const CNodePtr & cnode,const std::unordered_map<std::string,ValuePtr> & primal_attrs,const std::vector<NodeDebugInfoPtr> & primal_debug_infos)187 FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &current_primal_fg,
188                              const CNodePtr &cnode, const std::unordered_map<std::string, ValuePtr> &primal_attrs,
189                              const std::vector<NodeDebugInfoPtr> &primal_debug_infos) {
190   MS_EXCEPTION_IF_NULL(primal);
191   MS_EXCEPTION_IF_NULL(bprop_fg);
192   CheckBprop(bprop_fg, primal->ToString());
193 
194   FuncGraphPtr cloned_bprop_fg;
195   {
196     PrimalAttrGuard primal_attr_guard(primal_attrs);
197     PrimalDebugInfoGuard primal_debug_info_guard(primal_debug_infos);
198     cloned_bprop_fg = BasicClone(bprop_fg);
199   }
200   MS_EXCEPTION_IF_NULL(cloned_bprop_fg);
201 
202   GraphDebugInfoPtr debug_info = nullptr;
203   {
204     TraceGuard guard(std::make_shared<TraceCopy>(bprop_fg->debug_info()));
205     debug_info = std::make_shared<GraphDebugInfo>();
206   }
207   if (debug_info->trace_info() != nullptr && debug_info->trace_info()->debug_info() != nullptr) {
208     debug_info->trace_info()->debug_info()->set_name(primal->ToString());
209   }
210   cloned_bprop_fg->debug_info()->set_name("");
211   cloned_bprop_fg->debug_info()->set_trace_info(std::make_shared<TraceGradBprop>(debug_info));
212 
213   // Make sure (out, dout) provided.
214   if (cloned_bprop_fg->parameters().size() < 2) {
215     MS_LOG(EXCEPTION)
216       << "The function 'bprop' of Primitive or Cell requires at least 2 params 'out' and 'dout', but got only "
217       << cloned_bprop_fg->parameters().size() << ".\n"
218       << trace::GetDebugInfo(cloned_bprop_fg->debug_info());
219   }
220   AnfNodePtr bout = BuildOutput(cloned_bprop_fg, current_primal_fg);
221   cloned_bprop_fg->set_output(bout);
222 
223   FuncGraphPtr outer = nullptr;
224   {
225     auto outer_debug_info = std::make_shared<GraphDebugInfo>();
226     outer_debug_info->set_name(primal->ToString());
227     TraceGuard guard(std::make_shared<TraceGradFprop>(outer_debug_info));
228     outer = std::make_shared<FuncGraph>();
229     (void)outer->transforms().emplace("primal", FuncGraphTransform(primal));
230     outer->set_output(NewValueNode(kNone));
231   }
232 
233   auto mng = Manage({cloned_bprop_fg, outer}, false);
234 
235   // In a bprop definition, the last two param should be out and dout.
236   auto param_size = cloned_bprop_fg->parameters().size();
237   auto param_num = param_size - 1;
238   auto dout = cloned_bprop_fg->parameters()[param_num];
239   param_num--;
240   auto out_param = cloned_bprop_fg->parameters()[param_num];
241 
242   std::vector<AnfNodePtr> transf_args;
243 
244   if constexpr (std::is_same<T, PrimitivePtr>::value) {
245     PrimitivePtr primitive = primal;
246     TransformArgsForPrimitive(mng, cloned_bprop_fg, primal, outer, &transf_args);
247     (void)transf_args.insert(transf_args.begin(), NewValueNode(primal));
248   } else {
249     TransformArgsForFuncGraph(mng, cloned_bprop_fg, current_primal_fg, outer, &transf_args);
250     (void)transf_args.insert(transf_args.begin(), NewValueNode(current_primal_fg));
251   }
252   CNodePtr out_value = nullptr;
253   if (cnode != nullptr) {  // Set equiv debug info. for Primitive CNode out.
254     TraceGuard trace_guard(std::make_shared<TraceEquiv>(cnode->debug_info()));
255     out_value = outer->NewCNode(transf_args);
256     if constexpr (std::is_same<T, PrimitivePtr>::value) {
257       out_value->CloneCNodeInfo(cnode);
258     }
259   } else {
260     out_value = outer->NewCNode(transf_args);
261   }
262   (void)mng->Replace(out_param, out_value);
263 
264   TraceGuard guard(std::make_shared<TraceGradSens>(out_param->debug_info()));
265   auto new_dout = cloned_bprop_fg->add_parameter();
266   (void)mng->Replace(dout, new_dout);
267   // We remove all parameters except new_dout.
268   std::vector<AnfNodePtr> newBpropParams = {new_dout};
269   cloned_bprop_fg->set_parameters(newBpropParams);
270   outer->set_output(outer->NewCNode({NewValueNode(prim::kPrimMakeTuple), out_value, NewValueNode(cloned_bprop_fg)}));
271   return BasicClone(outer);
272 }
273 }  // namespace ad
274 }  // namespace mindspore
275 
276 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_D_FUNCTOR_H_
277