#include #include #ifdef USE_RPC #include #endif #include #include #include #include #include #include #include #include namespace torch::jit { using ::c10::IValue; static void restoreAccurateTypeTagsIfPossible(const IValue& root) { if (root.isObject()) { restoreAccurateTypeTags(root, root.type()); } } // Pickled objects are stored in a form compatible with Python pickling. // In torchscript List[T]/Dict[K, V] are statically typed and contain // dynamic type tags that allow T, K, and V to be recovered. But this // info is not stored in the Python pickling information. However, we // can recover this information from the static type of the top-level // object being unpickled, because we have a record of the type of the // objects it contains as attributes. // `IfPossible` - we can only do this recovery when we have an object as // the top-level unpickled thing (which is guaranteed for Modules, but // not for torch.load/torch.save). Otherwise we do not know the types // of the contained objects and cannot restore the tags. void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) { struct Work { TypePtr type; IValue value; }; std::vector to_process = {{type_tag, root}}; std::unordered_set scanned; while (!to_process.empty()) { Work w = std::move(to_process.back()); to_process.pop_back(); // ensure we only scan each pointer value once, otherwise this // can become exponential (and if we allow recursive data in the future, // it would not terminiate). if (w.value.isPtrType()) { const void* key = w.value.internalToPointer(); auto it = scanned.find(key); if (it != scanned.end()) { continue; } scanned.emplace_hint(it, key); } auto kind = w.type->kind(); if (auto dyn = w.type->castRaw()) { kind = dyn->dynamicKind(); } switch (kind) { case TensorType::Kind: case StorageType::Kind: case NumberType::Kind: case FloatType::Kind: case ComplexType::Kind: case IntType::Kind: case NoneType::Kind: case GeneratorType::Kind: case QuantizerType::Kind: case BoolType::Kind: case VarType::Kind: case CapsuleType::Kind: case PyObjectType::Kind: case StringType::Kind: case FunctionType::Kind: case DeviceObjType::Kind: case StreamObjType::Kind: case QSchemeType::Kind: case LayoutType::Kind: case MemoryFormatType::Kind: case ScalarTypeType::Kind: case RRefType::Kind: case AnyType::Kind: case AnyListType::Kind: case AnyTupleType::Kind: case AnyClassType::Kind: case AnyEnumType::Kind: // no op, there is nothing to tag break; case c10::SymIntType::Kind: // TODO: Can this really show up though? :think: TORCH_CHECK(!w.value.toSymInt().is_heap_allocated()); // no op, there is nothing to tag break; case c10::SymFloatType::Kind: TORCH_CHECK(!w.value.toSymFloat().is_symbolic()); // no op, there is nothing to tag break; case c10::SymBoolType::Kind: TORCH_CHECK(!w.value.toSymBool().is_heap_allocated()); // no op, there is nothing to tag break; case DynamicType::Kind: case UnionType::Kind: case EnumType::Kind: // TODO(gmagogsfm): Implement serialization/deserialization of Enum. TORCH_INTERNAL_ASSERT(false); case TupleType::Kind: { auto t = w.value.toTuple(); for (size_t i = 0; i < w.type->containedTypeSize(); ++i) { Work elem = {w.type->containedType(i), t->elements().at(i)}; to_process.emplace_back(std::move(elem)); } } break; case FutureType::Kind: { auto f = w.value.toFuture(); if (f->completed()) { Work elem = {w.type->containedType(0), f->value()}; to_process.emplace_back(std::move(elem)); } } break; case AwaitType::Kind: { auto aw = w.value.toAwait(); if (aw->completed()) { Work elem = {w.type->containedType(0), aw->wait()}; to_process.emplace_back(std::move(elem)); } } break; case OptionalType::Kind: { if (!w.value.isNone()) { Work elem = {w.type->containedType(0), w.value}; to_process.emplace_back(std::move(elem)); } } break; case ListType::Kind: { // specialized lists do not need their type refined, so we can exit // early here if (!w.value.isList()) { break; } auto elem_type = w.type->containedType(0); auto lst = w.value.toList(); lst.unsafeSetElementType(elem_type); for (const IValue& item : lst) { Work elem = {elem_type, item}; to_process.emplace_back(std::move(elem)); } } break; case DictType::Kind: { auto d = w.value.toGenericDict(); auto keyType = w.type->containedType(0); auto valType = w.type->containedType(1); d.unsafeSetKeyType(keyType); d.unsafeSetValueType(valType); for (const auto& item : d) { Work kelem = {keyType, item.key()}; Work velem = {valType, item.value()}; to_process.emplace_back(std::move(kelem)); to_process.emplace_back(std::move(velem)); } } break; // in both cases the dynamic type is a class, and we are going to tag with // the dynamic type case InterfaceType::Kind: case ClassType::Kind: { auto obj = w.value.toObject(); auto typ = obj->type(); // note: intentionally using the dynamic type, // the static type is potentially less accurate for (size_t i = 0; i < typ->numAttributes(); ++i) { Work elem = {typ->getAttribute(i), obj->getSlot(i)}; to_process.emplace_back(std::move(elem)); } }; } } } namespace { template bool is(const Type& type) { if (type.kind() == T::Kind) { return true; } if (auto dyn = type.castRaw()) { return dyn->tag() == c10::DynamicTypeTrait::tagValue(); } return false; } } // namespace static void restoreContainerTypeTags( const IValue& ivalue, const TypePtr& type) { if (is(*type)) { auto dict = ivalue.toGenericDict(); dict.unsafeSetKeyType(type->containedType(0)); dict.unsafeSetValueType(type->containedType(1)); } else if (is(*type)) { ivalue.toList().unsafeSetElementType(type->containedType(0)); } else { AT_ERROR("Unknown type for tag restoration: " + type->annotation_str()); } } IValue Unpickler::parse_ivalue() { run(); TORCH_CHECK( stack_.size() == 1, "Unpickler expected 1 element on the stack, but found ", stack_.size()); if (version_ <= 2) { // See [type tag serialization] restoreAccurateTypeTagsIfPossible(stack_[0]); } return stack_[0]; } double Unpickler::readFloat() { AT_ASSERT(sizeof(double) == 8); double big_endian = read(); #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ double little_endian = 0; // Pickle floats are big endian, so reverse the bytes auto big_endian_ptr = reinterpret_cast(&big_endian); std::reverse_copy( big_endian_ptr, big_endian_ptr + sizeof(big_endian), reinterpret_cast(&little_endian)); return little_endian; #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ return big_endian; #else #error Unexpected or undefined __BYTE_ORDER__ #endif } void Unpickler::run() { // Expect a PROTO opcode and protocol number at the start of blob auto opcode = readOpCode(); TORCH_CHECK( opcode == PickleOpCode::PROTO, "Expected PROTO opcode at the start" " of pickle archive, found ", int(static_cast(opcode))); uint8_t protocol = read(); TORCH_CHECK( protocol == 2, "Only Pickle protocol 2 is supported, found protocol = ", protocol); while (true) { PickleOpCode opcode = readInstruction(); if (opcode == PickleOpCode::STOP) { return; } } } void Unpickler::setInput(size_t memo_id) { AT_ASSERT(!stack_.empty()); if (memo_id >= memo_table_.size()) { memo_table_.insert( memo_table_.end(), memo_id - memo_table_.size(), IValue()); memo_table_.push_back(stack_.back()); } else { memo_table_[memo_id] = stack_.back(); } } // emplace_back on bool vectors does not exist on some systems // avoid it by calling push_back for bool template inline void append(std::vector& a, T&& e) { a.emplace_back(std::forward(e)); } template <> // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) inline void append(std::vector& a, bool&& e) { a.push_back(e); } static std::vector tupleToIntList(const IValue& v) { return fmap(v.toTupleRef().elements(), [](const IValue& v) -> int64_t { return v.toInt(); }); } // note we cannot use toIntList, toDoubleList because during unpickling the // lists are not yet tagged template static std::vector convertList(const IValue& v) { return fmap(v.toListRef(), [](const IValue& elem) { return elem.to(); }); } PickleOpCode Unpickler::readInstruction() { auto opcode = readOpCode(); switch (opcode) { case PickleOpCode::EMPTY_LIST: { stack_.emplace_back(c10::impl::GenericList(AnyType::get())); } break; case PickleOpCode::EMPTY_TUPLE: { if (empty_tuple_.isNone()) { // we only need one object, since tuples are not mutable. empty_tuple_ = c10::ivalue::Tuple::create(std::vector()); } stack_.emplace_back(empty_tuple_); } break; case PickleOpCode::BINPUT: { size_t memo_id = read(); setInput(memo_id); } break; case PickleOpCode::LONG_BINPUT: { TORCH_CHECK( std::numeric_limits::max() >= std::numeric_limits::max(), "Found a LONG_BINPUT opcode, but size_t on this system is " "not big enough to decode it"); size_t memo_id = read(); setInput(memo_id); } break; case PickleOpCode::MARK: { // Mark location of the container ivalue in the stack marks_.push_back(stack_.size()); } break; case PickleOpCode::NEWTRUE: { stack_.emplace_back(true); } break; case PickleOpCode::NEWFALSE: { stack_.emplace_back(false); } break; case PickleOpCode::NONE: { stack_.emplace_back(); } break; case PickleOpCode::BININT1: { uint8_t value = read(); stack_.emplace_back(int64_t(value)); } break; case PickleOpCode::BININT2: { uint16_t value = from_le16(read()); stack_.emplace_back(int64_t(value)); } break; case PickleOpCode::BININT: { int32_t value = from_le32(read()); stack_.emplace_back(int64_t(value)); } break; case PickleOpCode::LONG1: { // Only read LONG1s with 8 as the length uint8_t length = read(); TORCH_CHECK(length == 8, "Expected length to be 8, got ", int(length)); stack_.emplace_back(int64_t(from_le64(read()))); } break; case PickleOpCode::BINUNICODE: { uint32_t length = from_le32(read()); stack_.emplace_back(readBytes(length)); } break; case PickleOpCode::BINUNICODE8: { int64_t length = from_le64(read()); stack_.emplace_back(readBytes(length)); } break; case PickleOpCode::BINFLOAT: stack_.emplace_back(readFloat()); break; case PickleOpCode::TUPLE: { TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty"); size_t start = marks_.back(); marks_.pop_back(); std::vector elements; TORCH_CHECK( stack_.size() >= start, "Parsing error: wrong start index ", start, " for stack_ of size ", stack_.size()); const auto tupleSize = stack_.size() - start; switch (tupleSize) { case 3: { auto e3 = pop(stack_); auto e2 = pop(stack_); auto e1 = pop(stack_); stack_.emplace_back(c10::ivalue::Tuple::create( std::move(e1), std::move(e2), std::move(e3))); break; } case 2: { auto e2 = pop(stack_); auto e1 = pop(stack_); stack_.emplace_back( c10::ivalue::Tuple::create(std::move(e1), std::move(e2))); break; } case 1: stack_.emplace_back(c10::ivalue::Tuple::create(pop(stack_))); break; default: { elements.reserve(stack_.size() - start); auto start_it = stack_.begin() + static_cast(start); for (auto it = start_it; it != stack_.end(); ++it) { elements.emplace_back(std::move(*it)); } stack_.erase(start_it, stack_.end()); stack_.emplace_back(c10::ivalue::Tuple::create(std::move(elements))); break; } } } break; case PickleOpCode::TUPLE1: { TORCH_CHECK( !stack_.empty(), "Parsing error: stack_ contains ", stack_.size(), " elements, at least 1 expected"); stack_.emplace_back(c10::ivalue::Tuple::create(pop(stack_))); } break; case PickleOpCode::TUPLE2: { TORCH_CHECK( stack_.size() > 1, "Parsing error: stack_ contains ", stack_.size(), " elements, at least 2 expected"); auto e2 = pop(stack_); auto e1 = pop(stack_); stack_.emplace_back( c10::ivalue::Tuple::create(std::move(e1), std::move(e2))); } break; case PickleOpCode::TUPLE3: { TORCH_CHECK( stack_.size() > 2, "Parsing error: stack_ contains ", stack_.size(), " elements, at least 3 expected"); auto e3 = pop(stack_); auto e2 = pop(stack_); auto e1 = pop(stack_); stack_.emplace_back(c10::ivalue::Tuple::create( std::move(e1), std::move(e2), std::move(e3))); } break; case PickleOpCode::EMPTY_DICT: stack_.emplace_back( c10::impl::GenericDict(AnyType::get(), AnyType::get())); break; case PickleOpCode::APPENDS: { TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty"); size_t start = marks_.back(); TORCH_CHECK( start > 0 && start <= stack_.size(), "Parsing error: wrong start index ", start, " for stack_ of size ", stack_.size()); auto list_ivalue = stack_.at(start - 1); readList(list_ivalue); } break; case PickleOpCode::APPEND: { TORCH_CHECK( stack_.size() >= 2, "Parsing error: missing elements in stack_."); auto list_ivalue = stack_.at(stack_.size() - 2); readListElements(list_ivalue, stack_.size() - 1); } break; case PickleOpCode::LIST: { IValue list_ivalue = c10::impl::GenericList(AnyType::get()); readList(list_ivalue); stack_.push_back(std::move(list_ivalue)); } break; case PickleOpCode::DICT: { TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty"); size_t start = marks_.back(); marks_.pop_back(); TORCH_CHECK( stack_.size() > start, "Parsing error: wrong start index ", start, " for stack_ which of size ", stack_.size()); auto dict = c10::impl::GenericDict(AnyType::get(), AnyType::get()); TORCH_CHECK( (stack_.size() - start) % 2 == 0, "Parsing error: stack_ is of size ", stack_.size(), " and start index is ", start, ", but stack_ is iterated by two elements at a time"); for (size_t i = start; i < stack_.size(); i += 2) { dict.insert_or_assign(stack_[i], stack_[i + 1]); } stack_.erase( stack_.begin() + static_cast(start), stack_.end()); stack_.emplace_back(std::move(dict)); } break; case PickleOpCode::SETITEMS: { TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty"); size_t start = marks_.back(); marks_.pop_back(); TORCH_CHECK( start > 0 && start <= stack_.size(), "Parsing error: wrong start index for stack_"); auto dict = stack_.at(start - 1).toGenericDict(); TORCH_CHECK( (stack_.size() - start) % 2 == 0, "Parsing error: stack_ is of size ", stack_.size(), " and start index is ", start, ", but stack_ is iterated by two elemenst at a time"); for (size_t i = start; i < stack_.size(); i += 2) { dict.insert_or_assign(stack_[i], stack_[i + 1]); } stack_.erase( stack_.begin() + static_cast(start), stack_.end()); } break; case PickleOpCode::BINGET: { auto pos = read(); TORCH_CHECK( memo_table_.size() > pos, "Parsing error: out of bounds access at ", (size_t)pos, " to memo_table_ which is of size ", memo_table_.size()); stack_.push_back(memo_table_.at(pos)); } break; case PickleOpCode::LONG_BINGET: { auto pos = read(); TORCH_CHECK( memo_table_.size() > pos, "Parsing error: out of bounds access at ", (size_t)pos, " to memo_table_ which is of size ", memo_table_.size()); stack_.push_back(memo_table_.at(pos)); } break; case PickleOpCode::STOP: break; case PickleOpCode::GLOBAL: { // Module name, it's not needed for anything auto module_name = readString(); auto class_name = readString(); readGlobal(module_name, class_name); } break; case PickleOpCode::NEWOBJ: { TORCH_CHECK(!stack_.empty(), "Parsing error: stack_ is empty"); // pop empty tuple, the actual action is stored in the globals_stack_ stack_.pop_back(); } break; // because we have NEWOBJ do nothing, BUILD and REDUCE end up doing // the same thing case PickleOpCode::BUILD: case PickleOpCode::REDUCE: { // stack is: // extract and remove from the stack: TORCH_CHECK( stack_.size() > 1, "Parsing error: stack_ contains ", stack_.size(), " elements, at least 2 expected"); std::swap(*(stack_.end() - 2), *(stack_.end() - 1)); size_t idx = stack_.back().toInt(); stack_.pop_back(); // stack is: TORCH_CHECK( idx < globals_.size(), "Parsing error: out of bounds access to globals_"); globals_.at(idx)(); } break; case PickleOpCode::BINPERSID: { TORCH_CHECK(!stack_.empty(), "Parsing error: stack_ is empty"); auto tuple = pop(stack_).toTuple(); const auto& args = tuple->elements(); AT_ASSERT( args.at(0).toStringRef() == "storage", "unknown PERSID key ", args.at(0).toStringRef()); at::ScalarType type = args.at(1).toScalarType(); const std::string& key = args.at(2).toStringRef(); at::Device device(args.at(3).toStringRef()); // remap device location if it's not meta if (device_ && !device.is_meta()) { device = *device_; } at::Storage storage; if (storage_context_ != nullptr && storage_context_->hasStorage(key)) { // for torch.package logic where storage may be loaded already storage = storage_context_->getStorage(key); } else { int64_t numel = args.at(4).toInt(); caffe2::TypeMeta dtype = at::CPU(type).typeMeta(); at::DataPtr storage_ptr; if (numel > 0) { // If there are no elements in the tensor, there's no point in // reading a zero (0) byte file from the input stream and paying // that cost. storage_ptr = read_record_(key); } storage = at::Storage( c10::Storage::use_byte_size_t(), numel * dtype.itemsize(), std::move(storage_ptr), /*allocator=*/nullptr, /*resizable=*/false); // NB: we didn't set any allocator for the // tensor if (storage_context_ != nullptr) { storage_context_->addStorage(key, storage); } } auto options = at::CPU(type).options(); if (use_storage_device_) { options = options.device(storage.device()); device = storage.device(); } at::Tensor tensor; if (options.backend() == c10::Backend::QuantizedCPU) { tensor = at::_empty_affine_quantized({}, options, 0, 0) .set_(storage, 0, {}, {}); } else { tensor = at::empty({0}, options).set_(storage); } if (device.is_cuda() || device.is_xpu() || device.is_meta() || device.is_hpu() || device.is_mps() || device.is_privateuseone()) { tensor = tensor.to(device, tensor.scalar_type()); } else if (device.type() != DeviceType::CPU) { AT_ERROR( "supported devices include CPU, CUDA, HPU and ", c10::get_privateuse1_backend(), " however got ", DeviceTypeName(device.type(), false)); } stack_.emplace_back(std::move(tensor)); } break; case PickleOpCode::SETITEM: { // At this OpCode, stack looks like // | Stack Bottom | // | ...... | // | Dict | -> (stack_size - 3) // | Key | -> (stack_size - 2) // | Value | -> (stack_size - 1) TORCH_CHECK( stack_.size() >= 3, "Parsing error: stack doesn't have enough elements"); auto stack_size = stack_.size(); auto dict_pos = stack_size - 3; auto key_pos = stack_size - 2; auto val_pos = stack_size - 1; TORCH_CHECK( (dict_pos < stack_size) && (key_pos < stack_size) && (val_pos < stack_size), "Parsing error: attempted out-of-bounds access while processing SETITEM opcode"); auto dict = stack_.at(dict_pos).toGenericDict(); dict.insert_or_assign(stack_.at(key_pos), stack_.at(val_pos)); stack_.erase( stack_.begin() + static_cast(key_pos), stack_.end()); } break; default: { AT_ERROR( "Unknown opcode for unpickling at ", // NOLINTNEXTLINE(performance-no-int-to-ptr) reinterpret_cast(opcode), ": ", int(static_cast(opcode))); } break; } return opcode; } void Unpickler::readGlobal( const std::string& module_name, const std::string& class_name) { if (this->skip_next_read_global) { // See [NOTE] skip_next_read_global this->skip_next_read_global--; if (this->skip_next_read_global == 1) { // Pass through to the correct handler } else if (this->skip_next_read_global == 0) { // Corresponds to the type of `Tensor` being unpickled if (module_name != "torch" || class_name != "Tensor") { TORCH_WARN( "Trying to load a Subclassed Tensor, it will be converted to at::Tensor in C++"); } stack_.emplace_back(int64_t(globals_.size() - 1)); return; } else { TORCH_CHECK(false, "INVALID VALUES") } } // TODO [unpickler refactor] __main__ isn't used by the pickler anymore, this // is only here for bc-compatibility reasons if (module_name == "__main__") { if (class_name == "TensorID") { globals_.emplace_back([this] { auto setitem_data = stack_.back(); stack_.pop_back(); TORCH_INTERNAL_ASSERT( !tensor_table_.empty(), "Pickler tried to write a tensor but had no tensor table to write to"); stack_.emplace_back(tensor_table_.at(setitem_data.toInt())); }); } else if (class_name == "IntList") { globals_.emplace_back([this] { stack_.back().toList().unsafeSetElementType(IntType::get()); }); } else { AT_ERROR("Unknown pickler class id", class_name); } } else if (module_name == "torch.jit._pickle") { if (class_name == "build_tensor_from_id") { globals_.emplace_back([this] { // Pop reduce arg off the stack auto data = stack_.back().toTupleRef().elements().at(0); stack_.pop_back(); TORCH_CHECK( !tensor_table_.empty(), "Found a tensor table reference but Unpickler" " has no tensor table\n"); stack_.emplace_back(tensor_table_.at(data.toInt())); }); } else if (class_name == "restore_type_tag") { globals_.emplace_back([this] { auto tuple = stack_.back().toTuple(); const auto& data = tuple->elements(); auto type_str = data.at(1).toStringRef(); stack_.pop_back(); TypePtr type = nullptr; auto entry = type_cache_.find(type_str); if (entry != type_cache_.end()) { type = entry->second; } else { if (type_resolver_ == nullptr) { // If we haven't injected a custom way of retrieving types from // names, use a barebones type parser. type = type_parser_(type_str); } else { type = type_resolver_(type_str).type_; } type_cache_[type_str] = type; } // TODO: Use lookahead to avoid creating the tuple and immediately // destroying it here restoreContainerTypeTags(data.at(0), type); stack_.emplace_back(data.at(0)); }); } else { TypePtr elem_type = nullptr; if (class_name == "build_intlist") { elem_type = IntType::get(); } else if (class_name == "build_tensorlist") { elem_type = TensorType::get(); } else if (class_name == "build_doublelist") { elem_type = FloatType::get(); } else if (class_name == "build_boollist") { elem_type = BoolType::get(); } else { AT_ERROR("Unknown pickler class id ", class_name); } // Unpickle a list specialization (e.g. List[Tensor], List[int], ...) globals_.emplace_back([this, elem_type] { // Pop reduce arg off the stack auto data = stack_.back().toTupleRef().elements().at(0).toList(); stack_.pop_back(); data.unsafeSetElementType(elem_type); stack_.emplace_back(std::move(data)); }); } } else if ( module_name == "torch._utils" && (class_name == "_rebuild_tensor_v2" || class_name == "_rebuild_qtensor")) { // Unpickle a tensor bool quantized = class_name == "_rebuild_qtensor"; rebuildTensor(quantized); } else if ( module_name == "torch._tensor" && (class_name == "_rebuild_from_type_v2")) { // Unpickle a Tensor with Python attributes or // a Subclassed Tensor. rebuildTensorFromTypeV2(); } else if ( module_name == "torch._utils" && class_name == "_rebuild_sparse_tensor") { rebuildSparseTensor(); } else if (module_name == "builtins" && class_name == "complex") { globals_.emplace_back([this] { auto tuple = pop(stack_).toTuple(); const auto& elems = tuple->elements(); AT_ASSERT(elems.size() == 2); auto complex = c10::complex(elems.at(0).toDouble(), elems.at(1).toDouble()); stack_.emplace_back(complex); }); } else if (module_name == "collections" && class_name == "OrderedDict") { // collections.OrderedDict is used in tensor serialization for a tensor's // backward hooks (but they are not actually saved with this Pickler) globals_.emplace_back([this] { // drop the Tuple that was argument to OrderedDict, and replace it // with None OrderedDicts only appear in tensor deserialization and // their value is never used stack_.back() = IValue(); }); } else if (module_name == "torch" && class_name == "device") { globals_.emplace_back([this] { auto device_string = stack_.back().toTupleRef().elements().at(0); stack_.pop_back(); stack_.emplace_back(c10::Device(device_string.toStringRef())); }); stack_.emplace_back(int64_t(globals_.size() - 1)); return; } else if (module_name == "torch.distributed.rpc" && class_name == "rref") { #ifdef USE_RPC return rebuildRRef(); #else TORCH_INTERNAL_ASSERT( false, "RRef unpickling is only supported with the distributed package"); #endif } else if (module_name == "torch") { // Try to manually resolve several global enums // NOTE: this does not put a global into the global table, // like the other branches here because no REDUCE or BUILD will // be called on this value. Instead, we just put it on the stack // and return early std::optional scalar_type; #define CHECK_SCALAR(_, name) \ if (class_name == #name "Storage") { \ scalar_type = c10::k##name; \ } AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CHECK_SCALAR) #undef CHECK_SCALAR if (scalar_type.has_value()) { stack_.emplace_back(int64_t(*scalar_type)); return; } std::optional qscheme; for (int i = 0; i < at::COMPILE_TIME_NUM_QSCHEMES; ++i) { if (class_name == toString(static_cast(i))) { qscheme = static_cast(i); } } if (qscheme.has_value()) { stack_.emplace_back(int64_t(*qscheme)); return; } TORCH_CHECK( false, "Unpickler found unknown torch global, 'torch.", class_name, "'"); } else { TORCH_CHECK( type_resolver_, "Unpickler found unknown type ", module_name, ".", class_name); at::StrongTypePtr type = type_resolver_(c10::QualifiedName(module_name, class_name)); if (auto enum_type = type.type_->cast()) { globals_.emplace_back([this, enum_type] { auto val = stack_.back(); stack_.pop_back(); for (const auto& p : enum_type->enumNamesValues()) { if (p.second == val) { auto enum_holder = c10::make_intrusive( enum_type, p.first, p.second); stack_.emplace_back(std::move(enum_holder)); return; } } }); } else { // Otherwise, global is a class/object type. globals_.emplace_back([this, type] { auto val = stack_.back(); stack_.pop_back(); auto obj = obj_loader_(type, val); stack_.emplace_back(std::move(obj)); }); } } stack_.emplace_back(int64_t(globals_.size() - 1)); } void Unpickler::rebuildSparseTensor() { globals_.emplace_back([this] { auto tup = pop(stack_).toTuple(); const auto& elements = tup->elements(); size_t idx = 0; auto layout = elements.at(idx++).toInt(); at::Tensor result; switch (layout) { case static_cast(c10::Layout::Sparse): { std::vector size = tupleToIntList(elements.at(idx++)); bool requires_grad = elements.at(idx++).toBool(); auto& indices_tensor = elements.at(idx++).toTensor(); auto& values_tensor = elements.at(idx++).toTensor(); auto options = values_tensor.options() .layout(c10::Layout::Sparse) .requires_grad(requires_grad); result = at::_sparse_coo_tensor_unsafe( indices_tensor, values_tensor, size, options); result = autograd::make_variable(result, options.requires_grad()); break; } case static_cast(c10::Layout::SparseCsr): { std::vector size = tupleToIntList(elements.at(idx++)); bool requires_grad = elements.at(idx++).toBool(); auto& crow_indices = elements.at(idx++).toTensor(); auto& col_indices = elements.at(idx++).toTensor(); auto& values_tensor = elements.at(idx++).toTensor(); auto options = values_tensor.options() .layout(c10::Layout::SparseCsr) .requires_grad(requires_grad); result = at::_sparse_csr_tensor_unsafe( crow_indices, col_indices, values_tensor, size, options); result = autograd::make_variable(std::move(result), options.requires_grad()); break; } default: TORCH_CHECK( false, "Unsupported sparse tensor layout type in serialization ", static_cast(layout)); break; } stack_.emplace_back(std::move(result)); }); } void Unpickler::rebuildTensor(bool quantized) { globals_.emplace_back([this, quantized] { auto tup = pop(stack_).toTuple(); const auto& elements = tup->elements(); size_t idx = 0; auto& storage_tensor = elements.at(idx++).toTensor(); int64_t storage_offset = elements.at(idx++).toInt(); std::vector size = tupleToIntList(elements.at(idx++)); std::vector stride = tupleToIntList(elements.at(idx++)); at::Tensor result; if (quantized) { auto qparams_tuple = elements.at(idx++).toTuple(); const auto& qparams = qparams_tuple->elements(); auto qscheme = static_cast(qparams.at(0).toInt()); switch (qscheme) { case at::kPerTensorAffine: { double q_scale = qparams.at(1).toDouble(); int64_t q_zero_point = qparams.at(2).toInt(); result = at::_empty_affine_quantized( {0}, storage_tensor.options(), q_scale, q_zero_point); } break; case at::kPerChannelAffineFloatQParams: case at::kPerChannelAffine: { const auto& scales = qparams.at(1).toTensor(); const auto& zero_points = qparams.at(2).toTensor(); int64_t axis = qparams.at(3).toInt(); result = at::_empty_per_channel_affine_quantized( {0}, scales, zero_points, axis, storage_tensor.options()); } break; default: TORCH_CHECK( false, "Unsupported tensor quantization type in serialization ", toString(qscheme)); break; } } else { result = at::empty({0}, storage_tensor.options()); } bool requires_grad = elements.at(idx++).toBool(); idx++; // backwards hooks is empty at::TensorImpl* impl = result.unsafeGetTensorImpl(); impl->set_storage_keep_dtype(storage_tensor.storage()); impl->set_storage_offset(storage_offset); impl->set_sizes_and_strides(size, stride); result = autograd::make_variable(result, requires_grad); // Handle if math_bits were pickled. // See `args` of _reduce_ex_internal // for a regular tensor (final else case). // Tensors pickled before this patch didn't // have this argument for storing MathBits, // in that case, we do nothing. // NOTE: `math_bits` is the 7th arg. // NOTE: This is only meant for regular tensor and not quantized // which also has 7 args serialized. if (!quantized && elements.size() == 7) { auto math_bits = elements.at(idx++).toGenericDict(); torch::jit::setTensorMetadata(result, math_bits); } stack_.emplace_back(std::move(result)); }); } void Unpickler::rebuildTensorFromTypeV2() { // [NOTE] skip_next_read_global // When rebuilding Tensor with Python Attr or Subclassed Tensor, // we receive `(func, type(self), args, state)` on stack for // `rebuildTensorFromTypeV2`. // Thus next call to readGlobal corresponds to `func` which is // the function to rebuild the base tensor. // The call after `func` to readGlobal corresponds to `type` of the // Tensor where we raise warning if the type is not `torch.Tensor`. this->skip_next_read_global = 2; auto curr_globals_idx = globals_.size(); globals_.emplace_back([this, curr_globals_idx] { // args is a tuple with following data // (function to rebuild base tensor, type of tensor, // arguments to construct base tensor, Python State (as dict)) auto args = pop(stack_).toTuple(); size_t tup_idx = 0; const auto args_elems = args->elements(); auto base_tensor_args = args_elems.at(tup_idx + 2).toTuple(); auto py_state = args_elems.at(tup_idx + 3).toGenericDict(); if (!py_state.empty()) { TORCH_WARN( "Loading Tensor with Python attributes will return at::Tensor with Python attributes being discarded"); } // This calls the function to rebuild the // base tensor. // Eg. `rebuildTensor`, `rebuildSpareTensor`. stack_.emplace_back(base_tensor_args); globals_[curr_globals_idx + 1](); stack_.emplace_back(pop(stack_)); }); } #ifdef USE_RPC void Unpickler::rebuildRRef() { globals_.emplace_back([this] { // It is the same as how rref is unpickled in python, // see PyRRef::unpickle auto tuple = std::move(stack_.back()).toTuple(); const auto& args = tuple->elements(); stack_.pop_back(); TORCH_INTERNAL_ASSERT( args.size() == distributed::rpc::RFD_TUPLE_SIZE, "Pickled RRefForkData must contain 7 numbers."); auto ownerId = static_cast(args.at(distributed::rpc::OWNER_IDX).toInt()); // const reference will extend the lifetime of the temporary variable const auto& rrefId = distributed::rpc::RRefId( static_cast(args.at(distributed::rpc::RREFID_ON_IDX).toInt()), static_cast(args.at(distributed::rpc::RREFID_ID_IDX).toInt())); const auto& forkId = distributed::rpc::RRefId( static_cast(args.at(distributed::rpc::FORKID_ON_IDX).toInt()), static_cast(args.at(distributed::rpc::FORKID_ID_IDX).toInt())); auto parent = static_cast(args.at(distributed::rpc::PARENT_IDX).toInt()); const auto& typeStr = static_cast( args.at(distributed::rpc::TYPE_IDX).toStringRef()); auto rrefForkData = distributed::rpc::RRefForkData( ownerId, rrefId, forkId, parent, typeStr); auto& ctx = distributed::rpc::RRefContext::getInstance(); c10::intrusive_ptr rref; TORCH_INTERNAL_ASSERT( type_resolver_ != nullptr, "type_resolver_ is nullptr."); at::StrongTypePtr type = type_resolver_(c10::QualifiedName(typeStr)); rref = ctx.getOrCreateRRef(rrefForkData, type.type_); ctx.notifyOwnerAndParentOfFork( rrefForkData.forkId_, rrefForkData.parent_, rref); stack_.emplace_back( c10::static_intrusive_pointer_cast(rref)); }); stack_.emplace_back(int64_t(globals_.size() - 1)); return; } #endif void Unpickler::readSlowWithBuffer(char* dest, size_t sz) { // First, read any partial from buffer (may be 0). // We explicitly assume that sz > buffer_remaining_, // and that sz is never bigger than buffer_.size(). AT_ASSERT(sz > buffer_remaining_); const size_t from_old_buf = buffer_remaining_; if (from_old_buf != 0) { memcpy(dest, buffer_.data() + buffer_pos_, from_old_buf); } const size_t needed = sz - from_old_buf; // Full read into the buffer. The calls here all explicitly // assume that one buffer will be enough for any sz. AT_ASSERT(sz <= buffer_.size()); buffer_remaining_ = reader_(buffer_.data(), buffer_.size()); if (buffer_remaining_ < needed) { AT_ERROR("Unexpected end of pickler archive."); } memcpy(dest + from_old_buf, buffer_.data(), needed); buffer_pos_ = needed; // assignment (0'ed from read) buffer_remaining_ -= needed; } // Read a number of bytes from the input stream std::string Unpickler::readBytes(size_t length) { std::string data; static const size_t kSmallString = 64; TORCH_CHECK( length <= data.max_size(), "Parsing error: can't read ", length, " bytes to a string"); if (length <= buffer_remaining_) { // Fast-path: entirely in buffer. data.assign(buffer_.data() + buffer_pos_, length); buffer_pos_ += length; buffer_remaining_ -= length; } else if (length <= kSmallString) { // If the string is smallish, do a full buffer read, // and read out of that buffer. data.resize(length); readSlowWithBuffer(&data[0], length); } else { // Otherwise, for larger strings, read what we can from // the buffer, and then read directly to the destination. const size_t from_old_buf = buffer_remaining_; if (from_old_buf != 0) { data.reserve(length); data.append(buffer_.data() + buffer_pos_, from_old_buf); } data.resize(length); const size_t needed = length - from_old_buf; size_t nread = reader_(&data[from_old_buf], needed); if (nread != needed) { AT_ERROR("Unexpected end of pickler archive."); } buffer_remaining_ = 0; // buffer_pos_ has no meaning with buffer_remaining_ == 0. } return data; } void Unpickler::readListElements(IValue list_ivalue, size_t start) { auto num_elements = stack_.size() - start; auto elements = c10::ArrayRef(stack_).slice(start); if (list_ivalue.isIntList()) { auto list = std::move(list_ivalue).toIntList(); list.reserve(num_elements); for (const auto& elem : elements) { list.emplace_back(elem.toInt()); } } else if (list_ivalue.isTensorList()) { auto list = std::move(list_ivalue).toTensorList(); list.reserve(num_elements); for (const auto& elem : elements) { list.emplace_back(elem.toTensor()); } } else if (list_ivalue.isDoubleList()) { auto list = std::move(list_ivalue).toDoubleList(); list.reserve(num_elements); for (const auto& elem : elements) { list.emplace_back(elem.toDouble()); } } else if (list_ivalue.isBoolList()) { auto list = std::move(list_ivalue).toBoolList(); list.reserve(num_elements); for (const auto& elem : elements) { list.push_back(elem.toBool()); } } else if (list_ivalue.isList()) { auto list = std::move(list_ivalue).toList(); list.reserve(num_elements); for (const auto& elem : elements) { list.emplace_back(elem); } } else { AT_ERROR("Unknown IValue list kind: ", list_ivalue.tagKind()); } stack_.erase( stack_.begin() + static_cast(start), stack_.end()); } // Pop all the list items off of the stack and append them to the list at // the corresponding MARK void Unpickler::readList(IValue list_ivalue) { TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty"); size_t start = marks_.back(); marks_.pop_back(); readListElements(std::move(list_ivalue), start); } inline bool is_valid_python_id_char(char c) { return c == '_' || c == '.' || (c >= '0' && c <= '9') || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'); } // Read a newline terminated string std::string Unpickler::readString() { std::string ss; while (true) { auto* const bufferStart = buffer_.data() + buffer_pos_; const auto bufferLeft = buffer_.size() - buffer_pos_; char* const newlinePtr = static_cast(memchr(bufferStart, '\n', bufferLeft)); if (newlinePtr) { // read up to newline and we are done. auto const charsRead = newlinePtr - bufferStart; ss.append(bufferStart, charsRead); buffer_remaining_ -= charsRead + 1; buffer_pos_ += charsRead + 1; break; } else { // read whole buffer, refill for (const char* p = bufferStart; p < bufferStart + bufferLeft; ++p) { // Simple check just in case there is no terminating '\n' TORCH_CHECK( is_valid_python_id_char(*p), "Found character '", int(uint8_t(*p)), "' in string, ", "strings must be qualified Python identifiers"); } ss.append(bufferStart, bufferLeft); buffer_remaining_ = reader_(buffer_.data(), buffer_.size()); buffer_pos_ = 0; } } return ss; } } // namespace torch::jit