• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <string.h>
10 
11 #include <fp16.h>
12 
13 #include <xnnpack.h>
14 #include <xnnpack/log.h>
15 #include <xnnpack/params.h>
16 #include <xnnpack/subgraph.h>
17 
18 
create_constant_pad_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata)19 static enum xnn_status create_constant_pad_operator(
20   const struct xnn_node* node,
21   const struct xnn_value* values,
22   size_t num_values,
23   struct xnn_operator_data* opdata)
24 {
25   assert(node->num_inputs == 1);
26   const uint32_t input_id = node->inputs[0];
27   assert(input_id != XNN_INVALID_VALUE_ID);
28   assert(input_id < num_values);
29 
30   assert(node->num_outputs == 1);
31   const uint32_t output_id = node->outputs[0];
32   assert(output_id != XNN_INVALID_VALUE_ID);
33   assert(output_id < num_values);
34 
35   enum xnn_status status;
36   switch (node->compute_type) {
37 #ifndef XNN_NO_F16_OPERATORS
38     case xnn_compute_type_fp16:
39       status = xnn_create_constant_pad_nd_x16(
40         &node->params.static_pad.padding_value,
41         node->flags,
42         &opdata->operator_object);
43       break;
44 #endif  // !defined(XNN_NO_F16_OPERATORS)
45     case xnn_compute_type_fp32:
46       status = xnn_create_constant_pad_nd_x32(
47         &node->params.static_pad.padding_value,
48         node->flags,
49         &opdata->operator_object);
50       break;
51 #ifndef XNN_NO_QS8_OPERATORS
52     case xnn_compute_type_qs8:
53 #endif  // !defined(XNN_NO_QS8_OPERATORS)
54 #ifndef XNN_NO_QU8_OPERATORS
55     case xnn_compute_type_qu8:
56 #endif  // !defined(XNN_NO_QU8_OPERATORS)
57 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
58       status = xnn_create_constant_pad_nd_x8(
59         &node->params.static_pad.padding_value,
60         node->flags,
61         &opdata->operator_object);
62       break;
63 #endif  // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
64     default:
65       XNN_UNREACHABLE;
66   }
67   if (status == xnn_status_success) {
68     opdata->shape1 = values[input_id].shape;
69     memcpy(opdata->pre_paddings, node->params.static_pad.pre_paddings, sizeof(size_t) * XNN_MAX_TENSOR_DIMS);
70     memcpy(opdata->post_paddings, node->params.static_pad.post_paddings, sizeof(size_t) * XNN_MAX_TENSOR_DIMS);
71     opdata->inputs[0] = input_id;
72     opdata->outputs[0] = output_id;
73   }
74   return status;
75 }
76 
setup_constant_pad_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)77 static enum xnn_status setup_constant_pad_operator(
78   const struct xnn_operator_data* opdata,
79   const struct xnn_blob* blobs,
80   size_t num_blobs,
81   pthreadpool_t threadpool)
82 {
83   const uint32_t input_id = opdata->inputs[0];
84   assert(input_id != XNN_INVALID_VALUE_ID);
85   assert(input_id < num_blobs);
86 
87   const uint32_t output_id = opdata->outputs[0];
88   assert(output_id != XNN_INVALID_VALUE_ID);
89   assert(output_id < num_blobs);
90 
91   const struct xnn_blob* input_blob = blobs + input_id;
92   const void* input_data = input_blob->data;
93   assert(input_data != NULL);
94 
95   const struct xnn_blob* output_blob = blobs + output_id;
96   void* output_data = output_blob->data;
97   assert(output_data != NULL);
98 
99   switch (opdata->operator_object->type) {
100 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
101     case xnn_operator_type_constant_pad_nd_x8:
102       return xnn_setup_constant_pad_nd_x8(
103         opdata->operator_object,
104         opdata->shape1.num_dims,
105         opdata->shape1.dim,
106         opdata->pre_paddings,
107         opdata->post_paddings,
108         input_data,
109         output_data,
110         threadpool);
111       break;
112 #endif  // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
113 #ifndef XNN_NO_F16_OPERATORS
114     case xnn_operator_type_constant_pad_nd_x16:
115       return xnn_setup_constant_pad_nd_x16(
116         opdata->operator_object,
117         opdata->shape1.num_dims,
118         opdata->shape1.dim,
119         opdata->pre_paddings,
120         opdata->post_paddings,
121         input_data,
122         output_data,
123         threadpool);
124       break;
125 #endif  // !defined(XNN_NO_F16_OPERATORS)
126     case xnn_operator_type_constant_pad_nd_x32:
127       return xnn_setup_constant_pad_nd_x32(
128         opdata->operator_object,
129         opdata->shape1.num_dims,
130         opdata->shape1.dim,
131         opdata->pre_paddings,
132         opdata->post_paddings,
133         input_data,
134         output_data,
135         threadpool);
136       break;
137     default:
138       XNN_UNREACHABLE;
139   }
140 }
141 
xnn_define_static_constant_pad(xnn_subgraph_t subgraph,const size_t * pre_paddings,const size_t * post_paddings,float padding_value,uint32_t input_id,uint32_t output_id,uint32_t flags)142 enum xnn_status xnn_define_static_constant_pad(
143   xnn_subgraph_t subgraph,
144   const size_t* pre_paddings,
145   const size_t* post_paddings,
146   float padding_value,
147   uint32_t input_id,
148   uint32_t output_id,
149   uint32_t flags)
150 {
151   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
152     xnn_log_error("failed to define %s operator: XNNPACK is not initialized",
153       xnn_node_type_to_string(xnn_node_type_static_constant_pad));
154     return xnn_status_uninitialized;
155   }
156 
157   if (input_id >= subgraph->num_values) {
158     xnn_log_error(
159       "failed to define %s operator with input ID #%" PRIu32 ": invalid Value ID",
160       xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id);
161     return xnn_status_invalid_parameter;
162   }
163 
164   const struct xnn_value* input_value = &subgraph->values[input_id];
165   if (input_value->type != xnn_value_type_dense_tensor) {
166     xnn_log_error(
167       "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
168       xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id, input_value->type);
169     return xnn_status_invalid_parameter;
170   }
171 
172   switch (input_value->datatype) {
173     case xnn_datatype_fp32:
174 #ifndef XNN_NO_QS8_OPERATORS
175     case xnn_datatype_qint8:
176 #endif  // !defined(XNN_NO_QS8_OPERATORS)
177 #ifndef XNN_NO_QU8_OPERATORS
178     case xnn_datatype_quint8:
179 #endif  // !defined(XNN_NO_QU8_OPERATORS)
180       break;
181     default:
182       xnn_log_error(
183         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
184         xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id,
185         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
186       return xnn_status_invalid_parameter;
187   }
188 
189   if (output_id >= subgraph->num_values) {
190     xnn_log_error(
191       "failed to define %s operator with output ID #%" PRIu32 ": invalid Value ID",
192       xnn_node_type_to_string(xnn_node_type_static_constant_pad), output_id);
193     return xnn_status_invalid_parameter;
194   }
195 
196   const struct xnn_value* output_value = &subgraph->values[output_id];
197   if (output_value->type != xnn_value_type_dense_tensor) {
198     xnn_log_error(
199       "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
200       xnn_node_type_to_string(xnn_node_type_static_constant_pad), output_id, output_value->type);
201     return xnn_status_invalid_parameter;
202   }
203 
204   enum xnn_compute_type compute_type = xnn_compute_type_invalid;
205   switch (output_value->datatype) {
206     case xnn_datatype_fp32:
207       compute_type = xnn_compute_type_fp32;
208       break;
209 #ifndef XNN_NO_QS8_OPERATORS
210     case xnn_datatype_qint8:
211       compute_type = xnn_compute_type_qs8;
212       break;
213 #endif  // !defined(XNN_NO_QS8_OPERATORS)
214 #ifndef XNN_NO_QU8_OPERATORS
215     case xnn_datatype_quint8:
216       compute_type = xnn_compute_type_qu8;
217       break;
218 #endif  // !defined(XNN_NO_QU8_OPERATORS)
219     default:
220       xnn_log_error(
221         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
222         xnn_node_type_to_string(xnn_node_type_static_constant_pad), output_id,
223         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
224       return xnn_status_invalid_parameter;
225   }
226 
227   if (input_value->datatype != output_value->datatype) {
228     xnn_log_error(
229       "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
230       ": mismatching datatypes across input (%s) and output (%s)",
231       xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id, output_id,
232       xnn_datatype_to_string(input_value->datatype),
233       xnn_datatype_to_string(output_value->datatype));
234     return xnn_status_invalid_parameter;
235   }
236 
237 #if !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
238   if (output_value->datatype == xnn_datatype_qint8 || output_value->datatype == xnn_datatype_quint8) {
239     if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
240       xnn_log_error(
241         "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
242         ": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
243         xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id, output_id,
244         input_value->quantization.zero_point, output_value->quantization.zero_point);
245       return xnn_status_invalid_parameter;
246     }
247     if (input_value->quantization.scale != output_value->quantization.scale) {
248       xnn_log_error(
249         "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
250         ": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
251         xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id, output_id,
252         input_value->quantization.scale, output_value->quantization.scale);
253       return xnn_status_invalid_parameter;
254     }
255   }
256 #endif  // !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
257 
258   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
259   if (node == NULL) {
260     return xnn_status_out_of_memory;
261   }
262 
263   const size_t num_dims = subgraph->values[input_id].shape.num_dims;
264   memcpy(&node->params.static_pad.pre_paddings, pre_paddings, num_dims * sizeof(size_t));
265   memcpy(&node->params.static_pad.post_paddings, post_paddings, num_dims * sizeof(size_t));
266   switch (output_value->datatype) {
267     case xnn_datatype_fp32:
268       node->params.static_pad.padding_value = fp32_to_bits(padding_value);
269       break;
270 #ifndef XNN_NO_QS8_OPERATORS
271     case xnn_datatype_qint8:
272     {
273       const float output_scale = output_value->quantization.scale;
274       const int32_t output_zero_point = output_value->quantization.zero_point;
275       node->params.static_pad.padding_value = (int8_t)
276         lrintf(fminf(fmaxf(padding_value / output_scale + (float) output_zero_point, -128.0f), 127.0f));
277       break;
278     }
279 #endif  // !defined(XNN_NO_QS8_OPERATORS)
280 #ifndef XNN_NO_QU8_OPERATORS
281     case xnn_datatype_quint8:
282     {
283       const float output_scale = output_value->quantization.scale;
284       const int32_t output_zero_point = output_value->quantization.zero_point;
285       node->params.static_pad.padding_value = (uint8_t)
286         lrintf(fminf(fmaxf(padding_value / output_scale + (float) output_zero_point, 0.0f), 255.0f));
287       break;
288     }
289 #endif  // !defined(XNN_NO_QU8_OPERATORS)
290     default:
291       XNN_UNREACHABLE;
292   }
293 
294   node->type = xnn_node_type_static_constant_pad;
295   node->compute_type = compute_type;
296   node->num_inputs = 1;
297   node->inputs[0] = input_id;
298   node->num_outputs = 1;
299   node->outputs[0] = output_id;
300   node->flags = flags;
301 
302   node->create = create_constant_pad_operator;
303   node->setup = setup_constant_pad_operator;
304 
305   return xnn_status_success;
306 }
307