Home
last modified time | relevance | path

Searched defs:nsplits (Results 1 – 3 of 3) sorted by relevance

/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/
Dflash_bwd_launch_template.h64 __global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) { in flash_bwd_convert_dq_kernel()
Dflash_bwd_preprocess_kernel.h181 inline __device__ void convert_dQ(const Params &params, const int nsplits) { in convert_dQ()
Dflash_api.cpp941 …const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num… in mha_bwd() local
1172 …const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num… in mha_varlen_bwd() local