1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/fp32/strided_slice_fp32.h"
18 #include "nnacl/errorcode.h"
19
PadStridedSliceParameterTo8D(StridedSliceParameter * param)20 void PadStridedSliceParameterTo8D(StridedSliceParameter *param) {
21 int32_t begins[DIMENSION_8D];
22 int32_t ends[DIMENSION_8D];
23 int32_t strides[DIMENSION_8D];
24 int32_t input_shape[DIMENSION_8D];
25 int32_t i;
26 for (i = 0; i < param->num_axes_; ++i) {
27 begins[i] = param->begins_[i];
28 ends[i] = MSMIN(param->ends_[i], param->in_shape_[i]);
29 strides[i] = param->strides_[i];
30 input_shape[i] = param->in_shape_[i];
31 }
32 for (i = param->num_axes_; i < param->in_shape_length_; ++i) {
33 input_shape[i] = param->in_shape_[i];
34 begins[i] = 0;
35 ends[i] = param->in_shape_[i];
36 strides[i] = 1;
37 }
38 int32_t real_index = param->in_shape_length_ - 1;
39 for (i = DIMENSION_8D - 1; i >= 0; --i) {
40 if (real_index >= 0) {
41 param->begins_[i] = begins[real_index];
42 param->ends_[i] = ends[real_index];
43 param->strides_[i] = strides[real_index];
44 param->in_shape_[i] = input_shape[real_index--];
45 } else {
46 param->begins_[i] = 0;
47 param->ends_[i] = 1;
48 param->strides_[i] = 1;
49 param->in_shape_[i] = 1;
50 }
51 }
52 param->num_axes_ = DIMENSION_8D;
53 param->in_shape_length_ = DIMENSION_8D;
54 }
55
LoopContinue(int stride,int i,int end)56 bool LoopContinue(int stride, int i, int end) { return stride > 0 ? i < end : i > end; }
57
DoStridedSliceIntFp64Bool(const void * in_data,void * out_data,StridedSliceParameter * param)58 int DoStridedSliceIntFp64Bool(const void *in_data, void *out_data, StridedSliceParameter *param) {
59 if (in_data == NULL || out_data == NULL || param == NULL) {
60 return NNACL_NULL_PTR;
61 }
62 if (param->num_axes_ > DIMENSION_8D) {
63 return NNACL_PARAM_INVALID;
64 }
65
66 int *begins = param->begins_;
67 int *ends = param->ends_;
68 int *strides = param->strides_;
69 int *in_shape = param->in_shape_;
70 if (param->num_axes_ < DIMENSION_8D) {
71 PadStridedSliceParameterTo8D(param);
72 }
73 int dim_offset[DIMENSION_8D - 1];
74 dim_offset[6] = in_shape[7];
75 dim_offset[5] = in_shape[6] * dim_offset[6];
76 dim_offset[4] = in_shape[5] * dim_offset[5];
77 dim_offset[3] = in_shape[4] * dim_offset[4];
78 dim_offset[2] = in_shape[3] * dim_offset[3];
79 dim_offset[1] = in_shape[2] * dim_offset[2];
80 dim_offset[0] = in_shape[1] * dim_offset[1];
81 size_t out_offset = 0;
82 int32_t dim0, dim1, dim2, dim3, dim4, dim5, dim6, dim7;
83 for (dim0 = begins[0]; LoopContinue(strides[0], dim0, ends[0]); dim0 += strides[0]) {
84 for (dim1 = begins[1]; LoopContinue(strides[1], dim1, ends[1]); dim1 += strides[1]) {
85 for (dim2 = begins[2]; LoopContinue(strides[2], dim2, ends[2]); dim2 += strides[2]) {
86 for (dim3 = begins[3]; LoopContinue(strides[3], dim3, ends[3]); dim3 += strides[3]) {
87 for (dim4 = begins[4]; LoopContinue(strides[4], dim4, ends[4]); dim4 += strides[4]) {
88 for (dim5 = begins[5]; LoopContinue(strides[5], dim5, ends[5]); dim5 += strides[5]) {
89 for (dim6 = begins[6]; LoopContinue(strides[6], dim6, ends[6]); dim6 += strides[6]) {
90 for (dim7 = begins[7]; LoopContinue(strides[7], dim7, ends[7]); dim7 += strides[7]) {
91 int32_t in_offset = dim0 * dim_offset[0] + dim1 * dim_offset[1] + dim2 * dim_offset[2] +
92 dim3 * dim_offset[3] + dim4 * dim_offset[4] + dim5 * dim_offset[5] +
93 dim6 * dim_offset[6] + dim7;
94 if (param->data_type == kDataTypeInt) {
95 *((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset);
96 } else if (param->data_type == kDataTypeInt8) {
97 *((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset);
98 } else if (param->data_type == kDataTypeBool) {
99 *((bool *)out_data + out_offset) = *((bool *)in_data + in_offset);
100 } else if (param->data_type == kDataTypeFloat64) {
101 *((double *)out_data + out_offset) = *((double *)in_data + in_offset);
102 } else {
103 return NNACL_ERR;
104 }
105 out_offset++;
106 }
107 }
108 }
109 }
110 }
111 }
112 }
113 }
114 return NNACL_OK;
115 }
116
DoStridedSlice(const void * in_data,void * out_data,StridedSliceParameter * param)117 int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *param) {
118 if (in_data == NULL || out_data == NULL || param == NULL) {
119 return NNACL_NULL_PTR;
120 }
121 if (param->data_type != kDataTypeFloat && param->data_type != kDataTypeFloat16) {
122 return DoStridedSliceIntFp64Bool(in_data, out_data, param);
123 }
124 if (param->num_axes_ > DIMENSION_8D) {
125 return NNACL_PARAM_INVALID;
126 }
127
128 int *begins = param->begins_;
129 int *ends = param->ends_;
130 int *strides = param->strides_;
131 int *in_shape = param->in_shape_;
132 if (param->num_axes_ < DIMENSION_8D) {
133 PadStridedSliceParameterTo8D(param);
134 }
135 int dim_offset[DIMENSION_8D - 1];
136 dim_offset[6] = in_shape[7];
137 dim_offset[5] = in_shape[6] * dim_offset[6];
138 dim_offset[4] = in_shape[5] * dim_offset[5];
139 dim_offset[3] = in_shape[4] * dim_offset[4];
140 dim_offset[2] = in_shape[3] * dim_offset[3];
141 dim_offset[1] = in_shape[2] * dim_offset[2];
142 dim_offset[0] = in_shape[1] * dim_offset[1];
143 size_t out_offset = 0;
144 int32_t dim0, dim1, dim2, dim3, dim4, dim5, dim6, dim7;
145 for (dim0 = begins[0]; LoopContinue(strides[0], dim0, ends[0]); dim0 += strides[0]) {
146 for (dim1 = begins[1]; LoopContinue(strides[1], dim1, ends[1]); dim1 += strides[1]) {
147 for (dim2 = begins[2]; LoopContinue(strides[2], dim2, ends[2]); dim2 += strides[2]) {
148 for (dim3 = begins[3]; LoopContinue(strides[3], dim3, ends[3]); dim3 += strides[3]) {
149 for (dim4 = begins[4]; LoopContinue(strides[4], dim4, ends[4]); dim4 += strides[4]) {
150 for (dim5 = begins[5]; LoopContinue(strides[5], dim5, ends[5]); dim5 += strides[5]) {
151 for (dim6 = begins[6]; LoopContinue(strides[6], dim6, ends[6]); dim6 += strides[6]) {
152 for (dim7 = begins[7]; LoopContinue(strides[7], dim7, ends[7]); dim7 += strides[7]) {
153 int32_t in_offset = dim0 * dim_offset[0] + dim1 * dim_offset[1] + dim2 * dim_offset[2] +
154 dim3 * dim_offset[3] + dim4 * dim_offset[4] + dim5 * dim_offset[5] +
155 dim6 * dim_offset[6] + dim7;
156 if (param->data_type == kDataTypeFloat) {
157 *((float *)out_data + out_offset) = *((float *)in_data + in_offset);
158 #ifdef ENABLE_ARM64
159 } else if (param->data_type == kDataTypeFloat16) {
160 *((float16_t *)out_data + out_offset) = *((float16_t *)in_data + in_offset);
161 #endif
162 } else {
163 return NNACL_ERR;
164 }
165 out_offset++;
166 }
167 }
168 }
169 }
170 }
171 }
172 }
173 }
174 return NNACL_OK;
175 }
176
FastStride(const uint8_t * input,uint8_t * output,int split_len,int stride,size_t outer,size_t inner_size,size_t in_offset)177 void FastStride(const uint8_t *input, uint8_t *output, int split_len, int stride, size_t outer, size_t inner_size,
178 size_t in_offset) {
179 for (size_t i = 0; i < outer; ++i) {
180 const uint8_t *input_ptr = input + i * in_offset;
181 for (int j = 0; j < split_len; ++j) {
182 memcpy(output, input_ptr, inner_size);
183 output += inner_size;
184 input_ptr += inner_size * stride;
185 }
186 }
187 }
188