1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #include <assert.h>
10 #include <math.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <stdlib.h>
14
15 #include <xnnpack.h>
16 #include <xnnpack/allocator.h>
17 #include <xnnpack/log.h>
18 #include <xnnpack/operator.h>
19 #include <xnnpack/params-init.h>
20 #include <xnnpack/params.h>
21
22
xnn_create_add_nc_q8(size_t channels,size_t a_stride,size_t b_stride,size_t sum_stride,uint8_t a_zero_point,float a_scale,uint8_t b_zero_point,float b_scale,uint8_t sum_zero_point,float sum_scale,uint8_t sum_min,uint8_t sum_max,uint32_t flags,xnn_operator_t * add_op_out)23 enum xnn_status xnn_create_add_nc_q8(
24 size_t channels,
25 size_t a_stride,
26 size_t b_stride,
27 size_t sum_stride,
28 uint8_t a_zero_point,
29 float a_scale,
30 uint8_t b_zero_point,
31 float b_scale,
32 uint8_t sum_zero_point,
33 float sum_scale,
34 uint8_t sum_min,
35 uint8_t sum_max,
36 uint32_t flags,
37 xnn_operator_t* add_op_out)
38 {
39 xnn_operator_t add_op = NULL;
40 enum xnn_status status = xnn_status_uninitialized;
41
42 if (!xnn_params.initialized) {
43 xnn_log_error("failed to create Add operator: XNNPACK is not initialized");
44 goto error;
45 }
46
47 status = xnn_status_invalid_parameter;
48
49 if (channels == 0) {
50 xnn_log_error(
51 "failed to create Add operator with %zu channels: number of channels must be non-zero", channels);
52 goto error;
53 }
54
55 if (a_stride < channels) {
56 xnn_log_error(
57 "failed to create Add operator with A element stride of %zu: "
58 "stride must be at least as large as the number of channels (%zu)",
59 a_stride, channels);
60 goto error;
61 }
62
63 if (b_stride < channels) {
64 xnn_log_error(
65 "failed to create Add operator with B element stride of %zu: "
66 "stride must be at least as large as the number of channels (%zu)",
67 b_stride, channels);
68 goto error;
69 }
70
71 if (sum_stride < channels) {
72 xnn_log_error(
73 "failed to create Add operator with Sum element stride of %zu: "
74 "stride must be at least as large as the number of channels (%zu)",
75 sum_stride, channels);
76 goto error;
77 }
78
79 if (a_scale <= 0.0f || !isnormal(a_scale)) {
80 xnn_log_error(
81 "failed to create Add operator with %.7g A scale: scale must be finite, normalized, and positive", a_scale);
82 goto error;
83 }
84
85 if (b_scale <= 0.0f || !isnormal(b_scale)) {
86 xnn_log_error(
87 "failed to create Add operator with %.7g B scale: scale must be finite, normalized, and positive", b_scale);
88 goto error;
89 }
90
91 if (sum_scale <= 0.0f || !isnormal(sum_scale)) {
92 xnn_log_error(
93 "failed to create Add operator with %.7g output scale: scale must be finite, normalized, and positive",
94 sum_scale);
95 goto error;
96 }
97
98 if (sum_min >= sum_max) {
99 xnn_log_error(
100 "failed to create Add operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
101 sum_min, sum_max);
102 goto error;
103 }
104
105 status = xnn_status_unsupported_parameter;
106
107 const float a_output_scale = a_scale / sum_scale;
108 if (a_output_scale < 0x1.0p-14f || a_output_scale >= 0x1.0p+8f) {
109 xnn_log_error(
110 "failed to create Add operator with %.7g A-to-output scale ratio: scale ratio must be in [2**-14, 2**8) range",
111 a_output_scale);
112 goto error;
113 }
114
115 const float b_output_scale = b_scale / sum_scale;
116 if (b_output_scale < 0x1.0p-14f || b_output_scale >= 0x1.0p+8f) {
117 xnn_log_error(
118 "failed to create Add operator with %.7g A-to-output scale ratio: scale ratio must be in [2**-14, 2**8) range",
119 b_output_scale);
120 goto error;
121 }
122
123 status = xnn_status_out_of_memory;
124
125 add_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
126 if (add_op == NULL) {
127 xnn_log_error("failed to allocate %zu bytes for Add operator descriptor", sizeof(struct xnn_operator));
128 goto error;
129 }
130
131 add_op->channels = channels;
132 add_op->input_pixel_stride = a_stride;
133 add_op->input2_pixel_stride = b_stride;
134 add_op->output_pixel_stride = sum_stride;
135 add_op->q8_add_params =
136 xnn_init_q8_add_params(
137 a_zero_point, b_zero_point, sum_zero_point,
138 a_scale / sum_scale, b_scale / sum_scale,
139 sum_min, sum_max);
140
141 add_op->type = xnn_operator_type_add_nc_q8;
142 add_op->ukernel.type = xnn_ukernel_type_add;
143
144 add_op->state = xnn_run_state_invalid;
145
146 *add_op_out = add_op;
147 return xnn_status_success;
148
149 error:
150 xnn_delete_operator(add_op);
151 return status;
152 }
153
xnn_create_add_nc_f32(size_t channels,size_t a_stride,size_t b_stride,size_t sum_stride,float sum_min,float sum_max,uint32_t flags,xnn_operator_t * add_op_out)154 enum xnn_status xnn_create_add_nc_f32(
155 size_t channels,
156 size_t a_stride,
157 size_t b_stride,
158 size_t sum_stride,
159 float sum_min,
160 float sum_max,
161 uint32_t flags,
162 xnn_operator_t* add_op_out)
163 {
164 xnn_operator_t add_op = NULL;
165 enum xnn_status status = xnn_status_uninitialized;
166
167 if (!xnn_params.initialized) {
168 xnn_log_error("failed to create Add operator: XNNPACK is not initialized");
169 goto error;
170 }
171
172 status = xnn_status_invalid_parameter;
173
174 if (channels == 0) {
175 xnn_log_error(
176 "failed to create add operator with %zu channels: number of channels must be non-zero", channels);
177 goto error;
178 }
179
180 if (a_stride < channels) {
181 xnn_log_error(
182 "failed to create Add operator with A element stride of %zu: "
183 "stride must be at least as large as the number of channels (%zu)",
184 a_stride, channels);
185 goto error;
186 }
187
188 if (b_stride < channels) {
189 xnn_log_error(
190 "failed to create Add operator with B element stride of %zu: "
191 "stride must be at least as large as the number of channels (%zu)",
192 b_stride, channels);
193 goto error;
194 }
195
196 if (sum_stride < channels) {
197 xnn_log_error(
198 "failed to create Add operator with Sum element stride of %zu: "
199 "stride must be at least as large as the number of channels (%zu)",
200 sum_stride, channels);
201 goto error;
202 }
203
204 if (isnan(sum_min)) {
205 xnn_log_error(
206 "failed to create Add operator with NaN output lower bound: lower bound must be non-NaN");
207 goto error;
208 }
209
210 if (isnan(sum_max)) {
211 xnn_log_error(
212 "failed to create Add operator with NaN output upper bound: upper bound must be non-NaN");
213 goto error;
214 }
215
216 if (sum_min >= sum_max) {
217 xnn_log_error(
218 "failed to create Add operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
219 sum_min, sum_max);
220 goto error;
221 }
222
223 status = xnn_status_out_of_memory;
224
225 add_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
226 if (add_op == NULL) {
227 xnn_log_error("failed to allocate %zu bytes for Add operator descriptor", sizeof(struct xnn_operator));
228 goto error;
229 }
230
231 add_op->channels = channels;
232 add_op->input_pixel_stride = a_stride;
233 add_op->input2_pixel_stride = b_stride;
234 add_op->output_pixel_stride = sum_stride;
235 add_op->f32_output_params = xnn_init_f32_output_params(sum_min, sum_max);
236
237 add_op->type = xnn_operator_type_add_nc_f32;
238 add_op->ukernel.type = xnn_ukernel_type_add;
239
240 add_op->state = xnn_run_state_invalid;
241
242 *add_op_out = add_op;
243 return xnn_status_success;
244
245 error:
246 xnn_delete_operator(add_op);
247 return status;
248 }
249
xnn_setup_add_nc_q8(xnn_operator_t add_op,size_t batch_size,const uint8_t * a,const uint8_t * b,uint8_t * sum,pthreadpool_t threadpool)250 enum xnn_status xnn_setup_add_nc_q8(
251 xnn_operator_t add_op,
252 size_t batch_size,
253 const uint8_t* a,
254 const uint8_t* b,
255 uint8_t* sum,
256 pthreadpool_t threadpool)
257 {
258 if (add_op->type != xnn_operator_type_add_nc_q8) {
259 xnn_log_error("failed to setup Add (NC, Q8) operator: operator type mismatch");
260 return xnn_status_invalid_parameter;
261 }
262 add_op->state = xnn_run_state_invalid;
263
264 if (!xnn_params.initialized) {
265 xnn_log_error("failed to setup Add operator: XNNPACK is not initialized");
266 return xnn_status_uninitialized;
267 }
268
269 if (batch_size == 0) {
270 add_op->state = xnn_run_state_skip;
271 return xnn_status_success;
272 }
273
274 const size_t channels = add_op->channels;
275 const size_t a_stride = add_op->input_pixel_stride;
276 const size_t b_stride = add_op->input2_pixel_stride;
277 const size_t sum_stride = add_op->output_pixel_stride;
278 if ((((a_stride ^ channels) | (b_stride ^ channels) | (sum_stride ^ channels)) == 0) || batch_size == 1) {
279 const size_t block_size = 4096;
280 add_op->context.add_contiguous = (struct add_contiguous_context) {
281 .a = a,
282 .b = b,
283 .y = sum,
284 .params.q8 = add_op->q8_add_params,
285 .ukernel = xnn_params.q8.vadd,
286 };
287 add_op->compute.type = xnn_parallelization_type_1d_tile_1d;
288 add_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_add_contiguous;
289 add_op->compute.range[0] = batch_size * channels * sizeof(uint8_t);
290 add_op->compute.tile[0] = block_size;
291 } else {
292 add_op->context.add_strided = (struct add_strided_context) {
293 .a = a,
294 .a_stride = a_stride * sizeof(uint8_t),
295 .b = b,
296 .b_stride = b_stride * sizeof(uint8_t),
297 .y = sum,
298 .y_stride = sum_stride * sizeof(uint8_t),
299 .n = channels,
300 .params.q8 = add_op->q8_add_params,
301 .ukernel = xnn_params.q8.vadd,
302 };
303 add_op->compute.type = xnn_parallelization_type_1d_tile_1d;
304 add_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_add_strided;
305 add_op->compute.range[0] = batch_size;
306 add_op->compute.tile[0] = 1;
307 }
308 add_op->state = xnn_run_state_ready;
309
310 return xnn_status_success;
311 }
312
xnn_setup_add_nc_f32(xnn_operator_t add_op,size_t batch_size,const float * a,const float * b,float * sum,pthreadpool_t threadpool)313 enum xnn_status xnn_setup_add_nc_f32(
314 xnn_operator_t add_op,
315 size_t batch_size,
316 const float* a,
317 const float* b,
318 float* sum,
319 pthreadpool_t threadpool)
320 {
321 if (add_op->type != xnn_operator_type_add_nc_f32) {
322 xnn_log_error("failed to setup Add (NC, F32) operator: operator type mismatch");
323 return xnn_status_invalid_parameter;
324 }
325 add_op->state = xnn_run_state_invalid;
326
327 if (!xnn_params.initialized) {
328 xnn_log_error("failed to setup Add operator: XNNPACK is not initialized");
329 return xnn_status_uninitialized;
330 }
331
332 if (batch_size == 0) {
333 add_op->state = xnn_run_state_skip;
334 return xnn_status_success;
335 }
336
337 const size_t channels = add_op->channels;
338 const size_t a_stride = add_op->input_pixel_stride;
339 const size_t b_stride = add_op->input2_pixel_stride;
340 const size_t sum_stride = add_op->output_pixel_stride;
341 if ((((a_stride ^ channels) | (b_stride ^ channels) | (sum_stride ^ channels)) == 0) || batch_size == 1) {
342 const size_t block_size = 4096;
343 add_op->context.add_contiguous = (struct add_contiguous_context) {
344 .a = a,
345 .b = b,
346 .y = sum,
347 .params.f32 = add_op->f32_output_params,
348 .ukernel = xnn_params.f32.vadd.op_ukernel,
349 };
350 add_op->compute.type = xnn_parallelization_type_1d_tile_1d;
351 add_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_add_contiguous;
352 add_op->compute.range[0] = batch_size * channels * sizeof(float);
353 add_op->compute.tile[0] = block_size;
354 } else {
355 add_op->context.add_strided = (struct add_strided_context) {
356 .a = a,
357 .a_stride = a_stride * sizeof(float),
358 .b = b,
359 .b_stride = b_stride * sizeof(float),
360 .y = sum,
361 .y_stride = sum_stride * sizeof(float),
362 .n = channels * sizeof(float),
363 .params.f32 = add_op->f32_output_params,
364 .ukernel = xnn_params.f32.vadd.op_ukernel,
365 };
366 add_op->compute.type = xnn_parallelization_type_1d_tile_1d;
367 add_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_add_strided;
368 add_op->compute.range[0] = batch_size;
369 add_op->compute.tile[0] = 1;
370 }
371 add_op->state = xnn_run_state_ready;
372
373 return xnn_status_success;
374 }
375