• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/runtime/kernel/arm/fp32/scatter_nd_fp32.h"
18 #include <cstring>
19 #include <vector>
20 #include "schema/model_generated.h"
21 #include "src/kernel_registry.h"
22 #include "include/errorcode.h"
23 
24 using mindspore::kernel::KERNEL_ARCH;
25 using mindspore::lite::KernelRegistrar;
26 using mindspore::lite::RET_ERROR;
27 using mindspore::lite::RET_OK;
28 using mindspore::schema::PrimitiveType_ScatterNd;
29 
30 namespace mindspore::kernel {
31 namespace {
32 constexpr int kScatterShapeIndex = 0;
33 constexpr int kScatterIndicesIndex = 1;
34 constexpr int kScatterUpdateIndex = 2;
35 }  // namespace
Init()36 int ScatterNDCPUKernel::Init() {
37   CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_3D);
38   CHECK_LESS_RETURN(out_tensors_.size(), 1);
39   if (!InferShapeDone()) {
40     return RET_OK;
41   }
42   return ReSize();
43 }
44 
ReSize()45 int ScatterNDCPUKernel::ReSize() {
46   auto shape = in_tensors_.at(kScatterShapeIndex);
47   auto indices = in_tensors_.at(kScatterIndicesIndex);
48   auto update = in_tensors_.at(kScatterUpdateIndex);
49 
50   update_ptr_ = reinterpret_cast<float *>(update->MutableData());
51   MS_ASSERT(update_ptr_ != nullptr);
52   output_ptr_ = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
53   MS_ASSERT(output_ptr_ != nullptr);
54 
55   // check indices shape
56   auto shape_rank = shape->ElementsNum();
57   auto shape_data = reinterpret_cast<int *>(shape->MutableData());
58   MS_ASSERT(shape_data != nullptr);
59   auto indice_unit_rank = indices->shape().back();
60   if (indice_unit_rank > shape_rank) {
61     MS_LOG(ERROR) << "Value of last dimension of indices is greater than shape rank.";
62     return RET_ERROR;
63   }
64 
65   if (indices->shape().size() < 2) {
66     MS_LOG(ERROR) << "Indices dimension smaller than 2.";
67     return RET_ERROR;
68   }
69 
70   // check consistency of the shape indices and shape
71   auto update_rank = static_cast<int>(update->shape().size());
72   auto indices_shape = indices->shape();
73   if (update_rank != static_cast<int>(indices->shape().size() - 1 + shape_rank - indice_unit_rank)) {
74     MS_LOG(ERROR) << "Update, shape rank and indices rank inconsistent.";
75     return RET_ERROR;
76   }
77   // check update shape
78   auto update_shape = update->shape();
79   for (size_t i = 0; i < indices_shape.size() - 1; i++) {
80     if (update_shape.at(i) != indices_shape.at(i)) {
81       MS_LOG(ERROR) << "Value of " << i << " th dimension of indices is not equal to that of update.";
82       return RET_ERROR;
83     }
84   }
85   for (size_t i = 0; i < shape->ElementsNum() - (indices_shape.size() - 1); i++) {
86     if (update_shape.at(i + indices_shape.size() - 1) != shape_data[i + indices_shape.size() - 1]) {
87       MS_LOG(ERROR) << "Value of " << i + indices_shape.size() - 1
88                     << " th dimension of indices is not equal to the corresbonding dimension of shape.";
89       return RET_ERROR;
90     }
91   }
92 
93   // calculate unit_size_
94   unit_size_ = 1;
95   for (int i = indices_shape.size() - 1; i < update_rank; i++) {
96     unit_size_ *= update_shape.at(i);
97   }
98 
99   // calculate offsets
100   int out_stride = 1;
101   out_strides_.push_back(1);
102   for (int i = indice_unit_rank - 2; i >= 0; i--) {
103     out_stride *= shape_data[i + 1];
104     out_strides_.push_back(out_stride);
105   }
106 
107   num_unit_ = 1;
108   num_unit_ *= update_shape.at(indices_shape.size() - 2);
109   for (int i = indices_shape.size() - 3; i >= 0; i--) {
110     num_unit_ *= update_shape.at(i);
111   }
112 
113   int *indices_ptr = reinterpret_cast<int *>(indices->MutableData());
114   CHECK_NULL_RETURN(indices_ptr);
115   output_unit_offsets_.clear();
116   for (int i = 0; i < num_unit_; i++) {
117     int tmp_stride = 0;
118     for (int j = 0; j < indice_unit_rank; j++) {
119       tmp_stride += indices_ptr[i * indice_unit_rank + j] * out_strides_.at(j) * unit_size_;
120     }
121     output_unit_offsets_.push_back(tmp_stride);
122   }
123 
124   thread_n_num_ = MSMIN(op_parameter_->thread_num_, num_unit_);
125   if (thread_n_num_ == 0) {
126     return RET_ERROR;
127   }
128   thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
129   return RET_OK;
130 }
131 
ScatterND(int task_id)132 int ScatterNDCPUKernel::ScatterND(int task_id) {
133   int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
134   if (num_unit_thread <= 0) {
135     return RET_OK;
136   }
137   int offset = task_id * thread_n_stride_;
138   MS_LOG(ERROR) << "offset " << offset;
139   auto ret = DoScatterND(output_ptr_, update_ptr_ + offset * unit_size_, output_unit_offsets_.data() + offset,
140                          unit_size_, num_unit_thread);
141   if (ret != RET_OK) {
142     MS_LOG(ERROR) << "ScatterND error task_id[" << task_id << "] error_code[" << ret << "]";
143     return RET_ERROR;
144   }
145   return RET_OK;
146 }
147 
ScatterNDRun(void * cdata,int task_id,float lhs_scale,float rhs_scale)148 int ScatterNDRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
149   CHECK_NULL_RETURN(cdata);
150   auto g_kernel = reinterpret_cast<ScatterNDCPUKernel *>(cdata);
151   MS_ASSERT(g_kernel != nullptr);
152   auto ret = g_kernel->ScatterND(task_id);
153   if (ret != RET_OK) {
154     MS_LOG(ERROR) << "ScatterNDRun error task_id[" << task_id << "] error_code[" << ret << "]";
155     return RET_ERROR;
156   }
157   return RET_OK;
158 }
159 
Run()160 int ScatterNDCPUKernel::Run() {
161   auto ret = ParallelLaunch(this->ms_context_, ScatterNDRun, this, thread_n_num_);
162   if (ret != RET_OK) {
163     MS_LOG(ERROR) << "ScatterND error error_code[" << ret << "]";
164     return RET_ERROR;
165   }
166 
167   return RET_OK;
168 }
169 
170 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ScatterNd, LiteKernelCreator<ScatterNDCPUKernel>)
171 }  // namespace mindspore::kernel
172