• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===========================================================================
15"""graph kernel split"""
16import json
17import getopt
18import sys
19import model
20
21
22def print_usage():
23    print('Usage: graph_kernel_split.py [OPTION] <JSON_FILE>')
24    print('Options:')
25    print('  -s <config/auto>\tsplit graph with config')
26    print('  -e \t\testimate graph')
27    print('  -i \t\tnaive estimate')
28    print('  -o <prefix>\toutput split graphs')
29    print('  -v \t\tverbose mode')
30    print('  -h \t\tprint this help')
31
32
33class Option:
34    """Options"""
35
36    def __init__(self):
37        self.split = None
38        self.estimate = False
39        self.estimate_naive = False
40        self.output = None
41        self.verbose = False
42        self.help = False
43
44    def parse(self, options):
45        """parse options"""
46        for name, val in options:
47            if name == '-h':
48                self.help = True
49            elif name == '-v':
50                self.verbose = True
51            elif name == '-o':
52                self.output = val
53            elif name == '-e':
54                self.estimate = True
55            elif name == '-s':
56                self.split = val
57            elif name == '-i':
58                self.estimate_naive = True
59
60
61opt = Option()
62
63
64def estimate(graph_in, parts_in, naive):
65    """estimate graphs costs"""
66    def _print_cost(name, c):
67        print("%s\tdma_ratio=%f, saturation=%f, mix_saturation=%f, type=%s" %
68              (name, c.dma_ratio(), c.saturation(), c.mix_saturation(), c.cost_type()))
69    main_cost, _ = model.estimate(graph_in, naive)
70    split_cost, sub_costs = model.estimate(parts_in, naive) if parts_in else (None, None)
71    _print_cost("MainGraph:", main_cost)
72    if parts_in:
73        _print_cost("Subgraphs:", split_cost)
74        if opt.verbose:
75            for i, sub_cost in enumerate(sub_costs):
76                _print_cost(" |_%d:\t" % (i), sub_cost)
77
78
79def split_graph(graph_in, config):
80    """split graph"""
81    if config == 'auto':
82        return model.split(graph_in)
83    subgraphs = []
84    all_tensors = []
85    subgraph_idx = 0
86    config_parts = config.split('|')
87    for part in config_parts:
88        tensor_names = part.split(',')
89        graph_name = "%s_%d" % (graph_in.name, subgraph_idx)
90        g = graph_in.extract_subgraph(graph_name, tensor_names)
91        assert len(g.ops) == len(tensor_names)
92        subgraphs.append(g)
93        all_tensors += tensor_names
94        subgraph_idx += 1
95    if len(all_tensors) < len(graph_in.ops):
96        graph_name = "%s_%d" % (graph_in.name, subgraph_idx)
97        g = graph_in.extract_subgraph(graph_name, all_tensors, True)
98        subgraphs.append(g)
99    return subgraphs
100
101
102def main():
103    opts, args = getopt.getopt(sys.argv[1:], 'heivo:s:')
104    opt.parse(opts)
105    if len(args) != 1 or opt.help:
106        print_usage()
107        sys.exit(0)
108    in_file = args[0]
109    with open(in_file, 'r') as f:
110        desc = json.loads(f.read())
111        comp = model.load_composite(desc)
112        graph = comp.graph
113        parts = []
114        # 1. split sub-graphs
115        if opt.split is not None:
116            parts = split_graph(graph, opt.split)
117        if opt.verbose:
118            print('----------- main graph --------------')
119            print(graph)
120            for i, _ in enumerate(parts):
121                print('---------------- sub graph %d ---------------' % (i))
122                print(parts[i])
123        # 2. estimate cost
124        if opt.estimate:
125            print('------------- cost --------------')
126            estimate(graph, parts, False)
127        if opt.estimate_naive:
128            print('------------- naive cost --------------')
129            estimate(graph, parts, True)
130        # 3. output parts
131        if opt.output is not None:
132            for graph_part in parts:
133                desc = comp.dump(graph_part)
134                s_desc = json.dumps(desc)
135                fname = "%s_%s.json" % (opt.output, graph_part.name)
136                with open(fname, 'w', encoding='utf-8') as of:
137                    of.write(s_desc)
138
139
140if __name__ == '__main__':
141    main()
142