• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import importlib
17
18import numpy as np
19
20from tensorflow.python.eager import backprop
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import nn_ops
25from tensorflow.python.ops.distributions import laplace as laplace_lib
26from tensorflow.python.platform import test
27
28from tensorflow.python.platform import tf_logging
29
30
31def try_import(name):  # pylint: disable=invalid-name
32  module = None
33  try:
34    module = importlib.import_module(name)
35  except ImportError as e:
36    tf_logging.warning("Could not import %s: %s" % (name, str(e)))
37  return module
38
39
40stats = try_import("scipy.stats")
41
42
43@test_util.run_all_in_graph_and_eager_modes
44class LaplaceTest(test.TestCase):
45
46  def testLaplaceShape(self):
47    loc = constant_op.constant([3.0] * 5)
48    scale = constant_op.constant(11.0)
49    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
50
51    self.assertEqual(self.evaluate(laplace.batch_shape_tensor()), (5,))
52    self.assertEqual(laplace.batch_shape, tensor_shape.TensorShape([5]))
53    self.assertAllEqual(self.evaluate(laplace.event_shape_tensor()), [])
54    self.assertEqual(laplace.event_shape, tensor_shape.TensorShape([]))
55
56  def testLaplaceLogPDF(self):
57    batch_size = 6
58    loc = constant_op.constant([2.0] * batch_size)
59    scale = constant_op.constant([3.0] * batch_size)
60    loc_v = 2.0
61    scale_v = 3.0
62    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
63    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
64    log_pdf = laplace.log_prob(x)
65    self.assertEqual(log_pdf.get_shape(), (6,))
66    if not stats:
67      return
68    expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
69    self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
70
71    pdf = laplace.prob(x)
72    self.assertEqual(pdf.get_shape(), (6,))
73    self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
74
75  def testLaplaceLogPDFMultidimensional(self):
76    batch_size = 6
77    loc = constant_op.constant([[2.0, 4.0]] * batch_size)
78    scale = constant_op.constant([[3.0, 4.0]] * batch_size)
79    loc_v = np.array([2.0, 4.0])
80    scale_v = np.array([3.0, 4.0])
81    x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
82    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
83    log_pdf = laplace.log_prob(x)
84    log_pdf_values = self.evaluate(log_pdf)
85    self.assertEqual(log_pdf.get_shape(), (6, 2))
86
87    pdf = laplace.prob(x)
88    pdf_values = self.evaluate(pdf)
89    self.assertEqual(pdf.get_shape(), (6, 2))
90    if not stats:
91      return
92    expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
93    self.assertAllClose(log_pdf_values, expected_log_pdf)
94    self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
95
96  def testLaplaceLogPDFMultidimensionalBroadcasting(self):
97    batch_size = 6
98    loc = constant_op.constant([[2.0, 4.0]] * batch_size)
99    scale = constant_op.constant(3.0)
100    loc_v = np.array([2.0, 4.0])
101    scale_v = 3.0
102    x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
103    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
104    log_pdf = laplace.log_prob(x)
105    log_pdf_values = self.evaluate(log_pdf)
106    self.assertEqual(log_pdf.get_shape(), (6, 2))
107
108    pdf = laplace.prob(x)
109    pdf_values = self.evaluate(pdf)
110    self.assertEqual(pdf.get_shape(), (6, 2))
111    if not stats:
112      return
113    expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
114    self.assertAllClose(log_pdf_values, expected_log_pdf)
115    self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
116
117  def testLaplaceCDF(self):
118    batch_size = 6
119    loc = constant_op.constant([2.0] * batch_size)
120    scale = constant_op.constant([3.0] * batch_size)
121    loc_v = 2.0
122    scale_v = 3.0
123    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
124
125    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
126
127    cdf = laplace.cdf(x)
128    self.assertEqual(cdf.get_shape(), (6,))
129    if not stats:
130      return
131    expected_cdf = stats.laplace.cdf(x, loc_v, scale=scale_v)
132    self.assertAllClose(self.evaluate(cdf), expected_cdf)
133
134  def testLaplaceLogCDF(self):
135    batch_size = 6
136    loc = constant_op.constant([2.0] * batch_size)
137    scale = constant_op.constant([3.0] * batch_size)
138    loc_v = 2.0
139    scale_v = 3.0
140    x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
141
142    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
143
144    cdf = laplace.log_cdf(x)
145    self.assertEqual(cdf.get_shape(), (6,))
146    if not stats:
147      return
148    expected_cdf = stats.laplace.logcdf(x, loc_v, scale=scale_v)
149    self.assertAllClose(self.evaluate(cdf), expected_cdf)
150
151  def testLaplaceLogSurvivalFunction(self):
152    batch_size = 6
153    loc = constant_op.constant([2.0] * batch_size)
154    scale = constant_op.constant([3.0] * batch_size)
155    loc_v = 2.0
156    scale_v = 3.0
157    x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
158
159    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
160
161    sf = laplace.log_survival_function(x)
162    self.assertEqual(sf.get_shape(), (6,))
163    if not stats:
164      return
165    expected_sf = stats.laplace.logsf(x, loc_v, scale=scale_v)
166    self.assertAllClose(self.evaluate(sf), expected_sf)
167
168  def testLaplaceMean(self):
169    loc_v = np.array([1.0, 3.0, 2.5])
170    scale_v = np.array([1.0, 4.0, 5.0])
171    laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
172    self.assertEqual(laplace.mean().get_shape(), (3,))
173    if not stats:
174      return
175    expected_means = stats.laplace.mean(loc_v, scale=scale_v)
176    self.assertAllClose(self.evaluate(laplace.mean()), expected_means)
177
178  def testLaplaceMode(self):
179    loc_v = np.array([0.5, 3.0, 2.5])
180    scale_v = np.array([1.0, 4.0, 5.0])
181    laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
182    self.assertEqual(laplace.mode().get_shape(), (3,))
183    self.assertAllClose(self.evaluate(laplace.mode()), loc_v)
184
185  def testLaplaceVariance(self):
186    loc_v = np.array([1.0, 3.0, 2.5])
187    scale_v = np.array([1.0, 4.0, 5.0])
188    laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
189    self.assertEqual(laplace.variance().get_shape(), (3,))
190    if not stats:
191      return
192    expected_variances = stats.laplace.var(loc_v, scale=scale_v)
193    self.assertAllClose(self.evaluate(laplace.variance()), expected_variances)
194
195  def testLaplaceStd(self):
196    loc_v = np.array([1.0, 3.0, 2.5])
197    scale_v = np.array([1.0, 4.0, 5.0])
198    laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
199    self.assertEqual(laplace.stddev().get_shape(), (3,))
200    if not stats:
201      return
202    expected_stddev = stats.laplace.std(loc_v, scale=scale_v)
203    self.assertAllClose(self.evaluate(laplace.stddev()), expected_stddev)
204
205  def testLaplaceEntropy(self):
206    loc_v = np.array([1.0, 3.0, 2.5])
207    scale_v = np.array([1.0, 4.0, 5.0])
208    laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
209    self.assertEqual(laplace.entropy().get_shape(), (3,))
210    if not stats:
211      return
212    expected_entropy = stats.laplace.entropy(loc_v, scale=scale_v)
213    self.assertAllClose(self.evaluate(laplace.entropy()), expected_entropy)
214
215  def testLaplaceSample(self):
216    loc_v = 4.0
217    scale_v = 3.0
218    loc = constant_op.constant(loc_v)
219    scale = constant_op.constant(scale_v)
220    n = 100000
221    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
222    samples = laplace.sample(n, seed=137)
223    sample_values = self.evaluate(samples)
224    self.assertEqual(samples.get_shape(), (n,))
225    self.assertEqual(sample_values.shape, (n,))
226    if not stats:
227      return
228    self.assertAllClose(
229        sample_values.mean(),
230        stats.laplace.mean(loc_v, scale=scale_v),
231        rtol=0.05,
232        atol=0.)
233    self.assertAllClose(
234        sample_values.var(),
235        stats.laplace.var(loc_v, scale=scale_v),
236        rtol=0.05,
237        atol=0.)
238    self.assertTrue(self._kstest(loc_v, scale_v, sample_values))
239
240  def testLaplaceFullyReparameterized(self):
241    loc = constant_op.constant(4.0)
242    scale = constant_op.constant(3.0)
243    with backprop.GradientTape() as tape:
244      tape.watch(loc)
245      tape.watch(scale)
246      laplace = laplace_lib.Laplace(loc=loc, scale=scale)
247      samples = laplace.sample(100)
248    grad_loc, grad_scale = tape.gradient(samples, [loc, scale])
249    self.assertIsNotNone(grad_loc)
250    self.assertIsNotNone(grad_scale)
251
252  def testLaplaceSampleMultiDimensional(self):
253    loc_v = np.array([np.arange(1, 101, dtype=np.float32)])  # 1 x 100
254    scale_v = np.array([np.arange(1, 11, dtype=np.float32)]).T  # 10 x 1
255    laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
256    n = 10000
257    samples = laplace.sample(n, seed=137)
258    sample_values = self.evaluate(samples)
259    self.assertEqual(samples.get_shape(), (n, 10, 100))
260    self.assertEqual(sample_values.shape, (n, 10, 100))
261    zeros = np.zeros_like(loc_v + scale_v)  # 10 x 100
262    loc_bc = loc_v + zeros
263    scale_bc = scale_v + zeros
264    if not stats:
265      return
266    self.assertAllClose(
267        sample_values.mean(axis=0),
268        stats.laplace.mean(loc_bc, scale=scale_bc),
269        rtol=0.35,
270        atol=0.)
271    self.assertAllClose(
272        sample_values.var(axis=0),
273        stats.laplace.var(loc_bc, scale=scale_bc),
274        rtol=0.105,
275        atol=0.0)
276    fails = 0
277    trials = 0
278    for ai, a in enumerate(np.reshape(loc_v, [-1])):
279      for bi, b in enumerate(np.reshape(scale_v, [-1])):
280        s = sample_values[:, bi, ai]
281        trials += 1
282        fails += 0 if self._kstest(a, b, s) else 1
283    self.assertLess(fails, trials * 0.03)
284
285  def _kstest(self, loc, scale, samples):
286    # Uses the Kolmogorov-Smirnov test for goodness of fit.
287    if not stats:
288      return True  # If scipy isn't available, return "True" for passing
289    ks, _ = stats.kstest(samples, stats.laplace(loc, scale=scale).cdf)
290    # Return True when the test passes.
291    return ks < 0.02
292
293  def testLaplacePdfOfSampleMultiDims(self):
294    laplace = laplace_lib.Laplace(loc=[7., 11.], scale=[[5.], [6.]])
295    num = 50000
296    samples = laplace.sample(num, seed=137)
297    pdfs = laplace.prob(samples)
298    sample_vals, pdf_vals = self.evaluate([samples, pdfs])
299    self.assertEqual(samples.get_shape(), (num, 2, 2))
300    self.assertEqual(pdfs.get_shape(), (num, 2, 2))
301    self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
302    self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
303    self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
304    self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
305    if not stats:
306      return
307    self.assertAllClose(
308        stats.laplace.mean(
309            [[7., 11.], [7., 11.]], scale=np.array([[5., 5.], [6., 6.]])),
310        sample_vals.mean(axis=0),
311        rtol=0.05,
312        atol=0.)
313    self.assertAllClose(
314        stats.laplace.var([[7., 11.], [7., 11.]],
315                          scale=np.array([[5., 5.], [6., 6.]])),
316        sample_vals.var(axis=0),
317        rtol=0.05,
318        atol=0.)
319
320  def _assertIntegral(self, sample_vals, pdf_vals, err=1e-3):
321    s_p = zip(sample_vals, pdf_vals)
322    prev = (0, 0)
323    total = 0
324    for k in sorted(s_p, key=lambda x: x[0]):
325      pair_pdf = (k[1] + prev[1]) / 2
326      total += (k[0] - prev[0]) * pair_pdf
327      prev = k
328    self.assertNear(1., total, err=err)
329
330  def testLaplaceNonPositiveInitializationParamsRaises(self):
331    loc_v = constant_op.constant(0.0, name="loc")
332    scale_v = constant_op.constant(-1.0, name="scale")
333    with self.assertRaisesOpError("Condition x > 0 did not hold element-wise"):
334      laplace = laplace_lib.Laplace(
335          loc=loc_v, scale=scale_v, validate_args=True)
336      self.evaluate(laplace.mean())
337    loc_v = constant_op.constant(1.0, name="loc")
338    scale_v = constant_op.constant(0.0, name="scale")
339    with self.assertRaisesOpError("Condition x > 0 did not hold element-wise"):
340      laplace = laplace_lib.Laplace(
341          loc=loc_v, scale=scale_v, validate_args=True)
342      self.evaluate(laplace.mean())
343
344  def testLaplaceWithSoftplusScale(self):
345    loc_v = constant_op.constant([0.0, 1.0], name="loc")
346    scale_v = constant_op.constant([-1.0, 2.0], name="scale")
347    laplace = laplace_lib.LaplaceWithSoftplusScale(loc=loc_v, scale=scale_v)
348    self.assertAllClose(
349        self.evaluate(nn_ops.softplus(scale_v)), self.evaluate(laplace.scale))
350    self.assertAllClose(self.evaluate(loc_v), self.evaluate(laplace.loc))
351
352
353if __name__ == "__main__":
354  test.main()
355