• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include <map>
19 #include <set>
20 #include <utility>
21 #include "plugin/device/cpu/kernel/sparse_matrix_mul_cpu_kernel.h"
22 #include "include/common/thread_pool.h"
23 #include "mindspore/core/ops/sparse_matrix_mul.h"
24 
25 namespace mindspore {
26 namespace kernel {
27 namespace {
28 constexpr size_t kInputNum = 6;
29 constexpr size_t kOutputNum = 5;
30 constexpr size_t kAShapeIdx = 0;
31 constexpr size_t kABatchPointersIdx = 1;
32 constexpr size_t kAIndptrIdx = 2;
33 constexpr size_t kAIndicesIdx = 3;
34 constexpr size_t kAValuesIdx = 4;
35 constexpr size_t kBDenseIdx = 5;
36 constexpr size_t kOutShapeIdx = 0;
37 constexpr size_t kOutBatchPointersIdx = 1;
38 constexpr size_t kOutIndptrIdx = 2;
39 constexpr size_t kOutIndicesIdx = 3;
40 constexpr size_t kOutValuesIdx = 4;
41 constexpr size_t bShapeNum1 = 1;
42 constexpr size_t bShapeNum2 = 2;
43 using KernelRunFunc = SparseMatrixMulCpuKernelMod::KernelRunFunc;
44 }  // namespace
Init(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)45 bool SparseMatrixMulCpuKernelMod::Init(const std::vector<KernelTensor *> &inputs,
46                                        const std::vector<KernelTensor *> &outputs) {
47   if (!MatchKernelFunc(kernel_name_, inputs, outputs)) {
48     return false;
49   }
50   return true;
51 }
52 
Resize(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)53 int SparseMatrixMulCpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs,
54                                         const std::vector<KernelTensor *> &outputs) {
55   int ret = KRET_OK;
56   if ((ret = KernelMod::Resize(inputs, outputs)) != 0) {
57     MS_LOG(ERROR) << kernel_name_ << " reinit failed.";
58     return ret;
59   }
60   std::vector<int64_t> b_shape = inputs[kBDenseIdx]->GetShapeVector();
61   size_t b_shape_num = b_shape.size();
62   if (b_shape_num == bShapeNum1) {
63     col_ = LongToSize(b_shape[0]);
64   } else if (b_shape_num == bShapeNum2) {
65     row_ = LongToSize(b_shape[0]);
66     col_ = LongToSize(b_shape[1]);
67   }
68   return ret;
69 }
70 
71 template <typename T, typename S>
LaunchKernel(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > &,const std::vector<KernelTensor * > & outputs)72 const bool SparseMatrixMulCpuKernelMod::LaunchKernel(const std::vector<KernelTensor *> &inputs,
73                                                      const std::vector<KernelTensor *> &,
74                                                      const std::vector<KernelTensor *> &outputs) {
75   CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
76   CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
77   const auto a_shape = reinterpret_cast<T *>(inputs[kAShapeIdx]->device_ptr());
78   const auto a_batch_pointers = reinterpret_cast<T *>(inputs[kABatchPointersIdx]->device_ptr());
79   const auto a_indptr = reinterpret_cast<T *>(inputs[kAIndptrIdx]->device_ptr());
80   const auto a_indices = reinterpret_cast<T *>(inputs[kAIndicesIdx]->device_ptr());
81   const auto a_values = reinterpret_cast<S *>(inputs[kAValuesIdx]->device_ptr());
82   const auto b_dense = reinterpret_cast<S *>(inputs[kBDenseIdx]->device_ptr());
83 
84   auto c_shape = reinterpret_cast<T *>(outputs[kOutShapeIdx]->device_ptr());
85   auto c_batch_pointers = reinterpret_cast<T *>(outputs[kOutBatchPointersIdx]->device_ptr());
86   auto c_indptr = reinterpret_cast<T *>(outputs[kOutIndptrIdx]->device_ptr());
87   auto c_indices = reinterpret_cast<T *>(outputs[kOutIndicesIdx]->device_ptr());
88   auto c_values = reinterpret_cast<S *>(outputs[kOutValuesIdx]->device_ptr());
89   const int64_t a_indices_num = SizeToLong(inputs[kAIndicesIdx]->size() / (sizeof(T)));
90   const int64_t b_dense_num = SizeToLong(inputs[kBDenseIdx]->size() / (sizeof(S)));
91 
92   errno_t ret = memcpy_s(c_batch_pointers, inputs[kABatchPointersIdx]->size(), a_batch_pointers,
93                          inputs[kABatchPointersIdx]->size());
94   ret += memcpy_s(c_shape, inputs[kAShapeIdx]->size(), a_shape, inputs[kAShapeIdx]->size());
95   ret += memcpy_s(c_indptr, inputs[kAIndptrIdx]->size(), a_indptr, inputs[kAIndptrIdx]->size());
96   ret += memcpy_s(c_indices, inputs[kAIndicesIdx]->size(), a_indices, inputs[kAIndicesIdx]->size());
97   if (ret != EOK) {
98     MS_LOG(ERROR) << kernel_name_ << "memcpy_s failed.";
99   }
100 
101   int64_t index = 0;
102   for (int i = 0; i < a_indices_num; i++) {
103     int64_t col = a_indices[i];
104     int64_t row = 0;
105     while (true) {
106       if (i >= a_indptr[index] && i < a_indptr[index + 1]) {
107         row = index;
108         break;
109       } else {
110         index++;
111       }
112     }
113     int64_t absIndex = row * SizeToLong(col_) + col;
114     if (absIndex < b_dense_num) {
115       c_values[i] = a_values[i] * b_dense[absIndex];
116     } else {
117       c_values[i] = 0;
118     }
119   }
120   return true;
121 }
122 
123 #define CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(ms_index_type, ms_value_type, index_type, value_type) \
124   {                                                                                                 \
125     KernelAttr()                                                                                    \
126       .AddInputAttr(ms_index_type)                                                                  \
127       .AddInputAttr(ms_index_type)                                                                  \
128       .AddInputAttr(ms_index_type)                                                                  \
129       .AddInputAttr(ms_index_type)                                                                  \
130       .AddInputAttr(ms_value_type)                                                                  \
131       .AddInputAttr(ms_value_type)                                                                  \
132       .AddOutputAttr(ms_index_type)                                                                 \
133       .AddOutputAttr(ms_index_type)                                                                 \
134       .AddOutputAttr(ms_index_type)                                                                 \
135       .AddOutputAttr(ms_index_type)                                                                 \
136       .AddOutputAttr(ms_value_type),                                                                \
137       &SparseMatrixMulCpuKernelMod::LaunchKernel<index_type, value_type>                            \
138   }
139 
GetFuncList() const140 const std::vector<std::pair<KernelAttr, KernelRunFunc>> &SparseMatrixMulCpuKernelMod::GetFuncList() const {
141   static const std::vector<std::pair<KernelAttr, KernelRunFunc>> func_list = {
142     // float values
143     CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeFloat32, int, float),
144     CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeFloat32, int64_t, float),
145     // double values
146     CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeFloat64, int, double),
147     CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeFloat64, int64_t, double),
148     // int values
149     CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int, int),
150     CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt32, int64_t, int),
151     // int64 values
152     CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt64, int, int64_t),
153     CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t, int64_t),
154     // int16 values
155     CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt16, int, int16_t),
156     CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt16, int64_t, int16_t),
157   };
158   return func_list;
159 }
160 
161 MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseMatrixMul, SparseMatrixMulCpuKernelMod);
162 }  // namespace kernel
163 }  // namespace mindspore
164