1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/runtime/kernel/arm/base/slice_base.h"
17 #include "src/kernel_registry.h"
18 #include "nnacl/base/slice_base.h"
19 #include "src/tensor.h"
20
21 using mindspore::lite::KernelRegistrar;
22 using mindspore::lite::RET_ERROR;
23 using mindspore::lite::RET_NULL_PTR;
24 using mindspore::lite::RET_OK;
25 using mindspore::schema::PrimitiveType_SliceFusion;
26
27 namespace {
28 constexpr int kNumInput0 = 0;
29 constexpr int kNumInput1 = 1;
30 constexpr int kNumInput2 = 2;
31 constexpr int kNumInputSize = 3;
32 } // namespace
33 namespace mindspore::kernel {
SliceLaunch(void * cdata,int task_id,float lhs_scale,float rhs_scale)34 int SliceLaunch(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
35 if (cdata == nullptr) {
36 MS_LOG(ERROR) << "Input cdata is nullptr!";
37 return RET_ERROR;
38 }
39 auto kernel = reinterpret_cast<SliceCPUKernel *>(cdata);
40 return kernel->SliceParallelRun(task_id);
41 }
42
ReSize()43 int SliceCPUKernel::ReSize() {
44 auto in_tensor = in_tensors_[kNumInput0];
45 auto begin_tensor = in_tensors_[kNumInput1];
46 auto size_tensor = in_tensors_[kNumInput2];
47 MS_ASSERT(in_tensor->shape().size() == static_cast<size_t>(begin_tensor->ElementsNum()));
48 MS_ASSERT(in_tensor->shape().size() == static_cast<size_t>(size_tensor->ElementsNum()));
49 MS_ASSERT(in_tensor->shape().size() <= DIMENSION_8D);
50 auto begin = reinterpret_cast<int32_t *>(begin_tensor->data());
51 CHECK_NULL_RETURN(begin);
52 auto size = reinterpret_cast<int32_t *>(size_tensor->data());
53 CHECK_NULL_RETURN(size);
54
55 param_->param_length_ = in_tensor->shape().size();
56 if (param_->param_length_ > DIMENSION_8D) {
57 MS_LOG(ERROR) << "input dimension num should <= " << DIMENSION_8D;
58 return RET_ERROR;
59 }
60 for (int i = 0; i < param_->param_length_; ++i) {
61 param_->shape_[i] = in_tensor->DimensionSize(i);
62 param_->begin_[i] = begin[i];
63 param_->size_[i] = size[i] < 0 ? param_->shape_[i] - param_->begin_[i] : size[i];
64 param_->end_[i] = param_->begin_[i] + param_->size_[i];
65 }
66 if (param_->param_length_ < DIMENSION_8D) {
67 PadSliceParameterTo8D(param_);
68 }
69 return RET_OK;
70 }
71
Init()72 int SliceCPUKernel::Init() {
73 CHECK_LESS_RETURN(in_tensors_.size(), kNumInputSize);
74 CHECK_LESS_RETURN(out_tensors_.size(), 1);
75 CHECK_NULL_RETURN(in_tensors_[kNumInput0]);
76 CHECK_NULL_RETURN(in_tensors_[kNumInput1]);
77 CHECK_NULL_RETURN(in_tensors_[kNumInput2]);
78 CHECK_NULL_RETURN(out_tensors_[0]);
79 CHECK_NULL_RETURN(op_parameter_);
80 if (!InferShapeDone()) {
81 return RET_OK;
82 }
83 return ReSize();
84 }
85
SliceParallelRun(int thread_id)86 int SliceCPUKernel::SliceParallelRun(int thread_id) {
87 DoSlice(in_tensors_.at(0)->data(), out_tensors_.at(0)->data(), param_, thread_id,
88 lite::DataTypeSize(in_tensors_.at(0)->data_type()));
89 return RET_OK;
90 }
91
Run()92 int SliceCPUKernel::Run() {
93 auto input_data = reinterpret_cast<float *>(in_tensors_.at(0)->data());
94 auto output_data = reinterpret_cast<float *>(out_tensors_.at(0)->data());
95 if (input_data == nullptr || output_data == nullptr) {
96 return RET_NULL_PTR;
97 }
98 // param_ shape info has already been extended to 8d
99 constexpr size_t kDimHUnder8D = 5;
100 if (param_->size_[kDimHUnder8D] < op_parameter_->thread_num_) {
101 DoSliceNoParallel(input_data, output_data, param_, lite::DataTypeSize(in_tensors_.at(0)->data_type()));
102 return RET_OK;
103 }
104 auto ret = ParallelLaunch(this->ms_context_, SliceLaunch, this, op_parameter_->thread_num_);
105 if (ret != RET_OK) {
106 MS_LOG(ERROR) << "slice launch fail!ret: " << ret;
107 return RET_ERROR;
108 }
109 return RET_OK;
110 }
111
112 REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_SliceFusion, LiteKernelCreator<SliceCPUKernel>)
113 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SliceFusion, LiteKernelCreator<SliceCPUKernel>)
114 } // namespace mindspore::kernel
115