1# Copyright 2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Custom Logger.""" 16import os 17import sys 18import logging 19from datetime import datetime 20 21__all__ = ["get_logger"] 22 23GLOBAL_LOGGER = None 24 25 26class Logger(logging.Logger): 27 """ 28 Logger classes and functions, support print information on console and files. 29 30 Args: 31 logger_name(str): The name of Logger. In most cases, it can be the name of the network. 32 """ 33 34 def __init__(self, logger_name="fasterrcnn"): 35 super(Logger, self).__init__(logger_name) 36 self.log_level = "INFO" 37 self.rank_id = _get_rank_id() 38 self.device_per_servers = 8 39 self.formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s") 40 41 42def setup_logging(logger_name="fasterrcnn", log_level="INFO", rank_id=None, device_per_servers=8): 43 """Setup logging file.""" 44 logger = get_logger() 45 logger.name = logger_name 46 logger.log_level = log_level 47 if rank_id is not None: 48 logger.rank_id = rank_id 49 logger.device_per_servers = device_per_servers 50 51 if logger.log_level not in ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]: 52 raise ValueError( 53 f"Not support log_level: {logger.log_level}, " 54 f"the log_level should be in ['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG']" 55 ) 56 57 # In the distributed scenario, only one card is printed on the console. 58 if logger.rank_id % logger.device_per_servers == 0: 59 console = logging.StreamHandler(sys.stdout) 60 console.setLevel(logger.log_level) 61 console.setFormatter(logger.formatter) 62 logger.addHandler(console) 63 64 65def setup_logging_file(log_dir="./logs"): 66 """Setup logging file.""" 67 logger = get_logger() 68 if not os.path.exists(log_dir): 69 os.makedirs(log_dir, exist_ok=True) 70 71 # Generate a file stream based on the log generation time and rank_id 72 log_name = f"{logger.name}_{datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')}_rank_{logger.rank_id}.log" 73 log_path = os.path.join(log_dir, log_name) 74 file_handler = logging.FileHandler(log_path) 75 file_handler.setLevel(logger.log_level) 76 file_handler.setFormatter(logger.formatter) 77 logger.addHandler(file_handler) 78 79 80def print_args(args): 81 """Print hyper-parameter""" 82 get_logger().info("Args:") 83 args_dict = vars(args) 84 for key in args_dict.keys(): 85 get_logger().info("--> %s: %s", key, args_dict[key]) 86 get_logger().info("") 87 88 89def important_info(msg, *args, **kwargs): 90 """For information that needs to be focused on, add special printing format.""" 91 line_width = 2 92 important_msg = "\n" 93 important_msg += ("*" * 70 + "\n") * line_width 94 important_msg += ("*" * line_width + "\n") * 2 95 important_msg += "*" * line_width + " " * 8 + msg + "\n" 96 important_msg += ("*" * line_width + "\n") * 2 97 important_msg += ("*" * 70 + "\n") * line_width 98 get_logger().info(important_msg, *args, **kwargs) 99 100 101def info(msg, *args, **kwargs): 102 """ 103 Log a message with severity 'INFO' on the logger. 104 105 Examples: 106 >>> logger.setup_logging(logger_name="fasterrcnn", log_level="INFO", rank_id=0, device_per_servers=8) 107 >>> logger.setup_logging_file(log_dir="./logs") 108 >>> logger.info("test info") 109 """ 110 get_logger().info(msg, *args, **kwargs) 111 112 113def debug(msg, *args, **kwargs): 114 """Log a message with severity 'DEBUG' on the logger.""" 115 get_logger().debug(msg, *args, **kwargs) 116 117 118def error(msg, *args, **kwargs): 119 """Log a message with severity 'ERROR' on the logger.""" 120 get_logger().error(msg, *args, **kwargs) 121 122 123def warning(msg, *args, **kwargs): 124 """Log a message with severity 'WARNING' on the logger.""" 125 get_logger().warning(msg, *args, **kwargs) 126 127 128def critical(msg, *args, **kwargs): 129 """Log a message with severity 'CRITICAL' on the logger.""" 130 get_logger().critical(msg, *args, **kwargs) 131 132 133def get_level(): 134 """ 135 Get the logger level. 136 137 Returns: 138 str, the Log level includes 'CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'. 139 """ 140 # level and glog level mapping dictionary 141 142 return get_logger().log_level 143 144 145def _get_rank_id(): 146 """Get rank id.""" 147 rank_id = os.getenv("RANK_ID") 148 gpu_rank_id = os.getenv("OMPI_COMM_WORLD_RANK") 149 rank = "0" 150 if rank_id and gpu_rank_id and rank_id != gpu_rank_id: 151 print( 152 f"Environment variables RANK_ID and OMPI_COMM_WORLD_RANK set by different values, RANK_ID={rank_id}, " 153 f"OMPI_COMM_WORLD_RANK={gpu_rank_id}. We will use RANK_ID to get rank id by default.", 154 flush=True, 155 ) 156 if rank_id: 157 rank = rank_id 158 elif gpu_rank_id: 159 rank = gpu_rank_id 160 return int(rank) 161 162 163def get_logger(): 164 """Get logger instance.""" 165 global GLOBAL_LOGGER 166 if GLOBAL_LOGGER: 167 return GLOBAL_LOGGER 168 GLOBAL_LOGGER = Logger() 169 return GLOBAL_LOGGER 170