1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9
10 #include <xnnpack.h>
11 #include <xnnpack/log.h>
12 #include <xnnpack/params.h>
13 #include <xnnpack/subgraph.h>
14
15
create_global_average_pooling_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata)16 static enum xnn_status create_global_average_pooling_operator(
17 const struct xnn_node* node,
18 const struct xnn_value* values,
19 size_t num_values,
20 struct xnn_operator_data* opdata)
21 {
22 assert(node->num_inputs == 1);
23 const uint32_t input_id = node->inputs[0];
24 assert(input_id != XNN_INVALID_VALUE_ID);
25 assert(input_id < num_values);
26
27 assert(node->num_outputs == 1);
28 const uint32_t output_id = node->outputs[0];
29 assert(output_id != XNN_INVALID_VALUE_ID);
30 assert(output_id < num_values);
31
32 const size_t num_input_dims = values[input_id].shape.num_dims;
33 assert(num_input_dims >= 1);
34 const size_t channel_dim = values[input_id].shape.dim[num_input_dims - 1];
35
36 enum xnn_status status;
37 if (values[node->inputs[0]].layout == xnn_layout_type_nchw) {
38 assert(node->compute_type == xnn_compute_type_fp32);
39 status = xnn_create_global_average_pooling_ncw_f32(
40 channel_dim /* channels */,
41 node->activation.output_min,
42 node->activation.output_max,
43 node->flags,
44 &opdata->operator_object);
45 } else {
46 assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
47 assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
48 switch (node->compute_type) {
49 case xnn_compute_type_fp32:
50 status = xnn_create_global_average_pooling_nwc_f32(
51 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
52 node->activation.output_min,
53 node->activation.output_max,
54 node->flags,
55 &opdata->operator_object);
56 break;
57 #ifndef XNN_NO_F16_OPERATORS
58 case xnn_compute_type_fp16:
59 status = xnn_create_global_average_pooling_nwc_f16(
60 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
61 node->activation.output_min,
62 node->activation.output_max,
63 node->flags,
64 &opdata->operator_object);
65 break;
66 #endif // !defined(XNN_NO_F16_OPERATORS)
67 #ifndef XNN_NO_QS8_OPERATORS
68 case xnn_compute_type_qs8:
69 {
70 const float output_scale = values[output_id].quantization.scale;
71 const int32_t output_zero_point = values[output_id].quantization.zero_point;
72 const int8_t output_min =
73 (int8_t) lrintf(fminf(fmaxf(node->activation.output_min / output_scale + (float) output_zero_point, -128.0f), 127.0f));
74 const int8_t output_max =
75 (int8_t) lrintf(fminf(fmaxf(node->activation.output_max / output_scale + (float) output_zero_point, -128.0f), 127.0f));
76 status = xnn_create_global_average_pooling_nwc_qs8(
77 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
78 (int8_t) values[input_id].quantization.zero_point, values[input_id].quantization.scale,
79 (int8_t) values[output_id].quantization.zero_point, values[output_id].quantization.scale,
80 output_min,
81 output_max,
82 node->flags,
83 &opdata->operator_object);
84 break;
85 }
86 #endif // !defined(XNN_NO_QS8_OPERATORS)
87 #ifndef XNN_NO_QU8_OPERATORS
88 case xnn_compute_type_qu8:
89 {
90 const float output_scale = values[output_id].quantization.scale;
91 const int32_t output_zero_point = values[output_id].quantization.zero_point;
92 const uint8_t output_min =
93 (uint8_t) lrintf(fminf(fmaxf(node->activation.output_min / output_scale + (float) output_zero_point, 0.0f), 255.0f));
94 const uint8_t output_max =
95 (uint8_t) lrintf(fminf(fmaxf(node->activation.output_max / output_scale + (float) output_zero_point, 0.0f), 255.0f));
96 status = xnn_create_global_average_pooling_nwc_qu8(
97 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
98 (uint8_t) values[input_id].quantization.zero_point, values[input_id].quantization.scale,
99 (uint8_t) values[output_id].quantization.zero_point, values[output_id].quantization.scale,
100 output_min,
101 output_max,
102 node->flags,
103 &opdata->operator_object);
104 break;
105 }
106 #endif // !defined(XNN_NO_QU8_OPERATORS)
107 default:
108 XNN_UNREACHABLE;
109 }
110 }
111 if (status == xnn_status_success) {
112 opdata->batch_size = values[input_id].shape.dim[0];
113 opdata->input_width = values[input_id].shape.dim[1] * values[input_id].shape.dim[2];
114 opdata->inputs[0] = input_id;
115 opdata->outputs[0] = output_id;
116 }
117 return status;
118 }
119
setup_global_average_pooling_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)120 static enum xnn_status setup_global_average_pooling_operator(
121 const struct xnn_operator_data* opdata,
122 const struct xnn_blob* blobs,
123 size_t num_blobs,
124 pthreadpool_t threadpool)
125 {
126 const uint32_t input_id = opdata->inputs[0];
127 assert(input_id != XNN_INVALID_VALUE_ID);
128 assert(input_id < num_blobs);
129
130 const uint32_t output_id = opdata->outputs[0];
131 assert(output_id != XNN_INVALID_VALUE_ID);
132 assert(output_id < num_blobs);
133
134 const struct xnn_blob* input_blob = blobs + input_id;
135 const void* input_data = input_blob->data;
136 assert(input_data != NULL);
137
138 const struct xnn_blob* output_blob = blobs + output_id;
139 void* output_data = output_blob->data;
140 assert(output_data != NULL);
141
142 switch (opdata->operator_object->type) {
143 case xnn_operator_type_global_average_pooling_ncw_f32:
144 return xnn_setup_global_average_pooling_ncw_f32(
145 opdata->operator_object,
146 opdata->batch_size,
147 opdata->input_width,
148 input_data,
149 output_data,
150 threadpool);
151 break;
152 case xnn_operator_type_global_average_pooling_nwc_f32:
153 return xnn_setup_global_average_pooling_nwc_f32(
154 opdata->operator_object,
155 opdata->batch_size,
156 opdata->input_width,
157 input_data,
158 output_data,
159 threadpool);
160 break;
161 #ifndef XNN_NO_F16_OPERATORS
162 case xnn_operator_type_global_average_pooling_nwc_f16:
163 return xnn_setup_global_average_pooling_nwc_f16(
164 opdata->operator_object,
165 opdata->batch_size,
166 opdata->input_width,
167 input_data,
168 output_data,
169 threadpool);
170 break;
171 #endif // !defined(XNN_NO_F16_OPERATORS)
172 #ifndef XNN_NO_QS8_OPERATORS
173 case xnn_operator_type_global_average_pooling_nwc_qs8:
174 return xnn_setup_global_average_pooling_nwc_qs8(
175 opdata->operator_object,
176 opdata->batch_size,
177 opdata->input_width,
178 input_data,
179 output_data,
180 threadpool);
181 break;
182 #endif // !defined(XNN_NO_QS8_OPERATORS)
183 #ifndef XNN_NO_QU8_OPERATORS
184 case xnn_operator_type_global_average_pooling_nwc_qu8:
185 return xnn_setup_global_average_pooling_nwc_qu8(
186 opdata->operator_object,
187 opdata->batch_size,
188 opdata->input_width,
189 input_data,
190 output_data,
191 threadpool);
192 break;
193 #endif // !defined(XNN_NO_QU8_OPERATORS)
194 default:
195 XNN_UNREACHABLE;
196 }
197 }
198
xnn_define_global_average_pooling_2d(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input_id,uint32_t output_id,uint32_t flags)199 enum xnn_status xnn_define_global_average_pooling_2d(
200 xnn_subgraph_t subgraph,
201 float output_min,
202 float output_max,
203 uint32_t input_id,
204 uint32_t output_id,
205 uint32_t flags)
206 {
207 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
208 xnn_log_error("failed to define %s operator: XNNPACK is not initialized",
209 xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d));
210 return xnn_status_uninitialized;
211 }
212
213 if (isnan(output_min)) {
214 xnn_log_error(
215 "failed to define %s operator with NaN output lower bound: lower bound must be non-NaN",
216 xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d));
217 return xnn_status_invalid_parameter;
218 }
219
220 if (isnan(output_max)) {
221 xnn_log_error(
222 "failed to define %s operator with NaN output upper bound: upper bound must be non-NaN",
223 xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d));
224 return xnn_status_invalid_parameter;
225 }
226
227 if (output_min >= output_max) {
228 xnn_log_error(
229 "failed to define %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
230 xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d), output_min, output_max);
231 return xnn_status_invalid_parameter;
232 }
233
234 if (input_id >= subgraph->num_values) {
235 xnn_log_error(
236 "failed to define %s operator with input ID #%" PRIu32 ": invalid Value ID",
237 xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d), input_id);
238 return xnn_status_invalid_parameter;
239 }
240
241 const struct xnn_value* input_value = &subgraph->values[input_id];
242 if (input_value->type != xnn_value_type_dense_tensor) {
243 xnn_log_error(
244 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
245 xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d), input_id, input_value->type);
246 return xnn_status_invalid_parameter;
247 }
248
249 switch (input_value->datatype) {
250 case xnn_datatype_fp32:
251 #ifndef XNN_NO_QS8_OPERATORS
252 case xnn_datatype_qint8:
253 #endif // !defined(XNN_NO_QS8_OPERATORS)
254 #ifndef XNN_NO_QU8_OPERATORS
255 case xnn_datatype_quint8:
256 #endif // !defined(XNN_NO_QU8_OPERATORS)
257 break;
258 default:
259 xnn_log_error(
260 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
261 xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d), input_id,
262 xnn_datatype_to_string(input_value->datatype), input_value->datatype);
263 return xnn_status_invalid_parameter;
264 }
265
266 if (output_id >= subgraph->num_values) {
267 xnn_log_error(
268 "failed to define %s operator with output ID #%" PRIu32 ": invalid Value ID",
269 xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d), output_id);
270 return xnn_status_invalid_parameter;
271 }
272
273 const struct xnn_value* output_value = &subgraph->values[output_id];
274 if (output_value->type != xnn_value_type_dense_tensor) {
275 xnn_log_error(
276 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
277 xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d), output_id, output_value->type);
278 return xnn_status_invalid_parameter;
279 }
280
281 enum xnn_compute_type compute_type = xnn_compute_type_invalid;
282 switch (output_value->datatype) {
283 case xnn_datatype_fp32:
284 compute_type = xnn_compute_type_fp32;
285 break;
286 #ifndef XNN_NO_QS8_OPERATORS
287 case xnn_datatype_qint8:
288 compute_type = xnn_compute_type_qs8;
289 break;
290 #endif // !defined(XNN_NO_QS8_OPERATORS)
291 #ifndef XNN_NO_QU8_OPERATORS
292 case xnn_datatype_quint8:
293 compute_type = xnn_compute_type_qu8;
294 break;
295 #endif // !defined(XNN_NO_QU8_OPERATORS)
296 default:
297 xnn_log_error(
298 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
299 xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d), output_id,
300 xnn_datatype_to_string(output_value->datatype), output_value->datatype);
301 return xnn_status_invalid_parameter;
302 }
303
304 if (input_value->datatype != output_value->datatype) {
305 xnn_log_error(
306 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
307 ": mismatching datatypes across input (%s) and output (%s)",
308 xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d), input_id, output_id,
309 xnn_datatype_to_string(input_value->datatype),
310 xnn_datatype_to_string(output_value->datatype));
311 return xnn_status_invalid_parameter;
312 }
313
314 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
315 if (node == NULL) {
316 return xnn_status_out_of_memory;
317 }
318
319 node->type = xnn_node_type_global_average_pooling_2d;
320 node->compute_type = compute_type;
321 node->activation.output_min = output_min;
322 node->activation.output_max = output_max;
323 node->num_inputs = 1;
324 node->inputs[0] = input_id;
325 node->num_outputs = 1;
326 node->outputs[0] = output_id;
327 node->flags = flags;
328
329 node->create = create_global_average_pooling_operator;
330 node->setup = setup_global_average_pooling_operator;
331
332 return xnn_status_success;
333 }
334