• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import functools
2import re
3import string
4import typing as t
5
6if t.TYPE_CHECKING:
7    import typing_extensions as te
8
9    class HasHTML(te.Protocol):
10        def __html__(self) -> str:
11            pass
12
13
14__version__ = "2.1.1"
15
16_strip_comments_re = re.compile(r"<!--.*?-->")
17_strip_tags_re = re.compile(r"<.*?>")
18
19
20def _simple_escaping_wrapper(name: str) -> t.Callable[..., "Markup"]:
21    orig = getattr(str, name)
22
23    @functools.wraps(orig)
24    def wrapped(self: "Markup", *args: t.Any, **kwargs: t.Any) -> "Markup":
25        args = _escape_argspec(list(args), enumerate(args), self.escape)  # type: ignore
26        _escape_argspec(kwargs, kwargs.items(), self.escape)
27        return self.__class__(orig(self, *args, **kwargs))
28
29    return wrapped
30
31
32class Markup(str):
33    """A string that is ready to be safely inserted into an HTML or XML
34    document, either because it was escaped or because it was marked
35    safe.
36
37    Passing an object to the constructor converts it to text and wraps
38    it to mark it safe without escaping. To escape the text, use the
39    :meth:`escape` class method instead.
40
41    >>> Markup("Hello, <em>World</em>!")
42    Markup('Hello, <em>World</em>!')
43    >>> Markup(42)
44    Markup('42')
45    >>> Markup.escape("Hello, <em>World</em>!")
46    Markup('Hello &lt;em&gt;World&lt;/em&gt;!')
47
48    This implements the ``__html__()`` interface that some frameworks
49    use. Passing an object that implements ``__html__()`` will wrap the
50    output of that method, marking it safe.
51
52    >>> class Foo:
53    ...     def __html__(self):
54    ...         return '<a href="/foo">foo</a>'
55    ...
56    >>> Markup(Foo())
57    Markup('<a href="/foo">foo</a>')
58
59    This is a subclass of :class:`str`. It has the same methods, but
60    escapes their arguments and returns a ``Markup`` instance.
61
62    >>> Markup("<em>%s</em>") % ("foo & bar",)
63    Markup('<em>foo &amp; bar</em>')
64    >>> Markup("<em>Hello</em> ") + "<foo>"
65    Markup('<em>Hello</em> &lt;foo&gt;')
66    """
67
68    __slots__ = ()
69
70    def __new__(
71        cls, base: t.Any = "", encoding: t.Optional[str] = None, errors: str = "strict"
72    ) -> "Markup":
73        if hasattr(base, "__html__"):
74            base = base.__html__()
75
76        if encoding is None:
77            return super().__new__(cls, base)
78
79        return super().__new__(cls, base, encoding, errors)
80
81    def __html__(self) -> "Markup":
82        return self
83
84    def __add__(self, other: t.Union[str, "HasHTML"]) -> "Markup":
85        if isinstance(other, str) or hasattr(other, "__html__"):
86            return self.__class__(super().__add__(self.escape(other)))
87
88        return NotImplemented
89
90    def __radd__(self, other: t.Union[str, "HasHTML"]) -> "Markup":
91        if isinstance(other, str) or hasattr(other, "__html__"):
92            return self.escape(other).__add__(self)
93
94        return NotImplemented
95
96    def __mul__(self, num: "te.SupportsIndex") -> "Markup":
97        if isinstance(num, int):
98            return self.__class__(super().__mul__(num))
99
100        return NotImplemented
101
102    __rmul__ = __mul__
103
104    def __mod__(self, arg: t.Any) -> "Markup":
105        if isinstance(arg, tuple):
106            # a tuple of arguments, each wrapped
107            arg = tuple(_MarkupEscapeHelper(x, self.escape) for x in arg)
108        elif hasattr(type(arg), "__getitem__") and not isinstance(arg, str):
109            # a mapping of arguments, wrapped
110            arg = _MarkupEscapeHelper(arg, self.escape)
111        else:
112            # a single argument, wrapped with the helper and a tuple
113            arg = (_MarkupEscapeHelper(arg, self.escape),)
114
115        return self.__class__(super().__mod__(arg))
116
117    def __repr__(self) -> str:
118        return f"{self.__class__.__name__}({super().__repr__()})"
119
120    def join(self, seq: t.Iterable[t.Union[str, "HasHTML"]]) -> "Markup":
121        return self.__class__(super().join(map(self.escape, seq)))
122
123    join.__doc__ = str.join.__doc__
124
125    def split(  # type: ignore
126        self, sep: t.Optional[str] = None, maxsplit: int = -1
127    ) -> t.List["Markup"]:
128        return [self.__class__(v) for v in super().split(sep, maxsplit)]
129
130    split.__doc__ = str.split.__doc__
131
132    def rsplit(  # type: ignore
133        self, sep: t.Optional[str] = None, maxsplit: int = -1
134    ) -> t.List["Markup"]:
135        return [self.__class__(v) for v in super().rsplit(sep, maxsplit)]
136
137    rsplit.__doc__ = str.rsplit.__doc__
138
139    def splitlines(self, keepends: bool = False) -> t.List["Markup"]:  # type: ignore
140        return [self.__class__(v) for v in super().splitlines(keepends)]
141
142    splitlines.__doc__ = str.splitlines.__doc__
143
144    def unescape(self) -> str:
145        """Convert escaped markup back into a text string. This replaces
146        HTML entities with the characters they represent.
147
148        >>> Markup("Main &raquo; <em>About</em>").unescape()
149        'Main » <em>About</em>'
150        """
151        from html import unescape
152
153        return unescape(str(self))
154
155    def striptags(self) -> str:
156        """:meth:`unescape` the markup, remove tags, and normalize
157        whitespace to single spaces.
158
159        >>> Markup("Main &raquo;\t<em>About</em>").striptags()
160        'Main » About'
161        """
162        # Use two regexes to avoid ambiguous matches.
163        value = _strip_comments_re.sub("", self)
164        value = _strip_tags_re.sub("", value)
165        value = " ".join(value.split())
166        return Markup(value).unescape()
167
168    @classmethod
169    def escape(cls, s: t.Any) -> "Markup":
170        """Escape a string. Calls :func:`escape` and ensures that for
171        subclasses the correct type is returned.
172        """
173        rv = escape(s)
174
175        if rv.__class__ is not cls:
176            return cls(rv)
177
178        return rv
179
180    for method in (
181        "__getitem__",
182        "capitalize",
183        "title",
184        "lower",
185        "upper",
186        "replace",
187        "ljust",
188        "rjust",
189        "lstrip",
190        "rstrip",
191        "center",
192        "strip",
193        "translate",
194        "expandtabs",
195        "swapcase",
196        "zfill",
197    ):
198        locals()[method] = _simple_escaping_wrapper(method)
199
200    del method
201
202    def partition(self, sep: str) -> t.Tuple["Markup", "Markup", "Markup"]:
203        l, s, r = super().partition(self.escape(sep))
204        cls = self.__class__
205        return cls(l), cls(s), cls(r)
206
207    def rpartition(self, sep: str) -> t.Tuple["Markup", "Markup", "Markup"]:
208        l, s, r = super().rpartition(self.escape(sep))
209        cls = self.__class__
210        return cls(l), cls(s), cls(r)
211
212    def format(self, *args: t.Any, **kwargs: t.Any) -> "Markup":
213        formatter = EscapeFormatter(self.escape)
214        return self.__class__(formatter.vformat(self, args, kwargs))
215
216    def __html_format__(self, format_spec: str) -> "Markup":
217        if format_spec:
218            raise ValueError("Unsupported format specification for Markup.")
219
220        return self
221
222
223class EscapeFormatter(string.Formatter):
224    __slots__ = ("escape",)
225
226    def __init__(self, escape: t.Callable[[t.Any], Markup]) -> None:
227        self.escape = escape
228        super().__init__()
229
230    def format_field(self, value: t.Any, format_spec: str) -> str:
231        if hasattr(value, "__html_format__"):
232            rv = value.__html_format__(format_spec)
233        elif hasattr(value, "__html__"):
234            if format_spec:
235                raise ValueError(
236                    f"Format specifier {format_spec} given, but {type(value)} does not"
237                    " define __html_format__. A class that defines __html__ must define"
238                    " __html_format__ to work with format specifiers."
239                )
240            rv = value.__html__()
241        else:
242            # We need to make sure the format spec is str here as
243            # otherwise the wrong callback methods are invoked.
244            rv = string.Formatter.format_field(self, value, str(format_spec))
245        return str(self.escape(rv))
246
247
248_ListOrDict = t.TypeVar("_ListOrDict", list, dict)
249
250
251def _escape_argspec(
252    obj: _ListOrDict, iterable: t.Iterable[t.Any], escape: t.Callable[[t.Any], Markup]
253) -> _ListOrDict:
254    """Helper for various string-wrapped functions."""
255    for key, value in iterable:
256        if isinstance(value, str) or hasattr(value, "__html__"):
257            obj[key] = escape(value)
258
259    return obj
260
261
262class _MarkupEscapeHelper:
263    """Helper for :meth:`Markup.__mod__`."""
264
265    __slots__ = ("obj", "escape")
266
267    def __init__(self, obj: t.Any, escape: t.Callable[[t.Any], Markup]) -> None:
268        self.obj = obj
269        self.escape = escape
270
271    def __getitem__(self, item: t.Any) -> "_MarkupEscapeHelper":
272        return _MarkupEscapeHelper(self.obj[item], self.escape)
273
274    def __str__(self) -> str:
275        return str(self.escape(self.obj))
276
277    def __repr__(self) -> str:
278        return str(self.escape(repr(self.obj)))
279
280    def __int__(self) -> int:
281        return int(self.obj)
282
283    def __float__(self) -> float:
284        return float(self.obj)
285
286
287# circular import
288try:
289    from ._speedups import escape as escape
290    from ._speedups import escape_silent as escape_silent
291    from ._speedups import soft_str as soft_str
292except ImportError:
293    from ._native import escape as escape
294    from ._native import escape_silent as escape_silent  # noqa: F401
295    from ._native import soft_str as soft_str  # noqa: F401
296