1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test sparse feature bprop """ 16import pytest 17import numpy as np 18 19import mindspore as ms 20import mindspore.nn as nn 21from mindspore import context 22from mindspore.common.parameter import Parameter 23from mindspore.common.tensor import Tensor 24from mindspore.ops import composite as C, operations as P 25from mindspore.ops.operations.comm_ops import AllReduce 26from mindspore.common.api import _cell_graph_executor 27from mindspore.nn import TrainOneStepCell, Adam 28 29 30grad_all = C.GradOperation(get_all=True) 31 32 33@pytest.fixture(name="test_context") 34def _test_context(): 35 context.set_context(enable_sparse=True) 36 yield 37 context.set_context(enable_sparse=False) 38 context.reset_auto_parallel_context() 39 40 41class GradWrap(nn.Cell): 42 def __init__(self, network): 43 super(GradWrap, self).__init__() 44 self.network = network 45 46 def construct(self, x): 47 return grad_all(self.network)(x) 48 49def test_bprop_with_sparse_feature_allreduce(test_context): 50 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel") 51 52 class Net(nn.Cell): 53 def __init__(self, axis=0, shape=None): 54 super(Net, self).__init__() 55 if shape is None: 56 shape = [8, 8] 57 self.all_reduce = AllReduce() 58 self.gatherv2 = P.SparseGatherV2() 59 self.index = Tensor(np.ones(shape), dtype=ms.int32) 60 self.axis = axis 61 62 def construct(self, x): 63 out = self.all_reduce(x) 64 out = self.gatherv2(out, self.index, self.axis) 65 66 return out 67 68 net = GradWrap(Net()) 69 x = Tensor(np.ones([64, 64]), dtype=ms.float32) 70 71 net.set_train() 72 _cell_graph_executor.compile(net, x) 73 74 75def test_bprop_with_sparse_feature_mirror(test_context): 76 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 77 78 class Net(nn.Cell): 79 def __init__(self, shape=None): 80 super(Net, self).__init__() 81 if shape is None: 82 shape = [8, 8] 83 self.index = Tensor(np.ones(shape), dtype=ms.int32) 84 self.embeddinglookup = nn.EmbeddingLookup(64, 64, param_init='ones') 85 self.embeddinglookup.embeddinglookup.shard(((1, 1), (8, 1))) 86 87 def construct(self, x, b): 88 out = self.embeddinglookup(self.index) 89 90 return out 91 92 _x = Tensor(np.ones([126, 64, 32]), dtype=ms.float32) 93 _b = Tensor(np.ones([126, 64, 32]), dtype=ms.float32) 94 95 def compile_net(net): 96 optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9) 97 train_net = TrainOneStepCell(net, optimizer) 98 train_net.set_train() 99 _cell_graph_executor.compile(train_net, _x, _b) 100 101 net = Net() 102 compile_net(net) 103 104 105def test_bprop_with_sparse_feature_dataparallel(test_context): 106 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="data_parallel") 107 108 class Net(nn.Cell): 109 def __init__(self, axis=0, shape=None): 110 super(Net, self).__init__() 111 if shape is None: 112 shape = [8, 8] 113 weight = Tensor(np.ones([64, 64]), dtype=ms.float32) 114 self.weight = Parameter(weight, "w") 115 self.index = Tensor(np.ones(shape), dtype=ms.int32) 116 self.axis = axis 117 self.gatherv2 = P.SparseGatherV2() 118 119 def construct(self, x, b): 120 out = self.gatherv2(self.weight, self.index, self.axis) 121 122 return out 123 124 _x = Tensor(np.ones([126, 64, 32]), dtype=ms.float32) 125 _b = Tensor(np.ones([126, 64, 32]), dtype=ms.float32) 126 127 def compile_net(net): 128 optimizer = Adam(net.trainable_params(), learning_rate=0.1, loss_scale=1024.0, weight_decay=0.9) 129 train_net = TrainOneStepCell(net, optimizer) 130 train_net.set_train() 131 _cell_graph_executor.compile(train_net, _x, _b) 132 133 net = Net() 134 compile_net(net) 135