• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.adservices.service.measurement.noising;
18 
19 import com.android.adservices.LogUtil;
20 import com.android.adservices.service.measurement.PrivacyParams;
21 
22 import com.google.common.math.DoubleMath;
23 
24 import java.util.ArrayList;
25 import java.util.Arrays;
26 import java.util.HashMap;
27 import java.util.List;
28 import java.util.Map;
29 import java.util.Objects;
30 
31 /**
32  * Combinatorics utilities used for randomization.
33  */
34 public class Combinatorics {
35 
36     /**
37      * Computes the binomial coefficient aka {@code n} choose {@code k}.
38      * https://en.wikipedia.org/wiki/Binomial_coefficient
39      *
40      * @return binomial coefficient for (n choose k)
41      * @throws ArithmeticException if the result overflows an int
42      */
getBinomialCoefficient(int n, int k)43     static int getBinomialCoefficient(int n, int k) {
44         if (k > n) {
45             return 0;
46         }
47         if (k == n || n == 0) {
48             return 1;
49         }
50         // getBinomialCoefficient(n, k) == getBinomialCoefficient(n, n - k).
51         if (k > n - k) {
52             k = n - k;
53         }
54         // (n choose k) = n (n -1) ... (n - (k - 1)) / k!
55         // = mul((n + 1 - i) / i), i from 1 -> k.
56         //
57         // You might be surprised that this algorithm works just fine with integer
58         // division (i.e. division occurs cleanly with no remainder). However, this is
59         // true for a very simple reason. Imagine a value of `i` causes division with
60         // remainder in the below algorithm. This immediately implies that
61         // (n choose i) is fractional, which we know is not the case.
62         int result = 1;
63         for (int i = 1; i <= k; i++) {
64             result = Math.multiplyExact(result, n + 1 - i);
65             result = result / i;
66         }
67         return result;
68     }
69 
70     /**
71      * Returns the k-combination associated with the number {@code combinationIndex}. In
72      * other words, returns the combination of {@code k} integers uniquely indexed by
73      * {@code combinationIndex} in the combinatorial number system.
74      * https://en.wikipedia.org/wiki/Combinatorial_number_system
75      *
76      * @return combinationIndex-th lexicographically smallest k-combination.
77      * @throws ArithmeticException in case of int overflow
78      */
getKCombinationAtIndex(int combinationIndex, int k)79     static int[] getKCombinationAtIndex(int combinationIndex, int k) {
80         // Computes the combinationIndex-th lexicographically smallest k-combination.
81         // https://en.wikipedia.org/wiki/Combinatorial_number_system
82         //
83         // A k-combination is a sequence of k non-negative integers in decreasing order.
84         // a_k > a_{k-1} > ... > a_2 > a_1 >= 0.
85         // k-combinations can be ordered lexicographically, with the smallest
86         // k-combination being a_k=k-1, a_{k-1}=k-2, .., a_1=0. Given an index
87         // combinationIndex>=0, and an order k, this method returns the
88         // combinationIndex-th smallest k-combination.
89         //
90         // Given an index combinationIndex, the combinationIndex-th k-combination
91         // is the unique set of k non-negative integers
92         // a_k > a_{k-1} > ... > a_2 > a_1 >= 0
93         // such that combinationIndex = \sum_{i=1}^k {a_i}\choose{i}
94         //
95         // We find this set via a simple greedy algorithm.
96         // http://math0.wvstateu.edu/~baker/cs405/code/Combinadics.html
97         int[] result = new int[k];
98         if (k == 0) {
99             return result;
100         }
101         // To find a_k, iterate candidates upwards from 0 until we've found the
102         // maximum a such that (a choose k) <= combinationIndex. Let a_k = a. Use
103         // the previous binomial coefficient to compute the next one. Note: possible
104         // to speed this up via something other than incremental search.
105         int target = combinationIndex;
106         int candidate = k - 1;
107         int binomialCoefficient = 0;
108         int nextBinomialCoefficient = 1;
109         while (nextBinomialCoefficient <= target) {
110             candidate++;
111             binomialCoefficient = nextBinomialCoefficient;
112             // (n + 1 choose k) = (n choose k) * (n + 1) / (n + 1 - k)
113             nextBinomialCoefficient = Math.multiplyExact(binomialCoefficient, (candidate + 1));
114             nextBinomialCoefficient /= candidate + 1 - k;
115         }
116         // We know from the k-combination definition, all subsequent values will be
117         // strictly decreasing. Find them all by decrementing candidate.
118         // Use the previous binomial coefficient to compute the next one.
119         int currentK = k;
120         int currentIndex = 0;
121         while (true) {
122             if (binomialCoefficient <= target) {
123                 result[currentIndex] = candidate;
124                 currentIndex++;
125                 target -= binomialCoefficient;
126                 if (currentIndex == k) {
127                     return result;
128                 }
129                 // (n - 1 choose k - 1) = (n choose k) * k / n
130                 binomialCoefficient = binomialCoefficient * currentK / candidate;
131                 currentK--;
132             } else {
133                 // (n - 1 choose k) = (n choose k) * (n - k) / n
134                 binomialCoefficient = binomialCoefficient * (candidate - currentK) / candidate;
135             }
136             candidate--;
137         }
138     }
139 
140     /**
141      * Returns the number of possible sequences of "stars and bars" sequences
142      * https://en.wikipedia.org/wiki/Stars_and_bars_(combinatorics),
143      * which is equivalent to (numStars + numBars choose numStars).
144      *
145      * @param numStars number of stars
146      * @param numBars  number of bars
147      * @return number of possible sequences
148      */
getNumberOfStarsAndBarsSequences(int numStars, int numBars)149     public static int getNumberOfStarsAndBarsSequences(int numStars, int numBars) {
150         return getBinomialCoefficient(numStars + numBars, numStars);
151     }
152 
153     /**
154      * Returns an array of the indices of every star in the stars and bars sequence indexed by
155      * {@code sequenceIndex}.
156      *
157      * @param numStars number of stars in the sequence
158      * @param sequenceIndex index of the sequence
159      * @return list of indices of every star in stars & bars sequence
160      */
getStarIndices(int numStars, int sequenceIndex)161     public static int[] getStarIndices(int numStars, int sequenceIndex) {
162         return getKCombinationAtIndex(sequenceIndex, numStars);
163     }
164 
165     /**
166      * From an array with the index of every star in a stars and bars sequence, returns an array
167      * which, for every star, counts the number of bars preceding it.
168      *
169      * @param starIndices indices of the stars in descending order
170      * @return count of bars preceding every star
171      */
getBarsPrecedingEachStar(int[] starIndices)172     public static int[] getBarsPrecedingEachStar(int[] starIndices) {
173         for (int i = 0; i < starIndices.length; i++) {
174             int starIndex = starIndices[i];
175             // There are {@code starIndex} prior positions in the sequence, and `i` prior
176             // stars, so there are {@code starIndex - i} prior bars.
177             starIndices[i] = starIndex - (starIndices.length - 1 - i);
178         }
179         return starIndices;
180     }
181 
182     /**
183      * Compute number of states from the trigger specification
184      *
185      * @param numBucketIncrements number of bucket increments (equivalent to number of triggers)
186      * @param numTriggerData number of trigger data. (equivalent to number of metadata)
187      * @param numWindows number of reporting windows
188      * @return number of states
189      */
getNumStatesArithmetic( int numBucketIncrements, int numTriggerData, int numWindows)190     public static int getNumStatesArithmetic(
191             int numBucketIncrements, int numTriggerData, int numWindows) {
192         int numStars = numBucketIncrements;
193         int numBars = Math.multiplyExact(numTriggerData, numWindows);
194         return getNumberOfStarsAndBarsSequences(numStars, numBars);
195     }
196 
197     /**
198      * Using dynamic programming to compute number of states. Assuming the parameter validation has
199      * been checked to avoid overflow or out of memory error
200      *
201      * @param totalCap total incremental cap
202      * @param perTypeNumWindowList reporting window per trigger data
203      * @param perTypeCapList cap per trigger data
204      * @return number of states
205      */
getNumStatesRecursive( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList)206     private static int getNumStatesRecursive(
207             int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList) {
208         int index = perTypeNumWindowList.length - 1;
209         return getNumStatesRecursive(
210                 totalCap,
211                 index,
212                 perTypeNumWindowList[index],
213                 perTypeCapList[index],
214                 perTypeNumWindowList,
215                 perTypeCapList,
216                 new HashMap<>());
217     }
218 
getNumStatesRecursive( int totalCap, int index, int winVal, int capVal, int[] perTypeNumWindowList, int[] perTypeCapList, Map<List<Integer>, Integer> dp)219     private static int getNumStatesRecursive(
220             int totalCap,
221             int index,
222             int winVal,
223             int capVal,
224             int[] perTypeNumWindowList,
225             int[] perTypeCapList,
226             Map<List<Integer>, Integer> dp) {
227         List<Integer> key = List.of(totalCap, index, winVal, capVal);
228         if (!dp.containsKey(key)) {
229             if (winVal == 0 && index == 0) {
230                 dp.put(key, 1);
231             } else if (winVal == 0) {
232                 dp.put(key, getNumStatesRecursive(
233                         totalCap,
234                         index - 1,
235                         perTypeNumWindowList[index - 1],
236                         perTypeCapList[index - 1],
237                         perTypeNumWindowList,
238                         perTypeCapList,
239                         dp));
240             } else {
241                 int result = 0;
242                 for (int i = 0; i <= Math.min(totalCap, capVal); i++) {
243                     result = Math.addExact(result, getNumStatesRecursive(
244                             totalCap - i,
245                             index,
246                             winVal - 1,
247                             capVal - i,
248                             perTypeNumWindowList,
249                             perTypeCapList,
250                             dp));
251                 }
252                 dp.put(key, result);
253             }
254         }
255         return dp.get(key);
256     }
257 
258     /**
259      * Compute number of states for flexible event report API
260      *
261      * @param totalCap number of total increments
262      * @param perTypeNumWindowList reporting window for each trigger data
263      * @param perTypeCapList limit of the increment of each trigger data
264      * @return number of states
265      */
getNumStatesFlexAPI( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList)266     public static int getNumStatesFlexAPI(
267             int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList) {
268         if (!validateInputReportingPara(totalCap, perTypeNumWindowList, perTypeCapList)) {
269             LogUtil.e("Input parameters are out of range");
270             return -1;
271         }
272         boolean canComputeArithmetic = true;
273         for (int i = 1; i < perTypeNumWindowList.length; i++) {
274             if (perTypeNumWindowList[i] != perTypeNumWindowList[i - 1]) {
275                 canComputeArithmetic = false;
276                 break;
277             }
278         }
279         for (int n : perTypeCapList) {
280             if (n < totalCap) {
281                 canComputeArithmetic = false;
282                 break;
283             }
284         }
285         if (canComputeArithmetic) {
286             return getNumStatesArithmetic(totalCap, perTypeCapList.length, perTypeNumWindowList[0]);
287         }
288 
289         return getNumStatesRecursive(totalCap, perTypeNumWindowList, perTypeCapList);
290     }
291 
validateInputReportingPara( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList)292     private static boolean validateInputReportingPara(
293             int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList) {
294         for (int n : perTypeNumWindowList) {
295             if (n > PrivacyParams.getMaxFlexibleEventReportingWindows()) return false;
296         }
297         return PrivacyParams.getMaxFlexibleEventReports()
298                         >= Math.min(totalCap, Arrays.stream(perTypeCapList).sum())
299                 && perTypeNumWindowList.length
300                         <= PrivacyParams.getMaxFlexibleEventTriggerDataCardinality();
301     }
302 
303     /**
304      * @param numOfStates Number of States
305      * @return the probability to use fake reports
306      */
getFlipProbability(int numOfStates)307     public static double getFlipProbability(int numOfStates) {
308         int epsilon = PrivacyParams.getPrivacyEpsilon();
309         return numOfStates / (numOfStates + Math.exp(epsilon) - 1);
310     }
311 
getBinaryEntropy(double x)312     private static double getBinaryEntropy(double x) {
313         if (DoubleMath.fuzzyEquals(x, 0.0d, PrivacyParams.NUMBER_EQUAL_THRESHOLD)
314                 || DoubleMath.fuzzyEquals(x, 1.0d, PrivacyParams.NUMBER_EQUAL_THRESHOLD)) {
315             return 0;
316         }
317         return (-1.0) * x * DoubleMath.log2(x) - (1 - x) * DoubleMath.log2(1 - x);
318     }
319 
320     /**
321      * @param numOfStates Number of States
322      * @param flipProbability Flip Probability
323      * @return the information gain
324      */
getInformationGain(int numOfStates, double flipProbability)325     public static double getInformationGain(int numOfStates, double flipProbability) {
326         double log2Q = DoubleMath.log2(numOfStates);
327         double fakeProbability = flipProbability * (numOfStates - 1) / numOfStates;
328         return log2Q
329                 - getBinaryEntropy(fakeProbability)
330                 - fakeProbability * DoubleMath.log2(numOfStates - 1);
331     }
332 
333     /**
334      * Generate fake report set given a report specification and the rank order number
335      *
336      * @param totalCap total_cap
337      * @param perTypeNumWindowList per type number of window list
338      * @param perTypeCapList per type cap list
339      * @return a report set based on the input rank
340      */
getReportSetBasedOnRank( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, int rank, Map<List<Integer>, Integer> dp)341     public static List<AtomReportState> getReportSetBasedOnRank(
342             int totalCap,
343             int[] perTypeNumWindowList,
344             int[] perTypeCapList,
345             int rank,
346             Map<List<Integer>, Integer> dp) {
347         int triggerTypeIndex = perTypeNumWindowList.length - 1;
348 
349         return getReportSetBasedOnRankRecursive(
350                 totalCap,
351                 triggerTypeIndex,
352                 perTypeNumWindowList[triggerTypeIndex],
353                 perTypeCapList[triggerTypeIndex],
354                 rank,
355                 perTypeNumWindowList,
356                 perTypeCapList,
357                 dp);
358     }
359 
getReportSetBasedOnRankRecursive( int totalCap, int triggerTypeIndex, int winVal, int capVal, int rank, int[] perTypeNumWindowList, int[] perTypeCapList, Map<List<Integer>, Integer> numStatesLookupTable)360     private static List<AtomReportState> getReportSetBasedOnRankRecursive(
361             int totalCap,
362             int triggerTypeIndex,
363             int winVal,
364             int capVal,
365             int rank,
366             int[] perTypeNumWindowList,
367             int[] perTypeCapList,
368             Map<List<Integer>, Integer> numStatesLookupTable) {
369 
370         if (winVal == 0 && triggerTypeIndex == 0) {
371             return new ArrayList<>();
372         } else if (winVal == 0) {
373             return getReportSetBasedOnRankRecursive(
374                     totalCap,
375                     triggerTypeIndex - 1,
376                     perTypeNumWindowList[triggerTypeIndex - 1],
377                     perTypeCapList[triggerTypeIndex - 1],
378                     rank,
379                     perTypeNumWindowList,
380                     perTypeCapList,
381                     numStatesLookupTable);
382         }
383         for (int i = 0; i <= Math.min(totalCap, capVal); i++) {
384             int currentNumStates =
385                     getNumStatesRecursive(
386                             totalCap - i,
387                             triggerTypeIndex,
388                             winVal - 1,
389                             capVal - i,
390                             perTypeNumWindowList,
391                             perTypeCapList,
392                             numStatesLookupTable);
393             if (currentNumStates > rank) {
394                 // The triggers to be appended.
395                 List<AtomReportState> toAppend = new ArrayList<>();
396                 for (int k = 0; k < i; k++) {
397                     toAppend.add(new AtomReportState(triggerTypeIndex, winVal - 1));
398                 }
399                 List<AtomReportState> otherReports =
400                         getReportSetBasedOnRankRecursive(
401                                 totalCap - i,
402                                 triggerTypeIndex,
403                                 winVal - 1,
404                                 capVal - i,
405                                 rank,
406                                 perTypeNumWindowList,
407                                 perTypeCapList,
408                                 numStatesLookupTable);
409                 toAppend.addAll(otherReports);
410                 return toAppend;
411             } else {
412                 rank -= currentNumStates;
413             }
414         }
415         // will not reach here
416         return new ArrayList<>();
417     }
418 
419     /** A single report including triggerDataType and window index for the fake report generation */
420     public static class AtomReportState {
421         private final int mTriggerDataType;
422         private final int mWindowIndex;
423 
AtomReportState(int triggerDataType, int windowIndex)424         public AtomReportState(int triggerDataType, int windowIndex) {
425             this.mTriggerDataType = triggerDataType;
426             this.mWindowIndex = windowIndex;
427         }
428 
getTriggerDataType()429         public int getTriggerDataType() {
430             return mTriggerDataType;
431         }
432         ;
433 
getWindowIndex()434         public final int getWindowIndex() {
435             return mWindowIndex;
436         }
437         ;
438 
439         @Override
equals(Object obj)440         public boolean equals(Object obj) {
441             if (!(obj instanceof AtomReportState)) {
442                 return false;
443             }
444             AtomReportState t = (AtomReportState) obj;
445             return mTriggerDataType == t.mTriggerDataType && mWindowIndex == t.mWindowIndex;
446         }
447 
448         @Override
hashCode()449         public int hashCode() {
450             return Objects.hash(mWindowIndex, mTriggerDataType);
451         }
452     }
453 }
454