#pragma once #include #include #include #include #include #include namespace torch::jit { /// Pickle an IValue by calling a function to handle writing the data. /// /// `writer` is a function that takes in a pointer to a chunk of memory and its /// size and consumes it. /// /// See `jit::pickle` for more details. TORCH_API void pickle( std::function writer, const IValue& ivalue, std::vector* tensor_table = nullptr); /// Save a `torch::IValue` in a format compatible with Python's `pickle` module /// /// If present, `tensor_table` is a pointer to a table in which tensors that /// are contained within `ivalue` are stored, and the bytes returned by the /// pickler will only include references to these tensors in the table. This can /// be used to keep the binary blob size small. /// If not provided, tensors are stored in the same byte stream as the pickle /// data, similar to `torch.save()` in eager Python. /// /// Pickled values can be loaded in Python and C++: /// \rst /// .. code-block:: cpp /// /// torch::IValue float_value(2.3); /// /// // TODO: when tensors are stored in the pickle, delete this /// std::vector tensor_table; /// auto data = torch::jit::pickle(float_value, &tensor_table); /// /// std::vector ivalues = /// torch::jit::unpickle(data.data(), data.size()); /// /// .. code-block:: python /// /// values = torch.load('data.pkl') /// print(values) /// /// \endrst TORCH_API std::vector pickle( const IValue& ivalue, std::vector* tensor_table = nullptr); /// Save a `torch::IValue` in a format that can be loaded by both /// `torch::pickle_load` in C++ and `torch.load` in Python. TORCH_API std::vector pickle_save(const IValue& ivalue); /// Deserialize a `torch::IValue` from bytes produced by either /// `torch::pickle_save` in C++ or `torch.save` in Python TORCH_API IValue pickle_load(const std::vector& data); /// Deserialize a `torch::IValue` from bytes produced by either /// `torch::pickle_save` in C++ or `torch.save` in Python with custom object. TORCH_API IValue pickle_load_obj(std::string_view data); /// `reader` is a function that takes in a size to read from some pickled /// binary. `reader` should remember where it last read, and return /// the number of bytes read. /// See `torch::pickle` for details. /// type_resolver is used to resolve any JIT type based on type str TORCH_API IValue unpickle( std::function reader, TypeResolver type_resolver, c10::ArrayRef tensor_table, c10::TypePtr (*type_parser)(const std::string&) = Unpickler::defaultTypeParser, ObjLoader obj_loader = nullptr); /// Decode a chunk of memory containing pickled data into its `torch::IValue`s. /// /// If any `torch::IValue`s in the pickled data are `Object`s, then a /// `class_resolver` function must be provided. /// /// See `torch::pickle` for details. TORCH_API IValue unpickle( const char* data, size_t size, TypeResolver type_resolver = nullptr, c10::ArrayRef tensor_table = {}, c10::TypePtr (*type_parser)(const std::string&) = Unpickler::defaultTypeParser); /// Decode a chunk of memory containing pickled data into its `torch::IValue`s. /// /// If any `torch::IValue`s in the pickled data are `Object`s, then a /// `class_resolver` function must be provided. /// /// See `torch::pickle` for details. TORCH_API IValue unpickle( const char* data, size_t size, ObjLoader obj_loader, TypeResolver type_resolver = nullptr, c10::ArrayRef tensor_table = {}, c10::TypePtr (*type_parser)(const std::string&) = Unpickler::defaultTypeParser); #ifndef C10_MOBILE class VectorReader : public caffe2::serialize::ReadAdapterInterface { public: VectorReader(std::vector data) : data_(std::move(data)) {} size_t size() const override { return data_.size(); } size_t read(uint64_t pos, void* buf, size_t n, const char* what) const override; private: std::vector data_; }; class StringViewReader : public caffe2::serialize::ReadAdapterInterface { public: StringViewReader(std::string_view data) : data_(data) {} size_t size() const override { return data_.size(); } size_t read(uint64_t pos, void* buf, size_t n, const char* what) const override; private: std::string_view data_; }; #endif } // namespace torch::jit