# Owner(s): ["oncall: export"] from collections import OrderedDict import torch from torch._dynamo.test_case import TestCase from torch.export._tree_utils import is_equivalent, reorder_kwargs from torch.testing._internal.common_utils import run_tests from torch.utils._pytree import tree_structure class TestTreeUtils(TestCase): def test_reorder_kwargs(self): original_kwargs = {"a": torch.tensor(0), "b": torch.tensor(1)} user_kwargs = {"b": torch.tensor(2), "a": torch.tensor(3)} orig_spec = tree_structure(((), original_kwargs)) reordered_kwargs = reorder_kwargs(user_kwargs, orig_spec) # Key ordering should be the same self.assertEqual(reordered_kwargs.popitem()[0], original_kwargs.popitem()[0]), self.assertEqual(reordered_kwargs.popitem()[0], original_kwargs.popitem()[0]), def test_equivalence_check(self): tree1 = {"a": torch.tensor(0), "b": torch.tensor(1), "c": None} tree2 = OrderedDict(a=torch.tensor(0), b=torch.tensor(1), c=None) spec1 = tree_structure(tree1) spec2 = tree_structure(tree2) def dict_ordered_dict_eq(type1, context1, type2, context2): if type1 is None or type2 is None: return type1 is type2 and context1 == context2 if issubclass(type1, (dict, OrderedDict)) and issubclass( type2, (dict, OrderedDict) ): return context1 == context2 return type1 is type2 and context1 == context2 self.assertTrue(is_equivalent(spec1, spec2, dict_ordered_dict_eq)) # Wrong ordering should still fail tree3 = OrderedDict(b=torch.tensor(1), a=torch.tensor(0)) spec3 = tree_structure(tree3) self.assertFalse(is_equivalent(spec1, spec3, dict_ordered_dict_eq)) if __name__ == "__main__": run_tests()