• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Test dataset profiling."""
16import os
17import tempfile
18import glob
19import pytest
20
21import mindspore.dataset as ds
22from mindspore.dataset import DSCallback
23from mindspore import dtype as mstype
24import mindspore.log as logger
25import mindspore.dataset.transforms as transforms
26import mindspore as ms
27from mindspore.profiler import Profiler
28from tests.security_utils import security_off_wrap
29
30MNIST_DIR = "/home/workspace/mindspore_dataset/mnist/"
31CIFAR10_DIR = "/home/workspace/mindspore_dataset/cifar-10-batches-bin/"
32
33
34def create_dict_iterator(datasets):
35    """create_dict_iterator"""
36    count = 0
37    for _ in datasets.create_dict_iterator(num_epochs=1, output_numpy=True):
38        count += 1
39
40
41class PrintInfo(DSCallback):
42    """PrintInfo"""
43
44    @staticmethod
45    def ds_begin(ds_run_context):
46        """ds_begin"""
47        logger.info("callback: start dataset pipeline", ds_run_context.cur_epoch_num)
48
49    @staticmethod
50    def ds_epoch_begin(ds_run_context):
51        """ds_epoch_begin"""
52        logger.info("callback: epoch begin, we are in epoch", ds_run_context.cur_epoch_num)
53
54    @staticmethod
55    def ds_epoch_end(ds_run_context):
56        """ds_epoch_end"""
57        logger.info("callback: epoch end, we are in epoch", ds_run_context.cur_epoch_num)
58
59    @staticmethod
60    def ds_step_begin(ds_run_context):
61        """ds_step_begin"""
62        logger.info("callback: step start, we are in epoch", ds_run_context.cur_step_num)
63
64    @staticmethod
65    def ds_step_end(ds_run_context):
66        """ds_step_end"""
67        logger.info("callback: step end, we are in epoch", ds_run_context.cur_step_num)
68
69
70def add_one_by_epoch(batchinfo):
71    """add_one_by_epoch"""
72    return batchinfo.get_epoch_num() + 1
73
74
75def other_method_dataset():
76    """create other_method dataset"""
77    path_base = os.path.split(os.path.realpath(__file__))[0]
78    data = []
79    for d in range(10):
80        data.append(d)
81    dataset = ds.GeneratorDataset(data, "column1")
82    dataset = dataset.batch(batch_size=add_one_by_epoch)
83    create_dict_iterator(dataset)
84
85    dataset = ds.GeneratorDataset([1, 2], "col1", shuffle=False, num_parallel_workers=1)
86    dataset = dataset.map(operations=lambda x: x, callbacks=PrintInfo())
87    create_dict_iterator(dataset)
88
89    schema = ds.Schema()
90    schema.add_column(name='col1', de_type=mstype.int64, shape=[2])
91    columns1 = [{'name': 'image', 'type': 'int8', 'shape': [3, 3]},
92                {'name': 'label', 'type': 'int8', 'shape': [1]}]
93    schema.parse_columns(columns1)
94
95    pipeline1 = ds.MnistDataset(MNIST_DIR, num_samples=100)
96    pipeline2 = ds.Cifar10Dataset(CIFAR10_DIR, num_samples=100)
97    ds.compare(pipeline1, pipeline2)
98
99    dataset = ds.MnistDataset(MNIST_DIR, num_samples=100)
100    one_hot_encode = transforms.OneHot(10)
101    dataset = dataset.map(operations=one_hot_encode, input_columns="label")
102    dataset = dataset.batch(batch_size=10, drop_remainder=True)
103    ds.serialize(dataset, json_filepath=os.path.join(path_base, "mnist_dataset_pipeline.json"))
104    ds.show(dataset)
105    serialized_data = ds.serialize(dataset)
106    ds.deserialize(input_dict=serialized_data)
107    return dataset
108
109
110@pytest.mark.level0
111@pytest.mark.platform_arm_ascend_training
112@pytest.mark.platform_x86_ascend_training
113@pytest.mark.platform_arm_ascend910b_training
114@pytest.mark.env_onecard
115@security_off_wrap
116def test_ascend_dataset_profiler():
117    """
118    Feature: Test the dataset profiling.
119    Description: Traverse the dataset data, perform data preprocessing, and then verify the collected profiling data.
120    Expectation: No dataset_iterator_profiling file generated.
121    """
122    ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")
123    with tempfile.TemporaryDirectory() as tmpdir:
124        profiler = Profiler(output_path=tmpdir)
125        other_method_dataset()
126        profiler.analyse()
127        assert len(glob.glob(f"{tmpdir}/profiler*/dataset_iterator_profiling_*.txt")) == 1
128