• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import time
16import numpy as np
17import pytest
18from mindspore import context, nn, Tensor
19from mindspore import log as logger
20from mindspore.common.api import _cell_graph_executor
21from mindspore.common import dtype as mstype
22from mindspore.ops import operations as P
23import mindspore.dataset as de
24from mindspore.dataset.vision import c_transforms as c_vision
25from mindspore.dataset.transforms import c_transforms as c_trans
26
27
28DATA_DIR = "/home/workspace/mindspore_dataset/cifar-10-verify-bin"
29
30
31def dataset_cifar(dataset_path=None, batch_size=32, repeat_num=1, num_rows=9600, distribution_num=None, shard_id=None,
32                  drop_remainder=True, usage=None, shuffle=False, num_workers=8, resize_size=32, pad_info=None):
33    if dataset_path is None:
34        dataset_path = DATA_DIR
35
36    ds = de.Cifar10Dataset(dataset_path, num_samples=num_rows, num_shards=distribution_num, shard_id=shard_id,
37                           shuffle=shuffle, usage=usage, num_parallel_workers=num_workers)
38
39    typecast_op = c_trans.TypeCast(mstype.int32)
40    ds = ds.map(input_columns="label", operations=typecast_op, num_parallel_workers=num_workers)
41
42    image_op_list = [c_vision.Resize(resize_size),
43                     c_vision.Rescale(1.0 / 255.0, 0.0),
44                     c_vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
45                     c_vision.HWC2CHW()]
46    ds = ds.map(input_columns="image", operations=image_op_list, num_parallel_workers=num_workers)
47
48    ds = ds.batch(batch_size, drop_remainder=drop_remainder, num_parallel_workers=num_workers, pad_info=pad_info)
49    ds = ds.repeat(repeat_num)
50
51    return ds
52
53
54def op_network_with_epoch(network, step_num):
55    iter_num = 0
56    network.set_train()
57    for _ in range(step_num):
58        op_return = network()
59        op_return = op_return.asnumpy()
60        logger.info("Op_return is : %s", op_return)
61        iter_num += 1
62        logger.info("Iter Num : %s", iter_num)
63
64    return iter_num
65
66
67def convert_type(shapes, types):
68    ms_types = []
69    for np_shape, np_type in zip(shapes, types):
70        input_np = np.zeros(np_shape, np_type)
71        tensor = Tensor(input_np)
72        ms_types.append(tensor.dtype)
73    return ms_types
74
75
76def get_dataset_base_value(dataset):
77    dataset_size = dataset.get_dataset_size()
78    batch_size = dataset.get_batch_size()
79    return dataset_size, batch_size
80
81
82def dataset_send_tdt(dataset):
83    time.sleep(1)
84    dataset.send(1)
85
86
87def get_dataset_shapes_and_types(dataset):
88    dataset_shapes = dataset.output_shapes()
89    np_types = dataset.output_types()
90    dataset_types = convert_type(dataset_shapes, np_types)
91    return dataset_shapes, dataset_types
92
93
94class SingleOpNetwork(nn.Cell):
95    def __init__(self, shapes):
96        super(SingleOpNetwork, self).__init__()
97        self.shapes = tuple(shapes[0])
98        self.Op_Reshape_network = P.Reshape()
99
100    def construct(self, network_input):
101        return self.Op_Reshape_network(network_input, self.shapes)
102
103
104class NetWithTDT(nn.Cell):
105    def __init__(self, network, dataset_types, dataset_shapes, shared_name=''):
106        super(NetWithTDT, self).__init__()
107        self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_shapes), shared_name)
108        self.Op_network = network
109
110    def construct(self):
111        next_input, _ = self.get_next()
112        return self.Op_network(next_input)
113
114
115def op_network_with_step_num(dataset, step_num):
116    dataset_shapes, dataset_types = get_dataset_shapes_and_types(dataset)
117    _, batch_size = get_dataset_base_value(dataset)
118    dataset = dataset.device_que()
119    queue_name = dataset.queue_name
120
121    net = SingleOpNetwork(dataset_shapes)
122    net_with_dataset = NetWithTDT(net, dataset_types, dataset_shapes, queue_name)
123    # when device type is Davinci, net should has get_next operation before call init_dataset
124    _cell_graph_executor.init_dataset(dataset.queue_name, 1, batch_size, dataset_types, dataset_shapes, (), "")
125    dataset_send_tdt(dataset)
126    return op_network_with_epoch(net_with_dataset, step_num)
127
128
129@pytest.mark.level0
130@pytest.mark.platform_arm_ascend_training
131@pytest.mark.platform_x86_ascend_training
132@pytest.mark.env_onecard
133def test_tdt_consume_beyond_produce():
134    context.set_context(mode=context.GRAPH_MODE)
135
136    batch_size = 64
137    repeat_num = 1
138    num_rows = 640
139    beyond_step_num = 1000
140    ds = dataset_cifar(batch_size=batch_size, repeat_num=repeat_num, num_rows=num_rows)
141
142    try:
143        iter_num = op_network_with_step_num(ds, step_num=beyond_step_num)
144        logger.info("out_iter_num:%s", iter_num)
145        assert False
146    except RuntimeError as e:
147        logger.info("when dataset batch num is less than train loop, error msg is %s", e)
148        assert True
149
150
151@pytest.mark.level0
152@pytest.mark.platform_arm_ascend_training
153@pytest.mark.platform_x86_ascend_training
154@pytest.mark.env_onecard
155def test_tdt_produce_beyond_consume():
156    context.set_context(mode=context.GRAPH_MODE)
157
158    batch_size = 64
159    repeat_num = 1
160    num_rows = 6400
161    beyond_step_num = 10
162    ds = dataset_cifar(batch_size=batch_size, repeat_num=repeat_num, num_rows=num_rows)
163
164    iter_num = op_network_with_step_num(ds, step_num=beyond_step_num)
165    logger.info("out_iter_num:%s", iter_num)
166    assert iter_num == 10
167