1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/cadence/reference/kernels/kernels.h>
10
11 #include <executorch/runtime/kernel/kernel_includes.h>
12
13 namespace impl {
14 namespace reference {
15 namespace native {
16
17 using executorch::aten::Tensor;
18 using executorch::runtime::KernelRuntimeContext;
19
20 // This implements a generic 2d conv kernel that operates on raw pointers.
21 // The version handles both quantized and fp32 convolutions.
22 // The input is of shape [n x c x h x w]
23 // The weight is of shape [oc x wc x wh x ww], where wc == c
24 // The output is of shape [n x oc x oh x ow]
25 // The bias is of shape [oc]
26 template <typename IT, typename WT, typename BT, typename OT, bool quantized>
conv2d_nchw_core_generic(const IT * __restrict__ p_in,const WT * __restrict__ p_weight,const BT * __restrict__ p_bias,OT * __restrict__ p_out,int32_t n,int32_t c,int32_t h,int32_t w,int32_t oc,int32_t wc,int32_t wh,int32_t ww,int32_t oh,int32_t ow,int16_t s0,int16_t s1,int16_t p0,int16_t p1,int16_t d0,int16_t d1,int16_t groups,IT in_zero_point=0,const int32_t * __restrict__ weight_zero_point=nullptr,const float * __restrict__ bias_scale=nullptr,float out_scale=1,OT out_zero_point=0,bool per_tensor_quantized=true)27 __attribute__((noinline)) void conv2d_nchw_core_generic(
28 // All the arrays
29 const IT* __restrict__ p_in,
30 const WT* __restrict__ p_weight,
31 const BT* __restrict__ p_bias,
32 OT* __restrict__ p_out,
33 // The array sizes
34 int32_t n,
35 int32_t c,
36 int32_t h,
37 int32_t w,
38 int32_t oc,
39 int32_t wc,
40 int32_t wh,
41 int32_t ww,
42 int32_t oh,
43 int32_t ow,
44 // Stride
45 int16_t s0,
46 int16_t s1,
47 // Padding
48 int16_t p0,
49 int16_t p1,
50 // Dilation
51 int16_t d0,
52 int16_t d1,
53 // Group for depthwise conv
54 int16_t groups,
55 // Optional args that are only relevant for quantized convolution
56 // input zero point
57 IT in_zero_point = 0,
58 // weight zero point
59 const int32_t* __restrict__ weight_zero_point = nullptr,
60 const float* __restrict__ bias_scale = nullptr,
61 float out_scale = 1,
62 OT out_zero_point = 0,
63 bool per_tensor_quantized = true) {
64 float inv_out_scale = 1. / out_scale;
65 bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0;
66
67 // Compute the number of in and out channels per group
68 const int ocpg = oc / groups;
69 const int icpg = c / groups;
70
71 // Iterate over all the output batches (i.e., n)
72 for (int _n = 0; _n < n; ++_n) {
73 const IT* in_batch = p_in + _n * c * h * w;
74 OT* out_batch = p_out + _n * oc * oh * ow;
75 // Compute separable convolution for each group
76 for (int _g = 0; _g < groups; ++_g) {
77 // Identify the input and output channels involved in the computation
78 // of this group
79 int sic = _g * icpg;
80 int soc = _g * ocpg;
81 // Populate all the output channels in the group
82 for (int _oc = soc; _oc < soc + ocpg; ++_oc) {
83 OT* out_plane = out_batch + _oc * oh * ow;
84 const WT* weight_batch = p_weight + _oc * wc * wh * ww;
85 // We compute one output channel at a time. The computation can be
86 // thought of as a stencil computation: we iterate over an input of size
87 // icpg x h x w, with a stencil of size icpg x wh x ww, to compute an
88 // output channel of size 1 x oh x ow.
89 for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) {
90 for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) {
91 float acc = p_bias[_oc];
92 // Below is the stencil computation that performs the hadamard
93 // product+accumulation of each input channel (contributing to the
94 // output channel being computed) with the corresponding weight
95 // channel.
96 // If the padding is 0, and dilation is 1, then we can remove the
97 // unnecessary checks, and simplify the code so that it can be
98 // vectorized by Tensilica compiler.
99 if (zero_pad_unit_dilation) {
100 for (int _ic = sic; _ic < sic + icpg; ++_ic) {
101 const IT* in_plane = in_batch + _ic * h * w;
102 const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww;
103 for (int _wh = 0; _wh < wh; ++_wh) {
104 for (int _ww = 0; _ww < ww; ++_ww) {
105 int ioff = (_h + _wh) * w + (_w + _ww);
106 int woff = _wh * ww + _ww;
107 float lhs = in_plane[ioff] - in_zero_point;
108 float rhs = weight_plane[woff] -
109 (quantized ? weight_zero_point[0] : 0);
110 acc += lhs * rhs;
111 }
112 }
113 }
114 } else {
115 for (int _ic = sic; _ic < sic + icpg; ++_ic) {
116 const IT* in_plane = in_batch + _ic * h * w;
117 const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww;
118 for (int _wh = 0; _wh < wh; ++_wh) {
119 for (int _ww = 0; _ww < ww; ++_ww) {
120 if (((_h + d0 * _wh - p0) >= 0) &&
121 ((_h + d0 * _wh - p0) < h) &&
122 ((_w + d1 * _ww - p1) >= 0) &&
123 ((_w + d1 * _ww - p1 < w))) {
124 int ioff =
125 (_h + d0 * _wh - p0) * w + (_w + d1 * _ww - p1);
126 int woff = _wh * ww + _ww;
127 float lhs = in_plane[ioff] - in_zero_point;
128 float rhs = weight_plane[woff] -
129 (quantized ? weight_zero_point[0] : 0);
130 acc += lhs * rhs;
131 }
132 }
133 }
134 }
135 }
136 if (quantized) {
137 float val =
138 (per_tensor_quantized ? bias_scale[0] : bias_scale[_oc]) *
139 acc;
140 out_plane[_oh * ow + _ow] =
141 kernels::quantize<OT>(val, inv_out_scale, out_zero_point);
142 } else {
143 out_plane[_oh * ow + _ow] = acc;
144 }
145 }
146 }
147 }
148 }
149 }
150 }
151
152 // The quantized convolution kernel. in_scale and weight_scale are implicit in
153 // bias_scale, since it is a product of the two. The kernel will branch to
154 // quantized::conv1d or quantized::conv2d based on the dimensionality of
155 // activation tensor.
quantized_conv_out(KernelRuntimeContext & ctx,const Tensor & input,const Tensor & weight,const Tensor & bias,executorch::aten::IntArrayRef stride,executorch::aten::IntArrayRef padding,executorch::aten::IntArrayRef dilation,int64_t groups,int64_t in_zero_point,const Tensor & weight_zero_point,const Tensor & bias_scale,double output_scale,int64_t output_zero_point,const Tensor & out_multiplier,const Tensor & out_shift,bool channel_last,Tensor & out)156 void quantized_conv_out(
157 KernelRuntimeContext& ctx,
158 const Tensor& input,
159 const Tensor& weight,
160 const Tensor& bias,
161 executorch::aten::IntArrayRef stride,
162 executorch::aten::IntArrayRef padding,
163 executorch::aten::IntArrayRef dilation,
164 int64_t groups,
165 int64_t in_zero_point,
166 const Tensor& weight_zero_point,
167 const Tensor& bias_scale,
168 double output_scale,
169 int64_t output_zero_point,
170 const Tensor& out_multiplier,
171 const Tensor& out_shift,
172 bool channel_last,
173 Tensor& out) {
174 bool conv1d = input.dim() == 3;
175 // input = [n, c, h, w]
176 const int n = input.size(0);
177 const int c = input.size(1);
178 const int h = conv1d ? 1 : input.size(2);
179 const int w = conv1d ? input.size(2) : input.size(3);
180 // weight = [oc, wc, wh, ww]
181 const int oc = weight.size(0);
182 const int wc = weight.size(1);
183 const int wh = conv1d ? 1 : weight.size(2);
184 const int ww = conv1d ? weight.size(2) : weight.size(3);
185 // output = [n, oc, oh, ow]
186 const int oh = conv1d ? 1 : out.size(2);
187 const int ow = conv1d ? out.size(2) : out.size(3);
188
189 // Bool flag to check if weight tensor is quantized per-tensor or
190 // per-channel
191 bool per_tensor_quantized = bias_scale.numel() == 1;
192
193 if (out.scalar_type() == exec_aten::ScalarType::Byte) {
194 conv2d_nchw_core_generic<uint8_t, uint8_t, int32_t, uint8_t, true>(
195 input.const_data_ptr<uint8_t>(),
196 weight.const_data_ptr<uint8_t>(),
197 bias.const_data_ptr<int32_t>(),
198 out.mutable_data_ptr<uint8_t>(),
199 n,
200 c,
201 h,
202 w,
203 oc,
204 wc,
205 wh,
206 ww,
207 oh,
208 ow,
209 stride[0],
210 stride[1],
211 padding[0],
212 padding[1],
213 dilation[0],
214 dilation[1],
215 groups,
216 in_zero_point,
217 weight_zero_point.const_data_ptr<int32_t>(),
218 bias_scale.const_data_ptr<float>(),
219 output_scale,
220 (uint8_t)output_zero_point,
221 per_tensor_quantized);
222 } else if (out.scalar_type() == exec_aten::ScalarType::Char) {
223 conv2d_nchw_core_generic<int8_t, int8_t, int32_t, int8_t, true>(
224 input.const_data_ptr<int8_t>(),
225 weight.const_data_ptr<int8_t>(),
226 bias.const_data_ptr<int32_t>(),
227 out.mutable_data_ptr<int8_t>(),
228 n,
229 c,
230 h,
231 w,
232 oc,
233 wc,
234 wh,
235 ww,
236 oh,
237 ow,
238 stride[0],
239 stride[1],
240 padding[0],
241 padding[1],
242 dilation[0],
243 dilation[1],
244 groups,
245 in_zero_point,
246 weight_zero_point.const_data_ptr<int32_t>(),
247 bias_scale.const_data_ptr<float>(),
248 output_scale,
249 (int8_t)output_zero_point,
250 per_tensor_quantized);
251 } else {
252 ET_CHECK_MSG(
253 false,
254 "Unhandled input dtype %hhd",
255 static_cast<int8_t>(input.scalar_type()));
256 }
257 }
258
259 }; // namespace native
260 }; // namespace reference
261 }; // namespace impl
262