• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // © 2021 and later: Unicode, Inc. and others.
2 // License & terms of use: http://www.unicode.org/copyright.html
3 
4 #include <utility>
5 #include <ctgmath>
6 
7 #include "unicode/utypes.h"
8 
9 #if !UCONFIG_NO_BREAK_ITERATION
10 
11 #include "brkeng.h"
12 #include "charstr.h"
13 #include "cmemory.h"
14 #include "lstmbe.h"
15 #include "putilimp.h"
16 #include "uassert.h"
17 #include "ubrkimpl.h"
18 #include "uresimp.h"
19 #include "uvectr32.h"
20 #include "uvector.h"
21 
22 #include "unicode/brkiter.h"
23 #include "unicode/resbund.h"
24 #include "unicode/ubrk.h"
25 #include "unicode/uniset.h"
26 #include "unicode/ustring.h"
27 #include "unicode/utf.h"
28 
29 U_NAMESPACE_BEGIN
30 
31 // Uncomment the following #define to debug.
32 // #define LSTM_DEBUG 1
33 // #define LSTM_VECTORIZER_DEBUG 1
34 
35 /**
36  * Interface for reading 1D array.
37  */
38 class ReadArray1D {
39 public:
40     virtual ~ReadArray1D();
41     virtual int32_t d1() const = 0;
42     virtual float get(int32_t i) const = 0;
43 
44 #ifdef LSTM_DEBUG
print() const45     void print() const {
46         printf("\n[");
47         for (int32_t i = 0; i < d1(); i++) {
48            printf("%0.8e ", get(i));
49            if (i % 4 == 3) printf("\n");
50         }
51         printf("]\n");
52     }
53 #endif
54 };
55 
~ReadArray1D()56 ReadArray1D::~ReadArray1D()
57 {
58 }
59 
60 /**
61  * Interface for reading 2D array.
62  */
63 class ReadArray2D {
64 public:
65     virtual ~ReadArray2D();
66     virtual int32_t d1() const = 0;
67     virtual int32_t d2() const = 0;
68     virtual float get(int32_t i, int32_t j) const = 0;
69 };
70 
~ReadArray2D()71 ReadArray2D::~ReadArray2D()
72 {
73 }
74 
75 /**
76  * A class to index a float array as a 1D Array without owning the pointer or
77  * copy the data.
78  */
79 class ConstArray1D : public ReadArray1D {
80 public:
ConstArray1D()81     ConstArray1D() : data_(nullptr), d1_(0) {}
82 
ConstArray1D(const float * data,int32_t d1)83     ConstArray1D(const float* data, int32_t d1) : data_(data), d1_(d1) {}
84 
85     virtual ~ConstArray1D();
86 
87     // Init the object, the object does not own the data nor copy.
88     // It is designed to directly use data from memory mapped resources.
init(const int32_t * data,int32_t d1)89     void init(const int32_t* data, int32_t d1) {
90         U_ASSERT(IEEE_754 == 1);
91         data_ = reinterpret_cast<const float*>(data);
92         d1_ = d1;
93     }
94 
95     // ReadArray1D methods.
d1() const96     virtual int32_t d1() const override { return d1_; }
get(int32_t i) const97     virtual float get(int32_t i) const override {
98         U_ASSERT(i < d1_);
99         return data_[i];
100     }
101 
102 private:
103     const float* data_;
104     int32_t d1_;
105 };
106 
~ConstArray1D()107 ConstArray1D::~ConstArray1D()
108 {
109 }
110 
111 /**
112  * A class to index a float array as a 2D Array without owning the pointer or
113  * copy the data.
114  */
115 class ConstArray2D : public ReadArray2D {
116 public:
ConstArray2D()117     ConstArray2D() : data_(nullptr), d1_(0), d2_(0) {}
118 
ConstArray2D(const float * data,int32_t d1,int32_t d2)119     ConstArray2D(const float* data, int32_t d1, int32_t d2)
120         : data_(data), d1_(d1), d2_(d2) {}
121 
122     virtual ~ConstArray2D();
123 
124     // Init the object, the object does not own the data nor copy.
125     // It is designed to directly use data from memory mapped resources.
init(const int32_t * data,int32_t d1,int32_t d2)126     void init(const int32_t* data, int32_t d1, int32_t d2) {
127         U_ASSERT(IEEE_754 == 1);
128         data_ = reinterpret_cast<const float*>(data);
129         d1_ = d1;
130         d2_ = d2;
131     }
132 
133     // ReadArray2D methods.
d1() const134     inline int32_t d1() const override { return d1_; }
d2() const135     inline int32_t d2() const override { return d2_; }
get(int32_t i,int32_t j) const136     float get(int32_t i, int32_t j) const override {
137         U_ASSERT(i < d1_);
138         U_ASSERT(j < d2_);
139         return data_[i * d2_ + j];
140     }
141 
142     // Expose the ith row as a ConstArray1D
row(int32_t i) const143     inline ConstArray1D row(int32_t i) const {
144         U_ASSERT(i < d1_);
145         return ConstArray1D(data_ + i * d2_, d2_);
146     }
147 
148 private:
149     const float* data_;
150     int32_t d1_;
151     int32_t d2_;
152 };
153 
~ConstArray2D()154 ConstArray2D::~ConstArray2D()
155 {
156 }
157 
158 /**
159  * A class to allocate data as a writable 1D array.
160  * This is the main class implement matrix operation.
161  */
162 class Array1D : public ReadArray1D {
163 public:
Array1D()164     Array1D() : memory_(nullptr), data_(nullptr), d1_(0) {}
Array1D(int32_t d1,UErrorCode & status)165     Array1D(int32_t d1, UErrorCode &status)
166         : memory_(uprv_malloc(d1 * sizeof(float))),
167           data_((float*)memory_), d1_(d1) {
168         if (U_SUCCESS(status)) {
169             if (memory_ == nullptr) {
170                 status = U_MEMORY_ALLOCATION_ERROR;
171                 return;
172             }
173             clear();
174         }
175     }
176 
177     virtual ~Array1D();
178 
179     // A special constructor which does not own the memory but writeable
180     // as a slice of an array.
Array1D(float * data,int32_t d1)181     Array1D(float* data, int32_t d1)
182         : memory_(nullptr), data_(data), d1_(d1) {}
183 
184     // ReadArray1D methods.
d1() const185     virtual int32_t d1() const override { return d1_; }
get(int32_t i) const186     virtual float get(int32_t i) const override {
187         U_ASSERT(i < d1_);
188         return data_[i];
189     }
190 
191     // Return the index which point to the max data in the array.
maxIndex() const192     inline int32_t maxIndex() const {
193         int32_t index = 0;
194         float max = data_[0];
195         for (int32_t i = 1; i < d1_; i++) {
196             if (data_[i] > max) {
197                 max = data_[i];
198                 index = i;
199             }
200         }
201         return index;
202     }
203 
204     // Slice part of the array to a new one.
slice(int32_t from,int32_t size) const205     inline Array1D slice(int32_t from, int32_t size) const {
206         U_ASSERT(from >= 0);
207         U_ASSERT(from < d1_);
208         U_ASSERT(from + size <= d1_);
209         return Array1D(data_ + from, size);
210     }
211 
212     // Add dot product of a 1D array and a 2D array into this one.
addDotProduct(const ReadArray1D & a,const ReadArray2D & b)213     inline Array1D& addDotProduct(const ReadArray1D& a, const ReadArray2D& b) {
214         U_ASSERT(a.d1() == b.d1());
215         U_ASSERT(b.d2() == d1());
216         for (int32_t i = 0; i < d1(); i++) {
217             for (int32_t j = 0; j < a.d1(); j++) {
218                 data_[i] += a.get(j) * b.get(j, i);
219             }
220         }
221         return *this;
222     }
223 
224     // Hadamard Product the values of another array of the same size into this one.
hadamardProduct(const ReadArray1D & a)225     inline Array1D& hadamardProduct(const ReadArray1D& a) {
226         U_ASSERT(a.d1() == d1());
227         for (int32_t i = 0; i < d1(); i++) {
228             data_[i] *= a.get(i);
229         }
230         return *this;
231     }
232 
233     // Add the Hadamard Product of two arrays of the same size into this one.
addHadamardProduct(const ReadArray1D & a,const ReadArray1D & b)234     inline Array1D& addHadamardProduct(const ReadArray1D& a, const ReadArray1D& b) {
235         U_ASSERT(a.d1() == d1());
236         U_ASSERT(b.d1() == d1());
237         for (int32_t i = 0; i < d1(); i++) {
238             data_[i] += a.get(i) * b.get(i);
239         }
240         return *this;
241     }
242 
243     // Add the values of another array of the same size into this one.
add(const ReadArray1D & a)244     inline Array1D& add(const ReadArray1D& a) {
245         U_ASSERT(a.d1() == d1());
246         for (int32_t i = 0; i < d1(); i++) {
247             data_[i] += a.get(i);
248         }
249         return *this;
250     }
251 
252     // Assign the values of another array of the same size into this one.
assign(const ReadArray1D & a)253     inline Array1D& assign(const ReadArray1D& a) {
254         U_ASSERT(a.d1() == d1());
255         for (int32_t i = 0; i < d1(); i++) {
256             data_[i] = a.get(i);
257         }
258         return *this;
259     }
260 
261     // Apply tanh to all the elements in the array.
tanh()262     inline Array1D& tanh() {
263         return tanh(*this);
264     }
265 
266     // Apply tanh of a and store into this array.
tanh(const Array1D & a)267     inline Array1D& tanh(const Array1D& a) {
268         U_ASSERT(a.d1() == d1());
269         for (int32_t i = 0; i < d1_; i++) {
270             data_[i] = std::tanh(a.get(i));
271         }
272         return *this;
273     }
274 
275     // Apply sigmoid to all the elements in the array.
sigmoid()276     inline Array1D& sigmoid() {
277         for (int32_t i = 0; i < d1_; i++) {
278             data_[i] = 1.0f/(1.0f + expf(-data_[i]));
279         }
280         return *this;
281     }
282 
clear()283     inline Array1D& clear() {
284         uprv_memset(data_, 0, d1_ * sizeof(float));
285         return *this;
286     }
287 
288 private:
289     void* memory_;
290     float* data_;
291     int32_t d1_;
292 };
293 
~Array1D()294 Array1D::~Array1D()
295 {
296     uprv_free(memory_);
297 }
298 
299 class Array2D : public ReadArray2D {
300 public:
Array2D()301     Array2D() : memory_(nullptr), data_(nullptr), d1_(0), d2_(0) {}
Array2D(int32_t d1,int32_t d2,UErrorCode & status)302     Array2D(int32_t d1, int32_t d2, UErrorCode &status)
303         : memory_(uprv_malloc(d1 * d2 * sizeof(float))),
304           data_((float*)memory_), d1_(d1), d2_(d2) {
305         if (U_SUCCESS(status)) {
306             if (memory_ == nullptr) {
307                 status = U_MEMORY_ALLOCATION_ERROR;
308                 return;
309             }
310             clear();
311         }
312     }
313     virtual ~Array2D();
314 
315     // ReadArray2D methods.
d1() const316     virtual int32_t d1() const override { return d1_; }
d2() const317     virtual int32_t d2() const override { return d2_; }
get(int32_t i,int32_t j) const318     virtual float get(int32_t i, int32_t j) const override {
319         U_ASSERT(i < d1_);
320         U_ASSERT(j < d2_);
321         return data_[i * d2_ + j];
322     }
323 
row(int32_t i) const324     inline Array1D row(int32_t i) const {
325         U_ASSERT(i < d1_);
326         return Array1D(data_ + i * d2_, d2_);
327     }
328 
clear()329     inline Array2D& clear() {
330         uprv_memset(data_, 0, d1_ * d2_ * sizeof(float));
331         return *this;
332     }
333 
334 private:
335     void* memory_;
336     float* data_;
337     int32_t d1_;
338     int32_t d2_;
339 };
340 
~Array2D()341 Array2D::~Array2D()
342 {
343     uprv_free(memory_);
344 }
345 
346 typedef enum {
347     BEGIN,
348     INSIDE,
349     END,
350     SINGLE
351 } LSTMClass;
352 
353 typedef enum {
354     UNKNOWN,
355     CODE_POINTS,
356     GRAPHEME_CLUSTER,
357 } EmbeddingType;
358 
359 struct LSTMData : public UMemory {
360     LSTMData(UResourceBundle* rb, UErrorCode &status);
361     ~LSTMData();
362     UHashtable* fDict;
363     EmbeddingType fType;
364     const UChar* fName;
365     ConstArray2D fEmbedding;
366     ConstArray2D fForwardW;
367     ConstArray2D fForwardU;
368     ConstArray1D fForwardB;
369     ConstArray2D fBackwardW;
370     ConstArray2D fBackwardU;
371     ConstArray1D fBackwardB;
372     ConstArray2D fOutputW;
373     ConstArray1D fOutputB;
374 
375 private:
376     UResourceBundle* fBundle;
377 };
378 
LSTMData(UResourceBundle * rb,UErrorCode & status)379 LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status)
380     : fDict(nullptr), fType(UNKNOWN), fName(nullptr),
381       fBundle(rb)
382 {
383     if (U_FAILURE(status)) {
384         return;
385     }
386     if (IEEE_754 != 1) {
387         status = U_UNSUPPORTED_ERROR;
388         return;
389     }
390     LocalUResourceBundlePointer embeddings_res(
391         ures_getByKey(rb, "embeddings", nullptr, &status));
392     int32_t embedding_size = ures_getInt(embeddings_res.getAlias(), &status);
393     LocalUResourceBundlePointer hunits_res(
394         ures_getByKey(rb, "hunits", nullptr, &status));
395     if (U_FAILURE(status)) return;
396     int32_t hunits = ures_getInt(hunits_res.getAlias(), &status);
397     const UChar* type = ures_getStringByKey(rb, "type", nullptr, &status);
398     if (U_FAILURE(status)) return;
399     if (u_strCompare(type, -1, u"codepoints", -1, false) == 0) {
400         fType = CODE_POINTS;
401     } else if (u_strCompare(type, -1, u"graphclust", -1, false) == 0) {
402         fType = GRAPHEME_CLUSTER;
403     }
404     fName = ures_getStringByKey(rb, "model", nullptr, &status);
405     LocalUResourceBundlePointer dataRes(ures_getByKey(rb, "data", nullptr, &status));
406     if (U_FAILURE(status)) return;
407     int32_t data_len = 0;
408     const int32_t* data = ures_getIntVector(dataRes.getAlias(), &data_len, &status);
409     fDict = uhash_open(uhash_hashUChars, uhash_compareUChars, nullptr, &status);
410 
411     StackUResourceBundle stackTempBundle;
412     ResourceDataValue value;
413     ures_getValueWithFallback(rb, "dict", stackTempBundle.getAlias(), value, status);
414     ResourceArray stringArray = value.getArray(status);
415     int32_t num_index = stringArray.getSize();
416     if (U_FAILURE(status)) { return; }
417 
418     // put dict into hash
419     int32_t stringLength;
420     for (int32_t idx = 0; idx < num_index; idx++) {
421         stringArray.getValue(idx, value);
422         const UChar* str = value.getString(stringLength, status);
423         uhash_putiAllowZero(fDict, (void*)str, idx, &status);
424         if (U_FAILURE(status)) return;
425 #ifdef LSTM_VECTORIZER_DEBUG
426         printf("Assign [");
427         while (*str != 0x0000) {
428             printf("U+%04x ", *str);
429             str++;
430         }
431         printf("] map to %d\n", idx-1);
432 #endif
433     }
434     int32_t mat1_size = (num_index + 1) * embedding_size;
435     int32_t mat2_size = embedding_size * 4 * hunits;
436     int32_t mat3_size = hunits * 4 * hunits;
437     int32_t mat4_size = 4 * hunits;
438     int32_t mat5_size = mat2_size;
439     int32_t mat6_size = mat3_size;
440     int32_t mat7_size = mat4_size;
441     int32_t mat8_size = 2 * hunits * 4;
442 #if U_DEBUG
443     int32_t mat9_size = 4;
444     U_ASSERT(data_len == mat1_size + mat2_size + mat3_size + mat4_size + mat5_size +
445         mat6_size + mat7_size + mat8_size + mat9_size);
446 #endif
447 
448     fEmbedding.init(data, (num_index + 1), embedding_size);
449     data += mat1_size;
450     fForwardW.init(data, embedding_size, 4 * hunits);
451     data += mat2_size;
452     fForwardU.init(data, hunits, 4 * hunits);
453     data += mat3_size;
454     fForwardB.init(data, 4 * hunits);
455     data += mat4_size;
456     fBackwardW.init(data, embedding_size, 4 * hunits);
457     data += mat5_size;
458     fBackwardU.init(data, hunits, 4 * hunits);
459     data += mat6_size;
460     fBackwardB.init(data, 4 * hunits);
461     data += mat7_size;
462     fOutputW.init(data, 2 * hunits, 4);
463     data += mat8_size;
464     fOutputB.init(data, 4);
465 }
466 
~LSTMData()467 LSTMData::~LSTMData() {
468     uhash_close(fDict);
469     ures_close(fBundle);
470 }
471 
472 class Vectorizer : public UMemory {
473 public:
Vectorizer(UHashtable * dict)474     Vectorizer(UHashtable* dict) : fDict(dict) {}
475     virtual ~Vectorizer();
476     virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
477                            UVector32 &offsets, UVector32 &indices,
478                            UErrorCode &status) const = 0;
479 protected:
stringToIndex(const UChar * str) const480     int32_t stringToIndex(const UChar* str) const {
481         UBool found = false;
482         int32_t ret = uhash_getiAndFound(fDict, (const void*)str, &found);
483         if (!found) {
484             ret = fDict->count;
485         }
486 #ifdef LSTM_VECTORIZER_DEBUG
487         printf("[");
488         while (*str != 0x0000) {
489             printf("U+%04x ", *str);
490             str++;
491         }
492         printf("] map to %d\n", ret);
493 #endif
494         return ret;
495     }
496 
497 private:
498     UHashtable* fDict;
499 };
500 
~Vectorizer()501 Vectorizer::~Vectorizer()
502 {
503 }
504 
505 class CodePointsVectorizer : public Vectorizer {
506 public:
CodePointsVectorizer(UHashtable * dict)507     CodePointsVectorizer(UHashtable* dict) : Vectorizer(dict) {}
508     virtual ~CodePointsVectorizer();
509     virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
510                            UVector32 &offsets, UVector32 &indices,
511                            UErrorCode &status) const override;
512 };
513 
~CodePointsVectorizer()514 CodePointsVectorizer::~CodePointsVectorizer()
515 {
516 }
517 
vectorize(UText * text,int32_t startPos,int32_t endPos,UVector32 & offsets,UVector32 & indices,UErrorCode & status) const518 void CodePointsVectorizer::vectorize(
519     UText *text, int32_t startPos, int32_t endPos,
520     UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
521 {
522     if (offsets.ensureCapacity(endPos - startPos, status) &&
523             indices.ensureCapacity(endPos - startPos, status)) {
524         if (U_FAILURE(status)) return;
525         utext_setNativeIndex(text, startPos);
526         int32_t current;
527         UChar str[2] = {0, 0};
528         while (U_SUCCESS(status) &&
529                (current = (int32_t)utext_getNativeIndex(text)) < endPos) {
530             // Since the LSTMBreakEngine is currently only accept chars in BMP,
531             // we can ignore the possibility of hitting supplementary code
532             // point.
533             str[0] = (UChar) utext_next32(text);
534             U_ASSERT(!U_IS_SURROGATE(str[0]));
535             offsets.addElement(current, status);
536             indices.addElement(stringToIndex(str), status);
537         }
538     }
539 }
540 
541 class GraphemeClusterVectorizer : public Vectorizer {
542 public:
GraphemeClusterVectorizer(UHashtable * dict)543     GraphemeClusterVectorizer(UHashtable* dict)
544         : Vectorizer(dict)
545     {
546     }
547     virtual ~GraphemeClusterVectorizer();
548     virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
549                            UVector32 &offsets, UVector32 &indices,
550                            UErrorCode &status) const override;
551 };
552 
~GraphemeClusterVectorizer()553 GraphemeClusterVectorizer::~GraphemeClusterVectorizer()
554 {
555 }
556 
557 constexpr int32_t MAX_GRAPHEME_CLSTER_LENGTH = 10;
558 
vectorize(UText * text,int32_t startPos,int32_t endPos,UVector32 & offsets,UVector32 & indices,UErrorCode & status) const559 void GraphemeClusterVectorizer::vectorize(
560     UText *text, int32_t startPos, int32_t endPos,
561     UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
562 {
563     if (U_FAILURE(status)) return;
564     if (!offsets.ensureCapacity(endPos - startPos, status) ||
565             !indices.ensureCapacity(endPos - startPos, status)) {
566         return;
567     }
568     if (U_FAILURE(status)) return;
569     LocalPointer<BreakIterator> graphemeIter(BreakIterator::createCharacterInstance(Locale(), status));
570     if (U_FAILURE(status)) return;
571     graphemeIter->setText(text, status);
572     if (U_FAILURE(status)) return;
573 
574     if (startPos != 0) {
575         graphemeIter->preceding(startPos);
576     }
577     int32_t last = startPos;
578     int32_t current = startPos;
579     UChar str[MAX_GRAPHEME_CLSTER_LENGTH];
580     while ((current = graphemeIter->next()) != BreakIterator::DONE) {
581         if (current >= endPos) {
582             break;
583         }
584         if (current > startPos) {
585             utext_extract(text, last, current, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
586             if (U_FAILURE(status)) return;
587             offsets.addElement(last, status);
588             indices.addElement(stringToIndex(str), status);
589             if (U_FAILURE(status)) return;
590         }
591         last = current;
592     }
593     if (U_FAILURE(status) || last >= endPos) {
594         return;
595     }
596     utext_extract(text, last, endPos, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
597     if (U_SUCCESS(status)) {
598         offsets.addElement(last, status);
599         indices.addElement(stringToIndex(str), status);
600     }
601 }
602 
603 // Computing LSTM as stated in
604 // https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate
605 // ifco is temp array allocate outside which does not need to be
606 // input/output value but could avoid unnecessary memory alloc/free if passing
607 // in.
compute(int32_t hunits,const ReadArray2D & W,const ReadArray2D & U,const ReadArray1D & b,const ReadArray1D & x,Array1D & h,Array1D & c,Array1D & ifco)608 void compute(
609     int32_t hunits,
610     const ReadArray2D& W, const ReadArray2D& U, const ReadArray1D& b,
611     const ReadArray1D& x, Array1D& h, Array1D& c,
612     Array1D& ifco)
613 {
614     // ifco = x * W + h * U + b
615     ifco.assign(b)
616         .addDotProduct(x, W)
617         .addDotProduct(h, U);
618 
619     ifco.slice(0*hunits, hunits).sigmoid();  // i: sigmod
620     ifco.slice(1*hunits, hunits).sigmoid(); // f: sigmoid
621     ifco.slice(2*hunits, hunits).tanh(); // c_: tanh
622     ifco.slice(3*hunits, hunits).sigmoid(); // o: sigmod
623 
624     c.hadamardProduct(ifco.slice(hunits, hunits))
625         .addHadamardProduct(ifco.slice(0, hunits), ifco.slice(2*hunits, hunits));
626 
627     h.tanh(c)
628         .hadamardProduct(ifco.slice(3*hunits, hunits));
629 }
630 
631 // Minimum word size
632 static const int32_t MIN_WORD = 2;
633 
634 // Minimum number of characters for two words
635 static const int32_t MIN_WORD_SPAN = MIN_WORD * 2;
636 
637 int32_t
divideUpDictionaryRange(UText * text,int32_t startPos,int32_t endPos,UVector32 & foundBreaks,UErrorCode & status) const638 LSTMBreakEngine::divideUpDictionaryRange( UText *text,
639                                                 int32_t startPos,
640                                                 int32_t endPos,
641                                                 UVector32 &foundBreaks,
642                                                 UErrorCode& status) const {
643     if (U_FAILURE(status)) return 0;
644     int32_t beginFoundBreakSize = foundBreaks.size();
645     utext_setNativeIndex(text, startPos);
646     utext_moveIndex32(text, MIN_WORD_SPAN);
647     if (utext_getNativeIndex(text) >= endPos) {
648         return 0;       // Not enough characters for two words
649     }
650     utext_setNativeIndex(text, startPos);
651 
652     UVector32 offsets(status);
653     UVector32 indices(status);
654     if (U_FAILURE(status)) return 0;
655     fVectorizer->vectorize(text, startPos, endPos, offsets, indices, status);
656     if (U_FAILURE(status)) return 0;
657     int32_t* offsetsBuf = offsets.getBuffer();
658     int32_t* indicesBuf = indices.getBuffer();
659 
660     int32_t input_seq_len = indices.size();
661     int32_t hunits = fData->fForwardU.d1();
662 
663     // ----- Begin of all the Array memory allocation needed for this function
664     // Allocate temp array used inside compute()
665     Array1D ifco(4 * hunits, status);
666 
667     Array1D c(hunits, status);
668     Array1D logp(4, status);
669 
670     // TODO: limit size of hBackward. If input_seq_len is too big, we could
671     // run out of memory.
672     // Backward LSTM
673     Array2D hBackward(input_seq_len, hunits, status);
674 
675     // Allocate fbRow and slice the internal array in two.
676     Array1D fbRow(2 * hunits, status);
677 
678     // ----- End of all the Array memory allocation needed for this function
679     if (U_FAILURE(status)) return 0;
680 
681     // To save the needed memory usage, the following is different from the
682     // Python or ICU4X implementation. We first perform the Backward LSTM
683     // and then merge the iteration of the forward LSTM and the output layer
684     // together because we only neetdto remember the h[t-1] for Forward LSTM.
685     for (int32_t i = input_seq_len - 1; i >= 0; i--) {
686         Array1D hRow = hBackward.row(i);
687         if (i != input_seq_len - 1) {
688             hRow.assign(hBackward.row(i+1));
689         }
690 #ifdef LSTM_DEBUG
691         printf("hRow %d\n", i);
692         hRow.print();
693         printf("indicesBuf[%d] = %d\n", i, indicesBuf[i]);
694         printf("fData->fEmbedding.row(indicesBuf[%d]):\n", i);
695         fData->fEmbedding.row(indicesBuf[i]).print();
696 #endif  // LSTM_DEBUG
697         compute(hunits,
698                 fData->fBackwardW, fData->fBackwardU, fData->fBackwardB,
699                 fData->fEmbedding.row(indicesBuf[i]),
700                 hRow, c, ifco);
701     }
702 
703 
704     Array1D forwardRow = fbRow.slice(0, hunits);  // point to first half of data in fbRow.
705     Array1D backwardRow = fbRow.slice(hunits, hunits);  // point to second half of data n fbRow.
706 
707     // The following iteration merge the forward LSTM and the output layer
708     // together.
709     c.clear();  // reuse c since it is the same size.
710     for (int32_t i = 0; i < input_seq_len; i++) {
711 #ifdef LSTM_DEBUG
712         printf("forwardRow %d\n", i);
713         forwardRow.print();
714 #endif  // LSTM_DEBUG
715         // Forward LSTM
716         // Calculate the result into forwardRow, which point to the data in the first half
717         // of fbRow.
718         compute(hunits,
719                 fData->fForwardW, fData->fForwardU, fData->fForwardB,
720                 fData->fEmbedding.row(indicesBuf[i]),
721                 forwardRow, c, ifco);
722 
723         // assign the data from hBackward.row(i) to second half of fbRowa.
724         backwardRow.assign(hBackward.row(i));
725 
726         logp.assign(fData->fOutputB).addDotProduct(fbRow, fData->fOutputW);
727 #ifdef LSTM_DEBUG
728         printf("backwardRow %d\n", i);
729         backwardRow.print();
730         printf("logp %d\n", i);
731         logp.print();
732 #endif  // LSTM_DEBUG
733 
734         // current = argmax(logp)
735         LSTMClass current = (LSTMClass)logp.maxIndex();
736         // BIES logic.
737         if (current == BEGIN || current == SINGLE) {
738             if (i != 0) {
739                 foundBreaks.addElement(offsetsBuf[i], status);
740                 if (U_FAILURE(status)) return 0;
741             }
742         }
743     }
744     return foundBreaks.size() - beginFoundBreakSize;
745 }
746 
createVectorizer(const LSTMData * data,UErrorCode & status)747 Vectorizer* createVectorizer(const LSTMData* data, UErrorCode &status) {
748     if (U_FAILURE(status)) {
749         return nullptr;
750     }
751     switch (data->fType) {
752         case CODE_POINTS:
753             return new CodePointsVectorizer(data->fDict);
754             break;
755         case GRAPHEME_CLUSTER:
756             return new GraphemeClusterVectorizer(data->fDict);
757             break;
758         default:
759             break;
760     }
761     UPRV_UNREACHABLE_EXIT;
762 }
763 
LSTMBreakEngine(const LSTMData * data,const UnicodeSet & set,UErrorCode & status)764 LSTMBreakEngine::LSTMBreakEngine(const LSTMData* data, const UnicodeSet& set, UErrorCode &status)
765     : DictionaryBreakEngine(), fData(data), fVectorizer(createVectorizer(fData, status))
766 {
767     if (U_FAILURE(status)) {
768       fData = nullptr;  // If failure, we should not delete fData in destructor because the caller will do so.
769       return;
770     }
771     setCharacters(set);
772 }
773 
~LSTMBreakEngine()774 LSTMBreakEngine::~LSTMBreakEngine() {
775     delete fData;
776     delete fVectorizer;
777 }
778 
name() const779 const UChar* LSTMBreakEngine::name() const {
780     return fData->fName;
781 }
782 
defaultLSTM(UScriptCode script,UErrorCode & status)783 UnicodeString defaultLSTM(UScriptCode script, UErrorCode& status) {
784     // open root from brkitr tree.
785     UResourceBundle *b = ures_open(U_ICUDATA_BRKITR, "", &status);
786     b = ures_getByKeyWithFallback(b, "lstm", b, &status);
787     UnicodeString result = ures_getUnicodeStringByKey(b, uscript_getShortName(script), &status);
788     ures_close(b);
789     return result;
790 }
791 
CreateLSTMDataForScript(UScriptCode script,UErrorCode & status)792 U_CAPI const LSTMData* U_EXPORT2 CreateLSTMDataForScript(UScriptCode script, UErrorCode& status)
793 {
794     if (script != USCRIPT_KHMER && script != USCRIPT_LAO && script != USCRIPT_MYANMAR && script != USCRIPT_THAI) {
795         return nullptr;
796     }
797     UnicodeString name = defaultLSTM(script, status);
798     if (U_FAILURE(status)) return nullptr;
799     CharString namebuf;
800     namebuf.appendInvariantChars(name, status).truncate(namebuf.lastIndexOf('.'));
801 
802     LocalUResourceBundlePointer rb(
803         ures_openDirect(U_ICUDATA_BRKITR, namebuf.data(), &status));
804     if (U_FAILURE(status)) return nullptr;
805 
806     return CreateLSTMData(rb.orphan(), status);
807 }
808 
CreateLSTMData(UResourceBundle * rb,UErrorCode & status)809 U_CAPI const LSTMData* U_EXPORT2 CreateLSTMData(UResourceBundle* rb, UErrorCode& status)
810 {
811     return new LSTMData(rb, status);
812 }
813 
814 U_CAPI const LanguageBreakEngine* U_EXPORT2
CreateLSTMBreakEngine(UScriptCode script,const LSTMData * data,UErrorCode & status)815 CreateLSTMBreakEngine(UScriptCode script, const LSTMData* data, UErrorCode& status)
816 {
817     UnicodeString unicodeSetString;
818     switch(script) {
819         case USCRIPT_THAI:
820             unicodeSetString = UnicodeString(u"[[:Thai:]&[:LineBreak=SA:]]");
821             break;
822         case USCRIPT_MYANMAR:
823             unicodeSetString = UnicodeString(u"[[:Mymr:]&[:LineBreak=SA:]]");
824             break;
825         default:
826             delete data;
827             return nullptr;
828     }
829     UnicodeSet unicodeSet;
830     unicodeSet.applyPattern(unicodeSetString, status);
831     const LanguageBreakEngine* engine = new LSTMBreakEngine(data, unicodeSet, status);
832     if (U_FAILURE(status) || engine == nullptr) {
833         if (engine != nullptr) {
834             delete engine;
835         } else {
836             status = U_MEMORY_ALLOCATION_ERROR;
837         }
838         return nullptr;
839     }
840     return engine;
841 }
842 
DeleteLSTMData(const LSTMData * data)843 U_CAPI void U_EXPORT2 DeleteLSTMData(const LSTMData* data)
844 {
845     delete data;
846 }
847 
LSTMDataName(const LSTMData * data)848 U_CAPI const UChar* U_EXPORT2 LSTMDataName(const LSTMData* data)
849 {
850     return data->fName;
851 }
852 
853 U_NAMESPACE_END
854 
855 #endif /* #if !UCONFIG_NO_BREAK_ITERATION */
856