1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16 17import mindspore as ms 18import mindspore.nn as nn 19from mindspore.common.api import _cell_graph_executor 20from mindspore.ops import operations as P 21from mindspore.ops import composite as C 22from mindspore import Tensor, context 23from tests.ut.python.ops.test_math_ops import VirtualLoss 24 25 26grad_all = C.GradOperation(get_all=True) 27 28 29class GradWrap(nn.Cell): 30 def __init__(self, network): 31 super(GradWrap, self).__init__() 32 self.network = network 33 34 def construct(self, x, y): 35 return grad_all(self.network)(x, y) 36 37class NetWithLoss(nn.Cell): 38 def __init__(self, network): 39 super(NetWithLoss, self).__init__() 40 self.loss = VirtualLoss() 41 self.network = network 42 43 def construct(self, x, y): 44 predict = self.network(x, y) 45 return self.loss(predict) 46 47class Net(nn.Cell): 48 def __init__(self, shape, offset, strategy1=None, strategy2=None, target="Device"): 49 super().__init__() 50 self.index = Tensor(np.ones(shape), dtype=ms.int32) 51 self.offset = offset 52 self.elu = P.EmbeddingLookup().shard(strategy1).add_prim_attr("primitive_target", target) 53 self.mm = P.BatchMatMul().shard(strategy2) 54 55 def construct(self, x, y): 56 out = self.elu(x, self.index, self.offset) 57 out = self.mm(out, y) 58 return out 59 60 61def test_embeddinglookup_reducescatter_false(): 62 shape = [8, 8] 63 offset = 8 64 net = NetWithLoss(Net(shape, offset)) 65 net.set_auto_parallel() 66 67 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 68 y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32) 69 net.set_train() 70 _cell_graph_executor.compile(net, x, y) 71 72 73def test_embeddinglookup_reducescatter_true(): 74 shape = [8, 8] 75 offset = 8 76 net = NetWithLoss(Net(shape, offset)) 77 net.set_auto_parallel() 78 79 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 80 y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32) 81 net.set_train() 82 _cell_graph_executor.compile(net, x, y) 83 84 85def test_embeddinglookup_reducescatter_false_grad(): 86 shape = [8, 8] 87 offset = 8 88 net = GradWrap(NetWithLoss(Net(shape, offset))) 89 net.set_auto_parallel() 90 91 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 92 y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32) 93 net.set_train() 94 _cell_graph_executor.compile(net, x, y) 95 96 97def test_embeddinglookup_reducescatter_true_grad(): 98 shape = [8, 8] 99 offset = 8 100 net = GradWrap(NetWithLoss(Net(shape, offset))) 101 net.set_auto_parallel() 102 103 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 104 y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32) 105 net.set_train() 106 _cell_graph_executor.compile(net, x, y) 107 108 109def test_embeddinglookup_semi_auto1(): 110 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 111 shape = [64, 32] 112 offset = 0 113 strategy1 = ((8, 1), (1, 1)) 114 strategy2 = ((4, 1, 2), (4, 2, 1)) 115 net = GradWrap(NetWithLoss(Net(shape, offset, strategy1, strategy2, "CPU"))) 116 117 net.set_auto_parallel() 118 x = Tensor(np.ones([64, 64]), dtype=ms.float32) 119 y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) 120 net.set_train() 121 _cell_graph_executor.compile(net, x, y) 122 123 124def test_embeddinglookup_semi_auto2(): 125 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 126 shape = [64, 32] 127 offset = 0 128 strategy1 = ((1, 8), (1, 1)) 129 strategy2 = ((4, 1, 2), (4, 2, 1)) 130 net = GradWrap(NetWithLoss(Net(shape, offset, strategy1, strategy2, "CPU"))) 131 132 net.set_auto_parallel() 133 x = Tensor(np.ones([64, 64]), dtype=ms.float32) 134 y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) 135 net.set_train() 136 _cell_graph_executor.compile(net, x, y) 137