1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/runtime/kernel/arm/fp32/scatter_nd_update_fp32.h"
18 #include <cstring>
19 #include "src/runtime/kernel/arm/fp32/scatter_nd_fp32.h"
20 #include "schema/model_generated.h"
21 #include "src/kernel_registry.h"
22 #include "include/errorcode.h"
23
24 using mindspore::kernel::KERNEL_ARCH;
25 using mindspore::lite::KernelRegistrar;
26 using mindspore::lite::RET_ERROR;
27 using mindspore::lite::RET_OK;
28 using mindspore::schema::PrimitiveType_ScatterNdUpdate;
29
30 namespace mindspore::kernel {
31 namespace {
32 constexpr int kScatterUpdateInputIndex = 0;
33 constexpr int kScatterIndicesIndex = 1;
34 constexpr int kScatterUpdateIndex = 2;
35 constexpr size_t kScatterIndicesDims = 2;
36 } // namespace
Init()37 int ScatterNdUpdateCPUKernel::Init() {
38 CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_3D);
39 CHECK_LESS_RETURN(out_tensors_.size(), 1);
40 if (!InferShapeDone()) {
41 return RET_OK;
42 }
43 return ReSize();
44 }
45
ReSize()46 int ScatterNdUpdateCPUKernel::ReSize() {
47 auto input = in_tensors_.at(kScatterUpdateInputIndex);
48 auto indices = in_tensors_.at(kScatterIndicesIndex);
49 auto update = in_tensors_.at(kScatterUpdateIndex);
50 auto output = out_tensors_.front();
51
52 output_ptr_ = reinterpret_cast<float *>(output->MutableData());
53 MS_ASSERT(output_ptr_ != nullptr);
54
55 // check indices shape
56 int input_rank = static_cast<int>(input->shape().size());
57 int indice_unit_rank = indices->shape().back();
58 if (indice_unit_rank > input_rank) {
59 MS_LOG(ERROR) << "Value of last dimension of indices is greater than input rank.";
60 return RET_ERROR;
61 }
62
63 if (indices->shape().size() < kScatterIndicesDims) {
64 MS_LOG(ERROR) << "Indices dimension smaller than 2.";
65 return RET_ERROR;
66 }
67
68 // check consistency of the shape indices and shape
69 int update_rank = static_cast<int>(update->shape().size());
70 auto indices_shape = indices->shape();
71 auto update_shape = update->shape();
72 unit_size_ = 1;
73 for (int i = indices_shape.size() - 1; i < update_rank; i++) {
74 unit_size_ *= update_shape.at(i);
75 }
76
77 // calculate offsets
78 int out_stride = 1;
79 out_strides_.push_back(1);
80 for (int i = indice_unit_rank - 2; i >= 0; i--) {
81 out_stride *= input->shape()[i + 1];
82 out_strides_.push_back(out_stride);
83 }
84 std::reverse(out_strides_.begin(), out_strides_.end());
85
86 num_unit_ = 1;
87 num_unit_ *= update_shape.at(indices_shape.size() - 2);
88 for (int i = indices_shape.size() - 3; i >= 0; i--) {
89 num_unit_ *= update_shape.at(i);
90 }
91
92 int *indices_ptr = reinterpret_cast<int *>(indices->MutableData());
93 MS_ASSERT(indices_ptr != nullptr);
94 output_unit_offsets_.clear();
95 for (int i = 0; i < num_unit_; i++) {
96 int tmp_stride = 0;
97 for (int j = 0; j < indice_unit_rank; j++) {
98 tmp_stride += indices_ptr[i * indice_unit_rank + j] * out_strides_.at(j) * unit_size_;
99 }
100 output_unit_offsets_.push_back(tmp_stride);
101 }
102
103 thread_n_num_ = MSMIN(op_parameter_->thread_num_, num_unit_);
104 if (thread_n_num_ == 0) {
105 return RET_ERROR;
106 }
107 thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
108 return RET_OK;
109 }
110
ScatterNdUpdate(int task_id)111 int ScatterNdUpdateCPUKernel::ScatterNdUpdate(int task_id) {
112 int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
113 if (num_unit_thread <= 0) {
114 return RET_OK;
115 }
116 int offset = task_id * thread_n_stride_;
117 auto ret = DoScatterND(output_ptr_, update_ptr_ + offset * unit_size_, output_unit_offsets_.data() + offset,
118 unit_size_, num_unit_thread);
119 if (ret != RET_OK) {
120 MS_LOG(ERROR) << "ScatterNdUpdate error task_id[" << task_id << "] error_code[" << ret << "]";
121 return RET_ERROR;
122 }
123 in_tensors_.at(kScatterUpdateInputIndex)->IncRefCount();
124 return RET_OK;
125 }
126
ScatterNdUpdateRun(void * cdata,int task_id,float lhs_scale,float rhs_scale)127 int ScatterNdUpdateRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
128 auto g_kernel = reinterpret_cast<ScatterNdUpdateCPUKernel *>(cdata);
129 MS_ASSERT(g_kernel != nullptr);
130 auto ret = g_kernel->ScatterNdUpdate(task_id);
131 if (ret != RET_OK) {
132 MS_LOG(ERROR) << "ScatterNdUpdateRun error task_id[" << task_id << "] error_code[" << ret << "]";
133 return RET_ERROR;
134 }
135 return RET_OK;
136 }
137
Run()138 int ScatterNdUpdateCPUKernel::Run() {
139 auto in_tensor = in_tensors().front();
140 auto out_tensor = out_tensors().front();
141 if (in_tensor->allocator() == nullptr || in_tensor->allocator() != out_tensor->allocator() ||
142 op_parameter_->is_train_session_) {
143 memcpy(out_tensor->data(), in_tensor->data(), in_tensor->Size());
144 } else {
145 out_tensor->FreeData();
146 out_tensor->ResetRefCount();
147 in_tensor->allocator()->IncRefCount(in_tensor->data(), out_tensor->ref_count());
148 out_tensor->set_data(in_tensor->data());
149 out_tensor->set_own_data(in_tensor->own_data());
150 output_ptr_ = reinterpret_cast<float *>(out_tensor->data());
151 }
152 auto indices = in_tensors_.at(kScatterIndicesIndex);
153 if (!indices->IsConst() && ReSize() != RET_OK) {
154 MS_LOG(ERROR) << "ScatterNdUpdate resize failed.";
155 return RET_ERROR;
156 }
157 auto update = in_tensors_.at(kScatterUpdateIndex);
158 update_ptr_ = reinterpret_cast<float *>(update->MutableData());
159 MS_ASSERT(update_ptr_ != nullptr);
160
161 auto ret = ParallelLaunch(this->ms_context_, ScatterNdUpdateRun, this, thread_n_num_);
162 if (ret != RET_OK) {
163 MS_LOG(ERROR) << "ScatterNdUpdate error error_code[" << ret << "]";
164 return RET_ERROR;
165 }
166
167 return RET_OK;
168 }
169
170 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ScatterNdUpdate, LiteKernelCreator<ScatterNdUpdateCPUKernel>)
171 } // namespace mindspore::kernel
172