• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from Cython.Compiler import TypeSlots
2from Cython.Compiler.ExprNodes import not_a_constant
3import cython
4cython.declare(UtilityCode=object, EncodedString=object, BytesLiteral=object,
5               Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
6               UtilNodes=object, Naming=object)
7
8import Nodes
9import ExprNodes
10import PyrexTypes
11import Visitor
12import Builtin
13import UtilNodes
14import Options
15import Naming
16
17from Code import UtilityCode
18from StringEncoding import EncodedString, BytesLiteral
19from Errors import error
20from ParseTreeTransforms import SkipDeclarations
21
22import copy
23import codecs
24
25try:
26    from __builtin__ import reduce
27except ImportError:
28    from functools import reduce
29
30try:
31    from __builtin__ import basestring
32except ImportError:
33    basestring = str # Python 3
34
35def load_c_utility(name):
36    return UtilityCode.load_cached(name, "Optimize.c")
37
38def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
39    if isinstance(node, coercion_nodes):
40        return node.arg
41    return node
42
43def unwrap_node(node):
44    while isinstance(node, UtilNodes.ResultRefNode):
45        node = node.expression
46    return node
47
48def is_common_value(a, b):
49    a = unwrap_node(a)
50    b = unwrap_node(b)
51    if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
52        return a.name == b.name
53    if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
54        return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
55    return False
56
57def filter_none_node(node):
58    if node is not None and node.constant_result is None:
59        return None
60    return node
61
62class IterationTransform(Visitor.EnvTransform):
63    """Transform some common for-in loop patterns into efficient C loops:
64
65    - for-in-dict loop becomes a while loop calling PyDict_Next()
66    - for-in-enumerate is replaced by an external counter variable
67    - for-in-range loop becomes a plain C for loop
68    """
69    def visit_PrimaryCmpNode(self, node):
70        if node.is_ptr_contains():
71
72            # for t in operand2:
73            #     if operand1 == t:
74            #         res = True
75            #         break
76            # else:
77            #     res = False
78
79            pos = node.pos
80            result_ref = UtilNodes.ResultRefNode(node)
81            if isinstance(node.operand2, ExprNodes.IndexNode):
82                base_type = node.operand2.base.type.base_type
83            else:
84                base_type = node.operand2.type.base_type
85            target_handle = UtilNodes.TempHandle(base_type)
86            target = target_handle.ref(pos)
87            cmp_node = ExprNodes.PrimaryCmpNode(
88                pos, operator=u'==', operand1=node.operand1, operand2=target)
89            if_body = Nodes.StatListNode(
90                pos,
91                stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
92                         Nodes.BreakStatNode(pos)])
93            if_node = Nodes.IfStatNode(
94                pos,
95                if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
96                else_clause=None)
97            for_loop = UtilNodes.TempsBlockNode(
98                pos,
99                temps = [target_handle],
100                body = Nodes.ForInStatNode(
101                    pos,
102                    target=target,
103                    iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
104                    body=if_node,
105                    else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
106            for_loop = for_loop.analyse_expressions(self.current_env())
107            for_loop = self.visit(for_loop)
108            new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
109
110            if node.operator == 'not_in':
111                new_node = ExprNodes.NotNode(pos, operand=new_node)
112            return new_node
113
114        else:
115            self.visitchildren(node)
116            return node
117
118    def visit_ForInStatNode(self, node):
119        self.visitchildren(node)
120        return self._optimise_for_loop(node, node.iterator.sequence)
121
122    def _optimise_for_loop(self, node, iterator, reversed=False):
123        if iterator.type is Builtin.dict_type:
124            # like iterating over dict.keys()
125            if reversed:
126                # CPython raises an error here: not a sequence
127                return node
128            return self._transform_dict_iteration(
129                node, dict_obj=iterator, method=None, keys=True, values=False)
130
131        # C array (slice) iteration?
132        if iterator.type.is_ptr or iterator.type.is_array:
133            return self._transform_carray_iteration(node, iterator, reversed=reversed)
134        if iterator.type is Builtin.bytes_type:
135            return self._transform_bytes_iteration(node, iterator, reversed=reversed)
136        if iterator.type is Builtin.unicode_type:
137            return self._transform_unicode_iteration(node, iterator, reversed=reversed)
138
139        # the rest is based on function calls
140        if not isinstance(iterator, ExprNodes.SimpleCallNode):
141            return node
142
143        if iterator.args is None:
144            arg_count = iterator.arg_tuple and len(iterator.arg_tuple.args) or 0
145        else:
146            arg_count = len(iterator.args)
147            if arg_count and iterator.self is not None:
148                arg_count -= 1
149
150        function = iterator.function
151        # dict iteration?
152        if function.is_attribute and not reversed and not arg_count:
153            base_obj = iterator.self or function.obj
154            method = function.attribute
155            # in Py3, items() is equivalent to Py2's iteritems()
156            is_safe_iter = self.global_scope().context.language_level >= 3
157
158            if not is_safe_iter and method in ('keys', 'values', 'items'):
159                # try to reduce this to the corresponding .iter*() methods
160                if isinstance(base_obj, ExprNodes.SimpleCallNode):
161                    inner_function = base_obj.function
162                    if (inner_function.is_name and inner_function.name == 'dict'
163                            and inner_function.entry
164                            and inner_function.entry.is_builtin):
165                        # e.g. dict(something).items() => safe to use .iter*()
166                        is_safe_iter = True
167
168            keys = values = False
169            if method == 'iterkeys' or (is_safe_iter and method == 'keys'):
170                keys = True
171            elif method == 'itervalues' or (is_safe_iter and method == 'values'):
172                values = True
173            elif method == 'iteritems' or (is_safe_iter and method == 'items'):
174                keys = values = True
175
176            if keys or values:
177                return self._transform_dict_iteration(
178                    node, base_obj, method, keys, values)
179
180        # enumerate/reversed ?
181        if iterator.self is None and function.is_name and \
182               function.entry and function.entry.is_builtin:
183            if function.name == 'enumerate':
184                if reversed:
185                    # CPython raises an error here: not a sequence
186                    return node
187                return self._transform_enumerate_iteration(node, iterator)
188            elif function.name == 'reversed':
189                if reversed:
190                    # CPython raises an error here: not a sequence
191                    return node
192                return self._transform_reversed_iteration(node, iterator)
193
194        # range() iteration?
195        if Options.convert_range and node.target.type.is_int:
196            if iterator.self is None and function.is_name and \
197                   function.entry and function.entry.is_builtin and \
198                   function.name in ('range', 'xrange'):
199                return self._transform_range_iteration(node, iterator, reversed=reversed)
200
201        return node
202
203    def _transform_reversed_iteration(self, node, reversed_function):
204        args = reversed_function.arg_tuple.args
205        if len(args) == 0:
206            error(reversed_function.pos,
207                  "reversed() requires an iterable argument")
208            return node
209        elif len(args) > 1:
210            error(reversed_function.pos,
211                  "reversed() takes exactly 1 argument")
212            return node
213        arg = args[0]
214
215        # reversed(list/tuple) ?
216        if arg.type in (Builtin.tuple_type, Builtin.list_type):
217            node.iterator.sequence = arg.as_none_safe_node("'NoneType' object is not iterable")
218            node.iterator.reversed = True
219            return node
220
221        return self._optimise_for_loop(node, arg, reversed=True)
222
223    PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
224        PyrexTypes.c_char_ptr_type, [
225            PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
226            ])
227
228    PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
229        PyrexTypes.c_py_ssize_t_type, [
230            PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
231            ])
232
233    def _transform_bytes_iteration(self, node, slice_node, reversed=False):
234        target_type = node.target.type
235        if not target_type.is_int and target_type is not Builtin.bytes_type:
236            # bytes iteration returns bytes objects in Py2, but
237            # integers in Py3
238            return node
239
240        unpack_temp_node = UtilNodes.LetRefNode(
241            slice_node.as_none_safe_node("'NoneType' is not iterable"))
242
243        slice_base_node = ExprNodes.PythonCapiCallNode(
244            slice_node.pos, "PyBytes_AS_STRING",
245            self.PyBytes_AS_STRING_func_type,
246            args = [unpack_temp_node],
247            is_temp = 0,
248            )
249        len_node = ExprNodes.PythonCapiCallNode(
250            slice_node.pos, "PyBytes_GET_SIZE",
251            self.PyBytes_GET_SIZE_func_type,
252            args = [unpack_temp_node],
253            is_temp = 0,
254            )
255
256        return UtilNodes.LetNode(
257            unpack_temp_node,
258            self._transform_carray_iteration(
259                node,
260                ExprNodes.SliceIndexNode(
261                    slice_node.pos,
262                    base = slice_base_node,
263                    start = None,
264                    step = None,
265                    stop = len_node,
266                    type = slice_base_node.type,
267                    is_temp = 1,
268                    ),
269                reversed = reversed))
270
271    PyUnicode_READ_func_type = PyrexTypes.CFuncType(
272        PyrexTypes.c_py_ucs4_type, [
273            PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_type, None),
274            PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_type, None),
275            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None)
276        ])
277
278    init_unicode_iteration_func_type = PyrexTypes.CFuncType(
279        PyrexTypes.c_int_type, [
280            PyrexTypes.CFuncTypeArg("s", PyrexTypes.py_object_type, None),
281            PyrexTypes.CFuncTypeArg("length", PyrexTypes.c_py_ssize_t_ptr_type, None),
282            PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_ptr_type, None),
283            PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_ptr_type, None)
284        ],
285        exception_value = '-1')
286
287    def _transform_unicode_iteration(self, node, slice_node, reversed=False):
288        if slice_node.is_literal:
289            # try to reduce to byte iteration for plain Latin-1 strings
290            try:
291                bytes_value = BytesLiteral(slice_node.value.encode('latin1'))
292            except UnicodeEncodeError:
293                pass
294            else:
295                bytes_slice = ExprNodes.SliceIndexNode(
296                    slice_node.pos,
297                    base=ExprNodes.BytesNode(
298                        slice_node.pos, value=bytes_value,
299                        constant_result=bytes_value,
300                        type=PyrexTypes.c_char_ptr_type).coerce_to(
301                            PyrexTypes.c_uchar_ptr_type, self.current_env()),
302                    start=None,
303                    stop=ExprNodes.IntNode(
304                        slice_node.pos, value=str(len(bytes_value)),
305                        constant_result=len(bytes_value),
306                        type=PyrexTypes.c_py_ssize_t_type),
307                    type=Builtin.unicode_type,  # hint for Python conversion
308                )
309                return self._transform_carray_iteration(node, bytes_slice, reversed)
310
311        unpack_temp_node = UtilNodes.LetRefNode(
312            slice_node.as_none_safe_node("'NoneType' is not iterable"))
313
314        start_node = ExprNodes.IntNode(
315            node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t_type)
316        length_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
317        end_node = length_temp.ref(node.pos)
318        if reversed:
319            relation1, relation2 = '>', '>='
320            start_node, end_node = end_node, start_node
321        else:
322            relation1, relation2 = '<=', '<'
323
324        kind_temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
325        data_temp = UtilNodes.TempHandle(PyrexTypes.c_void_ptr_type)
326        counter_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
327
328        target_value = ExprNodes.PythonCapiCallNode(
329            slice_node.pos, "__Pyx_PyUnicode_READ",
330            self.PyUnicode_READ_func_type,
331            args = [kind_temp.ref(slice_node.pos),
332                    data_temp.ref(slice_node.pos),
333                    counter_temp.ref(node.target.pos)],
334            is_temp = False,
335            )
336        if target_value.type != node.target.type:
337            target_value = target_value.coerce_to(node.target.type,
338                                                  self.current_env())
339        target_assign = Nodes.SingleAssignmentNode(
340            pos = node.target.pos,
341            lhs = node.target,
342            rhs = target_value)
343        body = Nodes.StatListNode(
344            node.pos,
345            stats = [target_assign, node.body])
346
347        loop_node = Nodes.ForFromStatNode(
348            node.pos,
349            bound1=start_node, relation1=relation1,
350            target=counter_temp.ref(node.target.pos),
351            relation2=relation2, bound2=end_node,
352            step=None, body=body,
353            else_clause=node.else_clause,
354            from_range=True)
355
356        setup_node = Nodes.ExprStatNode(
357            node.pos,
358            expr = ExprNodes.PythonCapiCallNode(
359                slice_node.pos, "__Pyx_init_unicode_iteration",
360                self.init_unicode_iteration_func_type,
361                args = [unpack_temp_node,
362                        ExprNodes.AmpersandNode(slice_node.pos, operand=length_temp.ref(slice_node.pos),
363                                                type=PyrexTypes.c_py_ssize_t_ptr_type),
364                        ExprNodes.AmpersandNode(slice_node.pos, operand=data_temp.ref(slice_node.pos),
365                                                type=PyrexTypes.c_void_ptr_ptr_type),
366                        ExprNodes.AmpersandNode(slice_node.pos, operand=kind_temp.ref(slice_node.pos),
367                                                type=PyrexTypes.c_int_ptr_type),
368                        ],
369                is_temp = True,
370                result_is_used = False,
371                utility_code=UtilityCode.load_cached("unicode_iter", "Optimize.c"),
372                ))
373        return UtilNodes.LetNode(
374            unpack_temp_node,
375            UtilNodes.TempsBlockNode(
376                node.pos, temps=[counter_temp, length_temp, data_temp, kind_temp],
377                body=Nodes.StatListNode(node.pos, stats=[setup_node, loop_node])))
378
379    def _transform_carray_iteration(self, node, slice_node, reversed=False):
380        neg_step = False
381        if isinstance(slice_node, ExprNodes.SliceIndexNode):
382            slice_base = slice_node.base
383            start = filter_none_node(slice_node.start)
384            stop = filter_none_node(slice_node.stop)
385            step = None
386            if not stop:
387                if not slice_base.type.is_pyobject:
388                    error(slice_node.pos, "C array iteration requires known end index")
389                return node
390
391        elif isinstance(slice_node, ExprNodes.IndexNode):
392            assert isinstance(slice_node.index, ExprNodes.SliceNode)
393            slice_base = slice_node.base
394            index = slice_node.index
395            start = filter_none_node(index.start)
396            stop = filter_none_node(index.stop)
397            step = filter_none_node(index.step)
398            if step:
399                if not isinstance(step.constant_result, (int,long)) \
400                       or step.constant_result == 0 \
401                       or step.constant_result > 0 and not stop \
402                       or step.constant_result < 0 and not start:
403                    if not slice_base.type.is_pyobject:
404                        error(step.pos, "C array iteration requires known step size and end index")
405                    return node
406                else:
407                    # step sign is handled internally by ForFromStatNode
408                    step_value = step.constant_result
409                    if reversed:
410                        step_value = -step_value
411                    neg_step = step_value < 0
412                    step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
413                                             value=str(abs(step_value)),
414                                             constant_result=abs(step_value))
415
416        elif slice_node.type.is_array:
417            if slice_node.type.size is None:
418                error(slice_node.pos, "C array iteration requires known end index")
419                return node
420            slice_base = slice_node
421            start = None
422            stop = ExprNodes.IntNode(
423                slice_node.pos, value=str(slice_node.type.size),
424                type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
425            step = None
426
427        else:
428            if not slice_node.type.is_pyobject:
429                error(slice_node.pos, "C array iteration requires known end index")
430            return node
431
432        if start:
433            start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
434        if stop:
435            stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
436        if stop is None:
437            if neg_step:
438                stop = ExprNodes.IntNode(
439                    slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
440            else:
441                error(slice_node.pos, "C array iteration requires known step size and end index")
442                return node
443
444        if reversed:
445            if not start:
446                start = ExprNodes.IntNode(slice_node.pos, value="0",  constant_result=0,
447                                          type=PyrexTypes.c_py_ssize_t_type)
448            # if step was provided, it was already negated above
449            start, stop = stop, start
450
451        ptr_type = slice_base.type
452        if ptr_type.is_array:
453            ptr_type = ptr_type.element_ptr_type()
454        carray_ptr = slice_base.coerce_to_simple(self.current_env())
455
456        if start and start.constant_result != 0:
457            start_ptr_node = ExprNodes.AddNode(
458                start.pos,
459                operand1=carray_ptr,
460                operator='+',
461                operand2=start,
462                type=ptr_type)
463        else:
464            start_ptr_node = carray_ptr
465
466        if stop and stop.constant_result != 0:
467            stop_ptr_node = ExprNodes.AddNode(
468                stop.pos,
469                operand1=ExprNodes.CloneNode(carray_ptr),
470                operator='+',
471                operand2=stop,
472                type=ptr_type
473                ).coerce_to_simple(self.current_env())
474        else:
475            stop_ptr_node = ExprNodes.CloneNode(carray_ptr)
476
477        counter = UtilNodes.TempHandle(ptr_type)
478        counter_temp = counter.ref(node.target.pos)
479
480        if slice_base.type.is_string and node.target.type.is_pyobject:
481            # special case: char* -> bytes/unicode
482            if slice_node.type is Builtin.unicode_type:
483                target_value = ExprNodes.CastNode(
484                    ExprNodes.DereferenceNode(
485                        node.target.pos, operand=counter_temp,
486                        type=ptr_type.base_type),
487                    PyrexTypes.c_py_ucs4_type).coerce_to(
488                        node.target.type, self.current_env())
489            else:
490                # char* -> bytes coercion requires slicing, not indexing
491                target_value = ExprNodes.SliceIndexNode(
492                    node.target.pos,
493                    start=ExprNodes.IntNode(node.target.pos, value='0',
494                                            constant_result=0,
495                                            type=PyrexTypes.c_int_type),
496                    stop=ExprNodes.IntNode(node.target.pos, value='1',
497                                           constant_result=1,
498                                           type=PyrexTypes.c_int_type),
499                    base=counter_temp,
500                    type=Builtin.bytes_type,
501                    is_temp=1)
502        elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type):
503            # Allow iteration with pointer target to avoid copy.
504            target_value = counter_temp
505        else:
506            # TODO: can this safely be replaced with DereferenceNode() as above?
507            target_value = ExprNodes.IndexNode(
508                node.target.pos,
509                index=ExprNodes.IntNode(node.target.pos, value='0',
510                                        constant_result=0,
511                                        type=PyrexTypes.c_int_type),
512                base=counter_temp,
513                is_buffer_access=False,
514                type=ptr_type.base_type)
515
516        if target_value.type != node.target.type:
517            target_value = target_value.coerce_to(node.target.type,
518                                                  self.current_env())
519
520        target_assign = Nodes.SingleAssignmentNode(
521            pos = node.target.pos,
522            lhs = node.target,
523            rhs = target_value)
524
525        body = Nodes.StatListNode(
526            node.pos,
527            stats = [target_assign, node.body])
528
529        relation1, relation2 = self._find_for_from_node_relations(neg_step, reversed)
530
531        for_node = Nodes.ForFromStatNode(
532            node.pos,
533            bound1=start_ptr_node, relation1=relation1,
534            target=counter_temp,
535            relation2=relation2, bound2=stop_ptr_node,
536            step=step, body=body,
537            else_clause=node.else_clause,
538            from_range=True)
539
540        return UtilNodes.TempsBlockNode(
541            node.pos, temps=[counter],
542            body=for_node)
543
544    def _transform_enumerate_iteration(self, node, enumerate_function):
545        args = enumerate_function.arg_tuple.args
546        if len(args) == 0:
547            error(enumerate_function.pos,
548                  "enumerate() requires an iterable argument")
549            return node
550        elif len(args) > 2:
551            error(enumerate_function.pos,
552                  "enumerate() takes at most 2 arguments")
553            return node
554
555        if not node.target.is_sequence_constructor:
556            # leave this untouched for now
557            return node
558        targets = node.target.args
559        if len(targets) != 2:
560            # leave this untouched for now
561            return node
562
563        enumerate_target, iterable_target = targets
564        counter_type = enumerate_target.type
565
566        if not counter_type.is_pyobject and not counter_type.is_int:
567            # nothing we can do here, I guess
568            return node
569
570        if len(args) == 2:
571            start = unwrap_coerced_node(args[1]).coerce_to(counter_type, self.current_env())
572        else:
573            start = ExprNodes.IntNode(enumerate_function.pos,
574                                      value='0',
575                                      type=counter_type,
576                                      constant_result=0)
577        temp = UtilNodes.LetRefNode(start)
578
579        inc_expression = ExprNodes.AddNode(
580            enumerate_function.pos,
581            operand1 = temp,
582            operand2 = ExprNodes.IntNode(node.pos, value='1',
583                                         type=counter_type,
584                                         constant_result=1),
585            operator = '+',
586            type = counter_type,
587            #inplace = True,   # not worth using in-place operation for Py ints
588            is_temp = counter_type.is_pyobject
589            )
590
591        loop_body = [
592            Nodes.SingleAssignmentNode(
593                pos = enumerate_target.pos,
594                lhs = enumerate_target,
595                rhs = temp),
596            Nodes.SingleAssignmentNode(
597                pos = enumerate_target.pos,
598                lhs = temp,
599                rhs = inc_expression)
600            ]
601
602        if isinstance(node.body, Nodes.StatListNode):
603            node.body.stats = loop_body + node.body.stats
604        else:
605            loop_body.append(node.body)
606            node.body = Nodes.StatListNode(
607                node.body.pos,
608                stats = loop_body)
609
610        node.target = iterable_target
611        node.item = node.item.coerce_to(iterable_target.type, self.current_env())
612        node.iterator.sequence = args[0]
613
614        # recurse into loop to check for further optimisations
615        return UtilNodes.LetNode(temp, self._optimise_for_loop(node, node.iterator.sequence))
616
617    def _find_for_from_node_relations(self, neg_step_value, reversed):
618        if reversed:
619            if neg_step_value:
620                return '<', '<='
621            else:
622                return '>', '>='
623        else:
624            if neg_step_value:
625                return '>=', '>'
626            else:
627                return '<=', '<'
628
629    def _transform_range_iteration(self, node, range_function, reversed=False):
630        args = range_function.arg_tuple.args
631        if len(args) < 3:
632            step_pos = range_function.pos
633            step_value = 1
634            step = ExprNodes.IntNode(step_pos, value='1',
635                                     constant_result=1)
636        else:
637            step = args[2]
638            step_pos = step.pos
639            if not isinstance(step.constant_result, (int, long)):
640                # cannot determine step direction
641                return node
642            step_value = step.constant_result
643            if step_value == 0:
644                # will lead to an error elsewhere
645                return node
646            if reversed and step_value not in (1, -1):
647                # FIXME: currently broken - requires calculation of the correct bounds
648                return node
649            if not isinstance(step, ExprNodes.IntNode):
650                step = ExprNodes.IntNode(step_pos, value=str(step_value),
651                                         constant_result=step_value)
652
653        if len(args) == 1:
654            bound1 = ExprNodes.IntNode(range_function.pos, value='0',
655                                       constant_result=0)
656            bound2 = args[0].coerce_to_integer(self.current_env())
657        else:
658            bound1 = args[0].coerce_to_integer(self.current_env())
659            bound2 = args[1].coerce_to_integer(self.current_env())
660
661        relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed)
662
663        if reversed:
664            bound1, bound2 = bound2, bound1
665            if step_value < 0:
666                step_value = -step_value
667        else:
668            if step_value < 0:
669                step_value = -step_value
670
671        step.value = str(step_value)
672        step.constant_result = step_value
673        step = step.coerce_to_integer(self.current_env())
674
675        if not bound2.is_literal:
676            # stop bound must be immutable => keep it in a temp var
677            bound2_is_temp = True
678            bound2 = UtilNodes.LetRefNode(bound2)
679        else:
680            bound2_is_temp = False
681
682        for_node = Nodes.ForFromStatNode(
683            node.pos,
684            target=node.target,
685            bound1=bound1, relation1=relation1,
686            relation2=relation2, bound2=bound2,
687            step=step, body=node.body,
688            else_clause=node.else_clause,
689            from_range=True)
690
691        if bound2_is_temp:
692            for_node = UtilNodes.LetNode(bound2, for_node)
693
694        return for_node
695
696    def _transform_dict_iteration(self, node, dict_obj, method, keys, values):
697        temps = []
698        temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
699        temps.append(temp)
700        dict_temp = temp.ref(dict_obj.pos)
701        temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
702        temps.append(temp)
703        pos_temp = temp.ref(node.pos)
704
705        key_target = value_target = tuple_target = None
706        if keys and values:
707            if node.target.is_sequence_constructor:
708                if len(node.target.args) == 2:
709                    key_target, value_target = node.target.args
710                else:
711                    # unusual case that may or may not lead to an error
712                    return node
713            else:
714                tuple_target = node.target
715        elif keys:
716            key_target = node.target
717        else:
718            value_target = node.target
719
720        if isinstance(node.body, Nodes.StatListNode):
721            body = node.body
722        else:
723            body = Nodes.StatListNode(pos = node.body.pos,
724                                      stats = [node.body])
725
726        # keep original length to guard against dict modification
727        dict_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
728        temps.append(dict_len_temp)
729        dict_len_temp_addr = ExprNodes.AmpersandNode(
730            node.pos, operand=dict_len_temp.ref(dict_obj.pos),
731            type=PyrexTypes.c_ptr_type(dict_len_temp.type))
732        temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
733        temps.append(temp)
734        is_dict_temp = temp.ref(node.pos)
735        is_dict_temp_addr = ExprNodes.AmpersandNode(
736            node.pos, operand=is_dict_temp,
737            type=PyrexTypes.c_ptr_type(temp.type))
738
739        iter_next_node = Nodes.DictIterationNextNode(
740            dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp,
741            key_target, value_target, tuple_target,
742            is_dict_temp)
743        iter_next_node = iter_next_node.analyse_expressions(self.current_env())
744        body.stats[0:0] = [iter_next_node]
745
746        if method:
747            method_node = ExprNodes.StringNode(
748                dict_obj.pos, is_identifier=True, value=method)
749            dict_obj = dict_obj.as_none_safe_node(
750                "'NoneType' object has no attribute '%s'",
751                error = "PyExc_AttributeError",
752                format_args = [method])
753        else:
754            method_node = ExprNodes.NullNode(dict_obj.pos)
755            dict_obj = dict_obj.as_none_safe_node("'NoneType' object is not iterable")
756
757        def flag_node(value):
758            value = value and 1 or 0
759            return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value)
760
761        result_code = [
762            Nodes.SingleAssignmentNode(
763                node.pos,
764                lhs = pos_temp,
765                rhs = ExprNodes.IntNode(node.pos, value='0',
766                                        constant_result=0)),
767            Nodes.SingleAssignmentNode(
768                dict_obj.pos,
769                lhs = dict_temp,
770                rhs = ExprNodes.PythonCapiCallNode(
771                    dict_obj.pos,
772                    "__Pyx_dict_iterator",
773                    self.PyDict_Iterator_func_type,
774                    utility_code = UtilityCode.load_cached("dict_iter", "Optimize.c"),
775                    args = [dict_obj, flag_node(dict_obj.type is Builtin.dict_type),
776                            method_node, dict_len_temp_addr, is_dict_temp_addr,
777                            ],
778                    is_temp=True,
779                )),
780            Nodes.WhileStatNode(
781                node.pos,
782                condition = None,
783                body = body,
784                else_clause = node.else_clause
785                )
786            ]
787
788        return UtilNodes.TempsBlockNode(
789            node.pos, temps=temps,
790            body=Nodes.StatListNode(
791                node.pos,
792                stats = result_code
793                ))
794
795    PyDict_Iterator_func_type = PyrexTypes.CFuncType(
796        PyrexTypes.py_object_type, [
797            PyrexTypes.CFuncTypeArg("dict",  PyrexTypes.py_object_type, None),
798            PyrexTypes.CFuncTypeArg("is_dict",  PyrexTypes.c_int_type, None),
799            PyrexTypes.CFuncTypeArg("method_name",  PyrexTypes.py_object_type, None),
800            PyrexTypes.CFuncTypeArg("p_orig_length",  PyrexTypes.c_py_ssize_t_ptr_type, None),
801            PyrexTypes.CFuncTypeArg("p_is_dict",  PyrexTypes.c_int_ptr_type, None),
802            ])
803
804
805class SwitchTransform(Visitor.VisitorTransform):
806    """
807    This transformation tries to turn long if statements into C switch statements.
808    The requirement is that every clause be an (or of) var == value, where the var
809    is common among all clauses and both var and value are ints.
810    """
811    NO_MATCH = (None, None, None)
812
813    def extract_conditions(self, cond, allow_not_in):
814        while True:
815            if isinstance(cond, (ExprNodes.CoerceToTempNode,
816                                 ExprNodes.CoerceToBooleanNode)):
817                cond = cond.arg
818            elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
819                # this is what we get from the FlattenInListTransform
820                cond = cond.subexpression
821            elif isinstance(cond, ExprNodes.TypecastNode):
822                cond = cond.operand
823            else:
824                break
825
826        if isinstance(cond, ExprNodes.PrimaryCmpNode):
827            if cond.cascade is not None:
828                return self.NO_MATCH
829            elif cond.is_c_string_contains() and \
830                   isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
831                not_in = cond.operator == 'not_in'
832                if not_in and not allow_not_in:
833                    return self.NO_MATCH
834                if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
835                       cond.operand2.contains_surrogates():
836                    # dealing with surrogates leads to different
837                    # behaviour on wide and narrow Unicode
838                    # platforms => refuse to optimise this case
839                    return self.NO_MATCH
840                return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
841            elif not cond.is_python_comparison():
842                if cond.operator == '==':
843                    not_in = False
844                elif allow_not_in and cond.operator == '!=':
845                    not_in = True
846                else:
847                    return self.NO_MATCH
848                # this looks somewhat silly, but it does the right
849                # checks for NameNode and AttributeNode
850                if is_common_value(cond.operand1, cond.operand1):
851                    if cond.operand2.is_literal:
852                        return not_in, cond.operand1, [cond.operand2]
853                    elif getattr(cond.operand2, 'entry', None) \
854                             and cond.operand2.entry.is_const:
855                        return not_in, cond.operand1, [cond.operand2]
856                if is_common_value(cond.operand2, cond.operand2):
857                    if cond.operand1.is_literal:
858                        return not_in, cond.operand2, [cond.operand1]
859                    elif getattr(cond.operand1, 'entry', None) \
860                             and cond.operand1.entry.is_const:
861                        return not_in, cond.operand2, [cond.operand1]
862        elif isinstance(cond, ExprNodes.BoolBinopNode):
863            if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
864                allow_not_in = (cond.operator == 'and')
865                not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
866                not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
867                if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
868                    if (not not_in_1) or allow_not_in:
869                        return not_in_1, t1, c1+c2
870        return self.NO_MATCH
871
872    def extract_in_string_conditions(self, string_literal):
873        if isinstance(string_literal, ExprNodes.UnicodeNode):
874            charvals = list(map(ord, set(string_literal.value)))
875            charvals.sort()
876            return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
877                                       constant_result=charval)
878                     for charval in charvals ]
879        else:
880            # this is a bit tricky as Py3's bytes type returns
881            # integers on iteration, whereas Py2 returns 1-char byte
882            # strings
883            characters = string_literal.value
884            characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
885            characters.sort()
886            return [ ExprNodes.CharNode(string_literal.pos, value=charval,
887                                        constant_result=charval)
888                     for charval in characters ]
889
890    def extract_common_conditions(self, common_var, condition, allow_not_in):
891        not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
892        if var is None:
893            return self.NO_MATCH
894        elif common_var is not None and not is_common_value(var, common_var):
895            return self.NO_MATCH
896        elif not (var.type.is_int or var.type.is_enum) or sum([not (cond.type.is_int or cond.type.is_enum) for cond in conditions]):
897            return self.NO_MATCH
898        return not_in, var, conditions
899
900    def has_duplicate_values(self, condition_values):
901        # duplicated values don't work in a switch statement
902        seen = set()
903        for value in condition_values:
904            if value.has_constant_result():
905                if value.constant_result in seen:
906                    return True
907                seen.add(value.constant_result)
908            else:
909                # this isn't completely safe as we don't know the
910                # final C value, but this is about the best we can do
911                try:
912                    if value.entry.cname in seen:
913                        return True
914                except AttributeError:
915                    return True  # play safe
916                seen.add(value.entry.cname)
917        return False
918
919    def visit_IfStatNode(self, node):
920        common_var = None
921        cases = []
922        for if_clause in node.if_clauses:
923            _, common_var, conditions = self.extract_common_conditions(
924                common_var, if_clause.condition, False)
925            if common_var is None:
926                self.visitchildren(node)
927                return node
928            cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
929                                              conditions = conditions,
930                                              body = if_clause.body))
931
932        condition_values = [
933            cond for case in cases for cond in case.conditions]
934        if len(condition_values) < 2:
935            self.visitchildren(node)
936            return node
937        if self.has_duplicate_values(condition_values):
938            self.visitchildren(node)
939            return node
940
941        common_var = unwrap_node(common_var)
942        switch_node = Nodes.SwitchStatNode(pos = node.pos,
943                                           test = common_var,
944                                           cases = cases,
945                                           else_clause = node.else_clause)
946        return switch_node
947
948    def visit_CondExprNode(self, node):
949        not_in, common_var, conditions = self.extract_common_conditions(
950            None, node.test, True)
951        if common_var is None \
952               or len(conditions) < 2 \
953               or self.has_duplicate_values(conditions):
954            self.visitchildren(node)
955            return node
956        return self.build_simple_switch_statement(
957            node, common_var, conditions, not_in,
958            node.true_val, node.false_val)
959
960    def visit_BoolBinopNode(self, node):
961        not_in, common_var, conditions = self.extract_common_conditions(
962            None, node, True)
963        if common_var is None \
964               or len(conditions) < 2 \
965               or self.has_duplicate_values(conditions):
966            self.visitchildren(node)
967            return node
968
969        return self.build_simple_switch_statement(
970            node, common_var, conditions, not_in,
971            ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
972            ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
973
974    def visit_PrimaryCmpNode(self, node):
975        not_in, common_var, conditions = self.extract_common_conditions(
976            None, node, True)
977        if common_var is None \
978               or len(conditions) < 2 \
979               or self.has_duplicate_values(conditions):
980            self.visitchildren(node)
981            return node
982
983        return self.build_simple_switch_statement(
984            node, common_var, conditions, not_in,
985            ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
986            ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
987
988    def build_simple_switch_statement(self, node, common_var, conditions,
989                                      not_in, true_val, false_val):
990        result_ref = UtilNodes.ResultRefNode(node)
991        true_body = Nodes.SingleAssignmentNode(
992            node.pos,
993            lhs = result_ref,
994            rhs = true_val,
995            first = True)
996        false_body = Nodes.SingleAssignmentNode(
997            node.pos,
998            lhs = result_ref,
999            rhs = false_val,
1000            first = True)
1001
1002        if not_in:
1003            true_body, false_body = false_body, true_body
1004
1005        cases = [Nodes.SwitchCaseNode(pos = node.pos,
1006                                      conditions = conditions,
1007                                      body = true_body)]
1008
1009        common_var = unwrap_node(common_var)
1010        switch_node = Nodes.SwitchStatNode(pos = node.pos,
1011                                           test = common_var,
1012                                           cases = cases,
1013                                           else_clause = false_body)
1014        replacement = UtilNodes.TempResultFromStatNode(result_ref, switch_node)
1015        return replacement
1016
1017    def visit_EvalWithTempExprNode(self, node):
1018        # drop unused expression temp from FlattenInListTransform
1019        orig_expr = node.subexpression
1020        temp_ref = node.lazy_temp
1021        self.visitchildren(node)
1022        if node.subexpression is not orig_expr:
1023            # node was restructured => check if temp is still used
1024            if not Visitor.tree_contains(node.subexpression, temp_ref):
1025                return node.subexpression
1026        return node
1027
1028    visit_Node = Visitor.VisitorTransform.recurse_to_children
1029
1030
1031class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
1032    """
1033    This transformation flattens "x in [val1, ..., valn]" into a sequential list
1034    of comparisons.
1035    """
1036
1037    def visit_PrimaryCmpNode(self, node):
1038        self.visitchildren(node)
1039        if node.cascade is not None:
1040            return node
1041        elif node.operator == 'in':
1042            conjunction = 'or'
1043            eq_or_neq = '=='
1044        elif node.operator == 'not_in':
1045            conjunction = 'and'
1046            eq_or_neq = '!='
1047        else:
1048            return node
1049
1050        if not isinstance(node.operand2, (ExprNodes.TupleNode,
1051                                          ExprNodes.ListNode,
1052                                          ExprNodes.SetNode)):
1053            return node
1054
1055        args = node.operand2.args
1056        if len(args) == 0:
1057            # note: lhs may have side effects
1058            return node
1059
1060        lhs = UtilNodes.ResultRefNode(node.operand1)
1061
1062        conds = []
1063        temps = []
1064        for arg in args:
1065            try:
1066                # Trial optimisation to avoid redundant temp
1067                # assignments.  However, since is_simple() is meant to
1068                # be called after type analysis, we ignore any errors
1069                # and just play safe in that case.
1070                is_simple_arg = arg.is_simple()
1071            except Exception:
1072                is_simple_arg = False
1073            if not is_simple_arg:
1074                # must evaluate all non-simple RHS before doing the comparisons
1075                arg = UtilNodes.LetRefNode(arg)
1076                temps.append(arg)
1077            cond = ExprNodes.PrimaryCmpNode(
1078                                pos = node.pos,
1079                                operand1 = lhs,
1080                                operator = eq_or_neq,
1081                                operand2 = arg,
1082                                cascade = None)
1083            conds.append(ExprNodes.TypecastNode(
1084                                pos = node.pos,
1085                                operand = cond,
1086                                type = PyrexTypes.c_bint_type))
1087        def concat(left, right):
1088            return ExprNodes.BoolBinopNode(
1089                                pos = node.pos,
1090                                operator = conjunction,
1091                                operand1 = left,
1092                                operand2 = right)
1093
1094        condition = reduce(concat, conds)
1095        new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
1096        for temp in temps[::-1]:
1097            new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
1098        return new_node
1099
1100    visit_Node = Visitor.VisitorTransform.recurse_to_children
1101
1102
1103class DropRefcountingTransform(Visitor.VisitorTransform):
1104    """Drop ref-counting in safe places.
1105    """
1106    visit_Node = Visitor.VisitorTransform.recurse_to_children
1107
1108    def visit_ParallelAssignmentNode(self, node):
1109        """
1110        Parallel swap assignments like 'a,b = b,a' are safe.
1111        """
1112        left_names, right_names = [], []
1113        left_indices, right_indices = [], []
1114        temps = []
1115
1116        for stat in node.stats:
1117            if isinstance(stat, Nodes.SingleAssignmentNode):
1118                if not self._extract_operand(stat.lhs, left_names,
1119                                             left_indices, temps):
1120                    return node
1121                if not self._extract_operand(stat.rhs, right_names,
1122                                             right_indices, temps):
1123                    return node
1124            elif isinstance(stat, Nodes.CascadedAssignmentNode):
1125                # FIXME
1126                return node
1127            else:
1128                return node
1129
1130        if left_names or right_names:
1131            # lhs/rhs names must be a non-redundant permutation
1132            lnames = [ path for path, n in left_names ]
1133            rnames = [ path for path, n in right_names ]
1134            if set(lnames) != set(rnames):
1135                return node
1136            if len(set(lnames)) != len(right_names):
1137                return node
1138
1139        if left_indices or right_indices:
1140            # base name and index of index nodes must be a
1141            # non-redundant permutation
1142            lindices = []
1143            for lhs_node in left_indices:
1144                index_id = self._extract_index_id(lhs_node)
1145                if not index_id:
1146                    return node
1147                lindices.append(index_id)
1148            rindices = []
1149            for rhs_node in right_indices:
1150                index_id = self._extract_index_id(rhs_node)
1151                if not index_id:
1152                    return node
1153                rindices.append(index_id)
1154
1155            if set(lindices) != set(rindices):
1156                return node
1157            if len(set(lindices)) != len(right_indices):
1158                return node
1159
1160            # really supporting IndexNode requires support in
1161            # __Pyx_GetItemInt(), so let's stop short for now
1162            return node
1163
1164        temp_args = [t.arg for t in temps]
1165        for temp in temps:
1166            temp.use_managed_ref = False
1167
1168        for _, name_node in left_names + right_names:
1169            if name_node not in temp_args:
1170                name_node.use_managed_ref = False
1171
1172        for index_node in left_indices + right_indices:
1173            index_node.use_managed_ref = False
1174
1175        return node
1176
1177    def _extract_operand(self, node, names, indices, temps):
1178        node = unwrap_node(node)
1179        if not node.type.is_pyobject:
1180            return False
1181        if isinstance(node, ExprNodes.CoerceToTempNode):
1182            temps.append(node)
1183            node = node.arg
1184        name_path = []
1185        obj_node = node
1186        while isinstance(obj_node, ExprNodes.AttributeNode):
1187            if obj_node.is_py_attr:
1188                return False
1189            name_path.append(obj_node.member)
1190            obj_node = obj_node.obj
1191        if isinstance(obj_node, ExprNodes.NameNode):
1192            name_path.append(obj_node.name)
1193            names.append( ('.'.join(name_path[::-1]), node) )
1194        elif isinstance(node, ExprNodes.IndexNode):
1195            if node.base.type != Builtin.list_type:
1196                return False
1197            if not node.index.type.is_int:
1198                return False
1199            if not isinstance(node.base, ExprNodes.NameNode):
1200                return False
1201            indices.append(node)
1202        else:
1203            return False
1204        return True
1205
1206    def _extract_index_id(self, index_node):
1207        base = index_node.base
1208        index = index_node.index
1209        if isinstance(index, ExprNodes.NameNode):
1210            index_val = index.name
1211        elif isinstance(index, ExprNodes.ConstNode):
1212            # FIXME:
1213            return None
1214        else:
1215            return None
1216        return (base.name, index_val)
1217
1218
1219class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
1220    """Optimize some common calls to builtin types *before* the type
1221    analysis phase and *after* the declarations analysis phase.
1222
1223    This transform cannot make use of any argument types, but it can
1224    restructure the tree in a way that the type analysis phase can
1225    respond to.
1226
1227    Introducing C function calls here may not be a good idea.  Move
1228    them to the OptimizeBuiltinCalls transform instead, which runs
1229    after type analysis.
1230    """
1231    # only intercept on call nodes
1232    visit_Node = Visitor.VisitorTransform.recurse_to_children
1233
1234    def visit_SimpleCallNode(self, node):
1235        self.visitchildren(node)
1236        function = node.function
1237        if not self._function_is_builtin_name(function):
1238            return node
1239        return self._dispatch_to_handler(node, function, node.args)
1240
1241    def visit_GeneralCallNode(self, node):
1242        self.visitchildren(node)
1243        function = node.function
1244        if not self._function_is_builtin_name(function):
1245            return node
1246        arg_tuple = node.positional_args
1247        if not isinstance(arg_tuple, ExprNodes.TupleNode):
1248            return node
1249        args = arg_tuple.args
1250        return self._dispatch_to_handler(
1251            node, function, args, node.keyword_args)
1252
1253    def _function_is_builtin_name(self, function):
1254        if not function.is_name:
1255            return False
1256        env = self.current_env()
1257        entry = env.lookup(function.name)
1258        if entry is not env.builtin_scope().lookup_here(function.name):
1259            return False
1260        # if entry is None, it's at least an undeclared name, so likely builtin
1261        return True
1262
1263    def _dispatch_to_handler(self, node, function, args, kwargs=None):
1264        if kwargs is None:
1265            handler_name = '_handle_simple_function_%s' % function.name
1266        else:
1267            handler_name = '_handle_general_function_%s' % function.name
1268        handle_call = getattr(self, handler_name, None)
1269        if handle_call is not None:
1270            if kwargs is None:
1271                return handle_call(node, args)
1272            else:
1273                return handle_call(node, args, kwargs)
1274        return node
1275
1276    def _inject_capi_function(self, node, cname, func_type, utility_code=None):
1277        node.function = ExprNodes.PythonCapiFunctionNode(
1278            node.function.pos, node.function.name, cname, func_type,
1279            utility_code = utility_code)
1280
1281    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1282        if not expected: # None or 0
1283            arg_str = ''
1284        elif isinstance(expected, basestring) or expected > 1:
1285            arg_str = '...'
1286        elif expected == 1:
1287            arg_str = 'x'
1288        else:
1289            arg_str = ''
1290        if expected is not None:
1291            expected_str = 'expected %s, ' % expected
1292        else:
1293            expected_str = ''
1294        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1295            function_name, arg_str, expected_str, len(args)))
1296
1297    # specific handlers for simple call nodes
1298
1299    def _handle_simple_function_float(self, node, pos_args):
1300        if not pos_args:
1301            return ExprNodes.FloatNode(node.pos, value='0.0')
1302        if len(pos_args) > 1:
1303            self._error_wrong_arg_count('float', node, pos_args, 1)
1304        arg_type = getattr(pos_args[0], 'type', None)
1305        if arg_type in (PyrexTypes.c_double_type, Builtin.float_type):
1306            return pos_args[0]
1307        return node
1308
1309    class YieldNodeCollector(Visitor.TreeVisitor):
1310        def __init__(self):
1311            Visitor.TreeVisitor.__init__(self)
1312            self.yield_stat_nodes = {}
1313            self.yield_nodes = []
1314
1315        visit_Node = Visitor.TreeVisitor.visitchildren
1316        # XXX: disable inlining while it's not back supported
1317        def __visit_YieldExprNode(self, node):
1318            self.yield_nodes.append(node)
1319            self.visitchildren(node)
1320
1321        def __visit_ExprStatNode(self, node):
1322            self.visitchildren(node)
1323            if node.expr in self.yield_nodes:
1324                self.yield_stat_nodes[node.expr] = node
1325
1326        def __visit_GeneratorExpressionNode(self, node):
1327            # enable when we support generic generator expressions
1328            #
1329            # everything below this node is out of scope
1330            pass
1331
1332    def _find_single_yield_expression(self, node):
1333        collector = self.YieldNodeCollector()
1334        collector.visitchildren(node)
1335        if len(collector.yield_nodes) != 1:
1336            return None, None
1337        yield_node = collector.yield_nodes[0]
1338        try:
1339            return (yield_node.arg, collector.yield_stat_nodes[yield_node])
1340        except KeyError:
1341            return None, None
1342
1343    def _handle_simple_function_all(self, node, pos_args):
1344        """Transform
1345
1346        _result = all(x for L in LL for x in L)
1347
1348        into
1349
1350        for L in LL:
1351            for x in L:
1352                if not x:
1353                    _result = False
1354                    break
1355            else:
1356                continue
1357            break
1358        else:
1359            _result = True
1360        """
1361        return self._transform_any_all(node, pos_args, False)
1362
1363    def _handle_simple_function_any(self, node, pos_args):
1364        """Transform
1365
1366        _result = any(x for L in LL for x in L)
1367
1368        into
1369
1370        for L in LL:
1371            for x in L:
1372                if x:
1373                    _result = True
1374                    break
1375            else:
1376                continue
1377            break
1378        else:
1379            _result = False
1380        """
1381        return self._transform_any_all(node, pos_args, True)
1382
1383    def _transform_any_all(self, node, pos_args, is_any):
1384        if len(pos_args) != 1:
1385            return node
1386        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1387            return node
1388        gen_expr_node = pos_args[0]
1389        loop_node = gen_expr_node.loop
1390        yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1391        if yield_expression is None:
1392            return node
1393
1394        if is_any:
1395            condition = yield_expression
1396        else:
1397            condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)
1398
1399        result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
1400        test_node = Nodes.IfStatNode(
1401            yield_expression.pos,
1402            else_clause = None,
1403            if_clauses = [ Nodes.IfClauseNode(
1404                yield_expression.pos,
1405                condition = condition,
1406                body = Nodes.StatListNode(
1407                    node.pos,
1408                    stats = [
1409                        Nodes.SingleAssignmentNode(
1410                            node.pos,
1411                            lhs = result_ref,
1412                            rhs = ExprNodes.BoolNode(yield_expression.pos, value = is_any,
1413                                                     constant_result = is_any)),
1414                        Nodes.BreakStatNode(node.pos)
1415                        ])) ]
1416            )
1417        loop = loop_node
1418        while isinstance(loop.body, Nodes.LoopNode):
1419            next_loop = loop.body
1420            loop.body = Nodes.StatListNode(loop.body.pos, stats = [
1421                loop.body,
1422                Nodes.BreakStatNode(yield_expression.pos)
1423                ])
1424            next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos)
1425            loop = next_loop
1426        loop_node.else_clause = Nodes.SingleAssignmentNode(
1427            node.pos,
1428            lhs = result_ref,
1429            rhs = ExprNodes.BoolNode(yield_expression.pos, value = not is_any,
1430                                     constant_result = not is_any))
1431
1432        Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node)
1433
1434        return ExprNodes.InlinedGeneratorExpressionNode(
1435            gen_expr_node.pos, loop = loop_node, result_node = result_ref,
1436            expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
1437
1438    def _handle_simple_function_sorted(self, node, pos_args):
1439        """Transform sorted(genexpr) and sorted([listcomp]) into
1440        [listcomp].sort().  CPython just reads the iterable into a
1441        list and calls .sort() on it.  Expanding the iterable in a
1442        listcomp is still faster and the result can be sorted in
1443        place.
1444        """
1445        if len(pos_args) != 1:
1446            return node
1447        if isinstance(pos_args[0], ExprNodes.ComprehensionNode) \
1448               and pos_args[0].type is Builtin.list_type:
1449            listcomp_node = pos_args[0]
1450            loop_node = listcomp_node.loop
1451        elif isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1452            gen_expr_node = pos_args[0]
1453            loop_node = gen_expr_node.loop
1454            yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1455            if yield_expression is None:
1456                return node
1457
1458            append_node = ExprNodes.ComprehensionAppendNode(
1459                yield_expression.pos, expr = yield_expression)
1460
1461            Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1462
1463            listcomp_node = ExprNodes.ComprehensionNode(
1464                gen_expr_node.pos, loop = loop_node,
1465                append = append_node, type = Builtin.list_type,
1466                expr_scope = gen_expr_node.expr_scope,
1467                has_local_scope = True)
1468            append_node.target = listcomp_node
1469        else:
1470            return node
1471
1472        result_node = UtilNodes.ResultRefNode(
1473            pos = loop_node.pos, type = Builtin.list_type, may_hold_none=False)
1474        listcomp_assign_node = Nodes.SingleAssignmentNode(
1475            node.pos, lhs = result_node, rhs = listcomp_node, first = True)
1476
1477        sort_method = ExprNodes.AttributeNode(
1478            node.pos, obj = result_node, attribute = EncodedString('sort'),
1479            # entry ? type ?
1480            needs_none_check = False)
1481        sort_node = Nodes.ExprStatNode(
1482            node.pos, expr = ExprNodes.SimpleCallNode(
1483                node.pos, function = sort_method, args = []))
1484
1485        sort_node.analyse_declarations(self.current_env())
1486
1487        return UtilNodes.TempResultFromStatNode(
1488            result_node,
1489            Nodes.StatListNode(node.pos, stats = [ listcomp_assign_node, sort_node ]))
1490
1491    def _handle_simple_function_sum(self, node, pos_args):
1492        """Transform sum(genexpr) into an equivalent inlined aggregation loop.
1493        """
1494        if len(pos_args) not in (1,2):
1495            return node
1496        if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
1497                                        ExprNodes.ComprehensionNode)):
1498            return node
1499        gen_expr_node = pos_args[0]
1500        loop_node = gen_expr_node.loop
1501
1502        if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
1503            yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1504            if yield_expression is None:
1505                return node
1506        else: # ComprehensionNode
1507            yield_stat_node = gen_expr_node.append
1508            yield_expression = yield_stat_node.expr
1509            try:
1510                if not yield_expression.is_literal or not yield_expression.type.is_int:
1511                    return node
1512            except AttributeError:
1513                return node # in case we don't have a type yet
1514            # special case: old Py2 backwards compatible "sum([int_const for ...])"
1515            # can safely be unpacked into a genexpr
1516
1517        if len(pos_args) == 1:
1518            start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
1519        else:
1520            start = pos_args[1]
1521
1522        result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
1523        add_node = Nodes.SingleAssignmentNode(
1524            yield_expression.pos,
1525            lhs = result_ref,
1526            rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
1527            )
1528
1529        Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node)
1530
1531        exec_code = Nodes.StatListNode(
1532            node.pos,
1533            stats = [
1534                Nodes.SingleAssignmentNode(
1535                    start.pos,
1536                    lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
1537                    rhs = start,
1538                    first = True),
1539                loop_node
1540                ])
1541
1542        return ExprNodes.InlinedGeneratorExpressionNode(
1543            gen_expr_node.pos, loop = exec_code, result_node = result_ref,
1544            expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
1545            has_local_scope = gen_expr_node.has_local_scope)
1546
1547    def _handle_simple_function_min(self, node, pos_args):
1548        return self._optimise_min_max(node, pos_args, '<')
1549
1550    def _handle_simple_function_max(self, node, pos_args):
1551        return self._optimise_min_max(node, pos_args, '>')
1552
1553    def _optimise_min_max(self, node, args, operator):
1554        """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
1555        """
1556        if len(args) <= 1:
1557            if len(args) == 1 and args[0].is_sequence_constructor:
1558                args = args[0].args
1559            else:
1560                # leave this to Python
1561                return node
1562
1563        cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:]))
1564
1565        last_result = args[0]
1566        for arg_node in cascaded_nodes:
1567            result_ref = UtilNodes.ResultRefNode(last_result)
1568            last_result = ExprNodes.CondExprNode(
1569                arg_node.pos,
1570                true_val = arg_node,
1571                false_val = result_ref,
1572                test = ExprNodes.PrimaryCmpNode(
1573                    arg_node.pos,
1574                    operand1 = arg_node,
1575                    operator = operator,
1576                    operand2 = result_ref,
1577                    )
1578                )
1579            last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)
1580
1581        for ref_node in cascaded_nodes[::-1]:
1582            last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
1583
1584        return last_result
1585
1586    def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
1587        if not pos_args:
1588            return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
1589        # This is a bit special - for iterables (including genexps),
1590        # Python actually overallocates and resizes a newly created
1591        # tuple incrementally while reading items, which we can't
1592        # easily do without explicit node support. Instead, we read
1593        # the items into a list and then copy them into a tuple of the
1594        # final size.  This takes up to twice as much memory, but will
1595        # have to do until we have real support for genexps.
1596        result = self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
1597        if result is not node:
1598            return ExprNodes.AsTupleNode(node.pos, arg=result)
1599        return node
1600
1601    def _handle_simple_function_frozenset(self, node, pos_args):
1602        """Replace frozenset([...]) by frozenset((...)) as tuples are more efficient.
1603        """
1604        if len(pos_args) != 1:
1605            return node
1606        if pos_args[0].is_sequence_constructor and not pos_args[0].args:
1607            del pos_args[0]
1608        elif isinstance(pos_args[0], ExprNodes.ListNode):
1609            pos_args[0] = pos_args[0].as_tuple()
1610        return node
1611
1612    def _handle_simple_function_list(self, node, pos_args):
1613        if not pos_args:
1614            return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
1615        return self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
1616
1617    def _handle_simple_function_set(self, node, pos_args):
1618        if not pos_args:
1619            return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
1620        return self._transform_list_set_genexpr(node, pos_args, Builtin.set_type)
1621
1622    def _transform_list_set_genexpr(self, node, pos_args, target_type):
1623        """Replace set(genexpr) and list(genexpr) by a literal comprehension.
1624        """
1625        if len(pos_args) > 1:
1626            return node
1627        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1628            return node
1629        gen_expr_node = pos_args[0]
1630        loop_node = gen_expr_node.loop
1631
1632        yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1633        if yield_expression is None:
1634            return node
1635
1636        append_node = ExprNodes.ComprehensionAppendNode(
1637            yield_expression.pos,
1638            expr = yield_expression)
1639
1640        Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1641
1642        comp = ExprNodes.ComprehensionNode(
1643            node.pos,
1644            has_local_scope = True,
1645            expr_scope = gen_expr_node.expr_scope,
1646            loop = loop_node,
1647            append = append_node,
1648            type = target_type)
1649        append_node.target = comp
1650        return comp
1651
1652    def _handle_simple_function_dict(self, node, pos_args):
1653        """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
1654        """
1655        if len(pos_args) == 0:
1656            return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
1657        if len(pos_args) > 1:
1658            return node
1659        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1660            return node
1661        gen_expr_node = pos_args[0]
1662        loop_node = gen_expr_node.loop
1663
1664        yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1665        if yield_expression is None:
1666            return node
1667
1668        if not isinstance(yield_expression, ExprNodes.TupleNode):
1669            return node
1670        if len(yield_expression.args) != 2:
1671            return node
1672
1673        append_node = ExprNodes.DictComprehensionAppendNode(
1674            yield_expression.pos,
1675            key_expr = yield_expression.args[0],
1676            value_expr = yield_expression.args[1])
1677
1678        Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1679
1680        dictcomp = ExprNodes.ComprehensionNode(
1681            node.pos,
1682            has_local_scope = True,
1683            expr_scope = gen_expr_node.expr_scope,
1684            loop = loop_node,
1685            append = append_node,
1686            type = Builtin.dict_type)
1687        append_node.target = dictcomp
1688        return dictcomp
1689
1690    # specific handlers for general call nodes
1691
1692    def _handle_general_function_dict(self, node, pos_args, kwargs):
1693        """Replace dict(a=b,c=d,...) by the underlying keyword dict
1694        construction which is done anyway.
1695        """
1696        if len(pos_args) > 0:
1697            return node
1698        if not isinstance(kwargs, ExprNodes.DictNode):
1699            return node
1700        return kwargs
1701
1702
1703class InlineDefNodeCalls(Visitor.NodeRefCleanupMixin, Visitor.EnvTransform):
1704    visit_Node = Visitor.VisitorTransform.recurse_to_children
1705
1706    def get_constant_value_node(self, name_node):
1707        if name_node.cf_state is None:
1708            return None
1709        if name_node.cf_state.cf_is_null:
1710            return None
1711        entry = self.current_env().lookup(name_node.name)
1712        if not entry or (not entry.cf_assignments
1713                         or len(entry.cf_assignments) != 1):
1714            # not just a single assignment in all closures
1715            return None
1716        return entry.cf_assignments[0].rhs
1717
1718    def visit_SimpleCallNode(self, node):
1719        self.visitchildren(node)
1720        if not self.current_directives.get('optimize.inline_defnode_calls'):
1721            return node
1722        function_name = node.function
1723        if not function_name.is_name:
1724            return node
1725        function = self.get_constant_value_node(function_name)
1726        if not isinstance(function, ExprNodes.PyCFunctionNode):
1727            return node
1728        inlined = ExprNodes.InlinedDefNodeCallNode(
1729            node.pos, function_name=function_name,
1730            function=function, args=node.args)
1731        if inlined.can_be_inlined():
1732            return self.replace(node, inlined)
1733        return node
1734
1735
1736class OptimizeBuiltinCalls(Visitor.MethodDispatcherTransform):
1737    """Optimize some common methods calls and instantiation patterns
1738    for builtin types *after* the type analysis phase.
1739
1740    Running after type analysis, this transform can only perform
1741    function replacements that do not alter the function return type
1742    in a way that was not anticipated by the type analysis.
1743    """
1744    ### cleanup to avoid redundant coercions to/from Python types
1745
1746    def _visit_PyTypeTestNode(self, node):
1747        # disabled - appears to break assignments in some cases, and
1748        # also drops a None check, which might still be required
1749        """Flatten redundant type checks after tree changes.
1750        """
1751        old_arg = node.arg
1752        self.visitchildren(node)
1753        if old_arg is node.arg or node.arg.type != node.type:
1754            return node
1755        return node.arg
1756
1757    def _visit_TypecastNode(self, node):
1758        # disabled - the user may have had a reason to put a type
1759        # cast, even if it looks redundant to Cython
1760        """
1761        Drop redundant type casts.
1762        """
1763        self.visitchildren(node)
1764        if node.type == node.operand.type:
1765            return node.operand
1766        return node
1767
1768    def visit_ExprStatNode(self, node):
1769        """
1770        Drop useless coercions.
1771        """
1772        self.visitchildren(node)
1773        if isinstance(node.expr, ExprNodes.CoerceToPyTypeNode):
1774            node.expr = node.expr.arg
1775        return node
1776
1777    def visit_CoerceToBooleanNode(self, node):
1778        """Drop redundant conversion nodes after tree changes.
1779        """
1780        self.visitchildren(node)
1781        arg = node.arg
1782        if isinstance(arg, ExprNodes.PyTypeTestNode):
1783            arg = arg.arg
1784        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1785            if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
1786                return arg.arg.coerce_to_boolean(self.current_env())
1787        return node
1788
1789    def visit_CoerceFromPyTypeNode(self, node):
1790        """Drop redundant conversion nodes after tree changes.
1791
1792        Also, optimise away calls to Python's builtin int() and
1793        float() if the result is going to be coerced back into a C
1794        type anyway.
1795        """
1796        self.visitchildren(node)
1797        arg = node.arg
1798        if not arg.type.is_pyobject:
1799            # no Python conversion left at all, just do a C coercion instead
1800            if node.type == arg.type:
1801                return arg
1802            else:
1803                return arg.coerce_to(node.type, self.current_env())
1804        if isinstance(arg, ExprNodes.PyTypeTestNode):
1805            arg = arg.arg
1806        if arg.is_literal:
1807            if (node.type.is_int and isinstance(arg, ExprNodes.IntNode) or
1808                    node.type.is_float and isinstance(arg, ExprNodes.FloatNode) or
1809                    node.type.is_int and isinstance(arg, ExprNodes.BoolNode)):
1810                return arg.coerce_to(node.type, self.current_env())
1811        elif isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1812            if arg.type is PyrexTypes.py_object_type:
1813                if node.type.assignable_from(arg.arg.type):
1814                    # completely redundant C->Py->C coercion
1815                    return arg.arg.coerce_to(node.type, self.current_env())
1816        elif isinstance(arg, ExprNodes.SimpleCallNode):
1817            if node.type.is_int or node.type.is_float:
1818                return self._optimise_numeric_cast_call(node, arg)
1819        elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access:
1820            index_node = arg.index
1821            if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
1822                index_node = index_node.arg
1823            if index_node.type.is_int:
1824                return self._optimise_int_indexing(node, arg, index_node)
1825        return node
1826
1827    PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
1828        PyrexTypes.c_char_type, [
1829            PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
1830            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
1831            PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
1832            ],
1833        exception_value = "((char)-1)",
1834        exception_check = True)
1835
1836    def _optimise_int_indexing(self, coerce_node, arg, index_node):
1837        env = self.current_env()
1838        bound_check_bool = env.directives['boundscheck'] and 1 or 0
1839        if arg.base.type is Builtin.bytes_type:
1840            if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
1841                # bytes[index] -> char
1842                bound_check_node = ExprNodes.IntNode(
1843                    coerce_node.pos, value=str(bound_check_bool),
1844                    constant_result=bound_check_bool)
1845                node = ExprNodes.PythonCapiCallNode(
1846                    coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
1847                    self.PyBytes_GetItemInt_func_type,
1848                    args=[
1849                        arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
1850                        index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
1851                        bound_check_node,
1852                        ],
1853                    is_temp=True,
1854                    utility_code=UtilityCode.load_cached(
1855                        'bytes_index', 'StringTools.c'))
1856                if coerce_node.type is not PyrexTypes.c_char_type:
1857                    node = node.coerce_to(coerce_node.type, env)
1858                return node
1859        return coerce_node
1860
1861    def _optimise_numeric_cast_call(self, node, arg):
1862        function = arg.function
1863        if not isinstance(function, ExprNodes.NameNode) \
1864               or not function.type.is_builtin_type \
1865               or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
1866            return node
1867        args = arg.arg_tuple.args
1868        if len(args) != 1:
1869            return node
1870        func_arg = args[0]
1871        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1872            func_arg = func_arg.arg
1873        elif func_arg.type.is_pyobject:
1874            # play safe: Python conversion might work on all sorts of things
1875            return node
1876        if function.name == 'int':
1877            if func_arg.type.is_int or node.type.is_int:
1878                if func_arg.type == node.type:
1879                    return func_arg
1880                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1881                    return ExprNodes.TypecastNode(
1882                        node.pos, operand=func_arg, type=node.type)
1883        elif function.name == 'float':
1884            if func_arg.type.is_float or node.type.is_float:
1885                if func_arg.type == node.type:
1886                    return func_arg
1887                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1888                    return ExprNodes.TypecastNode(
1889                        node.pos, operand=func_arg, type=node.type)
1890        return node
1891
1892    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1893        if not expected: # None or 0
1894            arg_str = ''
1895        elif isinstance(expected, basestring) or expected > 1:
1896            arg_str = '...'
1897        elif expected == 1:
1898            arg_str = 'x'
1899        else:
1900            arg_str = ''
1901        if expected is not None:
1902            expected_str = 'expected %s, ' % expected
1903        else:
1904            expected_str = ''
1905        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1906            function_name, arg_str, expected_str, len(args)))
1907
1908    ### generic fallbacks
1909
1910    def _handle_function(self, node, function_name, function, arg_list, kwargs):
1911        return node
1912
1913    def _handle_method(self, node, type_name, attr_name, function,
1914                       arg_list, is_unbound_method, kwargs):
1915        """
1916        Try to inject C-API calls for unbound method calls to builtin types.
1917        While the method declarations in Builtin.py already handle this, we
1918        can additionally resolve bound and unbound methods here that were
1919        assigned to variables ahead of time.
1920        """
1921        if kwargs:
1922            return node
1923        if not function or not function.is_attribute or not function.obj.is_name:
1924            # cannot track unbound method calls over more than one indirection as
1925            # the names might have been reassigned in the meantime
1926            return node
1927        type_entry = self.current_env().lookup(type_name)
1928        if not type_entry:
1929            return node
1930        method = ExprNodes.AttributeNode(
1931            node.function.pos,
1932            obj=ExprNodes.NameNode(
1933                function.pos,
1934                name=type_name,
1935                entry=type_entry,
1936                type=type_entry.type),
1937            attribute=attr_name,
1938            is_called=True).analyse_as_unbound_cmethod_node(self.current_env())
1939        if method is None:
1940            return node
1941        args = node.args
1942        if args is None and node.arg_tuple:
1943            args = node.arg_tuple.args
1944        call_node = ExprNodes.SimpleCallNode(
1945            node.pos,
1946            function=method,
1947            args=args)
1948        if not is_unbound_method:
1949            call_node.self = function.obj
1950        call_node.analyse_c_function_call(self.current_env())
1951        call_node.analysed = True
1952        return call_node.coerce_to(node.type, self.current_env())
1953
1954    ### builtin types
1955
1956    PyDict_Copy_func_type = PyrexTypes.CFuncType(
1957        Builtin.dict_type, [
1958            PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
1959            ])
1960
1961    def _handle_simple_function_dict(self, node, function, pos_args):
1962        """Replace dict(some_dict) by PyDict_Copy(some_dict).
1963        """
1964        if len(pos_args) != 1:
1965            return node
1966        arg = pos_args[0]
1967        if arg.type is Builtin.dict_type:
1968            arg = arg.as_none_safe_node("'NoneType' is not iterable")
1969            return ExprNodes.PythonCapiCallNode(
1970                node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1971                args = [arg],
1972                is_temp = node.is_temp
1973                )
1974        return node
1975
1976    PyList_AsTuple_func_type = PyrexTypes.CFuncType(
1977        Builtin.tuple_type, [
1978            PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
1979            ])
1980
1981    def _handle_simple_function_tuple(self, node, function, pos_args):
1982        """Replace tuple([...]) by a call to PyList_AsTuple.
1983        """
1984        if len(pos_args) != 1:
1985            return node
1986        arg = pos_args[0]
1987        if arg.type is Builtin.tuple_type and not arg.may_be_none():
1988            return arg
1989        if arg.type is not Builtin.list_type:
1990            return node
1991        pos_args[0] = arg.as_none_safe_node(
1992            "'NoneType' object is not iterable")
1993
1994        return ExprNodes.PythonCapiCallNode(
1995            node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1996            args = pos_args,
1997            is_temp = node.is_temp
1998            )
1999
2000    PySet_New_func_type = PyrexTypes.CFuncType(
2001        Builtin.set_type, [
2002            PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)
2003        ])
2004
2005    def _handle_simple_function_set(self, node, function, pos_args):
2006        if len(pos_args) != 1:
2007            return node
2008        if pos_args[0].is_sequence_constructor:
2009            # We can optimise set([x,y,z]) safely into a set literal,
2010            # but only if we create all items before adding them -
2011            # adding an item may raise an exception if it is not
2012            # hashable, but creating the later items may have
2013            # side-effects.
2014            args = []
2015            temps = []
2016            for arg in pos_args[0].args:
2017                if not arg.is_simple():
2018                    arg = UtilNodes.LetRefNode(arg)
2019                    temps.append(arg)
2020                args.append(arg)
2021            result = ExprNodes.SetNode(node.pos, is_temp=1, args=args)
2022            for temp in temps[::-1]:
2023                result = UtilNodes.EvalWithTempExprNode(temp, result)
2024            return result
2025        else:
2026            # PySet_New(it) is better than a generic Python call to set(it)
2027            return ExprNodes.PythonCapiCallNode(
2028                node.pos, "PySet_New",
2029                self.PySet_New_func_type,
2030                args=pos_args,
2031                is_temp=node.is_temp,
2032                utility_code=UtilityCode.load_cached('pyset_compat', 'Builtins.c'),
2033                py_name="set")
2034
2035    PyFrozenSet_New_func_type = PyrexTypes.CFuncType(
2036        Builtin.frozenset_type, [
2037            PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)
2038        ])
2039
2040    def _handle_simple_function_frozenset(self, node, function, pos_args):
2041        if not pos_args:
2042            pos_args = [ExprNodes.NullNode(node.pos)]
2043        elif len(pos_args) > 1:
2044            return node
2045        elif pos_args[0].type is Builtin.frozenset_type and not pos_args[0].may_be_none():
2046            return pos_args[0]
2047        # PyFrozenSet_New(it) is better than a generic Python call to frozenset(it)
2048        return ExprNodes.PythonCapiCallNode(
2049            node.pos, "__Pyx_PyFrozenSet_New",
2050            self.PyFrozenSet_New_func_type,
2051            args=pos_args,
2052            is_temp=node.is_temp,
2053            utility_code=UtilityCode.load_cached('pyfrozenset_new', 'Builtins.c'),
2054            py_name="frozenset")
2055
2056    PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
2057        PyrexTypes.c_double_type, [
2058            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
2059            ],
2060        exception_value = "((double)-1)",
2061        exception_check = True)
2062
2063    def _handle_simple_function_float(self, node, function, pos_args):
2064        """Transform float() into either a C type cast or a faster C
2065        function call.
2066        """
2067        # Note: this requires the float() function to be typed as
2068        # returning a C 'double'
2069        if len(pos_args) == 0:
2070            return ExprNodes.FloatNode(
2071                node, value="0.0", constant_result=0.0
2072                ).coerce_to(Builtin.float_type, self.current_env())
2073        elif len(pos_args) != 1:
2074            self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
2075            return node
2076        func_arg = pos_args[0]
2077        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
2078            func_arg = func_arg.arg
2079        if func_arg.type is PyrexTypes.c_double_type:
2080            return func_arg
2081        elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
2082            return ExprNodes.TypecastNode(
2083                node.pos, operand=func_arg, type=node.type)
2084        return ExprNodes.PythonCapiCallNode(
2085            node.pos, "__Pyx_PyObject_AsDouble",
2086            self.PyObject_AsDouble_func_type,
2087            args = pos_args,
2088            is_temp = node.is_temp,
2089            utility_code = load_c_utility('pyobject_as_double'),
2090            py_name = "float")
2091
2092    PyNumber_Int_func_type = PyrexTypes.CFuncType(
2093        PyrexTypes.py_object_type, [
2094            PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None)
2095            ])
2096
2097    def _handle_simple_function_int(self, node, function, pos_args):
2098        """Transform int() into a faster C function call.
2099        """
2100        if len(pos_args) == 0:
2101            return ExprNodes.IntNode(node, value="0", constant_result=0,
2102                                     type=PyrexTypes.py_object_type)
2103        elif len(pos_args) != 1:
2104            return node  # int(x, base)
2105        func_arg = pos_args[0]
2106        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
2107            return node  # handled in visit_CoerceFromPyTypeNode()
2108        if func_arg.type.is_pyobject and node.type.is_pyobject:
2109            return ExprNodes.PythonCapiCallNode(
2110                node.pos, "PyNumber_Int", self.PyNumber_Int_func_type,
2111                args=pos_args, is_temp=True)
2112        return node
2113
2114    def _handle_simple_function_bool(self, node, function, pos_args):
2115        """Transform bool(x) into a type coercion to a boolean.
2116        """
2117        if len(pos_args) == 0:
2118            return ExprNodes.BoolNode(
2119                node.pos, value=False, constant_result=False
2120                ).coerce_to(Builtin.bool_type, self.current_env())
2121        elif len(pos_args) != 1:
2122            self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
2123            return node
2124        else:
2125            # => !!<bint>(x)  to make sure it's exactly 0 or 1
2126            operand = pos_args[0].coerce_to_boolean(self.current_env())
2127            operand = ExprNodes.NotNode(node.pos, operand = operand)
2128            operand = ExprNodes.NotNode(node.pos, operand = operand)
2129            # coerce back to Python object as that's the result we are expecting
2130            return operand.coerce_to_pyobject(self.current_env())
2131
2132    ### builtin functions
2133
2134    Pyx_strlen_func_type = PyrexTypes.CFuncType(
2135        PyrexTypes.c_size_t_type, [
2136            PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
2137            ])
2138
2139    Pyx_Py_UNICODE_strlen_func_type = PyrexTypes.CFuncType(
2140        PyrexTypes.c_size_t_type, [
2141            PyrexTypes.CFuncTypeArg("unicode", PyrexTypes.c_py_unicode_ptr_type, None)
2142            ])
2143
2144    PyObject_Size_func_type = PyrexTypes.CFuncType(
2145        PyrexTypes.c_py_ssize_t_type, [
2146            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
2147            ],
2148        exception_value="-1")
2149
2150    _map_to_capi_len_function = {
2151        Builtin.unicode_type   : "__Pyx_PyUnicode_GET_LENGTH",
2152        Builtin.bytes_type     : "PyBytes_GET_SIZE",
2153        Builtin.list_type      : "PyList_GET_SIZE",
2154        Builtin.tuple_type     : "PyTuple_GET_SIZE",
2155        Builtin.dict_type      : "PyDict_Size",
2156        Builtin.set_type       : "PySet_Size",
2157        Builtin.frozenset_type : "PySet_Size",
2158        }.get
2159
2160    _ext_types_with_pysize = set(["cpython.array.array"])
2161
2162    def _handle_simple_function_len(self, node, function, pos_args):
2163        """Replace len(char*) by the equivalent call to strlen(),
2164        len(Py_UNICODE) by the equivalent Py_UNICODE_strlen() and
2165        len(known_builtin_type) by an equivalent C-API call.
2166        """
2167        if len(pos_args) != 1:
2168            self._error_wrong_arg_count('len', node, pos_args, 1)
2169            return node
2170        arg = pos_args[0]
2171        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
2172            arg = arg.arg
2173        if arg.type.is_string:
2174            new_node = ExprNodes.PythonCapiCallNode(
2175                node.pos, "strlen", self.Pyx_strlen_func_type,
2176                args = [arg],
2177                is_temp = node.is_temp,
2178                utility_code = UtilityCode.load_cached("IncludeStringH", "StringTools.c"))
2179        elif arg.type.is_pyunicode_ptr:
2180            new_node = ExprNodes.PythonCapiCallNode(
2181                node.pos, "__Pyx_Py_UNICODE_strlen", self.Pyx_Py_UNICODE_strlen_func_type,
2182                args = [arg],
2183                is_temp = node.is_temp)
2184        elif arg.type.is_pyobject:
2185            cfunc_name = self._map_to_capi_len_function(arg.type)
2186            if cfunc_name is None:
2187                arg_type = arg.type
2188                if ((arg_type.is_extension_type or arg_type.is_builtin_type)
2189                    and arg_type.entry.qualified_name in self._ext_types_with_pysize):
2190                    cfunc_name = 'Py_SIZE'
2191                else:
2192                    return node
2193            arg = arg.as_none_safe_node(
2194                "object of type 'NoneType' has no len()")
2195            new_node = ExprNodes.PythonCapiCallNode(
2196                node.pos, cfunc_name, self.PyObject_Size_func_type,
2197                args = [arg],
2198                is_temp = node.is_temp)
2199        elif arg.type.is_unicode_char:
2200            return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
2201                                     type=node.type)
2202        else:
2203            return node
2204        if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
2205            new_node = new_node.coerce_to(node.type, self.current_env())
2206        return new_node
2207
2208    Pyx_Type_func_type = PyrexTypes.CFuncType(
2209        Builtin.type_type, [
2210            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
2211            ])
2212
2213    def _handle_simple_function_type(self, node, function, pos_args):
2214        """Replace type(o) by a macro call to Py_TYPE(o).
2215        """
2216        if len(pos_args) != 1:
2217            return node
2218        node = ExprNodes.PythonCapiCallNode(
2219            node.pos, "Py_TYPE", self.Pyx_Type_func_type,
2220            args = pos_args,
2221            is_temp = False)
2222        return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
2223
2224    Py_type_check_func_type = PyrexTypes.CFuncType(
2225        PyrexTypes.c_bint_type, [
2226            PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
2227            ])
2228
2229    def _handle_simple_function_isinstance(self, node, function, pos_args):
2230        """Replace isinstance() checks against builtin types by the
2231        corresponding C-API call.
2232        """
2233        if len(pos_args) != 2:
2234            return node
2235        arg, types = pos_args
2236        temp = None
2237        if isinstance(types, ExprNodes.TupleNode):
2238            types = types.args
2239            if arg.is_attribute or not arg.is_simple():
2240                arg = temp = UtilNodes.ResultRefNode(arg)
2241        elif types.type is Builtin.type_type:
2242            types = [types]
2243        else:
2244            return node
2245
2246        tests = []
2247        test_nodes = []
2248        env = self.current_env()
2249        for test_type_node in types:
2250            builtin_type = None
2251            if test_type_node.is_name:
2252                if test_type_node.entry:
2253                    entry = env.lookup(test_type_node.entry.name)
2254                    if entry and entry.type and entry.type.is_builtin_type:
2255                        builtin_type = entry.type
2256            if builtin_type is Builtin.type_type:
2257                # all types have type "type", but there's only one 'type'
2258                if entry.name != 'type' or not (
2259                        entry.scope and entry.scope.is_builtin_scope):
2260                    builtin_type = None
2261            if builtin_type is not None:
2262                type_check_function = entry.type.type_check_function(exact=False)
2263                if type_check_function in tests:
2264                    continue
2265                tests.append(type_check_function)
2266                type_check_args = [arg]
2267            elif test_type_node.type is Builtin.type_type:
2268                type_check_function = '__Pyx_TypeCheck'
2269                type_check_args = [arg, test_type_node]
2270            else:
2271                return node
2272            test_nodes.append(
2273                ExprNodes.PythonCapiCallNode(
2274                    test_type_node.pos, type_check_function, self.Py_type_check_func_type,
2275                    args = type_check_args,
2276                    is_temp = True,
2277                    ))
2278
2279        def join_with_or(a,b, make_binop_node=ExprNodes.binop_node):
2280            or_node = make_binop_node(node.pos, 'or', a, b)
2281            or_node.type = PyrexTypes.c_bint_type
2282            or_node.is_temp = True
2283            return or_node
2284
2285        test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
2286        if temp is not None:
2287            test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
2288        return test_node
2289
2290    def _handle_simple_function_ord(self, node, function, pos_args):
2291        """Unpack ord(Py_UNICODE) and ord('X').
2292        """
2293        if len(pos_args) != 1:
2294            return node
2295        arg = pos_args[0]
2296        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
2297            if arg.arg.type.is_unicode_char:
2298                return ExprNodes.TypecastNode(
2299                    arg.pos, operand=arg.arg, type=PyrexTypes.c_int_type
2300                    ).coerce_to(node.type, self.current_env())
2301        elif isinstance(arg, ExprNodes.UnicodeNode):
2302            if len(arg.value) == 1:
2303                return ExprNodes.IntNode(
2304                    arg.pos, type=PyrexTypes.c_int_type,
2305                    value=str(ord(arg.value)),
2306                    constant_result=ord(arg.value)
2307                    ).coerce_to(node.type, self.current_env())
2308        elif isinstance(arg, ExprNodes.StringNode):
2309            if arg.unicode_value and len(arg.unicode_value) == 1 \
2310                    and ord(arg.unicode_value) <= 255:  # Py2/3 portability
2311                return ExprNodes.IntNode(
2312                    arg.pos, type=PyrexTypes.c_int_type,
2313                    value=str(ord(arg.unicode_value)),
2314                    constant_result=ord(arg.unicode_value)
2315                    ).coerce_to(node.type, self.current_env())
2316        return node
2317
2318    ### special methods
2319
2320    Pyx_tp_new_func_type = PyrexTypes.CFuncType(
2321        PyrexTypes.py_object_type, [
2322            PyrexTypes.CFuncTypeArg("type",   PyrexTypes.py_object_type, None),
2323            PyrexTypes.CFuncTypeArg("args",   Builtin.tuple_type, None),
2324            ])
2325
2326    Pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType(
2327        PyrexTypes.py_object_type, [
2328            PyrexTypes.CFuncTypeArg("type",   PyrexTypes.py_object_type, None),
2329            PyrexTypes.CFuncTypeArg("args",   Builtin.tuple_type, None),
2330            PyrexTypes.CFuncTypeArg("kwargs", Builtin.dict_type, None),
2331        ])
2332
2333    def _handle_any_slot__new__(self, node, function, args,
2334                                is_unbound_method, kwargs=None):
2335        """Replace 'exttype.__new__(exttype, ...)' by a call to exttype->tp_new()
2336        """
2337        obj = function.obj
2338        if not is_unbound_method or len(args) < 1:
2339            return node
2340        type_arg = args[0]
2341        if not obj.is_name or not type_arg.is_name:
2342            # play safe
2343            return node
2344        if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
2345            # not a known type, play safe
2346            return node
2347        if not type_arg.type_entry or not obj.type_entry:
2348            if obj.name != type_arg.name:
2349                return node
2350            # otherwise, we know it's a type and we know it's the same
2351            # type for both - that should do
2352        elif type_arg.type_entry != obj.type_entry:
2353            # different types - may or may not lead to an error at runtime
2354            return node
2355
2356        args_tuple = ExprNodes.TupleNode(node.pos, args=args[1:])
2357        args_tuple = args_tuple.analyse_types(
2358            self.current_env(), skip_children=True)
2359
2360        if type_arg.type_entry:
2361            ext_type = type_arg.type_entry.type
2362            if (ext_type.is_extension_type and ext_type.typeobj_cname and
2363                    ext_type.scope.global_scope() == self.current_env().global_scope()):
2364                # known type in current module
2365                tp_slot = TypeSlots.ConstructorSlot("tp_new", '__new__')
2366                slot_func_cname = TypeSlots.get_slot_function(ext_type.scope, tp_slot)
2367                if slot_func_cname:
2368                    cython_scope = self.context.cython_scope
2369                    PyTypeObjectPtr = PyrexTypes.CPtrType(
2370                        cython_scope.lookup('PyTypeObject').type)
2371                    pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType(
2372                        PyrexTypes.py_object_type, [
2373                            PyrexTypes.CFuncTypeArg("type",   PyTypeObjectPtr, None),
2374                            PyrexTypes.CFuncTypeArg("args",   PyrexTypes.py_object_type, None),
2375                            PyrexTypes.CFuncTypeArg("kwargs", PyrexTypes.py_object_type, None),
2376                            ])
2377
2378                    type_arg = ExprNodes.CastNode(type_arg, PyTypeObjectPtr)
2379                    if not kwargs:
2380                        kwargs = ExprNodes.NullNode(node.pos, type=PyrexTypes.py_object_type)  # hack?
2381                    return ExprNodes.PythonCapiCallNode(
2382                        node.pos, slot_func_cname,
2383                        pyx_tp_new_kwargs_func_type,
2384                        args=[type_arg, args_tuple, kwargs],
2385                        is_temp=True)
2386        else:
2387            # arbitrary variable, needs a None check for safety
2388            type_arg = type_arg.as_none_safe_node(
2389                "object.__new__(X): X is not a type object (NoneType)")
2390
2391        utility_code = UtilityCode.load_cached('tp_new', 'ObjectHandling.c')
2392        if kwargs:
2393            return ExprNodes.PythonCapiCallNode(
2394                node.pos, "__Pyx_tp_new_kwargs", self.Pyx_tp_new_kwargs_func_type,
2395                args=[type_arg, args_tuple, kwargs],
2396                utility_code=utility_code,
2397                is_temp=node.is_temp
2398                )
2399        else:
2400            return ExprNodes.PythonCapiCallNode(
2401                node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
2402                args=[type_arg, args_tuple],
2403                utility_code=utility_code,
2404                is_temp=node.is_temp
2405            )
2406
2407    ### methods of builtin types
2408
2409    PyObject_Append_func_type = PyrexTypes.CFuncType(
2410        PyrexTypes.c_returncode_type, [
2411            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2412            PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
2413            ],
2414        exception_value="-1")
2415
2416    def _handle_simple_method_object_append(self, node, function, args, is_unbound_method):
2417        """Optimistic optimisation as X.append() is almost always
2418        referring to a list.
2419        """
2420        if len(args) != 2 or node.result_is_used:
2421            return node
2422
2423        return ExprNodes.PythonCapiCallNode(
2424            node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
2425            args=args,
2426            may_return_none=False,
2427            is_temp=node.is_temp,
2428            result_is_used=False,
2429            utility_code=load_c_utility('append')
2430        )
2431
2432    PyByteArray_Append_func_type = PyrexTypes.CFuncType(
2433        PyrexTypes.c_returncode_type, [
2434            PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None),
2435            PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_int_type, None),
2436            ],
2437        exception_value="-1")
2438
2439    PyByteArray_AppendObject_func_type = PyrexTypes.CFuncType(
2440        PyrexTypes.c_returncode_type, [
2441            PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None),
2442            PyrexTypes.CFuncTypeArg("value", PyrexTypes.py_object_type, None),
2443            ],
2444        exception_value="-1")
2445
2446    def _handle_simple_method_bytearray_append(self, node, function, args, is_unbound_method):
2447        if len(args) != 2:
2448            return node
2449        func_name = "__Pyx_PyByteArray_Append"
2450        func_type = self.PyByteArray_Append_func_type
2451
2452        value = unwrap_coerced_node(args[1])
2453        if value.type.is_int or isinstance(value, ExprNodes.IntNode):
2454            value = value.coerce_to(PyrexTypes.c_int_type, self.current_env())
2455            utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c")
2456        elif value.is_string_literal:
2457            if not value.can_coerce_to_char_literal():
2458                return node
2459            value = value.coerce_to(PyrexTypes.c_char_type, self.current_env())
2460            utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c")
2461        elif value.type.is_pyobject:
2462            func_name = "__Pyx_PyByteArray_AppendObject"
2463            func_type = self.PyByteArray_AppendObject_func_type
2464            utility_code = UtilityCode.load_cached("ByteArrayAppendObject", "StringTools.c")
2465        else:
2466            return node
2467
2468        new_node = ExprNodes.PythonCapiCallNode(
2469            node.pos, func_name, func_type,
2470            args=[args[0], value],
2471            may_return_none=False,
2472            is_temp=node.is_temp,
2473            utility_code=utility_code,
2474        )
2475        if node.result_is_used:
2476            new_node = new_node.coerce_to(node.type, self.current_env())
2477        return new_node
2478
2479    PyObject_Pop_func_type = PyrexTypes.CFuncType(
2480        PyrexTypes.py_object_type, [
2481            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2482            ])
2483
2484    PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
2485        PyrexTypes.py_object_type, [
2486            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2487            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
2488            ])
2489
2490    def _handle_simple_method_list_pop(self, node, function, args, is_unbound_method):
2491        return self._handle_simple_method_object_pop(
2492            node, function, args, is_unbound_method, is_list=True)
2493
2494    def _handle_simple_method_object_pop(self, node, function, args, is_unbound_method, is_list=False):
2495        """Optimistic optimisation as X.pop([n]) is almost always
2496        referring to a list.
2497        """
2498        if not args:
2499            return node
2500        args = args[:]
2501        if is_list:
2502            type_name = 'List'
2503            args[0] = args[0].as_none_safe_node(
2504                "'NoneType' object has no attribute '%s'",
2505                error="PyExc_AttributeError",
2506                format_args=['pop'])
2507        else:
2508            type_name = 'Object'
2509        if len(args) == 1:
2510            return ExprNodes.PythonCapiCallNode(
2511                node.pos, "__Pyx_Py%s_Pop" % type_name,
2512                self.PyObject_Pop_func_type,
2513                args=args,
2514                may_return_none=True,
2515                is_temp=node.is_temp,
2516                utility_code=load_c_utility('pop'),
2517            )
2518        elif len(args) == 2:
2519            index = unwrap_coerced_node(args[1])
2520            if is_list or isinstance(index, ExprNodes.IntNode):
2521                index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2522            if index.type.is_int:
2523                widest = PyrexTypes.widest_numeric_type(
2524                    index.type, PyrexTypes.c_py_ssize_t_type)
2525                if widest == PyrexTypes.c_py_ssize_t_type:
2526                    args[1] = index
2527                    return ExprNodes.PythonCapiCallNode(
2528                        node.pos, "__Pyx_Py%s_PopIndex" % type_name,
2529                        self.PyObject_PopIndex_func_type,
2530                        args=args,
2531                        may_return_none=True,
2532                        is_temp=node.is_temp,
2533                        utility_code=load_c_utility("pop_index"),
2534                    )
2535
2536        return node
2537
2538    single_param_func_type = PyrexTypes.CFuncType(
2539        PyrexTypes.c_returncode_type, [
2540            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
2541            ],
2542        exception_value = "-1")
2543
2544    def _handle_simple_method_list_sort(self, node, function, args, is_unbound_method):
2545        """Call PyList_Sort() instead of the 0-argument l.sort().
2546        """
2547        if len(args) != 1:
2548            return node
2549        return self._substitute_method_call(
2550            node, function, "PyList_Sort", self.single_param_func_type,
2551            'sort', is_unbound_method, args).coerce_to(node.type, self.current_env)
2552
2553    Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
2554        PyrexTypes.py_object_type, [
2555            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
2556            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
2557            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
2558            ])
2559
2560    def _handle_simple_method_dict_get(self, node, function, args, is_unbound_method):
2561        """Replace dict.get() by a call to PyDict_GetItem().
2562        """
2563        if len(args) == 2:
2564            args.append(ExprNodes.NoneNode(node.pos))
2565        elif len(args) != 3:
2566            self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
2567            return node
2568
2569        return self._substitute_method_call(
2570            node, function,
2571            "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
2572            'get', is_unbound_method, args,
2573            may_return_none = True,
2574            utility_code = load_c_utility("dict_getitem_default"))
2575
2576    Pyx_PyDict_SetDefault_func_type = PyrexTypes.CFuncType(
2577        PyrexTypes.py_object_type, [
2578            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
2579            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
2580            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
2581            PyrexTypes.CFuncTypeArg("is_safe_type", PyrexTypes.c_int_type, None),
2582            ])
2583
2584    def _handle_simple_method_dict_setdefault(self, node, function, args, is_unbound_method):
2585        """Replace dict.setdefault() by calls to PyDict_GetItem() and PyDict_SetItem().
2586        """
2587        if len(args) == 2:
2588            args.append(ExprNodes.NoneNode(node.pos))
2589        elif len(args) != 3:
2590            self._error_wrong_arg_count('dict.setdefault', node, args, "2 or 3")
2591            return node
2592        key_type = args[1].type
2593        if key_type.is_builtin_type:
2594            is_safe_type = int(key_type.name in
2595                               'str bytes unicode float int long bool')
2596        elif key_type is PyrexTypes.py_object_type:
2597            is_safe_type = -1  # don't know
2598        else:
2599            is_safe_type = 0   # definitely not
2600        args.append(ExprNodes.IntNode(
2601            node.pos, value=str(is_safe_type), constant_result=is_safe_type))
2602
2603        return self._substitute_method_call(
2604            node, function,
2605            "__Pyx_PyDict_SetDefault", self.Pyx_PyDict_SetDefault_func_type,
2606            'setdefault', is_unbound_method, args,
2607            may_return_none=True,
2608            utility_code=load_c_utility('dict_setdefault'))
2609
2610
2611    ### unicode type methods
2612
2613    PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
2614        PyrexTypes.c_bint_type, [
2615            PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
2616            ])
2617
2618    def _inject_unicode_predicate(self, node, function, args, is_unbound_method):
2619        if is_unbound_method or len(args) != 1:
2620            return node
2621        ustring = args[0]
2622        if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2623               not ustring.arg.type.is_unicode_char:
2624            return node
2625        uchar = ustring.arg
2626        method_name = function.attribute
2627        if method_name == 'istitle':
2628            # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
2629            utility_code = UtilityCode.load_cached(
2630                "py_unicode_istitle", "StringTools.c")
2631            function_name = '__Pyx_Py_UNICODE_ISTITLE'
2632        else:
2633            utility_code = None
2634            function_name = 'Py_UNICODE_%s' % method_name.upper()
2635        func_call = self._substitute_method_call(
2636            node, function,
2637            function_name, self.PyUnicode_uchar_predicate_func_type,
2638            method_name, is_unbound_method, [uchar],
2639            utility_code = utility_code)
2640        if node.type.is_pyobject:
2641            func_call = func_call.coerce_to_pyobject(self.current_env)
2642        return func_call
2643
2644    _handle_simple_method_unicode_isalnum   = _inject_unicode_predicate
2645    _handle_simple_method_unicode_isalpha   = _inject_unicode_predicate
2646    _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
2647    _handle_simple_method_unicode_isdigit   = _inject_unicode_predicate
2648    _handle_simple_method_unicode_islower   = _inject_unicode_predicate
2649    _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
2650    _handle_simple_method_unicode_isspace   = _inject_unicode_predicate
2651    _handle_simple_method_unicode_istitle   = _inject_unicode_predicate
2652    _handle_simple_method_unicode_isupper   = _inject_unicode_predicate
2653
2654    PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
2655        PyrexTypes.c_py_ucs4_type, [
2656            PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
2657            ])
2658
2659    def _inject_unicode_character_conversion(self, node, function, args, is_unbound_method):
2660        if is_unbound_method or len(args) != 1:
2661            return node
2662        ustring = args[0]
2663        if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2664               not ustring.arg.type.is_unicode_char:
2665            return node
2666        uchar = ustring.arg
2667        method_name = function.attribute
2668        function_name = 'Py_UNICODE_TO%s' % method_name.upper()
2669        func_call = self._substitute_method_call(
2670            node, function,
2671            function_name, self.PyUnicode_uchar_conversion_func_type,
2672            method_name, is_unbound_method, [uchar])
2673        if node.type.is_pyobject:
2674            func_call = func_call.coerce_to_pyobject(self.current_env)
2675        return func_call
2676
2677    _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
2678    _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
2679    _handle_simple_method_unicode_title = _inject_unicode_character_conversion
2680
2681    PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
2682        Builtin.list_type, [
2683            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2684            PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
2685            ])
2686
2687    def _handle_simple_method_unicode_splitlines(self, node, function, args, is_unbound_method):
2688        """Replace unicode.splitlines(...) by a direct call to the
2689        corresponding C-API function.
2690        """
2691        if len(args) not in (1,2):
2692            self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
2693            return node
2694        self._inject_bint_default_argument(node, args, 1, False)
2695
2696        return self._substitute_method_call(
2697            node, function,
2698            "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
2699            'splitlines', is_unbound_method, args)
2700
2701    PyUnicode_Split_func_type = PyrexTypes.CFuncType(
2702        Builtin.list_type, [
2703            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2704            PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
2705            PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
2706            ]
2707        )
2708
2709    def _handle_simple_method_unicode_split(self, node, function, args, is_unbound_method):
2710        """Replace unicode.split(...) by a direct call to the
2711        corresponding C-API function.
2712        """
2713        if len(args) not in (1,2,3):
2714            self._error_wrong_arg_count('unicode.split', node, args, "1-3")
2715            return node
2716        if len(args) < 2:
2717            args.append(ExprNodes.NullNode(node.pos))
2718        self._inject_int_default_argument(
2719            node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
2720
2721        return self._substitute_method_call(
2722            node, function,
2723            "PyUnicode_Split", self.PyUnicode_Split_func_type,
2724            'split', is_unbound_method, args)
2725
2726    PyString_Tailmatch_func_type = PyrexTypes.CFuncType(
2727        PyrexTypes.c_bint_type, [
2728            PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None),  # bytes/str/unicode
2729            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2730            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2731            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2732            PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2733            ],
2734        exception_value = '-1')
2735
2736    def _handle_simple_method_unicode_endswith(self, node, function, args, is_unbound_method):
2737        return self._inject_tailmatch(
2738            node, function, args, is_unbound_method, 'unicode', 'endswith',
2739            unicode_tailmatch_utility_code, +1)
2740
2741    def _handle_simple_method_unicode_startswith(self, node, function, args, is_unbound_method):
2742        return self._inject_tailmatch(
2743            node, function, args, is_unbound_method, 'unicode', 'startswith',
2744            unicode_tailmatch_utility_code, -1)
2745
2746    def _inject_tailmatch(self, node, function, args, is_unbound_method, type_name,
2747                          method_name, utility_code, direction):
2748        """Replace unicode.startswith(...) and unicode.endswith(...)
2749        by a direct call to the corresponding C-API function.
2750        """
2751        if len(args) not in (2,3,4):
2752            self._error_wrong_arg_count('%s.%s' % (type_name, method_name), node, args, "2-4")
2753            return node
2754        self._inject_int_default_argument(
2755            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2756        self._inject_int_default_argument(
2757            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2758        args.append(ExprNodes.IntNode(
2759            node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2760
2761        method_call = self._substitute_method_call(
2762            node, function,
2763            "__Pyx_Py%s_Tailmatch" % type_name.capitalize(),
2764            self.PyString_Tailmatch_func_type,
2765            method_name, is_unbound_method, args,
2766            utility_code = utility_code)
2767        return method_call.coerce_to(Builtin.bool_type, self.current_env())
2768
2769    PyUnicode_Find_func_type = PyrexTypes.CFuncType(
2770        PyrexTypes.c_py_ssize_t_type, [
2771            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2772            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2773            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2774            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2775            PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2776            ],
2777        exception_value = '-2')
2778
2779    def _handle_simple_method_unicode_find(self, node, function, args, is_unbound_method):
2780        return self._inject_unicode_find(
2781            node, function, args, is_unbound_method, 'find', +1)
2782
2783    def _handle_simple_method_unicode_rfind(self, node, function, args, is_unbound_method):
2784        return self._inject_unicode_find(
2785            node, function, args, is_unbound_method, 'rfind', -1)
2786
2787    def _inject_unicode_find(self, node, function, args, is_unbound_method,
2788                             method_name, direction):
2789        """Replace unicode.find(...) and unicode.rfind(...) by a
2790        direct call to the corresponding C-API function.
2791        """
2792        if len(args) not in (2,3,4):
2793            self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2794            return node
2795        self._inject_int_default_argument(
2796            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2797        self._inject_int_default_argument(
2798            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2799        args.append(ExprNodes.IntNode(
2800            node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2801
2802        method_call = self._substitute_method_call(
2803            node, function, "PyUnicode_Find", self.PyUnicode_Find_func_type,
2804            method_name, is_unbound_method, args)
2805        return method_call.coerce_to_pyobject(self.current_env())
2806
2807    PyUnicode_Count_func_type = PyrexTypes.CFuncType(
2808        PyrexTypes.c_py_ssize_t_type, [
2809            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2810            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2811            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2812            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2813            ],
2814        exception_value = '-1')
2815
2816    def _handle_simple_method_unicode_count(self, node, function, args, is_unbound_method):
2817        """Replace unicode.count(...) by a direct call to the
2818        corresponding C-API function.
2819        """
2820        if len(args) not in (2,3,4):
2821            self._error_wrong_arg_count('unicode.count', node, args, "2-4")
2822            return node
2823        self._inject_int_default_argument(
2824            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2825        self._inject_int_default_argument(
2826            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2827
2828        method_call = self._substitute_method_call(
2829            node, function, "PyUnicode_Count", self.PyUnicode_Count_func_type,
2830            'count', is_unbound_method, args)
2831        return method_call.coerce_to_pyobject(self.current_env())
2832
2833    PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
2834        Builtin.unicode_type, [
2835            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2836            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2837            PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
2838            PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
2839            ])
2840
2841    def _handle_simple_method_unicode_replace(self, node, function, args, is_unbound_method):
2842        """Replace unicode.replace(...) by a direct call to the
2843        corresponding C-API function.
2844        """
2845        if len(args) not in (3,4):
2846            self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
2847            return node
2848        self._inject_int_default_argument(
2849            node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
2850
2851        return self._substitute_method_call(
2852            node, function, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
2853            'replace', is_unbound_method, args)
2854
2855    PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
2856        Builtin.bytes_type, [
2857            PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2858            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2859            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2860            ])
2861
2862    PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
2863        Builtin.bytes_type, [
2864            PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2865            ])
2866
2867    _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
2868                          'unicode_escape', 'raw_unicode_escape']
2869
2870    _special_codecs = [ (name, codecs.getencoder(name))
2871                        for name in _special_encodings ]
2872
2873    def _handle_simple_method_unicode_encode(self, node, function, args, is_unbound_method):
2874        """Replace unicode.encode(...) by a direct C-API call to the
2875        corresponding codec.
2876        """
2877        if len(args) < 1 or len(args) > 3:
2878            self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
2879            return node
2880
2881        string_node = args[0]
2882
2883        if len(args) == 1:
2884            null_node = ExprNodes.NullNode(node.pos)
2885            return self._substitute_method_call(
2886                node, function, "PyUnicode_AsEncodedString",
2887                self.PyUnicode_AsEncodedString_func_type,
2888                'encode', is_unbound_method, [string_node, null_node, null_node])
2889
2890        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2891        if parameters is None:
2892            return node
2893        encoding, encoding_node, error_handling, error_handling_node = parameters
2894
2895        if encoding and isinstance(string_node, ExprNodes.UnicodeNode):
2896            # constant, so try to do the encoding at compile time
2897            try:
2898                value = string_node.value.encode(encoding, error_handling)
2899            except:
2900                # well, looks like we can't
2901                pass
2902            else:
2903                value = BytesLiteral(value)
2904                value.encoding = encoding
2905                return ExprNodes.BytesNode(
2906                    string_node.pos, value=value, type=Builtin.bytes_type)
2907
2908        if encoding and error_handling == 'strict':
2909            # try to find a specific encoder function
2910            codec_name = self._find_special_codec_name(encoding)
2911            if codec_name is not None:
2912                encode_function = "PyUnicode_As%sString" % codec_name
2913                return self._substitute_method_call(
2914                    node, function, encode_function,
2915                    self.PyUnicode_AsXyzString_func_type,
2916                    'encode', is_unbound_method, [string_node])
2917
2918        return self._substitute_method_call(
2919            node, function, "PyUnicode_AsEncodedString",
2920            self.PyUnicode_AsEncodedString_func_type,
2921            'encode', is_unbound_method,
2922            [string_node, encoding_node, error_handling_node])
2923
2924    PyUnicode_DecodeXyz_func_ptr_type = PyrexTypes.CPtrType(PyrexTypes.CFuncType(
2925        Builtin.unicode_type, [
2926            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2927            PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2928            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2929            ]))
2930
2931    _decode_c_string_func_type = PyrexTypes.CFuncType(
2932        Builtin.unicode_type, [
2933            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2934            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2935            PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
2936            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2937            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2938            PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
2939            ])
2940
2941    _decode_bytes_func_type = PyrexTypes.CFuncType(
2942        Builtin.unicode_type, [
2943            PyrexTypes.CFuncTypeArg("string", PyrexTypes.py_object_type, None),
2944            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2945            PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
2946            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2947            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2948            PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
2949            ])
2950
2951    _decode_cpp_string_func_type = None  # lazy init
2952
2953    def _handle_simple_method_bytes_decode(self, node, function, args, is_unbound_method):
2954        """Replace char*.decode() by a direct C-API call to the
2955        corresponding codec, possibly resolving a slice on the char*.
2956        """
2957        if not (1 <= len(args) <= 3):
2958            self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
2959            return node
2960
2961        # normalise input nodes
2962        string_node = args[0]
2963        start = stop = None
2964        if isinstance(string_node, ExprNodes.SliceIndexNode):
2965            index_node = string_node
2966            string_node = index_node.base
2967            start, stop = index_node.start, index_node.stop
2968            if not start or start.constant_result == 0:
2969                start = None
2970        if isinstance(string_node, ExprNodes.CoerceToPyTypeNode):
2971            string_node = string_node.arg
2972
2973        string_type = string_node.type
2974        if string_type in (Builtin.bytes_type, Builtin.bytearray_type):
2975            if is_unbound_method:
2976                string_node = string_node.as_none_safe_node(
2977                    "descriptor '%s' requires a '%s' object but received a 'NoneType'",
2978                    format_args=['decode', string_type.name])
2979            else:
2980                string_node = string_node.as_none_safe_node(
2981                    "'NoneType' object has no attribute '%s'",
2982                    error="PyExc_AttributeError",
2983                    format_args=['decode'])
2984        elif not string_type.is_string and not string_type.is_cpp_string:
2985            # nothing to optimise here
2986            return node
2987
2988        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2989        if parameters is None:
2990            return node
2991        encoding, encoding_node, error_handling, error_handling_node = parameters
2992
2993        if not start:
2994            start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
2995        elif not start.type.is_int:
2996            start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2997        if stop and not stop.type.is_int:
2998            stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2999
3000        # try to find a specific encoder function
3001        codec_name = None
3002        if encoding is not None:
3003            codec_name = self._find_special_codec_name(encoding)
3004        if codec_name is not None:
3005            decode_function = ExprNodes.RawCNameExprNode(
3006                node.pos, type=self.PyUnicode_DecodeXyz_func_ptr_type,
3007                cname="PyUnicode_Decode%s" % codec_name)
3008            encoding_node = ExprNodes.NullNode(node.pos)
3009        else:
3010            decode_function = ExprNodes.NullNode(node.pos)
3011
3012        # build the helper function call
3013        temps = []
3014        if string_type.is_string:
3015            # C string
3016            if not stop:
3017                # use strlen() to find the string length, just as CPython would
3018                if not string_node.is_name:
3019                    string_node = UtilNodes.LetRefNode(string_node) # used twice
3020                    temps.append(string_node)
3021                stop = ExprNodes.PythonCapiCallNode(
3022                    string_node.pos, "strlen", self.Pyx_strlen_func_type,
3023                    args=[string_node],
3024                    is_temp=False,
3025                    utility_code=UtilityCode.load_cached("IncludeStringH", "StringTools.c"),
3026                ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
3027            helper_func_type = self._decode_c_string_func_type
3028            utility_code_name = 'decode_c_string'
3029        elif string_type.is_cpp_string:
3030            # C++ std::string
3031            if not stop:
3032                stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
3033                                         constant_result=ExprNodes.not_a_constant)
3034            if self._decode_cpp_string_func_type is None:
3035                # lazy init to reuse the C++ string type
3036                self._decode_cpp_string_func_type = PyrexTypes.CFuncType(
3037                    Builtin.unicode_type, [
3038                        PyrexTypes.CFuncTypeArg("string", string_type, None),
3039                        PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
3040                        PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
3041                        PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
3042                        PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
3043                        PyrexTypes.CFuncTypeArg("decode_func", self.PyUnicode_DecodeXyz_func_ptr_type, None),
3044                    ])
3045            helper_func_type = self._decode_cpp_string_func_type
3046            utility_code_name = 'decode_cpp_string'
3047        else:
3048            # Python bytes/bytearray object
3049            if not stop:
3050                stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
3051                                         constant_result=ExprNodes.not_a_constant)
3052            helper_func_type = self._decode_bytes_func_type
3053            if string_type is Builtin.bytes_type:
3054                utility_code_name = 'decode_bytes'
3055            else:
3056                utility_code_name = 'decode_bytearray'
3057
3058        node = ExprNodes.PythonCapiCallNode(
3059            node.pos, '__Pyx_%s' % utility_code_name, helper_func_type,
3060            args=[string_node, start, stop, encoding_node, error_handling_node, decode_function],
3061            is_temp=node.is_temp,
3062            utility_code=UtilityCode.load_cached(utility_code_name, 'StringTools.c'),
3063        )
3064
3065        for temp in temps[::-1]:
3066            node = UtilNodes.EvalWithTempExprNode(temp, node)
3067        return node
3068
3069    _handle_simple_method_bytearray_decode = _handle_simple_method_bytes_decode
3070
3071    def _find_special_codec_name(self, encoding):
3072        try:
3073            requested_codec = codecs.getencoder(encoding)
3074        except LookupError:
3075            return None
3076        for name, codec in self._special_codecs:
3077            if codec == requested_codec:
3078                if '_' in name:
3079                    name = ''.join([s.capitalize()
3080                                    for s in name.split('_')])
3081                return name
3082        return None
3083
3084    def _unpack_encoding_and_error_mode(self, pos, args):
3085        null_node = ExprNodes.NullNode(pos)
3086
3087        if len(args) >= 2:
3088            encoding, encoding_node = self._unpack_string_and_cstring_node(args[1])
3089            if encoding_node is None:
3090                return None
3091        else:
3092            encoding = None
3093            encoding_node = null_node
3094
3095        if len(args) == 3:
3096            error_handling, error_handling_node = self._unpack_string_and_cstring_node(args[2])
3097            if error_handling_node is None:
3098                return None
3099            if error_handling == 'strict':
3100                error_handling_node = null_node
3101        else:
3102            error_handling = 'strict'
3103            error_handling_node = null_node
3104
3105        return (encoding, encoding_node, error_handling, error_handling_node)
3106
3107    def _unpack_string_and_cstring_node(self, node):
3108        if isinstance(node, ExprNodes.CoerceToPyTypeNode):
3109            node = node.arg
3110        if isinstance(node, ExprNodes.UnicodeNode):
3111            encoding = node.value
3112            node = ExprNodes.BytesNode(
3113                node.pos, value=BytesLiteral(encoding.utf8encode()),
3114                type=PyrexTypes.c_char_ptr_type)
3115        elif isinstance(node, (ExprNodes.StringNode, ExprNodes.BytesNode)):
3116            encoding = node.value.decode('ISO-8859-1')
3117            node = ExprNodes.BytesNode(
3118                node.pos, value=node.value, type=PyrexTypes.c_char_ptr_type)
3119        elif node.type is Builtin.bytes_type:
3120            encoding = None
3121            node = node.coerce_to(PyrexTypes.c_char_ptr_type, self.current_env())
3122        elif node.type.is_string:
3123            encoding = None
3124        else:
3125            encoding = node = None
3126        return encoding, node
3127
3128    def _handle_simple_method_str_endswith(self, node, function, args, is_unbound_method):
3129        return self._inject_tailmatch(
3130            node, function, args, is_unbound_method, 'str', 'endswith',
3131            str_tailmatch_utility_code, +1)
3132
3133    def _handle_simple_method_str_startswith(self, node, function, args, is_unbound_method):
3134        return self._inject_tailmatch(
3135            node, function, args, is_unbound_method, 'str', 'startswith',
3136            str_tailmatch_utility_code, -1)
3137
3138    def _handle_simple_method_bytes_endswith(self, node, function, args, is_unbound_method):
3139        return self._inject_tailmatch(
3140            node, function, args, is_unbound_method, 'bytes', 'endswith',
3141            bytes_tailmatch_utility_code, +1)
3142
3143    def _handle_simple_method_bytes_startswith(self, node, function, args, is_unbound_method):
3144        return self._inject_tailmatch(
3145            node, function, args, is_unbound_method, 'bytes', 'startswith',
3146            bytes_tailmatch_utility_code, -1)
3147
3148    '''   # disabled for now, enable when we consider it worth it (see StringTools.c)
3149    def _handle_simple_method_bytearray_endswith(self, node, function, args, is_unbound_method):
3150        return self._inject_tailmatch(
3151            node, function, args, is_unbound_method, 'bytearray', 'endswith',
3152            bytes_tailmatch_utility_code, +1)
3153
3154    def _handle_simple_method_bytearray_startswith(self, node, function, args, is_unbound_method):
3155        return self._inject_tailmatch(
3156            node, function, args, is_unbound_method, 'bytearray', 'startswith',
3157            bytes_tailmatch_utility_code, -1)
3158    '''
3159
3160    ### helpers
3161
3162    def _substitute_method_call(self, node, function, name, func_type,
3163                                attr_name, is_unbound_method, args=(),
3164                                utility_code=None, is_temp=None,
3165                                may_return_none=ExprNodes.PythonCapiCallNode.may_return_none):
3166        args = list(args)
3167        if args and not args[0].is_literal:
3168            self_arg = args[0]
3169            if is_unbound_method:
3170                self_arg = self_arg.as_none_safe_node(
3171                    "descriptor '%s' requires a '%s' object but received a 'NoneType'",
3172                    format_args=[attr_name, function.obj.name])
3173            else:
3174                self_arg = self_arg.as_none_safe_node(
3175                    "'NoneType' object has no attribute '%s'",
3176                    error = "PyExc_AttributeError",
3177                    format_args = [attr_name])
3178            args[0] = self_arg
3179        if is_temp is None:
3180            is_temp = node.is_temp
3181        return ExprNodes.PythonCapiCallNode(
3182            node.pos, name, func_type,
3183            args = args,
3184            is_temp = is_temp,
3185            utility_code = utility_code,
3186            may_return_none = may_return_none,
3187            result_is_used = node.result_is_used,
3188            )
3189
3190    def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
3191        assert len(args) >= arg_index
3192        if len(args) == arg_index:
3193            args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
3194                                          type=type, constant_result=default_value))
3195        else:
3196            args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
3197
3198    def _inject_bint_default_argument(self, node, args, arg_index, default_value):
3199        assert len(args) >= arg_index
3200        if len(args) == arg_index:
3201            default_value = bool(default_value)
3202            args.append(ExprNodes.BoolNode(node.pos, value=default_value,
3203                                           constant_result=default_value))
3204        else:
3205            args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
3206
3207
3208unicode_tailmatch_utility_code = UtilityCode.load_cached('unicode_tailmatch', 'StringTools.c')
3209bytes_tailmatch_utility_code = UtilityCode.load_cached('bytes_tailmatch', 'StringTools.c')
3210str_tailmatch_utility_code = UtilityCode.load_cached('str_tailmatch', 'StringTools.c')
3211
3212
3213class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
3214    """Calculate the result of constant expressions to store it in
3215    ``expr_node.constant_result``, and replace trivial cases by their
3216    constant result.
3217
3218    General rules:
3219
3220    - We calculate float constants to make them available to the
3221      compiler, but we do not aggregate them into a single literal
3222      node to prevent any loss of precision.
3223
3224    - We recursively calculate constants from non-literal nodes to
3225      make them available to the compiler, but we only aggregate
3226      literal nodes at each step.  Non-literal nodes are never merged
3227      into a single node.
3228    """
3229
3230    def __init__(self, reevaluate=False):
3231        """
3232        The reevaluate argument specifies whether constant values that were
3233        previously computed should be recomputed.
3234        """
3235        super(ConstantFolding, self).__init__()
3236        self.reevaluate = reevaluate
3237
3238    def _calculate_const(self, node):
3239        if (not self.reevaluate and
3240                node.constant_result is not ExprNodes.constant_value_not_set):
3241            return
3242
3243        # make sure we always set the value
3244        not_a_constant = ExprNodes.not_a_constant
3245        node.constant_result = not_a_constant
3246
3247        # check if all children are constant
3248        children = self.visitchildren(node)
3249        for child_result in children.values():
3250            if type(child_result) is list:
3251                for child in child_result:
3252                    if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
3253                        return
3254            elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
3255                return
3256
3257        # now try to calculate the real constant value
3258        try:
3259            node.calculate_constant_result()
3260#            if node.constant_result is not ExprNodes.not_a_constant:
3261#                print node.__class__.__name__, node.constant_result
3262        except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
3263            # ignore all 'normal' errors here => no constant result
3264            pass
3265        except Exception:
3266            # this looks like a real error
3267            import traceback, sys
3268            traceback.print_exc(file=sys.stdout)
3269
3270    NODE_TYPE_ORDER = [ExprNodes.BoolNode, ExprNodes.CharNode,
3271                       ExprNodes.IntNode, ExprNodes.FloatNode]
3272
3273    def _widest_node_class(self, *nodes):
3274        try:
3275            return self.NODE_TYPE_ORDER[
3276                max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
3277        except ValueError:
3278            return None
3279
3280    def _bool_node(self, node, value):
3281        value = bool(value)
3282        return ExprNodes.BoolNode(node.pos, value=value, constant_result=value)
3283
3284    def visit_ExprNode(self, node):
3285        self._calculate_const(node)
3286        return node
3287
3288    def visit_UnopNode(self, node):
3289        self._calculate_const(node)
3290        if not node.has_constant_result():
3291            if node.operator == '!':
3292                return self._handle_NotNode(node)
3293            return node
3294        if not node.operand.is_literal:
3295            return node
3296        if node.operator == '!':
3297            return self._bool_node(node, node.constant_result)
3298        elif isinstance(node.operand, ExprNodes.BoolNode):
3299            return ExprNodes.IntNode(node.pos, value=str(int(node.constant_result)),
3300                                     type=PyrexTypes.c_int_type,
3301                                     constant_result=int(node.constant_result))
3302        elif node.operator == '+':
3303            return self._handle_UnaryPlusNode(node)
3304        elif node.operator == '-':
3305            return self._handle_UnaryMinusNode(node)
3306        return node
3307
3308    _negate_operator = {
3309        'in': 'not_in',
3310        'not_in': 'in',
3311        'is': 'is_not',
3312        'is_not': 'is'
3313    }.get
3314
3315    def _handle_NotNode(self, node):
3316        operand = node.operand
3317        if isinstance(operand, ExprNodes.PrimaryCmpNode):
3318            operator = self._negate_operator(operand.operator)
3319            if operator:
3320                node = copy.copy(operand)
3321                node.operator = operator
3322                node = self.visit_PrimaryCmpNode(node)
3323        return node
3324
3325    def _handle_UnaryMinusNode(self, node):
3326        def _negate(value):
3327            if value.startswith('-'):
3328                value = value[1:]
3329            else:
3330                value = '-' + value
3331            return value
3332
3333        node_type = node.operand.type
3334        if isinstance(node.operand, ExprNodes.FloatNode):
3335            # this is a safe operation
3336            return ExprNodes.FloatNode(node.pos, value=_negate(node.operand.value),
3337                                       type=node_type,
3338                                       constant_result=node.constant_result)
3339        if node_type.is_int and node_type.signed or \
3340                isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
3341            return ExprNodes.IntNode(node.pos, value=_negate(node.operand.value),
3342                                     type=node_type,
3343                                     longness=node.operand.longness,
3344                                     constant_result=node.constant_result)
3345        return node
3346
3347    def _handle_UnaryPlusNode(self, node):
3348        if (node.operand.has_constant_result() and
3349                    node.constant_result == node.operand.constant_result):
3350            return node.operand
3351        return node
3352
3353    def visit_BoolBinopNode(self, node):
3354        self._calculate_const(node)
3355        if not node.operand1.has_constant_result():
3356            return node
3357        if node.operand1.constant_result:
3358            if node.operator == 'and':
3359                return node.operand2
3360            else:
3361                return node.operand1
3362        else:
3363            if node.operator == 'and':
3364                return node.operand1
3365            else:
3366                return node.operand2
3367
3368    def visit_BinopNode(self, node):
3369        self._calculate_const(node)
3370        if node.constant_result is ExprNodes.not_a_constant:
3371            return node
3372        if isinstance(node.constant_result, float):
3373            return node
3374        operand1, operand2 = node.operand1, node.operand2
3375        if not operand1.is_literal or not operand2.is_literal:
3376            return node
3377
3378        # now inject a new constant node with the calculated value
3379        try:
3380            type1, type2 = operand1.type, operand2.type
3381            if type1 is None or type2 is None:
3382                return node
3383        except AttributeError:
3384            return node
3385
3386        if type1.is_numeric and type2.is_numeric:
3387            widest_type = PyrexTypes.widest_numeric_type(type1, type2)
3388        else:
3389            widest_type = PyrexTypes.py_object_type
3390
3391        target_class = self._widest_node_class(operand1, operand2)
3392        if target_class is None:
3393            return node
3394        elif target_class is ExprNodes.BoolNode and node.operator in '+-//<<%**>>':
3395            # C arithmetic results in at least an int type
3396            target_class = ExprNodes.IntNode
3397        elif target_class is ExprNodes.CharNode and node.operator in '+-//<<%**>>&|^':
3398            # C arithmetic results in at least an int type
3399            target_class = ExprNodes.IntNode
3400
3401        if target_class is ExprNodes.IntNode:
3402            unsigned = getattr(operand1, 'unsigned', '') and \
3403                       getattr(operand2, 'unsigned', '')
3404            longness = "LL"[:max(len(getattr(operand1, 'longness', '')),
3405                                 len(getattr(operand2, 'longness', '')))]
3406            new_node = ExprNodes.IntNode(pos=node.pos,
3407                                         unsigned=unsigned, longness=longness,
3408                                         value=str(int(node.constant_result)),
3409                                         constant_result=int(node.constant_result))
3410            # IntNode is smart about the type it chooses, so we just
3411            # make sure we were not smarter this time
3412            if widest_type.is_pyobject or new_node.type.is_pyobject:
3413                new_node.type = PyrexTypes.py_object_type
3414            else:
3415                new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
3416        else:
3417            if target_class is ExprNodes.BoolNode:
3418                node_value = node.constant_result
3419            else:
3420                node_value = str(node.constant_result)
3421            new_node = target_class(pos=node.pos, type = widest_type,
3422                                    value = node_value,
3423                                    constant_result = node.constant_result)
3424        return new_node
3425
3426    def visit_MulNode(self, node):
3427        self._calculate_const(node)
3428        if node.operand1.is_sequence_constructor:
3429            return self._calculate_constant_seq(node, node.operand1, node.operand2)
3430        if isinstance(node.operand1, ExprNodes.IntNode) and \
3431                node.operand2.is_sequence_constructor:
3432            return self._calculate_constant_seq(node, node.operand2, node.operand1)
3433        return self.visit_BinopNode(node)
3434
3435    def _calculate_constant_seq(self, node, sequence_node, factor):
3436        if factor.constant_result != 1 and sequence_node.args:
3437            if isinstance(factor.constant_result, (int, long)) and factor.constant_result <= 0:
3438                del sequence_node.args[:]
3439                sequence_node.mult_factor = None
3440            elif sequence_node.mult_factor is not None:
3441                if (isinstance(factor.constant_result, (int, long)) and
3442                        isinstance(sequence_node.mult_factor.constant_result, (int, long))):
3443                    value = sequence_node.mult_factor.constant_result * factor.constant_result
3444                    sequence_node.mult_factor = ExprNodes.IntNode(
3445                        sequence_node.mult_factor.pos,
3446                        value=str(value), constant_result=value)
3447                else:
3448                    # don't know if we can combine the factors, so don't
3449                    return self.visit_BinopNode(node)
3450            else:
3451                sequence_node.mult_factor = factor
3452        return sequence_node
3453
3454    def visit_PrimaryCmpNode(self, node):
3455        # calculate constant partial results in the comparison cascade
3456        self.visitchildren(node, ['operand1'])
3457        left_node = node.operand1
3458        cmp_node = node
3459        while cmp_node is not None:
3460            self.visitchildren(cmp_node, ['operand2'])
3461            right_node = cmp_node.operand2
3462            cmp_node.constant_result = not_a_constant
3463            if left_node.has_constant_result() and right_node.has_constant_result():
3464                try:
3465                    cmp_node.calculate_cascaded_constant_result(left_node.constant_result)
3466                except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
3467                    pass  # ignore all 'normal' errors here => no constant result
3468            left_node = right_node
3469            cmp_node = cmp_node.cascade
3470
3471        if not node.cascade:
3472            if node.has_constant_result():
3473                return self._bool_node(node, node.constant_result)
3474            return node
3475
3476        # collect partial cascades: [[value, CmpNode...], [value, CmpNode, ...], ...]
3477        cascades = [[node.operand1]]
3478        final_false_result = []
3479
3480        def split_cascades(cmp_node):
3481            if cmp_node.has_constant_result():
3482                if not cmp_node.constant_result:
3483                    # False => short-circuit
3484                    final_false_result.append(self._bool_node(cmp_node, False))
3485                    return
3486                else:
3487                    # True => discard and start new cascade
3488                    cascades.append([cmp_node.operand2])
3489            else:
3490                # not constant => append to current cascade
3491                cascades[-1].append(cmp_node)
3492            if cmp_node.cascade:
3493                split_cascades(cmp_node.cascade)
3494
3495        split_cascades(node)
3496
3497        cmp_nodes = []
3498        for cascade in cascades:
3499            if len(cascade) < 2:
3500                continue
3501            cmp_node = cascade[1]
3502            pcmp_node = ExprNodes.PrimaryCmpNode(
3503                cmp_node.pos,
3504                operand1=cascade[0],
3505                operator=cmp_node.operator,
3506                operand2=cmp_node.operand2,
3507                constant_result=not_a_constant)
3508            cmp_nodes.append(pcmp_node)
3509
3510            last_cmp_node = pcmp_node
3511            for cmp_node in cascade[2:]:
3512                last_cmp_node.cascade = cmp_node
3513                last_cmp_node = cmp_node
3514            last_cmp_node.cascade = None
3515
3516        if final_false_result:
3517            # last cascade was constant False
3518            cmp_nodes.append(final_false_result[0])
3519        elif not cmp_nodes:
3520            # only constants, but no False result
3521            return self._bool_node(node, True)
3522        node = cmp_nodes[0]
3523        if len(cmp_nodes) == 1:
3524            if node.has_constant_result():
3525                return self._bool_node(node, node.constant_result)
3526        else:
3527            for cmp_node in cmp_nodes[1:]:
3528                node = ExprNodes.BoolBinopNode(
3529                    node.pos,
3530                    operand1=node,
3531                    operator='and',
3532                    operand2=cmp_node,
3533                    constant_result=not_a_constant)
3534        return node
3535
3536    def visit_CondExprNode(self, node):
3537        self._calculate_const(node)
3538        if not node.test.has_constant_result():
3539            return node
3540        if node.test.constant_result:
3541            return node.true_val
3542        else:
3543            return node.false_val
3544
3545    def visit_IfStatNode(self, node):
3546        self.visitchildren(node)
3547        # eliminate dead code based on constant condition results
3548        if_clauses = []
3549        for if_clause in node.if_clauses:
3550            condition = if_clause.condition
3551            if condition.has_constant_result():
3552                if condition.constant_result:
3553                    # always true => subsequent clauses can safely be dropped
3554                    node.else_clause = if_clause.body
3555                    break
3556                # else: false => drop clause
3557            else:
3558                # unknown result => normal runtime evaluation
3559                if_clauses.append(if_clause)
3560        if if_clauses:
3561            node.if_clauses = if_clauses
3562            return node
3563        elif node.else_clause:
3564            return node.else_clause
3565        else:
3566            return Nodes.StatListNode(node.pos, stats=[])
3567
3568    def visit_SliceIndexNode(self, node):
3569        self._calculate_const(node)
3570        # normalise start/stop values
3571        if node.start is None or node.start.constant_result is None:
3572            start = node.start = None
3573        else:
3574            start = node.start.constant_result
3575        if node.stop is None or node.stop.constant_result is None:
3576            stop = node.stop = None
3577        else:
3578            stop = node.stop.constant_result
3579        # cut down sliced constant sequences
3580        if node.constant_result is not not_a_constant:
3581            base = node.base
3582            if base.is_sequence_constructor and base.mult_factor is None:
3583                base.args = base.args[start:stop]
3584                return base
3585            elif base.is_string_literal:
3586                base = base.as_sliced_node(start, stop)
3587                if base is not None:
3588                    return base
3589        return node
3590
3591    def visit_ComprehensionNode(self, node):
3592        self.visitchildren(node)
3593        if isinstance(node.loop, Nodes.StatListNode) and not node.loop.stats:
3594            # loop was pruned already => transform into literal
3595            if node.type is Builtin.list_type:
3596                return ExprNodes.ListNode(
3597                    node.pos, args=[], constant_result=[])
3598            elif node.type is Builtin.set_type:
3599                return ExprNodes.SetNode(
3600                    node.pos, args=[], constant_result=set())
3601            elif node.type is Builtin.dict_type:
3602                return ExprNodes.DictNode(
3603                    node.pos, key_value_pairs=[], constant_result={})
3604        return node
3605
3606    def visit_ForInStatNode(self, node):
3607        self.visitchildren(node)
3608        sequence = node.iterator.sequence
3609        if isinstance(sequence, ExprNodes.SequenceNode):
3610            if not sequence.args:
3611                if node.else_clause:
3612                    return node.else_clause
3613                else:
3614                    # don't break list comprehensions
3615                    return Nodes.StatListNode(node.pos, stats=[])
3616            # iterating over a list literal? => tuples are more efficient
3617            if isinstance(sequence, ExprNodes.ListNode):
3618                node.iterator.sequence = sequence.as_tuple()
3619        return node
3620
3621    def visit_WhileStatNode(self, node):
3622        self.visitchildren(node)
3623        if node.condition and node.condition.has_constant_result():
3624            if node.condition.constant_result:
3625                node.condition = None
3626                node.else_clause = None
3627            else:
3628                return node.else_clause
3629        return node
3630
3631    def visit_ExprStatNode(self, node):
3632        self.visitchildren(node)
3633        if not isinstance(node.expr, ExprNodes.ExprNode):
3634            # ParallelRangeTransform does this ...
3635            return node
3636        # drop unused constant expressions
3637        if node.expr.has_constant_result():
3638            return None
3639        return node
3640
3641    # in the future, other nodes can have their own handler method here
3642    # that can replace them with a constant result node
3643
3644    visit_Node = Visitor.VisitorTransform.recurse_to_children
3645
3646
3647class FinalOptimizePhase(Visitor.CythonTransform):
3648    """
3649    This visitor handles several commuting optimizations, and is run
3650    just before the C code generation phase.
3651
3652    The optimizations currently implemented in this class are:
3653        - eliminate None assignment and refcounting for first assignment.
3654        - isinstance -> typecheck for cdef types
3655        - eliminate checks for None and/or types that became redundant after tree changes
3656    """
3657    def visit_SingleAssignmentNode(self, node):
3658        """Avoid redundant initialisation of local variables before their
3659        first assignment.
3660        """
3661        self.visitchildren(node)
3662        if node.first:
3663            lhs = node.lhs
3664            lhs.lhs_of_first_assignment = True
3665        return node
3666
3667    def visit_SimpleCallNode(self, node):
3668        """Replace generic calls to isinstance(x, type) by a more efficient
3669        type check.
3670        """
3671        self.visitchildren(node)
3672        if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
3673            if node.function.name == 'isinstance' and len(node.args) == 2:
3674                type_arg = node.args[1]
3675                if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
3676                    cython_scope = self.context.cython_scope
3677                    node.function.entry = cython_scope.lookup('PyObject_TypeCheck')
3678                    node.function.type = node.function.entry.type
3679                    PyTypeObjectPtr = PyrexTypes.CPtrType(cython_scope.lookup('PyTypeObject').type)
3680                    node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
3681        return node
3682
3683    def visit_PyTypeTestNode(self, node):
3684        """Remove tests for alternatively allowed None values from
3685        type tests when we know that the argument cannot be None
3686        anyway.
3687        """
3688        self.visitchildren(node)
3689        if not node.notnone:
3690            if not node.arg.may_be_none():
3691                node.notnone = True
3692        return node
3693
3694    def visit_NoneCheckNode(self, node):
3695        """Remove None checks from expressions that definitely do not
3696        carry a None value.
3697        """
3698        self.visitchildren(node)
3699        if not node.arg.may_be_none():
3700            return node.arg
3701        return node
3702
3703class ConsolidateOverflowCheck(Visitor.CythonTransform):
3704    """
3705    This class facilitates the sharing of overflow checking among all nodes
3706    of a nested arithmetic expression.  For example, given the expression
3707    a*b + c, where a, b, and x are all possibly overflowing ints, the entire
3708    sequence will be evaluated and the overflow bit checked only at the end.
3709    """
3710    overflow_bit_node = None
3711
3712    def visit_Node(self, node):
3713        if self.overflow_bit_node is not None:
3714            saved = self.overflow_bit_node
3715            self.overflow_bit_node = None
3716            self.visitchildren(node)
3717            self.overflow_bit_node = saved
3718        else:
3719            self.visitchildren(node)
3720        return node
3721
3722    def visit_NumBinopNode(self, node):
3723        if node.overflow_check and node.overflow_fold:
3724            top_level_overflow = self.overflow_bit_node is None
3725            if top_level_overflow:
3726                self.overflow_bit_node = node
3727            else:
3728                node.overflow_bit_node = self.overflow_bit_node
3729                node.overflow_check = False
3730            self.visitchildren(node)
3731            if top_level_overflow:
3732                self.overflow_bit_node = None
3733        else:
3734            self.visitchildren(node)
3735        return node
3736