#include #include #include #include #include #include #include namespace c10 { // Precondition: data_ has a large negative number that should be // treated as a constant. It is NOT a valid pointer. In other words, // SymInt has temporarily violated invariants // Postcondition: invariants on SymInt are fixed void SymInt::promote_to_negative() { auto s = SymInt(SymNode(c10::make_intrusive>(data_))); // Similar to move operator=, but do NOT release data_ data_ = s.data_; s.data_ = 0; } SymNode SymInt::toSymNode() const { TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE( is_heap_allocated(), "SymInt::toSymNode is_heap_allocated"); return SymNode::reclaim_copy(toSymNodeImplUnowned()); } SymInt::SymInt(SymNode sin_sp) { TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE( sin_sp->is_int(), "SymInt::SymInt sin_sp->is_int()"); auto ptr = static_cast( reinterpret_cast(static_cast(sin_sp.release()))); auto rep = (ptr & ~MASK) | IS_SYM; data_ = static_cast(rep); } bool SymInt::has_hint() const { if (!is_heap_allocated()) { return true; } return toSymNodeImplUnowned()->has_hint(); } #define DEFINE_BINARY(API, OP, METHOD, RET) \ RET SymInt::API(const SymInt& sci) const { \ if (auto ma = maybe_as_int()) { \ if (auto mb = sci.maybe_as_int()) { \ return RET(OP(*ma, *mb)); \ } else { \ auto b = sci.toSymNode(); \ return RET(b->wrap_int(*ma)->METHOD(b)); \ } \ } else { \ if (auto mb = sci.maybe_as_int()) { \ auto a = toSymNodeImplUnowned(); \ return RET(a->METHOD(a->wrap_int(*mb))); \ } else { \ return RET(toSymNodeImplUnowned()->METHOD(sci.toSymNode())); \ } \ } \ } // clang-format off DEFINE_BINARY(operator+, std::plus<>(), add, SymInt) DEFINE_BINARY(operator-, std::minus<>(), sub, SymInt) DEFINE_BINARY(operator*, std::multiplies<>(), mul, SymInt) DEFINE_BINARY(operator/, std::divides<>(), floordiv, SymInt) DEFINE_BINARY(operator%, std::modulus<>(), mod, SymInt) DEFINE_BINARY(sym_eq, std::equal_to<>(), eq, SymBool) DEFINE_BINARY(sym_ne, std::not_equal_to<>(), ne, SymBool) DEFINE_BINARY(sym_lt, std::less<>(), lt, SymBool) DEFINE_BINARY(sym_le, std::less_equal<>(), le, SymBool) DEFINE_BINARY(sym_gt, std::greater<>(), gt, SymBool) DEFINE_BINARY(sym_ge, std::greater_equal<>(), ge, SymBool) DEFINE_BINARY(min, std::min, sym_min, SymInt) DEFINE_BINARY(max, std::max, sym_max, SymInt) // clang-format on SymInt::operator SymFloat() const { if (auto ma = maybe_as_int()) { return SymFloat(double(*ma)); } else { return SymFloat(toSymNodeImplUnowned()->sym_float()); } } bool SymInt::is_same(const SymInt& other) const { if (is_heap_allocated() != other.is_heap_allocated()) { return false; } // Both not heap allocated if (!is_heap_allocated() && this->operator!=(other)) { return false; } // Both heap allocated if (is_heap_allocated() && toSymNodeImplUnowned() != other.toSymNodeImplUnowned()) { return false; } return true; } SymNode SymInt::wrap_node(const SymNode& base) const { if (auto ma = maybe_as_int()) { return base->wrap_int(*ma); } else { return toSymNode(); } } SymInt SymInt::clone() const { if (auto ma = maybe_as_int()) { return SymInt(*ma); } else { return SymInt(toSymNodeImplUnowned()->clone()); } } int64_t SymInt::guard_int(const char* file, int64_t line) const { if (auto ma = maybe_as_int()) { return *ma; } else { return toSymNodeImplUnowned()->guard_int(file, line); } } bool SymInt::expect_size(const char* file, int64_t line) const { if (auto ma = maybe_as_int()) { return *ma >= 0; } else { return toSymNodeImplUnowned()->expect_size(file, line); } } SymInt operator-(const SymInt& s) { if (auto ma = s.maybe_as_int()) { const auto val = *ma; // Note: Result of `-std::numeric_limits::min()` is undefined // But on many platforms it equals to self + setting Carry/Overflow flags // Which in opimized code affects results of `check_range` condition // Workaround by using ternary that avoids alterning the flags #if C10_HAS_BUILTIN_OVERFLOW() std::decay_t out = 0; if (C10_UNLIKELY(__builtin_sub_overflow(out, val, &out))) { return SymInt(val); } return SymInt(out); #else constexpr auto val_min = std::numeric_limits::min(); return SymInt(val != val_min ? -val : val_min); #endif } else { return SymInt(s.toSymNodeImplUnowned()->neg()); } } void SymInt::operator*=(const SymInt& sci) { *this = *this * sci; } void SymInt::operator/=(const SymInt& sci) { *this = *this / sci; } void SymInt::operator+=(const SymInt& sci) { *this = *this + sci; } std::ostream& operator<<(std::ostream& os, const SymInt& s) { if (s.is_heap_allocated()) { os << s.toSymNodeImplUnowned()->str(); } else { os << s.as_int_unchecked(); } return os; } // This template lets us not do a refcount bump when we do an // identity conversion template struct Convert {}; template <> struct Convert { const SymInt& operator()(const SymInt& a) { return a; } }; template <> struct Convert { SymFloat operator()(const SymInt& a) { return a; } }; #define DEFINE_SYMINT_OP_INTONLY(scalar_t, RetTy) \ RetTy operator%(const SymInt& a, scalar_t b) { \ return Convert()(a) % RetTy(b); \ }; \ RetTy operator%(scalar_t a, const SymInt& b) { \ return RetTy(a) % Convert()(b); \ }; #define DEFINE_SYMINT_OP(scalar_t, RetTy) \ RetTy operator+(const SymInt& a, scalar_t b) { \ return Convert()(a) + RetTy(b); \ }; \ RetTy operator-(const SymInt& a, scalar_t b) { \ return Convert()(a) - RetTy(b); \ }; \ RetTy operator*(const SymInt& a, scalar_t b) { \ return Convert()(a) * RetTy(b); \ }; \ RetTy operator/(const SymInt& a, scalar_t b) { \ return Convert()(a) / RetTy(b); \ }; \ RetTy operator+(scalar_t a, const SymInt& b) { \ return RetTy(a) + Convert()(b); \ }; \ RetTy operator-(scalar_t a, const SymInt& b) { \ return RetTy(a) - Convert()(b); \ }; \ RetTy operator*(scalar_t a, const SymInt& b) { \ return RetTy(a) * Convert()(b); \ }; \ RetTy operator/(scalar_t a, const SymInt& b) { \ return RetTy(a) / Convert()(b); \ }; \ bool operator==(const SymInt& a, scalar_t b) { \ return Convert()(a) == RetTy(b); \ }; \ bool operator!=(const SymInt& a, scalar_t b) { \ return Convert()(a) != RetTy(b); \ }; \ bool operator<(const SymInt& a, scalar_t b) { \ return Convert()(a) < RetTy(b); \ }; \ bool operator<=(const SymInt& a, scalar_t b) { \ return Convert()(a) <= RetTy(b); \ }; \ bool operator>(const SymInt& a, scalar_t b) { \ return Convert()(a) > RetTy(b); \ }; \ bool operator>=(const SymInt& a, scalar_t b) { \ return Convert()(a) >= RetTy(b); \ }; \ bool operator==(scalar_t a, const SymInt& b) { \ return RetTy(a) == Convert()(b); \ }; \ bool operator!=(scalar_t a, const SymInt& b) { \ return RetTy(a) != Convert()(b); \ }; \ bool operator<(scalar_t a, const SymInt& b) { \ return RetTy(a) < Convert()(b); \ }; \ bool operator<=(scalar_t a, const SymInt& b) { \ return RetTy(a) <= Convert()(b); \ }; \ bool operator>(scalar_t a, const SymInt& b) { \ return RetTy(a) > Convert()(b); \ }; \ bool operator>=(scalar_t a, const SymInt& b) { \ return RetTy(a) >= Convert()(b); \ }; DEFINE_SYMINT_OP_INTONLY(int64_t, SymInt) DEFINE_SYMINT_OP_INTONLY(int32_t, SymInt) DEFINE_SYMINT_OP_INTONLY(uint64_t, SymInt) DEFINE_SYMINT_OP_INTONLY(uint32_t, SymInt) DEFINE_SYMINT_OP(int64_t, SymInt) DEFINE_SYMINT_OP(int32_t, SymInt) // make sure constants work DEFINE_SYMINT_OP(uint64_t, SymInt) DEFINE_SYMINT_OP(uint32_t, SymInt) DEFINE_SYMINT_OP(double, SymFloat) DEFINE_SYMINT_OP(float, SymFloat) // just for completeness #if defined(__APPLE__) DEFINE_SYMINT_OP_INTONLY(size_t, SymInt) // needed for osx DEFINE_SYMINT_OP(size_t, SymInt) // needed for osx #endif } // namespace c10