• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""tbe compiler"""
16import json
17import os
18import sys
19from te.platform.cce_conf import te_set_version
20from te_fusion.fusion_util import fusion_op
21from te_fusion.fusion_manager import set_context_parameter
22import tbe.common.context.op_info as operator_info
23sys.path.append(os.path.abspath(os.path.dirname(__file__)))
24# pylint: disable=wrong-import-position
25from tbe_common import check_kernel_info, get_args, get_built_in_impl_path
26
27build_in_impl_path = get_built_in_impl_path()
28
29# op function list
30op_build = "compile"
31
32
33def _initialize(impl_path):
34    """Initialize"""
35    if impl_path == "":
36        op_module_name = build_in_impl_path
37    else:
38        op_module_name = impl_path
39    if not op_module_name:
40        raise ValueError("Can not find the env TBE_IMPL_PATH")
41
42    sys.path.insert(0, op_module_name)
43
44
45def _replace_range(args):
46    for arg in args:
47        if not arg or not arg.__contains__('range'):
48            continue
49        shape_range = arg["range"]
50        for range_item in shape_range:
51            for index, value in enumerate(range_item):
52                if value < 0:
53                    range_item[index] = None
54
55
56def build_op(build_type, json_str, tune_mode=None):
57    """
58    call op functions with function name and input args json_str
59
60    Args:
61        build_type : op function name
62        json_str (str): op function input args
63        tune_mode (str): if use auto_tune
64
65    Raises:
66        Exception: If specific keyword is not found.
67    """
68    kernel_info = json.loads(json_str)
69    check_kernel_info(kernel_info)
70    te_set_version(kernel_info["op_info"]["socVersion"])
71    op_name = kernel_info['op_info']['name']
72    op_type = kernel_info['op_info']['Type']
73    rl_tune_switch = kernel_info['op_info']['rl_tune_switch']
74    rl_tune_list = kernel_info['op_info']['rl_tune_list']
75    reset_op_info = kernel_info["reset_op_info"]
76    op_tune_switch = kernel_info['op_info']['op_tune_switch']
77    op_tune_list = kernel_info['op_info']['op_tune_list']
78    pass_list = kernel_info['op_info']['pass_list']
79
80    try:
81        custom_flag = False
82        if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
83            impl_path = os.path.realpath(kernel_info['impl_path'])
84            if os.path.isfile(impl_path):
85                path, file_name = os.path.split(impl_path)
86                op_name, _ = os.path.splitext(file_name)
87                impl_path = path
88                custom_flag = True
89            else:
90                impl_path = ""
91        _initialize(impl_path)
92
93        inputs_args = get_args(kernel_info['op_info'], 'inputs')
94        outputs_args = get_args(kernel_info['op_info'], 'outputs')
95        attrs_args = get_args(kernel_info['op_info'], 'attrs')
96        kernel_name = kernel_info['op_info']['kernel_name']
97        is_dynamic_shape = kernel_info['op_info']['is_dynamic_shape']
98        if is_dynamic_shape:
99            _replace_range(inputs_args)
100            _replace_range(outputs_args)
101
102        if custom_flag:
103            op_module_name = op_name
104        else:
105            if is_dynamic_shape:
106                op_module = __import__("impl.dynamic." + op_name, globals(), locals(), [op_name], 0)
107                op_module_name = "impl.dynamic." + op_name
108            else:
109                op_module_name = "impl." + op_name
110        # get function
111        if build_type == op_build:
112            if custom_flag:
113                py_fn_name = kernel_info['op_info']['name']
114            else:
115                py_fn_name = op_name
116        else:
117            raise ValueError("function {} is not supported by Tbe op {}.".format(build_type, op_name))
118
119        # call function
120        if is_dynamic_shape:
121            op_func = getattr(op_module, py_fn_name, None)
122            if op_func is None:
123                raise ValueError("Op:{} function {} is not supported by Tbe.".format(op_name, build_type))
124            import tbe.common.context.op_context as op_context
125            with op_context.OpContext("dynamic"):
126                op_info = operator_info.OpInfo(op_type, op_type)
127                context = op_context.get_context()
128                context.add_op_info(op_info)
129                set_context_parameter(context, None, None, reset_op_info)
130                op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
131                compile_info = op_context.get_context().get_compile_info()
132                if tune_mode is not None:
133                    return compile_info, (inputs_args, outputs_args, attrs_args), op_module_name
134                return compile_info
135        else:
136            attrs_args.append(kernel_name)
137            import te_fusion.fusion_manager as fusion_manager
138            res = fusion_manager.build_single_op(op_module_name, py_fn_name, op_type, "build",
139                                                 inputs=inputs_args,
140                                                 outputs=outputs_args,
141                                                 attrs=attrs_args,
142                                                 unknown_shape=False,
143                                                 int64_mode=False,
144                                                 dynamic_compile_static=False,
145                                                 op_pattern=None,
146                                                 auto_tiling_mode=None,
147                                                 device_id=None,
148                                                 fuzz_build_info=None,
149                                                 reset_op_info=reset_op_info,
150                                                 switch_str=rl_tune_switch,
151                                                 lic_opt_list=rl_tune_list,
152                                                 pass_opt_list=pass_list,
153                                                 switch_optune=op_tune_switch,
154                                                 optune_opt_list=op_tune_list)
155            if tune_mode is not None:
156                return None, (inputs_args, outputs_args, attrs_args), op_module_name
157            return res
158
159    except Exception as e:
160        raise RuntimeError(e)
161    finally:
162        pass
163
164
165def compile_fusion_op(json_str):
166    """
167    compile fusion op with input args json_str
168
169    Args:
170        json_str (str): op function input args
171
172    Raises:
173        Exception: If specific keyword is not found.
174    """
175    args = json.loads(json_str)
176    reset_op_info = args["reset_op_info"]
177    te_set_version(args['fusion_op']["socVersion"])
178    if 'fusion_op' not in args or not args['fusion_op']:
179        raise ValueError("Json string Errors, key:fusion_op not found.")
180    args['fusion_op']['SocInfo'] = args['SocInfo']
181    fusion_op_arg = args['fusion_op']
182    rl_tune_switch = args['fusion_op']['rl_tune_switch']
183    rl_tune_list = args['fusion_op']['rl_tune_list']
184    op_tune_switch = args['fusion_op']['op_tune_switch']
185    op_tune_list = args['fusion_op']['op_tune_list']
186    pass_list = args['fusion_op']['pass_list']
187    return fusion_op(json.dumps(fusion_op_arg), reset_op_info=reset_op_info, switch_str=rl_tune_switch,
188                     lic_opt_list=rl_tune_list, pass_opt_list=pass_list,
189                     switch_optune=op_tune_switch, optune_opt_list=op_tune_list)
190
191
192def compile_with_json(json_str):
193    """
194    Compile tbe with json.
195
196    Args:
197        json_str (str): jason file path.
198
199    """
200    json_info = json.loads(json_str)
201    if "fusion_op" in json_info:
202        ret = compile_fusion_op(json_str)
203    else:
204        ret = build_op(op_build, json_str, None)
205    return ret
206
207
208if __name__ == "__main__":
209    in_args = sys.stdin.readline()
210    result = compile_with_json(in_args)
211    if isinstance(result, dict):
212        sys.stderr.write(json.dumps(result))
213        sys.stderr.flush()
214