1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/runtime/kernel/arm/base/strided_slice.h"
17 #include <vector>
18 #include "schema/model_generated.h"
19 #include "src/kernel_registry.h"
20 #include "include/errorcode.h"
21
22 using mindspore::kernel::KERNEL_ARCH;
23 using mindspore::lite::KernelRegistrar;
24 using mindspore::lite::RET_ERROR;
25 using mindspore::lite::RET_NULL_PTR;
26 using mindspore::lite::RET_OK;
27 using mindspore::schema::PrimitiveType_StridedSlice;
28
29 namespace {
30 constexpr int kNumInputSize = 2;
31 constexpr int kNumOutputSize = 1;
32 } // namespace
33 namespace mindspore::kernel {
Init()34 int StridedSliceCPUKernel::Init() {
35 CHECK_LESS_RETURN(in_tensors_.size(), kNumInputSize);
36 CHECK_LESS_RETURN(out_tensors_.size(), kNumOutputSize);
37 CHECK_NULL_RETURN(in_tensors_[0]);
38 CHECK_NULL_RETURN(in_tensors_[1]);
39 CHECK_NULL_RETURN(out_tensors_[0]);
40 if (!InferShapeDone()) {
41 return RET_OK;
42 }
43 return ReSize();
44 }
45
InitFastRunParam()46 void StridedSliceCPUKernel::InitFastRunParam() {
47 auto in_shape = in_tensors_.front()->shape();
48 auto out_shape = out_tensors_.front()->shape();
49 // reset && cal inner, outer
50 outer_ = 1;
51 inner_ = 1;
52 for (int i = 0; i < split_axis_; ++i) {
53 outer_ *= in_shape[i];
54 }
55 for (size_t i = split_axis_ + 1; i < in_shape.size(); i++) {
56 inner_ *= in_shape[i];
57 }
58 // decide multi-thread launch strategy
59 if (op_parameter_->thread_num_ == 0) {
60 MS_LOG(ERROR) << "thread num is zero.";
61 return;
62 }
63 if (outer_ == 1) {
64 parallel_on_split_axis_ = true;
65 cal_num_per_thread_ = UP_DIV(out_shape[split_axis_], op_parameter_->thread_num_);
66 } else {
67 parallel_on_outer_ = true;
68 cal_num_per_thread_ = UP_DIV(outer_, op_parameter_->thread_num_);
69 }
70 }
71
ReSize()72 int StridedSliceCPUKernel::ReSize() {
73 auto input_tensor = in_tensors_.at(0);
74 auto begin_tensor = in_tensors_.at(1);
75 if (input_tensor->shape().size() > DIMENSION_8D || begin_tensor->shape().size() > DIMENSION_8D) {
76 MS_LOG(ERROR) << "StridedSlice not support input rank or begin num exceeds " << DIMENSION_8D;
77 return RET_ERROR;
78 }
79 fast_run_ = MatchFastPattern();
80 if (fast_run_) {
81 InitFastRunParam();
82 }
83 return RET_OK;
84 }
85
MatchFastPattern()86 bool StridedSliceCPUKernel::MatchFastPattern() {
87 // This function is seeking if that the number of only one dimension
88 // is different between input and output. If so, we can do some trick.
89 // Example 1:
90 // input shape info: [1, 80, 46, 40]
91 // output shape info: [1, 80, 20, 40]
92 // Example 2:
93 // input shape info: [1, 46, 40]
94 // output shape info: [1, 20, 40]
95 auto in_shape = in_tensors_.front()->shape();
96 auto out_shape = out_tensors_.front()->shape();
97 if (in_shape.size() != out_shape.size()) {
98 return false;
99 }
100 std::vector<int> axis_list;
101 for (size_t i = 0; i < in_shape.size(); ++i) {
102 if (in_shape[i] != out_shape[i]) {
103 axis_list.emplace_back(i);
104 }
105 }
106 if (axis_list.size() == 1) {
107 split_axis_ = axis_list.front();
108 return true;
109 }
110 return false;
111 }
112
FastRunImpl(int task_id)113 int StridedSliceCPUKernel::FastRunImpl(int task_id) {
114 auto in_shape = in_tensors_.front()->shape();
115 auto out_shape = out_tensors_.front()->shape();
116 int begin_index = param_->begins_[split_axis_];
117 int caled_num = task_id * cal_num_per_thread_;
118 if (parallel_on_outer_) {
119 uint8_t *cur_in_ptr = input_ptr_ + (caled_num * in_shape[split_axis_] + begin_index) * inner_size_;
120 uint8_t *cur_out_ptr = output_ptr_ + caled_num * out_shape[split_axis_] * inner_size_;
121 int cur_outer = outer_ - caled_num;
122 if (cur_outer <= 0) {
123 return RET_OK;
124 }
125 if (cur_outer > cal_num_per_thread_) {
126 cur_outer = cal_num_per_thread_;
127 }
128 FastStride(cur_in_ptr, cur_out_ptr, out_shape[split_axis_], param_->strides_[split_axis_], cur_outer, inner_size_,
129 in_shape[split_axis_] * inner_size_);
130 } else {
131 MS_ASSERT(parallel_on_split_axis_);
132 uint8_t *cur_in_ptr = input_ptr_ + (caled_num * param_->strides_[split_axis_] + begin_index) * inner_size_;
133 uint8_t *cur_out_ptr = output_ptr_ + caled_num * inner_size_;
134 int cal_axis_num = out_shape[split_axis_] - caled_num;
135 if (cal_axis_num <= 0) {
136 return RET_OK;
137 }
138 if (cal_axis_num > cal_num_per_thread_) {
139 cal_axis_num = cal_num_per_thread_;
140 }
141 FastStride(cur_in_ptr, cur_out_ptr, cal_axis_num, param_->strides_[split_axis_], 1, inner_size_, 0);
142 }
143 return RET_OK;
144 }
145
StrideRun(void * cdata,int task_id,float lhs_scale,float rhs_scale)146 int StrideRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
147 CHECK_NULL_RETURN(cdata);
148 auto stride = reinterpret_cast<StridedSliceCPUKernel *>(cdata);
149 auto ret = stride->FastRunImpl(task_id);
150 if (ret != RET_OK) {
151 MS_LOG(ERROR) << "StrideRun error task_id[" << task_id << "] error_code[" << ret << "]";
152 return ret;
153 }
154 return RET_OK;
155 }
156
FastRun()157 int StridedSliceCPUKernel::FastRun() {
158 // Update length of inner size, because data type of tensor may be changed
159 // from float32 to float16 during fp16 sub-graph partition process.
160 auto input = in_tensors_.front();
161 switch (input->data_type()) {
162 case kNumberTypeInt8:
163 inner_size_ = inner_ * sizeof(int8_t);
164 break;
165 case kNumberTypeFloat32:
166 inner_size_ = inner_ * sizeof(float);
167 break;
168 case kNumberTypeFloat16:
169 inner_size_ = inner_ * sizeof(int16_t);
170 break;
171 case kNumberTypeInt32:
172 inner_size_ = inner_ * sizeof(int32_t);
173 break;
174 default:
175 MS_LOG(ERROR) << "Not supported data type: " << input->data_type();
176 return RET_ERROR;
177 }
178 input_ptr_ = reinterpret_cast<uint8_t *>(in_tensors_.front()->data());
179 CHECK_NULL_RETURN(input_ptr_);
180 output_ptr_ = reinterpret_cast<uint8_t *>(out_tensors_.front()->data());
181 CHECK_NULL_RETURN(output_ptr_);
182 if (input_ptr_ == nullptr || output_ptr_ == nullptr) {
183 return RET_NULL_PTR;
184 }
185 auto ret = ParallelLaunch(this->ms_context_, StrideRun, this, op_parameter_->thread_num_);
186 if (ret != RET_OK) {
187 MS_LOG(ERROR) << "Stride run error error_code[" << ret << "]";
188 return ret;
189 }
190 return RET_OK;
191 }
192
NormalRun()193 int StridedSliceCPUKernel::NormalRun() {
194 auto input = in_tensors_.at(0);
195 switch (input->data_type()) {
196 case kNumberTypeInt8:
197 param_->data_type = kDataTypeInt8;
198 break;
199 case kNumberTypeFloat32:
200 param_->data_type = kDataTypeFloat;
201 break;
202 case kNumberTypeFloat16:
203 param_->data_type = kDataTypeFloat16;
204 break;
205 case kNumberTypeInt32:
206 param_->data_type = kDataTypeInt;
207 break;
208 default:
209 MS_LOG(ERROR) << "Not supported data type: " << input->data_type();
210 return RET_ERROR;
211 }
212 auto output = out_tensors_.at(0);
213 CHECK_NULL_RETURN(input->data());
214 CHECK_NULL_RETURN(output->data());
215 auto ret = DoStridedSlice(input->data(), output->data(), param_);
216 if (ret != RET_OK) {
217 MS_LOG(ERROR) << "StridedSlice error error_code[" << ret << "]";
218 return RET_ERROR;
219 }
220 return RET_OK;
221 }
222
Run()223 int StridedSliceCPUKernel::Run() {
224 if (fast_run_) {
225 return FastRun();
226 }
227 return NormalRun();
228 }
229
230 REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_StridedSlice, LiteKernelCreator<StridedSliceCPUKernel>)
231 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_StridedSlice, LiteKernelCreator<StridedSliceCPUKernel>)
232 REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_StridedSlice, LiteKernelCreator<StridedSliceCPUKernel>)
233 REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_StridedSlice, LiteKernelCreator<StridedSliceCPUKernel>)
234 } // namespace mindspore::kernel
235