• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Student t distribution."""
16
17import importlib
18import math
19
20import numpy as np
21
22from tensorflow.python.eager import backprop
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import random_seed
25from tensorflow.python.framework import test_util
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import nn_ops
28from tensorflow.python.ops.distributions import student_t
29from tensorflow.python.platform import test
30from tensorflow.python.platform import tf_logging
31
32
33def try_import(name):  # pylint: disable=invalid-name
34  module = None
35  try:
36    module = importlib.import_module(name)
37  except ImportError as e:
38    tf_logging.warning("Could not import %s: %s" % (name, str(e)))
39  return module
40
41
42stats = try_import("scipy.stats")
43
44
45@test_util.run_all_in_graph_and_eager_modes
46class StudentTTest(test.TestCase):
47
48  def testStudentPDFAndLogPDF(self):
49    batch_size = 6
50    df = constant_op.constant([3.] * batch_size)
51    mu = constant_op.constant([7.] * batch_size)
52    sigma = constant_op.constant([8.] * batch_size)
53    df_v = 3.
54    mu_v = 7.
55    sigma_v = 8.
56    t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
57    student = student_t.StudentT(df, loc=mu, scale=-sigma)  # pylint: disable=invalid-unary-operand-type
58
59    log_pdf = student.log_prob(t)
60    self.assertEqual(log_pdf.get_shape(), (6,))
61    log_pdf_values = self.evaluate(log_pdf)
62    pdf = student.prob(t)
63    self.assertEqual(pdf.get_shape(), (6,))
64    pdf_values = self.evaluate(pdf)
65
66    if not stats:
67      return
68
69    expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
70    expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
71    self.assertAllClose(expected_log_pdf, log_pdf_values)
72    self.assertAllClose(np.log(expected_pdf), log_pdf_values)
73    self.assertAllClose(expected_pdf, pdf_values)
74    self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
75
76  def testStudentLogPDFMultidimensional(self):
77    batch_size = 6
78    df = constant_op.constant([[1.5, 7.2]] * batch_size)
79    mu = constant_op.constant([[3., -3.]] * batch_size)
80    sigma = constant_op.constant(
81        [[-math.sqrt(10.), math.sqrt(15.)]] * batch_size)
82    df_v = np.array([1.5, 7.2])
83    mu_v = np.array([3., -3.])
84    sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)])
85    t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T
86    student = student_t.StudentT(df, loc=mu, scale=sigma)
87    log_pdf = student.log_prob(t)
88    log_pdf_values = self.evaluate(log_pdf)
89    self.assertEqual(log_pdf.get_shape(), (6, 2))
90    pdf = student.prob(t)
91    pdf_values = self.evaluate(pdf)
92    self.assertEqual(pdf.get_shape(), (6, 2))
93
94    if not stats:
95      return
96    expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
97    expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
98    self.assertAllClose(expected_log_pdf, log_pdf_values)
99    self.assertAllClose(np.log(expected_pdf), log_pdf_values)
100    self.assertAllClose(expected_pdf, pdf_values)
101    self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
102
103  def testStudentCDFAndLogCDF(self):
104    batch_size = 6
105    df = constant_op.constant([3.] * batch_size)
106    mu = constant_op.constant([7.] * batch_size)
107    sigma = constant_op.constant([-8.] * batch_size)
108    df_v = 3.
109    mu_v = 7.
110    sigma_v = 8.
111    t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
112    student = student_t.StudentT(df, loc=mu, scale=sigma)
113
114    log_cdf = student.log_cdf(t)
115    self.assertEqual(log_cdf.get_shape(), (6,))
116    log_cdf_values = self.evaluate(log_cdf)
117    cdf = student.cdf(t)
118    self.assertEqual(cdf.get_shape(), (6,))
119    cdf_values = self.evaluate(cdf)
120
121    if not stats:
122      return
123    expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v)
124    expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v)
125    self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5)
126    self.assertAllClose(
127        np.log(expected_cdf), log_cdf_values, atol=0., rtol=1e-5)
128    self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5)
129    self.assertAllClose(
130        np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5)
131
132  def testStudentEntropy(self):
133    df_v = np.array([[2., 3., 7.]])  # 1x3
134    mu_v = np.array([[1., -1, 0]])  # 1x3
135    sigma_v = np.array([[1., -2., 3.]]).T  # transposed => 3x1
136    student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v)
137    ent = student.entropy()
138    ent_values = self.evaluate(ent)
139
140    # Help scipy broadcast to 3x3
141    ones = np.array([[1, 1, 1]])
142    sigma_bc = np.abs(sigma_v) * ones
143    mu_bc = ones.T * mu_v
144    df_bc = ones.T * df_v
145    if not stats:
146      return
147    expected_entropy = stats.t.entropy(
148        np.reshape(df_bc, [-1]),
149        loc=np.reshape(mu_bc, [-1]),
150        scale=np.reshape(sigma_bc, [-1]))
151    expected_entropy = np.reshape(expected_entropy, df_bc.shape)
152    self.assertAllClose(expected_entropy, ent_values)
153
154  def testStudentSample(self):
155    df = constant_op.constant(4.)
156    mu = constant_op.constant(3.)
157    sigma = constant_op.constant(-math.sqrt(10.))
158    df_v = 4.
159    mu_v = 3.
160    sigma_v = np.sqrt(10.)
161    n = constant_op.constant(200000)
162    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
163    samples = student.sample(n, seed=123456)
164    sample_values = self.evaluate(samples)
165    n_val = 200000
166    self.assertEqual(sample_values.shape, (n_val,))
167    self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0)
168    self.assertAllClose(
169        sample_values.var(), sigma_v**2 * df_v / (df_v - 2), rtol=0.1, atol=0)
170    self._checkKLApprox(df_v, mu_v, sigma_v, sample_values)
171
172  # Test that sampling with the same seed twice gives the same results.
173  def testStudentSampleMultipleTimes(self):
174    df = constant_op.constant(4.)
175    mu = constant_op.constant(3.)
176    sigma = constant_op.constant(math.sqrt(10.))
177    n = constant_op.constant(100)
178
179    random_seed.set_random_seed(654321)
180    student = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t1")
181    samples1 = self.evaluate(student.sample(n, seed=123456))
182
183    random_seed.set_random_seed(654321)
184    student2 = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t2")
185    samples2 = self.evaluate(student2.sample(n, seed=123456))
186
187    self.assertAllClose(samples1, samples2)
188
189  def testStudentSampleSmallDfNoNan(self):
190    df_v = [1e-1, 1e-5, 1e-10, 1e-20]
191    df = constant_op.constant(df_v)
192    n = constant_op.constant(200000)
193    student = student_t.StudentT(df=df, loc=1., scale=1.)
194    samples = student.sample(n, seed=123456)
195    sample_values = self.evaluate(samples)
196    n_val = 200000
197    self.assertEqual(sample_values.shape, (n_val, 4))
198    self.assertTrue(np.all(np.logical_not(np.isnan(sample_values))))
199
200  def testStudentSampleMultiDimensional(self):
201    batch_size = 7
202    df = constant_op.constant([[5., 7.]] * batch_size)
203    mu = constant_op.constant([[3., -3.]] * batch_size)
204    sigma = constant_op.constant(
205        [[math.sqrt(10.), math.sqrt(15.)]] * batch_size)
206    df_v = [5., 7.]
207    mu_v = [3., -3.]
208    sigma_v = [np.sqrt(10.), np.sqrt(15.)]
209    n = constant_op.constant(200000)
210    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
211    samples = student.sample(n, seed=123456)
212    sample_values = self.evaluate(samples)
213    self.assertEqual(samples.get_shape(), (200000, batch_size, 2))
214    self.assertAllClose(
215        sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0)
216    self.assertAllClose(
217        sample_values[:, 0, 0].var(),
218        sigma_v[0]**2 * df_v[0] / (df_v[0] - 2),
219        rtol=0.2,
220        atol=0)
221    self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0])
222    self.assertAllClose(
223        sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0)
224    self.assertAllClose(
225        sample_values[:, 0, 1].var(),
226        sigma_v[1]**2 * df_v[1] / (df_v[1] - 2),
227        rtol=0.2,
228        atol=0)
229    self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1])
230
231  def _checkKLApprox(self, df, mu, sigma, samples):
232    n = samples.size
233    np.random.seed(137)
234    if not stats:
235      return
236    sample_scipy = stats.t.rvs(df, loc=mu, scale=sigma, size=n)
237    covg = 0.99
238    r = stats.t.interval(covg, df, loc=mu, scale=sigma)
239    bins = 100
240    hist, _ = np.histogram(samples, bins=bins, range=r)
241    hist_scipy, _ = np.histogram(sample_scipy, bins=bins, range=r)
242    self.assertGreater(hist.sum(), n * (covg - .01))
243    self.assertGreater(hist_scipy.sum(), n * (covg - .01))
244    hist_min1 = hist + 1.  # put at least one item in each bucket
245    hist_norm = hist_min1 / hist_min1.sum()
246    hist_scipy_min1 = hist_scipy + 1.  # put at least one item in each bucket
247    hist_scipy_norm = hist_scipy_min1 / hist_scipy_min1.sum()
248    kl_appx = np.sum(np.log(hist_scipy_norm / hist_norm) * hist_scipy_norm)
249    self.assertLess(kl_appx, 1)
250
251  def testBroadcastingParams(self):
252
253    def _check(student):
254      self.assertEqual(student.mean().get_shape(), (3,))
255      self.assertEqual(student.variance().get_shape(), (3,))
256      self.assertEqual(student.entropy().get_shape(), (3,))
257      self.assertEqual(student.log_prob(2.).get_shape(), (3,))
258      self.assertEqual(student.prob(2.).get_shape(), (3,))
259      self.assertEqual(student.sample(37).get_shape(), (37, 3,))
260
261    _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.))
262    _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.))
263    _check(student_t.StudentT(df=7., loc=3., scale=[2., 3., 4.,]))
264
265  def testBroadcastingPdfArgs(self):
266
267    def _assert_shape(student, arg, shape):
268      self.assertEqual(student.log_prob(arg).get_shape(), shape)
269      self.assertEqual(student.prob(arg).get_shape(), shape)
270
271    def _check(student):
272      _assert_shape(student, 2., (3,))
273      xs = np.array([2., 3., 4.], dtype=np.float32)
274      _assert_shape(student, xs, (3,))
275      xs = np.array([xs])
276      _assert_shape(student, xs, (1, 3))
277      xs = xs.T
278      _assert_shape(student, xs, (3, 3))
279
280    _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.))
281    _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.))
282    _check(student_t.StudentT(df=7., loc=3., scale=[2., 3., 4.,]))
283
284    def _check2d(student):
285      _assert_shape(student, 2., (1, 3))
286      xs = np.array([2., 3., 4.], dtype=np.float32)
287      _assert_shape(student, xs, (1, 3))
288      xs = np.array([xs])
289      _assert_shape(student, xs, (1, 3))
290      xs = xs.T
291      _assert_shape(student, xs, (3, 3))
292
293    _check2d(student_t.StudentT(df=[[2., 3., 4.,]], loc=2., scale=1.))
294    _check2d(student_t.StudentT(df=7., loc=[[2., 3., 4.,]], scale=1.))
295    _check2d(student_t.StudentT(df=7., loc=3., scale=[[2., 3., 4.,]]))
296
297    def _check2d_rows(student):
298      _assert_shape(student, 2., (3, 1))
299      xs = np.array([2., 3., 4.], dtype=np.float32)  # (3,)
300      _assert_shape(student, xs, (3, 3))
301      xs = np.array([xs])  # (1,3)
302      _assert_shape(student, xs, (3, 3))
303      xs = xs.T  # (3,1)
304      _assert_shape(student, xs, (3, 1))
305
306    _check2d_rows(student_t.StudentT(df=[[2.], [3.], [4.]], loc=2., scale=1.))
307    _check2d_rows(student_t.StudentT(df=7., loc=[[2.], [3.], [4.]], scale=1.))
308    _check2d_rows(student_t.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]]))
309
310  def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
311    mu = [1., 3.3, 4.4]
312    student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.])
313    mean = self.evaluate(student.mean())
314    self.assertAllClose([1., 3.3, 4.4], mean)
315
316  def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self):
317    mu = [1., 3.3, 4.4]
318    student = student_t.StudentT(
319        df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.], allow_nan_stats=False)
320    with self.assertRaisesOpError("x < y"):
321      self.evaluate(student.mean())
322
323  def testMeanAllowNanStatsIsTrueReturnsNaNForUndefinedBatchMembers(self):
324    mu = [-2, 0., 1., 3.3, 4.4]
325    sigma = [5., 4., 3., 2., 1.]
326    student = student_t.StudentT(
327        df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma, allow_nan_stats=True)
328    mean = self.evaluate(student.mean())
329    self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean)
330
331  def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self):
332    # df = 0.5 ==> undefined mean ==> undefined variance.
333    # df = 1.5 ==> infinite variance.
334    df = [0.5, 1.5, 3., 5., 7.]
335    mu = [-2, 0., 1., 3.3, 4.4]
336    sigma = [5., 4., 3., 2., 1.]
337    student = student_t.StudentT(
338        df=df, loc=mu, scale=sigma, allow_nan_stats=True)
339    var = self.evaluate(student.variance())
340
341    if not stats:
342      return
343    expected_var = [
344        stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
345    ]
346    # Slicing off first element due to nan/inf mismatch in different SciPy
347    # versions.
348    self.assertAllClose(expected_var[1:], var[1:])
349
350  def testVarianceAllowNanStatsFalseGivesCorrectValueForDefinedBatchMembers(
351      self):
352    # df = 1.5 ==> infinite variance.
353    df = [1.5, 3., 5., 7.]
354    mu = [0., 1., 3.3, 4.4]
355    sigma = [4., 3., 2., 1.]
356    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
357    var = self.evaluate(student.variance())
358
359    if not stats:
360      return
361    expected_var = [
362        stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
363    ]
364    self.assertAllClose(expected_var, var)
365
366  def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
367    # df <= 1 ==> variance not defined
368    student = student_t.StudentT(df=1., loc=0., scale=1., allow_nan_stats=False)
369    with self.assertRaisesOpError("x < y"):
370      self.evaluate(student.variance())
371
372    # df <= 1 ==> variance not defined
373    student = student_t.StudentT(
374        df=0.5, loc=0., scale=1., allow_nan_stats=False)
375    with self.assertRaisesOpError("x < y"):
376      self.evaluate(student.variance())
377
378  def testStd(self):
379    # Defined for all batch members.
380    df = [3.5, 5., 3., 5., 7.]
381    mu = [-2.2]
382    sigma = [5., 4., 3., 2., 1.]
383    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
384    # Test broadcast of mu across shape of df/sigma
385    stddev = self.evaluate(student.stddev())
386    mu *= len(df)
387
388    if not stats:
389      return
390    expected_stddev = [
391        stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
392    ]
393    self.assertAllClose(expected_stddev, stddev)
394
395  def testMode(self):
396    df = [0.5, 1., 3]
397    mu = [-1, 0., 1]
398    sigma = [5., 4., 3.]
399    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
400    # Test broadcast of mu across shape of df/sigma
401    mode = self.evaluate(student.mode())
402    self.assertAllClose([-1., 0, 1], mode)
403
404  def testPdfOfSample(self):
405    student = student_t.StudentT(df=3., loc=np.pi, scale=1.)
406    num = 20000
407    samples = student.sample(num, seed=123456)
408    pdfs = student.prob(samples)
409    mean = student.mean()
410    mean_pdf = student.prob(student.mean())
411    sample_vals, pdf_vals, mean_val, mean_pdf_val = self.evaluate(
412        [samples, pdfs, student.mean(), mean_pdf])
413    self.assertEqual(samples.get_shape(), (num,))
414    self.assertEqual(pdfs.get_shape(), (num,))
415    self.assertEqual(mean.get_shape(), ())
416    self.assertNear(np.pi, np.mean(sample_vals), err=0.1)
417    self.assertNear(np.pi, mean_val, err=1e-6)
418    # Verify integral over sample*pdf ~= 1.
419    # Tolerance increased since eager was getting a value of 1.002041.
420    self._assertIntegral(sample_vals, pdf_vals, err=5e-2)
421    if not stats:
422      return
423    self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6)
424
425  def testFullyReparameterized(self):
426    df = constant_op.constant(2.0)
427    mu = constant_op.constant(1.0)
428    sigma = constant_op.constant(3.0)
429    with backprop.GradientTape() as tape:
430      tape.watch(df)
431      tape.watch(mu)
432      tape.watch(sigma)
433      student = student_t.StudentT(df=df, loc=mu, scale=sigma)
434      samples = student.sample(100)
435    grad_df, grad_mu, grad_sigma = tape.gradient(samples, [df, mu, sigma])
436    self.assertIsNotNone(grad_df)
437    self.assertIsNotNone(grad_mu)
438    self.assertIsNotNone(grad_sigma)
439
440  def testPdfOfSampleMultiDims(self):
441    student = student_t.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.)
442    self.assertAllEqual([], student.event_shape)
443    self.assertAllEqual([], self.evaluate(student.event_shape_tensor()))
444    self.assertAllEqual([2, 2], student.batch_shape)
445    self.assertAllEqual([2, 2], self.evaluate(student.batch_shape_tensor()))
446    num = 50000
447    samples = student.sample(num, seed=123456)
448    pdfs = student.prob(samples)
449    sample_vals, pdf_vals = self.evaluate([samples, pdfs])
450    self.assertEqual(samples.get_shape(), (num, 2, 2))
451    self.assertEqual(pdfs.get_shape(), (num, 2, 2))
452    self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=0.1)
453    self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=0.1)
454    self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.05)
455    self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.05)
456    self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.05)
457    self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.05)
458    if not stats:
459      return
460    self.assertNear(
461        stats.t.var(7., loc=0., scale=3.),  # loc d.n. effect var
462        np.var(sample_vals[:, :, 0]),
463        err=1.0)
464    self.assertNear(
465        stats.t.var(11., loc=0., scale=3.),  # loc d.n. effect var
466        np.var(sample_vals[:, :, 1]),
467        err=1.0)
468
469  def _assertIntegral(self, sample_vals, pdf_vals, err=1.5e-3):
470    s_p = zip(sample_vals, pdf_vals)
471    prev = (sample_vals.min() - 1000, 0)
472    total = 0
473    for k in sorted(s_p, key=lambda x: x[0]):
474      pair_pdf = (k[1] + prev[1]) / 2
475      total += (k[0] - prev[0]) * pair_pdf
476      prev = k
477    self.assertNear(1., total, err=err)
478
479  def testNegativeDofFails(self):
480    with self.assertRaisesOpError(r"Condition x > 0 did not hold"):
481      student = student_t.StudentT(
482          df=[2, -5.], loc=0., scale=1., validate_args=True, name="S")
483      self.evaluate(student.mean())
484
485  def testStudentTWithAbsDfSoftplusScale(self):
486    df = constant_op.constant([-3.2, -4.6])
487    mu = constant_op.constant([-4.2, 3.4])
488    sigma = constant_op.constant([-6.4, -8.8])
489    student = student_t.StudentTWithAbsDfSoftplusScale(
490        df=df, loc=mu, scale=sigma)
491    self.assertAllClose(
492        math_ops.floor(self.evaluate(math_ops.abs(df))),
493        self.evaluate(student.df))
494    self.assertAllClose(self.evaluate(mu), self.evaluate(student.loc))
495    self.assertAllClose(
496        self.evaluate(nn_ops.softplus(sigma)), self.evaluate(student.scale))
497
498
499if __name__ == "__main__":
500  test.main()
501