1# Copyright 2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6import operator 7import unittest 8from typing import Union 9 10import torch 11from executorch.backends.arm.test import common 12from executorch.backends.arm.test.tester.arm_tester import ArmTester 13from parameterized import parameterized 14 15 16class LiftedTensor(torch.nn.Module): 17 18 test_data = [ 19 # (operator, test_data, length) 20 (operator.add, (torch.randn(2, 2), 2)), 21 (operator.truediv, (torch.ones(2, 2), 2)), 22 (operator.mul, (torch.randn(2, 2), 2)), 23 (operator.sub, (torch.rand(2, 2), 2)), 24 ] 25 26 def __init__(self, op: callable): 27 super().__init__() 28 self.op = op 29 self.lifted_tensor = torch.Tensor([[1, 2], [3, 4]]) 30 31 def forward(self, x: torch.Tensor, length) -> torch.Tensor: 32 sliced = self.lifted_tensor[:, :length] 33 return self.op(sliced, x) 34 35 36class LiftedScalarTensor(torch.nn.Module): 37 test_data = [ 38 # (operator, test_data) 39 (operator.add, (torch.randn(2, 2),), 1.0), 40 (operator.truediv, (torch.randn(4, 2),), 1.0), 41 (operator.mul, (torch.randn(1, 2),), 2.0), 42 (operator.sub, (torch.randn(3),), 1.0), 43 ] 44 45 def __init__(self, op: callable, arg1: Union[int, float, torch.tensor]): 46 super().__init__() 47 self.op = op 48 self.arg1 = arg1 49 50 def forward(self, x: torch.Tensor) -> torch.Tensor: 51 return self.op(x, self.arg1) 52 53 54class TestLiftedTensor(unittest.TestCase): 55 """Tests the ArmPartitioner with a placeholder of type lifted tensor.""" 56 57 @parameterized.expand(LiftedTensor.test_data) 58 def test_partition_lifted_tensor_tosa_MI(self, op, data): 59 tester = ( 60 ArmTester( 61 LiftedTensor(op), 62 example_inputs=data, 63 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 64 ) 65 .export() 66 .to_edge() 67 ) 68 signature = tester.get_artifact().exported_program().graph_signature 69 assert len(signature.lifted_tensor_constants) > 0 70 tester.partition() 71 tester.to_executorch() 72 tester.run_method_and_compare_outputs(data) 73 74 @parameterized.expand(LiftedTensor.test_data) 75 def test_partition_lifted_tensor_tosa_BI(self, op, data): 76 tester = ( 77 ArmTester( 78 LiftedTensor(op), 79 example_inputs=data, 80 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 81 ) 82 .quantize() 83 .export() 84 .to_edge() 85 ) 86 signature = tester.get_artifact().exported_program().graph_signature 87 assert len(signature.lifted_tensor_constants) == 0 88 tester.partition() 89 tester.to_executorch() 90 tester.run_method_and_compare_outputs(data) 91 92 @parameterized.expand(LiftedScalarTensor.test_data) 93 def test_partition_lifted_scalar_tensor_tosa_MI(self, op, data, arg1): 94 ( 95 ArmTester( 96 LiftedScalarTensor(op, arg1), 97 example_inputs=(data), 98 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 99 ) 100 .export() 101 .to_edge() 102 .partition() 103 .to_executorch() 104 .run_method_and_compare_outputs(data) 105 ) 106 107 @parameterized.expand(LiftedScalarTensor.test_data) 108 def test_partition_lifted_scalar_tensor_tosa_BI(self, op, data, arg1): 109 ( 110 ArmTester( 111 LiftedScalarTensor(op, arg1), 112 example_inputs=(data), 113 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 114 ) 115 .quantize() 116 .export() 117 .to_edge() 118 .partition() 119 .to_executorch() 120 .run_method_and_compare_outputs(data) 121 ) 122