#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace c10 { bool _fastEqualsForContainer(const IValue& lhs, const IValue& rhs) { if (lhs.is(rhs)) { // Like Python, for containers we consider identity equality to be // sufficient but not necessary for value equality return true; } return lhs == rhs; } namespace ivalue { // This is in ivalue.cpp because we need to access Type::annotation_str, which // is declared in jit_type.h void checkCustomClassType(const ClassType* expected_type, const Type* actual_type) { // NB: doing pointer comparison here // If in the future there ever arises a need to call operator== on custom class // Type's, this needs to be changed! TORCH_CHECK(actual_type == static_cast(expected_type), "Tried to convert an IValue of type ", actual_type ? actual_type->repr_str() : std::string("*NULL*"), " to custom class type ", expected_type ? expected_type->repr_str() : std::string("*NULL*")); } TORCH_API c10::intrusive_ptr ConstantString::create( std::string str_) { return c10::make_intrusive(std::move(str_)); } TORCH_API c10::intrusive_ptr ConstantString::create( c10::string_view str_) { return c10::make_intrusive(std::string(str_)); } TORCH_API c10::intrusive_ptr ConstantString::create( const char* str_) { return c10::make_intrusive(std::string(str_)); } bool operator==(const ivalue::Tuple& lhs, const ivalue::Tuple& rhs) { return lhs.size() == rhs.size() && // see [container equality] std::equal( lhs.elements().cbegin(), lhs.elements().cend(), rhs.elements().cbegin(), _fastEqualsForContainer); } std::ostream& operator<<(std::ostream& out, const ivalue::EnumHolder& v) { out << v.qualifiedClassName() << "." << v.name(); return out; } bool operator==(const ivalue::EnumHolder& lhs, const ivalue::EnumHolder& rhs) { return lhs.name() == rhs.name() && *rhs.type() == *lhs.type(); } const std::string& ivalue::EnumHolder::qualifiedClassName() const { return type_->qualifiedClassName().qualifiedName(); } const std::string& ivalue::EnumHolder::unqualifiedClassName() const { return type_->qualifiedClassName().name(); } } // namespace ivalue c10::TypePtr IValue::TagType::get(const IValue& v) { switch (v.tag) { case Tag::None: return NoneType::get(); case Tag::Tensor: return TensorType::create(v.toTensor()); case Tag::Storage: return StorageType::get(); case Tag::Double: return FloatType::get(); case Tag::ComplexDouble: return ComplexType::get(); case Tag::Int: return IntType::get(); case Tag::SymInt: return c10::SymIntType::get(); case Tag::SymFloat: return c10::SymFloatType::get(); case Tag::SymBool: return c10::SymBoolType::get(); case Tag::Bool: return BoolType::get(); case Tag::String: return StringType::get(); case Tag::Blob: return AnyType::get(); case Tag::GenericDict: { auto d = v.toGenericDict(); return DictType::create(d.keyType(), d.valueType()); } case Tag::GenericList: return ListType::create(v.toList().elementType()); case Tag::Await: return AwaitType::create(v.toAwait()->elementType()); case Tag::Future: return FutureType::create(v.toFuture()->elementType()); case Tag::RRef: return RRefType::create(v.toRRef()->type()); case Tag::Device: return DeviceObjType::get(); case Tag::Stream: return StreamObjType::get(); case Tag::Object: return v.toObjectRef().type(); case Tag::PyObject: return PyObjectType::get(); case Tag::Uninitialized: return AnyType::get(); case Tag::Capsule: return CapsuleType::get(); case Tag::Tuple: return v.toTupleRef().type(); case Tag::Generator: return GeneratorType::get(); case Tag::Quantizer: return QuantizerType::get(); case Tag::Enum: return v.toEnumHolder()->type(); } // switch above is complete but this silences compiler warnings TORCH_INTERNAL_ASSERT(false, "unhandled case in IValue::type()"); // This static_assert has to go into some IValue member function; I // chose this one. It's not in the class body because that's in // ivalue.h, which is a very high-fanout header file and we want to // minimize build time. static_assert( kNumTags <= 32, "IValue::isIntrusivePtr needs to be updated because it assumes there are at most 32 tags"); } void IValue::visit(const std::function& visitor) const { if (visitor(*this)) { // Shortcut return; } switch (this->tag) { case Tag::Tuple: case Tag::GenericList: { c10::ArrayRef elems; if (isTuple()) { elems = this->toTupleRef().elements(); } else { elems = this->toListRef(); } for (auto& elem : elems) { elem.visit(visitor); } break; } case Tag::GenericDict: for (const auto& pair : this->toGenericDict()) { pair.value().visit(visitor); pair.key().visit(visitor); } break; case Tag::Object: { auto obj_type = type()->expect(); auto obj_value = toObject(); auto attributes = obj_type->getAttributes(); for (const auto& attr: attributes) { auto attribute = obj_value->getAttr(attr.getName()); attribute.visit(visitor); } break; } case Tag::PyObject: { c10::intrusive_ptr py_obj = toPyObjectHolder(); auto match = py_obj->tryToInferType(); if (match.success()) { auto contained_value = py_obj->toIValue(match.type()); contained_value.visit(visitor); } break; } default: break; } } void IValue::getSubValues(HashAliasedIValues& subValues) const { switch (this->tag) { case Tag::Tensor: subValues.insert(*this); return; case Tag::Tuple: case Tag::GenericList: { subValues.insert(*this); c10::ArrayRef elems; if (isTuple()) { elems = this->toTupleRef().elements(); } else { elems = this->toListRef(); } for (auto& elem : elems) { elem.getSubValues(subValues); } break; } case Tag::GenericDict: subValues.insert(*this); for (const auto& pair : this->toGenericDict()) { pair.value().getSubValues(subValues); pair.key().getSubValues(subValues); } break; case Tag::Object: { // Record Object IValue and its attributes. subValues.insert(*this); auto obj_type = type()->expect(); auto obj_value = toObject(); auto attributes = obj_type->getAttributes(); for (const auto& attr: attributes) { auto attribute = obj_value->getAttr(attr.getName()); attribute.getSubValues(subValues); } break; } case Tag::PyObject: { subValues.insert(*this); c10::intrusive_ptr py_obj = toPyObjectHolder(); auto match = py_obj->tryToInferType(); TORCH_CHECK_TYPE(match.success(), "Cannot infer type of ", py_obj->toStr(), ": ", match.reason()); auto contained_value = py_obj->toIValue(match.type()); contained_value.getSubValues(subValues); break; } case Tag::Future: case Tag::Await: case Tag::Device: case Tag::Uninitialized: case Tag::Capsule: TORCH_CHECK_TYPE( false, "Cannot inspect value of type ", this->tagKind()); default: // don't record scalars. break; } } bool IValue::overlaps(const IValue& rhs) const { HashAliasedIValues rhsSubValues, thisSubValues; rhs.getSubValues(rhsSubValues); getSubValues(thisSubValues); for (auto& sub : thisSubValues) { if (rhsSubValues.count(sub)) { return true; } } return false; } bool operator!=(const IValue& lhs, const IValue& rhs) { return !(lhs == rhs); } bool operator==(const IValue& lhs, const IValue& rhs) { IValue eq = lhs.equals(rhs); if (eq.isBool()) { return eq.toBool(); } // The only case we don't return bool is for tensor comparison. In Python, // `bool()` is called on the return value of `__eq__` if the return value is // not a boolean. Mimic that behavior here. TORCH_INTERNAL_ASSERT(eq.isTensor()); return eq.toTensor().is_nonzero(); } bool IValue::ptrEqual(const IValue& lhs, const IValue& rhs) { TORCH_INTERNAL_ASSERT(lhs.isIntrusivePtr()); TORCH_INTERNAL_ASSERT(rhs.isIntrusivePtr()); return lhs.tag == rhs.tag && lhs.payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr; } IValue IValue::equals(const IValue& rhs) const { const IValue& lhs = *this; switch (lhs.tag) { case Tag::None: // In Python you're not supposed to do this comparison apparently. Not // sure if we should warn here or what return rhs.isNone(); case Tag::Tensor: { if (!rhs.isTensor()) { return false; } return lhs.toTensor().eq(rhs.toTensor()); } case Tag::Storage: return rhs.isStorage() && lhs.toStorage().unsafeGetStorageImpl() == rhs.toStorage().unsafeGetStorageImpl(); case Tag::Double: return rhs.isDouble() && lhs.toDouble() == rhs.toDouble(); case Tag::ComplexDouble: return rhs.isComplexDouble() && lhs.toComplexDouble() == rhs.toComplexDouble(); case Tag::Int: return rhs.isInt() && lhs.toInt() == rhs.toInt(); case Tag::SymInt: return rhs.isSymInt() && lhs.toSymInt() == rhs.toSymInt(); case Tag::SymFloat: return rhs.isSymFloat() && lhs.toSymFloat() == rhs.toSymFloat(); case Tag::SymBool: return rhs.isSymBool() && lhs.toSymBool() == rhs.toSymBool(); case Tag::Bool: return rhs.isBool() && lhs.toBool() == rhs.toBool(); case Tag::String: return rhs.isString() && lhs.toStringRef() == rhs.toStringRef(); case Tag::GenericDict: return rhs.isGenericDict() && lhs.toGenericDict() == rhs.toGenericDict(); case Tag::Tuple: return rhs.isTuple() && *lhs.toTuple() == *rhs.toTuple(); case Tag::Stream: return rhs.isStream() && lhs.toStream() == rhs.toStream(); case Tag::Device: return rhs.isDevice() && lhs.toDevice() == rhs.toDevice(); case Tag::GenericList: return rhs.isList() && lhs.toList() == rhs.toList(); case Tag::Blob: case Tag::Future: case Tag::Await: case Tag::RRef: case Tag::Object: case Tag::PyObject: case Tag::Capsule: case Tag::Generator: case Tag::Quantizer: return ptrEqual(lhs, rhs); case Tag::Enum: return lhs.toEnumHolder()->is(*rhs.toEnumHolder()); case Tag::Uninitialized: // Unitialized ivalues show up in no-ops when the compiler can prove a // value will never be used. Just return false on any equality comparison. return false; } // the above switch should be exhaustive TORCH_INTERNAL_ASSERT(false, "we should never reach here") } size_t IValue::hash(const IValue& v) { switch (v.tag) { case Tag::None: return 0; case Tag::Bool: return c10::get_hash(v.payload.u.as_bool); case Tag::Double: return c10::get_hash(v.payload.u.as_double); case Tag::Tensor: // Tensor __hash__ is equivalent to `id()`, so take the pointer value of // the tensor to emulate it return c10::get_hash(v.payload.as_tensor.unsafeGetTensorImpl()); // NOLINTNEXTLINE(bugprone-branch-clone) case Tag::Storage: return c10::get_hash(v.payload.u.as_int); case Tag::Int: return c10::get_hash(v.payload.u.as_int); // NB: these are technically strict aliasing violations case Tag::SymInt: return c10::get_hash(v.payload.u.as_int); case Tag::SymFloat: return c10::get_hash(v.payload.u.as_int); case Tag::SymBool: return c10::get_hash(v.payload.u.as_int); case Tag::String: return c10::get_hash(v.toStringRef()); case Tag::Tuple: return c10::get_hash(*v.toTuple()); case Tag::Device: return c10::get_hash(v.toDevice()); case Tag::GenericDict: case Tag::GenericList: case Tag::Blob: case Tag::Future: case Tag::Await: case Tag::RRef: case Tag::Object: case Tag::PyObject: case Tag::Capsule: case Tag::Generator: case Tag::Quantizer: case Tag::ComplexDouble: case Tag::Enum: case Tag::Stream: case Tag::Uninitialized: throw std::runtime_error( "unhashable type: '" + v.type()->repr_str() + "'"); } // the above switch should be exhaustive TORCH_INTERNAL_ASSERT(false, "we should never reach here") } static bool isUndefinedTensor(const IValue& iv) { return iv.isTensor() && !iv.toTensor().defined(); } bool IValue::is(const IValue& rhs) const { const IValue& lhs = *this; // Special handling for undefined tensors: // 1. Undefined_tensor is None and vice versa. if ((isUndefinedTensor(lhs) && rhs.isNone()) || (lhs.isNone() && isUndefinedTensor(rhs))) { return true; } // 2. Undefined_tensor is Undefined_tensor. if (isUndefinedTensor(lhs) && isUndefinedTensor(rhs)) { return true; } if (lhs.isTensor()) { // Use the standard way of comparing two tensors for identity return rhs.isTensor() && lhs.toTensor().is_same(rhs.toTensor()); } if (lhs.isIntrusivePtr()) { return rhs.isIntrusivePtr() && ptrEqual(lhs, rhs); } return lhs == rhs; } template inline bool IValue::isListOf() const { // note: avoids calling type() to avoid extra referencing counting for the returned type. if (!isList()) { return false; } const auto& ty = static_cast(payload.u.as_intrusive_ptr)->elementType; if (ty->kind() == T::Kind) { return true; } return *ty == *TypeFactory::get(); } bool IValue::isDoubleList() const { return isListOf(); } bool IValue::isComplexDoubleList() const { return isListOf(); } bool IValue::isTensorList() const { return isListOf(); } bool IValue::isOptionalTensorList() const { if (!isList()) { return false; } const auto& ty = static_cast(payload.u.as_intrusive_ptr)->elementType; const auto& expected_ty = c10::getTypePtr>(); return expected_ty == ty; } bool IValue::isIntList() const { return isListOf(); } bool IValue::isSymIntList() const { return isListOf(); } bool IValue::isBoolList() const { return isListOf(); } namespace { using IValueFormatter = std::function; template std::ostream& printList( std::ostream& out, const T& list, const std::string& start, const std::string& finish, const IValueFormatter& formatter) { out << start; for (const auto i : c10::irange(list.size())) { if (i > 0) { out << ", "; } formatter(out, IValue(list[i])); } out << finish; return out; } // Properly disambiguate the type of an empty list std::ostream& printMaybeAnnotatedList( std::ostream& out, const IValue& the_list, const IValueFormatter& formatter) { auto list_elem_type = the_list.type()->containedType(0); if (the_list.toListRef().empty() || !elementTypeCanBeInferredFromMembers(list_elem_type)) { out << "annotate(" << the_list.type()->annotation_str() << ", "; printList(out, the_list.toListRef(), "[", "]", formatter); out << ")"; return out; } else { return printList(out, the_list.toListRef(), "[", "]", formatter); } } template std::ostream& printDict( std::ostream& out, const Dict& v, const IValueFormatter& formatter) { out << "{"; bool first = true; for (const auto& pair : v) { if (!first) { out << ", "; } formatter(out, pair.key()); out << ": "; formatter(out, pair.value()); first = false; } out << "}"; return out; } } // Properly disambiguate the type of an empty dict static std::ostream& printMaybeAnnotatedDict( std::ostream& out, const IValue& the_dict, const IValueFormatter& formatter) { auto value_type = the_dict.type()->castRaw()->getValueType(); if (the_dict.toGenericDict().empty() || !elementTypeCanBeInferredFromMembers(value_type)) { out << "annotate(" << the_dict.type()->annotation_str() << ","; printDict(out, the_dict.toGenericDict(), formatter) << ")"; } else { return printDict(out, the_dict.toGenericDict(), formatter); } return out; } static std::ostream& printComplex(std::ostream & out, const IValue & v) { c10::complex d = v.toComplexDouble(); IValue real(d.real()), imag(std::abs(d.imag())); auto sign = ""; if (d.imag() >= 0) { sign = "+"; } else { sign = "-"; } return out << real << sign << imag << "j"; } std::ostream& IValue::repr( std::ostream& out, std::function customFormatter) const { // First check if the caller has provided a custom formatter. Use that if possible. if (customFormatter(out, *this)) { return out; } const IValue& v = *this; // continue to use custom formatter in recursion auto formatter = [&](std::ostream& out, const IValue& input) { input.repr(out, customFormatter); }; switch (v.tag) { case IValue::Tag::None: return out << v.toNone(); case IValue::Tag::Double: { double d = v.toDouble(); int c = std::fpclassify(d); if ((c == FP_NORMAL || c == FP_ZERO ) && std::abs(d) < 1e10) { int64_t i = int64_t(d); if (double(i) == d) { // -0.0 (signed zero) needs to be parsed as -0. if (i == 0 && std::signbit(d)) { return out << "-" << i << "."; } return out << i << "."; } } auto orig_prec = out.precision(); return out << std::setprecision(std::numeric_limits::max_digits10) << d << std::setprecision(static_cast(orig_prec)); } case IValue::Tag::ComplexDouble: { return printComplex(out, v); } case IValue::Tag::Int: return out << v.toInt(); case IValue::Tag::SymInt: return out << v.toSymInt(); case IValue::Tag::SymFloat: return out << v.toSymFloat(); case IValue::Tag::SymBool: return out << v.toSymBool(); case IValue::Tag::Bool: return out << (v.toBool() ? "True" : "False"); case IValue::Tag::Tuple: { const auto& elements = v.toTupleRef().elements(); const auto& finish = elements.size() == 1 ? ",)" : ")"; return printList(out, elements, "(", finish, formatter); } case IValue::Tag::String: c10::printQuotedString(out, v.toStringRef()); return out; case IValue::Tag::GenericList: { return printMaybeAnnotatedList(out, *this, formatter); } case IValue::Tag::Device: { std::stringstream device_stream; device_stream << v.toDevice(); out << "torch.device("; c10::printQuotedString(out, device_stream.str()); return out << ")"; } case IValue::Tag::Generator: { auto generator = v.toGenerator(); out << "torch.Generator(device="; c10::printQuotedString(out, generator.device().str()); out << ", seed=" << generator.current_seed() << ")"; return out; } case IValue::Tag::GenericDict: return printMaybeAnnotatedDict(out, v, formatter); case IValue::Tag::Enum: { auto enum_holder = v.toEnumHolder(); return out << enum_holder->qualifiedClassName() << "." << enum_holder->name(); } case IValue::Tag::Object: { TORCH_INTERNAL_ASSERT(false, "repr() not defined on: ", v.tagKind(), ". Perhaps you've frozen a module with custom classes?"); } default: TORCH_INTERNAL_ASSERT(false, "repr() not defined on: ", v.tagKind()); } } static bool simpleClassTypeArg(const Argument& arg, const ClassTypePtr& type) { return arg.type() == type && !arg.kwarg_only() && !arg.default_value(); } torch::jit::Function* checkObjectSortSchema(const c10::ClassTypePtr& t, std::stringstream& why_not) { if (auto method = t->findMethod("__lt__")) { const auto& lt_schema = method->getSchema(); const auto& schema_args = lt_schema.arguments(); bool error = (schema_args.size() != 2 || !simpleClassTypeArg(schema_args[0], t) || !simpleClassTypeArg(schema_args[1], t) || lt_schema.returns().size() != 1 || lt_schema.returns()[0].type() != BoolType::get()); if (!error) { return method; } } why_not << "To sort a list of " << t->repr_str() << " it must define a " << "__lt__ method with two inputs of type " << t->repr_str() << " that " << "returns a bool"; return nullptr; } IValueComparator getLessThanComparator(const IValue& v) { if (v.isTensor()) { return [](const IValue& a, const IValue& b) { return a.toTensor().lt(b.toTensor()).is_nonzero(); }; } if (v.isDouble()) { return [](const IValue& a, const IValue& b) { return a.toDouble() < b.toDouble(); }; } if (v.isInt()) { return [](const IValue& a, const IValue& b) { return a.toInt() < b.toInt(); }; } if (v.isBool()) { return [](const IValue& a, const IValue& b) { return a.toBool() == false && b.toBool() == true; }; } if (v.isString()) { return [](const IValue& a, const IValue& b) { return a.toStringRef() < b.toStringRef(); }; } if (v.isTuple()) { const auto& elements = v.toTupleRef().elements(); size_t n = elements.size(); std::vector elements_lts; elements_lts.reserve(n); for (const auto i : c10::irange(n)) { elements_lts.push_back(getLessThanComparator(elements[i])); } return [elements_lts=std::move(elements_lts), n](const IValue& a, const IValue& b) { const auto& a_elements = a.toTupleRef().elements(); const auto& b_elements = b.toTupleRef().elements(); for (const auto i : c10::irange(n)) { if (elements_lts[i](a_elements[i], b_elements[i])) { return true; } if (a_elements[i] == b_elements[i]) { continue; } return false; } // Reaching here means two tuples are equal. return false; }; } if (v.isObject()) { std::stringstream why_not; torch::jit::Function* lt_func = checkObjectSortSchema(v.type()->expect(), why_not); if (!lt_func) { AT_ERROR(why_not.str()); } return [lt_func](const IValue& a, const IValue& b) { // Quick pass to satisfy "strict weak ordering" requirement if (a.is(b)) { return false; } torch::jit::Stack sort_stack; sort_stack.push_back(a); sort_stack.push_back(b); lt_func->run(sort_stack); return torch::jit::pop(sort_stack).toBool(); }; } AT_ERROR("IValues of type: ", v.tagKind(), " are not comparable"); } IValueComparator getGreaterThanComparator(const IValue& v) { auto lt = getLessThanComparator(v); return [lt = std::move(lt)](const IValue& a, const IValue& b) { return lt(b, a); // gt(a, b) === lt(b, a) }; } std::ostream& operator<<(std::ostream & out, const IValue & v) { auto formatter = [&](std::ostream& out, const IValue& v) { out << v; }; switch(v.tag) { case IValue::Tag::None: return out << v.toNone(); case IValue::Tag::Tensor: return out << v.toTensor(); case IValue::Tag::Storage: return out << v.toStorage().unsafeGetStorageImpl(); case IValue::Tag::Double: { double d = v.toDouble(); int c = std::fpclassify(d); if (c == FP_NORMAL || c == FP_ZERO) { int64_t i = int64_t(d); if (double(i) == d) { return out << i << "."; } } auto orig_prec = out.precision(); return out << std::setprecision(std::numeric_limits::max_digits10) << v.toDouble() << std::setprecision(static_cast(orig_prec)); } case IValue::Tag::ComplexDouble: { return printComplex(out, v); } case IValue::Tag::Int: return out << v.toInt(); case IValue::Tag::SymInt: return out << v.toSymInt(); case IValue::Tag::SymFloat: return out << v.toSymFloat(); case IValue::Tag::SymBool: return out << v.toSymBool(); case IValue::Tag::Bool: return out << (v.toBool() ? "True" : "False"); case IValue::Tag::Tuple: { const auto& elements = v.toTupleRef().elements(); const auto& finish = elements.size() == 1 ? ",)" : ")"; return printList(out, elements, "(", finish, formatter); } case IValue::Tag::String: return out << v.toStringRef(); case IValue::Tag::Blob: return out << *v.toBlob(); case IValue::Tag::Capsule: return out << "Capsule"; case IValue::Tag::GenericList: return printList(out, v.toList(), "[", "]", formatter); case IValue::Tag::RRef: return out << "RRef"; case IValue::Tag::Future: return out << "Future"; case IValue::Tag::Await: return out << "Await"; case IValue::Tag::Uninitialized: return out << "Uninitialized"; case IValue::Tag::Device: return out << v.toDevice(); case IValue::Tag::Stream: return out << v.toStream(); case IValue::Tag::GenericDict: return printDict(out, v.toGenericDict(), formatter); case IValue::Tag::PyObject: { auto py_obj = v.toPyObject(); return out << ""; } case IValue::Tag::Generator: return out << "Generator"; case IValue::Tag::Quantizer: return out << "Quantizer"; case IValue::Tag::Object: { // TODO we should attempt to call __str__ if the object defines it. auto obj = v.toObject(); // print this out the way python would do it return out << "<" << obj->name() << " object at " << obj.get() << ">"; } case IValue::Tag::Enum: { auto enum_holder = v.toEnumHolder(); return out << "Enum<" << enum_holder->unqualifiedClassName() << "." << enum_holder->name() << ">"; } } return out << ""; } #undef TORCH_FORALL_TAGS void IValue::dump() const { std::cout << *this << "\n"; } std::shared_ptr ivalue::Object::type() const { return type_.type_->expect(); } c10::intrusive_ptr ivalue::Object::create( ClassTypePtr classType, size_t numSlots) { return ivalue::Object::create( StrongTypePtr(nullptr, std::move(classType)), numSlots); } IValue IValue::deepcopy(std::optional device) const { IValue::HashIdentityIValueMap memo; return deepcopy(memo, device); } IValue IValue::deepcopy( IValue::HashIdentityIValueMap& memo, std::optional device) const { if (memo.count(*this)) { return memo.at(*this); } IValue copy; switch(tag) { case IValue::Tag::Tensor: { const at::Tensor& src_tensor = toTensor(); copy = device.has_value() && !src_tensor.device().is_meta() ? IValue(src_tensor.to(*device)) : IValue(src_tensor.clone()); } break; case IValue::Tag::Tuple: { std::vector copied_tuple; for (const auto& e : toTupleRef().elements()) { copied_tuple.emplace_back(e.deepcopy(memo, device)); } copy = IValue(ivalue::Tuple::create(std::move(copied_tuple))); } break; case IValue::Tag::GenericList: { auto list = toList(); auto copied_list = c10::impl::GenericList(list.elementType()); for (IValue v : list) { copied_list.push_back(v.deepcopy(memo, device)); } copy = IValue(copied_list); } break; case IValue::Tag::GenericDict: { auto dict = toGenericDict(); auto copied_dict = c10::impl::GenericDict(dict.keyType(), dict.valueType()); for (const auto& entry : dict) { copied_dict.insert( entry.key().deepcopy(memo, device), entry.value().deepcopy(memo, device)); } copy = IValue(copied_dict); } break; case IValue::Tag::Object: { auto class_type = type()->expect(); if (class_type->hasMethod("__getstate__") && class_type->hasMethod("__setstate__")) { copy = ivalue::Object::create( c10::StrongTypePtr(class_type->compilation_unit(), type()), class_type->numAttributes()); auto state = class_type->getMethod("__getstate__")({*this}); class_type->getMethod("__setstate__")({copy, std::move(state)}); } else { copy = IValue(toObject()->deepcopy(memo, device)); } } break; case IValue::Tag::Enum: { auto enum_holder = toEnumHolder(); copy = IValue(c10::make_intrusive( enum_holder->type(), enum_holder->name(), enum_holder->value().deepcopy(memo, device))); } break; case IValue::Tag::String: case IValue::Tag::None: case IValue::Tag::Double: case IValue::Tag::Int: case IValue::Tag::SymInt: case IValue::Tag::SymFloat: case IValue::Tag::SymBool: case IValue::Tag::Bool: case IValue::Tag::Device: case IValue::Tag::Generator: case IValue::Tag::Uninitialized: { copy = *this; } break; default: { AT_ERROR("Can't deepcopy IValue with tag: ", tagKind()); } } // NB: this doesn't work if an object contains itself, and it may // come up in the future when we expand the object system, we will // have a follow up PR to fix this when it becomes an issue. if (!isAliasOf(copy)) { memo[*this] = copy; } return copy; } void IValue::reportToTensorTypeError() const { TORCH_CHECK(false, "Expected Tensor but got ", tagKind()); } std::string ivalue::Object::name() const { // NOLINTNEXTLINE(bugprone-unchecked-optional-access) return type()->name()->qualifiedName(); } IValue ivalue::Object::getAttr(const std::string& name) const { const size_t slot = type()->getAttributeSlot(name); return getSlot(slot); } void ivalue::Object::setAttr(const std::string& name, IValue v) { const size_t slot = type()->getAttributeSlot(name); setSlot(slot, std::move(v)); } void ivalue::Object::unsafeRemoveAttr(const std::string& name) { const size_t slot = type()->getAttributeSlot(name); unsafeRemoveSlot(slot); } void ivalue::Object::resizeObject(size_t slot) { AT_ASSERT(slot < type()->numAttributes()); slots_.resize(type()->numAttributes()); } c10::intrusive_ptr ivalue::Object::copy() const { auto object = ivalue::Object::create(type_, type()->numAttributes()); for (const auto i : c10::irange(slots_.size())) { object->setSlot(i, slots_[i]); } return object; } c10::intrusive_ptr ivalue::Object::copy_to_weak_compilation_ref() const { auto object = ivalue::Object::create( WeakOrStrongTypePtr(type_.asWeakTypePtr()), type()->numAttributes()); for (const auto i : c10::irange(slots_.size())) { object->setSlot(i, slots_[i]); } return object; } c10::intrusive_ptr ivalue::Object::deepcopy( std::optional device) const { IValue::HashIdentityIValueMap memo; return deepcopy(memo, device); } c10::intrusive_ptr ivalue::Object::deepcopy( IValue::HashIdentityIValueMap& memo, std::optional device) const { auto cu = type_.cu_; auto object = ivalue::Object::create(WeakOrStrongTypePtr(type_.cu_, type_.type_), type()->numAttributes()); for (const auto i : c10::irange(slots_.size())) { if (*slots_[i].type() == *c10::TypeFactory::get()) { // If we've gotten here, it means that we have *not* copied this // class via __getstate__ and __setstate__. That fact and the // fact that we have a Capsule attribute mean that this is a // custom C++ class without serialization methods defined. std::stringstream err; err << "Cannot serialize custom bound C++ class"; if (auto qualname = type()->name()) { err << " " << qualname->qualifiedName(); } err << ". Please define serialization methods via def_pickle() for " "this class."; AT_ERROR(err.str()); } object->setSlot(i, slots_[i].deepcopy(memo, device)); } return object; } StrongTypePtr::StrongTypePtr( std::shared_ptr cu, TypePtr type) : cu_(std::move(cu)), type_(std::move(type)) { TORCH_INTERNAL_ASSERT(type_); } WeakTypePtr::WeakTypePtr( std::weak_ptr cu, TypePtr type) : cu_(std::move(cu)), type_(std::move(type)) {} WeakTypePtr WeakOrStrongTypePtr::asWeakTypePtr() const { if (!holds_strong_ref()) { return WeakTypePtr(cu_.getWeakRefOrThrow(), type_); } else { std::weak_ptr weak_cu = cu_.getStrongRefOrThrow(); return WeakTypePtr(std::move(weak_cu), type_); } } // Needs to be in this .cpp file to access the full definition of PyObjectHolder std::vector> ivalue::Future::extractStorages( const at::IValue& value) { std::vector> weakStorageImpls; // getSubValues works poorly on Python objects: it only works if they can be // converted to a "regular" IValue type hence, for example, it doesn't support // custom subclasses. Thus, instead, we extract the tensors through pickling. if (value.isPyObject()) { std::vector tensors = value.toPyObjectHolder()->extractTensors(); size_t num_storages = 0; for (const at::Tensor& tensor : tensors) { if (tensor.is_sparse()) { // Sparse tensor is indices and values. Both are tensors // and contain storage. Therefore num_storages needs to be // incremented by 2. num_storages += 2; } else { // A dense/strided tensor contains 1 storage. num_storages += 1; } } weakStorageImpls.reserve(num_storages); for (const at::Tensor& tensor : tensors) { if (tensor.is_sparse()) { // Sparse tensor is indices and values. Both are tensors // and contain storage. // TODO (rohan-varma): for tensors created with at::sparse_coo_tensor held // in a python object, this might need a coalesce(). weakStorageImpls.emplace_back(tensor.indices().storage().getWeakStorageImpl()); weakStorageImpls.emplace_back(tensor.values().storage().getWeakStorageImpl()); } else { // A dense/strided tensor contains 1 storage weakStorageImpls.emplace_back(tensor.storage().getWeakStorageImpl()); } } } else { at::IValue::HashAliasedIValues sub_values; // Prefer getSubValues() over visit() as the latter is a silent no-op for // some unsupported types, whereas the former at least fails loudly. value.getSubValues(sub_values); for (const at::IValue& sub_value : sub_values) { if (sub_value.isTensor()) { auto const & tens = sub_value.toTensor(); if (tens.is_sparse()) { // sparse tensors have 2 storages! One for indices one for values auto coalesced = tens.coalesce(); weakStorageImpls.emplace_back(coalesced.indices().storage().getWeakStorageImpl()); weakStorageImpls.emplace_back(coalesced.values().storage().getWeakStorageImpl()); } else { weakStorageImpls.emplace_back(tens.storage().getWeakStorageImpl()); } } } } return weakStorageImpls; } TORCH_API intrusive_ptr collectAll( const List>& srcs) { struct Ctx { explicit Ctx(const List>& srcs) : remaining(srcs.size()), srcFutures(srcs), asIvalue(srcFutures), // No need to pass devices, because dstFuture won't directly contain // the value, it will contain the srcFutures (which have no DataPtrs). dstFuture(make_intrusive(asIvalue.type())) {} std::atomic remaining{0}; List> srcFutures; IValue asIvalue; intrusive_ptr dstFuture; }; auto ctx = std::make_shared(srcs); if (ctx->srcFutures.empty()) { ctx->dstFuture->markCompleted(ctx->asIvalue); } else { for (const auto i : c10::irange(ctx->srcFutures.size())) { std::function func = [ctx](ivalue::Future& fut) { // Set error and exit early if encountered. if (fut.hasError()) { ctx->dstFuture->setErrorIfNeeded(fut.exception_ptr()); return; } if (--ctx->remaining == 0 && !ctx->dstFuture->completed()) { // No need to pass DataPtrs, because dstFuture won't directly contain // the value, it will contain the srcFutures (which have no DataPtrs). ctx->dstFuture->markCompleted(ctx->asIvalue); } }; ctx->srcFutures.get(i)->addCallback(func); } } return ctx->dstFuture; } namespace { #ifndef STRIP_ERROR_MESSAGES std::string formatSetOfDevices(const std::vector& devices) { std::ostringstream oss; std::copy( devices.begin(), devices.end(), std::ostream_iterator(oss, ", ")); return oss.str(); } #endif } TORCH_API intrusive_ptr collectAny( const List>& srcs) { if (srcs.empty()) { auto res = make_intrusive(NoneType::get()); res->markCompleted(); return res; } const TypePtr& typePtr = srcs.get(0)->elementType(); const std::vector& devices = srcs.get(0)->devices(); for (const auto i : c10::irange(srcs.size())) { if (srcs.get(i)->completed()) { return srcs.get(i); } TORCH_CHECK_TYPE( i == 0 || (*typePtr == *srcs.get(i)->elementType()), "Expected all futures to have the same type, but found ", *typePtr, " in position 0 and ", *srcs.get(i)->elementType(), " in position ", i); TORCH_CHECK_VALUE( i == 0 || (devices == srcs.get(i)->devices()), "Expected all futures to have the same devices, but found ", formatSetOfDevices(devices), " in position 0 and ", formatSetOfDevices(srcs.get(i)->devices()), " in position ", i); } struct Ctx { explicit Ctx( const List>& srcs, TypePtr typePtr, std::vector devices) : srcFutures(srcs), dstFuture(make_intrusive(std::move(typePtr), std::move(devices))) {} std::atomic done{false}; List> srcFutures; intrusive_ptr dstFuture; }; auto ctx = std::make_shared(srcs, typePtr, devices); std::function func = [ctx](ivalue::Future& src) { if (!ctx->done.exchange(true)) { intrusive_ptr dst = ctx->dstFuture; ctx->dstFuture.reset(); // Once future is satisfied, remove refs. ctx->srcFutures = List>(ctx->srcFutures.elementType()); if (src.hasError()) { dst->setError(src.exception_ptr()); } else { dst->markCompleted(src.constValue(), src.storages()); } } }; for (const auto i : c10::irange(ctx->srcFutures.size())) { ctx->srcFutures.get(i)->addCallback(func); } return ctx->dstFuture; } } // namespace c10