# Copyright 2024 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # # Tests the full op which creates a tensor of a given shape filled with a given value. # The shape and value are set at compile time, i.e. can't be set by a tensor input. # import unittest from typing import Tuple import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized class TestFull(unittest.TestCase): """Tests the full op which creates a tensor of a given shape filled with a given value.""" class Full(torch.nn.Module): # A single full op def forward(self): return torch.full((3, 3), 4.5) class AddConstFull(torch.nn.Module): # Input + a full with constant value. def forward(self, x: torch.Tensor): return torch.full((2, 2, 3, 3), 4.5, dtype=torch.float32) + x class AddVariableFull(torch.nn.Module): sizes = [ (5), (5, 5), (5, 5, 5), (1, 5, 5, 5), ] test_parameters = [((torch.randn(n) * 10 - 5, 3.2),) for n in sizes] def forward(self, x: torch.Tensor, y): # Input + a full with the shape from the input and a given value 'y'. return x + torch.full(x.shape, y) def _test_full_tosa_MI_pipeline( self, module: torch.nn.Module, example_data: Tuple, test_data: Tuple | None = None, ): if test_data is None: test_data = example_data ( ArmTester( module, example_inputs=example_data, compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_count({"torch.ops.aten.full.default": 1}) .to_edge() .partition() .check_not(["executorch_exir_dialects_edge__ops_aten_full_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .run_method_and_compare_outputs(inputs=test_data) ) def _test_full_tosa_BI_pipeline( self, module: torch.nn.Module, test_data: Tuple, permute_memory_to_nhwc: bool, ): ( ArmTester( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec( "TOSA-0.80.0+BI", permute_memory_to_nhwc=permute_memory_to_nhwc ), ) .quantize() .export() .check_count({"torch.ops.aten.full.default": 1}) .to_edge() .partition() .check_not(["executorch_exir_dialects_edge__ops_aten_full_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .run_method_and_compare_outputs(inputs=test_data) ) def _test_full_tosa_ethos_pipeline( self, compile_spec: list[CompileSpec], module: torch.nn.Module, test_data: Tuple ): ( ArmTester(module, example_inputs=test_data, compile_spec=compile_spec) .quantize() .export() .check_count({"torch.ops.aten.full.default": 1}) .to_edge() .partition() .check_not(["executorch_exir_dialects_edge__ops_aten_full_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() ) def _test_full_tosa_u55_pipeline(self, module: torch.nn.Module, test_data: Tuple): self._test_full_tosa_ethos_pipeline( common.get_u55_compile_spec(), module, test_data ) def _test_full_tosa_u85_pipeline(self, module: torch.nn.Module, test_data: Tuple): self._test_full_tosa_ethos_pipeline( common.get_u85_compile_spec(), module, test_data ) def test_only_full_tosa_MI(self): self._test_full_tosa_MI_pipeline(self.Full(), ()) def test_const_full_tosa_MI(self): _input = torch.rand((2, 2, 3, 3)) * 10 self._test_full_tosa_MI_pipeline(self.AddConstFull(), (_input,)) def test_const_full_nhwc_tosa_BI(self): _input = torch.rand((2, 2, 3, 3)) * 10 self._test_full_tosa_BI_pipeline(self.AddConstFull(), (_input,), True) @parameterized.expand(AddVariableFull.test_parameters) def test_full_tosa_MI(self, test_tensor: Tuple): self._test_full_tosa_MI_pipeline( self.AddVariableFull(), example_data=test_tensor ) @parameterized.expand(AddVariableFull.test_parameters) def test_full_tosa_BI(self, test_tensor: Tuple): self._test_full_tosa_BI_pipeline(self.AddVariableFull(), test_tensor, False) @parameterized.expand(AddVariableFull.test_parameters) def test_full_u55_BI(self, test_tensor: Tuple): self._test_full_tosa_u55_pipeline( self.AddVariableFull(), test_tensor, ) @parameterized.expand(AddVariableFull.test_parameters) def test_full_u85_BI(self, test_tensor: Tuple): self._test_full_tosa_u85_pipeline( self.AddVariableFull(), test_tensor, ) # This fails since full outputs int64 by default if 'fill_value' is integer, which our backend doesn't support. @unittest.expectedFailure def test_integer_value(self): _input = torch.ones((2, 2)) integer_fill_value = 1 self._test_full_tosa_MI_pipeline( self.AddVariableFull(), example_data=(_input, integer_fill_value) ) # This fails since the fill value in the full tensor is set at compile time by the example data (1.). # Test data tries to set it again at runtime (to 2.) but it doesn't do anything. # In eager mode, the fill value can be set at runtime, causing the outputs to not match. @unittest.expectedFailure def test_set_value_at_runtime(self): _input = torch.ones((2, 2)) example_fill_value = 1.0 test_fill_value = 2.0 self._test_full_tosa_MI_pipeline( self.AddVariableFull(), example_data=(_input, example_fill_value), test_data=(_input, test_fill_value), )