1 #include <ATen/native/cuda/fused_adam_impl.cuh>
2
3 #include <ATen/Dispatch.h>
4 #include <ATen/native/ForeachUtils.h>
5 #include <ATen/native/cuda/MultiTensorApply.cuh>
6 #include <ATen/native/cuda/fused_adam_utils.cuh>
7 #include <vector>
8
9 namespace at::native {
10
_fused_adam_cuda_impl_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList state_steps,const double lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)11 void _fused_adam_cuda_impl_(
12 at::TensorList params,
13 at::TensorList grads,
14 at::TensorList exp_avgs,
15 at::TensorList exp_avg_sqs,
16 at::TensorList state_steps,
17 const double lr,
18 const double beta1,
19 const double beta2,
20 const double weight_decay,
21 const double eps,
22 const bool maximize,
23 const std::optional<at::Tensor>& grad_scale,
24 const std::optional<at::Tensor>& found_inf) {
25 std::vector<std::vector<at::Tensor>> tensor_lists{
26 params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()};
27
28 const float* grad_scale_ptr =
29 grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
30 const float* found_inf_ptr =
31 found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
32 const float* lr_ptr = nullptr;
33
34 AT_DISPATCH_FLOATING_TYPES_AND2(
35 kHalf,
36 kBFloat16,
37 params[0].scalar_type(),
38 "fused_adam_kernel_cuda",
39 [&]() {
40 multi_tensor_apply_for_fused_optimizer<4>(
41 tensor_lists,
42 state_steps,
43 FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ORIGINAL, false>(),
44 lr_ptr, // unused
45 lr,
46 beta1,
47 beta2,
48 weight_decay,
49 eps,
50 maximize,
51 grad_scale_ptr,
52 found_inf_ptr);
53 });
54 }
55
56 // The following overload simply has a Tensor lr
_fused_adam_cuda_impl_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList state_steps,const at::Tensor & lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)57 void _fused_adam_cuda_impl_(
58 at::TensorList params,
59 at::TensorList grads,
60 at::TensorList exp_avgs,
61 at::TensorList exp_avg_sqs,
62 at::TensorList state_steps,
63 const at::Tensor& lr,
64 const double beta1,
65 const double beta2,
66 const double weight_decay,
67 const double eps,
68 const bool maximize,
69 const std::optional<at::Tensor>& grad_scale,
70 const std::optional<at::Tensor>& found_inf) {
71 std::vector<std::vector<at::Tensor>> tensor_lists{
72 params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()};
73
74 const float* grad_scale_ptr =
75 grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
76 const float* found_inf_ptr =
77 found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
78 const float* lr_ptr = lr.const_data_ptr<float>();
79
80 AT_DISPATCH_FLOATING_TYPES_AND2(
81 kHalf,
82 kBFloat16,
83 params[0].scalar_type(),
84 "fused_adam_kernel_cuda",
85 [&]() {
86 multi_tensor_apply_for_fused_optimizer<4>(
87 tensor_lists,
88 state_steps,
89 FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ORIGINAL, false>(),
90 lr_ptr,
91 1.0, // unused
92 beta1,
93 beta2,
94 weight_decay,
95 eps,
96 maximize,
97 grad_scale_ptr,
98 found_inf_ptr);
99 });
100 }
101
102 } // namespace at::native
103