/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include namespace vkcompute { void check_args( const api::vTensor& t_in, int64_t dim, int64_t index, const api::vTensor& t_out) { VK_CHECK_COND(check_packed_dim_is(t_in, WHCN::kChannelsDim)); VK_CHECK_COND(check_packed_dim_is(t_out, WHCN::kChannelsDim)); const int64_t in_dim = t_in.dim(); VK_CHECK_COND( in_dim == 3 || in_dim == 4, "Vulkan select only support 3d or 4d tensors!"); const int64_t in_size = t_in.size(dim); if (index < -in_size || index >= in_size) { VK_CHECK_COND( false, "select(): index ", index, " t_outof range for tensor of size ", in_size, " at dimension ", dim); } } void add_select_int_node( ComputeGraph& graph, const ValueRef in, const ValueRef dim_ref, const ValueRef index_ref, const ValueRef out) { vTensorPtr t_in = graph.get_tensor(in); vTensorPtr t_out = graph.get_tensor(out); int64_t dim = graph.extract_scalar(dim_ref); int64_t index = graph.extract_scalar(index_ref); check_args(*t_in, dim, index, *t_out); const int64_t in_size = t_in->size(dim); if (index < 0) { index += in_size; } std::string kernel_name; // for 3d tensors, these values are not used by the shader. int32_t num_texel_per_batch = 1; int32_t num_batches = 1; int64_t in_dim = t_in->dim(); if (in_dim == 3) { if (dim == 0) { kernel_name = "select_channel_3d"; } else if (dim == 1) { kernel_name = "select_height_3d"; } else if (dim == 2) { kernel_name = "select_width_3d"; } else { VK_CHECK_COND( false, "Unexpected dim value=", dim, "for the input 3d tensor"); } } else { // self.dim() == 4 num_texel_per_batch = static_cast(std::ceil(static_cast(t_in->size(1)) / 4)); num_batches = t_in->size(0); if (dim == 0) { kernel_name = "select_batch_4d"; } else if (dim == 1) { kernel_name = "select_channel_4d"; } else if (dim == 2) { kernel_name = "select_height_4d"; } else if (dim == 3) { kernel_name = "select_width_4d"; } else { VK_CHECK_COND( false, "Unexpected dim value=", dim, "for the input 4d tensor"); } } kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, *t_out); // TODO: add resizing to support dynamic shapes. graph.execute_nodes().emplace_back(new DispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), graph.create_global_wg_size(out), graph.create_local_wg_size(out), // Inputs and Outputs {{out, vkapi::MemoryAccessType::WRITE}, {in, vkapi::MemoryAccessType::READ}}, // Parameter buffers {t_out->logical_limits_ubo(), t_out->sizes_ubo(), // TODO: num_batches and num_texel_per_batch are provided by // t_out->sizes. Can change the following to reduce params // created. graph.create_params_buffer( utils::make_ivec4({index, num_batches, num_texel_per_batch, 0}))}, // Specialization Constants {})); } void select_int(ComputeGraph& graph, const std::vector& args) { return add_select_int_node(graph, args[0], args[1], args[2], args[3]); } REGISTER_OPERATORS { VK_REGISTER_OP(aten.select.int, select_int); VK_REGISTER_OP(aten.select_copy.int, select_int); } } // namespace vkcompute