• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023-2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Generate operator utils function
17"""
18import os
19import glob
20import hashlib
21import stat
22import yaml
23
24
25py_licence_str = f"""# Copyright 2023 Huawei Technologies Co., Ltd
26#
27# Licensed under the Apache License, Version 2.0 (the "License");
28# you may not use this file except in compliance with the License.
29# You may obtain a copy of the License at
30#
31# http://www.apache.org/licenses/LICENSE-2.0
32#
33# Unless required by applicable law or agreed to in writing, software
34# distributed under the License is distributed on an "AS IS" BASIS,
35# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
36# See the License for the specific language governing permissions and
37# limitations under the License.
38# ============================================================================
39"""
40
41cc_license_str = f"""/**
42 * Copyright 2023 Huawei Technologies Co., Ltd
43 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
45 * you may not use this file except in compliance with the License.
46 * You may obtain a copy of the License at
47 *
48 * http://www.apache.org/licenses/LICENSE-2.0
49 *
50 * Unless required by applicable law or agreed to in writing, software
51 * distributed under the License is distributed on an "AS IS" BASIS,
52 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
53 * See the License for the specific language governing permissions and
54 * limitations under the License.
55 */"""
56
57
58def convert_dtype_str(dtype_str):
59    """
60    Convert dtype str to expression in ops file
61    """
62    return 'DT_' + dtype_str.replace('[', '_').replace(']', '').upper()
63
64
65def get_type_str(type_str):
66    """
67    Get the unified type str for operator arg dtype.
68    """
69    # add more type here
70    type_kind_set = {
71        'int',
72        'float',
73        'bool',
74        'number',
75        'tuple[int]',
76        'tuple[float]',
77        'tuple[bool]',
78        'tuple[tensor]',
79        'tuple[str]',
80        'list[int]',
81        'list[float]',
82        'list[bool]',
83        'list[tensor]',
84        'list[str]',
85        'tensor',
86        'type',
87    }
88    if type_str in type_kind_set:
89        return "OpDtype." + convert_dtype_str(type_str)
90    raise TypeError(f"""Unsupported type {type_str} for args.""")
91
92
93def get_file_md5(file_path):
94    """
95    Get the md5 value for file.
96    """
97    if not os.path.exists(file_path):
98        return ""
99    if os.path.isdir(file_path):
100        return ""
101    with open(file_path, 'rb') as f:
102        data = f.read()
103    md5_value = hashlib.md5(data).hexdigest()
104    return md5_value
105
106
107def check_change_and_replace_file(last_file_path, tmp_file_path):
108    """
109    Compare tmp_file with the md5 value of the last generated file.
110    If the md5 value is the same, retain the last generated file.
111    Otherwise, update the last generated file to tmp_file.
112    """
113    last_md5 = get_file_md5(last_file_path)
114    tmp_md5 = get_file_md5(tmp_file_path)
115
116    if last_md5 == tmp_md5:
117        os.remove(tmp_file_path)
118    else:
119        if os.path.exists(last_file_path):
120            os.remove(last_file_path)
121        os.rename(tmp_file_path, last_file_path)
122
123
124def merge_files_to_one_file(file_paths, merged_file_path):
125    """
126    Merge multiple files into one file.
127    """
128    merged_content = ''
129    file_paths.sort()
130    for file_path in file_paths:
131        with open(file_path, 'r') as file:
132            merged_content += file.read()
133            merged_content += '\n'
134    with open(merged_file_path, 'w') as file:
135        file.write(merged_content)
136
137
138def merge_files(origin_dir, merged_file_path, file_format):
139    """
140    Merge multiple files into one file.
141    origin_dir: indicates the origin file directory.
142    merged_file_path: indicates the merged file path.
143    file_format: indicates the format of regular matching.
144    Files whose names meet the regular matching in 'origin_dir' directory will be merged into one file.
145    """
146    op_yaml_file_names = glob.glob(os.path.join(origin_dir, file_format))
147    merge_files_to_one_file(op_yaml_file_names, merged_file_path)
148
149
150def merge_files_append(origin_dir, merged_file_path, file_format):
151    """
152    Merge multiple files into one file.
153    origin_dir: indicates the origin file directory.
154    merged_file_path: indicates the merged file path.
155    file_format: indicates the format of regular matching.
156    Files whose names meet the regular matching in 'origin_dir' directory will be merged into one file.
157    """
158    file_paths = glob.glob(os.path.join(origin_dir, file_format))
159    merged_content = ''
160    file_paths.sort()
161    for file_path in file_paths:
162        with open(file_path, 'r') as file:
163            merged_content += file.read()
164            merged_content += '\n'
165    with open(merged_file_path, 'a') as file:
166        file.write(merged_content)
167
168
169def safe_load_yaml(yaml_file_path):
170    """
171    Load yaml dictionary from file.
172    """
173    yaml_str = dict()
174    with open(yaml_file_path, 'r') as yaml_file:
175        yaml_str.update(yaml.safe_load(yaml_file))
176    return yaml_str
177
178
179def get_assign_str_by_type_it(class_name, arg_info, arg_name, dtype):
180    """
181    Make type_it(arg, src_types, dst_type) python sentences.
182    """
183    assign_str = ""
184    type_cast = arg_info.get('type_cast')
185    if type_cast is not None:
186        type_cast_tuple = tuple(ct.strip() for ct in type_cast.split(","))
187        assign_str += f"type_it('{class_name}', '{arg_name}', {arg_name}, "
188        if len(type_cast_tuple) == 1:
189            assign_str += get_type_str(type_cast_tuple[0]) + ', '
190        else:
191            assign_str += '(' + ', '.join(get_type_str(ct) for ct in type_cast_tuple) + '), '
192        assign_str += get_type_str(dtype) + ')'
193    else:
194        assign_str = arg_name
195    return assign_str
196
197
198def write_file(path, data):
199    """
200    write data to path
201    :param path:
202    :param data:
203    :return:
204    """
205    flags = os.O_RDWR | os.O_CREAT
206    mode = stat.S_IWUSR | stat.S_IRUSR
207    fd = os.open(path, flags, mode)
208    with os.fdopen(fd, "w") as f:
209        f.write(data)
210