• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import sys
16import numpy as np
17
18import mindspore.context as context
19import mindspore.dataset as ds
20import mindspore.dataset.vision.c_transforms as vision
21import mindspore.nn as nn
22from mindspore.common.api import _cell_graph_executor
23from mindspore.common.tensor import Tensor
24from mindspore.dataset.vision import Inter
25from mindspore.ops import operations as P
26
27context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
28data_path = sys.argv[1]
29SCHEMA_DIR = "{0}/resnet_all_datasetSchema.json".format(data_path)
30
31
32def test_me_de_train_dataset():
33    data_list = ["{0}/train-00001-of-01024.data".format(data_path)]
34    data_set_new = ds.TFRecordDataset(data_list, schema=SCHEMA_DIR,
35                                      columns_list=["image/encoded", "image/class/label"])
36
37    resize_height = 224
38    resize_width = 224
39    rescale = 1.0 / 255.0
40    shift = 0.0
41
42    # define map operations
43
44    decode_op = vision.Decode()
45    resize_op = vision.Resize((resize_height, resize_width),
46                              Inter.LINEAR)  # Bilinear as default
47    rescale_op = vision.Rescale(rescale, shift)
48
49    # apply map operations on images
50    data_set_new = data_set_new.map(operations=decode_op, input_columns="image/encoded")
51    data_set_new = data_set_new.map(operations=resize_op, input_columns="image/encoded")
52    data_set_new = data_set_new.map(operations=rescale_op, input_columns="image/encoded")
53    hwc2chw_op = vision.HWC2CHW()
54    data_set_new = data_set_new.map(operations=hwc2chw_op, input_columns="image/encoded")
55    data_set_new = data_set_new.repeat(1)
56    # apply batch operations
57    batch_size_new = 32
58    data_set_new = data_set_new.batch(batch_size_new, drop_remainder=True)
59    return data_set_new
60
61
62def convert_type(shapes, types):
63    ms_types = []
64    for np_shape, np_type in zip(shapes, types):
65        input_np = np.zeros(np_shape, np_type)
66        tensor = Tensor(input_np)
67        ms_types.append(tensor.dtype)
68    return ms_types
69
70
71if __name__ == '__main__':
72    data_set = test_me_de_train_dataset()
73    dataset_size = data_set.get_dataset_size()
74    batch_size = data_set.get_batch_size()
75
76    dataset_shapes = data_set.output_shapes()
77    np_types = data_set.output_types()
78    dataset_types = convert_type(dataset_shapes, np_types)
79
80    ds1 = data_set.device_que()
81    get_next = P.GetNext(dataset_types, dataset_shapes, 2, ds1.queue_name)
82    tadd = P.ReLU()
83
84
85    class dataiter(nn.Cell):
86
87        def construct(self):
88            input_, _ = get_next()
89            return tadd(input_)
90
91
92    net = dataiter()
93    net.set_train()
94
95    _cell_graph_executor.init_dataset(ds1.queue_name, 39, batch_size,
96                                      dataset_types, dataset_shapes, (), 'dataset')
97    ds1.send()
98
99    for data in data_set.create_tuple_iterator(output_numpy=True, num_epochs=1):
100        output = net()
101        print(data[0].any())
102        print(
103            "****************************************************************************************************")
104        d = output.asnumpy()
105        print(d)
106        print(
107            "end+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++",
108            d.any())
109
110        assert (
111            (data[0] == d).all()), "TDT test execute failed, please check current code commit"
112    print(
113        "+++++++++++++++++++++++++++++++++++[INFO] Success+++++++++++++++++++++++++++++++++++++++++++")
114