• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8from typing import cast, List, Optional
9
10import torch
11from executorch.backends.xnnpack.partition.config.xnnpack_config import (
12    ConfigPrecisionType,
13    XNNPartitionerConfig,
14)
15from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
16from executorch.backends.xnnpack.utils.utils import get_input_node
17from executorch.exir.backend.canonical_partitioners.config_partitioner import (
18    format_target_name,
19)
20from executorch.exir.backend.utils import WhyNoPartition
21from torch.export import ExportedProgram
22
23logger = logging.getLogger(__name__)
24why = WhyNoPartition(logger=logger)
25
26
27class GenericNodePartitionerConfig(XNNPartitionerConfig):
28    def __init__(self, fused_act: Optional[List[str]] = None, **kwargs):
29        """
30        fused_act is a list of node target names that can be fused with this
31        node under quantization
32        """
33        self.fused_acts = fused_act or []
34        super().__init__(**kwargs)
35
36    def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
37        return self.check_common_constraints(node, ep)
38
39    def get_node_and_deps(
40        self, node: torch.fx.Node, ep: ExportedProgram
41    ) -> List[torch.fx.Node]:
42        deps = [node]
43        quantized_deps = []
44        if ConfigPrecisionType.STATIC_QUANT in self.enabled_precision_types:
45            # try to partition dequant inputs and quant outputs if static quant is enabled
46            if [(is_dequant(dq_input)) for dq_input in node.all_input_nodes].count(
47                False
48            ):
49                # if not all inputs are dequant nodes then it isn't quantized
50                return deps
51
52            quantized_deps.extend(node.all_input_nodes)
53
54            # check if quantized pattern has fused activation
55            if len(node.users) != 1:
56                return deps
57
58            node_output = list(node.users)[0]
59            if (
60                node_output.op == "call_function"
61                and format_target_name(node_output.target.__name__) in self.fused_acts
62            ):
63                quantized_deps.append(node_output)
64                fused_out_users = list(node_output.users.keys())
65                if len(fused_out_users) == 1:
66                    node_output = fused_out_users[0]
67
68            if not is_quant(node_output):
69                # Expected node --> fused_act (optional) --> dequant
70                return deps
71
72            quantized_deps.append(node_output)
73
74        return deps + quantized_deps
75
76
77class QuantizedPerTensorConfig(GenericNodePartitionerConfig):
78    target_name = "quantize_per_tensor.default"
79
80    def supported_precision_types(self) -> List[ConfigPrecisionType]:
81        return [ConfigPrecisionType.STATIC_QUANT]
82
83
84class DeQuantizedPerTensorConfig(GenericNodePartitionerConfig):
85    target_name = "dequantize_per_tensor.default"
86
87    def supported_precision_types(self) -> List[ConfigPrecisionType]:
88        return [ConfigPrecisionType.STATIC_QUANT]
89
90
91class HardtanhConfig(GenericNodePartitionerConfig):
92    target_name = "hardtanh.default"
93
94    def supported_precision_types(self) -> List[ConfigPrecisionType]:
95        return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
96
97
98class AddConfig(GenericNodePartitionerConfig):
99    target_name = "add.Tensor"
100
101    def __init__(self, **kwargs):
102        super().__init__(fused_act=["relu.default"], **kwargs)
103
104    def supported_precision_types(self) -> List[ConfigPrecisionType]:
105        return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
106
107
108class ReLUConfig(GenericNodePartitionerConfig):
109    target_name = "relu.default"
110
111    def supported_precision_types(self) -> List[ConfigPrecisionType]:
112        return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
113
114
115class AbsConfig(GenericNodePartitionerConfig):
116    target_name = "abs.default"
117
118    def supported_precision_types(self) -> List[ConfigPrecisionType]:
119        return [ConfigPrecisionType.FP32]
120
121
122class AvgPoolingConfig(GenericNodePartitionerConfig):
123    target_name = "avg_pool2d.default"
124
125    def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
126        """
127        XNNPACK does not support ceil_mode = True and count_include_pad = True
128        Additionally, we only support divisor_override if divisor_override = pooling region
129        """
130        if not self.check_common_constraints(node, ep):
131            return False
132
133        args = node.args
134
135        ceil_mode = False  # default is False
136        if len(args) >= 5:
137            ceil_mode = cast(bool, args[4])
138
139        count_include_pad = True  # default is True
140        if len(args) >= 6:
141            count_include_pad = cast(bool, args[5])
142
143        kernel_size = cast(List[int], args[1])
144        pooling_region = kernel_size[0] * kernel_size[1]
145        divisor_override = pooling_region  # Default divisor is pooling_region
146        if len(args) >= 7:
147            divisor_override = cast(int, args[6])
148
149        if ceil_mode:
150            why(node, reason="ceil mode is not supported")
151            return False
152
153        if count_include_pad:
154            why(
155                node,
156                reason="zero-padding in the averaging calculation is not supported",
157            )
158            return False
159
160        if divisor_override != pooling_region:
161            why(node, reason="divisor override is not supported")
162            return False
163
164        return True
165
166    def supported_precision_types(self) -> List[ConfigPrecisionType]:
167        return [ConfigPrecisionType.FP32]
168
169
170class CatConfig(GenericNodePartitionerConfig):
171    target_name = "cat.default"
172
173    def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
174        """
175        Only support concatenation of 2 - 4 tensors
176        """
177        if not self.check_common_constraints(node, ep):
178            return False
179
180        num_tensors = len(node.all_input_nodes)
181
182        if not (num_tensors >= 2 and num_tensors <= 4):
183            why(
184                node,
185                reason=f"only support concatenation of 2 - 4 tensors, got {num_tensors} tensors",
186            )
187            return False
188
189        return True
190
191    def supported_precision_types(self) -> List[ConfigPrecisionType]:
192        return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
193
194
195class CeilConfig(GenericNodePartitionerConfig):
196    target_name = "ceil.default"
197
198    def supported_precision_types(self) -> List[ConfigPrecisionType]:
199        return [ConfigPrecisionType.FP32]
200
201
202class ClampConfig(GenericNodePartitionerConfig):
203    target_name = "clamp.default"
204
205    def supported_precision_types(self) -> List[ConfigPrecisionType]:
206        return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
207
208
209class DivConfig(GenericNodePartitionerConfig):
210    target_name = "div.Tensor"
211
212    def supported_precision_types(self) -> List[ConfigPrecisionType]:
213        return [ConfigPrecisionType.FP32]
214
215
216class EluConfig(GenericNodePartitionerConfig):
217    target_name = "elu.default"
218
219    def supported_precision_types(self) -> List[ConfigPrecisionType]:
220        return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
221
222    def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
223        return torch.ops.aten.elu.default
224
225
226class SoftmaxConfig(GenericNodePartitionerConfig):
227    target_name = "_softmax.default"
228
229    def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
230        """
231        Check that dim is always the last dim
232        """
233        if not self.check_common_constraints(node, ep):
234            return False
235
236        dim = cast(int, node.args[1])
237        node_input = node.all_input_nodes[0]
238        tensor_dims = node_input.meta["val"].dim()
239
240        if not (dim == -1 or dim == tensor_dims - 1):
241            why(
242                node,
243                reason=f"dim must be the last dim, got dim = {dim} for tensor of rank {tensor_dims}",
244            )
245            return False
246        return True
247
248    def supported_precision_types(self) -> List[ConfigPrecisionType]:
249        return [ConfigPrecisionType.FP32]
250
251
252class PermuteConfig(GenericNodePartitionerConfig):
253    target_name = "permute_copy.default"
254
255    def supported_precision_types(self) -> List[ConfigPrecisionType]:
256        return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
257
258
259class SigmoidConfig(GenericNodePartitionerConfig):
260    target_name = "sigmoid.default"
261
262    def supported_precision_types(self) -> List[ConfigPrecisionType]:
263        return [ConfigPrecisionType.FP32]
264
265
266class MulConfig(GenericNodePartitionerConfig):
267    target_name = "mul.Tensor"
268
269    def supported_precision_types(self) -> List[ConfigPrecisionType]:
270        return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
271
272
273class MaximumConfig(GenericNodePartitionerConfig):
274    target_name = "maximum.default"
275
276    def supported_precision_types(self) -> List[ConfigPrecisionType]:
277        return [ConfigPrecisionType.FP32]
278
279
280class MaxPool2dConfig(GenericNodePartitionerConfig):
281    target_name = "max_pool2d.default"
282
283    def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
284        """
285        XNNPACK's maxpool2d does not support ceil mode
286        """
287        if not self.check_common_constraints(node, ep):
288            return False
289
290        is_ceil_mode = len(node.args) >= 6 and cast(bool, node.args[5])
291        if is_ceil_mode:
292            why(node, reason="ceil mode is not supported")
293            return False
294        return True
295
296    def supported_precision_types(self) -> List[ConfigPrecisionType]:
297        return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
298
299    def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
300        return torch.ops.aten.max_pool2d.default
301
302
303class UpsampleBilinear2dConfig(GenericNodePartitionerConfig):
304    target_name = "upsample_bilinear2d.vec"
305
306    def supported_precision_types(self) -> List[ConfigPrecisionType]:
307        return [ConfigPrecisionType.FP32]
308
309    def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
310        return torch.ops.aten.upsample_bilinear2d.vec
311
312
313class FloorConfig(GenericNodePartitionerConfig):
314    target_name = "floor.default"
315
316    def supported_precision_types(self) -> List[ConfigPrecisionType]:
317        return [ConfigPrecisionType.FP32]
318
319
320class HardswishConfig(GenericNodePartitionerConfig):
321    target_name = "hardswish.default"
322
323    def supported_precision_types(self) -> List[ConfigPrecisionType]:
324        return [ConfigPrecisionType.FP32]
325
326
327class LeakyReLUConfig(GenericNodePartitionerConfig):
328    target_name = "leaky_relu.default"
329
330    def supported_precision_types(self) -> List[ConfigPrecisionType]:
331        return [ConfigPrecisionType.FP32]
332
333
334class MeanDimConfig(GenericNodePartitionerConfig):
335    target_name = "mean.dim"
336
337    def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
338        """
339        Mean Dim currently only supports averaging 4D tensors across the innermost
340        dimensions
341        """
342        if not self.check_common_constraints(node, ep):
343            return False
344
345        dims = node.args[1]
346        output_dims = node.meta["val"].dim()
347
348        if dims not in ([-2, -1], [-1, -2]):
349            why(
350                node,
351                reason="mean.dim only supports averaging 4D tensors across the innermost dimensions",
352            )
353            return False
354
355        if output_dims != 4:
356            why(
357                node,
358                reason=f"mean.dim only supports averaging 4D tensors, got tensor of rank {output_dims}",
359            )
360            return False
361        return True
362
363    def supported_precision_types(self) -> List[ConfigPrecisionType]:
364        return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
365
366
367class MinimumConfig(GenericNodePartitionerConfig):
368    target_name = "minimum.default"
369
370    def supported_precision_types(self) -> List[ConfigPrecisionType]:
371        return [ConfigPrecisionType.FP32]
372
373
374class NegConfig(GenericNodePartitionerConfig):
375    target_name = "neg.default"
376
377    def supported_precision_types(self) -> List[ConfigPrecisionType]:
378        return [ConfigPrecisionType.FP32]
379
380
381class PowConfig(GenericNodePartitionerConfig):
382    target_name = "pow.Tensor_Scalar"
383
384    def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
385        """
386        Only support powers of two
387        """
388        if not self.check_common_constraints(node, ep):
389            return False
390
391        power = node.args[1]
392
393        if not isinstance(power, int):
394            why(node, reason=f"only support int powers, got {power}")
395            return False
396
397        if power != 2:
398            why(node, reason=f"only support power == 2, got {power}")
399            return False
400        return True
401
402    def supported_precision_types(self) -> List[ConfigPrecisionType]:
403        return [ConfigPrecisionType.FP32]
404
405
406class SliceCopyConfig(GenericNodePartitionerConfig):
407    target_name = "slice_copy.Tensor"
408
409    def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
410        """
411        Support slicing with stride = 1, no zero-dim tensors, Slice isn't supported
412        if the input or output is dynamic
413        """
414        if not self.check_common_constraints(node, ep):
415            return False
416
417        stride = 1
418        if len(node.args) > 4:
419            stride = cast(int, node.args[4])
420
421        if stride != 1:
422            return False
423
424        input_node = get_input_node(node, 0)
425        output_node = node
426
427        input_shape = list(input_node.meta["val"].shape)
428        output_shape = list(output_node.meta["val"].shape)
429
430        for dim in input_shape:
431            if not isinstance(dim, int) or dim == 0:
432                why(
433                    node,
434                    reason=f"input tensor has invalid shape, dim: {dim} of type {type(dim)}. Expecting non-zero, int values.",
435                )
436                return False
437
438        for dim in output_shape:
439            if not isinstance(dim, int) or dim == 0:
440                why(
441                    node,
442                    reason=f"output tensor has invalid shape, dim: {dim} of type {type(dim)}. Expecting non-zero, int values.",
443                )
444                return False
445
446        return True
447
448    def supported_precision_types(self) -> List[ConfigPrecisionType]:
449        return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
450
451
452class SquareRootConfig(GenericNodePartitionerConfig):
453    target_name = "sqrt.default"
454
455    def supported_precision_types(self) -> List[ConfigPrecisionType]:
456        return [ConfigPrecisionType.FP32]
457
458
459class ConstantPadConfig(GenericNodePartitionerConfig):
460    target_name = "constant_pad_nd.default"
461
462    def supported_precision_types(self) -> List[ConfigPrecisionType]:
463        return [ConfigPrecisionType.FP32]
464
465
466class SubConfig(GenericNodePartitionerConfig):
467    target_name = "sub.Tensor"
468
469    def supported_precision_types(self) -> List[ConfigPrecisionType]:
470        return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
471
472
473class BMMConfig(GenericNodePartitionerConfig):
474    """
475    Despite being a GEMM Kernel, BMM Can be partitioned like a single node partitioner
476    because it does not perform any packing on the inputs being matrix multiplied
477    """
478
479    target_name = "bmm.default"
480
481    def supported_precision_types(self) -> List[ConfigPrecisionType]:
482        return [ConfigPrecisionType.FP32]
483
484
485class SDPAConfig(GenericNodePartitionerConfig):
486    target_name = "scaled_dot_product_attention.default"
487
488    def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
489        """
490        Requires Mask to have Rank 2
491        """
492        if not self.check_common_constraints(node, ep):
493            return False
494
495        if len(node.all_input_nodes) < 4:
496            return False
497        mask_node = node.all_input_nodes[3]
498        mask_rank = mask_node.meta["val"].dim()
499        if mask_rank != 2:
500            why(
501                node,
502                reason=f"mask must have rank 2, got mask of rank {mask_rank}",
503            )
504            return False
505
506        return True
507
508    def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
509        return torch.ops.aten.scaled_dot_product_attention.default
510
511    def supported_precision_types(self) -> List[ConfigPrecisionType]:
512        return [ConfigPrecisionType.FP32]
513