1# Owner(s): ["module: onnx"] 2"""Unit tests on `torch.onnx.symbolic_helper`.""" 3 4import torch 5from torch.onnx import symbolic_helper 6from torch.onnx._globals import GLOBALS 7from torch.testing._internal import common_utils 8 9 10class TestHelperFunctions(common_utils.TestCase): 11 def setUp(self): 12 super().setUp() 13 self._initial_training_mode = GLOBALS.training_mode 14 15 def tearDown(self): 16 GLOBALS.training_mode = self._initial_training_mode 17 18 @common_utils.parametrize( 19 "op_train_mode,export_mode", 20 [ 21 common_utils.subtest( 22 [1, torch.onnx.TrainingMode.PRESERVE], name="export_mode_is_preserve" 23 ), 24 common_utils.subtest( 25 [0, torch.onnx.TrainingMode.EVAL], 26 name="modes_match_op_train_mode_0_export_mode_eval", 27 ), 28 common_utils.subtest( 29 [1, torch.onnx.TrainingMode.TRAINING], 30 name="modes_match_op_train_mode_1_export_mode_training", 31 ), 32 ], 33 ) 34 def test_check_training_mode_does_not_warn_when( 35 self, op_train_mode: int, export_mode: torch.onnx.TrainingMode 36 ): 37 GLOBALS.training_mode = export_mode 38 self.assertNotWarn( 39 lambda: symbolic_helper.check_training_mode(op_train_mode, "testop") 40 ) 41 42 @common_utils.parametrize( 43 "op_train_mode,export_mode", 44 [ 45 common_utils.subtest( 46 [0, torch.onnx.TrainingMode.TRAINING], 47 name="modes_do_not_match_op_train_mode_0_export_mode_training", 48 ), 49 common_utils.subtest( 50 [1, torch.onnx.TrainingMode.EVAL], 51 name="modes_do_not_match_op_train_mode_1_export_mode_eval", 52 ), 53 ], 54 ) 55 def test_check_training_mode_warns_when( 56 self, 57 op_train_mode: int, 58 export_mode: torch.onnx.TrainingMode, 59 ): 60 with self.assertWarnsRegex( 61 UserWarning, f"ONNX export mode is set to {export_mode}" 62 ): 63 GLOBALS.training_mode = export_mode 64 symbolic_helper.check_training_mode(op_train_mode, "testop") 65 66 67common_utils.instantiate_parametrized_tests(TestHelperFunctions) 68 69 70if __name__ == "__main__": 71 common_utils.run_tests() 72