1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""MinMaxUpdatePerLayer op""" 17from functools import reduce as functools_reduce 18import te.lang.cce 19from te import tvm 20from te.platform.fusion_manager import fusion_manager 21from topi import generic 22from topi.cce import util 23from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType 24 25minmax_update_perlayer_op_info = TBERegOp("MinMaxUpdatePerLayer") \ 26 .fusion_type("OPAQUE") \ 27 .async_flag(False) \ 28 .binfile_name("minmax_update_perlayer.so") \ 29 .compute_cost(10) \ 30 .kernel_name("minmax_update_perlayer") \ 31 .partial_flag(True) \ 32 .attr("ema", "optional", "bool", "all") \ 33 .attr("ema_decay", "optional", "float", "all") \ 34 .input(0, "x", None, "required", None) \ 35 .input(1, "min", None, "required", None) \ 36 .input(2, "max", None, "required", None) \ 37 .output(0, "min_up", True, "required", "all") \ 38 .output(1, "max_up", True, "required", "all") \ 39 .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, 40 DataType.F32_5HD) \ 41 .get_op_info() 42 43 44@op_info_register(minmax_update_perlayer_op_info) 45def _minmax_update_perlayer_tbe(): 46 """MinMaxUpdatePerLayer TBE register""" 47 return 48 49 50@fusion_manager.register("minmax_update_perlayer") 51def minmax_update_perlayer_compute(x, min_val, max_val, ema, ema_decay): 52 """MinMaxUpdatePerLayer compute""" 53 shape = te.lang.cce.util.shape_to_list(x.shape) 54 shape_min = te.lang.cce.util.shape_to_list(min_val.shape) 55 min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype) 56 max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype) 57 if not ema: 58 ema_decay = 0.0 59 60 # CalMinMax 61 axis = tuple(range(len(shape))) 62 x_min = te.lang.cce.reduce_min(x, axis=axis) 63 x_max = te.lang.cce.reduce_max(x, axis=axis) 64 x_min = te.lang.cce.broadcast(x_min, shape_min) 65 x_max = te.lang.cce.broadcast(x_max, shape_min) 66 min_val = te.lang.cce.vadd(te.lang.cce.vmuls( 67 min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) 68 max_val = te.lang.cce.vadd(te.lang.cce.vmuls( 69 max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) 70 min_val = te.lang.cce.vmins(min_val, 0) 71 max_val = te.lang.cce.vmaxs(max_val, 0) 72 73 return [min_val, max_val] 74 75 76@util.check_input_type(dict, dict, dict, dict, dict, bool, float, str) 77def minmax_update_perlayer(x, min_val, max_val, min_up, max_up, 78 ema, ema_decay, kernel_name="minmax_update_perlayer"): 79 """MinMaxUpdatePerLayer op""" 80 input_shape = x.get("shape") 81 input_dtype = x.get("dtype") 82 min_shape = min_val.get("ori_shape") 83 min_dtype = min_val.get("dtype") 84 max_shape = max_val.get("ori_shape") 85 max_dtype = max_val.get("dtype") 86 87 min_shape = util.scalar2tensor_one(min_shape) 88 max_shape = util.scalar2tensor_one(max_shape) 89 util.check_kernel_name(kernel_name) 90 util.check_shape_rule(input_shape) 91 util.check_shape_rule(min_shape, 1, 1, 1) 92 util.check_shape_rule(max_shape, 1, 1, 1) 93 util.check_tensor_shape_size(input_shape) 94 util.check_tensor_shape_size(min_shape) 95 util.check_tensor_shape_size(max_shape) 96 97 check_list = ["float32", "float16"] 98 x_dtype = input_dtype.lower() 99 min_dtype = min_dtype.lower() 100 max_dtype = max_dtype.lower() 101 util.check_dtype_rule(x_dtype, check_list) 102 util.check_dtype_rule(min_dtype, check_list) 103 util.check_dtype_rule(max_dtype, check_list) 104 105 input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) 106 shape_min, _, _ = util.produce_shapes(min_shape, input_shape) 107 108 input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) 109 min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) 110 max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) 111 res_list = minmax_update_perlayer_compute(input_data, min_data, max_data, ema, ema_decay) 112 113 with tvm.target.cce(): 114 sch = generic.auto_schedule(res_list) 115 116 tensor_list = [input_data, min_data, max_data] + list(res_list) 117 config = {"print_ir": False, 118 "name": kernel_name, 119 "tensor_list": tensor_list} 120 121 te.lang.cce.cce_build_code(sch, config) 122