• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" VAE """
16
17import os
18import numpy as np
19
20from utils import create_dataset, save_img
21
22import mindspore.nn as nn
23
24from mindspore import context
25from mindspore import Tensor
26from mindspore.train import Model
27from mindspore.train.callback import LossMonitor
28from mindspore.ops import operations as P
29from mindspore.common import dtype as mstype
30
31import zhusuan as zs
32
33class ReduceMeanLoss(nn.L1Loss):
34    def construct(self, base, target):
35        # return self.get_loss(x)
36        return base
37
38class Generator(zs.BayesianNet):
39    """ Generator """
40    def __init__(self, x_dim, z_dim, batch_size):
41        super().__init__()
42        self.x_dim = x_dim
43        self.z_dim = z_dim
44        self.batch_size = batch_size
45
46        self.fc1 = nn.Dense(z_dim, 500)
47        self.act1 = nn.ReLU()
48        self.fc2 = nn.Dense(500, 500)
49        self.act2 = nn.ReLU()
50        self.fc3 = nn.Dense(500, x_dim)
51        self.fill = P.Fill()
52        self.sigmoid = P.Sigmoid()
53        self.reshape_op = P.Reshape()
54
55    def ones(self, shape):
56        return self.fill(mstype.float32, shape, 1.)
57
58    def zeros(self, shape):
59        return self.fill(mstype.float32, shape, 0.)
60
61    def construct(self, x, z, y):
62        """ construct """
63        assert y is None ## we have no conditional information
64
65        if not x is None:
66            x = self.reshape_op(x, (32, 32*32))
67
68        z_mean = self.zeros((self.batch_size, self.z_dim))
69        z_std = self.ones((self.batch_size, self.z_dim))
70        z, log_prob_z = self.normal('latent', observation=z, mean=z_mean, std=z_std, shape=(), reparameterize=False)
71
72        x_mean = self.sigmoid(self.fc3(self.act2(self.fc2(self.act1(self.fc1(z))))))
73        if x is None:
74            #x = self.bernoulli_dist('sample', (), x_mean)
75            x = x_mean
76        x, log_prob_x = self.bernoulli('data', observation=x, shape=(), probs=x_mean)
77
78        return x, log_prob_x, z, log_prob_z
79
80class Variational(zs.BayesianNet):
81    """ Variational """
82    def __init__(self, x_dim, z_dim, batch_size):
83        super().__init__()
84        self.x_dim = x_dim
85        self.z_dim = z_dim
86        self.batch_size = batch_size
87        self.reshape_op = P.Reshape()
88
89        self.fc1 = nn.Dense(x_dim, 500)
90        self.act1 = nn.ReLU()
91        self.fc2 = nn.Dense(500, 500)
92        self.act2 = nn.ReLU()
93        self.fc3 = nn.Dense(500, z_dim)
94        self.fc4 = nn.Dense(500, z_dim)
95        self.fill = P.Fill()
96        self.exp = P.Exp()
97
98    def ones(self, shape):
99        return self.fill(mstype.float32, shape, 1.)
100
101    def zeros(self, shape):
102        return self.fill(mstype.float32, shape, 0.)
103
104    def construct(self, x, z, y):
105        """ construct """
106        assert y is None ## we have no conditional information
107        x = self.reshape_op(x, (32, 32*32))
108        z_logit = self.act2(self.fc2(self.act1(self.fc1(x))))
109        z_mean = self.fc3(z_logit)
110        z_std = self.exp(self.fc4(z_logit))
111        #z, log_prob_z = self.reparameterization(z_mean, z_std)
112        z, log_prob_z = self.normal('latent', observation=z, mean=z_mean, std=z_std, shape=(), reparameterize=True)
113        return z, log_prob_z
114
115def main():
116    # We currently support pynative mode with device GPU
117    context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
118    epoch_size = 1
119    batch_size = 32
120    mnist_path = "/data/chengzi/zhusuan-mindspore/data/MNIST"
121    repeat_size = 1
122
123    # Define model parameters
124    z_dim = 40
125    x_dim = 32*32
126
127    # create the network
128    generator = Generator(x_dim, z_dim, batch_size)
129    variational = Variational(x_dim, z_dim, batch_size)
130    network = zs.variational.ELBO(generator, variational)
131
132    # define loss
133    # learning rate setting
134    lr = 0.001
135    net_loss = ReduceMeanLoss()
136
137    # define the optimizer
138    print(network.trainable_params()[0])
139    net_opt = nn.Adam(network.trainable_params(), lr)
140
141    model = Model(network, net_loss, net_opt)
142
143    ds_train = create_dataset(os.path.join(mnist_path, "train"), batch_size, repeat_size)
144    model.train(epoch_size, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False)
145
146    print(network.trainable_params()[0])
147
148    iterator = ds_train.create_tuple_iterator()
149    for item in iterator:
150        batch_x = item[0].reshape(32, 32*32)
151        break
152    z, _ = network.variational(Tensor(batch_x), None, None)
153    sample, _, _, _ = network.generator(None, z, None)
154    sample = sample.asnumpy()
155    save_img(batch_x, 'result/origin_x.png')
156    save_img(sample, 'result/reconstruct_x.png')
157
158    for i in range(4):
159        sample, _, _, _ = network.generator(None, None, None)
160        sample = sample.asnumpy()
161        samples = sample if i == 0 else np.concatenate([samples, sample], axis=0)
162    save_img(samples, 'result/sample_x.png', num=4*batch_size)
163
164if __name__ == '__main__':
165    main()
166