1# mypy: allow-untyped-defs 2"""Dispatcher for AtenLib functions from onnx-script.""" 3 4from __future__ import annotations 5 6import logging 7import operator 8import types 9from typing import Any, Callable, Sequence, TYPE_CHECKING 10 11import torch 12import torch._ops 13import torch.fx 14from torch.onnx._internal.fx import ( 15 diagnostics, 16 registration, 17 type_utils as fx_type_utils, 18) 19 20 21if TYPE_CHECKING: 22 import onnxscript # type: ignore[import] 23 from onnxscript.function_libs.torch_lib import ( # type: ignore[import] 24 graph_building as onnxscript_graph_building, 25 ) 26 27 from torch.onnx import OnnxRegistry 28 29 30def _find_opschema_matched_symbolic_function_disagnostic_message_formatter( 31 fn: Callable, 32 self, 33 node: torch.fx.Node, 34 default_and_custom_functions: list[registration.ONNXFunction], 35 *args, 36 **kwargs, 37) -> str: 38 """Format the diagnostic message for the nearest match warning.""" 39 all_function_overload_names = "" 40 for symbolic_func in default_and_custom_functions: 41 overload_func = symbolic_func.onnx_function 42 all_function_overload_names += f"ONNX Node: {overload_func.name}[opset={overload_func.opset};is_custom={symbolic_func.is_custom}]. \n" # noqa: B950 43 return f"FX Node: {node.target}. \n" f"{all_function_overload_names}" 44 45 46def _find_operator_overloads_in_onnx_registry_disagnostic_message_formatter( 47 fn: Callable, 48 self, 49 node: torch.fx.Node, 50 *args, 51 **kwargs, 52) -> str: 53 """Format the diagnostic message for the nearest match warning.""" 54 return f"Searching operator overload: '{node.target}' in onnx registry...\n" 55 56 57class OnnxFunctionDispatcher: 58 """A dispatcher that finds the best ONNX Function for ATen/Custom operators. 59 60 It uses the `torch.ops` name to find the function. If not found, it falls back to default. 61 Otherwise, the best match is found among all function overloads. An exact match has 62 higher precedence over the closest ones. 63 64 Below is a breakdown on how the dispatch mechanism works: 65 66 1. Use the torch.ops name to find the function: 67 a. Check if the ATen overload exists in the registry. 68 b. If not, check if the default overload exists in the registry. 69 70 2. Find the nearest match among all overloaded functions: 71 a. If the types match perfectly, select the function. 72 b. Otherwise, find the nearest one with the highest matching score. Because of 73 the potential wrongly annotated dtypes and attributes matching, we use 74 nearest match to find the best function once the aten name is targeted. 75 76 3. Tie-breaker: If there are multiple nearest matches, we will select the one with 77 the highest matching score. 78 79 NOTE: The nearest match `doesn't guarantee` a correct match, and a warning message is logged. 80 """ 81 82 def __init__( 83 self, 84 onnx_registry: OnnxRegistry, 85 diagnostic_context: diagnostics.DiagnosticContext, 86 ): 87 """Initialize the ONNX Function dispatcher. 88 89 Args: 90 onnx_registry: The ONNX registry. 91 diagnostic_context: The diagnostic context to use for reporting errors. 92 """ 93 self.onnx_registry = onnx_registry 94 self.diagnostic_context = diagnostic_context 95 96 def dispatch( 97 self, 98 node: torch.fx.Node, 99 onnx_args: Sequence[ 100 fx_type_utils.TensorLike | str | int | float | bool | list | complex | None 101 ], 102 onnx_kwargs: dict[str, fx_type_utils.Argument], 103 diagnostic_context: diagnostics.DiagnosticContext, 104 ) -> onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction: 105 """Dispatches an ONNX function based on the given FX node, arguments, and keyword arguments. 106 Args: 107 node: The TorchFX node to dispatch the function for. 108 onnx_args: The arguments of the ONNX function. 109 onnx_kwargs: The keyword arguments of the ONNX function. 110 diagnostic_context: The diagnostic context to use for reporting errors. 111 Returns: 112 Either an `onnxscript.OnnxFunction` or `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm. 113 Raises: 114 RuntimeError: If there are no overloaded functions available for the given FX node. 115 """ 116 # If there are no overloaded functions available for the given FX node, raise an 117 # unsupported error 118 default_and_custom_functions = self.get_function_overloads( 119 node, diagnostic_context 120 ) 121 122 # If there are overloaded functions available, we will find one that perfect or 123 # nearest matches the given arguments and keyword arguments 124 return self._find_the_perfect_or_nearest_match_onnxfunction( 125 node, 126 default_and_custom_functions, 127 onnx_args, 128 onnx_kwargs, 129 diagnostic_context, 130 ) 131 132 def _filter_or_keep_complex( 133 self, 134 node, 135 default_and_custom_functions: list[registration.ONNXFunction], 136 diagnostic_context: diagnostics.DiagnosticContext, 137 ) -> list[registration.ONNXFunction]: 138 """Filter the complex functions if the input has complex dtype.""" 139 140 args_with_complex_dtype = [_is_arg_with_complex_dtype(arg) for arg in node.args] 141 if any(args_with_complex_dtype): 142 default_and_custom_functions = [ 143 func for func in default_and_custom_functions if func.is_complex 144 ] 145 # If we can't find the complex function group, raise error. 146 if not default_and_custom_functions: 147 op_full_name = self._get_aten_name( 148 node, diagnostic_context 149 ).qualified_name() 150 diagnostic = diagnostics.UnsupportedFxNodeDiagnostic( 151 diagnostics.rules.no_symbolic_function_for_call_function, 152 diagnostics.levels.ERROR, 153 f"Cannot find any COMPLEX symbolic function for {op_full_name}, " 154 f"which should be registered under {node.target}.", 155 unsupported_fx_node=node, 156 ) 157 diagnostic_context.log(diagnostic) 158 raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic) 159 else: 160 default_and_custom_functions = [ 161 func for func in default_and_custom_functions if not func.is_complex 162 ] 163 # If we can't find the complex function group, raise error. 164 if not default_and_custom_functions: 165 op_full_name = self._get_aten_name( 166 node, diagnostic_context 167 ).qualified_name() 168 diagnostic = diagnostics.UnsupportedFxNodeDiagnostic( 169 diagnostics.rules.no_symbolic_function_for_call_function, 170 diagnostics.levels.ERROR, 171 f"Can ONLY find COMPLEX symbolic function for {op_full_name}, " 172 f"which should be registered under {node.target}.", 173 unsupported_fx_node=node, 174 ) 175 diagnostic_context.log(diagnostic) 176 raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic) 177 return default_and_custom_functions 178 179 @diagnostics.diagnose_call( 180 diagnostics.rules.find_opschema_matched_symbolic_function, 181 diagnostic_message_formatter=_find_opschema_matched_symbolic_function_disagnostic_message_formatter, 182 ) 183 def _find_the_perfect_or_nearest_match_onnxfunction( 184 self, 185 node: torch.fx.Node, # this is used in diagnostic_message_formatter 186 default_and_custom_functions: list[registration.ONNXFunction], 187 onnx_args: Sequence[ 188 fx_type_utils.TensorLike | str | int | float | bool | list | complex | None 189 ], 190 onnx_kwargs: dict[str, fx_type_utils.Argument], 191 diagnostic_context: diagnostics.DiagnosticContext, 192 ): 193 """Find the perfect/nearest matched OnnxFunction for the given FX node, arguments, and keyword arguments. 194 195 Args: 196 default_and_custom_functions: The list includes overloaded functions, with 197 custom ones appearing after the default ones. 198 onnx_args: Arguments organized in PyTorch inputs way. 199 onnx_kwargs: Keyword arguments organized in PyTorch inputs way. 200 diagnostic_context: The diagnostic context to use for reporting errors. 201 202 Returns: 203 Either an `onnxscript.OnnxFunction` or `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm. 204 Raises: 205 RuntimeError: If there are no overloaded functions available for the given FX node. 206 """ 207 overload_match_ranking: dict[registration.ONNXFunction, int | None] = {} 208 diagnostic = diagnostic_context.inflight_diagnostic() 209 210 # Iterate the overloaded functions in reverse order to prioritize the custom ones 211 # over the default ones, and find the perfect match. 212 for symbolic_function in reversed(default_and_custom_functions): 213 function_opschema = _OnnxSchemaChecker(symbolic_function.onnx_function) 214 215 # NOTE: 1. If the perfect match is found, return the function 216 if function_opschema.perfect_match_inputs( 217 diagnostic, onnx_args, onnx_kwargs 218 ): 219 return symbolic_function.onnx_function 220 # Record the match score for the nearest match if it's not the perfect match 221 overload_match_ranking[symbolic_function] = function_opschema.match_score 222 223 # NOTE: 2. If there is no perfect match, find the nearest match among the nearest matche candidates 224 # If there is no nearest match, raise an error 225 overload_match_ranking = { 226 k: v for k, v in overload_match_ranking.items() if v is not None 227 } 228 if not overload_match_ranking: 229 # If there are no overloaded functions available for the given FX node, raise an 230 # unsupported error 231 op_full_name = self._get_aten_name( 232 node, diagnostic_context 233 ).qualified_name() 234 diagnostic = diagnostics.UnsupportedFxNodeDiagnostic( 235 diagnostics.rules.no_symbolic_function_for_call_function, 236 diagnostics.levels.ERROR, 237 f"Cannot find any perfect/nearest match of symbolic function for {op_full_name}," 238 f"which should be registered under {node.target}.", 239 unsupported_fx_node=node, 240 ) 241 diagnostic_context.log(diagnostic) 242 raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic) 243 244 diagnostic.warning( 245 "### Exact match is not found!\n" 246 "Cannot find a perfect match of symbolic overload, " 247 "a nearest match is found. Please check the ONNX output carefully. \n", 248 ) 249 diagnostic.level = diagnostics.levels.WARNING 250 # NOTE: 3. Tie breaker: if there are multiple nearest matches, we will choose the one 251 # that is custom first. If there are multiple custom ones, we will choose the one 252 # that is added lastly in the list. 253 symbolic_function_list: list[registration.ONNXFunction] = sorted( 254 overload_match_ranking, 255 key=lambda k: ( 256 overload_match_ranking[k], 257 k.is_custom, 258 default_and_custom_functions.index(k), 259 ), 260 reverse=True, 261 ) 262 return symbolic_function_list[0].onnx_function 263 264 def _get_aten_name( 265 self, node: torch.fx.Node, diagnostic_context: diagnostics.DiagnosticContext 266 ) -> registration.OpName: 267 """Get the OpName from the target. 268 269 Args: 270 node: The TorchFX node to get the aten name for. 271 diagnostic_context: The diagnostic context to use for reporting errors. 272 273 Returns: 274 The internal op name within dataclass: registration.OpName. 275 """ 276 if node.target == operator.getitem: 277 return registration.OpName.from_name_parts( 278 namespace="aten", op_name="getitem" 279 ) 280 if isinstance(node.target, torch._ops.OpOverloadPacket): 281 # aten::sym_size is the only OverloadPacket that we support. 282 # schema: aten::sym_size(Tensor self, int dim) -> Tensor 283 if node.target != torch.ops.aten.sym_size: 284 diagnostic = diagnostics.UnsupportedFxNodeDiagnostic( 285 diagnostics.rules.no_symbolic_function_for_call_function, 286 diagnostics.levels.ERROR, 287 f"Unsupported OverloadPacket: {node.target}, aten.sym_size is the only allowed OverloadPacket!", 288 unsupported_fx_node=node, 289 ) 290 diagnostic_context.log(diagnostic) 291 raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic) 292 # TODO(titaiwang): aten::sym_size has overload, but fx graph is using 293 # overloadpacket for some reasons. 294 # https://github.com/pytorch/pytorch/issues/97201 295 aten_op_default = node.target.default 296 return registration.OpName.from_op_overload(op_overload=aten_op_default) # type: ignore[no-any-return] 297 298 if isinstance(node.target, types.BuiltinFunctionType): 299 # Make sure it's symint/symfloat consuming builtin ops. 300 for node_arg in node.args: 301 if (not isinstance(node_arg, (torch.fx.Node, int, float))) or ( 302 isinstance(node_arg, torch.fx.Node) 303 and not fx_type_utils.is_torch_symbolic_type(node_arg.meta["val"]) 304 ): 305 diagnostic = diagnostics.UnsupportedFxNodeDiagnostic( 306 diagnostics.rules.no_symbolic_function_for_call_function, 307 diagnostics.levels.ERROR, 308 f"Unsupported node arg: {node_arg} (type {type(node_arg)}) with builtin function: {node.target}," 309 " only int/float/SymInt/SymFloat is supported with built-in ops!", 310 unsupported_fx_node=node, 311 ) 312 diagnostic_context.log(diagnostic) 313 raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic) 314 return registration.OpName.from_builtin_function(node.target) 315 316 if isinstance(node.target, torch._ops.OpOverload): 317 return registration.OpName.from_op_overload(op_overload=node.target) 318 319 # Unexpected target, raise error. 320 diagnostic = diagnostics.UnsupportedFxNodeDiagnostic( 321 diagnostics.rules.no_symbolic_function_for_call_function, 322 diagnostics.levels.ERROR, 323 f"Unknown call_function target: {node.target}", 324 unsupported_fx_node=node, 325 ) 326 diagnostic_context.log(diagnostic) 327 raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic) 328 329 @diagnostics.diagnose_call( 330 diagnostics.rules.find_operator_overloads_in_onnx_registry, 331 diagnostic_message_formatter=_find_operator_overloads_in_onnx_registry_disagnostic_message_formatter, 332 ) 333 def get_function_overloads( 334 self, 335 node: torch.fx.Node, 336 diagnostic_context: diagnostics.DiagnosticContext, 337 ) -> list[registration.ONNXFunction]: 338 """Get the function overloads from the registry. 339 340 Args: 341 node: The node to get the function overloads for. 342 diagnostic_context: The diagnostic context to use for reporting errors. 343 344 Returns: 345 The list contains ONNXFunctions, starting with the default ones and 346 followed by any custom ones. 347 """ 348 349 internal_opname: registration.OpName = self._get_aten_name( 350 node=node, diagnostic_context=diagnostic_context 351 ) 352 353 # If the ATen/Custom operators are not registered, the group will be None. 354 # And non-registered ATen/Custom operators will trigger error in the next step. 355 function_group: list[registration.ONNXFunction] | None = None 356 357 function_group = self.onnx_registry.get_op_functions( 358 namespace=internal_opname.namespace, 359 op_name=internal_opname.op_name, 360 overload=internal_opname.overload, 361 ) 362 363 # NOTE: Fall back to default overload if the ONNX registry doesn't have the overload. 364 if function_group is None: 365 function_group = self.onnx_registry.get_op_functions( 366 namespace=internal_opname.namespace, 367 op_name=internal_opname.op_name, 368 overload=None, 369 ) 370 if function_group is not None: 371 op_full_name = internal_opname.qualified_name() 372 diagnostic = diagnostic_context.inflight_diagnostic() 373 diagnostic.warning( 374 "### The operator overload is not found in onnx registry!\n" 375 "Cannot find the operator overload in onnx registry, but " 376 "the default overload is found. Please check the ONNX output carefully. \n", 377 ) 378 diagnostic.level = diagnostics.levels.WARNING 379 380 if function_group is not None: 381 # NOTE: If the input has complex dtype, we will only dispatch to the complex functions. 382 function_group = self._filter_or_keep_complex( 383 node, function_group, diagnostic_context 384 ) 385 return function_group # type: ignore[return-value] 386 387 op_full_name = internal_opname.qualified_name() 388 diagnostic = diagnostics.UnsupportedFxNodeDiagnostic( 389 diagnostics.rules.no_symbolic_function_for_call_function, 390 diagnostics.levels.ERROR, 391 f"Cannot find symbolic function for {op_full_name}, " 392 f"which should be registered under {node.target}.", 393 unsupported_fx_node=node, 394 ) 395 diagnostic_context.log(diagnostic) 396 raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic) 397 398 399class _OnnxSchemaChecker: 400 """ 401 The OnnxSchemaChecker class is a checker for ONNX OpSchema and param schema. 402 403 It provides methods to check for input compatibility based on the OpSchema. It also 404 provides a matching score to indicate how well the OpSchema matches the input and 405 kwargs types. A function will be evaluated as perfect match, nearest match eligible, 406 or no match. 407 408 Here are some common examples in categories: 409 410 1. [NOTE: Perfect match]: The number of inputs and attributes are exactly the same as 411 the OpSchema. The types of inputs and attributes are exactly the same as the 412 OpSchema. 413 414 ```python 415 inputs = (Tensor[2, 3], Tensor[2, 3]) 416 attributes = {"alpha": 1.0} 417 418 419 @torch_op("aten::op") 420 def aten_op(self: TReal, other: TReal, alpha: float = 1) -> TReal: ... 421 ``` 422 Result: Perfect match. 423 424 2. [NOTE: Optional input]: The dispatcher recognizes optional inputs. However, 425 the input can't be ignored. None must be provided. 426 427 ```python 428 inputs = (Tensor([2, 3]), None) 429 attributes = {} 430 431 aten_op(X: TTensor, Y: Optional[INT64]): 432 ... 433 ``` 434 Result: Perfect match. 435 Real example: `aten::convolution`. 436 437 3. [NOTE: Different attributes]: If an attribute is provided with value, it's 438 a must to match the attribute in function signature. 439 ```python 440 inputs = (Tensor([2, 3]),) 441 attributes = {"a":1, "b":2} 442 443 aten_op(X: TTensor, a: int): 444 ... 445 ``` 446 Result: No match. 447 Real example: `aten::div` vs `aten::div.Tensor_mode`. 448 449 4. [NOTE: Default attributes]: Default attribute will fill in the value into 450 inputs/attributes. 451 ```python 452 inputs = (Tensor([2, 3]),) 453 attributes = {} 454 455 aten_op(X: TTensor, a: int = 3): 456 ... 457 ``` 458 Result: Perfect match. 459 Real example: `aten::clone` 460 461 5. [NOTE: Ignore attribute with None value]: The attributes with None value 462 will be ignored in matching. 463 ```python 464 inputs = (Tensor([2, 3]),) 465 attributes = {"a": None} 466 467 aten_op(X: TTensor): 468 ... 469 ``` 470 Result: Perfect match. 471 472 ```python 473 inputs = (Tensor([2, 3]),) 474 attributes = {"a": None} 475 476 aten_op(X: TTensor, a: int = 3): 477 ... 478 ``` 479 Result: Nearest match eligible. 480 481 Real example: `aten::div` vs `aten::div.Tensor_mode`. 482 483 Attributes: 484 onnxfunction: The OnnxFunction. 485 param_schema: The parameter schema defined in the OnnxFunction. 486 op_schema: The ONNX OpSchema. 487 type_constraints: The type constraints defined in the OpSchema. 488 attributes: The attributes defined in the OpSchema. 489 _matching_score: The matching score of the OnnxSchemaChecker . 490 491 """ 492 493 def __init__( 494 self, 495 onnxfunction: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction, 496 ): 497 """Initialize the OnnxSchemaChecker . 498 499 Args: 500 onnxfunction: The OnnxFunction. 501 """ 502 self.onnxfunction = onnxfunction 503 self.param_schema = self.onnxfunction.param_schemas() 504 op_schema = self.onnxfunction.op_schema 505 # Both `OnnxFunction` and `TracedOnnxFunction` never return None for `op_schema`. 506 # However their base class would. Hence return type is annotated as Optional[OpSchema]. 507 assert op_schema is not None 508 self.op_schema = op_schema 509 self.type_constraints = { 510 # "T": {"tensor(int64)"} 511 constraint.type_param_str: set(constraint.allowed_type_strs) 512 for constraint in self.op_schema.type_constraints 513 } 514 self.attributes = self.op_schema.attributes 515 self._matching_score: int | None = None 516 517 @property 518 def match_score(self) -> int | None: 519 """The matching score of the OnnxSchemaChecker . 520 521 If this remains None, it means the matching score has not been calculated, 522 and it's not a nearest match candidate. 523 524 Returns: 525 The matching score of the OnnxSchemaChecker . 526 """ 527 return self._matching_score 528 529 def perfect_match_inputs( 530 self, 531 diagnostic: diagnostics.Diagnostic, 532 args: Sequence[ 533 fx_type_utils.TensorLike | str | int | float | bool | list | complex | None 534 ], 535 kwargs: dict[str, fx_type_utils.Argument], 536 ) -> bool: 537 """Check if the inputs perfectly match the OpSchema requirements. 538 539 The definition of perfect match is that the input types are all in the type 540 constraints and the number of inputs matches the number of inputs in the 541 OpSchema. 542 543 Checking steps: 544 1. The function signature matches the inputs number, and attribute names. 545 2. The input/attribute types are all in the type constraints. 546 547 A function should at least pass the first step to be eligible for the 548 nearest matching. 549 550 Args: 551 diagnostic: The diagnostic to use for logging detailed info. 552 args: The input arguments organized in PyTorch inputs way. 553 kwargs: The input keyword arguments organized in PyTorch inputs way. 554 555 Returns: 556 True if the inputs match the requirements, False otherwise. 557 """ 558 559 # NOTE: OnnxFunction does not have the same function signature as the original 560 # PyTorch operator. We need to separate the input/attributes from the arguments. 561 ( 562 function_inputs, 563 function_attributes, 564 ) = self._separate_input_attributes_from_arguments( 565 self.param_schema, 566 args, 567 kwargs, 568 fill_defaults=True, # fill defaults for optional arguments to match 569 ) 570 with diagnostic.log_section(logging.INFO, "Checking perfect match..."): 571 diagnostic.info( 572 "%s", 573 diagnostics.LazyString(diagnostics.format_argument, self.onnxfunction), 574 ) 575 # NOTE: 1. Check if the input number and attribute names match the 576 # OpSchema. If it's not, we know the function is not eligible to be a perfect 577 # match, nor a nearest match. 578 # We use is_perfect_match to postpone the return value to the end 579 # of the function, as we want to log all the mismatch info. 580 is_perfect_match = True 581 if len(function_inputs) != len(self.op_schema.inputs): 582 with diagnostic.log_section( 583 logging.INFO, "Failed: input number mismatch!" 584 ): 585 diagnostic.info( 586 "Actual %d vs expected %d", 587 len(function_inputs), 588 len(self.op_schema.inputs), 589 ) 590 diagnostic.info("The function is not a nearest match candidate.") 591 is_perfect_match = False 592 593 if set(function_attributes) != set(self.attributes): 594 with diagnostic.log_section( 595 logging.INFO, "Failed: attribute mismatch!" 596 ): 597 diagnostic.info( 598 "%s", 599 diagnostics.LazyString( 600 lambda: f"Actual {set(function_attributes)} vs expected {set(self.attributes)}", 601 ), 602 ) 603 diagnostic.info("The function is not a nearest match candidate.") 604 is_perfect_match = False 605 606 # If it's already not a perfect match, we can return False directly. Further 607 # checking is only for the functions that are eligible for nearest match. 608 if not is_perfect_match: 609 return False 610 611 # NOTE: 2. The dtypes of inputs and attributes should be in the 612 # type constraints of the OpSchema. If they are not, we know the function is not 613 # eligible to be a perfect match, but can be a nearest match candidate. 614 for schema_input, torch_input in zip( 615 self.op_schema.inputs, function_inputs 616 ): 617 torch_input_compatible_types = _find_onnx_data_type(torch_input) 618 allowed_types = self.type_constraints[schema_input.type_str] 619 if not allowed_types.intersection( 620 torch_input_compatible_types 621 ) and not any( 622 fx_type_utils.is_optional_onnx_dtype_str(onnx_type_str) 623 for onnx_type_str in allowed_types 624 ): 625 # If torch_input_compatible_types isn't in allowed_types 626 # of this input defined in the OpSchema, we know the function 627 # and the input are not compatible 628 with diagnostic.log_section( 629 logging.INFO, 630 "Failed: input type mismatch for input '%s'!", 631 schema_input.name, 632 ): 633 diagnostic.info( 634 "Actual %s vs\nExpected %s", 635 torch_input_compatible_types, 636 allowed_types, 637 ) 638 is_perfect_match = False 639 640 for attribute_name, attribute in function_attributes.items(): 641 if not self._match_onnx_attribute_type(attribute_name, attribute): 642 # If the attribute type of the OpSchema and the attribute type don't match, 643 # we know the function and the input are not compatible 644 with diagnostic.log_section( 645 logging.INFO, 646 "Failed: attribute '%s' type mismatch!", 647 attribute_name, 648 ): 649 diagnostic.info( 650 "Actual %s vs\nExpected %s", 651 type(attribute), 652 self.attributes[attribute_name].type, 653 ) 654 is_perfect_match = False 655 656 # NOTE: This is still a candidate for nearest match, as it only mismatches attributes on dtype. 657 self._record_matching_score(function_inputs, function_attributes) 658 diagnostic.info("match score: %d", self.match_score) 659 return is_perfect_match 660 661 def _match_onnx_attribute_type( 662 self, 663 attribute_name: str, 664 attribute: fx_type_utils.Argument | onnxscript_graph_building.TorchScriptTensor, 665 is_sequence: bool = False, 666 ) -> bool: 667 if isinstance(attribute, (int, float, bool, str)): 668 attribute_onnx_type = fx_type_utils.from_python_type_to_onnx_attribute_type( 669 type(attribute), is_sequence=is_sequence 670 ) 671 if attribute_onnx_type != self.attributes[attribute_name].type: 672 return False 673 # If the attribute is an empty list, we don't know the type of the list 674 # so it's a mismatch 675 elif isinstance(attribute, (list, tuple)) and attribute: 676 return self._match_onnx_attribute_type( 677 attribute_name, attribute[0], is_sequence=True 678 ) 679 else: 680 # NOTE: Unrecognized attribute type 681 return False 682 return True 683 684 def _record_matching_score( 685 self, 686 inputs: Sequence[ 687 fx_type_utils.TensorLike | str | int | float | bool | list | complex | None 688 ], 689 attributes: dict[str, fx_type_utils.Argument], 690 ): 691 """Calculate the inputs matching score of the OpSchema requirements to find the nearest match. 692 693 Only the functions which have the same number of inputs and attributes as the 694 OpSchema are eligible to be a nearest match candidate. Thus, we don't need to 695 check the length of inputs and attributes here, and only check the types of 696 inputs and attributes. 697 698 How the matchsing score is calculated: 699 score += 1 if one input/attribute type is in the type constraints. 700 701 Limitations: 702 None/NoeType/[] could result in zero matches, and the same score of overloads, 703 which will be recorded in SARIF. 704 705 Args: 706 inputs: The input arguments. 707 attributes: The input keyword arguments. 708 709 Returns: 710 True if the inputs match the requirements, False otherwise. 711 """ 712 self._matching_score = 0 713 # If they have different length of arguments, the score would be lower to those 714 # functions which have the same length of arguments. 715 for schema_input, torch_input in zip(self.op_schema.inputs, inputs): 716 torch_input_compatible_types = _find_onnx_data_type(torch_input) 717 allowed_types = self.type_constraints[schema_input.type_str] 718 if allowed_types.intersection(torch_input_compatible_types): 719 # If torch_input_compatible_types is in allowed_types 720 # of this input defined in the OpSchema, we know the function 721 # and the input are compatible 722 self._matching_score += 1 723 # NOTE: The penalty is applied to those functions which have different attributes. 724 for attribute_name, attribute_proto in self.attributes.items(): 725 attribute = attributes[attribute_name] 726 attribute_onnx_type = fx_type_utils.from_python_type_to_onnx_attribute_type( 727 type(attribute) 728 ) 729 if attribute_onnx_type != attribute_proto.type: 730 # If the attribute type of the OpSchema and the attribute type don't match, 731 # we know the function and the input are not compatible 732 self._matching_score -= 1 733 734 # NOTE: Referenced from onnxscript internal function. 735 # Importing this function makes the code less robust, as it is not a public API. 736 737 def _separate_input_attributes_from_arguments( 738 self, 739 param_schemas: Sequence[onnxscript.values.ParamSchema], 740 args: Sequence[ 741 fx_type_utils.TensorLike | str | int | float | bool | list | complex | None 742 ], 743 kwargs: dict[str, fx_type_utils.Argument], 744 fill_defaults: bool = True, 745 ) -> tuple[list[Any], dict[str, Any]]: 746 """Separate Python args and kwargs into ONNX inputs and attributes. 747 748 Extra_kwargs are ignored if their values are None. For example, if the 749 OpSchema has an attribute "rounding_mode" and the caller provides 750 "rounding_mode=None", the attribute "rounding_mode" will not be included 751 in the returned attributes when the OnnxFunction signature doesn't have 752 "rounding_mode" as an attribute. 753 754 Args: 755 param_schemas: The parameter schemas of an Op or a OnnxFunction. 756 args: The Python positional arguments supplied by the caller. 757 kwargs: The Python keyword arguments supplied by the caller. 758 fill_defaults: Whether to fill the default values for attributes. 759 760 Returns: 761 A tuple of two elements: 762 - A list of ONNX inputs. 763 - An dictionary of ONNX attribute names and values. 764 765 Raises: 766 TypeError: When allow_extra_kwargs is False and there are unknown kwargs. 767 TypeError: When a required input is not provided. 768 """ 769 # args, kwargs and param_schemas should be all in order 770 # user may not specify all inputs or attributes 771 772 import onnx 773 774 onnx_inputs: list[Any] = [] 775 onnx_attributes: dict[str, Any] = {} 776 # NOTE: We need to copy kwargs because we will mutate it 777 copy_kwargs = kwargs.copy() 778 for i, param in enumerate(param_schemas): 779 if param.is_variadic_input: 780 # Exhaust all remaining args 781 onnx_inputs.extend(args[i:]) 782 args = [] 783 continue 784 if i < len(args): 785 if param.is_input: 786 onnx_inputs.append(args[i]) 787 else: 788 onnx_attributes[param.name] = args[i] 789 elif param.name in copy_kwargs: 790 if param.is_input: 791 # Move the input from kwargs to inputs 792 onnx_inputs.append(copy_kwargs[param.name]) 793 copy_kwargs.pop(param.name) 794 else: 795 onnx_attributes[param.name] = copy_kwargs[param.name] 796 elif ( 797 param.is_attribute 798 and self.attributes[param.name].default_value.type 799 != onnx.AttributeProto.UNDEFINED # type: ignore[attr-defined] 800 ): 801 # User did not provide the attribute 802 if fill_defaults: 803 onnx_attributes[param.name] = param.default 804 # optional input 805 elif param.is_input: 806 if fill_defaults: 807 onnx_inputs.append(None) 808 809 # NOTE: Pick up extra kwargs if it's not None. None is not expected 810 # as an attribute value in torchlib. 811 for k, v in copy_kwargs.items(): 812 if k not in onnx_attributes and v is not None: 813 onnx_attributes[k] = v 814 return onnx_inputs, onnx_attributes 815 816 817def _is_arg_with_complex_dtype(arg: fx_type_utils.Argument) -> bool: 818 """Check if the node has complex dtype recursively.""" 819 if ( 820 isinstance(arg, torch.fx.Node) 821 and "val" in arg.meta 822 and isinstance(arg.meta["val"], torch.Tensor) 823 and torch.is_complex(arg.meta["val"]) 824 ): 825 return True 826 elif isinstance(arg, list): 827 for item in arg: 828 return _is_arg_with_complex_dtype(item) 829 return False 830 831 832def _find_onnx_data_type( 833 torch_input: fx_type_utils.TensorLike 834 | str 835 | int 836 | float 837 | bool 838 | list 839 | tuple 840 | complex 841 | None, 842) -> set[str]: 843 """Convert inputs data type from torch acceptable dtype to the compatible onnx dtype string.""" 844 if ( 845 isinstance(torch_input, fx_type_utils.TensorLike) 846 and torch_input.dtype is not None 847 ): 848 return fx_type_utils.from_torch_dtype_to_onnx_dtype_str(torch_input.dtype) 849 if isinstance(torch_input, (int, float, bool, str, complex)): 850 return fx_type_utils.from_torch_dtype_to_onnx_dtype_str(type(torch_input)) 851 if isinstance(torch_input, (list, tuple)) and torch_input: # [Tensor, Tensor] 852 the_first_non_none_item = next( 853 (item for item in torch_input if item is not None), None 854 ) 855 set_dtype = _find_onnx_data_type(the_first_non_none_item) 856 if any(isinstance(input, fx_type_utils.TensorLike) for input in torch_input): 857 # NOTE: Any Tensor involved in a list would make it a seq(tensor(onnx_type)) 858 return {f"seq({dtype})" for dtype in set_dtype} 859 else: 860 # constant list of non-tensor type 861 return set_dtype 862 if ( 863 torch_input is None 864 or ( 865 isinstance(torch_input, fx_type_utils.TensorLike) 866 and torch_input.dtype is None 867 ) 868 or (isinstance(torch_input, (list, tuple)) and not torch_input) 869 ): 870 # NOTE: None, No dtype, and empty list are edge cases, we allow it to be any type to relax the type check 871 # seq(tensor) also goes to here, as it is not supported in torchscript, and it would be None in this case. 872 return set() 873 874 raise RuntimeError(f"Unknown input type from input: {torch_input}") 875