• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/pynative/grad/jit/jit_dfunctor.h"
18 
19 #include <memory>
20 #include <string>
21 
22 #include "ir/func_graph_cloner.h"
23 #include "pipeline/pynative/pynative_utils.h"
24 
25 namespace mindspore {
26 namespace pynative {
27 namespace {
GenNewTensorInner(const TypePtr & type_elem,const BaseShapePtr & shape_elem)28 tensor::TensorPtr GenNewTensorInner(const TypePtr &type_elem, const BaseShapePtr &shape_elem) {
29   MS_EXCEPTION_IF_NULL(type_elem);
30   MS_EXCEPTION_IF_NULL(shape_elem);
31   auto shape = shape_elem->cast<abstract::ShapePtr>();
32   MS_EXCEPTION_IF_NULL(shape);
33   auto tensor_type = type_elem->cast<TensorTypePtr>();
34   MS_EXCEPTION_IF_NULL(tensor_type);
35   auto type = tensor_type->element();
36   MS_EXCEPTION_IF_NULL(type);
37   return std::make_shared<tensor::Tensor>(type->type_id(), shape->shape());
38 }
39 
NewValue(const TypePtr & type_elem,const BaseShapePtr & shape_elem)40 ValuePtr NewValue(const TypePtr &type_elem, const BaseShapePtr &shape_elem) {
41   MS_EXCEPTION_IF_NULL(type_elem);
42   MS_EXCEPTION_IF_NULL(shape_elem);
43   if (shape_elem->isa<abstract::TupleShape>()) {
44     auto tuple_shape = shape_elem->cast<abstract::TupleShapePtr>();
45     MS_EXCEPTION_IF_NULL(tuple_shape);
46     auto tuple_type = type_elem->cast<TuplePtr>();
47     MS_EXCEPTION_IF_NULL(tuple_type);
48     size_t output_num = tuple_type->elements().size();
49     std::vector<ValuePtr> value_list;
50     for (size_t i = 0; i < output_num; ++i) {
51       auto sub_shape_elem = tuple_shape->shape()[i];
52       auto sub_type_elem = tuple_type->elements()[i];
53       ValuePtr new_value = NewValue(sub_type_elem, sub_shape_elem);
54       value_list.push_back(new_value);
55     }
56     return std::make_shared<ValueTuple>(value_list);
57   }
58   if (type_elem->isa<TensorType>()) {
59     return GenNewTensorInner(type_elem, shape_elem);
60   }
61   if (shape_elem->isa<abstract::NoShape>()) {
62     ShapeVector NoShape;
63     if (type_elem->type_id() == kMetaTypeNone) {
64       return kNone;
65     }
66     return std::make_shared<tensor::Tensor>(type_elem->type_id(), NoShape);
67   }
68   MS_LOG(INTERNAL_EXCEPTION) << "Unknown shape: " << shape_elem->ToString() << ", type: " << type_elem->ToString();
69 }
70 
GenNewTensor(const CNodePtr & cnode_morph)71 ValueNodePtr GenNewTensor(const CNodePtr &cnode_morph) {
72   MS_EXCEPTION_IF_NULL(cnode_morph);
73   if (cnode_morph->forward().first != nullptr) {
74     return cnode_morph->forward().first;
75   }
76   if (IsPrimitiveCNode(cnode_morph, prim::kPrimUpdateState)) {
77     ValueNodePtr out_vnode = NewValueNode(std::make_shared<UMonad>());
78     out_vnode->set_abstract(std::make_shared<abstract::AbstractUMonad>());
79     return out_vnode;
80   }
81   // Function used to generate value node
82   auto gen_output_value_node = [](const ValuePtr &value) -> ValueNodePtr {
83     MS_EXCEPTION_IF_NULL(value);
84     auto v_node = NewValueNode(value);
85     v_node->set_abstract(value->ToAbstract()->Broaden());
86     return v_node;
87   };
88   // Create output value node for CNode
89   auto cnode_shape = cnode_morph->Shape();
90   MS_EXCEPTION_IF_NULL(cnode_shape);
91   auto cnode_type = cnode_morph->Type();
92   MS_EXCEPTION_IF_NULL(cnode_type);
93   if (cnode_type->isa<Tuple>()) {
94     auto tuple_shape = cnode_shape->cast<abstract::TupleShapePtr>();
95     MS_EXCEPTION_IF_NULL(tuple_shape);
96     auto tuple_type = cnode_type->cast<TuplePtr>();
97     MS_EXCEPTION_IF_NULL(tuple_type);
98     size_t output_num = tuple_type->elements().size();
99     MS_EXCEPTION_IF_CHECK_FAIL(output_num != 0, "No output value.");
100     std::vector<ValuePtr> output_values;
101     for (size_t i = 0; i < output_num; ++i) {
102       auto shape_elem = tuple_shape->shape()[i];
103       auto type_elem = tuple_type->elements()[i];
104       output_values.push_back(NewValue(type_elem, shape_elem));
105     }
106     auto value_tuple = std::make_shared<ValueTuple>(output_values);
107     return gen_output_value_node(value_tuple);
108   } else if (cnode_type->isa<List>()) {
109     auto list_shape = cnode_shape->cast<abstract::ListShapePtr>();
110     MS_EXCEPTION_IF_NULL(list_shape);
111     auto list_type = cnode_type->cast<ListPtr>();
112     MS_EXCEPTION_IF_NULL(list_type);
113     size_t output_num = list_type->elements().size();
114     MS_EXCEPTION_IF_CHECK_FAIL(output_num != 0, "No output value.");
115     std::vector<ValuePtr> output_values;
116     for (size_t i = 0; i < output_num; ++i) {
117       auto shape_elem = list_shape->shape()[i];
118       auto type_elem = list_type->elements()[i];
119       output_values.push_back(NewValue(type_elem, shape_elem));
120     }
121     auto value_tuple = std::make_shared<ValueList>(output_values);
122     return gen_output_value_node(value_tuple);
123   } else if (cnode_type->isa<TensorType>()) {
124     auto tensor_value = GenNewTensorInner(cnode_type, cnode_shape);
125     return gen_output_value_node(tensor_value);
126   } else if (cnode_shape->isa<abstract::NoShape>()) {
127     ShapeVector NoShape;
128     auto tensor_value = std::make_shared<tensor::Tensor>(cnode_type->type_id(), NoShape);
129     return gen_output_value_node(tensor_value);
130   }
131   MS_LOG(INTERNAL_EXCEPTION) << "Unknown shape: " << cnode_shape->ToString() << ", type: " << cnode_type->ToString();
132 }
133 
GetForwardOutNodeAndBpropGraph(const CNodePtr & k_app,CNodePtr * forward_node,FuncGraphPtr * bprop_graph,FuncGraphPtr * fprop_graph)134 void GetForwardOutNodeAndBpropGraph(const CNodePtr &k_app, CNodePtr *forward_node, FuncGraphPtr *bprop_graph,
135                                     FuncGraphPtr *fprop_graph) {
136   MS_EXCEPTION_IF_NULL(k_app);
137   MS_EXCEPTION_IF_NULL(fprop_graph);
138   const auto &prim = k_app->input(0);
139   if (!IsValueNode<FuncGraph>(prim)) {
140     return;
141   }
142   // Clone a new fprop graph for different k_app.
143   auto original_fprop = GetValueNode<FuncGraphPtr>(prim);
144   MS_EXCEPTION_IF_NULL(original_fprop);
145   *fprop_graph = BasicClone(original_fprop);
146   k_app->set_input(0, NewValueNode(*fprop_graph));
147 
148   // {prim::maketuple, forward_output, bprop_graph}
149   auto output = (*fprop_graph)->output();
150   MS_EXCEPTION_IF_NULL(output);
151   if (!output->isa<CNode>()) {
152     return;
153   }
154   auto make_tuple_node = output->cast<CNodePtr>();
155   MS_EXCEPTION_IF_NULL(make_tuple_node);
156   constexpr size_t input_size = 3;
157   if (make_tuple_node->size() != input_size) {
158     MS_LOG(INTERNAL_EXCEPTION) << "The inputs size of make tuple node " << make_tuple_node->DebugString()
159                                << " is not equal to " << input_size;
160   }
161 
162   // Get forward CNode.
163   constexpr size_t forward_output_index = 1;
164   const auto &output_node = make_tuple_node->input(forward_output_index);
165   MS_EXCEPTION_IF_NULL(output_node);
166   if (!output_node->isa<CNode>()) {
167     return;
168   }
169 
170   // Get bprop graph of forward CNode.
171   constexpr size_t bprop_graph_index = 2;
172   const auto &bprop_vnode = make_tuple_node->input(bprop_graph_index);
173   if (!IsValueNode<FuncGraph>(bprop_vnode)) {
174     return;
175   }
176 
177   MS_EXCEPTION_IF_NULL(forward_node);
178   MS_EXCEPTION_IF_NULL(bprop_graph);
179   *forward_node = output_node->cast<CNodePtr>();
180   *bprop_graph = GetValueNode<FuncGraphPtr>(bprop_vnode);
181 }
182 
RunOutputReplace(const CNodePtr & forward_node,const FuncGraphPtr & bprop_graph,const FuncGraphPtr & fprop_graph,const CNodePtr & cnode_morph)183 std::vector<AnfNodePtr> RunOutputReplace(const CNodePtr &forward_node, const FuncGraphPtr &bprop_graph,
184                                          const FuncGraphPtr &fprop_graph, const CNodePtr &cnode_morph) {
185   MS_EXCEPTION_IF_NULL(cnode_morph);
186   if (!PyNativeAlgo::GradCommon::IsRealOp(cnode_morph)) {
187     return {};
188   }
189   // Use manager to get the link relation among nodes.
190   MS_EXCEPTION_IF_NULL(bprop_graph);
191   MS_EXCEPTION_IF_NULL(fprop_graph);
192   auto manager = Manage({fprop_graph, bprop_graph}, false);
193 
194   // Replace output node.
195   MS_EXCEPTION_IF_NULL(forward_node);
196   auto ref_size = manager->node_users().at(forward_node).size();
197   MS_LOG(DEBUG) << "Ref size: " << ref_size;
198   auto output_vnode = GenNewTensor(cnode_morph);
199   MS_EXCEPTION_IF_NULL(output_vnode);
200   output_vnode->set_has_new_value(true);
201   (void)manager->Replace(forward_node, output_vnode);
202   MS_LOG(DEBUG) << "Replace: " << forward_node->DebugString() << " with " << output_vnode->ToString();
203 
204   // Save forward output node when it used in its bprop graph.
205   std::vector<AnfNodePtr> used_forward_nodes;
206   if (ref_size > 1) {
207     cnode_morph->set_forward(output_vnode, "");
208     used_forward_nodes.push_back(cnode_morph);
209     MS_LOG(DEBUG) << "node has been used in grad graph: " << cnode_morph->DebugString()
210                   << ", its output value: " << output_vnode->ToString();
211   }
212   return used_forward_nodes;
213 }
214 
RunInputReplace(const FuncGraphPtr & bprop_graph,const FuncGraphPtr & fprop_graph,const CNodePtr & cnode_morph)215 std::vector<AnfNodePtr> RunInputReplace(const FuncGraphPtr &bprop_graph, const FuncGraphPtr &fprop_graph,
216                                         const CNodePtr &cnode_morph) {
217   MS_EXCEPTION_IF_NULL(cnode_morph);
218   if (!PyNativeAlgo::GradCommon::IsRealOp(cnode_morph)) {
219     return {};
220   }
221   // Use manager to get the link relation among nodes.
222   MS_EXCEPTION_IF_NULL(bprop_graph);
223   MS_EXCEPTION_IF_NULL(fprop_graph);
224   auto manager = Manage({fprop_graph, bprop_graph}, false);
225 
226   const auto &paras = fprop_graph->parameters();
227   if (cnode_morph->size() - 1 != paras.size() && !IsPrimitiveCNode(cnode_morph, prim::kPrimUpdateState)) {
228     MS_LOG(EXCEPTION) << "The size of parameters in fprop graph:" << paras.size()
229                       << ", but the size of input tensors of forward node: " << cnode_morph->size() - 1;
230   }
231 
232   std::vector<AnfNodePtr> used_input_nodes;
233   for (size_t i = 0; i < paras.size(); ++i) {
234     const auto &input_node = cnode_morph->input(i + 1);
235     MS_EXCEPTION_IF_NULL(input_node);
236     // Parameter, ValueNode and StopGradient CNode no need to replace.
237     if (input_node->isa<Parameter>() || input_node->isa<ValueNode>() ||
238         !PyNativeAlgo::GradCommon::IsRealOp(input_node)) {
239       continue;
240     }
241     // Replace forward input node by its output value.
242     auto para_ref_size = manager->node_users()[paras[i]].size();
243     CNodePtr cnode_i = input_node->cast<CNodePtr>();
244     MS_EXCEPTION_IF_NULL(cnode_i);
245     auto output_vnode_i = GenNewTensor(cnode_i);
246     MS_EXCEPTION_IF_NULL(output_vnode_i);
247     output_vnode_i->set_has_new_value(true);
248     (void)manager->Replace(paras[i], output_vnode_i);
249     if (IsPrimitiveCNode(cnode_i, prim::kPrimLoad)) {
250       para_ref_size += 1;
251     }
252     MS_LOG(DEBUG) << "Replace: " << paras[i]->DebugString() << " with " << output_vnode_i->ToString();
253     // Save forward input node when it used in bprop graph.
254     if (para_ref_size > 0 && !IsPrimitiveCNode(input_node, prim::kPrimUpdateState)) {
255       cnode_i->set_forward(output_vnode_i, "");
256       used_input_nodes.push_back(cnode_i);
257       MS_LOG(DEBUG) << "Input CNode has been used in grad graph: " << cnode_i->DebugString()
258                     << ", its output value: " << output_vnode_i->ToString();
259     }
260   }
261 
262   return used_input_nodes;
263 }
264 }  // namespace
265 
ReplaceEquivOut(const CNodePtr & k_app,const CNodePtr & cnode_morph)266 void ReplaceEquivOut(const CNodePtr &k_app, const CNodePtr &cnode_morph) {
267   MS_EXCEPTION_IF_NULL(cnode_morph);
268   MS_LOG(DEBUG) << "Run replace for cnode morph: " << cnode_morph->DebugString();
269   // Get forward node and its fprop graph, bprop graph.
270   MS_EXCEPTION_IF_NULL(k_app);
271   CNodePtr forward_node = nullptr;
272   FuncGraphPtr bprop_graph = nullptr;
273   FuncGraphPtr fprop_graph = nullptr;
274   GetForwardOutNodeAndBpropGraph(k_app, &forward_node, &bprop_graph, &fprop_graph);
275   if (forward_node == nullptr || bprop_graph == nullptr || fprop_graph == nullptr) {
276     return;
277   }
278 
279   // Replace forward node used in bprop graph by its output tensors. The same process for its input node.
280   auto used_forward_nodes = RunOutputReplace(forward_node, bprop_graph, fprop_graph, cnode_morph);
281   auto used_input_nodes = RunInputReplace(bprop_graph, fprop_graph, cnode_morph);
282 
283   // Save used forward input and output nodes to func_graph.
284   auto ms_func_graph = cnode_morph->func_graph();
285   MS_EXCEPTION_IF_NULL(ms_func_graph);
286   ms_func_graph->set_used_forward_nodes(used_forward_nodes);
287   ms_func_graph->set_used_forward_nodes(used_input_nodes);
288 }
289 }  // namespace pynative
290 }  // namespace mindspore
291