• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.google.ux.material.libmonet.quantize;
18 
19 import static java.lang.Math.min;
20 
21 import java.util.Arrays;
22 import java.util.LinkedHashMap;
23 import java.util.Map;
24 import java.util.Random;
25 
26 /**
27  * An image quantizer that improves on the speed of a standard K-Means algorithm by implementing
28  * several optimizations, including deduping identical pixels and a triangle inequality rule that
29  * reduces the number of comparisons needed to identify which cluster a point should be moved to.
30  *
31  * <p>Wsmeans stands for Weighted Square Means.
32  *
33  * <p>This algorithm was designed by M. Emre Celebi, and was found in their 2011 paper, Improving
34  * the Performance of K-Means for Color Quantization. https://arxiv.org/abs/1101.0395
35  */
36 public final class QuantizerWsmeans {
QuantizerWsmeans()37   private QuantizerWsmeans() {}
38 
39   private static final class Distance implements Comparable<Distance> {
40     int index;
41     double distance;
42 
Distance()43     Distance() {
44       this.index = -1;
45       this.distance = -1;
46     }
47 
48     @Override
compareTo(Distance other)49     public int compareTo(Distance other) {
50       return ((Double) this.distance).compareTo(other.distance);
51     }
52   }
53 
54   private static final int MAX_ITERATIONS = 10;
55   private static final double MIN_MOVEMENT_DISTANCE = 3.0;
56 
57   /**
58    * Reduce the number of colors needed to represented the input, minimizing the difference between
59    * the original image and the recolored image.
60    *
61    * @param inputPixels Colors in ARGB format.
62    * @param startingClusters Defines the initial state of the quantizer. Passing an empty array is
63    *     fine, the implementation will create its own initial state that leads to reproducible
64    *     results for the same inputs. Passing an array that is the result of Wu quantization leads
65    *     to higher quality results.
66    * @param maxColors The number of colors to divide the image into. A lower number of colors may be
67    *     returned.
68    * @return Map with keys of colors in ARGB format, values of how many of the input pixels belong
69    *     to the color.
70    */
quantize( int[] inputPixels, int[] startingClusters, int maxColors)71   public static Map<Integer, Integer> quantize(
72       int[] inputPixels, int[] startingClusters, int maxColors) {
73     // Uses a seeded random number generator to ensure consistent results.
74     Random random = new Random(0x42688);
75 
76     Map<Integer, Integer> pixelToCount = new LinkedHashMap<>();
77     double[][] points = new double[inputPixels.length][];
78     int[] pixels = new int[inputPixels.length];
79     PointProvider pointProvider = new PointProviderLab();
80 
81     int pointCount = 0;
82     for (int i = 0; i < inputPixels.length; i++) {
83       int inputPixel = inputPixels[i];
84       Integer pixelCount = pixelToCount.get(inputPixel);
85       if (pixelCount == null) {
86         points[pointCount] = pointProvider.fromInt(inputPixel);
87         pixels[pointCount] = inputPixel;
88         pointCount++;
89 
90         pixelToCount.put(inputPixel, 1);
91       } else {
92         pixelToCount.put(inputPixel, pixelCount + 1);
93       }
94     }
95 
96     int[] counts = new int[pointCount];
97     for (int i = 0; i < pointCount; i++) {
98       int pixel = pixels[i];
99       int count = pixelToCount.get(pixel);
100       counts[i] = count;
101     }
102 
103     int clusterCount = min(maxColors, pointCount);
104     if (startingClusters.length != 0) {
105       clusterCount = min(clusterCount, startingClusters.length);
106     }
107 
108     double[][] clusters = new double[clusterCount][];
109     int clustersCreated = 0;
110     for (int i = 0; i < startingClusters.length; i++) {
111       clusters[i] = pointProvider.fromInt(startingClusters[i]);
112       clustersCreated++;
113     }
114 
115     int additionalClustersNeeded = clusterCount - clustersCreated;
116     if (additionalClustersNeeded > 0) {
117       for (int i = 0; i < additionalClustersNeeded; i++) {}
118     }
119 
120     int[] clusterIndices = new int[pointCount];
121     for (int i = 0; i < pointCount; i++) {
122       clusterIndices[i] = random.nextInt(clusterCount);
123     }
124 
125     int[][] indexMatrix = new int[clusterCount][];
126     for (int i = 0; i < clusterCount; i++) {
127       indexMatrix[i] = new int[clusterCount];
128     }
129 
130     Distance[][] distanceToIndexMatrix = new Distance[clusterCount][];
131     for (int i = 0; i < clusterCount; i++) {
132       distanceToIndexMatrix[i] = new Distance[clusterCount];
133       for (int j = 0; j < clusterCount; j++) {
134         distanceToIndexMatrix[i][j] = new Distance();
135       }
136     }
137 
138     int[] pixelCountSums = new int[clusterCount];
139     for (int iteration = 0; iteration < MAX_ITERATIONS; iteration++) {
140       for (int i = 0; i < clusterCount; i++) {
141         for (int j = i + 1; j < clusterCount; j++) {
142           double distance = pointProvider.distance(clusters[i], clusters[j]);
143           distanceToIndexMatrix[j][i].distance = distance;
144           distanceToIndexMatrix[j][i].index = i;
145           distanceToIndexMatrix[i][j].distance = distance;
146           distanceToIndexMatrix[i][j].index = j;
147         }
148         Arrays.sort(distanceToIndexMatrix[i]);
149         for (int j = 0; j < clusterCount; j++) {
150           indexMatrix[i][j] = distanceToIndexMatrix[i][j].index;
151         }
152       }
153 
154       int pointsMoved = 0;
155       for (int i = 0; i < pointCount; i++) {
156         double[] point = points[i];
157         int previousClusterIndex = clusterIndices[i];
158         double[] previousCluster = clusters[previousClusterIndex];
159         double previousDistance = pointProvider.distance(point, previousCluster);
160 
161         double minimumDistance = previousDistance;
162         int newClusterIndex = -1;
163         for (int j = 0; j < clusterCount; j++) {
164           if (distanceToIndexMatrix[previousClusterIndex][j].distance >= 4 * previousDistance) {
165             continue;
166           }
167           double distance = pointProvider.distance(point, clusters[j]);
168           if (distance < minimumDistance) {
169             minimumDistance = distance;
170             newClusterIndex = j;
171           }
172         }
173         if (newClusterIndex != -1) {
174           double distanceChange =
175               Math.abs(Math.sqrt(minimumDistance) - Math.sqrt(previousDistance));
176           if (distanceChange > MIN_MOVEMENT_DISTANCE) {
177             pointsMoved++;
178             clusterIndices[i] = newClusterIndex;
179           }
180         }
181       }
182 
183       if (pointsMoved == 0 && iteration != 0) {
184         break;
185       }
186 
187       double[] componentASums = new double[clusterCount];
188       double[] componentBSums = new double[clusterCount];
189       double[] componentCSums = new double[clusterCount];
190       Arrays.fill(pixelCountSums, 0);
191       for (int i = 0; i < pointCount; i++) {
192         int clusterIndex = clusterIndices[i];
193         double[] point = points[i];
194         int count = counts[i];
195         pixelCountSums[clusterIndex] += count;
196         componentASums[clusterIndex] += (point[0] * count);
197         componentBSums[clusterIndex] += (point[1] * count);
198         componentCSums[clusterIndex] += (point[2] * count);
199       }
200 
201       for (int i = 0; i < clusterCount; i++) {
202         int count = pixelCountSums[i];
203         if (count == 0) {
204           clusters[i] = new double[] {0., 0., 0.};
205           continue;
206         }
207         double a = componentASums[i] / count;
208         double b = componentBSums[i] / count;
209         double c = componentCSums[i] / count;
210         clusters[i][0] = a;
211         clusters[i][1] = b;
212         clusters[i][2] = c;
213       }
214     }
215 
216     Map<Integer, Integer> argbToPopulation = new LinkedHashMap<>();
217     for (int i = 0; i < clusterCount; i++) {
218       int count = pixelCountSums[i];
219       if (count == 0) {
220         continue;
221       }
222 
223       int possibleNewCluster = pointProvider.toInt(clusters[i]);
224       if (argbToPopulation.containsKey(possibleNewCluster)) {
225         continue;
226       }
227 
228       argbToPopulation.put(possibleNewCluster, count);
229     }
230 
231     return argbToPopulation;
232   }
233 }
234