1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/pynative/grad/jit/jit_dfunctor.h"
18
19 #include <memory>
20 #include <string>
21
22 #include "ir/func_graph_cloner.h"
23 #include "pipeline/pynative/pynative_utils.h"
24
25 namespace mindspore {
26 namespace pynative {
27 namespace {
GenNewTensorInner(const TypePtr & type_elem,const BaseShapePtr & shape_elem)28 tensor::TensorPtr GenNewTensorInner(const TypePtr &type_elem, const BaseShapePtr &shape_elem) {
29 MS_EXCEPTION_IF_NULL(type_elem);
30 MS_EXCEPTION_IF_NULL(shape_elem);
31 auto shape = shape_elem->cast<abstract::ShapePtr>();
32 MS_EXCEPTION_IF_NULL(shape);
33 auto tensor_type = type_elem->cast<TensorTypePtr>();
34 MS_EXCEPTION_IF_NULL(tensor_type);
35 auto type = tensor_type->element();
36 MS_EXCEPTION_IF_NULL(type);
37 return std::make_shared<tensor::Tensor>(type->type_id(), shape->shape());
38 }
39
NewValue(const TypePtr & type_elem,const BaseShapePtr & shape_elem)40 ValuePtr NewValue(const TypePtr &type_elem, const BaseShapePtr &shape_elem) {
41 MS_EXCEPTION_IF_NULL(type_elem);
42 MS_EXCEPTION_IF_NULL(shape_elem);
43 if (shape_elem->isa<abstract::TupleShape>()) {
44 auto tuple_shape = shape_elem->cast<abstract::TupleShapePtr>();
45 MS_EXCEPTION_IF_NULL(tuple_shape);
46 auto tuple_type = type_elem->cast<TuplePtr>();
47 MS_EXCEPTION_IF_NULL(tuple_type);
48 size_t output_num = tuple_type->elements().size();
49 std::vector<ValuePtr> value_list;
50 for (size_t i = 0; i < output_num; ++i) {
51 auto sub_shape_elem = tuple_shape->shape()[i];
52 auto sub_type_elem = tuple_type->elements()[i];
53 ValuePtr new_value = NewValue(sub_type_elem, sub_shape_elem);
54 value_list.push_back(new_value);
55 }
56 return std::make_shared<ValueTuple>(value_list);
57 }
58 if (type_elem->isa<TensorType>()) {
59 return GenNewTensorInner(type_elem, shape_elem);
60 }
61 if (shape_elem->isa<abstract::NoShape>()) {
62 ShapeVector NoShape;
63 if (type_elem->type_id() == kMetaTypeNone) {
64 return kNone;
65 }
66 return std::make_shared<tensor::Tensor>(type_elem->type_id(), NoShape);
67 }
68 MS_LOG(INTERNAL_EXCEPTION) << "Unknown shape: " << shape_elem->ToString() << ", type: " << type_elem->ToString();
69 }
70
GenNewTensor(const CNodePtr & cnode_morph)71 ValueNodePtr GenNewTensor(const CNodePtr &cnode_morph) {
72 MS_EXCEPTION_IF_NULL(cnode_morph);
73 if (cnode_morph->forward().first != nullptr) {
74 return cnode_morph->forward().first;
75 }
76 if (IsPrimitiveCNode(cnode_morph, prim::kPrimUpdateState)) {
77 ValueNodePtr out_vnode = NewValueNode(std::make_shared<UMonad>());
78 out_vnode->set_abstract(std::make_shared<abstract::AbstractUMonad>());
79 return out_vnode;
80 }
81 // Function used to generate value node
82 auto gen_output_value_node = [](const ValuePtr &value) -> ValueNodePtr {
83 MS_EXCEPTION_IF_NULL(value);
84 auto v_node = NewValueNode(value);
85 v_node->set_abstract(value->ToAbstract()->Broaden());
86 return v_node;
87 };
88 // Create output value node for CNode
89 auto cnode_shape = cnode_morph->Shape();
90 MS_EXCEPTION_IF_NULL(cnode_shape);
91 auto cnode_type = cnode_morph->Type();
92 MS_EXCEPTION_IF_NULL(cnode_type);
93 if (cnode_type->isa<Tuple>()) {
94 auto tuple_shape = cnode_shape->cast<abstract::TupleShapePtr>();
95 MS_EXCEPTION_IF_NULL(tuple_shape);
96 auto tuple_type = cnode_type->cast<TuplePtr>();
97 MS_EXCEPTION_IF_NULL(tuple_type);
98 size_t output_num = tuple_type->elements().size();
99 MS_EXCEPTION_IF_CHECK_FAIL(output_num != 0, "No output value.");
100 std::vector<ValuePtr> output_values;
101 for (size_t i = 0; i < output_num; ++i) {
102 auto shape_elem = tuple_shape->shape()[i];
103 auto type_elem = tuple_type->elements()[i];
104 output_values.push_back(NewValue(type_elem, shape_elem));
105 }
106 auto value_tuple = std::make_shared<ValueTuple>(output_values);
107 return gen_output_value_node(value_tuple);
108 } else if (cnode_type->isa<List>()) {
109 auto list_shape = cnode_shape->cast<abstract::ListShapePtr>();
110 MS_EXCEPTION_IF_NULL(list_shape);
111 auto list_type = cnode_type->cast<ListPtr>();
112 MS_EXCEPTION_IF_NULL(list_type);
113 size_t output_num = list_type->elements().size();
114 MS_EXCEPTION_IF_CHECK_FAIL(output_num != 0, "No output value.");
115 std::vector<ValuePtr> output_values;
116 for (size_t i = 0; i < output_num; ++i) {
117 auto shape_elem = list_shape->shape()[i];
118 auto type_elem = list_type->elements()[i];
119 output_values.push_back(NewValue(type_elem, shape_elem));
120 }
121 auto value_tuple = std::make_shared<ValueList>(output_values);
122 return gen_output_value_node(value_tuple);
123 } else if (cnode_type->isa<TensorType>()) {
124 auto tensor_value = GenNewTensorInner(cnode_type, cnode_shape);
125 return gen_output_value_node(tensor_value);
126 } else if (cnode_shape->isa<abstract::NoShape>()) {
127 ShapeVector NoShape;
128 auto tensor_value = std::make_shared<tensor::Tensor>(cnode_type->type_id(), NoShape);
129 return gen_output_value_node(tensor_value);
130 }
131 MS_LOG(INTERNAL_EXCEPTION) << "Unknown shape: " << cnode_shape->ToString() << ", type: " << cnode_type->ToString();
132 }
133
GetForwardOutNodeAndBpropGraph(const CNodePtr & k_app,CNodePtr * forward_node,FuncGraphPtr * bprop_graph,FuncGraphPtr * fprop_graph)134 void GetForwardOutNodeAndBpropGraph(const CNodePtr &k_app, CNodePtr *forward_node, FuncGraphPtr *bprop_graph,
135 FuncGraphPtr *fprop_graph) {
136 MS_EXCEPTION_IF_NULL(k_app);
137 MS_EXCEPTION_IF_NULL(fprop_graph);
138 const auto &prim = k_app->input(0);
139 if (!IsValueNode<FuncGraph>(prim)) {
140 return;
141 }
142 // Clone a new fprop graph for different k_app.
143 auto original_fprop = GetValueNode<FuncGraphPtr>(prim);
144 MS_EXCEPTION_IF_NULL(original_fprop);
145 *fprop_graph = BasicClone(original_fprop);
146 k_app->set_input(0, NewValueNode(*fprop_graph));
147
148 // {prim::maketuple, forward_output, bprop_graph}
149 auto output = (*fprop_graph)->output();
150 MS_EXCEPTION_IF_NULL(output);
151 if (!output->isa<CNode>()) {
152 return;
153 }
154 auto make_tuple_node = output->cast<CNodePtr>();
155 MS_EXCEPTION_IF_NULL(make_tuple_node);
156 constexpr size_t input_size = 3;
157 if (make_tuple_node->size() != input_size) {
158 MS_LOG(INTERNAL_EXCEPTION) << "The inputs size of make tuple node " << make_tuple_node->DebugString()
159 << " is not equal to " << input_size;
160 }
161
162 // Get forward CNode.
163 constexpr size_t forward_output_index = 1;
164 const auto &output_node = make_tuple_node->input(forward_output_index);
165 MS_EXCEPTION_IF_NULL(output_node);
166 if (!output_node->isa<CNode>()) {
167 return;
168 }
169
170 // Get bprop graph of forward CNode.
171 constexpr size_t bprop_graph_index = 2;
172 const auto &bprop_vnode = make_tuple_node->input(bprop_graph_index);
173 if (!IsValueNode<FuncGraph>(bprop_vnode)) {
174 return;
175 }
176
177 MS_EXCEPTION_IF_NULL(forward_node);
178 MS_EXCEPTION_IF_NULL(bprop_graph);
179 *forward_node = output_node->cast<CNodePtr>();
180 *bprop_graph = GetValueNode<FuncGraphPtr>(bprop_vnode);
181 }
182
RunOutputReplace(const CNodePtr & forward_node,const FuncGraphPtr & bprop_graph,const FuncGraphPtr & fprop_graph,const CNodePtr & cnode_morph)183 std::vector<AnfNodePtr> RunOutputReplace(const CNodePtr &forward_node, const FuncGraphPtr &bprop_graph,
184 const FuncGraphPtr &fprop_graph, const CNodePtr &cnode_morph) {
185 MS_EXCEPTION_IF_NULL(cnode_morph);
186 if (!PyNativeAlgo::GradCommon::IsRealOp(cnode_morph)) {
187 return {};
188 }
189 // Use manager to get the link relation among nodes.
190 MS_EXCEPTION_IF_NULL(bprop_graph);
191 MS_EXCEPTION_IF_NULL(fprop_graph);
192 auto manager = Manage({fprop_graph, bprop_graph}, false);
193
194 // Replace output node.
195 MS_EXCEPTION_IF_NULL(forward_node);
196 auto ref_size = manager->node_users().at(forward_node).size();
197 MS_LOG(DEBUG) << "Ref size: " << ref_size;
198 auto output_vnode = GenNewTensor(cnode_morph);
199 MS_EXCEPTION_IF_NULL(output_vnode);
200 output_vnode->set_has_new_value(true);
201 (void)manager->Replace(forward_node, output_vnode);
202 MS_LOG(DEBUG) << "Replace: " << forward_node->DebugString() << " with " << output_vnode->ToString();
203
204 // Save forward output node when it used in its bprop graph.
205 std::vector<AnfNodePtr> used_forward_nodes;
206 if (ref_size > 1) {
207 cnode_morph->set_forward(output_vnode, "");
208 used_forward_nodes.push_back(cnode_morph);
209 MS_LOG(DEBUG) << "node has been used in grad graph: " << cnode_morph->DebugString()
210 << ", its output value: " << output_vnode->ToString();
211 }
212 return used_forward_nodes;
213 }
214
RunInputReplace(const FuncGraphPtr & bprop_graph,const FuncGraphPtr & fprop_graph,const CNodePtr & cnode_morph)215 std::vector<AnfNodePtr> RunInputReplace(const FuncGraphPtr &bprop_graph, const FuncGraphPtr &fprop_graph,
216 const CNodePtr &cnode_morph) {
217 MS_EXCEPTION_IF_NULL(cnode_morph);
218 if (!PyNativeAlgo::GradCommon::IsRealOp(cnode_morph)) {
219 return {};
220 }
221 // Use manager to get the link relation among nodes.
222 MS_EXCEPTION_IF_NULL(bprop_graph);
223 MS_EXCEPTION_IF_NULL(fprop_graph);
224 auto manager = Manage({fprop_graph, bprop_graph}, false);
225
226 const auto ¶s = fprop_graph->parameters();
227 if (cnode_morph->size() - 1 != paras.size() && !IsPrimitiveCNode(cnode_morph, prim::kPrimUpdateState)) {
228 MS_LOG(EXCEPTION) << "The size of parameters in fprop graph:" << paras.size()
229 << ", but the size of input tensors of forward node: " << cnode_morph->size() - 1;
230 }
231
232 std::vector<AnfNodePtr> used_input_nodes;
233 for (size_t i = 0; i < paras.size(); ++i) {
234 const auto &input_node = cnode_morph->input(i + 1);
235 MS_EXCEPTION_IF_NULL(input_node);
236 // Parameter, ValueNode and StopGradient CNode no need to replace.
237 if (input_node->isa<Parameter>() || input_node->isa<ValueNode>() ||
238 !PyNativeAlgo::GradCommon::IsRealOp(input_node)) {
239 continue;
240 }
241 // Replace forward input node by its output value.
242 auto para_ref_size = manager->node_users()[paras[i]].size();
243 CNodePtr cnode_i = input_node->cast<CNodePtr>();
244 MS_EXCEPTION_IF_NULL(cnode_i);
245 auto output_vnode_i = GenNewTensor(cnode_i);
246 MS_EXCEPTION_IF_NULL(output_vnode_i);
247 output_vnode_i->set_has_new_value(true);
248 (void)manager->Replace(paras[i], output_vnode_i);
249 if (IsPrimitiveCNode(cnode_i, prim::kPrimLoad)) {
250 para_ref_size += 1;
251 }
252 MS_LOG(DEBUG) << "Replace: " << paras[i]->DebugString() << " with " << output_vnode_i->ToString();
253 // Save forward input node when it used in bprop graph.
254 if (para_ref_size > 0 && !IsPrimitiveCNode(input_node, prim::kPrimUpdateState)) {
255 cnode_i->set_forward(output_vnode_i, "");
256 used_input_nodes.push_back(cnode_i);
257 MS_LOG(DEBUG) << "Input CNode has been used in grad graph: " << cnode_i->DebugString()
258 << ", its output value: " << output_vnode_i->ToString();
259 }
260 }
261
262 return used_input_nodes;
263 }
264 } // namespace
265
ReplaceEquivOut(const CNodePtr & k_app,const CNodePtr & cnode_morph)266 void ReplaceEquivOut(const CNodePtr &k_app, const CNodePtr &cnode_morph) {
267 MS_EXCEPTION_IF_NULL(cnode_morph);
268 MS_LOG(DEBUG) << "Run replace for cnode morph: " << cnode_morph->DebugString();
269 // Get forward node and its fprop graph, bprop graph.
270 MS_EXCEPTION_IF_NULL(k_app);
271 CNodePtr forward_node = nullptr;
272 FuncGraphPtr bprop_graph = nullptr;
273 FuncGraphPtr fprop_graph = nullptr;
274 GetForwardOutNodeAndBpropGraph(k_app, &forward_node, &bprop_graph, &fprop_graph);
275 if (forward_node == nullptr || bprop_graph == nullptr || fprop_graph == nullptr) {
276 return;
277 }
278
279 // Replace forward node used in bprop graph by its output tensors. The same process for its input node.
280 auto used_forward_nodes = RunOutputReplace(forward_node, bprop_graph, fprop_graph, cnode_morph);
281 auto used_input_nodes = RunInputReplace(bprop_graph, fprop_graph, cnode_morph);
282
283 // Save used forward input and output nodes to func_graph.
284 auto ms_func_graph = cnode_morph->func_graph();
285 MS_EXCEPTION_IF_NULL(ms_func_graph);
286 ms_func_graph->set_used_forward_nodes(used_forward_nodes);
287 ms_func_graph->set_used_forward_nodes(used_input_nodes);
288 }
289 } // namespace pynative
290 } // namespace mindspore
291