• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16This module defines the subclass of DebugHook for minddata pipeline debugger.
17All these class are pre-defined for users for basic debugging purpose.
18"""
19
20import collections
21import numpy as np
22from PIL import Image
23from mindspore.dataset.debug.debug_hook import DebugHook
24
25
26class PrintMetaDataHook(DebugHook):
27    """
28    Debug hook used for MindData debug mode to print type and shape of data.
29    """
30    def __init__(self):
31        super().__init__()
32
33    def compute(self, *args):
34        for col_idx, col in enumerate(*args):
35            log_message = "Column {}. ".format(col_idx)
36
37            # log shape/size
38            if isinstance(col, np.ndarray):
39                log_message += "The dtype is [{}].".format(col.dtype)
40                log_message += " The shape is [{}].".format(col.shape)
41            elif isinstance(col, Image.Image):
42                log_message += "The type is [{}].".format(type(col))
43                log_message += " The shape is [{}].".format(col.size)
44            elif isinstance(col, collections.abc.Sized):
45                log_message += "The type is [{}].".format(type(col))
46                log_message += " The size is [{}].".format(len(col))
47            print(log_message, flush=True)
48        return args
49
50
51class PrintDataHook(DebugHook):
52    """
53    Debug hook used for MindData debug mode to print data.
54    """
55    def __init__(self):
56        super().__init__()
57
58    def compute(self, *args):
59        for col_idx, col in enumerate(*args):
60            log_message = "Column {}. ".format(col_idx)
61            if isinstance(col, Image.Image):
62                data = np.asarray(col)
63                log_message += "The data is [{}].".format(data)
64            else:
65                log_message += "The data is [{}].".format(col)
66            print(log_message, flush=True)
67        return args
68