#include #include #include #include #include namespace torch::jit::tensorexpr { CodeGen::CodeGen( StmtPtr stmt, std::vector buffer_args, at::Device device, std::string kernel_func_name) : stmt_(std::move(stmt)), buffer_args_(std::move(buffer_args)), device_(device), kernel_func_name_(std::move(kernel_func_name)) { ExtCallMemoryReuse extCallMemoryReuse(buffer_args_); apply_mutator(&extCallMemoryReuse); allocIntermediateBufs(); } RegisterCodeGenList& RegisterCodeGenList::GetInstance() { static RegisterCodeGenList codegen_list; return codegen_list; } RegisterCodeGenList::StmtFactoryMethod RegisterCodeGenList:: FindStmtFactoryMethod(const std::string& name) { auto iter = stmt_factory_methods_.find(name); if (iter == stmt_factory_methods_.end()) { std::ostringstream oss; oss << "Invalid stmt codegen name: " << name << ". "; oss << "Existing codegen names: ["; int index = 0; for (auto& entry : stmt_factory_methods_) { if (index != 0) { oss << ", "; } oss << entry.first; index++; } oss << "]"; throw std::runtime_error(oss.str()); } return iter->second; } void RegisterCodeGenList::AddStmtFactoryMethod( const std::string& name, const StmtFactoryMethod& stmt_factory_method) { stmt_factory_methods_[name] = stmt_factory_method; } std::unique_ptr CreateCodeGen( const std::string& name, StmtPtr stmt, const std::vector& params, at::Device device, const std::string& kernel_func_name) { RegisterCodeGenList::StmtFactoryMethod method = RegisterCodeGenList::GetInstance().FindStmtFactoryMethod(name); return method(std::move(stmt), params, device, kernel_func_name); } ExprPtr GenericIntrinsicsExpander::mutate(const IntrinsicsPtr& v) { if (v->op_type() == kSigmoid) { auto x = v->param(0)->accept_mutator(this); auto one = expr_to_vec( ExprHandle(getImmediateByType(v->dtype(), 1.0)), v->dtype().lanes()); auto zero = expr_to_vec( ExprHandle(getImmediateByType(v->dtype(), 0.0)), v->dtype().lanes()); ExprHandle y = one / (one + exp(zero - ExprHandle(x))); return y.node(); } return IRMutator::mutate(v); } void* CodeGen::argToPtr(const BufferArg& bufferArg, const CallArg& callArg) { if (!bufferArg.isVar()) { return callArg.data(); } switch (bufferArg.dtype().scalar_type()) { #define TYPE_CASE(_1, Name) \ case ScalarType::Name: \ return callArg.Name##Ptr(); AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE); #undef TYPE_CASE default: throw unsupported_dtype(); } return nullptr; } void CodeGen::call_with_numel(void** args, int64_t numel) { TORCH_INTERNAL_ASSERT( false, "This codegen backend does not implement call_with_numel"); } static std::optional bufSize(const BufPtr& buf) { size_t size = elementSize(buf->dtype().scalar_type()) * buf->dtype().lanes(); for (auto& d : buf->dims()) { if (!d->isConstant()) { return std::nullopt; } size = size * (*intValue(d)); } return size; } // This algorithm takes the list of intermediate buffers and their liveness // ranges, and returns the allocations of these buffers. A buffer 'A' can be // allocated in the memory (appears as a pair of 'A's in the allocation results) // or reuse another buffer such as 'B' (appears as ('A', 'B')). Specifically, we // linearly scan the intermediate buffers by the time they appear, and try to // assign it an existing non-occupied memory allocation. If there are no such // allocations available, we'll create memory for it. Once we are beyond the // liveness range of this buffer, we'll mark its corresponding memory allocation // as "up for grabs" for future reuse. static std::vector> AllocBufsWithMemReuse( const std::unordered_set& bufs, const std::unordered_map>& buf_ranges, const std::unordered_set& bufs_external_allocs) { // Sort buffers by the time they appear. std::vector bufs_sorted(bufs.begin(), bufs.end()); auto sorting_function_by_start_time = [&buf_ranges](const BufPtr& b1, const BufPtr& b2) -> bool { return std::get<0>(buf_ranges.at(b1)) < std::get<0>(buf_ranges.at(b2)); }; std::sort( bufs_sorted.begin(), bufs_sorted.end(), sorting_function_by_start_time); // Map intermediate buffers to the most recently used memory if any. std::list mem_up_for_grabs; std::unordered_map buf_mem_map; std::vector> buf_allocs; auto sorting_function_by_end_time = [&buf_ranges](const BufPtr& b1, const BufPtr& b2) -> bool { return std::get<1>(buf_ranges.at(b1)) < std::get<1>(buf_ranges.at(b2)); }; for (const auto& buf : bufs_sorted) { // If the buf has dynamic shapes, we'll skip it (i.e., allocate memory for // it, and there are no future reuses on its memory). // TODO: reuse memory for bufs with dynamic shapes if (!bufSize(buf)) { buf_allocs.emplace_back(buf, buf); continue; } auto start = std::get<0>(buf_ranges.at(buf)); // Release memory for buffers whose liveness range ends before the creation // time of this buf. // TODO: optimize in-place operations and copy operations std::vector buf_to_release; for (auto& mapped : buf_mem_map) { auto buf_mapped = mapped.first; auto end_buf_mapped = std::get<1>(buf_ranges.at(buf_mapped)); if (end_buf_mapped < start) { buf_to_release.push_back(buf_mapped); } } // Sort the buffers in the order of used time so the head of the release // list contains the most recently used buf. std::sort( buf_to_release.begin(), buf_to_release.end(), sorting_function_by_end_time); for (auto& buf_rl : buf_to_release) { mem_up_for_grabs.push_front(buf_mem_map.at(buf_rl)); buf_mem_map.erase(buf_rl); } bool allocated = false; if (bufs_external_allocs.find(buf) == bufs_external_allocs.end()) { // Check whether there are free memories that this buf can reuse. for (auto it = mem_up_for_grabs.begin(); it != mem_up_for_grabs.end(); it++) { auto m = *it; if (bufSize(m) >= bufSize(buf)) { buf_mem_map[buf] = m; buf_allocs.emplace_back(buf, m); allocated = true; mem_up_for_grabs.erase(it); break; } } } // If there are no memories to reuse, we'll have to allocate new memory for // it. if (!allocated) { buf_mem_map[buf] = buf; buf_allocs.emplace_back(buf, buf); } } return buf_allocs; } static StmtPtr insertAllocFree( std::vector>& buf_allocs, const std::unordered_set& bufs_external_allocs, const StmtPtr& stmt) { BlockPtr b = to(stmt); if (!b) { b = alloc(std::vector({stmt})); } std::vector bufs_ext_to_free; // Insert allocations and frees for temporary buffers at global scope. for (auto rit = buf_allocs.rbegin(); rit != buf_allocs.rend(); ++rit) { if (rit->first == rit->second) { BufPtr buf = rit->first; if (bufs_external_allocs.find(buf) == bufs_external_allocs.end()) { b->prepend_stmt(alloc(buf)); b->append_stmt(alloc(buf)); } else { bufs_ext_to_free.push_back(buf); } } else { b->prepend_stmt(alloc(rit->first, rit->second)); } } b->append_stmt(alloc(bufs_ext_to_free)); return b; } std::unordered_map ExtCallMemoryReuse:: makeExtCallFuncNameMap() { return { {"nnc_aten_quantize_per_tensor", "nnc_aten_quantize_per_tensor_out"}, {"nnc_aten_dequantize", "nnc_aten_dequantize_out"}, {"nnc_aten_quantized_mul", "nnc_aten_quantized_mul_out"}, {"nnc_aten_quantized_conv2d", "nnc_aten_quantized_conv2d_out"}, {"nnc_aten_quantized_conv2d_relu", "nnc_aten_quantized_conv2d_relu_out"}, {"nnc_aten_quantized_mul", "nnc_aten_quantized_mul_out"}, {"nnc_aten_quantized_sigmoid", "nnc_aten_quantized_sigmoid_out"}, {"nnc_aten_upsample_nearest2d", "nnc_aten_upsample_nearest2d_out"}, {"nnc_aten_quantized_linear", "nnc_aten_quantized_linear_out"}, {"nnc_aten_quantized_conv1d", "nnc_aten_quantized_conv1d_out"}, {"nnc_aten_quantized_mul_scalar", "nnc_aten_quantized_mul_scalar_out"}, {"nnc_aten_max_red", "nnc_aten_max_red_out"}, {"nnc_aten_conv1d", "nnc_aten_conv1d_out"}, }; } const std::unordered_map ExtCallMemoryReuse::extCallFuncNameMap_ = makeExtCallFuncNameMap(); ExtCallMemoryReuse::ExtCallMemoryReuse( const std::vector& bufferArgs) { for (const auto& ba : bufferArgs) { if (ba.buf()) { bufferArgs_.insert(ba.buf()); } } } StmtPtr ExtCallMemoryReuse::mutate(const ExternalCallPtr& v) { if (extCallFuncNameMap_.count(v->func_name()) && bufferArgs_.count(v->buf()) == 0) { std::vector buf_out_args = {v->buf()}; return alloc( extCallFuncNameMap_.at(v->func_name()), buf_out_args, v->buf_args(), v->args()); } return v; } // We allocate intermediate buffers by inserting Allocate/Free or // PlacementAllocate stmts. Allocate/Free stmts will allocate memory at runtime, // and PlacementAllocate stmt reuses the memory of one buffer for another // buffer. In current implementation, we use linear scan for memory reuses. // TODO: try more memory reuse algorithms and compare their memory efficiency. void CodeGen::allocIntermediateBufs() { // Identify intermediate buffers that are not allocated yet. auto bufs = NodeFinder::find(stmt_); std::unordered_set bufs_allocated; for (const auto& b : buffer_args_) { bufs_allocated.insert(b.buf()); } auto allocs = NodeFinder::find(stmt_); for (const auto& a : allocs) { bufs_allocated.insert(a->buf()); } std::unordered_set interm_bufs; std::unordered_map> interm_buf_ranges; for (const auto& buf : bufs) { if (!bufs_allocated.count(buf) && !interm_bufs.count(buf)) { interm_bufs.insert(buf); // Identify the access stmts to each unallocated intermediate buffer. auto range = BufLiveRange::liveRange(stmt_, buf); interm_buf_ranges.emplace(buf, range); } } const auto bufs_external_allocs = ExternalAllocBufFinder::find(stmt_); // For each intermediate buffer, we reuse the memory of an old buffer whose // liveness range does not overlap with the current buffer, or allocate memory // if reusing buffer is impossible. auto buf_allocs = AllocBufsWithMemReuse( interm_bufs, interm_buf_ranges, bufs_external_allocs); // Insert memory allocation/mapping nodes. if (!buf_allocs.empty()) { auto stmt_new = insertAllocFree(buf_allocs, bufs_external_allocs, stmt_); set_stmt(stmt_new); } GRAPH_DEBUG("\nMemory Allocation:\n\n", *stmt(), "\n"); } } // namespace torch::jit::tensorexpr