• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #ifdef USE_CUDA
2 #include <ATen/cuda/CUDAConfig.h>  // for the definition of AT_CUDNN_ENABLED
3 
4 #if AT_CUDNN_ENABLED()
5 
6 #include <ATen/ATen.h>
7 #include <torch/library.h>
8 #include <ATen/native/quantized/cudnn/utils.h>
9 #include <ATen/native/quantized/PackedParams.h>
10 #include <ATen/quantized/Quantizer.h>
11 #include <c10/core/QScheme.h>
12 #include <c10/util/irange.h>
13 #include <torch/library.h>
14 
15 int register_linear_params();
16 
prepack(at::Tensor weight,std::optional<at::Tensor> bias)17 c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightCudnn::prepack(
18         at::Tensor weight,
19         std::optional<at::Tensor> bias) {
20   TORCH_CHECK(weight.qscheme() == c10::kPerTensorAffine, "Unsupported qscheme: ", toString(weight.qscheme()));
21   const auto output_channels = weight.size(0);
22   const auto qtype = weight.qscheme();
23   if (bias.has_value()) {
24     TORCH_CHECK(bias.value().dim() == 1, "bias should be a vector (1D Tensor)");
25     TORCH_CHECK(
26         bias.value().size(0) == output_channels,
27         "bias should have K elements: " + std::to_string(output_channels));
28   }
29 
30   auto ret_ptr = c10::make_intrusive<PackedLinearWeightCudnn>(
31           std::move(weight),
32           std::move(bias),
33           qtype);
34   return ret_ptr;
35 }
36 
37 
38 namespace at::native {
39 namespace {
40 
41 class QLinearPackWeightInt8Cudnn final {
42  public:
run(at::Tensor weight,std::optional<Tensor> bias)43   static c10::intrusive_ptr<LinearPackedParamsBase> run(
44       at::Tensor weight,
45       std::optional<Tensor> bias) {
46       return PackedLinearWeightCudnn::prepack(std::move(weight), std::move(bias));
47   }
48 };
49 
TORCH_LIBRARY_IMPL(quantized,QuantizedCUDA,m)50 TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) {
51   register_linear_params();
52   m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack"), TORCH_FN(QLinearPackWeightInt8Cudnn::run));
53 }
54 
55 
56 } // namespace
57 } // namespace at::native
58 
59 
60 #endif  // AT_CUDNN_ENABLED
61 #endif  // USE_CUDA
62