• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mako/codegen.py
2# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
3#
4# This module is part of Mako and is released under
5# the MIT License: http://www.opensource.org/licenses/mit-license.php
6
7"""provides functionality for rendering a parsetree constructing into module
8source code."""
9
10import time
11import re
12from mako.pygen import PythonPrinter
13from mako import util, ast, parsetree, filters, exceptions
14from mako import compat
15
16
17MAGIC_NUMBER = 10
18
19# names which are hardwired into the
20# template and are not accessed via the
21# context itself
22RESERVED_NAMES = set(['context', 'loop', 'UNDEFINED'])
23
24def compile(node,
25                uri,
26                filename=None,
27                default_filters=None,
28                buffer_filters=None,
29                imports=None,
30                future_imports=None,
31                source_encoding=None,
32                generate_magic_comment=True,
33                disable_unicode=False,
34                strict_undefined=False,
35                enable_loop=True,
36                reserved_names=frozenset()):
37
38    """Generate module source code given a parsetree node,
39      uri, and optional source filename"""
40
41    # if on Py2K, push the "source_encoding" string to be
42    # a bytestring itself, as we will be embedding it into
43    # the generated source and we don't want to coerce the
44    # result into a unicode object, in "disable_unicode" mode
45    if not compat.py3k and isinstance(source_encoding, compat.text_type):
46        source_encoding = source_encoding.encode(source_encoding)
47
48
49    buf = util.FastEncodingBuffer()
50
51    printer = PythonPrinter(buf)
52    _GenerateRenderMethod(printer,
53                            _CompileContext(uri,
54                                            filename,
55                                            default_filters,
56                                            buffer_filters,
57                                            imports,
58                                            future_imports,
59                                            source_encoding,
60                                            generate_magic_comment,
61                                            disable_unicode,
62                                            strict_undefined,
63                                            enable_loop,
64                                            reserved_names),
65                                node)
66    return buf.getvalue()
67
68class _CompileContext(object):
69    def __init__(self,
70                    uri,
71                    filename,
72                    default_filters,
73                    buffer_filters,
74                    imports,
75                    future_imports,
76                    source_encoding,
77                    generate_magic_comment,
78                    disable_unicode,
79                    strict_undefined,
80                    enable_loop,
81                    reserved_names):
82        self.uri = uri
83        self.filename = filename
84        self.default_filters = default_filters
85        self.buffer_filters = buffer_filters
86        self.imports = imports
87        self.future_imports = future_imports
88        self.source_encoding = source_encoding
89        self.generate_magic_comment = generate_magic_comment
90        self.disable_unicode = disable_unicode
91        self.strict_undefined = strict_undefined
92        self.enable_loop = enable_loop
93        self.reserved_names = reserved_names
94
95class _GenerateRenderMethod(object):
96    """A template visitor object which generates the
97       full module source for a template.
98
99    """
100    def __init__(self, printer, compiler, node):
101        self.printer = printer
102        self.compiler = compiler
103        self.node = node
104        self.identifier_stack = [None]
105        self.in_def = isinstance(node, (parsetree.DefTag, parsetree.BlockTag))
106
107        if self.in_def:
108            name = "render_%s" % node.funcname
109            args = node.get_argument_expressions()
110            filtered = len(node.filter_args.args) > 0
111            buffered = eval(node.attributes.get('buffered', 'False'))
112            cached = eval(node.attributes.get('cached', 'False'))
113            defs = None
114            pagetag = None
115            if node.is_block and not node.is_anonymous:
116                args += ['**pageargs']
117        else:
118            defs = self.write_toplevel()
119            pagetag = self.compiler.pagetag
120            name = "render_body"
121            if pagetag is not None:
122                args = pagetag.body_decl.get_argument_expressions()
123                if not pagetag.body_decl.kwargs:
124                    args += ['**pageargs']
125                cached = eval(pagetag.attributes.get('cached', 'False'))
126                self.compiler.enable_loop = self.compiler.enable_loop or eval(
127                                        pagetag.attributes.get(
128                                                'enable_loop', 'False')
129                                    )
130            else:
131                args = ['**pageargs']
132                cached = False
133            buffered = filtered = False
134        if args is None:
135            args = ['context']
136        else:
137            args = [a for a in ['context'] + args]
138
139        self.write_render_callable(
140                            pagetag or node,
141                            name, args,
142                            buffered, filtered, cached)
143
144        if defs is not None:
145            for node in defs:
146                _GenerateRenderMethod(printer, compiler, node)
147
148        if not self.in_def:
149            self.write_metadata_struct()
150
151    def write_metadata_struct(self):
152        self.printer.source_map[self.printer.lineno] = \
153                    max(self.printer.source_map)
154        struct = {
155            "filename": self.compiler.filename,
156            "uri": self.compiler.uri,
157            "source_encoding": self.compiler.source_encoding,
158            "line_map": self.printer.source_map,
159        }
160        self.printer.writelines(
161            '"""',
162            '__M_BEGIN_METADATA',
163            compat.json.dumps(struct),
164            '__M_END_METADATA\n'
165            '"""'
166        )
167
168    @property
169    def identifiers(self):
170        return self.identifier_stack[-1]
171
172    def write_toplevel(self):
173        """Traverse a template structure for module-level directives and
174        generate the start of module-level code.
175
176        """
177        inherit = []
178        namespaces = {}
179        module_code = []
180
181        self.compiler.pagetag = None
182
183        class FindTopLevel(object):
184            def visitInheritTag(s, node):
185                inherit.append(node)
186            def visitNamespaceTag(s, node):
187                namespaces[node.name] = node
188            def visitPageTag(s, node):
189                self.compiler.pagetag = node
190            def visitCode(s, node):
191                if node.ismodule:
192                    module_code.append(node)
193
194        f = FindTopLevel()
195        for n in self.node.nodes:
196            n.accept_visitor(f)
197
198        self.compiler.namespaces = namespaces
199
200        module_ident = set()
201        for n in module_code:
202            module_ident = module_ident.union(n.declared_identifiers())
203
204        module_identifiers = _Identifiers(self.compiler)
205        module_identifiers.declared = module_ident
206
207        # module-level names, python code
208        if self.compiler.generate_magic_comment and \
209                self.compiler.source_encoding:
210            self.printer.writeline("# -*- coding:%s -*-" %
211                                    self.compiler.source_encoding)
212
213        if self.compiler.future_imports:
214            self.printer.writeline("from __future__ import %s" %
215                                   (", ".join(self.compiler.future_imports),))
216        self.printer.writeline("from mako import runtime, filters, cache")
217        self.printer.writeline("UNDEFINED = runtime.UNDEFINED")
218        self.printer.writeline("__M_dict_builtin = dict")
219        self.printer.writeline("__M_locals_builtin = locals")
220        self.printer.writeline("_magic_number = %r" % MAGIC_NUMBER)
221        self.printer.writeline("_modified_time = %r" % time.time())
222        self.printer.writeline("_enable_loop = %r" % self.compiler.enable_loop)
223        self.printer.writeline(
224                            "_template_filename = %r" % self.compiler.filename)
225        self.printer.writeline("_template_uri = %r" % self.compiler.uri)
226        self.printer.writeline(
227                    "_source_encoding = %r" % self.compiler.source_encoding)
228        if self.compiler.imports:
229            buf = ''
230            for imp in self.compiler.imports:
231                buf += imp + "\n"
232                self.printer.writeline(imp)
233            impcode = ast.PythonCode(
234                            buf,
235                            source='', lineno=0,
236                            pos=0,
237                            filename='template defined imports')
238        else:
239            impcode = None
240
241        main_identifiers = module_identifiers.branch(self.node)
242        module_identifiers.topleveldefs = \
243            module_identifiers.topleveldefs.\
244                union(main_identifiers.topleveldefs)
245        module_identifiers.declared.add("UNDEFINED")
246        if impcode:
247            module_identifiers.declared.update(impcode.declared_identifiers)
248
249        self.compiler.identifiers = module_identifiers
250        self.printer.writeline("_exports = %r" %
251                            [n.name for n in
252                            main_identifiers.topleveldefs.values()]
253                        )
254        self.printer.write_blanks(2)
255
256        if len(module_code):
257            self.write_module_code(module_code)
258
259        if len(inherit):
260            self.write_namespaces(namespaces)
261            self.write_inherit(inherit[-1])
262        elif len(namespaces):
263            self.write_namespaces(namespaces)
264
265        return list(main_identifiers.topleveldefs.values())
266
267    def write_render_callable(self, node, name, args, buffered, filtered,
268            cached):
269        """write a top-level render callable.
270
271        this could be the main render() method or that of a top-level def."""
272
273        if self.in_def:
274            decorator = node.decorator
275            if decorator:
276                self.printer.writeline(
277                                "@runtime._decorate_toplevel(%s)" % decorator)
278
279        self.printer.start_source(node.lineno)
280        self.printer.writelines(
281            "def %s(%s):" % (name, ','.join(args)),
282                # push new frame, assign current frame to __M_caller
283                "__M_caller = context.caller_stack._push_frame()",
284                "try:"
285        )
286        if buffered or filtered or cached:
287            self.printer.writeline("context._push_buffer()")
288
289        self.identifier_stack.append(
290                                self.compiler.identifiers.branch(self.node))
291        if (not self.in_def or self.node.is_block) and '**pageargs' in args:
292            self.identifier_stack[-1].argument_declared.add('pageargs')
293
294        if not self.in_def and (
295                                len(self.identifiers.locally_assigned) > 0 or
296                                len(self.identifiers.argument_declared) > 0
297                                ):
298            self.printer.writeline("__M_locals = __M_dict_builtin(%s)" %
299                                    ','.join([
300                                            "%s=%s" % (x, x) for x in
301                                            self.identifiers.argument_declared
302                                            ]))
303
304        self.write_variable_declares(self.identifiers, toplevel=True)
305
306        for n in self.node.nodes:
307            n.accept_visitor(self)
308
309        self.write_def_finish(self.node, buffered, filtered, cached)
310        self.printer.writeline(None)
311        self.printer.write_blanks(2)
312        if cached:
313            self.write_cache_decorator(
314                                node, name,
315                                args, buffered,
316                                self.identifiers, toplevel=True)
317
318    def write_module_code(self, module_code):
319        """write module-level template code, i.e. that which
320        is enclosed in <%! %> tags in the template."""
321        for n in module_code:
322            self.printer.start_source(n.lineno)
323            self.printer.write_indented_block(n.text)
324
325    def write_inherit(self, node):
326        """write the module-level inheritance-determination callable."""
327
328        self.printer.writelines(
329            "def _mako_inherit(template, context):",
330                "_mako_generate_namespaces(context)",
331                "return runtime._inherit_from(context, %s, _template_uri)" %
332                (node.parsed_attributes['file']),
333                None
334        )
335
336    def write_namespaces(self, namespaces):
337        """write the module-level namespace-generating callable."""
338        self.printer.writelines(
339            "def _mako_get_namespace(context, name):",
340                "try:",
341                    "return context.namespaces[(__name__, name)]",
342                "except KeyError:",
343                    "_mako_generate_namespaces(context)",
344                "return context.namespaces[(__name__, name)]",
345            None, None
346        )
347        self.printer.writeline("def _mako_generate_namespaces(context):")
348
349
350        for node in namespaces.values():
351            if 'import' in node.attributes:
352                self.compiler.has_ns_imports = True
353            self.printer.start_source(node.lineno)
354            if len(node.nodes):
355                self.printer.writeline("def make_namespace():")
356                export = []
357                identifiers = self.compiler.identifiers.branch(node)
358                self.in_def = True
359                class NSDefVisitor(object):
360                    def visitDefTag(s, node):
361                        s.visitDefOrBase(node)
362
363                    def visitBlockTag(s, node):
364                        s.visitDefOrBase(node)
365
366                    def visitDefOrBase(s, node):
367                        if node.is_anonymous:
368                            raise exceptions.CompileException(
369                                "Can't put anonymous blocks inside "
370                                "<%namespace>",
371                                **node.exception_kwargs
372                            )
373                        self.write_inline_def(node, identifiers, nested=False)
374                        export.append(node.funcname)
375                vis = NSDefVisitor()
376                for n in node.nodes:
377                    n.accept_visitor(vis)
378                self.printer.writeline("return [%s]" % (','.join(export)))
379                self.printer.writeline(None)
380                self.in_def = False
381                callable_name = "make_namespace()"
382            else:
383                callable_name = "None"
384
385            if 'file' in node.parsed_attributes:
386                self.printer.writeline(
387                                "ns = runtime.TemplateNamespace(%r,"
388                                " context._clean_inheritance_tokens(),"
389                                " templateuri=%s, callables=%s, "
390                                " calling_uri=_template_uri)" %
391                                (
392                                    node.name,
393                                    node.parsed_attributes.get('file', 'None'),
394                                    callable_name,
395                                )
396                            )
397            elif 'module' in node.parsed_attributes:
398                self.printer.writeline(
399                                "ns = runtime.ModuleNamespace(%r,"
400                                " context._clean_inheritance_tokens(),"
401                                " callables=%s, calling_uri=_template_uri,"
402                                " module=%s)" %
403                                (
404                                    node.name,
405                                    callable_name,
406                                    node.parsed_attributes.get(
407                                                'module', 'None')
408                                )
409                            )
410            else:
411                self.printer.writeline(
412                                "ns = runtime.Namespace(%r,"
413                                " context._clean_inheritance_tokens(),"
414                                " callables=%s, calling_uri=_template_uri)" %
415                                (
416                                    node.name,
417                                    callable_name,
418                                )
419                            )
420            if eval(node.attributes.get('inheritable', "False")):
421                self.printer.writeline("context['self'].%s = ns" % (node.name))
422
423            self.printer.writeline(
424                "context.namespaces[(__name__, %s)] = ns" % repr(node.name))
425            self.printer.write_blanks(1)
426        if not len(namespaces):
427            self.printer.writeline("pass")
428        self.printer.writeline(None)
429
430    def write_variable_declares(self, identifiers, toplevel=False, limit=None):
431        """write variable declarations at the top of a function.
432
433        the variable declarations are in the form of callable
434        definitions for defs and/or name lookup within the
435        function's context argument. the names declared are based
436        on the names that are referenced in the function body,
437        which don't otherwise have any explicit assignment
438        operation. names that are assigned within the body are
439        assumed to be locally-scoped variables and are not
440        separately declared.
441
442        for def callable definitions, if the def is a top-level
443        callable then a 'stub' callable is generated which wraps
444        the current Context into a closure. if the def is not
445        top-level, it is fully rendered as a local closure.
446
447        """
448
449        # collection of all defs available to us in this scope
450        comp_idents = dict([(c.funcname, c) for c in identifiers.defs])
451        to_write = set()
452
453        # write "context.get()" for all variables we are going to
454        # need that arent in the namespace yet
455        to_write = to_write.union(identifiers.undeclared)
456
457        # write closure functions for closures that we define
458        # right here
459        to_write = to_write.union(
460                        [c.funcname for c in identifiers.closuredefs.values()])
461
462        # remove identifiers that are declared in the argument
463        # signature of the callable
464        to_write = to_write.difference(identifiers.argument_declared)
465
466        # remove identifiers that we are going to assign to.
467        # in this way we mimic Python's behavior,
468        # i.e. assignment to a variable within a block
469        # means that variable is now a "locally declared" var,
470        # which cannot be referenced beforehand.
471        to_write = to_write.difference(identifiers.locally_declared)
472
473        if self.compiler.enable_loop:
474            has_loop = "loop" in to_write
475            to_write.discard("loop")
476        else:
477            has_loop = False
478
479        # if a limiting set was sent, constraint to those items in that list
480        # (this is used for the caching decorator)
481        if limit is not None:
482            to_write = to_write.intersection(limit)
483
484        if toplevel and getattr(self.compiler, 'has_ns_imports', False):
485            self.printer.writeline("_import_ns = {}")
486            self.compiler.has_imports = True
487            for ident, ns in self.compiler.namespaces.items():
488                if 'import' in ns.attributes:
489                    self.printer.writeline(
490                            "_mako_get_namespace(context, %r)."
491                                    "_populate(_import_ns, %r)" %
492                            (
493                                ident,
494                                re.split(r'\s*,\s*', ns.attributes['import'])
495                            ))
496
497        if has_loop:
498            self.printer.writeline(
499                'loop = __M_loop = runtime.LoopStack()'
500            )
501
502        for ident in to_write:
503            if ident in comp_idents:
504                comp = comp_idents[ident]
505                if comp.is_block:
506                    if not comp.is_anonymous:
507                        self.write_def_decl(comp, identifiers)
508                    else:
509                        self.write_inline_def(comp, identifiers, nested=True)
510                else:
511                    if comp.is_root():
512                        self.write_def_decl(comp, identifiers)
513                    else:
514                        self.write_inline_def(comp, identifiers, nested=True)
515
516            elif ident in self.compiler.namespaces:
517                self.printer.writeline(
518                            "%s = _mako_get_namespace(context, %r)" %
519                                (ident, ident)
520                            )
521            else:
522                if getattr(self.compiler, 'has_ns_imports', False):
523                    if self.compiler.strict_undefined:
524                        self.printer.writelines(
525                        "%s = _import_ns.get(%r, UNDEFINED)" %
526                        (ident, ident),
527                        "if %s is UNDEFINED:" % ident,
528                            "try:",
529                                "%s = context[%r]" % (ident, ident),
530                            "except KeyError:",
531                                "raise NameError(\"'%s' is not defined\")" %
532                                    ident,
533                            None, None
534                        )
535                    else:
536                        self.printer.writeline(
537                        "%s = _import_ns.get(%r, context.get(%r, UNDEFINED))" %
538                        (ident, ident, ident))
539                else:
540                    if self.compiler.strict_undefined:
541                        self.printer.writelines(
542                            "try:",
543                                "%s = context[%r]" % (ident, ident),
544                            "except KeyError:",
545                                "raise NameError(\"'%s' is not defined\")" %
546                                    ident,
547                            None
548                        )
549                    else:
550                        self.printer.writeline(
551                            "%s = context.get(%r, UNDEFINED)" % (ident, ident)
552                        )
553
554        self.printer.writeline("__M_writer = context.writer()")
555
556    def write_def_decl(self, node, identifiers):
557        """write a locally-available callable referencing a top-level def"""
558        funcname = node.funcname
559        namedecls = node.get_argument_expressions()
560        nameargs = node.get_argument_expressions(as_call=True)
561
562        if not self.in_def and (
563                                len(self.identifiers.locally_assigned) > 0 or
564                                len(self.identifiers.argument_declared) > 0):
565            nameargs.insert(0, 'context._locals(__M_locals)')
566        else:
567            nameargs.insert(0, 'context')
568        self.printer.writeline("def %s(%s):" % (funcname, ",".join(namedecls)))
569        self.printer.writeline(
570                    "return render_%s(%s)" % (funcname, ",".join(nameargs)))
571        self.printer.writeline(None)
572
573    def write_inline_def(self, node, identifiers, nested):
574        """write a locally-available def callable inside an enclosing def."""
575
576        namedecls = node.get_argument_expressions()
577
578        decorator = node.decorator
579        if decorator:
580            self.printer.writeline(
581                        "@runtime._decorate_inline(context, %s)" % decorator)
582        self.printer.writeline(
583                        "def %s(%s):" % (node.funcname, ",".join(namedecls)))
584        filtered = len(node.filter_args.args) > 0
585        buffered = eval(node.attributes.get('buffered', 'False'))
586        cached = eval(node.attributes.get('cached', 'False'))
587        self.printer.writelines(
588            # push new frame, assign current frame to __M_caller
589            "__M_caller = context.caller_stack._push_frame()",
590            "try:"
591        )
592        if buffered or filtered or cached:
593            self.printer.writelines(
594                "context._push_buffer()",
595            )
596
597        identifiers = identifiers.branch(node, nested=nested)
598
599        self.write_variable_declares(identifiers)
600
601        self.identifier_stack.append(identifiers)
602        for n in node.nodes:
603            n.accept_visitor(self)
604        self.identifier_stack.pop()
605
606        self.write_def_finish(node, buffered, filtered, cached)
607        self.printer.writeline(None)
608        if cached:
609            self.write_cache_decorator(node, node.funcname,
610                                        namedecls, False, identifiers,
611                                        inline=True, toplevel=False)
612
613    def write_def_finish(self, node, buffered, filtered, cached,
614            callstack=True):
615        """write the end section of a rendering function, either outermost or
616        inline.
617
618        this takes into account if the rendering function was filtered,
619        buffered, etc.  and closes the corresponding try: block if any, and
620        writes code to retrieve captured content, apply filters, send proper
621        return value."""
622
623        if not buffered and not cached and not filtered:
624            self.printer.writeline("return ''")
625            if callstack:
626                self.printer.writelines(
627                    "finally:",
628                        "context.caller_stack._pop_frame()",
629                    None
630                )
631
632        if buffered or filtered or cached:
633            if buffered or cached:
634                # in a caching scenario, don't try to get a writer
635                # from the context after popping; assume the caching
636                # implemenation might be using a context with no
637                # extra buffers
638                self.printer.writelines(
639                    "finally:",
640                        "__M_buf = context._pop_buffer()"
641                )
642            else:
643                self.printer.writelines(
644                    "finally:",
645                    "__M_buf, __M_writer = context._pop_buffer_and_writer()"
646                )
647
648            if callstack:
649                self.printer.writeline("context.caller_stack._pop_frame()")
650
651            s = "__M_buf.getvalue()"
652            if filtered:
653                s = self.create_filter_callable(node.filter_args.args, s,
654                                                False)
655            self.printer.writeline(None)
656            if buffered and not cached:
657                s = self.create_filter_callable(self.compiler.buffer_filters,
658                                                s, False)
659            if buffered or cached:
660                self.printer.writeline("return %s" % s)
661            else:
662                self.printer.writelines(
663                    "__M_writer(%s)" % s,
664                    "return ''"
665                )
666
667    def write_cache_decorator(self, node_or_pagetag, name,
668                                    args, buffered, identifiers,
669                                    inline=False, toplevel=False):
670        """write a post-function decorator to replace a rendering
671            callable with a cached version of itself."""
672
673        self.printer.writeline("__M_%s = %s" % (name, name))
674        cachekey = node_or_pagetag.parsed_attributes.get('cache_key',
675                                                         repr(name))
676
677        cache_args = {}
678        if self.compiler.pagetag is not None:
679            cache_args.update(
680                (
681                    pa[6:],
682                    self.compiler.pagetag.parsed_attributes[pa]
683                )
684                for pa in self.compiler.pagetag.parsed_attributes
685                if pa.startswith('cache_') and pa != 'cache_key'
686            )
687        cache_args.update(
688            (
689                pa[6:],
690                node_or_pagetag.parsed_attributes[pa]
691            ) for pa in node_or_pagetag.parsed_attributes
692            if pa.startswith('cache_') and pa != 'cache_key'
693        )
694        if 'timeout' in cache_args:
695            cache_args['timeout'] = int(eval(cache_args['timeout']))
696
697        self.printer.writeline("def %s(%s):" % (name, ','.join(args)))
698
699        # form "arg1, arg2, arg3=arg3, arg4=arg4", etc.
700        pass_args = [
701                        "%s=%s" % ((a.split('=')[0],) * 2) if '=' in a else a
702                        for a in args
703                    ]
704
705        self.write_variable_declares(
706                            identifiers,
707                            toplevel=toplevel,
708                            limit=node_or_pagetag.undeclared_identifiers()
709                        )
710        if buffered:
711            s = "context.get('local')."\
712                "cache._ctx_get_or_create("\
713                "%s, lambda:__M_%s(%s),  context, %s__M_defname=%r)" % (
714                                cachekey, name, ','.join(pass_args),
715                                ''.join(["%s=%s, " % (k, v)
716                                for k, v in cache_args.items()]),
717                                name
718                            )
719            # apply buffer_filters
720            s = self.create_filter_callable(self.compiler.buffer_filters, s,
721                                            False)
722            self.printer.writelines("return " + s, None)
723        else:
724            self.printer.writelines(
725                    "__M_writer(context.get('local')."
726                    "cache._ctx_get_or_create("
727                    "%s, lambda:__M_%s(%s), context, %s__M_defname=%r))" %
728                    (
729                        cachekey, name, ','.join(pass_args),
730                        ''.join(["%s=%s, " % (k, v)
731                        for k, v in cache_args.items()]),
732                        name,
733                    ),
734                    "return ''",
735                None
736            )
737
738    def create_filter_callable(self, args, target, is_expression):
739        """write a filter-applying expression based on the filters
740        present in the given filter names, adjusting for the global
741        'default' filter aliases as needed."""
742
743        def locate_encode(name):
744            if re.match(r'decode\..+', name):
745                return "filters." + name
746            elif self.compiler.disable_unicode:
747                return filters.NON_UNICODE_ESCAPES.get(name, name)
748            else:
749                return filters.DEFAULT_ESCAPES.get(name, name)
750
751        if 'n' not in args:
752            if is_expression:
753                if self.compiler.pagetag:
754                    args = self.compiler.pagetag.filter_args.args + args
755                if self.compiler.default_filters:
756                    args = self.compiler.default_filters + args
757        for e in args:
758            # if filter given as a function, get just the identifier portion
759            if e == 'n':
760                continue
761            m = re.match(r'(.+?)(\(.*\))', e)
762            if m:
763                ident, fargs = m.group(1, 2)
764                f = locate_encode(ident)
765                e = f + fargs
766            else:
767                e = locate_encode(e)
768                assert e is not None
769            target = "%s(%s)" % (e, target)
770        return target
771
772    def visitExpression(self, node):
773        self.printer.start_source(node.lineno)
774        if len(node.escapes) or \
775                (
776                    self.compiler.pagetag is not None and
777                    len(self.compiler.pagetag.filter_args.args)
778                ) or \
779                len(self.compiler.default_filters):
780
781            s = self.create_filter_callable(node.escapes_code.args,
782                                            "%s" % node.text, True)
783            self.printer.writeline("__M_writer(%s)" % s)
784        else:
785            self.printer.writeline("__M_writer(%s)" % node.text)
786
787    def visitControlLine(self, node):
788        if node.isend:
789            self.printer.writeline(None)
790            if node.has_loop_context:
791                self.printer.writeline('finally:')
792                self.printer.writeline("loop = __M_loop._exit()")
793                self.printer.writeline(None)
794        else:
795            self.printer.start_source(node.lineno)
796            if self.compiler.enable_loop and node.keyword == 'for':
797                text = mangle_mako_loop(node, self.printer)
798            else:
799                text = node.text
800            self.printer.writeline(text)
801            children = node.get_children()
802            # this covers the three situations where we want to insert a pass:
803            #    1) a ternary control line with no children,
804            #    2) a primary control line with nothing but its own ternary
805            #          and end control lines, and
806            #    3) any control line with no content other than comments
807            if not children or (
808                    compat.all(isinstance(c, (parsetree.Comment,
809                                            parsetree.ControlLine))
810                             for c in children) and
811                    compat.all((node.is_ternary(c.keyword) or c.isend)
812                             for c in children
813                             if isinstance(c, parsetree.ControlLine))):
814                self.printer.writeline("pass")
815
816    def visitText(self, node):
817        self.printer.start_source(node.lineno)
818        self.printer.writeline("__M_writer(%s)" % repr(node.content))
819
820    def visitTextTag(self, node):
821        filtered = len(node.filter_args.args) > 0
822        if filtered:
823            self.printer.writelines(
824                "__M_writer = context._push_writer()",
825                "try:",
826            )
827        for n in node.nodes:
828            n.accept_visitor(self)
829        if filtered:
830            self.printer.writelines(
831                "finally:",
832                "__M_buf, __M_writer = context._pop_buffer_and_writer()",
833                "__M_writer(%s)" %
834                self.create_filter_callable(
835                                node.filter_args.args,
836                                "__M_buf.getvalue()",
837                                False),
838                None
839            )
840
841    def visitCode(self, node):
842        if not node.ismodule:
843            self.printer.start_source(node.lineno)
844            self.printer.write_indented_block(node.text)
845
846            if not self.in_def and len(self.identifiers.locally_assigned) > 0:
847                # if we are the "template" def, fudge locally
848                # declared/modified variables into the "__M_locals" dictionary,
849                # which is used for def calls within the same template,
850                # to simulate "enclosing scope"
851                self.printer.writeline(
852                    '__M_locals_builtin_stored = __M_locals_builtin()')
853                self.printer.writeline(
854                    '__M_locals.update(__M_dict_builtin([(__M_key,'
855                    ' __M_locals_builtin_stored[__M_key]) for __M_key in'
856                    ' [%s] if __M_key in __M_locals_builtin_stored]))' %
857                    ','.join([repr(x) for x in node.declared_identifiers()]))
858
859    def visitIncludeTag(self, node):
860        self.printer.start_source(node.lineno)
861        args = node.attributes.get('args')
862        if args:
863            self.printer.writeline(
864                    "runtime._include_file(context, %s, _template_uri, %s)" %
865                    (node.parsed_attributes['file'], args))
866        else:
867            self.printer.writeline(
868                        "runtime._include_file(context, %s, _template_uri)" %
869                        (node.parsed_attributes['file']))
870
871    def visitNamespaceTag(self, node):
872        pass
873
874    def visitDefTag(self, node):
875        pass
876
877    def visitBlockTag(self, node):
878        if node.is_anonymous:
879            self.printer.writeline("%s()" % node.funcname)
880        else:
881            nameargs = node.get_argument_expressions(as_call=True)
882            nameargs += ['**pageargs']
883            self.printer.writeline("if 'parent' not in context._data or "
884                                  "not hasattr(context._data['parent'], '%s'):"
885                                  % node.funcname)
886            self.printer.writeline(
887                "context['self'].%s(%s)" % (node.funcname, ",".join(nameargs)))
888            self.printer.writeline("\n")
889
890    def visitCallNamespaceTag(self, node):
891        # TODO: we can put namespace-specific checks here, such
892        # as ensure the given namespace will be imported,
893        # pre-import the namespace, etc.
894        self.visitCallTag(node)
895
896    def visitCallTag(self, node):
897        self.printer.writeline("def ccall(caller):")
898        export = ['body']
899        callable_identifiers = self.identifiers.branch(node, nested=True)
900        body_identifiers = callable_identifiers.branch(node, nested=False)
901        # we want the 'caller' passed to ccall to be used
902        # for the body() function, but for other non-body()
903        # <%def>s within <%call> we want the current caller
904        # off the call stack (if any)
905        body_identifiers.add_declared('caller')
906
907        self.identifier_stack.append(body_identifiers)
908        class DefVisitor(object):
909            def visitDefTag(s, node):
910                s.visitDefOrBase(node)
911
912            def visitBlockTag(s, node):
913                s.visitDefOrBase(node)
914
915            def visitDefOrBase(s, node):
916                self.write_inline_def(node, callable_identifiers, nested=False)
917                if not node.is_anonymous:
918                    export.append(node.funcname)
919                # remove defs that are within the <%call> from the
920                # "closuredefs" defined in the body, so they dont render twice
921                if node.funcname in body_identifiers.closuredefs:
922                    del body_identifiers.closuredefs[node.funcname]
923
924        vis = DefVisitor()
925        for n in node.nodes:
926            n.accept_visitor(vis)
927        self.identifier_stack.pop()
928
929        bodyargs = node.body_decl.get_argument_expressions()
930        self.printer.writeline("def body(%s):" % ','.join(bodyargs))
931
932        # TODO: figure out best way to specify
933        # buffering/nonbuffering (at call time would be better)
934        buffered = False
935        if buffered:
936            self.printer.writelines(
937                "context._push_buffer()",
938                "try:"
939            )
940        self.write_variable_declares(body_identifiers)
941        self.identifier_stack.append(body_identifiers)
942
943        for n in node.nodes:
944            n.accept_visitor(self)
945        self.identifier_stack.pop()
946
947        self.write_def_finish(node, buffered, False, False, callstack=False)
948        self.printer.writelines(
949            None,
950            "return [%s]" % (','.join(export)),
951            None
952        )
953
954        self.printer.writelines(
955            # push on caller for nested call
956            "context.caller_stack.nextcaller = "
957                "runtime.Namespace('caller', context, "
958                                "callables=ccall(__M_caller))",
959            "try:")
960        self.printer.start_source(node.lineno)
961        self.printer.writelines(
962                "__M_writer(%s)" % self.create_filter_callable(
963                                                    [], node.expression, True),
964            "finally:",
965                "context.caller_stack.nextcaller = None",
966            None
967        )
968
969class _Identifiers(object):
970    """tracks the status of identifier names as template code is rendered."""
971
972    def __init__(self, compiler, node=None, parent=None, nested=False):
973        if parent is not None:
974            # if we are the branch created in write_namespaces(),
975            # we don't share any context from the main body().
976            if isinstance(node, parsetree.NamespaceTag):
977                self.declared = set()
978                self.topleveldefs = util.SetLikeDict()
979            else:
980                # things that have already been declared
981                # in an enclosing namespace (i.e. names we can just use)
982                self.declared = set(parent.declared).\
983                        union([c.name for c in parent.closuredefs.values()]).\
984                        union(parent.locally_declared).\
985                        union(parent.argument_declared)
986
987                # if these identifiers correspond to a "nested"
988                # scope, it means whatever the parent identifiers
989                # had as undeclared will have been declared by that parent,
990                # and therefore we have them in our scope.
991                if nested:
992                    self.declared = self.declared.union(parent.undeclared)
993
994                # top level defs that are available
995                self.topleveldefs = util.SetLikeDict(**parent.topleveldefs)
996        else:
997            self.declared = set()
998            self.topleveldefs = util.SetLikeDict()
999
1000        self.compiler = compiler
1001
1002        # things within this level that are referenced before they
1003        # are declared (e.g. assigned to)
1004        self.undeclared = set()
1005
1006        # things that are declared locally.  some of these things
1007        # could be in the "undeclared" list as well if they are
1008        # referenced before declared
1009        self.locally_declared = set()
1010
1011        # assignments made in explicit python blocks.
1012        # these will be propagated to
1013        # the context of local def calls.
1014        self.locally_assigned = set()
1015
1016        # things that are declared in the argument
1017        # signature of the def callable
1018        self.argument_declared = set()
1019
1020        # closure defs that are defined in this level
1021        self.closuredefs = util.SetLikeDict()
1022
1023        self.node = node
1024
1025        if node is not None:
1026            node.accept_visitor(self)
1027
1028        illegal_names = self.compiler.reserved_names.intersection(
1029                                                        self.locally_declared)
1030        if illegal_names:
1031            raise exceptions.NameConflictError(
1032                "Reserved words declared in template: %s" %
1033                ", ".join(illegal_names))
1034
1035
1036    def branch(self, node, **kwargs):
1037        """create a new Identifiers for a new Node, with
1038          this Identifiers as the parent."""
1039
1040        return _Identifiers(self.compiler, node, self, **kwargs)
1041
1042    @property
1043    def defs(self):
1044        return set(self.topleveldefs.union(self.closuredefs).values())
1045
1046    def __repr__(self):
1047        return "Identifiers(declared=%r, locally_declared=%r, "\
1048                "undeclared=%r, topleveldefs=%r, closuredefs=%r, "\
1049                "argumentdeclared=%r)" %\
1050                (
1051                    list(self.declared),
1052                    list(self.locally_declared),
1053                    list(self.undeclared),
1054                    [c.name for c in self.topleveldefs.values()],
1055                    [c.name for c in self.closuredefs.values()],
1056                    self.argument_declared)
1057
1058    def check_declared(self, node):
1059        """update the state of this Identifiers with the undeclared
1060            and declared identifiers of the given node."""
1061
1062        for ident in node.undeclared_identifiers():
1063            if ident != 'context' and\
1064                    ident not in self.declared.union(self.locally_declared):
1065                self.undeclared.add(ident)
1066        for ident in node.declared_identifiers():
1067            self.locally_declared.add(ident)
1068
1069    def add_declared(self, ident):
1070        self.declared.add(ident)
1071        if ident in self.undeclared:
1072            self.undeclared.remove(ident)
1073
1074    def visitExpression(self, node):
1075        self.check_declared(node)
1076
1077    def visitControlLine(self, node):
1078        self.check_declared(node)
1079
1080    def visitCode(self, node):
1081        if not node.ismodule:
1082            self.check_declared(node)
1083            self.locally_assigned = self.locally_assigned.union(
1084                                                node.declared_identifiers())
1085
1086    def visitNamespaceTag(self, node):
1087        # only traverse into the sub-elements of a
1088        # <%namespace> tag if we are the branch created in
1089        # write_namespaces()
1090        if self.node is node:
1091            for n in node.nodes:
1092                n.accept_visitor(self)
1093
1094    def _check_name_exists(self, collection, node):
1095        existing = collection.get(node.funcname)
1096        collection[node.funcname] = node
1097        if existing is not None and \
1098            existing is not node and \
1099            (node.is_block or existing.is_block):
1100            raise exceptions.CompileException(
1101                    "%%def or %%block named '%s' already "
1102                    "exists in this template." %
1103                    node.funcname, **node.exception_kwargs)
1104
1105    def visitDefTag(self, node):
1106        if node.is_root() and not node.is_anonymous:
1107            self._check_name_exists(self.topleveldefs, node)
1108        elif node is not self.node:
1109            self._check_name_exists(self.closuredefs, node)
1110
1111        for ident in node.undeclared_identifiers():
1112            if ident != 'context' and \
1113                    ident not in self.declared.union(self.locally_declared):
1114                self.undeclared.add(ident)
1115
1116        # visit defs only one level deep
1117        if node is self.node:
1118            for ident in node.declared_identifiers():
1119                self.argument_declared.add(ident)
1120
1121            for n in node.nodes:
1122                n.accept_visitor(self)
1123
1124    def visitBlockTag(self, node):
1125        if node is not self.node and not node.is_anonymous:
1126
1127            if isinstance(self.node, parsetree.DefTag):
1128                raise exceptions.CompileException(
1129                        "Named block '%s' not allowed inside of def '%s'"
1130                        % (node.name, self.node.name), **node.exception_kwargs)
1131            elif isinstance(self.node,
1132                            (parsetree.CallTag, parsetree.CallNamespaceTag)):
1133                raise exceptions.CompileException(
1134                        "Named block '%s' not allowed inside of <%%call> tag"
1135                        % (node.name, ), **node.exception_kwargs)
1136
1137        for ident in node.undeclared_identifiers():
1138            if ident != 'context' and \
1139                    ident not in self.declared.union(self.locally_declared):
1140                self.undeclared.add(ident)
1141
1142        if not node.is_anonymous:
1143            self._check_name_exists(self.topleveldefs, node)
1144            self.undeclared.add(node.funcname)
1145        elif node is not self.node:
1146            self._check_name_exists(self.closuredefs, node)
1147        for ident in node.declared_identifiers():
1148            self.argument_declared.add(ident)
1149        for n in node.nodes:
1150            n.accept_visitor(self)
1151
1152    def visitTextTag(self, node):
1153        for ident in node.undeclared_identifiers():
1154            if ident != 'context' and \
1155                    ident not in self.declared.union(self.locally_declared):
1156                self.undeclared.add(ident)
1157
1158    def visitIncludeTag(self, node):
1159        self.check_declared(node)
1160
1161    def visitPageTag(self, node):
1162        for ident in node.declared_identifiers():
1163            self.argument_declared.add(ident)
1164        self.check_declared(node)
1165
1166    def visitCallNamespaceTag(self, node):
1167        self.visitCallTag(node)
1168
1169    def visitCallTag(self, node):
1170        if node is self.node:
1171            for ident in node.undeclared_identifiers():
1172                if ident != 'context' and \
1173                        ident not in self.declared.union(
1174                                                self.locally_declared):
1175                    self.undeclared.add(ident)
1176            for ident in node.declared_identifiers():
1177                self.argument_declared.add(ident)
1178            for n in node.nodes:
1179                n.accept_visitor(self)
1180        else:
1181            for ident in node.undeclared_identifiers():
1182                if ident != 'context' and \
1183                        ident not in self.declared.union(
1184                                                self.locally_declared):
1185                    self.undeclared.add(ident)
1186
1187
1188_FOR_LOOP = re.compile(
1189        r'^for\s+((?:\(?)\s*[A-Za-z_][A-Za-z_0-9]*'
1190        r'(?:\s*,\s*(?:[A-Za-z_][A-Za-z0-9_]*),??)*\s*(?:\)?))\s+in\s+(.*):'
1191)
1192
1193def mangle_mako_loop(node, printer):
1194    """converts a for loop into a context manager wrapped around a for loop
1195    when access to the `loop` variable has been detected in the for loop body
1196    """
1197    loop_variable = LoopVariable()
1198    node.accept_visitor(loop_variable)
1199    if loop_variable.detected:
1200        node.nodes[-1].has_loop_context = True
1201        match = _FOR_LOOP.match(node.text)
1202        if match:
1203            printer.writelines(
1204                    'loop = __M_loop._enter(%s)' % match.group(2),
1205                    'try:'
1206                    #'with __M_loop(%s) as loop:' % match.group(2)
1207            )
1208            text = 'for %s in loop:' % match.group(1)
1209        else:
1210            raise SyntaxError("Couldn't apply loop context: %s" % node.text)
1211    else:
1212        text = node.text
1213    return text
1214
1215
1216class LoopVariable(object):
1217    """A node visitor which looks for the name 'loop' within undeclared
1218    identifiers."""
1219
1220    def __init__(self):
1221        self.detected = False
1222
1223    def _loop_reference_detected(self, node):
1224        if 'loop' in node.undeclared_identifiers():
1225            self.detected = True
1226        else:
1227            for n in node.get_children():
1228                n.accept_visitor(self)
1229
1230    def visitControlLine(self, node):
1231        self._loop_reference_detected(node)
1232
1233    def visitCode(self, node):
1234        self._loop_reference_detected(node)
1235
1236    def visitExpression(self, node):
1237        self._loop_reference_detected(node)
1238