1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""tbe compiler""" 16import json 17import os 18import sys 19from te.platform.cce_conf import te_set_version 20from te_fusion.fusion_util import fusion_op 21from te_fusion.fusion_manager import set_context_parameter 22import tbe.common.context.op_info as operator_info 23sys.path.append(os.path.abspath(os.path.dirname(__file__))) 24# pylint: disable=wrong-import-position 25from tbe_common import check_kernel_info, get_args, get_built_in_impl_path 26 27build_in_impl_path = get_built_in_impl_path() 28 29# op function list 30op_build = "compile" 31 32 33def _initialize(impl_path): 34 """Initialize""" 35 if impl_path == "": 36 op_module_name = build_in_impl_path 37 else: 38 op_module_name = impl_path 39 if not op_module_name: 40 raise ValueError("Can not find the env TBE_IMPL_PATH") 41 42 sys.path.insert(0, op_module_name) 43 44 45def _replace_range(args): 46 for arg in args: 47 if not arg or not arg.__contains__('range'): 48 continue 49 shape_range = arg["range"] 50 for range_item in shape_range: 51 for index, value in enumerate(range_item): 52 if value < 0: 53 range_item[index] = None 54 55 56def build_op(build_type, json_str, tune_mode=None): 57 """ 58 call op functions with function name and input args json_str 59 60 Args: 61 build_type : op function name 62 json_str (str): op function input args 63 tune_mode (str): if use auto_tune 64 65 Raises: 66 Exception: If specific keyword is not found. 67 """ 68 kernel_info = json.loads(json_str) 69 check_kernel_info(kernel_info) 70 te_set_version(kernel_info["op_info"]["socVersion"]) 71 op_name = kernel_info['op_info']['name'] 72 op_type = kernel_info['op_info']['Type'] 73 rl_tune_switch = kernel_info['op_info']['rl_tune_switch'] 74 rl_tune_list = kernel_info['op_info']['rl_tune_list'] 75 reset_op_info = kernel_info["reset_op_info"] 76 op_tune_switch = kernel_info['op_info']['op_tune_switch'] 77 op_tune_list = kernel_info['op_info']['op_tune_list'] 78 pass_list = kernel_info['op_info']['pass_list'] 79 80 try: 81 custom_flag = False 82 if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None: 83 impl_path = os.path.realpath(kernel_info['impl_path']) 84 if os.path.isfile(impl_path): 85 path, file_name = os.path.split(impl_path) 86 op_name, _ = os.path.splitext(file_name) 87 impl_path = path 88 custom_flag = True 89 else: 90 impl_path = "" 91 _initialize(impl_path) 92 93 inputs_args = get_args(kernel_info['op_info'], 'inputs') 94 outputs_args = get_args(kernel_info['op_info'], 'outputs') 95 attrs_args = get_args(kernel_info['op_info'], 'attrs') 96 kernel_name = kernel_info['op_info']['kernel_name'] 97 is_dynamic_shape = kernel_info['op_info']['is_dynamic_shape'] 98 if is_dynamic_shape: 99 _replace_range(inputs_args) 100 _replace_range(outputs_args) 101 102 if custom_flag: 103 op_module_name = op_name 104 else: 105 if is_dynamic_shape: 106 op_module = __import__("impl.dynamic." + op_name, globals(), locals(), [op_name], 0) 107 op_module_name = "impl.dynamic." + op_name 108 else: 109 op_module_name = "impl." + op_name 110 # get function 111 if build_type == op_build: 112 if custom_flag: 113 py_fn_name = kernel_info['op_info']['name'] 114 else: 115 py_fn_name = op_name 116 else: 117 raise ValueError("function {} is not supported by Tbe op {}.".format(build_type, op_name)) 118 119 # call function 120 if is_dynamic_shape: 121 op_func = getattr(op_module, py_fn_name, None) 122 if op_func is None: 123 raise ValueError("Op:{} function {} is not supported by Tbe.".format(op_name, build_type)) 124 import tbe.common.context.op_context as op_context 125 with op_context.OpContext("dynamic"): 126 op_info = operator_info.OpInfo(op_type, op_type) 127 context = op_context.get_context() 128 context.add_op_info(op_info) 129 set_context_parameter(context, None, None, reset_op_info) 130 op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) 131 compile_info = op_context.get_context().get_compile_info() 132 if tune_mode is not None: 133 return compile_info, (inputs_args, outputs_args, attrs_args), op_module_name 134 return compile_info 135 else: 136 attrs_args.append(kernel_name) 137 import te_fusion.fusion_manager as fusion_manager 138 res = fusion_manager.build_single_op(op_module_name, py_fn_name, op_type, "build", 139 inputs=inputs_args, 140 outputs=outputs_args, 141 attrs=attrs_args, 142 unknown_shape=False, 143 int64_mode=False, 144 dynamic_compile_static=False, 145 op_pattern=None, 146 auto_tiling_mode=None, 147 device_id=None, 148 fuzz_build_info=None, 149 reset_op_info=reset_op_info, 150 switch_str=rl_tune_switch, 151 lic_opt_list=rl_tune_list, 152 pass_opt_list=pass_list, 153 switch_optune=op_tune_switch, 154 optune_opt_list=op_tune_list) 155 if tune_mode is not None: 156 return None, (inputs_args, outputs_args, attrs_args), op_module_name 157 return res 158 159 except Exception as e: 160 raise RuntimeError(e) 161 finally: 162 pass 163 164 165def compile_fusion_op(json_str): 166 """ 167 compile fusion op with input args json_str 168 169 Args: 170 json_str (str): op function input args 171 172 Raises: 173 Exception: If specific keyword is not found. 174 """ 175 args = json.loads(json_str) 176 reset_op_info = args["reset_op_info"] 177 te_set_version(args['fusion_op']["socVersion"]) 178 if 'fusion_op' not in args or not args['fusion_op']: 179 raise ValueError("Json string Errors, key:fusion_op not found.") 180 args['fusion_op']['SocInfo'] = args['SocInfo'] 181 fusion_op_arg = args['fusion_op'] 182 rl_tune_switch = args['fusion_op']['rl_tune_switch'] 183 rl_tune_list = args['fusion_op']['rl_tune_list'] 184 op_tune_switch = args['fusion_op']['op_tune_switch'] 185 op_tune_list = args['fusion_op']['op_tune_list'] 186 pass_list = args['fusion_op']['pass_list'] 187 return fusion_op(json.dumps(fusion_op_arg), reset_op_info=reset_op_info, switch_str=rl_tune_switch, 188 lic_opt_list=rl_tune_list, pass_opt_list=pass_list, 189 switch_optune=op_tune_switch, optune_opt_list=op_tune_list) 190 191 192def compile_with_json(json_str): 193 """ 194 Compile tbe with json. 195 196 Args: 197 json_str (str): jason file path. 198 199 """ 200 json_info = json.loads(json_str) 201 if "fusion_op" in json_info: 202 ret = compile_fusion_op(json_str) 203 else: 204 ret = build_op(op_build, json_str, None) 205 return ret 206 207 208if __name__ == "__main__": 209 in_args = sys.stdin.readline() 210 result = compile_with_json(in_args) 211 if isinstance(result, dict): 212 sys.stderr.write(json.dumps(result)) 213 sys.stderr.flush() 214