• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Conditional Variational auto-encoder (CVAE)."""
16from mindspore.ops import composite as C
17from mindspore.ops import operations as P
18from mindspore._checkparam import Validator
19from ....cell import Cell
20from ....layer.basic import Dense, OneHot
21
22
23class ConditionalVAE(Cell):
24    r"""
25    Conditional Variational Auto-Encoder (CVAE).
26
27    The difference with VAE is that CVAE uses labels information.
28    For more details, refer to `Learning Structured Output Representation using Deep Conditional Generative Models
29    <http://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional-
30    generative-models>`_.
31
32    Note:
33        When encoder and decoder ard defined, the shape of the encoder's output tensor and decoder's input tensor
34        must be :math:`(N, hidden\_size)`.
35        The latent_size must be less than or equal to the hidden_size.
36
37    Args:
38        encoder(Cell): The Deep Neural Network (DNN) model defined as encoder.
39        decoder(Cell): The DNN model defined as decoder.
40        hidden_size(int): The size of encoder's output tensor.
41        latent_size(int): The size of the latent space.
42        num_classes(int): The number of classes.
43
44    Inputs:
45        - **input_x** (Tensor) - The shape of input tensor is :math:`(N, C, H, W)`, which is the same as the input of
46          encoder.
47
48        - **input_y** (Tensor) - The tensor of the target data, the shape is :math:`(N,)`.
49
50    Outputs:
51        - **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).
52
53    Supported Platforms:
54        ``Ascend`` ``GPU``
55    """
56
57    def __init__(self, encoder, decoder, hidden_size, latent_size, num_classes):
58        super(ConditionalVAE, self).__init__()
59        self.encoder = encoder
60        self.decoder = decoder
61        if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)):
62            raise TypeError('The encoder and decoder should be Cell type.')
63        self.hidden_size = Validator.check_positive_int(hidden_size)
64        self.latent_size = Validator.check_positive_int(latent_size)
65        if hidden_size < latent_size:
66            raise ValueError('The latent_size should be less than or equal to the hidden_size.')
67        self.num_classes = Validator.check_positive_int(num_classes)
68        self.normal = C.normal
69        self.exp = P.Exp()
70        self.reshape = P.Reshape()
71        self.shape = P.Shape()
72        self.concat = P.Concat(axis=1)
73        self.to_tensor = P.ScalarToArray()
74        self.one_hot = OneHot(depth=num_classes)
75        self.dense1 = Dense(self.hidden_size, self.latent_size)
76        self.dense2 = Dense(self.hidden_size, self.latent_size)
77        self.dense3 = Dense(self.latent_size + self.num_classes, self.hidden_size)
78
79    def _encode(self, x, y):
80        en_x = self.encoder(x, y)
81        mu = self.dense1(en_x)
82        log_var = self.dense2(en_x)
83        return mu, log_var
84
85    def _decode(self, z):
86        z = self.dense3(z)
87        recon_x = self.decoder(z)
88        return recon_x
89
90    def construct(self, x, y):
91        """
92        The input are x and y, so the WithLossCell method needs to be rewritten when using cvae interface.
93        """
94        mu, log_var = self._encode(x, y)
95        std = self.exp(0.5 * log_var)
96        z = self.normal(self.shape(mu), mu, std, seed=0)
97        y = self.one_hot(y)
98        z_c = self.concat((z, y))
99        recon_x = self._decode(z_c)
100        return recon_x, x, mu, std
101
102    def generate_sample(self, sample_y, generate_nums, shape):
103        """
104        Randomly sample from the latent space to generate samples.
105
106        Args:
107            sample_y (Tensor): Define the label of samples. Tensor of shape (generate_nums, ) and type mindspore.int32.
108            generate_nums (int): The number of samples to generate.
109            shape(tuple): The shape of sample, which must be the format of (generate_nums, C, H, W) or (-1, C, H, W).
110
111        Returns:
112            Tensor, the generated samples.
113        """
114        generate_nums = Validator.check_positive_int(generate_nums)
115        if not isinstance(shape, tuple) or len(shape) != 4 or (shape[0] != -1 and shape[0] != generate_nums):
116            raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).')
117        sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
118        sample_y = self.one_hot(sample_y)
119        sample_c = self.concat((sample_z, sample_y))
120        sample = self._decode(sample_c)
121        sample = self.reshape(sample, shape)
122        return sample
123
124    def reconstruct_sample(self, x, y):
125        """
126        Reconstruct samples from original data.
127
128        Args:
129            x (Tensor): The input tensor to be reconstructed, the shape is (N, C, H, W).
130            y (Tensor): The label of the input tensor, the shape is (N,).
131
132        Returns:
133            Tensor, the reconstructed sample.
134        """
135        mu, log_var = self._encode(x, y)
136        std = self.exp(0.5 * log_var)
137        z = self.normal(mu.shape, mu, std, seed=0)
138        y = self.one_hot(y)
139        z_c = self.concat((z, y))
140        recon_x = self._decode(z_c)
141        return recon_x
142