1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import os 16 17from mindspore import dtype as mstype 18import mindspore.dataset as ds 19import mindspore.dataset.vision.c_transforms as CV 20import mindspore.nn as nn 21from mindspore import context, Tensor 22import mindspore.ops as ops 23from mindspore.nn.probability.dpn import VAE 24from mindspore.nn.probability.infer import ELBO, SVI 25 26context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 27IMAGE_SHAPE = (-1, 1, 32, 32) 28image_path = os.path.join('/home/workspace/mindspore_dataset/mnist', "train") 29 30 31class Encoder(nn.Cell): 32 def __init__(self): 33 super(Encoder, self).__init__() 34 self.fc1 = nn.Dense(1024, 800) 35 self.fc2 = nn.Dense(800, 400) 36 self.relu = nn.ReLU() 37 self.flatten = nn.Flatten() 38 39 def construct(self, x): 40 x = self.flatten(x) 41 x = self.fc1(x) 42 x = self.relu(x) 43 x = self.fc2(x) 44 x = self.relu(x) 45 return x 46 47 48class Decoder(nn.Cell): 49 def __init__(self): 50 super(Decoder, self).__init__() 51 self.fc1 = nn.Dense(400, 1024) 52 self.sigmoid = nn.Sigmoid() 53 self.reshape = ops.Reshape() 54 55 def construct(self, z): 56 z = self.fc1(z) 57 z = self.reshape(z, IMAGE_SHAPE) 58 z = self.sigmoid(z) 59 return z 60 61 62def create_dataset(data_path, batch_size=32, repeat_size=1, 63 num_parallel_workers=1): 64 """ 65 create dataset for train or test 66 """ 67 # define dataset 68 mnist_ds = ds.MnistDataset(data_path) 69 70 resize_height, resize_width = 32, 32 71 rescale = 1.0 / 255.0 72 shift = 0.0 73 74 # define map operations 75 resize_op = CV.Resize((resize_height, resize_width)) # Bilinear mode 76 rescale_op = CV.Rescale(rescale, shift) 77 hwc2chw_op = CV.HWC2CHW() 78 79 # apply map operations on images 80 mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) 81 mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) 82 mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) 83 84 # apply DatasetOps 85 mnist_ds = mnist_ds.batch(batch_size) 86 mnist_ds = mnist_ds.repeat(repeat_size) 87 88 return mnist_ds 89 90 91def test_svi_vae(): 92 # define the encoder and decoder 93 encoder = Encoder() 94 decoder = Decoder() 95 # define the vae model 96 vae = VAE(encoder, decoder, hidden_size=400, latent_size=20) 97 # define the loss function 98 net_loss = ELBO(latent_prior='Normal', output_prior='Normal') 99 # define the optimizer 100 optimizer = nn.Adam(params=vae.trainable_params(), learning_rate=0.001) 101 # define the training dataset 102 ds_train = create_dataset(image_path, 128, 1) 103 net_with_loss = nn.WithLossCell(vae, net_loss) 104 # define the variational inference 105 vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer) 106 # run the vi to return the trained network. 107 vae = vi.run(train_dataset=ds_train, epochs=5) 108 # get the trained loss 109 trained_loss = vi.get_train_loss() 110 # test function: generate_sample 111 generated_sample = vae.generate_sample(64, IMAGE_SHAPE) 112 # test function: reconstruct_sample 113 for sample in ds_train.create_dict_iterator(output_numpy=True, num_epochs=1): 114 sample_x = Tensor(sample['image'], dtype=mstype.float32) 115 reconstructed_sample = vae.reconstruct_sample(sample_x) 116 print('The loss of the trained network is ', trained_loss) 117 print('The hape of the generated sample is ', generated_sample.shape) 118 print('The shape of the reconstructed sample is ', reconstructed_sample.shape) 119