# Owner(s): ["oncall: distributed"] from collections import OrderedDict from typing import TYPE_CHECKING import torch import torch.distributed.checkpoint._traverse as _traverse from torch.testing._internal.common_utils import run_tests, TestCase if TYPE_CHECKING: from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE # TODO: add comments for TestTraverse class TestTraverse(TestCase): def test_traverse_shallow(self) -> None: state_dict = { "key0": 1, "key1": [1, 2], "key2": {1: 2, 2: 3}, "key3": torch.tensor([1]), } data = {} def collect_data(path, value): nonlocal data data[path] = value _traverse.traverse_state_dict(state_dict, collect_data) self.assertIn(("key0",), data) self.assertEqual(data[("key0",)], 1) self.assertIn(("key1",), data) self.assertEqual(data[("key1",)], [1, 2]) self.assertIn(("key2", "1"), data) self.assertEqual(data[("key2", "1")], 2) self.assertIn(("key2", "2"), data) self.assertEqual(data[("key2", "2")], 3) self.assertIn(("key3",), data) self.assertEqual(data[("key3",)], torch.tensor([1])) def test_traverse_nested_list(self) -> None: state_dict = { "key1": [ torch.tensor([1]), [33, torch.tensor([2]), [44, 55]], [66, 77], ], } data = {} def collect_data(path, value): nonlocal data data[path] = value _traverse.traverse_state_dict(state_dict, collect_data) self.assertNotIn(("key1"), data) self.assertIn(("key1", 0), data) self.assertEqual(data[("key1", 0)], torch.tensor([1])) self.assertIn(("key1", 1, 0), data) self.assertEqual(data[("key1", 1, 0)], 33) self.assertIn(("key1", 1, 1), data) self.assertEqual(data[("key1", 1, 1)], torch.tensor([2])) self.assertIn(("key1", 1, 2), data) self.assertEqual(data[("key1", 1, 2)], [44, 55]) self.assertNotIn(("key1", 1, 2, 0), data) self.assertIn(("key1", 2), data) self.assertEqual(data[("key1", 2)], [66, 77]) def test_traverse_nested_dict(self) -> None: state_dict = { "key0": {"key1": 99, "key2": torch.tensor([1])}, } data = {} def collect_data(path, value): nonlocal data data[path] = value _traverse.traverse_state_dict(state_dict, collect_data) self.assertNotIn(("key0",), data) self.assertIn(("key0", "key1"), data) self.assertEqual(data[("key0", "key1")], 99) self.assertIn(("key0", "key2"), data) self.assertEqual(data[("key0", "key2")], torch.tensor([1])) def test_traverse_doesnt_ignore_intermediate_collections(self) -> None: state_dict: STATE_DICT_TYPE = {"key0": [{"key1": {"key2": torch.tensor([1])}}]} data = {} def collect_data(path, value): nonlocal data data[path] = value _traverse.traverse_state_dict(state_dict, collect_data) self.assertIn(("key0", 0, "key1", "key2"), data) self.assertEqual( data[("key0", 0, "key1", "key2")], torch.tensor([1]), ) def test_traverse_with_ordered_dict(self) -> None: state_dict = OrderedDict( { "key0": [ 99, torch.tensor([3]), ] } ) data = {} def collect_data(path, value): nonlocal data data[path] = value _traverse.traverse_state_dict(state_dict, collect_data) self.assertIn(("key0", 0), data) self.assertEqual(data[("key0", 0)], 99) self.assertIn(("key0", 1), data) self.assertEqual(data[("key0", 1)], torch.tensor([3])) def test_set_element(self) -> None: state_dict: STATE_DICT_TYPE = {} _traverse.set_element(state_dict, ("k",), 10) self.assertEqual(state_dict["k"], 10) _traverse.set_element(state_dict, ("k1", 2), 1) self.assertEqual(state_dict["k1"], [None, None, 1]) _traverse.set_element(state_dict, ("k1", 1), 99) self.assertEqual(state_dict["k1"], [None, 99, 1]) _traverse.set_element(state_dict, ("k1", 3), 88) self.assertEqual(state_dict["k1"], [None, 99, 1, 88]) _traverse.set_element(state_dict, ("k2", "k3"), 3) self.assertEqual(state_dict["k2"], {"k3": 3}) _traverse.set_element(state_dict, ("k2", "k4", 0, 0), 99) self.assertEqual(state_dict["k2"]["k4"][0], [99]) def test_get_element(self) -> None: state_dict = {"a": [0, 1], "b": [2, {"c": "d"}]} self.assertEqual(_traverse.get_element(state_dict, ("a",)), [0, 1]) self.assertEqual(_traverse.get_element(state_dict, ("b", 0)), 2) self.assertEqual(_traverse.get_element(state_dict, ("b", 1, "c")), "d") self.assertIsNone(_traverse.get_element(state_dict, ("c",))) self.assertIsNone(_traverse.get_element(state_dict, ("a", 33))) self.assertIsNone(_traverse.get_element(state_dict, ("b", 88))) self.assertIsNone(_traverse.get_element(state_dict, ("b", 0, 2))) self.assertIsNone(_traverse.get_element(state_dict, ("b", 1, 2))) self.assertIsNone(_traverse.get_element(state_dict, ("b", 1, "d"))) if __name__ == "__main__": run_tests()