1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS 2#include <ATen/native/mps/operations/FusedAdamWKernelImpl.h> 3 4#include <ATen/Dispatch.h> 5#include <ATen/native/ForeachUtils.h> 6#include <ATen/native/mps/operations/FusedOptimizerOps.h> 7#include <ATen/native/mps/operations/MultiTensorApply.h> 8#include <vector> 9 10namespace at::native::mps { 11 12void _fused_adamw_mps_impl_(TensorList params, 13 TensorList grads, 14 TensorList exp_avgs, 15 TensorList exp_avg_sqs, 16 TensorList state_steps, 17 const double lr, 18 const double beta1, 19 const double beta2, 20 const double weight_decay, 21 const double eps, 22 const bool maximize, 23 const std::optional<Tensor>& grad_scale, 24 const std::optional<Tensor>& found_inf) { 25 std::vector<std::vector<Tensor>> tensor_lists{params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()}; 26 27 const auto kernel_name = 28 "fused_adamw_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_steps[0]); 29 30 multi_tensor_apply_for_fused_optimizer<4, 512>(kernel_name, 31 tensor_lists, 32 state_steps, 33 FusedAdamEncodingFunctor(), 34 lr, 35 beta1, 36 beta2, 37 weight_decay, 38 eps, 39 maximize); 40} 41 42void _fused_adamw_mps_impl_(TensorList params, 43 TensorList grads, 44 TensorList exp_avgs, 45 TensorList exp_avg_sqs, 46 TensorList state_steps, 47 const Tensor& lr, 48 const double beta1, 49 const double beta2, 50 const double weight_decay, 51 const double eps, 52 const bool maximize, 53 const std::optional<Tensor>& grad_scale, 54 const std::optional<Tensor>& found_inf) { 55 std::vector<std::vector<Tensor>> tensor_lists{params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()}; 56 57 const auto kernel_name = 58 "fused_adamw_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_steps[0]); 59 60 multi_tensor_apply_for_fused_optimizer<4, 512>(kernel_name, 61 tensor_lists, 62 state_steps, 63 FusedAdamEncodingFunctor(), 64 lr, 65 beta1, 66 beta2, 67 weight_decay, 68 eps, 69 maximize); 70} 71 72} // namespace at::native::mps 73