1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9
10 #include <xnnpack.h>
11 #include <xnnpack/log.h>
12 #include <xnnpack/params.h>
13 #include <xnnpack/subgraph.h>
14
15
create_clamp_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata)16 static enum xnn_status create_clamp_operator(
17 const struct xnn_node* node,
18 const struct xnn_value* values,
19 size_t num_values,
20 struct xnn_operator_data* opdata)
21 {
22 assert(node->num_inputs == 1);
23 const uint32_t input_id = node->inputs[0];
24 assert(input_id != XNN_INVALID_VALUE_ID);
25 assert(input_id < num_values);
26
27 assert(node->num_outputs == 1);
28 const uint32_t output_id = node->outputs[0];
29 assert(output_id != XNN_INVALID_VALUE_ID);
30 assert(output_id < num_values);
31
32 const size_t num_input_dims = values[input_id].shape.num_dims;
33 const size_t channel_dim = num_input_dims == 0 ? 1 : values[input_id].shape.dim[num_input_dims - 1];
34
35 enum xnn_status status;
36 switch (node->compute_type) {
37 case xnn_compute_type_fp32:
38 status = xnn_create_clamp_nc_f32(
39 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
40 node->activation.output_min,
41 node->activation.output_max,
42 node->flags,
43 &opdata->operator_object);
44 break;
45 #ifndef XNN_NO_S8_OPERATORS
46 case xnn_compute_type_qs8:
47 {
48 const float output_scale = values[output_id].quantization.scale;
49 const int32_t output_zero_point = values[output_id].quantization.zero_point;
50 const int8_t output_min =
51 (int8_t) lrintf(fminf(fmaxf(node->activation.output_min / output_scale + (float) output_zero_point, -128.0f), 127.0f));
52 const int8_t output_max =
53 (int8_t) lrintf(fminf(fmaxf(node->activation.output_max / output_scale + (float) output_zero_point, -128.0f), 127.0f));
54 status = xnn_create_clamp_nc_s8(
55 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
56 output_min,
57 output_max,
58 node->flags,
59 &opdata->operator_object);
60 break;
61 }
62 #endif // !defined(XNN_NO_S8_OPERATORS)
63 #ifndef XNN_NO_U8_OPERATORS
64 case xnn_compute_type_qu8:
65 {
66 const float output_scale = values[output_id].quantization.scale;
67 const int32_t output_zero_point = values[output_id].quantization.zero_point;
68 const uint8_t output_min =
69 (uint8_t) lrintf(fminf(fmaxf(node->activation.output_min / output_scale + (float) output_zero_point, 0.0f), 255.0f));
70 const uint8_t output_max =
71 (uint8_t) lrintf(fminf(fmaxf(node->activation.output_max / output_scale + (float) output_zero_point, 0.0f), 255.0f));
72 status = xnn_create_clamp_nc_u8(
73 channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
74 output_min,
75 output_max,
76 node->flags,
77 &opdata->operator_object);
78 break;
79 }
80 #endif // !defined(XNN_NO_U8_OPERATORS)
81 default:
82 XNN_UNREACHABLE;
83 }
84 if (status == xnn_status_success) {
85 opdata->batch_size = xnn_shape_multiply_non_channel_dims(&values[input_id].shape);
86 opdata->inputs[0] = input_id;
87 opdata->outputs[0] = output_id;
88 }
89 return status;
90 }
91
setup_clamp_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)92 static enum xnn_status setup_clamp_operator(
93 const struct xnn_operator_data* opdata,
94 const struct xnn_blob* blobs,
95 size_t num_blobs,
96 pthreadpool_t threadpool)
97 {
98 const uint32_t input_id = opdata->inputs[0];
99 assert(input_id != XNN_INVALID_VALUE_ID);
100 assert(input_id < num_blobs);
101
102 const uint32_t output_id = opdata->outputs[0];
103 assert(output_id != XNN_INVALID_VALUE_ID);
104 assert(output_id < num_blobs);
105
106 const struct xnn_blob* input_blob = blobs + input_id;
107 const void* input_data = input_blob->data;
108 assert(input_data != NULL);
109
110 const struct xnn_blob* output_blob = blobs + output_id;
111 void* output_data = output_blob->data;
112 assert(output_data != NULL);
113
114 switch (opdata->operator_object->type) {
115 case xnn_operator_type_clamp_nc_f32:
116 return xnn_setup_clamp_nc_f32(
117 opdata->operator_object,
118 opdata->batch_size,
119 input_data,
120 output_data,
121 threadpool);
122 #ifndef XNN_NO_S8_OPERATORS
123 case xnn_operator_type_clamp_nc_s8:
124 return xnn_setup_clamp_nc_s8(
125 opdata->operator_object,
126 opdata->batch_size,
127 input_data,
128 output_data,
129 threadpool);
130 #endif // !defined(XNN_NO_S8_OPERATORS)
131 #ifndef XNN_NO_U8_OPERATORS
132 case xnn_operator_type_clamp_nc_u8:
133 return xnn_setup_clamp_nc_u8(
134 opdata->operator_object,
135 opdata->batch_size,
136 input_data,
137 output_data,
138 threadpool);
139 break;
140 #endif // !defined(XNN_NO_U8_OPERATORS)
141 default:
142 XNN_UNREACHABLE;
143 }
144 }
145
xnn_define_clamp(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input_id,uint32_t output_id,uint32_t flags)146 enum xnn_status xnn_define_clamp(
147 xnn_subgraph_t subgraph,
148 float output_min,
149 float output_max,
150 uint32_t input_id,
151 uint32_t output_id,
152 uint32_t flags)
153 {
154 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
155 xnn_log_error("failed to define %s operator: XNNPACK is not initialized",
156 xnn_node_type_to_string(xnn_node_type_clamp));
157 return xnn_status_uninitialized;
158 }
159
160 if (input_id >= subgraph->num_values) {
161 xnn_log_error(
162 "failed to define %s operator with input ID #%" PRIu32 ": invalid Value ID",
163 xnn_node_type_to_string(xnn_node_type_clamp), input_id);
164 return xnn_status_invalid_parameter;
165 }
166
167 const struct xnn_value* input_value = &subgraph->values[input_id];
168 if (input_value->type != xnn_value_type_dense_tensor) {
169 xnn_log_error(
170 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
171 xnn_node_type_to_string(xnn_node_type_clamp), input_id, input_value->type);
172 return xnn_status_invalid_parameter;
173 }
174
175 switch (input_value->datatype) {
176 case xnn_datatype_fp32:
177 #ifndef XNN_NO_S8_OPERATORS
178 case xnn_datatype_qint8:
179 #endif // !defined(XNN_NO_S8_OPERATORS)
180 #ifndef XNN_NO_U8_OPERATORS
181 case xnn_datatype_quint8:
182 #endif // !defined(XNN_NO_U8_OPERATORS)
183 break;
184 default:
185 xnn_log_error(
186 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
187 xnn_node_type_to_string(xnn_node_type_clamp), input_id,
188 xnn_datatype_to_string(input_value->datatype), input_value->datatype);
189 return xnn_status_invalid_parameter;
190 }
191
192 if (output_id >= subgraph->num_values) {
193 xnn_log_error(
194 "failed to define %s operator with output ID #%" PRIu32 ": invalid Value ID",
195 xnn_node_type_to_string(xnn_node_type_clamp), output_id);
196 return xnn_status_invalid_parameter;
197 }
198
199 const struct xnn_value* output_value = &subgraph->values[output_id];
200 if (output_value->type != xnn_value_type_dense_tensor) {
201 xnn_log_error(
202 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
203 xnn_node_type_to_string(xnn_node_type_clamp), output_id, output_value->type);
204 return xnn_status_invalid_parameter;
205 }
206
207 enum xnn_compute_type compute_type = xnn_compute_type_invalid;
208 switch (output_value->datatype) {
209 case xnn_datatype_fp32:
210 compute_type = xnn_compute_type_fp32;
211 break;
212 #ifndef XNN_NO_S8_OPERATORS
213 case xnn_datatype_qint8:
214 compute_type = xnn_compute_type_qs8;
215 break;
216 #endif // !defined(XNN_NO_S8_OPERATORS)
217 #ifndef XNN_NO_U8_OPERATORS
218 case xnn_datatype_quint8:
219 compute_type = xnn_compute_type_qu8;
220 break;
221 #endif // !defined(XNN_NO_U8_OPERATORS)
222 default:
223 xnn_log_error(
224 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
225 xnn_node_type_to_string(xnn_node_type_clamp), output_id,
226 xnn_datatype_to_string(output_value->datatype), output_value->datatype);
227 return xnn_status_invalid_parameter;
228 }
229
230 if (input_value->datatype != output_value->datatype) {
231 xnn_log_error(
232 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
233 ": mismatching datatypes across input (%s) and output (%s)",
234 xnn_node_type_to_string(xnn_node_type_clamp), input_id, output_id,
235 xnn_datatype_to_string(input_value->datatype),
236 xnn_datatype_to_string(output_value->datatype));
237 return xnn_status_invalid_parameter;
238 }
239
240 #if !defined(XNN_NO_U8_OPERATORS) || !defined(XNN_NO_S8_OPERATORS)
241 if (compute_type == xnn_datatype_qint8 || compute_type == xnn_datatype_quint8) {
242 if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
243 xnn_log_error(
244 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
245 ": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
246 xnn_node_type_to_string(xnn_node_type_clamp), input_id, output_id,
247 input_value->quantization.zero_point, output_value->quantization.zero_point);
248 return xnn_status_invalid_parameter;
249 }
250 if (input_value->quantization.scale != output_value->quantization.scale) {
251 xnn_log_error(
252 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
253 ": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
254 xnn_node_type_to_string(xnn_node_type_clamp), input_id, output_id,
255 input_value->quantization.scale, output_value->quantization.scale);
256 return xnn_status_invalid_parameter;
257 }
258 }
259 #endif // !defined(XNN_NO_U8_OPERATORS) || !defined(XNN_NO_S8_OPERATORS)
260
261 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
262 if (node == NULL) {
263 return xnn_status_out_of_memory;
264 }
265
266 node->type = xnn_node_type_clamp;
267 node->compute_type = compute_type;
268 node->activation.output_min = output_min;
269 node->activation.output_max = output_max;
270 node->num_inputs = 1;
271 node->inputs[0] = input_id;
272 node->num_outputs = 1;
273 node->outputs[0] = output_id;
274 node->flags = flags;
275
276 node->create = create_clamp_operator;
277 node->setup = setup_clamp_operator;
278
279 return xnn_status_success;
280 }
281