#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::jit { namespace { c10::AliasAnalysisKind aliasAnalysisIsSpecialCase() { return AliasAnalysisKind::INTERNAL_SPECIAL_CASE; } } // namespace // This pass looks for trees in the graph, where leaves are mm ops, and the // inner vertices are add nodes. Once we have such a tree they can be reduced to // two concats and a single mm (basically into a single multiply of a wide // matrix, with a tall matrix). Such patterns show up mostly in backward of // RNNs, since the derivative of many uses of matrix multiplies with same // weights forms exactly such a tree (note that it's usually also highly // imbalanced i.e. has O(n) depth). // // This (or any tree of adds of MMs): // // +------+ +------+ +------+ +------+ +------+ // | | | | | | | | | | // | L1 | | R1 | + | L2 | | R2 | = | O | // | | | | | | | | | | // +------+ +------+ +------+ +------+ +------+ // // can be basically transformed into a single MM which looks like this // (we concat all lhs operands, concat rhs operands, do mm): // // +------+ // | | // | R1 | // | | // +------+ // | | // | R2 | // | | // +------+ // +------+------+ +------+ // | | | | | // | L1 | L2 | | O | // | | | | | // +------+------+ +------+ // Note [Further optimizations] // It would be straightforward to extend the TreeToken class to also detect if // all MMs had the same lhs/rhs. In such case it's more efficient to expand the // lhs and use bmm + sum instead of repeating it in memory via concat. // Note [Overlapping trees] // Additionally it wouldn't be too hard to add support for partially overlapping // trees. Right now the it's forbidden in the algorithm (only a single tree will // be allowed), so theoretically we might miss some optimization options, // especially that the rejected tree could be much larger. I didn't implement // that because it's not necessary for the simple RNN cases I saw, so I decided // to keep stuff simple. If we ever get around implementing this, the right // solution is probably to fuse MMs for the common part, and assume it's an // input leaf for the outer two parts (I don't think it's beneficial to // recompute, unless the subtree is super small, but let's not get into such // details). // The algorithm we're using is simple. We're iterating through the graph in the // topological order and labeling nodes with TreeTokens. Then, we look for roots // of the trees we formed and fuse them. // Tunable parameter. Set to something larger if it turns out to be better. static constexpr size_t min_fusion_size = 4; static bool have_same_shape(at::TensorList inputs) { auto expected_sizes = inputs[0].sizes(); return (std::all_of( inputs.begin(), inputs.end(), [expected_sizes](const at::Tensor& t) { return t.sizes() == expected_sizes; })); } static bool should_be_transposed(at::TensorList inputs) { return (std::all_of(inputs.begin(), inputs.end(), [](const at::Tensor& t) { return t.stride(0) == 1 && t.stride(1) == t.size(0); })); } static std::vector transpose_inputs(at::TensorList inputs) { return fmap(inputs, [](const at::Tensor& i) { return i.t(); }); } static bool shape_is_fast_for_reduce( const at::Tensor& lhs, const at::Tensor& rhs) { size_t l = lhs.size(0); size_t m = lhs.size(1); size_t r = rhs.size(1); // Numbers obtained by some simple benchmarks of fp32 gemms on a TITAN V return m < 512 || ((l < 256 && r < 256) || (l > 256 && r > 256)); } RegisterOperators mm_tree_reduction_reg({Operator( "prim::MMTreeReduce(...) -> Tensor", [](Stack& stack) { auto num_inputs = pop(stack).toInt(); std::vector inputs; inputs.reserve(num_inputs); for (auto it = stack.end() - num_inputs; it != stack.end(); ++it) { inputs.push_back(std::move(*it).toTensor()); } drop(stack, num_inputs); AT_ASSERT(!inputs.empty()); AT_ASSERT(inputs.size() % 2 == 0); size_t side_num_elems = inputs.size() / 2; auto lhs_inputs = at::TensorList(inputs).slice(0, side_num_elems); auto rhs_inputs = at::TensorList(inputs).slice(side_num_elems); // TODO: checking this is not free, so we should stop if this keeps // failing if (have_same_shape(lhs_inputs) && have_same_shape(rhs_inputs) && shape_is_fast_for_reduce(lhs_inputs[0], rhs_inputs[0])) { // sometimes lhs_inputs or rhs_inputs are not contiguous, and that // causes at::cat to go through slow path view them as contiguous if // possible by transposing bool lhs_input_transposed = should_be_transposed(lhs_inputs); bool rhs_input_transposed = should_be_transposed(rhs_inputs); at::Tensor lhs, rhs; if (lhs_input_transposed) { std::vector lhs_contig_inputs = transpose_inputs(lhs_inputs); lhs = at::cat(lhs_contig_inputs, /*dim*/ 0); lhs = lhs.t(); } else { lhs = at::cat(lhs_inputs, /*dim=*/1); } if (rhs_input_transposed) { std::vector rhs_contig_inputs = transpose_inputs(rhs_inputs); rhs = at::cat(rhs_contig_inputs, /*dim*/ 1); rhs = rhs.t(); } else { rhs = at::cat(rhs_inputs, /*dim=*/0); } push(stack, at::mm(lhs, rhs)); } else { auto acc = at::mm(inputs[0], inputs[side_num_elems]); for (const auto i : c10::irange(1, side_num_elems)) { acc.add_(at::mm(inputs[i], inputs[side_num_elems + i])); } push(stack, std::move(acc)); } }, aliasAnalysisIsSpecialCase())}); // TreeTokens will be used to label nodes of the graph, if the nodes will fit // our mm/add tree pattern. Basically we do dynamic programming on DAGs, where // when we reach node N with inputs A and B, then A and B have already been // processed, and we can try to unify their TreeTokens (if they have them) // and build a larger tree. struct TreeToken { uint64_t tree_size = 0; // NOTE: measured in number of leaves i.e. mm ops Node* node = nullptr; bool is_root = false; static TreeToken mm(Node* mm) { TreeToken token; token.tree_size = 1; token.node = mm; token.is_root = true; return token; } // NB: the returned token might be invalid, so make sure to check its boolean // value! static TreeToken transpose(Node* t, TreeToken& inp_token) { TreeToken token; if (!inp_token.node->matches( "aten::mm(Tensor self, Tensor mat2) -> Tensor")) { return token; } token.tree_size = 1; token.node = t; token.is_root = true; inp_token.is_root = false; return token; } // NB: the returned token might be invalid, so make sure to check its boolean // value! static TreeToken add(Node* add, TreeToken& l, TreeToken& r) { TreeToken token; // See Note [Overlapping trees] if (&l == &r || !l.is_root || !r.is_root) return token; token.tree_size = l.tree_size + r.tree_size; token.node = add; token.is_root = true; l.is_root = r.is_root = false; // Reserve the subtrees, so they can't be used again. return token; } explicit operator bool() { return is_root; } std::vector removeTransposesAndGatherMatmuls() { std::vector matmuls; std::vector queue{node}; Graph* graph = node->owningGraph(); while (!queue.empty()) { auto n = queue.back(); queue.pop_back(); if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) { matmuls.push_back(n); } else if (n->matches("aten::t(Tensor self) -> Tensor")) { Node* input_node = n->input()->node(); AT_ASSERT(input_node->matches( "aten::mm(Tensor self, Tensor mat2) -> Tensor")); // (AB)^T == B^TA^T WithInsertPoint insert_guard{input_node}; Value* A = input_node->inputs()[0]; Value* B = input_node->inputs()[1]; Value* AT = graph->insert(aten::t, {A}); Value* BT = graph->insert(aten::t, {B}); Value* BTAT = graph->insert(aten::mm, {BT, AT}); n->output()->replaceAllUsesWith(BTAT); matmuls.push_back(BTAT->node()); } else if ( n->matches( "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) { queue.push_back(n->inputs()[0]->node()); queue.push_back(n->inputs()[1]->node()); } else { AT_ASSERTM(false, "Unsupported node found in a BatchMM tree!"); } } return matmuls; } }; enum class Side { LHS, RHS }; static void BatchMMTreeReduce(Block* block, AliasDb& alias_db) { auto graph = block->owningGraph(); // Look for trees in the block std::unordered_map tokens; for (auto node : block->nodes()) { if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") && !alias_db.hasWriters(node)) { tokens[node] = TreeToken::mm(node); } else if ( node->matches("aten::t(Tensor self) -> Tensor") && !alias_db.hasWriters(node)) { auto input_it = tokens.find(node->input()->node()); if (input_it != tokens.end()) { tokens[node] = TreeToken::transpose(node, input_it->second); } } else if ( node->matches( "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") && !alias_db.hasWriters(node)) { Node* lhs = node->inputs()[0]->node(); Node* rhs = node->inputs()[1]->node(); auto lhs_it = tokens.find(lhs); auto rhs_it = tokens.find(rhs); // See Note [Overlapping trees] (regarding the uses().size() == 1 check) // We could treat a subtree with multiple uses as if it was overlapping. // XXX: uses().size() == 1 is also something that guarantees that this // transform is valid, because we know for sure that the none of these // operands depend on the result of the other. If we were to remove this, // we need to compute a transitive closure and actually check the // dependencies. if (lhs_it != tokens.end() && rhs_it != tokens.end() && lhs->output()->uses().size() == 1 && rhs->output()->uses().size() == 1) { if (auto token = TreeToken::add(node, lhs_it->second, rhs_it->second)) { tokens[node] = token; } } } else { for (auto block : node->blocks()) { BatchMMTreeReduce(block, alias_db); } } } // Merge trees we've found for (auto& item : tokens) { auto& root = item.second; if (!root || root.tree_size < min_fusion_size) continue; auto matmuls = root.removeTransposesAndGatherMatmuls(); WithInsertPoint insert_guard{root.node}; Node* tree_reduce = graph->insertNode(graph->create(Symbol::prim("MMTreeReduce"))); for (Node* matmul : matmuls) { tree_reduce->addInput(matmul->inputs().at(0)); } for (Node* matmul : matmuls) { tree_reduce->addInput(matmul->inputs().at(1)); } root.node->output()->replaceAllUsesWith(tree_reduce->output()); // NB: don't bother with cleaning up after yourself. We'll use DCE for that. } } static bool shape_is_fast_for_side(const at::Tensor& other_side_input) { // Cutoff chosed by benchmarking on a TITAN V return other_side_input.numel() <= 1024 * 2048; } RegisterOperators mm_batch_side_reg({Operator( prim::MMBatchSide, [](const Node* node) -> Operation { size_t num_other_side_inputs = node->inputs().size() - 1; Side single_side = static_cast(node->i(Symbol::attr("side"))); return [num_other_side_inputs, single_side](Stack& stack) { at::Tensor side_input; std::vector other_side_inputs; other_side_inputs.reserve(num_other_side_inputs); for (auto it = stack.end() - num_other_side_inputs; it != stack.end(); ++it) { other_side_inputs.push_back(std::move(*it).toTensor()); } drop(stack, num_other_side_inputs); pop(stack, side_input); auto any_other_input = other_side_inputs[0]; if (have_same_shape(other_side_inputs) && shape_is_fast_for_side(other_side_inputs[0])) { auto other_side_input = at::cat(other_side_inputs, single_side == Side::LHS ? 1 : 0); auto mm_out = single_side == Side::LHS ? side_input.mm(other_side_input) : other_side_input.mm(side_input); auto outputs = at::chunk( mm_out, num_other_side_inputs, /*dim=*/single_side == Side::LHS ? 1 : 0); stack.insert( stack.end(), std::make_move_iterator(outputs.begin()), std::make_move_iterator(outputs.end())); } else { if (single_side == Side::LHS) { for (at::Tensor& other : other_side_inputs) { stack.emplace_back(side_input.mm(other)); } } else { for (at::Tensor& other : other_side_inputs) { stack.emplace_back(other.mm(side_input)); } } } }; }, aliasAnalysisIsSpecialCase())}); static std::pair, std::vector> gatherIndependentMMUses( Value* value, AliasDb& alias_db) { const auto postprocess = [&](std::vector mms) { if (mms.empty()) { return mms; } std::sort(mms.begin(), mms.end(), [](Node* n, Node* m) { return n->isBefore(m); }); // Filter out dependent MMs. This algorithm might do very badly if e.g. you // have a lot of independent MMs, that depend on the first one, but I doubt // this will be a common scenario. for (const auto i : c10::irange(mms.size())) { if (mms[i] == nullptr) continue; for (size_t j = i + 1; j < mms.size(); ++j) { if (mms[j] == nullptr) continue; if (!alias_db.couldMoveBeforeTopologically(mms[j], mms[i])) { mms[j] = nullptr; } } } return c10::filter(mms, [](Node* n) { return n != nullptr; }); }; Block* block = value->node()->owningBlock(); std::vector lhses; // Will contain nodes where value is used as an lhs std::vector rhses; // Like above, but rhs for (Use u : value->uses()) { if (u.user->owningBlock() == block && u.user->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") && !alias_db.hasWriters(u.user)) { if (u.offset == 0 && u.user->inputs()[1] != value) { lhses.push_back(u.user); } else if (u.offset == 1 && u.user->inputs()[0] != value) { rhses.push_back(u.user); } } } return std::make_pair( postprocess(std::move(lhses)), postprocess(std::move(rhses))); } static void BatchMMSide(Block* block, AliasDb& alias_db) { // NB: 8 is the current loop unrolling factor static constexpr size_t how_many_is_many = 8; const auto batch_side = [&](std::vector& mms, Side side) { AT_ASSERT(!mms.empty()); for (int64_t i = static_cast(mms.size()) - 2; i >= 0; --i) { bool move_ok = alias_db.moveBeforeTopologicallyValid(mms[i], mms[i + 1]); AT_ASSERT(move_ok); } WithInsertPoint insert_guard{mms[0]}; Graph* graph = mms[0]->owningGraph(); Node* batch_mm = graph->create( prim::MMBatchSide, /*inputs=*/{}, /*num_outputs=*/mms.size()); graph->insertNode(batch_mm); batch_mm->i_(Symbol::attr("side"), static_cast(side)); Value* const_side = mms[0]->inputs().at(side == Side::LHS ? 0 : 1); batch_mm->addInput(const_side); for (const auto i : c10::irange(mms.size())) { batch_mm->addInput(mms[i]->inputs().at(side == Side::LHS ? 1 : 0)); mms[i]->output()->replaceAllUsesWith(batch_mm->outputs().at(i)); } }; std::unordered_set considered_values; for (Node* node : block->nodes()) { if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") && !alias_db.hasWriters(node)) { for (Value* input : node->inputs()) { if (/*bool not_inserted = */ !considered_values.emplace(input).second) { continue; } auto uses_with_many = gatherIndependentMMUses(input, alias_db); if (uses_with_many.first.size() >= how_many_is_many) { batch_side(uses_with_many.first, Side::LHS); } if (uses_with_many.second.size() >= how_many_is_many) { batch_side(uses_with_many.second, Side::RHS); } } } else { for (Block* subblock : node->blocks()) { BatchMMSide(subblock, alias_db); } } } } static bool hasMMOperators(std::shared_ptr& graph) { DepthFirstGraphNodeIterator it(graph); Node* n = nullptr; while ((n = it.next()) != nullptr) { if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) { return true; } } return false; } void BatchMM(std::shared_ptr& graph) { if (!hasMMOperators(graph)) { return; } AliasDb alias_db(graph); BatchMMTreeReduce(graph->block(), alias_db); BatchMMSide(graph->block(), alias_db); EliminateDeadCode(graph); // It's possible that transpose rearrangements have created sequences of // consecutive transposes that didn't exist before. // tensor type properties are not guaranteed to be correct PeepholeOptimize(graph, /*disable_shape_peepholes*/ true); } } // namespace torch::jit