• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import os
16
17from mindspore import dtype as mstype
18import mindspore.dataset as ds
19import mindspore.dataset.vision.c_transforms as CV
20import mindspore.nn as nn
21from mindspore import context, Tensor
22import mindspore.ops as ops
23from mindspore.nn.probability.dpn import ConditionalVAE
24from mindspore.nn.probability.infer import ELBO, SVI
25
26context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
27IMAGE_SHAPE = (-1, 1, 32, 32)
28image_path = os.path.join('/home/workspace/mindspore_dataset/mnist', "train")
29
30
31class Encoder(nn.Cell):
32    def __init__(self, num_classes):
33        super(Encoder, self).__init__()
34        self.fc1 = nn.Dense(1024 + num_classes, 400)
35        self.relu = nn.ReLU()
36        self.flatten = nn.Flatten()
37        self.concat = ops.Concat(axis=1)
38        self.one_hot = nn.OneHot(depth=num_classes)
39
40    def construct(self, x, y):
41        x = self.flatten(x)
42        y = self.one_hot(y)
43        input_x = self.concat((x, y))
44        input_x = self.fc1(input_x)
45        input_x = self.relu(input_x)
46        return input_x
47
48
49class Decoder(nn.Cell):
50    def __init__(self):
51        super(Decoder, self).__init__()
52        self.fc2 = nn.Dense(400, 1024)
53        self.sigmoid = nn.Sigmoid()
54        self.reshape = ops.Reshape()
55
56    def construct(self, z):
57        z = self.fc2(z)
58        z = self.reshape(z, IMAGE_SHAPE)
59        z = self.sigmoid(z)
60        return z
61
62
63class CVAEWithLossCell(nn.WithLossCell):
64    """
65    Rewrite WithLossCell for CVAE
66    """
67    def construct(self, data, label):
68        out = self._backbone(data, label)
69        return self._loss_fn(out, label)
70
71
72def create_dataset(data_path, batch_size=32, repeat_size=1,
73                   num_parallel_workers=1):
74    """
75    create dataset for train or test
76    """
77    # define dataset
78    mnist_ds = ds.MnistDataset(data_path)
79
80    resize_height, resize_width = 32, 32
81    rescale = 1.0 / 255.0
82    shift = 0.0
83
84    # define map operations
85    resize_op = CV.Resize((resize_height, resize_width))  # Bilinear mode
86    rescale_op = CV.Rescale(rescale, shift)
87    hwc2chw_op = CV.HWC2CHW()
88
89    # apply map operations on images
90    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
91    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
92    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
93
94    # apply DatasetOps
95    mnist_ds = mnist_ds.batch(batch_size)
96    mnist_ds = mnist_ds.repeat(repeat_size)
97
98    return mnist_ds
99
100
101def test_svi_cvae():
102    # define the encoder and decoder
103    encoder = Encoder(num_classes=10)
104    decoder = Decoder()
105    # define the cvae model
106    cvae = ConditionalVAE(encoder, decoder, hidden_size=400, latent_size=20, num_classes=10)
107    # define the loss function
108    net_loss = ELBO(latent_prior='Normal', output_prior='Normal')
109    # define the optimizer
110    optimizer = nn.Adam(params=cvae.trainable_params(), learning_rate=0.001)
111    # define the training dataset
112    ds_train = create_dataset(image_path, 128, 1)
113    # define the WithLossCell modified
114    net_with_loss = CVAEWithLossCell(cvae, net_loss)
115    # define the variational inference
116    vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer)
117    # run the vi to return the trained network.
118    cvae = vi.run(train_dataset=ds_train, epochs=5)
119    # get the trained loss
120    trained_loss = vi.get_train_loss()
121    # test function: generate_sample
122    sample_label = Tensor([i for i in range(0, 8)] * 8, dtype=mstype.int32)
123    generated_sample = cvae.generate_sample(sample_label, 64, IMAGE_SHAPE)
124    # test function: reconstruct_sample
125    for sample in ds_train.create_dict_iterator(output_numpy=True, num_epochs=1):
126        sample_x = Tensor(sample['image'], dtype=mstype.float32)
127        sample_y = Tensor(sample['label'], dtype=mstype.int32)
128        reconstructed_sample = cvae.reconstruct_sample(sample_x, sample_y)
129    print('The loss of the trained network is ', trained_loss)
130    print('The shape of the generated sample is ', generated_sample.shape)
131    print('The shape of the reconstructed sample is ', reconstructed_sample.shape)
132