• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import os
2import subprocess
3import contextlib
4import functools
5import tempfile
6import shutil
7import operator
8
9
10@contextlib.contextmanager
11def pushd(dir):
12    orig = os.getcwd()
13    os.chdir(dir)
14    try:
15        yield dir
16    finally:
17        os.chdir(orig)
18
19
20@contextlib.contextmanager
21def tarball_context(url, target_dir=None, runner=None, pushd=pushd):
22    """
23    Get a tarball, extract it, change to that directory, yield, then
24    clean up.
25    `runner` is the function to invoke commands.
26    `pushd` is a context manager for changing the directory.
27    """
28    if target_dir is None:
29        target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '')
30    if runner is None:
31        runner = functools.partial(subprocess.check_call, shell=True)
32    # In the tar command, use --strip-components=1 to strip the first path and
33    #  then
34    #  use -C to cause the files to be extracted to {target_dir}. This ensures
35    #  that we always know where the files were extracted.
36    runner('mkdir {target_dir}'.format(**vars()))
37    try:
38        getter = 'wget {url} -O -'
39        extract = 'tar x{compression} --strip-components=1 -C {target_dir}'
40        cmd = ' | '.join((getter, extract))
41        runner(cmd.format(compression=infer_compression(url), **vars()))
42        with pushd(target_dir):
43            yield target_dir
44    finally:
45        runner('rm -Rf {target_dir}'.format(**vars()))
46
47
48def infer_compression(url):
49    """
50    Given a URL or filename, infer the compression code for tar.
51    """
52    # cheat and just assume it's the last two characters
53    compression_indicator = url[-2:]
54    mapping = dict(gz='z', bz='j', xz='J')
55    # Assume 'z' (gzip) if no match
56    return mapping.get(compression_indicator, 'z')
57
58
59@contextlib.contextmanager
60def temp_dir(remover=shutil.rmtree):
61    """
62    Create a temporary directory context. Pass a custom remover
63    to override the removal behavior.
64    """
65    temp_dir = tempfile.mkdtemp()
66    try:
67        yield temp_dir
68    finally:
69        remover(temp_dir)
70
71
72@contextlib.contextmanager
73def repo_context(url, branch=None, quiet=True, dest_ctx=temp_dir):
74    """
75    Check out the repo indicated by url.
76
77    If dest_ctx is supplied, it should be a context manager
78    to yield the target directory for the check out.
79    """
80    exe = 'git' if 'git' in url else 'hg'
81    with dest_ctx() as repo_dir:
82        cmd = [exe, 'clone', url, repo_dir]
83        if branch:
84            cmd.extend(['--branch', branch])
85        devnull = open(os.path.devnull, 'w')
86        stdout = devnull if quiet else None
87        subprocess.check_call(cmd, stdout=stdout)
88        yield repo_dir
89
90
91@contextlib.contextmanager
92def null():
93    yield
94
95
96class ExceptionTrap:
97    """
98    A context manager that will catch certain exceptions and provide an
99    indication they occurred.
100
101    >>> with ExceptionTrap() as trap:
102    ...     raise Exception()
103    >>> bool(trap)
104    True
105
106    >>> with ExceptionTrap() as trap:
107    ...     pass
108    >>> bool(trap)
109    False
110
111    >>> with ExceptionTrap(ValueError) as trap:
112    ...     raise ValueError("1 + 1 is not 3")
113    >>> bool(trap)
114    True
115
116    >>> with ExceptionTrap(ValueError) as trap:
117    ...     raise Exception()
118    Traceback (most recent call last):
119    ...
120    Exception
121
122    >>> bool(trap)
123    False
124    """
125
126    exc_info = None, None, None
127
128    def __init__(self, exceptions=(Exception,)):
129        self.exceptions = exceptions
130
131    def __enter__(self):
132        return self
133
134    @property
135    def type(self):
136        return self.exc_info[0]
137
138    @property
139    def value(self):
140        return self.exc_info[1]
141
142    @property
143    def tb(self):
144        return self.exc_info[2]
145
146    def __exit__(self, *exc_info):
147        type = exc_info[0]
148        matches = type and issubclass(type, self.exceptions)
149        if matches:
150            self.exc_info = exc_info
151        return matches
152
153    def __bool__(self):
154        return bool(self.type)
155
156    def raises(self, func, *, _test=bool):
157        """
158        Wrap func and replace the result with the truth
159        value of the trap (True if an exception occurred).
160
161        First, give the decorator an alias to support Python 3.8
162        Syntax.
163
164        >>> raises = ExceptionTrap(ValueError).raises
165
166        Now decorate a function that always fails.
167
168        >>> @raises
169        ... def fail():
170        ...     raise ValueError('failed')
171        >>> fail()
172        True
173        """
174
175        @functools.wraps(func)
176        def wrapper(*args, **kwargs):
177            with ExceptionTrap(self.exceptions) as trap:
178                func(*args, **kwargs)
179            return _test(trap)
180
181        return wrapper
182
183    def passes(self, func):
184        """
185        Wrap func and replace the result with the truth
186        value of the trap (True if no exception).
187
188        First, give the decorator an alias to support Python 3.8
189        Syntax.
190
191        >>> passes = ExceptionTrap(ValueError).passes
192
193        Now decorate a function that always fails.
194
195        >>> @passes
196        ... def fail():
197        ...     raise ValueError('failed')
198
199        >>> fail()
200        False
201        """
202        return self.raises(func, _test=operator.not_)
203
204
205class suppress(contextlib.suppress, contextlib.ContextDecorator):
206    """
207    A version of contextlib.suppress with decorator support.
208
209    >>> @suppress(KeyError)
210    ... def key_error():
211    ...     {}['']
212    >>> key_error()
213    """
214