• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import os
17import argparse
18
19import mindspore.context as context
20import mindspore.dataset as ds
21import mindspore.dataset.transforms as C
22import mindspore.dataset.vision as CV
23import mindspore.nn as nn
24from mindspore.common import dtype as mstype
25from mindspore.dataset.vision import Inter
26from mindspore.train import Model, LossMonitor, Accuracy
27from mindspore.common.initializer import TruncatedNormal
28from mindspore.communication import init, get_rank, get_group_size
29
30parser = argparse.ArgumentParser(description='test_ps_lenet')
31parser.add_argument("--device_target", type=str, default="GPU")
32parser.add_argument("--dataset_path", type=str, default="/home/workspace/mindspore_dataset/mnist")
33args, _ = parser.parse_known_args()
34device_target = args.device_target
35dataset_path = args.dataset_path
36context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
37
38def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
39    """weight initial for conv layer"""
40    weight = weight_variable()
41    return nn.Conv2d(in_channels, out_channels,
42                     kernel_size=kernel_size, stride=stride, padding=padding,
43                     weight_init=weight, has_bias=False, pad_mode="valid")
44
45
46def fc_with_initialize(input_channels, out_channels):
47    """weight initial for fc layer"""
48    weight = weight_variable()
49    bias = weight_variable()
50    return nn.Dense(input_channels, out_channels, weight, bias)
51
52
53def weight_variable():
54    """weight initial"""
55    return TruncatedNormal(0.02)
56
57
58class LeNet5(nn.Cell):
59    def __init__(self, num_class=10, channel=1):
60        super(LeNet5, self).__init__()
61        self.num_class = num_class
62        self.conv1 = conv(channel, 6, 5)
63        self.conv2 = conv(6, 16, 5)
64        self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
65        self.fc2 = fc_with_initialize(120, 84)
66        self.fc3 = fc_with_initialize(84, self.num_class)
67        self.relu = nn.ReLU()
68        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
69        self.flatten = nn.Flatten()
70
71    def construct(self, x):
72        x = self.conv1(x)
73        x = self.relu(x)
74        x = self.max_pool2d(x)
75        x = self.conv2(x)
76        x = self.relu(x)
77        x = self.max_pool2d(x)
78        x = self.flatten(x)
79        x = self.fc1(x)
80        x = self.relu(x)
81        x = self.fc2(x)
82        x = self.relu(x)
83        x = self.fc3(x)
84        return x
85
86def create_dataset(data_path, batch_size=32, repeat_size=1,
87                   num_parallel_workers=1):
88    """
89    create dataset for train or test
90    """
91    # define dataset
92    mnist_ds = ds.MnistDataset(data_path, num_shards=get_group_size(),
93                               shard_id=get_rank())
94
95    resize_height, resize_width = 32, 32
96    rescale = 1.0 / 255.0
97    shift = 0.0
98    rescale_nml = 1 / 0.3081
99    shift_nml = -1 * 0.1307 / 0.3081
100
101    # define map operations
102    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)  # Bilinear mode
103    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
104    rescale_op = CV.Rescale(rescale, shift)
105    hwc2chw_op = CV.HWC2CHW()
106    type_cast_op = C.TypeCast(mstype.int32)
107
108    # apply map operations on images
109    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
110    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
111    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
112    mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
113    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
114
115    # apply DatasetOps
116    buffer_size = 10000
117    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)  # 10000 as in LeNet train script
118    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
119    mnist_ds = mnist_ds.repeat(repeat_size)
120
121    return mnist_ds
122
123if __name__ == "__main__":
124    init()
125    context.set_auto_parallel_context(parallel_mode="data_parallel", gradients_mean=True, device_num=get_group_size())
126    network = LeNet5(10)
127    net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
128    net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
129    model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
130
131    ds_train = create_dataset(os.path.join(dataset_path, "train"), 32, 1)
132    model.train(5, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=True)
133
134    ds_eval = create_dataset(os.path.join(dataset_path, "test"), 32, 1)
135    acc = model.eval(ds_eval, dataset_sink_mode=True)
136
137    print("=====Accuracy=====")
138    print(acc['Accuracy'])
139