• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6import random
7
8import executorch.exir.control_flow as control_flow
9import torch
10from functorch.experimental.control_flow import cond
11from torch import nn
12
13
14class LayerNormModule(torch.nn.Module):
15    def __init__(self) -> None:
16        super().__init__()
17        self.norm = torch.nn.LayerNorm([5, 10, 10])
18
19    def forward(self, arg):
20        return self.norm(arg)
21
22    @staticmethod
23    def get_example_inputs():
24        return (torch.randn(20, 5, 10, 10),)
25
26
27class Conv2DModule(torch.nn.Module):
28    def __init__(self) -> None:
29        super().__init__()
30        self.conv = torch.nn.Conv2d(1, 3, 3, stride=1)
31
32    def forward(self, arg):
33        return self.conv(arg)
34
35    @staticmethod
36    def get_example_inputs():
37        return (torch.randn(1, 1, 3, 3),)
38
39
40class ModuleBasic(torch.nn.Module):
41    def __init__(self):
42        super(ModuleBasic, self).__init__()
43
44    def forward(self, x):
45        return torch.sin(x).max()
46
47    def get_random_inputs(self):
48        return (torch.randn(100),)
49
50
51class ModuleOpsReturnMulti(torch.nn.Module):
52    def __init__(self):
53        super(ModuleOpsReturnMulti, self).__init__()
54
55    def forward(self, a, b):
56        x, y = torch.topk(a, 3)
57        return x * 2 + b
58
59    def get_random_inputs(self):
60        return (torch.randn(10), torch.randn(3))
61
62
63class ModuleAdd(torch.nn.Module):
64    def __init__(self):
65        super(ModuleAdd, self).__init__()
66
67    def forward(self, x, y):
68        return torch.add(x, y)
69
70    def get_random_inputs(self):
71        return (torch.randn(2, 2), torch.randn(2, 2))
72
73
74class ModuleFloatAddWithAlpha(torch.nn.Module):
75    def __init__(self):
76        super(ModuleFloatAddWithAlpha, self).__init__()
77
78    def forward(self, x: torch.Tensor, y: torch.Tensor, c: float):
79        return torch.add(x, y, alpha=c)
80
81    def get_random_inputs(self):
82        return (torch.randn(2, 2), torch.randn(2, 2), random.random())
83
84
85class ModuleIntAddWithAlpha(torch.nn.Module):
86    def __init__(self):
87        super(ModuleIntAddWithAlpha, self).__init__()
88
89    def forward(self, x: torch.Tensor, y: torch.Tensor, c: int):
90        return torch.add(x, y, alpha=c)
91
92    def get_random_inputs(self):
93        return (
94            torch.randint(0, 10, (2, 2)),
95            torch.randint(0, 10, (2, 2)),
96            random.randint(0, 10),
97        )
98
99
100class ModuleContainers(torch.nn.Module):
101    def __init__(self):
102        super(ModuleContainers, self).__init__()
103
104    def forward(self, d):
105        a = d["a"]
106        b = d["b"]
107        return {"inputs": (a, b), "c": torch.add(a, b)}
108
109    def get_random_inputs(self):
110        return ({"a": torch.randn(2, 2), "b": torch.randn(2, 2)},)
111
112
113class ToyModelForMemPlanning(torch.nn.Module):
114    def __init__(self):
115        super(ToyModelForMemPlanning, self).__init__()
116
117    def forward(self, a, b):
118        o = a
119        for _ in range(3):
120            o = o * a
121            o = o + b
122        return o
123
124    def get_random_inputs(self):
125        return (
126            torch.randn(10),
127            torch.randn(10),
128        )
129
130
131class MemPlanningWithScratchTensor(torch.nn.Module):
132    def __init__(self):
133        super(MemPlanningWithScratchTensor, self).__init__()
134        self.linear1 = torch.nn.Linear(4, 2)
135        self.linear2 = torch.nn.Linear(4, 2)
136
137    def forward(self, a, b):
138        o1 = self.linear1(a)
139        o2 = self.linear2(b)
140        return o1 + o2
141
142    def get_random_inputs(self):
143        return (
144            torch.randn(10, 4),
145            torch.randn(10, 4),
146        )
147
148
149class ModuleOpsReturnTensorList(torch.nn.Module):
150    def __init__(self):
151        super(ModuleOpsReturnTensorList, self).__init__()
152
153    def forward(self, x):
154        split = torch.ops.aten.tensor_split.sections(x, 3)
155        return split[0]
156
157    def get_random_inputs(self):
158        return (torch.randn(100),)
159
160
161class ModuleReturnInput(torch.nn.Module):
162    def __init__(self):
163        super(ModuleReturnInput, self).__init__()
164
165    def forward(self, x):
166        return (x, x, {"x": x, "y": x}, [x, x, x])
167
168    def get_random_inputs(self):
169        return (torch.randn(1),)
170
171
172class ModuleIfElse(torch.nn.Module):
173    def __init__(self):
174        super().__init__()
175
176    def forward(self, c, x):
177        x = x * x
178
179        def addloop(x, n):
180            out = x
181            for _ in range(n - 1):
182                out = out + x
183            return out
184
185        def true_branch(c, x):
186            return addloop(x, 3)
187
188        def false_branch(c, x):
189            return addloop(x, 4)
190
191        y = cond(c, true_branch, false_branch, (c, x))
192        return y * y
193
194    def get_random_inputs(self):
195        return (torch.randint(2, [1]) == 0, torch.randn(10))
196
197
198class ModuleIfElseWithBoolInput(torch.nn.Module):
199    def __init__(self):
200        super().__init__()
201
202    def forward(self, c: bool, x: torch.Tensor):
203        x = x * x
204
205        def addloop(x, n):
206            out = x
207            for _ in range(n - 1):
208                out = out + x
209            return out
210
211        def true_branch(c, x):
212            return addloop(x, 3)
213
214        def false_branch(c, x):
215            return addloop(x, 4)
216
217        y = cond(c, true_branch, false_branch, (c, x))
218
219        return y * y
220
221    def get_random_inputs(self):
222        return (random.randint(0, 1) == 0, torch.randn(10))
223
224
225class ModuleWhileIf(torch.nn.Module):
226    def __init__(self):
227        super().__init__()
228
229    def forward(self, accum, cnt):
230        @control_flow.tracing_context(
231            inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))
232        )
233        def loop_cond(accum, cnt):
234            return cnt != torch.zeros([1]).to(dtype=torch.long)
235
236        @control_flow.tracing_context(
237            inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))
238        )
239        def loop_body(accum, cnt):
240            # return accum + cnt, cnt - torch.ones([1]).to(dtype=torch.long)
241            @control_flow.tracing_context(
242                inputs=(torch.zeros([1]).to(dtype=torch.long),)
243            )
244            def true_branch(cnt):
245                return cnt
246
247            @control_flow.tracing_context(
248                inputs=(torch.zeros([1]).to(dtype=torch.long),)
249            )
250            def false_branch(cnt):
251                return torch.zeros([1], dtype=torch.long)
252
253            accum = accum + cond(
254                torch.BoolTensor([True]), true_branch, false_branch, (cnt,)
255            )
256            # 'cnt - 1' does not work yet since the runtime does not expect
257            # tensor to be mixed with scalar for sub op.
258            return accum, cnt - torch.ones([1]).to(dtype=torch.long)
259
260        y, _ = control_flow.while_loop(
261            loop_cond,
262            loop_body,
263            (accum, cnt),
264        )
265        return y
266
267    def get_random_inputs(self):
268        return (torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))
269
270
271class ModuleIfWhile(torch.nn.Module):
272    def __init__(self):
273        super().__init__()
274
275    def forward(self, accum, cnt):
276        @control_flow.tracing_context(
277            inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))
278        )
279        def true_branch(accum, cnt):
280            @control_flow.tracing_context(
281                inputs=(
282                    torch.zeros([1]).to(dtype=torch.long),
283                    torch.randint(10, 100, [1]),
284                )
285            )
286            def loop_cond(accum, cnt):
287                return cnt != torch.zeros([1]).to(dtype=torch.long)
288
289            @control_flow.tracing_context(
290                inputs=(
291                    torch.zeros([1]).to(dtype=torch.long),
292                    torch.randint(10, 100, [1]),
293                )
294            )
295            def loop_body(accum, cnt):
296                return accum + cnt, cnt - torch.ones([1]).to(dtype=torch.long)
297
298            return control_flow.while_loop(loop_cond, loop_body, (accum, cnt))
299
300        @control_flow.tracing_context(
301            inputs=(torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))
302        )
303        def false_branch(accum, cnt):
304            return accum, cnt
305
306        return cond(torch.BoolTensor([True]), true_branch, false_branch, (accum, cnt))[
307            0
308        ]
309
310    def get_random_inputs(self):
311        return (torch.zeros([1]).to(dtype=torch.long), torch.randint(10, 100, [1]))
312
313
314class ModuleContiguousTensor(torch.nn.Module):
315    def __init__(self):
316        super().__init__()
317        self.linear = nn.Linear(8, 32)
318
319    def forward(self, arg):
320        return self.linear(arg)
321
322    def get_random_inputs(self):
323        return (torch.randn(3, 8),)
324
325
326class ModuleInputDynamicShape(torch.nn.Module):
327    def __init__(self):
328        super().__init__()
329
330    def forward(self, x):
331        for _ in range(4):
332            x = x + x
333            x = x * x
334        return x
335
336    def get_upper_bound_inputs(self):
337        return (torch.randn(10),)
338
339    def get_random_inputs(self):
340        n = random.randint(1, 10)
341        return (torch.randn(n),)
342
343
344class ModuleIntermediateDynamicShape(torch.nn.Module):
345    def __init__(self):
346        super().__init__()
347
348    def forward(self, x):
349        x = x * x
350
351        # We should use x[torch.nonzero(x)] ideally, but index op is not supported
352        # in the runtime so far.
353        x = torch.nonzero(x)
354        return x + x
355
356    def get_random_inputs(self):
357        return (torch.randint(0, 2, (10,), dtype=torch.float),)
358
359
360MPS_MODEL_NAME_TO_MODEL = {
361    "conv2D": lambda: (Conv2DModule(), Conv2DModule.get_example_inputs()),
362    "norm": lambda: (LayerNormModule(), LayerNormModule.get_example_inputs()),
363    "module_basic": lambda: (ModuleBasic(), ModuleBasic().get_random_inputs()),
364    "module_ops_return_multi": lambda: (
365        ModuleOpsReturnMulti(),
366        ModuleOpsReturnMulti().get_random_inputs(),
367    ),
368    "module_add": lambda: (ModuleAdd(), ModuleAdd().get_random_inputs()),
369    "module_float_add_with_alpha": lambda: (
370        ModuleFloatAddWithAlpha(),
371        ModuleFloatAddWithAlpha().get_random_inputs(),
372    ),
373    "module_int_add_with_alpha": lambda: (
374        ModuleIntAddWithAlpha(),
375        ModuleIntAddWithAlpha().get_random_inputs(),
376    ),
377    "module_containers": lambda: (
378        ModuleContainers(),
379        ModuleContainers().get_random_inputs(),
380    ),
381    "toy_model_for_mem_planning": lambda: (
382        ToyModelForMemPlanning(),
383        ToyModelForMemPlanning().get_random_inputs(),
384    ),
385    "mem_planning_with_scratch_tensor": lambda: (
386        MemPlanningWithScratchTensor(),
387        MemPlanningWithScratchTensor().get_random_inputs(),
388    ),
389    "module_ops_return_tensor_list": lambda: (
390        ModuleOpsReturnTensorList(),
391        ModuleOpsReturnTensorList().get_random_inputs(),
392    ),
393    "module_return_input": lambda: (
394        ModuleReturnInput(),
395        ModuleReturnInput().get_random_inputs(),
396    ),
397    "module_if_else": lambda: (ModuleIfElse(), ModuleIfElse().get_random_inputs()),
398    "module_if_else_with_bool_input": lambda: (
399        ModuleIfElseWithBoolInput(),
400        ModuleIfElseWithBoolInput().get_random_inputs(),
401    ),
402    "module_while_if": lambda: (ModuleWhileIf(), ModuleWhileIf().get_random_inputs()),
403    "module_if_while": lambda: (ModuleIfWhile(), ModuleIfWhile().get_random_inputs()),
404    "module_contiguous_tensor": lambda: (
405        ModuleContiguousTensor(),
406        ModuleContiguousTensor().get_random_inputs(),
407    ),
408    "module_input_dynamic_shape": lambda: (
409        ModuleInputDynamicShape(),
410        ModuleInputDynamicShape().get_random_inputs(),
411    ),
412    "module_intermediate_dynamic_shape": lambda: (
413        ModuleIntermediateDynamicShape(),
414        ModuleIntermediateDynamicShape().get_random_inputs(),
415    ),
416}
417