• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <ATen/mps/MPSProfiler.h>
4 #include <ATen/native/mps/operations/FusedOptimizerOps.h>
5 
6 namespace at::native {
7 namespace mps {
8 
9 static constexpr int64_t kChunkSize = 65536;
10 static constexpr int64_t kmaxThreadGroups = 32;
11 static constexpr int64_t kmaxTensors = 32;
12 
13 struct MetadataArguments { // the size of this struct must be less than 4 bytes
14   uint numels[kmaxTensors];
15   uint threadgroup_to_tensor[kmaxThreadGroups];
16   uint threadgroup_to_chunk[kmaxThreadGroups];
17 };
18 
19 struct FusedAdamEncodingFunctor {
operatorFusedAdamEncodingFunctor20   void operator()(
21       id<MTLComputeCommandEncoder>& computeEncoder,
22       id<MTLBuffer>& tensorArgumentBuffer,
23       const MetadataArguments& metadata_arguments,
24       const double lr,
25       const double beta1,
26       const double beta2,
27       const double weight_decay,
28       const double eps,
29       const bool maximize
30     ) const {
31 
32     float lr_lv = lr;
33     float beta1_lv = beta1;
34     float beta2_lv = beta2;
35     float weight_decay_lv = weight_decay;
36     float eps_lv = eps;
37     uint8_t maximize_lv = maximize;
38 
39     [computeEncoder setBuffer:tensorArgumentBuffer
40                                   offset:0
41                                   atIndex:0];
42     [computeEncoder setBytes:&metadata_arguments
43                                   length:sizeof(MetadataArguments)
44                                   atIndex:1];
45     mtl_setBytes(computeEncoder, lr_lv, 2);
46     mtl_setBytes(computeEncoder, beta1_lv, 3);
47     mtl_setBytes(computeEncoder, beta2_lv, 4);
48     mtl_setBytes(computeEncoder, weight_decay_lv, 5);
49     mtl_setBytes(computeEncoder, eps_lv, 6);
50     mtl_setBytes(computeEncoder, maximize_lv, 7);
51   }
52 
operatorFusedAdamEncodingFunctor53   void operator()(
54       id<MTLComputeCommandEncoder>& computeEncoder,
55       id<MTLBuffer>& tensorArgumentBuffer,
56       const MetadataArguments& metadata_arguments,
57       const at::Tensor& lr,
58       const double beta1,
59       const double beta2,
60       const double weight_decay,
61       const double eps,
62       const bool maximize
63     ) const {
64     float beta1_lv = beta1;
65     float beta2_lv = beta2;
66     float weight_decay_lv = weight_decay;
67     float eps_lv = eps;
68     uint8_t maximize_lv = maximize;
69 
70     [computeEncoder setBuffer:tensorArgumentBuffer
71                                   offset:0
72                                   atIndex:0];
73     [computeEncoder setBytes:&metadata_arguments
74                                   length:sizeof(MetadataArguments)
75                                   atIndex:1];
76     mtl_setBuffer(computeEncoder, lr, 2);
77     mtl_setBytes(computeEncoder, beta1_lv, 3);
78     mtl_setBytes(computeEncoder, beta2_lv, 4);
79     mtl_setBytes(computeEncoder, weight_decay_lv, 5);
80     mtl_setBytes(computeEncoder, eps_lv, 6);
81     mtl_setBytes(computeEncoder, maximize_lv, 7);
82   }
83 };
84 
85 template <bool momentum>
86 struct FusedSgdEncodingFunctor {};
87 
88 template <>
89 struct FusedSgdEncodingFunctor<true> {
90   void operator()(
91     id<MTLComputeCommandEncoder>& computeEncoder,
92       id<MTLBuffer>& tensorArgumentBuffer,
93       const MetadataArguments& metadata_arguments,
94       const double weight_decay,
95       const double momentum,
96       const double lr,
97       const double dampening,
98       const bool nesterov,
99       const bool maximize,
100       const bool is_first_step
101     ) const {
102       float weight_decay_lv = weight_decay;
103       float momentum_lv = momentum;
104       float lr_lv = lr;
105       float dampening_lv = dampening;
106       uint8_t nesterov_lv = nesterov;
107       uint8_t maximize_lv = maximize;
108       uint8_t is_first_step_lv = is_first_step;
109 
110       [computeEncoder setBuffer:tensorArgumentBuffer
111                                   offset:0
112                                   atIndex:0];
113       [computeEncoder setBytes:&metadata_arguments
114                                   length:sizeof(MetadataArguments)
115                                   atIndex:1];
116       mtl_setBytes(computeEncoder, weight_decay_lv, 2);
117       mtl_setBytes(computeEncoder, momentum_lv, 3);
118       mtl_setBytes(computeEncoder, lr_lv, 4);
119       mtl_setBytes(computeEncoder, dampening_lv, 5);
120       mtl_setBytes(computeEncoder, nesterov_lv, 6);
121       mtl_setBytes(computeEncoder, maximize_lv, 7);
122       mtl_setBytes(computeEncoder, is_first_step_lv, 8);
123   }
124 
125   void operator()(
126     id<MTLComputeCommandEncoder>& computeEncoder,
127       id<MTLBuffer>& tensorArgumentBuffer,
128       const MetadataArguments& metadata_arguments,
129       const double weight_decay,
130       const double momentum,
131       const at::Tensor& lr,
132       const double dampening,
133       const bool nesterov,
134       const bool maximize,
135       const bool is_first_step
136     ) const {
137       float weight_decay_lv = weight_decay;
138       float momentum_lv = momentum;
139       float dampening_lv = dampening;
140       uint8_t nesterov_lv = nesterov;
141       uint8_t maximize_lv = maximize;
142       uint8_t is_first_step_lv = is_first_step;
143 
144       [computeEncoder setBuffer:tensorArgumentBuffer
145                                   offset:0
146                                   atIndex:0];
147       [computeEncoder setBytes:&metadata_arguments
148                                   length:sizeof(MetadataArguments)
149                                   atIndex:1];
150       mtl_setBytes(computeEncoder, weight_decay_lv, 2);
151       mtl_setBytes(computeEncoder, momentum_lv, 3);
152       mtl_setBuffer(computeEncoder, lr, 4);
153       mtl_setBytes(computeEncoder, dampening_lv, 5);
154       mtl_setBytes(computeEncoder, nesterov_lv, 6);
155       mtl_setBytes(computeEncoder, maximize_lv, 7);
156       mtl_setBytes(computeEncoder, is_first_step_lv, 8);
157   }
158 };
159 
160 template <>
161 struct FusedSgdEncodingFunctor<false> {
162   void operator()(
163     id<MTLComputeCommandEncoder>& computeEncoder,
164       id<MTLBuffer>& tensorArgumentBuffer,
165       const MetadataArguments& metadata_arguments,
166       const double weight_decay,
167       const double lr,
168       const bool maximize
169     ) const {
170       float weight_decay_lv = weight_decay;
171       float lr_lv = lr;
172       uint8_t maximize_lv = maximize;
173 
174       [computeEncoder setBuffer:tensorArgumentBuffer
175                                   offset:0
176                                   atIndex:0];
177       [computeEncoder setBytes:&metadata_arguments
178                                   length:sizeof(MetadataArguments)
179                                   atIndex:1];
180       mtl_setBytes(computeEncoder, weight_decay_lv, 2);
181       mtl_setBytes(computeEncoder, lr_lv, 3);
182       mtl_setBytes(computeEncoder, maximize_lv, 4);
183   }
184 
185   void operator()(
186     id<MTLComputeCommandEncoder>& computeEncoder,
187       id<MTLBuffer>& tensorArgumentBuffer,
188       const MetadataArguments& metadata_arguments,
189       const double weight_decay,
190       const at::Tensor& lr,
191       const bool maximize
192     ) const {
193       float weight_decay_lv = weight_decay;
194       uint8_t maximize_lv = maximize;
195 
196       [computeEncoder setBuffer:tensorArgumentBuffer
197                                   offset:0
198                                   atIndex:0];
199       [computeEncoder setBytes:&metadata_arguments
200                                   length:sizeof(MetadataArguments)
201                                   atIndex:1];
202       mtl_setBytes(computeEncoder, weight_decay_lv, 2);
203       mtl_setBuffer(computeEncoder, lr, 3);
204       mtl_setBytes(computeEncoder, maximize_lv, 4);
205   }
206 };
207 
208 template <int depth, uint32_t kThreadGroupSize, typename encoder_func_t, typename... ArgTypes>
209 static void multi_tensor_apply_for_fused_optimizer(
210     const std::string& kernel_name,
211     std::vector<std::vector<at::Tensor>>& tensor_lists,
212     at::TensorList state_steps,
213     encoder_func_t encode,
214     ArgTypes... args
215     ) {
216   const auto num_tensors = tensor_lists[0].size();
217 
218   if (num_tensors == 0) {
219     return;
220   }
221 
222   TORCH_CHECK(
223       tensor_lists.size() == depth,
224       "Number of tensor lists has to match the depth");
225   for (const auto& d : c10::irange(depth)) {
226     TORCH_CHECK(
227       tensor_lists[d][0].scalar_type() == at::ScalarType::Float || tensor_lists[d][0].scalar_type() == at::ScalarType::Half, "Only float and half are supported");
228   }
229 
230   id<MTLDevice> device = MPSDevice::getInstance()->device();
231   MPSStream* mpsStream = getCurrentMPSStream();
232 
233   // Remove comment for debugging
234   /*
235   mpsStream->addCompletedHandler(^(id<MTLCommandBuffer> cb) {
236     [cb.logs enumerateObjectsUsingBlock:^(NSString* log, NSUInteger idx, BOOL* stop) {
237       NSLog(@"MPSStream: %@", log);
238       }
239     ];
240   });
241   */
242 
243   dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
244     @autoreleasepool {
245       id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
246       auto [fusedOptimizerPSO, fusedOptimizerFunc] = getCPLState(kernel_name);
247 
248       // this function call is a no-op if MPS Profiler is not enabled
249       getMPSProfiler().beginProfileKernel(fusedOptimizerPSO, kernel_name, {tensor_lists[0]});
250 
251       [computeEncoder setComputePipelineState:fusedOptimizerPSO];
252 
253       // BufferIndex is the index in the kernel function
254       auto tensorArgumentEncoder = [[fusedOptimizerFunc newArgumentEncoderWithBufferIndex:0] autorelease];
255       id<MTLBuffer> tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
256       [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
257 
258       int64_t tensor_loc = 0;
259       int64_t threadgroup_loc = 0;
260       MetadataArguments metadata_arguments;
261 
262       for (const auto tensor_index : c10::irange(num_tensors)) {
263         // short-circuit to avoid adding empty tensors to tensorListMeta
264         if (tensor_lists[0][tensor_index].numel() == 0) {
265           continue;
266         }
267 
268         for (const auto& d : c10::irange(depth)) {
269             mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors + tensor_loc);
270             [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageRead | MTLResourceUsageWrite];
271         }
272         if (state_steps.size() > 0){
273           mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors + tensor_loc);
274           [computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead];
275         }
276         metadata_arguments.numels[tensor_loc] = tensor_lists[0][tensor_index].numel();
277 
278         tensor_loc++;
279 
280         const auto numel = tensor_lists[0][tensor_index].numel();
281         const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
282         TORCH_CHECK(chunks > -1);
283 
284         for (const auto& chunk : c10::irange(chunks)) {
285             metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = tensor_loc - 1;
286             metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk;
287 
288             threadgroup_loc++;
289 
290             const auto tensor_full = tensor_loc == kmaxTensors && chunk == chunks - 1;
291             // Reach the maximum threadgroups per dispatch
292             const auto blocks_full = threadgroup_loc == kmaxThreadGroups;
293 
294             if (tensor_full || blocks_full){
295                 encode(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...);
296                 MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1);
297                 uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup];
298                 MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1);
299                 [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize];
300 
301                 // Reset
302                 threadgroup_loc = 0;
303                 if (chunk == chunks - 1) {
304                   // last chunk
305                   tensor_loc = 0;
306                   tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
307                   [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
308                 } else {
309                   // reuse the current tensor since the current one isn't done.
310                   metadata_arguments.numels[0] = metadata_arguments.numels[tensor_loc - 1];
311 
312                   tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
313                   [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
314 
315                   for (const auto& d : c10::irange(depth)) {
316                       mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors);
317                       [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageWrite | MTLResourceUsageRead];
318                   }
319                   if (state_steps.size() > 0){
320                     mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors);
321                     [computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead];
322                   }
323                   tensor_loc = 1;
324                 }
325             }
326         }
327       }
328 
329       if (threadgroup_loc != 0) {
330         encode(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...);
331         MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1);
332         uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup];
333         MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1);
334         [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize];
335       }
336 
337       getMPSProfiler().endProfileKernel(fusedOptimizerPSO);
338 
339     }
340   });
341 }
342 
343 } // namespace mps
344 } // namespace at::native
345