#pragma once #include #include #include #include #include #include #include #include #include #include #if defined(HAVE_MMAP) #include #include #include #include #include #endif /** * @file * * Helpers for identifying file formats when reading serialized data. * * Note that these functions are declared inline because they will typically * only be called from one or two locations per binary. */ namespace torch::jit { /** * The format of a file or data stream. */ enum class FileFormat { UnknownFileFormat = 0, FlatbufferFileFormat, ZipFileFormat, }; /// The size of the buffer to pass to #getFileFormat(), in bytes. constexpr size_t kFileFormatHeaderSize = 8; constexpr size_t kMaxAlignment = 16; /** * Returns the likely file format based on the magic header bytes in @p header, * which should contain the first bytes of a file or data stream. */ // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration) static inline FileFormat getFileFormat(const char* data) { // The size of magic strings to look for in the buffer. static constexpr size_t kMagicSize = 4; // Bytes 4..7 of a Flatbuffer-encoded file produced by // `flatbuffer_serializer.h`. (The first four bytes contain an offset to the // actual Flatbuffer data.) static constexpr std::array kFlatbufferMagicString = { 'P', 'T', 'M', 'F'}; static constexpr size_t kFlatbufferMagicOffset = 4; // The first four bytes of a ZIP file. static constexpr std::array kZipMagicString = { 'P', 'K', '\x03', '\x04'}; // Note that we check for Flatbuffer magic first. Since the first four bytes // of flatbuffer data contain an offset to the root struct, it's theoretically // possible to construct a file whose offset looks like the ZIP magic. On the // other hand, bytes 4-7 of ZIP files are constrained to a small set of values // that do not typically cross into the printable ASCII range, so a ZIP file // should never have a header that looks like a Flatbuffer file. if (std::memcmp( data + kFlatbufferMagicOffset, kFlatbufferMagicString.data(), kMagicSize) == 0) { // Magic header for a binary file containing a Flatbuffer-serialized mobile // Module. return FileFormat::FlatbufferFileFormat; } else if (std::memcmp(data, kZipMagicString.data(), kMagicSize) == 0) { // Magic header for a zip file, which we use to store pickled sub-files. return FileFormat::ZipFileFormat; } return FileFormat::UnknownFileFormat; } /** * Returns the likely file format based on the magic header bytes of @p data. * If the stream position changes while inspecting the data, this function will * restore the stream position to its original offset before returning. */ // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration) static inline FileFormat getFileFormat(std::istream& data) { FileFormat format = FileFormat::UnknownFileFormat; std::streampos orig_pos = data.tellg(); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) std::array header; data.read(header.data(), header.size()); if (data.good()) { format = getFileFormat(header.data()); } data.seekg(orig_pos, data.beg); return format; } /** * Returns the likely file format based on the magic header bytes of the file * named @p filename. */ // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration) static inline FileFormat getFileFormat(const std::string& filename) { std::ifstream data(filename, std::ifstream::binary); return getFileFormat(data); } // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration) static void file_not_found_error() { std::stringstream message; message << "Error while opening file: "; if (errno == ENOENT) { message << "no such file or directory" << '\n'; } else { message << "error no is: " << errno << '\n'; } TORCH_CHECK(false, message.str()); } // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration) static inline std::tuple, size_t> get_file_content( const char* filename) { #if defined(HAVE_MMAP) int fd = open(filename, O_RDONLY); if (fd < 0) { // failed to open file, chances are it's no such file or directory. file_not_found_error(); } struct stat statbuf {}; fstat(fd, &statbuf); size_t size = statbuf.st_size; void* ptr = mmap(nullptr, statbuf.st_size, PROT_READ, MAP_PRIVATE, fd, 0); close(fd); auto deleter = [statbuf](char* ptr) { munmap(ptr, statbuf.st_size); }; std::shared_ptr data(reinterpret_cast(ptr), deleter); #else FILE* f = fopen(filename, "rb"); if (f == nullptr) { file_not_found_error(); } fseek(f, 0, SEEK_END); size_t size = ftell(f); fseek(f, 0, SEEK_SET); // make sure buffer size is multiple of alignment size_t buffer_size = (size / kMaxAlignment + 1) * kMaxAlignment; std::shared_ptr data( static_cast(c10::alloc_cpu(buffer_size)), c10::free_cpu); fread(data.get(), size, 1, f); fclose(f); #endif return std::make_tuple(data, size); } // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration) static inline std::tuple, size_t> get_stream_content( std::istream& in) { // get size of the stream and reset to orig std::streampos orig_pos = in.tellg(); in.seekg(orig_pos, std::ios::end); const long size = in.tellg(); in.seekg(orig_pos, in.beg); // read stream // NOLINT make sure buffer size is multiple of alignment size_t buffer_size = (size / kMaxAlignment + 1) * kMaxAlignment; std::shared_ptr data( static_cast(c10::alloc_cpu(buffer_size)), c10::free_cpu); in.read(data.get(), size); // reset stream to original position in.seekg(orig_pos, in.beg); return std::make_tuple(data, size); } // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration) static inline std::tuple, size_t> get_rai_content( caffe2::serialize::ReadAdapterInterface* rai) { size_t buffer_size = (rai->size() / kMaxAlignment + 1) * kMaxAlignment; std::shared_ptr data( static_cast(c10::alloc_cpu(buffer_size)), c10::free_cpu); rai->read( 0, data.get(), rai->size(), "Loading ReadAdapterInterface to bytes"); return std::make_tuple(data, buffer_size); } } // namespace torch::jit