1 #pragma once 2 #include <ATen/core/Tensor.h> 3 4 namespace at::native::mps { 5 6 void _fused_adam_mps_impl_( 7 TensorList params, 8 TensorList grads, 9 TensorList exp_avgs, 10 TensorList exp_avg_sqs, 11 TensorList state_steps, 12 const double lr, 13 const double beta1, 14 const double beta2, 15 const double weight_decay, 16 const double eps, 17 const bool maximize, 18 const std::optional<Tensor>& grad_scale, 19 const std::optional<Tensor>& found_inf 20 ); 21 22 void _fused_adam_mps_impl_( 23 TensorList params, 24 TensorList grads, 25 TensorList exp_avgs, 26 TensorList exp_avg_sqs, 27 TensorList state_steps, 28 const Tensor& lr, 29 const double beta1, 30 const double beta2, 31 const double weight_decay, 32 const double eps, 33 const bool maximize, 34 const std::optional<Tensor>& grad_scale, 35 const std::optional<Tensor>& found_inf 36 ); 37 } // namespace at::native::mps 38