#pragma once #include #include #include #include #include #include #include namespace c10 { // Unlike other SymNodeImpl, this cannot be "dispatched" conventionally, // as it typically needs to defer to another SymNodeImpl // // Can either represent a bool, int (don't support float yet) this is useful // for representing otherwise unrepresentable large negative integer constant. template class C10_API ConstantSymNodeImpl : public SymNodeImpl { static_assert( ::std::is_same_v || ::std::is_same_v, "ConstantSymNodeImpl can only accept int64_t or bool types"); public: ConstantSymNodeImpl(T val) : value_(val) {} bool is_int() override { return is_int_(); } bool is_bool() override { return is_bool_(); } bool is_float() override { return false; } int64_t guard_int( const char* file [[maybe_unused]], int64_t line [[maybe_unused]]) override { TORCH_CHECK(is_int(), "not an int"); return int_(); } bool guard_bool( const char* file [[maybe_unused]], int64_t line [[maybe_unused]]) override { TORCH_CHECK(is_bool(), "not a bool"); return bool_(); } double guard_float( const char* file [[maybe_unused]], int64_t line [[maybe_unused]]) override { TORCH_CHECK(false, "not a float"); } int64_t int_() override { TORCH_CHECK(is_int(), "not an int"); return ::std::get(value_); } bool bool_() override { TORCH_CHECK(is_bool(), "not a bool"); return ::std::get(value_); } bool has_hint() override { return true; } c10::SymNode eq(const c10::SymNode& other) override; c10::SymNode ne(const c10::SymNode& other) override; c10::SymNode ge(const c10::SymNode& other) override; c10::SymNode le(const c10::SymNode& other) override; c10::SymNode lt(const c10::SymNode& other) override; c10::SymNode gt(const c10::SymNode& other) override; c10::SymNode mul(const c10::SymNode& other) override; ::std::string str() override { if constexpr (is_int_()) { return ::std::to_string(::std::get(value_)); } else { return ::std::get(value_) ? "true" : "false"; } } std::optional constant_int() override { if constexpr (is_int_()) { return ::std::get(value_); } else { return std::nullopt; } } std::optional constant_bool() override { if constexpr (is_bool_()) { return ::std::get(value_); } else { return std::nullopt; } } bool is_constant() override { return true; } bool is_symbolic() override { return false; } private: ::std::variant value_; static constexpr bool is_int_() { return ::std::is_same_v; } static constexpr bool is_bool_() { return ::std::is_same_v; } }; } // namespace c10