1 #ifdef USE_VULKAN_API
2
3 #include <ATen/native/quantized/PackedParams.h>
4 #include <ATen/native/vulkan/ops/Batchnorm.h>
5 #include <ATen/native/vulkan/ops/Common.h>
6 #include <ATen/native/vulkan/ops/Convolution.h>
7 #include <ATen/native/vulkan/ops/Gru.h>
8 #include <ATen/native/vulkan/ops/Layernorm.h>
9 #include <ATen/native/vulkan/ops/Lstm.h>
10 #include <ATen/native/vulkan/ops/Mm.h>
11 #include <ATen/native/vulkan/ops/QuantizedFunctions.h>
12 #include <ATen/native/vulkan/ops/Register.h>
13 #include <torch/custom_class.h>
14 #include <torch/library.h>
15
16 namespace at {
17 namespace native {
18 namespace vulkan {
19 namespace ops {
20
register_vulkan_conv2d_packed_context()21 int register_vulkan_conv2d_packed_context() {
22 static auto register_vulkan_conv2d_context =
23 torch::selective_class_<Conv2dPackedContext>(
24 "vulkan", TORCH_SELECTIVE_CLASS("Conv2dPackedContext"))
25 .def_pickle(
26 // __getstate__
27 [](const c10::intrusive_ptr<Conv2dPackedContext>& context) {
28 // context is packed
29 return context->unpack();
30 },
31 // __setstate__
32 [](c10::impl::GenericList state) {
33 // state is unpacked
34 return c10::make_intrusive<Conv2dPackedContext>(
35 Conv2dPackedContext::pack(state));
36 });
37 return 0;
38 }
39
register_vulkan_conv1d_packed_context()40 int register_vulkan_conv1d_packed_context() {
41 static auto register_vulkan_conv1d_context =
42 torch::selective_class_<Conv1dPackedContext>(
43 "vulkan", TORCH_SELECTIVE_CLASS("Conv1dPackedContext"))
44 .def_pickle(
45 // __getstate__
46 [](const c10::intrusive_ptr<Conv1dPackedContext>& context) {
47 // context is packed
48 return context->unpack();
49 },
50 // __setstate__
51 [](c10::impl::GenericList state) {
52 // state is unpacked
53 return c10::make_intrusive<Conv1dPackedContext>(
54 Conv1dPackedContext::pack(state));
55 });
56 return 0;
57 }
58
register_vulkan_linear_packed_context()59 int register_vulkan_linear_packed_context() {
60 static auto register_vulkan_linear_context =
61 torch::selective_class_<LinearPackedContext>(
62 "vulkan", TORCH_SELECTIVE_CLASS("LinearPackedContext"))
63 .def_pickle(
64 // __getstate__
65 [](const c10::intrusive_ptr<LinearPackedContext>& context) {
66 // context is packed
67 return context->unpack();
68 },
69 // __setstate__
70 [](c10::impl::GenericList state) {
71 // state is unpacked
72 return c10::make_intrusive<LinearPackedContext>(
73 LinearPackedContext::pack(state));
74 });
75 return 0;
76 }
77
register_vulkan_layernorm_packed_context()78 int register_vulkan_layernorm_packed_context() {
79 static auto register_vulkan_layernorm_context =
80 torch::selective_class_<LayernormPackedContext>(
81 "vulkan", TORCH_SELECTIVE_CLASS("LayernormPackedContext"))
82 .def_pickle(
83 // __getstate__
84 [](const c10::intrusive_ptr<LayernormPackedContext>& context) {
85 // context is packed
86 return context->unpack();
87 },
88 // __setstate__
89 [](c10::impl::GenericList state) {
90 // state is unpacked
91 return c10::make_intrusive<LayernormPackedContext>(
92 LayernormPackedContext::pack(state));
93 });
94 return 0;
95 }
96
97 namespace {
98
TORCH_LIBRARY(vulkan,m)99 TORCH_LIBRARY(vulkan, m) {
100 m.class_<BatchNormPackedContext>("BatchNormPackedContext")
101 .def_pickle(
102 // __getstate__
103 [](const c10::intrusive_ptr<BatchNormPackedContext>& context) {
104 // context is packed
105 return context->unpack();
106 },
107 // __setstate__
108 [](c10::impl::GenericList state) {
109 // state is unpacked
110 return c10::make_intrusive<BatchNormPackedContext>(
111 BatchNormPackedContext::pack(state));
112 });
113 m.class_<GruPackedContext>("GruPackedContext")
114 .def_pickle(
115 // __getstate__
116 [](const c10::intrusive_ptr<GruPackedContext>& context) {
117 // context is packed
118 return context->unpack();
119 },
120 // __setstate__
121 [](c10::impl::GenericList state) {
122 // state is unpacked
123 return c10::make_intrusive<GruPackedContext>(
124 GruPackedContext::pack(state));
125 });
126 m.class_<LstmPackedContext>("LstmPackedContext")
127 .def_pickle(
128 // __getstate__
129 [](const c10::intrusive_ptr<LstmPackedContext>& context) {
130 // context is packed
131 return context->unpack();
132 },
133 // __setstate__
134 [](c10::impl::GenericList state) {
135 // state is unpacked
136 return c10::make_intrusive<LstmPackedContext>(
137 LstmPackedContext::pack(state));
138 });
139 register_vulkan_conv2d_packed_context();
140 register_vulkan_conv1d_packed_context();
141 register_vulkan_linear_packed_context();
142 register_vulkan_layernorm_packed_context();
143 // To maintain backwards compatibility.
144 m.class_<Conv2dOpContext>("Conv2dOpContext")
145 .def_pickle(
146 // __getstate__
147 [](const c10::intrusive_ptr<Conv2dOpContext>& context) {
148 return context->unpack();
149 },
150 // __setstate__
151 [](Conv2dOpContext::State state) {
152 return conv2d_clamp_prepack(
153 std::move(std::get<0>(state)),
154 std::move(std::get<1>(state)),
155 std::move(std::get<2>(state)),
156 std::move(std::get<3>(state)),
157 std::move(std::get<4>(state)),
158 std::get<5>(state),
159 std::get<6>(state),
160 std::get<7>(state));
161 });
162 }
163
TORCH_LIBRARY(vulkan_prepack,m)164 TORCH_LIBRARY(vulkan_prepack, m) {
165 m.def(TORCH_SELECTIVE_SCHEMA(
166 "vulkan_prepack::create_conv2d_context(Tensor W, Tensor? B, int[2] stride, "
167 "int[2] padding, int[2] dilation, int groups, "
168 "Scalar? output_min=None, Scalar? output_max=None) "
169 "-> __torch__.torch.classes.vulkan.Conv2dPackedContext"));
170 m.def(TORCH_SELECTIVE_SCHEMA( // Backwards compatibility
171 "vulkan_prepack::conv2d_clamp_prepack(Tensor W, Tensor? B, int[2] stride, "
172 "int[2] padding, int[2] dilation, int groups, "
173 "Scalar? output_min=None, Scalar? output_max=None) "
174 "-> __torch__.torch.classes.vulkan.Conv2dOpContext"));
175 m.def(TORCH_SELECTIVE_SCHEMA(
176 "vulkan_prepack::run_conv2d_context(Tensor X, "
177 "__torch__.torch.classes.vulkan.Conv2dPackedContext W_prepack) -> Tensor Y"));
178 m.def(TORCH_SELECTIVE_SCHEMA( // Backwards compatibility
179 "vulkan_prepack::conv2d_clamp_run(Tensor X, "
180 "__torch__.torch.classes.vulkan.Conv2dOpContext W_prepack) -> Tensor Y"));
181 m.def(TORCH_SELECTIVE_SCHEMA(
182 "vulkan_prepack::create_tconv2d_context(Tensor W, Tensor? B, int[2] stride, "
183 "int[2] padding, int[2] output_padding, int[2] dilation, int groups, "
184 "Scalar? output_min=None, Scalar? output_max=None) "
185 "-> __torch__.torch.classes.vulkan.Conv2dPackedContext"));
186 m.def(TORCH_SELECTIVE_SCHEMA(
187 "vulkan_prepack::run_tconv2d_context(Tensor X, "
188 "__torch__.torch.classes.vulkan.Conv2dPackedContext W_prepack) -> Tensor Y"));
189 m.def(TORCH_SELECTIVE_SCHEMA(
190 "vulkan_prepack::create_qconv2d_context(Tensor W, Tensor? B, "
191 "int[2] stride, int[2] padding, int[2] dilation, int groups, "
192 "Scalar? output_min=None, Scalar? output_max=None) "
193 "-> __torch__.torch.classes.vulkan.Conv2dPackedContext"));
194 m.def(TORCH_SELECTIVE_SCHEMA(
195 "vulkan_prepack::run_qconv2d_context(Tensor X, float scale, int zero_point, "
196 "__torch__.torch.classes.vulkan.Conv2dPackedContext vk_context) -> Tensor Y"));
197 m.def(TORCH_SELECTIVE_SCHEMA(
198 "vulkan_prepack::create_conv1d_context(Tensor W, Tensor? B, int[2] stride, "
199 "int[2] padding, int[2] dilation, int groups) "
200 "-> __torch__.torch.classes.vulkan.Conv1dPackedContext"));
201 m.def(TORCH_SELECTIVE_SCHEMA(
202 "vulkan_prepack::run_conv1d_context(Tensor X, "
203 "__torch__.torch.classes.vulkan.Conv1dPackedContext W_prepack) -> Tensor Y"));
204 m.def(TORCH_SELECTIVE_SCHEMA(
205 "vulkan_prepack::create_qtconv2d_context(Tensor W, Tensor? B, int[2] stride, "
206 "int[2] padding, int[2] output_padding, int[2] dilation, int groups, "
207 "Scalar? output_min=None, Scalar? output_max=None) "
208 "-> __torch__.torch.classes.vulkan.Conv2dPackedContext"));
209 m.def(TORCH_SELECTIVE_SCHEMA(
210 "vulkan_prepack::create_linear_context(Tensor W, Tensor? B) "
211 "-> __torch__.torch.classes.vulkan.LinearPackedContext"));
212 m.def(TORCH_SELECTIVE_SCHEMA(
213 "vulkan_prepack::run_linear_context(Tensor X, "
214 "__torch__.torch.classes.vulkan.LinearPackedContext BW_prepack) -> Tensor Y"));
215 m.def(TORCH_SELECTIVE_SCHEMA(
216 "vulkan_prepack::run_qlinear_context(Tensor X, float scale, int zero_point, "
217 "__torch__.torch.classes.vulkan.LinearPackedContext vk_context) -> Tensor Y"));
218 m.def(TORCH_SELECTIVE_SCHEMA(
219 "vulkan_prepack::create_layernorm_context(Tensor? W, Tensor? B, float eps) "
220 "-> __torch__.torch.classes.vulkan.LayernormPackedContext"));
221 m.def(TORCH_SELECTIVE_SCHEMA(
222 "vulkan_prepack::run_layernorm_context(Tensor X, SymInt[] normalized_shape, "
223 "__torch__.torch.classes.vulkan.LayernormPackedContext BW_prepack) -> Tensor Y"));
224 m.def(TORCH_SELECTIVE_SCHEMA(
225 "vulkan_prepack::create_gru_context(Tensor[] params_cpu, "
226 "bool has_biases, "
227 "int num_layers, "
228 "float dropout, "
229 "bool train, "
230 "bool bidirectional, "
231 "bool batch_first) "
232 "-> __torch__.torch.classes.vulkan.GruPackedContext"));
233 m.def(TORCH_SELECTIVE_SCHEMA(
234 "vulkan_prepack::run_gru_context(Tensor input_vk, "
235 "Tensor hx_vk, "
236 "__torch__.torch.classes.vulkan.GruPackedContext G_prepack) -> (Tensor next_input, Tensor hidden_layer)"));
237 m.def(TORCH_SELECTIVE_SCHEMA(
238 "vulkan_prepack::create_lstm_context(Tensor[] params_cpu, "
239 "bool has_biases, "
240 "int num_layers, "
241 "float dropout, "
242 "bool train, "
243 "bool bidirectional, "
244 "bool batch_first) "
245 "-> __torch__.torch.classes.vulkan.LstmPackedContext"));
246 m.def(TORCH_SELECTIVE_SCHEMA(
247 "vulkan_prepack::run_lstm_context(Tensor input_vk, "
248 "Tensor hx_vk, "
249 "Tensor cx_vk, "
250 "__torch__.torch.classes.vulkan.LstmPackedContext L_prepack) -> (Tensor next_input, Tensor hidden_state, Tensor cell_state)"));
251 m.def(TORCH_SELECTIVE_SCHEMA(
252 "vulkan_prepack::create_batchnorm_context("
253 "Tensor? weight_opt, "
254 "Tensor? bias_opt, "
255 "Tensor? running_mean_opt, "
256 "Tensor? running_var_opt, "
257 "bool training, "
258 "float momentum, "
259 "float eps, "
260 "bool cudnn_enable) "
261 "-> __torch__.torch.classes.vulkan.BatchNormPackedContext"));
262 m.def(TORCH_SELECTIVE_SCHEMA(
263 "vulkan_prepack::run_batchnorm_context("
264 "Tensor input_vk, "
265 "__torch__.torch.classes.vulkan.BatchNormPackedContext context) "
266 "-> Tensor out"));
267 }
268
TORCH_LIBRARY_IMPL(vulkan_prepack,CPU,m)269 TORCH_LIBRARY_IMPL(vulkan_prepack, CPU, m) {
270 m.impl(
271 TORCH_SELECTIVE_NAME("vulkan_prepack::create_conv2d_context"),
272 TORCH_FN(create_conv2d_context));
273 m.impl(
274 TORCH_SELECTIVE_NAME("vulkan_prepack::conv2d_clamp_prepack"),
275 TORCH_FN(conv2d_clamp_prepack)); // Backwards compatibility
276 m.impl(
277 TORCH_SELECTIVE_NAME("vulkan_prepack::create_tconv2d_context"),
278 TORCH_FN(create_tconv2d_context));
279 m.impl(
280 TORCH_SELECTIVE_NAME("vulkan_prepack::create_conv1d_context"),
281 TORCH_FN(create_conv1d_context));
282 m.impl(
283 TORCH_SELECTIVE_NAME("vulkan_prepack::create_linear_context"),
284 TORCH_FN(create_linear_context));
285 m.impl(
286 TORCH_SELECTIVE_NAME("vulkan_prepack::create_layernorm_context"),
287 TORCH_FN(create_layernorm_context));
288 m.impl(
289 TORCH_SELECTIVE_NAME("vulkan_prepack::create_gru_context"),
290 TORCH_FN(create_gru_context));
291 m.impl(
292 TORCH_SELECTIVE_NAME("vulkan_prepack::create_lstm_context"),
293 TORCH_FN(create_lstm_context));
294 m.impl(
295 TORCH_SELECTIVE_NAME("vulkan_prepack::create_batchnorm_context"),
296 TORCH_FN(create_batchnorm_context));
297 }
298
TORCH_LIBRARY_IMPL(vulkan_prepack,QuantizedCPU,m)299 TORCH_LIBRARY_IMPL(vulkan_prepack, QuantizedCPU, m) {
300 m.impl(
301 TORCH_SELECTIVE_NAME("vulkan_prepack::create_qconv2d_context"),
302 TORCH_FN(create_qconv2d_context));
303 m.impl(
304 TORCH_SELECTIVE_NAME("vulkan_prepack::create_qtconv2d_context"),
305 TORCH_FN(create_qtconv2d_context));
306 }
307
TORCH_LIBRARY_IMPL(vulkan_prepack,Vulkan,m)308 TORCH_LIBRARY_IMPL(vulkan_prepack, Vulkan, m) {
309 m.impl(
310 TORCH_SELECTIVE_NAME("vulkan_prepack::run_conv2d_context"),
311 TORCH_FN(run_conv2d_context));
312 m.impl(
313 TORCH_SELECTIVE_NAME("vulkan_prepack::conv2d_clamp_run"),
314 TORCH_FN(conv2d_clamp_run)); // Backwards compatibility
315 m.impl(
316 TORCH_SELECTIVE_NAME("vulkan_prepack::run_tconv2d_context"),
317 TORCH_FN(run_tconv2d_context));
318 m.impl(
319 TORCH_SELECTIVE_NAME("vulkan_prepack::run_qconv2d_context"),
320 TORCH_FN(run_qconv2d_context));
321 m.impl(
322 TORCH_SELECTIVE_NAME("vulkan_prepack::run_conv1d_context"),
323 TORCH_FN(run_conv1d_context));
324 m.impl(
325 TORCH_SELECTIVE_NAME("vulkan_prepack::run_linear_context"),
326 TORCH_FN(run_linear_context));
327 m.impl(
328 TORCH_SELECTIVE_NAME("vulkan_prepack::run_layernorm_context"),
329 TORCH_FN(run_layernorm_context));
330 m.impl(
331 TORCH_SELECTIVE_NAME("vulkan_prepack::run_qlinear_context"),
332 TORCH_FN(run_qlinear_context));
333 m.impl(
334 TORCH_SELECTIVE_NAME("vulkan_prepack::run_gru_context"),
335 TORCH_FN(run_gru_context));
336 m.impl(
337 TORCH_SELECTIVE_NAME("vulkan_prepack::run_lstm_context"),
338 TORCH_FN(run_lstm_context));
339 m.impl(
340 TORCH_SELECTIVE_NAME("vulkan_prepack::run_batchnorm_context"),
341 TORCH_FN(run_batchnorm_context));
342 }
343
TORCH_LIBRARY(vulkan_quantized,m)344 TORCH_LIBRARY(vulkan_quantized, m) {
345 m.def(
346 TORCH_SELECTIVE_SCHEMA("vulkan_quantized::add(Tensor qa, "
347 "Tensor qb, "
348 "float scale, "
349 "int zero_point) -> Tensor qc"));
350 m.def(
351 TORCH_SELECTIVE_SCHEMA("vulkan_quantized::sub(Tensor qa, "
352 "Tensor qb, "
353 "float scale, "
354 "int zero_point)-> Tensor qc"));
355 m.def(
356 TORCH_SELECTIVE_SCHEMA("vulkan_quantized::mul(Tensor qa, "
357 "Tensor qb, "
358 "float scale, "
359 "int zero_point)-> Tensor qc"));
360 m.def(
361 TORCH_SELECTIVE_SCHEMA("vulkan_quantized::div(Tensor qa, "
362 "Tensor qb, "
363 "float scale, "
364 "int zero_point)-> Tensor qc"));
365 }
366
TORCH_LIBRARY_IMPL(vulkan_quantized,Vulkan,m)367 TORCH_LIBRARY_IMPL(vulkan_quantized, Vulkan, m) {
368 m.impl(
369 TORCH_SELECTIVE_NAME("vulkan_quantized::add"), TORCH_FN(quantized_add));
370 m.impl(
371 TORCH_SELECTIVE_NAME("vulkan_quantized::sub"), TORCH_FN(quantized_sub));
372 m.impl(
373 TORCH_SELECTIVE_NAME("vulkan_quantized::mul"), TORCH_FN(quantized_mul));
374 m.impl(
375 TORCH_SELECTIVE_NAME("vulkan_quantized::div"), TORCH_FN(quantized_div));
376 }
377
378 } // namespace
379 } // namespace ops
380 } // namespace vulkan
381 } // namespace native
382 } // namespace at
383
384 #endif /* USE_VULKAN_API */
385