1# Owner(s): ["oncall: distributed"] 2 3import dataclasses 4 5import torch 6from torch.distributed.checkpoint._dedup_tensors import dedup_tensors 7from torch.distributed.checkpoint.planner import SavePlan, WriteItemType 8from torch.distributed.checkpoint.planner_helpers import _create_write_item_for_tensor 9from torch.testing._internal.common_utils import run_tests, TestCase 10 11 12# TODO: add comments for create_plan 13def create_plan(second_fqn) -> SavePlan: 14 # the first write item is for a duplicated shard (that covers the whole tensor) 15 write_item_1 = _create_write_item_for_tensor("tensor_0", torch.rand(4)) 16 write_item_1 = dataclasses.replace(write_item_1, type=WriteItemType.SHARD) 17 18 # the second write item has different keys 19 write_item_2 = _create_write_item_for_tensor(second_fqn, torch.rand(10)) 20 21 return SavePlan([write_item_1, write_item_2]) 22 23 24# TODO: add comments for TestDedupTensor 25class TestDedupTensor(TestCase): 26 def test_dedup_shards(self): 27 rank0 = create_plan("r0") 28 rank1 = create_plan("r1") 29 30 dedup_plans = dedup_tensors([rank0, rank1]) 31 32 self.assertEqual(2, len(dedup_plans[0].items)) 33 self.assertEqual(1, len(dedup_plans[1].items)) 34 35 self.assertIn("tensor_0", (item.index.fqn for item in dedup_plans[0].items)) 36 self.assertIn("r0", (item.index.fqn for item in dedup_plans[0].items)) 37 38 self.assertIn("r1", (item.index.fqn for item in dedup_plans[1].items)) 39 40 41if __name__ == "__main__": 42 run_tests() 43