• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test bnn layers"""
16
17import numpy as np
18from mindspore import Tensor
19from mindspore.common.initializer import TruncatedNormal
20import mindspore.nn as nn
21from mindspore.nn import TrainOneStepCell
22from mindspore.nn.probability import bnn_layers
23import mindspore.ops as ops
24from mindspore import context
25from dataset import create_dataset
26
27context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
28
29
30def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
31    """weight initial for conv layer"""
32    weight = weight_variable()
33    return nn.Conv2d(in_channels, out_channels,
34                     kernel_size=kernel_size, stride=stride, padding=padding,
35                     weight_init=weight, has_bias=False, pad_mode="valid")
36
37
38def fc_with_initialize(input_channels, out_channels):
39    """weight initial for fc layer"""
40    weight = weight_variable()
41    bias = weight_variable()
42    return nn.Dense(input_channels, out_channels, weight, bias)
43
44
45def weight_variable():
46    """weight initial"""
47    return TruncatedNormal(0.02)
48
49
50class BNNLeNet5(nn.Cell):
51    """
52    bayesian Lenet network
53
54    Args:
55        num_class (int): Num classes. Default: 10.
56
57    Returns:
58        Tensor, output tensor
59    Examples:
60        >>> BNNLeNet5(num_class=10)
61
62    """
63    def __init__(self, num_class=10):
64        super(BNNLeNet5, self).__init__()
65        self.num_class = num_class
66        self.conv1 = bnn_layers.ConvReparam(1, 6, 5, stride=1, padding=0, has_bias=False, pad_mode="valid")
67        self.conv2 = conv(6, 16, 5)
68        self.fc1 = bnn_layers.DenseReparam(16 * 5 * 5, 120)
69        self.fc2 = fc_with_initialize(120, 84)
70        self.fc3 = fc_with_initialize(84, self.num_class)
71        self.relu = nn.ReLU()
72        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
73        self.flatten = nn.Flatten()
74        self.reshape = ops.Reshape()
75
76    def construct(self, x):
77        x = self.conv1(x)
78        x = self.relu(x)
79        x = self.max_pool2d(x)
80        x = self.conv2(x)
81        x = self.relu(x)
82        x = self.max_pool2d(x)
83        x = self.flatten(x)
84        x = self.fc1(x)
85        x = self.relu(x)
86        x = self.fc2(x)
87        x = self.relu(x)
88        x = self.fc3(x)
89        return x
90
91
92def train_model(train_net, net, dataset):
93    accs = []
94    loss_sum = 0
95    for _, data in enumerate(dataset.create_dict_iterator(output_numpy=True, num_epochs=1)):
96        train_x = Tensor(data['image'].astype(np.float32))
97        label = Tensor(data['label'].astype(np.int32))
98        loss = train_net(train_x, label)
99        output = net(train_x)
100        log_output = ops.LogSoftmax(axis=1)(output)
101        acc = np.mean(log_output.asnumpy().argmax(axis=1) == label.asnumpy())
102        accs.append(acc)
103        loss_sum += loss.asnumpy()
104
105    loss_sum = loss_sum / len(accs)
106    acc_mean = np.mean(accs)
107    return loss_sum, acc_mean
108
109
110def validate_model(net, dataset):
111    accs = []
112    for _, data in enumerate(dataset.create_dict_iterator(output_numpy=True, num_epochs=1)):
113        train_x = Tensor(data['image'].astype(np.float32))
114        label = Tensor(data['label'].astype(np.int32))
115        output = net(train_x)
116        log_output = ops.LogSoftmax(axis=1)(output)
117        acc = np.mean(log_output.asnumpy().argmax(axis=1) == label.asnumpy())
118        accs.append(acc)
119
120    acc_mean = np.mean(accs)
121    return acc_mean
122
123
124if __name__ == "__main__":
125    network = BNNLeNet5()
126
127    criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
128    optimizer = nn.AdamWeightDecay(params=network.trainable_params(), learning_rate=0.0001)
129
130    net_with_loss = bnn_layers.WithBNNLossCell(network, criterion, 60000, 0.000001)
131    train_bnn_network = TrainOneStepCell(net_with_loss, optimizer)
132    train_bnn_network.set_train()
133
134    train_set = create_dataset('/home/workspace/mindspore_dataset/mnist_data/train', 64, 1)
135    test_set = create_dataset('/home/workspace/mindspore_dataset/mnist_data/test', 64, 1)
136
137    epoch = 100
138
139    for i in range(epoch):
140        train_loss, train_acc = train_model(train_bnn_network, network, train_set)
141
142        valid_acc = validate_model(network, test_set)
143
144        print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tvalidation Accuracy: {:.4f}'.format(
145            i, train_loss, train_acc, valid_acc))
146