Searched defs:MHAParams (Results 1 – 1 of 1) sorted by relevance
120 struct MHAParams { struct121 c10::DeviceIndex device_id;122 fe::DataType_t dataType;123 std::array<int, MAX_MHA_DIM> q_dim;124 std::array<int, MAX_MHA_DIM> k_dim;125 std::array<int, MAX_MHA_DIM> v_dim;126 std::array<int, MAX_MHA_DIM> q_stride;127 std::array<int, MAX_MHA_DIM> k_stride;128 std::array<int, MAX_MHA_DIM> v_stride;129 int64_t b;[all …]