• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2#include <ATen/native/ForeachUtils.h>
3#include <ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h>
4#include <ATen/native/mps/operations/FusedAdamKernelImpl.h>
5
6#ifndef AT_PER_OPERATOR_HEADERS
7#include <ATen/Functions.h>
8#include <ATen/NativeFunctions.h>
9#else
10#include <ATen/ops/_fused_adam_native.h>
11#endif
12
13namespace at::native {
14using namespace mps;
15
16void _fused_adam_kernel_mps_(TensorList params,
17                             TensorList grads,
18                             TensorList exp_avgs,
19                             TensorList exp_avg_sqs,
20                             TensorList max_exp_avg_sqs,
21                             TensorList state_steps,
22                             const double lr,
23                             const double beta1,
24                             const double beta2,
25                             const double weight_decay,
26                             const double eps,
27                             const bool amsgrad,
28                             const bool maximize,
29                             const std::optional<Tensor>& grad_scale,
30                             const std::optional<Tensor>& found_inf) {
31  if (amsgrad) {
32    TORCH_CHECK(native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
33                "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
34    _fused_adam_amsgrad_mps_impl_(params,
35                                  grads,
36                                  exp_avgs,
37                                  exp_avg_sqs,
38                                  max_exp_avg_sqs,
39                                  state_steps,
40                                  lr,
41                                  beta1,
42                                  beta2,
43                                  weight_decay,
44                                  eps,
45                                  maximize,
46                                  grad_scale,
47                                  found_inf);
48  } else {
49    TORCH_CHECK(native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}),
50                "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
51    _fused_adam_mps_impl_(params,
52                          grads,
53                          exp_avgs,
54                          exp_avg_sqs,
55                          state_steps,
56                          lr,
57                          beta1,
58                          beta2,
59                          weight_decay,
60                          eps,
61                          maximize,
62                          grad_scale,
63                          found_inf);
64  }
65}
66
67// The following overload simply has a Tensor lr
68void _fused_adam_kernel_mps_(TensorList params,
69                             TensorList grads,
70                             TensorList exp_avgs,
71                             TensorList exp_avg_sqs,
72                             TensorList max_exp_avg_sqs,
73                             TensorList state_steps,
74                             const Tensor& lr,
75                             const double beta1,
76                             const double beta2,
77                             const double weight_decay,
78                             const double eps,
79                             const bool amsgrad,
80                             const bool maximize,
81                             const std::optional<Tensor>& grad_scale,
82                             const std::optional<Tensor>& found_inf) {
83  if (lr.is_cpu()) {
84    return _fused_adam_kernel_mps_(params,
85                                   grads,
86                                   exp_avgs,
87                                   exp_avg_sqs,
88                                   max_exp_avg_sqs,
89                                   state_steps,
90                                   lr.item<double>(),
91                                   beta1,
92                                   beta2,
93                                   weight_decay,
94                                   eps,
95                                   amsgrad,
96                                   maximize,
97                                   grad_scale,
98                                   found_inf);
99  }
100
101  // Manually check devices since we specify no device check in
102  // native_functions.yaml
103  Device param_device = params[0].device();
104  TORCH_CHECK(lr.device() == param_device, "lr must be on the same GPU device as the params");
105
106  if (amsgrad) {
107    TORCH_CHECK(native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
108                "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
109    _fused_adam_amsgrad_mps_impl_(params,
110                                  grads,
111                                  exp_avgs,
112                                  exp_avg_sqs,
113                                  max_exp_avg_sqs,
114                                  state_steps,
115                                  lr,
116                                  beta1,
117                                  beta2,
118                                  weight_decay,
119                                  eps,
120                                  maximize,
121                                  grad_scale,
122                                  found_inf);
123  } else {
124    TORCH_CHECK(native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}),
125                "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
126    _fused_adam_mps_impl_(params,
127                          grads,
128                          exp_avgs,
129                          exp_avg_sqs,
130                          state_steps,
131                          lr,
132                          beta1,
133                          beta2,
134                          weight_decay,
135                          eps,
136                          maximize,
137                          grad_scale,
138                          found_inf);
139  }
140}
141} // namespace at::native
142