• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Owner(s): ["oncall: r2p"]
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.abs
9import abc
10import unittest.mock as mock
11
12from torch.distributed.elastic.metrics.api import (
13    _get_metric_name,
14    MetricData,
15    MetricHandler,
16    MetricStream,
17    prof,
18)
19from torch.testing._internal.common_utils import run_tests, TestCase
20
21
22def foo_1():
23    pass
24
25
26class TestMetricsHandler(MetricHandler):
27    def __init__(self) -> None:
28        self.metric_data = {}
29
30    def emit(self, metric_data: MetricData):
31        self.metric_data[metric_data.name] = metric_data
32
33
34class Parent(abc.ABC):
35    @abc.abstractmethod
36    def func(self):
37        raise NotImplementedError
38
39    def base_func(self):
40        self.func()
41
42
43class Child(Parent):
44    # need to decorate the implementation not the abstract method!
45    @prof
46    def func(self):
47        pass
48
49
50class MetricsApiTest(TestCase):
51    def foo_2(self):
52        pass
53
54    @prof
55    def bar(self):
56        pass
57
58    @prof
59    def throw(self):
60        raise RuntimeError
61
62    @prof(group="torchelastic")
63    def bar2(self):
64        pass
65
66    def test_get_metric_name(self):
67        # Note: since pytorch uses main method to launch tests,
68        # the module will be different between fb and oss, this
69        # allows keeping the module name consistent.
70        foo_1.__module__ = "api_test"
71        self.assertEqual("api_test.foo_1", _get_metric_name(foo_1))
72        self.assertEqual("MetricsApiTest.foo_2", _get_metric_name(self.foo_2))
73
74    def test_profile(self):
75        handler = TestMetricsHandler()
76        stream = MetricStream("torchelastic", handler)
77        # patch instead of configure to avoid conflicts when running tests in parallel
78        with mock.patch(
79            "torch.distributed.elastic.metrics.api.getStream", return_value=stream
80        ):
81            self.bar()
82
83            self.assertEqual(1, handler.metric_data["MetricsApiTest.bar.success"].value)
84            self.assertNotIn("MetricsApiTest.bar.failure", handler.metric_data)
85            self.assertIn("MetricsApiTest.bar.duration.ms", handler.metric_data)
86
87            with self.assertRaises(RuntimeError):
88                self.throw()
89
90            self.assertEqual(
91                1, handler.metric_data["MetricsApiTest.throw.failure"].value
92            )
93            self.assertNotIn("MetricsApiTest.bar_raise.success", handler.metric_data)
94            self.assertIn("MetricsApiTest.throw.duration.ms", handler.metric_data)
95
96            self.bar2()
97            self.assertEqual(
98                "torchelastic",
99                handler.metric_data["MetricsApiTest.bar2.success"].group_name,
100            )
101
102    def test_inheritance(self):
103        handler = TestMetricsHandler()
104        stream = MetricStream("torchelastic", handler)
105        # patch instead of configure to avoid conflicts when running tests in parallel
106        with mock.patch(
107            "torch.distributed.elastic.metrics.api.getStream", return_value=stream
108        ):
109            c = Child()
110            c.base_func()
111
112            self.assertEqual(1, handler.metric_data["Child.func.success"].value)
113            self.assertIn("Child.func.duration.ms", handler.metric_data)
114
115
116if __name__ == "__main__":
117    run_tests()
118