• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Student t distribution."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import importlib
22import math
23
24import numpy as np
25
26from tensorflow.python.eager import backprop
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import random_seed
29from tensorflow.python.framework import test_util
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import nn_ops
32from tensorflow.python.ops.distributions import student_t
33from tensorflow.python.platform import test
34from tensorflow.python.platform import tf_logging
35
36
37def try_import(name):  # pylint: disable=invalid-name
38  module = None
39  try:
40    module = importlib.import_module(name)
41  except ImportError as e:
42    tf_logging.warning("Could not import %s: %s" % (name, str(e)))
43  return module
44
45
46stats = try_import("scipy.stats")
47
48
49@test_util.run_all_in_graph_and_eager_modes
50class StudentTTest(test.TestCase):
51
52  def testStudentPDFAndLogPDF(self):
53    batch_size = 6
54    df = constant_op.constant([3.] * batch_size)
55    mu = constant_op.constant([7.] * batch_size)
56    sigma = constant_op.constant([8.] * batch_size)
57    df_v = 3.
58    mu_v = 7.
59    sigma_v = 8.
60    t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
61    student = student_t.StudentT(df, loc=mu, scale=-sigma)
62
63    log_pdf = student.log_prob(t)
64    self.assertEquals(log_pdf.get_shape(), (6,))
65    log_pdf_values = self.evaluate(log_pdf)
66    pdf = student.prob(t)
67    self.assertEquals(pdf.get_shape(), (6,))
68    pdf_values = self.evaluate(pdf)
69
70    if not stats:
71      return
72
73    expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
74    expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
75    self.assertAllClose(expected_log_pdf, log_pdf_values)
76    self.assertAllClose(np.log(expected_pdf), log_pdf_values)
77    self.assertAllClose(expected_pdf, pdf_values)
78    self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
79
80  def testStudentLogPDFMultidimensional(self):
81    batch_size = 6
82    df = constant_op.constant([[1.5, 7.2]] * batch_size)
83    mu = constant_op.constant([[3., -3.]] * batch_size)
84    sigma = constant_op.constant(
85        [[-math.sqrt(10.), math.sqrt(15.)]] * batch_size)
86    df_v = np.array([1.5, 7.2])
87    mu_v = np.array([3., -3.])
88    sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)])
89    t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T
90    student = student_t.StudentT(df, loc=mu, scale=sigma)
91    log_pdf = student.log_prob(t)
92    log_pdf_values = self.evaluate(log_pdf)
93    self.assertEqual(log_pdf.get_shape(), (6, 2))
94    pdf = student.prob(t)
95    pdf_values = self.evaluate(pdf)
96    self.assertEqual(pdf.get_shape(), (6, 2))
97
98    if not stats:
99      return
100    expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
101    expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
102    self.assertAllClose(expected_log_pdf, log_pdf_values)
103    self.assertAllClose(np.log(expected_pdf), log_pdf_values)
104    self.assertAllClose(expected_pdf, pdf_values)
105    self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
106
107  def testStudentCDFAndLogCDF(self):
108    batch_size = 6
109    df = constant_op.constant([3.] * batch_size)
110    mu = constant_op.constant([7.] * batch_size)
111    sigma = constant_op.constant([-8.] * batch_size)
112    df_v = 3.
113    mu_v = 7.
114    sigma_v = 8.
115    t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
116    student = student_t.StudentT(df, loc=mu, scale=sigma)
117
118    log_cdf = student.log_cdf(t)
119    self.assertEquals(log_cdf.get_shape(), (6,))
120    log_cdf_values = self.evaluate(log_cdf)
121    cdf = student.cdf(t)
122    self.assertEquals(cdf.get_shape(), (6,))
123    cdf_values = self.evaluate(cdf)
124
125    if not stats:
126      return
127    expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v)
128    expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v)
129    self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5)
130    self.assertAllClose(
131        np.log(expected_cdf), log_cdf_values, atol=0., rtol=1e-5)
132    self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5)
133    self.assertAllClose(
134        np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5)
135
136  def testStudentEntropy(self):
137    df_v = np.array([[2., 3., 7.]])  # 1x3
138    mu_v = np.array([[1., -1, 0]])  # 1x3
139    sigma_v = np.array([[1., -2., 3.]]).T  # transposed => 3x1
140    student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v)
141    ent = student.entropy()
142    ent_values = self.evaluate(ent)
143
144    # Help scipy broadcast to 3x3
145    ones = np.array([[1, 1, 1]])
146    sigma_bc = np.abs(sigma_v) * ones
147    mu_bc = ones.T * mu_v
148    df_bc = ones.T * df_v
149    if not stats:
150      return
151    expected_entropy = stats.t.entropy(
152        np.reshape(df_bc, [-1]),
153        loc=np.reshape(mu_bc, [-1]),
154        scale=np.reshape(sigma_bc, [-1]))
155    expected_entropy = np.reshape(expected_entropy, df_bc.shape)
156    self.assertAllClose(expected_entropy, ent_values)
157
158  def testStudentSample(self):
159    df = constant_op.constant(4.)
160    mu = constant_op.constant(3.)
161    sigma = constant_op.constant(-math.sqrt(10.))
162    df_v = 4.
163    mu_v = 3.
164    sigma_v = np.sqrt(10.)
165    n = constant_op.constant(200000)
166    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
167    samples = student.sample(n, seed=123456)
168    sample_values = self.evaluate(samples)
169    n_val = 200000
170    self.assertEqual(sample_values.shape, (n_val,))
171    self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0)
172    self.assertAllClose(
173        sample_values.var(), sigma_v**2 * df_v / (df_v - 2), rtol=0.1, atol=0)
174    self._checkKLApprox(df_v, mu_v, sigma_v, sample_values)
175
176  # Test that sampling with the same seed twice gives the same results.
177  def testStudentSampleMultipleTimes(self):
178    df = constant_op.constant(4.)
179    mu = constant_op.constant(3.)
180    sigma = constant_op.constant(math.sqrt(10.))
181    n = constant_op.constant(100)
182
183    random_seed.set_random_seed(654321)
184    student = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t1")
185    samples1 = self.evaluate(student.sample(n, seed=123456))
186
187    random_seed.set_random_seed(654321)
188    student2 = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t2")
189    samples2 = self.evaluate(student2.sample(n, seed=123456))
190
191    self.assertAllClose(samples1, samples2)
192
193  def testStudentSampleSmallDfNoNan(self):
194    df_v = [1e-1, 1e-5, 1e-10, 1e-20]
195    df = constant_op.constant(df_v)
196    n = constant_op.constant(200000)
197    student = student_t.StudentT(df=df, loc=1., scale=1.)
198    samples = student.sample(n, seed=123456)
199    sample_values = self.evaluate(samples)
200    n_val = 200000
201    self.assertEqual(sample_values.shape, (n_val, 4))
202    self.assertTrue(np.all(np.logical_not(np.isnan(sample_values))))
203
204  def testStudentSampleMultiDimensional(self):
205    batch_size = 7
206    df = constant_op.constant([[5., 7.]] * batch_size)
207    mu = constant_op.constant([[3., -3.]] * batch_size)
208    sigma = constant_op.constant(
209        [[math.sqrt(10.), math.sqrt(15.)]] * batch_size)
210    df_v = [5., 7.]
211    mu_v = [3., -3.]
212    sigma_v = [np.sqrt(10.), np.sqrt(15.)]
213    n = constant_op.constant(200000)
214    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
215    samples = student.sample(n, seed=123456)
216    sample_values = self.evaluate(samples)
217    self.assertEqual(samples.get_shape(), (200000, batch_size, 2))
218    self.assertAllClose(
219        sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0)
220    self.assertAllClose(
221        sample_values[:, 0, 0].var(),
222        sigma_v[0]**2 * df_v[0] / (df_v[0] - 2),
223        rtol=0.2,
224        atol=0)
225    self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0])
226    self.assertAllClose(
227        sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0)
228    self.assertAllClose(
229        sample_values[:, 0, 1].var(),
230        sigma_v[1]**2 * df_v[1] / (df_v[1] - 2),
231        rtol=0.2,
232        atol=0)
233    self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1])
234
235  def _checkKLApprox(self, df, mu, sigma, samples):
236    n = samples.size
237    np.random.seed(137)
238    if not stats:
239      return
240    sample_scipy = stats.t.rvs(df, loc=mu, scale=sigma, size=n)
241    covg = 0.99
242    r = stats.t.interval(covg, df, loc=mu, scale=sigma)
243    bins = 100
244    hist, _ = np.histogram(samples, bins=bins, range=r)
245    hist_scipy, _ = np.histogram(sample_scipy, bins=bins, range=r)
246    self.assertGreater(hist.sum(), n * (covg - .01))
247    self.assertGreater(hist_scipy.sum(), n * (covg - .01))
248    hist_min1 = hist + 1.  # put at least one item in each bucket
249    hist_norm = hist_min1 / hist_min1.sum()
250    hist_scipy_min1 = hist_scipy + 1.  # put at least one item in each bucket
251    hist_scipy_norm = hist_scipy_min1 / hist_scipy_min1.sum()
252    kl_appx = np.sum(np.log(hist_scipy_norm / hist_norm) * hist_scipy_norm)
253    self.assertLess(kl_appx, 1)
254
255  def testBroadcastingParams(self):
256
257    def _check(student):
258      self.assertEqual(student.mean().get_shape(), (3,))
259      self.assertEqual(student.variance().get_shape(), (3,))
260      self.assertEqual(student.entropy().get_shape(), (3,))
261      self.assertEqual(student.log_prob(2.).get_shape(), (3,))
262      self.assertEqual(student.prob(2.).get_shape(), (3,))
263      self.assertEqual(student.sample(37).get_shape(), (37, 3,))
264
265    _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.))
266    _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.))
267    _check(student_t.StudentT(df=7., loc=3., scale=[2., 3., 4.,]))
268
269  def testBroadcastingPdfArgs(self):
270
271    def _assert_shape(student, arg, shape):
272      self.assertEqual(student.log_prob(arg).get_shape(), shape)
273      self.assertEqual(student.prob(arg).get_shape(), shape)
274
275    def _check(student):
276      _assert_shape(student, 2., (3,))
277      xs = np.array([2., 3., 4.], dtype=np.float32)
278      _assert_shape(student, xs, (3,))
279      xs = np.array([xs])
280      _assert_shape(student, xs, (1, 3))
281      xs = xs.T
282      _assert_shape(student, xs, (3, 3))
283
284    _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.))
285    _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.))
286    _check(student_t.StudentT(df=7., loc=3., scale=[2., 3., 4.,]))
287
288    def _check2d(student):
289      _assert_shape(student, 2., (1, 3))
290      xs = np.array([2., 3., 4.], dtype=np.float32)
291      _assert_shape(student, xs, (1, 3))
292      xs = np.array([xs])
293      _assert_shape(student, xs, (1, 3))
294      xs = xs.T
295      _assert_shape(student, xs, (3, 3))
296
297    _check2d(student_t.StudentT(df=[[2., 3., 4.,]], loc=2., scale=1.))
298    _check2d(student_t.StudentT(df=7., loc=[[2., 3., 4.,]], scale=1.))
299    _check2d(student_t.StudentT(df=7., loc=3., scale=[[2., 3., 4.,]]))
300
301    def _check2d_rows(student):
302      _assert_shape(student, 2., (3, 1))
303      xs = np.array([2., 3., 4.], dtype=np.float32)  # (3,)
304      _assert_shape(student, xs, (3, 3))
305      xs = np.array([xs])  # (1,3)
306      _assert_shape(student, xs, (3, 3))
307      xs = xs.T  # (3,1)
308      _assert_shape(student, xs, (3, 1))
309
310    _check2d_rows(student_t.StudentT(df=[[2.], [3.], [4.]], loc=2., scale=1.))
311    _check2d_rows(student_t.StudentT(df=7., loc=[[2.], [3.], [4.]], scale=1.))
312    _check2d_rows(student_t.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]]))
313
314  def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
315    mu = [1., 3.3, 4.4]
316    student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.])
317    mean = self.evaluate(student.mean())
318    self.assertAllClose([1., 3.3, 4.4], mean)
319
320  def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self):
321    mu = [1., 3.3, 4.4]
322    student = student_t.StudentT(
323        df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.], allow_nan_stats=False)
324    with self.assertRaisesOpError("x < y"):
325      self.evaluate(student.mean())
326
327  def testMeanAllowNanStatsIsTrueReturnsNaNForUndefinedBatchMembers(self):
328    mu = [-2, 0., 1., 3.3, 4.4]
329    sigma = [5., 4., 3., 2., 1.]
330    student = student_t.StudentT(
331        df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma, allow_nan_stats=True)
332    mean = self.evaluate(student.mean())
333    self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean)
334
335  def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self):
336    # df = 0.5 ==> undefined mean ==> undefined variance.
337    # df = 1.5 ==> infinite variance.
338    df = [0.5, 1.5, 3., 5., 7.]
339    mu = [-2, 0., 1., 3.3, 4.4]
340    sigma = [5., 4., 3., 2., 1.]
341    student = student_t.StudentT(
342        df=df, loc=mu, scale=sigma, allow_nan_stats=True)
343    var = self.evaluate(student.variance())
344
345    if not stats:
346      return
347    expected_var = [
348        stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
349    ]
350    # Slicing off first element due to nan/inf mismatch in different SciPy
351    # versions.
352    self.assertAllClose(expected_var[1:], var[1:])
353
354  def testVarianceAllowNanStatsFalseGivesCorrectValueForDefinedBatchMembers(
355      self):
356    # df = 1.5 ==> infinite variance.
357    df = [1.5, 3., 5., 7.]
358    mu = [0., 1., 3.3, 4.4]
359    sigma = [4., 3., 2., 1.]
360    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
361    var = self.evaluate(student.variance())
362
363    if not stats:
364      return
365    expected_var = [
366        stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
367    ]
368    self.assertAllClose(expected_var, var)
369
370  def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
371    # df <= 1 ==> variance not defined
372    student = student_t.StudentT(df=1., loc=0., scale=1., allow_nan_stats=False)
373    with self.assertRaisesOpError("x < y"):
374      self.evaluate(student.variance())
375
376    # df <= 1 ==> variance not defined
377    student = student_t.StudentT(
378        df=0.5, loc=0., scale=1., allow_nan_stats=False)
379    with self.assertRaisesOpError("x < y"):
380      self.evaluate(student.variance())
381
382  def testStd(self):
383    # Defined for all batch members.
384    df = [3.5, 5., 3., 5., 7.]
385    mu = [-2.2]
386    sigma = [5., 4., 3., 2., 1.]
387    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
388    # Test broadcast of mu across shape of df/sigma
389    stddev = self.evaluate(student.stddev())
390    mu *= len(df)
391
392    if not stats:
393      return
394    expected_stddev = [
395        stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
396    ]
397    self.assertAllClose(expected_stddev, stddev)
398
399  def testMode(self):
400    df = [0.5, 1., 3]
401    mu = [-1, 0., 1]
402    sigma = [5., 4., 3.]
403    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
404    # Test broadcast of mu across shape of df/sigma
405    mode = self.evaluate(student.mode())
406    self.assertAllClose([-1., 0, 1], mode)
407
408  def testPdfOfSample(self):
409    student = student_t.StudentT(df=3., loc=np.pi, scale=1.)
410    num = 20000
411    samples = student.sample(num, seed=123456)
412    pdfs = student.prob(samples)
413    mean = student.mean()
414    mean_pdf = student.prob(student.mean())
415    sample_vals, pdf_vals, mean_val, mean_pdf_val = self.evaluate(
416        [samples, pdfs, student.mean(), mean_pdf])
417    self.assertEqual(samples.get_shape(), (num,))
418    self.assertEqual(pdfs.get_shape(), (num,))
419    self.assertEqual(mean.get_shape(), ())
420    self.assertNear(np.pi, np.mean(sample_vals), err=0.1)
421    self.assertNear(np.pi, mean_val, err=1e-6)
422    # Verify integral over sample*pdf ~= 1.
423    # Tolerance increased since eager was getting a value of 1.002041.
424    self._assertIntegral(sample_vals, pdf_vals, err=5e-2)
425    if not stats:
426      return
427    self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6)
428
429  def testFullyReparameterized(self):
430    df = constant_op.constant(2.0)
431    mu = constant_op.constant(1.0)
432    sigma = constant_op.constant(3.0)
433    with backprop.GradientTape() as tape:
434      tape.watch(df)
435      tape.watch(mu)
436      tape.watch(sigma)
437      student = student_t.StudentT(df=df, loc=mu, scale=sigma)
438      samples = student.sample(100)
439    grad_df, grad_mu, grad_sigma = tape.gradient(samples, [df, mu, sigma])
440    self.assertIsNotNone(grad_df)
441    self.assertIsNotNone(grad_mu)
442    self.assertIsNotNone(grad_sigma)
443
444  def testPdfOfSampleMultiDims(self):
445    student = student_t.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.)
446    self.assertAllEqual([], student.event_shape)
447    self.assertAllEqual([], self.evaluate(student.event_shape_tensor()))
448    self.assertAllEqual([2, 2], student.batch_shape)
449    self.assertAllEqual([2, 2], self.evaluate(student.batch_shape_tensor()))
450    num = 50000
451    samples = student.sample(num, seed=123456)
452    pdfs = student.prob(samples)
453    sample_vals, pdf_vals = self.evaluate([samples, pdfs])
454    self.assertEqual(samples.get_shape(), (num, 2, 2))
455    self.assertEqual(pdfs.get_shape(), (num, 2, 2))
456    self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=0.1)
457    self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=0.1)
458    self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.05)
459    self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.05)
460    self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.05)
461    self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.05)
462    if not stats:
463      return
464    self.assertNear(
465        stats.t.var(7., loc=0., scale=3.),  # loc d.n. effect var
466        np.var(sample_vals[:, :, 0]),
467        err=1.0)
468    self.assertNear(
469        stats.t.var(11., loc=0., scale=3.),  # loc d.n. effect var
470        np.var(sample_vals[:, :, 1]),
471        err=1.0)
472
473  def _assertIntegral(self, sample_vals, pdf_vals, err=1.5e-3):
474    s_p = zip(sample_vals, pdf_vals)
475    prev = (sample_vals.min() - 1000, 0)
476    total = 0
477    for k in sorted(s_p, key=lambda x: x[0]):
478      pair_pdf = (k[1] + prev[1]) / 2
479      total += (k[0] - prev[0]) * pair_pdf
480      prev = k
481    self.assertNear(1., total, err=err)
482
483  def testNegativeDofFails(self):
484    with self.assertRaisesOpError(r"Condition x > 0 did not hold"):
485      student = student_t.StudentT(
486          df=[2, -5.], loc=0., scale=1., validate_args=True, name="S")
487      self.evaluate(student.mean())
488
489  def testStudentTWithAbsDfSoftplusScale(self):
490    df = constant_op.constant([-3.2, -4.6])
491    mu = constant_op.constant([-4.2, 3.4])
492    sigma = constant_op.constant([-6.4, -8.8])
493    student = student_t.StudentTWithAbsDfSoftplusScale(
494        df=df, loc=mu, scale=sigma)
495    self.assertAllClose(
496        math_ops.floor(self.evaluate(math_ops.abs(df))),
497        self.evaluate(student.df))
498    self.assertAllClose(self.evaluate(mu), self.evaluate(student.loc))
499    self.assertAllClose(
500        self.evaluate(nn_ops.softplus(sigma)), self.evaluate(student.scale))
501
502
503if __name__ == "__main__":
504  test.main()
505