1# Copyright 2023-2025 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16Generate operator utils function 17""" 18import os 19import glob 20import hashlib 21import stat 22import yaml 23 24 25py_licence_str = f"""# Copyright 2023 Huawei Technologies Co., Ltd 26# 27# Licensed under the Apache License, Version 2.0 (the "License"); 28# you may not use this file except in compliance with the License. 29# You may obtain a copy of the License at 30# 31# http://www.apache.org/licenses/LICENSE-2.0 32# 33# Unless required by applicable law or agreed to in writing, software 34# distributed under the License is distributed on an "AS IS" BASIS, 35# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 36# See the License for the specific language governing permissions and 37# limitations under the License. 38# ============================================================================ 39""" 40 41cc_license_str = f"""/** 42 * Copyright 2023 Huawei Technologies Co., Ltd 43 * 44 * Licensed under the Apache License, Version 2.0 (the "License"); 45 * you may not use this file except in compliance with the License. 46 * You may obtain a copy of the License at 47 * 48 * http://www.apache.org/licenses/LICENSE-2.0 49 * 50 * Unless required by applicable law or agreed to in writing, software 51 * distributed under the License is distributed on an "AS IS" BASIS, 52 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 53 * See the License for the specific language governing permissions and 54 * limitations under the License. 55 */""" 56 57 58def convert_dtype_str(dtype_str): 59 """ 60 Convert dtype str to expression in ops file 61 """ 62 return 'DT_' + dtype_str.replace('[', '_').replace(']', '').upper() 63 64 65def get_type_str(type_str): 66 """ 67 Get the unified type str for operator arg dtype. 68 """ 69 # add more type here 70 type_kind_set = { 71 'int', 72 'float', 73 'bool', 74 'number', 75 'tuple[int]', 76 'tuple[float]', 77 'tuple[bool]', 78 'tuple[tensor]', 79 'tuple[str]', 80 'list[int]', 81 'list[float]', 82 'list[bool]', 83 'list[tensor]', 84 'list[str]', 85 'tensor', 86 'type', 87 } 88 if type_str in type_kind_set: 89 return "OpDtype." + convert_dtype_str(type_str) 90 raise TypeError(f"""Unsupported type {type_str} for args.""") 91 92 93def get_file_md5(file_path): 94 """ 95 Get the md5 value for file. 96 """ 97 if not os.path.exists(file_path): 98 return "" 99 if os.path.isdir(file_path): 100 return "" 101 with open(file_path, 'rb') as f: 102 data = f.read() 103 md5_value = hashlib.md5(data).hexdigest() 104 return md5_value 105 106 107def check_change_and_replace_file(last_file_path, tmp_file_path): 108 """ 109 Compare tmp_file with the md5 value of the last generated file. 110 If the md5 value is the same, retain the last generated file. 111 Otherwise, update the last generated file to tmp_file. 112 """ 113 last_md5 = get_file_md5(last_file_path) 114 tmp_md5 = get_file_md5(tmp_file_path) 115 116 if last_md5 == tmp_md5: 117 os.remove(tmp_file_path) 118 else: 119 if os.path.exists(last_file_path): 120 os.remove(last_file_path) 121 os.rename(tmp_file_path, last_file_path) 122 123 124def merge_files_to_one_file(file_paths, merged_file_path): 125 """ 126 Merge multiple files into one file. 127 """ 128 merged_content = '' 129 file_paths.sort() 130 for file_path in file_paths: 131 with open(file_path, 'r') as file: 132 merged_content += file.read() 133 merged_content += '\n' 134 with open(merged_file_path, 'w') as file: 135 file.write(merged_content) 136 137 138def merge_files(origin_dir, merged_file_path, file_format): 139 """ 140 Merge multiple files into one file. 141 origin_dir: indicates the origin file directory. 142 merged_file_path: indicates the merged file path. 143 file_format: indicates the format of regular matching. 144 Files whose names meet the regular matching in 'origin_dir' directory will be merged into one file. 145 """ 146 op_yaml_file_names = glob.glob(os.path.join(origin_dir, file_format)) 147 merge_files_to_one_file(op_yaml_file_names, merged_file_path) 148 149 150def merge_files_append(origin_dir, merged_file_path, file_format): 151 """ 152 Merge multiple files into one file. 153 origin_dir: indicates the origin file directory. 154 merged_file_path: indicates the merged file path. 155 file_format: indicates the format of regular matching. 156 Files whose names meet the regular matching in 'origin_dir' directory will be merged into one file. 157 """ 158 file_paths = glob.glob(os.path.join(origin_dir, file_format)) 159 merged_content = '' 160 file_paths.sort() 161 for file_path in file_paths: 162 with open(file_path, 'r') as file: 163 merged_content += file.read() 164 merged_content += '\n' 165 with open(merged_file_path, 'a') as file: 166 file.write(merged_content) 167 168 169def safe_load_yaml(yaml_file_path): 170 """ 171 Load yaml dictionary from file. 172 """ 173 yaml_str = dict() 174 with open(yaml_file_path, 'r') as yaml_file: 175 yaml_str.update(yaml.safe_load(yaml_file)) 176 return yaml_str 177 178 179def get_assign_str_by_type_it(class_name, arg_info, arg_name, dtype): 180 """ 181 Make type_it(arg, src_types, dst_type) python sentences. 182 """ 183 assign_str = "" 184 type_cast = arg_info.get('type_cast') 185 if type_cast is not None: 186 type_cast_tuple = tuple(ct.strip() for ct in type_cast.split(",")) 187 assign_str += f"type_it('{class_name}', '{arg_name}', {arg_name}, " 188 if len(type_cast_tuple) == 1: 189 assign_str += get_type_str(type_cast_tuple[0]) + ', ' 190 else: 191 assign_str += '(' + ', '.join(get_type_str(ct) for ct in type_cast_tuple) + '), ' 192 assign_str += get_type_str(dtype) + ')' 193 else: 194 assign_str = arg_name 195 return assign_str 196 197 198def write_file(path, data): 199 """ 200 write data to path 201 :param path: 202 :param data: 203 :return: 204 """ 205 flags = os.O_RDWR | os.O_CREAT 206 mode = stat.S_IWUSR | stat.S_IRUSR 207 fd = os.open(path, flags, mode) 208 with os.fdopen(fd, "w") as f: 209 f.write(data) 210