• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 
10 #include <xnnpack.h>
11 #include <xnnpack/log.h>
12 #include <xnnpack/params.h>
13 #include <xnnpack/subgraph.h>
14 
15 
create_hardswish_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata)16 static enum xnn_status create_hardswish_operator(
17   const struct xnn_node* node,
18   const struct xnn_value* values,
19   size_t num_values,
20   struct xnn_operator_data* opdata)
21 {
22   assert(node->compute_type == xnn_compute_type_fp32);
23 
24   assert(node->num_inputs == 1);
25   const uint32_t input_id = node->inputs[0];
26   assert(input_id != XNN_INVALID_VALUE_ID);
27   assert(input_id < num_values);
28 
29   assert(node->num_outputs == 1);
30   const uint32_t output_id = node->outputs[0];
31   assert(output_id != XNN_INVALID_VALUE_ID);
32   assert(output_id < num_values);
33 
34   const size_t num_input_dims = values[input_id].shape.num_dims;
35   const size_t channel_dim = num_input_dims == 0 ? 1 : values[input_id].shape.dim[num_input_dims - 1];
36 
37   enum xnn_status status;
38   switch (node->compute_type) {
39     case xnn_compute_type_fp32:
40       status = xnn_create_hardswish_nc_f32(
41         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
42         node->flags,
43         &opdata->operator_object);
44       break;
45 #ifndef XNN_NO_F16_OPERATORS
46     case xnn_compute_type_fp16:
47       status = xnn_create_hardswish_nc_f16(
48         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
49         node->flags,
50         &opdata->operator_object);
51       break;
52 #endif  // !defined(XNN_NO_F16_OPERATORS)
53     default:
54       XNN_UNREACHABLE;
55   }
56   if (status == xnn_status_success) {
57     opdata->batch_size = xnn_shape_multiply_non_channel_dims(&values[input_id].shape);
58     opdata->inputs[0] = input_id;
59     opdata->outputs[0] = output_id;
60   }
61   return status;
62 }
63 
setup_hardswish_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)64 static enum xnn_status setup_hardswish_operator(
65   const struct xnn_operator_data* opdata,
66   const struct xnn_blob* blobs,
67   size_t num_blobs,
68   pthreadpool_t threadpool)
69 {
70   const uint32_t input_id = opdata->inputs[0];
71   assert(input_id != XNN_INVALID_VALUE_ID);
72   assert(input_id < num_blobs);
73 
74   const uint32_t output_id = opdata->outputs[0];
75   assert(output_id != XNN_INVALID_VALUE_ID);
76   assert(output_id < num_blobs);
77 
78   const struct xnn_blob* input_blob = blobs + input_id;
79   const void* input_data = input_blob->data;
80   assert(input_data != NULL);
81 
82   const struct xnn_blob* output_blob = blobs + output_id;
83   void* output_data = output_blob->data;
84   assert(output_data != NULL);
85 
86   switch (opdata->operator_object->type) {
87     case xnn_operator_type_hardswish_nc_f32:
88       return xnn_setup_hardswish_nc_f32(
89         opdata->operator_object,
90         opdata->batch_size,
91         input_data,
92         output_data,
93         threadpool);
94 #ifndef XNN_NO_F16_OPERATORS
95     case xnn_operator_type_hardswish_nc_f16:
96       return xnn_setup_hardswish_nc_f16(
97         opdata->operator_object,
98         opdata->batch_size,
99         input_data,
100         output_data,
101         threadpool);
102 #endif  // !defined(XNN_NO_F16_OPERATORS)
103     default:
104       XNN_UNREACHABLE;
105   }
106 }
107 
xnn_define_hardswish(xnn_subgraph_t subgraph,uint32_t input_id,uint32_t output_id,uint32_t flags)108 enum xnn_status xnn_define_hardswish(
109   xnn_subgraph_t subgraph,
110   uint32_t input_id,
111   uint32_t output_id,
112   uint32_t flags)
113 {
114   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
115     xnn_log_error("failed to define %s operator: XNNPACK is not initialized",
116       xnn_node_type_to_string(xnn_node_type_hardswish));
117     return xnn_status_uninitialized;
118   }
119 
120   if (input_id >= subgraph->num_values) {
121     xnn_log_error(
122       "failed to define %s operator with input ID #%" PRIu32 ": invalid Value ID",
123       xnn_node_type_to_string(xnn_node_type_hardswish), input_id);
124     return xnn_status_invalid_parameter;
125   }
126 
127   const struct xnn_value* input_value = &subgraph->values[input_id];
128   if (input_value->type != xnn_value_type_dense_tensor) {
129     xnn_log_error(
130       "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
131       xnn_node_type_to_string(xnn_node_type_hardswish), input_id, input_value->type);
132     return xnn_status_invalid_parameter;
133   }
134 
135   switch (input_value->datatype) {
136     case xnn_datatype_fp32:
137       break;
138     default:
139       xnn_log_error(
140         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
141         xnn_node_type_to_string(xnn_node_type_hardswish), input_id,
142         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
143       return xnn_status_invalid_parameter;
144   }
145 
146   if (output_id >= subgraph->num_values) {
147     xnn_log_error(
148       "failed to define %s operator with output ID #%" PRIu32 ": invalid Value ID",
149       xnn_node_type_to_string(xnn_node_type_hardswish), output_id);
150     return xnn_status_invalid_parameter;
151   }
152 
153   const struct xnn_value* output_value = &subgraph->values[output_id];
154   if (output_value->type != xnn_value_type_dense_tensor) {
155     xnn_log_error(
156       "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
157       xnn_node_type_to_string(xnn_node_type_hardswish), output_id, output_value->type);
158     return xnn_status_invalid_parameter;
159   }
160 
161   switch (output_value->datatype) {
162     case xnn_datatype_fp32:
163       break;
164     default:
165       xnn_log_error(
166         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
167         xnn_node_type_to_string(xnn_node_type_hardswish), output_id,
168         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
169       return xnn_status_invalid_parameter;
170   }
171 
172   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
173   if (node == NULL) {
174     return xnn_status_out_of_memory;
175   }
176 
177   node->type = xnn_node_type_hardswish;
178   node->compute_type = xnn_compute_type_fp32;
179   node->num_inputs = 1;
180   node->inputs[0] = input_id;
181   node->num_outputs = 1;
182   node->outputs[0] = output_id;
183   node->flags = flags;
184 
185   node->create = create_hardswish_operator;
186   node->setup = setup_hardswish_operator;
187 
188   return xnn_status_success;
189 }
190