• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #ifdef USE_VULKAN_API
2 
3 #include <ATen/native/quantized/PackedParams.h>
4 #include <ATen/native/vulkan/ops/Batchnorm.h>
5 #include <ATen/native/vulkan/ops/Common.h>
6 #include <ATen/native/vulkan/ops/Convolution.h>
7 #include <ATen/native/vulkan/ops/Gru.h>
8 #include <ATen/native/vulkan/ops/Layernorm.h>
9 #include <ATen/native/vulkan/ops/Lstm.h>
10 #include <ATen/native/vulkan/ops/Mm.h>
11 #include <ATen/native/vulkan/ops/QuantizedFunctions.h>
12 #include <ATen/native/vulkan/ops/Register.h>
13 #include <torch/custom_class.h>
14 #include <torch/library.h>
15 
16 namespace at {
17 namespace native {
18 namespace vulkan {
19 namespace ops {
20 
register_vulkan_conv2d_packed_context()21 int register_vulkan_conv2d_packed_context() {
22   static auto register_vulkan_conv2d_context =
23       torch::selective_class_<Conv2dPackedContext>(
24           "vulkan", TORCH_SELECTIVE_CLASS("Conv2dPackedContext"))
25           .def_pickle(
26               // __getstate__
27               [](const c10::intrusive_ptr<Conv2dPackedContext>& context) {
28                 // context is packed
29                 return context->unpack();
30               },
31               // __setstate__
32               [](c10::impl::GenericList state) {
33                 // state is unpacked
34                 return c10::make_intrusive<Conv2dPackedContext>(
35                     Conv2dPackedContext::pack(state));
36               });
37   return 0;
38 }
39 
register_vulkan_conv1d_packed_context()40 int register_vulkan_conv1d_packed_context() {
41   static auto register_vulkan_conv1d_context =
42       torch::selective_class_<Conv1dPackedContext>(
43           "vulkan", TORCH_SELECTIVE_CLASS("Conv1dPackedContext"))
44           .def_pickle(
45               // __getstate__
46               [](const c10::intrusive_ptr<Conv1dPackedContext>& context) {
47                 // context is packed
48                 return context->unpack();
49               },
50               // __setstate__
51               [](c10::impl::GenericList state) {
52                 // state is unpacked
53                 return c10::make_intrusive<Conv1dPackedContext>(
54                     Conv1dPackedContext::pack(state));
55               });
56   return 0;
57 }
58 
register_vulkan_linear_packed_context()59 int register_vulkan_linear_packed_context() {
60   static auto register_vulkan_linear_context =
61       torch::selective_class_<LinearPackedContext>(
62           "vulkan", TORCH_SELECTIVE_CLASS("LinearPackedContext"))
63           .def_pickle(
64               // __getstate__
65               [](const c10::intrusive_ptr<LinearPackedContext>& context) {
66                 // context is packed
67                 return context->unpack();
68               },
69               // __setstate__
70               [](c10::impl::GenericList state) {
71                 // state is unpacked
72                 return c10::make_intrusive<LinearPackedContext>(
73                     LinearPackedContext::pack(state));
74               });
75   return 0;
76 }
77 
register_vulkan_layernorm_packed_context()78 int register_vulkan_layernorm_packed_context() {
79   static auto register_vulkan_layernorm_context =
80       torch::selective_class_<LayernormPackedContext>(
81           "vulkan", TORCH_SELECTIVE_CLASS("LayernormPackedContext"))
82           .def_pickle(
83               // __getstate__
84               [](const c10::intrusive_ptr<LayernormPackedContext>& context) {
85                 // context is packed
86                 return context->unpack();
87               },
88               // __setstate__
89               [](c10::impl::GenericList state) {
90                 // state is unpacked
91                 return c10::make_intrusive<LayernormPackedContext>(
92                     LayernormPackedContext::pack(state));
93               });
94   return 0;
95 }
96 
97 namespace {
98 
TORCH_LIBRARY(vulkan,m)99 TORCH_LIBRARY(vulkan, m) {
100   m.class_<BatchNormPackedContext>("BatchNormPackedContext")
101       .def_pickle(
102           // __getstate__
103           [](const c10::intrusive_ptr<BatchNormPackedContext>& context) {
104             // context is packed
105             return context->unpack();
106           },
107           // __setstate__
108           [](c10::impl::GenericList state) {
109             // state is unpacked
110             return c10::make_intrusive<BatchNormPackedContext>(
111                 BatchNormPackedContext::pack(state));
112           });
113   m.class_<GruPackedContext>("GruPackedContext")
114       .def_pickle(
115           // __getstate__
116           [](const c10::intrusive_ptr<GruPackedContext>& context) {
117             // context is packed
118             return context->unpack();
119           },
120           // __setstate__
121           [](c10::impl::GenericList state) {
122             // state is unpacked
123             return c10::make_intrusive<GruPackedContext>(
124                 GruPackedContext::pack(state));
125           });
126   m.class_<LstmPackedContext>("LstmPackedContext")
127       .def_pickle(
128           // __getstate__
129           [](const c10::intrusive_ptr<LstmPackedContext>& context) {
130             // context is packed
131             return context->unpack();
132           },
133           // __setstate__
134           [](c10::impl::GenericList state) {
135             // state is unpacked
136             return c10::make_intrusive<LstmPackedContext>(
137                 LstmPackedContext::pack(state));
138           });
139   register_vulkan_conv2d_packed_context();
140   register_vulkan_conv1d_packed_context();
141   register_vulkan_linear_packed_context();
142   register_vulkan_layernorm_packed_context();
143   // To maintain backwards compatibility.
144   m.class_<Conv2dOpContext>("Conv2dOpContext")
145       .def_pickle(
146           // __getstate__
147           [](const c10::intrusive_ptr<Conv2dOpContext>& context) {
148             return context->unpack();
149           },
150           // __setstate__
151           [](Conv2dOpContext::State state) {
152             return conv2d_clamp_prepack(
153                 std::move(std::get<0>(state)),
154                 std::move(std::get<1>(state)),
155                 std::move(std::get<2>(state)),
156                 std::move(std::get<3>(state)),
157                 std::move(std::get<4>(state)),
158                 std::get<5>(state),
159                 std::get<6>(state),
160                 std::get<7>(state));
161           });
162 }
163 
TORCH_LIBRARY(vulkan_prepack,m)164 TORCH_LIBRARY(vulkan_prepack, m) {
165   m.def(TORCH_SELECTIVE_SCHEMA(
166       "vulkan_prepack::create_conv2d_context(Tensor W, Tensor? B, int[2] stride, "
167       "int[2] padding, int[2] dilation, int groups, "
168       "Scalar? output_min=None, Scalar? output_max=None) "
169       "-> __torch__.torch.classes.vulkan.Conv2dPackedContext"));
170   m.def(TORCH_SELECTIVE_SCHEMA( // Backwards compatibility
171       "vulkan_prepack::conv2d_clamp_prepack(Tensor W, Tensor? B, int[2] stride, "
172       "int[2] padding, int[2] dilation, int groups, "
173       "Scalar? output_min=None, Scalar? output_max=None) "
174       "-> __torch__.torch.classes.vulkan.Conv2dOpContext"));
175   m.def(TORCH_SELECTIVE_SCHEMA(
176       "vulkan_prepack::run_conv2d_context(Tensor X, "
177       "__torch__.torch.classes.vulkan.Conv2dPackedContext W_prepack) -> Tensor Y"));
178   m.def(TORCH_SELECTIVE_SCHEMA( // Backwards compatibility
179       "vulkan_prepack::conv2d_clamp_run(Tensor X, "
180       "__torch__.torch.classes.vulkan.Conv2dOpContext W_prepack) -> Tensor Y"));
181   m.def(TORCH_SELECTIVE_SCHEMA(
182       "vulkan_prepack::create_tconv2d_context(Tensor W, Tensor? B, int[2] stride, "
183       "int[2] padding, int[2] output_padding, int[2] dilation, int groups, "
184       "Scalar? output_min=None, Scalar? output_max=None) "
185       "-> __torch__.torch.classes.vulkan.Conv2dPackedContext"));
186   m.def(TORCH_SELECTIVE_SCHEMA(
187       "vulkan_prepack::run_tconv2d_context(Tensor X, "
188       "__torch__.torch.classes.vulkan.Conv2dPackedContext W_prepack) -> Tensor Y"));
189   m.def(TORCH_SELECTIVE_SCHEMA(
190       "vulkan_prepack::create_qconv2d_context(Tensor W, Tensor? B, "
191       "int[2] stride, int[2] padding, int[2] dilation, int groups, "
192       "Scalar? output_min=None, Scalar? output_max=None) "
193       "-> __torch__.torch.classes.vulkan.Conv2dPackedContext"));
194   m.def(TORCH_SELECTIVE_SCHEMA(
195       "vulkan_prepack::run_qconv2d_context(Tensor X, float scale, int zero_point, "
196       "__torch__.torch.classes.vulkan.Conv2dPackedContext vk_context) -> Tensor Y"));
197   m.def(TORCH_SELECTIVE_SCHEMA(
198       "vulkan_prepack::create_conv1d_context(Tensor W, Tensor? B, int[2] stride, "
199       "int[2] padding, int[2] dilation, int groups) "
200       "-> __torch__.torch.classes.vulkan.Conv1dPackedContext"));
201   m.def(TORCH_SELECTIVE_SCHEMA(
202       "vulkan_prepack::run_conv1d_context(Tensor X, "
203       "__torch__.torch.classes.vulkan.Conv1dPackedContext W_prepack) -> Tensor Y"));
204   m.def(TORCH_SELECTIVE_SCHEMA(
205       "vulkan_prepack::create_qtconv2d_context(Tensor W, Tensor? B, int[2] stride, "
206       "int[2] padding, int[2] output_padding, int[2] dilation, int groups, "
207       "Scalar? output_min=None, Scalar? output_max=None) "
208       "-> __torch__.torch.classes.vulkan.Conv2dPackedContext"));
209   m.def(TORCH_SELECTIVE_SCHEMA(
210       "vulkan_prepack::create_linear_context(Tensor W, Tensor? B) "
211       "-> __torch__.torch.classes.vulkan.LinearPackedContext"));
212   m.def(TORCH_SELECTIVE_SCHEMA(
213       "vulkan_prepack::run_linear_context(Tensor X, "
214       "__torch__.torch.classes.vulkan.LinearPackedContext BW_prepack) -> Tensor Y"));
215   m.def(TORCH_SELECTIVE_SCHEMA(
216       "vulkan_prepack::run_qlinear_context(Tensor X, float scale, int zero_point, "
217       "__torch__.torch.classes.vulkan.LinearPackedContext vk_context) -> Tensor Y"));
218   m.def(TORCH_SELECTIVE_SCHEMA(
219       "vulkan_prepack::create_layernorm_context(Tensor? W, Tensor? B, float eps) "
220       "-> __torch__.torch.classes.vulkan.LayernormPackedContext"));
221   m.def(TORCH_SELECTIVE_SCHEMA(
222       "vulkan_prepack::run_layernorm_context(Tensor X, SymInt[] normalized_shape, "
223       "__torch__.torch.classes.vulkan.LayernormPackedContext BW_prepack) -> Tensor Y"));
224   m.def(TORCH_SELECTIVE_SCHEMA(
225       "vulkan_prepack::create_gru_context(Tensor[] params_cpu, "
226       "bool has_biases, "
227       "int num_layers, "
228       "float dropout, "
229       "bool train, "
230       "bool bidirectional, "
231       "bool batch_first) "
232       "-> __torch__.torch.classes.vulkan.GruPackedContext"));
233   m.def(TORCH_SELECTIVE_SCHEMA(
234       "vulkan_prepack::run_gru_context(Tensor input_vk, "
235       "Tensor hx_vk, "
236       "__torch__.torch.classes.vulkan.GruPackedContext G_prepack) -> (Tensor next_input, Tensor hidden_layer)"));
237   m.def(TORCH_SELECTIVE_SCHEMA(
238       "vulkan_prepack::create_lstm_context(Tensor[] params_cpu, "
239       "bool has_biases, "
240       "int num_layers, "
241       "float dropout, "
242       "bool train, "
243       "bool bidirectional, "
244       "bool batch_first) "
245       "-> __torch__.torch.classes.vulkan.LstmPackedContext"));
246   m.def(TORCH_SELECTIVE_SCHEMA(
247       "vulkan_prepack::run_lstm_context(Tensor input_vk, "
248       "Tensor hx_vk, "
249       "Tensor cx_vk, "
250       "__torch__.torch.classes.vulkan.LstmPackedContext L_prepack) -> (Tensor next_input, Tensor hidden_state, Tensor cell_state)"));
251   m.def(TORCH_SELECTIVE_SCHEMA(
252       "vulkan_prepack::create_batchnorm_context("
253       "Tensor? weight_opt, "
254       "Tensor? bias_opt, "
255       "Tensor? running_mean_opt, "
256       "Tensor? running_var_opt, "
257       "bool training, "
258       "float momentum, "
259       "float eps, "
260       "bool cudnn_enable) "
261       "-> __torch__.torch.classes.vulkan.BatchNormPackedContext"));
262   m.def(TORCH_SELECTIVE_SCHEMA(
263       "vulkan_prepack::run_batchnorm_context("
264       "Tensor input_vk, "
265       "__torch__.torch.classes.vulkan.BatchNormPackedContext context) "
266       "-> Tensor out"));
267 }
268 
TORCH_LIBRARY_IMPL(vulkan_prepack,CPU,m)269 TORCH_LIBRARY_IMPL(vulkan_prepack, CPU, m) {
270   m.impl(
271       TORCH_SELECTIVE_NAME("vulkan_prepack::create_conv2d_context"),
272       TORCH_FN(create_conv2d_context));
273   m.impl(
274       TORCH_SELECTIVE_NAME("vulkan_prepack::conv2d_clamp_prepack"),
275       TORCH_FN(conv2d_clamp_prepack)); // Backwards compatibility
276   m.impl(
277       TORCH_SELECTIVE_NAME("vulkan_prepack::create_tconv2d_context"),
278       TORCH_FN(create_tconv2d_context));
279   m.impl(
280       TORCH_SELECTIVE_NAME("vulkan_prepack::create_conv1d_context"),
281       TORCH_FN(create_conv1d_context));
282   m.impl(
283       TORCH_SELECTIVE_NAME("vulkan_prepack::create_linear_context"),
284       TORCH_FN(create_linear_context));
285   m.impl(
286       TORCH_SELECTIVE_NAME("vulkan_prepack::create_layernorm_context"),
287       TORCH_FN(create_layernorm_context));
288   m.impl(
289       TORCH_SELECTIVE_NAME("vulkan_prepack::create_gru_context"),
290       TORCH_FN(create_gru_context));
291   m.impl(
292       TORCH_SELECTIVE_NAME("vulkan_prepack::create_lstm_context"),
293       TORCH_FN(create_lstm_context));
294   m.impl(
295       TORCH_SELECTIVE_NAME("vulkan_prepack::create_batchnorm_context"),
296       TORCH_FN(create_batchnorm_context));
297 }
298 
TORCH_LIBRARY_IMPL(vulkan_prepack,QuantizedCPU,m)299 TORCH_LIBRARY_IMPL(vulkan_prepack, QuantizedCPU, m) {
300   m.impl(
301       TORCH_SELECTIVE_NAME("vulkan_prepack::create_qconv2d_context"),
302       TORCH_FN(create_qconv2d_context));
303   m.impl(
304       TORCH_SELECTIVE_NAME("vulkan_prepack::create_qtconv2d_context"),
305       TORCH_FN(create_qtconv2d_context));
306 }
307 
TORCH_LIBRARY_IMPL(vulkan_prepack,Vulkan,m)308 TORCH_LIBRARY_IMPL(vulkan_prepack, Vulkan, m) {
309   m.impl(
310       TORCH_SELECTIVE_NAME("vulkan_prepack::run_conv2d_context"),
311       TORCH_FN(run_conv2d_context));
312   m.impl(
313       TORCH_SELECTIVE_NAME("vulkan_prepack::conv2d_clamp_run"),
314       TORCH_FN(conv2d_clamp_run)); // Backwards compatibility
315   m.impl(
316       TORCH_SELECTIVE_NAME("vulkan_prepack::run_tconv2d_context"),
317       TORCH_FN(run_tconv2d_context));
318   m.impl(
319       TORCH_SELECTIVE_NAME("vulkan_prepack::run_qconv2d_context"),
320       TORCH_FN(run_qconv2d_context));
321   m.impl(
322       TORCH_SELECTIVE_NAME("vulkan_prepack::run_conv1d_context"),
323       TORCH_FN(run_conv1d_context));
324   m.impl(
325       TORCH_SELECTIVE_NAME("vulkan_prepack::run_linear_context"),
326       TORCH_FN(run_linear_context));
327   m.impl(
328       TORCH_SELECTIVE_NAME("vulkan_prepack::run_layernorm_context"),
329       TORCH_FN(run_layernorm_context));
330   m.impl(
331       TORCH_SELECTIVE_NAME("vulkan_prepack::run_qlinear_context"),
332       TORCH_FN(run_qlinear_context));
333   m.impl(
334       TORCH_SELECTIVE_NAME("vulkan_prepack::run_gru_context"),
335       TORCH_FN(run_gru_context));
336   m.impl(
337       TORCH_SELECTIVE_NAME("vulkan_prepack::run_lstm_context"),
338       TORCH_FN(run_lstm_context));
339   m.impl(
340       TORCH_SELECTIVE_NAME("vulkan_prepack::run_batchnorm_context"),
341       TORCH_FN(run_batchnorm_context));
342 }
343 
TORCH_LIBRARY(vulkan_quantized,m)344 TORCH_LIBRARY(vulkan_quantized, m) {
345   m.def(
346       TORCH_SELECTIVE_SCHEMA("vulkan_quantized::add(Tensor qa, "
347                              "Tensor qb, "
348                              "float scale, "
349                              "int zero_point) -> Tensor qc"));
350   m.def(
351       TORCH_SELECTIVE_SCHEMA("vulkan_quantized::sub(Tensor qa, "
352                              "Tensor qb, "
353                              "float scale, "
354                              "int zero_point)-> Tensor qc"));
355   m.def(
356       TORCH_SELECTIVE_SCHEMA("vulkan_quantized::mul(Tensor qa, "
357                              "Tensor qb, "
358                              "float scale, "
359                              "int zero_point)-> Tensor qc"));
360   m.def(
361       TORCH_SELECTIVE_SCHEMA("vulkan_quantized::div(Tensor qa, "
362                              "Tensor qb, "
363                              "float scale, "
364                              "int zero_point)-> Tensor qc"));
365 }
366 
TORCH_LIBRARY_IMPL(vulkan_quantized,Vulkan,m)367 TORCH_LIBRARY_IMPL(vulkan_quantized, Vulkan, m) {
368   m.impl(
369       TORCH_SELECTIVE_NAME("vulkan_quantized::add"), TORCH_FN(quantized_add));
370   m.impl(
371       TORCH_SELECTIVE_NAME("vulkan_quantized::sub"), TORCH_FN(quantized_sub));
372   m.impl(
373       TORCH_SELECTIVE_NAME("vulkan_quantized::mul"), TORCH_FN(quantized_mul));
374   m.impl(
375       TORCH_SELECTIVE_NAME("vulkan_quantized::div"), TORCH_FN(quantized_div));
376 }
377 
378 } // namespace
379 } // namespace ops
380 } // namespace vulkan
381 } // namespace native
382 } // namespace at
383 
384 #endif /* USE_VULKAN_API */
385