1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" VAE """ 16 17import os 18import numpy as np 19 20from utils import create_dataset, save_img 21 22import mindspore.nn as nn 23 24from mindspore import context 25from mindspore import Tensor 26from mindspore.train import Model 27from mindspore.train.callback import LossMonitor 28from mindspore.ops import operations as P 29from mindspore.common import dtype as mstype 30 31import zhusuan as zs 32 33class ReduceMeanLoss(nn.L1Loss): 34 def construct(self, base, target): 35 # return self.get_loss(x) 36 return base 37 38class Generator(zs.BayesianNet): 39 """ Generator """ 40 def __init__(self, x_dim, z_dim, batch_size): 41 super().__init__() 42 self.x_dim = x_dim 43 self.z_dim = z_dim 44 self.batch_size = batch_size 45 46 self.fc1 = nn.Dense(z_dim, 500) 47 self.act1 = nn.ReLU() 48 self.fc2 = nn.Dense(500, 500) 49 self.act2 = nn.ReLU() 50 self.fc3 = nn.Dense(500, x_dim) 51 self.fill = P.Fill() 52 self.sigmoid = P.Sigmoid() 53 self.reshape_op = P.Reshape() 54 55 def ones(self, shape): 56 return self.fill(mstype.float32, shape, 1.) 57 58 def zeros(self, shape): 59 return self.fill(mstype.float32, shape, 0.) 60 61 def construct(self, x, z, y): 62 """ construct """ 63 assert y is None ## we have no conditional information 64 65 if not x is None: 66 x = self.reshape_op(x, (32, 32*32)) 67 68 z_mean = self.zeros((self.batch_size, self.z_dim)) 69 z_std = self.ones((self.batch_size, self.z_dim)) 70 z, log_prob_z = self.normal('latent', observation=z, mean=z_mean, std=z_std, shape=(), reparameterize=False) 71 72 x_mean = self.sigmoid(self.fc3(self.act2(self.fc2(self.act1(self.fc1(z)))))) 73 if x is None: 74 #x = self.bernoulli_dist('sample', (), x_mean) 75 x = x_mean 76 x, log_prob_x = self.bernoulli('data', observation=x, shape=(), probs=x_mean) 77 78 return x, log_prob_x, z, log_prob_z 79 80class Variational(zs.BayesianNet): 81 """ Variational """ 82 def __init__(self, x_dim, z_dim, batch_size): 83 super().__init__() 84 self.x_dim = x_dim 85 self.z_dim = z_dim 86 self.batch_size = batch_size 87 self.reshape_op = P.Reshape() 88 89 self.fc1 = nn.Dense(x_dim, 500) 90 self.act1 = nn.ReLU() 91 self.fc2 = nn.Dense(500, 500) 92 self.act2 = nn.ReLU() 93 self.fc3 = nn.Dense(500, z_dim) 94 self.fc4 = nn.Dense(500, z_dim) 95 self.fill = P.Fill() 96 self.exp = P.Exp() 97 98 def ones(self, shape): 99 return self.fill(mstype.float32, shape, 1.) 100 101 def zeros(self, shape): 102 return self.fill(mstype.float32, shape, 0.) 103 104 def construct(self, x, z, y): 105 """ construct """ 106 assert y is None ## we have no conditional information 107 x = self.reshape_op(x, (32, 32*32)) 108 z_logit = self.act2(self.fc2(self.act1(self.fc1(x)))) 109 z_mean = self.fc3(z_logit) 110 z_std = self.exp(self.fc4(z_logit)) 111 #z, log_prob_z = self.reparameterization(z_mean, z_std) 112 z, log_prob_z = self.normal('latent', observation=z, mean=z_mean, std=z_std, shape=(), reparameterize=True) 113 return z, log_prob_z 114 115def main(): 116 # We currently support pynative mode with device GPU 117 context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') 118 epoch_size = 1 119 batch_size = 32 120 mnist_path = "/data/chengzi/zhusuan-mindspore/data/MNIST" 121 repeat_size = 1 122 123 # Define model parameters 124 z_dim = 40 125 x_dim = 32*32 126 127 # create the network 128 generator = Generator(x_dim, z_dim, batch_size) 129 variational = Variational(x_dim, z_dim, batch_size) 130 network = zs.variational.ELBO(generator, variational) 131 132 # define loss 133 # learning rate setting 134 lr = 0.001 135 net_loss = ReduceMeanLoss() 136 137 # define the optimizer 138 print(network.trainable_params()[0]) 139 net_opt = nn.Adam(network.trainable_params(), lr) 140 141 model = Model(network, net_loss, net_opt) 142 143 ds_train = create_dataset(os.path.join(mnist_path, "train"), batch_size, repeat_size) 144 model.train(epoch_size, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False) 145 146 print(network.trainable_params()[0]) 147 148 iterator = ds_train.create_tuple_iterator() 149 for item in iterator: 150 batch_x = item[0].reshape(32, 32*32) 151 break 152 z, _ = network.variational(Tensor(batch_x), None, None) 153 sample, _, _, _ = network.generator(None, z, None) 154 sample = sample.asnumpy() 155 save_img(batch_x, 'result/origin_x.png') 156 save_img(sample, 'result/reconstruct_x.png') 157 158 for i in range(4): 159 sample, _, _, _ = network.generator(None, None, None) 160 sample = sample.asnumpy() 161 samples = sample if i == 0 else np.concatenate([samples, sample], axis=0) 162 save_img(samples, 'result/sample_x.png', num=4*batch_size) 163 164if __name__ == '__main__': 165 main() 166