1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS 2#include <ATen/native/ForeachUtils.h> 3#include <ATen/native/mps/operations/MultiTensorApply.h> 4 5#ifndef AT_PER_OPERATOR_HEADERS 6#include <ATen/Functions.h> 7#include <ATen/NativeFunctions.h> 8#else 9#include <ATen/ops/_fused_sgd.h> 10#include <ATen/ops/_fused_sgd_native.h> 11#endif 12 13namespace at::native { 14 15namespace mps { 16 17static void _fused_sgd_with_momentum_kernel_mps_(TensorList params, 18 TensorList grads, 19 TensorList momentum_buffer_list, 20 const double weight_decay, 21 const double momentum, 22 const double lr, 23 const double dampening, 24 const bool nesterov, 25 const bool maximize, 26 const bool is_first_step, 27 const std::optional<Tensor>& grad_scale, 28 const std::optional<Tensor>& found_inf) { 29 TORCH_CHECK_GT(momentum, 0); 30 TORCH_CHECK(native::check_fast_path_restrictions({params, grads, momentum_buffer_list})); 31 32 std::vector<std::vector<Tensor>> tensor_lists{params.vec(), grads.vec(), momentum_buffer_list.vec()}; 33 34 const std::string kernel_name = "fused_sgd_momentum_" + scalarToMetalTypeString(params[0].scalar_type()); 35 36 TensorList state_steps; 37 38 multi_tensor_apply_for_fused_optimizer<3, 512>(kernel_name, 39 tensor_lists, 40 state_steps, 41 FusedSgdEncodingFunctor<true /*momentum*/>(), 42 weight_decay, 43 momentum, 44 lr, 45 dampening, 46 nesterov, 47 maximize, 48 is_first_step); 49} 50 51static void _fused_sgd_with_momentum_kernel_mps_(TensorList params, 52 TensorList grads, 53 TensorList momentum_buffer_list, 54 const double weight_decay, 55 const double momentum, 56 const Tensor& lr_tensor, 57 const double dampening, 58 const bool nesterov, 59 const bool maximize, 60 const bool is_first_step, 61 const std::optional<Tensor>& grad_scale, 62 const std::optional<Tensor>& found_inf) { 63 TORCH_CHECK_GT(momentum, 0); 64 TORCH_CHECK(native::check_fast_path_restrictions({params, grads, momentum_buffer_list})); 65 66 std::vector<std::vector<Tensor>> tensor_lists{params.vec(), grads.vec(), momentum_buffer_list.vec()}; 67 68 const auto kernel_name = "fused_sgd_momentum_" + scalarToMetalTypeString(params[0].scalar_type()); 69 70 TensorList state_steps; 71 72 multi_tensor_apply_for_fused_optimizer<3, 512>(kernel_name, 73 tensor_lists, 74 state_steps, 75 FusedSgdEncodingFunctor<true /*momentum*/>(), 76 weight_decay, 77 momentum, 78 lr_tensor, 79 dampening, 80 nesterov, 81 maximize, 82 is_first_step); 83} 84 85} // namespace mps 86 87using namespace mps; 88 89void _fused_sgd_kernel_mps_(TensorList params, 90 TensorList grads, 91 TensorList momentum_buffer_list, 92 const double weight_decay, 93 const double momentum, 94 const double lr, 95 const double dampening, 96 const bool nesterov, 97 const bool maximize, 98 const bool is_first_step, 99 const std::optional<Tensor>& grad_scale, 100 const std::optional<Tensor>& found_inf) { 101 TORCH_CHECK(!grad_scale.has_value() && !found_inf.has_value(), "grad_scale and found_inf are not supported on MPS"); 102 103 if (!momentum_buffer_list.empty()) { 104 return _fused_sgd_with_momentum_kernel_mps_(params, 105 grads, 106 momentum_buffer_list, 107 weight_decay, 108 momentum, 109 lr, 110 dampening, 111 nesterov, 112 maximize, 113 is_first_step, 114 grad_scale, 115 found_inf); 116 } 117 TORCH_CHECK_EQ(momentum, 0); 118 TORCH_CHECK(native::check_fast_path_restrictions({params, grads})); 119 if (is_first_step) { 120 TORCH_WARN_ONCE("`is_first_step` argument has no effect when `momentum_buffer_list` is empty"); 121 } 122 123 std::vector<std::vector<Tensor>> tensor_lists{params.vec(), grads.vec()}; 124 125 const auto kernel_name = "fused_sgd_" + scalarToMetalTypeString(params[0].scalar_type()); 126 127 TensorList state_steps; 128 129 multi_tensor_apply_for_fused_optimizer<2, 512>(kernel_name, 130 tensor_lists, 131 state_steps, 132 FusedSgdEncodingFunctor<false /*momentum*/>(), 133 weight_decay, 134 lr, 135 maximize); 136} 137 138void _fused_sgd_kernel_mps_(TensorList params, 139 TensorList grads, 140 TensorList momentum_buffer_list, 141 const double weight_decay, 142 const double momentum, 143 const Tensor& lr_tensor, 144 const double dampening, 145 const bool nesterov, 146 const bool maximize, 147 const bool is_first_step, 148 const std::optional<Tensor>& grad_scale, 149 const std::optional<Tensor>& found_inf) { 150 TORCH_CHECK(!grad_scale.has_value() && !found_inf.has_value(), "grad_scale and found_inf are not supported on MPS"); 151 152 if (!momentum_buffer_list.empty()) { 153 return _fused_sgd_with_momentum_kernel_mps_(params, 154 grads, 155 momentum_buffer_list, 156 weight_decay, 157 momentum, 158 lr_tensor, 159 dampening, 160 nesterov, 161 maximize, 162 is_first_step, 163 grad_scale, 164 found_inf); 165 } 166 if (lr_tensor.is_cpu()) { 167 return _fused_sgd_kernel_mps_(params, 168 grads, 169 momentum_buffer_list, 170 weight_decay, 171 momentum, 172 lr_tensor.item<double>(), 173 dampening, 174 nesterov, 175 maximize, 176 is_first_step, 177 grad_scale, 178 found_inf); 179 } 180 TORCH_CHECK_EQ(momentum, 0); 181 TORCH_CHECK(native::check_fast_path_restrictions({params, grads})); 182 if (is_first_step) { 183 TORCH_WARN_ONCE("`is_first_step` argument has no effect when `momentum_buffer_list` is empty"); 184 } 185 186 TORCH_CHECK(lr_tensor.device() == params[0].device(), "lr must be on the same GPU device as the params"); 187 188 std::vector<std::vector<Tensor>> tensor_lists{params.vec(), grads.vec()}; 189 190 const std::string kernel_name = "fused_sgd_" + mps::scalarToMetalTypeString(params[0].scalar_type()); 191 192 TensorList state_steps; 193 194 multi_tensor_apply_for_fused_optimizer<2, 512>(kernel_name, 195 tensor_lists, 196 state_steps, 197 FusedSgdEncodingFunctor<false /*momentum*/>(), 198 weight_decay, 199 lr_tensor, 200 maximize); 201} 202 203} // namespace at::native 204