# Owner(s): ["oncall: distributed"] # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import unittest import torch import torch.nn as nn from torch.distributed.optim import _NamedOptimizer def _run_model_training(model_optim_lists): for _ in range(2): x = torch.rand(5, 8) for model_optim_list in model_optim_lists: model = model_optim_list[0] optim_list = model_optim_list[1] y = model(x) y.sum().backward() for optim in optim_list: optim.step() class TestDummyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() torch.manual_seed(0) self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU()) self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU()) self.net3 = nn.Linear(32, 64) self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8)) def forward(self, x): return self.net4(self.net3(self.net2(self.net1(x)))) class NamedOptimizerTest(unittest.TestCase): def _compare_state_dict_group(self, group, named_group, assert_equal=True): for key, val in group.items(): if key != "params": self.assertTrue( key in named_group, f"{key} not in named optimizer state dict" ) err_msg = ( f"{key} state not equal" if assert_equal else f"{key} state equal" ) if isinstance(val, torch.Tensor): fn = self.assertTrue if assert_equal else self.assertFalse fn(torch.allclose(val, named_group[key]), err_msg) else: fn = self.assertEqual if assert_equal else self.assertNotEqual fn(val, named_group[key], err_msg) def _compare_param_groups(self, param_groups_1, param_groups_2): self.assertTrue(isinstance(param_groups_1, list)) self.assertTrue(isinstance(param_groups_2, list)) for groups in zip(param_groups_1, param_groups_2): self._compare_param_group(groups[0], groups[1]) def _compare_param_group(self, group_1, group_2): self.assertTrue(isinstance(group_1, dict)) self.assertTrue(isinstance(group_2, dict)) for key, val in group_1.items(): self.assertTrue(key in group_2) if key != "params": self.assertEqual(val, group_2[key]) else: for tensors in zip(val, group_2[key]): self.assertTrue(torch.allclose(tensors[0], tensors[1])) def test_state_dict(self): """Check that NamedOptimizer exposes the expected state dict interface.""" m = TestDummyModel() m_dup = TestDummyModel() optim = torch.optim.SGD( m.parameters(), lr=1e-2, momentum=0.9, ) named_optim = _NamedOptimizer( m_dup.named_parameters(), torch.optim.SGD, lr=1e-2, momentum=0.9, ) self._compare_param_groups(optim.param_groups, named_optim.param_groups) _run_model_training([(m, [optim]), (m_dup, [named_optim])]) self._compare_param_groups(optim.param_groups, named_optim.param_groups) sd = optim.state_dict() named_sd = named_optim.state_dict() # Compare "state" in optim state dict self._compare_state_dict_group( sd["state"][0], named_sd["state"]["net1.0.weight"], assert_equal=True, ) self._compare_state_dict_group( sd["state"][3], named_sd["state"]["net2.0.bias"], assert_equal=True, ) self._compare_state_dict_group( sd["state"][4], named_sd["state"]["net3.weight"], assert_equal=True, ) self._compare_state_dict_group( sd["state"][7], named_sd["state"]["net4.1.bias"], assert_equal=True, ) def test_state_dict_multi_param_group(self): """Check that NamedOptimizer exposes the expected state dict interface when multiple param groups are specified.""" m = TestDummyModel() m_dup = TestDummyModel() optim_1 = torch.optim.SGD( [ {"params": m.net1.parameters()}, {"params": m.net3.parameters(), "lr": 1e-3}, ], lr=1e-2, momentum=0.9, ) optim_2 = torch.optim.Adam( [ {"params": m.net2.parameters()}, {"params": m.net4.parameters(), "lr": 1e-5}, ] ) named_optim_1 = _NamedOptimizer( m_dup.named_parameters(), torch.optim.SGD, [ {"params": m_dup.net1.parameters()}, {"params": m_dup.net3.parameters(), "lr": 1e-3}, ], lr=1e-2, momentum=0.9, ) named_optim_2 = _NamedOptimizer( m_dup.named_parameters(), torch.optim.Adam, [ {"params": m_dup.net2.parameters()}, {"params": m_dup.net4.parameters(), "lr": 1e-5}, ], ) self._compare_param_groups(optim_1.param_groups, named_optim_1.param_groups) self._compare_param_groups(optim_2.param_groups, named_optim_2.param_groups) _run_model_training( [(m, [optim_1, optim_2]), (m_dup, [named_optim_1, named_optim_2])] ) self._compare_param_groups(optim_1.param_groups, named_optim_1.param_groups) self._compare_param_groups(optim_2.param_groups, named_optim_2.param_groups) sd_1 = optim_1.state_dict() sd_2 = optim_2.state_dict() named_sd_1 = named_optim_1.state_dict() named_sd_2 = named_optim_2.state_dict() # Compare "state" in optim state dict self._compare_state_dict_group( sd_1["state"][0], named_sd_1["state"]["net1.0.weight"], assert_equal=True, ) self._compare_state_dict_group( sd_2["state"][1], named_sd_2["state"]["net2.0.bias"], assert_equal=True, ) self._compare_state_dict_group( sd_1["state"][2], named_sd_1["state"]["net3.weight"], assert_equal=True, ) self._compare_state_dict_group( sd_2["state"][3], named_sd_2["state"]["net4.1.bias"], assert_equal=True, ) # Compare "param_groups" in optim state dict self._compare_state_dict_group( sd_1["param_groups"][0], named_sd_1["param_groups"][0], assert_equal=True, ) self._compare_state_dict_group( sd_2["param_groups"][1], named_sd_2["param_groups"][1], assert_equal=True ) def test_load_state_dict(self): """Check that NamedOptimizer's load_state_dict works as expected.""" m = TestDummyModel() named_optim_1 = _NamedOptimizer( m.named_parameters(), torch.optim.SGD, lr=1e-2, momentum=0.9, ) _run_model_training([(m, [named_optim_1])]) state_dict_to_load = named_optim_1.state_dict() named_optim_2 = _NamedOptimizer( m.named_parameters(), torch.optim.SGD, lr=1e-2, momentum=0.6, ) _run_model_training([(m, [named_optim_2])]) state_dict_before_load = named_optim_2.state_dict() # Compare "state" in optim state dict self._compare_state_dict_group( state_dict_to_load["state"]["net1.0.weight"], state_dict_before_load["state"]["net1.0.weight"], assert_equal=False, ) self._compare_state_dict_group( state_dict_to_load["state"]["net2.0.bias"], state_dict_before_load["state"]["net2.0.bias"], assert_equal=False, ) self._compare_state_dict_group( state_dict_to_load["state"]["net3.weight"], state_dict_before_load["state"]["net3.weight"], assert_equal=False, ) self._compare_state_dict_group( state_dict_to_load["state"]["net4.1.bias"], state_dict_before_load["state"]["net4.1.bias"], assert_equal=False, ) named_optim_2.load_state_dict(state_dict_to_load) state_dict_after_load = named_optim_2.state_dict() # Compare "state" in optim state dict self._compare_state_dict_group( state_dict_to_load["state"]["net1.0.weight"], state_dict_after_load["state"]["net1.0.weight"], assert_equal=True, ) self._compare_state_dict_group( state_dict_to_load["state"]["net2.0.bias"], state_dict_after_load["state"]["net2.0.bias"], assert_equal=True, ) self._compare_state_dict_group( state_dict_to_load["state"]["net3.weight"], state_dict_after_load["state"]["net3.weight"], assert_equal=True, ) self._compare_state_dict_group( state_dict_to_load["state"]["net4.1.bias"], state_dict_after_load["state"]["net4.1.bias"], assert_equal=True, ) def test_load_state_dict_conditional_training(self): """Check that NamedOptimizer load_state_dict works under conditional training case.""" m = TestDummyModel() named_optim_1 = _NamedOptimizer( m.named_parameters(), torch.optim.SGD, [ {"params": m.net1.parameters()}, {"params": m.net3.parameters(), "lr": 1e-3}, ], lr=1e-2, momentum=0.9, ) _run_model_training([(m, [named_optim_1])]) state_dict_to_load = named_optim_1.state_dict() named_optim_2 = _NamedOptimizer( m.named_parameters(), torch.optim.SGD, lr=1e-2, momentum=0.6, ) _run_model_training([(m, [named_optim_2])]) named_optim_2.load_state_dict(state_dict_to_load) state_dict_after_load = named_optim_2.state_dict() # Compare "state" in optim state dict self._compare_state_dict_group( state_dict_to_load["state"]["net1.0.weight"], state_dict_after_load["state"]["net1.0.weight"], assert_equal=True, ) self._compare_state_dict_group( state_dict_to_load["state"]["net3.weight"], state_dict_after_load["state"]["net3.weight"], assert_equal=True, ) def test_load_state_dict_error(self): m = TestDummyModel() named_optim_1 = _NamedOptimizer( m.named_parameters(), torch.optim.SGD, lr=1e-2, momentum=0.9, ) _run_model_training([(m, [named_optim_1])]) state_dict_to_load = named_optim_1.state_dict() named_optim_2 = _NamedOptimizer( m.named_parameters(), torch.optim.SGD, lr=1e-2, momentum=0.6, ) err_msg = ( "Expects the optim to be initialized before load but found not initialized" ) with self.assertRaisesRegex(ValueError, err_msg): named_optim_2.load_state_dict(state_dict_to_load) def test_add_param_group(self): m = TestDummyModel() m_dup = TestDummyModel() optim = torch.optim.SGD( [ {"params": m.net1.parameters()}, {"params": m.net3.parameters(), "lr": 1e-3}, ], lr=1e-2, momentum=0.9, ) named_optim = _NamedOptimizer( m_dup.named_parameters(), torch.optim.SGD, [ {"params": m_dup.net1.parameters()}, {"params": m_dup.net3.parameters(), "lr": 1e-3}, ], lr=1e-2, momentum=0.9, ) _run_model_training([(m, [optim]), (m_dup, [named_optim])]) self._compare_param_groups(optim.param_groups, named_optim.param_groups) optim.add_param_group({"params": m.net2.parameters(), "lr": 1e-5}) named_optim.add_param_group({"params": m_dup.net2.parameters(), "lr": 1e-5}) _run_model_training([(m, [optim]), (m_dup, [named_optim])]) self._compare_param_groups(optim.param_groups, named_optim.param_groups) optim.add_param_group({"params": m.net4[1].weight, "lr": 1e-3}) named_optim.add_param_group({"params": m_dup.net4[1].weight, "lr": 1e-3}) _run_model_training([(m, [optim]), (m_dup, [named_optim])]) self._compare_param_groups(optim.param_groups, named_optim.param_groups) def test_add_param_group_error(self): m = TestDummyModel() named_optim = _NamedOptimizer( m.named_parameters(), torch.optim.SGD, [ {"params": m.net1.parameters()}, {"params": m.net3.parameters(), "lr": 1e-3}, ], lr=1e-2, momentum=0.9, ) err_msg = "some parameters are not in the module" with self.assertRaisesRegex(ValueError, err_msg): named_optim.add_param_group({"params": [torch.ones(8, 1)], "lr": 1e-5}) def test_init_state(self): m = TestDummyModel() named_optim = _NamedOptimizer( m.named_parameters(), torch.optim.SGD, [ {"params": m.net1.parameters()}, {"params": m.net3.parameters(), "lr": 1e-3}, ], lr=1e-2, momentum=0.9, ) named_sd = named_optim.state_dict() self.assertTrue(m.net1[0].weight.grad is None) self.assertTrue(len(named_sd["state"]) == 0) named_optim.init_state() named_sd = named_optim.state_dict() self.assertTrue(m.net1[0].weight.grad is not None) self.assertTrue("momentum_buffer" in named_sd["state"]["net1.0.weight"]) self.assertFalse( torch.all(named_sd["state"]["net1.0.weight"]["momentum_buffer"]).item() ) self.assertFalse( torch.all(named_sd["state"]["net1.0.bias"]["momentum_buffer"]).item() ) self.assertTrue(m.net3.bias.grad is not None) self.assertTrue("momentum_buffer" in named_sd["state"]["net3.bias"]) self.assertFalse( torch.all(named_sd["state"]["net3.bias"]["momentum_buffer"]).item() ) self.assertFalse( torch.all(named_sd["state"]["net3.weight"]["momentum_buffer"]).item() )