• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: onnx"]
2"""Unit tests on `torch.onnx.symbolic_helper`."""
3
4import torch
5from torch.onnx import symbolic_helper
6from torch.onnx._globals import GLOBALS
7from torch.testing._internal import common_utils
8
9
10class TestHelperFunctions(common_utils.TestCase):
11    def setUp(self):
12        super().setUp()
13        self._initial_training_mode = GLOBALS.training_mode
14
15    def tearDown(self):
16        GLOBALS.training_mode = self._initial_training_mode
17
18    @common_utils.parametrize(
19        "op_train_mode,export_mode",
20        [
21            common_utils.subtest(
22                [1, torch.onnx.TrainingMode.PRESERVE], name="export_mode_is_preserve"
23            ),
24            common_utils.subtest(
25                [0, torch.onnx.TrainingMode.EVAL],
26                name="modes_match_op_train_mode_0_export_mode_eval",
27            ),
28            common_utils.subtest(
29                [1, torch.onnx.TrainingMode.TRAINING],
30                name="modes_match_op_train_mode_1_export_mode_training",
31            ),
32        ],
33    )
34    def test_check_training_mode_does_not_warn_when(
35        self, op_train_mode: int, export_mode: torch.onnx.TrainingMode
36    ):
37        GLOBALS.training_mode = export_mode
38        self.assertNotWarn(
39            lambda: symbolic_helper.check_training_mode(op_train_mode, "testop")
40        )
41
42    @common_utils.parametrize(
43        "op_train_mode,export_mode",
44        [
45            common_utils.subtest(
46                [0, torch.onnx.TrainingMode.TRAINING],
47                name="modes_do_not_match_op_train_mode_0_export_mode_training",
48            ),
49            common_utils.subtest(
50                [1, torch.onnx.TrainingMode.EVAL],
51                name="modes_do_not_match_op_train_mode_1_export_mode_eval",
52            ),
53        ],
54    )
55    def test_check_training_mode_warns_when(
56        self,
57        op_train_mode: int,
58        export_mode: torch.onnx.TrainingMode,
59    ):
60        with self.assertWarnsRegex(
61            UserWarning, f"ONNX export mode is set to {export_mode}"
62        ):
63            GLOBALS.training_mode = export_mode
64            symbolic_helper.check_training_mode(op_train_mode, "testop")
65
66
67common_utils.instantiate_parametrized_tests(TestHelperFunctions)
68
69
70if __name__ == "__main__":
71    common_utils.run_tests()
72