1 /* 2 * Copyright 2021 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.google.ux.material.libmonet.quantize; 18 19 import static java.lang.Math.min; 20 21 import java.util.Arrays; 22 import java.util.LinkedHashMap; 23 import java.util.Map; 24 import java.util.Random; 25 26 /** 27 * An image quantizer that improves on the speed of a standard K-Means algorithm by implementing 28 * several optimizations, including deduping identical pixels and a triangle inequality rule that 29 * reduces the number of comparisons needed to identify which cluster a point should be moved to. 30 * 31 * <p>Wsmeans stands for Weighted Square Means. 32 * 33 * <p>This algorithm was designed by M. Emre Celebi, and was found in their 2011 paper, Improving 34 * the Performance of K-Means for Color Quantization. https://arxiv.org/abs/1101.0395 35 */ 36 public final class QuantizerWsmeans { QuantizerWsmeans()37 private QuantizerWsmeans() {} 38 39 private static final class Distance implements Comparable<Distance> { 40 int index; 41 double distance; 42 Distance()43 Distance() { 44 this.index = -1; 45 this.distance = -1; 46 } 47 48 @Override compareTo(Distance other)49 public int compareTo(Distance other) { 50 return ((Double) this.distance).compareTo(other.distance); 51 } 52 } 53 54 private static final int MAX_ITERATIONS = 10; 55 private static final double MIN_MOVEMENT_DISTANCE = 3.0; 56 57 /** 58 * Reduce the number of colors needed to represented the input, minimizing the difference between 59 * the original image and the recolored image. 60 * 61 * @param inputPixels Colors in ARGB format. 62 * @param startingClusters Defines the initial state of the quantizer. Passing an empty array is 63 * fine, the implementation will create its own initial state that leads to reproducible 64 * results for the same inputs. Passing an array that is the result of Wu quantization leads 65 * to higher quality results. 66 * @param maxColors The number of colors to divide the image into. A lower number of colors may be 67 * returned. 68 * @return Map with keys of colors in ARGB format, values of how many of the input pixels belong 69 * to the color. 70 */ quantize( int[] inputPixels, int[] startingClusters, int maxColors)71 public static Map<Integer, Integer> quantize( 72 int[] inputPixels, int[] startingClusters, int maxColors) { 73 // Uses a seeded random number generator to ensure consistent results. 74 Random random = new Random(0x42688); 75 76 Map<Integer, Integer> pixelToCount = new LinkedHashMap<>(); 77 double[][] points = new double[inputPixels.length][]; 78 int[] pixels = new int[inputPixels.length]; 79 PointProvider pointProvider = new PointProviderLab(); 80 81 int pointCount = 0; 82 for (int i = 0; i < inputPixels.length; i++) { 83 int inputPixel = inputPixels[i]; 84 Integer pixelCount = pixelToCount.get(inputPixel); 85 if (pixelCount == null) { 86 points[pointCount] = pointProvider.fromInt(inputPixel); 87 pixels[pointCount] = inputPixel; 88 pointCount++; 89 90 pixelToCount.put(inputPixel, 1); 91 } else { 92 pixelToCount.put(inputPixel, pixelCount + 1); 93 } 94 } 95 96 int[] counts = new int[pointCount]; 97 for (int i = 0; i < pointCount; i++) { 98 int pixel = pixels[i]; 99 int count = pixelToCount.get(pixel); 100 counts[i] = count; 101 } 102 103 int clusterCount = min(maxColors, pointCount); 104 if (startingClusters.length != 0) { 105 clusterCount = min(clusterCount, startingClusters.length); 106 } 107 108 double[][] clusters = new double[clusterCount][]; 109 int clustersCreated = 0; 110 for (int i = 0; i < startingClusters.length; i++) { 111 clusters[i] = pointProvider.fromInt(startingClusters[i]); 112 clustersCreated++; 113 } 114 115 int additionalClustersNeeded = clusterCount - clustersCreated; 116 if (additionalClustersNeeded > 0) { 117 for (int i = 0; i < additionalClustersNeeded; i++) {} 118 } 119 120 int[] clusterIndices = new int[pointCount]; 121 for (int i = 0; i < pointCount; i++) { 122 clusterIndices[i] = random.nextInt(clusterCount); 123 } 124 125 int[][] indexMatrix = new int[clusterCount][]; 126 for (int i = 0; i < clusterCount; i++) { 127 indexMatrix[i] = new int[clusterCount]; 128 } 129 130 Distance[][] distanceToIndexMatrix = new Distance[clusterCount][]; 131 for (int i = 0; i < clusterCount; i++) { 132 distanceToIndexMatrix[i] = new Distance[clusterCount]; 133 for (int j = 0; j < clusterCount; j++) { 134 distanceToIndexMatrix[i][j] = new Distance(); 135 } 136 } 137 138 int[] pixelCountSums = new int[clusterCount]; 139 for (int iteration = 0; iteration < MAX_ITERATIONS; iteration++) { 140 for (int i = 0; i < clusterCount; i++) { 141 for (int j = i + 1; j < clusterCount; j++) { 142 double distance = pointProvider.distance(clusters[i], clusters[j]); 143 distanceToIndexMatrix[j][i].distance = distance; 144 distanceToIndexMatrix[j][i].index = i; 145 distanceToIndexMatrix[i][j].distance = distance; 146 distanceToIndexMatrix[i][j].index = j; 147 } 148 Arrays.sort(distanceToIndexMatrix[i]); 149 for (int j = 0; j < clusterCount; j++) { 150 indexMatrix[i][j] = distanceToIndexMatrix[i][j].index; 151 } 152 } 153 154 int pointsMoved = 0; 155 for (int i = 0; i < pointCount; i++) { 156 double[] point = points[i]; 157 int previousClusterIndex = clusterIndices[i]; 158 double[] previousCluster = clusters[previousClusterIndex]; 159 double previousDistance = pointProvider.distance(point, previousCluster); 160 161 double minimumDistance = previousDistance; 162 int newClusterIndex = -1; 163 for (int j = 0; j < clusterCount; j++) { 164 if (distanceToIndexMatrix[previousClusterIndex][j].distance >= 4 * previousDistance) { 165 continue; 166 } 167 double distance = pointProvider.distance(point, clusters[j]); 168 if (distance < minimumDistance) { 169 minimumDistance = distance; 170 newClusterIndex = j; 171 } 172 } 173 if (newClusterIndex != -1) { 174 double distanceChange = 175 Math.abs(Math.sqrt(minimumDistance) - Math.sqrt(previousDistance)); 176 if (distanceChange > MIN_MOVEMENT_DISTANCE) { 177 pointsMoved++; 178 clusterIndices[i] = newClusterIndex; 179 } 180 } 181 } 182 183 if (pointsMoved == 0 && iteration != 0) { 184 break; 185 } 186 187 double[] componentASums = new double[clusterCount]; 188 double[] componentBSums = new double[clusterCount]; 189 double[] componentCSums = new double[clusterCount]; 190 Arrays.fill(pixelCountSums, 0); 191 for (int i = 0; i < pointCount; i++) { 192 int clusterIndex = clusterIndices[i]; 193 double[] point = points[i]; 194 int count = counts[i]; 195 pixelCountSums[clusterIndex] += count; 196 componentASums[clusterIndex] += (point[0] * count); 197 componentBSums[clusterIndex] += (point[1] * count); 198 componentCSums[clusterIndex] += (point[2] * count); 199 } 200 201 for (int i = 0; i < clusterCount; i++) { 202 int count = pixelCountSums[i]; 203 if (count == 0) { 204 clusters[i] = new double[] {0., 0., 0.}; 205 continue; 206 } 207 double a = componentASums[i] / count; 208 double b = componentBSums[i] / count; 209 double c = componentCSums[i] / count; 210 clusters[i][0] = a; 211 clusters[i][1] = b; 212 clusters[i][2] = c; 213 } 214 } 215 216 Map<Integer, Integer> argbToPopulation = new LinkedHashMap<>(); 217 for (int i = 0; i < clusterCount; i++) { 218 int count = pixelCountSums[i]; 219 if (count == 0) { 220 continue; 221 } 222 223 int possibleNewCluster = pointProvider.toInt(clusters[i]); 224 if (argbToPopulation.containsKey(possibleNewCluster)) { 225 continue; 226 } 227 228 argbToPopulation.put(possibleNewCluster, count); 229 } 230 231 return argbToPopulation; 232 } 233 } 234