1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Watchpoints test script for offline debugger APIs. 17""" 18 19import mindspore.offline_debug.dbg_services as d 20import pytest 21from dump_test_utils import compare_actual_with_expected 22from tests.security_utils import security_off_wrap 23 24GENERATE_GOLDEN = False 25test_name = "sync_trans_false_watchpoints" 26 27 28@pytest.mark.level0 29@pytest.mark.platform_arm_ascend_training 30@pytest.mark.platform_x86_ascend_training 31@pytest.mark.env_onecard 32@pytest.mark.skip(reason="needs updating") 33@security_off_wrap 34def test_sync_trans_false_watchpoints(): 35 36 if GENERATE_GOLDEN: 37 f_write = open(test_name + ".expected", "w") 38 else: 39 f_write = open(test_name + ".actual", "w") 40 41 debugger_backend = d.DbgServices( 42 dump_file_path="/home/workspace/mindspore_dataset/dumps/sync_trans_false/alexnet/") 43 44 _ = debugger_backend.initialize( 45 net_name="Network Name goes here!", is_sync_mode=True) 46 47 # NOTES: 48 # -> watch_condition=6 is MIN_LT 49 # -> watch_condition=18 is CHANGE_TOO_LARGE 50 51 # test 1: watchpoint set and hit (watch_condition=6) 52 param1 = d.Parameter(name="param", disabled=False, value=0.0) 53 _ = debugger_backend.add_watchpoint(watchpoint_id=1, watch_condition=6, 54 check_node_list={"Default/network-WithLossCell/_backbone-AlexNet/conv3-Conv2d/" 55 "Conv2D-op168": 56 {"device_id": [0], "root_graph_id": [0], 57 "is_parameter": False 58 }}, parameter_list=[param1]) 59 60 watchpoint_hits_test_1 = debugger_backend.check_watchpoints(iteration=2) 61 if len(watchpoint_hits_test_1) != 1: 62 f_write.write("ERROR -> test 1: watchpoint set but not hit just once") 63 print_watchpoint_hits(watchpoint_hits_test_1, 1, f_write) 64 65 # test 2: watchpoint remove and ensure it's not hit 66 _ = debugger_backend.remove_watchpoint(watchpoint_id=1) 67 watchpoint_hits_test_2 = debugger_backend.check_watchpoints(iteration=2) 68 if watchpoint_hits_test_2: 69 f_write.write("ERROR -> test 2: watchpoint removed but hit") 70 71 # test 3: watchpoint set and not hit, then remove 72 param2 = d.Parameter(name="param", disabled=False, value=-1000.0) 73 _ = debugger_backend.add_watchpoint(watchpoint_id=2, watch_condition=6, 74 check_node_list={"Default/network-WithLossCell/_backbone-AlexNet/conv3-Conv2d/" 75 "Conv2D-op308": 76 {"device_id": [0], "root_graph_id": [0], 77 "is_parameter": False 78 }}, parameter_list=[param2]) 79 80 watchpoint_hits_test_3 = debugger_backend.check_watchpoints(iteration=2) 81 if watchpoint_hits_test_3: 82 f_write.write("ERROR -> test 3: watchpoint set but not supposed to be hit") 83 _ = debugger_backend.remove_watchpoint(watchpoint_id=2) 84 85 # test 4: weight change watchpoint set and hit 86 param_abs_mean_update_ratio_gt = d.Parameter( 87 name="abs_mean_update_ratio_gt", disabled=False, value=0.0) 88 param_epsilon = d.Parameter(name="epsilon", disabled=True, value=0.0) 89 _ = debugger_backend.add_watchpoint(watchpoint_id=3, watch_condition=18, 90 check_node_list={"Default/network-WithLossCell/_backbone-AlexNet/fc3-Dense/" 91 "Parameter[6]_11/fc3.bias": 92 {"device_id": [0], "root_graph_id": [0], 93 "is_parameter": True 94 }}, parameter_list=[param_abs_mean_update_ratio_gt, 95 param_epsilon]) 96 97 watchpoint_hits_test_4 = debugger_backend.check_watchpoints(iteration=3) 98 if len(watchpoint_hits_test_4) != 1: 99 f_write.write("ERROR -> test 4: watchpoint weight change set but not hit just once") 100 print_watchpoint_hits(watchpoint_hits_test_4, 4, f_write) 101 f_write.close() 102 if not GENERATE_GOLDEN: 103 assert compare_actual_with_expected(test_name) 104 105 106def print_watchpoint_hits(watchpoint_hits, test_id, f_write): 107 """Print watchpoint hits.""" 108 for x, _ in enumerate(watchpoint_hits): 109 f_write.write("-----------------------------------------------------------\n") 110 f_write.write("watchpoint_hit for test_%u attributes:" % test_id + "\n") 111 f_write.write("name = " + watchpoint_hits[x].name + "\n") 112 f_write.write("slot = " + str(watchpoint_hits[x].slot) + "\n") 113 f_write.write("condition = " + str(watchpoint_hits[x].condition) + "\n") 114 f_write.write("watchpoint_id = " + str(watchpoint_hits[x].watchpoint_id) + "\n") 115 for p, _ in enumerate(watchpoint_hits[x].parameters): 116 f_write.write("parameter " + str(p) + " name = " + 117 watchpoint_hits[x].parameters[p].name + "\n") 118 f_write.write("parameter " + str(p) + " disabled = " + 119 str(watchpoint_hits[x].parameters[p].disabled) + "\n") 120 f_write.write("parameter " + str(p) + " value = " + 121 str(watchpoint_hits[x].parameters[p].value) + "\n") 122 f_write.write("parameter " + str(p) + " hit = " + 123 str(watchpoint_hits[x].parameters[p].hit) + "\n") 124 f_write.write("parameter " + str(p) + " actual_value = " + 125 str(watchpoint_hits[x].parameters[p].actual_value) + "\n") 126 f_write.write("error code = " + str(watchpoint_hits[x].error_code) + "\n") 127 f_write.write("device_id = " + str(watchpoint_hits[x].device_id) + "\n") 128 f_write.write("root_graph_id = " + str(watchpoint_hits[x].root_graph_id) + "\n") 129