• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/native/cuda/fused_adamw_impl.cuh>
2 
3 #include <ATen/Dispatch.h>
4 #include <ATen/native/ForeachUtils.h>
5 #include <ATen/native/cuda/MultiTensorApply.cuh>
6 #include <ATen/native/cuda/fused_adam_utils.cuh>
7 #include <vector>
8 
9 namespace at {
10 namespace native {
11 
_fused_adamw_cuda_impl_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList state_steps,const double lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)12 void _fused_adamw_cuda_impl_(
13     at::TensorList params,
14     at::TensorList grads,
15     at::TensorList exp_avgs,
16     at::TensorList exp_avg_sqs,
17     at::TensorList state_steps,
18     const double lr,
19     const double beta1,
20     const double beta2,
21     const double weight_decay,
22     const double eps,
23     const bool maximize,
24     const std::optional<at::Tensor>& grad_scale,
25     const std::optional<at::Tensor>& found_inf) {
26   std::vector<std::vector<at::Tensor>> tensor_lists{
27       params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()};
28 
29   const float* grad_scale_ptr =
30       grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
31   const float* found_inf_ptr =
32       found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
33   const float* lr_ptr = nullptr;
34 
35   AT_DISPATCH_FLOATING_TYPES_AND2(
36       kHalf,
37       kBFloat16,
38       params[0].scalar_type(),
39       "fused_adamw_kernel_cuda",
40       [&]() {
41         multi_tensor_apply_for_fused_optimizer<4>(
42             tensor_lists,
43             state_steps,
44             FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ADAMW, false>(),
45             lr_ptr, // unused
46             lr,
47             beta1,
48             beta2,
49             weight_decay,
50             eps,
51             maximize,
52             grad_scale_ptr,
53             found_inf_ptr);
54       });
55 }
56 
57 // The following overload simply has a Tensor lr
_fused_adamw_cuda_impl_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList state_steps,const at::Tensor & lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)58 void _fused_adamw_cuda_impl_(
59     at::TensorList params,
60     at::TensorList grads,
61     at::TensorList exp_avgs,
62     at::TensorList exp_avg_sqs,
63     at::TensorList state_steps,
64     const at::Tensor& lr,
65     const double beta1,
66     const double beta2,
67     const double weight_decay,
68     const double eps,
69     const bool maximize,
70     const std::optional<at::Tensor>& grad_scale,
71     const std::optional<at::Tensor>& found_inf) {
72   std::vector<std::vector<at::Tensor>> tensor_lists{
73       params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()};
74 
75   const float* grad_scale_ptr =
76       grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
77   const float* found_inf_ptr =
78       found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
79   const float* lr_ptr = lr.const_data_ptr<float>();
80 
81   AT_DISPATCH_FLOATING_TYPES_AND2(
82       kHalf,
83       kBFloat16,
84       params[0].scalar_type(),
85       "fused_adamw_kernel_cuda",
86       [&]() {
87         multi_tensor_apply_for_fused_optimizer<4>(
88             tensor_lists,
89             state_steps,
90             FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ADAMW, false>(),
91             lr_ptr,
92             1.0, // unused
93             beta1,
94             beta2,
95             weight_decay,
96             eps,
97             maximize,
98             grad_scale_ptr,
99             found_inf_ptr);
100       });
101 }
102 
103 } // namespace native
104 } // namespace at
105