• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/jit/pi/auto_grad/function_node.h"
17 #include <algorithm>
18 #include <exception>
19 #include <functional>
20 #include <iterator>
21 #include <memory>
22 #include <queue>
23 #include <sstream>
24 #include <string>
25 #include <utility>
26 #include "ops/sequence_ops.h"
27 #include "pipeline/jit/pi/auto_grad/edge.h"
28 #include "pipeline/jit/pi/auto_grad/grad_executor.h"
29 #include "pipeline/jit/pi/auto_grad/native_backward_function.h"
30 #include "utils/ms_utils.h"
31 
32 namespace mindspore {
33 namespace pijit {
34 namespace grad {
CleanResource()35 void FunctionNode::CleanResource() {
36   if (HasGradFunc(tensor_)) {
37     py::setattr(tensor_, "grad_fn", py::none());
38   }
39   tensor_ = py::none();
40   backward_func_ = nullptr;
41   edges_.clear();
42   dependences_.clear();
43 }
44 
PostBpropFunctionToEdges(const py::object & tensor)45 void PostBpropFunctionToEdges(const py::object &tensor) {
46   py::object grad_fn = python_adapter::GetPyObjAttr(tensor, "grad_fn");
47   if (py::isinstance<py::none>(grad_fn)) {
48     return;
49   }
50   auto func_node = grad_fn.cast<grad::FunctionNodePtr>();
51   func_node->GenerateBropFunction();
52   auto edges = func_node->GetNextEdges();
53   std::for_each(edges.begin(), edges.end(), [](const EdgePtr &edge) { edge->GetNode()->GenerateBropFunction(); });
54 }
55 
ConvertTupleToValueList(const py::list & inputs)56 ValuePtrList ConvertTupleToValueList(const py::list &inputs) {
57   ValuePtrList value_list;
58   for (const auto input : inputs) {
59     ValuePtr value = Convert::PyObjToValue(py::cast<py::object>(input));
60     if (value->template isa<None>()) {
61       return value_list;
62     }
63     value_list.push_back(value);
64   }
65   return value_list;
66 }
67 
ConvertArgByCastDtype(const py::object & arg,const ops::OpInputArg & op_arg)68 ValuePtr ConvertArgByCastDtype(const py::object &arg, const ops::OpInputArg &op_arg) {
69   parse::OpDefConvertFunc convert_func = parse::GetConverterByType(static_cast<int32_t>(op_arg.arg_dtype_));
70   MS_EXCEPTION_IF_NULL(convert_func);
71   ValuePtr value = convert_func(arg);
72   if (value != nullptr) {
73     return value;
74   }
75   for (auto cast_dtype : op_arg.cast_dtype_) {
76     convert_func = parse::GetConverterByType(parse::CombineTypesForTypeCast(cast_dtype, op_arg.arg_dtype_));
77     MS_EXCEPTION_IF_NULL(convert_func);
78     value = convert_func(arg);
79     if (value != nullptr) {
80       return value;
81     }
82   }
83   if (!py::isinstance<py::none>(arg) && value == nullptr) {
84     value = Convert::PyObjToValue(arg);
85   }
86   return value;
87 }
88 
ParseInputsByOpDef(const PrimitivePyPtr & prim,const ops::OpDefPtr & op_def,const py::list & inputs)89 ValuePtrList ParseInputsByOpDef(const PrimitivePyPtr &prim, const ops::OpDefPtr &op_def, const py::list &inputs) {
90   ValuePtrList input_values;
91   MS_EXCEPTION_IF_CHECK_FAIL(inputs.size() <= op_def->args_.size(), "The arguments is not match defined.");
92   size_t index = 0;
93   std::for_each(op_def->args_.begin(), op_def->args_.end(), [&prim, &inputs, &input_values, &index](const auto &arg) {
94     if (!arg.as_init_arg_) {
95       input_values.push_back(ConvertArgByCastDtype(inputs[index], arg));
96     } else {
97       auto value = py::getattr(prim->GetPyObj(), common::SafeCStr(arg.arg_name_));
98       if (!py::isinstance<py::none>(value)) {
99         input_values.push_back(ConvertArgByCastDtype(value, arg));
100       }
101     }
102     index++;
103   });
104   return input_values;
105 }
106 
ParseInputs(const PrimitivePyPtr & prim,const py::list & inputs)107 ValuePtrList ParseInputs(const PrimitivePyPtr &prim, const py::list &inputs) {
108   auto op_def = mindspore::ops::GetOpDef(prim->name());
109   if (op_def == nullptr) {
110     return ConvertTupleToValueList(inputs);
111   }
112   return ParseInputsByOpDef(prim, op_def, inputs);
113 }
114 
IsRequiresGradient(const py::handle & obj)115 bool FunctionNode::IsRequiresGradient(const py::handle &obj) {
116   if (!HasAttrReqGrad(obj)) {
117     return false;
118   }
119   auto requires_grad = obj.attr("requires_grad");
120   return py::isinstance<py::bool_>(requires_grad) && py::bool_(requires_grad);
121 }
122 
HasGradFunc(const py::handle & obj)123 bool FunctionNode::HasGradFunc(const py::handle &obj) {
124   auto grad_fn = python_adapter::GetPyObjAttr(py::cast<py::object>(obj), "grad_fn");
125   return py::isinstance<grad::FunctionNode>(grad_fn);
126 }
127 
GetOrCreateFunctionNode(const py::object & tensor,const py::object & prim,const py::object & out,const py::list & inputs)128 FunctionNodePtr GetOrCreateFunctionNode(const py::object &tensor, const py::object &prim, const py::object &out,
129                                         const py::list &inputs) {
130   py::object grad_fn = python_adapter::GetPyObjAttr(tensor, "grad_fn");
131   if (py::isinstance<grad::FunctionNode>(grad_fn)) {
132     return grad_fn.cast<grad::FunctionNodePtr>();
133   }
134   auto func_node = FunctionNode::CreateFunctionNode(tensor, prim, out, inputs);
135   if (!func_node->GetNextEdges().empty()) {
136     py::setattr(func_node->GetTensor(), "grad_fn", py::cast(func_node));
137     py::setattr(func_node->GetTensor(), "requires_grad", py::bool_(True));
138   }
139   return func_node;
140 }
141 
CreateFunctionNode(const py::object & tensor,const py::object & prim,const py::object & out,const py::list & inputs)142 FunctionNodePtr FunctionNode::CreateFunctionNode(const py::object &tensor, const py::object &prim,
143                                                  const py::object &out, const py::list &inputs) {
144   auto func_node = std::make_shared<FunctionNode>(tensor, prim, out);
145   MS_LOG(DEBUG) << "Create a function node(" << func_node.get() << ") for " << tensor.ptr();
146   func_node->SetInputs(inputs);
147   std::for_each(inputs.begin(), inputs.end(), [func_node, &inputs](const auto &obj) {
148     if (!FunctionNode::IsRequiresGradient(obj) && !FunctionNode::HasGradFunc(obj)) {
149       return;
150     }
151     auto input = py::cast<py::object>(obj);
152     auto node = GetOrCreateFunctionNode(input, py::none(), input, py::list());
153     func_node->AddNextEdge(node, std::distance(inputs.begin(), std::find(inputs.begin(), inputs.end(), input)));
154   });
155   return func_node;
156 }
157 
RecordPrimitive(const py::object & prim,const py::object & out,const py::list & inputs)158 void FunctionNode::RecordPrimitive(const py::object &prim, const py::object &out, const py::list &inputs) {
159   MS_LOG(DEBUG) << "Record " << out.ptr() << " for auto gradient.";
160   if (!py::isinstance<py::tuple>(out)) {
161     (void)GetOrCreateFunctionNode(out, prim, out, inputs);
162   } else {
163     auto func_node = CreateFunctionNode(out, prim, out, inputs);
164     std::for_each(out.begin(), out.end(), [&out, &func_node](const auto &obj) {
165       if (!HasAttrReqGrad(obj) || HasGradFunc(obj)) {
166         return;
167       }
168       auto tensor = py::cast<py::object>(obj);
169       auto temp_node = GetOrCreateFunctionNode(tensor, py::none(), tensor, py::list());
170       temp_node->index_ = std::distance(out.begin(), std::find(out.begin(), out.end(), tensor));
171       temp_node->AddNextEdge(func_node, 0);
172       py::setattr(tensor, "grad_fn", py::cast(temp_node));
173       py::setattr(tensor, "requires_grad", py::bool_(True));
174     });
175   }
176 }
177 
SetInputs(const py::list & inputs)178 void FunctionNode::SetInputs(const py::list &inputs) {
179   if (py::isinstance<py::none>(inputs) || inputs.empty()) {
180     FunctionContext::SetInputs({});
181   } else {
182     FunctionContext::SetInputs(ParseInputs(GetFunction()->cast<PrimitivePyPtr>(), inputs));
183   }
184 }
185 
ApplyEdges(const ValuePtrList & grad_values)186 void FunctionNode::ApplyEdges(const ValuePtrList &grad_values) {
187   MS_EXCEPTION_IF_CHECK_FAIL((grad_values.size() == edges_.size()), "The gradient values is not match.");
188   for (size_t index = 0; index < edges_.size(); index++) {
189     Notify(edges_[index]->GetNode(), grad_values[index]);
190   }
191 }
192 
ApplyNative()193 void FunctionNode::ApplyNative() {
194   ValuePtrList flatten_values = ValuePtrList(edges_.size(), GetGrad()[0]);
195   if (backward_func_ != nullptr) {
196     backward_func_->SetGradientIndexes({});
197     std::for_each(edges_.begin(), edges_.end(),
198                   [this](const auto &edge) { backward_func_->AddGradientIndex(edge->GetIndex()); });
199     if (GetOutput()->isa<ValueTuple>()) {
200       flatten_values = backward_func_->Run(GetInputs(), GetOutput(), MakeValue(GetGrad()));
201     } else {
202       flatten_values = backward_func_->Run(GetInputs(), GetOutput(), GetGrad()[index_]);
203     }
204   }
205   ApplyEdges(flatten_values);
206 }
207 
208 /// \brief Generate the bprop function.
GenerateBropFunction()209 void FunctionNode::GenerateBropFunction() {
210   auto generate_task = std::make_shared<RunGenerateBpropTask>([this]() {
211     MS_LOG(DEBUG) << "Generate brop function for node " << tensor_.ptr() << ", tensor is " << tensor_.ptr();
212     auto output = GetOutput();
213     auto executor = GradExecutor::GetInstance();
214     {
215       // gil for PyObject accessing
216       py::gil_scoped_acquire gil_acquire;
217       auto acc_fn = executor->GetAccumulateGraph(output);
218       acc_fn_ = executor->PrimBpropGraphPass(acc_fn);
219     }
220 
221     auto func = GetFunction();
222     if (func->isa<None>()) {
223       return;
224     }
225     try {
226       // gil for PyObject accessing
227       py::gil_scoped_acquire gil_acquire;
228       grad_fn_ = executor->GetBpropGraph(NewValueNode(func), GetInputs(), output, output);
229     } catch (const std::exception &e) {
230       MS_LOG(ERROR) << "Prim : " << func->ToString() << " Output : " << output->ToString();
231       MS_LOG(ERROR) << e.what();
232     }
233   });
234   GradExecutor::GetInstance()->DispatchGenerateTask(generate_task);
235   {
236     py::gil_scoped_release release;
237     GradExecutor::GetInstance()->GetAsyncTaskManager()->GetGenerateTaskQueue()->Wait();
238   }
239 }
240 
Visit(const FunctionNodePtr & node,const std::function<void (const FunctionNodePtr &)> & callback)241 void Visit(const FunctionNodePtr &node, const std::function<void(const FunctionNodePtr &)> &callback) {
242   std::queue<FunctionNodePtr> nodes;
243   nodes.push(node);
244   while (!nodes.empty()) {
245     auto fn = nodes.front();
246     nodes.pop();
247     std::for_each(fn->GetNextEdges().begin(), fn->GetNextEdges().end(),
248                   [&nodes](const auto &edge) { nodes.push(edge->GetNode()); });
249     callback(fn);
250   }
251 }
252 
SyncGradToPyObject()253 void FunctionNode::SyncGradToPyObject() {
254   auto sync_func = [](const FunctionNodePtr &node) {
255     auto retains_grad = python_adapter::GetPyObjAttr(node->tensor_, "retains_grad");
256     if (node->edges_.empty() || (py::isinstance<py::bool_>(retains_grad) && py::bool_(retains_grad))) {
257       auto _grad = python_adapter::GetPyObjAttr(node->tensor_, "grad");
258       if (!py::isinstance<py::none>(_grad)) {
259         node->AccumulateGradient(Convert::PyObjToValue(_grad), node->index_);
260       } else {
261         if (node->GetGrad()[node->index_]->isa<None>()) {
262           auto func = node->backward_func_;
263           if (func == nullptr) {
264             func = NativeBackwardFunc::GetInstance(prim::kPrimAdd);
265           }
266           node->SetGrad(func->Zeros(node->GetOutput()), node->index_);
267         }
268       }
269       auto value = Convert::ValueToPyObj(node->GetGrad()[node->index_]);
270       auto grad = python_adapter::CallPyFn("mindspore.common.api", "_convert_python_data", value);
271       py::setattr(node->tensor_, "grad", grad);
272     }
273     node->SetGrad(ValuePtrList(node->GetGrad().size(), kNone));
274   };
275   Visit(shared_from_base<FunctionNode>(), sync_func);
276 }
277 
Apply(const py::object & grad)278 void FunctionNode::Apply(const py::object &grad) {
279   UpdateDependence();
280   Notify(shared_from_base<FunctionNode>(), Convert::PyObjToValue(grad));
281   SyncGradToPyObject();
282   auto release_func = [](const FunctionNodePtr &node) { node->CleanResource(); };
283   Visit(shared_from_base<FunctionNode>(), release_func);
284 }
285 
ApplyInner(const ValuePtr & dout)286 void FunctionNode::ApplyInner(const ValuePtr &dout) {
287   MS_LOG(DEBUG) << "Start run apply() of " << tensor_.ptr() << ", tensor is " << tensor_.ptr();
288   MS_LOG(DEBUG) << "Prim is " << GetFunction()->ToString() << ", dout is " << dout->ToString();
289   auto run_task = std::make_shared<RunBpropTask>(
290     [this](const ValuePtr &dout) {
291       AccumulateGradient(dout, index_);
292       if (grad_fn_ == nullptr) {
293         return;
294       }
295       // gil for PyObject accessing
296       py::gil_scoped_acquire gil_acquire;
297       auto ret = GradExecutor::GetInstance()->RunGraph(grad_fn_, GetInputs(), GetOutput(), dout);
298       if (!ret->isa<ValueTuple>()) {
299         return;
300       }
301       auto tuple = ret->cast<ValueTuplePtr>();
302       std::for_each(edges_.begin(), edges_.end(),
303                     [&tuple](const auto &edge) { edge->GetNode()->ApplyInner(tuple->value()[edge->GetIndex()]); });
304     },
305     dout);
306 
307   GradExecutor::GetInstance()->DispatchRunTask(run_task);
308   {
309     py::gil_scoped_release release;
310     GradExecutor::GetInstance()->GetAsyncTaskManager()->GetRunTaskQueue()->Wait();
311   }
312 }
313 
UpdateDependence()314 void FunctionNode::UpdateDependence() {
315   auto mark_func = [](const FunctionNodePtr &node) { node->is_in_reverse_chain_ = true; };
316   Visit(shared_from_base<FunctionNode>(), mark_func);
317   auto update_func = [](const FunctionNodePtr &node) {
318     for (auto iter = node->dependences_.begin(); iter != node->dependences_.end();) {
319       if (!(*iter)->is_in_reverse_chain_) {
320         std::queue<FunctionNodePtr> nodes;
321         nodes.push((*iter));
322         while (!nodes.empty()) {
323           auto fn = nodes.front();
324           nodes.pop();
325           std::for_each((*iter)->dependences_.begin(), (*iter)->dependences_.end(),
326                         [&nodes](const auto &n) { nodes.push(n); });
327           fn->CleanResource();
328         }
329         iter = node->dependences_.erase(iter);
330       } else {
331         iter++;
332       }
333     }
334   };
335   dependences_.clear();
336   dependences_.insert(shared_from_base<FunctionNode>());
337   Visit(shared_from_base<FunctionNode>(), update_func);
338 }
339 
Notify(const FunctionNodePtr & node,const ValuePtr & dout)340 void FunctionNode::Notify(const FunctionNodePtr &node, const ValuePtr &dout) {
341   node->AccumulateGradient(dout, node->index_);
342   node->depend_cnt_.fetch_add(1);
343   if (!node->IsReady()) {
344     return;
345   }
346   node->ApplyNative();
347   node->depend_cnt_.store(0);
348   node->dependences_.clear();
349 }
350 
AccumulateGradient(const ValuePtr & dout,size_t index)351 void FunctionNode::AccumulateGradient(const ValuePtr &dout, size_t index) {
352   if (dout->isa<None>()) {
353     return;
354   }
355   auto func = backward_func_;
356   if (func == nullptr) {
357     func = NativeBackwardFunc::GetInstance(prim::kPrimAdd);
358   }
359   std::unique_lock<std::mutex> lock(mutex_);
360   auto value = GetGrad()[index];
361   if (value->isa<None>()) {
362     SetGrad(dout, index);
363   } else {
364     SetGrad(func->Add(dout, GetGrad()[index]), index);
365   }
366 }
367 
ToString() const368 std::string FunctionNode::ToString() const {
369   std::stringstream ss;
370   Dump(ss, "");
371   return ss.str();
372 }
373 
Dump(std::stringstream & ss,const std::string & prefix) const374 void FunctionNode::Dump(std::stringstream &ss, const std::string &prefix) const {
375   if (!prefix.empty()) {
376     ss << prefix << "-->";
377   }
378   auto prim = GetFunction();
379   ss << "FunctionNode(" << tensor_.ptr() << "(" << this << "), depend(" << dependences_.size() << "), "
380      << (prim->isa<None>() ? "None" : prim->ToString()) << ", "
381      << py::bool_(python_adapter::GetPyObjAttr(tensor_, "is_leaf")) << ")\n";
382   std::for_each(edges_.begin(), edges_.end(),
383                 [&ss, &prefix](const auto &edge) { edge->GetNode()->Dump(ss, prefix + "   "); });
384 }
385 }  // namespace grad
386 }  // namespace pijit
387 }  // namespace mindspore
388