• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.adservices.service.measurement.noising;
18 
19 import com.android.adservices.service.measurement.PrivacyParams;
20 
21 import com.google.common.math.DoubleMath;
22 import com.google.common.math.LongMath;
23 
24 import java.math.BigInteger;
25 import java.util.ArrayList;
26 import java.util.Arrays;
27 import java.util.List;
28 import java.util.Objects;
29 
30 /**
31  * Combinatorics utilities used for randomization.
32  */
33 public class Combinatorics {
34 
35     /**
36      * Returns the k-combination associated with the number {@code combinationIndex}. In
37      * other words, returns the combination of {@code k} integers uniquely indexed by
38      * {@code combinationIndex} in the combinatorial number system.
39      * https://en.wikipedia.org/wiki/Combinatorial_number_system
40      *
41      * @return combinationIndex-th lexicographically smallest k-combination.
42      * @throws ArithmeticException in case of int overflow
43      */
getKCombinationAtIndex(long combinationIndex, int k)44     static long[] getKCombinationAtIndex(long combinationIndex, int k) {
45         // Computes the combinationIndex-th lexicographically smallest k-combination.
46         // https://en.wikipedia.org/wiki/Combinatorial_number_system
47         //
48         // A k-combination is a sequence of k non-negative integers in decreasing order.
49         // a_k > a_{k-1} > ... > a_2 > a_1 >= 0.
50         // k-combinations can be ordered lexicographically, with the smallest
51         // k-combination being a_k=k-1, a_{k-1}=k-2, .., a_1=0. Given an index
52         // combinationIndex>=0, and an order k, this method returns the
53         // combinationIndex-th smallest k-combination.
54         //
55         // Given an index combinationIndex, the combinationIndex-th k-combination
56         // is the unique set of k non-negative integers
57         // a_k > a_{k-1} > ... > a_2 > a_1 >= 0
58         // such that combinationIndex = \sum_{i=1}^k {a_i}\choose{i}
59         //
60         // We find this set via a simple greedy algorithm.
61         // http://math0.wvstateu.edu/~baker/cs405/code/Combinadics.html
62         long[] result = new long[k];
63         if (k == 0) {
64             return result;
65         }
66         // To find a_k, iterate candidates upwards from 0 until we've found the
67         // maximum a such that (a choose k) <= combinationIndex. Let a_k = a. Use
68         // the previous binomial coefficient to compute the next one. Note: possible
69         // to speed this up via something other than incremental search.
70         long target = combinationIndex;
71         long candidate = (long) k - 1L;
72         long binomialCoefficient = 0L;
73         long nextBinomialCoefficient = 1L;
74         while (nextBinomialCoefficient <= target) {
75             candidate++;
76             binomialCoefficient = nextBinomialCoefficient;
77             // (n + 1 choose k) = (n choose k) * (n + 1) / (n + 1 - k)
78             nextBinomialCoefficient = Math.multiplyExact(binomialCoefficient, (candidate + 1));
79             nextBinomialCoefficient /= candidate + 1 - k;
80         }
81         // We know from the k-combination definition, all subsequent values will be
82         // strictly decreasing. Find them all by decrementing candidate.
83         // Use the previous binomial coefficient to compute the next one.
84         long currentK = (long) k;
85         int currentIndex = 0;
86         while (true) {
87             if (binomialCoefficient <= target) {
88                 result[currentIndex] = candidate;
89                 currentIndex++;
90                 target -= binomialCoefficient;
91                 if (currentIndex == k) {
92                     return result;
93                 }
94                 // (n - 1 choose k - 1) = (n choose k) * k / n
95                 binomialCoefficient = binomialCoefficient * currentK / candidate;
96                 currentK--;
97             } else {
98                 // (n - 1 choose k) = (n choose k) * (n - k) / n
99                 binomialCoefficient = binomialCoefficient * (candidate - currentK) / candidate;
100             }
101             candidate--;
102         }
103     }
104 
105     /**
106      * Returns the number of possible sequences of "stars and bars" sequences
107      * https://en.wikipedia.org/wiki/Stars_and_bars_(combinatorics),
108      * which is equivalent to (numStars + numBars choose numStars).
109      *
110      * @param numStars number of stars
111      * @param numBars  number of bars
112      * @return number of possible sequences
113      */
getNumberOfStarsAndBarsSequences(int numStars, int numBars)114     public static long getNumberOfStarsAndBarsSequences(int numStars, int numBars) {
115         // Note, LongMath::binomial returns Long.MAX_VALUE rather than overflow.
116         return LongMath.binomial(numStars + numBars, numStars);
117     }
118 
119     /**
120      * Returns an array of the indices of every star in the stars and bars sequence indexed by
121      * {@code sequenceIndex}.
122      *
123      * @param numStars number of stars in the sequence
124      * @param sequenceIndex index of the sequence
125      * @return list of indices of every star in stars & bars sequence
126      */
getStarIndices(int numStars, long sequenceIndex)127     public static long[] getStarIndices(int numStars, long sequenceIndex) {
128         return getKCombinationAtIndex(sequenceIndex, numStars);
129     }
130 
131     /**
132      * From an array with the index of every star in a stars and bars sequence, returns an array
133      * which, for every star, counts the number of bars preceding it.
134      *
135      * @param starIndices indices of the stars in descending order
136      * @return count of bars preceding every star
137      */
getBarsPrecedingEachStar(long[] starIndices)138     public static long[] getBarsPrecedingEachStar(long[] starIndices) {
139         for (int i = 0; i < starIndices.length; i++) {
140             long starIndex = starIndices[i];
141             // There are {@code starIndex} prior positions in the sequence, and `i` prior
142             // stars, so there are {@code starIndex - i} prior bars.
143             starIndices[i] = starIndex - ((long) starIndices.length - 1L - (long) i);
144         }
145         return starIndices;
146     }
147 
148     /**
149      * Compute number of states from the trigger specification
150      *
151      * @param numBucketIncrements number of bucket increments (equivalent to number of triggers)
152      * @param numTriggerData number of trigger data. (equivalent to number of metadata)
153      * @param numWindows number of reporting windows
154      * @return number of states
155      */
getNumStatesArithmetic( int numBucketIncrements, int numTriggerData, int numWindows)156     public static long getNumStatesArithmetic(
157             int numBucketIncrements, int numTriggerData, int numWindows) {
158         int numStars = numBucketIncrements;
159         int numBars = Math.multiplyExact(numTriggerData, numWindows);
160         return getNumberOfStarsAndBarsSequences(numStars, numBars);
161     }
162 
163     /**
164      * Using dynamic programming to compute number of states. Returns Long.MAX_VALUE if the result
165      * is greater than {@code bound}.
166      *
167      * @param totalCap total incremental cap
168      * @param perTypeNumWindowList reporting window per trigger data
169      * @param perTypeCapList cap per trigger data
170      * @param bound the highest state count allowed
171      * @return number of states
172      * @throws ArithmeticException in case of long overflow
173      */
getNumStatesIterative( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long bound)174     private static long getNumStatesIterative(
175             int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long bound) {
176         // Assumes perTypeCapList cannot sum to more than int value. Overflowing int here can lead
177         // to an exception when declaring the array size later, based on the min value.
178         int sum = 0;
179         for (int cap : perTypeCapList) {
180             sum += cap;
181         }
182         int leastTotalCap = Math.min(totalCap, sum);
183         long[][] dp = new long[2][leastTotalCap + 1];
184         int prev = 0;
185         int curr = 1;
186 
187         dp[prev][0] = 1L;
188         long result = 0L;
189 
190         for (int i = 0; i < perTypeNumWindowList.length && perTypeNumWindowList[i] > 0; i++) {
191             int winCount = perTypeNumWindowList[i];
192             int capCount = perTypeCapList[i];
193             result = 0L;
194 
195             for (int cap = 0; cap < leastTotalCap + 1; cap++) {
196                 dp[curr][cap] = 0L;
197 
198                 for (int capVal = 0; capVal < Math.min(cap, capCount) + 1; capVal++) {
199                     dp[curr][cap] = Math.addExact(
200                             dp[curr][cap],
201                             Math.multiplyExact(
202                                     dp[prev][cap - capVal],
203                                     getNumberOfStarsAndBarsSequences(capVal, winCount - 1)));
204                 }
205 
206                 result = Math.addExact(result, dp[curr][cap]);
207 
208                 if (result > bound) {
209                     return Long.MAX_VALUE;
210                 }
211             }
212 
213             curr ^= 1;
214             prev ^= 1;
215         }
216 
217         return Math.max(result, 1L);
218     }
219 
220     /**
221      * Compute number of states for flexible event report API. Returns Long.MAX_VALUE if the result
222      * exceeds {@code bound}.
223      *
224      * @param totalCap number of total increments
225      * @param perTypeNumWindowList reporting window for each trigger data
226      * @param perTypeCapList limit of the increment of each trigger data
227      * @param bound the highest state count allowed
228      * @return number of states
229      * @throws ArithmeticException in case of long overflow during the iterative procedure
230      */
getNumStatesFlexApi( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long bound)231     public static long getNumStatesFlexApi(
232             int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long bound) {
233         if (perTypeNumWindowList.length == 0 || perTypeCapList.length == 0) {
234             return 1;
235         }
236         for (int i = 1; i < perTypeNumWindowList.length; i++) {
237             if (perTypeNumWindowList[i] != perTypeNumWindowList[i - 1]) {
238                 return getNumStatesIterative(totalCap, perTypeNumWindowList, perTypeCapList, bound);
239             }
240         }
241         for (int n : perTypeCapList) {
242             if (n < totalCap) {
243                 return getNumStatesIterative(totalCap, perTypeNumWindowList, perTypeCapList, bound);
244             }
245         }
246 
247         long result = getNumStatesArithmetic(
248                 totalCap, perTypeCapList.length, perTypeNumWindowList[0]);
249 
250         return result > bound ? Long.MAX_VALUE : result;
251     }
252 
253     /**
254      * @param numOfStates Number of States
255      * @return the probability to use fake reports
256      */
getFlipProbability(long numOfStates, double privacyEpsilon)257     public static double getFlipProbability(long numOfStates, double privacyEpsilon) {
258         return (numOfStates) / (numOfStates + Math.exp(privacyEpsilon) - 1D);
259     }
260 
getBinaryEntropy(double x)261     private static double getBinaryEntropy(double x) {
262         if (DoubleMath.fuzzyEquals(x, 0.0d, PrivacyParams.NUMBER_EQUAL_THRESHOLD)
263                 || DoubleMath.fuzzyEquals(x, 1.0d, PrivacyParams.NUMBER_EQUAL_THRESHOLD)) {
264             return 0.0D;
265         }
266         return (-1.0D) * x * DoubleMath.log2(x) - (1 - x) * DoubleMath.log2(1 - x);
267     }
268 
269     /**
270      * @param numOfStates Number of States
271      * @param flipProbability Flip Probability
272      * @return the information gain
273      */
getInformationGain(long numOfStates, double flipProbability)274     public static double getInformationGain(long numOfStates, double flipProbability) {
275         if (numOfStates <= 1L) {
276             return 0d;
277         }
278         double log2Q = DoubleMath.log2(numOfStates);
279         double fakeProbability = flipProbability * (numOfStates - 1L) / numOfStates;
280         return log2Q
281                 - getBinaryEntropy(fakeProbability)
282                 - fakeProbability * DoubleMath.log2(numOfStates - 1);
283     }
284 
285     /**
286      * Returns the max information gain given the num of trigger states, attribution scope limit and
287      * max num event states.
288      *
289      * @param numTriggerStates The number of trigger states.
290      * @param attributionScopeLimit The attribution scope limit.
291      * @param maxEventStates The maximum number of event states (expected to be positive).
292      * @return The max information gain.
293      */
getMaxInformationGainWithAttributionScope( long numTriggerStates, long attributionScopeLimit, long maxEventStates)294     public static double getMaxInformationGainWithAttributionScope(
295             long numTriggerStates, long attributionScopeLimit, long maxEventStates) {
296         if (numTriggerStates <= 0 || maxEventStates <= 0) {
297             throw new IllegalArgumentException(
298                     "numTriggerStates and maxEventStates must be positive");
299         }
300         BigInteger totalNumStates =
301                 BigInteger.valueOf(numTriggerStates)
302                         .add(
303                                 BigInteger.valueOf(maxEventStates)
304                                         .multiply(BigInteger.valueOf(attributionScopeLimit - 1)));
305         return DoubleMath.log2(totalNumStates.doubleValue());
306     }
307 
308     /**
309      * Generate fake report set given a trigger specification and the rank order number
310      *
311      * @param totalCap total_cap
312      * @param perTypeNumWindowList per type number of window list
313      * @param perTypeCapList per type cap list
314      * @param rank the rank of the report state within all the report states
315      * @return a report set based on the input rank
316      */
getReportSetBasedOnRank( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long rank)317     public static List<AtomReportState> getReportSetBasedOnRank(
318             int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long rank) {
319         int triggerTypeIndex = perTypeNumWindowList.length - 1;
320         return getReportSetBasedOnRankRecursive(
321                 totalCap,
322                 triggerTypeIndex,
323                 perTypeNumWindowList[triggerTypeIndex],
324                 perTypeCapList[triggerTypeIndex],
325                 rank,
326                 perTypeNumWindowList,
327                 perTypeCapList);
328     }
329 
getReportSetBasedOnRankRecursive( int totalCap, int triggerTypeIndex, int winVal, int capVal, long rank, int[] perTypeNumWindowList, int[] perTypeCapList)330     private static List<AtomReportState> getReportSetBasedOnRankRecursive(
331             int totalCap,
332             int triggerTypeIndex,
333             int winVal,
334             int capVal,
335             long rank,
336             int[] perTypeNumWindowList,
337             int[] perTypeCapList) {
338 
339         if (winVal == 0 && triggerTypeIndex == 0) {
340             return new ArrayList<>();
341         } else if (winVal == 0) {
342             return getReportSetBasedOnRankRecursive(
343                     totalCap,
344                     triggerTypeIndex - 1,
345                     perTypeNumWindowList[triggerTypeIndex - 1],
346                     perTypeCapList[triggerTypeIndex - 1],
347                     rank,
348                     perTypeNumWindowList,
349                     perTypeCapList);
350         }
351         for (int i = 0; i <= Math.min(totalCap, capVal); i++) {
352             int[] perTypeNumWindowListClone = Arrays.copyOfRange(
353                     perTypeNumWindowList, 0, triggerTypeIndex + 1);
354             perTypeNumWindowListClone[triggerTypeIndex] = winVal - 1;
355             int[] perTypeCapListClone = Arrays.copyOfRange(
356                     perTypeCapList, 0, triggerTypeIndex + 1);
357             perTypeCapListClone[triggerTypeIndex] = capVal - i;
358             long currentNumStates =
359                     getNumStatesIterative(
360                             totalCap - i,
361                             perTypeNumWindowListClone,
362                             perTypeCapListClone,
363                             Long.MAX_VALUE);
364             if (currentNumStates > rank) {
365                 // The triggers to be appended.
366                 List<AtomReportState> toAppend = new ArrayList<>();
367                 for (int k = 0; k < i; k++) {
368                     toAppend.add(new AtomReportState(triggerTypeIndex, winVal - 1));
369                 }
370                 List<AtomReportState> otherReports =
371                         getReportSetBasedOnRankRecursive(
372                                 totalCap - i,
373                                 triggerTypeIndex,
374                                 winVal - 1,
375                                 capVal - i,
376                                 rank,
377                                 perTypeNumWindowList,
378                                 perTypeCapList);
379                 toAppend.addAll(otherReports);
380                 return toAppend;
381             } else {
382                 rank -= currentNumStates;
383             }
384         }
385         // will not reach here
386         return new ArrayList<>();
387     }
388 
389     /** A single report including triggerDataType and window index for the fake report generation */
390     public static class AtomReportState {
391         private final int mTriggerDataType;
392         private final int mWindowIndex;
393 
AtomReportState(int triggerDataType, int windowIndex)394         public AtomReportState(int triggerDataType, int windowIndex) {
395             this.mTriggerDataType = triggerDataType;
396             this.mWindowIndex = windowIndex;
397         }
398 
getTriggerDataType()399         public int getTriggerDataType() {
400             return mTriggerDataType;
401         }
402         ;
403 
getWindowIndex()404         public final int getWindowIndex() {
405             return mWindowIndex;
406         }
407         ;
408 
409         @Override
equals(Object obj)410         public boolean equals(Object obj) {
411             if (!(obj instanceof AtomReportState)) {
412                 return false;
413             }
414             AtomReportState t = (AtomReportState) obj;
415             return mTriggerDataType == t.mTriggerDataType && mWindowIndex == t.mWindowIndex;
416         }
417 
418         @Override
hashCode()419         public int hashCode() {
420             return Objects.hash(mWindowIndex, mTriggerDataType);
421         }
422     }
423 }
424