• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/runtime/kernel/arm/base/slice_base.h"
17 #include "src/kernel_registry.h"
18 #include "nnacl/base/slice_base.h"
19 #include "src/tensor.h"
20 
21 using mindspore::lite::KernelRegistrar;
22 using mindspore::lite::RET_ERROR;
23 using mindspore::lite::RET_NULL_PTR;
24 using mindspore::lite::RET_OK;
25 using mindspore::schema::PrimitiveType_SliceFusion;
26 
27 namespace {
28 constexpr int kNumInput0 = 0;
29 constexpr int kNumInput1 = 1;
30 constexpr int kNumInput2 = 2;
31 constexpr int kNumInputSize = 3;
32 }  // namespace
33 namespace mindspore::kernel {
SliceLaunch(void * cdata,int task_id,float lhs_scale,float rhs_scale)34 int SliceLaunch(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
35   if (cdata == nullptr) {
36     MS_LOG(ERROR) << "Input cdata is nullptr!";
37     return RET_ERROR;
38   }
39   auto kernel = reinterpret_cast<SliceCPUKernel *>(cdata);
40   return kernel->SliceParallelRun(task_id);
41 }
42 
ReSize()43 int SliceCPUKernel::ReSize() {
44   auto in_tensor = in_tensors_[kNumInput0];
45   auto begin_tensor = in_tensors_[kNumInput1];
46   auto size_tensor = in_tensors_[kNumInput2];
47   MS_ASSERT(in_tensor->shape().size() == static_cast<size_t>(begin_tensor->ElementsNum()));
48   MS_ASSERT(in_tensor->shape().size() == static_cast<size_t>(size_tensor->ElementsNum()));
49   MS_ASSERT(in_tensor->shape().size() <= DIMENSION_8D);
50   auto begin = reinterpret_cast<int32_t *>(begin_tensor->data());
51   CHECK_NULL_RETURN(begin);
52   auto size = reinterpret_cast<int32_t *>(size_tensor->data());
53   CHECK_NULL_RETURN(size);
54 
55   param_->param_length_ = in_tensor->shape().size();
56   if (param_->param_length_ > DIMENSION_8D) {
57     MS_LOG(ERROR) << "input dimension num should <= " << DIMENSION_8D;
58     return RET_ERROR;
59   }
60   for (int i = 0; i < param_->param_length_; ++i) {
61     param_->shape_[i] = in_tensor->DimensionSize(i);
62     param_->begin_[i] = begin[i];
63     param_->size_[i] = size[i] < 0 ? param_->shape_[i] - param_->begin_[i] : size[i];
64     param_->end_[i] = param_->begin_[i] + param_->size_[i];
65   }
66   if (param_->param_length_ < DIMENSION_8D) {
67     PadSliceParameterTo8D(param_);
68   }
69   return RET_OK;
70 }
71 
Init()72 int SliceCPUKernel::Init() {
73   CHECK_LESS_RETURN(in_tensors_.size(), kNumInputSize);
74   CHECK_LESS_RETURN(out_tensors_.size(), 1);
75   CHECK_NULL_RETURN(in_tensors_[kNumInput0]);
76   CHECK_NULL_RETURN(in_tensors_[kNumInput1]);
77   CHECK_NULL_RETURN(in_tensors_[kNumInput2]);
78   CHECK_NULL_RETURN(out_tensors_[0]);
79   CHECK_NULL_RETURN(op_parameter_);
80   if (!InferShapeDone()) {
81     return RET_OK;
82   }
83   return ReSize();
84 }
85 
SliceParallelRun(int thread_id)86 int SliceCPUKernel::SliceParallelRun(int thread_id) {
87   DoSlice(in_tensors_.at(0)->data(), out_tensors_.at(0)->data(), param_, thread_id,
88           lite::DataTypeSize(in_tensors_.at(0)->data_type()));
89   return RET_OK;
90 }
91 
Run()92 int SliceCPUKernel::Run() {
93   auto input_data = reinterpret_cast<float *>(in_tensors_.at(0)->data());
94   auto output_data = reinterpret_cast<float *>(out_tensors_.at(0)->data());
95   if (input_data == nullptr || output_data == nullptr) {
96     return RET_NULL_PTR;
97   }
98   // param_ shape info has already been extended to 8d
99   constexpr size_t kDimHUnder8D = 5;
100   if (param_->size_[kDimHUnder8D] < op_parameter_->thread_num_) {
101     DoSliceNoParallel(input_data, output_data, param_, lite::DataTypeSize(in_tensors_.at(0)->data_type()));
102     return RET_OK;
103   }
104   auto ret = ParallelLaunch(this->ms_context_, SliceLaunch, this, op_parameter_->thread_num_);
105   if (ret != RET_OK) {
106     MS_LOG(ERROR) << "slice launch fail!ret: " << ret;
107     return RET_ERROR;
108   }
109   return RET_OK;
110 }
111 
112 REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_SliceFusion, LiteKernelCreator<SliceCPUKernel>)
113 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SliceFusion, LiteKernelCreator<SliceCPUKernel>)
114 }  // namespace mindspore::kernel
115