1 // © 2021 and later: Unicode, Inc. and others.
2 // License & terms of use: http://www.unicode.org/copyright.html
3
4 #include "unicode/utypes.h"
5
6 #if !UCONFIG_NO_BREAK_ITERATION
7
8 #include "lstmbetst.h"
9 #include "lstmbe.h"
10
11 #include <algorithm>
12 #include <sstream>
13 #include <vector>
14
15 #include "charstr.h"
16
17 //---------------------------------------------
18 // runIndexedTest
19 //---------------------------------------------
20
21
runIndexedTest(int32_t index,UBool exec,const char * & name,char * params)22 void LSTMBETest::runIndexedTest( int32_t index, UBool exec, const char* &name, char* params )
23 {
24 fTestParams = params;
25
26 TESTCASE_AUTO_BEGIN;
27
28 TESTCASE_AUTO(TestThaiGraphclust);
29 TESTCASE_AUTO(TestThaiCodepoints);
30 TESTCASE_AUTO(TestBurmeseGraphclust);
31 TESTCASE_AUTO(TestThaiGraphclustWithLargeMemory);
32 TESTCASE_AUTO(TestThaiCodepointsWithLargeMemory);
33
34 TESTCASE_AUTO_END;
35 }
36
37
38 //--------------------------------------------------------------------------------------
39 //
40 // LSTMBETest constructor and destructor
41 //
42 //--------------------------------------------------------------------------------------
43
LSTMBETest()44 LSTMBETest::LSTMBETest() {
45 fTestParams = NULL;
46 }
47
48
~LSTMBETest()49 LSTMBETest::~LSTMBETest() {
50 }
51
getScriptFromModelName(const std::string & modelName)52 UScriptCode getScriptFromModelName(const std::string& modelName) {
53 if (modelName.find("Thai") == 0) {
54 return USCRIPT_THAI;
55 } else if (modelName.find("Burmese") == 0) {
56 return USCRIPT_MYANMAR;
57 }
58 // Add for other script codes.
59 UPRV_UNREACHABLE_EXIT;
60 }
61
62 // Read file generated by
63 // https://github.com/unicode-org/lstm_word_segmentation/blob/master/segment_text.py
64 // as test cases and compare the Output.
65 // Format of the file
66 // Model:\t[Model Name (such as 'Thai_graphclust_model4_heavy')]
67 // Embedding:\t[Embedding type (such as 'grapheme_clusters_tf')]
68 // Input:\t[source text]
69 // Output:\t[expected output separated by | ]
70 // Input: ...
71 // Output: ...
72 // The test will ensure the Input contains only the characters can be handled by
73 // the model. Since by default the LSTM models are not included, all the tested
74 // models need to be included under source/test/testdata.
75
runTestFromFile(const char * filename)76 void LSTMBETest::runTestFromFile(const char* filename) {
77 UErrorCode status = U_ZERO_ERROR;
78 LocalPointer<const LanguageBreakEngine> engine;
79 // Open and read the test data file.
80 const char *testDataDirectory = IntlTest::getSourceTestData(status);
81 CharString testFileName(testDataDirectory, -1, status);
82 testFileName.append(filename, -1, status);
83
84 int len;
85 UChar *testFile = ReadAndConvertFile(testFileName.data(), len, "UTF-8", status);
86 if (U_FAILURE(status)) {
87 errln("%s:%d Error %s opening test file %s", __FILE__, __LINE__, u_errorName(status), filename);
88 return;
89 }
90
91 // Put the test data into a UnicodeString
92 UnicodeString testString(false, testFile, len);
93
94 int32_t start = 0;
95
96 UnicodeString line;
97 int32_t end;
98 std::string actual_sep_str;
99 int32_t caseNum = 0;
100 // Iterate through all the lines in the test file.
101 do {
102 int32_t cr = testString.indexOf(u'\r', start);
103 int32_t lf = testString.indexOf(u'\n', start);
104 end = cr >= 0 ? (lf >= 0 ? std::min(cr, lf) : cr) : lf;
105 line = testString.tempSubString(start, end < 0 ? INT32_MAX : end - start);
106 if (line.length() > 0) {
107 // Separate each line to key and value by TAB.
108 int32_t tab = line.indexOf(u'\t');
109 UnicodeString key = line.tempSubString(0, tab);
110 const UnicodeString value = line.tempSubString(tab+1);
111
112 if (key == "Model:") {
113 std::string modelName;
114 value.toUTF8String<std::string>(modelName);
115 engine.adoptInstead(createEngineFromTestData(modelName.c_str(), getScriptFromModelName(modelName), status));
116 if (U_FAILURE(status)) {
117 dataerrln("Could not CreateLSTMBreakEngine for " + line + UnicodeString(u_errorName(status)));
118 return;
119 }
120 } else if (key == "Input:") {
121 // First, we ensure all the char in the Input lines are accepted
122 // by the engine before we test them.
123 caseNum++;
124 bool canHandleAllChars = true;
125 for (int32_t i = 0; i < value.length(); i++) {
126 if (!engine->handles(value.charAt(i))) {
127 errln(UnicodeString("Test Case#") + caseNum + " contains char '" +
128 UnicodeString(value.charAt(i)) +
129 "' cannot be handled by the engine in offset " + i + "\n" + line);
130 canHandleAllChars = false;
131 break;
132 }
133 }
134 if (! canHandleAllChars) {
135 return;
136 }
137
138 // If the engine can handle all the chars in the Input line, we
139 // then find the break points by calling the engine.
140 std::stringstream ss;
141
142 // Construct the UText which is expected by the the engine as
143 // input from the UnicodeString.
144 UText ut = UTEXT_INITIALIZER;
145 utext_openConstUnicodeString(&ut, &value, &status);
146 if (U_FAILURE(status)) {
147 dataerrln("Could not utext_openConstUnicodeString for " + value + UnicodeString(u_errorName(status)));
148 return;
149 }
150
151 UVector32 actual(status);
152 if (U_FAILURE(status)) {
153 dataerrln("%s:%d Error %s Could not allocate UVextor32", __FILE__, __LINE__, u_errorName(status));
154 return;
155 }
156 engine->findBreaks(&ut, 0, value.length(), actual, false, status);
157 if (U_FAILURE(status)) {
158 dataerrln("%s:%d Error %s findBreaks failed", __FILE__, __LINE__, u_errorName(status));
159 return;
160 }
161 utext_close(&ut);
162 for (int32_t i = 0; i < actual.size(); i++) {
163 ss << actual.elementAti(i) << ", ";
164 }
165 ss << value.length();
166 // Turn the break points into a string for easy comparison
167 // output.
168 actual_sep_str = "{" + ss.str() + "}";
169 } else if (key == "Output:" && !actual_sep_str.empty()) {
170 std::string d;
171 int32_t sep;
172 int32_t start = 0;
173 int32_t curr = 0;
174 std::stringstream ss;
175 while ((sep = value.indexOf(u'|', start)) >= 0) {
176 int32_t len = sep - start;
177 if (len > 0) {
178 if (curr > 0) {
179 ss << ", ";
180 }
181 curr += len;
182 ss << curr;
183 }
184 start = sep + 1;
185 }
186 // Turn the break points into a string for easy comparison
187 // output.
188 std::string expected = "{" + ss.str() + "}";
189 std::string utf8;
190
191 assertEquals((value + " Test Case#" + caseNum).toUTF8String<std::string>(utf8).c_str(),
192 expected.c_str(), actual_sep_str.c_str());
193 actual_sep_str.clear();
194 }
195 }
196 start = std::max(cr, lf) + 1;
197 } while (end >= 0);
198
199 delete [] testFile;
200 }
201
TestThaiGraphclust()202 void LSTMBETest::TestThaiGraphclust() {
203 runTestFromFile("Thai_graphclust_model4_heavy_Test.txt");
204 }
205
TestThaiCodepoints()206 void LSTMBETest::TestThaiCodepoints() {
207 runTestFromFile("Thai_codepoints_exclusive_model5_heavy_Test.txt");
208 }
209
TestBurmeseGraphclust()210 void LSTMBETest::TestBurmeseGraphclust() {
211 runTestFromFile("Burmese_graphclust_model5_heavy_Test.txt");
212 }
213
createEngineFromTestData(const char * model,UScriptCode script,UErrorCode & status)214 const LanguageBreakEngine* LSTMBETest::createEngineFromTestData(
215 const char* model, UScriptCode script, UErrorCode& status) {
216 const char* testdatapath=loadTestData(status);
217 if(U_FAILURE(status))
218 {
219 dataerrln("Could not load testdata.dat " + UnicodeString(testdatapath) + ", " +
220 UnicodeString(u_errorName(status)));
221 return nullptr;
222 }
223
224 LocalUResourceBundlePointer rb(
225 ures_openDirect(testdatapath, model, &status));
226 if (U_FAILURE(status)) {
227 dataerrln("Could not open " + UnicodeString(model) + " under " + UnicodeString(testdatapath) + ", " +
228 UnicodeString(u_errorName(status)));
229 return nullptr;
230 }
231
232 const LSTMData* data = CreateLSTMData(rb.orphan(), status);
233 if (U_FAILURE(status)) {
234 dataerrln("Could not CreateLSTMData " + UnicodeString(model) + " under " + UnicodeString(testdatapath) + ", " +
235 UnicodeString(u_errorName(status)));
236 return nullptr;
237 }
238 if (data == nullptr) {
239 return nullptr;
240 }
241
242 LocalPointer<const LanguageBreakEngine> engine(CreateLSTMBreakEngine(script, data, status));
243 if (U_FAILURE(status) || engine.getAlias() == nullptr) {
244 dataerrln("Could not CreateLSTMBreakEngine " + UnicodeString(testdatapath) + ", " +
245 UnicodeString(u_errorName(status)));
246 DeleteLSTMData(data);
247 return nullptr;
248 }
249 return engine.orphan();
250 }
251
252
TestThaiGraphclustWithLargeMemory()253 void LSTMBETest::TestThaiGraphclustWithLargeMemory() {
254 runTestWithLargeMemory("Thai_graphclust_model4_heavy", USCRIPT_THAI);
255
256 }
257
TestThaiCodepointsWithLargeMemory()258 void LSTMBETest::TestThaiCodepointsWithLargeMemory() {
259 runTestWithLargeMemory("Thai_codepoints_exclusive_model5_heavy", USCRIPT_THAI);
260 }
261
262 constexpr int32_t MEMORY_TEST_THESHOLD_SHORT = 2 * 1024; // 2 K Unicode Chars.
263 constexpr int32_t MEMORY_TEST_THESHOLD = 32 * 1024; // 32 K Unicode Chars.
264
265 // Test with very long unicode string.
runTestWithLargeMemory(const char * model,UScriptCode script)266 void LSTMBETest::runTestWithLargeMemory( const char* model, UScriptCode script) {
267 UErrorCode status = U_ZERO_ERROR;
268 int32_t test_threshold = quick ? MEMORY_TEST_THESHOLD_SHORT : MEMORY_TEST_THESHOLD;
269 LocalPointer<const LanguageBreakEngine> engine(
270 createEngineFromTestData(model, script, status));
271 if (U_FAILURE(status)) {
272 dataerrln("Could not CreateLSTMBreakEngine for " + UnicodeString(model) + UnicodeString(u_errorName(status)));
273 return;
274 }
275 UnicodeString text(u"อ"); // start with a single Thai char.
276 UVector32 actual(status);
277 if (U_FAILURE(status)) {
278 dataerrln("%s:%d Error %s Could not allocate UVextor32", __FILE__, __LINE__, u_errorName(status));
279 return;
280 }
281 while (U_SUCCESS(status) && text.length() <= test_threshold) {
282 // Construct the UText which is expected by the the engine as
283 // input from the UnicodeString.
284 UText ut = UTEXT_INITIALIZER;
285 utext_openConstUnicodeString(&ut, &text, &status);
286 if (U_FAILURE(status)) {
287 dataerrln("Could not utext_openConstUnicodeString for " + text + UnicodeString(u_errorName(status)));
288 return;
289 }
290
291 engine->findBreaks(&ut, 0, text.length(), actual, false, status);
292 utext_close(&ut);
293 text += text;
294 }
295 }
296 #endif // #if !UCONFIG_NO_BREAK_ITERATION
297