1 /**
2 * markupsafe._speedups
3 * ~~~~~~~~~~~~~~~~~~~~
4 *
5 * This module implements functions for automatic escaping in C for better
6 * performance.
7 *
8 * :copyright: (c) 2010 by Armin Ronacher.
9 * :license: BSD.
10 */
11
12 #include <Python.h>
13
14 #define ESCAPED_CHARS_TABLE_SIZE 63
15 #define UNICHR(x) (PyUnicode_AS_UNICODE((PyUnicodeObject*)PyUnicode_DecodeASCII(x, strlen(x), NULL)));
16
17 #if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN)
18 typedef int Py_ssize_t;
19 #define PY_SSIZE_T_MAX INT_MAX
20 #define PY_SSIZE_T_MIN INT_MIN
21 #endif
22
23
24 static PyObject* markup;
25 static Py_ssize_t escaped_chars_delta_len[ESCAPED_CHARS_TABLE_SIZE];
26 static Py_UNICODE *escaped_chars_repl[ESCAPED_CHARS_TABLE_SIZE];
27
28 static int
init_constants(void)29 init_constants(void)
30 {
31 PyObject *module;
32 /* happing of characters to replace */
33 escaped_chars_repl['"'] = UNICHR(""");
34 escaped_chars_repl['\''] = UNICHR("'");
35 escaped_chars_repl['&'] = UNICHR("&");
36 escaped_chars_repl['<'] = UNICHR("<");
37 escaped_chars_repl['>'] = UNICHR(">");
38
39 /* lengths of those characters when replaced - 1 */
40 memset(escaped_chars_delta_len, 0, sizeof (escaped_chars_delta_len));
41 escaped_chars_delta_len['"'] = escaped_chars_delta_len['\''] = \
42 escaped_chars_delta_len['&'] = 4;
43 escaped_chars_delta_len['<'] = escaped_chars_delta_len['>'] = 3;
44
45 /* import markup type so that we can mark the return value */
46 module = PyImport_ImportModule("markupsafe");
47 if (!module)
48 return 0;
49 markup = PyObject_GetAttrString(module, "Markup");
50 Py_DECREF(module);
51
52 return 1;
53 }
54
55 static PyObject*
escape_unicode(PyUnicodeObject * in)56 escape_unicode(PyUnicodeObject *in)
57 {
58 PyUnicodeObject *out;
59 Py_UNICODE *inp = PyUnicode_AS_UNICODE(in);
60 const Py_UNICODE *inp_end = PyUnicode_AS_UNICODE(in) + PyUnicode_GET_SIZE(in);
61 Py_UNICODE *next_escp;
62 Py_UNICODE *outp;
63 Py_ssize_t delta=0, erepl=0, delta_len=0;
64
65 /* First we need to figure out how long the escaped string will be */
66 while (*(inp) || inp < inp_end) {
67 if (*inp < ESCAPED_CHARS_TABLE_SIZE) {
68 delta += escaped_chars_delta_len[*inp];
69 erepl += !!escaped_chars_delta_len[*inp];
70 }
71 ++inp;
72 }
73
74 /* Do we need to escape anything at all? */
75 if (!erepl) {
76 Py_INCREF(in);
77 return (PyObject*)in;
78 }
79
80 out = (PyUnicodeObject*)PyUnicode_FromUnicode(NULL, PyUnicode_GET_SIZE(in) + delta);
81 if (!out)
82 return NULL;
83
84 outp = PyUnicode_AS_UNICODE(out);
85 inp = PyUnicode_AS_UNICODE(in);
86 while (erepl-- > 0) {
87 /* look for the next substitution */
88 next_escp = inp;
89 while (next_escp < inp_end) {
90 if (*next_escp < ESCAPED_CHARS_TABLE_SIZE &&
91 (delta_len = escaped_chars_delta_len[*next_escp])) {
92 ++delta_len;
93 break;
94 }
95 ++next_escp;
96 }
97
98 if (next_escp > inp) {
99 /* copy unescaped chars between inp and next_escp */
100 Py_UNICODE_COPY(outp, inp, next_escp-inp);
101 outp += next_escp - inp;
102 }
103
104 /* escape 'next_escp' */
105 Py_UNICODE_COPY(outp, escaped_chars_repl[*next_escp], delta_len);
106 outp += delta_len;
107
108 inp = next_escp + 1;
109 }
110 if (inp < inp_end)
111 Py_UNICODE_COPY(outp, inp, PyUnicode_GET_SIZE(in) - (inp - PyUnicode_AS_UNICODE(in)));
112
113 return (PyObject*)out;
114 }
115
116
117 static PyObject*
escape(PyObject * self,PyObject * text)118 escape(PyObject *self, PyObject *text)
119 {
120 PyObject *s = NULL, *rv = NULL, *html;
121
122 /* we don't have to escape integers, bools or floats */
123 if (PyLong_CheckExact(text) ||
124 #if PY_MAJOR_VERSION < 3
125 PyInt_CheckExact(text) ||
126 #endif
127 PyFloat_CheckExact(text) || PyBool_Check(text) ||
128 text == Py_None)
129 return PyObject_CallFunctionObjArgs(markup, text, NULL);
130
131 /* if the object has an __html__ method that performs the escaping */
132 html = PyObject_GetAttrString(text, "__html__");
133 if (html) {
134 rv = PyObject_CallObject(html, NULL);
135 Py_DECREF(html);
136 return rv;
137 }
138
139 /* otherwise make the object unicode if it isn't, then escape */
140 PyErr_Clear();
141 if (!PyUnicode_Check(text)) {
142 #if PY_MAJOR_VERSION < 3
143 PyObject *unicode = PyObject_Unicode(text);
144 #else
145 PyObject *unicode = PyObject_Str(text);
146 #endif
147 if (!unicode)
148 return NULL;
149 s = escape_unicode((PyUnicodeObject*)unicode);
150 Py_DECREF(unicode);
151 }
152 else
153 s = escape_unicode((PyUnicodeObject*)text);
154
155 /* convert the unicode string into a markup object. */
156 rv = PyObject_CallFunctionObjArgs(markup, (PyObject*)s, NULL);
157 Py_DECREF(s);
158 return rv;
159 }
160
161
162 static PyObject*
escape_silent(PyObject * self,PyObject * text)163 escape_silent(PyObject *self, PyObject *text)
164 {
165 if (text != Py_None)
166 return escape(self, text);
167 return PyObject_CallFunctionObjArgs(markup, NULL);
168 }
169
170
171 static PyObject*
soft_unicode(PyObject * self,PyObject * s)172 soft_unicode(PyObject *self, PyObject *s)
173 {
174 if (!PyUnicode_Check(s))
175 #if PY_MAJOR_VERSION < 3
176 return PyObject_Unicode(s);
177 #else
178 return PyObject_Str(s);
179 #endif
180 Py_INCREF(s);
181 return s;
182 }
183
184
185 static PyMethodDef module_methods[] = {
186 {"escape", (PyCFunction)escape, METH_O,
187 "escape(s) -> markup\n\n"
188 "Convert the characters &, <, >, ', and \" in string s to HTML-safe\n"
189 "sequences. Use this if you need to display text that might contain\n"
190 "such characters in HTML. Marks return value as markup string."},
191 {"escape_silent", (PyCFunction)escape_silent, METH_O,
192 "escape_silent(s) -> markup\n\n"
193 "Like escape but converts None to an empty string."},
194 {"soft_unicode", (PyCFunction)soft_unicode, METH_O,
195 "soft_unicode(object) -> string\n\n"
196 "Make a string unicode if it isn't already. That way a markup\n"
197 "string is not converted back to unicode."},
198 {NULL, NULL, 0, NULL} /* Sentinel */
199 };
200
201
202 #if PY_MAJOR_VERSION < 3
203
204 #ifndef PyMODINIT_FUNC /* declarations for DLL import/export */
205 #define PyMODINIT_FUNC void
206 #endif
207 PyMODINIT_FUNC
init_speedups(void)208 init_speedups(void)
209 {
210 if (!init_constants())
211 return;
212
213 Py_InitModule3("markupsafe._speedups", module_methods, "");
214 }
215
216 #else /* Python 3.x module initialization */
217
218 static struct PyModuleDef module_definition = {
219 PyModuleDef_HEAD_INIT,
220 "markupsafe._speedups",
221 NULL,
222 -1,
223 module_methods,
224 NULL,
225 NULL,
226 NULL,
227 NULL
228 };
229
230 PyMODINIT_FUNC
PyInit__speedups(void)231 PyInit__speedups(void)
232 {
233 if (!init_constants())
234 return NULL;
235
236 return PyModule_Create(&module_definition);
237 }
238
239 #endif
240