1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <algorithm>
18 #include <map>
19 #include <set>
20 #include <utility>
21 #include "plugin/device/cpu/kernel/sparse_matrix_mul_cpu_kernel.h"
22 #include "include/common/thread_pool.h"
23 #include "mindspore/core/ops/sparse_matrix_mul.h"
24
25 namespace mindspore {
26 namespace kernel {
27 namespace {
28 constexpr size_t kInputNum = 6;
29 constexpr size_t kOutputNum = 5;
30 constexpr size_t kAShapeIdx = 0;
31 constexpr size_t kABatchPointersIdx = 1;
32 constexpr size_t kAIndptrIdx = 2;
33 constexpr size_t kAIndicesIdx = 3;
34 constexpr size_t kAValuesIdx = 4;
35 constexpr size_t kBDenseIdx = 5;
36 constexpr size_t kOutShapeIdx = 0;
37 constexpr size_t kOutBatchPointersIdx = 1;
38 constexpr size_t kOutIndptrIdx = 2;
39 constexpr size_t kOutIndicesIdx = 3;
40 constexpr size_t kOutValuesIdx = 4;
41 constexpr size_t bShapeNum1 = 1;
42 constexpr size_t bShapeNum2 = 2;
43 using KernelRunFunc = SparseMatrixMulCpuKernelMod::KernelRunFunc;
44 } // namespace
Init(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)45 bool SparseMatrixMulCpuKernelMod::Init(const std::vector<KernelTensor *> &inputs,
46 const std::vector<KernelTensor *> &outputs) {
47 if (!MatchKernelFunc(kernel_name_, inputs, outputs)) {
48 return false;
49 }
50 return true;
51 }
52
Resize(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)53 int SparseMatrixMulCpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs,
54 const std::vector<KernelTensor *> &outputs) {
55 int ret = KRET_OK;
56 if ((ret = KernelMod::Resize(inputs, outputs)) != 0) {
57 MS_LOG(ERROR) << kernel_name_ << " reinit failed.";
58 return ret;
59 }
60 std::vector<int64_t> b_shape = inputs[kBDenseIdx]->GetShapeVector();
61 size_t b_shape_num = b_shape.size();
62 if (b_shape_num == bShapeNum1) {
63 col_ = LongToSize(b_shape[0]);
64 } else if (b_shape_num == bShapeNum2) {
65 row_ = LongToSize(b_shape[0]);
66 col_ = LongToSize(b_shape[1]);
67 }
68 return ret;
69 }
70
71 template <typename T, typename S>
LaunchKernel(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > &,const std::vector<KernelTensor * > & outputs)72 const bool SparseMatrixMulCpuKernelMod::LaunchKernel(const std::vector<KernelTensor *> &inputs,
73 const std::vector<KernelTensor *> &,
74 const std::vector<KernelTensor *> &outputs) {
75 CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
76 CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
77 const auto a_shape = reinterpret_cast<T *>(inputs[kAShapeIdx]->device_ptr());
78 const auto a_batch_pointers = reinterpret_cast<T *>(inputs[kABatchPointersIdx]->device_ptr());
79 const auto a_indptr = reinterpret_cast<T *>(inputs[kAIndptrIdx]->device_ptr());
80 const auto a_indices = reinterpret_cast<T *>(inputs[kAIndicesIdx]->device_ptr());
81 const auto a_values = reinterpret_cast<S *>(inputs[kAValuesIdx]->device_ptr());
82 const auto b_dense = reinterpret_cast<S *>(inputs[kBDenseIdx]->device_ptr());
83
84 auto c_shape = reinterpret_cast<T *>(outputs[kOutShapeIdx]->device_ptr());
85 auto c_batch_pointers = reinterpret_cast<T *>(outputs[kOutBatchPointersIdx]->device_ptr());
86 auto c_indptr = reinterpret_cast<T *>(outputs[kOutIndptrIdx]->device_ptr());
87 auto c_indices = reinterpret_cast<T *>(outputs[kOutIndicesIdx]->device_ptr());
88 auto c_values = reinterpret_cast<S *>(outputs[kOutValuesIdx]->device_ptr());
89 const int64_t a_indices_num = SizeToLong(inputs[kAIndicesIdx]->size() / (sizeof(T)));
90 const int64_t b_dense_num = SizeToLong(inputs[kBDenseIdx]->size() / (sizeof(S)));
91
92 errno_t ret = memcpy_s(c_batch_pointers, inputs[kABatchPointersIdx]->size(), a_batch_pointers,
93 inputs[kABatchPointersIdx]->size());
94 ret += memcpy_s(c_shape, inputs[kAShapeIdx]->size(), a_shape, inputs[kAShapeIdx]->size());
95 ret += memcpy_s(c_indptr, inputs[kAIndptrIdx]->size(), a_indptr, inputs[kAIndptrIdx]->size());
96 ret += memcpy_s(c_indices, inputs[kAIndicesIdx]->size(), a_indices, inputs[kAIndicesIdx]->size());
97 if (ret != EOK) {
98 MS_LOG(ERROR) << kernel_name_ << "memcpy_s failed.";
99 }
100
101 int64_t index = 0;
102 for (int i = 0; i < a_indices_num; i++) {
103 int64_t col = a_indices[i];
104 int64_t row = 0;
105 while (true) {
106 if (i >= a_indptr[index] && i < a_indptr[index + 1]) {
107 row = index;
108 break;
109 } else {
110 index++;
111 }
112 }
113 int64_t absIndex = row * SizeToLong(col_) + col;
114 if (absIndex < b_dense_num) {
115 c_values[i] = a_values[i] * b_dense[absIndex];
116 } else {
117 c_values[i] = 0;
118 }
119 }
120 return true;
121 }
122
123 #define CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(ms_index_type, ms_value_type, index_type, value_type) \
124 { \
125 KernelAttr() \
126 .AddInputAttr(ms_index_type) \
127 .AddInputAttr(ms_index_type) \
128 .AddInputAttr(ms_index_type) \
129 .AddInputAttr(ms_index_type) \
130 .AddInputAttr(ms_value_type) \
131 .AddInputAttr(ms_value_type) \
132 .AddOutputAttr(ms_index_type) \
133 .AddOutputAttr(ms_index_type) \
134 .AddOutputAttr(ms_index_type) \
135 .AddOutputAttr(ms_index_type) \
136 .AddOutputAttr(ms_value_type), \
137 &SparseMatrixMulCpuKernelMod::LaunchKernel<index_type, value_type> \
138 }
139
GetFuncList() const140 const std::vector<std::pair<KernelAttr, KernelRunFunc>> &SparseMatrixMulCpuKernelMod::GetFuncList() const {
141 static const std::vector<std::pair<KernelAttr, KernelRunFunc>> func_list = {
142 // float values
143 CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeFloat32, int, float),
144 CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeFloat32, int64_t, float),
145 // double values
146 CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeFloat64, int, double),
147 CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeFloat64, int64_t, double),
148 // int values
149 CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int, int),
150 CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt32, int64_t, int),
151 // int64 values
152 CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt64, int, int64_t),
153 CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t, int64_t),
154 // int16 values
155 CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt32, kNumberTypeInt16, int, int16_t),
156 CPU_SPARSE_MATRIX_MUL_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt16, int64_t, int16_t),
157 };
158 return func_list;
159 }
160
161 MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseMatrixMul, SparseMatrixMulCpuKernelMod);
162 } // namespace kernel
163 } // namespace mindspore
164