# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-strict import logging from typing import Any, Callable, Dict, final, List, Mapping, Optional, Tuple import executorch.backends.vulkan.utils as utils import torch from executorch.backends.vulkan.op_registry import ( get_op_features, has_impl, OpFeatures, vulkan_supported_ops, ) from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( VkMemoryLayout, VkStorageType, ) from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( DelegationSpec, Partitioner, PartitionResult, ) from executorch.exir.backend.utils import tag_constant_data from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupportBase # pyre-ignore ops_not_to_decompose = [ torch.ops.aten.upsample_nearest2d.vec, ] logger: logging.Logger = logging.getLogger("") logger.setLevel(logging.INFO) class VulkanSupportedOperators(OperatorSupportBase): def __init__( self, texture_limits: utils.ImageExtents, buffer_limit: int, require_dynamic_shape: bool = False, ) -> None: super().__init__() self.texture_limits: utils.ImageExtents = texture_limits self.buffer_limit = buffer_limit self.require_dynamic_shapes = require_dynamic_shape # The tensor dim limit is to guard against tensors with one or more # large dimensions, which cannot be represented by an image texture due # to the texture axis limits. self.tensor_dim_limit = 16384 def op_node_is_compatible( self, node: torch.fx.Node, features: Optional[OpFeatures] = None ) -> Tuple[bool, str]: """ Check if a given node is compatible with the Vulkan delegate's implementation of the operator called by the node. Each tensor argument participating in the operator call must be able to be represented with a (storage type, memory layout) combination that is supported by the operator implementation. """ target = node.target # Account for custom operators if node.target == torch.ops.higher_order.auto_functionalized: first_arg = node.args[0] assert isinstance(first_arg, torch._ops.OpOverload) target = first_arg.name() # Extract the features for the node's operator, if no override was provided if features is None: if not has_impl(target): return False, "no operator implementation" features = get_op_features(target) valid_texture_layouts = utils.possible_node_memory_layouts( node, self.texture_limits ) can_use_buffers = utils.within_buffer_limit(node, self.buffer_limit) for i, arg in enumerate(node.args): if ( isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg) and i not in features.skip_limits_check ): arg_texture_layouts = utils.possible_node_memory_layouts( arg, self.texture_limits ) valid_texture_layouts = valid_texture_layouts.intersection( arg_texture_layouts ) can_use_buffers = can_use_buffers and utils.within_buffer_limit( arg, self.buffer_limit ) # If there are no valid texture memory layouts, then buffer storage must be # supported by the operator implementation. if len(valid_texture_layouts) == 0: if not can_use_buffers: return ( False, f"op requires buffers that exceed the buffer limit ({self.buffer_limit})", ) compatible = VkStorageType.BUFFER in features.supported_storage_types() reason = "op is compatible" if not compatible: reason = "op requires buffers which is not supported by op impl" return compatible, reason op_available_layouts = features.supported_memory_layouts( VkStorageType.TEXTURE_3D ) is_compatible = any( layout in op_available_layouts for layout in valid_texture_layouts ) if not is_compatible: return False, "Required texutre memory layout not supported" return is_compatible, "Op is compatible" def node_is_compatible( self, node: torch.fx.Node, features: Optional[OpFeatures] = None ) -> Tuple[bool, str]: # TODO(ssjia) support symbolic ints if utils.is_symint_node(node): return False, "symint node not supported yet" elif utils.is_tensor_node(node): return self.op_node_is_compatible(node, features=features) return False, f"Unsupported node type: {node.format_node()}" def is_linear_permute(self, node: torch.fx.Node) -> Tuple[bool, bool]: """ Detect if a node is a permute/transpose that precedes a call to a `mm` or `addmm` operator. This node can be fused with the `mm` or `addmm` to produce a `linear` operator. This function returns two bool values: 1. The first indicates if this node can be fused into a linear node 2. The second indicates if the overall linear op can be executed with Vulkan The node will be partitioned only if both are true. """ if node.target not in [ exir_ops.edge.aten.t_copy.default, exir_ops.edge.aten.permute_copy.default, ]: return False, False if len(node.users) != 1: return False, False first_user = list(node.users.keys())[0] if first_user.target in [ exir_ops.edge.aten.mm.default, exir_ops.edge.aten.addmm.default, ]: # Only mark this node if the target linear op is valid if self.node_is_compatible(first_user)[0]: return True, True else: return True, False return False, False def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, bool]: """ Scalar tensors are usually converted to scalar values in the graph via` scalar_tensor[0].item()` in Python, which translates to a chain of `local_scalar_dense(torch.select.int(scalar_tensor, 0, 0))` in the graph. This function marks the entire chain as supported by the Vulkan delegate. Later, within vulkan_preprocess there will be a graph transform which replaces the chain with passing in the scalar tensor directly. Similar to the `is_linear_permute` function, this function has 2 return values. """ if node.target == exir_ops.edge.aten.select_copy.int: if len(node.users) != 1: return False, False # pyre-ignore if node.args[0].meta["val"].numel() != 1: return False, False local_scalar_dense = list(node.users.keys())[0] if local_scalar_dense.target != torch.ops.aten._local_scalar_dense.default: return False, False return self.is_in_local_scalar_dense_chain(local_scalar_dense) if node.target == torch.ops.aten._local_scalar_dense.default: return True, all(self.node_is_compatible(user)[0] for user in node.users) return False, False def log_skip(self, node: torch.fx.Node, reason: str) -> None: if node.op == "call_function": logger.info( f"[Vulkan Partitioner] Due to [{reason}], skipping {node.format_node()}" ) def is_node_supported( self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node ) -> bool: r = self._is_node_supported(node) return r def _is_node_supported(self, node: torch.fx.Node) -> bool: target = node.target if node.target == torch.ops.higher_order.auto_functionalized: first_arg = node.args[0] assert isinstance(first_arg, torch._ops.OpOverload) target = first_arg.name() is_linear_permute, target_linear_is_compatible = self.is_linear_permute(node) if is_linear_permute and target_linear_is_compatible: return True elif is_linear_permute: # Skip so that the permute can be fused into a linear by another backend self.log_skip(node, "permute node of non compatible linear node") return False is_in_local_scalar_dense_chain, dst_node_is_compatible = ( self.is_in_local_scalar_dense_chain(node) ) if is_in_local_scalar_dense_chain and dst_node_is_compatible: return True elif is_in_local_scalar_dense_chain: self.log_skip(node, "local scalar dense of incompatible op node") return False if target not in vulkan_supported_ops: self.log_skip(node, "no operator implementation") return False features = vulkan_supported_ops[target] if not features.check_node_fn(node): self.log_skip(node, "op args not supported") return False if self.require_dynamic_shapes and not features.resize_fn: self.log_skip(node, "no dynamic shape support") return False is_compatible, reason = self.node_is_compatible(node, features=features) if not is_compatible: self.log_skip(node, reason) return is_compatible def parse_compile_options(compile_options: Dict[str, Any]) -> List[CompileSpec]: compile_specs = [] for key, value in compile_options.items(): if isinstance(value, (VkStorageType, VkMemoryLayout)): value_bytes = int(value).to_bytes(4, byteorder="little") compile_specs.append(CompileSpec(key, value_bytes)) if isinstance(value, bool): value_bytes = value.to_bytes(1, byteorder="little") compile_specs.append(CompileSpec(key, value_bytes)) if key == "texture_limits": compile_specs.append( CompileSpec( "texture_limits_x", int(value[0]).to_bytes(4, byteorder="little") ) ) compile_specs.append( CompileSpec( "texture_limits_y", int(value[1]).to_bytes(4, byteorder="little") ) ) compile_specs.append( CompileSpec( "texture_limits_z", int(value[2]).to_bytes(4, byteorder="little") ) ) # Unhandled options are ignored return compile_specs @final class VulkanPartitioner(Partitioner): def __init__( self, compile_options: Optional[Dict[str, Any]] = None, ) -> None: self.options: Dict[str, Any] = {} if compile_options is not None: self.options = compile_options compile_spec = parse_compile_options(self.options) self.delegation_spec = DelegationSpec(VulkanBackend.__name__, compile_spec) def ops_to_not_decompose( self, ep: ExportedProgram ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: return (ops_not_to_decompose, None) def partition(self, exported_program: ExportedProgram) -> PartitionResult: # Run the CapabilityBasedPartitioner to return the largest possible # subgraphs containing the nodes with the tags partition_tags = {} texture_limits: utils.ImageExtents = self.options.get( "texture_limits", utils.DEFAULT_TEXTURE_LIMITS ) buffer_limit: int = self.options.get("buffer_limit", utils.DEFAULT_BUFFER_LIMIT) capability_partitioner = CapabilityBasedPartitioner( exported_program.graph_module, VulkanSupportedOperators( texture_limits, buffer_limit, require_dynamic_shape=self.options.get("require_dynamic_shapes", False), ), allows_single_node_partition=True, ) partition_list = capability_partitioner.propose_partitions() for partition in partition_list: for node in partition.nodes: tag = f"tag{partition.id}" node.meta["delegation_tag"] = tag partition_tags[tag] = self.delegation_spec pl = len(partition_list) if pl == 0: logger.warning("No Vulkan subgraphs can be partitioned!") else: logger.info(f"Found {pl} Vulkan subgraphs to be partitioned.") tag_constant_data(exported_program) return PartitionResult( tagged_exported_program=exported_program, partition_tags=partition_tags )