1 #include <ATen/native/vulkan/ops/Common.h>
2 #include <c10/util/irange.h>
3 #include <torch/library.h>
4
5 namespace at {
6 namespace native {
7 namespace vulkan {
8 namespace ops {
9 namespace {
10
11 using namespace api::utils;
12
13 namespace {
normalize_dim(int64_t d,int64_t n)14 inline int64_t normalize_dim(int64_t d, int64_t n) {
15 return (d % n + n) % n;
16 }
17 } // namespace
18
cat_batch(const MaterializedITensorListRef & tensors,vTensor & v_output)19 Tensor cat_batch(const MaterializedITensorListRef& tensors, vTensor& v_output) {
20 api::Context* const context = api::context();
21
22 uvec3 src_offset{};
23 uvec3 dst_offset{};
24
25 for (const at::Tensor& tensor : tensors) {
26 const Tensor self = tensor.is_vulkan() ? tensor : tensor.vulkan();
27 const vTensor& v_self = convert(self);
28
29 api::PipelineBarrier pipeline_barrier{};
30
31 context->submit_copy<api::VulkanImage, api::VulkanImage>(
32 // pipeline barrier
33 pipeline_barrier,
34 // images
35 v_self.image(pipeline_barrier, api::PipelineStage::TRANSFER),
36 v_output.image(
37 pipeline_barrier,
38 api::PipelineStage::TRANSFER,
39 api::MemoryAccessType::WRITE),
40 // copy details
41 v_self.extents(),
42 src_offset,
43 dst_offset,
44 // fence handle
45 VK_NULL_HANDLE);
46
47 // Increment by the number of texels in the depth dimension
48 dst_offset.data[2u] += v_self.extents().data[2u];
49 }
50
51 return convert(v_output);
52 }
53
cat_feature(const MaterializedITensorListRef & tensors,vTensor & v_output)54 Tensor cat_feature(
55 const MaterializedITensorListRef& tensors,
56 vTensor& v_output) {
57 api::Context* const context = api::context();
58
59 // Determine the channels of the output tensor
60 uint32_t ch_total = 0;
61 for (const at::Tensor& tensor : tensors) {
62 ch_total += get_dim<Dim4D::Channel>(tensor);
63 }
64
65 // Running counter of the number of channels already appended.
66 uint32_t ch_current = 0;
67 for (const at::Tensor& tensor : tensors) {
68 const Tensor self = tensor.is_vulkan() ? tensor : tensor.vulkan();
69 const vTensor& v_self = convert(self);
70
71 // Determine the number of channel texels that will be modified by
72 // appending this input tensor
73 uint32_t start_ch4 = ch_current / 4;
74
75 uint32_t end_ch4 =
76 api::utils::div_up(ch_current + get_dim<Dim4D::Channel>(v_self), 4u);
77
78 uint32_t ch4_range = end_ch4 - start_ch4;
79 uint32_t nc4_range = ch4_range * get_dim<Dim4D::Batch>(v_self);
80
81 const struct Block final {
82 ivec3 outExtents;
83 int32_t fill0;
84 ivec3 inExtents;
85 int32_t fill1;
86 uvec2 outChInfo;
87 uvec2 inChInfo;
88 uvec4 appendedChInfo;
89 } block{
90 api::utils::make_ivec3(v_output.extents()),
91 0,
92 api::utils::make_ivec3(v_self.extents()),
93 0,
94 {
95 ch_total,
96 api::utils::div_up(ch_total, 4u),
97 },
98 {
99 get_dim<Dim4D::Channel>(v_self),
100 api::utils::align_up(get_dim<Dim4D::Channel>(v_self), 4u),
101 },
102 {
103 ch_current,
104 start_ch4,
105 ch4_range,
106 0u,
107 },
108 };
109
110 api::UniformParamsBuffer params(context, block);
111 api::PipelineBarrier pipeline_barrier{};
112
113 context->submit_compute_job(
114 // shader descriptor
115 VK_KERNEL(cat_feature),
116 // pipeline barrier
117 pipeline_barrier,
118 // global work group size
119 {
120 get_dim<Dim4D::Width>(v_output),
121 get_dim<Dim4D::Height>(v_output),
122 nc4_range,
123 },
124 // local work group size
125 adaptive_work_group_size(v_self.extents()),
126 // fence handle
127 VK_NULL_HANDLE,
128 // shader arguments
129 v_output.image(
130 pipeline_barrier,
131 api::PipelineStage::COMPUTE,
132 api::MemoryAccessType::READ | api::MemoryAccessType::WRITE),
133 v_self.image(pipeline_barrier, api::PipelineStage::COMPUTE),
134 // params buffer
135 params.buffer());
136
137 ch_current += get_dim<Dim4D::Channel>(v_self);
138 }
139
140 return convert(v_output);
141 }
142
cat_feature_mult4ch(const MaterializedITensorListRef & tensors,vTensor & v_output)143 Tensor cat_feature_mult4ch(
144 const MaterializedITensorListRef& tensors,
145 vTensor& v_output) {
146 api::Context* const context = api::context();
147
148 int64_t depth_size_allprior = 0;
149 int64_t ch_interval = 0;
150 for (const at::Tensor& tensor : tensors) {
151 ch_interval += get_dim<Dim4D::Channel>(tensor);
152 }
153 const int64_t depth_interval = ch_interval / 4;
154
155 uvec3 src_offset{};
156 uvec3 dst_offset{};
157
158 for (const at::Tensor& tensor_arg : tensors) {
159 const Tensor tensor =
160 tensor_arg.is_vulkan() ? tensor_arg : tensor_arg.vulkan();
161 const vTensor& v_self = convert(tensor);
162
163 const uint32_t depth_slice =
164 safe_downcast<uint32_t>(get_dim<Dim4D::Channel>(tensor) / 4);
165
166 uvec3 copy_extents{
167 v_self.extents().data[0u], v_self.extents().data[1u], depth_slice};
168
169 for (const auto b : c10::irange(get_dim<Dim4D::Batch>(tensor))) {
170 src_offset.data[2u] = safe_downcast<uint32_t>(depth_slice * b);
171 dst_offset.data[2u] =
172 depth_size_allprior + safe_downcast<uint32_t>(depth_interval * b);
173
174 api::PipelineBarrier pipeline_barrier{};
175
176 context->submit_copy<api::VulkanImage, api::VulkanImage>(
177 // pipeline barrier
178 pipeline_barrier,
179 // images
180 v_self.image(pipeline_barrier, api::PipelineStage::TRANSFER),
181 v_output.image(
182 pipeline_barrier,
183 api::PipelineStage::TRANSFER,
184 api::MemoryAccessType::WRITE),
185 // copy details
186 copy_extents,
187 src_offset,
188 dst_offset,
189 // fence handle
190 VK_NULL_HANDLE);
191 }
192
193 depth_size_allprior += depth_slice;
194 }
195
196 return convert(v_output);
197 }
198
cat_width(const MaterializedITensorListRef & tensors,vTensor & v_output)199 Tensor cat_width(const MaterializedITensorListRef& tensors, vTensor& v_output) {
200 // TORCH_CHECK(false, "Vulkan cat not implemented for width dimension!");
201 api::Context* const context = api::context();
202
203 uvec3 src_offset{};
204 uvec3 dst_offset{};
205
206 for (const at::Tensor& tensor : tensors) {
207 const Tensor self = tensor.is_vulkan() ? tensor : tensor.vulkan();
208 const vTensor& v_self = convert(self);
209
210 api::PipelineBarrier pipeline_barrier{};
211
212 context->submit_copy<api::VulkanImage, api::VulkanImage>(
213 // pipeline barrier
214 pipeline_barrier,
215 // images
216 v_self.image(pipeline_barrier, api::PipelineStage::TRANSFER),
217 v_output.image(
218 pipeline_barrier,
219 api::PipelineStage::TRANSFER,
220 api::MemoryAccessType::WRITE),
221 // copy details
222 v_self.extents(),
223 src_offset,
224 dst_offset,
225 // fence handle
226 VK_NULL_HANDLE);
227
228 // Increment by width
229 dst_offset.data[0u] += v_self.extents().data[0u];
230 }
231
232 return convert(v_output);
233 }
234
cat_height(const MaterializedITensorListRef & tensors,vTensor & v_output)235 Tensor cat_height(
236 const MaterializedITensorListRef& tensors,
237 vTensor& v_output) {
238 api::Context* const context = api::context();
239
240 uvec3 src_offset{};
241 uvec3 dst_offset{};
242
243 for (const at::Tensor& tensor : tensors) {
244 const Tensor self = tensor.is_vulkan() ? tensor : tensor.vulkan();
245 const vTensor& v_self = convert(self);
246
247 api::PipelineBarrier pipeline_barrier{};
248
249 context->submit_copy<api::VulkanImage, api::VulkanImage>(
250 // pipeline barrier
251 pipeline_barrier,
252 // images
253 v_self.image(pipeline_barrier, api::PipelineStage::TRANSFER),
254 v_output.image(
255 pipeline_barrier,
256 api::PipelineStage::TRANSFER,
257 api::MemoryAccessType::WRITE),
258 // copy details
259 v_self.extents(),
260 src_offset,
261 dst_offset,
262 // fence handle
263 VK_NULL_HANDLE);
264
265 // Increment by height
266 dst_offset.data[1u] += v_self.extents().data[1u];
267 }
268
269 return convert(v_output);
270 }
271
cat(const at::ITensorListRef & tensors,const int64_t in_dim)272 Tensor cat(const at::ITensorListRef& tensors, const int64_t in_dim) {
273 TORCH_CHECK(!tensors.empty(), "Vulkan cat expects at least one tensor");
274 auto materialized = tensors.materialize();
275 TORCH_INTERNAL_ASSERT(!materialized.empty(), "Accessing empty array");
276 const at::Tensor& tensor = materialized[0];
277 auto ndim = safe_downcast<uint32_t>(tensor.dim());
278 const int64_t dim = normalize_dim(in_dim, ndim);
279 int64_t cat_dim_size = 0;
280 bool is_mult4ch = true;
281
282 for (const at::Tensor& t : materialized) {
283 TORCH_INTERNAL_ASSERT(
284 t.dim() <= 4,
285 "Vulkan cat expects inputs to have at most 4 dimensions, but got ",
286 t.dim(),
287 "d");
288
289 if (ndim < 3 || get_dim<Dim4D::Channel>(t) % 4 != 0) {
290 is_mult4ch = false;
291 }
292
293 for (const auto d : c10::irange(ndim)) {
294 if (d == dim) {
295 continue;
296 }
297 TORCH_INTERNAL_ASSERT(
298 t.size(d) == tensor.size(d),
299 "Vulkan cat inputs must have matching sizes except concatenated dimension");
300 }
301 cat_dim_size += t.size(dim);
302 }
303
304 auto result_size = tensor.sizes().vec();
305 TORCH_INTERNAL_ASSERT(!result_size.empty(), "Accessing empty array");
306 result_size[dim] = cat_dim_size;
307
308 vTensor v_output{
309 api::context(), result_size, convert_dtype(tensor.scalar_type())};
310
311 if (dim == ndim - 1) {
312 return cat_width(materialized, v_output);
313 }
314 if (dim == ndim - 2) {
315 return cat_height(materialized, v_output);
316 } else if (dim == ndim - 3) {
317 if (is_mult4ch) {
318 return cat_feature_mult4ch(materialized, v_output);
319 }
320 return cat_feature(materialized, v_output);
321 }
322 return cat_batch(materialized, v_output);
323 }
324
325 #ifdef USE_VULKAN_API
326
TORCH_LIBRARY_IMPL(aten,Vulkan,m)327 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
328 m.impl(TORCH_SELECTIVE_NAME("aten::cat"), TORCH_FN(cat));
329 }
330
331 #endif /* USE_VULKAN_API */
332
333 } // namespace
334 } // namespace ops
335 } // namespace vulkan
336 } // namespace native
337 } // namespace at
338