1# Copyright 2020-2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""dataset base and LeNet.""" 16import os 17 18from mindspore import dataset as ds 19from mindspore.common import dtype as mstype 20import mindspore.dataset.transforms as C 21from mindspore.dataset.vision import Inter 22import mindspore.dataset.vision as CV 23from mindspore import nn, Tensor 24from mindspore.common.initializer import Normal 25from mindspore.ops import operations as P 26 27 28def create_mnist_dataset(mode='train', num_samples=2, batch_size=2): 29 """create dataset for train or test""" 30 mnist_path = '/home/workspace/mindspore_dataset/mnist' 31 num_parallel_workers = 1 32 33 # define dataset 34 mnist_ds = ds.MnistDataset(os.path.join(mnist_path, mode), num_samples=num_samples, shuffle=False) 35 36 resize_height, resize_width = 32, 32 37 38 # define map operations 39 resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode 40 rescale_nml_op = CV.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081) 41 rescale_op = CV.Rescale(1.0 / 255.0, shift=0.0) 42 hwc2chw_op = CV.HWC2CHW() 43 type_cast_op = C.TypeCast(mstype.int32) 44 45 # apply map operations on images 46 mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) 47 mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) 48 mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) 49 mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) 50 mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) 51 52 # apply DatasetOps 53 mnist_ds = mnist_ds.batch(batch_size=batch_size, drop_remainder=True) 54 55 return mnist_ds 56 57 58class LeNet5(nn.Cell): 59 """ 60 Lenet network 61 62 Args: 63 num_class (int): Number of classes. Default: 10. 64 num_channel (int): Number of channels. Default: 1. 65 66 Returns: 67 Tensor, output tensor 68 Examples: 69 >>> LeNet(num_class=10) 70 71 """ 72 73 def __init__(self, num_class=10, num_channel=1, include_top=True): 74 super(LeNet5, self).__init__() 75 self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') 76 self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') 77 self.relu = nn.ReLU() 78 self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) 79 self.include_top = include_top 80 if self.include_top: 81 self.flatten = nn.Flatten() 82 self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) 83 self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) 84 self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) 85 86 self.scalar_summary = P.ScalarSummary() 87 self.image_summary = P.ImageSummary() 88 self.histogram_summary = P.HistogramSummary() 89 self.tensor_summary = P.TensorSummary() 90 self.channel = Tensor(num_channel) 91 92 def construct(self, x): 93 """construct.""" 94 self.image_summary('image', x) 95 x = self.conv1(x) 96 self.histogram_summary('histogram', x) 97 x = self.relu(x) 98 self.tensor_summary('tensor', x) 99 x = self.relu(x) 100 x = self.max_pool2d(x) 101 self.scalar_summary('scalar', self.channel) 102 x = self.conv2(x) 103 x = self.relu(x) 104 x = self.max_pool2d(x) 105 if not self.include_top: 106 return x 107 x = self.flatten(x) 108 x = self.relu(self.fc1(x)) 109 x = self.relu(self.fc2(x)) 110 x = self.fc3(x) 111 return x 112