/* * Copyright 2022 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "fcp/dictionary/dictionary.h" #include #include #include #include #include #include #include "fcp/base/monitoring.h" #include "fcp/dictionary/dictionary.pb.h" #include "absl/container/node_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" namespace fcp { namespace dictionary { // Bidirectional map defined as hash_map from strings to int32_t paired with // a vector of those keys for reverse lookup. typedef std::pair, std::vector> HashVectorBimap; namespace { // Map a string to an ID, using a bidirectional map (an std::pair containing // two data structures for string -> int and for int -> string lookups). int32_t MapLookup(const HashVectorBimap& bimap, const std::string& tag) { auto map_idx = bimap.first.find(tag); return map_idx == bimap.first.end() ? Dictionary::kNotFound : map_idx->second; } // Lookup a token given its ID. std::string MapReverseLookup(const HashVectorBimap& bimap, int32_t id) { if (id < 0 || id >= bimap.second.size()) { return ""; } return bimap.second[id]; } // Return the size of an stl-like data structure. int32_t GetSize(const HashVectorBimap& bimap) { return static_cast(bimap.first.size()); } int32_t GetMaxSpecialId(const DictionaryDescription::SpecialIds& special_ids) { int32_t max_special_id = -1; max_special_id = std::max(max_special_id, special_ids.bos()); max_special_id = std::max(max_special_id, special_ids.eos()); max_special_id = std::max(max_special_id, special_ids.unk()); return max_special_id; } // Dictionary implementation powered by templated utility functions above. template class DictionaryImpl : public Dictionary { public: DictionaryImpl( std::unique_ptr bimap, const DictionaryDescription::SpecialIds& special_ids, const DictionaryDescription::OutputBlocklistIds& output_blocklist_ids) : bimap_(std::move(bimap)), special_ids_(special_ids), max_special_id_(GetMaxSpecialId(special_ids)) { // Validate special ids. FCP_CHECK(special_ids.has_bos() == (special_ids.bos() >= 0)); FCP_CHECK(special_ids.has_eos() == (special_ids.eos() >= 0)); FCP_CHECK(special_ids.has_unk() == (special_ids.unk() >= 0)); // Token numbering starts at max(special_ids) + 1. output_blocklist_ids_.reserve(max_special_id_ + 1 + output_blocklist_ids.id_size()); for (int32_t id = 0; id <= max_special_id_; ++id) { output_blocklist_ids_.push_back(id); } for (int32_t id : output_blocklist_ids.id()) { output_blocklist_ids_.push_back(id); } } int32_t Size() const override { return GetSize(*bimap_) + max_special_id_ + 1; } int32_t TokenToId(const std::string& tag) const override { int32_t id = MapLookup(*bimap_, tag); if (id == kNotFound) { return special_ids_.unk(); } else { return id + max_special_id_ + 1; } } std::string IdToToken(int32_t id) const override { return MapReverseLookup(*bimap_, id - (max_special_id_ + 1)); } bool IsSpecialId(int32_t token_id) const override { return token_id <= max_special_id_; } const std::vector& GetSortedOutputBlocklistIds() const override { return output_blocklist_ids_; } const DictionaryDescription::SpecialIds& GetSpecialIds() const override { return special_ids_; } private: const std::unique_ptr bimap_; const DictionaryDescription::SpecialIds special_ids_; int32_t max_special_id_; std::vector output_blocklist_ids_; }; absl::Status IsOutputBlocklistIdsSortedAndUnique( const DictionaryDescription& description) { // All blocklist ids must be greater than max_special_id. const int32_t max_special_id = GetMaxSpecialId(description.special_ids()); // Make sure output blocklist IDs are sorted in ascending order and unique. if (description.has_output_blocklist_ids()) { for (int i = 0; i < description.output_blocklist_ids().id_size(); i++) { if (description.output_blocklist_ids().id(i) <= max_special_id) { return absl::InvalidArgumentError( "output_blocklist_ids should not overlap with special ids"); } if (!(i == 0 || description.output_blocklist_ids().id(i) > description.output_blocklist_ids().id(i - 1))) { return absl::InvalidArgumentError( "output_blocklist_ids not unique or sorted"); } } } return absl::OkStatus(); } } // anonymous namespace absl::StatusOr> Dictionary::Create( const DictionaryDescription& description) { if (!description.has_vocabulary()) { return absl::InvalidArgumentError( "Cannot create a dictionary that does not have vocabulary set"); } // Make sure output blocklist IDs are sorted in ascending order and unique. FCP_RETURN_IF_ERROR(IsOutputBlocklistIdsSortedAndUnique(description)); if (description.vocabulary().has_index()) { auto bimap = std::make_unique(); int i = 0; bimap->second.reserve(description.vocabulary().index().token_size()); for (const std::string& token : description.vocabulary().index().token()) { FCP_CHECK(!token.empty()); bimap->first[token] = i++; bimap->second.push_back(token); } return std::unique_ptr(new DictionaryImpl( std::move(bimap), description.special_ids(), description.output_blocklist_ids())); } else { return absl::InvalidArgumentError( "Invalid DictionaryDescription: no vocabulary specified."); } } } // namespace dictionary } // namespace fcp