• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""MatMul op"""
17from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
18
19matmul_op_info = TBERegOp("MatMul") \
20    .fusion_type("OPAQUE") \
21    .async_flag(False) \
22    .binfile_name("mat_mul.so") \
23    .compute_cost(10) \
24    .kernel_name("mat_mul") \
25    .partial_flag(True) \
26    .need_check_supported(True) \
27    .attr("transpose_x1", "required", "bool", "all") \
28    .attr("transpose_x2", "required", "bool", "all") \
29    .attr("offset_x", "optional", "int", "all", "0") \
30    .input(0, "x1", False, "required", "all") \
31    .input(1, "x2", False, "required", "all") \
32    .input(2, "bias", False, "optional", "all") \
33    .input(3, "offset_w", False, "optional", "all") \
34    .output(0, "y", False, "required", "all") \
35    .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I8_Default,
36                  DataType.I32_Default) \
37    .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.I8_Default,
38                  DataType.F16_FracNZ) \
39    .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_Default, DataType.I8_Default,
40                  DataType.F32_FracNZ) \
41    .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.I8_Default,
42                  DataType.F32_NHWC) \
43    .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I8_Default,
44                  DataType.F32_Default) \
45    .dtype_format(DataType.I32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, DataType.I8_Default,
46                  DataType.I32_NHWC) \
47    .get_op_info()
48
49
50@op_info_register(matmul_op_info)
51def _matmul_tbe():
52    """Mul TBE register"""
53    return
54