• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for testing random variables."""
16
17import math
18
19import numpy as np
20
21from tensorflow.python.ops.distributions import special_math
22
23
24def test_moment_matching(
25    samples,
26    number_moments,
27    dist,
28    stride=0):
29  """Return z-test scores for sample moments to match analytic moments.
30
31  Given `samples`, check that the first sample `number_moments` match
32  the given  `dist` moments by doing a z-test.
33
34  Args:
35    samples: Samples from target distribution.
36    number_moments: Python `int` describing how many sample moments to check.
37    dist: SciPy distribution object that provides analytic moments.
38    stride: Distance between samples to check for statistical properties.
39      A stride of 0 means to use all samples, while other strides test for
40      spatial correlation.
41  Returns:
42    Array of z_test scores.
43  """
44
45  sample_moments = []
46  expected_moments = []
47  variance_sample_moments = []
48  for i in range(1, number_moments + 1):
49    if len(samples.shape) == 2:
50      strided_range = samples.flat[::(i - 1) * stride + 1]
51    else:
52      strided_range = samples[::(i - 1) * stride + 1, ...]
53    sample_moments.append(np.mean(strided_range**i, axis=0))
54    expected_moments.append(dist.moment(i))
55    variance_sample_moments.append(
56        (dist.moment(2 * i) - dist.moment(i) ** 2) / len(strided_range))
57
58  z_test_scores = []
59  for i in range(1, number_moments + 1):
60    # Assume every operation has a small numerical error.
61    # It takes i multiplications to calculate one i-th moment.
62    total_variance = (
63        variance_sample_moments[i - 1] +
64        i * np.finfo(samples.dtype).eps)
65    tiny = np.finfo(samples.dtype).tiny
66    assert np.all(total_variance > 0)
67    total_variance = np.where(total_variance < tiny, tiny, total_variance)
68    # z_test is approximately a unit normal distribution.
69    z_test_scores.append(abs(
70        (sample_moments[i - 1] - expected_moments[i - 1]) / np.sqrt(
71            total_variance)))
72  return z_test_scores
73
74
75def chi_squared(x, bins):
76  """Pearson's Chi-squared test."""
77  x = np.ravel(x)
78  n = len(x)
79  histogram, _ = np.histogram(x, bins=bins, range=(0, 1))
80  expected = n / float(bins)
81  return np.sum(np.square(histogram - expected) / expected)
82
83
84def normal_cdf(x):
85  """Cumulative distribution function for a standard normal distribution."""
86  return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2))
87
88
89def anderson_darling(x):
90  """Anderson-Darling test for a standard normal distribution."""
91  x = np.sort(np.ravel(x))
92  n = len(x)
93  i = np.linspace(1, n, n)
94  z = np.sum((2 * i - 1) * np.log(normal_cdf(x)) +
95             (2 * (n - i) + 1) * np.log(1 - normal_cdf(x)))
96  return -n - z / n
97
98
99def test_truncated_normal(assert_equal,
100                          assert_all_close,
101                          n,
102                          y,
103                          means=None,
104                          stddevs=None,
105                          minvals=None,
106                          maxvals=None,
107                          mean_atol=5e-4,
108                          median_atol=8e-4,
109                          variance_rtol=1e-3):
110  """Tests truncated normal distribution's statistics."""
111  def _normal_cdf(x):
112    return .5 * math.erfc(-x / math.sqrt(2))
113
114  def normal_pdf(x):
115    return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
116
117  def probit(x):
118    return special_math.ndtri(x)
119
120  a = -2.
121  b = 2.
122  mu = 0.
123  sigma = 1.
124
125  if minvals is not None:
126    a = minvals
127
128  if maxvals is not None:
129    b = maxvals
130
131  if means is not None:
132    mu = means
133
134  if stddevs is not None:
135    sigma = stddevs
136
137  alpha = (a - mu) / sigma
138  beta = (b - mu) / sigma
139  z = _normal_cdf(beta) - _normal_cdf(alpha)
140
141  assert_equal((y >= a).sum(), n)
142  assert_equal((y <= b).sum(), n)
143
144  # For more information on these calculations, see:
145  # Burkardt, John. "The Truncated Normal Distribution".
146  # Department of Scientific Computing website. Florida State University.
147  expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
148  y = y.astype(float)
149  actual_mean = np.mean(y)
150  assert_all_close(actual_mean, expected_mean, atol=mean_atol)
151
152  expected_median = mu + probit(
153      (_normal_cdf(alpha) + _normal_cdf(beta)) / 2.) * sigma
154  actual_median = np.median(y)
155  assert_all_close(actual_median, expected_median, atol=median_atol)
156
157  expected_variance = sigma**2 * (1 + (
158      (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - (
159          (normal_pdf(alpha) - normal_pdf(beta)) / z)**2)
160  actual_variance = np.var(y)
161  assert_all_close(
162      actual_variance,
163      expected_variance,
164      rtol=variance_rtol)
165