#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::jit::tensorexpr { class InterpValue { public: InterpValue() : dtype_(kInt) { Intvalues.push_back(0); } template InterpValue(Dtype dtype, T v) : dtype_(dtype) { #define TYPE_CASE(Type, Name) \ if (dtype == k##Name) { \ Name##values.push_back(v); \ return; \ } AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE); #undef TYPE_CASE throw unsupported_dtype(); } #define VALUE_CTOR(Type, Name) \ InterpValue(Type v) : dtype_(k##Name) { \ Name##values.push_back(v); \ } AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_CTOR); #undef VALUE_CTOR explicit InterpValue(c10::quint8 v) : dtype_(kQUInt8) { QUInt8values.emplace_back(v.val_); } explicit InterpValue(c10::qint8 v) : dtype_(kQInt8) { QInt8values.emplace_back(v.val_); } #define VALUE_VEC_CTOR(Type, Name) \ InterpValue(const std::vector& v) \ : dtype_(Dtype(k##Name, v.size())), Name##values(v) {} AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_VEC_CTOR); VALUE_VEC_CTOR(c10::quint8, QUInt8) VALUE_VEC_CTOR(c10::qint8, QInt8) #undef VALUE_VEC_CTOR template T as() const; template const std::vector& as_vec() const; int64_t intValue() const; Dtype dtype() const { return dtype_; } private: Dtype dtype_; #define VALUE_STORAGE(Type, Name) std::vector Name##values; AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_STORAGE); VALUE_STORAGE(c10::qint8, QInt8); VALUE_STORAGE(c10::quint8, QUInt8); #undef VALUE_STORAGE void* ptr{nullptr}; }; #define VALUE_AS_DISPATCH(Type, Name) \ template <> \ inline Type InterpValue::as() const { \ if (dtype_ != k##Name) { \ throw unsupported_dtype(); \ } \ return Name##values[0]; \ } AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_DISPATCH); VALUE_AS_DISPATCH(c10::quint8, QUInt8); VALUE_AS_DISPATCH(c10::qint8, QInt8); #undef VALUE_AS_DISPATCH #define VALUE_AS_VEC_DISPATCH(Type, Name) \ template <> \ inline const std::vector& InterpValue::as_vec() const { \ if (dtype_.scalar_type() != ScalarType::Name) { \ throw unsupported_dtype(); \ } \ return Name##values; \ } AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_VEC_DISPATCH); VALUE_AS_VEC_DISPATCH(c10::quint8, QUInt8); VALUE_AS_VEC_DISPATCH(c10::qint8, QInt8); #undef VALUE_AS_VEC_DISPATCH template auto underlyingValue(Type x) { return x; } template <> inline auto underlyingValue(c10::quint8 x) { return x.val_; } template <> inline auto underlyingValue(c10::qint8 x) { return x.val_; } template To raw_bitcast(const From& src) { TORCH_CHECK(sizeof(To) == sizeof(From), "Invalid bitcast invocation"); To storage; std::memcpy(&storage, &src, sizeof(To)); return reinterpret_cast(storage); } class SimpleIREvaluatorImpl; class TORCH_API SimpleIREvaluator : public CodeGen { public: SimpleIREvaluator( StmtPtr stmt, const std::vector& buffer_args, at::Device device = at::kCPU, const std::string& kernel_func_name = "func"); ~SimpleIREvaluator() override; void call(const std::vector& args) override; void call_raw(const std::vector& args) override; template void operator()(const Ts&... ts) { std::vector args({CallArg(ts)...}); call(args); } void bindVar(const VarPtr& v, const ExprPtr& e); InterpValue value() const; private: void bindArg(const BufferArg& buf, void* data); void expand_intrinsics() { GenericIntrinsicsExpander intrinsics_expander; apply_mutator(&intrinsics_expander); } std::unique_ptr impl_; }; template class ExprEval { public: using BufferArg = CodeGen::BufferArg; using CallArg = CodeGen::CallArg; template ExprEval(const ExprHandle& expr, Ts... ts) : ExprEval(expr, {BufferArg(ts)...}) {} ExprEval(const ExprHandle& expr, const std::vector& buffer_args) : dtype_(expr.dtype()) { std::vector buffer_args_extended = buffer_args; BufHandle ret_buf("ret_val", {1}, dtype_); std::vector indices; ExprHandle zero = IntImm::make(0); for (size_t i = 0; i < ret_buf.ndim(); i++) { indices.push_back(zero); } StmtPtr store_stmt = Store::make(ret_buf, indices, expr); buffer_args_extended.emplace_back(ret_buf); codegen_.reset(new CodeGenType(store_stmt, buffer_args_extended)); } template void operator()(Ts... ts) { call(ts...); } void operator()(const std::vector& call_args) { call(call_args); } void bindVar(VarPtr v, ExprPtr e) { codegen_->bindVar(v, e); } void bindVar(const VarHandle& v, const ExprHandle& e) { codegen_->bindVar(v.node(), e.node()); } template void call(Ts... ts) { call({CallArg(ts)...}); } void call(const std::vector& call_args) { std::vector call_args_extended = call_args; switch (dtype_.scalar_type()) { #define TYPE_CASE(Type, Name) \ case ScalarType::Name: { \ std::vector ret_val_arg(1); \ call_args_extended.emplace_back(ret_val_arg); \ codegen_->call(call_args_extended); \ ret_value_ = InterpValue(ret_val_arg[0]); \ } break; AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE); TYPE_CASE(c10::quint8, QUInt8); TYPE_CASE(c10::qint8, QInt8); #undef TYPE_CASE case ScalarType::Bool: { std::vector ret_val_arg(1); call_args_extended.emplace_back(ret_val_arg.data()); codegen_->call(call_args_extended); ret_value_ = InterpValue((bool)ret_val_arg[0]); } break; default: throw unsupported_dtype(); } } void call_raw(const std::vector& args) { std::vector args_extended = args; switch (dtype_.scalar_type()) { #define TYPE_CASE(Type, Name) \ case ScalarType::Name: { \ std::vector ret_val_arg(1); \ args_extended.push_back(ret_val_arg.data()); \ codegen_->call_raw(args_extended); \ ret_value_ = InterpValue(ret_val_arg[0]); \ } break; AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE); TYPE_CASE(c10::quint8, QUInt8); TYPE_CASE(c10::qint8, QInt8); #undef TYPE_CASE case ScalarType::Bool: { std::vector ret_val_arg(1); args_extended.push_back(ret_val_arg.data()); codegen_->call_raw(args_extended); ret_value_ = InterpValue((bool)ret_val_arg[0]); } break; default: throw unsupported_dtype(); } } template T value(const std::vector& args) { call_raw(args); return ret_value_.as(); } template T value(Ts... ts) { call(std::forward(ts)...); return ret_value_.as(); } Dtype dtype() { return dtype_; } private: Dtype dtype_; std::unique_ptr codegen_; InterpValue ret_value_; }; // Evaluates the given expression and returns an int64_t value if the result of // the given expression is int64_t. std::optional evalInt(ExprPtr e); // Substitutes the given vars with their corresponding expressions in the input // expression. inline ExprPtr Substitute(const ExprPtr& expr, const VarMapping& var_mapping) { VarSubMutator var_sub(var_mapping); return expr->accept_mutator(&var_sub); } // Substitutes the given vars with their corresponding expressions in the input // statement. inline StmtPtr Substitute(const StmtPtr& stmt, const VarMapping& var_mapping) { VarSubMutator var_sub(var_mapping); return stmt->accept_mutator(&var_sub); } // Creates a clone of the input expression and substitutes the given vars with // their corresponding expressions in the clone. // NOTE: This works because cloning reuses variables and does not create new // ones, and `VarMapping` input has variables as the key. inline ExprPtr SubstituteInClone( const ExprPtr& expr, const VarMapping& var_mapping) { VarSubMutator var_sub(var_mapping); return Expr::clone(expr)->accept_mutator(&var_sub); } // Creates a clone of the input statement and substitutes the given vars with // their corresponding expressions in the clone. // NOTE: This works because cloning reuses variables and does not create new // ones, and `VarMapping` input has variables as the key. inline StmtPtr SubstituteInClone( const StmtPtr& stmt, const VarMapping& var_mapping) { VarSubMutator var_sub(var_mapping); return Stmt::clone(stmt)->accept_mutator(&var_sub); } } // namespace torch::jit::tensorexpr