1from Cython.Compiler import TypeSlots 2from Cython.Compiler.ExprNodes import not_a_constant 3import cython 4cython.declare(UtilityCode=object, EncodedString=object, BytesLiteral=object, 5 Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object, 6 UtilNodes=object, Naming=object) 7 8import Nodes 9import ExprNodes 10import PyrexTypes 11import Visitor 12import Builtin 13import UtilNodes 14import Options 15import Naming 16 17from Code import UtilityCode 18from StringEncoding import EncodedString, BytesLiteral 19from Errors import error 20from ParseTreeTransforms import SkipDeclarations 21 22import copy 23import codecs 24 25try: 26 from __builtin__ import reduce 27except ImportError: 28 from functools import reduce 29 30try: 31 from __builtin__ import basestring 32except ImportError: 33 basestring = str # Python 3 34 35def load_c_utility(name): 36 return UtilityCode.load_cached(name, "Optimize.c") 37 38def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)): 39 if isinstance(node, coercion_nodes): 40 return node.arg 41 return node 42 43def unwrap_node(node): 44 while isinstance(node, UtilNodes.ResultRefNode): 45 node = node.expression 46 return node 47 48def is_common_value(a, b): 49 a = unwrap_node(a) 50 b = unwrap_node(b) 51 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode): 52 return a.name == b.name 53 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode): 54 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute 55 return False 56 57def filter_none_node(node): 58 if node is not None and node.constant_result is None: 59 return None 60 return node 61 62class IterationTransform(Visitor.EnvTransform): 63 """Transform some common for-in loop patterns into efficient C loops: 64 65 - for-in-dict loop becomes a while loop calling PyDict_Next() 66 - for-in-enumerate is replaced by an external counter variable 67 - for-in-range loop becomes a plain C for loop 68 """ 69 def visit_PrimaryCmpNode(self, node): 70 if node.is_ptr_contains(): 71 72 # for t in operand2: 73 # if operand1 == t: 74 # res = True 75 # break 76 # else: 77 # res = False 78 79 pos = node.pos 80 result_ref = UtilNodes.ResultRefNode(node) 81 if isinstance(node.operand2, ExprNodes.IndexNode): 82 base_type = node.operand2.base.type.base_type 83 else: 84 base_type = node.operand2.type.base_type 85 target_handle = UtilNodes.TempHandle(base_type) 86 target = target_handle.ref(pos) 87 cmp_node = ExprNodes.PrimaryCmpNode( 88 pos, operator=u'==', operand1=node.operand1, operand2=target) 89 if_body = Nodes.StatListNode( 90 pos, 91 stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)), 92 Nodes.BreakStatNode(pos)]) 93 if_node = Nodes.IfStatNode( 94 pos, 95 if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)], 96 else_clause=None) 97 for_loop = UtilNodes.TempsBlockNode( 98 pos, 99 temps = [target_handle], 100 body = Nodes.ForInStatNode( 101 pos, 102 target=target, 103 iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2), 104 body=if_node, 105 else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0)))) 106 for_loop = for_loop.analyse_expressions(self.current_env()) 107 for_loop = self.visit(for_loop) 108 new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop) 109 110 if node.operator == 'not_in': 111 new_node = ExprNodes.NotNode(pos, operand=new_node) 112 return new_node 113 114 else: 115 self.visitchildren(node) 116 return node 117 118 def visit_ForInStatNode(self, node): 119 self.visitchildren(node) 120 return self._optimise_for_loop(node, node.iterator.sequence) 121 122 def _optimise_for_loop(self, node, iterator, reversed=False): 123 if iterator.type is Builtin.dict_type: 124 # like iterating over dict.keys() 125 if reversed: 126 # CPython raises an error here: not a sequence 127 return node 128 return self._transform_dict_iteration( 129 node, dict_obj=iterator, method=None, keys=True, values=False) 130 131 # C array (slice) iteration? 132 if iterator.type.is_ptr or iterator.type.is_array: 133 return self._transform_carray_iteration(node, iterator, reversed=reversed) 134 if iterator.type is Builtin.bytes_type: 135 return self._transform_bytes_iteration(node, iterator, reversed=reversed) 136 if iterator.type is Builtin.unicode_type: 137 return self._transform_unicode_iteration(node, iterator, reversed=reversed) 138 139 # the rest is based on function calls 140 if not isinstance(iterator, ExprNodes.SimpleCallNode): 141 return node 142 143 if iterator.args is None: 144 arg_count = iterator.arg_tuple and len(iterator.arg_tuple.args) or 0 145 else: 146 arg_count = len(iterator.args) 147 if arg_count and iterator.self is not None: 148 arg_count -= 1 149 150 function = iterator.function 151 # dict iteration? 152 if function.is_attribute and not reversed and not arg_count: 153 base_obj = iterator.self or function.obj 154 method = function.attribute 155 # in Py3, items() is equivalent to Py2's iteritems() 156 is_safe_iter = self.global_scope().context.language_level >= 3 157 158 if not is_safe_iter and method in ('keys', 'values', 'items'): 159 # try to reduce this to the corresponding .iter*() methods 160 if isinstance(base_obj, ExprNodes.SimpleCallNode): 161 inner_function = base_obj.function 162 if (inner_function.is_name and inner_function.name == 'dict' 163 and inner_function.entry 164 and inner_function.entry.is_builtin): 165 # e.g. dict(something).items() => safe to use .iter*() 166 is_safe_iter = True 167 168 keys = values = False 169 if method == 'iterkeys' or (is_safe_iter and method == 'keys'): 170 keys = True 171 elif method == 'itervalues' or (is_safe_iter and method == 'values'): 172 values = True 173 elif method == 'iteritems' or (is_safe_iter and method == 'items'): 174 keys = values = True 175 176 if keys or values: 177 return self._transform_dict_iteration( 178 node, base_obj, method, keys, values) 179 180 # enumerate/reversed ? 181 if iterator.self is None and function.is_name and \ 182 function.entry and function.entry.is_builtin: 183 if function.name == 'enumerate': 184 if reversed: 185 # CPython raises an error here: not a sequence 186 return node 187 return self._transform_enumerate_iteration(node, iterator) 188 elif function.name == 'reversed': 189 if reversed: 190 # CPython raises an error here: not a sequence 191 return node 192 return self._transform_reversed_iteration(node, iterator) 193 194 # range() iteration? 195 if Options.convert_range and node.target.type.is_int: 196 if iterator.self is None and function.is_name and \ 197 function.entry and function.entry.is_builtin and \ 198 function.name in ('range', 'xrange'): 199 return self._transform_range_iteration(node, iterator, reversed=reversed) 200 201 return node 202 203 def _transform_reversed_iteration(self, node, reversed_function): 204 args = reversed_function.arg_tuple.args 205 if len(args) == 0: 206 error(reversed_function.pos, 207 "reversed() requires an iterable argument") 208 return node 209 elif len(args) > 1: 210 error(reversed_function.pos, 211 "reversed() takes exactly 1 argument") 212 return node 213 arg = args[0] 214 215 # reversed(list/tuple) ? 216 if arg.type in (Builtin.tuple_type, Builtin.list_type): 217 node.iterator.sequence = arg.as_none_safe_node("'NoneType' object is not iterable") 218 node.iterator.reversed = True 219 return node 220 221 return self._optimise_for_loop(node, arg, reversed=True) 222 223 PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType( 224 PyrexTypes.c_char_ptr_type, [ 225 PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None) 226 ]) 227 228 PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType( 229 PyrexTypes.c_py_ssize_t_type, [ 230 PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None) 231 ]) 232 233 def _transform_bytes_iteration(self, node, slice_node, reversed=False): 234 target_type = node.target.type 235 if not target_type.is_int and target_type is not Builtin.bytes_type: 236 # bytes iteration returns bytes objects in Py2, but 237 # integers in Py3 238 return node 239 240 unpack_temp_node = UtilNodes.LetRefNode( 241 slice_node.as_none_safe_node("'NoneType' is not iterable")) 242 243 slice_base_node = ExprNodes.PythonCapiCallNode( 244 slice_node.pos, "PyBytes_AS_STRING", 245 self.PyBytes_AS_STRING_func_type, 246 args = [unpack_temp_node], 247 is_temp = 0, 248 ) 249 len_node = ExprNodes.PythonCapiCallNode( 250 slice_node.pos, "PyBytes_GET_SIZE", 251 self.PyBytes_GET_SIZE_func_type, 252 args = [unpack_temp_node], 253 is_temp = 0, 254 ) 255 256 return UtilNodes.LetNode( 257 unpack_temp_node, 258 self._transform_carray_iteration( 259 node, 260 ExprNodes.SliceIndexNode( 261 slice_node.pos, 262 base = slice_base_node, 263 start = None, 264 step = None, 265 stop = len_node, 266 type = slice_base_node.type, 267 is_temp = 1, 268 ), 269 reversed = reversed)) 270 271 PyUnicode_READ_func_type = PyrexTypes.CFuncType( 272 PyrexTypes.c_py_ucs4_type, [ 273 PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_type, None), 274 PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_type, None), 275 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None) 276 ]) 277 278 init_unicode_iteration_func_type = PyrexTypes.CFuncType( 279 PyrexTypes.c_int_type, [ 280 PyrexTypes.CFuncTypeArg("s", PyrexTypes.py_object_type, None), 281 PyrexTypes.CFuncTypeArg("length", PyrexTypes.c_py_ssize_t_ptr_type, None), 282 PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_ptr_type, None), 283 PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_ptr_type, None) 284 ], 285 exception_value = '-1') 286 287 def _transform_unicode_iteration(self, node, slice_node, reversed=False): 288 if slice_node.is_literal: 289 # try to reduce to byte iteration for plain Latin-1 strings 290 try: 291 bytes_value = BytesLiteral(slice_node.value.encode('latin1')) 292 except UnicodeEncodeError: 293 pass 294 else: 295 bytes_slice = ExprNodes.SliceIndexNode( 296 slice_node.pos, 297 base=ExprNodes.BytesNode( 298 slice_node.pos, value=bytes_value, 299 constant_result=bytes_value, 300 type=PyrexTypes.c_char_ptr_type).coerce_to( 301 PyrexTypes.c_uchar_ptr_type, self.current_env()), 302 start=None, 303 stop=ExprNodes.IntNode( 304 slice_node.pos, value=str(len(bytes_value)), 305 constant_result=len(bytes_value), 306 type=PyrexTypes.c_py_ssize_t_type), 307 type=Builtin.unicode_type, # hint for Python conversion 308 ) 309 return self._transform_carray_iteration(node, bytes_slice, reversed) 310 311 unpack_temp_node = UtilNodes.LetRefNode( 312 slice_node.as_none_safe_node("'NoneType' is not iterable")) 313 314 start_node = ExprNodes.IntNode( 315 node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t_type) 316 length_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) 317 end_node = length_temp.ref(node.pos) 318 if reversed: 319 relation1, relation2 = '>', '>=' 320 start_node, end_node = end_node, start_node 321 else: 322 relation1, relation2 = '<=', '<' 323 324 kind_temp = UtilNodes.TempHandle(PyrexTypes.c_int_type) 325 data_temp = UtilNodes.TempHandle(PyrexTypes.c_void_ptr_type) 326 counter_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) 327 328 target_value = ExprNodes.PythonCapiCallNode( 329 slice_node.pos, "__Pyx_PyUnicode_READ", 330 self.PyUnicode_READ_func_type, 331 args = [kind_temp.ref(slice_node.pos), 332 data_temp.ref(slice_node.pos), 333 counter_temp.ref(node.target.pos)], 334 is_temp = False, 335 ) 336 if target_value.type != node.target.type: 337 target_value = target_value.coerce_to(node.target.type, 338 self.current_env()) 339 target_assign = Nodes.SingleAssignmentNode( 340 pos = node.target.pos, 341 lhs = node.target, 342 rhs = target_value) 343 body = Nodes.StatListNode( 344 node.pos, 345 stats = [target_assign, node.body]) 346 347 loop_node = Nodes.ForFromStatNode( 348 node.pos, 349 bound1=start_node, relation1=relation1, 350 target=counter_temp.ref(node.target.pos), 351 relation2=relation2, bound2=end_node, 352 step=None, body=body, 353 else_clause=node.else_clause, 354 from_range=True) 355 356 setup_node = Nodes.ExprStatNode( 357 node.pos, 358 expr = ExprNodes.PythonCapiCallNode( 359 slice_node.pos, "__Pyx_init_unicode_iteration", 360 self.init_unicode_iteration_func_type, 361 args = [unpack_temp_node, 362 ExprNodes.AmpersandNode(slice_node.pos, operand=length_temp.ref(slice_node.pos), 363 type=PyrexTypes.c_py_ssize_t_ptr_type), 364 ExprNodes.AmpersandNode(slice_node.pos, operand=data_temp.ref(slice_node.pos), 365 type=PyrexTypes.c_void_ptr_ptr_type), 366 ExprNodes.AmpersandNode(slice_node.pos, operand=kind_temp.ref(slice_node.pos), 367 type=PyrexTypes.c_int_ptr_type), 368 ], 369 is_temp = True, 370 result_is_used = False, 371 utility_code=UtilityCode.load_cached("unicode_iter", "Optimize.c"), 372 )) 373 return UtilNodes.LetNode( 374 unpack_temp_node, 375 UtilNodes.TempsBlockNode( 376 node.pos, temps=[counter_temp, length_temp, data_temp, kind_temp], 377 body=Nodes.StatListNode(node.pos, stats=[setup_node, loop_node]))) 378 379 def _transform_carray_iteration(self, node, slice_node, reversed=False): 380 neg_step = False 381 if isinstance(slice_node, ExprNodes.SliceIndexNode): 382 slice_base = slice_node.base 383 start = filter_none_node(slice_node.start) 384 stop = filter_none_node(slice_node.stop) 385 step = None 386 if not stop: 387 if not slice_base.type.is_pyobject: 388 error(slice_node.pos, "C array iteration requires known end index") 389 return node 390 391 elif isinstance(slice_node, ExprNodes.IndexNode): 392 assert isinstance(slice_node.index, ExprNodes.SliceNode) 393 slice_base = slice_node.base 394 index = slice_node.index 395 start = filter_none_node(index.start) 396 stop = filter_none_node(index.stop) 397 step = filter_none_node(index.step) 398 if step: 399 if not isinstance(step.constant_result, (int,long)) \ 400 or step.constant_result == 0 \ 401 or step.constant_result > 0 and not stop \ 402 or step.constant_result < 0 and not start: 403 if not slice_base.type.is_pyobject: 404 error(step.pos, "C array iteration requires known step size and end index") 405 return node 406 else: 407 # step sign is handled internally by ForFromStatNode 408 step_value = step.constant_result 409 if reversed: 410 step_value = -step_value 411 neg_step = step_value < 0 412 step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type, 413 value=str(abs(step_value)), 414 constant_result=abs(step_value)) 415 416 elif slice_node.type.is_array: 417 if slice_node.type.size is None: 418 error(slice_node.pos, "C array iteration requires known end index") 419 return node 420 slice_base = slice_node 421 start = None 422 stop = ExprNodes.IntNode( 423 slice_node.pos, value=str(slice_node.type.size), 424 type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size) 425 step = None 426 427 else: 428 if not slice_node.type.is_pyobject: 429 error(slice_node.pos, "C array iteration requires known end index") 430 return node 431 432 if start: 433 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) 434 if stop: 435 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) 436 if stop is None: 437 if neg_step: 438 stop = ExprNodes.IntNode( 439 slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1) 440 else: 441 error(slice_node.pos, "C array iteration requires known step size and end index") 442 return node 443 444 if reversed: 445 if not start: 446 start = ExprNodes.IntNode(slice_node.pos, value="0", constant_result=0, 447 type=PyrexTypes.c_py_ssize_t_type) 448 # if step was provided, it was already negated above 449 start, stop = stop, start 450 451 ptr_type = slice_base.type 452 if ptr_type.is_array: 453 ptr_type = ptr_type.element_ptr_type() 454 carray_ptr = slice_base.coerce_to_simple(self.current_env()) 455 456 if start and start.constant_result != 0: 457 start_ptr_node = ExprNodes.AddNode( 458 start.pos, 459 operand1=carray_ptr, 460 operator='+', 461 operand2=start, 462 type=ptr_type) 463 else: 464 start_ptr_node = carray_ptr 465 466 if stop and stop.constant_result != 0: 467 stop_ptr_node = ExprNodes.AddNode( 468 stop.pos, 469 operand1=ExprNodes.CloneNode(carray_ptr), 470 operator='+', 471 operand2=stop, 472 type=ptr_type 473 ).coerce_to_simple(self.current_env()) 474 else: 475 stop_ptr_node = ExprNodes.CloneNode(carray_ptr) 476 477 counter = UtilNodes.TempHandle(ptr_type) 478 counter_temp = counter.ref(node.target.pos) 479 480 if slice_base.type.is_string and node.target.type.is_pyobject: 481 # special case: char* -> bytes/unicode 482 if slice_node.type is Builtin.unicode_type: 483 target_value = ExprNodes.CastNode( 484 ExprNodes.DereferenceNode( 485 node.target.pos, operand=counter_temp, 486 type=ptr_type.base_type), 487 PyrexTypes.c_py_ucs4_type).coerce_to( 488 node.target.type, self.current_env()) 489 else: 490 # char* -> bytes coercion requires slicing, not indexing 491 target_value = ExprNodes.SliceIndexNode( 492 node.target.pos, 493 start=ExprNodes.IntNode(node.target.pos, value='0', 494 constant_result=0, 495 type=PyrexTypes.c_int_type), 496 stop=ExprNodes.IntNode(node.target.pos, value='1', 497 constant_result=1, 498 type=PyrexTypes.c_int_type), 499 base=counter_temp, 500 type=Builtin.bytes_type, 501 is_temp=1) 502 elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type): 503 # Allow iteration with pointer target to avoid copy. 504 target_value = counter_temp 505 else: 506 # TODO: can this safely be replaced with DereferenceNode() as above? 507 target_value = ExprNodes.IndexNode( 508 node.target.pos, 509 index=ExprNodes.IntNode(node.target.pos, value='0', 510 constant_result=0, 511 type=PyrexTypes.c_int_type), 512 base=counter_temp, 513 is_buffer_access=False, 514 type=ptr_type.base_type) 515 516 if target_value.type != node.target.type: 517 target_value = target_value.coerce_to(node.target.type, 518 self.current_env()) 519 520 target_assign = Nodes.SingleAssignmentNode( 521 pos = node.target.pos, 522 lhs = node.target, 523 rhs = target_value) 524 525 body = Nodes.StatListNode( 526 node.pos, 527 stats = [target_assign, node.body]) 528 529 relation1, relation2 = self._find_for_from_node_relations(neg_step, reversed) 530 531 for_node = Nodes.ForFromStatNode( 532 node.pos, 533 bound1=start_ptr_node, relation1=relation1, 534 target=counter_temp, 535 relation2=relation2, bound2=stop_ptr_node, 536 step=step, body=body, 537 else_clause=node.else_clause, 538 from_range=True) 539 540 return UtilNodes.TempsBlockNode( 541 node.pos, temps=[counter], 542 body=for_node) 543 544 def _transform_enumerate_iteration(self, node, enumerate_function): 545 args = enumerate_function.arg_tuple.args 546 if len(args) == 0: 547 error(enumerate_function.pos, 548 "enumerate() requires an iterable argument") 549 return node 550 elif len(args) > 2: 551 error(enumerate_function.pos, 552 "enumerate() takes at most 2 arguments") 553 return node 554 555 if not node.target.is_sequence_constructor: 556 # leave this untouched for now 557 return node 558 targets = node.target.args 559 if len(targets) != 2: 560 # leave this untouched for now 561 return node 562 563 enumerate_target, iterable_target = targets 564 counter_type = enumerate_target.type 565 566 if not counter_type.is_pyobject and not counter_type.is_int: 567 # nothing we can do here, I guess 568 return node 569 570 if len(args) == 2: 571 start = unwrap_coerced_node(args[1]).coerce_to(counter_type, self.current_env()) 572 else: 573 start = ExprNodes.IntNode(enumerate_function.pos, 574 value='0', 575 type=counter_type, 576 constant_result=0) 577 temp = UtilNodes.LetRefNode(start) 578 579 inc_expression = ExprNodes.AddNode( 580 enumerate_function.pos, 581 operand1 = temp, 582 operand2 = ExprNodes.IntNode(node.pos, value='1', 583 type=counter_type, 584 constant_result=1), 585 operator = '+', 586 type = counter_type, 587 #inplace = True, # not worth using in-place operation for Py ints 588 is_temp = counter_type.is_pyobject 589 ) 590 591 loop_body = [ 592 Nodes.SingleAssignmentNode( 593 pos = enumerate_target.pos, 594 lhs = enumerate_target, 595 rhs = temp), 596 Nodes.SingleAssignmentNode( 597 pos = enumerate_target.pos, 598 lhs = temp, 599 rhs = inc_expression) 600 ] 601 602 if isinstance(node.body, Nodes.StatListNode): 603 node.body.stats = loop_body + node.body.stats 604 else: 605 loop_body.append(node.body) 606 node.body = Nodes.StatListNode( 607 node.body.pos, 608 stats = loop_body) 609 610 node.target = iterable_target 611 node.item = node.item.coerce_to(iterable_target.type, self.current_env()) 612 node.iterator.sequence = args[0] 613 614 # recurse into loop to check for further optimisations 615 return UtilNodes.LetNode(temp, self._optimise_for_loop(node, node.iterator.sequence)) 616 617 def _find_for_from_node_relations(self, neg_step_value, reversed): 618 if reversed: 619 if neg_step_value: 620 return '<', '<=' 621 else: 622 return '>', '>=' 623 else: 624 if neg_step_value: 625 return '>=', '>' 626 else: 627 return '<=', '<' 628 629 def _transform_range_iteration(self, node, range_function, reversed=False): 630 args = range_function.arg_tuple.args 631 if len(args) < 3: 632 step_pos = range_function.pos 633 step_value = 1 634 step = ExprNodes.IntNode(step_pos, value='1', 635 constant_result=1) 636 else: 637 step = args[2] 638 step_pos = step.pos 639 if not isinstance(step.constant_result, (int, long)): 640 # cannot determine step direction 641 return node 642 step_value = step.constant_result 643 if step_value == 0: 644 # will lead to an error elsewhere 645 return node 646 if reversed and step_value not in (1, -1): 647 # FIXME: currently broken - requires calculation of the correct bounds 648 return node 649 if not isinstance(step, ExprNodes.IntNode): 650 step = ExprNodes.IntNode(step_pos, value=str(step_value), 651 constant_result=step_value) 652 653 if len(args) == 1: 654 bound1 = ExprNodes.IntNode(range_function.pos, value='0', 655 constant_result=0) 656 bound2 = args[0].coerce_to_integer(self.current_env()) 657 else: 658 bound1 = args[0].coerce_to_integer(self.current_env()) 659 bound2 = args[1].coerce_to_integer(self.current_env()) 660 661 relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed) 662 663 if reversed: 664 bound1, bound2 = bound2, bound1 665 if step_value < 0: 666 step_value = -step_value 667 else: 668 if step_value < 0: 669 step_value = -step_value 670 671 step.value = str(step_value) 672 step.constant_result = step_value 673 step = step.coerce_to_integer(self.current_env()) 674 675 if not bound2.is_literal: 676 # stop bound must be immutable => keep it in a temp var 677 bound2_is_temp = True 678 bound2 = UtilNodes.LetRefNode(bound2) 679 else: 680 bound2_is_temp = False 681 682 for_node = Nodes.ForFromStatNode( 683 node.pos, 684 target=node.target, 685 bound1=bound1, relation1=relation1, 686 relation2=relation2, bound2=bound2, 687 step=step, body=node.body, 688 else_clause=node.else_clause, 689 from_range=True) 690 691 if bound2_is_temp: 692 for_node = UtilNodes.LetNode(bound2, for_node) 693 694 return for_node 695 696 def _transform_dict_iteration(self, node, dict_obj, method, keys, values): 697 temps = [] 698 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type) 699 temps.append(temp) 700 dict_temp = temp.ref(dict_obj.pos) 701 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) 702 temps.append(temp) 703 pos_temp = temp.ref(node.pos) 704 705 key_target = value_target = tuple_target = None 706 if keys and values: 707 if node.target.is_sequence_constructor: 708 if len(node.target.args) == 2: 709 key_target, value_target = node.target.args 710 else: 711 # unusual case that may or may not lead to an error 712 return node 713 else: 714 tuple_target = node.target 715 elif keys: 716 key_target = node.target 717 else: 718 value_target = node.target 719 720 if isinstance(node.body, Nodes.StatListNode): 721 body = node.body 722 else: 723 body = Nodes.StatListNode(pos = node.body.pos, 724 stats = [node.body]) 725 726 # keep original length to guard against dict modification 727 dict_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) 728 temps.append(dict_len_temp) 729 dict_len_temp_addr = ExprNodes.AmpersandNode( 730 node.pos, operand=dict_len_temp.ref(dict_obj.pos), 731 type=PyrexTypes.c_ptr_type(dict_len_temp.type)) 732 temp = UtilNodes.TempHandle(PyrexTypes.c_int_type) 733 temps.append(temp) 734 is_dict_temp = temp.ref(node.pos) 735 is_dict_temp_addr = ExprNodes.AmpersandNode( 736 node.pos, operand=is_dict_temp, 737 type=PyrexTypes.c_ptr_type(temp.type)) 738 739 iter_next_node = Nodes.DictIterationNextNode( 740 dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp, 741 key_target, value_target, tuple_target, 742 is_dict_temp) 743 iter_next_node = iter_next_node.analyse_expressions(self.current_env()) 744 body.stats[0:0] = [iter_next_node] 745 746 if method: 747 method_node = ExprNodes.StringNode( 748 dict_obj.pos, is_identifier=True, value=method) 749 dict_obj = dict_obj.as_none_safe_node( 750 "'NoneType' object has no attribute '%s'", 751 error = "PyExc_AttributeError", 752 format_args = [method]) 753 else: 754 method_node = ExprNodes.NullNode(dict_obj.pos) 755 dict_obj = dict_obj.as_none_safe_node("'NoneType' object is not iterable") 756 757 def flag_node(value): 758 value = value and 1 or 0 759 return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value) 760 761 result_code = [ 762 Nodes.SingleAssignmentNode( 763 node.pos, 764 lhs = pos_temp, 765 rhs = ExprNodes.IntNode(node.pos, value='0', 766 constant_result=0)), 767 Nodes.SingleAssignmentNode( 768 dict_obj.pos, 769 lhs = dict_temp, 770 rhs = ExprNodes.PythonCapiCallNode( 771 dict_obj.pos, 772 "__Pyx_dict_iterator", 773 self.PyDict_Iterator_func_type, 774 utility_code = UtilityCode.load_cached("dict_iter", "Optimize.c"), 775 args = [dict_obj, flag_node(dict_obj.type is Builtin.dict_type), 776 method_node, dict_len_temp_addr, is_dict_temp_addr, 777 ], 778 is_temp=True, 779 )), 780 Nodes.WhileStatNode( 781 node.pos, 782 condition = None, 783 body = body, 784 else_clause = node.else_clause 785 ) 786 ] 787 788 return UtilNodes.TempsBlockNode( 789 node.pos, temps=temps, 790 body=Nodes.StatListNode( 791 node.pos, 792 stats = result_code 793 )) 794 795 PyDict_Iterator_func_type = PyrexTypes.CFuncType( 796 PyrexTypes.py_object_type, [ 797 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), 798 PyrexTypes.CFuncTypeArg("is_dict", PyrexTypes.c_int_type, None), 799 PyrexTypes.CFuncTypeArg("method_name", PyrexTypes.py_object_type, None), 800 PyrexTypes.CFuncTypeArg("p_orig_length", PyrexTypes.c_py_ssize_t_ptr_type, None), 801 PyrexTypes.CFuncTypeArg("p_is_dict", PyrexTypes.c_int_ptr_type, None), 802 ]) 803 804 805class SwitchTransform(Visitor.VisitorTransform): 806 """ 807 This transformation tries to turn long if statements into C switch statements. 808 The requirement is that every clause be an (or of) var == value, where the var 809 is common among all clauses and both var and value are ints. 810 """ 811 NO_MATCH = (None, None, None) 812 813 def extract_conditions(self, cond, allow_not_in): 814 while True: 815 if isinstance(cond, (ExprNodes.CoerceToTempNode, 816 ExprNodes.CoerceToBooleanNode)): 817 cond = cond.arg 818 elif isinstance(cond, UtilNodes.EvalWithTempExprNode): 819 # this is what we get from the FlattenInListTransform 820 cond = cond.subexpression 821 elif isinstance(cond, ExprNodes.TypecastNode): 822 cond = cond.operand 823 else: 824 break 825 826 if isinstance(cond, ExprNodes.PrimaryCmpNode): 827 if cond.cascade is not None: 828 return self.NO_MATCH 829 elif cond.is_c_string_contains() and \ 830 isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)): 831 not_in = cond.operator == 'not_in' 832 if not_in and not allow_not_in: 833 return self.NO_MATCH 834 if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \ 835 cond.operand2.contains_surrogates(): 836 # dealing with surrogates leads to different 837 # behaviour on wide and narrow Unicode 838 # platforms => refuse to optimise this case 839 return self.NO_MATCH 840 return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2) 841 elif not cond.is_python_comparison(): 842 if cond.operator == '==': 843 not_in = False 844 elif allow_not_in and cond.operator == '!=': 845 not_in = True 846 else: 847 return self.NO_MATCH 848 # this looks somewhat silly, but it does the right 849 # checks for NameNode and AttributeNode 850 if is_common_value(cond.operand1, cond.operand1): 851 if cond.operand2.is_literal: 852 return not_in, cond.operand1, [cond.operand2] 853 elif getattr(cond.operand2, 'entry', None) \ 854 and cond.operand2.entry.is_const: 855 return not_in, cond.operand1, [cond.operand2] 856 if is_common_value(cond.operand2, cond.operand2): 857 if cond.operand1.is_literal: 858 return not_in, cond.operand2, [cond.operand1] 859 elif getattr(cond.operand1, 'entry', None) \ 860 and cond.operand1.entry.is_const: 861 return not_in, cond.operand2, [cond.operand1] 862 elif isinstance(cond, ExprNodes.BoolBinopNode): 863 if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'): 864 allow_not_in = (cond.operator == 'and') 865 not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in) 866 not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in) 867 if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2): 868 if (not not_in_1) or allow_not_in: 869 return not_in_1, t1, c1+c2 870 return self.NO_MATCH 871 872 def extract_in_string_conditions(self, string_literal): 873 if isinstance(string_literal, ExprNodes.UnicodeNode): 874 charvals = list(map(ord, set(string_literal.value))) 875 charvals.sort() 876 return [ ExprNodes.IntNode(string_literal.pos, value=str(charval), 877 constant_result=charval) 878 for charval in charvals ] 879 else: 880 # this is a bit tricky as Py3's bytes type returns 881 # integers on iteration, whereas Py2 returns 1-char byte 882 # strings 883 characters = string_literal.value 884 characters = list(set([ characters[i:i+1] for i in range(len(characters)) ])) 885 characters.sort() 886 return [ ExprNodes.CharNode(string_literal.pos, value=charval, 887 constant_result=charval) 888 for charval in characters ] 889 890 def extract_common_conditions(self, common_var, condition, allow_not_in): 891 not_in, var, conditions = self.extract_conditions(condition, allow_not_in) 892 if var is None: 893 return self.NO_MATCH 894 elif common_var is not None and not is_common_value(var, common_var): 895 return self.NO_MATCH 896 elif not (var.type.is_int or var.type.is_enum) or sum([not (cond.type.is_int or cond.type.is_enum) for cond in conditions]): 897 return self.NO_MATCH 898 return not_in, var, conditions 899 900 def has_duplicate_values(self, condition_values): 901 # duplicated values don't work in a switch statement 902 seen = set() 903 for value in condition_values: 904 if value.has_constant_result(): 905 if value.constant_result in seen: 906 return True 907 seen.add(value.constant_result) 908 else: 909 # this isn't completely safe as we don't know the 910 # final C value, but this is about the best we can do 911 try: 912 if value.entry.cname in seen: 913 return True 914 except AttributeError: 915 return True # play safe 916 seen.add(value.entry.cname) 917 return False 918 919 def visit_IfStatNode(self, node): 920 common_var = None 921 cases = [] 922 for if_clause in node.if_clauses: 923 _, common_var, conditions = self.extract_common_conditions( 924 common_var, if_clause.condition, False) 925 if common_var is None: 926 self.visitchildren(node) 927 return node 928 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos, 929 conditions = conditions, 930 body = if_clause.body)) 931 932 condition_values = [ 933 cond for case in cases for cond in case.conditions] 934 if len(condition_values) < 2: 935 self.visitchildren(node) 936 return node 937 if self.has_duplicate_values(condition_values): 938 self.visitchildren(node) 939 return node 940 941 common_var = unwrap_node(common_var) 942 switch_node = Nodes.SwitchStatNode(pos = node.pos, 943 test = common_var, 944 cases = cases, 945 else_clause = node.else_clause) 946 return switch_node 947 948 def visit_CondExprNode(self, node): 949 not_in, common_var, conditions = self.extract_common_conditions( 950 None, node.test, True) 951 if common_var is None \ 952 or len(conditions) < 2 \ 953 or self.has_duplicate_values(conditions): 954 self.visitchildren(node) 955 return node 956 return self.build_simple_switch_statement( 957 node, common_var, conditions, not_in, 958 node.true_val, node.false_val) 959 960 def visit_BoolBinopNode(self, node): 961 not_in, common_var, conditions = self.extract_common_conditions( 962 None, node, True) 963 if common_var is None \ 964 or len(conditions) < 2 \ 965 or self.has_duplicate_values(conditions): 966 self.visitchildren(node) 967 return node 968 969 return self.build_simple_switch_statement( 970 node, common_var, conditions, not_in, 971 ExprNodes.BoolNode(node.pos, value=True, constant_result=True), 972 ExprNodes.BoolNode(node.pos, value=False, constant_result=False)) 973 974 def visit_PrimaryCmpNode(self, node): 975 not_in, common_var, conditions = self.extract_common_conditions( 976 None, node, True) 977 if common_var is None \ 978 or len(conditions) < 2 \ 979 or self.has_duplicate_values(conditions): 980 self.visitchildren(node) 981 return node 982 983 return self.build_simple_switch_statement( 984 node, common_var, conditions, not_in, 985 ExprNodes.BoolNode(node.pos, value=True, constant_result=True), 986 ExprNodes.BoolNode(node.pos, value=False, constant_result=False)) 987 988 def build_simple_switch_statement(self, node, common_var, conditions, 989 not_in, true_val, false_val): 990 result_ref = UtilNodes.ResultRefNode(node) 991 true_body = Nodes.SingleAssignmentNode( 992 node.pos, 993 lhs = result_ref, 994 rhs = true_val, 995 first = True) 996 false_body = Nodes.SingleAssignmentNode( 997 node.pos, 998 lhs = result_ref, 999 rhs = false_val, 1000 first = True) 1001 1002 if not_in: 1003 true_body, false_body = false_body, true_body 1004 1005 cases = [Nodes.SwitchCaseNode(pos = node.pos, 1006 conditions = conditions, 1007 body = true_body)] 1008 1009 common_var = unwrap_node(common_var) 1010 switch_node = Nodes.SwitchStatNode(pos = node.pos, 1011 test = common_var, 1012 cases = cases, 1013 else_clause = false_body) 1014 replacement = UtilNodes.TempResultFromStatNode(result_ref, switch_node) 1015 return replacement 1016 1017 def visit_EvalWithTempExprNode(self, node): 1018 # drop unused expression temp from FlattenInListTransform 1019 orig_expr = node.subexpression 1020 temp_ref = node.lazy_temp 1021 self.visitchildren(node) 1022 if node.subexpression is not orig_expr: 1023 # node was restructured => check if temp is still used 1024 if not Visitor.tree_contains(node.subexpression, temp_ref): 1025 return node.subexpression 1026 return node 1027 1028 visit_Node = Visitor.VisitorTransform.recurse_to_children 1029 1030 1031class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations): 1032 """ 1033 This transformation flattens "x in [val1, ..., valn]" into a sequential list 1034 of comparisons. 1035 """ 1036 1037 def visit_PrimaryCmpNode(self, node): 1038 self.visitchildren(node) 1039 if node.cascade is not None: 1040 return node 1041 elif node.operator == 'in': 1042 conjunction = 'or' 1043 eq_or_neq = '==' 1044 elif node.operator == 'not_in': 1045 conjunction = 'and' 1046 eq_or_neq = '!=' 1047 else: 1048 return node 1049 1050 if not isinstance(node.operand2, (ExprNodes.TupleNode, 1051 ExprNodes.ListNode, 1052 ExprNodes.SetNode)): 1053 return node 1054 1055 args = node.operand2.args 1056 if len(args) == 0: 1057 # note: lhs may have side effects 1058 return node 1059 1060 lhs = UtilNodes.ResultRefNode(node.operand1) 1061 1062 conds = [] 1063 temps = [] 1064 for arg in args: 1065 try: 1066 # Trial optimisation to avoid redundant temp 1067 # assignments. However, since is_simple() is meant to 1068 # be called after type analysis, we ignore any errors 1069 # and just play safe in that case. 1070 is_simple_arg = arg.is_simple() 1071 except Exception: 1072 is_simple_arg = False 1073 if not is_simple_arg: 1074 # must evaluate all non-simple RHS before doing the comparisons 1075 arg = UtilNodes.LetRefNode(arg) 1076 temps.append(arg) 1077 cond = ExprNodes.PrimaryCmpNode( 1078 pos = node.pos, 1079 operand1 = lhs, 1080 operator = eq_or_neq, 1081 operand2 = arg, 1082 cascade = None) 1083 conds.append(ExprNodes.TypecastNode( 1084 pos = node.pos, 1085 operand = cond, 1086 type = PyrexTypes.c_bint_type)) 1087 def concat(left, right): 1088 return ExprNodes.BoolBinopNode( 1089 pos = node.pos, 1090 operator = conjunction, 1091 operand1 = left, 1092 operand2 = right) 1093 1094 condition = reduce(concat, conds) 1095 new_node = UtilNodes.EvalWithTempExprNode(lhs, condition) 1096 for temp in temps[::-1]: 1097 new_node = UtilNodes.EvalWithTempExprNode(temp, new_node) 1098 return new_node 1099 1100 visit_Node = Visitor.VisitorTransform.recurse_to_children 1101 1102 1103class DropRefcountingTransform(Visitor.VisitorTransform): 1104 """Drop ref-counting in safe places. 1105 """ 1106 visit_Node = Visitor.VisitorTransform.recurse_to_children 1107 1108 def visit_ParallelAssignmentNode(self, node): 1109 """ 1110 Parallel swap assignments like 'a,b = b,a' are safe. 1111 """ 1112 left_names, right_names = [], [] 1113 left_indices, right_indices = [], [] 1114 temps = [] 1115 1116 for stat in node.stats: 1117 if isinstance(stat, Nodes.SingleAssignmentNode): 1118 if not self._extract_operand(stat.lhs, left_names, 1119 left_indices, temps): 1120 return node 1121 if not self._extract_operand(stat.rhs, right_names, 1122 right_indices, temps): 1123 return node 1124 elif isinstance(stat, Nodes.CascadedAssignmentNode): 1125 # FIXME 1126 return node 1127 else: 1128 return node 1129 1130 if left_names or right_names: 1131 # lhs/rhs names must be a non-redundant permutation 1132 lnames = [ path for path, n in left_names ] 1133 rnames = [ path for path, n in right_names ] 1134 if set(lnames) != set(rnames): 1135 return node 1136 if len(set(lnames)) != len(right_names): 1137 return node 1138 1139 if left_indices or right_indices: 1140 # base name and index of index nodes must be a 1141 # non-redundant permutation 1142 lindices = [] 1143 for lhs_node in left_indices: 1144 index_id = self._extract_index_id(lhs_node) 1145 if not index_id: 1146 return node 1147 lindices.append(index_id) 1148 rindices = [] 1149 for rhs_node in right_indices: 1150 index_id = self._extract_index_id(rhs_node) 1151 if not index_id: 1152 return node 1153 rindices.append(index_id) 1154 1155 if set(lindices) != set(rindices): 1156 return node 1157 if len(set(lindices)) != len(right_indices): 1158 return node 1159 1160 # really supporting IndexNode requires support in 1161 # __Pyx_GetItemInt(), so let's stop short for now 1162 return node 1163 1164 temp_args = [t.arg for t in temps] 1165 for temp in temps: 1166 temp.use_managed_ref = False 1167 1168 for _, name_node in left_names + right_names: 1169 if name_node not in temp_args: 1170 name_node.use_managed_ref = False 1171 1172 for index_node in left_indices + right_indices: 1173 index_node.use_managed_ref = False 1174 1175 return node 1176 1177 def _extract_operand(self, node, names, indices, temps): 1178 node = unwrap_node(node) 1179 if not node.type.is_pyobject: 1180 return False 1181 if isinstance(node, ExprNodes.CoerceToTempNode): 1182 temps.append(node) 1183 node = node.arg 1184 name_path = [] 1185 obj_node = node 1186 while isinstance(obj_node, ExprNodes.AttributeNode): 1187 if obj_node.is_py_attr: 1188 return False 1189 name_path.append(obj_node.member) 1190 obj_node = obj_node.obj 1191 if isinstance(obj_node, ExprNodes.NameNode): 1192 name_path.append(obj_node.name) 1193 names.append( ('.'.join(name_path[::-1]), node) ) 1194 elif isinstance(node, ExprNodes.IndexNode): 1195 if node.base.type != Builtin.list_type: 1196 return False 1197 if not node.index.type.is_int: 1198 return False 1199 if not isinstance(node.base, ExprNodes.NameNode): 1200 return False 1201 indices.append(node) 1202 else: 1203 return False 1204 return True 1205 1206 def _extract_index_id(self, index_node): 1207 base = index_node.base 1208 index = index_node.index 1209 if isinstance(index, ExprNodes.NameNode): 1210 index_val = index.name 1211 elif isinstance(index, ExprNodes.ConstNode): 1212 # FIXME: 1213 return None 1214 else: 1215 return None 1216 return (base.name, index_val) 1217 1218 1219class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): 1220 """Optimize some common calls to builtin types *before* the type 1221 analysis phase and *after* the declarations analysis phase. 1222 1223 This transform cannot make use of any argument types, but it can 1224 restructure the tree in a way that the type analysis phase can 1225 respond to. 1226 1227 Introducing C function calls here may not be a good idea. Move 1228 them to the OptimizeBuiltinCalls transform instead, which runs 1229 after type analysis. 1230 """ 1231 # only intercept on call nodes 1232 visit_Node = Visitor.VisitorTransform.recurse_to_children 1233 1234 def visit_SimpleCallNode(self, node): 1235 self.visitchildren(node) 1236 function = node.function 1237 if not self._function_is_builtin_name(function): 1238 return node 1239 return self._dispatch_to_handler(node, function, node.args) 1240 1241 def visit_GeneralCallNode(self, node): 1242 self.visitchildren(node) 1243 function = node.function 1244 if not self._function_is_builtin_name(function): 1245 return node 1246 arg_tuple = node.positional_args 1247 if not isinstance(arg_tuple, ExprNodes.TupleNode): 1248 return node 1249 args = arg_tuple.args 1250 return self._dispatch_to_handler( 1251 node, function, args, node.keyword_args) 1252 1253 def _function_is_builtin_name(self, function): 1254 if not function.is_name: 1255 return False 1256 env = self.current_env() 1257 entry = env.lookup(function.name) 1258 if entry is not env.builtin_scope().lookup_here(function.name): 1259 return False 1260 # if entry is None, it's at least an undeclared name, so likely builtin 1261 return True 1262 1263 def _dispatch_to_handler(self, node, function, args, kwargs=None): 1264 if kwargs is None: 1265 handler_name = '_handle_simple_function_%s' % function.name 1266 else: 1267 handler_name = '_handle_general_function_%s' % function.name 1268 handle_call = getattr(self, handler_name, None) 1269 if handle_call is not None: 1270 if kwargs is None: 1271 return handle_call(node, args) 1272 else: 1273 return handle_call(node, args, kwargs) 1274 return node 1275 1276 def _inject_capi_function(self, node, cname, func_type, utility_code=None): 1277 node.function = ExprNodes.PythonCapiFunctionNode( 1278 node.function.pos, node.function.name, cname, func_type, 1279 utility_code = utility_code) 1280 1281 def _error_wrong_arg_count(self, function_name, node, args, expected=None): 1282 if not expected: # None or 0 1283 arg_str = '' 1284 elif isinstance(expected, basestring) or expected > 1: 1285 arg_str = '...' 1286 elif expected == 1: 1287 arg_str = 'x' 1288 else: 1289 arg_str = '' 1290 if expected is not None: 1291 expected_str = 'expected %s, ' % expected 1292 else: 1293 expected_str = '' 1294 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % ( 1295 function_name, arg_str, expected_str, len(args))) 1296 1297 # specific handlers for simple call nodes 1298 1299 def _handle_simple_function_float(self, node, pos_args): 1300 if not pos_args: 1301 return ExprNodes.FloatNode(node.pos, value='0.0') 1302 if len(pos_args) > 1: 1303 self._error_wrong_arg_count('float', node, pos_args, 1) 1304 arg_type = getattr(pos_args[0], 'type', None) 1305 if arg_type in (PyrexTypes.c_double_type, Builtin.float_type): 1306 return pos_args[0] 1307 return node 1308 1309 class YieldNodeCollector(Visitor.TreeVisitor): 1310 def __init__(self): 1311 Visitor.TreeVisitor.__init__(self) 1312 self.yield_stat_nodes = {} 1313 self.yield_nodes = [] 1314 1315 visit_Node = Visitor.TreeVisitor.visitchildren 1316 # XXX: disable inlining while it's not back supported 1317 def __visit_YieldExprNode(self, node): 1318 self.yield_nodes.append(node) 1319 self.visitchildren(node) 1320 1321 def __visit_ExprStatNode(self, node): 1322 self.visitchildren(node) 1323 if node.expr in self.yield_nodes: 1324 self.yield_stat_nodes[node.expr] = node 1325 1326 def __visit_GeneratorExpressionNode(self, node): 1327 # enable when we support generic generator expressions 1328 # 1329 # everything below this node is out of scope 1330 pass 1331 1332 def _find_single_yield_expression(self, node): 1333 collector = self.YieldNodeCollector() 1334 collector.visitchildren(node) 1335 if len(collector.yield_nodes) != 1: 1336 return None, None 1337 yield_node = collector.yield_nodes[0] 1338 try: 1339 return (yield_node.arg, collector.yield_stat_nodes[yield_node]) 1340 except KeyError: 1341 return None, None 1342 1343 def _handle_simple_function_all(self, node, pos_args): 1344 """Transform 1345 1346 _result = all(x for L in LL for x in L) 1347 1348 into 1349 1350 for L in LL: 1351 for x in L: 1352 if not x: 1353 _result = False 1354 break 1355 else: 1356 continue 1357 break 1358 else: 1359 _result = True 1360 """ 1361 return self._transform_any_all(node, pos_args, False) 1362 1363 def _handle_simple_function_any(self, node, pos_args): 1364 """Transform 1365 1366 _result = any(x for L in LL for x in L) 1367 1368 into 1369 1370 for L in LL: 1371 for x in L: 1372 if x: 1373 _result = True 1374 break 1375 else: 1376 continue 1377 break 1378 else: 1379 _result = False 1380 """ 1381 return self._transform_any_all(node, pos_args, True) 1382 1383 def _transform_any_all(self, node, pos_args, is_any): 1384 if len(pos_args) != 1: 1385 return node 1386 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): 1387 return node 1388 gen_expr_node = pos_args[0] 1389 loop_node = gen_expr_node.loop 1390 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node) 1391 if yield_expression is None: 1392 return node 1393 1394 if is_any: 1395 condition = yield_expression 1396 else: 1397 condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression) 1398 1399 result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type) 1400 test_node = Nodes.IfStatNode( 1401 yield_expression.pos, 1402 else_clause = None, 1403 if_clauses = [ Nodes.IfClauseNode( 1404 yield_expression.pos, 1405 condition = condition, 1406 body = Nodes.StatListNode( 1407 node.pos, 1408 stats = [ 1409 Nodes.SingleAssignmentNode( 1410 node.pos, 1411 lhs = result_ref, 1412 rhs = ExprNodes.BoolNode(yield_expression.pos, value = is_any, 1413 constant_result = is_any)), 1414 Nodes.BreakStatNode(node.pos) 1415 ])) ] 1416 ) 1417 loop = loop_node 1418 while isinstance(loop.body, Nodes.LoopNode): 1419 next_loop = loop.body 1420 loop.body = Nodes.StatListNode(loop.body.pos, stats = [ 1421 loop.body, 1422 Nodes.BreakStatNode(yield_expression.pos) 1423 ]) 1424 next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos) 1425 loop = next_loop 1426 loop_node.else_clause = Nodes.SingleAssignmentNode( 1427 node.pos, 1428 lhs = result_ref, 1429 rhs = ExprNodes.BoolNode(yield_expression.pos, value = not is_any, 1430 constant_result = not is_any)) 1431 1432 Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node) 1433 1434 return ExprNodes.InlinedGeneratorExpressionNode( 1435 gen_expr_node.pos, loop = loop_node, result_node = result_ref, 1436 expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all') 1437 1438 def _handle_simple_function_sorted(self, node, pos_args): 1439 """Transform sorted(genexpr) and sorted([listcomp]) into 1440 [listcomp].sort(). CPython just reads the iterable into a 1441 list and calls .sort() on it. Expanding the iterable in a 1442 listcomp is still faster and the result can be sorted in 1443 place. 1444 """ 1445 if len(pos_args) != 1: 1446 return node 1447 if isinstance(pos_args[0], ExprNodes.ComprehensionNode) \ 1448 and pos_args[0].type is Builtin.list_type: 1449 listcomp_node = pos_args[0] 1450 loop_node = listcomp_node.loop 1451 elif isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): 1452 gen_expr_node = pos_args[0] 1453 loop_node = gen_expr_node.loop 1454 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node) 1455 if yield_expression is None: 1456 return node 1457 1458 append_node = ExprNodes.ComprehensionAppendNode( 1459 yield_expression.pos, expr = yield_expression) 1460 1461 Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node) 1462 1463 listcomp_node = ExprNodes.ComprehensionNode( 1464 gen_expr_node.pos, loop = loop_node, 1465 append = append_node, type = Builtin.list_type, 1466 expr_scope = gen_expr_node.expr_scope, 1467 has_local_scope = True) 1468 append_node.target = listcomp_node 1469 else: 1470 return node 1471 1472 result_node = UtilNodes.ResultRefNode( 1473 pos = loop_node.pos, type = Builtin.list_type, may_hold_none=False) 1474 listcomp_assign_node = Nodes.SingleAssignmentNode( 1475 node.pos, lhs = result_node, rhs = listcomp_node, first = True) 1476 1477 sort_method = ExprNodes.AttributeNode( 1478 node.pos, obj = result_node, attribute = EncodedString('sort'), 1479 # entry ? type ? 1480 needs_none_check = False) 1481 sort_node = Nodes.ExprStatNode( 1482 node.pos, expr = ExprNodes.SimpleCallNode( 1483 node.pos, function = sort_method, args = [])) 1484 1485 sort_node.analyse_declarations(self.current_env()) 1486 1487 return UtilNodes.TempResultFromStatNode( 1488 result_node, 1489 Nodes.StatListNode(node.pos, stats = [ listcomp_assign_node, sort_node ])) 1490 1491 def _handle_simple_function_sum(self, node, pos_args): 1492 """Transform sum(genexpr) into an equivalent inlined aggregation loop. 1493 """ 1494 if len(pos_args) not in (1,2): 1495 return node 1496 if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode, 1497 ExprNodes.ComprehensionNode)): 1498 return node 1499 gen_expr_node = pos_args[0] 1500 loop_node = gen_expr_node.loop 1501 1502 if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode): 1503 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node) 1504 if yield_expression is None: 1505 return node 1506 else: # ComprehensionNode 1507 yield_stat_node = gen_expr_node.append 1508 yield_expression = yield_stat_node.expr 1509 try: 1510 if not yield_expression.is_literal or not yield_expression.type.is_int: 1511 return node 1512 except AttributeError: 1513 return node # in case we don't have a type yet 1514 # special case: old Py2 backwards compatible "sum([int_const for ...])" 1515 # can safely be unpacked into a genexpr 1516 1517 if len(pos_args) == 1: 1518 start = ExprNodes.IntNode(node.pos, value='0', constant_result=0) 1519 else: 1520 start = pos_args[1] 1521 1522 result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type) 1523 add_node = Nodes.SingleAssignmentNode( 1524 yield_expression.pos, 1525 lhs = result_ref, 1526 rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression) 1527 ) 1528 1529 Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node) 1530 1531 exec_code = Nodes.StatListNode( 1532 node.pos, 1533 stats = [ 1534 Nodes.SingleAssignmentNode( 1535 start.pos, 1536 lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref), 1537 rhs = start, 1538 first = True), 1539 loop_node 1540 ]) 1541 1542 return ExprNodes.InlinedGeneratorExpressionNode( 1543 gen_expr_node.pos, loop = exec_code, result_node = result_ref, 1544 expr_scope = gen_expr_node.expr_scope, orig_func = 'sum', 1545 has_local_scope = gen_expr_node.has_local_scope) 1546 1547 def _handle_simple_function_min(self, node, pos_args): 1548 return self._optimise_min_max(node, pos_args, '<') 1549 1550 def _handle_simple_function_max(self, node, pos_args): 1551 return self._optimise_min_max(node, pos_args, '>') 1552 1553 def _optimise_min_max(self, node, args, operator): 1554 """Replace min(a,b,...) and max(a,b,...) by explicit comparison code. 1555 """ 1556 if len(args) <= 1: 1557 if len(args) == 1 and args[0].is_sequence_constructor: 1558 args = args[0].args 1559 else: 1560 # leave this to Python 1561 return node 1562 1563 cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:])) 1564 1565 last_result = args[0] 1566 for arg_node in cascaded_nodes: 1567 result_ref = UtilNodes.ResultRefNode(last_result) 1568 last_result = ExprNodes.CondExprNode( 1569 arg_node.pos, 1570 true_val = arg_node, 1571 false_val = result_ref, 1572 test = ExprNodes.PrimaryCmpNode( 1573 arg_node.pos, 1574 operand1 = arg_node, 1575 operator = operator, 1576 operand2 = result_ref, 1577 ) 1578 ) 1579 last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result) 1580 1581 for ref_node in cascaded_nodes[::-1]: 1582 last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result) 1583 1584 return last_result 1585 1586 def _DISABLED_handle_simple_function_tuple(self, node, pos_args): 1587 if not pos_args: 1588 return ExprNodes.TupleNode(node.pos, args=[], constant_result=()) 1589 # This is a bit special - for iterables (including genexps), 1590 # Python actually overallocates and resizes a newly created 1591 # tuple incrementally while reading items, which we can't 1592 # easily do without explicit node support. Instead, we read 1593 # the items into a list and then copy them into a tuple of the 1594 # final size. This takes up to twice as much memory, but will 1595 # have to do until we have real support for genexps. 1596 result = self._transform_list_set_genexpr(node, pos_args, Builtin.list_type) 1597 if result is not node: 1598 return ExprNodes.AsTupleNode(node.pos, arg=result) 1599 return node 1600 1601 def _handle_simple_function_frozenset(self, node, pos_args): 1602 """Replace frozenset([...]) by frozenset((...)) as tuples are more efficient. 1603 """ 1604 if len(pos_args) != 1: 1605 return node 1606 if pos_args[0].is_sequence_constructor and not pos_args[0].args: 1607 del pos_args[0] 1608 elif isinstance(pos_args[0], ExprNodes.ListNode): 1609 pos_args[0] = pos_args[0].as_tuple() 1610 return node 1611 1612 def _handle_simple_function_list(self, node, pos_args): 1613 if not pos_args: 1614 return ExprNodes.ListNode(node.pos, args=[], constant_result=[]) 1615 return self._transform_list_set_genexpr(node, pos_args, Builtin.list_type) 1616 1617 def _handle_simple_function_set(self, node, pos_args): 1618 if not pos_args: 1619 return ExprNodes.SetNode(node.pos, args=[], constant_result=set()) 1620 return self._transform_list_set_genexpr(node, pos_args, Builtin.set_type) 1621 1622 def _transform_list_set_genexpr(self, node, pos_args, target_type): 1623 """Replace set(genexpr) and list(genexpr) by a literal comprehension. 1624 """ 1625 if len(pos_args) > 1: 1626 return node 1627 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): 1628 return node 1629 gen_expr_node = pos_args[0] 1630 loop_node = gen_expr_node.loop 1631 1632 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node) 1633 if yield_expression is None: 1634 return node 1635 1636 append_node = ExprNodes.ComprehensionAppendNode( 1637 yield_expression.pos, 1638 expr = yield_expression) 1639 1640 Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node) 1641 1642 comp = ExprNodes.ComprehensionNode( 1643 node.pos, 1644 has_local_scope = True, 1645 expr_scope = gen_expr_node.expr_scope, 1646 loop = loop_node, 1647 append = append_node, 1648 type = target_type) 1649 append_node.target = comp 1650 return comp 1651 1652 def _handle_simple_function_dict(self, node, pos_args): 1653 """Replace dict( (a,b) for ... ) by a literal { a:b for ... }. 1654 """ 1655 if len(pos_args) == 0: 1656 return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={}) 1657 if len(pos_args) > 1: 1658 return node 1659 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): 1660 return node 1661 gen_expr_node = pos_args[0] 1662 loop_node = gen_expr_node.loop 1663 1664 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node) 1665 if yield_expression is None: 1666 return node 1667 1668 if not isinstance(yield_expression, ExprNodes.TupleNode): 1669 return node 1670 if len(yield_expression.args) != 2: 1671 return node 1672 1673 append_node = ExprNodes.DictComprehensionAppendNode( 1674 yield_expression.pos, 1675 key_expr = yield_expression.args[0], 1676 value_expr = yield_expression.args[1]) 1677 1678 Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node) 1679 1680 dictcomp = ExprNodes.ComprehensionNode( 1681 node.pos, 1682 has_local_scope = True, 1683 expr_scope = gen_expr_node.expr_scope, 1684 loop = loop_node, 1685 append = append_node, 1686 type = Builtin.dict_type) 1687 append_node.target = dictcomp 1688 return dictcomp 1689 1690 # specific handlers for general call nodes 1691 1692 def _handle_general_function_dict(self, node, pos_args, kwargs): 1693 """Replace dict(a=b,c=d,...) by the underlying keyword dict 1694 construction which is done anyway. 1695 """ 1696 if len(pos_args) > 0: 1697 return node 1698 if not isinstance(kwargs, ExprNodes.DictNode): 1699 return node 1700 return kwargs 1701 1702 1703class InlineDefNodeCalls(Visitor.NodeRefCleanupMixin, Visitor.EnvTransform): 1704 visit_Node = Visitor.VisitorTransform.recurse_to_children 1705 1706 def get_constant_value_node(self, name_node): 1707 if name_node.cf_state is None: 1708 return None 1709 if name_node.cf_state.cf_is_null: 1710 return None 1711 entry = self.current_env().lookup(name_node.name) 1712 if not entry or (not entry.cf_assignments 1713 or len(entry.cf_assignments) != 1): 1714 # not just a single assignment in all closures 1715 return None 1716 return entry.cf_assignments[0].rhs 1717 1718 def visit_SimpleCallNode(self, node): 1719 self.visitchildren(node) 1720 if not self.current_directives.get('optimize.inline_defnode_calls'): 1721 return node 1722 function_name = node.function 1723 if not function_name.is_name: 1724 return node 1725 function = self.get_constant_value_node(function_name) 1726 if not isinstance(function, ExprNodes.PyCFunctionNode): 1727 return node 1728 inlined = ExprNodes.InlinedDefNodeCallNode( 1729 node.pos, function_name=function_name, 1730 function=function, args=node.args) 1731 if inlined.can_be_inlined(): 1732 return self.replace(node, inlined) 1733 return node 1734 1735 1736class OptimizeBuiltinCalls(Visitor.MethodDispatcherTransform): 1737 """Optimize some common methods calls and instantiation patterns 1738 for builtin types *after* the type analysis phase. 1739 1740 Running after type analysis, this transform can only perform 1741 function replacements that do not alter the function return type 1742 in a way that was not anticipated by the type analysis. 1743 """ 1744 ### cleanup to avoid redundant coercions to/from Python types 1745 1746 def _visit_PyTypeTestNode(self, node): 1747 # disabled - appears to break assignments in some cases, and 1748 # also drops a None check, which might still be required 1749 """Flatten redundant type checks after tree changes. 1750 """ 1751 old_arg = node.arg 1752 self.visitchildren(node) 1753 if old_arg is node.arg or node.arg.type != node.type: 1754 return node 1755 return node.arg 1756 1757 def _visit_TypecastNode(self, node): 1758 # disabled - the user may have had a reason to put a type 1759 # cast, even if it looks redundant to Cython 1760 """ 1761 Drop redundant type casts. 1762 """ 1763 self.visitchildren(node) 1764 if node.type == node.operand.type: 1765 return node.operand 1766 return node 1767 1768 def visit_ExprStatNode(self, node): 1769 """ 1770 Drop useless coercions. 1771 """ 1772 self.visitchildren(node) 1773 if isinstance(node.expr, ExprNodes.CoerceToPyTypeNode): 1774 node.expr = node.expr.arg 1775 return node 1776 1777 def visit_CoerceToBooleanNode(self, node): 1778 """Drop redundant conversion nodes after tree changes. 1779 """ 1780 self.visitchildren(node) 1781 arg = node.arg 1782 if isinstance(arg, ExprNodes.PyTypeTestNode): 1783 arg = arg.arg 1784 if isinstance(arg, ExprNodes.CoerceToPyTypeNode): 1785 if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type): 1786 return arg.arg.coerce_to_boolean(self.current_env()) 1787 return node 1788 1789 def visit_CoerceFromPyTypeNode(self, node): 1790 """Drop redundant conversion nodes after tree changes. 1791 1792 Also, optimise away calls to Python's builtin int() and 1793 float() if the result is going to be coerced back into a C 1794 type anyway. 1795 """ 1796 self.visitchildren(node) 1797 arg = node.arg 1798 if not arg.type.is_pyobject: 1799 # no Python conversion left at all, just do a C coercion instead 1800 if node.type == arg.type: 1801 return arg 1802 else: 1803 return arg.coerce_to(node.type, self.current_env()) 1804 if isinstance(arg, ExprNodes.PyTypeTestNode): 1805 arg = arg.arg 1806 if arg.is_literal: 1807 if (node.type.is_int and isinstance(arg, ExprNodes.IntNode) or 1808 node.type.is_float and isinstance(arg, ExprNodes.FloatNode) or 1809 node.type.is_int and isinstance(arg, ExprNodes.BoolNode)): 1810 return arg.coerce_to(node.type, self.current_env()) 1811 elif isinstance(arg, ExprNodes.CoerceToPyTypeNode): 1812 if arg.type is PyrexTypes.py_object_type: 1813 if node.type.assignable_from(arg.arg.type): 1814 # completely redundant C->Py->C coercion 1815 return arg.arg.coerce_to(node.type, self.current_env()) 1816 elif isinstance(arg, ExprNodes.SimpleCallNode): 1817 if node.type.is_int or node.type.is_float: 1818 return self._optimise_numeric_cast_call(node, arg) 1819 elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access: 1820 index_node = arg.index 1821 if isinstance(index_node, ExprNodes.CoerceToPyTypeNode): 1822 index_node = index_node.arg 1823 if index_node.type.is_int: 1824 return self._optimise_int_indexing(node, arg, index_node) 1825 return node 1826 1827 PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType( 1828 PyrexTypes.c_char_type, [ 1829 PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None), 1830 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None), 1831 PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None), 1832 ], 1833 exception_value = "((char)-1)", 1834 exception_check = True) 1835 1836 def _optimise_int_indexing(self, coerce_node, arg, index_node): 1837 env = self.current_env() 1838 bound_check_bool = env.directives['boundscheck'] and 1 or 0 1839 if arg.base.type is Builtin.bytes_type: 1840 if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type): 1841 # bytes[index] -> char 1842 bound_check_node = ExprNodes.IntNode( 1843 coerce_node.pos, value=str(bound_check_bool), 1844 constant_result=bound_check_bool) 1845 node = ExprNodes.PythonCapiCallNode( 1846 coerce_node.pos, "__Pyx_PyBytes_GetItemInt", 1847 self.PyBytes_GetItemInt_func_type, 1848 args=[ 1849 arg.base.as_none_safe_node("'NoneType' object is not subscriptable"), 1850 index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env), 1851 bound_check_node, 1852 ], 1853 is_temp=True, 1854 utility_code=UtilityCode.load_cached( 1855 'bytes_index', 'StringTools.c')) 1856 if coerce_node.type is not PyrexTypes.c_char_type: 1857 node = node.coerce_to(coerce_node.type, env) 1858 return node 1859 return coerce_node 1860 1861 def _optimise_numeric_cast_call(self, node, arg): 1862 function = arg.function 1863 if not isinstance(function, ExprNodes.NameNode) \ 1864 or not function.type.is_builtin_type \ 1865 or not isinstance(arg.arg_tuple, ExprNodes.TupleNode): 1866 return node 1867 args = arg.arg_tuple.args 1868 if len(args) != 1: 1869 return node 1870 func_arg = args[0] 1871 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode): 1872 func_arg = func_arg.arg 1873 elif func_arg.type.is_pyobject: 1874 # play safe: Python conversion might work on all sorts of things 1875 return node 1876 if function.name == 'int': 1877 if func_arg.type.is_int or node.type.is_int: 1878 if func_arg.type == node.type: 1879 return func_arg 1880 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float: 1881 return ExprNodes.TypecastNode( 1882 node.pos, operand=func_arg, type=node.type) 1883 elif function.name == 'float': 1884 if func_arg.type.is_float or node.type.is_float: 1885 if func_arg.type == node.type: 1886 return func_arg 1887 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float: 1888 return ExprNodes.TypecastNode( 1889 node.pos, operand=func_arg, type=node.type) 1890 return node 1891 1892 def _error_wrong_arg_count(self, function_name, node, args, expected=None): 1893 if not expected: # None or 0 1894 arg_str = '' 1895 elif isinstance(expected, basestring) or expected > 1: 1896 arg_str = '...' 1897 elif expected == 1: 1898 arg_str = 'x' 1899 else: 1900 arg_str = '' 1901 if expected is not None: 1902 expected_str = 'expected %s, ' % expected 1903 else: 1904 expected_str = '' 1905 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % ( 1906 function_name, arg_str, expected_str, len(args))) 1907 1908 ### generic fallbacks 1909 1910 def _handle_function(self, node, function_name, function, arg_list, kwargs): 1911 return node 1912 1913 def _handle_method(self, node, type_name, attr_name, function, 1914 arg_list, is_unbound_method, kwargs): 1915 """ 1916 Try to inject C-API calls for unbound method calls to builtin types. 1917 While the method declarations in Builtin.py already handle this, we 1918 can additionally resolve bound and unbound methods here that were 1919 assigned to variables ahead of time. 1920 """ 1921 if kwargs: 1922 return node 1923 if not function or not function.is_attribute or not function.obj.is_name: 1924 # cannot track unbound method calls over more than one indirection as 1925 # the names might have been reassigned in the meantime 1926 return node 1927 type_entry = self.current_env().lookup(type_name) 1928 if not type_entry: 1929 return node 1930 method = ExprNodes.AttributeNode( 1931 node.function.pos, 1932 obj=ExprNodes.NameNode( 1933 function.pos, 1934 name=type_name, 1935 entry=type_entry, 1936 type=type_entry.type), 1937 attribute=attr_name, 1938 is_called=True).analyse_as_unbound_cmethod_node(self.current_env()) 1939 if method is None: 1940 return node 1941 args = node.args 1942 if args is None and node.arg_tuple: 1943 args = node.arg_tuple.args 1944 call_node = ExprNodes.SimpleCallNode( 1945 node.pos, 1946 function=method, 1947 args=args) 1948 if not is_unbound_method: 1949 call_node.self = function.obj 1950 call_node.analyse_c_function_call(self.current_env()) 1951 call_node.analysed = True 1952 return call_node.coerce_to(node.type, self.current_env()) 1953 1954 ### builtin types 1955 1956 PyDict_Copy_func_type = PyrexTypes.CFuncType( 1957 Builtin.dict_type, [ 1958 PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None) 1959 ]) 1960 1961 def _handle_simple_function_dict(self, node, function, pos_args): 1962 """Replace dict(some_dict) by PyDict_Copy(some_dict). 1963 """ 1964 if len(pos_args) != 1: 1965 return node 1966 arg = pos_args[0] 1967 if arg.type is Builtin.dict_type: 1968 arg = arg.as_none_safe_node("'NoneType' is not iterable") 1969 return ExprNodes.PythonCapiCallNode( 1970 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type, 1971 args = [arg], 1972 is_temp = node.is_temp 1973 ) 1974 return node 1975 1976 PyList_AsTuple_func_type = PyrexTypes.CFuncType( 1977 Builtin.tuple_type, [ 1978 PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None) 1979 ]) 1980 1981 def _handle_simple_function_tuple(self, node, function, pos_args): 1982 """Replace tuple([...]) by a call to PyList_AsTuple. 1983 """ 1984 if len(pos_args) != 1: 1985 return node 1986 arg = pos_args[0] 1987 if arg.type is Builtin.tuple_type and not arg.may_be_none(): 1988 return arg 1989 if arg.type is not Builtin.list_type: 1990 return node 1991 pos_args[0] = arg.as_none_safe_node( 1992 "'NoneType' object is not iterable") 1993 1994 return ExprNodes.PythonCapiCallNode( 1995 node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type, 1996 args = pos_args, 1997 is_temp = node.is_temp 1998 ) 1999 2000 PySet_New_func_type = PyrexTypes.CFuncType( 2001 Builtin.set_type, [ 2002 PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None) 2003 ]) 2004 2005 def _handle_simple_function_set(self, node, function, pos_args): 2006 if len(pos_args) != 1: 2007 return node 2008 if pos_args[0].is_sequence_constructor: 2009 # We can optimise set([x,y,z]) safely into a set literal, 2010 # but only if we create all items before adding them - 2011 # adding an item may raise an exception if it is not 2012 # hashable, but creating the later items may have 2013 # side-effects. 2014 args = [] 2015 temps = [] 2016 for arg in pos_args[0].args: 2017 if not arg.is_simple(): 2018 arg = UtilNodes.LetRefNode(arg) 2019 temps.append(arg) 2020 args.append(arg) 2021 result = ExprNodes.SetNode(node.pos, is_temp=1, args=args) 2022 for temp in temps[::-1]: 2023 result = UtilNodes.EvalWithTempExprNode(temp, result) 2024 return result 2025 else: 2026 # PySet_New(it) is better than a generic Python call to set(it) 2027 return ExprNodes.PythonCapiCallNode( 2028 node.pos, "PySet_New", 2029 self.PySet_New_func_type, 2030 args=pos_args, 2031 is_temp=node.is_temp, 2032 utility_code=UtilityCode.load_cached('pyset_compat', 'Builtins.c'), 2033 py_name="set") 2034 2035 PyFrozenSet_New_func_type = PyrexTypes.CFuncType( 2036 Builtin.frozenset_type, [ 2037 PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None) 2038 ]) 2039 2040 def _handle_simple_function_frozenset(self, node, function, pos_args): 2041 if not pos_args: 2042 pos_args = [ExprNodes.NullNode(node.pos)] 2043 elif len(pos_args) > 1: 2044 return node 2045 elif pos_args[0].type is Builtin.frozenset_type and not pos_args[0].may_be_none(): 2046 return pos_args[0] 2047 # PyFrozenSet_New(it) is better than a generic Python call to frozenset(it) 2048 return ExprNodes.PythonCapiCallNode( 2049 node.pos, "__Pyx_PyFrozenSet_New", 2050 self.PyFrozenSet_New_func_type, 2051 args=pos_args, 2052 is_temp=node.is_temp, 2053 utility_code=UtilityCode.load_cached('pyfrozenset_new', 'Builtins.c'), 2054 py_name="frozenset") 2055 2056 PyObject_AsDouble_func_type = PyrexTypes.CFuncType( 2057 PyrexTypes.c_double_type, [ 2058 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None), 2059 ], 2060 exception_value = "((double)-1)", 2061 exception_check = True) 2062 2063 def _handle_simple_function_float(self, node, function, pos_args): 2064 """Transform float() into either a C type cast or a faster C 2065 function call. 2066 """ 2067 # Note: this requires the float() function to be typed as 2068 # returning a C 'double' 2069 if len(pos_args) == 0: 2070 return ExprNodes.FloatNode( 2071 node, value="0.0", constant_result=0.0 2072 ).coerce_to(Builtin.float_type, self.current_env()) 2073 elif len(pos_args) != 1: 2074 self._error_wrong_arg_count('float', node, pos_args, '0 or 1') 2075 return node 2076 func_arg = pos_args[0] 2077 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode): 2078 func_arg = func_arg.arg 2079 if func_arg.type is PyrexTypes.c_double_type: 2080 return func_arg 2081 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric: 2082 return ExprNodes.TypecastNode( 2083 node.pos, operand=func_arg, type=node.type) 2084 return ExprNodes.PythonCapiCallNode( 2085 node.pos, "__Pyx_PyObject_AsDouble", 2086 self.PyObject_AsDouble_func_type, 2087 args = pos_args, 2088 is_temp = node.is_temp, 2089 utility_code = load_c_utility('pyobject_as_double'), 2090 py_name = "float") 2091 2092 PyNumber_Int_func_type = PyrexTypes.CFuncType( 2093 PyrexTypes.py_object_type, [ 2094 PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None) 2095 ]) 2096 2097 def _handle_simple_function_int(self, node, function, pos_args): 2098 """Transform int() into a faster C function call. 2099 """ 2100 if len(pos_args) == 0: 2101 return ExprNodes.IntNode(node, value="0", constant_result=0, 2102 type=PyrexTypes.py_object_type) 2103 elif len(pos_args) != 1: 2104 return node # int(x, base) 2105 func_arg = pos_args[0] 2106 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode): 2107 return node # handled in visit_CoerceFromPyTypeNode() 2108 if func_arg.type.is_pyobject and node.type.is_pyobject: 2109 return ExprNodes.PythonCapiCallNode( 2110 node.pos, "PyNumber_Int", self.PyNumber_Int_func_type, 2111 args=pos_args, is_temp=True) 2112 return node 2113 2114 def _handle_simple_function_bool(self, node, function, pos_args): 2115 """Transform bool(x) into a type coercion to a boolean. 2116 """ 2117 if len(pos_args) == 0: 2118 return ExprNodes.BoolNode( 2119 node.pos, value=False, constant_result=False 2120 ).coerce_to(Builtin.bool_type, self.current_env()) 2121 elif len(pos_args) != 1: 2122 self._error_wrong_arg_count('bool', node, pos_args, '0 or 1') 2123 return node 2124 else: 2125 # => !!<bint>(x) to make sure it's exactly 0 or 1 2126 operand = pos_args[0].coerce_to_boolean(self.current_env()) 2127 operand = ExprNodes.NotNode(node.pos, operand = operand) 2128 operand = ExprNodes.NotNode(node.pos, operand = operand) 2129 # coerce back to Python object as that's the result we are expecting 2130 return operand.coerce_to_pyobject(self.current_env()) 2131 2132 ### builtin functions 2133 2134 Pyx_strlen_func_type = PyrexTypes.CFuncType( 2135 PyrexTypes.c_size_t_type, [ 2136 PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None) 2137 ]) 2138 2139 Pyx_Py_UNICODE_strlen_func_type = PyrexTypes.CFuncType( 2140 PyrexTypes.c_size_t_type, [ 2141 PyrexTypes.CFuncTypeArg("unicode", PyrexTypes.c_py_unicode_ptr_type, None) 2142 ]) 2143 2144 PyObject_Size_func_type = PyrexTypes.CFuncType( 2145 PyrexTypes.c_py_ssize_t_type, [ 2146 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None) 2147 ], 2148 exception_value="-1") 2149 2150 _map_to_capi_len_function = { 2151 Builtin.unicode_type : "__Pyx_PyUnicode_GET_LENGTH", 2152 Builtin.bytes_type : "PyBytes_GET_SIZE", 2153 Builtin.list_type : "PyList_GET_SIZE", 2154 Builtin.tuple_type : "PyTuple_GET_SIZE", 2155 Builtin.dict_type : "PyDict_Size", 2156 Builtin.set_type : "PySet_Size", 2157 Builtin.frozenset_type : "PySet_Size", 2158 }.get 2159 2160 _ext_types_with_pysize = set(["cpython.array.array"]) 2161 2162 def _handle_simple_function_len(self, node, function, pos_args): 2163 """Replace len(char*) by the equivalent call to strlen(), 2164 len(Py_UNICODE) by the equivalent Py_UNICODE_strlen() and 2165 len(known_builtin_type) by an equivalent C-API call. 2166 """ 2167 if len(pos_args) != 1: 2168 self._error_wrong_arg_count('len', node, pos_args, 1) 2169 return node 2170 arg = pos_args[0] 2171 if isinstance(arg, ExprNodes.CoerceToPyTypeNode): 2172 arg = arg.arg 2173 if arg.type.is_string: 2174 new_node = ExprNodes.PythonCapiCallNode( 2175 node.pos, "strlen", self.Pyx_strlen_func_type, 2176 args = [arg], 2177 is_temp = node.is_temp, 2178 utility_code = UtilityCode.load_cached("IncludeStringH", "StringTools.c")) 2179 elif arg.type.is_pyunicode_ptr: 2180 new_node = ExprNodes.PythonCapiCallNode( 2181 node.pos, "__Pyx_Py_UNICODE_strlen", self.Pyx_Py_UNICODE_strlen_func_type, 2182 args = [arg], 2183 is_temp = node.is_temp) 2184 elif arg.type.is_pyobject: 2185 cfunc_name = self._map_to_capi_len_function(arg.type) 2186 if cfunc_name is None: 2187 arg_type = arg.type 2188 if ((arg_type.is_extension_type or arg_type.is_builtin_type) 2189 and arg_type.entry.qualified_name in self._ext_types_with_pysize): 2190 cfunc_name = 'Py_SIZE' 2191 else: 2192 return node 2193 arg = arg.as_none_safe_node( 2194 "object of type 'NoneType' has no len()") 2195 new_node = ExprNodes.PythonCapiCallNode( 2196 node.pos, cfunc_name, self.PyObject_Size_func_type, 2197 args = [arg], 2198 is_temp = node.is_temp) 2199 elif arg.type.is_unicode_char: 2200 return ExprNodes.IntNode(node.pos, value='1', constant_result=1, 2201 type=node.type) 2202 else: 2203 return node 2204 if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type): 2205 new_node = new_node.coerce_to(node.type, self.current_env()) 2206 return new_node 2207 2208 Pyx_Type_func_type = PyrexTypes.CFuncType( 2209 Builtin.type_type, [ 2210 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None) 2211 ]) 2212 2213 def _handle_simple_function_type(self, node, function, pos_args): 2214 """Replace type(o) by a macro call to Py_TYPE(o). 2215 """ 2216 if len(pos_args) != 1: 2217 return node 2218 node = ExprNodes.PythonCapiCallNode( 2219 node.pos, "Py_TYPE", self.Pyx_Type_func_type, 2220 args = pos_args, 2221 is_temp = False) 2222 return ExprNodes.CastNode(node, PyrexTypes.py_object_type) 2223 2224 Py_type_check_func_type = PyrexTypes.CFuncType( 2225 PyrexTypes.c_bint_type, [ 2226 PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None) 2227 ]) 2228 2229 def _handle_simple_function_isinstance(self, node, function, pos_args): 2230 """Replace isinstance() checks against builtin types by the 2231 corresponding C-API call. 2232 """ 2233 if len(pos_args) != 2: 2234 return node 2235 arg, types = pos_args 2236 temp = None 2237 if isinstance(types, ExprNodes.TupleNode): 2238 types = types.args 2239 if arg.is_attribute or not arg.is_simple(): 2240 arg = temp = UtilNodes.ResultRefNode(arg) 2241 elif types.type is Builtin.type_type: 2242 types = [types] 2243 else: 2244 return node 2245 2246 tests = [] 2247 test_nodes = [] 2248 env = self.current_env() 2249 for test_type_node in types: 2250 builtin_type = None 2251 if test_type_node.is_name: 2252 if test_type_node.entry: 2253 entry = env.lookup(test_type_node.entry.name) 2254 if entry and entry.type and entry.type.is_builtin_type: 2255 builtin_type = entry.type 2256 if builtin_type is Builtin.type_type: 2257 # all types have type "type", but there's only one 'type' 2258 if entry.name != 'type' or not ( 2259 entry.scope and entry.scope.is_builtin_scope): 2260 builtin_type = None 2261 if builtin_type is not None: 2262 type_check_function = entry.type.type_check_function(exact=False) 2263 if type_check_function in tests: 2264 continue 2265 tests.append(type_check_function) 2266 type_check_args = [arg] 2267 elif test_type_node.type is Builtin.type_type: 2268 type_check_function = '__Pyx_TypeCheck' 2269 type_check_args = [arg, test_type_node] 2270 else: 2271 return node 2272 test_nodes.append( 2273 ExprNodes.PythonCapiCallNode( 2274 test_type_node.pos, type_check_function, self.Py_type_check_func_type, 2275 args = type_check_args, 2276 is_temp = True, 2277 )) 2278 2279 def join_with_or(a,b, make_binop_node=ExprNodes.binop_node): 2280 or_node = make_binop_node(node.pos, 'or', a, b) 2281 or_node.type = PyrexTypes.c_bint_type 2282 or_node.is_temp = True 2283 return or_node 2284 2285 test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env) 2286 if temp is not None: 2287 test_node = UtilNodes.EvalWithTempExprNode(temp, test_node) 2288 return test_node 2289 2290 def _handle_simple_function_ord(self, node, function, pos_args): 2291 """Unpack ord(Py_UNICODE) and ord('X'). 2292 """ 2293 if len(pos_args) != 1: 2294 return node 2295 arg = pos_args[0] 2296 if isinstance(arg, ExprNodes.CoerceToPyTypeNode): 2297 if arg.arg.type.is_unicode_char: 2298 return ExprNodes.TypecastNode( 2299 arg.pos, operand=arg.arg, type=PyrexTypes.c_int_type 2300 ).coerce_to(node.type, self.current_env()) 2301 elif isinstance(arg, ExprNodes.UnicodeNode): 2302 if len(arg.value) == 1: 2303 return ExprNodes.IntNode( 2304 arg.pos, type=PyrexTypes.c_int_type, 2305 value=str(ord(arg.value)), 2306 constant_result=ord(arg.value) 2307 ).coerce_to(node.type, self.current_env()) 2308 elif isinstance(arg, ExprNodes.StringNode): 2309 if arg.unicode_value and len(arg.unicode_value) == 1 \ 2310 and ord(arg.unicode_value) <= 255: # Py2/3 portability 2311 return ExprNodes.IntNode( 2312 arg.pos, type=PyrexTypes.c_int_type, 2313 value=str(ord(arg.unicode_value)), 2314 constant_result=ord(arg.unicode_value) 2315 ).coerce_to(node.type, self.current_env()) 2316 return node 2317 2318 ### special methods 2319 2320 Pyx_tp_new_func_type = PyrexTypes.CFuncType( 2321 PyrexTypes.py_object_type, [ 2322 PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None), 2323 PyrexTypes.CFuncTypeArg("args", Builtin.tuple_type, None), 2324 ]) 2325 2326 Pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType( 2327 PyrexTypes.py_object_type, [ 2328 PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None), 2329 PyrexTypes.CFuncTypeArg("args", Builtin.tuple_type, None), 2330 PyrexTypes.CFuncTypeArg("kwargs", Builtin.dict_type, None), 2331 ]) 2332 2333 def _handle_any_slot__new__(self, node, function, args, 2334 is_unbound_method, kwargs=None): 2335 """Replace 'exttype.__new__(exttype, ...)' by a call to exttype->tp_new() 2336 """ 2337 obj = function.obj 2338 if not is_unbound_method or len(args) < 1: 2339 return node 2340 type_arg = args[0] 2341 if not obj.is_name or not type_arg.is_name: 2342 # play safe 2343 return node 2344 if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type: 2345 # not a known type, play safe 2346 return node 2347 if not type_arg.type_entry or not obj.type_entry: 2348 if obj.name != type_arg.name: 2349 return node 2350 # otherwise, we know it's a type and we know it's the same 2351 # type for both - that should do 2352 elif type_arg.type_entry != obj.type_entry: 2353 # different types - may or may not lead to an error at runtime 2354 return node 2355 2356 args_tuple = ExprNodes.TupleNode(node.pos, args=args[1:]) 2357 args_tuple = args_tuple.analyse_types( 2358 self.current_env(), skip_children=True) 2359 2360 if type_arg.type_entry: 2361 ext_type = type_arg.type_entry.type 2362 if (ext_type.is_extension_type and ext_type.typeobj_cname and 2363 ext_type.scope.global_scope() == self.current_env().global_scope()): 2364 # known type in current module 2365 tp_slot = TypeSlots.ConstructorSlot("tp_new", '__new__') 2366 slot_func_cname = TypeSlots.get_slot_function(ext_type.scope, tp_slot) 2367 if slot_func_cname: 2368 cython_scope = self.context.cython_scope 2369 PyTypeObjectPtr = PyrexTypes.CPtrType( 2370 cython_scope.lookup('PyTypeObject').type) 2371 pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType( 2372 PyrexTypes.py_object_type, [ 2373 PyrexTypes.CFuncTypeArg("type", PyTypeObjectPtr, None), 2374 PyrexTypes.CFuncTypeArg("args", PyrexTypes.py_object_type, None), 2375 PyrexTypes.CFuncTypeArg("kwargs", PyrexTypes.py_object_type, None), 2376 ]) 2377 2378 type_arg = ExprNodes.CastNode(type_arg, PyTypeObjectPtr) 2379 if not kwargs: 2380 kwargs = ExprNodes.NullNode(node.pos, type=PyrexTypes.py_object_type) # hack? 2381 return ExprNodes.PythonCapiCallNode( 2382 node.pos, slot_func_cname, 2383 pyx_tp_new_kwargs_func_type, 2384 args=[type_arg, args_tuple, kwargs], 2385 is_temp=True) 2386 else: 2387 # arbitrary variable, needs a None check for safety 2388 type_arg = type_arg.as_none_safe_node( 2389 "object.__new__(X): X is not a type object (NoneType)") 2390 2391 utility_code = UtilityCode.load_cached('tp_new', 'ObjectHandling.c') 2392 if kwargs: 2393 return ExprNodes.PythonCapiCallNode( 2394 node.pos, "__Pyx_tp_new_kwargs", self.Pyx_tp_new_kwargs_func_type, 2395 args=[type_arg, args_tuple, kwargs], 2396 utility_code=utility_code, 2397 is_temp=node.is_temp 2398 ) 2399 else: 2400 return ExprNodes.PythonCapiCallNode( 2401 node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type, 2402 args=[type_arg, args_tuple], 2403 utility_code=utility_code, 2404 is_temp=node.is_temp 2405 ) 2406 2407 ### methods of builtin types 2408 2409 PyObject_Append_func_type = PyrexTypes.CFuncType( 2410 PyrexTypes.c_returncode_type, [ 2411 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None), 2412 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None), 2413 ], 2414 exception_value="-1") 2415 2416 def _handle_simple_method_object_append(self, node, function, args, is_unbound_method): 2417 """Optimistic optimisation as X.append() is almost always 2418 referring to a list. 2419 """ 2420 if len(args) != 2 or node.result_is_used: 2421 return node 2422 2423 return ExprNodes.PythonCapiCallNode( 2424 node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type, 2425 args=args, 2426 may_return_none=False, 2427 is_temp=node.is_temp, 2428 result_is_used=False, 2429 utility_code=load_c_utility('append') 2430 ) 2431 2432 PyByteArray_Append_func_type = PyrexTypes.CFuncType( 2433 PyrexTypes.c_returncode_type, [ 2434 PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None), 2435 PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_int_type, None), 2436 ], 2437 exception_value="-1") 2438 2439 PyByteArray_AppendObject_func_type = PyrexTypes.CFuncType( 2440 PyrexTypes.c_returncode_type, [ 2441 PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None), 2442 PyrexTypes.CFuncTypeArg("value", PyrexTypes.py_object_type, None), 2443 ], 2444 exception_value="-1") 2445 2446 def _handle_simple_method_bytearray_append(self, node, function, args, is_unbound_method): 2447 if len(args) != 2: 2448 return node 2449 func_name = "__Pyx_PyByteArray_Append" 2450 func_type = self.PyByteArray_Append_func_type 2451 2452 value = unwrap_coerced_node(args[1]) 2453 if value.type.is_int or isinstance(value, ExprNodes.IntNode): 2454 value = value.coerce_to(PyrexTypes.c_int_type, self.current_env()) 2455 utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c") 2456 elif value.is_string_literal: 2457 if not value.can_coerce_to_char_literal(): 2458 return node 2459 value = value.coerce_to(PyrexTypes.c_char_type, self.current_env()) 2460 utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c") 2461 elif value.type.is_pyobject: 2462 func_name = "__Pyx_PyByteArray_AppendObject" 2463 func_type = self.PyByteArray_AppendObject_func_type 2464 utility_code = UtilityCode.load_cached("ByteArrayAppendObject", "StringTools.c") 2465 else: 2466 return node 2467 2468 new_node = ExprNodes.PythonCapiCallNode( 2469 node.pos, func_name, func_type, 2470 args=[args[0], value], 2471 may_return_none=False, 2472 is_temp=node.is_temp, 2473 utility_code=utility_code, 2474 ) 2475 if node.result_is_used: 2476 new_node = new_node.coerce_to(node.type, self.current_env()) 2477 return new_node 2478 2479 PyObject_Pop_func_type = PyrexTypes.CFuncType( 2480 PyrexTypes.py_object_type, [ 2481 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None), 2482 ]) 2483 2484 PyObject_PopIndex_func_type = PyrexTypes.CFuncType( 2485 PyrexTypes.py_object_type, [ 2486 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None), 2487 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None), 2488 ]) 2489 2490 def _handle_simple_method_list_pop(self, node, function, args, is_unbound_method): 2491 return self._handle_simple_method_object_pop( 2492 node, function, args, is_unbound_method, is_list=True) 2493 2494 def _handle_simple_method_object_pop(self, node, function, args, is_unbound_method, is_list=False): 2495 """Optimistic optimisation as X.pop([n]) is almost always 2496 referring to a list. 2497 """ 2498 if not args: 2499 return node 2500 args = args[:] 2501 if is_list: 2502 type_name = 'List' 2503 args[0] = args[0].as_none_safe_node( 2504 "'NoneType' object has no attribute '%s'", 2505 error="PyExc_AttributeError", 2506 format_args=['pop']) 2507 else: 2508 type_name = 'Object' 2509 if len(args) == 1: 2510 return ExprNodes.PythonCapiCallNode( 2511 node.pos, "__Pyx_Py%s_Pop" % type_name, 2512 self.PyObject_Pop_func_type, 2513 args=args, 2514 may_return_none=True, 2515 is_temp=node.is_temp, 2516 utility_code=load_c_utility('pop'), 2517 ) 2518 elif len(args) == 2: 2519 index = unwrap_coerced_node(args[1]) 2520 if is_list or isinstance(index, ExprNodes.IntNode): 2521 index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) 2522 if index.type.is_int: 2523 widest = PyrexTypes.widest_numeric_type( 2524 index.type, PyrexTypes.c_py_ssize_t_type) 2525 if widest == PyrexTypes.c_py_ssize_t_type: 2526 args[1] = index 2527 return ExprNodes.PythonCapiCallNode( 2528 node.pos, "__Pyx_Py%s_PopIndex" % type_name, 2529 self.PyObject_PopIndex_func_type, 2530 args=args, 2531 may_return_none=True, 2532 is_temp=node.is_temp, 2533 utility_code=load_c_utility("pop_index"), 2534 ) 2535 2536 return node 2537 2538 single_param_func_type = PyrexTypes.CFuncType( 2539 PyrexTypes.c_returncode_type, [ 2540 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None), 2541 ], 2542 exception_value = "-1") 2543 2544 def _handle_simple_method_list_sort(self, node, function, args, is_unbound_method): 2545 """Call PyList_Sort() instead of the 0-argument l.sort(). 2546 """ 2547 if len(args) != 1: 2548 return node 2549 return self._substitute_method_call( 2550 node, function, "PyList_Sort", self.single_param_func_type, 2551 'sort', is_unbound_method, args).coerce_to(node.type, self.current_env) 2552 2553 Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType( 2554 PyrexTypes.py_object_type, [ 2555 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), 2556 PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None), 2557 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None), 2558 ]) 2559 2560 def _handle_simple_method_dict_get(self, node, function, args, is_unbound_method): 2561 """Replace dict.get() by a call to PyDict_GetItem(). 2562 """ 2563 if len(args) == 2: 2564 args.append(ExprNodes.NoneNode(node.pos)) 2565 elif len(args) != 3: 2566 self._error_wrong_arg_count('dict.get', node, args, "2 or 3") 2567 return node 2568 2569 return self._substitute_method_call( 2570 node, function, 2571 "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type, 2572 'get', is_unbound_method, args, 2573 may_return_none = True, 2574 utility_code = load_c_utility("dict_getitem_default")) 2575 2576 Pyx_PyDict_SetDefault_func_type = PyrexTypes.CFuncType( 2577 PyrexTypes.py_object_type, [ 2578 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), 2579 PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None), 2580 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None), 2581 PyrexTypes.CFuncTypeArg("is_safe_type", PyrexTypes.c_int_type, None), 2582 ]) 2583 2584 def _handle_simple_method_dict_setdefault(self, node, function, args, is_unbound_method): 2585 """Replace dict.setdefault() by calls to PyDict_GetItem() and PyDict_SetItem(). 2586 """ 2587 if len(args) == 2: 2588 args.append(ExprNodes.NoneNode(node.pos)) 2589 elif len(args) != 3: 2590 self._error_wrong_arg_count('dict.setdefault', node, args, "2 or 3") 2591 return node 2592 key_type = args[1].type 2593 if key_type.is_builtin_type: 2594 is_safe_type = int(key_type.name in 2595 'str bytes unicode float int long bool') 2596 elif key_type is PyrexTypes.py_object_type: 2597 is_safe_type = -1 # don't know 2598 else: 2599 is_safe_type = 0 # definitely not 2600 args.append(ExprNodes.IntNode( 2601 node.pos, value=str(is_safe_type), constant_result=is_safe_type)) 2602 2603 return self._substitute_method_call( 2604 node, function, 2605 "__Pyx_PyDict_SetDefault", self.Pyx_PyDict_SetDefault_func_type, 2606 'setdefault', is_unbound_method, args, 2607 may_return_none=True, 2608 utility_code=load_c_utility('dict_setdefault')) 2609 2610 2611 ### unicode type methods 2612 2613 PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType( 2614 PyrexTypes.c_bint_type, [ 2615 PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None), 2616 ]) 2617 2618 def _inject_unicode_predicate(self, node, function, args, is_unbound_method): 2619 if is_unbound_method or len(args) != 1: 2620 return node 2621 ustring = args[0] 2622 if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \ 2623 not ustring.arg.type.is_unicode_char: 2624 return node 2625 uchar = ustring.arg 2626 method_name = function.attribute 2627 if method_name == 'istitle': 2628 # istitle() doesn't directly map to Py_UNICODE_ISTITLE() 2629 utility_code = UtilityCode.load_cached( 2630 "py_unicode_istitle", "StringTools.c") 2631 function_name = '__Pyx_Py_UNICODE_ISTITLE' 2632 else: 2633 utility_code = None 2634 function_name = 'Py_UNICODE_%s' % method_name.upper() 2635 func_call = self._substitute_method_call( 2636 node, function, 2637 function_name, self.PyUnicode_uchar_predicate_func_type, 2638 method_name, is_unbound_method, [uchar], 2639 utility_code = utility_code) 2640 if node.type.is_pyobject: 2641 func_call = func_call.coerce_to_pyobject(self.current_env) 2642 return func_call 2643 2644 _handle_simple_method_unicode_isalnum = _inject_unicode_predicate 2645 _handle_simple_method_unicode_isalpha = _inject_unicode_predicate 2646 _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate 2647 _handle_simple_method_unicode_isdigit = _inject_unicode_predicate 2648 _handle_simple_method_unicode_islower = _inject_unicode_predicate 2649 _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate 2650 _handle_simple_method_unicode_isspace = _inject_unicode_predicate 2651 _handle_simple_method_unicode_istitle = _inject_unicode_predicate 2652 _handle_simple_method_unicode_isupper = _inject_unicode_predicate 2653 2654 PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType( 2655 PyrexTypes.c_py_ucs4_type, [ 2656 PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None), 2657 ]) 2658 2659 def _inject_unicode_character_conversion(self, node, function, args, is_unbound_method): 2660 if is_unbound_method or len(args) != 1: 2661 return node 2662 ustring = args[0] 2663 if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \ 2664 not ustring.arg.type.is_unicode_char: 2665 return node 2666 uchar = ustring.arg 2667 method_name = function.attribute 2668 function_name = 'Py_UNICODE_TO%s' % method_name.upper() 2669 func_call = self._substitute_method_call( 2670 node, function, 2671 function_name, self.PyUnicode_uchar_conversion_func_type, 2672 method_name, is_unbound_method, [uchar]) 2673 if node.type.is_pyobject: 2674 func_call = func_call.coerce_to_pyobject(self.current_env) 2675 return func_call 2676 2677 _handle_simple_method_unicode_lower = _inject_unicode_character_conversion 2678 _handle_simple_method_unicode_upper = _inject_unicode_character_conversion 2679 _handle_simple_method_unicode_title = _inject_unicode_character_conversion 2680 2681 PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType( 2682 Builtin.list_type, [ 2683 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), 2684 PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None), 2685 ]) 2686 2687 def _handle_simple_method_unicode_splitlines(self, node, function, args, is_unbound_method): 2688 """Replace unicode.splitlines(...) by a direct call to the 2689 corresponding C-API function. 2690 """ 2691 if len(args) not in (1,2): 2692 self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2") 2693 return node 2694 self._inject_bint_default_argument(node, args, 1, False) 2695 2696 return self._substitute_method_call( 2697 node, function, 2698 "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type, 2699 'splitlines', is_unbound_method, args) 2700 2701 PyUnicode_Split_func_type = PyrexTypes.CFuncType( 2702 Builtin.list_type, [ 2703 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), 2704 PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None), 2705 PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None), 2706 ] 2707 ) 2708 2709 def _handle_simple_method_unicode_split(self, node, function, args, is_unbound_method): 2710 """Replace unicode.split(...) by a direct call to the 2711 corresponding C-API function. 2712 """ 2713 if len(args) not in (1,2,3): 2714 self._error_wrong_arg_count('unicode.split', node, args, "1-3") 2715 return node 2716 if len(args) < 2: 2717 args.append(ExprNodes.NullNode(node.pos)) 2718 self._inject_int_default_argument( 2719 node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1") 2720 2721 return self._substitute_method_call( 2722 node, function, 2723 "PyUnicode_Split", self.PyUnicode_Split_func_type, 2724 'split', is_unbound_method, args) 2725 2726 PyString_Tailmatch_func_type = PyrexTypes.CFuncType( 2727 PyrexTypes.c_bint_type, [ 2728 PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None), # bytes/str/unicode 2729 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None), 2730 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), 2731 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None), 2732 PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None), 2733 ], 2734 exception_value = '-1') 2735 2736 def _handle_simple_method_unicode_endswith(self, node, function, args, is_unbound_method): 2737 return self._inject_tailmatch( 2738 node, function, args, is_unbound_method, 'unicode', 'endswith', 2739 unicode_tailmatch_utility_code, +1) 2740 2741 def _handle_simple_method_unicode_startswith(self, node, function, args, is_unbound_method): 2742 return self._inject_tailmatch( 2743 node, function, args, is_unbound_method, 'unicode', 'startswith', 2744 unicode_tailmatch_utility_code, -1) 2745 2746 def _inject_tailmatch(self, node, function, args, is_unbound_method, type_name, 2747 method_name, utility_code, direction): 2748 """Replace unicode.startswith(...) and unicode.endswith(...) 2749 by a direct call to the corresponding C-API function. 2750 """ 2751 if len(args) not in (2,3,4): 2752 self._error_wrong_arg_count('%s.%s' % (type_name, method_name), node, args, "2-4") 2753 return node 2754 self._inject_int_default_argument( 2755 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0") 2756 self._inject_int_default_argument( 2757 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX") 2758 args.append(ExprNodes.IntNode( 2759 node.pos, value=str(direction), type=PyrexTypes.c_int_type)) 2760 2761 method_call = self._substitute_method_call( 2762 node, function, 2763 "__Pyx_Py%s_Tailmatch" % type_name.capitalize(), 2764 self.PyString_Tailmatch_func_type, 2765 method_name, is_unbound_method, args, 2766 utility_code = utility_code) 2767 return method_call.coerce_to(Builtin.bool_type, self.current_env()) 2768 2769 PyUnicode_Find_func_type = PyrexTypes.CFuncType( 2770 PyrexTypes.c_py_ssize_t_type, [ 2771 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), 2772 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None), 2773 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), 2774 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None), 2775 PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None), 2776 ], 2777 exception_value = '-2') 2778 2779 def _handle_simple_method_unicode_find(self, node, function, args, is_unbound_method): 2780 return self._inject_unicode_find( 2781 node, function, args, is_unbound_method, 'find', +1) 2782 2783 def _handle_simple_method_unicode_rfind(self, node, function, args, is_unbound_method): 2784 return self._inject_unicode_find( 2785 node, function, args, is_unbound_method, 'rfind', -1) 2786 2787 def _inject_unicode_find(self, node, function, args, is_unbound_method, 2788 method_name, direction): 2789 """Replace unicode.find(...) and unicode.rfind(...) by a 2790 direct call to the corresponding C-API function. 2791 """ 2792 if len(args) not in (2,3,4): 2793 self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4") 2794 return node 2795 self._inject_int_default_argument( 2796 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0") 2797 self._inject_int_default_argument( 2798 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX") 2799 args.append(ExprNodes.IntNode( 2800 node.pos, value=str(direction), type=PyrexTypes.c_int_type)) 2801 2802 method_call = self._substitute_method_call( 2803 node, function, "PyUnicode_Find", self.PyUnicode_Find_func_type, 2804 method_name, is_unbound_method, args) 2805 return method_call.coerce_to_pyobject(self.current_env()) 2806 2807 PyUnicode_Count_func_type = PyrexTypes.CFuncType( 2808 PyrexTypes.c_py_ssize_t_type, [ 2809 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), 2810 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None), 2811 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), 2812 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None), 2813 ], 2814 exception_value = '-1') 2815 2816 def _handle_simple_method_unicode_count(self, node, function, args, is_unbound_method): 2817 """Replace unicode.count(...) by a direct call to the 2818 corresponding C-API function. 2819 """ 2820 if len(args) not in (2,3,4): 2821 self._error_wrong_arg_count('unicode.count', node, args, "2-4") 2822 return node 2823 self._inject_int_default_argument( 2824 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0") 2825 self._inject_int_default_argument( 2826 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX") 2827 2828 method_call = self._substitute_method_call( 2829 node, function, "PyUnicode_Count", self.PyUnicode_Count_func_type, 2830 'count', is_unbound_method, args) 2831 return method_call.coerce_to_pyobject(self.current_env()) 2832 2833 PyUnicode_Replace_func_type = PyrexTypes.CFuncType( 2834 Builtin.unicode_type, [ 2835 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), 2836 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None), 2837 PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None), 2838 PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None), 2839 ]) 2840 2841 def _handle_simple_method_unicode_replace(self, node, function, args, is_unbound_method): 2842 """Replace unicode.replace(...) by a direct call to the 2843 corresponding C-API function. 2844 """ 2845 if len(args) not in (3,4): 2846 self._error_wrong_arg_count('unicode.replace', node, args, "3-4") 2847 return node 2848 self._inject_int_default_argument( 2849 node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1") 2850 2851 return self._substitute_method_call( 2852 node, function, "PyUnicode_Replace", self.PyUnicode_Replace_func_type, 2853 'replace', is_unbound_method, args) 2854 2855 PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType( 2856 Builtin.bytes_type, [ 2857 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None), 2858 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None), 2859 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None), 2860 ]) 2861 2862 PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType( 2863 Builtin.bytes_type, [ 2864 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None), 2865 ]) 2866 2867 _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII', 2868 'unicode_escape', 'raw_unicode_escape'] 2869 2870 _special_codecs = [ (name, codecs.getencoder(name)) 2871 for name in _special_encodings ] 2872 2873 def _handle_simple_method_unicode_encode(self, node, function, args, is_unbound_method): 2874 """Replace unicode.encode(...) by a direct C-API call to the 2875 corresponding codec. 2876 """ 2877 if len(args) < 1 or len(args) > 3: 2878 self._error_wrong_arg_count('unicode.encode', node, args, '1-3') 2879 return node 2880 2881 string_node = args[0] 2882 2883 if len(args) == 1: 2884 null_node = ExprNodes.NullNode(node.pos) 2885 return self._substitute_method_call( 2886 node, function, "PyUnicode_AsEncodedString", 2887 self.PyUnicode_AsEncodedString_func_type, 2888 'encode', is_unbound_method, [string_node, null_node, null_node]) 2889 2890 parameters = self._unpack_encoding_and_error_mode(node.pos, args) 2891 if parameters is None: 2892 return node 2893 encoding, encoding_node, error_handling, error_handling_node = parameters 2894 2895 if encoding and isinstance(string_node, ExprNodes.UnicodeNode): 2896 # constant, so try to do the encoding at compile time 2897 try: 2898 value = string_node.value.encode(encoding, error_handling) 2899 except: 2900 # well, looks like we can't 2901 pass 2902 else: 2903 value = BytesLiteral(value) 2904 value.encoding = encoding 2905 return ExprNodes.BytesNode( 2906 string_node.pos, value=value, type=Builtin.bytes_type) 2907 2908 if encoding and error_handling == 'strict': 2909 # try to find a specific encoder function 2910 codec_name = self._find_special_codec_name(encoding) 2911 if codec_name is not None: 2912 encode_function = "PyUnicode_As%sString" % codec_name 2913 return self._substitute_method_call( 2914 node, function, encode_function, 2915 self.PyUnicode_AsXyzString_func_type, 2916 'encode', is_unbound_method, [string_node]) 2917 2918 return self._substitute_method_call( 2919 node, function, "PyUnicode_AsEncodedString", 2920 self.PyUnicode_AsEncodedString_func_type, 2921 'encode', is_unbound_method, 2922 [string_node, encoding_node, error_handling_node]) 2923 2924 PyUnicode_DecodeXyz_func_ptr_type = PyrexTypes.CPtrType(PyrexTypes.CFuncType( 2925 Builtin.unicode_type, [ 2926 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None), 2927 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None), 2928 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None), 2929 ])) 2930 2931 _decode_c_string_func_type = PyrexTypes.CFuncType( 2932 Builtin.unicode_type, [ 2933 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None), 2934 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), 2935 PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), 2936 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None), 2937 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None), 2938 PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None), 2939 ]) 2940 2941 _decode_bytes_func_type = PyrexTypes.CFuncType( 2942 Builtin.unicode_type, [ 2943 PyrexTypes.CFuncTypeArg("string", PyrexTypes.py_object_type, None), 2944 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), 2945 PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), 2946 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None), 2947 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None), 2948 PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None), 2949 ]) 2950 2951 _decode_cpp_string_func_type = None # lazy init 2952 2953 def _handle_simple_method_bytes_decode(self, node, function, args, is_unbound_method): 2954 """Replace char*.decode() by a direct C-API call to the 2955 corresponding codec, possibly resolving a slice on the char*. 2956 """ 2957 if not (1 <= len(args) <= 3): 2958 self._error_wrong_arg_count('bytes.decode', node, args, '1-3') 2959 return node 2960 2961 # normalise input nodes 2962 string_node = args[0] 2963 start = stop = None 2964 if isinstance(string_node, ExprNodes.SliceIndexNode): 2965 index_node = string_node 2966 string_node = index_node.base 2967 start, stop = index_node.start, index_node.stop 2968 if not start or start.constant_result == 0: 2969 start = None 2970 if isinstance(string_node, ExprNodes.CoerceToPyTypeNode): 2971 string_node = string_node.arg 2972 2973 string_type = string_node.type 2974 if string_type in (Builtin.bytes_type, Builtin.bytearray_type): 2975 if is_unbound_method: 2976 string_node = string_node.as_none_safe_node( 2977 "descriptor '%s' requires a '%s' object but received a 'NoneType'", 2978 format_args=['decode', string_type.name]) 2979 else: 2980 string_node = string_node.as_none_safe_node( 2981 "'NoneType' object has no attribute '%s'", 2982 error="PyExc_AttributeError", 2983 format_args=['decode']) 2984 elif not string_type.is_string and not string_type.is_cpp_string: 2985 # nothing to optimise here 2986 return node 2987 2988 parameters = self._unpack_encoding_and_error_mode(node.pos, args) 2989 if parameters is None: 2990 return node 2991 encoding, encoding_node, error_handling, error_handling_node = parameters 2992 2993 if not start: 2994 start = ExprNodes.IntNode(node.pos, value='0', constant_result=0) 2995 elif not start.type.is_int: 2996 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) 2997 if stop and not stop.type.is_int: 2998 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) 2999 3000 # try to find a specific encoder function 3001 codec_name = None 3002 if encoding is not None: 3003 codec_name = self._find_special_codec_name(encoding) 3004 if codec_name is not None: 3005 decode_function = ExprNodes.RawCNameExprNode( 3006 node.pos, type=self.PyUnicode_DecodeXyz_func_ptr_type, 3007 cname="PyUnicode_Decode%s" % codec_name) 3008 encoding_node = ExprNodes.NullNode(node.pos) 3009 else: 3010 decode_function = ExprNodes.NullNode(node.pos) 3011 3012 # build the helper function call 3013 temps = [] 3014 if string_type.is_string: 3015 # C string 3016 if not stop: 3017 # use strlen() to find the string length, just as CPython would 3018 if not string_node.is_name: 3019 string_node = UtilNodes.LetRefNode(string_node) # used twice 3020 temps.append(string_node) 3021 stop = ExprNodes.PythonCapiCallNode( 3022 string_node.pos, "strlen", self.Pyx_strlen_func_type, 3023 args=[string_node], 3024 is_temp=False, 3025 utility_code=UtilityCode.load_cached("IncludeStringH", "StringTools.c"), 3026 ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) 3027 helper_func_type = self._decode_c_string_func_type 3028 utility_code_name = 'decode_c_string' 3029 elif string_type.is_cpp_string: 3030 # C++ std::string 3031 if not stop: 3032 stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX', 3033 constant_result=ExprNodes.not_a_constant) 3034 if self._decode_cpp_string_func_type is None: 3035 # lazy init to reuse the C++ string type 3036 self._decode_cpp_string_func_type = PyrexTypes.CFuncType( 3037 Builtin.unicode_type, [ 3038 PyrexTypes.CFuncTypeArg("string", string_type, None), 3039 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), 3040 PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), 3041 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None), 3042 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None), 3043 PyrexTypes.CFuncTypeArg("decode_func", self.PyUnicode_DecodeXyz_func_ptr_type, None), 3044 ]) 3045 helper_func_type = self._decode_cpp_string_func_type 3046 utility_code_name = 'decode_cpp_string' 3047 else: 3048 # Python bytes/bytearray object 3049 if not stop: 3050 stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX', 3051 constant_result=ExprNodes.not_a_constant) 3052 helper_func_type = self._decode_bytes_func_type 3053 if string_type is Builtin.bytes_type: 3054 utility_code_name = 'decode_bytes' 3055 else: 3056 utility_code_name = 'decode_bytearray' 3057 3058 node = ExprNodes.PythonCapiCallNode( 3059 node.pos, '__Pyx_%s' % utility_code_name, helper_func_type, 3060 args=[string_node, start, stop, encoding_node, error_handling_node, decode_function], 3061 is_temp=node.is_temp, 3062 utility_code=UtilityCode.load_cached(utility_code_name, 'StringTools.c'), 3063 ) 3064 3065 for temp in temps[::-1]: 3066 node = UtilNodes.EvalWithTempExprNode(temp, node) 3067 return node 3068 3069 _handle_simple_method_bytearray_decode = _handle_simple_method_bytes_decode 3070 3071 def _find_special_codec_name(self, encoding): 3072 try: 3073 requested_codec = codecs.getencoder(encoding) 3074 except LookupError: 3075 return None 3076 for name, codec in self._special_codecs: 3077 if codec == requested_codec: 3078 if '_' in name: 3079 name = ''.join([s.capitalize() 3080 for s in name.split('_')]) 3081 return name 3082 return None 3083 3084 def _unpack_encoding_and_error_mode(self, pos, args): 3085 null_node = ExprNodes.NullNode(pos) 3086 3087 if len(args) >= 2: 3088 encoding, encoding_node = self._unpack_string_and_cstring_node(args[1]) 3089 if encoding_node is None: 3090 return None 3091 else: 3092 encoding = None 3093 encoding_node = null_node 3094 3095 if len(args) == 3: 3096 error_handling, error_handling_node = self._unpack_string_and_cstring_node(args[2]) 3097 if error_handling_node is None: 3098 return None 3099 if error_handling == 'strict': 3100 error_handling_node = null_node 3101 else: 3102 error_handling = 'strict' 3103 error_handling_node = null_node 3104 3105 return (encoding, encoding_node, error_handling, error_handling_node) 3106 3107 def _unpack_string_and_cstring_node(self, node): 3108 if isinstance(node, ExprNodes.CoerceToPyTypeNode): 3109 node = node.arg 3110 if isinstance(node, ExprNodes.UnicodeNode): 3111 encoding = node.value 3112 node = ExprNodes.BytesNode( 3113 node.pos, value=BytesLiteral(encoding.utf8encode()), 3114 type=PyrexTypes.c_char_ptr_type) 3115 elif isinstance(node, (ExprNodes.StringNode, ExprNodes.BytesNode)): 3116 encoding = node.value.decode('ISO-8859-1') 3117 node = ExprNodes.BytesNode( 3118 node.pos, value=node.value, type=PyrexTypes.c_char_ptr_type) 3119 elif node.type is Builtin.bytes_type: 3120 encoding = None 3121 node = node.coerce_to(PyrexTypes.c_char_ptr_type, self.current_env()) 3122 elif node.type.is_string: 3123 encoding = None 3124 else: 3125 encoding = node = None 3126 return encoding, node 3127 3128 def _handle_simple_method_str_endswith(self, node, function, args, is_unbound_method): 3129 return self._inject_tailmatch( 3130 node, function, args, is_unbound_method, 'str', 'endswith', 3131 str_tailmatch_utility_code, +1) 3132 3133 def _handle_simple_method_str_startswith(self, node, function, args, is_unbound_method): 3134 return self._inject_tailmatch( 3135 node, function, args, is_unbound_method, 'str', 'startswith', 3136 str_tailmatch_utility_code, -1) 3137 3138 def _handle_simple_method_bytes_endswith(self, node, function, args, is_unbound_method): 3139 return self._inject_tailmatch( 3140 node, function, args, is_unbound_method, 'bytes', 'endswith', 3141 bytes_tailmatch_utility_code, +1) 3142 3143 def _handle_simple_method_bytes_startswith(self, node, function, args, is_unbound_method): 3144 return self._inject_tailmatch( 3145 node, function, args, is_unbound_method, 'bytes', 'startswith', 3146 bytes_tailmatch_utility_code, -1) 3147 3148 ''' # disabled for now, enable when we consider it worth it (see StringTools.c) 3149 def _handle_simple_method_bytearray_endswith(self, node, function, args, is_unbound_method): 3150 return self._inject_tailmatch( 3151 node, function, args, is_unbound_method, 'bytearray', 'endswith', 3152 bytes_tailmatch_utility_code, +1) 3153 3154 def _handle_simple_method_bytearray_startswith(self, node, function, args, is_unbound_method): 3155 return self._inject_tailmatch( 3156 node, function, args, is_unbound_method, 'bytearray', 'startswith', 3157 bytes_tailmatch_utility_code, -1) 3158 ''' 3159 3160 ### helpers 3161 3162 def _substitute_method_call(self, node, function, name, func_type, 3163 attr_name, is_unbound_method, args=(), 3164 utility_code=None, is_temp=None, 3165 may_return_none=ExprNodes.PythonCapiCallNode.may_return_none): 3166 args = list(args) 3167 if args and not args[0].is_literal: 3168 self_arg = args[0] 3169 if is_unbound_method: 3170 self_arg = self_arg.as_none_safe_node( 3171 "descriptor '%s' requires a '%s' object but received a 'NoneType'", 3172 format_args=[attr_name, function.obj.name]) 3173 else: 3174 self_arg = self_arg.as_none_safe_node( 3175 "'NoneType' object has no attribute '%s'", 3176 error = "PyExc_AttributeError", 3177 format_args = [attr_name]) 3178 args[0] = self_arg 3179 if is_temp is None: 3180 is_temp = node.is_temp 3181 return ExprNodes.PythonCapiCallNode( 3182 node.pos, name, func_type, 3183 args = args, 3184 is_temp = is_temp, 3185 utility_code = utility_code, 3186 may_return_none = may_return_none, 3187 result_is_used = node.result_is_used, 3188 ) 3189 3190 def _inject_int_default_argument(self, node, args, arg_index, type, default_value): 3191 assert len(args) >= arg_index 3192 if len(args) == arg_index: 3193 args.append(ExprNodes.IntNode(node.pos, value=str(default_value), 3194 type=type, constant_result=default_value)) 3195 else: 3196 args[arg_index] = args[arg_index].coerce_to(type, self.current_env()) 3197 3198 def _inject_bint_default_argument(self, node, args, arg_index, default_value): 3199 assert len(args) >= arg_index 3200 if len(args) == arg_index: 3201 default_value = bool(default_value) 3202 args.append(ExprNodes.BoolNode(node.pos, value=default_value, 3203 constant_result=default_value)) 3204 else: 3205 args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env()) 3206 3207 3208unicode_tailmatch_utility_code = UtilityCode.load_cached('unicode_tailmatch', 'StringTools.c') 3209bytes_tailmatch_utility_code = UtilityCode.load_cached('bytes_tailmatch', 'StringTools.c') 3210str_tailmatch_utility_code = UtilityCode.load_cached('str_tailmatch', 'StringTools.c') 3211 3212 3213class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): 3214 """Calculate the result of constant expressions to store it in 3215 ``expr_node.constant_result``, and replace trivial cases by their 3216 constant result. 3217 3218 General rules: 3219 3220 - We calculate float constants to make them available to the 3221 compiler, but we do not aggregate them into a single literal 3222 node to prevent any loss of precision. 3223 3224 - We recursively calculate constants from non-literal nodes to 3225 make them available to the compiler, but we only aggregate 3226 literal nodes at each step. Non-literal nodes are never merged 3227 into a single node. 3228 """ 3229 3230 def __init__(self, reevaluate=False): 3231 """ 3232 The reevaluate argument specifies whether constant values that were 3233 previously computed should be recomputed. 3234 """ 3235 super(ConstantFolding, self).__init__() 3236 self.reevaluate = reevaluate 3237 3238 def _calculate_const(self, node): 3239 if (not self.reevaluate and 3240 node.constant_result is not ExprNodes.constant_value_not_set): 3241 return 3242 3243 # make sure we always set the value 3244 not_a_constant = ExprNodes.not_a_constant 3245 node.constant_result = not_a_constant 3246 3247 # check if all children are constant 3248 children = self.visitchildren(node) 3249 for child_result in children.values(): 3250 if type(child_result) is list: 3251 for child in child_result: 3252 if getattr(child, 'constant_result', not_a_constant) is not_a_constant: 3253 return 3254 elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant: 3255 return 3256 3257 # now try to calculate the real constant value 3258 try: 3259 node.calculate_constant_result() 3260# if node.constant_result is not ExprNodes.not_a_constant: 3261# print node.__class__.__name__, node.constant_result 3262 except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError): 3263 # ignore all 'normal' errors here => no constant result 3264 pass 3265 except Exception: 3266 # this looks like a real error 3267 import traceback, sys 3268 traceback.print_exc(file=sys.stdout) 3269 3270 NODE_TYPE_ORDER = [ExprNodes.BoolNode, ExprNodes.CharNode, 3271 ExprNodes.IntNode, ExprNodes.FloatNode] 3272 3273 def _widest_node_class(self, *nodes): 3274 try: 3275 return self.NODE_TYPE_ORDER[ 3276 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))] 3277 except ValueError: 3278 return None 3279 3280 def _bool_node(self, node, value): 3281 value = bool(value) 3282 return ExprNodes.BoolNode(node.pos, value=value, constant_result=value) 3283 3284 def visit_ExprNode(self, node): 3285 self._calculate_const(node) 3286 return node 3287 3288 def visit_UnopNode(self, node): 3289 self._calculate_const(node) 3290 if not node.has_constant_result(): 3291 if node.operator == '!': 3292 return self._handle_NotNode(node) 3293 return node 3294 if not node.operand.is_literal: 3295 return node 3296 if node.operator == '!': 3297 return self._bool_node(node, node.constant_result) 3298 elif isinstance(node.operand, ExprNodes.BoolNode): 3299 return ExprNodes.IntNode(node.pos, value=str(int(node.constant_result)), 3300 type=PyrexTypes.c_int_type, 3301 constant_result=int(node.constant_result)) 3302 elif node.operator == '+': 3303 return self._handle_UnaryPlusNode(node) 3304 elif node.operator == '-': 3305 return self._handle_UnaryMinusNode(node) 3306 return node 3307 3308 _negate_operator = { 3309 'in': 'not_in', 3310 'not_in': 'in', 3311 'is': 'is_not', 3312 'is_not': 'is' 3313 }.get 3314 3315 def _handle_NotNode(self, node): 3316 operand = node.operand 3317 if isinstance(operand, ExprNodes.PrimaryCmpNode): 3318 operator = self._negate_operator(operand.operator) 3319 if operator: 3320 node = copy.copy(operand) 3321 node.operator = operator 3322 node = self.visit_PrimaryCmpNode(node) 3323 return node 3324 3325 def _handle_UnaryMinusNode(self, node): 3326 def _negate(value): 3327 if value.startswith('-'): 3328 value = value[1:] 3329 else: 3330 value = '-' + value 3331 return value 3332 3333 node_type = node.operand.type 3334 if isinstance(node.operand, ExprNodes.FloatNode): 3335 # this is a safe operation 3336 return ExprNodes.FloatNode(node.pos, value=_negate(node.operand.value), 3337 type=node_type, 3338 constant_result=node.constant_result) 3339 if node_type.is_int and node_type.signed or \ 3340 isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject: 3341 return ExprNodes.IntNode(node.pos, value=_negate(node.operand.value), 3342 type=node_type, 3343 longness=node.operand.longness, 3344 constant_result=node.constant_result) 3345 return node 3346 3347 def _handle_UnaryPlusNode(self, node): 3348 if (node.operand.has_constant_result() and 3349 node.constant_result == node.operand.constant_result): 3350 return node.operand 3351 return node 3352 3353 def visit_BoolBinopNode(self, node): 3354 self._calculate_const(node) 3355 if not node.operand1.has_constant_result(): 3356 return node 3357 if node.operand1.constant_result: 3358 if node.operator == 'and': 3359 return node.operand2 3360 else: 3361 return node.operand1 3362 else: 3363 if node.operator == 'and': 3364 return node.operand1 3365 else: 3366 return node.operand2 3367 3368 def visit_BinopNode(self, node): 3369 self._calculate_const(node) 3370 if node.constant_result is ExprNodes.not_a_constant: 3371 return node 3372 if isinstance(node.constant_result, float): 3373 return node 3374 operand1, operand2 = node.operand1, node.operand2 3375 if not operand1.is_literal or not operand2.is_literal: 3376 return node 3377 3378 # now inject a new constant node with the calculated value 3379 try: 3380 type1, type2 = operand1.type, operand2.type 3381 if type1 is None or type2 is None: 3382 return node 3383 except AttributeError: 3384 return node 3385 3386 if type1.is_numeric and type2.is_numeric: 3387 widest_type = PyrexTypes.widest_numeric_type(type1, type2) 3388 else: 3389 widest_type = PyrexTypes.py_object_type 3390 3391 target_class = self._widest_node_class(operand1, operand2) 3392 if target_class is None: 3393 return node 3394 elif target_class is ExprNodes.BoolNode and node.operator in '+-//<<%**>>': 3395 # C arithmetic results in at least an int type 3396 target_class = ExprNodes.IntNode 3397 elif target_class is ExprNodes.CharNode and node.operator in '+-//<<%**>>&|^': 3398 # C arithmetic results in at least an int type 3399 target_class = ExprNodes.IntNode 3400 3401 if target_class is ExprNodes.IntNode: 3402 unsigned = getattr(operand1, 'unsigned', '') and \ 3403 getattr(operand2, 'unsigned', '') 3404 longness = "LL"[:max(len(getattr(operand1, 'longness', '')), 3405 len(getattr(operand2, 'longness', '')))] 3406 new_node = ExprNodes.IntNode(pos=node.pos, 3407 unsigned=unsigned, longness=longness, 3408 value=str(int(node.constant_result)), 3409 constant_result=int(node.constant_result)) 3410 # IntNode is smart about the type it chooses, so we just 3411 # make sure we were not smarter this time 3412 if widest_type.is_pyobject or new_node.type.is_pyobject: 3413 new_node.type = PyrexTypes.py_object_type 3414 else: 3415 new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type) 3416 else: 3417 if target_class is ExprNodes.BoolNode: 3418 node_value = node.constant_result 3419 else: 3420 node_value = str(node.constant_result) 3421 new_node = target_class(pos=node.pos, type = widest_type, 3422 value = node_value, 3423 constant_result = node.constant_result) 3424 return new_node 3425 3426 def visit_MulNode(self, node): 3427 self._calculate_const(node) 3428 if node.operand1.is_sequence_constructor: 3429 return self._calculate_constant_seq(node, node.operand1, node.operand2) 3430 if isinstance(node.operand1, ExprNodes.IntNode) and \ 3431 node.operand2.is_sequence_constructor: 3432 return self._calculate_constant_seq(node, node.operand2, node.operand1) 3433 return self.visit_BinopNode(node) 3434 3435 def _calculate_constant_seq(self, node, sequence_node, factor): 3436 if factor.constant_result != 1 and sequence_node.args: 3437 if isinstance(factor.constant_result, (int, long)) and factor.constant_result <= 0: 3438 del sequence_node.args[:] 3439 sequence_node.mult_factor = None 3440 elif sequence_node.mult_factor is not None: 3441 if (isinstance(factor.constant_result, (int, long)) and 3442 isinstance(sequence_node.mult_factor.constant_result, (int, long))): 3443 value = sequence_node.mult_factor.constant_result * factor.constant_result 3444 sequence_node.mult_factor = ExprNodes.IntNode( 3445 sequence_node.mult_factor.pos, 3446 value=str(value), constant_result=value) 3447 else: 3448 # don't know if we can combine the factors, so don't 3449 return self.visit_BinopNode(node) 3450 else: 3451 sequence_node.mult_factor = factor 3452 return sequence_node 3453 3454 def visit_PrimaryCmpNode(self, node): 3455 # calculate constant partial results in the comparison cascade 3456 self.visitchildren(node, ['operand1']) 3457 left_node = node.operand1 3458 cmp_node = node 3459 while cmp_node is not None: 3460 self.visitchildren(cmp_node, ['operand2']) 3461 right_node = cmp_node.operand2 3462 cmp_node.constant_result = not_a_constant 3463 if left_node.has_constant_result() and right_node.has_constant_result(): 3464 try: 3465 cmp_node.calculate_cascaded_constant_result(left_node.constant_result) 3466 except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError): 3467 pass # ignore all 'normal' errors here => no constant result 3468 left_node = right_node 3469 cmp_node = cmp_node.cascade 3470 3471 if not node.cascade: 3472 if node.has_constant_result(): 3473 return self._bool_node(node, node.constant_result) 3474 return node 3475 3476 # collect partial cascades: [[value, CmpNode...], [value, CmpNode, ...], ...] 3477 cascades = [[node.operand1]] 3478 final_false_result = [] 3479 3480 def split_cascades(cmp_node): 3481 if cmp_node.has_constant_result(): 3482 if not cmp_node.constant_result: 3483 # False => short-circuit 3484 final_false_result.append(self._bool_node(cmp_node, False)) 3485 return 3486 else: 3487 # True => discard and start new cascade 3488 cascades.append([cmp_node.operand2]) 3489 else: 3490 # not constant => append to current cascade 3491 cascades[-1].append(cmp_node) 3492 if cmp_node.cascade: 3493 split_cascades(cmp_node.cascade) 3494 3495 split_cascades(node) 3496 3497 cmp_nodes = [] 3498 for cascade in cascades: 3499 if len(cascade) < 2: 3500 continue 3501 cmp_node = cascade[1] 3502 pcmp_node = ExprNodes.PrimaryCmpNode( 3503 cmp_node.pos, 3504 operand1=cascade[0], 3505 operator=cmp_node.operator, 3506 operand2=cmp_node.operand2, 3507 constant_result=not_a_constant) 3508 cmp_nodes.append(pcmp_node) 3509 3510 last_cmp_node = pcmp_node 3511 for cmp_node in cascade[2:]: 3512 last_cmp_node.cascade = cmp_node 3513 last_cmp_node = cmp_node 3514 last_cmp_node.cascade = None 3515 3516 if final_false_result: 3517 # last cascade was constant False 3518 cmp_nodes.append(final_false_result[0]) 3519 elif not cmp_nodes: 3520 # only constants, but no False result 3521 return self._bool_node(node, True) 3522 node = cmp_nodes[0] 3523 if len(cmp_nodes) == 1: 3524 if node.has_constant_result(): 3525 return self._bool_node(node, node.constant_result) 3526 else: 3527 for cmp_node in cmp_nodes[1:]: 3528 node = ExprNodes.BoolBinopNode( 3529 node.pos, 3530 operand1=node, 3531 operator='and', 3532 operand2=cmp_node, 3533 constant_result=not_a_constant) 3534 return node 3535 3536 def visit_CondExprNode(self, node): 3537 self._calculate_const(node) 3538 if not node.test.has_constant_result(): 3539 return node 3540 if node.test.constant_result: 3541 return node.true_val 3542 else: 3543 return node.false_val 3544 3545 def visit_IfStatNode(self, node): 3546 self.visitchildren(node) 3547 # eliminate dead code based on constant condition results 3548 if_clauses = [] 3549 for if_clause in node.if_clauses: 3550 condition = if_clause.condition 3551 if condition.has_constant_result(): 3552 if condition.constant_result: 3553 # always true => subsequent clauses can safely be dropped 3554 node.else_clause = if_clause.body 3555 break 3556 # else: false => drop clause 3557 else: 3558 # unknown result => normal runtime evaluation 3559 if_clauses.append(if_clause) 3560 if if_clauses: 3561 node.if_clauses = if_clauses 3562 return node 3563 elif node.else_clause: 3564 return node.else_clause 3565 else: 3566 return Nodes.StatListNode(node.pos, stats=[]) 3567 3568 def visit_SliceIndexNode(self, node): 3569 self._calculate_const(node) 3570 # normalise start/stop values 3571 if node.start is None or node.start.constant_result is None: 3572 start = node.start = None 3573 else: 3574 start = node.start.constant_result 3575 if node.stop is None or node.stop.constant_result is None: 3576 stop = node.stop = None 3577 else: 3578 stop = node.stop.constant_result 3579 # cut down sliced constant sequences 3580 if node.constant_result is not not_a_constant: 3581 base = node.base 3582 if base.is_sequence_constructor and base.mult_factor is None: 3583 base.args = base.args[start:stop] 3584 return base 3585 elif base.is_string_literal: 3586 base = base.as_sliced_node(start, stop) 3587 if base is not None: 3588 return base 3589 return node 3590 3591 def visit_ComprehensionNode(self, node): 3592 self.visitchildren(node) 3593 if isinstance(node.loop, Nodes.StatListNode) and not node.loop.stats: 3594 # loop was pruned already => transform into literal 3595 if node.type is Builtin.list_type: 3596 return ExprNodes.ListNode( 3597 node.pos, args=[], constant_result=[]) 3598 elif node.type is Builtin.set_type: 3599 return ExprNodes.SetNode( 3600 node.pos, args=[], constant_result=set()) 3601 elif node.type is Builtin.dict_type: 3602 return ExprNodes.DictNode( 3603 node.pos, key_value_pairs=[], constant_result={}) 3604 return node 3605 3606 def visit_ForInStatNode(self, node): 3607 self.visitchildren(node) 3608 sequence = node.iterator.sequence 3609 if isinstance(sequence, ExprNodes.SequenceNode): 3610 if not sequence.args: 3611 if node.else_clause: 3612 return node.else_clause 3613 else: 3614 # don't break list comprehensions 3615 return Nodes.StatListNode(node.pos, stats=[]) 3616 # iterating over a list literal? => tuples are more efficient 3617 if isinstance(sequence, ExprNodes.ListNode): 3618 node.iterator.sequence = sequence.as_tuple() 3619 return node 3620 3621 def visit_WhileStatNode(self, node): 3622 self.visitchildren(node) 3623 if node.condition and node.condition.has_constant_result(): 3624 if node.condition.constant_result: 3625 node.condition = None 3626 node.else_clause = None 3627 else: 3628 return node.else_clause 3629 return node 3630 3631 def visit_ExprStatNode(self, node): 3632 self.visitchildren(node) 3633 if not isinstance(node.expr, ExprNodes.ExprNode): 3634 # ParallelRangeTransform does this ... 3635 return node 3636 # drop unused constant expressions 3637 if node.expr.has_constant_result(): 3638 return None 3639 return node 3640 3641 # in the future, other nodes can have their own handler method here 3642 # that can replace them with a constant result node 3643 3644 visit_Node = Visitor.VisitorTransform.recurse_to_children 3645 3646 3647class FinalOptimizePhase(Visitor.CythonTransform): 3648 """ 3649 This visitor handles several commuting optimizations, and is run 3650 just before the C code generation phase. 3651 3652 The optimizations currently implemented in this class are: 3653 - eliminate None assignment and refcounting for first assignment. 3654 - isinstance -> typecheck for cdef types 3655 - eliminate checks for None and/or types that became redundant after tree changes 3656 """ 3657 def visit_SingleAssignmentNode(self, node): 3658 """Avoid redundant initialisation of local variables before their 3659 first assignment. 3660 """ 3661 self.visitchildren(node) 3662 if node.first: 3663 lhs = node.lhs 3664 lhs.lhs_of_first_assignment = True 3665 return node 3666 3667 def visit_SimpleCallNode(self, node): 3668 """Replace generic calls to isinstance(x, type) by a more efficient 3669 type check. 3670 """ 3671 self.visitchildren(node) 3672 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode): 3673 if node.function.name == 'isinstance' and len(node.args) == 2: 3674 type_arg = node.args[1] 3675 if type_arg.type.is_builtin_type and type_arg.type.name == 'type': 3676 cython_scope = self.context.cython_scope 3677 node.function.entry = cython_scope.lookup('PyObject_TypeCheck') 3678 node.function.type = node.function.entry.type 3679 PyTypeObjectPtr = PyrexTypes.CPtrType(cython_scope.lookup('PyTypeObject').type) 3680 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr) 3681 return node 3682 3683 def visit_PyTypeTestNode(self, node): 3684 """Remove tests for alternatively allowed None values from 3685 type tests when we know that the argument cannot be None 3686 anyway. 3687 """ 3688 self.visitchildren(node) 3689 if not node.notnone: 3690 if not node.arg.may_be_none(): 3691 node.notnone = True 3692 return node 3693 3694 def visit_NoneCheckNode(self, node): 3695 """Remove None checks from expressions that definitely do not 3696 carry a None value. 3697 """ 3698 self.visitchildren(node) 3699 if not node.arg.may_be_none(): 3700 return node.arg 3701 return node 3702 3703class ConsolidateOverflowCheck(Visitor.CythonTransform): 3704 """ 3705 This class facilitates the sharing of overflow checking among all nodes 3706 of a nested arithmetic expression. For example, given the expression 3707 a*b + c, where a, b, and x are all possibly overflowing ints, the entire 3708 sequence will be evaluated and the overflow bit checked only at the end. 3709 """ 3710 overflow_bit_node = None 3711 3712 def visit_Node(self, node): 3713 if self.overflow_bit_node is not None: 3714 saved = self.overflow_bit_node 3715 self.overflow_bit_node = None 3716 self.visitchildren(node) 3717 self.overflow_bit_node = saved 3718 else: 3719 self.visitchildren(node) 3720 return node 3721 3722 def visit_NumBinopNode(self, node): 3723 if node.overflow_check and node.overflow_fold: 3724 top_level_overflow = self.overflow_bit_node is None 3725 if top_level_overflow: 3726 self.overflow_bit_node = node 3727 else: 3728 node.overflow_bit_node = self.overflow_bit_node 3729 node.overflow_check = False 3730 self.visitchildren(node) 3731 if top_level_overflow: 3732 self.overflow_bit_node = None 3733 else: 3734 self.visitchildren(node) 3735 return node 3736