• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 
4 namespace at::native::mps {
5 
6 void _fused_adam_amsgrad_mps_impl_(
7     TensorList params,
8     TensorList grads,
9     TensorList exp_avgs,
10     TensorList exp_avg_sqs,
11     TensorList max_exp_avg_sqs,
12     TensorList state_steps,
13     const double lr,
14     const double beta1,
15     const double beta2,
16     const double weight_decay,
17     const double eps,
18     const bool maximize,
19     const std::optional<Tensor>& grad_scale,
20     const std::optional<Tensor>& found_inf
21 );
22 
23 void _fused_adam_amsgrad_mps_impl_(
24     TensorList params,
25     TensorList grads,
26     TensorList exp_avgs,
27     TensorList exp_avg_sqs,
28     TensorList max_exp_avg_sqs,
29     TensorList state_steps,
30     const at::Tensor& lr,
31     const double beta1,
32     const double beta2,
33     const double weight_decay,
34     const double eps,
35     const bool maximize,
36     const std::optional<at::Tensor>& grad_scale,
37     const std::optional<at::Tensor>& found_inf
38 );
39 
40 } // namespace at::native::mps
41