• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #ifdef USE_VULKAN_API
2 #include <ATen/native/vulkan/ops/Common.h>
3 #include <ATen/native/vulkan/ops/QuantizedFunctions.h>
4 #include <ATen/native/vulkan/ops/Utils.h>
5 #include <torch/library.h>
6 
7 namespace at {
8 namespace native {
9 namespace vulkan {
10 namespace ops {
11 
12 using namespace api::utils;
13 
get_quantize_per_tensor_shader(const c10::ScalarType dtype)14 static api::ShaderInfo get_quantize_per_tensor_shader(
15     const c10::ScalarType dtype) {
16   switch (dtype) {
17     case c10::ScalarType::QUInt8:
18       return VK_KERNEL(quantize_per_tensor_quint8);
19     case c10::ScalarType::QInt8:
20       return VK_KERNEL(quantize_per_tensor_qint8);
21     case c10::ScalarType::QInt32:
22       return VK_KERNEL(quantize_per_tensor_qint32);
23     default:
24       TORCH_CHECK(
25           false,
26           "Vulkan quantization currently not supported for dtype ",
27           dtype);
28   }
29 }
30 
quantize_per_tensor(const at::Tensor & input_arg,const double scale,const int64_t zero_point,const c10::ScalarType dtype)31 Tensor quantize_per_tensor(
32     const at::Tensor& input_arg,
33     const double scale,
34     const int64_t zero_point,
35     const c10::ScalarType dtype) {
36   api::ShaderInfo compute_shader = get_quantize_per_tensor_shader(dtype);
37 
38   api::Context* const context = api::context();
39 
40   const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan();
41   const vTensor& v_input = convert(input);
42 
43   vTensor v_output{
44       context,
45       v_input.sizes(),
46       scale,
47       zero_point,
48       convert_dtype(dtype),
49   };
50 
51   const struct Block final {
52     uvec3 extents;
53     uint32_t _;
54     float scale;
55     float _1;
56     int32_t zero_point;
57     int32_t _2;
58   } block{
59       v_output.extents(),
60       0u,
61       safe_downcast<float>(scale),
62       0.0f,
63       safe_downcast<int32_t>(zero_point),
64       0u,
65   };
66 
67   api::UniformParamsBuffer params(context, block);
68   api::PipelineBarrier pipeline_barrier{};
69 
70   context->submit_compute_job(
71       // shader descriptor
72       compute_shader,
73       // barrier
74       pipeline_barrier,
75       // global work group size
76       v_input.extents(),
77       // local work group size
78       adaptive_work_group_size(v_input.extents()),
79       // fence handle
80       VK_NULL_HANDLE,
81       // shader arguments
82       v_output.image(
83           pipeline_barrier,
84           api::PipelineStage::COMPUTE,
85           api::MemoryAccessType::WRITE),
86       v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
87       // params buffer
88       params.buffer());
89 
90   return convert_quantized(v_output);
91 }
92 
quantize_per_tensor_tensor_qparams(const at::Tensor & input_arg,const at::Tensor & scale,const at::Tensor & zero_point,const c10::ScalarType dtype)93 Tensor quantize_per_tensor_tensor_qparams(
94     const at::Tensor& input_arg,
95     const at::Tensor& scale,
96     const at::Tensor& zero_point,
97     const c10::ScalarType dtype) {
98   TORCH_CHECK(
99       (scale.numel() == 1 && zero_point.numel() == 1),
100       "Only 1 element expected in scale and zero_point");
101   return quantize_per_tensor(
102       input_arg, scale.item().toDouble(), zero_point.item().toLong(), dtype);
103 }
104 
105 // helper for dequantize function to use scale and zero_point
dequantize_helper(const at::Tensor & input_arg,const double scale,const int64_t zero_point,const c10::ScalarType dtype)106 Tensor dequantize_helper(
107     const at::Tensor& input_arg,
108     const double scale,
109     const int64_t zero_point,
110     const c10::ScalarType dtype) {
111   TORCH_CHECK(dtype == kFloat, "Expected type Float");
112 
113   api::Context* const context = api::context();
114 
115   const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan();
116   const vTensor& v_input = convert(input);
117 
118   vTensor v_output{
119       context,
120       v_input.sizes(),
121       api::kFloat,
122   };
123 
124   const struct Block final {
125     uvec3 extents;
126     uint32_t _;
127     float scale;
128     float _1;
129     int32_t zero_point;
130     int32_t _2;
131   } block{
132       v_output.extents(),
133       0u,
134       safe_downcast<float>(scale),
135       0.0f,
136       safe_downcast<int32_t>(zero_point),
137       0u,
138   };
139 
140   api::UniformParamsBuffer params(context, block);
141   api::PipelineBarrier pipeline_barrier{};
142   context->submit_compute_job(
143       // shader descriptor
144       VK_KERNEL(dequantize),
145       // pipeline barrier
146       pipeline_barrier,
147       // global work group size
148       v_input.extents(),
149       // local work group size
150       adaptive_work_group_size(v_input.extents()),
151       // fence handle
152       VK_NULL_HANDLE,
153       // shader arguments
154       v_output.image(
155           pipeline_barrier,
156           api::PipelineStage::COMPUTE,
157           api::MemoryAccessType::WRITE),
158       v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
159       // params buffer
160       params.buffer());
161 
162   return convert(v_output);
163 }
164 
q_scale(const Tensor & self)165 static double q_scale(const Tensor& self) {
166   TORCH_CHECK(self.is_vulkan(), "Expecting a vulkan tensor for q_scale");
167   const vTensor& v_input = convert(self);
168   return v_input.get_scale();
169 }
170 
q_zero_point(const Tensor & self)171 static int64_t q_zero_point(const Tensor& self) {
172   TORCH_CHECK(self.is_vulkan(), "Expecting a vulkan tensor for q_zero_point");
173   const vTensor& v_input = convert(self);
174   return v_input.get_zero_point();
175 }
176 
dequantize(const Tensor & self)177 Tensor dequantize(const Tensor& self) {
178   double q_scale = convert(self).get_scale();
179   int64_t zero_point = convert(self).get_zero_point();
180   return dequantize_helper(self, q_scale, zero_point, kFloat);
181 }
182 
TORCH_LIBRARY_IMPL(aten,Vulkan,m)183 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
184   m.impl(
185       TORCH_SELECTIVE_NAME("aten::quantize_per_tensor"), quantize_per_tensor);
186   m.impl(
187       TORCH_SELECTIVE_NAME("aten::quantize_per_tensor.tensor_qparams"),
188       quantize_per_tensor_tensor_qparams);
189   m.impl(TORCH_SELECTIVE_NAME("aten::q_scale"), q_scale);
190   m.impl(TORCH_SELECTIVE_NAME("aten::q_zero_point"), q_zero_point);
191   m.impl(TORCH_SELECTIVE_NAME("aten::dequantize.self"), dequantize);
192 }
193 
194 } // namespace ops
195 } // namespace vulkan
196 } // namespace native
197 } // namespace at
198 #endif /* USE_VULKAN_API */
199