1# mypy: allow-untyped-defs 2import torch 3from torch.distributions import constraints 4from torch.distributions.categorical import Categorical 5from torch.distributions.distribution import Distribution 6from torch.types import _size 7 8 9__all__ = ["OneHotCategorical", "OneHotCategoricalStraightThrough"] 10 11 12class OneHotCategorical(Distribution): 13 r""" 14 Creates a one-hot categorical distribution parameterized by :attr:`probs` or 15 :attr:`logits`. 16 17 Samples are one-hot coded vectors of size ``probs.size(-1)``. 18 19 .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum, 20 and it will be normalized to sum to 1 along the last dimension. :attr:`probs` 21 will return this normalized value. 22 The `logits` argument will be interpreted as unnormalized log probabilities 23 and can therefore be any real number. It will likewise be normalized so that 24 the resulting probabilities sum to 1 along the last dimension. :attr:`logits` 25 will return this normalized value. 26 27 See also: :func:`torch.distributions.Categorical` for specifications of 28 :attr:`probs` and :attr:`logits`. 29 30 Example:: 31 32 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 33 >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) 34 >>> m.sample() # equal probability of 0, 1, 2, 3 35 tensor([ 0., 0., 0., 1.]) 36 37 Args: 38 probs (Tensor): event probabilities 39 logits (Tensor): event log probabilities (unnormalized) 40 """ 41 arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} 42 support = constraints.one_hot 43 has_enumerate_support = True 44 45 def __init__(self, probs=None, logits=None, validate_args=None): 46 self._categorical = Categorical(probs, logits) 47 batch_shape = self._categorical.batch_shape 48 event_shape = self._categorical.param_shape[-1:] 49 super().__init__(batch_shape, event_shape, validate_args=validate_args) 50 51 def expand(self, batch_shape, _instance=None): 52 new = self._get_checked_instance(OneHotCategorical, _instance) 53 batch_shape = torch.Size(batch_shape) 54 new._categorical = self._categorical.expand(batch_shape) 55 super(OneHotCategorical, new).__init__( 56 batch_shape, self.event_shape, validate_args=False 57 ) 58 new._validate_args = self._validate_args 59 return new 60 61 def _new(self, *args, **kwargs): 62 return self._categorical._new(*args, **kwargs) 63 64 @property 65 def _param(self): 66 return self._categorical._param 67 68 @property 69 def probs(self): 70 return self._categorical.probs 71 72 @property 73 def logits(self): 74 return self._categorical.logits 75 76 @property 77 def mean(self): 78 return self._categorical.probs 79 80 @property 81 def mode(self): 82 probs = self._categorical.probs 83 mode = probs.argmax(axis=-1) 84 return torch.nn.functional.one_hot(mode, num_classes=probs.shape[-1]).to(probs) 85 86 @property 87 def variance(self): 88 return self._categorical.probs * (1 - self._categorical.probs) 89 90 @property 91 def param_shape(self): 92 return self._categorical.param_shape 93 94 def sample(self, sample_shape=torch.Size()): 95 sample_shape = torch.Size(sample_shape) 96 probs = self._categorical.probs 97 num_events = self._categorical._num_events 98 indices = self._categorical.sample(sample_shape) 99 return torch.nn.functional.one_hot(indices, num_events).to(probs) 100 101 def log_prob(self, value): 102 if self._validate_args: 103 self._validate_sample(value) 104 indices = value.max(-1)[1] 105 return self._categorical.log_prob(indices) 106 107 def entropy(self): 108 return self._categorical.entropy() 109 110 def enumerate_support(self, expand=True): 111 n = self.event_shape[0] 112 values = torch.eye(n, dtype=self._param.dtype, device=self._param.device) 113 values = values.view((n,) + (1,) * len(self.batch_shape) + (n,)) 114 if expand: 115 values = values.expand((n,) + self.batch_shape + (n,)) 116 return values 117 118 119class OneHotCategoricalStraightThrough(OneHotCategorical): 120 r""" 121 Creates a reparameterizable :class:`OneHotCategorical` distribution based on the straight- 122 through gradient estimator from [1]. 123 124 [1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation 125 (Bengio et al., 2013) 126 """ 127 has_rsample = True 128 129 def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: 130 samples = self.sample(sample_shape) 131 probs = self._categorical.probs # cached via @lazy_property 132 return samples + (probs - probs.detach()) 133