1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/pi/auto_grad/function_node.h"
17 #include <algorithm>
18 #include <exception>
19 #include <functional>
20 #include <iterator>
21 #include <memory>
22 #include <queue>
23 #include <sstream>
24 #include <string>
25 #include <utility>
26 #include "ops/sequence_ops.h"
27 #include "pipeline/jit/pi/auto_grad/edge.h"
28 #include "pipeline/jit/pi/auto_grad/grad_executor.h"
29 #include "pipeline/jit/pi/auto_grad/native_backward_function.h"
30 #include "utils/ms_utils.h"
31
32 namespace mindspore {
33 namespace pijit {
34 namespace grad {
CleanResource()35 void FunctionNode::CleanResource() {
36 if (HasGradFunc(tensor_)) {
37 py::setattr(tensor_, "grad_fn", py::none());
38 }
39 tensor_ = py::none();
40 backward_func_ = nullptr;
41 edges_.clear();
42 dependences_.clear();
43 }
44
PostBpropFunctionToEdges(const py::object & tensor)45 void PostBpropFunctionToEdges(const py::object &tensor) {
46 py::object grad_fn = python_adapter::GetPyObjAttr(tensor, "grad_fn");
47 if (py::isinstance<py::none>(grad_fn)) {
48 return;
49 }
50 auto func_node = grad_fn.cast<grad::FunctionNodePtr>();
51 func_node->GenerateBropFunction();
52 auto edges = func_node->GetNextEdges();
53 std::for_each(edges.begin(), edges.end(), [](const EdgePtr &edge) { edge->GetNode()->GenerateBropFunction(); });
54 }
55
ConvertTupleToValueList(const py::list & inputs)56 ValuePtrList ConvertTupleToValueList(const py::list &inputs) {
57 ValuePtrList value_list;
58 for (const auto input : inputs) {
59 ValuePtr value = Convert::PyObjToValue(py::cast<py::object>(input));
60 if (value->template isa<None>()) {
61 return value_list;
62 }
63 value_list.push_back(value);
64 }
65 return value_list;
66 }
67
ConvertArgByCastDtype(const py::object & arg,const ops::OpInputArg & op_arg)68 ValuePtr ConvertArgByCastDtype(const py::object &arg, const ops::OpInputArg &op_arg) {
69 parse::OpDefConvertFunc convert_func = parse::GetConverterByType(static_cast<int32_t>(op_arg.arg_dtype_));
70 MS_EXCEPTION_IF_NULL(convert_func);
71 ValuePtr value = convert_func(arg);
72 if (value != nullptr) {
73 return value;
74 }
75 for (auto cast_dtype : op_arg.cast_dtype_) {
76 convert_func = parse::GetConverterByType(parse::CombineTypesForTypeCast(cast_dtype, op_arg.arg_dtype_));
77 MS_EXCEPTION_IF_NULL(convert_func);
78 value = convert_func(arg);
79 if (value != nullptr) {
80 return value;
81 }
82 }
83 if (!py::isinstance<py::none>(arg) && value == nullptr) {
84 value = Convert::PyObjToValue(arg);
85 }
86 return value;
87 }
88
ParseInputsByOpDef(const PrimitivePyPtr & prim,const ops::OpDefPtr & op_def,const py::list & inputs)89 ValuePtrList ParseInputsByOpDef(const PrimitivePyPtr &prim, const ops::OpDefPtr &op_def, const py::list &inputs) {
90 ValuePtrList input_values;
91 MS_EXCEPTION_IF_CHECK_FAIL(inputs.size() <= op_def->args_.size(), "The arguments is not match defined.");
92 size_t index = 0;
93 std::for_each(op_def->args_.begin(), op_def->args_.end(), [&prim, &inputs, &input_values, &index](const auto &arg) {
94 if (!arg.as_init_arg_) {
95 input_values.push_back(ConvertArgByCastDtype(inputs[index], arg));
96 } else {
97 auto value = py::getattr(prim->GetPyObj(), common::SafeCStr(arg.arg_name_));
98 if (!py::isinstance<py::none>(value)) {
99 input_values.push_back(ConvertArgByCastDtype(value, arg));
100 }
101 }
102 index++;
103 });
104 return input_values;
105 }
106
ParseInputs(const PrimitivePyPtr & prim,const py::list & inputs)107 ValuePtrList ParseInputs(const PrimitivePyPtr &prim, const py::list &inputs) {
108 auto op_def = mindspore::ops::GetOpDef(prim->name());
109 if (op_def == nullptr) {
110 return ConvertTupleToValueList(inputs);
111 }
112 return ParseInputsByOpDef(prim, op_def, inputs);
113 }
114
IsRequiresGradient(const py::handle & obj)115 bool FunctionNode::IsRequiresGradient(const py::handle &obj) {
116 if (!HasAttrReqGrad(obj)) {
117 return false;
118 }
119 auto requires_grad = obj.attr("requires_grad");
120 return py::isinstance<py::bool_>(requires_grad) && py::bool_(requires_grad);
121 }
122
HasGradFunc(const py::handle & obj)123 bool FunctionNode::HasGradFunc(const py::handle &obj) {
124 auto grad_fn = python_adapter::GetPyObjAttr(py::cast<py::object>(obj), "grad_fn");
125 return py::isinstance<grad::FunctionNode>(grad_fn);
126 }
127
GetOrCreateFunctionNode(const py::object & tensor,const py::object & prim,const py::object & out,const py::list & inputs)128 FunctionNodePtr GetOrCreateFunctionNode(const py::object &tensor, const py::object &prim, const py::object &out,
129 const py::list &inputs) {
130 py::object grad_fn = python_adapter::GetPyObjAttr(tensor, "grad_fn");
131 if (py::isinstance<grad::FunctionNode>(grad_fn)) {
132 return grad_fn.cast<grad::FunctionNodePtr>();
133 }
134 auto func_node = FunctionNode::CreateFunctionNode(tensor, prim, out, inputs);
135 if (!func_node->GetNextEdges().empty()) {
136 py::setattr(func_node->GetTensor(), "grad_fn", py::cast(func_node));
137 py::setattr(func_node->GetTensor(), "requires_grad", py::bool_(True));
138 }
139 return func_node;
140 }
141
CreateFunctionNode(const py::object & tensor,const py::object & prim,const py::object & out,const py::list & inputs)142 FunctionNodePtr FunctionNode::CreateFunctionNode(const py::object &tensor, const py::object &prim,
143 const py::object &out, const py::list &inputs) {
144 auto func_node = std::make_shared<FunctionNode>(tensor, prim, out);
145 MS_LOG(DEBUG) << "Create a function node(" << func_node.get() << ") for " << tensor.ptr();
146 func_node->SetInputs(inputs);
147 std::for_each(inputs.begin(), inputs.end(), [func_node, &inputs](const auto &obj) {
148 if (!FunctionNode::IsRequiresGradient(obj) && !FunctionNode::HasGradFunc(obj)) {
149 return;
150 }
151 auto input = py::cast<py::object>(obj);
152 auto node = GetOrCreateFunctionNode(input, py::none(), input, py::list());
153 func_node->AddNextEdge(node, std::distance(inputs.begin(), std::find(inputs.begin(), inputs.end(), input)));
154 });
155 return func_node;
156 }
157
RecordPrimitive(const py::object & prim,const py::object & out,const py::list & inputs)158 void FunctionNode::RecordPrimitive(const py::object &prim, const py::object &out, const py::list &inputs) {
159 MS_LOG(DEBUG) << "Record " << out.ptr() << " for auto gradient.";
160 if (!py::isinstance<py::tuple>(out)) {
161 (void)GetOrCreateFunctionNode(out, prim, out, inputs);
162 } else {
163 auto func_node = CreateFunctionNode(out, prim, out, inputs);
164 std::for_each(out.begin(), out.end(), [&out, &func_node](const auto &obj) {
165 if (!HasAttrReqGrad(obj) || HasGradFunc(obj)) {
166 return;
167 }
168 auto tensor = py::cast<py::object>(obj);
169 auto temp_node = GetOrCreateFunctionNode(tensor, py::none(), tensor, py::list());
170 temp_node->index_ = std::distance(out.begin(), std::find(out.begin(), out.end(), tensor));
171 temp_node->AddNextEdge(func_node, 0);
172 py::setattr(tensor, "grad_fn", py::cast(temp_node));
173 py::setattr(tensor, "requires_grad", py::bool_(True));
174 });
175 }
176 }
177
SetInputs(const py::list & inputs)178 void FunctionNode::SetInputs(const py::list &inputs) {
179 if (py::isinstance<py::none>(inputs) || inputs.empty()) {
180 FunctionContext::SetInputs({});
181 } else {
182 FunctionContext::SetInputs(ParseInputs(GetFunction()->cast<PrimitivePyPtr>(), inputs));
183 }
184 }
185
ApplyEdges(const ValuePtrList & grad_values)186 void FunctionNode::ApplyEdges(const ValuePtrList &grad_values) {
187 MS_EXCEPTION_IF_CHECK_FAIL((grad_values.size() == edges_.size()), "The gradient values is not match.");
188 for (size_t index = 0; index < edges_.size(); index++) {
189 Notify(edges_[index]->GetNode(), grad_values[index]);
190 }
191 }
192
ApplyNative()193 void FunctionNode::ApplyNative() {
194 ValuePtrList flatten_values = ValuePtrList(edges_.size(), GetGrad()[0]);
195 if (backward_func_ != nullptr) {
196 backward_func_->SetGradientIndexes({});
197 std::for_each(edges_.begin(), edges_.end(),
198 [this](const auto &edge) { backward_func_->AddGradientIndex(edge->GetIndex()); });
199 if (GetOutput()->isa<ValueTuple>()) {
200 flatten_values = backward_func_->Run(GetInputs(), GetOutput(), MakeValue(GetGrad()));
201 } else {
202 flatten_values = backward_func_->Run(GetInputs(), GetOutput(), GetGrad()[index_]);
203 }
204 }
205 ApplyEdges(flatten_values);
206 }
207
208 /// \brief Generate the bprop function.
GenerateBropFunction()209 void FunctionNode::GenerateBropFunction() {
210 auto generate_task = std::make_shared<RunGenerateBpropTask>([this]() {
211 MS_LOG(DEBUG) << "Generate brop function for node " << tensor_.ptr() << ", tensor is " << tensor_.ptr();
212 auto output = GetOutput();
213 auto executor = GradExecutor::GetInstance();
214 {
215 // gil for PyObject accessing
216 py::gil_scoped_acquire gil_acquire;
217 auto acc_fn = executor->GetAccumulateGraph(output);
218 acc_fn_ = executor->PrimBpropGraphPass(acc_fn);
219 }
220
221 auto func = GetFunction();
222 if (func->isa<None>()) {
223 return;
224 }
225 try {
226 // gil for PyObject accessing
227 py::gil_scoped_acquire gil_acquire;
228 grad_fn_ = executor->GetBpropGraph(NewValueNode(func), GetInputs(), output, output);
229 } catch (const std::exception &e) {
230 MS_LOG(ERROR) << "Prim : " << func->ToString() << " Output : " << output->ToString();
231 MS_LOG(ERROR) << e.what();
232 }
233 });
234 GradExecutor::GetInstance()->DispatchGenerateTask(generate_task);
235 {
236 py::gil_scoped_release release;
237 GradExecutor::GetInstance()->GetAsyncTaskManager()->GetGenerateTaskQueue()->Wait();
238 }
239 }
240
Visit(const FunctionNodePtr & node,const std::function<void (const FunctionNodePtr &)> & callback)241 void Visit(const FunctionNodePtr &node, const std::function<void(const FunctionNodePtr &)> &callback) {
242 std::queue<FunctionNodePtr> nodes;
243 nodes.push(node);
244 while (!nodes.empty()) {
245 auto fn = nodes.front();
246 nodes.pop();
247 std::for_each(fn->GetNextEdges().begin(), fn->GetNextEdges().end(),
248 [&nodes](const auto &edge) { nodes.push(edge->GetNode()); });
249 callback(fn);
250 }
251 }
252
SyncGradToPyObject()253 void FunctionNode::SyncGradToPyObject() {
254 auto sync_func = [](const FunctionNodePtr &node) {
255 auto retains_grad = python_adapter::GetPyObjAttr(node->tensor_, "retains_grad");
256 if (node->edges_.empty() || (py::isinstance<py::bool_>(retains_grad) && py::bool_(retains_grad))) {
257 auto _grad = python_adapter::GetPyObjAttr(node->tensor_, "grad");
258 if (!py::isinstance<py::none>(_grad)) {
259 node->AccumulateGradient(Convert::PyObjToValue(_grad), node->index_);
260 } else {
261 if (node->GetGrad()[node->index_]->isa<None>()) {
262 auto func = node->backward_func_;
263 if (func == nullptr) {
264 func = NativeBackwardFunc::GetInstance(prim::kPrimAdd);
265 }
266 node->SetGrad(func->Zeros(node->GetOutput()), node->index_);
267 }
268 }
269 auto value = Convert::ValueToPyObj(node->GetGrad()[node->index_]);
270 auto grad = python_adapter::CallPyFn("mindspore.common.api", "_convert_python_data", value);
271 py::setattr(node->tensor_, "grad", grad);
272 }
273 node->SetGrad(ValuePtrList(node->GetGrad().size(), kNone));
274 };
275 Visit(shared_from_base<FunctionNode>(), sync_func);
276 }
277
Apply(const py::object & grad)278 void FunctionNode::Apply(const py::object &grad) {
279 UpdateDependence();
280 Notify(shared_from_base<FunctionNode>(), Convert::PyObjToValue(grad));
281 SyncGradToPyObject();
282 auto release_func = [](const FunctionNodePtr &node) { node->CleanResource(); };
283 Visit(shared_from_base<FunctionNode>(), release_func);
284 }
285
ApplyInner(const ValuePtr & dout)286 void FunctionNode::ApplyInner(const ValuePtr &dout) {
287 MS_LOG(DEBUG) << "Start run apply() of " << tensor_.ptr() << ", tensor is " << tensor_.ptr();
288 MS_LOG(DEBUG) << "Prim is " << GetFunction()->ToString() << ", dout is " << dout->ToString();
289 auto run_task = std::make_shared<RunBpropTask>(
290 [this](const ValuePtr &dout) {
291 AccumulateGradient(dout, index_);
292 if (grad_fn_ == nullptr) {
293 return;
294 }
295 // gil for PyObject accessing
296 py::gil_scoped_acquire gil_acquire;
297 auto ret = GradExecutor::GetInstance()->RunGraph(grad_fn_, GetInputs(), GetOutput(), dout);
298 if (!ret->isa<ValueTuple>()) {
299 return;
300 }
301 auto tuple = ret->cast<ValueTuplePtr>();
302 std::for_each(edges_.begin(), edges_.end(),
303 [&tuple](const auto &edge) { edge->GetNode()->ApplyInner(tuple->value()[edge->GetIndex()]); });
304 },
305 dout);
306
307 GradExecutor::GetInstance()->DispatchRunTask(run_task);
308 {
309 py::gil_scoped_release release;
310 GradExecutor::GetInstance()->GetAsyncTaskManager()->GetRunTaskQueue()->Wait();
311 }
312 }
313
UpdateDependence()314 void FunctionNode::UpdateDependence() {
315 auto mark_func = [](const FunctionNodePtr &node) { node->is_in_reverse_chain_ = true; };
316 Visit(shared_from_base<FunctionNode>(), mark_func);
317 auto update_func = [](const FunctionNodePtr &node) {
318 for (auto iter = node->dependences_.begin(); iter != node->dependences_.end();) {
319 if (!(*iter)->is_in_reverse_chain_) {
320 std::queue<FunctionNodePtr> nodes;
321 nodes.push((*iter));
322 while (!nodes.empty()) {
323 auto fn = nodes.front();
324 nodes.pop();
325 std::for_each((*iter)->dependences_.begin(), (*iter)->dependences_.end(),
326 [&nodes](const auto &n) { nodes.push(n); });
327 fn->CleanResource();
328 }
329 iter = node->dependences_.erase(iter);
330 } else {
331 iter++;
332 }
333 }
334 };
335 dependences_.clear();
336 dependences_.insert(shared_from_base<FunctionNode>());
337 Visit(shared_from_base<FunctionNode>(), update_func);
338 }
339
Notify(const FunctionNodePtr & node,const ValuePtr & dout)340 void FunctionNode::Notify(const FunctionNodePtr &node, const ValuePtr &dout) {
341 node->AccumulateGradient(dout, node->index_);
342 node->depend_cnt_.fetch_add(1);
343 if (!node->IsReady()) {
344 return;
345 }
346 node->ApplyNative();
347 node->depend_cnt_.store(0);
348 node->dependences_.clear();
349 }
350
AccumulateGradient(const ValuePtr & dout,size_t index)351 void FunctionNode::AccumulateGradient(const ValuePtr &dout, size_t index) {
352 if (dout->isa<None>()) {
353 return;
354 }
355 auto func = backward_func_;
356 if (func == nullptr) {
357 func = NativeBackwardFunc::GetInstance(prim::kPrimAdd);
358 }
359 std::unique_lock<std::mutex> lock(mutex_);
360 auto value = GetGrad()[index];
361 if (value->isa<None>()) {
362 SetGrad(dout, index);
363 } else {
364 SetGrad(func->Add(dout, GetGrad()[index]), index);
365 }
366 }
367
ToString() const368 std::string FunctionNode::ToString() const {
369 std::stringstream ss;
370 Dump(ss, "");
371 return ss.str();
372 }
373
Dump(std::stringstream & ss,const std::string & prefix) const374 void FunctionNode::Dump(std::stringstream &ss, const std::string &prefix) const {
375 if (!prefix.empty()) {
376 ss << prefix << "-->";
377 }
378 auto prim = GetFunction();
379 ss << "FunctionNode(" << tensor_.ptr() << "(" << this << "), depend(" << dependences_.size() << "), "
380 << (prim->isa<None>() ? "None" : prim->ToString()) << ", "
381 << py::bool_(python_adapter::GetPyObjAttr(tensor_, "is_leaf")) << ")\n";
382 std::for_each(edges_.begin(), edges_.end(),
383 [&ss, &prefix](const auto &edge) { edge->GetNode()->Dump(ss, prefix + " "); });
384 }
385 } // namespace grad
386 } // namespace pijit
387 } // namespace mindspore
388