1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import itertools 7import unittest 8 9import torch 10from executorch.backends.arm.quantizer.arm_quantizer_utils import is_annotated 11from executorch.backends.arm.test import common 12from executorch.backends.arm.test.tester.arm_tester import ArmTester 13from torch.fx.passes.utils.source_matcher_utils import get_source_partitions 14 15 16class SingleOpModel(torch.nn.Module): 17 def __init__(self, op, example_input, **op_kwargs) -> None: 18 super().__init__() 19 self.op = op 20 self._example_input = example_input 21 self.op_kwargs = op_kwargs 22 23 def forward(self, x): 24 return self.op(x, **self.op_kwargs) 25 26 def example_inputs(self): 27 return self._example_input 28 29 30class TestGenericAnnotator(unittest.TestCase): 31 def check_annotation(self, model): 32 tester = ArmTester( 33 model, 34 model.example_inputs(), 35 common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 36 ) 37 quant_model = tester.quantize().get_artifact() 38 partitions = get_source_partitions(quant_model.graph, [model.op]) 39 partitions = list(itertools.chain.from_iterable(partitions.values())) 40 41 assert len(partitions) == 1 42 partition = partitions[0] 43 assert all(is_annotated(node) for node in partition.nodes) 44 45 def test_squeeze(self): 46 self.check_annotation(SingleOpModel(torch.squeeze, (torch.rand(8, 8, 1),))) 47 self.check_annotation(SingleOpModel(torch.squeeze_copy, (torch.rand(8, 8, 1),))) 48 49 def test_unsqueeze(self): 50 self.check_annotation( 51 SingleOpModel(torch.unsqueeze, (torch.rand(8, 8),), dim=0) 52 ) 53 self.check_annotation( 54 SingleOpModel(torch.unsqueeze_copy, (torch.rand(8, 8),), dim=0) 55 ) 56 57 def test_reshape(self): 58 self.check_annotation( 59 SingleOpModel(torch.reshape, (torch.randn(8, 8),), shape=(64,)), 60 ) 61 62 def test_view(self): 63 self.check_annotation( 64 SingleOpModel(torch.view_copy, (torch.randn(4, 4),), size=(2, 8)), 65 ) 66 67 def test_slice(self): 68 self.check_annotation( 69 SingleOpModel(torch.slice_copy, (torch.randn(3, 4),)), 70 ) 71 72 def test_transpose(self): 73 self.check_annotation( 74 SingleOpModel(torch.transpose, (torch.randn(2, 3),), dim0=0, dim1=1), 75 ) 76 self.check_annotation( 77 SingleOpModel(torch.transpose_copy, (torch.randn(2, 3),), dim0=0, dim1=1), 78 ) 79 80 def test_tile(self): 81 self.check_annotation( 82 SingleOpModel(torch.tile, (torch.randn(4, 4),), dims=(2,)), 83 ) 84 85 def test_flip(self): 86 self.check_annotation( 87 SingleOpModel(torch.flip, (torch.randn(2, 4),), dims=(0, 1)), 88 ) 89