1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Validator Functions for Offline Debugger APIs. 17""" 18from functools import wraps 19 20import mindspore.offline_debug.dbg_services as cds 21from mindspore.offline_debug.mi_validator_helpers import parse_user_args, type_check, \ 22 type_check_list, check_dir, check_uint32, check_uint64, check_iteration, check_param_id 23 24 25def check_init(method): 26 """Wrapper method to check the parameters of DbgServices init.""" 27 28 @wraps(method) 29 def new_method(self, *args, **kwargs): 30 [dump_file_path], _ = parse_user_args(method, *args, **kwargs) 31 32 type_check(dump_file_path, (str,), "dump_file_path") 33 check_dir(dump_file_path) 34 35 return method(self, *args, **kwargs) 36 37 return new_method 38 39 40def check_initialize(method): 41 """Wrapper method to check the parameters of DbgServices Initialize method.""" 42 43 @wraps(method) 44 def new_method(self, *args, **kwargs): 45 [net_name, is_sync_mode, max_mem_usage], _ = parse_user_args(method, *args, **kwargs) 46 47 type_check(net_name, (str,), "net_name") 48 type_check(is_sync_mode, (bool,), "is_sync_mode") 49 check_uint32(max_mem_usage, "max_mem_usage") 50 51 return method(self, *args, **kwargs) 52 53 return new_method 54 55 56def check_add_watchpoint(method): 57 """Wrapper method to check the parameters of DbgServices AddWatchpoint.""" 58 59 @wraps(method) 60 def new_method(self, *args, **kwargs): 61 [id_value, watch_condition, check_node_list, parameter_list], _ = parse_user_args(method, *args, **kwargs) 62 63 check_uint32(id_value, "id") 64 check_uint32(watch_condition, "watch_condition") 65 type_check(check_node_list, (dict,), "check_node_list") 66 for node_name, node_info in check_node_list.items(): 67 type_check(node_name, (str,), "node_name") 68 type_check(node_info, (dict,), "node_info") 69 for info_name, info_param in node_info.items(): 70 type_check(info_name, (str,), "node parameter name") 71 if info_name in ["rank_id"]: 72 check_param_id(info_param, info_name="rank_id") 73 elif info_name in ["root_graph_id"]: 74 check_param_id(info_param, info_name="root_graph_id") 75 elif info_name in ["is_output"]: 76 type_check(info_param, (bool,), "is_output") 77 else: 78 raise ValueError("Node parameter {} is not defined.".format(info_name)) 79 param_names = ["param_{0}".format(i) for i in range(len(parameter_list))] 80 type_check_list(parameter_list, (cds.Parameter,), param_names) 81 82 return method(self, *args, **kwargs) 83 84 return new_method 85 86 87def check_remove_watchpoint(method): 88 """Wrapper method to check the parameters of DbgServices RemoveWatchpoint.""" 89 90 @wraps(method) 91 def new_method(self, *args, **kwargs): 92 [id_value], _ = parse_user_args(method, *args, **kwargs) 93 94 check_uint32(id_value, "id") 95 96 return method(self, *args, **kwargs) 97 98 return new_method 99 100 101def check_check_watchpoints(method): 102 """Wrapper method to check the parameters of DbgServices CheckWatchpoint.""" 103 104 @wraps(method) 105 def new_method(self, *args, **kwargs): 106 [iteration], _ = parse_user_args(method, *args, **kwargs) 107 108 check_iteration(iteration, "iteration") 109 110 return method(self, *args, **kwargs) 111 112 return new_method 113 114 115def check_read_tensor_info(method): 116 """Wrapper method to check the parameters of DbgServices ReadTensors.""" 117 118 @wraps(method) 119 def new_method(self, *args, **kwargs): 120 [info_list], _ = parse_user_args(method, *args, **kwargs) 121 122 info_names = ["info_{0}".format(i) for i in range(len(info_list))] 123 type_check_list(info_list, (cds.TensorInfo,), info_names) 124 125 return method(self, *args, **kwargs) 126 127 return new_method 128 129 130def check_initialize_done(method): 131 """Wrapper method to check if initlize is done for DbgServices.""" 132 133 @wraps(method) 134 def new_method(self, *args, **kwargs): 135 136 if not self.initialized: 137 raise RuntimeError("Inilize should be called before any other methods of DbgServices!") 138 return method(self, *args, **kwargs) 139 140 return new_method 141 142 143def check_tensor_info_init(method): 144 """Wrapper method to check the parameters of DbgServices TensorInfo init.""" 145 146 @wraps(method) 147 def new_method(self, *args, **kwargs): 148 [node_name, slot, iteration, rank_id, root_graph_id, 149 is_output], _ = parse_user_args(method, *args, **kwargs) 150 151 type_check(node_name, (str,), "node_name") 152 check_uint32(slot, "slot") 153 check_iteration(iteration, "iteration") 154 check_uint32(rank_id, "rank_id") 155 check_uint32(root_graph_id, "root_graph_id") 156 type_check(is_output, (bool,), "is_output") 157 158 return method(self, *args, **kwargs) 159 160 return new_method 161 162 163def check_tensor_data_init(method): 164 """Wrapper method to check the parameters of DbgServices TensorData init.""" 165 166 @wraps(method) 167 def new_method(self, *args, **kwargs): 168 [data_ptr, data_size, dtype, shape], _ = parse_user_args(method, *args, **kwargs) 169 170 type_check(data_ptr, (bytes,), "data_ptr") 171 check_uint64(data_size, "data_size") 172 type_check(dtype, (int,), "dtype") 173 shape_names = ["shape_{0}".format(i) for i in range(len(shape))] 174 type_check_list(shape, (int,), shape_names) 175 176 if len(data_ptr) != data_size: 177 raise ValueError("data_ptr length ({0}) is not equal to data_size ({1}).".format(len(data_ptr), data_size)) 178 179 return method(self, *args, **kwargs) 180 181 return new_method 182 183 184def check_tensor_base_data_init(method): 185 """Wrapper method to check the parameters of DbgServices TensorBaseData init.""" 186 187 @wraps(method) 188 def new_method(self, *args, **kwargs): 189 [data_size, dtype, shape], _ = parse_user_args(method, *args, **kwargs) 190 191 check_uint64(data_size, "data_size") 192 type_check(dtype, (int,), "dtype") 193 shape_names = ["shape_{0}".format(i) for i in range(len(shape))] 194 type_check_list(shape, (int,), shape_names) 195 196 return method(self, *args, **kwargs) 197 198 return new_method 199 200 201def check_tensor_stat_data_init(method): 202 """Wrapper method to check the parameters of DbgServices TensorBaseData init.""" 203 204 @wraps(method) 205 def new_method(self, *args, **kwargs): 206 [data_size, dtype, shape, is_bool, max_value, min_value, 207 avg_value, count, neg_zero_count, pos_zero_count, 208 nan_count, neg_inf_count, pos_inf_count, 209 zero_count], _ = parse_user_args(method, *args, **kwargs) 210 211 check_uint64(data_size, "data_size") 212 type_check(dtype, (int,), "dtype") 213 shape_names = ["shape_{0}".format(i) for i in range(len(shape))] 214 type_check_list(shape, (int,), shape_names) 215 type_check(is_bool, (bool,), "is_bool") 216 type_check(max_value, (float,), "max_value") 217 type_check(min_value, (float,), "min_value") 218 type_check(avg_value, (float,), "avg_value") 219 type_check(count, (int,), "count") 220 type_check(neg_zero_count, (int,), "neg_zero_count") 221 type_check(pos_zero_count, (int,), "pos_zero_count") 222 type_check(nan_count, (int,), "nan_count") 223 type_check(neg_inf_count, (int,), "neg_inf_count") 224 type_check(pos_inf_count, (int,), "pos_inf_count") 225 type_check(zero_count, (int,), "zero_count") 226 227 228 return method(self, *args, **kwargs) 229 230 return new_method 231 232 233def check_watchpoint_hit_init(method): 234 """Wrapper method to check the parameters of DbgServices WatchpointHit init.""" 235 236 @wraps(method) 237 def new_method(self, *args, **kwargs): 238 [name, slot, condition, watchpoint_id, 239 parameters, error_code, rank_id, root_graph_id], _ = parse_user_args(method, *args, **kwargs) 240 241 type_check(name, (str,), "name") 242 check_uint32(slot, "slot") 243 type_check(condition, (int,), "condition") 244 check_uint32(watchpoint_id, "watchpoint_id") 245 param_names = ["param_{0}".format(i) for i in range(len(parameters))] 246 type_check_list(parameters, (cds.Parameter,), param_names) 247 type_check(error_code, (int,), "error_code") 248 check_uint32(rank_id, "rank_id") 249 check_uint32(root_graph_id, "root_graph_id") 250 251 return method(self, *args, **kwargs) 252 253 return new_method 254 255 256def check_parameter_init(method): 257 """Wrapper method to check the parameters of DbgServices Parameter init.""" 258 259 @wraps(method) 260 def new_method(self, *args, **kwargs): 261 [name, disabled, value, hit, actual_value], _ = parse_user_args(method, *args, **kwargs) 262 263 type_check(name, (str,), "name") 264 type_check(disabled, (bool,), "disabled") 265 type_check(value, (float,), "value") 266 type_check(hit, (bool,), "hit") 267 type_check(actual_value, (float,), "actual_value") 268 269 return method(self, *args, **kwargs) 270 271 return new_method 272