• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6from typing import cast, List
7
8import torch
9from executorch.backends.apple.mps.operators.node_visitor import (
10    NodeVisitor,
11    register_node_visitor,
12)
13from executorch.backends.apple.mps.serialization.mps_graph_schema import (
14    MPSCat,
15    MPSExpand,
16    MPSGraph,
17    MPSNode,
18    MPSPermute,
19    MPSPixelShuffle,
20    MPSSelect,
21    MPSSlice,
22    MPSSplitWithSizes,
23    MPSSqueeze,
24    MPSUnsqueeze,
25    MPSView,
26)
27from executorch.backends.apple.mps.utils.mps_utils import get_input_node
28from executorch.backends.transforms import get_shape
29from executorch.exir.dialects._ops import ops as exir_ops
30
31from executorch.exir.sym_util import eval_expr, eval_shape
32
33
34@register_node_visitor
35class PermuteVisitor(NodeVisitor):
36    target = "aten.permute_copy.default"
37
38    def __init__(self, *args) -> None:
39        super().__init__(*args)
40
41    def define_node(
42        self,
43        node: torch.fx.Node,
44        mps_graph: MPSGraph,
45    ) -> None:
46        mps_node = self.create_unary_node(node, mps_graph, MPSPermute)
47
48        permute_order = cast(List[int], node.args[1])
49        mps_node.mpsnode_union.num_dims = len(permute_order)
50        mps_node.mpsnode_union.perm = permute_order
51
52        mps_graph.mps_nodes.append(mps_node)
53
54
55@register_node_visitor
56class ViewExpandVisitor(NodeVisitor):
57    target = ["aten.view_copy.default", "aten.expand_copy.default"]
58
59    def __init__(self, *args) -> None:
60        super().__init__(*args)
61
62    def define_node(
63        self,
64        node: torch.fx.Node,
65        mps_graph: MPSGraph,
66    ) -> None:
67        node_type = (
68            MPSView
69            if node.target is exir_ops.edge.aten.view_copy.default
70            else MPSExpand
71        )
72        mps_node = self.create_unary_node(node, mps_graph, node_type)
73
74        view_shape = cast(List[int], node.args[1])
75        mps_node.mpsnode_union.num_dims = len(view_shape)
76        mps_node.mpsnode_union.shape = view_shape
77
78        mps_graph.mps_nodes.append(mps_node)
79
80
81@register_node_visitor
82class CatVisitor(NodeVisitor):
83    target = "aten.cat.default"
84
85    def __init__(self, *args) -> None:
86        super().__init__(*args)
87
88    def define_node(
89        self,
90        node: torch.fx.Node,
91        mps_graph: MPSGraph,
92    ) -> None:
93        tensors = cast(List[torch.fx.Node], node.args[0])
94        output_id = self.define_tensor(node, mps_graph)
95        input_ids: List[int] = []
96
97        for tensor in tensors:
98            input_ids.append(self.define_tensor(tensor, mps_graph))
99
100        dim = 0
101        if len(node.args) > 1:
102            dim = cast(int, node.args[1])
103            if dim < 0 and len(tensors) > 0:
104                dim += len(get_shape(tensors[0]))
105
106        mps_graph.mps_nodes.append(
107            MPSNode(
108                mpsnode_union=MPSCat(input_ids=input_ids, output_id=output_id, dim=dim),
109            ),
110        )
111
112
113@register_node_visitor
114class SqueezeUnsqueezeVisitor(NodeVisitor):
115    target = ["aten.unsqueeze_copy.default", "aten.squeeze_copy.dims"]
116
117    def __init__(self, *args) -> None:
118        super().__init__(*args)
119
120    def define_node(
121        self,
122        node: torch.fx.Node,
123        mps_graph: MPSGraph,
124    ) -> None:
125        node_type = (
126            MPSUnsqueeze
127            if node.target is exir_ops.edge.aten.unsqueeze_copy.default
128            else MPSSqueeze
129        )
130
131        mps_node = self.create_unary_node(node, mps_graph, node_type)
132
133        if node_type is MPSUnsqueeze:
134            mps_node.mpsnode_union.dim = cast(int, node.args[1])
135        else:
136            dims = cast(List[int], node.args[1])
137            input_shape = get_shape(get_input_node(node, 0))
138            new_dims = []
139            for dim in dims:
140                if input_shape[dim] == 1:
141                    new_dims.append(dim)
142            mps_node.mpsnode_union.dims = new_dims
143
144        mps_graph.mps_nodes.append(mps_node)
145
146
147@register_node_visitor
148class SelectVisitor(NodeVisitor):
149    target = "aten.select_copy.int"
150
151    def __init__(self, *args) -> None:
152        super().__init__(*args)
153
154    def define_node(
155        self,
156        node: torch.fx.Node,
157        mps_graph: MPSGraph,
158    ) -> None:
159        mps_node = self.create_unary_node(node, mps_graph, MPSSelect)
160        mps_node.mpsnode_union.dim = cast(int, node.args[1])
161        mps_node.mpsnode_union.index = eval_expr(cast(torch.SymInt, node.args[2]))
162        mps_graph.mps_nodes.append(mps_node)
163
164
165@register_node_visitor
166class PixelShuffleVisitor(NodeVisitor):
167    target = "aten.pixel_shuffle.default"
168
169    def __init__(self, *args) -> None:
170        super().__init__(*args)
171
172    def define_node(
173        self,
174        node: torch.fx.Node,
175        mps_graph: MPSGraph,
176    ) -> None:
177        mps_node = self.create_unary_node(node, mps_graph, MPSPixelShuffle)
178        mps_node.mpsnode_union.upscale_factor = cast(int, node.args[1])
179        mps_graph.mps_nodes.append(mps_node)
180
181
182@register_node_visitor
183class SliceVisitor(NodeVisitor):
184    target = "aten.slice_copy.Tensor"
185
186    def __init__(self, *args) -> None:
187        super().__init__(*args)
188
189    def define_node(
190        self,
191        node: torch.fx.Node,
192        mps_graph: MPSGraph,
193    ) -> None:
194        mps_node = self.create_unary_node(node, mps_graph, MPSSlice)
195
196        def maybe_wrap_dim(dim: int, n: int) -> List[int]:
197            if dim < 0:
198                wrapped_dim = dim + n
199                if wrapped_dim < 0:
200                    wrapped_dim = 0
201                return wrapped_dim
202            elif dim > n:
203                return n
204            return dim
205
206        start = None
207        end = None
208        if len(node.args) >= 2:
209            mps_node.mpsnode_union.dim = cast(int, node.args[1])
210        if len(node.args) >= 4:
211            end = cast(int, node.args[3])
212            start = cast(int, node.args[2])
213        if len(node.args) >= 5:
214            mps_node.mpsnode_union.step = cast(int, node.args[4])
215
216        input_shape = get_shape(get_input_node(node, 0))
217        dim_len = input_shape[
218            maybe_wrap_dim(mps_node.mpsnode_union.dim, len(input_shape))
219        ]
220
221        start_val = start if start is not None else 0
222        end_val = end if end is not None else dim_len
223
224        mps_node.mpsnode_union.start = maybe_wrap_dim(start_val, dim_len)
225        mps_node.mpsnode_union.end = maybe_wrap_dim(end_val, dim_len)
226        mps_graph.mps_nodes.append(mps_node)
227
228
229@register_node_visitor
230class SplitWithSizesVisitor(NodeVisitor):
231    target = "aten.split_with_sizes_copy.default"
232
233    def __init__(self, *args) -> None:
234        super().__init__(*args)
235
236    def define_node(
237        self,
238        node: torch.fx.Node,
239        mps_graph: MPSGraph,
240    ) -> None:
241        input1_id = self.define_tensor(get_input_node(node, 0), mps_graph)
242        output_ids = self.define_tensor_list(node, mps_graph)
243        split_sizes = eval_shape(cast(torch.SymInt, node.args[1]))
244        dim = cast(int, node.args[2])
245        input_shape = get_shape(get_input_node(node, 0))
246
247        if dim < 0 or dim >= len(input_shape):
248            raise RuntimeError(
249                f"split_copy: dim {dim} out of range for input tensor with {len(input_shape)} dimensions"
250            )
251
252        mps_node = MPSNode(
253            mpsnode_union=MPSSplitWithSizes(
254                input1_id=input1_id,
255                output_ids=output_ids,
256                split_sizes=split_sizes,
257                dim=dim,
258            )
259        )
260        mps_graph.mps_nodes.append(mps_node)
261