• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Validator Functions for Offline Debugger APIs.
17"""
18from functools import wraps
19
20import mindspore.offline_debug.dbg_services as cds
21from mindspore.offline_debug.mi_validator_helpers import parse_user_args, type_check, \
22    type_check_list, check_dir, check_uint32, check_uint64, check_iteration, check_param_id
23
24
25def check_init(method):
26    """Wrapper method to check the parameters of DbgServices init."""
27
28    @wraps(method)
29    def new_method(self, *args, **kwargs):
30        [dump_file_path], _ = parse_user_args(method, *args, **kwargs)
31
32        type_check(dump_file_path, (str,), "dump_file_path")
33        check_dir(dump_file_path)
34
35        return method(self, *args, **kwargs)
36
37    return new_method
38
39
40def check_initialize(method):
41    """Wrapper method to check the parameters of DbgServices Initialize method."""
42
43    @wraps(method)
44    def new_method(self, *args, **kwargs):
45        [net_name, is_sync_mode, max_mem_usage], _ = parse_user_args(method, *args, **kwargs)
46
47        type_check(net_name, (str,), "net_name")
48        type_check(is_sync_mode, (bool,), "is_sync_mode")
49        check_uint32(max_mem_usage, "max_mem_usage")
50
51        return method(self, *args, **kwargs)
52
53    return new_method
54
55
56def check_add_watchpoint(method):
57    """Wrapper method to check the parameters of DbgServices AddWatchpoint."""
58
59    @wraps(method)
60    def new_method(self, *args, **kwargs):
61        [id_value, watch_condition, check_node_list, parameter_list], _ = parse_user_args(method, *args, **kwargs)
62
63        check_uint32(id_value, "id")
64        check_uint32(watch_condition, "watch_condition")
65        type_check(check_node_list, (dict,), "check_node_list")
66        for node_name, node_info in check_node_list.items():
67            type_check(node_name, (str,), "node_name")
68            type_check(node_info, (dict,), "node_info")
69            for info_name, info_param in node_info.items():
70                type_check(info_name, (str,), "node parameter name")
71                if info_name in ["rank_id"]:
72                    check_param_id(info_param, info_name="rank_id")
73                elif info_name in ["root_graph_id"]:
74                    check_param_id(info_param, info_name="root_graph_id")
75                elif info_name in ["is_output"]:
76                    type_check(info_param, (bool,), "is_output")
77                else:
78                    raise ValueError("Node parameter {} is not defined.".format(info_name))
79        param_names = ["param_{0}".format(i) for i in range(len(parameter_list))]
80        type_check_list(parameter_list, (cds.Parameter,), param_names)
81
82        return method(self, *args, **kwargs)
83
84    return new_method
85
86
87def check_remove_watchpoint(method):
88    """Wrapper method to check the parameters of DbgServices RemoveWatchpoint."""
89
90    @wraps(method)
91    def new_method(self, *args, **kwargs):
92        [id_value], _ = parse_user_args(method, *args, **kwargs)
93
94        check_uint32(id_value, "id")
95
96        return method(self, *args, **kwargs)
97
98    return new_method
99
100
101def check_check_watchpoints(method):
102    """Wrapper method to check the parameters of DbgServices CheckWatchpoint."""
103
104    @wraps(method)
105    def new_method(self, *args, **kwargs):
106        [iteration], _ = parse_user_args(method, *args, **kwargs)
107
108        check_iteration(iteration, "iteration")
109
110        return method(self, *args, **kwargs)
111
112    return new_method
113
114
115def check_read_tensor_info(method):
116    """Wrapper method to check the parameters of DbgServices ReadTensors."""
117
118    @wraps(method)
119    def new_method(self, *args, **kwargs):
120        [info_list], _ = parse_user_args(method, *args, **kwargs)
121
122        info_names = ["info_{0}".format(i) for i in range(len(info_list))]
123        type_check_list(info_list, (cds.TensorInfo,), info_names)
124
125        return method(self, *args, **kwargs)
126
127    return new_method
128
129
130def check_initialize_done(method):
131    """Wrapper method to check if initlize is done for DbgServices."""
132
133    @wraps(method)
134    def new_method(self, *args, **kwargs):
135
136        if not self.initialized:
137            raise RuntimeError("Inilize should be called before any other methods of DbgServices!")
138        return method(self, *args, **kwargs)
139
140    return new_method
141
142
143def check_tensor_info_init(method):
144    """Wrapper method to check the parameters of DbgServices TensorInfo init."""
145
146    @wraps(method)
147    def new_method(self, *args, **kwargs):
148        [node_name, slot, iteration, rank_id, root_graph_id,
149         is_output], _ = parse_user_args(method, *args, **kwargs)
150
151        type_check(node_name, (str,), "node_name")
152        check_uint32(slot, "slot")
153        check_iteration(iteration, "iteration")
154        check_uint32(rank_id, "rank_id")
155        check_uint32(root_graph_id, "root_graph_id")
156        type_check(is_output, (bool,), "is_output")
157
158        return method(self, *args, **kwargs)
159
160    return new_method
161
162
163def check_tensor_data_init(method):
164    """Wrapper method to check the parameters of DbgServices TensorData init."""
165
166    @wraps(method)
167    def new_method(self, *args, **kwargs):
168        [data_ptr, data_size, dtype, shape], _ = parse_user_args(method, *args, **kwargs)
169
170        type_check(data_ptr, (bytes,), "data_ptr")
171        check_uint64(data_size, "data_size")
172        type_check(dtype, (int,), "dtype")
173        shape_names = ["shape_{0}".format(i) for i in range(len(shape))]
174        type_check_list(shape, (int,), shape_names)
175
176        if len(data_ptr) != data_size:
177            raise ValueError("data_ptr length ({0}) is not equal to data_size ({1}).".format(len(data_ptr), data_size))
178
179        return method(self, *args, **kwargs)
180
181    return new_method
182
183
184def check_tensor_base_data_init(method):
185    """Wrapper method to check the parameters of DbgServices TensorBaseData init."""
186
187    @wraps(method)
188    def new_method(self, *args, **kwargs):
189        [data_size, dtype, shape], _ = parse_user_args(method, *args, **kwargs)
190
191        check_uint64(data_size, "data_size")
192        type_check(dtype, (int,), "dtype")
193        shape_names = ["shape_{0}".format(i) for i in range(len(shape))]
194        type_check_list(shape, (int,), shape_names)
195
196        return method(self, *args, **kwargs)
197
198    return new_method
199
200
201def check_tensor_stat_data_init(method):
202    """Wrapper method to check the parameters of DbgServices TensorBaseData init."""
203
204    @wraps(method)
205    def new_method(self, *args, **kwargs):
206        [data_size, dtype, shape, is_bool, max_value, min_value,
207         avg_value, count, neg_zero_count, pos_zero_count,
208         nan_count, neg_inf_count, pos_inf_count,
209         zero_count], _ = parse_user_args(method, *args, **kwargs)
210
211        check_uint64(data_size, "data_size")
212        type_check(dtype, (int,), "dtype")
213        shape_names = ["shape_{0}".format(i) for i in range(len(shape))]
214        type_check_list(shape, (int,), shape_names)
215        type_check(is_bool, (bool,), "is_bool")
216        type_check(max_value, (float,), "max_value")
217        type_check(min_value, (float,), "min_value")
218        type_check(avg_value, (float,), "avg_value")
219        type_check(count, (int,), "count")
220        type_check(neg_zero_count, (int,), "neg_zero_count")
221        type_check(pos_zero_count, (int,), "pos_zero_count")
222        type_check(nan_count, (int,), "nan_count")
223        type_check(neg_inf_count, (int,), "neg_inf_count")
224        type_check(pos_inf_count, (int,), "pos_inf_count")
225        type_check(zero_count, (int,), "zero_count")
226
227
228        return method(self, *args, **kwargs)
229
230    return new_method
231
232
233def check_watchpoint_hit_init(method):
234    """Wrapper method to check the parameters of DbgServices WatchpointHit init."""
235
236    @wraps(method)
237    def new_method(self, *args, **kwargs):
238        [name, slot, condition, watchpoint_id,
239         parameters, error_code, rank_id, root_graph_id], _ = parse_user_args(method, *args, **kwargs)
240
241        type_check(name, (str,), "name")
242        check_uint32(slot, "slot")
243        type_check(condition, (int,), "condition")
244        check_uint32(watchpoint_id, "watchpoint_id")
245        param_names = ["param_{0}".format(i) for i in range(len(parameters))]
246        type_check_list(parameters, (cds.Parameter,), param_names)
247        type_check(error_code, (int,), "error_code")
248        check_uint32(rank_id, "rank_id")
249        check_uint32(root_graph_id, "root_graph_id")
250
251        return method(self, *args, **kwargs)
252
253    return new_method
254
255
256def check_parameter_init(method):
257    """Wrapper method to check the parameters of DbgServices Parameter init."""
258
259    @wraps(method)
260    def new_method(self, *args, **kwargs):
261        [name, disabled, value, hit, actual_value], _ = parse_user_args(method, *args, **kwargs)
262
263        type_check(name, (str,), "name")
264        type_check(disabled, (bool,), "disabled")
265        type_check(value, (float,), "value")
266        type_check(hit, (bool,), "hit")
267        type_check(actual_value, (float,), "actual_value")
268
269        return method(self, *args, **kwargs)
270
271    return new_method
272