1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import sys 16import numpy as np 17 18import mindspore.context as context 19import mindspore.dataset as ds 20import mindspore.dataset.vision.c_transforms as vision 21import mindspore.nn as nn 22from mindspore.common.api import _cell_graph_executor 23from mindspore.common.tensor import Tensor 24from mindspore.dataset.vision import Inter 25from mindspore.ops import operations as P 26 27context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 28data_path = sys.argv[1] 29SCHEMA_DIR = "{0}/resnet_all_datasetSchema.json".format(data_path) 30 31 32def test_me_de_train_dataset(): 33 data_list = ["{0}/train-00001-of-01024.data".format(data_path)] 34 data_set_new = ds.TFRecordDataset(data_list, schema=SCHEMA_DIR, 35 columns_list=["image/encoded", "image/class/label"]) 36 37 resize_height = 224 38 resize_width = 224 39 rescale = 1.0 / 255.0 40 shift = 0.0 41 42 # define map operations 43 44 decode_op = vision.Decode() 45 resize_op = vision.Resize((resize_height, resize_width), 46 Inter.LINEAR) # Bilinear as default 47 rescale_op = vision.Rescale(rescale, shift) 48 49 # apply map operations on images 50 data_set_new = data_set_new.map(operations=decode_op, input_columns="image/encoded") 51 data_set_new = data_set_new.map(operations=resize_op, input_columns="image/encoded") 52 data_set_new = data_set_new.map(operations=rescale_op, input_columns="image/encoded") 53 hwc2chw_op = vision.HWC2CHW() 54 data_set_new = data_set_new.map(operations=hwc2chw_op, input_columns="image/encoded") 55 data_set_new = data_set_new.repeat(1) 56 # apply batch operations 57 batch_size_new = 32 58 data_set_new = data_set_new.batch(batch_size_new, drop_remainder=True) 59 return data_set_new 60 61 62def convert_type(shapes, types): 63 ms_types = [] 64 for np_shape, np_type in zip(shapes, types): 65 input_np = np.zeros(np_shape, np_type) 66 tensor = Tensor(input_np) 67 ms_types.append(tensor.dtype) 68 return ms_types 69 70 71if __name__ == '__main__': 72 data_set = test_me_de_train_dataset() 73 dataset_size = data_set.get_dataset_size() 74 batch_size = data_set.get_batch_size() 75 76 dataset_shapes = data_set.output_shapes() 77 np_types = data_set.output_types() 78 dataset_types = convert_type(dataset_shapes, np_types) 79 80 ds1 = data_set.device_que() 81 get_next = P.GetNext(dataset_types, dataset_shapes, 2, ds1.queue_name) 82 tadd = P.ReLU() 83 84 85 class dataiter(nn.Cell): 86 87 def construct(self): 88 input_, _ = get_next() 89 return tadd(input_) 90 91 92 net = dataiter() 93 net.set_train() 94 95 _cell_graph_executor.init_dataset(ds1.queue_name, 39, batch_size, 96 dataset_types, dataset_shapes, (), 'dataset') 97 ds1.send() 98 99 for data in data_set.create_tuple_iterator(output_numpy=True, num_epochs=1): 100 output = net() 101 print(data[0].any()) 102 print( 103 "****************************************************************************************************") 104 d = output.asnumpy() 105 print(d) 106 print( 107 "end+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++", 108 d.any()) 109 110 assert ( 111 (data[0] == d).all()), "TDT test execute failed, please check current code commit" 112 print( 113 "+++++++++++++++++++++++++++++++++++[INFO] Success+++++++++++++++++++++++++++++++++++++++++++") 114