• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/cadence/reference/kernels/kernels.h>
10 
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 
13 namespace impl {
14 namespace reference {
15 namespace native {
16 
17 using executorch::aten::Tensor;
18 using executorch::runtime::KernelRuntimeContext;
19 
20 // This implements a generic 2d conv kernel that operates on raw pointers.
21 // The version handles both quantized and fp32 convolutions.
22 // The input is of shape [n x c x h x w]
23 // The weight is of shape [oc x wc x wh x ww], where wc == c
24 // The output is of shape [n x oc x oh x ow]
25 // The bias is of shape [oc]
26 template <typename IT, typename WT, typename BT, typename OT, bool quantized>
conv2d_nchw_core_generic(const IT * __restrict__ p_in,const WT * __restrict__ p_weight,const BT * __restrict__ p_bias,OT * __restrict__ p_out,int32_t n,int32_t c,int32_t h,int32_t w,int32_t oc,int32_t wc,int32_t wh,int32_t ww,int32_t oh,int32_t ow,int16_t s0,int16_t s1,int16_t p0,int16_t p1,int16_t d0,int16_t d1,int16_t groups,IT in_zero_point=0,const int32_t * __restrict__ weight_zero_point=nullptr,const float * __restrict__ bias_scale=nullptr,float out_scale=1,OT out_zero_point=0,bool per_tensor_quantized=true)27 __attribute__((noinline)) void conv2d_nchw_core_generic(
28     // All the arrays
29     const IT* __restrict__ p_in,
30     const WT* __restrict__ p_weight,
31     const BT* __restrict__ p_bias,
32     OT* __restrict__ p_out,
33     // The array sizes
34     int32_t n,
35     int32_t c,
36     int32_t h,
37     int32_t w,
38     int32_t oc,
39     int32_t wc,
40     int32_t wh,
41     int32_t ww,
42     int32_t oh,
43     int32_t ow,
44     // Stride
45     int16_t s0,
46     int16_t s1,
47     // Padding
48     int16_t p0,
49     int16_t p1,
50     // Dilation
51     int16_t d0,
52     int16_t d1,
53     // Group for depthwise conv
54     int16_t groups,
55     // Optional args that are only relevant for quantized convolution
56     // input zero point
57     IT in_zero_point = 0,
58     // weight zero point
59     const int32_t* __restrict__ weight_zero_point = nullptr,
60     const float* __restrict__ bias_scale = nullptr,
61     float out_scale = 1,
62     OT out_zero_point = 0,
63     bool per_tensor_quantized = true) {
64   float inv_out_scale = 1. / out_scale;
65   bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0;
66 
67   // Compute the number of in and out channels per group
68   const int ocpg = oc / groups;
69   const int icpg = c / groups;
70 
71   // Iterate over all the output batches (i.e., n)
72   for (int _n = 0; _n < n; ++_n) {
73     const IT* in_batch = p_in + _n * c * h * w;
74     OT* out_batch = p_out + _n * oc * oh * ow;
75     // Compute separable convolution for each group
76     for (int _g = 0; _g < groups; ++_g) {
77       // Identify the input and output channels involved in the computation
78       // of this group
79       int sic = _g * icpg;
80       int soc = _g * ocpg;
81       // Populate all the output channels in the group
82       for (int _oc = soc; _oc < soc + ocpg; ++_oc) {
83         OT* out_plane = out_batch + _oc * oh * ow;
84         const WT* weight_batch = p_weight + _oc * wc * wh * ww;
85         // We compute one output channel at a time. The computation can be
86         // thought of as a stencil computation: we iterate over an input of size
87         // icpg x h x w, with a stencil of size icpg x wh x ww, to compute an
88         // output channel of size 1 x oh x ow.
89         for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) {
90           for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) {
91             float acc = p_bias[_oc];
92             // Below is the stencil computation that performs the hadamard
93             // product+accumulation of each input channel (contributing to the
94             // output channel being computed) with the corresponding weight
95             // channel.
96             // If the padding is 0, and dilation is 1, then we can remove the
97             // unnecessary checks, and simplify the code so that it can be
98             // vectorized by Tensilica compiler.
99             if (zero_pad_unit_dilation) {
100               for (int _ic = sic; _ic < sic + icpg; ++_ic) {
101                 const IT* in_plane = in_batch + _ic * h * w;
102                 const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww;
103                 for (int _wh = 0; _wh < wh; ++_wh) {
104                   for (int _ww = 0; _ww < ww; ++_ww) {
105                     int ioff = (_h + _wh) * w + (_w + _ww);
106                     int woff = _wh * ww + _ww;
107                     float lhs = in_plane[ioff] - in_zero_point;
108                     float rhs = weight_plane[woff] -
109                         (quantized ? weight_zero_point[0] : 0);
110                     acc += lhs * rhs;
111                   }
112                 }
113               }
114             } else {
115               for (int _ic = sic; _ic < sic + icpg; ++_ic) {
116                 const IT* in_plane = in_batch + _ic * h * w;
117                 const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww;
118                 for (int _wh = 0; _wh < wh; ++_wh) {
119                   for (int _ww = 0; _ww < ww; ++_ww) {
120                     if (((_h + d0 * _wh - p0) >= 0) &&
121                         ((_h + d0 * _wh - p0) < h) &&
122                         ((_w + d1 * _ww - p1) >= 0) &&
123                         ((_w + d1 * _ww - p1 < w))) {
124                       int ioff =
125                           (_h + d0 * _wh - p0) * w + (_w + d1 * _ww - p1);
126                       int woff = _wh * ww + _ww;
127                       float lhs = in_plane[ioff] - in_zero_point;
128                       float rhs = weight_plane[woff] -
129                           (quantized ? weight_zero_point[0] : 0);
130                       acc += lhs * rhs;
131                     }
132                   }
133                 }
134               }
135             }
136             if (quantized) {
137               float val =
138                   (per_tensor_quantized ? bias_scale[0] : bias_scale[_oc]) *
139                   acc;
140               out_plane[_oh * ow + _ow] =
141                   kernels::quantize<OT>(val, inv_out_scale, out_zero_point);
142             } else {
143               out_plane[_oh * ow + _ow] = acc;
144             }
145           }
146         }
147       }
148     }
149   }
150 }
151 
152 // The quantized convolution kernel. in_scale and weight_scale are implicit in
153 // bias_scale, since it is a product of the two. The kernel will branch to
154 // quantized::conv1d or quantized::conv2d based on the dimensionality of
155 // activation tensor.
quantized_conv_out(KernelRuntimeContext & ctx,const Tensor & input,const Tensor & weight,const Tensor & bias,executorch::aten::IntArrayRef stride,executorch::aten::IntArrayRef padding,executorch::aten::IntArrayRef dilation,int64_t groups,int64_t in_zero_point,const Tensor & weight_zero_point,const Tensor & bias_scale,double output_scale,int64_t output_zero_point,const Tensor & out_multiplier,const Tensor & out_shift,bool channel_last,Tensor & out)156 void quantized_conv_out(
157     KernelRuntimeContext& ctx,
158     const Tensor& input,
159     const Tensor& weight,
160     const Tensor& bias,
161     executorch::aten::IntArrayRef stride,
162     executorch::aten::IntArrayRef padding,
163     executorch::aten::IntArrayRef dilation,
164     int64_t groups,
165     int64_t in_zero_point,
166     const Tensor& weight_zero_point,
167     const Tensor& bias_scale,
168     double output_scale,
169     int64_t output_zero_point,
170     const Tensor& out_multiplier,
171     const Tensor& out_shift,
172     bool channel_last,
173     Tensor& out) {
174   bool conv1d = input.dim() == 3;
175   // input = [n, c, h, w]
176   const int n = input.size(0);
177   const int c = input.size(1);
178   const int h = conv1d ? 1 : input.size(2);
179   const int w = conv1d ? input.size(2) : input.size(3);
180   // weight = [oc, wc, wh, ww]
181   const int oc = weight.size(0);
182   const int wc = weight.size(1);
183   const int wh = conv1d ? 1 : weight.size(2);
184   const int ww = conv1d ? weight.size(2) : weight.size(3);
185   // output = [n, oc, oh, ow]
186   const int oh = conv1d ? 1 : out.size(2);
187   const int ow = conv1d ? out.size(2) : out.size(3);
188 
189   // Bool flag to check if weight tensor is quantized per-tensor or
190   // per-channel
191   bool per_tensor_quantized = bias_scale.numel() == 1;
192 
193   if (out.scalar_type() == exec_aten::ScalarType::Byte) {
194     conv2d_nchw_core_generic<uint8_t, uint8_t, int32_t, uint8_t, true>(
195         input.const_data_ptr<uint8_t>(),
196         weight.const_data_ptr<uint8_t>(),
197         bias.const_data_ptr<int32_t>(),
198         out.mutable_data_ptr<uint8_t>(),
199         n,
200         c,
201         h,
202         w,
203         oc,
204         wc,
205         wh,
206         ww,
207         oh,
208         ow,
209         stride[0],
210         stride[1],
211         padding[0],
212         padding[1],
213         dilation[0],
214         dilation[1],
215         groups,
216         in_zero_point,
217         weight_zero_point.const_data_ptr<int32_t>(),
218         bias_scale.const_data_ptr<float>(),
219         output_scale,
220         (uint8_t)output_zero_point,
221         per_tensor_quantized);
222   } else if (out.scalar_type() == exec_aten::ScalarType::Char) {
223     conv2d_nchw_core_generic<int8_t, int8_t, int32_t, int8_t, true>(
224         input.const_data_ptr<int8_t>(),
225         weight.const_data_ptr<int8_t>(),
226         bias.const_data_ptr<int32_t>(),
227         out.mutable_data_ptr<int8_t>(),
228         n,
229         c,
230         h,
231         w,
232         oc,
233         wc,
234         wh,
235         ww,
236         oh,
237         ow,
238         stride[0],
239         stride[1],
240         padding[0],
241         padding[1],
242         dilation[0],
243         dilation[1],
244         groups,
245         in_zero_point,
246         weight_zero_point.const_data_ptr<int32_t>(),
247         bias_scale.const_data_ptr<float>(),
248         output_scale,
249         (int8_t)output_zero_point,
250         per_tensor_quantized);
251   } else {
252     ET_CHECK_MSG(
253         false,
254         "Unhandled input dtype %hhd",
255         static_cast<int8_t>(input.scalar_type()));
256   }
257 }
258 
259 }; // namespace native
260 }; // namespace reference
261 }; // namespace impl
262