1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "ops/ops_func_impl/matmul.h"
17 #include <set>
18 #include <map>
19 #include <string>
20 #include "utils/check_convert_utils.h"
21 #include "utils/ms_context.h"
22 #include "ops/op_name.h"
23 #include "ops/op_utils.h"
24 #include "ops/ops_func_impl/simple_infer.h"
25
26 namespace mindspore {
27 namespace ops {
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const28 BaseShapePtr MatMulFuncImpl::InferShape(const PrimitivePtr &primitive,
29 const std::vector<AbstractBasePtr> &input_args) const {
30 constexpr auto kMatMulInputNum = 4;
31 const std::string op_name = primitive->name();
32 (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual,
33 kMatMulInputNum, op_name);
34 auto x = CheckAndConvertUtils::CheckArgsType(op_name, input_args, 0, kObjectTypeTensorType);
35
36 auto y = CheckAndConvertUtils::CheckArgsType(op_name, input_args, 1, kObjectTypeTensorType);
37 const auto &x_shp = x->GetShape()->GetShapeVector();
38 const auto &y_shp = y->GetShape()->GetShapeVector();
39
40 if (IsDynamicRank(x_shp) || IsDynamicRank(y_shp)) {
41 ShapeVector ret_shape{abstract::Shape::kShapeRankAny};
42 return std::make_shared<abstract::Shape>(ret_shape);
43 }
44
45 auto transpose_a_op = GetScalarValue<bool>(input_args[2]->GetValue());
46 auto transpose_b_op = GetScalarValue<bool>(input_args[3]->GetValue());
47
48 if (!transpose_a_op.has_value()) {
49 return x->GetShape()->Clone();
50 }
51
52 if (!transpose_b_op.has_value()) {
53 return y->GetShape()->Clone();
54 }
55
56 auto transpose_a = transpose_a_op.value();
57 auto transpose_b = transpose_b_op.value();
58
59 if (x_shp.size() == 1 && y_shp.size() == 1 && x_shp[0] == 0 && y_shp[0] == 0) {
60 ShapeVector ret_shape;
61 return std::make_shared<abstract::Shape>(ret_shape);
62 }
63
64 return InferShape2D(x_shp, y_shp, transpose_a, transpose_b);
65 }
66
InferShape2D(const ShapeVector & x_shp,const ShapeVector & y_shp,bool transpose_a,bool transpose_b)67 BaseShapePtr MatMulFuncImpl::InferShape2D(const ShapeVector &x_shp, const ShapeVector &y_shp, bool transpose_a,
68 bool transpose_b) {
69 const size_t SHAPE_SIZE = 2;
70
71 if (x_shp.size() != SHAPE_SIZE || y_shp.size() != SHAPE_SIZE) {
72 MS_EXCEPTION(ValueError) << "MatMul inputs should have the same dimension size and equal to 2.";
73 }
74 auto x_col = x_shp[(transpose_a ? 0 : 1)];
75 auto y_row = y_shp[(transpose_b ? 1 : 0)];
76 if (x_col != y_row && x_col >= 0 && y_row >= 0) {
77 MS_EXCEPTION(ValueError) << "For 'MatMul' the input dimensions must be equal, but got 'x1_col': " << x_col
78 << " and 'x2_row': " << y_row << ".";
79 }
80
81 ShapeVector ret_shape;
82 auto make_shape = [&transpose_a, &transpose_b](ShapeVector &output, const ShapeVector xshp,
83 const ShapeVector yshp) -> void {
84 if (!xshp.empty() && !yshp.empty()) {
85 output.push_back(xshp[(transpose_a ? 1 : 0)]);
86 output.push_back(yshp[(transpose_b ? 0 : 1)]);
87 }
88 return;
89 };
90 make_shape(ret_shape, x_shp, y_shp);
91 return std::make_shared<abstract::Shape>(ret_shape);
92 }
93
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const94 TypePtr MatMulFuncImpl::InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
95 constexpr auto kMatMulInputNum = 2;
96 auto op_name = primitive->name();
97 (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual,
98 kMatMulInputNum, op_name);
99 auto x = CheckAndConvertUtils::CheckArgsType(op_name, input_args, 0, kObjectTypeTensorType);
100 auto y = CheckAndConvertUtils::CheckArgsType(op_name, input_args, 1, kObjectTypeTensorType);
101
102 auto x_tensor_type = x->GetType()->cast<TensorTypePtr>();
103 MS_EXCEPTION_IF_NULL(x_tensor_type);
104 auto y_tensor_type = y->GetType()->cast<TensorTypePtr>();
105 MS_EXCEPTION_IF_NULL(y_tensor_type);
106 TypePtr x_type = x_tensor_type->element();
107 TypePtr y_type = y_tensor_type->element();
108 if (x_type->type_id() != y_type->type_id()) {
109 MS_EXCEPTION(TypeError) << "For '" << op_name
110 << "', the type of 'x2' should be same as 'x1', but got 'x1' with type Tensor["
111 << x_type->ToString() << "] and 'x2' with type Tensor[" << y_type->ToString() << "].";
112 }
113 if (primitive->HasAttr("cast_type")) {
114 auto out_type = primitive->GetAttr("cast_type");
115 MS_EXCEPTION_IF_NULL(out_type);
116 if (!out_type->isa<Type>()) {
117 MS_EXCEPTION(ValueError) << "MatMul cast_type must be a `Type`";
118 }
119 x_type = out_type->cast<TypePtr>();
120 }
121
122 auto context_ptr = MsContext::GetInstance();
123 MS_EXCEPTION_IF_NULL(context_ptr);
124 std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
125 std::set<TypePtr> valid_types;
126 valid_types = {kUInt8, kInt8, kInt16, kInt32, kInt64, kFloat16,
127 kFloat32, kFloat64, kComplex64, kComplex128, kBFloat16};
128 std::map<std::string, TypePtr> types;
129 (void)types.emplace("x", input_args[kInputIndex0]->GetType());
130 (void)types.emplace("y", input_args[kInputIndex1]->GetType());
131 (void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, primitive->name());
132 if (x_type->type_id() == TypeId::kNumberTypeInt8 && device_target == kAscendDevice) {
133 return kInt32;
134 }
135 return x_type;
136 }
137
InferType(const PrimitivePtr & primitive,const ValuePtrList & input_values) const138 TypePtrList MatMulFuncImpl::InferType(const PrimitivePtr &primitive, const ValuePtrList &input_values) const {
139 const auto &x_tensor = input_values[kInputIndex0]->cast<tensor::BaseTensorPtr>();
140 const auto &y_tensor = input_values[kInputIndex1]->cast<tensor::BaseTensorPtr>();
141 MS_EXCEPTION_IF_NULL(x_tensor);
142 MS_EXCEPTION_IF_NULL(y_tensor);
143 TypePtr ret_type = x_tensor->Dtype();
144 if (primitive->HasAttr("cast_type")) {
145 auto out_type = primitive->GetAttr("cast_type");
146 MS_EXCEPTION_IF_NULL(out_type);
147 if (!out_type->isa<Type>()) {
148 MS_EXCEPTION(ValueError) << "MatMul cast_type must be a `Type`";
149 }
150 ret_type = out_type->cast<TypePtr>();
151 }
152 const auto x_type = x_tensor->Dtype();
153 const auto y_type = y_tensor->Dtype();
154 auto op_name = primitive->name();
155 if (x_type->type_id() != y_type->type_id()) {
156 MS_EXCEPTION(TypeError) << "For '" << op_name
157 << "', the type of 'x2' should be same as 'x1', but got 'x1' with type Tensor["
158 << x_type->ToString() << "] and 'x2' with type Tensor[" << y_type->ToString() << "].";
159 }
160 auto context_ptr = MsContext::GetInstance();
161 MS_EXCEPTION_IF_NULL(context_ptr);
162 std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
163 if (x_type->type_id() == TypeId::kNumberTypeInt8 && device_target == kAscendDevice) {
164 ret_type = kInt32;
165 }
166 return {ret_type};
167 }
168
InferShape(const PrimitivePtr & primitive,const ValuePtrList & input_values) const169 ShapeArray MatMulFuncImpl::InferShape(const PrimitivePtr &primitive, const ValuePtrList &input_values) const {
170 const auto &x_tensor = input_values[kInputIndex0]->cast<tensor::BaseTensorPtr>();
171 const auto &y_tensor = input_values[kInputIndex1]->cast<tensor::BaseTensorPtr>();
172 MS_EXCEPTION_IF_NULL(x_tensor);
173 MS_EXCEPTION_IF_NULL(y_tensor);
174
175 const auto &x_shp = x_tensor->shape();
176 const auto &y_shp = y_tensor->shape();
177
178 auto transpose_a_op = GetScalarValue<bool>(input_values[kInputIndex2]);
179 auto transpose_b_op = GetScalarValue<bool>(input_values[kInputIndex3]);
180
181 auto transpose_a = transpose_a_op.value();
182 auto transpose_b = transpose_b_op.value();
183
184 if (x_shp.size() == 1 && y_shp.size() == 1 && x_shp[0] == 0 && y_shp[0] == 0) {
185 ShapeVector ret_shape;
186 return {ret_shape};
187 }
188
189 const size_t SHAPE_SIZE = 2;
190
191 if (x_shp.size() != SHAPE_SIZE || y_shp.size() != SHAPE_SIZE) {
192 MS_EXCEPTION(ValueError) << "MatMul inputs should have the same dimension size and equal to 2.";
193 }
194 auto x_col = x_shp[(transpose_a ? 0 : 1)];
195 auto y_row = y_shp[(transpose_b ? 1 : 0)];
196 if (x_col != y_row && x_col >= 0 && y_row >= 0) {
197 MS_EXCEPTION(ValueError) << "For 'MatMul' the input dimensions must be equal, but got 'x1_col': " << x_col
198 << " and 'x2_row': " << y_row << ".";
199 }
200
201 ShapeVector ret_shape;
202 if (!x_shp.empty() && !y_shp.empty()) {
203 ret_shape.push_back(x_shp[(transpose_a ? 1 : 0)]);
204 ret_shape.push_back(y_shp[(transpose_b ? 0 : 1)]);
205 }
206 return {ret_shape};
207 }
208 REGISTER_SIMPLE_INFER(kNameMatMul, MatMulFuncImpl)
209 } // namespace ops
210 } // namespace mindspore
211