1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
10
11 #include <vector>
12
13 #include <executorch/runtime/platform/runtime.h>
14
15 #include <gtest/gtest.h>
16
17 using namespace ::testing;
18
19 using ::example::Version;
20 using ::executorch::extension::llm::Tokenizer;
21 using ::executorch::runtime::Error;
22 using ::executorch::runtime::Result;
23
24 class MultimodalTiktokenV5ExtensionTest : public Test {
25 public:
SetUp()26 void SetUp() override {
27 executorch::runtime::runtime_init();
28 tokenizer_ = get_tiktoken_for_llama(Version::Multimodal);
29 modelPath_ = std::getenv("RESOURCES_PATH") +
30 std::string("/test_tiktoken_tokenizer.model");
31 }
32
33 std::unique_ptr<Tokenizer> tokenizer_;
34 std::string modelPath_;
35 };
36
TEST_F(MultimodalTiktokenV5ExtensionTest,TokenizerVocabSizeIsExpected)37 TEST_F(MultimodalTiktokenV5ExtensionTest, TokenizerVocabSizeIsExpected) {
38 Error res = tokenizer_->load(modelPath_.c_str());
39 EXPECT_EQ(res, Error::Ok);
40 EXPECT_EQ(tokenizer_->vocab_size(), 128256);
41 EXPECT_EQ(tokenizer_->bos_tok(), 128000);
42 EXPECT_EQ(tokenizer_->eos_tok(), 128001);
43 }
44
TEST_F(MultimodalTiktokenV5ExtensionTest,TokenizerEncodeCorrectly)45 TEST_F(MultimodalTiktokenV5ExtensionTest, TokenizerEncodeCorrectly) {
46 Error res = tokenizer_->load(modelPath_.c_str());
47 EXPECT_EQ(res, Error::Ok);
48 Result<std::vector<uint64_t>> out = tokenizer_->encode(
49 "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n<|image|>What do you think is going on in this snapshot?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nAmidst a scenic garden backdrop, a man dressed in a suit with a distinct button on its lower portion stands prominently.<|eom_id|>",
50 0,
51 0);
52 EXPECT_EQ(out.error(), Error::Ok);
53 EXPECT_EQ(out.get().size(), 48);
54 std::vector<uint64_t> expected_out = {
55 128000, 128006, 882, 128007, 271, 128010, 3923, 656,
56 499, 1781, 374, 2133, 389, 304, 420, 16694,
57 30, 128009, 128006, 78191, 128007, 271, 6219, 307,
58 267, 264, 62081, 13863, 39577, 11, 264, 893,
59 26435, 304, 264, 7937, 449, 264, 12742, 3215,
60 389, 1202, 4827, 13651, 13656, 74088, 13, 128008};
61 for (size_t i = 0; i < expected_out.size(); ++i) {
62 EXPECT_EQ(expected_out[i], out.get()[i]);
63 }
64 }
65
TEST_F(MultimodalTiktokenV5ExtensionTest,TokenizerDecodeCorrectly)66 TEST_F(MultimodalTiktokenV5ExtensionTest, TokenizerDecodeCorrectly) {
67 Error res = tokenizer_->load(modelPath_.c_str());
68 EXPECT_EQ(res, Error::Ok);
69 std::vector<std::string> expected = {
70 "<|begin_of_text|>",
71 "<|start_header_id|>",
72 "user",
73 "<|end_header_id|>",
74 "<|image|>",
75 "<|image|>",
76 "hello",
77 "<|image|>",
78 "<|eom_id|>"};
79 std::vector<uint64_t> tokens = {
80 128000, 128006, 882, 128007, 128010, 128010, 15339, 128010, 128008};
81 for (size_t i = 0; i < tokens.size(); i++) {
82 Result<std::string> out = tokenizer_->decode(0, tokens[i]);
83 EXPECT_EQ(out.error(), Error::Ok);
84 EXPECT_EQ(out.get(), expected[i]);
85 }
86 }
87