• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: autograd"]
2
3import logging
4
5import torch
6from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
7
8
9class TestAutogradLogging(LoggingTestCase):
10    @make_logging_test(autograd=logging.DEBUG)
11    def test_logging(self, records):
12        a = torch.rand(10, requires_grad=True)
13        b = a.mul(2).div(3).sum()
14        c = b.clone()
15        torch.autograd.backward((b, c))
16
17        self.assertEqual(len(records), 5)
18        expected = [
19            "CloneBackward0",
20            "SumBackward0",
21            "DivBackward0",
22            "MulBackward0",
23            "AccumulateGrad",
24        ]
25
26        for i, record in enumerate(records):
27            self.assertIn(expected[i], record.getMessage())
28
29
30if __name__ == "__main__":
31    from torch._dynamo.test_case import run_tests
32
33    run_tests()
34