1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import mindspore as ms 16import mindspore.nn as nn 17from mindspore import Tensor 18from mindspore import context 19from mindspore.common.api import _cell_graph_executor 20from mindspore.common.initializer import initializer 21from mindspore.common.parameter import Parameter, ParameterTuple 22from mindspore.ops import composite as C 23from mindspore.ops import operations as P 24 25context.set_context(mode=context.GRAPH_MODE) 26 27 28grad_by_list = C.GradOperation(get_by_list=True) 29 30 31class NetWithLoss(nn.Cell): 32 def __init__(self, network, types, shapes, output_num, strategy3=None, strategy4=None, axis=-1): 33 super(NetWithLoss, self).__init__() 34 self.get_next = P.GetNext(types, shapes, output_num, "") 35 self.one_hot = P.OneHot(axis=axis).shard(strategy3) 36 self.on_value = Tensor(1.0, ms.float32) 37 self.off_value = Tensor(0.0, ms.float32) 38 self.loss = P.SoftmaxCrossEntropyWithLogits().shard(strategy4) 39 self.network = network 40 41 def construct(self): 42 data, label = self.get_next() 43 predict = self.network(data) 44 label = self.one_hot(label, 64, self.on_value, self.off_value) 45 return self.loss(predict, label)[0] 46 47 48class GradWrap(nn.Cell): 49 def __init__(self, network): 50 super(GradWrap, self).__init__() 51 self.network = network 52 self.weights = ParameterTuple(network.trainable_params()) 53 54 def construct(self): 55 return grad_by_list(self.network, self.weights)() 56 57 58def compile_net(net): 59 net.set_auto_parallel() 60 _cell_graph_executor.compile(net) 61 62 63def test_get_next_single(): 64 class Net(nn.Cell): 65 def __init__(self, channel=1, w=0.25): 66 super().__init__() 67 self.norm = P.L2Normalize(axis=1) 68 self.prelu = P.PReLU() 69 self.w = Parameter(initializer(w, [channel,]), name='w') 70 71 def construct(self, data): 72 x = self.norm(data) 73 x = self.prelu(x, self.w) 74 return x 75 76 net = GradWrap(NetWithLoss(Net(), [ms.float32, ms.int32], [[32, 64], [32]], 2)) 77 _cell_graph_executor.compile(net) 78 79 80def test_get_next_semi_auto_parallel(): 81 class Net(nn.Cell): 82 def __init__(self, channel=1, w=0.25, strategy1=None, strategy2=None): 83 super().__init__() 84 self.norm = P.L2Normalize().shard(strategy1) 85 self.prelu = P.PReLU().shard(strategy2) 86 self.w = Parameter(initializer(w, [channel,]), name='w') 87 88 def construct(self, data): 89 x = self.norm(data) 90 x = self.prelu(x, self.w) 91 return x 92 93 context.set_auto_parallel_context(device_num=4, global_rank=0) 94 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 95 network = Net(strategy1=((1, 4),), strategy2=((4, 1), (1,))) 96 strategy3 = ((4, 1), (), ()) 97 strategy4 = ((4, 1), (4, 1)) 98 net_with_loss = NetWithLoss(network, [ms.float32, ms.int32], [[32, 64], [32]], 2, strategy3=strategy3, 99 strategy4=strategy4) 100 net = GradWrap(net_with_loss) 101 compile_net(net) 102 103 104def test_get_next_semi_auto_parallel1(): 105 class Net(nn.Cell): 106 def __init__(self, channel=1, w=0.25, strategy1=None, strategy2=None): 107 super().__init__() 108 self.norm = P.L2Normalize().shard(strategy1) 109 self.prelu = P.PReLU().shard(strategy2) 110 self.w = Parameter(initializer(w, [channel,]), name='w') 111 112 def construct(self, data): 113 x = self.norm(data) 114 x = self.prelu(x, self.w) 115 return x 116 117 context.set_auto_parallel_context(device_num=4, global_rank=0) 118 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 119 network = Net(strategy1=((1, 4),), strategy2=((4, 1), (1,))) 120 strategy3 = ((1, 4), (), ()) 121 strategy4 = ((4, 1), (4, 1)) 122 net_with_loss = NetWithLoss(network, [ms.float32, ms.int32], [[32, 64], [32]], 2, strategy3=strategy3, 123 strategy4=strategy4) 124 net = GradWrap(net_with_loss) 125 compile_net(net) 126 127 128def test_get_next_auto_parallel(): 129 class Net(nn.Cell): 130 def __init__(self, channel=1, w=0.25, strategy1=None, strategy2=None): 131 super().__init__() 132 self.norm = P.L2Normalize().shard(strategy1) 133 self.prelu = P.PReLU().shard(strategy2) 134 self.w = Parameter(initializer(w, [channel,]), name='w') 135 136 def construct(self, data): 137 x = self.norm(data) 138 x = self.prelu(x, self.w) 139 return x 140 141 context.set_auto_parallel_context(device_num=4, global_rank=0) 142 context.set_auto_parallel_context(parallel_mode="auto_parallel") 143 network = Net() 144 net_with_loss = NetWithLoss(network, [ms.float32, ms.int32], [[32, 64], [32]], 2) 145 net = GradWrap(net_with_loss) 146 compile_net(net) 147 148 149def test_only_one_get_next(): 150 class Net(nn.Cell): 151 def __init__(self): 152 super().__init__() 153 self.get_next = P.GetNext([ms.float32, ms.int32], [[32, 64], [32]], 2, "") 154 155 def construct(self): 156 return self.get_next() 157 158 context.set_auto_parallel_context(device_num=4, global_rank=0) 159 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 160 net = Net() 161 net.set_train() 162 compile_net(net) 163