• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /***************************************************************************************************
2  * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
3  *reserved. SPDX-License-Identifier: BSD-3-Clause
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright notice,
9  *this list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22  *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23  *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24  *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25  *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26  *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27  *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28  *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29  *POSSIBILITY OF SUCH DAMAGE.
30  *
31  **************************************************************************************************/
32 /*! \file
33   \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
34 
35   File copied from <cutlass/epilogue/threadblock/epilogue.h>
36   then modified to:
37   (1) load 2 source fragments at the same time (pipelining)
38   (2) support reading from a different dtype
39   (3) pass the row id to the OutputOp if it takes it
40     (see MemoryEfficientAttentionNormalize)
41   Note that in general the fragment passed to the OutputOp could
42   span multiple rows but it does not happen with the configurations we have
43 */
44 
45 #pragma once
46 
47 #if defined(__CUDACC_RTC__)
48 #include <cuda/std/cassert>
49 #else
50 #include <cassert>
51 #endif
52 
53 #include <cutlass/aligned_buffer.h>
54 #include <cutlass/array.h>
55 #include <cutlass/cutlass.h>
56 #include <cutlass/functional.h>
57 #include <cutlass/layout/tensor.h>
58 #include <cutlass/layout/vector.h>
59 #include <cutlass/numeric_types.h>
60 #include <cutlass/tensor_coord.h>
61 
62 #include <cutlass/gemm/gemm.h>
63 
64 #include <cutlass/transform/pitch_linear_thread_map.h>
65 #include <cutlass/transform/threadblock/regular_tile_iterator.h>
66 
67 #include <cutlass/epilogue/threadblock/epilogue_base.h>
68 #include <cutlass/epilogue/threadblock/predicated_tile_iterator.h>
69 #include <cutlass/numeric_types.h>
70 
71 ////////////////////////////////////////////////////////////////////////////////
72 
73 namespace cutlass {
74 namespace epilogue {
75 namespace threadblock {
76 
77 template <typename Op>
78 struct ApplyEpilogueOp {
applyApplyEpilogueOp79   static CUTLASS_DEVICE typename Op::FragmentOutput apply(
80       Op const& output_op,
81       int row_id,
82       typename Op::FragmentAccumulator const& accum,
83       typename Op::FragmentOutput const& source) {
84     return output_op(accum, source);
85   }
applyApplyEpilogueOp86   static CUTLASS_DEVICE typename Op::FragmentOutput apply(
87       Op const& output_op,
88       int row_id,
89       typename Op::FragmentAccumulator const& accum) {
90     return output_op(accum);
91   }
92 };
93 
94 ////////////////////////////////////////////////////////////////////////////////
95 
96 /// Epilogue operator
97 template <
98     typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
99     typename WarpMmaOperator_, ///< Warp-level MMA operator (concept:
100                                ///< gemm::warp::MmaTensorOp)
101     int PartitionsK, ///< Number of partitions of the K dimension
102     typename OutputTileIterator_, ///< Tile iterator writing output tensors
103     typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting
104                                            ///< accumulators
105     typename WarpTileIterator_, ///< Warp-scoped tile iterator writing
106                                 ///< accumulators to SMEM
107     typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading
108                                   ///< from SMEM
109     typename OutputOp_, ///< Output operator
110     typename Padding_, ///< Padding added to SMEM allocation to avoid bank
111                        ///< conflicts (concept: MatrixShape)
112     int FragmentsPerPartition =
113         1, ///< Used to coarsten the epilogue granularity
114     int IterationsUnroll = ///< Used to reduce binary size when epilogue op is
115                            ///< large
116     (!IsEpilogueFunctorHeavy<OutputOp_>::value),
117     typename OutputTileSourceIterator_ =
118         OutputTileIterator_ ///< Tile iterator reading tensors
119     >
120 class EpiloguePipelined : public EpilogueBase<
121                               Shape_,
122                               typename WarpMmaOperator_::Shape,
123                               PartitionsK,
124                               AccumulatorFragmentIterator_,
125                               WarpTileIterator_,
126                               Padding_,
127                               FragmentsPerPartition> {
128  public:
129   using Base = EpilogueBase<
130       Shape_,
131       typename WarpMmaOperator_::Shape,
132       PartitionsK,
133       AccumulatorFragmentIterator_,
134       WarpTileIterator_,
135       Padding_,
136       FragmentsPerPartition>;
137 
138   using Shape = Shape_;
139   using WarpMmaOperator = WarpMmaOperator_;
140   static int const kPartitionsK = PartitionsK;
141   using OutputTileIterator = OutputTileIterator_;
142   using OutputTileSourceIterator = OutputTileSourceIterator_;
143   using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
144   using WarpTileIterator = WarpTileIterator_;
145   using SharedLoadIterator = SharedLoadIterator_;
146   using OutputOp = OutputOp_;
147   using Padding = Padding_;
148 
149   using Layout = layout::RowMajor;
150   using LongIndex = typename Layout::LongIndex;
151 
152   /// The complete warp-level accumulator tile
153   using AccumulatorTile = typename Base::AccumulatorTile;
154 
155   /// Accumulator element
156   using ElementAccumulator = typename WarpTileIterator::Element;
157 
158   /// Output element
159   using ElementOutput = typename OutputTileIterator::Element;
160   using ElementSource = typename OutputTileSourceIterator::Element;
161 
162   /// Output access size
163   static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
164 
165   /// Tensor reference to destination tensor
166   using TensorRef = typename OutputTileIterator::TensorRef;
167 
168   /// Tensor reference to sync tensor
169   using SyncTensorRef =
170       typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
171 
172   /// Const tensor reference to source tensor
173   using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
174 
175   /// Array type used to output
176   using OutputAccessType = Array<
177       typename OutputTileIterator::Element,
178       OutputTileIterator::kElementsPerAccess>;
179   using SourceAccessType = Array<
180       typename OutputTileSourceIterator::Element,
181       OutputTileSourceIterator::kElementsPerAccess>;
182 
183   /// Array type used by output functor
184   using AccumulatorAccessType = Array<
185       typename WarpTileIterator::Element,
186       OutputTileIterator::kElementsPerAccess>;
187 
188   /// Number of warps
189   using WarpCount = typename Base::WarpCount;
190 
191   static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1
192       ? Base::kFragmentsPerIteration
193       : kPartitionsK;
194   static int constexpr kSmemPointerOffset =
195       Base::SharedStorage::StorageShape::kCount / kSmemTiles;
196 
197  public:
198   static_assert(
199       OutputTileSourceIterator::Fragment::kElements ==
200           OutputTileIterator::Fragment::kElements,
201       "Mismatch between input tile and output tile iterator (kElements)");
202   static_assert(
203       OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations,
204       "Mismatch between input tile and output tile iterator (kIterations)");
205   static_assert(
206       SharedLoadIterator::Fragment::kElements ==
207           OutputTileIterator::Fragment::kElements,
208       "Mismatch between shared load iterator and output tile iterator.");
209 
210   static_assert(
211       OutputTileIterator::kElementsPerAccess,
212       "OutputTileIterator::kElementsPerAccess must not be zero.");
213 
214   static_assert(
215       !(OutputTileIterator::Fragment::kElements %
216         OutputTileIterator::kElementsPerAccess),
217       "Divisibility");
218 
219  private:
220   /// Loads fragment from shared memory aligned with output tensor
221   SharedLoadIterator shared_load_iterator_;
222 
223  public:
224   /// Constructor
225   CUTLASS_DEVICE
EpiloguePipelined(typename Base::SharedStorage & shared_storage,int thread_idx,int warp_idx,int lane_idx)226   EpiloguePipelined(
227       typename Base::SharedStorage& shared_storage, ///< Shared storage object
228       int thread_idx, ///< ID of a thread within the threadblock
229       int warp_idx, ///< ID of warp within threadblock
230       int lane_idx ///< Id of thread within warp
231       )
232       : Base(shared_storage, thread_idx, warp_idx, lane_idx),
233         shared_load_iterator_(shared_storage.reference(), thread_idx) {}
234 
235   /// Streams the result to global memory
236   CUTLASS_DEVICE
operator()237   void operator()(
238       OutputOp const& output_op, ///< Output operator
239       OutputTileIterator
240           destination_iterator, ///< Tile iterator for destination
241       AccumulatorTile const&
242           accumulators, ///< Complete warp-level accumulator tile
243       OutputTileSourceIterator
244           source_iterator) { ///< Threadblock tile coordinate in GEMM (in units
245                              ///< of threadblock tiles)
246 
247     if (!output_op.is_source_needed()) {
248       compute_source_not_needed_(output_op, destination_iterator, accumulators);
249     } else {
250       compute_source_needed_(
251           output_op, destination_iterator, accumulators, source_iterator);
252     }
253   }
254   CUTLASS_DEVICE
operator()255   void operator()(
256       OutputOp const& output_op, ///< Output operator
257       OutputTileIterator
258           destination_iterator, ///< Tile iterator for destination
259       AccumulatorTile const&
260           accumulators) { ///< Complete warp-level accumulator tile
261     compute_source_not_needed_(output_op, destination_iterator, accumulators);
262   }
263 
264  private:
265   template <class Seq>
266   struct acc2smem_source_not_needed;
267 
268   template <size_t... Seq>
269   struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
270     template <int Advance>
271     CUTLASS_DEVICE static void helper(
272         AccumulatorFragmentIterator accum_fragment_iterator,
273         WarpTileIterator& warp_tile_iterator) {
274       CUTLASS_PRAGMA_UNROLL
275       for (int i = 0; i < Advance; i++) {
276         ++accum_fragment_iterator;
277       }
278 
279       CUTLASS_PRAGMA_UNROLL
280       for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
281         typename AccumulatorFragmentIterator::Fragment accum_fragment;
282 
283         accum_fragment_iterator.load(accum_fragment);
284         ++accum_fragment_iterator;
285 
286         warp_tile_iterator.store(accum_fragment);
287         if (p < Base::kFragmentsPerIteration - 1) {
288           warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
289         }
290       }
291 
292       if (Base::kFragmentsPerIteration > 1) {
293         warp_tile_iterator.add_pointer_offset(
294             kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
295       }
296     }
297 
298     CUTLASS_DEVICE
299     static void push(
300         size_t pos,
301         AccumulatorFragmentIterator const& iterator_begin,
302         WarpTileIterator& warp_tile_iterator) {
303       int dummy[] = {
304           (pos == (Seq * Base::kFragmentsPerIteration)) &&
305           (helper<Seq * Base::kFragmentsPerIteration>(
306                iterator_begin, warp_tile_iterator),
307            0)...};
308 
309       CUTLASS_UNUSED(dummy[0]);
310     }
311   };
312 
313   static_assert(
314       kPartitionsK == 1 || Base::kFragmentsPerIteration == 1,
315       "One of these must be exactly 1.");
316 
317   /// Streams the result to global memory
318   CUTLASS_DEVICE
319   void compute_source_not_needed_(
320       OutputOp const& output_op, ///< Output operator
321       OutputTileIterator
322           destination_iterator, ///< Tile iterator for destination
323       AccumulatorTile const&
324           accumulators ///< Complete warp-level accumulator tile
325   ) {
326     //
327     // Iterator over warp-level accumulator fragment
328     //
329 
330     AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
331 
332     //
333     // Iterate over accumulator tile
334     //
335 
336 #pragma unroll(                                                          \
337     IterationsUnroll                                                     \
338         ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \
339         : 1)
340     for (int iter = 0; iter < OutputTileIterator::kIterations;
341          iter += Base::kFragmentsPerIteration) {
342       //
343       // Convert and store fragment
344       //
345 
346       __syncthreads();
347 
348       acc2smem_source_not_needed<cutlass::make_index_sequence<
349           OutputTileIterator::kIterations / Base::kFragmentsPerIteration>>::
350           push(iter, accum_fragment_iterator, this->warp_tile_iterator_);
351 
352       __syncthreads();
353 
354       //
355       // Load fragments from shared memory
356       //
357 
358       CUTLASS_PRAGMA_UNROLL
359       for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
360         typename SharedLoadIterator::Fragment
361             aligned_accum_fragment[kPartitionsK];
362 
363         shared_load_iterator_.load(aligned_accum_fragment[0]);
364 
365         if (p < Base::kFragmentsPerIteration - 1) {
366           shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
367         } else if (kPartitionsK > 1) {
368           plus<typename SharedLoadIterator::Fragment> add_fragments;
369 
370           CUTLASS_PRAGMA_UNROLL
371           for (int i = 1; i < kPartitionsK; ++i) {
372             shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
373             shared_load_iterator_.load(aligned_accum_fragment[i]);
374             aligned_accum_fragment[0] = add_fragments(
375                 aligned_accum_fragment[0], aligned_accum_fragment[i]);
376           }
377 
378           shared_load_iterator_.add_pointer_offset(
379               (1 - kPartitionsK) * kSmemPointerOffset);
380         }
381 
382         //
383         // Compute the output result
384         //
385 
386         typename OutputTileIterator::Fragment output_fragment;
387 
388         apply_output_operator_source_not_needed_(
389             destination_iterator.thread_start_row(),
390             output_fragment,
391             output_op,
392             aligned_accum_fragment[0]);
393 
394         //
395         // Store the final result
396         //
397 
398         destination_iterator.store(output_fragment);
399         ++destination_iterator;
400       }
401 
402       if (Base::kFragmentsPerIteration > 1) {
403         shared_load_iterator_.add_pointer_offset(
404             kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
405       }
406     }
407   }
408 
409   template <class Seq>
410   struct acc2smem_source_needed;
411 
412   template <size_t... Seq>
413   struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
414     template <int Advance>
415     CUTLASS_DEVICE static void helper(
416         AccumulatorFragmentIterator accum_fragment_iterator,
417         WarpTileIterator& warp_tile_iterator) {
418       CUTLASS_PRAGMA_UNROLL
419       for (int i = 0; i < Advance; i++) {
420         ++accum_fragment_iterator;
421       }
422 
423       typename AccumulatorFragmentIterator::Fragment accum_fragment;
424       accum_fragment_iterator.load(accum_fragment);
425       warp_tile_iterator.store(accum_fragment);
426     }
427 
428     CUTLASS_DEVICE
429     static void push(
430         size_t pos,
431         AccumulatorFragmentIterator const& iterator_begin,
432         WarpTileIterator& warp_tile_iterator) {
433       int dummy[] = {
434           (pos == Seq) &&
435           (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
436     }
437   };
438 
439   /// Streams the result to global memory
440   CUTLASS_DEVICE
441   void compute_source_needed_(
442       OutputOp const& output_op, ///< Output operator
443       OutputTileIterator
444           destination_iterator, ///< Tile iterator for destination
445       AccumulatorTile const&
446           accumulators, ///< Complete warp-level accumulator tile
447       OutputTileSourceIterator
448           source_iterator ///< Threadblock tile coordinate in GEMM (in units of
449                           ///< threadblock tiles)
450   ) {
451     typename OutputTileSourceIterator::Fragment source_fragment[2];
452 
453     source_fragment[0].clear();
454     source_iterator.load(source_fragment[0]);
455     ++source_iterator;
456     source_fragment[1].clear();
457 
458     //
459     // Iterator over warp-level accumulator fragment
460     //
461 
462     AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
463 
464     //
465     // Iterate over accumulator tile
466     //
467 
468 #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
469     for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
470       if (iter > 0) {
471         __syncthreads();
472       }
473       //
474       // Load the source for next iteration (pipelining)
475       //
476 
477       if (iter + 1 < OutputTileIterator::kIterations) {
478         source_iterator.load(source_fragment[(iter + 1) % 2]);
479       }
480       ++source_iterator;
481       acc2smem_source_needed<
482           cutlass::make_index_sequence<OutputTileIterator::kIterations>>::
483           push(iter, accum_fragment_iterator, this->warp_tile_iterator_);
484 
485       __syncthreads();
486 
487       //
488       // Load fragments from shared memory
489       //
490 
491       typename SharedLoadIterator::Fragment
492           aligned_accum_fragment[kPartitionsK];
493 
494       shared_load_iterator_.load(aligned_accum_fragment[0]);
495 
496       // If the number of k-slices is > 1 - perform a reduction amongst the
497       // k-slices
498       if (kPartitionsK > 1) {
499         plus<typename SharedLoadIterator::Fragment> add_fragments;
500 
501         CUTLASS_PRAGMA_UNROLL
502         for (int i = 1; i < kPartitionsK; ++i) {
503           shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
504           shared_load_iterator_.load(aligned_accum_fragment[i]);
505           aligned_accum_fragment[0] = add_fragments(
506               aligned_accum_fragment[0], aligned_accum_fragment[i]);
507         }
508 
509         shared_load_iterator_.add_pointer_offset(
510             (1 - kPartitionsK) * kSmemPointerOffset);
511       }
512 
513       //
514       // Compute the output result
515       //
516 
517       typename OutputTileIterator::Fragment output_fragment;
518 
519       apply_output_operator_(
520           destination_iterator.thread_start_row(),
521           output_fragment,
522           output_op,
523           aligned_accum_fragment[0],
524           source_fragment[iter % 2]);
525 
526       //
527       // Store the final result
528       //
529 
530       destination_iterator.store(output_fragment);
531       ++destination_iterator;
532     }
533   }
534 
535   /// Helper to invoke the output functor over each vector of output
536   CUTLASS_DEVICE
537   void apply_output_operator_(
538       int begin_row,
539       typename OutputTileIterator::Fragment& output_fragment,
540       OutputOp const& output_op, ///< Output operator
541       typename SharedLoadIterator::Fragment const& aligned_accum_fragment,
542       typename OutputTileSourceIterator::Fragment const& source_fragment) {
543     OutputAccessType* output_frag_ptr =
544         reinterpret_cast<OutputAccessType*>(&output_fragment);
545 
546     AccumulatorAccessType const* compute_frag_ptr =
547         reinterpret_cast<AccumulatorAccessType const*>(&aligned_accum_fragment);
548 
549     SourceAccessType const* source_frag_ptr =
550         reinterpret_cast<SourceAccessType const*>(&source_fragment);
551 
552     int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
553         OutputTileIterator::kElementsPerAccess;
554 
555     CUTLASS_PRAGMA_UNROLL
556     for (int i = 0; i < kOutputOpIterations; ++i) {
557       // Call the output operator
558       output_frag_ptr[i] = ApplyEpilogueOp<OutputOp>::apply(
559           output_op,
560           begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess),
561           compute_frag_ptr[i],
562           source_frag_ptr[i]);
563     }
564   }
565 
566   /// Helper to invoke the output functor over each vector of output
567   CUTLASS_DEVICE
568   void apply_output_operator_source_not_needed_(
569       int begin_row,
570       typename OutputTileIterator::Fragment& output_fragment,
571       OutputOp const& output_op, ///< Output operator
572       typename SharedLoadIterator::Fragment const& aligned_accum_fragment) {
573     OutputAccessType* output_frag_ptr =
574         reinterpret_cast<OutputAccessType*>(&output_fragment);
575 
576     AccumulatorAccessType const* compute_frag_ptr =
577         reinterpret_cast<AccumulatorAccessType const*>(&aligned_accum_fragment);
578 
579     int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
580         OutputTileIterator::kElementsPerAccess;
581 
582     CUTLASS_PRAGMA_UNROLL
583     for (int i = 0; i < kOutputOpIterations; ++i) {
584       // Call the output operator
585       output_frag_ptr[i] = ApplyEpilogueOp<OutputOp>::apply(
586           output_op,
587           begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess),
588           compute_frag_ptr[i]);
589     }
590   }
591 
592   // This should be constexpr, but it's only supported on c++14
593   constexpr int CUTLASS_HOST_DEVICE getRowOffset(int i) {
594     using ThreadMap = typename OutputTileIterator::ThreadMap;
595 
596     CUTLASS_PRAGMA_UNROLL
597     for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
598          ++cluster) {
599       CUTLASS_PRAGMA_UNROLL
600       for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
601         CUTLASS_PRAGMA_UNROLL
602         for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
603           int row_offset = row * ThreadMap::Delta::kRow +
604               group * ThreadMap::Delta::kGroup +
605               cluster * ThreadMap::Delta::kCluster;
606           int frag_row_idx =
607               (row +
608                ThreadMap::Iterations::kRow *
609                    (group + ThreadMap::Iterations::kGroup * cluster));
610           CUTLASS_PRAGMA_UNROLL
611           for (int column = 0; column < ThreadMap::Iterations::kColumn;
612                ++column) {
613             int frag_idx = ThreadMap::kElementsPerAccess *
614                 (frag_row_idx * ThreadMap::Iterations::kColumn + column);
615             if (i < frag_idx + ThreadMap::kElementsPerAccess) {
616               return row_offset;
617             }
618           }
619         }
620       }
621     }
622     return -1;
623   }
624 };
625 
626 ////////////////////////////////////////////////////////////////////////////////
627 
628 } // namespace threadblock
629 } // namespace epilogue
630 } // namespace cutlass
631 
632 ////////////////////////////////////////////////////////////////////////////////
633