• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2#include <ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h>
3
4#include <ATen/Dispatch.h>
5#include <ATen/native/ForeachUtils.h>
6#include <ATen/native/mps/operations/FusedOptimizerOps.h>
7#include <ATen/native/mps/operations/MultiTensorApply.h>
8#include <vector>
9
10namespace at::native::mps {
11
12void _fused_adamw_amsgrad_mps_impl_(TensorList params,
13                                    TensorList grads,
14                                    TensorList exp_avgs,
15                                    TensorList exp_avg_sqs,
16                                    TensorList max_exp_avg_sqs,
17                                    TensorList state_steps,
18                                    const double lr,
19                                    const double beta1,
20                                    const double beta2,
21                                    const double weight_decay,
22                                    const double eps,
23                                    const bool maximize,
24                                    const std::optional<Tensor>& grad_scale,
25                                    const std::optional<Tensor>& found_inf) {
26  std::vector<std::vector<Tensor>> tensor_lists{
27      params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec(), max_exp_avg_sqs.vec()};
28
29  const auto kernel_name =
30      "fused_adamw_amsgrad_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_steps[0]);
31
32  multi_tensor_apply_for_fused_optimizer<5, 512>(kernel_name,
33                                                 tensor_lists,
34                                                 state_steps,
35                                                 FusedAdamEncodingFunctor(),
36                                                 lr,
37                                                 beta1,
38                                                 beta2,
39                                                 weight_decay,
40                                                 eps,
41                                                 maximize);
42}
43void _fused_adamw_amsgrad_mps_impl_(TensorList params,
44                                    TensorList grads,
45                                    TensorList exp_avgs,
46                                    TensorList exp_avg_sqs,
47                                    TensorList max_exp_avg_sqs,
48                                    TensorList state_steps,
49                                    const Tensor& lr,
50                                    const double beta1,
51                                    const double beta2,
52                                    const double weight_decay,
53                                    const double eps,
54                                    const bool maximize,
55                                    const std::optional<Tensor>& grad_scale,
56                                    const std::optional<Tensor>& found_inf) {
57  std::vector<std::vector<Tensor>> tensor_lists{
58      params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec(), max_exp_avg_sqs.vec()};
59
60  const auto kernel_name =
61      "fused_adamw_amsgrad_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_steps[0]);
62
63  multi_tensor_apply_for_fused_optimizer<5, 512>(kernel_name,
64                                                 tensor_lists,
65                                                 state_steps,
66                                                 FusedAdamEncodingFunctor(),
67                                                 lr,
68                                                 beta1,
69                                                 beta2,
70                                                 weight_decay,
71                                                 eps,
72                                                 maximize);
73}
74
75} // namespace at::native::mps
76