#include #include #include #include namespace torch::jit::tensorexpr { StmtPtr Tensor::constructStmt( const std::vector& args, const ExprPtr& body, const std::vector& reduce_dims, const std::vector& reduce_args) const { std::vector indices(args.begin(), args.end()); size_t ndim = buf()->ndim(); size_t reduce_ndim = reduce_dims.size(); auto reduce_op = to(body); auto acc_buf = reduce_ndim > 0 ? reduce_op->getAccBuf() : nullptr; StmtPtr s = alloc(buf_, indices, body); if (reduce_ndim > 0) { TORCH_INTERNAL_ASSERT(reduce_op != nullptr); if (acc_buf != nullptr) { auto reducer = reduce_op->reducer(); std::vector output_args(args.begin(), args.end()); ExprPtr new_reduce_op = reducer( to(acc_buf), alloc(acc_buf->dtype(), reduce_op->getRiOperand()), output_args, reduce_args); new_reduce_op->set_dtype(acc_buf->dtype()); s = alloc(to(acc_buf), indices, new_reduce_op); } } if (ndim == 0 && reduce_ndim == 0) { return s; } if (reduce_ndim > 0) { TORCH_INTERNAL_ASSERT(reduce_op != nullptr); for (const auto i : c10::irange(reduce_ndim)) { // Going in reverse order: from innermost loop to the outermost size_t dim_index = reduce_ndim - i - 1; auto const& dim = reduce_dims[dim_index]; s = alloc(reduce_args[dim_index], immLike(dim, 0), dim, s); } s = alloc(std::vector({s})); BufPtr init_buf = acc_buf ? to(acc_buf) : buf(); ExprPtr init_expr = acc_buf ? to(acc_buf)->initializer() : buf()->initializer(); if (init_expr) { StorePtr init_stmt = alloc(init_buf, indices, init_expr); to(s)->prepend_stmt(init_stmt); } if (acc_buf != nullptr) { LoadPtr load_acc = alloc(acc_buf, indices); auto cast = alloc(buf()->dtype(), load_acc); StorePtr post_stmt = alloc(buf(), indices, cast); to(s)->append_stmt(post_stmt); } } TORCH_INTERNAL_ASSERT_DEBUG_ONLY( buf_->is_contiguous() || buf_->is_contiguous(at::MemoryFormat::ChannelsLast) || buf_->is_contiguous(at::MemoryFormat::ChannelsLast3d) || buf_->is_channels_last_1d_contiguous()); auto loop_order_fn = [&]() { std::vector loop_order; if (buf_->is_contiguous()) { for (int32_t i = args.size() - 1; i >= 0; i--) { loop_order.push_back(i); } } else if (buf_->is_contiguous(c10::MemoryFormat::ChannelsLast)) { loop_order = {1, 3, 2, 0}; } else if (buf_->is_contiguous(c10::MemoryFormat::ChannelsLast3d)) { loop_order = {1, 4, 3, 2, 0}; } else { loop_order = {1, 2, 0}; } return loop_order; }; auto loop_order = loop_order_fn(); for (auto dim_index : loop_order) { auto const& dim = buf()->dim(dim_index); s = alloc(args[dim_index], immLike(dim, 0), dim, s); } return s; } Tensor Compute( const std::string& name, const std::vector& dims, const std::optional>& strides, const std::function&)>& body_func) { std::vector args = create_index_vars(dims); ExprHandle body = body_func(args); BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides); return Tensor(buf, args, body); } Tensor Compute( const std::string& name, const std::vector& dims, const std::function&)>& body_func) { return Compute(name, dims, std::nullopt, body_func); } Tensor Compute( const std::string& name, const std::vector& dims, const std::optional>& strides, const std::function& body_func) { if (dims.size() != 1) { throw malformed_input("mismatch between body and arg size (1)"); } std::vector args = create_index_vars(dims); ExprHandle body = body_func(args[0]); BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides); return Tensor(buf, args, body); } Tensor Compute( const std::string& name, const std::vector& dims, const std::function& body_func) { return Compute(name, dims, std::nullopt, body_func); } Tensor Compute( const std::string& name, const std::vector& dims, const std::optional>& strides, const std::function& body_func) { if (dims.size() != 2) { throw malformed_input("mismatch between body and arg size (2)"); } std::vector args = create_index_vars(dims); ExprHandle body = body_func(args[0], args[1]); BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides); return Tensor(buf, args, body); } Tensor Compute( const std::string& name, const std::vector& dims, const std::function& body_func) { return Compute(name, dims, std::nullopt, body_func); } Tensor Compute( const std::string& name, const std::vector& dims, const std::optional>& strides, const std::function< ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>& body_func) { if (dims.size() != 3) { throw malformed_input("mismatch between body and arg size (3)"); } std::vector args = create_index_vars(dims); ExprHandle body = body_func(args[0], args[1], args[2]); BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides); return Tensor(buf, args, body); } Tensor Compute( const std::string& name, const std::vector& dims, const std::function< ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>& body_func) { return Compute(name, dims, std::nullopt, body_func); } Tensor Compute( const std::string& name, const std::vector& dims, const std::optional>& strides, const std::function& body_func) { if (dims.size() != 4) { throw malformed_input("mismatch between body and arg size (4)"); } std::vector args = create_index_vars(dims); ExprHandle body = body_func(args[0], args[1], args[2], args[3]); BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides); return Tensor(buf, args, body); } Tensor Compute( const std::string& name, const std::vector& dims, const std::function& body_func) { return Compute(name, dims, std::nullopt, body_func); } Tensor Reduce( const std::string& name, const std::vector& dims, const std::optional>& strides, const Reducer& reducer, const BufHandle& buffer, const std::vector& reduce_dims) { return Reduce( name, dims, strides, reducer, [&](ParameterList& p) { return buffer.load(p); }, reduce_dims); } Tensor Reduce( const std::string& name, const std::vector& dims, const Reducer& reducer, const BufHandle& buffer, const std::vector& reduce_dims) { return Reduce(name, dims, std::nullopt, reducer, buffer, reduce_dims); } Tensor Reduce( const std::string& name, const std::vector& dims, const std::optional>& strides, const Reducer& reducer, const Tensor& tensor, const std::vector& reduce_dims) { return Reduce( name, dims, strides, reducer, [&](ParameterList& p) { return tensor.load(p); }, reduce_dims); } Tensor Reduce( const std::string& name, const std::vector& dims, const Reducer& reducer, const Tensor& tensor, const std::vector& reduce_dims) { return Reduce(name, dims, std::nullopt, reducer, tensor, reduce_dims); } } // namespace torch::jit::tensorexpr