1 #pragma once 2 #include <ATen/core/Tensor.h> 3 4 namespace at::native::mps { 5 6 void _fused_adam_amsgrad_mps_impl_( 7 TensorList params, 8 TensorList grads, 9 TensorList exp_avgs, 10 TensorList exp_avg_sqs, 11 TensorList max_exp_avg_sqs, 12 TensorList state_steps, 13 const double lr, 14 const double beta1, 15 const double beta2, 16 const double weight_decay, 17 const double eps, 18 const bool maximize, 19 const std::optional<Tensor>& grad_scale, 20 const std::optional<Tensor>& found_inf 21 ); 22 23 void _fused_adam_amsgrad_mps_impl_( 24 TensorList params, 25 TensorList grads, 26 TensorList exp_avgs, 27 TensorList exp_avg_sqs, 28 TensorList max_exp_avg_sqs, 29 TensorList state_steps, 30 const at::Tensor& lr, 31 const double beta1, 32 const double beta2, 33 const double weight_decay, 34 const double eps, 35 const bool maximize, 36 const std::optional<at::Tensor>& grad_scale, 37 const std::optional<at::Tensor>& found_inf 38 ); 39 40 } // namespace at::native::mps 41