#include #include #ifdef USE_RPC #include #endif #include #include #include #include #include #include #include namespace torch::jit { using ::c10::IValue; // Protocol 2 is the highest that can be decoded by Python 2 // See https://docs.python.org/3/library/pickle.html#data-stream-format constexpr static uint8_t PROTOCOL_VERSION = 2; // NOLINTNEXTLINE(bugprone-exception-escape) Pickler::~Pickler() { flush(); } void Pickler::protocol() { push(PickleOpCode::PROTO); push(PROTOCOL_VERSION); } void Pickler::startTuple() { // All attributes get pushed into a tuple and their indices saved in the // module def push(PickleOpCode::MARK); } void Pickler::endTuple() { push(PickleOpCode::TUPLE); } void Pickler::stop() { push(PickleOpCode::STOP); flush(); } // unmemoized version called by pushIValue void Pickler::pushIValueImpl(const IValue& ivalue) { if (ivalue.isTensor()) { pushTensor(ivalue); } else if (ivalue.isTuple()) { pushTuple(ivalue); } else if (ivalue.isDouble()) { pushDouble(ivalue.toDouble()); } else if (ivalue.isComplexDouble()) { pushComplexDouble(ivalue); } else if (ivalue.isInt()) { pushInt(ivalue.toInt()); } else if (ivalue.isBool()) { pushBool(ivalue.toBool()); } else if (ivalue.isString()) { pushString(ivalue.toStringRef()); } else if (ivalue.isGenericDict()) { pushDict(ivalue); } else if (ivalue.isNone()) { push(PickleOpCode::NONE); } else if (ivalue.isIntList()) { pushSpecializedList(ivalue, "build_intlist", [this](const IValue& ivalue) { for (const int64_t item : ivalue.toIntVector()) { pushInt(item); } }); } else if (ivalue.isTensorList()) { pushSpecializedList( ivalue, "build_tensorlist", [this](const IValue& ivalue) { for (const at::Tensor& item : ivalue.toTensorVector()) { pushIValue(item); } }); } else if (ivalue.isDoubleList()) { pushSpecializedList( ivalue, "build_doublelist", [this](const IValue& ivalue) { for (double item : ivalue.toDoubleVector()) { pushDouble(item); } }); } else if (ivalue.isBoolList()) { pushSpecializedList(ivalue, "build_boollist", [this](const IValue& ivalue) { for (bool item : ivalue.toBoolList()) { pushBool(item); } }); // note: isList must be after isIntList and friends because // isList is true for all lists. } else if (ivalue.isList()) { pushGenericList(ivalue); } else if (ivalue.isObject()) { auto obj = ivalue.toObject(); auto type = obj->type(); if (memoized_class_types_ != nullptr) { // memoize every class type the Pickler encountered // This is used to make sure we capture all the run-time types // and serialize them properly for class/interface polymorphism memoized_class_types_->emplace_back(type); } auto type_name = type->name().value(); if (type_renamer_) { type_name = type_renamer_(type); } pushGlobal(type_name.prefix(), type_name.name()); push(PickleOpCode::EMPTY_TUPLE); push(PickleOpCode::NEWOBJ); if (checkHasValidSetGetState(type)) { Function& getstate = type->getMethod("__getstate__"); pushIValue(getstate({obj})); } else { push(PickleOpCode::EMPTY_DICT); push(PickleOpCode::MARK); for (size_t i = 0, n = type->numAttributes(); i < n; ++i) { pushString(type->getAttributeName(i)); pushIValue(obj->getSlot(i)); } push(PickleOpCode::SETITEMS); } push(PickleOpCode::BUILD); } else if (ivalue.isDevice()) { pushDevice(ivalue); } else if (ivalue.isCapsule()) { std::stringstream err; err << "Cannot serialize custom bound C++ class"; if (memoized_class_types_ && !memoized_class_types_->empty()) { if (auto qualname = memoized_class_types_->back()->name()) { err << " " << qualname->qualifiedName(); } } err << ". Please define serialization methods via def_pickle() for " "this class."; AT_ERROR(err.str()); } else if (ivalue.isRRef()) { #ifdef USE_RPC TORCH_CHECK( torch::distributed::rpc::getAllowJitRRefPickle() == true, "RRef jit pickling is only allowed inside RPC calls."); pushRRef(ivalue); #else TORCH_CHECK( false, "RRef pickling is only supported with the distributed package"); #endif } else if (ivalue.isEnum()) { auto enum_holder = ivalue.toEnumHolder(); const auto& qualified_class_name = enum_holder->type()->qualifiedClassName(); pushGlobal(qualified_class_name.prefix(), qualified_class_name.name()); pushIValue(enum_holder->value()); push(PickleOpCode::REDUCE); } else { AT_ERROR("Unknown IValue type for pickling: ", ivalue.tagKind()); } } void Pickler::pushDevice(const IValue& ivalue) { auto device = ivalue.toDevice(); auto deviceStr = device.str(); auto it = memoized_devices_map_.find(deviceStr); if (it == memoized_devices_map_.end()) { pushGlobal("torch", "device"); pushString(deviceStr); push(PickleOpCode::TUPLE1); push(PickleOpCode::REDUCE); memoized_devices_map_[deviceStr] = pushNextBinPut(); } else { pushBinGet(it->second); } } #ifdef USE_RPC void Pickler::pushRRef(const IValue& ivalue) { // It is the same as how rref is pickled in python, see PyRRef::pickle auto rrefInterface = ivalue.toRRef(); auto rref = c10::static_intrusive_pointer_cast(rrefInterface); pushGlobal("torch.distributed.rpc", "rref"); auto& ctx = distributed::rpc::RRefContext::getInstance(); auto rrefForkData = ctx.prepareChildFork(rref); push(PickleOpCode::MARK); pushInt(rrefForkData.ownerId_); pushInt(rrefForkData.rrefId_.createdOn_); pushInt(rrefForkData.rrefId_.localId_); pushInt(rrefForkData.forkId_.createdOn_); pushInt(rrefForkData.forkId_.localId_); pushInt(rrefForkData.parent_); pushString(rrefForkData.typeStr_); push(PickleOpCode::TUPLE); push(PickleOpCode::REDUCE); } #endif void Pickler::pushIValue(const IValue& ivalue) { bool shouldMemoizeByPointer = ivalue.isPtrType() && !ivalue.isString() && ivalue.use_count() > 1; // Mutable ivalues are memoized by pointer equality, which we handle at this // outer granularity. Immutable ivalues are memoized by value equality which // is handled in the type-specific handlers inside pushIValueImpl. if (shouldMemoizeByPointer) { const void* ptr = ivalue.internalToPointer(); TORCH_CHECK( ptr != nullptr, "Pickler cannot memoize ", ivalue.tagKind(), " IValue ", ivalue); auto memo_entry = memoized_ivalue_map_.find(ptr); if (memo_entry != memoized_ivalue_map_.end()) { // This value has already been pushed, just do a BINGET pushBinGet(memo_entry->second); return; } pushIValueImpl(ivalue); memoized_ivalues_.push_back(ivalue); memoized_ivalue_map_[ptr] = pushNextBinPut(); } else { pushIValueImpl(ivalue); } } void Pickler::pushInt(int64_t n) { if (n >= std::numeric_limits::min() && n <= std::numeric_limits::max()) { push(PickleOpCode::BININT1); push(n); } else if ( n >= std::numeric_limits::min() && n <= std::numeric_limits::max()) { push(PickleOpCode::BININT2); push(to_le16(n)); } else if ( n >= std::numeric_limits::min() && n <= std::numeric_limits::max()) { push(PickleOpCode::BININT); push(to_le32(n)); } else { // Push 8 byte integer push(PickleOpCode::LONG1); push(8); push(to_le64(n)); } } void Pickler::pushBool(bool value) { push(value ? PickleOpCode::NEWTRUE : PickleOpCode::NEWFALSE); } void Pickler::pushBinGet(uint32_t memo_id) { if (memo_id <= std::numeric_limits::max()) { push(PickleOpCode::BINGET); push(memo_id); } else { // Memoized too many items, issue a LONG_BINGET instead push(PickleOpCode::LONG_BINGET); push(memo_id); } } // unmemoized encoding of a string void Pickler::pushStringImpl(const std::string& string) { if (string.size() <= UINT_MAX) { push(PickleOpCode::BINUNICODE); push(to_le32(string.size())); pushBytes(string); } else { push(PickleOpCode::BINUNICODE8); push(to_le64(string.size())); pushBytes(string); } } void Pickler::pushString(const std::string& string) { auto it = memoized_strings_map_.find(string); if (it == memoized_strings_map_.end()) { pushStringImpl(string); memoized_strings_map_[string] = pushNextBinPut(); } else { pushBinGet(it->second); } } void Pickler::pushStorageOfTensor(const at::Tensor& tensor) { const at::Storage& storage = tensor.storage(); void* addr = storage.unsafeGetStorageImpl(); auto it = memoized_storage_map_.find(addr); if (it != memoized_storage_map_.end()) { pushBinGet(it->second); return; } // Tuple for persistent_load push(PickleOpCode::MARK); // typename pushString("storage"); // data_type std::string data_type = std::string(toString(tensor.scalar_type())).append("Storage"); pushGlobal("torch", data_type); // root_key std::string root_key = get_tensor_id_ != nullptr ? get_tensor_id_(tensor) : std::to_string(tensor_data_.size()); pushString(root_key); // location pushString(tensor.device().str()); // size pushInt( static_cast(tensor.storage().nbytes() / tensor.element_size())); push(PickleOpCode::TUPLE); push(PickleOpCode::BINPERSID); // TODO: Skip this if not writing tensors memoized_storage_map_[addr] = pushNextBinPut(); tensor_data_.push_back(tensor); } void Pickler::pushBytes(const std::string& string) { static const size_t kSmallStr = 32; if (string.size() <= kSmallStr && bufferPos_ + string.size() <= buffer_.size()) { // Small string that fits: buffer the data. memcpy(buffer_.data() + bufferPos_, string.data(), string.size()); bufferPos_ += string.size(); } else { // Otherwise, first flush, then write directly. flush(); writer_(string.data(), string.size()); } } void Pickler::pushGlobal( c10::string_view module_name, c10::string_view class_name) { std::string key; key.reserve(module_name.size() + class_name.size() + 2); key.append(module_name.data(), module_name.size()); key.push_back('\n'); key.append(class_name.data(), class_name.size()); key.push_back('\n'); const auto memo_entry = memoized_globals_map_.find(key); if (memo_entry == memoized_globals_map_.end()) { push(PickleOpCode::GLOBAL); pushBytes(key); // Push BINPUT without adding anything to the memoized_ivalues_ size_t memo_id = pushNextBinPut(); memoized_globals_map_.insert({key, memo_id}); } else { pushBinGet(memo_entry->second); } } void Pickler::pushTensor(const IValue& ivalue) { if (tensor_table_ == nullptr) { pushLiteralTensor(ivalue); } else { pushTensorReference(ivalue); } } void Pickler::pushLiteralSparseTensor(const at::Tensor& tensor) { pushGlobal("torch._utils", "_rebuild_sparse_tensor"); push(PickleOpCode::MARK); // layout auto layout = tensor.layout(); pushInt(static_cast(layout)); switch (layout) { case c10::Layout::Sparse: // size push(PickleOpCode::MARK); for (auto size : tensor.sizes()) { pushInt(size); } push(PickleOpCode::TUPLE); // requires grad pushIValue(tensor.requires_grad()); // indices pushTensor(tensor._indices()); // values pushTensor(tensor._values()); break; case c10::Layout::SparseCsr: push(PickleOpCode::MARK); for (auto size : tensor.sizes()) { pushInt(size); } push(PickleOpCode::TUPLE); pushIValue(tensor.requires_grad()); pushTensor(tensor.crow_indices()); pushTensor(tensor.col_indices()); pushTensor(tensor.values()); break; default: TORCH_CHECK( false, "Unsupported sparse tensor layout type in serialization ", layout); break; } // backward_hooks pushGlobal("collections", "OrderedDict"); push(PickleOpCode::EMPTY_TUPLE); // Construct the collections.OrderedDict for the backward_hooks push(PickleOpCode::REDUCE); push(PickleOpCode::TUPLE); // Call torch._utils._rebuild_sparse_coo_tensor push(PickleOpCode::REDUCE); } void Pickler::pushLiteralTensor(const IValue& ivalue) { // In contrast to tensor references, literal tensors are included in the // pickle program binary blob. They are written to the file after the STOP // opcode. They can't be included in the pickle program itself without a bunch // of extra machinery since byte strings are limited to 4 GB. // // The format here is the same one used by `torch.save()`. The code for the // format can be found in `torch/serialization.py`. auto& tensor = ivalue.toTensor(); if (tensor.is_sparse() || tensor.is_sparse_csr()) { pushLiteralSparseTensor(tensor); return; } bool quantized = tensor.is_quantized(); // The arguments to this function are: // storage, storage_offset, size, stride, requires_grad, backward_hooks pushGlobal( "torch._utils", quantized ? "_rebuild_qtensor" : "_rebuild_tensor_v2"); push(PickleOpCode::MARK); pushStorageOfTensor(tensor); // storage offset pushInt(tensor.storage_offset()); // size push(PickleOpCode::MARK); for (auto size : tensor.sizes()) { pushInt(size); } push(PickleOpCode::TUPLE); // stride push(PickleOpCode::MARK); for (auto stride : tensor.strides()) { pushInt(stride); } push(PickleOpCode::TUPLE); if (quantized) { push(PickleOpCode::MARK); pushGlobal("torch", toString(tensor.qscheme())); // tuple of (qscheme, scale, zp) or (qscheme, scales, zps, axis) switch (tensor.qscheme()) { case at::kPerTensorAffine: pushDouble(tensor.q_scale()); pushInt(tensor.q_zero_point()); break; case at::kPerChannelAffineFloatQParams: case at::kPerChannelAffine: { pushTensor(tensor.q_per_channel_scales()); pushTensor(tensor.q_per_channel_zero_points()); pushInt(tensor.q_per_channel_axis()); } break; default: TORCH_CHECK( false, "Unsupported tensor quantization type in serialization ", toString(tensor.qscheme())); break; } push(PickleOpCode::TUPLE); } // requires_grad pushIValue(tensor.requires_grad()); // backward_hooks pushGlobal("collections", "OrderedDict"); push(PickleOpCode::EMPTY_TUPLE); // Construct the collections.OrderedDict for the backward_hooks push(PickleOpCode::REDUCE); if (!quantized) { // Only push it for regular tensor if the dictionary is not empty. auto metadata = torch::jit::getTensorMetadata(tensor); if (!metadata.empty()) { // IValues based on std::unordered_map are slow and deprecated. // Thus, pass a c10::Dict to pushDict. c10::Dict math_bits_; for (const auto& pair : metadata) { math_bits_.insert(pair.first, pair.second); } pushDict(math_bits_); } } push(PickleOpCode::TUPLE); // Call torch._utils._rebuild_tensor_v2 push(PickleOpCode::REDUCE); } void Pickler::pushSpecializedList( const IValue& ivalue, const char* list_name, const std::function& item_pusher) { pushGlobal("torch.jit._pickle", list_name); // Reduce arguments are spread (e.g. `*args`) before calling the global, // so wrap in a tuple push(PickleOpCode::MARK); push(PickleOpCode::EMPTY_LIST); // Mark list push(PickleOpCode::MARK); // Add all items item_pusher(ivalue); // Finish list push(PickleOpCode::APPENDS); // Finish tuple push(PickleOpCode::TUPLE); // Call reduce push(PickleOpCode::REDUCE); } static inline double swapDouble(double value) { const char* bytes = reinterpret_cast(&value); double flipped = 0; char* out_bytes = reinterpret_cast(&flipped); for (const auto i : c10::irange(sizeof(double))) { out_bytes[i] = bytes[sizeof(double) - i - 1]; } return *reinterpret_cast(out_bytes); } void Pickler::pushDouble(double value) { push(PickleOpCode::BINFLOAT); #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ // Python pickle format is big endian, swap. push(swapDouble(value)); #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ push(value); #else #error Unexpected or undefined __BYTE_ORDER__ #endif } void Pickler::pushComplexDouble(const IValue& value) { c10::complex d = value.toComplexDouble(); pushGlobal("builtins", "complex"); pushIValue(d.real()); pushIValue(d.imag()); push(PickleOpCode::TUPLE2); push(PickleOpCode::REDUCE); } void Pickler::pushLong(const std::string& data) { uint64_t size = data.size(); TORCH_INTERNAL_ASSERT( size <= std::numeric_limits::max(), "Cannot pickle a long larger than 255 bytes"); push(PickleOpCode::LONG1); push(size); pushBytes(data); } void Pickler::pushTensorReference(const IValue& ivalue) { pushGlobal("torch.jit._pickle", "build_tensor_from_id"); tensor_table_->push_back(ivalue.toTensor()); auto tensor_id = tensor_table_->size() - 1; // Reduce arguments are spread (e.g. `*args`) before calling the global, // so wrap in a tuple push(PickleOpCode::MARK); pushIValue(static_cast(tensor_id)); push(PickleOpCode::TUPLE); push(PickleOpCode::REDUCE); } // startTypeTag() and endTypeTag() must be called in a pair, with 1 argument // pushed on the stack in between them. They will add the type of a container // ivalue to the stack as a string so we can preserve type tags across // serialization void Pickler::startTypeTag() { if (tag_aggregates_) { pushGlobal("torch.jit._pickle", "restore_type_tag"); } } namespace { std::optional type_printer(const c10::Type& type) { if (auto dyn = type.castRaw()) { return dyn->fallback()->annotation_str(type_printer); } return std::nullopt; } } // namespace // See startTypeTag void Pickler::endTypeTag(const IValue& ivalue) { if (!tag_aggregates_) { return; } TORCH_INTERNAL_ASSERT(ivalue.isGenericDict() || ivalue.isList()); // Push the dict type auto type = ivalue.type(); TORCH_INTERNAL_ASSERT(type); auto annot_str = type->annotation_str(type_printer); pushString(annot_str); // Pop the dict and type into a tuple push(PickleOpCode::TUPLE2); // Call function via reduce push(PickleOpCode::REDUCE); } void Pickler::pushDict(const IValue& ivalue) { auto dict = ivalue.toGenericDict(); startTypeTag(); push(PickleOpCode::EMPTY_DICT); static_assert( std::is_unsigned_v, "Expected size to be non-negative."); push(PickleOpCode::MARK); // Sort the dict for deterministic keys for (const auto& entry : dict) { pushIValue(entry.key()); pushIValue(entry.value()); } push(PickleOpCode::SETITEMS); endTypeTag(ivalue); } size_t Pickler::pushNextBinPut() { if (memo_id_ <= std::numeric_limits::max()) { push(PickleOpCode::BINPUT); push(memo_id_); } else { // Memoized too many items, issue a LONG_BINPUT instead push(PickleOpCode::LONG_BINPUT); push(memo_id_); } AT_ASSERT(memo_id_ <= std::numeric_limits::max()); ++memo_id_; return memo_id_ - 1; } void Pickler::pushGenericList(const IValue& ivalue) { auto list = ivalue.toListRef(); startTypeTag(); // Push the list items push(PickleOpCode::EMPTY_LIST); push(PickleOpCode::MARK); for (const IValue& item : list) { pushIValue(item); } push(PickleOpCode::APPENDS); endTypeTag(ivalue); } void Pickler::pushTuple(const IValue& ivalue) { auto tuple = ivalue.toTuple(); auto tuple_size = tuple->elements().size(); switch (tuple_size) { case 0: { push(PickleOpCode::EMPTY_TUPLE); } break; case 1: { pushIValue(tuple->elements()[0]); push(PickleOpCode::TUPLE1); } break; case 2: { pushIValue(tuple->elements()[0]); pushIValue(tuple->elements()[1]); push(PickleOpCode::TUPLE2); } break; case 3: { pushIValue(tuple->elements()[0]); pushIValue(tuple->elements()[1]); pushIValue(tuple->elements()[2]); push(PickleOpCode::TUPLE3); } break; default: { push(PickleOpCode::MARK); for (const IValue& item : tuple->elements()) { pushIValue(item); } push(PickleOpCode::TUPLE); } break; } } WriteableTensorData getWriteableTensorData( const at::Tensor& tensor, bool to_cpu) { WriteableTensorData result; result.tensor_ = tensor; result.size_ = tensor.storage().nbytes(); // TODO HIP support if (tensor.storage().device_type() != DeviceType::CPU && to_cpu) { // NB: This new tensor is created to support cuda tensors. // Storages can be mutated when converting tensors from cuda to cpu, // and we need a cpu tensor to copy data from. result.tensor_ = at::empty({0}, tensor.options()) .set_( tensor.storage(), /* storage_offset = */ 0, /* size = */ {static_cast( tensor.storage().nbytes() / tensor.element_size())}, /* stride = */ {1}) .cpu(); TORCH_CHECK( result.tensor_.storage().nbytes() == result.size_, "Storage tensor size did not match record size"); } return result; } bool checkHasValidSetGetState(const std::shared_ptr& cls) { // Check that the schemas for __getstate__ and __setstate__ are correct auto getstate = cls->findMethod("__getstate__"); if (getstate == nullptr) { return false; } auto get_schema = getstate->getSchema(); // Check __getstate__ // __getstate__ is expected to be (self) -> T TORCH_CHECK( get_schema.arguments().size() == 1, "'__getstate__' must have 'self' as its only argument, but found ", get_schema.arguments().size(), " arguments"); TORCH_CHECK( get_schema.returns().size() == 1, "'__getstate__' must return 1 value, but found ", get_schema.returns().size()); // Check __setstate__ if the method exists // __setstate__ is expected to be (self, T) -> None auto setstate = cls->findMethod("__setstate__"); if (!setstate) { return false; } auto set_schema = setstate->getSchema(); TORCH_CHECK( set_schema.arguments().size() == 2, "'__setstate__' must have 'self' and the state as its " "only arguments, but found ", set_schema.arguments().size(), " arguments"); TORCH_CHECK( set_schema.returns().size() == 1, "'__setstate__' must return None, but found ", set_schema.returns().size(), " return values"); TORCH_CHECK( set_schema.returns().at(0).type()->isSubtypeOf(*NoneType::get()), "'__setstate__' must return None, but found value of type", set_schema.returns().at(0).type()->annotation_str()); // Check that the return type of __getstate__ matches the input to // __setstate__ auto get_type = get_schema.returns().at(0).type(); auto set_type = set_schema.arguments().at(1).type(); TORCH_CHECK( get_type->isSubtypeOf(*set_type), "'__getstate__'s return type (", get_type->annotation_str(), ") does not match '__setstate__'s argument type (", set_type->annotation_str(), ")"); return true; } } // namespace torch::jit