1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/optimizer/ad/pynative_dfunctor.h"
18
19 #include <memory>
20 #include <vector>
21
22 namespace mindspore {
23 namespace ad {
GenNewTensorInner(const TypePtr & type_elem,const BaseShapePtr & shape_elem)24 tensor::TensorPtr PynativeDFunctor::GenNewTensorInner(const TypePtr &type_elem, const BaseShapePtr &shape_elem) {
25 MS_EXCEPTION_IF_NULL(type_elem);
26 MS_EXCEPTION_IF_NULL(shape_elem);
27 auto shape = shape_elem->cast<abstract::ShapePtr>();
28 MS_EXCEPTION_IF_NULL(shape);
29 auto tensor_type = type_elem->cast<TensorTypePtr>();
30 MS_EXCEPTION_IF_NULL(tensor_type);
31 auto type = tensor_type->element();
32 MS_EXCEPTION_IF_NULL(type);
33 return std::make_shared<tensor::Tensor>(type->type_id(), shape->shape());
34 }
35
GenNewTensor(const CNodePtr & cnode_morph)36 ValueNodePtr PynativeDFunctor::GenNewTensor(const CNodePtr &cnode_morph) {
37 MS_EXCEPTION_IF_NULL(cnode_morph);
38 if (cnode_morph->forward().first != nullptr) {
39 return cnode_morph->forward().first;
40 }
41 if (IsPrimitiveCNode(cnode_morph, prim::kPrimUpdateState)) {
42 ValueNodePtr out_vnode = NewValueNode(std::make_shared<UMonad>());
43 out_vnode->set_abstract(std::make_shared<abstract::AbstractUMonad>());
44 return out_vnode;
45 }
46
47 auto cnode_shape = cnode_morph->Shape();
48 MS_EXCEPTION_IF_NULL(cnode_shape);
49 auto cnode_type = cnode_morph->Type();
50 MS_EXCEPTION_IF_NULL(cnode_type);
51 // Create output values.
52 if (cnode_type->isa<Tuple>()) {
53 auto tuple_shape = cnode_shape->cast<abstract::TupleShapePtr>();
54 MS_EXCEPTION_IF_NULL(tuple_shape);
55 auto tuple_type = cnode_type->cast<TuplePtr>();
56 MS_EXCEPTION_IF_NULL(tuple_type);
57 size_t output_num = tuple_type->elements().size();
58 std::vector<ValuePtr> output_values;
59 for (size_t i = 0; i < output_num; ++i) {
60 auto shape_elem = tuple_shape->shape()[i];
61 auto type_elem = tuple_type->elements()[i];
62 output_values.push_back(GenNewTensorInner(type_elem, shape_elem));
63 }
64 if (output_values.empty()) {
65 MS_LOG(EXCEPTION) << "The output values is empty, cnode morph: " << cnode_morph->DebugString();
66 }
67 return NewValueNode(std::make_shared<ValueTuple>(output_values));
68 } else if (cnode_type->isa<TensorType>()) {
69 return NewValueNode(GenNewTensorInner(cnode_type, cnode_shape));
70 } else if (cnode_shape->isa<abstract::NoShape>()) {
71 ShapeVector NoShape;
72 return NewValueNode(std::make_shared<tensor::Tensor>(cnode_type->type_id(), NoShape));
73 }
74
75 MS_LOG(EXCEPTION) << "Unknown shape: " << cnode_shape->ToString() << ", type: " << cnode_type->ToString();
76 }
77
GetForwardOutNodeAndBpropGraph(const CNodePtr & k_app,CNodePtr * forward_node,FuncGraphPtr * bprop_graph,FuncGraphPtr * fprop_graph)78 void PynativeDFunctor::GetForwardOutNodeAndBpropGraph(const CNodePtr &k_app, CNodePtr *forward_node,
79 FuncGraphPtr *bprop_graph, FuncGraphPtr *fprop_graph) {
80 MS_EXCEPTION_IF_NULL(k_app);
81 MS_EXCEPTION_IF_NULL(fprop_graph);
82 const auto &prim = k_app->input(0);
83 if (!IsValueNode<FuncGraph>(prim)) {
84 return;
85 }
86 // Clone a new fprop graph for different k_app.
87 auto original_fprop = GetValueNode<FuncGraphPtr>(prim);
88 MS_EXCEPTION_IF_NULL(original_fprop);
89 *fprop_graph = BasicClone(original_fprop);
90 k_app->set_input(0, NewValueNode(*fprop_graph));
91
92 // {prim::maketuple, forward_output, bprop_graph}
93 auto output = (*fprop_graph)->output();
94 MS_EXCEPTION_IF_NULL(output);
95 if (!output->isa<CNode>()) {
96 return;
97 }
98 auto make_tuple_node = output->cast<CNodePtr>();
99 MS_EXCEPTION_IF_NULL(make_tuple_node);
100 constexpr size_t input_size = 3;
101 if (make_tuple_node->size() != input_size) {
102 MS_LOG(EXCEPTION) << "The inputs size of make tuple node " << make_tuple_node->DebugString() << " is not equal to "
103 << input_size;
104 }
105
106 // Get forward CNode.
107 const size_t forward_output_index = 1;
108 const auto &output_node = make_tuple_node->input(forward_output_index);
109 MS_EXCEPTION_IF_NULL(output_node);
110 if (!output_node->isa<CNode>()) {
111 return;
112 }
113
114 // Get bprop graph of forward CNode.
115 const size_t bprop_graph_index = 2;
116 const auto &bprop_vnode = make_tuple_node->input(bprop_graph_index);
117 if (!IsValueNode<FuncGraph>(bprop_vnode)) {
118 return;
119 }
120
121 MS_EXCEPTION_IF_NULL(forward_node);
122 MS_EXCEPTION_IF_NULL(bprop_graph);
123 *forward_node = output_node->cast<CNodePtr>();
124 *bprop_graph = GetValueNode<FuncGraphPtr>(bprop_vnode);
125 }
126
RunOutputReplace(const CNodePtr & forward_node,const FuncGraphPtr & bprop_graph,const FuncGraphPtr & fprop_graph,const CNodePtr & cnode_morph)127 std::vector<AnfNodePtr> PynativeDFunctor::RunOutputReplace(const CNodePtr &forward_node,
128 const FuncGraphPtr &bprop_graph,
129 const FuncGraphPtr &fprop_graph,
130 const CNodePtr &cnode_morph) {
131 MS_EXCEPTION_IF_NULL(cnode_morph);
132 if (IsPrimitiveCNode(cnode_morph, prim::kPrimStopGradient)) {
133 return {};
134 }
135 // Use manager to get the link relation among nodes.
136 MS_EXCEPTION_IF_NULL(bprop_graph);
137 MS_EXCEPTION_IF_NULL(fprop_graph);
138 auto manager = Manage({fprop_graph, bprop_graph}, false);
139
140 // Replace output node.
141 MS_EXCEPTION_IF_NULL(forward_node);
142 auto ref_size = manager->node_users().at(forward_node).size();
143 MS_LOG(DEBUG) << "Ref size: " << ref_size;
144 auto output_vnode = GenNewTensor(cnode_morph);
145 MS_EXCEPTION_IF_NULL(output_vnode);
146 output_vnode->set_has_new_value(true);
147 manager->Replace(forward_node, output_vnode);
148 MS_LOG(DEBUG) << "Replace: " << forward_node->DebugString() << " with " << output_vnode->ToString();
149
150 // Save forward output node when it used in its bprop graph.
151 std::vector<AnfNodePtr> used_forward_nodes;
152 constexpr size_t ref_twice = 2;
153 if (ref_size >= ref_twice) {
154 cnode_morph->set_forward(output_vnode, "");
155 used_forward_nodes.push_back(cnode_morph);
156 MS_LOG(DEBUG) << "node has been used in grad graph: " << cnode_morph->DebugString()
157 << ", its output value: " << output_vnode->ToString();
158 }
159 return used_forward_nodes;
160 }
161
RunInputReplace(const FuncGraphPtr & bprop_graph,const FuncGraphPtr & fprop_graph,const CNodePtr & cnode_morph)162 std::vector<AnfNodePtr> PynativeDFunctor::RunInputReplace(const FuncGraphPtr &bprop_graph,
163 const FuncGraphPtr &fprop_graph,
164 const CNodePtr &cnode_morph) {
165 // Use manager to get the link relation among nodes.
166 MS_EXCEPTION_IF_NULL(bprop_graph);
167 MS_EXCEPTION_IF_NULL(fprop_graph);
168 auto manager = Manage({fprop_graph, bprop_graph}, false);
169
170 MS_EXCEPTION_IF_NULL(cnode_morph);
171 const auto ¶s = fprop_graph->parameters();
172 if (cnode_morph->size() - 1 != paras.size() && !IsPrimitiveCNode(cnode_morph, prim::kPrimUpdateState)) {
173 MS_LOG(EXCEPTION) << "The size of parameters in fprop graph:" << paras.size()
174 << ", but the size of input tensors of forward node: " << cnode_morph->inputs().size() - 1;
175 }
176
177 std::vector<AnfNodePtr> used_input_nodes;
178 for (size_t i = 0; i < paras.size(); ++i) {
179 const auto &input_node = cnode_morph->input(i + 1);
180 MS_EXCEPTION_IF_NULL(input_node);
181 // Parameter, ValueNode and StopGradient CNode no need to replace.
182 if (input_node->isa<Parameter>() || input_node->isa<ValueNode>() ||
183 IsPrimitiveCNode(input_node, prim::kPrimStopGradient)) {
184 continue;
185 }
186 // Replace forward input node by its output value.
187 auto para_ref_size = manager->node_users()[paras[i]].size();
188 CNodePtr cnode_i = input_node->cast<CNodePtr>();
189 MS_EXCEPTION_IF_NULL(cnode_i);
190 auto output_vnode_i = GenNewTensor(cnode_i);
191 MS_EXCEPTION_IF_NULL(output_vnode_i);
192 output_vnode_i->set_has_new_value(true);
193 manager->Replace(paras[i], output_vnode_i);
194 MS_LOG(DEBUG) << "Replace: " << paras[i]->DebugString() << " with " << output_vnode_i->ToString();
195 // Save forward input node when it used in bprop graph.
196 if (para_ref_size > 0 && !IsPrimitiveCNode(input_node, prim::kPrimUpdateState)) {
197 cnode_i->set_forward(output_vnode_i, "");
198 used_input_nodes.push_back(cnode_i);
199 MS_LOG(DEBUG) << "Input CNode has been used in grad graph: " << cnode_i->DebugString()
200 << ", its output value: " << output_vnode_i->ToString();
201 }
202 }
203
204 return used_input_nodes;
205 }
206
ReplaceEquivdout(const CNodePtr & k_app,const CNodePtr & cnode_morph)207 void PynativeDFunctor::ReplaceEquivdout(const CNodePtr &k_app, const CNodePtr &cnode_morph) {
208 // The process of replacing forward node only works in pynative mode, when @ms_function is used.
209 MS_EXCEPTION_IF_NULL(cnode_morph);
210 MS_LOG(DEBUG) << "Run replace for cnode morph: " << cnode_morph->DebugString(2);
211 // Get forward node and its fprop graph, bprop graph.
212 MS_EXCEPTION_IF_NULL(k_app);
213 CNodePtr forward_node = nullptr;
214 FuncGraphPtr bprop_graph = nullptr;
215 FuncGraphPtr fprop_graph = nullptr;
216 GetForwardOutNodeAndBpropGraph(k_app, &forward_node, &bprop_graph, &fprop_graph);
217 if (forward_node == nullptr || bprop_graph == nullptr || fprop_graph == nullptr) {
218 return;
219 }
220
221 // Replace forward node used in bprop graph by its output tensors. The same process for its input node.
222 auto used_forward_nodes = RunOutputReplace(forward_node, bprop_graph, fprop_graph, cnode_morph);
223 auto used_input_nodes = RunInputReplace(bprop_graph, fprop_graph, cnode_morph);
224
225 // Save used forward input and output nodes to func_graph.
226 auto ms_func_graph = cnode_morph->func_graph();
227 MS_EXCEPTION_IF_NULL(ms_func_graph);
228 ms_func_graph->set_used_forward_nodes(used_forward_nodes);
229 ms_func_graph->set_used_forward_nodes(used_input_nodes);
230 }
231 } // namespace ad
232 } // namespace mindspore
233