1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS 2#include <ATen/native/ForeachUtils.h> 3#include <ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h> 4#include <ATen/native/mps/operations/FusedAdamKernelImpl.h> 5 6#ifndef AT_PER_OPERATOR_HEADERS 7#include <ATen/Functions.h> 8#include <ATen/NativeFunctions.h> 9#else 10#include <ATen/ops/_fused_adam_native.h> 11#endif 12 13namespace at::native { 14using namespace mps; 15 16void _fused_adam_kernel_mps_(TensorList params, 17 TensorList grads, 18 TensorList exp_avgs, 19 TensorList exp_avg_sqs, 20 TensorList max_exp_avg_sqs, 21 TensorList state_steps, 22 const double lr, 23 const double beta1, 24 const double beta2, 25 const double weight_decay, 26 const double eps, 27 const bool amsgrad, 28 const bool maximize, 29 const std::optional<Tensor>& grad_scale, 30 const std::optional<Tensor>& found_inf) { 31 if (amsgrad) { 32 TORCH_CHECK(native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}), 33 "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout"); 34 _fused_adam_amsgrad_mps_impl_(params, 35 grads, 36 exp_avgs, 37 exp_avg_sqs, 38 max_exp_avg_sqs, 39 state_steps, 40 lr, 41 beta1, 42 beta2, 43 weight_decay, 44 eps, 45 maximize, 46 grad_scale, 47 found_inf); 48 } else { 49 TORCH_CHECK(native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}), 50 "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout"); 51 _fused_adam_mps_impl_(params, 52 grads, 53 exp_avgs, 54 exp_avg_sqs, 55 state_steps, 56 lr, 57 beta1, 58 beta2, 59 weight_decay, 60 eps, 61 maximize, 62 grad_scale, 63 found_inf); 64 } 65} 66 67// The following overload simply has a Tensor lr 68void _fused_adam_kernel_mps_(TensorList params, 69 TensorList grads, 70 TensorList exp_avgs, 71 TensorList exp_avg_sqs, 72 TensorList max_exp_avg_sqs, 73 TensorList state_steps, 74 const Tensor& lr, 75 const double beta1, 76 const double beta2, 77 const double weight_decay, 78 const double eps, 79 const bool amsgrad, 80 const bool maximize, 81 const std::optional<Tensor>& grad_scale, 82 const std::optional<Tensor>& found_inf) { 83 if (lr.is_cpu()) { 84 return _fused_adam_kernel_mps_(params, 85 grads, 86 exp_avgs, 87 exp_avg_sqs, 88 max_exp_avg_sqs, 89 state_steps, 90 lr.item<double>(), 91 beta1, 92 beta2, 93 weight_decay, 94 eps, 95 amsgrad, 96 maximize, 97 grad_scale, 98 found_inf); 99 } 100 101 // Manually check devices since we specify no device check in 102 // native_functions.yaml 103 Device param_device = params[0].device(); 104 TORCH_CHECK(lr.device() == param_device, "lr must be on the same GPU device as the params"); 105 106 if (amsgrad) { 107 TORCH_CHECK(native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}), 108 "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout"); 109 _fused_adam_amsgrad_mps_impl_(params, 110 grads, 111 exp_avgs, 112 exp_avg_sqs, 113 max_exp_avg_sqs, 114 state_steps, 115 lr, 116 beta1, 117 beta2, 118 weight_decay, 119 eps, 120 maximize, 121 grad_scale, 122 found_inf); 123 } else { 124 TORCH_CHECK(native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}), 125 "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout"); 126 _fused_adam_mps_impl_(params, 127 grads, 128 exp_avgs, 129 exp_avg_sqs, 130 state_steps, 131 lr, 132 beta1, 133 beta2, 134 weight_decay, 135 eps, 136 maximize, 137 grad_scale, 138 found_inf); 139 } 140} 141} // namespace at::native 142