#include #include #include #include #include #include #include #include #include #include namespace torch::jit { namespace { struct WorkBlock : public std::pair { using pair::pair; Node* begin() { return this->first; } Node* end() { return this->second; } }; class SubgraphSlicer { public: SubgraphSlicer( Block* block, std::shared_ptr graph, size_t minSubgraphSize, AliasDb& aliasDb, std::vector& diff_nodes) : block_(block), graph_(std::move(graph)), minSubgraphSize_(minSubgraphSize), aliasDb_(aliasDb), diff_nodes_(diff_nodes) {} void run() { // We maintain alias db correctness in-place while building up the autodiff // subgraphs, however it is difficult to preserve correctness when // un-inlining autodiff subgraphs. We first recursively construct all // subgraphs and then recursively cleanup & unmerge the small subgraphs buildupSubgraphs(); GRAPH_DUMP("before unfuseAliasedOutputs", graph_); unfuseAliasedOutputs(block_); cleanupSubgraphs(); // Run CSE globally onceto eliminate duplicates that may have occurred // while inlining subgraphs. EliminateCommonSubexpression(graph_); } void cleanupSubgraphs() { auto curNode = *block_->nodes().rbegin(); while (curNode != *block_->nodes().rend()) { // Save the previous node, since we might delete `curNode` in next block auto prevNode = curNode->prev(); if (curNode->kind() == prim::DifferentiableGraph) { // Inlining nodes may cause some subexpression to come back in the // subgraphs (for example, copying constants in repeatedly will generate // redundant prim::Constants). Run CSE to clean them up. EliminateCommonSubexpression(curNode->g(attr::Subgraph)); if (!inlineIfTooSmall(curNode)) { diff_nodes_.push_back(curNode); } } curNode = prevNode; } for (Node* n : block_->nodes()) { for (Block* b : n->blocks()) { SubgraphSlicer(b, graph_, minSubgraphSize_, aliasDb_, diff_nodes_) .cleanupSubgraphs(); } } } void buildupSubgraphs() { // We need to run the slicer multiple times in order to get all merge // opportunities. This is because moveBeforeTopologicalValid may reorder // nodes to be AFTER the current iteration point. In order to properly // consider those nodes for merging, we need run the pass until no changes // have been made. // // Example: // c = f(a, b) // d = f(c) // e = f(d) <- iter is here, moving upward // After c.moveBeforeTopologicallyValid(e), we have: // c = f(a, b) // e = f(d) <- iter still here // d = f(c) <- this was node moved on the other side. // see [workblocks] auto workblocks = buildWorkBlocks(); for (auto& workblock : workblocks) { bool any_changed = true; while (any_changed) { any_changed = false; for (auto it = workblock.end()->reverseIterator(); it != workblock.begin()->reverseIterator();) { auto [tmp_it, changed] = scanNode(*it); it = tmp_it; any_changed |= changed; } } } // Construct Subgraphs Recursively for (Node* n : block_->nodes()) { for (auto subBlock : n->blocks()) { SubgraphSlicer( subBlock, graph_, minSubgraphSize_, aliasDb_, diff_nodes_) .buildupSubgraphs(); } } } private: void unfuseAliasedOutputs(Block* b) { bool any_changed = true; while (any_changed) { any_changed = false; // we walk in the reverse order, so we can skip // nodes that might get unfused after the current // prim::DifferentiableGraph for (auto n : b->nodes().reverse()) { if (n->kind() == prim::DifferentiableGraph) { // aliased outputs in DifferentiableGraphs must be unfused // since autodiff doesn't know how to handle them correctly // N.B. Note, |= since we don't want `unfuseAliasedOutputs` // to short-circuit any_changed |= SubgraphUtils::unmergeAliasedOutputs(n); any_changed |= SubgraphUtils::unmergeOutputsAlisingInputs(n); GRAPH_DEBUG( "any_changed on ", any_changed, " ", n->g(attr::Subgraph)->toString(false)); } } } for (Node* n : b->nodes()) { for (Block* ib : n->blocks()) { unfuseAliasedOutputs(ib); } } } std::vector buildWorkBlocks() { // [workblocks] // the IR has many nodes which can never be reordered around, such as a // prim::Bailout. if a node N is surrounded by two nodes which cannot be // reordered, A and B, then a differentiable subgraph that is created from N // can only contain nodes from (A, B) The nodes from A to B represent one // work block for the subgraph slicer to work on. By creating these up // front, we avoid retraversing the whole graph block any time scanNode // returns, and we can also avoid attempting to create differentiable // subgraphs in work blocks that do not contain a # of differentiable nodes // >= minSubgraphSize_ Node* end_bound_node = block_->return_node(); Node* curr = end_bound_node->prev(); std::vector worklist; size_t differentiable_nodes = 0; while (curr != block_->param_node()) { differentiable_nodes += shouldConsiderForMerge(curr); // cannot reorder around side effectful nodes if (curr->hasSideEffects()) { // not enough differentiable nodes to create a differentiable subgraph if (differentiable_nodes >= minSubgraphSize_) { worklist.emplace_back(curr, end_bound_node); } differentiable_nodes = 0; end_bound_node = curr; } curr = curr->prev(); } if (differentiable_nodes >= minSubgraphSize_) { worklist.emplace_back(curr, end_bound_node); } return worklist; } // Inline this node's group subgraph into the outer graph if it's smaller // than the specified minimum size. // // Returns true if an inlining has occurred, false otherwise. bool inlineIfTooSmall(Node* n) { AT_ASSERT(n->kind() == prim::DifferentiableGraph); auto subgraph = SubgraphUtils::getSubgraph(n); size_t i = 0; for (auto it = subgraph->nodes().begin(); it != subgraph->nodes().end(); ++it) { i += !it->notExecutedOp(); if (i >= minSubgraphSize_) { return false; } } SubgraphUtils::unmergeSubgraph(n); return true; } value_list sortReverseTopological(ArrayRef inputs) { value_list result; for (auto i : inputs) { if (i->node()->owningBlock() == block_) { result.push_back(i); } } // Sort in reverse topological order std::sort(result.begin(), result.end(), [&](Value* a, Value* b) { return a->node()->isAfter(b->node()); }); return result; } bool isViewOp(Node* n) { switch (n->kind()) { case aten::view: case aten::view_as: case aten::reshape: case aten::reshape_as: case aten::transpose: case aten::expand: case aten::expand_as: return true; } return false; } bool shouldConsiderForMerge(Node* node) { // if we're already in the process of merging if (node->kind() == prim::DifferentiableGraph) { return true; } if (node->kind() == prim::Constant) { return false; } // view ops as outputs of differentiable subgraphs can cause incorrect // differentiation for now, do not include them in the subgraph if (isViewOp(node)) { return false; } return isDifferentiable(node); } std::pair scanNode(Node* consumer) { if (shouldConsiderForMerge(consumer)) { if (consumer->kind() != prim::DifferentiableGraph) { consumer = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing( consumer, prim::DifferentiableGraph, aliasDb_); } auto inputs = sortReverseTopological(consumer->inputs()); for (auto input : inputs) { if (auto group = tryMerge(consumer, input->node())) { // we successfully merged, so the new group's `inputs` may have // changed. So rescan the new group for more merging opportunities. return std::make_pair(group.value()->reverseIterator(), true); } } } return std::make_pair(++consumer->reverseIterator(), false); } // Try to merge `producer` into `consumer`. If successful, this destroys // `producer` and returns the `consumer` group. std::optional tryMerge(Node* consumer, Node* producer) { AT_ASSERT(consumer->kind() == prim::DifferentiableGraph); bool canMerge = shouldConsiderForMerge(producer) && aliasDb_.moveBeforeTopologicallyValid(producer, consumer); if (!canMerge) { return std::nullopt; } SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing( producer, consumer, aliasDb_); return consumer; } Block* block_; std::shared_ptr graph_; size_t minSubgraphSize_; AliasDb& aliasDb_; std::vector& diff_nodes_; }; std::optional getProfileNodeRequiresGrad(Node* n) { TORCH_INTERNAL_ASSERT(n->kind() == prim::profile); if (!n->hasAttribute(attr::profiled_type)) { return std::nullopt; } auto& type = n->ty(attr::profiled_type); if (type->castRaw() == nullptr) { return std::nullopt; } return type->expectRef().requiresGrad(); } struct ContextMapping { std::vector ctx_stack_; std::unordered_map node_to_ctx_; void processNode(Node* n) { node_to_ctx_[n] = ctx_stack_.back(); if (n->kind() == prim::Enter) { ctx_stack_.push_back(n); } else if (n->kind() == prim::Exit) { ctx_stack_.pop_back(); } } void processBlock(Block* block) { for (Node* n : block->nodes()) { processNode(n); for (Block* b : n->blocks()) { processBlock(b); } if (n->kind() == prim::DifferentiableGraph) { const auto& subgraph = n->g(attr::Subgraph); processBlock(subgraph->block()); } } } ContextMapping(const std::shared_ptr& graph) { ctx_stack_.push_back(nullptr); processBlock(graph->block()); } const Node* get(const Node* n) const { auto it = node_to_ctx_.find(n); TORCH_INTERNAL_ASSERT( it != node_to_ctx_.end(), "Cannot find node in node-to-context mapping."); return it->second; } bool has(const Node* n) const { return node_to_ctx_.find(n) != node_to_ctx_.end(); } }; std::optional findRequiresGradForOutput( Node* diff_graph, Value* output, const ContextMapping& ctx_mapping) { for (auto& use : output->uses()) { // [Only consider profiles in the same context] // Ignore profiled uses if the use is within a different context. // For example, a profile node within a no_grad() context will record the // wrong requires_grad information. if (ctx_mapping.has(use.user) && ctx_mapping.get(use.user) != ctx_mapping.get(diff_graph)) { continue; } if (use.user->kind() == prim::profile) { auto req_grad_use = getProfileNodeRequiresGrad(use.user); if (req_grad_use.has_value()) { return req_grad_use; } } // maybe the profile node got absorbed into a differentiable graph if (use.user->kind() == prim::DifferentiableGraph) { const auto& dg = use.user->g(attr::Subgraph); // check all the uses of this graph input to look for profile nodes. Value* dg_value = dg->inputs()[use.offset]; for (auto& dg_use : dg_value->uses()) { // See [Only consider profiles in the same context] if (ctx_mapping.has(dg_use.user) && ctx_mapping.get(dg_use.user) != ctx_mapping.get(diff_graph)) { continue; } if (dg_use.user->kind() == prim::profile) { auto req_grad_use = getProfileNodeRequiresGrad(dg_use.user); if (req_grad_use.has_value()) { return req_grad_use; } } } } } return std::nullopt; } void AddRequiresGradToDifferentiableGraph( Node* diff_graph, const ContextMapping& ctx_mapping) { TORCH_INTERNAL_ASSERT(diff_graph->kind() == prim::DifferentiableGraph); const auto& subgraph = diff_graph->g(attr::Subgraph); for (auto i : c10::irange(subgraph->outputs().size())) { Value* output = subgraph->outputs()[i]; if (output->node()->kind() == prim::profile) { // already have requires_grad info from this profile node continue; } if (output->type()->castRaw() == nullptr) { // non-tensors don't get profiled. continue; } if (output->type()->expectRef().requiresGrad().has_value()) { continue; } // this node doesn't have any requires_grad info. // look at its uses to try to find a profile node. auto requires_grad = findRequiresGradForOutput( diff_graph, diff_graph->output(i), ctx_mapping); output->setType(output->type()->expectRef().withRequiresGrad( requires_grad)); } } void AddRequiresGradOnOutputNodes( Block* block, const ContextMapping& ctx_mapping) { for (Node* n : block->nodes()) { if (n->kind() == prim::DifferentiableGraph) { AddRequiresGradToDifferentiableGraph(n, ctx_mapping); } for (Block* b : n->blocks()) { AddRequiresGradOnOutputNodes(b, ctx_mapping); } } } // autodiff.cpp needs to know, for each output, whether or not it requires // grad. Sometimes a profile node will be present on the output, but sometimes // it won't be present. This might happen if there's a node with side effects // in between the definition of the output node and the profile node; in this // case the profile node and output node would be in different workblocks and // couldn't be merged into the same DifferentiableGraph. (see [workblocks]) // Or it could happen if the output is profiled twice and the profile nodes get // removed by unfusedAliasedOutputs. void AddRequiresGradOnOutputNodes(const std::shared_ptr& graph) { ContextMapping ctx_mapping(graph); AddRequiresGradOnOutputNodes(graph->block(), ctx_mapping); } } // anonymous namespace std::vector CreateAutodiffSubgraphs( const std::shared_ptr& graph, size_t threshold) { std::vector diff_nodes; AliasDb db(graph); GRAPH_DEBUG("Before creating autodiff subgraphs", *graph); SubgraphSlicer(graph->block(), graph, threshold, db, diff_nodes).run(); GRAPH_DEBUG("After creating autodiff subgraphs", *graph); AddRequiresGradOnOutputNodes(graph); GRAPH_DEBUG("diff_nodes.size() ", diff_nodes.size()); return diff_nodes; } } // namespace torch::jit