1 /* 2 * Copyright (C) 2022 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.adservices.service.measurement.noising; 18 19 import com.android.adservices.service.measurement.PrivacyParams; 20 21 import com.google.common.math.DoubleMath; 22 import com.google.common.math.LongMath; 23 24 import java.math.BigInteger; 25 import java.util.ArrayList; 26 import java.util.Arrays; 27 import java.util.List; 28 import java.util.Objects; 29 30 /** 31 * Combinatorics utilities used for randomization. 32 */ 33 public class Combinatorics { 34 35 /** 36 * Returns the k-combination associated with the number {@code combinationIndex}. In 37 * other words, returns the combination of {@code k} integers uniquely indexed by 38 * {@code combinationIndex} in the combinatorial number system. 39 * https://en.wikipedia.org/wiki/Combinatorial_number_system 40 * 41 * @return combinationIndex-th lexicographically smallest k-combination. 42 * @throws ArithmeticException in case of int overflow 43 */ getKCombinationAtIndex(long combinationIndex, int k)44 static long[] getKCombinationAtIndex(long combinationIndex, int k) { 45 // Computes the combinationIndex-th lexicographically smallest k-combination. 46 // https://en.wikipedia.org/wiki/Combinatorial_number_system 47 // 48 // A k-combination is a sequence of k non-negative integers in decreasing order. 49 // a_k > a_{k-1} > ... > a_2 > a_1 >= 0. 50 // k-combinations can be ordered lexicographically, with the smallest 51 // k-combination being a_k=k-1, a_{k-1}=k-2, .., a_1=0. Given an index 52 // combinationIndex>=0, and an order k, this method returns the 53 // combinationIndex-th smallest k-combination. 54 // 55 // Given an index combinationIndex, the combinationIndex-th k-combination 56 // is the unique set of k non-negative integers 57 // a_k > a_{k-1} > ... > a_2 > a_1 >= 0 58 // such that combinationIndex = \sum_{i=1}^k {a_i}\choose{i} 59 // 60 // We find this set via a simple greedy algorithm. 61 // http://math0.wvstateu.edu/~baker/cs405/code/Combinadics.html 62 long[] result = new long[k]; 63 if (k == 0) { 64 return result; 65 } 66 // To find a_k, iterate candidates upwards from 0 until we've found the 67 // maximum a such that (a choose k) <= combinationIndex. Let a_k = a. Use 68 // the previous binomial coefficient to compute the next one. Note: possible 69 // to speed this up via something other than incremental search. 70 long target = combinationIndex; 71 long candidate = (long) k - 1L; 72 long binomialCoefficient = 0L; 73 long nextBinomialCoefficient = 1L; 74 while (nextBinomialCoefficient <= target) { 75 candidate++; 76 binomialCoefficient = nextBinomialCoefficient; 77 // (n + 1 choose k) = (n choose k) * (n + 1) / (n + 1 - k) 78 nextBinomialCoefficient = Math.multiplyExact(binomialCoefficient, (candidate + 1)); 79 nextBinomialCoefficient /= candidate + 1 - k; 80 } 81 // We know from the k-combination definition, all subsequent values will be 82 // strictly decreasing. Find them all by decrementing candidate. 83 // Use the previous binomial coefficient to compute the next one. 84 long currentK = (long) k; 85 int currentIndex = 0; 86 while (true) { 87 if (binomialCoefficient <= target) { 88 result[currentIndex] = candidate; 89 currentIndex++; 90 target -= binomialCoefficient; 91 if (currentIndex == k) { 92 return result; 93 } 94 // (n - 1 choose k - 1) = (n choose k) * k / n 95 binomialCoefficient = binomialCoefficient * currentK / candidate; 96 currentK--; 97 } else { 98 // (n - 1 choose k) = (n choose k) * (n - k) / n 99 binomialCoefficient = binomialCoefficient * (candidate - currentK) / candidate; 100 } 101 candidate--; 102 } 103 } 104 105 /** 106 * Returns the number of possible sequences of "stars and bars" sequences 107 * https://en.wikipedia.org/wiki/Stars_and_bars_(combinatorics), 108 * which is equivalent to (numStars + numBars choose numStars). 109 * 110 * @param numStars number of stars 111 * @param numBars number of bars 112 * @return number of possible sequences 113 */ getNumberOfStarsAndBarsSequences(int numStars, int numBars)114 public static long getNumberOfStarsAndBarsSequences(int numStars, int numBars) { 115 // Note, LongMath::binomial returns Long.MAX_VALUE rather than overflow. 116 return LongMath.binomial(numStars + numBars, numStars); 117 } 118 119 /** 120 * Returns an array of the indices of every star in the stars and bars sequence indexed by 121 * {@code sequenceIndex}. 122 * 123 * @param numStars number of stars in the sequence 124 * @param sequenceIndex index of the sequence 125 * @return list of indices of every star in stars & bars sequence 126 */ getStarIndices(int numStars, long sequenceIndex)127 public static long[] getStarIndices(int numStars, long sequenceIndex) { 128 return getKCombinationAtIndex(sequenceIndex, numStars); 129 } 130 131 /** 132 * From an array with the index of every star in a stars and bars sequence, returns an array 133 * which, for every star, counts the number of bars preceding it. 134 * 135 * @param starIndices indices of the stars in descending order 136 * @return count of bars preceding every star 137 */ getBarsPrecedingEachStar(long[] starIndices)138 public static long[] getBarsPrecedingEachStar(long[] starIndices) { 139 for (int i = 0; i < starIndices.length; i++) { 140 long starIndex = starIndices[i]; 141 // There are {@code starIndex} prior positions in the sequence, and `i` prior 142 // stars, so there are {@code starIndex - i} prior bars. 143 starIndices[i] = starIndex - ((long) starIndices.length - 1L - (long) i); 144 } 145 return starIndices; 146 } 147 148 /** 149 * Compute number of states from the trigger specification 150 * 151 * @param numBucketIncrements number of bucket increments (equivalent to number of triggers) 152 * @param numTriggerData number of trigger data. (equivalent to number of metadata) 153 * @param numWindows number of reporting windows 154 * @return number of states 155 */ getNumStatesArithmetic( int numBucketIncrements, int numTriggerData, int numWindows)156 public static long getNumStatesArithmetic( 157 int numBucketIncrements, int numTriggerData, int numWindows) { 158 int numStars = numBucketIncrements; 159 int numBars = Math.multiplyExact(numTriggerData, numWindows); 160 return getNumberOfStarsAndBarsSequences(numStars, numBars); 161 } 162 163 /** 164 * Using dynamic programming to compute number of states. Returns Long.MAX_VALUE if the result 165 * is greater than {@code bound}. 166 * 167 * @param totalCap total incremental cap 168 * @param perTypeNumWindowList reporting window per trigger data 169 * @param perTypeCapList cap per trigger data 170 * @param bound the highest state count allowed 171 * @return number of states 172 * @throws ArithmeticException in case of long overflow 173 */ getNumStatesIterative( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long bound)174 private static long getNumStatesIterative( 175 int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long bound) { 176 // Assumes perTypeCapList cannot sum to more than int value. Overflowing int here can lead 177 // to an exception when declaring the array size later, based on the min value. 178 int sum = 0; 179 for (int cap : perTypeCapList) { 180 sum += cap; 181 } 182 int leastTotalCap = Math.min(totalCap, sum); 183 long[][] dp = new long[2][leastTotalCap + 1]; 184 int prev = 0; 185 int curr = 1; 186 187 dp[prev][0] = 1L; 188 long result = 0L; 189 190 for (int i = 0; i < perTypeNumWindowList.length && perTypeNumWindowList[i] > 0; i++) { 191 int winCount = perTypeNumWindowList[i]; 192 int capCount = perTypeCapList[i]; 193 result = 0L; 194 195 for (int cap = 0; cap < leastTotalCap + 1; cap++) { 196 dp[curr][cap] = 0L; 197 198 for (int capVal = 0; capVal < Math.min(cap, capCount) + 1; capVal++) { 199 dp[curr][cap] = Math.addExact( 200 dp[curr][cap], 201 Math.multiplyExact( 202 dp[prev][cap - capVal], 203 getNumberOfStarsAndBarsSequences(capVal, winCount - 1))); 204 } 205 206 result = Math.addExact(result, dp[curr][cap]); 207 208 if (result > bound) { 209 return Long.MAX_VALUE; 210 } 211 } 212 213 curr ^= 1; 214 prev ^= 1; 215 } 216 217 return Math.max(result, 1L); 218 } 219 220 /** 221 * Compute number of states for flexible event report API. Returns Long.MAX_VALUE if the result 222 * exceeds {@code bound}. 223 * 224 * @param totalCap number of total increments 225 * @param perTypeNumWindowList reporting window for each trigger data 226 * @param perTypeCapList limit of the increment of each trigger data 227 * @param bound the highest state count allowed 228 * @return number of states 229 * @throws ArithmeticException in case of long overflow during the iterative procedure 230 */ getNumStatesFlexApi( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long bound)231 public static long getNumStatesFlexApi( 232 int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long bound) { 233 if (perTypeNumWindowList.length == 0 || perTypeCapList.length == 0) { 234 return 1; 235 } 236 for (int i = 1; i < perTypeNumWindowList.length; i++) { 237 if (perTypeNumWindowList[i] != perTypeNumWindowList[i - 1]) { 238 return getNumStatesIterative(totalCap, perTypeNumWindowList, perTypeCapList, bound); 239 } 240 } 241 for (int n : perTypeCapList) { 242 if (n < totalCap) { 243 return getNumStatesIterative(totalCap, perTypeNumWindowList, perTypeCapList, bound); 244 } 245 } 246 247 long result = getNumStatesArithmetic( 248 totalCap, perTypeCapList.length, perTypeNumWindowList[0]); 249 250 return result > bound ? Long.MAX_VALUE : result; 251 } 252 253 /** 254 * @param numOfStates Number of States 255 * @return the probability to use fake reports 256 */ getFlipProbability(long numOfStates, double privacyEpsilon)257 public static double getFlipProbability(long numOfStates, double privacyEpsilon) { 258 return (numOfStates) / (numOfStates + Math.exp(privacyEpsilon) - 1D); 259 } 260 getBinaryEntropy(double x)261 private static double getBinaryEntropy(double x) { 262 if (DoubleMath.fuzzyEquals(x, 0.0d, PrivacyParams.NUMBER_EQUAL_THRESHOLD) 263 || DoubleMath.fuzzyEquals(x, 1.0d, PrivacyParams.NUMBER_EQUAL_THRESHOLD)) { 264 return 0.0D; 265 } 266 return (-1.0D) * x * DoubleMath.log2(x) - (1 - x) * DoubleMath.log2(1 - x); 267 } 268 269 /** 270 * @param numOfStates Number of States 271 * @param flipProbability Flip Probability 272 * @return the information gain 273 */ getInformationGain(long numOfStates, double flipProbability)274 public static double getInformationGain(long numOfStates, double flipProbability) { 275 if (numOfStates <= 1L) { 276 return 0d; 277 } 278 double log2Q = DoubleMath.log2(numOfStates); 279 double fakeProbability = flipProbability * (numOfStates - 1L) / numOfStates; 280 return log2Q 281 - getBinaryEntropy(fakeProbability) 282 - fakeProbability * DoubleMath.log2(numOfStates - 1); 283 } 284 285 /** 286 * Returns the max information gain given the num of trigger states, attribution scope limit and 287 * max num event states. 288 * 289 * @param numTriggerStates The number of trigger states. 290 * @param attributionScopeLimit The attribution scope limit. 291 * @param maxEventStates The maximum number of event states (expected to be positive). 292 * @return The max information gain. 293 */ getMaxInformationGainWithAttributionScope( long numTriggerStates, long attributionScopeLimit, long maxEventStates)294 public static double getMaxInformationGainWithAttributionScope( 295 long numTriggerStates, long attributionScopeLimit, long maxEventStates) { 296 if (numTriggerStates <= 0 || maxEventStates <= 0) { 297 throw new IllegalArgumentException( 298 "numTriggerStates and maxEventStates must be positive"); 299 } 300 BigInteger totalNumStates = 301 BigInteger.valueOf(numTriggerStates) 302 .add( 303 BigInteger.valueOf(maxEventStates) 304 .multiply(BigInteger.valueOf(attributionScopeLimit - 1))); 305 return DoubleMath.log2(totalNumStates.doubleValue()); 306 } 307 308 /** 309 * Generate fake report set given a trigger specification and the rank order number 310 * 311 * @param totalCap total_cap 312 * @param perTypeNumWindowList per type number of window list 313 * @param perTypeCapList per type cap list 314 * @param rank the rank of the report state within all the report states 315 * @return a report set based on the input rank 316 */ getReportSetBasedOnRank( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long rank)317 public static List<AtomReportState> getReportSetBasedOnRank( 318 int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long rank) { 319 int triggerTypeIndex = perTypeNumWindowList.length - 1; 320 return getReportSetBasedOnRankRecursive( 321 totalCap, 322 triggerTypeIndex, 323 perTypeNumWindowList[triggerTypeIndex], 324 perTypeCapList[triggerTypeIndex], 325 rank, 326 perTypeNumWindowList, 327 perTypeCapList); 328 } 329 getReportSetBasedOnRankRecursive( int totalCap, int triggerTypeIndex, int winVal, int capVal, long rank, int[] perTypeNumWindowList, int[] perTypeCapList)330 private static List<AtomReportState> getReportSetBasedOnRankRecursive( 331 int totalCap, 332 int triggerTypeIndex, 333 int winVal, 334 int capVal, 335 long rank, 336 int[] perTypeNumWindowList, 337 int[] perTypeCapList) { 338 339 if (winVal == 0 && triggerTypeIndex == 0) { 340 return new ArrayList<>(); 341 } else if (winVal == 0) { 342 return getReportSetBasedOnRankRecursive( 343 totalCap, 344 triggerTypeIndex - 1, 345 perTypeNumWindowList[triggerTypeIndex - 1], 346 perTypeCapList[triggerTypeIndex - 1], 347 rank, 348 perTypeNumWindowList, 349 perTypeCapList); 350 } 351 for (int i = 0; i <= Math.min(totalCap, capVal); i++) { 352 int[] perTypeNumWindowListClone = Arrays.copyOfRange( 353 perTypeNumWindowList, 0, triggerTypeIndex + 1); 354 perTypeNumWindowListClone[triggerTypeIndex] = winVal - 1; 355 int[] perTypeCapListClone = Arrays.copyOfRange( 356 perTypeCapList, 0, triggerTypeIndex + 1); 357 perTypeCapListClone[triggerTypeIndex] = capVal - i; 358 long currentNumStates = 359 getNumStatesIterative( 360 totalCap - i, 361 perTypeNumWindowListClone, 362 perTypeCapListClone, 363 Long.MAX_VALUE); 364 if (currentNumStates > rank) { 365 // The triggers to be appended. 366 List<AtomReportState> toAppend = new ArrayList<>(); 367 for (int k = 0; k < i; k++) { 368 toAppend.add(new AtomReportState(triggerTypeIndex, winVal - 1)); 369 } 370 List<AtomReportState> otherReports = 371 getReportSetBasedOnRankRecursive( 372 totalCap - i, 373 triggerTypeIndex, 374 winVal - 1, 375 capVal - i, 376 rank, 377 perTypeNumWindowList, 378 perTypeCapList); 379 toAppend.addAll(otherReports); 380 return toAppend; 381 } else { 382 rank -= currentNumStates; 383 } 384 } 385 // will not reach here 386 return new ArrayList<>(); 387 } 388 389 /** A single report including triggerDataType and window index for the fake report generation */ 390 public static class AtomReportState { 391 private final int mTriggerDataType; 392 private final int mWindowIndex; 393 AtomReportState(int triggerDataType, int windowIndex)394 public AtomReportState(int triggerDataType, int windowIndex) { 395 this.mTriggerDataType = triggerDataType; 396 this.mWindowIndex = windowIndex; 397 } 398 getTriggerDataType()399 public int getTriggerDataType() { 400 return mTriggerDataType; 401 } 402 ; 403 getWindowIndex()404 public final int getWindowIndex() { 405 return mWindowIndex; 406 } 407 ; 408 409 @Override equals(Object obj)410 public boolean equals(Object obj) { 411 if (!(obj instanceof AtomReportState)) { 412 return false; 413 } 414 AtomReportState t = (AtomReportState) obj; 415 return mTriggerDataType == t.mTriggerDataType && mWindowIndex == t.mWindowIndex; 416 } 417 418 @Override hashCode()419 public int hashCode() { 420 return Objects.hash(mWindowIndex, mTriggerDataType); 421 } 422 } 423 } 424