1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9from executorch.backends.arm.tosa_specification import ( 10 Tosa_0_80, 11 Tosa_1_00, 12 TosaSpecification, 13) 14 15from executorch.exir.backend.compile_spec_schema import CompileSpec 16from parameterized import parameterized 17 18test_valid_0_80_strings = [ 19 "TOSA-0.80.0+BI", 20 "TOSA-0.80.0+MI+8k", 21 "TOSA-0.80.0+BI+u55", 22] 23test_valid_1_00_strings = [ 24 "TOSA-1.00.0+INT+FP+fft", 25 "TOSA-1.00.0+FP+bf16+fft", 26 "TOSA-1.00.0+INT+int4+cf", 27 "TOSA-1.00.0+FP+cf+bf16+8k", 28 "TOSA-1.00.0+FP+INT+bf16+fft+int4+cf", 29 "TOSA-1.00.0+FP+INT+fft+int4+cf+8k", 30] 31 32test_valid_1_00_extensions = { 33 "INT": ["int16", "int4", "var", "cf"], 34 "FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"], 35} 36 37test_invalid_strings = [ 38 "TOSA-0.80.0+bi", 39 "TOSA-0.80.0", 40 "TOSA-0.80.0+8k", 41 "TOSA-0.80.0+BI+MI", 42 "TOSA-0.80.0+BI+U55", 43 "TOSA-1.00.0+fft", 44 "TOSA-1.00.0+fp+bf16+fft", 45 "TOSA-1.00.0+INT+INT4+cf", 46 "TOSA-1.00.0+BI", 47 "TOSA-1.00.0+FP+FP+INT", 48 "TOSA-1.00.0+FP+CF+bf16", 49 "TOSA-1.00.0+BF16+fft+int4+cf+INT", 50] 51 52test_compile_specs = [ 53 ([CompileSpec("tosa_version", "TOSA-0.80.0+BI".encode())],), 54 ([CompileSpec("tosa_version", "TOSA-0.80.0+BI+u55".encode())],), 55 ([CompileSpec("tosa_version", "TOSA-1.00.0+INT".encode())],), 56] 57 58test_compile_specs_no_version = [ 59 ([CompileSpec("other_key", "TOSA-0.80.0+BI".encode())],), 60 ([CompileSpec("other_key", "some_value".encode())],), 61] 62 63 64class TestTosaSpecification(unittest.TestCase): 65 """Tests the TOSA specification class""" 66 67 @parameterized.expand(test_valid_0_80_strings) 68 def test_version_string_0_80(self, version_string: str): 69 tosa_spec = TosaSpecification.create_from_string(version_string) 70 assert isinstance(tosa_spec, Tosa_0_80) 71 assert tosa_spec.profile in ["BI", "MI"] 72 73 @parameterized.expand(test_valid_1_00_strings) 74 def test_version_string_1_00(self, version_string: str): 75 tosa_spec = TosaSpecification.create_from_string(version_string) 76 assert isinstance(tosa_spec, Tosa_1_00) 77 assert [profile in ["INT", "FP"] for profile in tosa_spec.profiles].count( 78 True 79 ) > 0 80 81 for profile in tosa_spec.profiles: 82 assert [ 83 e in test_valid_1_00_extensions[profile] for e in tosa_spec.extensions 84 ] 85 86 @parameterized.expand(test_invalid_strings) 87 def test_invalid_version_strings(self, version_string: str): 88 tosa_spec = None 89 with self.assertRaises(ValueError): 90 tosa_spec = TosaSpecification.create_from_string(version_string) 91 92 assert tosa_spec is None 93 94 @parameterized.expand(test_compile_specs) 95 def test_create_from_compilespec(self, compile_specs: list[CompileSpec]): 96 tosa_spec = TosaSpecification.create_from_compilespecs(compile_specs) 97 assert isinstance(tosa_spec, TosaSpecification) 98 99 @parameterized.expand(test_compile_specs_no_version) 100 def test_create_from_invalid_compilespec(self, compile_specs: list[CompileSpec]): 101 tosa_spec = None 102 with self.assertRaises(ValueError): 103 tosa_spec = TosaSpecification.create_from_compilespecs(compile_specs) 104 105 assert tosa_spec is None 106