• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2#include <ATen/native/ForeachUtils.h>
3#include <ATen/native/mps/operations/MultiTensorApply.h>
4
5#ifndef AT_PER_OPERATOR_HEADERS
6#include <ATen/Functions.h>
7#include <ATen/NativeFunctions.h>
8#else
9#include <ATen/ops/_fused_sgd.h>
10#include <ATen/ops/_fused_sgd_native.h>
11#endif
12
13namespace at::native {
14
15namespace mps {
16
17static void _fused_sgd_with_momentum_kernel_mps_(TensorList params,
18                                                 TensorList grads,
19                                                 TensorList momentum_buffer_list,
20                                                 const double weight_decay,
21                                                 const double momentum,
22                                                 const double lr,
23                                                 const double dampening,
24                                                 const bool nesterov,
25                                                 const bool maximize,
26                                                 const bool is_first_step,
27                                                 const std::optional<Tensor>& grad_scale,
28                                                 const std::optional<Tensor>& found_inf) {
29  TORCH_CHECK_GT(momentum, 0);
30  TORCH_CHECK(native::check_fast_path_restrictions({params, grads, momentum_buffer_list}));
31
32  std::vector<std::vector<Tensor>> tensor_lists{params.vec(), grads.vec(), momentum_buffer_list.vec()};
33
34  const std::string kernel_name = "fused_sgd_momentum_" + scalarToMetalTypeString(params[0].scalar_type());
35
36  TensorList state_steps;
37
38  multi_tensor_apply_for_fused_optimizer<3, 512>(kernel_name,
39                                                 tensor_lists,
40                                                 state_steps,
41                                                 FusedSgdEncodingFunctor<true /*momentum*/>(),
42                                                 weight_decay,
43                                                 momentum,
44                                                 lr,
45                                                 dampening,
46                                                 nesterov,
47                                                 maximize,
48                                                 is_first_step);
49}
50
51static void _fused_sgd_with_momentum_kernel_mps_(TensorList params,
52                                                 TensorList grads,
53                                                 TensorList momentum_buffer_list,
54                                                 const double weight_decay,
55                                                 const double momentum,
56                                                 const Tensor& lr_tensor,
57                                                 const double dampening,
58                                                 const bool nesterov,
59                                                 const bool maximize,
60                                                 const bool is_first_step,
61                                                 const std::optional<Tensor>& grad_scale,
62                                                 const std::optional<Tensor>& found_inf) {
63  TORCH_CHECK_GT(momentum, 0);
64  TORCH_CHECK(native::check_fast_path_restrictions({params, grads, momentum_buffer_list}));
65
66  std::vector<std::vector<Tensor>> tensor_lists{params.vec(), grads.vec(), momentum_buffer_list.vec()};
67
68  const auto kernel_name = "fused_sgd_momentum_" + scalarToMetalTypeString(params[0].scalar_type());
69
70  TensorList state_steps;
71
72  multi_tensor_apply_for_fused_optimizer<3, 512>(kernel_name,
73                                                 tensor_lists,
74                                                 state_steps,
75                                                 FusedSgdEncodingFunctor<true /*momentum*/>(),
76                                                 weight_decay,
77                                                 momentum,
78                                                 lr_tensor,
79                                                 dampening,
80                                                 nesterov,
81                                                 maximize,
82                                                 is_first_step);
83}
84
85} // namespace mps
86
87using namespace mps;
88
89void _fused_sgd_kernel_mps_(TensorList params,
90                            TensorList grads,
91                            TensorList momentum_buffer_list,
92                            const double weight_decay,
93                            const double momentum,
94                            const double lr,
95                            const double dampening,
96                            const bool nesterov,
97                            const bool maximize,
98                            const bool is_first_step,
99                            const std::optional<Tensor>& grad_scale,
100                            const std::optional<Tensor>& found_inf) {
101  TORCH_CHECK(!grad_scale.has_value() && !found_inf.has_value(), "grad_scale and found_inf are not supported on MPS");
102
103  if (!momentum_buffer_list.empty()) {
104    return _fused_sgd_with_momentum_kernel_mps_(params,
105                                                grads,
106                                                momentum_buffer_list,
107                                                weight_decay,
108                                                momentum,
109                                                lr,
110                                                dampening,
111                                                nesterov,
112                                                maximize,
113                                                is_first_step,
114                                                grad_scale,
115                                                found_inf);
116  }
117  TORCH_CHECK_EQ(momentum, 0);
118  TORCH_CHECK(native::check_fast_path_restrictions({params, grads}));
119  if (is_first_step) {
120    TORCH_WARN_ONCE("`is_first_step` argument has no effect when `momentum_buffer_list` is empty");
121  }
122
123  std::vector<std::vector<Tensor>> tensor_lists{params.vec(), grads.vec()};
124
125  const auto kernel_name = "fused_sgd_" + scalarToMetalTypeString(params[0].scalar_type());
126
127  TensorList state_steps;
128
129  multi_tensor_apply_for_fused_optimizer<2, 512>(kernel_name,
130                                                 tensor_lists,
131                                                 state_steps,
132                                                 FusedSgdEncodingFunctor<false /*momentum*/>(),
133                                                 weight_decay,
134                                                 lr,
135                                                 maximize);
136}
137
138void _fused_sgd_kernel_mps_(TensorList params,
139                            TensorList grads,
140                            TensorList momentum_buffer_list,
141                            const double weight_decay,
142                            const double momentum,
143                            const Tensor& lr_tensor,
144                            const double dampening,
145                            const bool nesterov,
146                            const bool maximize,
147                            const bool is_first_step,
148                            const std::optional<Tensor>& grad_scale,
149                            const std::optional<Tensor>& found_inf) {
150  TORCH_CHECK(!grad_scale.has_value() && !found_inf.has_value(), "grad_scale and found_inf are not supported on MPS");
151
152  if (!momentum_buffer_list.empty()) {
153    return _fused_sgd_with_momentum_kernel_mps_(params,
154                                                grads,
155                                                momentum_buffer_list,
156                                                weight_decay,
157                                                momentum,
158                                                lr_tensor,
159                                                dampening,
160                                                nesterov,
161                                                maximize,
162                                                is_first_step,
163                                                grad_scale,
164                                                found_inf);
165  }
166  if (lr_tensor.is_cpu()) {
167    return _fused_sgd_kernel_mps_(params,
168                                  grads,
169                                  momentum_buffer_list,
170                                  weight_decay,
171                                  momentum,
172                                  lr_tensor.item<double>(),
173                                  dampening,
174                                  nesterov,
175                                  maximize,
176                                  is_first_step,
177                                  grad_scale,
178                                  found_inf);
179  }
180  TORCH_CHECK_EQ(momentum, 0);
181  TORCH_CHECK(native::check_fast_path_restrictions({params, grads}));
182  if (is_first_step) {
183    TORCH_WARN_ONCE("`is_first_step` argument has no effect when `momentum_buffer_list` is empty");
184  }
185
186  TORCH_CHECK(lr_tensor.device() == params[0].device(), "lr must be on the same GPU device as the params");
187
188  std::vector<std::vector<Tensor>> tensor_lists{params.vec(), grads.vec()};
189
190  const std::string kernel_name = "fused_sgd_" + mps::scalarToMetalTypeString(params[0].scalar_type());
191
192  TensorList state_steps;
193
194  multi_tensor_apply_for_fused_optimizer<2, 512>(kernel_name,
195                                                 tensor_lists,
196                                                 state_steps,
197                                                 FusedSgdEncodingFunctor<false /*momentum*/>(),
198                                                 weight_decay,
199                                                 lr_tensor,
200                                                 maximize);
201}
202
203} // namespace at::native
204