#include #include namespace torch::jit { bool MutationRemover::removeListMutation() { return RemoveListMutation(graph_->block()); } bool MutationRemover::removeTensorMutation() { return RemoveTensorMutation(graph_->block()); } bool MutationRemover::hasSideEffectOrAlias(Value* v, AliasDb* aliasDb) { // bail on nodes with side effects, blocks, or graph / graph inputs Node* n = v->node(); bool unhandled_node = !n->blocks().empty() || n->hasAttribute(attr::Subgraph) || n->hasSideEffects() || (v->node()->kind() == prim::Param); // if the output isn't contained or alias by the inputs to its node, it's // unique. No need to check for alias if the node is a ListConstruct. bool mayAliasInputs = (v->node()->kind() != prim::ListConstruct) && aliasDb->mayContainAlias(v->node()->inputs(), v); return unhandled_node || mayAliasInputs || (v->node()->kind() == prim::Param); } Node* MutationRemover::createSpecialMappedOp(Node* n) { WithInsertPoint guard(n); auto inputs = n->inputs(); Node* new_node = nullptr; if (n->matches( "aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)")) { auto dtype = graph_->insert(prim::dtype, {inputs.at(0)}); new_node = graph_ ->insert( aten::full_like, {inputs.at(0), inputs.at(1)}, {NamedValue("dtype", dtype)}) ->node(); new_node->copyMetadata(n); new_node->output()->setType(n->output()->type()); } else if (n->matches("aten::zero_(Tensor(a!) self) -> Tensor(a!)")) { new_node = graph_->insert(aten::zeros_like, {n->inputs().at(0)})->node(); } else if ( n->matches( "aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)")) { // TODO: we should have normal_like operator // normal(float mean, float std, int[] size, *, Generator? generator=None, // ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? // pin_memory=None) -> Tensor auto size = graph_->insert(aten::size, {n->inputs().at(0)}); auto dtype = graph_->insert(prim::dtype, {n->inputs().at(0)}); auto layout = graph_->insert(prim::layout, {n->inputs().at(0)}); auto device = graph_->insert(prim::device, {n->inputs().at(0)}); auto pin_memory = graph_->insert(aten::is_pinned, {n->inputs().at(0)}); auto generator = graph_->insertConstant(IValue()); new_node = graph_->insertNode(graph_->create( aten::normal, {n->inputs().at(1), n->inputs().at(2), size, generator, dtype, layout, device, pin_memory})); } else { TORCH_INTERNAL_ASSERT(false); } new_node->copyMetadata(n); new_node->output()->setType(n->output()->type()); return new_node; } static bool removableSetItem(Node* n) { if (n->kind() != aten::_set_item || n->input(1)->node()->kind() != prim::Constant) { return false; } if (n->inputs().at(0)->node()->kind() != prim::ListConstruct) { return false; } auto li_node = n->inputs().at(0)->node(); int64_t index = *constant_as(n->input(1)); if (index < 0) { index += li_node->inputs().size(); } auto li_len = static_cast(li_node->inputs().size()); return index < li_len && index >= 0; } bool MutationRemover::listMutationFollowingListConstruct(Node* n) { return ( (n->kind() == aten::append || (n->kind() == aten::insert && n->inputs().at(1)->node()->kind() == prim::Constant) || (removableSetItem(n))) && n->inputs().at(0)->node()->kind() == prim::ListConstruct); } bool MutationRemover::tryMakeCreationAndMutationAtomic( Value* mutated_value, Node* mutating_op) { // We can only remove mutation to values that are unique aliases in the // graph. if x = y[0] or y = self.y, then removing the mutation could // change observable semantics if (hasSideEffectOrAlias(mutated_value, getOrCreateAliasDb())) { return false; } // In order to safely remove a mutation, the creation of a tensor and its // subsequent mutation need to be one atomic operation return getOrCreateAliasDb()->moveBeforeTopologicallyValid( mutated_value->node(), mutating_op); } bool MutationRemover::tryMakeUnaliasedIfOutputAndMutationAtomic( Value* mutated_value, Node* mutating_op) { // if cond: // x = op() // else: // x = op() // x = add_(1) // if x in both blocks have no other uses and are unaliased in the graph, // and we make the if node and the mutation atomic, // then removing mutation add_ does not change observable semantics if (mutated_value->node()->kind() != prim::If) { return false; } auto if_node = mutated_value->node(); auto offset = mutated_value->offset(); auto true_value = if_node->blocks().at(0)->outputs().at(offset); auto false_value = if_node->blocks().at(1)->outputs().at(offset); if (true_value->uses().size() > 1 || false_value->uses().size() > 1) { return false; } if (hasSideEffectOrAlias(true_value, getOrCreateAliasDb()) || hasSideEffectOrAlias(false_value, getOrCreateAliasDb())) { return false; } return getOrCreateAliasDb()->moveBeforeTopologicallyValid( if_node, mutating_op); } bool MutationRemover::RemoveListMutation(Block* block) { bool changed = false; for (auto it = block->nodes().begin(); it != block->nodes().end();) { auto* node = *it; it++; for (Block* sub_block : node->blocks()) { changed |= RemoveListMutation(sub_block); } if (!listMutationFollowingListConstruct(node)) { continue; } Value* mutated_value = node->inputs().at(0); if (!tryMakeCreationAndMutationAtomic(mutated_value, node)) { continue; } changed = true; // We rewrite something like: // x = {v0} // x.append(v1) (or x.insert(0, v1)) // to: // x = {v0, v1} (or x = {v1, v0}) // We can remove x.append from the alias db list of writes. // All other aliasing properties remain valid. Node* list_construct = mutated_value->node(); switch (node->kind()) { case aten::append: list_construct->addInput(node->inputs().at(1)); break; case aten::insert: { int pos = toIValue(node->inputs().at(1))->toInt(); int size = list_construct->inputs().size(); // insert to neg position equals insert to std::max(pos+size, 0) if (pos < 0) { pos = std::max(pos + size, 0); } // insert beyond current list length is the same as append pos = std::min(pos, size); list_construct->insertInput(pos, node->inputs().at(2)); break; } case aten::_set_item: { int pos = toIValue(node->inputs().at(1))->toInt(); int size = list_construct->inputs().size(); if (pos < 0) { pos = std::max(pos + size, 0); } list_construct->replaceInput(pos, node->input(2)); break; } default: TORCH_INTERNAL_ASSERT(false); } // process use-chain and aliasing of node output bool has_output = (!node->outputs().empty()); if (has_output) { node->output()->replaceAllUsesWith(mutated_value); getOrCreateAliasDb()->writeIndex_->erase(node); } node->destroy(); // TODO: don't strictly need to reset write cache, evaluate on models getOrCreateAliasDb()->buildWrittenToLocationsIndex(); } return changed; } bool MutationRemover::RemoveTensorMutation(Block* block) { bool changed = false; for (auto it = block->nodes().begin(); it != block->nodes().end();) { auto* node = *it; it++; for (Block* sub_block : node->blocks()) { changed |= RemoveTensorMutation(sub_block); } if (mutation_filter_) { const auto& mutation_filter = *mutation_filter_; if (!mutation_filter(node)) { continue; } } // TODO: out op variants if (!inplaceOpVariant(node)) { continue; } Value* mutated_value = node->inputs().at(0); if (!tryMakeCreationAndMutationAtomic(mutated_value, node) && !tryMakeUnaliasedIfOutputAndMutationAtomic(mutated_value, node)) { continue; } Node* new_node = nullptr; if (isSpecialMappedOp(node)) { new_node = createSpecialMappedOp(node); } else { auto schema_name = node->schema().name(); auto new_schema = schema_name.substr(0, schema_name.size() - 1); new_node = graph_->create(Symbol::fromQualString(new_schema), 1); new_node->copyMetadata(node); new_node->insertBefore(node); for (Value* input : node->inputs()) { new_node->addInput(input); } new_node->output()->setType(node->output()->type()); // weird case where there is an inplace op and an equivalent functional op // of the same symbol, but they have different schemas if (!new_node->maybeOperator()) { new_node->destroy(); continue; } } changed = true; mutated_value->replaceAllUsesAfterNodeWith(node, new_node->output()); node->output()->replaceAllUsesWith(new_node->output()); // We rewrite something like: // x = torch.zeros() // x.add_(1) // x.add_(2) // to: // x = torch.zeros() // x0 = x.add(1) // x0.add_(2) // For the remainder of the function, x0 will have the // same aliasing relationships as the original x. // To avoid rebuilding the entire alias db, we can replace // the memory DAG element of x with x0. getOrCreateAliasDb()->replaceWithNewValue( mutated_value, new_node->output()); // it is an invariant that all mutable types have an element in the memory // DAG so we must regive x an alias db element. We have already verified // that the mutated value is a fresh alias with a single use. getOrCreateAliasDb()->createValue(mutated_value); // We must erase the destroyed node from the AliasDb lists of writes getOrCreateAliasDb()->writeIndex_->erase(node); node->destroy(); // now that we have removed a mutating op, the write cache is stale // TODO: don't strictly need to reset write cache, evaluate on models getOrCreateAliasDb()->buildWrittenToLocationsIndex(); } return changed; } bool MutationRemover::inplaceOpVariant(Node* n) { if (!n->kind().is_aten()) { return false; } if (isSpecialMappedOp(n)) { return true; } auto name = n->schema().name(); bool inplace_op = name.at(name.size() - 1) == '_'; if (!inplace_op) { return false; } // needs to have alias analysis by schema auto op = n->maybeOperator(); if (!op) { return false; } if (op->aliasAnalysisKind() != AliasAnalysisKind::FROM_SCHEMA) { return false; } // all inplace ops at time of writing have a single input that is mutated // and returned. check that this is true, anything else could have strange // semantics, if (n->outputs().size() != 1 || n->inputs().empty()) { return false; } auto inputs = n->inputs(); if (!getOrCreateAliasDb()->writesToAlias(n, {inputs.at(0)}) || getOrCreateAliasDb()->writesToAlias( n, {inputs.slice(1).begin(), inputs.slice(1).end()})) { return false; } auto new_schema = name.substr(0, name.size() - 1); return !getAllOperatorsFor(Symbol::fromQualString(new_schema)).empty(); } bool RemoveListMutation(const std::shared_ptr& graph) { MutationRemover mr(graph); return mr.removeListMutation(); } bool RemoveTensorMutation( const std::shared_ptr& graph, std::optional> mutation_filter) { MutationRemover mr(graph, std::move(mutation_filter)); return mr.removeTensorMutation(); } static const std::unordered_set activation_ops = []() { std::unordered_set target_ops; for (const auto& iter : activation_type_promotion_mapping) { std::string name = std::string(iter.first.toQualString()) + "_"; target_ops.insert(Symbol::fromQualString(name)); } return target_ops; }(); bool InplaceToFunctionalActivation(const std::shared_ptr& graph) { return RemoveTensorMutation(graph, [](Node* node) { return activation_ops.count(node->kind()) != 0; }); } } // namespace torch::jit