• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2#   Cython -- Things that don't belong
3#            anywhere else in particular
4#
5
6import os, sys, re, codecs
7
8modification_time = os.path.getmtime
9
10def cached_function(f):
11    cache = {}
12    uncomputed = object()
13    def wrapper(*args):
14        res = cache.get(args, uncomputed)
15        if res is uncomputed:
16            res = cache[args] = f(*args)
17        return res
18    return wrapper
19
20def cached_method(f):
21    cache_name = '__%s_cache' % f.__name__
22    def wrapper(self, *args):
23        cache = getattr(self, cache_name, None)
24        if cache is None:
25            cache = {}
26            setattr(self, cache_name, cache)
27        if args in cache:
28            return cache[args]
29        res = cache[args] = f(self, *args)
30        return res
31    return wrapper
32
33def replace_suffix(path, newsuf):
34    base, _ = os.path.splitext(path)
35    return base + newsuf
36
37def open_new_file(path):
38    if os.path.exists(path):
39        # Make sure to create a new file here so we can
40        # safely hard link the output files.
41        os.unlink(path)
42
43    # we use the ISO-8859-1 encoding here because we only write pure
44    # ASCII strings or (e.g. for file names) byte encoded strings as
45    # Unicode, so we need a direct mapping from the first 256 Unicode
46    # characters to a byte sequence, which ISO-8859-1 provides
47    return codecs.open(path, "w", encoding="ISO-8859-1")
48
49def castrate_file(path, st):
50    #  Remove junk contents from an output file after a
51    #  failed compilation.
52    #  Also sets access and modification times back to
53    #  those specified by st (a stat struct).
54    try:
55        f = open_new_file(path)
56    except EnvironmentError:
57        pass
58    else:
59        f.write(
60            "#error Do not use this file, it is the result of a failed Cython compilation.\n")
61        f.close()
62        if st:
63            os.utime(path, (st.st_atime, st.st_mtime-1))
64
65def file_newer_than(path, time):
66    ftime = modification_time(path)
67    return ftime > time
68
69@cached_function
70def search_include_directories(dirs, qualified_name, suffix, pos,
71                               include=False, sys_path=False):
72    # Search the list of include directories for the given
73    # file name. If a source file position is given, first
74    # searches the directory containing that file. Returns
75    # None if not found, but does not report an error.
76    # The 'include' option will disable package dereferencing.
77    # If 'sys_path' is True, also search sys.path.
78    if sys_path:
79        dirs = dirs + tuple(sys.path)
80    if pos:
81        file_desc = pos[0]
82        from Cython.Compiler.Scanning import FileSourceDescriptor
83        if not isinstance(file_desc, FileSourceDescriptor):
84            raise RuntimeError("Only file sources for code supported")
85        if include:
86            dirs = (os.path.dirname(file_desc.filename),) + dirs
87        else:
88            dirs = (find_root_package_dir(file_desc.filename),) + dirs
89
90    dotted_filename = qualified_name
91    if suffix:
92        dotted_filename += suffix
93    if not include:
94        names = qualified_name.split('.')
95        package_names = tuple(names[:-1])
96        module_name = names[-1]
97        module_filename = module_name + suffix
98        package_filename = "__init__" + suffix
99
100    for dir in dirs:
101        path = os.path.join(dir, dotted_filename)
102        if path_exists(path):
103            return path
104        if not include:
105            package_dir = check_package_dir(dir, package_names)
106            if package_dir is not None:
107                path = os.path.join(package_dir, module_filename)
108                if path_exists(path):
109                    return path
110                path = os.path.join(dir, package_dir, module_name,
111                                    package_filename)
112                if path_exists(path):
113                    return path
114    return None
115
116
117@cached_function
118def find_root_package_dir(file_path):
119    dir = os.path.dirname(file_path)
120    if file_path == dir:
121        return dir
122    elif is_package_dir(dir):
123        return find_root_package_dir(dir)
124    else:
125        return dir
126
127@cached_function
128def check_package_dir(dir, package_names):
129    for dirname in package_names:
130        dir = os.path.join(dir, dirname)
131        if not is_package_dir(dir):
132            return None
133    return dir
134
135@cached_function
136def is_package_dir(dir_path):
137    for filename in ("__init__.py",
138                     "__init__.pyx",
139                     "__init__.pxd"):
140        path = os.path.join(dir_path, filename)
141        if path_exists(path):
142            return 1
143
144@cached_function
145def path_exists(path):
146    # try on the filesystem first
147    if os.path.exists(path):
148        return True
149    # figure out if a PEP 302 loader is around
150    try:
151        loader = __loader__
152        # XXX the code below assumes a 'zipimport.zipimporter' instance
153        # XXX should be easy to generalize, but too lazy right now to write it
154        archive_path = getattr(loader, 'archive', None)
155        if archive_path:
156            normpath = os.path.normpath(path)
157            if normpath.startswith(archive_path):
158                arcname = normpath[len(archive_path)+1:]
159                try:
160                    loader.get_data(arcname)
161                    return True
162                except IOError:
163                    return False
164    except NameError:
165        pass
166    return False
167
168# file name encodings
169
170def decode_filename(filename):
171    if isinstance(filename, unicode):
172        return filename
173    try:
174        filename_encoding = sys.getfilesystemencoding()
175        if filename_encoding is None:
176            filename_encoding = sys.getdefaultencoding()
177        filename = filename.decode(filename_encoding)
178    except UnicodeDecodeError:
179        pass
180    return filename
181
182# support for source file encoding detection
183
184_match_file_encoding = re.compile(u"coding[:=]\s*([-\w.]+)").search
185
186def detect_file_encoding(source_filename):
187    f = open_source_file(source_filename, encoding="UTF-8", error_handling='ignore')
188    try:
189        return detect_opened_file_encoding(f)
190    finally:
191        f.close()
192
193def detect_opened_file_encoding(f):
194    # PEPs 263 and 3120
195    # Most of the time the first two lines fall in the first 250 chars,
196    # and this bulk read/split is much faster.
197    lines = f.read(250).split("\n")
198    if len(lines) > 2:
199        m = _match_file_encoding(lines[0]) or _match_file_encoding(lines[1])
200        if m:
201            return m.group(1)
202        else:
203            return "UTF-8"
204    else:
205        # Fallback to one-char-at-a-time detection.
206        f.seek(0)
207        chars = []
208        for i in range(2):
209            c = f.read(1)
210            while c and c != u'\n':
211                chars.append(c)
212                c = f.read(1)
213            encoding = _match_file_encoding(u''.join(chars))
214            if encoding:
215                return encoding.group(1)
216    return "UTF-8"
217
218
219def skip_bom(f):
220    """
221    Read past a BOM at the beginning of a source file.
222    This could be added to the scanner, but it's *substantially* easier
223    to keep it at this level.
224    """
225    if f.read(1) != u'\uFEFF':
226        f.seek(0)
227
228
229normalise_newlines = re.compile(u'\r\n?|\n').sub
230
231
232class NormalisedNewlineStream(object):
233    """The codecs module doesn't provide universal newline support.
234    This class is used as a stream wrapper that provides this
235    functionality.  The new 'io' in Py2.6+/3.x supports this out of the
236    box.
237    """
238
239    def __init__(self, stream):
240        # let's assume .read() doesn't change
241        self.stream = stream
242        self._read = stream.read
243        self.close = stream.close
244        self.encoding = getattr(stream, 'encoding', 'UTF-8')
245
246    def read(self, count=-1):
247        data = self._read(count)
248        if u'\r' not in data:
249            return data
250        if data.endswith(u'\r'):
251            # may be missing a '\n'
252            data += self._read(1)
253        return normalise_newlines(u'\n', data)
254
255    def readlines(self):
256        content = []
257        data = self.read(0x1000)
258        while data:
259            content.append(data)
260            data = self.read(0x1000)
261
262        return u''.join(content).splitlines(True)
263
264    def seek(self, pos):
265        if pos == 0:
266            self.stream.seek(0)
267        else:
268            raise NotImplementedError
269
270
271io = None
272if sys.version_info >= (2,6):
273    try:
274        import io
275    except ImportError:
276        pass
277
278
279def open_source_file(source_filename, mode="r",
280                     encoding=None, error_handling=None,
281                     require_normalised_newlines=True):
282    if encoding is None:
283        # Most of the time the coding is unspecified, so be optimistic that
284        # it's UTF-8.
285        f = open_source_file(source_filename, encoding="UTF-8", mode=mode, error_handling='ignore')
286        encoding = detect_opened_file_encoding(f)
287        if (encoding == "UTF-8"
288                and error_handling == 'ignore'
289                and require_normalised_newlines):
290            f.seek(0)
291            skip_bom(f)
292            return f
293        else:
294            f.close()
295    #
296    if not os.path.exists(source_filename):
297        try:
298            loader = __loader__
299            if source_filename.startswith(loader.archive):
300                return open_source_from_loader(
301                    loader, source_filename,
302                    encoding, error_handling,
303                    require_normalised_newlines)
304        except (NameError, AttributeError):
305            pass
306    #
307    if io is not None:
308        stream = io.open(source_filename, mode=mode,
309                         encoding=encoding, errors=error_handling)
310    else:
311        # codecs module doesn't have universal newline support
312        stream = codecs.open(source_filename, mode=mode,
313                             encoding=encoding, errors=error_handling)
314        if require_normalised_newlines:
315            stream = NormalisedNewlineStream(stream)
316    skip_bom(stream)
317    return stream
318
319
320def open_source_from_loader(loader,
321                            source_filename,
322                            encoding=None, error_handling=None,
323                            require_normalised_newlines=True):
324    nrmpath = os.path.normpath(source_filename)
325    arcname = nrmpath[len(loader.archive)+1:]
326    data = loader.get_data(arcname)
327    if io is not None:
328        return io.TextIOWrapper(io.BytesIO(data),
329                                encoding=encoding,
330                                errors=error_handling)
331    else:
332        try:
333            import cStringIO as StringIO
334        except ImportError:
335            import StringIO
336        reader = codecs.getreader(encoding)
337        stream = reader(StringIO.StringIO(data))
338        if require_normalised_newlines:
339            stream = NormalisedNewlineStream(stream)
340        return stream
341
342def str_to_number(value):
343    # note: this expects a string as input that was accepted by the
344    # parser already
345    if len(value) < 2:
346        value = int(value, 0)
347    elif value[0] == '0':
348        if value[1] in 'xX':
349            # hex notation ('0x1AF')
350            value = int(value[2:], 16)
351        elif value[1] in 'oO':
352            # Py3 octal notation ('0o136')
353            value = int(value[2:], 8)
354        elif value[1] in 'bB':
355            # Py3 binary notation ('0b101')
356            value = int(value[2:], 2)
357        else:
358            # Py2 octal notation ('0136')
359            value = int(value, 8)
360    else:
361        value = int(value, 0)
362    return value
363
364def long_literal(value):
365    if isinstance(value, basestring):
366        value = str_to_number(value)
367    return not -2**31 <= value < 2**31
368
369# all() and any() are new in 2.5
370try:
371    # Make sure to bind them on the module, as they will be accessed as
372    # attributes
373    all = all
374    any = any
375except NameError:
376    def all(items):
377        for item in items:
378            if not item:
379                return False
380        return True
381
382    def any(items):
383        for item in items:
384            if item:
385                return True
386        return False
387
388@cached_function
389def get_cython_cache_dir():
390    """get the cython cache dir
391
392    Priority:
393
394    1. CYTHON_CACHE_DIR
395    2. (OS X): ~/Library/Caches/Cython
396       (posix not OS X): XDG_CACHE_HOME/cython if XDG_CACHE_HOME defined
397    3. ~/.cython
398
399    """
400    if 'CYTHON_CACHE_DIR' in os.environ:
401        return os.environ['CYTHON_CACHE_DIR']
402
403    parent = None
404    if os.name == 'posix':
405        if sys.platform == 'darwin':
406            parent = os.path.expanduser('~/Library/Caches')
407        else:
408            # this could fallback on ~/.cache
409            parent = os.environ.get('XDG_CACHE_HOME')
410
411    if parent and os.path.isdir(parent):
412        return os.path.join(parent, 'cython')
413
414    # last fallback: ~/.cython
415    return os.path.expanduser(os.path.join('~', '.cython'))
416