• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Custom Logger."""
16import os
17import sys
18import logging
19from datetime import datetime
20
21__all__ = ["get_logger"]
22
23GLOBAL_LOGGER = None
24
25
26class Logger(logging.Logger):
27    """
28    Logger classes and functions, support print information on console and files.
29
30    Args:
31         logger_name(str): The name of Logger. In most cases, it can be the name of the network.
32    """
33
34    def __init__(self, logger_name="fasterrcnn"):
35        super(Logger, self).__init__(logger_name)
36        self.log_level = "INFO"
37        self.rank_id = _get_rank_id()
38        self.device_per_servers = 8
39        self.formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
40
41
42def setup_logging(logger_name="fasterrcnn", log_level="INFO", rank_id=None, device_per_servers=8):
43    """Setup logging file."""
44    logger = get_logger()
45    logger.name = logger_name
46    logger.log_level = log_level
47    if rank_id is not None:
48        logger.rank_id = rank_id
49    logger.device_per_servers = device_per_servers
50
51    if logger.log_level not in ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]:
52        raise ValueError(
53            f"Not support log_level: {logger.log_level}, "
54            f"the log_level should be in ['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG']"
55        )
56
57    # In the distributed scenario, only one card is printed on the console.
58    if logger.rank_id % logger.device_per_servers == 0:
59        console = logging.StreamHandler(sys.stdout)
60        console.setLevel(logger.log_level)
61        console.setFormatter(logger.formatter)
62        logger.addHandler(console)
63
64
65def setup_logging_file(log_dir="./logs"):
66    """Setup logging file."""
67    logger = get_logger()
68    if not os.path.exists(log_dir):
69        os.makedirs(log_dir, exist_ok=True)
70
71    # Generate a file stream based on the log generation time and rank_id
72    log_name = f"{logger.name}_{datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')}_rank_{logger.rank_id}.log"
73    log_path = os.path.join(log_dir, log_name)
74    file_handler = logging.FileHandler(log_path)
75    file_handler.setLevel(logger.log_level)
76    file_handler.setFormatter(logger.formatter)
77    logger.addHandler(file_handler)
78
79
80def print_args(args):
81    """Print hyper-parameter"""
82    get_logger().info("Args:")
83    args_dict = vars(args)
84    for key in args_dict.keys():
85        get_logger().info("--> %s: %s", key, args_dict[key])
86    get_logger().info("")
87
88
89def important_info(msg, *args, **kwargs):
90    """For information that needs to be focused on, add special printing format."""
91    line_width = 2
92    important_msg = "\n"
93    important_msg += ("*" * 70 + "\n") * line_width
94    important_msg += ("*" * line_width + "\n") * 2
95    important_msg += "*" * line_width + " " * 8 + msg + "\n"
96    important_msg += ("*" * line_width + "\n") * 2
97    important_msg += ("*" * 70 + "\n") * line_width
98    get_logger().info(important_msg, *args, **kwargs)
99
100
101def info(msg, *args, **kwargs):
102    """
103    Log a message with severity 'INFO' on the logger.
104
105    Examples:
106        >>> logger.setup_logging(logger_name="fasterrcnn", log_level="INFO", rank_id=0, device_per_servers=8)
107        >>> logger.setup_logging_file(log_dir="./logs")
108        >>> logger.info("test info")
109    """
110    get_logger().info(msg, *args, **kwargs)
111
112
113def debug(msg, *args, **kwargs):
114    """Log a message with severity 'DEBUG' on the logger."""
115    get_logger().debug(msg, *args, **kwargs)
116
117
118def error(msg, *args, **kwargs):
119    """Log a message with severity 'ERROR' on the logger."""
120    get_logger().error(msg, *args, **kwargs)
121
122
123def warning(msg, *args, **kwargs):
124    """Log a message with severity 'WARNING' on the logger."""
125    get_logger().warning(msg, *args, **kwargs)
126
127
128def critical(msg, *args, **kwargs):
129    """Log a message with severity 'CRITICAL' on the logger."""
130    get_logger().critical(msg, *args, **kwargs)
131
132
133def get_level():
134    """
135    Get the logger level.
136
137    Returns:
138        str, the Log level includes 'CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'.
139    """
140    # level and glog level mapping dictionary
141
142    return get_logger().log_level
143
144
145def _get_rank_id():
146    """Get rank id."""
147    rank_id = os.getenv("RANK_ID")
148    gpu_rank_id = os.getenv("OMPI_COMM_WORLD_RANK")
149    rank = "0"
150    if rank_id and gpu_rank_id and rank_id != gpu_rank_id:
151        print(
152            f"Environment variables RANK_ID and OMPI_COMM_WORLD_RANK set by different values, RANK_ID={rank_id}, "
153            f"OMPI_COMM_WORLD_RANK={gpu_rank_id}. We will use RANK_ID to get rank id by default.",
154            flush=True,
155        )
156    if rank_id:
157        rank = rank_id
158    elif gpu_rank_id:
159        rank = gpu_rank_id
160    return int(rank)
161
162
163def get_logger():
164    """Get logger instance."""
165    global GLOBAL_LOGGER
166    if GLOBAL_LOGGER:
167        return GLOBAL_LOGGER
168    GLOBAL_LOGGER = Logger()
169    return GLOBAL_LOGGER
170