1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <string.h>
10
11 #include <fp16.h>
12
13 #include <xnnpack.h>
14 #include <xnnpack/log.h>
15 #include <xnnpack/params.h>
16 #include <xnnpack/subgraph.h>
17
18
create_constant_pad_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata)19 static enum xnn_status create_constant_pad_operator(
20 const struct xnn_node* node,
21 const struct xnn_value* values,
22 size_t num_values,
23 struct xnn_operator_data* opdata)
24 {
25 assert(node->num_inputs == 1);
26 const uint32_t input_id = node->inputs[0];
27 assert(input_id != XNN_INVALID_VALUE_ID);
28 assert(input_id < num_values);
29
30 assert(node->num_outputs == 1);
31 const uint32_t output_id = node->outputs[0];
32 assert(output_id != XNN_INVALID_VALUE_ID);
33 assert(output_id < num_values);
34
35 enum xnn_status status;
36 switch (node->compute_type) {
37 #ifndef XNN_NO_F16_OPERATORS
38 case xnn_compute_type_fp16:
39 status = xnn_create_constant_pad_nd_x16(
40 &node->params.static_pad.padding_value,
41 node->flags,
42 &opdata->operator_object);
43 break;
44 #endif // !defined(XNN_NO_F16_OPERATORS)
45 case xnn_compute_type_fp32:
46 status = xnn_create_constant_pad_nd_x32(
47 &node->params.static_pad.padding_value,
48 node->flags,
49 &opdata->operator_object);
50 break;
51 #ifndef XNN_NO_QS8_OPERATORS
52 case xnn_compute_type_qs8:
53 #endif // !defined(XNN_NO_QS8_OPERATORS)
54 #ifndef XNN_NO_QU8_OPERATORS
55 case xnn_compute_type_qu8:
56 #endif // !defined(XNN_NO_QU8_OPERATORS)
57 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
58 status = xnn_create_constant_pad_nd_x8(
59 &node->params.static_pad.padding_value,
60 node->flags,
61 &opdata->operator_object);
62 break;
63 #endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
64 default:
65 XNN_UNREACHABLE;
66 }
67 if (status == xnn_status_success) {
68 opdata->shape1 = values[input_id].shape;
69 memcpy(opdata->pre_paddings, node->params.static_pad.pre_paddings, sizeof(size_t) * XNN_MAX_TENSOR_DIMS);
70 memcpy(opdata->post_paddings, node->params.static_pad.post_paddings, sizeof(size_t) * XNN_MAX_TENSOR_DIMS);
71 opdata->inputs[0] = input_id;
72 opdata->outputs[0] = output_id;
73 }
74 return status;
75 }
76
setup_constant_pad_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)77 static enum xnn_status setup_constant_pad_operator(
78 const struct xnn_operator_data* opdata,
79 const struct xnn_blob* blobs,
80 size_t num_blobs,
81 pthreadpool_t threadpool)
82 {
83 const uint32_t input_id = opdata->inputs[0];
84 assert(input_id != XNN_INVALID_VALUE_ID);
85 assert(input_id < num_blobs);
86
87 const uint32_t output_id = opdata->outputs[0];
88 assert(output_id != XNN_INVALID_VALUE_ID);
89 assert(output_id < num_blobs);
90
91 const struct xnn_blob* input_blob = blobs + input_id;
92 const void* input_data = input_blob->data;
93 assert(input_data != NULL);
94
95 const struct xnn_blob* output_blob = blobs + output_id;
96 void* output_data = output_blob->data;
97 assert(output_data != NULL);
98
99 switch (opdata->operator_object->type) {
100 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
101 case xnn_operator_type_constant_pad_nd_x8:
102 return xnn_setup_constant_pad_nd_x8(
103 opdata->operator_object,
104 opdata->shape1.num_dims,
105 opdata->shape1.dim,
106 opdata->pre_paddings,
107 opdata->post_paddings,
108 input_data,
109 output_data,
110 threadpool);
111 break;
112 #endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
113 #ifndef XNN_NO_F16_OPERATORS
114 case xnn_operator_type_constant_pad_nd_x16:
115 return xnn_setup_constant_pad_nd_x16(
116 opdata->operator_object,
117 opdata->shape1.num_dims,
118 opdata->shape1.dim,
119 opdata->pre_paddings,
120 opdata->post_paddings,
121 input_data,
122 output_data,
123 threadpool);
124 break;
125 #endif // !defined(XNN_NO_F16_OPERATORS)
126 case xnn_operator_type_constant_pad_nd_x32:
127 return xnn_setup_constant_pad_nd_x32(
128 opdata->operator_object,
129 opdata->shape1.num_dims,
130 opdata->shape1.dim,
131 opdata->pre_paddings,
132 opdata->post_paddings,
133 input_data,
134 output_data,
135 threadpool);
136 break;
137 default:
138 XNN_UNREACHABLE;
139 }
140 }
141
xnn_define_static_constant_pad(xnn_subgraph_t subgraph,const size_t * pre_paddings,const size_t * post_paddings,float padding_value,uint32_t input_id,uint32_t output_id,uint32_t flags)142 enum xnn_status xnn_define_static_constant_pad(
143 xnn_subgraph_t subgraph,
144 const size_t* pre_paddings,
145 const size_t* post_paddings,
146 float padding_value,
147 uint32_t input_id,
148 uint32_t output_id,
149 uint32_t flags)
150 {
151 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
152 xnn_log_error("failed to define %s operator: XNNPACK is not initialized",
153 xnn_node_type_to_string(xnn_node_type_static_constant_pad));
154 return xnn_status_uninitialized;
155 }
156
157 if (input_id >= subgraph->num_values) {
158 xnn_log_error(
159 "failed to define %s operator with input ID #%" PRIu32 ": invalid Value ID",
160 xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id);
161 return xnn_status_invalid_parameter;
162 }
163
164 const struct xnn_value* input_value = &subgraph->values[input_id];
165 if (input_value->type != xnn_value_type_dense_tensor) {
166 xnn_log_error(
167 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
168 xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id, input_value->type);
169 return xnn_status_invalid_parameter;
170 }
171
172 switch (input_value->datatype) {
173 case xnn_datatype_fp32:
174 #ifndef XNN_NO_QS8_OPERATORS
175 case xnn_datatype_qint8:
176 #endif // !defined(XNN_NO_QS8_OPERATORS)
177 #ifndef XNN_NO_QU8_OPERATORS
178 case xnn_datatype_quint8:
179 #endif // !defined(XNN_NO_QU8_OPERATORS)
180 break;
181 default:
182 xnn_log_error(
183 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
184 xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id,
185 xnn_datatype_to_string(input_value->datatype), input_value->datatype);
186 return xnn_status_invalid_parameter;
187 }
188
189 if (output_id >= subgraph->num_values) {
190 xnn_log_error(
191 "failed to define %s operator with output ID #%" PRIu32 ": invalid Value ID",
192 xnn_node_type_to_string(xnn_node_type_static_constant_pad), output_id);
193 return xnn_status_invalid_parameter;
194 }
195
196 const struct xnn_value* output_value = &subgraph->values[output_id];
197 if (output_value->type != xnn_value_type_dense_tensor) {
198 xnn_log_error(
199 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
200 xnn_node_type_to_string(xnn_node_type_static_constant_pad), output_id, output_value->type);
201 return xnn_status_invalid_parameter;
202 }
203
204 enum xnn_compute_type compute_type = xnn_compute_type_invalid;
205 switch (output_value->datatype) {
206 case xnn_datatype_fp32:
207 compute_type = xnn_compute_type_fp32;
208 break;
209 #ifndef XNN_NO_QS8_OPERATORS
210 case xnn_datatype_qint8:
211 compute_type = xnn_compute_type_qs8;
212 break;
213 #endif // !defined(XNN_NO_QS8_OPERATORS)
214 #ifndef XNN_NO_QU8_OPERATORS
215 case xnn_datatype_quint8:
216 compute_type = xnn_compute_type_qu8;
217 break;
218 #endif // !defined(XNN_NO_QU8_OPERATORS)
219 default:
220 xnn_log_error(
221 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
222 xnn_node_type_to_string(xnn_node_type_static_constant_pad), output_id,
223 xnn_datatype_to_string(output_value->datatype), output_value->datatype);
224 return xnn_status_invalid_parameter;
225 }
226
227 if (input_value->datatype != output_value->datatype) {
228 xnn_log_error(
229 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
230 ": mismatching datatypes across input (%s) and output (%s)",
231 xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id, output_id,
232 xnn_datatype_to_string(input_value->datatype),
233 xnn_datatype_to_string(output_value->datatype));
234 return xnn_status_invalid_parameter;
235 }
236
237 #if !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
238 if (output_value->datatype == xnn_datatype_qint8 || output_value->datatype == xnn_datatype_quint8) {
239 if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
240 xnn_log_error(
241 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
242 ": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
243 xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id, output_id,
244 input_value->quantization.zero_point, output_value->quantization.zero_point);
245 return xnn_status_invalid_parameter;
246 }
247 if (input_value->quantization.scale != output_value->quantization.scale) {
248 xnn_log_error(
249 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
250 ": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
251 xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id, output_id,
252 input_value->quantization.scale, output_value->quantization.scale);
253 return xnn_status_invalid_parameter;
254 }
255 }
256 #endif // !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
257
258 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
259 if (node == NULL) {
260 return xnn_status_out_of_memory;
261 }
262
263 const size_t num_dims = subgraph->values[input_id].shape.num_dims;
264 memcpy(&node->params.static_pad.pre_paddings, pre_paddings, num_dims * sizeof(size_t));
265 memcpy(&node->params.static_pad.post_paddings, post_paddings, num_dims * sizeof(size_t));
266 switch (output_value->datatype) {
267 case xnn_datatype_fp32:
268 node->params.static_pad.padding_value = fp32_to_bits(padding_value);
269 break;
270 #ifndef XNN_NO_QS8_OPERATORS
271 case xnn_datatype_qint8:
272 {
273 const float output_scale = output_value->quantization.scale;
274 const int32_t output_zero_point = output_value->quantization.zero_point;
275 node->params.static_pad.padding_value = (int8_t)
276 lrintf(fminf(fmaxf(padding_value / output_scale + (float) output_zero_point, -128.0f), 127.0f));
277 break;
278 }
279 #endif // !defined(XNN_NO_QS8_OPERATORS)
280 #ifndef XNN_NO_QU8_OPERATORS
281 case xnn_datatype_quint8:
282 {
283 const float output_scale = output_value->quantization.scale;
284 const int32_t output_zero_point = output_value->quantization.zero_point;
285 node->params.static_pad.padding_value = (uint8_t)
286 lrintf(fminf(fmaxf(padding_value / output_scale + (float) output_zero_point, 0.0f), 255.0f));
287 break;
288 }
289 #endif // !defined(XNN_NO_QU8_OPERATORS)
290 default:
291 XNN_UNREACHABLE;
292 }
293
294 node->type = xnn_node_type_static_constant_pad;
295 node->compute_type = compute_type;
296 node->num_inputs = 1;
297 node->inputs[0] = input_id;
298 node->num_outputs = 1;
299 node->outputs[0] = output_id;
300 node->flags = flags;
301
302 node->create = create_constant_pad_operator;
303 node->setup = setup_constant_pad_operator;
304
305 return xnn_status_success;
306 }
307