• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // © 2021 and later: Unicode, Inc. and others.
2 // License & terms of use: http://www.unicode.org/copyright.html
3 
4 #include "unicode/utypes.h"
5 
6 #if !UCONFIG_NO_BREAK_ITERATION
7 
8 #include "lstmbetst.h"
9 #include "lstmbe.h"
10 
11 #include <algorithm>
12 #include <sstream>
13 #include <vector>
14 
15 #include "charstr.h"
16 
17 //---------------------------------------------
18 // runIndexedTest
19 //---------------------------------------------
20 
21 
runIndexedTest(int32_t index,UBool exec,const char * & name,char * params)22 void LSTMBETest::runIndexedTest( int32_t index, UBool exec, const char* &name, char* params )
23 {
24     fTestParams = params;
25 
26     TESTCASE_AUTO_BEGIN;
27 
28     TESTCASE_AUTO(TestThaiGraphclust);
29     TESTCASE_AUTO(TestThaiCodepoints);
30     TESTCASE_AUTO(TestBurmeseGraphclust);
31     TESTCASE_AUTO(TestThaiGraphclustWithLargeMemory);
32     TESTCASE_AUTO(TestThaiCodepointsWithLargeMemory);
33 
34     TESTCASE_AUTO_END;
35 }
36 
37 
38 //--------------------------------------------------------------------------------------
39 //
40 //    LSTMBETest    constructor and destructor
41 //
42 //--------------------------------------------------------------------------------------
43 
LSTMBETest()44 LSTMBETest::LSTMBETest() {
45     fTestParams = NULL;
46 }
47 
48 
~LSTMBETest()49 LSTMBETest::~LSTMBETest() {
50 }
51 
getScriptFromModelName(const std::string & modelName)52 UScriptCode getScriptFromModelName(const std::string& modelName) {
53     if (modelName.find("Thai") == 0) {
54         return USCRIPT_THAI;
55     } else if (modelName.find("Burmese") == 0) {
56         return USCRIPT_MYANMAR;
57     }
58     // Add for other script codes.
59     UPRV_UNREACHABLE_EXIT;
60 }
61 
62 // Read file generated by
63 // https://github.com/unicode-org/lstm_word_segmentation/blob/master/segment_text.py
64 // as test cases and compare the Output.
65 // Format of the file
66 //   Model:\t[Model Name (such as 'Thai_graphclust_model4_heavy')]
67 //   Embedding:\t[Embedding type (such as 'grapheme_clusters_tf')]
68 //   Input:\t[source text]
69 //   Output:\t[expected output separated by | ]
70 //   Input: ...
71 //   Output: ...
72 // The test will ensure the Input contains only the characters can be handled by
73 // the model. Since by default the LSTM models are not included, all the tested
74 // models need to be included under source/test/testdata.
75 
runTestFromFile(const char * filename)76 void LSTMBETest::runTestFromFile(const char* filename) {
77     UErrorCode   status = U_ZERO_ERROR;
78     LocalPointer<const LanguageBreakEngine> engine;
79     //  Open and read the test data file.
80     const char *testDataDirectory = IntlTest::getSourceTestData(status);
81     CharString testFileName(testDataDirectory, -1, status);
82     testFileName.append(filename, -1, status);
83 
84     int len;
85     UChar *testFile = ReadAndConvertFile(testFileName.data(), len, "UTF-8", status);
86     if (U_FAILURE(status)) {
87         errln("%s:%d Error %s opening test file %s", __FILE__, __LINE__, u_errorName(status), filename);
88         return;
89     }
90 
91     //  Put the test data into a UnicodeString
92     UnicodeString testString(false, testFile, len);
93 
94     int32_t start = 0;
95 
96     UnicodeString line;
97     int32_t end;
98     std::string actual_sep_str;
99     int32_t caseNum = 0;
100     // Iterate through all the lines in the test file.
101     do {
102         int32_t cr = testString.indexOf(u'\r', start);
103         int32_t lf = testString.indexOf(u'\n', start);
104         end = cr >= 0 ? (lf >= 0 ? std::min(cr, lf) : cr) : lf;
105         line = testString.tempSubString(start, end < 0 ? INT32_MAX : end - start);
106         if (line.length() > 0) {
107             // Separate each line to key and value by TAB.
108             int32_t tab = line.indexOf(u'\t');
109             UnicodeString key = line.tempSubString(0, tab);
110             const UnicodeString value = line.tempSubString(tab+1);
111 
112             if (key == "Model:") {
113                 std::string modelName;
114                 value.toUTF8String<std::string>(modelName);
115                 engine.adoptInstead(createEngineFromTestData(modelName.c_str(), getScriptFromModelName(modelName), status));
116                 if (U_FAILURE(status)) {
117                     dataerrln("Could not CreateLSTMBreakEngine for " + line + UnicodeString(u_errorName(status)));
118                     return;
119                 }
120             } else if (key == "Input:") {
121                 // First, we ensure all the char in the Input lines are accepted
122                 // by the engine before we test them.
123                 caseNum++;
124                 bool canHandleAllChars = true;
125                 for (int32_t i = 0; i < value.length(); i++) {
126                     if (!engine->handles(value.charAt(i))) {
127                         errln(UnicodeString("Test Case#") + caseNum + " contains char '" +
128                                   UnicodeString(value.charAt(i)) +
129                                   "' cannot be handled by the engine in offset " + i + "\n" + line);
130                         canHandleAllChars = false;
131                         break;
132                     }
133                 }
134                 if (! canHandleAllChars) {
135                     return;
136                 }
137 
138                 // If the engine can handle all the chars in the Input line, we
139                 // then find the break points by calling the engine.
140                 std::stringstream ss;
141 
142                 // Construct the UText which is expected by the the engine as
143                 // input from the UnicodeString.
144                 UText ut = UTEXT_INITIALIZER;
145                 utext_openConstUnicodeString(&ut, &value, &status);
146                 if (U_FAILURE(status)) {
147                     dataerrln("Could not utext_openConstUnicodeString for " + value + UnicodeString(u_errorName(status)));
148                     return;
149                 }
150 
151                 UVector32 actual(status);
152                 if (U_FAILURE(status)) {
153                     dataerrln("%s:%d Error %s Could not allocate UVextor32", __FILE__, __LINE__, u_errorName(status));
154                     return;
155                 }
156                 engine->findBreaks(&ut, 0, value.length(), actual, false, status);
157                 if (U_FAILURE(status)) {
158                     dataerrln("%s:%d Error %s findBreaks failed", __FILE__, __LINE__, u_errorName(status));
159                     return;
160                 }
161                 utext_close(&ut);
162                 for (int32_t i = 0; i < actual.size(); i++) {
163                     ss << actual.elementAti(i) << ", ";
164                 }
165                 ss << value.length();
166                 // Turn the break points into a string for easy comparison
167                 // output.
168                 actual_sep_str = "{" + ss.str() + "}";
169             } else if (key == "Output:" && !actual_sep_str.empty()) {
170                 std::string d;
171                 int32_t sep;
172                 int32_t start = 0;
173                 int32_t curr = 0;
174                 std::stringstream ss;
175                 while ((sep = value.indexOf(u'|', start)) >= 0) {
176                     int32_t len = sep - start;
177                     if (len > 0) {
178                         if (curr > 0) {
179                             ss << ", ";
180                         }
181                         curr += len;
182                         ss << curr;
183                     }
184                     start = sep + 1;
185                 }
186                 // Turn the break points into a string for easy comparison
187                 // output.
188                 std::string expected = "{" + ss.str() + "}";
189                 std::string utf8;
190 
191                 assertEquals((value + " Test Case#" + caseNum).toUTF8String<std::string>(utf8).c_str(),
192                              expected.c_str(), actual_sep_str.c_str());
193                 actual_sep_str.clear();
194             }
195         }
196         start = std::max(cr, lf) + 1;
197     } while (end >= 0);
198 
199     delete [] testFile;
200 }
201 
TestThaiGraphclust()202 void LSTMBETest::TestThaiGraphclust() {
203     runTestFromFile("Thai_graphclust_model4_heavy_Test.txt");
204 }
205 
TestThaiCodepoints()206 void LSTMBETest::TestThaiCodepoints() {
207     runTestFromFile("Thai_codepoints_exclusive_model5_heavy_Test.txt");
208 }
209 
TestBurmeseGraphclust()210 void LSTMBETest::TestBurmeseGraphclust() {
211     runTestFromFile("Burmese_graphclust_model5_heavy_Test.txt");
212 }
213 
createEngineFromTestData(const char * model,UScriptCode script,UErrorCode & status)214 const LanguageBreakEngine* LSTMBETest::createEngineFromTestData(
215         const char* model, UScriptCode script, UErrorCode& status) {
216     const char* testdatapath=loadTestData(status);
217     if(U_FAILURE(status))
218     {
219         dataerrln("Could not load testdata.dat " + UnicodeString(testdatapath) +  ", " +
220                   UnicodeString(u_errorName(status)));
221         return nullptr;
222     }
223 
224     LocalUResourceBundlePointer rb(
225         ures_openDirect(testdatapath, model, &status));
226     if (U_FAILURE(status)) {
227         dataerrln("Could not open " + UnicodeString(model) + " under " +  UnicodeString(testdatapath) +  ", " +
228                   UnicodeString(u_errorName(status)));
229         return nullptr;
230     }
231 
232     const LSTMData* data = CreateLSTMData(rb.orphan(), status);
233     if (U_FAILURE(status)) {
234         dataerrln("Could not CreateLSTMData " + UnicodeString(model) + " under " +  UnicodeString(testdatapath) +  ", " +
235                   UnicodeString(u_errorName(status)));
236         return nullptr;
237     }
238     if (data == nullptr) {
239         return nullptr;
240     }
241 
242     LocalPointer<const LanguageBreakEngine> engine(CreateLSTMBreakEngine(script, data, status));
243     if (U_FAILURE(status) || engine.getAlias() == nullptr) {
244         dataerrln("Could not CreateLSTMBreakEngine " + UnicodeString(testdatapath) +  ", " +
245                   UnicodeString(u_errorName(status)));
246         DeleteLSTMData(data);
247         return nullptr;
248     }
249     return engine.orphan();
250 }
251 
252 
TestThaiGraphclustWithLargeMemory()253 void LSTMBETest::TestThaiGraphclustWithLargeMemory() {
254     runTestWithLargeMemory("Thai_graphclust_model4_heavy", USCRIPT_THAI);
255 
256 }
257 
TestThaiCodepointsWithLargeMemory()258 void LSTMBETest::TestThaiCodepointsWithLargeMemory() {
259     runTestWithLargeMemory("Thai_codepoints_exclusive_model5_heavy", USCRIPT_THAI);
260 }
261 
262 constexpr int32_t MEMORY_TEST_THESHOLD_SHORT = 2 * 1024; // 2 K Unicode Chars.
263 constexpr int32_t MEMORY_TEST_THESHOLD = 32 * 1024; // 32 K Unicode Chars.
264 
265 // Test with very long unicode string.
runTestWithLargeMemory(const char * model,UScriptCode script)266 void LSTMBETest::runTestWithLargeMemory( const char* model, UScriptCode script) {
267     UErrorCode   status = U_ZERO_ERROR;
268     int32_t test_threshold = quick ? MEMORY_TEST_THESHOLD_SHORT : MEMORY_TEST_THESHOLD;
269     LocalPointer<const LanguageBreakEngine> engine(
270         createEngineFromTestData(model, script, status));
271     if (U_FAILURE(status)) {
272         dataerrln("Could not CreateLSTMBreakEngine for " + UnicodeString(model) + UnicodeString(u_errorName(status)));
273         return;
274     }
275     UnicodeString text(u"อ");  // start with a single Thai char.
276     UVector32 actual(status);
277     if (U_FAILURE(status)) {
278         dataerrln("%s:%d Error %s Could not allocate UVextor32", __FILE__, __LINE__, u_errorName(status));
279         return;
280     }
281     while (U_SUCCESS(status) && text.length() <= test_threshold) {
282         // Construct the UText which is expected by the the engine as
283         // input from the UnicodeString.
284         UText ut = UTEXT_INITIALIZER;
285         utext_openConstUnicodeString(&ut, &text, &status);
286         if (U_FAILURE(status)) {
287             dataerrln("Could not utext_openConstUnicodeString for " + text + UnicodeString(u_errorName(status)));
288             return;
289         }
290 
291         engine->findBreaks(&ut, 0, text.length(), actual, false, status);
292         utext_close(&ut);
293         text += text;
294     }
295 }
296 #endif // #if !UCONFIG_NO_BREAK_ITERATION
297