• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // © 2021 and later: Unicode, Inc. and others.
2 // License & terms of use: http://www.unicode.org/copyright.html
3 //
4 /**
5  * A LSTMBreakEngine
6  */
7 package com.ibm.icu.impl.breakiter;
8 
9 import java.nio.ByteBuffer;
10 import java.nio.ByteOrder;
11 import java.text.CharacterIterator;
12 import java.util.ArrayList;
13 import java.util.Arrays;
14 import java.util.HashMap;
15 import java.util.List;
16 import java.util.Map;
17 
18 import com.ibm.icu.impl.ICUData;
19 import com.ibm.icu.impl.ICUResourceBundle;
20 import com.ibm.icu.lang.UCharacter;
21 import com.ibm.icu.lang.UProperty;
22 import com.ibm.icu.lang.UScript;
23 import com.ibm.icu.text.BreakIterator;
24 import com.ibm.icu.text.UnicodeSet;
25 import com.ibm.icu.util.UResourceBundle;
26 
27 /**
28  * @internal
29  */
30 public class LSTMBreakEngine extends DictionaryBreakEngine {
31     public enum EmbeddingType {
32       UNKNOWN,
33       CODE_POINTS,
34       GRAPHEME_CLUSTER
35     }
36 
37     public enum LSTMClass {
38       BEGIN,
39       INSIDE,
40       END,
41       SINGLE,
42     }
43 
make2DArray(int[] data, int start, int d1, int d2)44     private static float[][] make2DArray(int[] data, int start, int d1, int d2) {
45         byte[] bytes = new byte[4];
46         float [][] result = new float[d1][d2];
47         for (int i = 0; i < d1 ; i++) {
48             for (int j = 0; j < d2 ; j++) {
49                 int d = data[start++];
50                 bytes[0] = (byte) (d >> 24);
51                 bytes[1] = (byte) (d >> 16);
52                 bytes[2] = (byte) (d >> 8);
53                 bytes[3] = (byte) (d /*>> 0*/);
54                 result[i][j] = ByteBuffer.wrap(bytes).order(ByteOrder.BIG_ENDIAN).getFloat();
55             }
56         }
57         return result;
58     }
59 
make1DArray(int[] data, int start, int d1)60     private static float[] make1DArray(int[] data, int start, int d1) {
61         byte[] bytes = new byte[4];
62         float [] result = new float[d1];
63         for (int i = 0; i < d1 ; i++) {
64             int d = data[start++];
65             bytes[0] = (byte) (d >> 24);
66             bytes[1] = (byte) (d >> 16);
67             bytes[2] = (byte) (d >> 8);
68             bytes[3] = (byte) (d /*>> 0*/);
69             result[i] = ByteBuffer.wrap(bytes).order(ByteOrder.BIG_ENDIAN).getFloat();
70         }
71         return result;
72     }
73 
74     /** @internal */
75     public static class LSTMData {
LSTMData()76         private LSTMData() {
77         }
78 
LSTMData(UResourceBundle rb)79         public LSTMData(UResourceBundle rb) {
80             int embeddings = rb.get("embeddings").getInt();
81             int hunits = rb.get("hunits").getInt();
82             this.fType = EmbeddingType.UNKNOWN;
83             this.fName = rb.get("model").getString();
84             String typeString = rb.get("type").getString();
85             if (typeString.equals("codepoints")) {
86                 this.fType = EmbeddingType.CODE_POINTS;
87             } else if (typeString.equals("graphclust")) {
88                 this.fType = EmbeddingType.GRAPHEME_CLUSTER;
89             }
90             String[] dict = rb.get("dict").getStringArray();
91             int[] data = rb.get("data").getIntVector();
92             int dataLen = data.length;
93             int numIndex = dict.length;
94             fDict = new HashMap<String, Integer>(numIndex + 1);
95             int idx = 0;
96             for (String embedding : dict){
97                 fDict.put(embedding, idx++);
98             }
99             int mat1Size = (numIndex + 1) * embeddings;
100             int mat2Size = embeddings * 4 * hunits;
101             int mat3Size = hunits * 4 * hunits;
102             int mat4Size = 4 * hunits;
103             int mat5Size = mat2Size;
104             int mat6Size = mat3Size;
105             int mat7Size = mat4Size;
106             int mat8Size = 2 * hunits * 4;
107             int mat9Size = 4;
108             assert dataLen == mat1Size + mat2Size + mat3Size + mat4Size + mat5Size + mat6Size + mat7Size + mat8Size + mat9Size;
109             int start = 0;
110             this.fEmbedding = make2DArray(data, start, (numIndex+1), embeddings);
111             start += mat1Size;
112             this.fForwardW = make2DArray(data, start, embeddings, 4 * hunits);
113             start += mat2Size;
114             this.fForwardU = make2DArray(data, start, hunits, 4 * hunits);
115             start += mat3Size;
116             this.fForwardB = make1DArray(data, start, 4 * hunits);
117             start += mat4Size;
118             this.fBackwardW = make2DArray(data, start, embeddings, 4 * hunits);
119             start += mat5Size;
120             this.fBackwardU = make2DArray(data, start, hunits, 4 * hunits);
121             start += mat6Size;
122             this.fBackwardB = make1DArray(data, start, 4 * hunits);
123             start += mat7Size;
124             this.fOutputW = make2DArray(data, start, 2 * hunits, 4);
125             start += mat8Size;
126             this.fOutputB = make1DArray(data, start, 4);
127         }
128 
129         public EmbeddingType fType;
130         public String fName;
131         public Map<String, Integer> fDict;
132         public float fEmbedding[][];
133         public float fForwardW[][];
134         public float fForwardU[][];
135         public float fForwardB[];
136         public float fBackwardW[][];
137         public float fBackwardU[][];
138         public float fBackwardB[];
139         public float fOutputW[][];
140         public float fOutputB[];
141     }
142 
143     // Minimum word size
144     private static final byte MIN_WORD = 2;
145 
146     // Minimum number of characters for two words
147     private static final byte MIN_WORD_SPAN = MIN_WORD * 2;
148 
149     abstract class Vectorizer {
Vectorizer(Map<String, Integer> dict)150         public Vectorizer(Map<String, Integer> dict) {
151             this.fDict = dict;
152         }
vectorize(CharacterIterator fIter, int rangeStart, int rangeEnd, List<Integer> offsets, List<Integer> indicies)153         abstract public void vectorize(CharacterIterator fIter, int rangeStart, int rangeEnd,
154                               List<Integer> offsets, List<Integer> indicies);
getIndex(String token)155         protected int getIndex(String token) {
156             Integer res = fDict.get(token);
157             return (res == null) ? fDict.size() : res;
158         }
159         private Map<String, Integer> fDict;
160     }
161 
162     class CodePointsVectorizer extends Vectorizer {
CodePointsVectorizer(Map<String, Integer> dict)163         public CodePointsVectorizer(Map<String, Integer> dict) {
164             super(dict);
165         }
166 
vectorize(CharacterIterator fIter, int rangeStart, int rangeEnd, List<Integer> offsets, List<Integer> indicies)167         public void vectorize(CharacterIterator fIter, int rangeStart, int rangeEnd,
168                               List<Integer> offsets, List<Integer> indicies) {
169             fIter.setIndex(rangeStart);
170             for (char c = fIter.current();
171                  c != CharacterIterator.DONE && fIter.getIndex() < rangeEnd;
172                  c = fIter.next()) {
173                 offsets.add(fIter.getIndex());
174                 indicies.add(getIndex(String.valueOf(c)));
175             }
176         }
177     }
178 
179     class GraphemeClusterVectorizer extends Vectorizer {
GraphemeClusterVectorizer(Map<String, Integer> dict)180         public GraphemeClusterVectorizer(Map<String, Integer> dict) {
181             super(dict);
182         }
183 
substring(CharacterIterator text, int startPos, int endPos)184         private String substring(CharacterIterator text, int startPos, int endPos) {
185             int saved = text.getIndex();
186             text.setIndex(startPos);
187             StringBuilder sb = new StringBuilder();
188             for (char c = text.current();
189                  c != CharacterIterator.DONE && text.getIndex() < endPos;
190                  c = text.next()) {
191                 sb.append(c);
192             }
193             text.setIndex(saved);
194             return sb.toString();
195         }
196 
vectorize(CharacterIterator text, int startPos, int endPos, List<Integer> offsets, List<Integer> indicies)197         public void vectorize(CharacterIterator text, int startPos, int endPos,
198                               List<Integer> offsets, List<Integer> indicies) {
199             BreakIterator iter = BreakIterator.getCharacterInstance();
200             iter.setText(text);
201             int last = iter.next(startPos);
202             for (int curr = iter.next(); curr != BreakIterator.DONE && curr <= endPos; curr = iter.next()) {
203                 offsets.add(last);
204                 String segment = substring(text, last, curr);
205                 int index = getIndex(segment);
206                 indicies.add(index);
207                 last = curr;
208             }
209         }
210     }
211 
212     private final LSTMData fData;
213     private int fScript;
214     private final Vectorizer fVectorizer;
215 
makeVectorizer(LSTMData data)216     private Vectorizer makeVectorizer(LSTMData data) {
217         switch(data.fType) {
218             case CODE_POINTS:
219                 return new CodePointsVectorizer(data.fDict);
220             case GRAPHEME_CLUSTER:
221                 return new GraphemeClusterVectorizer(data.fDict);
222             default:
223                 return null;
224         }
225     }
226 
LSTMBreakEngine(int script, UnicodeSet set, LSTMData data)227     public LSTMBreakEngine(int script, UnicodeSet set, LSTMData data) {
228         setCharacters(set);
229         this.fScript = script;
230         this.fData = data;
231         this.fVectorizer = makeVectorizer(this.fData);
232     }
233 
234     @Override
hashCode()235     public int hashCode() {
236         return getClass().hashCode();
237     }
238 
239     @Override
handles(int c)240     public boolean handles(int c) {
241         return fScript == UCharacter.getIntPropertyValue(c, UProperty.SCRIPT);
242     }
243 
addDotProductTo(final float [] a, final float[][] b, float[] result)244     static private void addDotProductTo(final float [] a, final float[][] b, float[] result) {
245         assert a.length == b.length;
246         assert b[0].length == result.length;
247         for (int i = 0; i < result.length; i++) {
248             for (int j = 0; j < a.length; j++) {
249                 result[i] += a[j] * b[j][i];
250             }
251         }
252     }
253 
addTo(final float [] a, float[] result)254     static private void addTo(final float [] a, float[] result) {
255         assert a.length == result.length;
256         for (int i = 0; i < result.length; i++) {
257             result[i] += a[i];
258         }
259     }
260 
hadamardProductTo(final float [] a, float[] result)261     static private void hadamardProductTo(final float [] a, float[] result) {
262         assert a.length == result.length;
263         for (int i = 0; i < result.length; i++) {
264             result[i] *= a[i];
265         }
266     }
267 
addHadamardProductTo(final float [] a, final float [] b, float[] result)268     static private void addHadamardProductTo(final float [] a, final float [] b, float[] result) {
269         assert a.length == result.length;
270         assert b.length == result.length;
271         for (int i = 0; i < result.length; i++) {
272             result[i] += a[i] * b[i];
273         }
274     }
275 
sigmoid(float [] result, int start, int length)276     static private void sigmoid(float [] result, int start, int length) {
277         assert start < result.length;
278         assert start + length <= result.length;
279         for (int i = start; i < start + length; i++) {
280             result[i] = (float)(1.0/(1.0 + Math.exp(-result[i])));
281         }
282     }
283 
284     static private void tanh(float [] result, int start, int length) {
285         assert start < result.length;
286         assert start + length <= result.length;
287         for (int i = start; i < start + length; i++) {
288             result[i] = (float)Math.tanh(result[i]);
289         }
290     }
291 
292     static private int maxIndex(float [] data) {
293         int index = 0;
294         float max = data[0];
295         for (int i = 1; i < data.length; i++) {
296             if (data[i] > max) {
297                 max = data[i];
298                 index = i;
299             }
300         }
301         return index;
302     }
303 
304     /*
305     static private void print(float [] data) {
306         for (int i=0; i < data.length; i++) {
307             System.out.format("  %e", data[i]);
308             if (i % 4 == 3) {
309               System.out.println();
310             }
311         }
312         System.out.println();
313     }
314     */
315 
316     private float[] compute(final float[][] W, final float[][] U, final float[] B,
317                             final float[] x, float[] h, float[] c) {
318         // ifco = x * W + h * U + b
319         float[] ifco = Arrays.copyOf(B, B.length);
320         addDotProductTo(x, W, ifco);
321         float[] hU = new float[B.length];
322         addDotProductTo(h, U, ifco);
323 
324         int hunits = B.length / 4;
325         sigmoid(ifco, 0*hunits, hunits);  // i
326         sigmoid(ifco, 1*hunits, hunits);  // f
327         tanh(ifco, 2*hunits, hunits);  // c_
328         sigmoid(ifco, 3*hunits, hunits);  // o
329 
330         hadamardProductTo(Arrays.copyOfRange(ifco, hunits, 2*hunits), c);
331         addHadamardProductTo(Arrays.copyOf(ifco, hunits),
332             Arrays.copyOfRange(ifco, 2*hunits, 3*hunits), c);
333 
334         h = Arrays.copyOf(c, c.length);
335         tanh(h, 0, h.length);
336         hadamardProductTo(Arrays.copyOfRange(ifco, 3*hunits, 4*hunits), h);
337         // System.out.println("c");
338         // print(c);
339         // System.out.println("h");
340         // print(h);
341         return h;
342     }
343 
344     @Override
345     public int divideUpDictionaryRange(CharacterIterator fIter, int rangeStart, int rangeEnd,
346             DequeI foundBreaks, boolean isPhraseBreaking) {
347         int beginSize = foundBreaks.size();
348 
349         if ((rangeEnd - rangeStart) < MIN_WORD_SPAN) {
350             return 0;  // Not enough characters for word
351         }
352         List<Integer> offsets = new ArrayList<Integer>(rangeEnd - rangeStart);
353         List<Integer> indicies = new ArrayList<Integer>(rangeEnd - rangeStart);
354 
355         fVectorizer.vectorize(fIter, rangeStart, rangeEnd, offsets, indicies);
356 
357         // To save the needed memory usage, the following is different from the
358         // Python or ICU4X implementation. We first perform the Backward LSTM
359         // and then merge the iteration of the forward LSTM and the output layer
360         // together because we only need to remember the h[t-1] for Forward LSTM.
361         int inputSeqLength = indicies.size();
362         int hunits = this.fData.fForwardU.length;
363         float c[] = new float[hunits];
364 
365         // TODO: limit size of hBackward. If input_seq_len is too big, we could
366         // run out of memory.
367         // Backward LSTM
368         float hBackward[][] = new float[inputSeqLength][hunits];
369         for (int i = inputSeqLength - 1; i >= 0;  i--) {
370             if (i != inputSeqLength - 1) {
371                 hBackward[i] = Arrays.copyOf(hBackward[i+1], hunits);
372             }
373             // System.out.println("Backward LSTM " + i);
374             hBackward[i] = compute(this.fData.fBackwardW, this.fData.fBackwardU, this.fData.fBackwardB,
375                     this.fData.fEmbedding[indicies.get(i)],
376                     hBackward[i], c);
377         }
378 
379         c = new float[hunits];
380         float forwardH[] = new float[hunits];
381         float both[] = new float[2*hunits];
382 
383         // The following iteration merge the forward LSTM and the output layer
384         // together.
385         for (int i = 0 ; i < inputSeqLength; i++) {
386             // Forward LSTM
387             forwardH = compute(this.fData.fForwardW, this.fData.fForwardU, this.fData.fForwardB,
388                     this.fData.fEmbedding[indicies.get(i)],
389                     forwardH, c);
390 
391             System.arraycopy(forwardH, 0, both, 0, hunits);
392             System.arraycopy(hBackward[i], 0, both, hunits, hunits);
393 
394             //System.out.println("Merged " + i);
395             //print(both);
396 
397             // Output layer
398             // logp = fbRow * fOutputW + fOutputB
399             float logp[] = Arrays.copyOf(this.fData.fOutputB, this.fData.fOutputB.length);
400             addDotProductTo(both, this.fData.fOutputW, logp);
401 
402             int current = maxIndex(logp);
403 
404             // BIES logic.
405             if (current == LSTMClass.BEGIN.ordinal() ||
406                 current == LSTMClass.SINGLE.ordinal()) {
407                 if (i != 0) {
408                     foundBreaks.push(offsets.get(i));
409                 }
410             }
411         }
412 
413         return foundBreaks.size() - beginSize;
414     }
415 
416     public static LSTMData createData(UResourceBundle bundle) {
417         return new LSTMData(bundle);
418     }
419 
420     private static String defaultLSTM(int script) {
421         ICUResourceBundle rb = (ICUResourceBundle)UResourceBundle.getBundleInstance(ICUData.ICU_BRKITR_BASE_NAME);
422         return rb.getStringWithFallback("lstm/" + UScript.getShortName(script));
423     }
424 
425     public static LSTMData createData(int script) {
426         if (script != UScript.KHMER && script != UScript.LAO && script != UScript.MYANMAR && script != UScript.THAI) {
427             return null;
428         }
429         String name = defaultLSTM(script);
430         name = name.substring(0, name.indexOf("."));
431 
432         UResourceBundle rb = UResourceBundle.getBundleInstance(
433              ICUData.ICU_BRKITR_BASE_NAME, name,
434              ICUResourceBundle.ICU_DATA_CLASS_LOADER);
435         return createData(rb);
436     }
437 
438     public static LSTMBreakEngine create(int script, LSTMData data) {
439         String setExpr = "[[:" + UScript.getShortName(script) + ":]&[:LineBreak=SA:]]";
440         UnicodeSet set = new UnicodeSet();
441         set.applyPattern(setExpr);
442         set.compact();
443         return new LSTMBreakEngine(script, set, data);
444     }
445 }
446