1# Copyright 2012 The Chromium Authors. All rights reserved. 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4 5import fnmatch 6import importlib 7import inspect 8import os 9import re 10import sys 11 12from py_utils import camel_case 13 14 15def DiscoverModules(start_dir, top_level_dir, pattern='*'): 16 """Discover all modules in |start_dir| which match |pattern|. 17 18 Args: 19 start_dir: The directory to recursively search. 20 top_level_dir: The top level of the package, for importing. 21 pattern: Unix shell-style pattern for filtering the filenames to import. 22 23 Returns: 24 list of modules. 25 """ 26 # start_dir and top_level_dir must be consistent with each other. 27 start_dir = os.path.realpath(start_dir) 28 top_level_dir = os.path.realpath(top_level_dir) 29 30 modules = [] 31 sub_paths = list(os.walk(start_dir)) 32 # We sort the directories & file paths to ensure a deterministic ordering when 33 # traversing |top_level_dir|. 34 sub_paths.sort(key=lambda paths_tuple: paths_tuple[0]) 35 for dir_path, _, filenames in sub_paths: 36 # Sort the directories to walk recursively by the directory path. 37 filenames.sort() 38 for filename in filenames: 39 # Filter out unwanted filenames. 40 if filename.startswith('.') or filename.startswith('_'): 41 continue 42 if os.path.splitext(filename)[1] != '.py': 43 continue 44 if not fnmatch.fnmatch(filename, pattern): 45 continue 46 47 # Find the module. 48 module_rel_path = os.path.relpath( 49 os.path.join(dir_path, filename), top_level_dir) 50 module_name = re.sub(r'[/\\]', '.', os.path.splitext(module_rel_path)[0]) 51 52 # Import the module. 53 try: 54 # Make sure that top_level_dir is the first path in the sys.path in case 55 # there are naming conflict in module parts. 56 original_sys_path = sys.path[:] 57 sys.path.insert(0, top_level_dir) 58 module = importlib.import_module(module_name) 59 modules.append(module) 60 finally: 61 sys.path = original_sys_path 62 return modules 63 64 65def AssertNoKeyConflicts(classes_by_key_1, classes_by_key_2): 66 for k in classes_by_key_1: 67 if k in classes_by_key_2: 68 assert classes_by_key_1[k] is classes_by_key_2[k], ( 69 'Found conflicting classes for the same key: ' 70 'key=%s, class_1=%s, class_2=%s' % ( 71 k, classes_by_key_1[k], classes_by_key_2[k])) 72 73 74# TODO(dtu): Normalize all discoverable classes to have corresponding module 75# and class names, then always index by class name. 76def DiscoverClasses(start_dir, 77 top_level_dir, 78 base_class, 79 pattern='*', 80 index_by_class_name=True, 81 directly_constructable=False): 82 """Discover all classes in |start_dir| which subclass |base_class|. 83 84 Base classes that contain subclasses are ignored by default. 85 86 Args: 87 start_dir: The directory to recursively search. 88 top_level_dir: The top level of the package, for importing. 89 base_class: The base class to search for. 90 pattern: Unix shell-style pattern for filtering the filenames to import. 91 index_by_class_name: If True, use class name converted to 92 lowercase_with_underscores instead of module name in return dict keys. 93 directly_constructable: If True, will only return classes that can be 94 constructed without arguments 95 96 Returns: 97 dict of {module_name: class} or {underscored_class_name: class} 98 """ 99 modules = DiscoverModules(start_dir, top_level_dir, pattern) 100 classes = {} 101 for module in modules: 102 new_classes = DiscoverClassesInModule( 103 module, base_class, index_by_class_name, directly_constructable) 104 # TODO(nednguyen): we should remove index_by_class_name once 105 # benchmark_smoke_unittest in chromium/src/tools/perf no longer relied 106 # naming collisions to reduce the number of smoked benchmark tests. 107 # crbug.com/548652 108 if index_by_class_name: 109 AssertNoKeyConflicts(classes, new_classes) 110 classes = dict(list(classes.items()) + list(new_classes.items())) 111 return classes 112 113 114# TODO(nednguyen): we should remove index_by_class_name once 115# benchmark_smoke_unittest in chromium/src/tools/perf no longer relied 116# naming collisions to reduce the number of smoked benchmark tests. 117# crbug.com/548652 118def DiscoverClassesInModule(module, 119 base_class, 120 index_by_class_name=False, 121 directly_constructable=False): 122 """Discover all classes in |module| which subclass |base_class|. 123 124 Base classes that contain subclasses are ignored by default. 125 126 Args: 127 module: The module to search. 128 base_class: The base class to search for. 129 index_by_class_name: If True, use class name converted to 130 lowercase_with_underscores instead of module name in return dict keys. 131 132 Returns: 133 dict of {module_name: class} or {underscored_class_name: class} 134 """ 135 classes = {} 136 for _, obj in inspect.getmembers(module): 137 # Ensure object is a class. 138 if not inspect.isclass(obj): 139 continue 140 # Include only subclasses of base_class. 141 if not issubclass(obj, base_class): 142 continue 143 # Exclude the base_class itself. 144 if obj is base_class: 145 continue 146 # Exclude protected or private classes. 147 if obj.__name__.startswith('_'): 148 continue 149 # Include only the module in which the class is defined. 150 # If a class is imported by another module, exclude those duplicates. 151 if obj.__module__ != module.__name__: 152 continue 153 154 if index_by_class_name: 155 key_name = camel_case.ToUnderscore(obj.__name__) 156 else: 157 key_name = module.__name__.split('.')[-1] 158 if not directly_constructable or IsDirectlyConstructable(obj): 159 if key_name in classes and index_by_class_name: 160 assert classes[key_name] is obj, ( 161 'Duplicate key_name with different objs detected: ' 162 'key=%s, obj1=%s, obj2=%s' % (key_name, classes[key_name], obj)) 163 else: 164 classes[key_name] = obj 165 166 return classes 167 168 169def IsDirectlyConstructable(cls): 170 """Returns True if instance of |cls| can be construct without arguments.""" 171 assert inspect.isclass(cls) 172 if not hasattr(cls, '__init__'): 173 # Case |class A: pass|. 174 return True 175 if cls.__init__ is object.__init__: 176 # Case |class A(object): pass|. 177 return True 178 # Case |class (object):| with |__init__| other than |object.__init__|. 179 args, _, _, defaults = inspect.getargspec(cls.__init__) 180 if defaults is None: 181 defaults = () 182 # Return true if |self| is only arg without a default. 183 return len(args) == len(defaults) + 1 184 185 186_COUNTER = [0] 187 188 189def _GetUniqueModuleName(): 190 _COUNTER[0] += 1 191 return "module_" + str(_COUNTER[0]) 192