1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for testing random variables.""" 16 17import math 18 19import numpy as np 20 21from tensorflow.python.ops.distributions import special_math 22 23 24def test_moment_matching( 25 samples, 26 number_moments, 27 dist, 28 stride=0): 29 """Return z-test scores for sample moments to match analytic moments. 30 31 Given `samples`, check that the first sample `number_moments` match 32 the given `dist` moments by doing a z-test. 33 34 Args: 35 samples: Samples from target distribution. 36 number_moments: Python `int` describing how many sample moments to check. 37 dist: SciPy distribution object that provides analytic moments. 38 stride: Distance between samples to check for statistical properties. 39 A stride of 0 means to use all samples, while other strides test for 40 spatial correlation. 41 Returns: 42 Array of z_test scores. 43 """ 44 45 sample_moments = [] 46 expected_moments = [] 47 variance_sample_moments = [] 48 for i in range(1, number_moments + 1): 49 if len(samples.shape) == 2: 50 strided_range = samples.flat[::(i - 1) * stride + 1] 51 else: 52 strided_range = samples[::(i - 1) * stride + 1, ...] 53 sample_moments.append(np.mean(strided_range**i, axis=0)) 54 expected_moments.append(dist.moment(i)) 55 variance_sample_moments.append( 56 (dist.moment(2 * i) - dist.moment(i) ** 2) / len(strided_range)) 57 58 z_test_scores = [] 59 for i in range(1, number_moments + 1): 60 # Assume every operation has a small numerical error. 61 # It takes i multiplications to calculate one i-th moment. 62 total_variance = ( 63 variance_sample_moments[i - 1] + 64 i * np.finfo(samples.dtype).eps) 65 tiny = np.finfo(samples.dtype).tiny 66 assert np.all(total_variance > 0) 67 total_variance = np.where(total_variance < tiny, tiny, total_variance) 68 # z_test is approximately a unit normal distribution. 69 z_test_scores.append(abs( 70 (sample_moments[i - 1] - expected_moments[i - 1]) / np.sqrt( 71 total_variance))) 72 return z_test_scores 73 74 75def chi_squared(x, bins): 76 """Pearson's Chi-squared test.""" 77 x = np.ravel(x) 78 n = len(x) 79 histogram, _ = np.histogram(x, bins=bins, range=(0, 1)) 80 expected = n / float(bins) 81 return np.sum(np.square(histogram - expected) / expected) 82 83 84def normal_cdf(x): 85 """Cumulative distribution function for a standard normal distribution.""" 86 return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2)) 87 88 89def anderson_darling(x): 90 """Anderson-Darling test for a standard normal distribution.""" 91 x = np.sort(np.ravel(x)) 92 n = len(x) 93 i = np.linspace(1, n, n) 94 z = np.sum((2 * i - 1) * np.log(normal_cdf(x)) + 95 (2 * (n - i) + 1) * np.log(1 - normal_cdf(x))) 96 return -n - z / n 97 98 99def test_truncated_normal(assert_equal, 100 assert_all_close, 101 n, 102 y, 103 means=None, 104 stddevs=None, 105 minvals=None, 106 maxvals=None, 107 mean_atol=5e-4, 108 median_atol=8e-4, 109 variance_rtol=1e-3): 110 """Tests truncated normal distribution's statistics.""" 111 def _normal_cdf(x): 112 return .5 * math.erfc(-x / math.sqrt(2)) 113 114 def normal_pdf(x): 115 return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) 116 117 def probit(x): 118 return special_math.ndtri(x) 119 120 a = -2. 121 b = 2. 122 mu = 0. 123 sigma = 1. 124 125 if minvals is not None: 126 a = minvals 127 128 if maxvals is not None: 129 b = maxvals 130 131 if means is not None: 132 mu = means 133 134 if stddevs is not None: 135 sigma = stddevs 136 137 alpha = (a - mu) / sigma 138 beta = (b - mu) / sigma 139 z = _normal_cdf(beta) - _normal_cdf(alpha) 140 141 assert_equal((y >= a).sum(), n) 142 assert_equal((y <= b).sum(), n) 143 144 # For more information on these calculations, see: 145 # Burkardt, John. "The Truncated Normal Distribution". 146 # Department of Scientific Computing website. Florida State University. 147 expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma 148 y = y.astype(float) 149 actual_mean = np.mean(y) 150 assert_all_close(actual_mean, expected_mean, atol=mean_atol) 151 152 expected_median = mu + probit( 153 (_normal_cdf(alpha) + _normal_cdf(beta)) / 2.) * sigma 154 actual_median = np.median(y) 155 assert_all_close(actual_median, expected_median, atol=median_atol) 156 157 expected_variance = sigma**2 * (1 + ( 158 (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - ( 159 (normal_pdf(alpha) - normal_pdf(beta)) / z)**2) 160 actual_variance = np.var(y) 161 assert_all_close( 162 actual_variance, 163 expected_variance, 164 rtol=variance_rtol) 165