• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/runtime/kernel/arm/base/strided_slice.h"
17 #include <vector>
18 #include "schema/model_generated.h"
19 #include "src/kernel_registry.h"
20 #include "include/errorcode.h"
21 
22 using mindspore::kernel::KERNEL_ARCH;
23 using mindspore::lite::KernelRegistrar;
24 using mindspore::lite::RET_ERROR;
25 using mindspore::lite::RET_NULL_PTR;
26 using mindspore::lite::RET_OK;
27 using mindspore::schema::PrimitiveType_StridedSlice;
28 
29 namespace {
30 constexpr int kNumInputSize = 2;
31 constexpr int kNumOutputSize = 1;
32 }  // namespace
33 namespace mindspore::kernel {
Init()34 int StridedSliceCPUKernel::Init() {
35   CHECK_LESS_RETURN(in_tensors_.size(), kNumInputSize);
36   CHECK_LESS_RETURN(out_tensors_.size(), kNumOutputSize);
37   CHECK_NULL_RETURN(in_tensors_[0]);
38   CHECK_NULL_RETURN(in_tensors_[1]);
39   CHECK_NULL_RETURN(out_tensors_[0]);
40   if (!InferShapeDone()) {
41     return RET_OK;
42   }
43   return ReSize();
44 }
45 
InitFastRunParam()46 void StridedSliceCPUKernel::InitFastRunParam() {
47   auto in_shape = in_tensors_.front()->shape();
48   auto out_shape = out_tensors_.front()->shape();
49   // reset && cal inner, outer
50   outer_ = 1;
51   inner_ = 1;
52   for (int i = 0; i < split_axis_; ++i) {
53     outer_ *= in_shape[i];
54   }
55   for (size_t i = split_axis_ + 1; i < in_shape.size(); i++) {
56     inner_ *= in_shape[i];
57   }
58   // decide multi-thread launch strategy
59   if (op_parameter_->thread_num_ == 0) {
60     MS_LOG(ERROR) << "thread num is zero.";
61     return;
62   }
63   if (outer_ == 1) {
64     parallel_on_split_axis_ = true;
65     cal_num_per_thread_ = UP_DIV(out_shape[split_axis_], op_parameter_->thread_num_);
66   } else {
67     parallel_on_outer_ = true;
68     cal_num_per_thread_ = UP_DIV(outer_, op_parameter_->thread_num_);
69   }
70 }
71 
ReSize()72 int StridedSliceCPUKernel::ReSize() {
73   auto input_tensor = in_tensors_.at(0);
74   auto begin_tensor = in_tensors_.at(1);
75   if (input_tensor->shape().size() > DIMENSION_8D || begin_tensor->shape().size() > DIMENSION_8D) {
76     MS_LOG(ERROR) << "StridedSlice not support input rank or begin num exceeds " << DIMENSION_8D;
77     return RET_ERROR;
78   }
79   fast_run_ = MatchFastPattern();
80   if (fast_run_) {
81     InitFastRunParam();
82   }
83   return RET_OK;
84 }
85 
MatchFastPattern()86 bool StridedSliceCPUKernel::MatchFastPattern() {
87   // This function is seeking if that the number of only one dimension
88   // is different between input and output. If so, we can do some trick.
89   // Example 1:
90   // input shape info:  [1, 80, 46, 40]
91   // output shape info: [1, 80, 20, 40]
92   // Example 2:
93   // input shape info:  [1, 46, 40]
94   // output shape info: [1, 20, 40]
95   auto in_shape = in_tensors_.front()->shape();
96   auto out_shape = out_tensors_.front()->shape();
97   if (in_shape.size() != out_shape.size()) {
98     return false;
99   }
100   std::vector<int> axis_list;
101   for (size_t i = 0; i < in_shape.size(); ++i) {
102     if (in_shape[i] != out_shape[i]) {
103       axis_list.emplace_back(i);
104     }
105   }
106   if (axis_list.size() == 1) {
107     split_axis_ = axis_list.front();
108     return true;
109   }
110   return false;
111 }
112 
FastRunImpl(int task_id)113 int StridedSliceCPUKernel::FastRunImpl(int task_id) {
114   auto in_shape = in_tensors_.front()->shape();
115   auto out_shape = out_tensors_.front()->shape();
116   int begin_index = param_->begins_[split_axis_];
117   int caled_num = task_id * cal_num_per_thread_;
118   if (parallel_on_outer_) {
119     uint8_t *cur_in_ptr = input_ptr_ + (caled_num * in_shape[split_axis_] + begin_index) * inner_size_;
120     uint8_t *cur_out_ptr = output_ptr_ + caled_num * out_shape[split_axis_] * inner_size_;
121     int cur_outer = outer_ - caled_num;
122     if (cur_outer <= 0) {
123       return RET_OK;
124     }
125     if (cur_outer > cal_num_per_thread_) {
126       cur_outer = cal_num_per_thread_;
127     }
128     FastStride(cur_in_ptr, cur_out_ptr, out_shape[split_axis_], param_->strides_[split_axis_], cur_outer, inner_size_,
129                in_shape[split_axis_] * inner_size_);
130   } else {
131     MS_ASSERT(parallel_on_split_axis_);
132     uint8_t *cur_in_ptr = input_ptr_ + (caled_num * param_->strides_[split_axis_] + begin_index) * inner_size_;
133     uint8_t *cur_out_ptr = output_ptr_ + caled_num * inner_size_;
134     int cal_axis_num = out_shape[split_axis_] - caled_num;
135     if (cal_axis_num <= 0) {
136       return RET_OK;
137     }
138     if (cal_axis_num > cal_num_per_thread_) {
139       cal_axis_num = cal_num_per_thread_;
140     }
141     FastStride(cur_in_ptr, cur_out_ptr, cal_axis_num, param_->strides_[split_axis_], 1, inner_size_, 0);
142   }
143   return RET_OK;
144 }
145 
StrideRun(void * cdata,int task_id,float lhs_scale,float rhs_scale)146 int StrideRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
147   CHECK_NULL_RETURN(cdata);
148   auto stride = reinterpret_cast<StridedSliceCPUKernel *>(cdata);
149   auto ret = stride->FastRunImpl(task_id);
150   if (ret != RET_OK) {
151     MS_LOG(ERROR) << "StrideRun error task_id[" << task_id << "] error_code[" << ret << "]";
152     return ret;
153   }
154   return RET_OK;
155 }
156 
FastRun()157 int StridedSliceCPUKernel::FastRun() {
158   // Update length of inner size, because data type of tensor may be changed
159   // from float32 to float16 during fp16 sub-graph partition process.
160   auto input = in_tensors_.front();
161   switch (input->data_type()) {
162     case kNumberTypeInt8:
163       inner_size_ = inner_ * sizeof(int8_t);
164       break;
165     case kNumberTypeFloat32:
166       inner_size_ = inner_ * sizeof(float);
167       break;
168     case kNumberTypeFloat16:
169       inner_size_ = inner_ * sizeof(int16_t);
170       break;
171     case kNumberTypeInt32:
172       inner_size_ = inner_ * sizeof(int32_t);
173       break;
174     default:
175       MS_LOG(ERROR) << "Not supported data type: " << input->data_type();
176       return RET_ERROR;
177   }
178   input_ptr_ = reinterpret_cast<uint8_t *>(in_tensors_.front()->data());
179   CHECK_NULL_RETURN(input_ptr_);
180   output_ptr_ = reinterpret_cast<uint8_t *>(out_tensors_.front()->data());
181   CHECK_NULL_RETURN(output_ptr_);
182   if (input_ptr_ == nullptr || output_ptr_ == nullptr) {
183     return RET_NULL_PTR;
184   }
185   auto ret = ParallelLaunch(this->ms_context_, StrideRun, this, op_parameter_->thread_num_);
186   if (ret != RET_OK) {
187     MS_LOG(ERROR) << "Stride run error error_code[" << ret << "]";
188     return ret;
189   }
190   return RET_OK;
191 }
192 
NormalRun()193 int StridedSliceCPUKernel::NormalRun() {
194   auto input = in_tensors_.at(0);
195   switch (input->data_type()) {
196     case kNumberTypeInt8:
197       param_->data_type = kDataTypeInt8;
198       break;
199     case kNumberTypeFloat32:
200       param_->data_type = kDataTypeFloat;
201       break;
202     case kNumberTypeFloat16:
203       param_->data_type = kDataTypeFloat16;
204       break;
205     case kNumberTypeInt32:
206       param_->data_type = kDataTypeInt;
207       break;
208     default:
209       MS_LOG(ERROR) << "Not supported data type: " << input->data_type();
210       return RET_ERROR;
211   }
212   auto output = out_tensors_.at(0);
213   CHECK_NULL_RETURN(input->data());
214   CHECK_NULL_RETURN(output->data());
215   auto ret = DoStridedSlice(input->data(), output->data(), param_);
216   if (ret != RET_OK) {
217     MS_LOG(ERROR) << "StridedSlice error error_code[" << ret << "]";
218     return RET_ERROR;
219   }
220   return RET_OK;
221 }
222 
Run()223 int StridedSliceCPUKernel::Run() {
224   if (fast_run_) {
225     return FastRun();
226   }
227   return NormalRun();
228 }
229 
230 REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_StridedSlice, LiteKernelCreator<StridedSliceCPUKernel>)
231 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_StridedSlice, LiteKernelCreator<StridedSliceCPUKernel>)
232 REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_StridedSlice, LiteKernelCreator<StridedSliceCPUKernel>)
233 REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_StridedSlice, LiteKernelCreator<StridedSliceCPUKernel>)
234 }  // namespace mindspore::kernel
235