• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 
10 #include <xnnpack.h>
11 #include <xnnpack/log.h>
12 #include <xnnpack/params.h>
13 #include <xnnpack/subgraph.h>
14 
15 
create_clamp_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata)16 static enum xnn_status create_clamp_operator(
17   const struct xnn_node* node,
18   const struct xnn_value* values,
19   size_t num_values,
20   struct xnn_operator_data* opdata)
21 {
22   assert(node->num_inputs == 1);
23   const uint32_t input_id = node->inputs[0];
24   assert(input_id != XNN_INVALID_VALUE_ID);
25   assert(input_id < num_values);
26 
27   assert(node->num_outputs == 1);
28   const uint32_t output_id = node->outputs[0];
29   assert(output_id != XNN_INVALID_VALUE_ID);
30   assert(output_id < num_values);
31 
32   const size_t num_input_dims = values[input_id].shape.num_dims;
33   const size_t channel_dim = num_input_dims == 0 ? 1 : values[input_id].shape.dim[num_input_dims - 1];
34 
35   enum xnn_status status;
36   switch (node->compute_type) {
37     case xnn_compute_type_fp32:
38       status = xnn_create_clamp_nc_f32(
39         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
40         node->activation.output_min,
41         node->activation.output_max,
42         node->flags,
43         &opdata->operator_object);
44       break;
45 #ifndef XNN_NO_S8_OPERATORS
46     case xnn_compute_type_qs8:
47     {
48       const float output_scale = values[output_id].quantization.scale;
49       const int32_t output_zero_point = values[output_id].quantization.zero_point;
50       const int8_t output_min =
51         (int8_t) lrintf(fminf(fmaxf(node->activation.output_min / output_scale + (float) output_zero_point, -128.0f), 127.0f));
52       const int8_t output_max =
53         (int8_t) lrintf(fminf(fmaxf(node->activation.output_max / output_scale + (float) output_zero_point, -128.0f), 127.0f));
54       status = xnn_create_clamp_nc_s8(
55         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
56         output_min,
57         output_max,
58         node->flags,
59         &opdata->operator_object);
60       break;
61     }
62 #endif  // !defined(XNN_NO_S8_OPERATORS)
63 #ifndef XNN_NO_U8_OPERATORS
64     case xnn_compute_type_qu8:
65     {
66       const float output_scale = values[output_id].quantization.scale;
67       const int32_t output_zero_point = values[output_id].quantization.zero_point;
68       const uint8_t output_min =
69         (uint8_t) lrintf(fminf(fmaxf(node->activation.output_min / output_scale + (float) output_zero_point, 0.0f), 255.0f));
70       const uint8_t output_max =
71         (uint8_t) lrintf(fminf(fmaxf(node->activation.output_max / output_scale + (float) output_zero_point, 0.0f), 255.0f));
72       status = xnn_create_clamp_nc_u8(
73         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
74         output_min,
75         output_max,
76         node->flags,
77         &opdata->operator_object);
78       break;
79     }
80 #endif  // !defined(XNN_NO_U8_OPERATORS)
81     default:
82       XNN_UNREACHABLE;
83   }
84   if (status == xnn_status_success) {
85     opdata->batch_size = xnn_shape_multiply_non_channel_dims(&values[input_id].shape);
86     opdata->inputs[0] = input_id;
87     opdata->outputs[0] = output_id;
88   }
89   return status;
90 }
91 
setup_clamp_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)92 static enum xnn_status setup_clamp_operator(
93   const struct xnn_operator_data* opdata,
94   const struct xnn_blob* blobs,
95   size_t num_blobs,
96   pthreadpool_t threadpool)
97 {
98   const uint32_t input_id = opdata->inputs[0];
99   assert(input_id != XNN_INVALID_VALUE_ID);
100   assert(input_id < num_blobs);
101 
102   const uint32_t output_id = opdata->outputs[0];
103   assert(output_id != XNN_INVALID_VALUE_ID);
104   assert(output_id < num_blobs);
105 
106   const struct xnn_blob* input_blob = blobs + input_id;
107   const void* input_data = input_blob->data;
108   assert(input_data != NULL);
109 
110   const struct xnn_blob* output_blob = blobs + output_id;
111   void* output_data = output_blob->data;
112   assert(output_data != NULL);
113 
114   switch (opdata->operator_object->type) {
115     case xnn_operator_type_clamp_nc_f32:
116       return xnn_setup_clamp_nc_f32(
117         opdata->operator_object,
118         opdata->batch_size,
119         input_data,
120         output_data,
121         threadpool);
122 #ifndef XNN_NO_S8_OPERATORS
123     case xnn_operator_type_clamp_nc_s8:
124       return xnn_setup_clamp_nc_s8(
125         opdata->operator_object,
126         opdata->batch_size,
127         input_data,
128         output_data,
129         threadpool);
130 #endif  // !defined(XNN_NO_S8_OPERATORS)
131 #ifndef XNN_NO_U8_OPERATORS
132     case xnn_operator_type_clamp_nc_u8:
133       return xnn_setup_clamp_nc_u8(
134         opdata->operator_object,
135         opdata->batch_size,
136         input_data,
137         output_data,
138         threadpool);
139       break;
140 #endif  // !defined(XNN_NO_U8_OPERATORS)
141     default:
142       XNN_UNREACHABLE;
143   }
144 }
145 
xnn_define_clamp(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input_id,uint32_t output_id,uint32_t flags)146 enum xnn_status xnn_define_clamp(
147   xnn_subgraph_t subgraph,
148   float output_min,
149   float output_max,
150   uint32_t input_id,
151   uint32_t output_id,
152   uint32_t flags)
153 {
154   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
155     xnn_log_error("failed to define %s operator: XNNPACK is not initialized",
156       xnn_node_type_to_string(xnn_node_type_clamp));
157     return xnn_status_uninitialized;
158   }
159 
160   if (input_id >= subgraph->num_values) {
161     xnn_log_error(
162       "failed to define %s operator with input ID #%" PRIu32 ": invalid Value ID",
163       xnn_node_type_to_string(xnn_node_type_clamp), input_id);
164     return xnn_status_invalid_parameter;
165   }
166 
167   const struct xnn_value* input_value = &subgraph->values[input_id];
168   if (input_value->type != xnn_value_type_dense_tensor) {
169     xnn_log_error(
170       "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
171       xnn_node_type_to_string(xnn_node_type_clamp), input_id, input_value->type);
172     return xnn_status_invalid_parameter;
173   }
174 
175   switch (input_value->datatype) {
176     case xnn_datatype_fp32:
177 #ifndef XNN_NO_S8_OPERATORS
178     case xnn_datatype_qint8:
179 #endif  // !defined(XNN_NO_S8_OPERATORS)
180 #ifndef XNN_NO_U8_OPERATORS
181     case xnn_datatype_quint8:
182 #endif  // !defined(XNN_NO_U8_OPERATORS)
183       break;
184     default:
185       xnn_log_error(
186         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
187         xnn_node_type_to_string(xnn_node_type_clamp), input_id,
188         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
189       return xnn_status_invalid_parameter;
190   }
191 
192   if (output_id >= subgraph->num_values) {
193     xnn_log_error(
194       "failed to define %s operator with output ID #%" PRIu32 ": invalid Value ID",
195       xnn_node_type_to_string(xnn_node_type_clamp), output_id);
196     return xnn_status_invalid_parameter;
197   }
198 
199   const struct xnn_value* output_value = &subgraph->values[output_id];
200   if (output_value->type != xnn_value_type_dense_tensor) {
201     xnn_log_error(
202       "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
203       xnn_node_type_to_string(xnn_node_type_clamp), output_id, output_value->type);
204     return xnn_status_invalid_parameter;
205   }
206 
207   enum xnn_compute_type compute_type = xnn_compute_type_invalid;
208   switch (output_value->datatype) {
209     case xnn_datatype_fp32:
210       compute_type = xnn_compute_type_fp32;
211       break;
212 #ifndef XNN_NO_S8_OPERATORS
213     case xnn_datatype_qint8:
214       compute_type = xnn_compute_type_qs8;
215       break;
216 #endif  // !defined(XNN_NO_S8_OPERATORS)
217 #ifndef XNN_NO_U8_OPERATORS
218     case xnn_datatype_quint8:
219       compute_type = xnn_compute_type_qu8;
220       break;
221 #endif  // !defined(XNN_NO_U8_OPERATORS)
222     default:
223       xnn_log_error(
224         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
225         xnn_node_type_to_string(xnn_node_type_clamp), output_id,
226         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
227       return xnn_status_invalid_parameter;
228   }
229 
230   if (input_value->datatype != output_value->datatype) {
231     xnn_log_error(
232       "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
233       ": mismatching datatypes across input (%s) and output (%s)",
234       xnn_node_type_to_string(xnn_node_type_clamp), input_id, output_id,
235       xnn_datatype_to_string(input_value->datatype),
236       xnn_datatype_to_string(output_value->datatype));
237     return xnn_status_invalid_parameter;
238   }
239 
240 #if !defined(XNN_NO_U8_OPERATORS) || !defined(XNN_NO_S8_OPERATORS)
241   if (compute_type == xnn_datatype_qint8 || compute_type == xnn_datatype_quint8) {
242     if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
243       xnn_log_error(
244         "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
245         ": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
246         xnn_node_type_to_string(xnn_node_type_clamp), input_id, output_id,
247         input_value->quantization.zero_point, output_value->quantization.zero_point);
248       return xnn_status_invalid_parameter;
249     }
250     if (input_value->quantization.scale != output_value->quantization.scale) {
251       xnn_log_error(
252         "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
253         ": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
254         xnn_node_type_to_string(xnn_node_type_clamp), input_id, output_id,
255         input_value->quantization.scale, output_value->quantization.scale);
256       return xnn_status_invalid_parameter;
257     }
258   }
259 #endif  // !defined(XNN_NO_U8_OPERATORS) || !defined(XNN_NO_S8_OPERATORS)
260 
261   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
262   if (node == NULL) {
263     return xnn_status_out_of_memory;
264   }
265 
266   node->type = xnn_node_type_clamp;
267   node->compute_type = compute_type;
268   node->activation.output_min = output_min;
269   node->activation.output_max = output_max;
270   node->num_inputs = 1;
271   node->inputs[0] = input_id;
272   node->num_outputs = 1;
273   node->outputs[0] = output_id;
274   node->flags = flags;
275 
276   node->create = create_clamp_operator;
277   node->setup = setup_clamp_operator;
278 
279   return xnn_status_success;
280 }
281