Home
last modified time | relevance | path

Searched defs:MatmulDOIVJ (Results 1 – 1 of 1) sorted by relevance

/external/pytorch/aten/src/ATen/native/transformers/cuda/mem_eff_attention/
Dkernel_backward.h426 struct MatmulDOIVJ { struct
431 using ThreadblockShape =
433 using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
435 using ElementC = output_t;
436 using ElementAccum = accum_t;
439 using BiasGradEpilogueOutputOp =
447 using DefaultGemm = typename cutlass::gemm::kernel::DefaultGemm<
473 using Mma = typename MakeCustomMma<typename DefaultGemm::Mma, kMaxK>::Mma;
474 using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
481 using BiasGradEpilogue = typename DefaultGemm::Epilogue;
[all …]