1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16The VAE interface can be called to construct VAE-GAN network. 17""" 18import os 19 20import mindspore.dataset as ds 21import mindspore.dataset.vision.c_transforms as CV 22import mindspore.nn as nn 23from mindspore import context 24import mindspore.ops as ops 25from mindspore.nn.probability.dpn import VAE 26from mindspore.nn.probability.infer import ELBO, SVI 27 28context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 29IMAGE_SHAPE = (-1, 1, 32, 32) 30image_path = os.path.join('/home/workspace/mindspore_dataset/mnist', "train") 31 32 33class Encoder(nn.Cell): 34 def __init__(self): 35 super(Encoder, self).__init__() 36 self.fc1 = nn.Dense(1024, 400) 37 self.relu = nn.ReLU() 38 self.flatten = nn.Flatten() 39 40 def construct(self, x): 41 x = self.flatten(x) 42 x = self.fc1(x) 43 x = self.relu(x) 44 return x 45 46 47class Decoder(nn.Cell): 48 def __init__(self): 49 super(Decoder, self).__init__() 50 self.fc1 = nn.Dense(400, 1024) 51 self.relu = nn.ReLU() 52 self.sigmoid = nn.Sigmoid() 53 self.reshape = ops.Reshape() 54 55 def construct(self, z): 56 z = self.fc1(z) 57 z = self.reshape(z, IMAGE_SHAPE) 58 z = self.sigmoid(z) 59 return z 60 61 62class Discriminator(nn.Cell): 63 """ 64 The Discriminator of the GAN network. 65 """ 66 67 def __init__(self): 68 super(Discriminator, self).__init__() 69 self.fc1 = nn.Dense(1024, 400) 70 self.fc2 = nn.Dense(400, 720) 71 self.fc3 = nn.Dense(720, 1024) 72 self.relu = nn.ReLU() 73 self.sigmoid = nn.Sigmoid() 74 self.flatten = nn.Flatten() 75 76 def construct(self, x): 77 x = self.flatten(x) 78 x = self.fc1(x) 79 x = self.relu(x) 80 x = self.fc2(x) 81 x = self.relu(x) 82 x = self.fc3(x) 83 x = self.sigmoid(x) 84 return x 85 86 87class VaeGan(nn.Cell): 88 def __init__(self): 89 super(VaeGan, self).__init__() 90 self.E = Encoder() 91 self.G = Decoder() 92 self.D = Discriminator() 93 self.dense = nn.Dense(20, 400) 94 self.vae = VAE(self.E, self.G, 400, 20) 95 self.shape = ops.Shape() 96 self.normal = ops.normal 97 self.to_tensor = ops.ScalarToArray() 98 99 def construct(self, x): 100 recon_x, x, mu, std = self.vae(x) 101 z_p = self.normal(self.shape(mu), self.to_tensor(0.0), self.to_tensor(1.0), seed=0) 102 z_p = self.dense(z_p) 103 x_p = self.G(z_p) 104 ld_real = self.D(x) 105 ld_fake = self.D(recon_x) 106 ld_p = self.D(x_p) 107 return ld_real, ld_fake, ld_p, recon_x, x, mu, std 108 109 110class VaeGanLoss(ELBO): 111 def __init__(self): 112 super(VaeGanLoss, self).__init__() 113 self.zeros = ops.ZerosLike() 114 self.mse = nn.MSELoss(reduction='sum') 115 116 def construct(self, data, label): 117 ld_real, ld_fake, ld_p, recon_x, x, mu, std = data 118 y_real = self.zeros(ld_real) + 1 119 y_fake = self.zeros(ld_fake) 120 loss_D = self.mse(ld_real, y_real) 121 loss_GD = self.mse(ld_p, y_fake) 122 loss_G = self.mse(ld_fake, y_real) 123 reconstruct_loss = self.recon_loss(x, recon_x) 124 kl_loss = self.posterior('kl_loss', 'Normal', self.zeros(mu), self.zeros(mu) + 1, mu, std) 125 elbo_loss = reconstruct_loss + self.sum(kl_loss) 126 return loss_D + loss_G + loss_GD + elbo_loss 127 128 129def create_dataset(data_path, batch_size=32, repeat_size=1, 130 num_parallel_workers=1): 131 """ 132 create dataset for train or test 133 """ 134 # define dataset 135 mnist_ds = ds.MnistDataset(data_path) 136 137 resize_height, resize_width = 32, 32 138 rescale = 1.0 / 255.0 139 shift = 0.0 140 141 # define map operations 142 resize_op = CV.Resize((resize_height, resize_width)) # Bilinear mode 143 rescale_op = CV.Rescale(rescale, shift) 144 hwc2chw_op = CV.HWC2CHW() 145 146 # apply map operations on images 147 mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) 148 mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) 149 mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) 150 151 # apply DatasetOps 152 mnist_ds = mnist_ds.batch(batch_size) 153 mnist_ds = mnist_ds.repeat(repeat_size) 154 155 return mnist_ds 156 157 158def test_vae_gan(): 159 vae_gan = VaeGan() 160 net_loss = VaeGanLoss() 161 optimizer = nn.Adam(params=vae_gan.trainable_params(), learning_rate=0.001) 162 ds_train = create_dataset(image_path, 128, 1) 163 net_with_loss = nn.WithLossCell(vae_gan, net_loss) 164 vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer) 165 vae_gan = vi.run(train_dataset=ds_train, epochs=5) 166