• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/runtime/kernel/arm/fp32/scatter_nd_update_fp32.h"
18 #include <cstring>
19 #include "src/runtime/kernel/arm/fp32/scatter_nd_fp32.h"
20 #include "schema/model_generated.h"
21 #include "src/kernel_registry.h"
22 #include "include/errorcode.h"
23 
24 using mindspore::kernel::KERNEL_ARCH;
25 using mindspore::lite::KernelRegistrar;
26 using mindspore::lite::RET_ERROR;
27 using mindspore::lite::RET_OK;
28 using mindspore::schema::PrimitiveType_ScatterNdUpdate;
29 
30 namespace mindspore::kernel {
31 namespace {
32 constexpr int kScatterUpdateInputIndex = 0;
33 constexpr int kScatterIndicesIndex = 1;
34 constexpr int kScatterUpdateIndex = 2;
35 constexpr size_t kScatterIndicesDims = 2;
36 }  // namespace
Init()37 int ScatterNdUpdateCPUKernel::Init() {
38   CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_3D);
39   CHECK_LESS_RETURN(out_tensors_.size(), 1);
40   if (!InferShapeDone()) {
41     return RET_OK;
42   }
43   return ReSize();
44 }
45 
ReSize()46 int ScatterNdUpdateCPUKernel::ReSize() {
47   auto input = in_tensors_.at(kScatterUpdateInputIndex);
48   auto indices = in_tensors_.at(kScatterIndicesIndex);
49   auto update = in_tensors_.at(kScatterUpdateIndex);
50   auto output = out_tensors_.front();
51 
52   output_ptr_ = reinterpret_cast<float *>(output->MutableData());
53   MS_ASSERT(output_ptr_ != nullptr);
54 
55   // check indices shape
56   int input_rank = static_cast<int>(input->shape().size());
57   int indice_unit_rank = indices->shape().back();
58   if (indice_unit_rank > input_rank) {
59     MS_LOG(ERROR) << "Value of last dimension of indices is greater than input rank.";
60     return RET_ERROR;
61   }
62 
63   if (indices->shape().size() < kScatterIndicesDims) {
64     MS_LOG(ERROR) << "Indices dimension smaller than 2.";
65     return RET_ERROR;
66   }
67 
68   // check consistency of the shape indices and shape
69   int update_rank = static_cast<int>(update->shape().size());
70   auto indices_shape = indices->shape();
71   auto update_shape = update->shape();
72   unit_size_ = 1;
73   for (int i = indices_shape.size() - 1; i < update_rank; i++) {
74     unit_size_ *= update_shape.at(i);
75   }
76 
77   // calculate offsets
78   int out_stride = 1;
79   out_strides_.push_back(1);
80   for (int i = indice_unit_rank - 2; i >= 0; i--) {
81     out_stride *= input->shape()[i + 1];
82     out_strides_.push_back(out_stride);
83   }
84   std::reverse(out_strides_.begin(), out_strides_.end());
85 
86   num_unit_ = 1;
87   num_unit_ *= update_shape.at(indices_shape.size() - 2);
88   for (int i = indices_shape.size() - 3; i >= 0; i--) {
89     num_unit_ *= update_shape.at(i);
90   }
91 
92   int *indices_ptr = reinterpret_cast<int *>(indices->MutableData());
93   MS_ASSERT(indices_ptr != nullptr);
94   output_unit_offsets_.clear();
95   for (int i = 0; i < num_unit_; i++) {
96     int tmp_stride = 0;
97     for (int j = 0; j < indice_unit_rank; j++) {
98       tmp_stride += indices_ptr[i * indice_unit_rank + j] * out_strides_.at(j) * unit_size_;
99     }
100     output_unit_offsets_.push_back(tmp_stride);
101   }
102 
103   thread_n_num_ = MSMIN(op_parameter_->thread_num_, num_unit_);
104   if (thread_n_num_ == 0) {
105     return RET_ERROR;
106   }
107   thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
108   return RET_OK;
109 }
110 
ScatterNdUpdate(int task_id)111 int ScatterNdUpdateCPUKernel::ScatterNdUpdate(int task_id) {
112   int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
113   if (num_unit_thread <= 0) {
114     return RET_OK;
115   }
116   int offset = task_id * thread_n_stride_;
117   auto ret = DoScatterND(output_ptr_, update_ptr_ + offset * unit_size_, output_unit_offsets_.data() + offset,
118                          unit_size_, num_unit_thread);
119   if (ret != RET_OK) {
120     MS_LOG(ERROR) << "ScatterNdUpdate error task_id[" << task_id << "] error_code[" << ret << "]";
121     return RET_ERROR;
122   }
123   in_tensors_.at(kScatterUpdateInputIndex)->IncRefCount();
124   return RET_OK;
125 }
126 
ScatterNdUpdateRun(void * cdata,int task_id,float lhs_scale,float rhs_scale)127 int ScatterNdUpdateRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
128   auto g_kernel = reinterpret_cast<ScatterNdUpdateCPUKernel *>(cdata);
129   MS_ASSERT(g_kernel != nullptr);
130   auto ret = g_kernel->ScatterNdUpdate(task_id);
131   if (ret != RET_OK) {
132     MS_LOG(ERROR) << "ScatterNdUpdateRun error task_id[" << task_id << "] error_code[" << ret << "]";
133     return RET_ERROR;
134   }
135   return RET_OK;
136 }
137 
Run()138 int ScatterNdUpdateCPUKernel::Run() {
139   auto in_tensor = in_tensors().front();
140   auto out_tensor = out_tensors().front();
141   if (in_tensor->allocator() == nullptr || in_tensor->allocator() != out_tensor->allocator() ||
142       op_parameter_->is_train_session_) {
143     memcpy(out_tensor->data(), in_tensor->data(), in_tensor->Size());
144   } else {
145     out_tensor->FreeData();
146     out_tensor->ResetRefCount();
147     in_tensor->allocator()->IncRefCount(in_tensor->data(), out_tensor->ref_count());
148     out_tensor->set_data(in_tensor->data());
149     out_tensor->set_own_data(in_tensor->own_data());
150     output_ptr_ = reinterpret_cast<float *>(out_tensor->data());
151   }
152   auto indices = in_tensors_.at(kScatterIndicesIndex);
153   if (!indices->IsConst() && ReSize() != RET_OK) {
154     MS_LOG(ERROR) << "ScatterNdUpdate resize failed.";
155     return RET_ERROR;
156   }
157   auto update = in_tensors_.at(kScatterUpdateIndex);
158   update_ptr_ = reinterpret_cast<float *>(update->MutableData());
159   MS_ASSERT(update_ptr_ != nullptr);
160 
161   auto ret = ParallelLaunch(this->ms_context_, ScatterNdUpdateRun, this, thread_n_num_);
162   if (ret != RET_OK) {
163     MS_LOG(ERROR) << "ScatterNdUpdate error error_code[" << ret << "]";
164     return RET_ERROR;
165   }
166 
167   return RET_OK;
168 }
169 
170 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ScatterNdUpdate, LiteKernelCreator<ScatterNdUpdateCPUKernel>)
171 }  // namespace mindspore::kernel
172