1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import os 17import argparse 18 19import mindspore.context as context 20import mindspore.dataset as ds 21import mindspore.dataset.transforms.c_transforms as C 22import mindspore.dataset.vision.c_transforms as CV 23import mindspore.nn as nn 24from mindspore.common import dtype as mstype 25from mindspore.dataset.vision import Inter 26from mindspore.nn.metrics import Accuracy 27from mindspore.train import Model 28from mindspore.train.callback import LossMonitor 29from mindspore.common.initializer import TruncatedNormal 30 31parser = argparse.ArgumentParser(description='test_ps_lenet') 32parser.add_argument("--device_target", type=str, default="Ascend") 33parser.add_argument("--dataset_path", type=str, default="/home/workspace/mindspore_dataset/mnist") 34args, _ = parser.parse_known_args() 35device_target = args.device_target 36dataset_path = args.dataset_path 37context.set_context(mode=context.GRAPH_MODE, device_target=device_target) 38context.set_ps_context(enable_ps=True) 39 40def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): 41 """weight initial for conv layer""" 42 weight = weight_variable() 43 return nn.Conv2d(in_channels, out_channels, 44 kernel_size=kernel_size, stride=stride, padding=padding, 45 weight_init=weight, has_bias=False, pad_mode="valid") 46 47 48def fc_with_initialize(input_channels, out_channels): 49 """weight initial for fc layer""" 50 weight = weight_variable() 51 bias = weight_variable() 52 return nn.Dense(input_channels, out_channels, weight, bias) 53 54 55def weight_variable(): 56 """weight initial""" 57 return TruncatedNormal(0.02) 58 59 60class LeNet5(nn.Cell): 61 def __init__(self, num_class=10, channel=1): 62 super(LeNet5, self).__init__() 63 self.num_class = num_class 64 self.conv1 = conv(channel, 6, 5) 65 self.conv2 = conv(6, 16, 5) 66 self.fc1 = fc_with_initialize(16 * 5 * 5, 120) 67 self.fc2 = fc_with_initialize(120, 84) 68 self.fc3 = fc_with_initialize(84, self.num_class) 69 self.relu = nn.ReLU() 70 self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) 71 self.flatten = nn.Flatten() 72 73 def construct(self, x): 74 x = self.conv1(x) 75 x = self.relu(x) 76 x = self.max_pool2d(x) 77 x = self.conv2(x) 78 x = self.relu(x) 79 x = self.max_pool2d(x) 80 x = self.flatten(x) 81 x = self.fc1(x) 82 x = self.relu(x) 83 x = self.fc2(x) 84 x = self.relu(x) 85 x = self.fc3(x) 86 return x 87 88def create_dataset(data_path, batch_size=32, repeat_size=1, 89 num_parallel_workers=1): 90 """ 91 create dataset for train or test 92 """ 93 # define dataset 94 mnist_ds = ds.MnistDataset(data_path) 95 96 resize_height, resize_width = 32, 32 97 rescale = 1.0 / 255.0 98 shift = 0.0 99 rescale_nml = 1 / 0.3081 100 shift_nml = -1 * 0.1307 / 0.3081 101 102 # define map operations 103 resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode 104 rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) 105 rescale_op = CV.Rescale(rescale, shift) 106 hwc2chw_op = CV.HWC2CHW() 107 type_cast_op = C.TypeCast(mstype.int32) 108 109 # apply map operations on images 110 mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) 111 mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) 112 mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) 113 mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) 114 mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) 115 116 # apply DatasetOps 117 buffer_size = 10000 118 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script 119 mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) 120 mnist_ds = mnist_ds.repeat(repeat_size) 121 122 return mnist_ds 123 124if __name__ == "__main__": 125 network = LeNet5(10) 126 network.set_param_ps() 127 net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") 128 net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) 129 model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) 130 131 ds_train = create_dataset(os.path.join(dataset_path, "train"), 32, 1) 132 model.train(1, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False) 133 134 ds_eval = create_dataset(os.path.join(dataset_path, "test"), 32, 1) 135 acc = model.eval(ds_eval, dataset_sink_mode=False) 136 137 print("Accuracy:", acc['Accuracy']) 138 assert acc['Accuracy'] > 0.83 139