• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "ops/ops_func_impl/grouped_matmul.h"
17 #include <string>
18 #include <map>
19 #include <set>
20 #include <vector>
21 #include <utility>
22 #include "ops/op_utils.h"
23 #include "utils/check_convert_utils.h"
24 #include "abstract/ops/primitive_infer_map.h"
25 #include "mindapi/src/helper.h"
26 #include "mindspore/core/ops/other_ops.h"
27 
28 namespace mindspore {
29 namespace ops {
30 /*
31 separated means the size of tensorlist not equal 1.
32 integrated means the size of tensorlist is 1.
33 split_item        inputs     weight      outputs
34       0:      separated     separated    separated
35       1:     integrated     b, k, n      separated
36       2:      separated     separated    integrated
37       3:     integrated     b, k, n      integrated
38 */
39 namespace {
40 constexpr size_t listInputNum = 7;
41 constexpr size_t kGmmInputX = 0;
42 constexpr size_t kGmmInputWeight = 1;
43 // optional None
44 constexpr size_t kGmmInputGroupList = 7;
45 // attr
46 constexpr size_t kGmmInputSplitItem = 8;
47 constexpr size_t kGmmInputGroupType = 9;
48 // TensorShape
49 constexpr size_t gmmTensor2D = 2;
50 constexpr size_t gmmTensor3D = 3;
51 constexpr size_t gmmTensor6D = 6;
52 // split_item mode
53 constexpr size_t multiTensor = 0;
54 constexpr size_t singleTensor = 3;
55 }  // namespace
56 
57 int64_t gGroupedMatmulSplitItem = 0;
58 
CheckSplitItem(const std::string & op_name,const int64_t split_item) const59 void GroupedMatmulFuncImpl::CheckSplitItem(const std::string &op_name, const int64_t split_item) const {
60   if (split_item != multiTensor && split_item != singleTensor) {
61     MS_EXCEPTION(ValueError) << "For '" << op_name << "', the split_item only support 0 or 3, but got " << split_item;
62   }
63 }
64 
CheckGroupType(const std::string & op_name,const int64_t group_type) const65 void GroupedMatmulFuncImpl::CheckGroupType(const std::string &op_name, const int64_t group_type) const {
66   if (group_type != -1 && group_type != 0) {
67     MS_EXCEPTION(ValueError) << "For '" << op_name << "', the group_type only support -1 or 0, but got " << group_type;
68   }
69 }
70 
CheckSplitItemAndGroupType(const std::string & op_name,const int64_t group_type,const int64_t split_item) const71 void GroupedMatmulFuncImpl::CheckSplitItemAndGroupType(const std::string &op_name, const int64_t group_type,
72                                                        const int64_t split_item) const {
73   if (group_type == -1 && split_item != 0) {
74     MS_EXCEPTION(ValueError) << "For '" << op_name
75                              << "', group_type is -1 (not grouped), the split_item only support 0(multi tensor)"
76                              << "but split_item got " << split_item;
77   }
78   if (group_type == 0 && split_item != singleTensor) {
79     MS_EXCEPTION(ValueError) << "For '" << op_name
80                              << "', group_type is 0 (group m-axis), the split_item only support 3(one tensor)"
81                              << "but split_item got " << split_item;
82   }
83 }
84 
CheckInputType(const std::vector<AbstractBasePtr> & input_args,const std::string & op_name,const std::string & input_name,const size_t input_idx,const std::set<TypePtr> & check_list) const85 void GroupedMatmulFuncImpl::CheckInputType(const std::vector<AbstractBasePtr> &input_args, const std::string &op_name,
86                                            const std::string &input_name, const size_t input_idx,
87                                            const std::set<TypePtr> &check_list) const {
88   // Optional input args must be TensorList. If optional, it is a TensorList which has only a empty Tensor.
89   if (input_args[input_idx]->GetType()->isa<TypeNone>()) {
90     MS_EXCEPTION(ShapeError) << "For " << op_name << ", the input {" << input_name
91                              << "}, should be TensorList. but got "
92                              << input_args[input_idx]->GetType()->isa<TypeNone>();
93   }
94   // Check Type
95   abstract::AbstractTuple optional_list = *(input_args[input_idx]->cast<abstract::AbstractTuplePtr>());
96   for (size_t i = 0; i < optional_list.size(); i++) {
97     (void)CheckAndConvertUtils::CheckTensorTypeValid(input_name, optional_list[i]->GetType(), check_list, op_name);
98   }
99 }
100 
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const101 BaseShapePtr GroupedMatmulFuncImpl::InferShape(const PrimitivePtr &primitive,
102                                                const std::vector<AbstractBasePtr> &input_args) const {
103   MS_EXCEPTION_IF_NULL(primitive);
104   const std::string op_name = primitive->name();
105 
106   // check split_item
107   MS_EXCEPTION_IF_NULL(input_args[kGmmInputSplitItem]);
108   auto split_type = input_args[kGmmInputSplitItem]->GetType();
109   if (split_type->isa<TypeNone>()) {
110     MS_EXCEPTION(ShapeError) << "For '" << op_name << "', the split_item must be a int. Current split_item is None";
111   }
112   ValuePtr split_ptr = input_args[kGmmInputSplitItem]->GetValue();
113   auto split_item = GetValue<int64_t>(split_ptr);
114   CheckSplitItem(op_name, split_item);
115   gGroupedMatmulSplitItem = split_item;
116 
117   // check group_type
118   MS_EXCEPTION_IF_NULL(input_args[kGmmInputGroupType]);
119   auto group_type_type = input_args[kGmmInputGroupType]->GetType();
120   if (group_type_type->isa<TypeNone>()) {
121     MS_EXCEPTION(ShapeError) << "For '" << op_name << "', the group_type must be a int. Current group_type is None";
122   }
123   ValuePtr group_type_ptr = input_args[kGmmInputGroupType]->GetValue();
124   auto group_type = GetValue<int64_t>(group_type_ptr);
125   CheckGroupType(op_name, group_type);
126 
127   CheckSplitItemAndGroupType(op_name, group_type, split_item);
128 
129   // x_list
130   auto x_ptr = input_args[kGmmInputX]->cast<abstract::AbstractTuplePtr>();
131   MS_EXCEPTION_IF_NULL(x_ptr);
132   abstract::AbstractTuple x_list = *x_ptr;
133 
134   // weight_list
135   auto weight_ptr = input_args[kGmmInputWeight]->cast<abstract::AbstractTuplePtr>();
136   MS_EXCEPTION_IF_NULL(weight_ptr);
137   abstract::AbstractTuple weight_list = *weight_ptr;
138 
139   // for tensorlist(input arg) in backend split. (AscendConvertTupleInputToDynamicInput pass)
140   std::vector<int64_t> dyn_input_sizes;
141   for (size_t i = 0; i < listInputNum; ++i) {
142     if (input_args[i]->GetType()->isa<TypeNone>()) {
143       dyn_input_sizes.push_back(0);
144     } else {
145       abstract::AbstractTuple list = *(input_args[i]->cast<abstract::AbstractTuplePtr>());
146       dyn_input_sizes.push_back(list.size());
147     }
148   }
149   primitive->set_attr("group_info", MakeValue(dyn_input_sizes));
150 
151   // calculate shape. split_item = 3, x[0](m, n) * w[0](e, n, k) = out(m, k)
152   if (split_item == singleTensor) {
153     if (x_list.size() == 1 && weight_list.size() == 1) {
154       std::vector<int64_t> x_shape = x_list[0]->GetShape()->GetShapeVector();
155       std::vector<int64_t> w_shape = weight_list[0]->GetShape()->GetShapeVector();
156       if (x_shape.size() != gmmTensor2D) {
157         MS_EXCEPTION(ShapeError) << "For '" << op_name
158                                  << "', when split_item is 3, the x[0] must be 2D Tensor. But x[0] shape :" << x_shape;
159       }
160       if (w_shape.size() != gmmTensor3D) {
161         MS_EXCEPTION(ShapeError) << "For '" << op_name
162                                  << "', when split_item is 3, the w[0] must be 3D Tensor. But w[0] shape :" << w_shape;
163       }
164       if (x_shape[1] != w_shape[1]) {
165         MS_EXCEPTION(ShapeError) << "For '" << op_name << "', x[0] shape should be (m, n), w[0] shape show be(e, n, k)."
166                                  << "But x[0] shape: " << x_shape << ", w[0] shape: " << w_shape;
167       }
168       std::vector<BaseShapePtr> outshape_merge = {};
169       std::vector<int64_t> res_shape = {x_shape[0], w_shape.back()};
170       outshape_merge.emplace_back(std::make_shared<abstract::TensorShape>(res_shape));
171       return std::make_shared<abstract::TupleShape>(outshape_merge);
172     } else {
173       MS_EXCEPTION(ShapeError) << "For '" << op_name << "', when split_item is 3. the size of x or weight only be 1."
174                                << "But x size: " << x_list.size() << ", weight size: " << weight_list.size();
175     }
176   }
177 
178   if (x_list.size() != weight_list.size()) {
179     MS_EXCEPTION(ShapeError) << "For '" << op_name << "', when split_item is 0, x.size() == w.size()."
180                              << "But x size: " << x_list.size() << ", weight size: " << weight_list.size();
181   }
182 
183   std::vector<BaseShapePtr> outshape_list = {};
184   for (size_t i = 0; i < x_list.size(); i++) {
185     std::vector<int64_t> x_shape = x_list[i]->GetShape()->GetShapeVector();
186     std::vector<int64_t> w_shape = weight_list[i]->GetShape()->GetShapeVector();
187     if (x_shape.size() < gmmTensor2D || x_shape.size() > gmmTensor6D) {
188       MS_EXCEPTION(ShapeError) << "For '" << op_name << "', when split_item is 0, the tensor in x must be 2-6D. But"
189                                << i << "th tensor in x, shape is : " << x_shape;
190     }
191     if (w_shape.size() != gmmTensor2D) {
192       MS_EXCEPTION(ShapeError) << "For '" << op_name << "', when split_item is 0, the tensor in x must be 2-6D. But"
193                                << i << "th tensor in x, shape is : " << x_shape;
194     }
195     if (x_shape.back() != w_shape[0]) {
196       MS_EXCEPTION(ShapeError) << "For '" << op_name
197                                << "' The back in x[i] shape should be equal to the first in w[i] shape. But x[" << i
198                                << "] shape : " << x_shape << ", w[" << i << "] shape : " << w_shape;
199     }
200     std::vector<int64_t> res_shape = x_shape;
201     res_shape.back() = w_shape[1];  // x[a,b,c,m,n] * w[n,k] = out[a,b,c,m,k]
202     outshape_list.emplace_back(std::make_shared<abstract::TensorShape>(res_shape));
203   }
204 
205   return std::make_shared<abstract::TupleShape>(outshape_list);
206 }
207 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const208 TypePtr GroupedMatmulFuncImpl::InferType(const PrimitivePtr &primitive,
209                                          const std::vector<AbstractBasePtr> &input_args) const {
210   MS_EXCEPTION_IF_NULL(primitive);
211   const auto &op_name = primitive->name();
212   const std::set<TypePtr> xw_input_type = {kFloat16, kBFloat16, kFloat32, kInt8};
213 
214   MS_EXCEPTION_IF_NULL(input_args[kGmmInputX]);
215   CheckInputType(input_args, op_name, "x", kGmmInputX, xw_input_type);
216 
217   MS_EXCEPTION_IF_NULL(input_args[kGmmInputWeight]);
218   CheckInputType(input_args, op_name, "weight", kGmmInputWeight, xw_input_type);
219 
220   // get split_item and check groups
221   MS_EXCEPTION_IF_NULL(input_args[kGmmInputSplitItem]);
222   auto split_type = input_args[kGmmInputSplitItem]->GetType();
223   if (split_type->isa<TypeNone>()) {
224     MS_EXCEPTION(TypeError) << "For '" << op_name << "', the group_type must be a int. Current split_item is None";
225   }
226   ValuePtr split_ptr = input_args[kGmmInputSplitItem]->GetValue();
227   auto split_item = GetValue<int64_t>(split_ptr);
228   CheckSplitItem(op_name, split_item);
229   if (split_item == singleTensor) {
230     (void)CheckAndConvertUtils::CheckTensorTypeValid("grouplist", input_args[kGmmInputGroupList]->GetType(), {kInt64},
231                                                      op_name);
232   }
233 
234   // check group_list
235   MS_EXCEPTION_IF_NULL(input_args[kGmmInputGroupList]);
236   auto group_list_type = input_args[kGmmInputGroupList]->GetType();
237   if (split_item == singleTensor && group_list_type->isa<TypeNone>()) {
238     MS_EXCEPTION(ShapeError) << "For '" << op_name
239                              << "', the group_type must be a int when split_item equal 3. Current group_type is None";
240   }
241   auto group_list_shape_map =
242     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kGmmInputGroupList]->GetShape());
243   auto group_list_shape = group_list_shape_map[kShape];
244   if (split_item == singleTensor && group_list_shape.size() != 1) {
245     MS_EXCEPTION(ShapeError) << "For '" << op_name << "', the grouplist must be 1D Tensor when split_item equal 3."
246                              << "Current groups_list shape is " << group_list_shape;
247   }
248 
249   // check group_type
250   MS_EXCEPTION_IF_NULL(input_args[kGmmInputGroupType]);
251   auto group_type_type = input_args[kGmmInputGroupType]->GetType();
252   if (split_item == singleTensor && group_type_type->isa<TypeNone>()) {
253     MS_EXCEPTION(ShapeError) << "For '" << op_name
254                              << "', the group_type must be a int when split_item equal 3. Current group_type is None";
255   }
256   ValuePtr group_type_ptr = input_args[kGmmInputGroupType]->GetValue();
257   auto group_type = GetValue<int64_t>(group_type_ptr);
258   if (split_item == singleTensor && group_type != 0) {
259     MS_EXCEPTION(ShapeError) << "For '" << op_name
260                              << "', the group_type must be 0(split axis m) when split_item equal 3."
261                              << "Current group_type is " << group_type;
262   }
263 
264   // support split_item 0 or 3
265   std::vector<TypePtr> type_tuple;
266   abstract::AbstractTuple x_list = *(input_args[kGmmInputX]->cast<abstract::AbstractTuplePtr>());
267   for (size_t i = 0; i < x_list.size(); i++) {
268     type_tuple.emplace_back(x_list[i]->GetType()->Clone());
269   }
270 
271   return std::make_shared<Tuple>(std::move(type_tuple));
272 }
273 
274 // In compiler get grouplist(not none) for resize
GetValueDependArgIndices() const275 std::set<int64_t> GroupedMatmulFuncImpl::GetValueDependArgIndices() const {
276   if (gGroupedMatmulSplitItem == singleTensor) {
277     return {kGmmInputGroupList};
278   } else {
279     return {};
280   }
281 }
282 }  // namespace ops
283 }  // namespace mindspore
284