1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore as ms 18import mindspore.nn as nn 19from mindspore import Tensor 20from mindspore import context 21from mindspore.common.api import _cell_graph_executor 22from mindspore.common.parameter import Parameter 23from mindspore.ops import composite as C 24from mindspore.ops import operations as P 25from mindspore.common.initializer import initializer 26from mindspore.nn import TrainOneStepCell, Momentum 27from tests.ut.python.ops.test_math_ops import VirtualLoss 28 29 30grad_all = C.GradOperation(get_all=True) 31 32 33class NetWithLoss(nn.Cell): 34 def __init__(self, network): 35 super(NetWithLoss, self).__init__() 36 self.loss = VirtualLoss() 37 self.network = network 38 39 def construct(self, x): 40 predict = self.network(x) 41 return self.loss(predict) 42 43 44class GradWrap(nn.Cell): 45 def __init__(self, network): 46 super(GradWrap, self).__init__() 47 self.network = network 48 49 def construct(self, x): 50 return grad_all(self.network)(x) 51 52def test_unique_column_split(): 53 class Net(nn.Cell): 54 def __init__(self): 55 super().__init__() 56 self.unique = P.Unique().shard(((1,),)) 57 self.relu = P.ReLU() 58 self.mul = P.Mul() 59 self.embedding_lookp = P.Gather().shard(((1, 8), (1,))) 60 self.embedding_table = Parameter(initializer('normal', [2000, 128]), 61 name='embedding_table') 62 self.gatherv2 = P.Gather().shard(((1, 8), (1,))) 63 self.reshape = P.Reshape() 64 self.matmul = P.MatMul() 65 self.mul_weight = Parameter(Tensor(np.full([32, 64, 1], 0.5, dtype=np.float32)), name="mul_weight") 66 67 def construct(self, indices): 68 indices_flatten = self.reshape(indices, (-1,)) 69 unique_id, unique_idx = self.unique(indices_flatten) 70 unique_id_weight = self.embedding_lookp(self.embedding_table, unique_id, 0) 71 weight_flatten = self.gatherv2(unique_id_weight, unique_idx, 0) 72 weight = self.reshape(weight_flatten, (32, 64, 128)) 73 vx = self.mul(weight, self.mul_weight) 74 return vx 75 76 size = 8 77 context.set_auto_parallel_context(device_num=size, global_rank=0, parallel_mode="auto_parallel") 78 x = Tensor(np.ones([32, 64]), dtype=ms.int32) 79 net = Net() 80 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 81 train_net = TrainOneStepCell(net, optimizer) 82 train_net.set_auto_parallel() 83 train_net.set_train() 84 _cell_graph_executor.compile(train_net, x) 85 86def test_unique_row_split(): 87 class Net(nn.Cell): 88 def __init__(self): 89 super().__init__() 90 self.unique = P.Unique().shard(((1,),)) 91 self.relu = P.ReLU() 92 self.mul = P.Mul() 93 self.embedding_lookp = P.Gather().shard(((8, 1), (1,))) 94 self.embedding_table = Parameter(initializer('normal', [2000, 128]), 95 name='embedding_table') 96 self.gatherv2 = P.Gather().shard(((1, 1), (1,))) 97 self.reshape = P.Reshape() 98 self.matmul = P.MatMul() 99 self.mul_weight = Parameter(Tensor(np.full([32, 64, 1], 0.5, dtype=np.float32)), name="mul_weight") 100 101 def construct(self, indices): 102 indices_flatten = self.reshape(indices, (-1,)) 103 unique_id, unique_idx = self.unique(indices_flatten) 104 unique_id_weight = self.embedding_lookp(self.embedding_table, unique_id, 0) 105 weight_flatten = self.gatherv2(unique_id_weight, unique_idx, 0) 106 weight = self.reshape(weight_flatten, (32, 64, 128)) 107 vx = self.mul(weight, self.mul_weight) 108 return vx 109 110 size = 8 111 context.set_auto_parallel_context(device_num=size, global_rank=0, parallel_mode="semi_auto_parallel") 112 x = Tensor(np.ones([32, 64]), dtype=ms.int32) 113 net = Net() 114 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 115 train_net = TrainOneStepCell(net, optimizer) 116 train_net.set_auto_parallel() 117 train_net.set_train() 118 _cell_graph_executor.compile(train_net, x) 119