• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===========================================================================
15"""GraphKernel model builder"""
16
17import copy
18from . import op_infer
19from .model import Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy
20
21
22class GraphBuilder:
23    """Graph builder"""
24    class GraphWrapper:
25        """Graph wrapper"""
26
27        def __init__(self, name):
28            self.graph = Graph(name, [])
29
30        def set_input(self, *para):
31            """set input to graph inputs"""
32            for t in para:
33                t.para_type = Tensor.PARA_INPUT
34                self.graph.inputs.append(t)
35
36        def set_output(self, *para):
37            """set output to graph inputs"""
38            for t in para:
39                t.para_type = Tensor.PARA_OUTPUT
40                self.graph.outputs.append(t)
41
42    def __init__(self):
43        self.graphs = []
44        self.current = None
45        self.name_id = 0
46
47    def _alloc_tensor_name(self):
48        tid = self.name_id
49        self.name_id += 1
50        return "t%d" % (tid)
51
52    def graph_scope(self, name):
53        """The graph scope to be processed"""
54        class GraphScope:
55            """Graph Scope"""
56
57            def __init__(self, gb):
58                self.gb = gb
59
60            def __enter__(self):
61                return self.gb.current
62
63            def __exit__(self, ptype, value, trace):
64                self.gb.graphs.append(self.gb.current.graph)
65                self.gb.current = None
66
67        if self.current is not None:
68            raise ValueError("self.current wrong value")
69        self.current = self.GraphWrapper(name)
70        return GraphScope(self)
71
72    def tensor(self, shape, dtype, data_format="DefaultFormat", name=None, para_type=Tensor.PARA_NONE):
73        """Create a new Tensor"""
74        if name in (None, ''):
75            name = self._alloc_tensor_name()
76        if not shape:
77            shape = [1]
78        return Tensor(name, shape, dtype, data_format, para_type=para_type)
79
80    def value(self, dtype, value, name=None):
81        """Create a new Value"""
82        if name in (None, ''):
83            name = self._alloc_tensor_name()
84        v = Value(name, dtype, value)
85        return v
86
87    def op(self, prim, output, inputs, attrs=None):
88        """Insert an operator into graph"""
89        if attrs is None:
90            attrs = {}
91        if isinstance(inputs, Tensor):
92            inputs = [inputs]
93        tensor_inputs = [t for t in inputs if isinstance(t, Tensor)]
94        node = Operator(prim, tensor_inputs, output, attrs)
95        node.all_inputs = inputs
96        self.current.graph.add(node)
97
98    def emit(self, prim, inputs, name=None, attrs=None):
99        """Emit a new operation"""
100        if attrs is None:
101            attrs = {}
102        if isinstance(inputs, (Tensor, Value)):
103            inputs = [inputs]
104        tensor_inputs = [t for t in inputs if isinstance(t, (Tensor, Value))]
105        out_shape, out_dtype, out_format = op_infer.infer(prim, tensor_inputs, attrs)
106        output = self.tensor(out_shape, out_dtype, out_format, name)
107        self.op(prim, output, inputs, attrs)
108        return output
109
110    def get(self):
111        """Get graphs"""
112        return self.graphs
113
114
115class CompositeGraph:
116    """Composite Graph"""
117
118    def __init__(self):
119        self.graph = None
120        self.desc = None
121        self.tensors = {}  # name : Tensor
122
123    def refine(self):
124        """Refine Graph"""
125        AlignShape().visit_graph(self.graph)
126        AddControlBuddy().visit_graph(self.graph)
127
128    def load(self, desc):
129        """Load Graph from json"""
130        def _attr_of(op):
131            if not op['attr']:
132                return dict()
133            attr = {}
134            for a in op['attr']:
135                if a['name'] == 'axis' and op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin'):
136                    attr['reduce_axis'] = a['value']
137                else:
138                    attr[a['name']] = a['value']
139            return attr
140
141        builder = GraphBuilder()
142        with builder.graph_scope(desc['op']):
143            for in_desc in desc['input_desc'] if desc['input_desc'] is not None else []:
144                name, shape, dtype, data_format = in_desc[0]['tensor_name'], in_desc[
145                    0]['shape'], in_desc[0]['data_type'], in_desc[0]['format']
146                self.tensors[name] = builder.tensor(
147                    shape, dtype, data_format, name=name, para_type=Tensor.PARA_INPUT)
148            for out_desc in desc['output_desc']:
149                name, shape, dtype, data_format = out_desc['tensor_name'], out_desc[
150                    'shape'], out_desc['data_type'], out_desc['format']
151                self.tensors[name] = builder.tensor(
152                    shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT)
153            for op in desc['op_desc']:
154                inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d]
155                out_desc = op['output_desc']
156                name, shape, dtype, data_format = out_desc[0]['tensor_name'], out_desc[
157                    0]['shape'], out_desc[0]['data_type'], out_desc[0]['format']
158                if op['name'] == 'InplaceAssign':
159                    inputs[0].add_buddy(inputs[1])
160                    inputs[1].para_type = Tensor.PARA_OUTPUT
161                    output = inputs[2]
162                    self.tensors[name] = output
163                    continue
164                output = self.tensors.get(name, None)
165                if not output:
166                    output = builder.tensor(shape, dtype, data_format, name=name)
167                    self.tensors[name] = output
168                builder.op(op['name'], output, inputs, attrs=_attr_of(op))
169        self.graph = builder.get()[0]
170        self.desc = desc
171
172    def add_stitch_info(self, subgraph, desc):
173        """add stitch info to desc"""
174        if subgraph.stitch_info and subgraph.stitch_info.stitch_ops:
175            buffer_stitch = {'stitch_op': list(subgraph.stitch_info.stitch_ops)}
176            if subgraph.stitch_info.stitch_atomic_ops:
177                buffer_stitch['stitch_atomic_op'] = list(subgraph.stitch_info.stitch_atomic_ops)
178            desc['buffer_stitch'] = buffer_stitch
179        return desc
180
181    def add_recompute_ops(self, subgraph, desc):
182        """add recompute ops to desc"""
183        if subgraph.recompute_ops:
184            desc['recompute_ops'] = [op.output.name for op in subgraph.recompute_ops]
185        return desc
186
187    def _pre_dump(self, outputs):
188        """restore name to before load"""
189        inplace_assign = {}  # y_name, output_name
190        inplace_assign_z = None
191        for op in self.desc['op_desc']:
192            if op['name'] == 'InplaceAssign':
193                inplace_assign[op['input_desc'][1][0]['tensor_name']] = op['output_desc'][0]['tensor_name']
194        if inplace_assign:
195            for t in outputs:
196                if t.name not in inplace_assign:
197                    inplace_assign_z = t
198        return inplace_assign, inplace_assign_z
199
200    def dump(self, subgraph):
201        """Dump Graph to json"""
202        desc = {}
203        inputs, outputs = subgraph.deduce_parameters()
204        graph_ops = set(subgraph.ops)
205        inplace_assign, inplace_assign_z = self._pre_dump(outputs)
206
207        def dump_output(t):
208            if t.name in inplace_assign:
209                z = inplace_assign_z if inplace_assign_z is not None else self.tensors[t.name]
210                return {'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign[t.name]}
211            return {'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}
212
213        def dump_op_desc(d):
214            if d['name'] == 'InplaceAssign':
215                y = d['input_desc'][1][0]['tensor_name']
216                if self.tensors[y].op in graph_ops:
217                    z, fake = (inplace_assign_z, False) if inplace_assign_z is not None else (self.tensors[y], True)
218                    inplace_desc = copy.deepcopy(d)
219                    inplace_desc['attr'] = {'name': 'fake_output', 'value': fake}
220                    z_desc, out_desc = inplace_desc['input_desc'][2][0], inplace_desc['output_desc'][0]
221                    z_desc['shape'] = z.shape
222                    z_desc['data_type'] = z.dtype
223                    z_desc['tensor_name'] = z.name
224                    out_desc['shape'] = z.shape
225                    out_desc['data_type'] = z.dtype
226                    return inplace_desc
227            op = self.tensors[d['output_desc'][0]['tensor_name']].op
228            if op in graph_ops or op in subgraph.recompute_ops:
229                return d
230            return None
231
232        for key in self.desc.keys():
233            if key == 'input_desc':
234                desc[key] = [[{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}] for t in inputs]
235            elif key == 'output_desc':
236                desc[key] = list(map(dump_output, outputs))
237            elif key == 'op_desc':
238                op_desc = map(dump_op_desc, self.desc[key])
239                desc[key] = [d for d in op_desc if d is not None]
240            elif key == 'op':
241                desc[key] = subgraph.name
242            else:
243                desc[key] = self.desc[key]
244
245        desc = self.add_stitch_info(subgraph, desc)
246        desc = self.add_recompute_ops(subgraph, desc)
247        return desc
248
249
250def load_composite(desc):
251    """Load composite kernel"""
252    composite = CompositeGraph()
253    composite.load(desc)
254    composite.refine()
255    return composite
256