1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/kernel_compiler/cpu/stridedslice_cpu_kernel.h"
18 #include <utility>
19 #include <functional>
20 #include <algorithm>
21 #include <unordered_map>
22 #include "common/thread_pool.h"
23 #include "runtime/device/cpu/cpu_device_address.h"
24
25 namespace mindspore {
26 namespace kernel {
27 namespace {
28 constexpr size_t kStridedSliceInputsNum = 1;
29 constexpr size_t kStridedSliceOutputsNum = 1;
30 } // namespace
31
32 enum PosType { kBegin, kEnd };
33
NormalizePos(int pos,int dim_len,PosType pos_type)34 int NormalizePos(int pos, int dim_len, PosType pos_type) {
35 if (pos >= 0) {
36 int max_pos = pos_type == kBegin ? dim_len - 1 : dim_len;
37 return std::min(pos, max_pos);
38 }
39 int min_pos = pos_type == kBegin ? 0 : -1;
40 return std::max(pos + dim_len, min_pos);
41 }
42
InitKernel(const CNodePtr & kernel_node)43 void StridedSliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
44 MS_EXCEPTION_IF_NULL(kernel_node);
45 kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
46 input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
47 output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
48 if (input_shape_.size() > DIMENSION_8D || input_shape_.empty()) {
49 MS_LOG(EXCEPTION) << "StridedSlice only support 1D to 8D input tensor, but got " << input_shape_.size() << "D.";
50 }
51
52 auto begin = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN);
53 auto end = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, END);
54 auto stride = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDES);
55 if (begin.size() != end.size() || begin.size() != stride.size() || begin.size() > input_shape_.size()) {
56 MS_LOG(EXCEPTION)
57 << "StridedSLice requires the length of begin, stride and end must be equal and less than input dimension.";
58 }
59 dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
60 InitSliceParam(begin, end, stride);
61
62 parallel_ = MatchParallelPattern();
63 if (parallel_) {
64 InitParallelParam();
65 }
66 }
67
MatchParallelPattern()68 bool StridedSliceCPUKernel::MatchParallelPattern() {
69 // This function is seeking if that the number of only one dimension
70 // is different between input and output. If so, we can do some trick.
71 // Example 1:
72 // input shape info: [1, 80, 46, 40]
73 // output shape info: [1, 80, 20, 40]
74 // Example 2:
75 // input shape info: [1, 46, 40]
76 // output shape info: [1, 20, 40]
77 if (input_shape_.size() == output_shape_.size()) {
78 std::vector<int> axis_list;
79 for (size_t i = 0; i < input_shape_.size(); ++i) {
80 if (input_shape_[i] != output_shape_[i]) {
81 (void)axis_list.emplace_back(i);
82 }
83 }
84 if (axis_list.size() == 1) {
85 split_axis_ = axis_list.front();
86 return true;
87 }
88 }
89 return false;
90 }
91
InitParallelParam()92 void StridedSliceCPUKernel::InitParallelParam() {
93 outer_ = SizeToInt(
94 std::accumulate(input_shape_.begin(), input_shape_.begin() + split_axis_, size_t(1), std::multiplies<size_t>()));
95 inner_ = SizeToInt(
96 std::accumulate(input_shape_.begin() + split_axis_ + 1, input_shape_.end(), size_t(1), std::multiplies<size_t>()));
97
98 int max_thread_num = SizeToInt(common::ThreadPool::GetInstance().GetSyncRunThreadNum());
99 int thread_num = 1;
100 if (outer_ == 1) {
101 parallel_strategy_ = kOnSplitAxis;
102 thread_num = std::min(SizeToInt(output_shape_[split_axis_]), max_thread_num);
103 cal_num_per_thread_ = UP_DIV(output_shape_[split_axis_], thread_num);
104 } else {
105 parallel_strategy_ = kOnOuter;
106 thread_num = std::min(outer_, max_thread_num);
107 cal_num_per_thread_ = UP_DIV(outer_, thread_num);
108 }
109 slice_param_.op_parameter_.thread_num_ = thread_num;
110 }
111
InitSliceParam(const std::vector<int64_t> & begin,const std::vector<int64_t> & end,const std::vector<int64_t> & stride)112 void StridedSliceCPUKernel::InitSliceParam(const std::vector<int64_t> &begin, const std::vector<int64_t> &end,
113 const std::vector<int64_t> &stride) {
114 static const std::unordered_map<TypeId, std::pair<LiteDataType, int>> type_convert_map = {
115 {kNumberTypeBool, {kDataTypeBool, sizeof(bool)}},
116 {kNumberTypeInt32, {kDataTypeInt, sizeof(int)}},
117 {kNumberTypeFloat32, {kDataTypeFloat, sizeof(float)}},
118 {kNumberTypeFloat64, {kDataTypeFloat64, sizeof(double)}}};
119
120 auto type_pair = type_convert_map.find(dtype_);
121 if (type_pair == type_convert_map.end()) {
122 MS_LOG(EXCEPTION) << "StridedSlice supports bool, int32, float32 and float64 input tensor, but got "
123 << TypeIdToType(dtype_)->ToString();
124 }
125 data_size_ = type_pair->second.second;
126 slice_param_.data_type = type_pair->second.first;
127
128 for (size_t i = 0; i < DIMENSION_8D; i++) {
129 int dim_len;
130 if (i < begin.size()) {
131 dim_len = SizeToInt(input_shape_[i]);
132 int begin_pos = LongToInt(begin[i]);
133 int end_pos = LongToInt(end[i]);
134 int stride_size = LongToInt(stride[i]);
135 if (stride_size == 0) {
136 MS_LOG(EXCEPTION) << "StridedSlice requires the each dimension slice stride can't be 0.";
137 }
138 slice_param_.in_shape_[i] = dim_len;
139 slice_param_.strides_[i] = stride_size;
140 slice_param_.begins_[i] = NormalizePos(begin_pos, dim_len, kBegin);
141 slice_param_.ends_[i] = NormalizePos(end_pos, dim_len, kEnd);
142 if (slice_param_.ends_[i] <= slice_param_.begins_[i] && slice_param_.strides_[i] > 0) {
143 slice_param_.ends_[i] = slice_param_.begins_[i] + 1;
144 }
145 if (slice_param_.ends_[i] >= slice_param_.begins_[i] && slice_param_.strides_[i] < 0) {
146 slice_param_.ends_[i] = slice_param_.begins_[i] - 1;
147 }
148 } else if (i < input_shape_.size()) {
149 dim_len = SizeToInt(input_shape_[i]);
150 slice_param_.in_shape_[i] = dim_len;
151 slice_param_.begins_[i] = 0;
152 slice_param_.ends_[i] = dim_len;
153 slice_param_.strides_[i] = 1;
154 } else {
155 slice_param_.in_shape_[i] = 1;
156 slice_param_.begins_[i] = 0;
157 slice_param_.ends_[i] = 1;
158 slice_param_.strides_[i] = 1;
159 }
160 }
161 slice_param_.in_shape_length_ = DIMENSION_8D;
162 slice_param_.num_axes_ = DIMENSION_8D;
163 }
164
RunTaskOnOuter(const uint8_t * input_addr,uint8_t * output_addr,int start_pos)165 int StridedSliceCPUKernel::RunTaskOnOuter(const uint8_t *input_addr, uint8_t *output_addr, int start_pos) {
166 int begin_index = slice_param_.begins_[split_axis_];
167 int inner_size = inner_ * data_size_;
168 const uint8_t *cur_in_ptr = input_addr + (start_pos * input_shape_[split_axis_] + begin_index) * inner_size;
169 uint8_t *cur_out_ptr = output_addr + start_pos * output_shape_[split_axis_] * inner_size;
170 int cur_outer = outer_ - start_pos;
171 if (cur_outer <= 0) {
172 return common::SUCCESS;
173 }
174 cur_outer = cur_outer > cal_num_per_thread_ ? cal_num_per_thread_ : cur_outer;
175 FastStride(cur_in_ptr, cur_out_ptr, output_shape_[split_axis_], slice_param_.strides_[split_axis_], cur_outer,
176 inner_size, input_shape_[split_axis_] * inner_size);
177 return common::SUCCESS;
178 }
179
RunTaskOnSplitAxis(const uint8_t * input_addr,uint8_t * output_addr,int start_pos)180 int StridedSliceCPUKernel::RunTaskOnSplitAxis(const uint8_t *input_addr, uint8_t *output_addr, int start_pos) {
181 int begin_index = slice_param_.begins_[split_axis_];
182 int inner_size = inner_ * data_size_;
183 const uint8_t *cur_in_ptr = input_addr + (start_pos * slice_param_.strides_[split_axis_] + begin_index) * inner_size;
184 uint8_t *cur_out_ptr = output_addr + start_pos * inner_size;
185 int cal_axis_num = output_shape_[split_axis_] - start_pos;
186 if (cal_axis_num <= 0) {
187 return common::SUCCESS;
188 }
189 cal_axis_num = cal_axis_num > cal_num_per_thread_ ? cal_num_per_thread_ : cal_axis_num;
190 FastStride(cur_in_ptr, cur_out_ptr, cal_axis_num, slice_param_.strides_[split_axis_], 1, inner_size, 0);
191 return common::SUCCESS;
192 }
193
ParallelRun(const uint8_t * input_addr,uint8_t * output_addr,int thread_num)194 void StridedSliceCPUKernel::ParallelRun(const uint8_t *input_addr, uint8_t *output_addr, int thread_num) {
195 int thread_index = 0;
196 std::vector<common::Task> tasks;
197 std::function<int(StridedSliceCPUKernel *, const uint8_t *, uint8_t *, int)> execute_func;
198 if (parallel_strategy_ == kOnOuter) {
199 execute_func = &StridedSliceCPUKernel::RunTaskOnOuter;
200 } else if (parallel_strategy_ == kOnSplitAxis) {
201 execute_func = &StridedSliceCPUKernel::RunTaskOnSplitAxis;
202 } else {
203 MS_LOG(EXCEPTION) << "Not supported parallel execute strategy for StridedSlice.";
204 }
205
206 while (thread_index < thread_num) {
207 (void)tasks.emplace_back(
208 std::bind(execute_func, this, input_addr, output_addr, thread_index * cal_num_per_thread_));
209 thread_index++;
210 }
211 (void)common::ThreadPool::GetInstance().SyncRun(tasks);
212 }
213
Launch(const std::vector<kernel::AddressPtr> & inputs,const std::vector<kernel::AddressPtr> &,const std::vector<kernel::AddressPtr> & outputs)214 bool StridedSliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
215 const std::vector<kernel::AddressPtr> & /* workspace */,
216 const std::vector<kernel::AddressPtr> &outputs) {
217 CHECK_KERNEL_INPUTS_NUM(inputs.size(), kStridedSliceInputsNum, kernel_name_);
218 CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kStridedSliceOutputsNum, kernel_name_);
219 if (outputs[0]->size == 0) {
220 MS_LOG(WARNING) << "StridedSlice output memory size should be greater than 0, but got 0.";
221 return true;
222 }
223 auto input_addr = reinterpret_cast<uint8_t *>(inputs[0]->addr);
224 auto output_addr = reinterpret_cast<uint8_t *>(outputs[0]->addr);
225 int thread_num = slice_param_.op_parameter_.thread_num_;
226 if (parallel_ && thread_num >= 2) {
227 ParallelRun(input_addr, output_addr, thread_num);
228 } else {
229 (void)DoStridedSlice(input_addr, output_addr, &slice_param_);
230 }
231 return true;
232 }
233 } // namespace kernel
234 } // namespace mindspore
235