• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Produce the dataset
17"""
18
19import mindspore.dataset as ds
20import mindspore.dataset.vision.c_transforms as CV
21import mindspore.dataset.transforms.c_transforms as C
22from mindspore.dataset.vision import Inter
23from mindspore import dtype as mstype
24
25
26def create_dataset(data_path, batch_size=32, repeat_size=1,
27                   num_parallel_workers=1):
28    """
29    create dataset for train or test
30    """
31    # define dataset
32    mnist_ds = ds.MnistDataset(data_path)
33
34    resize_height, resize_width = 32, 32
35    rescale = 1.0 / 255.0
36    shift = 0.0
37    rescale_nml = 1 / 0.3081
38    shift_nml = -1 * 0.1307 / 0.3081
39
40    # define map operations
41    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)  # Bilinear mode
42    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
43    rescale_op = CV.Rescale(rescale, shift)
44    hwc2chw_op = CV.HWC2CHW()
45    type_cast_op = C.TypeCast(mstype.int32)
46
47    # apply map operations on images
48    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
49    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
50    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
51    mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
52    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
53
54    # apply DatasetOps
55    buffer_size = 10000
56    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)  # 10000 as in LeNet train script
57    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
58    mnist_ds = mnist_ds.repeat(repeat_size)
59
60    return mnist_ds
61