1 // © 2021 and later: Unicode, Inc. and others.
2 // License & terms of use: http://www.unicode.org/copyright.html
3
4 #include <complex>
5 #include <utility>
6
7 #include "unicode/utypes.h"
8
9 #if !UCONFIG_NO_BREAK_ITERATION
10
11 #include "brkeng.h"
12 #include "charstr.h"
13 #include "cmemory.h"
14 #include "lstmbe.h"
15 #include "putilimp.h"
16 #include "uassert.h"
17 #include "ubrkimpl.h"
18 #include "uresimp.h"
19 #include "uvectr32.h"
20 #include "uvector.h"
21
22 #include "unicode/brkiter.h"
23 #include "unicode/resbund.h"
24 #include "unicode/ubrk.h"
25 #include "unicode/uniset.h"
26 #include "unicode/ustring.h"
27 #include "unicode/utf.h"
28
29 U_NAMESPACE_BEGIN
30
31 // Uncomment the following #define to debug.
32 // #define LSTM_DEBUG 1
33 // #define LSTM_VECTORIZER_DEBUG 1
34
35 /**
36 * Interface for reading 1D array.
37 */
38 class ReadArray1D {
39 public:
40 virtual ~ReadArray1D();
41 virtual int32_t d1() const = 0;
42 virtual float get(int32_t i) const = 0;
43
44 #ifdef LSTM_DEBUG
print() const45 void print() const {
46 printf("\n[");
47 for (int32_t i = 0; i < d1(); i++) {
48 printf("%0.8e ", get(i));
49 if (i % 4 == 3) printf("\n");
50 }
51 printf("]\n");
52 }
53 #endif
54 };
55
~ReadArray1D()56 ReadArray1D::~ReadArray1D()
57 {
58 }
59
60 /**
61 * Interface for reading 2D array.
62 */
63 class ReadArray2D {
64 public:
65 virtual ~ReadArray2D();
66 virtual int32_t d1() const = 0;
67 virtual int32_t d2() const = 0;
68 virtual float get(int32_t i, int32_t j) const = 0;
69 };
70
~ReadArray2D()71 ReadArray2D::~ReadArray2D()
72 {
73 }
74
75 /**
76 * A class to index a float array as a 1D Array without owning the pointer or
77 * copy the data.
78 */
79 class ConstArray1D : public ReadArray1D {
80 public:
ConstArray1D()81 ConstArray1D() : data_(nullptr), d1_(0) {}
82
ConstArray1D(const float * data,int32_t d1)83 ConstArray1D(const float* data, int32_t d1) : data_(data), d1_(d1) {}
84
85 virtual ~ConstArray1D();
86
87 // Init the object, the object does not own the data nor copy.
88 // It is designed to directly use data from memory mapped resources.
init(const int32_t * data,int32_t d1)89 void init(const int32_t* data, int32_t d1) {
90 U_ASSERT(IEEE_754 == 1);
91 data_ = reinterpret_cast<const float*>(data);
92 d1_ = d1;
93 }
94
95 // ReadArray1D methods.
d1() const96 virtual int32_t d1() const override { return d1_; }
get(int32_t i) const97 virtual float get(int32_t i) const override {
98 U_ASSERT(i < d1_);
99 return data_[i];
100 }
101
102 private:
103 const float* data_;
104 int32_t d1_;
105 };
106
~ConstArray1D()107 ConstArray1D::~ConstArray1D()
108 {
109 }
110
111 /**
112 * A class to index a float array as a 2D Array without owning the pointer or
113 * copy the data.
114 */
115 class ConstArray2D : public ReadArray2D {
116 public:
ConstArray2D()117 ConstArray2D() : data_(nullptr), d1_(0), d2_(0) {}
118
ConstArray2D(const float * data,int32_t d1,int32_t d2)119 ConstArray2D(const float* data, int32_t d1, int32_t d2)
120 : data_(data), d1_(d1), d2_(d2) {}
121
122 virtual ~ConstArray2D();
123
124 // Init the object, the object does not own the data nor copy.
125 // It is designed to directly use data from memory mapped resources.
init(const int32_t * data,int32_t d1,int32_t d2)126 void init(const int32_t* data, int32_t d1, int32_t d2) {
127 U_ASSERT(IEEE_754 == 1);
128 data_ = reinterpret_cast<const float*>(data);
129 d1_ = d1;
130 d2_ = d2;
131 }
132
133 // ReadArray2D methods.
d1() const134 inline int32_t d1() const override { return d1_; }
d2() const135 inline int32_t d2() const override { return d2_; }
get(int32_t i,int32_t j) const136 float get(int32_t i, int32_t j) const override {
137 U_ASSERT(i < d1_);
138 U_ASSERT(j < d2_);
139 return data_[i * d2_ + j];
140 }
141
142 // Expose the ith row as a ConstArray1D
row(int32_t i) const143 inline ConstArray1D row(int32_t i) const {
144 U_ASSERT(i < d1_);
145 return ConstArray1D(data_ + i * d2_, d2_);
146 }
147
148 private:
149 const float* data_;
150 int32_t d1_;
151 int32_t d2_;
152 };
153
~ConstArray2D()154 ConstArray2D::~ConstArray2D()
155 {
156 }
157
158 /**
159 * A class to allocate data as a writable 1D array.
160 * This is the main class implement matrix operation.
161 */
162 class Array1D : public ReadArray1D {
163 public:
Array1D()164 Array1D() : memory_(nullptr), data_(nullptr), d1_(0) {}
Array1D(int32_t d1,UErrorCode & status)165 Array1D(int32_t d1, UErrorCode &status)
166 : memory_(uprv_malloc(d1 * sizeof(float))),
167 data_((float*)memory_), d1_(d1) {
168 if (U_SUCCESS(status)) {
169 if (memory_ == nullptr) {
170 status = U_MEMORY_ALLOCATION_ERROR;
171 return;
172 }
173 clear();
174 }
175 }
176
177 virtual ~Array1D();
178
179 // A special constructor which does not own the memory but writeable
180 // as a slice of an array.
Array1D(float * data,int32_t d1)181 Array1D(float* data, int32_t d1)
182 : memory_(nullptr), data_(data), d1_(d1) {}
183
184 // ReadArray1D methods.
d1() const185 virtual int32_t d1() const override { return d1_; }
get(int32_t i) const186 virtual float get(int32_t i) const override {
187 U_ASSERT(i < d1_);
188 return data_[i];
189 }
190
191 // Return the index which point to the max data in the array.
maxIndex() const192 inline int32_t maxIndex() const {
193 int32_t index = 0;
194 float max = data_[0];
195 for (int32_t i = 1; i < d1_; i++) {
196 if (data_[i] > max) {
197 max = data_[i];
198 index = i;
199 }
200 }
201 return index;
202 }
203
204 // Slice part of the array to a new one.
slice(int32_t from,int32_t size) const205 inline Array1D slice(int32_t from, int32_t size) const {
206 U_ASSERT(from >= 0);
207 U_ASSERT(from < d1_);
208 U_ASSERT(from + size <= d1_);
209 return Array1D(data_ + from, size);
210 }
211
212 // Add dot product of a 1D array and a 2D array into this one.
addDotProduct(const ReadArray1D & a,const ReadArray2D & b)213 inline Array1D& addDotProduct(const ReadArray1D& a, const ReadArray2D& b) {
214 U_ASSERT(a.d1() == b.d1());
215 U_ASSERT(b.d2() == d1());
216 for (int32_t i = 0; i < d1(); i++) {
217 for (int32_t j = 0; j < a.d1(); j++) {
218 data_[i] += a.get(j) * b.get(j, i);
219 }
220 }
221 return *this;
222 }
223
224 // Hadamard Product the values of another array of the same size into this one.
hadamardProduct(const ReadArray1D & a)225 inline Array1D& hadamardProduct(const ReadArray1D& a) {
226 U_ASSERT(a.d1() == d1());
227 for (int32_t i = 0; i < d1(); i++) {
228 data_[i] *= a.get(i);
229 }
230 return *this;
231 }
232
233 // Add the Hadamard Product of two arrays of the same size into this one.
addHadamardProduct(const ReadArray1D & a,const ReadArray1D & b)234 inline Array1D& addHadamardProduct(const ReadArray1D& a, const ReadArray1D& b) {
235 U_ASSERT(a.d1() == d1());
236 U_ASSERT(b.d1() == d1());
237 for (int32_t i = 0; i < d1(); i++) {
238 data_[i] += a.get(i) * b.get(i);
239 }
240 return *this;
241 }
242
243 // Add the values of another array of the same size into this one.
add(const ReadArray1D & a)244 inline Array1D& add(const ReadArray1D& a) {
245 U_ASSERT(a.d1() == d1());
246 for (int32_t i = 0; i < d1(); i++) {
247 data_[i] += a.get(i);
248 }
249 return *this;
250 }
251
252 // Assign the values of another array of the same size into this one.
assign(const ReadArray1D & a)253 inline Array1D& assign(const ReadArray1D& a) {
254 U_ASSERT(a.d1() == d1());
255 for (int32_t i = 0; i < d1(); i++) {
256 data_[i] = a.get(i);
257 }
258 return *this;
259 }
260
261 // Apply tanh to all the elements in the array.
tanh()262 inline Array1D& tanh() {
263 return tanh(*this);
264 }
265
266 // Apply tanh of a and store into this array.
tanh(const Array1D & a)267 inline Array1D& tanh(const Array1D& a) {
268 U_ASSERT(a.d1() == d1());
269 for (int32_t i = 0; i < d1_; i++) {
270 data_[i] = std::tanh(a.get(i));
271 }
272 return *this;
273 }
274
275 // Apply sigmoid to all the elements in the array.
sigmoid()276 inline Array1D& sigmoid() {
277 for (int32_t i = 0; i < d1_; i++) {
278 data_[i] = 1.0f/(1.0f + expf(-data_[i]));
279 }
280 return *this;
281 }
282
clear()283 inline Array1D& clear() {
284 uprv_memset(data_, 0, d1_ * sizeof(float));
285 return *this;
286 }
287
288 private:
289 void* memory_;
290 float* data_;
291 int32_t d1_;
292 };
293
~Array1D()294 Array1D::~Array1D()
295 {
296 uprv_free(memory_);
297 }
298
299 class Array2D : public ReadArray2D {
300 public:
Array2D()301 Array2D() : memory_(nullptr), data_(nullptr), d1_(0), d2_(0) {}
Array2D(int32_t d1,int32_t d2,UErrorCode & status)302 Array2D(int32_t d1, int32_t d2, UErrorCode &status)
303 : memory_(uprv_malloc(d1 * d2 * sizeof(float))),
304 data_((float*)memory_), d1_(d1), d2_(d2) {
305 if (U_SUCCESS(status)) {
306 if (memory_ == nullptr) {
307 status = U_MEMORY_ALLOCATION_ERROR;
308 return;
309 }
310 clear();
311 }
312 }
313 virtual ~Array2D();
314
315 // ReadArray2D methods.
d1() const316 virtual int32_t d1() const override { return d1_; }
d2() const317 virtual int32_t d2() const override { return d2_; }
get(int32_t i,int32_t j) const318 virtual float get(int32_t i, int32_t j) const override {
319 U_ASSERT(i < d1_);
320 U_ASSERT(j < d2_);
321 return data_[i * d2_ + j];
322 }
323
row(int32_t i) const324 inline Array1D row(int32_t i) const {
325 U_ASSERT(i < d1_);
326 return Array1D(data_ + i * d2_, d2_);
327 }
328
clear()329 inline Array2D& clear() {
330 uprv_memset(data_, 0, d1_ * d2_ * sizeof(float));
331 return *this;
332 }
333
334 private:
335 void* memory_;
336 float* data_;
337 int32_t d1_;
338 int32_t d2_;
339 };
340
~Array2D()341 Array2D::~Array2D()
342 {
343 uprv_free(memory_);
344 }
345
346 typedef enum {
347 BEGIN,
348 INSIDE,
349 END,
350 SINGLE
351 } LSTMClass;
352
353 typedef enum {
354 UNKNOWN,
355 CODE_POINTS,
356 GRAPHEME_CLUSTER,
357 } EmbeddingType;
358
359 struct LSTMData : public UMemory {
360 LSTMData(UResourceBundle* rb, UErrorCode &status);
361 ~LSTMData();
362 UHashtable* fDict;
363 EmbeddingType fType;
364 const UChar* fName;
365 ConstArray2D fEmbedding;
366 ConstArray2D fForwardW;
367 ConstArray2D fForwardU;
368 ConstArray1D fForwardB;
369 ConstArray2D fBackwardW;
370 ConstArray2D fBackwardU;
371 ConstArray1D fBackwardB;
372 ConstArray2D fOutputW;
373 ConstArray1D fOutputB;
374
375 private:
376 UResourceBundle* fBundle;
377 };
378
LSTMData(UResourceBundle * rb,UErrorCode & status)379 LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status)
380 : fDict(nullptr), fType(UNKNOWN), fName(nullptr),
381 fBundle(rb)
382 {
383 if (U_FAILURE(status)) {
384 return;
385 }
386 if (IEEE_754 != 1) {
387 status = U_UNSUPPORTED_ERROR;
388 return;
389 }
390 LocalUResourceBundlePointer embeddings_res(
391 ures_getByKey(rb, "embeddings", nullptr, &status));
392 int32_t embedding_size = ures_getInt(embeddings_res.getAlias(), &status);
393 LocalUResourceBundlePointer hunits_res(
394 ures_getByKey(rb, "hunits", nullptr, &status));
395 if (U_FAILURE(status)) return;
396 int32_t hunits = ures_getInt(hunits_res.getAlias(), &status);
397 const UChar* type = ures_getStringByKey(rb, "type", nullptr, &status);
398 if (U_FAILURE(status)) return;
399 if (u_strCompare(type, -1, u"codepoints", -1, false) == 0) {
400 fType = CODE_POINTS;
401 } else if (u_strCompare(type, -1, u"graphclust", -1, false) == 0) {
402 fType = GRAPHEME_CLUSTER;
403 }
404 fName = ures_getStringByKey(rb, "model", nullptr, &status);
405 LocalUResourceBundlePointer dataRes(ures_getByKey(rb, "data", nullptr, &status));
406 if (U_FAILURE(status)) return;
407 int32_t data_len = 0;
408 const int32_t* data = ures_getIntVector(dataRes.getAlias(), &data_len, &status);
409 fDict = uhash_open(uhash_hashUChars, uhash_compareUChars, nullptr, &status);
410
411 StackUResourceBundle stackTempBundle;
412 ResourceDataValue value;
413 ures_getValueWithFallback(rb, "dict", stackTempBundle.getAlias(), value, status);
414 ResourceArray stringArray = value.getArray(status);
415 int32_t num_index = stringArray.getSize();
416 if (U_FAILURE(status)) { return; }
417
418 // put dict into hash
419 int32_t stringLength;
420 for (int32_t idx = 0; idx < num_index; idx++) {
421 stringArray.getValue(idx, value);
422 const UChar* str = value.getString(stringLength, status);
423 uhash_putiAllowZero(fDict, (void*)str, idx, &status);
424 if (U_FAILURE(status)) return;
425 #ifdef LSTM_VECTORIZER_DEBUG
426 printf("Assign [");
427 while (*str != 0x0000) {
428 printf("U+%04x ", *str);
429 str++;
430 }
431 printf("] map to %d\n", idx-1);
432 #endif
433 }
434 int32_t mat1_size = (num_index + 1) * embedding_size;
435 int32_t mat2_size = embedding_size * 4 * hunits;
436 int32_t mat3_size = hunits * 4 * hunits;
437 int32_t mat4_size = 4 * hunits;
438 int32_t mat5_size = mat2_size;
439 int32_t mat6_size = mat3_size;
440 int32_t mat7_size = mat4_size;
441 int32_t mat8_size = 2 * hunits * 4;
442 #if U_DEBUG
443 int32_t mat9_size = 4;
444 U_ASSERT(data_len == mat1_size + mat2_size + mat3_size + mat4_size + mat5_size +
445 mat6_size + mat7_size + mat8_size + mat9_size);
446 #endif
447
448 fEmbedding.init(data, (num_index + 1), embedding_size);
449 data += mat1_size;
450 fForwardW.init(data, embedding_size, 4 * hunits);
451 data += mat2_size;
452 fForwardU.init(data, hunits, 4 * hunits);
453 data += mat3_size;
454 fForwardB.init(data, 4 * hunits);
455 data += mat4_size;
456 fBackwardW.init(data, embedding_size, 4 * hunits);
457 data += mat5_size;
458 fBackwardU.init(data, hunits, 4 * hunits);
459 data += mat6_size;
460 fBackwardB.init(data, 4 * hunits);
461 data += mat7_size;
462 fOutputW.init(data, 2 * hunits, 4);
463 data += mat8_size;
464 fOutputB.init(data, 4);
465 }
466
~LSTMData()467 LSTMData::~LSTMData() {
468 uhash_close(fDict);
469 ures_close(fBundle);
470 }
471
472 class Vectorizer : public UMemory {
473 public:
Vectorizer(UHashtable * dict)474 Vectorizer(UHashtable* dict) : fDict(dict) {}
475 virtual ~Vectorizer();
476 virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
477 UVector32 &offsets, UVector32 &indices,
478 UErrorCode &status) const = 0;
479 protected:
stringToIndex(const UChar * str) const480 int32_t stringToIndex(const UChar* str) const {
481 UBool found = false;
482 int32_t ret = uhash_getiAndFound(fDict, (const void*)str, &found);
483 if (!found) {
484 ret = fDict->count;
485 }
486 #ifdef LSTM_VECTORIZER_DEBUG
487 printf("[");
488 while (*str != 0x0000) {
489 printf("U+%04x ", *str);
490 str++;
491 }
492 printf("] map to %d\n", ret);
493 #endif
494 return ret;
495 }
496
497 private:
498 UHashtable* fDict;
499 };
500
~Vectorizer()501 Vectorizer::~Vectorizer()
502 {
503 }
504
505 class CodePointsVectorizer : public Vectorizer {
506 public:
CodePointsVectorizer(UHashtable * dict)507 CodePointsVectorizer(UHashtable* dict) : Vectorizer(dict) {}
508 virtual ~CodePointsVectorizer();
509 virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
510 UVector32 &offsets, UVector32 &indices,
511 UErrorCode &status) const override;
512 };
513
~CodePointsVectorizer()514 CodePointsVectorizer::~CodePointsVectorizer()
515 {
516 }
517
vectorize(UText * text,int32_t startPos,int32_t endPos,UVector32 & offsets,UVector32 & indices,UErrorCode & status) const518 void CodePointsVectorizer::vectorize(
519 UText *text, int32_t startPos, int32_t endPos,
520 UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
521 {
522 if (offsets.ensureCapacity(endPos - startPos, status) &&
523 indices.ensureCapacity(endPos - startPos, status)) {
524 if (U_FAILURE(status)) return;
525 utext_setNativeIndex(text, startPos);
526 int32_t current;
527 UChar str[2] = {0, 0};
528 while (U_SUCCESS(status) &&
529 (current = (int32_t)utext_getNativeIndex(text)) < endPos) {
530 // Since the LSTMBreakEngine is currently only accept chars in BMP,
531 // we can ignore the possibility of hitting supplementary code
532 // point.
533 str[0] = (UChar) utext_next32(text);
534 U_ASSERT(!U_IS_SURROGATE(str[0]));
535 offsets.addElement(current, status);
536 indices.addElement(stringToIndex(str), status);
537 }
538 }
539 }
540
541 class GraphemeClusterVectorizer : public Vectorizer {
542 public:
GraphemeClusterVectorizer(UHashtable * dict)543 GraphemeClusterVectorizer(UHashtable* dict)
544 : Vectorizer(dict)
545 {
546 }
547 virtual ~GraphemeClusterVectorizer();
548 virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
549 UVector32 &offsets, UVector32 &indices,
550 UErrorCode &status) const override;
551 };
552
~GraphemeClusterVectorizer()553 GraphemeClusterVectorizer::~GraphemeClusterVectorizer()
554 {
555 }
556
557 constexpr int32_t MAX_GRAPHEME_CLSTER_LENGTH = 10;
558
vectorize(UText * text,int32_t startPos,int32_t endPos,UVector32 & offsets,UVector32 & indices,UErrorCode & status) const559 void GraphemeClusterVectorizer::vectorize(
560 UText *text, int32_t startPos, int32_t endPos,
561 UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
562 {
563 if (U_FAILURE(status)) return;
564 if (!offsets.ensureCapacity(endPos - startPos, status) ||
565 !indices.ensureCapacity(endPos - startPos, status)) {
566 return;
567 }
568 if (U_FAILURE(status)) return;
569 LocalPointer<BreakIterator> graphemeIter(BreakIterator::createCharacterInstance(Locale(), status));
570 if (U_FAILURE(status)) return;
571 graphemeIter->setText(text, status);
572 if (U_FAILURE(status)) return;
573
574 if (startPos != 0) {
575 graphemeIter->preceding(startPos);
576 }
577 int32_t last = startPos;
578 int32_t current = startPos;
579 UChar str[MAX_GRAPHEME_CLSTER_LENGTH];
580 while ((current = graphemeIter->next()) != BreakIterator::DONE) {
581 if (current >= endPos) {
582 break;
583 }
584 if (current > startPos) {
585 utext_extract(text, last, current, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
586 if (U_FAILURE(status)) return;
587 offsets.addElement(last, status);
588 indices.addElement(stringToIndex(str), status);
589 if (U_FAILURE(status)) return;
590 }
591 last = current;
592 }
593 if (U_FAILURE(status) || last >= endPos) {
594 return;
595 }
596 utext_extract(text, last, endPos, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
597 if (U_SUCCESS(status)) {
598 offsets.addElement(last, status);
599 indices.addElement(stringToIndex(str), status);
600 }
601 }
602
603 // Computing LSTM as stated in
604 // https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate
605 // ifco is temp array allocate outside which does not need to be
606 // input/output value but could avoid unnecessary memory alloc/free if passing
607 // in.
compute(int32_t hunits,const ReadArray2D & W,const ReadArray2D & U,const ReadArray1D & b,const ReadArray1D & x,Array1D & h,Array1D & c,Array1D & ifco)608 void compute(
609 int32_t hunits,
610 const ReadArray2D& W, const ReadArray2D& U, const ReadArray1D& b,
611 const ReadArray1D& x, Array1D& h, Array1D& c,
612 Array1D& ifco)
613 {
614 // ifco = x * W + h * U + b
615 ifco.assign(b)
616 .addDotProduct(x, W)
617 .addDotProduct(h, U);
618
619 ifco.slice(0*hunits, hunits).sigmoid(); // i: sigmod
620 ifco.slice(1*hunits, hunits).sigmoid(); // f: sigmoid
621 ifco.slice(2*hunits, hunits).tanh(); // c_: tanh
622 ifco.slice(3*hunits, hunits).sigmoid(); // o: sigmod
623
624 c.hadamardProduct(ifco.slice(hunits, hunits))
625 .addHadamardProduct(ifco.slice(0, hunits), ifco.slice(2*hunits, hunits));
626
627 h.tanh(c)
628 .hadamardProduct(ifco.slice(3*hunits, hunits));
629 }
630
631 // Minimum word size
632 static const int32_t MIN_WORD = 2;
633
634 // Minimum number of characters for two words
635 static const int32_t MIN_WORD_SPAN = MIN_WORD * 2;
636
637 int32_t
divideUpDictionaryRange(UText * text,int32_t startPos,int32_t endPos,UVector32 & foundBreaks,UBool,UErrorCode & status) const638 LSTMBreakEngine::divideUpDictionaryRange( UText *text,
639 int32_t startPos,
640 int32_t endPos,
641 UVector32 &foundBreaks,
642 UBool /* isPhraseBreaking */,
643 UErrorCode& status) const {
644 if (U_FAILURE(status)) return 0;
645 int32_t beginFoundBreakSize = foundBreaks.size();
646 utext_setNativeIndex(text, startPos);
647 utext_moveIndex32(text, MIN_WORD_SPAN);
648 if (utext_getNativeIndex(text) >= endPos) {
649 return 0; // Not enough characters for two words
650 }
651 utext_setNativeIndex(text, startPos);
652
653 UVector32 offsets(status);
654 UVector32 indices(status);
655 if (U_FAILURE(status)) return 0;
656 fVectorizer->vectorize(text, startPos, endPos, offsets, indices, status);
657 if (U_FAILURE(status)) return 0;
658 int32_t* offsetsBuf = offsets.getBuffer();
659 int32_t* indicesBuf = indices.getBuffer();
660
661 int32_t input_seq_len = indices.size();
662 int32_t hunits = fData->fForwardU.d1();
663
664 // ----- Begin of all the Array memory allocation needed for this function
665 // Allocate temp array used inside compute()
666 Array1D ifco(4 * hunits, status);
667
668 Array1D c(hunits, status);
669 Array1D logp(4, status);
670
671 // TODO: limit size of hBackward. If input_seq_len is too big, we could
672 // run out of memory.
673 // Backward LSTM
674 Array2D hBackward(input_seq_len, hunits, status);
675
676 // Allocate fbRow and slice the internal array in two.
677 Array1D fbRow(2 * hunits, status);
678
679 // ----- End of all the Array memory allocation needed for this function
680 if (U_FAILURE(status)) return 0;
681
682 // To save the needed memory usage, the following is different from the
683 // Python or ICU4X implementation. We first perform the Backward LSTM
684 // and then merge the iteration of the forward LSTM and the output layer
685 // together because we only neetdto remember the h[t-1] for Forward LSTM.
686 for (int32_t i = input_seq_len - 1; i >= 0; i--) {
687 Array1D hRow = hBackward.row(i);
688 if (i != input_seq_len - 1) {
689 hRow.assign(hBackward.row(i+1));
690 }
691 #ifdef LSTM_DEBUG
692 printf("hRow %d\n", i);
693 hRow.print();
694 printf("indicesBuf[%d] = %d\n", i, indicesBuf[i]);
695 printf("fData->fEmbedding.row(indicesBuf[%d]):\n", i);
696 fData->fEmbedding.row(indicesBuf[i]).print();
697 #endif // LSTM_DEBUG
698 compute(hunits,
699 fData->fBackwardW, fData->fBackwardU, fData->fBackwardB,
700 fData->fEmbedding.row(indicesBuf[i]),
701 hRow, c, ifco);
702 }
703
704
705 Array1D forwardRow = fbRow.slice(0, hunits); // point to first half of data in fbRow.
706 Array1D backwardRow = fbRow.slice(hunits, hunits); // point to second half of data n fbRow.
707
708 // The following iteration merge the forward LSTM and the output layer
709 // together.
710 c.clear(); // reuse c since it is the same size.
711 for (int32_t i = 0; i < input_seq_len; i++) {
712 #ifdef LSTM_DEBUG
713 printf("forwardRow %d\n", i);
714 forwardRow.print();
715 #endif // LSTM_DEBUG
716 // Forward LSTM
717 // Calculate the result into forwardRow, which point to the data in the first half
718 // of fbRow.
719 compute(hunits,
720 fData->fForwardW, fData->fForwardU, fData->fForwardB,
721 fData->fEmbedding.row(indicesBuf[i]),
722 forwardRow, c, ifco);
723
724 // assign the data from hBackward.row(i) to second half of fbRowa.
725 backwardRow.assign(hBackward.row(i));
726
727 logp.assign(fData->fOutputB).addDotProduct(fbRow, fData->fOutputW);
728 #ifdef LSTM_DEBUG
729 printf("backwardRow %d\n", i);
730 backwardRow.print();
731 printf("logp %d\n", i);
732 logp.print();
733 #endif // LSTM_DEBUG
734
735 // current = argmax(logp)
736 LSTMClass current = (LSTMClass)logp.maxIndex();
737 // BIES logic.
738 if (current == BEGIN || current == SINGLE) {
739 if (i != 0) {
740 foundBreaks.addElement(offsetsBuf[i], status);
741 if (U_FAILURE(status)) return 0;
742 }
743 }
744 }
745 return foundBreaks.size() - beginFoundBreakSize;
746 }
747
createVectorizer(const LSTMData * data,UErrorCode & status)748 Vectorizer* createVectorizer(const LSTMData* data, UErrorCode &status) {
749 if (U_FAILURE(status)) {
750 return nullptr;
751 }
752 switch (data->fType) {
753 case CODE_POINTS:
754 return new CodePointsVectorizer(data->fDict);
755 break;
756 case GRAPHEME_CLUSTER:
757 return new GraphemeClusterVectorizer(data->fDict);
758 break;
759 default:
760 break;
761 }
762 UPRV_UNREACHABLE_EXIT;
763 }
764
LSTMBreakEngine(const LSTMData * data,const UnicodeSet & set,UErrorCode & status)765 LSTMBreakEngine::LSTMBreakEngine(const LSTMData* data, const UnicodeSet& set, UErrorCode &status)
766 : DictionaryBreakEngine(), fData(data), fVectorizer(createVectorizer(fData, status))
767 {
768 if (U_FAILURE(status)) {
769 fData = nullptr; // If failure, we should not delete fData in destructor because the caller will do so.
770 return;
771 }
772 setCharacters(set);
773 }
774
~LSTMBreakEngine()775 LSTMBreakEngine::~LSTMBreakEngine() {
776 delete fData;
777 delete fVectorizer;
778 }
779
name() const780 const UChar* LSTMBreakEngine::name() const {
781 return fData->fName;
782 }
783
defaultLSTM(UScriptCode script,UErrorCode & status)784 UnicodeString defaultLSTM(UScriptCode script, UErrorCode& status) {
785 // open root from brkitr tree.
786 UResourceBundle *b = ures_open(U_ICUDATA_BRKITR, "", &status);
787 b = ures_getByKeyWithFallback(b, "lstm", b, &status);
788 UnicodeString result = ures_getUnicodeStringByKey(b, uscript_getShortName(script), &status);
789 ures_close(b);
790 return result;
791 }
792
CreateLSTMDataForScript(UScriptCode script,UErrorCode & status)793 U_CAPI const LSTMData* U_EXPORT2 CreateLSTMDataForScript(UScriptCode script, UErrorCode& status)
794 {
795 if (script != USCRIPT_KHMER && script != USCRIPT_LAO && script != USCRIPT_MYANMAR && script != USCRIPT_THAI) {
796 return nullptr;
797 }
798 UnicodeString name = defaultLSTM(script, status);
799 if (U_FAILURE(status)) return nullptr;
800 CharString namebuf;
801 namebuf.appendInvariantChars(name, status).truncate(namebuf.lastIndexOf('.'));
802
803 LocalUResourceBundlePointer rb(
804 ures_openDirect(U_ICUDATA_BRKITR, namebuf.data(), &status));
805 if (U_FAILURE(status)) return nullptr;
806
807 return CreateLSTMData(rb.orphan(), status);
808 }
809
CreateLSTMData(UResourceBundle * rb,UErrorCode & status)810 U_CAPI const LSTMData* U_EXPORT2 CreateLSTMData(UResourceBundle* rb, UErrorCode& status)
811 {
812 return new LSTMData(rb, status);
813 }
814
815 U_CAPI const LanguageBreakEngine* U_EXPORT2
CreateLSTMBreakEngine(UScriptCode script,const LSTMData * data,UErrorCode & status)816 CreateLSTMBreakEngine(UScriptCode script, const LSTMData* data, UErrorCode& status)
817 {
818 UnicodeString unicodeSetString;
819 switch(script) {
820 case USCRIPT_THAI:
821 unicodeSetString = UnicodeString(u"[[:Thai:]&[:LineBreak=SA:]]");
822 break;
823 case USCRIPT_MYANMAR:
824 unicodeSetString = UnicodeString(u"[[:Mymr:]&[:LineBreak=SA:]]");
825 break;
826 default:
827 delete data;
828 return nullptr;
829 }
830 UnicodeSet unicodeSet;
831 unicodeSet.applyPattern(unicodeSetString, status);
832 const LanguageBreakEngine* engine = new LSTMBreakEngine(data, unicodeSet, status);
833 if (U_FAILURE(status) || engine == nullptr) {
834 if (engine != nullptr) {
835 delete engine;
836 } else {
837 status = U_MEMORY_ALLOCATION_ERROR;
838 }
839 return nullptr;
840 }
841 return engine;
842 }
843
DeleteLSTMData(const LSTMData * data)844 U_CAPI void U_EXPORT2 DeleteLSTMData(const LSTMData* data)
845 {
846 delete data;
847 }
848
LSTMDataName(const LSTMData * data)849 U_CAPI const UChar* U_EXPORT2 LSTMDataName(const LSTMData* data)
850 {
851 return data->fName;
852 }
853
854 U_NAMESPACE_END
855
856 #endif /* #if !UCONFIG_NO_BREAK_ITERATION */
857