# Owner(s): ["oncall: package/deploy"] import pickle from io import BytesIO from textwrap import dedent from torch.package import PackageExporter, PackageImporter, sys_importer from torch.testing._internal.common_utils import run_tests try: from .common import PackageTestCase except ImportError: # Support the case where we run this file directly. from common import PackageTestCase from pathlib import Path packaging_directory = Path(__file__).parent class TestSaveLoad(PackageTestCase): """Core save_* and loading API tests.""" def test_saving_source(self): buffer = BytesIO() with PackageExporter(buffer) as he: he.save_source_file("foo", str(packaging_directory / "module_a.py")) he.save_source_file("foodir", str(packaging_directory / "package_a")) buffer.seek(0) hi = PackageImporter(buffer) foo = hi.import_module("foo") s = hi.import_module("foodir.subpackage") self.assertEqual(foo.result, "module_a") self.assertEqual(s.result, "package_a.subpackage") def test_saving_string(self): buffer = BytesIO() with PackageExporter(buffer) as he: src = dedent( """\ import math the_math = math """ ) he.save_source_string("my_mod", src) buffer.seek(0) hi = PackageImporter(buffer) m = hi.import_module("math") import math self.assertIs(m, math) my_mod = hi.import_module("my_mod") self.assertIs(my_mod.math, math) def test_save_module(self): buffer = BytesIO() with PackageExporter(buffer) as he: import module_a import package_a he.save_module(module_a.__name__) he.save_module(package_a.__name__) buffer.seek(0) hi = PackageImporter(buffer) module_a_i = hi.import_module("module_a") self.assertEqual(module_a_i.result, "module_a") self.assertIsNot(module_a, module_a_i) package_a_i = hi.import_module("package_a") self.assertEqual(package_a_i.result, "package_a") self.assertIsNot(package_a_i, package_a) def test_dunder_imports(self): buffer = BytesIO() with PackageExporter(buffer) as he: import package_b obj = package_b.PackageBObject he.intern("**") he.save_pickle("res", "obj.pkl", obj) buffer.seek(0) hi = PackageImporter(buffer) loaded_obj = hi.load_pickle("res", "obj.pkl") package_b = hi.import_module("package_b") self.assertEqual(package_b.result, "package_b") math = hi.import_module("math") self.assertEqual(math.__name__, "math") xml_sub_sub_package = hi.import_module("xml.sax.xmlreader") self.assertEqual(xml_sub_sub_package.__name__, "xml.sax.xmlreader") subpackage_1 = hi.import_module("package_b.subpackage_1") self.assertEqual(subpackage_1.result, "subpackage_1") subpackage_2 = hi.import_module("package_b.subpackage_2") self.assertEqual(subpackage_2.result, "subpackage_2") subsubpackage_0 = hi.import_module("package_b.subpackage_0.subsubpackage_0") self.assertEqual(subsubpackage_0.result, "subsubpackage_0") def test_bad_dunder_imports(self): """Test to ensure bad __imports__ don't cause PackageExporter to fail.""" buffer = BytesIO() with PackageExporter(buffer) as e: e.save_source_string( "m", '__import__(these, unresolvable, "things", wont, crash, me)' ) def test_save_module_binary(self): f = BytesIO() with PackageExporter(f) as he: import module_a import package_a he.save_module(module_a.__name__) he.save_module(package_a.__name__) f.seek(0) hi = PackageImporter(f) module_a_i = hi.import_module("module_a") self.assertEqual(module_a_i.result, "module_a") self.assertIsNot(module_a, module_a_i) package_a_i = hi.import_module("package_a") self.assertEqual(package_a_i.result, "package_a") self.assertIsNot(package_a_i, package_a) def test_pickle(self): import package_a.subpackage obj = package_a.subpackage.PackageASubpackageObject() obj2 = package_a.PackageAObject(obj) buffer = BytesIO() with PackageExporter(buffer) as he: he.intern("**") he.save_pickle("obj", "obj.pkl", obj2) buffer.seek(0) hi = PackageImporter(buffer) # check we got dependencies sp = hi.import_module("package_a.subpackage") # check we didn't get other stuff with self.assertRaises(ImportError): hi.import_module("module_a") obj_loaded = hi.load_pickle("obj", "obj.pkl") self.assertIsNot(obj2, obj_loaded) self.assertIsInstance(obj_loaded.obj, sp.PackageASubpackageObject) self.assertIsNot( package_a.subpackage.PackageASubpackageObject, sp.PackageASubpackageObject ) def test_pickle_long_name_with_protocol_4(self): import package_a.long_name container = [] # Indirectly grab the function to avoid pasting a 256 character # function into the test package_a.long_name.add_function(container) buffer = BytesIO() with PackageExporter(buffer) as exporter: exporter.intern("**") exporter.save_pickle( "container", "container.pkl", container, pickle_protocol=4 ) buffer.seek(0) importer = PackageImporter(buffer) unpickled_container = importer.load_pickle("container", "container.pkl") self.assertIsNot(container, unpickled_container) self.assertEqual(len(unpickled_container), 1) self.assertEqual(container[0](), unpickled_container[0]()) def test_exporting_mismatched_code(self): """ If an object with the same qualified name is loaded from different packages, the user should get an error if they try to re-save the object with the wrong package's source code. """ import package_a.subpackage obj = package_a.subpackage.PackageASubpackageObject() obj2 = package_a.PackageAObject(obj) b1 = BytesIO() with PackageExporter(b1) as pe: pe.intern("**") pe.save_pickle("obj", "obj.pkl", obj2) b1.seek(0) importer1 = PackageImporter(b1) loaded1 = importer1.load_pickle("obj", "obj.pkl") b1.seek(0) importer2 = PackageImporter(b1) loaded2 = importer2.load_pickle("obj", "obj.pkl") def make_exporter(): pe = PackageExporter(BytesIO(), importer=[importer1, sys_importer]) # Ensure that the importer finds the 'PackageAObject' defined in 'importer1' first. return pe # This should fail. The 'PackageAObject' type defined from 'importer1' # is not necessarily the same 'obj2's version of 'PackageAObject'. pe = make_exporter() with self.assertRaises(pickle.PicklingError): pe.save_pickle("obj", "obj.pkl", obj2) # This should also fail. The 'PackageAObject' type defined from 'importer1' # is not necessarily the same as the one defined from 'importer2' pe = make_exporter() with self.assertRaises(pickle.PicklingError): pe.save_pickle("obj", "obj.pkl", loaded2) # This should succeed. The 'PackageAObject' type defined from # 'importer1' is a match for the one used by loaded1. pe = make_exporter() pe.save_pickle("obj", "obj.pkl", loaded1) def test_save_imported_module(self): """Saving a module that came from another PackageImporter should work.""" import package_a.subpackage obj = package_a.subpackage.PackageASubpackageObject() obj2 = package_a.PackageAObject(obj) buffer = BytesIO() with PackageExporter(buffer) as exporter: exporter.intern("**") exporter.save_pickle("model", "model.pkl", obj2) buffer.seek(0) importer = PackageImporter(buffer) imported_obj2 = importer.load_pickle("model", "model.pkl") imported_obj2_module = imported_obj2.__class__.__module__ # Should export without error. buffer2 = BytesIO() with PackageExporter(buffer2, importer=(importer, sys_importer)) as exporter: exporter.intern("**") exporter.save_module(imported_obj2_module) def test_save_imported_module_using_package_importer(self): """Exercise a corner case: re-packaging a module that uses `torch_package_importer`""" import package_a.use_torch_package_importer # noqa: F401 buffer = BytesIO() with PackageExporter(buffer) as exporter: exporter.intern("**") exporter.save_module("package_a.use_torch_package_importer") buffer.seek(0) importer = PackageImporter(buffer) # Should export without error. buffer2 = BytesIO() with PackageExporter(buffer2, importer=(importer, sys_importer)) as exporter: exporter.intern("**") exporter.save_module("package_a.use_torch_package_importer") if __name__ == "__main__": run_tests()