1 #include <ATen/native/cuda/fused_adamw_impl.cuh>
2
3 #include <ATen/Dispatch.h>
4 #include <ATen/native/ForeachUtils.h>
5 #include <ATen/native/cuda/MultiTensorApply.cuh>
6 #include <ATen/native/cuda/fused_adam_utils.cuh>
7 #include <vector>
8
9 namespace at {
10 namespace native {
11
_fused_adamw_cuda_impl_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList state_steps,const double lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)12 void _fused_adamw_cuda_impl_(
13 at::TensorList params,
14 at::TensorList grads,
15 at::TensorList exp_avgs,
16 at::TensorList exp_avg_sqs,
17 at::TensorList state_steps,
18 const double lr,
19 const double beta1,
20 const double beta2,
21 const double weight_decay,
22 const double eps,
23 const bool maximize,
24 const std::optional<at::Tensor>& grad_scale,
25 const std::optional<at::Tensor>& found_inf) {
26 std::vector<std::vector<at::Tensor>> tensor_lists{
27 params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()};
28
29 const float* grad_scale_ptr =
30 grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
31 const float* found_inf_ptr =
32 found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
33 const float* lr_ptr = nullptr;
34
35 AT_DISPATCH_FLOATING_TYPES_AND2(
36 kHalf,
37 kBFloat16,
38 params[0].scalar_type(),
39 "fused_adamw_kernel_cuda",
40 [&]() {
41 multi_tensor_apply_for_fused_optimizer<4>(
42 tensor_lists,
43 state_steps,
44 FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ADAMW, false>(),
45 lr_ptr, // unused
46 lr,
47 beta1,
48 beta2,
49 weight_decay,
50 eps,
51 maximize,
52 grad_scale_ptr,
53 found_inf_ptr);
54 });
55 }
56
57 // The following overload simply has a Tensor lr
_fused_adamw_cuda_impl_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList state_steps,const at::Tensor & lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)58 void _fused_adamw_cuda_impl_(
59 at::TensorList params,
60 at::TensorList grads,
61 at::TensorList exp_avgs,
62 at::TensorList exp_avg_sqs,
63 at::TensorList state_steps,
64 const at::Tensor& lr,
65 const double beta1,
66 const double beta2,
67 const double weight_decay,
68 const double eps,
69 const bool maximize,
70 const std::optional<at::Tensor>& grad_scale,
71 const std::optional<at::Tensor>& found_inf) {
72 std::vector<std::vector<at::Tensor>> tensor_lists{
73 params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()};
74
75 const float* grad_scale_ptr =
76 grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
77 const float* found_inf_ptr =
78 found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
79 const float* lr_ptr = lr.const_data_ptr<float>();
80
81 AT_DISPATCH_FLOATING_TYPES_AND2(
82 kHalf,
83 kBFloat16,
84 params[0].scalar_type(),
85 "fused_adamw_kernel_cuda",
86 [&]() {
87 multi_tensor_apply_for_fused_optimizer<4>(
88 tensor_lists,
89 state_steps,
90 FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ADAMW, false>(),
91 lr_ptr,
92 1.0, // unused
93 beta1,
94 beta2,
95 weight_decay,
96 eps,
97 maximize,
98 grad_scale_ptr,
99 found_inf_ptr);
100 });
101 }
102
103 } // namespace native
104 } // namespace at
105