• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/kernel_compiler/cpu/stridedslice_cpu_kernel.h"
18 #include <utility>
19 #include <functional>
20 #include <algorithm>
21 #include <unordered_map>
22 #include "common/thread_pool.h"
23 #include "runtime/device/cpu/cpu_device_address.h"
24 
25 namespace mindspore {
26 namespace kernel {
27 namespace {
28 constexpr size_t kStridedSliceInputsNum = 1;
29 constexpr size_t kStridedSliceOutputsNum = 1;
30 }  // namespace
31 
32 enum PosType { kBegin, kEnd };
33 
NormalizePos(int pos,int dim_len,PosType pos_type)34 int NormalizePos(int pos, int dim_len, PosType pos_type) {
35   if (pos >= 0) {
36     int max_pos = pos_type == kBegin ? dim_len - 1 : dim_len;
37     return std::min(pos, max_pos);
38   }
39   int min_pos = pos_type == kBegin ? 0 : -1;
40   return std::max(pos + dim_len, min_pos);
41 }
42 
InitKernel(const CNodePtr & kernel_node)43 void StridedSliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
44   MS_EXCEPTION_IF_NULL(kernel_node);
45   kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
46   input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
47   output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
48   if (input_shape_.size() > DIMENSION_8D || input_shape_.empty()) {
49     MS_LOG(EXCEPTION) << "StridedSlice only support 1D to 8D input tensor, but got " << input_shape_.size() << "D.";
50   }
51 
52   auto begin = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN);
53   auto end = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, END);
54   auto stride = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDES);
55   if (begin.size() != end.size() || begin.size() != stride.size() || begin.size() > input_shape_.size()) {
56     MS_LOG(EXCEPTION)
57       << "StridedSLice requires the length of begin, stride and end must be equal and less than input dimension.";
58   }
59   dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
60   InitSliceParam(begin, end, stride);
61 
62   parallel_ = MatchParallelPattern();
63   if (parallel_) {
64     InitParallelParam();
65   }
66 }
67 
MatchParallelPattern()68 bool StridedSliceCPUKernel::MatchParallelPattern() {
69   // This function is seeking if that the number of only one dimension
70   // is different between input and output. If so, we can do some trick.
71   // Example 1:
72   // input shape info:  [1, 80, 46, 40]
73   // output shape info: [1, 80, 20, 40]
74   // Example 2:
75   // input shape info:  [1, 46, 40]
76   // output shape info: [1, 20, 40]
77   if (input_shape_.size() == output_shape_.size()) {
78     std::vector<int> axis_list;
79     for (size_t i = 0; i < input_shape_.size(); ++i) {
80       if (input_shape_[i] != output_shape_[i]) {
81         (void)axis_list.emplace_back(i);
82       }
83     }
84     if (axis_list.size() == 1) {
85       split_axis_ = axis_list.front();
86       return true;
87     }
88   }
89   return false;
90 }
91 
InitParallelParam()92 void StridedSliceCPUKernel::InitParallelParam() {
93   outer_ = SizeToInt(
94     std::accumulate(input_shape_.begin(), input_shape_.begin() + split_axis_, size_t(1), std::multiplies<size_t>()));
95   inner_ = SizeToInt(
96     std::accumulate(input_shape_.begin() + split_axis_ + 1, input_shape_.end(), size_t(1), std::multiplies<size_t>()));
97 
98   int max_thread_num = SizeToInt(common::ThreadPool::GetInstance().GetSyncRunThreadNum());
99   int thread_num = 1;
100   if (outer_ == 1) {
101     parallel_strategy_ = kOnSplitAxis;
102     thread_num = std::min(SizeToInt(output_shape_[split_axis_]), max_thread_num);
103     cal_num_per_thread_ = UP_DIV(output_shape_[split_axis_], thread_num);
104   } else {
105     parallel_strategy_ = kOnOuter;
106     thread_num = std::min(outer_, max_thread_num);
107     cal_num_per_thread_ = UP_DIV(outer_, thread_num);
108   }
109   slice_param_.op_parameter_.thread_num_ = thread_num;
110 }
111 
InitSliceParam(const std::vector<int64_t> & begin,const std::vector<int64_t> & end,const std::vector<int64_t> & stride)112 void StridedSliceCPUKernel::InitSliceParam(const std::vector<int64_t> &begin, const std::vector<int64_t> &end,
113                                            const std::vector<int64_t> &stride) {
114   static const std::unordered_map<TypeId, std::pair<LiteDataType, int>> type_convert_map = {
115     {kNumberTypeBool, {kDataTypeBool, sizeof(bool)}},
116     {kNumberTypeInt32, {kDataTypeInt, sizeof(int)}},
117     {kNumberTypeFloat32, {kDataTypeFloat, sizeof(float)}},
118     {kNumberTypeFloat64, {kDataTypeFloat64, sizeof(double)}}};
119 
120   auto type_pair = type_convert_map.find(dtype_);
121   if (type_pair == type_convert_map.end()) {
122     MS_LOG(EXCEPTION) << "StridedSlice supports bool, int32, float32 and float64 input tensor, but got "
123                       << TypeIdToType(dtype_)->ToString();
124   }
125   data_size_ = type_pair->second.second;
126   slice_param_.data_type = type_pair->second.first;
127 
128   for (size_t i = 0; i < DIMENSION_8D; i++) {
129     int dim_len;
130     if (i < begin.size()) {
131       dim_len = SizeToInt(input_shape_[i]);
132       int begin_pos = LongToInt(begin[i]);
133       int end_pos = LongToInt(end[i]);
134       int stride_size = LongToInt(stride[i]);
135       if (stride_size == 0) {
136         MS_LOG(EXCEPTION) << "StridedSlice requires the each dimension slice stride can't be 0.";
137       }
138       slice_param_.in_shape_[i] = dim_len;
139       slice_param_.strides_[i] = stride_size;
140       slice_param_.begins_[i] = NormalizePos(begin_pos, dim_len, kBegin);
141       slice_param_.ends_[i] = NormalizePos(end_pos, dim_len, kEnd);
142       if (slice_param_.ends_[i] <= slice_param_.begins_[i] && slice_param_.strides_[i] > 0) {
143         slice_param_.ends_[i] = slice_param_.begins_[i] + 1;
144       }
145       if (slice_param_.ends_[i] >= slice_param_.begins_[i] && slice_param_.strides_[i] < 0) {
146         slice_param_.ends_[i] = slice_param_.begins_[i] - 1;
147       }
148     } else if (i < input_shape_.size()) {
149       dim_len = SizeToInt(input_shape_[i]);
150       slice_param_.in_shape_[i] = dim_len;
151       slice_param_.begins_[i] = 0;
152       slice_param_.ends_[i] = dim_len;
153       slice_param_.strides_[i] = 1;
154     } else {
155       slice_param_.in_shape_[i] = 1;
156       slice_param_.begins_[i] = 0;
157       slice_param_.ends_[i] = 1;
158       slice_param_.strides_[i] = 1;
159     }
160   }
161   slice_param_.in_shape_length_ = DIMENSION_8D;
162   slice_param_.num_axes_ = DIMENSION_8D;
163 }
164 
RunTaskOnOuter(const uint8_t * input_addr,uint8_t * output_addr,int start_pos)165 int StridedSliceCPUKernel::RunTaskOnOuter(const uint8_t *input_addr, uint8_t *output_addr, int start_pos) {
166   int begin_index = slice_param_.begins_[split_axis_];
167   int inner_size = inner_ * data_size_;
168   const uint8_t *cur_in_ptr = input_addr + (start_pos * input_shape_[split_axis_] + begin_index) * inner_size;
169   uint8_t *cur_out_ptr = output_addr + start_pos * output_shape_[split_axis_] * inner_size;
170   int cur_outer = outer_ - start_pos;
171   if (cur_outer <= 0) {
172     return common::SUCCESS;
173   }
174   cur_outer = cur_outer > cal_num_per_thread_ ? cal_num_per_thread_ : cur_outer;
175   FastStride(cur_in_ptr, cur_out_ptr, output_shape_[split_axis_], slice_param_.strides_[split_axis_], cur_outer,
176              inner_size, input_shape_[split_axis_] * inner_size);
177   return common::SUCCESS;
178 }
179 
RunTaskOnSplitAxis(const uint8_t * input_addr,uint8_t * output_addr,int start_pos)180 int StridedSliceCPUKernel::RunTaskOnSplitAxis(const uint8_t *input_addr, uint8_t *output_addr, int start_pos) {
181   int begin_index = slice_param_.begins_[split_axis_];
182   int inner_size = inner_ * data_size_;
183   const uint8_t *cur_in_ptr = input_addr + (start_pos * slice_param_.strides_[split_axis_] + begin_index) * inner_size;
184   uint8_t *cur_out_ptr = output_addr + start_pos * inner_size;
185   int cal_axis_num = output_shape_[split_axis_] - start_pos;
186   if (cal_axis_num <= 0) {
187     return common::SUCCESS;
188   }
189   cal_axis_num = cal_axis_num > cal_num_per_thread_ ? cal_num_per_thread_ : cal_axis_num;
190   FastStride(cur_in_ptr, cur_out_ptr, cal_axis_num, slice_param_.strides_[split_axis_], 1, inner_size, 0);
191   return common::SUCCESS;
192 }
193 
ParallelRun(const uint8_t * input_addr,uint8_t * output_addr,int thread_num)194 void StridedSliceCPUKernel::ParallelRun(const uint8_t *input_addr, uint8_t *output_addr, int thread_num) {
195   int thread_index = 0;
196   std::vector<common::Task> tasks;
197   std::function<int(StridedSliceCPUKernel *, const uint8_t *, uint8_t *, int)> execute_func;
198   if (parallel_strategy_ == kOnOuter) {
199     execute_func = &StridedSliceCPUKernel::RunTaskOnOuter;
200   } else if (parallel_strategy_ == kOnSplitAxis) {
201     execute_func = &StridedSliceCPUKernel::RunTaskOnSplitAxis;
202   } else {
203     MS_LOG(EXCEPTION) << "Not supported parallel execute strategy for StridedSlice.";
204   }
205 
206   while (thread_index < thread_num) {
207     (void)tasks.emplace_back(
208       std::bind(execute_func, this, input_addr, output_addr, thread_index * cal_num_per_thread_));
209     thread_index++;
210   }
211   (void)common::ThreadPool::GetInstance().SyncRun(tasks);
212 }
213 
Launch(const std::vector<kernel::AddressPtr> & inputs,const std::vector<kernel::AddressPtr> &,const std::vector<kernel::AddressPtr> & outputs)214 bool StridedSliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
215                                    const std::vector<kernel::AddressPtr> & /* workspace */,
216                                    const std::vector<kernel::AddressPtr> &outputs) {
217   CHECK_KERNEL_INPUTS_NUM(inputs.size(), kStridedSliceInputsNum, kernel_name_);
218   CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kStridedSliceOutputsNum, kernel_name_);
219   if (outputs[0]->size == 0) {
220     MS_LOG(WARNING) << "StridedSlice output memory size should be greater than 0, but got 0.";
221     return true;
222   }
223   auto input_addr = reinterpret_cast<uint8_t *>(inputs[0]->addr);
224   auto output_addr = reinterpret_cast<uint8_t *>(outputs[0]->addr);
225   int thread_num = slice_param_.op_parameter_.thread_num_;
226   if (parallel_ && thread_num >= 2) {
227     ParallelRun(input_addr, output_addr, thread_num);
228   } else {
229     (void)DoStridedSlice(input_addr, output_addr, &slice_param_);
230   }
231   return true;
232 }
233 }  // namespace kernel
234 }  // namespace mindspore
235