# Owner(s): ["module: distributions"] import io from numbers import Number import pytest import torch from torch.autograd import grad from torch.autograd.functional import jacobian from torch.distributions import ( constraints, Dirichlet, Independent, Normal, TransformedDistribution, ) from torch.distributions.transforms import ( _InverseTransform, AbsTransform, AffineTransform, ComposeTransform, CorrCholeskyTransform, CumulativeDistributionTransform, ExpTransform, identity_transform, IndependentTransform, LowerCholeskyTransform, PositiveDefiniteTransform, PowerTransform, ReshapeTransform, SigmoidTransform, SoftmaxTransform, SoftplusTransform, StickBreakingTransform, TanhTransform, Transform, ) from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix from torch.testing._internal.common_utils import run_tests def get_transforms(cache_size): transforms = [ AbsTransform(cache_size=cache_size), ExpTransform(cache_size=cache_size), PowerTransform(exponent=2, cache_size=cache_size), PowerTransform(exponent=-2, cache_size=cache_size), PowerTransform(exponent=torch.tensor(5.0).normal_(), cache_size=cache_size), PowerTransform(exponent=torch.tensor(5.0).normal_(), cache_size=cache_size), SigmoidTransform(cache_size=cache_size), TanhTransform(cache_size=cache_size), AffineTransform(0, 1, cache_size=cache_size), AffineTransform(1, -2, cache_size=cache_size), AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size), AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size), SoftmaxTransform(cache_size=cache_size), SoftplusTransform(cache_size=cache_size), StickBreakingTransform(cache_size=cache_size), LowerCholeskyTransform(cache_size=cache_size), CorrCholeskyTransform(cache_size=cache_size), PositiveDefiniteTransform(cache_size=cache_size), ComposeTransform( [ AffineTransform( torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size ), ] ), ComposeTransform( [ AffineTransform( torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size ), ExpTransform(cache_size=cache_size), ] ), ComposeTransform( [ AffineTransform(0, 1, cache_size=cache_size), AffineTransform( torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size ), AffineTransform(1, -2, cache_size=cache_size), AffineTransform( torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size ), ] ), ReshapeTransform((4, 5), (2, 5, 2)), IndependentTransform( AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size), 1 ), CumulativeDistributionTransform(Normal(0, 1)), ] transforms += [t.inv for t in transforms] return transforms def reshape_transform(transform, shape): # Needed to squash batch dims for testing jacobian if isinstance(transform, AffineTransform): if isinstance(transform.loc, Number): return transform try: return AffineTransform( transform.loc.expand(shape), transform.scale.expand(shape), cache_size=transform._cache_size, ) except RuntimeError: return AffineTransform( transform.loc.reshape(shape), transform.scale.reshape(shape), cache_size=transform._cache_size, ) if isinstance(transform, ComposeTransform): reshaped_parts = [] for p in transform.parts: reshaped_parts.append(reshape_transform(p, shape)) return ComposeTransform(reshaped_parts, cache_size=transform._cache_size) if isinstance(transform.inv, AffineTransform): return reshape_transform(transform.inv, shape).inv if isinstance(transform.inv, ComposeTransform): return reshape_transform(transform.inv, shape).inv return transform # Generate pytest ids def transform_id(x): assert isinstance(x, Transform) name = ( f"Inv({type(x._inv).__name__})" if isinstance(x, _InverseTransform) else f"{type(x).__name__}" ) return f"{name}(cache_size={x._cache_size})" def generate_data(transform): torch.manual_seed(1) while isinstance(transform, IndependentTransform): transform = transform.base_transform if isinstance(transform, ReshapeTransform): return torch.randn(transform.in_shape) if isinstance(transform.inv, ReshapeTransform): return torch.randn(transform.inv.out_shape) domain = transform.domain while ( isinstance(domain, constraints.independent) and domain is not constraints.real_vector ): domain = domain.base_constraint codomain = transform.codomain x = torch.empty(4, 5) positive_definite_constraints = [ constraints.lower_cholesky, constraints.positive_definite, ] if domain in positive_definite_constraints: x = torch.randn(6, 6) x = x.tril(-1) + x.diag().exp().diag_embed() if domain is constraints.positive_definite: return x @ x.T return x elif codomain in positive_definite_constraints: return torch.randn(6, 6) elif domain is constraints.real: return x.normal_() elif domain is constraints.real_vector: # For corr_cholesky the last dim in the vector # must be of size (dim * dim) // 2 x = torch.empty(3, 6) x = x.normal_() return x elif domain is constraints.positive: return x.normal_().exp() elif domain is constraints.unit_interval: return x.uniform_() elif isinstance(domain, constraints.interval): x = x.uniform_() x = x.mul_(domain.upper_bound - domain.lower_bound).add_(domain.lower_bound) return x elif domain is constraints.simplex: x = x.normal_().exp() x /= x.sum(-1, True) return x elif domain is constraints.corr_cholesky: x = torch.empty(4, 5, 5) x = x.normal_().tril() x /= x.norm(dim=-1, keepdim=True) x.diagonal(dim1=-1).copy_(x.diagonal(dim1=-1).abs()) return x raise ValueError(f"Unsupported domain: {domain}") TRANSFORMS_CACHE_ACTIVE = get_transforms(cache_size=1) TRANSFORMS_CACHE_INACTIVE = get_transforms(cache_size=0) ALL_TRANSFORMS = ( TRANSFORMS_CACHE_ACTIVE + TRANSFORMS_CACHE_INACTIVE + [identity_transform] ) @pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id) def test_inv_inv(transform, ids=transform_id): assert transform.inv.inv is transform @pytest.mark.parametrize("x", TRANSFORMS_CACHE_INACTIVE, ids=transform_id) @pytest.mark.parametrize("y", TRANSFORMS_CACHE_INACTIVE, ids=transform_id) def test_equality(x, y): if x is y: assert x == y else: assert x != y assert identity_transform == identity_transform.inv @pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id) def test_with_cache(transform): if transform._cache_size == 0: transform = transform.with_cache(1) assert transform._cache_size == 1 x = generate_data(transform).requires_grad_() try: y = transform(x) except NotImplementedError: pytest.skip("Not implemented.") y2 = transform(x) assert y2 is y @pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id) @pytest.mark.parametrize("test_cached", [True, False]) def test_forward_inverse(transform, test_cached): x = generate_data(transform).requires_grad_() assert transform.domain.check(x).all() # verify that the input data are valid try: y = transform(x) except NotImplementedError: pytest.skip("Not implemented.") assert y.shape == transform.forward_shape(x.shape) if test_cached: x2 = transform.inv(y) # should be implemented at least by caching else: try: x2 = transform.inv(y.clone()) # bypass cache except NotImplementedError: pytest.skip("Not implemented.") assert x2.shape == transform.inverse_shape(y.shape) y2 = transform(x2) if transform.bijective: # verify function inverse assert torch.allclose(x2, x, atol=1e-4, equal_nan=True), "\n".join( [ f"{transform} t.inv(t(-)) error", f"x = {x}", f"y = t(x) = {y}", f"x2 = t.inv(y) = {x2}", ] ) else: # verify weaker function pseudo-inverse assert torch.allclose(y2, y, atol=1e-4, equal_nan=True), "\n".join( [ f"{transform} t(t.inv(t(-))) error", f"x = {x}", f"y = t(x) = {y}", f"x2 = t.inv(y) = {x2}", f"y2 = t(x2) = {y2}", ] ) def test_compose_transform_shapes(): transform0 = ExpTransform() transform1 = SoftmaxTransform() transform2 = LowerCholeskyTransform() assert transform0.event_dim == 0 assert transform1.event_dim == 1 assert transform2.event_dim == 2 assert ComposeTransform([transform0, transform1]).event_dim == 1 assert ComposeTransform([transform0, transform2]).event_dim == 2 assert ComposeTransform([transform1, transform2]).event_dim == 2 transform0 = ExpTransform() transform1 = SoftmaxTransform() transform2 = LowerCholeskyTransform() base_dist0 = Normal(torch.zeros(4, 4), torch.ones(4, 4)) base_dist1 = Dirichlet(torch.ones(4, 4)) base_dist2 = Normal(torch.zeros(3, 4, 4), torch.ones(3, 4, 4)) @pytest.mark.parametrize( ("batch_shape", "event_shape", "dist"), [ ((4, 4), (), base_dist0), ((4,), (4,), base_dist1), ((4, 4), (), TransformedDistribution(base_dist0, [transform0])), ((4,), (4,), TransformedDistribution(base_dist0, [transform1])), ((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])), ((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])), ((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])), ((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])), ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])), ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])), ((4,), (4,), TransformedDistribution(base_dist1, [transform0])), ((4,), (4,), TransformedDistribution(base_dist1, [transform1])), ((), (4, 4), TransformedDistribution(base_dist1, [transform2])), ((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])), ((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])), ((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])), ((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])), ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])), ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])), ((3, 4, 4), (), base_dist2), ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])), ((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])), ((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])), ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])), ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])), ], ) def test_transformed_distribution_shapes(batch_shape, event_shape, dist): assert dist.batch_shape == batch_shape assert dist.event_shape == event_shape x = dist.rsample() try: dist.log_prob(x) # this should not crash except NotImplementedError: pytest.skip("Not implemented.") @pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id) def test_jit_fwd(transform): x = generate_data(transform).requires_grad_() def f(x): return transform(x) try: traced_f = torch.jit.trace(f, (x,)) except NotImplementedError: pytest.skip("Not implemented.") # check on different inputs x = generate_data(transform).requires_grad_() assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True) @pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id) def test_jit_inv(transform): y = generate_data(transform.inv).requires_grad_() def f(y): return transform.inv(y) try: traced_f = torch.jit.trace(f, (y,)) except NotImplementedError: pytest.skip("Not implemented.") # check on different inputs y = generate_data(transform.inv).requires_grad_() assert torch.allclose(f(y), traced_f(y), atol=1e-5, equal_nan=True) @pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id) def test_jit_jacobian(transform): x = generate_data(transform).requires_grad_() def f(x): y = transform(x) return transform.log_abs_det_jacobian(x, y) try: traced_f = torch.jit.trace(f, (x,)) except NotImplementedError: pytest.skip("Not implemented.") # check on different inputs x = generate_data(transform).requires_grad_() assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True) @pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id) def test_jacobian(transform): x = generate_data(transform) try: y = transform(x) actual = transform.log_abs_det_jacobian(x, y) except NotImplementedError: pytest.skip("Not implemented.") # Test shape target_shape = x.shape[: x.dim() - transform.domain.event_dim] assert actual.shape == target_shape # Expand if required transform = reshape_transform(transform, x.shape) ndims = len(x.shape) event_dim = ndims - transform.domain.event_dim x_ = x.view((-1,) + x.shape[event_dim:]) n = x_.shape[0] # Reshape to squash batch dims to a single batch dim transform = reshape_transform(transform, x_.shape) # 1. Transforms with unit jacobian if isinstance(transform, ReshapeTransform) or isinstance( transform.inv, ReshapeTransform ): expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim]) expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim]) # 2. Transforms with 0 off-diagonal elements elif transform.domain.event_dim == 0: jac = jacobian(transform, x_) # assert off-diagonal elements are zero assert torch.allclose(jac, jac.diagonal().diag_embed()) expected = jac.diagonal().abs().log().reshape(x.shape) # 3. Transforms with non-0 off-diagonal elements else: if isinstance(transform, CorrCholeskyTransform): jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1), x_) elif isinstance(transform.inv, CorrCholeskyTransform): jac = jacobian( lambda x: transform(vec_to_tril_matrix(x, diag=-1)), tril_matrix_to_vec(x_, diag=-1), ) elif isinstance(transform, StickBreakingTransform): jac = jacobian(lambda x: transform(x)[..., :-1], x_) else: jac = jacobian(transform, x_) # Note that jacobian will have shape (batch_dims, y_event_dims, batch_dims, x_event_dims) # However, batches are independent so this can be converted into a (batch_dims, event_dims, event_dims) # after reshaping the event dims (see above) to give a batched square matrix whose determinant # can be computed. gather_idx_shape = list(jac.shape) gather_idx_shape[-2] = 1 gather_idxs = ( torch.arange(n) .reshape((n,) + (1,) * (len(jac.shape) - 1)) .expand(gather_idx_shape) ) jac = jac.gather(-2, gather_idxs).squeeze(-2) out_ndims = jac.shape[-2] jac = jac[ ..., :out_ndims ] # Remove extra zero-valued dims (for inverse stick-breaking). expected = torch.slogdet(jac).logabsdet assert torch.allclose(actual, expected, atol=1e-5) @pytest.mark.parametrize( "event_dims", [(0,), (1,), (2, 3), (0, 1, 2), (1, 2, 0), (2, 0, 1)], ids=str ) def test_compose_affine(event_dims): transforms = [ AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims ] transform = ComposeTransform(transforms) assert transform.codomain.event_dim == max(event_dims) assert transform.domain.event_dim == max(event_dims) base_dist = Normal(0, 1) if transform.domain.event_dim: base_dist = base_dist.expand((1,) * transform.domain.event_dim) dist = TransformedDistribution(base_dist, transform.parts) assert dist.support.event_dim == max(event_dims) base_dist = Dirichlet(torch.ones(5)) if transform.domain.event_dim > 1: base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1)) dist = TransformedDistribution(base_dist, transforms) assert dist.support.event_dim == max(1, *event_dims) @pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str) def test_compose_reshape(batch_shape): transforms = [ ReshapeTransform((), ()), ReshapeTransform((2,), (1, 2)), ReshapeTransform((3, 1, 2), (6,)), ReshapeTransform((6,), (2, 3)), ] transform = ComposeTransform(transforms) assert transform.codomain.event_dim == 2 assert transform.domain.event_dim == 2 data = torch.randn(batch_shape + (3, 2)) assert transform(data).shape == batch_shape + (2, 3) dist = TransformedDistribution(Normal(data, 1), transforms) assert dist.batch_shape == batch_shape assert dist.event_shape == (2, 3) assert dist.support.event_dim == 2 @pytest.mark.parametrize("sample_shape", [(), (7,)], ids=str) @pytest.mark.parametrize("transform_dim", [0, 1, 2]) @pytest.mark.parametrize("base_batch_dim", [0, 1, 2]) @pytest.mark.parametrize("base_event_dim", [0, 1, 2]) @pytest.mark.parametrize("num_transforms", [0, 1, 2, 3]) def test_transformed_distribution( base_batch_dim, base_event_dim, transform_dim, num_transforms, sample_shape ): shape = torch.Size([2, 3, 4, 5]) base_dist = Normal(0, 1) base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim :]) if base_event_dim: base_dist = Independent(base_dist, base_event_dim) transforms = [ AffineTransform(torch.zeros(shape[4 - transform_dim :]), 1), ReshapeTransform((4, 5), (20,)), ReshapeTransform((3, 20), (6, 10)), ] transforms = transforms[:num_transforms] transform = ComposeTransform(transforms) # Check validation in .__init__(). if base_batch_dim + base_event_dim < transform.domain.event_dim: with pytest.raises(ValueError): TransformedDistribution(base_dist, transforms) return d = TransformedDistribution(base_dist, transforms) # Check sampling is sufficiently expanded. x = d.sample(sample_shape) assert x.shape == sample_shape + d.batch_shape + d.event_shape num_unique = len(set(x.reshape(-1).tolist())) assert num_unique >= 0.9 * x.numel() # Check log_prob shape on full samples. log_prob = d.log_prob(x) assert log_prob.shape == sample_shape + d.batch_shape # Check log_prob shape on partial samples. y = x while y.dim() > len(d.event_shape): y = y[0] log_prob = d.log_prob(y) assert log_prob.shape == d.batch_shape def test_save_load_transform(): # Evaluating `log_prob` will create a weakref `_inv` which cannot be pickled. Here, we check # that `__getstate__` correctly handles the weakref, and that we can evaluate the density after. dist = TransformedDistribution(Normal(0, 1), [AffineTransform(2, 3)]) x = torch.linspace(0, 1, 10) log_prob = dist.log_prob(x) stream = io.BytesIO() torch.save(dist, stream) stream.seek(0) other = torch.load(stream) assert torch.allclose(log_prob, other.log_prob(x)) @pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id) def test_transform_sign(transform: Transform): try: sign = transform.sign except NotImplementedError: pytest.skip("Not implemented.") x = generate_data(transform).requires_grad_() y = transform(x).sum() (derivatives,) = grad(y, [x]) assert torch.less(torch.as_tensor(0.0), derivatives * sign).all() if __name__ == "__main__": run_tests()