• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/core/Tensor.h>
2 #include <ATen/native/DispatchStub.h>
3 
4 namespace at::native {
5 
6 enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 };
7 
8 using fused_adam_fn = void (*)(
9     const at::Tensor& param,
10     const at::Tensor& grad,
11     const at::Tensor& exp_avg,
12     const at::Tensor& exp_avg_sq,
13     const at::Tensor& max_exp_avg_sq,
14     const at::Tensor& state_step,
15     const double lr,
16     const double beta1,
17     const double beta2,
18     const double weight_decay,
19     const double eps,
20     const bool amsgrad,
21     const bool maximize,
22     const float* grad_scale_ptr,
23     const ADAM_MODE);
24 
25 DECLARE_DISPATCH(fused_adam_fn, fused_adam_stub);
26 
27 } // namespace at::native
28