1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore as ms 18import mindspore.nn as nn 19from mindspore import Tensor 20from mindspore import context 21from mindspore.common.api import _cell_graph_executor 22from mindspore.common.parameter import Parameter 23from mindspore.nn.optim.momentum import Momentum 24from mindspore.ops import composite as C 25from mindspore.ops import operations as P 26from mindspore.train import Model 27from mindspore.context import ParallelMode 28from tests.dataset_mock import MindData 29from tests.ut.python.ops.test_math_ops import VirtualLoss 30 31context.set_context(mode=context.GRAPH_MODE) 32 33 34grad_all = C.GradOperation(get_all=True) 35 36 37class Dataset(MindData): 38 def __init__(self, predict, label, length=3): 39 super(Dataset, self).__init__(size=length) 40 self.predict = predict 41 self.label = label 42 self.index = 0 43 self.length = length 44 45 def __iter__(self): 46 return self 47 48 def __next__(self): 49 if self.index >= self.length: 50 raise StopIteration 51 self.index += 1 52 return self.predict, self.label 53 54 def reset(self): 55 self.index = 0 56 57 58class NetWithLoss(nn.Cell): 59 def __init__(self, network): 60 super(NetWithLoss, self).__init__() 61 self.loss = VirtualLoss() 62 self.network = network 63 64 def construct(self, x, y, b): 65 predict = self.network(x, y, b) 66 return self.loss(predict) 67 68 69class GradWrap(nn.Cell): 70 def __init__(self, network): 71 super(GradWrap, self).__init__() 72 self.network = network 73 74 def construct(self, x, y, b): 75 return grad_all(self.network)(x, y, b) 76 77 78def test_auto_parallel_arithmetic(): 79 class Net(nn.Cell): 80 def __init__(self): 81 super().__init__() 82 self.matmul = P.MatMul() 83 self.one_hot = P.OneHot() 84 self.on_value = Tensor(1.0, ms.float32) 85 self.off_value = Tensor(0.0, ms.float32) 86 self.matmul2 = P.MatMul() 87 88 def construct(self, x, y, b): 89 out = self.matmul(x, y) 90 out1 = self.one_hot(b, 64, self.on_value, self.off_value) 91 out2 = self.matmul2(out, out1) 92 return out2 93 94 context.set_auto_parallel_context(device_num=8, global_rank=0) 95 net = GradWrap(NetWithLoss(Net())) 96 context.set_auto_parallel_context(parallel_mode="auto_parallel") 97 net.set_auto_parallel() 98 99 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 100 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 101 b = Tensor(np.ones([64]), dtype=ms.int32) 102 net.set_train() 103 _cell_graph_executor.compile(net, x, y, b) 104 105 106def test_auto_parallel_arithmetic_model(): 107 class NetOneHot(nn.Cell): 108 def __init__(self): 109 super().__init__() 110 self.matmul = P.MatMul() 111 self.one_hot = P.OneHot().shard(((1, 8), (), ())) 112 self.on_value = Tensor(1.0, ms.float32) 113 self.off_value = Tensor(0.0, ms.float32) 114 self.matmul2 = P.MatMul() 115 self.w = Parameter(Tensor(np.zeros([32, 64]).astype(np.float32)), "weight", requires_grad=True) 116 117 def construct(self, x, b): 118 out = self.matmul(x, self.w) 119 out1 = self.one_hot(b, 64, self.on_value, self.off_value) 120 out2 = self.matmul2(out, out1) 121 return out2 122 123 context.reset_auto_parallel_context() 124 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL) 125 net = NetOneHot() 126 127 x = Tensor(np.ones([8, 32]), dtype=ms.float32) 128 b = Tensor(np.ones([8]), dtype=ms.int32) 129 dataset = Dataset(x, b, 2) 130 131 opt = Momentum(net.trainable_params(), 0.1, 0.9) 132 model = Model(net, optimizer=opt) 133 134 model.train(2, dataset, dataset_sink_mode=False) 135