• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6from typing import cast, List
7
8import torch
9from executorch.backends.apple.mps.operators.node_visitor import (
10    NodeVisitor,
11    register_node_visitor,
12)
13from executorch.backends.apple.mps.serialization.mps_graph_schema import (
14    MPSConv2D,
15    MPSDepthwiseConv2D,
16    MPSGraph,
17)
18from executorch.backends.apple.mps.utils.mps_utils import get_input_node
19from executorch.backends.transforms import get_shape
20
21
22@register_node_visitor
23class Conv2D(NodeVisitor):
24    target = "aten.convolution.default"
25
26    def __init__(self, *args) -> None:
27        super().__init__(*args)
28
29    def define_node(
30        self,
31        node: torch.fx.Node,
32        mps_graph: MPSGraph,
33    ) -> None:
34        input_shape = get_shape(get_input_node(node, 0))
35        weight_shape = get_shape(get_input_node(node, 1))
36        groups = cast(int, node.args[8])
37
38        # Convolution is depthwise if groups = input channels and output channel
39        # is a positive multiple of input channels
40
41        is_depthwise_conv = (groups > 1 and weight_shape[1] == 1) and (
42            len(input_shape) >= 4 and len(weight_shape) >= 4
43        )
44
45        mps_node = self.create_tertiary_node(
46            node, mps_graph, MPSDepthwiseConv2D if is_depthwise_conv else MPSConv2D
47        )
48
49        stride = cast(List[int], node.args[3])
50        padding = cast(List[int], node.args[4])
51        dilation = cast(List[int], node.args[5])
52
53        if len(stride) == 1:
54            stride = [1, stride[0]]
55        if len(padding) == 1:
56            padding = [0, padding[0]]
57        if len(dilation) == 1:
58            dilation = [1, dilation[0]]
59
60        mps_node.mpsnode_union.stride_y = stride[0]
61        mps_node.mpsnode_union.stride_x = stride[1]
62        mps_node.mpsnode_union.dilation_y = dilation[0]
63        mps_node.mpsnode_union.dilation_x = dilation[1]
64        mps_node.mpsnode_union.groups = groups
65        mps_node.mpsnode_union.padding_top = padding[0]
66        mps_node.mpsnode_union.padding_bottom = padding[0]
67        mps_node.mpsnode_union.padding_right = padding[1]
68        mps_node.mpsnode_union.padding_left = padding[1]
69
70        mps_graph.mps_nodes.append(mps_node)
71