1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import tempfile 8import unittest 9 10import torch 11from executorch.backends.arm.util.arm_model_evaluator import GenericModelEvaluator 12 13# Create an input that is hard to compress 14COMPRESSION_RATIO_TEST = torch.rand([1024, 1024]) 15 16 17def mocked_model_1(input: torch.Tensor) -> torch.Tensor: 18 return torch.tensor([1.0, 2.0, 3.0, 4.0]) 19 20 21def mocked_model_2(input: torch.Tensor) -> torch.Tensor: 22 return torch.tensor([1.0, 2.0, 3.0, 3.0]) 23 24 25class TestGenericModelEvaluator(unittest.TestCase): 26 """Tests the GenericModelEvaluator class.""" 27 28 def test_get_model_error(self): 29 example_input = torch.tensor([[1.0, 2.0, 3.0, 4.0]]) 30 evaluator = GenericModelEvaluator( 31 "dummy_model", 32 mocked_model_1, 33 mocked_model_2, 34 example_input, 35 "tmp/output_tag0.tosa", 36 ) 37 38 model_error_dict = evaluator.get_model_error() 39 40 self.assertEqual(model_error_dict["max_error"], [1.0]) 41 self.assertEqual(model_error_dict["max_absolute_error"], [1.0]) 42 self.assertEqual(model_error_dict["max_percentage_error"], [25.0]) 43 self.assertEqual(model_error_dict["mean_absolute_error"], [0.25]) 44 45 def test_get_compression_ratio(self): 46 with tempfile.NamedTemporaryFile(delete=True) as temp_bin: 47 torch.save(COMPRESSION_RATIO_TEST, temp_bin) 48 49 example_input = torch.tensor([[1.0, 2.0, 3.0, 4.0]]) 50 evaluator = GenericModelEvaluator( 51 "dummy_model", 52 mocked_model_1, 53 mocked_model_2, 54 example_input, 55 temp_bin.name, 56 ) 57 58 ratio = evaluator.get_compression_ratio() 59 self.assertAlmostEqual(ratio, 1.1, places=1) 60