1# Copyright 2012 The Chromium Authors. All rights reserved. 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4 5import fnmatch 6import importlib 7import inspect 8import os 9import re 10import sys 11 12from py_utils import camel_case 13 14 15def DiscoverModules(start_dir, top_level_dir, pattern='*'): 16 """Discover all modules in |start_dir| which match |pattern|. 17 18 Args: 19 start_dir: The directory to recursively search. 20 top_level_dir: The top level of the package, for importing. 21 pattern: Unix shell-style pattern for filtering the filenames to import. 22 23 Returns: 24 list of modules. 25 """ 26 # start_dir and top_level_dir must be consistent with each other. 27 start_dir = os.path.realpath(start_dir) 28 top_level_dir = os.path.realpath(top_level_dir) 29 30 modules = [] 31 sub_paths = list(os.walk(start_dir)) 32 # We sort the directories & file paths to ensure a deterministic ordering when 33 # traversing |top_level_dir|. 34 sub_paths.sort(key=lambda paths_tuple: paths_tuple[0]) 35 for dir_path, _, filenames in sub_paths: 36 # Sort the directories to walk recursively by the directory path. 37 filenames.sort() 38 for filename in filenames: 39 # Filter out unwanted filenames. 40 if filename.startswith('.') or filename.startswith('_'): 41 continue 42 if os.path.splitext(filename)[1] != '.py': 43 continue 44 if not fnmatch.fnmatch(filename, pattern): 45 continue 46 47 # Find the module. 48 module_rel_path = os.path.relpath( 49 os.path.join(dir_path, filename), top_level_dir) 50 module_name = re.sub(r'[/\\]', '.', os.path.splitext(module_rel_path)[0]) 51 52 # Import the module. 53 try: 54 # Make sure that top_level_dir is the first path in the sys.path in case 55 # there are naming conflict in module parts. 56 original_sys_path = sys.path[:] 57 sys.path.insert(0, top_level_dir) 58 module = importlib.import_module(module_name) 59 modules.append(module) 60 finally: 61 sys.path = original_sys_path 62 return modules 63 64 65def AssertNoKeyConflicts(classes_by_key_1, classes_by_key_2): 66 for k in classes_by_key_1: 67 if k in classes_by_key_2: 68 assert classes_by_key_1[k] is classes_by_key_2[k], ( 69 'Found conflicting classes for the same key: ' 70 'key=%s, class_1=%s, class_2=%s' % ( 71 k, classes_by_key_1[k], classes_by_key_2[k])) 72 73 74# TODO(dtu): Normalize all discoverable classes to have corresponding module 75# and class names, then always index by class name. 76def DiscoverClasses(start_dir, 77 top_level_dir, 78 base_class, 79 pattern='*', 80 index_by_class_name=True, 81 directly_constructable=False): 82 """Discover all classes in |start_dir| which subclass |base_class|. 83 84 Base classes that contain subclasses are ignored by default. 85 86 Args: 87 start_dir: The directory to recursively search. 88 top_level_dir: The top level of the package, for importing. 89 base_class: The base class to search for. 90 pattern: Unix shell-style pattern for filtering the filenames to import. 91 index_by_class_name: If True, use class name converted to 92 lowercase_with_underscores instead of module name in return dict keys. 93 directly_constructable: If True, will only return classes that can be 94 constructed without arguments 95 96 Returns: 97 dict of {module_name: class} or {underscored_class_name: class} 98 """ 99 modules = DiscoverModules(start_dir, top_level_dir, pattern) 100 classes = {} 101 for module in modules: 102 new_classes = DiscoverClassesInModule( 103 module, base_class, index_by_class_name, directly_constructable) 104 # TODO(crbug.com/548652): we should remove index_by_class_name once 105 # benchmark_smoke_unittest in chromium/src/tools/perf no longer relied 106 # naming collisions to reduce the number of smoked benchmark tests. 107 if index_by_class_name: 108 AssertNoKeyConflicts(classes, new_classes) 109 classes = dict(list(classes.items()) + list(new_classes.items())) 110 return classes 111 112 113# TODO(crbug.com/548652): we should remove index_by_class_name once 114# benchmark_smoke_unittest in chromium/src/tools/perf no longer relied 115# naming collisions to reduce the number of smoked benchmark tests. 116def DiscoverClassesInModule(module, 117 base_class, 118 index_by_class_name=False, 119 directly_constructable=False): 120 """Discover all classes in |module| which subclass |base_class|. 121 122 Base classes that contain subclasses are ignored by default. 123 124 Args: 125 module: The module to search. 126 base_class: The base class to search for. 127 index_by_class_name: If True, use class name converted to 128 lowercase_with_underscores instead of module name in return dict keys. 129 130 Returns: 131 dict of {module_name: class} or {underscored_class_name: class} 132 """ 133 classes = {} 134 for _, obj in inspect.getmembers(module): 135 # Ensure object is a class. 136 if not inspect.isclass(obj): 137 continue 138 # Include only subclasses of base_class. 139 if not issubclass(obj, base_class): 140 continue 141 # Exclude the base_class itself. 142 if obj is base_class: 143 continue 144 # Exclude protected or private classes. 145 if obj.__name__.startswith('_'): 146 continue 147 # Include only the module in which the class is defined. 148 # If a class is imported by another module, exclude those duplicates. 149 if obj.__module__ != module.__name__: 150 continue 151 152 if index_by_class_name: 153 key_name = camel_case.ToUnderscore(obj.__name__) 154 else: 155 key_name = module.__name__.split('.')[-1] 156 if not directly_constructable or IsDirectlyConstructable(obj): 157 if key_name in classes and index_by_class_name: 158 assert classes[key_name] is obj, ( 159 'Duplicate key_name with different objs detected: ' 160 'key=%s, obj1=%s, obj2=%s' % (key_name, classes[key_name], obj)) 161 else: 162 classes[key_name] = obj 163 164 return classes 165 166 167def IsDirectlyConstructable(cls): 168 """Returns True if instance of |cls| can be construct without arguments.""" 169 assert inspect.isclass(cls) 170 if not hasattr(cls, '__init__'): 171 # Case |class A: pass|. 172 return True 173 if cls.__init__ is object.__init__: 174 # Case |class A(object): pass|. 175 return True 176 # Case |class (object):| with |__init__| other than |object.__init__|. 177 args, _, _, defaults = inspect.getargspec(cls.__init__) 178 if defaults is None: 179 defaults = () 180 # Return true if |self| is only arg without a default. 181 return len(args) == len(defaults) + 1 182 183 184_COUNTER = [0] 185 186 187def _GetUniqueModuleName(): 188 _COUNTER[0] += 1 189 return "module_" + str(_COUNTER[0]) 190