• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 #include <ATen/native/mps/OperationUtils.h>
3 
4 namespace at::native {
5 namespace mps {
6 
7 static const char* FUSED_ADAM_OPS = R"METAL(
8 #include <metal_stdlib>
9 
10 #define kmaxThreadGroups 32
11 #define kmaxTensors 32
12 #define chunk_size 65536
13 
14 constexpr constant uint kParamIdx = 0;
15 constexpr constant uint kGradIdx = kParamIdx + kmaxTensors;
16 constexpr constant uint kExpAvgIdx = kGradIdx + kmaxTensors;
17 constexpr constant uint kMomentumBufferListIdx = kGradIdx + kmaxTensors;
18 constexpr constant uint kExpAvgSqIdx = kExpAvgIdx + kmaxTensors;
19 constexpr constant uint kMaxExpAvgSqIdx = kExpAvgSqIdx + kmaxTensors;
20 constexpr constant uint kStateStepsIdx = kExpAvgSqIdx + kmaxTensors;
21 constexpr constant uint kStateStepsIdxForAmsgrad = kMaxExpAvgSqIdx + kmaxTensors;
22 
23 template<typename T, typename state_steps_t>
24 struct AdamArguments {
25     metal::array<device T *,  kmaxTensors>   params        [[ id(kParamIdx) ]];
26     metal::array<device T *,  kmaxTensors>   grads         [[ id(kGradIdx) ]];
27     metal::array<device T *,  kmaxTensors>   exp_avgs      [[ id(kExpAvgIdx) ]];
28     metal::array<device T *,  kmaxTensors>   exp_avg_sqs   [[ id(kExpAvgSqIdx) ]];
29     metal::array<device state_steps_t *,  kmaxTensors>   state_steps   [[ id(kStateStepsIdx) ]];
30 };
31 
32 template<typename T, typename state_steps_t>
33 struct AdamAmsgradArguments {
34     metal::array<device T *,  kmaxTensors>   params        [[ id(kParamIdx) ]];
35     metal::array<device T *,  kmaxTensors>   grads         [[ id(kGradIdx) ]];
36     metal::array<device T *,  kmaxTensors>   exp_avgs      [[ id(kExpAvgIdx) ]];
37     metal::array<device T *,  kmaxTensors>   exp_avg_sqs   [[ id(kExpAvgSqIdx) ]];
38     metal::array<device T *,  kmaxTensors>   max_exp_avg_sqs   [[ id(kMaxExpAvgSqIdx) ]];
39     metal::array<device state_steps_t *,  kmaxTensors>   state_steps   [[ id(kStateStepsIdxForAmsgrad) ]];
40 };
41 
42 template<typename T>
43 struct SgdArguments {
44     metal::array<device T *,  kmaxTensors>   params        [[ id(kParamIdx) ]];
45     metal::array<device T *,  kmaxTensors>   grads         [[ id(kGradIdx) ]];
46 };
47 
48 template<typename T>
49 struct SgdMomentumArguments {
50     metal::array<device T *,  kmaxTensors>   params        [[ id(kParamIdx) ]];
51     metal::array<device T *,  kmaxTensors>   grads         [[ id(kGradIdx) ]];
52     metal::array<device T *,  kmaxTensors>   momentum_buffer_list       [[ id(kMomentumBufferListIdx) ]];
53 };
54 
55 struct MetadataArguments {
56     uint32_t numels[kmaxTensors];
57     uint32_t threadgroup_to_tensor[kmaxThreadGroups];
58     uint32_t threadgroup_to_chunk[kmaxThreadGroups];
59 };
60 
61 enum ADAM_MODE : uint8_t {
62   ORIGINAL = 0,
63   ADAMW = 1
64 };
65 
66 template <typename T, typename state_steps_t, ADAM_MODE adam_mode>
67 inline void adam_math_amsgrad(
68     device T & param,
69     device T & grad,
70     device T & exp_avg,
71     device T & exp_avg_sq,
72     device T & max_exp_avg_sq,
73     device state_steps_t & state_steps,
74     const float lr,
75     const float beta1,
76     const float beta2,
77     const float weight_decay,
78     const float eps,
79     const uint8_t maximize
80 ) {
81   T grad_ = grad;
82 
83   if (maximize) {
84     grad = -grad;
85   }
86 
87   // Update param, grad, 1st and 2nd order momentum.
88   if (weight_decay != 0) {
89     switch (adam_mode) {
90       case ADAM_MODE::ORIGINAL:
91         grad += param * weight_decay;
92         break;
93       case ADAM_MODE::ADAMW:
94         param -= lr * weight_decay * param;
95         break;
96     }
97   }
98 
99   exp_avg = beta1 * exp_avg + (1 - beta1) * grad;
100   exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad;
101   const float casted_state_steps = static_cast<float>(state_steps);
102   const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps);
103   const T step_size = lr / bias_correction1;
104   const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps);
105   const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
106   max_exp_avg_sq = metal::max(max_exp_avg_sq, exp_avg_sq);
107 
108   const T denom = (metal::precise::sqrt(max_exp_avg_sq) / bias_correction2_sqrt) + eps;
109   param -= step_size * exp_avg / denom;
110   grad = grad_;
111 }
112 
113 template <typename T, typename state_steps_t, ADAM_MODE adam_mode>
114 inline void adam_math(
115     device T & param,
116     device T & grad,
117     device T & exp_avg,
118     device T & exp_avg_sq,
119     device state_steps_t & state_steps,
120     const float lr,
121     const float beta1,
122     const float beta2,
123     const float weight_decay,
124     const float eps,
125     const uint8_t maximize
126 ) {
127   T grad_ = grad;
128 
129   if (maximize) {
130     grad = -grad;
131   }
132 
133   // Update param, grad, 1st and 2nd order momentum.
134   if (weight_decay != 0) {
135     switch (adam_mode) {
136       case ADAM_MODE::ORIGINAL:
137         grad += param * weight_decay;
138         break;
139       case ADAM_MODE::ADAMW:
140         param -= lr * weight_decay * param;
141         break;
142     }
143   }
144 
145   exp_avg = beta1 * exp_avg + (1 - beta1) * grad;
146   exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad;
147   const float casted_state_steps = static_cast<float>(state_steps);
148   const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps);
149   const T step_size = lr / bias_correction1;
150   const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps);
151   const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
152   const T denom = (metal::precise::sqrt(exp_avg_sq) / bias_correction2_sqrt) + eps;
153   param -= step_size * exp_avg / denom;
154   grad = grad_;
155 }
156 
157 template <typename T, typename state_steps_t, ADAM_MODE adam_mode>
158 kernel void fused_adam_amsgrad(
159     device   AdamAmsgradArguments<T, state_steps_t> & args    [[buffer(0)]],
160     constant MetadataArguments & metadata_args [[buffer(1)]],
161     constant float & lr             [[buffer(2)]],
162     constant float & beta1          [[buffer(3)]],
163     constant float & beta2          [[buffer(4)]],
164     constant float & weight_decay   [[buffer(5)]],
165     constant float & eps            [[buffer(6)]],
166     constant uint8_t   & maximize       [[buffer(7)]],
167     uint tid [[thread_position_in_threadgroup]],
168     uint tgid [[threadgroup_position_in_grid]],
169     uint tptg [[threads_per_threadgroup]]) {
170 
171     const uint32_t tensor_loc = metadata_args.threadgroup_to_tensor[tgid];
172     const uint32_t chunk_idx = metadata_args.threadgroup_to_chunk[tgid];
173     const uint32_t chunk_offset = chunk_idx * chunk_size;
174     const uint32_t numel = metadata_args.numels[tensor_loc] - chunk_offset;
175 
176     const auto step_count = args.state_steps[tensor_loc];
177 
178     // each chunk is a threadgroup
179     auto param = args.params[tensor_loc] + chunk_offset;
180     auto grad = args.grads[tensor_loc] + chunk_offset;
181     auto exp_avg = args.exp_avgs[tensor_loc] + chunk_offset;
182     auto exp_avg_sq = args.exp_avg_sqs[tensor_loc] + chunk_offset;
183     auto max_exp_avg_sq = args.max_exp_avg_sqs[tensor_loc] + chunk_offset;
184 
185     for (uint32_t i_start = tid; i_start < numel && i_start < chunk_size; i_start += tptg) {
186       adam_math_amsgrad<T, state_steps_t, adam_mode>(
187         *(param + i_start),
188         *(grad + i_start),
189         *(exp_avg + i_start),
190         *(exp_avg_sq + i_start),
191         *(max_exp_avg_sq + i_start),
192         *step_count,
193         lr,
194         beta1,
195         beta2,
196         weight_decay,
197         eps,
198         maximize
199       );
200     }
201 }
202 
203 template <typename T, typename state_steps_t, ADAM_MODE adam_mode>
204 kernel void fused_adam(
205     device   AdamArguments<T, state_steps_t> & args    [[buffer(0)]],
206     constant MetadataArguments & metadata_args [[buffer(1)]],
207     constant float & lr             [[buffer(2)]],
208     constant float & beta1          [[buffer(3)]],
209     constant float & beta2          [[buffer(4)]],
210     constant float & weight_decay   [[buffer(5)]],
211     constant float & eps            [[buffer(6)]],
212     constant uint8_t   & maximize       [[buffer(7)]],
213     uint tid [[thread_position_in_threadgroup]],
214     uint tgid [[threadgroup_position_in_grid]],
215     uint tptg [[threads_per_threadgroup]]) {
216 
217     const uint32_t tensor_loc = metadata_args.threadgroup_to_tensor[tgid];
218     const uint32_t chunk_idx = metadata_args.threadgroup_to_chunk[tgid];
219     const uint32_t chunk_offset = chunk_idx * chunk_size;
220     const uint32_t numel = metadata_args.numels[tensor_loc] - chunk_offset;
221 
222     const auto step_count = args.state_steps[tensor_loc];
223 
224     // each chunk is a threadgroup
225     auto param = args.params[tensor_loc] + chunk_offset;
226     auto grad = args.grads[tensor_loc] + chunk_offset;
227     auto exp_avg = args.exp_avgs[tensor_loc] + chunk_offset;
228     auto exp_avg_sq = args.exp_avg_sqs[tensor_loc] + chunk_offset;
229 
230     for (uint32_t i_start = tid; i_start < numel && i_start < chunk_size; i_start += tptg) {
231       adam_math<T, state_steps_t, adam_mode>(
232         *(param + i_start),
233         *(grad + i_start),
234         *(exp_avg + i_start),
235         *(exp_avg_sq + i_start),
236         *step_count,
237         lr,
238         beta1,
239         beta2,
240         weight_decay,
241         eps,
242         maximize
243       );
244     }
245 }
246 
247 #define REGISTER_FUSED_ADAM_OP(DTYPE, STATE_STEPS_DTYPE, ADAM_MODE_DTYPE, HOST_NAME, KERNEL_NAME, ARGUMENTS_STRUCT)       \
248 template                                    \
249 [[host_name(#HOST_NAME "_" #DTYPE "_" #STATE_STEPS_DTYPE)]]        \
250 kernel void KERNEL_NAME<DTYPE, STATE_STEPS_DTYPE, ADAM_MODE_DTYPE>(             \
251     device   ARGUMENTS_STRUCT<DTYPE, STATE_STEPS_DTYPE> & args    [[buffer(0)]],\
252     constant MetadataArguments & metadata_args [[buffer(1)]],\
253     constant float & lr             [[buffer(2)]],\
254     constant float & beta1          [[buffer(3)]],\
255     constant float & beta2          [[buffer(4)]],\
256     constant float & weight_decay   [[buffer(5)]],\
257     constant float & eps            [[buffer(6)]],\
258     constant uint8_t   & maximize       [[buffer(7)]],\
259     uint tid [[thread_position_in_threadgroup]],\
260     uint tgid [[threadgroup_position_in_grid]],\
261     uint tptg [[threads_per_threadgroup]])
262 
263 REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments);
264 REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments);
265 REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments);
266 REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments);
267 REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments);
268 REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments);
269 REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments);
270 REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments);
271 REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
272 REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
273 REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
274 REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
275 REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
276 REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
277 REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
278 REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
279 
280 template <typename T>
281 inline void sgd_momentum_math(
282     device T & param,
283     device T & grad,
284     device T & momentum_buffer,
285     const float weight_decay,
286     const float momentum,
287     const float lr,
288     const float dampening,
289     const uint8_t nesterov,
290     const uint8_t maximize,
291     const uint8_t is_first_step
292 ) {
293   auto grad_ = grad;
294   if (maximize) {
295       grad_ *= -1.0;
296   }
297   if (weight_decay != 0) {
298       grad_ += weight_decay * param;
299   }
300 
301   momentum_buffer = is_first_step ? grad_ : (momentum * momentum_buffer + (1 - dampening) * grad_);
302   if (nesterov) {
303       grad_ += momentum * momentum_buffer;
304   } else {
305       grad_ = momentum_buffer;
306   }
307 
308   param -= lr * grad_;
309 }
310 
311 template <typename T>
312 inline void sgd_math(
313     device T & param,
314     device T & grad,
315     const float weight_decay,
316     const float lr,
317     const uint8_t maximize
318 ) {
319   auto grad_ = grad;
320   if (maximize) {
321       grad_ *= -1.0;
322   }
323   if (weight_decay != 0) {
324       grad_ += weight_decay * param;
325   }
326 
327   param -= lr * grad_;
328 }
329 
330 template <typename T>
331 kernel void fused_sgd(
332     device   SgdArguments<T> & args    [[buffer(0)]],
333     constant MetadataArguments & metadata_args [[buffer(1)]],
334     constant float & weight_decay   [[buffer(2)]],
335     constant float & lr             [[buffer(3)]],
336     constant uint8_t & maximize       [[buffer(4)]],
337     uint tid [[thread_position_in_threadgroup]],
338     uint tgid [[threadgroup_position_in_grid]],
339     uint tptg [[threads_per_threadgroup]]) {
340 
341     const uint32_t tensor_loc = metadata_args.threadgroup_to_tensor[tgid];
342     const uint32_t chunk_idx = metadata_args.threadgroup_to_chunk[tgid];
343     const uint32_t chunk_offset = chunk_idx * chunk_size;
344     const uint32_t numel = metadata_args.numels[tensor_loc] - chunk_offset;
345 
346     // each chunk is a threadgroup
347     auto param = args.params[tensor_loc] + chunk_offset;
348     auto grad = args.grads[tensor_loc] + chunk_offset;
349 
350     for (uint32_t i_start = tid; i_start < numel && i_start < chunk_size; i_start += tptg) {
351       sgd_math<T>(
352         *(param + i_start),
353         *(grad + i_start),
354         weight_decay,
355         lr,
356         maximize
357       );
358     }
359 }
360 
361 template <typename T>
362 kernel void fused_sgd(
363     device   SgdMomentumArguments<T> & args    [[buffer(0)]],
364     constant MetadataArguments & metadata_args [[buffer(1)]],
365     constant float & weight_decay   [[buffer(2)]],
366     constant float & momentum       [[buffer(3)]],
367     constant float & lr             [[buffer(4)]],
368     constant float & dampening          [[buffer(5)]],
369     constant uint8_t & nesterov          [[buffer(6)]],
370     constant uint8_t   & maximize       [[buffer(7)]],
371     constant uint8_t   & is_first_step       [[buffer(8)]],
372     uint tid [[thread_position_in_threadgroup]],
373     uint tgid [[threadgroup_position_in_grid]],
374     uint tptg [[threads_per_threadgroup]]) {
375 
376     const uint32_t tensor_loc = metadata_args.threadgroup_to_tensor[tgid];
377     const uint32_t chunk_idx = metadata_args.threadgroup_to_chunk[tgid];
378     const uint32_t chunk_offset = chunk_idx * chunk_size;
379     const uint32_t numel = metadata_args.numels[tensor_loc] - chunk_offset;
380 
381     // each chunk is a threadgroup
382     auto param = args.params[tensor_loc] + chunk_offset;
383     auto grad = args.grads[tensor_loc] + chunk_offset;
384     auto momentum_buffer_list = args.momentum_buffer_list[tensor_loc] + chunk_offset;
385 
386     for (uint32_t i_start = tid; i_start < numel && i_start < chunk_size; i_start += tptg) {
387       sgd_momentum_math<T>(
388         *(param + i_start),
389         *(grad + i_start),
390         *(momentum_buffer_list + i_start),
391         weight_decay,
392         momentum,
393         lr,
394         dampening,
395         nesterov,
396         maximize,
397         is_first_step
398       );
399     }
400 }
401 
402 #define REGISTER_FUSED_SGD_OP(DTYPE) \
403 template                                                            \
404 [[host_name("fused_sgd_" #DTYPE)]]                                  \
405 kernel void fused_sgd<DTYPE>(                                       \
406     device   SgdArguments<DTYPE>     & args          [[buffer(0)]], \
407     constant MetadataArguments       & metadata_args [[buffer(1)]], \
408     constant float                   & weight_decay  [[buffer(2)]], \
409     constant float                   & lr            [[buffer(3)]], \
410     constant uint8_t                 & maximize      [[buffer(4)]], \
411     uint tid  [[thread_position_in_threadgroup]], \
412     uint tgid [[threadgroup_position_in_grid]],   \
413     uint tptg [[threads_per_threadgroup]])
414 
415 #define REGISTER_FUSED_SGD_MOMENTUM_OP(DTYPE) \
416 template                                                            \
417 [[host_name("fused_sgd_momentum_" #DTYPE)]]                         \
418 kernel void fused_sgd<DTYPE>(                                       \
419     device   SgdMomentumArguments<DTYPE> & args      [[buffer(0)]], \
420     constant MetadataArguments       & metadata_args [[buffer(1)]], \
421     constant float                   & weight_decay  [[buffer(2)]], \
422     constant float                   & momentum      [[buffer(3)]], \
423     constant float                   & lr            [[buffer(4)]], \
424     constant float                   & dampening     [[buffer(5)]], \
425     constant uint8_t                 & nesterov      [[buffer(6)]], \
426     constant uint8_t                 & maximize      [[buffer(7)]], \
427     constant uint8_t                 & is_first_step [[buffer(8)]], \
428     uint tid  [[thread_position_in_threadgroup]], \
429     uint tgid [[threadgroup_position_in_grid]],   \
430     uint tptg [[threads_per_threadgroup]])
431 
432 REGISTER_FUSED_SGD_OP(float);
433 REGISTER_FUSED_SGD_OP(half);
434 REGISTER_FUSED_SGD_MOMENTUM_OP(float);
435 REGISTER_FUSED_SGD_MOMENTUM_OP(half);
436 
437 )METAL";
438 
getCPLState(const std::string & fname)439 static std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getCPLState(const std::string& fname) {
440   static MetalShaderLibrary lib(FUSED_ADAM_OPS, 0);
441   return std::make_pair(lib.getPipelineStateForFunc(fname), lib.getMTLFunction(fname));
442 }
443 
444 } //namespace mps
445 } // namespace at::native