1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS 2#include <ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h> 3 4#include <ATen/Dispatch.h> 5#include <ATen/native/ForeachUtils.h> 6#include <ATen/native/mps/operations/FusedOptimizerOps.h> 7#include <ATen/native/mps/operations/MultiTensorApply.h> 8#include <vector> 9 10namespace at::native::mps { 11 12void _fused_adam_amsgrad_mps_impl_(TensorList params, 13 TensorList grads, 14 TensorList exp_avgs, 15 TensorList exp_avg_sqs, 16 TensorList max_exp_avg_sqs, 17 TensorList state_steps, 18 const double lr, 19 const double beta1, 20 const double beta2, 21 const double weight_decay, 22 const double eps, 23 const bool maximize, 24 const std::optional<Tensor>& grad_scale, 25 const std::optional<Tensor>& found_inf) { 26 std::vector<std::vector<Tensor>> tensor_lists{ 27 params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec(), max_exp_avg_sqs.vec()}; 28 29 const auto kernel_name = 30 "fused_adam_amsgrad_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_steps[0]); 31 32 multi_tensor_apply_for_fused_optimizer<5, 512>(kernel_name, 33 tensor_lists, 34 state_steps, 35 FusedAdamEncodingFunctor(), 36 lr, 37 beta1, 38 beta2, 39 weight_decay, 40 eps, 41 maximize); 42} 43 44void _fused_adam_amsgrad_mps_impl_(TensorList params, 45 TensorList grads, 46 TensorList exp_avgs, 47 TensorList exp_avg_sqs, 48 TensorList max_exp_avg_sqs, 49 TensorList state_steps, 50 const Tensor& lr, 51 const double beta1, 52 const double beta2, 53 const double weight_decay, 54 const double eps, 55 const bool maximize, 56 const std::optional<Tensor>& grad_scale, 57 const std::optional<Tensor>& found_inf) { 58 std::vector<std::vector<Tensor>> tensor_lists{ 59 params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec(), max_exp_avg_sqs.vec()}; 60 61 const std::string kernel_name = 62 "fused_adam_amsgrad_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_steps[0]); 63 64 multi_tensor_apply_for_fused_optimizer<5, 512>(kernel_name, 65 tensor_lists, 66 state_steps, 67 FusedAdamEncodingFunctor(), 68 lr, 69 beta1, 70 beta2, 71 weight_decay, 72 eps, 73 maximize); 74} 75 76} // namespace at::native::mps 77