1 /*************************************************************************************************** 2 * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights 3 *reserved. SPDX-License-Identifier: BSD-3-Clause 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, 9 *this list of conditions and the following disclaimer. 10 * 11 * 2. Redistributions in binary form must reproduce the above copyright notice, 12 * this list of conditions and the following disclaimer in the documentation 13 * and/or other materials provided with the distribution. 14 * 15 * 3. Neither the name of the copyright holder nor the names of its 16 * contributors may be used to endorse or promote products derived from 17 * this software without specific prior written permission. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 *POSSIBILITY OF SUCH DAMAGE. 30 * 31 **************************************************************************************************/ 32 /*! \file 33 \brief Template for a double-buffered threadblock-scoped GEMM kernel. 34 */ 35 36 #pragma once 37 38 #include <cutlass/aligned_buffer.h> 39 #include <cutlass/arch/cache_operation.h> 40 #include <cutlass/arch/memory.h> 41 #include <cutlass/arch/mma.h> 42 #include <cutlass/array.h> 43 #include <cutlass/cutlass.h> 44 #include <cutlass/gemm/gemm.h> 45 #include <cutlass/matrix_shape.h> 46 #include <cutlass/numeric_types.h> 47 48 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h> 49 50 ///////////////////////////////////////////////////////////////////////////////////////////////// 51 52 namespace cutlass { 53 namespace gemm { 54 namespace threadblock { 55 56 ///////////////////////////////////////////////////////////////////////////////////////////////// 57 58 /// Structure to compute the matrix product targeting CUDA cores and SIMT math 59 /// instructions. 60 template < 61 /// Size of the Gemm problem - concept: gemm::GemmShape<> 62 typename Shape_, 63 /// Iterates over tiles of A operand in global memory 64 // (concept: ReadableTileIterator | ForwardTileIterator | 65 // MaskedTileIterator) 66 typename IteratorA_, 67 /// Iterates over tiles of A operand in shared memory 68 /// (concept: WriteableTileIterator | RandomAccessTileIterator) 69 typename SmemIteratorA_, 70 /// Cache operation for operand A 71 cutlass::arch::CacheOperation::Kind CacheOpA, 72 /// Iterates over tiles of B operand in global memory 73 // (concept: ReadableTileIterator | ForwardTileIterator | 74 // MaskedTileIterator) 75 typename IteratorB_, 76 /// Iterates over tiles of B operand in shared memory 77 /// (concept: WriteableTileIterator | RandomAccessTileIterator) 78 typename SmemIteratorB_, 79 /// Cache operation for operand B 80 cutlass::arch::CacheOperation::Kind CacheOpB, 81 /// Data type of accumulator matrix 82 typename ElementC_, 83 /// Data type of accumulator matrix 84 typename LayoutC_, 85 /// Policy describing tuning details (concept: MmaPolicy) 86 typename Policy_, 87 /// Number of stages, 88 int Stages, 89 /// Use zfill or predicate for out-of-bound cp.async 90 SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, 91 /// Upper boundon the K dimension 92 int kMaxK = cutlass::platform::numeric_limits<int>::max(), 93 /// Used for partial specialization 94 typename Enable = bool> 95 class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> { 96 public: 97 ///< Base class 98 using Base = CustomMmaBase<Shape_, Policy_, Stages>; 99 ///< Size of the Gemm problem - concept: gemm::GemmShape<> 100 using Shape = Shape_; 101 ///< Iterates over tiles of A operand in global memory 102 using IteratorA = IteratorA_; 103 ///< Iterates over tiles of B operand in global memory 104 using IteratorB = IteratorB_; 105 ///< Data type of accumulator matrix 106 using ElementC = ElementC_; 107 ///< Layout of accumulator matrix 108 using LayoutC = LayoutC_; 109 ///< Policy describing tuning details 110 using Policy = Policy_; 111 112 using SmemIteratorA = SmemIteratorA_; 113 using SmemIteratorB = SmemIteratorB_; 114 115 static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; 116 static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; 117 118 // 119 // Dependent types 120 // 121 122 /// Fragment of accumulator tile 123 using FragmentC = typename Policy::Operator::FragmentC; 124 125 /// Warp-level Mma 126 using Operator = typename Policy::Operator; 127 128 /// Minimum architecture is Sm80 to support cp.async 129 using ArchTag = arch::Sm80; 130 131 /// Complex transform on A operand 132 static ComplexTransform const kTransformA = Operator::kTransformA; 133 134 /// Complex transform on B operand 135 static ComplexTransform const kTransformB = Operator::kTransformB; 136 137 /// Internal structure exposed for introspection. 138 struct Detail { 139 static_assert( 140 Base::kWarpGemmIterations > 1, 141 "The pipelined structure requires at least two warp-level " 142 "GEMM operations."); 143 144 /// Number of cp.async instructions to load one stage of operand A 145 static int const AsyncCopyIterationsPerStageA = 146 IteratorA::ThreadMap::Iterations::kCount; 147 148 /// Number of cp.async instructions to load one stage of operand B 149 static int const AsyncCopyIterationsPerStageB = 150 IteratorB::ThreadMap::Iterations::kCount; 151 152 /// Number of stages 153 static int const kStages = Stages; 154 155 /// Number of cp.async instructions to load on group of operand A 156 static int const kAccessesPerGroupA = 157 (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / 158 Base::kWarpGemmIterations; 159 160 /// Number of cp.async instructions to load on group of operand B 161 static int const kAccessesPerGroupB = 162 (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / 163 Base::kWarpGemmIterations; 164 }; 165 166 static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages; 167 static constexpr int kNumStagesConcurrentLoad = 168 kSmemContainsEntireMat ? Stages : Stages - 1; 169 170 private: 171 using WarpLoadedFragmentA = typename Operator::FragmentA; 172 using WarpLoadedFragmentB = typename Operator::FragmentB; 173 using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; 174 using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; 175 176 private: 177 // 178 // Data members 179 // 180 181 /// Iterator to write threadblock-scoped tile of A operand to shared memory 182 SmemIteratorA smem_iterator_A_; 183 184 /// Iterator to write threadblock-scoped tile of B operand to shared memory 185 SmemIteratorB smem_iterator_B_; 186 187 bool prologue_done_; 188 189 // Set to `True` to ensure the accumulator will be zero outside the GEMM 190 // footprint 191 bool zero_outside_bounds_; 192 193 public: 194 /// Construct from tensor references 195 CUTLASS_DEVICE CustomMmaMultistage(typename Base::SharedStorageA & shared_storageA,typename Base::SharedStorageB & shared_storageB,int thread_idx,int warp_idx,int lane_idx)196 CustomMmaMultistage( 197 ///< Shared storage needed for internal use by threadblock-scoped GEMM 198 typename Base::SharedStorageA& shared_storageA, 199 typename Base::SharedStorageB& shared_storageB, 200 ///< ID within the threadblock 201 int thread_idx, 202 ///< ID of warp 203 int warp_idx, 204 ///< ID of each thread within a warp 205 int lane_idx) 206 : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), 207 smem_iterator_A_(shared_storageA.ref(), thread_idx), 208 smem_iterator_B_(shared_storageB.ref(), thread_idx), 209 prologue_done_(false), 210 zero_outside_bounds_(false) { 211 // Compute warp location within threadblock tile by mapping the warp_id to 212 // three coordinates: 213 // _m: the warp's position within the threadblock along the M dimension 214 // _n: the warp's position within the threadblock along the N dimension 215 // _k: the warp's position within the threadblock along the K dimension 216 217 int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); 218 int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); 219 220 int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; 221 int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; 222 223 // Add per-warp offsets in units of warp-level tiles 224 this->warp_tile_iterator_A_.add_tile_offset( 225 {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); 226 this->warp_tile_iterator_B_.add_tile_offset( 227 {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); 228 } 229 CUTLASS_DEVICE CustomMmaMultistage(typename Base::SharedStorage & st,int thread_idx,int warp_idx,int lane_idx)230 CustomMmaMultistage( 231 ///< Shared storage needed for internal use by threadblock-scoped GEMM 232 typename Base::SharedStorage& st, 233 ///< ID within the threadblock 234 int thread_idx, 235 ///< ID of warp 236 int warp_idx, 237 ///< ID of each thread within a warp 238 int lane_idx) 239 : CustomMmaMultistage( 240 st.operand_A, 241 st.operand_B, 242 thread_idx, 243 warp_idx, 244 lane_idx) {} 245 246 CUTLASS_DEVICE set_prologue_done(bool value)247 void set_prologue_done(bool value) { 248 prologue_done_ = value; 249 } 250 251 CUTLASS_DEVICE set_zero_outside_bounds(bool value)252 void set_zero_outside_bounds(bool value) { 253 zero_outside_bounds_ = value; 254 } 255 256 template <bool kLoadA = true, bool kLoadB = true> prologue(typename Base::SharedStorage & shared_storage,IteratorA iterator_A,IteratorB iterator_B,int thread_idx,int problem_size_k)257 CUTLASS_DEVICE static void prologue( 258 typename Base::SharedStorage& shared_storage, 259 ///< iterator over A operand in global memory 260 IteratorA iterator_A, 261 ///< iterator over B operand in global memory 262 IteratorB iterator_B, 263 int thread_idx, 264 int problem_size_k) { 265 prologue<kLoadA, kLoadB>( 266 shared_storage.operand_A, 267 shared_storage.operand_B, 268 iterator_A, 269 iterator_B, 270 thread_idx, 271 problem_size_k); 272 } 273 274 template <bool kLoadA = true, bool kLoadB = true> prologue(typename Base::SharedStorageA & shared_storageA,typename Base::SharedStorageB & shared_storageB,IteratorA iterator_A,IteratorB iterator_B,int thread_idx,int problem_size_k)275 CUTLASS_DEVICE static void prologue( 276 typename Base::SharedStorageA& shared_storageA, 277 typename Base::SharedStorageB& shared_storageB, 278 ///< iterator over A operand in global memory 279 IteratorA iterator_A, 280 ///< iterator over B operand in global memory 281 IteratorB iterator_B, 282 int thread_idx, 283 int problem_size_k) { 284 SmemIteratorA smem_iterator_A(shared_storageA.ref(), thread_idx); 285 SmemIteratorB smem_iterator_B(shared_storageB.ref(), thread_idx); 286 int32_t iter = (problem_size_k + Base::Shape::kK - 1) / Base::Shape::kK; 287 _prologue<kLoadA, kLoadB>( 288 iterator_A, iterator_B, iter, smem_iterator_A, smem_iterator_B); 289 } 290 291 CUTLASS_DEVICE 292 void copy_tiles_and_advance( 293 IteratorA& iterator_A, 294 IteratorB& iterator_B, 295 int group_start_A = 0, 296 int group_start_B = 0) { 297 iterator_A.set_iteration_index( 298 group_start_A * IteratorA::kAccessesPerVector); 299 this->smem_iterator_A_.set_iteration_index(group_start_A); 300 301 // Async Copy for operand A 302 CUTLASS_PRAGMA_UNROLL 303 for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { 304 if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { 305 typename IteratorA::AccessType* dst_ptr = 306 reinterpret_cast<typename IteratorA::AccessType*>( 307 this->smem_iterator_A_.get()); 308 309 int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value * 310 IteratorA::ThreadMap::kElementsPerAccess / 311 IteratorA::kAccessesPerVector / 8; 312 313 CUTLASS_PRAGMA_UNROLL 314 for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { 315 auto gmem_ptr = iterator_A.get(); 316 317 if (zero_outside_bounds_ || 318 SharedMemoryClear == SharedMemoryClearOption::kZfill) { 319 cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>( 320 dst_ptr + v, gmem_ptr, iterator_A.valid()); 321 } else { 322 cutlass::arch::cp_async<kSrcBytes, kCacheOpA>( 323 dst_ptr + v, gmem_ptr, iterator_A.valid()); 324 } 325 326 ++iterator_A; 327 } 328 329 ++this->smem_iterator_A_; 330 } 331 } 332 333 iterator_B.set_iteration_index( 334 group_start_B * IteratorB::kAccessesPerVector); 335 this->smem_iterator_B_.set_iteration_index(group_start_B); 336 337 // Async Copy for operand B 338 CUTLASS_PRAGMA_UNROLL 339 for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { 340 if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { 341 typename IteratorB::AccessType* dst_ptr = 342 reinterpret_cast<typename IteratorB::AccessType*>( 343 this->smem_iterator_B_.get()); 344 345 int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value * 346 IteratorB::ThreadMap::kElementsPerAccess / 347 IteratorB::kAccessesPerVector / 8; 348 349 CUTLASS_PRAGMA_UNROLL 350 for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { 351 auto gmem_ptr = iterator_B.get(); 352 353 if (zero_outside_bounds_ || 354 SharedMemoryClear == SharedMemoryClearOption::kZfill) { 355 cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>( 356 dst_ptr + v, gmem_ptr, iterator_B.valid()); 357 } else { 358 cutlass::arch::cp_async<kSrcBytes, kCacheOpB>( 359 dst_ptr + v, gmem_ptr, iterator_B.valid()); 360 } 361 362 ++iterator_B; 363 } 364 ++this->smem_iterator_B_; 365 } 366 } 367 } 368 369 template <bool kLoadA = true, bool kLoadB = true> _prologue(IteratorA & iterator_A,IteratorB & iterator_B,int32_t & gemm_k_iterations,SmemIteratorA & smem_iterator_A_,SmemIteratorB & smem_iterator_B_)370 CUTLASS_DEVICE static void _prologue( 371 IteratorA& iterator_A, 372 IteratorB& iterator_B, 373 int32_t& gemm_k_iterations, 374 SmemIteratorA& smem_iterator_A_, 375 SmemIteratorB& smem_iterator_B_) { 376 // Issue several complete stages 377 CUTLASS_PRAGMA_UNROLL 378 for (int stage = 0; stage < kNumStagesConcurrentLoad; 379 ++stage, --gemm_k_iterations) { 380 iterator_A.clear_mask(gemm_k_iterations == 0); 381 iterator_B.clear_mask(gemm_k_iterations == 0); 382 383 iterator_A.set_iteration_index(0); 384 smem_iterator_A_.set_iteration_index(0); 385 386 // Async Copy for operand A 387 CUTLASS_PRAGMA_UNROLL 388 for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { 389 typename IteratorA::AccessType* dst_ptr = 390 reinterpret_cast<typename IteratorA::AccessType*>( 391 smem_iterator_A_.get()); 392 393 CUTLASS_PRAGMA_UNROLL 394 for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { 395 int const kSrcBytes = 396 sizeof_bits<typename IteratorA::Element>::value * 397 IteratorA::ThreadMap::kElementsPerAccess / 398 IteratorA::kAccessesPerVector / 8; 399 400 int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); 401 402 if (kLoadA) { 403 cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>( 404 dst_ptr + v, iterator_A.get(), iterator_A.valid()); 405 } 406 407 ++iterator_A; 408 } 409 410 ++smem_iterator_A_; 411 } 412 413 iterator_B.set_iteration_index(0); 414 smem_iterator_B_.set_iteration_index(0); 415 416 // Async Copy for operand B 417 CUTLASS_PRAGMA_UNROLL 418 for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { 419 typename IteratorB::AccessType* dst_ptr = 420 reinterpret_cast<typename IteratorB::AccessType*>( 421 smem_iterator_B_.get()); 422 423 CUTLASS_PRAGMA_UNROLL 424 for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { 425 int const kSrcBytes = 426 sizeof_bits<typename IteratorB::Element>::value * 427 IteratorB::ThreadMap::kElementsPerAccess / 428 IteratorB::kAccessesPerVector / 8; 429 430 if (kLoadB) { 431 cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>( 432 dst_ptr + v, iterator_B.get(), iterator_B.valid()); 433 } 434 435 ++iterator_B; 436 } 437 438 ++smem_iterator_B_; 439 } 440 441 // Move to the next stage 442 iterator_A.add_tile_offset({0, 1}); 443 iterator_B.add_tile_offset({1, 0}); 444 445 smem_iterator_A_.add_tile_offset({0, 1}); 446 smem_iterator_B_.add_tile_offset({1, 0}); 447 448 // Defines the boundary of a stage of cp.async. 449 cutlass::arch::cp_async_fence(); 450 } 451 } 452 453 /// Perform a threadblock-scoped matrix multiply-accumulate 454 CUTLASS_DEVICE operator()455 void operator()( 456 ///< problem size of GEMM 457 int gemm_k_iterations, 458 ///< destination accumulator tile 459 FragmentC& accum, 460 ///< iterator over A operand in global memory 461 IteratorA iterator_A, 462 ///< iterator over B operand in global memory 463 IteratorB iterator_B, 464 ///< initial value of accumulator 465 FragmentC const& src_accum) { 466 // 467 // Prologue 468 // 469 470 if (!prologue_done_) { 471 _prologue<true, true>( 472 iterator_A, 473 iterator_B, 474 gemm_k_iterations, 475 smem_iterator_A_, 476 smem_iterator_B_); 477 } else if (!kSmemContainsEntireMat) { 478 _prologue<false, false>( 479 iterator_A, 480 iterator_B, 481 gemm_k_iterations, 482 smem_iterator_A_, 483 smem_iterator_B_); 484 } else { 485 gemm_k_iterations -= kNumStagesConcurrentLoad; 486 } 487 488 // Perform accumulation in the 'd' output operand 489 accum = src_accum; 490 491 // 492 // Clear the remaining tiles of SMEM. This is a functional requirement for 493 // some kernels so that all accumulator elements outside the GEMM footprint 494 // are zero. 495 // 496 497 if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { 498 /// Iterator to write threadblock-scoped tile of A operand to shared 499 /// memory 500 SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); 501 502 typename IteratorA::AccessType zero_A; 503 zero_A.clear(); 504 505 last_smem_iterator_A.set_iteration_index(0); 506 507 // Async Copy for operand A 508 CUTLASS_PRAGMA_UNROLL 509 for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { 510 typename IteratorA::AccessType* dst_ptr = 511 reinterpret_cast<typename IteratorA::AccessType*>( 512 last_smem_iterator_A.get()); 513 514 *dst_ptr = zero_A; 515 516 ++last_smem_iterator_A; 517 } 518 519 /// Iterator to write threadblock-scoped tile of B operand to shared 520 /// memory 521 SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); 522 typename IteratorB::AccessType zero_B; 523 524 zero_B.clear(); 525 last_smem_iterator_B.set_iteration_index(0); 526 527 // Async Copy for operand B 528 CUTLASS_PRAGMA_UNROLL 529 for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { 530 typename IteratorB::AccessType* dst_ptr = 531 reinterpret_cast<typename IteratorB::AccessType*>( 532 last_smem_iterator_B.get()); 533 534 *dst_ptr = zero_B; 535 536 ++last_smem_iterator_B; 537 } 538 } 539 540 // Waits until kStages-2 stages have committed. 541 cutlass::arch::cp_async_wait<kNumStagesConcurrentLoad - 1>(); 542 __syncthreads(); 543 544 // Pair of fragments used to overlap shared memory loads and math 545 // instructions 546 WarpLoadedFragmentA warp_loaded_frag_A[2]; 547 WarpLoadedFragmentB warp_loaded_frag_B[2]; 548 WarpTransformedFragmentA warp_transformed_frag_A[2]; 549 WarpTransformedFragmentB warp_transformed_frag_B[2]; 550 551 Operator warp_mma; 552 553 this->warp_tile_iterator_A_.set_kgroup_index(0); 554 this->warp_tile_iterator_B_.set_kgroup_index(0); 555 556 this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); 557 this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); 558 559 ++this->warp_tile_iterator_A_; 560 ++this->warp_tile_iterator_B_; 561 562 iterator_A.clear_mask(gemm_k_iterations == 0); 563 iterator_B.clear_mask(gemm_k_iterations == 0); 564 565 int smem_write_stage_idx = Base::kStages - 1; 566 int smem_read_stage_idx = 0; 567 568 warp_mma.transform( 569 warp_transformed_frag_A[0], 570 warp_transformed_frag_B[0], 571 warp_loaded_frag_A[0], 572 warp_loaded_frag_B[0]); 573 574 // tf32x3 kernels use staging accumulation. warp_mma uses a temporary 575 // accumulator and this temporary accumulator is added to the final 576 // accumulator once in every mainloop iteration. 577 plus<FragmentC> plus_accum; 578 579 FragmentC tmp_accum; 580 581 if (platform::is_same< 582 typename Operator::MathOperator, 583 arch::OpMultiplyAddFastF32>::value || 584 platform::is_same< 585 typename Operator::MathOperator, 586 arch::OpMultiplyAddComplexFastF32>::value) { 587 tmp_accum.clear(); 588 } 589 590 // 591 // Mainloop 592 // 593 594 CUTLASS_GEMM_LOOP 595 for (; gemm_k_iterations > (-kNumStagesConcurrentLoad);) { 596 // 597 // Loop over GEMM K dimension 598 // 599 600 // Computes a warp-level GEMM on data held in shared memory 601 // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate 602 CUTLASS_PRAGMA_UNROLL 603 for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; 604 ++warp_mma_k) { 605 // Load warp-level tiles from shared memory, wrapping to k offset if 606 // this is the last group as the case may be. 607 608 this->warp_tile_iterator_A_.set_kgroup_index( 609 (warp_mma_k + 1) % Base::kWarpGemmIterations); 610 this->warp_tile_iterator_B_.set_kgroup_index( 611 (warp_mma_k + 1) % Base::kWarpGemmIterations); 612 613 // In case of a non-circular buffer ("kSmemContainsEntireMat") 614 // make sure we don't load out of bounds data. 615 if (!kSmemContainsEntireMat || 616 gemm_k_iterations > (-kNumStagesConcurrentLoad) || 617 warp_mma_k < Base::kWarpGemmIterations - 1) { 618 this->warp_tile_iterator_A_.load( 619 warp_loaded_frag_A[(warp_mma_k + 1) % 2]); 620 this->warp_tile_iterator_B_.load( 621 warp_loaded_frag_B[(warp_mma_k + 1) % 2]); 622 } 623 624 ++this->warp_tile_iterator_A_; 625 ++this->warp_tile_iterator_B_; 626 627 if (warp_mma_k > 0) 628 warp_mma.transform( 629 warp_transformed_frag_A[warp_mma_k % 2], 630 warp_transformed_frag_B[warp_mma_k % 2], 631 warp_loaded_frag_A[warp_mma_k % 2], 632 warp_loaded_frag_B[warp_mma_k % 2]); 633 634 if (platform::is_same< 635 typename Operator::MathOperator, 636 arch::OpMultiplyAddFastF32>::value || 637 platform::is_same< 638 typename Operator::MathOperator, 639 arch::OpMultiplyAddComplexFastF32>::value) { 640 warp_mma( 641 tmp_accum, 642 warp_transformed_frag_A[warp_mma_k % 2], 643 warp_transformed_frag_B[warp_mma_k % 2], 644 tmp_accum); 645 646 if (warp_mma_k == 0) { 647 accum = plus_accum(accum, tmp_accum); 648 tmp_accum.clear(); 649 } 650 } else { 651 warp_mma( 652 accum, 653 warp_transformed_frag_A[warp_mma_k % 2], 654 warp_transformed_frag_B[warp_mma_k % 2], 655 accum); 656 } 657 658 // Issue global->shared copies for the this stage 659 if (!kSmemContainsEntireMat && 660 warp_mma_k < Base::kWarpGemmIterations - 1) { 661 int group_start_iteration_A, group_start_iteration_B; 662 663 group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; 664 group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; 665 666 copy_tiles_and_advance( 667 iterator_A, 668 iterator_B, 669 group_start_iteration_A, 670 group_start_iteration_B); 671 } 672 673 if (warp_mma_k + 2 == Base::kWarpGemmIterations) { 674 if (!kSmemContainsEntireMat) { 675 int group_start_iteration_A, group_start_iteration_B; 676 group_start_iteration_A = 677 (warp_mma_k + 1) * Detail::kAccessesPerGroupA; 678 group_start_iteration_B = 679 (warp_mma_k + 1) * Detail::kAccessesPerGroupB; 680 681 copy_tiles_and_advance( 682 iterator_A, 683 iterator_B, 684 group_start_iteration_A, 685 group_start_iteration_B); 686 } 687 688 // Inserts a memory fence between stages of cp.async instructions. 689 cutlass::arch::cp_async_fence(); 690 691 // Waits until kStages-2 stages have committed. 692 cutlass::arch::cp_async_wait<kNumStagesConcurrentLoad - 1>(); 693 __syncthreads(); 694 695 // Move to the next stage 696 iterator_A.add_tile_offset({0, 1}); 697 iterator_B.add_tile_offset({1, 0}); 698 699 this->smem_iterator_A_.add_tile_offset({0, 1}); 700 this->smem_iterator_B_.add_tile_offset({1, 0}); 701 702 // Add negative offsets to return iterators to the 'start' of the 703 // circular buffer in shared memory 704 if (smem_write_stage_idx == (Base::kStages - 1)) { 705 this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); 706 this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); 707 smem_write_stage_idx = 0; 708 } else { 709 ++smem_write_stage_idx; 710 } 711 712 if (!kSmemContainsEntireMat && 713 smem_read_stage_idx == (Base::kStages - 1)) { 714 this->warp_tile_iterator_A_.add_tile_offset( 715 {0, 716 -Base::kStages * Policy::kPartitionsK * 717 Base::kWarpGemmIterations}); 718 this->warp_tile_iterator_B_.add_tile_offset( 719 {-Base::kStages * Policy::kPartitionsK * 720 Base::kWarpGemmIterations, 721 0}); 722 smem_read_stage_idx = 0; 723 } else { 724 ++smem_read_stage_idx; 725 } 726 727 --gemm_k_iterations; 728 iterator_A.clear_mask(gemm_k_iterations == 0); 729 iterator_B.clear_mask(gemm_k_iterations == 0); 730 } 731 732 // Do any conversions feeding the first stage at the end of the loop so 733 // we can start right away on mma instructions 734 if (warp_mma_k + 1 == Base::kWarpGemmIterations) 735 warp_mma.transform( 736 warp_transformed_frag_A[(warp_mma_k + 1) % 2], 737 warp_transformed_frag_B[(warp_mma_k + 1) % 2], 738 warp_loaded_frag_A[(warp_mma_k + 1) % 2], 739 warp_loaded_frag_B[(warp_mma_k + 1) % 2]); 740 } 741 } 742 743 if (platform::is_same< 744 typename Operator::MathOperator, 745 arch::OpMultiplyAddFastF32>::value || 746 platform::is_same< 747 typename Operator::MathOperator, 748 arch::OpMultiplyAddComplexFastF32>::value) { 749 accum = plus_accum(accum, tmp_accum); 750 } 751 752 if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { 753 // commit and drain all pending and predicated cp.async pnz from the GEMM 754 // mainloop 755 cutlass::arch::cp_async_fence(); 756 cutlass::arch::cp_async_wait<0>(); 757 __syncthreads(); 758 } 759 } 760 }; 761 762 ///////////////////////////////////////////////////////////////////////////////////////////////// 763 764 } // namespace threadblock 765 } // namespace gemm 766 } // namespace cutlass 767 768 ///////////////////////////////////////////////////////////////////////////////////////////////// 769