1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS 2#include <ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h> 3 4#include <ATen/Dispatch.h> 5#include <ATen/native/ForeachUtils.h> 6#include <ATen/native/mps/operations/FusedOptimizerOps.h> 7#include <ATen/native/mps/operations/MultiTensorApply.h> 8#include <vector> 9 10namespace at::native::mps { 11 12void _fused_adamw_amsgrad_mps_impl_(TensorList params, 13 TensorList grads, 14 TensorList exp_avgs, 15 TensorList exp_avg_sqs, 16 TensorList max_exp_avg_sqs, 17 TensorList state_steps, 18 const double lr, 19 const double beta1, 20 const double beta2, 21 const double weight_decay, 22 const double eps, 23 const bool maximize, 24 const std::optional<Tensor>& grad_scale, 25 const std::optional<Tensor>& found_inf) { 26 std::vector<std::vector<Tensor>> tensor_lists{ 27 params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec(), max_exp_avg_sqs.vec()}; 28 29 const auto kernel_name = 30 "fused_adamw_amsgrad_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_steps[0]); 31 32 multi_tensor_apply_for_fused_optimizer<5, 512>(kernel_name, 33 tensor_lists, 34 state_steps, 35 FusedAdamEncodingFunctor(), 36 lr, 37 beta1, 38 beta2, 39 weight_decay, 40 eps, 41 maximize); 42} 43void _fused_adamw_amsgrad_mps_impl_(TensorList params, 44 TensorList grads, 45 TensorList exp_avgs, 46 TensorList exp_avg_sqs, 47 TensorList max_exp_avg_sqs, 48 TensorList state_steps, 49 const Tensor& lr, 50 const double beta1, 51 const double beta2, 52 const double weight_decay, 53 const double eps, 54 const bool maximize, 55 const std::optional<Tensor>& grad_scale, 56 const std::optional<Tensor>& found_inf) { 57 std::vector<std::vector<Tensor>> tensor_lists{ 58 params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec(), max_exp_avg_sqs.vec()}; 59 60 const auto kernel_name = 61 "fused_adamw_amsgrad_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_steps[0]); 62 63 multi_tensor_apply_for_fused_optimizer<5, 512>(kernel_name, 64 tensor_lists, 65 state_steps, 66 FusedAdamEncodingFunctor(), 67 lr, 68 beta1, 69 beta2, 70 weight_decay, 71 eps, 72 maximize); 73} 74 75} // namespace at::native::mps 76