#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #endif namespace at::native { namespace mps { static void _fused_sgd_with_momentum_kernel_mps_(TensorList params, TensorList grads, TensorList momentum_buffer_list, const double weight_decay, const double momentum, const double lr, const double dampening, const bool nesterov, const bool maximize, const bool is_first_step, const std::optional& grad_scale, const std::optional& found_inf) { TORCH_CHECK_GT(momentum, 0); TORCH_CHECK(native::check_fast_path_restrictions({params, grads, momentum_buffer_list})); std::vector> tensor_lists{params.vec(), grads.vec(), momentum_buffer_list.vec()}; const std::string kernel_name = "fused_sgd_momentum_" + scalarToMetalTypeString(params[0].scalar_type()); TensorList state_steps; multi_tensor_apply_for_fused_optimizer<3, 512>(kernel_name, tensor_lists, state_steps, FusedSgdEncodingFunctor(), weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step); } static void _fused_sgd_with_momentum_kernel_mps_(TensorList params, TensorList grads, TensorList momentum_buffer_list, const double weight_decay, const double momentum, const Tensor& lr_tensor, const double dampening, const bool nesterov, const bool maximize, const bool is_first_step, const std::optional& grad_scale, const std::optional& found_inf) { TORCH_CHECK_GT(momentum, 0); TORCH_CHECK(native::check_fast_path_restrictions({params, grads, momentum_buffer_list})); std::vector> tensor_lists{params.vec(), grads.vec(), momentum_buffer_list.vec()}; const auto kernel_name = "fused_sgd_momentum_" + scalarToMetalTypeString(params[0].scalar_type()); TensorList state_steps; multi_tensor_apply_for_fused_optimizer<3, 512>(kernel_name, tensor_lists, state_steps, FusedSgdEncodingFunctor(), weight_decay, momentum, lr_tensor, dampening, nesterov, maximize, is_first_step); } } // namespace mps using namespace mps; void _fused_sgd_kernel_mps_(TensorList params, TensorList grads, TensorList momentum_buffer_list, const double weight_decay, const double momentum, const double lr, const double dampening, const bool nesterov, const bool maximize, const bool is_first_step, const std::optional& grad_scale, const std::optional& found_inf) { TORCH_CHECK(!grad_scale.has_value() && !found_inf.has_value(), "grad_scale and found_inf are not supported on MPS"); if (!momentum_buffer_list.empty()) { return _fused_sgd_with_momentum_kernel_mps_(params, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf); } TORCH_CHECK_EQ(momentum, 0); TORCH_CHECK(native::check_fast_path_restrictions({params, grads})); if (is_first_step) { TORCH_WARN_ONCE("`is_first_step` argument has no effect when `momentum_buffer_list` is empty"); } std::vector> tensor_lists{params.vec(), grads.vec()}; const auto kernel_name = "fused_sgd_" + scalarToMetalTypeString(params[0].scalar_type()); TensorList state_steps; multi_tensor_apply_for_fused_optimizer<2, 512>(kernel_name, tensor_lists, state_steps, FusedSgdEncodingFunctor(), weight_decay, lr, maximize); } void _fused_sgd_kernel_mps_(TensorList params, TensorList grads, TensorList momentum_buffer_list, const double weight_decay, const double momentum, const Tensor& lr_tensor, const double dampening, const bool nesterov, const bool maximize, const bool is_first_step, const std::optional& grad_scale, const std::optional& found_inf) { TORCH_CHECK(!grad_scale.has_value() && !found_inf.has_value(), "grad_scale and found_inf are not supported on MPS"); if (!momentum_buffer_list.empty()) { return _fused_sgd_with_momentum_kernel_mps_(params, grads, momentum_buffer_list, weight_decay, momentum, lr_tensor, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf); } if (lr_tensor.is_cpu()) { return _fused_sgd_kernel_mps_(params, grads, momentum_buffer_list, weight_decay, momentum, lr_tensor.item(), dampening, nesterov, maximize, is_first_step, grad_scale, found_inf); } TORCH_CHECK_EQ(momentum, 0); TORCH_CHECK(native::check_fast_path_restrictions({params, grads})); if (is_first_step) { TORCH_WARN_ONCE("`is_first_step` argument has no effect when `momentum_buffer_list` is empty"); } TORCH_CHECK(lr_tensor.device() == params[0].device(), "lr must be on the same GPU device as the params"); std::vector> tensor_lists{params.vec(), grads.vec()}; const std::string kernel_name = "fused_sgd_" + mps::scalarToMetalTypeString(params[0].scalar_type()); TensorList state_steps; multi_tensor_apply_for_fused_optimizer<2, 512>(kernel_name, tensor_lists, state_steps, FusedSgdEncodingFunctor(), weight_decay, lr_tensor, maximize); } } // namespace at::native