• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: inductor"]
2
3import functools
4import logging
5
6import torch
7from torch._inductor.runtime.benchmarking import benchmarker
8from torch._inductor.test_case import run_tests, TestCase
9from torch._inductor.utils import do_bench_using_profiling
10
11
12log = logging.getLogger(__name__)
13
14
15class TestBench(TestCase):
16    @classmethod
17    def setUpClass(cls):
18        super().setUpClass()
19        x = torch.rand(1024, 10).cuda().half()
20        w = torch.rand(512, 10).cuda().half()
21        cls._bench_fn = functools.partial(torch.nn.functional.linear, x, w)
22
23    def test_benchmarker(self):
24        res = benchmarker.benchmark_gpu(self._bench_fn)
25        log.warning("do_bench result: %s", res)
26        self.assertGreater(res, 0)
27
28    def test_do_bench_using_profiling(self):
29        res = do_bench_using_profiling(self._bench_fn)
30        log.warning("do_bench_using_profiling result: %s", res)
31        self.assertGreater(res, 0)
32
33
34if __name__ == "__main__":
35    run_tests("cuda")
36