• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/optimizer/ad/pynative_dfunctor.h"
18 
19 #include <memory>
20 #include <vector>
21 
22 namespace mindspore {
23 namespace ad {
GenNewTensorInner(const TypePtr & type_elem,const BaseShapePtr & shape_elem)24 tensor::TensorPtr PynativeDFunctor::GenNewTensorInner(const TypePtr &type_elem, const BaseShapePtr &shape_elem) {
25   MS_EXCEPTION_IF_NULL(type_elem);
26   MS_EXCEPTION_IF_NULL(shape_elem);
27   auto shape = shape_elem->cast<abstract::ShapePtr>();
28   MS_EXCEPTION_IF_NULL(shape);
29   auto tensor_type = type_elem->cast<TensorTypePtr>();
30   MS_EXCEPTION_IF_NULL(tensor_type);
31   auto type = tensor_type->element();
32   MS_EXCEPTION_IF_NULL(type);
33   return std::make_shared<tensor::Tensor>(type->type_id(), shape->shape());
34 }
35 
GenNewTensor(const CNodePtr & cnode_morph)36 ValueNodePtr PynativeDFunctor::GenNewTensor(const CNodePtr &cnode_morph) {
37   MS_EXCEPTION_IF_NULL(cnode_morph);
38   if (cnode_morph->forward().first != nullptr) {
39     return cnode_morph->forward().first;
40   }
41   if (IsPrimitiveCNode(cnode_morph, prim::kPrimUpdateState)) {
42     ValueNodePtr out_vnode = NewValueNode(std::make_shared<UMonad>());
43     out_vnode->set_abstract(std::make_shared<abstract::AbstractUMonad>());
44     return out_vnode;
45   }
46 
47   auto cnode_shape = cnode_morph->Shape();
48   MS_EXCEPTION_IF_NULL(cnode_shape);
49   auto cnode_type = cnode_morph->Type();
50   MS_EXCEPTION_IF_NULL(cnode_type);
51   // Create output values.
52   if (cnode_type->isa<Tuple>()) {
53     auto tuple_shape = cnode_shape->cast<abstract::TupleShapePtr>();
54     MS_EXCEPTION_IF_NULL(tuple_shape);
55     auto tuple_type = cnode_type->cast<TuplePtr>();
56     MS_EXCEPTION_IF_NULL(tuple_type);
57     size_t output_num = tuple_type->elements().size();
58     std::vector<ValuePtr> output_values;
59     for (size_t i = 0; i < output_num; ++i) {
60       auto shape_elem = tuple_shape->shape()[i];
61       auto type_elem = tuple_type->elements()[i];
62       output_values.push_back(GenNewTensorInner(type_elem, shape_elem));
63     }
64     if (output_values.empty()) {
65       MS_LOG(EXCEPTION) << "The output values is empty, cnode morph: " << cnode_morph->DebugString();
66     }
67     return NewValueNode(std::make_shared<ValueTuple>(output_values));
68   } else if (cnode_type->isa<TensorType>()) {
69     return NewValueNode(GenNewTensorInner(cnode_type, cnode_shape));
70   } else if (cnode_shape->isa<abstract::NoShape>()) {
71     ShapeVector NoShape;
72     return NewValueNode(std::make_shared<tensor::Tensor>(cnode_type->type_id(), NoShape));
73   }
74 
75   MS_LOG(EXCEPTION) << "Unknown shape: " << cnode_shape->ToString() << ", type: " << cnode_type->ToString();
76 }
77 
GetForwardOutNodeAndBpropGraph(const CNodePtr & k_app,CNodePtr * forward_node,FuncGraphPtr * bprop_graph,FuncGraphPtr * fprop_graph)78 void PynativeDFunctor::GetForwardOutNodeAndBpropGraph(const CNodePtr &k_app, CNodePtr *forward_node,
79                                                       FuncGraphPtr *bprop_graph, FuncGraphPtr *fprop_graph) {
80   MS_EXCEPTION_IF_NULL(k_app);
81   MS_EXCEPTION_IF_NULL(fprop_graph);
82   const auto &prim = k_app->input(0);
83   if (!IsValueNode<FuncGraph>(prim)) {
84     return;
85   }
86   // Clone a new fprop graph for different k_app.
87   auto original_fprop = GetValueNode<FuncGraphPtr>(prim);
88   MS_EXCEPTION_IF_NULL(original_fprop);
89   *fprop_graph = BasicClone(original_fprop);
90   k_app->set_input(0, NewValueNode(*fprop_graph));
91 
92   // {prim::maketuple, forward_output, bprop_graph}
93   auto output = (*fprop_graph)->output();
94   MS_EXCEPTION_IF_NULL(output);
95   if (!output->isa<CNode>()) {
96     return;
97   }
98   auto make_tuple_node = output->cast<CNodePtr>();
99   MS_EXCEPTION_IF_NULL(make_tuple_node);
100   constexpr size_t input_size = 3;
101   if (make_tuple_node->size() != input_size) {
102     MS_LOG(EXCEPTION) << "The inputs size of make tuple node " << make_tuple_node->DebugString() << " is not equal to "
103                       << input_size;
104   }
105 
106   // Get forward CNode.
107   const size_t forward_output_index = 1;
108   const auto &output_node = make_tuple_node->input(forward_output_index);
109   MS_EXCEPTION_IF_NULL(output_node);
110   if (!output_node->isa<CNode>()) {
111     return;
112   }
113 
114   // Get bprop graph of forward CNode.
115   const size_t bprop_graph_index = 2;
116   const auto &bprop_vnode = make_tuple_node->input(bprop_graph_index);
117   if (!IsValueNode<FuncGraph>(bprop_vnode)) {
118     return;
119   }
120 
121   MS_EXCEPTION_IF_NULL(forward_node);
122   MS_EXCEPTION_IF_NULL(bprop_graph);
123   *forward_node = output_node->cast<CNodePtr>();
124   *bprop_graph = GetValueNode<FuncGraphPtr>(bprop_vnode);
125 }
126 
RunOutputReplace(const CNodePtr & forward_node,const FuncGraphPtr & bprop_graph,const FuncGraphPtr & fprop_graph,const CNodePtr & cnode_morph)127 std::vector<AnfNodePtr> PynativeDFunctor::RunOutputReplace(const CNodePtr &forward_node,
128                                                            const FuncGraphPtr &bprop_graph,
129                                                            const FuncGraphPtr &fprop_graph,
130                                                            const CNodePtr &cnode_morph) {
131   MS_EXCEPTION_IF_NULL(cnode_morph);
132   if (IsPrimitiveCNode(cnode_morph, prim::kPrimStopGradient)) {
133     return {};
134   }
135   // Use manager to get the link relation among nodes.
136   MS_EXCEPTION_IF_NULL(bprop_graph);
137   MS_EXCEPTION_IF_NULL(fprop_graph);
138   auto manager = Manage({fprop_graph, bprop_graph}, false);
139 
140   // Replace output node.
141   MS_EXCEPTION_IF_NULL(forward_node);
142   auto ref_size = manager->node_users().at(forward_node).size();
143   MS_LOG(DEBUG) << "Ref size: " << ref_size;
144   auto output_vnode = GenNewTensor(cnode_morph);
145   MS_EXCEPTION_IF_NULL(output_vnode);
146   output_vnode->set_has_new_value(true);
147   manager->Replace(forward_node, output_vnode);
148   MS_LOG(DEBUG) << "Replace: " << forward_node->DebugString() << " with " << output_vnode->ToString();
149 
150   // Save forward output node when it used in its bprop graph.
151   std::vector<AnfNodePtr> used_forward_nodes;
152   constexpr size_t ref_twice = 2;
153   if (ref_size >= ref_twice) {
154     cnode_morph->set_forward(output_vnode, "");
155     used_forward_nodes.push_back(cnode_morph);
156     MS_LOG(DEBUG) << "node has been used in grad graph: " << cnode_morph->DebugString()
157                   << ", its output value: " << output_vnode->ToString();
158   }
159   return used_forward_nodes;
160 }
161 
RunInputReplace(const FuncGraphPtr & bprop_graph,const FuncGraphPtr & fprop_graph,const CNodePtr & cnode_morph)162 std::vector<AnfNodePtr> PynativeDFunctor::RunInputReplace(const FuncGraphPtr &bprop_graph,
163                                                           const FuncGraphPtr &fprop_graph,
164                                                           const CNodePtr &cnode_morph) {
165   // Use manager to get the link relation among nodes.
166   MS_EXCEPTION_IF_NULL(bprop_graph);
167   MS_EXCEPTION_IF_NULL(fprop_graph);
168   auto manager = Manage({fprop_graph, bprop_graph}, false);
169 
170   MS_EXCEPTION_IF_NULL(cnode_morph);
171   const auto &paras = fprop_graph->parameters();
172   if (cnode_morph->size() - 1 != paras.size() && !IsPrimitiveCNode(cnode_morph, prim::kPrimUpdateState)) {
173     MS_LOG(EXCEPTION) << "The size of parameters in fprop graph:" << paras.size()
174                       << ", but the size of input tensors of forward node: " << cnode_morph->inputs().size() - 1;
175   }
176 
177   std::vector<AnfNodePtr> used_input_nodes;
178   for (size_t i = 0; i < paras.size(); ++i) {
179     const auto &input_node = cnode_morph->input(i + 1);
180     MS_EXCEPTION_IF_NULL(input_node);
181     // Parameter, ValueNode and StopGradient CNode no need to replace.
182     if (input_node->isa<Parameter>() || input_node->isa<ValueNode>() ||
183         IsPrimitiveCNode(input_node, prim::kPrimStopGradient)) {
184       continue;
185     }
186     // Replace forward input node by its output value.
187     auto para_ref_size = manager->node_users()[paras[i]].size();
188     CNodePtr cnode_i = input_node->cast<CNodePtr>();
189     MS_EXCEPTION_IF_NULL(cnode_i);
190     auto output_vnode_i = GenNewTensor(cnode_i);
191     MS_EXCEPTION_IF_NULL(output_vnode_i);
192     output_vnode_i->set_has_new_value(true);
193     manager->Replace(paras[i], output_vnode_i);
194     MS_LOG(DEBUG) << "Replace: " << paras[i]->DebugString() << " with " << output_vnode_i->ToString();
195     // Save forward input node when it used in bprop graph.
196     if (para_ref_size > 0 && !IsPrimitiveCNode(input_node, prim::kPrimUpdateState)) {
197       cnode_i->set_forward(output_vnode_i, "");
198       used_input_nodes.push_back(cnode_i);
199       MS_LOG(DEBUG) << "Input CNode has been used in grad graph: " << cnode_i->DebugString()
200                     << ", its output value: " << output_vnode_i->ToString();
201     }
202   }
203 
204   return used_input_nodes;
205 }
206 
ReplaceEquivdout(const CNodePtr & k_app,const CNodePtr & cnode_morph)207 void PynativeDFunctor::ReplaceEquivdout(const CNodePtr &k_app, const CNodePtr &cnode_morph) {
208   // The process of replacing forward node only works in pynative mode, when @ms_function is used.
209   MS_EXCEPTION_IF_NULL(cnode_morph);
210   MS_LOG(DEBUG) << "Run replace for cnode morph: " << cnode_morph->DebugString(2);
211   // Get forward node and its fprop graph, bprop graph.
212   MS_EXCEPTION_IF_NULL(k_app);
213   CNodePtr forward_node = nullptr;
214   FuncGraphPtr bprop_graph = nullptr;
215   FuncGraphPtr fprop_graph = nullptr;
216   GetForwardOutNodeAndBpropGraph(k_app, &forward_node, &bprop_graph, &fprop_graph);
217   if (forward_node == nullptr || bprop_graph == nullptr || fprop_graph == nullptr) {
218     return;
219   }
220 
221   // Replace forward node used in bprop graph by its output tensors. The same process for its input node.
222   auto used_forward_nodes = RunOutputReplace(forward_node, bprop_graph, fprop_graph, cnode_morph);
223   auto used_input_nodes = RunInputReplace(bprop_graph, fprop_graph, cnode_morph);
224 
225   // Save used forward input and output nodes to func_graph.
226   auto ms_func_graph = cnode_morph->func_graph();
227   MS_EXCEPTION_IF_NULL(ms_func_graph);
228   ms_func_graph->set_used_forward_nodes(used_forward_nodes);
229   ms_func_graph->set_used_forward_nodes(used_input_nodes);
230 }
231 }  // namespace ad
232 }  // namespace mindspore
233