• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===========================================================================
15"""Test split"""
16import model
17from model import model as estimate
18from model import graph_split as split
19
20
21def get_nodes(sp, ops):
22    """Get nodes"""
23    if isinstance(ops[0], str):
24        new_ops = []
25        for t in ops:
26            for op in sp.graph.ops:
27                if op.output.name == t:
28                    new_ops.append(op)
29                    break
30            else:
31                print("ERROR: not found op: ", t)
32        ops = new_ops
33    return [sp.nodes[sp.graph.ops.index(op)] for op in ops]
34
35
36def first_connected(sp, space):
37    for cand in space:
38        nodes = [sp.nodes[i] for i in cand[0]]
39        graphs = sp.resolve_connnected_graphs(nodes)
40        if len(graphs) != 1:
41            print("connect check failed: ", nodes)
42            return False
43    return True
44
45
46def split_format(sp, cand):
47    names = []
48    for ids in cand:
49        ops = []
50        for i in ids:
51            ops.append(sp.graph.ops[i].output.name)
52        names.append(','.join(ops))
53    return '|'.join(names)
54
55
56def graph_1():
57    ''' ring, no succ_dep, no prev '''
58    gb = model.GraphBuilder()
59    with gb.graph_scope("main"):
60        a = gb.tensor([10240, 16], "float32", name="a")
61        b = gb.emit("Abs", a, 'b')
62        c = gb.emit("Abs", b, 'c')
63        d = gb.emit("Abs", c, 'd')
64        gb.emit('Add', [b, d], 'e')
65    return gb.get()[0]
66
67
68def graph_2():
69    ''' ring, succ_dep, no prev '''
70    gb = model.GraphBuilder()
71    with gb.graph_scope("main"):
72        a0 = gb.tensor([10240, 16], "float32", name="a0")
73        a = gb.emit("Abs", a0, 'a')
74        b = gb.emit("Abs", a, 'b')
75        c = gb.emit("Abs", a, 'c')
76        d = gb.emit("Abs", b, 'd')
77        e = gb.emit('Add', [c, d], 'e')
78        gb.emit("Abs", e, 'f')
79    return gb.get()[0]
80
81
82def graph_3():
83    ''' no ring, 1 sibling node '''
84    gb = model.GraphBuilder()
85    with gb.graph_scope("main"):
86        a0 = gb.tensor([10240, 16], "float32", name="a0")
87        a1 = gb.tensor([10240, 16], "float32", name="a1")
88        b = gb.emit("Abs", a0, 'b')
89        c = gb.emit("Abs", a1, 'c')
90        d = gb.emit("Abs", b, 'd')
91        e = gb.emit('Add', [c, d], 'e')
92        gb.emit("Abs", e, 'f')
93    return gb.get()[0]
94
95
96def graph_4():
97    ''' no ring, 2 sibling nodes in 1 step '''
98    gb = model.GraphBuilder()
99    with gb.graph_scope("main"):
100        a0 = gb.tensor([10240, 16], "float32", name="a0")
101        a1 = gb.tensor([10240, 16], "float32", name="a1")
102        b = gb.emit("Abs", a0, 'b')
103        c = gb.emit("Abs", b, 'c')
104        d = gb.emit("Abs", a1, 'd')
105        e = gb.emit("Abs", d, 'e')
106        f = gb.emit('Add', [c, e], 'f')
107        gb.emit('Abs', f, 'g')
108        h = gb.emit("Abs", d, 'h')
109        i = gb.emit('Add', [c, h], 'i')
110        gb.emit("Abs", i, 'j')
111    return gb.get()[0]
112
113
114def graph_5():
115    ''' no ring, 2 sibling step '''
116    gb = model.GraphBuilder()
117    with gb.graph_scope("main") as g:
118        a0 = gb.tensor([10240, 16], "float32", name="a0")
119        a1 = gb.tensor([10240, 16], "float32", name="a1")
120        a2 = gb.tensor([10240, 16], "float32", name="a2")
121        a = gb.emit("Abs", a0, 'a')
122        b = gb.emit("Abs", a1, 'b')
123        c = gb.emit("Abs", b, 'c')
124        d = gb.emit('Add', [a, c], 'd')
125        gb.emit("Abs", d, 'e')
126        f = gb.emit("Abs", a2, 'f')
127        g = gb.emit('Add', [c, f], 'g')
128        gb.emit("Abs", g, 'h')
129    return gb.get()[0]
130
131
132def graph_6():
133    ''' no ring, tree down '''
134    gb = model.GraphBuilder()
135    with gb.graph_scope("main"):
136        a0 = gb.tensor([10240, 16], "float32", name="a0")
137        a = gb.emit("Abs", a0, 'a')
138        b = gb.emit("Abs", a, 'b')
139        gb.emit("Abs", b, 'd')
140        gb.emit("Abs", b, 'e')
141        c = gb.emit("Abs", a, 'c')
142        gb.emit("Abs", c, 'f')
143        gb.emit("Abs", c, 'g')
144    return gb.get()[0]
145
146
147def graph_pat_1():
148    ''' split by reduce '''
149    gb = model.GraphBuilder()
150    with gb.graph_scope("main"):
151        a0 = gb.tensor([1024, 1024], "float32", name="a0")
152        a = gb.emit("Abs", a0, 'a')
153        b = gb.emit("Abs", a, 'b')
154        c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
155        d = gb.emit("Sqrt", c, 'd')
156        gb.emit("Sqrt", d, 'f')
157    return gb.get()[0]
158
159
160def graph_pat_2():
161    ''' multi output '''
162    gb = model.GraphBuilder()
163    with gb.graph_scope("main"):
164        a0 = gb.tensor([1024, 1024], "float32", name="a0")
165        a = gb.emit("Abs", a0, 'a')
166        b = gb.emit("Abs", a, 'b')
167        gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
168        gb.emit("ReduceSum", b, 'e', attrs={'reduce_axis': (1,)})
169    return gb.get()[0]
170
171
172def graph_pat_3():
173    ''' two reduce '''
174    gb = model.GraphBuilder()
175    with gb.graph_scope("main"):
176        a0 = gb.tensor([1024, 1024], "float32", name="a0")
177        a = gb.emit("Abs", a0, 'a')
178        b = gb.emit("Abs", a, 'b')
179        c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
180        d = gb.emit("Abs", c, 'd')
181        gb.emit("ReduceSum", d, 'e', attrs={'reduce_axis': (1,)})
182    return gb.get()[0]
183
184
185def graph_pat_4():
186    ''' elewise + broadcast '''
187    gb = model.GraphBuilder()
188    with gb.graph_scope("main"):
189        a0 = gb.tensor([1, 1024], "float32", name="a0")
190        a2 = gb.tensor([1014, 1024], "float32", name="a2")
191        a = gb.emit("Abs", a0, 'a')
192        b = gb.emit("Abs", a, 'b')
193        c = gb.emit("Abs", b, 'c')
194        d = gb.emit("Abs", c, 'd')
195        e = gb.emit("Abs", d, 'e')
196        f = gb.emit("Abs", e, 'f')
197        g0 = gb.emit("Abs", a2, 'g0')
198        # g0 = gb.emit("Abs", g0, 'g0')
199        # g0 = gb.emit("Abs", g0, 'g0')
200        # g0 = gb.emit("Abs", g0, 'g0')
201        # g0 = gb.emit("Abs", g0, 'g0')
202        # g0 = gb.emit("Abs", g0, 'g0')
203        # g0 = gb.emit("Abs", g0, 'g0')
204        g0 = gb.emit("Abs", g0, 'g0')
205        g1 = gb.emit('Add', [f, g0], 'g1')
206        g2 = gb.emit("Abs", g1, 'g2')
207        g3 = gb.emit("Abs", g2, 'g3')
208        g4 = gb.emit("Abs", g3, 'g4')
209        gb.emit("Abs", g4, 'g5')
210    return gb.get()[0]
211
212
213def graph_pat_5():
214    ''' reduce + reshape '''
215    gb = model.GraphBuilder()
216    with gb.graph_scope("main"):
217        a0 = gb.tensor([1024, 1024], "float32", name="a0")
218        a = gb.emit("Abs", a0, 'a')
219        b = gb.emit("Abs", a, 'b')
220        c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
221        d = gb.emit("Abs", c, 'd')
222        e = gb.tensor([512, 2048], "float32", name="e")
223        gb.op("Reshape", e, [d])
224    return gb.get()[0]
225
226
227def graph_pat_6():
228    ''' dimond '''
229    gb = model.GraphBuilder()
230    with gb.graph_scope("main"):
231        a0 = gb.tensor([1024, 1024], "float32", name="a0")
232        a = gb.emit("Abs", a0, 'a')
233        b = gb.emit("Abs", a, 'b')
234        c = gb.emit("Abs", a, 'c')
235        gb.emit("Add", [b, c], 'd')
236        gb.emit("Abs", c, 'f')  # broke dimond
237    return gb.get()[0]
238
239
240def graph_pat_7():
241    ''' buddy of control op '''
242    gb = model.GraphBuilder()
243    with gb.graph_scope("main"):
244        a0 = gb.tensor([1024, 1024], "float32", name="a0")
245        a1 = gb.tensor([1024, 1024], "float32", name="a1")
246        a = gb.emit("Abs", a0, 'a')
247        b = gb.emit("Abs", a1, 'b')
248        c = gb.emit("MakeTuple", [a, b], 'c')
249        d = gb.tensor([1024, 1024], "float32", name="d")
250        gb.op("AddN", d, [c])
251        gb.emit("Abs", d, 'f')
252    graph = gb.get()[0]
253    estimate.AddControlBuddy().visit_graph(graph)
254    return graph
255
256
257def graph_pat_8():
258    ''' reduce + reshape '''
259    gb = model.GraphBuilder()
260    with gb.graph_scope("main"):
261        a0 = gb.tensor([1024, 1024], "float32", name="a0")
262        a = gb.emit("Abs", a0, 'a')
263        b = gb.emit("Abs", a, 'b')
264        #c = gb.emit("Abs", b, 'b')
265        c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
266        gb.emit("Add", [b, c], 'd')
267    return gb.get()[0]
268
269
270def graph_pat_9():
271    ''' scalar  '''
272    gb = model.GraphBuilder()
273    with gb.graph_scope("main"):
274        a0 = gb.tensor([1024, 1024], "float32", name="a0")
275        a1 = gb.tensor([1], "float32", name="a1")
276        a = gb.emit("Maximum", a1, 'a')
277        b = gb.emit("Mul", [a, a1], 'b')
278        gb.emit('Mul', [b, a0], 'c')
279    return gb.get()[0]
280
281
282def graph_mo_1():
283    gb = model.GraphBuilder()
284    with gb.graph_scope("main"):
285        a0 = gb.tensor([1024, 1024], "float32", name="a0")
286        a = gb.emit("Abs", a0, 'a')
287        gb.emit("Abs", a, 'b')
288        gb.emit("Abs", a, 'c')
289    return gb.get()[0]
290
291
292def graph_mo_2():
293    gb = model.GraphBuilder()
294    with gb.graph_scope("main") as g:
295        a0 = gb.tensor([1024, 1024], "float32", name="a0")
296        a = gb.emit("Abs", a0, 'a')
297        b = gb.emit("Abs", a, 'b')
298        c = gb.emit("Abs", b, 'c')
299        g.set_output(b, c)
300    return gb.get()[0]
301
302
303def graph_mo_3():
304    ''' two reduce '''
305    gb = model.GraphBuilder()
306    with gb.graph_scope("main") as g:
307        a0 = gb.tensor([1024, 1024], "float32", name="a0")
308        a = gb.emit("Abs", a0, 'a')
309        b = gb.emit("Abs", a, 'b')
310        c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
311        g.set_output(b, c)
312    return gb.get()[0]
313
314
315def graph_mo_4():
316    ''' two reduce '''
317    gb = model.GraphBuilder()
318    with gb.graph_scope("main") as g:
319        a0 = gb.tensor([1024, 1024], "float32", name="a0")
320        a = gb.emit("Abs", a0, 'a')
321        b = gb.emit("Abs", a, 'b')
322        c = gb.emit("ReduceSum", a, 'c', attrs={'reduce_axis': (1,)})
323        g.set_output(b, c)
324    return gb.get()[0]
325
326
327def test_binary_split():
328    """Test binary split"""
329    def _test(graph, expected_space_size):
330        print("********* test on graph : {} *************".format(graph.name))
331        sp = split.GraphSpliter(graph)
332        nodes = get_nodes(sp, graph.ops)
333        space = sp.binary_split(nodes)
334        for i, s in enumerate(space):
335            print('{}: {}'.format(i, split_format(sp, s)))
336        assert len(space) == expected_space_size
337        assert first_connected(sp, space)
338    _test(graph_1(), 3)
339    _test(graph_2(), 7)
340    _test(graph_3(), 4)
341    _test(graph_4(), 17)
342    _test(graph_5(), 11)
343    _test(graph_6(), 24)
344
345
346def test_resolve_connnected_graphs():
347    """Test resolve connected graphs"""
348    graph = graph_5()
349    sp = split.GraphSpliter(graph)
350    n1 = get_nodes(sp, ['a', 'd', 'b', 'c'])
351    graphs = sp.resolve_connnected_graphs(n1)
352    print(graphs)
353    assert len(graphs) == 1
354    n2 = get_nodes(sp, ['a', 'd', 'e', 'f', 'g'])
355    graphs = sp.resolve_connnected_graphs(n2)
356    print(graphs)
357    assert len(graphs) == 2
358    n3 = get_nodes(sp, ['a', 'b', 'f'])
359    graphs = sp.resolve_connnected_graphs(n3)
360    print(graphs)
361    assert len(graphs) == 3
362
363
364def test_split():
365    """Test split"""
366    def _print_cost(name, c):
367        print("%s\tdma_ratio=%f, saturation=%f, mix_saturation=%f, type=%s" %
368              (name, c.dma_ratio(), c.saturation(), c.mix_saturation(), c.cost_type()))
369
370    def _test(graph):
371        print("********* test on graph : {} *************".format(graph.name))
372        sp = split.GraphSpliter(graph)
373        subgraphs = sp.split(False)
374        print('----- main graph -------')
375        print(graph)
376        for i, g in enumerate(subgraphs):
377            print(' -------- subgraph {} -------'.format(i))
378            print(g)
379        print("--------- cost ------------")
380        cost, _ = model.estimate(graph)
381        _print_cost("main graph", cost)
382        fc, sub_costs = model.estimate(subgraphs)
383        _print_cost("Subgraphs:", fc)
384        for i, cost in enumerate(sub_costs):
385            _print_cost(" |_%d:\t" % (i), cost)
386    _test(graph_5())
387    # _test(graph_4())
388
389
390def test_estimate():
391    """Test estimate"""
392    graph = graph_5()
393    e = estimate.Estimator(graph)
394    e.estimate()
395    print(e.iter_space)
396
397
398def test_pattern_split():
399    """Test pattern split"""
400    def _test(graph, expect_n=0):
401        print("************* main graph **************")
402        print(graph)
403        subgraphs = split.GraphSplitByPatternV2(graph).split()
404        for i, g in enumerate(subgraphs):
405            print(' -------- subgraph {} -------'.format(i))
406            print(g)
407        if expect_n > 0:
408            assert len(subgraphs) == expect_n
409
410    # _test(graph_1(), 1)
411    # _test(graph_pat_1(), 2)
412    # _test(graph_pat_2())
413    # _test(graph_pat_3())
414    # _test(graph_pat_4())
415    # _test(graph_pat_5())
416    # _test(graph_pat_6())
417    # _test(graph_pat_7())
418    # _test(graph_pat_8())
419    # _test(graph_pat_9())
420
421    # _test(graph_mo_1())
422    # _test(graph_mo_2())
423    # _test(graph_mo_3())
424    _test(graph_mo_4())
425
426
427def main():
428    # test_binary_split()
429    # test_resolve_connnected_graphs()
430    # test_split()
431    # test_estimate()
432    test_pattern_split()
433
434
435if __name__ == '__main__':
436    main()
437