#include #include #include #include #include #include #include namespace torch::jit { namespace { c10::StrongTypePtr customClassResolver(const c10::QualifiedName& qn) { at::TypePtr type = nullptr; if (c10::QualifiedName("__torch__").isPrefixOf(qn)) { type = torch::getCustomClass(qn.qualifiedName()); } else { // This is a regular type, fall back to the default type parser torch::jit::ScriptTypeParser parser; type = parser.parseType(qn.qualifiedName()); return c10::StrongTypePtr(nullptr, std::move(type)); } if (type == nullptr) { TORCH_CHECK( false, "Couldn't resolve type '{}', did you forget to add its build dependency?", qn.qualifiedName()); } // Passing nullptr is a little bit sus, but should be fine: // 1. The lifetime of the class type is not tied to a specific // CompilationUnit // but rather the global custom class registry. // 2. We will not access the `cu_` field and immediately discard this // StrongTypePtr post-deserialization. return c10::StrongTypePtr(nullptr, std::move(type)); } } // namespace void pickle( std::function writer, const IValue& ivalue, std::vector* tensor_table) { Pickler pickler(std::move(writer), tensor_table, nullptr, nullptr); pickler.protocol(); pickler.pushIValue(ivalue); pickler.stop(); } std::vector pickle( const IValue& ivalue, std::vector* tensor_table) { std::vector data; pickle( [&](const char* bytes, size_t len) { data.insert(data.end(), bytes, bytes + len); }, ivalue, tensor_table); return data; } // This has to live here instead of the C++ API to mirror torch.save since the // mobile build excludes the C++ API std::vector pickle_save(const at::IValue& ivalue) { #ifndef C10_MOBILE // Pickle the IValue into an array of bytes std::vector pickle_data; Pickler pickler([&](const char* buf, size_t size) { pickle_data.insert(pickle_data.end(), buf, buf + size); }); pickler.protocol(); pickler.pushIValue(ivalue); pickler.stop(); std::vector container_data; container_data.reserve(pickle_data.size()); caffe2::serialize::PyTorchStreamWriter writer( [&](const void* void_bytes, size_t len) { const char* bytes = reinterpret_cast(void_bytes); container_data.insert(container_data.end(), bytes, bytes + len); return len; }); // Write the generated bytes and the associated tensors into a data.pkl file // and data/0, data/1, data/2... files for each of the tensors writeArchiveAndTensors( "data", pickle_data.data(), pickle_data.size(), pickler.tensorData(), writer); return container_data; #else AT_ERROR( "pickle_save not supported on mobile " "(see https://github.com/pytorch/pytorch/pull/30108)"); #endif } #ifndef C10_MOBILE size_t VectorReader::read(uint64_t pos, void* buf, size_t n, const char* what) const { std::copy( data_.data() + pos, data_.data() + pos + n, reinterpret_cast(buf)); return n; } size_t StringViewReader::read( uint64_t pos, void* buf, size_t n, const char* what) const { std::copy( data_.data() + pos, data_.data() + pos + n, reinterpret_cast(buf)); return n; } #endif IValue pickle_load(const std::vector& data) { // Read in the pickle data #ifndef C10_MOBILE caffe2::serialize::PyTorchStreamReader reader( std::make_unique(data)); return readArchiveAndTensors( "data", /*pickle_prefix=*/"", /*tensor_prefix=*/"", /*type_resolver=*/std::nullopt, /*obj_loader=*/std::nullopt, /*device=*/std::nullopt, reader); #else AT_ERROR( "pickle_load not supported on mobile " "(see https://github.com/pytorch/pytorch/pull/30108)"); #endif }; // A specialized version of pickle_load that can load custom objects. c10::IValue pickle_load_obj(std::string_view data) { #ifndef C10_MOBILE caffe2::serialize::PyTorchStreamReader reader( std::make_unique(data)); return torch::jit::readArchiveAndTensors( "data", /*pickle_prefix=*/"", /*tensor_prefix=*/"", /*type_resolver=*/customClassResolver, /*obj_loader=*/torch::jit::ObjLoaderFunc, /*device=*/c10::nullopt, reader); #else AT_ERROR( "pickle_load not supported on mobile " "(see https://github.com/pytorch/pytorch/pull/30108)"); #endif } IValue unpickle( std::function reader, TypeResolver type_resolver, c10::ArrayRef tensor_table, c10::TypePtr (*type_parser)(const std::string&), ObjLoader obj_loader) { Unpickler unpickler( std::move(reader), std::move(type_resolver), tensor_table, std::move(obj_loader), type_parser); return unpickler.parse_ivalue(); } IValue unpickle( const char* data, size_t size, TypeResolver type_resolver, c10::ArrayRef tensor_table, c10::TypePtr (*type_parser)(const std::string&)) { return unpickle( data, size, nullptr, std::move(type_resolver), tensor_table, type_parser); } IValue unpickle( const char* data, size_t size, ObjLoader obj_loader, TypeResolver type_resolver, c10::ArrayRef tensor_table, c10::TypePtr (*type_parser)(const std::string&)) { size_t bytes_read = 0; return unpickle( [&](char* buffer, size_t len) -> size_t { if (bytes_read >= size) { return 0; } len = std::min(size - bytes_read, len); // Copy len bytes into buffer const char* start = data + bytes_read; std::memcpy(buffer, start, len); bytes_read += len; return len; }, std::move(type_resolver), tensor_table, type_parser, std::move(obj_loader)); } } // namespace torch::jit