• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: distributed"]
2
3import dataclasses
4
5import torch
6from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
7from torch.distributed.checkpoint.planner import SavePlan, WriteItemType
8from torch.distributed.checkpoint.planner_helpers import _create_write_item_for_tensor
9from torch.testing._internal.common_utils import run_tests, TestCase
10
11
12# TODO: add comments for create_plan
13def create_plan(second_fqn) -> SavePlan:
14    # the first write item is for a duplicated shard (that covers the whole tensor)
15    write_item_1 = _create_write_item_for_tensor("tensor_0", torch.rand(4))
16    write_item_1 = dataclasses.replace(write_item_1, type=WriteItemType.SHARD)
17
18    # the second write item has different keys
19    write_item_2 = _create_write_item_for_tensor(second_fqn, torch.rand(10))
20
21    return SavePlan([write_item_1, write_item_2])
22
23
24# TODO: add comments for TestDedupTensor
25class TestDedupTensor(TestCase):
26    def test_dedup_shards(self):
27        rank0 = create_plan("r0")
28        rank1 = create_plan("r1")
29
30        dedup_plans = dedup_tensors([rank0, rank1])
31
32        self.assertEqual(2, len(dedup_plans[0].items))
33        self.assertEqual(1, len(dedup_plans[1].items))
34
35        self.assertIn("tensor_0", (item.index.fqn for item in dedup_plans[0].items))
36        self.assertIn("r0", (item.index.fqn for item in dedup_plans[0].items))
37
38        self.assertIn("r1", (item.index.fqn for item in dedup_plans[1].items))
39
40
41if __name__ == "__main__":
42    run_tests()
43