• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_META_GRAD_DATA_H_
18 #define MINDSPORE_CORE_IR_META_GRAD_DATA_H_
19 
20 #include <memory>
21 #include <utility>
22 #include <map>
23 #include <vector>
24 #include <string>
25 #include "ir/anf.h"
26 #include "include/common/utils/utils.h"
27 
28 namespace mindspore {
29 namespace pynative::autograd {
30 class Variable;
31 }  // namespace pynative::autograd
32 
33 class TensorBackwardHook;
34 using TensorBackwardHookPtr = std::shared_ptr<TensorBackwardHook>;
35 using VariablePtr = std::shared_ptr<pynative::autograd::Variable>;
36 using VariableWeakPtr = std::weak_ptr<pynative::autograd::Variable>;
37 
38 class AutoGradMetaData {
39  public:
40   AutoGradMetaData() = default;
41   AutoGradMetaData(const VariablePtr &variable, const ParameterPtr &parameter,
42                    const InputType input_type = InputType::kConstant)
variable_(variable)43       : variable_(variable), parameter_(parameter), input_type_(input_type) {}
variable()44   VariablePtr variable() const { return variable_.lock(); }
set_variable(const VariablePtr & variable)45   void set_variable(const VariablePtr &variable) { variable_ = variable; }
parameter()46   ParameterPtr parameter() const { return parameter_.lock(); }
set_parameter(const ParameterPtr & parameter)47   void set_parameter(const ParameterPtr &parameter) { parameter_ = parameter; }
set_k_node(const AnfNodePtr & k_node)48   void set_k_node(const AnfNodePtr &k_node) { k_node_ = k_node; }
k_node()49   AnfNodePtr k_node() const { return k_node_.lock(); }
input_type()50   InputType input_type() const { return input_type_; }
set_input_type(InputType input_type)51   void set_input_type(InputType input_type) { input_type_ = input_type; }
op_index()52   size_t op_index() const { return op_index_; }
set_op_index(size_t op_index)53   void set_op_index(size_t op_index) { op_index_ = op_index; }
output_index()54   [[nodiscard]] size_t output_index() const { return output_index_; }
set_output_index(size_t output_index)55   void set_output_index(size_t output_index) { output_index_ = output_index; }
AddBackwardHook(uint64_t id,TensorBackwardHookPtr hook)56   void AddBackwardHook(uint64_t id, TensorBackwardHookPtr hook) {
57     (void)backward_hooks_.emplace(id, std::move(hook));
58     is_register_hook_ = true;
59   }
RemoveBackwardHook(uint64_t id)60   void RemoveBackwardHook(uint64_t id) { (void)backward_hooks_.erase(id); }
is_register_hook()61   bool is_register_hook() const { return is_register_hook_; }
backward_hooks()62   const std::map<uint64_t, TensorBackwardHookPtr> &backward_hooks() { return backward_hooks_; }
ClearBackwardHooks()63   void ClearBackwardHooks() { backward_hooks_.clear(); }
64 
65  private:
66   // Weakptr for variable, to avoid circular reference
67   VariableWeakPtr variable_;
68   // Weakptr to hold ir parameter of input or parameter
69   ParameterWeakPtr parameter_;
70   // Weakptr to k_node for tensor
71   AnfNodeWeakPtr k_node_;
72   // Type of grad tensor
73   InputType input_type_{InputType::kUnkown};
74   // Optional for op output, represent index of op in execute order.
75   size_t op_index_{0};
76   // Index of op output tensors.
77   size_t output_index_{0};
78   bool is_register_hook_{false};
79   // Tensor hooks
80   std::map<uint64_t, TensorBackwardHookPtr> backward_hooks_;
81 };
82 using AutoGradMetaDataPtr = std::shared_ptr<AutoGradMetaData>;
83 }  // namespace mindspore
84 #endif  // MINDSPORE_CORE_IR_META_GRAD_DATA_H_
85