1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS 2#include <ATen/native/ForeachUtils.h> 3#include <ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h> 4#include <ATen/native/mps/operations/FusedAdamWKernelImpl.h> 5 6#ifndef AT_PER_OPERATOR_HEADERS 7#include <ATen/Functions.h> 8#include <ATen/NativeFunctions.h> 9#else 10#include <ATen/ops/_fused_adamw_native.h> 11#endif 12 13namespace at::native { 14 15void _fused_adamw_kernel_mps_(at::TensorList params, 16 at::TensorList grads, 17 at::TensorList exp_avgs, 18 at::TensorList exp_avg_sqs, 19 at::TensorList max_exp_avg_sqs, 20 at::TensorList state_steps, 21 const double lr, 22 const double beta1, 23 const double beta2, 24 const double weight_decay, 25 const double eps, 26 const bool amsgrad, 27 const bool maximize, 28 const std::optional<at::Tensor>& grad_scale, 29 const std::optional<at::Tensor>& found_inf) { 30 if (amsgrad) { 31 TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}), 32 "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout"); 33 mps::_fused_adamw_amsgrad_mps_impl_(params, 34 grads, 35 exp_avgs, 36 exp_avg_sqs, 37 max_exp_avg_sqs, 38 state_steps, 39 lr, 40 beta1, 41 beta2, 42 weight_decay, 43 eps, 44 maximize, 45 grad_scale, 46 found_inf); 47 } else { 48 TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}), 49 "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout"); 50 mps::_fused_adamw_mps_impl_(params, 51 grads, 52 exp_avgs, 53 exp_avg_sqs, 54 state_steps, 55 lr, 56 beta1, 57 beta2, 58 weight_decay, 59 eps, 60 maximize, 61 grad_scale, 62 found_inf); 63 } 64} 65 66void _fused_adamw_kernel_mps_(at::TensorList params, 67 at::TensorList grads, 68 at::TensorList exp_avgs, 69 at::TensorList exp_avg_sqs, 70 at::TensorList max_exp_avg_sqs, 71 at::TensorList state_steps, 72 const at::Tensor& lr, 73 const double beta1, 74 const double beta2, 75 const double weight_decay, 76 const double eps, 77 const bool amsgrad, 78 const bool maximize, 79 const std::optional<at::Tensor>& grad_scale, 80 const std::optional<at::Tensor>& found_inf) { 81 if (lr.is_cpu()) { 82 return _fused_adamw_kernel_mps_(params, 83 grads, 84 exp_avgs, 85 exp_avg_sqs, 86 max_exp_avg_sqs, 87 state_steps, 88 lr.item<double>(), 89 beta1, 90 beta2, 91 weight_decay, 92 eps, 93 amsgrad, 94 maximize, 95 grad_scale, 96 found_inf); 97 } 98 99 // Manually check devices since we specify no device check in 100 // native_functions.yaml 101 Device param_device = params[0].device(); 102 103 TORCH_CHECK(lr.device() == param_device, "lr must be on the same GPU device as the params"); 104 105 if (amsgrad) { 106 TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}), 107 "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout"); 108 mps::_fused_adamw_amsgrad_mps_impl_(params, 109 grads, 110 exp_avgs, 111 exp_avg_sqs, 112 max_exp_avg_sqs, 113 state_steps, 114 lr, 115 beta1, 116 beta2, 117 weight_decay, 118 eps, 119 maximize, 120 grad_scale, 121 found_inf); 122 } else { 123 TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}), 124 "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout"); 125 mps::_fused_adamw_mps_impl_(params, 126 grads, 127 exp_avgs, 128 exp_avg_sqs, 129 state_steps, 130 lr, 131 beta1, 132 beta2, 133 weight_decay, 134 eps, 135 maximize, 136 grad_scale, 137 found_inf); 138 } 139} 140 141} // namespace at::native 142