1 #pragma once
2 #include <ATen/native/mps/OperationUtils.h>
3
4 namespace at::native {
5 namespace mps {
6
7 static const char* FUSED_ADAM_OPS = R"METAL(
8 #include <metal_stdlib>
9
10 #define kmaxThreadGroups 32
11 #define kmaxTensors 32
12 #define chunk_size 65536
13
14 constexpr constant uint kParamIdx = 0;
15 constexpr constant uint kGradIdx = kParamIdx + kmaxTensors;
16 constexpr constant uint kExpAvgIdx = kGradIdx + kmaxTensors;
17 constexpr constant uint kMomentumBufferListIdx = kGradIdx + kmaxTensors;
18 constexpr constant uint kExpAvgSqIdx = kExpAvgIdx + kmaxTensors;
19 constexpr constant uint kMaxExpAvgSqIdx = kExpAvgSqIdx + kmaxTensors;
20 constexpr constant uint kStateStepsIdx = kExpAvgSqIdx + kmaxTensors;
21 constexpr constant uint kStateStepsIdxForAmsgrad = kMaxExpAvgSqIdx + kmaxTensors;
22
23 template<typename T, typename state_steps_t>
24 struct AdamArguments {
25 metal::array<device T *, kmaxTensors> params [[ id(kParamIdx) ]];
26 metal::array<device T *, kmaxTensors> grads [[ id(kGradIdx) ]];
27 metal::array<device T *, kmaxTensors> exp_avgs [[ id(kExpAvgIdx) ]];
28 metal::array<device T *, kmaxTensors> exp_avg_sqs [[ id(kExpAvgSqIdx) ]];
29 metal::array<device state_steps_t *, kmaxTensors> state_steps [[ id(kStateStepsIdx) ]];
30 };
31
32 template<typename T, typename state_steps_t>
33 struct AdamAmsgradArguments {
34 metal::array<device T *, kmaxTensors> params [[ id(kParamIdx) ]];
35 metal::array<device T *, kmaxTensors> grads [[ id(kGradIdx) ]];
36 metal::array<device T *, kmaxTensors> exp_avgs [[ id(kExpAvgIdx) ]];
37 metal::array<device T *, kmaxTensors> exp_avg_sqs [[ id(kExpAvgSqIdx) ]];
38 metal::array<device T *, kmaxTensors> max_exp_avg_sqs [[ id(kMaxExpAvgSqIdx) ]];
39 metal::array<device state_steps_t *, kmaxTensors> state_steps [[ id(kStateStepsIdxForAmsgrad) ]];
40 };
41
42 template<typename T>
43 struct SgdArguments {
44 metal::array<device T *, kmaxTensors> params [[ id(kParamIdx) ]];
45 metal::array<device T *, kmaxTensors> grads [[ id(kGradIdx) ]];
46 };
47
48 template<typename T>
49 struct SgdMomentumArguments {
50 metal::array<device T *, kmaxTensors> params [[ id(kParamIdx) ]];
51 metal::array<device T *, kmaxTensors> grads [[ id(kGradIdx) ]];
52 metal::array<device T *, kmaxTensors> momentum_buffer_list [[ id(kMomentumBufferListIdx) ]];
53 };
54
55 struct MetadataArguments {
56 uint32_t numels[kmaxTensors];
57 uint32_t threadgroup_to_tensor[kmaxThreadGroups];
58 uint32_t threadgroup_to_chunk[kmaxThreadGroups];
59 };
60
61 enum ADAM_MODE : uint8_t {
62 ORIGINAL = 0,
63 ADAMW = 1
64 };
65
66 template <typename T, typename state_steps_t, ADAM_MODE adam_mode>
67 inline void adam_math_amsgrad(
68 device T & param,
69 device T & grad,
70 device T & exp_avg,
71 device T & exp_avg_sq,
72 device T & max_exp_avg_sq,
73 device state_steps_t & state_steps,
74 const float lr,
75 const float beta1,
76 const float beta2,
77 const float weight_decay,
78 const float eps,
79 const uint8_t maximize
80 ) {
81 T grad_ = grad;
82
83 if (maximize) {
84 grad = -grad;
85 }
86
87 // Update param, grad, 1st and 2nd order momentum.
88 if (weight_decay != 0) {
89 switch (adam_mode) {
90 case ADAM_MODE::ORIGINAL:
91 grad += param * weight_decay;
92 break;
93 case ADAM_MODE::ADAMW:
94 param -= lr * weight_decay * param;
95 break;
96 }
97 }
98
99 exp_avg = beta1 * exp_avg + (1 - beta1) * grad;
100 exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad;
101 const float casted_state_steps = static_cast<float>(state_steps);
102 const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps);
103 const T step_size = lr / bias_correction1;
104 const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps);
105 const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
106 max_exp_avg_sq = metal::max(max_exp_avg_sq, exp_avg_sq);
107
108 const T denom = (metal::precise::sqrt(max_exp_avg_sq) / bias_correction2_sqrt) + eps;
109 param -= step_size * exp_avg / denom;
110 grad = grad_;
111 }
112
113 template <typename T, typename state_steps_t, ADAM_MODE adam_mode>
114 inline void adam_math(
115 device T & param,
116 device T & grad,
117 device T & exp_avg,
118 device T & exp_avg_sq,
119 device state_steps_t & state_steps,
120 const float lr,
121 const float beta1,
122 const float beta2,
123 const float weight_decay,
124 const float eps,
125 const uint8_t maximize
126 ) {
127 T grad_ = grad;
128
129 if (maximize) {
130 grad = -grad;
131 }
132
133 // Update param, grad, 1st and 2nd order momentum.
134 if (weight_decay != 0) {
135 switch (adam_mode) {
136 case ADAM_MODE::ORIGINAL:
137 grad += param * weight_decay;
138 break;
139 case ADAM_MODE::ADAMW:
140 param -= lr * weight_decay * param;
141 break;
142 }
143 }
144
145 exp_avg = beta1 * exp_avg + (1 - beta1) * grad;
146 exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad;
147 const float casted_state_steps = static_cast<float>(state_steps);
148 const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps);
149 const T step_size = lr / bias_correction1;
150 const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps);
151 const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
152 const T denom = (metal::precise::sqrt(exp_avg_sq) / bias_correction2_sqrt) + eps;
153 param -= step_size * exp_avg / denom;
154 grad = grad_;
155 }
156
157 template <typename T, typename state_steps_t, ADAM_MODE adam_mode>
158 kernel void fused_adam_amsgrad(
159 device AdamAmsgradArguments<T, state_steps_t> & args [[buffer(0)]],
160 constant MetadataArguments & metadata_args [[buffer(1)]],
161 constant float & lr [[buffer(2)]],
162 constant float & beta1 [[buffer(3)]],
163 constant float & beta2 [[buffer(4)]],
164 constant float & weight_decay [[buffer(5)]],
165 constant float & eps [[buffer(6)]],
166 constant uint8_t & maximize [[buffer(7)]],
167 uint tid [[thread_position_in_threadgroup]],
168 uint tgid [[threadgroup_position_in_grid]],
169 uint tptg [[threads_per_threadgroup]]) {
170
171 const uint32_t tensor_loc = metadata_args.threadgroup_to_tensor[tgid];
172 const uint32_t chunk_idx = metadata_args.threadgroup_to_chunk[tgid];
173 const uint32_t chunk_offset = chunk_idx * chunk_size;
174 const uint32_t numel = metadata_args.numels[tensor_loc] - chunk_offset;
175
176 const auto step_count = args.state_steps[tensor_loc];
177
178 // each chunk is a threadgroup
179 auto param = args.params[tensor_loc] + chunk_offset;
180 auto grad = args.grads[tensor_loc] + chunk_offset;
181 auto exp_avg = args.exp_avgs[tensor_loc] + chunk_offset;
182 auto exp_avg_sq = args.exp_avg_sqs[tensor_loc] + chunk_offset;
183 auto max_exp_avg_sq = args.max_exp_avg_sqs[tensor_loc] + chunk_offset;
184
185 for (uint32_t i_start = tid; i_start < numel && i_start < chunk_size; i_start += tptg) {
186 adam_math_amsgrad<T, state_steps_t, adam_mode>(
187 *(param + i_start),
188 *(grad + i_start),
189 *(exp_avg + i_start),
190 *(exp_avg_sq + i_start),
191 *(max_exp_avg_sq + i_start),
192 *step_count,
193 lr,
194 beta1,
195 beta2,
196 weight_decay,
197 eps,
198 maximize
199 );
200 }
201 }
202
203 template <typename T, typename state_steps_t, ADAM_MODE adam_mode>
204 kernel void fused_adam(
205 device AdamArguments<T, state_steps_t> & args [[buffer(0)]],
206 constant MetadataArguments & metadata_args [[buffer(1)]],
207 constant float & lr [[buffer(2)]],
208 constant float & beta1 [[buffer(3)]],
209 constant float & beta2 [[buffer(4)]],
210 constant float & weight_decay [[buffer(5)]],
211 constant float & eps [[buffer(6)]],
212 constant uint8_t & maximize [[buffer(7)]],
213 uint tid [[thread_position_in_threadgroup]],
214 uint tgid [[threadgroup_position_in_grid]],
215 uint tptg [[threads_per_threadgroup]]) {
216
217 const uint32_t tensor_loc = metadata_args.threadgroup_to_tensor[tgid];
218 const uint32_t chunk_idx = metadata_args.threadgroup_to_chunk[tgid];
219 const uint32_t chunk_offset = chunk_idx * chunk_size;
220 const uint32_t numel = metadata_args.numels[tensor_loc] - chunk_offset;
221
222 const auto step_count = args.state_steps[tensor_loc];
223
224 // each chunk is a threadgroup
225 auto param = args.params[tensor_loc] + chunk_offset;
226 auto grad = args.grads[tensor_loc] + chunk_offset;
227 auto exp_avg = args.exp_avgs[tensor_loc] + chunk_offset;
228 auto exp_avg_sq = args.exp_avg_sqs[tensor_loc] + chunk_offset;
229
230 for (uint32_t i_start = tid; i_start < numel && i_start < chunk_size; i_start += tptg) {
231 adam_math<T, state_steps_t, adam_mode>(
232 *(param + i_start),
233 *(grad + i_start),
234 *(exp_avg + i_start),
235 *(exp_avg_sq + i_start),
236 *step_count,
237 lr,
238 beta1,
239 beta2,
240 weight_decay,
241 eps,
242 maximize
243 );
244 }
245 }
246
247 #define REGISTER_FUSED_ADAM_OP(DTYPE, STATE_STEPS_DTYPE, ADAM_MODE_DTYPE, HOST_NAME, KERNEL_NAME, ARGUMENTS_STRUCT) \
248 template \
249 [[host_name(#HOST_NAME "_" #DTYPE "_" #STATE_STEPS_DTYPE)]] \
250 kernel void KERNEL_NAME<DTYPE, STATE_STEPS_DTYPE, ADAM_MODE_DTYPE>( \
251 device ARGUMENTS_STRUCT<DTYPE, STATE_STEPS_DTYPE> & args [[buffer(0)]],\
252 constant MetadataArguments & metadata_args [[buffer(1)]],\
253 constant float & lr [[buffer(2)]],\
254 constant float & beta1 [[buffer(3)]],\
255 constant float & beta2 [[buffer(4)]],\
256 constant float & weight_decay [[buffer(5)]],\
257 constant float & eps [[buffer(6)]],\
258 constant uint8_t & maximize [[buffer(7)]],\
259 uint tid [[thread_position_in_threadgroup]],\
260 uint tgid [[threadgroup_position_in_grid]],\
261 uint tptg [[threads_per_threadgroup]])
262
263 REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments);
264 REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments);
265 REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments);
266 REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments);
267 REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments);
268 REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments);
269 REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments);
270 REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments);
271 REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
272 REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
273 REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
274 REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
275 REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
276 REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
277 REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
278 REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
279
280 template <typename T>
281 inline void sgd_momentum_math(
282 device T & param,
283 device T & grad,
284 device T & momentum_buffer,
285 const float weight_decay,
286 const float momentum,
287 const float lr,
288 const float dampening,
289 const uint8_t nesterov,
290 const uint8_t maximize,
291 const uint8_t is_first_step
292 ) {
293 auto grad_ = grad;
294 if (maximize) {
295 grad_ *= -1.0;
296 }
297 if (weight_decay != 0) {
298 grad_ += weight_decay * param;
299 }
300
301 momentum_buffer = is_first_step ? grad_ : (momentum * momentum_buffer + (1 - dampening) * grad_);
302 if (nesterov) {
303 grad_ += momentum * momentum_buffer;
304 } else {
305 grad_ = momentum_buffer;
306 }
307
308 param -= lr * grad_;
309 }
310
311 template <typename T>
312 inline void sgd_math(
313 device T & param,
314 device T & grad,
315 const float weight_decay,
316 const float lr,
317 const uint8_t maximize
318 ) {
319 auto grad_ = grad;
320 if (maximize) {
321 grad_ *= -1.0;
322 }
323 if (weight_decay != 0) {
324 grad_ += weight_decay * param;
325 }
326
327 param -= lr * grad_;
328 }
329
330 template <typename T>
331 kernel void fused_sgd(
332 device SgdArguments<T> & args [[buffer(0)]],
333 constant MetadataArguments & metadata_args [[buffer(1)]],
334 constant float & weight_decay [[buffer(2)]],
335 constant float & lr [[buffer(3)]],
336 constant uint8_t & maximize [[buffer(4)]],
337 uint tid [[thread_position_in_threadgroup]],
338 uint tgid [[threadgroup_position_in_grid]],
339 uint tptg [[threads_per_threadgroup]]) {
340
341 const uint32_t tensor_loc = metadata_args.threadgroup_to_tensor[tgid];
342 const uint32_t chunk_idx = metadata_args.threadgroup_to_chunk[tgid];
343 const uint32_t chunk_offset = chunk_idx * chunk_size;
344 const uint32_t numel = metadata_args.numels[tensor_loc] - chunk_offset;
345
346 // each chunk is a threadgroup
347 auto param = args.params[tensor_loc] + chunk_offset;
348 auto grad = args.grads[tensor_loc] + chunk_offset;
349
350 for (uint32_t i_start = tid; i_start < numel && i_start < chunk_size; i_start += tptg) {
351 sgd_math<T>(
352 *(param + i_start),
353 *(grad + i_start),
354 weight_decay,
355 lr,
356 maximize
357 );
358 }
359 }
360
361 template <typename T>
362 kernel void fused_sgd(
363 device SgdMomentumArguments<T> & args [[buffer(0)]],
364 constant MetadataArguments & metadata_args [[buffer(1)]],
365 constant float & weight_decay [[buffer(2)]],
366 constant float & momentum [[buffer(3)]],
367 constant float & lr [[buffer(4)]],
368 constant float & dampening [[buffer(5)]],
369 constant uint8_t & nesterov [[buffer(6)]],
370 constant uint8_t & maximize [[buffer(7)]],
371 constant uint8_t & is_first_step [[buffer(8)]],
372 uint tid [[thread_position_in_threadgroup]],
373 uint tgid [[threadgroup_position_in_grid]],
374 uint tptg [[threads_per_threadgroup]]) {
375
376 const uint32_t tensor_loc = metadata_args.threadgroup_to_tensor[tgid];
377 const uint32_t chunk_idx = metadata_args.threadgroup_to_chunk[tgid];
378 const uint32_t chunk_offset = chunk_idx * chunk_size;
379 const uint32_t numel = metadata_args.numels[tensor_loc] - chunk_offset;
380
381 // each chunk is a threadgroup
382 auto param = args.params[tensor_loc] + chunk_offset;
383 auto grad = args.grads[tensor_loc] + chunk_offset;
384 auto momentum_buffer_list = args.momentum_buffer_list[tensor_loc] + chunk_offset;
385
386 for (uint32_t i_start = tid; i_start < numel && i_start < chunk_size; i_start += tptg) {
387 sgd_momentum_math<T>(
388 *(param + i_start),
389 *(grad + i_start),
390 *(momentum_buffer_list + i_start),
391 weight_decay,
392 momentum,
393 lr,
394 dampening,
395 nesterov,
396 maximize,
397 is_first_step
398 );
399 }
400 }
401
402 #define REGISTER_FUSED_SGD_OP(DTYPE) \
403 template \
404 [[host_name("fused_sgd_" #DTYPE)]] \
405 kernel void fused_sgd<DTYPE>( \
406 device SgdArguments<DTYPE> & args [[buffer(0)]], \
407 constant MetadataArguments & metadata_args [[buffer(1)]], \
408 constant float & weight_decay [[buffer(2)]], \
409 constant float & lr [[buffer(3)]], \
410 constant uint8_t & maximize [[buffer(4)]], \
411 uint tid [[thread_position_in_threadgroup]], \
412 uint tgid [[threadgroup_position_in_grid]], \
413 uint tptg [[threads_per_threadgroup]])
414
415 #define REGISTER_FUSED_SGD_MOMENTUM_OP(DTYPE) \
416 template \
417 [[host_name("fused_sgd_momentum_" #DTYPE)]] \
418 kernel void fused_sgd<DTYPE>( \
419 device SgdMomentumArguments<DTYPE> & args [[buffer(0)]], \
420 constant MetadataArguments & metadata_args [[buffer(1)]], \
421 constant float & weight_decay [[buffer(2)]], \
422 constant float & momentum [[buffer(3)]], \
423 constant float & lr [[buffer(4)]], \
424 constant float & dampening [[buffer(5)]], \
425 constant uint8_t & nesterov [[buffer(6)]], \
426 constant uint8_t & maximize [[buffer(7)]], \
427 constant uint8_t & is_first_step [[buffer(8)]], \
428 uint tid [[thread_position_in_threadgroup]], \
429 uint tgid [[threadgroup_position_in_grid]], \
430 uint tptg [[threads_per_threadgroup]])
431
432 REGISTER_FUSED_SGD_OP(float);
433 REGISTER_FUSED_SGD_OP(half);
434 REGISTER_FUSED_SGD_MOMENTUM_OP(float);
435 REGISTER_FUSED_SGD_MOMENTUM_OP(half);
436
437 )METAL";
438
getCPLState(const std::string & fname)439 static std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getCPLState(const std::string& fname) {
440 static MetalShaderLibrary lib(FUSED_ADAM_OPS, 0);
441 return std::make_pair(lib.getPipelineStateForFunc(fname), lib.getMTLFunction(fname));
442 }
443
444 } //namespace mps
445 } // namespace at::native