• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import os
17import argparse
18
19import mindspore.context as context
20import mindspore.dataset as ds
21import mindspore.dataset.transforms.c_transforms as C
22import mindspore.dataset.vision.c_transforms as CV
23import mindspore.nn as nn
24from mindspore.common import dtype as mstype
25from mindspore.dataset.vision import Inter
26from mindspore.nn.metrics import Accuracy
27from mindspore.train import Model
28from mindspore.train.callback import LossMonitor
29from mindspore.common.initializer import TruncatedNormal
30
31parser = argparse.ArgumentParser(description='test_ps_lenet')
32parser.add_argument("--device_target", type=str, default="Ascend")
33parser.add_argument("--dataset_path", type=str, default="/home/workspace/mindspore_dataset/mnist")
34args, _ = parser.parse_known_args()
35device_target = args.device_target
36dataset_path = args.dataset_path
37context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
38context.set_ps_context(enable_ps=True)
39
40def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
41    """weight initial for conv layer"""
42    weight = weight_variable()
43    return nn.Conv2d(in_channels, out_channels,
44                     kernel_size=kernel_size, stride=stride, padding=padding,
45                     weight_init=weight, has_bias=False, pad_mode="valid")
46
47
48def fc_with_initialize(input_channels, out_channels):
49    """weight initial for fc layer"""
50    weight = weight_variable()
51    bias = weight_variable()
52    return nn.Dense(input_channels, out_channels, weight, bias)
53
54
55def weight_variable():
56    """weight initial"""
57    return TruncatedNormal(0.02)
58
59
60class LeNet5(nn.Cell):
61    def __init__(self, num_class=10, channel=1):
62        super(LeNet5, self).__init__()
63        self.num_class = num_class
64        self.conv1 = conv(channel, 6, 5)
65        self.conv2 = conv(6, 16, 5)
66        self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
67        self.fc2 = fc_with_initialize(120, 84)
68        self.fc3 = fc_with_initialize(84, self.num_class)
69        self.relu = nn.ReLU()
70        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
71        self.flatten = nn.Flatten()
72
73    def construct(self, x):
74        x = self.conv1(x)
75        x = self.relu(x)
76        x = self.max_pool2d(x)
77        x = self.conv2(x)
78        x = self.relu(x)
79        x = self.max_pool2d(x)
80        x = self.flatten(x)
81        x = self.fc1(x)
82        x = self.relu(x)
83        x = self.fc2(x)
84        x = self.relu(x)
85        x = self.fc3(x)
86        return x
87
88def create_dataset(data_path, batch_size=32, repeat_size=1,
89                   num_parallel_workers=1):
90    """
91    create dataset for train or test
92    """
93    # define dataset
94    mnist_ds = ds.MnistDataset(data_path)
95
96    resize_height, resize_width = 32, 32
97    rescale = 1.0 / 255.0
98    shift = 0.0
99    rescale_nml = 1 / 0.3081
100    shift_nml = -1 * 0.1307 / 0.3081
101
102    # define map operations
103    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)  # Bilinear mode
104    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
105    rescale_op = CV.Rescale(rescale, shift)
106    hwc2chw_op = CV.HWC2CHW()
107    type_cast_op = C.TypeCast(mstype.int32)
108
109    # apply map operations on images
110    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
111    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
112    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
113    mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
114    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
115
116    # apply DatasetOps
117    buffer_size = 10000
118    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)  # 10000 as in LeNet train script
119    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
120    mnist_ds = mnist_ds.repeat(repeat_size)
121
122    return mnist_ds
123
124if __name__ == "__main__":
125    network = LeNet5(10)
126    network.set_param_ps()
127    net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
128    net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
129    model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
130
131    ds_train = create_dataset(os.path.join(dataset_path, "train"), 32, 1)
132    model.train(1, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False)
133
134    ds_eval = create_dataset(os.path.join(dataset_path, "test"), 32, 1)
135    acc = model.eval(ds_eval, dataset_sink_mode=False)
136
137    print("Accuracy:", acc['Accuracy'])
138    assert acc['Accuracy'] > 0.83
139