• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/native/vulkan/ops/Common.h>
2 #include <ATen/native/vulkan/ops/Utils.h>
3 #include <torch/library.h>
4 
5 namespace at {
6 namespace native {
7 namespace vulkan {
8 namespace ops {
9 namespace {
10 
11 using namespace api::utils;
12 
13 struct Block final {
14   ivec2 info;
15 };
16 
unsqueeze(const at::Tensor & self,int64_t dim)17 Tensor unsqueeze(const at::Tensor& self, int64_t dim) {
18   TORCH_CHECK(
19       self.dim() <= 3,
20       "Vulkan unsqueeze only supports up to 3d tensors as input!");
21   TORCH_CHECK(
22       dim >= -self.dim() - 1 && dim <= self.dim(),
23       "Vulkan unsqueeze dimension out of range expected to be in range of [",
24       -self.dim() - 1,
25       ",",
26       self.dim(),
27       "], but got ",
28       dim);
29 
30   // Get the global Vulkan context
31   api::Context* const context = api::context();
32 
33   // Cast the input Tensor to a vTensor
34   const Tensor input = self.is_vulkan() ? self : self.vulkan();
35   const vTensor& v_input = convert(input);
36 
37   // Create the output texture. For unsqueeze, add a dimension.
38   std::vector<int64_t> output_size = v_input.sizes();
39   if (dim < 0) {
40     dim += (self.dim() + 1);
41   }
42   output_size.insert(output_size.begin() + dim, 1);
43   // Create the output texture
44   vTensor v_output{
45       context,
46       output_size,
47       convert_dtype(self.scalar_type()),
48   };
49 
50   // Required to determine how to insert memory barriers in the command buffer
51   api::PipelineBarrier pipeline_barrier{};
52 
53   // Total number of work items is equal to the size of the output texture
54   uvec3 global_size = v_output.extents();
55   // Adaptively determine local work group size, will usually be {4, 4, 4}
56   uvec3 local_size = adaptive_work_group_size(global_size);
57 
58   // When unsqueezing in the 0th dimension, only the metadata changes.
59   // So we can perform a copy.
60   if (dim == 0) {
61     const vTensor& v_self = convert(self);
62     uvec3 src_offset{};
63     uvec3 dst_offset{};
64     context->submit_copy<api::VulkanImage, api::VulkanImage>(
65         // pipeline barrier
66         pipeline_barrier,
67         // images
68         v_self.image(pipeline_barrier, api::PipelineStage::TRANSFER),
69         v_output.image(
70             pipeline_barrier,
71             api::PipelineStage::TRANSFER,
72             api::MemoryAccessType::WRITE),
73         // copy details
74         v_self.extents(),
75         src_offset,
76         dst_offset,
77         // fence handle
78         VK_NULL_HANDLE);
79     return convert(v_output);
80   }
81 
82   else {
83     int channel_index = 1; // Channel dimension in a 3D tensor
84     // Shift dim and channel_index for 1D, 2D tensors
85     if (self.dim() < 3) {
86       dim += (3 - self.dim());
87       channel_index = 0;
88     }
89 
90     // Create the params buffer
91     struct Block block {
92       {
93         // Dimension to unsqueeze
94         static_cast<int32_t>(dim),
95             // Keep track of the channel in Image3D
96             static_cast<int32_t>(
97                 std::ceil(static_cast<float>(output_size[channel_index]) / 4)),
98       }
99     };
100 
101     api::UniformParamsBuffer params(context, block);
102 
103     context->submit_compute_job(
104         // shader descriptor
105         VK_KERNEL(unsqueeze),
106         // pipeline barrier
107         pipeline_barrier,
108         // global work group size
109         global_size,
110         // local work group size
111         local_size,
112         // fence handle
113         VK_NULL_HANDLE,
114         // shader arguments
115         v_output.image(
116             pipeline_barrier,
117             api::PipelineStage::COMPUTE,
118             api::MemoryAccessType::WRITE),
119         v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
120         // params buffer
121         params.buffer());
122     return convert(v_output);
123   }
124 }
125 
126 #ifdef USE_VULKAN_API
127 
TORCH_LIBRARY_IMPL(aten,Vulkan,m)128 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
129   m.impl(TORCH_SELECTIVE_NAME("aten::unsqueeze"), TORCH_FN(unsqueeze));
130 }
131 
132 #endif /* USE_VULKAN_API */
133 
134 } // namespace
135 } // namespace ops
136 } // namespace vulkan
137 } // namespace native
138 } // namespace at
139