• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6import operator
7import unittest
8from typing import Union
9
10import torch
11from executorch.backends.arm.test import common
12from executorch.backends.arm.test.tester.arm_tester import ArmTester
13from parameterized import parameterized
14
15
16class LiftedTensor(torch.nn.Module):
17
18    test_data = [
19        # (operator, test_data, length)
20        (operator.add, (torch.randn(2, 2), 2)),
21        (operator.truediv, (torch.ones(2, 2), 2)),
22        (operator.mul, (torch.randn(2, 2), 2)),
23        (operator.sub, (torch.rand(2, 2), 2)),
24    ]
25
26    def __init__(self, op: callable):
27        super().__init__()
28        self.op = op
29        self.lifted_tensor = torch.Tensor([[1, 2], [3, 4]])
30
31    def forward(self, x: torch.Tensor, length) -> torch.Tensor:
32        sliced = self.lifted_tensor[:, :length]
33        return self.op(sliced, x)
34
35
36class LiftedScalarTensor(torch.nn.Module):
37    test_data = [
38        # (operator, test_data)
39        (operator.add, (torch.randn(2, 2),), 1.0),
40        (operator.truediv, (torch.randn(4, 2),), 1.0),
41        (operator.mul, (torch.randn(1, 2),), 2.0),
42        (operator.sub, (torch.randn(3),), 1.0),
43    ]
44
45    def __init__(self, op: callable, arg1: Union[int, float, torch.tensor]):
46        super().__init__()
47        self.op = op
48        self.arg1 = arg1
49
50    def forward(self, x: torch.Tensor) -> torch.Tensor:
51        return self.op(x, self.arg1)
52
53
54class TestLiftedTensor(unittest.TestCase):
55    """Tests the ArmPartitioner with a placeholder of type lifted tensor."""
56
57    @parameterized.expand(LiftedTensor.test_data)
58    def test_partition_lifted_tensor_tosa_MI(self, op, data):
59        tester = (
60            ArmTester(
61                LiftedTensor(op),
62                example_inputs=data,
63                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
64            )
65            .export()
66            .to_edge()
67        )
68        signature = tester.get_artifact().exported_program().graph_signature
69        assert len(signature.lifted_tensor_constants) > 0
70        tester.partition()
71        tester.to_executorch()
72        tester.run_method_and_compare_outputs(data)
73
74    @parameterized.expand(LiftedTensor.test_data)
75    def test_partition_lifted_tensor_tosa_BI(self, op, data):
76        tester = (
77            ArmTester(
78                LiftedTensor(op),
79                example_inputs=data,
80                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
81            )
82            .quantize()
83            .export()
84            .to_edge()
85        )
86        signature = tester.get_artifact().exported_program().graph_signature
87        assert len(signature.lifted_tensor_constants) == 0
88        tester.partition()
89        tester.to_executorch()
90        tester.run_method_and_compare_outputs(data)
91
92    @parameterized.expand(LiftedScalarTensor.test_data)
93    def test_partition_lifted_scalar_tensor_tosa_MI(self, op, data, arg1):
94        (
95            ArmTester(
96                LiftedScalarTensor(op, arg1),
97                example_inputs=(data),
98                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
99            )
100            .export()
101            .to_edge()
102            .partition()
103            .to_executorch()
104            .run_method_and_compare_outputs(data)
105        )
106
107    @parameterized.expand(LiftedScalarTensor.test_data)
108    def test_partition_lifted_scalar_tensor_tosa_BI(self, op, data, arg1):
109        (
110            ArmTester(
111                LiftedScalarTensor(op, arg1),
112                example_inputs=(data),
113                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
114            )
115            .quantize()
116            .export()
117            .to_edge()
118            .partition()
119            .to_executorch()
120            .run_method_and_compare_outputs(data)
121        )
122