1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Conditional Variational auto-encoder (CVAE).""" 16from mindspore.ops import composite as C 17from mindspore.ops import operations as P 18from mindspore._checkparam import Validator 19from ....cell import Cell 20from ....layer.basic import Dense, OneHot 21 22 23class ConditionalVAE(Cell): 24 r""" 25 Conditional Variational Auto-Encoder (CVAE). 26 27 The difference with VAE is that CVAE uses labels information. 28 For more details, refer to `Learning Structured Output Representation using Deep Conditional Generative Models 29 <http://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional- 30 generative-models>`_. 31 32 Note: 33 When encoder and decoder ard defined, the shape of the encoder's output tensor and decoder's input tensor 34 must be :math:`(N, hidden\_size)`. 35 The latent_size must be less than or equal to the hidden_size. 36 37 Args: 38 encoder(Cell): The Deep Neural Network (DNN) model defined as encoder. 39 decoder(Cell): The DNN model defined as decoder. 40 hidden_size(int): The size of encoder's output tensor. 41 latent_size(int): The size of the latent space. 42 num_classes(int): The number of classes. 43 44 Inputs: 45 - **input_x** (Tensor) - The shape of input tensor is :math:`(N, C, H, W)`, which is the same as the input of 46 encoder. 47 48 - **input_y** (Tensor) - The tensor of the target data, the shape is :math:`(N,)`. 49 50 Outputs: 51 - **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)). 52 53 Supported Platforms: 54 ``Ascend`` ``GPU`` 55 """ 56 57 def __init__(self, encoder, decoder, hidden_size, latent_size, num_classes): 58 super(ConditionalVAE, self).__init__() 59 self.encoder = encoder 60 self.decoder = decoder 61 if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)): 62 raise TypeError('The encoder and decoder should be Cell type.') 63 self.hidden_size = Validator.check_positive_int(hidden_size) 64 self.latent_size = Validator.check_positive_int(latent_size) 65 if hidden_size < latent_size: 66 raise ValueError('The latent_size should be less than or equal to the hidden_size.') 67 self.num_classes = Validator.check_positive_int(num_classes) 68 self.normal = C.normal 69 self.exp = P.Exp() 70 self.reshape = P.Reshape() 71 self.shape = P.Shape() 72 self.concat = P.Concat(axis=1) 73 self.to_tensor = P.ScalarToArray() 74 self.one_hot = OneHot(depth=num_classes) 75 self.dense1 = Dense(self.hidden_size, self.latent_size) 76 self.dense2 = Dense(self.hidden_size, self.latent_size) 77 self.dense3 = Dense(self.latent_size + self.num_classes, self.hidden_size) 78 79 def _encode(self, x, y): 80 en_x = self.encoder(x, y) 81 mu = self.dense1(en_x) 82 log_var = self.dense2(en_x) 83 return mu, log_var 84 85 def _decode(self, z): 86 z = self.dense3(z) 87 recon_x = self.decoder(z) 88 return recon_x 89 90 def construct(self, x, y): 91 """ 92 The input are x and y, so the WithLossCell method needs to be rewritten when using cvae interface. 93 """ 94 mu, log_var = self._encode(x, y) 95 std = self.exp(0.5 * log_var) 96 z = self.normal(self.shape(mu), mu, std, seed=0) 97 y = self.one_hot(y) 98 z_c = self.concat((z, y)) 99 recon_x = self._decode(z_c) 100 return recon_x, x, mu, std 101 102 def generate_sample(self, sample_y, generate_nums, shape): 103 """ 104 Randomly sample from the latent space to generate samples. 105 106 Args: 107 sample_y (Tensor): Define the label of samples. Tensor of shape (generate_nums, ) and type mindspore.int32. 108 generate_nums (int): The number of samples to generate. 109 shape(tuple): The shape of sample, which must be the format of (generate_nums, C, H, W) or (-1, C, H, W). 110 111 Returns: 112 Tensor, the generated samples. 113 """ 114 generate_nums = Validator.check_positive_int(generate_nums) 115 if not isinstance(shape, tuple) or len(shape) != 4 or (shape[0] != -1 and shape[0] != generate_nums): 116 raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).') 117 sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0) 118 sample_y = self.one_hot(sample_y) 119 sample_c = self.concat((sample_z, sample_y)) 120 sample = self._decode(sample_c) 121 sample = self.reshape(sample, shape) 122 return sample 123 124 def reconstruct_sample(self, x, y): 125 """ 126 Reconstruct samples from original data. 127 128 Args: 129 x (Tensor): The input tensor to be reconstructed, the shape is (N, C, H, W). 130 y (Tensor): The label of the input tensor, the shape is (N,). 131 132 Returns: 133 Tensor, the reconstructed sample. 134 """ 135 mu, log_var = self._encode(x, y) 136 std = self.exp(0.5 * log_var) 137 z = self.normal(mu.shape, mu, std, seed=0) 138 y = self.one_hot(y) 139 z_c = self.concat((z, y)) 140 recon_x = self._decode(z_c) 141 return recon_x 142