• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import tempfile
8import unittest
9
10import torch
11from executorch.backends.arm.util.arm_model_evaluator import GenericModelEvaluator
12
13# Create an input that is hard to compress
14COMPRESSION_RATIO_TEST = torch.rand([1024, 1024])
15
16
17def mocked_model_1(input: torch.Tensor) -> torch.Tensor:
18    return torch.tensor([1.0, 2.0, 3.0, 4.0])
19
20
21def mocked_model_2(input: torch.Tensor) -> torch.Tensor:
22    return torch.tensor([1.0, 2.0, 3.0, 3.0])
23
24
25class TestGenericModelEvaluator(unittest.TestCase):
26    """Tests the GenericModelEvaluator class."""
27
28    def test_get_model_error(self):
29        example_input = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
30        evaluator = GenericModelEvaluator(
31            "dummy_model",
32            mocked_model_1,
33            mocked_model_2,
34            example_input,
35            "tmp/output_tag0.tosa",
36        )
37
38        model_error_dict = evaluator.get_model_error()
39
40        self.assertEqual(model_error_dict["max_error"], [1.0])
41        self.assertEqual(model_error_dict["max_absolute_error"], [1.0])
42        self.assertEqual(model_error_dict["max_percentage_error"], [25.0])
43        self.assertEqual(model_error_dict["mean_absolute_error"], [0.25])
44
45    def test_get_compression_ratio(self):
46        with tempfile.NamedTemporaryFile(delete=True) as temp_bin:
47            torch.save(COMPRESSION_RATIO_TEST, temp_bin)
48
49            example_input = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
50            evaluator = GenericModelEvaluator(
51                "dummy_model",
52                mocked_model_1,
53                mocked_model_2,
54                example_input,
55                temp_bin.name,
56            )
57
58            ratio = evaluator.get_compression_ratio()
59            self.assertAlmostEqual(ratio, 1.1, places=1)
60