1 #include <ATen/native/vulkan/ops/Common.h>
2 #include <ATen/native/vulkan/ops/Utils.h>
3 #include <torch/library.h>
4
5 namespace at {
6 namespace native {
7 namespace vulkan {
8 namespace ops {
9 namespace {
10
11 using namespace api::utils;
12
13 struct Block final {
14 ivec2 info;
15 };
16
unsqueeze(const at::Tensor & self,int64_t dim)17 Tensor unsqueeze(const at::Tensor& self, int64_t dim) {
18 TORCH_CHECK(
19 self.dim() <= 3,
20 "Vulkan unsqueeze only supports up to 3d tensors as input!");
21 TORCH_CHECK(
22 dim >= -self.dim() - 1 && dim <= self.dim(),
23 "Vulkan unsqueeze dimension out of range expected to be in range of [",
24 -self.dim() - 1,
25 ",",
26 self.dim(),
27 "], but got ",
28 dim);
29
30 // Get the global Vulkan context
31 api::Context* const context = api::context();
32
33 // Cast the input Tensor to a vTensor
34 const Tensor input = self.is_vulkan() ? self : self.vulkan();
35 const vTensor& v_input = convert(input);
36
37 // Create the output texture. For unsqueeze, add a dimension.
38 std::vector<int64_t> output_size = v_input.sizes();
39 if (dim < 0) {
40 dim += (self.dim() + 1);
41 }
42 output_size.insert(output_size.begin() + dim, 1);
43 // Create the output texture
44 vTensor v_output{
45 context,
46 output_size,
47 convert_dtype(self.scalar_type()),
48 };
49
50 // Required to determine how to insert memory barriers in the command buffer
51 api::PipelineBarrier pipeline_barrier{};
52
53 // Total number of work items is equal to the size of the output texture
54 uvec3 global_size = v_output.extents();
55 // Adaptively determine local work group size, will usually be {4, 4, 4}
56 uvec3 local_size = adaptive_work_group_size(global_size);
57
58 // When unsqueezing in the 0th dimension, only the metadata changes.
59 // So we can perform a copy.
60 if (dim == 0) {
61 const vTensor& v_self = convert(self);
62 uvec3 src_offset{};
63 uvec3 dst_offset{};
64 context->submit_copy<api::VulkanImage, api::VulkanImage>(
65 // pipeline barrier
66 pipeline_barrier,
67 // images
68 v_self.image(pipeline_barrier, api::PipelineStage::TRANSFER),
69 v_output.image(
70 pipeline_barrier,
71 api::PipelineStage::TRANSFER,
72 api::MemoryAccessType::WRITE),
73 // copy details
74 v_self.extents(),
75 src_offset,
76 dst_offset,
77 // fence handle
78 VK_NULL_HANDLE);
79 return convert(v_output);
80 }
81
82 else {
83 int channel_index = 1; // Channel dimension in a 3D tensor
84 // Shift dim and channel_index for 1D, 2D tensors
85 if (self.dim() < 3) {
86 dim += (3 - self.dim());
87 channel_index = 0;
88 }
89
90 // Create the params buffer
91 struct Block block {
92 {
93 // Dimension to unsqueeze
94 static_cast<int32_t>(dim),
95 // Keep track of the channel in Image3D
96 static_cast<int32_t>(
97 std::ceil(static_cast<float>(output_size[channel_index]) / 4)),
98 }
99 };
100
101 api::UniformParamsBuffer params(context, block);
102
103 context->submit_compute_job(
104 // shader descriptor
105 VK_KERNEL(unsqueeze),
106 // pipeline barrier
107 pipeline_barrier,
108 // global work group size
109 global_size,
110 // local work group size
111 local_size,
112 // fence handle
113 VK_NULL_HANDLE,
114 // shader arguments
115 v_output.image(
116 pipeline_barrier,
117 api::PipelineStage::COMPUTE,
118 api::MemoryAccessType::WRITE),
119 v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
120 // params buffer
121 params.buffer());
122 return convert(v_output);
123 }
124 }
125
126 #ifdef USE_VULKAN_API
127
TORCH_LIBRARY_IMPL(aten,Vulkan,m)128 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
129 m.impl(TORCH_SELECTIVE_NAME("aten::unsqueeze"), TORCH_FN(unsqueeze));
130 }
131
132 #endif /* USE_VULKAN_API */
133
134 } // namespace
135 } // namespace ops
136 } // namespace vulkan
137 } // namespace native
138 } // namespace at
139