1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15from __future__ import absolute_import 16import te.lang.cce 17from te import tvm 18from te.platform.fusion_manager import fusion_manager 19from topi import generic 20from topi.cce import util 21from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType 22 23@fusion_manager.register("add3") 24def add3_compute(input1, input2, const_bias): 25 sum2 = te.lang.cce.vadd(input1, input2) 26 sum3 = te.lang.cce.vadds(sum2, tvm.const(const_bias, dtype=input1.dtype)) 27 return sum3 28 29 30cus_add3_op_info = TBERegOp("CusAdd3") \ 31 .fusion_type("OPAQUE") \ 32 .async_flag(False) \ 33 .binfile_name("add3.so") \ 34 .compute_cost(10) \ 35 .kernel_name("CusAdd3Impl") \ 36 .partial_flag(True) \ 37 .attr("const_bias", "required", "float", "all") \ 38 .input(0, "input1", False, "required", "all") \ 39 .input(1, "input2", False, "required", "all") \ 40 .output(0, "sum", False, "required", "all") \ 41 .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ 42 .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ 43 .get_op_info() 44 45 46@op_info_register(cus_add3_op_info) 47def CusAdd3Impl(input1, inptu2, sum1, const_bias, kernel_name="CusAdd3Impl"): 48 shape = input1.get("shape") 49 shape = util.shape_refine(shape) 50 dtype = input1.get("dtype").lower() 51 input1 = tvm.placeholder(shape, name="input1", dtype=dtype.lower()) 52 input2 = tvm.placeholder(shape, name="input2", dtype=dtype.lower()) 53 54 with tvm.target.cce(): 55 res = add3_compute(input1, input2, const_bias) 56 sch = generic.auto_schedule(res) 57 58 config = {"print_ir": False, 59 "name": kernel_name, 60 "tensor_list": [input1, input2, res]} 61 62 te.lang.cce.cce_build_code(sch, config) 63