• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import unittest
10from typing import List
11
12import executorch.devtools.bundled_program.schema as bp_schema
13
14import torch
15from executorch.devtools.bundled_program.config import ConfigValue
16from executorch.devtools.bundled_program.core import BundledProgram
17from executorch.devtools.bundled_program.util.test_util import (
18    get_common_executorch_program,
19)
20from executorch.exir._serialize import _serialize_pte_binary
21
22
23class TestBundle(unittest.TestCase):
24    def assertIOsetDataEqual(
25        self,
26        program_ioset_data: List[bp_schema.Value],
27        config_ioset_data: List[ConfigValue],
28    ) -> None:
29        self.assertEqual(len(program_ioset_data), len(config_ioset_data))
30        for program_element, config_element in zip(
31            program_ioset_data, config_ioset_data
32        ):
33            if isinstance(program_element.val, bp_schema.Tensor):
34                # TODO: Update to check the bundled input share the same type with the config input after supporting multiple types.
35                self.assertTrue(isinstance(config_element, torch.Tensor))
36                self.assertEqual(program_element.val.sizes, list(config_element.size()))
37                # TODO(gasoonjia): Check the inner data.
38            elif type(program_element.val) is bp_schema.Int:
39                self.assertEqual(program_element.val.int_val, config_element)
40            elif type(program_element.val) is bp_schema.Double:
41                self.assertEqual(program_element.val.double_val, config_element)
42            elif type(program_element.val) is bp_schema.Bool:
43                self.assertEqual(program_element.val.bool_val, config_element)
44
45    def test_bundled_program(self) -> None:
46        executorch_program, method_test_suites = get_common_executorch_program()
47
48        bundled_program = BundledProgram(executorch_program, method_test_suites)
49
50        method_test_suites = sorted(method_test_suites, key=lambda t: t.method_name)
51
52        for plan_id in range(len(executorch_program.executorch_program.execution_plan)):
53            bundled_plan_test = (
54                bundled_program.serialize_to_schema().method_test_suites[plan_id]
55            )
56            method_test_suite = method_test_suites[plan_id]
57
58            self.assertEqual(
59                len(bundled_plan_test.test_cases), len(method_test_suite.test_cases)
60            )
61            for bundled_program_ioset, method_test_case in zip(
62                bundled_plan_test.test_cases, method_test_suite.test_cases
63            ):
64                self.assertIOsetDataEqual(
65                    bundled_program_ioset.inputs, method_test_case.inputs
66                )
67                self.assertIOsetDataEqual(
68                    bundled_program_ioset.expected_outputs,
69                    method_test_case.expected_outputs,
70                )
71
72        self.assertEqual(
73            bundled_program.serialize_to_schema().program,
74            bytes(_serialize_pte_binary(executorch_program.executorch_program)),
75        )
76
77    def test_bundled_miss_methods(self) -> None:
78        executorch_program, method_test_suites = get_common_executorch_program()
79
80        # only keep the testcases for the first method to mimic the case that user only creates testcases for the first method.
81        method_test_suites = method_test_suites[:1]
82
83        _ = BundledProgram(executorch_program, method_test_suites)
84
85    def test_bundled_wrong_method_name(self) -> None:
86        executorch_program, method_test_suites = get_common_executorch_program()
87
88        method_test_suites[-1].method_name = "wrong_method_name"
89        self.assertRaises(
90            AssertionError,
91            BundledProgram,
92            executorch_program,
93            method_test_suites,
94        )
95
96    def test_bundle_wrong_input_type(self) -> None:
97        executorch_program, method_test_suites = get_common_executorch_program()
98
99        # pyre-ignore[8]: Use a wrong type on purpose. Should raise an error when creating a bundled program using method_test_suites.
100        method_test_suites[0].test_cases[-1].inputs = ["WRONG INPUT TYPE"]
101        self.assertRaises(
102            AssertionError,
103            BundledProgram,
104            executorch_program,
105            method_test_suites,
106        )
107
108    def test_bundle_wrong_output_type(self) -> None:
109        executorch_program, method_test_suites = get_common_executorch_program()
110
111        method_test_suites[0].test_cases[-1].expected_outputs = [
112            0,
113            0.0,
114        ]
115        self.assertRaises(
116            AssertionError,
117            BundledProgram,
118            executorch_program,
119            method_test_suites,
120        )
121