1 // © 2022 and later: Unicode, Inc. and others.
2 // License & terms of use: http://www.unicode.org/copyright.html
3
4 #include "unicode/utypes.h"
5
6 #if !UCONFIG_NO_BREAK_ITERATION
7
8 #include "cmemory.h"
9 #include "mlbe.h"
10 #include "uassert.h"
11 #include "ubrkimpl.h"
12 #include "unicode/resbund.h"
13 #include "unicode/udata.h"
14 #include "unicode/utf16.h"
15 #include "uresimp.h"
16 #include "util.h"
17 #include "uvectr32.h"
18
19 U_NAMESPACE_BEGIN
20
21 enum class ModelIndex { kUWStart = 0, kBWStart = 6, kTWStart = 9 };
22
MlBreakEngine(const UnicodeSet & digitOrOpenPunctuationOrAlphabetSet,const UnicodeSet & closePunctuationSet,UErrorCode & status)23 MlBreakEngine::MlBreakEngine(const UnicodeSet &digitOrOpenPunctuationOrAlphabetSet,
24 const UnicodeSet &closePunctuationSet, UErrorCode &status)
25 : fDigitOrOpenPunctuationOrAlphabetSet(digitOrOpenPunctuationOrAlphabetSet),
26 fClosePunctuationSet(closePunctuationSet),
27 fNegativeSum(0) {
28 if (U_FAILURE(status)) {
29 return;
30 }
31 loadMLModel(status);
32 }
33
~MlBreakEngine()34 MlBreakEngine::~MlBreakEngine() {}
35
divideUpRange(UText * inText,int32_t rangeStart,int32_t rangeEnd,UVector32 & foundBreaks,const UnicodeString & inString,const LocalPointer<UVector32> & inputMap,UErrorCode & status) const36 int32_t MlBreakEngine::divideUpRange(UText *inText, int32_t rangeStart, int32_t rangeEnd,
37 UVector32 &foundBreaks, const UnicodeString &inString,
38 const LocalPointer<UVector32> &inputMap,
39 UErrorCode &status) const {
40 if (U_FAILURE(status)) {
41 return 0;
42 }
43 if (rangeStart >= rangeEnd) {
44 status = U_ILLEGAL_ARGUMENT_ERROR;
45 return 0;
46 }
47
48 UVector32 boundary(inString.countChar32() + 1, status);
49 if (U_FAILURE(status)) {
50 return 0;
51 }
52 int32_t numBreaks = 0;
53 int32_t codePointLength = inString.countChar32();
54 // The ML algorithm groups six char and evaluates whether the 4th char is a breakpoint.
55 // In each iteration, it evaluates the 4th char and then moves forward one char like a sliding
56 // window. Initially, the first six values in the indexList are [-1, -1, 0, 1, 2, 3]. After
57 // moving forward, finally the last six values in the indexList are
58 // [length-4, length-3, length-2, length-1, -1, -1]. The "+4" here means four extra "-1".
59 int32_t indexSize = codePointLength + 4;
60 LocalMemory<int32_t> indexList(static_cast<int32_t*>(uprv_malloc(indexSize * sizeof(int32_t))));
61 if (indexList.isNull()) {
62 status = U_MEMORY_ALLOCATION_ERROR;
63 return 0;
64 }
65 int32_t numCodeUnits = initIndexList(inString, indexList.getAlias(), status);
66
67 // Add a break for the start.
68 boundary.addElement(0, status);
69 numBreaks++;
70 if (U_FAILURE(status)) return 0;
71
72 for (int32_t idx = 0; idx + 1 < codePointLength && U_SUCCESS(status); idx++) {
73 numBreaks =
74 evaluateBreakpoint(inString, indexList.getAlias(), idx, numCodeUnits, numBreaks, boundary, status);
75 if (idx + 4 < codePointLength) {
76 indexList[idx + 6] = numCodeUnits;
77 numCodeUnits += U16_LENGTH(inString.char32At(indexList[idx + 6]));
78 }
79 }
80
81 if (U_FAILURE(status)) return 0;
82
83 // Add a break for the end if there is not one there already.
84 if (boundary.lastElementi() != inString.countChar32()) {
85 boundary.addElement(inString.countChar32(), status);
86 numBreaks++;
87 }
88
89 int32_t prevCPPos = -1;
90 int32_t prevUTextPos = -1;
91 int32_t correctedNumBreaks = 0;
92 for (int32_t i = 0; i < numBreaks; i++) {
93 int32_t cpPos = boundary.elementAti(i);
94 int32_t utextPos = inputMap.isValid() ? inputMap->elementAti(cpPos) : cpPos + rangeStart;
95 U_ASSERT(cpPos > prevCPPos);
96 U_ASSERT(utextPos >= prevUTextPos);
97
98 if (utextPos > prevUTextPos) {
99 if (utextPos != rangeStart ||
100 (utextPos > 0 &&
101 fClosePunctuationSet.contains(utext_char32At(inText, utextPos - 1)))) {
102 foundBreaks.push(utextPos, status);
103 correctedNumBreaks++;
104 }
105 } else {
106 // Normalization expanded the input text, the dictionary found a boundary
107 // within the expansion, giving two boundaries with the same index in the
108 // original text. Ignore the second. See ticket #12918.
109 --numBreaks;
110 }
111 prevCPPos = cpPos;
112 prevUTextPos = utextPos;
113 }
114 (void)prevCPPos; // suppress compiler warnings about unused variable
115
116 UChar32 nextChar = utext_char32At(inText, rangeEnd);
117 if (!foundBreaks.isEmpty() && foundBreaks.peeki() == rangeEnd) {
118 // In phrase breaking, there has to be a breakpoint between Cj character and
119 // the number/open punctuation.
120 // E.g. る文字「そうだ、京都」->る▁文字▁「そうだ、▁京都」-> breakpoint between 字 and「
121 // E.g. 乗車率90%程度だろうか -> 乗車▁率▁90%▁程度だろうか -> breakpoint between 率 and 9
122 // E.g. しかもロゴがUnicode! -> しかも▁ロゴが▁Unicode!-> breakpoint between が and U
123 if (!fDigitOrOpenPunctuationOrAlphabetSet.contains(nextChar)) {
124 foundBreaks.popi();
125 correctedNumBreaks--;
126 }
127 }
128
129 return correctedNumBreaks;
130 }
131
evaluateBreakpoint(const UnicodeString & inString,int32_t * indexList,int32_t startIdx,int32_t numCodeUnits,int32_t numBreaks,UVector32 & boundary,UErrorCode & status) const132 int32_t MlBreakEngine::evaluateBreakpoint(const UnicodeString &inString, int32_t *indexList,
133 int32_t startIdx, int32_t numCodeUnits, int32_t numBreaks,
134 UVector32 &boundary, UErrorCode &status) const {
135 if (U_FAILURE(status)) {
136 return numBreaks;
137 }
138 int32_t start = 0, end = 0;
139 int32_t score = fNegativeSum;
140
141 for (int i = 0; i < 6; i++) {
142 // UW1 ~ UW6
143 start = startIdx + i;
144 if (indexList[start] != -1) {
145 end = (indexList[start + 1] != -1) ? indexList[start + 1] : numCodeUnits;
146 score += fModel[static_cast<int32_t>(ModelIndex::kUWStart) + i].geti(
147 inString.tempSubString(indexList[start], end - indexList[start]));
148 }
149 }
150 for (int i = 0; i < 3; i++) {
151 // BW1 ~ BW3
152 start = startIdx + i + 1;
153 if (indexList[start] != -1 && indexList[start + 1] != -1) {
154 end = (indexList[start + 2] != -1) ? indexList[start + 2] : numCodeUnits;
155 score += fModel[static_cast<int32_t>(ModelIndex::kBWStart) + i].geti(
156 inString.tempSubString(indexList[start], end - indexList[start]));
157 }
158 }
159 for (int i = 0; i < 4; i++) {
160 // TW1 ~ TW4
161 start = startIdx + i;
162 if (indexList[start] != -1 && indexList[start + 1] != -1 && indexList[start + 2] != -1) {
163 end = (indexList[start + 3] != -1) ? indexList[start + 3] : numCodeUnits;
164 score += fModel[static_cast<int32_t>(ModelIndex::kTWStart) + i].geti(
165 inString.tempSubString(indexList[start], end - indexList[start]));
166 }
167 }
168
169 if (score > 0) {
170 boundary.addElement(startIdx + 1, status);
171 numBreaks++;
172 }
173 return numBreaks;
174 }
175
initIndexList(const UnicodeString & inString,int32_t * indexList,UErrorCode & status) const176 int32_t MlBreakEngine::initIndexList(const UnicodeString &inString, int32_t *indexList,
177 UErrorCode &status) const {
178 if (U_FAILURE(status)) {
179 return 0;
180 }
181 int32_t index = 0;
182 int32_t length = inString.countChar32();
183 // Set all (lenght+4) items inside indexLength to -1 presuming -1 is 4 bytes of 0xff.
184 uprv_memset(indexList, 0xff, (length + 4) * sizeof(int32_t));
185 if (length > 0) {
186 indexList[2] = 0;
187 index = U16_LENGTH(inString.char32At(0));
188 if (length > 1) {
189 indexList[3] = index;
190 index += U16_LENGTH(inString.char32At(index));
191 if (length > 2) {
192 indexList[4] = index;
193 index += U16_LENGTH(inString.char32At(index));
194 if (length > 3) {
195 indexList[5] = index;
196 index += U16_LENGTH(inString.char32At(index));
197 }
198 }
199 }
200 }
201 return index;
202 }
203
loadMLModel(UErrorCode & error)204 void MlBreakEngine::loadMLModel(UErrorCode &error) {
205 // BudouX's model consists of thirteen categories, each of which is make up of pairs of the
206 // feature and its score. As integrating it into jaml.txt, we define thirteen kinds of key and
207 // value to represent the feature and the corresponding score respectively.
208
209 if (U_FAILURE(error)) return;
210
211 UnicodeString key;
212 StackUResourceBundle stackTempBundle;
213 ResourceDataValue modelKey;
214
215 LocalUResourceBundlePointer rbp(ures_openDirect(U_ICUDATA_BRKITR, "jaml", &error));
216 UResourceBundle *rb = rbp.getAlias();
217 if (U_FAILURE(error)) return;
218
219 int32_t index = 0;
220 initKeyValue(rb, "UW1Keys", "UW1Values", fModel[index++], error);
221 initKeyValue(rb, "UW2Keys", "UW2Values", fModel[index++], error);
222 initKeyValue(rb, "UW3Keys", "UW3Values", fModel[index++], error);
223 initKeyValue(rb, "UW4Keys", "UW4Values", fModel[index++], error);
224 initKeyValue(rb, "UW5Keys", "UW5Values", fModel[index++], error);
225 initKeyValue(rb, "UW6Keys", "UW6Values", fModel[index++], error);
226 initKeyValue(rb, "BW1Keys", "BW1Values", fModel[index++], error);
227 initKeyValue(rb, "BW2Keys", "BW2Values", fModel[index++], error);
228 initKeyValue(rb, "BW3Keys", "BW3Values", fModel[index++], error);
229 initKeyValue(rb, "TW1Keys", "TW1Values", fModel[index++], error);
230 initKeyValue(rb, "TW2Keys", "TW2Values", fModel[index++], error);
231 initKeyValue(rb, "TW3Keys", "TW3Values", fModel[index++], error);
232 initKeyValue(rb, "TW4Keys", "TW4Values", fModel[index++], error);
233 fNegativeSum /= 2;
234 }
235
initKeyValue(UResourceBundle * rb,const char * keyName,const char * valueName,Hashtable & model,UErrorCode & error)236 void MlBreakEngine::initKeyValue(UResourceBundle *rb, const char *keyName, const char *valueName,
237 Hashtable &model, UErrorCode &error) {
238 int32_t keySize = 0;
239 int32_t valueSize = 0;
240 int32_t stringLength = 0;
241 UnicodeString key;
242 StackUResourceBundle stackTempBundle;
243 ResourceDataValue modelKey;
244
245 // get modelValues
246 LocalUResourceBundlePointer modelValue(ures_getByKey(rb, valueName, nullptr, &error));
247 const int32_t *value = ures_getIntVector(modelValue.getAlias(), &valueSize, &error);
248 if (U_FAILURE(error)) return;
249
250 // get modelKeys
251 ures_getValueWithFallback(rb, keyName, stackTempBundle.getAlias(), modelKey, error);
252 ResourceArray stringArray = modelKey.getArray(error);
253 keySize = stringArray.getSize();
254 if (U_FAILURE(error)) return;
255
256 for (int32_t idx = 0; idx < keySize; idx++) {
257 stringArray.getValue(idx, modelKey);
258 key = UnicodeString(modelKey.getString(stringLength, error));
259 if (U_SUCCESS(error)) {
260 U_ASSERT(idx < valueSize);
261 fNegativeSum -= value[idx];
262 model.puti(key, value[idx], error);
263 }
264 }
265 }
266
267 U_NAMESPACE_END
268
269 #endif /* #if !UCONFIG_NO_BREAK_ITERATION */
270