1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =========================================================================== 15"""Test split""" 16import model 17from model import model as estimate 18from model import graph_split as split 19 20 21def get_nodes(sp, ops): 22 """Get nodes""" 23 if isinstance(ops[0], str): 24 new_ops = [] 25 for t in ops: 26 for op in sp.graph.ops: 27 if op.output.name == t: 28 new_ops.append(op) 29 break 30 else: 31 print("ERROR: not found op: ", t) 32 ops = new_ops 33 return [sp.nodes[sp.graph.ops.index(op)] for op in ops] 34 35 36def first_connected(sp, space): 37 for cand in space: 38 nodes = [sp.nodes[i] for i in cand[0]] 39 graphs = sp.resolve_connnected_graphs(nodes) 40 if len(graphs) != 1: 41 print("connect check failed: ", nodes) 42 return False 43 return True 44 45 46def split_format(sp, cand): 47 names = [] 48 for ids in cand: 49 ops = [] 50 for i in ids: 51 ops.append(sp.graph.ops[i].output.name) 52 names.append(','.join(ops)) 53 return '|'.join(names) 54 55 56def graph_1(): 57 ''' ring, no succ_dep, no prev ''' 58 gb = model.GraphBuilder() 59 with gb.graph_scope("main"): 60 a = gb.tensor([10240, 16], "float32", name="a") 61 b = gb.emit("Abs", a, 'b') 62 c = gb.emit("Abs", b, 'c') 63 d = gb.emit("Abs", c, 'd') 64 gb.emit('Add', [b, d], 'e') 65 return gb.get()[0] 66 67 68def graph_2(): 69 ''' ring, succ_dep, no prev ''' 70 gb = model.GraphBuilder() 71 with gb.graph_scope("main"): 72 a0 = gb.tensor([10240, 16], "float32", name="a0") 73 a = gb.emit("Abs", a0, 'a') 74 b = gb.emit("Abs", a, 'b') 75 c = gb.emit("Abs", a, 'c') 76 d = gb.emit("Abs", b, 'd') 77 e = gb.emit('Add', [c, d], 'e') 78 gb.emit("Abs", e, 'f') 79 return gb.get()[0] 80 81 82def graph_3(): 83 ''' no ring, 1 sibling node ''' 84 gb = model.GraphBuilder() 85 with gb.graph_scope("main"): 86 a0 = gb.tensor([10240, 16], "float32", name="a0") 87 a1 = gb.tensor([10240, 16], "float32", name="a1") 88 b = gb.emit("Abs", a0, 'b') 89 c = gb.emit("Abs", a1, 'c') 90 d = gb.emit("Abs", b, 'd') 91 e = gb.emit('Add', [c, d], 'e') 92 gb.emit("Abs", e, 'f') 93 return gb.get()[0] 94 95 96def graph_4(): 97 ''' no ring, 2 sibling nodes in 1 step ''' 98 gb = model.GraphBuilder() 99 with gb.graph_scope("main"): 100 a0 = gb.tensor([10240, 16], "float32", name="a0") 101 a1 = gb.tensor([10240, 16], "float32", name="a1") 102 b = gb.emit("Abs", a0, 'b') 103 c = gb.emit("Abs", b, 'c') 104 d = gb.emit("Abs", a1, 'd') 105 e = gb.emit("Abs", d, 'e') 106 f = gb.emit('Add', [c, e], 'f') 107 gb.emit('Abs', f, 'g') 108 h = gb.emit("Abs", d, 'h') 109 i = gb.emit('Add', [c, h], 'i') 110 gb.emit("Abs", i, 'j') 111 return gb.get()[0] 112 113 114def graph_5(): 115 ''' no ring, 2 sibling step ''' 116 gb = model.GraphBuilder() 117 with gb.graph_scope("main") as g: 118 a0 = gb.tensor([10240, 16], "float32", name="a0") 119 a1 = gb.tensor([10240, 16], "float32", name="a1") 120 a2 = gb.tensor([10240, 16], "float32", name="a2") 121 a = gb.emit("Abs", a0, 'a') 122 b = gb.emit("Abs", a1, 'b') 123 c = gb.emit("Abs", b, 'c') 124 d = gb.emit('Add', [a, c], 'd') 125 gb.emit("Abs", d, 'e') 126 f = gb.emit("Abs", a2, 'f') 127 g = gb.emit('Add', [c, f], 'g') 128 gb.emit("Abs", g, 'h') 129 return gb.get()[0] 130 131 132def graph_6(): 133 ''' no ring, tree down ''' 134 gb = model.GraphBuilder() 135 with gb.graph_scope("main"): 136 a0 = gb.tensor([10240, 16], "float32", name="a0") 137 a = gb.emit("Abs", a0, 'a') 138 b = gb.emit("Abs", a, 'b') 139 gb.emit("Abs", b, 'd') 140 gb.emit("Abs", b, 'e') 141 c = gb.emit("Abs", a, 'c') 142 gb.emit("Abs", c, 'f') 143 gb.emit("Abs", c, 'g') 144 return gb.get()[0] 145 146 147def graph_pat_1(): 148 ''' split by reduce ''' 149 gb = model.GraphBuilder() 150 with gb.graph_scope("main"): 151 a0 = gb.tensor([1024, 1024], "float32", name="a0") 152 a = gb.emit("Abs", a0, 'a') 153 b = gb.emit("Abs", a, 'b') 154 c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) 155 d = gb.emit("Sqrt", c, 'd') 156 gb.emit("Sqrt", d, 'f') 157 return gb.get()[0] 158 159 160def graph_pat_2(): 161 ''' multi output ''' 162 gb = model.GraphBuilder() 163 with gb.graph_scope("main"): 164 a0 = gb.tensor([1024, 1024], "float32", name="a0") 165 a = gb.emit("Abs", a0, 'a') 166 b = gb.emit("Abs", a, 'b') 167 gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) 168 gb.emit("ReduceSum", b, 'e', attrs={'reduce_axis': (1,)}) 169 return gb.get()[0] 170 171 172def graph_pat_3(): 173 ''' two reduce ''' 174 gb = model.GraphBuilder() 175 with gb.graph_scope("main"): 176 a0 = gb.tensor([1024, 1024], "float32", name="a0") 177 a = gb.emit("Abs", a0, 'a') 178 b = gb.emit("Abs", a, 'b') 179 c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) 180 d = gb.emit("Abs", c, 'd') 181 gb.emit("ReduceSum", d, 'e', attrs={'reduce_axis': (1,)}) 182 return gb.get()[0] 183 184 185def graph_pat_4(): 186 ''' elewise + broadcast ''' 187 gb = model.GraphBuilder() 188 with gb.graph_scope("main"): 189 a0 = gb.tensor([1, 1024], "float32", name="a0") 190 a2 = gb.tensor([1014, 1024], "float32", name="a2") 191 a = gb.emit("Abs", a0, 'a') 192 b = gb.emit("Abs", a, 'b') 193 c = gb.emit("Abs", b, 'c') 194 d = gb.emit("Abs", c, 'd') 195 e = gb.emit("Abs", d, 'e') 196 f = gb.emit("Abs", e, 'f') 197 g0 = gb.emit("Abs", a2, 'g0') 198 # g0 = gb.emit("Abs", g0, 'g0') 199 # g0 = gb.emit("Abs", g0, 'g0') 200 # g0 = gb.emit("Abs", g0, 'g0') 201 # g0 = gb.emit("Abs", g0, 'g0') 202 # g0 = gb.emit("Abs", g0, 'g0') 203 # g0 = gb.emit("Abs", g0, 'g0') 204 g0 = gb.emit("Abs", g0, 'g0') 205 g1 = gb.emit('Add', [f, g0], 'g1') 206 g2 = gb.emit("Abs", g1, 'g2') 207 g3 = gb.emit("Abs", g2, 'g3') 208 g4 = gb.emit("Abs", g3, 'g4') 209 gb.emit("Abs", g4, 'g5') 210 return gb.get()[0] 211 212 213def graph_pat_5(): 214 ''' reduce + reshape ''' 215 gb = model.GraphBuilder() 216 with gb.graph_scope("main"): 217 a0 = gb.tensor([1024, 1024], "float32", name="a0") 218 a = gb.emit("Abs", a0, 'a') 219 b = gb.emit("Abs", a, 'b') 220 c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) 221 d = gb.emit("Abs", c, 'd') 222 e = gb.tensor([512, 2048], "float32", name="e") 223 gb.op("Reshape", e, [d]) 224 return gb.get()[0] 225 226 227def graph_pat_6(): 228 ''' dimond ''' 229 gb = model.GraphBuilder() 230 with gb.graph_scope("main"): 231 a0 = gb.tensor([1024, 1024], "float32", name="a0") 232 a = gb.emit("Abs", a0, 'a') 233 b = gb.emit("Abs", a, 'b') 234 c = gb.emit("Abs", a, 'c') 235 gb.emit("Add", [b, c], 'd') 236 gb.emit("Abs", c, 'f') # broke dimond 237 return gb.get()[0] 238 239 240def graph_pat_7(): 241 ''' buddy of control op ''' 242 gb = model.GraphBuilder() 243 with gb.graph_scope("main"): 244 a0 = gb.tensor([1024, 1024], "float32", name="a0") 245 a1 = gb.tensor([1024, 1024], "float32", name="a1") 246 a = gb.emit("Abs", a0, 'a') 247 b = gb.emit("Abs", a1, 'b') 248 c = gb.emit("MakeTuple", [a, b], 'c') 249 d = gb.tensor([1024, 1024], "float32", name="d") 250 gb.op("AddN", d, [c]) 251 gb.emit("Abs", d, 'f') 252 graph = gb.get()[0] 253 estimate.AddControlBuddy().visit_graph(graph) 254 return graph 255 256 257def graph_pat_8(): 258 ''' reduce + reshape ''' 259 gb = model.GraphBuilder() 260 with gb.graph_scope("main"): 261 a0 = gb.tensor([1024, 1024], "float32", name="a0") 262 a = gb.emit("Abs", a0, 'a') 263 b = gb.emit("Abs", a, 'b') 264 #c = gb.emit("Abs", b, 'b') 265 c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) 266 gb.emit("Add", [b, c], 'd') 267 return gb.get()[0] 268 269 270def graph_pat_9(): 271 ''' scalar ''' 272 gb = model.GraphBuilder() 273 with gb.graph_scope("main"): 274 a0 = gb.tensor([1024, 1024], "float32", name="a0") 275 a1 = gb.tensor([1], "float32", name="a1") 276 a = gb.emit("Maximum", a1, 'a') 277 b = gb.emit("Mul", [a, a1], 'b') 278 gb.emit('Mul', [b, a0], 'c') 279 return gb.get()[0] 280 281 282def graph_mo_1(): 283 gb = model.GraphBuilder() 284 with gb.graph_scope("main"): 285 a0 = gb.tensor([1024, 1024], "float32", name="a0") 286 a = gb.emit("Abs", a0, 'a') 287 gb.emit("Abs", a, 'b') 288 gb.emit("Abs", a, 'c') 289 return gb.get()[0] 290 291 292def graph_mo_2(): 293 gb = model.GraphBuilder() 294 with gb.graph_scope("main") as g: 295 a0 = gb.tensor([1024, 1024], "float32", name="a0") 296 a = gb.emit("Abs", a0, 'a') 297 b = gb.emit("Abs", a, 'b') 298 c = gb.emit("Abs", b, 'c') 299 g.set_output(b, c) 300 return gb.get()[0] 301 302 303def graph_mo_3(): 304 ''' two reduce ''' 305 gb = model.GraphBuilder() 306 with gb.graph_scope("main") as g: 307 a0 = gb.tensor([1024, 1024], "float32", name="a0") 308 a = gb.emit("Abs", a0, 'a') 309 b = gb.emit("Abs", a, 'b') 310 c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)}) 311 g.set_output(b, c) 312 return gb.get()[0] 313 314 315def graph_mo_4(): 316 ''' two reduce ''' 317 gb = model.GraphBuilder() 318 with gb.graph_scope("main") as g: 319 a0 = gb.tensor([1024, 1024], "float32", name="a0") 320 a = gb.emit("Abs", a0, 'a') 321 b = gb.emit("Abs", a, 'b') 322 c = gb.emit("ReduceSum", a, 'c', attrs={'reduce_axis': (1,)}) 323 g.set_output(b, c) 324 return gb.get()[0] 325 326 327def test_binary_split(): 328 """Test binary split""" 329 def _test(graph, expected_space_size): 330 print("********* test on graph : {} *************".format(graph.name)) 331 sp = split.GraphSpliter(graph) 332 nodes = get_nodes(sp, graph.ops) 333 space = sp.binary_split(nodes) 334 for i, s in enumerate(space): 335 print('{}: {}'.format(i, split_format(sp, s))) 336 assert len(space) == expected_space_size 337 assert first_connected(sp, space) 338 _test(graph_1(), 3) 339 _test(graph_2(), 7) 340 _test(graph_3(), 4) 341 _test(graph_4(), 17) 342 _test(graph_5(), 11) 343 _test(graph_6(), 24) 344 345 346def test_resolve_connnected_graphs(): 347 """Test resolve connected graphs""" 348 graph = graph_5() 349 sp = split.GraphSpliter(graph) 350 n1 = get_nodes(sp, ['a', 'd', 'b', 'c']) 351 graphs = sp.resolve_connnected_graphs(n1) 352 print(graphs) 353 assert len(graphs) == 1 354 n2 = get_nodes(sp, ['a', 'd', 'e', 'f', 'g']) 355 graphs = sp.resolve_connnected_graphs(n2) 356 print(graphs) 357 assert len(graphs) == 2 358 n3 = get_nodes(sp, ['a', 'b', 'f']) 359 graphs = sp.resolve_connnected_graphs(n3) 360 print(graphs) 361 assert len(graphs) == 3 362 363 364def test_split(): 365 """Test split""" 366 def _print_cost(name, c): 367 print("%s\tdma_ratio=%f, saturation=%f, mix_saturation=%f, type=%s" % 368 (name, c.dma_ratio(), c.saturation(), c.mix_saturation(), c.cost_type())) 369 370 def _test(graph): 371 print("********* test on graph : {} *************".format(graph.name)) 372 sp = split.GraphSpliter(graph) 373 subgraphs = sp.split(False) 374 print('----- main graph -------') 375 print(graph) 376 for i, g in enumerate(subgraphs): 377 print(' -------- subgraph {} -------'.format(i)) 378 print(g) 379 print("--------- cost ------------") 380 cost, _ = model.estimate(graph) 381 _print_cost("main graph", cost) 382 fc, sub_costs = model.estimate(subgraphs) 383 _print_cost("Subgraphs:", fc) 384 for i, cost in enumerate(sub_costs): 385 _print_cost(" |_%d:\t" % (i), cost) 386 _test(graph_5()) 387 # _test(graph_4()) 388 389 390def test_estimate(): 391 """Test estimate""" 392 graph = graph_5() 393 e = estimate.Estimator(graph) 394 e.estimate() 395 print(e.iter_space) 396 397 398def test_pattern_split(): 399 """Test pattern split""" 400 def _test(graph, expect_n=0): 401 print("************* main graph **************") 402 print(graph) 403 subgraphs = split.GraphSplitByPatternV2(graph).split() 404 for i, g in enumerate(subgraphs): 405 print(' -------- subgraph {} -------'.format(i)) 406 print(g) 407 if expect_n > 0: 408 assert len(subgraphs) == expect_n 409 410 # _test(graph_1(), 1) 411 # _test(graph_pat_1(), 2) 412 # _test(graph_pat_2()) 413 # _test(graph_pat_3()) 414 # _test(graph_pat_4()) 415 # _test(graph_pat_5()) 416 # _test(graph_pat_6()) 417 # _test(graph_pat_7()) 418 # _test(graph_pat_8()) 419 # _test(graph_pat_9()) 420 421 # _test(graph_mo_1()) 422 # _test(graph_mo_2()) 423 # _test(graph_mo_3()) 424 _test(graph_mo_4()) 425 426 427def main(): 428 # test_binary_split() 429 # test_resolve_connnected_graphs() 430 # test_split() 431 # test_estimate() 432 test_pattern_split() 433 434 435if __name__ == '__main__': 436 main() 437