• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/native/vulkan/ops/Factory.h>
2 #include <torch/library.h>
3 
4 namespace at {
5 namespace native {
6 namespace vulkan {
7 namespace ops {
8 
_empty_affine_quantized(const IntArrayRef sizes,const std::optional<ScalarType> dtype,const std::optional<c10::Layout> layout,const std::optional<Device> device,const std::optional<bool> pin_memory,const double scale,const int64_t zero_point,const std::optional<MemoryFormat> memory_format)9 Tensor _empty_affine_quantized(
10     const IntArrayRef sizes,
11     const std::optional<ScalarType> dtype,
12     const std::optional<c10::Layout> layout,
13     const std::optional<Device> device,
14     const std::optional<bool> pin_memory,
15     const double scale,
16     const int64_t zero_point,
17     const std::optional<MemoryFormat> memory_format) {
18   api::StorageType storage_type = api::StorageType::TEXTURE_3D;
19   return convert_quantized(vTensor{
20       api::context(),
21       sizes.vec(),
22       scale,
23       zero_point,
24       convert_dtype(dtype ? *dtype : c10::kFloat),
25       storage_type,
26       memory_format ? get_gpu_memory_layout(storage_type, *memory_format)
27                     : api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED,
28   });
29 }
30 
empty_memory_format(const IntArrayRef sizes,const std::optional<ScalarType> dtype,const std::optional<c10::Layout> layout,const std::optional<Device> device,const std::optional<bool> pin_memory,const std::optional<MemoryFormat> memory_format)31 static Tensor empty_memory_format(
32     const IntArrayRef sizes,
33     const std::optional<ScalarType> dtype,
34     const std::optional<c10::Layout> layout,
35     const std::optional<Device> device,
36     const std::optional<bool> pin_memory,
37     const std::optional<MemoryFormat> memory_format) {
38   api::StorageType storage_type = api::StorageType::TEXTURE_3D;
39   return convert(vTensor{
40       api::context(),
41       sizes.vec(),
42       convert_dtype(dtype ? *dtype : c10::kFloat),
43       storage_type,
44       memory_format ? get_gpu_memory_layout(storage_type, *memory_format)
45                     : api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED,
46   });
47 }
48 
empty_strided(const IntArrayRef sizes,const IntArrayRef,const std::optional<ScalarType> dtype,const std::optional<c10::Layout> layout,const std::optional<Device> device,const std::optional<bool> pin_memory)49 static Tensor empty_strided(
50     const IntArrayRef sizes,
51     const IntArrayRef /* strides */,
52     const std::optional<ScalarType> dtype,
53     const std::optional<c10::Layout> layout,
54     const std::optional<Device> device,
55     const std::optional<bool> pin_memory) {
56   return empty_memory_format(
57       sizes, dtype, layout, device, pin_memory, c10::MemoryFormat::Contiguous);
58 }
59 
60 #ifdef USE_VULKAN_API
61 
TORCH_LIBRARY_IMPL(aten,Vulkan,m)62 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
63   m.impl(
64       TORCH_SELECTIVE_NAME("aten::empty.memory_format"),
65       at::native::vulkan::ops::empty_memory_format);
66   m.impl(
67       TORCH_SELECTIVE_NAME("aten::_empty_affine_quantized"),
68       at::native::vulkan::ops::_empty_affine_quantized);
69   m.impl(
70       TORCH_SELECTIVE_NAME("aten::empty_strided"),
71       TORCH_FN(at::native::vulkan::ops::empty_strided));
72 }
73 
74 #endif /* USE_VULKAN_API */
75 
76 } // namespace ops
77 } // namespace vulkan
78 } // namespace native
79 } // namespace at
80