# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-strict import sys import unittest import torch from executorch.exir.program._fake_program import ( get_fake_program, update_to_real_program, ) from torch.export import export, ExportedProgram def get_exported_program() -> ExportedProgram: class Linear(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(10, 10) self.register_buffer("buf", torch.randn(10, 10), persistent=False) def forward(self, arg) -> torch.Tensor: return self.linear(arg) + self.buf linear = Linear() exported_program = export( linear, args=(torch.randn(10, 10),), ).run_decompositions() return exported_program class TestFakeProgram(unittest.TestCase): def setUp(self) -> None: super().setUp() def test_fake_program(self) -> None: exported_program = get_exported_program() fake_program = get_fake_program(exported_program) print(f"Exported program size: {sys.getsizeof(exported_program.state_dict)}") print(f"Fake program size: {sys.getsizeof(fake_program.state_dict)}") # Fake program deep copies attributes besides verifier, state_dict and constants. self.assertEqual(exported_program.graph_signature, fake_program.graph_signature) self.assertNotEqual( id(exported_program.graph_signature), id(fake_program.graph_signature) ) self.assertEqual( exported_program.module_call_graph, fake_program.module_call_graph ) self.assertNotEqual( id(exported_program.module_call_graph), id(fake_program.module_call_graph) ) # Verifier is static. self.assertEqual(exported_program.verifier, fake_program.verifier) self.assertEqual(id(exported_program.verifier), id(fake_program.verifier)) # Fake program uses fake tensors for the state dict. Size should be not be larger. self.assertLessEqual( sys.getsizeof(fake_program.state_dict), sys.getsizeof(exported_program.state_dict), ) # Do not copy constants. self.assertEqual(exported_program.constants, fake_program.constants) self.assertEqual(id(exported_program.constants), id(fake_program.constants)) update_to_real_program(fake_program, exported_program) self.assertEqual(exported_program.state_dict, fake_program.state_dict) self.assertEqual( exported_program.state_dict.keys(), fake_program.state_dict.keys() )