1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import os 16 17from mindspore import dtype as mstype 18import mindspore.dataset as ds 19import mindspore.dataset.vision.c_transforms as CV 20import mindspore.nn as nn 21from mindspore import context, Tensor 22import mindspore.ops as ops 23from mindspore.nn.probability.dpn import ConditionalVAE 24from mindspore.nn.probability.infer import ELBO, SVI 25 26context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 27IMAGE_SHAPE = (-1, 1, 32, 32) 28image_path = os.path.join('/home/workspace/mindspore_dataset/mnist', "train") 29 30 31class Encoder(nn.Cell): 32 def __init__(self, num_classes): 33 super(Encoder, self).__init__() 34 self.fc1 = nn.Dense(1024 + num_classes, 400) 35 self.relu = nn.ReLU() 36 self.flatten = nn.Flatten() 37 self.concat = ops.Concat(axis=1) 38 self.one_hot = nn.OneHot(depth=num_classes) 39 40 def construct(self, x, y): 41 x = self.flatten(x) 42 y = self.one_hot(y) 43 input_x = self.concat((x, y)) 44 input_x = self.fc1(input_x) 45 input_x = self.relu(input_x) 46 return input_x 47 48 49class Decoder(nn.Cell): 50 def __init__(self): 51 super(Decoder, self).__init__() 52 self.fc2 = nn.Dense(400, 1024) 53 self.sigmoid = nn.Sigmoid() 54 self.reshape = ops.Reshape() 55 56 def construct(self, z): 57 z = self.fc2(z) 58 z = self.reshape(z, IMAGE_SHAPE) 59 z = self.sigmoid(z) 60 return z 61 62 63class CVAEWithLossCell(nn.WithLossCell): 64 """ 65 Rewrite WithLossCell for CVAE 66 """ 67 def construct(self, data, label): 68 out = self._backbone(data, label) 69 return self._loss_fn(out, label) 70 71 72def create_dataset(data_path, batch_size=32, repeat_size=1, 73 num_parallel_workers=1): 74 """ 75 create dataset for train or test 76 """ 77 # define dataset 78 mnist_ds = ds.MnistDataset(data_path) 79 80 resize_height, resize_width = 32, 32 81 rescale = 1.0 / 255.0 82 shift = 0.0 83 84 # define map operations 85 resize_op = CV.Resize((resize_height, resize_width)) # Bilinear mode 86 rescale_op = CV.Rescale(rescale, shift) 87 hwc2chw_op = CV.HWC2CHW() 88 89 # apply map operations on images 90 mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) 91 mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) 92 mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) 93 94 # apply DatasetOps 95 mnist_ds = mnist_ds.batch(batch_size) 96 mnist_ds = mnist_ds.repeat(repeat_size) 97 98 return mnist_ds 99 100 101def test_svi_cvae(): 102 # define the encoder and decoder 103 encoder = Encoder(num_classes=10) 104 decoder = Decoder() 105 # define the cvae model 106 cvae = ConditionalVAE(encoder, decoder, hidden_size=400, latent_size=20, num_classes=10) 107 # define the loss function 108 net_loss = ELBO(latent_prior='Normal', output_prior='Normal') 109 # define the optimizer 110 optimizer = nn.Adam(params=cvae.trainable_params(), learning_rate=0.001) 111 # define the training dataset 112 ds_train = create_dataset(image_path, 128, 1) 113 # define the WithLossCell modified 114 net_with_loss = CVAEWithLossCell(cvae, net_loss) 115 # define the variational inference 116 vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer) 117 # run the vi to return the trained network. 118 cvae = vi.run(train_dataset=ds_train, epochs=5) 119 # get the trained loss 120 trained_loss = vi.get_train_loss() 121 # test function: generate_sample 122 sample_label = Tensor([i for i in range(0, 8)] * 8, dtype=mstype.int32) 123 generated_sample = cvae.generate_sample(sample_label, 64, IMAGE_SHAPE) 124 # test function: reconstruct_sample 125 for sample in ds_train.create_dict_iterator(output_numpy=True, num_epochs=1): 126 sample_x = Tensor(sample['image'], dtype=mstype.float32) 127 sample_y = Tensor(sample['label'], dtype=mstype.int32) 128 reconstructed_sample = cvae.reconstruct_sample(sample_x, sample_y) 129 print('The loss of the trained network is ', trained_loss) 130 print('The shape of the generated sample is ', generated_sample.shape) 131 print('The shape of the reconstructed sample is ', reconstructed_sample.shape) 132