1 #include <ATen/native/vulkan/impl/Common.h> 2 3 namespace at { 4 namespace native { 5 namespace vulkan { 6 adaptive_work_group_size(const api::utils::uvec3 & global_work_group)7api::utils::uvec3 adaptive_work_group_size( 8 const api::utils::uvec3& global_work_group) { 9 api::utils::uvec3 local_group_size = {4, 4, 4}; 10 if (global_work_group.data[2u] == 1) { 11 if (global_work_group.data[1u] < 8) { 12 local_group_size.data[0u] = 16; 13 local_group_size.data[1u] = 4; 14 local_group_size.data[2u] = 1; 15 } else { 16 local_group_size.data[0u] = 8; 17 local_group_size.data[1u] = 8; 18 local_group_size.data[2u] = 1; 19 } 20 } 21 return local_group_size; 22 } 23 24 } // namespace vulkan 25 } // namespace native 26 } // namespace at 27