• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# encoding=utf-8
3# ============================================================================
4
5# Copyright (c) 2020 HiSilicon (Shanghai) Technologies CO., LIMITED.
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17# Description: API format check \n
18#
19# ============================================================================
20
21import re
22import os
23import sys
24import logging
25
26MODULE_LIST = ["driver", "bts", "wifi", "utils"]
27ERROR_LIST = ["td_void", "td_u8", "td_u32", " td_", "(td_"]
28logging.basicConfig(
29    filename="check_api_result.txt",
30    level=logging.NOTSET,
31    format='[%(asctime)s] [%(levelname)s] - %(message)s')
32
33
34def file_path_list(folder_path):
35    logging.debug("Start to read file name list...")
36    file_paths = []  # 存储目录下的所有文件名,含路径
37    for root, _, files in os.walk(folder_path):
38        for file in files:
39            file_paths.append(os.path.join(root, file))
40    return file_paths
41
42
43def api_read(file_path, api_dic):
44    txt = ""
45    if os.path.splitext(file_path)[-1] != ".h":
46        return api_dic
47    try:
48        with open(file_path, 'r') as f:
49            txt = f.read()
50        # print(txt)
51    except OSError as e:
52        logging.error("OSError file name: %s", file_path)
53        # raise OSError('invalid encoding byte because {}.'.format(e))
54    except UnicodeDecodeError as f:
55        logging.warning("UnicodeDecodeError file name: %s", file_path)
56        with open(file_path, 'r') as fin:
57            try:
58                for line in fin.readlines():
59                    txt += line.encode('utf-8', 'ignore').decode('utf-8')
60            except UnicodeDecodeError as f:
61                logging.error("line:UnicodeDecodeError file name: %s", file_path)
62                pass
63
64    res = re.findall(r"/\*\*.+? \* @brief(.*?)\n.*?\*/.*?\n(.*?);", txt, re.S | re.M)
65
66    for i in res:
67        i = list(i)
68        if "(" in i[1] and "{" not in i[1] and "#define" not in i[1]:
69            api_dic[i[1]] = file_path
70    return api_dic
71
72
73def function_handle(name):
74    class_name = name.split(" ", 1)[0]
75    function_name = name.split(" ", 1)[1]
76    return class_name, function_name
77
78
79def count_duplicate_module(count_dict, module, sub_module):
80    if "total" in count_dict:
81        count_dict["total"] += 1
82    else:
83        count_dict["total"] = 1
84
85    if module in count_dict:
86        count_dict[module]["total"] += 1
87        if sub_module in count_dict[module]:
88            count_dict[module][sub_module] += 1
89        else:
90            count_dict[module][sub_module] = 1
91    else:
92        count_dict[module] = {}
93        count_dict[module][sub_module] = 1
94        count_dict[module]["total"] = 1
95    return count_dict
96
97
98g_tabs_indent_bar = 10 * '-'
99g_tabs_indent_space = 10 * ' '
100
101
102def report_module_api(count_dict, detail):
103    print("CFBB total APIs:", str(count_dict["total"]))
104    for module in count_dict:
105        if module in "total":
106            continue
107        print(g_tabs_indent_bar, module, " total APIs:", count_dict[module]["total"])
108        if not detail:
109            continue
110        module_sum = 0
111        for sub_module in count_dict[module]:
112            if sub_module in "total":
113                continue
114            module_sum += count_dict[module][sub_module]
115            print(g_tabs_indent_space, g_tabs_indent_bar, sub_module, " APIs:\t", count_dict[module][sub_module])
116
117        if count_dict[module]["total"] != module_sum:
118            raise ImportError("module_sum APIs is not correct")
119
120
121def save_to_excel(dic, save2file):
122    if save2file:
123        try:
124            import openpyxl
125        except ImportError:
126            raise ImportError("please install openpyxl")
127        wb = openpyxl.Workbook()
128        wb.create_sheet("log", 0)
129        ws = wb["log"]
130        ws.append(["API name", "module name", "file path"])
131
132    count_dict = {}
133    for i in dic:
134        class_name, func_name = function_handle(i)
135        module_name = "unknown"
136        for m in MODULE_LIST:
137            if m in dic[i]:
138                module_name = m
139                break
140        d = ("%s %s" % (class_name, func_name)), module_name, dic[i]
141        sub_module = os.path.basename(dic[i])
142        count_dict = count_duplicate_module(count_dict, module_name, sub_module)
143        if class_name != "typedef":
144            if func_name.startswith("uapi_") is not True and module_name == "driver":
145                logging.error("API format is incorrect(don't start with uapi_): %s", func_name)
146        if save2file:
147            ws.append(d)
148    report_module_api(count_dict, True)
149    if save2file:
150        filename = "api_list.xlsx"
151        wb.save(filename)
152    logging.info("Saved successfully.")
153
154
155def save_to_txt(dic):
156    with open("error_api.txt", "w+") as f:
157        for i in dic.keys():
158            f.write(dic[i] + ":" + i + "\n")
159
160
161def print_error_type(dic):
162    error_dic = {}
163    for i in dic.keys():
164        for n in ERROR_LIST:
165            if n in i:
166                error_dic[i] = dic[i]
167    for i in error_dic.keys():
168        print("error type: " + i + ":" + error_dic[i])
169    return error_dic
170
171
172def main():
173    if len(sys.argv) == 2:
174        curr_path = sys.argv[1]
175    else:
176        curr_path = os.getcwd()
177    print("curr_path:", curr_path)
178    if os.path.isdir(curr_path):
179        path_list = file_path_list(curr_path)
180        result = {}
181        for k in path_list:
182            result = api_read(k, result)
183        logging.info("API count: %s", len(result))
184        error_dic = print_error_type(result)
185        print(error_dic)
186        save_to_txt(error_dic)
187        if error_dic:
188            return -1
189        else:
190            return 0
191    else:
192        logging.error("error path!")
193        return -1
194
195
196if __name__ == '__main__':
197
198    ret = main()
199    if ret == -1:
200        print(-1)
201    else:
202        print(0)
203