• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import time
16import operator
17
18
19class Register:
20    def __init__(self):
21        self.case_targets = dict()
22        self.case_levels = dict()
23        self.skip_cases = dict()
24
25    def target_ascend(self, fn):
26        self._add_target(fn, "Ascend")
27        return fn
28
29    def target_gpu(self, fn):
30        self._add_target(fn, "GPU")
31        return fn
32
33    def target_cpu(self, fn):
34        self._add_target(fn, "CPU")
35        return fn
36
37    def level0(self, fn):
38        self._add_level(fn, 0)
39        return fn
40
41    def level1(self, fn):
42        self._add_level(fn, 1)
43        return fn
44
45    def skip(self, reason):
46        def deco(fn):
47            self.skip_cases[fn] = reason
48            return fn
49
50        return deco
51
52    def _add_target(self, fn, target):
53        if fn not in self.case_targets:
54            self.case_targets[fn] = set()
55        self.case_targets[fn].add(target)
56
57    def _add_level(self, fn, level):
58        self.case_levels[fn] = level
59
60    def check_and_run(self, target, level):
61        time_cost = dict()
62        for fn, targets in self.case_targets.items():
63            if fn in self.skip_cases:
64                continue
65            if target not in targets:
66                continue
67            if fn not in self.case_levels:
68                continue
69            if self.case_levels[fn] != level:
70                continue
71            print(f"\nexceute fn:{fn}, level:{level}, target:{target}")
72            start_time = time.time()
73            fn()
74            end_time = time.time()
75            time_cost[fn] = end_time - start_time
76
77        sorted_time_cost = sorted(time_cost.items(), key=operator.itemgetter(1), reverse=True)
78        total_cost_time = 0
79        for item in sorted_time_cost:
80            total_cost_time += item[1]
81            print("Time:", item[1], ", fn:", item[0], "\n")
82        print("Total cost time:", total_cost_time)
83
84
85case_register = Register()
86