• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2#include <ATen/native/mps/operations/FusedAdamWKernelImpl.h>
3
4#include <ATen/Dispatch.h>
5#include <ATen/native/ForeachUtils.h>
6#include <ATen/native/mps/operations/FusedOptimizerOps.h>
7#include <ATen/native/mps/operations/MultiTensorApply.h>
8#include <vector>
9
10namespace at::native::mps {
11
12void _fused_adamw_mps_impl_(TensorList params,
13                            TensorList grads,
14                            TensorList exp_avgs,
15                            TensorList exp_avg_sqs,
16                            TensorList state_steps,
17                            const double lr,
18                            const double beta1,
19                            const double beta2,
20                            const double weight_decay,
21                            const double eps,
22                            const bool maximize,
23                            const std::optional<Tensor>& grad_scale,
24                            const std::optional<Tensor>& found_inf) {
25  std::vector<std::vector<Tensor>> tensor_lists{params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()};
26
27  const auto kernel_name =
28      "fused_adamw_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_steps[0]);
29
30  multi_tensor_apply_for_fused_optimizer<4, 512>(kernel_name,
31                                                 tensor_lists,
32                                                 state_steps,
33                                                 FusedAdamEncodingFunctor(),
34                                                 lr,
35                                                 beta1,
36                                                 beta2,
37                                                 weight_decay,
38                                                 eps,
39                                                 maximize);
40}
41
42void _fused_adamw_mps_impl_(TensorList params,
43                            TensorList grads,
44                            TensorList exp_avgs,
45                            TensorList exp_avg_sqs,
46                            TensorList state_steps,
47                            const Tensor& lr,
48                            const double beta1,
49                            const double beta2,
50                            const double weight_decay,
51                            const double eps,
52                            const bool maximize,
53                            const std::optional<Tensor>& grad_scale,
54                            const std::optional<Tensor>& found_inf) {
55  std::vector<std::vector<Tensor>> tensor_lists{params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()};
56
57  const auto kernel_name =
58      "fused_adamw_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_steps[0]);
59
60  multi_tensor_apply_for_fused_optimizer<4, 512>(kernel_name,
61                                                 tensor_lists,
62                                                 state_steps,
63                                                 FusedAdamEncodingFunctor(),
64                                                 lr,
65                                                 beta1,
66                                                 beta2,
67                                                 weight_decay,
68                                                 eps,
69                                                 maximize);
70}
71
72} // namespace at::native::mps
73