Home
last modified time | relevance | path

Searched defs:MHAParams (Results 1 – 1 of 1) sorted by relevance

/external/pytorch/aten/src/ATen/native/cudnn/
DMHA.cpp120 struct MHAParams { struct
121 c10::DeviceIndex device_id;
122 fe::DataType_t dataType;
123 std::array<int, MAX_MHA_DIM> q_dim;
124 std::array<int, MAX_MHA_DIM> k_dim;
125 std::array<int, MAX_MHA_DIM> v_dim;
126 std::array<int, MAX_MHA_DIM> q_stride;
127 std::array<int, MAX_MHA_DIM> k_stride;
128 std::array<int, MAX_MHA_DIM> v_stride;
129 int64_t b;
[all …]