1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =========================================================================== 15"""graph kernel split""" 16import json 17import getopt 18import sys 19import model 20 21 22def print_usage(): 23 print('Usage: graph_kernel_split.py [OPTION] <JSON_FILE>') 24 print('Options:') 25 print(' -s <config/auto>\tsplit graph with config') 26 print(' -e \t\testimate graph') 27 print(' -i \t\tnaive estimate') 28 print(' -o <prefix>\toutput split graphs') 29 print(' -v \t\tverbose mode') 30 print(' -h \t\tprint this help') 31 32 33class Option: 34 """Options""" 35 36 def __init__(self): 37 self.split = None 38 self.estimate = False 39 self.estimate_naive = False 40 self.output = None 41 self.verbose = False 42 self.help = False 43 44 def parse(self, options): 45 """parse options""" 46 for name, val in options: 47 if name == '-h': 48 self.help = True 49 elif name == '-v': 50 self.verbose = True 51 elif name == '-o': 52 self.output = val 53 elif name == '-e': 54 self.estimate = True 55 elif name == '-s': 56 self.split = val 57 elif name == '-i': 58 self.estimate_naive = True 59 60 61opt = Option() 62 63 64def estimate(graph_in, parts_in, naive): 65 """estimate graphs costs""" 66 def _print_cost(name, c): 67 print("%s\tdma_ratio=%f, saturation=%f, mix_saturation=%f, type=%s" % 68 (name, c.dma_ratio(), c.saturation(), c.mix_saturation(), c.cost_type())) 69 main_cost, _ = model.estimate(graph_in, naive) 70 split_cost, sub_costs = model.estimate(parts_in, naive) if parts_in else (None, None) 71 _print_cost("MainGraph:", main_cost) 72 if parts_in: 73 _print_cost("Subgraphs:", split_cost) 74 if opt.verbose: 75 for i, sub_cost in enumerate(sub_costs): 76 _print_cost(" |_%d:\t" % (i), sub_cost) 77 78 79def split_graph(graph_in, config): 80 """split graph""" 81 if config == 'auto': 82 return model.split(graph_in) 83 subgraphs = [] 84 all_tensors = [] 85 subgraph_idx = 0 86 config_parts = config.split('|') 87 for part in config_parts: 88 tensor_names = part.split(',') 89 graph_name = "%s_%d" % (graph_in.name, subgraph_idx) 90 g = graph_in.extract_subgraph(graph_name, tensor_names) 91 assert len(g.ops) == len(tensor_names) 92 subgraphs.append(g) 93 all_tensors += tensor_names 94 subgraph_idx += 1 95 if len(all_tensors) < len(graph_in.ops): 96 graph_name = "%s_%d" % (graph_in.name, subgraph_idx) 97 g = graph_in.extract_subgraph(graph_name, all_tensors, True) 98 subgraphs.append(g) 99 return subgraphs 100 101 102def main(): 103 opts, args = getopt.getopt(sys.argv[1:], 'heivo:s:') 104 opt.parse(opts) 105 if len(args) != 1 or opt.help: 106 print_usage() 107 sys.exit(0) 108 in_file = args[0] 109 with open(in_file, 'r') as f: 110 desc = json.loads(f.read()) 111 comp = model.load_composite(desc) 112 graph = comp.graph 113 parts = [] 114 # 1. split sub-graphs 115 if opt.split is not None: 116 parts = split_graph(graph, opt.split) 117 if opt.verbose: 118 print('----------- main graph --------------') 119 print(graph) 120 for i, _ in enumerate(parts): 121 print('---------------- sub graph %d ---------------' % (i)) 122 print(parts[i]) 123 # 2. estimate cost 124 if opt.estimate: 125 print('------------- cost --------------') 126 estimate(graph, parts, False) 127 if opt.estimate_naive: 128 print('------------- naive cost --------------') 129 estimate(graph, parts, True) 130 # 3. output parts 131 if opt.output is not None: 132 for graph_part in parts: 133 desc = comp.dump(graph_part) 134 s_desc = json.dumps(desc) 135 fname = "%s_%s.json" % (opt.output, graph_part.name) 136 with open(fname, 'w', encoding='utf-8') as of: 137 of.write(s_desc) 138 139 140if __name__ == '__main__': 141 main() 142