1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for testing random variables.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23 24def test_moment_matching( 25 samples, 26 number_moments, 27 dist, 28 stride=0): 29 """Return z-test scores for sample moments to match analytic moments. 30 31 Given `samples`, check that the first sample `number_moments` match 32 the given `dist` moments by doing a z-test. 33 34 Args: 35 samples: Samples from target distribution. 36 number_moments: Python `int` describing how many sample moments to check. 37 dist: SciPy distribution object that provides analytic moments. 38 stride: Distance between samples to check for statistical properties. 39 A stride of 0 means to use all samples, while other strides test for 40 spatial correlation. 41 Returns: 42 Array of z_test scores. 43 """ 44 45 sample_moments = [] 46 expected_moments = [] 47 variance_sample_moments = [] 48 x = samples.flat 49 for i in range(1, number_moments + 1): 50 strided_range = x[::(i - 1) * stride + 1] 51 sample_moments.append(np.mean(strided_range ** i)) 52 expected_moments.append(dist.moment(i)) 53 variance_sample_moments.append( 54 (dist.moment(2 * i) - dist.moment(i) ** 2) / len(strided_range)) 55 56 z_test_scores = [] 57 for i in range(1, number_moments + 1): 58 # Assume every operation has a small numerical error. 59 # It takes i multiplications to calculate one i-th moment. 60 total_variance = ( 61 variance_sample_moments[i - 1] + 62 i * np.finfo(samples.dtype).eps) 63 tiny = np.finfo(samples.dtype).tiny 64 assert np.all(total_variance > 0) 65 if total_variance < tiny: 66 total_variance = tiny 67 # z_test is approximately a unit normal distribution. 68 z_test_scores.append(abs( 69 (sample_moments[i - 1] - expected_moments[i - 1]) / np.sqrt( 70 total_variance))) 71 return z_test_scores 72 73