1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/runtime/kernel/arm/fp32/scatter_nd_fp32.h"
18 #include <cstring>
19 #include <vector>
20 #include "schema/model_generated.h"
21 #include "src/kernel_registry.h"
22 #include "include/errorcode.h"
23
24 using mindspore::kernel::KERNEL_ARCH;
25 using mindspore::lite::KernelRegistrar;
26 using mindspore::lite::RET_ERROR;
27 using mindspore::lite::RET_OK;
28 using mindspore::schema::PrimitiveType_ScatterNd;
29
30 namespace mindspore::kernel {
31 namespace {
32 constexpr int kScatterShapeIndex = 0;
33 constexpr int kScatterIndicesIndex = 1;
34 constexpr int kScatterUpdateIndex = 2;
35 } // namespace
Init()36 int ScatterNDCPUKernel::Init() {
37 CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_3D);
38 CHECK_LESS_RETURN(out_tensors_.size(), 1);
39 if (!InferShapeDone()) {
40 return RET_OK;
41 }
42 return ReSize();
43 }
44
ReSize()45 int ScatterNDCPUKernel::ReSize() {
46 auto shape = in_tensors_.at(kScatterShapeIndex);
47 auto indices = in_tensors_.at(kScatterIndicesIndex);
48 auto update = in_tensors_.at(kScatterUpdateIndex);
49
50 update_ptr_ = reinterpret_cast<float *>(update->MutableData());
51 MS_ASSERT(update_ptr_ != nullptr);
52 output_ptr_ = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
53 MS_ASSERT(output_ptr_ != nullptr);
54
55 // check indices shape
56 auto shape_rank = shape->ElementsNum();
57 auto shape_data = reinterpret_cast<int *>(shape->MutableData());
58 MS_ASSERT(shape_data != nullptr);
59 auto indice_unit_rank = indices->shape().back();
60 if (indice_unit_rank > shape_rank) {
61 MS_LOG(ERROR) << "Value of last dimension of indices is greater than shape rank.";
62 return RET_ERROR;
63 }
64
65 if (indices->shape().size() < 2) {
66 MS_LOG(ERROR) << "Indices dimension smaller than 2.";
67 return RET_ERROR;
68 }
69
70 // check consistency of the shape indices and shape
71 auto update_rank = static_cast<int>(update->shape().size());
72 auto indices_shape = indices->shape();
73 if (update_rank != static_cast<int>(indices->shape().size() - 1 + shape_rank - indice_unit_rank)) {
74 MS_LOG(ERROR) << "Update, shape rank and indices rank inconsistent.";
75 return RET_ERROR;
76 }
77 // check update shape
78 auto update_shape = update->shape();
79 for (size_t i = 0; i < indices_shape.size() - 1; i++) {
80 if (update_shape.at(i) != indices_shape.at(i)) {
81 MS_LOG(ERROR) << "Value of " << i << " th dimension of indices is not equal to that of update.";
82 return RET_ERROR;
83 }
84 }
85 for (size_t i = 0; i < shape->ElementsNum() - (indices_shape.size() - 1); i++) {
86 if (update_shape.at(i + indices_shape.size() - 1) != shape_data[i + indices_shape.size() - 1]) {
87 MS_LOG(ERROR) << "Value of " << i + indices_shape.size() - 1
88 << " th dimension of indices is not equal to the corresbonding dimension of shape.";
89 return RET_ERROR;
90 }
91 }
92
93 // calculate unit_size_
94 unit_size_ = 1;
95 for (int i = indices_shape.size() - 1; i < update_rank; i++) {
96 unit_size_ *= update_shape.at(i);
97 }
98
99 // calculate offsets
100 int out_stride = 1;
101 out_strides_.push_back(1);
102 for (int i = indice_unit_rank - 2; i >= 0; i--) {
103 out_stride *= shape_data[i + 1];
104 out_strides_.push_back(out_stride);
105 }
106
107 num_unit_ = 1;
108 num_unit_ *= update_shape.at(indices_shape.size() - 2);
109 for (int i = indices_shape.size() - 3; i >= 0; i--) {
110 num_unit_ *= update_shape.at(i);
111 }
112
113 int *indices_ptr = reinterpret_cast<int *>(indices->MutableData());
114 CHECK_NULL_RETURN(indices_ptr);
115 output_unit_offsets_.clear();
116 for (int i = 0; i < num_unit_; i++) {
117 int tmp_stride = 0;
118 for (int j = 0; j < indice_unit_rank; j++) {
119 tmp_stride += indices_ptr[i * indice_unit_rank + j] * out_strides_.at(j) * unit_size_;
120 }
121 output_unit_offsets_.push_back(tmp_stride);
122 }
123
124 thread_n_num_ = MSMIN(op_parameter_->thread_num_, num_unit_);
125 if (thread_n_num_ == 0) {
126 return RET_ERROR;
127 }
128 thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
129 return RET_OK;
130 }
131
ScatterND(int task_id)132 int ScatterNDCPUKernel::ScatterND(int task_id) {
133 int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
134 if (num_unit_thread <= 0) {
135 return RET_OK;
136 }
137 int offset = task_id * thread_n_stride_;
138 MS_LOG(ERROR) << "offset " << offset;
139 auto ret = DoScatterND(output_ptr_, update_ptr_ + offset * unit_size_, output_unit_offsets_.data() + offset,
140 unit_size_, num_unit_thread);
141 if (ret != RET_OK) {
142 MS_LOG(ERROR) << "ScatterND error task_id[" << task_id << "] error_code[" << ret << "]";
143 return RET_ERROR;
144 }
145 return RET_OK;
146 }
147
ScatterNDRun(void * cdata,int task_id,float lhs_scale,float rhs_scale)148 int ScatterNDRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
149 CHECK_NULL_RETURN(cdata);
150 auto g_kernel = reinterpret_cast<ScatterNDCPUKernel *>(cdata);
151 MS_ASSERT(g_kernel != nullptr);
152 auto ret = g_kernel->ScatterND(task_id);
153 if (ret != RET_OK) {
154 MS_LOG(ERROR) << "ScatterNDRun error task_id[" << task_id << "] error_code[" << ret << "]";
155 return RET_ERROR;
156 }
157 return RET_OK;
158 }
159
Run()160 int ScatterNDCPUKernel::Run() {
161 auto ret = ParallelLaunch(this->ms_context_, ScatterNDRun, this, thread_n_num_);
162 if (ret != RET_OK) {
163 MS_LOG(ERROR) << "ScatterND error error_code[" << ret << "]";
164 return RET_ERROR;
165 }
166
167 return RET_OK;
168 }
169
170 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ScatterNd, LiteKernelCreator<ScatterNDCPUKernel>)
171 } // namespace mindspore::kernel
172