• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6from typing import cast, List
7
8import torch
9from executorch.backends.apple.mps.operators.node_visitor import (
10    NodeVisitor,
11    register_node_visitor,
12)
13from executorch.backends.apple.mps.serialization.mps_graph_schema import (
14    MPSEmbedding,
15    MPSGraph,
16    MPSIndexPut,
17    MPSIndexSelect,
18    MPSIndexTensor,
19    MPSScatter,
20)
21from executorch.backends.apple.mps.utils.mps_utils import get_input_node
22from executorch.backends.transforms import get_shape
23from executorch.exir.sym_util import eval_expr
24
25
26@register_node_visitor
27class IndexSelectVisitor(NodeVisitor):
28    target = "aten.index_select.default"
29
30    def __init__(self, *args) -> None:
31        super().__init__(*args)
32
33    def define_node(
34        self,
35        node: torch.fx.Node,
36        mps_graph: MPSGraph,
37    ) -> None:
38        mps_node = self.create_unary_node(node, mps_graph, MPSIndexSelect)
39        mps_node.mpsnode_union.dim = cast(int, node.args[1])
40        mps_node.mpsnode_union.index_id = self.define_tensor(
41            get_input_node(node, 2), mps_graph
42        )
43
44        mps_graph.mps_nodes.append(mps_node)
45
46
47@register_node_visitor
48class IndexTensorVisitor(NodeVisitor):
49    target = "aten.index.Tensor"
50
51    def __init__(self, *args) -> None:
52        super().__init__(*args)
53
54    def define_node(
55        self,
56        node: torch.fx.Node,
57        mps_graph: MPSGraph,
58    ) -> None:
59        mps_node = self.create_unary_node(node, mps_graph, MPSIndexTensor)
60        tensors = cast(List[torch.fx.Node], node.args[1])
61        for tensor in tensors:
62            mps_node.mpsnode_union.indices_id.append(
63                self.define_tensor(tensor, mps_graph)
64            )
65
66        mps_graph.mps_nodes.append(mps_node)
67
68
69@register_node_visitor
70class IndexPutVisitor(NodeVisitor):
71    target = "aten.index_put.default"
72
73    def __init__(self, *args) -> None:
74        super().__init__(*args)
75
76    def infer_sizes(self, a: List[int], b: List[int]):
77        dimsA = len(a)
78        dimsB = len(b)
79        ndim = dimsA if dimsA > dimsB else dimsB
80        expandedSizes = [0] * ndim
81        for i in range(ndim - 1, -1, -1):
82            offset = ndim - 1 - i
83            dimA = dimsA - 1 - offset
84            dimB = dimsB - 1 - offset
85            sizeA = a[dimA] if dimA >= 0 else -1
86            sizeB = b[dimB] if dimB >= 0 else -1
87            expandedSizes[i] = sizeA if sizeB == -1 else sizeB
88
89        return expandedSizes
90
91    def define_node(
92        self,
93        node: torch.fx.Node,
94        mps_graph: MPSGraph,
95    ) -> None:
96        mps_node = self.create_unary_node(node, mps_graph, MPSIndexPut)
97        updates_shape = get_shape(node.args[2])
98        input_shape = get_shape(node.args[0])
99        new_shape = []
100        if len(updates_shape) != 1 and len(updates_shape) != len(input_shape):
101            new_shape = self.infer_sizes(input_shape, updates_shape)
102            mps_node.mpsnode_union.values_shape = new_shape
103
104        tensors = cast(List[torch.fx.Node], node.args[1])
105        for tensor in tensors:
106            mps_node.mpsnode_union.indices_id.append(
107                self.define_tensor(tensor, mps_graph)
108            )
109
110        mps_node.mpsnode_union.values_id = self.define_tensor(
111            get_input_node(node, 2), mps_graph
112        )
113        mps_graph.mps_nodes.append(mps_node)
114
115
116@register_node_visitor
117class SliceScatterVisitor(NodeVisitor):
118    target = "aten.slice_scatter.default"
119
120    def __init__(self, *args) -> None:
121        super().__init__(*args)
122        self.invalid_val = 2**63 - 1
123
124    def maybe_wrap_dim(self, dim: int, n: int) -> List[int]:
125        if dim < 0:
126            wrapped_dim = dim + n
127            if wrapped_dim < 0:
128                wrapped_dim = 0
129            return wrapped_dim
130        elif dim > n:
131            return n
132        return dim
133
134    def get_exapnded_index(self, idx, shape, dim):
135        if idx.dim() == 0:
136            return idx.expand(shape)
137
138        dim = self.maybe_wrap_dim(dim, len(shape))
139
140        # setup new_index_shape as [BS, 1, ..., idx_size, ..., 1]
141        # to reshape index_
142        idx_size = idx.size(0)
143        new_index_shape = [1] * len(shape)
144        new_index_shape[dim] = idx_size
145
146        # Now apply expand to index_
147        index = idx.view(new_index_shape)
148        new_index_shape = list(shape)
149        new_index_shape[dim] = idx_size
150        index = index.expand(new_index_shape)
151
152        return index
153
154    def get_slice_scatter_indices(
155        self, dim, start, end, step, input_shape, dtype=torch.int64
156    ):
157        idx = torch.arange(start, end, step, dtype=dtype)
158        return self.get_exapnded_index(idx, input_shape, dim)
159
160    def define_node(
161        self,
162        node: torch.fx.Node,
163        mps_graph: MPSGraph,
164    ) -> None:
165        mps_node = self.create_unary_node(node, mps_graph, MPSScatter)
166
167        start = None
168        end = None
169        step = 1
170
171        mps_node.mpsnode_union.src_id = self.define_tensor(
172            get_input_node(node, 1), mps_graph
173        )
174        if len(node.args) >= 3:
175            mps_node.mpsnode_union.dim = cast(int, node.args[2])
176        if len(node.args) >= 4:
177            start = cast(int, node.args[3])
178        if len(node.args) >= 5 and node.args[4] != self.invalid_val:
179            end = cast(int, node.args[4])
180        if len(node.args) >= 6:
181            step = cast(int, node.args[5])
182
183        input_shape = get_shape(get_input_node(node, 0))
184        dim_len = input_shape[
185            self.maybe_wrap_dim(mps_node.mpsnode_union.dim, len(input_shape))
186        ]
187
188        start_val = start if start is not None else 0
189        end_val = end if end is not None else dim_len
190
191        scatter_indices = self.get_slice_scatter_indices(
192            mps_node.mpsnode_union.dim, start_val, end_val, step, input_shape
193        )
194        mps_node.mpsnode_union.idx_id = self.define_constant(scatter_indices, mps_graph)
195        mps_graph.mps_nodes.append(mps_node)
196
197
198@register_node_visitor
199class EmbeddingVisitor(NodeVisitor):
200    target = "aten.embedding.default"
201
202    def __init__(self, *args) -> None:
203        super().__init__(*args)
204
205    def define_node(
206        self,
207        node: torch.fx.Node,
208        mps_graph: MPSGraph,
209    ) -> None:
210        n_args = len(node.args)
211        mps_node = self.create_binary_node(
212            node,
213            mps_graph,
214            MPSEmbedding,
215        )
216
217        if n_args >= 3:
218            mps_node.mpsnode_union.padding_idx = eval_expr(
219                cast(torch.SymInt, node.args[2])
220            )
221        if n_args >= 4:
222            mps_node.mpsnode_union.scale_grad_by_freq = cast(bool, node.args[3])
223        if n_args >= 5:
224            mps_node.mpsnode_union.sparse = cast(bool, node.args[4])
225        mps_graph.mps_nodes.append(mps_node)
226