• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""MinMaxUpdatePerLayer op"""
17from functools import reduce as functools_reduce
18import te.lang.cce
19from te import tvm
20from te.platform.fusion_manager import fusion_manager
21from topi import generic
22from topi.cce import util
23from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
24
25minmax_update_perlayer_op_info = TBERegOp("MinMaxUpdatePerLayer") \
26    .fusion_type("OPAQUE") \
27    .async_flag(False) \
28    .binfile_name("minmax_update_perlayer.so") \
29    .compute_cost(10) \
30    .kernel_name("minmax_update_perlayer") \
31    .partial_flag(True) \
32    .attr("ema", "optional", "bool", "all") \
33    .attr("ema_decay", "optional", "float", "all") \
34    .input(0, "x", None, "required", None) \
35    .input(1, "min", None, "required", None) \
36    .input(2, "max", None, "required", None) \
37    .output(0, "min_up", True, "required", "all") \
38    .output(1, "max_up", True, "required", "all") \
39    .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
40                  DataType.F32_5HD) \
41    .get_op_info()
42
43
44@op_info_register(minmax_update_perlayer_op_info)
45def _minmax_update_perlayer_tbe():
46    """MinMaxUpdatePerLayer TBE register"""
47    return
48
49
50@fusion_manager.register("minmax_update_perlayer")
51def minmax_update_perlayer_compute(x, min_val, max_val, ema, ema_decay):
52    """MinMaxUpdatePerLayer compute"""
53    shape = te.lang.cce.util.shape_to_list(x.shape)
54    shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
55    min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
56    max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
57    if not ema:
58        ema_decay = 0.0
59
60    # CalMinMax
61    axis = tuple(range(len(shape)))
62    x_min = te.lang.cce.reduce_min(x, axis=axis)
63    x_max = te.lang.cce.reduce_max(x, axis=axis)
64    x_min = te.lang.cce.broadcast(x_min, shape_min)
65    x_max = te.lang.cce.broadcast(x_max, shape_min)
66    min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
67        min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
68    max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
69        max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
70    min_val = te.lang.cce.vmins(min_val, 0)
71    max_val = te.lang.cce.vmaxs(max_val, 0)
72
73    return [min_val, max_val]
74
75
76@util.check_input_type(dict, dict, dict, dict, dict, bool, float, str)
77def minmax_update_perlayer(x, min_val, max_val, min_up, max_up,
78                           ema, ema_decay, kernel_name="minmax_update_perlayer"):
79    """MinMaxUpdatePerLayer op"""
80    input_shape = x.get("shape")
81    input_dtype = x.get("dtype")
82    min_shape = min_val.get("ori_shape")
83    min_dtype = min_val.get("dtype")
84    max_shape = max_val.get("ori_shape")
85    max_dtype = max_val.get("dtype")
86
87    min_shape = util.scalar2tensor_one(min_shape)
88    max_shape = util.scalar2tensor_one(max_shape)
89    util.check_kernel_name(kernel_name)
90    util.check_shape_rule(input_shape)
91    util.check_shape_rule(min_shape, 1, 1, 1)
92    util.check_shape_rule(max_shape, 1, 1, 1)
93    util.check_tensor_shape_size(input_shape)
94    util.check_tensor_shape_size(min_shape)
95    util.check_tensor_shape_size(max_shape)
96
97    check_list = ["float32", "float16"]
98    x_dtype = input_dtype.lower()
99    min_dtype = min_dtype.lower()
100    max_dtype = max_dtype.lower()
101    util.check_dtype_rule(x_dtype, check_list)
102    util.check_dtype_rule(min_dtype, check_list)
103    util.check_dtype_rule(max_dtype, check_list)
104
105    input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
106    shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
107
108    input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
109    min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
110    max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
111    res_list = minmax_update_perlayer_compute(input_data, min_data, max_data, ema, ema_decay)
112
113    with tvm.target.cce():
114        sch = generic.auto_schedule(res_list)
115
116    tensor_list = [input_data, min_data, max_data] + list(res_list)
117    config = {"print_ir": False,
118              "name": kernel_name,
119              "tensor_list": tensor_list}
120
121    te.lang.cce.cce_build_code(sch, config)
122