• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 
10 #include <xnnpack.h>
11 #include <xnnpack/log.h>
12 #include <xnnpack/params.h>
13 #include <xnnpack/subgraph.h>
14 
15 
create_global_average_pooling_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata)16 static enum xnn_status create_global_average_pooling_operator(
17   const struct xnn_node* node,
18   const struct xnn_value* values,
19   size_t num_values,
20   struct xnn_operator_data* opdata)
21 {
22   assert(node->num_inputs == 1);
23   const uint32_t input_id = node->inputs[0];
24   assert(input_id != XNN_INVALID_VALUE_ID);
25   assert(input_id < num_values);
26 
27   assert(node->num_outputs == 1);
28   const uint32_t output_id = node->outputs[0];
29   assert(output_id != XNN_INVALID_VALUE_ID);
30   assert(output_id < num_values);
31 
32   const size_t num_input_dims = values[input_id].shape.num_dims;
33   assert(num_input_dims >= 1);
34   const size_t channel_dim = values[input_id].shape.dim[num_input_dims - 1];
35 
36   enum xnn_status status;
37   if (values[node->inputs[0]].layout == xnn_layout_type_nchw) {
38     assert(node->compute_type == xnn_compute_type_fp32);
39     status = xnn_create_global_average_pooling_ncw_f32(
40       channel_dim /* channels */,
41       node->activation.output_min,
42       node->activation.output_max,
43       node->flags,
44       &opdata->operator_object);
45   } else {
46     assert(values[node->inputs[0]].layout == xnn_layout_type_nhwc);
47     assert(values[node->outputs[0]].layout == xnn_layout_type_nhwc);
48     switch (node->compute_type) {
49       case xnn_compute_type_fp32:
50         status = xnn_create_global_average_pooling_nwc_f32(
51           channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
52           node->activation.output_min,
53           node->activation.output_max,
54           node->flags,
55           &opdata->operator_object);
56         break;
57 #ifndef XNN_NO_F16_OPERATORS
58       case xnn_compute_type_fp16:
59         status = xnn_create_global_average_pooling_nwc_f16(
60           channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
61           node->activation.output_min,
62           node->activation.output_max,
63           node->flags,
64           &opdata->operator_object);
65         break;
66 #endif  // !defined(XNN_NO_F16_OPERATORS)
67 #ifndef XNN_NO_QS8_OPERATORS
68       case xnn_compute_type_qs8:
69       {
70         const float output_scale = values[output_id].quantization.scale;
71         const int32_t output_zero_point = values[output_id].quantization.zero_point;
72         const int8_t output_min =
73           (int8_t) lrintf(fminf(fmaxf(node->activation.output_min / output_scale + (float) output_zero_point, -128.0f), 127.0f));
74         const int8_t output_max =
75           (int8_t) lrintf(fminf(fmaxf(node->activation.output_max / output_scale + (float) output_zero_point, -128.0f), 127.0f));
76         status = xnn_create_global_average_pooling_nwc_qs8(
77           channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
78           (int8_t) values[input_id].quantization.zero_point, values[input_id].quantization.scale,
79           (int8_t) values[output_id].quantization.zero_point, values[output_id].quantization.scale,
80           output_min,
81           output_max,
82           node->flags,
83           &opdata->operator_object);
84         break;
85       }
86 #endif  // !defined(XNN_NO_QS8_OPERATORS)
87 #ifndef XNN_NO_QU8_OPERATORS
88       case xnn_compute_type_qu8:
89       {
90         const float output_scale = values[output_id].quantization.scale;
91         const int32_t output_zero_point = values[output_id].quantization.zero_point;
92         const uint8_t output_min =
93           (uint8_t) lrintf(fminf(fmaxf(node->activation.output_min / output_scale + (float) output_zero_point, 0.0f), 255.0f));
94         const uint8_t output_max =
95           (uint8_t) lrintf(fminf(fmaxf(node->activation.output_max / output_scale + (float) output_zero_point, 0.0f), 255.0f));
96         status = xnn_create_global_average_pooling_nwc_qu8(
97           channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
98           (uint8_t) values[input_id].quantization.zero_point, values[input_id].quantization.scale,
99           (uint8_t) values[output_id].quantization.zero_point, values[output_id].quantization.scale,
100           output_min,
101           output_max,
102           node->flags,
103           &opdata->operator_object);
104         break;
105       }
106 #endif  // !defined(XNN_NO_QU8_OPERATORS)
107       default:
108         XNN_UNREACHABLE;
109     }
110   }
111   if (status == xnn_status_success) {
112     opdata->batch_size = values[input_id].shape.dim[0];
113     opdata->input_width = values[input_id].shape.dim[1] * values[input_id].shape.dim[2];
114     opdata->inputs[0] = input_id;
115     opdata->outputs[0] = output_id;
116   }
117   return status;
118 }
119 
setup_global_average_pooling_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)120 static enum xnn_status setup_global_average_pooling_operator(
121   const struct xnn_operator_data* opdata,
122   const struct xnn_blob* blobs,
123   size_t num_blobs,
124   pthreadpool_t threadpool)
125 {
126   const uint32_t input_id = opdata->inputs[0];
127   assert(input_id != XNN_INVALID_VALUE_ID);
128   assert(input_id < num_blobs);
129 
130   const uint32_t output_id = opdata->outputs[0];
131   assert(output_id != XNN_INVALID_VALUE_ID);
132   assert(output_id < num_blobs);
133 
134   const struct xnn_blob* input_blob = blobs + input_id;
135   const void* input_data = input_blob->data;
136   assert(input_data != NULL);
137 
138   const struct xnn_blob* output_blob = blobs + output_id;
139   void* output_data = output_blob->data;
140   assert(output_data != NULL);
141 
142   switch (opdata->operator_object->type) {
143     case xnn_operator_type_global_average_pooling_ncw_f32:
144       return xnn_setup_global_average_pooling_ncw_f32(
145         opdata->operator_object,
146         opdata->batch_size,
147         opdata->input_width,
148         input_data,
149         output_data,
150         threadpool);
151       break;
152     case xnn_operator_type_global_average_pooling_nwc_f32:
153       return xnn_setup_global_average_pooling_nwc_f32(
154         opdata->operator_object,
155         opdata->batch_size,
156         opdata->input_width,
157         input_data,
158         output_data,
159         threadpool);
160       break;
161 #ifndef XNN_NO_F16_OPERATORS
162     case xnn_operator_type_global_average_pooling_nwc_f16:
163       return xnn_setup_global_average_pooling_nwc_f16(
164         opdata->operator_object,
165         opdata->batch_size,
166         opdata->input_width,
167         input_data,
168         output_data,
169         threadpool);
170       break;
171 #endif  // !defined(XNN_NO_F16_OPERATORS)
172 #ifndef XNN_NO_QS8_OPERATORS
173     case xnn_operator_type_global_average_pooling_nwc_qs8:
174       return xnn_setup_global_average_pooling_nwc_qs8(
175         opdata->operator_object,
176         opdata->batch_size,
177         opdata->input_width,
178         input_data,
179         output_data,
180         threadpool);
181       break;
182 #endif  // !defined(XNN_NO_QS8_OPERATORS)
183 #ifndef XNN_NO_QU8_OPERATORS
184     case xnn_operator_type_global_average_pooling_nwc_qu8:
185       return xnn_setup_global_average_pooling_nwc_qu8(
186         opdata->operator_object,
187         opdata->batch_size,
188         opdata->input_width,
189         input_data,
190         output_data,
191         threadpool);
192       break;
193 #endif  // !defined(XNN_NO_QU8_OPERATORS)
194     default:
195       XNN_UNREACHABLE;
196   }
197 }
198 
xnn_define_global_average_pooling_2d(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input_id,uint32_t output_id,uint32_t flags)199 enum xnn_status xnn_define_global_average_pooling_2d(
200   xnn_subgraph_t subgraph,
201   float output_min,
202   float output_max,
203   uint32_t input_id,
204   uint32_t output_id,
205   uint32_t flags)
206 {
207   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
208     xnn_log_error("failed to define %s operator: XNNPACK is not initialized",
209       xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d));
210     return xnn_status_uninitialized;
211   }
212 
213   if (isnan(output_min)) {
214     xnn_log_error(
215       "failed to define %s operator with NaN output lower bound: lower bound must be non-NaN",
216       xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d));
217     return xnn_status_invalid_parameter;
218   }
219 
220   if (isnan(output_max)) {
221     xnn_log_error(
222       "failed to define %s operator with NaN output upper bound: upper bound must be non-NaN",
223       xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d));
224     return xnn_status_invalid_parameter;
225   }
226 
227   if (output_min >= output_max) {
228     xnn_log_error(
229       "failed to define %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
230       xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d), output_min, output_max);
231     return xnn_status_invalid_parameter;
232   }
233 
234   if (input_id >= subgraph->num_values) {
235     xnn_log_error(
236       "failed to define %s operator with input ID #%" PRIu32 ": invalid Value ID",
237       xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d), input_id);
238     return xnn_status_invalid_parameter;
239   }
240 
241   const struct xnn_value* input_value = &subgraph->values[input_id];
242   if (input_value->type != xnn_value_type_dense_tensor) {
243     xnn_log_error(
244       "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
245       xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d), input_id, input_value->type);
246     return xnn_status_invalid_parameter;
247   }
248 
249   switch (input_value->datatype) {
250     case xnn_datatype_fp32:
251 #ifndef XNN_NO_QS8_OPERATORS
252     case xnn_datatype_qint8:
253 #endif  // !defined(XNN_NO_QS8_OPERATORS)
254 #ifndef XNN_NO_QU8_OPERATORS
255     case xnn_datatype_quint8:
256 #endif  // !defined(XNN_NO_QU8_OPERATORS)
257       break;
258     default:
259       xnn_log_error(
260         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
261         xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d), input_id,
262         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
263       return xnn_status_invalid_parameter;
264   }
265 
266   if (output_id >= subgraph->num_values) {
267     xnn_log_error(
268       "failed to define %s operator with output ID #%" PRIu32 ": invalid Value ID",
269       xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d), output_id);
270     return xnn_status_invalid_parameter;
271   }
272 
273   const struct xnn_value* output_value = &subgraph->values[output_id];
274   if (output_value->type != xnn_value_type_dense_tensor) {
275     xnn_log_error(
276       "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
277       xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d), output_id, output_value->type);
278     return xnn_status_invalid_parameter;
279   }
280 
281   enum xnn_compute_type compute_type = xnn_compute_type_invalid;
282   switch (output_value->datatype) {
283     case xnn_datatype_fp32:
284       compute_type = xnn_compute_type_fp32;
285       break;
286 #ifndef XNN_NO_QS8_OPERATORS
287     case xnn_datatype_qint8:
288       compute_type = xnn_compute_type_qs8;
289       break;
290 #endif  // !defined(XNN_NO_QS8_OPERATORS)
291 #ifndef XNN_NO_QU8_OPERATORS
292     case xnn_datatype_quint8:
293       compute_type = xnn_compute_type_qu8;
294       break;
295 #endif  // !defined(XNN_NO_QU8_OPERATORS)
296     default:
297       xnn_log_error(
298         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
299         xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d), output_id,
300         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
301       return xnn_status_invalid_parameter;
302   }
303 
304   if (input_value->datatype != output_value->datatype) {
305     xnn_log_error(
306       "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
307       ": mismatching datatypes across input (%s) and output (%s)",
308       xnn_node_type_to_string(xnn_node_type_global_average_pooling_2d), input_id, output_id,
309       xnn_datatype_to_string(input_value->datatype),
310       xnn_datatype_to_string(output_value->datatype));
311     return xnn_status_invalid_parameter;
312   }
313 
314   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
315   if (node == NULL) {
316     return xnn_status_out_of_memory;
317   }
318 
319   node->type = xnn_node_type_global_average_pooling_2d;
320   node->compute_type = compute_type;
321   node->activation.output_min = output_min;
322   node->activation.output_max = output_max;
323   node->num_inputs = 1;
324   node->inputs[0] = input_id;
325   node->num_outputs = 1;
326   node->outputs[0] = output_id;
327   node->flags = flags;
328 
329   node->create = create_global_average_pooling_operator;
330   node->setup = setup_global_average_pooling_operator;
331 
332   return xnn_status_success;
333 }
334