1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6from typing import cast, List 7 8import torch 9from executorch.backends.apple.mps.operators.node_visitor import ( 10 NodeVisitor, 11 register_node_visitor, 12) 13from executorch.backends.apple.mps.serialization.mps_graph_schema import ( 14 MPSConv2D, 15 MPSDepthwiseConv2D, 16 MPSGraph, 17) 18from executorch.backends.apple.mps.utils.mps_utils import get_input_node 19from executorch.backends.transforms import get_shape 20 21 22@register_node_visitor 23class Conv2D(NodeVisitor): 24 target = "aten.convolution.default" 25 26 def __init__(self, *args) -> None: 27 super().__init__(*args) 28 29 def define_node( 30 self, 31 node: torch.fx.Node, 32 mps_graph: MPSGraph, 33 ) -> None: 34 input_shape = get_shape(get_input_node(node, 0)) 35 weight_shape = get_shape(get_input_node(node, 1)) 36 groups = cast(int, node.args[8]) 37 38 # Convolution is depthwise if groups = input channels and output channel 39 # is a positive multiple of input channels 40 41 is_depthwise_conv = (groups > 1 and weight_shape[1] == 1) and ( 42 len(input_shape) >= 4 and len(weight_shape) >= 4 43 ) 44 45 mps_node = self.create_tertiary_node( 46 node, mps_graph, MPSDepthwiseConv2D if is_depthwise_conv else MPSConv2D 47 ) 48 49 stride = cast(List[int], node.args[3]) 50 padding = cast(List[int], node.args[4]) 51 dilation = cast(List[int], node.args[5]) 52 53 if len(stride) == 1: 54 stride = [1, stride[0]] 55 if len(padding) == 1: 56 padding = [0, padding[0]] 57 if len(dilation) == 1: 58 dilation = [1, dilation[0]] 59 60 mps_node.mpsnode_union.stride_y = stride[0] 61 mps_node.mpsnode_union.stride_x = stride[1] 62 mps_node.mpsnode_union.dilation_y = dilation[0] 63 mps_node.mpsnode_union.dilation_x = dilation[1] 64 mps_node.mpsnode_union.groups = groups 65 mps_node.mpsnode_union.padding_top = padding[0] 66 mps_node.mpsnode_union.padding_bottom = padding[0] 67 mps_node.mpsnode_union.padding_right = padding[1] 68 mps_node.mpsnode_union.padding_left = padding[1] 69 70 mps_graph.mps_nodes.append(mps_node) 71