#pragma once #include #include #include #include #include #include #include namespace py = pybind11; namespace torch::autograd { struct PySavedVariableHooks : public SavedVariableHooks { PySavedVariableHooks(py::function& pack_hook, py::function& unpack_hook); void call_pack_hook(const at::Tensor& tensor) override; at::Tensor call_unpack_hook() override; ~PySavedVariableHooks() override; private: PyObject* pack_hook_; PyObject* unpack_hook_; PyObject* data_ = nullptr; }; struct PyDefaultSavedVariableHooks { static void push_hooks(py::function& pack_hook, py::function& unpack_hook); static void pop_hooks(); static std::unique_ptr get_hooks(); }; } // namespace torch::autograd