1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ========================================================================== 15"""test graph parallel case""" 16import model 17 18def injective_graph(shape): 19 gb = model.GraphBuilder() 20 with gb.graph_scope('injective') as _: 21 a1 = gb.tensor(shape, 'float32') 22 a2 = gb.emit('Abs', a1) 23 a3 = gb.emit('Abs', a2) 24 gb.emit('Abs', a3) 25 return gb.get()[0] 26 27def reduce_graph(shape, reduce_axis): 28 gb = model.GraphBuilder() 29 with gb.graph_scope('reduce') as _: 30 a1 = gb.tensor(shape, 'float32') 31 a2 = gb.emit('Abs', a1) 32 a3 = gb.emit('Abs', a2) 33 gb.emit('ReduceSum', a3, 'C', attrs={'reduce_axis': reduce_axis}) 34 return gb.get()[0] 35 36def block_fusion(graphs): 37 gain = model.parallel_estimate(graphs) 38 print("fusion = {}, bottleneck = {}, gain = {}".format(gain.fusion_type, gain.bottleneck, gain.gain)) 39 return gain.fusion_type == "block_fusion" and gain.gain > 0 40 41if __name__ == "__main__": 42 assert block_fusion([injective_graph([40, 1024]), injective_graph([40, 1024])]) 43 assert block_fusion([reduce_graph([1024, 1024], [1]), injective_graph([24, 1024])]) 44 assert not block_fusion([reduce_graph([1024, 1024], [1]), injective_graph([50, 1024])]) 45 assert not block_fusion([reduce_graph([1024, 1024], [0, 1]), injective_graph([1024, 1024])]) 46