1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import unittest 10from typing import List 11 12import executorch.devtools.bundled_program.schema as bp_schema 13 14import torch 15from executorch.devtools.bundled_program.config import ConfigValue 16from executorch.devtools.bundled_program.core import BundledProgram 17from executorch.devtools.bundled_program.util.test_util import ( 18 get_common_executorch_program, 19) 20from executorch.exir._serialize import _serialize_pte_binary 21 22 23class TestBundle(unittest.TestCase): 24 def assertIOsetDataEqual( 25 self, 26 program_ioset_data: List[bp_schema.Value], 27 config_ioset_data: List[ConfigValue], 28 ) -> None: 29 self.assertEqual(len(program_ioset_data), len(config_ioset_data)) 30 for program_element, config_element in zip( 31 program_ioset_data, config_ioset_data 32 ): 33 if isinstance(program_element.val, bp_schema.Tensor): 34 # TODO: Update to check the bundled input share the same type with the config input after supporting multiple types. 35 self.assertTrue(isinstance(config_element, torch.Tensor)) 36 self.assertEqual(program_element.val.sizes, list(config_element.size())) 37 # TODO(gasoonjia): Check the inner data. 38 elif type(program_element.val) is bp_schema.Int: 39 self.assertEqual(program_element.val.int_val, config_element) 40 elif type(program_element.val) is bp_schema.Double: 41 self.assertEqual(program_element.val.double_val, config_element) 42 elif type(program_element.val) is bp_schema.Bool: 43 self.assertEqual(program_element.val.bool_val, config_element) 44 45 def test_bundled_program(self) -> None: 46 executorch_program, method_test_suites = get_common_executorch_program() 47 48 bundled_program = BundledProgram(executorch_program, method_test_suites) 49 50 method_test_suites = sorted(method_test_suites, key=lambda t: t.method_name) 51 52 for plan_id in range(len(executorch_program.executorch_program.execution_plan)): 53 bundled_plan_test = ( 54 bundled_program.serialize_to_schema().method_test_suites[plan_id] 55 ) 56 method_test_suite = method_test_suites[plan_id] 57 58 self.assertEqual( 59 len(bundled_plan_test.test_cases), len(method_test_suite.test_cases) 60 ) 61 for bundled_program_ioset, method_test_case in zip( 62 bundled_plan_test.test_cases, method_test_suite.test_cases 63 ): 64 self.assertIOsetDataEqual( 65 bundled_program_ioset.inputs, method_test_case.inputs 66 ) 67 self.assertIOsetDataEqual( 68 bundled_program_ioset.expected_outputs, 69 method_test_case.expected_outputs, 70 ) 71 72 self.assertEqual( 73 bundled_program.serialize_to_schema().program, 74 bytes(_serialize_pte_binary(executorch_program.executorch_program)), 75 ) 76 77 def test_bundled_miss_methods(self) -> None: 78 executorch_program, method_test_suites = get_common_executorch_program() 79 80 # only keep the testcases for the first method to mimic the case that user only creates testcases for the first method. 81 method_test_suites = method_test_suites[:1] 82 83 _ = BundledProgram(executorch_program, method_test_suites) 84 85 def test_bundled_wrong_method_name(self) -> None: 86 executorch_program, method_test_suites = get_common_executorch_program() 87 88 method_test_suites[-1].method_name = "wrong_method_name" 89 self.assertRaises( 90 AssertionError, 91 BundledProgram, 92 executorch_program, 93 method_test_suites, 94 ) 95 96 def test_bundle_wrong_input_type(self) -> None: 97 executorch_program, method_test_suites = get_common_executorch_program() 98 99 # pyre-ignore[8]: Use a wrong type on purpose. Should raise an error when creating a bundled program using method_test_suites. 100 method_test_suites[0].test_cases[-1].inputs = ["WRONG INPUT TYPE"] 101 self.assertRaises( 102 AssertionError, 103 BundledProgram, 104 executorch_program, 105 method_test_suites, 106 ) 107 108 def test_bundle_wrong_output_type(self) -> None: 109 executorch_program, method_test_suites = get_common_executorch_program() 110 111 method_test_suites[0].test_cases[-1].expected_outputs = [ 112 0, 113 0.0, 114 ] 115 self.assertRaises( 116 AssertionError, 117 BundledProgram, 118 executorch_program, 119 method_test_suites, 120 ) 121