1 // © 2021 and later: Unicode, Inc. and others.
2 // License & terms of use: http://www.unicode.org/copyright.html
3
4 #include <utility>
5 #include <ctgmath>
6
7 #include "unicode/utypes.h"
8
9 #if !UCONFIG_NO_BREAK_ITERATION
10
11 #include "brkeng.h"
12 #include "charstr.h"
13 #include "cmemory.h"
14 #include "lstmbe.h"
15 #include "putilimp.h"
16 #include "uassert.h"
17 #include "ubrkimpl.h"
18 #include "uresimp.h"
19 #include "uvectr32.h"
20 #include "uvector.h"
21
22 #include "unicode/brkiter.h"
23 #include "unicode/resbund.h"
24 #include "unicode/ubrk.h"
25 #include "unicode/uniset.h"
26 #include "unicode/ustring.h"
27 #include "unicode/utf.h"
28
29 U_NAMESPACE_BEGIN
30
31 // Uncomment the following #define to debug.
32 // #define LSTM_DEBUG 1
33 // #define LSTM_VECTORIZER_DEBUG 1
34
35 /**
36 * Interface for reading 1D array.
37 */
38 class ReadArray1D {
39 public:
40 virtual ~ReadArray1D();
41 virtual int32_t d1() const = 0;
42 virtual float get(int32_t i) const = 0;
43
44 #ifdef LSTM_DEBUG
print() const45 void print() const {
46 printf("\n[");
47 for (int32_t i = 0; i < d1(); i++) {
48 printf("%0.8e ", get(i));
49 if (i % 4 == 3) printf("\n");
50 }
51 printf("]\n");
52 }
53 #endif
54 };
55
~ReadArray1D()56 ReadArray1D::~ReadArray1D()
57 {
58 }
59
60 /**
61 * Interface for reading 2D array.
62 */
63 class ReadArray2D {
64 public:
65 virtual ~ReadArray2D();
66 virtual int32_t d1() const = 0;
67 virtual int32_t d2() const = 0;
68 virtual float get(int32_t i, int32_t j) const = 0;
69 };
70
~ReadArray2D()71 ReadArray2D::~ReadArray2D()
72 {
73 }
74
75 /**
76 * A class to index a float array as a 1D Array without owning the pointer or
77 * copy the data.
78 */
79 class ConstArray1D : public ReadArray1D {
80 public:
ConstArray1D()81 ConstArray1D() : data_(nullptr), d1_(0) {}
82
ConstArray1D(const float * data,int32_t d1)83 ConstArray1D(const float* data, int32_t d1) : data_(data), d1_(d1) {}
84
85 virtual ~ConstArray1D();
86
87 // Init the object, the object does not own the data nor copy.
88 // It is designed to directly use data from memory mapped resources.
init(const int32_t * data,int32_t d1)89 void init(const int32_t* data, int32_t d1) {
90 U_ASSERT(IEEE_754 == 1);
91 data_ = reinterpret_cast<const float*>(data);
92 d1_ = d1;
93 }
94
95 // ReadArray1D methods.
d1() const96 virtual int32_t d1() const override { return d1_; }
get(int32_t i) const97 virtual float get(int32_t i) const override {
98 U_ASSERT(i < d1_);
99 return data_[i];
100 }
101
102 private:
103 const float* data_;
104 int32_t d1_;
105 };
106
~ConstArray1D()107 ConstArray1D::~ConstArray1D()
108 {
109 }
110
111 /**
112 * A class to index a float array as a 2D Array without owning the pointer or
113 * copy the data.
114 */
115 class ConstArray2D : public ReadArray2D {
116 public:
ConstArray2D()117 ConstArray2D() : data_(nullptr), d1_(0), d2_(0) {}
118
ConstArray2D(const float * data,int32_t d1,int32_t d2)119 ConstArray2D(const float* data, int32_t d1, int32_t d2)
120 : data_(data), d1_(d1), d2_(d2) {}
121
122 virtual ~ConstArray2D();
123
124 // Init the object, the object does not own the data nor copy.
125 // It is designed to directly use data from memory mapped resources.
init(const int32_t * data,int32_t d1,int32_t d2)126 void init(const int32_t* data, int32_t d1, int32_t d2) {
127 U_ASSERT(IEEE_754 == 1);
128 data_ = reinterpret_cast<const float*>(data);
129 d1_ = d1;
130 d2_ = d2;
131 }
132
133 // ReadArray2D methods.
d1() const134 inline int32_t d1() const override { return d1_; }
d2() const135 inline int32_t d2() const override { return d2_; }
get(int32_t i,int32_t j) const136 float get(int32_t i, int32_t j) const override {
137 U_ASSERT(i < d1_);
138 U_ASSERT(j < d2_);
139 return data_[i * d2_ + j];
140 }
141
142 // Expose the ith row as a ConstArray1D
row(int32_t i) const143 inline ConstArray1D row(int32_t i) const {
144 U_ASSERT(i < d1_);
145 return ConstArray1D(data_ + i * d2_, d2_);
146 }
147
148 private:
149 const float* data_;
150 int32_t d1_;
151 int32_t d2_;
152 };
153
~ConstArray2D()154 ConstArray2D::~ConstArray2D()
155 {
156 }
157
158 /**
159 * A class to allocate data as a writable 1D array.
160 * This is the main class implement matrix operation.
161 */
162 class Array1D : public ReadArray1D {
163 public:
Array1D()164 Array1D() : memory_(nullptr), data_(nullptr), d1_(0) {}
Array1D(int32_t d1,UErrorCode & status)165 Array1D(int32_t d1, UErrorCode &status)
166 : memory_(uprv_malloc(d1 * sizeof(float))),
167 data_((float*)memory_), d1_(d1) {
168 if (U_SUCCESS(status)) {
169 if (memory_ == nullptr) {
170 status = U_MEMORY_ALLOCATION_ERROR;
171 return;
172 }
173 clear();
174 }
175 }
176
177 virtual ~Array1D();
178
179 // A special constructor which does not own the memory but writeable
180 // as a slice of an array.
Array1D(float * data,int32_t d1)181 Array1D(float* data, int32_t d1)
182 : memory_(nullptr), data_(data), d1_(d1) {}
183
184 // ReadArray1D methods.
d1() const185 virtual int32_t d1() const override { return d1_; }
get(int32_t i) const186 virtual float get(int32_t i) const override {
187 U_ASSERT(i < d1_);
188 return data_[i];
189 }
190
191 // Return the index which point to the max data in the array.
maxIndex() const192 inline int32_t maxIndex() const {
193 int32_t index = 0;
194 float max = data_[0];
195 for (int32_t i = 1; i < d1_; i++) {
196 if (data_[i] > max) {
197 max = data_[i];
198 index = i;
199 }
200 }
201 return index;
202 }
203
204 // Slice part of the array to a new one.
slice(int32_t from,int32_t size) const205 inline Array1D slice(int32_t from, int32_t size) const {
206 U_ASSERT(from >= 0);
207 U_ASSERT(from < d1_);
208 U_ASSERT(from + size <= d1_);
209 return Array1D(data_ + from, size);
210 }
211
212 // Add dot product of a 1D array and a 2D array into this one.
addDotProduct(const ReadArray1D & a,const ReadArray2D & b)213 inline Array1D& addDotProduct(const ReadArray1D& a, const ReadArray2D& b) {
214 U_ASSERT(a.d1() == b.d1());
215 U_ASSERT(b.d2() == d1());
216 for (int32_t i = 0; i < d1(); i++) {
217 for (int32_t j = 0; j < a.d1(); j++) {
218 data_[i] += a.get(j) * b.get(j, i);
219 }
220 }
221 return *this;
222 }
223
224 // Hadamard Product the values of another array of the same size into this one.
hadamardProduct(const ReadArray1D & a)225 inline Array1D& hadamardProduct(const ReadArray1D& a) {
226 U_ASSERT(a.d1() == d1());
227 for (int32_t i = 0; i < d1(); i++) {
228 data_[i] *= a.get(i);
229 }
230 return *this;
231 }
232
233 // Add the Hadamard Product of two arrays of the same size into this one.
addHadamardProduct(const ReadArray1D & a,const ReadArray1D & b)234 inline Array1D& addHadamardProduct(const ReadArray1D& a, const ReadArray1D& b) {
235 U_ASSERT(a.d1() == d1());
236 U_ASSERT(b.d1() == d1());
237 for (int32_t i = 0; i < d1(); i++) {
238 data_[i] += a.get(i) * b.get(i);
239 }
240 return *this;
241 }
242
243 // Add the values of another array of the same size into this one.
add(const ReadArray1D & a)244 inline Array1D& add(const ReadArray1D& a) {
245 U_ASSERT(a.d1() == d1());
246 for (int32_t i = 0; i < d1(); i++) {
247 data_[i] += a.get(i);
248 }
249 return *this;
250 }
251
252 // Assign the values of another array of the same size into this one.
assign(const ReadArray1D & a)253 inline Array1D& assign(const ReadArray1D& a) {
254 U_ASSERT(a.d1() == d1());
255 for (int32_t i = 0; i < d1(); i++) {
256 data_[i] = a.get(i);
257 }
258 return *this;
259 }
260
261 // Apply tanh to all the elements in the array.
tanh()262 inline Array1D& tanh() {
263 return tanh(*this);
264 }
265
266 // Apply tanh of a and store into this array.
tanh(const Array1D & a)267 inline Array1D& tanh(const Array1D& a) {
268 U_ASSERT(a.d1() == d1());
269 for (int32_t i = 0; i < d1_; i++) {
270 data_[i] = std::tanh(a.get(i));
271 }
272 return *this;
273 }
274
275 // Apply sigmoid to all the elements in the array.
sigmoid()276 inline Array1D& sigmoid() {
277 for (int32_t i = 0; i < d1_; i++) {
278 data_[i] = 1.0f/(1.0f + expf(-data_[i]));
279 }
280 return *this;
281 }
282
clear()283 inline Array1D& clear() {
284 uprv_memset(data_, 0, d1_ * sizeof(float));
285 return *this;
286 }
287
288 private:
289 void* memory_;
290 float* data_;
291 int32_t d1_;
292 };
293
~Array1D()294 Array1D::~Array1D()
295 {
296 uprv_free(memory_);
297 }
298
299 class Array2D : public ReadArray2D {
300 public:
Array2D()301 Array2D() : memory_(nullptr), data_(nullptr), d1_(0), d2_(0) {}
Array2D(int32_t d1,int32_t d2,UErrorCode & status)302 Array2D(int32_t d1, int32_t d2, UErrorCode &status)
303 : memory_(uprv_malloc(d1 * d2 * sizeof(float))),
304 data_((float*)memory_), d1_(d1), d2_(d2) {
305 if (U_SUCCESS(status)) {
306 if (memory_ == nullptr) {
307 status = U_MEMORY_ALLOCATION_ERROR;
308 return;
309 }
310 clear();
311 }
312 }
313 virtual ~Array2D();
314
315 // ReadArray2D methods.
d1() const316 virtual int32_t d1() const override { return d1_; }
d2() const317 virtual int32_t d2() const override { return d2_; }
get(int32_t i,int32_t j) const318 virtual float get(int32_t i, int32_t j) const override {
319 U_ASSERT(i < d1_);
320 U_ASSERT(j < d2_);
321 return data_[i * d2_ + j];
322 }
323
row(int32_t i) const324 inline Array1D row(int32_t i) const {
325 U_ASSERT(i < d1_);
326 return Array1D(data_ + i * d2_, d2_);
327 }
328
clear()329 inline Array2D& clear() {
330 uprv_memset(data_, 0, d1_ * d2_ * sizeof(float));
331 return *this;
332 }
333
334 private:
335 void* memory_;
336 float* data_;
337 int32_t d1_;
338 int32_t d2_;
339 };
340
~Array2D()341 Array2D::~Array2D()
342 {
343 uprv_free(memory_);
344 }
345
346 typedef enum {
347 BEGIN,
348 INSIDE,
349 END,
350 SINGLE
351 } LSTMClass;
352
353 typedef enum {
354 UNKNOWN,
355 CODE_POINTS,
356 GRAPHEME_CLUSTER,
357 } EmbeddingType;
358
359 struct LSTMData : public UMemory {
360 LSTMData(UResourceBundle* rb, UErrorCode &status);
361 ~LSTMData();
362 UHashtable* fDict;
363 EmbeddingType fType;
364 const UChar* fName;
365 ConstArray2D fEmbedding;
366 ConstArray2D fForwardW;
367 ConstArray2D fForwardU;
368 ConstArray1D fForwardB;
369 ConstArray2D fBackwardW;
370 ConstArray2D fBackwardU;
371 ConstArray1D fBackwardB;
372 ConstArray2D fOutputW;
373 ConstArray1D fOutputB;
374
375 private:
376 UResourceBundle* fBundle;
377 };
378
LSTMData(UResourceBundle * rb,UErrorCode & status)379 LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status)
380 : fDict(nullptr), fType(UNKNOWN), fName(nullptr),
381 fBundle(rb)
382 {
383 if (U_FAILURE(status)) {
384 return;
385 }
386 if (IEEE_754 != 1) {
387 status = U_UNSUPPORTED_ERROR;
388 return;
389 }
390 LocalUResourceBundlePointer embeddings_res(
391 ures_getByKey(rb, "embeddings", nullptr, &status));
392 int32_t embedding_size = ures_getInt(embeddings_res.getAlias(), &status);
393 LocalUResourceBundlePointer hunits_res(
394 ures_getByKey(rb, "hunits", nullptr, &status));
395 if (U_FAILURE(status)) return;
396 int32_t hunits = ures_getInt(hunits_res.getAlias(), &status);
397 const UChar* type = ures_getStringByKey(rb, "type", nullptr, &status);
398 if (U_FAILURE(status)) return;
399 if (u_strCompare(type, -1, u"codepoints", -1, false) == 0) {
400 fType = CODE_POINTS;
401 } else if (u_strCompare(type, -1, u"graphclust", -1, false) == 0) {
402 fType = GRAPHEME_CLUSTER;
403 }
404 fName = ures_getStringByKey(rb, "model", nullptr, &status);
405 LocalUResourceBundlePointer dataRes(ures_getByKey(rb, "data", nullptr, &status));
406 if (U_FAILURE(status)) return;
407 int32_t data_len = 0;
408 const int32_t* data = ures_getIntVector(dataRes.getAlias(), &data_len, &status);
409 fDict = uhash_open(uhash_hashUChars, uhash_compareUChars, nullptr, &status);
410
411 StackUResourceBundle stackTempBundle;
412 ResourceDataValue value;
413 ures_getValueWithFallback(rb, "dict", stackTempBundle.getAlias(), value, status);
414 ResourceArray stringArray = value.getArray(status);
415 int32_t num_index = stringArray.getSize();
416 if (U_FAILURE(status)) { return; }
417
418 // put dict into hash
419 int32_t stringLength;
420 for (int32_t idx = 0; idx < num_index; idx++) {
421 stringArray.getValue(idx, value);
422 const UChar* str = value.getString(stringLength, status);
423 uhash_putiAllowZero(fDict, (void*)str, idx, &status);
424 if (U_FAILURE(status)) return;
425 #ifdef LSTM_VECTORIZER_DEBUG
426 printf("Assign [");
427 while (*str != 0x0000) {
428 printf("U+%04x ", *str);
429 str++;
430 }
431 printf("] map to %d\n", idx-1);
432 #endif
433 }
434 int32_t mat1_size = (num_index + 1) * embedding_size;
435 int32_t mat2_size = embedding_size * 4 * hunits;
436 int32_t mat3_size = hunits * 4 * hunits;
437 int32_t mat4_size = 4 * hunits;
438 int32_t mat5_size = mat2_size;
439 int32_t mat6_size = mat3_size;
440 int32_t mat7_size = mat4_size;
441 int32_t mat8_size = 2 * hunits * 4;
442 #if U_DEBUG
443 int32_t mat9_size = 4;
444 U_ASSERT(data_len == mat1_size + mat2_size + mat3_size + mat4_size + mat5_size +
445 mat6_size + mat7_size + mat8_size + mat9_size);
446 #endif
447
448 fEmbedding.init(data, (num_index + 1), embedding_size);
449 data += mat1_size;
450 fForwardW.init(data, embedding_size, 4 * hunits);
451 data += mat2_size;
452 fForwardU.init(data, hunits, 4 * hunits);
453 data += mat3_size;
454 fForwardB.init(data, 4 * hunits);
455 data += mat4_size;
456 fBackwardW.init(data, embedding_size, 4 * hunits);
457 data += mat5_size;
458 fBackwardU.init(data, hunits, 4 * hunits);
459 data += mat6_size;
460 fBackwardB.init(data, 4 * hunits);
461 data += mat7_size;
462 fOutputW.init(data, 2 * hunits, 4);
463 data += mat8_size;
464 fOutputB.init(data, 4);
465 }
466
~LSTMData()467 LSTMData::~LSTMData() {
468 uhash_close(fDict);
469 ures_close(fBundle);
470 }
471
472 class Vectorizer : public UMemory {
473 public:
Vectorizer(UHashtable * dict)474 Vectorizer(UHashtable* dict) : fDict(dict) {}
475 virtual ~Vectorizer();
476 virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
477 UVector32 &offsets, UVector32 &indices,
478 UErrorCode &status) const = 0;
479 protected:
stringToIndex(const UChar * str) const480 int32_t stringToIndex(const UChar* str) const {
481 UBool found = false;
482 int32_t ret = uhash_getiAndFound(fDict, (const void*)str, &found);
483 if (!found) {
484 ret = fDict->count;
485 }
486 #ifdef LSTM_VECTORIZER_DEBUG
487 printf("[");
488 while (*str != 0x0000) {
489 printf("U+%04x ", *str);
490 str++;
491 }
492 printf("] map to %d\n", ret);
493 #endif
494 return ret;
495 }
496
497 private:
498 UHashtable* fDict;
499 };
500
~Vectorizer()501 Vectorizer::~Vectorizer()
502 {
503 }
504
505 class CodePointsVectorizer : public Vectorizer {
506 public:
CodePointsVectorizer(UHashtable * dict)507 CodePointsVectorizer(UHashtable* dict) : Vectorizer(dict) {}
508 virtual ~CodePointsVectorizer();
509 virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
510 UVector32 &offsets, UVector32 &indices,
511 UErrorCode &status) const override;
512 };
513
~CodePointsVectorizer()514 CodePointsVectorizer::~CodePointsVectorizer()
515 {
516 }
517
vectorize(UText * text,int32_t startPos,int32_t endPos,UVector32 & offsets,UVector32 & indices,UErrorCode & status) const518 void CodePointsVectorizer::vectorize(
519 UText *text, int32_t startPos, int32_t endPos,
520 UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
521 {
522 if (offsets.ensureCapacity(endPos - startPos, status) &&
523 indices.ensureCapacity(endPos - startPos, status)) {
524 if (U_FAILURE(status)) return;
525 utext_setNativeIndex(text, startPos);
526 int32_t current;
527 UChar str[2] = {0, 0};
528 while (U_SUCCESS(status) &&
529 (current = (int32_t)utext_getNativeIndex(text)) < endPos) {
530 // Since the LSTMBreakEngine is currently only accept chars in BMP,
531 // we can ignore the possibility of hitting supplementary code
532 // point.
533 str[0] = (UChar) utext_next32(text);
534 U_ASSERT(!U_IS_SURROGATE(str[0]));
535 offsets.addElement(current, status);
536 indices.addElement(stringToIndex(str), status);
537 }
538 }
539 }
540
541 class GraphemeClusterVectorizer : public Vectorizer {
542 public:
GraphemeClusterVectorizer(UHashtable * dict)543 GraphemeClusterVectorizer(UHashtable* dict)
544 : Vectorizer(dict)
545 {
546 }
547 virtual ~GraphemeClusterVectorizer();
548 virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
549 UVector32 &offsets, UVector32 &indices,
550 UErrorCode &status) const override;
551 };
552
~GraphemeClusterVectorizer()553 GraphemeClusterVectorizer::~GraphemeClusterVectorizer()
554 {
555 }
556
557 constexpr int32_t MAX_GRAPHEME_CLSTER_LENGTH = 10;
558
vectorize(UText * text,int32_t startPos,int32_t endPos,UVector32 & offsets,UVector32 & indices,UErrorCode & status) const559 void GraphemeClusterVectorizer::vectorize(
560 UText *text, int32_t startPos, int32_t endPos,
561 UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
562 {
563 if (U_FAILURE(status)) return;
564 if (!offsets.ensureCapacity(endPos - startPos, status) ||
565 !indices.ensureCapacity(endPos - startPos, status)) {
566 return;
567 }
568 if (U_FAILURE(status)) return;
569 LocalPointer<BreakIterator> graphemeIter(BreakIterator::createCharacterInstance(Locale(), status));
570 if (U_FAILURE(status)) return;
571 graphemeIter->setText(text, status);
572 if (U_FAILURE(status)) return;
573
574 if (startPos != 0) {
575 graphemeIter->preceding(startPos);
576 }
577 int32_t last = startPos;
578 int32_t current = startPos;
579 UChar str[MAX_GRAPHEME_CLSTER_LENGTH];
580 while ((current = graphemeIter->next()) != BreakIterator::DONE) {
581 if (current >= endPos) {
582 break;
583 }
584 if (current > startPos) {
585 utext_extract(text, last, current, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
586 if (U_FAILURE(status)) return;
587 offsets.addElement(last, status);
588 indices.addElement(stringToIndex(str), status);
589 if (U_FAILURE(status)) return;
590 }
591 last = current;
592 }
593 if (U_FAILURE(status) || last >= endPos) {
594 return;
595 }
596 utext_extract(text, last, endPos, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
597 if (U_SUCCESS(status)) {
598 offsets.addElement(last, status);
599 indices.addElement(stringToIndex(str), status);
600 }
601 }
602
603 // Computing LSTM as stated in
604 // https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate
605 // ifco is temp array allocate outside which does not need to be
606 // input/output value but could avoid unnecessary memory alloc/free if passing
607 // in.
compute(int32_t hunits,const ReadArray2D & W,const ReadArray2D & U,const ReadArray1D & b,const ReadArray1D & x,Array1D & h,Array1D & c,Array1D & ifco)608 void compute(
609 int32_t hunits,
610 const ReadArray2D& W, const ReadArray2D& U, const ReadArray1D& b,
611 const ReadArray1D& x, Array1D& h, Array1D& c,
612 Array1D& ifco)
613 {
614 // ifco = x * W + h * U + b
615 ifco.assign(b)
616 .addDotProduct(x, W)
617 .addDotProduct(h, U);
618
619 ifco.slice(0*hunits, hunits).sigmoid(); // i: sigmod
620 ifco.slice(1*hunits, hunits).sigmoid(); // f: sigmoid
621 ifco.slice(2*hunits, hunits).tanh(); // c_: tanh
622 ifco.slice(3*hunits, hunits).sigmoid(); // o: sigmod
623
624 c.hadamardProduct(ifco.slice(hunits, hunits))
625 .addHadamardProduct(ifco.slice(0, hunits), ifco.slice(2*hunits, hunits));
626
627 h.tanh(c)
628 .hadamardProduct(ifco.slice(3*hunits, hunits));
629 }
630
631 // Minimum word size
632 static const int32_t MIN_WORD = 2;
633
634 // Minimum number of characters for two words
635 static const int32_t MIN_WORD_SPAN = MIN_WORD * 2;
636
637 int32_t
divideUpDictionaryRange(UText * text,int32_t startPos,int32_t endPos,UVector32 & foundBreaks,UErrorCode & status) const638 LSTMBreakEngine::divideUpDictionaryRange( UText *text,
639 int32_t startPos,
640 int32_t endPos,
641 UVector32 &foundBreaks,
642 UErrorCode& status) const {
643 if (U_FAILURE(status)) return 0;
644 int32_t beginFoundBreakSize = foundBreaks.size();
645 utext_setNativeIndex(text, startPos);
646 utext_moveIndex32(text, MIN_WORD_SPAN);
647 if (utext_getNativeIndex(text) >= endPos) {
648 return 0; // Not enough characters for two words
649 }
650 utext_setNativeIndex(text, startPos);
651
652 UVector32 offsets(status);
653 UVector32 indices(status);
654 if (U_FAILURE(status)) return 0;
655 fVectorizer->vectorize(text, startPos, endPos, offsets, indices, status);
656 if (U_FAILURE(status)) return 0;
657 int32_t* offsetsBuf = offsets.getBuffer();
658 int32_t* indicesBuf = indices.getBuffer();
659
660 int32_t input_seq_len = indices.size();
661 int32_t hunits = fData->fForwardU.d1();
662
663 // ----- Begin of all the Array memory allocation needed for this function
664 // Allocate temp array used inside compute()
665 Array1D ifco(4 * hunits, status);
666
667 Array1D c(hunits, status);
668 Array1D logp(4, status);
669
670 // TODO: limit size of hBackward. If input_seq_len is too big, we could
671 // run out of memory.
672 // Backward LSTM
673 Array2D hBackward(input_seq_len, hunits, status);
674
675 // Allocate fbRow and slice the internal array in two.
676 Array1D fbRow(2 * hunits, status);
677
678 // ----- End of all the Array memory allocation needed for this function
679 if (U_FAILURE(status)) return 0;
680
681 // To save the needed memory usage, the following is different from the
682 // Python or ICU4X implementation. We first perform the Backward LSTM
683 // and then merge the iteration of the forward LSTM and the output layer
684 // together because we only neetdto remember the h[t-1] for Forward LSTM.
685 for (int32_t i = input_seq_len - 1; i >= 0; i--) {
686 Array1D hRow = hBackward.row(i);
687 if (i != input_seq_len - 1) {
688 hRow.assign(hBackward.row(i+1));
689 }
690 #ifdef LSTM_DEBUG
691 printf("hRow %d\n", i);
692 hRow.print();
693 printf("indicesBuf[%d] = %d\n", i, indicesBuf[i]);
694 printf("fData->fEmbedding.row(indicesBuf[%d]):\n", i);
695 fData->fEmbedding.row(indicesBuf[i]).print();
696 #endif // LSTM_DEBUG
697 compute(hunits,
698 fData->fBackwardW, fData->fBackwardU, fData->fBackwardB,
699 fData->fEmbedding.row(indicesBuf[i]),
700 hRow, c, ifco);
701 }
702
703
704 Array1D forwardRow = fbRow.slice(0, hunits); // point to first half of data in fbRow.
705 Array1D backwardRow = fbRow.slice(hunits, hunits); // point to second half of data n fbRow.
706
707 // The following iteration merge the forward LSTM and the output layer
708 // together.
709 c.clear(); // reuse c since it is the same size.
710 for (int32_t i = 0; i < input_seq_len; i++) {
711 #ifdef LSTM_DEBUG
712 printf("forwardRow %d\n", i);
713 forwardRow.print();
714 #endif // LSTM_DEBUG
715 // Forward LSTM
716 // Calculate the result into forwardRow, which point to the data in the first half
717 // of fbRow.
718 compute(hunits,
719 fData->fForwardW, fData->fForwardU, fData->fForwardB,
720 fData->fEmbedding.row(indicesBuf[i]),
721 forwardRow, c, ifco);
722
723 // assign the data from hBackward.row(i) to second half of fbRowa.
724 backwardRow.assign(hBackward.row(i));
725
726 logp.assign(fData->fOutputB).addDotProduct(fbRow, fData->fOutputW);
727 #ifdef LSTM_DEBUG
728 printf("backwardRow %d\n", i);
729 backwardRow.print();
730 printf("logp %d\n", i);
731 logp.print();
732 #endif // LSTM_DEBUG
733
734 // current = argmax(logp)
735 LSTMClass current = (LSTMClass)logp.maxIndex();
736 // BIES logic.
737 if (current == BEGIN || current == SINGLE) {
738 if (i != 0) {
739 foundBreaks.addElement(offsetsBuf[i], status);
740 if (U_FAILURE(status)) return 0;
741 }
742 }
743 }
744 return foundBreaks.size() - beginFoundBreakSize;
745 }
746
createVectorizer(const LSTMData * data,UErrorCode & status)747 Vectorizer* createVectorizer(const LSTMData* data, UErrorCode &status) {
748 if (U_FAILURE(status)) {
749 return nullptr;
750 }
751 switch (data->fType) {
752 case CODE_POINTS:
753 return new CodePointsVectorizer(data->fDict);
754 break;
755 case GRAPHEME_CLUSTER:
756 return new GraphemeClusterVectorizer(data->fDict);
757 break;
758 default:
759 break;
760 }
761 UPRV_UNREACHABLE_EXIT;
762 }
763
LSTMBreakEngine(const LSTMData * data,const UnicodeSet & set,UErrorCode & status)764 LSTMBreakEngine::LSTMBreakEngine(const LSTMData* data, const UnicodeSet& set, UErrorCode &status)
765 : DictionaryBreakEngine(), fData(data), fVectorizer(createVectorizer(fData, status))
766 {
767 if (U_FAILURE(status)) {
768 fData = nullptr; // If failure, we should not delete fData in destructor because the caller will do so.
769 return;
770 }
771 setCharacters(set);
772 }
773
~LSTMBreakEngine()774 LSTMBreakEngine::~LSTMBreakEngine() {
775 delete fData;
776 delete fVectorizer;
777 }
778
name() const779 const UChar* LSTMBreakEngine::name() const {
780 return fData->fName;
781 }
782
defaultLSTM(UScriptCode script,UErrorCode & status)783 UnicodeString defaultLSTM(UScriptCode script, UErrorCode& status) {
784 // open root from brkitr tree.
785 UResourceBundle *b = ures_open(U_ICUDATA_BRKITR, "", &status);
786 b = ures_getByKeyWithFallback(b, "lstm", b, &status);
787 UnicodeString result = ures_getUnicodeStringByKey(b, uscript_getShortName(script), &status);
788 ures_close(b);
789 return result;
790 }
791
CreateLSTMDataForScript(UScriptCode script,UErrorCode & status)792 U_CAPI const LSTMData* U_EXPORT2 CreateLSTMDataForScript(UScriptCode script, UErrorCode& status)
793 {
794 if (script != USCRIPT_KHMER && script != USCRIPT_LAO && script != USCRIPT_MYANMAR && script != USCRIPT_THAI) {
795 return nullptr;
796 }
797 UnicodeString name = defaultLSTM(script, status);
798 if (U_FAILURE(status)) return nullptr;
799 CharString namebuf;
800 namebuf.appendInvariantChars(name, status).truncate(namebuf.lastIndexOf('.'));
801
802 LocalUResourceBundlePointer rb(
803 ures_openDirect(U_ICUDATA_BRKITR, namebuf.data(), &status));
804 if (U_FAILURE(status)) return nullptr;
805
806 return CreateLSTMData(rb.orphan(), status);
807 }
808
CreateLSTMData(UResourceBundle * rb,UErrorCode & status)809 U_CAPI const LSTMData* U_EXPORT2 CreateLSTMData(UResourceBundle* rb, UErrorCode& status)
810 {
811 return new LSTMData(rb, status);
812 }
813
814 U_CAPI const LanguageBreakEngine* U_EXPORT2
CreateLSTMBreakEngine(UScriptCode script,const LSTMData * data,UErrorCode & status)815 CreateLSTMBreakEngine(UScriptCode script, const LSTMData* data, UErrorCode& status)
816 {
817 UnicodeString unicodeSetString;
818 switch(script) {
819 case USCRIPT_THAI:
820 unicodeSetString = UnicodeString(u"[[:Thai:]&[:LineBreak=SA:]]");
821 break;
822 case USCRIPT_MYANMAR:
823 unicodeSetString = UnicodeString(u"[[:Mymr:]&[:LineBreak=SA:]]");
824 break;
825 default:
826 delete data;
827 return nullptr;
828 }
829 UnicodeSet unicodeSet;
830 unicodeSet.applyPattern(unicodeSetString, status);
831 const LanguageBreakEngine* engine = new LSTMBreakEngine(data, unicodeSet, status);
832 if (U_FAILURE(status) || engine == nullptr) {
833 if (engine != nullptr) {
834 delete engine;
835 } else {
836 status = U_MEMORY_ALLOCATION_ERROR;
837 }
838 return nullptr;
839 }
840 return engine;
841 }
842
DeleteLSTMData(const LSTMData * data)843 U_CAPI void U_EXPORT2 DeleteLSTMData(const LSTMData* data)
844 {
845 delete data;
846 }
847
LSTMDataName(const LSTMData * data)848 U_CAPI const UChar* U_EXPORT2 LSTMDataName(const LSTMData* data)
849 {
850 return data->fName;
851 }
852
853 U_NAMESPACE_END
854
855 #endif /* #if !UCONFIG_NO_BREAK_ITERATION */
856