• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import os
16
17from mindspore import dtype as mstype
18import mindspore.dataset as ds
19import mindspore.dataset.vision.c_transforms as CV
20import mindspore.nn as nn
21from mindspore import context, Tensor
22import mindspore.ops as ops
23from mindspore.nn.probability.dpn import VAE
24from mindspore.nn.probability.infer import ELBO, SVI
25
26context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
27IMAGE_SHAPE = (-1, 1, 32, 32)
28image_path = os.path.join('/home/workspace/mindspore_dataset/mnist', "train")
29
30
31class Encoder(nn.Cell):
32    def __init__(self):
33        super(Encoder, self).__init__()
34        self.fc1 = nn.Dense(1024, 800)
35        self.fc2 = nn.Dense(800, 400)
36        self.relu = nn.ReLU()
37        self.flatten = nn.Flatten()
38
39    def construct(self, x):
40        x = self.flatten(x)
41        x = self.fc1(x)
42        x = self.relu(x)
43        x = self.fc2(x)
44        x = self.relu(x)
45        return x
46
47
48class Decoder(nn.Cell):
49    def __init__(self):
50        super(Decoder, self).__init__()
51        self.fc1 = nn.Dense(400, 1024)
52        self.sigmoid = nn.Sigmoid()
53        self.reshape = ops.Reshape()
54
55    def construct(self, z):
56        z = self.fc1(z)
57        z = self.reshape(z, IMAGE_SHAPE)
58        z = self.sigmoid(z)
59        return z
60
61
62def create_dataset(data_path, batch_size=32, repeat_size=1,
63                   num_parallel_workers=1):
64    """
65    create dataset for train or test
66    """
67    # define dataset
68    mnist_ds = ds.MnistDataset(data_path)
69
70    resize_height, resize_width = 32, 32
71    rescale = 1.0 / 255.0
72    shift = 0.0
73
74    # define map operations
75    resize_op = CV.Resize((resize_height, resize_width))  # Bilinear mode
76    rescale_op = CV.Rescale(rescale, shift)
77    hwc2chw_op = CV.HWC2CHW()
78
79    # apply map operations on images
80    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
81    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
82    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
83
84    # apply DatasetOps
85    mnist_ds = mnist_ds.batch(batch_size)
86    mnist_ds = mnist_ds.repeat(repeat_size)
87
88    return mnist_ds
89
90
91def test_svi_vae():
92    # define the encoder and decoder
93    encoder = Encoder()
94    decoder = Decoder()
95    # define the vae model
96    vae = VAE(encoder, decoder, hidden_size=400, latent_size=20)
97    # define the loss function
98    net_loss = ELBO(latent_prior='Normal', output_prior='Normal')
99    # define the optimizer
100    optimizer = nn.Adam(params=vae.trainable_params(), learning_rate=0.001)
101    # define the training dataset
102    ds_train = create_dataset(image_path, 128, 1)
103    net_with_loss = nn.WithLossCell(vae, net_loss)
104    # define the variational inference
105    vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer)
106    # run the vi to return the trained network.
107    vae = vi.run(train_dataset=ds_train, epochs=5)
108    # get the trained loss
109    trained_loss = vi.get_train_loss()
110    # test function: generate_sample
111    generated_sample = vae.generate_sample(64, IMAGE_SHAPE)
112    # test function: reconstruct_sample
113    for sample in ds_train.create_dict_iterator(output_numpy=True, num_epochs=1):
114        sample_x = Tensor(sample['image'], dtype=mstype.float32)
115        reconstructed_sample = vae.reconstruct_sample(sample_x)
116    print('The loss of the trained network is ', trained_loss)
117    print('The hape of the generated sample is ', generated_sample.shape)
118    print('The shape of the reconstructed sample is ', reconstructed_sample.shape)
119