1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""The Evidence Lower Bound (ELBO).""" 16from mindspore.ops import operations as P 17from ...distribution.normal import Normal 18from ....cell import Cell 19from ....loss.loss import MSELoss 20 21 22class ELBO(Cell): 23 r""" 24 The Evidence Lower Bound (ELBO). 25 26 Variational inference minimizes the Kullback-Leibler (KL) divergence from the variational distribution to 27 the posterior distribution. It maximizes the ELBO, a lower bound on the logarithm of 28 the marginal probability of the observations log p(x). The ELBO is equal to the negative KL divergence up to 29 an additive constant. 30 For more details, refer to `Variational Inference: A Review for Statisticians <https://arxiv.org/abs/1601.00670>`_. 31 32 Args: 33 latent_prior(str): The prior distribution of latent space. Default: Normal. 34 35 - Normal: The prior distribution of latent space is Normal. 36 37 output_prior(str): The distribution of output data. Default: Normal. 38 39 - Normal: If the distribution of output data is Normal, the reconstruct loss is MSELoss. 40 41 Inputs: 42 - **input_data** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)). 43 - **target_data** (Tensor) - the target tensor of shape :math:`(N,)`. 44 45 Outputs: 46 Tensor, loss float tensor. 47 48 Supported Platforms: 49 ``Ascend`` ``GPU`` 50 """ 51 52 def __init__(self, latent_prior='Normal', output_prior='Normal'): 53 super(ELBO, self).__init__() 54 self.sum = P.ReduceSum() 55 self.zeros = P.ZerosLike() 56 if latent_prior == 'Normal': 57 self.posterior = Normal() 58 else: 59 raise ValueError('The values of latent_prior now only support Normal') 60 if output_prior == 'Normal': 61 self.recon_loss = MSELoss(reduction='sum') 62 else: 63 raise ValueError('The values of output_dis now only support Normal') 64 65 def construct(self, data, label): 66 recon_x, x, mu, std = data 67 reconstruct_loss = self.recon_loss(x, recon_x) 68 kl_loss = self.posterior('kl_loss', 'Normal', self.zeros(mu), self.zeros(mu)+1, mu, std) 69 elbo = reconstruct_loss + self.sum(kl_loss) 70 return elbo 71