• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_OPTIMIZER_CPU_KERNEL_H_
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_OPTIMIZER_CPU_KERNEL_H_
19 
20 #include <vector>
21 #include <memory>
22 #include <unordered_map>
23 #include <algorithm>
24 #include <utility>
25 #include "backend/kernel_compiler/cpu/cpu_kernel.h"
26 #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
27 #include "common/thread_pool.h"
28 namespace mindspore {
29 namespace kernel {
30 template <typename T>
31 struct SparseGradient {
32   float *value_{nullptr};
33   T *indices_{nullptr};
34   size_t indices_size_{0};
35 };
36 
37 template <typename T>
38 struct ReduceSparseGradientParam {
39   SparseGradient<T> *input_grad_{nullptr};
40   SparseGradient<T> *workspace_grad_{nullptr};
41   SparseGradient<T> *output_grad_{nullptr};
42   size_t max_index_{0};
43   size_t value_stride_{0};
44   bool use_sort_reduce_{false};
45 };
46 
47 template <typename T>
48 struct MultiThreadComputeParams {
49   float *var_{nullptr};
50   float *accum_{nullptr};
51   float *linear_{nullptr};
52   float *m_{nullptr};
53   float *m_t_{nullptr};
54   float *v_{nullptr};
55   float lr_{0};
56   float l1_{0};
57   float l2_{0};
58   float lr_power_{0};
59   float beta1_{0};
60   float beta2_{0};
61   float epsilon_{0};
62   SparseGradient<T> sparse_grad_;
63   size_t var_first_dim_size_{0};
64   size_t var_outer_dim_size_{0};
65   bool use_nesterov_;
66 };
67 
68 template <typename T>
69 using MultiThreadComputeFunc = std::function<void(MultiThreadComputeParams<T> *param, size_t start, size_t end)>;
70 
71 template <typename T>
72 struct BucketSparseGradient {
73   float *value_;
74   T *indices_;
75   T *global_indices_;
76   size_t indices_size_;
77 };
78 
79 template <typename T>
80 struct MultiThreadReduceSparseGradientParam {
81   SparseGradient<T> *input_grad_{nullptr};
82   SparseGradient<T> *workspace_grad_{nullptr};
83   SparseGradient<T> *output_grad_{nullptr};
84   size_t max_index_{0};
85   size_t value_stride_{0};
86   size_t thread_num_{0};
87   bool use_sort_reduce_{false};
88 };
89 
90 class SparseOptimizerCPUKernel : public CPUKernel {
91  public:
92   SparseOptimizerCPUKernel() = default;
93   ~SparseOptimizerCPUKernel() override = default;
94 
95   template <typename T>
BucketReduceSparseGradient(const ReduceSparseGradientParam<T> & param)96   static void BucketReduceSparseGradient(const ReduceSparseGradientParam<T> &param) {
97     MS_LOG(DEBUG) << "Start";
98     MS_EXCEPTION_IF_NULL(param.input_grad_);
99     size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
100     if (param.input_grad_->indices_size_ < thread_num) {
101       thread_num = param.input_grad_->indices_size_;
102     }
103     MultiThreadReduceSparseGradientParam<T> multi_thread_param(
104       {param.input_grad_, param.workspace_grad_, param.output_grad_, param.max_index_, param.value_stride_, thread_num,
105        param.use_sort_reduce_});
106     std::vector<std::shared_ptr<SparseGradient<T>>> segments;
107     std::vector<std::shared_ptr<std::vector<size_t>>> segment_bucket_sizes;
108     SplitAndCalculateSegmentBucketSize(multi_thread_param, &segments, &segment_bucket_sizes);
109 
110     std::vector<std::shared_ptr<BucketSparseGradient<T>>> buckets;
111     GatherSegmentIndicesToOutputBucket(multi_thread_param, segments, segment_bucket_sizes, &buckets);
112 
113     std::vector<std::shared_ptr<SparseGradient<T>>> reduced_buckets;
114     ReduceBucketSparseGradientToWorkspace(multi_thread_param, buckets, &reduced_buckets);
115 
116     MergeReduceSparseGradient(multi_thread_param, reduced_buckets);
117     MS_LOG(DEBUG) << "End";
118   }
119 
120  protected:
121   template <typename T>
MultiThreadCompute(const MultiThreadComputeFunc<T> & func,MultiThreadComputeParams<T> * params,size_t total_compute_size)122   void MultiThreadCompute(const MultiThreadComputeFunc<T> &func, MultiThreadComputeParams<T> *params,
123                           size_t total_compute_size) const {
124     std::vector<common::Task> tasks;
125     auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
126     tasks.reserve(max_thread_num);
127     size_t start = 0;
128     size_t once_compute_size = (total_compute_size + max_thread_num - 1) / max_thread_num;
129     while (start < total_compute_size) {
130       size_t end = (start + once_compute_size) > total_compute_size ? total_compute_size : (start + once_compute_size);
131       auto task = [&func, &params, start, end]() {
132         func(params, start, end);
133         return common::SUCCESS;
134       };
135       (void)tasks.emplace_back(task);
136       start += once_compute_size;
137     }
138     (void)common::ThreadPool::GetInstance().SyncRun(tasks);
139   }
140 
141  private:
142   template <typename T>
CalculateEachBucketSize(const std::shared_ptr<SparseGradient<T>> & sparse_grad,size_t max_index,std::vector<size_t> * each_bucket_size)143   static void CalculateEachBucketSize(const std::shared_ptr<SparseGradient<T>> &sparse_grad, size_t max_index,
144                                       std::vector<size_t> *each_bucket_size) {
145     MS_LOG(DEBUG) << "Start";
146     MS_EXCEPTION_IF_NULL(sparse_grad);
147     MS_EXCEPTION_IF_NULL(sparse_grad->indices_);
148     MS_EXCEPTION_IF_NULL(each_bucket_size);
149     size_t bucket_num = each_bucket_size->size();
150     if (bucket_num < 1) {
151       MS_LOG(EXCEPTION) << "Bucket num must > 0!";
152     }
153     for (size_t i = 0; i < sparse_grad->indices_size_; ++i) {
154       T index = sparse_grad->indices_[i];
155       if (index >= 0 && LongToSize(index) < max_index) {
156         auto bucket_id = index % bucket_num;
157         each_bucket_size->at(bucket_id)++;
158       }
159     }
160     MS_LOG(DEBUG) << "End";
161   }
162 
163   template <typename T>
SplitAndCalculateSegmentBucketSize(const MultiThreadReduceSparseGradientParam<T> & param,std::vector<std::shared_ptr<SparseGradient<T>>> * segments_ptr,std::vector<std::shared_ptr<std::vector<size_t>>> * segment_bucket_sizes_ptr)164   static void SplitAndCalculateSegmentBucketSize(
165     const MultiThreadReduceSparseGradientParam<T> &param, std::vector<std::shared_ptr<SparseGradient<T>>> *segments_ptr,
166     std::vector<std::shared_ptr<std::vector<size_t>>> *segment_bucket_sizes_ptr) {
167     MS_EXCEPTION_IF_NULL(param.input_grad_);
168     MS_EXCEPTION_IF_NULL(segment_bucket_sizes_ptr);
169     MS_EXCEPTION_IF_NULL(segments_ptr);
170     auto &segments = *segments_ptr;
171     auto &segment_bucket_sizes = *segment_bucket_sizes_ptr;
172     auto input_grad = param.input_grad_;
173     if (param.thread_num_ < 1) {
174       MS_EXCEPTION(ArgumentError) << "Input param thread num must > 0!";
175     }
176     size_t thread_indices_size = input_grad->indices_size_ / param.thread_num_;
177     size_t left_indices_size = input_grad->indices_size_ % param.thread_num_;
178     std::vector<common::Task> tasks;
179     tasks.reserve(param.thread_num_);
180     segments.reserve(param.thread_num_);
181 
182     size_t current_indices_offset = 0;
183     for (size_t i = 0; i < param.thread_num_; ++i) {
184       (void)segment_bucket_sizes.emplace_back(std::make_shared<std::vector<size_t>>(param.thread_num_, 0));
185       size_t indices_size = thread_indices_size;
186       if (i < left_indices_size) {
187         indices_size += 1;
188       }
189       (void)segments.emplace_back(std::make_shared<SparseGradient<T>>());
190       segments[i]->value_ = input_grad->value_ + current_indices_offset * param.value_stride_;
191       segments[i]->indices_ = input_grad->indices_ + current_indices_offset;
192       segments[i]->indices_size_ = indices_size;
193       auto task = [&segments, &param, &segment_bucket_sizes, i]() {
194         CalculateEachBucketSize<T>(segments[i], param.max_index_, segment_bucket_sizes[i].get());
195         return common::SUCCESS;
196       };
197       (void)tasks.emplace_back(task);
198       current_indices_offset += indices_size;
199     }
200     (void)common::ThreadPool::GetInstance().SyncRun(tasks);
201   }
202 
203   template <typename T>
CopySegmentIndicesToBucket(const MultiThreadReduceSparseGradientParam<T> & param,const std::shared_ptr<SparseGradient<T>> & segment,size_t bucket_offset,const std::vector<std::shared_ptr<BucketSparseGradient<T>>> & buckets)204   static void CopySegmentIndicesToBucket(const MultiThreadReduceSparseGradientParam<T> &param,
205                                          const std::shared_ptr<SparseGradient<T>> &segment, size_t bucket_offset,
206                                          const std::vector<std::shared_ptr<BucketSparseGradient<T>>> &buckets) {
207     MS_LOG(DEBUG) << "Start";
208     MS_EXCEPTION_IF_NULL(segment);
209     MS_EXCEPTION_IF_NULL(segment->indices_);
210     if (param.thread_num_ == 0) {
211       MS_EXCEPTION(ArgumentError) << "Input param thread num must > 0!";
212     }
213     std::vector<size_t> bucket_data_num(param.thread_num_, 0);
214     for (size_t i = 0; i < segment->indices_size_; ++i) {
215       T index = segment->indices_[i];
216       if (index >= 0 && LongToSize(index) < param.max_index_) {
217         auto bucket_id = index % param.thread_num_;
218         auto bucket_index = bucket_data_num[bucket_id];
219         buckets[bucket_id]->indices_[bucket_index] = index;
220         buckets[bucket_id]->global_indices_[bucket_index] = bucket_offset + i;
221         bucket_data_num[bucket_id]++;
222       }
223     }
224     MS_LOG(DEBUG) << "End";
225   }
226 
227   template <typename T>
GatherSegmentIndicesToOutputBucket(const MultiThreadReduceSparseGradientParam<T> & param,const std::vector<std::shared_ptr<SparseGradient<T>>> & segments,const std::vector<std::shared_ptr<std::vector<size_t>>> & segment_bucket_sizes,std::vector<std::shared_ptr<BucketSparseGradient<T>>> * buckets_ptr)228   static void GatherSegmentIndicesToOutputBucket(
229     const MultiThreadReduceSparseGradientParam<T> &param,
230     const std::vector<std::shared_ptr<SparseGradient<T>>> &segments,
231     const std::vector<std::shared_ptr<std::vector<size_t>>> &segment_bucket_sizes,
232     std::vector<std::shared_ptr<BucketSparseGradient<T>>> *buckets_ptr) {
233     MS_EXCEPTION_IF_NULL(param.output_grad_);
234     MS_EXCEPTION_IF_NULL(param.output_grad_->value_);
235     MS_EXCEPTION_IF_NULL(param.output_grad_->indices_);
236     MS_EXCEPTION_IF_NULL(buckets_ptr);
237     auto &buckets = *buckets_ptr;
238     size_t thread_num = param.thread_num_;
239     if (thread_num != segment_bucket_sizes.size()) {
240       MS_EXCEPTION(ArgumentError) << "Input param thread num not equal to segment size!";
241     }
242     std::vector<size_t> bucket_data_size(thread_num, 0);
243     for (size_t i = 0; i < thread_num; ++i) {
244       for (size_t j = 0; j < thread_num; ++j) {
245         bucket_data_size[j] += segment_bucket_sizes[i]->at(j);
246       }
247     }
248     size_t current_indices_offset = 0;
249     for (size_t i = 0; i < thread_num; ++i) {
250       (void)buckets.emplace_back(std::make_shared<BucketSparseGradient<T>>());
251       buckets[i]->value_ = param.output_grad_->value_ + current_indices_offset * param.value_stride_;
252       buckets[i]->indices_ = param.output_grad_->indices_ + current_indices_offset;
253       buckets[i]->global_indices_ = param.workspace_grad_->indices_ + current_indices_offset;
254       buckets[i]->indices_size_ = bucket_data_size[i];
255       current_indices_offset += bucket_data_size[i];
256     }
257     std::vector<size_t> tmp_bucket_data_size(thread_num, 0);
258     std::vector<std::vector<std::shared_ptr<BucketSparseGradient<T>>>> each_thread_buckets;
259     for (size_t i = 0; i < thread_num; ++i) {
260       std::vector<std::shared_ptr<BucketSparseGradient<T>>> thread_buckets;
261       for (size_t j = 0; j < thread_num; ++j) {
262         (void)thread_buckets.emplace_back(std::make_shared<BucketSparseGradient<T>>());
263         thread_buckets[j]->indices_ = buckets[j]->indices_ + tmp_bucket_data_size[j];
264         thread_buckets[j]->global_indices_ = buckets[j]->global_indices_ + tmp_bucket_data_size[j];
265         thread_buckets[j]->value_ = buckets[j]->value_ + tmp_bucket_data_size[j] * param.value_stride_;
266         thread_buckets[j]->indices_size_ = segment_bucket_sizes[i]->at(j);
267         tmp_bucket_data_size[j] += segment_bucket_sizes[i]->at(j);
268       }
269       (void)each_thread_buckets.emplace_back(thread_buckets);
270     }
271     std::vector<common::Task> tasks;
272     tasks.reserve(thread_num);
273     current_indices_offset = 0;
274     for (size_t i = 0; i < thread_num; ++i) {
275       auto task = [&param, &segments, &each_thread_buckets, i, current_indices_offset]() {
276         CopySegmentIndicesToBucket<T>(param, segments[i], current_indices_offset, each_thread_buckets[i]);
277         return common::SUCCESS;
278       };
279       (void)tasks.emplace_back(task);
280       current_indices_offset += segments[i]->indices_size_;
281     }
282     (void)common::ThreadPool::GetInstance().SyncRun(tasks);
283   }
284 
285   template <typename T>
SortAndReduceBucketSparseGradient(const MultiThreadReduceSparseGradientParam<T> & param,const std::shared_ptr<BucketSparseGradient<T>> & bucket,const std::shared_ptr<SparseGradient<T>> & reduced_bucket)286   static void SortAndReduceBucketSparseGradient(const MultiThreadReduceSparseGradientParam<T> &param,
287                                                 const std::shared_ptr<BucketSparseGradient<T>> &bucket,
288                                                 const std::shared_ptr<SparseGradient<T>> &reduced_bucket) {
289     MS_LOG(DEBUG) << "Start";
290     MS_EXCEPTION_IF_NULL(bucket);
291     MS_EXCEPTION_IF_NULL(bucket->value_);
292     MS_EXCEPTION_IF_NULL(bucket->indices_);
293     MS_EXCEPTION_IF_NULL(reduced_bucket);
294     MS_EXCEPTION_IF_NULL(reduced_bucket->value_);
295     MS_EXCEPTION_IF_NULL(reduced_bucket->indices_);
296     std::vector<std::pair<T, T>> sorted_indices;
297     sorted_indices.reserve(bucket->indices_size_);
298     for (size_t i = 0; i < bucket->indices_size_; ++i) {
299       T index = bucket->indices_[i];
300       T global_index = bucket->global_indices_[i];
301       (void)sorted_indices.emplace_back(std::pair<T, T>(index, global_index));
302     }
303     std::sort(sorted_indices.begin(), sorted_indices.end());
304 
305     float *global_value = param.input_grad_->value_;
306     size_t unique_indices_size = 0;
307     size_t max_length = reduced_bucket->indices_size_ * param.value_stride_;
308     T last_index{0};
309     size_t value_offset{0};
310     for (size_t i = 0; i < sorted_indices.size(); ++i) {
311       T index = sorted_indices[i].first;
312       T global_index = sorted_indices[i].second;
313       T global_value_offset = global_index * param.value_stride_;
314       if (i == 0 || index != last_index) {
315         if (i != 0) {
316           unique_indices_size++;
317         }
318         reduced_bucket->indices_[unique_indices_size] = index;
319         value_offset = unique_indices_size * param.value_stride_;
320         auto ret_code = memcpy_s(reduced_bucket->value_ + value_offset, (max_length - value_offset) * sizeof(float),
321                                  global_value + global_value_offset, param.value_stride_ * sizeof(float));
322         if (ret_code != EOK) {
323           MS_LOG(EXCEPTION) << "Failed to copy data!";
324         }
325       } else {
326         for (size_t j = 0; j < param.value_stride_; ++j) {
327           reduced_bucket->value_[value_offset + j] += global_value[global_value_offset + j];
328         }
329       }
330       last_index = index;
331     }
332     reduced_bucket->indices_size_ = unique_indices_size;
333     MS_LOG(DEBUG) << "End";
334   }
335 
336   template <typename T>
ReduceBucketSparseGradient(const MultiThreadReduceSparseGradientParam<T> & param,const std::shared_ptr<BucketSparseGradient<T>> & bucket,const std::shared_ptr<SparseGradient<T>> & reduced_bucket)337   static void ReduceBucketSparseGradient(const MultiThreadReduceSparseGradientParam<T> &param,
338                                          const std::shared_ptr<BucketSparseGradient<T>> &bucket,
339                                          const std::shared_ptr<SparseGradient<T>> &reduced_bucket) {
340     MS_LOG(DEBUG) << "Start";
341     MS_EXCEPTION_IF_NULL(bucket);
342     MS_EXCEPTION_IF_NULL(bucket->value_);
343     MS_EXCEPTION_IF_NULL(bucket->indices_);
344     MS_EXCEPTION_IF_NULL(reduced_bucket);
345     MS_EXCEPTION_IF_NULL(reduced_bucket->value_);
346     MS_EXCEPTION_IF_NULL(reduced_bucket->indices_);
347 
348     float *global_value = param.input_grad_->value_;
349     std::unordered_map<T, size_t> index_map;
350     size_t unique_indices_size = 0;
351     size_t max_length = reduced_bucket->indices_size_ * param.value_stride_;
352     for (size_t i = 0; i < bucket->indices_size_; ++i) {
353       T index = bucket->indices_[i];
354       T global_index = bucket->global_indices_[i];
355       auto iter = index_map.find(index);
356       if (iter == index_map.end()) {
357         reduced_bucket->indices_[unique_indices_size] = index;
358         size_t start_index = unique_indices_size * param.value_stride_;
359         index_map[index] = start_index;
360         auto ret_code =
361           memcpy_s(reduced_bucket->value_ + start_index, (max_length - start_index) * sizeof(float),
362                    global_value + global_index * param.value_stride_, param.value_stride_ * sizeof(float));
363         if (ret_code != EOK) {
364           MS_LOG(EXCEPTION) << "Failed to copy data!";
365         }
366         unique_indices_size++;
367       } else {
368         size_t start_index = iter->second;
369         size_t end_index = start_index + param.value_stride_;
370         for (size_t j = start_index, k = global_index * param.value_stride_; j < end_index; ++j, ++k) {
371           reduced_bucket->value_[j] += global_value[k];
372         }
373       }
374     }
375     reduced_bucket->indices_size_ = unique_indices_size;
376     MS_LOG(DEBUG) << "End";
377   }
378 
379   template <typename T>
ReduceBucketSparseGradientToWorkspace(const MultiThreadReduceSparseGradientParam<T> & param,const std::vector<std::shared_ptr<BucketSparseGradient<T>>> & buckets,std::vector<std::shared_ptr<SparseGradient<T>>> * reduced_buckets_ptr)380   static void ReduceBucketSparseGradientToWorkspace(
381     const MultiThreadReduceSparseGradientParam<T> &param,
382     const std::vector<std::shared_ptr<BucketSparseGradient<T>>> &buckets,
383     std::vector<std::shared_ptr<SparseGradient<T>>> *reduced_buckets_ptr) {
384     MS_EXCEPTION_IF_NULL(param.workspace_grad_);
385     MS_EXCEPTION_IF_NULL(param.workspace_grad_->value_);
386     MS_EXCEPTION_IF_NULL(param.workspace_grad_->indices_);
387     MS_EXCEPTION_IF_NULL(reduced_buckets_ptr);
388     auto &reduced_buckets = *reduced_buckets_ptr;
389     size_t thread_num = buckets.size();
390     std::vector<common::Task> tasks;
391     tasks.reserve(thread_num);
392 
393     size_t current_indices_offset = 0;
394     for (size_t i = 0; i < thread_num; ++i) {
395       (void)reduced_buckets.emplace_back(std::make_shared<SparseGradient<T>>());
396       reduced_buckets[i]->value_ = param.workspace_grad_->value_ + current_indices_offset * param.value_stride_;
397       reduced_buckets[i]->indices_ = param.workspace_grad_->indices_ + current_indices_offset;
398       reduced_buckets[i]->indices_size_ = buckets[i]->indices_size_;
399       auto task = [&param, &buckets, &reduced_buckets, i]() {
400         if (param.use_sort_reduce_) {
401           SortAndReduceBucketSparseGradient<T>(param, buckets[i], reduced_buckets[i]);
402         } else {
403           ReduceBucketSparseGradient<T>(param, buckets[i], reduced_buckets[i]);
404         }
405         return common::SUCCESS;
406       };
407       (void)tasks.emplace_back(task);
408       current_indices_offset += buckets[i]->indices_size_;
409     }
410     (void)common::ThreadPool::GetInstance().SyncRun(tasks);
411   }
412 
413   template <typename T>
MergeReduceSparseGradient(const MultiThreadReduceSparseGradientParam<T> & param,const std::vector<std::shared_ptr<SparseGradient<T>>> & reduced_buckets)414   static void MergeReduceSparseGradient(const MultiThreadReduceSparseGradientParam<T> &param,
415                                         const std::vector<std::shared_ptr<SparseGradient<T>>> &reduced_buckets) {
416     MS_EXCEPTION_IF_NULL(param.output_grad_);
417     auto output_grad = param.output_grad_;
418     MS_EXCEPTION_IF_NULL(output_grad->value_);
419     MS_EXCEPTION_IF_NULL(output_grad->indices_);
420     size_t stride_data_size = param.value_stride_ * sizeof(float);
421     size_t unique_indices_size = 0;
422     for (size_t i = 0; i < reduced_buckets.size(); ++i) {
423       auto &bucket = reduced_buckets[i];
424       MS_EXCEPTION_IF_NULL(bucket);
425       if (bucket->indices_size_ == 0) {
426         continue;
427       }
428       auto ret_code = memcpy_s(output_grad->value_ + unique_indices_size * param.value_stride_,
429                                (output_grad->indices_size_ - unique_indices_size) * stride_data_size, bucket->value_,
430                                bucket->indices_size_ * stride_data_size);
431       if (ret_code != EOK) {
432         MS_LOG(EXCEPTION) << "Failed to copy data!";
433       }
434       ret_code = memcpy_s(output_grad->indices_ + unique_indices_size,
435                           (output_grad->indices_size_ - unique_indices_size) * sizeof(T), bucket->indices_,
436                           bucket->indices_size_ * sizeof(T));
437       if (ret_code != EOK) {
438         MS_LOG(EXCEPTION) << "Failed to copy data!";
439       }
440       unique_indices_size += bucket->indices_size_;
441     }
442     output_grad->indices_size_ = unique_indices_size;
443   }
444 
445  protected:
446   TypeId indices_data_type_{kNumberTypeInt32};
447   size_t indices_size_{0};
448   size_t var_first_dim_size_{0};
449   size_t var_outer_dim_size_{1};
450 };
451 }  // namespace kernel
452 }  // namespace mindspore
453 
454 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_OPTIMIZER_CPU_KERNEL_H_
455