• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/fp32/strided_slice_fp32.h"
18 #include "nnacl/errorcode.h"
19 
PadStridedSliceParameterTo8D(StridedSliceParameter * param)20 void PadStridedSliceParameterTo8D(StridedSliceParameter *param) {
21   int32_t begins[DIMENSION_8D];
22   int32_t ends[DIMENSION_8D];
23   int32_t strides[DIMENSION_8D];
24   int32_t input_shape[DIMENSION_8D];
25   int32_t i;
26   for (i = 0; i < param->num_axes_; ++i) {
27     begins[i] = param->begins_[i];
28     ends[i] = MSMIN(param->ends_[i], param->in_shape_[i]);
29     strides[i] = param->strides_[i];
30     input_shape[i] = param->in_shape_[i];
31   }
32   for (i = param->num_axes_; i < param->in_shape_length_; ++i) {
33     input_shape[i] = param->in_shape_[i];
34     begins[i] = 0;
35     ends[i] = param->in_shape_[i];
36     strides[i] = 1;
37   }
38   int32_t real_index = param->in_shape_length_ - 1;
39   for (i = DIMENSION_8D - 1; i >= 0; --i) {
40     if (real_index >= 0) {
41       param->begins_[i] = begins[real_index];
42       param->ends_[i] = ends[real_index];
43       param->strides_[i] = strides[real_index];
44       param->in_shape_[i] = input_shape[real_index--];
45     } else {
46       param->begins_[i] = 0;
47       param->ends_[i] = 1;
48       param->strides_[i] = 1;
49       param->in_shape_[i] = 1;
50     }
51   }
52   param->num_axes_ = DIMENSION_8D;
53   param->in_shape_length_ = DIMENSION_8D;
54 }
55 
LoopContinue(int stride,int i,int end)56 bool LoopContinue(int stride, int i, int end) { return stride > 0 ? i < end : i > end; }
57 
DoStridedSliceIntFp64Bool(const void * in_data,void * out_data,StridedSliceParameter * param)58 int DoStridedSliceIntFp64Bool(const void *in_data, void *out_data, StridedSliceParameter *param) {
59   if (in_data == NULL || out_data == NULL || param == NULL) {
60     return NNACL_NULL_PTR;
61   }
62   if (param->num_axes_ > DIMENSION_8D) {
63     return NNACL_PARAM_INVALID;
64   }
65 
66   int *begins = param->begins_;
67   int *ends = param->ends_;
68   int *strides = param->strides_;
69   int *in_shape = param->in_shape_;
70   if (param->num_axes_ < DIMENSION_8D) {
71     PadStridedSliceParameterTo8D(param);
72   }
73   int dim_offset[DIMENSION_8D - 1];
74   dim_offset[6] = in_shape[7];
75   dim_offset[5] = in_shape[6] * dim_offset[6];
76   dim_offset[4] = in_shape[5] * dim_offset[5];
77   dim_offset[3] = in_shape[4] * dim_offset[4];
78   dim_offset[2] = in_shape[3] * dim_offset[3];
79   dim_offset[1] = in_shape[2] * dim_offset[2];
80   dim_offset[0] = in_shape[1] * dim_offset[1];
81   size_t out_offset = 0;
82   int32_t dim0, dim1, dim2, dim3, dim4, dim5, dim6, dim7;
83   for (dim0 = begins[0]; LoopContinue(strides[0], dim0, ends[0]); dim0 += strides[0]) {
84     for (dim1 = begins[1]; LoopContinue(strides[1], dim1, ends[1]); dim1 += strides[1]) {
85       for (dim2 = begins[2]; LoopContinue(strides[2], dim2, ends[2]); dim2 += strides[2]) {
86         for (dim3 = begins[3]; LoopContinue(strides[3], dim3, ends[3]); dim3 += strides[3]) {
87           for (dim4 = begins[4]; LoopContinue(strides[4], dim4, ends[4]); dim4 += strides[4]) {
88             for (dim5 = begins[5]; LoopContinue(strides[5], dim5, ends[5]); dim5 += strides[5]) {
89               for (dim6 = begins[6]; LoopContinue(strides[6], dim6, ends[6]); dim6 += strides[6]) {
90                 for (dim7 = begins[7]; LoopContinue(strides[7], dim7, ends[7]); dim7 += strides[7]) {
91                   int32_t in_offset = dim0 * dim_offset[0] + dim1 * dim_offset[1] + dim2 * dim_offset[2] +
92                                       dim3 * dim_offset[3] + dim4 * dim_offset[4] + dim5 * dim_offset[5] +
93                                       dim6 * dim_offset[6] + dim7;
94                   if (param->data_type == kDataTypeInt) {
95                     *((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset);
96                   } else if (param->data_type == kDataTypeInt8) {
97                     *((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset);
98                   } else if (param->data_type == kDataTypeBool) {
99                     *((bool *)out_data + out_offset) = *((bool *)in_data + in_offset);
100                   } else if (param->data_type == kDataTypeFloat64) {
101                     *((double *)out_data + out_offset) = *((double *)in_data + in_offset);
102                   } else {
103                     return NNACL_ERR;
104                   }
105                   out_offset++;
106                 }
107               }
108             }
109           }
110         }
111       }
112     }
113   }
114   return NNACL_OK;
115 }
116 
DoStridedSlice(const void * in_data,void * out_data,StridedSliceParameter * param)117 int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *param) {
118   if (in_data == NULL || out_data == NULL || param == NULL) {
119     return NNACL_NULL_PTR;
120   }
121   if (param->data_type != kDataTypeFloat && param->data_type != kDataTypeFloat16) {
122     return DoStridedSliceIntFp64Bool(in_data, out_data, param);
123   }
124   if (param->num_axes_ > DIMENSION_8D) {
125     return NNACL_PARAM_INVALID;
126   }
127 
128   int *begins = param->begins_;
129   int *ends = param->ends_;
130   int *strides = param->strides_;
131   int *in_shape = param->in_shape_;
132   if (param->num_axes_ < DIMENSION_8D) {
133     PadStridedSliceParameterTo8D(param);
134   }
135   int dim_offset[DIMENSION_8D - 1];
136   dim_offset[6] = in_shape[7];
137   dim_offset[5] = in_shape[6] * dim_offset[6];
138   dim_offset[4] = in_shape[5] * dim_offset[5];
139   dim_offset[3] = in_shape[4] * dim_offset[4];
140   dim_offset[2] = in_shape[3] * dim_offset[3];
141   dim_offset[1] = in_shape[2] * dim_offset[2];
142   dim_offset[0] = in_shape[1] * dim_offset[1];
143   size_t out_offset = 0;
144   int32_t dim0, dim1, dim2, dim3, dim4, dim5, dim6, dim7;
145   for (dim0 = begins[0]; LoopContinue(strides[0], dim0, ends[0]); dim0 += strides[0]) {
146     for (dim1 = begins[1]; LoopContinue(strides[1], dim1, ends[1]); dim1 += strides[1]) {
147       for (dim2 = begins[2]; LoopContinue(strides[2], dim2, ends[2]); dim2 += strides[2]) {
148         for (dim3 = begins[3]; LoopContinue(strides[3], dim3, ends[3]); dim3 += strides[3]) {
149           for (dim4 = begins[4]; LoopContinue(strides[4], dim4, ends[4]); dim4 += strides[4]) {
150             for (dim5 = begins[5]; LoopContinue(strides[5], dim5, ends[5]); dim5 += strides[5]) {
151               for (dim6 = begins[6]; LoopContinue(strides[6], dim6, ends[6]); dim6 += strides[6]) {
152                 for (dim7 = begins[7]; LoopContinue(strides[7], dim7, ends[7]); dim7 += strides[7]) {
153                   int32_t in_offset = dim0 * dim_offset[0] + dim1 * dim_offset[1] + dim2 * dim_offset[2] +
154                                       dim3 * dim_offset[3] + dim4 * dim_offset[4] + dim5 * dim_offset[5] +
155                                       dim6 * dim_offset[6] + dim7;
156                   if (param->data_type == kDataTypeFloat) {
157                     *((float *)out_data + out_offset) = *((float *)in_data + in_offset);
158 #ifdef ENABLE_ARM64
159                   } else if (param->data_type == kDataTypeFloat16) {
160                     *((float16_t *)out_data + out_offset) = *((float16_t *)in_data + in_offset);
161 #endif
162                   } else {
163                     return NNACL_ERR;
164                   }
165                   out_offset++;
166                 }
167               }
168             }
169           }
170         }
171       }
172     }
173   }
174   return NNACL_OK;
175 }
176 
FastStride(const uint8_t * input,uint8_t * output,int split_len,int stride,size_t outer,size_t inner_size,size_t in_offset)177 void FastStride(const uint8_t *input, uint8_t *output, int split_len, int stride, size_t outer, size_t inner_size,
178                 size_t in_offset) {
179   for (size_t i = 0; i < outer; ++i) {
180     const uint8_t *input_ptr = input + i * in_offset;
181     for (int j = 0; j < split_len; ++j) {
182       memcpy(output, input_ptr, inner_size);
183       output += inner_size;
184       input_ptr += inner_size * stride;
185     }
186   }
187 }
188