1 /* 2 * Copyright (C) 2021 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.internal.graphics.palette; 18 19 import static java.lang.System.arraycopy; 20 21 import android.annotation.NonNull; 22 import android.annotation.Nullable; 23 import android.graphics.Color; 24 25 import java.util.ArrayList; 26 import java.util.List; 27 import java.util.Map; 28 import java.util.Set; 29 30 31 /** 32 * Wu's quantization algorithm is a box-cut quantizer that minimizes variance. It takes longer to 33 * run than, say, median color cut, but provides the highest quality results currently known. 34 * 35 * Prefer `QuantizerCelebi`: coupled with Kmeans, this provides the best-known results for image 36 * quantization. 37 * 38 * Seemingly all Wu implementations are based off of one C code snippet that cites a book from 1992 39 * Graphics Gems vol. II, pp. 126-133. As a result, it is very hard to understand the mechanics of 40 * the algorithm, beyond the commentary provided in the C code. Comments on the methods of this 41 * class are avoided in favor of finding another implementation and reading the commentary there, 42 * avoiding perpetuating the same incomplete and somewhat confusing commentary here. 43 */ 44 public final class WuQuantizer implements Quantizer { 45 // A histogram of all the input colors is constructed. It has the shape of a 46 // cube. The cube would be too large if it contained all 16 million colors: 47 // historical best practice is to use 5 bits of the 8 in each channel, 48 // reducing the histogram to a volume of ~32,000. 49 private static final int BITS = 5; 50 private static final int MAX_INDEX = 32; 51 private static final int SIDE_LENGTH = 33; 52 private static final int TOTAL_SIZE = 35937; 53 54 private int[] mWeights; 55 private int[] mMomentsR; 56 private int[] mMomentsG; 57 private int[] mMomentsB; 58 private double[] mMoments; 59 private Box[] mCubes; 60 private Palette mPalette; 61 private int[] mColors; 62 private Map<Integer, Integer> mInputPixelToCount; 63 64 @Override getQuantizedColors()65 public List<Palette.Swatch> getQuantizedColors() { 66 return mPalette.getSwatches(); 67 } 68 69 @Override quantize(@onNull int[] pixels, int colorCount)70 public void quantize(@NonNull int[] pixels, int colorCount) { 71 assert (pixels.length > 0); 72 73 QuantizerMap quantizerMap = new QuantizerMap(); 74 quantizerMap.quantize(pixels, colorCount); 75 mInputPixelToCount = quantizerMap.getColorToCount(); 76 // Extraction should not be run on using a color count higher than the number of colors 77 // in the pixels. The algorithm doesn't expect that to be the case, unexpected results and 78 // exceptions may occur. 79 Set<Integer> uniqueColors = mInputPixelToCount.keySet(); 80 if (uniqueColors.size() <= colorCount) { 81 mColors = new int[mInputPixelToCount.keySet().size()]; 82 int index = 0; 83 for (int color : uniqueColors) { 84 mColors[index++] = color; 85 } 86 } else { 87 constructHistogram(mInputPixelToCount); 88 createMoments(); 89 CreateBoxesResult createBoxesResult = createBoxes(colorCount); 90 mColors = createResult(createBoxesResult.mResultCount); 91 } 92 93 List<Palette.Swatch> swatches = new ArrayList<>(); 94 for (int color : mColors) { 95 swatches.add(new Palette.Swatch(color, 0)); 96 } 97 mPalette = Palette.from(swatches); 98 } 99 100 @Nullable getColors()101 public int[] getColors() { 102 return mColors; 103 } 104 105 /** Keys are color ints, values are the number of pixels in the image matching that color int */ 106 @Nullable inputPixelToCount()107 public Map<Integer, Integer> inputPixelToCount() { 108 return mInputPixelToCount; 109 } 110 getIndex(int r, int g, int b)111 private static int getIndex(int r, int g, int b) { 112 return (r << 10) + (r << 6) + (g << 5) + r + g + b; 113 } 114 constructHistogram(Map<Integer, Integer> pixels)115 private void constructHistogram(Map<Integer, Integer> pixels) { 116 mWeights = new int[TOTAL_SIZE]; 117 mMomentsR = new int[TOTAL_SIZE]; 118 mMomentsG = new int[TOTAL_SIZE]; 119 mMomentsB = new int[TOTAL_SIZE]; 120 mMoments = new double[TOTAL_SIZE]; 121 122 for (Map.Entry<Integer, Integer> pair : pixels.entrySet()) { 123 int pixel = pair.getKey(); 124 int count = pair.getValue(); 125 int red = Color.red(pixel); 126 int green = Color.green(pixel); 127 int blue = Color.blue(pixel); 128 int bitsToRemove = 8 - BITS; 129 int iR = (red >> bitsToRemove) + 1; 130 int iG = (green >> bitsToRemove) + 1; 131 int iB = (blue >> bitsToRemove) + 1; 132 int index = getIndex(iR, iG, iB); 133 mWeights[index] += count; 134 mMomentsR[index] += (red * count); 135 mMomentsG[index] += (green * count); 136 mMomentsB[index] += (blue * count); 137 mMoments[index] += (count * ((red * red) + (green * green) + (blue * blue))); 138 } 139 } 140 createMoments()141 private void createMoments() { 142 for (int r = 1; r < SIDE_LENGTH; ++r) { 143 int[] area = new int[SIDE_LENGTH]; 144 int[] areaR = new int[SIDE_LENGTH]; 145 int[] areaG = new int[SIDE_LENGTH]; 146 int[] areaB = new int[SIDE_LENGTH]; 147 double[] area2 = new double[SIDE_LENGTH]; 148 149 for (int g = 1; g < SIDE_LENGTH; ++g) { 150 int line = 0; 151 int lineR = 0; 152 int lineG = 0; 153 int lineB = 0; 154 155 double line2 = 0.0; 156 for (int b = 1; b < SIDE_LENGTH; ++b) { 157 int index = getIndex(r, g, b); 158 line += mWeights[index]; 159 lineR += mMomentsR[index]; 160 lineG += mMomentsG[index]; 161 lineB += mMomentsB[index]; 162 line2 += mMoments[index]; 163 164 area[b] += line; 165 areaR[b] += lineR; 166 areaG[b] += lineG; 167 areaB[b] += lineB; 168 area2[b] += line2; 169 170 int previousIndex = getIndex(r - 1, g, b); 171 mWeights[index] = mWeights[previousIndex] + area[b]; 172 mMomentsR[index] = mMomentsR[previousIndex] + areaR[b]; 173 mMomentsG[index] = mMomentsG[previousIndex] + areaG[b]; 174 mMomentsB[index] = mMomentsB[previousIndex] + areaB[b]; 175 mMoments[index] = mMoments[previousIndex] + area2[b]; 176 } 177 } 178 } 179 } 180 createBoxes(int maxColorCount)181 private CreateBoxesResult createBoxes(int maxColorCount) { 182 mCubes = new Box[maxColorCount]; 183 for (int i = 0; i < maxColorCount; i++) { 184 mCubes[i] = new Box(); 185 } 186 double[] volumeVariance = new double[maxColorCount]; 187 Box firstBox = mCubes[0]; 188 firstBox.r1 = MAX_INDEX; 189 firstBox.g1 = MAX_INDEX; 190 firstBox.b1 = MAX_INDEX; 191 192 int generatedColorCount = 0; 193 int next = 0; 194 195 for (int i = 1; i < maxColorCount; i++) { 196 if (cut(mCubes[next], mCubes[i])) { 197 volumeVariance[next] = (mCubes[next].vol > 1) ? variance(mCubes[next]) : 0.0; 198 volumeVariance[i] = (mCubes[i].vol > 1) ? variance(mCubes[i]) : 0.0; 199 } else { 200 volumeVariance[next] = 0.0; 201 i--; 202 } 203 204 next = 0; 205 206 double temp = volumeVariance[0]; 207 for (int k = 1; k <= i; k++) { 208 if (volumeVariance[k] > temp) { 209 temp = volumeVariance[k]; 210 next = k; 211 } 212 } 213 generatedColorCount = i + 1; 214 if (temp <= 0.0) { 215 break; 216 } 217 } 218 219 return new CreateBoxesResult(maxColorCount, generatedColorCount); 220 } 221 createResult(int colorCount)222 private int[] createResult(int colorCount) { 223 int[] colors = new int[colorCount]; 224 int nextAvailableIndex = 0; 225 for (int i = 0; i < colorCount; ++i) { 226 Box cube = mCubes[i]; 227 int weight = volume(cube, mWeights); 228 if (weight > 0) { 229 int r = (volume(cube, mMomentsR) / weight); 230 int g = (volume(cube, mMomentsG) / weight); 231 int b = (volume(cube, mMomentsB) / weight); 232 int color = Color.rgb(r, g, b); 233 colors[nextAvailableIndex++] = color; 234 } 235 } 236 int[] resultArray = new int[nextAvailableIndex]; 237 arraycopy(colors, 0, resultArray, 0, nextAvailableIndex); 238 return resultArray; 239 } 240 variance(Box cube)241 private double variance(Box cube) { 242 int dr = volume(cube, mMomentsR); 243 int dg = volume(cube, mMomentsG); 244 int db = volume(cube, mMomentsB); 245 double xx = 246 mMoments[getIndex(cube.r1, cube.g1, cube.b1)] 247 - mMoments[getIndex(cube.r1, cube.g1, cube.b0)] 248 - mMoments[getIndex(cube.r1, cube.g0, cube.b1)] 249 + mMoments[getIndex(cube.r1, cube.g0, cube.b0)] 250 - mMoments[getIndex(cube.r0, cube.g1, cube.b1)] 251 + mMoments[getIndex(cube.r0, cube.g1, cube.b0)] 252 + mMoments[getIndex(cube.r0, cube.g0, cube.b1)] 253 - mMoments[getIndex(cube.r0, cube.g0, cube.b0)]; 254 255 int hypotenuse = (dr * dr + dg * dg + db * db); 256 int volume2 = volume(cube, mWeights); 257 double variance2 = xx - ((double) hypotenuse / (double) volume2); 258 return variance2; 259 } 260 cut(Box one, Box two)261 private boolean cut(Box one, Box two) { 262 int wholeR = volume(one, mMomentsR); 263 int wholeG = volume(one, mMomentsG); 264 int wholeB = volume(one, mMomentsB); 265 int wholeW = volume(one, mWeights); 266 267 MaximizeResult maxRResult = 268 maximize(one, Direction.RED, one.r0 + 1, one.r1, wholeR, wholeG, wholeB, wholeW); 269 MaximizeResult maxGResult = 270 maximize(one, Direction.GREEN, one.g0 + 1, one.g1, wholeR, wholeG, wholeB, wholeW); 271 MaximizeResult maxBResult = 272 maximize(one, Direction.BLUE, one.b0 + 1, one.b1, wholeR, wholeG, wholeB, wholeW); 273 Direction cutDirection; 274 double maxR = maxRResult.mMaximum; 275 double maxG = maxGResult.mMaximum; 276 double maxB = maxBResult.mMaximum; 277 if (maxR >= maxG && maxR >= maxB) { 278 if (maxRResult.mCutLocation < 0) { 279 return false; 280 } 281 cutDirection = Direction.RED; 282 } else if (maxG >= maxR && maxG >= maxB) { 283 cutDirection = Direction.GREEN; 284 } else { 285 cutDirection = Direction.BLUE; 286 } 287 288 two.r1 = one.r1; 289 two.g1 = one.g1; 290 two.b1 = one.b1; 291 292 switch (cutDirection) { 293 case RED: 294 one.r1 = maxRResult.mCutLocation; 295 two.r0 = one.r1; 296 two.g0 = one.g0; 297 two.b0 = one.b0; 298 break; 299 case GREEN: 300 one.g1 = maxGResult.mCutLocation; 301 two.r0 = one.r0; 302 two.g0 = one.g1; 303 two.b0 = one.b0; 304 break; 305 case BLUE: 306 one.b1 = maxBResult.mCutLocation; 307 two.r0 = one.r0; 308 two.g0 = one.g0; 309 two.b0 = one.b1; 310 break; 311 default: 312 throw new IllegalArgumentException("unexpected direction " + cutDirection); 313 } 314 315 one.vol = (one.r1 - one.r0) * (one.g1 - one.g0) * (one.b1 - one.b0); 316 two.vol = (two.r1 - two.r0) * (two.g1 - two.g0) * (two.b1 - two.b0); 317 318 return true; 319 } 320 maximize( Box cube, Direction direction, int first, int last, int wholeR, int wholeG, int wholeB, int wholeW)321 private MaximizeResult maximize( 322 Box cube, 323 Direction direction, 324 int first, 325 int last, 326 int wholeR, 327 int wholeG, 328 int wholeB, 329 int wholeW) { 330 int baseR = bottom(cube, direction, mMomentsR); 331 int baseG = bottom(cube, direction, mMomentsG); 332 int baseB = bottom(cube, direction, mMomentsB); 333 int baseW = bottom(cube, direction, mWeights); 334 335 double max = 0.0; 336 int cut = -1; 337 for (int i = first; i < last; i++) { 338 int halfR = baseR + top(cube, direction, i, mMomentsR); 339 int halfG = baseG + top(cube, direction, i, mMomentsG); 340 int halfB = baseB + top(cube, direction, i, mMomentsB); 341 int halfW = baseW + top(cube, direction, i, mWeights); 342 343 if (halfW == 0) { 344 continue; 345 } 346 double tempNumerator = halfR * halfR + halfG * halfG + halfB * halfB; 347 double tempDenominator = halfW; 348 double temp = tempNumerator / tempDenominator; 349 350 halfR = wholeR - halfR; 351 halfG = wholeG - halfG; 352 halfB = wholeB - halfB; 353 halfW = wholeW - halfW; 354 if (halfW == 0) { 355 continue; 356 } 357 358 tempNumerator = halfR * halfR + halfG * halfG + halfB * halfB; 359 tempDenominator = halfW; 360 temp += (tempNumerator / tempDenominator); 361 if (temp > max) { 362 max = temp; 363 cut = i; 364 } 365 } 366 return new MaximizeResult(cut, max); 367 } 368 volume(Box cube, int[] moment)369 private static int volume(Box cube, int[] moment) { 370 return (moment[getIndex(cube.r1, cube.g1, cube.b1)] 371 - moment[getIndex(cube.r1, cube.g1, cube.b0)] 372 - moment[getIndex(cube.r1, cube.g0, cube.b1)] 373 + moment[getIndex(cube.r1, cube.g0, cube.b0)] 374 - moment[getIndex(cube.r0, cube.g1, cube.b1)] 375 + moment[getIndex(cube.r0, cube.g1, cube.b0)] 376 + moment[getIndex(cube.r0, cube.g0, cube.b1)] 377 - moment[getIndex(cube.r0, cube.g0, cube.b0)]); 378 } 379 bottom(Box cube, Direction direction, int[] moment)380 private static int bottom(Box cube, Direction direction, int[] moment) { 381 switch (direction) { 382 case RED: 383 return -moment[getIndex(cube.r0, cube.g1, cube.b1)] 384 + moment[getIndex(cube.r0, cube.g1, cube.b0)] 385 + moment[getIndex(cube.r0, cube.g0, cube.b1)] 386 - moment[getIndex(cube.r0, cube.g0, cube.b0)]; 387 case GREEN: 388 return -moment[getIndex(cube.r1, cube.g0, cube.b1)] 389 + moment[getIndex(cube.r1, cube.g0, cube.b0)] 390 + moment[getIndex(cube.r0, cube.g0, cube.b1)] 391 - moment[getIndex(cube.r0, cube.g0, cube.b0)]; 392 case BLUE: 393 return -moment[getIndex(cube.r1, cube.g1, cube.b0)] 394 + moment[getIndex(cube.r1, cube.g0, cube.b0)] 395 + moment[getIndex(cube.r0, cube.g1, cube.b0)] 396 - moment[getIndex(cube.r0, cube.g0, cube.b0)]; 397 default: 398 throw new IllegalArgumentException("unexpected direction " + direction); 399 } 400 } 401 top(Box cube, Direction direction, int position, int[] moment)402 private static int top(Box cube, Direction direction, int position, int[] moment) { 403 switch (direction) { 404 case RED: 405 return (moment[getIndex(position, cube.g1, cube.b1)] 406 - moment[getIndex(position, cube.g1, cube.b0)] 407 - moment[getIndex(position, cube.g0, cube.b1)] 408 + moment[getIndex(position, cube.g0, cube.b0)]); 409 case GREEN: 410 return (moment[getIndex(cube.r1, position, cube.b1)] 411 - moment[getIndex(cube.r1, position, cube.b0)] 412 - moment[getIndex(cube.r0, position, cube.b1)] 413 + moment[getIndex(cube.r0, position, cube.b0)]); 414 case BLUE: 415 return (moment[getIndex(cube.r1, cube.g1, position)] 416 - moment[getIndex(cube.r1, cube.g0, position)] 417 - moment[getIndex(cube.r0, cube.g1, position)] 418 + moment[getIndex(cube.r0, cube.g0, position)]); 419 default: 420 throw new IllegalArgumentException("unexpected direction " + direction); 421 } 422 } 423 424 private enum Direction { 425 RED, 426 GREEN, 427 BLUE 428 } 429 430 private static class MaximizeResult { 431 // < 0 if cut impossible 432 final int mCutLocation; 433 final double mMaximum; 434 MaximizeResult(int cut, double max)435 MaximizeResult(int cut, double max) { 436 mCutLocation = cut; 437 mMaximum = max; 438 } 439 } 440 441 private static class CreateBoxesResult { 442 final int mRequestedCount; 443 final int mResultCount; 444 CreateBoxesResult(int requestedCount, int resultCount)445 CreateBoxesResult(int requestedCount, int resultCount) { 446 mRequestedCount = requestedCount; 447 mResultCount = resultCount; 448 } 449 } 450 451 private static class Box { 452 public int r0 = 0; 453 public int r1 = 0; 454 public int g0 = 0; 455 public int g1 = 0; 456 public int b0 = 0; 457 public int b1 = 0; 458 public int vol = 0; 459 } 460 } 461 462 463