• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2012 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import fnmatch
6import importlib
7import inspect
8import os
9import re
10import sys
11
12from py_utils import camel_case
13
14
15def DiscoverModules(start_dir, top_level_dir, pattern='*'):
16  """Discover all modules in |start_dir| which match |pattern|.
17
18  Args:
19    start_dir: The directory to recursively search.
20    top_level_dir: The top level of the package, for importing.
21    pattern: Unix shell-style pattern for filtering the filenames to import.
22
23  Returns:
24    list of modules.
25  """
26  # start_dir and top_level_dir must be consistent with each other.
27  start_dir = os.path.realpath(start_dir)
28  top_level_dir = os.path.realpath(top_level_dir)
29
30  modules = []
31  sub_paths = list(os.walk(start_dir))
32  # We sort the directories & file paths to ensure a deterministic ordering when
33  # traversing |top_level_dir|.
34  sub_paths.sort(key=lambda paths_tuple: paths_tuple[0])
35  for dir_path, _, filenames in sub_paths:
36    # Sort the directories to walk recursively by the directory path.
37    filenames.sort()
38    for filename in filenames:
39      # Filter out unwanted filenames.
40      if filename.startswith('.') or filename.startswith('_'):
41        continue
42      if os.path.splitext(filename)[1] != '.py':
43        continue
44      if not fnmatch.fnmatch(filename, pattern):
45        continue
46
47      # Find the module.
48      module_rel_path = os.path.relpath(
49          os.path.join(dir_path, filename), top_level_dir)
50      module_name = re.sub(r'[/\\]', '.', os.path.splitext(module_rel_path)[0])
51
52      # Import the module.
53      try:
54        # Make sure that top_level_dir is the first path in the sys.path in case
55        # there are naming conflict in module parts.
56        original_sys_path = sys.path[:]
57        sys.path.insert(0, top_level_dir)
58        module = importlib.import_module(module_name)
59        modules.append(module)
60      finally:
61        sys.path = original_sys_path
62  return modules
63
64
65def AssertNoKeyConflicts(classes_by_key_1, classes_by_key_2):
66  for k in classes_by_key_1:
67    if k in classes_by_key_2:
68      assert classes_by_key_1[k] is classes_by_key_2[k], (
69          'Found conflicting classes for the same key: '
70          'key=%s, class_1=%s, class_2=%s' % (
71              k, classes_by_key_1[k], classes_by_key_2[k]))
72
73
74# TODO(dtu): Normalize all discoverable classes to have corresponding module
75# and class names, then always index by class name.
76def DiscoverClasses(start_dir,
77                    top_level_dir,
78                    base_class,
79                    pattern='*',
80                    index_by_class_name=True,
81                    directly_constructable=False):
82  """Discover all classes in |start_dir| which subclass |base_class|.
83
84  Base classes that contain subclasses are ignored by default.
85
86  Args:
87    start_dir: The directory to recursively search.
88    top_level_dir: The top level of the package, for importing.
89    base_class: The base class to search for.
90    pattern: Unix shell-style pattern for filtering the filenames to import.
91    index_by_class_name: If True, use class name converted to
92        lowercase_with_underscores instead of module name in return dict keys.
93    directly_constructable: If True, will only return classes that can be
94        constructed without arguments
95
96  Returns:
97    dict of {module_name: class} or {underscored_class_name: class}
98  """
99  modules = DiscoverModules(start_dir, top_level_dir, pattern)
100  classes = {}
101  for module in modules:
102    new_classes = DiscoverClassesInModule(
103        module, base_class, index_by_class_name, directly_constructable)
104    # TODO(nednguyen): we should remove index_by_class_name once
105    # benchmark_smoke_unittest in chromium/src/tools/perf no longer relied
106    # naming collisions to reduce the number of smoked benchmark tests.
107    # crbug.com/548652
108    if index_by_class_name:
109      AssertNoKeyConflicts(classes, new_classes)
110    classes = dict(list(classes.items()) + list(new_classes.items()))
111  return classes
112
113
114# TODO(nednguyen): we should remove index_by_class_name once
115# benchmark_smoke_unittest in chromium/src/tools/perf no longer relied
116# naming collisions to reduce the number of smoked benchmark tests.
117# crbug.com/548652
118def DiscoverClassesInModule(module,
119                            base_class,
120                            index_by_class_name=False,
121                            directly_constructable=False):
122  """Discover all classes in |module| which subclass |base_class|.
123
124  Base classes that contain subclasses are ignored by default.
125
126  Args:
127    module: The module to search.
128    base_class: The base class to search for.
129    index_by_class_name: If True, use class name converted to
130        lowercase_with_underscores instead of module name in return dict keys.
131
132  Returns:
133    dict of {module_name: class} or {underscored_class_name: class}
134  """
135  classes = {}
136  for _, obj in inspect.getmembers(module):
137    # Ensure object is a class.
138    if not inspect.isclass(obj):
139      continue
140    # Include only subclasses of base_class.
141    if not issubclass(obj, base_class):
142      continue
143    # Exclude the base_class itself.
144    if obj is base_class:
145      continue
146    # Exclude protected or private classes.
147    if obj.__name__.startswith('_'):
148      continue
149    # Include only the module in which the class is defined.
150    # If a class is imported by another module, exclude those duplicates.
151    if obj.__module__ != module.__name__:
152      continue
153
154    if index_by_class_name:
155      key_name = camel_case.ToUnderscore(obj.__name__)
156    else:
157      key_name = module.__name__.split('.')[-1]
158    if not directly_constructable or IsDirectlyConstructable(obj):
159      if key_name in classes and index_by_class_name:
160        assert classes[key_name] is obj, (
161            'Duplicate key_name with different objs detected: '
162            'key=%s, obj1=%s, obj2=%s' % (key_name, classes[key_name], obj))
163      else:
164        classes[key_name] = obj
165
166  return classes
167
168
169def IsDirectlyConstructable(cls):
170  """Returns True if instance of |cls| can be construct without arguments."""
171  assert inspect.isclass(cls)
172  if not hasattr(cls, '__init__'):
173    # Case |class A: pass|.
174    return True
175  if cls.__init__ is object.__init__:
176    # Case |class A(object): pass|.
177    return True
178  # Case |class (object):| with |__init__| other than |object.__init__|.
179  args, _, _, defaults = inspect.getargspec(cls.__init__)
180  if defaults is None:
181    defaults = ()
182  # Return true if |self| is only arg without a default.
183  return len(args) == len(defaults) + 1
184
185
186_COUNTER = [0]
187
188
189def _GetUniqueModuleName():
190  _COUNTER[0] += 1
191  return "module_" + str(_COUNTER[0])
192