# Copyright (c) 2021, Alliance for Open Media. All rights reserved. # # This source code is subject to the terms of the BSD 2 Clause License and # the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License # was not distributed with this source code in the LICENSE file, you can # obtain it at www.aomedia.org/license/software. If the Alliance for Open # Media Patent License 1.0 was not distributed with this source code in the # PATENTS file, you can obtain it at www.aomedia.org/license/patent. # from __future__ import print_function import sys import os import operator from pycparser import c_parser, c_ast, parse_file from math import * from inspect import currentframe, getframeinfo from collections import deque def debug_print(frameinfo): print('******** ERROR:', frameinfo.filename, frameinfo.lineno, '********') class StructItem(): def __init__(self, typedef_name=None, struct_name=None, struct_node=None, is_union=False): self.typedef_name = typedef_name self.struct_name = struct_name self.struct_node = struct_node self.is_union = is_union self.child_decl_map = None def __str__(self): return str(self.typedef_name) + ' ' + str(self.struct_name) + ' ' + str( self.is_union) def compute_child_decl_map(self, struct_info): self.child_decl_map = {} if self.struct_node != None and self.struct_node.decls != None: for decl_node in self.struct_node.decls: if decl_node.name == None: for sub_decl_node in decl_node.type.decls: sub_decl_status = parse_decl_node(struct_info, sub_decl_node) self.child_decl_map[sub_decl_node.name] = sub_decl_status else: decl_status = parse_decl_node(struct_info, decl_node) self.child_decl_map[decl_status.name] = decl_status def get_child_decl_status(self, decl_name): if self.child_decl_map == None: debug_print(getframeinfo(currentframe())) print('child_decl_map is None') return None if decl_name not in self.child_decl_map: debug_print(getframeinfo(currentframe())) print(decl_name, 'does not exist ') return None return self.child_decl_map[decl_name] class StructInfo(): def __init__(self): self.struct_name_dic = {} self.typedef_name_dic = {} self.enum_value_dic = {} # enum value -> enum_node self.enum_name_dic = {} # enum name -> enum_node self.struct_item_list = [] def get_struct_by_typedef_name(self, typedef_name): if typedef_name in self.typedef_name_dic: return self.typedef_name_dic[typedef_name] else: return None def get_struct_by_struct_name(self, struct_name): if struct_name in self.struct_name_dic: return self.struct_name_dic[struct_name] else: debug_print(getframeinfo(currentframe())) print('Cant find', struct_name) return None def update_struct_item_list(self): # Collect all struct_items from struct_name_dic and typedef_name_dic # Compute child_decl_map for each struct item. for struct_name in self.struct_name_dic.keys(): struct_item = self.struct_name_dic[struct_name] struct_item.compute_child_decl_map(self) self.struct_item_list.append(struct_item) for typedef_name in self.typedef_name_dic.keys(): struct_item = self.typedef_name_dic[typedef_name] if struct_item.struct_name not in self.struct_name_dic: struct_item.compute_child_decl_map(self) self.struct_item_list.append(struct_item) def update_enum(self, enum_node): if enum_node.name != None: self.enum_name_dic[enum_node.name] = enum_node if enum_node.values != None: enumerator_list = enum_node.values.enumerators for enumerator in enumerator_list: self.enum_value_dic[enumerator.name] = enum_node def update(self, typedef_name=None, struct_name=None, struct_node=None, is_union=False): """T: typedef_name S: struct_name N: struct_node T S N case 1: o o o typedef struct P { int u; } K; T S N case 2: o o x typedef struct P K; T S N case 3: x o o struct P { int u; }; T S N case 4: o x o typedef struct { int u; } K; """ struct_item = None # Check whether struct_name or typedef_name is already in the dictionary if struct_name in self.struct_name_dic: struct_item = self.struct_name_dic[struct_name] if typedef_name in self.typedef_name_dic: struct_item = self.typedef_name_dic[typedef_name] if struct_item == None: struct_item = StructItem(typedef_name, struct_name, struct_node, is_union) if struct_node.decls != None: struct_item.struct_node = struct_node if struct_name != None: self.struct_name_dic[struct_name] = struct_item if typedef_name != None: self.typedef_name_dic[typedef_name] = struct_item class StructDefVisitor(c_ast.NodeVisitor): def __init__(self): self.struct_info = StructInfo() def visit_Struct(self, node): if node.decls != None: self.struct_info.update(None, node.name, node) self.generic_visit(node) def visit_Union(self, node): if node.decls != None: self.struct_info.update(None, node.name, node, True) self.generic_visit(node) def visit_Enum(self, node): self.struct_info.update_enum(node) self.generic_visit(node) def visit_Typedef(self, node): if node.type.__class__.__name__ == 'TypeDecl': typedecl = node.type if typedecl.type.__class__.__name__ == 'Struct': struct_node = typedecl.type typedef_name = node.name struct_name = struct_node.name self.struct_info.update(typedef_name, struct_name, struct_node) elif typedecl.type.__class__.__name__ == 'Union': union_node = typedecl.type typedef_name = node.name union_name = union_node.name self.struct_info.update(typedef_name, union_name, union_node, True) # TODO(angiebird): Do we need to deal with enum here? self.generic_visit(node) def build_struct_info(ast): v = StructDefVisitor() v.visit(ast) struct_info = v.struct_info struct_info.update_struct_item_list() return v.struct_info class DeclStatus(): def __init__(self, name, struct_item=None, is_ptr_decl=False): self.name = name self.struct_item = struct_item self.is_ptr_decl = is_ptr_decl def get_child_decl_status(self, decl_name): if self.struct_item != None: return self.struct_item.get_child_decl_status(decl_name) else: #TODO(angiebird): 2. Investigage the situation when a struct's definition can't be found. return None def __str__(self): return str(self.struct_item) + ' ' + str(self.name) + ' ' + str( self.is_ptr_decl) def peel_ptr_decl(decl_type_node): """ Remove PtrDecl and ArrayDecl layer """ is_ptr_decl = False peeled_decl_type_node = decl_type_node while peeled_decl_type_node.__class__.__name__ == 'PtrDecl' or peeled_decl_type_node.__class__.__name__ == 'ArrayDecl': is_ptr_decl = True peeled_decl_type_node = peeled_decl_type_node.type return is_ptr_decl, peeled_decl_type_node def parse_peeled_decl_type_node(struct_info, node): struct_item = None if node.__class__.__name__ == 'TypeDecl': if node.type.__class__.__name__ == 'IdentifierType': identifier_type_node = node.type typedef_name = identifier_type_node.names[0] struct_item = struct_info.get_struct_by_typedef_name(typedef_name) elif node.type.__class__.__name__ == 'Struct': struct_node = node.type if struct_node.name != None: struct_item = struct_info.get_struct_by_struct_name(struct_node.name) else: struct_item = StructItem(None, None, struct_node, False) struct_item.compute_child_decl_map(struct_info) elif node.type.__class__.__name__ == 'Union': # TODO(angiebird): Special treatment for Union? struct_node = node.type if struct_node.name != None: struct_item = struct_info.get_struct_by_struct_name(struct_node.name) else: struct_item = StructItem(None, None, struct_node, True) struct_item.compute_child_decl_map(struct_info) elif node.type.__class__.__name__ == 'Enum': # TODO(angiebird): Special treatment for Union? struct_node = node.type struct_item = None else: print('Unrecognized peeled_decl_type_node.type', node.type.__class__.__name__) else: # debug_print(getframeinfo(currentframe())) # print(node.__class__.__name__) #TODO(angiebird): Do we need to take care of this part? pass return struct_item def parse_decl_node(struct_info, decl_node): # struct_item is None if this decl_node is not a struct_item decl_node_name = decl_node.name decl_type_node = decl_node.type is_ptr_decl, peeled_decl_type_node = peel_ptr_decl(decl_type_node) struct_item = parse_peeled_decl_type_node(struct_info, peeled_decl_type_node) return DeclStatus(decl_node_name, struct_item, is_ptr_decl) def get_lvalue_lead(lvalue_node): """return '&' or '*' of lvalue if available""" if lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '&': return '&' elif lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '*': return '*' return None def parse_lvalue(lvalue_node): """get id_chain from lvalue""" id_chain = parse_lvalue_recursive(lvalue_node, []) return id_chain def parse_lvalue_recursive(lvalue_node, id_chain): """cpi->rd->u -> (cpi->rd)->u""" if lvalue_node.__class__.__name__ == 'ID': id_chain.append(lvalue_node.name) id_chain.reverse() return id_chain elif lvalue_node.__class__.__name__ == 'StructRef': id_chain.append(lvalue_node.field.name) return parse_lvalue_recursive(lvalue_node.name, id_chain) elif lvalue_node.__class__.__name__ == 'ArrayRef': return parse_lvalue_recursive(lvalue_node.name, id_chain) elif lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '&': return parse_lvalue_recursive(lvalue_node.expr, id_chain) elif lvalue_node.__class__.__name__ == 'UnaryOp' and lvalue_node.op == '*': return parse_lvalue_recursive(lvalue_node.expr, id_chain) else: return None class FuncDefVisitor(c_ast.NodeVisitor): func_dictionary = {} def visit_FuncDef(self, node): func_name = node.decl.name self.func_dictionary[func_name] = node def build_func_dictionary(ast): v = FuncDefVisitor() v.visit(ast) return v.func_dictionary def get_func_start_coord(func_node): return func_node.coord def find_end_node(node): node_list = [] for c in node: node_list.append(c) if len(node_list) == 0: return node else: return find_end_node(node_list[-1]) def get_func_end_coord(func_node): return find_end_node(func_node).coord def get_func_size(func_node): start_coord = get_func_start_coord(func_node) end_coord = get_func_end_coord(func_node) if start_coord.file == end_coord.file: return end_coord.line - start_coord.line + 1 else: return None def save_object(obj, filename): with open(filename, 'wb') as obj_fp: pickle.dump(obj, obj_fp, protocol=-1) def load_object(filename): obj = None with open(filename, 'rb') as obj_fp: obj = pickle.load(obj_fp) return obj def get_av1_ast(gen_ast=False): # TODO(angiebird): Generalize this path c_filename = './av1_pp.c' print('generate ast') ast = parse_file(c_filename) #save_object(ast, ast_file) print('finished generate ast') return ast def get_func_param_id_map(func_def_node): param_id_map = {} func_decl = func_def_node.decl.type param_list = func_decl.args.params for decl in param_list: param_id_map[decl.name] = decl return param_id_map class IDTreeStack(): def __init__(self, global_id_tree): self.stack = deque() self.global_id_tree = global_id_tree def add_link_node(self, node, link_id_chain): link_node = self.add_id_node(link_id_chain) node.link_node = link_node node.link_id_chain = link_id_chain def push_id_tree(self, id_tree=None): if id_tree == None: id_tree = IDStatusNode() self.stack.append(id_tree) return id_tree def pop_id_tree(self): return self.stack.pop() def add_id_seed_node(self, id_seed, decl_status): return self.stack[-1].add_child(id_seed, decl_status) def get_id_seed_node(self, id_seed): idx = len(self.stack) - 1 while idx >= 0: id_node = self.stack[idx].get_child(id_seed) if id_node != None: return id_node idx -= 1 id_node = self.global_id_tree.get_child(id_seed) if id_node != None: return id_node return None def add_id_node(self, id_chain): id_seed = id_chain[0] id_seed_node = self.get_id_seed_node(id_seed) if id_seed_node == None: return None if len(id_chain) == 1: return id_seed_node return id_seed_node.add_descendant(id_chain[1:]) def get_id_node(self, id_chain): id_seed = id_chain[0] id_seed_node = self.get_id_seed_node(id_seed) if id_seed_node == None: return None if len(id_chain) == 1: return id_seed_node return id_seed_node.get_descendant(id_chain[1:]) def top(self): return self.stack[-1] class IDStatusNode(): def __init__(self, name=None, root=None): if root is None: self.root = self else: self.root = root self.name = name self.parent = None self.children = {} self.assign = False self.last_assign_coord = None self.refer = False self.last_refer_coord = None self.decl_status = None self.link_id_chain = None self.link_node = None self.visit = False def set_link_id_chain(self, link_id_chain): self.set_assign(False) self.link_id_chain = link_id_chain self.link_node = self.root.get_descendant(link_id_chain) def set_link_node(self, link_node): self.set_assign(False) self.link_id_chain = ['*'] self.link_node = link_node def get_link_id_chain(self): return self.link_id_chain def get_concrete_node(self): if self.visit == True: # return None when there is a loop return None self.visit = True if self.link_node == None: self.visit = False return self else: concrete_node = self.link_node.get_concrete_node() self.visit = False if concrete_node == None: return self return concrete_node def set_assign(self, assign, coord=None): concrete_node = self.get_concrete_node() concrete_node.assign = assign concrete_node.last_assign_coord = coord def get_assign(self): concrete_node = self.get_concrete_node() return concrete_node.assign def set_refer(self, refer, coord=None): concrete_node = self.get_concrete_node() concrete_node.refer = refer concrete_node.last_refer_coord = coord def get_refer(self): concrete_node = self.get_concrete_node() return concrete_node.refer def set_parent(self, parent): concrete_node = self.get_concrete_node() concrete_node.parent = parent def add_child(self, name, decl_status=None): concrete_node = self.get_concrete_node() if name not in concrete_node.children: child_id_node = IDStatusNode(name, concrete_node.root) concrete_node.children[name] = child_id_node if decl_status == None: # Check if the child decl_status can be inferred from its parent's # decl_status if self.decl_status != None: decl_status = self.decl_status.get_child_decl_status(name) child_id_node.set_decl_status(decl_status) return concrete_node.children[name] def get_child(self, name): concrete_node = self.get_concrete_node() if name in concrete_node.children: return concrete_node.children[name] else: return None def add_descendant(self, id_chain): current_node = self.get_concrete_node() for name in id_chain: current_node.add_child(name) parent_node = current_node current_node = current_node.get_child(name) current_node.set_parent(parent_node) return current_node def get_descendant(self, id_chain): current_node = self.get_concrete_node() for name in id_chain: current_node = current_node.get_child(name) if current_node == None: return None return current_node def get_children(self): current_node = self.get_concrete_node() return current_node.children def set_decl_status(self, decl_status): current_node = self.get_concrete_node() current_node.decl_status = decl_status def get_decl_status(self): current_node = self.get_concrete_node() return current_node.decl_status def __str__(self): if self.link_id_chain is None: return str(self.name) + ' a: ' + str(int(self.assign)) + ' r: ' + str( int(self.refer)) else: return str(self.name) + ' -> ' + ' '.join(self.link_id_chain) def collect_assign_refer_status(self, id_chain=None, assign_ls=None, refer_ls=None): if id_chain == None: id_chain = [] if assign_ls == None: assign_ls = [] if refer_ls == None: refer_ls = [] id_chain.append(self.name) if self.assign: info_str = ' '.join([ ' '.join(id_chain[1:]), 'a:', str(int(self.assign)), 'r:', str(int(self.refer)), str(self.last_assign_coord) ]) assign_ls.append(info_str) if self.refer: info_str = ' '.join([ ' '.join(id_chain[1:]), 'a:', str(int(self.assign)), 'r:', str(int(self.refer)), str(self.last_refer_coord) ]) refer_ls.append(info_str) for c in self.children: self.children[c].collect_assign_refer_status(id_chain, assign_ls, refer_ls) id_chain.pop() return assign_ls, refer_ls def show(self): assign_ls, refer_ls = self.collect_assign_refer_status() print('---- assign ----') for item in assign_ls: print(item) print('---- refer ----') for item in refer_ls: print(item) class FuncInOutVisitor(c_ast.NodeVisitor): def __init__(self, func_def_node, struct_info, func_dictionary, keep_body_id_tree=True, call_param_map=None, global_id_tree=None, func_history=None, unknown=None): self.func_dictionary = func_dictionary self.struct_info = struct_info self.param_id_map = get_func_param_id_map(func_def_node) self.parent_node = None self.global_id_tree = global_id_tree self.body_id_tree = None self.keep_body_id_tree = keep_body_id_tree if func_history == None: self.func_history = {} else: self.func_history = func_history if unknown == None: self.unknown = [] else: self.unknown = unknown self.id_tree_stack = IDTreeStack(global_id_tree) self.id_tree_stack.push_id_tree() #TODO move this part into a function for param in self.param_id_map: decl_node = self.param_id_map[param] decl_status = parse_decl_node(self.struct_info, decl_node) descendant = self.id_tree_stack.add_id_seed_node(decl_status.name, decl_status) if call_param_map is not None and param in call_param_map: # This is a function call. # Map the input parameter to the caller's nodes # TODO(angiebird): Can we use add_link_node here? descendant.set_link_node(call_param_map[param]) def get_id_tree_stack(self): return self.id_tree_stack def generic_visit(self, node): prev_parent = self.parent_node self.parent_node = node for c in node: self.visit(c) self.parent_node = prev_parent # TODO rename def add_new_id_tree(self, node): self.id_tree_stack.push_id_tree() self.generic_visit(node) id_tree = self.id_tree_stack.pop_id_tree() if self.parent_node == None and self.keep_body_id_tree == True: # this is function body self.body_id_tree = id_tree def visit_For(self, node): self.add_new_id_tree(node) def visit_Compound(self, node): self.add_new_id_tree(node) def visit_Decl(self, node): if node.type.__class__.__name__ != 'FuncDecl': decl_status = parse_decl_node(self.struct_info, node) descendant = self.id_tree_stack.add_id_seed_node(decl_status.name, decl_status) if node.init is not None: init_id_chain = self.process_lvalue(node.init) if init_id_chain != None: if decl_status.struct_item is None: init_descendant = self.id_tree_stack.add_id_node(init_id_chain) if init_descendant != None: init_descendant.set_refer(True, node.coord) else: self.unknown.append(node) descendant.set_assign(True, node.coord) else: self.id_tree_stack.add_link_node(descendant, init_id_chain) else: self.unknown.append(node) else: descendant.set_assign(True, node.coord) self.generic_visit(node) def is_lvalue(self, node): if self.parent_node is None: # TODO(angiebird): Do every lvalue has parent_node != None? return False if self.parent_node.__class__.__name__ == 'StructRef': return False if self.parent_node.__class__.__name__ == 'ArrayRef' and node == self.parent_node.name: # if node == self.parent_node.subscript, the node could be lvalue return False if self.parent_node.__class__.__name__ == 'UnaryOp' and self.parent_node.op == '&': return False if self.parent_node.__class__.__name__ == 'UnaryOp' and self.parent_node.op == '*': return False return True def process_lvalue(self, node): id_chain = parse_lvalue(node) if id_chain == None: return id_chain elif id_chain[0] in self.struct_info.enum_value_dic: return None else: return id_chain def process_possible_lvalue(self, node): if self.is_lvalue(node): id_chain = self.process_lvalue(node) lead_char = get_lvalue_lead(node) # make sure the id is not an enum value if id_chain == None: self.unknown.append(node) return descendant = self.id_tree_stack.add_id_node(id_chain) if descendant == None: self.unknown.append(node) return decl_status = descendant.get_decl_status() if decl_status == None: descendant.set_assign(True, node.coord) descendant.set_refer(True, node.coord) self.unknown.append(node) return if self.parent_node.__class__.__name__ == 'Assignment': if node is self.parent_node.lvalue: if decl_status.struct_item != None: if len(id_chain) > 1: descendant.set_assign(True, node.coord) elif len(id_chain) == 1: if lead_char == '*': descendant.set_assign(True, node.coord) else: right_id_chain = self.process_lvalue(self.parent_node.rvalue) if right_id_chain != None: self.id_tree_stack.add_link_node(descendant, right_id_chain) else: #TODO(angiebird): 1.Find a better way to deal with this case. descendant.set_assign(True, node.coord) else: debug_print(getframeinfo(currentframe())) else: descendant.set_assign(True, node.coord) elif node is self.parent_node.rvalue: if decl_status.struct_item is None: descendant.set_refer(True, node.coord) if lead_char == '&': descendant.set_assign(True, node.coord) else: left_id_chain = self.process_lvalue(self.parent_node.lvalue) left_lead_char = get_lvalue_lead(self.parent_node.lvalue) if left_id_chain != None: if len(left_id_chain) > 1: descendant.set_refer(True, node.coord) elif len(left_id_chain) == 1: if left_lead_char == '*': descendant.set_refer(True, node.coord) else: #TODO(angiebird): Check whether the other node is linked to this node. pass else: self.unknown.append(self.parent_node.lvalue) debug_print(getframeinfo(currentframe())) else: self.unknown.append(self.parent_node.lvalue) debug_print(getframeinfo(currentframe())) else: debug_print(getframeinfo(currentframe())) elif self.parent_node.__class__.__name__ == 'UnaryOp': # TODO(angiebird): Consider +=, *=, -=, /= etc if self.parent_node.op == '--' or self.parent_node.op == '++' or\ self.parent_node.op == 'p--' or self.parent_node.op == 'p++': descendant.set_assign(True, node.coord) descendant.set_refer(True, node.coord) else: descendant.set_refer(True, node.coord) elif self.parent_node.__class__.__name__ == 'Decl': #The logic is at visit_Decl pass elif self.parent_node.__class__.__name__ == 'ExprList': #The logic is at visit_FuncCall pass else: descendant.set_refer(True, node.coord) def visit_ID(self, node): # If the parent is a FuncCall, this ID is a function name. if self.parent_node.__class__.__name__ != 'FuncCall': self.process_possible_lvalue(node) self.generic_visit(node) def visit_StructRef(self, node): self.process_possible_lvalue(node) self.generic_visit(node) def visit_ArrayRef(self, node): self.process_possible_lvalue(node) self.generic_visit(node) def visit_UnaryOp(self, node): if node.op == '&' or node.op == '*': self.process_possible_lvalue(node) self.generic_visit(node) def visit_FuncCall(self, node): if node.name.__class__.__name__ == 'ID': if node.name.name in self.func_dictionary: if node.name.name not in self.func_history: self.func_history[node.name.name] = True func_def_node = self.func_dictionary[node.name.name] call_param_map = self.process_func_call(node, func_def_node) visitor = FuncInOutVisitor(func_def_node, self.struct_info, self.func_dictionary, False, call_param_map, self.global_id_tree, self.func_history, self.unknown) visitor.visit(func_def_node.body) else: self.unknown.append(node) self.generic_visit(node) def process_func_call(self, func_call_node, func_def_node): # set up a refer/assign for func parameters # return call_param_map call_param_ls = func_call_node.args.exprs call_param_map = {} func_decl = func_def_node.decl.type decl_param_ls = func_decl.args.params for param_node, decl_node in zip(call_param_ls, decl_param_ls): id_chain = self.process_lvalue(param_node) if id_chain != None: descendant = self.id_tree_stack.add_id_node(id_chain) if descendant == None: self.unknown.append(param_node) else: decl_status = descendant.get_decl_status() if decl_status != None: if decl_status.struct_item == None: if decl_status.is_ptr_decl == True: descendant.set_assign(True, param_node.coord) descendant.set_refer(True, param_node.coord) else: descendant.set_refer(True, param_node.coord) else: call_param_map[decl_node.name] = descendant else: self.unknown.append(param_node) else: self.unknown.append(param_node) return call_param_map def build_global_id_tree(ast, struct_info): global_id_tree = IDStatusNode() for node in ast.ext: if node.__class__.__name__ == 'Decl': # id tree is for tracking assign/refer status # we don't care about function id because they can't be changed if node.type.__class__.__name__ != 'FuncDecl': decl_status = parse_decl_node(struct_info, node) descendant = global_id_tree.add_child(decl_status.name, decl_status) return global_id_tree class FuncAnalyzer(): def __init__(self): self.ast = get_av1_ast() self.struct_info = build_struct_info(self.ast) self.func_dictionary = build_func_dictionary(self.ast) self.global_id_tree = build_global_id_tree(self.ast, self.struct_info) def analyze(self, func_name): if func_name in self.func_dictionary: func_def_node = self.func_dictionary[func_name] visitor = FuncInOutVisitor(func_def_node, self.struct_info, self.func_dictionary, True, None, self.global_id_tree) visitor.visit(func_def_node.body) root = visitor.get_id_tree_stack() root.top().show() else: print(func_name, "doesn't exist") if __name__ == '__main__': fa = FuncAnalyzer() fa.analyze('tpl_get_satd_cost') pass