1 /** 2 * Copyright 2019-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_IR_META_GRAD_DATA_H_ 18 #define MINDSPORE_CORE_IR_META_GRAD_DATA_H_ 19 20 #include <memory> 21 #include <utility> 22 #include <map> 23 #include <vector> 24 #include <string> 25 #include "ir/anf.h" 26 #include "include/common/utils/utils.h" 27 28 namespace mindspore { 29 namespace pynative::autograd { 30 class Variable; 31 } // namespace pynative::autograd 32 33 class TensorBackwardHook; 34 using TensorBackwardHookPtr = std::shared_ptr<TensorBackwardHook>; 35 using VariablePtr = std::shared_ptr<pynative::autograd::Variable>; 36 using VariableWeakPtr = std::weak_ptr<pynative::autograd::Variable>; 37 38 class AutoGradMetaData { 39 public: 40 AutoGradMetaData() = default; 41 AutoGradMetaData(const VariablePtr &variable, const ParameterPtr ¶meter, 42 const InputType input_type = InputType::kConstant) variable_(variable)43 : variable_(variable), parameter_(parameter), input_type_(input_type) {} variable()44 VariablePtr variable() const { return variable_.lock(); } set_variable(const VariablePtr & variable)45 void set_variable(const VariablePtr &variable) { variable_ = variable; } parameter()46 ParameterPtr parameter() const { return parameter_.lock(); } set_parameter(const ParameterPtr & parameter)47 void set_parameter(const ParameterPtr ¶meter) { parameter_ = parameter; } set_k_node(const AnfNodePtr & k_node)48 void set_k_node(const AnfNodePtr &k_node) { k_node_ = k_node; } k_node()49 AnfNodePtr k_node() const { return k_node_.lock(); } input_type()50 InputType input_type() const { return input_type_; } set_input_type(InputType input_type)51 void set_input_type(InputType input_type) { input_type_ = input_type; } op_index()52 size_t op_index() const { return op_index_; } set_op_index(size_t op_index)53 void set_op_index(size_t op_index) { op_index_ = op_index; } output_index()54 [[nodiscard]] size_t output_index() const { return output_index_; } set_output_index(size_t output_index)55 void set_output_index(size_t output_index) { output_index_ = output_index; } AddBackwardHook(uint64_t id,TensorBackwardHookPtr hook)56 void AddBackwardHook(uint64_t id, TensorBackwardHookPtr hook) { 57 (void)backward_hooks_.emplace(id, std::move(hook)); 58 is_register_hook_ = true; 59 } RemoveBackwardHook(uint64_t id)60 void RemoveBackwardHook(uint64_t id) { (void)backward_hooks_.erase(id); } is_register_hook()61 bool is_register_hook() const { return is_register_hook_; } backward_hooks()62 const std::map<uint64_t, TensorBackwardHookPtr> &backward_hooks() { return backward_hooks_; } ClearBackwardHooks()63 void ClearBackwardHooks() { backward_hooks_.clear(); } 64 65 private: 66 // Weakptr for variable, to avoid circular reference 67 VariableWeakPtr variable_; 68 // Weakptr to hold ir parameter of input or parameter 69 ParameterWeakPtr parameter_; 70 // Weakptr to k_node for tensor 71 AnfNodeWeakPtr k_node_; 72 // Type of grad tensor 73 InputType input_type_{InputType::kUnkown}; 74 // Optional for op output, represent index of op in execute order. 75 size_t op_index_{0}; 76 // Index of op output tensors. 77 size_t output_index_{0}; 78 bool is_register_hook_{false}; 79 // Tensor hooks 80 std::map<uint64_t, TensorBackwardHookPtr> backward_hooks_; 81 }; 82 using AutoGradMetaDataPtr = std::shared_ptr<AutoGradMetaData>; 83 } // namespace mindspore 84 #endif // MINDSPORE_CORE_IR_META_GRAD_DATA_H_ 85