• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/runtime/kernel/arm/fp32/gather_fp32.h"
18 #include <limits>
19 #include "schema/model_generated.h"
20 #include "src/kernel_registry.h"
21 
22 using mindspore::kernel::KERNEL_ARCH;
23 using mindspore::lite::KernelRegistrar;
24 using mindspore::lite::RET_ERROR;
25 using mindspore::lite::RET_OK;
26 using mindspore::schema::PrimitiveType_Gather;
27 
28 namespace mindspore::kernel {
29 namespace {
30 constexpr int kSecondInput = 2;
31 }
Init()32 int GatherCPUKernel::Init() {
33   CHECK_LESS_RETURN(in_tensors_.size(), kInputSize2);
34   CHECK_LESS_RETURN(out_tensors_.size(), 1);
35   CHECK_NULL_RETURN(in_tensors_.at(kSecondInput)->data());
36   axis_ = *(reinterpret_cast<int *>(in_tensors_.at(kSecondInput)->data()));
37   if (!InferShapeDone()) {
38     return RET_OK;
39   }
40   return ReSize();
41 }
42 
ReSize()43 int GatherCPUKernel::ReSize() { return RET_OK; }
44 
DoGather(int task_id)45 int GatherCPUKernel::DoGather(int task_id) {
46   auto input_tensor = in_tensors_.at(0);
47   auto indices_tensor = in_tensors_.at(1);
48   auto out_tensor = out_tensors_.at(0);
49 
50   auto in_shape = input_tensor->shape();
51   int in_rank = in_shape.size();
52   int indices_element_size = indices_tensor->ElementsNum();
53   MS_CHECK_LT(axis_, in_rank, RET_ERROR);
54   const int limit = in_shape.at(axis_);
55 
56   int outer_size = 1, inner_size = 1;
57   for (int i = 0; i < axis_; ++i) {
58     outer_size *= in_shape.at(i);
59   }
60   for (int i = axis_ + 1; i < in_rank; ++i) {
61     inner_size *= in_shape.at(i);
62   }
63   int stride = UP_DIV(outer_size, op_parameter_->thread_num_);
64   int count = MSMIN(stride, outer_size - stride * task_id);
65   if (count <= 0) {
66     return RET_OK;
67   }
68   auto thread_stride = stride * task_id;
69 
70   int8_t *int8_in = reinterpret_cast<int8_t *>(input_tensor->data());
71   CHECK_NULL_RETURN(int8_in);
72   int8_t *int8_out = reinterpret_cast<int8_t *>(out_tensor->data());
73   CHECK_NULL_RETURN(int8_out);
74 
75   int data_size = static_cast<int>(lite::DataTypeSize(input_tensor->data_type()));
76   int8_in += thread_stride * limit * inner_size * data_size;
77   int8_out += thread_stride * indices_element_size * inner_size * data_size;
78 
79   int error_code = Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, data_size);
80 
81   return error_code;
82 }
83 
GatherRun(void * cdata,int task_id,float lhs_scale,float rhs_scale)84 int GatherRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
85   auto gather_kernel = reinterpret_cast<GatherCPUKernel *>(cdata);
86   auto error_code = gather_kernel->DoGather(task_id);
87   if (error_code != RET_OK) {
88     MS_LOG(ERROR) << "GatherRun error task_id[" << task_id << "] error_code[" << error_code << "]";
89   }
90   return error_code;
91 }
92 
Run()93 int GatherCPUKernel::Run() {
94   auto indices_tensor = in_tensors_.at(1);
95   int indices_num = indices_tensor->ElementsNum();
96   bool isIndicesInt32 = indices_tensor->data_type() == kNumberTypeInt32;
97   int ret = AssignIndicesData(isIndicesInt32, indices_num, indices_tensor);
98   if (ret != RET_OK) {
99     MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]";
100     return ret;
101   }
102 
103   ret = ParallelLaunch(this->ms_context_, GatherRun, this, op_parameter_->thread_num_);
104   if (ret != RET_OK) {
105     MS_LOG(ERROR) << "Gather function error error_code[" << ret << "]";
106   }
107   if (!isIndicesInt32) {
108     ms_context_->allocator->Free(indices_data_);
109     indices_data_ = nullptr;
110   }
111   return ret;
112 }
113 
AssignIndicesData(bool isIndicesInt32,int indices_num,lite::Tensor * indices_tensor)114 int GatherCPUKernel::AssignIndicesData(bool isIndicesInt32, int indices_num, lite::Tensor *indices_tensor) {
115   if (!isIndicesInt32) {
116     if (indices_num >= std::numeric_limits<int>::max() / static_cast<int>(sizeof(int))) {
117       MS_LOG(ERROR) << "Input indices_num is invalid, indices_num: " << indices_num;
118       return RET_ERROR;
119     }
120     indices_data_ = reinterpret_cast<int32_t *>(ms_context_->allocator->Malloc(sizeof(int32_t) * indices_num));
121     if (indices_data_ == nullptr) {
122       MS_LOG(ERROR) << "Memory allocation failed";
123       return RET_ERROR;
124     }
125     if (indices_tensor->data_type() == kNumberTypeInt64) {
126       for (int i = 0; i < indices_num; i++) {
127         indices_data_[i] = reinterpret_cast<int64_t *>(indices_tensor->MutableData())[i];
128       }
129     } else {
130       for (int i = 0; i < indices_num; i++) {
131         indices_data_[i] = static_cast<int>(reinterpret_cast<float *>(indices_tensor->MutableData())[i]);
132       }
133     }
134   } else {
135     indices_data_ = reinterpret_cast<int32_t *>(indices_tensor->MutableData());
136   }
137   return RET_OK;
138 }
139 
140 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Gather, LiteKernelCreator<GatherCPUKernel>)
141 REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Gather, LiteKernelCreator<GatherCPUKernel>)
142 }  // namespace mindspore::kernel
143