• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import os
16from argparse import ArgumentParser
17
18from mindspore import dataset as ds
19from mindspore import nn, Tensor, context
20from mindspore.train import Accuracy
21from mindspore.nn.optim import Momentum
22from mindspore.dataset.transforms import transforms as C
23from mindspore.dataset.vision import transforms as CV
24from mindspore.dataset.vision import Inter
25from mindspore.common import dtype as mstype
26from mindspore.common.initializer import TruncatedNormal
27from mindspore.train import Model
28
29
30def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
31    """weight initial for conv layer"""
32    weight = weight_variable()
33    return nn.Conv2d(in_channels, out_channels,
34                     kernel_size=kernel_size, stride=stride, padding=padding,
35                     weight_init=weight, has_bias=False, pad_mode="valid")
36
37
38def fc_with_initialize(input_channels, out_channels):
39    """weight initial for fc layer"""
40    weight = weight_variable()
41    bias = weight_variable()
42    return nn.Dense(input_channels, out_channels, weight, bias)
43
44
45def weight_variable():
46    """weight initial"""
47    return TruncatedNormal(0.02)
48
49
50class LeNet5(nn.Cell):
51    """Define LeNet5 network."""
52
53    def __init__(self, num_class=10, channel=1):
54        """Net init."""
55        super(LeNet5, self).__init__()
56        self.num_class = num_class
57        self.conv1 = conv(channel, 6, 5)
58        self.conv2 = conv(6, 16, 5)
59        self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
60        self.fc2 = fc_with_initialize(120, 84)
61        self.fc3 = fc_with_initialize(84, self.num_class)
62        self.relu = nn.ReLU()
63        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
64        self.flatten = nn.Flatten()
65        self.channel = Tensor(channel)
66
67    def construct(self, data):
68        """define construct."""
69        output = self.conv1(data)
70        output = self.relu(output)
71        output = self.max_pool2d(output)
72        output = self.conv2(output)
73        output = self.relu(output)
74        output = self.max_pool2d(output)
75        output = self.flatten(output)
76        output = self.fc1(output)
77        output = self.relu(output)
78        output = self.fc2(output)
79        output = self.relu(output)
80        output = self.fc3(output)
81        return output
82
83
84def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1):
85    """create dataset for train"""
86    # define dataset
87    mnist_ds = ds.MnistDataset(data_path, num_samples=batch_size * 10)
88
89    resize_height, resize_width = 32, 32
90    rescale = 1.0 / 255.0
91    rescale_nml = 1 / 0.3081
92    shift_nml = -1 * 0.1307 / 0.3081
93
94    # define map operations
95    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)  # Bilinear mode
96    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
97    rescale_op = CV.Rescale(rescale, shift=0.0)
98    hwc2chw_op = CV.HWC2CHW()
99    type_cast_op = C.TypeCast(mstype.int32)
100
101    # apply map operations on images
102    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
103    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
104    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
105    mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
106    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
107
108    # apply DatasetOps
109    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
110    mnist_ds = mnist_ds.repeat(repeat_size)
111
112    return mnist_ds
113
114
115def train_with_profiler():
116    """Train Net with profiling."""
117    target = args.target
118    mode = args.mode
119    mnist_path = '/home/workspace/mindspore_dataset/mnist'
120    context.set_context(mode=mode, device_target=target)
121    ds_train = create_dataset(os.path.join(mnist_path, "train"))
122    if ds_train.get_dataset_size() == 0:
123        raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
124
125    lenet = LeNet5()
126    loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
127    optim = Momentum(lenet.trainable_params(), learning_rate=0.1, momentum=0.9)
128    model = Model(lenet, loss_fn=loss, optimizer=optim, metrics={'acc': Accuracy()})
129    model.train(1, ds_train, dataset_sink_mode=True)
130
131
132parser = ArgumentParser(description='test env enable profiler')
133parser.add_argument('--target', type=str)
134parser.add_argument('--mode', type=int)
135args = parser.parse_args()
136train_with_profiler()
137