1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =========================================================================== 15"""GraphKernel model builder""" 16 17import copy 18from . import op_infer 19from .model import Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy 20 21 22class GraphBuilder: 23 """Graph builder""" 24 class GraphWrapper: 25 """Graph wrapper""" 26 27 def __init__(self, name): 28 self.graph = Graph(name, []) 29 30 def set_input(self, *para): 31 """set input to graph inputs""" 32 for t in para: 33 t.para_type = Tensor.PARA_INPUT 34 self.graph.inputs.append(t) 35 36 def set_output(self, *para): 37 """set output to graph inputs""" 38 for t in para: 39 t.para_type = Tensor.PARA_OUTPUT 40 self.graph.outputs.append(t) 41 42 def __init__(self): 43 self.graphs = [] 44 self.current = None 45 self.name_id = 0 46 47 def _alloc_tensor_name(self): 48 tid = self.name_id 49 self.name_id += 1 50 return "t%d" % (tid) 51 52 def graph_scope(self, name): 53 """The graph scope to be processed""" 54 class GraphScope: 55 """Graph Scope""" 56 57 def __init__(self, gb): 58 self.gb = gb 59 60 def __enter__(self): 61 return self.gb.current 62 63 def __exit__(self, ptype, value, trace): 64 self.gb.graphs.append(self.gb.current.graph) 65 self.gb.current = None 66 67 if self.current is not None: 68 raise ValueError("self.current wrong value") 69 self.current = self.GraphWrapper(name) 70 return GraphScope(self) 71 72 def tensor(self, shape, dtype, data_format="DefaultFormat", name=None, para_type=Tensor.PARA_NONE): 73 """Create a new Tensor""" 74 if name in (None, ''): 75 name = self._alloc_tensor_name() 76 if not shape: 77 shape = [1] 78 return Tensor(name, shape, dtype, data_format, para_type=para_type) 79 80 def value(self, dtype, value, name=None): 81 """Create a new Value""" 82 if name in (None, ''): 83 name = self._alloc_tensor_name() 84 v = Value(name, dtype, value) 85 return v 86 87 def op(self, prim, output, inputs, attrs=None): 88 """Insert an operator into graph""" 89 if attrs is None: 90 attrs = {} 91 if isinstance(inputs, Tensor): 92 inputs = [inputs] 93 tensor_inputs = [t for t in inputs if isinstance(t, Tensor)] 94 node = Operator(prim, tensor_inputs, output, attrs) 95 node.all_inputs = inputs 96 self.current.graph.add(node) 97 98 def emit(self, prim, inputs, name=None, attrs=None): 99 """Emit a new operation""" 100 if attrs is None: 101 attrs = {} 102 if isinstance(inputs, (Tensor, Value)): 103 inputs = [inputs] 104 tensor_inputs = [t for t in inputs if isinstance(t, (Tensor, Value))] 105 out_shape, out_dtype, out_format = op_infer.infer(prim, tensor_inputs, attrs) 106 output = self.tensor(out_shape, out_dtype, out_format, name) 107 self.op(prim, output, inputs, attrs) 108 return output 109 110 def get(self): 111 """Get graphs""" 112 return self.graphs 113 114 115class CompositeGraph: 116 """Composite Graph""" 117 118 def __init__(self): 119 self.graph = None 120 self.desc = None 121 self.tensors = {} # name : Tensor 122 123 def refine(self): 124 """Refine Graph""" 125 AlignShape().visit_graph(self.graph) 126 AddControlBuddy().visit_graph(self.graph) 127 128 def load(self, desc): 129 """Load Graph from json""" 130 def _attr_of(op): 131 if not op['attr']: 132 return dict() 133 attr = {} 134 for a in op['attr']: 135 if a['name'] == 'axis' and op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin'): 136 attr['reduce_axis'] = a['value'] 137 else: 138 attr[a['name']] = a['value'] 139 return attr 140 141 builder = GraphBuilder() 142 with builder.graph_scope(desc['op']): 143 for in_desc in desc['input_desc'] if desc['input_desc'] is not None else []: 144 name, shape, dtype, data_format = in_desc[0]['tensor_name'], in_desc[ 145 0]['shape'], in_desc[0]['data_type'], in_desc[0]['format'] 146 self.tensors[name] = builder.tensor( 147 shape, dtype, data_format, name=name, para_type=Tensor.PARA_INPUT) 148 for out_desc in desc['output_desc']: 149 name, shape, dtype, data_format = out_desc['tensor_name'], out_desc[ 150 'shape'], out_desc['data_type'], out_desc['format'] 151 self.tensors[name] = builder.tensor( 152 shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT) 153 for op in desc['op_desc']: 154 inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d] 155 out_desc = op['output_desc'] 156 name, shape, dtype, data_format = out_desc[0]['tensor_name'], out_desc[ 157 0]['shape'], out_desc[0]['data_type'], out_desc[0]['format'] 158 if op['name'] == 'InplaceAssign': 159 inputs[0].add_buddy(inputs[1]) 160 inputs[1].para_type = Tensor.PARA_OUTPUT 161 output = inputs[2] 162 self.tensors[name] = output 163 continue 164 output = self.tensors.get(name, None) 165 if not output: 166 output = builder.tensor(shape, dtype, data_format, name=name) 167 self.tensors[name] = output 168 builder.op(op['name'], output, inputs, attrs=_attr_of(op)) 169 self.graph = builder.get()[0] 170 self.desc = desc 171 172 def add_stitch_info(self, subgraph, desc): 173 """add stitch info to desc""" 174 if subgraph.stitch_info and subgraph.stitch_info.stitch_ops: 175 buffer_stitch = {'stitch_op': list(subgraph.stitch_info.stitch_ops)} 176 if subgraph.stitch_info.stitch_atomic_ops: 177 buffer_stitch['stitch_atomic_op'] = list(subgraph.stitch_info.stitch_atomic_ops) 178 desc['buffer_stitch'] = buffer_stitch 179 return desc 180 181 def add_recompute_ops(self, subgraph, desc): 182 """add recompute ops to desc""" 183 if subgraph.recompute_ops: 184 desc['recompute_ops'] = [op.output.name for op in subgraph.recompute_ops] 185 return desc 186 187 def _pre_dump(self, outputs): 188 """restore name to before load""" 189 inplace_assign = {} # y_name, output_name 190 inplace_assign_z = None 191 for op in self.desc['op_desc']: 192 if op['name'] == 'InplaceAssign': 193 inplace_assign[op['input_desc'][1][0]['tensor_name']] = op['output_desc'][0]['tensor_name'] 194 if inplace_assign: 195 for t in outputs: 196 if t.name not in inplace_assign: 197 inplace_assign_z = t 198 return inplace_assign, inplace_assign_z 199 200 def dump(self, subgraph): 201 """Dump Graph to json""" 202 desc = {} 203 inputs, outputs = subgraph.deduce_parameters() 204 graph_ops = set(subgraph.ops) 205 inplace_assign, inplace_assign_z = self._pre_dump(outputs) 206 207 def dump_output(t): 208 if t.name in inplace_assign: 209 z = inplace_assign_z if inplace_assign_z is not None else self.tensors[t.name] 210 return {'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign[t.name]} 211 return {'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name} 212 213 def dump_op_desc(d): 214 if d['name'] == 'InplaceAssign': 215 y = d['input_desc'][1][0]['tensor_name'] 216 if self.tensors[y].op in graph_ops: 217 z, fake = (inplace_assign_z, False) if inplace_assign_z is not None else (self.tensors[y], True) 218 inplace_desc = copy.deepcopy(d) 219 inplace_desc['attr'] = {'name': 'fake_output', 'value': fake} 220 z_desc, out_desc = inplace_desc['input_desc'][2][0], inplace_desc['output_desc'][0] 221 z_desc['shape'] = z.shape 222 z_desc['data_type'] = z.dtype 223 z_desc['tensor_name'] = z.name 224 out_desc['shape'] = z.shape 225 out_desc['data_type'] = z.dtype 226 return inplace_desc 227 op = self.tensors[d['output_desc'][0]['tensor_name']].op 228 if op in graph_ops or op in subgraph.recompute_ops: 229 return d 230 return None 231 232 for key in self.desc.keys(): 233 if key == 'input_desc': 234 desc[key] = [[{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}] for t in inputs] 235 elif key == 'output_desc': 236 desc[key] = list(map(dump_output, outputs)) 237 elif key == 'op_desc': 238 op_desc = map(dump_op_desc, self.desc[key]) 239 desc[key] = [d for d in op_desc if d is not None] 240 elif key == 'op': 241 desc[key] = subgraph.name 242 else: 243 desc[key] = self.desc[key] 244 245 desc = self.add_stitch_info(subgraph, desc) 246 desc = self.add_recompute_ops(subgraph, desc) 247 return desc 248 249 250def load_composite(desc): 251 """Load composite kernel""" 252 composite = CompositeGraph() 253 composite.load(desc) 254 composite.refine() 255 return composite 256