• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "ops/ops_func_impl/matmul.h"
17 #include <set>
18 #include <map>
19 #include <string>
20 #include "utils/check_convert_utils.h"
21 #include "utils/ms_context.h"
22 #include "ops/op_name.h"
23 #include "ops/op_utils.h"
24 #include "ops/ops_func_impl/simple_infer.h"
25 
26 namespace mindspore {
27 namespace ops {
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const28 BaseShapePtr MatMulFuncImpl::InferShape(const PrimitivePtr &primitive,
29                                         const std::vector<AbstractBasePtr> &input_args) const {
30   constexpr auto kMatMulInputNum = 4;
31   const std::string op_name = primitive->name();
32   (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual,
33                                            kMatMulInputNum, op_name);
34   auto x = CheckAndConvertUtils::CheckArgsType(op_name, input_args, 0, kObjectTypeTensorType);
35 
36   auto y = CheckAndConvertUtils::CheckArgsType(op_name, input_args, 1, kObjectTypeTensorType);
37   const auto &x_shp = x->GetShape()->GetShapeVector();
38   const auto &y_shp = y->GetShape()->GetShapeVector();
39 
40   if (IsDynamicRank(x_shp) || IsDynamicRank(y_shp)) {
41     ShapeVector ret_shape{abstract::Shape::kShapeRankAny};
42     return std::make_shared<abstract::Shape>(ret_shape);
43   }
44 
45   auto transpose_a_op = GetScalarValue<bool>(input_args[2]->GetValue());
46   auto transpose_b_op = GetScalarValue<bool>(input_args[3]->GetValue());
47 
48   if (!transpose_a_op.has_value()) {
49     return x->GetShape()->Clone();
50   }
51 
52   if (!transpose_b_op.has_value()) {
53     return y->GetShape()->Clone();
54   }
55 
56   auto transpose_a = transpose_a_op.value();
57   auto transpose_b = transpose_b_op.value();
58 
59   if (x_shp.size() == 1 && y_shp.size() == 1 && x_shp[0] == 0 && y_shp[0] == 0) {
60     ShapeVector ret_shape;
61     return std::make_shared<abstract::Shape>(ret_shape);
62   }
63 
64   return InferShape2D(x_shp, y_shp, transpose_a, transpose_b);
65 }
66 
InferShape2D(const ShapeVector & x_shp,const ShapeVector & y_shp,bool transpose_a,bool transpose_b)67 BaseShapePtr MatMulFuncImpl::InferShape2D(const ShapeVector &x_shp, const ShapeVector &y_shp, bool transpose_a,
68                                           bool transpose_b) {
69   const size_t SHAPE_SIZE = 2;
70 
71   if (x_shp.size() != SHAPE_SIZE || y_shp.size() != SHAPE_SIZE) {
72     MS_EXCEPTION(ValueError) << "MatMul inputs should have the same dimension size and equal to 2.";
73   }
74   auto x_col = x_shp[(transpose_a ? 0 : 1)];
75   auto y_row = y_shp[(transpose_b ? 1 : 0)];
76   if (x_col != y_row && x_col >= 0 && y_row >= 0) {
77     MS_EXCEPTION(ValueError) << "For 'MatMul' the input dimensions must be equal, but got 'x1_col': " << x_col
78                              << " and 'x2_row': " << y_row << ".";
79   }
80 
81   ShapeVector ret_shape;
82   auto make_shape = [&transpose_a, &transpose_b](ShapeVector &output, const ShapeVector xshp,
83                                                  const ShapeVector yshp) -> void {
84     if (!xshp.empty() && !yshp.empty()) {
85       output.push_back(xshp[(transpose_a ? 1 : 0)]);
86       output.push_back(yshp[(transpose_b ? 0 : 1)]);
87     }
88     return;
89   };
90   make_shape(ret_shape, x_shp, y_shp);
91   return std::make_shared<abstract::Shape>(ret_shape);
92 }
93 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const94 TypePtr MatMulFuncImpl::InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
95   constexpr auto kMatMulInputNum = 2;
96   auto op_name = primitive->name();
97   (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual,
98                                            kMatMulInputNum, op_name);
99   auto x = CheckAndConvertUtils::CheckArgsType(op_name, input_args, 0, kObjectTypeTensorType);
100   auto y = CheckAndConvertUtils::CheckArgsType(op_name, input_args, 1, kObjectTypeTensorType);
101 
102   auto x_tensor_type = x->GetType()->cast<TensorTypePtr>();
103   MS_EXCEPTION_IF_NULL(x_tensor_type);
104   auto y_tensor_type = y->GetType()->cast<TensorTypePtr>();
105   MS_EXCEPTION_IF_NULL(y_tensor_type);
106   TypePtr x_type = x_tensor_type->element();
107   TypePtr y_type = y_tensor_type->element();
108   if (x_type->type_id() != y_type->type_id()) {
109     MS_EXCEPTION(TypeError) << "For '" << op_name
110                             << "', the type of 'x2' should be same as 'x1', but got 'x1' with type Tensor["
111                             << x_type->ToString() << "] and 'x2' with type Tensor[" << y_type->ToString() << "].";
112   }
113   if (primitive->HasAttr("cast_type")) {
114     auto out_type = primitive->GetAttr("cast_type");
115     MS_EXCEPTION_IF_NULL(out_type);
116     if (!out_type->isa<Type>()) {
117       MS_EXCEPTION(ValueError) << "MatMul cast_type must be a `Type`";
118     }
119     x_type = out_type->cast<TypePtr>();
120   }
121 
122   auto context_ptr = MsContext::GetInstance();
123   MS_EXCEPTION_IF_NULL(context_ptr);
124   std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
125   std::set<TypePtr> valid_types;
126   valid_types = {kUInt8,   kInt8,    kInt16,     kInt32,      kInt64,   kFloat16,
127                  kFloat32, kFloat64, kComplex64, kComplex128, kBFloat16};
128   std::map<std::string, TypePtr> types;
129   (void)types.emplace("x", input_args[kInputIndex0]->GetType());
130   (void)types.emplace("y", input_args[kInputIndex1]->GetType());
131   (void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, primitive->name());
132   if (x_type->type_id() == TypeId::kNumberTypeInt8 && device_target == kAscendDevice) {
133     return kInt32;
134   }
135   return x_type;
136 }
137 
InferType(const PrimitivePtr & primitive,const ValuePtrList & input_values) const138 TypePtrList MatMulFuncImpl::InferType(const PrimitivePtr &primitive, const ValuePtrList &input_values) const {
139   const auto &x_tensor = input_values[kInputIndex0]->cast<tensor::BaseTensorPtr>();
140   const auto &y_tensor = input_values[kInputIndex1]->cast<tensor::BaseTensorPtr>();
141   MS_EXCEPTION_IF_NULL(x_tensor);
142   MS_EXCEPTION_IF_NULL(y_tensor);
143   TypePtr ret_type = x_tensor->Dtype();
144   if (primitive->HasAttr("cast_type")) {
145     auto out_type = primitive->GetAttr("cast_type");
146     MS_EXCEPTION_IF_NULL(out_type);
147     if (!out_type->isa<Type>()) {
148       MS_EXCEPTION(ValueError) << "MatMul cast_type must be a `Type`";
149     }
150     ret_type = out_type->cast<TypePtr>();
151   }
152   const auto x_type = x_tensor->Dtype();
153   const auto y_type = y_tensor->Dtype();
154   auto op_name = primitive->name();
155   if (x_type->type_id() != y_type->type_id()) {
156     MS_EXCEPTION(TypeError) << "For '" << op_name
157                             << "', the type of 'x2' should be same as 'x1', but got 'x1' with type Tensor["
158                             << x_type->ToString() << "] and 'x2' with type Tensor[" << y_type->ToString() << "].";
159   }
160   auto context_ptr = MsContext::GetInstance();
161   MS_EXCEPTION_IF_NULL(context_ptr);
162   std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
163   if (x_type->type_id() == TypeId::kNumberTypeInt8 && device_target == kAscendDevice) {
164     ret_type = kInt32;
165   }
166   return {ret_type};
167 }
168 
InferShape(const PrimitivePtr & primitive,const ValuePtrList & input_values) const169 ShapeArray MatMulFuncImpl::InferShape(const PrimitivePtr &primitive, const ValuePtrList &input_values) const {
170   const auto &x_tensor = input_values[kInputIndex0]->cast<tensor::BaseTensorPtr>();
171   const auto &y_tensor = input_values[kInputIndex1]->cast<tensor::BaseTensorPtr>();
172   MS_EXCEPTION_IF_NULL(x_tensor);
173   MS_EXCEPTION_IF_NULL(y_tensor);
174 
175   const auto &x_shp = x_tensor->shape();
176   const auto &y_shp = y_tensor->shape();
177 
178   auto transpose_a_op = GetScalarValue<bool>(input_values[kInputIndex2]);
179   auto transpose_b_op = GetScalarValue<bool>(input_values[kInputIndex3]);
180 
181   auto transpose_a = transpose_a_op.value();
182   auto transpose_b = transpose_b_op.value();
183 
184   if (x_shp.size() == 1 && y_shp.size() == 1 && x_shp[0] == 0 && y_shp[0] == 0) {
185     ShapeVector ret_shape;
186     return {ret_shape};
187   }
188 
189   const size_t SHAPE_SIZE = 2;
190 
191   if (x_shp.size() != SHAPE_SIZE || y_shp.size() != SHAPE_SIZE) {
192     MS_EXCEPTION(ValueError) << "MatMul inputs should have the same dimension size and equal to 2.";
193   }
194   auto x_col = x_shp[(transpose_a ? 0 : 1)];
195   auto y_row = y_shp[(transpose_b ? 1 : 0)];
196   if (x_col != y_row && x_col >= 0 && y_row >= 0) {
197     MS_EXCEPTION(ValueError) << "For 'MatMul' the input dimensions must be equal, but got 'x1_col': " << x_col
198                              << " and 'x2_row': " << y_row << ".";
199   }
200 
201   ShapeVector ret_shape;
202   if (!x_shp.empty() && !y_shp.empty()) {
203     ret_shape.push_back(x_shp[(transpose_a ? 1 : 0)]);
204     ret_shape.push_back(y_shp[(transpose_b ? 0 : 1)]);
205   }
206   return {ret_shape};
207 }
208 REGISTER_SIMPLE_INFER(kNameMatMul, MatMulFuncImpl)
209 }  // namespace ops
210 }  // namespace mindspore
211