1 #pragma once 2 #include <ATen/core/Tensor.h> 3 #include <ATen/mps/MPSProfiler.h> 4 #include <ATen/native/mps/operations/FusedOptimizerOps.h> 5 6 namespace at::native { 7 namespace mps { 8 9 static constexpr int64_t kChunkSize = 65536; 10 static constexpr int64_t kmaxThreadGroups = 32; 11 static constexpr int64_t kmaxTensors = 32; 12 13 struct MetadataArguments { // the size of this struct must be less than 4 bytes 14 uint numels[kmaxTensors]; 15 uint threadgroup_to_tensor[kmaxThreadGroups]; 16 uint threadgroup_to_chunk[kmaxThreadGroups]; 17 }; 18 19 struct FusedAdamEncodingFunctor { operatorFusedAdamEncodingFunctor20 void operator()( 21 id<MTLComputeCommandEncoder>& computeEncoder, 22 id<MTLBuffer>& tensorArgumentBuffer, 23 const MetadataArguments& metadata_arguments, 24 const double lr, 25 const double beta1, 26 const double beta2, 27 const double weight_decay, 28 const double eps, 29 const bool maximize 30 ) const { 31 32 float lr_lv = lr; 33 float beta1_lv = beta1; 34 float beta2_lv = beta2; 35 float weight_decay_lv = weight_decay; 36 float eps_lv = eps; 37 uint8_t maximize_lv = maximize; 38 39 [computeEncoder setBuffer:tensorArgumentBuffer 40 offset:0 41 atIndex:0]; 42 [computeEncoder setBytes:&metadata_arguments 43 length:sizeof(MetadataArguments) 44 atIndex:1]; 45 mtl_setBytes(computeEncoder, lr_lv, 2); 46 mtl_setBytes(computeEncoder, beta1_lv, 3); 47 mtl_setBytes(computeEncoder, beta2_lv, 4); 48 mtl_setBytes(computeEncoder, weight_decay_lv, 5); 49 mtl_setBytes(computeEncoder, eps_lv, 6); 50 mtl_setBytes(computeEncoder, maximize_lv, 7); 51 } 52 operatorFusedAdamEncodingFunctor53 void operator()( 54 id<MTLComputeCommandEncoder>& computeEncoder, 55 id<MTLBuffer>& tensorArgumentBuffer, 56 const MetadataArguments& metadata_arguments, 57 const at::Tensor& lr, 58 const double beta1, 59 const double beta2, 60 const double weight_decay, 61 const double eps, 62 const bool maximize 63 ) const { 64 float beta1_lv = beta1; 65 float beta2_lv = beta2; 66 float weight_decay_lv = weight_decay; 67 float eps_lv = eps; 68 uint8_t maximize_lv = maximize; 69 70 [computeEncoder setBuffer:tensorArgumentBuffer 71 offset:0 72 atIndex:0]; 73 [computeEncoder setBytes:&metadata_arguments 74 length:sizeof(MetadataArguments) 75 atIndex:1]; 76 mtl_setBuffer(computeEncoder, lr, 2); 77 mtl_setBytes(computeEncoder, beta1_lv, 3); 78 mtl_setBytes(computeEncoder, beta2_lv, 4); 79 mtl_setBytes(computeEncoder, weight_decay_lv, 5); 80 mtl_setBytes(computeEncoder, eps_lv, 6); 81 mtl_setBytes(computeEncoder, maximize_lv, 7); 82 } 83 }; 84 85 template <bool momentum> 86 struct FusedSgdEncodingFunctor {}; 87 88 template <> 89 struct FusedSgdEncodingFunctor<true> { 90 void operator()( 91 id<MTLComputeCommandEncoder>& computeEncoder, 92 id<MTLBuffer>& tensorArgumentBuffer, 93 const MetadataArguments& metadata_arguments, 94 const double weight_decay, 95 const double momentum, 96 const double lr, 97 const double dampening, 98 const bool nesterov, 99 const bool maximize, 100 const bool is_first_step 101 ) const { 102 float weight_decay_lv = weight_decay; 103 float momentum_lv = momentum; 104 float lr_lv = lr; 105 float dampening_lv = dampening; 106 uint8_t nesterov_lv = nesterov; 107 uint8_t maximize_lv = maximize; 108 uint8_t is_first_step_lv = is_first_step; 109 110 [computeEncoder setBuffer:tensorArgumentBuffer 111 offset:0 112 atIndex:0]; 113 [computeEncoder setBytes:&metadata_arguments 114 length:sizeof(MetadataArguments) 115 atIndex:1]; 116 mtl_setBytes(computeEncoder, weight_decay_lv, 2); 117 mtl_setBytes(computeEncoder, momentum_lv, 3); 118 mtl_setBytes(computeEncoder, lr_lv, 4); 119 mtl_setBytes(computeEncoder, dampening_lv, 5); 120 mtl_setBytes(computeEncoder, nesterov_lv, 6); 121 mtl_setBytes(computeEncoder, maximize_lv, 7); 122 mtl_setBytes(computeEncoder, is_first_step_lv, 8); 123 } 124 125 void operator()( 126 id<MTLComputeCommandEncoder>& computeEncoder, 127 id<MTLBuffer>& tensorArgumentBuffer, 128 const MetadataArguments& metadata_arguments, 129 const double weight_decay, 130 const double momentum, 131 const at::Tensor& lr, 132 const double dampening, 133 const bool nesterov, 134 const bool maximize, 135 const bool is_first_step 136 ) const { 137 float weight_decay_lv = weight_decay; 138 float momentum_lv = momentum; 139 float dampening_lv = dampening; 140 uint8_t nesterov_lv = nesterov; 141 uint8_t maximize_lv = maximize; 142 uint8_t is_first_step_lv = is_first_step; 143 144 [computeEncoder setBuffer:tensorArgumentBuffer 145 offset:0 146 atIndex:0]; 147 [computeEncoder setBytes:&metadata_arguments 148 length:sizeof(MetadataArguments) 149 atIndex:1]; 150 mtl_setBytes(computeEncoder, weight_decay_lv, 2); 151 mtl_setBytes(computeEncoder, momentum_lv, 3); 152 mtl_setBuffer(computeEncoder, lr, 4); 153 mtl_setBytes(computeEncoder, dampening_lv, 5); 154 mtl_setBytes(computeEncoder, nesterov_lv, 6); 155 mtl_setBytes(computeEncoder, maximize_lv, 7); 156 mtl_setBytes(computeEncoder, is_first_step_lv, 8); 157 } 158 }; 159 160 template <> 161 struct FusedSgdEncodingFunctor<false> { 162 void operator()( 163 id<MTLComputeCommandEncoder>& computeEncoder, 164 id<MTLBuffer>& tensorArgumentBuffer, 165 const MetadataArguments& metadata_arguments, 166 const double weight_decay, 167 const double lr, 168 const bool maximize 169 ) const { 170 float weight_decay_lv = weight_decay; 171 float lr_lv = lr; 172 uint8_t maximize_lv = maximize; 173 174 [computeEncoder setBuffer:tensorArgumentBuffer 175 offset:0 176 atIndex:0]; 177 [computeEncoder setBytes:&metadata_arguments 178 length:sizeof(MetadataArguments) 179 atIndex:1]; 180 mtl_setBytes(computeEncoder, weight_decay_lv, 2); 181 mtl_setBytes(computeEncoder, lr_lv, 3); 182 mtl_setBytes(computeEncoder, maximize_lv, 4); 183 } 184 185 void operator()( 186 id<MTLComputeCommandEncoder>& computeEncoder, 187 id<MTLBuffer>& tensorArgumentBuffer, 188 const MetadataArguments& metadata_arguments, 189 const double weight_decay, 190 const at::Tensor& lr, 191 const bool maximize 192 ) const { 193 float weight_decay_lv = weight_decay; 194 uint8_t maximize_lv = maximize; 195 196 [computeEncoder setBuffer:tensorArgumentBuffer 197 offset:0 198 atIndex:0]; 199 [computeEncoder setBytes:&metadata_arguments 200 length:sizeof(MetadataArguments) 201 atIndex:1]; 202 mtl_setBytes(computeEncoder, weight_decay_lv, 2); 203 mtl_setBuffer(computeEncoder, lr, 3); 204 mtl_setBytes(computeEncoder, maximize_lv, 4); 205 } 206 }; 207 208 template <int depth, uint32_t kThreadGroupSize, typename encoder_func_t, typename... ArgTypes> 209 static void multi_tensor_apply_for_fused_optimizer( 210 const std::string& kernel_name, 211 std::vector<std::vector<at::Tensor>>& tensor_lists, 212 at::TensorList state_steps, 213 encoder_func_t encode, 214 ArgTypes... args 215 ) { 216 const auto num_tensors = tensor_lists[0].size(); 217 218 if (num_tensors == 0) { 219 return; 220 } 221 222 TORCH_CHECK( 223 tensor_lists.size() == depth, 224 "Number of tensor lists has to match the depth"); 225 for (const auto& d : c10::irange(depth)) { 226 TORCH_CHECK( 227 tensor_lists[d][0].scalar_type() == at::ScalarType::Float || tensor_lists[d][0].scalar_type() == at::ScalarType::Half, "Only float and half are supported"); 228 } 229 230 id<MTLDevice> device = MPSDevice::getInstance()->device(); 231 MPSStream* mpsStream = getCurrentMPSStream(); 232 233 // Remove comment for debugging 234 /* 235 mpsStream->addCompletedHandler(^(id<MTLCommandBuffer> cb) { 236 [cb.logs enumerateObjectsUsingBlock:^(NSString* log, NSUInteger idx, BOOL* stop) { 237 NSLog(@"MPSStream: %@", log); 238 } 239 ]; 240 }); 241 */ 242 243 dispatch_sync_with_rethrow(mpsStream->queue(), ^() { 244 @autoreleasepool { 245 id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder(); 246 auto [fusedOptimizerPSO, fusedOptimizerFunc] = getCPLState(kernel_name); 247 248 // this function call is a no-op if MPS Profiler is not enabled 249 getMPSProfiler().beginProfileKernel(fusedOptimizerPSO, kernel_name, {tensor_lists[0]}); 250 251 [computeEncoder setComputePipelineState:fusedOptimizerPSO]; 252 253 // BufferIndex is the index in the kernel function 254 auto tensorArgumentEncoder = [[fusedOptimizerFunc newArgumentEncoderWithBufferIndex:0] autorelease]; 255 id<MTLBuffer> tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease]; 256 [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; 257 258 int64_t tensor_loc = 0; 259 int64_t threadgroup_loc = 0; 260 MetadataArguments metadata_arguments; 261 262 for (const auto tensor_index : c10::irange(num_tensors)) { 263 // short-circuit to avoid adding empty tensors to tensorListMeta 264 if (tensor_lists[0][tensor_index].numel() == 0) { 265 continue; 266 } 267 268 for (const auto& d : c10::irange(depth)) { 269 mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors + tensor_loc); 270 [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageRead | MTLResourceUsageWrite]; 271 } 272 if (state_steps.size() > 0){ 273 mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors + tensor_loc); 274 [computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead]; 275 } 276 metadata_arguments.numels[tensor_loc] = tensor_lists[0][tensor_index].numel(); 277 278 tensor_loc++; 279 280 const auto numel = tensor_lists[0][tensor_index].numel(); 281 const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); 282 TORCH_CHECK(chunks > -1); 283 284 for (const auto& chunk : c10::irange(chunks)) { 285 metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = tensor_loc - 1; 286 metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk; 287 288 threadgroup_loc++; 289 290 const auto tensor_full = tensor_loc == kmaxTensors && chunk == chunks - 1; 291 // Reach the maximum threadgroups per dispatch 292 const auto blocks_full = threadgroup_loc == kmaxThreadGroups; 293 294 if (tensor_full || blocks_full){ 295 encode(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...); 296 MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); 297 uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup]; 298 MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1); 299 [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; 300 301 // Reset 302 threadgroup_loc = 0; 303 if (chunk == chunks - 1) { 304 // last chunk 305 tensor_loc = 0; 306 tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease]; 307 [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; 308 } else { 309 // reuse the current tensor since the current one isn't done. 310 metadata_arguments.numels[0] = metadata_arguments.numels[tensor_loc - 1]; 311 312 tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease]; 313 [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; 314 315 for (const auto& d : c10::irange(depth)) { 316 mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors); 317 [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageWrite | MTLResourceUsageRead]; 318 } 319 if (state_steps.size() > 0){ 320 mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors); 321 [computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead]; 322 } 323 tensor_loc = 1; 324 } 325 } 326 } 327 } 328 329 if (threadgroup_loc != 0) { 330 encode(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...); 331 MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); 332 uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup]; 333 MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1); 334 [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; 335 } 336 337 getMPSProfiler().endProfileKernel(fusedOptimizerPSO); 338 339 } 340 }); 341 } 342 343 } // namespace mps 344 } // namespace at::native 345