• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env Rscript
2#
3# Copyright 2015 Google Inc. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17source('analysis/R/read_input.R')
18
19RandomPartition <- function(total, weights) {
20  # Outputs a random partition according to a specified distribution
21  # Args:
22  #   total - number of samples
23  #   weights - weights that are proportional to the probability density
24  #              function of the target distribution
25  # Returns:
26  #   a histogram sampled according to the pdf
27  # Example:
28  #   > RandomPartition(100, c(3, 2, 1, 0, 1))
29  #   [1] 47 24 15  0 14
30  if (any(weights < 0))
31    stop("Probabilities cannot be negative")
32
33  if (sum(weights) == 0)
34    stop("Probabilities cannot sum up to 0")
35
36  bins <- length(weights)
37  result <- rep(0, bins)
38
39  # idiomatic way:
40  #   rnd_list <- sample(strs, total, replace = TRUE, weights)
41  #   apply(as.array(strs), 1, function(x) length(rnd_list[rnd_list == x]))
42  #
43  # The following is much faster for larger totals. We can replace a loop with
44  # (tail) recusion, but R chokes with the recursion depth > 850.
45
46  w <- sum(weights)
47
48  for (i in 1:bins)
49    if (total > 0) {  # if total == 0, nothing else to do
50      # invariant: w = sum(weights[i:bins])
51      # rather than computing sum every time leading to quadratic time, keep
52      # updating it
53
54      # The probability p is clamped to [0, 1] to avoid under/overflow errors.
55      p <- min(max(weights[i] / w, 0), 1)
56      # draw the number of balls falling into the current bin
57      rnd_draw <- rbinom(n = 1, size = total, prob = p)
58      result[i] <- rnd_draw  # push rnd_draw balls from total to result[i]
59      total <- total - rnd_draw
60      w <- w - weights[i]
61  }
62
63  names(result) <- names(weights)
64
65  return(result)
66}
67
68GenerateCounts <- function(params, true_map, partition, reports_per_client) {
69  # Fast simulation of the marginal table for RAPPOR reports
70  # Args:
71  #   params - parameters of the RAPPOR reporting process
72  #   true_map - hashed true inputs
73  #   partition - allocation of clients between true values
74  #   reports_per_client - number of reports (IRRs) per client
75  if (nrow(true_map) != (params$m * params$k)) {
76    stop(cat("Map does not match the params file!",
77                 "mk =", params$m * params$k,
78                 "nrow(map):", nrow(true_map),
79                 sep = " "))
80  }
81
82  # For each reporting type computes its allocation to cohorts.
83  # Output is an m x strs matrix.
84  cohorts <- as.matrix(
85                apply(as.data.frame(partition), 1,
86                      function(count) RandomPartition(count, rep(1, params$m))))
87
88  # Expands to (m x k) x strs matrix, where each element (corresponding to the
89  # bit in the aggregate Bloom filter) is repeated k times.
90  expanded <- apply(cohorts, 2, function(vec) rep(vec, each = params$k))
91
92  # For each bit, the number of clients reporting this bit:
93  clients_per_bit <- rep(apply(cohorts, 1, sum), each = params$k)
94
95  # Computes the true number of bits set to one BEFORE PRR.
96  true_ones <- apply(expanded * true_map, 1, sum)
97
98  ones_in_prr <-
99    unlist(lapply(true_ones,
100                  function(x) rbinom(n = 1, size = x, prob = 1 - params$f / 2))) +
101    unlist(lapply(clients_per_bit - true_ones,  # clients where the bit is 0
102                  function(x) rbinom(n = 1, size = x, prob =  params$f / 2)))
103
104  # Number of IRRs where each bit is reported (either as 0 or as 1)
105  reports_per_bit <- clients_per_bit * reports_per_client
106
107  ones_before_irr <- ones_in_prr * reports_per_client
108
109  ones_after_irr <-
110    unlist(lapply(ones_before_irr,
111                  function(x) rbinom(n = 1, size = x, prob = params$q))) +
112    unlist(lapply(reports_per_bit - ones_before_irr,
113                  function(x) rbinom(n = 1, size = x, prob = params$p)))
114
115  counts <- cbind(apply(cohorts, 1, sum) * reports_per_client,
116        matrix(ones_after_irr, nrow = params$m, ncol = params$k, byrow = TRUE))
117
118  if(any(is.na(counts)))
119    stop("Failed to generate bit counts. Likely due to integer overflow.")
120
121  counts
122}
123
124ComputePdf <- function(distr, range) {
125  # Outputs discrete probability density function for a given distribution
126
127  # These are the five distributions in gen_sim_input.py
128  if (distr == 'exp') {
129    pdf <- dexp(1:range, rate = 5 / range)
130  } else if (distr == 'gauss') {
131    half <- range / 2
132    left <- -half + 1
133    pdf <- dnorm(left : half, sd = range / 6)
134  } else if (distr == 'unif') {
135    # e.g. for N = 4, weights are [0.25, 0.25, 0.25, 0.25]
136    pdf <- dunif(1:range, max = range)
137  } else if (distr == 'zipf1') {
138    # Since the distrubition defined over a finite set, we allow the parameter
139    # of the Zipf distribution to be 1.
140    pdf <- sapply(1:range, function(x) 1 / x)
141  } else if (distr == 'zipf1.5') {
142    pdf <- sapply(1:range, function(x) 1 / x^1.5)
143  }
144  else {
145    stop(sprintf("Invalid distribution '%s'", distr))
146  }
147
148  pdf <- pdf / sum(pdf)  # normalize
149
150  pdf
151}
152
153# Usage:
154#
155# $ ./gen_counts.R exp 10000 1 foo_params.csv foo_true_map.csv foo
156#
157# Inputs:
158#   distribution name
159#   number of clients
160#   reports per client
161#   parameters file
162#   map file
163#   prefix for output files
164# Outputs:
165#   foo_counts.csv
166#   foo_hist.csv
167#
168# Warning: the number of reports in any cohort must be less than
169#          .Machine$integer.max
170
171main <- function(argv) {
172  distr <- argv[[1]]
173  num_clients <- as.integer(argv[[2]])
174  reports_per_client <- as.integer(argv[[3]])
175  params_file <- argv[[4]]
176  true_map_file <- argv[[5]]
177  out_prefix <- argv[[6]]
178
179  params <- ReadParameterFile(params_file)
180
181  true_map <- ReadMapFile(true_map_file, params)
182
183  num_unique_values <- length(true_map$strs)
184
185  pdf <- ComputePdf(distr, num_unique_values)
186
187  # Computes the number of clients reporting each string
188  # according to the pre-specified distribution.
189  partition <- RandomPartition(num_clients, pdf)
190
191  # Histogram
192  true_hist <- data.frame(string = true_map$strs, count = partition)
193
194  counts <- GenerateCounts(params, true_map$map, partition, reports_per_client)
195
196  # Now create a CSV file
197
198  # Opposite of ReadCountsFile in read_input.R
199  # http://stackoverflow.com/questions/6750546/export-csv-without-col-names
200  counts_path <- paste0(out_prefix, '_counts.csv')
201  write.table(counts, file = counts_path,
202              row.names = FALSE, col.names = FALSE, sep = ',')
203  cat(sprintf('Wrote %s\n', counts_path))
204
205  # TODO: Don't write strings that appear 0 times?
206  hist_path <- paste0(out_prefix, '_hist.csv')
207  write.csv(true_hist, file = hist_path, row.names = FALSE)
208  cat(sprintf('Wrote %s\n', hist_path))
209}
210
211if (length(sys.frames()) == 0) {
212  main(commandArgs(TRUE))
213}
214