1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import time 16import numpy as np 17import pytest 18from mindspore import context, nn, Tensor 19from mindspore import log as logger 20from mindspore.common.api import _cell_graph_executor 21from mindspore.common import dtype as mstype 22from mindspore.ops import operations as P 23import mindspore.dataset as de 24from mindspore.dataset.vision import c_transforms as c_vision 25from mindspore.dataset.transforms import c_transforms as c_trans 26 27 28DATA_DIR = "/home/workspace/mindspore_dataset/cifar-10-verify-bin" 29 30 31def dataset_cifar(dataset_path=None, batch_size=32, repeat_num=1, num_rows=9600, distribution_num=None, shard_id=None, 32 drop_remainder=True, usage=None, shuffle=False, num_workers=8, resize_size=32, pad_info=None): 33 if dataset_path is None: 34 dataset_path = DATA_DIR 35 36 ds = de.Cifar10Dataset(dataset_path, num_samples=num_rows, num_shards=distribution_num, shard_id=shard_id, 37 shuffle=shuffle, usage=usage, num_parallel_workers=num_workers) 38 39 typecast_op = c_trans.TypeCast(mstype.int32) 40 ds = ds.map(input_columns="label", operations=typecast_op, num_parallel_workers=num_workers) 41 42 image_op_list = [c_vision.Resize(resize_size), 43 c_vision.Rescale(1.0 / 255.0, 0.0), 44 c_vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 45 c_vision.HWC2CHW()] 46 ds = ds.map(input_columns="image", operations=image_op_list, num_parallel_workers=num_workers) 47 48 ds = ds.batch(batch_size, drop_remainder=drop_remainder, num_parallel_workers=num_workers, pad_info=pad_info) 49 ds = ds.repeat(repeat_num) 50 51 return ds 52 53 54def op_network_with_epoch(network, step_num): 55 iter_num = 0 56 network.set_train() 57 for _ in range(step_num): 58 op_return = network() 59 op_return = op_return.asnumpy() 60 logger.info("Op_return is : %s", op_return) 61 iter_num += 1 62 logger.info("Iter Num : %s", iter_num) 63 64 return iter_num 65 66 67def convert_type(shapes, types): 68 ms_types = [] 69 for np_shape, np_type in zip(shapes, types): 70 input_np = np.zeros(np_shape, np_type) 71 tensor = Tensor(input_np) 72 ms_types.append(tensor.dtype) 73 return ms_types 74 75 76def get_dataset_base_value(dataset): 77 dataset_size = dataset.get_dataset_size() 78 batch_size = dataset.get_batch_size() 79 return dataset_size, batch_size 80 81 82def dataset_send_tdt(dataset): 83 time.sleep(1) 84 dataset.send(1) 85 86 87def get_dataset_shapes_and_types(dataset): 88 dataset_shapes = dataset.output_shapes() 89 np_types = dataset.output_types() 90 dataset_types = convert_type(dataset_shapes, np_types) 91 return dataset_shapes, dataset_types 92 93 94class SingleOpNetwork(nn.Cell): 95 def __init__(self, shapes): 96 super(SingleOpNetwork, self).__init__() 97 self.shapes = tuple(shapes[0]) 98 self.Op_Reshape_network = P.Reshape() 99 100 def construct(self, network_input): 101 return self.Op_Reshape_network(network_input, self.shapes) 102 103 104class NetWithTDT(nn.Cell): 105 def __init__(self, network, dataset_types, dataset_shapes, shared_name=''): 106 super(NetWithTDT, self).__init__() 107 self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_shapes), shared_name) 108 self.Op_network = network 109 110 def construct(self): 111 next_input, _ = self.get_next() 112 return self.Op_network(next_input) 113 114 115def op_network_with_step_num(dataset, step_num): 116 dataset_shapes, dataset_types = get_dataset_shapes_and_types(dataset) 117 _, batch_size = get_dataset_base_value(dataset) 118 dataset = dataset.device_que() 119 queue_name = dataset.queue_name 120 121 net = SingleOpNetwork(dataset_shapes) 122 net_with_dataset = NetWithTDT(net, dataset_types, dataset_shapes, queue_name) 123 # when device type is Davinci, net should has get_next operation before call init_dataset 124 _cell_graph_executor.init_dataset(dataset.queue_name, 1, batch_size, dataset_types, dataset_shapes, (), "") 125 dataset_send_tdt(dataset) 126 return op_network_with_epoch(net_with_dataset, step_num) 127 128 129@pytest.mark.level0 130@pytest.mark.platform_arm_ascend_training 131@pytest.mark.platform_x86_ascend_training 132@pytest.mark.env_onecard 133def test_tdt_consume_beyond_produce(): 134 context.set_context(mode=context.GRAPH_MODE) 135 136 batch_size = 64 137 repeat_num = 1 138 num_rows = 640 139 beyond_step_num = 1000 140 ds = dataset_cifar(batch_size=batch_size, repeat_num=repeat_num, num_rows=num_rows) 141 142 try: 143 iter_num = op_network_with_step_num(ds, step_num=beyond_step_num) 144 logger.info("out_iter_num:%s", iter_num) 145 assert False 146 except RuntimeError as e: 147 logger.info("when dataset batch num is less than train loop, error msg is %s", e) 148 assert True 149 150 151@pytest.mark.level0 152@pytest.mark.platform_arm_ascend_training 153@pytest.mark.platform_x86_ascend_training 154@pytest.mark.env_onecard 155def test_tdt_produce_beyond_consume(): 156 context.set_context(mode=context.GRAPH_MODE) 157 158 batch_size = 64 159 repeat_num = 1 160 num_rows = 6400 161 beyond_step_num = 10 162 ds = dataset_cifar(batch_size=batch_size, repeat_num=repeat_num, num_rows=num_rows) 163 164 iter_num = op_network_with_step_num(ds, step_num=beyond_step_num) 165 logger.info("out_iter_num:%s", iter_num) 166 assert iter_num == 10 167