1# Copyright 2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Test dataset profiling.""" 16import os 17import tempfile 18import glob 19import pytest 20 21import mindspore.dataset as ds 22from mindspore.dataset import DSCallback 23from mindspore import dtype as mstype 24import mindspore.log as logger 25import mindspore.dataset.transforms as transforms 26import mindspore as ms 27from mindspore.profiler import Profiler 28from tests.security_utils import security_off_wrap 29 30MNIST_DIR = "/home/workspace/mindspore_dataset/mnist/" 31CIFAR10_DIR = "/home/workspace/mindspore_dataset/cifar-10-batches-bin/" 32 33 34def create_dict_iterator(datasets): 35 """create_dict_iterator""" 36 count = 0 37 for _ in datasets.create_dict_iterator(num_epochs=1, output_numpy=True): 38 count += 1 39 40 41class PrintInfo(DSCallback): 42 """PrintInfo""" 43 44 @staticmethod 45 def ds_begin(ds_run_context): 46 """ds_begin""" 47 logger.info("callback: start dataset pipeline", ds_run_context.cur_epoch_num) 48 49 @staticmethod 50 def ds_epoch_begin(ds_run_context): 51 """ds_epoch_begin""" 52 logger.info("callback: epoch begin, we are in epoch", ds_run_context.cur_epoch_num) 53 54 @staticmethod 55 def ds_epoch_end(ds_run_context): 56 """ds_epoch_end""" 57 logger.info("callback: epoch end, we are in epoch", ds_run_context.cur_epoch_num) 58 59 @staticmethod 60 def ds_step_begin(ds_run_context): 61 """ds_step_begin""" 62 logger.info("callback: step start, we are in epoch", ds_run_context.cur_step_num) 63 64 @staticmethod 65 def ds_step_end(ds_run_context): 66 """ds_step_end""" 67 logger.info("callback: step end, we are in epoch", ds_run_context.cur_step_num) 68 69 70def add_one_by_epoch(batchinfo): 71 """add_one_by_epoch""" 72 return batchinfo.get_epoch_num() + 1 73 74 75def other_method_dataset(): 76 """create other_method dataset""" 77 path_base = os.path.split(os.path.realpath(__file__))[0] 78 data = [] 79 for d in range(10): 80 data.append(d) 81 dataset = ds.GeneratorDataset(data, "column1") 82 dataset = dataset.batch(batch_size=add_one_by_epoch) 83 create_dict_iterator(dataset) 84 85 dataset = ds.GeneratorDataset([1, 2], "col1", shuffle=False, num_parallel_workers=1) 86 dataset = dataset.map(operations=lambda x: x, callbacks=PrintInfo()) 87 create_dict_iterator(dataset) 88 89 schema = ds.Schema() 90 schema.add_column(name='col1', de_type=mstype.int64, shape=[2]) 91 columns1 = [{'name': 'image', 'type': 'int8', 'shape': [3, 3]}, 92 {'name': 'label', 'type': 'int8', 'shape': [1]}] 93 schema.parse_columns(columns1) 94 95 pipeline1 = ds.MnistDataset(MNIST_DIR, num_samples=100) 96 pipeline2 = ds.Cifar10Dataset(CIFAR10_DIR, num_samples=100) 97 ds.compare(pipeline1, pipeline2) 98 99 dataset = ds.MnistDataset(MNIST_DIR, num_samples=100) 100 one_hot_encode = transforms.OneHot(10) 101 dataset = dataset.map(operations=one_hot_encode, input_columns="label") 102 dataset = dataset.batch(batch_size=10, drop_remainder=True) 103 ds.serialize(dataset, json_filepath=os.path.join(path_base, "mnist_dataset_pipeline.json")) 104 ds.show(dataset) 105 serialized_data = ds.serialize(dataset) 106 ds.deserialize(input_dict=serialized_data) 107 return dataset 108 109 110@pytest.mark.level0 111@pytest.mark.platform_arm_ascend_training 112@pytest.mark.platform_x86_ascend_training 113@pytest.mark.platform_arm_ascend910b_training 114@pytest.mark.env_onecard 115@security_off_wrap 116def test_ascend_dataset_profiler(): 117 """ 118 Feature: Test the dataset profiling. 119 Description: Traverse the dataset data, perform data preprocessing, and then verify the collected profiling data. 120 Expectation: No dataset_iterator_profiling file generated. 121 """ 122 ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend") 123 with tempfile.TemporaryDirectory() as tmpdir: 124 profiler = Profiler(output_path=tmpdir) 125 other_method_dataset() 126 profiler.analyse() 127 assert len(glob.glob(f"{tmpdir}/profiler*/dataset_iterator_profiling_*.txt")) == 1 128