• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9from executorch.backends.arm.tosa_specification import (
10    Tosa_0_80,
11    Tosa_1_00,
12    TosaSpecification,
13)
14
15from executorch.exir.backend.compile_spec_schema import CompileSpec
16from parameterized import parameterized
17
18test_valid_0_80_strings = [
19    "TOSA-0.80.0+BI",
20    "TOSA-0.80.0+MI+8k",
21    "TOSA-0.80.0+BI+u55",
22]
23test_valid_1_00_strings = [
24    "TOSA-1.00.0+INT+FP+fft",
25    "TOSA-1.00.0+FP+bf16+fft",
26    "TOSA-1.00.0+INT+int4+cf",
27    "TOSA-1.00.0+FP+cf+bf16+8k",
28    "TOSA-1.00.0+FP+INT+bf16+fft+int4+cf",
29    "TOSA-1.00.0+FP+INT+fft+int4+cf+8k",
30]
31
32test_valid_1_00_extensions = {
33    "INT": ["int16", "int4", "var", "cf"],
34    "FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"],
35}
36
37test_invalid_strings = [
38    "TOSA-0.80.0+bi",
39    "TOSA-0.80.0",
40    "TOSA-0.80.0+8k",
41    "TOSA-0.80.0+BI+MI",
42    "TOSA-0.80.0+BI+U55",
43    "TOSA-1.00.0+fft",
44    "TOSA-1.00.0+fp+bf16+fft",
45    "TOSA-1.00.0+INT+INT4+cf",
46    "TOSA-1.00.0+BI",
47    "TOSA-1.00.0+FP+FP+INT",
48    "TOSA-1.00.0+FP+CF+bf16",
49    "TOSA-1.00.0+BF16+fft+int4+cf+INT",
50]
51
52test_compile_specs = [
53    ([CompileSpec("tosa_version", "TOSA-0.80.0+BI".encode())],),
54    ([CompileSpec("tosa_version", "TOSA-0.80.0+BI+u55".encode())],),
55    ([CompileSpec("tosa_version", "TOSA-1.00.0+INT".encode())],),
56]
57
58test_compile_specs_no_version = [
59    ([CompileSpec("other_key", "TOSA-0.80.0+BI".encode())],),
60    ([CompileSpec("other_key", "some_value".encode())],),
61]
62
63
64class TestTosaSpecification(unittest.TestCase):
65    """Tests the TOSA specification class"""
66
67    @parameterized.expand(test_valid_0_80_strings)
68    def test_version_string_0_80(self, version_string: str):
69        tosa_spec = TosaSpecification.create_from_string(version_string)
70        assert isinstance(tosa_spec, Tosa_0_80)
71        assert tosa_spec.profile in ["BI", "MI"]
72
73    @parameterized.expand(test_valid_1_00_strings)
74    def test_version_string_1_00(self, version_string: str):
75        tosa_spec = TosaSpecification.create_from_string(version_string)
76        assert isinstance(tosa_spec, Tosa_1_00)
77        assert [profile in ["INT", "FP"] for profile in tosa_spec.profiles].count(
78            True
79        ) > 0
80
81        for profile in tosa_spec.profiles:
82            assert [
83                e in test_valid_1_00_extensions[profile] for e in tosa_spec.extensions
84            ]
85
86    @parameterized.expand(test_invalid_strings)
87    def test_invalid_version_strings(self, version_string: str):
88        tosa_spec = None
89        with self.assertRaises(ValueError):
90            tosa_spec = TosaSpecification.create_from_string(version_string)
91
92        assert tosa_spec is None
93
94    @parameterized.expand(test_compile_specs)
95    def test_create_from_compilespec(self, compile_specs: list[CompileSpec]):
96        tosa_spec = TosaSpecification.create_from_compilespecs(compile_specs)
97        assert isinstance(tosa_spec, TosaSpecification)
98
99    @parameterized.expand(test_compile_specs_no_version)
100    def test_create_from_invalid_compilespec(self, compile_specs: list[CompileSpec]):
101        tosa_spec = None
102        with self.assertRaises(ValueError):
103            tosa_spec = TosaSpecification.create_from_compilespecs(compile_specs)
104
105        assert tosa_spec is None
106