1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "ops/ops_func_impl/grouped_matmul.h"
17 #include <string>
18 #include <map>
19 #include <set>
20 #include <vector>
21 #include <utility>
22 #include "ops/op_utils.h"
23 #include "utils/check_convert_utils.h"
24 #include "abstract/ops/primitive_infer_map.h"
25 #include "mindapi/src/helper.h"
26 #include "mindspore/core/ops/other_ops.h"
27
28 namespace mindspore {
29 namespace ops {
30 /*
31 separated means the size of tensorlist not equal 1.
32 integrated means the size of tensorlist is 1.
33 split_item inputs weight outputs
34 0: separated separated separated
35 1: integrated b, k, n separated
36 2: separated separated integrated
37 3: integrated b, k, n integrated
38 */
39 namespace {
40 constexpr size_t listInputNum = 7;
41 constexpr size_t kGmmInputX = 0;
42 constexpr size_t kGmmInputWeight = 1;
43 // optional None
44 constexpr size_t kGmmInputGroupList = 7;
45 // attr
46 constexpr size_t kGmmInputSplitItem = 8;
47 constexpr size_t kGmmInputGroupType = 9;
48 // TensorShape
49 constexpr size_t gmmTensor2D = 2;
50 constexpr size_t gmmTensor3D = 3;
51 constexpr size_t gmmTensor6D = 6;
52 // split_item mode
53 constexpr size_t multiTensor = 0;
54 constexpr size_t singleTensor = 3;
55 } // namespace
56
57 int64_t gGroupedMatmulSplitItem = 0;
58
CheckSplitItem(const std::string & op_name,const int64_t split_item) const59 void GroupedMatmulFuncImpl::CheckSplitItem(const std::string &op_name, const int64_t split_item) const {
60 if (split_item != multiTensor && split_item != singleTensor) {
61 MS_EXCEPTION(ValueError) << "For '" << op_name << "', the split_item only support 0 or 3, but got " << split_item;
62 }
63 }
64
CheckGroupType(const std::string & op_name,const int64_t group_type) const65 void GroupedMatmulFuncImpl::CheckGroupType(const std::string &op_name, const int64_t group_type) const {
66 if (group_type != -1 && group_type != 0) {
67 MS_EXCEPTION(ValueError) << "For '" << op_name << "', the group_type only support -1 or 0, but got " << group_type;
68 }
69 }
70
CheckSplitItemAndGroupType(const std::string & op_name,const int64_t group_type,const int64_t split_item) const71 void GroupedMatmulFuncImpl::CheckSplitItemAndGroupType(const std::string &op_name, const int64_t group_type,
72 const int64_t split_item) const {
73 if (group_type == -1 && split_item != 0) {
74 MS_EXCEPTION(ValueError) << "For '" << op_name
75 << "', group_type is -1 (not grouped), the split_item only support 0(multi tensor)"
76 << "but split_item got " << split_item;
77 }
78 if (group_type == 0 && split_item != singleTensor) {
79 MS_EXCEPTION(ValueError) << "For '" << op_name
80 << "', group_type is 0 (group m-axis), the split_item only support 3(one tensor)"
81 << "but split_item got " << split_item;
82 }
83 }
84
CheckInputType(const std::vector<AbstractBasePtr> & input_args,const std::string & op_name,const std::string & input_name,const size_t input_idx,const std::set<TypePtr> & check_list) const85 void GroupedMatmulFuncImpl::CheckInputType(const std::vector<AbstractBasePtr> &input_args, const std::string &op_name,
86 const std::string &input_name, const size_t input_idx,
87 const std::set<TypePtr> &check_list) const {
88 // Optional input args must be TensorList. If optional, it is a TensorList which has only a empty Tensor.
89 if (input_args[input_idx]->GetType()->isa<TypeNone>()) {
90 MS_EXCEPTION(ShapeError) << "For " << op_name << ", the input {" << input_name
91 << "}, should be TensorList. but got "
92 << input_args[input_idx]->GetType()->isa<TypeNone>();
93 }
94 // Check Type
95 abstract::AbstractTuple optional_list = *(input_args[input_idx]->cast<abstract::AbstractTuplePtr>());
96 for (size_t i = 0; i < optional_list.size(); i++) {
97 (void)CheckAndConvertUtils::CheckTensorTypeValid(input_name, optional_list[i]->GetType(), check_list, op_name);
98 }
99 }
100
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const101 BaseShapePtr GroupedMatmulFuncImpl::InferShape(const PrimitivePtr &primitive,
102 const std::vector<AbstractBasePtr> &input_args) const {
103 MS_EXCEPTION_IF_NULL(primitive);
104 const std::string op_name = primitive->name();
105
106 // check split_item
107 MS_EXCEPTION_IF_NULL(input_args[kGmmInputSplitItem]);
108 auto split_type = input_args[kGmmInputSplitItem]->GetType();
109 if (split_type->isa<TypeNone>()) {
110 MS_EXCEPTION(ShapeError) << "For '" << op_name << "', the split_item must be a int. Current split_item is None";
111 }
112 ValuePtr split_ptr = input_args[kGmmInputSplitItem]->GetValue();
113 auto split_item = GetValue<int64_t>(split_ptr);
114 CheckSplitItem(op_name, split_item);
115 gGroupedMatmulSplitItem = split_item;
116
117 // check group_type
118 MS_EXCEPTION_IF_NULL(input_args[kGmmInputGroupType]);
119 auto group_type_type = input_args[kGmmInputGroupType]->GetType();
120 if (group_type_type->isa<TypeNone>()) {
121 MS_EXCEPTION(ShapeError) << "For '" << op_name << "', the group_type must be a int. Current group_type is None";
122 }
123 ValuePtr group_type_ptr = input_args[kGmmInputGroupType]->GetValue();
124 auto group_type = GetValue<int64_t>(group_type_ptr);
125 CheckGroupType(op_name, group_type);
126
127 CheckSplitItemAndGroupType(op_name, group_type, split_item);
128
129 // x_list
130 auto x_ptr = input_args[kGmmInputX]->cast<abstract::AbstractTuplePtr>();
131 MS_EXCEPTION_IF_NULL(x_ptr);
132 abstract::AbstractTuple x_list = *x_ptr;
133
134 // weight_list
135 auto weight_ptr = input_args[kGmmInputWeight]->cast<abstract::AbstractTuplePtr>();
136 MS_EXCEPTION_IF_NULL(weight_ptr);
137 abstract::AbstractTuple weight_list = *weight_ptr;
138
139 // for tensorlist(input arg) in backend split. (AscendConvertTupleInputToDynamicInput pass)
140 std::vector<int64_t> dyn_input_sizes;
141 for (size_t i = 0; i < listInputNum; ++i) {
142 if (input_args[i]->GetType()->isa<TypeNone>()) {
143 dyn_input_sizes.push_back(0);
144 } else {
145 abstract::AbstractTuple list = *(input_args[i]->cast<abstract::AbstractTuplePtr>());
146 dyn_input_sizes.push_back(list.size());
147 }
148 }
149 primitive->set_attr("group_info", MakeValue(dyn_input_sizes));
150
151 // calculate shape. split_item = 3, x[0](m, n) * w[0](e, n, k) = out(m, k)
152 if (split_item == singleTensor) {
153 if (x_list.size() == 1 && weight_list.size() == 1) {
154 std::vector<int64_t> x_shape = x_list[0]->GetShape()->GetShapeVector();
155 std::vector<int64_t> w_shape = weight_list[0]->GetShape()->GetShapeVector();
156 if (x_shape.size() != gmmTensor2D) {
157 MS_EXCEPTION(ShapeError) << "For '" << op_name
158 << "', when split_item is 3, the x[0] must be 2D Tensor. But x[0] shape :" << x_shape;
159 }
160 if (w_shape.size() != gmmTensor3D) {
161 MS_EXCEPTION(ShapeError) << "For '" << op_name
162 << "', when split_item is 3, the w[0] must be 3D Tensor. But w[0] shape :" << w_shape;
163 }
164 if (x_shape[1] != w_shape[1]) {
165 MS_EXCEPTION(ShapeError) << "For '" << op_name << "', x[0] shape should be (m, n), w[0] shape show be(e, n, k)."
166 << "But x[0] shape: " << x_shape << ", w[0] shape: " << w_shape;
167 }
168 std::vector<BaseShapePtr> outshape_merge = {};
169 std::vector<int64_t> res_shape = {x_shape[0], w_shape.back()};
170 outshape_merge.emplace_back(std::make_shared<abstract::TensorShape>(res_shape));
171 return std::make_shared<abstract::TupleShape>(outshape_merge);
172 } else {
173 MS_EXCEPTION(ShapeError) << "For '" << op_name << "', when split_item is 3. the size of x or weight only be 1."
174 << "But x size: " << x_list.size() << ", weight size: " << weight_list.size();
175 }
176 }
177
178 if (x_list.size() != weight_list.size()) {
179 MS_EXCEPTION(ShapeError) << "For '" << op_name << "', when split_item is 0, x.size() == w.size()."
180 << "But x size: " << x_list.size() << ", weight size: " << weight_list.size();
181 }
182
183 std::vector<BaseShapePtr> outshape_list = {};
184 for (size_t i = 0; i < x_list.size(); i++) {
185 std::vector<int64_t> x_shape = x_list[i]->GetShape()->GetShapeVector();
186 std::vector<int64_t> w_shape = weight_list[i]->GetShape()->GetShapeVector();
187 if (x_shape.size() < gmmTensor2D || x_shape.size() > gmmTensor6D) {
188 MS_EXCEPTION(ShapeError) << "For '" << op_name << "', when split_item is 0, the tensor in x must be 2-6D. But"
189 << i << "th tensor in x, shape is : " << x_shape;
190 }
191 if (w_shape.size() != gmmTensor2D) {
192 MS_EXCEPTION(ShapeError) << "For '" << op_name << "', when split_item is 0, the tensor in x must be 2-6D. But"
193 << i << "th tensor in x, shape is : " << x_shape;
194 }
195 if (x_shape.back() != w_shape[0]) {
196 MS_EXCEPTION(ShapeError) << "For '" << op_name
197 << "' The back in x[i] shape should be equal to the first in w[i] shape. But x[" << i
198 << "] shape : " << x_shape << ", w[" << i << "] shape : " << w_shape;
199 }
200 std::vector<int64_t> res_shape = x_shape;
201 res_shape.back() = w_shape[1]; // x[a,b,c,m,n] * w[n,k] = out[a,b,c,m,k]
202 outshape_list.emplace_back(std::make_shared<abstract::TensorShape>(res_shape));
203 }
204
205 return std::make_shared<abstract::TupleShape>(outshape_list);
206 }
207
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const208 TypePtr GroupedMatmulFuncImpl::InferType(const PrimitivePtr &primitive,
209 const std::vector<AbstractBasePtr> &input_args) const {
210 MS_EXCEPTION_IF_NULL(primitive);
211 const auto &op_name = primitive->name();
212 const std::set<TypePtr> xw_input_type = {kFloat16, kBFloat16, kFloat32, kInt8};
213
214 MS_EXCEPTION_IF_NULL(input_args[kGmmInputX]);
215 CheckInputType(input_args, op_name, "x", kGmmInputX, xw_input_type);
216
217 MS_EXCEPTION_IF_NULL(input_args[kGmmInputWeight]);
218 CheckInputType(input_args, op_name, "weight", kGmmInputWeight, xw_input_type);
219
220 // get split_item and check groups
221 MS_EXCEPTION_IF_NULL(input_args[kGmmInputSplitItem]);
222 auto split_type = input_args[kGmmInputSplitItem]->GetType();
223 if (split_type->isa<TypeNone>()) {
224 MS_EXCEPTION(TypeError) << "For '" << op_name << "', the group_type must be a int. Current split_item is None";
225 }
226 ValuePtr split_ptr = input_args[kGmmInputSplitItem]->GetValue();
227 auto split_item = GetValue<int64_t>(split_ptr);
228 CheckSplitItem(op_name, split_item);
229 if (split_item == singleTensor) {
230 (void)CheckAndConvertUtils::CheckTensorTypeValid("grouplist", input_args[kGmmInputGroupList]->GetType(), {kInt64},
231 op_name);
232 }
233
234 // check group_list
235 MS_EXCEPTION_IF_NULL(input_args[kGmmInputGroupList]);
236 auto group_list_type = input_args[kGmmInputGroupList]->GetType();
237 if (split_item == singleTensor && group_list_type->isa<TypeNone>()) {
238 MS_EXCEPTION(ShapeError) << "For '" << op_name
239 << "', the group_type must be a int when split_item equal 3. Current group_type is None";
240 }
241 auto group_list_shape_map =
242 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kGmmInputGroupList]->GetShape());
243 auto group_list_shape = group_list_shape_map[kShape];
244 if (split_item == singleTensor && group_list_shape.size() != 1) {
245 MS_EXCEPTION(ShapeError) << "For '" << op_name << "', the grouplist must be 1D Tensor when split_item equal 3."
246 << "Current groups_list shape is " << group_list_shape;
247 }
248
249 // check group_type
250 MS_EXCEPTION_IF_NULL(input_args[kGmmInputGroupType]);
251 auto group_type_type = input_args[kGmmInputGroupType]->GetType();
252 if (split_item == singleTensor && group_type_type->isa<TypeNone>()) {
253 MS_EXCEPTION(ShapeError) << "For '" << op_name
254 << "', the group_type must be a int when split_item equal 3. Current group_type is None";
255 }
256 ValuePtr group_type_ptr = input_args[kGmmInputGroupType]->GetValue();
257 auto group_type = GetValue<int64_t>(group_type_ptr);
258 if (split_item == singleTensor && group_type != 0) {
259 MS_EXCEPTION(ShapeError) << "For '" << op_name
260 << "', the group_type must be 0(split axis m) when split_item equal 3."
261 << "Current group_type is " << group_type;
262 }
263
264 // support split_item 0 or 3
265 std::vector<TypePtr> type_tuple;
266 abstract::AbstractTuple x_list = *(input_args[kGmmInputX]->cast<abstract::AbstractTuplePtr>());
267 for (size_t i = 0; i < x_list.size(); i++) {
268 type_tuple.emplace_back(x_list[i]->GetType()->Clone());
269 }
270
271 return std::make_shared<Tuple>(std::move(type_tuple));
272 }
273
274 // In compiler get grouplist(not none) for resize
GetValueDependArgIndices() const275 std::set<int64_t> GroupedMatmulFuncImpl::GetValueDependArgIndices() const {
276 if (gGroupedMatmulSplitItem == singleTensor) {
277 return {kGmmInputGroupList};
278 } else {
279 return {};
280 }
281 }
282 } // namespace ops
283 } // namespace mindspore
284