#pragma once #include #include #include #include #include #include namespace torch::autograd { using Variable = at::Tensor; struct Node; TORCH_API extern const char* ERR_BACKWARD_TWICE; /// A snapshot of a variable at a certain version. A `SavedVariable` stores /// enough information to reconstruct a variable from a certain point in time. class TORCH_API SavedVariable { public: SavedVariable() = default; SavedVariable( const Variable& variable, bool is_output, bool is_inplace_on_view = false); SavedVariable( const std::optional& variable, bool is_output, bool is_inplace_on_view = false); SavedVariable(SavedVariable&&) = default; SavedVariable& operator=(SavedVariable&&) = default; ~SavedVariable() { if (fw_grad_) { // See note [ Using ForwardGrad ] fw_grad_->clear(); } } /// Reconstructs the saved variable. Pass `saved_for` as the gradient /// function if constructing the `SavedVariable` with it would have caused a /// circular reference. Variable unpack(std::shared_ptr saved_for = nullptr) const; void register_hooks(std::unique_ptr&& hooks); void reset_data(); bool has_hooks() const { return (bool)hooks_; } private: // This field contains either: // 1. the variable to save // 2. or its tensor_data. // If storing the variable itself would create a circular reference, // we fall into the second case and its metadata is also saved separately. // In that case, the grad_fn must be passed in to the unpack function when // reconstructing the Variable (except when we are doing an inplace operation // on a view, see below). The field saved_original_ below reflects the two // cases: its value is true in the first case and false in the second case. // The value data_.defined() can be false in three cases: // 1. SavedVariable was constructed without a Tensor (the value to save is // None), in that case was_default_constructed_ will be kept at true // 2. The saved variable has been released by calling // SavedVariable::reset_data(), typically during the backward pass // 3. Hooks have been registered. In that case, hooks_ will be defined // instead. Note that the value of saved_original_ only reflects what happened // during the construction of the SavedVariable. If saved_original_ is true, // we saved the original tensor in data_, but if the user registers hooks, we // will no longer have it (despite the saved_original_ still being true) at::Tensor data_; // This field is used to store the forward AD gradients associated with // the saved Tensor. Note that this shared_ptr must never be shared with // either the saved Tensor or the unpacked Tensor. See note [ Using // ForwardGrad ] std::shared_ptr fw_grad_; // Weak version of grad_fn_ that prevents leaks in rebase_history() for // inplace views. // This variable is used when the user chooses to create a SavedVariable with // is_inplace_on_view = true. // In that case, the grad_fn passed in to the unpack function at unwrapping // time is unused. std::weak_ptr weak_grad_fn_; uint32_t saved_version_ = 0; uint32_t output_nr_ = 0; bool was_default_constructed_ = true; bool is_inplace_on_view_ = false; bool saved_original_ = false; bool is_leaf_ = false; bool is_output_ = false; // Hooks are a pair of functions pack_hook/unpack_hook that provides // fine-grained control over how the SavedVariable should save its data. // pack_hook is called upon registration, while unpack_hook is called when // unpacking. std::unique_ptr hooks_; // Fields grad_fn_, grad_accumulator_, and requires_grad_ are only used if // hooks are defined. They are set before pack_hook is called and used after // unpack_hook is called. std::shared_ptr grad_fn_; // For the usual case where leaf tensors are the input, we expect its // grad_acc to be kept alive by the graph. The reason SavedVariable holds // a owning reference is to support the case where a custom autograd Function // saves an intermediate. std::shared_ptr grad_accumulator_; bool requires_grad_ = false; void save_metadata(const Variable& data); static std::unique_ptr get_default_hooks(); void set_hooks_and_pack_data( std::unique_ptr&& hooks, const Variable& data); }; } // namespace torch::autograd