/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include namespace example { using ::executorch::extension::llm::Tiktoken; namespace { static constexpr int32_t kSpecialTokensSize = 256; static constexpr size_t kBOSTokenIndex = 0; static constexpr size_t kEOSTokenIndex = 1; static inline std::unique_ptr> _get_default_special_tokens() { auto special_tokens = std::make_unique>(std::vector{ "<|begin_of_text|>", "<|end_of_text|>", "<|reserved_special_token_0|>", "<|reserved_special_token_1|>", "<|finetune_right_pad_id|>", "<|step_id|>", "<|start_header_id|>", "<|end_header_id|>", "<|eom_id|>", "<|eot_id|>", "<|python_tag|>"}); // pad the rest of the special tokens with reserved tokens ssize_t reserved_special_token_num = 2; while (special_tokens->size() < kSpecialTokensSize) { special_tokens->emplace_back( "<|reserved_special_token_" + std::to_string(reserved_special_token_num++) + "|>"); } return special_tokens; } static inline std::unique_ptr> _get_multimodal_special_tokens() { auto special_tokens = std::make_unique>(std::vector{ "<|begin_of_text|>", "<|end_of_text|>", "<|reserved_special_token_0|>", "<|reserved_special_token_1|>", "<|reserved_special_token_2|>", "<|reserved_special_token_3|>", "<|start_header_id|>", "<|end_header_id|>", "<|eom_id|>", "<|eot_id|>", "<|image|>"}); // pad the rest of the special tokens with reserved tokens except the last // one ssize_t reserved_special_token_num = 4; while (special_tokens->size() < kSpecialTokensSize - 1) { special_tokens->emplace_back( "<|reserved_special_token_" + std::to_string(reserved_special_token_num++) + "|>"); } special_tokens->emplace_back("<|python_tag|>"); return special_tokens; } std::unique_ptr> _get_special_tokens(Version version) { switch (version) { case Version::Multimodal: return _get_multimodal_special_tokens(); default: return _get_default_special_tokens(); } } } // namespace std::unique_ptr get_tiktoken_for_llama(Version version) { return std::make_unique( _get_special_tokens(version), kBOSTokenIndex, kEOSTokenIndex); } } // namespace example