1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8from typing import cast, List, Optional 9 10import torch 11from executorch.backends.xnnpack.partition.config.xnnpack_config import ( 12 ConfigPrecisionType, 13 XNNPartitionerConfig, 14) 15from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant 16from executorch.backends.xnnpack.utils.utils import get_input_node 17from executorch.exir.backend.canonical_partitioners.config_partitioner import ( 18 format_target_name, 19) 20from executorch.exir.backend.utils import WhyNoPartition 21from torch.export import ExportedProgram 22 23logger = logging.getLogger(__name__) 24why = WhyNoPartition(logger=logger) 25 26 27class GenericNodePartitionerConfig(XNNPartitionerConfig): 28 def __init__(self, fused_act: Optional[List[str]] = None, **kwargs): 29 """ 30 fused_act is a list of node target names that can be fused with this 31 node under quantization 32 """ 33 self.fused_acts = fused_act or [] 34 super().__init__(**kwargs) 35 36 def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: 37 return self.check_common_constraints(node, ep) 38 39 def get_node_and_deps( 40 self, node: torch.fx.Node, ep: ExportedProgram 41 ) -> List[torch.fx.Node]: 42 deps = [node] 43 quantized_deps = [] 44 if ConfigPrecisionType.STATIC_QUANT in self.enabled_precision_types: 45 # try to partition dequant inputs and quant outputs if static quant is enabled 46 if [(is_dequant(dq_input)) for dq_input in node.all_input_nodes].count( 47 False 48 ): 49 # if not all inputs are dequant nodes then it isn't quantized 50 return deps 51 52 quantized_deps.extend(node.all_input_nodes) 53 54 # check if quantized pattern has fused activation 55 if len(node.users) != 1: 56 return deps 57 58 node_output = list(node.users)[0] 59 if ( 60 node_output.op == "call_function" 61 and format_target_name(node_output.target.__name__) in self.fused_acts 62 ): 63 quantized_deps.append(node_output) 64 fused_out_users = list(node_output.users.keys()) 65 if len(fused_out_users) == 1: 66 node_output = fused_out_users[0] 67 68 if not is_quant(node_output): 69 # Expected node --> fused_act (optional) --> dequant 70 return deps 71 72 quantized_deps.append(node_output) 73 74 return deps + quantized_deps 75 76 77class QuantizedPerTensorConfig(GenericNodePartitionerConfig): 78 target_name = "quantize_per_tensor.default" 79 80 def supported_precision_types(self) -> List[ConfigPrecisionType]: 81 return [ConfigPrecisionType.STATIC_QUANT] 82 83 84class DeQuantizedPerTensorConfig(GenericNodePartitionerConfig): 85 target_name = "dequantize_per_tensor.default" 86 87 def supported_precision_types(self) -> List[ConfigPrecisionType]: 88 return [ConfigPrecisionType.STATIC_QUANT] 89 90 91class HardtanhConfig(GenericNodePartitionerConfig): 92 target_name = "hardtanh.default" 93 94 def supported_precision_types(self) -> List[ConfigPrecisionType]: 95 return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] 96 97 98class AddConfig(GenericNodePartitionerConfig): 99 target_name = "add.Tensor" 100 101 def __init__(self, **kwargs): 102 super().__init__(fused_act=["relu.default"], **kwargs) 103 104 def supported_precision_types(self) -> List[ConfigPrecisionType]: 105 return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] 106 107 108class ReLUConfig(GenericNodePartitionerConfig): 109 target_name = "relu.default" 110 111 def supported_precision_types(self) -> List[ConfigPrecisionType]: 112 return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] 113 114 115class AbsConfig(GenericNodePartitionerConfig): 116 target_name = "abs.default" 117 118 def supported_precision_types(self) -> List[ConfigPrecisionType]: 119 return [ConfigPrecisionType.FP32] 120 121 122class AvgPoolingConfig(GenericNodePartitionerConfig): 123 target_name = "avg_pool2d.default" 124 125 def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: 126 """ 127 XNNPACK does not support ceil_mode = True and count_include_pad = True 128 Additionally, we only support divisor_override if divisor_override = pooling region 129 """ 130 if not self.check_common_constraints(node, ep): 131 return False 132 133 args = node.args 134 135 ceil_mode = False # default is False 136 if len(args) >= 5: 137 ceil_mode = cast(bool, args[4]) 138 139 count_include_pad = True # default is True 140 if len(args) >= 6: 141 count_include_pad = cast(bool, args[5]) 142 143 kernel_size = cast(List[int], args[1]) 144 pooling_region = kernel_size[0] * kernel_size[1] 145 divisor_override = pooling_region # Default divisor is pooling_region 146 if len(args) >= 7: 147 divisor_override = cast(int, args[6]) 148 149 if ceil_mode: 150 why(node, reason="ceil mode is not supported") 151 return False 152 153 if count_include_pad: 154 why( 155 node, 156 reason="zero-padding in the averaging calculation is not supported", 157 ) 158 return False 159 160 if divisor_override != pooling_region: 161 why(node, reason="divisor override is not supported") 162 return False 163 164 return True 165 166 def supported_precision_types(self) -> List[ConfigPrecisionType]: 167 return [ConfigPrecisionType.FP32] 168 169 170class CatConfig(GenericNodePartitionerConfig): 171 target_name = "cat.default" 172 173 def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: 174 """ 175 Only support concatenation of 2 - 4 tensors 176 """ 177 if not self.check_common_constraints(node, ep): 178 return False 179 180 num_tensors = len(node.all_input_nodes) 181 182 if not (num_tensors >= 2 and num_tensors <= 4): 183 why( 184 node, 185 reason=f"only support concatenation of 2 - 4 tensors, got {num_tensors} tensors", 186 ) 187 return False 188 189 return True 190 191 def supported_precision_types(self) -> List[ConfigPrecisionType]: 192 return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] 193 194 195class CeilConfig(GenericNodePartitionerConfig): 196 target_name = "ceil.default" 197 198 def supported_precision_types(self) -> List[ConfigPrecisionType]: 199 return [ConfigPrecisionType.FP32] 200 201 202class ClampConfig(GenericNodePartitionerConfig): 203 target_name = "clamp.default" 204 205 def supported_precision_types(self) -> List[ConfigPrecisionType]: 206 return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] 207 208 209class DivConfig(GenericNodePartitionerConfig): 210 target_name = "div.Tensor" 211 212 def supported_precision_types(self) -> List[ConfigPrecisionType]: 213 return [ConfigPrecisionType.FP32] 214 215 216class EluConfig(GenericNodePartitionerConfig): 217 target_name = "elu.default" 218 219 def supported_precision_types(self) -> List[ConfigPrecisionType]: 220 return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] 221 222 def get_original_aten(self) -> Optional[torch._ops.OpOverload]: 223 return torch.ops.aten.elu.default 224 225 226class SoftmaxConfig(GenericNodePartitionerConfig): 227 target_name = "_softmax.default" 228 229 def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: 230 """ 231 Check that dim is always the last dim 232 """ 233 if not self.check_common_constraints(node, ep): 234 return False 235 236 dim = cast(int, node.args[1]) 237 node_input = node.all_input_nodes[0] 238 tensor_dims = node_input.meta["val"].dim() 239 240 if not (dim == -1 or dim == tensor_dims - 1): 241 why( 242 node, 243 reason=f"dim must be the last dim, got dim = {dim} for tensor of rank {tensor_dims}", 244 ) 245 return False 246 return True 247 248 def supported_precision_types(self) -> List[ConfigPrecisionType]: 249 return [ConfigPrecisionType.FP32] 250 251 252class PermuteConfig(GenericNodePartitionerConfig): 253 target_name = "permute_copy.default" 254 255 def supported_precision_types(self) -> List[ConfigPrecisionType]: 256 return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] 257 258 259class SigmoidConfig(GenericNodePartitionerConfig): 260 target_name = "sigmoid.default" 261 262 def supported_precision_types(self) -> List[ConfigPrecisionType]: 263 return [ConfigPrecisionType.FP32] 264 265 266class MulConfig(GenericNodePartitionerConfig): 267 target_name = "mul.Tensor" 268 269 def supported_precision_types(self) -> List[ConfigPrecisionType]: 270 return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] 271 272 273class MaximumConfig(GenericNodePartitionerConfig): 274 target_name = "maximum.default" 275 276 def supported_precision_types(self) -> List[ConfigPrecisionType]: 277 return [ConfigPrecisionType.FP32] 278 279 280class MaxPool2dConfig(GenericNodePartitionerConfig): 281 target_name = "max_pool2d.default" 282 283 def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: 284 """ 285 XNNPACK's maxpool2d does not support ceil mode 286 """ 287 if not self.check_common_constraints(node, ep): 288 return False 289 290 is_ceil_mode = len(node.args) >= 6 and cast(bool, node.args[5]) 291 if is_ceil_mode: 292 why(node, reason="ceil mode is not supported") 293 return False 294 return True 295 296 def supported_precision_types(self) -> List[ConfigPrecisionType]: 297 return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] 298 299 def get_original_aten(self) -> Optional[torch._ops.OpOverload]: 300 return torch.ops.aten.max_pool2d.default 301 302 303class UpsampleBilinear2dConfig(GenericNodePartitionerConfig): 304 target_name = "upsample_bilinear2d.vec" 305 306 def supported_precision_types(self) -> List[ConfigPrecisionType]: 307 return [ConfigPrecisionType.FP32] 308 309 def get_original_aten(self) -> Optional[torch._ops.OpOverload]: 310 return torch.ops.aten.upsample_bilinear2d.vec 311 312 313class FloorConfig(GenericNodePartitionerConfig): 314 target_name = "floor.default" 315 316 def supported_precision_types(self) -> List[ConfigPrecisionType]: 317 return [ConfigPrecisionType.FP32] 318 319 320class HardswishConfig(GenericNodePartitionerConfig): 321 target_name = "hardswish.default" 322 323 def supported_precision_types(self) -> List[ConfigPrecisionType]: 324 return [ConfigPrecisionType.FP32] 325 326 327class LeakyReLUConfig(GenericNodePartitionerConfig): 328 target_name = "leaky_relu.default" 329 330 def supported_precision_types(self) -> List[ConfigPrecisionType]: 331 return [ConfigPrecisionType.FP32] 332 333 334class MeanDimConfig(GenericNodePartitionerConfig): 335 target_name = "mean.dim" 336 337 def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: 338 """ 339 Mean Dim currently only supports averaging 4D tensors across the innermost 340 dimensions 341 """ 342 if not self.check_common_constraints(node, ep): 343 return False 344 345 dims = node.args[1] 346 output_dims = node.meta["val"].dim() 347 348 if dims not in ([-2, -1], [-1, -2]): 349 why( 350 node, 351 reason="mean.dim only supports averaging 4D tensors across the innermost dimensions", 352 ) 353 return False 354 355 if output_dims != 4: 356 why( 357 node, 358 reason=f"mean.dim only supports averaging 4D tensors, got tensor of rank {output_dims}", 359 ) 360 return False 361 return True 362 363 def supported_precision_types(self) -> List[ConfigPrecisionType]: 364 return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] 365 366 367class MinimumConfig(GenericNodePartitionerConfig): 368 target_name = "minimum.default" 369 370 def supported_precision_types(self) -> List[ConfigPrecisionType]: 371 return [ConfigPrecisionType.FP32] 372 373 374class NegConfig(GenericNodePartitionerConfig): 375 target_name = "neg.default" 376 377 def supported_precision_types(self) -> List[ConfigPrecisionType]: 378 return [ConfigPrecisionType.FP32] 379 380 381class PowConfig(GenericNodePartitionerConfig): 382 target_name = "pow.Tensor_Scalar" 383 384 def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: 385 """ 386 Only support powers of two 387 """ 388 if not self.check_common_constraints(node, ep): 389 return False 390 391 power = node.args[1] 392 393 if not isinstance(power, int): 394 why(node, reason=f"only support int powers, got {power}") 395 return False 396 397 if power != 2: 398 why(node, reason=f"only support power == 2, got {power}") 399 return False 400 return True 401 402 def supported_precision_types(self) -> List[ConfigPrecisionType]: 403 return [ConfigPrecisionType.FP32] 404 405 406class SliceCopyConfig(GenericNodePartitionerConfig): 407 target_name = "slice_copy.Tensor" 408 409 def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: 410 """ 411 Support slicing with stride = 1, no zero-dim tensors, Slice isn't supported 412 if the input or output is dynamic 413 """ 414 if not self.check_common_constraints(node, ep): 415 return False 416 417 stride = 1 418 if len(node.args) > 4: 419 stride = cast(int, node.args[4]) 420 421 if stride != 1: 422 return False 423 424 input_node = get_input_node(node, 0) 425 output_node = node 426 427 input_shape = list(input_node.meta["val"].shape) 428 output_shape = list(output_node.meta["val"].shape) 429 430 for dim in input_shape: 431 if not isinstance(dim, int) or dim == 0: 432 why( 433 node, 434 reason=f"input tensor has invalid shape, dim: {dim} of type {type(dim)}. Expecting non-zero, int values.", 435 ) 436 return False 437 438 for dim in output_shape: 439 if not isinstance(dim, int) or dim == 0: 440 why( 441 node, 442 reason=f"output tensor has invalid shape, dim: {dim} of type {type(dim)}. Expecting non-zero, int values.", 443 ) 444 return False 445 446 return True 447 448 def supported_precision_types(self) -> List[ConfigPrecisionType]: 449 return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] 450 451 452class SquareRootConfig(GenericNodePartitionerConfig): 453 target_name = "sqrt.default" 454 455 def supported_precision_types(self) -> List[ConfigPrecisionType]: 456 return [ConfigPrecisionType.FP32] 457 458 459class ConstantPadConfig(GenericNodePartitionerConfig): 460 target_name = "constant_pad_nd.default" 461 462 def supported_precision_types(self) -> List[ConfigPrecisionType]: 463 return [ConfigPrecisionType.FP32] 464 465 466class SubConfig(GenericNodePartitionerConfig): 467 target_name = "sub.Tensor" 468 469 def supported_precision_types(self) -> List[ConfigPrecisionType]: 470 return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] 471 472 473class BMMConfig(GenericNodePartitionerConfig): 474 """ 475 Despite being a GEMM Kernel, BMM Can be partitioned like a single node partitioner 476 because it does not perform any packing on the inputs being matrix multiplied 477 """ 478 479 target_name = "bmm.default" 480 481 def supported_precision_types(self) -> List[ConfigPrecisionType]: 482 return [ConfigPrecisionType.FP32] 483 484 485class SDPAConfig(GenericNodePartitionerConfig): 486 target_name = "scaled_dot_product_attention.default" 487 488 def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: 489 """ 490 Requires Mask to have Rank 2 491 """ 492 if not self.check_common_constraints(node, ep): 493 return False 494 495 if len(node.all_input_nodes) < 4: 496 return False 497 mask_node = node.all_input_nodes[3] 498 mask_rank = mask_node.meta["val"].dim() 499 if mask_rank != 2: 500 why( 501 node, 502 reason=f"mask must have rank 2, got mask of rank {mask_rank}", 503 ) 504 return False 505 506 return True 507 508 def get_original_aten(self) -> Optional[torch._ops.OpOverload]: 509 return torch.ops.aten.scaled_dot_product_attention.default 510 511 def supported_precision_types(self) -> List[ConfigPrecisionType]: 512 return [ConfigPrecisionType.FP32] 513