• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The RE2 Authors.  All Rights Reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file.
4r"""A drop-in replacement for the re module.
5
6It uses RE2 under the hood, of course, so various PCRE features
7(e.g. backreferences, look-around assertions) are not supported.
8See https://github.com/google/re2/wiki/Syntax for the canonical
9reference, but known syntactic "gotchas" relative to Python are:
10
11  * PCRE supports \Z and \z; RE2 supports \z; Python supports \z,
12    but calls it \Z. You must rewrite \Z to \z in pattern strings.
13
14Known differences between this module's API and the re module's API:
15
16  * The error class does not provide any error information as attributes.
17  * The Options class replaces the re module's flags with RE2's options as
18    gettable/settable properties. Please see re2.h for their documentation.
19  * The pattern string and the input string do not have to be the same type.
20    Any str will be encoded to UTF-8.
21  * The pattern string cannot be str if the options specify Latin-1 encoding.
22
23This module's LRU cache contains a maximum of 128 regular expression objects.
24Each regular expression object's underlying RE2 object uses a maximum of 8MiB
25of memory (by default). Hence, this module's LRU cache uses a maximum of 1GiB
26of memory (by default), but in most cases, it should use much less than that.
27"""
28
29import codecs
30import functools
31import itertools
32
33import _re2
34
35
36class error(Exception):
37  pass
38
39
40class Options(_re2.RE2.Options):
41
42  __slots__ = ()
43
44  NAMES = (
45      'max_mem',
46      'encoding',
47      'posix_syntax',
48      'longest_match',
49      'log_errors',
50      'literal',
51      'never_nl',
52      'dot_nl',
53      'never_capture',
54      'case_sensitive',
55      'perl_classes',
56      'word_boundary',
57      'one_line',
58  )
59
60
61def compile(pattern, options=None):
62  if isinstance(pattern, _Regexp):
63    if options:
64      raise error('pattern is already compiled, so '
65                  'options may not be specified')
66    pattern = pattern._pattern
67  options = options or Options()
68  values = tuple(getattr(options, name) for name in Options.NAMES)
69  return _Regexp._make(pattern, values)
70
71
72def search(pattern, text, options=None):
73  return compile(pattern, options=options).search(text)
74
75
76def match(pattern, text, options=None):
77  return compile(pattern, options=options).match(text)
78
79
80def fullmatch(pattern, text, options=None):
81  return compile(pattern, options=options).fullmatch(text)
82
83
84def finditer(pattern, text, options=None):
85  return compile(pattern, options=options).finditer(text)
86
87
88def findall(pattern, text, options=None):
89  return compile(pattern, options=options).findall(text)
90
91
92def split(pattern, text, maxsplit=0, options=None):
93  return compile(pattern, options=options).split(text, maxsplit)
94
95
96def subn(pattern, repl, text, count=0, options=None):
97  return compile(pattern, options=options).subn(repl, text, count)
98
99
100def sub(pattern, repl, text, count=0, options=None):
101  return compile(pattern, options=options).sub(repl, text, count)
102
103
104def _encode(t):
105  return t.encode(encoding='utf-8')
106
107
108def _decode(b):
109  return b.decode(encoding='utf-8')
110
111
112def escape(pattern):
113  if isinstance(pattern, str):
114    encoded_pattern = _encode(pattern)
115    escaped = _re2.RE2.QuoteMeta(encoded_pattern)
116    decoded_escaped = _decode(escaped)
117    return decoded_escaped
118  else:
119    escaped = _re2.RE2.QuoteMeta(pattern)
120    return escaped
121
122
123def purge():
124  return _Regexp._make.cache_clear()
125
126
127_Anchor = _re2.RE2.Anchor
128_NULL_SPAN = (-1, -1)
129
130
131class _Regexp(object):
132
133  __slots__ = ('_pattern', '_regexp')
134
135  @classmethod
136  @functools.lru_cache(typed=True)
137  def _make(cls, pattern, values):
138    options = Options()
139    for name, value in zip(Options.NAMES, values):
140      setattr(options, name, value)
141    return cls(pattern, options)
142
143  def __init__(self, pattern, options):
144    self._pattern = pattern
145    if isinstance(self._pattern, str):
146      if options.encoding == Options.Encoding.LATIN1:
147        raise error('string type of pattern is str, but '
148                    'encoding specified in options is LATIN1')
149      encoded_pattern = _encode(self._pattern)
150      self._regexp = _re2.RE2(encoded_pattern, options)
151    else:
152      self._regexp = _re2.RE2(self._pattern, options)
153    if not self._regexp.ok():
154      raise error(self._regexp.error())
155
156  def __getstate__(self):
157    options = {name: getattr(self.options, name) for name in Options.NAMES}
158    return self._pattern, options
159
160  def __setstate__(self, state):
161    pattern, options = state
162    values = tuple(options[name] for name in Options.NAMES)
163    other = _Regexp._make(pattern, values)
164    self._pattern = other._pattern
165    self._regexp = other._regexp
166
167  def _match(self, anchor, text, pos=None, endpos=None):
168    pos = 0 if pos is None else max(0, min(pos, len(text)))
169    endpos = len(text) if endpos is None else max(0, min(endpos, len(text)))
170    if pos > endpos:
171      return
172    if isinstance(text, str):
173      encoded_text = _encode(text)
174      encoded_pos = _re2.CharLenToBytes(encoded_text, 0, pos)
175      if endpos == len(text):
176        # This is the common case.
177        encoded_endpos = len(encoded_text)
178      else:
179        encoded_endpos = encoded_pos + _re2.CharLenToBytes(
180            encoded_text, encoded_pos, endpos - pos)
181      decoded_offsets = {0: 0}
182      last_offset = 0
183      while True:
184        spans = self._regexp.Match(anchor, encoded_text, encoded_pos,
185                                   encoded_endpos)
186        if spans[0] == _NULL_SPAN:
187          break
188
189        # This algorithm is linear in the length of encoded_text. Specifically,
190        # no matter how many groups there are for a given regular expression or
191        # how many iterations through the loop there are for a given generator,
192        # this algorithm uses a single, straightforward pass over encoded_text.
193        offsets = sorted(set(itertools.chain(*spans)))
194        if offsets[0] == -1:
195          offsets = offsets[1:]
196        # Discard the rest of the items because they are useless now - and we
197        # could accumulate one item per str offset in the pathological case!
198        decoded_offsets = {last_offset: decoded_offsets[last_offset]}
199        for offset in offsets:
200          decoded_offsets[offset] = (
201              decoded_offsets[last_offset] +
202              _re2.BytesToCharLen(encoded_text, last_offset, offset))
203          last_offset = offset
204
205        def decode(span):
206          if span == _NULL_SPAN:
207            return span
208          return decoded_offsets[span[0]], decoded_offsets[span[1]]
209
210        decoded_spans = [decode(span) for span in spans]
211        yield _Match(self, text, pos, endpos, decoded_spans)
212        if encoded_pos == encoded_endpos:
213          break
214        elif encoded_pos == spans[0][1]:
215          # We matched the empty string at encoded_pos and would be stuck, so
216          # in order to make forward progress, increment the str offset.
217          encoded_pos += _re2.CharLenToBytes(encoded_text, encoded_pos, 1)
218        else:
219          encoded_pos = spans[0][1]
220    else:
221      while True:
222        spans = self._regexp.Match(anchor, text, pos, endpos)
223        if spans[0] == _NULL_SPAN:
224          break
225        yield _Match(self, text, pos, endpos, spans)
226        if pos == endpos:
227          break
228        elif pos == spans[0][1]:
229          # We matched the empty string at pos and would be stuck, so in order
230          # to make forward progress, increment the bytes offset.
231          pos += 1
232        else:
233          pos = spans[0][1]
234
235  def search(self, text, pos=None, endpos=None):
236    return next(self._match(_Anchor.UNANCHORED, text, pos, endpos), None)
237
238  def match(self, text, pos=None, endpos=None):
239    return next(self._match(_Anchor.ANCHOR_START, text, pos, endpos), None)
240
241  def fullmatch(self, text, pos=None, endpos=None):
242    return next(self._match(_Anchor.ANCHOR_BOTH, text, pos, endpos), None)
243
244  def finditer(self, text, pos=None, endpos=None):
245    return self._match(_Anchor.UNANCHORED, text, pos, endpos)
246
247  def findall(self, text, pos=None, endpos=None):
248    empty = type(text)()
249    items = []
250    for match in self.finditer(text, pos, endpos):
251      if not self.groups:
252        item = match.group()
253      elif self.groups == 1:
254        item = match.groups(default=empty)[0]
255      else:
256        item = match.groups(default=empty)
257      items.append(item)
258    return items
259
260  def _split(self, cb, text, maxsplit=0):
261    if maxsplit < 0:
262      return [text], 0
263    elif maxsplit > 0:
264      matchiter = itertools.islice(self.finditer(text), maxsplit)
265    else:
266      matchiter = self.finditer(text)
267    pieces = []
268    end = 0
269    numsplit = 0
270    for match in matchiter:
271      pieces.append(text[end:match.start()])
272      pieces.extend(cb(match))
273      end = match.end()
274      numsplit += 1
275    pieces.append(text[end:])
276    return pieces, numsplit
277
278  def split(self, text, maxsplit=0):
279    cb = lambda match: [match[group] for group in range(1, self.groups + 1)]
280    pieces, _ = self._split(cb, text, maxsplit)
281    return pieces
282
283  def subn(self, repl, text, count=0):
284    cb = lambda match: [repl(match) if callable(repl) else match.expand(repl)]
285    empty = type(text)()
286    pieces, numsplit = self._split(cb, text, count)
287    joined_pieces = empty.join(pieces)
288    return joined_pieces, numsplit
289
290  def sub(self, repl, text, count=0):
291    joined_pieces, _ = self.subn(repl, text, count)
292    return joined_pieces
293
294  @property
295  def pattern(self):
296    return self._pattern
297
298  @property
299  def options(self):
300    return self._regexp.options()
301
302  @property
303  def groups(self):
304    return self._regexp.NumberOfCapturingGroups()
305
306  @property
307  def groupindex(self):
308    groups = self._regexp.NamedCapturingGroups()
309    if isinstance(self._pattern, str):
310      decoded_groups = [(_decode(group), index) for group, index in groups]
311      return dict(decoded_groups)
312    else:
313      return dict(groups)
314
315  @property
316  def programsize(self):
317    return self._regexp.ProgramSize()
318
319  @property
320  def reverseprogramsize(self):
321    return self._regexp.ReverseProgramSize()
322
323  @property
324  def programfanout(self):
325    return self._regexp.ProgramFanout()
326
327  @property
328  def reverseprogramfanout(self):
329    return self._regexp.ReverseProgramFanout()
330
331  def possiblematchrange(self, maxlen):
332    ok, min, max = self._regexp.PossibleMatchRange(maxlen)
333    if not ok:
334      raise error('failed to compute match range')
335    return min, max
336
337
338class _Match(object):
339
340  __slots__ = ('_regexp', '_text', '_pos', '_endpos', '_spans')
341
342  def __init__(self, regexp, text, pos, endpos, spans):
343    self._regexp = regexp
344    self._text = text
345    self._pos = pos
346    self._endpos = endpos
347    self._spans = spans
348
349  # Python prioritises three-digit octal numbers over group escapes.
350  # For example, \100 should not be handled the same way as \g<10>0.
351  _OCTAL_RE = compile('\\\\[0-7][0-7][0-7]')
352
353  # Python supports \1 through \99 (inclusive) and \g<...> syntax.
354  _GROUP_RE = compile('\\\\[1-9][0-9]?|\\\\g<\\w+>')
355
356  @classmethod
357  @functools.lru_cache(typed=True)
358  def _split(cls, template):
359    if isinstance(template, str):
360      backslash = '\\'
361    else:
362      backslash = b'\\'
363    empty = type(template)()
364    pieces = [empty]
365    index = template.find(backslash)
366    while index != -1:
367      piece, template = template[:index], template[index:]
368      pieces[-1] += piece
369      octal_match = cls._OCTAL_RE.match(template)
370      group_match = cls._GROUP_RE.match(template)
371      if (not octal_match) and group_match:
372        index = group_match.end()
373        piece, template = template[:index], template[index:]
374        pieces.extend((piece, empty))
375      else:
376        # 2 isn't enough for \o, \x, \N, \u and \U escapes, but none of those
377        # should contain backslashes, so break them here and then fix them at
378        # the beginning of the next loop iteration or right before returning.
379        index = 2
380        piece, template = template[:index], template[index:]
381        pieces[-1] += piece
382      index = template.find(backslash)
383    pieces[-1] += template
384    return pieces
385
386  def expand(self, template):
387    if isinstance(template, str):
388      unescape = codecs.unicode_escape_decode
389    else:
390      unescape = codecs.escape_decode
391    empty = type(template)()
392    # Make a copy so that we don't clobber the cached pieces!
393    pieces = list(self._split(template))
394    for index, piece in enumerate(pieces):
395      if not index % 2:
396        pieces[index], _ = unescape(piece)
397      else:
398        if len(piece) <= 3:  # \1 through \99 (inclusive)
399          group = int(piece[1:])
400        else:  # \g<...>
401          group = piece[3:-1]
402          try:
403            group = int(group)
404          except ValueError:
405            pass
406        pieces[index] = self.__getitem__(group) or empty
407    joined_pieces = empty.join(pieces)
408    return joined_pieces
409
410  def __getitem__(self, group):
411    if not isinstance(group, int):
412      try:
413        group = self._regexp.groupindex[group]
414      except KeyError:
415        raise IndexError('bad group name')
416    if not 0 <= group <= self._regexp.groups:
417      raise IndexError('bad group index')
418    span = self._spans[group]
419    if span == _NULL_SPAN:
420      return None
421    return self._text[span[0]:span[1]]
422
423  def group(self, *groups):
424    if not groups:
425      groups = (0,)
426    items = (self.__getitem__(group) for group in groups)
427    return next(items) if len(groups) == 1 else tuple(items)
428
429  def groups(self, default=None):
430    items = []
431    for group in range(1, self._regexp.groups + 1):
432      item = self.__getitem__(group)
433      items.append(default if item is None else item)
434    return tuple(items)
435
436  def groupdict(self, default=None):
437    items = []
438    for group, index in self._regexp.groupindex.items():
439      item = self.__getitem__(index)
440      items.append((group, default) if item is None else (group, item))
441    return dict(items)
442
443  def start(self, group=0):
444    if not 0 <= group <= self._regexp.groups:
445      raise IndexError('bad group index')
446    return self._spans[group][0]
447
448  def end(self, group=0):
449    if not 0 <= group <= self._regexp.groups:
450      raise IndexError('bad group index')
451    return self._spans[group][1]
452
453  def span(self, group=0):
454    if not 0 <= group <= self._regexp.groups:
455      raise IndexError('bad group index')
456    return self._spans[group]
457
458  @property
459  def re(self):
460    return self._regexp
461
462  @property
463  def string(self):
464    return self._text
465
466  @property
467  def pos(self):
468    return self._pos
469
470  @property
471  def endpos(self):
472    return self._endpos
473
474  @property
475  def lastindex(self):
476    max_end = -1
477    max_group = None
478    # We look for the rightmost right parenthesis by keeping the first group
479    # that ends at max_end because that is the leftmost/outermost group when
480    # there are nested groups!
481    for group in range(1, self._regexp.groups + 1):
482      end = self._spans[group][1]
483      if max_end < end:
484        max_end = end
485        max_group = group
486    return max_group
487
488  @property
489  def lastgroup(self):
490    max_group = self.lastindex
491    if not max_group:
492      return None
493    for group, index in self._regexp.groupindex.items():
494      if max_group == index:
495        return group
496    return None
497
498
499class Set(object):
500  """A Pythonic wrapper around RE2::Set."""
501
502  __slots__ = ('_set')
503
504  def __init__(self, anchor, options=None):
505    options = options or Options()
506    self._set = _re2.Set(anchor, options)
507
508  @classmethod
509  def SearchSet(cls, options=None):
510    return cls(_Anchor.UNANCHORED, options=options)
511
512  @classmethod
513  def MatchSet(cls, options=None):
514    return cls(_Anchor.ANCHOR_START, options=options)
515
516  @classmethod
517  def FullMatchSet(cls, options=None):
518    return cls(_Anchor.ANCHOR_BOTH, options=options)
519
520  def Add(self, pattern):
521    if isinstance(pattern, str):
522      encoded_pattern = _encode(pattern)
523      index = self._set.Add(encoded_pattern)
524    else:
525      index = self._set.Add(pattern)
526    if index == -1:
527      raise error('failed to add %r to Set' % pattern)
528    return index
529
530  def Compile(self):
531    if not self._set.Compile():
532      raise error('failed to compile Set')
533
534  def Match(self, text):
535    if isinstance(text, str):
536      encoded_text = _encode(text)
537      matches = self._set.Match(encoded_text)
538    else:
539      matches = self._set.Match(text)
540    return matches or None
541
542
543class Filter(object):
544  """A Pythonic wrapper around FilteredRE2."""
545
546  __slots__ = ('_filter', '_patterns')
547
548  def __init__(self):
549    self._filter = _re2.Filter()
550    self._patterns = []
551
552  def Add(self, pattern, options=None):
553    options = options or Options()
554    if isinstance(pattern, str):
555      encoded_pattern = _encode(pattern)
556      index = self._filter.Add(encoded_pattern, options)
557    else:
558      index = self._filter.Add(pattern, options)
559    if index == -1:
560      raise error('failed to add %r to Filter' % pattern)
561    self._patterns.append(pattern)
562    return index
563
564  def Compile(self):
565    if not self._filter.Compile():
566      raise error('failed to compile Filter')
567
568  def Match(self, text, potential=False):
569    if isinstance(text, str):
570      encoded_text = _encode(text)
571      matches = self._filter.Match(encoded_text, potential)
572    else:
573      matches = self._filter.Match(text, potential)
574    return matches or None
575
576  def re(self, index):
577    if not 0 <= index < len(self._patterns):
578      raise IndexError('bad index')
579    proxy = object.__new__(_Regexp)
580    proxy._pattern = self._patterns[index]
581    proxy._regexp = self._filter.GetRE2(index)
582    return proxy
583