• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2#include <ATen/native/ForeachUtils.h>
3#include <ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h>
4#include <ATen/native/mps/operations/FusedAdamWKernelImpl.h>
5
6#ifndef AT_PER_OPERATOR_HEADERS
7#include <ATen/Functions.h>
8#include <ATen/NativeFunctions.h>
9#else
10#include <ATen/ops/_fused_adamw_native.h>
11#endif
12
13namespace at::native {
14
15void _fused_adamw_kernel_mps_(at::TensorList params,
16                              at::TensorList grads,
17                              at::TensorList exp_avgs,
18                              at::TensorList exp_avg_sqs,
19                              at::TensorList max_exp_avg_sqs,
20                              at::TensorList state_steps,
21                              const double lr,
22                              const double beta1,
23                              const double beta2,
24                              const double weight_decay,
25                              const double eps,
26                              const bool amsgrad,
27                              const bool maximize,
28                              const std::optional<at::Tensor>& grad_scale,
29                              const std::optional<at::Tensor>& found_inf) {
30  if (amsgrad) {
31    TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
32                "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
33    mps::_fused_adamw_amsgrad_mps_impl_(params,
34                                        grads,
35                                        exp_avgs,
36                                        exp_avg_sqs,
37                                        max_exp_avg_sqs,
38                                        state_steps,
39                                        lr,
40                                        beta1,
41                                        beta2,
42                                        weight_decay,
43                                        eps,
44                                        maximize,
45                                        grad_scale,
46                                        found_inf);
47  } else {
48    TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}),
49                "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
50    mps::_fused_adamw_mps_impl_(params,
51                                grads,
52                                exp_avgs,
53                                exp_avg_sqs,
54                                state_steps,
55                                lr,
56                                beta1,
57                                beta2,
58                                weight_decay,
59                                eps,
60                                maximize,
61                                grad_scale,
62                                found_inf);
63  }
64}
65
66void _fused_adamw_kernel_mps_(at::TensorList params,
67                              at::TensorList grads,
68                              at::TensorList exp_avgs,
69                              at::TensorList exp_avg_sqs,
70                              at::TensorList max_exp_avg_sqs,
71                              at::TensorList state_steps,
72                              const at::Tensor& lr,
73                              const double beta1,
74                              const double beta2,
75                              const double weight_decay,
76                              const double eps,
77                              const bool amsgrad,
78                              const bool maximize,
79                              const std::optional<at::Tensor>& grad_scale,
80                              const std::optional<at::Tensor>& found_inf) {
81  if (lr.is_cpu()) {
82    return _fused_adamw_kernel_mps_(params,
83                                    grads,
84                                    exp_avgs,
85                                    exp_avg_sqs,
86                                    max_exp_avg_sqs,
87                                    state_steps,
88                                    lr.item<double>(),
89                                    beta1,
90                                    beta2,
91                                    weight_decay,
92                                    eps,
93                                    amsgrad,
94                                    maximize,
95                                    grad_scale,
96                                    found_inf);
97  }
98
99  // Manually check devices since we specify no device check in
100  // native_functions.yaml
101  Device param_device = params[0].device();
102
103  TORCH_CHECK(lr.device() == param_device, "lr must be on the same GPU device as the params");
104
105  if (amsgrad) {
106    TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
107                "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
108    mps::_fused_adamw_amsgrad_mps_impl_(params,
109                                        grads,
110                                        exp_avgs,
111                                        exp_avg_sqs,
112                                        max_exp_avg_sqs,
113                                        state_steps,
114                                        lr,
115                                        beta1,
116                                        beta2,
117                                        weight_decay,
118                                        eps,
119                                        maximize,
120                                        grad_scale,
121                                        found_inf);
122  } else {
123    TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}),
124                "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
125    mps::_fused_adamw_mps_impl_(params,
126                                grads,
127                                exp_avgs,
128                                exp_avg_sqs,
129                                state_steps,
130                                lr,
131                                beta1,
132                                beta2,
133                                weight_decay,
134                                eps,
135                                maximize,
136                                grad_scale,
137                                found_inf);
138  }
139}
140
141} // namespace at::native
142