• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/native/cuda/fused_adam_impl.cuh>
2 
3 #include <ATen/Dispatch.h>
4 #include <ATen/native/ForeachUtils.h>
5 #include <ATen/native/cuda/MultiTensorApply.cuh>
6 #include <ATen/native/cuda/fused_adam_utils.cuh>
7 #include <vector>
8 
9 namespace at::native {
10 
_fused_adam_cuda_impl_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList state_steps,const double lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)11 void _fused_adam_cuda_impl_(
12     at::TensorList params,
13     at::TensorList grads,
14     at::TensorList exp_avgs,
15     at::TensorList exp_avg_sqs,
16     at::TensorList state_steps,
17     const double lr,
18     const double beta1,
19     const double beta2,
20     const double weight_decay,
21     const double eps,
22     const bool maximize,
23     const std::optional<at::Tensor>& grad_scale,
24     const std::optional<at::Tensor>& found_inf) {
25   std::vector<std::vector<at::Tensor>> tensor_lists{
26       params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()};
27 
28   const float* grad_scale_ptr =
29       grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
30   const float* found_inf_ptr =
31       found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
32   const float* lr_ptr = nullptr;
33 
34   AT_DISPATCH_FLOATING_TYPES_AND2(
35       kHalf,
36       kBFloat16,
37       params[0].scalar_type(),
38       "fused_adam_kernel_cuda",
39       [&]() {
40         multi_tensor_apply_for_fused_optimizer<4>(
41             tensor_lists,
42             state_steps,
43             FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ORIGINAL, false>(),
44             lr_ptr, // unused
45             lr,
46             beta1,
47             beta2,
48             weight_decay,
49             eps,
50             maximize,
51             grad_scale_ptr,
52             found_inf_ptr);
53       });
54 }
55 
56 // The following overload simply has a Tensor lr
_fused_adam_cuda_impl_(at::TensorList params,at::TensorList grads,at::TensorList exp_avgs,at::TensorList exp_avg_sqs,at::TensorList state_steps,const at::Tensor & lr,const double beta1,const double beta2,const double weight_decay,const double eps,const bool maximize,const std::optional<at::Tensor> & grad_scale,const std::optional<at::Tensor> & found_inf)57 void _fused_adam_cuda_impl_(
58     at::TensorList params,
59     at::TensorList grads,
60     at::TensorList exp_avgs,
61     at::TensorList exp_avg_sqs,
62     at::TensorList state_steps,
63     const at::Tensor& lr,
64     const double beta1,
65     const double beta2,
66     const double weight_decay,
67     const double eps,
68     const bool maximize,
69     const std::optional<at::Tensor>& grad_scale,
70     const std::optional<at::Tensor>& found_inf) {
71   std::vector<std::vector<at::Tensor>> tensor_lists{
72       params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()};
73 
74   const float* grad_scale_ptr =
75       grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
76   const float* found_inf_ptr =
77       found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
78   const float* lr_ptr = lr.const_data_ptr<float>();
79 
80   AT_DISPATCH_FLOATING_TYPES_AND2(
81       kHalf,
82       kBFloat16,
83       params[0].scalar_type(),
84       "fused_adam_kernel_cuda",
85       [&]() {
86         multi_tensor_apply_for_fused_optimizer<4>(
87             tensor_lists,
88             state_steps,
89             FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ORIGINAL, false>(),
90             lr_ptr,
91             1.0, // unused
92             beta1,
93             beta2,
94             weight_decay,
95             eps,
96             maximize,
97             grad_scale_ptr,
98             found_inf_ptr);
99       });
100 }
101 
102 } // namespace at::native
103