1 #ifdef USE_VULKAN_API
2 #include <ATen/native/vulkan/ops/Common.h>
3 #include <ATen/native/vulkan/ops/QuantizedFunctions.h>
4 #include <ATen/native/vulkan/ops/Utils.h>
5 #include <torch/library.h>
6
7 namespace at {
8 namespace native {
9 namespace vulkan {
10 namespace ops {
11
12 using namespace api::utils;
13
get_quantize_per_tensor_shader(const c10::ScalarType dtype)14 static api::ShaderInfo get_quantize_per_tensor_shader(
15 const c10::ScalarType dtype) {
16 switch (dtype) {
17 case c10::ScalarType::QUInt8:
18 return VK_KERNEL(quantize_per_tensor_quint8);
19 case c10::ScalarType::QInt8:
20 return VK_KERNEL(quantize_per_tensor_qint8);
21 case c10::ScalarType::QInt32:
22 return VK_KERNEL(quantize_per_tensor_qint32);
23 default:
24 TORCH_CHECK(
25 false,
26 "Vulkan quantization currently not supported for dtype ",
27 dtype);
28 }
29 }
30
quantize_per_tensor(const at::Tensor & input_arg,const double scale,const int64_t zero_point,const c10::ScalarType dtype)31 Tensor quantize_per_tensor(
32 const at::Tensor& input_arg,
33 const double scale,
34 const int64_t zero_point,
35 const c10::ScalarType dtype) {
36 api::ShaderInfo compute_shader = get_quantize_per_tensor_shader(dtype);
37
38 api::Context* const context = api::context();
39
40 const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan();
41 const vTensor& v_input = convert(input);
42
43 vTensor v_output{
44 context,
45 v_input.sizes(),
46 scale,
47 zero_point,
48 convert_dtype(dtype),
49 };
50
51 const struct Block final {
52 uvec3 extents;
53 uint32_t _;
54 float scale;
55 float _1;
56 int32_t zero_point;
57 int32_t _2;
58 } block{
59 v_output.extents(),
60 0u,
61 safe_downcast<float>(scale),
62 0.0f,
63 safe_downcast<int32_t>(zero_point),
64 0u,
65 };
66
67 api::UniformParamsBuffer params(context, block);
68 api::PipelineBarrier pipeline_barrier{};
69
70 context->submit_compute_job(
71 // shader descriptor
72 compute_shader,
73 // barrier
74 pipeline_barrier,
75 // global work group size
76 v_input.extents(),
77 // local work group size
78 adaptive_work_group_size(v_input.extents()),
79 // fence handle
80 VK_NULL_HANDLE,
81 // shader arguments
82 v_output.image(
83 pipeline_barrier,
84 api::PipelineStage::COMPUTE,
85 api::MemoryAccessType::WRITE),
86 v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
87 // params buffer
88 params.buffer());
89
90 return convert_quantized(v_output);
91 }
92
quantize_per_tensor_tensor_qparams(const at::Tensor & input_arg,const at::Tensor & scale,const at::Tensor & zero_point,const c10::ScalarType dtype)93 Tensor quantize_per_tensor_tensor_qparams(
94 const at::Tensor& input_arg,
95 const at::Tensor& scale,
96 const at::Tensor& zero_point,
97 const c10::ScalarType dtype) {
98 TORCH_CHECK(
99 (scale.numel() == 1 && zero_point.numel() == 1),
100 "Only 1 element expected in scale and zero_point");
101 return quantize_per_tensor(
102 input_arg, scale.item().toDouble(), zero_point.item().toLong(), dtype);
103 }
104
105 // helper for dequantize function to use scale and zero_point
dequantize_helper(const at::Tensor & input_arg,const double scale,const int64_t zero_point,const c10::ScalarType dtype)106 Tensor dequantize_helper(
107 const at::Tensor& input_arg,
108 const double scale,
109 const int64_t zero_point,
110 const c10::ScalarType dtype) {
111 TORCH_CHECK(dtype == kFloat, "Expected type Float");
112
113 api::Context* const context = api::context();
114
115 const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan();
116 const vTensor& v_input = convert(input);
117
118 vTensor v_output{
119 context,
120 v_input.sizes(),
121 api::kFloat,
122 };
123
124 const struct Block final {
125 uvec3 extents;
126 uint32_t _;
127 float scale;
128 float _1;
129 int32_t zero_point;
130 int32_t _2;
131 } block{
132 v_output.extents(),
133 0u,
134 safe_downcast<float>(scale),
135 0.0f,
136 safe_downcast<int32_t>(zero_point),
137 0u,
138 };
139
140 api::UniformParamsBuffer params(context, block);
141 api::PipelineBarrier pipeline_barrier{};
142 context->submit_compute_job(
143 // shader descriptor
144 VK_KERNEL(dequantize),
145 // pipeline barrier
146 pipeline_barrier,
147 // global work group size
148 v_input.extents(),
149 // local work group size
150 adaptive_work_group_size(v_input.extents()),
151 // fence handle
152 VK_NULL_HANDLE,
153 // shader arguments
154 v_output.image(
155 pipeline_barrier,
156 api::PipelineStage::COMPUTE,
157 api::MemoryAccessType::WRITE),
158 v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
159 // params buffer
160 params.buffer());
161
162 return convert(v_output);
163 }
164
q_scale(const Tensor & self)165 static double q_scale(const Tensor& self) {
166 TORCH_CHECK(self.is_vulkan(), "Expecting a vulkan tensor for q_scale");
167 const vTensor& v_input = convert(self);
168 return v_input.get_scale();
169 }
170
q_zero_point(const Tensor & self)171 static int64_t q_zero_point(const Tensor& self) {
172 TORCH_CHECK(self.is_vulkan(), "Expecting a vulkan tensor for q_zero_point");
173 const vTensor& v_input = convert(self);
174 return v_input.get_zero_point();
175 }
176
dequantize(const Tensor & self)177 Tensor dequantize(const Tensor& self) {
178 double q_scale = convert(self).get_scale();
179 int64_t zero_point = convert(self).get_zero_point();
180 return dequantize_helper(self, q_scale, zero_point, kFloat);
181 }
182
TORCH_LIBRARY_IMPL(aten,Vulkan,m)183 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
184 m.impl(
185 TORCH_SELECTIVE_NAME("aten::quantize_per_tensor"), quantize_per_tensor);
186 m.impl(
187 TORCH_SELECTIVE_NAME("aten::quantize_per_tensor.tensor_qparams"),
188 quantize_per_tensor_tensor_qparams);
189 m.impl(TORCH_SELECTIVE_NAME("aten::q_scale"), q_scale);
190 m.impl(TORCH_SELECTIVE_NAME("aten::q_zero_point"), q_zero_point);
191 m.impl(TORCH_SELECTIVE_NAME("aten::dequantize.self"), dequantize);
192 }
193
194 } // namespace ops
195 } // namespace vulkan
196 } // namespace native
197 } // namespace at
198 #endif /* USE_VULKAN_API */
199