#!/usr/bin/env python3 # Owner(s): ["oncall: r2p"] # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree.abs import abc import unittest.mock as mock from torch.distributed.elastic.metrics.api import ( _get_metric_name, MetricData, MetricHandler, MetricStream, prof, ) from torch.testing._internal.common_utils import run_tests, TestCase def foo_1(): pass class TestMetricsHandler(MetricHandler): def __init__(self) -> None: self.metric_data = {} def emit(self, metric_data: MetricData): self.metric_data[metric_data.name] = metric_data class Parent(abc.ABC): @abc.abstractmethod def func(self): raise NotImplementedError def base_func(self): self.func() class Child(Parent): # need to decorate the implementation not the abstract method! @prof def func(self): pass class MetricsApiTest(TestCase): def foo_2(self): pass @prof def bar(self): pass @prof def throw(self): raise RuntimeError @prof(group="torchelastic") def bar2(self): pass def test_get_metric_name(self): # Note: since pytorch uses main method to launch tests, # the module will be different between fb and oss, this # allows keeping the module name consistent. foo_1.__module__ = "api_test" self.assertEqual("api_test.foo_1", _get_metric_name(foo_1)) self.assertEqual("MetricsApiTest.foo_2", _get_metric_name(self.foo_2)) def test_profile(self): handler = TestMetricsHandler() stream = MetricStream("torchelastic", handler) # patch instead of configure to avoid conflicts when running tests in parallel with mock.patch( "torch.distributed.elastic.metrics.api.getStream", return_value=stream ): self.bar() self.assertEqual(1, handler.metric_data["MetricsApiTest.bar.success"].value) self.assertNotIn("MetricsApiTest.bar.failure", handler.metric_data) self.assertIn("MetricsApiTest.bar.duration.ms", handler.metric_data) with self.assertRaises(RuntimeError): self.throw() self.assertEqual( 1, handler.metric_data["MetricsApiTest.throw.failure"].value ) self.assertNotIn("MetricsApiTest.bar_raise.success", handler.metric_data) self.assertIn("MetricsApiTest.throw.duration.ms", handler.metric_data) self.bar2() self.assertEqual( "torchelastic", handler.metric_data["MetricsApiTest.bar2.success"].group_name, ) def test_inheritance(self): handler = TestMetricsHandler() stream = MetricStream("torchelastic", handler) # patch instead of configure to avoid conflicts when running tests in parallel with mock.patch( "torch.distributed.elastic.metrics.api.getStream", return_value=stream ): c = Child() c.base_func() self.assertEqual(1, handler.metric_data["Child.func.success"].value) self.assertIn("Child.func.duration.ms", handler.metric_data) if __name__ == "__main__": run_tests()