1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 2# 3# Copyright 2020-2024 Huawei Technologies Co., Ltd 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# ============================================================================ 17"""The module of parser python object, called by c++.""" 18 19from __future__ import absolute_import 20import os 21import sys 22import ast 23import re 24import hashlib 25import inspect 26import types 27from collections import namedtuple 28from typing import NamedTuple 29from textwrap import dedent 30import builtins 31import numpy 32 33import asttokens 34import astunparse 35 36from mindspore import Tensor, CSRTensor, COOTensor, RowTensor 37from mindspore import log as logger 38from mindspore import nn 39from mindspore import ops 40from mindspore import context 41from mindspore import tensor 42from mindspore.common.api import _MindsporeFunctionExecutor 43from mindspore.common import dtype as mstype 44from mindspore.common.parameter import Parameter 45from mindspore.common import mutable 46from mindspore.common._register_for_adapter import ms_adapter_registry 47from mindspore._checkparam import is_stub_tensor 48from .namespace import Namespace, ModuleNamespace, ClosureNamespace, ClassMemberNamespace 49from .resources import parse_object_map, ops_symbol_map, convert_object_map, convert_class_to_function_map, trope_ns 50from .resources import SYMBOL_UNDEFINE, constant_fold_functions 51from ...common.api import _convert_python_data 52 53# Define resolve type 54RESOLVE_TYPE_NONE = 0 # Resolve None. 55RESOLVE_TYPE_FUNCTION = 1 # Resolve function. 56RESOLVE_TYPE_METHOD = 2 # Resolve class method. 57RESOLVE_TYPE_CLASS_TYPE = 3 # Resolve class type. 58RESOLVE_TYPE_CLASS_INSTANCE = 4 # Resolve the class instance of common class. 59RESOLVE_TYPE_NAMESPACE_INSTANCE = 5 # Resolve the namespace instance. 60RESOLVE_TYPE_NUMPY_INT_NUMBER = 6 # Resolve numpy int number. 61RESOLVE_TYPE_NUMPY_FLOAT_NUMBER = 7 # Resolve numpy float number. 62RESOLVE_TYPE_NUMPY_BOOL_NUMBER = 8 # Resolve numpy bool number. 63RESOLVE_TYPE_TUPLE = 9 # Resolve builtin tuple type. 64RESOLVE_TYPE_LIST = 10 # Resolve builtin list type. 65RESOLVE_TYPE_INVALID = 0xFF # Resolve invalid. 66 67# Define the class instance detail type 68# When the type is RESOLVE_TYPE_CLASS_INSTANCE 69CLASS_INSTANCE_TYPE_CELL = 0 # Class instance type is Cell 70CLASS_INSTANCE_TYPE_PRIMITIVE = 1 # Class instance type is Primitive 71CLASS_INSTANCE_TYPE_NUMPY_ARRAY = 2 # Class instance type is Numpy Array 72CLASS_INSTANCE_TYPE_TENSOR = 3 # Class instance type is Tensor 73CLASS_INSTANCE_TYPE_ADAPTER_TENSOR = 4 # Class instance type is Adapter Tensor 74CLASS_INSTANCE_TYPE_INVALID = 0xFF 75 76# Ast main type 77AST_MAIN_TYPE_STMT = 0 # ast.Stmt 78AST_MAIN_TYPE_EXPR = 1 # ast.Expr 79AST_MAIN_TYPE_SLICE = 2 # ast.Slice 80AST_MAIN_TYPE_UNKNOWN = 0xFF # unknown 81 82# Ast sub type 83AST_SUB_TYPE_AND = 3 # ast.And 84AST_SUB_TYPE_OR = 4 # ast.Or 85AST_SUB_TYPE_NAME = 5 # ast.Name 86AST_SUB_TYPE_TUPLE = 6 # ast.Tuple 87AST_SUB_TYPE_LIST = 7 # ast.List 88AST_SUB_TYPE_SUBSCRIPT = 8 # ast.Subscript 89AST_SUB_TYPE_STARRED = 9 # ast.Starred 90AST_SUB_TYPE_ATTRIBUTE = 10 # ast.Attribute 91AST_SUB_TYPE_DICT = 11 # ast.Dict 92AST_SUB_TYPE_UNKNOWN = 0xFF # unknown 93 94# Syntax support 95SYNTAX_SUPPORTED = 0 # Supported syntax 96SYNTAX_UNSUPPORTED_INTERNAL_TYPE = 1 # Unsupported internal type 97SYNTAX_UNSUPPORTED_EXTERNAL_TYPE = 2 # Unsupported external type 98SYNTAX_HYBRID_TYPE = 3 # Hybrid type 99SYNTAX_UNSUPPORTED_NAMESPACE = 4 # Unsupported namespace 100 101# Module source location 102MODULE_FROM_MINDSPORE = 0 103MODULE_FROM_THIRDPARTY = 1 104MODULE_FROM_USER_WORKSPACE = 2 105 106 107# Process expr statement white list 108# Add as needed, eg: "clear", "extend", "insert", "remove", "reverse" 109parse_expr_statement_white_list = ( 110 "append", "insert", "clear", "reverse", "extend", "update", 111) 112 113_builtin_function_or_method_type = type(abs) 114 115# Unsupported python builtin type in graph mode. 116_unsupported_python_builtin_type = ( 117 set, dict, slice, complex, reversed, type, 118) 119 120# Unsupported python builtin type in JIT Fallback. 121_fallback_unsupported_python_builtin_type = ( 122 compile, eval, exec 123) 124 125_modules_from_mindspore = ( 126 "mindspore", "msadapter", "mindocr", "mindyolo", "mindnlp", "mindcv", "mindspore_rec", "mindaudio", "mindone", 127 "mindspore_rl", "mindformers", "mindpet", "mindpose", "mindface", "mindsearch", "mindinsight", "mindelec", 128 "mindflow", "mindsponge", "mindearth", "sciai", "mindquantum", "mindarmour", "mindpandas", "mindvision", 129 "mindspore_gl", "mindspore_federated", "mindspore_gs", "mindspore_serving", "mindspore_xai", "mindspore_hub", 130 "ringmo_framework", "troubleshooter", "mindtorch", 131) 132 133_global_params = {} 134 135 136def _convert_map(): 137 """Get convert object map""" 138 adapter_convert_map = ms_adapter_registry.convert_map 139 return adapter_convert_map if adapter_convert_map else convert_object_map 140 141 142def create_slice_obj(start, end, step): 143 """Create slice object""" 144 return slice(start, end, step) 145 146 147def parse_cb(func, parse_method=None): 148 """Implements the function of parse.""" 149 return Parser(func, parse_method) 150 151 152def get_attr_from_object(obj, attr_name=None): 153 """ 154 Get attr from object. 155 156 Args: 157 obj(Object): Instance of class or module. 158 attr_name(str): Attribute name to check. 159 160 Returns: 161 Object, obj's attr. 162 """ 163 164 if obj is not None and attr_name is not None and hasattr(obj, attr_name): 165 return getattr(obj, attr_name) 166 return None 167 168 169def check_attr_is_property(obj, attr_name): 170 """ 171 Check if the attribute is decorated by @property. 172 173 Args: 174 obj(Object): Instance of a class. 175 attr_name(str): Attribute name to check. 176 177 Returns: 178 obj(bool): If the attribute is decorated by @property. 179 """ 180 logger.debug(f"attr_name:{attr_name}") 181 logger.debug(f"obj.__class__.__dict__.keys():{obj.__class__.__dict__.keys()}") 182 if attr_name in obj.__class__.__dict__.keys() and isinstance(obj.__class__.__dict__[attr_name], property): 183 attr_obj = obj.__class__.__dict__[attr_name] 184 if (hasattr(attr_obj, 'fget')) and hasattr(attr_obj.fget, '__code__'): 185 logger.debug(f'The attribute {attr_name} is decorated by @property.') 186 return True 187 return False 188 189 190def get_parse_method_of_class(obj, parse_method=None): 191 """ 192 Get parse method of class. 193 194 Args: 195 obj(Object): Instance of class. 196 parse_method(str): Save the method name. Cell object has default method named 'construct'. 197 198 Returns: 199 Function, obj's method. 200 """ 201 202 method_name = None 203 if parse_method is not None: 204 method_name = parse_method 205 elif isinstance(obj, nn.Cell): 206 if obj._enable_backward_hook: 207 method_name = "_backward_hook_construct" 208 else: 209 method_name = "construct" 210 211 return get_attr_from_object(obj, method_name) 212 213 214def get_bprop_method_of_class(obj, parse_method=None): 215 """ 216 Get bprop method of class. 217 218 Args: 219 obj (Object): Instance of class. 220 parse_method(str): Save the method name. Cell object has default method named 'bprop'. 221 222 Returns: 223 Function, obj's method. 224 """ 225 226 if isinstance(obj, nn.Cell): 227 method_name = "bprop" 228 return get_attr_from_object(obj, method_name) 229 return None 230 231 232def resolve_symbol(namespace, symbol): 233 """ 234 Resolve a symbol. 235 236 Note: 237 Can't get function when use closure function. So save the fn on namespace. 238 239 Args: 240 namespace (Object): Symbol's namespace. 241 symbol (str): Need resolve symbol. 242 243 Returns: 244 Object, resolve result of symbol. 245 """ 246 # All exceptions need to be caught in this function 247 try: 248 resolve_ = namespace[symbol] 249 250 # The list and dict is not hashable, it can not be key for the map, just return the result 251 if isinstance(resolve_, (tuple, list, dict)): 252 return resolve_ 253 if hasattr(resolve_, "__self__") and isinstance(resolve_.__self__, (tuple, list, dict)): 254 return resolve_ 255 if getattr(resolve_, "__hash__") is None: 256 return resolve_ 257 258 # If need trope the obj 259 convert_map = _convert_map() 260 if resolve_ in convert_map: 261 resolve_ = convert_map.get(resolve_) 262 logger.debug("Convert resolve: %r", resolve_) 263 except Exception as e: 264 if isinstance(e, NotImplementedError): 265 raise e 266 resolve_ = mstype._null 267 logger.debug("Resolve exception occurred, value: %r", e) 268 logger.debug("Resolve type is invalid, namespace: %s, symbol: %s", 269 namespace.__str__(), symbol) 270 271 if isinstance(resolve_, _MindsporeFunctionExecutor): 272 logger.debug("Resolve class _MindsporeFunctionExecutor, resolve fn instead.") 273 resolve_ = resolve_.fn 274 logger.debug(f"Found '{symbol}' in {namespace.__str__()}, resolved: {resolve_} / {type(resolve_)}") 275 return resolve_ 276 277 278def generate_scope(obj): 279 """Generate the scope for every cell object in the network.""" 280 if isinstance(obj, nn.Cell): 281 obj.generate_scope() 282 283 284def get_scope_name(obj): 285 """Returns the scope of a cell object in one network.""" 286 if isinstance(obj, nn.Cell): 287 return obj.get_scope() 288 return None 289 290 291def get_type(obj): 292 """Returns the type string of input object""" 293 return type(obj) 294 295 296def get_object_key(obj): 297 """Return the function key: module + name.""" 298 obj_key = "" 299 if hasattr(obj, "__name__"): 300 if hasattr(obj, "cell_init_args"): 301 obj_key = "%s_ID" % (str(obj.__class__.__name__) + str(obj.__name__) + obj.cell_init_args) 302 obj_id = "%s_ID%d" % (str(obj.__class__.__name__) + str(obj.__name__), id(obj)) 303 else: 304 # `<class 'xxxxxxx'>` 305 # -> `xxxxxxx` 306 tag = str(obj.__class__)[8:-2] 307 if hasattr(obj, "cell_init_args"): 308 obj_key = "%s_ID" % (tag + obj.cell_init_args) 309 obj_id = "%s_ID%d" % (tag, id(obj)) 310 logger.debug("obj_key: %s, obj_id: %s", obj_key, obj_id) 311 312 # Method has same id of different instance 313 if isinstance(obj, types.MethodType): 314 method_instance = obj.__self__ 315 instance_id = "%s_ID%d" % (str(method_instance.__class__.__name__), id(method_instance)) 316 if isinstance(method_instance, (tuple, list, dict)): 317 obj_id = instance_id + obj_id 318 else: 319 obj_id = instance_id + obj_id + str(obj.__hash__()) 320 return obj_id, obj_key 321 322 323def is_class_member_of_self(node): 324 """Check the attr is class member variable.""" 325 type_ = node.__class__.__name__ 326 if type_ == "Attribute": 327 if not hasattr(node.value, "id"): 328 return False 329 id_ = node.value.id 330 if id_ == "self": 331 return True 332 return False 333 334 335def is_class_member_recursive(node): 336 """Check the attr is class member variable resurcively.""" 337 type_ = node.__class__.__name__ 338 if type_ == "Attribute": 339 if hasattr(node.value, "value"): 340 return is_class_member_recursive(node.value) 341 if not hasattr(node.value, "id"): 342 return False 343 id_ = node.value.id 344 if id_ == "self": 345 return True 346 return False 347 348 349def get_obj_id(obj): 350 """Get the obj id.""" 351 return str(id(obj)) 352 353 354def is_lambda_function(obj): 355 """Determine whether is a lambda function.""" 356 if isinstance(obj, types.FunctionType): 357 source_code = inspect.getsource(obj) 358 return "lambda" in source_code and "<function" in str(obj) and "<lambda>" in str(obj) 359 return False 360 361 362def get_obj_type(obj): 363 """Get the obj type.""" 364 logger.debug("Get object type: %r", obj) 365 obj_type = RESOLVE_TYPE_INVALID 366 if obj is None: 367 obj_type = RESOLVE_TYPE_NONE 368 elif isinstance(obj, types.FunctionType) or type(obj).__name__ == 'cython_function_or_method': 369 obj_type = RESOLVE_TYPE_FUNCTION 370 elif isinstance(obj, types.MethodType): 371 obj_type = RESOLVE_TYPE_METHOD 372 elif isinstance(obj, type): 373 obj_type = RESOLVE_TYPE_CLASS_TYPE 374 elif isinstance(obj, Namespace): 375 obj_type = RESOLVE_TYPE_NAMESPACE_INSTANCE 376 elif isinstance(obj, tuple): 377 obj_type = RESOLVE_TYPE_TUPLE 378 elif isinstance(obj, list): 379 obj_type = RESOLVE_TYPE_LIST 380 elif _is_class_instance(obj): 381 obj_type = RESOLVE_TYPE_CLASS_INSTANCE 382 elif _is_numpy_int_number(obj): 383 obj_type = RESOLVE_TYPE_NUMPY_INT_NUMBER 384 elif _is_numpy_float_number(obj): 385 obj_type = RESOLVE_TYPE_NUMPY_FLOAT_NUMBER 386 elif _is_numpy_bool_number(obj): 387 obj_type = RESOLVE_TYPE_NUMPY_BOOL_NUMBER 388 else: 389 obj_type = RESOLVE_TYPE_INVALID 390 return obj_type 391 392 393def check_obj_bool(obj): 394 """Check if the type of the current object is bool.""" 395 logger.debug("Check if the type of the current object(%r) is bool: %r", obj, bool(obj)) 396 return bool(obj) 397 398 399def get_class_instance_type(obj): 400 """Get the class instance detail type.""" 401 # Check the obj type 402 logger.debug("Get the class type(%r)", obj) 403 if isinstance(obj, nn.Cell): 404 return CLASS_INSTANCE_TYPE_CELL 405 if isinstance(obj, ops.Primitive): 406 return CLASS_INSTANCE_TYPE_PRIMITIVE 407 if isinstance(obj, numpy.ndarray): 408 return CLASS_INSTANCE_TYPE_NUMPY_ARRAY 409 return CLASS_INSTANCE_TYPE_INVALID 410 411 412def _is_ms_class(obj): 413 """Check if obj is ms_class object.""" 414 return hasattr(obj, '__ms_class__') 415 416 417def _is_class_instance(obj): 418 """Confirm the obj is class instance.""" 419 return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_ms_class(obj) or hasattr(obj, '__parse_method__') 420 421 422def _is_numpy_int_number(obj): 423 """Confirm the obj is numpy int number.""" 424 return isinstance(obj, (numpy.int8, numpy.int16, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint64)) 425 426 427def _is_numpy_float_number(obj): 428 """Confirm the obj is numpy float number.""" 429 return isinstance(obj, (numpy.float16, numpy.float32, numpy.float64)) 430 431 432def _is_numpy_bool_number(obj): 433 """Confirm the obj is numpy bool number.""" 434 return isinstance(obj, numpy.bool_) 435 436 437def _convert_tuple_to_args_kwargs(params): 438 """Convert tuple to args and kwargs.""" 439 args = tuple() 440 kwargs = dict() 441 for param in params: 442 if isinstance(param, dict): 443 kwargs.update(param) 444 else: 445 args += (param,) 446 return (args, kwargs) 447 448 449def is_supported_create_instance_type(cls_type): 450 """Check if cls_type is a supported instance type.""" 451 return issubclass(cls_type, (nn.Cell, ops.Primitive, ops.GradOperation)) or _is_ms_class(cls_type) 452 453 454def create_instance(cls_type, params=None): 455 """Create python instance.""" 456 if not isinstance(cls_type, type): 457 logger.warning(f"create_instance(), cls_type is not a type, cls_type: {cls_type}") 458 return None 459 460 # Check the type, now only support nn.Cell and Primitive. 461 obj = None 462 if is_supported_create_instance_type(cls_type): 463 # Check arguments, only support *args or **kwargs. 464 if params is None: 465 obj = cls_type() 466 elif isinstance(params, tuple): 467 args, kwargs = _convert_tuple_to_args_kwargs(params) 468 logger.debug(f"create_instance(), args: {args}, kwargs: {kwargs}") 469 if args and kwargs: 470 obj = cls_type(*args, **kwargs) 471 elif args: 472 obj = cls_type(*args) 473 elif kwargs: 474 obj = cls_type(**kwargs) 475 # If invalid parameters. 476 if obj is None: 477 raise ValueError(f"When call 'create_instance', the parameter should be *args or **kwargs, " 478 f"but got {params.__class__.__name__}, params: {params}") 479 return obj 480 481 482def convert_class_to_function(cls_str, cls_obj): 483 """Convert class to function.""" 484 if issubclass(cls_obj, (Parameter, ops.MultitypeFuncGraph)): 485 raise ValueError(f"Failed to compile in GRAPH_MODE because creating {cls_str} instances is not " 486 f"supported in 'construct' or @jit decorated function. Try to create {cls_str} " 487 f"instances external such as initialized in the method '__init__' before assigning. " 488 f"For more details, please refer to " 489 f"https://www.mindspore.cn/docs/zh-CN/master/design/dynamic_graph_and_static_graph.html \n") 490 return convert_class_to_function_map.get(cls_str) 491 492 493def python_isinstance(x, cmp_type): 494 """Python isinstance function.""" 495 # Convert _c_expression tensor to python tensor. 496 x = _convert_python_data(x) 497 return isinstance(x, cmp_type) 498 499 500def ms_isinstance(x, cmp_type): 501 """Isinstance for ms type.""" 502 pytype_to_mstype = { 503 bool: mstype.Bool, 504 int: mstype.Int, 505 float: mstype.Float, 506 str: mstype.String, 507 list: mstype.List, 508 tuple: mstype.Tuple, 509 dict: mstype.Dict, 510 Tensor: mstype.TensorType, 511 Parameter: mstype.RefType, 512 slice: mstype.Slice, 513 } 514 if cmp_type not in pytype_to_mstype: 515 return False 516 if isinstance(x, mstype.Bool) and cmp_type == int: 517 return True 518 return isinstance(x, pytype_to_mstype.get(cmp_type)) 519 520 521def is_cell_list(obj): 522 """Check if obj is nn.CellList""" 523 return isinstance(obj, nn.CellList) 524 525 526def convert_cell_list_to_sequence(obj): 527 """Convert nn.CellList to sequence.""" 528 if not hasattr(obj, "__cell_as_list__"): 529 raise TypeError(f"Obj should be nn.CellList, but got {obj}") 530 if not hasattr(obj, "_cells"): 531 raise AttributeError(f"nn.CellList is missing _cells property.") 532 cells = getattr(obj, "_cells") 533 return list(cells.values()) 534 535 536def get_obj_from_sequence(obj, index): 537 """Implement `tuple_getitem`.""" 538 if not isinstance(obj, (tuple, list)): 539 raise TypeError(f"Should not get item from a object that not sequence type, obj: {obj}") 540 # Not check index out of range by self. 541 return obj[index] 542 543 544def get_module_namespace(obj): 545 """Get the module's namespace.""" 546 logger.debug("get module namespace, module: %r", obj) 547 mod_namespace = None 548 if isinstance(obj, types.ModuleType): 549 mod_namespace = ModuleNamespace(obj.__name__) 550 else: 551 logger.warning("Module(%r) is invalid, get namespace failure!", obj) 552 return mod_namespace 553 554 555def get_class_member_namespace_symbol(obj): 556 """Get obj class member type.""" 557 logger.debug("get class instance namespace, object: %r", obj) 558 class_namespace = ClassMemberNamespace(obj) 559 logger.debug("class namespace: %r", class_namespace) 560 return class_namespace 561 562 563def get_obj_defined_from_obj_type(obj_type): 564 """Get the class defined from object type which is in BuiltInMap.""" 565 logger.debug("get the object type: %r", obj_type) 566 567 def func(): 568 pass 569 570 obj_type_defined_map = { 571 "Tensor": Tensor, 572 "RowTensor": RowTensor, 573 "COOTensor": COOTensor, 574 "CSRTensor": CSRTensor, 575 "Parameter": Parameter, 576 "String": "", 577 "Function": func, 578 "Int": int, 579 "Float": float, 580 "UInt": int, 581 "Bool": bool, 582 "List": list, 583 "Tuple": tuple, 584 "Dictionary": dict, 585 "NamedTuple": NamedTuple, 586 } 587 588 return obj_type_defined_map.get(obj_type) 589 590 591def is_class_type(cls): 592 """Check if cls is a class type.""" 593 return isinstance(cls, type) 594 595 596def get_adapter_tensor_attr(name): 597 """Get the method or @property modified function of the class, excluding those inherited from parent class.""" 598 cls = ms_adapter_registry.tensor 599 properties = [key for key, value in vars(cls).items() if isinstance(value, property)] 600 if name in properties: 601 return getattr(cls, name).fget, True 602 methods = [key for key, value in vars(cls).items() if inspect.isfunction(value)] 603 if name in methods: 604 return getattr(cls, name), False 605 return None, False 606 607 608def is_adapter_tensor_class(cls): 609 """Check if cls is adapter tensor type.""" 610 return cls in (Tensor, ms_adapter_registry.tensor) 611 612 613def is_adapter_parameter_class(cls): 614 """Check if cls is adapter parameter type.""" 615 return cls in (Parameter, ms_adapter_registry.parameter) 616 617 618def get_ms_class_name(cls): 619 """Get the name of the class instance decorated with jit_class.""" 620 if isinstance(cls, type): 621 return cls.__name__ 622 return cls.__class__.__name__ 623 624 625def convert_to_ms_tensor(data): 626 """Convert C++ tensor to mindspore tensor.""" 627 return Tensor(data) 628 629 630def convert_to_ms_csrtensor(data): 631 """Convert C++ csrtensor to mindspore csrtensor.""" 632 return CSRTensor(csr_tensor=data) 633 634 635def convert_to_ms_cootensor(data): 636 """Convert C++ cootensor to mindspore cootensor.""" 637 return COOTensor(coo_tensor=data) 638 639 640def convert_to_namedtuple(type_name, key_sequeue, value_sequeue): 641 """Convert C++ namedtuple to python object namedtuple.""" 642 logger.debug(f"type_name: {type_name}, key_sequeue: {key_sequeue}, value_sequeue: {value_sequeue}") 643 return namedtuple(type_name, [*key_sequeue])(*value_sequeue) 644 645 646def get_object_description(obj, fname, fline): 647 """Return method or funcition description for error report, include location, class name, etc.""" 648 if isinstance(obj, types.MethodType): 649 obj_cls = obj.__self__.__class__ 650 class_name = f"{obj_cls.__module__}.{obj_cls.__qualname__}" 651 cls_fname = inspect.getfile(obj_cls) 652 _, cls_fline = inspect.getsourcelines(obj_cls) 653 class_loc = f"{cls_fname}:{cls_fline}" 654 return f"bound method '{obj.__name__}' at {fname}:{fline} of <{class_name} at {class_loc} object>" 655 if isinstance(obj, types.FunctionType): 656 return f"function '{obj.__name__}' at {fname}:{fline}" 657 if isinstance(obj, ast.FunctionDef): 658 return f"function '{obj.name}' at {fname}:{fline}" 659 if isinstance(obj, ast.Attribute): 660 return f"attribute " 661 return str(obj) 662 663 664def expand_expr_statement(node): 665 """ 666 Process the expr statement and expand it. 667 668 Returns: 669 tuple, (True, expr.value, x)/(False, None, None). 670 """ 671 if isinstance(node, ast.Expr): 672 expr_value = node.value 673 if isinstance(expr_value, ast.Call): 674 func = expr_value.func 675 if isinstance(func, ast.Attribute) and \ 676 hasattr(func, "attr") and \ 677 hasattr(func, "value"): 678 method = func.attr 679 target = func.value 680 if method in parse_expr_statement_white_list: 681 logger.debug("Expand expr, target:%s, method:%s", target, method) 682 return True, expr_value, target 683 if not isinstance(expr_value, ast.Str): 684 return True, expr_value 685 return (False,) 686 687 688def get_ast_namespace_symbol(obj): 689 """Get obj type and namespace and symbol.""" 690 # Get symbol from object map. 691 ops_info = parse_object_map.get(type(obj), SYMBOL_UNDEFINE) 692 logger.debug("ops info: %r", ops_info) 693 return ops_info 694 695 696def get_operation_symbol(obj): 697 """Get obj operation symbol.""" 698 ops_symbol = ops_symbol_map.get(type(obj), SYMBOL_UNDEFINE) 699 logger.debug("ops symbol: %s", ops_symbol) 700 return ops_symbol 701 702 703def get_operation_namespace_symbol(var: str): 704 """Get operation namespace and symbol.""" 705 ops_info = (trope_ns, var) 706 logger.debug("get operation ops info: %r", ops_info) 707 return ops_info 708 709def get_ast_type(node): 710 """Get the ast type.""" 711 ast_type = AST_SUB_TYPE_UNKNOWN 712 if isinstance(node, ast.And): 713 ast_type = AST_SUB_TYPE_AND 714 elif isinstance(node, ast.Or): 715 ast_type = AST_SUB_TYPE_OR 716 elif isinstance(node, ast.Name): 717 ast_type = AST_SUB_TYPE_NAME 718 elif isinstance(node, ast.Tuple): 719 ast_type = AST_SUB_TYPE_TUPLE 720 elif isinstance(node, ast.List): 721 ast_type = AST_SUB_TYPE_LIST 722 elif isinstance(node, ast.Subscript): 723 ast_type = AST_SUB_TYPE_SUBSCRIPT 724 elif isinstance(node, ast.Starred): 725 ast_type = AST_SUB_TYPE_STARRED 726 elif isinstance(node, ast.Attribute): 727 ast_type = AST_SUB_TYPE_ATTRIBUTE 728 elif isinstance(node, ast.Dict): 729 ast_type = AST_SUB_TYPE_DICT 730 else: 731 ast_type = AST_SUB_TYPE_UNKNOWN 732 return ast_type 733 734 735def get_node_type(node): 736 """Process an ast node.""" 737 method_name = f"{node.__class__.__name__}" 738 node_type = [method_name] 739 # Judge the ast main type. 740 if isinstance(node, ast.stmt): 741 node_type.append(AST_MAIN_TYPE_STMT) 742 elif isinstance(node, (ast.expr, ast.slice)) or node is None: 743 # ast.slice and ast.expr should be expr. 744 node_type.append(AST_MAIN_TYPE_EXPR) 745 else: 746 node_type.append(AST_MAIN_TYPE_UNKNOWN) 747 return node_type 748 749 750def get_args_default_values(node): 751 """ 752 Get the args'default values of parse object. 753 754 Examples: 755 - Function: 756 func(a, b, *c, d=0, **e) 757 - The ast is as below: 758 args=arguments( 759 args=[arg(a), arg(b)], vararg=arg(c), kwonlyargs=[arg(d)], kw_defaults=[Num(0)], kwarg=arg(e) 760 ) 761 762 - Function: 763 func(a, b, c=1) 764 - The ast is as below: 765 args=arguments( 766 args=[arg(a), arg(b), arg(c)], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[Num(1)] 767 ) 768 """ 769 defaults = [None] * (len(node.args.args) - len(node.args.defaults)) 770 defaults = defaults + node.args.defaults 771 if node.args.vararg: 772 defaults.append(None) 773 defaults = defaults + node.args.kw_defaults 774 if node.args.kwarg: 775 defaults.append(None) 776 return defaults 777 778 779def get_args(node): 780 """Get the arg of parse object. The order is [args, vararg, kwonlyargs, kwarg]""" 781 args = [] 782 # Process position args. 783 for arg in node.args.args: 784 args.append(arg) 785 # Process vararg: vararg is append after position. 786 if node.args.vararg: 787 args.append(node.args.vararg) 788 # Process kwonlyargs: kwonlyargs is append after vararg. 789 if node.args.kwonlyargs: 790 for kwonlyarg in node.args.kwonlyargs: 791 args.append(kwonlyarg) 792 # Process kwarg: kwarg is append after vararg. 793 if node.args.kwarg: 794 args.append(node.args.kwarg) 795 return args 796 797 798def get_arg_spec_and_default_values(func): 799 """Get the full arg specification and the default arg values of a function""" 800 arg_spec = inspect.getfullargspec(func) 801 defaults = {} 802 args_len = len(arg_spec.args) 803 if arg_spec.defaults: 804 defaults_len = len(arg_spec.defaults) 805 for i in range(defaults_len): 806 defaults[arg_spec.args[args_len - i - 1]] = arg_spec.defaults[defaults_len - i - 1] 807 if arg_spec.kwonlydefaults: 808 for k, v in arg_spec.kwonlydefaults.items(): 809 defaults[k] = v 810 return arg_spec, defaults 811 812 813def _convert_stub_tensor(data): 814 """Convert stub tensor output to tensor""" 815 if is_stub_tensor(data): 816 return data.stub_sync() 817 if isinstance(data, tuple): 818 # Handle namedtuple since its type is tuple. 819 if hasattr(data, "_fields"): 820 type_name = data.__class__.__name__ 821 data_dict = data._asdict() 822 fields = data_dict.keys() 823 return namedtuple(type_name, fields)(**_convert_stub_tensor(data_dict)) 824 return tuple(_convert_stub_tensor(x) for x in data) 825 if data.__class__ is list: 826 # Keep the list object not change. 827 for i in range(len(data)): 828 data[i] = _convert_stub_tensor(data[i]) 829 return data 830 if data.__class__ is dict: 831 # Keep the dict object not change. 832 keys = tuple(data.keys()) 833 for key in keys: 834 data[_convert_stub_tensor(key)] = _convert_stub_tensor(data.pop(key)) 835 return data 836 return data 837 838 839def eval_script(exp_str, params): 840 """Evaluate a python expression.""" 841 if not isinstance(params, tuple): 842 raise ValueError(f"eval_script(), params is not a tuple, params: {params}") 843 if len(params) != 2: 844 raise ValueError(f"eval_script(), params tuple length is wrong, params: {params}") 845 846 # Eval function parses the expression argument and evaluates it as a python expression. 847 global_params = params[0] 848 local_params = params[1] 849 try: 850 local_params = _convert_python_data(local_params) 851 res = eval(exp_str, global_params, local_params) 852 res = _convert_stub_tensor(res) 853 except Exception as e: 854 error_info = f"When eval '{exp_str}' by using JIT Fallback feature, an error occurred: " + str(e) 855 logger.debug(error_info) 856 raise e 857 858 return res 859 860 861def get_script_id_attrs(script): 862 """Get the ids for the ast of script""" 863 ast_tokens = asttokens.ASTTokens(script, parse=True) 864 ast_tree = ast_tokens.tree 865 ast_str = astunparse.dump(ast_tree) 866 ids = re.findall(r"id='(.+?)'", ast_str) 867 id_sets = set(ids) 868 pattern = r"Attribute\(\s*value.*?id='(.*?)'.*?attr='(.*?)'.*?\)" 869 matches = re.findall(pattern, ast_str, re.DOTALL) 870 id_attrs = ["{}.{}".format(match[0], match[1]) for match in matches] 871 logger.debug(f'id_attrs: {id_attrs}') 872 id_attrs_set = set(id_attrs) 873 logger.debug(f'id_attrs_set: {id_attrs_set}') 874 res = id_sets.union(id_attrs_set) 875 logger.debug(f'res: {res}') 876 return res 877 878 879def generate_lambda_object(script): 880 """Generate lambda expression object using script""" 881 return eval(script, {}, {}) 882 883 884def get_global_params(): 885 """Get the global parameter.""" 886 logger.debug(f"get global_dict: {_global_params}") 887 return _global_params 888 889 890def get_dtype(name: str): 891 """get mstype from name""" 892 return get_attr_from_object(mstype, name) 893 894 895def check_attrs(target_object, func_name: str): 896 """Check if attr is overridden.""" 897 if isinstance(target_object, Tensor): 898 return False 899 if hasattr(target_object, func_name): 900 if not hasattr(target_object.__class__.__base__, func_name): 901 if target_object.__class__.__base__ is object: 902 return False 903 return True 904 if getattr(target_object.__class__, func_name) is not getattr(target_object.__class__.__base__, func_name): 905 return True 906 return False 907 908 909def check_is_subclass(target_object, parent): 910 """Check if target_object is a subclass.""" 911 if issubclass(target_object.__class__, parent): 912 if target_object.__class__ is not parent: 913 return True 914 return False 915 916 917class ThirdPartyLibraryChecker: 918 """ 919 Check if a module or function is from third-party libraries. 920 921 Rules for detecting third-party libraries: 922 923 1. The mindspore module and its suite are not third-party libraries. 924 925 2. Python built-in modules and python standard libraries are third-party libraries. 926 927 3. Modules with module names provided by MS_JIT_IGNORE_MODULES are treated as third-party 928 libraries, but those provided by MS_JIT_MODULES are not. 929 930 4. Third-party libraries have 'site-packages' in their installation path. 931 """ 932 def __init__(self): 933 self.user_workspace_dir = self.get_top_level_module_path(os.getcwd()) 934 self.python_builtin_dir = os.path.abspath(os.path.dirname(os.__file__)) 935 936 @staticmethod 937 def get_jit_modules(): 938 """Modules in jit_modules require jit.""" 939 jit_modules = [] 940 # Get jit modules from environment variable. 941 env_modules = os.getenv('MS_JIT_MODULES') 942 if env_modules is not None: 943 jit_modules = env_modules.split(',') 944 return jit_modules 945 946 @staticmethod 947 def get_jit_ignore_modules(): 948 """Modules in jit_ignore_modules do not need jit.""" 949 jit_ignore_modules = [] 950 # Get jit ignore modules from environment variable. 951 env_modules = os.getenv('MS_JIT_IGNORE_MODULES') 952 if env_modules is not None: 953 jit_ignore_modules = env_modules.split(',') 954 # sys.builtin_module_names do not need jit. 955 jit_ignore_modules.extend(sys.builtin_module_names) 956 return jit_ignore_modules 957 958 @staticmethod 959 def is_mindspore_related_module(module): 960 """Check if module is mindspore module or its suite.""" 961 module_leftmost_name = module.__name__.split('.')[0] 962 return module_leftmost_name in _modules_from_mindspore 963 964 def get_top_level_module_path(self, module_path): 965 """Get the path of the top level package of the current working directory.""" 966 module_abspath = os.path.abspath(module_path) 967 upper_path = os.path.abspath(os.path.dirname(module_abspath)) 968 if module_abspath == upper_path: 969 return module_abspath 970 # Check whether __init__.py exists in the upper directory. 971 init_path = os.path.join(upper_path, '__init__.py') 972 # If the path does not exist or is accessed without permission, os.path.isfile returns false. 973 if os.path.isfile(init_path): 974 module_abspath = self.get_top_level_module_path(upper_path) 975 return module_abspath 976 977 def is_third_party_module(self, module): 978 """Check if module is a third-party library.""" 979 module_leftmost_name = module.__name__.split('.')[0] 980 # Modules in jit_ignore_modules are treated as third-party libraries, such as sys.builtin_module_names. 981 jit_ignore_modules = self.get_jit_ignore_modules() 982 if module_leftmost_name in jit_ignore_modules: 983 logger.debug(f"Found third-party module '{module_leftmost_name}' in jit_ignore_modules.") 984 return True 985 # Modules in jit_modules require jit and they are considered to be in user workspace. 986 jit_modules = self.get_jit_modules() 987 if module_leftmost_name in jit_modules: 988 logger.debug(f"Found user-defined module '{module_leftmost_name}' in jit_modules.") 989 return False 990 # A modules without __file__ attribute is considered to be in user workspace. 991 if not hasattr(module, '__file__'): 992 return False 993 module_path = os.path.abspath(module.__file__) 994 # Python builtin modules are treated as third-party libraries. 995 if module_path.startswith(self.python_builtin_dir): 996 logger.debug(f"Found python builtin module '{module.__name__}', which is a third-party module.") 997 return True 998 # Check if module is under user workspace directory. 999 if module_path.startswith(self.user_workspace_dir): 1000 logger.debug(f"Found module '{module.__name__}' in user_workspace_dir: {self.user_workspace_dir}") 1001 return False 1002 # Third-party modules are under site-packages. 1003 split_path = module_path.split(os.path.sep) 1004 result = "site-packages" in split_path 1005 if result: 1006 logger.debug(f"Found third-party module '{module.__name__}' in path '{module_path}'") 1007 return result 1008 1009 def get_module_source_location(self, module): 1010 """Get the source location of the module.""" 1011 if self.is_mindspore_related_module(module): 1012 return MODULE_FROM_MINDSPORE 1013 if self.is_third_party_module(module): 1014 return MODULE_FROM_THIRDPARTY 1015 return MODULE_FROM_USER_WORKSPACE 1016 1017 def is_third_party_module_or_function(self, value): 1018 """Check if value is from a third-party library.""" 1019 if inspect.ismodule(value): 1020 module = value 1021 elif (isinstance(value, types.FunctionType) and not hasattr(value, "__jit_function__")) or \ 1022 (isinstance(value, types.MethodType) and not hasattr(value.__func__, "__jit_function__")): 1023 value_hashable = True 1024 try: 1025 hash(value) 1026 except TypeError: 1027 value_hashable = False 1028 if value_hashable and value in _convert_map(): 1029 return False 1030 module = inspect.getmodule(value) 1031 if module is None: 1032 return False 1033 else: 1034 return False 1035 return self.get_module_source_location(module) == MODULE_FROM_THIRDPARTY 1036 1037 1038third_party_checker = ThirdPartyLibraryChecker() 1039 1040 1041def is_from_third_party_library(value): 1042 """Check if value is from a third-party library.""" 1043 return third_party_checker.is_third_party_module_or_function(value) 1044 1045 1046def get_const_abs(obj): 1047 """Get absolute value of const object.""" 1048 return abs(obj) 1049 1050 1051def get_const_round(obj): 1052 """Get round value of const object.""" 1053 if isinstance(obj, tuple): 1054 val = obj[0] 1055 point_num = obj[1] 1056 return round(val, point_num) 1057 return round(obj) 1058 1059 1060def get_const_len(obj): 1061 """Get the length of const object.""" 1062 return len(obj) 1063 1064 1065def get_method_info(obj): 1066 """Get the class name of the object from its method.""" 1067 if not (inspect.ismethod(obj) or 'built-in method' in repr(obj)): 1068 return None, None 1069 class_name_and_method_name = obj.__qualname__.split('.') 1070 return class_name_and_method_name[0], class_name_and_method_name[1] 1071 1072 1073def is_ms_tensor_method(obj): 1074 """Check if the obj is a method of MindSpore Tensor""" 1075 if not hasattr(obj, '__name__') or not hasattr(Tensor, obj.__name__): 1076 return False 1077 fn = inspect.unwrap(obj.__func__ if isinstance(obj, types.MethodType) else obj) 1078 return fn == getattr(Tensor, obj.__name__) 1079 1080 1081def can_constant_fold(obj): 1082 """Check if the obj is the function can be constantly folded.""" 1083 return obj in constant_fold_functions 1084 1085 1086class Parser: 1087 """ 1088 Parser python code to ast tree. 1089 1090 Args: 1091 fn(FunctionType/MethodType): Need parse object instance. 1092 parse_method(ExtendInfoOfParseObj): Extend information for parse the function. 1093 ast_cache: Dictionary for caching ast tree. 1094 """ 1095 ast_cache = {} 1096 1097 def __init__(self, fn: (types.FunctionType, types.MethodType), parse_method=None) -> None: 1098 self.fn = inspect.unwrap(fn.__func__ if isinstance(fn, types.MethodType) else fn) 1099 self.parse_method = parse_method 1100 self.line_offset = 0 1101 self.filename: str = self.fn.__code__.co_filename 1102 1103 # Used to resolve the function's globals namespace. 1104 self.global_namespace = ModuleNamespace(self.fn.__module__) 1105 self.global_namespace.dicts[0]["__ms_tensor_func__"] = tensor 1106 1107 self.function_module = self.fn.__module__ 1108 # Used to resolve the function's nonlocals. 1109 self.closure_namespace = ClosureNamespace(self.fn) 1110 self.function_name = self.fn.__qualname__ 1111 self.lines = [] 1112 self.col_offset = 0 1113 1114 @staticmethod 1115 def is_unsupported_namespace(value): 1116 """To check if not supported for namespace""" 1117 unsupported = isinstance(value, _builtin_function_or_method_type) and value not in _convert_map() 1118 logger.debug(f"'{value}' unsupported: {unsupported}.") 1119 if unsupported and value in _fallback_unsupported_python_builtin_type: 1120 raise TypeError(f"'{value}' is not supported both in JIT Fallback and graph mode.") 1121 return unsupported 1122 1123 @staticmethod 1124 def is_unsupported_python_builtin_type(value): 1125 """To check if not supported for builtin type""" 1126 unsupported = value in _unsupported_python_builtin_type 1127 logger.debug(f"value: '{value}', unsupported builtin type: {unsupported}.") 1128 return unsupported 1129 1130 @staticmethod 1131 def get_tensor_class_type(value): 1132 """To check if is class Tensor type""" 1133 if value == Tensor: 1134 return CLASS_INSTANCE_TYPE_TENSOR 1135 if issubclass(value, ms_adapter_registry.tensor): 1136 return CLASS_INSTANCE_TYPE_ADAPTER_TENSOR 1137 return CLASS_INSTANCE_TYPE_INVALID 1138 1139 @staticmethod 1140 def get_adapter_convert_function(class_object): 1141 """Get convert function for adapter tensor""" 1142 class_object_name = class_object.__name__ 1143 if class_object_name in ms_adapter_registry.convert_adapter_tensor_map: 1144 return ms_adapter_registry.convert_adapter_tensor_map[class_object_name] 1145 return None 1146 1147 @staticmethod 1148 def is_unsupported_internal_type(value): 1149 """To check if not supported internal type, such as Tensor""" 1150 if not inspect.isclass(value): 1151 return False 1152 if value == Tensor: 1153 logger.debug(f"Found unsupported internal type: '{value}'.") 1154 return True 1155 if ms_adapter_registry.is_registered and issubclass(value, ms_adapter_registry.tensor): 1156 return True 1157 return False 1158 1159 @staticmethod 1160 def get_convert_object_for_mutable(value): 1161 """Get the convert object for value which don't support to be converted in C++.""" 1162 # The value may not be supported to do ConvertData such as api 'mutable', 1163 # and we get its converted object from python. 1164 if inspect.isfunction(value) and value in (mutable,): 1165 return _convert_map().get(value) 1166 return value 1167 1168 def get_syntax_support_type(self, value): 1169 """Get syntax support type.""" 1170 if is_from_third_party_library(value): 1171 logger.debug(f"value: '{value}' is from third party library.") 1172 return SYNTAX_UNSUPPORTED_NAMESPACE 1173 if inspect.isclass(value) or isinstance(value, _builtin_function_or_method_type): 1174 if self.is_unsupported_internal_type(value): 1175 return SYNTAX_UNSUPPORTED_INTERNAL_TYPE 1176 if self.is_unsupported_namespace(value): 1177 return SYNTAX_UNSUPPORTED_NAMESPACE 1178 if self.is_unsupported_python_builtin_type(value): 1179 return SYNTAX_UNSUPPORTED_EXTERNAL_TYPE 1180 return SYNTAX_SUPPORTED 1181 1182 def check_lambda(self, src): 1183 obj_type = get_obj_type(self.fn) 1184 if (obj_type != RESOLVE_TYPE_FUNCTION or src[:4] == "def ") and is_lambda_function(self.fn): 1185 logger.debug("fn is lambda: %r", self.fn) 1186 raise ValueError("An error occurred while parsing the positional information of the lambda expression. " 1187 "Please write the lambda expression on a separate line.\nFor example, " 1188 "the code 'def __init__(self, combine_fn=lambda x: x + 1):' rewritten as\n" 1189 "'def __init__(self, combine_fn=\nlambda x: x + 1\n):' will solve the problem.") 1190 1191 def parse(self): 1192 """Parse the function or method.""" 1193 logger.debug("fn: %r", self.fn) 1194 if isinstance(self.fn, (types.FunctionType, types.MethodType)) or \ 1195 type(self.fn).__name__ == 'cython_function_or_method': 1196 attr = 'source' 1197 try: 1198 source_lines = inspect.getsourcelines(self.fn) 1199 if context.get_context('support_binary') and \ 1200 '/mindspore/' not in self.filename and '\\mindspore\\' not in self.filename and \ 1201 (not hasattr(self.fn, attr) or getattr(self.fn, attr) != source_lines): 1202 if not os.access(self.filename, os.W_OK): 1203 raise PermissionError(f"Don't have the write permission on the file {self.filename}.") 1204 with open(self.filename, 'a') as f: 1205 f.write(f"\n# Set source attribute for function {self.function_name} " 1206 f"to support run so or pyc file in Graph Mode." 1207 f"\nsetattr({self.function_name}, '{attr}', {source_lines})\n") 1208 setattr(self.fn, attr, source_lines) 1209 except (OSError, TypeError) as e: 1210 if hasattr(self.fn, attr): 1211 source_lines = getattr(self.fn, attr) 1212 else: 1213 if e.__str__() == "could not get source code": 1214 raise OSError(f"Mindspore can not compile temporary source code in terminal. " 1215 f"Please write source code to a python file and run the file.") 1216 raise e 1217 self.lines, self.line_offset = source_lines 1218 original_src = ''.join(self.lines) 1219 hexstr = hashlib.sha256(original_src.encode()).hexdigest() 1220 ast_tokens_cache = Parser.ast_cache.get(hexstr) 1221 if not ast_tokens_cache: 1222 src = dedent(original_src) 1223 self.col_offset = \ 1224 len(original_src.split('\n')[0]) - len(src.split('\n')[0]) 1225 logger.debug("Get source: %s", src) 1226 self.check_lambda(src) 1227 try: 1228 ast_tokens = asttokens.ASTTokens(src, parse=True) 1229 except IndentationError as idt_err: 1230 idt_err.filename = self.filename 1231 idt_err.lineno = self.line_offset 1232 idt_err.msg = f"There are incorrect indentations in definition or comment of function: " \ 1233 f"'{self.function_name}'." 1234 raise idt_err 1235 ast_tokens_cache = (ast_tokens, self.col_offset) 1236 Parser.ast_cache[hexstr] = ast_tokens_cache 1237 else: 1238 self.col_offset = ast_tokens_cache[1] 1239 return ast_tokens_cache[0], ast_tokens_cache[0].tree 1240 1241 logger.error("Fn type is invalid") 1242 return None, None 1243 1244 def get_name_from_namespace(self, value): 1245 try: 1246 value_str = value.__name__ 1247 logger.debug( 1248 f"value: {type(value)}, '{value_str}', hasattr(__name__): {hasattr(value, '__name__')}.") 1249 except: 1250 value_str = str(value) 1251 logger.debug(f"value: {type(value)}, '{value_str}'.") 1252 return value_str 1253 1254 1255 def is_builtin_function_name(self, var): 1256 """Check if the var is builtin_function name.""" 1257 logger.debug(f"Check if the var'{var}' is builtin function.") 1258 builtin_function_names = vars(builtins).keys() 1259 if var in builtin_function_names: 1260 return True 1261 return False 1262 1263 1264 def get_namespace_symbol(self, var: str): 1265 """Get mindspore builtin namespace and symbol.""" 1266 if var in self.closure_namespace: 1267 logger.debug(f"Found '{var}' in closure_namespace {self.closure_namespace.__str__()}.") 1268 try: 1269 value = self.closure_namespace[var] 1270 return self.closure_namespace, var, value 1271 except UnboundLocalError: 1272 return self.closure_namespace, var, None 1273 if var in self.global_namespace: 1274 logger.debug(f"Found '{var}' in global_namespace {self.global_namespace.__str__()}.") 1275 value = self.global_namespace[var] 1276 self.get_name_from_namespace(value) 1277 # To check if allowed to support. 1278 value = self.get_convert_object_for_mutable(value) 1279 support_type = self.get_syntax_support_type(value) 1280 support_info = self.global_namespace, var, value, support_type 1281 return support_info 1282 1283 logger.debug(f"The name '{var}' is an undefined symbol.") 1284 return None, None, None 1285 1286 def check_third_party_library_side_effect(self, var, attr): 1287 """Check if value is from a third-party library.""" 1288 logger.debug(f"var '{var}'.") 1289 logger.debug(f"attr '{attr}'.") 1290 side_effect_attrs = { 1291 "numpy": {"load", "save", "savez", "savez_compressed", "loadtxt", "savetxt", "genfromtxt", "fromregex", 1292 "fromstring", "tofile", "memmap", "open_memmap", "open", "exists", "abspath", "DataSource", 1293 "format"}, 1294 "pandas": {"read_csv", "to_csv", "read_excel", "to_excel", "read_json", "to_json", "read_html", "to_html", 1295 "read_sql", "to_sql", "read_feather", "to_feather", "read_parquet", "to_parquet", "read_pickle", 1296 "to_pickle"}, 1297 "scipy": {"loadmat", "savemat"}, 1298 "csv": {"reader", "writer"}, 1299 "json": {"load", "loads", "dump", "dumps"}, 1300 "pickle": {"load", "loads", "dump", "dumps"}, 1301 "h5py": {"File", "Group", "Dataset"}, 1302 "os": {"listdir", "isfile", "exists", "isdir", "mkdir", "remove", "rmdir", "symlink", "rename"}, 1303 "shutil": {"copy", "copy2", "copytree", "move", "rmtree"}, 1304 "pathlib": {"Path", "mkdir", "rmdir", "unlink", "rename", "symlink_to"}, 1305 "glob": {"glob", "iglob"}, 1306 "zipfile": {"zipfile", "ZipFile", "write", "extractall"}, 1307 "troubleshooter": {"save", "load"}} 1308 if var in self.global_namespace: 1309 logger.debug(f"Found '{var}' in global_namespace {self.global_namespace.__str__()}.") 1310 value = self.global_namespace[var] 1311 value_str = self.get_name_from_namespace(value) 1312 value = self.get_convert_object_for_mutable(value) 1313 if is_from_third_party_library(value): 1314 logger.debug(f"value: '{value}' is from third party library.") 1315 # pylint: disable=get-dict-value-exception 1316 if value_str in side_effect_attrs and attr in side_effect_attrs[value_str]: 1317 return True 1318 return False 1319 1320 def analyze_super(self, class_type_node, subclass_instance): 1321 """Analyze super and return a class instance.""" 1322 sub_class = type(subclass_instance) 1323 if class_type_node is None: 1324 return super(sub_class, subclass_instance) 1325 if isinstance(class_type_node, ast.Name): 1326 class_name = getattr(class_type_node, 'id') 1327 elif isinstance(class_type_node, ast.Attribute): 1328 class_name = getattr(class_type_node, 'attr') 1329 else: 1330 raise ValueError(f"The first argument of 'super()' must be a class type, " 1331 f"but got {class_type_node.__class__.__name__}.") 1332 1333 target_father_class = None 1334 for class_element in sub_class.mro(): 1335 if class_element.__name__ == class_name: 1336 target_father_class = class_element 1337 break 1338 if target_father_class is None: 1339 raise ValueError(f"The second argument of 'super()' must be 'self', " 1340 f"but got {subclass_instance}.") 1341 return super(target_father_class, subclass_instance) 1342 1343 def get_jit_comments(self, start_lineno, end_lineno): 1344 """ 1345 Get the comments at the location, starting with '# @jit'. 1346 1347 Args: 1348 start_lineno: The start line no. 1349 end_lineno: The end line no. 1350 1351 Returns: 1352 list[str], the comment strings. 1353 """ 1354 comments = [] 1355 # Ignore if to fetch the whole lines's comments. 1356 if start_lineno == 1 and end_lineno == len(self.lines): 1357 return comments 1358 1359 # Add previous line comment. 1360 if start_lineno > 1: 1361 previous_lineno = start_lineno - 1 1362 previous_line = self.lines[previous_lineno - 1] 1363 striped_previous_line = previous_line.strip(' \t') 1364 result = re.search(r'^#\s*@jit[^\'\"]*?(?=\n|$)', striped_previous_line) 1365 if result: 1366 comments.append(result.group()) 1367 1368 # Add line ending comments. 1369 if start_lineno >= 1: 1370 while start_lineno <= end_lineno: 1371 line = self.lines[start_lineno - 1] 1372 result = re.search(r'#\s*@jit[^\'\"]*?(?=\n|$)', line) 1373 if result: 1374 comments.append(result.group()) 1375 start_lineno += 1 1376 return comments 1377 1378 def get_source_code(self, start_lineno, start_colno, end_lineno, end_colno): 1379 """ 1380 Get the script source at the location. 1381 1382 Args: 1383 start_lineno: The start line no. 1384 start_colno: The start column no. 1385 end_lineno: The end line no. 1386 end_colno: The end column no. 1387 1388 Returns: 1389 str, the source string. 1390 """ 1391 1392 if start_lineno == 0: 1393 logger.critical('start_lineno should not be 0') 1394 1395 first_line = self.lines[start_lineno - 1] 1396 if start_lineno == end_lineno: 1397 src = first_line[self.col_offset + start_colno:self.col_offset + end_colno] 1398 return src 1399 1400 src = first_line[self.col_offset + start_colno:] 1401 while start_lineno < end_lineno - 1: 1402 src += self.lines[start_lineno] 1403 start_lineno += 1 1404 last_line = self.lines[end_lineno - 1] 1405 src += last_line[:self.col_offset + end_colno] 1406 return src 1407 1408 def get_location(self, node): 1409 """ 1410 Get location of node start and end line no. 1411 1412 Args: 1413 node: AST op node or tuple or List. This is a node in the ANF diagram, 1414 here is the code location to get this node. 1415 1416 Returns: 1417 List, [fileName, linestart, colstart, lineend, colend]. 1418 """ 1419 res = [self.filename] 1420 err_exit = 0 1421 if isinstance(node, (list, tuple)): 1422 node_size = len(node) 1423 if node_size == 0: 1424 err_exit = 1 1425 else: 1426 start_node = node[0] 1427 end_node = node[-1] 1428 else: 1429 start_node = node 1430 end_node = node 1431 1432 if err_exit == 0: 1433 if hasattr(start_node, "first_token") and \ 1434 hasattr(end_node, "last_token"): 1435 start_lineno, start_colno = start_node.first_token.start 1436 end_lineno, end_colno = end_node.last_token.end 1437 expr_src = self.get_source_code(start_lineno, start_colno, end_lineno, end_colno) 1438 comments = self.get_jit_comments(start_lineno, end_lineno) 1439 start_lineno += self.line_offset - 1 1440 start_colno += self.col_offset 1441 end_lineno += self.line_offset - 1 1442 end_colno += self.col_offset 1443 res = res + [start_lineno, start_colno, end_lineno, end_colno, expr_src, comments] 1444 else: 1445 res = res + [0, 0, 0, 0, '', []] 1446 return res 1447