• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #ifdef USE_CUDA
2 #include <ATen/cuda/CUDAConfig.h>  // for the definition of AT_CUDNN_ENABLED
3 
4 #if AT_CUDNN_ENABLED()
5 
6 #include <ATen/ATen.h>
7 #include <torch/library.h>
8 #include <ATen/native/quantized/cpu/QuantUtils.h>
9 #include <ATen/native/quantized/cudnn/utils.h>
10 #include <ATen/native/quantized/PackedParams.h>
11 #include <ATen/quantized/Quantizer.h>
12 #include <c10/core/QScheme.h>
13 #include <c10/util/irange.h>
14 #include <torch/library.h>
15 
16 #include <utility>
17 
18 template <int kSpatialDim = 2>
19 int register_conv_params();
20 
21 extern template int register_conv_params<2>();
22 extern template int register_conv_params<3>();
23 
24 template <int kSpatialDim>
25 c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightCudnn<
26     kSpatialDim>::
prepack(at::Tensor weight,std::optional<at::Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> output_padding,torch::List<int64_t> dilation,int64_t groups,bool transpose)27     prepack(
28         at::Tensor weight,
29         std::optional<at::Tensor> bias,
30         torch::List<int64_t> stride,
31         torch::List<int64_t> padding,
32         torch::List<int64_t> output_padding,
33         torch::List<int64_t> dilation,
34         int64_t groups,
35         bool transpose) {
36   // TODO: need to check out to implement groups for conv operator in Conv.cpp
37   TORCH_CHECK(groups == 1, "Quantized cudnn conv2d is currently limited to groups = 1; received groups =", groups);
38   TORCH_CHECK(weight.qscheme() == c10::kPerTensorAffine, "Unsupported qscheme: ", toString(weight.qscheme()));
39   TORCH_CHECK(
40       kSpatialDim == 2,  // 1D is packed as 2d, hence we don't need other checks
41       "cuDNN packing only supports 2D convolution.");
42   TORCH_CHECK(
43       weight.ndimension() == kSpatialDim + 2,
44       "Weights are expected to have ",
45       kSpatialDim + 2,
46       " dimensions");
47   TORCH_CHECK(
48       stride.size() == kSpatialDim,
49       "stride should contain ",
50       kSpatialDim,
51       " elements for ",
52       kSpatialDim,
53       "D convolution.");
54   TORCH_CHECK(
55       padding.size() == kSpatialDim,
56       "quantized::conv_prepack (cudnn): Specify front/top/left padding only. "
57       "end/bottom/right padding assumed to be equal to front/top/left");
58   TORCH_CHECK(
59       !transpose || output_padding.size() == kSpatialDim,
60       "quantized::conv_prepack: Specify top/left output padding "
61       "only. bottom/right padding assumed to be equal to top/left");
62   TORCH_CHECK(
63       dilation.size() == kSpatialDim,
64       "quantized::conv_prepack (cudnn): dilation should contain ",
65       kSpatialDim,
66       " elements for ",
67       kSpatialDim,
68       "D convolution.");
69   TORCH_CHECK(!transpose, "cudNN quantized conv prepack expects transpose = false")
70   const auto num_unpadded_output_channels = weight.size(0);
71   const auto qtype = weight.qscheme();
72   if (bias.has_value()) {
73     TORCH_CHECK(bias.value().dim() == 1, "bias should be a vector (1D Tensor)");
74     TORCH_CHECK(
75         bias.value().size(0) == num_unpadded_output_channels,
76         "bias should have K elements: " + std::to_string(num_unpadded_output_channels));
77     // TODO: we create a broadcasted_bias tensor later so I think we don't need to make this contiguous here.
78     // we will revisit this when nvidia adds proper support for broadcasting
79     // bias_contig = bias->contiguous();
80   }
81 
82   // cudnn v8.4.0 expects conv2d's int8 weight tensor's input and output channels to be a multiple of 4. if it is not
83   // we need to explicitly pad it to a multiple of 4 ourselves as cudnn does not currently support padding.
84   // TODO: when and if cudnn enables padding in their operators, we can remove padding on our end;
85   // currently, limit padding support to groups=1 (ungrouped conv)
86   // TODO: implement this for groups > 1
87   auto num_input_channels = weight.size(1);
88   auto num_output_slices2pad = (4 - num_unpadded_output_channels % 4) % 4;
89   auto num_input_slices2pad = (4 - num_input_channels % 4) % 4;
90   if (num_output_slices2pad != 0 || num_input_slices2pad != 0) {
91     // the second argument is an initializer list of padded values. there are 2 values for each dimension.
92     // refer to https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html for more details
93     weight = at::pad(weight, {0, 0, 0, 0, 0, num_input_slices2pad, 0, num_output_slices2pad}, "constant", 0);
94     if (bias.has_value()) {
95       bias.value() = at::pad(bias.value(), {0, num_output_slices2pad}, "constant", 0);
96     }
97   }
98 
99   auto ret_ptr = c10::make_intrusive<PackedConvWeightCudnn<kSpatialDim>>(
100           weight.to(c10::MemoryFormat::ChannelsLast), // TODO: this assumes 2D I think. make it more general?
101           std::move(bias),
102           stride,
103           padding,
104           output_padding,
105           dilation,
106           groups,
107           transpose,
108           qtype,
109           num_unpadded_output_channels);
110   return ret_ptr;
111 }
112 
113 template
114 c10::intrusive_ptr<ConvPackedParamsBase<2>> PackedConvWeightCudnn<
115     2>::
116     prepack(
117         at::Tensor weight,
118         std::optional<at::Tensor> bias_in,
119         torch::List<int64_t> stride,
120         torch::List<int64_t> padding,
121         torch::List<int64_t> output_padding,
122         torch::List<int64_t> dilation,
123         int64_t groups,
124         bool transpose);
125 
126 
127 namespace at::native {
128 namespace {
129 
130 template <int kSpatialDim = 2>
131 class QConvPackWeightInt8Cudnn final {
132  public:
run_conv(Tensor weight,std::optional<Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> dilation,int64_t groups)133   static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> run_conv(
134       Tensor weight,
135       std::optional<Tensor> bias,
136       torch::List<int64_t> stride,
137       torch::List<int64_t> padding,
138       torch::List<int64_t> dilation,
139       int64_t groups) {
140     torch::List<int64_t> output_padding;
141     output_padding.reserve(kSpatialDim);
142     for (C10_UNUSED const auto idx : c10::irange(kSpatialDim)) {
143       output_padding.push_back((int64_t)0);
144     }
145     return _run(weight, bias, stride, padding, output_padding, dilation, groups,
146                 /*transpose=*/false);
147   }
148 
149  private:
_run(Tensor weight,std::optional<Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> output_padding,torch::List<int64_t> dilation,int64_t groups,bool transpose)150   static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> _run(
151       Tensor weight,
152       std::optional<Tensor> bias,
153       torch::List<int64_t> stride,
154       torch::List<int64_t> padding,
155       torch::List<int64_t> output_padding,
156       torch::List<int64_t> dilation,
157       int64_t groups,
158       bool transpose) {
159     return PackedConvWeightCudnn<kSpatialDim>::prepack(
160         weight, bias, stride, padding, output_padding, dilation, groups,
161         transpose);
162   }
163 };
164 
165 class QConv1dPackWeightInt8Cudnn final {
166  public:
run_conv(Tensor weight,std::optional<Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> dilation,int64_t groups)167   static c10::intrusive_ptr<ConvPackedParamsBase<2>> run_conv(
168       Tensor weight,
169       std::optional<Tensor> bias,
170       torch::List<int64_t> stride,
171       torch::List<int64_t> padding,
172       torch::List<int64_t> dilation,
173       int64_t groups) {
174     const torch::List<int64_t> output_padding({0});
175     return _run(std::move(weight), std::move(bias), stride, padding, output_padding, dilation, groups,
176                 /*transpose=*/false);
177   }
178 
179  private:
_run(Tensor weight,std::optional<Tensor> bias,torch::List<int64_t> stride,torch::List<int64_t> padding,torch::List<int64_t> output_padding,torch::List<int64_t> dilation,int64_t groups,bool transpose)180   static c10::intrusive_ptr<ConvPackedParamsBase<2>> _run(
181       Tensor weight,
182       std::optional<Tensor> bias,
183       torch::List<int64_t> stride,
184       torch::List<int64_t> padding,
185       torch::List<int64_t> output_padding,
186       torch::List<int64_t> dilation,
187       int64_t groups,
188       bool transpose) {
189     if (weight.dim() == 3) {
190       // we currently use conv2d kernel for conv1d by making the input and weight tensors
191       // 4D rather than 3D. we add a dummy width dimension of size 1
192       // out channels, in channels / groups, L -> out channels, in channels / groups, 1, L
193       weight = weight.unsqueeze(-2);
194     }
195     stride = quant_utils::MakeArgForConv1d(stride, 1);
196     padding = quant_utils::MakeArgForConv1d(padding, 0);
197     output_padding = quant_utils::MakeArgForConv1d(output_padding, 0);
198     dilation = quant_utils::MakeArgForConv1d(dilation, 1);
199 
200     return PackedConvWeightCudnn<2>::prepack(
201         weight, std::move(bias), stride, padding, output_padding, dilation, groups,
202         transpose);
203   }
204 };
205 
TORCH_LIBRARY_IMPL(quantized,QuantizedCUDA,m)206 TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) {
207   register_conv_params<2>();
208   register_conv_params<3>();
209   m.impl(TORCH_SELECTIVE_NAME("quantized::conv1d_prepack"), TORCH_FN(QConv1dPackWeightInt8Cudnn::run_conv));
210   m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_prepack"), TORCH_FN(QConvPackWeightInt8Cudnn<2>::run_conv));
211 }
212 
213 } // namespace
214 } // namespace at::native
215 
216 #endif  // AT_CUDNN_ENABLED
217 #endif  // USE_CUDA
218