1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore as ms 18import mindspore.nn as nn 19from mindspore import Tensor 20from mindspore import context 21from mindspore.common.api import _cell_graph_executor 22from mindspore.ops import composite as C 23from mindspore.ops import operations as P 24from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell 25 26context.set_context(mode=context.GRAPH_MODE) 27 28 29grad_all = C.GradOperation(get_all=True) 30 31 32class NetWithLoss(nn.Cell): 33 def __init__(self, network, strategy3, strategy4, axis): 34 super(NetWithLoss, self).__init__() 35 self.one_hot = P.OneHot(axis=axis).shard(strategy3) 36 self.on_value = Tensor(2.0, ms.float32) 37 self.off_value = Tensor(1.0, ms.float32) 38 self.loss = P.SoftmaxCrossEntropyWithLogits().shard(strategy4) 39 self.network = network 40 41 def construct(self, x, y, b): 42 predict = self.network(x, y) 43 label = self.one_hot(b, 64, self.on_value, self.off_value) 44 return self.loss(predict, label)[0] 45 46 47class GradWrap(nn.Cell): 48 def __init__(self, network): 49 super(GradWrap, self).__init__() 50 self.network = network 51 52 def construct(self, x, y, b): 53 return grad_all(self.network)(x, y, b) 54 55 56class Net(nn.Cell): 57 def __init__(self, strategy1, strategy2): 58 super().__init__() 59 self.matmul = P.MatMul().shard(strategy1) 60 self.gelu = P.GeLU().shard(strategy2) 61 62 def construct(self, x, y): 63 out = self.matmul(x, y) 64 out = self.gelu(out) 65 return out 66 67 68def compile_graph(strategy1, strategy2, strategy3, strategy4, auto=False, onthot_axis=-1): 69 net = GradWrap(_VirtualDatasetCell(NetWithLoss(Net(strategy1, strategy2), strategy3, strategy4, axis=onthot_axis))) 70 net.set_auto_parallel() 71 if auto: 72 context.set_auto_parallel_context(parallel_mode="auto_parallel") 73 else: 74 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 75 76 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 77 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 78 b = Tensor(np.ones([64]), dtype=ms.int32) 79 net.set_train() 80 _cell_graph_executor.compile(net, x, y, b) 81 82 83def test_onehot_model_parallel(): 84 context.set_auto_parallel_context(device_num=16, global_rank=0) 85 strategy1 = ((2, 4), (4, 2)) 86 strategy2 = ((2, 8),) 87 strategy3 = ((1, 16), (), ()) 88 strategy4 = ((16, 1), (16, 1)) 89 compile_graph(strategy1, strategy2, strategy3, strategy4) 90 91 92def test_onehot_batch_parallel(): 93 context.set_auto_parallel_context(device_num=16, global_rank=0) 94 strategy1 = ((2, 4), (4, 2)) 95 strategy2 = ((2, 8),) 96 strategy3 = ((16, 1), (), ()) 97 strategy4 = ((16, 1), (16, 1)) 98 compile_graph(strategy1, strategy2, strategy3, strategy4) 99 100 101def test_onehot_batch_parallel_invalid_strategy(): 102 context.set_auto_parallel_context(device_num=16, global_rank=0) 103 strategy1 = ((2, 4), (4, 2)) 104 strategy2 = ((2, 8),) 105 strategy3 = ((16,), (), ()) 106 strategy4 = ((16, 1), (16, 1)) 107 try: 108 compile_graph(strategy1, strategy2, strategy3, strategy4) 109 except ValueError: 110 pass 111 except TypeError: 112 pass 113 except RuntimeError: 114 pass 115 116 117def test_onehot_repeated_calculation(): 118 context.set_auto_parallel_context(device_num=16, global_rank=0) 119 strategy1 = ((2, 4), (4, 2)) 120 strategy2 = ((2, 8),) 121 strategy3 = ((4, 1), (), ()) 122 strategy4 = ((16, 1), (16, 1)) 123 compile_graph(strategy1, strategy2, strategy3, strategy4) 124 125 126def test_onehot_auto(): 127 context.set_auto_parallel_context(device_num=16, global_rank=0) 128 strategy1 = None 129 strategy2 = None 130 strategy3 = None 131 strategy4 = None 132 compile_graph(strategy1, strategy2, strategy3, strategy4, auto=True) 133 134 135def test_onehot_batch_parallel_axis0(): 136 context.set_auto_parallel_context(device_num=16, global_rank=0) 137 strategy1 = ((2, 4), (4, 2)) 138 strategy2 = ((2, 8),) 139 strategy3 = ((16, 1), (), ()) 140 strategy4 = ((16, 1), (16, 1)) 141 compile_graph(strategy1, strategy2, strategy3, strategy4, onthot_axis=0) 142 143 144# auto parallel for onehot axis equal to 0 has not been supported yet 145def test_onehot_batch_parallel_invalid_strategy_axis0(): 146 context.set_auto_parallel_context(device_num=16, global_rank=0) 147 strategy1 = ((2, 4), (4, 2)) 148 strategy2 = ((2, 8),) 149 strategy3 = None 150 strategy4 = ((16, 1), (16, 1)) 151 try: 152 compile_graph(strategy1, strategy2, strategy3, strategy4, onthot_axis=0) 153 except ValueError: 154 pass 155 except TypeError: 156 pass 157 except RuntimeError: 158 pass 159 160 161def test_onehot_repeated_calculation_axis0(): 162 context.set_auto_parallel_context(device_num=16, global_rank=0) 163 strategy1 = ((2, 4), (4, 2)) 164 strategy2 = ((2, 8),) 165 strategy3 = ((4, 1), (), ()) 166 strategy4 = ((16, 1), (16, 1)) 167 compile_graph(strategy1, strategy2, strategy3, strategy4, onthot_axis=0) 168 169 170def test_onehot_auto_axis0(): 171 context.set_auto_parallel_context(device_num=16, global_rank=14) 172 strategy1 = None 173 strategy2 = None 174 strategy3 = None 175 strategy4 = None 176 compile_graph(strategy1, strategy2, strategy3, strategy4, auto=True, onthot_axis=0) 177