• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/native/vulkan/ops/Common.h>
2 #include <c10/util/irange.h>
3 #include <torch/library.h>
4 
5 namespace at {
6 namespace native {
7 namespace vulkan {
8 namespace ops {
9 namespace {
10 
11 using namespace api::utils;
12 
13 namespace {
normalize_dim(int64_t d,int64_t n)14 inline int64_t normalize_dim(int64_t d, int64_t n) {
15   return (d % n + n) % n;
16 }
17 } // namespace
18 
cat_batch(const MaterializedITensorListRef & tensors,vTensor & v_output)19 Tensor cat_batch(const MaterializedITensorListRef& tensors, vTensor& v_output) {
20   api::Context* const context = api::context();
21 
22   uvec3 src_offset{};
23   uvec3 dst_offset{};
24 
25   for (const at::Tensor& tensor : tensors) {
26     const Tensor self = tensor.is_vulkan() ? tensor : tensor.vulkan();
27     const vTensor& v_self = convert(self);
28 
29     api::PipelineBarrier pipeline_barrier{};
30 
31     context->submit_copy<api::VulkanImage, api::VulkanImage>(
32         // pipeline barrier
33         pipeline_barrier,
34         // images
35         v_self.image(pipeline_barrier, api::PipelineStage::TRANSFER),
36         v_output.image(
37             pipeline_barrier,
38             api::PipelineStage::TRANSFER,
39             api::MemoryAccessType::WRITE),
40         // copy details
41         v_self.extents(),
42         src_offset,
43         dst_offset,
44         // fence handle
45         VK_NULL_HANDLE);
46 
47     // Increment by the number of texels in the depth dimension
48     dst_offset.data[2u] += v_self.extents().data[2u];
49   }
50 
51   return convert(v_output);
52 }
53 
cat_feature(const MaterializedITensorListRef & tensors,vTensor & v_output)54 Tensor cat_feature(
55     const MaterializedITensorListRef& tensors,
56     vTensor& v_output) {
57   api::Context* const context = api::context();
58 
59   // Determine the channels of the output tensor
60   uint32_t ch_total = 0;
61   for (const at::Tensor& tensor : tensors) {
62     ch_total += get_dim<Dim4D::Channel>(tensor);
63   }
64 
65   // Running counter of the number of channels already appended.
66   uint32_t ch_current = 0;
67   for (const at::Tensor& tensor : tensors) {
68     const Tensor self = tensor.is_vulkan() ? tensor : tensor.vulkan();
69     const vTensor& v_self = convert(self);
70 
71     // Determine the number of channel texels that will be modified by
72     // appending this input tensor
73     uint32_t start_ch4 = ch_current / 4;
74 
75     uint32_t end_ch4 =
76         api::utils::div_up(ch_current + get_dim<Dim4D::Channel>(v_self), 4u);
77 
78     uint32_t ch4_range = end_ch4 - start_ch4;
79     uint32_t nc4_range = ch4_range * get_dim<Dim4D::Batch>(v_self);
80 
81     const struct Block final {
82       ivec3 outExtents;
83       int32_t fill0;
84       ivec3 inExtents;
85       int32_t fill1;
86       uvec2 outChInfo;
87       uvec2 inChInfo;
88       uvec4 appendedChInfo;
89     } block{
90         api::utils::make_ivec3(v_output.extents()),
91         0,
92         api::utils::make_ivec3(v_self.extents()),
93         0,
94         {
95             ch_total,
96             api::utils::div_up(ch_total, 4u),
97         },
98         {
99             get_dim<Dim4D::Channel>(v_self),
100             api::utils::align_up(get_dim<Dim4D::Channel>(v_self), 4u),
101         },
102         {
103             ch_current,
104             start_ch4,
105             ch4_range,
106             0u,
107         },
108     };
109 
110     api::UniformParamsBuffer params(context, block);
111     api::PipelineBarrier pipeline_barrier{};
112 
113     context->submit_compute_job(
114         // shader descriptor
115         VK_KERNEL(cat_feature),
116         // pipeline barrier
117         pipeline_barrier,
118         // global work group size
119         {
120             get_dim<Dim4D::Width>(v_output),
121             get_dim<Dim4D::Height>(v_output),
122             nc4_range,
123         },
124         // local work group size
125         adaptive_work_group_size(v_self.extents()),
126         // fence handle
127         VK_NULL_HANDLE,
128         // shader arguments
129         v_output.image(
130             pipeline_barrier,
131             api::PipelineStage::COMPUTE,
132             api::MemoryAccessType::READ | api::MemoryAccessType::WRITE),
133         v_self.image(pipeline_barrier, api::PipelineStage::COMPUTE),
134         // params buffer
135         params.buffer());
136 
137     ch_current += get_dim<Dim4D::Channel>(v_self);
138   }
139 
140   return convert(v_output);
141 }
142 
cat_feature_mult4ch(const MaterializedITensorListRef & tensors,vTensor & v_output)143 Tensor cat_feature_mult4ch(
144     const MaterializedITensorListRef& tensors,
145     vTensor& v_output) {
146   api::Context* const context = api::context();
147 
148   int64_t depth_size_allprior = 0;
149   int64_t ch_interval = 0;
150   for (const at::Tensor& tensor : tensors) {
151     ch_interval += get_dim<Dim4D::Channel>(tensor);
152   }
153   const int64_t depth_interval = ch_interval / 4;
154 
155   uvec3 src_offset{};
156   uvec3 dst_offset{};
157 
158   for (const at::Tensor& tensor_arg : tensors) {
159     const Tensor tensor =
160         tensor_arg.is_vulkan() ? tensor_arg : tensor_arg.vulkan();
161     const vTensor& v_self = convert(tensor);
162 
163     const uint32_t depth_slice =
164         safe_downcast<uint32_t>(get_dim<Dim4D::Channel>(tensor) / 4);
165 
166     uvec3 copy_extents{
167         v_self.extents().data[0u], v_self.extents().data[1u], depth_slice};
168 
169     for (const auto b : c10::irange(get_dim<Dim4D::Batch>(tensor))) {
170       src_offset.data[2u] = safe_downcast<uint32_t>(depth_slice * b);
171       dst_offset.data[2u] =
172           depth_size_allprior + safe_downcast<uint32_t>(depth_interval * b);
173 
174       api::PipelineBarrier pipeline_barrier{};
175 
176       context->submit_copy<api::VulkanImage, api::VulkanImage>(
177           // pipeline barrier
178           pipeline_barrier,
179           // images
180           v_self.image(pipeline_barrier, api::PipelineStage::TRANSFER),
181           v_output.image(
182               pipeline_barrier,
183               api::PipelineStage::TRANSFER,
184               api::MemoryAccessType::WRITE),
185           // copy details
186           copy_extents,
187           src_offset,
188           dst_offset,
189           // fence handle
190           VK_NULL_HANDLE);
191     }
192 
193     depth_size_allprior += depth_slice;
194   }
195 
196   return convert(v_output);
197 }
198 
cat_width(const MaterializedITensorListRef & tensors,vTensor & v_output)199 Tensor cat_width(const MaterializedITensorListRef& tensors, vTensor& v_output) {
200   // TORCH_CHECK(false, "Vulkan cat not implemented for width dimension!");
201   api::Context* const context = api::context();
202 
203   uvec3 src_offset{};
204   uvec3 dst_offset{};
205 
206   for (const at::Tensor& tensor : tensors) {
207     const Tensor self = tensor.is_vulkan() ? tensor : tensor.vulkan();
208     const vTensor& v_self = convert(self);
209 
210     api::PipelineBarrier pipeline_barrier{};
211 
212     context->submit_copy<api::VulkanImage, api::VulkanImage>(
213         // pipeline barrier
214         pipeline_barrier,
215         // images
216         v_self.image(pipeline_barrier, api::PipelineStage::TRANSFER),
217         v_output.image(
218             pipeline_barrier,
219             api::PipelineStage::TRANSFER,
220             api::MemoryAccessType::WRITE),
221         // copy details
222         v_self.extents(),
223         src_offset,
224         dst_offset,
225         // fence handle
226         VK_NULL_HANDLE);
227 
228     // Increment by width
229     dst_offset.data[0u] += v_self.extents().data[0u];
230   }
231 
232   return convert(v_output);
233 }
234 
cat_height(const MaterializedITensorListRef & tensors,vTensor & v_output)235 Tensor cat_height(
236     const MaterializedITensorListRef& tensors,
237     vTensor& v_output) {
238   api::Context* const context = api::context();
239 
240   uvec3 src_offset{};
241   uvec3 dst_offset{};
242 
243   for (const at::Tensor& tensor : tensors) {
244     const Tensor self = tensor.is_vulkan() ? tensor : tensor.vulkan();
245     const vTensor& v_self = convert(self);
246 
247     api::PipelineBarrier pipeline_barrier{};
248 
249     context->submit_copy<api::VulkanImage, api::VulkanImage>(
250         // pipeline barrier
251         pipeline_barrier,
252         // images
253         v_self.image(pipeline_barrier, api::PipelineStage::TRANSFER),
254         v_output.image(
255             pipeline_barrier,
256             api::PipelineStage::TRANSFER,
257             api::MemoryAccessType::WRITE),
258         // copy details
259         v_self.extents(),
260         src_offset,
261         dst_offset,
262         // fence handle
263         VK_NULL_HANDLE);
264 
265     // Increment by height
266     dst_offset.data[1u] += v_self.extents().data[1u];
267   }
268 
269   return convert(v_output);
270 }
271 
cat(const at::ITensorListRef & tensors,const int64_t in_dim)272 Tensor cat(const at::ITensorListRef& tensors, const int64_t in_dim) {
273   TORCH_CHECK(!tensors.empty(), "Vulkan cat expects at least one tensor");
274   auto materialized = tensors.materialize();
275   TORCH_INTERNAL_ASSERT(!materialized.empty(), "Accessing empty array");
276   const at::Tensor& tensor = materialized[0];
277   auto ndim = safe_downcast<uint32_t>(tensor.dim());
278   const int64_t dim = normalize_dim(in_dim, ndim);
279   int64_t cat_dim_size = 0;
280   bool is_mult4ch = true;
281 
282   for (const at::Tensor& t : materialized) {
283     TORCH_INTERNAL_ASSERT(
284         t.dim() <= 4,
285         "Vulkan cat expects inputs to have at most 4 dimensions, but got ",
286         t.dim(),
287         "d");
288 
289     if (ndim < 3 || get_dim<Dim4D::Channel>(t) % 4 != 0) {
290       is_mult4ch = false;
291     }
292 
293     for (const auto d : c10::irange(ndim)) {
294       if (d == dim) {
295         continue;
296       }
297       TORCH_INTERNAL_ASSERT(
298           t.size(d) == tensor.size(d),
299           "Vulkan cat inputs must have matching sizes except concatenated dimension");
300     }
301     cat_dim_size += t.size(dim);
302   }
303 
304   auto result_size = tensor.sizes().vec();
305   TORCH_INTERNAL_ASSERT(!result_size.empty(), "Accessing empty array");
306   result_size[dim] = cat_dim_size;
307 
308   vTensor v_output{
309       api::context(), result_size, convert_dtype(tensor.scalar_type())};
310 
311   if (dim == ndim - 1) {
312     return cat_width(materialized, v_output);
313   }
314   if (dim == ndim - 2) {
315     return cat_height(materialized, v_output);
316   } else if (dim == ndim - 3) {
317     if (is_mult4ch) {
318       return cat_feature_mult4ch(materialized, v_output);
319     }
320     return cat_feature(materialized, v_output);
321   }
322   return cat_batch(materialized, v_output);
323 }
324 
325 #ifdef USE_VULKAN_API
326 
TORCH_LIBRARY_IMPL(aten,Vulkan,m)327 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
328   m.impl(TORCH_SELECTIVE_NAME("aten::cat"), TORCH_FN(cat));
329 }
330 
331 #endif /* USE_VULKAN_API */
332 
333 } // namespace
334 } // namespace ops
335 } // namespace vulkan
336 } // namespace native
337 } // namespace at
338