• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/ops_func_impl/matmul_ext.h"
18 #include <algorithm>
19 #include <set>
20 #include <map>
21 #include <vector>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include "ops/op_name.h"
26 #include "utils/shape_utils.h"
27 #include "abstract/dshape.h"
28 #include "ir/primitive.h"
29 #include "ops/op_utils.h"
30 #include "utils/check_convert_utils.h"
31 #include "utils/ms_context.h"
32 #include "ops/ops_func_impl/simple_infer.h"
33 
34 namespace mindspore {
35 namespace ops {
CheckMatMulShapes(const ShapeVector & shape1,const ShapeVector & shape2)36 ShapeVector CheckMatMulShapes(const ShapeVector &shape1, const ShapeVector &shape2) {
37   ShapeVector shape_out;
38   if (shape1.size() == 0 || shape2.size() == 0) {
39     MS_EXCEPTION(ValueError) << "For 'MatMulExt' op, inputs must be all tensors and rank >= 1";
40   }
41   if (shape2.size() >= kDim2 && shape1.back() != shape2[shape2.size() - kDim2]) {
42     MS_EXCEPTION(RuntimeError) << "For 'MatMulExt' op, shape1[-1] must be equal to shape2[-2], but got "
43                                << shape1.back() << " and " << shape2[shape2.size() - kDim2] << ".";
44   }
45   int len_diff = std::abs(static_cast<int>(shape1.size()) - static_cast<int>(shape2.size()));
46   ShapeVector shape1_padded;
47   ShapeVector shape2_padded;
48   if (shape1.size() < shape2.size()) {
49     shape1_padded = ShapeVector(len_diff, 1);
50     shape1_padded.insert(shape1_padded.end(), shape1.begin(), shape1.end());
51     shape2_padded = shape2;
52   } else {
53     shape2_padded = ShapeVector(len_diff, 1);
54     shape2_padded.insert(shape2_padded.end(), shape2.begin(), shape2.end());
55     shape1_padded = shape1;
56   }
57   int max_len = std::max(static_cast<int>(shape1_padded.size()) - kInputIndex2,
58                          static_cast<int>(shape2_padded.size()) - kInputIndex2);
59   for (int i = 0; i < max_len; ++i) {
60     int64_t dim1 = i < static_cast<int>(shape1_padded.size() - kInputIndex2) ? shape1_padded[i] : 1;
61     int64_t dim2 = i < static_cast<int>(shape2_padded.size() - kInputIndex2) ? shape2_padded[i] : 1;
62     if (dim1 != 1 && dim2 != 1 && dim1 != dim2) {
63       MS_EXCEPTION(RuntimeError) << "For 'MatMulExt' op,  shape1 and shape2 must be broadcastable, but got "
64                                  << shape1_padded << " and " << shape2_padded;
65     }
66     shape_out.push_back(std::max(dim1, dim2));
67   }
68   return shape_out;
69 }
70 
GetMatMulExtBroadcastShape(const ShapeVector & base_shape,const ShapeVector & input_shape)71 ShapeVector GetMatMulExtBroadcastShape(const ShapeVector &base_shape, const ShapeVector &input_shape) {
72   const size_t kNum2 = 2;
73   ShapeVector broadcast_shape = base_shape;
74   if (input_shape.size() == 1) {
75     broadcast_shape.push_back(1);
76     broadcast_shape.push_back(input_shape[0]);
77   } else {
78     broadcast_shape.push_back(input_shape[input_shape.size() - kNum2]);
79     broadcast_shape.push_back(input_shape[input_shape.size() - 1]);
80   }
81   return broadcast_shape;
82 }
83 
InferShapeRem(const ShapeVector & shape_backbone,const ShapeVector & shape1,const ShapeVector & shape2,bool transpose_b)84 ShapeVector InferShapeRem(const ShapeVector &shape_backbone, const ShapeVector &shape1, const ShapeVector &shape2,
85                           bool transpose_b) {
86   int ndim1 = SizeToInt(shape1.size());
87   int ndim2 = SizeToInt(shape2.size());
88   ShapeVector shape_rem(shape_backbone);
89   if (ndim1 >= SizeToInt(kDim2)) {
90     shape_rem.push_back(shape1[ndim1 - SizeToInt(kDim2)]);
91   }
92   if (transpose_b) {
93     if (ndim2 >= SizeToInt(kDim2)) {
94       shape_rem.push_back(shape2[ndim2 - SizeToInt(kDim2)]);
95     }
96   } else {
97     if (ndim2 >= 1) {
98       shape_rem.push_back(shape2.back());
99     }
100   }
101   return shape_rem;
102 }
103 
MatMulMakeShape(ShapeVector * output,const ShapeVector xshp,const ShapeVector yshp)104 void MatMulMakeShape(ShapeVector *output, const ShapeVector xshp, const ShapeVector yshp) {
105   size_t offset = kDim2;
106   if (xshp.empty() || yshp.empty()) {
107     return;
108   }
109   auto x_rank = xshp.size();
110   auto y_rank = yshp.size();
111   if (x_rank == 1 && y_rank == 1) {
112     return;
113   }
114 
115   auto max_rank = x_rank > y_rank ? x_rank : y_rank;
116 
117   if (x_rank == 1 || y_rank == 1) {
118     for (size_t i = 0; i < max_rank - 1; i++) {
119       output->push_back(abstract::Shape::kShapeDimAny);
120     }
121     return;
122   }
123 
124   ShapeVector long_input = xshp.size() > yshp.size() ? xshp : yshp;
125   ShapeVector short_input = xshp.size() > yshp.size() ? yshp : xshp;
126   size_t size_diff = long_input.size() - short_input.size();
127   for (size_t i = 0; i < long_input.size() - offset; i++) {
128     if (long_input[i] < 0) {
129       output->push_back(abstract::Shape::kShapeDimAny);
130     } else if (i >= size_diff) {
131       output->push_back(long_input[i] > short_input[i - size_diff] ? long_input[i] : short_input[i - size_diff]);
132     } else {
133       output->push_back(long_input[i]);
134     }
135   }
136   size_t x_offset = xshp.size() - offset;
137   size_t y_offset = yshp.size() - offset;
138   output->push_back(xshp[x_offset]);
139   output->push_back(yshp[y_offset + 1]);
140 }
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const141 BaseShapePtr MatMulExtFuncImpl::InferShape(const PrimitivePtr &primitive,
142                                            const std::vector<AbstractBasePtr> &input_args) const {
143   MS_EXCEPTION_IF_NULL(primitive);
144   auto constexpr kMatMulExtInputNum = 2;
145   (void)CheckAndConvertUtils::CheckInteger("input num", SizeToLong(input_args.size()), kEqual, kMatMulExtInputNum,
146                                            primitive->name());
147   auto x_shp = input_args[kInputIndex0]->GetShape()->GetShapeVector();
148   auto y_shp = input_args[kInputIndex1]->GetShape()->GetShapeVector();
149   if (IsDynamicRank(x_shp) || IsDynamicRank(y_shp)) {
150     return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
151   }
152 
153   bool dynamic_shape = IsDynamic(x_shp) || IsDynamic(y_shp);
154   if (!dynamic_shape) {
155     bool transpose_b = y_shp.size() == 1;
156     ShapeVector shape_backbone = CheckMatMulShapes(x_shp, y_shp);
157     ShapeVector ret_shape = InferShapeRem(shape_backbone, x_shp, y_shp, transpose_b);
158     return std::make_shared<abstract::Shape>(std::move(ret_shape));
159   }
160 
161   ShapeVector ret_shape;
162   MatMulMakeShape(&ret_shape, x_shp, y_shp);
163   return std::make_shared<abstract::Shape>(std::move(ret_shape));
164 }
165 
InferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args) const166 TypePtr MatMulExtFuncImpl::InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const {
167   MS_EXCEPTION_IF_NULL(prim);
168   const std::set<TypePtr> valid_types = {kInt8,   kInt16,   kInt32,   kInt64,   kUInt8,     kUInt16,     kUInt32,
169                                          kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128, kBFloat16};
170   std::map<std::string, TypePtr> types;
171   (void)types.emplace("x", input_args[0]->GetType());
172   (void)types.emplace("w", input_args[1]->GetType());
173   (void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
174   TypePtr x_type = input_args[0]->GetType();
175   return x_type;
176 }
177 
InferType(const PrimitivePtr & primitive,const ValuePtrList & input_values) const178 TypePtrList MatMulExtFuncImpl::InferType(const PrimitivePtr &primitive, const ValuePtrList &input_values) const {
179   const auto &x_tensor = input_values[kInputIndex0]->cast<tensor::BaseTensorPtr>();
180   const auto &y_tensor = input_values[kInputIndex1]->cast<tensor::BaseTensorPtr>();
181   MS_EXCEPTION_IF_NULL(x_tensor);
182   MS_EXCEPTION_IF_NULL(y_tensor);
183   TypePtr ret_type = x_tensor->Dtype();
184   const auto x_type = x_tensor->Dtype();
185   const auto y_type = y_tensor->Dtype();
186   auto op_name = primitive->name();
187   if (x_type->type_id() != y_type->type_id()) {
188     MS_EXCEPTION(TypeError) << "For '" << op_name
189                             << "', the type of 'x2' should be same as 'x1', but got 'x1' with type Tensor["
190                             << x_type->ToString() << "] and 'x2' with type Tensor[" << y_type->ToString() << "].";
191   }
192   return {ret_type};
193 }
194 
InferShape(const PrimitivePtr & primitive,const ValuePtrList & input_values) const195 ShapeArray MatMulExtFuncImpl::InferShape(const PrimitivePtr &primitive, const ValuePtrList &input_values) const {
196   const auto &x_tensor = input_values[kInputIndex0]->cast<tensor::BaseTensorPtr>();
197   const auto &y_tensor = input_values[kInputIndex1]->cast<tensor::BaseTensorPtr>();
198   MS_EXCEPTION_IF_NULL(x_tensor);
199   MS_EXCEPTION_IF_NULL(y_tensor);
200 
201   const auto &x_shp = x_tensor->shape();
202   const auto &y_shp = y_tensor->shape();
203 
204   bool transpose_b = y_shp.size() == 1;
205   ShapeVector shape_backbone = CheckMatMulShapes(x_shp, y_shp);
206   ShapeVector ret_shape = InferShapeRem(shape_backbone, x_shp, y_shp, transpose_b);
207   return {ret_shape};
208 }
209 REGISTER_SIMPLE_INFER(kNameMatMulExt, MatMulExtFuncImpl)
210 }  // namespace ops
211 }  // namespace mindspore
212