1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import os 16from argparse import ArgumentParser 17 18from mindspore import dataset as ds 19from mindspore import nn, Tensor, context 20from mindspore.train import Accuracy 21from mindspore.nn.optim import Momentum 22from mindspore.dataset.transforms import transforms as C 23from mindspore.dataset.vision import transforms as CV 24from mindspore.dataset.vision import Inter 25from mindspore.common import dtype as mstype 26from mindspore.common.initializer import TruncatedNormal 27from mindspore.train import Model 28 29 30def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): 31 """weight initial for conv layer""" 32 weight = weight_variable() 33 return nn.Conv2d(in_channels, out_channels, 34 kernel_size=kernel_size, stride=stride, padding=padding, 35 weight_init=weight, has_bias=False, pad_mode="valid") 36 37 38def fc_with_initialize(input_channels, out_channels): 39 """weight initial for fc layer""" 40 weight = weight_variable() 41 bias = weight_variable() 42 return nn.Dense(input_channels, out_channels, weight, bias) 43 44 45def weight_variable(): 46 """weight initial""" 47 return TruncatedNormal(0.02) 48 49 50class LeNet5(nn.Cell): 51 """Define LeNet5 network.""" 52 53 def __init__(self, num_class=10, channel=1): 54 """Net init.""" 55 super(LeNet5, self).__init__() 56 self.num_class = num_class 57 self.conv1 = conv(channel, 6, 5) 58 self.conv2 = conv(6, 16, 5) 59 self.fc1 = fc_with_initialize(16 * 5 * 5, 120) 60 self.fc2 = fc_with_initialize(120, 84) 61 self.fc3 = fc_with_initialize(84, self.num_class) 62 self.relu = nn.ReLU() 63 self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) 64 self.flatten = nn.Flatten() 65 self.channel = Tensor(channel) 66 67 def construct(self, data): 68 """define construct.""" 69 output = self.conv1(data) 70 output = self.relu(output) 71 output = self.max_pool2d(output) 72 output = self.conv2(output) 73 output = self.relu(output) 74 output = self.max_pool2d(output) 75 output = self.flatten(output) 76 output = self.fc1(output) 77 output = self.relu(output) 78 output = self.fc2(output) 79 output = self.relu(output) 80 output = self.fc3(output) 81 return output 82 83 84def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1): 85 """create dataset for train""" 86 # define dataset 87 mnist_ds = ds.MnistDataset(data_path, num_samples=batch_size * 10) 88 89 resize_height, resize_width = 32, 32 90 rescale = 1.0 / 255.0 91 rescale_nml = 1 / 0.3081 92 shift_nml = -1 * 0.1307 / 0.3081 93 94 # define map operations 95 resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode 96 rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) 97 rescale_op = CV.Rescale(rescale, shift=0.0) 98 hwc2chw_op = CV.HWC2CHW() 99 type_cast_op = C.TypeCast(mstype.int32) 100 101 # apply map operations on images 102 mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) 103 mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) 104 mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) 105 mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) 106 mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) 107 108 # apply DatasetOps 109 mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) 110 mnist_ds = mnist_ds.repeat(repeat_size) 111 112 return mnist_ds 113 114 115def train_with_profiler(): 116 """Train Net with profiling.""" 117 target = args.target 118 mode = args.mode 119 mnist_path = '/home/workspace/mindspore_dataset/mnist' 120 context.set_context(mode=mode, device_target=target) 121 ds_train = create_dataset(os.path.join(mnist_path, "train")) 122 if ds_train.get_dataset_size() == 0: 123 raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") 124 125 lenet = LeNet5() 126 loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") 127 optim = Momentum(lenet.trainable_params(), learning_rate=0.1, momentum=0.9) 128 model = Model(lenet, loss_fn=loss, optimizer=optim, metrics={'acc': Accuracy()}) 129 model.train(1, ds_train, dataset_sink_mode=True) 130 131 132parser = ArgumentParser(description='test env enable profiler') 133parser.add_argument('--target', type=str) 134parser.add_argument('--mode', type=int) 135args = parser.parse_args() 136train_with_profiler() 137