#pragma once #include #include #include #include #include namespace c10d { class TORCH_PYTHON_API PythonCommHook : public CommHookInterface { public: // Takes a state and a callable hook. The inputs are Python objects. // The state is passed to the hook in runHook method, and it can be used to // maintain and update any state information during the execution of the hook. // The hook performs user-specified processing and returns a future indicating // asychronous communication of gradients. PythonCommHook(py::object state, py::object hook) : state_(std::move(state)), hook_(std::move(hook)) {} ~PythonCommHook() override; c10::intrusive_ptr runHook(GradBucket& bucket) override; at::Tensor parseHookResult(const c10::IValue& result) override; private: // Only needed for stateful communication. py::object state_; py::object hook_; }; } // namespace c10d