• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/pynative/grad/variable.h"
18 #include <memory>
19 #include "pipeline/pynative/pynative_utils.h"
20 
21 namespace mindspore::pynative::autograd {
UpdateNextEdges(const ValuePtrList & inputs)22 void BackwardNode::UpdateNextEdges(const ValuePtrList &inputs) {
23   MS_LOG(DEBUG) << "Get input size " << inputs.size();
24   next_edges_.reserve(inputs.size());
25   gradient_index_.reserve(inputs.size());
26   for (size_t i = 0; i < inputs.size(); ++i) {
27     const auto &value = inputs[i];
28     if (value->isa<tensor::BaseTensor>()) {
29       auto tensor = value->cast<tensor::BaseTensorPtr>();
30       auto auto_grad_meta_data = tensor->auto_grad_meta_data();
31       MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
32       auto variable = auto_grad_meta_data->variable();
33       if (variable == nullptr || !variable->is_need_grad()) {
34         continue;
35       }
36       MS_LOG(DEBUG) << "Add next edge for tensor " << tensor->id();
37       (void)next_edges_.emplace_back(variable, auto_grad_meta_data->output_index());
38       (void)gradient_index_.emplace_back(i);
39     }
40     // to do sparse tensor.
41   }
42 }
43 
PostProcess(const ValuePtrList & gradient_value)44 ValuePtrList BackwardNode::PostProcess(const ValuePtrList &gradient_value) {
45   ValuePtrList gradients;
46   ValuePtrList flatten_values = PyNativeAlgo::DataConvert::FlattenTensorSeqInValueSeq(gradient_value, false);
47   gradients.reserve(flatten_values.size());
48   for (const auto index : gradient_index_) {
49     if (index >= flatten_values.size()) {
50       MS_LOG(EXCEPTION) << "Inputs gradient index should smaller than flatten_values size!";
51     }
52     auto gradient_tensor = flatten_values[index];
53     (void)gradients.emplace_back(gradient_tensor);
54   }
55   return gradients;
56 }
57 
LazeUpdateZeroGradient(const ValuePtrList & dout,FuncBuilder * func_builder,const ValuePtr & output)58 ValuePtrList BackwardNode::LazeUpdateZeroGradient(const ValuePtrList &dout, FuncBuilder *func_builder,
59                                                   const ValuePtr &output) {
60   if (dout.size() == kSizeOne) {
61     return dout;
62   }
63   ValuePtrList outputs;
64   PyNativeAlgo::DataConvert::FlattenValueSeqArg(output, true, false, &outputs);
65   if (outputs.size() != dout.size()) {
66     MS_LOG(EXCEPTION) << "Gradients size should be same as output size! But got output size: " << outputs.size()
67                       << ", gradients size: " << dout.size();
68   }
69   ValuePtrList real_dout(dout.size());
70   for (size_t i = 0; i < dout.size(); ++i) {
71     if (dout[i]->isa<None>()) {
72       MS_LOG(DEBUG) << "Op " << name() << " has multi outputs, and exist null dout, now do emit zeros";
73       auto zero_value =
74         PyNativeAlgo::AutoGrad::BuildSpecialValueGrad(outputs[i], nullptr, func_builder, SpecialType::kZerosLikeType);
75       MS_EXCEPTION_IF_NULL(zero_value);
76       real_dout[i] = zero_value;
77     } else {
78       real_dout[i] = dout[i];
79     }
80   }
81   return real_dout;
82 }
83 
ToString() const84 std::string FuncVariable::ToString() const {
85   std::ostringstream buf;
86   buf << "Variable name: " << func_node()->name() << ", is_need_grad: " << is_need_grad()
87       << ", is_need_propagate: " << is_need_propagate() << " is_leaf: " << is_leaf() << "\n";
88   for (size_t i = 0; i < func_node()->next_edges().size(); ++i) {
89     auto last_variable = func_node()->next_edges()[i].variable;
90     auto index = func_node()->next_edges()[i].input_index;
91     buf << "Last edge: " << i << ", variable name: " << last_variable->func_node()->name()
92         << ", output index: " << index << "\n";
93   }
94   return buf.str();
95 }
96 
ToString() const97 std::string IrVariable::ToString() const {
98   std::ostringstream buf;
99   buf << "Variable id: " << PyNativeAlgo::Common::GetIdByValue(out_value()) << ", is_need_grad: " << is_need_grad()
100       << ", is_need_propagate: " << is_need_propagate() << ", is_leaf: " << is_leaf();
101   for (size_t i = 0; i < ir_function_node()->next_edges().size(); ++i) {
102     auto last_variable = ir_function_node()->next_edges()[i].first;
103     auto din = ir_function_node()->next_edges()[i].second;
104     buf << ", next edge variable id: " << PyNativeAlgo::Common::GetIdByValue(last_variable->out_value())
105         << " din: " << din->DebugString();
106   }
107   return buf.str();
108 }
109 
RealDout()110 AnfNodePtr IrVariable::RealDout() {
111   if (static_cast<bool>(MS_UNLIKELY(PyNativeAlgo::AutoGrad::IsZerosLikeNode(ir_function_node()->accumulate_dout())))) {
112     ir_function_node()->set_accumulate_dout(PyNativeAlgo::AutoGrad::BuildSpecialNode(
113       ir_function_node()->tape(), out_value(), ir_function_node()->accumulate_dout()->abstract(),
114       SpecialType::kZerosLikeType));
115   }
116   const auto &accumulate_dout = ir_function_node()->accumulate_dout();
117   const auto &dout_abs = accumulate_dout->abstract();
118   MS_EXCEPTION_IF_NULL(dout_abs);
119   // For input, if it is a sparsetensor, we need return a sparsetensor.
120   if (out_value()->isa<tensor::BaseTensor>() || dout_abs->isa<abstract::AbstractSparseTensor>()) {
121     return accumulate_dout;
122   } else if (out_value()->isa<tensor::MetaSparseTensor>()) {
123     return PyNativeAlgo::AutoGrad::BuildSparseTensorNode(ir_function_node()->tape(), out_value(), accumulate_dout);
124   }
125   return accumulate_dout;
126 }
127 }  // namespace mindspore::pynative::autograd
128