1 // © 2021 and later: Unicode, Inc. and others. 2 // License & terms of use: http://www.unicode.org/copyright.html 3 // 4 /** 5 * A LSTMBreakEngine 6 */ 7 package com.ibm.icu.impl.breakiter; 8 9 import java.nio.ByteBuffer; 10 import java.nio.ByteOrder; 11 import java.text.CharacterIterator; 12 import java.util.ArrayList; 13 import java.util.Arrays; 14 import java.util.HashMap; 15 import java.util.List; 16 import java.util.Map; 17 18 import com.ibm.icu.impl.ICUData; 19 import com.ibm.icu.impl.ICUResourceBundle; 20 import com.ibm.icu.lang.UCharacter; 21 import com.ibm.icu.lang.UProperty; 22 import com.ibm.icu.lang.UScript; 23 import com.ibm.icu.text.BreakIterator; 24 import com.ibm.icu.text.UnicodeSet; 25 import com.ibm.icu.util.UResourceBundle; 26 27 /** 28 * @internal 29 */ 30 public class LSTMBreakEngine extends DictionaryBreakEngine { 31 public enum EmbeddingType { 32 UNKNOWN, 33 CODE_POINTS, 34 GRAPHEME_CLUSTER 35 } 36 37 public enum LSTMClass { 38 BEGIN, 39 INSIDE, 40 END, 41 SINGLE, 42 } 43 make2DArray(int[] data, int start, int d1, int d2)44 private static float[][] make2DArray(int[] data, int start, int d1, int d2) { 45 byte[] bytes = new byte[4]; 46 float [][] result = new float[d1][d2]; 47 for (int i = 0; i < d1 ; i++) { 48 for (int j = 0; j < d2 ; j++) { 49 int d = data[start++]; 50 bytes[0] = (byte) (d >> 24); 51 bytes[1] = (byte) (d >> 16); 52 bytes[2] = (byte) (d >> 8); 53 bytes[3] = (byte) (d /*>> 0*/); 54 result[i][j] = ByteBuffer.wrap(bytes).order(ByteOrder.BIG_ENDIAN).getFloat(); 55 } 56 } 57 return result; 58 } 59 make1DArray(int[] data, int start, int d1)60 private static float[] make1DArray(int[] data, int start, int d1) { 61 byte[] bytes = new byte[4]; 62 float [] result = new float[d1]; 63 for (int i = 0; i < d1 ; i++) { 64 int d = data[start++]; 65 bytes[0] = (byte) (d >> 24); 66 bytes[1] = (byte) (d >> 16); 67 bytes[2] = (byte) (d >> 8); 68 bytes[3] = (byte) (d /*>> 0*/); 69 result[i] = ByteBuffer.wrap(bytes).order(ByteOrder.BIG_ENDIAN).getFloat(); 70 } 71 return result; 72 } 73 74 /** @internal */ 75 public static class LSTMData { LSTMData()76 private LSTMData() { 77 } 78 LSTMData(UResourceBundle rb)79 public LSTMData(UResourceBundle rb) { 80 int embeddings = rb.get("embeddings").getInt(); 81 int hunits = rb.get("hunits").getInt(); 82 this.fType = EmbeddingType.UNKNOWN; 83 this.fName = rb.get("model").getString(); 84 String typeString = rb.get("type").getString(); 85 if (typeString.equals("codepoints")) { 86 this.fType = EmbeddingType.CODE_POINTS; 87 } else if (typeString.equals("graphclust")) { 88 this.fType = EmbeddingType.GRAPHEME_CLUSTER; 89 } 90 String[] dict = rb.get("dict").getStringArray(); 91 int[] data = rb.get("data").getIntVector(); 92 int dataLen = data.length; 93 int numIndex = dict.length; 94 fDict = new HashMap<String, Integer>(numIndex + 1); 95 int idx = 0; 96 for (String embedding : dict){ 97 fDict.put(embedding, idx++); 98 } 99 int mat1Size = (numIndex + 1) * embeddings; 100 int mat2Size = embeddings * 4 * hunits; 101 int mat3Size = hunits * 4 * hunits; 102 int mat4Size = 4 * hunits; 103 int mat5Size = mat2Size; 104 int mat6Size = mat3Size; 105 int mat7Size = mat4Size; 106 int mat8Size = 2 * hunits * 4; 107 int mat9Size = 4; 108 assert dataLen == mat1Size + mat2Size + mat3Size + mat4Size + mat5Size + mat6Size + mat7Size + mat8Size + mat9Size; 109 int start = 0; 110 this.fEmbedding = make2DArray(data, start, (numIndex+1), embeddings); 111 start += mat1Size; 112 this.fForwardW = make2DArray(data, start, embeddings, 4 * hunits); 113 start += mat2Size; 114 this.fForwardU = make2DArray(data, start, hunits, 4 * hunits); 115 start += mat3Size; 116 this.fForwardB = make1DArray(data, start, 4 * hunits); 117 start += mat4Size; 118 this.fBackwardW = make2DArray(data, start, embeddings, 4 * hunits); 119 start += mat5Size; 120 this.fBackwardU = make2DArray(data, start, hunits, 4 * hunits); 121 start += mat6Size; 122 this.fBackwardB = make1DArray(data, start, 4 * hunits); 123 start += mat7Size; 124 this.fOutputW = make2DArray(data, start, 2 * hunits, 4); 125 start += mat8Size; 126 this.fOutputB = make1DArray(data, start, 4); 127 } 128 129 public EmbeddingType fType; 130 public String fName; 131 public Map<String, Integer> fDict; 132 public float fEmbedding[][]; 133 public float fForwardW[][]; 134 public float fForwardU[][]; 135 public float fForwardB[]; 136 public float fBackwardW[][]; 137 public float fBackwardU[][]; 138 public float fBackwardB[]; 139 public float fOutputW[][]; 140 public float fOutputB[]; 141 } 142 143 // Minimum word size 144 private static final byte MIN_WORD = 2; 145 146 // Minimum number of characters for two words 147 private static final byte MIN_WORD_SPAN = MIN_WORD * 2; 148 149 abstract class Vectorizer { Vectorizer(Map<String, Integer> dict)150 public Vectorizer(Map<String, Integer> dict) { 151 this.fDict = dict; 152 } vectorize(CharacterIterator fIter, int rangeStart, int rangeEnd, List<Integer> offsets, List<Integer> indicies)153 abstract public void vectorize(CharacterIterator fIter, int rangeStart, int rangeEnd, 154 List<Integer> offsets, List<Integer> indicies); getIndex(String token)155 protected int getIndex(String token) { 156 Integer res = fDict.get(token); 157 return (res == null) ? fDict.size() : res; 158 } 159 private Map<String, Integer> fDict; 160 } 161 162 class CodePointsVectorizer extends Vectorizer { CodePointsVectorizer(Map<String, Integer> dict)163 public CodePointsVectorizer(Map<String, Integer> dict) { 164 super(dict); 165 } 166 vectorize(CharacterIterator fIter, int rangeStart, int rangeEnd, List<Integer> offsets, List<Integer> indicies)167 public void vectorize(CharacterIterator fIter, int rangeStart, int rangeEnd, 168 List<Integer> offsets, List<Integer> indicies) { 169 fIter.setIndex(rangeStart); 170 for (char c = fIter.current(); 171 c != CharacterIterator.DONE && fIter.getIndex() < rangeEnd; 172 c = fIter.next()) { 173 offsets.add(fIter.getIndex()); 174 indicies.add(getIndex(String.valueOf(c))); 175 } 176 } 177 } 178 179 class GraphemeClusterVectorizer extends Vectorizer { GraphemeClusterVectorizer(Map<String, Integer> dict)180 public GraphemeClusterVectorizer(Map<String, Integer> dict) { 181 super(dict); 182 } 183 substring(CharacterIterator text, int startPos, int endPos)184 private String substring(CharacterIterator text, int startPos, int endPos) { 185 int saved = text.getIndex(); 186 text.setIndex(startPos); 187 StringBuilder sb = new StringBuilder(); 188 for (char c = text.current(); 189 c != CharacterIterator.DONE && text.getIndex() < endPos; 190 c = text.next()) { 191 sb.append(c); 192 } 193 text.setIndex(saved); 194 return sb.toString(); 195 } 196 vectorize(CharacterIterator text, int startPos, int endPos, List<Integer> offsets, List<Integer> indicies)197 public void vectorize(CharacterIterator text, int startPos, int endPos, 198 List<Integer> offsets, List<Integer> indicies) { 199 BreakIterator iter = BreakIterator.getCharacterInstance(); 200 iter.setText(text); 201 int last = iter.next(startPos); 202 for (int curr = iter.next(); curr != BreakIterator.DONE && curr <= endPos; curr = iter.next()) { 203 offsets.add(last); 204 String segment = substring(text, last, curr); 205 int index = getIndex(segment); 206 indicies.add(index); 207 last = curr; 208 } 209 } 210 } 211 212 private final LSTMData fData; 213 private int fScript; 214 private final Vectorizer fVectorizer; 215 makeVectorizer(LSTMData data)216 private Vectorizer makeVectorizer(LSTMData data) { 217 switch(data.fType) { 218 case CODE_POINTS: 219 return new CodePointsVectorizer(data.fDict); 220 case GRAPHEME_CLUSTER: 221 return new GraphemeClusterVectorizer(data.fDict); 222 default: 223 return null; 224 } 225 } 226 LSTMBreakEngine(int script, UnicodeSet set, LSTMData data)227 public LSTMBreakEngine(int script, UnicodeSet set, LSTMData data) { 228 setCharacters(set); 229 this.fScript = script; 230 this.fData = data; 231 this.fVectorizer = makeVectorizer(this.fData); 232 } 233 234 @Override hashCode()235 public int hashCode() { 236 return getClass().hashCode(); 237 } 238 239 @Override handles(int c)240 public boolean handles(int c) { 241 return fScript == UCharacter.getIntPropertyValue(c, UProperty.SCRIPT); 242 } 243 addDotProductTo(final float [] a, final float[][] b, float[] result)244 static private void addDotProductTo(final float [] a, final float[][] b, float[] result) { 245 assert a.length == b.length; 246 assert b[0].length == result.length; 247 for (int i = 0; i < result.length; i++) { 248 for (int j = 0; j < a.length; j++) { 249 result[i] += a[j] * b[j][i]; 250 } 251 } 252 } 253 addTo(final float [] a, float[] result)254 static private void addTo(final float [] a, float[] result) { 255 assert a.length == result.length; 256 for (int i = 0; i < result.length; i++) { 257 result[i] += a[i]; 258 } 259 } 260 hadamardProductTo(final float [] a, float[] result)261 static private void hadamardProductTo(final float [] a, float[] result) { 262 assert a.length == result.length; 263 for (int i = 0; i < result.length; i++) { 264 result[i] *= a[i]; 265 } 266 } 267 addHadamardProductTo(final float [] a, final float [] b, float[] result)268 static private void addHadamardProductTo(final float [] a, final float [] b, float[] result) { 269 assert a.length == result.length; 270 assert b.length == result.length; 271 for (int i = 0; i < result.length; i++) { 272 result[i] += a[i] * b[i]; 273 } 274 } 275 sigmoid(float [] result, int start, int length)276 static private void sigmoid(float [] result, int start, int length) { 277 assert start < result.length; 278 assert start + length <= result.length; 279 for (int i = start; i < start + length; i++) { 280 result[i] = (float)(1.0/(1.0 + Math.exp(-result[i]))); 281 } 282 } 283 284 static private void tanh(float [] result, int start, int length) { 285 assert start < result.length; 286 assert start + length <= result.length; 287 for (int i = start; i < start + length; i++) { 288 result[i] = (float)Math.tanh(result[i]); 289 } 290 } 291 292 static private int maxIndex(float [] data) { 293 int index = 0; 294 float max = data[0]; 295 for (int i = 1; i < data.length; i++) { 296 if (data[i] > max) { 297 max = data[i]; 298 index = i; 299 } 300 } 301 return index; 302 } 303 304 /* 305 static private void print(float [] data) { 306 for (int i=0; i < data.length; i++) { 307 System.out.format(" %e", data[i]); 308 if (i % 4 == 3) { 309 System.out.println(); 310 } 311 } 312 System.out.println(); 313 } 314 */ 315 316 private float[] compute(final float[][] W, final float[][] U, final float[] B, 317 final float[] x, float[] h, float[] c) { 318 // ifco = x * W + h * U + b 319 float[] ifco = Arrays.copyOf(B, B.length); 320 addDotProductTo(x, W, ifco); 321 float[] hU = new float[B.length]; 322 addDotProductTo(h, U, ifco); 323 324 int hunits = B.length / 4; 325 sigmoid(ifco, 0*hunits, hunits); // i 326 sigmoid(ifco, 1*hunits, hunits); // f 327 tanh(ifco, 2*hunits, hunits); // c_ 328 sigmoid(ifco, 3*hunits, hunits); // o 329 330 hadamardProductTo(Arrays.copyOfRange(ifco, hunits, 2*hunits), c); 331 addHadamardProductTo(Arrays.copyOf(ifco, hunits), 332 Arrays.copyOfRange(ifco, 2*hunits, 3*hunits), c); 333 334 h = Arrays.copyOf(c, c.length); 335 tanh(h, 0, h.length); 336 hadamardProductTo(Arrays.copyOfRange(ifco, 3*hunits, 4*hunits), h); 337 // System.out.println("c"); 338 // print(c); 339 // System.out.println("h"); 340 // print(h); 341 return h; 342 } 343 344 @Override 345 public int divideUpDictionaryRange(CharacterIterator fIter, int rangeStart, int rangeEnd, 346 DequeI foundBreaks, boolean isPhraseBreaking) { 347 int beginSize = foundBreaks.size(); 348 349 if ((rangeEnd - rangeStart) < MIN_WORD_SPAN) { 350 return 0; // Not enough characters for word 351 } 352 List<Integer> offsets = new ArrayList<Integer>(rangeEnd - rangeStart); 353 List<Integer> indicies = new ArrayList<Integer>(rangeEnd - rangeStart); 354 355 fVectorizer.vectorize(fIter, rangeStart, rangeEnd, offsets, indicies); 356 357 // To save the needed memory usage, the following is different from the 358 // Python or ICU4X implementation. We first perform the Backward LSTM 359 // and then merge the iteration of the forward LSTM and the output layer 360 // together because we only need to remember the h[t-1] for Forward LSTM. 361 int inputSeqLength = indicies.size(); 362 int hunits = this.fData.fForwardU.length; 363 float c[] = new float[hunits]; 364 365 // TODO: limit size of hBackward. If input_seq_len is too big, we could 366 // run out of memory. 367 // Backward LSTM 368 float hBackward[][] = new float[inputSeqLength][hunits]; 369 for (int i = inputSeqLength - 1; i >= 0; i--) { 370 if (i != inputSeqLength - 1) { 371 hBackward[i] = Arrays.copyOf(hBackward[i+1], hunits); 372 } 373 // System.out.println("Backward LSTM " + i); 374 hBackward[i] = compute(this.fData.fBackwardW, this.fData.fBackwardU, this.fData.fBackwardB, 375 this.fData.fEmbedding[indicies.get(i)], 376 hBackward[i], c); 377 } 378 379 c = new float[hunits]; 380 float forwardH[] = new float[hunits]; 381 float both[] = new float[2*hunits]; 382 383 // The following iteration merge the forward LSTM and the output layer 384 // together. 385 for (int i = 0 ; i < inputSeqLength; i++) { 386 // Forward LSTM 387 forwardH = compute(this.fData.fForwardW, this.fData.fForwardU, this.fData.fForwardB, 388 this.fData.fEmbedding[indicies.get(i)], 389 forwardH, c); 390 391 System.arraycopy(forwardH, 0, both, 0, hunits); 392 System.arraycopy(hBackward[i], 0, both, hunits, hunits); 393 394 //System.out.println("Merged " + i); 395 //print(both); 396 397 // Output layer 398 // logp = fbRow * fOutputW + fOutputB 399 float logp[] = Arrays.copyOf(this.fData.fOutputB, this.fData.fOutputB.length); 400 addDotProductTo(both, this.fData.fOutputW, logp); 401 402 int current = maxIndex(logp); 403 404 // BIES logic. 405 if (current == LSTMClass.BEGIN.ordinal() || 406 current == LSTMClass.SINGLE.ordinal()) { 407 if (i != 0) { 408 foundBreaks.push(offsets.get(i)); 409 } 410 } 411 } 412 413 return foundBreaks.size() - beginSize; 414 } 415 416 public static LSTMData createData(UResourceBundle bundle) { 417 return new LSTMData(bundle); 418 } 419 420 private static String defaultLSTM(int script) { 421 ICUResourceBundle rb = (ICUResourceBundle)UResourceBundle.getBundleInstance(ICUData.ICU_BRKITR_BASE_NAME); 422 return rb.getStringWithFallback("lstm/" + UScript.getShortName(script)); 423 } 424 425 public static LSTMData createData(int script) { 426 if (script != UScript.KHMER && script != UScript.LAO && script != UScript.MYANMAR && script != UScript.THAI) { 427 return null; 428 } 429 String name = defaultLSTM(script); 430 name = name.substring(0, name.indexOf(".")); 431 432 UResourceBundle rb = UResourceBundle.getBundleInstance( 433 ICUData.ICU_BRKITR_BASE_NAME, name, 434 ICUResourceBundle.ICU_DATA_CLASS_LOADER); 435 return createData(rb); 436 } 437 438 public static LSTMBreakEngine create(int script, LSTMData data) { 439 String setExpr = "[[:" + UScript.getShortName(script) + ":]&[:LineBreak=SA:]]"; 440 UnicodeSet set = new UnicodeSet(); 441 set.applyPattern(setExpr); 442 set.compact(); 443 return new LSTMBreakEngine(script, set, data); 444 } 445 } 446