1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import os 17import argparse 18 19import mindspore.context as context 20import mindspore.dataset as ds 21import mindspore.dataset.transforms as C 22import mindspore.dataset.vision as CV 23import mindspore.nn as nn 24from mindspore.common import dtype as mstype 25from mindspore.dataset.vision import Inter 26from mindspore.train import Model, LossMonitor, Accuracy 27from mindspore.common.initializer import TruncatedNormal 28from mindspore.communication import init, get_rank, get_group_size 29 30parser = argparse.ArgumentParser(description='test_ps_lenet') 31parser.add_argument("--device_target", type=str, default="GPU") 32parser.add_argument("--dataset_path", type=str, default="/home/workspace/mindspore_dataset/mnist") 33args, _ = parser.parse_known_args() 34device_target = args.device_target 35dataset_path = args.dataset_path 36context.set_context(mode=context.GRAPH_MODE, device_target=device_target) 37 38def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): 39 """weight initial for conv layer""" 40 weight = weight_variable() 41 return nn.Conv2d(in_channels, out_channels, 42 kernel_size=kernel_size, stride=stride, padding=padding, 43 weight_init=weight, has_bias=False, pad_mode="valid") 44 45 46def fc_with_initialize(input_channels, out_channels): 47 """weight initial for fc layer""" 48 weight = weight_variable() 49 bias = weight_variable() 50 return nn.Dense(input_channels, out_channels, weight, bias) 51 52 53def weight_variable(): 54 """weight initial""" 55 return TruncatedNormal(0.02) 56 57 58class LeNet5(nn.Cell): 59 def __init__(self, num_class=10, channel=1): 60 super(LeNet5, self).__init__() 61 self.num_class = num_class 62 self.conv1 = conv(channel, 6, 5) 63 self.conv2 = conv(6, 16, 5) 64 self.fc1 = fc_with_initialize(16 * 5 * 5, 120) 65 self.fc2 = fc_with_initialize(120, 84) 66 self.fc3 = fc_with_initialize(84, self.num_class) 67 self.relu = nn.ReLU() 68 self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) 69 self.flatten = nn.Flatten() 70 71 def construct(self, x): 72 x = self.conv1(x) 73 x = self.relu(x) 74 x = self.max_pool2d(x) 75 x = self.conv2(x) 76 x = self.relu(x) 77 x = self.max_pool2d(x) 78 x = self.flatten(x) 79 x = self.fc1(x) 80 x = self.relu(x) 81 x = self.fc2(x) 82 x = self.relu(x) 83 x = self.fc3(x) 84 return x 85 86def create_dataset(data_path, batch_size=32, repeat_size=1, 87 num_parallel_workers=1): 88 """ 89 create dataset for train or test 90 """ 91 # define dataset 92 mnist_ds = ds.MnistDataset(data_path, num_shards=get_group_size(), 93 shard_id=get_rank()) 94 95 resize_height, resize_width = 32, 32 96 rescale = 1.0 / 255.0 97 shift = 0.0 98 rescale_nml = 1 / 0.3081 99 shift_nml = -1 * 0.1307 / 0.3081 100 101 # define map operations 102 resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode 103 rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) 104 rescale_op = CV.Rescale(rescale, shift) 105 hwc2chw_op = CV.HWC2CHW() 106 type_cast_op = C.TypeCast(mstype.int32) 107 108 # apply map operations on images 109 mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) 110 mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) 111 mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) 112 mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) 113 mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) 114 115 # apply DatasetOps 116 buffer_size = 10000 117 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script 118 mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) 119 mnist_ds = mnist_ds.repeat(repeat_size) 120 121 return mnist_ds 122 123if __name__ == "__main__": 124 init() 125 context.set_auto_parallel_context(parallel_mode="data_parallel", gradients_mean=True, device_num=get_group_size()) 126 network = LeNet5(10) 127 net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") 128 net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) 129 model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) 130 131 ds_train = create_dataset(os.path.join(dataset_path, "train"), 32, 1) 132 model.train(5, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=True) 133 134 ds_eval = create_dataset(os.path.join(dataset_path, "test"), 32, 1) 135 acc = model.eval(ds_eval, dataset_sink_mode=True) 136 137 print("=====Accuracy=====") 138 print(acc['Accuracy']) 139