1 /* 2 * Copyright (C) 2022 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.adservices.service.measurement.noising; 18 19 import com.android.adservices.LogUtil; 20 import com.android.adservices.service.measurement.PrivacyParams; 21 22 import com.google.common.math.DoubleMath; 23 24 import java.util.ArrayList; 25 import java.util.Arrays; 26 import java.util.HashMap; 27 import java.util.List; 28 import java.util.Map; 29 import java.util.Objects; 30 31 /** 32 * Combinatorics utilities used for randomization. 33 */ 34 public class Combinatorics { 35 36 /** 37 * Computes the binomial coefficient aka {@code n} choose {@code k}. 38 * https://en.wikipedia.org/wiki/Binomial_coefficient 39 * 40 * @return binomial coefficient for (n choose k) 41 * @throws ArithmeticException if the result overflows an int 42 */ getBinomialCoefficient(int n, int k)43 static int getBinomialCoefficient(int n, int k) { 44 if (k > n) { 45 return 0; 46 } 47 if (k == n || n == 0) { 48 return 1; 49 } 50 // getBinomialCoefficient(n, k) == getBinomialCoefficient(n, n - k). 51 if (k > n - k) { 52 k = n - k; 53 } 54 // (n choose k) = n (n -1) ... (n - (k - 1)) / k! 55 // = mul((n + 1 - i) / i), i from 1 -> k. 56 // 57 // You might be surprised that this algorithm works just fine with integer 58 // division (i.e. division occurs cleanly with no remainder). However, this is 59 // true for a very simple reason. Imagine a value of `i` causes division with 60 // remainder in the below algorithm. This immediately implies that 61 // (n choose i) is fractional, which we know is not the case. 62 int result = 1; 63 for (int i = 1; i <= k; i++) { 64 result = Math.multiplyExact(result, n + 1 - i); 65 result = result / i; 66 } 67 return result; 68 } 69 70 /** 71 * Returns the k-combination associated with the number {@code combinationIndex}. In 72 * other words, returns the combination of {@code k} integers uniquely indexed by 73 * {@code combinationIndex} in the combinatorial number system. 74 * https://en.wikipedia.org/wiki/Combinatorial_number_system 75 * 76 * @return combinationIndex-th lexicographically smallest k-combination. 77 * @throws ArithmeticException in case of int overflow 78 */ getKCombinationAtIndex(int combinationIndex, int k)79 static int[] getKCombinationAtIndex(int combinationIndex, int k) { 80 // Computes the combinationIndex-th lexicographically smallest k-combination. 81 // https://en.wikipedia.org/wiki/Combinatorial_number_system 82 // 83 // A k-combination is a sequence of k non-negative integers in decreasing order. 84 // a_k > a_{k-1} > ... > a_2 > a_1 >= 0. 85 // k-combinations can be ordered lexicographically, with the smallest 86 // k-combination being a_k=k-1, a_{k-1}=k-2, .., a_1=0. Given an index 87 // combinationIndex>=0, and an order k, this method returns the 88 // combinationIndex-th smallest k-combination. 89 // 90 // Given an index combinationIndex, the combinationIndex-th k-combination 91 // is the unique set of k non-negative integers 92 // a_k > a_{k-1} > ... > a_2 > a_1 >= 0 93 // such that combinationIndex = \sum_{i=1}^k {a_i}\choose{i} 94 // 95 // We find this set via a simple greedy algorithm. 96 // http://math0.wvstateu.edu/~baker/cs405/code/Combinadics.html 97 int[] result = new int[k]; 98 if (k == 0) { 99 return result; 100 } 101 // To find a_k, iterate candidates upwards from 0 until we've found the 102 // maximum a such that (a choose k) <= combinationIndex. Let a_k = a. Use 103 // the previous binomial coefficient to compute the next one. Note: possible 104 // to speed this up via something other than incremental search. 105 int target = combinationIndex; 106 int candidate = k - 1; 107 int binomialCoefficient = 0; 108 int nextBinomialCoefficient = 1; 109 while (nextBinomialCoefficient <= target) { 110 candidate++; 111 binomialCoefficient = nextBinomialCoefficient; 112 // (n + 1 choose k) = (n choose k) * (n + 1) / (n + 1 - k) 113 nextBinomialCoefficient = Math.multiplyExact(binomialCoefficient, (candidate + 1)); 114 nextBinomialCoefficient /= candidate + 1 - k; 115 } 116 // We know from the k-combination definition, all subsequent values will be 117 // strictly decreasing. Find them all by decrementing candidate. 118 // Use the previous binomial coefficient to compute the next one. 119 int currentK = k; 120 int currentIndex = 0; 121 while (true) { 122 if (binomialCoefficient <= target) { 123 result[currentIndex] = candidate; 124 currentIndex++; 125 target -= binomialCoefficient; 126 if (currentIndex == k) { 127 return result; 128 } 129 // (n - 1 choose k - 1) = (n choose k) * k / n 130 binomialCoefficient = binomialCoefficient * currentK / candidate; 131 currentK--; 132 } else { 133 // (n - 1 choose k) = (n choose k) * (n - k) / n 134 binomialCoefficient = binomialCoefficient * (candidate - currentK) / candidate; 135 } 136 candidate--; 137 } 138 } 139 140 /** 141 * Returns the number of possible sequences of "stars and bars" sequences 142 * https://en.wikipedia.org/wiki/Stars_and_bars_(combinatorics), 143 * which is equivalent to (numStars + numBars choose numStars). 144 * 145 * @param numStars number of stars 146 * @param numBars number of bars 147 * @return number of possible sequences 148 */ getNumberOfStarsAndBarsSequences(int numStars, int numBars)149 public static int getNumberOfStarsAndBarsSequences(int numStars, int numBars) { 150 return getBinomialCoefficient(numStars + numBars, numStars); 151 } 152 153 /** 154 * Returns an array of the indices of every star in the stars and bars sequence indexed by 155 * {@code sequenceIndex}. 156 * 157 * @param numStars number of stars in the sequence 158 * @param sequenceIndex index of the sequence 159 * @return list of indices of every star in stars & bars sequence 160 */ getStarIndices(int numStars, int sequenceIndex)161 public static int[] getStarIndices(int numStars, int sequenceIndex) { 162 return getKCombinationAtIndex(sequenceIndex, numStars); 163 } 164 165 /** 166 * From an array with the index of every star in a stars and bars sequence, returns an array 167 * which, for every star, counts the number of bars preceding it. 168 * 169 * @param starIndices indices of the stars in descending order 170 * @return count of bars preceding every star 171 */ getBarsPrecedingEachStar(int[] starIndices)172 public static int[] getBarsPrecedingEachStar(int[] starIndices) { 173 for (int i = 0; i < starIndices.length; i++) { 174 int starIndex = starIndices[i]; 175 // There are {@code starIndex} prior positions in the sequence, and `i` prior 176 // stars, so there are {@code starIndex - i} prior bars. 177 starIndices[i] = starIndex - (starIndices.length - 1 - i); 178 } 179 return starIndices; 180 } 181 182 /** 183 * Compute number of states from the trigger specification 184 * 185 * @param numBucketIncrements number of bucket increments (equivalent to number of triggers) 186 * @param numTriggerData number of trigger data. (equivalent to number of metadata) 187 * @param numWindows number of reporting windows 188 * @return number of states 189 */ getNumStatesArithmetic( int numBucketIncrements, int numTriggerData, int numWindows)190 public static int getNumStatesArithmetic( 191 int numBucketIncrements, int numTriggerData, int numWindows) { 192 int numStars = numBucketIncrements; 193 int numBars = Math.multiplyExact(numTriggerData, numWindows); 194 return getNumberOfStarsAndBarsSequences(numStars, numBars); 195 } 196 197 /** 198 * Using dynamic programming to compute number of states. Assuming the parameter validation has 199 * been checked to avoid overflow or out of memory error 200 * 201 * @param totalCap total incremental cap 202 * @param perTypeNumWindowList reporting window per trigger data 203 * @param perTypeCapList cap per trigger data 204 * @return number of states 205 */ getNumStatesRecursive( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList)206 private static int getNumStatesRecursive( 207 int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList) { 208 int index = perTypeNumWindowList.length - 1; 209 return getNumStatesRecursive( 210 totalCap, 211 index, 212 perTypeNumWindowList[index], 213 perTypeCapList[index], 214 perTypeNumWindowList, 215 perTypeCapList, 216 new HashMap<>()); 217 } 218 getNumStatesRecursive( int totalCap, int index, int winVal, int capVal, int[] perTypeNumWindowList, int[] perTypeCapList, Map<List<Integer>, Integer> dp)219 private static int getNumStatesRecursive( 220 int totalCap, 221 int index, 222 int winVal, 223 int capVal, 224 int[] perTypeNumWindowList, 225 int[] perTypeCapList, 226 Map<List<Integer>, Integer> dp) { 227 List<Integer> key = List.of(totalCap, index, winVal, capVal); 228 if (!dp.containsKey(key)) { 229 if (winVal == 0 && index == 0) { 230 dp.put(key, 1); 231 } else if (winVal == 0) { 232 dp.put(key, getNumStatesRecursive( 233 totalCap, 234 index - 1, 235 perTypeNumWindowList[index - 1], 236 perTypeCapList[index - 1], 237 perTypeNumWindowList, 238 perTypeCapList, 239 dp)); 240 } else { 241 int result = 0; 242 for (int i = 0; i <= Math.min(totalCap, capVal); i++) { 243 result = Math.addExact(result, getNumStatesRecursive( 244 totalCap - i, 245 index, 246 winVal - 1, 247 capVal - i, 248 perTypeNumWindowList, 249 perTypeCapList, 250 dp)); 251 } 252 dp.put(key, result); 253 } 254 } 255 return dp.get(key); 256 } 257 258 /** 259 * Compute number of states for flexible event report API 260 * 261 * @param totalCap number of total increments 262 * @param perTypeNumWindowList reporting window for each trigger data 263 * @param perTypeCapList limit of the increment of each trigger data 264 * @return number of states 265 */ getNumStatesFlexAPI( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList)266 public static int getNumStatesFlexAPI( 267 int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList) { 268 if (!validateInputReportingPara(totalCap, perTypeNumWindowList, perTypeCapList)) { 269 LogUtil.e("Input parameters are out of range"); 270 return -1; 271 } 272 boolean canComputeArithmetic = true; 273 for (int i = 1; i < perTypeNumWindowList.length; i++) { 274 if (perTypeNumWindowList[i] != perTypeNumWindowList[i - 1]) { 275 canComputeArithmetic = false; 276 break; 277 } 278 } 279 for (int n : perTypeCapList) { 280 if (n < totalCap) { 281 canComputeArithmetic = false; 282 break; 283 } 284 } 285 if (canComputeArithmetic) { 286 return getNumStatesArithmetic(totalCap, perTypeCapList.length, perTypeNumWindowList[0]); 287 } 288 289 return getNumStatesRecursive(totalCap, perTypeNumWindowList, perTypeCapList); 290 } 291 validateInputReportingPara( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList)292 private static boolean validateInputReportingPara( 293 int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList) { 294 for (int n : perTypeNumWindowList) { 295 if (n > PrivacyParams.getMaxFlexibleEventReportingWindows()) return false; 296 } 297 return PrivacyParams.getMaxFlexibleEventReports() 298 >= Math.min(totalCap, Arrays.stream(perTypeCapList).sum()) 299 && perTypeNumWindowList.length 300 <= PrivacyParams.getMaxFlexibleEventTriggerDataCardinality(); 301 } 302 303 /** 304 * @param numOfStates Number of States 305 * @return the probability to use fake reports 306 */ getFlipProbability(int numOfStates)307 public static double getFlipProbability(int numOfStates) { 308 int epsilon = PrivacyParams.getPrivacyEpsilon(); 309 return numOfStates / (numOfStates + Math.exp(epsilon) - 1); 310 } 311 getBinaryEntropy(double x)312 private static double getBinaryEntropy(double x) { 313 if (DoubleMath.fuzzyEquals(x, 0.0d, PrivacyParams.NUMBER_EQUAL_THRESHOLD) 314 || DoubleMath.fuzzyEquals(x, 1.0d, PrivacyParams.NUMBER_EQUAL_THRESHOLD)) { 315 return 0; 316 } 317 return (-1.0) * x * DoubleMath.log2(x) - (1 - x) * DoubleMath.log2(1 - x); 318 } 319 320 /** 321 * @param numOfStates Number of States 322 * @param flipProbability Flip Probability 323 * @return the information gain 324 */ getInformationGain(int numOfStates, double flipProbability)325 public static double getInformationGain(int numOfStates, double flipProbability) { 326 double log2Q = DoubleMath.log2(numOfStates); 327 double fakeProbability = flipProbability * (numOfStates - 1) / numOfStates; 328 return log2Q 329 - getBinaryEntropy(fakeProbability) 330 - fakeProbability * DoubleMath.log2(numOfStates - 1); 331 } 332 333 /** 334 * Generate fake report set given a report specification and the rank order number 335 * 336 * @param totalCap total_cap 337 * @param perTypeNumWindowList per type number of window list 338 * @param perTypeCapList per type cap list 339 * @return a report set based on the input rank 340 */ getReportSetBasedOnRank( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, int rank, Map<List<Integer>, Integer> dp)341 public static List<AtomReportState> getReportSetBasedOnRank( 342 int totalCap, 343 int[] perTypeNumWindowList, 344 int[] perTypeCapList, 345 int rank, 346 Map<List<Integer>, Integer> dp) { 347 int triggerTypeIndex = perTypeNumWindowList.length - 1; 348 349 return getReportSetBasedOnRankRecursive( 350 totalCap, 351 triggerTypeIndex, 352 perTypeNumWindowList[triggerTypeIndex], 353 perTypeCapList[triggerTypeIndex], 354 rank, 355 perTypeNumWindowList, 356 perTypeCapList, 357 dp); 358 } 359 getReportSetBasedOnRankRecursive( int totalCap, int triggerTypeIndex, int winVal, int capVal, int rank, int[] perTypeNumWindowList, int[] perTypeCapList, Map<List<Integer>, Integer> numStatesLookupTable)360 private static List<AtomReportState> getReportSetBasedOnRankRecursive( 361 int totalCap, 362 int triggerTypeIndex, 363 int winVal, 364 int capVal, 365 int rank, 366 int[] perTypeNumWindowList, 367 int[] perTypeCapList, 368 Map<List<Integer>, Integer> numStatesLookupTable) { 369 370 if (winVal == 0 && triggerTypeIndex == 0) { 371 return new ArrayList<>(); 372 } else if (winVal == 0) { 373 return getReportSetBasedOnRankRecursive( 374 totalCap, 375 triggerTypeIndex - 1, 376 perTypeNumWindowList[triggerTypeIndex - 1], 377 perTypeCapList[triggerTypeIndex - 1], 378 rank, 379 perTypeNumWindowList, 380 perTypeCapList, 381 numStatesLookupTable); 382 } 383 for (int i = 0; i <= Math.min(totalCap, capVal); i++) { 384 int currentNumStates = 385 getNumStatesRecursive( 386 totalCap - i, 387 triggerTypeIndex, 388 winVal - 1, 389 capVal - i, 390 perTypeNumWindowList, 391 perTypeCapList, 392 numStatesLookupTable); 393 if (currentNumStates > rank) { 394 // The triggers to be appended. 395 List<AtomReportState> toAppend = new ArrayList<>(); 396 for (int k = 0; k < i; k++) { 397 toAppend.add(new AtomReportState(triggerTypeIndex, winVal - 1)); 398 } 399 List<AtomReportState> otherReports = 400 getReportSetBasedOnRankRecursive( 401 totalCap - i, 402 triggerTypeIndex, 403 winVal - 1, 404 capVal - i, 405 rank, 406 perTypeNumWindowList, 407 perTypeCapList, 408 numStatesLookupTable); 409 toAppend.addAll(otherReports); 410 return toAppend; 411 } else { 412 rank -= currentNumStates; 413 } 414 } 415 // will not reach here 416 return new ArrayList<>(); 417 } 418 419 /** A single report including triggerDataType and window index for the fake report generation */ 420 public static class AtomReportState { 421 private final int mTriggerDataType; 422 private final int mWindowIndex; 423 AtomReportState(int triggerDataType, int windowIndex)424 public AtomReportState(int triggerDataType, int windowIndex) { 425 this.mTriggerDataType = triggerDataType; 426 this.mWindowIndex = windowIndex; 427 } 428 getTriggerDataType()429 public int getTriggerDataType() { 430 return mTriggerDataType; 431 } 432 ; 433 getWindowIndex()434 public final int getWindowIndex() { 435 return mWindowIndex; 436 } 437 ; 438 439 @Override equals(Object obj)440 public boolean equals(Object obj) { 441 if (!(obj instanceof AtomReportState)) { 442 return false; 443 } 444 AtomReportState t = (AtomReportState) obj; 445 return mTriggerDataType == t.mTriggerDataType && mWindowIndex == t.mWindowIndex; 446 } 447 448 @Override hashCode()449 public int hashCode() { 450 return Objects.hash(mWindowIndex, mTriggerDataType); 451 } 452 } 453 } 454