1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import time 16import operator 17 18 19class Register: 20 def __init__(self): 21 self.case_targets = dict() 22 self.case_levels = dict() 23 self.skip_cases = dict() 24 25 def target_ascend(self, fn): 26 self._add_target(fn, "Ascend") 27 return fn 28 29 def target_gpu(self, fn): 30 self._add_target(fn, "GPU") 31 return fn 32 33 def target_cpu(self, fn): 34 self._add_target(fn, "CPU") 35 return fn 36 37 def level0(self, fn): 38 self._add_level(fn, 0) 39 return fn 40 41 def level1(self, fn): 42 self._add_level(fn, 1) 43 return fn 44 45 def skip(self, reason): 46 def deco(fn): 47 self.skip_cases[fn] = reason 48 return fn 49 50 return deco 51 52 def _add_target(self, fn, target): 53 if fn not in self.case_targets: 54 self.case_targets[fn] = set() 55 self.case_targets[fn].add(target) 56 57 def _add_level(self, fn, level): 58 self.case_levels[fn] = level 59 60 def check_and_run(self, target, level): 61 time_cost = dict() 62 for fn, targets in self.case_targets.items(): 63 if fn in self.skip_cases: 64 continue 65 if target not in targets: 66 continue 67 if fn not in self.case_levels: 68 continue 69 if self.case_levels[fn] != level: 70 continue 71 print(f"\nexceute fn:{fn}, level:{level}, target:{target}") 72 start_time = time.time() 73 fn() 74 end_time = time.time() 75 time_cost[fn] = end_time - start_time 76 77 sorted_time_cost = sorted(time_cost.items(), key=operator.itemgetter(1), reverse=True) 78 total_cost_time = 0 79 for item in sorted_time_cost: 80 total_cost_time += item[1] 81 print("Time:", item[1], ", fn:", item[0], "\n") 82 print("Total cost time:", total_cost_time) 83 84 85case_register = Register() 86