#include #include namespace torch::jit { namespace { class DictNodeImplBase { public: virtual ~DictNodeImplBase() = default; virtual bool contains(const IValue&) const = 0; virtual size_t size() const = 0; virtual Value* get(const IValue&) const = 0; bool canOptimize() { return !has_overlap_ && !has_non_const_key_; } protected: bool has_overlap_ = false; bool has_non_const_key_ = false; }; template class DictNodeImpl : public DictNodeImplBase { public: DictNodeImpl( std::function ivalue_converter, Node* dict_creation_node) : ivalue_converter_(std::move(ivalue_converter)) { for (size_t i = 0; i < dict_creation_node->inputs().size(); i += 2) { auto key_opt = toIValue(dict_creation_node->input(i)); // Key is not constant if we cannot convert to IValue if (key_opt == std::nullopt) { has_non_const_key_ = true; continue; } KeyType key = ivalue_converter_(*key_opt); if (dict_.find(key) == dict_.end()) { dict_.emplace(key, dict_creation_node->input(i + 1)); } else { has_overlap_ = true; } } } bool contains(const IValue& ivalue) const override { auto key = ivalue_converter_(ivalue); return dict_.find(key) != dict_.end(); } size_t size() const override { return dict_.size(); } Value* get(const IValue& ivalue) const override { auto val = ivalue_converter_(ivalue); auto loc = dict_.find(val); if (loc != dict_.end()) { return loc->second; } TORCH_CHECK(false, "Cannot get non-existent key"); } private: std::unordered_map dict_; std::function ivalue_converter_; }; class DictNode { public: explicit DictNode(Node* dict_creation_node) { auto dict_type = dict_creation_node->output()->type(); auto key_value_types = dict_type->containedTypes(); TORCH_CHECK( key_value_types.size() == 2, "Dict must have 2 contained types"); const auto& key_type = key_value_types[0]; switch (key_type->kind()) { case TypeKind::IntType: { auto ivalue_converter = [](const IValue& ival) { return ival.toInt(); }; impl_ = std::make_unique>( std::move(ivalue_converter), dict_creation_node); break; } case TypeKind::FloatType: { auto ivalue_converter = [](const IValue& ival) { return ival.toDouble(); }; impl_ = std::make_unique>( std::move(ivalue_converter), dict_creation_node); break; } case TypeKind::StringType: { auto ivalue_converter = [](const IValue& ival) { return *ival.toString(); }; impl_ = std::make_unique>( std::move(ivalue_converter), dict_creation_node); break; } default: impl_ = nullptr; } } bool canOptimize() const { if (impl_) { return impl_->canOptimize(); } return false; } size_t size() const { if (impl_) { return impl_->size(); } return 0; } std::optional getOrNullopt(const IValue& key) const { if (impl_ && impl_->contains(key)) { return impl_->get(key); } return std::nullopt; } private: std::unique_ptr impl_; }; bool isDict(Value* v) { return v->type()->castRaw() != nullptr; } class PeepholeOptimizeDictIdiomsImpl { public: explicit PeepholeOptimizeDictIdiomsImpl(std::shared_ptr graph) : graph_(std::move(graph)), aliasDb_(std::make_unique(graph_)) {} bool run() { collectMutatedDicts(graph_->block()); return runBlock(graph_->block()); } private: void checkForMutatedDicts(Value* v) { if (isDict(v) && aliasDb_->hasWriters(v)) { mutated_dicts_.insert(v); } } void collectMutatedDicts(Block* b) { for (Value* v : b->inputs()) { checkForMutatedDicts(v); } for (Node* n : b->nodes()) { for (Value* v : n->outputs()) { checkForMutatedDicts(v); } for (Block* block : n->blocks()) { collectMutatedDicts(block); } } } const DictNode& getDictNode(Node* creation_node) { auto cached = dict_cache_.find(creation_node); if (cached == dict_cache_.end()) { cached = dict_cache_.emplace(creation_node, DictNode(creation_node)).first; } return cached->second; } std::optional getValueFromDict(Node* dict_creation_node, Value* key) { const DictNode& dict_node = getDictNode(dict_creation_node); auto key_opt = toIValue(key); // Key is not constant if we cannot convert to IValue if (key_opt == std::nullopt) { return std::nullopt; } IValue key_ival = *key_opt; if (dict_node.canOptimize()) { return dict_node.getOrNullopt(key_ival); } return std::nullopt; } std::optional computeLen(Node* dict_creation_node) { const DictNode& dict_node = getDictNode(dict_creation_node); if (dict_node.canOptimize()) { return static_cast(dict_node.size()); } return std::nullopt; } bool optimizeLen(Node* len_node, Node* creation_node) { if (creation_node->kind() == prim::DictConstruct) { auto len = computeLen(creation_node); if (len != std::nullopt) { WithInsertPoint guard(len_node); len_node->output()->replaceAllUsesWith(graph_->insertConstant(len)); return true; } } return false; } bool optimizeGetItem(Node* getitem_node, Node* creation_node) { if (creation_node->kind() == prim::DictConstruct) { auto key = getitem_node->input(1); auto value = getValueFromDict(creation_node, key); if (value != std::nullopt) { getitem_node->output()->replaceAllUsesWith(*value); return true; } } return false; } bool runBlock(Block* block) { bool changed = false; for (Node* node : block->nodes()) { for (Block* b : node->blocks()) { changed |= runBlock(b); } // only optimizing dict ops if (node->inputs().empty() || !isDict(node->input(0))) { continue; } auto first_input = node->input(0); // only optimizing ops with unmutated inputs if (mutated_dicts_.count(first_input)) { continue; } if (node->kind() == aten::len) { changed |= optimizeLen(node, first_input->node()); } else if (node->kind() == aten::__getitem__) { changed |= optimizeGetItem(node, first_input->node()); } } return changed; } std::shared_ptr graph_; std::unordered_set mutated_dicts_; std::unique_ptr aliasDb_; std::unordered_map dict_cache_; }; } // namespace bool PeepholeOptimizeDictIdioms(const std::shared_ptr& graph) { PeepholeOptimizeDictIdiomsImpl opt(graph); return opt.run(); } } // namespace torch::jit