1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15from __future__ import absolute_import 16 17import te.lang.cce 18from te import tvm 19from te.platform.fusion_manager import fusion_manager 20from topi import generic 21from topi.cce import util 22 23from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType 24 25@fusion_manager.register("square") 26def square_compute(input_x, output_y): 27 """ 28 algorithm: square 29 calculating data's square,y= x*x 30 31 Parameters 32 ---------- 33 input_x: TVM tensor 34 the placeholder of input data 35 output_y: dict 36 shape and dtype of output, should be same shape and type as input 37 kernel_name: str 38 cce kernel name, default value is square 39 40 Returns 41 ------- 42 res : tvm.tensor 43 the result of square 44 """ 45 res = te.lang.cce.vmul(input_x, input_x) 46 return res 47 48 49cus_square_op_info = TBERegOp("CusSquare") \ 50 .fusion_type("OPAQUE") \ 51 .async_flag(False) \ 52 .binfile_name("square.so") \ 53 .compute_cost(10) \ 54 .kernel_name("CusSquareImpl") \ 55 .partial_flag(True) \ 56 .input(0, "x", False, "required", "all") \ 57 .output(0, "y", False, "required", "all") \ 58 .dtype_format(DataType.F32_Default, DataType.F32_Default) \ 59 .dtype_format(DataType.F16_Default, DataType.F16_Default) \ 60 .get_op_info() 61 62 63@op_info_register(cus_square_op_info) 64def CusSquareImpl(input_x, output_y, kernel_name="CusSquareImpl"): 65 """ 66 algorithm: square 67 calculating data's square,y= x*x 68 69 Parameters 70 ---------- 71 input_x : dict 72 shape and dtype of input, only support float32 73 output_y: dict 74 shape and dtype of output, should be same shape and type as input 75 kernel_name : str 76 kernel name, default value is "square" 77 78 Returns 79 ------- 80 None 81 """ 82 shape = input_x.get("shape") 83 dtype = input_x.get("dtype").lower() 84 85 shape = util.shape_refine(shape) 86 data = tvm.placeholder(shape, name="data", dtype=dtype.lower()) 87 88 with tvm.target.cce(): 89 res = square_compute(data, output_y) 90 sch = generic.auto_schedule(res) 91 92 config = {"print_ir": False, 93 "name": kernel_name, 94 "tensor_list": [data, res]} 95 96 te.lang.cce.cce_build_code(sch, config) 97