1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/ops_func_impl/matmul_ext.h"
18 #include <algorithm>
19 #include <set>
20 #include <map>
21 #include <vector>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include "ops/op_name.h"
26 #include "utils/shape_utils.h"
27 #include "abstract/dshape.h"
28 #include "ir/primitive.h"
29 #include "ops/op_utils.h"
30 #include "utils/check_convert_utils.h"
31 #include "utils/ms_context.h"
32 #include "ops/ops_func_impl/simple_infer.h"
33
34 namespace mindspore {
35 namespace ops {
CheckMatMulShapes(const ShapeVector & shape1,const ShapeVector & shape2)36 ShapeVector CheckMatMulShapes(const ShapeVector &shape1, const ShapeVector &shape2) {
37 ShapeVector shape_out;
38 if (shape1.size() == 0 || shape2.size() == 0) {
39 MS_EXCEPTION(ValueError) << "For 'MatMulExt' op, inputs must be all tensors and rank >= 1";
40 }
41 if (shape2.size() >= kDim2 && shape1.back() != shape2[shape2.size() - kDim2]) {
42 MS_EXCEPTION(RuntimeError) << "For 'MatMulExt' op, shape1[-1] must be equal to shape2[-2], but got "
43 << shape1.back() << " and " << shape2[shape2.size() - kDim2] << ".";
44 }
45 int len_diff = std::abs(static_cast<int>(shape1.size()) - static_cast<int>(shape2.size()));
46 ShapeVector shape1_padded;
47 ShapeVector shape2_padded;
48 if (shape1.size() < shape2.size()) {
49 shape1_padded = ShapeVector(len_diff, 1);
50 shape1_padded.insert(shape1_padded.end(), shape1.begin(), shape1.end());
51 shape2_padded = shape2;
52 } else {
53 shape2_padded = ShapeVector(len_diff, 1);
54 shape2_padded.insert(shape2_padded.end(), shape2.begin(), shape2.end());
55 shape1_padded = shape1;
56 }
57 int max_len = std::max(static_cast<int>(shape1_padded.size()) - kInputIndex2,
58 static_cast<int>(shape2_padded.size()) - kInputIndex2);
59 for (int i = 0; i < max_len; ++i) {
60 int64_t dim1 = i < static_cast<int>(shape1_padded.size() - kInputIndex2) ? shape1_padded[i] : 1;
61 int64_t dim2 = i < static_cast<int>(shape2_padded.size() - kInputIndex2) ? shape2_padded[i] : 1;
62 if (dim1 != 1 && dim2 != 1 && dim1 != dim2) {
63 MS_EXCEPTION(RuntimeError) << "For 'MatMulExt' op, shape1 and shape2 must be broadcastable, but got "
64 << shape1_padded << " and " << shape2_padded;
65 }
66 shape_out.push_back(std::max(dim1, dim2));
67 }
68 return shape_out;
69 }
70
GetMatMulExtBroadcastShape(const ShapeVector & base_shape,const ShapeVector & input_shape)71 ShapeVector GetMatMulExtBroadcastShape(const ShapeVector &base_shape, const ShapeVector &input_shape) {
72 const size_t kNum2 = 2;
73 ShapeVector broadcast_shape = base_shape;
74 if (input_shape.size() == 1) {
75 broadcast_shape.push_back(1);
76 broadcast_shape.push_back(input_shape[0]);
77 } else {
78 broadcast_shape.push_back(input_shape[input_shape.size() - kNum2]);
79 broadcast_shape.push_back(input_shape[input_shape.size() - 1]);
80 }
81 return broadcast_shape;
82 }
83
InferShapeRem(const ShapeVector & shape_backbone,const ShapeVector & shape1,const ShapeVector & shape2,bool transpose_b)84 ShapeVector InferShapeRem(const ShapeVector &shape_backbone, const ShapeVector &shape1, const ShapeVector &shape2,
85 bool transpose_b) {
86 int ndim1 = SizeToInt(shape1.size());
87 int ndim2 = SizeToInt(shape2.size());
88 ShapeVector shape_rem(shape_backbone);
89 if (ndim1 >= SizeToInt(kDim2)) {
90 shape_rem.push_back(shape1[ndim1 - SizeToInt(kDim2)]);
91 }
92 if (transpose_b) {
93 if (ndim2 >= SizeToInt(kDim2)) {
94 shape_rem.push_back(shape2[ndim2 - SizeToInt(kDim2)]);
95 }
96 } else {
97 if (ndim2 >= 1) {
98 shape_rem.push_back(shape2.back());
99 }
100 }
101 return shape_rem;
102 }
103
MatMulMakeShape(ShapeVector * output,const ShapeVector xshp,const ShapeVector yshp)104 void MatMulMakeShape(ShapeVector *output, const ShapeVector xshp, const ShapeVector yshp) {
105 size_t offset = kDim2;
106 if (xshp.empty() || yshp.empty()) {
107 return;
108 }
109 auto x_rank = xshp.size();
110 auto y_rank = yshp.size();
111 if (x_rank == 1 && y_rank == 1) {
112 return;
113 }
114
115 auto max_rank = x_rank > y_rank ? x_rank : y_rank;
116
117 if (x_rank == 1 || y_rank == 1) {
118 for (size_t i = 0; i < max_rank - 1; i++) {
119 output->push_back(abstract::Shape::kShapeDimAny);
120 }
121 return;
122 }
123
124 ShapeVector long_input = xshp.size() > yshp.size() ? xshp : yshp;
125 ShapeVector short_input = xshp.size() > yshp.size() ? yshp : xshp;
126 size_t size_diff = long_input.size() - short_input.size();
127 for (size_t i = 0; i < long_input.size() - offset; i++) {
128 if (long_input[i] < 0) {
129 output->push_back(abstract::Shape::kShapeDimAny);
130 } else if (i >= size_diff) {
131 output->push_back(long_input[i] > short_input[i - size_diff] ? long_input[i] : short_input[i - size_diff]);
132 } else {
133 output->push_back(long_input[i]);
134 }
135 }
136 size_t x_offset = xshp.size() - offset;
137 size_t y_offset = yshp.size() - offset;
138 output->push_back(xshp[x_offset]);
139 output->push_back(yshp[y_offset + 1]);
140 }
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const141 BaseShapePtr MatMulExtFuncImpl::InferShape(const PrimitivePtr &primitive,
142 const std::vector<AbstractBasePtr> &input_args) const {
143 MS_EXCEPTION_IF_NULL(primitive);
144 auto constexpr kMatMulExtInputNum = 2;
145 (void)CheckAndConvertUtils::CheckInteger("input num", SizeToLong(input_args.size()), kEqual, kMatMulExtInputNum,
146 primitive->name());
147 auto x_shp = input_args[kInputIndex0]->GetShape()->GetShapeVector();
148 auto y_shp = input_args[kInputIndex1]->GetShape()->GetShapeVector();
149 if (IsDynamicRank(x_shp) || IsDynamicRank(y_shp)) {
150 return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
151 }
152
153 bool dynamic_shape = IsDynamic(x_shp) || IsDynamic(y_shp);
154 if (!dynamic_shape) {
155 bool transpose_b = y_shp.size() == 1;
156 ShapeVector shape_backbone = CheckMatMulShapes(x_shp, y_shp);
157 ShapeVector ret_shape = InferShapeRem(shape_backbone, x_shp, y_shp, transpose_b);
158 return std::make_shared<abstract::Shape>(std::move(ret_shape));
159 }
160
161 ShapeVector ret_shape;
162 MatMulMakeShape(&ret_shape, x_shp, y_shp);
163 return std::make_shared<abstract::Shape>(std::move(ret_shape));
164 }
165
InferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args) const166 TypePtr MatMulExtFuncImpl::InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const {
167 MS_EXCEPTION_IF_NULL(prim);
168 const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
169 kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128, kBFloat16};
170 std::map<std::string, TypePtr> types;
171 (void)types.emplace("x", input_args[0]->GetType());
172 (void)types.emplace("w", input_args[1]->GetType());
173 (void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
174 TypePtr x_type = input_args[0]->GetType();
175 return x_type;
176 }
177
InferType(const PrimitivePtr & primitive,const ValuePtrList & input_values) const178 TypePtrList MatMulExtFuncImpl::InferType(const PrimitivePtr &primitive, const ValuePtrList &input_values) const {
179 const auto &x_tensor = input_values[kInputIndex0]->cast<tensor::BaseTensorPtr>();
180 const auto &y_tensor = input_values[kInputIndex1]->cast<tensor::BaseTensorPtr>();
181 MS_EXCEPTION_IF_NULL(x_tensor);
182 MS_EXCEPTION_IF_NULL(y_tensor);
183 TypePtr ret_type = x_tensor->Dtype();
184 const auto x_type = x_tensor->Dtype();
185 const auto y_type = y_tensor->Dtype();
186 auto op_name = primitive->name();
187 if (x_type->type_id() != y_type->type_id()) {
188 MS_EXCEPTION(TypeError) << "For '" << op_name
189 << "', the type of 'x2' should be same as 'x1', but got 'x1' with type Tensor["
190 << x_type->ToString() << "] and 'x2' with type Tensor[" << y_type->ToString() << "].";
191 }
192 return {ret_type};
193 }
194
InferShape(const PrimitivePtr & primitive,const ValuePtrList & input_values) const195 ShapeArray MatMulExtFuncImpl::InferShape(const PrimitivePtr &primitive, const ValuePtrList &input_values) const {
196 const auto &x_tensor = input_values[kInputIndex0]->cast<tensor::BaseTensorPtr>();
197 const auto &y_tensor = input_values[kInputIndex1]->cast<tensor::BaseTensorPtr>();
198 MS_EXCEPTION_IF_NULL(x_tensor);
199 MS_EXCEPTION_IF_NULL(y_tensor);
200
201 const auto &x_shp = x_tensor->shape();
202 const auto &y_shp = y_tensor->shape();
203
204 bool transpose_b = y_shp.size() == 1;
205 ShapeVector shape_backbone = CheckMatMulShapes(x_shp, y_shp);
206 ShapeVector ret_shape = InferShapeRem(shape_backbone, x_shp, y_shp, transpose_b);
207 return {ret_shape};
208 }
209 REGISTER_SIMPLE_INFER(kNameMatMulExt, MatMulExtFuncImpl)
210 } // namespace ops
211 } // namespace mindspore
212