1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16This module defines the class for minddata pipeline debugger. 17class DebugHook is not exposed to users as an external API. 18""" 19 20from abc import ABC, abstractmethod 21 22 23class DebugHook(ABC): 24 """ 25 The base class for Dataset Pipeline Python Debugger hook. All user defined hook behaviors 26 must inherit this base class. 27 28 To debug the input and output data of `map` operation in dataset pipeline, users can add 29 breakpoint in `compute` method, or print types and shapes of the data. 30 31 Args: 32 prev_op_name (str, optional): name of the operation before current debugging point. Default: ``None``. 33 34 Examples: 35 >>> import mindspore.dataset as ds 36 >>> import mindspore.dataset.debug as debug 37 >>> 38 >>> class CustomizedHook(debug.DebugHook): 39 ... def __init__(self): 40 ... super().__init__() 41 ... 42 ... def compute(self, *args): 43 ... import pdb 44 ... pdb.set_trace() 45 ... print("Data after decode", *args) 46 ... return args 47 >>> 48 >>> # Enable debug mode 49 >>> ds.config.set_debug_mode(True, debug_hook_list=[CustomizedHook()]) 50 >>> 51 >>> # Define dataset pipeline 52 >>> dataset = ds.ImageFolderDataset(dataset_dir="/path/to/image_folder_dataset_directory") 53 >>> # Insert debug hook after `Decode` operation. 54 >>> dataset = dataset.map([vision.Decode(), CustomizedHook(), vision.CenterCrop(100)]) 55 """ 56 def __init__(self, prev_op_name=None): 57 self.prev_op_name = prev_op_name 58 self.is_first_op = None 59 60 def __call__(self, *args): 61 # If insert debug function into map, like [Decode(), debug_fun(), Resize], 62 # the debug_fun does not have self.prev_op_name, so skip the common print. 63 if not self.prev_op_name: 64 pass 65 else: 66 # log op name 67 if self.is_first_op: 68 log_message = "[Dataset debugger] Print the [INPUT] of the operation [{}].".format(self.prev_op_name) 69 else: 70 log_message = "[Dataset debugger] Print the [OUTPUT] of the operation [{}].".format(self.prev_op_name) 71 print(log_message, flush=True) 72 73 ######################## NOTE ######################## 74 # Add a breakpoint to the following line to inspect 75 # input and output of each transform. 76 ###################################################### 77 self.compute(args) 78 return args 79 80 @abstractmethod 81 def compute(self, *args): 82 """ 83 Defines the debug behaviour to be performed. This method must be overridden by all subclasses. 84 Refers to the example above to define a customized hook. 85 86 Args: 87 *args (Any): The input/output of the operation, just print it. 88 """ 89 raise RuntimeError("compute() is not overridden in subclass of class DebugHook.") 90 91 def set_previous_op_name(self, prev_op_name): 92 # Set prev_op_name. 93 self.prev_op_name = prev_op_name 94 95 def set_is_first(self, is_first_op): 96 # Set op is the first in map. 97 self.is_first_op = is_first_op 98