1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_OPTIMIZER_CPU_KERNEL_H_ 18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_OPTIMIZER_CPU_KERNEL_H_ 19 20 #include <vector> 21 #include <memory> 22 #include <unordered_map> 23 #include <algorithm> 24 #include <utility> 25 #include "backend/kernel_compiler/cpu/cpu_kernel.h" 26 #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" 27 #include "common/thread_pool.h" 28 namespace mindspore { 29 namespace kernel { 30 template <typename T> 31 struct SparseGradient { 32 float *value_{nullptr}; 33 T *indices_{nullptr}; 34 size_t indices_size_{0}; 35 }; 36 37 template <typename T> 38 struct ReduceSparseGradientParam { 39 SparseGradient<T> *input_grad_{nullptr}; 40 SparseGradient<T> *workspace_grad_{nullptr}; 41 SparseGradient<T> *output_grad_{nullptr}; 42 size_t max_index_{0}; 43 size_t value_stride_{0}; 44 bool use_sort_reduce_{false}; 45 }; 46 47 template <typename T> 48 struct MultiThreadComputeParams { 49 float *var_{nullptr}; 50 float *accum_{nullptr}; 51 float *linear_{nullptr}; 52 float *m_{nullptr}; 53 float *m_t_{nullptr}; 54 float *v_{nullptr}; 55 float lr_{0}; 56 float l1_{0}; 57 float l2_{0}; 58 float lr_power_{0}; 59 float beta1_{0}; 60 float beta2_{0}; 61 float epsilon_{0}; 62 SparseGradient<T> sparse_grad_; 63 size_t var_first_dim_size_{0}; 64 size_t var_outer_dim_size_{0}; 65 bool use_nesterov_; 66 }; 67 68 template <typename T> 69 using MultiThreadComputeFunc = std::function<void(MultiThreadComputeParams<T> *param, size_t start, size_t end)>; 70 71 template <typename T> 72 struct BucketSparseGradient { 73 float *value_; 74 T *indices_; 75 T *global_indices_; 76 size_t indices_size_; 77 }; 78 79 template <typename T> 80 struct MultiThreadReduceSparseGradientParam { 81 SparseGradient<T> *input_grad_{nullptr}; 82 SparseGradient<T> *workspace_grad_{nullptr}; 83 SparseGradient<T> *output_grad_{nullptr}; 84 size_t max_index_{0}; 85 size_t value_stride_{0}; 86 size_t thread_num_{0}; 87 bool use_sort_reduce_{false}; 88 }; 89 90 class SparseOptimizerCPUKernel : public CPUKernel { 91 public: 92 SparseOptimizerCPUKernel() = default; 93 ~SparseOptimizerCPUKernel() override = default; 94 95 template <typename T> BucketReduceSparseGradient(const ReduceSparseGradientParam<T> & param)96 static void BucketReduceSparseGradient(const ReduceSparseGradientParam<T> ¶m) { 97 MS_LOG(DEBUG) << "Start"; 98 MS_EXCEPTION_IF_NULL(param.input_grad_); 99 size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); 100 if (param.input_grad_->indices_size_ < thread_num) { 101 thread_num = param.input_grad_->indices_size_; 102 } 103 MultiThreadReduceSparseGradientParam<T> multi_thread_param( 104 {param.input_grad_, param.workspace_grad_, param.output_grad_, param.max_index_, param.value_stride_, thread_num, 105 param.use_sort_reduce_}); 106 std::vector<std::shared_ptr<SparseGradient<T>>> segments; 107 std::vector<std::shared_ptr<std::vector<size_t>>> segment_bucket_sizes; 108 SplitAndCalculateSegmentBucketSize(multi_thread_param, &segments, &segment_bucket_sizes); 109 110 std::vector<std::shared_ptr<BucketSparseGradient<T>>> buckets; 111 GatherSegmentIndicesToOutputBucket(multi_thread_param, segments, segment_bucket_sizes, &buckets); 112 113 std::vector<std::shared_ptr<SparseGradient<T>>> reduced_buckets; 114 ReduceBucketSparseGradientToWorkspace(multi_thread_param, buckets, &reduced_buckets); 115 116 MergeReduceSparseGradient(multi_thread_param, reduced_buckets); 117 MS_LOG(DEBUG) << "End"; 118 } 119 120 protected: 121 template <typename T> MultiThreadCompute(const MultiThreadComputeFunc<T> & func,MultiThreadComputeParams<T> * params,size_t total_compute_size)122 void MultiThreadCompute(const MultiThreadComputeFunc<T> &func, MultiThreadComputeParams<T> *params, 123 size_t total_compute_size) const { 124 std::vector<common::Task> tasks; 125 auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); 126 tasks.reserve(max_thread_num); 127 size_t start = 0; 128 size_t once_compute_size = (total_compute_size + max_thread_num - 1) / max_thread_num; 129 while (start < total_compute_size) { 130 size_t end = (start + once_compute_size) > total_compute_size ? total_compute_size : (start + once_compute_size); 131 auto task = [&func, ¶ms, start, end]() { 132 func(params, start, end); 133 return common::SUCCESS; 134 }; 135 (void)tasks.emplace_back(task); 136 start += once_compute_size; 137 } 138 (void)common::ThreadPool::GetInstance().SyncRun(tasks); 139 } 140 141 private: 142 template <typename T> CalculateEachBucketSize(const std::shared_ptr<SparseGradient<T>> & sparse_grad,size_t max_index,std::vector<size_t> * each_bucket_size)143 static void CalculateEachBucketSize(const std::shared_ptr<SparseGradient<T>> &sparse_grad, size_t max_index, 144 std::vector<size_t> *each_bucket_size) { 145 MS_LOG(DEBUG) << "Start"; 146 MS_EXCEPTION_IF_NULL(sparse_grad); 147 MS_EXCEPTION_IF_NULL(sparse_grad->indices_); 148 MS_EXCEPTION_IF_NULL(each_bucket_size); 149 size_t bucket_num = each_bucket_size->size(); 150 if (bucket_num < 1) { 151 MS_LOG(EXCEPTION) << "Bucket num must > 0!"; 152 } 153 for (size_t i = 0; i < sparse_grad->indices_size_; ++i) { 154 T index = sparse_grad->indices_[i]; 155 if (index >= 0 && LongToSize(index) < max_index) { 156 auto bucket_id = index % bucket_num; 157 each_bucket_size->at(bucket_id)++; 158 } 159 } 160 MS_LOG(DEBUG) << "End"; 161 } 162 163 template <typename T> SplitAndCalculateSegmentBucketSize(const MultiThreadReduceSparseGradientParam<T> & param,std::vector<std::shared_ptr<SparseGradient<T>>> * segments_ptr,std::vector<std::shared_ptr<std::vector<size_t>>> * segment_bucket_sizes_ptr)164 static void SplitAndCalculateSegmentBucketSize( 165 const MultiThreadReduceSparseGradientParam<T> ¶m, std::vector<std::shared_ptr<SparseGradient<T>>> *segments_ptr, 166 std::vector<std::shared_ptr<std::vector<size_t>>> *segment_bucket_sizes_ptr) { 167 MS_EXCEPTION_IF_NULL(param.input_grad_); 168 MS_EXCEPTION_IF_NULL(segment_bucket_sizes_ptr); 169 MS_EXCEPTION_IF_NULL(segments_ptr); 170 auto &segments = *segments_ptr; 171 auto &segment_bucket_sizes = *segment_bucket_sizes_ptr; 172 auto input_grad = param.input_grad_; 173 if (param.thread_num_ < 1) { 174 MS_EXCEPTION(ArgumentError) << "Input param thread num must > 0!"; 175 } 176 size_t thread_indices_size = input_grad->indices_size_ / param.thread_num_; 177 size_t left_indices_size = input_grad->indices_size_ % param.thread_num_; 178 std::vector<common::Task> tasks; 179 tasks.reserve(param.thread_num_); 180 segments.reserve(param.thread_num_); 181 182 size_t current_indices_offset = 0; 183 for (size_t i = 0; i < param.thread_num_; ++i) { 184 (void)segment_bucket_sizes.emplace_back(std::make_shared<std::vector<size_t>>(param.thread_num_, 0)); 185 size_t indices_size = thread_indices_size; 186 if (i < left_indices_size) { 187 indices_size += 1; 188 } 189 (void)segments.emplace_back(std::make_shared<SparseGradient<T>>()); 190 segments[i]->value_ = input_grad->value_ + current_indices_offset * param.value_stride_; 191 segments[i]->indices_ = input_grad->indices_ + current_indices_offset; 192 segments[i]->indices_size_ = indices_size; 193 auto task = [&segments, ¶m, &segment_bucket_sizes, i]() { 194 CalculateEachBucketSize<T>(segments[i], param.max_index_, segment_bucket_sizes[i].get()); 195 return common::SUCCESS; 196 }; 197 (void)tasks.emplace_back(task); 198 current_indices_offset += indices_size; 199 } 200 (void)common::ThreadPool::GetInstance().SyncRun(tasks); 201 } 202 203 template <typename T> CopySegmentIndicesToBucket(const MultiThreadReduceSparseGradientParam<T> & param,const std::shared_ptr<SparseGradient<T>> & segment,size_t bucket_offset,const std::vector<std::shared_ptr<BucketSparseGradient<T>>> & buckets)204 static void CopySegmentIndicesToBucket(const MultiThreadReduceSparseGradientParam<T> ¶m, 205 const std::shared_ptr<SparseGradient<T>> &segment, size_t bucket_offset, 206 const std::vector<std::shared_ptr<BucketSparseGradient<T>>> &buckets) { 207 MS_LOG(DEBUG) << "Start"; 208 MS_EXCEPTION_IF_NULL(segment); 209 MS_EXCEPTION_IF_NULL(segment->indices_); 210 if (param.thread_num_ == 0) { 211 MS_EXCEPTION(ArgumentError) << "Input param thread num must > 0!"; 212 } 213 std::vector<size_t> bucket_data_num(param.thread_num_, 0); 214 for (size_t i = 0; i < segment->indices_size_; ++i) { 215 T index = segment->indices_[i]; 216 if (index >= 0 && LongToSize(index) < param.max_index_) { 217 auto bucket_id = index % param.thread_num_; 218 auto bucket_index = bucket_data_num[bucket_id]; 219 buckets[bucket_id]->indices_[bucket_index] = index; 220 buckets[bucket_id]->global_indices_[bucket_index] = bucket_offset + i; 221 bucket_data_num[bucket_id]++; 222 } 223 } 224 MS_LOG(DEBUG) << "End"; 225 } 226 227 template <typename T> GatherSegmentIndicesToOutputBucket(const MultiThreadReduceSparseGradientParam<T> & param,const std::vector<std::shared_ptr<SparseGradient<T>>> & segments,const std::vector<std::shared_ptr<std::vector<size_t>>> & segment_bucket_sizes,std::vector<std::shared_ptr<BucketSparseGradient<T>>> * buckets_ptr)228 static void GatherSegmentIndicesToOutputBucket( 229 const MultiThreadReduceSparseGradientParam<T> ¶m, 230 const std::vector<std::shared_ptr<SparseGradient<T>>> &segments, 231 const std::vector<std::shared_ptr<std::vector<size_t>>> &segment_bucket_sizes, 232 std::vector<std::shared_ptr<BucketSparseGradient<T>>> *buckets_ptr) { 233 MS_EXCEPTION_IF_NULL(param.output_grad_); 234 MS_EXCEPTION_IF_NULL(param.output_grad_->value_); 235 MS_EXCEPTION_IF_NULL(param.output_grad_->indices_); 236 MS_EXCEPTION_IF_NULL(buckets_ptr); 237 auto &buckets = *buckets_ptr; 238 size_t thread_num = param.thread_num_; 239 if (thread_num != segment_bucket_sizes.size()) { 240 MS_EXCEPTION(ArgumentError) << "Input param thread num not equal to segment size!"; 241 } 242 std::vector<size_t> bucket_data_size(thread_num, 0); 243 for (size_t i = 0; i < thread_num; ++i) { 244 for (size_t j = 0; j < thread_num; ++j) { 245 bucket_data_size[j] += segment_bucket_sizes[i]->at(j); 246 } 247 } 248 size_t current_indices_offset = 0; 249 for (size_t i = 0; i < thread_num; ++i) { 250 (void)buckets.emplace_back(std::make_shared<BucketSparseGradient<T>>()); 251 buckets[i]->value_ = param.output_grad_->value_ + current_indices_offset * param.value_stride_; 252 buckets[i]->indices_ = param.output_grad_->indices_ + current_indices_offset; 253 buckets[i]->global_indices_ = param.workspace_grad_->indices_ + current_indices_offset; 254 buckets[i]->indices_size_ = bucket_data_size[i]; 255 current_indices_offset += bucket_data_size[i]; 256 } 257 std::vector<size_t> tmp_bucket_data_size(thread_num, 0); 258 std::vector<std::vector<std::shared_ptr<BucketSparseGradient<T>>>> each_thread_buckets; 259 for (size_t i = 0; i < thread_num; ++i) { 260 std::vector<std::shared_ptr<BucketSparseGradient<T>>> thread_buckets; 261 for (size_t j = 0; j < thread_num; ++j) { 262 (void)thread_buckets.emplace_back(std::make_shared<BucketSparseGradient<T>>()); 263 thread_buckets[j]->indices_ = buckets[j]->indices_ + tmp_bucket_data_size[j]; 264 thread_buckets[j]->global_indices_ = buckets[j]->global_indices_ + tmp_bucket_data_size[j]; 265 thread_buckets[j]->value_ = buckets[j]->value_ + tmp_bucket_data_size[j] * param.value_stride_; 266 thread_buckets[j]->indices_size_ = segment_bucket_sizes[i]->at(j); 267 tmp_bucket_data_size[j] += segment_bucket_sizes[i]->at(j); 268 } 269 (void)each_thread_buckets.emplace_back(thread_buckets); 270 } 271 std::vector<common::Task> tasks; 272 tasks.reserve(thread_num); 273 current_indices_offset = 0; 274 for (size_t i = 0; i < thread_num; ++i) { 275 auto task = [¶m, &segments, &each_thread_buckets, i, current_indices_offset]() { 276 CopySegmentIndicesToBucket<T>(param, segments[i], current_indices_offset, each_thread_buckets[i]); 277 return common::SUCCESS; 278 }; 279 (void)tasks.emplace_back(task); 280 current_indices_offset += segments[i]->indices_size_; 281 } 282 (void)common::ThreadPool::GetInstance().SyncRun(tasks); 283 } 284 285 template <typename T> SortAndReduceBucketSparseGradient(const MultiThreadReduceSparseGradientParam<T> & param,const std::shared_ptr<BucketSparseGradient<T>> & bucket,const std::shared_ptr<SparseGradient<T>> & reduced_bucket)286 static void SortAndReduceBucketSparseGradient(const MultiThreadReduceSparseGradientParam<T> ¶m, 287 const std::shared_ptr<BucketSparseGradient<T>> &bucket, 288 const std::shared_ptr<SparseGradient<T>> &reduced_bucket) { 289 MS_LOG(DEBUG) << "Start"; 290 MS_EXCEPTION_IF_NULL(bucket); 291 MS_EXCEPTION_IF_NULL(bucket->value_); 292 MS_EXCEPTION_IF_NULL(bucket->indices_); 293 MS_EXCEPTION_IF_NULL(reduced_bucket); 294 MS_EXCEPTION_IF_NULL(reduced_bucket->value_); 295 MS_EXCEPTION_IF_NULL(reduced_bucket->indices_); 296 std::vector<std::pair<T, T>> sorted_indices; 297 sorted_indices.reserve(bucket->indices_size_); 298 for (size_t i = 0; i < bucket->indices_size_; ++i) { 299 T index = bucket->indices_[i]; 300 T global_index = bucket->global_indices_[i]; 301 (void)sorted_indices.emplace_back(std::pair<T, T>(index, global_index)); 302 } 303 std::sort(sorted_indices.begin(), sorted_indices.end()); 304 305 float *global_value = param.input_grad_->value_; 306 size_t unique_indices_size = 0; 307 size_t max_length = reduced_bucket->indices_size_ * param.value_stride_; 308 T last_index{0}; 309 size_t value_offset{0}; 310 for (size_t i = 0; i < sorted_indices.size(); ++i) { 311 T index = sorted_indices[i].first; 312 T global_index = sorted_indices[i].second; 313 T global_value_offset = global_index * param.value_stride_; 314 if (i == 0 || index != last_index) { 315 if (i != 0) { 316 unique_indices_size++; 317 } 318 reduced_bucket->indices_[unique_indices_size] = index; 319 value_offset = unique_indices_size * param.value_stride_; 320 auto ret_code = memcpy_s(reduced_bucket->value_ + value_offset, (max_length - value_offset) * sizeof(float), 321 global_value + global_value_offset, param.value_stride_ * sizeof(float)); 322 if (ret_code != EOK) { 323 MS_LOG(EXCEPTION) << "Failed to copy data!"; 324 } 325 } else { 326 for (size_t j = 0; j < param.value_stride_; ++j) { 327 reduced_bucket->value_[value_offset + j] += global_value[global_value_offset + j]; 328 } 329 } 330 last_index = index; 331 } 332 reduced_bucket->indices_size_ = unique_indices_size; 333 MS_LOG(DEBUG) << "End"; 334 } 335 336 template <typename T> ReduceBucketSparseGradient(const MultiThreadReduceSparseGradientParam<T> & param,const std::shared_ptr<BucketSparseGradient<T>> & bucket,const std::shared_ptr<SparseGradient<T>> & reduced_bucket)337 static void ReduceBucketSparseGradient(const MultiThreadReduceSparseGradientParam<T> ¶m, 338 const std::shared_ptr<BucketSparseGradient<T>> &bucket, 339 const std::shared_ptr<SparseGradient<T>> &reduced_bucket) { 340 MS_LOG(DEBUG) << "Start"; 341 MS_EXCEPTION_IF_NULL(bucket); 342 MS_EXCEPTION_IF_NULL(bucket->value_); 343 MS_EXCEPTION_IF_NULL(bucket->indices_); 344 MS_EXCEPTION_IF_NULL(reduced_bucket); 345 MS_EXCEPTION_IF_NULL(reduced_bucket->value_); 346 MS_EXCEPTION_IF_NULL(reduced_bucket->indices_); 347 348 float *global_value = param.input_grad_->value_; 349 std::unordered_map<T, size_t> index_map; 350 size_t unique_indices_size = 0; 351 size_t max_length = reduced_bucket->indices_size_ * param.value_stride_; 352 for (size_t i = 0; i < bucket->indices_size_; ++i) { 353 T index = bucket->indices_[i]; 354 T global_index = bucket->global_indices_[i]; 355 auto iter = index_map.find(index); 356 if (iter == index_map.end()) { 357 reduced_bucket->indices_[unique_indices_size] = index; 358 size_t start_index = unique_indices_size * param.value_stride_; 359 index_map[index] = start_index; 360 auto ret_code = 361 memcpy_s(reduced_bucket->value_ + start_index, (max_length - start_index) * sizeof(float), 362 global_value + global_index * param.value_stride_, param.value_stride_ * sizeof(float)); 363 if (ret_code != EOK) { 364 MS_LOG(EXCEPTION) << "Failed to copy data!"; 365 } 366 unique_indices_size++; 367 } else { 368 size_t start_index = iter->second; 369 size_t end_index = start_index + param.value_stride_; 370 for (size_t j = start_index, k = global_index * param.value_stride_; j < end_index; ++j, ++k) { 371 reduced_bucket->value_[j] += global_value[k]; 372 } 373 } 374 } 375 reduced_bucket->indices_size_ = unique_indices_size; 376 MS_LOG(DEBUG) << "End"; 377 } 378 379 template <typename T> ReduceBucketSparseGradientToWorkspace(const MultiThreadReduceSparseGradientParam<T> & param,const std::vector<std::shared_ptr<BucketSparseGradient<T>>> & buckets,std::vector<std::shared_ptr<SparseGradient<T>>> * reduced_buckets_ptr)380 static void ReduceBucketSparseGradientToWorkspace( 381 const MultiThreadReduceSparseGradientParam<T> ¶m, 382 const std::vector<std::shared_ptr<BucketSparseGradient<T>>> &buckets, 383 std::vector<std::shared_ptr<SparseGradient<T>>> *reduced_buckets_ptr) { 384 MS_EXCEPTION_IF_NULL(param.workspace_grad_); 385 MS_EXCEPTION_IF_NULL(param.workspace_grad_->value_); 386 MS_EXCEPTION_IF_NULL(param.workspace_grad_->indices_); 387 MS_EXCEPTION_IF_NULL(reduced_buckets_ptr); 388 auto &reduced_buckets = *reduced_buckets_ptr; 389 size_t thread_num = buckets.size(); 390 std::vector<common::Task> tasks; 391 tasks.reserve(thread_num); 392 393 size_t current_indices_offset = 0; 394 for (size_t i = 0; i < thread_num; ++i) { 395 (void)reduced_buckets.emplace_back(std::make_shared<SparseGradient<T>>()); 396 reduced_buckets[i]->value_ = param.workspace_grad_->value_ + current_indices_offset * param.value_stride_; 397 reduced_buckets[i]->indices_ = param.workspace_grad_->indices_ + current_indices_offset; 398 reduced_buckets[i]->indices_size_ = buckets[i]->indices_size_; 399 auto task = [¶m, &buckets, &reduced_buckets, i]() { 400 if (param.use_sort_reduce_) { 401 SortAndReduceBucketSparseGradient<T>(param, buckets[i], reduced_buckets[i]); 402 } else { 403 ReduceBucketSparseGradient<T>(param, buckets[i], reduced_buckets[i]); 404 } 405 return common::SUCCESS; 406 }; 407 (void)tasks.emplace_back(task); 408 current_indices_offset += buckets[i]->indices_size_; 409 } 410 (void)common::ThreadPool::GetInstance().SyncRun(tasks); 411 } 412 413 template <typename T> MergeReduceSparseGradient(const MultiThreadReduceSparseGradientParam<T> & param,const std::vector<std::shared_ptr<SparseGradient<T>>> & reduced_buckets)414 static void MergeReduceSparseGradient(const MultiThreadReduceSparseGradientParam<T> ¶m, 415 const std::vector<std::shared_ptr<SparseGradient<T>>> &reduced_buckets) { 416 MS_EXCEPTION_IF_NULL(param.output_grad_); 417 auto output_grad = param.output_grad_; 418 MS_EXCEPTION_IF_NULL(output_grad->value_); 419 MS_EXCEPTION_IF_NULL(output_grad->indices_); 420 size_t stride_data_size = param.value_stride_ * sizeof(float); 421 size_t unique_indices_size = 0; 422 for (size_t i = 0; i < reduced_buckets.size(); ++i) { 423 auto &bucket = reduced_buckets[i]; 424 MS_EXCEPTION_IF_NULL(bucket); 425 if (bucket->indices_size_ == 0) { 426 continue; 427 } 428 auto ret_code = memcpy_s(output_grad->value_ + unique_indices_size * param.value_stride_, 429 (output_grad->indices_size_ - unique_indices_size) * stride_data_size, bucket->value_, 430 bucket->indices_size_ * stride_data_size); 431 if (ret_code != EOK) { 432 MS_LOG(EXCEPTION) << "Failed to copy data!"; 433 } 434 ret_code = memcpy_s(output_grad->indices_ + unique_indices_size, 435 (output_grad->indices_size_ - unique_indices_size) * sizeof(T), bucket->indices_, 436 bucket->indices_size_ * sizeof(T)); 437 if (ret_code != EOK) { 438 MS_LOG(EXCEPTION) << "Failed to copy data!"; 439 } 440 unique_indices_size += bucket->indices_size_; 441 } 442 output_grad->indices_size_ = unique_indices_size; 443 } 444 445 protected: 446 TypeId indices_data_type_{kNumberTypeInt32}; 447 size_t indices_size_{0}; 448 size_t var_first_dim_size_{0}; 449 size_t var_outer_dim_size_{1}; 450 }; 451 } // namespace kernel 452 } // namespace mindspore 453 454 #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_OPTIMIZER_CPU_KERNEL_H_ 455