1 #ifdef USE_CUDA
2 #include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
3
4 #if AT_CUDNN_ENABLED()
5
6 #include <ATen/ATen.h>
7 #include <torch/library.h>
8 #include <ATen/native/quantized/cudnn/utils.h>
9 #include <ATen/native/quantized/PackedParams.h>
10 #include <ATen/quantized/Quantizer.h>
11 #include <c10/core/QScheme.h>
12 #include <c10/util/irange.h>
13 #include <torch/library.h>
14
15 int register_linear_params();
16
prepack(at::Tensor weight,std::optional<at::Tensor> bias)17 c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightCudnn::prepack(
18 at::Tensor weight,
19 std::optional<at::Tensor> bias) {
20 TORCH_CHECK(weight.qscheme() == c10::kPerTensorAffine, "Unsupported qscheme: ", toString(weight.qscheme()));
21 const auto output_channels = weight.size(0);
22 const auto qtype = weight.qscheme();
23 if (bias.has_value()) {
24 TORCH_CHECK(bias.value().dim() == 1, "bias should be a vector (1D Tensor)");
25 TORCH_CHECK(
26 bias.value().size(0) == output_channels,
27 "bias should have K elements: " + std::to_string(output_channels));
28 }
29
30 auto ret_ptr = c10::make_intrusive<PackedLinearWeightCudnn>(
31 std::move(weight),
32 std::move(bias),
33 qtype);
34 return ret_ptr;
35 }
36
37
38 namespace at::native {
39 namespace {
40
41 class QLinearPackWeightInt8Cudnn final {
42 public:
run(at::Tensor weight,std::optional<Tensor> bias)43 static c10::intrusive_ptr<LinearPackedParamsBase> run(
44 at::Tensor weight,
45 std::optional<Tensor> bias) {
46 return PackedLinearWeightCudnn::prepack(std::move(weight), std::move(bias));
47 }
48 };
49
TORCH_LIBRARY_IMPL(quantized,QuantizedCUDA,m)50 TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) {
51 register_linear_params();
52 m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack"), TORCH_FN(QLinearPackWeightInt8Cudnn::run));
53 }
54
55
56 } // namespace
57 } // namespace at::native
58
59
60 #endif // AT_CUDNN_ENABLED
61 #endif // USE_CUDA
62