1 #include <ATen/native/vulkan/ops/Factory.h>
2 #include <torch/library.h>
3
4 namespace at {
5 namespace native {
6 namespace vulkan {
7 namespace ops {
8
_empty_affine_quantized(const IntArrayRef sizes,const std::optional<ScalarType> dtype,const std::optional<c10::Layout> layout,const std::optional<Device> device,const std::optional<bool> pin_memory,const double scale,const int64_t zero_point,const std::optional<MemoryFormat> memory_format)9 Tensor _empty_affine_quantized(
10 const IntArrayRef sizes,
11 const std::optional<ScalarType> dtype,
12 const std::optional<c10::Layout> layout,
13 const std::optional<Device> device,
14 const std::optional<bool> pin_memory,
15 const double scale,
16 const int64_t zero_point,
17 const std::optional<MemoryFormat> memory_format) {
18 api::StorageType storage_type = api::StorageType::TEXTURE_3D;
19 return convert_quantized(vTensor{
20 api::context(),
21 sizes.vec(),
22 scale,
23 zero_point,
24 convert_dtype(dtype ? *dtype : c10::kFloat),
25 storage_type,
26 memory_format ? get_gpu_memory_layout(storage_type, *memory_format)
27 : api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED,
28 });
29 }
30
empty_memory_format(const IntArrayRef sizes,const std::optional<ScalarType> dtype,const std::optional<c10::Layout> layout,const std::optional<Device> device,const std::optional<bool> pin_memory,const std::optional<MemoryFormat> memory_format)31 static Tensor empty_memory_format(
32 const IntArrayRef sizes,
33 const std::optional<ScalarType> dtype,
34 const std::optional<c10::Layout> layout,
35 const std::optional<Device> device,
36 const std::optional<bool> pin_memory,
37 const std::optional<MemoryFormat> memory_format) {
38 api::StorageType storage_type = api::StorageType::TEXTURE_3D;
39 return convert(vTensor{
40 api::context(),
41 sizes.vec(),
42 convert_dtype(dtype ? *dtype : c10::kFloat),
43 storage_type,
44 memory_format ? get_gpu_memory_layout(storage_type, *memory_format)
45 : api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED,
46 });
47 }
48
empty_strided(const IntArrayRef sizes,const IntArrayRef,const std::optional<ScalarType> dtype,const std::optional<c10::Layout> layout,const std::optional<Device> device,const std::optional<bool> pin_memory)49 static Tensor empty_strided(
50 const IntArrayRef sizes,
51 const IntArrayRef /* strides */,
52 const std::optional<ScalarType> dtype,
53 const std::optional<c10::Layout> layout,
54 const std::optional<Device> device,
55 const std::optional<bool> pin_memory) {
56 return empty_memory_format(
57 sizes, dtype, layout, device, pin_memory, c10::MemoryFormat::Contiguous);
58 }
59
60 #ifdef USE_VULKAN_API
61
TORCH_LIBRARY_IMPL(aten,Vulkan,m)62 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
63 m.impl(
64 TORCH_SELECTIVE_NAME("aten::empty.memory_format"),
65 at::native::vulkan::ops::empty_memory_format);
66 m.impl(
67 TORCH_SELECTIVE_NAME("aten::_empty_affine_quantized"),
68 at::native::vulkan::ops::_empty_affine_quantized);
69 m.impl(
70 TORCH_SELECTIVE_NAME("aten::empty_strided"),
71 TORCH_FN(at::native::vulkan::ops::empty_strided));
72 }
73
74 #endif /* USE_VULKAN_API */
75
76 } // namespace ops
77 } // namespace vulkan
78 } // namespace native
79 } // namespace at
80