• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /***************************************************************************************************
2  * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
3  *reserved. SPDX-License-Identifier: BSD-3-Clause
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright notice,
9  *this list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22  *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23  *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24  *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25  *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26  *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27  *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28  *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29  *POSSIBILITY OF SUCH DAMAGE.
30  *
31  **************************************************************************************************/
32 /*! \file
33     \brief Template for a double-buffered threadblock-scoped GEMM kernel.
34 */
35 
36 #pragma once
37 
38 #include <cutlass/aligned_buffer.h>
39 #include <cutlass/arch/cache_operation.h>
40 #include <cutlass/arch/memory.h>
41 #include <cutlass/arch/mma.h>
42 #include <cutlass/array.h>
43 #include <cutlass/cutlass.h>
44 #include <cutlass/gemm/gemm.h>
45 #include <cutlass/matrix_shape.h>
46 #include <cutlass/numeric_types.h>
47 
48 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h>
49 
50 /////////////////////////////////////////////////////////////////////////////////////////////////
51 
52 namespace cutlass {
53 namespace gemm {
54 namespace threadblock {
55 
56 /////////////////////////////////////////////////////////////////////////////////////////////////
57 
58 /// Structure to compute the matrix product targeting CUDA cores and SIMT math
59 /// instructions.
60 template <
61     /// Size of the Gemm problem - concept: gemm::GemmShape<>
62     typename Shape_,
63     /// Iterates over tiles of A operand in global memory
64     //  (concept: ReadableTileIterator | ForwardTileIterator |
65     //  MaskedTileIterator)
66     typename IteratorA_,
67     /// Iterates over tiles of A operand in shared memory
68     /// (concept: WriteableTileIterator | RandomAccessTileIterator)
69     typename SmemIteratorA_,
70     /// Cache operation for operand A
71     cutlass::arch::CacheOperation::Kind CacheOpA,
72     /// Iterates over tiles of B operand in global memory
73     //  (concept: ReadableTileIterator | ForwardTileIterator |
74     //  MaskedTileIterator)
75     typename IteratorB_,
76     /// Iterates over tiles of B operand in shared memory
77     /// (concept: WriteableTileIterator | RandomAccessTileIterator)
78     typename SmemIteratorB_,
79     /// Cache operation for operand B
80     cutlass::arch::CacheOperation::Kind CacheOpB,
81     /// Data type of accumulator matrix
82     typename ElementC_,
83     /// Data type of accumulator matrix
84     typename LayoutC_,
85     /// Policy describing tuning details (concept: MmaPolicy)
86     typename Policy_,
87     /// Number of stages,
88     int Stages,
89     /// Use zfill or predicate for out-of-bound cp.async
90     SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
91     /// Upper boundon the K dimension
92     int kMaxK = cutlass::platform::numeric_limits<int>::max(),
93     /// Used for partial specialization
94     typename Enable = bool>
95 class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
96  public:
97   ///< Base class
98   using Base = CustomMmaBase<Shape_, Policy_, Stages>;
99   ///< Size of the Gemm problem - concept: gemm::GemmShape<>
100   using Shape = Shape_;
101   ///< Iterates over tiles of A operand in global memory
102   using IteratorA = IteratorA_;
103   ///< Iterates over tiles of B operand in global memory
104   using IteratorB = IteratorB_;
105   ///< Data type of accumulator matrix
106   using ElementC = ElementC_;
107   ///< Layout of accumulator matrix
108   using LayoutC = LayoutC_;
109   ///< Policy describing tuning details
110   using Policy = Policy_;
111 
112   using SmemIteratorA = SmemIteratorA_;
113   using SmemIteratorB = SmemIteratorB_;
114 
115   static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
116   static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
117 
118   //
119   // Dependent types
120   //
121 
122   /// Fragment of accumulator tile
123   using FragmentC = typename Policy::Operator::FragmentC;
124 
125   /// Warp-level Mma
126   using Operator = typename Policy::Operator;
127 
128   /// Minimum architecture is Sm80 to support cp.async
129   using ArchTag = arch::Sm80;
130 
131   /// Complex transform on A operand
132   static ComplexTransform const kTransformA = Operator::kTransformA;
133 
134   /// Complex transform on B operand
135   static ComplexTransform const kTransformB = Operator::kTransformB;
136 
137   /// Internal structure exposed for introspection.
138   struct Detail {
139     static_assert(
140         Base::kWarpGemmIterations > 1,
141         "The pipelined structure requires at least two warp-level "
142         "GEMM operations.");
143 
144     /// Number of cp.async instructions to load one stage of operand A
145     static int const AsyncCopyIterationsPerStageA =
146         IteratorA::ThreadMap::Iterations::kCount;
147 
148     /// Number of cp.async instructions to load one stage of operand B
149     static int const AsyncCopyIterationsPerStageB =
150         IteratorB::ThreadMap::Iterations::kCount;
151 
152     /// Number of stages
153     static int const kStages = Stages;
154 
155     /// Number of cp.async instructions to load on group of operand A
156     static int const kAccessesPerGroupA =
157         (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) /
158         Base::kWarpGemmIterations;
159 
160     /// Number of cp.async instructions to load on group of operand B
161     static int const kAccessesPerGroupB =
162         (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) /
163         Base::kWarpGemmIterations;
164   };
165 
166   static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages;
167   static constexpr int kNumStagesConcurrentLoad =
168       kSmemContainsEntireMat ? Stages : Stages - 1;
169 
170  private:
171   using WarpLoadedFragmentA = typename Operator::FragmentA;
172   using WarpLoadedFragmentB = typename Operator::FragmentB;
173   using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
174   using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
175 
176  private:
177   //
178   // Data members
179   //
180 
181   /// Iterator to write threadblock-scoped tile of A operand to shared memory
182   SmemIteratorA smem_iterator_A_;
183 
184   /// Iterator to write threadblock-scoped tile of B operand to shared memory
185   SmemIteratorB smem_iterator_B_;
186 
187   bool prologue_done_;
188 
189   // Set to `True` to ensure the accumulator will be zero outside the GEMM
190   // footprint
191   bool zero_outside_bounds_;
192 
193  public:
194   /// Construct from tensor references
195   CUTLASS_DEVICE
CustomMmaMultistage(typename Base::SharedStorageA & shared_storageA,typename Base::SharedStorageB & shared_storageB,int thread_idx,int warp_idx,int lane_idx)196   CustomMmaMultistage(
197       ///< Shared storage needed for internal use by threadblock-scoped GEMM
198       typename Base::SharedStorageA& shared_storageA,
199       typename Base::SharedStorageB& shared_storageB,
200       ///< ID within the threadblock
201       int thread_idx,
202       ///< ID of warp
203       int warp_idx,
204       ///< ID of each thread within a warp
205       int lane_idx)
206       : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx),
207         smem_iterator_A_(shared_storageA.ref(), thread_idx),
208         smem_iterator_B_(shared_storageB.ref(), thread_idx),
209         prologue_done_(false),
210         zero_outside_bounds_(false) {
211     // Compute warp location within threadblock tile by mapping the warp_id to
212     // three coordinates:
213     //   _m: the warp's position within the threadblock along the M dimension
214     //   _n: the warp's position within the threadblock along the N dimension
215     //   _k: the warp's position within the threadblock along the K dimension
216 
217     int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
218     int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
219 
220     int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
221     int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
222 
223     // Add per-warp offsets in units of warp-level tiles
224     this->warp_tile_iterator_A_.add_tile_offset(
225         {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
226     this->warp_tile_iterator_B_.add_tile_offset(
227         {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
228   }
229   CUTLASS_DEVICE
CustomMmaMultistage(typename Base::SharedStorage & st,int thread_idx,int warp_idx,int lane_idx)230   CustomMmaMultistage(
231       ///< Shared storage needed for internal use by threadblock-scoped GEMM
232       typename Base::SharedStorage& st,
233       ///< ID within the threadblock
234       int thread_idx,
235       ///< ID of warp
236       int warp_idx,
237       ///< ID of each thread within a warp
238       int lane_idx)
239       : CustomMmaMultistage(
240             st.operand_A,
241             st.operand_B,
242             thread_idx,
243             warp_idx,
244             lane_idx) {}
245 
246   CUTLASS_DEVICE
set_prologue_done(bool value)247   void set_prologue_done(bool value) {
248     prologue_done_ = value;
249   }
250 
251   CUTLASS_DEVICE
set_zero_outside_bounds(bool value)252   void set_zero_outside_bounds(bool value) {
253     zero_outside_bounds_ = value;
254   }
255 
256   template <bool kLoadA = true, bool kLoadB = true>
prologue(typename Base::SharedStorage & shared_storage,IteratorA iterator_A,IteratorB iterator_B,int thread_idx,int problem_size_k)257   CUTLASS_DEVICE static void prologue(
258       typename Base::SharedStorage& shared_storage,
259       ///< iterator over A operand in global memory
260       IteratorA iterator_A,
261       ///< iterator over B operand in global memory
262       IteratorB iterator_B,
263       int thread_idx,
264       int problem_size_k) {
265     prologue<kLoadA, kLoadB>(
266         shared_storage.operand_A,
267         shared_storage.operand_B,
268         iterator_A,
269         iterator_B,
270         thread_idx,
271         problem_size_k);
272   }
273 
274   template <bool kLoadA = true, bool kLoadB = true>
prologue(typename Base::SharedStorageA & shared_storageA,typename Base::SharedStorageB & shared_storageB,IteratorA iterator_A,IteratorB iterator_B,int thread_idx,int problem_size_k)275   CUTLASS_DEVICE static void prologue(
276       typename Base::SharedStorageA& shared_storageA,
277       typename Base::SharedStorageB& shared_storageB,
278       ///< iterator over A operand in global memory
279       IteratorA iterator_A,
280       ///< iterator over B operand in global memory
281       IteratorB iterator_B,
282       int thread_idx,
283       int problem_size_k) {
284     SmemIteratorA smem_iterator_A(shared_storageA.ref(), thread_idx);
285     SmemIteratorB smem_iterator_B(shared_storageB.ref(), thread_idx);
286     int32_t iter = (problem_size_k + Base::Shape::kK - 1) / Base::Shape::kK;
287     _prologue<kLoadA, kLoadB>(
288         iterator_A, iterator_B, iter, smem_iterator_A, smem_iterator_B);
289   }
290 
291   CUTLASS_DEVICE
292   void copy_tiles_and_advance(
293       IteratorA& iterator_A,
294       IteratorB& iterator_B,
295       int group_start_A = 0,
296       int group_start_B = 0) {
297     iterator_A.set_iteration_index(
298         group_start_A * IteratorA::kAccessesPerVector);
299     this->smem_iterator_A_.set_iteration_index(group_start_A);
300 
301     // Async Copy for operand A
302     CUTLASS_PRAGMA_UNROLL
303     for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
304       if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
305         typename IteratorA::AccessType* dst_ptr =
306             reinterpret_cast<typename IteratorA::AccessType*>(
307                 this->smem_iterator_A_.get());
308 
309         int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
310             IteratorA::ThreadMap::kElementsPerAccess /
311             IteratorA::kAccessesPerVector / 8;
312 
313         CUTLASS_PRAGMA_UNROLL
314         for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
315           auto gmem_ptr = iterator_A.get();
316 
317           if (zero_outside_bounds_ ||
318               SharedMemoryClear == SharedMemoryClearOption::kZfill) {
319             cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
320                 dst_ptr + v, gmem_ptr, iterator_A.valid());
321           } else {
322             cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
323                 dst_ptr + v, gmem_ptr, iterator_A.valid());
324           }
325 
326           ++iterator_A;
327         }
328 
329         ++this->smem_iterator_A_;
330       }
331     }
332 
333     iterator_B.set_iteration_index(
334         group_start_B * IteratorB::kAccessesPerVector);
335     this->smem_iterator_B_.set_iteration_index(group_start_B);
336 
337     // Async Copy for operand B
338     CUTLASS_PRAGMA_UNROLL
339     for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
340       if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
341         typename IteratorB::AccessType* dst_ptr =
342             reinterpret_cast<typename IteratorB::AccessType*>(
343                 this->smem_iterator_B_.get());
344 
345         int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
346             IteratorB::ThreadMap::kElementsPerAccess /
347             IteratorB::kAccessesPerVector / 8;
348 
349         CUTLASS_PRAGMA_UNROLL
350         for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
351           auto gmem_ptr = iterator_B.get();
352 
353           if (zero_outside_bounds_ ||
354               SharedMemoryClear == SharedMemoryClearOption::kZfill) {
355             cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
356                 dst_ptr + v, gmem_ptr, iterator_B.valid());
357           } else {
358             cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
359                 dst_ptr + v, gmem_ptr, iterator_B.valid());
360           }
361 
362           ++iterator_B;
363         }
364         ++this->smem_iterator_B_;
365       }
366     }
367   }
368 
369   template <bool kLoadA = true, bool kLoadB = true>
_prologue(IteratorA & iterator_A,IteratorB & iterator_B,int32_t & gemm_k_iterations,SmemIteratorA & smem_iterator_A_,SmemIteratorB & smem_iterator_B_)370   CUTLASS_DEVICE static void _prologue(
371       IteratorA& iterator_A,
372       IteratorB& iterator_B,
373       int32_t& gemm_k_iterations,
374       SmemIteratorA& smem_iterator_A_,
375       SmemIteratorB& smem_iterator_B_) {
376     // Issue several complete stages
377     CUTLASS_PRAGMA_UNROLL
378     for (int stage = 0; stage < kNumStagesConcurrentLoad;
379          ++stage, --gemm_k_iterations) {
380       iterator_A.clear_mask(gemm_k_iterations == 0);
381       iterator_B.clear_mask(gemm_k_iterations == 0);
382 
383       iterator_A.set_iteration_index(0);
384       smem_iterator_A_.set_iteration_index(0);
385 
386       // Async Copy for operand A
387       CUTLASS_PRAGMA_UNROLL
388       for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
389         typename IteratorA::AccessType* dst_ptr =
390             reinterpret_cast<typename IteratorA::AccessType*>(
391                 smem_iterator_A_.get());
392 
393         CUTLASS_PRAGMA_UNROLL
394         for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
395           int const kSrcBytes =
396               sizeof_bits<typename IteratorA::Element>::value *
397               IteratorA::ThreadMap::kElementsPerAccess /
398               IteratorA::kAccessesPerVector / 8;
399 
400           int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
401 
402           if (kLoadA) {
403             cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
404                 dst_ptr + v, iterator_A.get(), iterator_A.valid());
405           }
406 
407           ++iterator_A;
408         }
409 
410         ++smem_iterator_A_;
411       }
412 
413       iterator_B.set_iteration_index(0);
414       smem_iterator_B_.set_iteration_index(0);
415 
416       // Async Copy for operand B
417       CUTLASS_PRAGMA_UNROLL
418       for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
419         typename IteratorB::AccessType* dst_ptr =
420             reinterpret_cast<typename IteratorB::AccessType*>(
421                 smem_iterator_B_.get());
422 
423         CUTLASS_PRAGMA_UNROLL
424         for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
425           int const kSrcBytes =
426               sizeof_bits<typename IteratorB::Element>::value *
427               IteratorB::ThreadMap::kElementsPerAccess /
428               IteratorB::kAccessesPerVector / 8;
429 
430           if (kLoadB) {
431             cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
432                 dst_ptr + v, iterator_B.get(), iterator_B.valid());
433           }
434 
435           ++iterator_B;
436         }
437 
438         ++smem_iterator_B_;
439       }
440 
441       // Move to the next stage
442       iterator_A.add_tile_offset({0, 1});
443       iterator_B.add_tile_offset({1, 0});
444 
445       smem_iterator_A_.add_tile_offset({0, 1});
446       smem_iterator_B_.add_tile_offset({1, 0});
447 
448       // Defines the boundary of a stage of cp.async.
449       cutlass::arch::cp_async_fence();
450     }
451   }
452 
453   /// Perform a threadblock-scoped matrix multiply-accumulate
454   CUTLASS_DEVICE
operator()455   void operator()(
456       ///< problem size of GEMM
457       int gemm_k_iterations,
458       ///< destination accumulator tile
459       FragmentC& accum,
460       ///< iterator over A operand in global memory
461       IteratorA iterator_A,
462       ///< iterator over B operand in global memory
463       IteratorB iterator_B,
464       ///< initial value of accumulator
465       FragmentC const& src_accum) {
466     //
467     // Prologue
468     //
469 
470     if (!prologue_done_) {
471       _prologue<true, true>(
472           iterator_A,
473           iterator_B,
474           gemm_k_iterations,
475           smem_iterator_A_,
476           smem_iterator_B_);
477     } else if (!kSmemContainsEntireMat) {
478       _prologue<false, false>(
479           iterator_A,
480           iterator_B,
481           gemm_k_iterations,
482           smem_iterator_A_,
483           smem_iterator_B_);
484     } else {
485       gemm_k_iterations -= kNumStagesConcurrentLoad;
486     }
487 
488     // Perform accumulation in the 'd' output operand
489     accum = src_accum;
490 
491     //
492     // Clear the remaining tiles of SMEM. This is a functional requirement for
493     // some kernels so that all accumulator elements outside the GEMM footprint
494     // are zero.
495     //
496 
497     if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
498       /// Iterator to write threadblock-scoped tile of A operand to shared
499       /// memory
500       SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
501 
502       typename IteratorA::AccessType zero_A;
503       zero_A.clear();
504 
505       last_smem_iterator_A.set_iteration_index(0);
506 
507       // Async Copy for operand A
508       CUTLASS_PRAGMA_UNROLL
509       for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
510         typename IteratorA::AccessType* dst_ptr =
511             reinterpret_cast<typename IteratorA::AccessType*>(
512                 last_smem_iterator_A.get());
513 
514         *dst_ptr = zero_A;
515 
516         ++last_smem_iterator_A;
517       }
518 
519       /// Iterator to write threadblock-scoped tile of B operand to shared
520       /// memory
521       SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
522       typename IteratorB::AccessType zero_B;
523 
524       zero_B.clear();
525       last_smem_iterator_B.set_iteration_index(0);
526 
527       // Async Copy for operand B
528       CUTLASS_PRAGMA_UNROLL
529       for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
530         typename IteratorB::AccessType* dst_ptr =
531             reinterpret_cast<typename IteratorB::AccessType*>(
532                 last_smem_iterator_B.get());
533 
534         *dst_ptr = zero_B;
535 
536         ++last_smem_iterator_B;
537       }
538     }
539 
540     // Waits until kStages-2 stages have committed.
541     cutlass::arch::cp_async_wait<kNumStagesConcurrentLoad - 1>();
542     __syncthreads();
543 
544     // Pair of fragments used to overlap shared memory loads and math
545     // instructions
546     WarpLoadedFragmentA warp_loaded_frag_A[2];
547     WarpLoadedFragmentB warp_loaded_frag_B[2];
548     WarpTransformedFragmentA warp_transformed_frag_A[2];
549     WarpTransformedFragmentB warp_transformed_frag_B[2];
550 
551     Operator warp_mma;
552 
553     this->warp_tile_iterator_A_.set_kgroup_index(0);
554     this->warp_tile_iterator_B_.set_kgroup_index(0);
555 
556     this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]);
557     this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]);
558 
559     ++this->warp_tile_iterator_A_;
560     ++this->warp_tile_iterator_B_;
561 
562     iterator_A.clear_mask(gemm_k_iterations == 0);
563     iterator_B.clear_mask(gemm_k_iterations == 0);
564 
565     int smem_write_stage_idx = Base::kStages - 1;
566     int smem_read_stage_idx = 0;
567 
568     warp_mma.transform(
569         warp_transformed_frag_A[0],
570         warp_transformed_frag_B[0],
571         warp_loaded_frag_A[0],
572         warp_loaded_frag_B[0]);
573 
574     // tf32x3 kernels use staging accumulation. warp_mma uses a temporary
575     // accumulator and this temporary accumulator is added to the final
576     // accumulator once in every mainloop iteration.
577     plus<FragmentC> plus_accum;
578 
579     FragmentC tmp_accum;
580 
581     if (platform::is_same<
582             typename Operator::MathOperator,
583             arch::OpMultiplyAddFastF32>::value ||
584         platform::is_same<
585             typename Operator::MathOperator,
586             arch::OpMultiplyAddComplexFastF32>::value) {
587       tmp_accum.clear();
588     }
589 
590     //
591     // Mainloop
592     //
593 
594     CUTLASS_GEMM_LOOP
595     for (; gemm_k_iterations > (-kNumStagesConcurrentLoad);) {
596       //
597       // Loop over GEMM K dimension
598       //
599 
600       // Computes a warp-level GEMM on data held in shared memory
601       // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
602       CUTLASS_PRAGMA_UNROLL
603       for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
604            ++warp_mma_k) {
605         // Load warp-level tiles from shared memory, wrapping to k offset if
606         // this is the last group as the case may be.
607 
608         this->warp_tile_iterator_A_.set_kgroup_index(
609             (warp_mma_k + 1) % Base::kWarpGemmIterations);
610         this->warp_tile_iterator_B_.set_kgroup_index(
611             (warp_mma_k + 1) % Base::kWarpGemmIterations);
612 
613         // In case of a non-circular buffer ("kSmemContainsEntireMat")
614         // make sure we don't load out of bounds data.
615         if (!kSmemContainsEntireMat ||
616             gemm_k_iterations > (-kNumStagesConcurrentLoad) ||
617             warp_mma_k < Base::kWarpGemmIterations - 1) {
618           this->warp_tile_iterator_A_.load(
619               warp_loaded_frag_A[(warp_mma_k + 1) % 2]);
620           this->warp_tile_iterator_B_.load(
621               warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
622         }
623 
624         ++this->warp_tile_iterator_A_;
625         ++this->warp_tile_iterator_B_;
626 
627         if (warp_mma_k > 0)
628           warp_mma.transform(
629               warp_transformed_frag_A[warp_mma_k % 2],
630               warp_transformed_frag_B[warp_mma_k % 2],
631               warp_loaded_frag_A[warp_mma_k % 2],
632               warp_loaded_frag_B[warp_mma_k % 2]);
633 
634         if (platform::is_same<
635                 typename Operator::MathOperator,
636                 arch::OpMultiplyAddFastF32>::value ||
637             platform::is_same<
638                 typename Operator::MathOperator,
639                 arch::OpMultiplyAddComplexFastF32>::value) {
640           warp_mma(
641               tmp_accum,
642               warp_transformed_frag_A[warp_mma_k % 2],
643               warp_transformed_frag_B[warp_mma_k % 2],
644               tmp_accum);
645 
646           if (warp_mma_k == 0) {
647             accum = plus_accum(accum, tmp_accum);
648             tmp_accum.clear();
649           }
650         } else {
651           warp_mma(
652               accum,
653               warp_transformed_frag_A[warp_mma_k % 2],
654               warp_transformed_frag_B[warp_mma_k % 2],
655               accum);
656         }
657 
658         // Issue global->shared copies for the this stage
659         if (!kSmemContainsEntireMat &&
660             warp_mma_k < Base::kWarpGemmIterations - 1) {
661           int group_start_iteration_A, group_start_iteration_B;
662 
663           group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
664           group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
665 
666           copy_tiles_and_advance(
667               iterator_A,
668               iterator_B,
669               group_start_iteration_A,
670               group_start_iteration_B);
671         }
672 
673         if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
674           if (!kSmemContainsEntireMat) {
675             int group_start_iteration_A, group_start_iteration_B;
676             group_start_iteration_A =
677                 (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
678             group_start_iteration_B =
679                 (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
680 
681             copy_tiles_and_advance(
682                 iterator_A,
683                 iterator_B,
684                 group_start_iteration_A,
685                 group_start_iteration_B);
686           }
687 
688           // Inserts a memory fence between stages of cp.async instructions.
689           cutlass::arch::cp_async_fence();
690 
691           // Waits until kStages-2 stages have committed.
692           cutlass::arch::cp_async_wait<kNumStagesConcurrentLoad - 1>();
693           __syncthreads();
694 
695           // Move to the next stage
696           iterator_A.add_tile_offset({0, 1});
697           iterator_B.add_tile_offset({1, 0});
698 
699           this->smem_iterator_A_.add_tile_offset({0, 1});
700           this->smem_iterator_B_.add_tile_offset({1, 0});
701 
702           // Add negative offsets to return iterators to the 'start' of the
703           // circular buffer in shared memory
704           if (smem_write_stage_idx == (Base::kStages - 1)) {
705             this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
706             this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
707             smem_write_stage_idx = 0;
708           } else {
709             ++smem_write_stage_idx;
710           }
711 
712           if (!kSmemContainsEntireMat &&
713               smem_read_stage_idx == (Base::kStages - 1)) {
714             this->warp_tile_iterator_A_.add_tile_offset(
715                 {0,
716                  -Base::kStages * Policy::kPartitionsK *
717                      Base::kWarpGemmIterations});
718             this->warp_tile_iterator_B_.add_tile_offset(
719                 {-Base::kStages * Policy::kPartitionsK *
720                      Base::kWarpGemmIterations,
721                  0});
722             smem_read_stage_idx = 0;
723           } else {
724             ++smem_read_stage_idx;
725           }
726 
727           --gemm_k_iterations;
728           iterator_A.clear_mask(gemm_k_iterations == 0);
729           iterator_B.clear_mask(gemm_k_iterations == 0);
730         }
731 
732         // Do any conversions feeding the first stage at the end of the loop so
733         // we can start right away on mma instructions
734         if (warp_mma_k + 1 == Base::kWarpGemmIterations)
735           warp_mma.transform(
736               warp_transformed_frag_A[(warp_mma_k + 1) % 2],
737               warp_transformed_frag_B[(warp_mma_k + 1) % 2],
738               warp_loaded_frag_A[(warp_mma_k + 1) % 2],
739               warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
740       }
741     }
742 
743     if (platform::is_same<
744             typename Operator::MathOperator,
745             arch::OpMultiplyAddFastF32>::value ||
746         platform::is_same<
747             typename Operator::MathOperator,
748             arch::OpMultiplyAddComplexFastF32>::value) {
749       accum = plus_accum(accum, tmp_accum);
750     }
751 
752     if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
753       // commit and drain all pending and predicated cp.async pnz from the GEMM
754       // mainloop
755       cutlass::arch::cp_async_fence();
756       cutlass::arch::cp_async_wait<0>();
757       __syncthreads();
758     }
759   }
760 };
761 
762 /////////////////////////////////////////////////////////////////////////////////////////////////
763 
764 } // namespace threadblock
765 } // namespace gemm
766 } // namespace cutlass
767 
768 /////////////////////////////////////////////////////////////////////////////////////////////////
769