• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16The VAE interface can be called to construct VAE-GAN network.
17"""
18import os
19
20import mindspore.dataset as ds
21import mindspore.dataset.vision.c_transforms as CV
22import mindspore.nn as nn
23from mindspore import context
24import mindspore.ops as ops
25from mindspore.nn.probability.dpn import VAE
26from mindspore.nn.probability.infer import ELBO, SVI
27
28context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
29IMAGE_SHAPE = (-1, 1, 32, 32)
30image_path = os.path.join('/home/workspace/mindspore_dataset/mnist', "train")
31
32
33class Encoder(nn.Cell):
34    def __init__(self):
35        super(Encoder, self).__init__()
36        self.fc1 = nn.Dense(1024, 400)
37        self.relu = nn.ReLU()
38        self.flatten = nn.Flatten()
39
40    def construct(self, x):
41        x = self.flatten(x)
42        x = self.fc1(x)
43        x = self.relu(x)
44        return x
45
46
47class Decoder(nn.Cell):
48    def __init__(self):
49        super(Decoder, self).__init__()
50        self.fc1 = nn.Dense(400, 1024)
51        self.relu = nn.ReLU()
52        self.sigmoid = nn.Sigmoid()
53        self.reshape = ops.Reshape()
54
55    def construct(self, z):
56        z = self.fc1(z)
57        z = self.reshape(z, IMAGE_SHAPE)
58        z = self.sigmoid(z)
59        return z
60
61
62class Discriminator(nn.Cell):
63    """
64    The Discriminator of the GAN network.
65    """
66
67    def __init__(self):
68        super(Discriminator, self).__init__()
69        self.fc1 = nn.Dense(1024, 400)
70        self.fc2 = nn.Dense(400, 720)
71        self.fc3 = nn.Dense(720, 1024)
72        self.relu = nn.ReLU()
73        self.sigmoid = nn.Sigmoid()
74        self.flatten = nn.Flatten()
75
76    def construct(self, x):
77        x = self.flatten(x)
78        x = self.fc1(x)
79        x = self.relu(x)
80        x = self.fc2(x)
81        x = self.relu(x)
82        x = self.fc3(x)
83        x = self.sigmoid(x)
84        return x
85
86
87class VaeGan(nn.Cell):
88    def __init__(self):
89        super(VaeGan, self).__init__()
90        self.E = Encoder()
91        self.G = Decoder()
92        self.D = Discriminator()
93        self.dense = nn.Dense(20, 400)
94        self.vae = VAE(self.E, self.G, 400, 20)
95        self.shape = ops.Shape()
96        self.normal = ops.normal
97        self.to_tensor = ops.ScalarToArray()
98
99    def construct(self, x):
100        recon_x, x, mu, std = self.vae(x)
101        z_p = self.normal(self.shape(mu), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
102        z_p = self.dense(z_p)
103        x_p = self.G(z_p)
104        ld_real = self.D(x)
105        ld_fake = self.D(recon_x)
106        ld_p = self.D(x_p)
107        return ld_real, ld_fake, ld_p, recon_x, x, mu, std
108
109
110class VaeGanLoss(ELBO):
111    def __init__(self):
112        super(VaeGanLoss, self).__init__()
113        self.zeros = ops.ZerosLike()
114        self.mse = nn.MSELoss(reduction='sum')
115
116    def construct(self, data, label):
117        ld_real, ld_fake, ld_p, recon_x, x, mu, std = data
118        y_real = self.zeros(ld_real) + 1
119        y_fake = self.zeros(ld_fake)
120        loss_D = self.mse(ld_real, y_real)
121        loss_GD = self.mse(ld_p, y_fake)
122        loss_G = self.mse(ld_fake, y_real)
123        reconstruct_loss = self.recon_loss(x, recon_x)
124        kl_loss = self.posterior('kl_loss', 'Normal', self.zeros(mu), self.zeros(mu) + 1, mu, std)
125        elbo_loss = reconstruct_loss + self.sum(kl_loss)
126        return loss_D + loss_G + loss_GD + elbo_loss
127
128
129def create_dataset(data_path, batch_size=32, repeat_size=1,
130                   num_parallel_workers=1):
131    """
132    create dataset for train or test
133    """
134    # define dataset
135    mnist_ds = ds.MnistDataset(data_path)
136
137    resize_height, resize_width = 32, 32
138    rescale = 1.0 / 255.0
139    shift = 0.0
140
141    # define map operations
142    resize_op = CV.Resize((resize_height, resize_width))  # Bilinear mode
143    rescale_op = CV.Rescale(rescale, shift)
144    hwc2chw_op = CV.HWC2CHW()
145
146    # apply map operations on images
147    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
148    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
149    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
150
151    # apply DatasetOps
152    mnist_ds = mnist_ds.batch(batch_size)
153    mnist_ds = mnist_ds.repeat(repeat_size)
154
155    return mnist_ds
156
157
158def test_vae_gan():
159    vae_gan = VaeGan()
160    net_loss = VaeGanLoss()
161    optimizer = nn.Adam(params=vae_gan.trainable_params(), learning_rate=0.001)
162    ds_train = create_dataset(image_path, 128, 1)
163    net_with_loss = nn.WithLossCell(vae_gan, net_loss)
164    vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer)
165    vae_gan = vi.run(train_dataset=ds_train, epochs=5)
166