• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Watchpoints test script for offline debugger APIs.
17"""
18
19import os
20import json
21import shutil
22import numpy as np
23import mindspore.offline_debug.dbg_services as d
24from dump_test_utils import build_dump_structure
25from tests.security_utils import security_off_wrap
26
27
28class TestOfflineWatchpoints:
29    """Test watchpoint for offline debugger."""
30    GENERATE_GOLDEN = False
31    test_name = "watchpoints"
32    watchpoint_hits_json = []
33    temp_dir = ''
34
35    @classmethod
36    def setup_class(cls):
37        """Init setup for offline watchpoints test"""
38        name1 = "Conv2D.Conv2D-op369.0.0.1"
39        tensor1 = np.array([[[-1.2808e-03, 7.7629e-03, 1.9241e-02],
40                             [-1.3931e-02, 8.9359e-04, -1.1520e-02],
41                             [-6.3248e-03, 1.8749e-03, 1.0132e-02]],
42                            [[-2.5520e-03, -6.0005e-03, -5.1918e-03],
43                             [-2.7866e-03, 2.5487e-04, 8.4782e-04],
44                             [-4.6310e-03, -8.9111e-03, -8.1778e-05]],
45                            [[1.3914e-03, 6.0844e-04, 1.0643e-03],
46                             [-2.0966e-02, -1.2865e-03, -1.8692e-03],
47                             [-1.6647e-02, 1.0233e-03, -4.1313e-03]]], np.float32)
48        info1 = d.TensorInfo(node_name="Default/network-WithLossCell/_backbone-AlexNet/conv1-Conv2d/Conv2D-op369",
49                             slot=1, iteration=2, rank_id=0, root_graph_id=0, is_output=False)
50
51        name2 = "Parameter.fc2.bias.0.0.2"
52        tensor2 = np.array([-5.0167350e-06, 1.2509107e-05, -4.3148934e-06, 8.1415592e-06,
53                            2.1177532e-07, 2.9952851e-06], np.float32)
54        info2 = d.TensorInfo(node_name="Default/network-WithLossCell/_backbone-AlexNet/fc3-Dense/"
55                             "Parameter[6]_11/fc2.bias",
56                             slot=0, iteration=2, rank_id=0, root_graph_id=0, is_output=True)
57
58        tensor3 = np.array([2.9060817e-07, -5.1009415e-06, -2.8662325e-06, 2.6036503e-06,
59                            -5.1546101e-07, 6.0798648e-06], np.float32)
60        info3 = d.TensorInfo(node_name="Default/network-WithLossCell/_backbone-AlexNet/fc3-Dense/"
61                             "Parameter[6]_11/fc2.bias",
62                             slot=0, iteration=3, rank_id=0, root_graph_id=0, is_output=True)
63
64        name3 = "CudnnUniformReal.CudnnUniformReal-op391.0.0.3"
65        tensor4 = np.array([-32.0, -4096.0], np.float32)
66        info4 = d.TensorInfo(node_name="Default/CudnnUniformReal-op391",
67                             slot=0, iteration=2, rank_id=0, root_graph_id=0, is_output=False)
68
69        tensor_info = [info1, info2, info3, info4]
70        tensor_name = [name1, name2, name2, name3]
71        tensor_list = [tensor1, tensor2, tensor3, tensor4]
72        cls.temp_dir = build_dump_structure(tensor_name, tensor_list, "Test", tensor_info)
73
74    @classmethod
75    def teardown_class(cls):
76        shutil.rmtree(cls.temp_dir)
77
78    @security_off_wrap
79    def test_sync_add_remove_watchpoints_hit(self):
80        # NOTES: watch_condition=6 is MIN_LT
81        # watchpoint set and hit (watch_condition=6), then remove it
82        debugger_backend = d.DbgServices(dump_file_path=self.temp_dir)
83        _ = debugger_backend.initialize(net_name="Test", is_sync_mode=True)
84        param = d.Parameter(name="param", disabled=False, value=0.0)
85        _ = debugger_backend.add_watchpoint(watchpoint_id=1, watch_condition=6,
86                                            check_node_list={"Default/network-WithLossCell/_backbone-AlexNet"
87                                                             "/conv1-Conv2d/Conv2D-op369":
88                                                             {"rank_id": [0], "root_graph_id": [0], "is_output": False
89                                                              }}, parameter_list=[param])
90        # add second watchpoint to check the watchpoint hit in correct order
91        param1 = d.Parameter(name="param", disabled=False, value=10.0)
92        _ = debugger_backend.add_watchpoint(watchpoint_id=2, watch_condition=6,
93                                            check_node_list={"Default/CudnnUniformReal-op391":
94                                                             {"rank_id": [0], "root_graph_id": [0], "is_output": False
95                                                              }}, parameter_list=[param1])
96
97        watchpoint_hits_test = debugger_backend.check_watchpoints(iteration=2)
98        assert len(watchpoint_hits_test) == 2
99        if self.GENERATE_GOLDEN:
100            self.print_watchpoint_hits(watchpoint_hits_test, 0, False)
101        else:
102            self.compare_expect_actual_result(watchpoint_hits_test, 0)
103
104        _ = debugger_backend.remove_watchpoint(watchpoint_id=1)
105        watchpoint_hits_test_1 = debugger_backend.check_watchpoints(iteration=2)
106        assert len(watchpoint_hits_test_1) == 1
107
108    @security_off_wrap
109    def test_sync_add_remove_watchpoints_not_hit(self):
110        # watchpoint set and not hit(watch_condition=6), then remove
111        debugger_backend = d.DbgServices(dump_file_path=self.temp_dir)
112        _ = debugger_backend.initialize(net_name="Test", is_sync_mode=True)
113        param = d.Parameter(name="param", disabled=False, value=-1000.0)
114        _ = debugger_backend.add_watchpoint(watchpoint_id=2, watch_condition=6,
115                                            check_node_list={"Default/network-WithLossCell/_backbone-AlexNet"
116                                                             "/conv1-Conv2d/Conv2D-op369":
117                                                             {"rank_id": [0], "root_graph_id": [0], "is_output": False
118                                                              }}, parameter_list=[param])
119
120        watchpoint_hits_test = debugger_backend.check_watchpoints(iteration=2)
121        assert not watchpoint_hits_test
122        _ = debugger_backend.remove_watchpoint(watchpoint_id=2)
123
124    @security_off_wrap
125    def test_sync_weight_change_watchpoints_hit(self):
126        # NOTES: watch_condition=18 is CHANGE_TOO_LARGE
127        # weight change watchpoint set and hit(watch_condition=18)
128        debugger_backend = d.DbgServices(dump_file_path=self.temp_dir)
129        _ = debugger_backend.initialize(net_name="Test", is_sync_mode=True)
130        param_abs_mean_update_ratio_gt = d.Parameter(
131            name="abs_mean_update_ratio_gt", disabled=False, value=0.0)
132        param_epsilon = d.Parameter(name="epsilon", disabled=True, value=0.0)
133        _ = debugger_backend.add_watchpoint(watchpoint_id=3, watch_condition=18,
134                                            check_node_list={"Default/network-WithLossCell/_backbone-AlexNet/fc3-Dense/"
135                                                             "Parameter[6]_11/fc2.bias":
136                                                             {"rank_id": [0], "root_graph_id": [0], "is_output": True
137                                                              }}, parameter_list=[param_abs_mean_update_ratio_gt,
138                                                                                  param_epsilon])
139
140        watchpoint_hits_test = debugger_backend.check_watchpoints(iteration=3)
141        assert len(watchpoint_hits_test) == 1
142        if self.GENERATE_GOLDEN:
143            self.print_watchpoint_hits(watchpoint_hits_test, 2, True)
144        else:
145            self.compare_expect_actual_result(watchpoint_hits_test, 2)
146
147    @security_off_wrap
148    def test_async_add_remove_watchpoint_hit(self):
149        # watchpoint set and hit(watch_condition=6) in async mode, then remove
150        debugger_backend = d.DbgServices(dump_file_path=self.temp_dir)
151        _ = debugger_backend.initialize(net_name="Test", is_sync_mode=False)
152        param = d.Parameter(name="param", disabled=False, value=0.0)
153        _ = debugger_backend.add_watchpoint(watchpoint_id=1, watch_condition=6,
154                                            check_node_list={"Default/network-WithLossCell/_backbone-AlexNet"
155                                                             "/conv1-Conv2d/Conv2D-op369":
156                                                             {"rank_id": [0], "root_graph_id": [0], "is_output": False
157                                                              }}, parameter_list=[param])
158
159        watchpoint_hits_test = debugger_backend.check_watchpoints(iteration=2)
160        assert len(watchpoint_hits_test) == 1
161        if not self.GENERATE_GOLDEN:
162            self.compare_expect_actual_result(watchpoint_hits_test, 0)
163
164        _ = debugger_backend.remove_watchpoint(watchpoint_id=1)
165        watchpoint_hits_test_1 = debugger_backend.check_watchpoints(iteration=2)
166        assert not watchpoint_hits_test_1
167
168    @security_off_wrap
169    def test_async_add_remove_watchpoints_not_hit(self):
170        # watchpoint set and not hit(watch_condition=6) in async mode, then remove
171        debugger_backend = d.DbgServices(dump_file_path=self.temp_dir)
172        _ = debugger_backend.initialize(net_name="Test", is_sync_mode=False)
173        param = d.Parameter(name="param", disabled=False, value=-1000.0)
174        _ = debugger_backend.add_watchpoint(watchpoint_id=2, watch_condition=6,
175                                            check_node_list={"Default/network-WithLossCell/_backbone-AlexNet"
176                                                             "/conv1-Conv2d/Conv2D-op369":
177                                                             {"rank_id": [0], "root_graph_id": [0], "is_output": False
178                                                              }}, parameter_list=[param])
179
180        watchpoint_hits_test = debugger_backend.check_watchpoints(iteration=2)
181        assert not watchpoint_hits_test
182        _ = debugger_backend.remove_watchpoint(watchpoint_id=2)
183
184    def compare_expect_actual_result(self, watchpoint_hits_list, test_index):
185        """Compare actual result with golden file."""
186        golden_file = os.path.realpath(os.path.join("../data/dump/gpu_dumps/golden/",
187                                                    self.test_name + "_expected.json"))
188        with open(golden_file) as f:
189            expected_list = json.load(f)
190            for x, watchpoint_hits in enumerate(watchpoint_hits_list):
191                test_id = "watchpoint_hit" + str(test_index+x+1)
192                info = expected_list[x+test_index][test_id]
193                assert watchpoint_hits.name == info['name']
194                assert watchpoint_hits.slot == info['slot']
195                assert watchpoint_hits.condition == info['condition']
196                assert watchpoint_hits.watchpoint_id == info['watchpoint_id']
197                assert watchpoint_hits.error_code == info['error_code']
198                assert watchpoint_hits.rank_id == info['rank_id']
199                assert watchpoint_hits.root_graph_id == info['root_graph_id']
200                for p, _ in enumerate(watchpoint_hits.parameters):
201                    parameter = "parameter" + str(p)
202                    assert watchpoint_hits.parameters[p].name == info['paremeter'][p][parameter]['name']
203                    assert watchpoint_hits.parameters[p].disabled == info['paremeter'][p][parameter]['disabled']
204                    assert watchpoint_hits.parameters[p].value == info['paremeter'][p][parameter]['value']
205                    assert watchpoint_hits.parameters[p].hit == info['paremeter'][p][parameter]['hit']
206                    assert watchpoint_hits.parameters[p].actual_value == info['paremeter'][p][parameter]['actual_value']
207
208    def print_watchpoint_hits(self, watchpoint_hits_list, test_index, is_print):
209        """Print watchpoint hits."""
210        for x, watchpoint_hits in enumerate(watchpoint_hits_list):
211            parameter_json = []
212            for p, _ in enumerate(watchpoint_hits.parameters):
213                parameter = "parameter" + str(p)
214                parameter_json.append({
215                    parameter: {
216                        'name': watchpoint_hits.parameters[p].name,
217                        'disabled': watchpoint_hits.parameters[p].disabled,
218                        'value': watchpoint_hits.parameters[p].value,
219                        'hit': watchpoint_hits.parameters[p].hit,
220                        'actual_value': watchpoint_hits.parameters[p].actual_value
221                    }
222                })
223            watchpoint_hit = "watchpoint_hit" + str(test_index+x+1)
224            self.watchpoint_hits_json.append({
225                watchpoint_hit: {
226                    'name': watchpoint_hits.name,
227                    'slot': watchpoint_hits.slot,
228                    'condition': watchpoint_hits.condition,
229                    'watchpoint_id': watchpoint_hits.watchpoint_id,
230                    'paremeter': parameter_json,
231                    'error_code': watchpoint_hits.error_code,
232                    'rank_id': watchpoint_hits.rank_id,
233                    'root_graph_id': watchpoint_hits.root_graph_id
234                }
235            })
236        if is_print:
237            with open(self.test_name + "_expected.json", "w") as dump_f:
238                json.dump(self.watchpoint_hits_json, dump_f, indent=4, separators=(',', ': '))
239