• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""The Evidence Lower Bound (ELBO)."""
16from mindspore.ops import operations as P
17from ...distribution.normal import Normal
18from ....cell import Cell
19from ....loss.loss import MSELoss
20
21
22class ELBO(Cell):
23    r"""
24    The Evidence Lower Bound (ELBO).
25
26    Variational inference minimizes the Kullback-Leibler (KL) divergence from the variational distribution to
27    the posterior distribution. It maximizes the ELBO, a lower bound on the logarithm of
28    the marginal probability of the observations log p(x). The ELBO is equal to the negative KL divergence up to
29    an additive constant.
30    For more details, refer to `Variational Inference: A Review for Statisticians <https://arxiv.org/abs/1601.00670>`_.
31
32    Args:
33        latent_prior(str): The prior distribution of latent space. Default: Normal.
34
35            - Normal: The prior distribution of latent space is Normal.
36
37        output_prior(str): The distribution of output data. Default: Normal.
38
39            - Normal: If the distribution of output data is Normal, the reconstruct loss is MSELoss.
40
41    Inputs:
42        - **input_data** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).
43        - **target_data** (Tensor) - the target tensor of shape :math:`(N,)`.
44
45    Outputs:
46        Tensor, loss float tensor.
47
48    Supported Platforms:
49        ``Ascend`` ``GPU``
50    """
51
52    def __init__(self, latent_prior='Normal', output_prior='Normal'):
53        super(ELBO, self).__init__()
54        self.sum = P.ReduceSum()
55        self.zeros = P.ZerosLike()
56        if latent_prior == 'Normal':
57            self.posterior = Normal()
58        else:
59            raise ValueError('The values of latent_prior now only support Normal')
60        if output_prior == 'Normal':
61            self.recon_loss = MSELoss(reduction='sum')
62        else:
63            raise ValueError('The values of output_dis now only support Normal')
64
65    def construct(self, data, label):
66        recon_x, x, mu, std = data
67        reconstruct_loss = self.recon_loss(x, recon_x)
68        kl_loss = self.posterior('kl_loss', 'Normal', self.zeros(mu), self.zeros(mu)+1, mu, std)
69        elbo = reconstruct_loss + self.sum(kl_loss)
70        return elbo
71