#pragma once #include #include namespace c10d { enum class BuiltinCommHookType : uint8_t { ALLREDUCE = 1, FP16_COMPRESS = 2, }; class AllReduceCommHook : public CppCommHookInterface> { public: explicit AllReduceCommHook(const c10::intrusive_ptr& state) : CppCommHookInterface>(state) {} ~AllReduceCommHook() override = default; c10::intrusive_ptr runHook(GradBucket& bucket) override; }; class FP16CompressCommHook : public CppCommHookInterface> { public: explicit FP16CompressCommHook(const c10::intrusive_ptr& state) : CppCommHookInterface>(state) {} ~FP16CompressCommHook() override = default; c10::intrusive_ptr runHook(GradBucket& bucket) override; }; // Almost same as AllReduceCommHook, but without division inside the hook. // This enables the optimization of fusing copy and division and saves one scan // over all the input parameters, when no communication hook is provided by the // user. Only used internally and not released as a public built-in // communication hook. class _AllReduceBySumCommHook : public CppCommHookInterface> { public: explicit _AllReduceBySumCommHook( const c10::intrusive_ptr& state) : CppCommHookInterface>(state) {} ~_AllReduceBySumCommHook() override = default; c10::intrusive_ptr runHook(GradBucket& bucket) override; }; } // namespace c10d