1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <stdlib.h>
10
11 #include <xnnpack.h>
12 #include <xnnpack/allocator.h>
13 #include <xnnpack/log.h>
14 #include <xnnpack/operator.h>
15 #include <xnnpack/params-init.h>
16 #include <xnnpack/params.h>
17
18
xnn_create_hardswish_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * hardswish_op_out)19 enum xnn_status xnn_create_hardswish_nc_f32(
20 size_t channels,
21 size_t input_stride,
22 size_t output_stride,
23 uint32_t flags,
24 xnn_operator_t* hardswish_op_out)
25 {
26 xnn_operator_t hardswish_op = NULL;
27 enum xnn_status status = xnn_status_uninitialized;
28
29 if (!xnn_params.initialized) {
30 xnn_log_error("failed to create HardSwish operator: XNNPACK is not initialized");
31 goto error;
32 }
33
34 status = xnn_status_invalid_parameter;
35
36 if (channels == 0) {
37 xnn_log_error(
38 "failed to create HardSwish operator with %zu channels: number of channels must be non-zero", channels);
39 goto error;
40 }
41
42 if (input_stride < channels) {
43 xnn_log_error(
44 "failed to create HardSwish operator with input element stride of %zu: "
45 "stride must be at least as large as the number of channels (%zu)",
46 input_stride, channels);
47 goto error;
48 }
49
50 if (output_stride < channels) {
51 xnn_log_error(
52 "failed to create HardSwish operator with output element stride of %zu: "
53 "stride must be at least as large as the number of channels (%zu)",
54 output_stride, channels);
55 goto error;
56 }
57
58 status = xnn_status_out_of_memory;
59
60 hardswish_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
61 if (hardswish_op == NULL) {
62 xnn_log_error("failed to allocate %zu bytes for xnn_operator structure", sizeof(struct xnn_operator));
63 goto error;
64 }
65
66 hardswish_op->channels = channels;
67 hardswish_op->input_pixel_stride = input_stride;
68 hardswish_op->output_pixel_stride = output_stride;
69 hardswish_op->f32_hswish_params = xnn_init_f32_hswish_params();
70
71 hardswish_op->type = xnn_operator_type_hardswish_nc_f32;
72 hardswish_op->ukernel.type = xnn_ukernel_type_hswish;
73
74 hardswish_op->state = xnn_run_state_invalid;
75
76 *hardswish_op_out = hardswish_op;
77 return xnn_status_success;
78
79 error:
80 xnn_delete_operator(hardswish_op);
81 return status;
82 }
83
xnn_setup_hardswish_nc_f32(xnn_operator_t hardswish_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)84 enum xnn_status xnn_setup_hardswish_nc_f32(
85 xnn_operator_t hardswish_op,
86 size_t batch_size,
87 const float* input,
88 float* output,
89 pthreadpool_t threadpool)
90 {
91 if (hardswish_op->type != xnn_operator_type_hardswish_nc_f32) {
92 xnn_log_error("failed to setup HardSwish (F32) operator: operator type mismatch");
93 return xnn_status_invalid_parameter;
94 }
95 hardswish_op->state = xnn_run_state_invalid;
96
97 if (!xnn_params.initialized) {
98 xnn_log_error("failed to setup HardSwish operator: XNNPACK is not initialized");
99 return xnn_status_uninitialized;
100 }
101
102 if (batch_size == 0) {
103 hardswish_op->state = xnn_run_state_skip;
104 return xnn_status_success;
105 }
106
107 const size_t channels = hardswish_op->channels;
108 const size_t input_stride = hardswish_op->input_pixel_stride;
109 const size_t output_stride = hardswish_op->output_pixel_stride;
110 if ((((input_stride ^ channels) | (output_stride ^ channels)) == 0) || batch_size == 1) {
111 const size_t block_size = 4096;
112 hardswish_op->context.univector_contiguous = (struct univector_contiguous_context) {
113 .x = input,
114 .x_stride = input_stride * sizeof(float),
115 .y = output,
116 .y_stride = output_stride * sizeof(float),
117 .ukernel = xnn_params.f32.hswish,
118 .params.f32_hswish = hardswish_op->f32_hswish_params,
119 };
120 hardswish_op->compute.type = xnn_parallelization_type_1d_tile_1d;
121 hardswish_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_contiguous;
122 hardswish_op->compute.range[0] = batch_size * channels * sizeof(float);
123 hardswish_op->compute.tile[0] = block_size;
124 } else {
125 hardswish_op->context.univector_strided = (struct univector_strided_context) {
126 .n = channels * sizeof(float),
127 .x = input,
128 .x_stride = input_stride * sizeof(float),
129 .y = output,
130 .y_stride = output_stride * sizeof(float),
131 .ukernel = xnn_params.f32.hswish,
132 .params.f32_hswish = hardswish_op->f32_hswish_params,
133 };
134 hardswish_op->compute.type = xnn_parallelization_type_1d_tile_1d;
135 hardswish_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_strided;
136 hardswish_op->compute.range[0] = batch_size;
137 hardswish_op->compute.tile[0] = 1;
138 }
139 hardswish_op->state = xnn_run_state_ready;
140
141 return xnn_status_success;
142 }
143