• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for testing random variables."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23
24def test_moment_matching(
25    samples,
26    number_moments,
27    dist,
28    stride=0):
29  """Return z-test scores for sample moments to match analytic moments.
30
31  Given `samples`, check that the first sample `number_moments` match
32  the given  `dist` moments by doing a z-test.
33
34  Args:
35    samples: Samples from target distribution.
36    number_moments: Python `int` describing how many sample moments to check.
37    dist: SciPy distribution object that provides analytic moments.
38    stride: Distance between samples to check for statistical properties.
39      A stride of 0 means to use all samples, while other strides test for
40      spatial correlation.
41  Returns:
42    Array of z_test scores.
43  """
44
45  sample_moments = []
46  expected_moments = []
47  variance_sample_moments = []
48  x = samples.flat
49  for i in range(1, number_moments + 1):
50    strided_range = x[::(i - 1) * stride + 1]
51    sample_moments.append(np.mean(strided_range ** i))
52    expected_moments.append(dist.moment(i))
53    variance_sample_moments.append(
54        (dist.moment(2 * i) - dist.moment(i) ** 2) / len(strided_range))
55
56  z_test_scores = []
57  for i in range(1, number_moments + 1):
58    # Assume every operation has a small numerical error.
59    # It takes i multiplications to calculate one i-th moment.
60    total_variance = (
61        variance_sample_moments[i - 1] +
62        i * np.finfo(samples.dtype).eps)
63    tiny = np.finfo(samples.dtype).tiny
64    assert np.all(total_variance > 0)
65    if total_variance < tiny:
66      total_variance = tiny
67    # z_test is approximately a unit normal distribution.
68    z_test_scores.append(abs(
69        (sample_moments[i - 1] - expected_moments[i - 1]) / np.sqrt(
70            total_variance)))
71  return z_test_scores
72
73