• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 
4 namespace at::native::mps {
5 
6 void _fused_adam_mps_impl_(
7     TensorList params,
8     TensorList grads,
9     TensorList exp_avgs,
10     TensorList exp_avg_sqs,
11     TensorList state_steps,
12     const double lr,
13     const double beta1,
14     const double beta2,
15     const double weight_decay,
16     const double eps,
17     const bool maximize,
18     const std::optional<Tensor>& grad_scale,
19     const std::optional<Tensor>& found_inf
20 );
21 
22 void _fused_adam_mps_impl_(
23     TensorList params,
24     TensorList grads,
25     TensorList exp_avgs,
26     TensorList exp_avg_sqs,
27     TensorList state_steps,
28     const Tensor& lr,
29     const double beta1,
30     const double beta2,
31     const double weight_decay,
32     const double eps,
33     const bool maximize,
34     const std::optional<Tensor>& grad_scale,
35     const std::optional<Tensor>& found_inf
36 );
37 } // namespace at::native::mps
38