• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
10 
11 #include <vector>
12 
13 #include <executorch/runtime/platform/runtime.h>
14 
15 #include <gtest/gtest.h>
16 
17 using namespace ::testing;
18 
19 using ::example::Version;
20 using ::executorch::extension::llm::Tokenizer;
21 using ::executorch::runtime::Error;
22 using ::executorch::runtime::Result;
23 
24 class MultimodalTiktokenV5ExtensionTest : public Test {
25  public:
SetUp()26   void SetUp() override {
27     executorch::runtime::runtime_init();
28     tokenizer_ = get_tiktoken_for_llama(Version::Multimodal);
29     modelPath_ = std::getenv("RESOURCES_PATH") +
30         std::string("/test_tiktoken_tokenizer.model");
31   }
32 
33   std::unique_ptr<Tokenizer> tokenizer_;
34   std::string modelPath_;
35 };
36 
TEST_F(MultimodalTiktokenV5ExtensionTest,TokenizerVocabSizeIsExpected)37 TEST_F(MultimodalTiktokenV5ExtensionTest, TokenizerVocabSizeIsExpected) {
38   Error res = tokenizer_->load(modelPath_.c_str());
39   EXPECT_EQ(res, Error::Ok);
40   EXPECT_EQ(tokenizer_->vocab_size(), 128256);
41   EXPECT_EQ(tokenizer_->bos_tok(), 128000);
42   EXPECT_EQ(tokenizer_->eos_tok(), 128001);
43 }
44 
TEST_F(MultimodalTiktokenV5ExtensionTest,TokenizerEncodeCorrectly)45 TEST_F(MultimodalTiktokenV5ExtensionTest, TokenizerEncodeCorrectly) {
46   Error res = tokenizer_->load(modelPath_.c_str());
47   EXPECT_EQ(res, Error::Ok);
48   Result<std::vector<uint64_t>> out = tokenizer_->encode(
49       "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n<|image|>What do you think is going on in this snapshot?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nAmidst a scenic garden backdrop, a man dressed in a suit with a distinct button on its lower portion stands prominently.<|eom_id|>",
50       0,
51       0);
52   EXPECT_EQ(out.error(), Error::Ok);
53   EXPECT_EQ(out.get().size(), 48);
54   std::vector<uint64_t> expected_out = {
55       128000, 128006, 882,    128007, 271,    128010, 3923,  656,
56       499,    1781,   374,    2133,   389,    304,    420,   16694,
57       30,     128009, 128006, 78191,  128007, 271,    6219,  307,
58       267,    264,    62081,  13863,  39577,  11,     264,   893,
59       26435,  304,    264,    7937,   449,    264,    12742, 3215,
60       389,    1202,   4827,   13651,  13656,  74088,  13,    128008};
61   for (size_t i = 0; i < expected_out.size(); ++i) {
62     EXPECT_EQ(expected_out[i], out.get()[i]);
63   }
64 }
65 
TEST_F(MultimodalTiktokenV5ExtensionTest,TokenizerDecodeCorrectly)66 TEST_F(MultimodalTiktokenV5ExtensionTest, TokenizerDecodeCorrectly) {
67   Error res = tokenizer_->load(modelPath_.c_str());
68   EXPECT_EQ(res, Error::Ok);
69   std::vector<std::string> expected = {
70       "<|begin_of_text|>",
71       "<|start_header_id|>",
72       "user",
73       "<|end_header_id|>",
74       "<|image|>",
75       "<|image|>",
76       "hello",
77       "<|image|>",
78       "<|eom_id|>"};
79   std::vector<uint64_t> tokens = {
80       128000, 128006, 882, 128007, 128010, 128010, 15339, 128010, 128008};
81   for (size_t i = 0; i < tokens.size(); i++) {
82     Result<std::string> out = tokenizer_->decode(0, tokens[i]);
83     EXPECT_EQ(out.error(), Error::Ok);
84     EXPECT_EQ(out.get(), expected[i]);
85   }
86 }
87